github.com/crewjam/saml@v0.4.14/samlidp/samlidp_test.go (about)

     1  package samlidp
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/rsa"
     6  	"crypto/x509"
     7  	"encoding/pem"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"gotest.tools/assert"
    16  	is "gotest.tools/assert/cmp"
    17  	"gotest.tools/golden"
    18  
    19  	"github.com/golang-jwt/jwt/v4"
    20  
    21  	"github.com/crewjam/saml"
    22  	"github.com/crewjam/saml/logger"
    23  )
    24  
    25  type testRandomReader struct {
    26  	Next byte
    27  }
    28  
    29  func (tr *testRandomReader) Read(p []byte) (n int, err error) {
    30  	for i := 0; i < len(p); i++ {
    31  		p[i] = tr.Next
    32  		tr.Next += 2
    33  	}
    34  	return len(p), nil
    35  }
    36  
    37  func mustParseURL(s string) url.URL {
    38  	rv, err := url.Parse(s)
    39  	if err != nil {
    40  		panic(err)
    41  	}
    42  	return *rv
    43  }
    44  
    45  func mustParsePrivateKey(pemStr []byte) crypto.PrivateKey {
    46  	b, _ := pem.Decode(pemStr)
    47  	if b == nil {
    48  		panic("cannot parse PEM")
    49  	}
    50  	k, err := x509.ParsePKCS1PrivateKey(b.Bytes)
    51  	if err != nil {
    52  		panic(err)
    53  	}
    54  	return k
    55  }
    56  
    57  func mustParseCertificate(pemStr []byte) *x509.Certificate {
    58  	b, _ := pem.Decode(pemStr)
    59  	if b == nil {
    60  		panic("cannot parse PEM")
    61  	}
    62  	cert, err := x509.ParseCertificate(b.Bytes)
    63  	if err != nil {
    64  		panic(err)
    65  	}
    66  	return cert
    67  }
    68  
    69  type ServerTest struct {
    70  	SPKey         *rsa.PrivateKey
    71  	SPCertificate *x509.Certificate
    72  	SP            saml.ServiceProvider
    73  
    74  	Key         crypto.PrivateKey
    75  	Certificate *x509.Certificate
    76  	Server      *Server
    77  	Store       MemoryStore
    78  }
    79  
    80  func NewServerTest(t *testing.T) *ServerTest {
    81  	test := ServerTest{}
    82  	saml.TimeNow = func() time.Time {
    83  		rv, _ := time.Parse("Mon Jan 2 15:04:05 MST 2006", "Mon Dec 1 01:57:09 UTC 2015")
    84  		return rv
    85  	}
    86  	jwt.TimeFunc = saml.TimeNow
    87  	saml.RandReader = &testRandomReader{}
    88  
    89  	test.SPKey = mustParsePrivateKey(golden.Get(t, "sp_key.pem")).(*rsa.PrivateKey)
    90  	test.SPCertificate = mustParseCertificate(golden.Get(t, "sp_cert.pem"))
    91  	test.SP = saml.ServiceProvider{
    92  		Key:         test.SPKey,
    93  		Certificate: test.SPCertificate,
    94  		MetadataURL: mustParseURL("https://sp.example.com/saml2/metadata"),
    95  		AcsURL:      mustParseURL("https://sp.example.com/saml2/acs"),
    96  		IDPMetadata: &saml.EntityDescriptor{},
    97  	}
    98  	test.Key = mustParsePrivateKey(golden.Get(t, "idp_key.pem")).(*rsa.PrivateKey)
    99  	test.Certificate = mustParseCertificate(golden.Get(t, "idp_cert.pem"))
   100  
   101  	test.Store = MemoryStore{}
   102  
   103  	var err error
   104  	test.Server, err = New(Options{
   105  		Certificate: test.Certificate,
   106  		Key:         test.Key,
   107  		Logger:      logger.DefaultLogger,
   108  		Store:       &test.Store,
   109  		URL:         url.URL{Scheme: "https", Host: "idp.example.com"},
   110  	})
   111  	if err != nil {
   112  		panic(err)
   113  	}
   114  
   115  	test.SP.IDPMetadata = test.Server.IDP.Metadata()
   116  	test.Server.serviceProviders["https://sp.example.com/saml2/metadata"] = test.SP.Metadata()
   117  	return &test
   118  }
   119  
   120  func TestHTTPCanHandleMetadataRequest(t *testing.T) {
   121  	test := NewServerTest(t)
   122  	w := httptest.NewRecorder()
   123  	r, _ := http.NewRequest("GET", "https://idp.example.com/metadata", nil)
   124  	test.Server.ServeHTTP(w, r)
   125  	assert.Check(t, is.Equal(http.StatusOK, w.Code))
   126  	assert.Check(t,
   127  		strings.HasPrefix(w.Body.String(), "<EntityDescriptor"),
   128  		w.Body.String())
   129  	golden.Assert(t, w.Body.String(), "http_metadata_response.html")
   130  }
   131  
   132  func TestHTTPCanSSORequest(t *testing.T) {
   133  	test := NewServerTest(t)
   134  	u, err := test.SP.MakeRedirectAuthenticationRequest("frob")
   135  	assert.Check(t, err)
   136  
   137  	w := httptest.NewRecorder()
   138  	r, _ := http.NewRequest("GET", u.String(), nil)
   139  	test.Server.ServeHTTP(w, r)
   140  	assert.Check(t, is.Equal(http.StatusOK, w.Code))
   141  	assert.Check(t,
   142  		strings.HasPrefix(w.Body.String(), "<html><p></p><form method=\"post\" action=\"https://idp.example.com/sso\">"),
   143  		w.Body.String())
   144  	golden.Assert(t, w.Body.String(), "http_sso_response.html")
   145  }