github.com/trustbloc/kms-go@v1.1.2/crypto/tinkcrypto/primitive/composite/ecdh/ecdh_decrypt_factory.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package ecdh
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  
    13  	"github.com/google/tink/go/core/cryptofmt"
    14  	"github.com/google/tink/go/core/primitiveset"
    15  	"github.com/google/tink/go/core/registry"
    16  	"github.com/google/tink/go/keyset"
    17  
    18  	"github.com/trustbloc/kms-go/crypto/tinkcrypto/primitive/composite/api"
    19  )
    20  
    21  // NewECDHDecrypt returns an CompositeDecrypt primitive from the given keyset handle.
    22  func NewECDHDecrypt(h *keyset.Handle) (api.CompositeDecrypt, error) {
    23  	return NewECDHDecryptWithKeyManager(h, nil /*keyManager*/)
    24  }
    25  
    26  // NewECDHDecryptWithKeyManager returns an CompositeDecrypt primitive from the given keyset handle and custom key
    27  // manager.
    28  func NewECDHDecryptWithKeyManager(h *keyset.Handle, km registry.KeyManager) (api.CompositeDecrypt, error) {
    29  	ps, err := h.PrimitivesWithKeyManager(km)
    30  	if err != nil {
    31  		return nil, fmt.Errorf("ecdh_factory: cannot obtain primitive set: %w", err)
    32  	}
    33  
    34  	return newDecryptPrimitiveSet(ps)
    35  }
    36  
    37  // decryptPrimitiveSet is an CompositeDecrypt implementation that uses the underlying primitive set for
    38  // decryption.
    39  type decryptPrimitiveSet struct {
    40  	ps *primitiveset.PrimitiveSet
    41  }
    42  
    43  // Asserts that primitiveSet implements the CompositeDecrypt interface.
    44  var _ api.CompositeDecrypt = (*decryptPrimitiveSet)(nil)
    45  
    46  func newDecryptPrimitiveSet(ps *primitiveset.PrimitiveSet) (*decryptPrimitiveSet, error) {
    47  	if _, ok := (ps.Primary.Primitive).(api.CompositeDecrypt); !ok {
    48  		return nil, errors.New("ecdh_factory: not a CompositeDecrypt primitive")
    49  	}
    50  
    51  	for _, primitives := range ps.Entries {
    52  		for _, p := range primitives {
    53  			if _, ok := (p.Primitive).(api.CompositeDecrypt); !ok {
    54  				return nil, errors.New("ecdh_factory: not a CompositeDecrypt primitive")
    55  			}
    56  		}
    57  	}
    58  
    59  	ret := new(decryptPrimitiveSet)
    60  	ret.ps = ps
    61  
    62  	return ret, nil
    63  }
    64  
    65  func (a *decryptPrimitiveSet) entries(ct []byte) map[string][]*primitiveset.Entry {
    66  	cipherEntries := make(map[string][]*primitiveset.Entry)
    67  
    68  	prefixSize := cryptofmt.NonRawPrefixSize
    69  	if len(ct) > prefixSize {
    70  		if entries, err := a.ps.EntriesForPrefix(string(ct[:prefixSize])); err == nil {
    71  			cipherEntries[string(ct[prefixSize:])] = entries
    72  		}
    73  	}
    74  
    75  	if entries, err := a.ps.RawEntries(); err == nil {
    76  		cipherEntries[string(ct)] = entries
    77  	}
    78  
    79  	return cipherEntries
    80  }
    81  
    82  // Decrypt decrypts the given ciphertext and authenticates it with the given
    83  // additional authenticated data. It returns the corresponding plaintext if the
    84  // ciphertext is authenticated.
    85  func (a *decryptPrimitiveSet) Decrypt(ct, aad []byte) ([]byte, error) {
    86  	for cipher, entries := range a.entries(ct) {
    87  		for _, e := range entries {
    88  			p, ok := (e.Primitive).(api.CompositeDecrypt)
    89  			if !ok {
    90  				return nil, errors.New("ecdh_factory: not a CompositeDecrypt primitive")
    91  			}
    92  
    93  			pt, e := p.Decrypt([]byte(cipher), aad)
    94  			if e == nil {
    95  				return pt, nil
    96  			}
    97  		}
    98  	}
    99  
   100  	// nothing worked
   101  	return nil, errors.New("ecdh_factory: decryption failed")
   102  }