github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/jwe.go (about)

     1  //go:generate ../tools/cmd/genjwe.sh
     2  
     3  // Package jwe implements JWE as described in https://tools.ietf.org/html/rfc7516
     4  package jwe
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto/ecdsa"
    10  	"crypto/rsa"
    11  	"fmt"
    12  	"io"
    13  	"sync"
    14  
    15  	"github.com/lestrrat-go/blackmagic"
    16  	"github.com/lestrrat-go/jwx/v2/internal/base64"
    17  	"github.com/lestrrat-go/jwx/v2/internal/json"
    18  	"github.com/lestrrat-go/jwx/v2/internal/keyconv"
    19  	"github.com/lestrrat-go/jwx/v2/jwk"
    20  
    21  	"github.com/lestrrat-go/jwx/v2/jwa"
    22  	"github.com/lestrrat-go/jwx/v2/jwe/internal/aescbc"
    23  	"github.com/lestrrat-go/jwx/v2/jwe/internal/content_crypt"
    24  	"github.com/lestrrat-go/jwx/v2/jwe/internal/keyenc"
    25  	"github.com/lestrrat-go/jwx/v2/jwe/internal/keygen"
    26  	"github.com/lestrrat-go/jwx/v2/x25519"
    27  )
    28  
    29  var muSettings sync.RWMutex
    30  var maxPBES2Count = 10000
    31  var maxDecompressBufferSize int64 = 10 * 1024 * 1024 // 10MB
    32  
    33  func Settings(options ...GlobalOption) {
    34  	muSettings.Lock()
    35  	defer muSettings.Unlock()
    36  	//nolint:forcetypeassert
    37  	for _, option := range options {
    38  		switch option.Ident() {
    39  		case identMaxPBES2Count{}:
    40  			maxPBES2Count = option.Value().(int)
    41  		case identMaxDecompressBufferSize{}:
    42  			maxDecompressBufferSize = option.Value().(int64)
    43  		case identMaxBufferSize{}:
    44  			aescbc.SetMaxBufferSize(option.Value().(int64))
    45  		}
    46  	}
    47  }
    48  
    49  const (
    50  	fmtInvalid = iota
    51  	fmtCompact
    52  	fmtJSON
    53  	fmtJSONPretty
    54  	fmtMax
    55  )
    56  
    57  var _ = fmtInvalid
    58  var _ = fmtMax
    59  
    60  var registry = json.NewRegistry()
    61  
    62  type keyEncrypterWrapper struct {
    63  	encrypter KeyEncrypter
    64  }
    65  
    66  func (w *keyEncrypterWrapper) Algorithm() jwa.KeyEncryptionAlgorithm {
    67  	return w.encrypter.Algorithm()
    68  }
    69  
    70  func (w *keyEncrypterWrapper) EncryptKey(cek []byte) (keygen.ByteSource, error) {
    71  	encrypted, err := w.encrypter.EncryptKey(cek)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	return keygen.ByteKey(encrypted), nil
    76  }
    77  
    78  type recipientBuilder struct {
    79  	alg     jwa.KeyEncryptionAlgorithm
    80  	key     interface{}
    81  	headers Headers
    82  }
    83  
    84  func (b *recipientBuilder) Build(cek []byte, calg jwa.ContentEncryptionAlgorithm, cc *content_crypt.Generic) (Recipient, []byte, error) {
    85  	var enc keyenc.Encrypter
    86  
    87  	// we need the raw key for later use
    88  	rawKey := b.key
    89  
    90  	var keyID string
    91  	if ke, ok := b.key.(KeyEncrypter); ok {
    92  		enc = &keyEncrypterWrapper{encrypter: ke}
    93  		if kider, ok := enc.(KeyIDer); ok {
    94  			keyID = kider.KeyID()
    95  		}
    96  	} else if jwkKey, ok := b.key.(jwk.Key); ok {
    97  		// Meanwhile, grab the kid as well
    98  		keyID = jwkKey.KeyID()
    99  
   100  		var raw interface{}
   101  		if err := jwkKey.Raw(&raw); err != nil {
   102  			return nil, nil, fmt.Errorf(`failed to retrieve raw key out of %T: %w`, b.key, err)
   103  		}
   104  
   105  		rawKey = raw
   106  	}
   107  
   108  	if enc == nil {
   109  		switch b.alg {
   110  		case jwa.RSA1_5:
   111  			var pubkey rsa.PublicKey
   112  			if err := keyconv.RSAPublicKey(&pubkey, rawKey); err != nil {
   113  				return nil, nil, fmt.Errorf(`failed to generate public key from key (%T): %w`, rawKey, err)
   114  			}
   115  
   116  			v, err := keyenc.NewRSAPKCSEncrypt(b.alg, &pubkey)
   117  			if err != nil {
   118  				return nil, nil, fmt.Errorf(`failed to create RSA PKCS encrypter: %w`, err)
   119  			}
   120  			enc = v
   121  		case jwa.RSA_OAEP, jwa.RSA_OAEP_256:
   122  			var pubkey rsa.PublicKey
   123  			if err := keyconv.RSAPublicKey(&pubkey, rawKey); err != nil {
   124  				return nil, nil, fmt.Errorf(`failed to generate public key from key (%T): %w`, rawKey, err)
   125  			}
   126  
   127  			v, err := keyenc.NewRSAOAEPEncrypt(b.alg, &pubkey)
   128  			if err != nil {
   129  				return nil, nil, fmt.Errorf(`failed to create RSA OAEP encrypter: %w`, err)
   130  			}
   131  			enc = v
   132  		case jwa.A128KW, jwa.A192KW, jwa.A256KW,
   133  			jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW,
   134  			jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
   135  			sharedkey, ok := rawKey.([]byte)
   136  			if !ok {
   137  				return nil, nil, fmt.Errorf(`invalid key: []byte required (%T)`, rawKey)
   138  			}
   139  
   140  			var err error
   141  			switch b.alg {
   142  			case jwa.A128KW, jwa.A192KW, jwa.A256KW:
   143  				enc, err = keyenc.NewAES(b.alg, sharedkey)
   144  			case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
   145  				enc, err = keyenc.NewPBES2Encrypt(b.alg, sharedkey)
   146  			default:
   147  				enc, err = keyenc.NewAESGCMEncrypt(b.alg, sharedkey)
   148  			}
   149  			if err != nil {
   150  				return nil, nil, fmt.Errorf(`failed to create key wrap encrypter: %w`, err)
   151  			}
   152  			// NOTE: there was formerly a restriction, introduced
   153  			// in PR #26, which disallowed certain key/content
   154  			// algorithm combinations. This seemed bogus, and
   155  			// interop with the jose tool demonstrates it.
   156  		case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
   157  			var keysize int
   158  			switch b.alg {
   159  			case jwa.ECDH_ES:
   160  				// https://tools.ietf.org/html/rfc7518#page-15
   161  				// In Direct Key Agreement mode, the output of the Concat KDF MUST be a
   162  				// key of the same length as that used by the "enc" algorithm.
   163  				keysize = cc.KeySize()
   164  			case jwa.ECDH_ES_A128KW:
   165  				keysize = 16
   166  			case jwa.ECDH_ES_A192KW:
   167  				keysize = 24
   168  			case jwa.ECDH_ES_A256KW:
   169  				keysize = 32
   170  			}
   171  
   172  			switch key := rawKey.(type) {
   173  			case x25519.PublicKey:
   174  				var apu, apv []byte
   175  				if hdrs := b.headers; hdrs != nil {
   176  					apu = hdrs.AgreementPartyUInfo()
   177  					apv = hdrs.AgreementPartyVInfo()
   178  				}
   179  
   180  				v, err := keyenc.NewECDHESEncrypt(b.alg, calg, keysize, rawKey, apu, apv)
   181  				if err != nil {
   182  					return nil, nil, fmt.Errorf(`failed to create ECDHS key wrap encrypter: %w`, err)
   183  				}
   184  				enc = v
   185  			default:
   186  				var pubkey ecdsa.PublicKey
   187  				if err := keyconv.ECDSAPublicKey(&pubkey, rawKey); err != nil {
   188  					return nil, nil, fmt.Errorf(`failed to generate public key from key (%T): %w`, key, err)
   189  				}
   190  
   191  				var apu, apv []byte
   192  				if hdrs := b.headers; hdrs != nil {
   193  					apu = hdrs.AgreementPartyUInfo()
   194  					apv = hdrs.AgreementPartyVInfo()
   195  				}
   196  
   197  				v, err := keyenc.NewECDHESEncrypt(b.alg, calg, keysize, &pubkey, apu, apv)
   198  				if err != nil {
   199  					return nil, nil, fmt.Errorf(`failed to create ECDHS key wrap encrypter: %w`, err)
   200  				}
   201  				enc = v
   202  			}
   203  		case jwa.DIRECT:
   204  			sharedkey, ok := rawKey.([]byte)
   205  			if !ok {
   206  				return nil, nil, fmt.Errorf("invalid key: []byte required")
   207  			}
   208  			enc, _ = keyenc.NewNoop(b.alg, sharedkey)
   209  		default:
   210  			return nil, nil, fmt.Errorf(`invalid key encryption algorithm (%s)`, b.alg)
   211  		}
   212  	}
   213  
   214  	r := NewRecipient()
   215  	if hdrs := b.headers; hdrs != nil {
   216  		_ = r.SetHeaders(hdrs)
   217  	}
   218  
   219  	if err := r.Headers().Set(AlgorithmKey, b.alg); err != nil {
   220  		return nil, nil, fmt.Errorf(`failed to set header: %w`, err)
   221  	}
   222  
   223  	if keyID != "" {
   224  		if err := r.Headers().Set(KeyIDKey, keyID); err != nil {
   225  			return nil, nil, fmt.Errorf(`failed to set header: %w`, err)
   226  		}
   227  	}
   228  
   229  	var rawCEK []byte
   230  	enckey, err := enc.EncryptKey(cek)
   231  	if err != nil {
   232  		return nil, nil, fmt.Errorf(`failed to encrypt key: %w`, err)
   233  	}
   234  	if enc.Algorithm() == jwa.ECDH_ES || enc.Algorithm() == jwa.DIRECT {
   235  		rawCEK = enckey.Bytes()
   236  	} else {
   237  		if err := r.SetEncryptedKey(enckey.Bytes()); err != nil {
   238  			return nil, nil, fmt.Errorf(`failed to set encrypted key: %w`, err)
   239  		}
   240  	}
   241  
   242  	if hp, ok := enckey.(populater); ok {
   243  		if err := hp.Populate(r.Headers()); err != nil {
   244  			return nil, nil, fmt.Errorf(`failed to populate: %w`, err)
   245  		}
   246  	}
   247  
   248  	return r, rawCEK, nil
   249  }
   250  
   251  // Encrypt generates a JWE message for the given payload and returns
   252  // it in serialized form, which can be in either compact or
   253  // JSON format. Default is compact.
   254  //
   255  // You must pass at least one key to `jwe.Encrypt()` by using `jwe.WithKey()`
   256  // option.
   257  //
   258  //	jwe.Encrypt(payload, jwe.WithKey(alg, key))
   259  //	jwe.Encrypt(payload, jws.WithJSON(), jws.WithKey(alg1, key1), jws.WithKey(alg2, key2))
   260  //
   261  // Note that in the second example the `jws.WithJSON()` option is
   262  // specified as well. This is because the compact serialization
   263  // format does not support multiple recipients, and users must
   264  // specifically ask for the JSON serialization format.
   265  //
   266  // Read the documentation for `jwe.WithKey()` to learn more about the
   267  // possible values that can be used for `alg` and `key`.
   268  //
   269  // Look for options that return `jwe.EncryptOption` or `jws.EncryptDecryptOption`
   270  // for a complete list of options that can be passed to this function.
   271  func Encrypt(payload []byte, options ...EncryptOption) ([]byte, error) {
   272  	return encrypt(payload, nil, options...)
   273  }
   274  
   275  // Encryptstatic is exactly like Encrypt, except it accepts a static
   276  // content encryption key (CEK). It is separated out from the main
   277  // Encrypt function such that the latter does not accidentally use a static
   278  // CEK.
   279  //
   280  // DO NOT attempt to use this function unless you completely understand the
   281  // security implications to using static CEKs. You have been warned.
   282  //
   283  // This function is currently considered EXPERIMENTAL, and is subject to
   284  // future changes across minor/micro versions.
   285  func EncryptStatic(payload, cek []byte, options ...EncryptOption) ([]byte, error) {
   286  	if len(cek) <= 0 {
   287  		return nil, fmt.Errorf(`jwe.EncryptStatic: empty CEK`)
   288  	}
   289  	return encrypt(payload, cek, options...)
   290  }
   291  
   292  // encrypt is separate so it can receive cek from outside.
   293  // (but we don't want to receive it in the options slice)
   294  func encrypt(payload, cek []byte, options ...EncryptOption) ([]byte, error) {
   295  	// default content encryption algorithm
   296  	calg := jwa.A256GCM
   297  
   298  	// default compression is "none"
   299  	compression := jwa.NoCompress
   300  
   301  	// default format is compact serialization
   302  	format := fmtCompact
   303  
   304  	// builds each "recipient" with encrypted_key and headers
   305  	var builders []*recipientBuilder
   306  
   307  	var protected Headers
   308  	var mergeProtected bool
   309  	var useRawCEK bool
   310  	for _, option := range options {
   311  		//nolint:forcetypeassert
   312  		switch option.Ident() {
   313  		case identKey{}:
   314  			data := option.Value().(*withKey)
   315  			v, ok := data.alg.(jwa.KeyEncryptionAlgorithm)
   316  			if !ok {
   317  				return nil, fmt.Errorf(`jwe.Encrypt: expected alg to be jwa.KeyEncryptionAlgorithm, but got %T`, data.alg)
   318  			}
   319  
   320  			switch v {
   321  			case jwa.DIRECT, jwa.ECDH_ES:
   322  				useRawCEK = true
   323  			}
   324  
   325  			builders = append(builders, &recipientBuilder{
   326  				alg:     v,
   327  				key:     data.key,
   328  				headers: data.headers,
   329  			})
   330  		case identContentEncryptionAlgorithm{}:
   331  			calg = option.Value().(jwa.ContentEncryptionAlgorithm)
   332  		case identCompress{}:
   333  			compression = option.Value().(jwa.CompressionAlgorithm)
   334  		case identMergeProtectedHeaders{}:
   335  			mergeProtected = option.Value().(bool)
   336  		case identProtectedHeaders{}:
   337  			v := option.Value().(Headers)
   338  			if !mergeProtected || protected == nil {
   339  				protected = v
   340  			} else {
   341  				ctx := context.TODO()
   342  				merged, err := protected.Merge(ctx, v)
   343  				if err != nil {
   344  					return nil, fmt.Errorf(`jwe.Encrypt: failed to merge headers: %w`, err)
   345  				}
   346  				protected = merged
   347  			}
   348  		case identSerialization{}:
   349  			format = option.Value().(int)
   350  		}
   351  	}
   352  
   353  	// We need to have at least one builder
   354  	switch l := len(builders); {
   355  	case l == 0:
   356  		return nil, fmt.Errorf(`jwe.Encrypt: missing key encryption builders: use jwe.WithKey() to specify one`)
   357  	case l > 1:
   358  		if format == fmtCompact {
   359  			return nil, fmt.Errorf(`jwe.Encrypt: cannot use compact serialization when multiple recipients exist (check the number of WithKey() argument, or use WithJSON())`)
   360  		}
   361  	}
   362  
   363  	if useRawCEK {
   364  		if len(builders) != 1 {
   365  			return nil, fmt.Errorf(`jwe.Encrypt: multiple recipients for ECDH-ES/DIRECT mode supported`)
   366  		}
   367  	}
   368  
   369  	// There is exactly one content encrypter.
   370  	contentcrypt, err := content_crypt.NewGeneric(calg)
   371  	if err != nil {
   372  		return nil, fmt.Errorf(`jwe.Encrypt: failed to create AES encrypter: %w`, err)
   373  	}
   374  
   375  	if len(cek) <= 0 {
   376  		generator := keygen.NewRandom(contentcrypt.KeySize())
   377  		bk, err := generator.Generate()
   378  		if err != nil {
   379  			return nil, fmt.Errorf(`jwe.Encrypt: failed to generate key: %w`, err)
   380  		}
   381  		cek = bk.Bytes()
   382  	}
   383  
   384  	recipients := make([]Recipient, len(builders))
   385  	for i, builder := range builders {
   386  		// some builders require hint from the contentcrypt object
   387  		r, rawCEK, err := builder.Build(cek, calg, contentcrypt)
   388  		if err != nil {
   389  			return nil, fmt.Errorf(`jwe.Encrypt: failed to create recipient #%d: %w`, i, err)
   390  		}
   391  		recipients[i] = r
   392  
   393  		// Kinda feels weird, but if useRawCEK == true, we asserted earlier
   394  		// that len(builders) == 1, so this is OK
   395  		if useRawCEK {
   396  			cek = rawCEK
   397  		}
   398  	}
   399  
   400  	if protected == nil {
   401  		protected = NewHeaders()
   402  	}
   403  
   404  	if err := protected.Set(ContentEncryptionKey, calg); err != nil {
   405  		return nil, fmt.Errorf(`jwe.Encrypt: failed to set "enc" in protected header: %w`, err)
   406  	}
   407  
   408  	if compression != jwa.NoCompress {
   409  		payload, err = compress(payload)
   410  		if err != nil {
   411  			return nil, fmt.Errorf(`jwe.Encrypt: failed to compress payload before encryption: %w`, err)
   412  		}
   413  		if err := protected.Set(CompressionKey, compression); err != nil {
   414  			return nil, fmt.Errorf(`jwe.Encrypt: failed to set "zip" in protected header: %w`, err)
   415  		}
   416  	}
   417  
   418  	// If there's only one recipient, you want to include that in the
   419  	// protected header
   420  	if len(recipients) == 1 {
   421  		h, err := protected.Merge(context.TODO(), recipients[0].Headers())
   422  		if err != nil {
   423  			return nil, fmt.Errorf(`jwe.Encrypt: failed to merge protected headers: %w`, err)
   424  		}
   425  		protected = h
   426  	}
   427  
   428  	aad, err := protected.Encode()
   429  	if err != nil {
   430  		return nil, fmt.Errorf(`failed to base64 encode protected headers: %w`, err)
   431  	}
   432  
   433  	iv, ciphertext, tag, err := contentcrypt.Encrypt(cek, payload, aad)
   434  	if err != nil {
   435  		return nil, fmt.Errorf(`failed to encrypt payload: %w`, err)
   436  	}
   437  
   438  	msg := NewMessage()
   439  
   440  	if err := msg.Set(CipherTextKey, ciphertext); err != nil {
   441  		return nil, fmt.Errorf(`failed to set %s: %w`, CipherTextKey, err)
   442  	}
   443  	if err := msg.Set(InitializationVectorKey, iv); err != nil {
   444  		return nil, fmt.Errorf(`failed to set %s: %w`, InitializationVectorKey, err)
   445  	}
   446  	if err := msg.Set(ProtectedHeadersKey, protected); err != nil {
   447  		return nil, fmt.Errorf(`failed to set %s: %w`, ProtectedHeadersKey, err)
   448  	}
   449  	if err := msg.Set(RecipientsKey, recipients); err != nil {
   450  		return nil, fmt.Errorf(`failed to set %s: %w`, RecipientsKey, err)
   451  	}
   452  	if err := msg.Set(TagKey, tag); err != nil {
   453  		return nil, fmt.Errorf(`failed to set %s: %w`, TagKey, err)
   454  	}
   455  
   456  	switch format {
   457  	case fmtCompact:
   458  		return Compact(msg)
   459  	case fmtJSON:
   460  		return json.Marshal(msg)
   461  	case fmtJSONPretty:
   462  		return json.MarshalIndent(msg, "", "  ")
   463  	default:
   464  		return nil, fmt.Errorf(`jwe.Encrypt: invalid serialization`)
   465  	}
   466  }
   467  
   468  type decryptCtx struct {
   469  	msg                     *Message
   470  	aad                     []byte
   471  	cek                     *[]byte
   472  	computedAad             []byte
   473  	keyProviders            []KeyProvider
   474  	protectedHeaders        Headers
   475  	maxDecompressBufferSize int64
   476  }
   477  
   478  // Decrypt takes encrypted payload, and information required to decrypt the
   479  // payload (e.g. the key encryption algorithm and the corresponding
   480  // key to decrypt the JWE message) in its optional arguments. See
   481  // the examples and list of options that return a DecryptOption for possible
   482  // values. Upon successful decryptiond returns the decrypted payload.
   483  //
   484  // The JWE message can be either compact or full JSON format.
   485  //
   486  // When using `jwe.WithKeyEncryptionAlgorithm()`, you can pass a `jwa.KeyAlgorithm`
   487  // for convenience: this is mainly to allow you to directly pass the result of `(jwk.Key).Algorithm()`.
   488  // However, do note that while `(jwk.Key).Algorithm()` could very well contain key encryption
   489  // algorithms, it could also contain other types of values, such as _signature algorithms_.
   490  // In order for `jwe.Decrypt` to work properly, the `alg` parameter must be of type
   491  // `jwa.KeyEncryptionAlgorithm` or otherwise it will cause an error.
   492  //
   493  // When using `jwe.WithKey()`, the value must be a private key.
   494  // It can be either in its raw format (e.g. *rsa.PrivateKey) or a jwk.Key
   495  //
   496  // When the encrypted message is also compressed, the decompressed payload must be
   497  // smaller than the size specified by the `jwe.WithMaxDecompressBufferSize` setting,
   498  // which defaults to 10MB. If the decompressed payload is larger than this size,
   499  // an error is returned.
   500  //
   501  // You can opt to change the MaxDecompressBufferSize setting globally, or on a
   502  // per-call basis by passing the `jwe.WithMaxDecompressBufferSize` option to
   503  // either `jwe.Settings()` or `jwe.Decrypt()`:
   504  //
   505  //	jwe.Settings(jwe.WithMaxDecompressBufferSize(10*1024*1024)) // changes value globally
   506  //	jwe.Decrypt(..., jwe.WithMaxDecompressBufferSize(250*1024)) // changes just for this call
   507  func Decrypt(buf []byte, options ...DecryptOption) ([]byte, error) {
   508  	var keyProviders []KeyProvider
   509  	var keyUsed interface{}
   510  	var cek *[]byte
   511  	var dst *Message
   512  	perCallMaxDecompressBufferSize := maxDecompressBufferSize
   513  	//nolint:forcetypeassert
   514  	for _, option := range options {
   515  		switch option.Ident() {
   516  		case identMessage{}:
   517  			dst = option.Value().(*Message)
   518  		case identKeyProvider{}:
   519  			keyProviders = append(keyProviders, option.Value().(KeyProvider))
   520  		case identKeyUsed{}:
   521  			keyUsed = option.Value()
   522  		case identKey{}:
   523  			pair := option.Value().(*withKey)
   524  			alg, ok := pair.alg.(jwa.KeyEncryptionAlgorithm)
   525  			if !ok {
   526  				return nil, fmt.Errorf(`WithKey() option must be specified using jwa.KeyEncryptionAlgorithm (got %T)`, pair.alg)
   527  			}
   528  			keyProviders = append(keyProviders, &staticKeyProvider{
   529  				alg: alg,
   530  				key: pair.key,
   531  			})
   532  		case identCEK{}:
   533  			cek = option.Value().(*[]byte)
   534  		case identMaxDecompressBufferSize{}:
   535  			perCallMaxDecompressBufferSize = option.Value().(int64)
   536  		}
   537  	}
   538  
   539  	if len(keyProviders) < 1 {
   540  		return nil, fmt.Errorf(`jwe.Decrypt: no key providers have been provided (see jwe.WithKey(), jwe.WithKeySet(), and jwe.WithKeyProvider()`)
   541  	}
   542  
   543  	msg, err := parseJSONOrCompact(buf, true)
   544  	if err != nil {
   545  		return nil, fmt.Errorf(`failed to parse buffer for Decrypt: %w`, err)
   546  	}
   547  
   548  	// Process things that are common to the message
   549  	ctx := context.TODO()
   550  	h, err := msg.protectedHeaders.Clone(ctx)
   551  	if err != nil {
   552  		return nil, fmt.Errorf(`failed to copy protected headers: %w`, err)
   553  	}
   554  	h, err = h.Merge(ctx, msg.unprotectedHeaders)
   555  	if err != nil {
   556  		return nil, fmt.Errorf(`failed to merge headers for message decryption: %w`, err)
   557  	}
   558  
   559  	var aad []byte
   560  	if aadContainer := msg.authenticatedData; aadContainer != nil {
   561  		aad = base64.Encode(aadContainer)
   562  	}
   563  
   564  	var computedAad []byte
   565  	if len(msg.rawProtectedHeaders) > 0 {
   566  		computedAad = msg.rawProtectedHeaders
   567  	} else {
   568  		// this is probably not required once msg.Decrypt is deprecated
   569  		var err error
   570  		computedAad, err = msg.protectedHeaders.Encode()
   571  		if err != nil {
   572  			return nil, fmt.Errorf(`failed to encode protected headers: %w`, err)
   573  		}
   574  	}
   575  
   576  	// for each recipient, attempt to match the key providers
   577  	// if we have no recipients, pretend like we only have one
   578  	recipients := msg.recipients
   579  	if len(recipients) == 0 {
   580  		r := NewRecipient()
   581  		if err := r.SetHeaders(msg.protectedHeaders); err != nil {
   582  			return nil, fmt.Errorf(`failed to set headers to recipient: %w`, err)
   583  		}
   584  		recipients = append(recipients, r)
   585  	}
   586  
   587  	var dctx decryptCtx
   588  
   589  	dctx.aad = aad
   590  	dctx.computedAad = computedAad
   591  	dctx.msg = msg
   592  	dctx.keyProviders = keyProviders
   593  	dctx.protectedHeaders = h
   594  	dctx.cek = cek
   595  	dctx.maxDecompressBufferSize = perCallMaxDecompressBufferSize
   596  
   597  	var lastError error
   598  	for _, recipient := range recipients {
   599  		decrypted, err := dctx.try(ctx, recipient, keyUsed)
   600  		if err != nil {
   601  			lastError = err
   602  			continue
   603  		}
   604  		if dst != nil {
   605  			*dst = *msg
   606  			dst.rawProtectedHeaders = nil
   607  			dst.storeProtectedHeaders = false
   608  		}
   609  		return decrypted, nil
   610  	}
   611  	return nil, fmt.Errorf(`jwe.Decrypt: failed to decrypt any of the recipients (last error = %w)`, lastError)
   612  }
   613  
   614  func (dctx *decryptCtx) try(ctx context.Context, recipient Recipient, keyUsed interface{}) ([]byte, error) {
   615  	var tried int
   616  	var lastError error
   617  	for i, kp := range dctx.keyProviders {
   618  		var sink algKeySink
   619  		if err := kp.FetchKeys(ctx, &sink, recipient, dctx.msg); err != nil {
   620  			return nil, fmt.Errorf(`key provider %d failed: %w`, i, err)
   621  		}
   622  
   623  		for _, pair := range sink.list {
   624  			tried++
   625  			// alg is converted here because pair.alg is of type jwa.KeyAlgorithm.
   626  			// this may seem ugly, but we're trying to avoid declaring separate
   627  			// structs for `alg jwa.KeyAlgorithm` and `alg jwa.SignatureAlgorithm`
   628  			//nolint:forcetypeassert
   629  			alg := pair.alg.(jwa.KeyEncryptionAlgorithm)
   630  			key := pair.key
   631  
   632  			decrypted, err := dctx.decryptContent(ctx, alg, key, recipient)
   633  			if err != nil {
   634  				lastError = err
   635  				continue
   636  			}
   637  
   638  			if keyUsed != nil {
   639  				if err := blackmagic.AssignIfCompatible(keyUsed, key); err != nil {
   640  					return nil, fmt.Errorf(`failed to assign used key (%T) to %T: %w`, key, keyUsed, err)
   641  				}
   642  			}
   643  			return decrypted, nil
   644  		}
   645  	}
   646  	return nil, fmt.Errorf(`jwe.Decrypt: tried %d keys, but failed to match any of the keys with recipient (last error = %s)`, tried, lastError)
   647  }
   648  
   649  func (dctx *decryptCtx) decryptContent(ctx context.Context, alg jwa.KeyEncryptionAlgorithm, key interface{}, recipient Recipient) ([]byte, error) {
   650  	if jwkKey, ok := key.(jwk.Key); ok {
   651  		var raw interface{}
   652  		if err := jwkKey.Raw(&raw); err != nil {
   653  			return nil, fmt.Errorf(`failed to retrieve raw key from %T: %w`, key, err)
   654  		}
   655  		key = raw
   656  	}
   657  
   658  	dec := newDecrypter(alg, dctx.msg.protectedHeaders.ContentEncryption(), key).
   659  		AuthenticatedData(dctx.aad).
   660  		ComputedAuthenticatedData(dctx.computedAad).
   661  		InitializationVector(dctx.msg.initializationVector).
   662  		Tag(dctx.msg.tag).
   663  		CEK(dctx.cek)
   664  
   665  	if recipient.Headers().Algorithm() != alg {
   666  		// algorithms don't match
   667  		return nil, fmt.Errorf(`jwe.Decrypt: key and recipient algorithms do not match`)
   668  	}
   669  
   670  	h2, err := dctx.protectedHeaders.Clone(ctx)
   671  	if err != nil {
   672  		return nil, fmt.Errorf(`jwe.Decrypt: failed to copy headers (1): %w`, err)
   673  	}
   674  
   675  	h2, err = h2.Merge(ctx, recipient.Headers())
   676  	if err != nil {
   677  		return nil, fmt.Errorf(`failed to copy headers (2): %w`, err)
   678  	}
   679  
   680  	switch alg {
   681  	case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
   682  		epkif, ok := h2.Get(EphemeralPublicKeyKey)
   683  		if !ok {
   684  			return nil, fmt.Errorf(`failed to get 'epk' field`)
   685  		}
   686  		switch epk := epkif.(type) {
   687  		case jwk.ECDSAPublicKey:
   688  			var pubkey ecdsa.PublicKey
   689  			if err := epk.Raw(&pubkey); err != nil {
   690  				return nil, fmt.Errorf(`failed to get public key: %w`, err)
   691  			}
   692  			dec.PublicKey(&pubkey)
   693  		case jwk.OKPPublicKey:
   694  			var pubkey interface{}
   695  			if err := epk.Raw(&pubkey); err != nil {
   696  				return nil, fmt.Errorf(`failed to get public key: %w`, err)
   697  			}
   698  			dec.PublicKey(pubkey)
   699  		default:
   700  			return nil, fmt.Errorf("unexpected 'epk' type %T for alg %s", epkif, alg)
   701  		}
   702  
   703  		if apu := h2.AgreementPartyUInfo(); len(apu) > 0 {
   704  			dec.AgreementPartyUInfo(apu)
   705  		}
   706  		if apv := h2.AgreementPartyVInfo(); len(apv) > 0 {
   707  			dec.AgreementPartyVInfo(apv)
   708  		}
   709  	case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW:
   710  		ivB64, ok := h2.Get(InitializationVectorKey)
   711  		if ok {
   712  			ivB64Str, ok := ivB64.(string)
   713  			if !ok {
   714  				return nil, fmt.Errorf("unexpected type for 'iv': %T", ivB64)
   715  			}
   716  			iv, err := base64.DecodeString(ivB64Str)
   717  			if err != nil {
   718  				return nil, fmt.Errorf(`failed to b64-decode 'iv': %w`, err)
   719  			}
   720  			dec.KeyInitializationVector(iv)
   721  		}
   722  		tagB64, ok := h2.Get(TagKey)
   723  		if ok {
   724  			tagB64Str, ok := tagB64.(string)
   725  			if !ok {
   726  				return nil, fmt.Errorf("unexpected type for 'tag': %T", tagB64)
   727  			}
   728  			tag, err := base64.DecodeString(tagB64Str)
   729  			if err != nil {
   730  				return nil, fmt.Errorf(`failed to b64-decode 'tag': %w`, err)
   731  			}
   732  			dec.KeyTag(tag)
   733  		}
   734  	case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
   735  		saltB64, ok := h2.Get(SaltKey)
   736  		if !ok {
   737  			return nil, fmt.Errorf(`failed to get 'p2s' field`)
   738  		}
   739  		saltB64Str, ok := saltB64.(string)
   740  		if !ok {
   741  			return nil, fmt.Errorf("unexpected type for 'p2s': %T", saltB64)
   742  		}
   743  
   744  		count, ok := h2.Get(CountKey)
   745  		if !ok {
   746  			return nil, fmt.Errorf(`failed to get 'p2c' field`)
   747  		}
   748  		countFlt, ok := count.(float64)
   749  		if !ok {
   750  			return nil, fmt.Errorf("unexpected type for 'p2c': %T", count)
   751  		}
   752  		muSettings.RLock()
   753  		maxCount := maxPBES2Count
   754  		muSettings.RUnlock()
   755  		if countFlt > float64(maxCount) {
   756  			return nil, fmt.Errorf("invalid 'p2c' value")
   757  		}
   758  		salt, err := base64.DecodeString(saltB64Str)
   759  		if err != nil {
   760  			return nil, fmt.Errorf(`failed to b64-decode 'salt': %w`, err)
   761  		}
   762  		dec.KeySalt(salt)
   763  		dec.KeyCount(int(countFlt))
   764  	}
   765  
   766  	plaintext, err := dec.Decrypt(recipient, dctx.msg.cipherText, dctx.msg)
   767  	if err != nil {
   768  		return nil, fmt.Errorf(`jwe.Decrypt: decryption failed: %w`, err)
   769  	}
   770  
   771  	if h2.Compression() == jwa.Deflate {
   772  		buf, err := uncompress(plaintext, dctx.maxDecompressBufferSize)
   773  		if err != nil {
   774  			return nil, fmt.Errorf(`jwe.Derypt: failed to uncompress payload: %w`, err)
   775  		}
   776  		plaintext = buf
   777  	}
   778  
   779  	if plaintext == nil {
   780  		return nil, fmt.Errorf(`failed to find matching recipient`)
   781  	}
   782  
   783  	return plaintext, nil
   784  }
   785  
   786  // Parse parses the JWE message into a Message object. The JWE message
   787  // can be either compact or full JSON format.
   788  //
   789  // Parse() currently does not take any options, but the API accepts it
   790  // in anticipation of future addition.
   791  func Parse(buf []byte, _ ...ParseOption) (*Message, error) {
   792  	return parseJSONOrCompact(buf, false)
   793  }
   794  
   795  func parseJSONOrCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) {
   796  	buf = bytes.TrimSpace(buf)
   797  	if len(buf) == 0 {
   798  		return nil, fmt.Errorf(`empty buffer`)
   799  	}
   800  
   801  	if buf[0] == '{' {
   802  		return parseJSON(buf, storeProtectedHeaders)
   803  	}
   804  	return parseCompact(buf, storeProtectedHeaders)
   805  }
   806  
   807  // ParseString is the same as Parse, but takes a string.
   808  func ParseString(s string) (*Message, error) {
   809  	return Parse([]byte(s))
   810  }
   811  
   812  // ParseReader is the same as Parse, but takes an io.Reader.
   813  func ParseReader(src io.Reader) (*Message, error) {
   814  	buf, err := io.ReadAll(src)
   815  	if err != nil {
   816  		return nil, fmt.Errorf(`failed to read from io.Reader: %w`, err)
   817  	}
   818  	return Parse(buf)
   819  }
   820  
   821  func parseJSON(buf []byte, storeProtectedHeaders bool) (*Message, error) {
   822  	m := NewMessage()
   823  	m.storeProtectedHeaders = storeProtectedHeaders
   824  	if err := json.Unmarshal(buf, &m); err != nil {
   825  		return nil, fmt.Errorf(`failed to parse JSON: %w`, err)
   826  	}
   827  	return m, nil
   828  }
   829  
   830  func parseCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) {
   831  	parts := bytes.Split(buf, []byte{'.'})
   832  	if len(parts) != 5 {
   833  		return nil, fmt.Errorf(`compact JWE format must have five parts (%d)`, len(parts))
   834  	}
   835  
   836  	hdrbuf, err := base64.Decode(parts[0])
   837  	if err != nil {
   838  		return nil, fmt.Errorf(`failed to parse first part of compact form: %w`, err)
   839  	}
   840  
   841  	protected := NewHeaders()
   842  	if err := json.Unmarshal(hdrbuf, protected); err != nil {
   843  		return nil, fmt.Errorf(`failed to parse header JSON: %w`, err)
   844  	}
   845  
   846  	ivbuf, err := base64.Decode(parts[2])
   847  	if err != nil {
   848  		return nil, fmt.Errorf(`failed to base64 decode iv: %w`, err)
   849  	}
   850  
   851  	ctbuf, err := base64.Decode(parts[3])
   852  	if err != nil {
   853  		return nil, fmt.Errorf(`failed to base64 decode content: %w`, err)
   854  	}
   855  
   856  	tagbuf, err := base64.Decode(parts[4])
   857  	if err != nil {
   858  		return nil, fmt.Errorf(`failed to base64 decode tag: %w`, err)
   859  	}
   860  
   861  	m := NewMessage()
   862  	if err := m.Set(CipherTextKey, ctbuf); err != nil {
   863  		return nil, fmt.Errorf(`failed to set %s: %w`, CipherTextKey, err)
   864  	}
   865  	if err := m.Set(InitializationVectorKey, ivbuf); err != nil {
   866  		return nil, fmt.Errorf(`failed to set %s: %w`, InitializationVectorKey, err)
   867  	}
   868  	if err := m.Set(ProtectedHeadersKey, protected); err != nil {
   869  		return nil, fmt.Errorf(`failed to set %s: %w`, ProtectedHeadersKey, err)
   870  	}
   871  
   872  	if err := m.makeDummyRecipient(string(parts[1]), protected); err != nil {
   873  		return nil, fmt.Errorf(`failed to setup recipient: %w`, err)
   874  	}
   875  
   876  	if err := m.Set(TagKey, tagbuf); err != nil {
   877  		return nil, fmt.Errorf(`failed to set %s: %w`, TagKey, err)
   878  	}
   879  
   880  	if storeProtectedHeaders {
   881  		// This is later used for decryption.
   882  		m.rawProtectedHeaders = parts[0]
   883  	}
   884  
   885  	return m, nil
   886  }
   887  
   888  // RegisterCustomField allows users to specify that a private field
   889  // be decoded as an instance of the specified type. This option has
   890  // a global effect.
   891  //
   892  // For example, suppose you have a custom field `x-birthday`, which
   893  // you want to represent as a string formatted in RFC3339 in JSON,
   894  // but want it back as `time.Time`.
   895  //
   896  // In that case you would register a custom field as follows
   897  //
   898  //	jwe.RegisterCustomField(`x-birthday`, timeT)
   899  //
   900  // Then `hdr.Get("x-birthday")` will still return an `interface{}`,
   901  // but you can convert its type to `time.Time`
   902  //
   903  //	bdayif, _ := hdr.Get(`x-birthday`)
   904  //	bday := bdayif.(time.Time)
   905  func RegisterCustomField(name string, object interface{}) {
   906  	registry.Register(name, object)
   907  }