github.com/cloudflare/circl@v1.5.0/pke/kyber/kyber512/internal/cpapke.go (about)

     1  package internal
     2  
     3  import (
     4  	"bytes"
     5  
     6  	"github.com/cloudflare/circl/internal/sha3"
     7  	"github.com/cloudflare/circl/kem"
     8  	"github.com/cloudflare/circl/pke/kyber/internal/common"
     9  )
    10  
    11  // A Kyber.CPAPKE private key.
    12  type PrivateKey struct {
    13  	sh Vec // NTT(s), normalized
    14  }
    15  
    16  // A Kyber.CPAPKE public key.
    17  type PublicKey struct {
    18  	rho [32]byte // ρ, the seed for the matrix A
    19  	th  Vec      // NTT(t), normalized
    20  
    21  	// cached values
    22  	aT Mat // the matrix Aᵀ
    23  }
    24  
    25  // Packs the private key to buf.
    26  func (sk *PrivateKey) Pack(buf []byte) {
    27  	sk.sh.Pack(buf)
    28  }
    29  
    30  // Unpacks the private key from buf.
    31  func (sk *PrivateKey) Unpack(buf []byte) {
    32  	sk.sh.Unpack(buf)
    33  	sk.sh.Normalize()
    34  }
    35  
    36  // Packs the public key to buf.
    37  func (pk *PublicKey) Pack(buf []byte) {
    38  	pk.th.Pack(buf)
    39  	copy(buf[K*common.PolySize:], pk.rho[:])
    40  }
    41  
    42  // Unpacks the public key from buf. Checks if the public key is normalized.
    43  func (pk *PublicKey) UnpackMLKEM(buf []byte) error {
    44  	pk.Unpack(buf)
    45  
    46  	// FIPS 203 §7.2 "encapsulation key check" (2).
    47  	var buf2 [K * common.PolySize]byte
    48  	pk.th.Pack(buf2[:])
    49  	if !bytes.Equal(buf[:len(buf2)], buf2[:]) {
    50  		return kem.ErrPubKey
    51  	}
    52  	return nil
    53  }
    54  
    55  // Unpacks the public key from buf.
    56  func (pk *PublicKey) Unpack(buf []byte) {
    57  	pk.th.Unpack(buf)
    58  	pk.th.Normalize()
    59  	copy(pk.rho[:], buf[K*common.PolySize:])
    60  	pk.aT.Derive(&pk.rho, true)
    61  }
    62  
    63  // Derives a new Kyber.CPAPKE keypair from the given seed.
    64  func NewKeyFromSeed(seed []byte) (*PublicKey, *PrivateKey) {
    65  	var pk PublicKey
    66  	var sk PrivateKey
    67  
    68  	var expandedSeed [64]byte
    69  
    70  	h := sha3.New512()
    71  	_, _ = h.Write(seed)
    72  
    73  	// This writes hash into expandedSeed.  Yes, this is idiomatic Go.
    74  	_, _ = h.Read(expandedSeed[:])
    75  
    76  	copy(pk.rho[:], expandedSeed[:32])
    77  	sigma := expandedSeed[32:] // σ, the noise seed
    78  
    79  	pk.aT.Derive(&pk.rho, false) // Expand ρ to matrix A; we'll transpose later
    80  
    81  	var eh Vec
    82  	sk.sh.DeriveNoise(sigma, 0, Eta1) // Sample secret vector s
    83  	sk.sh.NTT()
    84  	sk.sh.Normalize()
    85  
    86  	eh.DeriveNoise(sigma, K, Eta1) // Sample blind e
    87  	eh.NTT()
    88  
    89  	// Next, we compute t = A s + e.
    90  	for i := 0; i < K; i++ {
    91  		// Note that coefficients of s are bounded by q and those of A
    92  		// are bounded by 4.5q and so their product is bounded by 2¹⁵q
    93  		// as required for multiplication.
    94  		PolyDotHat(&pk.th[i], &pk.aT[i], &sk.sh)
    95  
    96  		// A and s were not in Montgomery form, so the Montgomery
    97  		// multiplications in the inner product added a factor R⁻¹ which
    98  		// we'll cancel out now.  This will also ensure the coefficients of
    99  		// t are bounded in absolute value by q.
   100  		pk.th[i].ToMont()
   101  	}
   102  
   103  	pk.th.Add(&pk.th, &eh) // bounded by 8q.
   104  	pk.th.Normalize()
   105  	pk.aT.Transpose()
   106  
   107  	return &pk, &sk
   108  }
   109  
   110  // Decrypts ciphertext ct meant for private key sk to plaintext pt.
   111  func (sk *PrivateKey) DecryptTo(pt, ct []byte) {
   112  	var u Vec
   113  	var v, m common.Poly
   114  
   115  	u.Decompress(ct, DU)
   116  	v.Decompress(ct[K*compressedPolySize(DU):], DV)
   117  
   118  	// Compute m = v - <s, u>
   119  	u.NTT()
   120  	PolyDotHat(&m, &sk.sh, &u)
   121  	m.BarrettReduce()
   122  	m.InvNTT()
   123  	m.Sub(&v, &m)
   124  	m.Normalize()
   125  
   126  	// Compress polynomial m to original message
   127  	m.CompressMessageTo(pt)
   128  }
   129  
   130  // Encrypts message pt for the public key to ciphertext ct using randomness
   131  // from seed.
   132  //
   133  // seed has to be of length SeedSize, pt of PlaintextSize and ct of
   134  // CiphertextSize.
   135  func (pk *PublicKey) EncryptTo(ct, pt, seed []byte) {
   136  	var rh, e1, u Vec
   137  	var e2, v, m common.Poly
   138  
   139  	// Sample r, e₁ and e₂ from B_η
   140  	rh.DeriveNoise(seed, 0, Eta1)
   141  	rh.NTT()
   142  	rh.BarrettReduce()
   143  
   144  	e1.DeriveNoise(seed, K, common.Eta2)
   145  	e2.DeriveNoise(seed, 2*K, common.Eta2)
   146  
   147  	// Next we compute u = Aᵀ r + e₁.  First Aᵀ.
   148  	for i := 0; i < K; i++ {
   149  		// Note that coefficients of r are bounded by q and those of Aᵀ
   150  		// are bounded by 4.5q and so their product is bounded by 2¹⁵q
   151  		// as required for multiplication.
   152  		PolyDotHat(&u[i], &pk.aT[i], &rh)
   153  	}
   154  
   155  	u.BarrettReduce()
   156  
   157  	// Aᵀ and r were not in Montgomery form, so the Montgomery
   158  	// multiplications in the inner product added a factor R⁻¹ which
   159  	// the InvNTT cancels out.
   160  	u.InvNTT()
   161  
   162  	u.Add(&u, &e1) // u = Aᵀ r + e₁
   163  
   164  	// Next compute v = <t, r> + e₂ + Decompress_q(m, 1).
   165  	PolyDotHat(&v, &pk.th, &rh)
   166  	v.BarrettReduce()
   167  	v.InvNTT()
   168  
   169  	m.DecompressMessage(pt)
   170  	v.Add(&v, &m)
   171  	v.Add(&v, &e2) // v = <t, r> + e₂ + Decompress_q(m, 1)
   172  
   173  	// Pack ciphertext
   174  	u.Normalize()
   175  	v.Normalize()
   176  
   177  	u.CompressTo(ct, DU)
   178  	v.CompressTo(ct[K*compressedPolySize(DU):], DV)
   179  }
   180  
   181  // Returns whether sk equals other.
   182  func (sk *PrivateKey) Equal(other *PrivateKey) bool {
   183  	ret := int16(0)
   184  	for i := 0; i < K; i++ {
   185  		for j := 0; j < common.N; j++ {
   186  			ret |= sk.sh[i][j] ^ other.sh[i][j]
   187  		}
   188  	}
   189  	return ret == 0
   190  }