github.com/cloudflare/circl@v1.5.0/kem/kyber/templates/pkg.templ.go (about)

     1  // +build ignore
     2  // The previous line (and this one up to the warning below) is removed by the
     3  // template generator.
     4  
     5  // Code generated from pkg.templ.go. DO NOT EDIT.
     6  
     7  // Package {{.Pkg}} implements the IND-CCA2 secure key encapsulation mechanism
     8  {{ if .NIST -}}
     9  // {{.KemName}} as defined in FIPS203.
    10  {{- else -}}
    11  // {{.KemName}} as submitted to round 3 of the NIST PQC competition and
    12  // described in
    13  //
    14  // https://pq-crystals.org/kyber/data/kyber-specification-round3.pdf
    15  {{- end }}
    16  package {{.Pkg}}
    17  
    18  import (
    19  	"bytes"
    20  	"crypto/subtle"
    21  	"io"
    22  
    23  	"github.com/cloudflare/circl/internal/sha3"
    24  	"github.com/cloudflare/circl/kem"
    25  	cpapke "github.com/cloudflare/circl/pke/kyber/{{.PkePkg}}"
    26  	cryptoRand "crypto/rand"
    27  )
    28  
    29  const (
    30  	// Size of seed for NewKeyFromSeed
    31  	KeySeedSize = cpapke.KeySeedSize + 32
    32  
    33  	// Size of seed for EncapsulateTo.
    34  	EncapsulationSeedSize = 32
    35  
    36  	// Size of the established shared key.
    37  	SharedKeySize = 32
    38  
    39  	// Size of the encapsulated shared key.
    40  	CiphertextSize = cpapke.CiphertextSize
    41  
    42  	// Size of a packed public key.
    43  	PublicKeySize = cpapke.PublicKeySize
    44  
    45  	// Size of a packed private key.
    46  	PrivateKeySize = cpapke.PrivateKeySize + cpapke.PublicKeySize + 64
    47  )
    48  
    49  // Type of a {{.KemName}} public key
    50  type PublicKey struct {
    51  	pk *cpapke.PublicKey
    52  
    53  	hpk [32]byte // H(pk)
    54  }
    55  
    56  // Type of a {{.KemName}} private key
    57  type PrivateKey struct {
    58  	sk  *cpapke.PrivateKey
    59  	pk  *cpapke.PublicKey
    60  	hpk [32]byte // H(pk)
    61  	z   [32]byte
    62  }
    63  
    64  // NewKeyFromSeed derives a public/private keypair deterministically
    65  // from the given seed.
    66  //
    67  // Panics if seed is not of length KeySeedSize.
    68  func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) {
    69  	var sk PrivateKey
    70  	var pk PublicKey
    71  
    72  	if len(seed) != KeySeedSize {
    73  		panic("seed must be of length KeySeedSize")
    74  	}
    75  
    76  	{{ if .NIST -}}
    77  	pk.pk, sk.sk = cpapke.NewKeyFromSeedMLKEM(seed[:cpapke.KeySeedSize])
    78  	{{- else -}}
    79  	pk.pk, sk.sk = cpapke.NewKeyFromSeed(seed[:cpapke.KeySeedSize])
    80  	{{- end }}
    81  	sk.pk = pk.pk
    82  	copy(sk.z[:], seed[cpapke.KeySeedSize:])
    83  
    84  	// Compute H(pk)
    85  	var ppk [cpapke.PublicKeySize]byte
    86  	sk.pk.Pack(ppk[:])
    87  	h := sha3.New256()
    88  	h.Write(ppk[:])
    89  	h.Read(sk.hpk[:])
    90  	copy(pk.hpk[:], sk.hpk[:])
    91  
    92  	return &pk, &sk
    93  }
    94  
    95  // GenerateKeyPair generates public and private keys using entropy from rand.
    96  // If rand is nil, crypto/rand.Reader will be used.
    97  func GenerateKeyPair(rand io.Reader) (*PublicKey, *PrivateKey, error) {
    98  	var seed [KeySeedSize]byte
    99  	if rand == nil {
   100  		rand = cryptoRand.Reader
   101  	}
   102  	_, err := io.ReadFull(rand, seed[:])
   103  	if err != nil {
   104  		return nil, nil, err
   105  	}
   106  	pk, sk := NewKeyFromSeed(seed[:])
   107  	return pk, sk, nil
   108  }
   109  
   110  // EncapsulateTo generates a shared key and ciphertext that contains it
   111  // for the public key using randomness from seed and writes the shared key
   112  // to ss and ciphertext to ct.
   113  //
   114  // Panics if ss, ct or seed are not of length SharedKeySize, CiphertextSize
   115  // and EncapsulationSeedSize respectively.
   116  //
   117  // seed may be nil, in which case crypto/rand.Reader is used to generate one.
   118  func (pk *PublicKey) EncapsulateTo(ct, ss []byte, seed []byte) {
   119  	if seed == nil {
   120  		seed = make([]byte, EncapsulationSeedSize)
   121  		if _, err := cryptoRand.Read(seed[:]); err != nil {
   122  			panic(err)
   123  		}
   124  	} else {
   125  		if len(seed) != EncapsulationSeedSize {
   126  			panic("seed must be of length EncapsulationSeedSize")
   127  		}
   128  	}
   129  
   130  	if len(ct) != CiphertextSize {
   131  		panic("ct must be of length CiphertextSize")
   132  	}
   133  
   134  	if len(ss) != SharedKeySize {
   135  		panic("ss must be of length SharedKeySize")
   136  	}
   137  
   138  	var m [32]byte
   139  	{{ if .NIST -}}
   140  	copy(m[:], seed)
   141  	{{- else -}}
   142  	// m = H(seed), the hash of shame
   143  	h := sha3.New256()
   144  	h.Write(seed)
   145  	h.Read(m[:])
   146  	{{- end }}
   147  
   148  	// (K', r) = G(m ‖ H(pk))
   149  	var kr [64]byte
   150  	g := sha3.New512()
   151  	g.Write(m[:])
   152  	g.Write(pk.hpk[:])
   153  	g.Read(kr[:])
   154  
   155  	// c = Kyber.CPAPKE.Enc(pk, m, r)
   156  	pk.pk.EncryptTo(ct, m[:], kr[32:])
   157  
   158  	{{ if .NIST -}}
   159  	copy(ss, kr[:SharedKeySize])
   160  	{{- else -}}
   161  	// Compute H(c) and put in second slot of kr, which will be (K', H(c)).
   162  	h.Reset()
   163  	h.Write(ct[:CiphertextSize])
   164  	h.Read(kr[32:])
   165  
   166  	// K = KDF(K' ‖ H(c))
   167  	kdf := sha3.NewShake256()
   168  	kdf.Write(kr[:])
   169  	kdf.Read(ss[:SharedKeySize])
   170  	{{- end }}
   171  }
   172  
   173  // DecapsulateTo computes the shared key which is encapsulated in ct
   174  // for the private key.
   175  //
   176  // Panics if ct or ss are not of length CiphertextSize and SharedKeySize
   177  // respectively.
   178  func (sk *PrivateKey) DecapsulateTo(ss, ct []byte) {
   179  	if len(ct) != CiphertextSize {
   180  		panic("ct must be of length CiphertextSize")
   181  	}
   182  
   183  	if len(ss) != SharedKeySize {
   184  		panic("ss must be of length SharedKeySize")
   185  	}
   186  
   187  	// m' = Kyber.CPAPKE.Dec(sk, ct)
   188  	var m2 [32]byte
   189  	sk.sk.DecryptTo(m2[:], ct)
   190  
   191  	// (K'', r') = G(m' ‖ H(pk))
   192  	var kr2 [64]byte
   193  	g := sha3.New512()
   194  	g.Write(m2[:])
   195  	g.Write(sk.hpk[:])
   196  	g.Read(kr2[:])
   197  
   198  	// c' = Kyber.CPAPKE.Enc(pk, m', r')
   199  	var ct2 [CiphertextSize]byte
   200  	sk.pk.EncryptTo(ct2[:], m2[:], kr2[32:])
   201  
   202  	{{ if .NIST -}}
   203  	var ss2 [SharedKeySize]byte
   204  
   205  	// Compute shared secret in case of rejection: ss₂ = PRF(z ‖ c)
   206  	prf := sha3.NewShake256()
   207  	prf.Write(sk.z[:])
   208  	prf.Write(ct[:CiphertextSize])
   209  	prf.Read(ss2[:])
   210  
   211  	// Set ss2 to the real shared secret if c = c'.
   212  	subtle.ConstantTimeCopy(
   213  		subtle.ConstantTimeCompare(ct, ct2[:]),
   214  		ss2[:],
   215  		kr2[:SharedKeySize],
   216  	)
   217  
   218  	copy(ss, ss2[:])
   219  	{{- else -}}
   220  	// Compute H(c) and put in second slot of kr2, which will be (K'', H(c)).
   221  	h := sha3.New256()
   222  	h.Write(ct[:CiphertextSize])
   223  	h.Read(kr2[32:])
   224  
   225  	// Replace K'' by  z in the first slot of kr2 if c ≠ c'.
   226  	subtle.ConstantTimeCopy(
   227  		1-subtle.ConstantTimeCompare(ct, ct2[:]),
   228  		kr2[:32],
   229  		sk.z[:],
   230  	)
   231  
   232  	// K = KDF(K''/z, H(c))
   233  	kdf := sha3.NewShake256()
   234  	kdf.Write(kr2[:])
   235  	kdf.Read(ss)
   236  	{{- end }}
   237  }
   238  
   239  // Packs sk to buf.
   240  //
   241  // Panics if buf is not of size PrivateKeySize.
   242  func (sk *PrivateKey) Pack(buf []byte) {
   243  	if len(buf) != PrivateKeySize {
   244  		panic("buf must be of length PrivateKeySize")
   245  	}
   246  
   247  	sk.sk.Pack(buf[:cpapke.PrivateKeySize])
   248  	buf = buf[cpapke.PrivateKeySize:]
   249  	sk.pk.Pack(buf[:cpapke.PublicKeySize])
   250  	buf = buf[cpapke.PublicKeySize:]
   251  	copy(buf, sk.hpk[:])
   252  	buf = buf[32:]
   253  	copy(buf, sk.z[:])
   254  }
   255  
   256  // Unpacks sk from buf.
   257  //
   258  // Panics if buf is not of size PrivateKeySize.
   259  {{ if .NIST -}}
   260  //
   261  // Returns an error if buf is not of size PrivateKeySize, or private key
   262  // doesn't pass the ML-KEM decapsulation key check.
   263  func (sk *PrivateKey) Unpack(buf []byte) error {
   264  	if len(buf) != PrivateKeySize {
   265  		return kem.ErrPrivKeySize
   266  	}
   267  {{- else -}}
   268  func (sk *PrivateKey) Unpack(buf []byte) {
   269  	if len(buf) != PrivateKeySize {
   270  		panic("buf must be of length PrivateKeySize")
   271  	}
   272  {{- end }}
   273  
   274  	sk.sk = new(cpapke.PrivateKey)
   275  	sk.sk.Unpack(buf[:cpapke.PrivateKeySize])
   276  	buf = buf[cpapke.PrivateKeySize:]
   277  	sk.pk = new(cpapke.PublicKey)
   278  	sk.pk.Unpack(buf[:cpapke.PublicKeySize])
   279  {{ if .NIST -}}
   280  	var hpk [32]byte
   281  	h := sha3.New256()
   282  	h.Write(buf[:cpapke.PublicKeySize])
   283  	h.Read(hpk[:])
   284  {{ end -}}
   285  	buf = buf[cpapke.PublicKeySize:]
   286  	copy(sk.hpk[:], buf[:32])
   287  	copy(sk.z[:], buf[32:])
   288  {{ if .NIST -}}
   289  	if !bytes.Equal(hpk[:], sk.hpk[:]) {
   290  		return kem.ErrPrivKey
   291  	}
   292  	return nil
   293  {{ end -}}
   294  }
   295  
   296  // Packs pk to buf.
   297  //
   298  // Panics if buf is not of size PublicKeySize.
   299  func (pk *PublicKey) Pack(buf []byte) {
   300  	if len(buf) != PublicKeySize {
   301  		panic("buf must be of length PublicKeySize")
   302  	}
   303  
   304  	pk.pk.Pack(buf)
   305  }
   306  
   307  // Unpacks pk from buf.
   308  //
   309  {{ if .NIST -}}
   310  // Returns an error if buf is not of size PublicKeySize, or the public key
   311  // is not normalized.
   312  func (pk *PublicKey) Unpack(buf []byte) error {
   313  	if len(buf) != PublicKeySize {
   314  		return kem.ErrPubKeySize
   315  	}
   316  {{- else -}}
   317  // Panics if buf is not of size PublicKeySize.
   318  func (pk *PublicKey) Unpack(buf []byte) {
   319  	if len(buf) != PublicKeySize {
   320  		panic("buf must be of length PublicKeySize")
   321  	}
   322  {{- end }}
   323  
   324  	pk.pk = new(cpapke.PublicKey)
   325  	{{ if .NIST -}}
   326  	if err := pk.pk.UnpackMLKEM(buf); err != nil {
   327  		return err
   328  	}
   329  	{{- else -}}
   330  	pk.pk.Unpack(buf)
   331  	{{- end }}
   332  
   333  	// Compute cached H(pk)
   334  	h := sha3.New256()
   335  	h.Write(buf)
   336  	h.Read(pk.hpk[:])
   337  
   338  	{{ if .NIST -}}
   339  	return nil
   340  	{{- end }}
   341  }
   342  
   343  // Boilerplate down below for the KEM scheme API.
   344  
   345  type scheme struct{}
   346  
   347  var sch kem.Scheme = &scheme{}
   348  
   349  // Scheme returns a KEM interface.
   350  func Scheme() kem.Scheme { return sch }
   351  
   352  func (*scheme) Name() string               { return "{{.Name}}" }
   353  func (*scheme) PublicKeySize() int         { return PublicKeySize }
   354  func (*scheme) PrivateKeySize() int        { return PrivateKeySize }
   355  func (*scheme) SeedSize() int              { return KeySeedSize }
   356  func (*scheme) SharedKeySize() int         { return SharedKeySize }
   357  func (*scheme) CiphertextSize() int        { return CiphertextSize }
   358  func (*scheme) EncapsulationSeedSize() int { return EncapsulationSeedSize }
   359  
   360  func (sk *PrivateKey) Scheme() kem.Scheme { return sch }
   361  func (pk *PublicKey) Scheme() kem.Scheme  { return sch }
   362  
   363  func (sk *PrivateKey) MarshalBinary() ([]byte, error) {
   364  	var ret [PrivateKeySize]byte
   365  	sk.Pack(ret[:])
   366  	return ret[:], nil
   367  }
   368  
   369  func (sk *PrivateKey) Equal(other kem.PrivateKey) bool {
   370  	oth, ok := other.(*PrivateKey)
   371  	if !ok {
   372  		return false
   373  	}
   374  	if sk.pk == nil && oth.pk == nil {
   375  		return true
   376  	}
   377  	if sk.pk == nil || oth.pk == nil {
   378  		return false
   379  	}
   380  	if !bytes.Equal(sk.hpk[:], oth.hpk[:]) ||
   381  		subtle.ConstantTimeCompare(sk.z[:], oth.z[:]) != 1 {
   382  		return false
   383  	}
   384  	return sk.sk.Equal(oth.sk)
   385  }
   386  
   387  func (pk *PublicKey) Equal(other kem.PublicKey) bool {
   388  	oth, ok := other.(*PublicKey)
   389  	if !ok {
   390  		return false
   391  	}
   392  	if pk.pk == nil && oth.pk == nil {
   393  		return true
   394  	}
   395  	if pk.pk == nil || oth.pk == nil {
   396  		return false
   397  	}
   398  	return bytes.Equal(pk.hpk[:], oth.hpk[:])
   399  }
   400  
   401  func (sk *PrivateKey) Public() kem.PublicKey {
   402  	pk := new(PublicKey)
   403  	pk.pk = sk.pk
   404  	copy(pk.hpk[:], sk.hpk[:])
   405  	return pk
   406  }
   407  
   408  func (pk *PublicKey) MarshalBinary() ([]byte, error) {
   409  	var ret [PublicKeySize]byte
   410  	pk.Pack(ret[:])
   411  	return ret[:], nil
   412  }
   413  
   414  func (*scheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
   415  	return GenerateKeyPair(cryptoRand.Reader)
   416  }
   417  
   418  func (*scheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
   419  	if len(seed) != KeySeedSize {
   420  		panic(kem.ErrSeedSize)
   421  	}
   422  	return NewKeyFromSeed(seed[:])
   423  }
   424  
   425  func (*scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) {
   426  	ct = make([]byte, CiphertextSize)
   427  	ss = make([]byte, SharedKeySize)
   428  
   429  	pub, ok := pk.(*PublicKey)
   430  	if !ok {
   431  		return nil, nil, kem.ErrTypeMismatch
   432  	}
   433  	pub.EncapsulateTo(ct, ss, nil)
   434  	return
   435  }
   436  
   437  func (*scheme) EncapsulateDeterministically(pk kem.PublicKey, seed []byte) (
   438  	ct, ss []byte, err error) {
   439  	if len(seed) != EncapsulationSeedSize {
   440  		return nil, nil, kem.ErrSeedSize
   441  	}
   442  
   443  	ct = make([]byte, CiphertextSize)
   444  	ss = make([]byte, SharedKeySize)
   445  
   446  	pub, ok := pk.(*PublicKey)
   447  	if !ok {
   448  		return nil, nil, kem.ErrTypeMismatch
   449  	}
   450  	pub.EncapsulateTo(ct, ss, seed)
   451  	return
   452  }
   453  
   454  func (*scheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) {
   455  	if len(ct) != CiphertextSize {
   456  		return nil, kem.ErrCiphertextSize
   457  	}
   458  
   459  	priv, ok := sk.(*PrivateKey)
   460  	if !ok {
   461  		return nil, kem.ErrTypeMismatch
   462  	}
   463  	ss := make([]byte, SharedKeySize)
   464  	priv.DecapsulateTo(ss, ct)
   465  	return ss, nil
   466  }
   467  
   468  func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) {
   469  	var ret PublicKey
   470  	{{ if .NIST -}}
   471  	if err := ret.Unpack(buf); err != nil {
   472  		return nil, err
   473  	}
   474  	{{- else -}}
   475  	if len(buf) != PublicKeySize {
   476  		return nil, kem.ErrPubKeySize
   477  	}
   478  	ret.Unpack(buf)
   479  	{{- end }}
   480  	return &ret, nil
   481  }
   482  
   483  func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) {
   484  	if len(buf) != PrivateKeySize {
   485  		return nil, kem.ErrPrivKeySize
   486  	}
   487  	var ret PrivateKey
   488  	{{ if .NIST -}}
   489  	if err := ret.Unpack(buf); err != nil {
   490  		return nil, err
   491  	}
   492  	{{- else -}}
   493  	ret.Unpack(buf)
   494  	{{- end }}
   495  	return &ret, nil
   496  }