github.com/crewjam/saml@v0.4.14/samlsp/middleware_test.go (about)

     1  package samlsp
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rsa"
     6  	"crypto/sha256"
     7  	"crypto/x509"
     8  	"encoding/base64"
     9  	"encoding/json"
    10  	"encoding/xml"
    11  	"io"
    12  	"net"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"net/url"
    16  	"strings"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/golang-jwt/jwt/v4"
    21  	dsig "github.com/russellhaering/goxmldsig"
    22  	"gotest.tools/assert"
    23  	is "gotest.tools/assert/cmp"
    24  	"gotest.tools/golden"
    25  
    26  	"github.com/crewjam/saml"
    27  	"github.com/crewjam/saml/testsaml"
    28  )
    29  
    30  type MiddlewareTest struct {
    31  	AuthnRequest          []byte
    32  	SamlResponse          []byte
    33  	Key                   *rsa.PrivateKey
    34  	Certificate           *x509.Certificate
    35  	IDPMetadata           []byte
    36  	Middleware            *Middleware
    37  	expectedSessionCookie string
    38  }
    39  
    40  type testRandomReader struct {
    41  	Next byte
    42  }
    43  
    44  func (tr *testRandomReader) Read(p []byte) (n int, err error) {
    45  	for i := 0; i < len(p); i++ {
    46  		p[i] = tr.Next
    47  		tr.Next += 2
    48  	}
    49  	return len(p), nil
    50  }
    51  
    52  func NewMiddlewareTest(t *testing.T) *MiddlewareTest {
    53  	test := MiddlewareTest{}
    54  	saml.TimeNow = func() time.Time {
    55  		rv, _ := time.Parse("Mon Jan 2 15:04:05.999999999 MST 2006", "Mon Dec 1 01:57:09.123456789 UTC 2015")
    56  		return rv
    57  	}
    58  	jwt.TimeFunc = saml.TimeNow
    59  	saml.Clock = dsig.NewFakeClockAt(saml.TimeNow())
    60  	saml.RandReader = &testRandomReader{}
    61  
    62  	test.AuthnRequest = golden.Get(t, "authn_request.url")
    63  	test.SamlResponse = golden.Get(t, "saml_response.xml")
    64  	test.Key = mustParsePrivateKey(golden.Get(t, "key.pem")).(*rsa.PrivateKey)
    65  	test.Certificate = mustParseCertificate(golden.Get(t, "cert.pem"))
    66  	test.IDPMetadata = golden.Get(t, "idp_metadata.xml")
    67  
    68  	var metadata saml.EntityDescriptor
    69  	if err := xml.Unmarshal(test.IDPMetadata, &metadata); err != nil {
    70  		panic(err)
    71  	}
    72  
    73  	opts := Options{
    74  		URL:         mustParseURL("https://15661444.ngrok.io/"),
    75  		Key:         test.Key,
    76  		Certificate: test.Certificate,
    77  		IDPMetadata: &metadata,
    78  	}
    79  
    80  	var err error
    81  	test.Middleware, err = New(opts)
    82  	if err != nil {
    83  		panic(err)
    84  	}
    85  
    86  	sessionProvider := DefaultSessionProvider(opts)
    87  	sessionProvider.Name = "ttt"
    88  	sessionProvider.MaxAge = 7200 * time.Second
    89  
    90  	sessionCodec := sessionProvider.Codec.(JWTSessionCodec)
    91  	sessionCodec.MaxAge = 7200 * time.Second
    92  	sessionProvider.Codec = sessionCodec
    93  
    94  	test.Middleware.Session = sessionProvider
    95  
    96  	test.Middleware.ServiceProvider.MetadataURL.Path = "/saml2/metadata"
    97  	test.Middleware.ServiceProvider.AcsURL.Path = "/saml2/acs"
    98  	test.Middleware.ServiceProvider.SloURL.Path = "/saml2/slo"
    99  
   100  	var tc JWTSessionClaims
   101  	if err := json.Unmarshal(golden.Get(t, "token.json"), &tc); err != nil {
   102  		panic(err)
   103  	}
   104  	test.expectedSessionCookie, err = sessionProvider.Codec.Encode(tc)
   105  	if err != nil {
   106  		panic(err)
   107  	}
   108  
   109  	return &test
   110  }
   111  
   112  func (test *MiddlewareTest) makeTrackedRequest(id string) string {
   113  	codec := test.Middleware.RequestTracker.(CookieRequestTracker).Codec
   114  	token, err := codec.Encode(TrackedRequest{
   115  		Index:         "KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6",
   116  		SAMLRequestID: id,
   117  		URI:           "/frob",
   118  	})
   119  	if err != nil {
   120  		panic(err)
   121  	}
   122  	return token
   123  }
   124  
   125  func TestMiddlewareCanProduceMetadata(t *testing.T) {
   126  	test := NewMiddlewareTest(t)
   127  	req, _ := http.NewRequest("GET", "/saml2/metadata", nil)
   128  
   129  	resp := httptest.NewRecorder()
   130  	test.Middleware.ServeHTTP(resp, req)
   131  	assert.Check(t, is.Equal(http.StatusOK, resp.Code))
   132  	assert.Check(t, is.Equal("application/samlmetadata+xml",
   133  		resp.Header().Get("Content-type")))
   134  	golden.Assert(t, resp.Body.String(), "expected_middleware_metadata.xml")
   135  }
   136  
   137  func TestMiddlewareFourOhFour(t *testing.T) {
   138  	test := NewMiddlewareTest(t)
   139  	req, _ := http.NewRequest("GET", "/this/is/not/a/supported/uri", nil)
   140  
   141  	resp := httptest.NewRecorder()
   142  	test.Middleware.ServeHTTP(resp, req)
   143  	assert.Check(t, is.Equal(http.StatusNotFound, resp.Code))
   144  	respBuf, _ := io.ReadAll(resp.Body)
   145  	assert.Check(t, is.Equal("404 page not found\n", string(respBuf)))
   146  }
   147  
   148  func TestMiddlewareRequireAccountNoCreds(t *testing.T) {
   149  	test := NewMiddlewareTest(t)
   150  	test.Middleware.ServiceProvider.AcsURL.Scheme = "http"
   151  
   152  	handler := test.Middleware.RequireAccount(
   153  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   154  			panic("not reached")
   155  		}))
   156  
   157  	req, _ := http.NewRequest("GET", "/frob", nil)
   158  	resp := httptest.NewRecorder()
   159  	handler.ServeHTTP(resp, req)
   160  
   161  	assert.Check(t, is.Equal(http.StatusFound, resp.Code))
   162  	assert.Check(t, is.Equal("saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+
   163  		test.makeTrackedRequest("id-00020406080a0c0e10121416181a1c1e20222426")+"; Path=/saml2/acs; Max-Age=90; HttpOnly",
   164  		resp.Header().Get("Set-Cookie")))
   165  
   166  	redirectURL, err := url.Parse(resp.Header().Get("Location"))
   167  	assert.Check(t, err)
   168  	decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL)
   169  	assert.Check(t, err)
   170  	golden.Assert(t, string(decodedRequest), "expected_authn_request.xml")
   171  }
   172  
   173  func TestMiddlewareRequireAccountNoCredsSecure(t *testing.T) {
   174  	test := NewMiddlewareTest(t)
   175  
   176  	handler := test.Middleware.RequireAccount(
   177  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   178  			panic("not reached")
   179  		}))
   180  
   181  	req, _ := http.NewRequest("GET", "/frob", nil)
   182  	resp := httptest.NewRecorder()
   183  	handler.ServeHTTP(resp, req)
   184  
   185  	assert.Check(t, is.Equal(http.StatusFound, resp.Code))
   186  	assert.Check(t, is.Equal("saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("id-00020406080a0c0e10121416181a1c1e20222426")+"; Path=/saml2/acs; Max-Age=90; HttpOnly; Secure",
   187  		resp.Header().Get("Set-Cookie")))
   188  
   189  	redirectURL, err := url.Parse(resp.Header().Get("Location"))
   190  	assert.Check(t, err)
   191  	decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL)
   192  	assert.Check(t, err)
   193  	golden.Assert(t, string(decodedRequest), "expected_authn_request_secure.xml")
   194  }
   195  
   196  func TestMiddlewareRequireAccountNoCredsPostBinding(t *testing.T) {
   197  	test := NewMiddlewareTest(t)
   198  	test.Middleware.ServiceProvider.IDPMetadata.IDPSSODescriptors[0].SingleSignOnServices = test.Middleware.ServiceProvider.IDPMetadata.IDPSSODescriptors[0].SingleSignOnServices[1:2]
   199  	assert.Check(t, is.Equal("",
   200  		test.Middleware.ServiceProvider.GetSSOBindingLocation(saml.HTTPRedirectBinding)))
   201  
   202  	handler := test.Middleware.RequireAccount(
   203  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   204  			panic("not reached")
   205  		}))
   206  
   207  	req, _ := http.NewRequest("GET", "/frob", nil)
   208  	resp := httptest.NewRecorder()
   209  	handler.ServeHTTP(resp, req)
   210  
   211  	assert.Check(t, is.Equal(http.StatusOK, resp.Code))
   212  	assert.Check(t, is.Equal("saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("id-00020406080a0c0e10121416181a1c1e20222426")+"; Path=/saml2/acs; Max-Age=90; HttpOnly; Secure",
   213  		resp.Header().Get("Set-Cookie")))
   214  
   215  	golden.Assert(t, resp.Body.String(), "expected_post_binding_response.html")
   216  
   217  	// check that the CSP script hash is set correctly
   218  	scriptContent := "document.getElementById('SAMLSubmitButton').style.visibility=\"hidden\";document.getElementById('SAMLRequestForm').submit();"
   219  	scriptSum := sha256.Sum256([]byte(scriptContent))
   220  	scriptHash := base64.StdEncoding.EncodeToString(scriptSum[:])
   221  	assert.Check(t, is.Equal("default-src; script-src 'sha256-"+scriptHash+"'; reflected-xss block; referrer no-referrer;",
   222  		resp.Header().Get("Content-Security-Policy")))
   223  
   224  	assert.Check(t, is.Equal("text/html", resp.Header().Get("Content-type")))
   225  }
   226  
   227  func TestMiddlewareRequireAccountCreds(t *testing.T) {
   228  	test := NewMiddlewareTest(t)
   229  	handler := test.Middleware.RequireAccount(
   230  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   231  			genericSession := SessionFromContext(r.Context())
   232  			jwtSession := genericSession.(JWTSessionClaims)
   233  			assert.Check(t, is.Equal("555-5555", jwtSession.Attributes.Get("telephoneNumber")))
   234  			assert.Check(t, is.Equal("And I", jwtSession.Attributes.Get("sn")))
   235  			assert.Check(t, is.Equal("urn:mace:dir:entitlement:common-lib-terms", jwtSession.Attributes.Get("eduPersonEntitlement")))
   236  			assert.Check(t, is.Equal("", jwtSession.Attributes.Get("eduPersonTargetedID")))
   237  			assert.Check(t, is.Equal("Me Myself", jwtSession.Attributes.Get("givenName")))
   238  			assert.Check(t, is.Equal("Me Myself And I", jwtSession.Attributes.Get("cn")))
   239  			assert.Check(t, is.Equal("myself", jwtSession.Attributes.Get("uid")))
   240  			assert.Check(t, is.Equal("myself@testshib.org", jwtSession.Attributes.Get("eduPersonPrincipalName")))
   241  			assert.Check(t, is.DeepEqual([]string{"Member@testshib.org", "Staff@testshib.org"}, jwtSession.Attributes["eduPersonScopedAffiliation"]))
   242  			assert.Check(t, is.DeepEqual([]string{"Member", "Staff"}, jwtSession.Attributes["eduPersonAffiliation"]))
   243  			w.WriteHeader(http.StatusTeapot)
   244  		}))
   245  
   246  	req, _ := http.NewRequest("GET", "/frob", nil)
   247  	req.Header.Set("Cookie", ""+
   248  		"ttt="+test.expectedSessionCookie+"; "+
   249  		"Path=/; Max-Age=7200")
   250  	resp := httptest.NewRecorder()
   251  	handler.ServeHTTP(resp, req)
   252  
   253  	assert.Check(t, is.Equal(http.StatusTeapot, resp.Code))
   254  }
   255  
   256  func TestMiddlewareRequireAccountBadCreds(t *testing.T) {
   257  	test := NewMiddlewareTest(t)
   258  	handler := test.Middleware.RequireAccount(
   259  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   260  			panic("not reached")
   261  		}))
   262  
   263  	req, _ := http.NewRequest("GET", "/frob", nil)
   264  	req.Header.Set("Cookie", ""+
   265  		"ttt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.yejJbiI6Ik1lIE15c2VsZiBBbmQgSSIsImVkdVBlcnNvbkFmZmlsaWF0aW9uIjoiU3RhZmYiLCJlZHVQZXJzb25FbnRpdGxlbWVudCI6InVybjptYWNlOmRpcjplbnRpdGxlbWVudDpjb21tb24tbGliLXRlcm1zIiwiZWR1UGVyc29uUHJpbmNpcGFsTmFtZSI6Im15c2VsZkB0ZXN0c2hpYi5vcmciLCJlZHVQZXJzb25TY29wZWRBZmZpbGlhdGlvbiI6IlN0YWZmQHRlc3RzaGliLm9yZyIsImVkdVBlcnNvblRhcmdldGVkSUQiOiIiLCJleHAiOjE0NDg5Mzg2MjksImdpdmVuTmFtZSI6Ik1lIE15c2VsZiIsInNuIjoiQW5kIEkiLCJ0ZWxlcGhvbmVOdW1iZXIiOiI1NTUtNTU1NSIsInVpZCI6Im15c2VsZiJ9.SqeTkbGG35oFj_9H-d9oVdV-Hb7Vqam6LvZLcmia7FY; "+
   266  		"Path=/; Max-Age=7200; Secure")
   267  	resp := httptest.NewRecorder()
   268  	handler.ServeHTTP(resp, req)
   269  
   270  	assert.Check(t, is.Equal(http.StatusFound, resp.Code))
   271  
   272  	assert.Check(t, is.Equal("saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("id-00020406080a0c0e10121416181a1c1e20222426")+"; Path=/saml2/acs; Max-Age=90; HttpOnly; Secure",
   273  		resp.Header().Get("Set-Cookie")))
   274  
   275  	redirectURL, err := url.Parse(resp.Header().Get("Location"))
   276  	assert.Check(t, err)
   277  	decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL)
   278  	assert.Check(t, err)
   279  	golden.Assert(t, string(decodedRequest), "expected_authn_request_secure.xml")
   280  }
   281  
   282  func TestMiddlewareRequireAccountExpiredCreds(t *testing.T) {
   283  	test := NewMiddlewareTest(t)
   284  	jwt.TimeFunc = func() time.Time {
   285  		rv, _ := time.Parse("Mon Jan 2 15:04:05 UTC 2006", "Mon Dec 1 01:31:21 UTC 2115")
   286  		return rv
   287  	}
   288  
   289  	handler := test.Middleware.RequireAccount(
   290  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   291  			panic("not reached")
   292  		}))
   293  
   294  	req, _ := http.NewRequest("GET", "/frob", nil)
   295  	req.Header.Set("Cookie", ""+
   296  		"ttt="+test.expectedSessionCookie+"; "+
   297  		"Path=/; Max-Age=7200")
   298  	resp := httptest.NewRecorder()
   299  	handler.ServeHTTP(resp, req)
   300  
   301  	assert.Check(t, is.Equal(http.StatusFound, resp.Code))
   302  	assert.Check(t, is.Equal("saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("id-00020406080a0c0e10121416181a1c1e20222426")+"; Path=/saml2/acs; Max-Age=90; HttpOnly; Secure",
   303  		resp.Header().Get("Set-Cookie")))
   304  
   305  	redirectURL, err := url.Parse(resp.Header().Get("Location"))
   306  	assert.Check(t, err)
   307  	decodedRequest, err := testsaml.ParseRedirectRequest(redirectURL)
   308  	assert.Check(t, err)
   309  	golden.Assert(t, string(decodedRequest), "expected_authn_request_secure.xml")
   310  }
   311  
   312  func TestMiddlewareRequireAccountPanicOnRequestToACS(t *testing.T) {
   313  	test := NewMiddlewareTest(t)
   314  	handler := test.Middleware.RequireAccount(
   315  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   316  			panic("not reached")
   317  		}))
   318  
   319  	req, _ := http.NewRequest("POST", "https://15661444.ngrok.io/saml2/acs", nil)
   320  	resp := httptest.NewRecorder()
   321  
   322  	assert.Check(t, is.Panics(func() { handler.ServeHTTP(resp, req) }))
   323  }
   324  
   325  func TestMiddlewareRequireAttribute(t *testing.T) {
   326  	test := NewMiddlewareTest(t)
   327  	handler := test.Middleware.RequireAccount(
   328  		RequireAttribute("eduPersonAffiliation", "Staff")(
   329  			http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   330  				w.WriteHeader(http.StatusTeapot)
   331  			})))
   332  
   333  	req, _ := http.NewRequest("GET", "/frob", nil)
   334  	req.Header.Set("Cookie", ""+
   335  		"ttt="+test.expectedSessionCookie+"; "+
   336  		"Path=/; Max-Age=7200")
   337  	resp := httptest.NewRecorder()
   338  	handler.ServeHTTP(resp, req)
   339  
   340  	assert.Check(t, is.Equal(http.StatusTeapot, resp.Code))
   341  }
   342  
   343  func TestMiddlewareRequireAttributeWrongValue(t *testing.T) {
   344  	test := NewMiddlewareTest(t)
   345  	handler := test.Middleware.RequireAccount(
   346  		RequireAttribute("eduPersonAffiliation", "DomainAdmins")(
   347  			http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   348  				panic("not reached")
   349  			})))
   350  
   351  	req, _ := http.NewRequest("GET", "/frob", nil)
   352  	req.Header.Set("Cookie", ""+
   353  		"ttt="+test.expectedSessionCookie+"; "+
   354  		"Path=/; Max-Age=7200")
   355  	resp := httptest.NewRecorder()
   356  	handler.ServeHTTP(resp, req)
   357  
   358  	assert.Check(t, is.Equal(http.StatusForbidden, resp.Code))
   359  }
   360  
   361  func TestMiddlewareRequireAttributeNotPresent(t *testing.T) {
   362  	test := NewMiddlewareTest(t)
   363  	handler := test.Middleware.RequireAccount(
   364  		RequireAttribute("valueThatDoesntExist", "doesntMatter")(
   365  			http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   366  				panic("not reached")
   367  			})))
   368  
   369  	req, _ := http.NewRequest("GET", "/frob", nil)
   370  	req.Header.Set("Cookie", ""+
   371  		"ttt="+test.expectedSessionCookie+"; "+
   372  		"Path=/; Max-Age=7200")
   373  	resp := httptest.NewRecorder()
   374  	handler.ServeHTTP(resp, req)
   375  
   376  	assert.Check(t, is.Equal(http.StatusForbidden, resp.Code))
   377  }
   378  
   379  func TestMiddlewareRequireAttributeMissingAccount(t *testing.T) {
   380  	test := NewMiddlewareTest(t)
   381  	handler := RequireAttribute("eduPersonAffiliation", "DomainAdmins")(
   382  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   383  			panic("not reached")
   384  		}))
   385  
   386  	req, _ := http.NewRequest("GET", "/frob", nil)
   387  	req.Header.Set("Cookie", ""+
   388  		"ttt="+test.expectedSessionCookie+"; "+
   389  		"Path=/; Max-Age=7200")
   390  	resp := httptest.NewRecorder()
   391  	handler.ServeHTTP(resp, req)
   392  
   393  	assert.Check(t, is.Equal(http.StatusForbidden, resp.Code))
   394  }
   395  
   396  func TestMiddlewareCanParseResponse(t *testing.T) {
   397  	test := NewMiddlewareTest(t)
   398  	v := &url.Values{}
   399  	v.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse))
   400  	v.Set("RelayState", "KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6")
   401  	req, _ := http.NewRequest("POST", "/saml2/acs", bytes.NewReader([]byte(v.Encode())))
   402  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   403  	req.Header.Set("Cookie", ""+
   404  		"saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("id-9e61753d64e928af5a7a341a97f420c9"))
   405  
   406  	resp := httptest.NewRecorder()
   407  	test.Middleware.ServeHTTP(resp, req)
   408  	assert.Check(t, is.Equal(http.StatusFound, resp.Code))
   409  
   410  	assert.Check(t, is.Equal("/frob", resp.Header().Get("Location")))
   411  	assert.Check(t, is.DeepEqual([]string{
   412  		"saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6=; Domain=15661444.ngrok.io; Expires=Thu, 01 Jan 1970 00:00:01 GMT",
   413  		"ttt=" + test.expectedSessionCookie + "; " +
   414  			"Path=/; Domain=15661444.ngrok.io; Max-Age=7200; HttpOnly; Secure"},
   415  		resp.Header()["Set-Cookie"]))
   416  }
   417  
   418  func TestMiddlewareDefaultCookieDomainIPv4(t *testing.T) {
   419  	test := NewMiddlewareTest(t)
   420  	ipv4Loopback := net.IP{127, 0, 0, 1}
   421  
   422  	sp := DefaultSessionProvider(Options{
   423  		URL: mustParseURL("https://" + net.JoinHostPort(ipv4Loopback.String(), "54321")),
   424  		Key: test.Key,
   425  	})
   426  
   427  	req, _ := http.NewRequest("GET", "/", nil)
   428  	resp := httptest.NewRecorder()
   429  	assert.Check(t, sp.CreateSession(resp, req, &saml.Assertion{}))
   430  
   431  	assert.Check(t,
   432  		strings.Contains(resp.Header().Get("Set-Cookie"), "Domain=127.0.0.1;"),
   433  		"Cookie domain must not contain a port or the cookie cannot be set properly: %v", resp.Header().Get("Set-Cookie"))
   434  }
   435  
   436  func TestMiddlewareDefaultCookieDomainIPv6(t *testing.T) {
   437  	t.Skip("fails") // TODO(ross): fix this test
   438  
   439  	test := NewMiddlewareTest(t)
   440  
   441  	sp := DefaultSessionProvider(Options{
   442  		URL: mustParseURL("https://" + net.JoinHostPort(net.IPv6loopback.String(), "54321")),
   443  		Key: test.Key,
   444  	})
   445  
   446  	req, _ := http.NewRequest("GET", "/", nil)
   447  	resp := httptest.NewRecorder()
   448  	assert.Check(t, sp.CreateSession(resp, req, &saml.Assertion{}))
   449  
   450  	assert.Check(t,
   451  		strings.Contains(resp.Header().Get("Set-Cookie"), "Domain=::1;"),
   452  		"Cookie domain must not contain a port or the cookie cannot be set properly: %v", resp.Header().Get("Set-Cookie"))
   453  }
   454  
   455  func TestMiddlewareRejectsInvalidRelayState(t *testing.T) {
   456  	test := NewMiddlewareTest(t)
   457  
   458  	test.Middleware.OnError = func(w http.ResponseWriter, r *http.Request, err error) {
   459  		assert.Check(t, is.Error(err, http.ErrNoCookie.Error()))
   460  		http.Error(w, "forbidden", http.StatusTeapot)
   461  	}
   462  
   463  	v := &url.Values{}
   464  	v.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse))
   465  	v.Set("RelayState", "ICIkJigqLC4wMjQ2ODo8PkBCREZISkxOUFJUVlhaXF5gYmRmaGpsbnBy")
   466  	req, _ := http.NewRequest("POST", "/saml2/acs", bytes.NewReader([]byte(v.Encode())))
   467  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   468  	req.Header.Set("Cookie", ""+
   469  		"saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("id-9e61753d64e928af5a7a341a97f420c9"))
   470  
   471  	resp := httptest.NewRecorder()
   472  	test.Middleware.ServeHTTP(resp, req)
   473  	assert.Check(t, is.Equal(http.StatusTeapot, resp.Code))
   474  	assert.Check(t, is.Equal("", resp.Header().Get("Location")))
   475  	assert.Check(t, is.Equal("", resp.Header().Get("Set-Cookie")))
   476  }
   477  
   478  func TestMiddlewareRejectsInvalidCookie(t *testing.T) {
   479  	test := NewMiddlewareTest(t)
   480  
   481  	test.Middleware.OnError = func(w http.ResponseWriter, r *http.Request, err error) {
   482  		assert.Check(t, is.Error(err, "Authentication failed"))
   483  		http.Error(w, "forbidden", http.StatusTeapot)
   484  	}
   485  
   486  	v := &url.Values{}
   487  	v.Set("SAMLResponse", base64.StdEncoding.EncodeToString(test.SamlResponse))
   488  	v.Set("RelayState", "KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6")
   489  	req, _ := http.NewRequest("POST", "/saml2/acs", bytes.NewReader([]byte(v.Encode())))
   490  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   491  	req.Header.Set("Cookie", ""+
   492  		"saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("wrong"))
   493  
   494  	resp := httptest.NewRecorder()
   495  	test.Middleware.ServeHTTP(resp, req)
   496  	assert.Check(t, is.Equal(http.StatusTeapot, resp.Code))
   497  	assert.Check(t, is.Equal("", resp.Header().Get("Location")))
   498  	assert.Check(t, is.Equal("", resp.Header().Get("Set-Cookie")))
   499  }
   500  
   501  func TestMiddlewareHandlesInvalidResponse(t *testing.T) {
   502  	test := NewMiddlewareTest(t)
   503  	v := &url.Values{}
   504  	v.Set("SAMLResponse", "this is not a valid saml response")
   505  	v.Set("RelayState", "KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6")
   506  
   507  	req, _ := http.NewRequest("POST", "/saml2/acs", bytes.NewReader([]byte(v.Encode())))
   508  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   509  	req.Header.Set("Cookie", ""+
   510  		"saml_KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6="+test.makeTrackedRequest("wrong"))
   511  
   512  	resp := httptest.NewRecorder()
   513  	test.Middleware.ServeHTTP(resp, req)
   514  
   515  	// note: it is important that when presented with an invalid request,
   516  	// the ACS handles DOES NOT reveal detailed error information in the
   517  	// HTTP response.
   518  	assert.Check(t, is.Equal(http.StatusForbidden, resp.Code))
   519  	respBody, _ := io.ReadAll(resp.Body)
   520  	assert.Check(t, is.Equal("Forbidden\n", string(respBody)))
   521  	assert.Check(t, is.Equal("", resp.Header().Get("Location")))
   522  	assert.Check(t, is.Equal("", resp.Header().Get("Set-Cookie")))
   523  }