github.com/cloudflare/circl@v1.5.0/sign/dilithium/mode3/internal/dilithium.go (about)

     1  package internal
     2  
     3  import (
     4  	cryptoRand "crypto/rand"
     5  	"crypto/subtle"
     6  	"io"
     7  
     8  	"github.com/cloudflare/circl/internal/sha3"
     9  	common "github.com/cloudflare/circl/sign/internal/dilithium"
    10  )
    11  
    12  const (
    13  	// Size of a packed polynomial of norm ≤η.
    14  	// (Note that the  formula is not valid in general.)
    15  	PolyLeqEtaSize = (common.N * DoubleEtaBits) / 8
    16  
    17  	// β = τη, the maximum size of c s₂.
    18  	Beta = Tau * Eta
    19  
    20  	// γ₁ range of y
    21  	Gamma1 = 1 << Gamma1Bits
    22  
    23  	// Size of packed polynomial of norm <γ₁ such as z
    24  	PolyLeGamma1Size = (Gamma1Bits + 1) * common.N / 8
    25  
    26  	// α = 2γ₂ parameter for decompose
    27  	Alpha = 2 * Gamma2
    28  
    29  	// Size of a packed private key
    30  	PrivateKeySize = 32 + 32 + TRSize + PolyLeqEtaSize*(L+K) + common.PolyT0Size*K
    31  
    32  	// Size of a packed public key
    33  	PublicKeySize = 32 + common.PolyT1Size*K
    34  
    35  	// Size of a packed signature
    36  	SignatureSize = L*PolyLeGamma1Size + Omega + K + CTildeSize
    37  
    38  	// Size of packed w₁
    39  	PolyW1Size = (common.N * (common.QBits - Gamma1Bits)) / 8
    40  )
    41  
    42  // PublicKey is the type of Dilithium public keys.
    43  type PublicKey struct {
    44  	rho [32]byte
    45  	t1  VecK
    46  
    47  	// Cached values
    48  	t1p [common.PolyT1Size * K]byte
    49  	A   *Mat
    50  	tr  *[TRSize]byte
    51  }
    52  
    53  // PrivateKey is the type of Dilithium private keys.
    54  type PrivateKey struct {
    55  	rho [32]byte
    56  	key [32]byte
    57  	s1  VecL
    58  	s2  VecK
    59  	t0  VecK
    60  	tr  [TRSize]byte
    61  
    62  	// Cached values
    63  	A   Mat  // ExpandA(ρ)
    64  	s1h VecL // NTT(s₁)
    65  	s2h VecK // NTT(s₂)
    66  	t0h VecK // NTT(t₀)
    67  }
    68  
    69  type unpackedSignature struct {
    70  	z    VecL
    71  	hint VecK
    72  	c    [CTildeSize]byte
    73  }
    74  
    75  // Packs the signature into buf.
    76  func (sig *unpackedSignature) Pack(buf []byte) {
    77  	copy(buf[:], sig.c[:])
    78  	sig.z.PackLeGamma1(buf[CTildeSize:])
    79  	sig.hint.PackHint(buf[CTildeSize+L*PolyLeGamma1Size:])
    80  }
    81  
    82  // Sets sig to the signature encoded in the buffer.
    83  //
    84  // Returns whether buf contains a properly packed signature.
    85  func (sig *unpackedSignature) Unpack(buf []byte) bool {
    86  	if len(buf) < SignatureSize {
    87  		return false
    88  	}
    89  	copy(sig.c[:], buf[:])
    90  	sig.z.UnpackLeGamma1(buf[CTildeSize:])
    91  	if sig.z.Exceeds(Gamma1 - Beta) {
    92  		return false
    93  	}
    94  	if !sig.hint.UnpackHint(buf[CTildeSize+L*PolyLeGamma1Size:]) {
    95  		return false
    96  	}
    97  	return true
    98  }
    99  
   100  // Packs the public key into buf.
   101  func (pk *PublicKey) Pack(buf *[PublicKeySize]byte) {
   102  	copy(buf[:32], pk.rho[:])
   103  	copy(buf[32:], pk.t1p[:])
   104  }
   105  
   106  // Sets pk to the public key encoded in buf.
   107  func (pk *PublicKey) Unpack(buf *[PublicKeySize]byte) {
   108  	copy(pk.rho[:], buf[:32])
   109  	copy(pk.t1p[:], buf[32:])
   110  
   111  	pk.t1.UnpackT1(pk.t1p[:])
   112  	pk.A = new(Mat)
   113  	pk.A.Derive(&pk.rho)
   114  
   115  	// tr = CRH(ρ ‖ t1) = CRH(pk)
   116  	pk.tr = new([TRSize]byte)
   117  	h := sha3.NewShake256()
   118  	_, _ = h.Write(buf[:])
   119  	_, _ = h.Read(pk.tr[:])
   120  }
   121  
   122  // Packs the private key into buf.
   123  func (sk *PrivateKey) Pack(buf *[PrivateKeySize]byte) {
   124  	copy(buf[:32], sk.rho[:])
   125  	copy(buf[32:64], sk.key[:])
   126  	copy(buf[64:64+TRSize], sk.tr[:])
   127  	offset := 64 + TRSize
   128  	sk.s1.PackLeqEta(buf[offset:])
   129  	offset += PolyLeqEtaSize * L
   130  	sk.s2.PackLeqEta(buf[offset:])
   131  	offset += PolyLeqEtaSize * K
   132  	sk.t0.PackT0(buf[offset:])
   133  }
   134  
   135  // Sets sk to the private key encoded in buf.
   136  func (sk *PrivateKey) Unpack(buf *[PrivateKeySize]byte) {
   137  	copy(sk.rho[:], buf[:32])
   138  	copy(sk.key[:], buf[32:64])
   139  	copy(sk.tr[:], buf[64:64+TRSize])
   140  	offset := 64 + TRSize
   141  	sk.s1.UnpackLeqEta(buf[offset:])
   142  	offset += PolyLeqEtaSize * L
   143  	sk.s2.UnpackLeqEta(buf[offset:])
   144  	offset += PolyLeqEtaSize * K
   145  	sk.t0.UnpackT0(buf[offset:])
   146  
   147  	// Cached values
   148  	sk.A.Derive(&sk.rho)
   149  	sk.t0h = sk.t0
   150  	sk.t0h.NTT()
   151  	sk.s1h = sk.s1
   152  	sk.s1h.NTT()
   153  	sk.s2h = sk.s2
   154  	sk.s2h.NTT()
   155  }
   156  
   157  // GenerateKey generates a public/private key pair using entropy from rand.
   158  // If rand is nil, crypto/rand.Reader will be used.
   159  func GenerateKey(rand io.Reader) (*PublicKey, *PrivateKey, error) {
   160  	var seed [32]byte
   161  	if rand == nil {
   162  		rand = cryptoRand.Reader
   163  	}
   164  	_, err := io.ReadFull(rand, seed[:])
   165  	if err != nil {
   166  		return nil, nil, err
   167  	}
   168  	pk, sk := NewKeyFromSeed(&seed)
   169  	return pk, sk, nil
   170  }
   171  
   172  // NewKeyFromSeed derives a public/private key pair using the given seed.
   173  func NewKeyFromSeed(seed *[common.SeedSize]byte) (*PublicKey, *PrivateKey) {
   174  	var eSeed [128]byte // expanded seed
   175  	var pk PublicKey
   176  	var sk PrivateKey
   177  	var sSeed [64]byte
   178  
   179  	h := sha3.NewShake256()
   180  	_, _ = h.Write(seed[:])
   181  
   182  	if NIST {
   183  		_, _ = h.Write([]byte{byte(K), byte(L)})
   184  	}
   185  
   186  	_, _ = h.Read(eSeed[:])
   187  
   188  	copy(pk.rho[:], eSeed[:32])
   189  	copy(sSeed[:], eSeed[32:96])
   190  	copy(sk.key[:], eSeed[96:])
   191  	copy(sk.rho[:], pk.rho[:])
   192  
   193  	sk.A.Derive(&pk.rho)
   194  
   195  	for i := uint16(0); i < L; i++ {
   196  		PolyDeriveUniformLeqEta(&sk.s1[i], &sSeed, i)
   197  	}
   198  
   199  	for i := uint16(0); i < K; i++ {
   200  		PolyDeriveUniformLeqEta(&sk.s2[i], &sSeed, i+L)
   201  	}
   202  
   203  	sk.s1h = sk.s1
   204  	sk.s1h.NTT()
   205  	sk.s2h = sk.s2
   206  	sk.s2h.NTT()
   207  
   208  	sk.computeT0andT1(&sk.t0, &pk.t1)
   209  
   210  	sk.t0h = sk.t0
   211  	sk.t0h.NTT()
   212  
   213  	// Complete public key far enough to be packed
   214  	pk.t1.PackT1(pk.t1p[:])
   215  	pk.A = &sk.A
   216  
   217  	// Finish private key
   218  	var packedPk [PublicKeySize]byte
   219  	pk.Pack(&packedPk)
   220  
   221  	// tr = CRH(ρ ‖ t1) = CRH(pk)
   222  	h.Reset()
   223  	_, _ = h.Write(packedPk[:])
   224  	_, _ = h.Read(sk.tr[:])
   225  
   226  	// Finish cache of public key
   227  	pk.tr = &sk.tr
   228  
   229  	return &pk, &sk
   230  }
   231  
   232  // Computes t0 and t1 from sk.s1h, sk.s2 and sk.A.
   233  func (sk *PrivateKey) computeT0andT1(t0, t1 *VecK) {
   234  	var t VecK
   235  
   236  	// Set t to A s₁ + s₂
   237  	for i := 0; i < K; i++ {
   238  		PolyDotHat(&t[i], &sk.A[i], &sk.s1h)
   239  		t[i].ReduceLe2Q()
   240  		t[i].InvNTT()
   241  	}
   242  	t.Add(&t, &sk.s2)
   243  	t.Normalize()
   244  
   245  	// Compute t₀, t₁ = Power2Round(t)
   246  	t.Power2Round(t0, t1)
   247  }
   248  
   249  // Verify checks whether the given signature by pk on msg is valid.
   250  //
   251  // For Dilithium this is the top-level verification function.
   252  // In ML-DSA, this is ML-DSA.Verify_internal.
   253  func Verify(pk *PublicKey, msg func(io.Writer), signature []byte) bool {
   254  	var sig unpackedSignature
   255  	var mu [64]byte
   256  	var zh VecL
   257  	var Az, Az2dct1, w1 VecK
   258  	var ch common.Poly
   259  	var cp [CTildeSize]byte
   260  	var w1Packed [PolyW1Size * K]byte
   261  
   262  	// Note that Unpack() checked whether ‖z‖_∞ < γ₁ - β
   263  	// and ensured that there at most ω ones in pk.hint.
   264  	if !sig.Unpack(signature) {
   265  		return false
   266  	}
   267  
   268  	// μ = CRH(tr ‖ msg)
   269  	h := sha3.NewShake256()
   270  	_, _ = h.Write(pk.tr[:])
   271  	msg(&h)
   272  	_, _ = h.Read(mu[:])
   273  
   274  	// Compute Az
   275  	zh = sig.z
   276  	zh.NTT()
   277  
   278  	for i := 0; i < K; i++ {
   279  		PolyDotHat(&Az[i], &pk.A[i], &zh)
   280  	}
   281  
   282  	// Next, we compute Az - 2ᵈ·c·t₁.
   283  	// Note that the coefficients of t₁ are bounded by 256 = 2⁹,
   284  	// so the coefficients of Az2dct1 will bounded by 2⁹⁺ᵈ = 2²³ < 2q,
   285  	// which is small enough for NTT().
   286  	Az2dct1.MulBy2toD(&pk.t1)
   287  	Az2dct1.NTT()
   288  	PolyDeriveUniformBall(&ch, sig.c[:])
   289  	ch.NTT()
   290  	for i := 0; i < K; i++ {
   291  		Az2dct1[i].MulHat(&Az2dct1[i], &ch)
   292  	}
   293  	Az2dct1.Sub(&Az, &Az2dct1)
   294  	Az2dct1.ReduceLe2Q()
   295  	Az2dct1.InvNTT()
   296  	Az2dct1.NormalizeAssumingLe2Q()
   297  
   298  	// UseHint(pk.hint, Az - 2ᵈ·c·t₁)
   299  	//    = UseHint(pk.hint, w - c·s₂ + c·t₀)
   300  	//    = UseHint(pk.hint, r + c·t₀)
   301  	//    = r₁ = w₁.
   302  	w1.UseHint(&Az2dct1, &sig.hint)
   303  	w1.PackW1(w1Packed[:])
   304  
   305  	// c' = H(μ, w₁)
   306  	h.Reset()
   307  	_, _ = h.Write(mu[:])
   308  	_, _ = h.Write(w1Packed[:])
   309  	_, _ = h.Read(cp[:])
   310  
   311  	return sig.c == cp
   312  }
   313  
   314  // SignTo signs the given message and writes the signature into signature.
   315  //
   316  // For Dilithium this is the top-level signing function. For ML-DSA
   317  // this is ML-DSA.Sign_internal.
   318  //
   319  //nolint:funlen
   320  func SignTo(sk *PrivateKey, msg func(io.Writer), rnd [32]byte, signature []byte) {
   321  	var mu, rhop [64]byte
   322  	var w1Packed [PolyW1Size * K]byte
   323  	var y, yh VecL
   324  	var w, w0, w1, w0mcs2, ct0, w0mcs2pct0 VecK
   325  	var ch common.Poly
   326  	var yNonce uint16
   327  	var sig unpackedSignature
   328  
   329  	if len(signature) < SignatureSize {
   330  		panic("Signature does not fit in that byteslice")
   331  	}
   332  
   333  	//  μ = CRH(tr ‖ msg)
   334  	h := sha3.NewShake256()
   335  	_, _ = h.Write(sk.tr[:])
   336  	msg(&h)
   337  	_, _ = h.Read(mu[:])
   338  
   339  	// ρ' = CRH(key ‖ μ)
   340  	h.Reset()
   341  	_, _ = h.Write(sk.key[:])
   342  	if NIST {
   343  		_, _ = h.Write(rnd[:])
   344  	}
   345  	_, _ = h.Write(mu[:])
   346  	_, _ = h.Read(rhop[:])
   347  
   348  	// Main rejection loop
   349  	attempt := 0
   350  	for {
   351  		attempt++
   352  		if attempt >= 576 {
   353  			// Depending on the mode, one try has a chance between 1/7 and 1/4
   354  			// of succeeding.  Thus it is safe to say that 576 iterations
   355  			// are enough as (6/7)⁵⁷⁶ < 2⁻¹²⁸.
   356  			panic("This should only happen 1 in  2^{128}: something is wrong.")
   357  		}
   358  
   359  		// y = ExpandMask(ρ', key)
   360  		VecLDeriveUniformLeGamma1(&y, &rhop, yNonce)
   361  		yNonce += uint16(L)
   362  
   363  		// Set w to A y
   364  		yh = y
   365  		yh.NTT()
   366  		for i := 0; i < K; i++ {
   367  			PolyDotHat(&w[i], &sk.A[i], &yh)
   368  			w[i].ReduceLe2Q()
   369  			w[i].InvNTT()
   370  		}
   371  
   372  		// Decompose w into w₀ and w₁
   373  		w.NormalizeAssumingLe2Q()
   374  		w.Decompose(&w0, &w1)
   375  
   376  		// c~ = H(μ ‖ w₁)
   377  		w1.PackW1(w1Packed[:])
   378  		h.Reset()
   379  		_, _ = h.Write(mu[:])
   380  		_, _ = h.Write(w1Packed[:])
   381  		_, _ = h.Read(sig.c[:])
   382  
   383  		PolyDeriveUniformBall(&ch, sig.c[:])
   384  		ch.NTT()
   385  
   386  		// Ensure ‖ w₀ - c·s2 ‖_∞ < γ₂ - β.
   387  		//
   388  		// By Lemma 3 of the specification this is equivalent to checking that
   389  		// both ‖ r₀ ‖_∞ < γ₂ - β and r₁ = w₁, for the decomposition
   390  		// w - c·s₂	 = r₁ α + r₀ as computed by decompose().
   391  		// See also §4.1 of the specification.
   392  		for i := 0; i < K; i++ {
   393  			w0mcs2[i].MulHat(&ch, &sk.s2h[i])
   394  			w0mcs2[i].InvNTT()
   395  		}
   396  		w0mcs2.Sub(&w0, &w0mcs2)
   397  		w0mcs2.Normalize()
   398  
   399  		if w0mcs2.Exceeds(Gamma2 - Beta) {
   400  			continue
   401  		}
   402  
   403  		// z = y + c·s₁
   404  		for i := 0; i < L; i++ {
   405  			sig.z[i].MulHat(&ch, &sk.s1h[i])
   406  			sig.z[i].InvNTT()
   407  		}
   408  		sig.z.Add(&sig.z, &y)
   409  		sig.z.Normalize()
   410  
   411  		// Ensure  ‖z‖_∞ < γ₁ - β
   412  		if sig.z.Exceeds(Gamma1 - Beta) {
   413  			continue
   414  		}
   415  
   416  		// Compute c·t₀
   417  		for i := 0; i < K; i++ {
   418  			ct0[i].MulHat(&ch, &sk.t0h[i])
   419  			ct0[i].InvNTT()
   420  		}
   421  		ct0.NormalizeAssumingLe2Q()
   422  
   423  		// Ensure ‖c·t₀‖_∞ < γ₂.
   424  		if ct0.Exceeds(Gamma2) {
   425  			continue
   426  		}
   427  
   428  		// Create the hint to be able to reconstruct w₁ from w - c·s₂ + c·t0.
   429  		// Note that we're not using makeHint() in the obvious way as we
   430  		// do not know whether ‖ sc·s₂ - c·t₀ ‖_∞ < γ₂.  Instead we note
   431  		// that our makeHint() is actually the same as a makeHint for a
   432  		// different decomposition:
   433  		//
   434  		// Earlier we ensured indirectly with a check that r₁ = w₁ where
   435  		// r = w - c·s₂.  Hence r₀ = r - r₁ α = w - c·s₂ - w₁ α = w₀ - c·s₂.
   436  		// Thus  MakeHint(w₀ - c·s₂ + c·t₀, w₁) = MakeHint(r0 + c·t₀, r₁)
   437  		// and UseHint(w - c·s₂ + c·t₀, w₁) = UseHint(r + c·t₀, r₁).
   438  		// As we just ensured that ‖ c·t₀ ‖_∞ < γ₂ our usage is correct.
   439  		w0mcs2pct0.Add(&w0mcs2, &ct0)
   440  		w0mcs2pct0.NormalizeAssumingLe2Q()
   441  		hintPop := sig.hint.MakeHint(&w0mcs2pct0, &w1)
   442  		if hintPop > Omega {
   443  			continue
   444  		}
   445  
   446  		break
   447  	}
   448  
   449  	sig.Pack(signature[:])
   450  }
   451  
   452  // Computes the public key corresponding to this private key.
   453  func (sk *PrivateKey) Public() *PublicKey {
   454  	var t0 VecK
   455  	pk := &PublicKey{
   456  		rho: sk.rho,
   457  		A:   &sk.A,
   458  		tr:  &sk.tr,
   459  	}
   460  	sk.computeT0andT1(&t0, &pk.t1)
   461  	pk.t1.PackT1(pk.t1p[:])
   462  	return pk
   463  }
   464  
   465  // Equal returns whether the two public keys are equal
   466  func (pk *PublicKey) Equal(other *PublicKey) bool {
   467  	return pk.rho == other.rho && pk.t1 == other.t1
   468  }
   469  
   470  // Equal returns whether the two private keys are equal
   471  func (sk *PrivateKey) Equal(other *PrivateKey) bool {
   472  	ret := (subtle.ConstantTimeCompare(sk.rho[:], other.rho[:]) &
   473  		subtle.ConstantTimeCompare(sk.key[:], other.key[:]) &
   474  		subtle.ConstantTimeCompare(sk.tr[:], other.tr[:]))
   475  
   476  	acc := uint32(0)
   477  	for i := 0; i < L; i++ {
   478  		for j := 0; j < common.N; j++ {
   479  			acc |= sk.s1[i][j] ^ other.s1[i][j]
   480  		}
   481  	}
   482  	for i := 0; i < K; i++ {
   483  		for j := 0; j < common.N; j++ {
   484  			acc |= sk.s2[i][j] ^ other.s2[i][j]
   485  			acc |= sk.t0[i][j] ^ other.t0[i][j]
   486  		}
   487  	}
   488  	return (ret & subtle.ConstantTimeEq(int32(acc), 0)) == 1
   489  }