github.com/crewjam/saml@v0.4.14/identity_provider.go (about) 1 package saml 2 3 import ( 4 "bytes" 5 "crypto" 6 "crypto/tls" 7 "crypto/x509" 8 "encoding/base64" 9 "encoding/xml" 10 "fmt" 11 "io" 12 "net/http" 13 "net/url" 14 "os" 15 "regexp" 16 "strconv" 17 "text/template" 18 "time" 19 20 "github.com/beevik/etree" 21 xrv "github.com/mattermost/xml-roundtrip-validator" 22 dsig "github.com/russellhaering/goxmldsig" 23 24 "github.com/crewjam/saml/logger" 25 "github.com/crewjam/saml/xmlenc" 26 ) 27 28 // Session represents a user session. It is returned by the 29 // SessionProvider implementation's GetSession method. Fields here 30 // are used to set fields in the SAML assertion. 31 type Session struct { 32 ID string 33 CreateTime time.Time 34 ExpireTime time.Time 35 Index string 36 37 NameID string 38 NameIDFormat string 39 SubjectID string 40 41 Groups []string 42 UserName string 43 UserEmail string 44 UserCommonName string 45 UserSurname string 46 UserGivenName string 47 UserScopedAffiliation string 48 49 CustomAttributes []Attribute 50 } 51 52 // SessionProvider is an interface used by IdentityProvider to determine the 53 // Session associated with a request. For an example implementation, see 54 // GetSession in the samlidp package. 55 type SessionProvider interface { 56 // GetSession returns the remote user session associated with the http.Request. 57 // 58 // If (and only if) the request is not associated with a session then GetSession 59 // must complete the HTTP request and return nil. 60 GetSession(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session 61 } 62 63 // ServiceProviderProvider is an interface used by IdentityProvider to look up 64 // service provider metadata for a request. 65 type ServiceProviderProvider interface { 66 // GetServiceProvider returns the Service Provider metadata for the 67 // service provider ID, which is typically the service provider's 68 // metadata URL. If an appropriate service provider cannot be found then 69 // the returned error must be os.ErrNotExist. 70 GetServiceProvider(r *http.Request, serviceProviderID string) (*EntityDescriptor, error) 71 } 72 73 // AssertionMaker is an interface used by IdentityProvider to construct the 74 // assertion for a request. The default implementation is DefaultAssertionMaker, 75 // which is used if not AssertionMaker is specified. 76 type AssertionMaker interface { 77 // MakeAssertion constructs an assertion from session and the request and 78 // assigns it to req.Assertion. 79 MakeAssertion(req *IdpAuthnRequest, session *Session) error 80 } 81 82 // IdentityProvider implements the SAML Identity Provider role (IDP). 83 // 84 // An identity provider receives SAML assertion requests and responds 85 // with SAML Assertions. 86 // 87 // You must provide a keypair that is used to 88 // sign assertions. 89 // 90 // You must provide an implementation of ServiceProviderProvider which 91 // returns 92 // 93 // You must provide an implementation of the SessionProvider which 94 // handles the actual authentication (i.e. prompting for a username 95 // and password). 96 type IdentityProvider struct { 97 Key crypto.PrivateKey 98 Signer crypto.Signer 99 Logger logger.Interface 100 Certificate *x509.Certificate 101 Intermediates []*x509.Certificate 102 MetadataURL url.URL 103 SSOURL url.URL 104 LogoutURL url.URL 105 ServiceProviderProvider ServiceProviderProvider 106 SessionProvider SessionProvider 107 AssertionMaker AssertionMaker 108 SignatureMethod string 109 ValidDuration *time.Duration 110 } 111 112 // Metadata returns the metadata structure for this identity provider. 113 func (idp *IdentityProvider) Metadata() *EntityDescriptor { 114 certStr := base64.StdEncoding.EncodeToString(idp.Certificate.Raw) 115 116 var validDuration time.Duration 117 if idp.ValidDuration != nil { 118 validDuration = *idp.ValidDuration 119 } else { 120 validDuration = DefaultValidDuration 121 } 122 123 ed := &EntityDescriptor{ 124 EntityID: idp.MetadataURL.String(), 125 ValidUntil: TimeNow().Add(validDuration), 126 CacheDuration: validDuration, 127 IDPSSODescriptors: []IDPSSODescriptor{ 128 { 129 SSODescriptor: SSODescriptor{ 130 RoleDescriptor: RoleDescriptor{ 131 ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol", 132 KeyDescriptors: []KeyDescriptor{ 133 { 134 Use: "signing", 135 KeyInfo: KeyInfo{ 136 X509Data: X509Data{ 137 X509Certificates: []X509Certificate{ 138 {Data: certStr}, 139 }, 140 }, 141 }, 142 }, 143 { 144 Use: "encryption", 145 KeyInfo: KeyInfo{ 146 X509Data: X509Data{ 147 X509Certificates: []X509Certificate{ 148 {Data: certStr}, 149 }, 150 }, 151 }, 152 EncryptionMethods: []EncryptionMethod{ 153 {Algorithm: "http://www.w3.org/2001/04/xmlenc#aes128-cbc"}, 154 {Algorithm: "http://www.w3.org/2001/04/xmlenc#aes192-cbc"}, 155 {Algorithm: "http://www.w3.org/2001/04/xmlenc#aes256-cbc"}, 156 {Algorithm: "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p"}, 157 }, 158 }, 159 }, 160 }, 161 NameIDFormats: []NameIDFormat{NameIDFormat("urn:oasis:names:tc:SAML:2.0:nameid-format:transient")}, 162 }, 163 SingleSignOnServices: []Endpoint{ 164 { 165 Binding: HTTPRedirectBinding, 166 Location: idp.SSOURL.String(), 167 }, 168 { 169 Binding: HTTPPostBinding, 170 Location: idp.SSOURL.String(), 171 }, 172 }, 173 }, 174 }, 175 } 176 177 if idp.LogoutURL.String() != "" { 178 ed.IDPSSODescriptors[0].SSODescriptor.SingleLogoutServices = []Endpoint{ 179 { 180 Binding: HTTPRedirectBinding, 181 Location: idp.LogoutURL.String(), 182 }, 183 } 184 } 185 186 return ed 187 } 188 189 // Handler returns an http.Handler that serves the metadata and SSO 190 // URLs 191 func (idp *IdentityProvider) Handler() http.Handler { 192 mux := http.NewServeMux() 193 mux.HandleFunc(idp.MetadataURL.Path, idp.ServeMetadata) 194 mux.HandleFunc(idp.SSOURL.Path, idp.ServeSSO) 195 return mux 196 } 197 198 // ServeMetadata is an http.HandlerFunc that serves the IDP metadata 199 func (idp *IdentityProvider) ServeMetadata(w http.ResponseWriter, _ *http.Request) { 200 buf, _ := xml.MarshalIndent(idp.Metadata(), "", " ") 201 w.Header().Set("Content-Type", "application/samlmetadata+xml") 202 if _, err := w.Write(buf); err != nil { 203 idp.Logger.Printf("ERROR: %s", err) 204 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 205 } 206 } 207 208 // ServeSSO handles SAML auth requests. 209 // 210 // When it gets a request for a user that does not have a valid session, 211 // then it prompts the user via XXX. 212 // 213 // If the session already exists, then it produces a SAML assertion and 214 // returns an HTTP response according to the specified binding. The 215 // only supported binding right now is the HTTP-POST binding which returns 216 // an HTML form in the appropriate format with Javascript to automatically 217 // submit that form the to service provider's Assertion Customer Service 218 // endpoint. 219 // 220 // If the SAML request is invalid or cannot be verified a simple StatusBadRequest 221 // response is sent. 222 // 223 // If the assertion cannot be created or returned, a StatusInternalServerError 224 // response is sent. 225 func (idp *IdentityProvider) ServeSSO(w http.ResponseWriter, r *http.Request) { 226 req, err := NewIdpAuthnRequest(idp, r) 227 if err != nil { 228 idp.Logger.Printf("failed to parse request: %s", err) 229 http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) 230 return 231 } 232 233 if err := req.Validate(); err != nil { 234 idp.Logger.Printf("failed to validate request: %s", err) 235 http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) 236 return 237 } 238 239 // TODO(ross): we must check that the request ID has not been previously 240 // issued. 241 242 session := idp.SessionProvider.GetSession(w, r, req) 243 if session == nil { 244 return 245 } 246 247 assertionMaker := idp.AssertionMaker 248 if assertionMaker == nil { 249 assertionMaker = DefaultAssertionMaker{} 250 } 251 if err := assertionMaker.MakeAssertion(req, session); err != nil { 252 idp.Logger.Printf("failed to make assertion: %s", err) 253 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 254 return 255 } 256 if err := req.WriteResponse(w); err != nil { 257 idp.Logger.Printf("failed to write response: %s", err) 258 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 259 return 260 } 261 } 262 263 // ServeIDPInitiated handes an IDP-initiated authorization request. Requests of this 264 // type require us to know a registered service provider and (optionally) the RelayState 265 // that will be passed to the application. 266 func (idp *IdentityProvider) ServeIDPInitiated(w http.ResponseWriter, r *http.Request, serviceProviderID string, relayState string) { 267 req := &IdpAuthnRequest{ 268 IDP: idp, 269 HTTPRequest: r, 270 RelayState: relayState, 271 Now: TimeNow(), 272 } 273 274 session := idp.SessionProvider.GetSession(w, r, req) 275 if session == nil { 276 // If GetSession returns nil, it must have written an HTTP response, per the interface 277 // (this is probably because it drew a login form or something) 278 return 279 } 280 281 var err error 282 req.ServiceProviderMetadata, err = idp.ServiceProviderProvider.GetServiceProvider(r, serviceProviderID) 283 if err == os.ErrNotExist { 284 idp.Logger.Printf("cannot find service provider: %s", serviceProviderID) 285 http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) 286 return 287 } else if err != nil { 288 idp.Logger.Printf("cannot find service provider %s: %v", serviceProviderID, err) 289 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 290 return 291 } 292 293 // find an ACS endpoint that we can use 294 for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors { 295 for _, endpoint := range spssoDescriptor.AssertionConsumerServices { 296 if endpoint.Binding == HTTPPostBinding { 297 // explicitly copy loop iterator variables 298 // 299 // c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable 300 // 301 // (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately, 302 // but it certainly doesn't hurt anything and may prevent bugs in the future.) 303 endpoint, spssoDescriptor := endpoint, spssoDescriptor 304 305 req.ACSEndpoint = &endpoint 306 req.SPSSODescriptor = &spssoDescriptor 307 break 308 } 309 } 310 if req.ACSEndpoint != nil { 311 break 312 } 313 } 314 if req.ACSEndpoint == nil { 315 idp.Logger.Printf("saml metadata does not contain an Assertion Customer Service url") 316 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 317 return 318 } 319 320 assertionMaker := idp.AssertionMaker 321 if assertionMaker == nil { 322 assertionMaker = DefaultAssertionMaker{} 323 } 324 if err := assertionMaker.MakeAssertion(req, session); err != nil { 325 idp.Logger.Printf("failed to make assertion: %s", err) 326 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 327 return 328 } 329 330 if err := req.WriteResponse(w); err != nil { 331 idp.Logger.Printf("failed to write response: %s", err) 332 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 333 return 334 } 335 } 336 337 // IdpAuthnRequest is used by IdentityProvider to handle a single authentication request. 338 type IdpAuthnRequest struct { 339 IDP *IdentityProvider 340 HTTPRequest *http.Request 341 RelayState string 342 RequestBuffer []byte 343 Request AuthnRequest 344 ServiceProviderMetadata *EntityDescriptor 345 SPSSODescriptor *SPSSODescriptor 346 ACSEndpoint *IndexedEndpoint 347 Assertion *Assertion 348 AssertionEl *etree.Element 349 ResponseEl *etree.Element 350 Now time.Time 351 } 352 353 // NewIdpAuthnRequest returns a new IdpAuthnRequest for the given HTTP request to the authorization 354 // service. 355 func NewIdpAuthnRequest(idp *IdentityProvider, r *http.Request) (*IdpAuthnRequest, error) { 356 req := &IdpAuthnRequest{ 357 IDP: idp, 358 HTTPRequest: r, 359 Now: TimeNow(), 360 } 361 362 switch r.Method { 363 case "GET": 364 compressedRequest, err := base64.StdEncoding.DecodeString(r.URL.Query().Get("SAMLRequest")) 365 if err != nil { 366 return nil, fmt.Errorf("cannot decode request: %s", err) 367 } 368 req.RequestBuffer, err = io.ReadAll(newSaferFlateReader(bytes.NewReader(compressedRequest))) 369 if err != nil { 370 return nil, fmt.Errorf("cannot decompress request: %s", err) 371 } 372 req.RelayState = r.URL.Query().Get("RelayState") 373 case "POST": 374 if err := r.ParseForm(); err != nil { 375 return nil, err 376 } 377 var err error 378 req.RequestBuffer, err = base64.StdEncoding.DecodeString(r.PostForm.Get("SAMLRequest")) 379 if err != nil { 380 return nil, err 381 } 382 req.RelayState = r.PostForm.Get("RelayState") 383 default: 384 return nil, fmt.Errorf("method not allowed") 385 } 386 387 return req, nil 388 } 389 390 // Validate checks that the authentication request is valid and assigns 391 // the AuthnRequest and Metadata properties. Returns a non-nil error if the 392 // request is not valid. 393 func (req *IdpAuthnRequest) Validate() error { 394 if err := xrv.Validate(bytes.NewReader(req.RequestBuffer)); err != nil { 395 return err 396 } 397 398 if err := xml.Unmarshal(req.RequestBuffer, &req.Request); err != nil { 399 return err 400 } 401 402 // We always have exactly one IDP SSO descriptor 403 if len(req.IDP.Metadata().IDPSSODescriptors) != 1 { 404 panic("expected exactly one IDP SSO descriptor in IDP metadata") 405 } 406 idpSsoDescriptor := req.IDP.Metadata().IDPSSODescriptors[0] 407 408 // TODO(ross): support signed authn requests 409 // For now we do the safe thing and fail in the case where we think 410 // requests might be signed. 411 if idpSsoDescriptor.WantAuthnRequestsSigned != nil && *idpSsoDescriptor.WantAuthnRequestsSigned { 412 return fmt.Errorf("authn request signature checking is not currently supported") 413 } 414 415 // In http://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf ยง3.4.5.2 416 // we get a description of the Destination attribute: 417 // 418 // If the message is signed, the Destination XML attribute in the root SAML 419 // element of the protocol message MUST contain the URL to which the sender 420 // has instructed the user agent to deliver the message. The recipient MUST 421 // then verify that the value matches the location at which the message has 422 // been received. 423 // 424 // We require the destination be correct either (a) if signing is enabled or 425 // (b) if it was provided. 426 mustHaveDestination := idpSsoDescriptor.WantAuthnRequestsSigned != nil && *idpSsoDescriptor.WantAuthnRequestsSigned 427 mustHaveDestination = mustHaveDestination || req.Request.Destination != "" 428 if mustHaveDestination { 429 if req.Request.Destination != req.IDP.SSOURL.String() { 430 return fmt.Errorf("expected destination to be %q, not %q", req.IDP.SSOURL.String(), req.Request.Destination) 431 } 432 } 433 434 if req.Request.IssueInstant.Add(MaxIssueDelay).Before(req.Now) { 435 return fmt.Errorf("request expired at %s", 436 req.Request.IssueInstant.Add(MaxIssueDelay)) 437 } 438 if req.Request.Version != "2.0" { 439 return fmt.Errorf("expected SAML request version 2.0 got %v", req.Request.Version) 440 } 441 442 // find the service provider 443 serviceProviderID := req.Request.Issuer.Value 444 serviceProvider, err := req.IDP.ServiceProviderProvider.GetServiceProvider(req.HTTPRequest, serviceProviderID) 445 if err == os.ErrNotExist { 446 return fmt.Errorf("cannot handle request from unknown service provider %s", serviceProviderID) 447 } else if err != nil { 448 return fmt.Errorf("cannot find service provider %s: %v", serviceProviderID, err) 449 } 450 req.ServiceProviderMetadata = serviceProvider 451 452 // Check that the ACS URL matches an ACS endpoint in the SP metadata. 453 if err := req.getACSEndpoint(); err != nil { 454 return fmt.Errorf("cannot find assertion consumer service: %v", err) 455 } 456 457 return nil 458 } 459 460 func (req *IdpAuthnRequest) getACSEndpoint() error { 461 if req.Request.AssertionConsumerServiceIndex != "" { 462 for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors { 463 for _, spAssertionConsumerService := range spssoDescriptor.AssertionConsumerServices { 464 if strconv.Itoa(spAssertionConsumerService.Index) == req.Request.AssertionConsumerServiceIndex { 465 // explicitly copy loop iterator variables 466 // 467 // c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable 468 // 469 // (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately, 470 // but it certainly doesn't hurt anything and may prevent bugs in the future.) 471 spssoDescriptor, spAssertionConsumerService := spssoDescriptor, spAssertionConsumerService 472 473 req.SPSSODescriptor = &spssoDescriptor 474 req.ACSEndpoint = &spAssertionConsumerService 475 return nil 476 } 477 } 478 } 479 } 480 481 if req.Request.AssertionConsumerServiceURL != "" { 482 for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors { 483 for _, spAssertionConsumerService := range spssoDescriptor.AssertionConsumerServices { 484 if spAssertionConsumerService.Location == req.Request.AssertionConsumerServiceURL { 485 // explicitly copy loop iterator variables 486 // 487 // c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable 488 // 489 // (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately, 490 // but it certainly doesn't hurt anything and may prevent bugs in the future.) 491 spssoDescriptor, spAssertionConsumerService := spssoDescriptor, spAssertionConsumerService 492 493 req.SPSSODescriptor = &spssoDescriptor 494 req.ACSEndpoint = &spAssertionConsumerService 495 return nil 496 } 497 } 498 } 499 } 500 501 // Some service providers, like the Microsoft Azure AD service provider, issue 502 // assertion requests that don't specify an ACS url at all. 503 if req.Request.AssertionConsumerServiceURL == "" && req.Request.AssertionConsumerServiceIndex == "" { 504 // find a default ACS binding in the metadata that we can use 505 for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors { 506 for _, spAssertionConsumerService := range spssoDescriptor.AssertionConsumerServices { 507 if spAssertionConsumerService.IsDefault != nil && *spAssertionConsumerService.IsDefault { 508 switch spAssertionConsumerService.Binding { 509 case HTTPPostBinding, HTTPRedirectBinding: 510 // explicitly copy loop iterator variables 511 // 512 // c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable 513 // 514 // (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately, 515 // but it certainly doesn't hurt anything and may prevent bugs in the future.) 516 spssoDescriptor, spAssertionConsumerService := spssoDescriptor, spAssertionConsumerService 517 518 req.SPSSODescriptor = &spssoDescriptor 519 req.ACSEndpoint = &spAssertionConsumerService 520 return nil 521 } 522 } 523 } 524 } 525 526 // if we can't find a default, use *any* ACS binding 527 for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors { 528 for _, spAssertionConsumerService := range spssoDescriptor.AssertionConsumerServices { 529 switch spAssertionConsumerService.Binding { 530 case HTTPPostBinding, HTTPRedirectBinding: 531 // explicitly copy loop iterator variables 532 // 533 // c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable 534 // 535 // (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately, 536 // but it certainly doesn't hurt anything and may prevent bugs in the future.) 537 spssoDescriptor, spAssertionConsumerService := spssoDescriptor, spAssertionConsumerService 538 539 req.SPSSODescriptor = &spssoDescriptor 540 req.ACSEndpoint = &spAssertionConsumerService 541 return nil 542 } 543 } 544 } 545 } 546 547 return os.ErrNotExist // no ACS url found or specified 548 } 549 550 // DefaultAssertionMaker produces a SAML assertion for the 551 // given request and assigns it to req.Assertion. 552 type DefaultAssertionMaker struct { 553 } 554 555 // MakeAssertion implements AssertionMaker. It produces a SAML assertion from the 556 // given request and assigns it to req.Assertion. 557 func (DefaultAssertionMaker) MakeAssertion(req *IdpAuthnRequest, session *Session) error { 558 attributes := []Attribute{} 559 560 var attributeConsumingService *AttributeConsumingService 561 for _, acs := range req.SPSSODescriptor.AttributeConsumingServices { 562 if acs.IsDefault != nil && *acs.IsDefault { 563 // explicitly copy loop iterator variables 564 // 565 // c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable 566 // 567 // (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately, 568 // but it certainly doesn't hurt anything and may prevent bugs in the future.) 569 acs := acs 570 571 attributeConsumingService = &acs 572 break 573 } 574 } 575 if attributeConsumingService == nil { 576 for _, acs := range req.SPSSODescriptor.AttributeConsumingServices { 577 // explicitly copy loop iterator variables 578 // 579 // c.f. https://github.com/golang/go/wiki/CommonMistakes#using-reference-to-loop-iterator-variable 580 // 581 // (note that I'm pretty sure this isn't strictly necessary because we break out of the loop immediately, 582 // but it certainly doesn't hurt anything and may prevent bugs in the future.) 583 acs := acs 584 585 attributeConsumingService = &acs 586 break 587 } 588 } 589 if attributeConsumingService == nil { 590 attributeConsumingService = &AttributeConsumingService{} 591 } 592 593 for _, requestedAttribute := range attributeConsumingService.RequestedAttributes { 594 if requestedAttribute.NameFormat == "urn:oasis:names:tc:SAML:2.0:attrname-format:basic" || requestedAttribute.NameFormat == "urn:oasis:names:tc:SAML:2.0:attrname-format:unspecified" { 595 attrName := requestedAttribute.Name 596 attrName = regexp.MustCompile("[^A-Za-z0-9]+").ReplaceAllString(attrName, "") 597 switch attrName { 598 case "email", "emailaddress": 599 attributes = append(attributes, Attribute{ 600 FriendlyName: requestedAttribute.FriendlyName, 601 Name: requestedAttribute.Name, 602 NameFormat: requestedAttribute.NameFormat, 603 Values: []AttributeValue{{ 604 Type: "xs:string", 605 Value: session.UserEmail, 606 }}, 607 }) 608 case "name", "fullname", "cn", "commonname": 609 attributes = append(attributes, Attribute{ 610 FriendlyName: requestedAttribute.FriendlyName, 611 Name: requestedAttribute.Name, 612 NameFormat: requestedAttribute.NameFormat, 613 Values: []AttributeValue{{ 614 Type: "xs:string", 615 Value: session.UserCommonName, 616 }}, 617 }) 618 case "givenname", "firstname": 619 attributes = append(attributes, Attribute{ 620 FriendlyName: requestedAttribute.FriendlyName, 621 Name: requestedAttribute.Name, 622 NameFormat: requestedAttribute.NameFormat, 623 Values: []AttributeValue{{ 624 Type: "xs:string", 625 Value: session.UserGivenName, 626 }}, 627 }) 628 case "surname", "lastname", "familyname": 629 attributes = append(attributes, Attribute{ 630 FriendlyName: requestedAttribute.FriendlyName, 631 Name: requestedAttribute.Name, 632 NameFormat: requestedAttribute.NameFormat, 633 Values: []AttributeValue{{ 634 Type: "xs:string", 635 Value: session.UserSurname, 636 }}, 637 }) 638 case "uid", "user", "userid": 639 attributes = append(attributes, Attribute{ 640 FriendlyName: requestedAttribute.FriendlyName, 641 Name: requestedAttribute.Name, 642 NameFormat: requestedAttribute.NameFormat, 643 Values: []AttributeValue{{ 644 Type: "xs:string", 645 Value: session.UserName, 646 }}, 647 }) 648 } 649 } 650 } 651 652 if session.UserName != "" { 653 attributes = append(attributes, Attribute{ 654 FriendlyName: "uid", 655 Name: "urn:oid:0.9.2342.19200300.100.1.1", 656 NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", 657 Values: []AttributeValue{{ 658 Type: "xs:string", 659 Value: session.UserName, 660 }}, 661 }) 662 } 663 664 if session.UserEmail != "" { 665 attributes = append(attributes, Attribute{ 666 FriendlyName: "eduPersonPrincipalName", 667 Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.6", 668 NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", 669 Values: []AttributeValue{{ 670 Type: "xs:string", 671 Value: session.UserEmail, 672 }}, 673 }) 674 } 675 if session.UserSurname != "" { 676 attributes = append(attributes, Attribute{ 677 FriendlyName: "sn", 678 Name: "urn:oid:2.5.4.4", 679 NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", 680 Values: []AttributeValue{{ 681 Type: "xs:string", 682 Value: session.UserSurname, 683 }}, 684 }) 685 } 686 if session.UserGivenName != "" { 687 attributes = append(attributes, Attribute{ 688 FriendlyName: "givenName", 689 Name: "urn:oid:2.5.4.42", 690 NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", 691 Values: []AttributeValue{{ 692 Type: "xs:string", 693 Value: session.UserGivenName, 694 }}, 695 }) 696 } 697 698 if session.UserCommonName != "" { 699 attributes = append(attributes, Attribute{ 700 FriendlyName: "cn", 701 Name: "urn:oid:2.5.4.3", 702 NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", 703 Values: []AttributeValue{{ 704 Type: "xs:string", 705 Value: session.UserCommonName, 706 }}, 707 }) 708 } 709 710 if session.UserScopedAffiliation != "" { 711 attributes = append(attributes, Attribute{ 712 FriendlyName: "uid", 713 Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.9", 714 NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", 715 Values: []AttributeValue{{ 716 Type: "xs:string", 717 Value: session.UserScopedAffiliation, 718 }}, 719 }) 720 } 721 722 attributes = append(attributes, session.CustomAttributes...) 723 724 if len(session.Groups) != 0 { 725 groupMemberAttributeValues := []AttributeValue{} 726 for _, group := range session.Groups { 727 groupMemberAttributeValues = append(groupMemberAttributeValues, AttributeValue{ 728 Type: "xs:string", 729 Value: group, 730 }) 731 } 732 attributes = append(attributes, Attribute{ 733 FriendlyName: "eduPersonAffiliation", 734 Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.1", 735 NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", 736 Values: groupMemberAttributeValues, 737 }) 738 } 739 740 if session.SubjectID != "" { 741 attributes = append(attributes, Attribute{ 742 Name: "urn:oasis:names:tc:SAML:attribute:subject-id", 743 NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", 744 Values: []AttributeValue{ 745 { 746 Type: "xs:string", 747 Value: session.SubjectID, 748 }, 749 }, 750 }) 751 } 752 753 // allow for some clock skew in the validity period using the 754 // issuer's apparent clock. 755 notBefore := req.Now.Add(-1 * MaxClockSkew) 756 notOnOrAfterAfter := req.Now.Add(MaxIssueDelay) 757 if notBefore.Before(req.Request.IssueInstant) { 758 notBefore = req.Request.IssueInstant 759 notOnOrAfterAfter = notBefore.Add(MaxIssueDelay) 760 } 761 762 nameIDFormat := "urn:oasis:names:tc:SAML:2.0:nameid-format:transient" 763 764 if session.NameIDFormat != "" { 765 nameIDFormat = session.NameIDFormat 766 } 767 768 req.Assertion = &Assertion{ 769 ID: fmt.Sprintf("id-%x", randomBytes(20)), 770 IssueInstant: TimeNow(), 771 Version: "2.0", 772 Issuer: Issuer{ 773 Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity", 774 Value: req.IDP.Metadata().EntityID, 775 }, 776 Subject: &Subject{ 777 NameID: &NameID{ 778 Format: nameIDFormat, 779 NameQualifier: req.IDP.Metadata().EntityID, 780 SPNameQualifier: req.ServiceProviderMetadata.EntityID, 781 Value: session.NameID, 782 }, 783 SubjectConfirmations: []SubjectConfirmation{ 784 { 785 Method: "urn:oasis:names:tc:SAML:2.0:cm:bearer", 786 SubjectConfirmationData: &SubjectConfirmationData{ 787 Address: req.HTTPRequest.RemoteAddr, 788 InResponseTo: req.Request.ID, 789 NotOnOrAfter: req.Now.Add(MaxIssueDelay), 790 Recipient: req.ACSEndpoint.Location, 791 }, 792 }, 793 }, 794 }, 795 Conditions: &Conditions{ 796 NotBefore: notBefore, 797 NotOnOrAfter: notOnOrAfterAfter, 798 AudienceRestrictions: []AudienceRestriction{ 799 { 800 Audience: Audience{Value: req.ServiceProviderMetadata.EntityID}, 801 }, 802 }, 803 }, 804 AuthnStatements: []AuthnStatement{ 805 { 806 AuthnInstant: session.CreateTime, 807 SessionIndex: session.Index, 808 SubjectLocality: &SubjectLocality{ 809 Address: req.HTTPRequest.RemoteAddr, 810 }, 811 AuthnContext: AuthnContext{ 812 AuthnContextClassRef: &AuthnContextClassRef{ 813 Value: "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport", 814 }, 815 }, 816 }, 817 }, 818 AttributeStatements: []AttributeStatement{ 819 { 820 Attributes: attributes, 821 }, 822 }, 823 } 824 825 return nil 826 } 827 828 // The Canonicalizer prefix list MUST be empty. Various implementations 829 // (maybe ours?) do not appear to support non-empty prefix lists in XML C14N. 830 const canonicalizerPrefixList = "" 831 832 // MakeAssertionEl sets `AssertionEl` to a signed, possibly encrypted, version of `Assertion`. 833 func (req *IdpAuthnRequest) MakeAssertionEl() error { 834 signingContext, err := req.signingContext() 835 if err != nil { 836 return err 837 } 838 839 assertionEl := req.Assertion.Element() 840 841 signedAssertionEl, err := signingContext.SignEnveloped(assertionEl) 842 if err != nil { 843 return err 844 } 845 846 sigEl := signedAssertionEl.Child[len(signedAssertionEl.Child)-1] 847 req.Assertion.Signature = sigEl.(*etree.Element) 848 signedAssertionEl = req.Assertion.Element() 849 850 certBuf, err := req.getSPEncryptionCert() 851 if err == os.ErrNotExist { 852 req.AssertionEl = signedAssertionEl 853 return nil 854 } else if err != nil { 855 return err 856 } 857 858 var signedAssertionBuf []byte 859 { 860 doc := etree.NewDocument() 861 doc.SetRoot(signedAssertionEl) 862 signedAssertionBuf, err = doc.WriteToBytes() 863 if err != nil { 864 return err 865 } 866 } 867 868 encryptor := xmlenc.OAEP() 869 encryptor.BlockCipher = xmlenc.AES128CBC 870 encryptor.DigestMethod = &xmlenc.SHA1 871 encryptedDataEl, err := encryptor.Encrypt(certBuf, signedAssertionBuf, nil) 872 if err != nil { 873 return err 874 } 875 encryptedDataEl.CreateAttr("Type", "http://www.w3.org/2001/04/xmlenc#Element") 876 877 encryptedAssertionEl := etree.NewElement("saml:EncryptedAssertion") 878 encryptedAssertionEl.AddChild(encryptedDataEl) 879 req.AssertionEl = encryptedAssertionEl 880 881 return nil 882 } 883 884 // IdpAuthnRequestForm contans HTML form information to be submitted to the 885 // SAML HTTP POST binding ACS. 886 type IdpAuthnRequestForm struct { 887 URL string 888 SAMLResponse string 889 RelayState string 890 } 891 892 // PostBinding creates the HTTP POST form information for this 893 // `IdpAuthnRequest`. If `Response` is not already set, it calls MakeResponse 894 // to produce it. 895 func (req *IdpAuthnRequest) PostBinding() (IdpAuthnRequestForm, error) { 896 var form IdpAuthnRequestForm 897 898 if req.ResponseEl == nil { 899 if err := req.MakeResponse(); err != nil { 900 return form, err 901 } 902 } 903 904 doc := etree.NewDocument() 905 doc.SetRoot(req.ResponseEl) 906 responseBuf, err := doc.WriteToBytes() 907 if err != nil { 908 return form, err 909 } 910 911 if req.ACSEndpoint.Binding != HTTPPostBinding { 912 return form, fmt.Errorf("%s: unsupported binding %s", 913 req.ServiceProviderMetadata.EntityID, 914 req.ACSEndpoint.Binding) 915 } 916 917 form.URL = req.ACSEndpoint.Location 918 form.SAMLResponse = base64.StdEncoding.EncodeToString(responseBuf) 919 form.RelayState = req.RelayState 920 921 return form, nil 922 } 923 924 // WriteResponse writes the `Response` to the http.ResponseWriter. If 925 // `Response` is not already set, it calls MakeResponse to produce it. 926 func (req *IdpAuthnRequest) WriteResponse(w http.ResponseWriter) error { 927 form, err := req.PostBinding() 928 if err != nil { 929 return err 930 } 931 932 tmpl := template.Must(template.New("saml-post-form").Parse(`<html>` + 933 `<form method="post" action="{{.URL}}" id="SAMLResponseForm">` + 934 `<input type="hidden" name="SAMLResponse" value="{{.SAMLResponse}}" />` + 935 `<input type="hidden" name="RelayState" value="{{.RelayState}}" />` + 936 `<input id="SAMLSubmitButton" type="submit" value="Continue" />` + 937 `</form>` + 938 `<script>document.getElementById('SAMLSubmitButton').style.visibility='hidden';</script>` + 939 `<script>document.getElementById('SAMLResponseForm').submit();</script>` + 940 `</html>`)) 941 942 buf := bytes.NewBuffer(nil) 943 if err := tmpl.Execute(buf, form); err != nil { 944 return err 945 } 946 if _, err := io.Copy(w, buf); err != nil { 947 return err 948 } 949 return nil 950 } 951 952 // getSPEncryptionCert returns the certificate which we can use to encrypt things 953 // to the SP in PEM format, or nil if no such certificate is found. 954 func (req *IdpAuthnRequest) getSPEncryptionCert() (*x509.Certificate, error) { 955 certStr := "" 956 for _, keyDescriptor := range req.SPSSODescriptor.KeyDescriptors { 957 if keyDescriptor.Use == "encryption" { 958 certStr = keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data 959 break 960 } 961 } 962 963 // If there are no certs explicitly labeled for encryption, return the first 964 // non-empty cert we find. 965 if certStr == "" { 966 for _, keyDescriptor := range req.SPSSODescriptor.KeyDescriptors { 967 if keyDescriptor.Use == "" && len(keyDescriptor.KeyInfo.X509Data.X509Certificates) != 0 && keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data != "" { 968 certStr = keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data 969 break 970 } 971 } 972 } 973 974 if certStr == "" { 975 return nil, os.ErrNotExist 976 } 977 978 // cleanup whitespace and re-encode a PEM 979 certStr = regexp.MustCompile(`\s+`).ReplaceAllString(certStr, "") 980 certBytes, err := base64.StdEncoding.DecodeString(certStr) 981 if err != nil { 982 return nil, fmt.Errorf("cannot decode certificate base64: %v", err) 983 } 984 cert, err := x509.ParseCertificate(certBytes) 985 if err != nil { 986 return nil, fmt.Errorf("cannot parse certificate: %v", err) 987 } 988 return cert, nil 989 } 990 991 // unmarshalEtreeHack parses `el` and sets values in the structure `v`. 992 // 993 // This is a hack -- it first serializes the element, then uses xml.Unmarshal. 994 func unmarshalEtreeHack(el *etree.Element, v interface{}) error { 995 doc := etree.NewDocument() 996 doc.SetRoot(el) 997 buf, err := doc.WriteToBytes() 998 if err != nil { 999 return err 1000 } 1001 return xml.Unmarshal(buf, v) 1002 } 1003 1004 // MakeResponse creates and assigns a new SAML response in ResponseEl. `Assertion` must 1005 // be non-nil. If MakeAssertionEl() has not been called, this function calls it for 1006 // you. 1007 func (req *IdpAuthnRequest) MakeResponse() error { 1008 if req.AssertionEl == nil { 1009 if err := req.MakeAssertionEl(); err != nil { 1010 return err 1011 } 1012 } 1013 1014 response := &Response{ 1015 Destination: req.ACSEndpoint.Location, 1016 ID: fmt.Sprintf("id-%x", randomBytes(20)), 1017 InResponseTo: req.Request.ID, 1018 IssueInstant: req.Now, 1019 Version: "2.0", 1020 Issuer: &Issuer{ 1021 Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity", 1022 Value: req.IDP.MetadataURL.String(), 1023 }, 1024 Status: Status{ 1025 StatusCode: StatusCode{ 1026 Value: StatusSuccess, 1027 }, 1028 }, 1029 } 1030 1031 responseEl := response.Element() 1032 responseEl.AddChild(req.AssertionEl) // AssertionEl either an EncryptedAssertion or Assertion element 1033 1034 // Sign the response element (we've already signed the Assertion element) 1035 { 1036 signingContext, err := req.signingContext() 1037 if err != nil { 1038 return err 1039 } 1040 1041 signedResponseEl, err := signingContext.SignEnveloped(responseEl) 1042 if err != nil { 1043 return err 1044 } 1045 1046 sigEl := signedResponseEl.ChildElements()[len(signedResponseEl.ChildElements())-1] 1047 response.Signature = sigEl 1048 responseEl = response.Element() 1049 responseEl.AddChild(req.AssertionEl) 1050 } 1051 1052 req.ResponseEl = responseEl 1053 return nil 1054 } 1055 1056 // signingContext will create a signing context for the request. 1057 func (req *IdpAuthnRequest) signingContext() (*dsig.SigningContext, error) { 1058 // Create a cert chain based off of the IDP cert and its intermediates. 1059 certificates := [][]byte{req.IDP.Certificate.Raw} 1060 for _, cert := range req.IDP.Intermediates { 1061 certificates = append(certificates, cert.Raw) 1062 } 1063 1064 var signingContext *dsig.SigningContext 1065 var err error 1066 // If signer is set, use it instead of the private key. 1067 if req.IDP.Signer != nil { 1068 signingContext, err = dsig.NewSigningContext(req.IDP.Signer, certificates) 1069 if err != nil { 1070 return nil, err 1071 } 1072 } else { 1073 keyPair := tls.Certificate{ 1074 Certificate: certificates, 1075 PrivateKey: req.IDP.Key, 1076 Leaf: req.IDP.Certificate, 1077 } 1078 keyStore := dsig.TLSCertKeyStore(keyPair) 1079 1080 signingContext = dsig.NewDefaultSigningContext(keyStore) 1081 } 1082 1083 // Default to using SHA1 if the signature method isn't set. 1084 signatureMethod := req.IDP.SignatureMethod 1085 if signatureMethod == "" { 1086 signatureMethod = dsig.RSASHA1SignatureMethod 1087 } 1088 1089 signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) 1090 if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { 1091 return nil, err 1092 } 1093 1094 return signingContext, nil 1095 }