github.com/trustbloc/kms-go@v1.1.2/crypto/tinkcrypto/primitive/composite/ecdh/ecdh_encrypt_factory.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package ecdh
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  
    13  	"github.com/google/tink/go/core/primitiveset"
    14  	"github.com/google/tink/go/core/registry"
    15  	"github.com/google/tink/go/keyset"
    16  
    17  	"github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite/api"
    18  )
    19  
    20  // NewECDHEncrypt returns an CompositeEncrypt primitive from the given keyset handle.
    21  func NewECDHEncrypt(h *keyset.Handle) (api.CompositeEncrypt, error) {
    22  	return NewECDHEncryptWithKeyManager(h, nil /*keyManager*/)
    23  }
    24  
    25  // NewECDHEncryptWithKeyManager returns an CompositeEncrypt primitive from the given h keyset handle and
    26  // custom km key manager.
    27  func NewECDHEncryptWithKeyManager(h *keyset.Handle, km registry.KeyManager) (api.CompositeEncrypt, error) {
    28  	ps, err := h.PrimitivesWithKeyManager(km)
    29  	if err != nil {
    30  		return nil, fmt.Errorf("ecdh_factory: cannot obtain primitive set: %w", err)
    31  	}
    32  
    33  	return newEncryptPrimitiveSet(ps)
    34  }
    35  
    36  // encryptPrimitiveSet is an CompositeEncrypt implementation that uses the underlying primitive set for encryption.
    37  type encryptPrimitiveSet struct {
    38  	ps *primitiveset.PrimitiveSet
    39  }
    40  
    41  // Asserts that primitiveSet implements the CompositeEncrypt interface.
    42  var _ api.CompositeEncrypt = (*encryptPrimitiveSet)(nil)
    43  
    44  func newEncryptPrimitiveSet(ps *primitiveset.PrimitiveSet) (*encryptPrimitiveSet, error) {
    45  	if _, ok := (ps.Primary.Primitive).(api.CompositeEncrypt); !ok {
    46  		return nil, errors.New("ecdh_factory: not a CompositeEncrypt primitive")
    47  	}
    48  
    49  	for _, primitives := range ps.Entries {
    50  		for _, p := range primitives {
    51  			if _, ok := (p.Primitive).(api.CompositeEncrypt); !ok {
    52  				return nil, errors.New("ecdh_factory: not a CompositeEncrypt primitive")
    53  			}
    54  		}
    55  	}
    56  
    57  	ret := new(encryptPrimitiveSet)
    58  	ret.ps = ps
    59  
    60  	return ret, nil
    61  }
    62  
    63  // Encrypt encrypts the given plaintext using the recipient public key found in the enclosed primitive.
    64  // It returns the ciphertext being a serialized JWE []byte.
    65  func (a *encryptPrimitiveSet) Encrypt(pt, aad []byte) ([]byte, error) {
    66  	primary := a.ps.Primary
    67  
    68  	p, ok := (primary.Primitive).(api.CompositeEncrypt)
    69  	if !ok {
    70  		return nil, errors.New("ecdh_factory: not a CompositeEncrypt primitive")
    71  	}
    72  
    73  	return p.Encrypt(pt, aad)
    74  }