github.com/cloudflare/circl@v1.5.0/pke/kyber/kyber1024/internal/cpapke.go (about)

     1  // Code generated from kyber512/internal/cpapke.go by gen.go
     2  
     3  package internal
     4  
     5  import (
     6  	"bytes"
     7  
     8  	"github.com/cloudflare/circl/internal/sha3"
     9  	"github.com/cloudflare/circl/kem"
    10  	"github.com/cloudflare/circl/pke/kyber/internal/common"
    11  )
    12  
    13  // A Kyber.CPAPKE private key.
    14  type PrivateKey struct {
    15  	sh Vec // NTT(s), normalized
    16  }
    17  
    18  // A Kyber.CPAPKE public key.
    19  type PublicKey struct {
    20  	rho [32]byte // ρ, the seed for the matrix A
    21  	th  Vec      // NTT(t), normalized
    22  
    23  	// cached values
    24  	aT Mat // the matrix Aᵀ
    25  }
    26  
    27  // Packs the private key to buf.
    28  func (sk *PrivateKey) Pack(buf []byte) {
    29  	sk.sh.Pack(buf)
    30  }
    31  
    32  // Unpacks the private key from buf.
    33  func (sk *PrivateKey) Unpack(buf []byte) {
    34  	sk.sh.Unpack(buf)
    35  	sk.sh.Normalize()
    36  }
    37  
    38  // Packs the public key to buf.
    39  func (pk *PublicKey) Pack(buf []byte) {
    40  	pk.th.Pack(buf)
    41  	copy(buf[K*common.PolySize:], pk.rho[:])
    42  }
    43  
    44  // Unpacks the public key from buf. Checks if the public key is normalized.
    45  func (pk *PublicKey) UnpackMLKEM(buf []byte) error {
    46  	pk.Unpack(buf)
    47  
    48  	// FIPS 203 §7.2 "encapsulation key check" (2).
    49  	var buf2 [K * common.PolySize]byte
    50  	pk.th.Pack(buf2[:])
    51  	if !bytes.Equal(buf[:len(buf2)], buf2[:]) {
    52  		return kem.ErrPubKey
    53  	}
    54  	return nil
    55  }
    56  
    57  // Unpacks the public key from buf.
    58  func (pk *PublicKey) Unpack(buf []byte) {
    59  	pk.th.Unpack(buf)
    60  	pk.th.Normalize()
    61  	copy(pk.rho[:], buf[K*common.PolySize:])
    62  	pk.aT.Derive(&pk.rho, true)
    63  }
    64  
    65  // Derives a new Kyber.CPAPKE keypair from the given seed.
    66  func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) {
    67  	var pk PublicKey
    68  	var sk PrivateKey
    69  
    70  	var expandedSeed [64]byte
    71  
    72  	h := sha3.New512()
    73  	_, _ = h.Write(seed)
    74  
    75  	// This writes hash into expandedSeed.  Yes, this is idiomatic Go.
    76  	_, _ = h.Read(expandedSeed[:])
    77  
    78  	copy(pk.rho[:], expandedSeed[:32])
    79  	sigma := expandedSeed[32:] // σ, the noise seed
    80  
    81  	pk.aT.Derive(&pk.rho, false) // Expand ρ to matrix A; we'll transpose later
    82  
    83  	var eh Vec
    84  	sk.sh.DeriveNoise(sigma, 0, Eta1) // Sample secret vector s
    85  	sk.sh.NTT()
    86  	sk.sh.Normalize()
    87  
    88  	eh.DeriveNoise(sigma, K, Eta1) // Sample blind e
    89  	eh.NTT()
    90  
    91  	// Next, we compute t = A s + e.
    92  	for i := 0; i < K; i++ {
    93  		// Note that coefficients of s are bounded by q and those of A
    94  		// are bounded by 4.5q and so their product is bounded by 2¹⁵q
    95  		// as required for multiplication.
    96  		PolyDotHat(&pk.th[i], &pk.aT[i], &sk.sh)
    97  
    98  		// A and s were not in Montgomery form, so the Montgomery
    99  		// multiplications in the inner product added a factor R⁻¹ which
   100  		// we'll cancel out now.  This will also ensure the coefficients of
   101  		// t are bounded in absolute value by q.
   102  		pk.th[i].ToMont()
   103  	}
   104  
   105  	pk.th.Add(&pk.th, &eh) // bounded by 8q.
   106  	pk.th.Normalize()
   107  	pk.aT.Transpose()
   108  
   109  	return &pk, &sk
   110  }
   111  
   112  // Decrypts ciphertext ct meant for private key sk to plaintext pt.
   113  func (sk *PrivateKey) DecryptTo(pt, ct []byte) {
   114  	var u Vec
   115  	var v, m common.Poly
   116  
   117  	u.Decompress(ct, DU)
   118  	v.Decompress(ct[K*compressedPolySize(DU):], DV)
   119  
   120  	// Compute m = v - <s, u>
   121  	u.NTT()
   122  	PolyDotHat(&m, &sk.sh, &u)
   123  	m.BarrettReduce()
   124  	m.InvNTT()
   125  	m.Sub(&v, &m)
   126  	m.Normalize()
   127  
   128  	// Compress polynomial m to original message
   129  	m.CompressMessageTo(pt)
   130  }
   131  
   132  // Encrypts message pt for the public key to ciphertext ct using randomness
   133  // from seed.
   134  //
   135  // seed has to be of length SeedSize, pt of PlaintextSize and ct of
   136  // CiphertextSize.
   137  func (pk *PublicKey) EncryptTo(ct, pt, seed []byte) {
   138  	var rh, e1, u Vec
   139  	var e2, v, m common.Poly
   140  
   141  	// Sample r, e₁ and e₂ from B_η
   142  	rh.DeriveNoise(seed, 0, Eta1)
   143  	rh.NTT()
   144  	rh.BarrettReduce()
   145  
   146  	e1.DeriveNoise(seed, K, common.Eta2)
   147  	e2.DeriveNoise(seed, 2*K, common.Eta2)
   148  
   149  	// Next we compute u = Aᵀ r + e₁.  First Aᵀ.
   150  	for i := 0; i < K; i++ {
   151  		// Note that coefficients of r are bounded by q and those of Aᵀ
   152  		// are bounded by 4.5q and so their product is bounded by 2¹⁵q
   153  		// as required for multiplication.
   154  		PolyDotHat(&u[i], &pk.aT[i], &rh)
   155  	}
   156  
   157  	u.BarrettReduce()
   158  
   159  	// Aᵀ and r were not in Montgomery form, so the Montgomery
   160  	// multiplications in the inner product added a factor R⁻¹ which
   161  	// the InvNTT cancels out.
   162  	u.InvNTT()
   163  
   164  	u.Add(&u, &e1) // u = Aᵀ r + e₁
   165  
   166  	// Next compute v = <t, r> + e₂ + Decompress_q(m, 1).
   167  	PolyDotHat(&v, &pk.th, &rh)
   168  	v.BarrettReduce()
   169  	v.InvNTT()
   170  
   171  	m.DecompressMessage(pt)
   172  	v.Add(&v, &m)
   173  	v.Add(&v, &e2) // v = <t, r> + e₂ + Decompress_q(m, 1)
   174  
   175  	// Pack ciphertext
   176  	u.Normalize()
   177  	v.Normalize()
   178  
   179  	u.CompressTo(ct, DU)
   180  	v.CompressTo(ct[K*compressedPolySize(DU):], DV)
   181  }
   182  
   183  // Returns whether sk equals other.
   184  func (sk *PrivateKey) Equal(other *PrivateKey) bool {
   185  	ret := int16(0)
   186  	for i := 0; i < K; i++ {
   187  		for j := 0; j < common.N; j++ {
   188  			ret |= sk.sh[i][j] ^ other.sh[i][j]
   189  		}
   190  	}
   191  	return ret == 0
   192  }