github.com/consensys/gnark@v0.11.0/internal/generator/backend/template/representations/gkr.go.tmpl (about) 1 import ( 2 "fmt" 3 {{- template "import_fr" .}} 4 {{- template "import_gkr" .}} 5 {{- template "import_polynomial" .}} 6 fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" 7 "github.com/consensys/gnark-crypto/utils" 8 "github.com/consensys/gnark/constraint" 9 hint "github.com/consensys/gnark/constraint/solver" 10 "github.com/consensys/gnark/internal/algo_utils" 11 "hash" 12 "math/big" 13 "sync" 14 ) 15 16 type GkrSolvingData struct { 17 assignments gkr.WireAssignment 18 circuit gkr.Circuit 19 memoryPool polynomial.Pool 20 workers *utils.WorkerPool 21 } 22 23 func convertCircuit(noPtr constraint.GkrCircuit) (gkr.Circuit, error) { 24 resCircuit := make(gkr.Circuit, len(noPtr)) 25 var found bool 26 for i := range noPtr { 27 if resCircuit[i].Gate, found = gkr.Gates[noPtr[i].Gate]; !found && noPtr[i].Gate != "" { 28 return nil, fmt.Errorf("gate \"%s\" not found", noPtr[i].Gate) 29 } 30 resCircuit[i].Inputs = algo_utils.Map(noPtr[i].Inputs, algo_utils.SlicePtrAt(resCircuit)) 31 } 32 return resCircuit, nil 33 } 34 35 func (d *GkrSolvingData) init(info constraint.GkrInfo) (assignment gkrAssignment, err error) { 36 if d.circuit, err = convertCircuit(info.Circuit); err != nil { 37 return 38 } 39 d.memoryPool = polynomial.NewPool(d.circuit.MemoryRequirements(info.NbInstances)...) 40 d.workers = utils.NewWorkerPool() 41 42 assignment = make(gkrAssignment, len(d.circuit)) 43 d.assignments = make(gkr.WireAssignment, len(d.circuit)) 44 for i := range assignment { 45 assignment[i] = d.memoryPool.Make(info.NbInstances) 46 d.assignments[&d.circuit[i]] = assignment[i] 47 } 48 return 49 } 50 51 func (d *GkrSolvingData) dumpAssignments() { 52 for _, p := range d.assignments { 53 d.memoryPool.Dump(p) 54 } 55 } 56 57 // this module assumes that wire and instance indexes respect dependencies 58 59 type gkrAssignment [][]fr.Element //gkrAssignment is indexed wire first, instance second 60 61 func (a gkrAssignment) setOuts(circuit constraint.GkrCircuit, outs []*big.Int) { 62 outsI := 0 63 for i := range circuit { 64 if circuit[i].IsOutput() { 65 for j := range a[i] { 66 a[i][j].BigInt(outs[outsI]) 67 outsI++ 68 } 69 } 70 } 71 // Check if outsI == len(outs)? 72 } 73 74 func GkrSolveHint(info constraint.GkrInfo, solvingData *GkrSolvingData) hint.Hint { 75 return func(_ *big.Int, ins, outs []*big.Int) error { 76 // assumes assignmentVector is arranged wire first, instance second in order of solution 77 circuit := info.Circuit 78 nbInstances := info.NbInstances 79 offsets := info.AssignmentOffsets() 80 assignment, err := solvingData.init(info) 81 if err != nil { 82 return err 83 } 84 chunks := circuit.Chunks(nbInstances) 85 86 solveTask := func(chunkOffset int) utils.Task { 87 return func(startInChunk, endInChunk int) { 88 start := startInChunk + chunkOffset 89 end := endInChunk + chunkOffset 90 inputs := solvingData.memoryPool.Make(info.MaxNIns) 91 dependencyHeads := make([]int, len(circuit)) 92 for wI, w := range circuit { 93 dependencyHeads[wI] = algo_utils.BinarySearchFunc(func(i int) int { 94 return w.Dependencies[i].InputInstance 95 }, len(w.Dependencies), start) 96 } 97 98 for instanceI := start; instanceI < end; instanceI++ { 99 for wireI, wire := range circuit { 100 if wire.IsInput() { 101 if dependencyHeads[wireI] < len(wire.Dependencies) && instanceI == wire.Dependencies[dependencyHeads[wireI]].InputInstance { 102 dep := wire.Dependencies[dependencyHeads[wireI]] 103 assignment[wireI][instanceI].Set(&assignment[dep.OutputWire][dep.OutputInstance]) 104 dependencyHeads[wireI]++ 105 } else { 106 assignment[wireI][instanceI].SetBigInt(ins[offsets[wireI]+instanceI-dependencyHeads[wireI]]) 107 } 108 } else { 109 // assemble the inputs 110 inputIndexes := info.Circuit[wireI].Inputs 111 for i, inputI := range inputIndexes { 112 inputs[i].Set(&assignment[inputI][instanceI]) 113 } 114 gate := solvingData.circuit[wireI].Gate 115 assignment[wireI][instanceI] = gate.Evaluate(inputs[:len(inputIndexes)]...) 116 } 117 } 118 } 119 solvingData.memoryPool.Dump(inputs) 120 } 121 } 122 123 start := 0 124 for _, end := range chunks { 125 solvingData.workers.Submit(end-start, solveTask(start), 1024).Wait() 126 start = end 127 } 128 129 assignment.setOuts(info.Circuit, outs) 130 131 return nil 132 } 133 } 134 135 func frToBigInts(dst []*big.Int, src []fr.Element) { 136 for i := range src { 137 src[i].BigInt(dst[i]) 138 } 139 } 140 141 func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { 142 143 return func(_ *big.Int, ins, outs []*big.Int) error { 144 insBytes := algo_utils.Map(ins[1:], func(i *big.Int) []byte { // the first input is dummy, just to ensure the solver's work is done before the prover is called 145 b := make([]byte, fr.Bytes) 146 i.FillBytes(b) 147 return b[:] 148 }) 149 150 hsh, err := GetHashBuilder(hashName) 151 if err != nil { 152 return err 153 } 154 155 proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh(), insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) 156 if err != nil { 157 return err 158 } 159 160 // serialize proof: TODO: In gnark-crypto? 161 offset := 0 162 for i := range proof { 163 for _, poly := range proof[i].PartialSumPolys { 164 frToBigInts(outs[offset:], poly) 165 offset += len(poly) 166 } 167 if proof[i].FinalEvalProof != nil { 168 finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) 169 frToBigInts(outs[offset:], finalEvalProof) 170 offset += len(finalEvalProof) 171 } 172 } 173 174 data.dumpAssignments() 175 176 return nil 177 178 } 179 } 180 181 // TODO: Move to gnark-crypto 182 var ( 183 hashBuilderRegistry = make(map[string]func() hash.Hash) 184 hasBuilderLock sync.RWMutex 185 ) 186 187 func RegisterHashBuilder(name string, builder func() hash.Hash) { 188 hasBuilderLock.Lock() 189 defer hasBuilderLock.Unlock() 190 hashBuilderRegistry[name] = builder 191 } 192 193 func GetHashBuilder(name string) (func() hash.Hash, error) { 194 hasBuilderLock.RLock() 195 defer hasBuilderLock.RUnlock() 196 builder, ok := hashBuilderRegistry[name] 197 if !ok { 198 return nil, fmt.Errorf("hash function not found") 199 } 200 return builder, nil 201 } 202 203 204 // For testing purposes 205 type ConstPseudoHash int 206 207 func (c ConstPseudoHash) Write(p []byte) (int, error) { 208 return len(p), nil 209 } 210 211 func (c ConstPseudoHash) Sum([]byte) []byte { 212 var x fr.Element 213 x.SetInt64(int64(c)) 214 res := x.Bytes() 215 return res[:] 216 } 217 218 func (c ConstPseudoHash) Reset() {} 219 220 func (c ConstPseudoHash) Size() int { 221 return fr.Bytes 222 } 223 224 func (c ConstPseudoHash) BlockSize() int { 225 return fr.Bytes 226 }