github.com/cloudflare/circl@v1.5.0/sign/mldsa/mldsa44/internal/dilithium.go (about)

     1  // Code generated from mode3/internal/dilithium.go by gen.go
     2  
     3  package internal
     4  
     5  import (
     6  	cryptoRand "crypto/rand"
     7  	"crypto/subtle"
     8  	"io"
     9  
    10  	"github.com/cloudflare/circl/internal/sha3"
    11  	common "github.com/cloudflare/circl/sign/internal/dilithium"
    12  )
    13  
    14  const (
    15  	// Size of a packed polynomial of norm ≤η.
    16  	// (Note that the  formula is not valid in general.)
    17  	PolyLeqEtaSize = (common.N * DoubleEtaBits) / 8
    18  
    19  	// β = τη, the maximum size of c s₂.
    20  	Beta = Tau * Eta
    21  
    22  	// γ₁ range of y
    23  	Gamma1 = 1 << Gamma1Bits
    24  
    25  	// Size of packed polynomial of norm <γ₁ such as z
    26  	PolyLeGamma1Size = (Gamma1Bits + 1) * common.N / 8
    27  
    28  	// α = 2γ₂ parameter for decompose
    29  	Alpha = 2 * Gamma2
    30  
    31  	// Size of a packed private key
    32  	PrivateKeySize = 32 + 32 + TRSize + PolyLeqEtaSize*(L+K) + common.PolyT0Size*K
    33  
    34  	// Size of a packed public key
    35  	PublicKeySize = 32 + common.PolyT1Size*K
    36  
    37  	// Size of a packed signature
    38  	SignatureSize = L*PolyLeGamma1Size + Omega + K + CTildeSize
    39  
    40  	// Size of packed w₁
    41  	PolyW1Size = (common.N * (common.QBits - Gamma1Bits)) / 8
    42  )
    43  
    44  // PublicKey is the type of Dilithium public keys.
    45  type PublicKey struct {
    46  	rho [32]byte
    47  	t1  VecK
    48  
    49  	// Cached values
    50  	t1p [common.PolyT1Size * K]byte
    51  	A   *Mat
    52  	tr  *[TRSize]byte
    53  }
    54  
    55  // PrivateKey is the type of Dilithium private keys.
    56  type PrivateKey struct {
    57  	rho [32]byte
    58  	key [32]byte
    59  	s1  VecL
    60  	s2  VecK
    61  	t0  VecK
    62  	tr  [TRSize]byte
    63  
    64  	// Cached values
    65  	A   Mat  // ExpandA(ρ)
    66  	s1h VecL // NTT(s₁)
    67  	s2h VecK // NTT(s₂)
    68  	t0h VecK // NTT(t₀)
    69  }
    70  
    71  type unpackedSignature struct {
    72  	z    VecL
    73  	hint VecK
    74  	c    [CTildeSize]byte
    75  }
    76  
    77  // Packs the signature into buf.
    78  func (sig *unpackedSignature) Pack(buf []byte) {
    79  	copy(buf[:], sig.c[:])
    80  	sig.z.PackLeGamma1(buf[CTildeSize:])
    81  	sig.hint.PackHint(buf[CTildeSize+L*PolyLeGamma1Size:])
    82  }
    83  
    84  // Sets sig to the signature encoded in the buffer.
    85  //
    86  // Returns whether buf contains a properly packed signature.
    87  func (sig *unpackedSignature) Unpack(buf []byte) bool {
    88  	if len(buf) < SignatureSize {
    89  		return false
    90  	}
    91  	copy(sig.c[:], buf[:])
    92  	sig.z.UnpackLeGamma1(buf[CTildeSize:])
    93  	if sig.z.Exceeds(Gamma1 - Beta) {
    94  		return false
    95  	}
    96  	if !sig.hint.UnpackHint(buf[CTildeSize+L*PolyLeGamma1Size:]) {
    97  		return false
    98  	}
    99  	return true
   100  }
   101  
   102  // Packs the public key into buf.
   103  func (pk *PublicKey) Pack(buf *[PublicKeySize]byte) {
   104  	copy(buf[:32], pk.rho[:])
   105  	copy(buf[32:], pk.t1p[:])
   106  }
   107  
   108  // Sets pk to the public key encoded in buf.
   109  func (pk *PublicKey) Unpack(buf *[PublicKeySize]byte) {
   110  	copy(pk.rho[:], buf[:32])
   111  	copy(pk.t1p[:], buf[32:])
   112  
   113  	pk.t1.UnpackT1(pk.t1p[:])
   114  	pk.A = new(Mat)
   115  	pk.A.Derive(&pk.rho)
   116  
   117  	// tr = CRH(ρ ‖ t1) = CRH(pk)
   118  	pk.tr = new([TRSize]byte)
   119  	h := sha3.NewShake256()
   120  	_, _ = h.Write(buf[:])
   121  	_, _ = h.Read(pk.tr[:])
   122  }
   123  
   124  // Packs the private key into buf.
   125  func (sk *PrivateKey) Pack(buf *[PrivateKeySize]byte) {
   126  	copy(buf[:32], sk.rho[:])
   127  	copy(buf[32:64], sk.key[:])
   128  	copy(buf[64:64+TRSize], sk.tr[:])
   129  	offset := 64 + TRSize
   130  	sk.s1.PackLeqEta(buf[offset:])
   131  	offset += PolyLeqEtaSize * L
   132  	sk.s2.PackLeqEta(buf[offset:])
   133  	offset += PolyLeqEtaSize * K
   134  	sk.t0.PackT0(buf[offset:])
   135  }
   136  
   137  // Sets sk to the private key encoded in buf.
   138  func (sk *PrivateKey) Unpack(buf *[PrivateKeySize]byte) {
   139  	copy(sk.rho[:], buf[:32])
   140  	copy(sk.key[:], buf[32:64])
   141  	copy(sk.tr[:], buf[64:64+TRSize])
   142  	offset := 64 + TRSize
   143  	sk.s1.UnpackLeqEta(buf[offset:])
   144  	offset += PolyLeqEtaSize * L
   145  	sk.s2.UnpackLeqEta(buf[offset:])
   146  	offset += PolyLeqEtaSize * K
   147  	sk.t0.UnpackT0(buf[offset:])
   148  
   149  	// Cached values
   150  	sk.A.Derive(&sk.rho)
   151  	sk.t0h = sk.t0
   152  	sk.t0h.NTT()
   153  	sk.s1h = sk.s1
   154  	sk.s1h.NTT()
   155  	sk.s2h = sk.s2
   156  	sk.s2h.NTT()
   157  }
   158  
   159  // GenerateKey generates a public/private key pair using entropy from rand.
   160  // If rand is nil, crypto/rand.Reader will be used.
   161  func GenerateKey(rand io.Reader) (*PublicKey, *PrivateKey, error) {
   162  	var seed [32]byte
   163  	if rand == nil {
   164  		rand = cryptoRand.Reader
   165  	}
   166  	_, err := io.ReadFull(rand, seed[:])
   167  	if err != nil {
   168  		return nil, nil, err
   169  	}
   170  	pk, sk := NewKeyFromSeed(&seed)
   171  	return pk, sk, nil
   172  }
   173  
   174  // NewKeyFromSeed derives a public/private key pair using the given seed.
   175  func NewKeyFromSeed(seed *[common.SeedSize]byte) (*PublicKey, *PrivateKey) {
   176  	var eSeed [128]byte // expanded seed
   177  	var pk PublicKey
   178  	var sk PrivateKey
   179  	var sSeed [64]byte
   180  
   181  	h := sha3.NewShake256()
   182  	_, _ = h.Write(seed[:])
   183  
   184  	if NIST {
   185  		_, _ = h.Write([]byte{byte(K), byte(L)})
   186  	}
   187  
   188  	_, _ = h.Read(eSeed[:])
   189  
   190  	copy(pk.rho[:], eSeed[:32])
   191  	copy(sSeed[:], eSeed[32:96])
   192  	copy(sk.key[:], eSeed[96:])
   193  	copy(sk.rho[:], pk.rho[:])
   194  
   195  	sk.A.Derive(&pk.rho)
   196  
   197  	for i := uint16(0); i < L; i++ {
   198  		PolyDeriveUniformLeqEta(&sk.s1[i], &sSeed, i)
   199  	}
   200  
   201  	for i := uint16(0); i < K; i++ {
   202  		PolyDeriveUniformLeqEta(&sk.s2[i], &sSeed, i+L)
   203  	}
   204  
   205  	sk.s1h = sk.s1
   206  	sk.s1h.NTT()
   207  	sk.s2h = sk.s2
   208  	sk.s2h.NTT()
   209  
   210  	sk.computeT0andT1(&sk.t0, &pk.t1)
   211  
   212  	sk.t0h = sk.t0
   213  	sk.t0h.NTT()
   214  
   215  	// Complete public key far enough to be packed
   216  	pk.t1.PackT1(pk.t1p[:])
   217  	pk.A = &sk.A
   218  
   219  	// Finish private key
   220  	var packedPk [PublicKeySize]byte
   221  	pk.Pack(&packedPk)
   222  
   223  	// tr = CRH(ρ ‖ t1) = CRH(pk)
   224  	h.Reset()
   225  	_, _ = h.Write(packedPk[:])
   226  	_, _ = h.Read(sk.tr[:])
   227  
   228  	// Finish cache of public key
   229  	pk.tr = &sk.tr
   230  
   231  	return &pk, &sk
   232  }
   233  
   234  // Computes t0 and t1 from sk.s1h, sk.s2 and sk.A.
   235  func (sk *PrivateKey) computeT0andT1(t0, t1 *VecK) {
   236  	var t VecK
   237  
   238  	// Set t to A s₁ + s₂
   239  	for i := 0; i < K; i++ {
   240  		PolyDotHat(&t[i], &sk.A[i], &sk.s1h)
   241  		t[i].ReduceLe2Q()
   242  		t[i].InvNTT()
   243  	}
   244  	t.Add(&t, &sk.s2)
   245  	t.Normalize()
   246  
   247  	// Compute t₀, t₁ = Power2Round(t)
   248  	t.Power2Round(t0, t1)
   249  }
   250  
   251  // Verify checks whether the given signature by pk on msg is valid.
   252  //
   253  // For Dilithium this is the top-level verification function.
   254  // In ML-DSA, this is ML-DSA.Verify_internal.
   255  func Verify(pk *PublicKey, msg func(io.Writer), signature []byte) bool {
   256  	var sig unpackedSignature
   257  	var mu [64]byte
   258  	var zh VecL
   259  	var Az, Az2dct1, w1 VecK
   260  	var ch common.Poly
   261  	var cp [CTildeSize]byte
   262  	var w1Packed [PolyW1Size * K]byte
   263  
   264  	// Note that Unpack() checked whether ‖z‖_∞ < γ₁ - β
   265  	// and ensured that there at most ω ones in pk.hint.
   266  	if !sig.Unpack(signature) {
   267  		return false
   268  	}
   269  
   270  	// μ = CRH(tr ‖ msg)
   271  	h := sha3.NewShake256()
   272  	_, _ = h.Write(pk.tr[:])
   273  	msg(&h)
   274  	_, _ = h.Read(mu[:])
   275  
   276  	// Compute Az
   277  	zh = sig.z
   278  	zh.NTT()
   279  
   280  	for i := 0; i < K; i++ {
   281  		PolyDotHat(&Az[i], &pk.A[i], &zh)
   282  	}
   283  
   284  	// Next, we compute Az - 2ᵈ·c·t₁.
   285  	// Note that the coefficients of t₁ are bounded by 256 = 2⁹,
   286  	// so the coefficients of Az2dct1 will bounded by 2⁹⁺ᵈ = 2²³ < 2q,
   287  	// which is small enough for NTT().
   288  	Az2dct1.MulBy2toD(&pk.t1)
   289  	Az2dct1.NTT()
   290  	PolyDeriveUniformBall(&ch, sig.c[:])
   291  	ch.NTT()
   292  	for i := 0; i < K; i++ {
   293  		Az2dct1[i].MulHat(&Az2dct1[i], &ch)
   294  	}
   295  	Az2dct1.Sub(&Az, &Az2dct1)
   296  	Az2dct1.ReduceLe2Q()
   297  	Az2dct1.InvNTT()
   298  	Az2dct1.NormalizeAssumingLe2Q()
   299  
   300  	// UseHint(pk.hint, Az - 2ᵈ·c·t₁)
   301  	//    = UseHint(pk.hint, w - c·s₂ + c·t₀)
   302  	//    = UseHint(pk.hint, r + c·t₀)
   303  	//    = r₁ = w₁.
   304  	w1.UseHint(&Az2dct1, &sig.hint)
   305  	w1.PackW1(w1Packed[:])
   306  
   307  	// c' = H(μ, w₁)
   308  	h.Reset()
   309  	_, _ = h.Write(mu[:])
   310  	_, _ = h.Write(w1Packed[:])
   311  	_, _ = h.Read(cp[:])
   312  
   313  	return sig.c == cp
   314  }
   315  
   316  // SignTo signs the given message and writes the signature into signature.
   317  //
   318  // For Dilithium this is the top-level signing function. For ML-DSA
   319  // this is ML-DSA.Sign_internal.
   320  //
   321  //nolint:funlen
   322  func SignTo(sk *PrivateKey, msg func(io.Writer), rnd [32]byte, signature []byte) {
   323  	var mu, rhop [64]byte
   324  	var w1Packed [PolyW1Size * K]byte
   325  	var y, yh VecL
   326  	var w, w0, w1, w0mcs2, ct0, w0mcs2pct0 VecK
   327  	var ch common.Poly
   328  	var yNonce uint16
   329  	var sig unpackedSignature
   330  
   331  	if len(signature) < SignatureSize {
   332  		panic("Signature does not fit in that byteslice")
   333  	}
   334  
   335  	//  μ = CRH(tr ‖ msg)
   336  	h := sha3.NewShake256()
   337  	_, _ = h.Write(sk.tr[:])
   338  	msg(&h)
   339  	_, _ = h.Read(mu[:])
   340  
   341  	// ρ' = CRH(key ‖ μ)
   342  	h.Reset()
   343  	_, _ = h.Write(sk.key[:])
   344  	if NIST {
   345  		_, _ = h.Write(rnd[:])
   346  	}
   347  	_, _ = h.Write(mu[:])
   348  	_, _ = h.Read(rhop[:])
   349  
   350  	// Main rejection loop
   351  	attempt := 0
   352  	for {
   353  		attempt++
   354  		if attempt >= 576 {
   355  			// Depending on the mode, one try has a chance between 1/7 and 1/4
   356  			// of succeeding.  Thus it is safe to say that 576 iterations
   357  			// are enough as (6/7)⁵⁷⁶ < 2⁻¹²⁸.
   358  			panic("This should only happen 1 in  2^{128}: something is wrong.")
   359  		}
   360  
   361  		// y = ExpandMask(ρ', key)
   362  		VecLDeriveUniformLeGamma1(&y, &rhop, yNonce)
   363  		yNonce += uint16(L)
   364  
   365  		// Set w to A y
   366  		yh = y
   367  		yh.NTT()
   368  		for i := 0; i < K; i++ {
   369  			PolyDotHat(&w[i], &sk.A[i], &yh)
   370  			w[i].ReduceLe2Q()
   371  			w[i].InvNTT()
   372  		}
   373  
   374  		// Decompose w into w₀ and w₁
   375  		w.NormalizeAssumingLe2Q()
   376  		w.Decompose(&w0, &w1)
   377  
   378  		// c~ = H(μ ‖ w₁)
   379  		w1.PackW1(w1Packed[:])
   380  		h.Reset()
   381  		_, _ = h.Write(mu[:])
   382  		_, _ = h.Write(w1Packed[:])
   383  		_, _ = h.Read(sig.c[:])
   384  
   385  		PolyDeriveUniformBall(&ch, sig.c[:])
   386  		ch.NTT()
   387  
   388  		// Ensure ‖ w₀ - c·s2 ‖_∞ < γ₂ - β.
   389  		//
   390  		// By Lemma 3 of the specification this is equivalent to checking that
   391  		// both ‖ r₀ ‖_∞ < γ₂ - β and r₁ = w₁, for the decomposition
   392  		// w - c·s₂	 = r₁ α + r₀ as computed by decompose().
   393  		// See also §4.1 of the specification.
   394  		for i := 0; i < K; i++ {
   395  			w0mcs2[i].MulHat(&ch, &sk.s2h[i])
   396  			w0mcs2[i].InvNTT()
   397  		}
   398  		w0mcs2.Sub(&w0, &w0mcs2)
   399  		w0mcs2.Normalize()
   400  
   401  		if w0mcs2.Exceeds(Gamma2 - Beta) {
   402  			continue
   403  		}
   404  
   405  		// z = y + c·s₁
   406  		for i := 0; i < L; i++ {
   407  			sig.z[i].MulHat(&ch, &sk.s1h[i])
   408  			sig.z[i].InvNTT()
   409  		}
   410  		sig.z.Add(&sig.z, &y)
   411  		sig.z.Normalize()
   412  
   413  		// Ensure  ‖z‖_∞ < γ₁ - β
   414  		if sig.z.Exceeds(Gamma1 - Beta) {
   415  			continue
   416  		}
   417  
   418  		// Compute c·t₀
   419  		for i := 0; i < K; i++ {
   420  			ct0[i].MulHat(&ch, &sk.t0h[i])
   421  			ct0[i].InvNTT()
   422  		}
   423  		ct0.NormalizeAssumingLe2Q()
   424  
   425  		// Ensure ‖c·t₀‖_∞ < γ₂.
   426  		if ct0.Exceeds(Gamma2) {
   427  			continue
   428  		}
   429  
   430  		// Create the hint to be able to reconstruct w₁ from w - c·s₂ + c·t0.
   431  		// Note that we're not using makeHint() in the obvious way as we
   432  		// do not know whether ‖ sc·s₂ - c·t₀ ‖_∞ < γ₂.  Instead we note
   433  		// that our makeHint() is actually the same as a makeHint for a
   434  		// different decomposition:
   435  		//
   436  		// Earlier we ensured indirectly with a check that r₁ = w₁ where
   437  		// r = w - c·s₂.  Hence r₀ = r - r₁ α = w - c·s₂ - w₁ α = w₀ - c·s₂.
   438  		// Thus  MakeHint(w₀ - c·s₂ + c·t₀, w₁) = MakeHint(r0 + c·t₀, r₁)
   439  		// and UseHint(w - c·s₂ + c·t₀, w₁) = UseHint(r + c·t₀, r₁).
   440  		// As we just ensured that ‖ c·t₀ ‖_∞ < γ₂ our usage is correct.
   441  		w0mcs2pct0.Add(&w0mcs2, &ct0)
   442  		w0mcs2pct0.NormalizeAssumingLe2Q()
   443  		hintPop := sig.hint.MakeHint(&w0mcs2pct0, &w1)
   444  		if hintPop > Omega {
   445  			continue
   446  		}
   447  
   448  		break
   449  	}
   450  
   451  	sig.Pack(signature[:])
   452  }
   453  
   454  // Computes the public key corresponding to this private key.
   455  func (sk *PrivateKey) Public() *PublicKey {
   456  	var t0 VecK
   457  	pk := &PublicKey{
   458  		rho: sk.rho,
   459  		A:   &sk.A,
   460  		tr:  &sk.tr,
   461  	}
   462  	sk.computeT0andT1(&t0, &pk.t1)
   463  	pk.t1.PackT1(pk.t1p[:])
   464  	return pk
   465  }
   466  
   467  // Equal returns whether the two public keys are equal
   468  func (pk *PublicKey) Equal(other *PublicKey) bool {
   469  	return pk.rho == other.rho && pk.t1 == other.t1
   470  }
   471  
   472  // Equal returns whether the two private keys are equal
   473  func (sk *PrivateKey) Equal(other *PrivateKey) bool {
   474  	ret := (subtle.ConstantTimeCompare(sk.rho[:], other.rho[:]) &
   475  		subtle.ConstantTimeCompare(sk.key[:], other.key[:]) &
   476  		subtle.ConstantTimeCompare(sk.tr[:], other.tr[:]))
   477  
   478  	acc := uint32(0)
   479  	for i := 0; i < L; i++ {
   480  		for j := 0; j < common.N; j++ {
   481  			acc |= sk.s1[i][j] ^ other.s1[i][j]
   482  		}
   483  	}
   484  	for i := 0; i < K; i++ {
   485  		for j := 0; j < common.N; j++ {
   486  			acc |= sk.s2[i][j] ^ other.s2[i][j]
   487  			acc |= sk.t0[i][j] ^ other.t0[i][j]
   488  		}
   489  	}
   490  	return (ret & subtle.ConstantTimeEq(int32(acc), 0)) == 1
   491  }