github.com/crewjam/saml@v0.4.14/samlidp/samlidp_test.go (about) 1 package samlidp 2 3 import ( 4 "crypto" 5 "crypto/rsa" 6 "crypto/x509" 7 "encoding/pem" 8 "net/http" 9 "net/http/httptest" 10 "net/url" 11 "strings" 12 "testing" 13 "time" 14 15 "gotest.tools/assert" 16 is "gotest.tools/assert/cmp" 17 "gotest.tools/golden" 18 19 "github.com/golang-jwt/jwt/v4" 20 21 "github.com/crewjam/saml" 22 "github.com/crewjam/saml/logger" 23 ) 24 25 type testRandomReader struct { 26 Next byte 27 } 28 29 func (tr *testRandomReader) Read(p []byte) (n int, err error) { 30 for i := 0; i < len(p); i++ { 31 p[i] = tr.Next 32 tr.Next += 2 33 } 34 return len(p), nil 35 } 36 37 func mustParseURL(s string) url.URL { 38 rv, err := url.Parse(s) 39 if err != nil { 40 panic(err) 41 } 42 return *rv 43 } 44 45 func mustParsePrivateKey(pemStr []byte) crypto.PrivateKey { 46 b, _ := pem.Decode(pemStr) 47 if b == nil { 48 panic("cannot parse PEM") 49 } 50 k, err := x509.ParsePKCS1PrivateKey(b.Bytes) 51 if err != nil { 52 panic(err) 53 } 54 return k 55 } 56 57 func mustParseCertificate(pemStr []byte) *x509.Certificate { 58 b, _ := pem.Decode(pemStr) 59 if b == nil { 60 panic("cannot parse PEM") 61 } 62 cert, err := x509.ParseCertificate(b.Bytes) 63 if err != nil { 64 panic(err) 65 } 66 return cert 67 } 68 69 type ServerTest struct { 70 SPKey *rsa.PrivateKey 71 SPCertificate *x509.Certificate 72 SP saml.ServiceProvider 73 74 Key crypto.PrivateKey 75 Certificate *x509.Certificate 76 Server *Server 77 Store MemoryStore 78 } 79 80 func NewServerTest(t *testing.T) *ServerTest { 81 test := ServerTest{} 82 saml.TimeNow = func() time.Time { 83 rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Dec 1 01:57:09 UTC 2015") 84 return rv 85 } 86 jwt.TimeFunc = saml.TimeNow 87 saml.RandReader = &testRandomReader{} 88 89 test.SPKey = mustParsePrivateKey(golden.Get(t, "sp_key.pem")).(*rsa.PrivateKey) 90 test.SPCertificate = mustParseCertificate(golden.Get(t, "sp_cert.pem")) 91 test.SP = saml.ServiceProvider{ 92 Key: test.SPKey, 93 Certificate: test.SPCertificate, 94 MetadataURL: mustParseURL("https://sp.example.com/saml2/metadata"), 95 AcsURL: mustParseURL("https://sp.example.com/saml2/acs"), 96 IDPMetadata: &saml.EntityDescriptor{}, 97 } 98 test.Key = mustParsePrivateKey(golden.Get(t, "idp_key.pem")).(*rsa.PrivateKey) 99 test.Certificate = mustParseCertificate(golden.Get(t, "idp_cert.pem")) 100 101 test.Store = MemoryStore{} 102 103 var err error 104 test.Server, err = New(Options{ 105 Certificate: test.Certificate, 106 Key: test.Key, 107 Logger: logger.DefaultLogger, 108 Store: &test.Store, 109 URL: url.URL{Scheme: "https", Host: "idp.example.com"}, 110 }) 111 if err != nil { 112 panic(err) 113 } 114 115 test.SP.IDPMetadata = test.Server.IDP.Metadata() 116 test.Server.serviceProviders["https://sp.example.com/saml2/metadata"] = test.SP.Metadata() 117 return &test 118 } 119 120 func TestHTTPCanHandleMetadataRequest(t *testing.T) { 121 test := NewServerTest(t) 122 w := httptest.NewRecorder() 123 r, _ := http.NewRequest("GET", "https://idp.example.com/metadata", nil) 124 test.Server.ServeHTTP(w, r) 125 assert.Check(t, is.Equal(http.StatusOK, w.Code)) 126 assert.Check(t, 127 strings.HasPrefix(w.Body.String(), "<EntityDescriptor"), 128 w.Body.String()) 129 golden.Assert(t, w.Body.String(), "http_metadata_response.html") 130 } 131 132 func TestHTTPCanSSORequest(t *testing.T) { 133 test := NewServerTest(t) 134 u, err := test.SP.MakeRedirectAuthenticationRequest("frob") 135 assert.Check(t, err) 136 137 w := httptest.NewRecorder() 138 r, _ := http.NewRequest("GET", u.String(), nil) 139 test.Server.ServeHTTP(w, r) 140 assert.Check(t, is.Equal(http.StatusOK, w.Code)) 141 assert.Check(t, 142 strings.HasPrefix(w.Body.String(), "<html><p></p><form method=\"post\" action=\"https://idp.example.com/sso\">"), 143 w.Body.String()) 144 golden.Assert(t, w.Body.String(), "http_sso_response.html") 145 }