github.com/iden3/go-circom-witnesscalc@v0.2.1-0.20230314155733-dd1f248a91b6/witnesscalc_test.go (about)

     1  package witnesscalc
     2  
     3  import (
     4  	"encoding/hex"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"log"
     9  	"math"
    10  	"math/big"
    11  	"os"
    12  	"os/exec"
    13  	"path"
    14  	"strings"
    15  	"testing"
    16  	"time"
    17  
    18  	wasm3 "github.com/iden3/go-wasm3"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  type TestParams struct {
    24  	wasmFilename   string
    25  	inputsFilename string
    26  	prime          string
    27  	nVars          int32
    28  	r              string
    29  	rInv           string
    30  	witness        string
    31  }
    32  
    33  func TestWitnessCalcMyCircuit1(t *testing.T) {
    34  	testWitnessCalc(t, TestParams{
    35  		wasmFilename:   "test_files/mycircuit.wasm",
    36  		inputsFilename: "test_files/mycircuit-input1.json",
    37  		prime:          "21888242871839275222246405745257275088548364400416034343698204186575808495617",
    38  		nVars:          4,
    39  		r:              "115792089237316195423570985008687907853269984665640564039457584007913129639936",
    40  		rInv:           "9915499612839321149637521777990102151350674507940716049588462388200839649614",
    41  		witness:        `["1","33","3","11"]`,
    42  	}, true)
    43  }
    44  
    45  func TestWitnessCalcMyCircuit2(t *testing.T) {
    46  	testWitnessCalc(t, TestParams{
    47  		wasmFilename:   "test_files/mycircuit.wasm",
    48  		inputsFilename: "test_files/mycircuit-input2.json",
    49  		prime:          "21888242871839275222246405745257275088548364400416034343698204186575808495617",
    50  		nVars:          4,
    51  		r:              "115792089237316195423570985008687907853269984665640564039457584007913129639936",
    52  		rInv:           "9915499612839321149637521777990102151350674507940716049588462388200839649614",
    53  		witness:        `["1","21888242871839275222246405745257275088548364400416034343698204186575672693159","21888242871839275222246405745257275088548364400416034343698204186575796149939","11"]`,
    54  	}, true)
    55  }
    56  
    57  func TestWitnessCalcMyCircuit3(t *testing.T) {
    58  	testWitnessCalc(t, TestParams{
    59  		wasmFilename:   "test_files/mycircuit.wasm",
    60  		inputsFilename: "test_files/mycircuit-input3.json",
    61  		prime:          "21888242871839275222246405745257275088548364400416034343698204186575808495617",
    62  		nVars:          4,
    63  		r:              "115792089237316195423570985008687907853269984665640564039457584007913129639936",
    64  		rInv:           "9915499612839321149637521777990102151350674507940716049588462388200839649614",
    65  		witness:        `["1","21888242871839275222246405745257275088548364400416034343698204186575808493616","10944121435919637611123202872628637544274182200208017171849102093287904246808","2"]`,
    66  	}, true)
    67  }
    68  
    69  func TestWitnessCalcSmtVerifier10(t *testing.T) {
    70  	witnessJSON, err := ioutil.ReadFile("test_files/smtverifier10-witness.json")
    71  	if err != nil {
    72  		panic(err)
    73  	}
    74  	testWitnessCalc(t, TestParams{
    75  		wasmFilename:   "test_files/smtverifier10.wasm",
    76  		inputsFilename: "test_files/smtverifier10-input.json",
    77  		prime:          "21888242871839275222246405745257275088548364400416034343698204186575808495617",
    78  		nVars:          4794,
    79  		r:              "115792089237316195423570985008687907853269984665640564039457584007913129639936",
    80  		rInv:           "9915499612839321149637521777990102151350674507940716049588462388200839649614",
    81  		witness:        string(witnessJSON),
    82  	}, false)
    83  }
    84  
    85  var testNConstraints = false
    86  
    87  func TestWitnessCalcNConstraints(t *testing.T) {
    88  	if !testNConstraints {
    89  		return
    90  	}
    91  	oldWd, err := os.Getwd()
    92  	require.Nil(t, err)
    93  	defer func() {
    94  		err := os.Chdir(oldWd)
    95  		require.Nil(t, err)
    96  	}()
    97  	err = os.Chdir(path.Join(oldWd, "test_files"))
    98  	require.Nil(t, err)
    99  
   100  	for i := 1; i < 8; i++ {
   101  		// for i := 1; i < 3; i++ {
   102  		n := int(math.Pow10(i))
   103  		log.Printf("WitnessCalc with %v constraints\n", n)
   104  		err := exec.Command("cp", "nconstraints.circom", "nconstraints.circom.tmp").Run()
   105  		require.Nil(t, err)
   106  		err = exec.Command("sed", "-i", fmt.Sprintf("s/{{N}}/%v/g", n), "nconstraints.circom.tmp").Run()
   107  		require.Nil(t, err)
   108  		start := time.Now()
   109  		err = exec.Command("./node_modules/.bin/circom", "nconstraints.circom.tmp", "-w", fmt.Sprintf("nconstraints-%v.wasm", n)).Run()
   110  		if err != nil {
   111  			fmt.Println(err)
   112  		}
   113  		elapsed := time.Since(start)
   114  		require.Nil(t, err)
   115  		log.Printf("Circuit compilation took %v\n", elapsed)
   116  
   117  		wasmFilename := fmt.Sprintf("nconstraints-%v.wasm", n)
   118  		var inputs = map[string]interface{}{"in": new(big.Int).SetInt64(2)}
   119  
   120  		runtime := wasm3.NewRuntime(&wasm3.Config{
   121  			Environment: wasm3.NewEnvironment(),
   122  			StackSize:   64 * 1024,
   123  		})
   124  		wasmBytes, err := ioutil.ReadFile(wasmFilename)
   125  		require.Nil(t, err)
   126  		module, err := runtime.ParseModule(wasmBytes)
   127  		require.Nil(t, err)
   128  		module, err = runtime.LoadModule(module)
   129  		require.Nil(t, err)
   130  		witnessCalculator, err := NewWitnessCalculator(runtime, module)
   131  		require.Nil(t, err)
   132  		p := witnessCalculator.prime
   133  		start = time.Now()
   134  		w, err := witnessCalculator.CalculateWitness(inputs, false)
   135  		elapsed = time.Since(start)
   136  		require.Nil(t, err)
   137  		log.Printf("Witness calculation took %v\n", elapsed)
   138  
   139  		runtime.Destroy()
   140  
   141  		out := new(big.Int).SetInt64(2)
   142  		for i := 1; i < n; i++ {
   143  			out.Mul(out, out)
   144  			out.Add(out, new(big.Int).SetInt64(int64(i)))
   145  			out.Mod(out, p)
   146  		}
   147  
   148  		assert.Equal(t, out, w[1])
   149  
   150  		err = os.Remove("nconstraints.circom.tmp")
   151  		require.Nil(t, err)
   152  		err = os.Remove(fmt.Sprintf("nconstraints-%v.wasm", n))
   153  		require.Nil(t, err)
   154  	}
   155  }
   156  
   157  func testWitnessCalc(t *testing.T, p TestParams, logWitness bool) {
   158  	log.Print("Initializing WASM3")
   159  
   160  	runtime := wasm3.NewRuntime(&wasm3.Config{
   161  		Environment: wasm3.NewEnvironment(),
   162  		StackSize:   64 * 1024,
   163  	})
   164  	log.Println("Runtime ok")
   165  	// err := runtime.ResizeMemory(16)
   166  	// if err != nil {
   167  	// 	panic(err)
   168  	// }
   169  
   170  	// log.Println("Runtime Memory len: ", len(runtime.Memory()))
   171  
   172  	wasmBytes, err := ioutil.ReadFile(p.wasmFilename)
   173  	require.Nil(t, err)
   174  	log.Printf("Read WASM module (%d bytes)\n", len(wasmBytes))
   175  
   176  	module, err := runtime.ParseModule(wasmBytes)
   177  	require.Nil(t, err)
   178  	module, err = runtime.LoadModule(module)
   179  	require.Nil(t, err)
   180  	log.Print("Loaded module")
   181  
   182  	// fmt.Printf("NumImports: %v\n", module.NumImports())
   183  	// fns, err := NewWitnessCalcFns(runtime, module)
   184  	// if err != nil {
   185  	// 	panic(err)
   186  	// }
   187  
   188  	inputsBytes, err := ioutil.ReadFile(p.inputsFilename)
   189  	require.Nil(t, err)
   190  	inputs, err := ParseInputs(inputsBytes)
   191  	require.Nil(t, err)
   192  	log.Print("Inputs: ", inputs)
   193  
   194  	witnessCalculator, err := NewWitnessCalculator(runtime, module)
   195  	require.Nil(t, err)
   196  	log.Print("n32: ", witnessCalculator.n32)
   197  	log.Print("prime: ", witnessCalculator.prime)
   198  	log.Print("mask32: ", witnessCalculator.mask32)
   199  	log.Print("nVars: ", witnessCalculator.nVars)
   200  	log.Print("n64: ", witnessCalculator.n64)
   201  	log.Print("r: ", witnessCalculator.r)
   202  	log.Print("rInv: ", witnessCalculator.rInv)
   203  
   204  	assert.Equal(t, p.prime, witnessCalculator.prime.String())
   205  	assert.Equal(t, p.r, witnessCalculator.r.String())
   206  	assert.Equal(t, p.rInv, witnessCalculator.rInv.String())
   207  	assert.Equal(t, p.nVars, witnessCalculator.nVars)
   208  
   209  	start := time.Now()
   210  	w, err := witnessCalculator.CalculateWitness(inputs, false)
   211  	elapsed := time.Since(start)
   212  	require.Nil(t, err)
   213  	log.Printf("Witness calculation took %v\n", elapsed)
   214  	if logWitness {
   215  		log.Print("Witness: ", w)
   216  	}
   217  	wJSON, err := json.Marshal(WitnessJSON(w))
   218  	require.Nil(t, err)
   219  	if logWitness {
   220  		log.Print("Witness JSON: ", string(wJSON))
   221  	}
   222  	pWitness := strings.ReplaceAll(p.witness, " ", "")
   223  	pWitness = strings.ReplaceAll(pWitness, "\n", "")
   224  	assert.Equal(t, pWitness, string(wJSON))
   225  
   226  	// DEBUG
   227  	// {
   228  	// 	elemsCalc := strings.Split(string(wJSON), ",")
   229  	// 	elems := strings.Split(p.witness, ",")
   230  	// 	if len(elemsCalc) != len(elems) {
   231  	// 		panic(fmt.Errorf("Witness length differs: %v, %v", len(elemsCalc), len(elems)))
   232  	// 	}
   233  	// 	for i := 0; i < len(elems); i++ {
   234  	// 		fmt.Printf("exp %v\ngot %v\n\n", elems[i], elemsCalc[i])
   235  	// 	}
   236  	// }
   237  	wb, err := witnessCalculator.CalculateBinWitness(inputs, false)
   238  	require.Nil(t, err)
   239  	if logWitness {
   240  		log.Print("WitnessBin: ", hex.EncodeToString(wb))
   241  	}
   242  }