github.com/crewjam/saml@v0.4.14/samlsp/middleware_test.go (about) 1 package samlsp 2 3 import ( 4 "bytes" 5 "crypto/rsa" 6 "crypto/sha256" 7 "crypto/x509" 8 "encoding/base64" 9 "encoding/json" 10 "encoding/xml" 11 "io" 12 "net" 13 "net/http" 14 "net/http/httptest" 15 "net/url" 16 "strings" 17 "testing" 18 "time" 19 20 "github.com/golang-jwt/jwt/v4" 21 dsig "github.com/russellhaering/goxmldsig" 22 "gotest.tools/assert" 23 is "gotest.tools/assert/cmp" 24 "gotest.tools/golden" 25 26 "github.com/crewjam/saml" 27 "github.com/crewjam/saml/testsaml" 28 ) 29 30 type MiddlewareTest struct { 31 AuthnRequest []byte 32 SamlResponse []byte 33 Key *rsa.PrivateKey 34 Certificate *x509.Certificate 35 IDPMetadata []byte 36 Middleware *Middleware 37 expectedSessionCookie string 38 } 39 40 type testRandomReader struct { 41 Next byte 42 } 43 44 func (tr *testRandomReader) Read(p []byte) (n int, err error) { 45 for i := 0; i < len(p); i++ { 46 p[i] = tr.Next 47 tr.Next += 2 48 } 49 return len(p), nil 50 } 51 52 func NewMiddlewareTest(t *testing.T) *MiddlewareTest { 53 test := MiddlewareTest{} 54 saml.TimeNow = func() time.Time { 55 rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 MST 2006", "Mon Dec 1 01:57:09.123456789 UTC 2015") 56 return rv 57 } 58 jwt.TimeFunc = saml.TimeNow 59 saml.Clock = dsig.NewFakeClockAt(saml.TimeNow()) 60 saml.RandReader = &testRandomReader{} 61 62 test.AuthnRequest = golden.Get(t, "authn_request.url") 63 test.SamlResponse = golden.Get(t, "saml_response.xml") 64 test.Key = mustParsePrivateKey(golden.Get(t, "key.pem")).(*rsa.PrivateKey) 65 test.Certificate = mustParseCertificate(golden.Get(t, "cert.pem")) 66 test.IDPMetadata = golden.Get(t, "idp_metadata.xml") 67 68 var metadata saml.EntityDescriptor 69 if err := xml.Unmarshal(test.IDPMetadata, &metadata); err != nil { 70 panic(err) 71 } 72 73 opts := Options{ 74 URL: mustParseURL("https://15661444.ngrok.io/"), 75 Key: test.Key, 76 Certificate: test.Certificate, 77 IDPMetadata: &metadata, 78 } 79 80 var err error 81 test.Middleware, err = New(opts) 82 if err != nil { 83 panic(err) 84 } 85 86 sessionProvider := DefaultSessionProvider(opts) 87 sessionProvider.Name = "ttt" 88 sessionProvider.MaxAge = 7200 * time.Second 89 90 sessionCodec := sessionProvider.Codec.(JWTSessionCodec) 91 sessionCodec.MaxAge = 7200 * time.Second 92 sessionProvider.Codec = sessionCodec 93 94 test.Middleware.Session = sessionProvider 95 96 test.Middleware.ServiceProvider.MetadataURL.Path = "/saml2/metadata" 97 test.Middleware.ServiceProvider.AcsURL.Path = "/saml2/acs" 98 test.Middleware.ServiceProvider.SloURL.Path = "/saml2/slo" 99 100 var tc JWTSessionClaims 101 if err := json.Unmarshal(golden.Get(t, "token.json"), &tc); err != nil { 102 panic(err) 103 } 104 test.expectedSessionCookie, err = sessionProvider.Codec.Encode(tc) 105 if err != nil { 106 panic(err) 107 } 108 109 return &test 110 } 111 112 func (test *MiddlewareTest) makeTrackedRequest(id string) string { 113 codec := test.Middleware.RequestTracker.(CookieRequestTracker).Codec 114 token, err := codec.Encode(TrackedRequest{ 115 Index: "KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6", 116 SAMLRequestID: id, 117 URI: "/frob", 118 }) 119 if err != nil { 120 panic(err) 121 } 122 return token 123 } 124 125 func TestMiddlewareCanProduceMetadata(t *testing.T) { 126 test := NewMiddlewareTest(t) 127 req, _ := http.NewRequest("GET", "/saml2/metadata", nil) 128 129 resp := httptest.NewRecorder() 130 test.Middleware.ServeHTTP(resp, req) 131 assert.Check(t, is.Equal(http.StatusOK, resp.Code)) 132 assert.Check(t, is.Equal("application/samlmetadata+xml", 133 resp.Header().Get("Content-type"))) 134 golden.Assert(t, resp.Body.String(), "expected_middleware_metadata.xml") 135 } 136 137 func TestMiddlewareFourOhFour(t *testing.T) { 138 test := NewMiddlewareTest(t) 139 req, _ := http.NewRequest("GET", "/this/is/not/a/supported/uri", nil) 140 141 resp := httptest.NewRecorder() 142 test.Middleware.ServeHTTP(resp, req) 143 assert.Check(t, is.Equal(http.StatusNotFound, resp.Code)) 144 respBuf, _ := io.ReadAll(resp.Body) 145 assert.Check(t, is.Equal("404 page not found\n", string(respBuf))) 146 } 147 148 func TestMiddlewareRequireAccountNoCreds(t *testing.T) { 149 test := NewMiddlewareTest(t) 150 test.Middleware.ServiceProvider.AcsURL.Scheme = "http" 151 152 handler := test.Middleware.RequireAccount( 153 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 154 panic("not reached") 155 })) 156 157 req, _ := http.NewRequest("GET", "/frob", nil) 158 resp := httptest.NewRecorder() 159 handler.ServeHTTP(resp, req) 160 161 assert.Check(t, is.Equal(http.StatusFound, resp.Code)) 162 assert.Check(t, is.Equal("saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+ 163 test.makeTrackedRequest("id-00020406080a0c0e10121416181a1c1e20222426")+"; Path=/saml2/acs; Max-Age=90; HttpOnly", 164 resp.Header().Get("Set-Cookie"))) 165 166 redirectURL, err := url.Parse(resp.Header().Get("Location")) 167 assert.Check(t, err) 168 decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL) 169 assert.Check(t, err) 170 golden.Assert(t, string(decodedRequest), "expected_authn_request.xml") 171 } 172 173 func TestMiddlewareRequireAccountNoCredsSecure(t *testing.T) { 174 test := NewMiddlewareTest(t) 175 176 handler := test.Middleware.RequireAccount( 177 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 178 panic("not reached") 179 })) 180 181 req, _ := http.NewRequest("GET", "/frob", nil) 182 resp := httptest.NewRecorder() 183 handler.ServeHTTP(resp, req) 184 185 assert.Check(t, is.Equal(http.StatusFound, resp.Code)) 186 assert.Check(t, is.Equal("saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("id-00020406080a0c0e10121416181a1c1e20222426")+"; Path=/saml2/acs; Max-Age=90; HttpOnly; Secure", 187 resp.Header().Get("Set-Cookie"))) 188 189 redirectURL, err := url.Parse(resp.Header().Get("Location")) 190 assert.Check(t, err) 191 decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL) 192 assert.Check(t, err) 193 golden.Assert(t, string(decodedRequest), "expected_authn_request_secure.xml") 194 } 195 196 func TestMiddlewareRequireAccountNoCredsPostBinding(t *testing.T) { 197 test := NewMiddlewareTest(t) 198 test.Middleware.ServiceProvider.IDPMetadata.IDPSSODescriptors[0].SingleSignOnServices = test.Middleware.ServiceProvider.IDPMetadata.IDPSSODescriptors[0].SingleSignOnServices[1:2] 199 assert.Check(t, is.Equal("", 200 test.Middleware.ServiceProvider.GetSSOBindingLocation(saml.HTTPRedirectBinding))) 201 202 handler := test.Middleware.RequireAccount( 203 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 204 panic("not reached") 205 })) 206 207 req, _ := http.NewRequest("GET", "/frob", nil) 208 resp := httptest.NewRecorder() 209 handler.ServeHTTP(resp, req) 210 211 assert.Check(t, is.Equal(http.StatusOK, resp.Code)) 212 assert.Check(t, is.Equal("saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("id-00020406080a0c0e10121416181a1c1e20222426")+"; Path=/saml2/acs; Max-Age=90; HttpOnly; Secure", 213 resp.Header().Get("Set-Cookie"))) 214 215 golden.Assert(t, resp.Body.String(), "expected_post_binding_response.html") 216 217 // check that the CSP script hash is set correctly 218 scriptContent := "document.getElementById('SAMLSubmitButton').style.visibility=\"hidden\";document.getElementById('SAMLRequestForm').submit();" 219 scriptSum := sha256.Sum256([]byte(scriptContent)) 220 scriptHash := base64.StdEncoding.EncodeToString(scriptSum[:]) 221 assert.Check(t, is.Equal("default-src; script-src 'sha256-"+scriptHash+"'; reflected-xss block; referrer no-referrer;", 222 resp.Header().Get("Content-Security-Policy"))) 223 224 assert.Check(t, is.Equal("text/html", resp.Header().Get("Content-type"))) 225 } 226 227 func TestMiddlewareRequireAccountCreds(t *testing.T) { 228 test := NewMiddlewareTest(t) 229 handler := test.Middleware.RequireAccount( 230 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 231 genericSession := SessionFromContext(r.Context()) 232 jwtSession := genericSession.(JWTSessionClaims) 233 assert.Check(t, is.Equal("555-5555", jwtSession.Attributes.Get("telephoneNumber"))) 234 assert.Check(t, is.Equal("And I", jwtSession.Attributes.Get("sn"))) 235 assert.Check(t, is.Equal("urn:mace:dir:entitlement:common-lib-terms", jwtSession.Attributes.Get("eduPersonEntitlement"))) 236 assert.Check(t, is.Equal("", jwtSession.Attributes.Get("eduPersonTargetedID"))) 237 assert.Check(t, is.Equal("Me Myself", jwtSession.Attributes.Get("givenName"))) 238 assert.Check(t, is.Equal("Me Myself And I", jwtSession.Attributes.Get("cn"))) 239 assert.Check(t, is.Equal("myself", jwtSession.Attributes.Get("uid"))) 240 assert.Check(t, is.Equal("myself@testshib.org", jwtSession.Attributes.Get("eduPersonPrincipalName"))) 241 assert.Check(t, is.DeepEqual([]string{"Member@testshib.org", "Staff@testshib.org"}, jwtSession.Attributes["eduPersonScopedAffiliation"])) 242 assert.Check(t, is.DeepEqual([]string{"Member", "Staff"}, jwtSession.Attributes["eduPersonAffiliation"])) 243 w.WriteHeader(http.StatusTeapot) 244 })) 245 246 req, _ := http.NewRequest("GET", "/frob", nil) 247 req.Header.Set("Cookie", ""+ 248 "ttt="+test.expectedSessionCookie+"; "+ 249 "Path=/; Max-Age=7200") 250 resp := httptest.NewRecorder() 251 handler.ServeHTTP(resp, req) 252 253 assert.Check(t, is.Equal(http.StatusTeapot, resp.Code)) 254 } 255 256 func TestMiddlewareRequireAccountBadCreds(t *testing.T) { 257 test := NewMiddlewareTest(t) 258 handler := test.Middleware.RequireAccount( 259 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 260 panic("not reached") 261 })) 262 263 req, _ := http.NewRequest("GET", "/frob", nil) 264 req.Header.Set("Cookie", ""+ 265 "ttt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.yejJbiI6Ik1lIE15c2VsZiBBbmQgSSIsImVkdVBlcnNvbkFmZmlsaWF0aW9uIjoiU3RhZmYiLCJlZHVQZXJzb25FbnRpdGxlbWVudCI6InVybjptYWNlOmRpcjplbnRpdGxlbWVudDpjb21tb24tbGliLXRlcm1zIiwiZWR1UGVyc29uUHJpbmNpcGFsTmFtZSI6Im15c2VsZkB0ZXN0c2hpYi5vcmciLCJlZHVQZXJzb25TY29wZWRBZmZpbGlhdGlvbiI6IlN0YWZmQHRlc3RzaGliLm9yZyIsImVkdVBlcnNvblRhcmdldGVkSUQiOiIiLCJleHAiOjE0NDg5Mzg2MjksImdpdmVuTmFtZSI6Ik1lIE15c2VsZiIsInNuIjoiQW5kIEkiLCJ0ZWxlcGhvbmVOdW1iZXIiOiI1NTUtNTU1NSIsInVpZCI6Im15c2VsZiJ9.SqeTkbGG35oFj_9H-d9oVdV-Hb7Vqam6LvZLcmia7FY; "+ 266 "Path=/; Max-Age=7200; Secure") 267 resp := httptest.NewRecorder() 268 handler.ServeHTTP(resp, req) 269 270 assert.Check(t, is.Equal(http.StatusFound, resp.Code)) 271 272 assert.Check(t, is.Equal("saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("id-00020406080a0c0e10121416181a1c1e20222426")+"; Path=/saml2/acs; Max-Age=90; HttpOnly; Secure", 273 resp.Header().Get("Set-Cookie"))) 274 275 redirectURL, err := url.Parse(resp.Header().Get("Location")) 276 assert.Check(t, err) 277 decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL) 278 assert.Check(t, err) 279 golden.Assert(t, string(decodedRequest), "expected_authn_request_secure.xml") 280 } 281 282 func TestMiddlewareRequireAccountExpiredCreds(t *testing.T) { 283 test := NewMiddlewareTest(t) 284 jwt.TimeFunc = func() time.Time { 285 rv, _ := time.Parse("Mon Jan 2 15:04:05 UTC 2006", "Mon Dec 1 01:31:21 UTC 2115") 286 return rv 287 } 288 289 handler := test.Middleware.RequireAccount( 290 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 291 panic("not reached") 292 })) 293 294 req, _ := http.NewRequest("GET", "/frob", nil) 295 req.Header.Set("Cookie", ""+ 296 "ttt="+test.expectedSessionCookie+"; "+ 297 "Path=/; Max-Age=7200") 298 resp := httptest.NewRecorder() 299 handler.ServeHTTP(resp, req) 300 301 assert.Check(t, is.Equal(http.StatusFound, resp.Code)) 302 assert.Check(t, is.Equal("saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("id-00020406080a0c0e10121416181a1c1e20222426")+"; Path=/saml2/acs; Max-Age=90; HttpOnly; Secure", 303 resp.Header().Get("Set-Cookie"))) 304 305 redirectURL, err := url.Parse(resp.Header().Get("Location")) 306 assert.Check(t, err) 307 decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL) 308 assert.Check(t, err) 309 golden.Assert(t, string(decodedRequest), "expected_authn_request_secure.xml") 310 } 311 312 func TestMiddlewareRequireAccountPanicOnRequestToACS(t *testing.T) { 313 test := NewMiddlewareTest(t) 314 handler := test.Middleware.RequireAccount( 315 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 316 panic("not reached") 317 })) 318 319 req, _ := http.NewRequest("POST", "https://15661444.ngrok.io/saml2/acs", nil) 320 resp := httptest.NewRecorder() 321 322 assert.Check(t, is.Panics(func() { handler.ServeHTTP(resp, req) })) 323 } 324 325 func TestMiddlewareRequireAttribute(t *testing.T) { 326 test := NewMiddlewareTest(t) 327 handler := test.Middleware.RequireAccount( 328 RequireAttribute("eduPersonAffiliation", "Staff")( 329 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 330 w.WriteHeader(http.StatusTeapot) 331 }))) 332 333 req, _ := http.NewRequest("GET", "/frob", nil) 334 req.Header.Set("Cookie", ""+ 335 "ttt="+test.expectedSessionCookie+"; "+ 336 "Path=/; Max-Age=7200") 337 resp := httptest.NewRecorder() 338 handler.ServeHTTP(resp, req) 339 340 assert.Check(t, is.Equal(http.StatusTeapot, resp.Code)) 341 } 342 343 func TestMiddlewareRequireAttributeWrongValue(t *testing.T) { 344 test := NewMiddlewareTest(t) 345 handler := test.Middleware.RequireAccount( 346 RequireAttribute("eduPersonAffiliation", "DomainAdmins")( 347 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 348 panic("not reached") 349 }))) 350 351 req, _ := http.NewRequest("GET", "/frob", nil) 352 req.Header.Set("Cookie", ""+ 353 "ttt="+test.expectedSessionCookie+"; "+ 354 "Path=/; Max-Age=7200") 355 resp := httptest.NewRecorder() 356 handler.ServeHTTP(resp, req) 357 358 assert.Check(t, is.Equal(http.StatusForbidden, resp.Code)) 359 } 360 361 func TestMiddlewareRequireAttributeNotPresent(t *testing.T) { 362 test := NewMiddlewareTest(t) 363 handler := test.Middleware.RequireAccount( 364 RequireAttribute("valueThatDoesntExist", "doesntMatter")( 365 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 366 panic("not reached") 367 }))) 368 369 req, _ := http.NewRequest("GET", "/frob", nil) 370 req.Header.Set("Cookie", ""+ 371 "ttt="+test.expectedSessionCookie+"; "+ 372 "Path=/; Max-Age=7200") 373 resp := httptest.NewRecorder() 374 handler.ServeHTTP(resp, req) 375 376 assert.Check(t, is.Equal(http.StatusForbidden, resp.Code)) 377 } 378 379 func TestMiddlewareRequireAttributeMissingAccount(t *testing.T) { 380 test := NewMiddlewareTest(t) 381 handler := RequireAttribute("eduPersonAffiliation", "DomainAdmins")( 382 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 383 panic("not reached") 384 })) 385 386 req, _ := http.NewRequest("GET", "/frob", nil) 387 req.Header.Set("Cookie", ""+ 388 "ttt="+test.expectedSessionCookie+"; "+ 389 "Path=/; Max-Age=7200") 390 resp := httptest.NewRecorder() 391 handler.ServeHTTP(resp, req) 392 393 assert.Check(t, is.Equal(http.StatusForbidden, resp.Code)) 394 } 395 396 func TestMiddlewareCanParseResponse(t *testing.T) { 397 test := NewMiddlewareTest(t) 398 v := &url.Values{} 399 v.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) 400 v.Set("RelayState", "KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6") 401 req, _ := http.NewRequest("POST", "/saml2/acs", bytes.NewReader([]byte(v.Encode()))) 402 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 403 req.Header.Set("Cookie", ""+ 404 "saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("id-9e61753d64e928af5a7a341a97f420c9")) 405 406 resp := httptest.NewRecorder() 407 test.Middleware.ServeHTTP(resp, req) 408 assert.Check(t, is.Equal(http.StatusFound, resp.Code)) 409 410 assert.Check(t, is.Equal("/frob", resp.Header().Get("Location"))) 411 assert.Check(t, is.DeepEqual([]string{ 412 "saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6=; Domain=15661444.ngrok.io; Expires=Thu, 01 Jan 1970 00:00:01 GMT", 413 "ttt=" + test.expectedSessionCookie + "; " + 414 "Path=/; Domain=15661444.ngrok.io; Max-Age=7200; HttpOnly; Secure"}, 415 resp.Header()["Set-Cookie"])) 416 } 417 418 func TestMiddlewareDefaultCookieDomainIPv4(t *testing.T) { 419 test := NewMiddlewareTest(t) 420 ipv4Loopback := net.IP{127, 0, 0, 1} 421 422 sp := DefaultSessionProvider(Options{ 423 URL: mustParseURL("https://" + net.JoinHostPort(ipv4Loopback.String(), "54321")), 424 Key: test.Key, 425 }) 426 427 req, _ := http.NewRequest("GET", "/", nil) 428 resp := httptest.NewRecorder() 429 assert.Check(t, sp.CreateSession(resp, req, &saml.Assertion{})) 430 431 assert.Check(t, 432 strings.Contains(resp.Header().Get("Set-Cookie"), "Domain=127.0.0.1;"), 433 "Cookie domain must not contain a port or the cookie cannot be set properly: %v", resp.Header().Get("Set-Cookie")) 434 } 435 436 func TestMiddlewareDefaultCookieDomainIPv6(t *testing.T) { 437 t.Skip("fails") // TODO(ross): fix this test 438 439 test := NewMiddlewareTest(t) 440 441 sp := DefaultSessionProvider(Options{ 442 URL: mustParseURL("https://" + net.JoinHostPort(net.IPv6loopback.String(), "54321")), 443 Key: test.Key, 444 }) 445 446 req, _ := http.NewRequest("GET", "/", nil) 447 resp := httptest.NewRecorder() 448 assert.Check(t, sp.CreateSession(resp, req, &saml.Assertion{})) 449 450 assert.Check(t, 451 strings.Contains(resp.Header().Get("Set-Cookie"), "Domain=::1;"), 452 "Cookie domain must not contain a port or the cookie cannot be set properly: %v", resp.Header().Get("Set-Cookie")) 453 } 454 455 func TestMiddlewareRejectsInvalidRelayState(t *testing.T) { 456 test := NewMiddlewareTest(t) 457 458 test.Middleware.OnError = func(w http.ResponseWriter, r *http.Request, err error) { 459 assert.Check(t, is.Error(err, http.ErrNoCookie.Error())) 460 http.Error(w, "forbidden", http.StatusTeapot) 461 } 462 463 v := &url.Values{} 464 v.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) 465 v.Set("RelayState", "ICIkJigqLC4wMjQ2ODo8PkBCREZISkxOUFJUVlhaXF5gYmRmaGpsbnBy") 466 req, _ := http.NewRequest("POST", "/saml2/acs", bytes.NewReader([]byte(v.Encode()))) 467 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 468 req.Header.Set("Cookie", ""+ 469 "saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("id-9e61753d64e928af5a7a341a97f420c9")) 470 471 resp := httptest.NewRecorder() 472 test.Middleware.ServeHTTP(resp, req) 473 assert.Check(t, is.Equal(http.StatusTeapot, resp.Code)) 474 assert.Check(t, is.Equal("", resp.Header().Get("Location"))) 475 assert.Check(t, is.Equal("", resp.Header().Get("Set-Cookie"))) 476 } 477 478 func TestMiddlewareRejectsInvalidCookie(t *testing.T) { 479 test := NewMiddlewareTest(t) 480 481 test.Middleware.OnError = func(w http.ResponseWriter, r *http.Request, err error) { 482 assert.Check(t, is.Error(err, "Authentication failed")) 483 http.Error(w, "forbidden", http.StatusTeapot) 484 } 485 486 v := &url.Values{} 487 v.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse)) 488 v.Set("RelayState", "KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6") 489 req, _ := http.NewRequest("POST", "/saml2/acs", bytes.NewReader([]byte(v.Encode()))) 490 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 491 req.Header.Set("Cookie", ""+ 492 "saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("wrong")) 493 494 resp := httptest.NewRecorder() 495 test.Middleware.ServeHTTP(resp, req) 496 assert.Check(t, is.Equal(http.StatusTeapot, resp.Code)) 497 assert.Check(t, is.Equal("", resp.Header().Get("Location"))) 498 assert.Check(t, is.Equal("", resp.Header().Get("Set-Cookie"))) 499 } 500 501 func TestMiddlewareHandlesInvalidResponse(t *testing.T) { 502 test := NewMiddlewareTest(t) 503 v := &url.Values{} 504 v.Set("SAMLResponse", "this is not a valid saml response") 505 v.Set("RelayState", "KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6") 506 507 req, _ := http.NewRequest("POST", "/saml2/acs", bytes.NewReader([]byte(v.Encode()))) 508 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 509 req.Header.Set("Cookie", ""+ 510 "saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("wrong")) 511 512 resp := httptest.NewRecorder() 513 test.Middleware.ServeHTTP(resp, req) 514 515 // note: it is important that when presented with an invalid request, 516 // the ACS handles DOES NOT reveal detailed error information in the 517 // HTTP response. 518 assert.Check(t, is.Equal(http.StatusForbidden, resp.Code)) 519 respBody, _ := io.ReadAll(resp.Body) 520 assert.Check(t, is.Equal("Forbidden\n", string(respBody))) 521 assert.Check(t, is.Equal("", resp.Header().Get("Location"))) 522 assert.Check(t, is.Equal("", resp.Header().Get("Set-Cookie"))) 523 }