github.com/consensys/gnark@v0.11.0/internal/generator/backend/template/representations/gkr.go.tmpl (about)

     1  import (
     2  	"fmt"
     3  	{{- template "import_fr" .}}
     4  	{{- template "import_gkr" .}}
     5  	{{- template "import_polynomial" .}}
     6  	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
     7  	"github.com/consensys/gnark-crypto/utils"
     8  	"github.com/consensys/gnark/constraint"
     9  	hint "github.com/consensys/gnark/constraint/solver"
    10  	"github.com/consensys/gnark/internal/algo_utils"
    11  	"hash"
    12  	"math/big"
    13  	"sync"
    14  )
    15  
    16  type GkrSolvingData struct {
    17  	assignments gkr.WireAssignment
    18  	circuit     gkr.Circuit
    19  	memoryPool  polynomial.Pool
    20  	workers     *utils.WorkerPool
    21  }
    22  
    23  func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) {
    24  	resCircuit := make(gkr.Circuit, len(noPtr))
    25  	var found bool
    26  	for i := range noPtr {
    27  		if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" {
    28  			return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate)
    29  		}
    30  		resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit))
    31  	}
    32  	return resCircuit, nil
    33  }
    34  
    35  func (d *GkrSolvingData) init(info constraint.GkrInfo) (assignment gkrAssignment, err error) {
    36  	if d.circuit, err = convertCircuit(info.Circuit); err != nil {
    37  		return
    38  	}
    39  	d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...)
    40  	d.workers = utils.NewWorkerPool()
    41  
    42  	assignment = make(gkrAssignment, len(d.circuit))
    43  	d.assignments = make(gkr.WireAssignment, len(d.circuit))
    44  	for i := range assignment {
    45  		assignment[i] = d.memoryPool.Make(info.NbInstances)
    46  		d.assignments[&d.circuit[i]] = assignment[i]
    47  	}
    48  	return
    49  }
    50  
    51  func (d *GkrSolvingData) dumpAssignments() {
    52  	for _, p := range d.assignments {
    53  		d.memoryPool.Dump(p)
    54  	}
    55  }
    56  
    57  // this module assumes that wire and instance indexes respect dependencies
    58  
    59  type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second
    60  
    61  func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) {
    62  	outsI := 0
    63  	for i := range circuit {
    64  		if circuit[i].IsOutput() {
    65  			for j := range a[i] {
    66  				a[i][j].BigInt(outs[outsI])
    67  				outsI++
    68  			}
    69  		}
    70  	}
    71  	// Check if outsI == len(outs)?
    72  }
    73  
    74  func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint {
    75  	return func(_ *big.Int, ins, outs []*big.Int) error {
    76  		// assumes assignmentVector is arranged wire first, instance second in order of solution
    77  		circuit := info.Circuit
    78  		nbInstances := info.NbInstances
    79  		offsets := info.AssignmentOffsets()
    80  		assignment, err := solvingData.init(info)
    81  		if err != nil {
    82  			return err
    83  		}
    84  		chunks := circuit.Chunks(nbInstances)
    85  
    86  		solveTask := func(chunkOffset int) utils.Task {
    87  			return func(startInChunk, endInChunk int) {
    88  				start := startInChunk + chunkOffset
    89  				end := endInChunk + chunkOffset
    90  				inputs := solvingData.memoryPool.Make(info.MaxNIns)
    91  				dependencyHeads := make([]int, len(circuit))
    92  				for wI, w := range circuit {
    93  					dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int {
    94  						return w.Dependencies[i].InputInstance
    95  					}, len(w.Dependencies), start)
    96  				}
    97  
    98  				for instanceI := start; instanceI < end; instanceI++ {
    99  					for wireI, wire := range circuit {
   100  						if wire.IsInput() {
   101  							if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance {
   102  								dep := wire.Dependencies[dependencyHeads[wireI]]
   103  								assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance])
   104  								dependencyHeads[wireI]++
   105  							} else {
   106  								assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]])
   107  							}
   108  						} else {
   109  							// assemble the inputs
   110  							inputIndexes := info.Circuit[wireI].Inputs
   111  							for i, inputI := range inputIndexes {
   112  								inputs[i].Set(&assignment[inputI][instanceI])
   113  							}
   114  							gate := solvingData.circuit[wireI].Gate
   115  							assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...)
   116  						}
   117  					}
   118  				}
   119  				solvingData.memoryPool.Dump(inputs)
   120  			}
   121  		}
   122  
   123  		start := 0
   124  		for _, end := range chunks {
   125  			solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait()
   126  			start = end
   127  		}
   128  
   129  		assignment.setOuts(info.Circuit, outs)
   130  
   131  		return nil
   132  	}
   133  }
   134  
   135  func frToBigInts(dst []*big.Int, src []fr.Element) {
   136  	for i := range src {
   137  		src[i].BigInt(dst[i])
   138  	}
   139  }
   140  
   141  func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint {
   142  
   143  	return func(_ *big.Int, ins, outs []*big.Int) error {
   144  		insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called
   145  			b := make([]byte, fr.Bytes)
   146  			i.FillBytes(b)
   147  			return b[:]
   148  		})
   149  
   150  		hsh, err := GetHashBuilder(hashName)
   151  		if err != nil {
   152  			return err
   153  		}
   154  
   155  		proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh(), insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers))
   156  		if err != nil {
   157  			return err
   158  		}
   159  
   160  		// serialize proof: TODO: In gnark-crypto?
   161  		offset := 0
   162  		for i := range proof {
   163  			for _, poly := range proof[i].PartialSumPolys {
   164  				frToBigInts(outs[offset:], poly)
   165  				offset += len(poly)
   166  			}
   167  			if proof[i].FinalEvalProof != nil {
   168  				finalEvalProof := proof[i].FinalEvalProof.([]fr.Element)
   169  				frToBigInts(outs[offset:], finalEvalProof)
   170  				offset += len(finalEvalProof)
   171  			}
   172  		}
   173  
   174  		data.dumpAssignments()
   175  
   176  		return nil
   177  
   178  	}
   179  }
   180  
   181  // TODO: Move to gnark-crypto
   182  var (
   183  	hashBuilderRegistry = make(map[string]func() hash.Hash)
   184  	hasBuilderLock sync.RWMutex
   185  )
   186  
   187  func RegisterHashBuilder(name string, builder func() hash.Hash) {
   188  	hasBuilderLock.Lock()
   189  	defer hasBuilderLock.Unlock()
   190  	hashBuilderRegistry[name] = builder
   191  }
   192  
   193  func GetHashBuilder(name string) (func() hash.Hash, error) {
   194  	hasBuilderLock.RLock()
   195  	defer hasBuilderLock.RUnlock()
   196  	builder, ok := hashBuilderRegistry[name]
   197  	if !ok {
   198  		return nil, fmt.Errorf("hash function not found")
   199  	}
   200  	return builder, nil
   201  }
   202  
   203  
   204  // For testing purposes
   205  type ConstPseudoHash int
   206  
   207  func (c ConstPseudoHash) Write(p []byte) (int, error) {
   208  	return len(p), nil
   209  }
   210  
   211  func (c ConstPseudoHash) Sum([]byte) []byte {
   212  	var x fr.Element
   213  	x.SetInt64(int64(c))
   214  	res := x.Bytes()
   215  	return res[:]
   216  }
   217  
   218  func (c ConstPseudoHash) Reset() {}
   219  
   220  func (c ConstPseudoHash) Size() int {
   221  	return fr.Bytes
   222  }
   223  
   224  func (c ConstPseudoHash) BlockSize() int {
   225  	return fr.Bytes
   226  }