github.com/cloudflare/circl@v1.5.0/kem/mlkem/mlkem1024/kyber.go (about)

     1  // Code generated from pkg.templ.go. DO NOT EDIT.
     2  
     3  // Package mlkem1024 implements the IND-CCA2 secure key encapsulation mechanism
     4  // ML-KEM-1024 as defined in FIPS203.
     5  package mlkem1024
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/subtle"
    10  	"io"
    11  
    12  	cryptoRand "crypto/rand"
    13  	"github.com/cloudflare/circl/internal/sha3"
    14  	"github.com/cloudflare/circl/kem"
    15  	cpapke "github.com/cloudflare/circl/pke/kyber/kyber1024"
    16  )
    17  
    18  const (
    19  	// Size of seed for NewKeyFromSeed
    20  	KeySeedSize = cpapke.KeySeedSize + 32
    21  
    22  	// Size of seed for EncapsulateTo.
    23  	EncapsulationSeedSize = 32
    24  
    25  	// Size of the established shared key.
    26  	SharedKeySize = 32
    27  
    28  	// Size of the encapsulated shared key.
    29  	CiphertextSize = cpapke.CiphertextSize
    30  
    31  	// Size of a packed public key.
    32  	PublicKeySize = cpapke.PublicKeySize
    33  
    34  	// Size of a packed private key.
    35  	PrivateKeySize = cpapke.PrivateKeySize + cpapke.PublicKeySize + 64
    36  )
    37  
    38  // Type of a ML-KEM-1024 public key
    39  type PublicKey struct {
    40  	pk *cpapke.PublicKey
    41  
    42  	hpk [32]byte // H(pk)
    43  }
    44  
    45  // Type of a ML-KEM-1024 private key
    46  type PrivateKey struct {
    47  	sk  *cpapke.PrivateKey
    48  	pk  *cpapke.PublicKey
    49  	hpk [32]byte // H(pk)
    50  	z   [32]byte
    51  }
    52  
    53  // NewKeyFromSeed derives a public/private keypair deterministically
    54  // from the given seed.
    55  //
    56  // Panics if seed is not of length KeySeedSize.
    57  func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) {
    58  	var sk PrivateKey
    59  	var pk PublicKey
    60  
    61  	if len(seed) != KeySeedSize {
    62  		panic("seed must be of length KeySeedSize")
    63  	}
    64  
    65  	pk.pk, sk.sk = cpapke.NewKeyFromSeedMLKEM(seed[:cpapke.KeySeedSize])
    66  	sk.pk = pk.pk
    67  	copy(sk.z[:], seed[cpapke.KeySeedSize:])
    68  
    69  	// Compute H(pk)
    70  	var ppk [cpapke.PublicKeySize]byte
    71  	sk.pk.Pack(ppk[:])
    72  	h := sha3.New256()
    73  	h.Write(ppk[:])
    74  	h.Read(sk.hpk[:])
    75  	copy(pk.hpk[:], sk.hpk[:])
    76  
    77  	return &pk, &sk
    78  }
    79  
    80  // GenerateKeyPair generates public and private keys using entropy from rand.
    81  // If rand is nil, crypto/rand.Reader will be used.
    82  func GenerateKeyPair(rand io.Reader) (*PublicKey, *PrivateKey, error) {
    83  	var seed [KeySeedSize]byte
    84  	if rand == nil {
    85  		rand = cryptoRand.Reader
    86  	}
    87  	_, err := io.ReadFull(rand, seed[:])
    88  	if err != nil {
    89  		return nil, nil, err
    90  	}
    91  	pk, sk := NewKeyFromSeed(seed[:])
    92  	return pk, sk, nil
    93  }
    94  
    95  // EncapsulateTo generates a shared key and ciphertext that contains it
    96  // for the public key using randomness from seed and writes the shared key
    97  // to ss and ciphertext to ct.
    98  //
    99  // Panics if ss, ct or seed are not of length SharedKeySize, CiphertextSize
   100  // and EncapsulationSeedSize respectively.
   101  //
   102  // seed may be nil, in which case crypto/rand.Reader is used to generate one.
   103  func (pk *PublicKey) EncapsulateTo(ct, ss []byte, seed []byte) {
   104  	if seed == nil {
   105  		seed = make([]byte, EncapsulationSeedSize)
   106  		if _, err := cryptoRand.Read(seed[:]); err != nil {
   107  			panic(err)
   108  		}
   109  	} else {
   110  		if len(seed) != EncapsulationSeedSize {
   111  			panic("seed must be of length EncapsulationSeedSize")
   112  		}
   113  	}
   114  
   115  	if len(ct) != CiphertextSize {
   116  		panic("ct must be of length CiphertextSize")
   117  	}
   118  
   119  	if len(ss) != SharedKeySize {
   120  		panic("ss must be of length SharedKeySize")
   121  	}
   122  
   123  	var m [32]byte
   124  	copy(m[:], seed)
   125  
   126  	// (K', r) = G(m ‖ H(pk))
   127  	var kr [64]byte
   128  	g := sha3.New512()
   129  	g.Write(m[:])
   130  	g.Write(pk.hpk[:])
   131  	g.Read(kr[:])
   132  
   133  	// c = Kyber.CPAPKE.Enc(pk, m, r)
   134  	pk.pk.EncryptTo(ct, m[:], kr[32:])
   135  
   136  	copy(ss, kr[:SharedKeySize])
   137  }
   138  
   139  // DecapsulateTo computes the shared key which is encapsulated in ct
   140  // for the private key.
   141  //
   142  // Panics if ct or ss are not of length CiphertextSize and SharedKeySize
   143  // respectively.
   144  func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) {
   145  	if len(ct) != CiphertextSize {
   146  		panic("ct must be of length CiphertextSize")
   147  	}
   148  
   149  	if len(ss) != SharedKeySize {
   150  		panic("ss must be of length SharedKeySize")
   151  	}
   152  
   153  	// m' = Kyber.CPAPKE.Dec(sk, ct)
   154  	var m2 [32]byte
   155  	sk.sk.DecryptTo(m2[:], ct)
   156  
   157  	// (K'', r') = G(m' ‖ H(pk))
   158  	var kr2 [64]byte
   159  	g := sha3.New512()
   160  	g.Write(m2[:])
   161  	g.Write(sk.hpk[:])
   162  	g.Read(kr2[:])
   163  
   164  	// c' = Kyber.CPAPKE.Enc(pk, m', r')
   165  	var ct2 [CiphertextSize]byte
   166  	sk.pk.EncryptTo(ct2[:], m2[:], kr2[32:])
   167  
   168  	var ss2 [SharedKeySize]byte
   169  
   170  	// Compute shared secret in case of rejection: ss₂ = PRF(z ‖ c)
   171  	prf := sha3.NewShake256()
   172  	prf.Write(sk.z[:])
   173  	prf.Write(ct[:CiphertextSize])
   174  	prf.Read(ss2[:])
   175  
   176  	// Set ss2 to the real shared secret if c = c'.
   177  	subtle.ConstantTimeCopy(
   178  		subtle.ConstantTimeCompare(ct, ct2[:]),
   179  		ss2[:],
   180  		kr2[:SharedKeySize],
   181  	)
   182  
   183  	copy(ss, ss2[:])
   184  }
   185  
   186  // Packs sk to buf.
   187  //
   188  // Panics if buf is not of size PrivateKeySize.
   189  func (sk *PrivateKey) Pack(buf []byte) {
   190  	if len(buf) != PrivateKeySize {
   191  		panic("buf must be of length PrivateKeySize")
   192  	}
   193  
   194  	sk.sk.Pack(buf[:cpapke.PrivateKeySize])
   195  	buf = buf[cpapke.PrivateKeySize:]
   196  	sk.pk.Pack(buf[:cpapke.PublicKeySize])
   197  	buf = buf[cpapke.PublicKeySize:]
   198  	copy(buf, sk.hpk[:])
   199  	buf = buf[32:]
   200  	copy(buf, sk.z[:])
   201  }
   202  
   203  // Unpacks sk from buf.
   204  //
   205  // Panics if buf is not of size PrivateKeySize.
   206  //
   207  // Returns an error if buf is not of size PrivateKeySize, or private key
   208  // doesn't pass the ML-KEM decapsulation key check.
   209  func (sk *PrivateKey) Unpack(buf []byte) error {
   210  	if len(buf) != PrivateKeySize {
   211  		return kem.ErrPrivKeySize
   212  	}
   213  
   214  	sk.sk = new(cpapke.PrivateKey)
   215  	sk.sk.Unpack(buf[:cpapke.PrivateKeySize])
   216  	buf = buf[cpapke.PrivateKeySize:]
   217  	sk.pk = new(cpapke.PublicKey)
   218  	sk.pk.Unpack(buf[:cpapke.PublicKeySize])
   219  	var hpk [32]byte
   220  	h := sha3.New256()
   221  	h.Write(buf[:cpapke.PublicKeySize])
   222  	h.Read(hpk[:])
   223  	buf = buf[cpapke.PublicKeySize:]
   224  	copy(sk.hpk[:], buf[:32])
   225  	copy(sk.z[:], buf[32:])
   226  	if !bytes.Equal(hpk[:], sk.hpk[:]) {
   227  		return kem.ErrPrivKey
   228  	}
   229  	return nil
   230  }
   231  
   232  // Packs pk to buf.
   233  //
   234  // Panics if buf is not of size PublicKeySize.
   235  func (pk *PublicKey) Pack(buf []byte) {
   236  	if len(buf) != PublicKeySize {
   237  		panic("buf must be of length PublicKeySize")
   238  	}
   239  
   240  	pk.pk.Pack(buf)
   241  }
   242  
   243  // Unpacks pk from buf.
   244  //
   245  // Returns an error if buf is not of size PublicKeySize, or the public key
   246  // is not normalized.
   247  func (pk *PublicKey) Unpack(buf []byte) error {
   248  	if len(buf) != PublicKeySize {
   249  		return kem.ErrPubKeySize
   250  	}
   251  
   252  	pk.pk = new(cpapke.PublicKey)
   253  	if err := pk.pk.UnpackMLKEM(buf); err != nil {
   254  		return err
   255  	}
   256  
   257  	// Compute cached H(pk)
   258  	h := sha3.New256()
   259  	h.Write(buf)
   260  	h.Read(pk.hpk[:])
   261  
   262  	return nil
   263  }
   264  
   265  // Boilerplate down below for the KEM scheme API.
   266  
   267  type scheme struct{}
   268  
   269  var sch kem.Scheme = &scheme{}
   270  
   271  // Scheme returns a KEM interface.
   272  func Scheme() kem.Scheme { return sch }
   273  
   274  func (*scheme) Name() string               { return "ML-KEM-1024" }
   275  func (*scheme) PublicKeySize() int         { return PublicKeySize }
   276  func (*scheme) PrivateKeySize() int        { return PrivateKeySize }
   277  func (*scheme) SeedSize() int              { return KeySeedSize }
   278  func (*scheme) SharedKeySize() int         { return SharedKeySize }
   279  func (*scheme) CiphertextSize() int        { return CiphertextSize }
   280  func (*scheme) EncapsulationSeedSize() int { return EncapsulationSeedSize }
   281  
   282  func (sk *PrivateKey) Scheme() kem.Scheme { return sch }
   283  func (pk *PublicKey) Scheme() kem.Scheme  { return sch }
   284  
   285  func (sk *PrivateKey) MarshalBinary() ([]byte, error) {
   286  	var ret [PrivateKeySize]byte
   287  	sk.Pack(ret[:])
   288  	return ret[:], nil
   289  }
   290  
   291  func (sk *PrivateKey) Equal(other kem.PrivateKey) bool {
   292  	oth, ok := other.(*PrivateKey)
   293  	if !ok {
   294  		return false
   295  	}
   296  	if sk.pk == nil && oth.pk == nil {
   297  		return true
   298  	}
   299  	if sk.pk == nil || oth.pk == nil {
   300  		return false
   301  	}
   302  	if !bytes.Equal(sk.hpk[:], oth.hpk[:]) ||
   303  		subtle.ConstantTimeCompare(sk.z[:], oth.z[:]) != 1 {
   304  		return false
   305  	}
   306  	return sk.sk.Equal(oth.sk)
   307  }
   308  
   309  func (pk *PublicKey) Equal(other kem.PublicKey) bool {
   310  	oth, ok := other.(*PublicKey)
   311  	if !ok {
   312  		return false
   313  	}
   314  	if pk.pk == nil && oth.pk == nil {
   315  		return true
   316  	}
   317  	if pk.pk == nil || oth.pk == nil {
   318  		return false
   319  	}
   320  	return bytes.Equal(pk.hpk[:], oth.hpk[:])
   321  }
   322  
   323  func (sk *PrivateKey) Public() kem.PublicKey {
   324  	pk := new(PublicKey)
   325  	pk.pk = sk.pk
   326  	copy(pk.hpk[:], sk.hpk[:])
   327  	return pk
   328  }
   329  
   330  func (pk *PublicKey) MarshalBinary() ([]byte, error) {
   331  	var ret [PublicKeySize]byte
   332  	pk.Pack(ret[:])
   333  	return ret[:], nil
   334  }
   335  
   336  func (*scheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
   337  	return GenerateKeyPair(cryptoRand.Reader)
   338  }
   339  
   340  func (*scheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
   341  	if len(seed) != KeySeedSize {
   342  		panic(kem.ErrSeedSize)
   343  	}
   344  	return NewKeyFromSeed(seed[:])
   345  }
   346  
   347  func (*scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) {
   348  	ct = make([]byte, CiphertextSize)
   349  	ss = make([]byte, SharedKeySize)
   350  
   351  	pub, ok := pk.(*PublicKey)
   352  	if !ok {
   353  		return nil, nil, kem.ErrTypeMismatch
   354  	}
   355  	pub.EncapsulateTo(ct, ss, nil)
   356  	return
   357  }
   358  
   359  func (*scheme) EncapsulateDeterministically(pk kem.PublicKey, seed []byte) (
   360  	ct, ss []byte, err error) {
   361  	if len(seed) != EncapsulationSeedSize {
   362  		return nil, nil, kem.ErrSeedSize
   363  	}
   364  
   365  	ct = make([]byte, CiphertextSize)
   366  	ss = make([]byte, SharedKeySize)
   367  
   368  	pub, ok := pk.(*PublicKey)
   369  	if !ok {
   370  		return nil, nil, kem.ErrTypeMismatch
   371  	}
   372  	pub.EncapsulateTo(ct, ss, seed)
   373  	return
   374  }
   375  
   376  func (*scheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) {
   377  	if len(ct) != CiphertextSize {
   378  		return nil, kem.ErrCiphertextSize
   379  	}
   380  
   381  	priv, ok := sk.(*PrivateKey)
   382  	if !ok {
   383  		return nil, kem.ErrTypeMismatch
   384  	}
   385  	ss := make([]byte, SharedKeySize)
   386  	priv.DecapsulateTo(ss, ct)
   387  	return ss, nil
   388  }
   389  
   390  func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) {
   391  	var ret PublicKey
   392  	if err := ret.Unpack(buf); err != nil {
   393  		return nil, err
   394  	}
   395  	return &ret, nil
   396  }
   397  
   398  func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) {
   399  	if len(buf) != PrivateKeySize {
   400  		return nil, kem.ErrPrivKeySize
   401  	}
   402  	var ret PrivateKey
   403  	if err := ret.Unpack(buf); err != nil {
   404  		return nil, err
   405  	}
   406  	return &ret, nil
   407  }