github.com/hyperledger/aries-framework-go@v0.3.2/pkg/doc/jose/decrypter.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package jose
     8  
     9  import (
    10  	"crypto/ecdsa"
    11  	"encoding/base64"
    12  	"encoding/json"
    13  	"fmt"
    14  	"strings"
    15  
    16  	"github.com/google/tink/go/keyset"
    17  
    18  	ecdhpb "github.com/hyperledger/aries-framework-go/component/kmscrypto/crypto/tinkcrypto/primitive/proto/ecdh_aead_go_proto"
    19  	"github.com/hyperledger/aries-framework-go/component/kmscrypto/doc/jose/jwk"
    20  	cryptoapi "github.com/hyperledger/aries-framework-go/pkg/crypto"
    21  	"github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto/primitive/composite"
    22  	"github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto/primitive/composite/api"
    23  	"github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto/primitive/composite/ecdh"
    24  	"github.com/hyperledger/aries-framework-go/pkg/crypto/tinkcrypto/primitive/composite/keyio"
    25  	"github.com/hyperledger/aries-framework-go/pkg/doc/jose/kid/resolver"
    26  	"github.com/hyperledger/aries-framework-go/pkg/kms"
    27  )
    28  
    29  // Decrypter interface to Decrypt JWE messages.
    30  type Decrypter interface {
    31  	// Decrypt a deserialized JWE, extracts the corresponding recipient key to decrypt plaintext and returns it
    32  	Decrypt(jwe *JSONWebEncryption) ([]byte, error)
    33  }
    34  
    35  // JWEDecrypt is responsible for decrypting a JWE message and returns its protected plaintext.
    36  type JWEDecrypt struct {
    37  	kidResolvers []resolver.KIDResolver
    38  	crypto       cryptoapi.Crypto
    39  	kms          kms.KeyManager
    40  }
    41  
    42  // NewJWEDecrypt creates a new JWEDecrypt instance to parse and decrypt a JWE message for a given recipient
    43  // store is needed for Authcrypt only (to fetch sender's pre agreed upon public key), it is not needed for Anoncrypt.
    44  func NewJWEDecrypt(kidResolvers []resolver.KIDResolver, c cryptoapi.Crypto, k kms.KeyManager) *JWEDecrypt {
    45  	return &JWEDecrypt{
    46  		kidResolvers: kidResolvers,
    47  		crypto:       c,
    48  		kms:          k,
    49  	}
    50  }
    51  
    52  func getECDHDecPrimitive(cek []byte, encAlg EncAlg, nistpKW bool) (api.CompositeDecrypt, error) {
    53  	ceAlg := aeadAlg[encAlg]
    54  
    55  	if ceAlg <= 0 {
    56  		return nil, fmt.Errorf("invalid content encAlg: '%s'", encAlg)
    57  	}
    58  
    59  	kt := ecdh.KeyTemplateForECDHPrimitiveWithCEK(cek, nistpKW, ceAlg)
    60  
    61  	kh, err := keyset.NewHandle(kt)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	return ecdh.NewECDHDecrypt(kh)
    67  }
    68  
    69  // Decrypt a deserialized JWE, decrypts its protected content and returns plaintext.
    70  func (jd *JWEDecrypt) Decrypt(jwe *JSONWebEncryption) ([]byte, error) {
    71  	encAlg, err := jd.validateAndExtractProtectedHeaders(jwe)
    72  	if err != nil {
    73  		return nil, fmt.Errorf("jwedecrypt: %w", err)
    74  	}
    75  
    76  	var wkOpts []cryptoapi.WrapKeyOpts
    77  
    78  	skid, ok := jwe.ProtectedHeaders.SenderKeyID()
    79  	if !ok {
    80  		skid, ok = fetchSKIDFromAPU(jwe)
    81  	}
    82  
    83  	if ok && skid != "" {
    84  		senderKH, e := jd.fetchSenderPubKey(skid, EncAlg(encAlg))
    85  		if e != nil {
    86  			return nil, fmt.Errorf("jwedecrypt: failed to add sender public key for skid: %w", e)
    87  		}
    88  
    89  		wkOpts = append(wkOpts, cryptoapi.WithSender(senderKH), cryptoapi.WithTag([]byte(jwe.Tag)))
    90  	}
    91  
    92  	recWK, err := buildRecipientsWrappedKey(jwe)
    93  	if err != nil {
    94  		return nil, fmt.Errorf("jwedecrypt: failed to build recipients WK: %w", err)
    95  	}
    96  
    97  	cek, err := jd.unwrapCEK(recWK, wkOpts...)
    98  	if err != nil {
    99  		return nil, fmt.Errorf("jwedecrypt: %w", err)
   100  	}
   101  
   102  	if len(recWK) == 1 {
   103  		// ensure EPK is marshalled the same way as during encryption since it is merged into ProtectHeaders.
   104  		marshalledEPK, err := convertRecEPKToMarshalledJWK(&recWK[0].EPK)
   105  		if err != nil {
   106  			return nil, fmt.Errorf("jwedecrypt: %w", err)
   107  		}
   108  
   109  		jwe.ProtectedHeaders["epk"] = json.RawMessage(marshalledEPK)
   110  	}
   111  
   112  	return jd.decryptJWE(jwe, cek)
   113  }
   114  
   115  func fetchSKIDFromAPU(jwe *JSONWebEncryption) (string, bool) {
   116  	// for multi-recipients only: check apu in protectedHeaders if it's found for ECDH-1PU, if skid header is empty then
   117  	// use apu as skid instead.
   118  	if len(jwe.Recipients) > 1 {
   119  		if a, apuOK := jwe.ProtectedHeaders["apu"]; apuOK {
   120  			skidBytes, err := base64.RawURLEncoding.DecodeString(a.(string))
   121  			if err != nil {
   122  				return "", false
   123  			}
   124  
   125  			return string(skidBytes), true
   126  		}
   127  	}
   128  
   129  	return "", false
   130  }
   131  
   132  //nolint:gocyclo
   133  func (jd *JWEDecrypt) unwrapCEK(recWK []*cryptoapi.RecipientWrappedKey,
   134  	senderOpt ...cryptoapi.WrapKeyOpts) ([]byte, error) {
   135  	var (
   136  		cek  []byte
   137  		errs []error
   138  	)
   139  
   140  	for _, rec := range recWK {
   141  		var unwrapOpts []cryptoapi.WrapKeyOpts
   142  
   143  		if strings.HasPrefix(rec.KID, "did:key") || strings.Index(rec.KID, "#") > 0 {
   144  			// resolve and use kms KID if did:key or KeyAgreement.ID.
   145  			resolvedRec, err := jd.resolveKID(rec.KID)
   146  			if err != nil {
   147  				errs = append(errs, err)
   148  				continue
   149  			}
   150  
   151  			// Need to get the kms KID in order to do kms.Get() since original rec.KID is a did:key/KeyAgreement.ID.
   152  			// This is necessary to ensure recipient is the owner of the key.
   153  			rec.KID = resolvedRec.KID
   154  		}
   155  
   156  		recKH, err := jd.kms.Get(rec.KID)
   157  		if err != nil {
   158  			continue
   159  		}
   160  
   161  		if rec.EPK.Type == ecdhpb.KeyType_OKP.String() {
   162  			unwrapOpts = append(unwrapOpts, cryptoapi.WithXC20PKW())
   163  		}
   164  
   165  		if senderOpt != nil {
   166  			unwrapOpts = append(unwrapOpts, senderOpt...)
   167  		}
   168  
   169  		if len(unwrapOpts) > 0 {
   170  			cek, err = jd.crypto.UnwrapKey(rec, recKH, unwrapOpts...)
   171  		} else {
   172  			cek, err = jd.crypto.UnwrapKey(rec, recKH)
   173  		}
   174  
   175  		if err == nil {
   176  			break
   177  		}
   178  
   179  		errs = append(errs, err)
   180  	}
   181  
   182  	if len(cek) == 0 {
   183  		return nil, fmt.Errorf("failed to unwrap cek: %v", errs)
   184  	}
   185  
   186  	return cek, nil
   187  }
   188  
   189  func (jd *JWEDecrypt) resolveKID(kid string) (*cryptoapi.PublicKey, error) {
   190  	var errs []error
   191  
   192  	for _, resolver := range jd.kidResolvers {
   193  		rKID, err := resolver.Resolve(kid)
   194  		if err == nil {
   195  			return rKID, nil
   196  		}
   197  
   198  		errs = append(errs, err)
   199  	}
   200  
   201  	return nil, fmt.Errorf("resolveKID: %v", errs)
   202  }
   203  
   204  func (jd *JWEDecrypt) decryptJWE(jwe *JSONWebEncryption, cek []byte) ([]byte, error) {
   205  	encAlg, ok := jwe.ProtectedHeaders.Encryption()
   206  	if !ok {
   207  		return nil, fmt.Errorf("jwedecrypt: JWE 'enc' protected header is missing")
   208  	}
   209  
   210  	decPrimitive, err := getECDHDecPrimitive(cek, EncAlg(encAlg), true)
   211  	if err != nil {
   212  		return nil, fmt.Errorf("jwedecrypt: failed to get decryption primitive: %w", err)
   213  	}
   214  
   215  	encryptedData, err := buildEncryptedData(jwe)
   216  	if err != nil {
   217  		return nil, fmt.Errorf("jwedecrypt: failed to build encryptedData for Decrypt(): %w", err)
   218  	}
   219  
   220  	aadBytes := []byte(jwe.AAD)
   221  
   222  	authData, err := computeAuthData(jwe.ProtectedHeaders, jwe.OrigProtectedHders, aadBytes)
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  
   227  	return decPrimitive.Decrypt(encryptedData, authData)
   228  }
   229  
   230  func (jd *JWEDecrypt) fetchSenderPubKey(skid string, encAlg EncAlg) (*keyset.Handle, error) {
   231  	senderKey, err := jd.resolveKID(skid)
   232  	if err != nil {
   233  		return nil, fmt.Errorf("fetchSenderPubKey: %w", err)
   234  	}
   235  
   236  	ceAlg := aeadAlg[encAlg]
   237  
   238  	if ceAlg <= 0 {
   239  		return nil, fmt.Errorf("fetchSenderPubKey: invalid content encAlg: '%s'", encAlg)
   240  	}
   241  
   242  	return keyio.PublicKeyToKeysetHandle(senderKey, ceAlg)
   243  }
   244  
   245  func (jd *JWEDecrypt) validateAndExtractProtectedHeaders(jwe *JSONWebEncryption) (string, error) {
   246  	if jwe == nil {
   247  		return "", fmt.Errorf("jwe is nil")
   248  	}
   249  
   250  	if len(jwe.ProtectedHeaders) == 0 {
   251  		return "", fmt.Errorf("jwe is missing protected headers")
   252  	}
   253  
   254  	protectedHeaders := jwe.ProtectedHeaders
   255  
   256  	encAlg, ok := protectedHeaders.Encryption()
   257  	if !ok {
   258  		return "", fmt.Errorf("jwe is missing encryption algorithm 'enc' header")
   259  	}
   260  
   261  	switch encAlg {
   262  	case string(A256GCM), string(XC20P), string(A128CBCHS256),
   263  		string(A192CBCHS384), string(A256CBCHS384), string(A256CBCHS512):
   264  	default:
   265  		return "", fmt.Errorf("encryption algorithm '%s' not supported", encAlg)
   266  	}
   267  
   268  	return encAlg, nil
   269  }
   270  
   271  func buildRecipientsWrappedKey(jwe *JSONWebEncryption) ([]*cryptoapi.RecipientWrappedKey, error) {
   272  	var (
   273  		recipients []*cryptoapi.RecipientWrappedKey
   274  		err        error
   275  	)
   276  
   277  	for _, recJWE := range jwe.Recipients {
   278  		headers := recJWE.Header
   279  		alg, ok := jwe.ProtectedHeaders.Algorithm()
   280  		is1PU := ok && strings.Contains(strings.ToUpper(alg), "1PU")
   281  
   282  		if len(jwe.Recipients) == 1 || is1PU {
   283  			// compact serialization: it has only 1 recipient with no headers or 1pu, extract from protectedHeaders.
   284  			headers, err = extractRecipientHeaders(jwe.ProtectedHeaders)
   285  			if err != nil {
   286  				return nil, err
   287  			}
   288  		}
   289  
   290  		var recWK *cryptoapi.RecipientWrappedKey
   291  		// set kid if 1PU (authcrypt) with multi recipients since common protected headers don't have the recipient kid.
   292  		if is1PU && len(jwe.Recipients) > 1 {
   293  			headers.KID = recJWE.Header.KID
   294  		}
   295  
   296  		recWK, err = createRecWK(headers, []byte(recJWE.EncryptedKey))
   297  		if err != nil {
   298  			return nil, err
   299  		}
   300  
   301  		recipients = append(recipients, recWK)
   302  	}
   303  
   304  	return recipients, nil
   305  }
   306  
   307  func createRecWK(headers *RecipientHeaders, encryptedKey []byte) (*cryptoapi.RecipientWrappedKey, error) {
   308  	recWK, err := convertMarshalledJWKToRecKey(headers.EPK)
   309  	if err != nil {
   310  		return nil, err
   311  	}
   312  
   313  	recWK.KID = headers.KID
   314  	recWK.Alg = headers.Alg
   315  
   316  	err = updateAPUAPVInRecWK(recWK, headers)
   317  	if err != nil {
   318  		return nil, err
   319  	}
   320  
   321  	recWK.EncryptedCEK = encryptedKey
   322  
   323  	return recWK, nil
   324  }
   325  
   326  func updateAPUAPVInRecWK(recWK *cryptoapi.RecipientWrappedKey, headers *RecipientHeaders) error {
   327  	decodedAPU, decodedAPV, err := decodeAPUAPV(headers)
   328  	if err != nil {
   329  		return fmt.Errorf("updateAPUAPVInRecWK: %w", err)
   330  	}
   331  
   332  	recWK.APU = decodedAPU
   333  	recWK.APV = decodedAPV
   334  
   335  	return nil
   336  }
   337  
   338  func buildEncryptedData(jwe *JSONWebEncryption) ([]byte, error) {
   339  	encData := new(composite.EncryptedData)
   340  	encData.Tag = []byte(jwe.Tag)
   341  	encData.IV = []byte(jwe.IV)
   342  	encData.Ciphertext = []byte(jwe.Ciphertext)
   343  
   344  	return json.Marshal(encData)
   345  }
   346  
   347  // extractRecipientHeaders will extract RecipientHeaders from headers argument.
   348  func extractRecipientHeaders(headers map[string]interface{}) (*RecipientHeaders, error) {
   349  	// Since headers is a generic map, epk value is converted to a generic map by Serialize(), ie we lose RawMessage
   350  	// type of epk. We need to convert epk value (generic map) to marshaled json so we can call RawMessage.Unmarshal()
   351  	// to get the original epk value (RawMessage type).
   352  	mapData, ok := headers[HeaderEPK].(map[string]interface{})
   353  	if !ok {
   354  		return nil, fmt.Errorf("JSON value is not a map (%#v)", headers[HeaderEPK])
   355  	}
   356  
   357  	epkBytes, err := json.Marshal(mapData)
   358  	if err != nil {
   359  		return nil, err
   360  	}
   361  
   362  	epk := json.RawMessage{}
   363  
   364  	err = epk.UnmarshalJSON(epkBytes)
   365  	if err != nil {
   366  		return nil, err
   367  	}
   368  
   369  	alg := ""
   370  	if headers[HeaderAlgorithm] != nil {
   371  		alg = fmt.Sprintf("%v", headers[HeaderAlgorithm])
   372  	}
   373  
   374  	kid := ""
   375  	if headers[HeaderKeyID] != nil {
   376  		kid = fmt.Sprintf("%v", headers[HeaderKeyID])
   377  	}
   378  
   379  	apu := ""
   380  	if headers["apu"] != nil {
   381  		apu = fmt.Sprintf("%v", headers["apu"])
   382  	}
   383  
   384  	apv := ""
   385  	if headers["apv"] != nil {
   386  		apv = fmt.Sprintf("%v", headers["apv"])
   387  	}
   388  
   389  	recHeaders := &RecipientHeaders{
   390  		Alg: alg,
   391  		KID: kid,
   392  		EPK: epk,
   393  		APU: apu,
   394  		APV: apv,
   395  	}
   396  
   397  	// original headers should remain untouched to avoid modifying the original JWE content.
   398  	return recHeaders, nil
   399  }
   400  
   401  func convertMarshalledJWKToRecKey(marshalledJWK []byte) (*cryptoapi.RecipientWrappedKey, error) {
   402  	j := &jwk.JWK{}
   403  
   404  	err := j.UnmarshalJSON(marshalledJWK)
   405  	if err != nil {
   406  		return nil, err
   407  	}
   408  
   409  	epk := cryptoapi.PublicKey{
   410  		Curve: j.Crv,
   411  		Type:  j.Kty,
   412  	}
   413  
   414  	switch key := j.Key.(type) {
   415  	case *ecdsa.PublicKey:
   416  		epk.X = key.X.Bytes()
   417  		epk.Y = key.Y.Bytes()
   418  	case []byte:
   419  		epk.X = key
   420  	default:
   421  		return nil, fmt.Errorf("unsupported recipient key type")
   422  	}
   423  
   424  	return &cryptoapi.RecipientWrappedKey{
   425  		KID: j.KeyID,
   426  		EPK: epk,
   427  	}, nil
   428  }