github.com/crewjam/saml@v0.4.14/samlsp/session_jwt.go (about) 1 package samlsp 2 3 import ( 4 "crypto/rsa" 5 "errors" 6 "fmt" 7 "time" 8 9 "github.com/golang-jwt/jwt/v4" 10 11 "github.com/crewjam/saml" 12 ) 13 14 const ( 15 defaultSessionMaxAge = time.Hour 16 claimNameSessionIndex = "SessionIndex" 17 ) 18 19 // JWTSessionCodec implements SessionCoded to encode and decode Sessions from 20 // the corresponding JWT. 21 type JWTSessionCodec struct { 22 SigningMethod jwt.SigningMethod 23 Audience string 24 Issuer string 25 MaxAge time.Duration 26 Key *rsa.PrivateKey 27 } 28 29 var _ SessionCodec = JWTSessionCodec{} 30 31 // New creates a Session from the SAML assertion. 32 // 33 // The returned Session is a JWTSessionClaims. 34 func (c JWTSessionCodec) New(assertion *saml.Assertion) (Session, error) { 35 now := saml.TimeNow() 36 claims := JWTSessionClaims{} 37 claims.SAMLSession = true 38 claims.Audience = c.Audience 39 claims.Issuer = c.Issuer 40 claims.IssuedAt = now.Unix() 41 claims.ExpiresAt = now.Add(c.MaxAge).Unix() 42 claims.NotBefore = now.Unix() 43 44 if sub := assertion.Subject; sub != nil { 45 if nameID := sub.NameID; nameID != nil { 46 claims.Subject = nameID.Value 47 } 48 } 49 50 claims.Attributes = map[string][]string{} 51 52 for _, attributeStatement := range assertion.AttributeStatements { 53 for _, attr := range attributeStatement.Attributes { 54 claimName := attr.FriendlyName 55 if claimName == "" { 56 claimName = attr.Name 57 } 58 for _, value := range attr.Values { 59 claims.Attributes[claimName] = append(claims.Attributes[claimName], value.Value) 60 } 61 } 62 } 63 64 // add SessionIndex to claims Attributes 65 for _, authnStatement := range assertion.AuthnStatements { 66 claims.Attributes[claimNameSessionIndex] = append(claims.Attributes[claimNameSessionIndex], 67 authnStatement.SessionIndex) 68 } 69 70 return claims, nil 71 } 72 73 // Encode returns a serialized version of the Session. 74 // 75 // The provided session must be a JWTSessionClaims, otherwise this 76 // function will panic. 77 func (c JWTSessionCodec) Encode(s Session) (string, error) { 78 claims := s.(JWTSessionClaims) // this will panic if you pass the wrong kind of session 79 80 token := jwt.NewWithClaims(c.SigningMethod, claims) 81 signedString, err := token.SignedString(c.Key) 82 if err != nil { 83 return "", err 84 } 85 86 return signedString, nil 87 } 88 89 // Decode parses the serialized session that may have been returned by Encode 90 // and returns a Session. 91 func (c JWTSessionCodec) Decode(signed string) (Session, error) { 92 parser := jwt.Parser{ 93 ValidMethods: []string{c.SigningMethod.Alg()}, 94 } 95 claims := JWTSessionClaims{} 96 _, err := parser.ParseWithClaims(signed, &claims, func(*jwt.Token) (interface{}, error) { 97 return c.Key.Public(), nil 98 }) 99 // TODO(ross): check for errors due to bad time and return ErrNoSession 100 if err != nil { 101 return nil, err 102 } 103 if !claims.VerifyAudience(c.Audience, true) { 104 return nil, fmt.Errorf("expected audience %q, got %q", c.Audience, claims.Audience) 105 } 106 if !claims.VerifyIssuer(c.Issuer, true) { 107 return nil, fmt.Errorf("expected issuer %q, got %q", c.Issuer, claims.Issuer) 108 } 109 if !claims.SAMLSession { 110 return nil, errors.New("expected saml-session") 111 } 112 return claims, nil 113 } 114 115 // JWTSessionClaims represents the JWT claims in the encoded session 116 type JWTSessionClaims struct { 117 jwt.StandardClaims 118 Attributes Attributes `json:"attr"` 119 SAMLSession bool `json:"saml-session"` 120 } 121 122 var _ Session = JWTSessionClaims{} 123 124 // GetAttributes implements SessionWithAttributes. It returns the SAMl attributes. 125 func (c JWTSessionClaims) GetAttributes() Attributes { 126 return c.Attributes 127 } 128 129 // Attributes is a map of attributes provided in the SAML assertion 130 type Attributes map[string][]string 131 132 // Get returns the first attribute named `key` or an empty string if 133 // no such attributes is present. 134 func (a Attributes) Get(key string) string { 135 if a == nil { 136 return "" 137 } 138 v := a[key] 139 if len(v) == 0 { 140 return "" 141 } 142 return v[0] 143 }