github.com/consensys/gnark@v0.11.0/backend/groth16/bn254/icicle/icicle.go (about) 1 //go:build icicle 2 3 package icicle_bn254 4 5 import ( 6 "fmt" 7 "math/big" 8 "math/bits" 9 "time" 10 "unsafe" 11 12 curve "github.com/consensys/gnark-crypto/ecc/bn254" 13 "github.com/consensys/gnark-crypto/ecc/bn254/fp" 14 "github.com/consensys/gnark-crypto/ecc/bn254/fr" 15 "github.com/consensys/gnark-crypto/ecc/bn254/fr/hash_to_field" 16 "github.com/consensys/gnark-crypto/ecc/bn254/fr/pedersen" 17 "github.com/consensys/gnark/backend" 18 groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" 19 "github.com/consensys/gnark/backend/groth16/internal" 20 "github.com/consensys/gnark/backend/witness" 21 "github.com/consensys/gnark/constraint" 22 cs "github.com/consensys/gnark/constraint/bn254" 23 "github.com/consensys/gnark/constraint/solver" 24 "github.com/consensys/gnark/internal/utils" 25 "github.com/consensys/gnark/logger" 26 iciclegnark "github.com/ingonyama-zk/iciclegnark/curves/bn254" 27 ) 28 29 const HasIcicle = true 30 31 func (pk *ProvingKey) setupDevicePointers() error { 32 if pk.deviceInfo != nil { 33 return nil 34 } 35 pk.deviceInfo = &deviceInfo{} 36 n := int(pk.Domain.Cardinality) 37 sizeBytes := n * fr.Bytes 38 39 /************************* Start Domain Device Setup ***************************/ 40 copyCosetInvDone := make(chan unsafe.Pointer, 1) 41 copyCosetDone := make(chan unsafe.Pointer, 1) 42 copyDenDone := make(chan unsafe.Pointer, 1) 43 /************************* CosetTableInv ***************************/ 44 go iciclegnark.CopyToDevice(pk.Domain.CosetTableInv, sizeBytes, copyCosetInvDone) 45 46 /************************* CosetTable ***************************/ 47 go iciclegnark.CopyToDevice(pk.Domain.CosetTable, sizeBytes, copyCosetDone) 48 49 /************************* Den ***************************/ 50 var denI, oneI fr.Element 51 oneI.SetOne() 52 denI.Exp(pk.Domain.FrMultiplicativeGen, big.NewInt(int64(pk.Domain.Cardinality))) 53 denI.Sub(&denI, &oneI).Inverse(&denI) 54 55 log2SizeFloor := bits.Len(uint(n)) - 1 56 denIcicleArr := []fr.Element{denI} 57 for i := 0; i < log2SizeFloor; i++ { 58 denIcicleArr = append(denIcicleArr, denIcicleArr...) 59 } 60 pow2Remainder := n - 1<<log2SizeFloor 61 for i := 0; i < pow2Remainder; i++ { 62 denIcicleArr = append(denIcicleArr, denI) 63 } 64 65 go iciclegnark.CopyToDevice(denIcicleArr, sizeBytes, copyDenDone) 66 67 /************************* Twiddles and Twiddles Inv ***************************/ 68 twiddlesInv_d_gen, twddles_err := iciclegnark.GenerateTwiddleFactors(n, true) 69 if twddles_err != nil { 70 return twddles_err 71 } 72 73 twiddles_d_gen, twddles_err := iciclegnark.GenerateTwiddleFactors(n, false) 74 if twddles_err != nil { 75 return twddles_err 76 } 77 78 /************************* End Domain Device Setup ***************************/ 79 pk.DomainDevice.Twiddles = twiddles_d_gen 80 pk.DomainDevice.TwiddlesInv = twiddlesInv_d_gen 81 82 pk.DomainDevice.CosetTableInv = <-copyCosetInvDone 83 pk.DomainDevice.CosetTable = <-copyCosetDone 84 pk.DenDevice = <-copyDenDone 85 86 /************************* Start G1 Device Setup ***************************/ 87 /************************* A ***************************/ 88 pointsBytesA := len(pk.G1.A) * fp.Bytes * 2 89 copyADone := make(chan unsafe.Pointer, 1) 90 go iciclegnark.CopyPointsToDevice(pk.G1.A, pointsBytesA, copyADone) // Make a function for points 91 92 /************************* B ***************************/ 93 pointsBytesB := len(pk.G1.B) * fp.Bytes * 2 94 copyBDone := make(chan unsafe.Pointer, 1) 95 go iciclegnark.CopyPointsToDevice(pk.G1.B, pointsBytesB, copyBDone) // Make a function for points 96 97 /************************* K ***************************/ 98 var pointsNoInfinity []curve.G1Affine 99 for i, gnarkPoint := range pk.G1.K { 100 if gnarkPoint.IsInfinity() { 101 pk.InfinityPointIndicesK = append(pk.InfinityPointIndicesK, i) 102 } else { 103 pointsNoInfinity = append(pointsNoInfinity, gnarkPoint) 104 } 105 } 106 107 pointsBytesK := len(pointsNoInfinity) * fp.Bytes * 2 108 copyKDone := make(chan unsafe.Pointer, 1) 109 go iciclegnark.CopyPointsToDevice(pointsNoInfinity, pointsBytesK, copyKDone) // Make a function for points 110 111 /************************* Z ***************************/ 112 pointsBytesZ := len(pk.G1.Z) * fp.Bytes * 2 113 copyZDone := make(chan unsafe.Pointer, 1) 114 go iciclegnark.CopyPointsToDevice(pk.G1.Z, pointsBytesZ, copyZDone) // Make a function for points 115 116 /************************* End G1 Device Setup ***************************/ 117 pk.G1Device.A = <-copyADone 118 pk.G1Device.B = <-copyBDone 119 pk.G1Device.K = <-copyKDone 120 pk.G1Device.Z = <-copyZDone 121 122 /************************* Start G2 Device Setup ***************************/ 123 pointsBytesB2 := len(pk.G2.B) * fp.Bytes * 4 124 copyG2BDone := make(chan unsafe.Pointer, 1) 125 go iciclegnark.CopyG2PointsToDevice(pk.G2.B, pointsBytesB2, copyG2BDone) // Make a function for points 126 pk.G2Device.B = <-copyG2BDone 127 128 /************************* End G2 Device Setup ***************************/ 129 return nil 130 } 131 132 // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). 133 func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*groth16_bn254.Proof, error) { 134 opt, err := backend.NewProverConfig(opts...) 135 if err != nil { 136 return nil, fmt.Errorf("new prover config: %w", err) 137 } 138 if opt.HashToFieldFn == nil { 139 opt.HashToFieldFn = hash_to_field.New([]byte(constraint.CommitmentDst)) 140 } 141 if opt.Accelerator != "icicle" { 142 return groth16_bn254.Prove(r1cs, &pk.ProvingKey, fullWitness, opts...) 143 } 144 log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "icicle").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() 145 if pk.deviceInfo == nil { 146 log.Debug().Msg("precomputing proving key in GPU") 147 if err := pk.setupDevicePointers(); err != nil { 148 return nil, fmt.Errorf("setup device pointers: %w", err) 149 } 150 } 151 152 commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) 153 154 proof := &groth16_bn254.Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} 155 156 solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] 157 158 privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) 159 for i := range commitmentInfo { 160 solverOpts = append(solverOpts, solver.OverrideHint(commitmentInfo[i].HintID, func(i int) solver.Hint { 161 return func(_ *big.Int, in []*big.Int, out []*big.Int) error { 162 privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) 163 hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] 164 committed := in[len(hashed):] 165 for j, inJ := range committed { 166 privateCommittedValues[i][j].SetBigInt(inJ) 167 } 168 169 var err error 170 if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { 171 return err 172 } 173 174 opt.HashToFieldFn.Write(constraint.SerializeCommitment(proof.Commitments[i].Marshal(), hashed, (fr.Bits-1)/8+1)) 175 hashBts := opt.HashToFieldFn.Sum(nil) 176 opt.HashToFieldFn.Reset() 177 nbBuf := fr.Bytes 178 if opt.HashToFieldFn.Size() < fr.Bytes { 179 nbBuf = opt.HashToFieldFn.Size() 180 } 181 var res fr.Element 182 res.SetBytes(hashBts[:nbBuf]) 183 res.BigInt(out[0]) 184 return err 185 } 186 }(i))) 187 } 188 189 if r1cs.GkrInfo.Is() { 190 var gkrData cs.GkrSolvingData 191 solverOpts = append(solverOpts, 192 solver.OverrideHint(r1cs.GkrInfo.SolveHintID, cs.GkrSolveHint(r1cs.GkrInfo, &gkrData)), 193 solver.OverrideHint(r1cs.GkrInfo.ProveHintID, cs.GkrProveHint(r1cs.GkrInfo.HashName, &gkrData))) 194 } 195 196 _solution, err := r1cs.Solve(fullWitness, solverOpts...) 197 if err != nil { 198 return nil, err 199 } 200 201 solution := _solution.(*cs.R1CSSolution) 202 wireValues := []fr.Element(solution.W) 203 204 start := time.Now() 205 206 commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) 207 for i := range commitmentInfo { 208 copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) 209 } 210 211 if proof.CommitmentPok, err = pedersen.BatchProve(pk.CommitmentKeys, privateCommittedValues, commitmentsSerialized); err != nil { 212 return nil, err 213 } 214 215 // H (witness reduction / FFT part) 216 var h unsafe.Pointer 217 chHDone := make(chan struct{}, 1) 218 go func() { 219 h = computeH(solution.A, solution.B, solution.C, pk) 220 solution.A = nil 221 solution.B = nil 222 solution.C = nil 223 chHDone <- struct{}{} 224 }() 225 226 // we need to copy and filter the wireValues for each multi exp 227 // as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity 228 var wireValuesADevice, wireValuesBDevice iciclegnark.OnDeviceData 229 chWireValuesA, chWireValuesB := make(chan struct{}, 1), make(chan struct{}, 1) 230 231 go func() { 232 wireValuesA := make([]fr.Element, len(wireValues)-int(pk.NbInfinityA)) 233 for i, j := 0, 0; j < len(wireValuesA); i++ { 234 if pk.InfinityA[i] { 235 continue 236 } 237 wireValuesA[j] = wireValues[i] 238 j++ 239 } 240 wireValuesASize := len(wireValuesA) 241 scalarBytes := wireValuesASize * fr.Bytes 242 243 // Copy scalars to the device and retain ptr to them 244 copyDone := make(chan unsafe.Pointer, 1) 245 iciclegnark.CopyToDevice(wireValuesA, scalarBytes, copyDone) 246 wireValuesADevicePtr := <-copyDone 247 248 wireValuesADevice = iciclegnark.OnDeviceData{ 249 P: wireValuesADevicePtr, 250 Size: wireValuesASize, 251 } 252 253 close(chWireValuesA) 254 }() 255 go func() { 256 wireValuesB := make([]fr.Element, len(wireValues)-int(pk.NbInfinityB)) 257 for i, j := 0, 0; j < len(wireValuesB); i++ { 258 if pk.InfinityB[i] { 259 continue 260 } 261 wireValuesB[j] = wireValues[i] 262 j++ 263 } 264 wireValuesBSize := len(wireValuesB) 265 scalarBytes := wireValuesBSize * fr.Bytes 266 267 // Copy scalars to the device and retain ptr to them 268 copyDone := make(chan unsafe.Pointer, 1) 269 iciclegnark.CopyToDevice(wireValuesB, scalarBytes, copyDone) 270 wireValuesBDevicePtr := <-copyDone 271 272 wireValuesBDevice = iciclegnark.OnDeviceData{ 273 P: wireValuesBDevicePtr, 274 Size: wireValuesBSize, 275 } 276 277 close(chWireValuesB) 278 }() 279 280 // sample random r and s 281 var r, s big.Int 282 var _r, _s, _kr fr.Element 283 if _, err := _r.SetRandom(); err != nil { 284 return nil, err 285 } 286 if _, err := _s.SetRandom(); err != nil { 287 return nil, err 288 } 289 _kr.Mul(&_r, &_s).Neg(&_kr) 290 291 _r.BigInt(&r) 292 _s.BigInt(&s) 293 294 // computes r[δ], s[δ], kr[δ] 295 deltas := curve.BatchScalarMultiplicationG1(&pk.G1.Delta, []fr.Element{_r, _s, _kr}) 296 297 var bs1, ar curve.G1Jac 298 299 computeBS1 := func() error { 300 <-chWireValuesB 301 302 if bs1, _, err = iciclegnark.MsmOnDevice(wireValuesBDevice.P, pk.G1Device.B, wireValuesBDevice.Size, true); err != nil { 303 return err 304 } 305 306 bs1.AddMixed(&pk.G1.Beta) 307 bs1.AddMixed(&deltas[1]) 308 309 return nil 310 } 311 312 computeAR1 := func() error { 313 <-chWireValuesA 314 315 if ar, _, err = iciclegnark.MsmOnDevice(wireValuesADevice.P, pk.G1Device.A, wireValuesADevice.Size, true); err != nil { 316 return err 317 } 318 319 ar.AddMixed(&pk.G1.Alpha) 320 ar.AddMixed(&deltas[0]) 321 proof.Ar.FromJacobian(&ar) 322 323 return nil 324 } 325 326 computeKRS := func() error { 327 var krs, krs2, p1 curve.G1Jac 328 sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 329 330 // check for small circuits as iciclegnark doesn't handle zero sizes well 331 if len(pk.G1.Z) > 0 { 332 if krs2, _, err = iciclegnark.MsmOnDevice(h, pk.G1Device.Z, sizeH, true); err != nil { 333 return err 334 } 335 } 336 337 // filter the wire values if needed 338 // TODO Perf @Tabaie worst memory allocation offender 339 toRemove := commitmentInfo.GetPrivateCommitted() 340 toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) 341 scalars := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) 342 343 // filter zero/infinity points since icicle doesn't handle them 344 // See https://github.com/ingonyama-zk/icicle/issues/169 for more info 345 for _, indexToRemove := range pk.InfinityPointIndicesK { 346 scalars = append(scalars[:indexToRemove], scalars[indexToRemove+1:]...) 347 } 348 349 scalarBytes := len(scalars) * fr.Bytes 350 351 copyDone := make(chan unsafe.Pointer, 1) 352 iciclegnark.CopyToDevice(scalars, scalarBytes, copyDone) 353 scalars_d := <-copyDone 354 355 krs, _, err = iciclegnark.MsmOnDevice(scalars_d, pk.G1Device.K, len(scalars), true) 356 iciclegnark.FreeDevicePointer(scalars_d) 357 358 if err != nil { 359 return err 360 } 361 362 krs.AddMixed(&deltas[2]) 363 364 krs.AddAssign(&krs2) 365 366 p1.ScalarMultiplication(&ar, &s) 367 krs.AddAssign(&p1) 368 369 p1.ScalarMultiplication(&bs1, &r) 370 krs.AddAssign(&p1) 371 372 proof.Krs.FromJacobian(&krs) 373 374 return nil 375 } 376 377 computeBS2 := func() error { 378 // Bs2 (1 multi exp G2 - size = len(wires)) 379 var Bs, deltaS curve.G2Jac 380 381 <-chWireValuesB 382 if Bs, _, err = iciclegnark.MsmG2OnDevice(wireValuesBDevice.P, pk.G2Device.B, wireValuesBDevice.Size, true); err != nil { 383 return err 384 } 385 386 deltaS.FromAffine(&pk.G2.Delta) 387 deltaS.ScalarMultiplication(&deltaS, &s) 388 Bs.AddAssign(&deltaS) 389 Bs.AddMixed(&pk.G2.Beta) 390 391 proof.Bs.FromJacobian(&Bs) 392 return nil 393 } 394 395 // wait for FFT to end 396 <-chHDone 397 398 // schedule our proof part computations 399 if err := computeAR1(); err != nil { 400 return nil, err 401 } 402 if err := computeBS1(); err != nil { 403 return nil, err 404 } 405 if err := computeKRS(); err != nil { 406 return nil, err 407 } 408 if err := computeBS2(); err != nil { 409 return nil, err 410 } 411 412 log.Debug().Dur("took", time.Since(start)).Msg("prover done") 413 414 // free device/GPU memory that is not needed for future proofs (scalars/hpoly) 415 go func() { 416 iciclegnark.FreeDevicePointer(wireValuesADevice.P) 417 iciclegnark.FreeDevicePointer(wireValuesBDevice.P) 418 iciclegnark.FreeDevicePointer(h) 419 }() 420 421 return proof, nil 422 } 423 424 // if len(toRemove) == 0, returns slice 425 // else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex 426 // this assumes len(slice) > len(toRemove) 427 // filterHeap modifies toRemove 428 func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { 429 430 if len(toRemove) == 0 { 431 return slice 432 } 433 434 heap := utils.IntHeap(toRemove) 435 heap.Heapify() 436 437 r = make([]fr.Element, 0, len(slice)) 438 439 // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) 440 for i := 0; i < len(slice); i++ { 441 if len(heap) > 0 && i+sliceFirstIndex == heap[0] { 442 for len(heap) > 0 && i+sliceFirstIndex == heap[0] { 443 heap.Pop() 444 } 445 continue 446 } 447 r = append(r, slice[i]) 448 } 449 450 return 451 } 452 453 func computeH(a, b, c []fr.Element, pk *ProvingKey) unsafe.Pointer { 454 // H part of Krs 455 // Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1)) 456 // 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c) 457 // 2 - ca = fft_coset(_a), ba = fft_coset(_b), cc = fft_coset(_c) 458 // 3 - h = ifft_coset(ca o cb - cc) 459 460 n := len(a) 461 462 // add padding to ensure input length is domain cardinality 463 padding := make([]fr.Element, int(pk.Domain.Cardinality)-n) 464 a = append(a, padding...) 465 b = append(b, padding...) 466 c = append(c, padding...) 467 n = len(a) 468 469 sizeBytes := n * fr.Bytes 470 471 /*********** Copy a,b,c to Device Start ************/ 472 // Individual channels are necessary to know which device pointers 473 // point to which vector 474 copyADone := make(chan unsafe.Pointer, 1) 475 copyBDone := make(chan unsafe.Pointer, 1) 476 copyCDone := make(chan unsafe.Pointer, 1) 477 478 go iciclegnark.CopyToDevice(a, sizeBytes, copyADone) 479 go iciclegnark.CopyToDevice(b, sizeBytes, copyBDone) 480 go iciclegnark.CopyToDevice(c, sizeBytes, copyCDone) 481 482 a_device := <-copyADone 483 b_device := <-copyBDone 484 c_device := <-copyCDone 485 /*********** Copy a,b,c to Device End ************/ 486 487 computeInttNttDone := make(chan error, 1) 488 computeInttNttOnDevice := func(devicePointer unsafe.Pointer) { 489 a_intt_d := iciclegnark.INttOnDevice(devicePointer, pk.DomainDevice.TwiddlesInv, nil, n, sizeBytes, false) 490 iciclegnark.NttOnDevice(devicePointer, a_intt_d, pk.DomainDevice.Twiddles, pk.DomainDevice.CosetTable, n, n, sizeBytes, true) 491 computeInttNttDone <- nil 492 iciclegnark.FreeDevicePointer(a_intt_d) 493 } 494 495 go computeInttNttOnDevice(a_device) 496 go computeInttNttOnDevice(b_device) 497 go computeInttNttOnDevice(c_device) 498 _, _, _ = <-computeInttNttDone, <-computeInttNttDone, <-computeInttNttDone 499 500 iciclegnark.PolyOps(a_device, b_device, c_device, pk.DenDevice, n) 501 502 h := iciclegnark.INttOnDevice(a_device, pk.DomainDevice.TwiddlesInv, pk.DomainDevice.CosetTableInv, n, sizeBytes, true) 503 504 go func() { 505 iciclegnark.FreeDevicePointer(a_device) 506 iciclegnark.FreeDevicePointer(b_device) 507 iciclegnark.FreeDevicePointer(c_device) 508 }() 509 510 iciclegnark.ReverseScalars(h, n) 511 512 return h 513 }