github.com/avenga/couper@v1.12.2/accesscontrol/jwt.go (about)

     1  package accesscontrol
     2  
     3  import (
     4  	"context"
     5  	"crypto/ecdsa"
     6  	"crypto/rsa"
     7  	"crypto/x509"
     8  	"encoding/pem"
     9  	goerrors "errors"
    10  	"fmt"
    11  	"net/http"
    12  	"strings"
    13  
    14  	"github.com/golang-jwt/jwt/v4"
    15  	"github.com/hashicorp/hcl/v2"
    16  	"github.com/sirupsen/logrus"
    17  
    18  	"github.com/avenga/couper/accesscontrol/jwk"
    19  	acjwt "github.com/avenga/couper/accesscontrol/jwt"
    20  	"github.com/avenga/couper/config/request"
    21  	"github.com/avenga/couper/errors"
    22  	"github.com/avenga/couper/eval"
    23  	"github.com/avenga/couper/internal/seetie"
    24  )
    25  
    26  const (
    27  	Invalid JWTSourceType = iota
    28  	Cookie
    29  	Header
    30  	Value
    31  )
    32  
    33  var (
    34  	_ AccessControl         = &JWT{}
    35  	_ DisablePrivateCaching = &JWT{}
    36  )
    37  
    38  type (
    39  	JWTSourceType uint8
    40  	JWTSource     struct {
    41  		Expr hcl.Expression
    42  		Name string
    43  		Type JWTSourceType
    44  	}
    45  )
    46  
    47  type JWT struct {
    48  	algorithm             acjwt.Algorithm
    49  	claims                hcl.Expression
    50  	claimsRequired        []string
    51  	disablePrivateCaching bool
    52  	source                JWTSource
    53  	hmacSecret            []byte
    54  	name                  string
    55  	parser                *jwt.Parser
    56  	pubKey                interface{}
    57  	rolesClaim            string
    58  	rolesMap              map[string][]string
    59  	permissionsClaim      string
    60  	permissionsMap        map[string][]string
    61  	jwks                  *jwk.JWKS
    62  }
    63  
    64  type JWTOptions struct {
    65  	Algorithm             string
    66  	Claims                hcl.Expression
    67  	ClaimsRequired        []string
    68  	DisablePrivateCaching bool
    69  	Name                  string // TODO: more generic (validate)
    70  	RolesClaim            string
    71  	RolesMap              map[string][]string
    72  	PermissionsClaim      string
    73  	PermissionsMap        map[string][]string
    74  	Source                JWTSource
    75  	Key                   []byte
    76  	JWKS                  *jwk.JWKS
    77  }
    78  
    79  func NewJWTSource(cookie, header string, value hcl.Expression) JWTSource {
    80  	c, h := strings.TrimSpace(cookie), strings.TrimSpace(header)
    81  
    82  	if value != nil {
    83  		v, _ := value.Value(nil)
    84  		if !v.IsNull() {
    85  			if h != "" || c != "" {
    86  				return JWTSource{}
    87  			}
    88  
    89  			return JWTSource{
    90  				Name: "",
    91  				Type: Value,
    92  				Expr: value,
    93  			}
    94  		}
    95  	}
    96  	if c != "" && h == "" {
    97  		return JWTSource{
    98  			Name: c,
    99  			Type: Cookie,
   100  		}
   101  	}
   102  	if h != "" && c == "" {
   103  		return JWTSource{
   104  			Name: h,
   105  			Type: Header,
   106  		}
   107  	}
   108  	if h == "" && c == "" {
   109  		return JWTSource{
   110  			Name: "Authorization",
   111  			Type: Header,
   112  		}
   113  	}
   114  	return JWTSource{}
   115  }
   116  
   117  // NewJWT parses the key and creates Validation obj which can be referenced in related handlers.
   118  func NewJWT(options *JWTOptions) (*JWT, error) {
   119  	jwtAC, err := newJWT(options)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  
   124  	jwtAC.algorithm = acjwt.NewAlgorithm(options.Algorithm)
   125  	if jwtAC.algorithm == acjwt.AlgorithmUnknown {
   126  		return nil, fmt.Errorf("algorithm %q is not supported", options.Algorithm)
   127  	}
   128  
   129  	jwtAC.parser = newParser([]acjwt.Algorithm{jwtAC.algorithm})
   130  
   131  	if jwtAC.algorithm.IsHMAC() {
   132  		jwtAC.hmacSecret = options.Key
   133  		return jwtAC, nil
   134  	}
   135  
   136  	pubKey, err := parsePublicPEMKey(options.Key)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  
   141  	jwtAC.pubKey = pubKey
   142  	return jwtAC, nil
   143  }
   144  
   145  func NewJWTFromJWKS(options *JWTOptions) (*JWT, error) {
   146  	jwtAC, err := newJWT(options)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  
   151  	if options.JWKS == nil {
   152  		return nil, fmt.Errorf("invalid JWKS")
   153  	}
   154  
   155  	algorithms := append(acjwt.RSAAlgorithms, acjwt.ECDSAlgorithms...)
   156  	jwtAC.parser = newParser(algorithms)
   157  	jwtAC.jwks = options.JWKS
   158  
   159  	return jwtAC, nil
   160  }
   161  
   162  func newJWT(options *JWTOptions) (*JWT, error) {
   163  	if options.Source.Type == Invalid {
   164  		return nil, fmt.Errorf("token source is invalid")
   165  	}
   166  
   167  	if options.RolesClaim != "" && options.RolesMap == nil {
   168  		return nil, fmt.Errorf("missing roles_map")
   169  	}
   170  
   171  	jwtAC := &JWT{
   172  		claims:                options.Claims,
   173  		claimsRequired:        options.ClaimsRequired,
   174  		disablePrivateCaching: options.DisablePrivateCaching,
   175  		name:                  options.Name,
   176  		rolesClaim:            options.RolesClaim,
   177  		rolesMap:              options.RolesMap,
   178  		permissionsClaim:      options.PermissionsClaim,
   179  		permissionsMap:        options.PermissionsMap,
   180  		source:                options.Source,
   181  	}
   182  	return jwtAC, nil
   183  }
   184  
   185  func (j *JWT) DisablePrivateCaching() bool {
   186  	return j.disablePrivateCaching
   187  }
   188  
   189  // Validate reading the token from configured source and validates against the key.
   190  func (j *JWT) Validate(req *http.Request) error {
   191  	var tokenValue string
   192  	var err error
   193  
   194  	switch j.source.Type {
   195  	case Cookie:
   196  		cookie, cerr := req.Cookie(j.source.Name)
   197  		if cerr != http.ErrNoCookie && cookie != nil {
   198  			tokenValue = cookie.Value
   199  		}
   200  	case Header:
   201  		if strings.ToLower(j.source.Name) == "authorization" {
   202  			if tokenValue = req.Header.Get(j.source.Name); tokenValue != "" {
   203  				if tokenValue, err = getBearer(tokenValue); err != nil {
   204  					return errors.JwtTokenMissing.With(err)
   205  				}
   206  			}
   207  		} else {
   208  			tokenValue = req.Header.Get(j.source.Name)
   209  		}
   210  	case Value:
   211  		requestContext := eval.ContextFromRequest(req).HCLContext()
   212  		value, diags := eval.Value(requestContext, j.source.Expr)
   213  		if diags != nil {
   214  			return diags
   215  		}
   216  
   217  		tokenValue = seetie.ValueToString(value)
   218  	}
   219  
   220  	if tokenValue == "" {
   221  		return errors.JwtTokenMissing.Message("token required")
   222  	}
   223  
   224  	expectedClaims, err := j.getConfiguredClaims(req)
   225  	if err != nil {
   226  		return err
   227  	}
   228  
   229  	if j.jwks != nil {
   230  		// load JWKS if needed
   231  		j.jwks.Data()
   232  	}
   233  
   234  	tokenClaims := jwt.MapClaims{}
   235  	_, err = j.parser.ParseWithClaims(tokenValue, tokenClaims, j.getValidationKey)
   236  	if err != nil {
   237  		if goerrors.Is(err, jwt.ErrTokenExpired) {
   238  			return errors.JwtTokenExpired.With(err)
   239  		}
   240  		return errors.JwtTokenInvalid.With(err)
   241  	}
   242  
   243  	err = j.validateClaims(tokenClaims, expectedClaims)
   244  	if err != nil {
   245  		return errors.JwtTokenInvalid.With(err)
   246  	}
   247  
   248  	ctx := req.Context()
   249  	acMap, ok := ctx.Value(request.AccessControls).(map[string]interface{})
   250  	if !ok {
   251  		acMap = make(map[string]interface{})
   252  	}
   253  	// treat token claims as map for context
   254  	acMap[j.name] = map[string]interface{}(tokenClaims)
   255  	ctx = context.WithValue(ctx, request.AccessControls, acMap)
   256  
   257  	log := req.Context().Value(request.LogEntry).(*logrus.Entry).WithContext(req.Context())
   258  	jwtGrantedPermissions := j.getGrantedPermissions(tokenClaims, log)
   259  
   260  	grantedPermissions, _ := ctx.Value(request.GrantedPermissions).([]string)
   261  
   262  	grantedPermissions = append(grantedPermissions, jwtGrantedPermissions...)
   263  
   264  	ctx = context.WithValue(ctx, request.GrantedPermissions, grantedPermissions)
   265  
   266  	*req = *req.WithContext(ctx)
   267  
   268  	return nil
   269  }
   270  
   271  func (j *JWT) getValidationKey(token *jwt.Token) (interface{}, error) {
   272  	if j.jwks != nil {
   273  		return j.jwks.GetSigKeyForToken(token)
   274  	}
   275  
   276  	switch j.algorithm {
   277  	case acjwt.AlgorithmRSA256, acjwt.AlgorithmRSA384, acjwt.AlgorithmRSA512:
   278  		return j.pubKey, nil
   279  	case acjwt.AlgorithmECDSA256, acjwt.AlgorithmECDSA384, acjwt.AlgorithmECDSA512:
   280  		return j.pubKey, nil
   281  	case acjwt.AlgorithmHMAC256, acjwt.AlgorithmHMAC384, acjwt.AlgorithmHMAC512:
   282  		return j.hmacSecret, nil
   283  	default: // this error case gets normally caught on configuration level
   284  		return nil, errors.Configuration.Message("algorithm is not supported")
   285  	}
   286  }
   287  
   288  // getConfiguredClaims evaluates the expected claim values from the configuration, and especially iss and aud
   289  func (j *JWT) getConfiguredClaims(req *http.Request) (map[string]interface{}, error) {
   290  	claims := make(map[string]interface{})
   291  	if j.claims != nil {
   292  		val, verr := eval.Value(eval.ContextFromRequest(req).HCLContext(), j.claims)
   293  		if verr != nil {
   294  			return nil, verr
   295  		}
   296  		claims = seetie.ValueToMap(val)
   297  
   298  		var ok bool
   299  		if issVal, exists := claims["iss"]; exists {
   300  			_, ok = issVal.(string)
   301  			if !ok {
   302  				return nil, errors.Configuration.Message("invalid value type, string expected (claims / iss)")
   303  			}
   304  		}
   305  
   306  		if audVal, exists := claims["aud"]; exists {
   307  			_, ok = audVal.(string)
   308  			if !ok {
   309  				return nil, errors.Configuration.Message("invalid value type, string expected (claims / aud)")
   310  			}
   311  		}
   312  	}
   313  
   314  	return claims, nil
   315  }
   316  
   317  // validateClaims validates the token claims against the list of required claims and the expected claims values
   318  func (j *JWT) validateClaims(tokenClaims jwt.MapClaims, expectedClaims map[string]interface{}) error {
   319  	for _, key := range j.claimsRequired {
   320  		if _, ok := tokenClaims[key]; !ok {
   321  			return fmt.Errorf("required claim is missing: " + key)
   322  		}
   323  	}
   324  
   325  	for k, v := range expectedClaims {
   326  		val, exist := tokenClaims[k]
   327  		if !exist {
   328  			return fmt.Errorf("required claim is missing: " + k)
   329  		}
   330  
   331  		if k == "iss" {
   332  			if !tokenClaims.VerifyIssuer(v.(string), true) {
   333  				return errors.JwtTokenInvalid.Message("invalid issuer")
   334  			}
   335  			continue
   336  		}
   337  		if k == "aud" {
   338  			if !tokenClaims.VerifyAudience(v.(string), true) {
   339  				return errors.JwtTokenInvalid.Message("invalid audience")
   340  			}
   341  			continue
   342  		}
   343  
   344  		if val != v {
   345  			return fmt.Errorf("unexpected value for claim %s, got %q, expected %q", k, val, v)
   346  		}
   347  	}
   348  	return nil
   349  }
   350  
   351  func (j *JWT) getGrantedPermissions(tokenClaims jwt.MapClaims, log *logrus.Entry) []string {
   352  	var grantedPermissions []string
   353  
   354  	grantedPermissions = j.addPermissionsFromPermissionsClaim(tokenClaims, grantedPermissions, log)
   355  
   356  	grantedPermissions = j.addPermissionsFromRoles(tokenClaims, grantedPermissions, log)
   357  
   358  	grantedPermissions = j.addMappedPermissions(grantedPermissions, grantedPermissions)
   359  
   360  	return grantedPermissions
   361  }
   362  
   363  const warnInvalidValueMsg = "invalid %s claim value type, ignoring claim, value %#v"
   364  
   365  func (j *JWT) addPermissionsFromPermissionsClaim(tokenClaims jwt.MapClaims, permissions []string, log *logrus.Entry) []string {
   366  	if j.permissionsClaim == "" {
   367  		return permissions
   368  	}
   369  
   370  	permissionsFromClaim, exists := tokenClaims[j.permissionsClaim]
   371  	if !exists {
   372  		return permissions
   373  	}
   374  
   375  	// ["foo", "bar"] is stored as []interface{}, not []string, unfortunately
   376  	permissionsArray, ok := permissionsFromClaim.([]interface{})
   377  	if ok {
   378  		var vals []string
   379  		for _, v := range permissionsArray {
   380  			p, ok := v.(string)
   381  			if !ok {
   382  				log.Warn(fmt.Sprintf(warnInvalidValueMsg, "permissions", permissionsFromClaim))
   383  				return permissions
   384  			}
   385  			vals = append(vals, p)
   386  		}
   387  		for _, val := range vals {
   388  			permissions, _ = addPermission(permissions, val)
   389  		}
   390  	} else {
   391  		permissionsString, ok := permissionsFromClaim.(string)
   392  		if !ok {
   393  			log.Warn(fmt.Sprintf(warnInvalidValueMsg, "permissions", permissionsFromClaim))
   394  			return permissions
   395  		}
   396  		for _, p := range strings.Split(permissionsString, " ") {
   397  			permissions, _ = addPermission(permissions, p)
   398  		}
   399  	}
   400  	return permissions
   401  }
   402  
   403  func (j *JWT) getRoleValues(rolesClaimValue interface{}, log *logrus.Entry) []string {
   404  	var roleValues []string
   405  	// ["foo", "bar"] is stored as []interface{}, not []string, unfortunately
   406  	rolesArray, ok := rolesClaimValue.([]interface{})
   407  	if ok {
   408  		var vals []string
   409  		for _, v := range rolesArray {
   410  			r, ok := v.(string)
   411  			if !ok {
   412  				log.Warn(fmt.Sprintf(warnInvalidValueMsg, "roles", rolesClaimValue))
   413  				return roleValues
   414  			}
   415  			vals = append(vals, r)
   416  		}
   417  		return vals
   418  	}
   419  
   420  	rolesString, ok := rolesClaimValue.(string)
   421  	if !ok {
   422  		log.Warn(fmt.Sprintf(warnInvalidValueMsg, "roles", rolesClaimValue))
   423  		return roleValues
   424  	}
   425  	return strings.Split(rolesString, " ")
   426  }
   427  
   428  func (j *JWT) addPermissionsFromRoles(tokenClaims jwt.MapClaims, permissions []string, log *logrus.Entry) []string {
   429  	if j.rolesClaim == "" || j.rolesMap == nil {
   430  		return permissions
   431  	}
   432  
   433  	rolesClaimValue, exists := tokenClaims[j.rolesClaim]
   434  	if !exists {
   435  		return permissions
   436  	}
   437  
   438  	roleValues := j.getRoleValues(rolesClaimValue, log)
   439  	for _, r := range roleValues {
   440  		if perms, exist := j.rolesMap[r]; exist {
   441  			for _, p := range perms {
   442  				permissions, _ = addPermission(permissions, p)
   443  			}
   444  		}
   445  	}
   446  
   447  	if perms, exist := j.rolesMap["*"]; exist {
   448  		for _, p := range perms {
   449  			permissions, _ = addPermission(permissions, p)
   450  		}
   451  	}
   452  	return permissions
   453  }
   454  
   455  func (j *JWT) addMappedPermissions(source, target []string) []string {
   456  	if j.permissionsMap == nil {
   457  		return target
   458  	}
   459  
   460  	for _, val := range source {
   461  		mappedValues, exist := j.permissionsMap[val]
   462  		if !exist {
   463  			// no mapping for value
   464  			continue
   465  		}
   466  
   467  		var l []string
   468  		for _, mv := range mappedValues {
   469  			var added bool
   470  			// add value from mapping?
   471  			target, added = addPermission(target, mv)
   472  			if !added {
   473  				continue
   474  			}
   475  			l = append(l, mv)
   476  		}
   477  		// recursion: call only with values not already in target
   478  		target = j.addMappedPermissions(l, target)
   479  	}
   480  	return target
   481  }
   482  
   483  func addPermission(permissions []string, permission string) ([]string, bool) {
   484  	permission = strings.TrimSpace(permission)
   485  	if permission == "" {
   486  		return permissions, false
   487  	}
   488  	for _, p := range permissions {
   489  		if p == permission {
   490  			return permissions, false
   491  		}
   492  	}
   493  	return append(permissions, permission), true
   494  }
   495  
   496  func getBearer(val string) (string, error) {
   497  	const bearer = "bearer "
   498  	if strings.HasPrefix(strings.ToLower(val), bearer) {
   499  		return strings.Trim(val[len(bearer):], " "), nil
   500  	}
   501  	return "", fmt.Errorf("bearer required with authorization header")
   502  }
   503  
   504  // newParser creates a new parser
   505  func newParser(algos []acjwt.Algorithm) *jwt.Parser {
   506  	var algorithms []string
   507  	for _, a := range algos {
   508  		algorithms = append(algorithms, a.String())
   509  	}
   510  	options := []jwt.ParserOption{
   511  		jwt.WithValidMethods(algorithms),
   512  		// no equivalent in new lib
   513  		// jwt.WithLeeway(time.Second),
   514  	}
   515  
   516  	return jwt.NewParser(options...)
   517  }
   518  
   519  // parsePublicPEMKey tries to parse all supported publicKey variations which
   520  // must be given in PEM encoded format.
   521  func parsePublicPEMKey(key []byte) (pub interface{}, err error) {
   522  	pemBlock, _ := pem.Decode(key)
   523  	if pemBlock == nil {
   524  		return nil, jwt.ErrKeyMustBePEMEncoded
   525  	}
   526  	pubKey, pubErr := x509.ParsePKCS1PublicKey(pemBlock.Bytes)
   527  	if pubErr != nil {
   528  		pkixKey, pkerr := x509.ParsePKIXPublicKey(pemBlock.Bytes)
   529  		if pkerr != nil {
   530  			cert, cerr := x509.ParseCertificate(pemBlock.Bytes)
   531  			if cerr != nil {
   532  				return nil, jwt.ErrNotRSAPublicKey
   533  			}
   534  			if k, ok := cert.PublicKey.(*rsa.PublicKey); ok {
   535  				return k, nil
   536  			}
   537  			if k, ok := cert.PublicKey.(*ecdsa.PublicKey); ok {
   538  				return k, nil
   539  			}
   540  
   541  			return nil, fmt.Errorf("invalid RSA/ECDSA public key")
   542  		}
   543  
   544  		if k, ok := pkixKey.(*rsa.PublicKey); ok {
   545  			return k, nil
   546  		}
   547  
   548  		if k, ok := pkixKey.(*ecdsa.PublicKey); ok {
   549  			return k, nil
   550  		}
   551  
   552  		return nil, fmt.Errorf("invalid RSA/ECDSA public key")
   553  	}
   554  	return pubKey, nil
   555  }