github.com/crewjam/saml@v0.4.14/samlidp/session.go (about)

     1  package samlidp
     2  
     3  import (
     4  	"encoding/base64"
     5  	"encoding/hex"
     6  	"encoding/json"
     7  	"fmt"
     8  	"net/http"
     9  	"text/template"
    10  	"time"
    11  
    12  	"golang.org/x/crypto/bcrypt"
    13  
    14  	"github.com/zenazn/goji/web"
    15  
    16  	"github.com/crewjam/saml"
    17  )
    18  
    19  var sessionMaxAge = time.Hour
    20  
    21  // GetSession returns the *Session for this request.
    22  //
    23  // If the remote user has specified a username and password in the request
    24  // then it is validated against the user database. If valid it sets a
    25  // cookie and returns the newly created session object.
    26  //
    27  // If the remote user has specified invalid credentials then a login form
    28  // is returned with an English-language toast telling the user their
    29  // password was invalid.
    30  //
    31  // If a session cookie already exists and represents a valid session,
    32  // then the session is returned
    33  //
    34  // If neither credentials nor a valid session cookie exist, this function
    35  // sends a login form and returns nil.
    36  func (s *Server) GetSession(w http.ResponseWriter, r *http.Request, req *saml.IdpAuthnRequest) *saml.Session {
    37  	// if we received login credentials then maybe we can create a session
    38  	if r.Method == "POST" && r.PostForm.Get("user") != "" {
    39  		user := User{}
    40  		if err := s.Store.Get(fmt.Sprintf("/users/%s", r.PostForm.Get("user")), &user); err != nil {
    41  			s.sendLoginForm(w, r, req, "Invalid username or password")
    42  			return nil
    43  		}
    44  
    45  		if err := bcrypt.CompareHashAndPassword(user.HashedPassword, []byte(r.PostForm.Get("password"))); err != nil {
    46  			s.sendLoginForm(w, r, req, "Invalid username or password")
    47  			return nil
    48  		}
    49  
    50  		session := &saml.Session{
    51  			ID:         base64.StdEncoding.EncodeToString(randomBytes(32)),
    52  			NameID:     user.Email,
    53  			CreateTime: saml.TimeNow(),
    54  			ExpireTime: saml.TimeNow().Add(sessionMaxAge),
    55  			Index:      hex.EncodeToString(randomBytes(32)),
    56  			UserName:   user.Name,
    57  			// nolint:gocritic // Groups should be a slice here.
    58  			Groups:                user.Groups[:],
    59  			UserEmail:             user.Email,
    60  			UserCommonName:        user.CommonName,
    61  			UserSurname:           user.Surname,
    62  			UserGivenName:         user.GivenName,
    63  			UserScopedAffiliation: user.ScopedAffiliation,
    64  		}
    65  		if err := s.Store.Put(fmt.Sprintf("/sessions/%s", session.ID), &session); err != nil {
    66  			http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
    67  			return nil
    68  		}
    69  
    70  		http.SetCookie(w, &http.Cookie{
    71  			Name:     "session",
    72  			Value:    session.ID,
    73  			MaxAge:   int(sessionMaxAge.Seconds()),
    74  			HttpOnly: true,
    75  			Secure:   r.URL.Scheme == "https",
    76  			Path:     "/",
    77  		})
    78  		return session
    79  	}
    80  
    81  	if sessionCookie, err := r.Cookie("session"); err == nil {
    82  		session := &saml.Session{}
    83  		if err := s.Store.Get(fmt.Sprintf("/sessions/%s", sessionCookie.Value), session); err != nil {
    84  			if err == ErrNotFound {
    85  				s.sendLoginForm(w, r, req, "")
    86  				return nil
    87  			}
    88  			http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
    89  			return nil
    90  		}
    91  
    92  		if saml.TimeNow().After(session.ExpireTime) {
    93  			s.sendLoginForm(w, r, req, "")
    94  			return nil
    95  		}
    96  		return session
    97  	}
    98  
    99  	s.sendLoginForm(w, r, req, "")
   100  	return nil
   101  }
   102  
   103  // sendLoginForm produces a form which requests a username and password and directs the user
   104  // back to the IDP authorize URL to restart the SAML login flow, this time establishing a
   105  // session based on the credentials that were provided.
   106  func (s *Server) sendLoginForm(w http.ResponseWriter, _ *http.Request, req *saml.IdpAuthnRequest, toast string) {
   107  	tmpl := template.Must(template.New("saml-post-form").Parse(`` +
   108  		`<html>` +
   109  		`<p>{{.Toast}}</p>` +
   110  		`<form method="post" action="{{.URL}}">` +
   111  		`<input type="text" name="user" placeholder="user" value="" />` +
   112  		`<input type="password" name="password" placeholder="password" value="" />` +
   113  		`<input type="hidden" name="SAMLRequest" value="{{.SAMLRequest}}" />` +
   114  		`<input type="hidden" name="RelayState" value="{{.RelayState}}" />` +
   115  		`<input type="submit" value="Log In" />` +
   116  		`</form>` +
   117  		`</html>`))
   118  	data := struct {
   119  		Toast       string
   120  		URL         string
   121  		SAMLRequest string
   122  		RelayState  string
   123  	}{
   124  		Toast:       toast,
   125  		URL:         req.IDP.SSOURL.String(),
   126  		SAMLRequest: base64.StdEncoding.EncodeToString(req.RequestBuffer),
   127  		RelayState:  req.RelayState,
   128  	}
   129  
   130  	if err := tmpl.Execute(w, data); err != nil {
   131  		panic(err)
   132  	}
   133  }
   134  
   135  // HandleLogin handles the `POST /login` and `GET /login` forms. If credentials are present
   136  // in the request body, then they are validated. For valid credentials, the response is a
   137  // 200 OK and the JSON session object. For invalid credentials, the HTML login prompt form
   138  // is sent.
   139  func (s *Server) HandleLogin(_ web.C, w http.ResponseWriter, r *http.Request) {
   140  	if err := r.ParseForm(); err != nil {
   141  		http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
   142  		return
   143  	}
   144  	session := s.GetSession(w, r, &saml.IdpAuthnRequest{IDP: &s.IDP})
   145  	if session == nil {
   146  		return
   147  	}
   148  	if err := json.NewEncoder(w).Encode(session); err != nil {
   149  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   150  		return
   151  	}
   152  }
   153  
   154  // HandleListSessions handles the `GET /sessions/` request and responds with a JSON formatted list
   155  // of session names.
   156  func (s *Server) HandleListSessions(_ web.C, w http.ResponseWriter, _ *http.Request) {
   157  	sessions, err := s.Store.List("/sessions/")
   158  	if err != nil {
   159  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   160  		return
   161  	}
   162  
   163  	err = json.NewEncoder(w).Encode(struct {
   164  		Sessions []string `json:"sessions"`
   165  	}{Sessions: sessions})
   166  	if err != nil {
   167  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   168  		return
   169  	}
   170  }
   171  
   172  // HandleGetSession handles the `GET /sessions/:id` request and responds with the session
   173  // object in JSON format.
   174  func (s *Server) HandleGetSession(c web.C, w http.ResponseWriter, _ *http.Request) {
   175  	session := saml.Session{}
   176  	err := s.Store.Get(fmt.Sprintf("/sessions/%s", c.URLParams["id"]), &session)
   177  	if err != nil {
   178  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   179  		return
   180  	}
   181  	if err := json.NewEncoder(w).Encode(session); err != nil {
   182  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   183  		return
   184  	}
   185  }
   186  
   187  // HandleDeleteSession handles the `DELETE /sessions/:id` request. It invalidates the
   188  // specified session.
   189  func (s *Server) HandleDeleteSession(c web.C, w http.ResponseWriter, _ *http.Request) {
   190  	err := s.Store.Delete(fmt.Sprintf("/sessions/%s", c.URLParams["id"]))
   191  	if err != nil {
   192  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   193  		return
   194  	}
   195  	w.WriteHeader(http.StatusNoContent)
   196  }