github.com/crewjam/saml@v0.4.14/identity_provider.go (about)

     1  package saml
     2  
     3  import (
     4  	"bytes"
     5  	"crypto"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"encoding/base64"
     9  	"encoding/xml"
    10  	"fmt"
    11  	"io"
    12  	"net/http"
    13  	"net/url"
    14  	"os"
    15  	"regexp"
    16  	"strconv"
    17  	"text/template"
    18  	"time"
    19  
    20  	"github.com/beevik/etree"
    21  	xrv "github.com/mattermost/xml-roundtrip-validator"
    22  	dsig "github.com/russellhaering/goxmldsig"
    23  
    24  	"github.com/crewjam/saml/logger"
    25  	"github.com/crewjam/saml/xmlenc"
    26  )
    27  
    28  // Session represents a user session. It is returned by the
    29  // SessionProvider implementation's GetSession method. Fields here
    30  // are used to set fields in the SAML assertion.
    31  type Session struct {
    32  	ID         string
    33  	CreateTime time.Time
    34  	ExpireTime time.Time
    35  	Index      string
    36  
    37  	NameID       string
    38  	NameIDFormat string
    39  	SubjectID    string
    40  
    41  	Groups                []string
    42  	UserName              string
    43  	UserEmail             string
    44  	UserCommonName        string
    45  	UserSurname           string
    46  	UserGivenName         string
    47  	UserScopedAffiliation string
    48  
    49  	CustomAttributes []Attribute
    50  }
    51  
    52  // SessionProvider is an interface used by IdentityProvider to determine the
    53  // Session associated with a request. For an example implementation, see
    54  // GetSession in the samlidp package.
    55  type SessionProvider interface {
    56  	// GetSession returns the remote user session associated with the http.Request.
    57  	//
    58  	// If (and only if) the request is not associated with a session then GetSession
    59  	// must complete the HTTP request and return nil.
    60  	GetSession(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session
    61  }
    62  
    63  // ServiceProviderProvider is an interface used by IdentityProvider to look up
    64  // service provider metadata for a request.
    65  type ServiceProviderProvider interface {
    66  	// GetServiceProvider returns the Service Provider metadata for the
    67  	// service provider ID, which is typically the service provider's
    68  	// metadata URL. If an appropriate service provider cannot be found then
    69  	// the returned error must be os.ErrNotExist.
    70  	GetServiceProvider(r *http.Request, serviceProviderID string) (*EntityDescriptor, error)
    71  }
    72  
    73  // AssertionMaker is an interface used by IdentityProvider to construct the
    74  // assertion for a request. The default implementation is DefaultAssertionMaker,
    75  // which is used if not AssertionMaker is specified.
    76  type AssertionMaker interface {
    77  	// MakeAssertion constructs an assertion from session and the request and
    78  	// assigns it to req.Assertion.
    79  	MakeAssertion(req *IdpAuthnRequest, session *Session) error
    80  }
    81  
    82  // IdentityProvider implements the SAML Identity Provider role (IDP).
    83  //
    84  // An identity provider receives SAML assertion requests and responds
    85  // with SAML Assertions.
    86  //
    87  // You must provide a keypair that is used to
    88  // sign assertions.
    89  //
    90  // You must provide an implementation of ServiceProviderProvider which
    91  // returns
    92  //
    93  // You must provide an implementation of the SessionProvider which
    94  // handles the actual authentication (i.e. prompting for a username
    95  // and password).
    96  type IdentityProvider struct {
    97  	Key                     crypto.PrivateKey
    98  	Signer                  crypto.Signer
    99  	Logger                  logger.Interface
   100  	Certificate             *x509.Certificate
   101  	Intermediates           []*x509.Certificate
   102  	MetadataURL             url.URL
   103  	SSOURL                  url.URL
   104  	LogoutURL               url.URL
   105  	ServiceProviderProvider ServiceProviderProvider
   106  	SessionProvider         SessionProvider
   107  	AssertionMaker          AssertionMaker
   108  	SignatureMethod         string
   109  	ValidDuration           *time.Duration
   110  }
   111  
   112  // Metadata returns the metadata structure for this identity provider.
   113  func (idp *IdentityProvider) Metadata() *EntityDescriptor {
   114  	certStr := base64.StdEncoding.EncodeToString(idp.Certificate.Raw)
   115  
   116  	var validDuration time.Duration
   117  	if idp.ValidDuration != nil {
   118  		validDuration = *idp.ValidDuration
   119  	} else {
   120  		validDuration = DefaultValidDuration
   121  	}
   122  
   123  	ed := &EntityDescriptor{
   124  		EntityID:      idp.MetadataURL.String(),
   125  		ValidUntil:    TimeNow().Add(validDuration),
   126  		CacheDuration: validDuration,
   127  		IDPSSODescriptors: []IDPSSODescriptor{
   128  			{
   129  				SSODescriptor: SSODescriptor{
   130  					RoleDescriptor: RoleDescriptor{
   131  						ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol",
   132  						KeyDescriptors: []KeyDescriptor{
   133  							{
   134  								Use: "signing",
   135  								KeyInfo: KeyInfo{
   136  									X509Data: X509Data{
   137  										X509Certificates: []X509Certificate{
   138  											{Data: certStr},
   139  										},
   140  									},
   141  								},
   142  							},
   143  							{
   144  								Use: "encryption",
   145  								KeyInfo: KeyInfo{
   146  									X509Data: X509Data{
   147  										X509Certificates: []X509Certificate{
   148  											{Data: certStr},
   149  										},
   150  									},
   151  								},
   152  								EncryptionMethods: []EncryptionMethod{
   153  									{Algorithm: "http://www.w3.org/2001/04/xmlenc#aes128-cbc"},
   154  									{Algorithm: "http://www.w3.org/2001/04/xmlenc#aes192-cbc"},
   155  									{Algorithm: "http://www.w3.org/2001/04/xmlenc#aes256-cbc"},
   156  									{Algorithm: "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p"},
   157  								},
   158  							},
   159  						},
   160  					},
   161  					NameIDFormats: []NameIDFormat{NameIDFormat("urn:oasis:names:tc:SAML:2.0:nameid-format:transient")},
   162  				},
   163  				SingleSignOnServices: []Endpoint{
   164  					{
   165  						Binding:  HTTPRedirectBinding,
   166  						Location: idp.SSOURL.String(),
   167  					},
   168  					{
   169  						Binding:  HTTPPostBinding,
   170  						Location: idp.SSOURL.String(),
   171  					},
   172  				},
   173  			},
   174  		},
   175  	}
   176  
   177  	if idp.LogoutURL.String() != "" {
   178  		ed.IDPSSODescriptors[0].SSODescriptor.SingleLogoutServices = []Endpoint{
   179  			{
   180  				Binding:  HTTPRedirectBinding,
   181  				Location: idp.LogoutURL.String(),
   182  			},
   183  		}
   184  	}
   185  
   186  	return ed
   187  }
   188  
   189  // Handler returns an http.Handler that serves the metadata and SSO
   190  // URLs
   191  func (idp *IdentityProvider) Handler() http.Handler {
   192  	mux := http.NewServeMux()
   193  	mux.HandleFunc(idp.MetadataURL.Path, idp.ServeMetadata)
   194  	mux.HandleFunc(idp.SSOURL.Path, idp.ServeSSO)
   195  	return mux
   196  }
   197  
   198  // ServeMetadata is an http.HandlerFunc that serves the IDP metadata
   199  func (idp *IdentityProvider) ServeMetadata(w http.ResponseWriter, _ *http.Request) {
   200  	buf, _ := xml.MarshalIndent(idp.Metadata(), "", "  ")
   201  	w.Header().Set("Content-Type", "application/samlmetadata+xml")
   202  	if _, err := w.Write(buf); err != nil {
   203  		idp.Logger.Printf("ERROR: %s", err)
   204  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   205  	}
   206  }
   207  
   208  // ServeSSO handles SAML auth requests.
   209  //
   210  // When it gets a request for a user that does not have a valid session,
   211  // then it prompts the user via XXX.
   212  //
   213  // If the session already exists, then it produces a SAML assertion and
   214  // returns an HTTP response according to the specified binding. The
   215  // only supported binding right now is the HTTP-POST binding which returns
   216  // an HTML form in the appropriate format with Javascript to automatically
   217  // submit that form the to service provider's Assertion Customer Service
   218  // endpoint.
   219  //
   220  // If the SAML request is invalid or cannot be verified a simple StatusBadRequest
   221  // response is sent.
   222  //
   223  // If the assertion cannot be created or returned, a StatusInternalServerError
   224  // response is sent.
   225  func (idp *IdentityProvider) ServeSSO(w http.ResponseWriter, r *http.Request) {
   226  	req, err := NewIdpAuthnRequest(idp, r)
   227  	if err != nil {
   228  		idp.Logger.Printf("failed to parse request: %s", err)
   229  		http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
   230  		return
   231  	}
   232  
   233  	if err := req.Validate(); err != nil {
   234  		idp.Logger.Printf("failed to validate request: %s", err)
   235  		http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
   236  		return
   237  	}
   238  
   239  	// TODO(ross): we must check that the request ID has not been previously
   240  	//   issued.
   241  
   242  	session := idp.SessionProvider.GetSession(w, r, req)
   243  	if session == nil {
   244  		return
   245  	}
   246  
   247  	assertionMaker := idp.AssertionMaker
   248  	if assertionMaker == nil {
   249  		assertionMaker = DefaultAssertionMaker{}
   250  	}
   251  	if err := assertionMaker.MakeAssertion(req, session); err != nil {
   252  		idp.Logger.Printf("failed to make assertion: %s", err)
   253  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   254  		return
   255  	}
   256  	if err := req.WriteResponse(w); err != nil {
   257  		idp.Logger.Printf("failed to write response: %s", err)
   258  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   259  		return
   260  	}
   261  }
   262  
   263  // ServeIDPInitiated handes an IDP-initiated authorization request. Requests of this
   264  // type require us to know a registered service provider and (optionally) the RelayState
   265  // that will be passed to the application.
   266  func (idp *IdentityProvider) ServeIDPInitiated(w http.ResponseWriter, r *http.Request, serviceProviderID string, relayState string) {
   267  	req := &IdpAuthnRequest{
   268  		IDP:         idp,
   269  		HTTPRequest: r,
   270  		RelayState:  relayState,
   271  		Now:         TimeNow(),
   272  	}
   273  
   274  	session := idp.SessionProvider.GetSession(w, r, req)
   275  	if session == nil {
   276  		// If GetSession returns nil, it must have written an HTTP response, per the interface
   277  		// (this is probably because it drew a login form or something)
   278  		return
   279  	}
   280  
   281  	var err error
   282  	req.ServiceProviderMetadata, err = idp.ServiceProviderProvider.GetServiceProvider(r, serviceProviderID)
   283  	if err == os.ErrNotExist {
   284  		idp.Logger.Printf("cannot find service provider: %s", serviceProviderID)
   285  		http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
   286  		return
   287  	} else if err != nil {
   288  		idp.Logger.Printf("cannot find service provider %s: %v", serviceProviderID, err)
   289  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   290  		return
   291  	}
   292  
   293  	// find an ACS endpoint that we can use
   294  	for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors {
   295  		for _, endpoint := range spssoDescriptor.AssertionConsumerServices {
   296  			if endpoint.Binding == HTTPPostBinding {
   297  				// explicitly copy loop iterator variables
   298  				//
   299  				// c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable
   300  				//
   301  				// (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately,
   302  				// but it certainly doesn't hurt anything and may prevent bugs in the future.)
   303  				endpoint, spssoDescriptor := endpoint, spssoDescriptor
   304  
   305  				req.ACSEndpoint = &endpoint
   306  				req.SPSSODescriptor = &spssoDescriptor
   307  				break
   308  			}
   309  		}
   310  		if req.ACSEndpoint != nil {
   311  			break
   312  		}
   313  	}
   314  	if req.ACSEndpoint == nil {
   315  		idp.Logger.Printf("saml metadata does not contain an Assertion Customer Service url")
   316  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   317  		return
   318  	}
   319  
   320  	assertionMaker := idp.AssertionMaker
   321  	if assertionMaker == nil {
   322  		assertionMaker = DefaultAssertionMaker{}
   323  	}
   324  	if err := assertionMaker.MakeAssertion(req, session); err != nil {
   325  		idp.Logger.Printf("failed to make assertion: %s", err)
   326  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   327  		return
   328  	}
   329  
   330  	if err := req.WriteResponse(w); err != nil {
   331  		idp.Logger.Printf("failed to write response: %s", err)
   332  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   333  		return
   334  	}
   335  }
   336  
   337  // IdpAuthnRequest is used by IdentityProvider to handle a single authentication request.
   338  type IdpAuthnRequest struct {
   339  	IDP                     *IdentityProvider
   340  	HTTPRequest             *http.Request
   341  	RelayState              string
   342  	RequestBuffer           []byte
   343  	Request                 AuthnRequest
   344  	ServiceProviderMetadata *EntityDescriptor
   345  	SPSSODescriptor         *SPSSODescriptor
   346  	ACSEndpoint             *IndexedEndpoint
   347  	Assertion               *Assertion
   348  	AssertionEl             *etree.Element
   349  	ResponseEl              *etree.Element
   350  	Now                     time.Time
   351  }
   352  
   353  // NewIdpAuthnRequest returns a new IdpAuthnRequest for the given HTTP request to the authorization
   354  // service.
   355  func NewIdpAuthnRequest(idp *IdentityProvider, r *http.Request) (*IdpAuthnRequest, error) {
   356  	req := &IdpAuthnRequest{
   357  		IDP:         idp,
   358  		HTTPRequest: r,
   359  		Now:         TimeNow(),
   360  	}
   361  
   362  	switch r.Method {
   363  	case "GET":
   364  		compressedRequest, err := base64.StdEncoding.DecodeString(r.URL.Query().Get("SAMLRequest"))
   365  		if err != nil {
   366  			return nil, fmt.Errorf("cannot decode request: %s", err)
   367  		}
   368  		req.RequestBuffer, err = io.ReadAll(newSaferFlateReader(bytes.NewReader(compressedRequest)))
   369  		if err != nil {
   370  			return nil, fmt.Errorf("cannot decompress request: %s", err)
   371  		}
   372  		req.RelayState = r.URL.Query().Get("RelayState")
   373  	case "POST":
   374  		if err := r.ParseForm(); err != nil {
   375  			return nil, err
   376  		}
   377  		var err error
   378  		req.RequestBuffer, err = base64.StdEncoding.DecodeString(r.PostForm.Get("SAMLRequest"))
   379  		if err != nil {
   380  			return nil, err
   381  		}
   382  		req.RelayState = r.PostForm.Get("RelayState")
   383  	default:
   384  		return nil, fmt.Errorf("method not allowed")
   385  	}
   386  
   387  	return req, nil
   388  }
   389  
   390  // Validate checks that the authentication request is valid and assigns
   391  // the AuthnRequest and Metadata properties. Returns a non-nil error if the
   392  // request is not valid.
   393  func (req *IdpAuthnRequest) Validate() error {
   394  	if err := xrv.Validate(bytes.NewReader(req.RequestBuffer)); err != nil {
   395  		return err
   396  	}
   397  
   398  	if err := xml.Unmarshal(req.RequestBuffer, &req.Request); err != nil {
   399  		return err
   400  	}
   401  
   402  	// We always have exactly one IDP SSO descriptor
   403  	if len(req.IDP.Metadata().IDPSSODescriptors) != 1 {
   404  		panic("expected exactly one IDP SSO descriptor in IDP metadata")
   405  	}
   406  	idpSsoDescriptor := req.IDP.Metadata().IDPSSODescriptors[0]
   407  
   408  	// TODO(ross): support signed authn requests
   409  	// For now we do the safe thing and fail in the case where we think
   410  	// requests might be signed.
   411  	if idpSsoDescriptor.WantAuthnRequestsSigned != nil && *idpSsoDescriptor.WantAuthnRequestsSigned {
   412  		return fmt.Errorf("authn request signature checking is not currently supported")
   413  	}
   414  
   415  	// In http://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf ยง3.4.5.2
   416  	// we get a description of the Destination attribute:
   417  	//
   418  	//   If the message is signed, the Destination XML attribute in the root SAML
   419  	//   element of the protocol message MUST contain the URL to which the sender
   420  	//   has instructed the user agent to deliver the message. The recipient MUST
   421  	//   then verify that the value matches the location at which the message has
   422  	//   been received.
   423  	//
   424  	// We require the destination be correct either (a) if signing is enabled or
   425  	// (b) if it was provided.
   426  	mustHaveDestination := idpSsoDescriptor.WantAuthnRequestsSigned != nil && *idpSsoDescriptor.WantAuthnRequestsSigned
   427  	mustHaveDestination = mustHaveDestination || req.Request.Destination != ""
   428  	if mustHaveDestination {
   429  		if req.Request.Destination != req.IDP.SSOURL.String() {
   430  			return fmt.Errorf("expected destination to be %q, not %q", req.IDP.SSOURL.String(), req.Request.Destination)
   431  		}
   432  	}
   433  
   434  	if req.Request.IssueInstant.Add(MaxIssueDelay).Before(req.Now) {
   435  		return fmt.Errorf("request expired at %s",
   436  			req.Request.IssueInstant.Add(MaxIssueDelay))
   437  	}
   438  	if req.Request.Version != "2.0" {
   439  		return fmt.Errorf("expected SAML request version 2.0 got %v", req.Request.Version)
   440  	}
   441  
   442  	// find the service provider
   443  	serviceProviderID := req.Request.Issuer.Value
   444  	serviceProvider, err := req.IDP.ServiceProviderProvider.GetServiceProvider(req.HTTPRequest, serviceProviderID)
   445  	if err == os.ErrNotExist {
   446  		return fmt.Errorf("cannot handle request from unknown service provider %s", serviceProviderID)
   447  	} else if err != nil {
   448  		return fmt.Errorf("cannot find service provider %s: %v", serviceProviderID, err)
   449  	}
   450  	req.ServiceProviderMetadata = serviceProvider
   451  
   452  	// Check that the ACS URL matches an ACS endpoint in the SP metadata.
   453  	if err := req.getACSEndpoint(); err != nil {
   454  		return fmt.Errorf("cannot find assertion consumer service: %v", err)
   455  	}
   456  
   457  	return nil
   458  }
   459  
   460  func (req *IdpAuthnRequest) getACSEndpoint() error {
   461  	if req.Request.AssertionConsumerServiceIndex != "" {
   462  		for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors {
   463  			for _, spAssertionConsumerService := range spssoDescriptor.AssertionConsumerServices {
   464  				if strconv.Itoa(spAssertionConsumerService.Index) == req.Request.AssertionConsumerServiceIndex {
   465  					// explicitly copy loop iterator variables
   466  					//
   467  					// c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable
   468  					//
   469  					// (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately,
   470  					// but it certainly doesn't hurt anything and may prevent bugs in the future.)
   471  					spssoDescriptor, spAssertionConsumerService := spssoDescriptor, spAssertionConsumerService
   472  
   473  					req.SPSSODescriptor = &spssoDescriptor
   474  					req.ACSEndpoint = &spAssertionConsumerService
   475  					return nil
   476  				}
   477  			}
   478  		}
   479  	}
   480  
   481  	if req.Request.AssertionConsumerServiceURL != "" {
   482  		for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors {
   483  			for _, spAssertionConsumerService := range spssoDescriptor.AssertionConsumerServices {
   484  				if spAssertionConsumerService.Location == req.Request.AssertionConsumerServiceURL {
   485  					// explicitly copy loop iterator variables
   486  					//
   487  					// c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable
   488  					//
   489  					// (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately,
   490  					// but it certainly doesn't hurt anything and may prevent bugs in the future.)
   491  					spssoDescriptor, spAssertionConsumerService := spssoDescriptor, spAssertionConsumerService
   492  
   493  					req.SPSSODescriptor = &spssoDescriptor
   494  					req.ACSEndpoint = &spAssertionConsumerService
   495  					return nil
   496  				}
   497  			}
   498  		}
   499  	}
   500  
   501  	// Some service providers, like the Microsoft Azure AD service provider, issue
   502  	// assertion requests that don't specify an ACS url at all.
   503  	if req.Request.AssertionConsumerServiceURL == "" && req.Request.AssertionConsumerServiceIndex == "" {
   504  		// find a default ACS binding in the metadata that we can use
   505  		for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors {
   506  			for _, spAssertionConsumerService := range spssoDescriptor.AssertionConsumerServices {
   507  				if spAssertionConsumerService.IsDefault != nil && *spAssertionConsumerService.IsDefault {
   508  					switch spAssertionConsumerService.Binding {
   509  					case HTTPPostBinding, HTTPRedirectBinding:
   510  						// explicitly copy loop iterator variables
   511  						//
   512  						// c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable
   513  						//
   514  						// (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately,
   515  						// but it certainly doesn't hurt anything and may prevent bugs in the future.)
   516  						spssoDescriptor, spAssertionConsumerService := spssoDescriptor, spAssertionConsumerService
   517  
   518  						req.SPSSODescriptor = &spssoDescriptor
   519  						req.ACSEndpoint = &spAssertionConsumerService
   520  						return nil
   521  					}
   522  				}
   523  			}
   524  		}
   525  
   526  		// if we can't find a default, use *any* ACS binding
   527  		for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors {
   528  			for _, spAssertionConsumerService := range spssoDescriptor.AssertionConsumerServices {
   529  				switch spAssertionConsumerService.Binding {
   530  				case HTTPPostBinding, HTTPRedirectBinding:
   531  					// explicitly copy loop iterator variables
   532  					//
   533  					// c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable
   534  					//
   535  					// (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately,
   536  					// but it certainly doesn't hurt anything and may prevent bugs in the future.)
   537  					spssoDescriptor, spAssertionConsumerService := spssoDescriptor, spAssertionConsumerService
   538  
   539  					req.SPSSODescriptor = &spssoDescriptor
   540  					req.ACSEndpoint = &spAssertionConsumerService
   541  					return nil
   542  				}
   543  			}
   544  		}
   545  	}
   546  
   547  	return os.ErrNotExist // no ACS url found or specified
   548  }
   549  
   550  // DefaultAssertionMaker produces a SAML assertion for the
   551  // given request and assigns it to req.Assertion.
   552  type DefaultAssertionMaker struct {
   553  }
   554  
   555  // MakeAssertion implements AssertionMaker. It produces a SAML assertion from the
   556  // given request and assigns it to req.Assertion.
   557  func (DefaultAssertionMaker) MakeAssertion(req *IdpAuthnRequest, session *Session) error {
   558  	attributes := []Attribute{}
   559  
   560  	var attributeConsumingService *AttributeConsumingService
   561  	for _, acs := range req.SPSSODescriptor.AttributeConsumingServices {
   562  		if acs.IsDefault != nil && *acs.IsDefault {
   563  			// explicitly copy loop iterator variables
   564  			//
   565  			// c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable
   566  			//
   567  			// (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately,
   568  			// but it certainly doesn't hurt anything and may prevent bugs in the future.)
   569  			acs := acs
   570  
   571  			attributeConsumingService = &acs
   572  			break
   573  		}
   574  	}
   575  	if attributeConsumingService == nil {
   576  		for _, acs := range req.SPSSODescriptor.AttributeConsumingServices {
   577  			// explicitly copy loop iterator variables
   578  			//
   579  			// c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable
   580  			//
   581  			// (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately,
   582  			// but it certainly doesn't hurt anything and may prevent bugs in the future.)
   583  			acs := acs
   584  
   585  			attributeConsumingService = &acs
   586  			break
   587  		}
   588  	}
   589  	if attributeConsumingService == nil {
   590  		attributeConsumingService = &AttributeConsumingService{}
   591  	}
   592  
   593  	for _, requestedAttribute := range attributeConsumingService.RequestedAttributes {
   594  		if requestedAttribute.NameFormat == "urn:oasis:names:tc:SAML:2.0:attrname-format:basic" || requestedAttribute.NameFormat == "urn:oasis:names:tc:SAML:2.0:attrname-format:unspecified" {
   595  			attrName := requestedAttribute.Name
   596  			attrName = regexp.MustCompile("[^A-Za-z0-9]+").ReplaceAllString(attrName, "")
   597  			switch attrName {
   598  			case "email", "emailaddress":
   599  				attributes = append(attributes, Attribute{
   600  					FriendlyName: requestedAttribute.FriendlyName,
   601  					Name:         requestedAttribute.Name,
   602  					NameFormat:   requestedAttribute.NameFormat,
   603  					Values: []AttributeValue{{
   604  						Type:  "xs:string",
   605  						Value: session.UserEmail,
   606  					}},
   607  				})
   608  			case "name", "fullname", "cn", "commonname":
   609  				attributes = append(attributes, Attribute{
   610  					FriendlyName: requestedAttribute.FriendlyName,
   611  					Name:         requestedAttribute.Name,
   612  					NameFormat:   requestedAttribute.NameFormat,
   613  					Values: []AttributeValue{{
   614  						Type:  "xs:string",
   615  						Value: session.UserCommonName,
   616  					}},
   617  				})
   618  			case "givenname", "firstname":
   619  				attributes = append(attributes, Attribute{
   620  					FriendlyName: requestedAttribute.FriendlyName,
   621  					Name:         requestedAttribute.Name,
   622  					NameFormat:   requestedAttribute.NameFormat,
   623  					Values: []AttributeValue{{
   624  						Type:  "xs:string",
   625  						Value: session.UserGivenName,
   626  					}},
   627  				})
   628  			case "surname", "lastname", "familyname":
   629  				attributes = append(attributes, Attribute{
   630  					FriendlyName: requestedAttribute.FriendlyName,
   631  					Name:         requestedAttribute.Name,
   632  					NameFormat:   requestedAttribute.NameFormat,
   633  					Values: []AttributeValue{{
   634  						Type:  "xs:string",
   635  						Value: session.UserSurname,
   636  					}},
   637  				})
   638  			case "uid", "user", "userid":
   639  				attributes = append(attributes, Attribute{
   640  					FriendlyName: requestedAttribute.FriendlyName,
   641  					Name:         requestedAttribute.Name,
   642  					NameFormat:   requestedAttribute.NameFormat,
   643  					Values: []AttributeValue{{
   644  						Type:  "xs:string",
   645  						Value: session.UserName,
   646  					}},
   647  				})
   648  			}
   649  		}
   650  	}
   651  
   652  	if session.UserName != "" {
   653  		attributes = append(attributes, Attribute{
   654  			FriendlyName: "uid",
   655  			Name:         "urn:oid:0.9.2342.19200300.100.1.1",
   656  			NameFormat:   "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
   657  			Values: []AttributeValue{{
   658  				Type:  "xs:string",
   659  				Value: session.UserName,
   660  			}},
   661  		})
   662  	}
   663  
   664  	if session.UserEmail != "" {
   665  		attributes = append(attributes, Attribute{
   666  			FriendlyName: "eduPersonPrincipalName",
   667  			Name:         "urn:oid:1.3.6.1.4.1.5923.1.1.1.6",
   668  			NameFormat:   "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
   669  			Values: []AttributeValue{{
   670  				Type:  "xs:string",
   671  				Value: session.UserEmail,
   672  			}},
   673  		})
   674  	}
   675  	if session.UserSurname != "" {
   676  		attributes = append(attributes, Attribute{
   677  			FriendlyName: "sn",
   678  			Name:         "urn:oid:2.5.4.4",
   679  			NameFormat:   "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
   680  			Values: []AttributeValue{{
   681  				Type:  "xs:string",
   682  				Value: session.UserSurname,
   683  			}},
   684  		})
   685  	}
   686  	if session.UserGivenName != "" {
   687  		attributes = append(attributes, Attribute{
   688  			FriendlyName: "givenName",
   689  			Name:         "urn:oid:2.5.4.42",
   690  			NameFormat:   "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
   691  			Values: []AttributeValue{{
   692  				Type:  "xs:string",
   693  				Value: session.UserGivenName,
   694  			}},
   695  		})
   696  	}
   697  
   698  	if session.UserCommonName != "" {
   699  		attributes = append(attributes, Attribute{
   700  			FriendlyName: "cn",
   701  			Name:         "urn:oid:2.5.4.3",
   702  			NameFormat:   "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
   703  			Values: []AttributeValue{{
   704  				Type:  "xs:string",
   705  				Value: session.UserCommonName,
   706  			}},
   707  		})
   708  	}
   709  
   710  	if session.UserScopedAffiliation != "" {
   711  		attributes = append(attributes, Attribute{
   712  			FriendlyName: "uid",
   713  			Name:         "urn:oid:1.3.6.1.4.1.5923.1.1.1.9",
   714  			NameFormat:   "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
   715  			Values: []AttributeValue{{
   716  				Type:  "xs:string",
   717  				Value: session.UserScopedAffiliation,
   718  			}},
   719  		})
   720  	}
   721  
   722  	attributes = append(attributes, session.CustomAttributes...)
   723  
   724  	if len(session.Groups) != 0 {
   725  		groupMemberAttributeValues := []AttributeValue{}
   726  		for _, group := range session.Groups {
   727  			groupMemberAttributeValues = append(groupMemberAttributeValues, AttributeValue{
   728  				Type:  "xs:string",
   729  				Value: group,
   730  			})
   731  		}
   732  		attributes = append(attributes, Attribute{
   733  			FriendlyName: "eduPersonAffiliation",
   734  			Name:         "urn:oid:1.3.6.1.4.1.5923.1.1.1.1",
   735  			NameFormat:   "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
   736  			Values:       groupMemberAttributeValues,
   737  		})
   738  	}
   739  
   740  	if session.SubjectID != "" {
   741  		attributes = append(attributes, Attribute{
   742  			Name:       "urn:oasis:names:tc:SAML:attribute:subject-id",
   743  			NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
   744  			Values: []AttributeValue{
   745  				{
   746  					Type:  "xs:string",
   747  					Value: session.SubjectID,
   748  				},
   749  			},
   750  		})
   751  	}
   752  
   753  	// allow for some clock skew in the validity period using the
   754  	// issuer's apparent clock.
   755  	notBefore := req.Now.Add(-1 * MaxClockSkew)
   756  	notOnOrAfterAfter := req.Now.Add(MaxIssueDelay)
   757  	if notBefore.Before(req.Request.IssueInstant) {
   758  		notBefore = req.Request.IssueInstant
   759  		notOnOrAfterAfter = notBefore.Add(MaxIssueDelay)
   760  	}
   761  
   762  	nameIDFormat := "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"
   763  
   764  	if session.NameIDFormat != "" {
   765  		nameIDFormat = session.NameIDFormat
   766  	}
   767  
   768  	req.Assertion = &Assertion{
   769  		ID:           fmt.Sprintf("id-%x", randomBytes(20)),
   770  		IssueInstant: TimeNow(),
   771  		Version:      "2.0",
   772  		Issuer: Issuer{
   773  			Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
   774  			Value:  req.IDP.Metadata().EntityID,
   775  		},
   776  		Subject: &Subject{
   777  			NameID: &NameID{
   778  				Format:          nameIDFormat,
   779  				NameQualifier:   req.IDP.Metadata().EntityID,
   780  				SPNameQualifier: req.ServiceProviderMetadata.EntityID,
   781  				Value:           session.NameID,
   782  			},
   783  			SubjectConfirmations: []SubjectConfirmation{
   784  				{
   785  					Method: "urn:oasis:names:tc:SAML:2.0:cm:bearer",
   786  					SubjectConfirmationData: &SubjectConfirmationData{
   787  						Address:      req.HTTPRequest.RemoteAddr,
   788  						InResponseTo: req.Request.ID,
   789  						NotOnOrAfter: req.Now.Add(MaxIssueDelay),
   790  						Recipient:    req.ACSEndpoint.Location,
   791  					},
   792  				},
   793  			},
   794  		},
   795  		Conditions: &Conditions{
   796  			NotBefore:    notBefore,
   797  			NotOnOrAfter: notOnOrAfterAfter,
   798  			AudienceRestrictions: []AudienceRestriction{
   799  				{
   800  					Audience: Audience{Value: req.ServiceProviderMetadata.EntityID},
   801  				},
   802  			},
   803  		},
   804  		AuthnStatements: []AuthnStatement{
   805  			{
   806  				AuthnInstant: session.CreateTime,
   807  				SessionIndex: session.Index,
   808  				SubjectLocality: &SubjectLocality{
   809  					Address: req.HTTPRequest.RemoteAddr,
   810  				},
   811  				AuthnContext: AuthnContext{
   812  					AuthnContextClassRef: &AuthnContextClassRef{
   813  						Value: "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport",
   814  					},
   815  				},
   816  			},
   817  		},
   818  		AttributeStatements: []AttributeStatement{
   819  			{
   820  				Attributes: attributes,
   821  			},
   822  		},
   823  	}
   824  
   825  	return nil
   826  }
   827  
   828  // The Canonicalizer prefix list MUST be empty. Various implementations
   829  // (maybe ours?) do not appear to support non-empty prefix lists in XML C14N.
   830  const canonicalizerPrefixList = ""
   831  
   832  // MakeAssertionEl sets `AssertionEl` to a signed, possibly encrypted, version of `Assertion`.
   833  func (req *IdpAuthnRequest) MakeAssertionEl() error {
   834  	signingContext, err := req.signingContext()
   835  	if err != nil {
   836  		return err
   837  	}
   838  
   839  	assertionEl := req.Assertion.Element()
   840  
   841  	signedAssertionEl, err := signingContext.SignEnveloped(assertionEl)
   842  	if err != nil {
   843  		return err
   844  	}
   845  
   846  	sigEl := signedAssertionEl.Child[len(signedAssertionEl.Child)-1]
   847  	req.Assertion.Signature = sigEl.(*etree.Element)
   848  	signedAssertionEl = req.Assertion.Element()
   849  
   850  	certBuf, err := req.getSPEncryptionCert()
   851  	if err == os.ErrNotExist {
   852  		req.AssertionEl = signedAssertionEl
   853  		return nil
   854  	} else if err != nil {
   855  		return err
   856  	}
   857  
   858  	var signedAssertionBuf []byte
   859  	{
   860  		doc := etree.NewDocument()
   861  		doc.SetRoot(signedAssertionEl)
   862  		signedAssertionBuf, err = doc.WriteToBytes()
   863  		if err != nil {
   864  			return err
   865  		}
   866  	}
   867  
   868  	encryptor := xmlenc.OAEP()
   869  	encryptor.BlockCipher = xmlenc.AES128CBC
   870  	encryptor.DigestMethod = &xmlenc.SHA1
   871  	encryptedDataEl, err := encryptor.Encrypt(certBuf, signedAssertionBuf, nil)
   872  	if err != nil {
   873  		return err
   874  	}
   875  	encryptedDataEl.CreateAttr("Type", "http://www.w3.org/2001/04/xmlenc#Element")
   876  
   877  	encryptedAssertionEl := etree.NewElement("saml:EncryptedAssertion")
   878  	encryptedAssertionEl.AddChild(encryptedDataEl)
   879  	req.AssertionEl = encryptedAssertionEl
   880  
   881  	return nil
   882  }
   883  
   884  // IdpAuthnRequestForm contans HTML form information to be submitted to the
   885  // SAML HTTP POST binding ACS.
   886  type IdpAuthnRequestForm struct {
   887  	URL          string
   888  	SAMLResponse string
   889  	RelayState   string
   890  }
   891  
   892  // PostBinding creates the HTTP POST form information for this
   893  // `IdpAuthnRequest`. If `Response` is not already set, it calls MakeResponse
   894  // to produce it.
   895  func (req *IdpAuthnRequest) PostBinding() (IdpAuthnRequestForm, error) {
   896  	var form IdpAuthnRequestForm
   897  
   898  	if req.ResponseEl == nil {
   899  		if err := req.MakeResponse(); err != nil {
   900  			return form, err
   901  		}
   902  	}
   903  
   904  	doc := etree.NewDocument()
   905  	doc.SetRoot(req.ResponseEl)
   906  	responseBuf, err := doc.WriteToBytes()
   907  	if err != nil {
   908  		return form, err
   909  	}
   910  
   911  	if req.ACSEndpoint.Binding != HTTPPostBinding {
   912  		return form, fmt.Errorf("%s: unsupported binding %s",
   913  			req.ServiceProviderMetadata.EntityID,
   914  			req.ACSEndpoint.Binding)
   915  	}
   916  
   917  	form.URL = req.ACSEndpoint.Location
   918  	form.SAMLResponse = base64.StdEncoding.EncodeToString(responseBuf)
   919  	form.RelayState = req.RelayState
   920  
   921  	return form, nil
   922  }
   923  
   924  // WriteResponse writes the `Response` to the http.ResponseWriter. If
   925  // `Response` is not already set, it calls MakeResponse to produce it.
   926  func (req *IdpAuthnRequest) WriteResponse(w http.ResponseWriter) error {
   927  	form, err := req.PostBinding()
   928  	if err != nil {
   929  		return err
   930  	}
   931  
   932  	tmpl := template.Must(template.New("saml-post-form").Parse(`<html>` +
   933  		`<form method="post" action="{{.URL}}" id="SAMLResponseForm">` +
   934  		`<input type="hidden" name="SAMLResponse" value="{{.SAMLResponse}}" />` +
   935  		`<input type="hidden" name="RelayState" value="{{.RelayState}}" />` +
   936  		`<input id="SAMLSubmitButton" type="submit" value="Continue" />` +
   937  		`</form>` +
   938  		`<script>document.getElementById('SAMLSubmitButton').style.visibility='hidden';</script>` +
   939  		`<script>document.getElementById('SAMLResponseForm').submit();</script>` +
   940  		`</html>`))
   941  
   942  	buf := bytes.NewBuffer(nil)
   943  	if err := tmpl.Execute(buf, form); err != nil {
   944  		return err
   945  	}
   946  	if _, err := io.Copy(w, buf); err != nil {
   947  		return err
   948  	}
   949  	return nil
   950  }
   951  
   952  // getSPEncryptionCert returns the certificate which we can use to encrypt things
   953  // to the SP in PEM format, or nil if no such certificate is found.
   954  func (req *IdpAuthnRequest) getSPEncryptionCert() (*x509.Certificate, error) {
   955  	certStr := ""
   956  	for _, keyDescriptor := range req.SPSSODescriptor.KeyDescriptors {
   957  		if keyDescriptor.Use == "encryption" {
   958  			certStr = keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data
   959  			break
   960  		}
   961  	}
   962  
   963  	// If there are no certs explicitly labeled for encryption, return the first
   964  	// non-empty cert we find.
   965  	if certStr == "" {
   966  		for _, keyDescriptor := range req.SPSSODescriptor.KeyDescriptors {
   967  			if keyDescriptor.Use == "" && len(keyDescriptor.KeyInfo.X509Data.X509Certificates) != 0 && keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data != "" {
   968  				certStr = keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data
   969  				break
   970  			}
   971  		}
   972  	}
   973  
   974  	if certStr == "" {
   975  		return nil, os.ErrNotExist
   976  	}
   977  
   978  	// cleanup whitespace and re-encode a PEM
   979  	certStr = regexp.MustCompile(`\s+`).ReplaceAllString(certStr, "")
   980  	certBytes, err := base64.StdEncoding.DecodeString(certStr)
   981  	if err != nil {
   982  		return nil, fmt.Errorf("cannot decode certificate base64: %v", err)
   983  	}
   984  	cert, err := x509.ParseCertificate(certBytes)
   985  	if err != nil {
   986  		return nil, fmt.Errorf("cannot parse certificate: %v", err)
   987  	}
   988  	return cert, nil
   989  }
   990  
   991  // unmarshalEtreeHack parses `el` and sets values in the structure `v`.
   992  //
   993  // This is a hack -- it first serializes the element, then uses xml.Unmarshal.
   994  func unmarshalEtreeHack(el *etree.Element, v interface{}) error {
   995  	doc := etree.NewDocument()
   996  	doc.SetRoot(el)
   997  	buf, err := doc.WriteToBytes()
   998  	if err != nil {
   999  		return err
  1000  	}
  1001  	return xml.Unmarshal(buf, v)
  1002  }
  1003  
  1004  // MakeResponse creates and assigns a new SAML response in ResponseEl. `Assertion` must
  1005  // be non-nil. If MakeAssertionEl() has not been called, this function calls it for
  1006  // you.
  1007  func (req *IdpAuthnRequest) MakeResponse() error {
  1008  	if req.AssertionEl == nil {
  1009  		if err := req.MakeAssertionEl(); err != nil {
  1010  			return err
  1011  		}
  1012  	}
  1013  
  1014  	response := &Response{
  1015  		Destination:  req.ACSEndpoint.Location,
  1016  		ID:           fmt.Sprintf("id-%x", randomBytes(20)),
  1017  		InResponseTo: req.Request.ID,
  1018  		IssueInstant: req.Now,
  1019  		Version:      "2.0",
  1020  		Issuer: &Issuer{
  1021  			Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
  1022  			Value:  req.IDP.MetadataURL.String(),
  1023  		},
  1024  		Status: Status{
  1025  			StatusCode: StatusCode{
  1026  				Value: StatusSuccess,
  1027  			},
  1028  		},
  1029  	}
  1030  
  1031  	responseEl := response.Element()
  1032  	responseEl.AddChild(req.AssertionEl) // AssertionEl either an EncryptedAssertion or Assertion element
  1033  
  1034  	// Sign the response element (we've already signed the Assertion element)
  1035  	{
  1036  		signingContext, err := req.signingContext()
  1037  		if err != nil {
  1038  			return err
  1039  		}
  1040  
  1041  		signedResponseEl, err := signingContext.SignEnveloped(responseEl)
  1042  		if err != nil {
  1043  			return err
  1044  		}
  1045  
  1046  		sigEl := signedResponseEl.ChildElements()[len(signedResponseEl.ChildElements())-1]
  1047  		response.Signature = sigEl
  1048  		responseEl = response.Element()
  1049  		responseEl.AddChild(req.AssertionEl)
  1050  	}
  1051  
  1052  	req.ResponseEl = responseEl
  1053  	return nil
  1054  }
  1055  
  1056  // signingContext will create a signing context for the request.
  1057  func (req *IdpAuthnRequest) signingContext() (*dsig.SigningContext, error) {
  1058  	// Create a cert chain based off of the IDP cert and its intermediates.
  1059  	certificates := [][]byte{req.IDP.Certificate.Raw}
  1060  	for _, cert := range req.IDP.Intermediates {
  1061  		certificates = append(certificates, cert.Raw)
  1062  	}
  1063  
  1064  	var signingContext *dsig.SigningContext
  1065  	var err error
  1066  	// If signer is set, use it instead of the private key.
  1067  	if req.IDP.Signer != nil {
  1068  		signingContext, err = dsig.NewSigningContext(req.IDP.Signer, certificates)
  1069  		if err != nil {
  1070  			return nil, err
  1071  		}
  1072  	} else {
  1073  		keyPair := tls.Certificate{
  1074  			Certificate: certificates,
  1075  			PrivateKey:  req.IDP.Key,
  1076  			Leaf:        req.IDP.Certificate,
  1077  		}
  1078  		keyStore := dsig.TLSCertKeyStore(keyPair)
  1079  
  1080  		signingContext = dsig.NewDefaultSigningContext(keyStore)
  1081  	}
  1082  
  1083  	// Default to using SHA1 if the signature method isn't set.
  1084  	signatureMethod := req.IDP.SignatureMethod
  1085  	if signatureMethod == "" {
  1086  		signatureMethod = dsig.RSASHA1SignatureMethod
  1087  	}
  1088  
  1089  	signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
  1090  	if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
  1091  		return nil, err
  1092  	}
  1093  
  1094  	return signingContext, nil
  1095  }