github.com/lestrrat-go/jwx/v2@v2.0.21/jwt/jwt.go (about)

     1  //go:generate ../tools/cmd/genjwt.sh
     2  //go:generate stringer -type=TokenOption -output=token_options_gen.go
     3  
     4  // Package jwt implements JSON Web Tokens as described in https://tools.ietf.org/html/rfc7519
     5  package jwt
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"sync/atomic"
    13  
    14  	"github.com/lestrrat-go/jwx/v2"
    15  	"github.com/lestrrat-go/jwx/v2/internal/json"
    16  	"github.com/lestrrat-go/jwx/v2/jws"
    17  	"github.com/lestrrat-go/jwx/v2/jwt/internal/types"
    18  )
    19  
    20  var compactOnly uint32
    21  var errInvalidJWT = errors.New(`invalid JWT`)
    22  
    23  // ErrInvalidJWT returns the opaque error value that is returned when
    24  // `jwt.Parse` fails due to not being able to deduce the format of
    25  // the incoming buffer
    26  func ErrInvalidJWT() error {
    27  	return errInvalidJWT
    28  }
    29  
    30  // Settings controls global settings that are specific to JWTs.
    31  func Settings(options ...GlobalOption) {
    32  	var flattenAudience bool
    33  	var compactOnlyBool bool
    34  	var parsePedantic bool
    35  	var parsePrecision = types.MaxPrecision + 1  // illegal value, so we can detect nothing was set
    36  	var formatPrecision = types.MaxPrecision + 1 // illegal value, so we can detect nothing was set
    37  
    38  	//nolint:forcetypeassert
    39  	for _, option := range options {
    40  		switch option.Ident() {
    41  		case identFlattenAudience{}:
    42  			flattenAudience = option.Value().(bool)
    43  		case identCompactOnly{}:
    44  			compactOnlyBool = option.Value().(bool)
    45  		case identNumericDateParsePedantic{}:
    46  			parsePedantic = option.Value().(bool)
    47  		case identNumericDateParsePrecision{}:
    48  			v := option.Value().(int)
    49  			// only accept this value if it's in our desired range
    50  			if v >= 0 && v <= int(types.MaxPrecision) {
    51  				parsePrecision = uint32(v)
    52  			}
    53  		case identNumericDateFormatPrecision{}:
    54  			v := option.Value().(int)
    55  			// only accept this value if it's in our desired range
    56  			if v >= 0 && v <= int(types.MaxPrecision) {
    57  				formatPrecision = uint32(v)
    58  			}
    59  		}
    60  	}
    61  
    62  	if parsePrecision <= types.MaxPrecision { // remember we set default to max + 1
    63  		v := atomic.LoadUint32(&types.ParsePrecision)
    64  		if v != parsePrecision {
    65  			atomic.CompareAndSwapUint32(&types.ParsePrecision, v, parsePrecision)
    66  		}
    67  	}
    68  
    69  	if formatPrecision <= types.MaxPrecision { // remember we set default to max + 1
    70  		v := atomic.LoadUint32(&types.FormatPrecision)
    71  		if v != formatPrecision {
    72  			atomic.CompareAndSwapUint32(&types.FormatPrecision, v, formatPrecision)
    73  		}
    74  	}
    75  
    76  	{
    77  		v := atomic.LoadUint32(&types.Pedantic)
    78  		if (v == 1) != parsePedantic {
    79  			var newVal uint32
    80  			if parsePedantic {
    81  				newVal = 1
    82  			}
    83  			atomic.CompareAndSwapUint32(&types.Pedantic, v, newVal)
    84  		}
    85  	}
    86  
    87  	{
    88  		v := atomic.LoadUint32(&compactOnly)
    89  		if (v == 1) != compactOnlyBool {
    90  			var newVal uint32
    91  			if compactOnlyBool {
    92  				newVal = 1
    93  			}
    94  			atomic.CompareAndSwapUint32(&compactOnly, v, newVal)
    95  		}
    96  	}
    97  
    98  	{
    99  		defaultOptionsMu.Lock()
   100  		if flattenAudience {
   101  			defaultOptions.Enable(FlattenAudience)
   102  		} else {
   103  			defaultOptions.Disable(FlattenAudience)
   104  		}
   105  		defaultOptionsMu.Unlock()
   106  	}
   107  }
   108  
   109  var registry = json.NewRegistry()
   110  
   111  // ParseString calls Parse against a string
   112  func ParseString(s string, options ...ParseOption) (Token, error) {
   113  	return parseBytes([]byte(s), options...)
   114  }
   115  
   116  // Parse parses the JWT token payload and creates a new `jwt.Token` object.
   117  // The token must be encoded in either JSON format or compact format.
   118  //
   119  // This function can only work with either raw JWT (JSON) and JWS (Compact or JSON).
   120  // If you need JWE support on top of it, you will need to rollout your
   121  // own workaround.
   122  //
   123  // If the token is signed and you want to verify the payload matches the signature,
   124  // you must pass the jwt.WithKey(alg, key) or jwt.WithKeySet(jwk.Set) option.
   125  // If you do not specify these parameters, no verification will be performed.
   126  //
   127  // During verification, if the JWS headers specify a key ID (`kid`), the
   128  // key used for verification must match the specified ID. If you are somehow
   129  // using a key without a `kid` (which is highly unlikely if you are working
   130  // with a JWT from a well know provider), you can workaround this by modifying
   131  // the `jwk.Key` and setting the `kid` header.
   132  //
   133  // If you also want to assert the validity of the JWT itself (i.e. expiration
   134  // and such), use the `Validate()` function on the returned token, or pass the
   135  // `WithValidate(true)` option. Validate options can also be passed to
   136  // `Parse`
   137  //
   138  // This function takes both ParseOption and ValidateOption types:
   139  // ParseOptions control the parsing behavior, and ValidateOptions are
   140  // passed to `Validate()` when `jwt.WithValidate` is specified.
   141  func Parse(s []byte, options ...ParseOption) (Token, error) {
   142  	return parseBytes(s, options...)
   143  }
   144  
   145  // ParseInsecure is exactly the same as Parse(), but it disables
   146  // signature verification and token validation.
   147  //
   148  // You cannot override `jwt.WithVerify()` or `jwt.WithValidate()`
   149  // using this function. Providing these options would result in
   150  // an error
   151  func ParseInsecure(s []byte, options ...ParseOption) (Token, error) {
   152  	for _, option := range options {
   153  		switch option.Ident() {
   154  		case identVerify{}, identValidate{}:
   155  			return nil, fmt.Errorf(`jwt.ParseInsecure: jwt.WithVerify() and jwt.WithValidate() may not be specified`)
   156  		}
   157  	}
   158  
   159  	options = append(options, WithVerify(false), WithValidate(false))
   160  	return Parse(s, options...)
   161  }
   162  
   163  // ParseReader calls Parse against an io.Reader
   164  func ParseReader(src io.Reader, options ...ParseOption) (Token, error) {
   165  	// We're going to need the raw bytes regardless. Read it.
   166  	data, err := io.ReadAll(src)
   167  	if err != nil {
   168  		return nil, fmt.Errorf(`failed to read from token data source: %w`, err)
   169  	}
   170  	return parseBytes(data, options...)
   171  }
   172  
   173  type parseCtx struct {
   174  	token            Token
   175  	validateOpts     []ValidateOption
   176  	verifyOpts       []jws.VerifyOption
   177  	localReg         *json.Registry
   178  	pedantic         bool
   179  	skipVerification bool
   180  	validate         bool
   181  }
   182  
   183  func parseBytes(data []byte, options ...ParseOption) (Token, error) {
   184  	var ctx parseCtx
   185  
   186  	// Validation is turned on by default. You need to specify
   187  	// jwt.WithValidate(false) if you want to disable it
   188  	ctx.validate = true
   189  
   190  	// Verification is required (i.e., it is assumed that the incoming
   191  	// data is in JWS format) unless the user explicitly asks for
   192  	// it to be skipped.
   193  	verification := true
   194  
   195  	var verifyOpts []Option
   196  	for _, o := range options {
   197  		if v, ok := o.(ValidateOption); ok {
   198  			ctx.validateOpts = append(ctx.validateOpts, v)
   199  			continue
   200  		}
   201  
   202  		//nolint:forcetypeassert
   203  		switch o.Ident() {
   204  		case identKey{}, identKeySet{}, identVerifyAuto{}, identKeyProvider{}:
   205  			verifyOpts = append(verifyOpts, o)
   206  		case identToken{}:
   207  			token, ok := o.Value().(Token)
   208  			if !ok {
   209  				return nil, fmt.Errorf(`invalid token passed via WithToken() option (%T)`, o.Value())
   210  			}
   211  			ctx.token = token
   212  		case identPedantic{}:
   213  			ctx.pedantic = o.Value().(bool)
   214  		case identValidate{}:
   215  			ctx.validate = o.Value().(bool)
   216  		case identVerify{}:
   217  			verification = o.Value().(bool)
   218  		case identTypedClaim{}:
   219  			pair := o.Value().(claimPair)
   220  			if ctx.localReg == nil {
   221  				ctx.localReg = json.NewRegistry()
   222  			}
   223  			ctx.localReg.Register(pair.Name, pair.Value)
   224  		}
   225  	}
   226  
   227  	if !verification {
   228  		ctx.skipVerification = true
   229  	}
   230  
   231  	lvo := len(verifyOpts)
   232  	if lvo == 0 && verification {
   233  		return nil, fmt.Errorf(`jwt.Parse: no keys for verification are provided (use jwt.WithVerify(false) to explicitly skip)`)
   234  	}
   235  
   236  	if lvo > 0 {
   237  		converted, err := toVerifyOptions(verifyOpts...)
   238  		if err != nil {
   239  			return nil, fmt.Errorf(`jwt.Parse: failed to convert options into jws.VerifyOption: %w`, err)
   240  		}
   241  		ctx.verifyOpts = converted
   242  	}
   243  
   244  	data = bytes.TrimSpace(data)
   245  	return parse(&ctx, data)
   246  }
   247  
   248  const (
   249  	_JwsVerifyInvalid = iota
   250  	_JwsVerifyDone
   251  	_JwsVerifyExpectNested
   252  	_JwsVerifySkipped
   253  )
   254  
   255  var _ = _JwsVerifyInvalid
   256  
   257  func verifyJWS(ctx *parseCtx, payload []byte) ([]byte, int, error) {
   258  	if len(ctx.verifyOpts) == 0 {
   259  		return nil, _JwsVerifySkipped, nil
   260  	}
   261  
   262  	verifyOpts := ctx.verifyOpts
   263  	if atomic.LoadUint32(&compactOnly) == 1 {
   264  		verifyOpts = append(verifyOpts, jws.WithCompact())
   265  	}
   266  	verified, err := jws.Verify(payload, verifyOpts...)
   267  	return verified, _JwsVerifyDone, err
   268  }
   269  
   270  // verify parameter exists to make sure that we don't accidentally skip
   271  // over verification just because alg == ""  or key == nil or something.
   272  func parse(ctx *parseCtx, data []byte) (Token, error) {
   273  	payload := data
   274  	const maxDecodeLevels = 2
   275  
   276  	// If cty = `JWT`, we expect this to be a nested structure
   277  	var expectNested bool
   278  
   279  OUTER:
   280  	for i := 0; i < maxDecodeLevels; i++ {
   281  		switch kind := jwx.GuessFormat(payload); kind {
   282  		case jwx.JWT:
   283  			if ctx.pedantic {
   284  				if expectNested {
   285  					return nil, fmt.Errorf(`expected nested encrypted/signed payload, got raw JWT`)
   286  				}
   287  			}
   288  
   289  			if i == 0 {
   290  				// We were NOT enveloped in other formats
   291  				if !ctx.skipVerification {
   292  					if _, _, err := verifyJWS(ctx, payload); err != nil {
   293  						return nil, err
   294  					}
   295  				}
   296  			}
   297  
   298  			break OUTER
   299  		case jwx.InvalidFormat:
   300  			return nil, ErrInvalidJWT()
   301  		case jwx.UnknownFormat:
   302  			// "Unknown" may include invalid JWTs, for example, those who lack "aud"
   303  			// claim. We could be pedantic and reject these
   304  			if ctx.pedantic {
   305  				return nil, fmt.Errorf(`unknown JWT format (pedantic)`)
   306  			}
   307  
   308  			if i == 0 {
   309  				// We were NOT enveloped in other formats
   310  				if !ctx.skipVerification {
   311  					if _, _, err := verifyJWS(ctx, payload); err != nil {
   312  						return nil, err
   313  					}
   314  				}
   315  			}
   316  			break OUTER
   317  		case jwx.JWS:
   318  			// Food for thought: This is going to break if you have multiple layers of
   319  			// JWS enveloping using different keys. It is highly unlikely use case,
   320  			// but it might happen.
   321  
   322  			// skipVerification should only be set to true by us. It's used
   323  			// when we just want to parse the JWT out of a payload
   324  			if !ctx.skipVerification {
   325  				// nested return value means:
   326  				// false (next envelope _may_ need to be processed)
   327  				// true (next envelope MUST be processed)
   328  				v, state, err := verifyJWS(ctx, payload)
   329  				if err != nil {
   330  					return nil, err
   331  				}
   332  
   333  				if state != _JwsVerifySkipped {
   334  					payload = v
   335  
   336  					// We only check for cty and typ if the pedantic flag is enabled
   337  					if !ctx.pedantic {
   338  						continue
   339  					}
   340  
   341  					if state == _JwsVerifyExpectNested {
   342  						expectNested = true
   343  						continue OUTER
   344  					}
   345  
   346  					// if we're not nested, we found our target. bail out of this loop
   347  					break OUTER
   348  				}
   349  			}
   350  
   351  			// No verification.
   352  			var parseOptions []jws.ParseOption
   353  			if atomic.LoadUint32(&compactOnly) == 1 {
   354  				parseOptions = append(parseOptions, jws.WithCompact())
   355  			}
   356  			m, err := jws.Parse(data, parseOptions...)
   357  			if err != nil {
   358  				return nil, fmt.Errorf(`invalid jws message: %w`, err)
   359  			}
   360  			payload = m.Payload()
   361  		default:
   362  			return nil, fmt.Errorf(`unsupported format (layer: #%d)`, i+1)
   363  		}
   364  		expectNested = false
   365  	}
   366  
   367  	if ctx.token == nil {
   368  		ctx.token = New()
   369  	}
   370  
   371  	if ctx.localReg != nil {
   372  		dcToken, ok := ctx.token.(TokenWithDecodeCtx)
   373  		if !ok {
   374  			return nil, fmt.Errorf(`typed claim was requested, but the token (%T) does not support DecodeCtx`, ctx.token)
   375  		}
   376  		dc := json.NewDecodeCtx(ctx.localReg)
   377  		dcToken.SetDecodeCtx(dc)
   378  		defer func() { dcToken.SetDecodeCtx(nil) }()
   379  	}
   380  
   381  	if err := json.Unmarshal(payload, ctx.token); err != nil {
   382  		return nil, fmt.Errorf(`failed to parse token: %w`, err)
   383  	}
   384  
   385  	if ctx.validate {
   386  		if err := Validate(ctx.token, ctx.validateOpts...); err != nil {
   387  			return nil, err
   388  		}
   389  	}
   390  	return ctx.token, nil
   391  }
   392  
   393  // Sign is a convenience function to create a signed JWT token serialized in
   394  // compact form.
   395  //
   396  // It accepts either a raw key (e.g. rsa.PrivateKey, ecdsa.PrivateKey, etc)
   397  // or a jwk.Key, and the name of the algorithm that should be used to sign
   398  // the token.
   399  //
   400  // If the key is a jwk.Key and the key contains a key ID (`kid` field),
   401  // then it is added to the protected header generated by the signature
   402  //
   403  // The algorithm specified in the `alg` parameter must be able to support
   404  // the type of key you provided, otherwise an error is returned.
   405  // For convenience `alg` is of type jwa.KeyAlgorithm so you can pass
   406  // the return value of `(jwk.Key).Algorithm()` directly, but in practice
   407  // it must be an instance of jwa.SignatureAlgorithm, otherwise an error
   408  // is returned.
   409  //
   410  // The protected header will also automatically have the `typ` field set
   411  // to the literal value `JWT`, unless you provide a custom value for it
   412  // by jwt.WithHeaders option.
   413  func Sign(t Token, options ...SignOption) ([]byte, error) {
   414  	var soptions []jws.SignOption
   415  	if l := len(options); l > 0 {
   416  		// we need to from SignOption to Option because ... reasons
   417  		// (todo: when go1.18 prevails, use type parameters
   418  		rawoptions := make([]Option, l)
   419  		for i, option := range options {
   420  			rawoptions[i] = option
   421  		}
   422  
   423  		converted, err := toSignOptions(rawoptions...)
   424  		if err != nil {
   425  			return nil, fmt.Errorf(`jwt.Sign: failed to convert options into jws.SignOption: %w`, err)
   426  		}
   427  		soptions = converted
   428  	}
   429  	return NewSerializer().sign(soptions...).Serialize(t)
   430  }
   431  
   432  // Equal compares two JWT tokens. Do not use `reflect.Equal` or the like
   433  // to compare tokens as they will also compare extra detail such as
   434  // sync.Mutex objects used to control concurrent access.
   435  //
   436  // The comparison for values is currently done using a simple equality ("=="),
   437  // except for time.Time, which uses time.Equal after dropping the monotonic
   438  // clock and truncating the values to 1 second accuracy.
   439  //
   440  // if both t1 and t2 are nil, returns true
   441  func Equal(t1, t2 Token) bool {
   442  	if t1 == nil && t2 == nil {
   443  		return true
   444  	}
   445  
   446  	// we already checked for t1 == t2 == nil, so safe to do this
   447  	if t1 == nil || t2 == nil {
   448  		return false
   449  	}
   450  
   451  	j1, err := json.Marshal(t1)
   452  	if err != nil {
   453  		return false
   454  	}
   455  
   456  	j2, err := json.Marshal(t2)
   457  	if err != nil {
   458  		return false
   459  	}
   460  
   461  	return bytes.Equal(j1, j2)
   462  }
   463  
   464  func (t *stdToken) Clone() (Token, error) {
   465  	dst := New()
   466  
   467  	dst.Options().Set(*(t.Options()))
   468  	for _, pair := range t.makePairs() {
   469  		//nolint:forcetypeassert
   470  		key := pair.Key.(string)
   471  		if err := dst.Set(key, pair.Value); err != nil {
   472  			return nil, fmt.Errorf(`failed to set %s: %w`, key, err)
   473  		}
   474  	}
   475  	return dst, nil
   476  }
   477  
   478  // RegisterCustomField allows users to specify that a private field
   479  // be decoded as an instance of the specified type. This option has
   480  // a global effect.
   481  //
   482  // For example, suppose you have a custom field `x-birthday`, which
   483  // you want to represent as a string formatted in RFC3339 in JSON,
   484  // but want it back as `time.Time`.
   485  //
   486  // In that case you would register a custom field as follows
   487  //
   488  //	jwt.RegisterCustomField(`x-birthday`, timeT)
   489  //
   490  // Then `token.Get("x-birthday")` will still return an `interface{}`,
   491  // but you can convert its type to `time.Time`
   492  //
   493  //	bdayif, _ := token.Get(`x-birthday`)
   494  //	bday := bdayif.(time.Time)
   495  func RegisterCustomField(name string, object interface{}) {
   496  	registry.Register(name, object)
   497  }