github.com/iden3/go-circom-witnesscalc@v0.2.1-0.20230314155733-dd1f248a91b6/circom2witnesscalc.go (about)

     1  package witnesscalc
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"math/big"
     8  
     9  	"github.com/wasmerio/wasmer-go/wasmer"
    10  )
    11  
    12  // Circom2WitnessCalculator is the object that allows performing witness calculation
    13  // from signal inputs using the WitnessCalc WASM module.
    14  type Circom2WitnessCalculator struct {
    15  	instance            *wasmer.Instance
    16  	sanityCheck         bool
    17  	n32                 int32
    18  	version             int32
    19  	witnessSize         int32
    20  	init                wasmer.NativeFunction
    21  	getFieldNumLen32    wasmer.NativeFunction
    22  	getInputSignalSize  wasmer.NativeFunction
    23  	getInputSize        wasmer.NativeFunction
    24  	getRawPrime         wasmer.NativeFunction
    25  	getVersion          wasmer.NativeFunction
    26  	getWitness          wasmer.NativeFunction
    27  	readSharedRWMemory  wasmer.NativeFunction
    28  	setInputSignal      wasmer.NativeFunction
    29  	writeSharedRWMemory wasmer.NativeFunction
    30  }
    31  
    32  // NewCircom2WitnessCalculator creates a new WitnessCalculator from the WitnessCalc
    33  // loaded WASM module in the runtime.
    34  func NewCircom2WitnessCalculator(wasmBytes []byte, sanityCheck bool) (*Circom2WitnessCalculator, error) {
    35  	engine := wasmer.NewEngine()
    36  	store := wasmer.NewStore(engine)
    37  
    38  	// Compiles the module
    39  	module, _ := wasmer.NewModule(store, wasmBytes)
    40  
    41  	limits, err := wasmer.NewLimits(2000, 100000)
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  
    46  	memType := wasmer.NewMemoryType(limits)
    47  
    48  	memory := wasmer.NewMemory(store, memType)
    49  
    50  	// Instantiates the module
    51  	importObject := wasmer.NewImportObject()
    52  
    53  	importObject.Register("env", map[string]wasmer.IntoExtern{
    54  		"memory": memory,
    55  	})
    56  
    57  	importObject.Register("runtime", map[string]wasmer.IntoExtern{
    58  		"exceptionHandler":   getExceptionHandler(store),
    59  		"showSharedRWMemory": getShowSharedRWMemory(store),
    60  		"log":                getLog(store),
    61  	})
    62  
    63  	instance, err := wasmer.NewInstance(module, importObject)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	// Gets the `init` exported function from the WebAssembly instance.
    69  	init, err := instance.Exports.GetFunction("init")
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	// Calls that exported function with Go standard values. The WebAssembly
    75  	// types are inferred and values are casted automatically.
    76  	_, err = init(1)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	getFieldNumLen32, err := instance.Exports.GetFunction("getFieldNumLen32")
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	n32, err := getFieldNumLen32()
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	// this function is missing in wasm files generated with circom version prior to v2.0.4
    91  	getInputSignalSize, _ := instance.Exports.GetFunction("getInputSignalSize")
    92  
    93  	getInputSize, err := instance.Exports.GetFunction("getInputSize")
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	getRawPrime, err := instance.Exports.GetFunction("getRawPrime")
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	getVersion, err := instance.Exports.GetFunction("getVersion")
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	version, err := getVersion()
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  
   113  	getWitness, err := instance.Exports.GetFunction("getWitness")
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	getWitnessSize, err := instance.Exports.GetFunction("getWitnessSize")
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	witnessSize, err := getWitnessSize()
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	setInputSignal, err := instance.Exports.GetFunction("setInputSignal")
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	readSharedRWMemory, err := instance.Exports.GetFunction("readSharedRWMemory")
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  
   138  	writeSharedRWMemory, err := instance.Exports.GetFunction("writeSharedRWMemory")
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	return &Circom2WitnessCalculator{
   144  		instance:            instance,
   145  		sanityCheck:         sanityCheck,
   146  		n32:                 n32.(int32),
   147  		version:             version.(int32),
   148  		witnessSize:         witnessSize.(int32),
   149  		init:                init,
   150  		getFieldNumLen32:    getFieldNumLen32,
   151  		getInputSignalSize:  getInputSignalSize,
   152  		getInputSize:        getInputSize,
   153  		getRawPrime:         getRawPrime,
   154  		getWitness:          getWitness,
   155  		getVersion:          getVersion,
   156  		setInputSignal:      setInputSignal,
   157  		readSharedRWMemory:  readSharedRWMemory,
   158  		writeSharedRWMemory: writeSharedRWMemory,
   159  	}, nil
   160  }
   161  
   162  // CalculateWitness calculates the witness given the inputs.
   163  func (wc *Circom2WitnessCalculator) CalculateWitness(inputs map[string]interface{}, sanityCheck bool) ([]*big.Int, error) {
   164  
   165  	w := make([]*big.Int, wc.witnessSize)
   166  
   167  	err := wc.doCalculateWitness(inputs, sanityCheck)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	for i := 0; i < int(wc.witnessSize); i++ {
   173  		_, err := wc.getWitness(i)
   174  		if err != nil {
   175  			return nil, err
   176  		}
   177  		arr := make([]uint32, wc.n32)
   178  		for j := 0; j < int(wc.n32); j++ {
   179  			val, err := wc.readSharedRWMemory(int32(j))
   180  			if err != nil {
   181  				return nil, err
   182  			}
   183  			arr[int(wc.n32)-1-j] = uint32(val.(int32))
   184  		}
   185  		w[i] = fromArray32(arr)
   186  	}
   187  
   188  	return w, nil
   189  }
   190  
   191  // CalculateBinWitness calculates the witness in binary given the inputs.
   192  func (wc *Circom2WitnessCalculator) CalculateBinWitness(inputs map[string]interface{}, sanityCheck bool) ([]byte, error) {
   193  	buff := new(bytes.Buffer)
   194  
   195  	err := wc.doCalculateWitness(inputs, sanityCheck)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  
   200  	for i := 0; i < int(wc.witnessSize); i++ {
   201  		_, err := wc.getWitness(i)
   202  		if err != nil {
   203  			return nil, err
   204  		}
   205  
   206  		for j := 0; j < int(wc.n32); j++ {
   207  			val, err := wc.readSharedRWMemory(j)
   208  			if err != nil {
   209  				return nil, err
   210  			}
   211  			_ = binary.Write(buff, binary.LittleEndian, uint32(val.(int32)))
   212  		}
   213  	}
   214  
   215  	return buff.Bytes(), nil
   216  }
   217  
   218  // CalculateWTNSBin calculates the witness in binary given the inputs.
   219  func (wc *Circom2WitnessCalculator) CalculateWTNSBin(inputs map[string]interface{}, sanityCheck bool) ([]byte, error) {
   220  	buff := new(bytes.Buffer)
   221  
   222  	err := wc.doCalculateWitness(inputs, sanityCheck)
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  
   227  	buff.Grow(int(wc.witnessSize*wc.n32 + wc.n32 + 11))
   228  
   229  	// wtns
   230  	_ = buff.WriteByte('w')
   231  	_ = buff.WriteByte('t')
   232  	_ = buff.WriteByte('n')
   233  	_ = buff.WriteByte('s')
   234  
   235  	//version 2
   236  	_ = binary.Write(buff, binary.LittleEndian, uint32(2))
   237  
   238  	//number of sections: 2
   239  	_ = binary.Write(buff, binary.LittleEndian, uint32(2))
   240  
   241  	//id section 1
   242  	_ = binary.Write(buff, binary.LittleEndian, uint32(1))
   243  
   244  	n8 := wc.n32 * 4
   245  	//id section 1 length in 64bytes
   246  	idSection1length := 8 + n8
   247  	_ = binary.Write(buff, binary.LittleEndian, uint64(idSection1length))
   248  
   249  	//this.n32
   250  	_ = binary.Write(buff, binary.LittleEndian, uint32(n8))
   251  
   252  	//prime number
   253  	_, err = wc.getRawPrime()
   254  	if err != nil {
   255  		return nil, err
   256  	}
   257  
   258  	for j := 0; j < int(wc.n32); j++ {
   259  		val, err := wc.readSharedRWMemory(int32(j))
   260  		if err != nil {
   261  			return nil, err
   262  		}
   263  		_ = binary.Write(buff, binary.LittleEndian, uint32(val.(int32)))
   264  	}
   265  
   266  	// witness size
   267  	_ = binary.Write(buff, binary.LittleEndian, uint32(wc.witnessSize))
   268  
   269  	//id section 2
   270  	_ = binary.Write(buff, binary.LittleEndian, uint32(2))
   271  
   272  	// section 2 length
   273  	idSection2length := n8 * wc.witnessSize
   274  	_ = binary.Write(buff, binary.LittleEndian, uint64(idSection2length))
   275  
   276  	for i := 0; i < int(wc.witnessSize); i++ {
   277  		_, err := wc.getWitness(i)
   278  		if err != nil {
   279  			return nil, err
   280  		}
   281  
   282  		for j := 0; j < int(wc.n32); j++ {
   283  			val, err := wc.readSharedRWMemory(j)
   284  			if err != nil {
   285  				return nil, err
   286  			}
   287  			_ = binary.Write(buff, binary.LittleEndian, uint32(val.(int32)))
   288  		}
   289  	}
   290  
   291  	return buff.Bytes(), nil
   292  }
   293  
   294  // CalculateWitness calculates the witness given the inputs.
   295  func (wc *Circom2WitnessCalculator) doCalculateWitness(inputs map[string]interface{}, sanityCheck bool) error {
   296  	//input is assumed to be a map from signals to arrays of bigInts
   297  	sanityCheckVal := int32(0)
   298  	if sanityCheck {
   299  		sanityCheckVal = 1
   300  	}
   301  	_, err := wc.init(sanityCheckVal)
   302  	if err != nil {
   303  		return err
   304  	}
   305  
   306  	inputCounter := 0
   307  	for inputName, inputValue := range inputs {
   308  		hMSB, hLSB := fnvHash(inputName)
   309  		fSlice := flatSlice(inputValue)
   310  
   311  		if wc.getInputSignalSize != nil {
   312  			signalSize, err := wc.getInputSignalSize(hMSB, hLSB)
   313  			if err != nil {
   314  				return err
   315  			}
   316  
   317  			if signalSize.(int32) < 0 {
   318  				return fmt.Errorf("signal %s not found", inputName)
   319  			}
   320  			if len(fSlice) < int(signalSize.(int32)) {
   321  				return fmt.Errorf("not enough values for input signal %s", inputName)
   322  			}
   323  			if len(fSlice) > int(signalSize.(int32)) {
   324  				return fmt.Errorf("too many values for input signal %s", inputName)
   325  			}
   326  		}
   327  
   328  		for i := 0; i < len(fSlice); i++ {
   329  			arrFr, err := toArray32(fSlice[i], int(wc.n32))
   330  			if err != nil {
   331  				return err
   332  			}
   333  			for j := 0; j < int(wc.n32); j++ {
   334  				_, err := wc.writeSharedRWMemory(j, int32(arrFr[int(wc.n32)-1-j]))
   335  				if err != nil {
   336  					return err
   337  				}
   338  			}
   339  			_, err = wc.setInputSignal(hMSB, hLSB, i)
   340  			if err != nil {
   341  				return err
   342  			}
   343  			inputCounter++
   344  		}
   345  	}
   346  	inputSize, err := wc.getInputSize()
   347  	if inputCounter < int(inputSize.(int32)) {
   348  		return fmt.Errorf("not all inputs have been set: only %d out of %d", inputCounter, inputSize)
   349  	}
   350  	return nil
   351  }
   352  
   353  func getExceptionHandler(store *wasmer.Store) wasmer.IntoExtern {
   354  	function := wasmer.NewFunction(
   355  		store,
   356  		wasmer.NewFunctionType(
   357  			wasmer.NewValueTypes(wasmer.I32), // one i32 argument
   358  			wasmer.NewValueTypes(),           // zero results
   359  		),
   360  		func(args []wasmer.Value) ([]wasmer.Value, error) {
   361  			if len(args) > 0 {
   362  				code := args[0].I32()
   363  				var errStr string
   364  				if code == 1 {
   365  					errStr = "Signal not found. "
   366  				} else if code == 2 {
   367  					errStr = "Too many signals set. "
   368  				} else if code == 3 {
   369  					errStr = "Signal already set. "
   370  				} else if code == 4 {
   371  					errStr = "Assert Failed. "
   372  				} else if code == 5 {
   373  					errStr = "Not enough memory. "
   374  				} else if code == 6 {
   375  					errStr = "Input signal array access exceeds the size"
   376  				} else {
   377  					errStr = "Unknown error"
   378  				}
   379  				fmt.Println(errStr)
   380  			}
   381  			return []wasmer.Value{}, nil
   382  		},
   383  	)
   384  	return function
   385  }
   386  
   387  func getShowSharedRWMemory(store *wasmer.Store) wasmer.IntoExtern {
   388  	function := wasmer.NewFunction(
   389  		store,
   390  		wasmer.NewFunctionType(
   391  			wasmer.NewValueTypes(),
   392  			wasmer.NewValueTypes(),
   393  		),
   394  		func(args []wasmer.Value) ([]wasmer.Value, error) {
   395  			return []wasmer.Value{}, nil
   396  		},
   397  	)
   398  	return function
   399  }
   400  
   401  func getLog(store *wasmer.Store) wasmer.IntoExtern {
   402  	function := wasmer.NewFunction(
   403  		store,
   404  		wasmer.NewFunctionType(
   405  			wasmer.NewValueTypes(),
   406  			wasmer.NewValueTypes(),
   407  		),
   408  		func(args []wasmer.Value) ([]wasmer.Value, error) {
   409  			return []wasmer.Value{}, nil
   410  		},
   411  	)
   412  	return function
   413  }
   414  
   415  func toArray32(s *big.Int, size int) ([]uint32, error) {
   416  	res := make([]uint32, size)
   417  	rem := s
   418  
   419  	radix := big.NewInt(0x100000000)
   420  	zero := big.NewInt(0)
   421  	i := size - 1
   422  	// while not zero rem
   423  	for rem.Cmp(zero) != 0 {
   424  		res[i] = uint32(new(big.Int).Mod(rem, radix).Uint64())
   425  		rem.Div(rem, radix)
   426  		i--
   427  	}
   428  	return res, nil
   429  }
   430  
   431  func fromArray32(arr []uint32) *big.Int {
   432  	res := new(big.Int)
   433  	radix := big.NewInt(0x100000000)
   434  	for i := 0; i < len(arr); i++ {
   435  		res.Mul(res, radix)
   436  		res.Add(res, big.NewInt(int64(arr[i])))
   437  	}
   438  	return res
   439  }