github.com/cloudflare/circl@v1.5.0/kem/hybrid/xkem.go (about)

     1  package hybrid
     2  
     3  import (
     4  	"bytes"
     5  	cryptoRand "crypto/rand"
     6  	"crypto/subtle"
     7  
     8  	"github.com/cloudflare/circl/dh/x25519"
     9  	"github.com/cloudflare/circl/dh/x448"
    10  	"github.com/cloudflare/circl/internal/sha3"
    11  	"github.com/cloudflare/circl/kem"
    12  )
    13  
    14  type xPublicKey struct {
    15  	scheme *xScheme
    16  	key    []byte
    17  }
    18  type xPrivateKey struct {
    19  	scheme *xScheme
    20  	key    []byte
    21  }
    22  type xScheme struct {
    23  	size int
    24  }
    25  
    26  var (
    27  	x25519Kem = &xScheme{x25519.Size}
    28  	x448Kem   = &xScheme{x448.Size}
    29  )
    30  
    31  func (sch *xScheme) Name() string {
    32  	switch sch.size {
    33  	case x25519.Size:
    34  		return "X25519"
    35  	case x448.Size:
    36  		return "X448"
    37  	}
    38  	panic(kem.ErrTypeMismatch)
    39  }
    40  
    41  func (sch *xScheme) PublicKeySize() int         { return sch.size }
    42  func (sch *xScheme) PrivateKeySize() int        { return sch.size }
    43  func (sch *xScheme) SeedSize() int              { return sch.size }
    44  func (sch *xScheme) SharedKeySize() int         { return sch.size }
    45  func (sch *xScheme) CiphertextSize() int        { return sch.size }
    46  func (sch *xScheme) EncapsulationSeedSize() int { return sch.size }
    47  
    48  func (sk *xPrivateKey) Scheme() kem.Scheme { return sk.scheme }
    49  func (pk *xPublicKey) Scheme() kem.Scheme  { return pk.scheme }
    50  
    51  func (sk *xPrivateKey) MarshalBinary() ([]byte, error) {
    52  	ret := make([]byte, len(sk.key))
    53  	copy(ret, sk.key)
    54  	return ret, nil
    55  }
    56  
    57  func (sk *xPrivateKey) Equal(other kem.PrivateKey) bool {
    58  	oth, ok := other.(*xPrivateKey)
    59  	if !ok {
    60  		return false
    61  	}
    62  	if oth.scheme != sk.scheme {
    63  		return false
    64  	}
    65  	return subtle.ConstantTimeCompare(oth.key, sk.key) == 1
    66  }
    67  
    68  func (sk *xPrivateKey) Public() kem.PublicKey {
    69  	pk := xPublicKey{sk.scheme, make([]byte, sk.scheme.size)}
    70  	switch sk.scheme.size {
    71  	case x25519.Size:
    72  		var sk2, pk2 x25519.Key
    73  		copy(sk2[:], sk.key)
    74  		x25519.KeyGen(&pk2, &sk2)
    75  		copy(pk.key, pk2[:])
    76  	case x448.Size:
    77  		var sk2, pk2 x448.Key
    78  		copy(sk2[:], sk.key)
    79  		x448.KeyGen(&pk2, &sk2)
    80  		copy(pk.key, pk2[:])
    81  	}
    82  	return &pk
    83  }
    84  
    85  func (pk *xPublicKey) Equal(other kem.PublicKey) bool {
    86  	oth, ok := other.(*xPublicKey)
    87  	if !ok {
    88  		return false
    89  	}
    90  	if oth.scheme != pk.scheme {
    91  		return false
    92  	}
    93  	return bytes.Equal(oth.key, pk.key)
    94  }
    95  
    96  func (pk *xPublicKey) MarshalBinary() ([]byte, error) {
    97  	ret := make([]byte, pk.scheme.size)
    98  	copy(ret, pk.key)
    99  	return ret, nil
   100  }
   101  
   102  func (sch *xScheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
   103  	seed := make([]byte, sch.SeedSize())
   104  	_, err := cryptoRand.Read(seed)
   105  	if err != nil {
   106  		return nil, nil, err
   107  	}
   108  	pk, sk := sch.DeriveKeyPair(seed)
   109  	return pk, sk, nil
   110  }
   111  
   112  func (sch *xScheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
   113  	if len(seed) != sch.SeedSize() {
   114  		panic(kem.ErrSeedSize)
   115  	}
   116  	sk := xPrivateKey{scheme: sch, key: make([]byte, sch.size)}
   117  
   118  	h := sha3.NewShake256()
   119  	_, _ = h.Write(seed)
   120  	_, _ = h.Read(sk.key)
   121  
   122  	return sk.Public(), &sk
   123  }
   124  
   125  func (sch *xScheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) {
   126  	seed := make([]byte, sch.EncapsulationSeedSize())
   127  	_, err = cryptoRand.Read(seed)
   128  	if err != nil {
   129  		return
   130  	}
   131  	return sch.EncapsulateDeterministically(pk, seed)
   132  }
   133  
   134  func (pk *xPublicKey) X(sk *xPrivateKey) []byte {
   135  	if pk.scheme != sk.scheme {
   136  		panic(kem.ErrTypeMismatch)
   137  	}
   138  
   139  	switch pk.scheme.size {
   140  	case x25519.Size:
   141  		var ss2, pk2, sk2 x25519.Key
   142  		copy(pk2[:], pk.key)
   143  		copy(sk2[:], sk.key)
   144  		x25519.Shared(&ss2, &sk2, &pk2)
   145  		return ss2[:]
   146  	case x448.Size:
   147  		var ss2, pk2, sk2 x448.Key
   148  		copy(pk2[:], pk.key)
   149  		copy(sk2[:], sk.key)
   150  		x448.Shared(&ss2, &sk2, &pk2)
   151  		return ss2[:]
   152  	}
   153  	panic(kem.ErrTypeMismatch)
   154  }
   155  
   156  func (sch *xScheme) EncapsulateDeterministically(
   157  	pk kem.PublicKey, seed []byte,
   158  ) (ct, ss []byte, err error) {
   159  	if len(seed) != sch.EncapsulationSeedSize() {
   160  		return nil, nil, kem.ErrSeedSize
   161  	}
   162  	pub, ok := pk.(*xPublicKey)
   163  	if !ok || pub.scheme != sch {
   164  		return nil, nil, kem.ErrTypeMismatch
   165  	}
   166  
   167  	pk2, sk2 := sch.DeriveKeyPair(seed)
   168  	ss = pub.X(sk2.(*xPrivateKey))
   169  	ct, _ = pk2.MarshalBinary()
   170  	return
   171  }
   172  
   173  func (sch *xScheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) {
   174  	if len(ct) != sch.CiphertextSize() {
   175  		return nil, kem.ErrCiphertextSize
   176  	}
   177  
   178  	priv, ok := sk.(*xPrivateKey)
   179  	if !ok || priv.scheme != sch {
   180  		return nil, kem.ErrTypeMismatch
   181  	}
   182  
   183  	pk, err := sch.UnmarshalBinaryPublicKey(ct)
   184  	if err != nil {
   185  		return nil, err
   186  	}
   187  
   188  	ss := pk.(*xPublicKey).X(priv)
   189  	return ss, nil
   190  }
   191  
   192  func (sch *xScheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) {
   193  	if len(buf) != sch.PublicKeySize() {
   194  		return nil, kem.ErrPubKeySize
   195  	}
   196  	ret := xPublicKey{sch, make([]byte, sch.size)}
   197  	copy(ret.key, buf)
   198  	return &ret, nil
   199  }
   200  
   201  func (sch *xScheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) {
   202  	if len(buf) != sch.PrivateKeySize() {
   203  		return nil, kem.ErrPrivKeySize
   204  	}
   205  	ret := xPrivateKey{sch, make([]byte, sch.size)}
   206  	copy(ret.key, buf)
   207  	return &ret, nil
   208  }