github.com/consensys/gnark-crypto@v0.14.0/ecc/bls12-377/fr/gkr/gkr.go (about) 1 // Copyright 2020 Consensys Software Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Code generated by consensys/gnark-crypto DO NOT EDIT 16 17 package gkr 18 19 import ( 20 "fmt" 21 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" 22 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" 23 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/sumcheck" 24 fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" 25 "github.com/consensys/gnark-crypto/internal/parallel" 26 "github.com/consensys/gnark-crypto/utils" 27 "math/big" 28 "strconv" 29 "sync" 30 ) 31 32 // The goal is to prove/verify evaluations of many instances of the same circuit 33 34 // Gate must be a low-degree polynomial 35 type Gate interface { 36 Evaluate(...fr.Element) fr.Element 37 Degree() int 38 } 39 40 type Wire struct { 41 Gate Gate 42 Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire 43 nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) 44 } 45 46 type Circuit []Wire 47 48 func (w Wire) IsInput() bool { 49 return len(w.Inputs) == 0 50 } 51 52 func (w Wire) IsOutput() bool { 53 return w.nbUniqueOutputs == 0 54 } 55 56 func (w Wire) NbClaims() int { 57 if w.IsOutput() { 58 return 1 59 } 60 return w.nbUniqueOutputs 61 } 62 63 func (w Wire) noProof() bool { 64 return w.IsInput() && w.NbClaims() == 1 65 } 66 67 func (c Circuit) maxGateDegree() int { 68 res := 1 69 for i := range c { 70 if !c[i].IsInput() { 71 res = utils.Max(res, c[i].Gate.Degree()) 72 } 73 } 74 return res 75 } 76 77 // WireAssignment is assignment of values to the same wire across many instances of the circuit 78 type WireAssignment map[*Wire]polynomial.MultiLin 79 80 type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) 81 82 type eqTimesGateEvalSumcheckLazyClaims struct { 83 wire *Wire 84 evaluationPoints [][]fr.Element 85 claimedEvaluations []fr.Element 86 manager *claimsManager // WARNING: Circular references 87 } 88 89 func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { 90 return len(e.evaluationPoints) 91 } 92 93 func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { 94 return len(e.evaluationPoints[0]) 95 } 96 97 func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { 98 evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) 99 return evalsAsPoly.Eval(&a) 100 } 101 102 func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { 103 return 1 + e.wire.Gate.Degree() 104 } 105 106 func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { 107 inputEvaluationsNoRedundancy := proof.([]fr.Element) 108 109 // the eq terms 110 numClaims := len(e.evaluationPoints) 111 evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) 112 for i := numClaims - 2; i >= 0; i-- { 113 evaluation.Mul(&evaluation, &combinationCoeff) 114 eq := polynomial.EvalEq(e.evaluationPoints[i], r) 115 evaluation.Add(&evaluation, &eq) 116 } 117 118 // the g(...) term 119 var gateEvaluation fr.Element 120 if e.wire.IsInput() { 121 gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) 122 } else { 123 inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) 124 indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) 125 126 proofI := 0 127 for inI, in := range e.wire.Inputs { 128 indexInProof, found := indexesInProof[in] 129 if !found { 130 indexInProof = proofI 131 indexesInProof[in] = indexInProof 132 133 // defer verification, store new claim 134 e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) 135 proofI++ 136 } 137 inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] 138 } 139 if proofI != len(inputEvaluationsNoRedundancy) { 140 return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) 141 } 142 gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) 143 } 144 145 evaluation.Mul(&evaluation, &gateEvaluation) 146 147 if evaluation.Equal(&purportedValue) { 148 return nil 149 } 150 return fmt.Errorf("incompatible evaluations") 151 } 152 153 type eqTimesGateEvalSumcheckClaims struct { 154 wire *Wire 155 evaluationPoints [][]fr.Element // x in the paper 156 claimedEvaluations []fr.Element // y in the paper 157 manager *claimsManager 158 159 inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations 160 161 eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) 162 } 163 164 func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { 165 varsNum := c.VarsNum() 166 eqLength := 1 << varsNum 167 claimsNum := c.ClaimsNum() 168 // initialize the eq tables 169 c.eq = c.manager.memPool.Make(eqLength) 170 171 c.eq[0].SetOne() 172 c.eq.Eq(c.evaluationPoints[0]) 173 174 newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) 175 aI := combinationCoeff 176 177 for k := 1; k < claimsNum; k++ { //TODO: parallelizable? 178 // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points 179 newEq[0].Set(&aI) 180 181 c.eqAcc(c.eq, newEq, c.evaluationPoints[k]) 182 183 // newEq.Eq(c.evaluationPoints[k]) 184 // eqAsPoly := polynomial.Polynomial(c.eq) //just semantics 185 // eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) 186 187 if k+1 < claimsNum { 188 aI.Mul(&aI, &combinationCoeff) 189 } 190 } 191 192 c.manager.memPool.Dump(newEq) 193 194 // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree 195 196 return c.computeGJ() 197 } 198 199 // eqAcc sets m to an eq table at q and then adds it to e 200 func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) { 201 n := len(q) 202 203 //At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁) 204 for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁ 205 // go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ 206 const threshold = 1 << 6 207 k := 1 << i 208 if k < threshold { 209 for j := 0; j < k; j++ { 210 j0 := j << (n - i) // bᵢ₊₁ = 0 211 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 212 213 m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ 214 m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) 215 } 216 } else { 217 c.manager.workers.Submit(k, func(start, end int) { 218 for j := start; j < end; j++ { 219 j0 := j << (n - i) // bᵢ₊₁ = 0 220 j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1 221 222 m[j1].Mul(&q[i], &m[j0]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁ 223 m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁) 224 } 225 }, 1024).Wait() 226 } 227 228 } 229 c.manager.workers.Submit(len(e), func(start, end int) { 230 for i := start; i < end; i++ { 231 e[i].Add(&e[i], &m[i]) 232 } 233 }, 512).Wait() 234 235 // e.Add(e, polynomial.Polynomial(m)) 236 } 237 238 // computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k 239 // the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). 240 // The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum. 241 func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial { 242 243 degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) 244 nbGateIn := len(c.inputPreprocessors) 245 246 // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables 247 s := make([]polynomial.MultiLin, nbGateIn+1) 248 s[0] = c.eq 249 copy(s[1:], c.inputPreprocessors) 250 251 // Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called 252 nbInner := len(s) // wrt output, which has high nbOuter and low nbInner 253 nbOuter := len(s[0]) / 2 254 255 gJ := make([]fr.Element, degGJ) 256 var mu sync.Mutex 257 computeAll := func(start, end int) { 258 var step fr.Element 259 260 res := make([]fr.Element, degGJ) 261 operands := make([]fr.Element, degGJ*nbInner) 262 263 for i := start; i < end; i++ { 264 265 block := nbOuter + i 266 for j := 0; j < nbInner; j++ { 267 step.Set(&s[j][i]) 268 operands[j].Set(&s[j][block]) 269 step.Sub(&operands[j], &step) 270 for d := 1; d < degGJ; d++ { 271 operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step) 272 } 273 } 274 275 _s := 0 276 _e := nbInner 277 for d := 0; d < degGJ; d++ { 278 summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...) 279 summand.Mul(&summand, &operands[_s]) 280 res[d].Add(&res[d], &summand) 281 _s, _e = _e, _e+nbInner 282 } 283 } 284 mu.Lock() 285 for i := 0; i < len(gJ); i++ { 286 gJ[i].Add(&gJ[i], &res[i]) 287 } 288 mu.Unlock() 289 } 290 291 const minBlockSize = 64 292 293 if nbOuter < minBlockSize { 294 // no parallelization 295 computeAll(0, nbOuter) 296 } else { 297 c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait() 298 } 299 300 // Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though 301 302 return gJ 303 } 304 305 // Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j 306 func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { 307 const minBlockSize = 512 308 n := len(c.eq) / 2 309 if n < minBlockSize { 310 // no parallelization 311 for i := 0; i < len(c.inputPreprocessors); i++ { 312 c.inputPreprocessors[i].Fold(element) 313 } 314 c.eq.Fold(element) 315 } else { 316 wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors)) 317 for i := 0; i < len(c.inputPreprocessors); i++ { 318 wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize) 319 } 320 c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait() 321 for _, wg := range wgs { 322 wg.Wait() 323 } 324 } 325 326 return c.computeGJ() 327 } 328 329 func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { 330 return len(c.evaluationPoints[0]) 331 } 332 333 func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { 334 return len(c.claimedEvaluations) 335 } 336 337 func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { 338 339 //defer the proof, return list of claims 340 evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) 341 noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) 342 noMoreClaimsAllowed[c.wire] = struct{}{} 343 344 for inI, in := range c.wire.Inputs { 345 puI := c.inputPreprocessors[inI] 346 if _, found := noMoreClaimsAllowed[in]; !found { 347 noMoreClaimsAllowed[in] = struct{}{} 348 puI.Fold(r[len(r)-1]) 349 c.manager.add(in, r, puI[0]) 350 evaluations = append(evaluations, puI[0]) 351 } 352 c.manager.memPool.Dump(puI) 353 } 354 355 c.manager.memPool.Dump(c.claimedEvaluations, c.eq) 356 357 return evaluations 358 } 359 360 type claimsManager struct { 361 claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims 362 assignment WireAssignment 363 memPool *polynomial.Pool 364 workers *utils.WorkerPool 365 } 366 367 func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) { 368 claims.assignment = assignment 369 claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) 370 claims.memPool = o.pool 371 claims.workers = o.workers 372 373 for i := range c { 374 wire := &c[i] 375 376 claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ 377 wire: wire, 378 evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), 379 claimedEvaluations: claims.memPool.Make(wire.NbClaims()), 380 manager: &claims, 381 } 382 } 383 return 384 } 385 386 func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { 387 claim := m.claimsMap[wire] 388 i := len(claim.evaluationPoints) 389 claim.claimedEvaluations[i] = evaluation 390 claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) 391 } 392 393 func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { 394 return m.claimsMap[wire] 395 } 396 397 func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { 398 lazy := m.claimsMap[wire] 399 res := &eqTimesGateEvalSumcheckClaims{ 400 wire: wire, 401 evaluationPoints: lazy.evaluationPoints, 402 claimedEvaluations: lazy.claimedEvaluations, 403 manager: m, 404 } 405 406 if wire.IsInput() { 407 res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} 408 } else { 409 res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) 410 411 for inputI, inputW := range wire.Inputs { 412 res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied 413 } 414 } 415 return res 416 } 417 418 func (m *claimsManager) deleteClaim(wire *Wire) { 419 delete(m.claimsMap, wire) 420 } 421 422 type settings struct { 423 pool *polynomial.Pool 424 sorted []*Wire 425 transcript *fiatshamir.Transcript 426 transcriptPrefix string 427 nbVars int 428 workers *utils.WorkerPool 429 } 430 431 type Option func(*settings) 432 433 func WithPool(pool *polynomial.Pool) Option { 434 return func(options *settings) { 435 options.pool = pool 436 } 437 } 438 439 func WithSortedCircuit(sorted []*Wire) Option { 440 return func(options *settings) { 441 options.sorted = sorted 442 } 443 } 444 445 func WithWorkers(workers *utils.WorkerPool) Option { 446 return func(options *settings) { 447 options.workers = workers 448 } 449 } 450 451 // MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement 452 func (c Circuit) MemoryRequirements(nbInstances int) []int { 453 res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)} 454 455 if res[0] > res[1] { // make sure it's sorted 456 res[0], res[1] = res[1], res[0] 457 if res[1] > res[2] { 458 res[1], res[2] = res[2], res[1] 459 } 460 } 461 462 return res 463 } 464 465 func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { 466 var o settings 467 var err error 468 for _, option := range options { 469 option(&o) 470 } 471 472 o.nbVars = assignment.NumVars() 473 nbInstances := assignment.NumInstances() 474 if 1<<o.nbVars != nbInstances { 475 return o, fmt.Errorf("number of instances must be power of 2") 476 } 477 478 if o.pool == nil { 479 pool := polynomial.NewPool(c.MemoryRequirements(nbInstances)...) 480 o.pool = &pool 481 } 482 483 if o.workers == nil { 484 o.workers = utils.NewWorkerPool() 485 } 486 487 if o.sorted == nil { 488 o.sorted = topologicalSort(c) 489 } 490 491 if transcriptSettings.Transcript == nil { 492 challengeNames := ChallengeNames(o.sorted, o.nbVars, transcriptSettings.Prefix) 493 o.transcript = fiatshamir.NewTranscript(transcriptSettings.Hash, challengeNames...) 494 for i := range transcriptSettings.BaseChallenges { 495 if err = o.transcript.Bind(challengeNames[0], transcriptSettings.BaseChallenges[i]); err != nil { 496 return o, err 497 } 498 } 499 } else { 500 o.transcript, o.transcriptPrefix = transcriptSettings.Transcript, transcriptSettings.Prefix 501 } 502 503 return o, err 504 } 505 506 // ProofSize computes how large the proof for a circuit would be. It needs nbUniqueOutputs to be set 507 func ProofSize(c Circuit, logNbInstances int) int { 508 nbUniqueInputs := 0 509 nbPartialEvalPolys := 0 510 for i := range c { 511 nbUniqueInputs += c[i].nbUniqueOutputs // each unique output is manifest in a finalEvalProof entry 512 if !c[i].noProof() { 513 nbPartialEvalPolys += c[i].Gate.Degree() + 1 514 } 515 } 516 return nbUniqueInputs + nbPartialEvalPolys*logNbInstances 517 } 518 519 func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { 520 521 // Pre-compute the size TODO: Consider not doing this and just grow the list by appending 522 size := logNbInstances // first challenge 523 524 for _, w := range sorted { 525 if w.noProof() { // no proof, no challenge 526 continue 527 } 528 if w.NbClaims() > 1 { //combine the claims 529 size++ 530 } 531 size += logNbInstances // full run of sumcheck on logNbInstances variables 532 } 533 534 nums := make([]string, utils.Max(len(sorted), logNbInstances)) 535 for i := range nums { 536 nums[i] = strconv.Itoa(i) 537 } 538 539 challenges := make([]string, size) 540 541 // output wire claims 542 firstChallengePrefix := prefix + "fC." 543 for j := 0; j < logNbInstances; j++ { 544 challenges[j] = firstChallengePrefix + nums[j] 545 } 546 j := logNbInstances 547 for i := len(sorted) - 1; i >= 0; i-- { 548 if sorted[i].noProof() { 549 continue 550 } 551 wirePrefix := prefix + "w" + nums[i] + "." 552 553 if sorted[i].NbClaims() > 1 { 554 challenges[j] = wirePrefix + "comb" 555 j++ 556 } 557 558 partialSumPrefix := wirePrefix + "pSP." 559 for k := 0; k < logNbInstances; k++ { 560 challenges[j] = partialSumPrefix + nums[k] 561 j++ 562 } 563 } 564 return challenges 565 } 566 567 func getFirstChallengeNames(logNbInstances int, prefix string) []string { 568 res := make([]string, logNbInstances) 569 firstChallengePrefix := prefix + "fC." 570 for i := 0; i < logNbInstances; i++ { 571 res[i] = firstChallengePrefix + strconv.Itoa(i) 572 } 573 return res 574 } 575 576 func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { 577 res := make([]fr.Element, len(names)) 578 for i, name := range names { 579 if bytes, err := transcript.ComputeChallenge(name); err == nil { 580 res[i].SetBytes(bytes) 581 } else { 582 return nil, err 583 } 584 } 585 return res, nil 586 } 587 588 // Prove consistency of the claimed assignment 589 func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { 590 o, err := setup(c, assignment, transcriptSettings, options...) 591 if err != nil { 592 return nil, err 593 } 594 defer o.workers.Stop() 595 596 claims := newClaimsManager(c, assignment, o) 597 598 proof := make(Proof, len(c)) 599 // firstChallenge called rho in the paper 600 var firstChallenge []fr.Element 601 firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) 602 if err != nil { 603 return nil, err 604 } 605 606 wirePrefix := o.transcriptPrefix + "w" 607 var baseChallenge [][]byte 608 for i := len(c) - 1; i >= 0; i-- { 609 610 wire := o.sorted[i] 611 612 if wire.IsOutput() { 613 claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) 614 } 615 616 claim := claims.getClaim(wire) 617 if wire.noProof() { // input wires with one claim only 618 proof[i] = sumcheck.Proof{ 619 PartialSumPolys: []polynomial.Polynomial{}, 620 FinalEvalProof: []fr.Element{}, 621 } 622 } else { 623 if proof[i], err = sumcheck.Prove( 624 claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), 625 ); err != nil { 626 return proof, err 627 } 628 629 finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) 630 baseChallenge = make([][]byte, len(finalEvalProof)) 631 for j := range finalEvalProof { 632 bytes := finalEvalProof[j].Bytes() 633 baseChallenge[j] = bytes[:] 634 } 635 } 636 // the verifier checks a single claim about input wires itself 637 claims.deleteClaim(wire) 638 } 639 640 return proof, nil 641 } 642 643 // Verify the consistency of the claimed output with the claimed input 644 // Unlike in Prove, the assignment argument need not be complete 645 func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { 646 o, err := setup(c, assignment, transcriptSettings, options...) 647 if err != nil { 648 return err 649 } 650 defer o.workers.Stop() 651 652 claims := newClaimsManager(c, assignment, o) 653 654 var firstChallenge []fr.Element 655 firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) 656 if err != nil { 657 return err 658 } 659 660 wirePrefix := o.transcriptPrefix + "w" 661 var baseChallenge [][]byte 662 for i := len(c) - 1; i >= 0; i-- { 663 wire := o.sorted[i] 664 665 if wire.IsOutput() { 666 claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) 667 } 668 669 proofW := proof[i] 670 finalEvalProof := proofW.FinalEvalProof.([]fr.Element) 671 claim := claims.getLazyClaim(wire) 672 if wire.noProof() { // input wires with one claim only 673 // make sure the proof is empty 674 if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { 675 return fmt.Errorf("no proof allowed for input wire with a single claim") 676 } 677 678 if wire.NbClaims() == 1 { // input wire 679 // simply evaluate and see if it matches 680 evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) 681 if !claim.claimedEvaluations[0].Equal(&evaluation) { 682 return fmt.Errorf("incorrect input wire claim") 683 } 684 } 685 } else if err = sumcheck.Verify( 686 claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), 687 ); err == nil { 688 baseChallenge = make([][]byte, len(finalEvalProof)) 689 for j := range finalEvalProof { 690 bytes := finalEvalProof[j].Bytes() 691 baseChallenge[j] = bytes[:] 692 } 693 } else { 694 return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? 695 } 696 claims.deleteClaim(wire) 697 } 698 return nil 699 } 700 701 // outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. 702 func outputsList(c Circuit, indexes map[*Wire]int) [][]int { 703 res := make([][]int, len(c)) 704 for i := range c { 705 res[i] = make([]int, 0) 706 c[i].nbUniqueOutputs = 0 707 if c[i].IsInput() { 708 c[i].Gate = IdentityGate{} 709 } 710 } 711 ins := make(map[int]struct{}, len(c)) 712 for i := range c { 713 for k := range ins { // clear map 714 delete(ins, k) 715 } 716 for _, in := range c[i].Inputs { 717 inI := indexes[in] 718 res[inI] = append(res[inI], i) 719 if _, ok := ins[inI]; !ok { 720 in.nbUniqueOutputs++ 721 ins[inI] = struct{}{} 722 } 723 } 724 } 725 return res 726 } 727 728 type topSortData struct { 729 outputs [][]int 730 status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done 731 index map[*Wire]int 732 leastReady int 733 } 734 735 func (d *topSortData) markDone(i int) { 736 737 d.status[i] = -1 738 739 for _, outI := range d.outputs[i] { 740 d.status[outI]-- 741 if d.status[outI] == 0 && outI < d.leastReady { 742 d.leastReady = outI 743 } 744 } 745 746 for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { 747 d.leastReady++ 748 } 749 } 750 751 func indexMap(c Circuit) map[*Wire]int { 752 res := make(map[*Wire]int, len(c)) 753 for i := range c { 754 res[&c[i]] = i 755 } 756 return res 757 } 758 759 func statusList(c Circuit) []int { 760 res := make([]int, len(c)) 761 for i := range c { 762 res[i] = len(c[i].Inputs) 763 } 764 return res 765 } 766 767 // topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on 768 // occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. 769 // It also sets the nbOutput flags, and a dummy IdentityGate for input wires. 770 // Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. 771 // Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input 772 func topologicalSort(c Circuit) []*Wire { 773 var data topSortData 774 data.index = indexMap(c) 775 data.outputs = outputsList(c, data.index) 776 data.status = statusList(c) 777 sorted := make([]*Wire, len(c)) 778 779 for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { 780 } 781 782 for i := range c { 783 sorted[i] = &c[data.leastReady] 784 data.markDone(data.leastReady) 785 } 786 787 return sorted 788 } 789 790 // Complete the circuit evaluation from input values 791 func (a WireAssignment) Complete(c Circuit) WireAssignment { 792 793 sortedWires := topologicalSort(c) 794 nbInstances := a.NumInstances() 795 maxNbIns := 0 796 797 for _, w := range sortedWires { 798 maxNbIns = utils.Max(maxNbIns, len(w.Inputs)) 799 if a[w] == nil { 800 a[w] = make([]fr.Element, nbInstances) 801 } 802 } 803 804 parallel.Execute(nbInstances, func(start, end int) { 805 ins := make([]fr.Element, maxNbIns) 806 for i := start; i < end; i++ { 807 for _, w := range sortedWires { 808 if !w.IsInput() { 809 for inI, in := range w.Inputs { 810 ins[inI] = a[in][i] 811 } 812 a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...) 813 } 814 } 815 } 816 }) 817 818 return a 819 } 820 821 func (a WireAssignment) NumInstances() int { 822 for _, aW := range a { 823 return len(aW) 824 } 825 panic("empty assignment") 826 } 827 828 func (a WireAssignment) NumVars() int { 829 for _, aW := range a { 830 return aW.NumVars() 831 } 832 panic("empty assignment") 833 } 834 835 // SerializeToBigInts flattens a proof object into the given slice of big.Ints 836 // useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this 837 func (p Proof) SerializeToBigInts(outs []*big.Int) { 838 offset := 0 839 for i := range p { 840 for _, poly := range p[i].PartialSumPolys { 841 frToBigInts(outs[offset:], poly) 842 offset += len(poly) 843 } 844 if p[i].FinalEvalProof != nil { 845 finalEvalProof := p[i].FinalEvalProof.([]fr.Element) 846 frToBigInts(outs[offset:], finalEvalProof) 847 offset += len(finalEvalProof) 848 } 849 } 850 } 851 852 func frToBigInts(dst []*big.Int, src []fr.Element) { 853 for i := range src { 854 src[i].BigInt(dst[i]) 855 } 856 } 857 858 // Gates defined by name 859 var Gates = map[string]Gate{ 860 "identity": IdentityGate{}, 861 "add": AddGate{}, 862 "sub": SubGate{}, 863 "neg": NegGate{}, 864 "mul": MulGate(2), 865 } 866 867 type IdentityGate struct{} 868 type AddGate struct{} 869 type MulGate int 870 type SubGate struct{} 871 type NegGate struct{} 872 873 func (IdentityGate) Evaluate(input ...fr.Element) fr.Element { 874 return input[0] 875 } 876 877 func (IdentityGate) Degree() int { 878 return 1 879 } 880 881 func (g AddGate) Evaluate(x ...fr.Element) (res fr.Element) { 882 switch len(x) { 883 case 0: 884 // set zero 885 case 1: 886 res.Set(&x[0]) 887 default: 888 res.Add(&x[0], &x[1]) 889 for i := 2; i < len(x); i++ { 890 res.Add(&res, &x[i]) 891 } 892 } 893 return 894 } 895 896 func (g AddGate) Degree() int { 897 return 1 898 } 899 900 func (g MulGate) Evaluate(x ...fr.Element) (res fr.Element) { 901 if len(x) != int(g) { 902 panic("wrong input count") 903 } 904 switch len(x) { 905 case 0: 906 res.SetOne() 907 case 1: 908 res.Set(&x[0]) 909 default: 910 res.Mul(&x[0], &x[1]) 911 for i := 2; i < len(x); i++ { 912 res.Mul(&res, &x[i]) 913 } 914 } 915 return 916 } 917 918 func (g MulGate) Degree() int { 919 return int(g) 920 } 921 922 func (g SubGate) Evaluate(element ...fr.Element) (diff fr.Element) { 923 if len(element) > 2 { 924 panic("not implemented") //TODO 925 } 926 diff.Sub(&element[0], &element[1]) 927 return 928 } 929 930 func (g SubGate) Degree() int { 931 return 1 932 } 933 934 func (g NegGate) Evaluate(element ...fr.Element) (neg fr.Element) { 935 if len(element) != 1 { 936 panic("univariate gate") 937 } 938 neg.Neg(&element[0]) 939 return 940 } 941 942 func (g NegGate) Degree() int { 943 return 1 944 }