github.com/crewjam/saml@v0.4.14/service_provider.go (about)

     1  package saml
     2  
     3  import (
     4  	"bytes"
     5  	"compress/flate"
     6  	"context"
     7  	"crypto/rsa"
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"encoding/base64"
    11  	"encoding/xml"
    12  	"errors"
    13  	"fmt"
    14  	"html/template"
    15  	"io"
    16  	"net/http"
    17  	"net/url"
    18  	"regexp"
    19  	"time"
    20  
    21  	"github.com/beevik/etree"
    22  	xrv "github.com/mattermost/xml-roundtrip-validator"
    23  	dsig "github.com/russellhaering/goxmldsig"
    24  	"github.com/russellhaering/goxmldsig/etreeutils"
    25  
    26  	"github.com/crewjam/saml/logger"
    27  	"github.com/crewjam/saml/xmlenc"
    28  )
    29  
    30  // NameIDFormat is the format of the id
    31  type NameIDFormat string
    32  
    33  // Element returns an XML element representation of n.
    34  func (n NameIDFormat) Element() *etree.Element {
    35  	el := etree.NewElement("")
    36  	el.SetText(string(n))
    37  	return el
    38  }
    39  
    40  // Name ID formats
    41  const (
    42  	UnspecifiedNameIDFormat  NameIDFormat = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified"
    43  	TransientNameIDFormat    NameIDFormat = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient"
    44  	EmailAddressNameIDFormat NameIDFormat = "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
    45  	PersistentNameIDFormat   NameIDFormat = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"
    46  )
    47  
    48  // SignatureVerifier verifies a signature
    49  //
    50  // Can be implemented in order to override ServiceProvider's default
    51  // way of verifying signatures.
    52  type SignatureVerifier interface {
    53  	VerifySignature(validationContext *dsig.ValidationContext, el *etree.Element) error
    54  }
    55  
    56  // ServiceProvider implements SAML Service provider.
    57  //
    58  // In SAML, service providers delegate responsibility for identifying
    59  // clients to an identity provider. If you are writing an application
    60  // that uses passwords (or whatever) stored somewhere else, then you
    61  // are service provider.
    62  //
    63  // See the example directory for an example of a web application using
    64  // the service provider interface.
    65  type ServiceProvider struct {
    66  	// Entity ID is optional - if not specified then MetadataURL will be used
    67  	EntityID string
    68  
    69  	// Key is the RSA private key we use to sign requests.
    70  	Key *rsa.PrivateKey
    71  
    72  	// Certificate is the RSA public part of Key.
    73  	Certificate   *x509.Certificate
    74  	Intermediates []*x509.Certificate
    75  
    76  	// HTTPClient to use during SAML artifact resolution
    77  	HTTPClient *http.Client
    78  
    79  	// MetadataURL is the full URL to the metadata endpoint on this host,
    80  	// i.e. https://example.com/saml/metadata
    81  	MetadataURL url.URL
    82  
    83  	// AcsURL is the full URL to the SAML Assertion Customer Service endpoint
    84  	// on this host, i.e. https://example.com/saml/acs
    85  	AcsURL url.URL
    86  
    87  	// SloURL is the full URL to the SAML Single Logout endpoint on this host.
    88  	// i.e. https://example.com/saml/slo
    89  	SloURL url.URL
    90  
    91  	// IDPMetadata is the metadata from the identity provider.
    92  	IDPMetadata *EntityDescriptor
    93  
    94  	// AuthnNameIDFormat is the format used in the NameIDPolicy for
    95  	// authentication requests
    96  	AuthnNameIDFormat NameIDFormat
    97  
    98  	// MetadataValidDuration is a duration used to calculate validUntil
    99  	// attribute in the metadata endpoint
   100  	MetadataValidDuration time.Duration
   101  
   102  	// ForceAuthn allows you to force re-authentication of users even if the user
   103  	// has a SSO session at the IdP.
   104  	ForceAuthn *bool
   105  
   106  	// RequestedAuthnContext allow you to specify the requested authentication
   107  	// context in authentication requests
   108  	RequestedAuthnContext *RequestedAuthnContext
   109  
   110  	// AllowIdpInitiated
   111  	AllowIDPInitiated bool
   112  
   113  	// DefaultRedirectURI where untracked requests (as of IDPInitiated) are redirected to
   114  	DefaultRedirectURI string
   115  
   116  	// SignatureVerifier, if non-nil, allows you to implement an alternative way
   117  	// to verify signatures.
   118  	SignatureVerifier SignatureVerifier
   119  
   120  	// SignatureMethod, if non-empty, authentication requests will be signed
   121  	SignatureMethod string
   122  
   123  	// LogoutBindings specify the bindings available for SLO endpoint. If empty,
   124  	// HTTP-POST binding is used.
   125  	LogoutBindings []string
   126  }
   127  
   128  // MaxIssueDelay is the longest allowed time between when a SAML assertion is
   129  // issued by the IDP and the time it is received by ParseResponse. This is used
   130  // to prevent old responses from being replayed (while allowing for some clock
   131  // drift between the SP and IDP).
   132  var MaxIssueDelay = time.Second * 90
   133  
   134  // MaxClockSkew allows for leeway for clock skew between the IDP and SP when
   135  // validating assertions. It defaults to 180 seconds (matches shibboleth).
   136  var MaxClockSkew = time.Second * 180
   137  
   138  // DefaultValidDuration is how long we assert that the SP metadata is valid.
   139  const DefaultValidDuration = time.Hour * 24 * 2
   140  
   141  // DefaultCacheDuration is how long we ask the IDP to cache the SP metadata.
   142  const DefaultCacheDuration = time.Hour * 24 * 1
   143  
   144  // Metadata returns the service provider metadata
   145  func (sp *ServiceProvider) Metadata() *EntityDescriptor {
   146  	validDuration := DefaultValidDuration
   147  	if sp.MetadataValidDuration > 0 {
   148  		validDuration = sp.MetadataValidDuration
   149  	}
   150  
   151  	authnRequestsSigned := len(sp.SignatureMethod) > 0
   152  	wantAssertionsSigned := true
   153  	validUntil := TimeNow().Add(validDuration)
   154  
   155  	var keyDescriptors []KeyDescriptor
   156  	if sp.Certificate != nil {
   157  		certBytes := sp.Certificate.Raw
   158  		for _, intermediate := range sp.Intermediates {
   159  			certBytes = append(certBytes, intermediate.Raw...)
   160  		}
   161  		keyDescriptors = []KeyDescriptor{
   162  			{
   163  				Use: "encryption",
   164  				KeyInfo: KeyInfo{
   165  					X509Data: X509Data{
   166  						X509Certificates: []X509Certificate{
   167  							{Data: base64.StdEncoding.EncodeToString(certBytes)},
   168  						},
   169  					},
   170  				},
   171  				EncryptionMethods: []EncryptionMethod{
   172  					{Algorithm: "http://www.w3.org/2001/04/xmlenc#aes128-cbc"},
   173  					{Algorithm: "http://www.w3.org/2001/04/xmlenc#aes192-cbc"},
   174  					{Algorithm: "http://www.w3.org/2001/04/xmlenc#aes256-cbc"},
   175  					{Algorithm: "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p"},
   176  				},
   177  			},
   178  		}
   179  		if len(sp.SignatureMethod) > 0 {
   180  			keyDescriptors = append(keyDescriptors, KeyDescriptor{
   181  				Use: "signing",
   182  				KeyInfo: KeyInfo{
   183  					X509Data: X509Data{
   184  						X509Certificates: []X509Certificate{
   185  							{Data: base64.StdEncoding.EncodeToString(certBytes)},
   186  						},
   187  					},
   188  				},
   189  			})
   190  		}
   191  	}
   192  
   193  	sloEndpoints := make([]Endpoint, len(sp.LogoutBindings))
   194  	for i, binding := range sp.LogoutBindings {
   195  		sloEndpoints[i] = Endpoint{
   196  			Binding:          binding,
   197  			Location:         sp.SloURL.String(),
   198  			ResponseLocation: sp.SloURL.String(),
   199  		}
   200  	}
   201  
   202  	return &EntityDescriptor{
   203  		EntityID:   firstSet(sp.EntityID, sp.MetadataURL.String()),
   204  		ValidUntil: validUntil,
   205  
   206  		SPSSODescriptors: []SPSSODescriptor{
   207  			{
   208  				SSODescriptor: SSODescriptor{
   209  					RoleDescriptor: RoleDescriptor{
   210  						ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol",
   211  						KeyDescriptors:             keyDescriptors,
   212  						ValidUntil:                 &validUntil,
   213  					},
   214  					SingleLogoutServices: sloEndpoints,
   215  					NameIDFormats:        []NameIDFormat{sp.AuthnNameIDFormat},
   216  				},
   217  				AuthnRequestsSigned:  &authnRequestsSigned,
   218  				WantAssertionsSigned: &wantAssertionsSigned,
   219  
   220  				AssertionConsumerServices: []IndexedEndpoint{
   221  					{
   222  						Binding:  HTTPPostBinding,
   223  						Location: sp.AcsURL.String(),
   224  						Index:    1,
   225  					},
   226  					{
   227  						Binding:  HTTPArtifactBinding,
   228  						Location: sp.AcsURL.String(),
   229  						Index:    2,
   230  					},
   231  				},
   232  			},
   233  		},
   234  	}
   235  }
   236  
   237  // MakeRedirectAuthenticationRequest creates a SAML authentication request using
   238  // the HTTP-Redirect binding. It returns a URL that we will redirect the user to
   239  // in order to start the auth process.
   240  func (sp *ServiceProvider) MakeRedirectAuthenticationRequest(relayState string) (*url.URL, error) {
   241  	req, err := sp.MakeAuthenticationRequest(sp.GetSSOBindingLocation(HTTPRedirectBinding), HTTPRedirectBinding, HTTPPostBinding)
   242  	if err != nil {
   243  		return nil, err
   244  	}
   245  	return req.Redirect(relayState, sp)
   246  }
   247  
   248  // Redirect returns a URL suitable for using the redirect binding with the request
   249  func (r *AuthnRequest) Redirect(relayState string, sp *ServiceProvider) (*url.URL, error) {
   250  	w := &bytes.Buffer{}
   251  	w1 := base64.NewEncoder(base64.StdEncoding, w)
   252  	w2, _ := flate.NewWriter(w1, 9)
   253  	doc := etree.NewDocument()
   254  	doc.SetRoot(r.Element())
   255  	if _, err := doc.WriteTo(w2); err != nil {
   256  		panic(err)
   257  	}
   258  	if err := w2.Close(); err != nil {
   259  		panic(err)
   260  	}
   261  	if err := w1.Close(); err != nil {
   262  		panic(err)
   263  	}
   264  
   265  	rv, _ := url.Parse(r.Destination)
   266  	// We can't depend on Query().set() as order matters for signing
   267  	query := rv.RawQuery
   268  	if len(query) > 0 {
   269  		query += "&SAMLRequest=" + url.QueryEscape(w.String())
   270  	} else {
   271  		query += "SAMLRequest=" + url.QueryEscape(w.String())
   272  	}
   273  
   274  	if relayState != "" {
   275  		query += "&RelayState=" + relayState
   276  	}
   277  	if len(sp.SignatureMethod) > 0 {
   278  		query += "&SigAlg=" + url.QueryEscape(sp.SignatureMethod)
   279  		signingContext, err := GetSigningContext(sp)
   280  
   281  		if err != nil {
   282  			return nil, err
   283  		}
   284  
   285  		sig, err := signingContext.SignString(query)
   286  		if err != nil {
   287  			return nil, err
   288  		}
   289  		query += "&Signature=" + url.QueryEscape(base64.StdEncoding.EncodeToString(sig))
   290  	}
   291  
   292  	rv.RawQuery = query
   293  
   294  	return rv, nil
   295  }
   296  
   297  // GetSSOBindingLocation returns URL for the IDP's Single Sign On Service binding
   298  // of the specified type (HTTPRedirectBinding or HTTPPostBinding)
   299  func (sp *ServiceProvider) GetSSOBindingLocation(binding string) string {
   300  	for _, idpSSODescriptor := range sp.IDPMetadata.IDPSSODescriptors {
   301  		for _, singleSignOnService := range idpSSODescriptor.SingleSignOnServices {
   302  			if singleSignOnService.Binding == binding {
   303  				return singleSignOnService.Location
   304  			}
   305  		}
   306  	}
   307  	return ""
   308  }
   309  
   310  // GetArtifactBindingLocation returns URL for the IDP's Artifact binding of the
   311  // specified type
   312  func (sp *ServiceProvider) GetArtifactBindingLocation(binding string) string {
   313  	for _, idpSSODescriptor := range sp.IDPMetadata.IDPSSODescriptors {
   314  		for _, artifactResolutionService := range idpSSODescriptor.ArtifactResolutionServices {
   315  			if artifactResolutionService.Binding == binding {
   316  				return artifactResolutionService.Location
   317  			}
   318  		}
   319  	}
   320  	return ""
   321  }
   322  
   323  // GetSLOBindingLocation returns URL for the IDP's Single Log Out Service binding
   324  // of the specified type (HTTPRedirectBinding or HTTPPostBinding)
   325  func (sp *ServiceProvider) GetSLOBindingLocation(binding string) string {
   326  	for _, idpSSODescriptor := range sp.IDPMetadata.IDPSSODescriptors {
   327  		for _, singleLogoutService := range idpSSODescriptor.SingleLogoutServices {
   328  			if singleLogoutService.Binding == binding {
   329  				return singleLogoutService.Location
   330  			}
   331  		}
   332  	}
   333  	return ""
   334  }
   335  
   336  // getIDPSigningCerts returns the certificates which we can use to verify things
   337  // signed by the IDP in PEM format, or nil if no such certificate is found.
   338  func (sp *ServiceProvider) getIDPSigningCerts() ([]*x509.Certificate, error) {
   339  	var certStrs []string
   340  
   341  	// We need to include non-empty certs where the "use" attribute is
   342  	// either set to "signing" or is missing
   343  	for _, idpSSODescriptor := range sp.IDPMetadata.IDPSSODescriptors {
   344  		for _, keyDescriptor := range idpSSODescriptor.KeyDescriptors {
   345  			if len(keyDescriptor.KeyInfo.X509Data.X509Certificates) != 0 {
   346  				switch keyDescriptor.Use {
   347  				case "", "signing":
   348  					for _, certificate := range keyDescriptor.KeyInfo.X509Data.X509Certificates {
   349  						certStrs = append(certStrs, certificate.Data)
   350  					}
   351  				}
   352  			}
   353  		}
   354  	}
   355  
   356  	if len(certStrs) == 0 {
   357  		return nil, errors.New("cannot find any signing certificate in the IDP SSO descriptor")
   358  	}
   359  
   360  	certs := make([]*x509.Certificate, len(certStrs))
   361  
   362  	// cleanup whitespace
   363  	regex := regexp.MustCompile(`\s+`)
   364  	for i, certStr := range certStrs {
   365  		certStr = regex.ReplaceAllString(certStr, "")
   366  		certBytes, err := base64.StdEncoding.DecodeString(certStr)
   367  		if err != nil {
   368  			return nil, fmt.Errorf("cannot parse certificate: %s", err)
   369  		}
   370  
   371  		parsedCert, err := x509.ParseCertificate(certBytes)
   372  		if err != nil {
   373  			return nil, err
   374  		}
   375  		certs[i] = parsedCert
   376  	}
   377  
   378  	return certs, nil
   379  }
   380  
   381  // MakeArtifactResolveRequest produces a new ArtifactResolve object to send to the idp's Artifact resolver
   382  func (sp *ServiceProvider) MakeArtifactResolveRequest(artifactID string) (*ArtifactResolve, error) {
   383  	req := ArtifactResolve{
   384  		ID:           fmt.Sprintf("id-%x", randomBytes(20)),
   385  		IssueInstant: TimeNow(),
   386  		Version:      "2.0",
   387  		Issuer: &Issuer{
   388  			Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
   389  			Value:  firstSet(sp.EntityID, sp.MetadataURL.String()),
   390  		},
   391  		Artifact: artifactID,
   392  	}
   393  
   394  	if len(sp.SignatureMethod) > 0 {
   395  		if err := sp.SignArtifactResolve(&req); err != nil {
   396  			return nil, err
   397  		}
   398  	}
   399  
   400  	return &req, nil
   401  }
   402  
   403  // MakeAuthenticationRequest produces a new AuthnRequest object to send to the idpURL
   404  // that uses the specified binding (HTTPRedirectBinding or HTTPPostBinding)
   405  func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string, binding string, resultBinding string) (*AuthnRequest, error) {
   406  
   407  	allowCreate := true
   408  	nameIDFormat := sp.nameIDFormat()
   409  	req := AuthnRequest{
   410  		AssertionConsumerServiceURL: sp.AcsURL.String(),
   411  		Destination:                 idpURL,
   412  		ProtocolBinding:             resultBinding, // default binding for the response
   413  		ID:                          fmt.Sprintf("id-%x", randomBytes(20)),
   414  		IssueInstant:                TimeNow(),
   415  		Version:                     "2.0",
   416  		Issuer: &Issuer{
   417  			Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
   418  			Value:  firstSet(sp.EntityID, sp.MetadataURL.String()),
   419  		},
   420  		NameIDPolicy: &NameIDPolicy{
   421  			AllowCreate: &allowCreate,
   422  			// TODO(ross): figure out exactly policy we need
   423  			// urn:mace:shibboleth:1.0:nameIdentifier
   424  			// urn:oasis:names:tc:SAML:2.0:nameid-format:transient
   425  			Format: &nameIDFormat,
   426  		},
   427  		ForceAuthn:            sp.ForceAuthn,
   428  		RequestedAuthnContext: sp.RequestedAuthnContext,
   429  	}
   430  	// We don't need to sign the XML document if the IDP uses HTTP-Redirect binding
   431  	if len(sp.SignatureMethod) > 0 && binding == HTTPPostBinding {
   432  		if err := sp.SignAuthnRequest(&req); err != nil {
   433  			return nil, err
   434  		}
   435  	}
   436  	return &req, nil
   437  }
   438  
   439  // GetSigningContext returns a dsig.SigningContext initialized based on the Service Provider's configuration
   440  func GetSigningContext(sp *ServiceProvider) (*dsig.SigningContext, error) {
   441  	keyPair := tls.Certificate{
   442  		Certificate: [][]byte{sp.Certificate.Raw},
   443  		PrivateKey:  sp.Key,
   444  		Leaf:        sp.Certificate,
   445  	}
   446  	// TODO: add intermediates for SP
   447  	// for _, cert := range sp.Intermediates {
   448  	// 	keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
   449  	// }
   450  	keyStore := dsig.TLSCertKeyStore(keyPair)
   451  
   452  	if sp.SignatureMethod != dsig.RSASHA1SignatureMethod &&
   453  		sp.SignatureMethod != dsig.RSASHA256SignatureMethod &&
   454  		sp.SignatureMethod != dsig.RSASHA512SignatureMethod {
   455  		return nil, fmt.Errorf("invalid signing method %s", sp.SignatureMethod)
   456  	}
   457  	signatureMethod := sp.SignatureMethod
   458  	signingContext := dsig.NewDefaultSigningContext(keyStore)
   459  	signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
   460  	if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
   461  		return nil, err
   462  	}
   463  
   464  	return signingContext, nil
   465  }
   466  
   467  // SignArtifactResolve adds the `Signature` element to the `ArtifactResolve`.
   468  func (sp *ServiceProvider) SignArtifactResolve(req *ArtifactResolve) error {
   469  	signingContext, err := GetSigningContext(sp)
   470  	if err != nil {
   471  		return err
   472  	}
   473  	assertionEl := req.Element()
   474  
   475  	signedRequestEl, err := signingContext.SignEnveloped(assertionEl)
   476  	if err != nil {
   477  		return err
   478  	}
   479  
   480  	sigEl := signedRequestEl.Child[len(signedRequestEl.Child)-1]
   481  	req.Signature = sigEl.(*etree.Element)
   482  	return nil
   483  }
   484  
   485  // SignAuthnRequest adds the `Signature` element to the `AuthnRequest`.
   486  func (sp *ServiceProvider) SignAuthnRequest(req *AuthnRequest) error {
   487  
   488  	signingContext, err := GetSigningContext(sp)
   489  	if err != nil {
   490  		return err
   491  	}
   492  	assertionEl := req.Element()
   493  
   494  	signedRequestEl, err := signingContext.SignEnveloped(assertionEl)
   495  	if err != nil {
   496  		return err
   497  	}
   498  
   499  	sigEl := signedRequestEl.Child[len(signedRequestEl.Child)-1]
   500  	req.Signature = sigEl.(*etree.Element)
   501  	return nil
   502  }
   503  
   504  // MakePostAuthenticationRequest creates a SAML authentication request using
   505  // the HTTP-POST binding. It returns HTML text representing an HTML form that
   506  // can be sent presented to a browser to initiate the login process.
   507  func (sp *ServiceProvider) MakePostAuthenticationRequest(relayState string) ([]byte, error) {
   508  	req, err := sp.MakeAuthenticationRequest(sp.GetSSOBindingLocation(HTTPPostBinding), HTTPPostBinding, HTTPPostBinding)
   509  	if err != nil {
   510  		return nil, err
   511  	}
   512  	return req.Post(relayState), nil
   513  }
   514  
   515  // Post returns an HTML form suitable for using the HTTP-POST binding with the request
   516  func (r *AuthnRequest) Post(relayState string) []byte {
   517  	doc := etree.NewDocument()
   518  	doc.SetRoot(r.Element())
   519  	reqBuf, err := doc.WriteToBytes()
   520  	if err != nil {
   521  		panic(err)
   522  	}
   523  	encodedReqBuf := base64.StdEncoding.EncodeToString(reqBuf)
   524  
   525  	tmpl := template.Must(template.New("saml-post-form").Parse(`` +
   526  		`<form method="post" action="{{.URL}}" id="SAMLRequestForm">` +
   527  		`<input type="hidden" name="SAMLRequest" value="{{.SAMLRequest}}" />` +
   528  		`<input type="hidden" name="RelayState" value="{{.RelayState}}" />` +
   529  		`<input id="SAMLSubmitButton" type="submit" value="Submit" />` +
   530  		`</form>` +
   531  		`<script>document.getElementById('SAMLSubmitButton').style.visibility="hidden";` +
   532  		`document.getElementById('SAMLRequestForm').submit();</script>`))
   533  	data := struct {
   534  		URL         string
   535  		SAMLRequest string
   536  		RelayState  string
   537  	}{
   538  		URL:         r.Destination,
   539  		SAMLRequest: encodedReqBuf,
   540  		RelayState:  relayState,
   541  	}
   542  
   543  	rv := bytes.Buffer{}
   544  	if err := tmpl.Execute(&rv, data); err != nil {
   545  		panic(err)
   546  	}
   547  
   548  	return rv.Bytes()
   549  }
   550  
   551  // AssertionAttributes is a list of AssertionAttribute
   552  type AssertionAttributes []AssertionAttribute
   553  
   554  // Get returns the assertion attribute whose Name or FriendlyName
   555  // matches name, or nil if no matching attribute is found.
   556  func (aa AssertionAttributes) Get(name string) *AssertionAttribute {
   557  	for _, attr := range aa {
   558  		if attr.Name == name {
   559  			return &attr
   560  		}
   561  		if attr.FriendlyName == name {
   562  			return &attr
   563  		}
   564  	}
   565  	return nil
   566  }
   567  
   568  // AssertionAttribute represents an attribute of the user extracted from
   569  // a SAML Assertion.
   570  type AssertionAttribute struct {
   571  	FriendlyName string
   572  	Name         string
   573  	Value        string
   574  }
   575  
   576  // InvalidResponseError is the error produced by ParseResponse when it fails.
   577  // The underlying error is in PrivateErr. Response is the response as it was
   578  // known at the time validation failed. Now is the time that was used to validate
   579  // time-dependent parts of the assertion.
   580  type InvalidResponseError struct {
   581  	PrivateErr error
   582  	Response   string
   583  	Now        time.Time
   584  }
   585  
   586  func (ivr *InvalidResponseError) Error() string {
   587  	return "Authentication failed"
   588  }
   589  
   590  // ErrBadStatus is returned when the assertion provided is valid but the
   591  // status code is not "urn:oasis:names:tc:SAML:2.0:status:Success".
   592  type ErrBadStatus struct {
   593  	Status string
   594  }
   595  
   596  func (e ErrBadStatus) Error() string {
   597  	return e.Status
   598  }
   599  
   600  // ParseResponse extracts the SAML IDP response received in req, resolves
   601  // artifacts when necessary, validates it, and returns the verified assertion.
   602  func (sp *ServiceProvider) ParseResponse(req *http.Request, possibleRequestIDs []string) (*Assertion, error) {
   603  	if artifactID := req.Form.Get("SAMLart"); artifactID != "" {
   604  		return sp.handleArtifactRequest(req.Context(), artifactID, possibleRequestIDs)
   605  	}
   606  	return sp.parseResponseHTTP(req, possibleRequestIDs)
   607  }
   608  
   609  func (sp *ServiceProvider) handleArtifactRequest(ctx context.Context, artifactID string, possibleRequestIDs []string) (*Assertion, error) {
   610  	retErr := &InvalidResponseError{Now: TimeNow()}
   611  
   612  	artifactResolveRequest, err := sp.MakeArtifactResolveRequest(artifactID)
   613  	if err != nil {
   614  		retErr.PrivateErr = fmt.Errorf("cannot generate artifact resolution request: %s", err)
   615  		return nil, retErr
   616  	}
   617  
   618  	requestBody, err := elementToBytes(artifactResolveRequest.SoapRequest())
   619  	if err != nil {
   620  		retErr.PrivateErr = err
   621  		return nil, retErr
   622  	}
   623  
   624  	req, err := http.NewRequestWithContext(ctx, "POST", sp.GetArtifactBindingLocation(SOAPBinding),
   625  		bytes.NewReader(requestBody))
   626  	if err != nil {
   627  		retErr.PrivateErr = err
   628  		return nil, retErr
   629  	}
   630  	req.Header.Set("Content-Type", "text/xml")
   631  
   632  	httpClient := sp.HTTPClient
   633  	if httpClient == nil {
   634  		httpClient = http.DefaultClient
   635  	}
   636  	response, err := httpClient.Do(req)
   637  	if err != nil {
   638  		retErr.PrivateErr = fmt.Errorf("cannot resolve artifact: %s", err)
   639  		return nil, retErr
   640  	}
   641  	defer func() {
   642  		if err := response.Body.Close(); err != nil {
   643  			logger.DefaultLogger.Printf("Error while closing response body during artifact resolution: %v", err)
   644  		}
   645  	}()
   646  	if response.StatusCode != 200 {
   647  		retErr.PrivateErr = fmt.Errorf("Error during artifact resolution: HTTP status %d (%s)", response.StatusCode, response.Status)
   648  		return nil, retErr
   649  	}
   650  	responseBody, err := io.ReadAll(response.Body)
   651  	if err != nil {
   652  		retErr.PrivateErr = fmt.Errorf("Error during artifact resolution: %s", err)
   653  		return nil, retErr
   654  	}
   655  	assertion, err := sp.ParseXMLArtifactResponse(responseBody, possibleRequestIDs, artifactResolveRequest.ID)
   656  	if err != nil {
   657  		return nil, err
   658  	}
   659  	return assertion, nil
   660  }
   661  
   662  func (sp *ServiceProvider) parseResponseHTTP(req *http.Request, possibleRequestIDs []string) (*Assertion, error) {
   663  	retErr := &InvalidResponseError{
   664  		Now: TimeNow(),
   665  	}
   666  
   667  	rawResponseBuf, err := base64.StdEncoding.DecodeString(req.PostForm.Get("SAMLResponse"))
   668  	if err != nil {
   669  		retErr.PrivateErr = fmt.Errorf("cannot parse base64: %s", err)
   670  		return nil, retErr
   671  	}
   672  
   673  	assertion, err := sp.ParseXMLResponse(rawResponseBuf, possibleRequestIDs)
   674  	if err != nil {
   675  		return nil, err
   676  	}
   677  	return assertion, nil
   678  }
   679  
   680  // ParseXMLArtifactResponse validates the SAML Artifact resolver response
   681  // and returns the verified assertion.
   682  //
   683  // This function handles verifying the digital signature, and verifying
   684  // that the specified conditions and properties are met.
   685  //
   686  // If the function fails it will return an InvalidResponseError whose
   687  // properties are useful in describing which part of the parsing process
   688  // failed. However, to discourage inadvertent disclosure the diagnostic
   689  // information, the Error() method returns a static string.
   690  func (sp *ServiceProvider) ParseXMLArtifactResponse(soapResponseXML []byte, possibleRequestIDs []string, artifactRequestID string) (*Assertion, error) {
   691  	now := TimeNow()
   692  	retErr := &InvalidResponseError{
   693  		Response: string(soapResponseXML),
   694  		Now:      now,
   695  	}
   696  
   697  	// ensure that the response XML is well-formed before we parse it
   698  	if err := xrv.Validate(bytes.NewReader(soapResponseXML)); err != nil {
   699  		retErr.PrivateErr = fmt.Errorf("invalid xml: %s", err)
   700  		return nil, retErr
   701  	}
   702  
   703  	doc := etree.NewDocument()
   704  	if err := doc.ReadFromBytes(soapResponseXML); err != nil {
   705  		retErr.PrivateErr = fmt.Errorf("cannot unmarshal response: %s", err)
   706  		return nil, retErr
   707  	}
   708  	if doc.Root() == nil {
   709  		retErr.PrivateErr = errors.New("invalid xml: no root")
   710  		return nil, retErr
   711  	}
   712  	if doc.Root().NamespaceURI() != "http://schemas.xmlsoap.org/soap/envelope/" ||
   713  		doc.Root().Tag != "Envelope" {
   714  		retErr.PrivateErr = fmt.Errorf("expected a SOAP Envelope")
   715  		return nil, retErr
   716  	}
   717  
   718  	soapBodyEl, err := findOneChild(doc.Root(), "http://schemas.xmlsoap.org/soap/envelope/", "Body")
   719  	if err != nil {
   720  		retErr.PrivateErr = err
   721  		return nil, retErr
   722  	}
   723  
   724  	artifactResponseEl, err := findOneChild(soapBodyEl, "urn:oasis:names:tc:SAML:2.0:protocol", "ArtifactResponse")
   725  	if err != nil {
   726  		retErr.PrivateErr = err
   727  		return nil, retErr
   728  	}
   729  
   730  	return sp.parseArtifactResponse(artifactResponseEl, possibleRequestIDs, artifactRequestID, now)
   731  }
   732  
   733  func (sp *ServiceProvider) parseArtifactResponse(artifactResponseEl *etree.Element, possibleRequestIDs []string, artifactRequestID string, now time.Time) (*Assertion, error) {
   734  	retErr := &InvalidResponseError{
   735  		Now:      now,
   736  		Response: elementToString(artifactResponseEl),
   737  	}
   738  
   739  	{
   740  		var artifactResponse ArtifactResponse
   741  		if err := unmarshalElement(artifactResponseEl, &artifactResponse); err != nil {
   742  			retErr.PrivateErr = err
   743  			return nil, retErr
   744  		}
   745  		if artifactResponse.InResponseTo != artifactRequestID {
   746  			retErr.PrivateErr = fmt.Errorf("`InResponseTo` does not match the artifact request ID (expected %s)", artifactRequestID)
   747  			return nil, retErr
   748  		}
   749  		if artifactResponse.IssueInstant.Add(MaxIssueDelay).Before(now) {
   750  			retErr.PrivateErr = fmt.Errorf("response IssueInstant expired at %s", artifactResponse.IssueInstant.Add(MaxIssueDelay))
   751  			return nil, retErr
   752  		}
   753  		if artifactResponse.Issuer != nil && artifactResponse.Issuer.Value != sp.IDPMetadata.EntityID {
   754  			retErr.PrivateErr = fmt.Errorf("response Issuer does not match the IDP metadata (expected %q)", sp.IDPMetadata.EntityID)
   755  			return nil, retErr
   756  		}
   757  		if artifactResponse.Status.StatusCode.Value != StatusSuccess {
   758  			retErr.PrivateErr = ErrBadStatus{Status: artifactResponse.Status.StatusCode.Value}
   759  			return nil, retErr
   760  		}
   761  	}
   762  
   763  	var signatureRequirement signatureRequirement
   764  	sigErr := sp.validateSignature(artifactResponseEl)
   765  	switch sigErr {
   766  	case nil:
   767  		signatureRequirement = signatureNotRequired
   768  	case errSignatureElementNotPresent:
   769  		signatureRequirement = signatureRequired
   770  	default:
   771  		retErr.PrivateErr = sigErr
   772  		return nil, retErr
   773  	}
   774  
   775  	responseEl, err := findOneChild(artifactResponseEl, "urn:oasis:names:tc:SAML:2.0:protocol", "Response")
   776  	if err != nil {
   777  		retErr.PrivateErr = err
   778  		return nil, retErr
   779  	}
   780  
   781  	assertion, err := sp.parseResponse(responseEl, possibleRequestIDs, now, signatureRequirement)
   782  	if err != nil {
   783  		retErr.PrivateErr = err
   784  		return nil, retErr
   785  	}
   786  
   787  	return assertion, nil
   788  }
   789  
   790  // ParseXMLResponse parses and validates the SAML IDP response and
   791  // returns the verified assertion.
   792  //
   793  // This function handles decrypting the message, verifying the digital
   794  // signature on the assertion, and verifying that the specified conditions
   795  // and properties are met.
   796  //
   797  // If the function fails it will return an InvalidResponseError whose
   798  // properties are useful in describing which part of the parsing process
   799  // failed. However, to discourage inadvertent disclosure the diagnostic
   800  // information, the Error() method returns a static string.
   801  func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleRequestIDs []string) (*Assertion, error) {
   802  	now := TimeNow()
   803  	var err error
   804  	retErr := &InvalidResponseError{
   805  		Now:      now,
   806  		Response: string(decodedResponseXML),
   807  	}
   808  
   809  	// ensure that the response XML is well-formed before we parse it
   810  	if err := xrv.Validate(bytes.NewReader(decodedResponseXML)); err != nil {
   811  		retErr.PrivateErr = fmt.Errorf("invalid xml: %s", err)
   812  		return nil, retErr
   813  	}
   814  
   815  	doc := etree.NewDocument()
   816  	if err := doc.ReadFromBytes(decodedResponseXML); err != nil {
   817  		retErr.PrivateErr = err
   818  		return nil, retErr
   819  	}
   820  	if doc.Root() == nil {
   821  		retErr.PrivateErr = errors.New("invalid xml: no root")
   822  		return nil, retErr
   823  	}
   824  
   825  	assertion, err := sp.parseResponse(doc.Root(), possibleRequestIDs, now, signatureRequired)
   826  	if err != nil {
   827  		retErr.PrivateErr = err
   828  		return nil, retErr
   829  	}
   830  
   831  	return assertion, nil
   832  }
   833  
   834  type signatureRequirement int
   835  
   836  const (
   837  	signatureRequired signatureRequirement = iota
   838  	signatureNotRequired
   839  )
   840  
   841  // validateXMLResponse validates the SAML IDP response and returns
   842  // the verified assertion.
   843  //
   844  // This function handles decrypting the message, verifying the digital
   845  // signature on the assertion, and verifying that the specified conditions
   846  // and properties are met.
   847  func (sp *ServiceProvider) parseResponse(responseEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
   848  	var responseSignatureErr error
   849  	var responseHasSignature bool
   850  	if signatureRequirement == signatureRequired {
   851  		responseSignatureErr = sp.validateSignature(responseEl)
   852  		if responseSignatureErr != errSignatureElementNotPresent {
   853  			responseHasSignature = true
   854  		}
   855  
   856  		// Note: we're deferring taking action on the signature validation until after we've
   857  		// processed the request attributes, because certain test cases seem to require this mis-feature.
   858  		// TODO(ross): adjust the test cases so that we can abort here if the Response signature is invalid.
   859  	}
   860  
   861  	// validate request attributes
   862  	{
   863  		var response Response
   864  		if err := unmarshalElement(responseEl, &response); err != nil {
   865  			return nil, fmt.Errorf("cannot unmarshal response: %v", err)
   866  		}
   867  
   868  		// If the response is *not* signed, the Destination may be omitted.
   869  		if responseHasSignature || response.Destination != "" {
   870  			if response.Destination != sp.AcsURL.String() {
   871  				return nil, fmt.Errorf("`Destination` does not match AcsURL (expected %q, actual %q)", sp.AcsURL.String(), response.Destination)
   872  			}
   873  		}
   874  
   875  		requestIDvalid := false
   876  		if sp.AllowIDPInitiated {
   877  			requestIDvalid = true
   878  		} else {
   879  			for _, possibleRequestID := range possibleRequestIDs {
   880  				if response.InResponseTo == possibleRequestID {
   881  					requestIDvalid = true
   882  				}
   883  			}
   884  		}
   885  		if !requestIDvalid {
   886  			return nil, fmt.Errorf("`InResponseTo` does not match any of the possible request IDs (expected %v)", possibleRequestIDs)
   887  		}
   888  
   889  		if response.IssueInstant.Add(MaxIssueDelay).Before(now) {
   890  			return nil, fmt.Errorf("response IssueInstant expired at %s", response.IssueInstant.Add(MaxIssueDelay))
   891  		}
   892  		if response.Issuer != nil && response.Issuer.Value != sp.IDPMetadata.EntityID {
   893  			return nil, fmt.Errorf("response Issuer does not match the IDP metadata (expected %q)", sp.IDPMetadata.EntityID)
   894  		}
   895  		if response.Status.StatusCode.Value != StatusSuccess {
   896  			return nil, ErrBadStatus{Status: response.Status.StatusCode.Value}
   897  		}
   898  	}
   899  
   900  	if signatureRequirement == signatureRequired {
   901  		switch responseSignatureErr {
   902  		case nil:
   903  			// since the request has a signature, none of the Assertions need one
   904  			signatureRequirement = signatureNotRequired
   905  		case errSignatureElementNotPresent:
   906  			// the request has no signature, so assertions must be signed
   907  			signatureRequirement = signatureRequired // nop
   908  		default:
   909  			return nil, responseSignatureErr
   910  		}
   911  	}
   912  
   913  	var errs []error
   914  	var assertions []Assertion
   915  
   916  	// look for encrypted assertions
   917  	{
   918  		encryptedAssertionEls, err := findChildren(responseEl, "urn:oasis:names:tc:SAML:2.0:assertion", "EncryptedAssertion")
   919  		if err != nil {
   920  			return nil, err
   921  		}
   922  		for _, encryptedAssertionEl := range encryptedAssertionEls {
   923  			assertion, err := sp.parseEncryptedAssertion(encryptedAssertionEl, possibleRequestIDs, now, signatureRequirement)
   924  			if err != nil {
   925  				errs = append(errs, err)
   926  				continue
   927  			}
   928  			assertions = append(assertions, *assertion)
   929  		}
   930  	}
   931  
   932  	// look for plaintext assertions
   933  	{
   934  		assertionEls, err := findChildren(responseEl, "urn:oasis:names:tc:SAML:2.0:assertion", "Assertion")
   935  		if err != nil {
   936  			return nil, err
   937  		}
   938  		for _, assertionEl := range assertionEls {
   939  			assertion, err := sp.parseAssertion(assertionEl, possibleRequestIDs, now, signatureRequirement)
   940  			if err != nil {
   941  				errs = append(errs, err)
   942  				continue
   943  			}
   944  			assertions = append(assertions, *assertion)
   945  		}
   946  	}
   947  
   948  	if len(assertions) == 0 {
   949  		if len(errs) > 0 {
   950  			return nil, errs[0]
   951  		}
   952  		return nil, fmt.Errorf("expected at least one valid Assertion, none found")
   953  	}
   954  
   955  	// if we have at least one assertion, return the first one. It is almost universally true that valid responses
   956  	// contain only one assertion. This is less that fully correct, but we didn't realize that there could be more
   957  	// than one assertion at the time of establishing the public interface of ParseXMLResponse(), so for compatibility
   958  	// we return the first one.
   959  	return &assertions[0], nil
   960  }
   961  
   962  func (sp *ServiceProvider) parseEncryptedAssertion(encryptedAssertionEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
   963  	assertionEl, err := sp.decryptElement(encryptedAssertionEl)
   964  	if err != nil {
   965  		return nil, fmt.Errorf("failed to decrypt EncryptedAssertion: %v", err)
   966  	}
   967  	return sp.parseAssertion(assertionEl, possibleRequestIDs, now, signatureRequirement)
   968  }
   969  
   970  func (sp *ServiceProvider) decryptElement(encryptedEl *etree.Element) (*etree.Element, error) {
   971  	encryptedDataEl, err := findOneChild(encryptedEl, "http://www.w3.org/2001/04/xmlenc#", "EncryptedData")
   972  	if err != nil {
   973  		return nil, err
   974  	}
   975  
   976  	var key interface{} = sp.Key
   977  	keyEl := encryptedEl.FindElement("./EncryptedKey")
   978  	if keyEl != nil {
   979  		var err error
   980  		key, err = xmlenc.Decrypt(sp.Key, keyEl)
   981  		if err != nil {
   982  			return nil, fmt.Errorf("failed to decrypt key from response: %s", err)
   983  		}
   984  	}
   985  
   986  	plaintextEl, err := xmlenc.Decrypt(key, encryptedDataEl)
   987  	if err != nil {
   988  		return nil, err
   989  	}
   990  
   991  	if err := xrv.Validate(bytes.NewReader(plaintextEl)); err != nil {
   992  		return nil, fmt.Errorf("plaintext response contains invalid XML: %s", err)
   993  	}
   994  
   995  	doc := etree.NewDocument()
   996  	if err := doc.ReadFromBytes(plaintextEl); err != nil {
   997  		return nil, fmt.Errorf("cannot parse plaintext response %v", err)
   998  	}
   999  	return doc.Root(), nil
  1000  }
  1001  
  1002  func (sp *ServiceProvider) parseAssertion(assertionEl *etree.Element, possibleRequestIDs []string, now time.Time, signatureRequirement signatureRequirement) (*Assertion, error) {
  1003  	if signatureRequirement == signatureRequired {
  1004  		sigErr := sp.validateSignature(assertionEl)
  1005  		if sigErr != nil {
  1006  			return nil, sigErr
  1007  		}
  1008  	}
  1009  
  1010  	// parse the assertion we just validated
  1011  	var assertion Assertion
  1012  	if err := unmarshalElement(assertionEl, &assertion); err != nil {
  1013  		return nil, err
  1014  	}
  1015  
  1016  	if err := sp.validateAssertion(&assertion, possibleRequestIDs, now); err != nil {
  1017  		return nil, err
  1018  	}
  1019  
  1020  	return &assertion, nil
  1021  }
  1022  
  1023  // validateAssertion checks that the conditions specified in assertion match
  1024  // the requirements to accept. If validation fails, it returns an error describing
  1025  // the failure. (The digital signature on the assertion is not checked -- this
  1026  // should be done before calling this function).
  1027  func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleRequestIDs []string, now time.Time) error {
  1028  	if assertion.IssueInstant.Add(MaxIssueDelay).Before(now) {
  1029  		return fmt.Errorf("expired on %s", assertion.IssueInstant.Add(MaxIssueDelay))
  1030  	}
  1031  	if assertion.Issuer.Value != sp.IDPMetadata.EntityID {
  1032  		return fmt.Errorf("issuer is not %q", sp.IDPMetadata.EntityID)
  1033  	}
  1034  	for _, subjectConfirmation := range assertion.Subject.SubjectConfirmations {
  1035  		requestIDvalid := false
  1036  
  1037  		// We *DO NOT* validate InResponseTo when AllowIDPInitiated is set. Here's why:
  1038  		//
  1039  		// The SAML specification does not provide clear guidance for handling InResponseTo for IDP-initiated
  1040  		// requests where there is no request to be in response to. The specification says:
  1041  		//
  1042  		//   InResponseTo [Optional]
  1043  		//       The ID of a SAML protocol message in response to which an attesting entity can present the
  1044  		//       assertion. For example, this attribute might be used to correlate the assertion to a SAML
  1045  		//       request that resulted in its presentation.
  1046  		//
  1047  		// The initial thought was that we should specify a single empty string in possibleRequestIDs for IDP-initiated
  1048  		// requests so that we would ensure that an InResponseTo was *not* provided in those cases where it wasn't
  1049  		// expected. Even that turns out to be frustrating for users. And in practice some IDPs (e.g. Rippling)
  1050  		// set a specific non-empty value for InResponseTo in IDP-initiated requests.
  1051  		//
  1052  		// Finally, it is unclear that there is significant security value in checking InResponseTo when we allow
  1053  		// IDP initiated assertions.
  1054  		if !sp.AllowIDPInitiated {
  1055  			for _, possibleRequestID := range possibleRequestIDs {
  1056  				if subjectConfirmation.SubjectConfirmationData.InResponseTo == possibleRequestID {
  1057  					requestIDvalid = true
  1058  					break
  1059  				}
  1060  			}
  1061  			if !requestIDvalid {
  1062  				return fmt.Errorf("assertion SubjectConfirmation one of the possible request IDs (%v)", possibleRequestIDs)
  1063  			}
  1064  		}
  1065  		if subjectConfirmation.SubjectConfirmationData.Recipient != sp.AcsURL.String() {
  1066  			return fmt.Errorf("assertion SubjectConfirmation Recipient is not %s", sp.AcsURL.String())
  1067  		}
  1068  		if subjectConfirmation.SubjectConfirmationData.NotOnOrAfter.Add(MaxClockSkew).Before(now) {
  1069  			return fmt.Errorf("assertion SubjectConfirmationData is expired")
  1070  		}
  1071  	}
  1072  	if assertion.Conditions.NotBefore.Add(-MaxClockSkew).After(now) {
  1073  		return fmt.Errorf("assertion Conditions is not yet valid")
  1074  	}
  1075  	if assertion.Conditions.NotOnOrAfter.Add(MaxClockSkew).Before(now) {
  1076  		return fmt.Errorf("assertion Conditions is expired")
  1077  	}
  1078  
  1079  	audienceRestrictionsValid := len(assertion.Conditions.AudienceRestrictions) == 0
  1080  	audience := firstSet(sp.EntityID, sp.MetadataURL.String())
  1081  	for _, audienceRestriction := range assertion.Conditions.AudienceRestrictions {
  1082  		if audienceRestriction.Audience.Value == audience {
  1083  			audienceRestrictionsValid = true
  1084  		}
  1085  	}
  1086  	if !audienceRestrictionsValid {
  1087  		return fmt.Errorf("assertion Conditions AudienceRestriction does not contain %q", audience)
  1088  	}
  1089  	return nil
  1090  }
  1091  
  1092  var errSignatureElementNotPresent = errors.New("signature element not present")
  1093  
  1094  // validateSignature returns nil iff the Signature embedded in the element is valid
  1095  func (sp *ServiceProvider) validateSignature(el *etree.Element) error {
  1096  	sigEl, err := findChild(el, "http://www.w3.org/2000/09/xmldsig#", "Signature")
  1097  	if err != nil {
  1098  		return err
  1099  	}
  1100  	if sigEl == nil {
  1101  		return errSignatureElementNotPresent
  1102  	}
  1103  
  1104  	certs, err := sp.getIDPSigningCerts()
  1105  	if err != nil {
  1106  		return fmt.Errorf("cannot validate signature on %s: %v", el.Tag, err)
  1107  	}
  1108  
  1109  	certificateStore := dsig.MemoryX509CertificateStore{
  1110  		Roots: certs,
  1111  	}
  1112  
  1113  	validationContext := dsig.NewDefaultValidationContext(&certificateStore)
  1114  	validationContext.IdAttribute = "ID"
  1115  	if Clock != nil {
  1116  		validationContext.Clock = Clock
  1117  	}
  1118  
  1119  	// Some SAML responses contain a RSAKeyValue element. One of two things is happening here:
  1120  	//
  1121  	// (1) We're getting something signed by a key we already know about -- the public key
  1122  	//     of the signing cert provided in the metadata.
  1123  	// (2) We're getting something signed by a key we *don't* know about, and which we have
  1124  	//     no ability to verify.
  1125  	//
  1126  	// The best course of action is to just remove the KeyInfo so that dsig falls back to
  1127  	// verifying against the public key provided in the metadata.
  1128  	if el.FindElement("./Signature/KeyInfo/X509Data/X509Certificate") == nil {
  1129  		if sigEl := el.FindElement("./Signature"); sigEl != nil {
  1130  			if keyInfo := sigEl.FindElement("KeyInfo"); keyInfo != nil {
  1131  				sigEl.RemoveChild(keyInfo)
  1132  			}
  1133  		}
  1134  	}
  1135  
  1136  	ctx, err := etreeutils.NSBuildParentContext(el)
  1137  	if err != nil {
  1138  		return fmt.Errorf("cannot validate signature on %s: %v", el.Tag, err)
  1139  	}
  1140  	ctx, err = ctx.SubContext(el)
  1141  	if err != nil {
  1142  		return fmt.Errorf("cannot validate signature on %s: %v", el.Tag, err)
  1143  	}
  1144  	el, err = etreeutils.NSDetatch(ctx, el)
  1145  	if err != nil {
  1146  		return fmt.Errorf("cannot validate signature on %s: %v", el.Tag, err)
  1147  	}
  1148  
  1149  	if sp.SignatureVerifier != nil {
  1150  		return sp.SignatureVerifier.VerifySignature(validationContext, el)
  1151  	}
  1152  
  1153  	if _, err := validationContext.Validate(el); err != nil {
  1154  		return fmt.Errorf("cannot validate signature on %s: %v", el.Tag, err)
  1155  	}
  1156  
  1157  	return nil
  1158  }
  1159  
  1160  // SignLogoutRequest adds the `Signature` element to the `LogoutRequest`.
  1161  func (sp *ServiceProvider) SignLogoutRequest(req *LogoutRequest) error {
  1162  	keyPair := tls.Certificate{
  1163  		Certificate: [][]byte{sp.Certificate.Raw},
  1164  		PrivateKey:  sp.Key,
  1165  		Leaf:        sp.Certificate,
  1166  	}
  1167  	// TODO: add intermediates for SP
  1168  	// for _, cert := range sp.Intermediates {
  1169  	// 	keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
  1170  	// }
  1171  	keyStore := dsig.TLSCertKeyStore(keyPair)
  1172  
  1173  	if sp.SignatureMethod != dsig.RSASHA1SignatureMethod &&
  1174  		sp.SignatureMethod != dsig.RSASHA256SignatureMethod &&
  1175  		sp.SignatureMethod != dsig.RSASHA512SignatureMethod {
  1176  		return fmt.Errorf("invalid signing method %s", sp.SignatureMethod)
  1177  	}
  1178  	signatureMethod := sp.SignatureMethod
  1179  	signingContext := dsig.NewDefaultSigningContext(keyStore)
  1180  	signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
  1181  	if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
  1182  		return err
  1183  	}
  1184  
  1185  	assertionEl := req.Element()
  1186  
  1187  	signedRequestEl, err := signingContext.SignEnveloped(assertionEl)
  1188  	if err != nil {
  1189  		return err
  1190  	}
  1191  
  1192  	sigEl := signedRequestEl.Child[len(signedRequestEl.Child)-1]
  1193  	req.Signature = sigEl.(*etree.Element)
  1194  	return nil
  1195  }
  1196  
  1197  // MakeLogoutRequest produces a new LogoutRequest object for idpURL.
  1198  func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequest, error) {
  1199  
  1200  	req := LogoutRequest{
  1201  		ID:           fmt.Sprintf("id-%x", randomBytes(20)),
  1202  		IssueInstant: TimeNow(),
  1203  		Version:      "2.0",
  1204  		Destination:  idpURL,
  1205  		Issuer: &Issuer{
  1206  			Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
  1207  			Value:  firstSet(sp.EntityID, sp.MetadataURL.String()),
  1208  		},
  1209  		NameID: &NameID{
  1210  			Format:          sp.nameIDFormat(),
  1211  			Value:           nameID,
  1212  			NameQualifier:   sp.IDPMetadata.EntityID,
  1213  			SPNameQualifier: sp.Metadata().EntityID,
  1214  		},
  1215  	}
  1216  	if len(sp.SignatureMethod) > 0 {
  1217  		if err := sp.SignLogoutRequest(&req); err != nil {
  1218  			return nil, err
  1219  		}
  1220  	}
  1221  	return &req, nil
  1222  }
  1223  
  1224  // MakeRedirectLogoutRequest creates a SAML authentication request using
  1225  // the HTTP-Redirect binding. It returns a URL that we will redirect the user to
  1226  // in order to start the auth process.
  1227  func (sp *ServiceProvider) MakeRedirectLogoutRequest(nameID, relayState string) (*url.URL, error) {
  1228  	req, err := sp.MakeLogoutRequest(sp.GetSLOBindingLocation(HTTPRedirectBinding), nameID)
  1229  	if err != nil {
  1230  		return nil, err
  1231  	}
  1232  	return req.Redirect(relayState), nil
  1233  }
  1234  
  1235  // Redirect returns a URL suitable for using the redirect binding with the request
  1236  func (r *LogoutRequest) Redirect(relayState string) *url.URL {
  1237  	w := &bytes.Buffer{}
  1238  	w1 := base64.NewEncoder(base64.StdEncoding, w)
  1239  	w2, _ := flate.NewWriter(w1, 9)
  1240  	doc := etree.NewDocument()
  1241  	doc.SetRoot(r.Element())
  1242  	if _, err := doc.WriteTo(w2); err != nil {
  1243  		panic(err)
  1244  	}
  1245  	if err := w2.Close(); err != nil {
  1246  		panic(err)
  1247  	}
  1248  	if err := w1.Close(); err != nil {
  1249  		panic(err)
  1250  	}
  1251  
  1252  	rv, _ := url.Parse(r.Destination)
  1253  
  1254  	query := rv.Query()
  1255  	query.Set("SAMLRequest", w.String())
  1256  	if relayState != "" {
  1257  		query.Set("RelayState", relayState)
  1258  	}
  1259  	rv.RawQuery = query.Encode()
  1260  
  1261  	return rv
  1262  }
  1263  
  1264  // MakePostLogoutRequest creates a SAML authentication request using
  1265  // the HTTP-POST binding. It returns HTML text representing an HTML form that
  1266  // can be sent presented to a browser to initiate the logout process.
  1267  func (sp *ServiceProvider) MakePostLogoutRequest(nameID, relayState string) ([]byte, error) {
  1268  	req, err := sp.MakeLogoutRequest(sp.GetSLOBindingLocation(HTTPPostBinding), nameID)
  1269  	if err != nil {
  1270  		return nil, err
  1271  	}
  1272  	return req.Post(relayState), nil
  1273  }
  1274  
  1275  // Post returns an HTML form suitable for using the HTTP-POST binding with the request
  1276  func (r *LogoutRequest) Post(relayState string) []byte {
  1277  	doc := etree.NewDocument()
  1278  	doc.SetRoot(r.Element())
  1279  	reqBuf, err := doc.WriteToBytes()
  1280  	if err != nil {
  1281  		panic(err)
  1282  	}
  1283  	encodedReqBuf := base64.StdEncoding.EncodeToString(reqBuf)
  1284  
  1285  	tmpl := template.Must(template.New("saml-post-form").Parse(`` +
  1286  		`<form method="post" action="{{.URL}}" id="SAMLRequestForm">` +
  1287  		`<input type="hidden" name="SAMLRequest" value="{{.SAMLRequest}}" />` +
  1288  		`<input type="hidden" name="RelayState" value="{{.RelayState}}" />` +
  1289  		`<input id="SAMLSubmitButton" type="submit" value="Submit" />` +
  1290  		`</form>` +
  1291  		`<script>document.getElementById('SAMLSubmitButton').style.visibility="hidden";` +
  1292  		`document.getElementById('SAMLRequestForm').submit();</script>`))
  1293  	data := struct {
  1294  		URL         string
  1295  		SAMLRequest string
  1296  		RelayState  string
  1297  	}{
  1298  		URL:         r.Destination,
  1299  		SAMLRequest: encodedReqBuf,
  1300  		RelayState:  relayState,
  1301  	}
  1302  
  1303  	rv := bytes.Buffer{}
  1304  	if err := tmpl.Execute(&rv, data); err != nil {
  1305  		panic(err)
  1306  	}
  1307  
  1308  	return rv.Bytes()
  1309  }
  1310  
  1311  // MakeLogoutResponse produces a new LogoutResponse object for idpURL and logoutRequestID.
  1312  func (sp *ServiceProvider) MakeLogoutResponse(idpURL, logoutRequestID string) (*LogoutResponse, error) {
  1313  	response := LogoutResponse{
  1314  		ID:           fmt.Sprintf("id-%x", randomBytes(20)),
  1315  		InResponseTo: logoutRequestID,
  1316  		Version:      "2.0",
  1317  		IssueInstant: TimeNow(),
  1318  		Destination:  idpURL,
  1319  		Issuer: &Issuer{
  1320  			Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
  1321  			Value:  firstSet(sp.EntityID, sp.MetadataURL.String()),
  1322  		},
  1323  		Status: Status{
  1324  			StatusCode: StatusCode{
  1325  				Value: StatusSuccess,
  1326  			},
  1327  		},
  1328  	}
  1329  
  1330  	if len(sp.SignatureMethod) > 0 {
  1331  		if err := sp.SignLogoutResponse(&response); err != nil {
  1332  			return nil, err
  1333  		}
  1334  	}
  1335  	return &response, nil
  1336  }
  1337  
  1338  // MakeRedirectLogoutResponse creates a SAML LogoutResponse using
  1339  // the HTTP-Redirect binding. It returns a URL that we will redirect the user to
  1340  // for LogoutResponse.
  1341  func (sp *ServiceProvider) MakeRedirectLogoutResponse(logoutRequestID, relayState string) (*url.URL, error) {
  1342  	resp, err := sp.MakeLogoutResponse(sp.GetSLOBindingLocation(HTTPRedirectBinding), logoutRequestID)
  1343  	if err != nil {
  1344  		return nil, err
  1345  	}
  1346  	return resp.Redirect(relayState), nil
  1347  }
  1348  
  1349  // Redirect returns a URL suitable for using the redirect binding with the LogoutResponse.
  1350  func (r *LogoutResponse) Redirect(relayState string) *url.URL {
  1351  	w := &bytes.Buffer{}
  1352  	w1 := base64.NewEncoder(base64.StdEncoding, w)
  1353  	w2, _ := flate.NewWriter(w1, 9)
  1354  	doc := etree.NewDocument()
  1355  	doc.SetRoot(r.Element())
  1356  	if _, err := doc.WriteTo(w2); err != nil {
  1357  		panic(err)
  1358  	}
  1359  	if err := w2.Close(); err != nil {
  1360  		panic(err)
  1361  	}
  1362  	if err := w1.Close(); err != nil {
  1363  		panic(err)
  1364  	}
  1365  
  1366  	rv, _ := url.Parse(r.Destination)
  1367  
  1368  	query := rv.Query()
  1369  	query.Set("SAMLResponse", w.String())
  1370  	if relayState != "" {
  1371  		query.Set("RelayState", relayState)
  1372  	}
  1373  	rv.RawQuery = query.Encode()
  1374  
  1375  	return rv
  1376  }
  1377  
  1378  // MakePostLogoutResponse creates a SAML LogoutResponse using
  1379  // the HTTP-POST binding. It returns HTML text representing an HTML form that
  1380  // can be sent presented to a browser for LogoutResponse.
  1381  func (sp *ServiceProvider) MakePostLogoutResponse(logoutRequestID, relayState string) ([]byte, error) {
  1382  	resp, err := sp.MakeLogoutResponse(sp.GetSLOBindingLocation(HTTPPostBinding), logoutRequestID)
  1383  	if err != nil {
  1384  		return nil, err
  1385  	}
  1386  	return resp.Post(relayState), nil
  1387  }
  1388  
  1389  // Post returns an HTML form suitable for using the HTTP-POST binding with the LogoutResponse.
  1390  func (r *LogoutResponse) Post(relayState string) []byte {
  1391  	doc := etree.NewDocument()
  1392  	doc.SetRoot(r.Element())
  1393  	reqBuf, err := doc.WriteToBytes()
  1394  	if err != nil {
  1395  		panic(err)
  1396  	}
  1397  	encodedReqBuf := base64.StdEncoding.EncodeToString(reqBuf)
  1398  
  1399  	tmpl := template.Must(template.New("saml-post-form").Parse(`` +
  1400  		`<form method="post" action="{{.URL}}" id="SAMLResponseForm">` +
  1401  		`<input type="hidden" name="SAMLResponse" value="{{.SAMLResponse}}" />` +
  1402  		`<input type="hidden" name="RelayState" value="{{.RelayState}}" />` +
  1403  		`<input id="SAMLSubmitButton" type="submit" value="Submit" />` +
  1404  		`</form>` +
  1405  		`<script>document.getElementById('SAMLSubmitButton').style.visibility="hidden";` +
  1406  		`document.getElementById('SAMLResponseForm').submit();</script>`))
  1407  	data := struct {
  1408  		URL          string
  1409  		SAMLResponse string
  1410  		RelayState   string
  1411  	}{
  1412  		URL:          r.Destination,
  1413  		SAMLResponse: encodedReqBuf,
  1414  		RelayState:   relayState,
  1415  	}
  1416  
  1417  	rv := bytes.Buffer{}
  1418  	if err := tmpl.Execute(&rv, data); err != nil {
  1419  		panic(err)
  1420  	}
  1421  
  1422  	return rv.Bytes()
  1423  }
  1424  
  1425  // SignLogoutResponse adds the `Signature` element to the `LogoutResponse`.
  1426  func (sp *ServiceProvider) SignLogoutResponse(resp *LogoutResponse) error {
  1427  	keyPair := tls.Certificate{
  1428  		Certificate: [][]byte{sp.Certificate.Raw},
  1429  		PrivateKey:  sp.Key,
  1430  		Leaf:        sp.Certificate,
  1431  	}
  1432  	// TODO: add intermediates for SP
  1433  	// for _, cert := range sp.Intermediates {
  1434  	// 	keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
  1435  	// }
  1436  	keyStore := dsig.TLSCertKeyStore(keyPair)
  1437  
  1438  	if sp.SignatureMethod != dsig.RSASHA1SignatureMethod &&
  1439  		sp.SignatureMethod != dsig.RSASHA256SignatureMethod &&
  1440  		sp.SignatureMethod != dsig.RSASHA512SignatureMethod {
  1441  		return fmt.Errorf("invalid signing method %s", sp.SignatureMethod)
  1442  	}
  1443  	signatureMethod := sp.SignatureMethod
  1444  	signingContext := dsig.NewDefaultSigningContext(keyStore)
  1445  	signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
  1446  	if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
  1447  		return err
  1448  	}
  1449  
  1450  	assertionEl := resp.Element()
  1451  
  1452  	signedRequestEl, err := signingContext.SignEnveloped(assertionEl)
  1453  	if err != nil {
  1454  		return err
  1455  	}
  1456  
  1457  	sigEl := signedRequestEl.Child[len(signedRequestEl.Child)-1]
  1458  	resp.Signature = sigEl.(*etree.Element)
  1459  	return nil
  1460  }
  1461  
  1462  func (sp *ServiceProvider) nameIDFormat() string {
  1463  	var nameIDFormat string
  1464  	switch sp.AuthnNameIDFormat {
  1465  	case "":
  1466  		// To maintain library back-compat, use "transient" if unset.
  1467  		nameIDFormat = string(TransientNameIDFormat)
  1468  	case UnspecifiedNameIDFormat:
  1469  		// Spec defines an empty value as "unspecified" so don't set one.
  1470  	default:
  1471  		nameIDFormat = string(sp.AuthnNameIDFormat)
  1472  	}
  1473  	return nameIDFormat
  1474  }
  1475  
  1476  // ValidateLogoutResponseRequest validates the LogoutResponse content from the request
  1477  func (sp *ServiceProvider) ValidateLogoutResponseRequest(req *http.Request) error {
  1478  	if data := req.URL.Query().Get("SAMLResponse"); data != "" {
  1479  		return sp.ValidateLogoutResponseRedirect(data)
  1480  	}
  1481  
  1482  	err := req.ParseForm()
  1483  	if err != nil {
  1484  		return fmt.Errorf("unable to parse form: %v", err)
  1485  	}
  1486  
  1487  	return sp.ValidateLogoutResponseForm(req.PostForm.Get("SAMLResponse"))
  1488  }
  1489  
  1490  // ValidateLogoutResponseForm returns a nil error if the logout response is valid.
  1491  func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error {
  1492  	retErr := &InvalidResponseError{
  1493  		Now: TimeNow(),
  1494  	}
  1495  
  1496  	rawResponseBuf, err := base64.StdEncoding.DecodeString(postFormData)
  1497  	if err != nil {
  1498  		retErr.PrivateErr = fmt.Errorf("unable to parse base64: %s", err)
  1499  		return retErr
  1500  	}
  1501  	retErr.Response = string(rawResponseBuf)
  1502  
  1503  	// TODO(ross): add test case for this (SLO does not have tests right now)
  1504  	if err := xrv.Validate(bytes.NewReader(rawResponseBuf)); err != nil {
  1505  		return fmt.Errorf("response contains invalid XML: %s", err)
  1506  	}
  1507  
  1508  	doc := etree.NewDocument()
  1509  	if err := doc.ReadFromBytes(rawResponseBuf); err != nil {
  1510  		retErr.PrivateErr = err
  1511  		return retErr
  1512  	}
  1513  
  1514  	if err := sp.validateSignature(doc.Root()); err != nil {
  1515  		retErr.PrivateErr = err
  1516  		return retErr
  1517  	}
  1518  
  1519  	var resp LogoutResponse
  1520  	if err := unmarshalElement(doc.Root(), &resp); err != nil {
  1521  		retErr.PrivateErr = err
  1522  		return retErr
  1523  	}
  1524  	return sp.validateLogoutResponse(&resp)
  1525  }
  1526  
  1527  // ValidateLogoutResponseRedirect returns a nil error if the logout response is valid.
  1528  //
  1529  // URL Binding appears to be gzip / flate encoded
  1530  // See https://www.oasis-open.org/committees/download.php/20645/sstc-saml-tech-overview-2%200-draft-10.pdf  6.6
  1531  func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData string) error {
  1532  	retErr := &InvalidResponseError{
  1533  		Now: TimeNow(),
  1534  	}
  1535  
  1536  	rawResponseBuf, err := base64.StdEncoding.DecodeString(queryParameterData)
  1537  	if err != nil {
  1538  		retErr.PrivateErr = fmt.Errorf("unable to parse base64: %s", err)
  1539  		return retErr
  1540  	}
  1541  	retErr.Response = string(rawResponseBuf)
  1542  
  1543  	gr, err := io.ReadAll(newSaferFlateReader(bytes.NewBuffer(rawResponseBuf)))
  1544  	if err != nil {
  1545  		retErr.PrivateErr = err
  1546  		return retErr
  1547  	}
  1548  
  1549  	if err := xrv.Validate(bytes.NewReader(gr)); err != nil {
  1550  		return err
  1551  	}
  1552  
  1553  	doc := etree.NewDocument()
  1554  	if err := doc.ReadFromBytes(rawResponseBuf); err != nil {
  1555  		retErr.PrivateErr = err
  1556  		return retErr
  1557  	}
  1558  
  1559  	if err := sp.validateSignature(doc.Root()); err != nil {
  1560  		retErr.PrivateErr = err
  1561  		return retErr
  1562  	}
  1563  
  1564  	var resp LogoutResponse
  1565  	if err := unmarshalElement(doc.Root(), &resp); err != nil {
  1566  		retErr.PrivateErr = err
  1567  		return retErr
  1568  	}
  1569  	return sp.validateLogoutResponse(&resp)
  1570  }
  1571  
  1572  // validateLogoutResponse validates the LogoutResponse fields. Returns a nil error if the LogoutResponse is valid.
  1573  func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error {
  1574  	if resp.Destination != sp.SloURL.String() {
  1575  		return fmt.Errorf("`Destination` does not match SloURL (expected %q)", sp.SloURL.String())
  1576  	}
  1577  
  1578  	now := time.Now()
  1579  	if resp.IssueInstant.Add(MaxIssueDelay).Before(now) {
  1580  		return fmt.Errorf("issueInstant expired at %s", resp.IssueInstant.Add(MaxIssueDelay))
  1581  	}
  1582  	if resp.Issuer.Value != sp.IDPMetadata.EntityID {
  1583  		return fmt.Errorf("issuer does not match the IDP metadata (expected %q)", sp.IDPMetadata.EntityID)
  1584  	}
  1585  	if resp.Status.StatusCode.Value != StatusSuccess {
  1586  		return fmt.Errorf("status code was not %s", StatusSuccess)
  1587  	}
  1588  
  1589  	return nil
  1590  }
  1591  
  1592  func firstSet(a, b string) string {
  1593  	if a == "" {
  1594  		return b
  1595  	}
  1596  	return a
  1597  }
  1598  
  1599  // findChildren returns all the elements matching childNS/childTag that are direct children of parentEl.
  1600  func findChildren(parentEl *etree.Element, childNS string, childTag string) ([]*etree.Element, error) {
  1601  	//nolint:prealloc // We don't know how many child elements we'll actually put into this array.
  1602  	var rv []*etree.Element
  1603  	for _, childEl := range parentEl.ChildElements() {
  1604  		if childEl.Tag != childTag {
  1605  			continue
  1606  		}
  1607  
  1608  		ctx, err := etreeutils.NSBuildParentContext(childEl)
  1609  		if err != nil {
  1610  			return nil, err
  1611  		}
  1612  		ctx, err = ctx.SubContext(childEl)
  1613  		if err != nil {
  1614  			return nil, err
  1615  		}
  1616  
  1617  		ns, err := ctx.LookupPrefix(childEl.Space)
  1618  		if err != nil {
  1619  			return nil, fmt.Errorf("[%s]:%s cannot find prefix %s: %v", childNS, childTag, childEl.Space, err)
  1620  		}
  1621  		if ns != childNS {
  1622  			continue
  1623  		}
  1624  
  1625  		rv = append(rv, childEl)
  1626  	}
  1627  
  1628  	return rv, nil
  1629  }
  1630  
  1631  // findOneChild finds the specified child element. Returns an error if the element doesn't exist.
  1632  func findOneChild(parentEl *etree.Element, childNS string, childTag string) (*etree.Element, error) {
  1633  	children, err := findChildren(parentEl, childNS, childTag)
  1634  	if err != nil {
  1635  		return nil, err
  1636  	}
  1637  	switch len(children) {
  1638  	case 0:
  1639  		return nil, fmt.Errorf("cannot find %s:%s element", childNS, childTag)
  1640  	case 1:
  1641  		return children[0], nil
  1642  	default:
  1643  		return nil, fmt.Errorf("expected exactly one %s:%s element", childNS, childTag)
  1644  	}
  1645  }
  1646  
  1647  // findChild finds the specified child element. Returns (nil, nil) of the element doesn't exist.
  1648  func findChild(parentEl *etree.Element, childNS string, childTag string) (*etree.Element, error) {
  1649  	children, err := findChildren(parentEl, childNS, childTag)
  1650  	if err != nil {
  1651  		return nil, err
  1652  	}
  1653  	switch len(children) {
  1654  	case 0:
  1655  		return nil, nil
  1656  	case 1:
  1657  		return children[0], nil
  1658  	default:
  1659  		return nil, fmt.Errorf("expected at most one %s:%s element", childNS, childTag)
  1660  	}
  1661  }
  1662  
  1663  func elementToBytes(el *etree.Element) ([]byte, error) {
  1664  	namespaces := map[string]string{}
  1665  	for _, childEl := range el.FindElements("//*") {
  1666  		ns := childEl.NamespaceURI()
  1667  		if ns != "" {
  1668  			namespaces[childEl.Space] = ns
  1669  		}
  1670  	}
  1671  
  1672  	doc := etree.NewDocument()
  1673  	doc.SetRoot(el.Copy())
  1674  	for space, uri := range namespaces {
  1675  		doc.Root().CreateAttr("xmlns:"+space, uri)
  1676  	}
  1677  
  1678  	return doc.WriteToBytes()
  1679  }
  1680  
  1681  // unmarshalElement serializes el into v by serializing el and then parsing it with xml.Unmarshal.
  1682  func unmarshalElement(el *etree.Element, v interface{}) error {
  1683  	buf, err := elementToBytes(el)
  1684  	if err != nil {
  1685  		return err
  1686  	}
  1687  	return xml.Unmarshal(buf, v)
  1688  }
  1689  
  1690  func elementToString(el *etree.Element) string {
  1691  	buf, err := elementToBytes(el)
  1692  	if err != nil {
  1693  		return ""
  1694  	}
  1695  	return string(buf)
  1696  }