github.com/crewjam/saml@v0.4.14/samlsp/middleware.go (about) 1 package samlsp 2 3 import ( 4 "bytes" 5 "encoding/xml" 6 "net/http" 7 8 "github.com/crewjam/saml" 9 ) 10 11 // Middleware implements middleware than allows a web application 12 // to support SAML. 13 // 14 // It implements http.Handler so that it can provide the metadata and ACS endpoints, 15 // typically /saml/metadata and /saml/acs, respectively. 16 // 17 // It also provides middleware RequireAccount which redirects users to 18 // the auth process if they do not have session credentials. 19 // 20 // When redirecting the user through the SAML auth flow, the middleware assigns 21 // a temporary cookie with a random name beginning with "saml_". The value of 22 // the cookie is a signed JSON Web Token containing the original URL requested 23 // and the SAML request ID. The random part of the name corresponds to the 24 // RelayState parameter passed through the SAML flow. 25 // 26 // When validating the SAML response, the RelayState is used to look up the 27 // correct cookie, validate that the SAML request ID, and redirect the user 28 // back to their original URL. 29 // 30 // Sessions are established by issuing a JSON Web Token (JWT) as a session 31 // cookie once the SAML flow has succeeded. The JWT token contains the 32 // authenticated attributes from the SAML assertion. 33 // 34 // When the middleware receives a request with a valid session JWT it extracts 35 // the SAML attributes and modifies the http.Request object adding a Context 36 // object to the request context that contains attributes from the initial 37 // SAML assertion. 38 // 39 // When issuing JSON Web Tokens, a signing key is required. Because the 40 // SAML service provider already has a private key, we borrow that key 41 // to sign the JWTs as well. 42 type Middleware struct { 43 ServiceProvider saml.ServiceProvider 44 OnError func(w http.ResponseWriter, r *http.Request, err error) 45 Binding string // either saml.HTTPPostBinding or saml.HTTPRedirectBinding 46 ResponseBinding string // either saml.HTTPPostBinding or saml.HTTPArtifactBinding 47 RequestTracker RequestTracker 48 Session SessionProvider 49 } 50 51 // ServeHTTP implements http.Handler and serves the SAML-specific HTTP endpoints 52 // on the URIs specified by m.ServiceProvider.MetadataURL and 53 // m.ServiceProvider.AcsURL. 54 func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { 55 if r.URL.Path == m.ServiceProvider.MetadataURL.Path { 56 m.ServeMetadata(w, r) 57 return 58 } 59 60 if r.URL.Path == m.ServiceProvider.AcsURL.Path { 61 m.ServeACS(w, r) 62 return 63 } 64 65 http.NotFoundHandler().ServeHTTP(w, r) 66 } 67 68 // ServeMetadata handles requests for the SAML metadata endpoint. 69 func (m *Middleware) ServeMetadata(w http.ResponseWriter, _ *http.Request) { 70 buf, _ := xml.MarshalIndent(m.ServiceProvider.Metadata(), "", " ") 71 w.Header().Set("Content-Type", "application/samlmetadata+xml") 72 if _, err := w.Write(buf); err != nil { 73 http.Error(w, err.Error(), http.StatusInternalServerError) 74 return 75 } 76 } 77 78 // ServeACS handles requests for the SAML ACS endpoint. 79 func (m *Middleware) ServeACS(w http.ResponseWriter, r *http.Request) { 80 err := r.ParseForm() 81 if err != nil { 82 m.OnError(w, r, err) 83 return 84 } 85 86 possibleRequestIDs := []string{} 87 if m.ServiceProvider.AllowIDPInitiated { 88 possibleRequestIDs = append(possibleRequestIDs, "") 89 } 90 91 trackedRequests := m.RequestTracker.GetTrackedRequests(r) 92 for _, tr := range trackedRequests { 93 possibleRequestIDs = append(possibleRequestIDs, tr.SAMLRequestID) 94 } 95 96 assertion, err := m.ServiceProvider.ParseResponse(r, possibleRequestIDs) 97 if err != nil { 98 m.OnError(w, r, err) 99 return 100 } 101 102 m.CreateSessionFromAssertion(w, r, assertion, m.ServiceProvider.DefaultRedirectURI) 103 } 104 105 // RequireAccount is HTTP middleware that requires that each request be 106 // associated with a valid session. If the request is not associated with a valid 107 // session, then rather than serve the request, the middleware redirects the user 108 // to start the SAML auth flow. 109 func (m *Middleware) RequireAccount(handler http.Handler) http.Handler { 110 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 111 session, err := m.Session.GetSession(r) 112 if session != nil { 113 r = r.WithContext(ContextWithSession(r.Context(), session)) 114 handler.ServeHTTP(w, r) 115 return 116 } 117 if err == ErrNoSession { 118 m.HandleStartAuthFlow(w, r) 119 return 120 } 121 122 m.OnError(w, r, err) 123 }) 124 } 125 126 // HandleStartAuthFlow is called to start the SAML authentication process. 127 func (m *Middleware) HandleStartAuthFlow(w http.ResponseWriter, r *http.Request) { 128 // If we try to redirect when the original request is the ACS URL we'll 129 // end up in a loop. This is a programming error, so we panic here. In 130 // general this means a 500 to the user, which is preferable to a 131 // redirect loop. 132 if r.URL.Path == m.ServiceProvider.AcsURL.Path { 133 panic("don't wrap Middleware with RequireAccount") 134 } 135 136 var binding, bindingLocation string 137 if m.Binding != "" { 138 binding = m.Binding 139 bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding) 140 } else { 141 binding = saml.HTTPRedirectBinding 142 bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding) 143 if bindingLocation == "" { 144 binding = saml.HTTPPostBinding 145 bindingLocation = m.ServiceProvider.GetSSOBindingLocation(binding) 146 } 147 } 148 149 authReq, err := m.ServiceProvider.MakeAuthenticationRequest(bindingLocation, binding, m.ResponseBinding) 150 if err != nil { 151 http.Error(w, err.Error(), http.StatusInternalServerError) 152 return 153 } 154 155 // relayState is limited to 80 bytes but also must be integrity protected. 156 // this means that we cannot use a JWT because it is way to long. Instead 157 // we set a signed cookie that encodes the original URL which we'll check 158 // against the SAML response when we get it. 159 relayState, err := m.RequestTracker.TrackRequest(w, r, authReq.ID) 160 if err != nil { 161 http.Error(w, err.Error(), http.StatusInternalServerError) 162 return 163 } 164 165 if binding == saml.HTTPRedirectBinding { 166 redirectURL, err := authReq.Redirect(relayState, &m.ServiceProvider) 167 if err != nil { 168 http.Error(w, err.Error(), http.StatusInternalServerError) 169 return 170 } 171 w.Header().Add("Location", redirectURL.String()) 172 w.WriteHeader(http.StatusFound) 173 return 174 } 175 if binding == saml.HTTPPostBinding { 176 w.Header().Add("Content-Security-Policy", ""+ 177 "default-src; "+ 178 "script-src 'sha256-AjPdJSbZmeWHnEc5ykvJFay8FTWeTeRbs9dutfZ0HqE='; "+ 179 "reflected-xss block; referrer no-referrer;") 180 w.Header().Add("Content-type", "text/html") 181 var buf bytes.Buffer 182 buf.WriteString(`<!DOCTYPE html><html><body>`) 183 buf.Write(authReq.Post(relayState)) 184 buf.WriteString(`</body></html>`) 185 if _, err := w.Write(buf.Bytes()); err != nil { 186 http.Error(w, err.Error(), http.StatusInternalServerError) 187 return 188 } 189 return 190 } 191 panic("not reached") 192 } 193 194 // CreateSessionFromAssertion is invoked by ServeHTTP when we have a new, valid SAML assertion. 195 func (m *Middleware) CreateSessionFromAssertion(w http.ResponseWriter, r *http.Request, assertion *saml.Assertion, redirectURI string) { 196 if trackedRequestIndex := r.Form.Get("RelayState"); trackedRequestIndex != "" { 197 trackedRequest, err := m.RequestTracker.GetTrackedRequest(r, trackedRequestIndex) 198 if err != nil { 199 if err == http.ErrNoCookie && m.ServiceProvider.AllowIDPInitiated { 200 if uri := r.Form.Get("RelayState"); uri != "" { 201 redirectURI = uri 202 } 203 } else { 204 m.OnError(w, r, err) 205 return 206 } 207 } else { 208 if err := m.RequestTracker.StopTrackingRequest(w, r, trackedRequestIndex); err != nil { 209 m.OnError(w, r, err) 210 return 211 } 212 213 redirectURI = trackedRequest.URI 214 } 215 } 216 217 if err := m.Session.CreateSession(w, r, assertion); err != nil { 218 m.OnError(w, r, err) 219 return 220 } 221 222 http.Redirect(w, r, redirectURI, http.StatusFound) 223 } 224 225 // RequireAttribute returns a middleware function that requires that the 226 // SAML attribute `name` be set to `value`. This can be used to require 227 // that a remote user be a member of a group. It relies on the Claims assigned 228 // to to the context in RequireAccount. 229 // 230 // For example: 231 // 232 // goji.Use(m.RequireAccount) 233 // goji.Use(RequireAttributeMiddleware("eduPersonAffiliation", "Staff")) 234 func RequireAttribute(name, value string) func(http.Handler) http.Handler { 235 return func(handler http.Handler) http.Handler { 236 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 237 if session := SessionFromContext(r.Context()); session != nil { 238 // this will panic if we have the wrong type of Session, and that is OK. 239 sessionWithAttributes := session.(SessionWithAttributes) 240 attributes := sessionWithAttributes.GetAttributes() 241 if values, ok := attributes[name]; ok { 242 for _, v := range values { 243 if v == value { 244 handler.ServeHTTP(w, r) 245 return 246 } 247 } 248 } 249 } 250 http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) 251 }) 252 } 253 }