github.com/cloudflare/circl@v1.5.0/sign/dilithium/mode2/internal/dilithium.go (about) 1 // Code generated from mode3/internal/dilithium.go by gen.go 2 3 package internal 4 5 import ( 6 cryptoRand "crypto/rand" 7 "crypto/subtle" 8 "io" 9 10 "github.com/cloudflare/circl/internal/sha3" 11 common "github.com/cloudflare/circl/sign/internal/dilithium" 12 ) 13 14 const ( 15 // Size of a packed polynomial of norm ≤η. 16 // (Note that the formula is not valid in general.) 17 PolyLeqEtaSize = (common.N * DoubleEtaBits) / 8 18 19 // β = τη, the maximum size of c s₂. 20 Beta = Tau * Eta 21 22 // γ₁ range of y 23 Gamma1 = 1 << Gamma1Bits 24 25 // Size of packed polynomial of norm <γ₁ such as z 26 PolyLeGamma1Size = (Gamma1Bits + 1) * common.N / 8 27 28 // α = 2γ₂ parameter for decompose 29 Alpha = 2 * Gamma2 30 31 // Size of a packed private key 32 PrivateKeySize = 32 + 32 + TRSize + PolyLeqEtaSize*(L+K) + common.PolyT0Size*K 33 34 // Size of a packed public key 35 PublicKeySize = 32 + common.PolyT1Size*K 36 37 // Size of a packed signature 38 SignatureSize = L*PolyLeGamma1Size + Omega + K + CTildeSize 39 40 // Size of packed w₁ 41 PolyW1Size = (common.N * (common.QBits - Gamma1Bits)) / 8 42 ) 43 44 // PublicKey is the type of Dilithium public keys. 45 type PublicKey struct { 46 rho [32]byte 47 t1 VecK 48 49 // Cached values 50 t1p [common.PolyT1Size * K]byte 51 A *Mat 52 tr *[TRSize]byte 53 } 54 55 // PrivateKey is the type of Dilithium private keys. 56 type PrivateKey struct { 57 rho [32]byte 58 key [32]byte 59 s1 VecL 60 s2 VecK 61 t0 VecK 62 tr [TRSize]byte 63 64 // Cached values 65 A Mat // ExpandA(ρ) 66 s1h VecL // NTT(s₁) 67 s2h VecK // NTT(s₂) 68 t0h VecK // NTT(t₀) 69 } 70 71 type unpackedSignature struct { 72 z VecL 73 hint VecK 74 c [CTildeSize]byte 75 } 76 77 // Packs the signature into buf. 78 func (sig *unpackedSignature) Pack(buf []byte) { 79 copy(buf[:], sig.c[:]) 80 sig.z.PackLeGamma1(buf[CTildeSize:]) 81 sig.hint.PackHint(buf[CTildeSize+L*PolyLeGamma1Size:]) 82 } 83 84 // Sets sig to the signature encoded in the buffer. 85 // 86 // Returns whether buf contains a properly packed signature. 87 func (sig *unpackedSignature) Unpack(buf []byte) bool { 88 if len(buf) < SignatureSize { 89 return false 90 } 91 copy(sig.c[:], buf[:]) 92 sig.z.UnpackLeGamma1(buf[CTildeSize:]) 93 if sig.z.Exceeds(Gamma1 - Beta) { 94 return false 95 } 96 if !sig.hint.UnpackHint(buf[CTildeSize+L*PolyLeGamma1Size:]) { 97 return false 98 } 99 return true 100 } 101 102 // Packs the public key into buf. 103 func (pk *PublicKey) Pack(buf *[PublicKeySize]byte) { 104 copy(buf[:32], pk.rho[:]) 105 copy(buf[32:], pk.t1p[:]) 106 } 107 108 // Sets pk to the public key encoded in buf. 109 func (pk *PublicKey) Unpack(buf *[PublicKeySize]byte) { 110 copy(pk.rho[:], buf[:32]) 111 copy(pk.t1p[:], buf[32:]) 112 113 pk.t1.UnpackT1(pk.t1p[:]) 114 pk.A = new(Mat) 115 pk.A.Derive(&pk.rho) 116 117 // tr = CRH(ρ ‖ t1) = CRH(pk) 118 pk.tr = new([TRSize]byte) 119 h := sha3.NewShake256() 120 _, _ = h.Write(buf[:]) 121 _, _ = h.Read(pk.tr[:]) 122 } 123 124 // Packs the private key into buf. 125 func (sk *PrivateKey) Pack(buf *[PrivateKeySize]byte) { 126 copy(buf[:32], sk.rho[:]) 127 copy(buf[32:64], sk.key[:]) 128 copy(buf[64:64+TRSize], sk.tr[:]) 129 offset := 64 + TRSize 130 sk.s1.PackLeqEta(buf[offset:]) 131 offset += PolyLeqEtaSize * L 132 sk.s2.PackLeqEta(buf[offset:]) 133 offset += PolyLeqEtaSize * K 134 sk.t0.PackT0(buf[offset:]) 135 } 136 137 // Sets sk to the private key encoded in buf. 138 func (sk *PrivateKey) Unpack(buf *[PrivateKeySize]byte) { 139 copy(sk.rho[:], buf[:32]) 140 copy(sk.key[:], buf[32:64]) 141 copy(sk.tr[:], buf[64:64+TRSize]) 142 offset := 64 + TRSize 143 sk.s1.UnpackLeqEta(buf[offset:]) 144 offset += PolyLeqEtaSize * L 145 sk.s2.UnpackLeqEta(buf[offset:]) 146 offset += PolyLeqEtaSize * K 147 sk.t0.UnpackT0(buf[offset:]) 148 149 // Cached values 150 sk.A.Derive(&sk.rho) 151 sk.t0h = sk.t0 152 sk.t0h.NTT() 153 sk.s1h = sk.s1 154 sk.s1h.NTT() 155 sk.s2h = sk.s2 156 sk.s2h.NTT() 157 } 158 159 // GenerateKey generates a public/private key pair using entropy from rand. 160 // If rand is nil, crypto/rand.Reader will be used. 161 func GenerateKey(rand io.Reader) (*PublicKey, *PrivateKey, error) { 162 var seed [32]byte 163 if rand == nil { 164 rand = cryptoRand.Reader 165 } 166 _, err := io.ReadFull(rand, seed[:]) 167 if err != nil { 168 return nil, nil, err 169 } 170 pk, sk := NewKeyFromSeed(&seed) 171 return pk, sk, nil 172 } 173 174 // NewKeyFromSeed derives a public/private key pair using the given seed. 175 func NewKeyFromSeed(seed *[common.SeedSize]byte) (*PublicKey, *PrivateKey) { 176 var eSeed [128]byte // expanded seed 177 var pk PublicKey 178 var sk PrivateKey 179 var sSeed [64]byte 180 181 h := sha3.NewShake256() 182 _, _ = h.Write(seed[:]) 183 184 if NIST { 185 _, _ = h.Write([]byte{byte(K), byte(L)}) 186 } 187 188 _, _ = h.Read(eSeed[:]) 189 190 copy(pk.rho[:], eSeed[:32]) 191 copy(sSeed[:], eSeed[32:96]) 192 copy(sk.key[:], eSeed[96:]) 193 copy(sk.rho[:], pk.rho[:]) 194 195 sk.A.Derive(&pk.rho) 196 197 for i := uint16(0); i < L; i++ { 198 PolyDeriveUniformLeqEta(&sk.s1[i], &sSeed, i) 199 } 200 201 for i := uint16(0); i < K; i++ { 202 PolyDeriveUniformLeqEta(&sk.s2[i], &sSeed, i+L) 203 } 204 205 sk.s1h = sk.s1 206 sk.s1h.NTT() 207 sk.s2h = sk.s2 208 sk.s2h.NTT() 209 210 sk.computeT0andT1(&sk.t0, &pk.t1) 211 212 sk.t0h = sk.t0 213 sk.t0h.NTT() 214 215 // Complete public key far enough to be packed 216 pk.t1.PackT1(pk.t1p[:]) 217 pk.A = &sk.A 218 219 // Finish private key 220 var packedPk [PublicKeySize]byte 221 pk.Pack(&packedPk) 222 223 // tr = CRH(ρ ‖ t1) = CRH(pk) 224 h.Reset() 225 _, _ = h.Write(packedPk[:]) 226 _, _ = h.Read(sk.tr[:]) 227 228 // Finish cache of public key 229 pk.tr = &sk.tr 230 231 return &pk, &sk 232 } 233 234 // Computes t0 and t1 from sk.s1h, sk.s2 and sk.A. 235 func (sk *PrivateKey) computeT0andT1(t0, t1 *VecK) { 236 var t VecK 237 238 // Set t to A s₁ + s₂ 239 for i := 0; i < K; i++ { 240 PolyDotHat(&t[i], &sk.A[i], &sk.s1h) 241 t[i].ReduceLe2Q() 242 t[i].InvNTT() 243 } 244 t.Add(&t, &sk.s2) 245 t.Normalize() 246 247 // Compute t₀, t₁ = Power2Round(t) 248 t.Power2Round(t0, t1) 249 } 250 251 // Verify checks whether the given signature by pk on msg is valid. 252 // 253 // For Dilithium this is the top-level verification function. 254 // In ML-DSA, this is ML-DSA.Verify_internal. 255 func Verify(pk *PublicKey, msg func(io.Writer), signature []byte) bool { 256 var sig unpackedSignature 257 var mu [64]byte 258 var zh VecL 259 var Az, Az2dct1, w1 VecK 260 var ch common.Poly 261 var cp [CTildeSize]byte 262 var w1Packed [PolyW1Size * K]byte 263 264 // Note that Unpack() checked whether ‖z‖_∞ < γ₁ - β 265 // and ensured that there at most ω ones in pk.hint. 266 if !sig.Unpack(signature) { 267 return false 268 } 269 270 // μ = CRH(tr ‖ msg) 271 h := sha3.NewShake256() 272 _, _ = h.Write(pk.tr[:]) 273 msg(&h) 274 _, _ = h.Read(mu[:]) 275 276 // Compute Az 277 zh = sig.z 278 zh.NTT() 279 280 for i := 0; i < K; i++ { 281 PolyDotHat(&Az[i], &pk.A[i], &zh) 282 } 283 284 // Next, we compute Az - 2ᵈ·c·t₁. 285 // Note that the coefficients of t₁ are bounded by 256 = 2⁹, 286 // so the coefficients of Az2dct1 will bounded by 2⁹⁺ᵈ = 2²³ < 2q, 287 // which is small enough for NTT(). 288 Az2dct1.MulBy2toD(&pk.t1) 289 Az2dct1.NTT() 290 PolyDeriveUniformBall(&ch, sig.c[:]) 291 ch.NTT() 292 for i := 0; i < K; i++ { 293 Az2dct1[i].MulHat(&Az2dct1[i], &ch) 294 } 295 Az2dct1.Sub(&Az, &Az2dct1) 296 Az2dct1.ReduceLe2Q() 297 Az2dct1.InvNTT() 298 Az2dct1.NormalizeAssumingLe2Q() 299 300 // UseHint(pk.hint, Az - 2ᵈ·c·t₁) 301 // = UseHint(pk.hint, w - c·s₂ + c·t₀) 302 // = UseHint(pk.hint, r + c·t₀) 303 // = r₁ = w₁. 304 w1.UseHint(&Az2dct1, &sig.hint) 305 w1.PackW1(w1Packed[:]) 306 307 // c' = H(μ, w₁) 308 h.Reset() 309 _, _ = h.Write(mu[:]) 310 _, _ = h.Write(w1Packed[:]) 311 _, _ = h.Read(cp[:]) 312 313 return sig.c == cp 314 } 315 316 // SignTo signs the given message and writes the signature into signature. 317 // 318 // For Dilithium this is the top-level signing function. For ML-DSA 319 // this is ML-DSA.Sign_internal. 320 // 321 //nolint:funlen 322 func SignTo(sk *PrivateKey, msg func(io.Writer), rnd [32]byte, signature []byte) { 323 var mu, rhop [64]byte 324 var w1Packed [PolyW1Size * K]byte 325 var y, yh VecL 326 var w, w0, w1, w0mcs2, ct0, w0mcs2pct0 VecK 327 var ch common.Poly 328 var yNonce uint16 329 var sig unpackedSignature 330 331 if len(signature) < SignatureSize { 332 panic("Signature does not fit in that byteslice") 333 } 334 335 // μ = CRH(tr ‖ msg) 336 h := sha3.NewShake256() 337 _, _ = h.Write(sk.tr[:]) 338 msg(&h) 339 _, _ = h.Read(mu[:]) 340 341 // ρ' = CRH(key ‖ μ) 342 h.Reset() 343 _, _ = h.Write(sk.key[:]) 344 if NIST { 345 _, _ = h.Write(rnd[:]) 346 } 347 _, _ = h.Write(mu[:]) 348 _, _ = h.Read(rhop[:]) 349 350 // Main rejection loop 351 attempt := 0 352 for { 353 attempt++ 354 if attempt >= 576 { 355 // Depending on the mode, one try has a chance between 1/7 and 1/4 356 // of succeeding. Thus it is safe to say that 576 iterations 357 // are enough as (6/7)⁵⁷⁶ < 2⁻¹²⁸. 358 panic("This should only happen 1 in 2^{128}: something is wrong.") 359 } 360 361 // y = ExpandMask(ρ', key) 362 VecLDeriveUniformLeGamma1(&y, &rhop, yNonce) 363 yNonce += uint16(L) 364 365 // Set w to A y 366 yh = y 367 yh.NTT() 368 for i := 0; i < K; i++ { 369 PolyDotHat(&w[i], &sk.A[i], &yh) 370 w[i].ReduceLe2Q() 371 w[i].InvNTT() 372 } 373 374 // Decompose w into w₀ and w₁ 375 w.NormalizeAssumingLe2Q() 376 w.Decompose(&w0, &w1) 377 378 // c~ = H(μ ‖ w₁) 379 w1.PackW1(w1Packed[:]) 380 h.Reset() 381 _, _ = h.Write(mu[:]) 382 _, _ = h.Write(w1Packed[:]) 383 _, _ = h.Read(sig.c[:]) 384 385 PolyDeriveUniformBall(&ch, sig.c[:]) 386 ch.NTT() 387 388 // Ensure ‖ w₀ - c·s2 ‖_∞ < γ₂ - β. 389 // 390 // By Lemma 3 of the specification this is equivalent to checking that 391 // both ‖ r₀ ‖_∞ < γ₂ - β and r₁ = w₁, for the decomposition 392 // w - c·s₂ = r₁ α + r₀ as computed by decompose(). 393 // See also §4.1 of the specification. 394 for i := 0; i < K; i++ { 395 w0mcs2[i].MulHat(&ch, &sk.s2h[i]) 396 w0mcs2[i].InvNTT() 397 } 398 w0mcs2.Sub(&w0, &w0mcs2) 399 w0mcs2.Normalize() 400 401 if w0mcs2.Exceeds(Gamma2 - Beta) { 402 continue 403 } 404 405 // z = y + c·s₁ 406 for i := 0; i < L; i++ { 407 sig.z[i].MulHat(&ch, &sk.s1h[i]) 408 sig.z[i].InvNTT() 409 } 410 sig.z.Add(&sig.z, &y) 411 sig.z.Normalize() 412 413 // Ensure ‖z‖_∞ < γ₁ - β 414 if sig.z.Exceeds(Gamma1 - Beta) { 415 continue 416 } 417 418 // Compute c·t₀ 419 for i := 0; i < K; i++ { 420 ct0[i].MulHat(&ch, &sk.t0h[i]) 421 ct0[i].InvNTT() 422 } 423 ct0.NormalizeAssumingLe2Q() 424 425 // Ensure ‖c·t₀‖_∞ < γ₂. 426 if ct0.Exceeds(Gamma2) { 427 continue 428 } 429 430 // Create the hint to be able to reconstruct w₁ from w - c·s₂ + c·t0. 431 // Note that we're not using makeHint() in the obvious way as we 432 // do not know whether ‖ sc·s₂ - c·t₀ ‖_∞ < γ₂. Instead we note 433 // that our makeHint() is actually the same as a makeHint for a 434 // different decomposition: 435 // 436 // Earlier we ensured indirectly with a check that r₁ = w₁ where 437 // r = w - c·s₂. Hence r₀ = r - r₁ α = w - c·s₂ - w₁ α = w₀ - c·s₂. 438 // Thus MakeHint(w₀ - c·s₂ + c·t₀, w₁) = MakeHint(r0 + c·t₀, r₁) 439 // and UseHint(w - c·s₂ + c·t₀, w₁) = UseHint(r + c·t₀, r₁). 440 // As we just ensured that ‖ c·t₀ ‖_∞ < γ₂ our usage is correct. 441 w0mcs2pct0.Add(&w0mcs2, &ct0) 442 w0mcs2pct0.NormalizeAssumingLe2Q() 443 hintPop := sig.hint.MakeHint(&w0mcs2pct0, &w1) 444 if hintPop > Omega { 445 continue 446 } 447 448 break 449 } 450 451 sig.Pack(signature[:]) 452 } 453 454 // Computes the public key corresponding to this private key. 455 func (sk *PrivateKey) Public() *PublicKey { 456 var t0 VecK 457 pk := &PublicKey{ 458 rho: sk.rho, 459 A: &sk.A, 460 tr: &sk.tr, 461 } 462 sk.computeT0andT1(&t0, &pk.t1) 463 pk.t1.PackT1(pk.t1p[:]) 464 return pk 465 } 466 467 // Equal returns whether the two public keys are equal 468 func (pk *PublicKey) Equal(other *PublicKey) bool { 469 return pk.rho == other.rho && pk.t1 == other.t1 470 } 471 472 // Equal returns whether the two private keys are equal 473 func (sk *PrivateKey) Equal(other *PrivateKey) bool { 474 ret := (subtle.ConstantTimeCompare(sk.rho[:], other.rho[:]) & 475 subtle.ConstantTimeCompare(sk.key[:], other.key[:]) & 476 subtle.ConstantTimeCompare(sk.tr[:], other.tr[:])) 477 478 acc := uint32(0) 479 for i := 0; i < L; i++ { 480 for j := 0; j < common.N; j++ { 481 acc |= sk.s1[i][j] ^ other.s1[i][j] 482 } 483 } 484 for i := 0; i < K; i++ { 485 for j := 0; j < common.N; j++ { 486 acc |= sk.s2[i][j] ^ other.s2[i][j] 487 acc |= sk.t0[i][j] ^ other.t0[i][j] 488 } 489 } 490 return (ret & subtle.ConstantTimeEq(int32(acc), 0)) == 1 491 }