github.com/consensys/gnark-crypto@v0.14.0/internal/generator/gkr/template/gkr.go.tmpl (about)

     1  import (
     2  	"fmt"
     3  	"{{.FieldPackagePath}}"
     4  	"{{.FieldPackagePath}}/polynomial"
     5  	"{{.FieldPackagePath}}/sumcheck"
     6  	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
     7  	"github.com/consensys/gnark-crypto/internal/parallel"
     8  	"github.com/consensys/gnark-crypto/utils"
     9  	"math/big"
    10  	"strconv"
    11  	"sync"
    12  )
    13  
    14  {{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}}
    15  
    16  // The goal is to prove/verify evaluations of many instances of the same circuit
    17  
    18  // Gate must be a low-degree polynomial
    19  type Gate interface {
    20  	Evaluate(...{{.ElementType}}) {{.ElementType}}
    21  	Degree() int
    22  }
    23  
    24  type Wire struct {
    25  	Gate            Gate
    26  	Inputs          []*Wire // if there are no Inputs, the wire is assumed an input wire
    27  	nbUniqueOutputs int     // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one)
    28  }
    29  
    30  type Circuit []Wire
    31  
    32  func (w Wire) IsInput() bool {
    33  	return len(w.Inputs) == 0
    34  }
    35  
    36  func (w Wire) IsOutput() bool {
    37  	return w.nbUniqueOutputs == 0
    38  }
    39  
    40  func (w Wire) NbClaims() int {
    41  	if w.IsOutput() {
    42  		return 1
    43  	}
    44  	return w.nbUniqueOutputs
    45  }
    46  
    47  func (w Wire) noProof() bool {
    48  	return w.IsInput() && w.NbClaims() == 1
    49  }
    50  
    51  func (c Circuit) maxGateDegree() int {
    52  	res := 1
    53  	for i := range c {
    54  		if !c[i].IsInput() {
    55  			res = utils.Max(res, c[i].Gate.Degree())
    56  		}
    57  	}
    58  	return res
    59  }
    60  
    61  // WireAssignment is assignment of values to the same wire across many instances of the circuit
    62  type WireAssignment map[*Wire]polynomial.MultiLin
    63  
    64  type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial)
    65  
    66  type eqTimesGateEvalSumcheckLazyClaims struct {
    67  	wire               *Wire
    68  	evaluationPoints   [][]{{.ElementType}}
    69  	claimedEvaluations []{{.ElementType}}
    70  	manager            *claimsManager // WARNING: Circular references
    71  }
    72  
    73  func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int {
    74  	return len(e.evaluationPoints)
    75  }
    76  
    77  func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int {
    78  	return len(e.evaluationPoints[0])
    79  }
    80  
    81  func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a {{.ElementType}}) {{.ElementType}} {
    82  	evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations)
    83  	return evalsAsPoly.Eval(&a)
    84  }
    85  
    86  func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int {
    87  	return 1 + e.wire.Gate.Degree()
    88  }
    89  
    90  func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error {
    91  	inputEvaluationsNoRedundancy := proof.([]{{.ElementType}})
    92  
    93  	// the eq terms
    94  	numClaims := len(e.evaluationPoints)
    95  	evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r)
    96  	for i := numClaims - 2; i >= 0; i-- {
    97  		evaluation.Mul(&evaluation, &combinationCoeff)
    98  		eq := polynomial.EvalEq(e.evaluationPoints[i], r)
    99  		evaluation.Add(&evaluation, &eq)
   100  	}
   101  
   102  	// the g(...) term
   103  	var gateEvaluation {{.ElementType}}
   104  	if e.wire.IsInput() {
   105  		gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool)
   106  	} else {
   107  		inputEvaluations := make([]{{.ElementType}}, len(e.wire.Inputs))
   108  		indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy))
   109  
   110  		proofI := 0
   111  		for inI, in := range e.wire.Inputs {
   112  			indexInProof, found := indexesInProof[in]
   113  			if !found {
   114  				indexInProof = proofI
   115  				indexesInProof[in] = indexInProof
   116  
   117  				// defer verification, store new claim
   118  				e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof])
   119  				proofI++
   120  			}
   121  			inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof]
   122  		}
   123  		if proofI != len(inputEvaluationsNoRedundancy) {
   124  			return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI)
   125  		}
   126  		gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...)
   127  	}
   128  
   129  	evaluation.Mul(&evaluation, &gateEvaluation)
   130  
   131  	if evaluation.Equal(&purportedValue) {
   132  		return nil
   133  	}
   134  	return fmt.Errorf("incompatible evaluations")
   135  }
   136  
   137  type eqTimesGateEvalSumcheckClaims struct {
   138  	wire               *Wire
   139  	evaluationPoints   [][]{{.ElementType}} // x in the paper
   140  	claimedEvaluations []{{.ElementType}}   // y in the paper
   141  	manager            *claimsManager
   142  
   143  	inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations
   144  
   145  	eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -)
   146  }
   147  
   148  func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff {{.ElementType}}) polynomial.Polynomial {
   149  	varsNum := c.VarsNum()
   150  	eqLength := 1 << varsNum
   151  	claimsNum := c.ClaimsNum()
   152  	// initialize the eq tables
   153  	c.eq = c.manager.memPool.Make(eqLength)
   154  
   155  	c.eq[0].SetOne()
   156  	c.eq.Eq(c.evaluationPoints[0])
   157  
   158  	newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength))
   159  	aI := combinationCoeff
   160  
   161  	for k := 1; k < claimsNum; k++ { //TODO: parallelizable?
   162  		// define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points
   163  		newEq[0].Set(&aI)
   164  
   165  		c.eqAcc(c.eq, newEq,c.evaluationPoints[k])
   166  
   167  		// newEq.Eq(c.evaluationPoints[k])
   168  		// eqAsPoly := polynomial.Polynomial(c.eq) //just semantics
   169  		// eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq))
   170  
   171  		if k+1 < claimsNum {
   172  			aI.Mul(&aI, &combinationCoeff)
   173  		}
   174  	}
   175  
   176  	c.manager.memPool.Dump(newEq)
   177  
   178  	// from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree
   179  
   180  	return c.computeGJ()
   181  }
   182  
   183  // eqAcc sets m to an eq table at q and then adds it to e
   184  func (c *eqTimesGateEvalSumcheckClaims) eqAcc(e, m polynomial.MultiLin, q []{{.ElementType}}) {
   185  	n := len(q)
   186  
   187  	//At the end of each iteration, m(h₁, ..., hₙ) = Eq(q₁, ..., qᵢ₊₁, h₁, ..., hᵢ₊₁)
   188  	for i := range q { // In the comments we use a 1-based index so q[i] = qᵢ₊₁
   189  		// go through all assignments of (b₁, ..., bᵢ) ∈ {0,1}ⁱ
   190  		const threshold = 1 << 6
   191  		k := 1 << i
   192  		if k < threshold {
   193  			for j := 0; j < k; j++ {
   194  				j0 := j << (n - i)    // bᵢ₊₁ = 0
   195  				j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1
   196  
   197  				m[j1].Mul(&q[i], &m[j0])    // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁
   198  				m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁)
   199  			}
   200  		} else {
   201  			c.manager.workers.Submit(k, func(start, end int) {
   202  				for j := start; j < end; j++ {
   203  					j0 := j << (n - i)    // bᵢ₊₁ = 0
   204  					j1 := j0 + 1<<(n-1-i) // bᵢ₊₁ = 1
   205  
   206  					m[j1].Mul(&q[i], &m[j0])    // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 1) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) qᵢ₊₁
   207  					m[j0].Sub(&m[j0], &m[j1]) // Eq(q₁, ..., qᵢ₊₁, b₁, ..., bᵢ, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) Eq(qᵢ₊₁, 0) = Eq(q₁, ..., qᵢ, b₁, ..., bᵢ) (1-qᵢ₊₁)
   208  				}
   209  			}, 1024).Wait()
   210  		}
   211  
   212  	}
   213  	c.manager.workers.Submit(len(e), func(start, end int) {
   214  		for i := start; i < end; i++ {
   215  			e[i].Add(&e[i], &m[i])
   216  		}
   217  	}, 512).Wait()
   218  
   219  	// e.Add(e, polynomial.Polynomial(m))
   220  }
   221  
   222  
   223  // computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where  E = ∑ eq_k
   224  // the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)).
   225  // The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = gⱼ₋₁(rⱼ₋₁). By convention, g₀ is a constant polynomial equal to the claimed sum.
   226  func (c *eqTimesGateEvalSumcheckClaims) computeGJ() polynomial.Polynomial {
   227  
   228  	degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j)
   229  	nbGateIn := len(c.inputPreprocessors)
   230  
   231  	// Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables
   232  	s := make([]polynomial.MultiLin, nbGateIn+1)
   233  	s[0] = c.eq
   234  	copy(s[1:], c.inputPreprocessors)
   235  
   236  	// Perf-TODO: Collate once at claim "combination" time and not again. then, even folding can be done in one operation every time "next" is called
   237  	nbInner := len(s) // wrt output, which has high nbOuter and low nbInner
   238  	nbOuter := len(s[0]) / 2
   239  
   240  	gJ := make([]{{.ElementType}}, degGJ)
   241  	var mu sync.Mutex
   242  	computeAll := func(start, end int) {
   243  		var step {{.ElementType}}
   244  
   245  		res := make([]{{.ElementType}}, degGJ)
   246  		operands := make([]{{.ElementType}}, degGJ*nbInner)
   247  
   248  		for i := start; i < end; i++ {
   249  
   250  			block := nbOuter + i
   251  			for j := 0; j < nbInner; j++ {
   252  				step.Set(&s[j][i])
   253  				operands[j].Set(&s[j][block])
   254  				step.Sub(&operands[j], &step)
   255  				for d := 1; d < degGJ; d++ {
   256  					operands[d*nbInner+j].Add(&operands[(d-1)*nbInner+j], &step)
   257  				}
   258  			}
   259  
   260  			_s := 0
   261  			_e := nbInner
   262  			for d := 0; d < degGJ; d++ {
   263  				summand := c.wire.Gate.Evaluate(operands[_s+1 : _e]...)
   264  				summand.Mul(&summand, &operands[_s])
   265  				res[d].Add(&res[d], &summand)
   266  				_s, _e = _e, _e+nbInner
   267  			}
   268  		}
   269  		mu.Lock()
   270  		for i := 0; i < len(gJ); i++ {
   271  			gJ[i].Add(&gJ[i], &res[i])
   272  		}
   273  		mu.Unlock()
   274  	}
   275  
   276  	const minBlockSize = 64
   277  
   278  	if nbOuter < minBlockSize {
   279  		// no parallelization
   280  		computeAll(0, nbOuter)
   281  	} else {
   282  		c.manager.workers.Submit(nbOuter, computeAll, minBlockSize).Wait()
   283  	}
   284  
   285  	// Perf-TODO: Separate functions Gate.TotalDegree and Gate.Degree(i) so that we get to use possibly smaller values for degGJ. Won't help with MiMC though
   286  
   287  	return gJ
   288  }
   289  
   290  // Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j
   291  func (c *eqTimesGateEvalSumcheckClaims) Next(element {{.ElementType}}) polynomial.Polynomial {
   292  	const minBlockSize = 512
   293  	n := len(c.eq) / 2
   294  	if n < minBlockSize {
   295  		// no parallelization
   296  		for i := 0; i < len(c.inputPreprocessors); i++ {
   297  			c.inputPreprocessors[i].Fold(element)
   298  		}
   299  		c.eq.Fold(element)
   300  	} else {
   301  		wgs := make([]*sync.WaitGroup, len(c.inputPreprocessors))
   302  		for i := 0; i < len(c.inputPreprocessors); i++ {
   303  			wgs[i] = c.manager.workers.Submit(n, c.inputPreprocessors[i].FoldParallel(element), minBlockSize)
   304  		}
   305  		c.manager.workers.Submit(n, c.eq.FoldParallel(element), minBlockSize).Wait()
   306  		for _, wg := range wgs {
   307  			wg.Wait()
   308  		}
   309  	}
   310  
   311  	return c.computeGJ()
   312  }
   313  
   314  func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int {
   315  	return len(c.evaluationPoints[0])
   316  }
   317  
   318  func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int {
   319  	return len(c.claimedEvaluations)
   320  }
   321  
   322  func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []{{.ElementType}}) interface{} {
   323  
   324  	//defer the proof, return list of claims
   325  	evaluations := make([]{{.ElementType}}, 0, len(c.wire.Inputs))
   326  	noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors))
   327  	noMoreClaimsAllowed[c.wire] = struct{}{}
   328  
   329  	for inI, in := range c.wire.Inputs {
   330  		puI := c.inputPreprocessors[inI]
   331  		if _, found := noMoreClaimsAllowed[in]; !found {
   332  			noMoreClaimsAllowed[in] = struct{}{}
   333  			puI.Fold(r[len(r)-1])
   334  			c.manager.add(in, r, puI[0])
   335  			evaluations = append(evaluations, puI[0])
   336  		}
   337  		c.manager.memPool.Dump(puI)
   338  	}
   339  
   340  	c.manager.memPool.Dump(c.claimedEvaluations, c.eq)
   341  
   342  	return evaluations
   343  }
   344  
   345  type claimsManager struct {
   346  	claimsMap  map[*Wire]*eqTimesGateEvalSumcheckLazyClaims
   347  	assignment WireAssignment
   348  	memPool    *polynomial.Pool
   349  	workers    *utils.WorkerPool
   350  }
   351  
   352  func newClaimsManager(c Circuit, assignment WireAssignment, o settings) (claims claimsManager) {
   353  	claims.assignment = assignment
   354  	claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c))
   355  	claims.memPool = o.pool
   356  	claims.workers = o.workers
   357  
   358  	for i := range c {
   359  		wire := &c[i]
   360  
   361  		claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{
   362  			wire:               wire,
   363  			evaluationPoints:   make([][]{{.ElementType}}, 0, wire.NbClaims()),
   364  			claimedEvaluations: claims.memPool.Make(wire.NbClaims()),
   365  			manager:            &claims,
   366  		}
   367  	}
   368  	return
   369  }
   370  
   371  func (m *claimsManager) add(wire *Wire, evaluationPoint []{{.ElementType}}, evaluation {{.ElementType}}) {
   372  	claim := m.claimsMap[wire]
   373  	i := len(claim.evaluationPoints)
   374  	claim.claimedEvaluations[i] = evaluation
   375  	claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint)
   376  }
   377  
   378  func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims {
   379  	return m.claimsMap[wire]
   380  }
   381  
   382  func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims {
   383  	lazy := m.claimsMap[wire]
   384  	res := &eqTimesGateEvalSumcheckClaims{
   385  		wire:               wire,
   386  		evaluationPoints:   lazy.evaluationPoints,
   387  		claimedEvaluations: lazy.claimedEvaluations,
   388  		manager:            m,
   389  	}
   390  
   391  	if wire.IsInput() {
   392  		res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])}
   393  	} else {
   394  		res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs))
   395  
   396  		for inputI, inputW := range wire.Inputs {
   397  			res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied
   398  		}
   399  	}
   400  	return res
   401  }
   402  
   403  func (m *claimsManager) deleteClaim(wire *Wire) {
   404  	delete(m.claimsMap, wire)
   405  }
   406  
   407  type settings struct {
   408  	pool             *polynomial.Pool
   409  	sorted           []*Wire
   410  	transcript       *fiatshamir.Transcript
   411  	transcriptPrefix string
   412  	nbVars           int
   413  	workers          *utils.WorkerPool
   414  }
   415  
   416  type Option func(*settings)
   417  
   418  func WithPool(pool *polynomial.Pool) Option {
   419  	return func (options *settings) {
   420  		options.pool = pool
   421  	}
   422  }
   423  
   424  func WithSortedCircuit(sorted []*Wire) Option {
   425  	return func(options *settings) {
   426  		options.sorted = sorted
   427  	}
   428  }
   429  
   430  func WithWorkers(workers *utils.WorkerPool) Option {
   431  	return func(options *settings) {
   432  		options.workers = workers
   433      }
   434  }
   435  
   436  // MemoryRequirements returns an increasing vector of memory allocation sizes required for proving a GKR statement
   437  func (c Circuit) MemoryRequirements(nbInstances int) []int {
   438  	res := []int{256, nbInstances, nbInstances * (c.maxGateDegree() + 1)}
   439  
   440  	if res[0] > res[1] { // make sure it's sorted
   441  		res[0], res[1] = res[1], res[0]
   442  		if res[1] > res[2] {
   443  			res[1], res[2] = res[2], res[1]
   444  		}
   445  	}
   446  
   447  	return res
   448  }
   449  
   450  func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) {
   451  	var o settings
   452  	var err error
   453   	for _, option := range options {
   454  		option(&o)
   455  	}
   456  
   457  	o.nbVars = assignment.NumVars()
   458  	nbInstances := assignment.NumInstances()
   459  	if 1<<o.nbVars != nbInstances {
   460  		return o, fmt.Errorf("number of instances must be power of 2")
   461  	}
   462  
   463  	if o.pool == nil {
   464  		pool := polynomial.NewPool(c.MemoryRequirements(nbInstances)...)
   465  		o.pool = &pool
   466  	}
   467  
   468  	if o.workers == nil {
   469  		o.workers = utils.NewWorkerPool()
   470  	}
   471  
   472  	if o.sorted == nil {
   473  		o.sorted = {{$topologicalSort}}(c)
   474  	}
   475  
   476  	if transcriptSettings.Transcript == nil {
   477  		challengeNames := ChallengeNames(o.sorted, o.nbVars, transcriptSettings.Prefix)
   478  		o.transcript = fiatshamir.NewTranscript(transcriptSettings.Hash, challengeNames...)
   479  		for i := range transcriptSettings.BaseChallenges {
   480  			if err = o.transcript.Bind(challengeNames[0], transcriptSettings.BaseChallenges[i]); err != nil {
   481  				return o, err
   482  			}
   483  		}
   484  	} else {
   485  		o.transcript, o.transcriptPrefix = transcriptSettings.Transcript, transcriptSettings.Prefix
   486  	}
   487  
   488  	return o, err
   489  }
   490  
   491  // ProofSize computes how large the proof for a circuit would be. It needs nbUniqueOutputs to be set
   492  func ProofSize(c Circuit, logNbInstances int) int {
   493  	nbUniqueInputs := 0
   494  	nbPartialEvalPolys := 0
   495  	for i := range c {
   496  		nbUniqueInputs += c[i].nbUniqueOutputs // each unique output is manifest in a finalEvalProof entry
   497  		if !c[i].noProof() {
   498  			nbPartialEvalPolys += c[i].Gate.Degree() + 1
   499  		}
   500  	}
   501  	return nbUniqueInputs + nbPartialEvalPolys*logNbInstances
   502  }
   503  
   504  func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string {
   505  
   506  	// Pre-compute the size TODO: Consider not doing this and just grow the list by appending
   507  	size := logNbInstances // first challenge
   508  
   509  	for _, w := range sorted {
   510  		if w.noProof() { // no proof, no challenge
   511  			continue
   512  		}
   513  		if w.NbClaims() > 1 { //combine the claims
   514  			size++
   515  		}
   516  		size += logNbInstances // full run of sumcheck on logNbInstances variables
   517  	}
   518  
   519  	nums := make([]string, utils.Max(len(sorted), logNbInstances))
   520  	for i := range nums {
   521  		nums[i] = strconv.Itoa(i)
   522  	}
   523  
   524  	challenges := make([]string, size)
   525  
   526  	// output wire claims
   527  	firstChallengePrefix := prefix + "fC."
   528  	for j := 0; j < logNbInstances; j++ {
   529  		challenges[j] = firstChallengePrefix + nums[j]
   530  	}
   531  	j := logNbInstances
   532  	for i := len(sorted) - 1; i >= 0; i-- {
   533  		if sorted[i].noProof() {
   534  			continue
   535  		}
   536  		wirePrefix := prefix + "w" + nums[i] + "."
   537  
   538  		if sorted[i].NbClaims() > 1 {
   539  			challenges[j] = wirePrefix + "comb"
   540  			j++
   541  		}
   542  
   543  		partialSumPrefix := wirePrefix + "pSP."
   544  		for k := 0; k < logNbInstances; k++ {
   545  			challenges[j] = partialSumPrefix + nums[k]
   546  			j++
   547  		}
   548  	}
   549  	return challenges
   550  }
   551  
   552  func getFirstChallengeNames(logNbInstances int, prefix string) []string {
   553  	res := make([]string, logNbInstances)
   554  	firstChallengePrefix := prefix + "fC."
   555  	for i := 0; i < logNbInstances; i++ {
   556  		res[i] = firstChallengePrefix + strconv.Itoa(i)
   557  	}
   558  	return res
   559  }
   560  
   561  func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]{{.ElementType}}, error) {
   562  	res := make([]{{.ElementType}}, len(names))
   563  	for i, name := range names {
   564  		if bytes, err := transcript.ComputeChallenge(name); err == nil {
   565  			res[i].SetBytes(bytes)
   566  		} else {
   567  			return nil, err
   568  		}
   569  	}
   570  	return res, nil
   571  }
   572  
   573  // Prove consistency of the claimed assignment
   574  func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) {
   575  	o, err := setup(c, assignment, transcriptSettings, options...)
   576  	if err != nil {
   577  		return nil, err
   578  	}
   579  	defer o.workers.Stop()
   580  
   581  	claims := newClaimsManager(c, assignment, o)
   582  
   583  	proof := make(Proof, len(c))
   584  	// firstChallenge called rho in the paper
   585  	var firstChallenge []{{.ElementType}}
   586  	firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix))
   587  	if err != nil {
   588  		return nil, err
   589  	}
   590  
   591  	wirePrefix := o.transcriptPrefix + "w"
   592  	var baseChallenge [][]byte
   593  	for i := len(c) - 1; i >= 0; i-- {
   594  
   595  		wire := o.sorted[i]
   596  
   597  		if wire.IsOutput() {
   598  			claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool))
   599  		}
   600  
   601  		claim := claims.getClaim(wire)
   602  		if wire.noProof() { // input wires with one claim only
   603  			proof[i] = sumcheck.Proof{
   604  				PartialSumPolys: []polynomial.Polynomial{},
   605  				FinalEvalProof:  []{{.ElementType}}{},
   606  			}
   607  		} else {
   608  			if proof[i], err = sumcheck.Prove(
   609  				claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...),
   610  			); err != nil {
   611  				return proof, err
   612  			}
   613  
   614  			finalEvalProof := proof[i].FinalEvalProof.([]{{.ElementType}})
   615  			baseChallenge = make([][]byte, len(finalEvalProof))
   616  			for j := range finalEvalProof {
   617  				bytes := finalEvalProof[j].Bytes()
   618  				baseChallenge[j] = bytes[:]
   619  			}
   620  		}
   621  		// the verifier checks a single claim about input wires itself
   622  		claims.deleteClaim(wire)
   623  	}
   624  
   625  	return proof, nil
   626  }
   627  
   628  // Verify the consistency of the claimed output with the claimed input
   629  // Unlike in Prove, the assignment argument need not be complete
   630  func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error {
   631  	o, err := setup(c, assignment, transcriptSettings, options...)
   632  	if err != nil {
   633  		return err
   634  	}
   635  	defer o.workers.Stop()
   636  
   637  	claims := newClaimsManager(c, assignment, o)
   638  
   639  	var firstChallenge []{{.ElementType}}
   640  	firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix))
   641  	if err != nil {
   642  		return err
   643  	}
   644  
   645  	wirePrefix := o.transcriptPrefix + "w"
   646  	var baseChallenge [][]byte
   647  	for i := len(c) - 1; i >= 0; i-- {
   648  		wire := o.sorted[i]
   649  
   650  		if wire.IsOutput() {
   651  			claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool))
   652  		}
   653  
   654  		proofW := proof[i]
   655  		finalEvalProof := proofW.FinalEvalProof.([]{{.ElementType}})
   656  		claim := claims.getLazyClaim(wire)
   657  		if wire.noProof() { // input wires with one claim only
   658  			// make sure the proof is empty
   659  			if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 {
   660  				return fmt.Errorf("no proof allowed for input wire with a single claim")
   661  			}
   662  
   663  			if wire.NbClaims() == 1 { // input wire
   664  				// simply evaluate and see if it matches
   665  				evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool)
   666  				if !claim.claimedEvaluations[0].Equal(&evaluation) {
   667  					return fmt.Errorf("incorrect input wire claim")
   668  				}
   669  			}
   670  		} else if err = sumcheck.Verify(
   671  			claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...),
   672  		); err == nil {
   673  			baseChallenge = make([][]byte, len(finalEvalProof))
   674  			for j := range finalEvalProof {
   675  				bytes := finalEvalProof[j].Bytes()
   676  				baseChallenge[j] = bytes[:]
   677  			}
   678  		} else {
   679  			return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump?
   680  		}
   681  		claims.deleteClaim(wire)
   682  	}
   683  	return nil
   684  }
   685  
   686  // outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata.
   687  func outputsList(c Circuit, indexes map[*Wire]int) [][]int {
   688  	res := make([][]int, len(c))
   689  	for i := range c {
   690  		res[i] = make([]int, 0)
   691  		c[i].nbUniqueOutputs = 0
   692  		if c[i].IsInput() {
   693  			c[i].Gate = IdentityGate{}
   694  		}
   695  	}
   696  	ins := make(map[int]struct{}, len(c))
   697  	for i := range c {
   698  		for k := range ins {	// clear map
   699  			delete(ins, k)
   700  		}
   701  		for _, in := range c[i].Inputs {
   702  			inI := indexes[in]
   703  			res[inI] = append(res[inI], i)
   704  			if _, ok := ins[inI]; !ok {
   705  				in.nbUniqueOutputs++
   706  				ins[inI] = struct{}{}
   707  			}
   708  		}
   709  	}
   710  	return res
   711  }
   712  
   713  type topSortData struct {
   714  	outputs    [][]int
   715  	status     []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done
   716  	index      map[*Wire]int
   717  	leastReady int
   718  }
   719  
   720  func (d *topSortData) markDone(i int) {
   721  
   722  	d.status[i] = -1
   723  
   724  	for _, outI := range d.outputs[i] {
   725  		d.status[outI]--
   726  		if d.status[outI] == 0 && outI < d.leastReady {
   727  			d.leastReady = outI
   728  		}
   729  	}
   730  
   731  	for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 {
   732  		d.leastReady++
   733  	}
   734  }
   735  
   736  func indexMap(c Circuit) map[*Wire]int {
   737  	res := make(map[*Wire]int, len(c))
   738  	for i := range c {
   739  		res[&c[i]] = i
   740  	}
   741  	return res
   742  }
   743  
   744  func statusList(c Circuit) []int {
   745  	res := make([]int, len(c))
   746  	for i := range c {
   747  		res[i] = len(c[i].Inputs)
   748  	}
   749  	return res
   750  }
   751  
   752  // {{$topologicalSort}} sorts the wires in order of dependence. Such that for any wire, any one it depends on
   753  // occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged.
   754  // It also sets the nbOutput flags, and a dummy IdentityGate for input wires.
   755  // Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small.
   756  // Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input
   757  func {{$topologicalSort}}(c Circuit) []*Wire {
   758  	var data topSortData
   759  	data.index = indexMap(c)
   760  	data.outputs = outputsList(c, data.index)
   761  	data.status = statusList(c)
   762  	sorted := make([]*Wire, len(c))
   763  
   764  	for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ {
   765  	}
   766  
   767  	for i := range c {
   768  		sorted[i] = &c[data.leastReady]
   769  		data.markDone(data.leastReady)
   770  	}
   771  
   772  	return sorted
   773  }
   774  
   775  // Complete the circuit evaluation from input values
   776  func (a WireAssignment) Complete(c Circuit) WireAssignment {
   777  
   778  	sortedWires := {{$topologicalSort}}(c)
   779  	nbInstances := a.NumInstances()
   780  	maxNbIns := 0
   781  
   782  	for _, w := range sortedWires {
   783  		maxNbIns = utils.Max(maxNbIns, len(w.Inputs))
   784  		if a[w] == nil {
   785  			a[w] = make([]{{.ElementType}}, nbInstances)
   786  		}
   787  	}
   788  
   789  	parallel.Execute(nbInstances, func(start, end int) {
   790  		ins := make([]{{.ElementType}}, maxNbIns)
   791  		for i := start; i < end; i++ {
   792  			for _, w := range sortedWires {
   793  				if !w.IsInput() {
   794  					for inI, in := range w.Inputs {
   795  						ins[inI] = a[in][i]
   796  					}
   797  					a[w][i] = w.Gate.Evaluate(ins[:len(w.Inputs)]...)
   798  				}
   799  			}
   800  		}
   801  	})
   802  
   803  	return a
   804  }
   805  
   806  func (a WireAssignment) NumInstances() int {
   807  	for _, aW := range a {
   808  		return len(aW)
   809  	}
   810  	panic("empty assignment")
   811  }
   812  
   813  func (a WireAssignment) NumVars() int {
   814  	for _, aW := range a {
   815  		return aW.NumVars()
   816  	}
   817  	panic("empty assignment")
   818  }
   819  
   820  // SerializeToBigInts flattens a proof object into the given slice of big.Ints
   821  // useful in gnark hints. TODO: Change propagation: Once this is merged, it will duplicate some code in std/gkr/bn254Prover.go. Remove that in favor of this
   822  func (p Proof) SerializeToBigInts(outs []*big.Int) {
   823  	offset := 0
   824  	for i := range p {
   825  		for _, poly := range p[i].PartialSumPolys {
   826  			frToBigInts(outs[offset:], poly)
   827  			offset += len(poly)
   828  		}
   829  		if p[i].FinalEvalProof != nil {
   830  			finalEvalProof := p[i].FinalEvalProof.([]{{.ElementType}})
   831  			frToBigInts(outs[offset:], finalEvalProof)
   832  			offset += len(finalEvalProof)
   833  		}
   834  	}
   835  }
   836  
   837  func frToBigInts(dst []*big.Int, src []{{.ElementType}}) {
   838  	for i := range src {
   839  		src[i].BigInt(dst[i])
   840  	}
   841  }
   842  
   843  // Gates defined by name
   844  var Gates = map[string]Gate{
   845  	"identity": IdentityGate{},
   846  	"add":		AddGate{},
   847  	"sub":		SubGate{},
   848  	"neg":		NegGate{},
   849  	"mul":      MulGate(2),
   850  }
   851  
   852  type IdentityGate struct{}
   853  type AddGate struct{}
   854  type MulGate int
   855  type SubGate struct{}
   856  type NegGate struct{}
   857  
   858  func (IdentityGate) Evaluate(input ...{{.ElementType}}) {{.ElementType}} {
   859  	return input[0]
   860  }
   861  
   862  func (IdentityGate) Degree() int {
   863  	return 1
   864  }
   865  
   866  func (g AddGate) Evaluate(x ...{{.ElementType}}) (res {{.ElementType}}) {
   867  	switch len(x) {
   868  	case 0:
   869  	// set zero
   870  	case 1:
   871  		res.Set(&x[0])
   872  	default:
   873  		res.Add(&x[0], &x[1])
   874  		for i := 2; i < len(x); i++ {
   875  			res.Add(&res, &x[i])
   876  		}
   877  	}
   878  	return
   879  }
   880  
   881  func (g AddGate) Degree() int {
   882  	return 1
   883  }
   884  
   885  func (g MulGate) Evaluate(x ...{{.ElementType}}) (res {{.ElementType}}) {
   886  	if len(x) != int(g) {
   887  		panic("wrong input count")
   888  	}
   889  	switch len(x) {
   890  	case 0:
   891  		res.SetOne()
   892  	case 1:
   893  		res.Set(&x[0])
   894  	default:
   895  		res.Mul(&x[0], &x[1])
   896  		for i := 2; i < len(x); i++ {
   897  			res.Mul(&res, &x[i])
   898  		}
   899  	}
   900  	return
   901  }
   902  
   903  func (g MulGate) Degree() int {
   904  	return int(g)
   905  }
   906  
   907  func (g SubGate) Evaluate(element ...{{.ElementType}}) (diff {{.ElementType}}) {
   908  	if len(element) > 2 {
   909  		panic("not implemented") //TODO
   910  	}
   911  	diff.Sub(&element[0], &element[1])
   912  	return
   913  }
   914  
   915  func (g SubGate) Degree() int {
   916  	return 1
   917  }
   918  
   919  func (g NegGate) Evaluate(element ...{{.ElementType}}) (neg {{.ElementType}}) {
   920  	if len(element) != 1 {
   921  		panic("univariate gate")
   922  	}
   923  	neg.Neg(&element[0])
   924  	return
   925  }
   926  
   927  func (g NegGate) Degree() int {
   928  	return 1
   929  }