github.com/crewjam/saml@v0.4.14/samlsp/request_tracker_jwt.go (about)

     1  package samlsp
     2  
     3  import (
     4  	"crypto/rsa"
     5  	"fmt"
     6  	"time"
     7  
     8  	"github.com/golang-jwt/jwt/v4"
     9  
    10  	"github.com/crewjam/saml"
    11  )
    12  
    13  var defaultJWTSigningMethod = jwt.SigningMethodRS256
    14  
    15  // JWTTrackedRequestCodec encodes TrackedRequests as signed JWTs
    16  type JWTTrackedRequestCodec struct {
    17  	SigningMethod jwt.SigningMethod
    18  	Audience      string
    19  	Issuer        string
    20  	MaxAge        time.Duration
    21  	Key           *rsa.PrivateKey
    22  }
    23  
    24  var _ TrackedRequestCodec = JWTTrackedRequestCodec{}
    25  
    26  // JWTTrackedRequestClaims represents the JWT claims for a tracked request.
    27  type JWTTrackedRequestClaims struct {
    28  	jwt.RegisteredClaims
    29  	TrackedRequest
    30  	SAMLAuthnRequest bool `json:"saml-authn-request"`
    31  }
    32  
    33  // Encode returns an encoded string representing the TrackedRequest.
    34  func (s JWTTrackedRequestCodec) Encode(value TrackedRequest) (string, error) {
    35  	now := saml.TimeNow()
    36  	claims := JWTTrackedRequestClaims{
    37  		RegisteredClaims: jwt.RegisteredClaims{
    38  			Audience:  jwt.ClaimStrings{s.Audience},
    39  			ExpiresAt: jwt.NewNumericDate(now.Add(s.MaxAge)),
    40  			IssuedAt:  jwt.NewNumericDate(now),
    41  			Issuer:    s.Issuer,
    42  			NotBefore: jwt.NewNumericDate(now), // TODO(ross): correct for clock skew
    43  			Subject:   value.Index,
    44  		},
    45  		TrackedRequest:   value,
    46  		SAMLAuthnRequest: true,
    47  	}
    48  	token := jwt.NewWithClaims(s.SigningMethod, claims)
    49  	return token.SignedString(s.Key)
    50  }
    51  
    52  // Decode returns a Tracked request from an encoded string.
    53  func (s JWTTrackedRequestCodec) Decode(signed string) (*TrackedRequest, error) {
    54  	parser := jwt.Parser{
    55  		ValidMethods: []string{s.SigningMethod.Alg()},
    56  	}
    57  	claims := JWTTrackedRequestClaims{}
    58  	_, err := parser.ParseWithClaims(signed, &claims, func(*jwt.Token) (interface{}, error) {
    59  		return s.Key.Public(), nil
    60  	})
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	if !claims.VerifyAudience(s.Audience, true) {
    65  		return nil, fmt.Errorf("expected audience %q, got %q", s.Audience, claims.Audience)
    66  	}
    67  	if !claims.VerifyIssuer(s.Issuer, true) {
    68  		return nil, fmt.Errorf("expected issuer %q, got %q", s.Issuer, claims.Issuer)
    69  	}
    70  	if !claims.SAMLAuthnRequest {
    71  		return nil, fmt.Errorf("expected saml-authn-request")
    72  	}
    73  	claims.TrackedRequest.Index = claims.Subject
    74  	return &claims.TrackedRequest, nil
    75  }