github.com/lestrrat-go/jwx/v2@v2.0.21/jwt/validate.go (about) 1 package jwt 2 3 import ( 4 "context" 5 "fmt" 6 "strconv" 7 "time" 8 ) 9 10 type Clock interface { 11 Now() time.Time 12 } 13 type ClockFunc func() time.Time 14 15 func (f ClockFunc) Now() time.Time { 16 return f() 17 } 18 19 func isSupportedTimeClaim(c string) error { 20 switch c { 21 case ExpirationKey, IssuedAtKey, NotBeforeKey: 22 return nil 23 } 24 return NewValidationError(fmt.Errorf(`unsupported time claim %s`, strconv.Quote(c))) 25 } 26 27 func timeClaim(t Token, clock Clock, c string) time.Time { 28 switch c { 29 case ExpirationKey: 30 return t.Expiration() 31 case IssuedAtKey: 32 return t.IssuedAt() 33 case NotBeforeKey: 34 return t.NotBefore() 35 case "": 36 return clock.Now() 37 } 38 return time.Time{} // should *NEVER* reach here, but... 39 } 40 41 // Validate makes sure that the essential claims stand. 42 // 43 // See the various `WithXXX` functions for optional parameters 44 // that can control the behavior of this method. 45 func Validate(t Token, options ...ValidateOption) error { 46 ctx := context.Background() 47 trunc := time.Second 48 49 var clock Clock = ClockFunc(time.Now) 50 var skew time.Duration 51 var validators = []Validator{ 52 IsIssuedAtValid(), 53 IsExpirationValid(), 54 IsNbfValid(), 55 } 56 for _, o := range options { 57 //nolint:forcetypeassert 58 switch o.Ident() { 59 case identClock{}: 60 clock = o.Value().(Clock) 61 case identAcceptableSkew{}: 62 skew = o.Value().(time.Duration) 63 case identTruncation{}: 64 trunc = o.Value().(time.Duration) 65 case identContext{}: 66 ctx = o.Value().(context.Context) 67 case identValidator{}: 68 v := o.Value().(Validator) 69 switch v := v.(type) { 70 case *isInTimeRange: 71 if v.c1 != "" { 72 if err := isSupportedTimeClaim(v.c1); err != nil { 73 return err 74 } 75 validators = append(validators, IsRequired(v.c1)) 76 } 77 if v.c2 != "" { 78 if err := isSupportedTimeClaim(v.c2); err != nil { 79 return err 80 } 81 validators = append(validators, IsRequired(v.c2)) 82 } 83 } 84 validators = append(validators, v) 85 } 86 } 87 88 ctx = SetValidationCtxSkew(ctx, skew) 89 ctx = SetValidationCtxClock(ctx, clock) 90 ctx = SetValidationCtxTruncation(ctx, trunc) 91 for _, v := range validators { 92 if err := v.Validate(ctx, t); err != nil { 93 return err 94 } 95 } 96 97 return nil 98 } 99 100 type isInTimeRange struct { 101 c1 string 102 c2 string 103 dur time.Duration 104 less bool // if true, d =< c1 - c2. otherwise d >= c1 - c2 105 } 106 107 // MaxDeltaIs implements the logic behind `WithMaxDelta()` option 108 func MaxDeltaIs(c1, c2 string, dur time.Duration) Validator { 109 return &isInTimeRange{ 110 c1: c1, 111 c2: c2, 112 dur: dur, 113 less: true, 114 } 115 } 116 117 // MinDeltaIs implements the logic behind `WithMinDelta()` option 118 func MinDeltaIs(c1, c2 string, dur time.Duration) Validator { 119 return &isInTimeRange{ 120 c1: c1, 121 c2: c2, 122 dur: dur, 123 less: false, 124 } 125 } 126 127 func (iitr *isInTimeRange) Validate(ctx context.Context, t Token) ValidationError { 128 clock := ValidationCtxClock(ctx) // MUST be populated 129 skew := ValidationCtxSkew(ctx) // MUST be populated 130 // We don't check if the claims already exist, because we already did that 131 // by piggybacking on `required` check. 132 t1 := timeClaim(t, clock, iitr.c1) 133 t2 := timeClaim(t, clock, iitr.c2) 134 if iitr.less { // t1 - t2 <= iitr.dur 135 // t1 - t2 < iitr.dur + skew 136 if t1.Sub(t2) > iitr.dur+skew { 137 return NewValidationError(fmt.Errorf(`iitr between %s and %s exceeds %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew)) 138 } 139 } else { 140 if t1.Sub(t2) < iitr.dur-skew { 141 return NewValidationError(fmt.Errorf(`iitr between %s and %s is less than %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew)) 142 } 143 } 144 return nil 145 } 146 147 type ValidationError interface { 148 error 149 isValidationError() 150 Unwrap() error 151 } 152 153 func NewValidationError(err error) ValidationError { 154 return &validationError{error: err} 155 } 156 157 // This is a generic validation error. 158 type validationError struct { 159 error 160 } 161 162 func (validationError) isValidationError() {} 163 func (err *validationError) Unwrap() error { 164 return err.error 165 } 166 167 type missingRequiredClaimError struct { 168 claim string 169 } 170 171 func (err *missingRequiredClaimError) Error() string { 172 return fmt.Sprintf("%q not satisfied: required claim not found", err.claim) 173 } 174 175 func (err *missingRequiredClaimError) Is(target error) bool { 176 _, ok := target.(*missingRequiredClaimError) 177 return ok 178 } 179 180 func (err *missingRequiredClaimError) isValidationError() {} 181 func (*missingRequiredClaimError) Unwrap() error { return nil } 182 183 type invalidAudienceError struct { 184 error 185 } 186 187 func (err *invalidAudienceError) Is(target error) bool { 188 _, ok := target.(*invalidAudienceError) 189 return ok 190 } 191 192 func (err *invalidAudienceError) isValidationError() {} 193 func (err *invalidAudienceError) Unwrap() error { 194 return err.error 195 } 196 197 func (err *invalidAudienceError) Error() string { 198 if err.error == nil { 199 return `"aud" not satisfied` 200 } 201 return err.error.Error() 202 } 203 204 type invalidIssuerError struct { 205 error 206 } 207 208 func (err *invalidIssuerError) Is(target error) bool { 209 _, ok := target.(*invalidIssuerError) 210 return ok 211 } 212 213 func (err *invalidIssuerError) isValidationError() {} 214 func (err *invalidIssuerError) Unwrap() error { 215 return err.error 216 } 217 218 func (err *invalidIssuerError) Error() string { 219 if err.error == nil { 220 return `"iss" not satisfied` 221 } 222 return err.error.Error() 223 } 224 225 var errTokenExpired = NewValidationError(fmt.Errorf(`"exp" not satisfied`)) 226 var errInvalidIssuedAt = NewValidationError(fmt.Errorf(`"iat" not satisfied`)) 227 var errTokenNotYetValid = NewValidationError(fmt.Errorf(`"nbf" not satisfied`)) 228 var errInvalidAudience = &invalidAudienceError{} 229 var errInvalidIssuer = &invalidIssuerError{} 230 var errRequiredClaim = &missingRequiredClaimError{} 231 232 // ErrTokenExpired returns the immutable error used when `exp` claim 233 // is not satisfied. 234 // 235 // The return value should only be used for comparison using `errors.Is()` 236 func ErrTokenExpired() ValidationError { 237 return errTokenExpired 238 } 239 240 // ErrInvalidIssuedAt returns the immutable error used when `iat` claim 241 // is not satisfied 242 // 243 // The return value should only be used for comparison using `errors.Is()` 244 func ErrInvalidIssuedAt() ValidationError { 245 return errInvalidIssuedAt 246 } 247 248 // ErrTokenNotYetValid returns the immutable error used when `nbf` claim 249 // is not satisfied 250 // 251 // The return value should only be used for comparison using `errors.Is()` 252 func ErrTokenNotYetValid() ValidationError { 253 return errTokenNotYetValid 254 } 255 256 // ErrInvalidAudience returns the immutable error used when `aud` claim 257 // is not satisfied 258 // 259 // The return value should only be used for comparison using `errors.Is()` 260 func ErrInvalidAudience() ValidationError { 261 return errInvalidAudience 262 } 263 264 // ErrInvalidIssuer returns the immutable error used when `iss` claim 265 // is not satisfied 266 // 267 // The return value should only be used for comparison using `errors.Is()` 268 func ErrInvalidIssuer() ValidationError { 269 return errInvalidIssuer 270 } 271 272 // ErrMissingRequiredClaim should not have been exported, and will be 273 // removed in a future release. Use `ErrRequiredClaim()` instead to get 274 // an error to be used in `errors.Is()` 275 // 276 // This function should not have been implemented as a constructor. 277 // but rather a means to retrieve an opaque and immutable error value 278 // that could be passed to `errors.Is()`. 279 func ErrMissingRequiredClaim(name string) ValidationError { 280 return &missingRequiredClaimError{claim: name} 281 } 282 283 // ErrRequiredClaim returns the immutable error used when the claim 284 // specified by `jwt.IsRequired()` is not present. 285 // 286 // The return value should only be used for comparison using `errors.Is()` 287 func ErrRequiredClaim() ValidationError { 288 return errRequiredClaim 289 } 290 291 // Validator describes interface to validate a Token. 292 type Validator interface { 293 // Validate should return an error if a required conditions is not met. 294 Validate(context.Context, Token) ValidationError 295 } 296 297 // ValidatorFunc is a type of Validator that does not have any 298 // state, that is implemented as a function 299 type ValidatorFunc func(context.Context, Token) ValidationError 300 301 func (vf ValidatorFunc) Validate(ctx context.Context, tok Token) ValidationError { 302 return vf(ctx, tok) 303 } 304 305 type identValidationCtxClock struct{} 306 type identValidationCtxSkew struct{} 307 type identValidationCtxTruncation struct{} 308 309 func SetValidationCtxClock(ctx context.Context, cl Clock) context.Context { 310 return context.WithValue(ctx, identValidationCtxClock{}, cl) 311 } 312 313 func SetValidationCtxTruncation(ctx context.Context, dur time.Duration) context.Context { 314 return context.WithValue(ctx, identValidationCtxTruncation{}, dur) 315 } 316 317 func SetValidationCtxSkew(ctx context.Context, dur time.Duration) context.Context { 318 return context.WithValue(ctx, identValidationCtxSkew{}, dur) 319 } 320 321 // ValidationCtxClock returns the Clock object associated with 322 // the current validation context. This value will always be available 323 // during validation of tokens. 324 func ValidationCtxClock(ctx context.Context) Clock { 325 //nolint:forcetypeassert 326 return ctx.Value(identValidationCtxClock{}).(Clock) 327 } 328 329 func ValidationCtxSkew(ctx context.Context) time.Duration { 330 //nolint:forcetypeassert 331 return ctx.Value(identValidationCtxSkew{}).(time.Duration) 332 } 333 334 func ValidationCtxTruncation(ctx context.Context) time.Duration { 335 //nolint:forcetypeassert 336 return ctx.Value(identValidationCtxTruncation{}).(time.Duration) 337 } 338 339 // IsExpirationValid is one of the default validators that will be executed. 340 // It does not need to be specified by users, but it exists as an 341 // exported field so that you can check what it does. 342 // 343 // The supplied context.Context object must have the "clock" and "skew" 344 // populated with appropriate values using SetValidationCtxClock() and 345 // SetValidationCtxSkew() 346 func IsExpirationValid() Validator { 347 return ValidatorFunc(isExpirationValid) 348 } 349 350 func isExpirationValid(ctx context.Context, t Token) ValidationError { 351 tv := t.Expiration() 352 if tv.IsZero() || tv.Unix() == 0 { 353 return nil 354 } 355 356 clock := ValidationCtxClock(ctx) // MUST be populated 357 skew := ValidationCtxSkew(ctx) // MUST be populated 358 trunc := ValidationCtxTruncation(ctx) // MUST be populated 359 360 now := clock.Now().Truncate(trunc) 361 ttv := tv.Truncate(trunc) 362 363 // expiration date must be after NOW 364 if !now.Before(ttv.Add(skew)) { 365 return ErrTokenExpired() 366 } 367 return nil 368 } 369 370 // IsIssuedAtValid is one of the default validators that will be executed. 371 // It does not need to be specified by users, but it exists as an 372 // exported field so that you can check what it does. 373 // 374 // The supplied context.Context object must have the "clock" and "skew" 375 // populated with appropriate values using SetValidationCtxClock() and 376 // SetValidationCtxSkew() 377 func IsIssuedAtValid() Validator { 378 return ValidatorFunc(isIssuedAtValid) 379 } 380 381 func isIssuedAtValid(ctx context.Context, t Token) ValidationError { 382 tv := t.IssuedAt() 383 if tv.IsZero() || tv.Unix() == 0 { 384 return nil 385 } 386 387 clock := ValidationCtxClock(ctx) // MUST be populated 388 skew := ValidationCtxSkew(ctx) // MUST be populated 389 trunc := ValidationCtxTruncation(ctx) // MUST be populated 390 391 now := clock.Now().Truncate(trunc) 392 ttv := tv.Truncate(trunc) 393 394 if now.Before(ttv.Add(-1 * skew)) { 395 return ErrInvalidIssuedAt() 396 } 397 return nil 398 } 399 400 // IsNbfValid is one of the default validators that will be executed. 401 // It does not need to be specified by users, but it exists as an 402 // exported field so that you can check what it does. 403 // 404 // The supplied context.Context object must have the "clock" and "skew" 405 // populated with appropriate values using SetValidationCtxClock() and 406 // SetValidationCtxSkew() 407 func IsNbfValid() Validator { 408 return ValidatorFunc(isNbfValid) 409 } 410 411 func isNbfValid(ctx context.Context, t Token) ValidationError { 412 tv := t.NotBefore() 413 if tv.IsZero() || tv.Unix() == 0 { 414 return nil 415 } 416 417 clock := ValidationCtxClock(ctx) // MUST be populated 418 skew := ValidationCtxSkew(ctx) // MUST be populated 419 trunc := ValidationCtxTruncation(ctx) // MUST be populated 420 421 // Truncation always happens even for trunc = 0 because 422 // we also use this to strip monotonic clocks 423 now := clock.Now().Truncate(trunc) 424 ttv := tv.Truncate(trunc) 425 426 // "now" cannot be before t - skew, so we check for now > t - skew 427 ttv = ttv.Add(-1 * skew) 428 if now.Before(ttv) { 429 return ErrTokenNotYetValid() 430 } 431 return nil 432 } 433 434 type claimContainsString struct { 435 name string 436 value string 437 makeErr func(error) ValidationError 438 } 439 440 // ClaimContainsString can be used to check if the claim called `name`, which is 441 // expected to be a list of strings, contains `value`. Currently because of the 442 // implementation this will probably only work for `aud` fields. 443 func ClaimContainsString(name, value string) Validator { 444 return claimContainsString{ 445 name: name, 446 value: value, 447 makeErr: NewValidationError, 448 } 449 } 450 451 // IsValidationError returns true if the error is a validation error 452 func IsValidationError(err error) bool { 453 switch err { 454 case errTokenExpired, errTokenNotYetValid, errInvalidIssuedAt: 455 return true 456 default: 457 switch err.(type) { 458 case *validationError, *invalidAudienceError, *invalidIssuerError, *missingRequiredClaimError: 459 return true 460 default: 461 return false 462 } 463 } 464 } 465 466 func (ccs claimContainsString) Validate(_ context.Context, t Token) ValidationError { 467 v, ok := t.Get(ccs.name) 468 if !ok { 469 return ccs.makeErr(fmt.Errorf(`claim %q not found`, ccs.name)) 470 } 471 472 list, ok := v.([]string) 473 if !ok { 474 return ccs.makeErr(fmt.Errorf(`claim %q must be a []string (got %T)`, ccs.name, v)) 475 } 476 477 for _, v := range list { 478 if v == ccs.value { 479 return nil 480 } 481 } 482 return ccs.makeErr(fmt.Errorf(`%q not satisfied`, ccs.name)) 483 } 484 485 func makeInvalidAudienceError(err error) ValidationError { 486 return &invalidAudienceError{error: err} 487 } 488 489 // audienceClaimContainsString can be used to check if the audience claim, which is 490 // expected to be a list of strings, contains `value`. 491 func audienceClaimContainsString(value string) Validator { 492 return claimContainsString{ 493 name: AudienceKey, 494 value: value, 495 makeErr: makeInvalidAudienceError, 496 } 497 } 498 499 type claimValueIs struct { 500 name string 501 value interface{} 502 makeErr func(error) ValidationError 503 } 504 505 // ClaimValueIs creates a Validator that checks if the value of claim `name` 506 // matches `value`. The comparison is done using a simple `==` comparison, 507 // and therefore complex comparisons may fail using this code. If you 508 // need to do more, use a custom Validator. 509 func ClaimValueIs(name string, value interface{}) Validator { 510 return &claimValueIs{ 511 name: name, 512 value: value, 513 makeErr: NewValidationError, 514 } 515 } 516 517 func (cv *claimValueIs) Validate(_ context.Context, t Token) ValidationError { 518 v, ok := t.Get(cv.name) 519 if !ok { 520 return cv.makeErr(fmt.Errorf(`%q not satisfied: claim %q does not exist`, cv.name, cv.name)) 521 } 522 if v != cv.value { 523 return cv.makeErr(fmt.Errorf(`%q not satisfied: values do not match`, cv.name)) 524 } 525 return nil 526 } 527 528 func makeIssuerClaimError(err error) ValidationError { 529 return &invalidIssuerError{error: err} 530 } 531 532 // issuerClaimValueIs creates a Validator that checks if the issuer claim 533 // matches `value`. 534 func issuerClaimValueIs(value string) Validator { 535 return &claimValueIs{ 536 name: IssuerKey, 537 value: value, 538 makeErr: makeIssuerClaimError, 539 } 540 } 541 542 // IsRequired creates a Validator that checks if the required claim `name` 543 // exists in the token 544 func IsRequired(name string) Validator { 545 return isRequired(name) 546 } 547 548 type isRequired string 549 550 func (ir isRequired) Validate(_ context.Context, t Token) ValidationError { 551 name := string(ir) 552 _, ok := t.Get(name) 553 if !ok { 554 return &missingRequiredClaimError{claim: name} 555 } 556 return nil 557 }