github.com/consensys/gnark-crypto@v0.14.0/ecc/bn254/fr/gkr/gkr.go (about)

     1  // Copyright 2020 Consensys Software Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Code generated by consensys/gnark-crypto DO NOT EDIT
    16  
    17  package gkr
    18  
    19  import (
    20  	"fmt"
    21  	"github.com/consensys/gnark-crypto/ecc/bn254/fr"
    22  	"github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial"
    23  	"github.com/consensys/gnark-crypto/ecc/bn254/fr/sumcheck"
    24  	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
    25  	"github.com/consensys/gnark-crypto/internal/parallel"
    26  	"github.com/consensys/gnark-crypto/utils"
    27  	"math/big"
    28  	"strconv"
    29  	"sync"
    30  )
    31  
    32  // The goal is to prove/verify evaluations of many instances of the same circuit
    33  
    34  // Gate must be a low-degree polynomial
    35  type Gate interface {
    36  	Evaluate(...fr.Element) fr.Element
    37  	Degree() int
    38  }
    39  
    40  type Wire struct {
    41  	Gate            Gate
    42  	Inputs          []*Wire // if there are no Inputs, the wire is assumed an input wire
    43  	nbUniqueOutputs int     // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one)
    44  }
    45  
    46  type Circuit []Wire
    47  
    48  func (w Wire) IsInput() bool {
    49  	return len(w.Inputs) == 0
    50  }
    51  
    52  func (w Wire) IsOutput() bool {
    53  	return w.nbUniqueOutputs == 0
    54  }
    55  
    56  func (w Wire) NbClaims() int {
    57  	if w.IsOutput() {
    58  		return 1
    59  	}
    60  	return w.nbUniqueOutputs
    61  }
    62  
    63  func (w Wire) noProof() bool {
    64  	return w.IsInput() && w.NbClaims() == 1
    65  }
    66  
    67  func (c Circuit) maxGateDegree() int {
    68  	res := 1
    69  	for i := range c {
    70  		if !c[i].IsInput() {
    71  			res = utils.Max(res, c[i].Gate.Degree())
    72  		}
    73  	}
    74  	return res
    75  }
    76  
    77  // WireAssignment is assignment of values to the same wire across many instances of the circuit
    78  type WireAssignment map[*Wire]polynomial.MultiLin
    79  
    80  type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial)
    81  
    82  type eqTimesGateEvalSumcheckLazyClaims struct {
    83  	wire               *Wire
    84  	evaluationPoints   [][]fr.Element
    85  	claimedEvaluations []fr.Element
    86  	manager            *claimsManager // WARNING: Circular references
    87  }
    88  
    89  func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int {
    90  	return len(e.evaluationPoints)
    91  }
    92  
    93  func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int {
    94  	return len(e.evaluationPoints[0])
    95  }
    96  
    97  func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element {
    98  	evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations)
    99  	return evalsAsPoly.Eval(&a)
   100  }
   101  
   102  func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int {
   103  	return 1 + e.wire.Gate.Degree()
   104  }
   105  
   106  func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error {
   107  	inputEvaluationsNoRedundancy := proof.([]fr.Element)
   108  
   109  	// the eq terms
   110  	numClaims := len(e.evaluationPoints)
   111  	evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r)
   112  	for i := numClaims - 2; i >= 0; i-- {
   113  		evaluation.Mul(&evaluation, &combinationCoeff)
   114  		eq := polynomial.EvalEq(e.evaluationPoints[i], r)
   115  		evaluation.Add(&evaluation, &eq)
   116  	}
   117  
   118  	// the g(...) term
   119  	var gateEvaluation fr.Element
   120  	if e.wire.IsInput() {
   121  		gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool)
   122  	} else {
   123  		inputEvaluations := make([]fr.Element, len(e.wire.Inputs))
   124  		indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy))
   125  
   126  		proofI := 0
   127  		for inI, in := range e.wire.Inputs {
   128  			indexInProof, found := indexesInProof[in]
   129  			if !found {
   130  				indexInProof = proofI
   131  				indexesInProof[in] = indexInProof
   132  
   133  				// defer verification, store new claim
   134  				e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof])
   135  				proofI++
   136  			}
   137  			inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof]
   138  		}
   139  		if proofI != len(inputEvaluationsNoRedundancy) {
   140  			return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI)
   141  		}
   142  		gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...)
   143  	}
   144  
   145  	evaluation.Mul(&evaluation, &gateEvaluation)
   146  
   147  	if evaluation.Equal(&purportedValue) {
   148  		return nil
   149  	}
   150  	return fmt.Errorf("incompatible evaluations")
   151  }
   152  
   153  type eqTimesGateEvalSumcheckClaims struct {
   154  	wire               *Wire
   155  	evaluationPoints   [][]fr.Element // x in the paper
   156  	claimedEvaluations []fr.Element   // y in the paper
   157  	manager            *claimsManager
   158  
   159  	inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations
   160  
   161  	eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -)
   162  }
   163  
   164  func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial {
   165  	varsNum := c.VarsNum()
   166  	eqLength := 1 << varsNum
   167  	claimsNum := c.ClaimsNum()
   168  	// initialize the eq tables
   169  	c.eq = c.manager.memPool.Make(eqLength)
   170  
   171  	c.eq[0].SetOne()
   172  	c.eq.Eq(c.evaluationPoints[0])
   173  
   174  	newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength))
   175  	aI := combinationCoeff
   176  
   177  	for k := 1; k < claimsNum; k++ { //TODO: parallelizable?
   178  		// define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points
   179  		newEq[0].Set(&aI)
   180  
   181  		c.eqAcc(c.eq, newEq, c.evaluationPoints[k])
   182  
   183  		// newEq.Eq(c.evaluationPoints[k])
   184  		// eqAsPoly := polynomial.Polynomial(c.eq) //just semantics
   185  		// eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq))
   186  
   187  		if k+1 < claimsNum {
   188  			aI.Mul(&aI, &combinationCoeff)
   189  		}
   190  	}
   191  
   192  	c.manager.memPool.Dump(newEq)
   193  
   194  	// from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree
   195  
   196  	return c.computeGJ()
   197  }
   198  
   199  // eqAcc sets m to an eq table at q and then adds it to e
   200  func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []fr.Element) {
   201  	n := len(q)
   202  
   203  	//At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁)
   204  	for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁
   205  		// go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ
   206  		const threshold = 1 << 6
   207  		k := 1 << i
   208  		if k < threshold {
   209  			for j := 0; j < k; j++ {
   210  				j0 := j << (n - i)    // bᵢ₊₁ = 0
   211  				j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1
   212  
   213  				m[j1].Mul(&q[i], &m[j0])  // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁
   214  				m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁)
   215  			}
   216  		} else {
   217  			c.manager.workers.Submit(k, func(start, end int) {
   218  				for j := start; j < end; j++ {
   219  					j0 := j << (n - i)    // bᵢ₊₁ = 0
   220  					j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1
   221  
   222  					m[j1].Mul(&q[i], &m[j0])  // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁
   223  					m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁)
   224  				}
   225  			}, 1024).Wait()
   226  		}
   227  
   228  	}
   229  	c.manager.workers.Submit(len(e), func(start, end int) {
   230  		for i := start; i < end; i++ {
   231  			e[i].Add(&e[i], &m[i])
   232  		}
   233  	}, 512).Wait()
   234  
   235  	// e.Add(e, polynomial.Polynomial(m))
   236  }
   237  
   238  // computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where  E = ∑ eq_k
   239  // the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)).
   240  // The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum.
   241  func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial {
   242  
   243  	degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j)
   244  	nbGateIn := len(c.inputPreprocessors)
   245  
   246  	// Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables
   247  	s := make([]polynomial.MultiLin, nbGateIn+1)
   248  	s[0] = c.eq
   249  	copy(s[1:], c.inputPreprocessors)
   250  
   251  	// Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called
   252  	nbInner := len(s) // wrt output, which has high nbOuter and low nbInner
   253  	nbOuter := len(s[0]) / 2
   254  
   255  	gJ := make([]fr.Element, degGJ)
   256  	var mu sync.Mutex
   257  	computeAll := func(start, end int) {
   258  		var step fr.Element
   259  
   260  		res := make([]fr.Element, degGJ)
   261  		operands := make([]fr.Element, degGJ*nbInner)
   262  
   263  		for i := start; i < end; i++ {
   264  
   265  			block := nbOuter + i
   266  			for j := 0; j < nbInner; j++ {
   267  				step.Set(&s[j][i])
   268  				operands[j].Set(&s[j][block])
   269  				step.Sub(&operands[j], &step)
   270  				for d := 1; d < degGJ; d++ {
   271  					operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step)
   272  				}
   273  			}
   274  
   275  			_s := 0
   276  			_e := nbInner
   277  			for d := 0; d < degGJ; d++ {
   278  				summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...)
   279  				summand.Mul(&summand, &operands[_s])
   280  				res[d].Add(&res[d], &summand)
   281  				_s, _e = _e, _e+nbInner
   282  			}
   283  		}
   284  		mu.Lock()
   285  		for i := 0; i < len(gJ); i++ {
   286  			gJ[i].Add(&gJ[i], &res[i])
   287  		}
   288  		mu.Unlock()
   289  	}
   290  
   291  	const minBlockSize = 64
   292  
   293  	if nbOuter < minBlockSize {
   294  		// no parallelization
   295  		computeAll(0, nbOuter)
   296  	} else {
   297  		c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait()
   298  	}
   299  
   300  	// Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though
   301  
   302  	return gJ
   303  }
   304  
   305  // Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j
   306  func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial {
   307  	const minBlockSize = 512
   308  	n := len(c.eq) / 2
   309  	if n < minBlockSize {
   310  		// no parallelization
   311  		for i := 0; i < len(c.inputPreprocessors); i++ {
   312  			c.inputPreprocessors[i].Fold(element)
   313  		}
   314  		c.eq.Fold(element)
   315  	} else {
   316  		wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors))
   317  		for i := 0; i < len(c.inputPreprocessors); i++ {
   318  			wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize)
   319  		}
   320  		c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait()
   321  		for _, wg := range wgs {
   322  			wg.Wait()
   323  		}
   324  	}
   325  
   326  	return c.computeGJ()
   327  }
   328  
   329  func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int {
   330  	return len(c.evaluationPoints[0])
   331  }
   332  
   333  func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int {
   334  	return len(c.claimedEvaluations)
   335  }
   336  
   337  func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} {
   338  
   339  	//defer the proof, return list of claims
   340  	evaluations := make([]fr.Element, 0, len(c.wire.Inputs))
   341  	noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors))
   342  	noMoreClaimsAllowed[c.wire] = struct{}{}
   343  
   344  	for inI, in := range c.wire.Inputs {
   345  		puI := c.inputPreprocessors[inI]
   346  		if _, found := noMoreClaimsAllowed[in]; !found {
   347  			noMoreClaimsAllowed[in] = struct{}{}
   348  			puI.Fold(r[len(r)-1])
   349  			c.manager.add(in, r, puI[0])
   350  			evaluations = append(evaluations, puI[0])
   351  		}
   352  		c.manager.memPool.Dump(puI)
   353  	}
   354  
   355  	c.manager.memPool.Dump(c.claimedEvaluations, c.eq)
   356  
   357  	return evaluations
   358  }
   359  
   360  type claimsManager struct {
   361  	claimsMap  map[*Wire]*eqTimesGateEvalSumcheckLazyClaims
   362  	assignment WireAssignment
   363  	memPool    *polynomial.Pool
   364  	workers    *utils.WorkerPool
   365  }
   366  
   367  func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) {
   368  	claims.assignment = assignment
   369  	claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c))
   370  	claims.memPool = o.pool
   371  	claims.workers = o.workers
   372  
   373  	for i := range c {
   374  		wire := &c[i]
   375  
   376  		claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{
   377  			wire:               wire,
   378  			evaluationPoints:   make([][]fr.Element, 0, wire.NbClaims()),
   379  			claimedEvaluations: claims.memPool.Make(wire.NbClaims()),
   380  			manager:            &claims,
   381  		}
   382  	}
   383  	return
   384  }
   385  
   386  func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) {
   387  	claim := m.claimsMap[wire]
   388  	i := len(claim.evaluationPoints)
   389  	claim.claimedEvaluations[i] = evaluation
   390  	claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint)
   391  }
   392  
   393  func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims {
   394  	return m.claimsMap[wire]
   395  }
   396  
   397  func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims {
   398  	lazy := m.claimsMap[wire]
   399  	res := &eqTimesGateEvalSumcheckClaims{
   400  		wire:               wire,
   401  		evaluationPoints:   lazy.evaluationPoints,
   402  		claimedEvaluations: lazy.claimedEvaluations,
   403  		manager:            m,
   404  	}
   405  
   406  	if wire.IsInput() {
   407  		res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])}
   408  	} else {
   409  		res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs))
   410  
   411  		for inputI, inputW := range wire.Inputs {
   412  			res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied
   413  		}
   414  	}
   415  	return res
   416  }
   417  
   418  func (m *claimsManager) deleteClaim(wire *Wire) {
   419  	delete(m.claimsMap, wire)
   420  }
   421  
   422  type settings struct {
   423  	pool             *polynomial.Pool
   424  	sorted           []*Wire
   425  	transcript       *fiatshamir.Transcript
   426  	transcriptPrefix string
   427  	nbVars           int
   428  	workers          *utils.WorkerPool
   429  }
   430  
   431  type Option func(*settings)
   432  
   433  func WithPool(pool *polynomial.Pool) Option {
   434  	return func(options *settings) {
   435  		options.pool = pool
   436  	}
   437  }
   438  
   439  func WithSortedCircuit(sorted []*Wire) Option {
   440  	return func(options *settings) {
   441  		options.sorted = sorted
   442  	}
   443  }
   444  
   445  func WithWorkers(workers *utils.WorkerPool) Option {
   446  	return func(options *settings) {
   447  		options.workers = workers
   448  	}
   449  }
   450  
   451  // MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement
   452  func (c Circuit) MemoryRequirements(nbInstances int) []int {
   453  	res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)}
   454  
   455  	if res[0] > res[1] { // make sure it's sorted
   456  		res[0], res[1] = res[1], res[0]
   457  		if res[1] > res[2] {
   458  			res[1], res[2] = res[2], res[1]
   459  		}
   460  	}
   461  
   462  	return res
   463  }
   464  
   465  func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) {
   466  	var o settings
   467  	var err error
   468  	for _, option := range options {
   469  		option(&o)
   470  	}
   471  
   472  	o.nbVars = assignment.NumVars()
   473  	nbInstances := assignment.NumInstances()
   474  	if 1<<o.nbVars != nbInstances {
   475  		return o, fmt.Errorf("number of instances must be power of 2")
   476  	}
   477  
   478  	if o.pool == nil {
   479  		pool := polynomial.NewPool(c.MemoryRequirements(nbInstances)...)
   480  		o.pool = &pool
   481  	}
   482  
   483  	if o.workers == nil {
   484  		o.workers = utils.NewWorkerPool()
   485  	}
   486  
   487  	if o.sorted == nil {
   488  		o.sorted = topologicalSort(c)
   489  	}
   490  
   491  	if transcriptSettings.Transcript == nil {
   492  		challengeNames := ChallengeNames(o.sorted, o.nbVars, transcriptSettings.Prefix)
   493  		o.transcript = fiatshamir.NewTranscript(transcriptSettings.Hash, challengeNames...)
   494  		for i := range transcriptSettings.BaseChallenges {
   495  			if err = o.transcript.Bind(challengeNames[0], transcriptSettings.BaseChallenges[i]); err != nil {
   496  				return o, err
   497  			}
   498  		}
   499  	} else {
   500  		o.transcript, o.transcriptPrefix = transcriptSettings.Transcript, transcriptSettings.Prefix
   501  	}
   502  
   503  	return o, err
   504  }
   505  
   506  // ProofSize computes how large the proof for a circuit would be. It needs nbUniqueOutputs to be set
   507  func ProofSize(c Circuit, logNbInstances int) int {
   508  	nbUniqueInputs := 0
   509  	nbPartialEvalPolys := 0
   510  	for i := range c {
   511  		nbUniqueInputs += c[i].nbUniqueOutputs // each unique output is manifest in a finalEvalProof entry
   512  		if !c[i].noProof() {
   513  			nbPartialEvalPolys += c[i].Gate.Degree() + 1
   514  		}
   515  	}
   516  	return nbUniqueInputs + nbPartialEvalPolys*logNbInstances
   517  }
   518  
   519  func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string {
   520  
   521  	// Pre-compute the size TODO: Consider not doing this and just grow the list by appending
   522  	size := logNbInstances // first challenge
   523  
   524  	for _, w := range sorted {
   525  		if w.noProof() { // no proof, no challenge
   526  			continue
   527  		}
   528  		if w.NbClaims() > 1 { //combine the claims
   529  			size++
   530  		}
   531  		size += logNbInstances // full run of sumcheck on logNbInstances variables
   532  	}
   533  
   534  	nums := make([]string, utils.Max(len(sorted), logNbInstances))
   535  	for i := range nums {
   536  		nums[i] = strconv.Itoa(i)
   537  	}
   538  
   539  	challenges := make([]string, size)
   540  
   541  	// output wire claims
   542  	firstChallengePrefix := prefix + "fC."
   543  	for j := 0; j < logNbInstances; j++ {
   544  		challenges[j] = firstChallengePrefix + nums[j]
   545  	}
   546  	j := logNbInstances
   547  	for i := len(sorted) - 1; i >= 0; i-- {
   548  		if sorted[i].noProof() {
   549  			continue
   550  		}
   551  		wirePrefix := prefix + "w" + nums[i] + "."
   552  
   553  		if sorted[i].NbClaims() > 1 {
   554  			challenges[j] = wirePrefix + "comb"
   555  			j++
   556  		}
   557  
   558  		partialSumPrefix := wirePrefix + "pSP."
   559  		for k := 0; k < logNbInstances; k++ {
   560  			challenges[j] = partialSumPrefix + nums[k]
   561  			j++
   562  		}
   563  	}
   564  	return challenges
   565  }
   566  
   567  func getFirstChallengeNames(logNbInstances int, prefix string) []string {
   568  	res := make([]string, logNbInstances)
   569  	firstChallengePrefix := prefix + "fC."
   570  	for i := 0; i < logNbInstances; i++ {
   571  		res[i] = firstChallengePrefix + strconv.Itoa(i)
   572  	}
   573  	return res
   574  }
   575  
   576  func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) {
   577  	res := make([]fr.Element, len(names))
   578  	for i, name := range names {
   579  		if bytes, err := transcript.ComputeChallenge(name); err == nil {
   580  			res[i].SetBytes(bytes)
   581  		} else {
   582  			return nil, err
   583  		}
   584  	}
   585  	return res, nil
   586  }
   587  
   588  // Prove consistency of the claimed assignment
   589  func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) {
   590  	o, err := setup(c, assignment, transcriptSettings, options...)
   591  	if err != nil {
   592  		return nil, err
   593  	}
   594  	defer o.workers.Stop()
   595  
   596  	claims := newClaimsManager(c, assignment, o)
   597  
   598  	proof := make(Proof, len(c))
   599  	// firstChallenge called rho in the paper
   600  	var firstChallenge []fr.Element
   601  	firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix))
   602  	if err != nil {
   603  		return nil, err
   604  	}
   605  
   606  	wirePrefix := o.transcriptPrefix + "w"
   607  	var baseChallenge [][]byte
   608  	for i := len(c) - 1; i >= 0; i-- {
   609  
   610  		wire := o.sorted[i]
   611  
   612  		if wire.IsOutput() {
   613  			claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool))
   614  		}
   615  
   616  		claim := claims.getClaim(wire)
   617  		if wire.noProof() { // input wires with one claim only
   618  			proof[i] = sumcheck.Proof{
   619  				PartialSumPolys: []polynomial.Polynomial{},
   620  				FinalEvalProof:  []fr.Element{},
   621  			}
   622  		} else {
   623  			if proof[i], err = sumcheck.Prove(
   624  				claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...),
   625  			); err != nil {
   626  				return proof, err
   627  			}
   628  
   629  			finalEvalProof := proof[i].FinalEvalProof.([]fr.Element)
   630  			baseChallenge = make([][]byte, len(finalEvalProof))
   631  			for j := range finalEvalProof {
   632  				bytes := finalEvalProof[j].Bytes()
   633  				baseChallenge[j] = bytes[:]
   634  			}
   635  		}
   636  		// the verifier checks a single claim about input wires itself
   637  		claims.deleteClaim(wire)
   638  	}
   639  
   640  	return proof, nil
   641  }
   642  
   643  // Verify the consistency of the claimed output with the claimed input
   644  // Unlike in Prove, the assignment argument need not be complete
   645  func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error {
   646  	o, err := setup(c, assignment, transcriptSettings, options...)
   647  	if err != nil {
   648  		return err
   649  	}
   650  	defer o.workers.Stop()
   651  
   652  	claims := newClaimsManager(c, assignment, o)
   653  
   654  	var firstChallenge []fr.Element
   655  	firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix))
   656  	if err != nil {
   657  		return err
   658  	}
   659  
   660  	wirePrefix := o.transcriptPrefix + "w"
   661  	var baseChallenge [][]byte
   662  	for i := len(c) - 1; i >= 0; i-- {
   663  		wire := o.sorted[i]
   664  
   665  		if wire.IsOutput() {
   666  			claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool))
   667  		}
   668  
   669  		proofW := proof[i]
   670  		finalEvalProof := proofW.FinalEvalProof.([]fr.Element)
   671  		claim := claims.getLazyClaim(wire)
   672  		if wire.noProof() { // input wires with one claim only
   673  			// make sure the proof is empty
   674  			if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 {
   675  				return fmt.Errorf("no proof allowed for input wire with a single claim")
   676  			}
   677  
   678  			if wire.NbClaims() == 1 { // input wire
   679  				// simply evaluate and see if it matches
   680  				evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool)
   681  				if !claim.claimedEvaluations[0].Equal(&evaluation) {
   682  					return fmt.Errorf("incorrect input wire claim")
   683  				}
   684  			}
   685  		} else if err = sumcheck.Verify(
   686  			claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...),
   687  		); err == nil {
   688  			baseChallenge = make([][]byte, len(finalEvalProof))
   689  			for j := range finalEvalProof {
   690  				bytes := finalEvalProof[j].Bytes()
   691  				baseChallenge[j] = bytes[:]
   692  			}
   693  		} else {
   694  			return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump?
   695  		}
   696  		claims.deleteClaim(wire)
   697  	}
   698  	return nil
   699  }
   700  
   701  // outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata.
   702  func outputsList(c Circuit, indexes map[*Wire]int) [][]int {
   703  	res := make([][]int, len(c))
   704  	for i := range c {
   705  		res[i] = make([]int, 0)
   706  		c[i].nbUniqueOutputs = 0
   707  		if c[i].IsInput() {
   708  			c[i].Gate = IdentityGate{}
   709  		}
   710  	}
   711  	ins := make(map[int]struct{}, len(c))
   712  	for i := range c {
   713  		for k := range ins { // clear map
   714  			delete(ins, k)
   715  		}
   716  		for _, in := range c[i].Inputs {
   717  			inI := indexes[in]
   718  			res[inI] = append(res[inI], i)
   719  			if _, ok := ins[inI]; !ok {
   720  				in.nbUniqueOutputs++
   721  				ins[inI] = struct{}{}
   722  			}
   723  		}
   724  	}
   725  	return res
   726  }
   727  
   728  type topSortData struct {
   729  	outputs    [][]int
   730  	status     []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done
   731  	index      map[*Wire]int
   732  	leastReady int
   733  }
   734  
   735  func (d *topSortData) markDone(i int) {
   736  
   737  	d.status[i] = -1
   738  
   739  	for _, outI := range d.outputs[i] {
   740  		d.status[outI]--
   741  		if d.status[outI] == 0 && outI < d.leastReady {
   742  			d.leastReady = outI
   743  		}
   744  	}
   745  
   746  	for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 {
   747  		d.leastReady++
   748  	}
   749  }
   750  
   751  func indexMap(c Circuit) map[*Wire]int {
   752  	res := make(map[*Wire]int, len(c))
   753  	for i := range c {
   754  		res[&c[i]] = i
   755  	}
   756  	return res
   757  }
   758  
   759  func statusList(c Circuit) []int {
   760  	res := make([]int, len(c))
   761  	for i := range c {
   762  		res[i] = len(c[i].Inputs)
   763  	}
   764  	return res
   765  }
   766  
   767  // topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on
   768  // occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged.
   769  // It also sets the nbOutput flags, and a dummy IdentityGate for input wires.
   770  // Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small.
   771  // Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input
   772  func topologicalSort(c Circuit) []*Wire {
   773  	var data topSortData
   774  	data.index = indexMap(c)
   775  	data.outputs = outputsList(c, data.index)
   776  	data.status = statusList(c)
   777  	sorted := make([]*Wire, len(c))
   778  
   779  	for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ {
   780  	}
   781  
   782  	for i := range c {
   783  		sorted[i] = &c[data.leastReady]
   784  		data.markDone(data.leastReady)
   785  	}
   786  
   787  	return sorted
   788  }
   789  
   790  // Complete the circuit evaluation from input values
   791  func (a WireAssignment) Complete(c Circuit) WireAssignment {
   792  
   793  	sortedWires := topologicalSort(c)
   794  	nbInstances := a.NumInstances()
   795  	maxNbIns := 0
   796  
   797  	for _, w := range sortedWires {
   798  		maxNbIns = utils.Max(maxNbIns, len(w.Inputs))
   799  		if a[w] == nil {
   800  			a[w] = make([]fr.Element, nbInstances)
   801  		}
   802  	}
   803  
   804  	parallel.Execute(nbInstances, func(start, end int) {
   805  		ins := make([]fr.Element, maxNbIns)
   806  		for i := start; i < end; i++ {
   807  			for _, w := range sortedWires {
   808  				if !w.IsInput() {
   809  					for inI, in := range w.Inputs {
   810  						ins[inI] = a[in][i]
   811  					}
   812  					a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...)
   813  				}
   814  			}
   815  		}
   816  	})
   817  
   818  	return a
   819  }
   820  
   821  func (a WireAssignment) NumInstances() int {
   822  	for _, aW := range a {
   823  		return len(aW)
   824  	}
   825  	panic("empty assignment")
   826  }
   827  
   828  func (a WireAssignment) NumVars() int {
   829  	for _, aW := range a {
   830  		return aW.NumVars()
   831  	}
   832  	panic("empty assignment")
   833  }
   834  
   835  // SerializeToBigInts flattens a proof object into the given slice of big.Ints
   836  // useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this
   837  func (p Proof) SerializeToBigInts(outs []*big.Int) {
   838  	offset := 0
   839  	for i := range p {
   840  		for _, poly := range p[i].PartialSumPolys {
   841  			frToBigInts(outs[offset:], poly)
   842  			offset += len(poly)
   843  		}
   844  		if p[i].FinalEvalProof != nil {
   845  			finalEvalProof := p[i].FinalEvalProof.([]fr.Element)
   846  			frToBigInts(outs[offset:], finalEvalProof)
   847  			offset += len(finalEvalProof)
   848  		}
   849  	}
   850  }
   851  
   852  func frToBigInts(dst []*big.Int, src []fr.Element) {
   853  	for i := range src {
   854  		src[i].BigInt(dst[i])
   855  	}
   856  }
   857  
   858  // Gates defined by name
   859  var Gates = map[string]Gate{
   860  	"identity": IdentityGate{},
   861  	"add":      AddGate{},
   862  	"sub":      SubGate{},
   863  	"neg":      NegGate{},
   864  	"mul":      MulGate(2),
   865  }
   866  
   867  type IdentityGate struct{}
   868  type AddGate struct{}
   869  type MulGate int
   870  type SubGate struct{}
   871  type NegGate struct{}
   872  
   873  func (IdentityGate) Evaluate(input ...fr.Element) fr.Element {
   874  	return input[0]
   875  }
   876  
   877  func (IdentityGate) Degree() int {
   878  	return 1
   879  }
   880  
   881  func (g AddGate) Evaluate(x ...fr.Element) (res fr.Element) {
   882  	switch len(x) {
   883  	case 0:
   884  	// set zero
   885  	case 1:
   886  		res.Set(&x[0])
   887  	default:
   888  		res.Add(&x[0], &x[1])
   889  		for i := 2; i < len(x); i++ {
   890  			res.Add(&res, &x[i])
   891  		}
   892  	}
   893  	return
   894  }
   895  
   896  func (g AddGate) Degree() int {
   897  	return 1
   898  }
   899  
   900  func (g MulGate) Evaluate(x ...fr.Element) (res fr.Element) {
   901  	if len(x) != int(g) {
   902  		panic("wrong input count")
   903  	}
   904  	switch len(x) {
   905  	case 0:
   906  		res.SetOne()
   907  	case 1:
   908  		res.Set(&x[0])
   909  	default:
   910  		res.Mul(&x[0], &x[1])
   911  		for i := 2; i < len(x); i++ {
   912  			res.Mul(&res, &x[i])
   913  		}
   914  	}
   915  	return
   916  }
   917  
   918  func (g MulGate) Degree() int {
   919  	return int(g)
   920  }
   921  
   922  func (g SubGate) Evaluate(element ...fr.Element) (diff fr.Element) {
   923  	if len(element) > 2 {
   924  		panic("not implemented") //TODO
   925  	}
   926  	diff.Sub(&element[0], &element[1])
   927  	return
   928  }
   929  
   930  func (g SubGate) Degree() int {
   931  	return 1
   932  }
   933  
   934  func (g NegGate) Evaluate(element ...fr.Element) (neg fr.Element) {
   935  	if len(element) != 1 {
   936  		panic("univariate gate")
   937  	}
   938  	neg.Neg(&element[0])
   939  	return
   940  }
   941  
   942  func (g NegGate) Degree() int {
   943  	return 1
   944  }