github.com/crewjam/saml@v0.4.14/samlsp/request_tracker_jwt.go (about) 1 package samlsp 2 3 import ( 4 "crypto/rsa" 5 "fmt" 6 "time" 7 8 "github.com/golang-jwt/jwt/v4" 9 10 "github.com/crewjam/saml" 11 ) 12 13 var defaultJWTSigningMethod = jwt.SigningMethodRS256 14 15 // JWTTrackedRequestCodec encodes TrackedRequests as signed JWTs 16 type JWTTrackedRequestCodec struct { 17 SigningMethod jwt.SigningMethod 18 Audience string 19 Issuer string 20 MaxAge time.Duration 21 Key *rsa.PrivateKey 22 } 23 24 var _ TrackedRequestCodec = JWTTrackedRequestCodec{} 25 26 // JWTTrackedRequestClaims represents the JWT claims for a tracked request. 27 type JWTTrackedRequestClaims struct { 28 jwt.RegisteredClaims 29 TrackedRequest 30 SAMLAuthnRequest bool `json:"saml-authn-request"` 31 } 32 33 // Encode returns an encoded string representing the TrackedRequest. 34 func (s JWTTrackedRequestCodec) Encode(value TrackedRequest) (string, error) { 35 now := saml.TimeNow() 36 claims := JWTTrackedRequestClaims{ 37 RegisteredClaims: jwt.RegisteredClaims{ 38 Audience: jwt.ClaimStrings{s.Audience}, 39 ExpiresAt: jwt.NewNumericDate(now.Add(s.MaxAge)), 40 IssuedAt: jwt.NewNumericDate(now), 41 Issuer: s.Issuer, 42 NotBefore: jwt.NewNumericDate(now), // TODO(ross): correct for clock skew 43 Subject: value.Index, 44 }, 45 TrackedRequest: value, 46 SAMLAuthnRequest: true, 47 } 48 token := jwt.NewWithClaims(s.SigningMethod, claims) 49 return token.SignedString(s.Key) 50 } 51 52 // Decode returns a Tracked request from an encoded string. 53 func (s JWTTrackedRequestCodec) Decode(signed string) (*TrackedRequest, error) { 54 parser := jwt.Parser{ 55 ValidMethods: []string{s.SigningMethod.Alg()}, 56 } 57 claims := JWTTrackedRequestClaims{} 58 _, err := parser.ParseWithClaims(signed, &claims, func(*jwt.Token) (interface{}, error) { 59 return s.Key.Public(), nil 60 }) 61 if err != nil { 62 return nil, err 63 } 64 if !claims.VerifyAudience(s.Audience, true) { 65 return nil, fmt.Errorf("expected audience %q, got %q", s.Audience, claims.Audience) 66 } 67 if !claims.VerifyIssuer(s.Issuer, true) { 68 return nil, fmt.Errorf("expected issuer %q, got %q", s.Issuer, claims.Issuer) 69 } 70 if !claims.SAMLAuthnRequest { 71 return nil, fmt.Errorf("expected saml-authn-request") 72 } 73 claims.TrackedRequest.Index = claims.Subject 74 return &claims.TrackedRequest, nil 75 }