github.com/consensys/gnark-crypto@v0.14.0/internal/generator/gkr/template/gkr.go.tmpl (about) 1 import ( 2 "fmt" 3 "{{.FieldPackagePath}}" 4 "{{.FieldPackagePath}}/polynomial" 5 "{{.FieldPackagePath}}/sumcheck" 6 fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" 7 "github.com/consensys/gnark-crypto/internal/parallel" 8 "github.com/consensys/gnark-crypto/utils" 9 "math/big" 10 "strconv" 11 "sync" 12 ) 13 14 {{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} 15 16 // The goal is to prove/verify evaluations of many instances of the same circuit 17 18 // Gate must be a low-degree polynomial 19 type Gate interface { 20 Evaluate(...{{.ElementType}}) {{.ElementType}} 21 Degree() int 22 } 23 24 type Wire struct { 25 Gate Gate 26 Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire 27 nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) 28 } 29 30 type Circuit []Wire 31 32 func (w Wire) IsInput() bool { 33 return len(w.Inputs) == 0 34 } 35 36 func (w Wire) IsOutput() bool { 37 return w.nbUniqueOutputs == 0 38 } 39 40 func (w Wire) NbClaims() int { 41 if w.IsOutput() { 42 return 1 43 } 44 return w.nbUniqueOutputs 45 } 46 47 func (w Wire) noProof() bool { 48 return w.IsInput() && w.NbClaims() == 1 49 } 50 51 func (c Circuit) maxGateDegree() int { 52 res := 1 53 for i := range c { 54 if !c[i].IsInput() { 55 res = utils.Max(res, c[i].Gate.Degree()) 56 } 57 } 58 return res 59 } 60 61 // WireAssignment is assignment of values to the same wire across many instances of the circuit 62 type WireAssignment map[*Wire]polynomial.MultiLin 63 64 type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) 65 66 type eqTimesGateEvalSumcheckLazyClaims struct { 67 wire *Wire 68 evaluationPoints [][]{{.ElementType}} 69 claimedEvaluations []{{.ElementType}} 70 manager *claimsManager // WARNING: Circular references 71 } 72 73 func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { 74 return len(e.evaluationPoints) 75 } 76 77 func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { 78 return len(e.evaluationPoints[0]) 79 } 80 81 func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a {{.ElementType}}) {{.ElementType}} { 82 evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) 83 return evalsAsPoly.Eval(&a) 84 } 85 86 func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { 87 return 1 + e.wire.Gate.Degree() 88 } 89 90 func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error { 91 inputEvaluationsNoRedundancy := proof.([]{{.ElementType}}) 92 93 // the eq terms 94 numClaims := len(e.evaluationPoints) 95 evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) 96 for i := numClaims - 2; i >= 0; i-- { 97 evaluation.Mul(&evaluation, &combinationCoeff) 98 eq := polynomial.EvalEq(e.evaluationPoints[i], r) 99 evaluation.Add(&evaluation, &eq) 100 } 101 102 // the g(...) term 103 var gateEvaluation {{.ElementType}} 104 if e.wire.IsInput() { 105 gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) 106 } else { 107 inputEvaluations := make([]{{.ElementType}}, len(e.wire.Inputs)) 108 indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) 109 110 proofI := 0 111 for inI, in := range e.wire.Inputs { 112 indexInProof, found := indexesInProof[in] 113 if !found { 114 indexInProof = proofI 115 indexesInProof[in] = indexInProof 116 117 // defer verification, store new claim 118 e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) 119 proofI++ 120 } 121 inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] 122 } 123 if proofI != len(inputEvaluationsNoRedundancy) { 124 return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) 125 } 126 gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) 127 } 128 129 evaluation.Mul(&evaluation, &gateEvaluation) 130 131 if evaluation.Equal(&purportedValue) { 132 return nil 133 } 134 return fmt.Errorf("incompatible evaluations") 135 } 136 137 type eqTimesGateEvalSumcheckClaims struct { 138 wire *Wire 139 evaluationPoints [][]{{.ElementType}} // x in the paper 140 claimedEvaluations []{{.ElementType}} // y in the paper 141 manager *claimsManager 142 143 inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations 144 145 eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) 146 } 147 148 func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff {{.ElementType}}) polynomial.Polynomial { 149 varsNum := c.VarsNum() 150 eqLength := 1 << varsNum 151 claimsNum := c.ClaimsNum() 152 // initialize the eq tables 153 c.eq = c.manager.memPool.Make(eqLength) 154 155 c.eq[0].SetOne() 156 c.eq.Eq(c.evaluationPoints[0]) 157 158 newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) 159 aI := combinationCoeff 160 161 for k := 1; k < claimsNum; k++ { //TODO: parallelizable? 162 // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points 163 newEq[0].Set(&aI) 164 165 c.eqAcc(c.eq, newEq,c.evaluationPoints[k]) 166 167 // newEq.Eq(c.evaluationPoints[k]) 168 // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics 169 // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) 170 171 if k+1 < claimsNum { 172 aI.Mul(&aI, &combinationCoeff) 173 } 174 } 175 176 c.manager.memPool.Dump(newEq) 177 178 // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree 179 180 return c.computeGJ() 181 } 182 183 // eqAcc sets m to an eq table at q and then adds it to e 184 func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []{{.ElementType}}) { 185 n := len(q) 186 187 //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) 188 for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ 189 // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ 190 const threshold = 1 << 6 191 k := 1 << i 192 if k < threshold { 193 for j := 0; j < k; j++ { 194 j0 := j << (n - i) // bᵢ₊₁ = 0 195 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 196 197 m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ 198 m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) 199 } 200 } else { 201 c.manager.workers.Submit(k, func(start, end int) { 202 for j := start; j < end; j++ { 203 j0 := j << (n - i) // bᵢ₊₁ = 0 204 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 205 206 m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ 207 m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) 208 } 209 }, 1024).Wait() 210 } 211 212 } 213 c.manager.workers.Submit(len(e), func(start, end int) { 214 for i := start; i < end; i++ { 215 e[i].Add(&e[i], &m[i]) 216 } 217 }, 512).Wait() 218 219 // e.Add(e, polynomial.Polynomial(m)) 220 } 221 222 223 // computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k 224 // the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). 225 // The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. 226 func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { 227 228 degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) 229 nbGateIn := len(c.inputPreprocessors) 230 231 // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables 232 s := make([]polynomial.MultiLin, nbGateIn+1) 233 s[0] = c.eq 234 copy(s[1:], c.inputPreprocessors) 235 236 // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called 237 nbInner := len(s) // wrt output, which has high nbOuter and low nbInner 238 nbOuter := len(s[0]) / 2 239 240 gJ := make([]{{.ElementType}}, degGJ) 241 var mu sync.Mutex 242 computeAll := func(start, end int) { 243 var step {{.ElementType}} 244 245 res := make([]{{.ElementType}}, degGJ) 246 operands := make([]{{.ElementType}}, degGJ*nbInner) 247 248 for i := start; i < end; i++ { 249 250 block := nbOuter + i 251 for j := 0; j < nbInner; j++ { 252 step.Set(&s[j][i]) 253 operands[j].Set(&s[j][block]) 254 step.Sub(&operands[j], &step) 255 for d := 1; d < degGJ; d++ { 256 operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) 257 } 258 } 259 260 _s := 0 261 _e := nbInner 262 for d := 0; d < degGJ; d++ { 263 summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) 264 summand.Mul(&summand, &operands[_s]) 265 res[d].Add(&res[d], &summand) 266 _s, _e = _e, _e+nbInner 267 } 268 } 269 mu.Lock() 270 for i := 0; i < len(gJ); i++ { 271 gJ[i].Add(&gJ[i], &res[i]) 272 } 273 mu.Unlock() 274 } 275 276 const minBlockSize = 64 277 278 if nbOuter < minBlockSize { 279 // no parallelization 280 computeAll(0, nbOuter) 281 } else { 282 c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() 283 } 284 285 // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though 286 287 return gJ 288 } 289 290 // Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j 291 func (c *eqTimesGateEvalSumcheckClaims) Next(element {{.ElementType}}) polynomial.Polynomial { 292 const minBlockSize = 512 293 n := len(c.eq) / 2 294 if n < minBlockSize { 295 // no parallelization 296 for i := 0; i < len(c.inputPreprocessors); i++ { 297 c.inputPreprocessors[i].Fold(element) 298 } 299 c.eq.Fold(element) 300 } else { 301 wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) 302 for i := 0; i < len(c.inputPreprocessors); i++ { 303 wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) 304 } 305 c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() 306 for _, wg := range wgs { 307 wg.Wait() 308 } 309 } 310 311 return c.computeGJ() 312 } 313 314 func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { 315 return len(c.evaluationPoints[0]) 316 } 317 318 func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { 319 return len(c.claimedEvaluations) 320 } 321 322 func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []{{.ElementType}}) interface{} { 323 324 //defer the proof, return list of claims 325 evaluations := make([]{{.ElementType}}, 0, len(c.wire.Inputs)) 326 noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) 327 noMoreClaimsAllowed[c.wire] = struct{}{} 328 329 for inI, in := range c.wire.Inputs { 330 puI := c.inputPreprocessors[inI] 331 if _, found := noMoreClaimsAllowed[in]; !found { 332 noMoreClaimsAllowed[in] = struct{}{} 333 puI.Fold(r[len(r)-1]) 334 c.manager.add(in, r, puI[0]) 335 evaluations = append(evaluations, puI[0]) 336 } 337 c.manager.memPool.Dump(puI) 338 } 339 340 c.manager.memPool.Dump(c.claimedEvaluations, c.eq) 341 342 return evaluations 343 } 344 345 type claimsManager struct { 346 claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims 347 assignment WireAssignment 348 memPool *polynomial.Pool 349 workers *utils.WorkerPool 350 } 351 352 func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { 353 claims.assignment = assignment 354 claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) 355 claims.memPool = o.pool 356 claims.workers = o.workers 357 358 for i := range c { 359 wire := &c[i] 360 361 claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ 362 wire: wire, 363 evaluationPoints: make([][]{{.ElementType}}, 0, wire.NbClaims()), 364 claimedEvaluations: claims.memPool.Make(wire.NbClaims()), 365 manager: &claims, 366 } 367 } 368 return 369 } 370 371 func (m *claimsManager) add(wire *Wire, evaluationPoint []{{.ElementType}}, evaluation {{.ElementType}}) { 372 claim := m.claimsMap[wire] 373 i := len(claim.evaluationPoints) 374 claim.claimedEvaluations[i] = evaluation 375 claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) 376 } 377 378 func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { 379 return m.claimsMap[wire] 380 } 381 382 func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { 383 lazy := m.claimsMap[wire] 384 res := &eqTimesGateEvalSumcheckClaims{ 385 wire: wire, 386 evaluationPoints: lazy.evaluationPoints, 387 claimedEvaluations: lazy.claimedEvaluations, 388 manager: m, 389 } 390 391 if wire.IsInput() { 392 res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} 393 } else { 394 res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) 395 396 for inputI, inputW := range wire.Inputs { 397 res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied 398 } 399 } 400 return res 401 } 402 403 func (m *claimsManager) deleteClaim(wire *Wire) { 404 delete(m.claimsMap, wire) 405 } 406 407 type settings struct { 408 pool *polynomial.Pool 409 sorted []*Wire 410 transcript *fiatshamir.Transcript 411 transcriptPrefix string 412 nbVars int 413 workers *utils.WorkerPool 414 } 415 416 type Option func(*settings) 417 418 func WithPool(pool *polynomial.Pool) Option { 419 return func (options *settings) { 420 options.pool = pool 421 } 422 } 423 424 func WithSortedCircuit(sorted []*Wire) Option { 425 return func(options *settings) { 426 options.sorted = sorted 427 } 428 } 429 430 func WithWorkers(workers *utils.WorkerPool) Option { 431 return func(options *settings) { 432 options.workers = workers 433 } 434 } 435 436 // MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement 437 func (c Circuit) MemoryRequirements(nbInstances int) []int { 438 res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} 439 440 if res[0] > res[1] { // make sure it's sorted 441 res[0], res[1] = res[1], res[0] 442 if res[1] > res[2] { 443 res[1], res[2] = res[2], res[1] 444 } 445 } 446 447 return res 448 } 449 450 func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { 451 var o settings 452 var err error 453 for _, option := range options { 454 option(&o) 455 } 456 457 o.nbVars = assignment.NumVars() 458 nbInstances := assignment.NumInstances() 459 if 1<<o.nbVars != nbInstances { 460 return o, fmt.Errorf("number of instances must be power of 2") 461 } 462 463 if o.pool == nil { 464 pool := polynomial.NewPool(c.MemoryRequirements(nbInstances)...) 465 o.pool = &pool 466 } 467 468 if o.workers == nil { 469 o.workers = utils.NewWorkerPool() 470 } 471 472 if o.sorted == nil { 473 o.sorted = {{$topologicalSort}}(c) 474 } 475 476 if transcriptSettings.Transcript == nil { 477 challengeNames := ChallengeNames(o.sorted, o.nbVars, transcriptSettings.Prefix) 478 o.transcript = fiatshamir.NewTranscript(transcriptSettings.Hash, challengeNames...) 479 for i := range transcriptSettings.BaseChallenges { 480 if err = o.transcript.Bind(challengeNames[0], transcriptSettings.BaseChallenges[i]); err != nil { 481 return o, err 482 } 483 } 484 } else { 485 o.transcript, o.transcriptPrefix = transcriptSettings.Transcript, transcriptSettings.Prefix 486 } 487 488 return o, err 489 } 490 491 // ProofSize computes how large the proof for a circuit would be. It needs nbUniqueOutputs to be set 492 func ProofSize(c Circuit, logNbInstances int) int { 493 nbUniqueInputs := 0 494 nbPartialEvalPolys := 0 495 for i := range c { 496 nbUniqueInputs += c[i].nbUniqueOutputs // each unique output is manifest in a finalEvalProof entry 497 if !c[i].noProof() { 498 nbPartialEvalPolys += c[i].Gate.Degree() + 1 499 } 500 } 501 return nbUniqueInputs + nbPartialEvalPolys*logNbInstances 502 } 503 504 func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { 505 506 // Pre-compute the size TODO: Consider not doing this and just grow the list by appending 507 size := logNbInstances // first challenge 508 509 for _, w := range sorted { 510 if w.noProof() { // no proof, no challenge 511 continue 512 } 513 if w.NbClaims() > 1 { //combine the claims 514 size++ 515 } 516 size += logNbInstances // full run of sumcheck on logNbInstances variables 517 } 518 519 nums := make([]string, utils.Max(len(sorted), logNbInstances)) 520 for i := range nums { 521 nums[i] = strconv.Itoa(i) 522 } 523 524 challenges := make([]string, size) 525 526 // output wire claims 527 firstChallengePrefix := prefix + "fC." 528 for j := 0; j < logNbInstances; j++ { 529 challenges[j] = firstChallengePrefix + nums[j] 530 } 531 j := logNbInstances 532 for i := len(sorted) - 1; i >= 0; i-- { 533 if sorted[i].noProof() { 534 continue 535 } 536 wirePrefix := prefix + "w" + nums[i] + "." 537 538 if sorted[i].NbClaims() > 1 { 539 challenges[j] = wirePrefix + "comb" 540 j++ 541 } 542 543 partialSumPrefix := wirePrefix + "pSP." 544 for k := 0; k < logNbInstances; k++ { 545 challenges[j] = partialSumPrefix + nums[k] 546 j++ 547 } 548 } 549 return challenges 550 } 551 552 func getFirstChallengeNames(logNbInstances int, prefix string) []string { 553 res := make([]string, logNbInstances) 554 firstChallengePrefix := prefix + "fC." 555 for i := 0; i < logNbInstances; i++ { 556 res[i] = firstChallengePrefix + strconv.Itoa(i) 557 } 558 return res 559 } 560 561 func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]{{.ElementType}}, error) { 562 res := make([]{{.ElementType}}, len(names)) 563 for i, name := range names { 564 if bytes, err := transcript.ComputeChallenge(name); err == nil { 565 res[i].SetBytes(bytes) 566 } else { 567 return nil, err 568 } 569 } 570 return res, nil 571 } 572 573 // Prove consistency of the claimed assignment 574 func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { 575 o, err := setup(c, assignment, transcriptSettings, options...) 576 if err != nil { 577 return nil, err 578 } 579 defer o.workers.Stop() 580 581 claims := newClaimsManager(c, assignment, o) 582 583 proof := make(Proof, len(c)) 584 // firstChallenge called rho in the paper 585 var firstChallenge []{{.ElementType}} 586 firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) 587 if err != nil { 588 return nil, err 589 } 590 591 wirePrefix := o.transcriptPrefix + "w" 592 var baseChallenge [][]byte 593 for i := len(c) - 1; i >= 0; i-- { 594 595 wire := o.sorted[i] 596 597 if wire.IsOutput() { 598 claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) 599 } 600 601 claim := claims.getClaim(wire) 602 if wire.noProof() { // input wires with one claim only 603 proof[i] = sumcheck.Proof{ 604 PartialSumPolys: []polynomial.Polynomial{}, 605 FinalEvalProof: []{{.ElementType}}{}, 606 } 607 } else { 608 if proof[i], err = sumcheck.Prove( 609 claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), 610 ); err != nil { 611 return proof, err 612 } 613 614 finalEvalProof := proof[i].FinalEvalProof.([]{{.ElementType}}) 615 baseChallenge = make([][]byte, len(finalEvalProof)) 616 for j := range finalEvalProof { 617 bytes := finalEvalProof[j].Bytes() 618 baseChallenge[j] = bytes[:] 619 } 620 } 621 // the verifier checks a single claim about input wires itself 622 claims.deleteClaim(wire) 623 } 624 625 return proof, nil 626 } 627 628 // Verify the consistency of the claimed output with the claimed input 629 // Unlike in Prove, the assignment argument need not be complete 630 func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { 631 o, err := setup(c, assignment, transcriptSettings, options...) 632 if err != nil { 633 return err 634 } 635 defer o.workers.Stop() 636 637 claims := newClaimsManager(c, assignment, o) 638 639 var firstChallenge []{{.ElementType}} 640 firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) 641 if err != nil { 642 return err 643 } 644 645 wirePrefix := o.transcriptPrefix + "w" 646 var baseChallenge [][]byte 647 for i := len(c) - 1; i >= 0; i-- { 648 wire := o.sorted[i] 649 650 if wire.IsOutput() { 651 claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) 652 } 653 654 proofW := proof[i] 655 finalEvalProof := proofW.FinalEvalProof.([]{{.ElementType}}) 656 claim := claims.getLazyClaim(wire) 657 if wire.noProof() { // input wires with one claim only 658 // make sure the proof is empty 659 if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { 660 return fmt.Errorf("no proof allowed for input wire with a single claim") 661 } 662 663 if wire.NbClaims() == 1 { // input wire 664 // simply evaluate and see if it matches 665 evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) 666 if !claim.claimedEvaluations[0].Equal(&evaluation) { 667 return fmt.Errorf("incorrect input wire claim") 668 } 669 } 670 } else if err = sumcheck.Verify( 671 claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), 672 ); err == nil { 673 baseChallenge = make([][]byte, len(finalEvalProof)) 674 for j := range finalEvalProof { 675 bytes := finalEvalProof[j].Bytes() 676 baseChallenge[j] = bytes[:] 677 } 678 } else { 679 return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? 680 } 681 claims.deleteClaim(wire) 682 } 683 return nil 684 } 685 686 // outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. 687 func outputsList(c Circuit, indexes map[*Wire]int) [][]int { 688 res := make([][]int, len(c)) 689 for i := range c { 690 res[i] = make([]int, 0) 691 c[i].nbUniqueOutputs = 0 692 if c[i].IsInput() { 693 c[i].Gate = IdentityGate{} 694 } 695 } 696 ins := make(map[int]struct{}, len(c)) 697 for i := range c { 698 for k := range ins { // clear map 699 delete(ins, k) 700 } 701 for _, in := range c[i].Inputs { 702 inI := indexes[in] 703 res[inI] = append(res[inI], i) 704 if _, ok := ins[inI]; !ok { 705 in.nbUniqueOutputs++ 706 ins[inI] = struct{}{} 707 } 708 } 709 } 710 return res 711 } 712 713 type topSortData struct { 714 outputs [][]int 715 status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done 716 index map[*Wire]int 717 leastReady int 718 } 719 720 func (d *topSortData) markDone(i int) { 721 722 d.status[i] = -1 723 724 for _, outI := range d.outputs[i] { 725 d.status[outI]-- 726 if d.status[outI] == 0 && outI < d.leastReady { 727 d.leastReady = outI 728 } 729 } 730 731 for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { 732 d.leastReady++ 733 } 734 } 735 736 func indexMap(c Circuit) map[*Wire]int { 737 res := make(map[*Wire]int, len(c)) 738 for i := range c { 739 res[&c[i]] = i 740 } 741 return res 742 } 743 744 func statusList(c Circuit) []int { 745 res := make([]int, len(c)) 746 for i := range c { 747 res[i] = len(c[i].Inputs) 748 } 749 return res 750 } 751 752 // {{$topologicalSort}} sorts the wires in order of dependence. Such that for any wire, any one it depends on 753 // occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. 754 // It also sets the nbOutput flags, and a dummy IdentityGate for input wires. 755 // Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. 756 // Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input 757 func {{$topologicalSort}}(c Circuit) []*Wire { 758 var data topSortData 759 data.index = indexMap(c) 760 data.outputs = outputsList(c, data.index) 761 data.status = statusList(c) 762 sorted := make([]*Wire, len(c)) 763 764 for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { 765 } 766 767 for i := range c { 768 sorted[i] = &c[data.leastReady] 769 data.markDone(data.leastReady) 770 } 771 772 return sorted 773 } 774 775 // Complete the circuit evaluation from input values 776 func (a WireAssignment) Complete(c Circuit) WireAssignment { 777 778 sortedWires := {{$topologicalSort}}(c) 779 nbInstances := a.NumInstances() 780 maxNbIns := 0 781 782 for _, w := range sortedWires { 783 maxNbIns = utils.Max(maxNbIns, len(w.Inputs)) 784 if a[w] == nil { 785 a[w] = make([]{{.ElementType}}, nbInstances) 786 } 787 } 788 789 parallel.Execute(nbInstances, func(start, end int) { 790 ins := make([]{{.ElementType}}, maxNbIns) 791 for i := start; i < end; i++ { 792 for _, w := range sortedWires { 793 if !w.IsInput() { 794 for inI, in := range w.Inputs { 795 ins[inI] = a[in][i] 796 } 797 a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) 798 } 799 } 800 } 801 }) 802 803 return a 804 } 805 806 func (a WireAssignment) NumInstances() int { 807 for _, aW := range a { 808 return len(aW) 809 } 810 panic("empty assignment") 811 } 812 813 func (a WireAssignment) NumVars() int { 814 for _, aW := range a { 815 return aW.NumVars() 816 } 817 panic("empty assignment") 818 } 819 820 // SerializeToBigInts flattens a proof object into the given slice of big.Ints 821 // useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this 822 func (p Proof) SerializeToBigInts(outs []*big.Int) { 823 offset := 0 824 for i := range p { 825 for _, poly := range p[i].PartialSumPolys { 826 frToBigInts(outs[offset:], poly) 827 offset += len(poly) 828 } 829 if p[i].FinalEvalProof != nil { 830 finalEvalProof := p[i].FinalEvalProof.([]{{.ElementType}}) 831 frToBigInts(outs[offset:], finalEvalProof) 832 offset += len(finalEvalProof) 833 } 834 } 835 } 836 837 func frToBigInts(dst []*big.Int, src []{{.ElementType}}) { 838 for i := range src { 839 src[i].BigInt(dst[i]) 840 } 841 } 842 843 // Gates defined by name 844 var Gates = map[string]Gate{ 845 "identity": IdentityGate{}, 846 "add": AddGate{}, 847 "sub": SubGate{}, 848 "neg": NegGate{}, 849 "mul": MulGate(2), 850 } 851 852 type IdentityGate struct{} 853 type AddGate struct{} 854 type MulGate int 855 type SubGate struct{} 856 type NegGate struct{} 857 858 func (IdentityGate) Evaluate(input ...{{.ElementType}}) {{.ElementType}} { 859 return input[0] 860 } 861 862 func (IdentityGate) Degree() int { 863 return 1 864 } 865 866 func (g AddGate) Evaluate(x ...{{.ElementType}}) (res {{.ElementType}}) { 867 switch len(x) { 868 case 0: 869 // set zero 870 case 1: 871 res.Set(&x[0]) 872 default: 873 res.Add(&x[0], &x[1]) 874 for i := 2; i < len(x); i++ { 875 res.Add(&res, &x[i]) 876 } 877 } 878 return 879 } 880 881 func (g AddGate) Degree() int { 882 return 1 883 } 884 885 func (g MulGate) Evaluate(x ...{{.ElementType}}) (res {{.ElementType}}) { 886 if len(x) != int(g) { 887 panic("wrong input count") 888 } 889 switch len(x) { 890 case 0: 891 res.SetOne() 892 case 1: 893 res.Set(&x[0]) 894 default: 895 res.Mul(&x[0], &x[1]) 896 for i := 2; i < len(x); i++ { 897 res.Mul(&res, &x[i]) 898 } 899 } 900 return 901 } 902 903 func (g MulGate) Degree() int { 904 return int(g) 905 } 906 907 func (g SubGate) Evaluate(element ...{{.ElementType}}) (diff {{.ElementType}}) { 908 if len(element) > 2 { 909 panic("not implemented") //TODO 910 } 911 diff.Sub(&element[0], &element[1]) 912 return 913 } 914 915 func (g SubGate) Degree() int { 916 return 1 917 } 918 919 func (g NegGate) Evaluate(element ...{{.ElementType}}) (neg {{.ElementType}}) { 920 if len(element) != 1 { 921 panic("univariate gate") 922 } 923 neg.Neg(&element[0]) 924 return 925 } 926 927 func (g NegGate) Degree() int { 928 return 1 929 }