github.com/trustbloc/kms-go@v1.1.2/crypto/tinkcrypto/wrap_support.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package tinkcrypto
     8  
     9  import (
    10  	"bytes"
    11  	"crypto"
    12  	"crypto/aes"
    13  	"crypto/cipher"
    14  	"crypto/ecdsa"
    15  	"crypto/elliptic"
    16  	"crypto/rand"
    17  	"encoding/binary"
    18  	"errors"
    19  	"fmt"
    20  
    21  	josecipher "github.com/go-jose/go-jose/v3/cipher"
    22  	hybrid "github.com/google/tink/go/hybrid/subtle"
    23  	"golang.org/x/crypto/chacha20poly1305"
    24  
    25  	"github.com/trustbloc/kms-go/util/cryptoutil"
    26  
    27  	"github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/aead/subtle"
    28  )
    29  
    30  type keyWrapper interface {
    31  	getCurve(curve string) (elliptic.Curve, error)
    32  	generateKey(curve elliptic.Curve) (interface{}, error)
    33  	createPrimitive(key []byte) (interface{}, error)
    34  	wrap(blockPrimitive interface{}, cek []byte) ([]byte, error)
    35  	unwrap(blockPrimitive interface{}, encryptedKey []byte) ([]byte, error)
    36  	deriveSender1Pu(kwAlg string, apu, apv, tag []byte, ephemeralPriv, senderPrivKey, recPubKey interface{},
    37  		keySize int) ([]byte, error)
    38  	deriveRecipient1Pu(kwAlg string, apu, apv, tag []byte, ephemeralPub, senderPubKey, recPrivKey interface{},
    39  		keySize int) ([]byte, error)
    40  }
    41  
    42  type ecKWSupport struct{}
    43  
    44  func (w *ecKWSupport) getCurve(curve string) (elliptic.Curve, error) {
    45  	return hybrid.GetCurve(curve)
    46  }
    47  
    48  func (w *ecKWSupport) generateKey(curve elliptic.Curve) (interface{}, error) {
    49  	return ecdsa.GenerateKey(curve, rand.Reader)
    50  }
    51  
    52  func (w *ecKWSupport) createPrimitive(kek []byte) (interface{}, error) {
    53  	return aes.NewCipher(kek)
    54  }
    55  
    56  func (w *ecKWSupport) wrap(block interface{}, cek []byte) ([]byte, error) {
    57  	blockCipher, ok := block.(cipher.Block)
    58  	if !ok {
    59  		return nil, errors.New("wrap support: EC wrap with invalid cipher block type")
    60  	}
    61  
    62  	return josecipher.KeyWrap(blockCipher, cek)
    63  }
    64  
    65  func (w *ecKWSupport) unwrap(block interface{}, encryptedKey []byte) ([]byte, error) {
    66  	blockCipher, ok := block.(cipher.Block)
    67  	if !ok {
    68  		return nil, errors.New("unwrap support: EC wrap with invalid cipher block type")
    69  	}
    70  
    71  	return josecipher.KeyUnwrap(blockCipher, encryptedKey)
    72  }
    73  
    74  func (w *ecKWSupport) deriveSender1Pu(alg string, apu, apv, tag []byte, ephemeralPriv, senderPrivKey interface{},
    75  	recPubKey interface{}, keySize int) ([]byte, error) {
    76  	ephemeralPrivEC, ok := ephemeralPriv.(*ecdsa.PrivateKey)
    77  	if !ok {
    78  		return nil, errors.New("deriveSender1Pu: ephemeral key not ECDSA type")
    79  	}
    80  
    81  	senderPrivKeyEC, ok := senderPrivKey.(*ecdsa.PrivateKey)
    82  	if !ok {
    83  		return nil, errors.New("deriveSender1Pu: sender key not ECDSA type")
    84  	}
    85  
    86  	recPubKeyEC, ok := recPubKey.(*ecdsa.PublicKey)
    87  	if !ok {
    88  		return nil, errors.New("deriveSender1Pu: recipient key not ECDSA type")
    89  	}
    90  
    91  	if recPubKeyEC.Curve != ephemeralPrivEC.Curve || recPubKeyEC.Curve != senderPrivKeyEC.Curve {
    92  		return nil, errors.New("deriveSender1Pu: recipient, sender and ephemeral key are not on the same curve")
    93  	}
    94  
    95  	ze := deriveECDH(ephemeralPrivEC, recPubKeyEC, keySize)
    96  	zs := deriveECDH(senderPrivKeyEC, recPubKeyEC, keySize)
    97  
    98  	return derive1Pu(alg, ze, zs, apu, apv, tag, keySize), nil
    99  }
   100  
   101  func (w *ecKWSupport) deriveRecipient1Pu(alg string, apu, apv, tag []byte, ephemeralPub, senderPubKey interface{},
   102  	recPrivKey interface{}, keySize int) ([]byte, error) {
   103  	ephemeralPubEC, ok := ephemeralPub.(*ecdsa.PublicKey)
   104  	if !ok {
   105  		return nil, errors.New("deriveRecipient1Pu: ephemeral key not ECDSA type")
   106  	}
   107  
   108  	senderPubKeyEC, ok := senderPubKey.(*ecdsa.PublicKey)
   109  	if !ok {
   110  		return nil, errors.New("deriveRecipient1Pu: sender key not ECDSA type")
   111  	}
   112  
   113  	recPrivKeyEC, ok := recPrivKey.(*ecdsa.PrivateKey)
   114  	if !ok {
   115  		return nil, errors.New("deriveRecipient1Pu: recipient key not ECDSA type")
   116  	}
   117  
   118  	if recPrivKeyEC.Curve != ephemeralPubEC.Curve || recPrivKeyEC.Curve != senderPubKeyEC.Curve {
   119  		return nil, errors.New("deriveRecipient1Pu: recipient, sender and ephemeral key are not on the same curve")
   120  	}
   121  
   122  	// DeriveECDHES checks if keys are on the same curve
   123  	ze := deriveECDH(recPrivKeyEC, ephemeralPubEC, keySize)
   124  	zs := deriveECDH(recPrivKeyEC, senderPubKeyEC, keySize)
   125  
   126  	return derive1Pu(alg, ze, zs, apu, apv, tag, keySize), nil
   127  }
   128  
   129  const byteSize = 8
   130  
   131  // deriveECDH does key derivation using ECDH only (without KDF).
   132  func deriveECDH(priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, size int) []byte {
   133  	if size > 1<<16 {
   134  		panic("ECDH-ES output size too large, must be less than or equal to 1<<16")
   135  	}
   136  
   137  	// suppPubInfo is the encoded length of the output size in bits
   138  	supPubInfo := make([]byte, 4)
   139  	binary.BigEndian.PutUint32(supPubInfo, uint32(size)*byteSize)
   140  
   141  	if !priv.PublicKey.Curve.IsOnCurve(pub.X, pub.Y) {
   142  		panic("public key not on same curve as private key")
   143  	}
   144  
   145  	z, _ := priv.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes())
   146  	zBytes := z.Bytes()
   147  
   148  	// Note that calling z.Bytes() on a big.Int may strip leading zero bytes from
   149  	// the returned byte array. This can lead to a problem where zBytes will be
   150  	// shorter than expected which breaks the key derivation. Therefore we must pad
   151  	// to the full length of the expected coordinate here before calling the KDF.
   152  	octSize := dSize(priv.Curve)
   153  	if len(zBytes) != octSize {
   154  		zBytes = append(bytes.Repeat([]byte{0}, octSize-len(zBytes)), zBytes...)
   155  	}
   156  
   157  	return zBytes
   158  }
   159  
   160  func dSize(curve elliptic.Curve) int {
   161  	order := curve.Params().P
   162  	bitLen := order.BitLen()
   163  	size := bitLen / byteSize
   164  
   165  	if bitLen%byteSize != 0 {
   166  		size++
   167  	}
   168  
   169  	return size
   170  }
   171  
   172  type okpKWSupport struct{}
   173  
   174  func (o *okpKWSupport) getCurve(curve string) (elliptic.Curve, error) {
   175  	return nil, errors.New("getCurve: not implemented for OKP KW support")
   176  }
   177  
   178  func (o *okpKWSupport) generateKey(_ elliptic.Curve) (interface{}, error) {
   179  	newKey := make([]byte, cryptoutil.Curve25519KeySize)
   180  
   181  	_, err := rand.Read(newKey)
   182  	if err != nil {
   183  		return nil, fmt.Errorf("generateKey: failed to create X25519 random key: %w", err)
   184  	}
   185  
   186  	return newKey, nil
   187  }
   188  
   189  func (o *okpKWSupport) createPrimitive(kek []byte) (interface{}, error) {
   190  	p, err := chacha20poly1305.NewX(kek)
   191  	if err != nil {
   192  		return nil, fmt.Errorf("createPrimitive: failed to create OKP primitive: %w", err)
   193  	}
   194  
   195  	return p, nil
   196  }
   197  
   198  func (o *okpKWSupport) wrap(aead interface{}, cek []byte) ([]byte, error) {
   199  	aeadPrimitive, ok := aead.(cipher.AEAD)
   200  	if !ok {
   201  		return nil, errors.New("wrap support: OKP wrap with invalid primitive type")
   202  	}
   203  
   204  	nonceSize := aeadPrimitive.NonceSize()
   205  	nonce := make([]byte, nonceSize)
   206  
   207  	_, err := rand.Read(nonce)
   208  	if err != nil {
   209  		return nil, fmt.Errorf("wrap support: failed to generate random nonce: %w", err)
   210  	}
   211  
   212  	cipherText := aeadPrimitive.Seal(nil, nonce, cek, nil)
   213  
   214  	return append(nonce, cipherText...), nil
   215  }
   216  
   217  func (o *okpKWSupport) unwrap(aead interface{}, encryptedKey []byte) ([]byte, error) {
   218  	aeadPrimitive, ok := aead.(cipher.AEAD)
   219  	if !ok {
   220  		return nil, errors.New("unwrap support: OKP unwrap with invalid primitive type")
   221  	}
   222  
   223  	if len(encryptedKey) < aeadPrimitive.NonceSize() {
   224  		return nil, errors.New("unwrap support: OKP unwrap invalid key")
   225  	}
   226  
   227  	nonce := encryptedKey[:aeadPrimitive.NonceSize()]
   228  
   229  	cek, err := aeadPrimitive.Open(nil, nonce, encryptedKey[aeadPrimitive.NonceSize():], nil)
   230  	if err != nil {
   231  		return nil, fmt.Errorf("unwrap support: OKP failed to unwrap key: %w", err)
   232  	}
   233  
   234  	return cek, nil
   235  }
   236  
   237  func (o *okpKWSupport) deriveSender1Pu(kwAlg string, apu, apv, tag []byte, ephemeralPriv, senderPrivKey interface{},
   238  	recPubKey interface{}, _ int) ([]byte, error) {
   239  	ephemeralPrivOKP, ok := ephemeralPriv.([]byte)
   240  	if !ok {
   241  		return nil, errors.New("deriveSender1Pu: ephemeral key not OKP type")
   242  	}
   243  
   244  	ephemeralPrivOKPChacha := new([chacha20poly1305.KeySize]byte)
   245  	copy(ephemeralPrivOKPChacha[:], ephemeralPrivOKP)
   246  
   247  	senderPrivKeyOKP, ok := senderPrivKey.([]byte)
   248  	if !ok {
   249  		return nil, errors.New("deriveSender1Pu: sender key not OKP type")
   250  	}
   251  
   252  	senderPrivKeyOKPChacha := new([chacha20poly1305.KeySize]byte)
   253  	copy(senderPrivKeyOKPChacha[:], senderPrivKeyOKP)
   254  
   255  	recPubKeyOKP, ok := recPubKey.([]byte)
   256  	if !ok {
   257  		return nil, errors.New("deriveSender1Pu: recipient key not OKP type")
   258  	}
   259  
   260  	recPubKeyOKPChacha := new([chacha20poly1305.KeySize]byte)
   261  	copy(recPubKeyOKPChacha[:], recPubKeyOKP)
   262  
   263  	ze, err := cryptoutil.DeriveECDHX25519(ephemeralPrivOKPChacha, recPubKeyOKPChacha)
   264  	if err != nil {
   265  		return nil, fmt.Errorf("deriveSender1Pu: %w", err)
   266  	}
   267  
   268  	zs, err := cryptoutil.DeriveECDHX25519(senderPrivKeyOKPChacha, recPubKeyOKPChacha)
   269  	if err != nil {
   270  		return nil, fmt.Errorf("deriveSender1Pu: %w", err)
   271  	}
   272  
   273  	return derive1Pu(kwAlg, ze, zs, apu, apv, tag, chacha20poly1305.KeySize), nil
   274  }
   275  
   276  func (o *okpKWSupport) deriveRecipient1Pu(kwAlg string, apu, apv, tag []byte, ephemeralPub, senderPubKey interface{},
   277  	recPrivKey interface{}, _ int) ([]byte, error) {
   278  	ephemeralPubOKP, ok := ephemeralPub.([]byte)
   279  	if !ok {
   280  		return nil, errors.New("deriveRecipient1Pu: ephemeral key not OKP type")
   281  	}
   282  
   283  	ephemeralPubOKPChacha := new([chacha20poly1305.KeySize]byte)
   284  	copy(ephemeralPubOKPChacha[:], ephemeralPubOKP)
   285  
   286  	senderPubKeyOKP, ok := senderPubKey.([]byte)
   287  	if !ok {
   288  		return nil, errors.New("deriveRecipient1Pu: sender key not OKP type")
   289  	}
   290  
   291  	senderPubKeyOKPChacha := new([chacha20poly1305.KeySize]byte)
   292  	copy(senderPubKeyOKPChacha[:], senderPubKeyOKP)
   293  
   294  	recPrivKeyOKP, ok := recPrivKey.([]byte)
   295  	if !ok {
   296  		return nil, errors.New("deriveRecipient1Pu: recipient key not OKP type")
   297  	}
   298  
   299  	recPrivKeyOKPChacha := new([chacha20poly1305.KeySize]byte)
   300  	copy(recPrivKeyOKPChacha[:], recPrivKeyOKP)
   301  
   302  	ze, err := cryptoutil.DeriveECDHX25519(recPrivKeyOKPChacha, ephemeralPubOKPChacha)
   303  	if err != nil {
   304  		return nil, fmt.Errorf("deriveRecipient1Pu: %w", err)
   305  	}
   306  
   307  	zs, err := cryptoutil.DeriveECDHX25519(recPrivKeyOKPChacha, senderPubKeyOKPChacha)
   308  	if err != nil {
   309  		return nil, fmt.Errorf("deriveRecipient1Pu: %w", err)
   310  	}
   311  
   312  	return derive1Pu(kwAlg, ze, zs, apu, apv, tag, chacha20poly1305.KeySize), nil
   313  }
   314  
   315  func derive1Pu(kwAlg string, ze, zs, apu, apv, tag []byte, keySize int) []byte {
   316  	z := append([]byte{}, ze...)
   317  	z = append(z, zs...)
   318  
   319  	return kdfWithTag(kwAlg, z, apu, apv, tag, keySize, true)
   320  }
   321  
   322  func kdf(kwAlg string, z, apu, apv []byte, keySize int) []byte {
   323  	return kdfWithTag(kwAlg, z, apu, apv, nil, keySize, false)
   324  }
   325  
   326  func kdfWithTag(kwAlg string, z, apu, apv, tag []byte, keySize int, useTag bool) []byte {
   327  	algID := cryptoutil.LengthPrefix([]byte(kwAlg))
   328  	ptyUInfo := cryptoutil.LengthPrefix(apu)
   329  	ptyVInfo := cryptoutil.LengthPrefix(apv)
   330  
   331  	supPubLen := 4
   332  	supPubInfo := make([]byte, supPubLen)
   333  
   334  	byteLen := 8
   335  	kdfKeySize := keySize
   336  
   337  	switch kwAlg {
   338  	case ECDH1PUA128KWAlg:
   339  		kdfKeySize = subtle.AES128Size
   340  	case ECDH1PUA192KWAlg:
   341  		kdfKeySize = subtle.AES192Size
   342  	case ECDH1PUA256KWAlg:
   343  		kdfKeySize = subtle.AES256Size
   344  	}
   345  
   346  	binary.BigEndian.PutUint32(supPubInfo, uint32(kdfKeySize)*uint32(byteLen))
   347  
   348  	if useTag {
   349  		// append Tag to SuppPubInfo as described here:
   350  		// https://datatracker.ietf.org/doc/html/draft-madden-jose-ecdh-1pu-04#section-2.3
   351  		tagInfo := cryptoutil.LengthPrefix(tag)
   352  		supPubInfo = append(supPubInfo, tagInfo...)
   353  	}
   354  
   355  	reader := josecipher.NewConcatKDF(crypto.SHA256, z, algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{})
   356  
   357  	kek := make([]byte, kdfKeySize)
   358  
   359  	_, _ = reader.Read(kek) // nolint:errcheck // ConcatKDF's Read() never returns an error
   360  
   361  	return kek
   362  }