github.com/lestrrat-go/jwx/v2@v2.0.21/jws/jws.go (about)

     1  //go:generate ../tools/cmd/genjws.sh
     2  
     3  // Package jws implements the digital signature on JSON based data
     4  // structures as described in https://tools.ietf.org/html/rfc7515
     5  //
     6  // If you do not care about the details, the only things that you
     7  // would need to use are the following functions:
     8  //
     9  //	jws.Sign(payload, jws.WithKey(algorithm, key))
    10  //	jws.Verify(serialized, jws.WithKey(algorithm, key))
    11  //
    12  // To sign, simply use `jws.Sign`. `payload` is a []byte buffer that
    13  // contains whatever data you want to sign. `alg` is one of the
    14  // jwa.SignatureAlgorithm constants from package jwa. For RSA and
    15  // ECDSA family of algorithms, you will need to prepare a private key.
    16  // For HMAC family, you just need a []byte value. The `jws.Sign`
    17  // function will return the encoded JWS message on success.
    18  //
    19  // To verify, use `jws.Verify`. It will parse the `encodedjws` buffer
    20  // and verify the result using `algorithm` and `key`. Upon successful
    21  // verification, the original payload is returned, so you can work on it.
    22  package jws
    23  
    24  import (
    25  	"bufio"
    26  	"bytes"
    27  	"context"
    28  	"crypto/ecdsa"
    29  	"crypto/ed25519"
    30  	"crypto/rsa"
    31  	"errors"
    32  	"fmt"
    33  	"io"
    34  	"reflect"
    35  	"strings"
    36  	"sync"
    37  	"unicode"
    38  	"unicode/utf8"
    39  
    40  	"github.com/lestrrat-go/blackmagic"
    41  	"github.com/lestrrat-go/jwx/v2/internal/base64"
    42  	"github.com/lestrrat-go/jwx/v2/internal/json"
    43  	"github.com/lestrrat-go/jwx/v2/internal/pool"
    44  	"github.com/lestrrat-go/jwx/v2/jwa"
    45  	"github.com/lestrrat-go/jwx/v2/jwk"
    46  	"github.com/lestrrat-go/jwx/v2/x25519"
    47  )
    48  
    49  var registry = json.NewRegistry()
    50  
    51  type payloadSigner struct {
    52  	signer    Signer
    53  	key       interface{}
    54  	protected Headers
    55  	public    Headers
    56  }
    57  
    58  func (s *payloadSigner) Sign(payload []byte) ([]byte, error) {
    59  	return s.signer.Sign(payload, s.key)
    60  }
    61  
    62  func (s *payloadSigner) Algorithm() jwa.SignatureAlgorithm {
    63  	return s.signer.Algorithm()
    64  }
    65  
    66  func (s *payloadSigner) ProtectedHeader() Headers {
    67  	return s.protected
    68  }
    69  
    70  func (s *payloadSigner) PublicHeader() Headers {
    71  	return s.public
    72  }
    73  
    74  var signers = make(map[jwa.SignatureAlgorithm]Signer)
    75  var muSigner = &sync.Mutex{}
    76  
    77  func removeSigner(alg jwa.SignatureAlgorithm) {
    78  	muSigner.Lock()
    79  	defer muSigner.Unlock()
    80  	delete(signers, alg)
    81  }
    82  
    83  func makeSigner(alg jwa.SignatureAlgorithm, key interface{}, public, protected Headers) (*payloadSigner, error) {
    84  	muSigner.Lock()
    85  	signer, ok := signers[alg]
    86  	if !ok {
    87  		v, err := NewSigner(alg)
    88  		if err != nil {
    89  			muSigner.Unlock()
    90  			return nil, fmt.Errorf(`failed to create payload signer: %w`, err)
    91  		}
    92  		signers[alg] = v
    93  		signer = v
    94  	}
    95  	muSigner.Unlock()
    96  
    97  	return &payloadSigner{
    98  		signer:    signer,
    99  		key:       key,
   100  		public:    public,
   101  		protected: protected,
   102  	}, nil
   103  }
   104  
   105  const (
   106  	fmtInvalid = 1 << iota
   107  	fmtCompact
   108  	fmtJSON
   109  	fmtJSONPretty
   110  	fmtMax
   111  )
   112  
   113  // silence linters
   114  var _ = fmtInvalid
   115  var _ = fmtMax
   116  
   117  func validateKeyBeforeUse(key interface{}) error {
   118  	jwkKey, ok := key.(jwk.Key)
   119  	if !ok {
   120  		converted, err := jwk.FromRaw(key)
   121  		if err != nil {
   122  			return fmt.Errorf(`could not convert key of type %T to jwk.Key for validation: %w`, key, err)
   123  		}
   124  		jwkKey = converted
   125  	}
   126  	return jwkKey.Validate()
   127  }
   128  
   129  // Sign generates a JWS message for the given payload and returns
   130  // it in serialized form, which can be in either compact or
   131  // JSON format. Default is compact.
   132  //
   133  // You must pass at least one key to `jws.Sign()` by using `jws.WithKey()`
   134  // option.
   135  //
   136  //	jws.Sign(payload, jws.WithKey(alg, key))
   137  //	jws.Sign(payload, jws.WithJSON(), jws.WithKey(alg1, key1), jws.WithKey(alg2, key2))
   138  //
   139  // Note that in the second example the `jws.WithJSON()` option is
   140  // specified as well. This is because the compact serialization
   141  // format does not support multiple signatures, and users must
   142  // specifically ask for the JSON serialization format.
   143  //
   144  // Read the documentation for `jws.WithKey()` to learn more about the
   145  // possible values that can be used for `alg` and `key`.
   146  //
   147  // You may create JWS messages with the "none" (jwa.NoSignature) algorithm
   148  // if you use the `jws.WithInsecureNoSignature()` option. This option
   149  // can be combined with one or more signature keys, as well as the
   150  // `jws.WithJSON()` option to generate multiple signatures (though
   151  // the usefulness of such constructs is highly debatable)
   152  //
   153  // Note that this library does not allow you to successfully call `jws.Verify()` on
   154  // signatures with the "none" algorithm. To parse these, use `jws.Parse()` instead.
   155  //
   156  // If you want to use a detached payload, use `jws.WithDetachedPayload()` as
   157  // one of the options. When you use this option, you must always set the
   158  // first parameter (`payload`) to `nil`, or the function will return an error
   159  //
   160  // You may also want to look at how to pass protected headers to the
   161  // signing process, as you will likely be required to set the `b64` field
   162  // when using detached payload.
   163  //
   164  // Look for options that return `jws.SignOption` or `jws.SignVerifyOption`
   165  // for a complete list of options that can be passed to this function.
   166  func Sign(payload []byte, options ...SignOption) ([]byte, error) {
   167  	format := fmtCompact
   168  	var signers []*payloadSigner
   169  	var detached bool
   170  	var noneSignature *payloadSigner
   171  	var validateKey bool
   172  	for _, option := range options {
   173  		//nolint:forcetypeassert
   174  		switch option.Ident() {
   175  		case identSerialization{}:
   176  			format = option.Value().(int)
   177  		case identInsecureNoSignature{}:
   178  			data := option.Value().(*withInsecureNoSignature)
   179  			// only the last one is used (we overwrite previous values)
   180  			noneSignature = &payloadSigner{
   181  				signer:    noneSigner{},
   182  				protected: data.protected,
   183  			}
   184  		case identKey{}:
   185  			data := option.Value().(*withKey)
   186  
   187  			alg, ok := data.alg.(jwa.SignatureAlgorithm)
   188  			if !ok {
   189  				return nil, fmt.Errorf(`jws.Sign: expected algorithm to be of type jwa.SignatureAlgorithm but got (%[1]q, %[1]T)`, data.alg)
   190  			}
   191  
   192  			// No, we don't accept "none" here.
   193  			if alg == jwa.NoSignature {
   194  				return nil, fmt.Errorf(`jws.Sign: "none" (jwa.NoSignature) cannot be used with jws.WithKey`)
   195  			}
   196  
   197  			signer, err := makeSigner(alg, data.key, data.public, data.protected)
   198  			if err != nil {
   199  				return nil, fmt.Errorf(`jws.Sign: failed to create signer: %w`, err)
   200  			}
   201  			signers = append(signers, signer)
   202  		case identDetachedPayload{}:
   203  			detached = true
   204  			if payload != nil {
   205  				return nil, fmt.Errorf(`jws.Sign: payload must be nil when jws.WithDetachedPayload() is specified`)
   206  			}
   207  			payload = option.Value().([]byte)
   208  		case identValidateKey{}:
   209  			validateKey = option.Value().(bool)
   210  		}
   211  	}
   212  
   213  	if noneSignature != nil {
   214  		signers = append(signers, noneSignature)
   215  	}
   216  
   217  	lsigner := len(signers)
   218  	if lsigner == 0 {
   219  		return nil, fmt.Errorf(`jws.Sign: no signers available. Specify an alogirthm and akey using jws.WithKey()`)
   220  	}
   221  
   222  	// Design note: while we could have easily set format = fmtJSON when
   223  	// lsigner > 1, I believe the decision to change serialization formats
   224  	// must be explicitly stated by the caller. Otherwise I'm pretty sure
   225  	// there would be people filing issues saying "I get JSON when I expcted
   226  	// compact serialization".
   227  	//
   228  	// Therefore, instead of making implicit format conversions, we force the
   229  	// user to spell it out as `jws.Sign(..., jws.WithJSON(), jws.WithKey(...), jws.WithKey(...))`
   230  	if format == fmtCompact && lsigner != 1 {
   231  		return nil, fmt.Errorf(`jws.Sign: cannot have multiple signers (keys) specified for compact serialization. Use only one jws.WithKey()`)
   232  	}
   233  
   234  	// Create a Message object with all the bits and bobs, and we'll
   235  	// serialize it in the end
   236  	var result Message
   237  
   238  	result.payload = payload
   239  
   240  	result.signatures = make([]*Signature, 0, len(signers))
   241  	for i, signer := range signers {
   242  		protected := signer.ProtectedHeader()
   243  		if protected == nil {
   244  			protected = NewHeaders()
   245  		}
   246  
   247  		if err := protected.Set(AlgorithmKey, signer.Algorithm()); err != nil {
   248  			return nil, fmt.Errorf(`failed to set "alg" header: %w`, err)
   249  		}
   250  
   251  		if key, ok := signer.key.(jwk.Key); ok {
   252  			if kid := key.KeyID(); kid != "" {
   253  				if err := protected.Set(KeyIDKey, kid); err != nil {
   254  					return nil, fmt.Errorf(`failed to set "kid" header: %w`, err)
   255  				}
   256  			}
   257  		}
   258  		sig := &Signature{
   259  			headers:   signer.PublicHeader(),
   260  			protected: protected,
   261  			// cheat. FIXXXXXXMEEEEEE
   262  			detached: detached,
   263  		}
   264  
   265  		if validateKey {
   266  			if err := validateKeyBeforeUse(signer.key); err != nil {
   267  				return nil, fmt.Errorf(`jws.Verify: %w`, err)
   268  			}
   269  		}
   270  		_, _, err := sig.Sign(payload, signer.signer, signer.key)
   271  		if err != nil {
   272  			return nil, fmt.Errorf(`failed to generate signature for signer #%d (alg=%s): %w`, i, signer.Algorithm(), err)
   273  		}
   274  
   275  		result.signatures = append(result.signatures, sig)
   276  	}
   277  
   278  	switch format {
   279  	case fmtJSON:
   280  		return json.Marshal(result)
   281  	case fmtJSONPretty:
   282  		return json.MarshalIndent(result, "", "  ")
   283  	case fmtCompact:
   284  		// Take the only signature object, and convert it into a Compact
   285  		// serialization format
   286  		var compactOpts []CompactOption
   287  		if detached {
   288  			compactOpts = append(compactOpts, WithDetached(detached))
   289  		}
   290  		return Compact(&result, compactOpts...)
   291  	default:
   292  		return nil, fmt.Errorf(`jws.Sign: invalid serialization format`)
   293  	}
   294  }
   295  
   296  var allowNoneWhitelist = jwk.WhitelistFunc(func(string) bool {
   297  	return false
   298  })
   299  
   300  // Verify checks if the given JWS message is verifiable using `alg` and `key`.
   301  // `key` may be a "raw" key (e.g. rsa.PublicKey) or a jwk.Key
   302  //
   303  // If the verification is successful, `err` is nil, and the content of the
   304  // payload that was signed is returned. If you need more fine-grained
   305  // control of the verification process, manually generate a
   306  // `Verifier` in `verify` subpackage, and call `Verify` method on it.
   307  // If you need to access signatures and JOSE headers in a JWS message,
   308  // use `Parse` function to get `Message` object.
   309  //
   310  // Because the use of "none" (jwa.NoSignature) algorithm is strongly discouraged,
   311  // this function DOES NOT consider it a success when `{"alg":"none"}` is
   312  // encountered in the message (it would also be counter intuitive when the code says
   313  // you _verified_ something when in fact it did no such thing). If you want to
   314  // accept messages with "none" signature algorithm, use `jws.Parse` to get the
   315  // raw JWS message.
   316  func Verify(buf []byte, options ...VerifyOption) ([]byte, error) {
   317  	var parseOptions []ParseOption
   318  	var dst *Message
   319  	var detachedPayload []byte
   320  	var keyProviders []KeyProvider
   321  	var keyUsed interface{}
   322  	var validateKey bool
   323  
   324  	ctx := context.Background()
   325  
   326  	//nolint:forcetypeassert
   327  	for _, option := range options {
   328  		switch option.Ident() {
   329  		case identMessage{}:
   330  			dst = option.Value().(*Message)
   331  		case identDetachedPayload{}:
   332  			detachedPayload = option.Value().([]byte)
   333  		case identKey{}:
   334  			pair := option.Value().(*withKey)
   335  			alg, ok := pair.alg.(jwa.SignatureAlgorithm)
   336  			if !ok {
   337  				return nil, fmt.Errorf(`WithKey() option must be specified using jwa.SignatureAlgorithm (got %T)`, pair.alg)
   338  			}
   339  			keyProviders = append(keyProviders, &staticKeyProvider{
   340  				alg: alg,
   341  				key: pair.key,
   342  			})
   343  		case identKeyProvider{}:
   344  			keyProviders = append(keyProviders, option.Value().(KeyProvider))
   345  		case identKeyUsed{}:
   346  			keyUsed = option.Value()
   347  		case identContext{}:
   348  			ctx = option.Value().(context.Context)
   349  		case identValidateKey{}:
   350  			validateKey = option.Value().(bool)
   351  		case identSerialization{}:
   352  			parseOptions = append(parseOptions, option.(ParseOption))
   353  		default:
   354  			return nil, fmt.Errorf(`invalid jws.VerifyOption %q passed`, `With`+strings.TrimPrefix(fmt.Sprintf(`%T`, option.Ident()), `jws.ident`))
   355  		}
   356  	}
   357  
   358  	if len(keyProviders) < 1 {
   359  		return nil, fmt.Errorf(`jws.Verify: no key providers have been provided (see jws.WithKey(), jws.WithKeySet(), jws.WithVerifyAuto(), and jws.WithKeyProvider()`)
   360  	}
   361  
   362  	msg, err := Parse(buf, parseOptions...)
   363  	if err != nil {
   364  		return nil, fmt.Errorf(`failed to parse jws: %w`, err)
   365  	}
   366  	defer msg.clearRaw()
   367  
   368  	if detachedPayload != nil {
   369  		if len(msg.payload) != 0 {
   370  			return nil, fmt.Errorf(`can't specify detached payload for JWS with payload`)
   371  		}
   372  
   373  		msg.payload = detachedPayload
   374  	}
   375  
   376  	// Pre-compute the base64 encoded version of payload
   377  	var payload string
   378  	if msg.b64 {
   379  		payload = base64.EncodeToString(msg.payload)
   380  	} else {
   381  		payload = string(msg.payload)
   382  	}
   383  
   384  	verifyBuf := pool.GetBytesBuffer()
   385  	defer pool.ReleaseBytesBuffer(verifyBuf)
   386  
   387  	var errs []error
   388  	for i, sig := range msg.signatures {
   389  		verifyBuf.Reset()
   390  
   391  		var encodedProtectedHeader string
   392  		if rbp, ok := sig.protected.(interface{ rawBuffer() []byte }); ok {
   393  			if raw := rbp.rawBuffer(); raw != nil {
   394  				encodedProtectedHeader = base64.EncodeToString(raw)
   395  			}
   396  		}
   397  
   398  		if encodedProtectedHeader == "" {
   399  			protected, err := json.Marshal(sig.protected)
   400  			if err != nil {
   401  				return nil, fmt.Errorf(`failed to marshal "protected" for signature #%d: %w`, i+1, err)
   402  			}
   403  
   404  			encodedProtectedHeader = base64.EncodeToString(protected)
   405  		}
   406  
   407  		verifyBuf.WriteString(encodedProtectedHeader)
   408  		verifyBuf.WriteByte('.')
   409  		verifyBuf.WriteString(payload)
   410  
   411  		for i, kp := range keyProviders {
   412  			var sink algKeySink
   413  			if err := kp.FetchKeys(ctx, &sink, sig, msg); err != nil {
   414  				return nil, fmt.Errorf(`key provider %d failed: %w`, i, err)
   415  			}
   416  
   417  			for _, pair := range sink.list {
   418  				// alg is converted here because pair.alg is of type jwa.KeyAlgorithm.
   419  				// this may seem ugly, but we're trying to avoid declaring separate
   420  				// structs for `alg jwa.KeyAlgorithm` and `alg jwa.SignatureAlgorithm`
   421  				//nolint:forcetypeassert
   422  				alg := pair.alg.(jwa.SignatureAlgorithm)
   423  				key := pair.key
   424  
   425  				if validateKey {
   426  					if err := validateKeyBeforeUse(key); err != nil {
   427  						return nil, fmt.Errorf(`jws.Verify: %w`, err)
   428  					}
   429  				}
   430  				verifier, err := NewVerifier(alg)
   431  				if err != nil {
   432  					return nil, fmt.Errorf(`failed to create verifier for algorithm %q: %w`, alg, err)
   433  				}
   434  
   435  				if err := verifier.Verify(verifyBuf.Bytes(), sig.signature, key); err != nil {
   436  					errs = append(errs, err)
   437  					continue
   438  				}
   439  
   440  				if keyUsed != nil {
   441  					if err := blackmagic.AssignIfCompatible(keyUsed, key); err != nil {
   442  						return nil, fmt.Errorf(`failed to assign used key (%T) to %T: %w`, key, keyUsed, err)
   443  					}
   444  				}
   445  
   446  				if dst != nil {
   447  					*(dst) = *msg
   448  				}
   449  
   450  				return msg.payload, nil
   451  			}
   452  		}
   453  	}
   454  	return nil, &verifyError{errs: errs}
   455  }
   456  
   457  type verifyError struct {
   458  	// Note: when/if we can ditch Go < 1.20, we can change this to a simple
   459  	// `err error`, where the value is the result of `errors.Join()`
   460  	//
   461  	// We also need to implement Unwrap:
   462  	// func (e *verifyError) Unwrap() error {
   463  	//	return e.err
   464  	//}
   465  	//
   466  	// And finally, As() can go away
   467  	errs []error
   468  }
   469  
   470  func (e *verifyError) Error() string {
   471  	return `could not verify message using any of the signatures or keys`
   472  }
   473  
   474  func (e *verifyError) As(target interface{}) bool {
   475  	for _, err := range e.errs {
   476  		if errors.As(err, target) {
   477  			return true
   478  		}
   479  	}
   480  	return false
   481  }
   482  
   483  // IsVerificationError returns true if the error came from the verification part of the
   484  // jws.Verify function, allowing you to check if the error is a result of actual
   485  // verification failure.
   486  //
   487  // For example, if the error happened while fetching a key
   488  // from a datasource, feeding that error should to this function return false, whereas
   489  // a failure to compute the signature for whatever reason would be a verification error
   490  // and returns true.
   491  func IsVerificationError(err error) bool {
   492  	var ve *verifyError
   493  	return errors.As(err, &ve)
   494  }
   495  
   496  // get the value of b64 header field.
   497  // If the field does not exist, returns true (default)
   498  // Otherwise return the value specified by the header field.
   499  func getB64Value(hdr Headers) bool {
   500  	b64raw, ok := hdr.Get("b64")
   501  	if !ok {
   502  		return true // default
   503  	}
   504  
   505  	b64, ok := b64raw.(bool) // default
   506  	if !ok {
   507  		return false
   508  	}
   509  	return b64
   510  }
   511  
   512  // This is an "optimized" io.ReadAll(). It will attempt to read
   513  // all of the contents from the reader IF the reader is of a certain
   514  // concrete type.
   515  func readAll(rdr io.Reader) ([]byte, bool) {
   516  	switch rdr.(type) {
   517  	case *bytes.Reader, *bytes.Buffer, *strings.Reader:
   518  		data, err := io.ReadAll(rdr)
   519  		if err != nil {
   520  			return nil, false
   521  		}
   522  		return data, true
   523  	default:
   524  		return nil, false
   525  	}
   526  }
   527  
   528  // Parse parses contents from the given source and creates a jws.Message
   529  // struct. By default the input can be in either compact or full JSON serialization.
   530  //
   531  // You may pass `jws.WithJSON()` and/or `jws.WithCompact()` to specify
   532  // explicitly which format to use. If neither or both is specified, the function
   533  // will attempt to autodetect the format. If one or the other is specified,
   534  // only the specified format will be attempted.
   535  func Parse(src []byte, options ...ParseOption) (*Message, error) {
   536  	var formats int
   537  	for _, option := range options {
   538  		//nolint:forcetypeassert
   539  		switch option.Ident() {
   540  		case identSerialization{}:
   541  			switch option.Value().(int) {
   542  			case fmtJSON:
   543  				formats |= fmtJSON
   544  			case fmtCompact:
   545  				formats |= fmtCompact
   546  			}
   547  		}
   548  	}
   549  
   550  	// if format is 0 or both JSON/Compact, auto detect
   551  	if v := formats & (fmtJSON | fmtCompact); v == 0 || v == fmtJSON|fmtCompact {
   552  		for i := 0; i < len(src); i++ {
   553  			r := rune(src[i])
   554  			if r >= utf8.RuneSelf {
   555  				r, _ = utf8.DecodeRune(src)
   556  			}
   557  			if !unicode.IsSpace(r) {
   558  				if r == '{' {
   559  					return parseJSON(src)
   560  				}
   561  				return parseCompact(src)
   562  			}
   563  		}
   564  	} else if formats&fmtCompact == fmtCompact {
   565  		return parseCompact(src)
   566  	} else if formats&fmtJSON == fmtJSON {
   567  		return parseJSON(src)
   568  	}
   569  
   570  	return nil, fmt.Errorf(`invalid byte sequence`)
   571  }
   572  
   573  // Parse parses contents from the given source and creates a jws.Message
   574  // struct. The input can be in either compact or full JSON serialization.
   575  func ParseString(src string) (*Message, error) {
   576  	return Parse([]byte(src))
   577  }
   578  
   579  // Parse parses contents from the given source and creates a jws.Message
   580  // struct. The input can be in either compact or full JSON serialization.
   581  func ParseReader(src io.Reader) (*Message, error) {
   582  	if data, ok := readAll(src); ok {
   583  		return Parse(data)
   584  	}
   585  
   586  	rdr := bufio.NewReader(src)
   587  	var first rune
   588  	for {
   589  		r, _, err := rdr.ReadRune()
   590  		if err != nil {
   591  			return nil, fmt.Errorf(`failed to read rune: %w`, err)
   592  		}
   593  		if !unicode.IsSpace(r) {
   594  			first = r
   595  			if err := rdr.UnreadRune(); err != nil {
   596  				return nil, fmt.Errorf(`failed to unread rune: %w`, err)
   597  			}
   598  
   599  			break
   600  		}
   601  	}
   602  
   603  	var parser func(io.Reader) (*Message, error)
   604  	if first == '{' {
   605  		parser = parseJSONReader
   606  	} else {
   607  		parser = parseCompactReader
   608  	}
   609  
   610  	m, err := parser(rdr)
   611  	if err != nil {
   612  		return nil, fmt.Errorf(`failed to parse jws message: %w`, err)
   613  	}
   614  
   615  	return m, nil
   616  }
   617  
   618  func parseJSONReader(src io.Reader) (result *Message, err error) {
   619  	var m Message
   620  	if err := json.NewDecoder(src).Decode(&m); err != nil {
   621  		return nil, fmt.Errorf(`failed to unmarshal jws message: %w`, err)
   622  	}
   623  	return &m, nil
   624  }
   625  
   626  func parseJSON(data []byte) (result *Message, err error) {
   627  	var m Message
   628  	if err := json.Unmarshal(data, &m); err != nil {
   629  		return nil, fmt.Errorf(`failed to unmarshal jws message: %w`, err)
   630  	}
   631  	return &m, nil
   632  }
   633  
   634  // SplitCompact splits a JWT and returns its three parts
   635  // separately: protected headers, payload and signature.
   636  func SplitCompact(src []byte) ([]byte, []byte, []byte, error) {
   637  	parts := bytes.Split(src, []byte("."))
   638  	if len(parts) < 3 {
   639  		return nil, nil, nil, fmt.Errorf(`invalid number of segments`)
   640  	}
   641  	return parts[0], parts[1], parts[2], nil
   642  }
   643  
   644  // SplitCompactString splits a JWT and returns its three parts
   645  // separately: protected headers, payload and signature.
   646  func SplitCompactString(src string) ([]byte, []byte, []byte, error) {
   647  	parts := strings.Split(src, ".")
   648  	if len(parts) < 3 {
   649  		return nil, nil, nil, fmt.Errorf(`invalid number of segments`)
   650  	}
   651  	return []byte(parts[0]), []byte(parts[1]), []byte(parts[2]), nil
   652  }
   653  
   654  // SplitCompactReader splits a JWT and returns its three parts
   655  // separately: protected headers, payload and signature.
   656  func SplitCompactReader(rdr io.Reader) ([]byte, []byte, []byte, error) {
   657  	if data, ok := readAll(rdr); ok {
   658  		return SplitCompact(data)
   659  	}
   660  
   661  	var protected []byte
   662  	var payload []byte
   663  	var signature []byte
   664  	var periods int
   665  	var state int
   666  
   667  	buf := make([]byte, 4096)
   668  	var sofar []byte
   669  
   670  	for {
   671  		// read next bytes
   672  		n, err := rdr.Read(buf)
   673  		// return on unexpected read error
   674  		if err != nil && err != io.EOF {
   675  			return nil, nil, nil, fmt.Errorf(`unexpected end of input: %w`, err)
   676  		}
   677  
   678  		// append to current buffer
   679  		sofar = append(sofar, buf[:n]...)
   680  		// loop to capture multiple '.' in current buffer
   681  		for loop := true; loop; {
   682  			var i = bytes.IndexByte(sofar, '.')
   683  			if i == -1 && err != io.EOF {
   684  				// no '.' found -> exit and read next bytes (outer loop)
   685  				loop = false
   686  				continue
   687  			} else if i == -1 && err == io.EOF {
   688  				// no '.' found -> process rest and exit
   689  				i = len(sofar)
   690  				loop = false
   691  			} else {
   692  				// '.' found
   693  				periods++
   694  			}
   695  
   696  			// Reaching this point means we have found a '.' or EOF and process the rest of the buffer
   697  			switch state {
   698  			case 0:
   699  				protected = sofar[:i]
   700  				state++
   701  			case 1:
   702  				payload = sofar[:i]
   703  				state++
   704  			case 2:
   705  				signature = sofar[:i]
   706  			}
   707  			// Shorten current buffer
   708  			if len(sofar) > i {
   709  				sofar = sofar[i+1:]
   710  			}
   711  		}
   712  		// Exit on EOF
   713  		if err == io.EOF {
   714  			break
   715  		}
   716  	}
   717  	if periods != 2 {
   718  		return nil, nil, nil, fmt.Errorf(`invalid number of segments`)
   719  	}
   720  
   721  	return protected, payload, signature, nil
   722  }
   723  
   724  // parseCompactReader parses a JWS value serialized via compact serialization.
   725  func parseCompactReader(rdr io.Reader) (m *Message, err error) {
   726  	protected, payload, signature, err := SplitCompactReader(rdr)
   727  	if err != nil {
   728  		return nil, fmt.Errorf(`invalid compact serialization format: %w`, err)
   729  	}
   730  	return parse(protected, payload, signature)
   731  }
   732  
   733  func parseCompact(data []byte) (m *Message, err error) {
   734  	protected, payload, signature, err := SplitCompact(data)
   735  	if err != nil {
   736  		return nil, fmt.Errorf(`invalid compact serialization format: %w`, err)
   737  	}
   738  	return parse(protected, payload, signature)
   739  }
   740  
   741  func parse(protected, payload, signature []byte) (*Message, error) {
   742  	decodedHeader, err := base64.Decode(protected)
   743  	if err != nil {
   744  		return nil, fmt.Errorf(`failed to decode protected headers: %w`, err)
   745  	}
   746  
   747  	hdr := NewHeaders()
   748  	if err := json.Unmarshal(decodedHeader, hdr); err != nil {
   749  		return nil, fmt.Errorf(`failed to parse JOSE headers: %w`, err)
   750  	}
   751  
   752  	var decodedPayload []byte
   753  	b64 := getB64Value(hdr)
   754  	if !b64 {
   755  		decodedPayload = payload
   756  	} else {
   757  		v, err := base64.Decode(payload)
   758  		if err != nil {
   759  			return nil, fmt.Errorf(`failed to decode payload: %w`, err)
   760  		}
   761  		decodedPayload = v
   762  	}
   763  
   764  	decodedSignature, err := base64.Decode(signature)
   765  	if err != nil {
   766  		return nil, fmt.Errorf(`failed to decode signature: %w`, err)
   767  	}
   768  
   769  	var msg Message
   770  	msg.payload = decodedPayload
   771  	msg.signatures = append(msg.signatures, &Signature{
   772  		protected: hdr,
   773  		signature: decodedSignature,
   774  	})
   775  	msg.b64 = b64
   776  	return &msg, nil
   777  }
   778  
   779  // RegisterCustomField allows users to specify that a private field
   780  // be decoded as an instance of the specified type. This option has
   781  // a global effect.
   782  //
   783  // For example, suppose you have a custom field `x-birthday`, which
   784  // you want to represent as a string formatted in RFC3339 in JSON,
   785  // but want it back as `time.Time`.
   786  //
   787  // In that case you would register a custom field as follows
   788  //
   789  //	jwe.RegisterCustomField(`x-birthday`, timeT)
   790  //
   791  // Then `hdr.Get("x-birthday")` will still return an `interface{}`,
   792  // but you can convert its type to `time.Time`
   793  //
   794  //	bdayif, _ := hdr.Get(`x-birthday`)
   795  //	bday := bdayif.(time.Time)
   796  func RegisterCustomField(name string, object interface{}) {
   797  	registry.Register(name, object)
   798  }
   799  
   800  // Helpers for signature verification
   801  var rawKeyToKeyType = make(map[reflect.Type]jwa.KeyType)
   802  var keyTypeToAlgorithms = make(map[jwa.KeyType][]jwa.SignatureAlgorithm)
   803  
   804  func init() {
   805  	rawKeyToKeyType[reflect.TypeOf([]byte(nil))] = jwa.OctetSeq
   806  	rawKeyToKeyType[reflect.TypeOf(ed25519.PublicKey(nil))] = jwa.OKP
   807  	rawKeyToKeyType[reflect.TypeOf(rsa.PublicKey{})] = jwa.RSA
   808  	rawKeyToKeyType[reflect.TypeOf((*rsa.PublicKey)(nil))] = jwa.RSA
   809  	rawKeyToKeyType[reflect.TypeOf(ecdsa.PublicKey{})] = jwa.EC
   810  	rawKeyToKeyType[reflect.TypeOf((*ecdsa.PublicKey)(nil))] = jwa.EC
   811  
   812  	addAlgorithmForKeyType(jwa.OKP, jwa.EdDSA)
   813  	for _, alg := range []jwa.SignatureAlgorithm{jwa.HS256, jwa.HS384, jwa.HS512} {
   814  		addAlgorithmForKeyType(jwa.OctetSeq, alg)
   815  	}
   816  	for _, alg := range []jwa.SignatureAlgorithm{jwa.RS256, jwa.RS384, jwa.RS512, jwa.PS256, jwa.PS384, jwa.PS512} {
   817  		addAlgorithmForKeyType(jwa.RSA, alg)
   818  	}
   819  	for _, alg := range []jwa.SignatureAlgorithm{jwa.ES256, jwa.ES384, jwa.ES512} {
   820  		addAlgorithmForKeyType(jwa.EC, alg)
   821  	}
   822  }
   823  
   824  func addAlgorithmForKeyType(kty jwa.KeyType, alg jwa.SignatureAlgorithm) {
   825  	keyTypeToAlgorithms[kty] = append(keyTypeToAlgorithms[kty], alg)
   826  }
   827  
   828  // AlgorithmsForKey returns the possible signature algorithms that can
   829  // be used for a given key. It only takes in consideration keys/algorithms
   830  // for verification purposes, as this is the only usage where one may need
   831  // dynamically figure out which method to use.
   832  func AlgorithmsForKey(key interface{}) ([]jwa.SignatureAlgorithm, error) {
   833  	var kty jwa.KeyType
   834  	switch key := key.(type) {
   835  	case jwk.Key:
   836  		kty = key.KeyType()
   837  	case rsa.PublicKey, *rsa.PublicKey, rsa.PrivateKey, *rsa.PrivateKey:
   838  		kty = jwa.RSA
   839  	case ecdsa.PublicKey, *ecdsa.PublicKey, ecdsa.PrivateKey, *ecdsa.PrivateKey:
   840  		kty = jwa.EC
   841  	case ed25519.PublicKey, ed25519.PrivateKey, x25519.PublicKey, x25519.PrivateKey:
   842  		kty = jwa.OKP
   843  	case []byte:
   844  		kty = jwa.OctetSeq
   845  	default:
   846  		return nil, fmt.Errorf(`invalid key %T`, key)
   847  	}
   848  
   849  	algs, ok := keyTypeToAlgorithms[kty]
   850  	if !ok {
   851  		return nil, fmt.Errorf(`invalid key type %q`, kty)
   852  	}
   853  	return algs, nil
   854  }
   855  
   856  // Because the keys defined in github.com/lestrrat-go/jwx/jwk may also implement
   857  // crypto.Signer, it would be possible for to mix up key types when signing/verifying
   858  // for example, when we specify jws.WithKey(jwa.RSA256, cryptoSigner), the cryptoSigner
   859  // can be for RSA, or any other type that implements crypto.Signer... even if it's for the
   860  // wrong algorithm.
   861  //
   862  // These functions are there to differentiate between the valid KNOWN key types.
   863  // For any other key type that is outside of the Go std library and our own code,
   864  // we must rely on the user to be vigilant.
   865  //
   866  // Notes: symmetric keys are obviously not part of this. for v2 OKP keys,
   867  // x25519 does not implement Sign()
   868  func isValidRSAKey(key interface{}) bool {
   869  	switch key.(type) {
   870  	case
   871  		ecdsa.PrivateKey, *ecdsa.PrivateKey,
   872  		ed25519.PrivateKey,
   873  		jwk.ECDSAPrivateKey, jwk.OKPPrivateKey:
   874  		// these are NOT ok
   875  		return false
   876  	}
   877  	return true
   878  }
   879  
   880  func isValidECDSAKey(key interface{}) bool {
   881  	switch key.(type) {
   882  	case
   883  		ed25519.PrivateKey,
   884  		rsa.PrivateKey, *rsa.PrivateKey,
   885  		jwk.RSAPrivateKey, jwk.OKPPrivateKey:
   886  		// these are NOT ok
   887  		return false
   888  	}
   889  	return true
   890  }
   891  
   892  func isValidEDDSAKey(key interface{}) bool {
   893  	switch key.(type) {
   894  	case
   895  		ecdsa.PrivateKey, *ecdsa.PrivateKey,
   896  		rsa.PrivateKey, *rsa.PrivateKey,
   897  		jwk.RSAPrivateKey, jwk.ECDSAPrivateKey:
   898  		// these are NOT ok
   899  		return false
   900  	}
   901  	return true
   902  }