github.com/iden3/go-circom-witnesscalc@v0.2.1-0.20230314155733-dd1f248a91b6/witnesscalc.go (about)

     1  package witnesscalc
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"math"
     8  	"math/big"
     9  	"reflect"
    10  	"unsafe"
    11  
    12  	log "github.com/sirupsen/logrus"
    13  
    14  	wasm3 "github.com/iden3/go-wasm3"
    15  )
    16  
    17  // witnessCalcFns are wrapper functions to the WitnessCalc WASM module
    18  type witnessCalcFns struct {
    19  	getFrLen          func() (int32, error)
    20  	getPRawPrime      func() (int32, error)
    21  	getNVars          func() (int32, error)
    22  	init              func(sanityCheck int32) error
    23  	getSignalOffset32 func(pR, component, hashMSB, hashLSB int32) error
    24  	setSignal         func(cIdx, component, signal, pVal int32) error
    25  	getPWitness       func(w int32) (int32, error)
    26  	getWitnessBuffer  func() (int32, error)
    27  }
    28  
    29  func getStack(sp unsafe.Pointer, length int) []uint64 {
    30  	var data = (*uint64)(sp)
    31  	var header reflect.SliceHeader
    32  	header = *(*reflect.SliceHeader)(unsafe.Pointer(&header))
    33  	header.Data = uintptr(unsafe.Pointer(data))
    34  	header.Len = int(length)
    35  	header.Cap = int(length)
    36  	return *(*[]uint64)(unsafe.Pointer(&header))
    37  }
    38  
    39  func getMem(r *wasm3.Runtime, _mem unsafe.Pointer) []byte {
    40  	var data = (*uint8)(_mem)
    41  	length := r.GetAllocatedMemoryLength()
    42  	var header reflect.SliceHeader
    43  	header = *(*reflect.SliceHeader)(unsafe.Pointer(&header))
    44  	header.Data = uintptr(unsafe.Pointer(data))
    45  	header.Len = int(length)
    46  	header.Cap = int(length)
    47  	return *(*[]byte)(unsafe.Pointer(&header))
    48  }
    49  
    50  func getStr(mem []byte, p uint64) string {
    51  	var buf bytes.Buffer
    52  	for ; mem[p] != 0; p++ {
    53  		buf.WriteByte(mem[p])
    54  	}
    55  	return buf.String()
    56  }
    57  
    58  // newWitnessCalcFns builds the witnessCalcFns from the loaded WitnessCalc WASM
    59  // module in the runtime.  Imported functions (logging) are binded to dummy functions.
    60  func newWitnessCalcFns(r *wasm3.Runtime, m *wasm3.Module, wc *WitnessCalculator) (*witnessCalcFns, error) {
    61  	r.AttachFunction("runtime", "error", "v(iiiiii)", wasm3.CallbackFunction(
    62  		func(runtime wasm3.RuntimeT, sp unsafe.Pointer, _mem unsafe.Pointer) int {
    63  			// func(code, pstr, a, b, c, d)
    64  
    65  			stack := getStack(sp, 6)
    66  			mem := getMem(r, _mem)
    67  
    68  			code := stack[0]
    69  			pstr := stack[1]
    70  			a := stack[2]
    71  			b := stack[3]
    72  			c := stack[4]
    73  			d := stack[5]
    74  
    75  			var errStr string
    76  			if code == 7 {
    77  				errStr = fmt.Sprintf("%s %v != %v %s",
    78  					getStr(mem, pstr),
    79  					wc.loadFr(int32(b)), wc.loadFr(int32(c)), getStr(mem, d))
    80  			} else {
    81  				errStr = fmt.Sprintf("%s %v %v %v %v",
    82  					getStr(mem, pstr), a, b, c, getStr(mem, d))
    83  			}
    84  			log.Errorf("WitnessCalculator WASM Error (%v): %v", code, errStr)
    85  			return 0
    86  		},
    87  	))
    88  	r.AttachFunction("runtime", "logSetSignal", "v(ii)", wasm3.CallbackFunction(
    89  		func(runtime wasm3.RuntimeT, sp unsafe.Pointer, mem unsafe.Pointer) int {
    90  			return 0
    91  		},
    92  	))
    93  	r.AttachFunction("runtime", "logGetSignal", "v(ii)", wasm3.CallbackFunction(
    94  		func(runtime wasm3.RuntimeT, sp unsafe.Pointer, mem unsafe.Pointer) int {
    95  			return 0
    96  		},
    97  	))
    98  	r.AttachFunction("runtime", "logFinishComponent", "v(i)", wasm3.CallbackFunction(
    99  		func(runtime wasm3.RuntimeT, sp unsafe.Pointer, mem unsafe.Pointer) int {
   100  			return 0
   101  		},
   102  	))
   103  	r.AttachFunction("runtime", "logStartComponent", "v(i)", wasm3.CallbackFunction(
   104  		func(runtime wasm3.RuntimeT, sp unsafe.Pointer, mem unsafe.Pointer) int {
   105  			return 0
   106  		},
   107  	))
   108  	r.AttachFunction("runtime", "log", "v(i)", wasm3.CallbackFunction(
   109  		func(runtime wasm3.RuntimeT, sp unsafe.Pointer, mem unsafe.Pointer) int {
   110  			return 0
   111  		},
   112  	))
   113  
   114  	_getFrLen, err := r.FindFunction("getFrLen")
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	getFrLen := func() (int32, error) {
   119  		res, err := _getFrLen()
   120  		if err != nil {
   121  			return 0, err
   122  		}
   123  		return res.(int32), nil
   124  	}
   125  	_getPRawPrime, err := r.FindFunction("getPRawPrime")
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	getPRawPrime := func() (int32, error) {
   130  		res, err := _getPRawPrime()
   131  		if err != nil {
   132  			return 0, err
   133  		}
   134  		return res.(int32), nil
   135  	}
   136  	_getNVars, err := r.FindFunction("getNVars")
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	getNVars := func() (int32, error) {
   141  		res, err := _getNVars()
   142  		if err != nil {
   143  			return 0, err
   144  		}
   145  		return res.(int32), nil
   146  	}
   147  	_init, err := r.FindFunction("init")
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  	init := func(sanityCheck int32) error {
   152  		_, err := _init(sanityCheck)
   153  		if err != nil {
   154  			return err
   155  		}
   156  		return nil
   157  	}
   158  	_getSignalOffset32, err := r.FindFunction("getSignalOffset32")
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	getSignalOffset32 := func(pR, component, hashMSB, hashLSB int32) error {
   163  		_, err := _getSignalOffset32(pR, component, hashMSB, hashLSB)
   164  		if err != nil {
   165  			return err
   166  		}
   167  		return nil
   168  	}
   169  	_setSignal, err := r.FindFunction("setSignal")
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  	setSignal := func(cIdx, component, signal, pVal int32) error {
   174  		_, err := _setSignal(cIdx, component, signal, pVal)
   175  		if err != nil {
   176  			return err
   177  		}
   178  		return nil
   179  	}
   180  	_getPWitness, err := r.FindFunction("getPWitness")
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  	getPWitness := func(w int32) (int32, error) {
   185  		res, err := _getPWitness(w)
   186  		if err != nil {
   187  			return 0, err
   188  		}
   189  		return res.(int32), nil
   190  	}
   191  	_getWitnessBuffer, err := r.FindFunction("getWitnessBuffer")
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  	getWitnessBuffer := func() (int32, error) {
   196  		res, err := _getWitnessBuffer()
   197  		if err != nil {
   198  			return 0, err
   199  		}
   200  		return res.(int32), nil
   201  	}
   202  
   203  	return &witnessCalcFns{
   204  		getFrLen:          getFrLen,
   205  		getPRawPrime:      getPRawPrime,
   206  		getNVars:          getNVars,
   207  		init:              init,
   208  		getSignalOffset32: getSignalOffset32,
   209  		setSignal:         setSignal,
   210  		getPWitness:       getPWitness,
   211  		getWitnessBuffer:  getWitnessBuffer,
   212  	}, nil
   213  }
   214  
   215  // WitnessJSON is a wrapper type to Marshal the Witness in JSON format
   216  type WitnessJSON []*big.Int
   217  
   218  // MarshalJSON marshals the WitnessJSON where each value is encoded in base 10
   219  // as a string in an array.
   220  func (w WitnessJSON) MarshalJSON() ([]byte, error) {
   221  	var buffer bytes.Buffer
   222  	buffer.WriteString("[")
   223  	for i, bi := range w {
   224  		buffer.WriteString(`"` + bi.String() + `"`)
   225  		if i != len(w)-1 {
   226  			buffer.WriteString(",")
   227  		}
   228  	}
   229  	buffer.WriteString("]")
   230  	return buffer.Bytes(), nil
   231  }
   232  
   233  // loadBigInt loads a *big.Int from the runtime memory at position p.
   234  func loadBigInt(runtime *wasm3.Runtime, p int32, n int32) *big.Int {
   235  	bigIntBytes := make([]byte, n)
   236  	copy(bigIntBytes, runtime.Memory()[p:p+n])
   237  	return new(big.Int).SetBytes(swap(bigIntBytes))
   238  }
   239  
   240  // WitnessCalculator is the object that allows performing witness calculation
   241  // from signal inputs using the WitnessCalc WASM module.
   242  type WitnessCalculator struct {
   243  	n32    int32
   244  	prime  *big.Int
   245  	mask32 *big.Int
   246  	nVars  int32
   247  	n64    uint
   248  	r      *big.Int
   249  	rInv   *big.Int
   250  
   251  	shortMax *big.Int
   252  	shortMin *big.Int
   253  
   254  	runtime *wasm3.Runtime
   255  	fns     *witnessCalcFns
   256  }
   257  
   258  // NewWitnessCalculator creates a new WitnessCalculator from the WitnessCalc
   259  // loaded WASM module in the runtime.
   260  func NewWitnessCalculator(runtime *wasm3.Runtime, module *wasm3.Module) (*WitnessCalculator, error) {
   261  	var wc WitnessCalculator
   262  	fns, err := newWitnessCalcFns(runtime, module, &wc)
   263  	if err != nil {
   264  		return nil, err
   265  	}
   266  
   267  	n32, err := fns.getFrLen()
   268  	if err != nil {
   269  		return nil, err
   270  	}
   271  	// n32 = (n32 >> 2) - 2
   272  	n32 = n32 - 8
   273  
   274  	pRawPrime, err := fns.getPRawPrime()
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  
   279  	prime := loadBigInt(runtime, pRawPrime, n32)
   280  
   281  	mask32 := new(big.Int).SetUint64(0xFFFFFFFF)
   282  	nVars, err := fns.getNVars()
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  
   287  	n64 := uint(((prime.BitLen() - 1) / 64) + 1)
   288  	r := new(big.Int).SetInt64(1)
   289  	r.Lsh(r, n64*64)
   290  	rInv := new(big.Int).ModInverse(r, prime)
   291  
   292  	shortMax, ok := new(big.Int).SetString("0x80000000", 0)
   293  	if !ok {
   294  		return nil, fmt.Errorf("unable to set shortMax from string")
   295  	}
   296  	shortMin := new(big.Int).Set(prime)
   297  	shortMin.Sub(shortMin, shortMax)
   298  
   299  	wc.n32 = n32
   300  	wc.prime = prime
   301  	wc.mask32 = mask32
   302  	wc.nVars = nVars
   303  	wc.n64 = n64
   304  	wc.r = r
   305  	wc.rInv = rInv
   306  	wc.shortMin = shortMin
   307  	wc.shortMax = shortMax
   308  	wc.runtime = runtime
   309  	wc.fns = fns
   310  	return &wc, nil
   311  }
   312  
   313  // loadBigInt loads a *big.Int from the runtime memory at position p.
   314  func (wc *WitnessCalculator) loadBigInt(p int32, n int32) *big.Int {
   315  	return loadBigInt(wc.runtime, p, n)
   316  }
   317  
   318  var zero32 [32]byte
   319  
   320  // storeBigInt stores a *big.Int into the runtime memory at position p.
   321  func (wc *WitnessCalculator) storeBigInt(p int32, v *big.Int) {
   322  	bigIntBytes := swap(v.Bytes())
   323  	copy(wc.runtime.Memory()[p:p+32], zero32[:])
   324  	copy(wc.runtime.Memory()[p:p+int32(len(bigIntBytes))], bigIntBytes)
   325  }
   326  
   327  // memFreePos gives the next free runtime memory position.
   328  func (wc *WitnessCalculator) memFreePos() int32 {
   329  	return int32(binary.LittleEndian.Uint32(wc.runtime.Memory()[:4]))
   330  }
   331  
   332  // setMemFreePos sets the next free runtime memory position.
   333  func (wc *WitnessCalculator) setMemFreePos(p int32) {
   334  	binary.LittleEndian.PutUint32(wc.runtime.Memory()[:4], uint32(p))
   335  }
   336  
   337  // allocInt reserves space in the runtime memory and returns its position.
   338  func (wc *WitnessCalculator) allocInt() int32 {
   339  	p := wc.memFreePos()
   340  	wc.setMemFreePos(p + 8)
   341  	return p
   342  }
   343  
   344  // allocFr reserves space in the runtime memory for a Field element and returns its position.
   345  func (wc *WitnessCalculator) allocFr() int32 {
   346  	p := wc.memFreePos()
   347  	wc.setMemFreePos(p + wc.n32*4 + 8)
   348  	return p
   349  }
   350  
   351  // getInt loads an int32 from the runtime memory at position p.
   352  func (wc *WitnessCalculator) getInt(p int32) int32 {
   353  	return int32(binary.LittleEndian.Uint32(wc.runtime.Memory()[p : p+4]))
   354  }
   355  
   356  // setInt stores an int32 in the runtime memory at position p.
   357  func (wc *WitnessCalculator) setInt(p, v int32) {
   358  	binary.LittleEndian.PutUint32(wc.runtime.Memory()[p:p+4], uint32(v))
   359  }
   360  
   361  // setShortPositive stores a small positive Field element in the runtime memory at position p.
   362  func (wc *WitnessCalculator) setShortPositive(p int32, v *big.Int) {
   363  	if !v.IsInt64() || v.Int64() >= 0x80000000 {
   364  		panic(fmt.Errorf("v should be < 0x80000000"))
   365  	}
   366  	wc.setInt(p, int32(v.Int64()))
   367  	wc.setInt(p+4, 0)
   368  }
   369  
   370  // setShortPositive stores a small negative *big.Int in the runtime memory at position p.
   371  func (wc *WitnessCalculator) setShortNegative(p int32, v *big.Int) {
   372  	vNeg := new(big.Int).Set(wc.prime) // prime
   373  	vNeg.Sub(vNeg, wc.shortMax)        // prime - max
   374  	vNeg.Sub(v, vNeg)                  // v - (prime - max)
   375  	vNeg.Add(wc.shortMax, vNeg)        // max + (v - (prime - max))
   376  	if !vNeg.IsInt64() || vNeg.Int64() < 0x80000000 || vNeg.Int64() >= 0x80000000*2 {
   377  		panic(fmt.Errorf("v should be < 0x80000000"))
   378  	}
   379  	wc.setInt(p, int32(vNeg.Int64()))
   380  	wc.setInt(p+4, 0)
   381  }
   382  
   383  // setShortPositive stores a normal Field element in the runtime memory at position p.
   384  func (wc *WitnessCalculator) setLongNormal(p int32, v *big.Int) {
   385  	wc.setInt(p, 0)
   386  	wc.setInt(p+4, math.MinInt32) // math.MinInt32 = 0x80000000
   387  	wc.storeBigInt(p+8, v)
   388  }
   389  
   390  // storeFr stores a Field element in the runtime memory at position p.
   391  func (wc *WitnessCalculator) storeFr(p int32, v *big.Int) {
   392  	if v.Cmp(wc.shortMax) == -1 {
   393  		wc.setShortPositive(p, v)
   394  	} else if v.Cmp(wc.shortMin) >= 0 {
   395  		wc.setShortNegative(p, v)
   396  	} else {
   397  		wc.setLongNormal(p, v)
   398  	}
   399  }
   400  
   401  // fromMontgomery transforms a Field element from Montgomery form to regular form.
   402  func (wc *WitnessCalculator) fromMontgomery(v *big.Int) *big.Int {
   403  	res := new(big.Int).Set(v)
   404  	res.Mul(res, wc.rInv)
   405  	res.Mod(res, wc.prime)
   406  	return res
   407  }
   408  
   409  // loadFr loads a Field element from the runtime memory at position p.
   410  func (wc *WitnessCalculator) loadFr(p int32) *big.Int {
   411  	m := wc.runtime.Memory()
   412  	if (m[p+4+3] & 0x80) != 0 {
   413  		res := wc.loadBigInt(p+8, wc.n32)
   414  		if (m[p+4+3] & 0x40) != 0 {
   415  			return wc.fromMontgomery(res)
   416  		} else {
   417  			return res
   418  		}
   419  	} else {
   420  		if (m[p+3] & 0x40) != 0 {
   421  			res := wc.loadBigInt(p, 4) // res
   422  			res.Sub(res, wc.shortMax)  // res - max
   423  			res.Add(wc.prime, res)     // res - max + prime
   424  			res.Sub(res, wc.shortMax)  // res - max + (prime - max)
   425  			return res
   426  		} else {
   427  			return wc.loadBigInt(p, 4)
   428  		}
   429  	}
   430  }
   431  
   432  // doCalculateWitness is an internal function that calculates the witness.
   433  func (wc *WitnessCalculator) doCalculateWitness(inputs map[string]interface{}, sanityCheck bool) error {
   434  	sanityCheckVal := int32(0)
   435  	if sanityCheck {
   436  		sanityCheckVal = 1
   437  	}
   438  	if err := wc.fns.init(sanityCheckVal); err != nil {
   439  		return err
   440  	}
   441  	pSigOffset := wc.allocInt()
   442  	pFr := wc.allocFr()
   443  
   444  	for inputName, inputValue := range inputs {
   445  		hMSB, hLSB := fnvHash(inputName)
   446  		wc.fns.getSignalOffset32(pSigOffset, 0, hMSB, hLSB)
   447  		sigOffset := wc.getInt(pSigOffset)
   448  		fSlice := flatSlice(inputValue)
   449  		for i, value := range fSlice {
   450  			wc.storeFr(pFr, value)
   451  			wc.fns.setSignal(0, 0, sigOffset+int32(i), pFr)
   452  		}
   453  	}
   454  
   455  	return nil
   456  }
   457  
   458  // CalculateWitness calculates the witness given the inputs.
   459  func (wc *WitnessCalculator) CalculateWitness(inputs map[string]interface{}, sanityCheck bool) ([]*big.Int, error) {
   460  	oldMemFreePos := wc.memFreePos()
   461  
   462  	if err := wc.doCalculateWitness(inputs, sanityCheck); err != nil {
   463  		return nil, err
   464  	}
   465  
   466  	w := make([]*big.Int, wc.nVars)
   467  	for i := int32(0); i < wc.nVars; i++ {
   468  		pWitness, err := wc.fns.getPWitness(i)
   469  		if err != nil {
   470  			return nil, err
   471  		}
   472  		w[i] = wc.loadFr(pWitness)
   473  	}
   474  
   475  	wc.setMemFreePos(oldMemFreePos)
   476  	return w, nil
   477  }
   478  
   479  // CalculateWitness calculates the witness in binary given the inputs.
   480  func (wc *WitnessCalculator) CalculateBinWitness(inputs map[string]interface{}, sanityCheck bool) ([]byte, error) {
   481  	oldMemFreePos := wc.memFreePos()
   482  
   483  	if err := wc.doCalculateWitness(inputs, sanityCheck); err != nil {
   484  		return nil, err
   485  	}
   486  	pWitnessBuff, err := wc.fns.getWitnessBuffer()
   487  	if err != nil {
   488  		return nil, err
   489  	}
   490  	witnessBuff := make([]byte, uint(wc.nVars)*wc.n64*8)
   491  	copy(witnessBuff, wc.runtime.Memory()[pWitnessBuff:int(pWitnessBuff)+len(witnessBuff)])
   492  
   493  	wc.setMemFreePos(oldMemFreePos)
   494  	return witnessBuff, nil
   495  }