github.com/iden3/go-circom-witnesscalc@v0.2.1-0.20230314155733-dd1f248a91b6/circom2witnesscalc.go (about) 1 package witnesscalc 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "fmt" 7 "math/big" 8 9 "github.com/wasmerio/wasmer-go/wasmer" 10 ) 11 12 // Circom2WitnessCalculator is the object that allows performing witness calculation 13 // from signal inputs using the WitnessCalc WASM module. 14 type Circom2WitnessCalculator struct { 15 instance *wasmer.Instance 16 sanityCheck bool 17 n32 int32 18 version int32 19 witnessSize int32 20 init wasmer.NativeFunction 21 getFieldNumLen32 wasmer.NativeFunction 22 getInputSignalSize wasmer.NativeFunction 23 getInputSize wasmer.NativeFunction 24 getRawPrime wasmer.NativeFunction 25 getVersion wasmer.NativeFunction 26 getWitness wasmer.NativeFunction 27 readSharedRWMemory wasmer.NativeFunction 28 setInputSignal wasmer.NativeFunction 29 writeSharedRWMemory wasmer.NativeFunction 30 } 31 32 // NewCircom2WitnessCalculator creates a new WitnessCalculator from the WitnessCalc 33 // loaded WASM module in the runtime. 34 func NewCircom2WitnessCalculator(wasmBytes []byte, sanityCheck bool) (*Circom2WitnessCalculator, error) { 35 engine := wasmer.NewEngine() 36 store := wasmer.NewStore(engine) 37 38 // Compiles the module 39 module, _ := wasmer.NewModule(store, wasmBytes) 40 41 limits, err := wasmer.NewLimits(2000, 100000) 42 if err != nil { 43 return nil, err 44 } 45 46 memType := wasmer.NewMemoryType(limits) 47 48 memory := wasmer.NewMemory(store, memType) 49 50 // Instantiates the module 51 importObject := wasmer.NewImportObject() 52 53 importObject.Register("env", map[string]wasmer.IntoExtern{ 54 "memory": memory, 55 }) 56 57 importObject.Register("runtime", map[string]wasmer.IntoExtern{ 58 "exceptionHandler": getExceptionHandler(store), 59 "showSharedRWMemory": getShowSharedRWMemory(store), 60 "log": getLog(store), 61 }) 62 63 instance, err := wasmer.NewInstance(module, importObject) 64 if err != nil { 65 return nil, err 66 } 67 68 // Gets the `init` exported function from the WebAssembly instance. 69 init, err := instance.Exports.GetFunction("init") 70 if err != nil { 71 return nil, err 72 } 73 74 // Calls that exported function with Go standard values. The WebAssembly 75 // types are inferred and values are casted automatically. 76 _, err = init(1) 77 if err != nil { 78 return nil, err 79 } 80 81 getFieldNumLen32, err := instance.Exports.GetFunction("getFieldNumLen32") 82 if err != nil { 83 return nil, err 84 } 85 n32, err := getFieldNumLen32() 86 if err != nil { 87 return nil, err 88 } 89 90 // this function is missing in wasm files generated with circom version prior to v2.0.4 91 getInputSignalSize, _ := instance.Exports.GetFunction("getInputSignalSize") 92 93 getInputSize, err := instance.Exports.GetFunction("getInputSize") 94 if err != nil { 95 return nil, err 96 } 97 98 getRawPrime, err := instance.Exports.GetFunction("getRawPrime") 99 if err != nil { 100 return nil, err 101 } 102 103 getVersion, err := instance.Exports.GetFunction("getVersion") 104 if err != nil { 105 return nil, err 106 } 107 108 version, err := getVersion() 109 if err != nil { 110 return nil, err 111 } 112 113 getWitness, err := instance.Exports.GetFunction("getWitness") 114 if err != nil { 115 return nil, err 116 } 117 118 getWitnessSize, err := instance.Exports.GetFunction("getWitnessSize") 119 if err != nil { 120 return nil, err 121 } 122 123 witnessSize, err := getWitnessSize() 124 if err != nil { 125 return nil, err 126 } 127 128 setInputSignal, err := instance.Exports.GetFunction("setInputSignal") 129 if err != nil { 130 return nil, err 131 } 132 133 readSharedRWMemory, err := instance.Exports.GetFunction("readSharedRWMemory") 134 if err != nil { 135 return nil, err 136 } 137 138 writeSharedRWMemory, err := instance.Exports.GetFunction("writeSharedRWMemory") 139 if err != nil { 140 return nil, err 141 } 142 143 return &Circom2WitnessCalculator{ 144 instance: instance, 145 sanityCheck: sanityCheck, 146 n32: n32.(int32), 147 version: version.(int32), 148 witnessSize: witnessSize.(int32), 149 init: init, 150 getFieldNumLen32: getFieldNumLen32, 151 getInputSignalSize: getInputSignalSize, 152 getInputSize: getInputSize, 153 getRawPrime: getRawPrime, 154 getWitness: getWitness, 155 getVersion: getVersion, 156 setInputSignal: setInputSignal, 157 readSharedRWMemory: readSharedRWMemory, 158 writeSharedRWMemory: writeSharedRWMemory, 159 }, nil 160 } 161 162 // CalculateWitness calculates the witness given the inputs. 163 func (wc *Circom2WitnessCalculator) CalculateWitness(inputs map[string]interface{}, sanityCheck bool) ([]*big.Int, error) { 164 165 w := make([]*big.Int, wc.witnessSize) 166 167 err := wc.doCalculateWitness(inputs, sanityCheck) 168 if err != nil { 169 return nil, err 170 } 171 172 for i := 0; i < int(wc.witnessSize); i++ { 173 _, err := wc.getWitness(i) 174 if err != nil { 175 return nil, err 176 } 177 arr := make([]uint32, wc.n32) 178 for j := 0; j < int(wc.n32); j++ { 179 val, err := wc.readSharedRWMemory(int32(j)) 180 if err != nil { 181 return nil, err 182 } 183 arr[int(wc.n32)-1-j] = uint32(val.(int32)) 184 } 185 w[i] = fromArray32(arr) 186 } 187 188 return w, nil 189 } 190 191 // CalculateBinWitness calculates the witness in binary given the inputs. 192 func (wc *Circom2WitnessCalculator) CalculateBinWitness(inputs map[string]interface{}, sanityCheck bool) ([]byte, error) { 193 buff := new(bytes.Buffer) 194 195 err := wc.doCalculateWitness(inputs, sanityCheck) 196 if err != nil { 197 return nil, err 198 } 199 200 for i := 0; i < int(wc.witnessSize); i++ { 201 _, err := wc.getWitness(i) 202 if err != nil { 203 return nil, err 204 } 205 206 for j := 0; j < int(wc.n32); j++ { 207 val, err := wc.readSharedRWMemory(j) 208 if err != nil { 209 return nil, err 210 } 211 _ = binary.Write(buff, binary.LittleEndian, uint32(val.(int32))) 212 } 213 } 214 215 return buff.Bytes(), nil 216 } 217 218 // CalculateWTNSBin calculates the witness in binary given the inputs. 219 func (wc *Circom2WitnessCalculator) CalculateWTNSBin(inputs map[string]interface{}, sanityCheck bool) ([]byte, error) { 220 buff := new(bytes.Buffer) 221 222 err := wc.doCalculateWitness(inputs, sanityCheck) 223 if err != nil { 224 return nil, err 225 } 226 227 buff.Grow(int(wc.witnessSize*wc.n32 + wc.n32 + 11)) 228 229 // wtns 230 _ = buff.WriteByte('w') 231 _ = buff.WriteByte('t') 232 _ = buff.WriteByte('n') 233 _ = buff.WriteByte('s') 234 235 //version 2 236 _ = binary.Write(buff, binary.LittleEndian, uint32(2)) 237 238 //number of sections: 2 239 _ = binary.Write(buff, binary.LittleEndian, uint32(2)) 240 241 //id section 1 242 _ = binary.Write(buff, binary.LittleEndian, uint32(1)) 243 244 n8 := wc.n32 * 4 245 //id section 1 length in 64bytes 246 idSection1length := 8 + n8 247 _ = binary.Write(buff, binary.LittleEndian, uint64(idSection1length)) 248 249 //this.n32 250 _ = binary.Write(buff, binary.LittleEndian, uint32(n8)) 251 252 //prime number 253 _, err = wc.getRawPrime() 254 if err != nil { 255 return nil, err 256 } 257 258 for j := 0; j < int(wc.n32); j++ { 259 val, err := wc.readSharedRWMemory(int32(j)) 260 if err != nil { 261 return nil, err 262 } 263 _ = binary.Write(buff, binary.LittleEndian, uint32(val.(int32))) 264 } 265 266 // witness size 267 _ = binary.Write(buff, binary.LittleEndian, uint32(wc.witnessSize)) 268 269 //id section 2 270 _ = binary.Write(buff, binary.LittleEndian, uint32(2)) 271 272 // section 2 length 273 idSection2length := n8 * wc.witnessSize 274 _ = binary.Write(buff, binary.LittleEndian, uint64(idSection2length)) 275 276 for i := 0; i < int(wc.witnessSize); i++ { 277 _, err := wc.getWitness(i) 278 if err != nil { 279 return nil, err 280 } 281 282 for j := 0; j < int(wc.n32); j++ { 283 val, err := wc.readSharedRWMemory(j) 284 if err != nil { 285 return nil, err 286 } 287 _ = binary.Write(buff, binary.LittleEndian, uint32(val.(int32))) 288 } 289 } 290 291 return buff.Bytes(), nil 292 } 293 294 // CalculateWitness calculates the witness given the inputs. 295 func (wc *Circom2WitnessCalculator) doCalculateWitness(inputs map[string]interface{}, sanityCheck bool) error { 296 //input is assumed to be a map from signals to arrays of bigInts 297 sanityCheckVal := int32(0) 298 if sanityCheck { 299 sanityCheckVal = 1 300 } 301 _, err := wc.init(sanityCheckVal) 302 if err != nil { 303 return err 304 } 305 306 inputCounter := 0 307 for inputName, inputValue := range inputs { 308 hMSB, hLSB := fnvHash(inputName) 309 fSlice := flatSlice(inputValue) 310 311 if wc.getInputSignalSize != nil { 312 signalSize, err := wc.getInputSignalSize(hMSB, hLSB) 313 if err != nil { 314 return err 315 } 316 317 if signalSize.(int32) < 0 { 318 return fmt.Errorf("signal %s not found", inputName) 319 } 320 if len(fSlice) < int(signalSize.(int32)) { 321 return fmt.Errorf("not enough values for input signal %s", inputName) 322 } 323 if len(fSlice) > int(signalSize.(int32)) { 324 return fmt.Errorf("too many values for input signal %s", inputName) 325 } 326 } 327 328 for i := 0; i < len(fSlice); i++ { 329 arrFr, err := toArray32(fSlice[i], int(wc.n32)) 330 if err != nil { 331 return err 332 } 333 for j := 0; j < int(wc.n32); j++ { 334 _, err := wc.writeSharedRWMemory(j, int32(arrFr[int(wc.n32)-1-j])) 335 if err != nil { 336 return err 337 } 338 } 339 _, err = wc.setInputSignal(hMSB, hLSB, i) 340 if err != nil { 341 return err 342 } 343 inputCounter++ 344 } 345 } 346 inputSize, err := wc.getInputSize() 347 if inputCounter < int(inputSize.(int32)) { 348 return fmt.Errorf("not all inputs have been set: only %d out of %d", inputCounter, inputSize) 349 } 350 return nil 351 } 352 353 func getExceptionHandler(store *wasmer.Store) wasmer.IntoExtern { 354 function := wasmer.NewFunction( 355 store, 356 wasmer.NewFunctionType( 357 wasmer.NewValueTypes(wasmer.I32), // one i32 argument 358 wasmer.NewValueTypes(), // zero results 359 ), 360 func(args []wasmer.Value) ([]wasmer.Value, error) { 361 if len(args) > 0 { 362 code := args[0].I32() 363 var errStr string 364 if code == 1 { 365 errStr = "Signal not found. " 366 } else if code == 2 { 367 errStr = "Too many signals set. " 368 } else if code == 3 { 369 errStr = "Signal already set. " 370 } else if code == 4 { 371 errStr = "Assert Failed. " 372 } else if code == 5 { 373 errStr = "Not enough memory. " 374 } else if code == 6 { 375 errStr = "Input signal array access exceeds the size" 376 } else { 377 errStr = "Unknown error" 378 } 379 fmt.Println(errStr) 380 } 381 return []wasmer.Value{}, nil 382 }, 383 ) 384 return function 385 } 386 387 func getShowSharedRWMemory(store *wasmer.Store) wasmer.IntoExtern { 388 function := wasmer.NewFunction( 389 store, 390 wasmer.NewFunctionType( 391 wasmer.NewValueTypes(), 392 wasmer.NewValueTypes(), 393 ), 394 func(args []wasmer.Value) ([]wasmer.Value, error) { 395 return []wasmer.Value{}, nil 396 }, 397 ) 398 return function 399 } 400 401 func getLog(store *wasmer.Store) wasmer.IntoExtern { 402 function := wasmer.NewFunction( 403 store, 404 wasmer.NewFunctionType( 405 wasmer.NewValueTypes(), 406 wasmer.NewValueTypes(), 407 ), 408 func(args []wasmer.Value) ([]wasmer.Value, error) { 409 return []wasmer.Value{}, nil 410 }, 411 ) 412 return function 413 } 414 415 func toArray32(s *big.Int, size int) ([]uint32, error) { 416 res := make([]uint32, size) 417 rem := s 418 419 radix := big.NewInt(0x100000000) 420 zero := big.NewInt(0) 421 i := size - 1 422 // while not zero rem 423 for rem.Cmp(zero) != 0 { 424 res[i] = uint32(new(big.Int).Mod(rem, radix).Uint64()) 425 rem.Div(rem, radix) 426 i-- 427 } 428 return res, nil 429 } 430 431 func fromArray32(arr []uint32) *big.Int { 432 res := new(big.Int) 433 radix := big.NewInt(0x100000000) 434 for i := 0; i < len(arr); i++ { 435 res.Mul(res, radix) 436 res.Add(res, big.NewInt(int64(arr[i]))) 437 } 438 return res 439 }