github.com/lestrrat-go/jwx/v2@v2.0.21/jwt/validate.go (about)

     1  package jwt
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strconv"
     7  	"time"
     8  )
     9  
    10  type Clock interface {
    11  	Now() time.Time
    12  }
    13  type ClockFunc func() time.Time
    14  
    15  func (f ClockFunc) Now() time.Time {
    16  	return f()
    17  }
    18  
    19  func isSupportedTimeClaim(c string) error {
    20  	switch c {
    21  	case ExpirationKey, IssuedAtKey, NotBeforeKey:
    22  		return nil
    23  	}
    24  	return NewValidationError(fmt.Errorf(`unsupported time claim %s`, strconv.Quote(c)))
    25  }
    26  
    27  func timeClaim(t Token, clock Clock, c string) time.Time {
    28  	switch c {
    29  	case ExpirationKey:
    30  		return t.Expiration()
    31  	case IssuedAtKey:
    32  		return t.IssuedAt()
    33  	case NotBeforeKey:
    34  		return t.NotBefore()
    35  	case "":
    36  		return clock.Now()
    37  	}
    38  	return time.Time{} // should *NEVER* reach here, but...
    39  }
    40  
    41  // Validate makes sure that the essential claims stand.
    42  //
    43  // See the various `WithXXX` functions for optional parameters
    44  // that can control the behavior of this method.
    45  func Validate(t Token, options ...ValidateOption) error {
    46  	ctx := context.Background()
    47  	trunc := time.Second
    48  
    49  	var clock Clock = ClockFunc(time.Now)
    50  	var skew time.Duration
    51  	var validators = []Validator{
    52  		IsIssuedAtValid(),
    53  		IsExpirationValid(),
    54  		IsNbfValid(),
    55  	}
    56  	for _, o := range options {
    57  		//nolint:forcetypeassert
    58  		switch o.Ident() {
    59  		case identClock{}:
    60  			clock = o.Value().(Clock)
    61  		case identAcceptableSkew{}:
    62  			skew = o.Value().(time.Duration)
    63  		case identTruncation{}:
    64  			trunc = o.Value().(time.Duration)
    65  		case identContext{}:
    66  			ctx = o.Value().(context.Context)
    67  		case identValidator{}:
    68  			v := o.Value().(Validator)
    69  			switch v := v.(type) {
    70  			case *isInTimeRange:
    71  				if v.c1 != "" {
    72  					if err := isSupportedTimeClaim(v.c1); err != nil {
    73  						return err
    74  					}
    75  					validators = append(validators, IsRequired(v.c1))
    76  				}
    77  				if v.c2 != "" {
    78  					if err := isSupportedTimeClaim(v.c2); err != nil {
    79  						return err
    80  					}
    81  					validators = append(validators, IsRequired(v.c2))
    82  				}
    83  			}
    84  			validators = append(validators, v)
    85  		}
    86  	}
    87  
    88  	ctx = SetValidationCtxSkew(ctx, skew)
    89  	ctx = SetValidationCtxClock(ctx, clock)
    90  	ctx = SetValidationCtxTruncation(ctx, trunc)
    91  	for _, v := range validators {
    92  		if err := v.Validate(ctx, t); err != nil {
    93  			return err
    94  		}
    95  	}
    96  
    97  	return nil
    98  }
    99  
   100  type isInTimeRange struct {
   101  	c1   string
   102  	c2   string
   103  	dur  time.Duration
   104  	less bool // if true, d =< c1 - c2. otherwise d >= c1 - c2
   105  }
   106  
   107  // MaxDeltaIs implements the logic behind `WithMaxDelta()` option
   108  func MaxDeltaIs(c1, c2 string, dur time.Duration) Validator {
   109  	return &isInTimeRange{
   110  		c1:   c1,
   111  		c2:   c2,
   112  		dur:  dur,
   113  		less: true,
   114  	}
   115  }
   116  
   117  // MinDeltaIs implements the logic behind `WithMinDelta()` option
   118  func MinDeltaIs(c1, c2 string, dur time.Duration) Validator {
   119  	return &isInTimeRange{
   120  		c1:   c1,
   121  		c2:   c2,
   122  		dur:  dur,
   123  		less: false,
   124  	}
   125  }
   126  
   127  func (iitr *isInTimeRange) Validate(ctx context.Context, t Token) ValidationError {
   128  	clock := ValidationCtxClock(ctx) // MUST be populated
   129  	skew := ValidationCtxSkew(ctx)   // MUST be populated
   130  	// We don't check if the claims already exist, because we already did that
   131  	// by piggybacking on `required` check.
   132  	t1 := timeClaim(t, clock, iitr.c1)
   133  	t2 := timeClaim(t, clock, iitr.c2)
   134  	if iitr.less { // t1 - t2 <= iitr.dur
   135  		// t1 - t2 < iitr.dur + skew
   136  		if t1.Sub(t2) > iitr.dur+skew {
   137  			return NewValidationError(fmt.Errorf(`iitr between %s and %s exceeds %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew))
   138  		}
   139  	} else {
   140  		if t1.Sub(t2) < iitr.dur-skew {
   141  			return NewValidationError(fmt.Errorf(`iitr between %s and %s is less than %s (skew %s)`, iitr.c1, iitr.c2, iitr.dur, skew))
   142  		}
   143  	}
   144  	return nil
   145  }
   146  
   147  type ValidationError interface {
   148  	error
   149  	isValidationError()
   150  	Unwrap() error
   151  }
   152  
   153  func NewValidationError(err error) ValidationError {
   154  	return &validationError{error: err}
   155  }
   156  
   157  // This is a generic validation error.
   158  type validationError struct {
   159  	error
   160  }
   161  
   162  func (validationError) isValidationError() {}
   163  func (err *validationError) Unwrap() error {
   164  	return err.error
   165  }
   166  
   167  type missingRequiredClaimError struct {
   168  	claim string
   169  }
   170  
   171  func (err *missingRequiredClaimError) Error() string {
   172  	return fmt.Sprintf("%q not satisfied: required claim not found", err.claim)
   173  }
   174  
   175  func (err *missingRequiredClaimError) Is(target error) bool {
   176  	_, ok := target.(*missingRequiredClaimError)
   177  	return ok
   178  }
   179  
   180  func (err *missingRequiredClaimError) isValidationError() {}
   181  func (*missingRequiredClaimError) Unwrap() error          { return nil }
   182  
   183  type invalidAudienceError struct {
   184  	error
   185  }
   186  
   187  func (err *invalidAudienceError) Is(target error) bool {
   188  	_, ok := target.(*invalidAudienceError)
   189  	return ok
   190  }
   191  
   192  func (err *invalidAudienceError) isValidationError() {}
   193  func (err *invalidAudienceError) Unwrap() error {
   194  	return err.error
   195  }
   196  
   197  func (err *invalidAudienceError) Error() string {
   198  	if err.error == nil {
   199  		return `"aud" not satisfied`
   200  	}
   201  	return err.error.Error()
   202  }
   203  
   204  type invalidIssuerError struct {
   205  	error
   206  }
   207  
   208  func (err *invalidIssuerError) Is(target error) bool {
   209  	_, ok := target.(*invalidIssuerError)
   210  	return ok
   211  }
   212  
   213  func (err *invalidIssuerError) isValidationError() {}
   214  func (err *invalidIssuerError) Unwrap() error {
   215  	return err.error
   216  }
   217  
   218  func (err *invalidIssuerError) Error() string {
   219  	if err.error == nil {
   220  		return `"iss" not satisfied`
   221  	}
   222  	return err.error.Error()
   223  }
   224  
   225  var errTokenExpired = NewValidationError(fmt.Errorf(`"exp" not satisfied`))
   226  var errInvalidIssuedAt = NewValidationError(fmt.Errorf(`"iat" not satisfied`))
   227  var errTokenNotYetValid = NewValidationError(fmt.Errorf(`"nbf" not satisfied`))
   228  var errInvalidAudience = &invalidAudienceError{}
   229  var errInvalidIssuer = &invalidIssuerError{}
   230  var errRequiredClaim = &missingRequiredClaimError{}
   231  
   232  // ErrTokenExpired returns the immutable error used when `exp` claim
   233  // is not satisfied.
   234  //
   235  // The return value should only be used for comparison using `errors.Is()`
   236  func ErrTokenExpired() ValidationError {
   237  	return errTokenExpired
   238  }
   239  
   240  // ErrInvalidIssuedAt returns the immutable error used when `iat` claim
   241  // is not satisfied
   242  //
   243  // The return value should only be used for comparison using `errors.Is()`
   244  func ErrInvalidIssuedAt() ValidationError {
   245  	return errInvalidIssuedAt
   246  }
   247  
   248  // ErrTokenNotYetValid returns the immutable error used when `nbf` claim
   249  // is not satisfied
   250  //
   251  // The return value should only be used for comparison using `errors.Is()`
   252  func ErrTokenNotYetValid() ValidationError {
   253  	return errTokenNotYetValid
   254  }
   255  
   256  // ErrInvalidAudience returns the immutable error used when `aud` claim
   257  // is not satisfied
   258  //
   259  // The return value should only be used for comparison using `errors.Is()`
   260  func ErrInvalidAudience() ValidationError {
   261  	return errInvalidAudience
   262  }
   263  
   264  // ErrInvalidIssuer returns the immutable error used when `iss` claim
   265  // is not satisfied
   266  //
   267  // The return value should only be used for comparison using `errors.Is()`
   268  func ErrInvalidIssuer() ValidationError {
   269  	return errInvalidIssuer
   270  }
   271  
   272  // ErrMissingRequiredClaim should not have been exported, and will be
   273  // removed in a future release. Use `ErrRequiredClaim()` instead to get
   274  // an error to be used in `errors.Is()`
   275  //
   276  // This function should not have been implemented as a constructor.
   277  // but rather a means to retrieve an opaque and immutable error value
   278  // that could be passed to `errors.Is()`.
   279  func ErrMissingRequiredClaim(name string) ValidationError {
   280  	return &missingRequiredClaimError{claim: name}
   281  }
   282  
   283  // ErrRequiredClaim returns the immutable error used when the claim
   284  // specified by `jwt.IsRequired()` is not present.
   285  //
   286  // The return value should only be used for comparison using `errors.Is()`
   287  func ErrRequiredClaim() ValidationError {
   288  	return errRequiredClaim
   289  }
   290  
   291  // Validator describes interface to validate a Token.
   292  type Validator interface {
   293  	// Validate should return an error if a required conditions is not met.
   294  	Validate(context.Context, Token) ValidationError
   295  }
   296  
   297  // ValidatorFunc is a type of Validator that does not have any
   298  // state, that is implemented as a function
   299  type ValidatorFunc func(context.Context, Token) ValidationError
   300  
   301  func (vf ValidatorFunc) Validate(ctx context.Context, tok Token) ValidationError {
   302  	return vf(ctx, tok)
   303  }
   304  
   305  type identValidationCtxClock struct{}
   306  type identValidationCtxSkew struct{}
   307  type identValidationCtxTruncation struct{}
   308  
   309  func SetValidationCtxClock(ctx context.Context, cl Clock) context.Context {
   310  	return context.WithValue(ctx, identValidationCtxClock{}, cl)
   311  }
   312  
   313  func SetValidationCtxTruncation(ctx context.Context, dur time.Duration) context.Context {
   314  	return context.WithValue(ctx, identValidationCtxTruncation{}, dur)
   315  }
   316  
   317  func SetValidationCtxSkew(ctx context.Context, dur time.Duration) context.Context {
   318  	return context.WithValue(ctx, identValidationCtxSkew{}, dur)
   319  }
   320  
   321  // ValidationCtxClock returns the Clock object associated with
   322  // the current validation context. This value will always be available
   323  // during validation of tokens.
   324  func ValidationCtxClock(ctx context.Context) Clock {
   325  	//nolint:forcetypeassert
   326  	return ctx.Value(identValidationCtxClock{}).(Clock)
   327  }
   328  
   329  func ValidationCtxSkew(ctx context.Context) time.Duration {
   330  	//nolint:forcetypeassert
   331  	return ctx.Value(identValidationCtxSkew{}).(time.Duration)
   332  }
   333  
   334  func ValidationCtxTruncation(ctx context.Context) time.Duration {
   335  	//nolint:forcetypeassert
   336  	return ctx.Value(identValidationCtxTruncation{}).(time.Duration)
   337  }
   338  
   339  // IsExpirationValid is one of the default validators that will be executed.
   340  // It does not need to be specified by users, but it exists as an
   341  // exported field so that you can check what it does.
   342  //
   343  // The supplied context.Context object must have the "clock" and "skew"
   344  // populated with appropriate values using SetValidationCtxClock() and
   345  // SetValidationCtxSkew()
   346  func IsExpirationValid() Validator {
   347  	return ValidatorFunc(isExpirationValid)
   348  }
   349  
   350  func isExpirationValid(ctx context.Context, t Token) ValidationError {
   351  	tv := t.Expiration()
   352  	if tv.IsZero() || tv.Unix() == 0 {
   353  		return nil
   354  	}
   355  
   356  	clock := ValidationCtxClock(ctx)      // MUST be populated
   357  	skew := ValidationCtxSkew(ctx)        // MUST be populated
   358  	trunc := ValidationCtxTruncation(ctx) // MUST be populated
   359  
   360  	now := clock.Now().Truncate(trunc)
   361  	ttv := tv.Truncate(trunc)
   362  
   363  	// expiration date must be after NOW
   364  	if !now.Before(ttv.Add(skew)) {
   365  		return ErrTokenExpired()
   366  	}
   367  	return nil
   368  }
   369  
   370  // IsIssuedAtValid is one of the default validators that will be executed.
   371  // It does not need to be specified by users, but it exists as an
   372  // exported field so that you can check what it does.
   373  //
   374  // The supplied context.Context object must have the "clock" and "skew"
   375  // populated with appropriate values using SetValidationCtxClock() and
   376  // SetValidationCtxSkew()
   377  func IsIssuedAtValid() Validator {
   378  	return ValidatorFunc(isIssuedAtValid)
   379  }
   380  
   381  func isIssuedAtValid(ctx context.Context, t Token) ValidationError {
   382  	tv := t.IssuedAt()
   383  	if tv.IsZero() || tv.Unix() == 0 {
   384  		return nil
   385  	}
   386  
   387  	clock := ValidationCtxClock(ctx)      // MUST be populated
   388  	skew := ValidationCtxSkew(ctx)        // MUST be populated
   389  	trunc := ValidationCtxTruncation(ctx) // MUST be populated
   390  
   391  	now := clock.Now().Truncate(trunc)
   392  	ttv := tv.Truncate(trunc)
   393  
   394  	if now.Before(ttv.Add(-1 * skew)) {
   395  		return ErrInvalidIssuedAt()
   396  	}
   397  	return nil
   398  }
   399  
   400  // IsNbfValid is one of the default validators that will be executed.
   401  // It does not need to be specified by users, but it exists as an
   402  // exported field so that you can check what it does.
   403  //
   404  // The supplied context.Context object must have the "clock" and "skew"
   405  // populated with appropriate values using SetValidationCtxClock() and
   406  // SetValidationCtxSkew()
   407  func IsNbfValid() Validator {
   408  	return ValidatorFunc(isNbfValid)
   409  }
   410  
   411  func isNbfValid(ctx context.Context, t Token) ValidationError {
   412  	tv := t.NotBefore()
   413  	if tv.IsZero() || tv.Unix() == 0 {
   414  		return nil
   415  	}
   416  
   417  	clock := ValidationCtxClock(ctx)      // MUST be populated
   418  	skew := ValidationCtxSkew(ctx)        // MUST be populated
   419  	trunc := ValidationCtxTruncation(ctx) // MUST be populated
   420  
   421  	// Truncation always happens even for trunc = 0 because
   422  	// we also use this to strip monotonic clocks
   423  	now := clock.Now().Truncate(trunc)
   424  	ttv := tv.Truncate(trunc)
   425  
   426  	// "now" cannot be before t - skew, so we check for now > t - skew
   427  	ttv = ttv.Add(-1 * skew)
   428  	if now.Before(ttv) {
   429  		return ErrTokenNotYetValid()
   430  	}
   431  	return nil
   432  }
   433  
   434  type claimContainsString struct {
   435  	name    string
   436  	value   string
   437  	makeErr func(error) ValidationError
   438  }
   439  
   440  // ClaimContainsString can be used to check if the claim called `name`, which is
   441  // expected to be a list of strings, contains `value`. Currently because of the
   442  // implementation this will probably only work for `aud` fields.
   443  func ClaimContainsString(name, value string) Validator {
   444  	return claimContainsString{
   445  		name:    name,
   446  		value:   value,
   447  		makeErr: NewValidationError,
   448  	}
   449  }
   450  
   451  // IsValidationError returns true if the error is a validation error
   452  func IsValidationError(err error) bool {
   453  	switch err {
   454  	case errTokenExpired, errTokenNotYetValid, errInvalidIssuedAt:
   455  		return true
   456  	default:
   457  		switch err.(type) {
   458  		case *validationError, *invalidAudienceError, *invalidIssuerError, *missingRequiredClaimError:
   459  			return true
   460  		default:
   461  			return false
   462  		}
   463  	}
   464  }
   465  
   466  func (ccs claimContainsString) Validate(_ context.Context, t Token) ValidationError {
   467  	v, ok := t.Get(ccs.name)
   468  	if !ok {
   469  		return ccs.makeErr(fmt.Errorf(`claim %q not found`, ccs.name))
   470  	}
   471  
   472  	list, ok := v.([]string)
   473  	if !ok {
   474  		return ccs.makeErr(fmt.Errorf(`claim %q must be a []string (got %T)`, ccs.name, v))
   475  	}
   476  
   477  	for _, v := range list {
   478  		if v == ccs.value {
   479  			return nil
   480  		}
   481  	}
   482  	return ccs.makeErr(fmt.Errorf(`%q not satisfied`, ccs.name))
   483  }
   484  
   485  func makeInvalidAudienceError(err error) ValidationError {
   486  	return &invalidAudienceError{error: err}
   487  }
   488  
   489  // audienceClaimContainsString can be used to check if the audience claim, which is
   490  // expected to be a list of strings, contains `value`.
   491  func audienceClaimContainsString(value string) Validator {
   492  	return claimContainsString{
   493  		name:    AudienceKey,
   494  		value:   value,
   495  		makeErr: makeInvalidAudienceError,
   496  	}
   497  }
   498  
   499  type claimValueIs struct {
   500  	name    string
   501  	value   interface{}
   502  	makeErr func(error) ValidationError
   503  }
   504  
   505  // ClaimValueIs creates a Validator that checks if the value of claim `name`
   506  // matches `value`. The comparison is done using a simple `==` comparison,
   507  // and therefore complex comparisons may fail using this code. If you
   508  // need to do more, use a custom Validator.
   509  func ClaimValueIs(name string, value interface{}) Validator {
   510  	return &claimValueIs{
   511  		name:    name,
   512  		value:   value,
   513  		makeErr: NewValidationError,
   514  	}
   515  }
   516  
   517  func (cv *claimValueIs) Validate(_ context.Context, t Token) ValidationError {
   518  	v, ok := t.Get(cv.name)
   519  	if !ok {
   520  		return cv.makeErr(fmt.Errorf(`%q not satisfied: claim %q does not exist`, cv.name, cv.name))
   521  	}
   522  	if v != cv.value {
   523  		return cv.makeErr(fmt.Errorf(`%q not satisfied: values do not match`, cv.name))
   524  	}
   525  	return nil
   526  }
   527  
   528  func makeIssuerClaimError(err error) ValidationError {
   529  	return &invalidIssuerError{error: err}
   530  }
   531  
   532  // issuerClaimValueIs creates a Validator that checks if the issuer claim
   533  // matches `value`.
   534  func issuerClaimValueIs(value string) Validator {
   535  	return &claimValueIs{
   536  		name:    IssuerKey,
   537  		value:   value,
   538  		makeErr: makeIssuerClaimError,
   539  	}
   540  }
   541  
   542  // IsRequired creates a Validator that checks if the required claim `name`
   543  // exists in the token
   544  func IsRequired(name string) Validator {
   545  	return isRequired(name)
   546  }
   547  
   548  type isRequired string
   549  
   550  func (ir isRequired) Validate(_ context.Context, t Token) ValidationError {
   551  	name := string(ir)
   552  	_, ok := t.Get(name)
   553  	if !ok {
   554  		return &missingRequiredClaimError{claim: name}
   555  	}
   556  	return nil
   557  }