github.com/crewjam/saml@v0.4.14/samlsp/middleware.go (about)

     1  package samlsp
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/xml"
     6  	"net/http"
     7  
     8  	"github.com/crewjam/saml"
     9  )
    10  
    11  // Middleware implements middleware than allows a web application
    12  // to support SAML.
    13  //
    14  // It implements http.Handler so that it can provide the metadata and ACS endpoints,
    15  // typically /saml/metadata and /saml/acs, respectively.
    16  //
    17  // It also provides middleware RequireAccount which redirects users to
    18  // the auth process if they do not have session credentials.
    19  //
    20  // When redirecting the user through the SAML auth flow, the middleware assigns
    21  // a temporary cookie with a random name beginning with "saml_". The value of
    22  // the cookie is a signed JSON Web Token containing the original URL requested
    23  // and the SAML request ID. The random part of the name corresponds to the
    24  // RelayState parameter passed through the SAML flow.
    25  //
    26  // When validating the SAML response, the RelayState is used to look up the
    27  // correct cookie, validate that the SAML request ID, and redirect the user
    28  // back to their original URL.
    29  //
    30  // Sessions are established by issuing a JSON Web Token (JWT) as a session
    31  // cookie once the SAML flow has succeeded. The JWT token contains the
    32  // authenticated attributes from the SAML assertion.
    33  //
    34  // When the middleware receives a request with a valid session JWT it extracts
    35  // the SAML attributes and modifies the http.Request object adding a Context
    36  // object to the request context that contains attributes from the initial
    37  // SAML assertion.
    38  //
    39  // When issuing JSON Web Tokens, a signing key is required. Because the
    40  // SAML service provider already has a private key, we borrow that key
    41  // to sign the JWTs as well.
    42  type Middleware struct {
    43  	ServiceProvider saml.ServiceProvider
    44  	OnError         func(w http.ResponseWriter, r *http.Request, err error)
    45  	Binding         string // either saml.HTTPPostBinding or saml.HTTPRedirectBinding
    46  	ResponseBinding string // either saml.HTTPPostBinding or saml.HTTPArtifactBinding
    47  	RequestTracker  RequestTracker
    48  	Session         SessionProvider
    49  }
    50  
    51  // ServeHTTP implements http.Handler and serves the SAML-specific HTTP endpoints
    52  // on the URIs specified by m.ServiceProvider.MetadataURL and
    53  // m.ServiceProvider.AcsURL.
    54  func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    55  	if r.URL.Path == m.ServiceProvider.MetadataURL.Path {
    56  		m.ServeMetadata(w, r)
    57  		return
    58  	}
    59  
    60  	if r.URL.Path == m.ServiceProvider.AcsURL.Path {
    61  		m.ServeACS(w, r)
    62  		return
    63  	}
    64  
    65  	http.NotFoundHandler().ServeHTTP(w, r)
    66  }
    67  
    68  // ServeMetadata handles requests for the SAML metadata endpoint.
    69  func (m *Middleware) ServeMetadata(w http.ResponseWriter, _ *http.Request) {
    70  	buf, _ := xml.MarshalIndent(m.ServiceProvider.Metadata(), "", "  ")
    71  	w.Header().Set("Content-Type", "application/samlmetadata+xml")
    72  	if _, err := w.Write(buf); err != nil {
    73  		http.Error(w, err.Error(), http.StatusInternalServerError)
    74  		return
    75  	}
    76  }
    77  
    78  // ServeACS handles requests for the SAML ACS endpoint.
    79  func (m *Middleware) ServeACS(w http.ResponseWriter, r *http.Request) {
    80  	err := r.ParseForm()
    81  	if err != nil {
    82  		m.OnError(w, r, err)
    83  		return
    84  	}
    85  
    86  	possibleRequestIDs := []string{}
    87  	if m.ServiceProvider.AllowIDPInitiated {
    88  		possibleRequestIDs = append(possibleRequestIDs, "")
    89  	}
    90  
    91  	trackedRequests := m.RequestTracker.GetTrackedRequests(r)
    92  	for _, tr := range trackedRequests {
    93  		possibleRequestIDs = append(possibleRequestIDs, tr.SAMLRequestID)
    94  	}
    95  
    96  	assertion, err := m.ServiceProvider.ParseResponse(r, possibleRequestIDs)
    97  	if err != nil {
    98  		m.OnError(w, r, err)
    99  		return
   100  	}
   101  
   102  	m.CreateSessionFromAssertion(w, r, assertion, m.ServiceProvider.DefaultRedirectURI)
   103  }
   104  
   105  // RequireAccount is HTTP middleware that requires that each request be
   106  // associated with a valid session. If the request is not associated with a valid
   107  // session, then rather than serve the request, the middleware redirects the user
   108  // to start the SAML auth flow.
   109  func (m *Middleware) RequireAccount(handler http.Handler) http.Handler {
   110  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   111  		session, err := m.Session.GetSession(r)
   112  		if session != nil {
   113  			r = r.WithContext(ContextWithSession(r.Context(), session))
   114  			handler.ServeHTTP(w, r)
   115  			return
   116  		}
   117  		if err == ErrNoSession {
   118  			m.HandleStartAuthFlow(w, r)
   119  			return
   120  		}
   121  
   122  		m.OnError(w, r, err)
   123  	})
   124  }
   125  
   126  // HandleStartAuthFlow is called to start the SAML authentication process.
   127  func (m *Middleware) HandleStartAuthFlow(w http.ResponseWriter, r *http.Request) {
   128  	// If we try to redirect when the original request is the ACS URL we'll
   129  	// end up in a loop. This is a programming error, so we panic here. In
   130  	// general this means a 500 to the user, which is preferable to a
   131  	// redirect loop.
   132  	if r.URL.Path == m.ServiceProvider.AcsURL.Path {
   133  		panic("don't wrap Middleware with RequireAccount")
   134  	}
   135  
   136  	var binding, bindingLocation string
   137  	if m.Binding != "" {
   138  		binding = m.Binding
   139  		bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding)
   140  	} else {
   141  		binding = saml.HTTPRedirectBinding
   142  		bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding)
   143  		if bindingLocation == "" {
   144  			binding = saml.HTTPPostBinding
   145  			bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding)
   146  		}
   147  	}
   148  
   149  	authReq, err := m.ServiceProvider.MakeAuthenticationRequest(bindingLocation, binding, m.ResponseBinding)
   150  	if err != nil {
   151  		http.Error(w, err.Error(), http.StatusInternalServerError)
   152  		return
   153  	}
   154  
   155  	// relayState is limited to 80 bytes but also must be integrity protected.
   156  	// this means that we cannot use a JWT because it is way to long. Instead
   157  	// we set a signed cookie that encodes the original URL which we'll check
   158  	// against the SAML response when we get it.
   159  	relayState, err := m.RequestTracker.TrackRequest(w, r, authReq.ID)
   160  	if err != nil {
   161  		http.Error(w, err.Error(), http.StatusInternalServerError)
   162  		return
   163  	}
   164  
   165  	if binding == saml.HTTPRedirectBinding {
   166  		redirectURL, err := authReq.Redirect(relayState, &m.ServiceProvider)
   167  		if err != nil {
   168  			http.Error(w, err.Error(), http.StatusInternalServerError)
   169  			return
   170  		}
   171  		w.Header().Add("Location", redirectURL.String())
   172  		w.WriteHeader(http.StatusFound)
   173  		return
   174  	}
   175  	if binding == saml.HTTPPostBinding {
   176  		w.Header().Add("Content-Security-Policy", ""+
   177  			"default-src; "+
   178  			"script-src 'sha256-AjPdJSbZmeWHnEc5ykvJFay8FTWeTeRbs9dutfZ0HqE='; "+
   179  			"reflected-xss block; referrer no-referrer;")
   180  		w.Header().Add("Content-type", "text/html")
   181  		var buf bytes.Buffer
   182  		buf.WriteString(`<!DOCTYPE html><html><body>`)
   183  		buf.Write(authReq.Post(relayState))
   184  		buf.WriteString(`</body></html>`)
   185  		if _, err := w.Write(buf.Bytes()); err != nil {
   186  			http.Error(w, err.Error(), http.StatusInternalServerError)
   187  			return
   188  		}
   189  		return
   190  	}
   191  	panic("not reached")
   192  }
   193  
   194  // CreateSessionFromAssertion is invoked by ServeHTTP when we have a new, valid SAML assertion.
   195  func (m *Middleware) CreateSessionFromAssertion(w http.ResponseWriter, r *http.Request, assertion *saml.Assertion, redirectURI string) {
   196  	if trackedRequestIndex := r.Form.Get("RelayState"); trackedRequestIndex != "" {
   197  		trackedRequest, err := m.RequestTracker.GetTrackedRequest(r, trackedRequestIndex)
   198  		if err != nil {
   199  			if err == http.ErrNoCookie && m.ServiceProvider.AllowIDPInitiated {
   200  				if uri := r.Form.Get("RelayState"); uri != "" {
   201  					redirectURI = uri
   202  				}
   203  			} else {
   204  				m.OnError(w, r, err)
   205  				return
   206  			}
   207  		} else {
   208  			if err := m.RequestTracker.StopTrackingRequest(w, r, trackedRequestIndex); err != nil {
   209  				m.OnError(w, r, err)
   210  				return
   211  			}
   212  
   213  			redirectURI = trackedRequest.URI
   214  		}
   215  	}
   216  
   217  	if err := m.Session.CreateSession(w, r, assertion); err != nil {
   218  		m.OnError(w, r, err)
   219  		return
   220  	}
   221  
   222  	http.Redirect(w, r, redirectURI, http.StatusFound)
   223  }
   224  
   225  // RequireAttribute returns a middleware function that requires that the
   226  // SAML attribute `name` be set to `value`. This can be used to require
   227  // that a remote user be a member of a group. It relies on the Claims assigned
   228  // to to the context in RequireAccount.
   229  //
   230  // For example:
   231  //
   232  //	goji.Use(m.RequireAccount)
   233  //	goji.Use(RequireAttributeMiddleware("eduPersonAffiliation", "Staff"))
   234  func RequireAttribute(name, value string) func(http.Handler) http.Handler {
   235  	return func(handler http.Handler) http.Handler {
   236  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   237  			if session := SessionFromContext(r.Context()); session != nil {
   238  				// this will panic if we have the wrong type of Session, and that is OK.
   239  				sessionWithAttributes := session.(SessionWithAttributes)
   240  				attributes := sessionWithAttributes.GetAttributes()
   241  				if values, ok := attributes[name]; ok {
   242  					for _, v := range values {
   243  						if v == value {
   244  							handler.ServeHTTP(w, r)
   245  							return
   246  						}
   247  					}
   248  				}
   249  			}
   250  			http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
   251  		})
   252  	}
   253  }