github.com/trustbloc/kms-go@v1.1.2/doc/jose/decrypter.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package jose
     8  
     9  import (
    10  	"crypto/ecdsa"
    11  	"encoding/base64"
    12  	"encoding/json"
    13  	"fmt"
    14  	"strings"
    15  
    16  	"github.com/google/tink/go/keyset"
    17  
    18  	"github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite"
    19  	"github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite/api"
    20  	"github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite/ecdh"
    21  	"github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite/keyio"
    22  	ecdhpb "github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/proto/ecdh_aead_go_proto"
    23  	"github.com/trustbloc/kms-go/doc/jose/jwk"
    24  	resolver "github.com/trustbloc/kms-go/doc/jose/kidresolver"
    25  
    26  	cryptoapi "github.com/trustbloc/kms-go/spi/crypto"
    27  	"github.com/trustbloc/kms-go/spi/kms"
    28  )
    29  
    30  // Decrypter interface to Decrypt JWE messages.
    31  type Decrypter interface {
    32  	// Decrypt a deserialized JWE, extracts the corresponding recipient key to decrypt plaintext and returns it
    33  	Decrypt(jwe *JSONWebEncryption) ([]byte, error)
    34  }
    35  
    36  // JWEDecrypt is responsible for decrypting a JWE message and returns its protected plaintext.
    37  type JWEDecrypt struct {
    38  	kidResolvers []resolver.KIDResolver
    39  	crypto       cryptoapi.Crypto
    40  	kms          kms.KeyManager
    41  }
    42  
    43  // NewJWEDecrypt creates a new JWEDecrypt instance to parse and decrypt a JWE message for a given recipient
    44  // store is needed for Authcrypt only (to fetch sender's pre agreed upon public key), it is not needed for Anoncrypt.
    45  func NewJWEDecrypt(kidResolvers []resolver.KIDResolver, c cryptoapi.Crypto, k kms.KeyManager) *JWEDecrypt {
    46  	return &JWEDecrypt{
    47  		kidResolvers: kidResolvers,
    48  		crypto:       c,
    49  		kms:          k,
    50  	}
    51  }
    52  
    53  func getECDHDecPrimitive(cek []byte, encAlg EncAlg, nistpKW bool) (api.CompositeDecrypt, error) {
    54  	ceAlg := aeadAlg[encAlg]
    55  
    56  	if ceAlg <= 0 {
    57  		return nil, fmt.Errorf("invalid content encAlg: '%s'", encAlg)
    58  	}
    59  
    60  	kt := ecdh.KeyTemplateForECDHPrimitiveWithCEK(cek, nistpKW, ceAlg)
    61  
    62  	kh, err := keyset.NewHandle(kt)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	return ecdh.NewECDHDecrypt(kh)
    68  }
    69  
    70  // Decrypt a deserialized JWE, decrypts its protected content and returns plaintext.
    71  func (jd *JWEDecrypt) Decrypt(jwe *JSONWebEncryption) ([]byte, error) {
    72  	encAlg, err := jd.validateAndExtractProtectedHeaders(jwe)
    73  	if err != nil {
    74  		return nil, fmt.Errorf("jwedecrypt: %w", err)
    75  	}
    76  
    77  	var wkOpts []cryptoapi.WrapKeyOpts
    78  
    79  	skid, ok := jwe.ProtectedHeaders.SenderKeyID()
    80  	if !ok {
    81  		skid, ok = fetchSKIDFromAPU(jwe)
    82  	}
    83  
    84  	if ok && skid != "" {
    85  		senderKH, e := jd.fetchSenderPubKey(skid, EncAlg(encAlg))
    86  		if e != nil {
    87  			return nil, fmt.Errorf("jwedecrypt: failed to add sender public key for skid: %w", e)
    88  		}
    89  
    90  		wkOpts = append(wkOpts, cryptoapi.WithSender(senderKH), cryptoapi.WithTag([]byte(jwe.Tag)))
    91  	}
    92  
    93  	recWK, err := buildRecipientsWrappedKey(jwe)
    94  	if err != nil {
    95  		return nil, fmt.Errorf("jwedecrypt: failed to build recipients WK: %w", err)
    96  	}
    97  
    98  	cek, err := jd.unwrapCEK(recWK, wkOpts...)
    99  	if err != nil {
   100  		return nil, fmt.Errorf("jwedecrypt: %w", err)
   101  	}
   102  
   103  	if len(recWK) == 1 {
   104  		// ensure EPK is marshalled the same way as during encryption since it is merged into ProtectHeaders.
   105  		marshalledEPK, err := convertRecEPKToMarshalledJWK(&recWK[0].EPK)
   106  		if err != nil {
   107  			return nil, fmt.Errorf("jwedecrypt: %w", err)
   108  		}
   109  
   110  		jwe.ProtectedHeaders["epk"] = json.RawMessage(marshalledEPK)
   111  	}
   112  
   113  	return jd.decryptJWE(jwe, cek)
   114  }
   115  
   116  func fetchSKIDFromAPU(jwe *JSONWebEncryption) (string, bool) {
   117  	// for multi-recipients only: check apu in protectedHeaders if it's found for ECDH-1PU, if skid header is empty then
   118  	// use apu as skid instead.
   119  	if len(jwe.Recipients) > 1 {
   120  		if a, apuOK := jwe.ProtectedHeaders["apu"]; apuOK {
   121  			skidBytes, err := base64.RawURLEncoding.DecodeString(a.(string))
   122  			if err != nil {
   123  				return "", false
   124  			}
   125  
   126  			return string(skidBytes), true
   127  		}
   128  	}
   129  
   130  	return "", false
   131  }
   132  
   133  //nolint:gocyclo
   134  func (jd *JWEDecrypt) unwrapCEK(recWK []*cryptoapi.RecipientWrappedKey,
   135  	senderOpt ...cryptoapi.WrapKeyOpts) ([]byte, error) {
   136  	var (
   137  		cek  []byte
   138  		errs []error
   139  	)
   140  
   141  	for _, rec := range recWK {
   142  		var unwrapOpts []cryptoapi.WrapKeyOpts
   143  
   144  		if strings.HasPrefix(rec.KID, "did:key") || strings.Index(rec.KID, "#") > 0 {
   145  			// resolve and use kms KID if did:key or KeyAgreement.ID.
   146  			resolvedRec, err := jd.resolveKID(rec.KID)
   147  			if err != nil {
   148  				errs = append(errs, err)
   149  				continue
   150  			}
   151  
   152  			// Need to get the kms KID in order to do kms.Get() since original rec.KID is a did:key/KeyAgreement.ID.
   153  			// This is necessary to ensure recipient is the owner of the key.
   154  			rec.KID = resolvedRec.KID
   155  		}
   156  
   157  		recKH, err := jd.kms.Get(rec.KID)
   158  		if err != nil {
   159  			continue
   160  		}
   161  
   162  		if rec.EPK.Type == ecdhpb.KeyType_OKP.String() {
   163  			unwrapOpts = append(unwrapOpts, cryptoapi.WithXC20PKW())
   164  		}
   165  
   166  		if senderOpt != nil {
   167  			unwrapOpts = append(unwrapOpts, senderOpt...)
   168  		}
   169  
   170  		if len(unwrapOpts) > 0 {
   171  			cek, err = jd.crypto.UnwrapKey(rec, recKH, unwrapOpts...)
   172  		} else {
   173  			cek, err = jd.crypto.UnwrapKey(rec, recKH)
   174  		}
   175  
   176  		if err == nil {
   177  			break
   178  		}
   179  
   180  		errs = append(errs, err)
   181  	}
   182  
   183  	if len(cek) == 0 {
   184  		return nil, fmt.Errorf("failed to unwrap cek: %v", errs)
   185  	}
   186  
   187  	return cek, nil
   188  }
   189  
   190  func (jd *JWEDecrypt) resolveKID(kid string) (*cryptoapi.PublicKey, error) {
   191  	var errs []error
   192  
   193  	for _, resolver := range jd.kidResolvers {
   194  		rKID, err := resolver.Resolve(kid)
   195  		if err == nil {
   196  			return rKID, nil
   197  		}
   198  
   199  		errs = append(errs, err)
   200  	}
   201  
   202  	return nil, fmt.Errorf("resolveKID: %v", errs)
   203  }
   204  
   205  func (jd *JWEDecrypt) decryptJWE(jwe *JSONWebEncryption, cek []byte) ([]byte, error) {
   206  	encAlg, ok := jwe.ProtectedHeaders.Encryption()
   207  	if !ok {
   208  		return nil, fmt.Errorf("jwedecrypt: JWE 'enc' protected header is missing")
   209  	}
   210  
   211  	decPrimitive, err := getECDHDecPrimitive(cek, EncAlg(encAlg), true)
   212  	if err != nil {
   213  		return nil, fmt.Errorf("jwedecrypt: failed to get decryption primitive: %w", err)
   214  	}
   215  
   216  	encryptedData, err := buildEncryptedData(jwe)
   217  	if err != nil {
   218  		return nil, fmt.Errorf("jwedecrypt: failed to build encryptedData for Decrypt(): %w", err)
   219  	}
   220  
   221  	aadBytes := []byte(jwe.AAD)
   222  
   223  	authData, err := computeAuthData(jwe.ProtectedHeaders, jwe.OrigProtectedHders, aadBytes)
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  
   228  	return decPrimitive.Decrypt(encryptedData, authData)
   229  }
   230  
   231  func (jd *JWEDecrypt) fetchSenderPubKey(skid string, encAlg EncAlg) (*keyset.Handle, error) {
   232  	senderKey, err := jd.resolveKID(skid)
   233  	if err != nil {
   234  		return nil, fmt.Errorf("fetchSenderPubKey: %w", err)
   235  	}
   236  
   237  	ceAlg := aeadAlg[encAlg]
   238  
   239  	if ceAlg <= 0 {
   240  		return nil, fmt.Errorf("fetchSenderPubKey: invalid content encAlg: '%s'", encAlg)
   241  	}
   242  
   243  	return keyio.PublicKeyToKeysetHandle(senderKey, ceAlg)
   244  }
   245  
   246  func (jd *JWEDecrypt) validateAndExtractProtectedHeaders(jwe *JSONWebEncryption) (string, error) {
   247  	if jwe == nil {
   248  		return "", fmt.Errorf("jwe is nil")
   249  	}
   250  
   251  	if len(jwe.ProtectedHeaders) == 0 {
   252  		return "", fmt.Errorf("jwe is missing protected headers")
   253  	}
   254  
   255  	protectedHeaders := jwe.ProtectedHeaders
   256  
   257  	encAlg, ok := protectedHeaders.Encryption()
   258  	if !ok {
   259  		return "", fmt.Errorf("jwe is missing encryption algorithm 'enc' header")
   260  	}
   261  
   262  	switch encAlg {
   263  	case string(A256GCM), string(XC20P), string(A128CBCHS256),
   264  		string(A192CBCHS384), string(A256CBCHS384), string(A256CBCHS512):
   265  	default:
   266  		return "", fmt.Errorf("encryption algorithm '%s' not supported", encAlg)
   267  	}
   268  
   269  	return encAlg, nil
   270  }
   271  
   272  func buildRecipientsWrappedKey(jwe *JSONWebEncryption) ([]*cryptoapi.RecipientWrappedKey, error) {
   273  	var (
   274  		recipients []*cryptoapi.RecipientWrappedKey
   275  		err        error
   276  	)
   277  
   278  	for _, recJWE := range jwe.Recipients {
   279  		headers := recJWE.Header
   280  		alg, ok := jwe.ProtectedHeaders.Algorithm()
   281  		is1PU := ok && strings.Contains(strings.ToUpper(alg), "1PU")
   282  
   283  		if len(jwe.Recipients) == 1 || is1PU {
   284  			// compact serialization: it has only 1 recipient with no headers or 1pu, extract from protectedHeaders.
   285  			headers, err = extractRecipientHeaders(jwe.ProtectedHeaders)
   286  			if err != nil {
   287  				return nil, err
   288  			}
   289  		}
   290  
   291  		var recWK *cryptoapi.RecipientWrappedKey
   292  		// set kid if 1PU (authcrypt) with multi recipients since common protected headers don't have the recipient kid.
   293  		if is1PU && len(jwe.Recipients) > 1 {
   294  			headers.KID = recJWE.Header.KID
   295  		}
   296  
   297  		recWK, err = createRecWK(headers, []byte(recJWE.EncryptedKey))
   298  		if err != nil {
   299  			return nil, err
   300  		}
   301  
   302  		recipients = append(recipients, recWK)
   303  	}
   304  
   305  	return recipients, nil
   306  }
   307  
   308  func createRecWK(headers *RecipientHeaders, encryptedKey []byte) (*cryptoapi.RecipientWrappedKey, error) {
   309  	recWK, err := convertMarshalledJWKToRecKey(headers.EPK)
   310  	if err != nil {
   311  		return nil, err
   312  	}
   313  
   314  	recWK.KID = headers.KID
   315  	recWK.Alg = headers.Alg
   316  
   317  	err = updateAPUAPVInRecWK(recWK, headers)
   318  	if err != nil {
   319  		return nil, err
   320  	}
   321  
   322  	recWK.EncryptedCEK = encryptedKey
   323  
   324  	return recWK, nil
   325  }
   326  
   327  func updateAPUAPVInRecWK(recWK *cryptoapi.RecipientWrappedKey, headers *RecipientHeaders) error {
   328  	decodedAPU, decodedAPV, err := decodeAPUAPV(headers)
   329  	if err != nil {
   330  		return fmt.Errorf("updateAPUAPVInRecWK: %w", err)
   331  	}
   332  
   333  	recWK.APU = decodedAPU
   334  	recWK.APV = decodedAPV
   335  
   336  	return nil
   337  }
   338  
   339  func buildEncryptedData(jwe *JSONWebEncryption) ([]byte, error) {
   340  	encData := new(composite.EncryptedData)
   341  	encData.Tag = []byte(jwe.Tag)
   342  	encData.IV = []byte(jwe.IV)
   343  	encData.Ciphertext = []byte(jwe.Ciphertext)
   344  
   345  	return json.Marshal(encData)
   346  }
   347  
   348  // extractRecipientHeaders will extract RecipientHeaders from headers argument.
   349  func extractRecipientHeaders(headers map[string]interface{}) (*RecipientHeaders, error) {
   350  	// Since headers is a generic map, epk value is converted to a generic map by Serialize(), ie we lose RawMessage
   351  	// type of epk. We need to convert epk value (generic map) to marshaled json so we can call RawMessage.Unmarshal()
   352  	// to get the original epk value (RawMessage type).
   353  	mapData, ok := headers[HeaderEPK].(map[string]interface{})
   354  	if !ok {
   355  		return nil, fmt.Errorf("JSON value is not a map (%#v)", headers[HeaderEPK])
   356  	}
   357  
   358  	epkBytes, err := json.Marshal(mapData)
   359  	if err != nil {
   360  		return nil, err
   361  	}
   362  
   363  	epk := json.RawMessage{}
   364  
   365  	err = epk.UnmarshalJSON(epkBytes)
   366  	if err != nil {
   367  		return nil, err
   368  	}
   369  
   370  	alg := ""
   371  	if headers[HeaderAlgorithm] != nil {
   372  		alg = fmt.Sprintf("%v", headers[HeaderAlgorithm])
   373  	}
   374  
   375  	kid := ""
   376  	if headers[HeaderKeyID] != nil {
   377  		kid = fmt.Sprintf("%v", headers[HeaderKeyID])
   378  	}
   379  
   380  	apu := ""
   381  	if headers["apu"] != nil {
   382  		apu = fmt.Sprintf("%v", headers["apu"])
   383  	}
   384  
   385  	apv := ""
   386  	if headers["apv"] != nil {
   387  		apv = fmt.Sprintf("%v", headers["apv"])
   388  	}
   389  
   390  	recHeaders := &RecipientHeaders{
   391  		Alg: alg,
   392  		KID: kid,
   393  		EPK: epk,
   394  		APU: apu,
   395  		APV: apv,
   396  	}
   397  
   398  	// original headers should remain untouched to avoid modifying the original JWE content.
   399  	return recHeaders, nil
   400  }
   401  
   402  func convertMarshalledJWKToRecKey(marshalledJWK []byte) (*cryptoapi.RecipientWrappedKey, error) {
   403  	j := &jwk.JWK{}
   404  
   405  	err := j.UnmarshalJSON(marshalledJWK)
   406  	if err != nil {
   407  		return nil, err
   408  	}
   409  
   410  	epk := cryptoapi.PublicKey{
   411  		Curve: j.Crv,
   412  		Type:  j.Kty,
   413  	}
   414  
   415  	switch key := j.Key.(type) {
   416  	case *ecdsa.PublicKey:
   417  		epk.X = key.X.Bytes()
   418  		epk.Y = key.Y.Bytes()
   419  	case []byte:
   420  		epk.X = key
   421  	default:
   422  		return nil, fmt.Errorf("unsupported recipient key type")
   423  	}
   424  
   425  	return &cryptoapi.RecipientWrappedKey{
   426  		KID: j.KeyID,
   427  		EPK: epk,
   428  	}, nil
   429  }