github.com/iden3/go-circom-witnesscalc@v0.2.1-0.20230314155733-dd1f248a91b6/witnesscalc_test.go (about) 1 package witnesscalc 2 3 import ( 4 "encoding/hex" 5 "encoding/json" 6 "fmt" 7 "io/ioutil" 8 "log" 9 "math" 10 "math/big" 11 "os" 12 "os/exec" 13 "path" 14 "strings" 15 "testing" 16 "time" 17 18 wasm3 "github.com/iden3/go-wasm3" 19 "github.com/stretchr/testify/assert" 20 "github.com/stretchr/testify/require" 21 ) 22 23 type TestParams struct { 24 wasmFilename string 25 inputsFilename string 26 prime string 27 nVars int32 28 r string 29 rInv string 30 witness string 31 } 32 33 func TestWitnessCalcMyCircuit1(t *testing.T) { 34 testWitnessCalc(t, TestParams{ 35 wasmFilename: "test_files/mycircuit.wasm", 36 inputsFilename: "test_files/mycircuit-input1.json", 37 prime: "21888242871839275222246405745257275088548364400416034343698204186575808495617", 38 nVars: 4, 39 r: "115792089237316195423570985008687907853269984665640564039457584007913129639936", 40 rInv: "9915499612839321149637521777990102151350674507940716049588462388200839649614", 41 witness: `["1","33","3","11"]`, 42 }, true) 43 } 44 45 func TestWitnessCalcMyCircuit2(t *testing.T) { 46 testWitnessCalc(t, TestParams{ 47 wasmFilename: "test_files/mycircuit.wasm", 48 inputsFilename: "test_files/mycircuit-input2.json", 49 prime: "21888242871839275222246405745257275088548364400416034343698204186575808495617", 50 nVars: 4, 51 r: "115792089237316195423570985008687907853269984665640564039457584007913129639936", 52 rInv: "9915499612839321149637521777990102151350674507940716049588462388200839649614", 53 witness: `["1","21888242871839275222246405745257275088548364400416034343698204186575672693159","21888242871839275222246405745257275088548364400416034343698204186575796149939","11"]`, 54 }, true) 55 } 56 57 func TestWitnessCalcMyCircuit3(t *testing.T) { 58 testWitnessCalc(t, TestParams{ 59 wasmFilename: "test_files/mycircuit.wasm", 60 inputsFilename: "test_files/mycircuit-input3.json", 61 prime: "21888242871839275222246405745257275088548364400416034343698204186575808495617", 62 nVars: 4, 63 r: "115792089237316195423570985008687907853269984665640564039457584007913129639936", 64 rInv: "9915499612839321149637521777990102151350674507940716049588462388200839649614", 65 witness: `["1","21888242871839275222246405745257275088548364400416034343698204186575808493616","10944121435919637611123202872628637544274182200208017171849102093287904246808","2"]`, 66 }, true) 67 } 68 69 func TestWitnessCalcSmtVerifier10(t *testing.T) { 70 witnessJSON, err := ioutil.ReadFile("test_files/smtverifier10-witness.json") 71 if err != nil { 72 panic(err) 73 } 74 testWitnessCalc(t, TestParams{ 75 wasmFilename: "test_files/smtverifier10.wasm", 76 inputsFilename: "test_files/smtverifier10-input.json", 77 prime: "21888242871839275222246405745257275088548364400416034343698204186575808495617", 78 nVars: 4794, 79 r: "115792089237316195423570985008687907853269984665640564039457584007913129639936", 80 rInv: "9915499612839321149637521777990102151350674507940716049588462388200839649614", 81 witness: string(witnessJSON), 82 }, false) 83 } 84 85 var testNConstraints = false 86 87 func TestWitnessCalcNConstraints(t *testing.T) { 88 if !testNConstraints { 89 return 90 } 91 oldWd, err := os.Getwd() 92 require.Nil(t, err) 93 defer func() { 94 err := os.Chdir(oldWd) 95 require.Nil(t, err) 96 }() 97 err = os.Chdir(path.Join(oldWd, "test_files")) 98 require.Nil(t, err) 99 100 for i := 1; i < 8; i++ { 101 // for i := 1; i < 3; i++ { 102 n := int(math.Pow10(i)) 103 log.Printf("WitnessCalc with %v constraints\n", n) 104 err := exec.Command("cp", "nconstraints.circom", "nconstraints.circom.tmp").Run() 105 require.Nil(t, err) 106 err = exec.Command("sed", "-i", fmt.Sprintf("s/{{N}}/%v/g", n), "nconstraints.circom.tmp").Run() 107 require.Nil(t, err) 108 start := time.Now() 109 err = exec.Command("./node_modules/.bin/circom", "nconstraints.circom.tmp", "-w", fmt.Sprintf("nconstraints-%v.wasm", n)).Run() 110 if err != nil { 111 fmt.Println(err) 112 } 113 elapsed := time.Since(start) 114 require.Nil(t, err) 115 log.Printf("Circuit compilation took %v\n", elapsed) 116 117 wasmFilename := fmt.Sprintf("nconstraints-%v.wasm", n) 118 var inputs = map[string]interface{}{"in": new(big.Int).SetInt64(2)} 119 120 runtime := wasm3.NewRuntime(&wasm3.Config{ 121 Environment: wasm3.NewEnvironment(), 122 StackSize: 64 * 1024, 123 }) 124 wasmBytes, err := ioutil.ReadFile(wasmFilename) 125 require.Nil(t, err) 126 module, err := runtime.ParseModule(wasmBytes) 127 require.Nil(t, err) 128 module, err = runtime.LoadModule(module) 129 require.Nil(t, err) 130 witnessCalculator, err := NewWitnessCalculator(runtime, module) 131 require.Nil(t, err) 132 p := witnessCalculator.prime 133 start = time.Now() 134 w, err := witnessCalculator.CalculateWitness(inputs, false) 135 elapsed = time.Since(start) 136 require.Nil(t, err) 137 log.Printf("Witness calculation took %v\n", elapsed) 138 139 runtime.Destroy() 140 141 out := new(big.Int).SetInt64(2) 142 for i := 1; i < n; i++ { 143 out.Mul(out, out) 144 out.Add(out, new(big.Int).SetInt64(int64(i))) 145 out.Mod(out, p) 146 } 147 148 assert.Equal(t, out, w[1]) 149 150 err = os.Remove("nconstraints.circom.tmp") 151 require.Nil(t, err) 152 err = os.Remove(fmt.Sprintf("nconstraints-%v.wasm", n)) 153 require.Nil(t, err) 154 } 155 } 156 157 func testWitnessCalc(t *testing.T, p TestParams, logWitness bool) { 158 log.Print("Initializing WASM3") 159 160 runtime := wasm3.NewRuntime(&wasm3.Config{ 161 Environment: wasm3.NewEnvironment(), 162 StackSize: 64 * 1024, 163 }) 164 log.Println("Runtime ok") 165 // err := runtime.ResizeMemory(16) 166 // if err != nil { 167 // panic(err) 168 // } 169 170 // log.Println("Runtime Memory len: ", len(runtime.Memory())) 171 172 wasmBytes, err := ioutil.ReadFile(p.wasmFilename) 173 require.Nil(t, err) 174 log.Printf("Read WASM module (%d bytes)\n", len(wasmBytes)) 175 176 module, err := runtime.ParseModule(wasmBytes) 177 require.Nil(t, err) 178 module, err = runtime.LoadModule(module) 179 require.Nil(t, err) 180 log.Print("Loaded module") 181 182 // fmt.Printf("NumImports: %v\n", module.NumImports()) 183 // fns, err := NewWitnessCalcFns(runtime, module) 184 // if err != nil { 185 // panic(err) 186 // } 187 188 inputsBytes, err := ioutil.ReadFile(p.inputsFilename) 189 require.Nil(t, err) 190 inputs, err := ParseInputs(inputsBytes) 191 require.Nil(t, err) 192 log.Print("Inputs: ", inputs) 193 194 witnessCalculator, err := NewWitnessCalculator(runtime, module) 195 require.Nil(t, err) 196 log.Print("n32: ", witnessCalculator.n32) 197 log.Print("prime: ", witnessCalculator.prime) 198 log.Print("mask32: ", witnessCalculator.mask32) 199 log.Print("nVars: ", witnessCalculator.nVars) 200 log.Print("n64: ", witnessCalculator.n64) 201 log.Print("r: ", witnessCalculator.r) 202 log.Print("rInv: ", witnessCalculator.rInv) 203 204 assert.Equal(t, p.prime, witnessCalculator.prime.String()) 205 assert.Equal(t, p.r, witnessCalculator.r.String()) 206 assert.Equal(t, p.rInv, witnessCalculator.rInv.String()) 207 assert.Equal(t, p.nVars, witnessCalculator.nVars) 208 209 start := time.Now() 210 w, err := witnessCalculator.CalculateWitness(inputs, false) 211 elapsed := time.Since(start) 212 require.Nil(t, err) 213 log.Printf("Witness calculation took %v\n", elapsed) 214 if logWitness { 215 log.Print("Witness: ", w) 216 } 217 wJSON, err := json.Marshal(WitnessJSON(w)) 218 require.Nil(t, err) 219 if logWitness { 220 log.Print("Witness JSON: ", string(wJSON)) 221 } 222 pWitness := strings.ReplaceAll(p.witness, " ", "") 223 pWitness = strings.ReplaceAll(pWitness, "\n", "") 224 assert.Equal(t, pWitness, string(wJSON)) 225 226 // DEBUG 227 // { 228 // elemsCalc := strings.Split(string(wJSON), ",") 229 // elems := strings.Split(p.witness, ",") 230 // if len(elemsCalc) != len(elems) { 231 // panic(fmt.Errorf("Witness length differs: %v, %v", len(elemsCalc), len(elems))) 232 // } 233 // for i := 0; i < len(elems); i++ { 234 // fmt.Printf("exp %v\ngot %v\n\n", elems[i], elemsCalc[i]) 235 // } 236 // } 237 wb, err := witnessCalculator.CalculateBinWitness(inputs, false) 238 require.Nil(t, err) 239 if logWitness { 240 log.Print("WitnessBin: ", hex.EncodeToString(wb)) 241 } 242 }