github.com/trustbloc/kms-go@v1.1.2/doc/jose/encrypter.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package jose
     8  
     9  import (
    10  	"bytes"
    11  	"crypto/ecdsa"
    12  	"crypto/elliptic"
    13  	"crypto/rand"
    14  	"crypto/sha256"
    15  	"encoding/base64"
    16  	"encoding/json"
    17  	"errors"
    18  	"fmt"
    19  	"math/big"
    20  	"sort"
    21  	"strings"
    22  
    23  	"github.com/go-jose/go-jose/v3"
    24  	hybrid "github.com/google/tink/go/hybrid/subtle"
    25  	"github.com/google/tink/go/keyset"
    26  	"github.com/google/tink/go/subtle/random"
    27  	"golang.org/x/crypto/curve25519"
    28  
    29  	"github.com/trustbloc/kms-go/crypto/tinkcrypto"
    30  	"github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/aead/subtle"
    31  	"github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite"
    32  	"github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite/api"
    33  	"github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite/ecdh"
    34  	ecdhpb "github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/proto/ecdh_aead_go_proto"
    35  	"github.com/trustbloc/kms-go/doc/jose/jwk"
    36  	"github.com/trustbloc/kms-go/util/cryptoutil"
    37  
    38  	cryptoapi "github.com/trustbloc/kms-go/spi/crypto"
    39  )
    40  
    41  // EncAlg represents the JWE content encryption algorithm.
    42  type EncAlg string
    43  
    44  const (
    45  	// A256GCM for AES256GCM content encryption.
    46  	A256GCM = EncAlg(A256GCMALG)
    47  	// XC20P for XChacha20Poly1305 content encryption.
    48  	XC20P = EncAlg(XC20PALG)
    49  	// A128CBCHS256 for A128CBC-HS256 (AES128-CBC+HMAC-SHA256) content encryption.
    50  	A128CBCHS256 = EncAlg(A128CBCHS256ALG)
    51  	// A192CBCHS384 for A192CBC-HS384 (AES192-CBC+HMAC-SHA384) content encryption.
    52  	A192CBCHS384 = EncAlg(A192CBCHS384ALG)
    53  	// A256CBCHS384 for A256CBC-HS384 (AES256-CBC+HMAC-SHA384) content encryption.
    54  	A256CBCHS384 = EncAlg(A256CBCHS384ALG)
    55  	// A256CBCHS512 for A256CBC-HS512 (AES256-CBC+HMAC-SHA512) content encryption.
    56  	A256CBCHS512 = EncAlg(A256CBCHS512ALG)
    57  )
    58  
    59  // Encrypter interface to Encrypt/Decrypt JWE messages.
    60  type Encrypter interface {
    61  	// EncryptWithAuthData encrypt plaintext and aad sent to more than 1 recipients and returns a valid
    62  	// JSONWebEncryption instance
    63  	EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error)
    64  
    65  	// Encrypt plaintext with empty aad sent to 1 or more recipients and returns a valid JSONWebEncryption instance
    66  	Encrypt(plaintext []byte) (*JSONWebEncryption, error)
    67  }
    68  
    69  // JWEEncrypt is responsible for encrypting a plaintext and its AAD into a protected JWE and decrypting it.
    70  type JWEEncrypt struct {
    71  	recipientsKeys []*cryptoapi.PublicKey
    72  	skid           string
    73  	senderKH       *keyset.Handle
    74  	encAlg         EncAlg
    75  	encTyp         string
    76  	cty            string
    77  	crypto         cryptoapi.Crypto
    78  }
    79  
    80  // NewJWEEncrypt creates a new JWEEncrypt instance to build JWE with recipientsPubKeys
    81  // senderKID and senderKH are used for Authcrypt (to authenticate the sender), if not set JWEEncrypt assumes Anoncrypt.
    82  func NewJWEEncrypt(encAlg EncAlg, envelopMediaType, cty, senderKID string, senderKH *keyset.Handle,
    83  	recipientsPubKeys []*cryptoapi.PublicKey, crypto cryptoapi.Crypto) (*JWEEncrypt, error) {
    84  	if len(recipientsPubKeys) == 0 {
    85  		return nil, fmt.Errorf("empty recipientsPubKeys list")
    86  	}
    87  
    88  	switch encAlg {
    89  	case A256GCM, XC20P, A128CBCHS256, A192CBCHS384, A256CBCHS384, A256CBCHS512:
    90  	default:
    91  		return nil, fmt.Errorf("encryption algorithm '%s' not supported", encAlg)
    92  	}
    93  
    94  	if crypto == nil {
    95  		return nil, errors.New("crypto service is required to create a JWEEncrypt instance")
    96  	}
    97  
    98  	if senderKH != nil {
    99  		// senderKID is required with non empty senderKH
   100  		if senderKID == "" {
   101  			return nil, errors.New("senderKID is required with senderKH")
   102  		}
   103  	}
   104  
   105  	return &JWEEncrypt{
   106  		recipientsKeys: recipientsPubKeys,
   107  		skid:           senderKID,
   108  		senderKH:       senderKH,
   109  		encAlg:         encAlg,
   110  		encTyp:         envelopMediaType,
   111  		cty:            cty,
   112  		crypto:         crypto,
   113  	}, nil
   114  }
   115  
   116  func (je *JWEEncrypt) getECDHEncPrimitive(cek []byte) (api.CompositeEncrypt, error) {
   117  	nistpKW := je.useNISTPKW()
   118  
   119  	encAlg, ok := aeadAlg[je.encAlg]
   120  	if !ok {
   121  		return nil, fmt.Errorf("getECDHEncPrimitive: encAlg not supported: '%v'", je.encAlg)
   122  	}
   123  
   124  	kt := ecdh.KeyTemplateForECDHPrimitiveWithCEK(cek, nistpKW, encAlg)
   125  
   126  	kh, err := keyset.NewHandle(kt)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	pubKH, err := kh.Public()
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  
   136  	return ecdh.NewECDHEncrypt(pubKH)
   137  }
   138  
   139  // Encrypt encrypt plaintext with AAD and returns a JSONWebEncryption instance to serialize a JWE instance.
   140  func (je *JWEEncrypt) Encrypt(plaintext []byte) (*JSONWebEncryption, error) {
   141  	return je.EncryptWithAuthData(plaintext, nil)
   142  }
   143  
   144  // EncryptWithAuthData encrypt plaintext with AAD and returns a JSONWebEncryption instance to serialize a JWE instance.
   145  func (je *JWEEncrypt) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) {
   146  	protectedHeaders := map[string]interface{}{
   147  		HeaderEncryption: je.encAlg,
   148  		HeaderType:       je.encTyp,
   149  	}
   150  
   151  	je.addExtraProtectedHeaders(protectedHeaders)
   152  
   153  	cek := je.newCEK()
   154  
   155  	// creating the crypto primitive requires a pre-built cek
   156  	encPrimitive, err := je.getECDHEncPrimitive(cek)
   157  	if err != nil {
   158  		return nil, fmt.Errorf("jweencrypt: failed to get encryption primitive: %w", err)
   159  	}
   160  
   161  	authData, err := computeAuthData(protectedHeaders, "", aad)
   162  	if err != nil {
   163  		return nil, fmt.Errorf("jweencrypt: computeAuthData: marshal error %w", err)
   164  	}
   165  
   166  	if je.senderKH != nil && je.skid != "" {
   167  		// ecdh-1pu encryption requires CBC+HMAC encAlg types.
   168  		return je.encryptWithSender(encPrimitive, plaintext, authData, cek, aad)
   169  	}
   170  
   171  	return je.encrypt(protectedHeaders, encPrimitive, plaintext, authData, cek, aad)
   172  }
   173  
   174  func (je *JWEEncrypt) encrypt(protectedHeaders map[string]interface{}, encPrimitive api.CompositeEncrypt,
   175  	plaintext, authData, cek, aad []byte) (*JSONWebEncryption, error) {
   176  	recipients, singleRecipientHeaderADDs, err := je.wrapCEKForRecipients(cek, []byte{}, []byte{}, authData, json.Marshal)
   177  	if err != nil {
   178  		return nil, fmt.Errorf("jweencrypt: failed to wrap cek: %w", err)
   179  	}
   180  
   181  	if len(singleRecipientHeaderADDs) > 0 {
   182  		authData = singleRecipientHeaderADDs
   183  	}
   184  
   185  	recipientsHeaders, singleRecipientHeaders, err := je.buildRecs(recipients, false)
   186  	if err != nil {
   187  		return nil, fmt.Errorf("jweencrypt: failed to build recipients: %w", err)
   188  	}
   189  
   190  	serializedEncData, err := encPrimitive.Encrypt(plaintext, authData)
   191  	if err != nil {
   192  		return nil, fmt.Errorf("jweencrypt: failed to Encrypt: %w", err)
   193  	}
   194  
   195  	encData := new(composite.EncryptedData)
   196  
   197  	err = json.Unmarshal(serializedEncData, encData)
   198  	if err != nil {
   199  		return nil, fmt.Errorf("jweencrypt: unmarshal encrypted data failed: %w", err)
   200  	}
   201  
   202  	if singleRecipientHeaders != nil {
   203  		mergeRecipientHeaders(protectedHeaders, singleRecipientHeaders)
   204  	}
   205  
   206  	return getJSONWebEncryption(encData, recipientsHeaders, protectedHeaders, aad), nil
   207  }
   208  
   209  func (je *JWEEncrypt) encryptWithSender(primitive api.CompositeEncrypt,
   210  	plaintext, authData, cek, aad []byte) (*JSONWebEncryption, error) {
   211  	// pre-generate an EPK + compute apu and apv to be added to authData
   212  	apu, apv, err := je.buildAPUAPV()
   213  	if err != nil {
   214  		return nil, fmt.Errorf("jweencryptWithSender: %w", err)
   215  	}
   216  
   217  	// protectHeaders must be replaced for 1PU to match authData for encryption/decryption
   218  	// (including EPK, APU, APV always + alg and kid if single recipient).
   219  	epk, authData, newProtectedHeaders, err := je.generateEPKAndUpdateAuthDataFor1PU(authData, cek, apu, apv)
   220  	if err != nil {
   221  		return nil, fmt.Errorf("jweencryptWithSender: %w", err)
   222  	}
   223  
   224  	serializedEncData, err := primitive.Encrypt(plaintext, authData)
   225  	if err != nil {
   226  		return nil, fmt.Errorf("jweencryptWithSender: failed to Encrypt: %w", err)
   227  	}
   228  
   229  	encData := new(composite.EncryptedData)
   230  
   231  	err = json.Unmarshal(serializedEncData, encData)
   232  	if err != nil {
   233  		return nil, fmt.Errorf("jweencryptWithSender: unmarshal encrypted data failed: %w", err)
   234  	}
   235  
   236  	recipients, _, err := je.wrapCEKForRecipientsWithTagAndEPK(cek, apu, apv, authData,
   237  		encData.Tag, json.Marshal, epk)
   238  	if err != nil {
   239  		return nil, fmt.Errorf("jweencryptWithSender: failed to wrap cek: %w", err)
   240  	}
   241  
   242  	recipientsHeaders, _, err := je.buildRecs(recipients, true)
   243  	if err != nil {
   244  		return nil, fmt.Errorf("jweencryptWithSender: failed to build recipients: %w", err)
   245  	}
   246  
   247  	return getJSONWebEncryption(encData, recipientsHeaders, newProtectedHeaders, aad), nil
   248  }
   249  
   250  func getJSONWebEncryption(encData *composite.EncryptedData, recipientsHeaders []*Recipient,
   251  	protectedHeaders map[string]interface{}, aad []byte) *JSONWebEncryption {
   252  	return &JSONWebEncryption{
   253  		IV:               string(encData.IV),
   254  		Tag:              string(encData.Tag),
   255  		Ciphertext:       string(encData.Ciphertext),
   256  		Recipients:       recipientsHeaders,
   257  		ProtectedHeaders: protectedHeaders,
   258  		AAD:              string(aad),
   259  	}
   260  }
   261  
   262  func (je *JWEEncrypt) wrapCEKForRecipients(cek, apu, apv, aad []byte,
   263  	marshaller marshalFunc) ([]*cryptoapi.RecipientWrappedKey, []byte, error) {
   264  	return je.wrapCEKForRecipientsWithTagAndEPK(cek, apu, apv, aad, nil, marshaller, nil)
   265  }
   266  
   267  func (je *JWEEncrypt) wrapCEKForRecipientsWithTagAndEPK(cek, apu, apv, aad, tag []byte,
   268  	marshaller marshalFunc, epk *cryptoapi.PrivateKey) ([]*cryptoapi.RecipientWrappedKey, []byte, error) {
   269  	var (
   270  		computedAPU []byte
   271  		computedAPV []byte
   272  		err         error
   273  	)
   274  
   275  	if len(tag) > 0 {
   276  		// build apu/apv prior to key wrapping for 1PU only (tag not empty).
   277  		computedAPU, computedAPV, err = je.buildAPUAPV()
   278  		if err != nil {
   279  			return nil, nil, fmt.Errorf("wrapCEKForRecipientsWithTagAndEPK: %w", err)
   280  		}
   281  	}
   282  
   283  	if len(apv) == 0 {
   284  		apv = make([]byte, len(computedAPV))
   285  		copy(apv, computedAPV)
   286  	}
   287  
   288  	if len(apu) == 0 && je.skid != "" {
   289  		apu = make([]byte, len(computedAPU))
   290  		copy(apu, computedAPU)
   291  	}
   292  
   293  	wrapOpts := je.getWrapKeyOpts(tag, epk)
   294  
   295  	rw, kek, err := je.wrapKey(cek, apu, apv, aad, wrapOpts, marshaller)
   296  	if err != nil {
   297  		return nil, nil, fmt.Errorf("wrapCEKForRecipientsWithTagAndEPK: %w", err)
   298  	}
   299  
   300  	return rw, kek, nil
   301  }
   302  
   303  func (je *JWEEncrypt) wrapKey(cek, apu, apv, aad []byte, wrapOpts []cryptoapi.WrapKeyOpts,
   304  	marshaller marshalFunc) ([]*cryptoapi.RecipientWrappedKey, []byte, error) {
   305  	var (
   306  		recipientsWK       []*cryptoapi.RecipientWrappedKey
   307  		singleRecipientAAD []byte
   308  	)
   309  
   310  	for i, recPubKey := range je.recipientsKeys {
   311  		var (
   312  			kek *cryptoapi.RecipientWrappedKey
   313  			err error
   314  		)
   315  
   316  		if len(wrapOpts) > 0 {
   317  			kek, err = je.crypto.WrapKey(cek, apu, apv, recPubKey, wrapOpts...)
   318  		} else {
   319  			kek, err = je.crypto.WrapKey(cek, apu, apv, recPubKey)
   320  		}
   321  
   322  		if err != nil {
   323  			return nil, nil, fmt.Errorf("wrapKey: %d failed: %w", i+1, err)
   324  		}
   325  
   326  		je.encodeAPUAPV(kek)
   327  
   328  		recipientsWK = append(recipientsWK, kek)
   329  
   330  		if len(je.recipientsKeys) == 1 {
   331  			singleRecipientAAD, err = mergeSingleRecipientHeaders(kek, aad, marshaller)
   332  			if err != nil {
   333  				return nil, nil, fmt.Errorf("wrapKey: merge recipent headers failed for %d: %w", i+1, err)
   334  			}
   335  		}
   336  	}
   337  
   338  	return recipientsWK, singleRecipientAAD, nil
   339  }
   340  
   341  func (je *JWEEncrypt) encodeAPUAPV(kek *cryptoapi.RecipientWrappedKey) {
   342  	// APU and APV must be base64URL encoded.
   343  	if len(kek.APU) > 0 {
   344  		apuBytes := make([]byte, len(kek.APU))
   345  		copy(apuBytes, kek.APU)
   346  		kek.APU = make([]byte, base64.RawURLEncoding.EncodedLen(len(apuBytes)))
   347  		base64.RawURLEncoding.Encode(kek.APU, apuBytes)
   348  	}
   349  
   350  	if len(kek.APV) > 0 {
   351  		apvBytes := make([]byte, len(kek.APV))
   352  		copy(apvBytes, kek.APV)
   353  		kek.APV = make([]byte, base64.RawURLEncoding.EncodedLen(len(apvBytes)))
   354  		base64.RawURLEncoding.Encode(kek.APV, apvBytes)
   355  	}
   356  }
   357  
   358  func (je *JWEEncrypt) getWrapKeyOpts(tag []byte, epk *cryptoapi.PrivateKey) []cryptoapi.WrapKeyOpts {
   359  	var wrapOpts []cryptoapi.WrapKeyOpts
   360  
   361  	if je.recipientsKeys[0].Type == "OKP" {
   362  		wrapOpts = append(wrapOpts, cryptoapi.WithXC20PKW())
   363  	}
   364  
   365  	if je.skid != "" && je.senderKH != nil {
   366  		wrapOpts = append(wrapOpts, cryptoapi.WithSender(je.senderKH))
   367  	}
   368  
   369  	if len(tag) > 0 {
   370  		wrapOpts = append(wrapOpts, cryptoapi.WithTag(tag))
   371  	}
   372  
   373  	if epk != nil {
   374  		wrapOpts = append(wrapOpts, cryptoapi.WithEPK(epk))
   375  	}
   376  
   377  	return wrapOpts
   378  }
   379  
   380  // mergeSingleRecipientHeaders for single recipient encryption, recipient header info is available in the key, update
   381  // AAD with this info and return the marshalled merged result.
   382  func mergeSingleRecipientHeaders(recipientWK *cryptoapi.RecipientWrappedKey,
   383  	aad []byte, marshaller marshalFunc) ([]byte, error) {
   384  	var externalAAD []byte
   385  
   386  	aadIdx := len(aad)
   387  
   388  	if i := bytes.Index(aad, []byte(".")); i > 0 {
   389  		aadIdx = i
   390  		externalAAD = append(externalAAD, aad[aadIdx+1:]...)
   391  	}
   392  
   393  	newAAD, err := base64.RawURLEncoding.DecodeString(string(aad[:aadIdx]))
   394  	if err != nil {
   395  		return nil, err
   396  	}
   397  
   398  	rawHeaders := map[string]json.RawMessage{}
   399  
   400  	err = json.Unmarshal(newAAD, &rawHeaders)
   401  	if err != nil {
   402  		return nil, err
   403  	}
   404  
   405  	if recipientWK.KID != "" {
   406  		var kid []byte
   407  
   408  		kid, err = marshaller(recipientWK.KID)
   409  		if err != nil {
   410  			return nil, err
   411  		}
   412  
   413  		rawHeaders["kid"] = kid
   414  	}
   415  
   416  	alg, err := marshaller(recipientWK.Alg)
   417  	if err != nil {
   418  		return nil, err
   419  	}
   420  
   421  	rawHeaders["alg"] = alg
   422  
   423  	err = addKDFHeaders(rawHeaders, recipientWK, marshaller)
   424  	if err != nil {
   425  		return nil, err
   426  	}
   427  
   428  	mAAD, err := marshaller(rawHeaders)
   429  	if err != nil {
   430  		return nil, err
   431  	}
   432  
   433  	mAADStr := []byte(base64.RawURLEncoding.EncodeToString(mAAD))
   434  
   435  	if len(externalAAD) > 0 {
   436  		mAADStr = append(mAADStr, byte('.'))
   437  		mAADStr = append(mAADStr, externalAAD...)
   438  	}
   439  
   440  	return mAADStr, nil
   441  }
   442  
   443  func addKDFHeaders(rawHeaders map[string]json.RawMessage, recipientWK *cryptoapi.RecipientWrappedKey,
   444  	marshaller marshalFunc) error {
   445  	var err error
   446  
   447  	mEPK, err := convertRecEPKToMarshalledJWK(&recipientWK.EPK)
   448  	if err != nil {
   449  		return err
   450  	}
   451  
   452  	rawHeaders["epk"] = mEPK
   453  
   454  	if len(recipientWK.APU) != 0 {
   455  		rawHeaders["apu"], err = marshaller(fmt.Sprintf("%s", recipientWK.APU))
   456  		if err != nil {
   457  			return err
   458  		}
   459  	}
   460  
   461  	if len(recipientWK.APV) != 0 {
   462  		rawHeaders["apv"], err = marshaller(fmt.Sprintf("%s", recipientWK.APV))
   463  		if err != nil {
   464  			return err
   465  		}
   466  	}
   467  
   468  	return nil
   469  }
   470  
   471  func mergeRecipientHeaders(headers map[string]interface{}, recHeaders *RecipientHeaders) {
   472  	headers[HeaderAlgorithm] = recHeaders.Alg
   473  	if recHeaders.KID != "" {
   474  		headers[HeaderKeyID] = recHeaders.KID
   475  	}
   476  
   477  	// EPK, APU, APV will be marshalled by Serialize
   478  	headers[HeaderEPK] = recHeaders.EPK
   479  	if recHeaders.APU != "" {
   480  		headers["apu"] = base64.RawURLEncoding.EncodeToString([]byte(recHeaders.APU))
   481  	}
   482  
   483  	if recHeaders.APV != "" {
   484  		headers["apv"] = base64.RawURLEncoding.EncodeToString([]byte(recHeaders.APV))
   485  	}
   486  }
   487  
   488  func (je *JWEEncrypt) buildRecs(recWKs []*cryptoapi.RecipientWrappedKey,
   489  	forAuthcrypt bool) ([]*Recipient, *RecipientHeaders, error) {
   490  	var (
   491  		recipients             []*Recipient
   492  		singleRecipientHeaders *RecipientHeaders
   493  	)
   494  
   495  	for _, rec := range recWKs {
   496  		recHeaders, err := buildRecipientHeaders(rec, forAuthcrypt)
   497  		if err != nil {
   498  			return nil, nil, err
   499  		}
   500  
   501  		recipients = append(recipients, &Recipient{
   502  			EncryptedKey: string(rec.EncryptedCEK),
   503  			Header:       recHeaders,
   504  		})
   505  	}
   506  
   507  	// if we have only 1 recipient, then assume compact JWE serialization format. This means recipient header should
   508  	// be merged with the JWE envelope's protected headers and not added to the recipients
   509  	if len(recWKs) == 1 {
   510  		var (
   511  			decodedAPU []byte
   512  			decodedAPV []byte
   513  			err        error
   514  		)
   515  
   516  		decodedAPU, decodedAPV, err = decodeAPUAPV(recipients[0].Header)
   517  		if err != nil {
   518  			return nil, nil, err
   519  		}
   520  
   521  		singleRecipientHeaders = &RecipientHeaders{
   522  			Alg: recipients[0].Header.Alg,
   523  			KID: recipients[0].Header.KID,
   524  			EPK: recipients[0].Header.EPK,
   525  			APU: string(decodedAPU),
   526  			APV: string(decodedAPV),
   527  		}
   528  
   529  		recipients[0].Header = nil
   530  	}
   531  
   532  	return recipients, singleRecipientHeaders, nil
   533  }
   534  
   535  func (je *JWEEncrypt) addExtraProtectedHeaders(protectedHeaders map[string]interface{}) {
   536  	// set cty if it's not empty
   537  	if je.cty != "" {
   538  		protectedHeaders[HeaderContentType] = je.cty
   539  	}
   540  
   541  	// set skid if it's not empty
   542  	if je.skid != "" {
   543  		protectedHeaders[HeaderSenderKeyID] = je.skid
   544  	}
   545  }
   546  
   547  func (je *JWEEncrypt) useNISTPKW() bool {
   548  	if je.senderKH == nil {
   549  		return true
   550  	}
   551  
   552  	for _, ki := range je.senderKH.KeysetInfo().KeyInfo {
   553  		switch ki.TypeUrl {
   554  		case "type.hyperledger.org/hyperledger.aries.crypto.tink.NistPEcdhKwPublicKey",
   555  			"type.hyperledger.org/hyperledger.aries.crypto.tink.NistPEcdhKwPrivateKey":
   556  			return true
   557  		case "type.hyperledger.org/hyperledger.aries.crypto.tink.X25519EcdhKwPublicKey",
   558  			"type.hyperledger.org/hyperledger.aries.crypto.tink.X25519EcdhKwPrivateKey":
   559  			return false
   560  		}
   561  	}
   562  
   563  	return true
   564  }
   565  
   566  func (je *JWEEncrypt) newCEK() []byte {
   567  	twoKeys := 2
   568  	defKeySize := 32
   569  
   570  	switch je.encAlg {
   571  	case A256GCM, XC20P:
   572  		return random.GetRandomBytes(uint32(defKeySize))
   573  	case A128CBCHS256:
   574  		return random.GetRandomBytes(uint32(subtle.AES128Size * twoKeys)) // cek: 32 bytes.
   575  	case A192CBCHS384:
   576  		return random.GetRandomBytes(uint32(subtle.AES192Size * twoKeys)) // cek: 48 bytes.
   577  	case A256CBCHS384:
   578  		return random.GetRandomBytes(uint32(subtle.AES256Size + subtle.AES192Size)) // cek: 56 bytes.
   579  	case A256CBCHS512:
   580  		return random.GetRandomBytes(uint32(subtle.AES256Size * twoKeys)) // cek: 64 bytes.
   581  	default:
   582  		return random.GetRandomBytes(uint32(defKeySize)) // default cek: 32 bytes.
   583  	}
   584  }
   585  
   586  func (je *JWEEncrypt) buildAPUAPV() ([]byte, []byte, error) {
   587  	if je.skid == "" {
   588  		return nil, nil, fmt.Errorf("cannot create APU/APV with empty sender skid")
   589  	}
   590  
   591  	if len(je.recipientsKeys) == 0 {
   592  		return nil, nil, fmt.Errorf("cannot create APU/APV with empty recipient keys")
   593  	}
   594  
   595  	var recKIDs []string
   596  
   597  	apu := make([]byte, len(je.skid))
   598  	copy(apu, je.skid)
   599  
   600  	for _, r := range je.recipientsKeys {
   601  		recKIDs = append(recKIDs, r.KID)
   602  	}
   603  
   604  	// set recipients' sorted kids list then SHA256 hashed in apv.
   605  	sort.Strings(recKIDs)
   606  
   607  	apvList := []byte(strings.Join(recKIDs, "."))
   608  	apv32 := sha256.Sum256(apvList)
   609  	apv := make([]byte, 32)
   610  	copy(apv, apv32[:])
   611  
   612  	return apu, apv, nil
   613  }
   614  
   615  func (je *JWEEncrypt) generateEPKAndUpdateAuthDataFor1PU(auth,
   616  	cek, apu, apv []byte) (*cryptoapi.PrivateKey, []byte, map[string]interface{}, error) {
   617  	var epk *cryptoapi.PrivateKey
   618  
   619  	// generate an EPK based on the first recipient.
   620  	epk, kwAlg, err := je.newEPK(cek)
   621  	if err != nil {
   622  		return nil, nil, nil, fmt.Errorf("generateEPKAndUpdateAuthDataFor1PU: %w", err)
   623  	}
   624  
   625  	aadIndex := bytes.Index(auth, []byte("."))
   626  	lastIndex := aadIndex
   627  
   628  	if lastIndex < 0 {
   629  		lastIndex = len(auth)
   630  	}
   631  
   632  	return je.buildCommonAuthData(aadIndex, kwAlg, string(auth[:lastIndex]), auth, apu, apv, epk)
   633  }
   634  
   635  func (je *JWEEncrypt) buildCommonAuthData(aadIndex int, kwAlg, authData string, auth, apu, apv []byte,
   636  	epk *cryptoapi.PrivateKey) (*cryptoapi.PrivateKey, []byte, map[string]interface{}, error) {
   637  	authDataBytes, err := base64.RawURLEncoding.DecodeString(authData)
   638  	if err != nil {
   639  		return nil, nil, nil, fmt.Errorf("buildCommonAuthData: authdata decode: %w", err)
   640  	}
   641  
   642  	authDataJSON := map[string]interface{}{}
   643  
   644  	err = json.Unmarshal(authDataBytes, &authDataJSON)
   645  	if err != nil {
   646  		return nil, nil, nil, fmt.Errorf("buildCommonAuthData: authData unmarshall: %w", err)
   647  	}
   648  
   649  	if len(je.recipientsKeys) == 1 {
   650  		// kid is part of the protected headers for single recipient JWEs.
   651  		authDataJSON["kid"] = je.recipientsKeys[0].KID
   652  	}
   653  
   654  	authDataJSON["alg"] = kwAlg
   655  
   656  	marshalledEPK, err := convertRecEPKToMarshalledJWK(&epk.PublicKey)
   657  	if err != nil {
   658  		return nil, nil, nil, fmt.Errorf("buildCommonAuthData: epk marshall: %w", err)
   659  	}
   660  
   661  	authDataJSON["epk"] = json.RawMessage(marshalledEPK)
   662  
   663  	encodedAPU := []byte(base64.RawURLEncoding.EncodeToString(apu))
   664  	authDataJSON["apu"] = string(encodedAPU)
   665  
   666  	encodedAPV := []byte(base64.RawURLEncoding.EncodeToString(apv))
   667  	authDataJSON["apv"] = string(encodedAPV)
   668  
   669  	newAuth, err := json.Marshal(authDataJSON)
   670  	if err != nil {
   671  		return nil, nil, nil, fmt.Errorf("buildCommonAuthData: authData marshall: %w", err)
   672  	}
   673  
   674  	authData = base64.RawURLEncoding.EncodeToString(newAuth)
   675  
   676  	if aadIndex > 0 {
   677  		authData += string(auth[aadIndex:])
   678  	}
   679  
   680  	return epk, []byte(authData), authDataJSON, nil
   681  }
   682  
   683  func (je *JWEEncrypt) newEPK(cek []byte) (*cryptoapi.PrivateKey, string, error) {
   684  	var (
   685  		kwAlg string
   686  		epk   *cryptoapi.PrivateKey
   687  		err   error
   688  	)
   689  
   690  	switch je.recipientsKeys[0].Type {
   691  	case "EC":
   692  		epk, kwAlg, err = je.ecEPKAndAlg(cek)
   693  		if err != nil {
   694  			return nil, "", fmt.Errorf("newEPK: %w", err)
   695  		}
   696  	case "OKP":
   697  		epk, kwAlg, err = je.okpEPKAndAlg()
   698  		if err != nil {
   699  			return nil, "", fmt.Errorf("newEPK: %w", err)
   700  		}
   701  	default:
   702  		return nil, "", fmt.Errorf("newEPK: invalid key type '%v'", je.recipientsKeys[0].Type)
   703  	}
   704  
   705  	return epk, kwAlg, nil
   706  }
   707  
   708  func (je *JWEEncrypt) ecEPKAndAlg(cek []byte) (*cryptoapi.PrivateKey, string, error) {
   709  	var kwAlg string
   710  
   711  	curve, err := hybrid.GetCurve(je.recipientsKeys[0].Curve)
   712  	if err != nil {
   713  		return nil, "", fmt.Errorf("ecEPKAndAlg: getCurve: %w", err)
   714  	}
   715  
   716  	pk, err := ecdsa.GenerateKey(curve, rand.Reader)
   717  	if err != nil {
   718  		return nil, "", fmt.Errorf("ecEPKAndAlg: generate ec key: %w", err)
   719  	}
   720  
   721  	epk := &cryptoapi.PrivateKey{
   722  		PublicKey: cryptoapi.PublicKey{
   723  			Type:  "EC",
   724  			Curve: pk.Curve.Params().Name,
   725  			X:     pk.X.Bytes(),
   726  			Y:     pk.Y.Bytes(),
   727  		},
   728  		D: pk.D.Bytes(),
   729  	}
   730  
   731  	two := 2
   732  
   733  	switch len(cek) {
   734  	case subtle.AES128Size * two:
   735  		kwAlg = tinkcrypto.ECDH1PUA128KWAlg
   736  	case subtle.AES192Size * two:
   737  		kwAlg = tinkcrypto.ECDH1PUA192KWAlg
   738  	case subtle.AES256Size * two:
   739  		kwAlg = tinkcrypto.ECDH1PUA256KWAlg
   740  	}
   741  
   742  	return epk, kwAlg, nil
   743  }
   744  
   745  func (je *JWEEncrypt) okpEPKAndAlg() (*cryptoapi.PrivateKey, string, error) {
   746  	ephemeralPrivKey := make([]byte, cryptoutil.Curve25519KeySize)
   747  
   748  	_, err := rand.Read(ephemeralPrivKey)
   749  	if err != nil {
   750  		return nil, "", fmt.Errorf("okpEPKAndAlg: generate random key for OKP: %w", err)
   751  	}
   752  
   753  	ephemeralPubKey, err := curve25519.X25519(ephemeralPrivKey, curve25519.Basepoint)
   754  	if err != nil {
   755  		return nil, "", fmt.Errorf("okpEPKAndAlg: get public epk for OKP: %w", err)
   756  	}
   757  
   758  	kwAlg := tinkcrypto.ECDH1PUXC20PKWAlg
   759  
   760  	epk := &cryptoapi.PrivateKey{
   761  		PublicKey: cryptoapi.PublicKey{
   762  			Type:  "OKP",
   763  			Curve: "X25519",
   764  			X:     ephemeralPubKey,
   765  		},
   766  		D: ephemeralPrivKey,
   767  	}
   768  
   769  	return epk, kwAlg, nil
   770  }
   771  
   772  func decodeAPUAPV(headers *RecipientHeaders) ([]byte, []byte, error) {
   773  	var (
   774  		decodedAPU []byte
   775  		decodedAPV []byte
   776  		err        error
   777  	)
   778  
   779  	if len(headers.APU) > 0 {
   780  		decodedAPU, err = base64.RawURLEncoding.DecodeString(headers.APU)
   781  		if err != nil {
   782  			return nil, nil, err
   783  		}
   784  	}
   785  
   786  	if len(headers.APV) > 0 {
   787  		decodedAPV, err = base64.RawURLEncoding.DecodeString(headers.APV)
   788  		if err != nil {
   789  			return nil, nil, err
   790  		}
   791  	}
   792  
   793  	return decodedAPU, decodedAPV, nil
   794  }
   795  
   796  func buildRecipientHeaders(rec *cryptoapi.RecipientWrappedKey, forAuthcrypt bool) (*RecipientHeaders, error) {
   797  	mRecJWK, err := convertRecEPKToMarshalledJWK(&rec.EPK)
   798  	if err != nil {
   799  		return nil, fmt.Errorf("failed to convert recipient key to marshalled JWK: %w", err)
   800  	}
   801  
   802  	rh := &RecipientHeaders{
   803  		KID: rec.KID,
   804  	}
   805  
   806  	// authcrypt envelopes have these headers shared for all recipients in the protected headers.
   807  	if !forAuthcrypt {
   808  		rh.Alg = rec.Alg
   809  		rh.EPK = mRecJWK
   810  		rh.APU = string(rec.APU)
   811  		rh.APV = string(rec.APV)
   812  	}
   813  
   814  	return rh, nil
   815  }
   816  
   817  func convertRecEPKToMarshalledJWK(recEPK *cryptoapi.PublicKey) ([]byte, error) {
   818  	var (
   819  		c   elliptic.Curve
   820  		err error
   821  		key interface{}
   822  	)
   823  
   824  	switch recEPK.Type {
   825  	case ecdhpb.KeyType_EC.String():
   826  		c, err = hybrid.GetCurve(recEPK.Curve)
   827  		if err != nil {
   828  			return nil, err
   829  		}
   830  
   831  		key = &ecdsa.PublicKey{
   832  			Curve: c,
   833  			X:     new(big.Int).SetBytes(recEPK.X),
   834  			Y:     new(big.Int).SetBytes(recEPK.Y),
   835  		}
   836  	case ecdhpb.KeyType_OKP.String():
   837  		key = recEPK.X
   838  	default:
   839  		return nil, errors.New("invalid key type")
   840  	}
   841  
   842  	recJWK := jwk.JWK{
   843  		JSONWebKey: jose.JSONWebKey{
   844  			Key: key,
   845  		},
   846  		Kty: recEPK.Type,
   847  		Crv: recEPK.Curve,
   848  	}
   849  
   850  	return recJWK.MarshalJSON()
   851  }
   852  
   853  // Get the additional authenticated data from a JWE object.
   854  func computeAuthData(protectedHeaders map[string]interface{}, origProtectedHeader string, aad []byte) ([]byte, error) {
   855  	var protected string
   856  
   857  	if len(origProtectedHeader) > 0 {
   858  		// use origProtectedheader if set instead of marshal/unmarshal existing headers. This is important especially
   859  		// for ECDH-1PU because protectHeaders are used in the sender authentication mechanism. JSON keys order must
   860  		// remain untouched. This is critical for successful verification.
   861  		protected = origProtectedHeader
   862  	} else if protectedHeaders != nil {
   863  		protectedHeadersJSON := map[string]json.RawMessage{}
   864  
   865  		for k, v := range protectedHeaders {
   866  			mV, err := json.Marshal(v)
   867  			if err != nil {
   868  				return nil, fmt.Errorf("computeAuthData: %w", err)
   869  			}
   870  
   871  			rawMsg := json.RawMessage(mV) // need to explicitly convert []byte to RawMessage (same as go-jose)
   872  			protectedHeadersJSON[k] = rawMsg
   873  		}
   874  
   875  		err := jwkMarshalEPK(protectedHeadersJSON)
   876  		if err != nil {
   877  			return nil, fmt.Errorf("computeAuthData: %w", err)
   878  		}
   879  
   880  		mProtected, err := json.Marshal(protectedHeadersJSON)
   881  		if err != nil {
   882  			return nil, fmt.Errorf("computeAuthData: %w", err)
   883  		}
   884  
   885  		protected = base64.RawURLEncoding.EncodeToString(mProtected)
   886  	}
   887  
   888  	output := []byte(protected)
   889  	if len(aad) > 0 {
   890  		output = append(output, '.')
   891  
   892  		encLen := base64.RawURLEncoding.EncodedLen(len(aad))
   893  		aadEncoded := make([]byte, encLen)
   894  
   895  		base64.RawURLEncoding.Encode(aadEncoded, aad)
   896  		output = append(output, aadEncoded...)
   897  	}
   898  
   899  	return output, nil
   900  }
   901  
   902  func jwkMarshalEPK(protectedHeadersJSON map[string]json.RawMessage) error {
   903  	// must use jwk.MarshalJSON() to marshal EPK to maintain headers order.
   904  	if protectedHeadersJSON[HeaderEPK] != nil {
   905  		epk := &jwk.JWK{}
   906  
   907  		err := epk.UnmarshalJSON(protectedHeadersJSON[HeaderEPK])
   908  		if err != nil {
   909  			return err
   910  		}
   911  
   912  		mEPK, err := epk.MarshalJSON()
   913  		if err != nil {
   914  			return fmt.Errorf("jwkMarshalEPK: %w", err)
   915  		}
   916  
   917  		protectedHeadersJSON[HeaderEPK] = mEPK
   918  	}
   919  
   920  	return nil
   921  }