github.com/crewjam/saml@v0.4.14/samlsp/session_jwt.go (about)

     1  package samlsp
     2  
     3  import (
     4  	"crypto/rsa"
     5  	"errors"
     6  	"fmt"
     7  	"time"
     8  
     9  	"github.com/golang-jwt/jwt/v4"
    10  
    11  	"github.com/crewjam/saml"
    12  )
    13  
    14  const (
    15  	defaultSessionMaxAge  = time.Hour
    16  	claimNameSessionIndex = "SessionIndex"
    17  )
    18  
    19  // JWTSessionCodec implements SessionCoded to encode and decode Sessions from
    20  // the corresponding JWT.
    21  type JWTSessionCodec struct {
    22  	SigningMethod jwt.SigningMethod
    23  	Audience      string
    24  	Issuer        string
    25  	MaxAge        time.Duration
    26  	Key           *rsa.PrivateKey
    27  }
    28  
    29  var _ SessionCodec = JWTSessionCodec{}
    30  
    31  // New creates a Session from the SAML assertion.
    32  //
    33  // The returned Session is a JWTSessionClaims.
    34  func (c JWTSessionCodec) New(assertion *saml.Assertion) (Session, error) {
    35  	now := saml.TimeNow()
    36  	claims := JWTSessionClaims{}
    37  	claims.SAMLSession = true
    38  	claims.Audience = c.Audience
    39  	claims.Issuer = c.Issuer
    40  	claims.IssuedAt = now.Unix()
    41  	claims.ExpiresAt = now.Add(c.MaxAge).Unix()
    42  	claims.NotBefore = now.Unix()
    43  
    44  	if sub := assertion.Subject; sub != nil {
    45  		if nameID := sub.NameID; nameID != nil {
    46  			claims.Subject = nameID.Value
    47  		}
    48  	}
    49  
    50  	claims.Attributes = map[string][]string{}
    51  
    52  	for _, attributeStatement := range assertion.AttributeStatements {
    53  		for _, attr := range attributeStatement.Attributes {
    54  			claimName := attr.FriendlyName
    55  			if claimName == "" {
    56  				claimName = attr.Name
    57  			}
    58  			for _, value := range attr.Values {
    59  				claims.Attributes[claimName] = append(claims.Attributes[claimName], value.Value)
    60  			}
    61  		}
    62  	}
    63  
    64  	// add SessionIndex to claims Attributes
    65  	for _, authnStatement := range assertion.AuthnStatements {
    66  		claims.Attributes[claimNameSessionIndex] = append(claims.Attributes[claimNameSessionIndex],
    67  			authnStatement.SessionIndex)
    68  	}
    69  
    70  	return claims, nil
    71  }
    72  
    73  // Encode returns a serialized version of the Session.
    74  //
    75  // The provided session must be a JWTSessionClaims, otherwise this
    76  // function will panic.
    77  func (c JWTSessionCodec) Encode(s Session) (string, error) {
    78  	claims := s.(JWTSessionClaims) // this will panic if you pass the wrong kind of session
    79  
    80  	token := jwt.NewWithClaims(c.SigningMethod, claims)
    81  	signedString, err := token.SignedString(c.Key)
    82  	if err != nil {
    83  		return "", err
    84  	}
    85  
    86  	return signedString, nil
    87  }
    88  
    89  // Decode parses the serialized session that may have been returned by Encode
    90  // and returns a Session.
    91  func (c JWTSessionCodec) Decode(signed string) (Session, error) {
    92  	parser := jwt.Parser{
    93  		ValidMethods: []string{c.SigningMethod.Alg()},
    94  	}
    95  	claims := JWTSessionClaims{}
    96  	_, err := parser.ParseWithClaims(signed, &claims, func(*jwt.Token) (interface{}, error) {
    97  		return c.Key.Public(), nil
    98  	})
    99  	// TODO(ross): check for errors due to bad time and return ErrNoSession
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	if !claims.VerifyAudience(c.Audience, true) {
   104  		return nil, fmt.Errorf("expected audience %q, got %q", c.Audience, claims.Audience)
   105  	}
   106  	if !claims.VerifyIssuer(c.Issuer, true) {
   107  		return nil, fmt.Errorf("expected issuer %q, got %q", c.Issuer, claims.Issuer)
   108  	}
   109  	if !claims.SAMLSession {
   110  		return nil, errors.New("expected saml-session")
   111  	}
   112  	return claims, nil
   113  }
   114  
   115  // JWTSessionClaims represents the JWT claims in the encoded session
   116  type JWTSessionClaims struct {
   117  	jwt.StandardClaims
   118  	Attributes  Attributes `json:"attr"`
   119  	SAMLSession bool       `json:"saml-session"`
   120  }
   121  
   122  var _ Session = JWTSessionClaims{}
   123  
   124  // GetAttributes implements SessionWithAttributes. It returns the SAMl attributes.
   125  func (c JWTSessionClaims) GetAttributes() Attributes {
   126  	return c.Attributes
   127  }
   128  
   129  // Attributes is a map of attributes provided in the SAML assertion
   130  type Attributes map[string][]string
   131  
   132  // Get returns the first attribute named `key` or an empty string if
   133  // no such attributes is present.
   134  func (a Attributes) Get(key string) string {
   135  	if a == nil {
   136  		return ""
   137  	}
   138  	v := a[key]
   139  	if len(v) == 0 {
   140  		return ""
   141  	}
   142  	return v[0]
   143  }