github.com/avenga/couper@v1.12.2/accesscontrol/jwt.go (about) 1 package accesscontrol 2 3 import ( 4 "context" 5 "crypto/ecdsa" 6 "crypto/rsa" 7 "crypto/x509" 8 "encoding/pem" 9 goerrors "errors" 10 "fmt" 11 "net/http" 12 "strings" 13 14 "github.com/golang-jwt/jwt/v4" 15 "github.com/hashicorp/hcl/v2" 16 "github.com/sirupsen/logrus" 17 18 "github.com/avenga/couper/accesscontrol/jwk" 19 acjwt "github.com/avenga/couper/accesscontrol/jwt" 20 "github.com/avenga/couper/config/request" 21 "github.com/avenga/couper/errors" 22 "github.com/avenga/couper/eval" 23 "github.com/avenga/couper/internal/seetie" 24 ) 25 26 const ( 27 Invalid JWTSourceType = iota 28 Cookie 29 Header 30 Value 31 ) 32 33 var ( 34 _ AccessControl = &JWT{} 35 _ DisablePrivateCaching = &JWT{} 36 ) 37 38 type ( 39 JWTSourceType uint8 40 JWTSource struct { 41 Expr hcl.Expression 42 Name string 43 Type JWTSourceType 44 } 45 ) 46 47 type JWT struct { 48 algorithm acjwt.Algorithm 49 claims hcl.Expression 50 claimsRequired []string 51 disablePrivateCaching bool 52 source JWTSource 53 hmacSecret []byte 54 name string 55 parser *jwt.Parser 56 pubKey interface{} 57 rolesClaim string 58 rolesMap map[string][]string 59 permissionsClaim string 60 permissionsMap map[string][]string 61 jwks *jwk.JWKS 62 } 63 64 type JWTOptions struct { 65 Algorithm string 66 Claims hcl.Expression 67 ClaimsRequired []string 68 DisablePrivateCaching bool 69 Name string // TODO: more generic (validate) 70 RolesClaim string 71 RolesMap map[string][]string 72 PermissionsClaim string 73 PermissionsMap map[string][]string 74 Source JWTSource 75 Key []byte 76 JWKS *jwk.JWKS 77 } 78 79 func NewJWTSource(cookie, header string, value hcl.Expression) JWTSource { 80 c, h := strings.TrimSpace(cookie), strings.TrimSpace(header) 81 82 if value != nil { 83 v, _ := value.Value(nil) 84 if !v.IsNull() { 85 if h != "" || c != "" { 86 return JWTSource{} 87 } 88 89 return JWTSource{ 90 Name: "", 91 Type: Value, 92 Expr: value, 93 } 94 } 95 } 96 if c != "" && h == "" { 97 return JWTSource{ 98 Name: c, 99 Type: Cookie, 100 } 101 } 102 if h != "" && c == "" { 103 return JWTSource{ 104 Name: h, 105 Type: Header, 106 } 107 } 108 if h == "" && c == "" { 109 return JWTSource{ 110 Name: "Authorization", 111 Type: Header, 112 } 113 } 114 return JWTSource{} 115 } 116 117 // NewJWT parses the key and creates Validation obj which can be referenced in related handlers. 118 func NewJWT(options *JWTOptions) (*JWT, error) { 119 jwtAC, err := newJWT(options) 120 if err != nil { 121 return nil, err 122 } 123 124 jwtAC.algorithm = acjwt.NewAlgorithm(options.Algorithm) 125 if jwtAC.algorithm == acjwt.AlgorithmUnknown { 126 return nil, fmt.Errorf("algorithm %q is not supported", options.Algorithm) 127 } 128 129 jwtAC.parser = newParser([]acjwt.Algorithm{jwtAC.algorithm}) 130 131 if jwtAC.algorithm.IsHMAC() { 132 jwtAC.hmacSecret = options.Key 133 return jwtAC, nil 134 } 135 136 pubKey, err := parsePublicPEMKey(options.Key) 137 if err != nil { 138 return nil, err 139 } 140 141 jwtAC.pubKey = pubKey 142 return jwtAC, nil 143 } 144 145 func NewJWTFromJWKS(options *JWTOptions) (*JWT, error) { 146 jwtAC, err := newJWT(options) 147 if err != nil { 148 return nil, err 149 } 150 151 if options.JWKS == nil { 152 return nil, fmt.Errorf("invalid JWKS") 153 } 154 155 algorithms := append(acjwt.RSAAlgorithms, acjwt.ECDSAlgorithms...) 156 jwtAC.parser = newParser(algorithms) 157 jwtAC.jwks = options.JWKS 158 159 return jwtAC, nil 160 } 161 162 func newJWT(options *JWTOptions) (*JWT, error) { 163 if options.Source.Type == Invalid { 164 return nil, fmt.Errorf("token source is invalid") 165 } 166 167 if options.RolesClaim != "" && options.RolesMap == nil { 168 return nil, fmt.Errorf("missing roles_map") 169 } 170 171 jwtAC := &JWT{ 172 claims: options.Claims, 173 claimsRequired: options.ClaimsRequired, 174 disablePrivateCaching: options.DisablePrivateCaching, 175 name: options.Name, 176 rolesClaim: options.RolesClaim, 177 rolesMap: options.RolesMap, 178 permissionsClaim: options.PermissionsClaim, 179 permissionsMap: options.PermissionsMap, 180 source: options.Source, 181 } 182 return jwtAC, nil 183 } 184 185 func (j *JWT) DisablePrivateCaching() bool { 186 return j.disablePrivateCaching 187 } 188 189 // Validate reading the token from configured source and validates against the key. 190 func (j *JWT) Validate(req *http.Request) error { 191 var tokenValue string 192 var err error 193 194 switch j.source.Type { 195 case Cookie: 196 cookie, cerr := req.Cookie(j.source.Name) 197 if cerr != http.ErrNoCookie && cookie != nil { 198 tokenValue = cookie.Value 199 } 200 case Header: 201 if strings.ToLower(j.source.Name) == "authorization" { 202 if tokenValue = req.Header.Get(j.source.Name); tokenValue != "" { 203 if tokenValue, err = getBearer(tokenValue); err != nil { 204 return errors.JwtTokenMissing.With(err) 205 } 206 } 207 } else { 208 tokenValue = req.Header.Get(j.source.Name) 209 } 210 case Value: 211 requestContext := eval.ContextFromRequest(req).HCLContext() 212 value, diags := eval.Value(requestContext, j.source.Expr) 213 if diags != nil { 214 return diags 215 } 216 217 tokenValue = seetie.ValueToString(value) 218 } 219 220 if tokenValue == "" { 221 return errors.JwtTokenMissing.Message("token required") 222 } 223 224 expectedClaims, err := j.getConfiguredClaims(req) 225 if err != nil { 226 return err 227 } 228 229 if j.jwks != nil { 230 // load JWKS if needed 231 j.jwks.Data() 232 } 233 234 tokenClaims := jwt.MapClaims{} 235 _, err = j.parser.ParseWithClaims(tokenValue, tokenClaims, j.getValidationKey) 236 if err != nil { 237 if goerrors.Is(err, jwt.ErrTokenExpired) { 238 return errors.JwtTokenExpired.With(err) 239 } 240 return errors.JwtTokenInvalid.With(err) 241 } 242 243 err = j.validateClaims(tokenClaims, expectedClaims) 244 if err != nil { 245 return errors.JwtTokenInvalid.With(err) 246 } 247 248 ctx := req.Context() 249 acMap, ok := ctx.Value(request.AccessControls).(map[string]interface{}) 250 if !ok { 251 acMap = make(map[string]interface{}) 252 } 253 // treat token claims as map for context 254 acMap[j.name] = map[string]interface{}(tokenClaims) 255 ctx = context.WithValue(ctx, request.AccessControls, acMap) 256 257 log := req.Context().Value(request.LogEntry).(*logrus.Entry).WithContext(req.Context()) 258 jwtGrantedPermissions := j.getGrantedPermissions(tokenClaims, log) 259 260 grantedPermissions, _ := ctx.Value(request.GrantedPermissions).([]string) 261 262 grantedPermissions = append(grantedPermissions, jwtGrantedPermissions...) 263 264 ctx = context.WithValue(ctx, request.GrantedPermissions, grantedPermissions) 265 266 *req = *req.WithContext(ctx) 267 268 return nil 269 } 270 271 func (j *JWT) getValidationKey(token *jwt.Token) (interface{}, error) { 272 if j.jwks != nil { 273 return j.jwks.GetSigKeyForToken(token) 274 } 275 276 switch j.algorithm { 277 case acjwt.AlgorithmRSA256, acjwt.AlgorithmRSA384, acjwt.AlgorithmRSA512: 278 return j.pubKey, nil 279 case acjwt.AlgorithmECDSA256, acjwt.AlgorithmECDSA384, acjwt.AlgorithmECDSA512: 280 return j.pubKey, nil 281 case acjwt.AlgorithmHMAC256, acjwt.AlgorithmHMAC384, acjwt.AlgorithmHMAC512: 282 return j.hmacSecret, nil 283 default: // this error case gets normally caught on configuration level 284 return nil, errors.Configuration.Message("algorithm is not supported") 285 } 286 } 287 288 // getConfiguredClaims evaluates the expected claim values from the configuration, and especially iss and aud 289 func (j *JWT) getConfiguredClaims(req *http.Request) (map[string]interface{}, error) { 290 claims := make(map[string]interface{}) 291 if j.claims != nil { 292 val, verr := eval.Value(eval.ContextFromRequest(req).HCLContext(), j.claims) 293 if verr != nil { 294 return nil, verr 295 } 296 claims = seetie.ValueToMap(val) 297 298 var ok bool 299 if issVal, exists := claims["iss"]; exists { 300 _, ok = issVal.(string) 301 if !ok { 302 return nil, errors.Configuration.Message("invalid value type, string expected (claims / iss)") 303 } 304 } 305 306 if audVal, exists := claims["aud"]; exists { 307 _, ok = audVal.(string) 308 if !ok { 309 return nil, errors.Configuration.Message("invalid value type, string expected (claims / aud)") 310 } 311 } 312 } 313 314 return claims, nil 315 } 316 317 // validateClaims validates the token claims against the list of required claims and the expected claims values 318 func (j *JWT) validateClaims(tokenClaims jwt.MapClaims, expectedClaims map[string]interface{}) error { 319 for _, key := range j.claimsRequired { 320 if _, ok := tokenClaims[key]; !ok { 321 return fmt.Errorf("required claim is missing: " + key) 322 } 323 } 324 325 for k, v := range expectedClaims { 326 val, exist := tokenClaims[k] 327 if !exist { 328 return fmt.Errorf("required claim is missing: " + k) 329 } 330 331 if k == "iss" { 332 if !tokenClaims.VerifyIssuer(v.(string), true) { 333 return errors.JwtTokenInvalid.Message("invalid issuer") 334 } 335 continue 336 } 337 if k == "aud" { 338 if !tokenClaims.VerifyAudience(v.(string), true) { 339 return errors.JwtTokenInvalid.Message("invalid audience") 340 } 341 continue 342 } 343 344 if val != v { 345 return fmt.Errorf("unexpected value for claim %s, got %q, expected %q", k, val, v) 346 } 347 } 348 return nil 349 } 350 351 func (j *JWT) getGrantedPermissions(tokenClaims jwt.MapClaims, log *logrus.Entry) []string { 352 var grantedPermissions []string 353 354 grantedPermissions = j.addPermissionsFromPermissionsClaim(tokenClaims, grantedPermissions, log) 355 356 grantedPermissions = j.addPermissionsFromRoles(tokenClaims, grantedPermissions, log) 357 358 grantedPermissions = j.addMappedPermissions(grantedPermissions, grantedPermissions) 359 360 return grantedPermissions 361 } 362 363 const warnInvalidValueMsg = "invalid %s claim value type, ignoring claim, value %#v" 364 365 func (j *JWT) addPermissionsFromPermissionsClaim(tokenClaims jwt.MapClaims, permissions []string, log *logrus.Entry) []string { 366 if j.permissionsClaim == "" { 367 return permissions 368 } 369 370 permissionsFromClaim, exists := tokenClaims[j.permissionsClaim] 371 if !exists { 372 return permissions 373 } 374 375 // ["foo", "bar"] is stored as []interface{}, not []string, unfortunately 376 permissionsArray, ok := permissionsFromClaim.([]interface{}) 377 if ok { 378 var vals []string 379 for _, v := range permissionsArray { 380 p, ok := v.(string) 381 if !ok { 382 log.Warn(fmt.Sprintf(warnInvalidValueMsg, "permissions", permissionsFromClaim)) 383 return permissions 384 } 385 vals = append(vals, p) 386 } 387 for _, val := range vals { 388 permissions, _ = addPermission(permissions, val) 389 } 390 } else { 391 permissionsString, ok := permissionsFromClaim.(string) 392 if !ok { 393 log.Warn(fmt.Sprintf(warnInvalidValueMsg, "permissions", permissionsFromClaim)) 394 return permissions 395 } 396 for _, p := range strings.Split(permissionsString, " ") { 397 permissions, _ = addPermission(permissions, p) 398 } 399 } 400 return permissions 401 } 402 403 func (j *JWT) getRoleValues(rolesClaimValue interface{}, log *logrus.Entry) []string { 404 var roleValues []string 405 // ["foo", "bar"] is stored as []interface{}, not []string, unfortunately 406 rolesArray, ok := rolesClaimValue.([]interface{}) 407 if ok { 408 var vals []string 409 for _, v := range rolesArray { 410 r, ok := v.(string) 411 if !ok { 412 log.Warn(fmt.Sprintf(warnInvalidValueMsg, "roles", rolesClaimValue)) 413 return roleValues 414 } 415 vals = append(vals, r) 416 } 417 return vals 418 } 419 420 rolesString, ok := rolesClaimValue.(string) 421 if !ok { 422 log.Warn(fmt.Sprintf(warnInvalidValueMsg, "roles", rolesClaimValue)) 423 return roleValues 424 } 425 return strings.Split(rolesString, " ") 426 } 427 428 func (j *JWT) addPermissionsFromRoles(tokenClaims jwt.MapClaims, permissions []string, log *logrus.Entry) []string { 429 if j.rolesClaim == "" || j.rolesMap == nil { 430 return permissions 431 } 432 433 rolesClaimValue, exists := tokenClaims[j.rolesClaim] 434 if !exists { 435 return permissions 436 } 437 438 roleValues := j.getRoleValues(rolesClaimValue, log) 439 for _, r := range roleValues { 440 if perms, exist := j.rolesMap[r]; exist { 441 for _, p := range perms { 442 permissions, _ = addPermission(permissions, p) 443 } 444 } 445 } 446 447 if perms, exist := j.rolesMap["*"]; exist { 448 for _, p := range perms { 449 permissions, _ = addPermission(permissions, p) 450 } 451 } 452 return permissions 453 } 454 455 func (j *JWT) addMappedPermissions(source, target []string) []string { 456 if j.permissionsMap == nil { 457 return target 458 } 459 460 for _, val := range source { 461 mappedValues, exist := j.permissionsMap[val] 462 if !exist { 463 // no mapping for value 464 continue 465 } 466 467 var l []string 468 for _, mv := range mappedValues { 469 var added bool 470 // add value from mapping? 471 target, added = addPermission(target, mv) 472 if !added { 473 continue 474 } 475 l = append(l, mv) 476 } 477 // recursion: call only with values not already in target 478 target = j.addMappedPermissions(l, target) 479 } 480 return target 481 } 482 483 func addPermission(permissions []string, permission string) ([]string, bool) { 484 permission = strings.TrimSpace(permission) 485 if permission == "" { 486 return permissions, false 487 } 488 for _, p := range permissions { 489 if p == permission { 490 return permissions, false 491 } 492 } 493 return append(permissions, permission), true 494 } 495 496 func getBearer(val string) (string, error) { 497 const bearer = "bearer " 498 if strings.HasPrefix(strings.ToLower(val), bearer) { 499 return strings.Trim(val[len(bearer):], " "), nil 500 } 501 return "", fmt.Errorf("bearer required with authorization header") 502 } 503 504 // newParser creates a new parser 505 func newParser(algos []acjwt.Algorithm) *jwt.Parser { 506 var algorithms []string 507 for _, a := range algos { 508 algorithms = append(algorithms, a.String()) 509 } 510 options := []jwt.ParserOption{ 511 jwt.WithValidMethods(algorithms), 512 // no equivalent in new lib 513 // jwt.WithLeeway(time.Second), 514 } 515 516 return jwt.NewParser(options...) 517 } 518 519 // parsePublicPEMKey tries to parse all supported publicKey variations which 520 // must be given in PEM encoded format. 521 func parsePublicPEMKey(key []byte) (pub interface{}, err error) { 522 pemBlock, _ := pem.Decode(key) 523 if pemBlock == nil { 524 return nil, jwt.ErrKeyMustBePEMEncoded 525 } 526 pubKey, pubErr := x509.ParsePKCS1PublicKey(pemBlock.Bytes) 527 if pubErr != nil { 528 pkixKey, pkerr := x509.ParsePKIXPublicKey(pemBlock.Bytes) 529 if pkerr != nil { 530 cert, cerr := x509.ParseCertificate(pemBlock.Bytes) 531 if cerr != nil { 532 return nil, jwt.ErrNotRSAPublicKey 533 } 534 if k, ok := cert.PublicKey.(*rsa.PublicKey); ok { 535 return k, nil 536 } 537 if k, ok := cert.PublicKey.(*ecdsa.PublicKey); ok { 538 return k, nil 539 } 540 541 return nil, fmt.Errorf("invalid RSA/ECDSA public key") 542 } 543 544 if k, ok := pkixKey.(*rsa.PublicKey); ok { 545 return k, nil 546 } 547 548 if k, ok := pkixKey.(*ecdsa.PublicKey); ok { 549 return k, nil 550 } 551 552 return nil, fmt.Errorf("invalid RSA/ECDSA public key") 553 } 554 return pubKey, nil 555 }