github.com/consensys/gnark@v0.11.0/internal/generator/backend/template/zkpschemes/groth16/groth16.prove.go.tmpl (about) 1 import ( 2 "fmt" 3 "runtime" 4 "math/big" 5 "time" 6 7 {{- template "import_fr" . }} 8 {{- template "import_curve" . }} 9 {{- template "import_backend_cs" . }} 10 {{- template "import_fft" . }} 11 {{- template "import_hash_to_field" . }} 12 "github.com/consensys/gnark/constraint" 13 "github.com/consensys/gnark-crypto/ecc" 14 "github.com/consensys/gnark/internal/utils" 15 "github.com/consensys/gnark/backend" 16 "github.com/consensys/gnark/backend/groth16/internal" 17 "github.com/consensys/gnark/constraint/solver" 18 "github.com/consensys/gnark/backend/witness" 19 "github.com/consensys/gnark/logger" 20 21 fcs "github.com/consensys/gnark/frontend/cs" 22 ) 23 24 25 // Proof represents a Groth16 proof that was encoded with a ProvingKey and can be verified 26 // with a valid statement and a VerifyingKey 27 // Notation follows Figure 4. in DIZK paper https://eprint.iacr.org/2018/691.pdf 28 type Proof struct { 29 Ar, Krs curve.G1Affine 30 Bs curve.G2Affine 31 Commitments []curve.G1Affine // Pedersen commitments a la https://eprint.iacr.org/2022/1072 32 CommitmentPok curve.G1Affine // Batched proof of knowledge of the above commitments 33 } 34 35 // isValid ensures proof elements are in the correct subgroup 36 func (proof *Proof) isValid() bool { 37 return proof.Ar.IsInSubGroup() && proof.Krs.IsInSubGroup() && proof.Bs.IsInSubGroup() 38 } 39 40 // CurveID returns the curveID 41 func (proof *Proof) CurveID() ecc.ID { 42 return curve.ID 43 } 44 45 // Prove generates the proof of knowledge of a r1cs with full witness (secret + public part). 46 func Prove(r1cs *cs.R1CS, pk *ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (*Proof, error) { 47 opt, err := backend.NewProverConfig(opts...) 48 if err != nil { 49 return nil, fmt.Errorf("new prover config: %w", err) 50 } 51 if opt.HashToFieldFn == nil { 52 opt.HashToFieldFn = hash_to_field.New([]byte(constraint.CommitmentDst)) 53 } 54 55 log := logger.Logger().With().Str("curve", r1cs.CurveID().String()).Str("acceleration", "none").Int("nbConstraints", r1cs.GetNbConstraints()).Str("backend", "groth16").Logger() 56 57 commitmentInfo := r1cs.CommitmentInfo.(constraint.Groth16Commitments) 58 59 proof := &Proof{Commitments: make([]curve.G1Affine, len(commitmentInfo))} 60 61 solverOpts := opt.SolverOpts[:len(opt.SolverOpts):len(opt.SolverOpts)] 62 63 privateCommittedValues := make([][]fr.Element, len(commitmentInfo)) 64 65 // override hints 66 bsb22ID := solver.GetHintID(fcs.Bsb22CommitmentComputePlaceholder) 67 solverOpts = append(solverOpts, solver.OverrideHint(bsb22ID, func(_ *big.Int, in []*big.Int, out []*big.Int) error { 68 i := int(in[0].Int64()) 69 in = in[1:] 70 privateCommittedValues[i] = make([]fr.Element, len(commitmentInfo[i].PrivateCommitted)) 71 hashed := in[:len(commitmentInfo[i].PublicAndCommitmentCommitted)] 72 committed := in[+len(hashed):] 73 for j, inJ := range committed { 74 privateCommittedValues[i][j].SetBigInt(inJ) 75 } 76 77 var err error 78 if proof.Commitments[i], err = pk.CommitmentKeys[i].Commit(privateCommittedValues[i]); err != nil { 79 return err 80 } 81 82 opt.HashToFieldFn.Write(constraint.SerializeCommitment(proof.Commitments[i].Marshal(), hashed, (fr.Bits-1)/8+1)) 83 hashBts := opt.HashToFieldFn.Sum(nil) 84 opt.HashToFieldFn.Reset() 85 nbBuf := fr.Bytes 86 if opt.HashToFieldFn.Size() < fr.Bytes { 87 nbBuf = opt.HashToFieldFn.Size() 88 } 89 var res fr.Element 90 res.SetBytes(hashBts[:nbBuf]) 91 res.BigInt(out[0]) 92 return nil 93 })) 94 95 _solution, err := r1cs.Solve(fullWitness, solverOpts...) 96 if err != nil { 97 return nil, err 98 } 99 100 solution := _solution.(*cs.R1CSSolution) 101 wireValues := []fr.Element(solution.W) 102 103 start := time.Now() 104 poks := make([]curve.G1Affine, len(pk.CommitmentKeys)) 105 106 for i := range pk.CommitmentKeys { 107 var err error 108 if poks[i], err = pk.CommitmentKeys[i].ProveKnowledge(privateCommittedValues[i]); err != nil { 109 return nil, err 110 } 111 } 112 // compute challenge for folding the PoKs from the commitments 113 commitmentsSerialized := make([]byte, fr.Bytes*len(commitmentInfo)) 114 for i := range commitmentInfo { 115 copy(commitmentsSerialized[fr.Bytes*i:], wireValues[commitmentInfo[i].CommitmentIndex].Marshal()) 116 } 117 challenge, err := fr.Hash(commitmentsSerialized, []byte("G16-BSB22"), 1) 118 if err != nil { 119 return nil, err 120 } 121 if _, err = proof.CommitmentPok.Fold(poks, challenge[0], ecc.MultiExpConfig{NbTasks: 1}); err != nil { 122 return nil, err 123 } 124 125 // H (witness reduction / FFT part) 126 var h []fr.Element 127 chHDone := make(chan struct{}, 1) 128 go func() { 129 h = computeH(solution.A, solution.B, solution.C, &pk.Domain) 130 solution.A = nil 131 solution.B = nil 132 solution.C = nil 133 chHDone <- struct{}{} 134 }() 135 136 // we need to copy and filter the wireValues for each multi exp 137 // as pk.G1.A, pk.G1.B and pk.G2.B may have (a significant) number of point at infinity 138 var wireValuesA, wireValuesB []fr.Element 139 chWireValuesA, chWireValuesB := make(chan struct{}, 1), make(chan struct{}, 1) 140 141 go func() { 142 wireValuesA = make([]fr.Element, len(wireValues)-int(pk.NbInfinityA)) 143 for i, j := 0, 0; j < len(wireValuesA); i++ { 144 if pk.InfinityA[i] { 145 continue 146 } 147 wireValuesA[j] = wireValues[i] 148 j++ 149 } 150 close(chWireValuesA) 151 }() 152 go func() { 153 wireValuesB = make([]fr.Element, len(wireValues)-int(pk.NbInfinityB)) 154 for i, j := 0, 0; j < len(wireValuesB); i++ { 155 if pk.InfinityB[i] { 156 continue 157 } 158 wireValuesB[j] = wireValues[i] 159 j++ 160 } 161 close(chWireValuesB) 162 }() 163 164 // sample random r and s 165 var r, s big.Int 166 var _r, _s, _kr fr.Element 167 if _, err := _r.SetRandom(); err != nil { 168 return nil, err 169 } 170 if _, err := _s.SetRandom(); err != nil { 171 return nil, err 172 } 173 _kr.Mul(&_r, &_s).Neg(&_kr) 174 175 _r.BigInt(&r) 176 _s.BigInt(&s) 177 178 // computes r[δ], s[δ], kr[δ] 179 deltas := curve.BatchScalarMultiplicationG1(&pk.G1.Delta, []fr.Element{_r, _s, _kr}) 180 181 var bs1, ar curve.G1Jac 182 183 n := runtime.NumCPU() 184 185 chBs1Done := make(chan error, 1) 186 computeBS1 := func() { 187 <-chWireValuesB 188 if _, err := bs1.MultiExp(pk.G1.B, wireValuesB, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { 189 chBs1Done <- err 190 close(chBs1Done) 191 return 192 } 193 bs1.AddMixed(&pk.G1.Beta) 194 bs1.AddMixed(&deltas[1]) 195 chBs1Done <- nil 196 } 197 198 chArDone := make(chan error, 1) 199 computeAR1 := func() { 200 <-chWireValuesA 201 if _, err := ar.MultiExp(pk.G1.A, wireValuesA, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { 202 chArDone <- err 203 close(chArDone) 204 return 205 } 206 ar.AddMixed(&pk.G1.Alpha) 207 ar.AddMixed(&deltas[0]) 208 proof.Ar.FromJacobian(&ar) 209 chArDone <- nil 210 } 211 212 chKrsDone := make(chan error, 1) 213 computeKRS := func() { 214 // we could NOT split the Krs multiExp in 2, and just append pk.G1.K and pk.G1.Z 215 // however, having similar lengths for our tasks helps with parallelism 216 217 var krs, krs2, p1 curve.G1Jac 218 chKrs2Done := make(chan error, 1) 219 sizeH := int(pk.Domain.Cardinality - 1) // comes from the fact the deg(H)=(n-1)+(n-1)-n=n-2 220 go func() { 221 _, err := krs2.MultiExp(pk.G1.Z, h[:sizeH], ecc.MultiExpConfig{NbTasks: n / 2}) 222 chKrs2Done <- err 223 }() 224 225 // filter the wire values if needed 226 // TODO Perf @Tabaie worst memory allocation offender 227 toRemove := commitmentInfo.GetPrivateCommitted() 228 toRemove = append(toRemove, commitmentInfo.CommitmentIndexes()) 229 _wireValues := filterHeap(wireValues[r1cs.GetNbPublicVariables():], r1cs.GetNbPublicVariables(), internal.ConcatAll(toRemove...)) 230 231 if _, err := krs.MultiExp(pk.G1.K, _wireValues, ecc.MultiExpConfig{NbTasks: n / 2}); err != nil { 232 chKrsDone <- err 233 return 234 } 235 krs.AddMixed(&deltas[2]) 236 n := 3 237 for n != 0 { 238 select { 239 case err := <-chKrs2Done: 240 if err != nil { 241 chKrsDone <- err 242 return 243 } 244 krs.AddAssign(&krs2) 245 case err := <-chArDone: 246 if err != nil { 247 chKrsDone <- err 248 return 249 } 250 p1.ScalarMultiplication(&ar, &s) 251 krs.AddAssign(&p1) 252 case err := <-chBs1Done: 253 if err != nil { 254 chKrsDone <- err 255 return 256 } 257 p1.ScalarMultiplication(&bs1, &r) 258 krs.AddAssign(&p1) 259 } 260 n-- 261 } 262 263 proof.Krs.FromJacobian(&krs) 264 chKrsDone <- nil 265 } 266 267 computeBS2 := func() error { 268 // Bs2 (1 multi exp G2 - size = len(wires)) 269 var Bs, deltaS curve.G2Jac 270 271 nbTasks := n 272 if nbTasks <= 16 { 273 // if we don't have a lot of CPUs, this may artificially split the MSM 274 nbTasks *= 2 275 } 276 <-chWireValuesB 277 if _, err := Bs.MultiExp(pk.G2.B, wireValuesB, ecc.MultiExpConfig{NbTasks: nbTasks}); err != nil { 278 return err 279 } 280 281 deltaS.FromAffine(&pk.G2.Delta) 282 deltaS.ScalarMultiplication(&deltaS, &s) 283 Bs.AddAssign(&deltaS) 284 Bs.AddMixed(&pk.G2.Beta) 285 286 proof.Bs.FromJacobian(&Bs) 287 return nil 288 } 289 290 // wait for FFT to end, as it uses all our CPUs 291 <-chHDone 292 293 // schedule our proof part computations 294 go computeKRS() 295 go computeAR1() 296 go computeBS1() 297 if err := computeBS2(); err != nil { 298 return nil, err 299 } 300 301 // wait for all parts of the proof to be computed. 302 if err := <-chKrsDone; err != nil { 303 return nil, err 304 } 305 306 log.Debug().Dur("took", time.Since(start)).Msg("prover done") 307 308 return proof, nil 309 } 310 311 // if len(toRemove) == 0, returns slice 312 // else, returns a new slice without the indexes in toRemove. The first value in the slice is taken as indexes as sliceFirstIndex 313 // this assumes len(slice) > len(toRemove) 314 // filterHeap modifies toRemove 315 func filterHeap(slice []fr.Element, sliceFirstIndex int, toRemove []int) (r []fr.Element) { 316 317 if len(toRemove) == 0 { 318 return slice 319 } 320 321 heap := utils.IntHeap(toRemove) 322 heap.Heapify() 323 324 r = make([]fr.Element, 0, len(slice)) 325 326 // note: we can optimize that for the likely case where len(slice) >>> len(toRemove) 327 for i:=0; i < len(slice);i++ { 328 if len(heap) > 0 && i+sliceFirstIndex == heap[0] { 329 for len(heap) > 0 && i+sliceFirstIndex == heap[0] { 330 heap.Pop() 331 } 332 continue 333 } 334 r = append(r, slice[i]) 335 } 336 337 return 338 } 339 340 func computeH(a, b, c []fr.Element, domain *fft.Domain) []fr.Element { 341 // H part of Krs 342 // Compute H (hz=ab-c, where z=-2 on ker X^n+1 (z(x)=x^n-1)) 343 // 1 - _a = ifft(a), _b = ifft(b), _c = ifft(c) 344 // 2 - ca = fft_coset(_a), ba = fft_coset(_b), cc = fft_coset(_c) 345 // 3 - h = ifft_coset(ca o cb - cc) 346 347 n := len(a) 348 349 // add padding to ensure input length is domain cardinality 350 padding := make([]fr.Element, int(domain.Cardinality)-n) 351 a = append(a, padding...) 352 b = append(b, padding...) 353 c = append(c, padding...) 354 n = len(a) 355 356 domain.FFTInverse(a, fft.DIF) 357 domain.FFTInverse(b, fft.DIF) 358 domain.FFTInverse(c, fft.DIF) 359 360 domain.FFT(a, fft.DIT, fft.OnCoset()) 361 domain.FFT(b, fft.DIT, fft.OnCoset()) 362 domain.FFT(c, fft.DIT, fft.OnCoset()) 363 364 var den, one fr.Element 365 one.SetOne() 366 den.Exp(domain.FrMultiplicativeGen, big.NewInt(int64(domain.Cardinality))) 367 den.Sub(&den, &one).Inverse(&den) 368 369 // h = ifft_coset(ca o cb - cc) 370 // reusing a to avoid unnecessary memory allocation 371 utils.Parallelize(n, func(start, end int) { 372 for i := start; i < end; i++ { 373 a[i].Mul(&a[i], &b[i]). 374 Sub(&a[i], &c[i]). 375 Mul(&a[i], &den) 376 } 377 }) 378 379 // ifft_coset 380 domain.FFTInverse(a, fft.DIF, fft.OnCoset()) 381 382 return a 383 }