github.com/iden3/go-circom-witnesscalc@v0.2.1-0.20230314155733-dd1f248a91b6/witnesscalc.go (about) 1 package witnesscalc 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "fmt" 7 "math" 8 "math/big" 9 "reflect" 10 "unsafe" 11 12 log "github.com/sirupsen/logrus" 13 14 wasm3 "github.com/iden3/go-wasm3" 15 ) 16 17 // witnessCalcFns are wrapper functions to the WitnessCalc WASM module 18 type witnessCalcFns struct { 19 getFrLen func() (int32, error) 20 getPRawPrime func() (int32, error) 21 getNVars func() (int32, error) 22 init func(sanityCheck int32) error 23 getSignalOffset32 func(pR, component, hashMSB, hashLSB int32) error 24 setSignal func(cIdx, component, signal, pVal int32) error 25 getPWitness func(w int32) (int32, error) 26 getWitnessBuffer func() (int32, error) 27 } 28 29 func getStack(sp unsafe.Pointer, length int) []uint64 { 30 var data = (*uint64)(sp) 31 var header reflect.SliceHeader 32 header = *(*reflect.SliceHeader)(unsafe.Pointer(&header)) 33 header.Data = uintptr(unsafe.Pointer(data)) 34 header.Len = int(length) 35 header.Cap = int(length) 36 return *(*[]uint64)(unsafe.Pointer(&header)) 37 } 38 39 func getMem(r *wasm3.Runtime, _mem unsafe.Pointer) []byte { 40 var data = (*uint8)(_mem) 41 length := r.GetAllocatedMemoryLength() 42 var header reflect.SliceHeader 43 header = *(*reflect.SliceHeader)(unsafe.Pointer(&header)) 44 header.Data = uintptr(unsafe.Pointer(data)) 45 header.Len = int(length) 46 header.Cap = int(length) 47 return *(*[]byte)(unsafe.Pointer(&header)) 48 } 49 50 func getStr(mem []byte, p uint64) string { 51 var buf bytes.Buffer 52 for ; mem[p] != 0; p++ { 53 buf.WriteByte(mem[p]) 54 } 55 return buf.String() 56 } 57 58 // newWitnessCalcFns builds the witnessCalcFns from the loaded WitnessCalc WASM 59 // module in the runtime. Imported functions (logging) are binded to dummy functions. 60 func newWitnessCalcFns(r *wasm3.Runtime, m *wasm3.Module, wc *WitnessCalculator) (*witnessCalcFns, error) { 61 r.AttachFunction("runtime", "error", "v(iiiiii)", wasm3.CallbackFunction( 62 func(runtime wasm3.RuntimeT, sp unsafe.Pointer, _mem unsafe.Pointer) int { 63 // func(code, pstr, a, b, c, d) 64 65 stack := getStack(sp, 6) 66 mem := getMem(r, _mem) 67 68 code := stack[0] 69 pstr := stack[1] 70 a := stack[2] 71 b := stack[3] 72 c := stack[4] 73 d := stack[5] 74 75 var errStr string 76 if code == 7 { 77 errStr = fmt.Sprintf("%s %v != %v %s", 78 getStr(mem, pstr), 79 wc.loadFr(int32(b)), wc.loadFr(int32(c)), getStr(mem, d)) 80 } else { 81 errStr = fmt.Sprintf("%s %v %v %v %v", 82 getStr(mem, pstr), a, b, c, getStr(mem, d)) 83 } 84 log.Errorf("WitnessCalculator WASM Error (%v): %v", code, errStr) 85 return 0 86 }, 87 )) 88 r.AttachFunction("runtime", "logSetSignal", "v(ii)", wasm3.CallbackFunction( 89 func(runtime wasm3.RuntimeT, sp unsafe.Pointer, mem unsafe.Pointer) int { 90 return 0 91 }, 92 )) 93 r.AttachFunction("runtime", "logGetSignal", "v(ii)", wasm3.CallbackFunction( 94 func(runtime wasm3.RuntimeT, sp unsafe.Pointer, mem unsafe.Pointer) int { 95 return 0 96 }, 97 )) 98 r.AttachFunction("runtime", "logFinishComponent", "v(i)", wasm3.CallbackFunction( 99 func(runtime wasm3.RuntimeT, sp unsafe.Pointer, mem unsafe.Pointer) int { 100 return 0 101 }, 102 )) 103 r.AttachFunction("runtime", "logStartComponent", "v(i)", wasm3.CallbackFunction( 104 func(runtime wasm3.RuntimeT, sp unsafe.Pointer, mem unsafe.Pointer) int { 105 return 0 106 }, 107 )) 108 r.AttachFunction("runtime", "log", "v(i)", wasm3.CallbackFunction( 109 func(runtime wasm3.RuntimeT, sp unsafe.Pointer, mem unsafe.Pointer) int { 110 return 0 111 }, 112 )) 113 114 _getFrLen, err := r.FindFunction("getFrLen") 115 if err != nil { 116 return nil, err 117 } 118 getFrLen := func() (int32, error) { 119 res, err := _getFrLen() 120 if err != nil { 121 return 0, err 122 } 123 return res.(int32), nil 124 } 125 _getPRawPrime, err := r.FindFunction("getPRawPrime") 126 if err != nil { 127 return nil, err 128 } 129 getPRawPrime := func() (int32, error) { 130 res, err := _getPRawPrime() 131 if err != nil { 132 return 0, err 133 } 134 return res.(int32), nil 135 } 136 _getNVars, err := r.FindFunction("getNVars") 137 if err != nil { 138 return nil, err 139 } 140 getNVars := func() (int32, error) { 141 res, err := _getNVars() 142 if err != nil { 143 return 0, err 144 } 145 return res.(int32), nil 146 } 147 _init, err := r.FindFunction("init") 148 if err != nil { 149 return nil, err 150 } 151 init := func(sanityCheck int32) error { 152 _, err := _init(sanityCheck) 153 if err != nil { 154 return err 155 } 156 return nil 157 } 158 _getSignalOffset32, err := r.FindFunction("getSignalOffset32") 159 if err != nil { 160 return nil, err 161 } 162 getSignalOffset32 := func(pR, component, hashMSB, hashLSB int32) error { 163 _, err := _getSignalOffset32(pR, component, hashMSB, hashLSB) 164 if err != nil { 165 return err 166 } 167 return nil 168 } 169 _setSignal, err := r.FindFunction("setSignal") 170 if err != nil { 171 return nil, err 172 } 173 setSignal := func(cIdx, component, signal, pVal int32) error { 174 _, err := _setSignal(cIdx, component, signal, pVal) 175 if err != nil { 176 return err 177 } 178 return nil 179 } 180 _getPWitness, err := r.FindFunction("getPWitness") 181 if err != nil { 182 return nil, err 183 } 184 getPWitness := func(w int32) (int32, error) { 185 res, err := _getPWitness(w) 186 if err != nil { 187 return 0, err 188 } 189 return res.(int32), nil 190 } 191 _getWitnessBuffer, err := r.FindFunction("getWitnessBuffer") 192 if err != nil { 193 return nil, err 194 } 195 getWitnessBuffer := func() (int32, error) { 196 res, err := _getWitnessBuffer() 197 if err != nil { 198 return 0, err 199 } 200 return res.(int32), nil 201 } 202 203 return &witnessCalcFns{ 204 getFrLen: getFrLen, 205 getPRawPrime: getPRawPrime, 206 getNVars: getNVars, 207 init: init, 208 getSignalOffset32: getSignalOffset32, 209 setSignal: setSignal, 210 getPWitness: getPWitness, 211 getWitnessBuffer: getWitnessBuffer, 212 }, nil 213 } 214 215 // WitnessJSON is a wrapper type to Marshal the Witness in JSON format 216 type WitnessJSON []*big.Int 217 218 // MarshalJSON marshals the WitnessJSON where each value is encoded in base 10 219 // as a string in an array. 220 func (w WitnessJSON) MarshalJSON() ([]byte, error) { 221 var buffer bytes.Buffer 222 buffer.WriteString("[") 223 for i, bi := range w { 224 buffer.WriteString(`"` + bi.String() + `"`) 225 if i != len(w)-1 { 226 buffer.WriteString(",") 227 } 228 } 229 buffer.WriteString("]") 230 return buffer.Bytes(), nil 231 } 232 233 // loadBigInt loads a *big.Int from the runtime memory at position p. 234 func loadBigInt(runtime *wasm3.Runtime, p int32, n int32) *big.Int { 235 bigIntBytes := make([]byte, n) 236 copy(bigIntBytes, runtime.Memory()[p:p+n]) 237 return new(big.Int).SetBytes(swap(bigIntBytes)) 238 } 239 240 // WitnessCalculator is the object that allows performing witness calculation 241 // from signal inputs using the WitnessCalc WASM module. 242 type WitnessCalculator struct { 243 n32 int32 244 prime *big.Int 245 mask32 *big.Int 246 nVars int32 247 n64 uint 248 r *big.Int 249 rInv *big.Int 250 251 shortMax *big.Int 252 shortMin *big.Int 253 254 runtime *wasm3.Runtime 255 fns *witnessCalcFns 256 } 257 258 // NewWitnessCalculator creates a new WitnessCalculator from the WitnessCalc 259 // loaded WASM module in the runtime. 260 func NewWitnessCalculator(runtime *wasm3.Runtime, module *wasm3.Module) (*WitnessCalculator, error) { 261 var wc WitnessCalculator 262 fns, err := newWitnessCalcFns(runtime, module, &wc) 263 if err != nil { 264 return nil, err 265 } 266 267 n32, err := fns.getFrLen() 268 if err != nil { 269 return nil, err 270 } 271 // n32 = (n32 >> 2) - 2 272 n32 = n32 - 8 273 274 pRawPrime, err := fns.getPRawPrime() 275 if err != nil { 276 return nil, err 277 } 278 279 prime := loadBigInt(runtime, pRawPrime, n32) 280 281 mask32 := new(big.Int).SetUint64(0xFFFFFFFF) 282 nVars, err := fns.getNVars() 283 if err != nil { 284 return nil, err 285 } 286 287 n64 := uint(((prime.BitLen() - 1) / 64) + 1) 288 r := new(big.Int).SetInt64(1) 289 r.Lsh(r, n64*64) 290 rInv := new(big.Int).ModInverse(r, prime) 291 292 shortMax, ok := new(big.Int).SetString("0x80000000", 0) 293 if !ok { 294 return nil, fmt.Errorf("unable to set shortMax from string") 295 } 296 shortMin := new(big.Int).Set(prime) 297 shortMin.Sub(shortMin, shortMax) 298 299 wc.n32 = n32 300 wc.prime = prime 301 wc.mask32 = mask32 302 wc.nVars = nVars 303 wc.n64 = n64 304 wc.r = r 305 wc.rInv = rInv 306 wc.shortMin = shortMin 307 wc.shortMax = shortMax 308 wc.runtime = runtime 309 wc.fns = fns 310 return &wc, nil 311 } 312 313 // loadBigInt loads a *big.Int from the runtime memory at position p. 314 func (wc *WitnessCalculator) loadBigInt(p int32, n int32) *big.Int { 315 return loadBigInt(wc.runtime, p, n) 316 } 317 318 var zero32 [32]byte 319 320 // storeBigInt stores a *big.Int into the runtime memory at position p. 321 func (wc *WitnessCalculator) storeBigInt(p int32, v *big.Int) { 322 bigIntBytes := swap(v.Bytes()) 323 copy(wc.runtime.Memory()[p:p+32], zero32[:]) 324 copy(wc.runtime.Memory()[p:p+int32(len(bigIntBytes))], bigIntBytes) 325 } 326 327 // memFreePos gives the next free runtime memory position. 328 func (wc *WitnessCalculator) memFreePos() int32 { 329 return int32(binary.LittleEndian.Uint32(wc.runtime.Memory()[:4])) 330 } 331 332 // setMemFreePos sets the next free runtime memory position. 333 func (wc *WitnessCalculator) setMemFreePos(p int32) { 334 binary.LittleEndian.PutUint32(wc.runtime.Memory()[:4], uint32(p)) 335 } 336 337 // allocInt reserves space in the runtime memory and returns its position. 338 func (wc *WitnessCalculator) allocInt() int32 { 339 p := wc.memFreePos() 340 wc.setMemFreePos(p + 8) 341 return p 342 } 343 344 // allocFr reserves space in the runtime memory for a Field element and returns its position. 345 func (wc *WitnessCalculator) allocFr() int32 { 346 p := wc.memFreePos() 347 wc.setMemFreePos(p + wc.n32*4 + 8) 348 return p 349 } 350 351 // getInt loads an int32 from the runtime memory at position p. 352 func (wc *WitnessCalculator) getInt(p int32) int32 { 353 return int32(binary.LittleEndian.Uint32(wc.runtime.Memory()[p : p+4])) 354 } 355 356 // setInt stores an int32 in the runtime memory at position p. 357 func (wc *WitnessCalculator) setInt(p, v int32) { 358 binary.LittleEndian.PutUint32(wc.runtime.Memory()[p:p+4], uint32(v)) 359 } 360 361 // setShortPositive stores a small positive Field element in the runtime memory at position p. 362 func (wc *WitnessCalculator) setShortPositive(p int32, v *big.Int) { 363 if !v.IsInt64() || v.Int64() >= 0x80000000 { 364 panic(fmt.Errorf("v should be < 0x80000000")) 365 } 366 wc.setInt(p, int32(v.Int64())) 367 wc.setInt(p+4, 0) 368 } 369 370 // setShortPositive stores a small negative *big.Int in the runtime memory at position p. 371 func (wc *WitnessCalculator) setShortNegative(p int32, v *big.Int) { 372 vNeg := new(big.Int).Set(wc.prime) // prime 373 vNeg.Sub(vNeg, wc.shortMax) // prime - max 374 vNeg.Sub(v, vNeg) // v - (prime - max) 375 vNeg.Add(wc.shortMax, vNeg) // max + (v - (prime - max)) 376 if !vNeg.IsInt64() || vNeg.Int64() < 0x80000000 || vNeg.Int64() >= 0x80000000*2 { 377 panic(fmt.Errorf("v should be < 0x80000000")) 378 } 379 wc.setInt(p, int32(vNeg.Int64())) 380 wc.setInt(p+4, 0) 381 } 382 383 // setShortPositive stores a normal Field element in the runtime memory at position p. 384 func (wc *WitnessCalculator) setLongNormal(p int32, v *big.Int) { 385 wc.setInt(p, 0) 386 wc.setInt(p+4, math.MinInt32) // math.MinInt32 = 0x80000000 387 wc.storeBigInt(p+8, v) 388 } 389 390 // storeFr stores a Field element in the runtime memory at position p. 391 func (wc *WitnessCalculator) storeFr(p int32, v *big.Int) { 392 if v.Cmp(wc.shortMax) == -1 { 393 wc.setShortPositive(p, v) 394 } else if v.Cmp(wc.shortMin) >= 0 { 395 wc.setShortNegative(p, v) 396 } else { 397 wc.setLongNormal(p, v) 398 } 399 } 400 401 // fromMontgomery transforms a Field element from Montgomery form to regular form. 402 func (wc *WitnessCalculator) fromMontgomery(v *big.Int) *big.Int { 403 res := new(big.Int).Set(v) 404 res.Mul(res, wc.rInv) 405 res.Mod(res, wc.prime) 406 return res 407 } 408 409 // loadFr loads a Field element from the runtime memory at position p. 410 func (wc *WitnessCalculator) loadFr(p int32) *big.Int { 411 m := wc.runtime.Memory() 412 if (m[p+4+3] & 0x80) != 0 { 413 res := wc.loadBigInt(p+8, wc.n32) 414 if (m[p+4+3] & 0x40) != 0 { 415 return wc.fromMontgomery(res) 416 } else { 417 return res 418 } 419 } else { 420 if (m[p+3] & 0x40) != 0 { 421 res := wc.loadBigInt(p, 4) // res 422 res.Sub(res, wc.shortMax) // res - max 423 res.Add(wc.prime, res) // res - max + prime 424 res.Sub(res, wc.shortMax) // res - max + (prime - max) 425 return res 426 } else { 427 return wc.loadBigInt(p, 4) 428 } 429 } 430 } 431 432 // doCalculateWitness is an internal function that calculates the witness. 433 func (wc *WitnessCalculator) doCalculateWitness(inputs map[string]interface{}, sanityCheck bool) error { 434 sanityCheckVal := int32(0) 435 if sanityCheck { 436 sanityCheckVal = 1 437 } 438 if err := wc.fns.init(sanityCheckVal); err != nil { 439 return err 440 } 441 pSigOffset := wc.allocInt() 442 pFr := wc.allocFr() 443 444 for inputName, inputValue := range inputs { 445 hMSB, hLSB := fnvHash(inputName) 446 wc.fns.getSignalOffset32(pSigOffset, 0, hMSB, hLSB) 447 sigOffset := wc.getInt(pSigOffset) 448 fSlice := flatSlice(inputValue) 449 for i, value := range fSlice { 450 wc.storeFr(pFr, value) 451 wc.fns.setSignal(0, 0, sigOffset+int32(i), pFr) 452 } 453 } 454 455 return nil 456 } 457 458 // CalculateWitness calculates the witness given the inputs. 459 func (wc *WitnessCalculator) CalculateWitness(inputs map[string]interface{}, sanityCheck bool) ([]*big.Int, error) { 460 oldMemFreePos := wc.memFreePos() 461 462 if err := wc.doCalculateWitness(inputs, sanityCheck); err != nil { 463 return nil, err 464 } 465 466 w := make([]*big.Int, wc.nVars) 467 for i := int32(0); i < wc.nVars; i++ { 468 pWitness, err := wc.fns.getPWitness(i) 469 if err != nil { 470 return nil, err 471 } 472 w[i] = wc.loadFr(pWitness) 473 } 474 475 wc.setMemFreePos(oldMemFreePos) 476 return w, nil 477 } 478 479 // CalculateWitness calculates the witness in binary given the inputs. 480 func (wc *WitnessCalculator) CalculateBinWitness(inputs map[string]interface{}, sanityCheck bool) ([]byte, error) { 481 oldMemFreePos := wc.memFreePos() 482 483 if err := wc.doCalculateWitness(inputs, sanityCheck); err != nil { 484 return nil, err 485 } 486 pWitnessBuff, err := wc.fns.getWitnessBuffer() 487 if err != nil { 488 return nil, err 489 } 490 witnessBuff := make([]byte, uint(wc.nVars)*wc.n64*8) 491 copy(witnessBuff, wc.runtime.Memory()[pWitnessBuff:int(pWitnessBuff)+len(witnessBuff)]) 492 493 wc.setMemFreePos(oldMemFreePos) 494 return witnessBuff, nil 495 }