github.com/cloudflare/circl@v1.5.0/hpke/xkem.go (about)

     1  package hpke
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"crypto/subtle"
     7  	"fmt"
     8  	"io"
     9  
    10  	"github.com/cloudflare/circl/dh/x25519"
    11  	"github.com/cloudflare/circl/dh/x448"
    12  	"github.com/cloudflare/circl/kem"
    13  )
    14  
    15  type xKEM struct {
    16  	dhKemBase
    17  	size int
    18  }
    19  
    20  func (x xKEM) PrivateKeySize() int        { return x.size }
    21  func (x xKEM) SeedSize() int              { return x.size }
    22  func (x xKEM) CiphertextSize() int        { return x.size }
    23  func (x xKEM) PublicKeySize() int         { return x.size }
    24  func (x xKEM) EncapsulationSeedSize() int { return x.size }
    25  
    26  func (x xKEM) sizeDH() int { return x.size }
    27  func (x xKEM) calcDH(dh []byte, sk kem.PrivateKey, pk kem.PublicKey) error {
    28  	PK := pk.(*xKEMPubKey)
    29  	SK := sk.(*xKEMPrivKey)
    30  	switch x.size {
    31  	case x25519.Size:
    32  		var ss, sKey, pKey x25519.Key
    33  		copy(sKey[:], SK.priv)
    34  		copy(pKey[:], PK.pub)
    35  		if !x25519.Shared(&ss, &sKey, &pKey) {
    36  			return ErrInvalidKEMSharedSecret
    37  		}
    38  		copy(dh, ss[:])
    39  	case x448.Size:
    40  		var ss, sKey, pKey x448.Key
    41  		copy(sKey[:], SK.priv)
    42  		copy(pKey[:], PK.pub)
    43  		if !x448.Shared(&ss, &sKey, &pKey) {
    44  			return ErrInvalidKEMSharedSecret
    45  		}
    46  		copy(dh, ss[:])
    47  	}
    48  	return nil
    49  }
    50  
    51  // Deterministically derives a keypair from a seed. If you're unsure,
    52  // you're better off using GenerateKey().
    53  //
    54  // Panics if seed is not of length SeedSize().
    55  func (x xKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
    56  	// Implementation based on
    57  	// https://www.ietf.org/archive/id/draft-irtf-cfrg-hpke-07.html#name-derivekeypair
    58  	if len(seed) != x.SeedSize() {
    59  		panic(kem.ErrSeedSize)
    60  	}
    61  	sk := &xKEMPrivKey{scheme: x, priv: make([]byte, x.size)}
    62  	dkpPrk := x.labeledExtract([]byte(""), []byte("dkp_prk"), seed)
    63  	bytes := x.labeledExpand(
    64  		dkpPrk,
    65  		[]byte("sk"),
    66  		nil,
    67  		uint16(x.PrivateKeySize()),
    68  	)
    69  	copy(sk.priv, bytes)
    70  	return sk.Public(), sk
    71  }
    72  
    73  func (x xKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
    74  	sk := &xKEMPrivKey{scheme: x, priv: make([]byte, x.PrivateKeySize())}
    75  	_, err := io.ReadFull(rand.Reader, sk.priv)
    76  	if err != nil {
    77  		return nil, nil, err
    78  	}
    79  	return sk.Public(), sk, nil
    80  }
    81  
    82  func (x xKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {
    83  	l := x.PrivateKeySize()
    84  	if len(data) < l {
    85  		return nil, ErrInvalidKEMPrivateKey
    86  	}
    87  	sk := &xKEMPrivKey{x, make([]byte, l), nil}
    88  	copy(sk.priv, data[:l])
    89  	if !sk.validate() {
    90  		return nil, ErrInvalidKEMPrivateKey
    91  	}
    92  	return sk, nil
    93  }
    94  
    95  func (x xKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) {
    96  	l := x.PublicKeySize()
    97  	if len(data) < l {
    98  		return nil, ErrInvalidKEMPublicKey
    99  	}
   100  	pk := &xKEMPubKey{x, make([]byte, l)}
   101  	copy(pk.pub, data[:l])
   102  	if !pk.validate() {
   103  		return nil, ErrInvalidKEMPublicKey
   104  	}
   105  	return pk, nil
   106  }
   107  
   108  type xKEMPubKey struct {
   109  	scheme xKEM
   110  	pub    []byte
   111  }
   112  
   113  func (k *xKEMPubKey) String() string     { return fmt.Sprintf("%x", k.pub) }
   114  func (k *xKEMPubKey) Scheme() kem.Scheme { return k.scheme }
   115  func (k *xKEMPubKey) MarshalBinary() ([]byte, error) {
   116  	return append(make([]byte, 0, k.scheme.PublicKeySize()), k.pub...), nil
   117  }
   118  
   119  func (k *xKEMPubKey) Equal(pk kem.PublicKey) bool {
   120  	k1, ok := pk.(*xKEMPubKey)
   121  	return ok &&
   122  		k.scheme.id == k1.scheme.id &&
   123  		bytes.Equal(k.pub, k1.pub)
   124  }
   125  func (k *xKEMPubKey) validate() bool { return len(k.pub) == k.scheme.PublicKeySize() }
   126  
   127  type xKEMPrivKey struct {
   128  	scheme xKEM
   129  	priv   []byte
   130  	pub    *xKEMPubKey
   131  }
   132  
   133  func (k *xKEMPrivKey) String() string     { return fmt.Sprintf("%x", k.priv) }
   134  func (k *xKEMPrivKey) Scheme() kem.Scheme { return k.scheme }
   135  func (k *xKEMPrivKey) MarshalBinary() ([]byte, error) {
   136  	return append(make([]byte, 0, k.scheme.PrivateKeySize()), k.priv...), nil
   137  }
   138  
   139  func (k *xKEMPrivKey) Equal(pk kem.PrivateKey) bool {
   140  	k1, ok := pk.(*xKEMPrivKey)
   141  	return ok &&
   142  		k.scheme.id == k1.scheme.id &&
   143  		subtle.ConstantTimeCompare(k.priv, k1.priv) == 1
   144  }
   145  
   146  func (k *xKEMPrivKey) Public() kem.PublicKey {
   147  	if k.pub == nil {
   148  		k.pub = &xKEMPubKey{scheme: k.scheme, pub: make([]byte, k.scheme.size)}
   149  		switch k.scheme.size {
   150  		case x25519.Size:
   151  			var sk, pk x25519.Key
   152  			copy(sk[:], k.priv)
   153  			x25519.KeyGen(&pk, &sk)
   154  			copy(k.pub.pub, pk[:])
   155  		case x448.Size:
   156  			var sk, pk x448.Key
   157  			copy(sk[:], k.priv)
   158  			x448.KeyGen(&pk, &sk)
   159  			copy(k.pub.pub, pk[:])
   160  		}
   161  	}
   162  	return k.pub
   163  }
   164  func (k *xKEMPrivKey) validate() bool { return len(k.priv) == k.scheme.PrivateKeySize() }