github.com/n1ghtfa1l/go-vnt@v0.6.4-alpha.6/core/wavm/tests/env_test.go (about)

     1  package tests
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"math/big"
     9  	"path/filepath"
    10  	"strconv"
    11  	"strings"
    12  	"testing"
    13  	T "time"
    14  
    15  	"github.com/vntchain/go-vnt/accounts/abi"
    16  	"github.com/vntchain/go-vnt/common"
    17  	"github.com/vntchain/go-vnt/core"
    18  	"github.com/vntchain/go-vnt/core/state"
    19  	"github.com/vntchain/go-vnt/core/vm"
    20  	errorsmsg "github.com/vntchain/go-vnt/core/vm"
    21  	inter "github.com/vntchain/go-vnt/core/vm/interface"
    22  	"github.com/vntchain/go-vnt/core/wavm"
    23  	wasmContract "github.com/vntchain/go-vnt/core/wavm/contract"
    24  	"github.com/vntchain/go-vnt/log"
    25  	"github.com/vntchain/go-vnt/params"
    26  	"github.com/vntchain/go-vnt/vntdb"
    27  )
    28  
    29  var envJsonPath = filepath.Join("", "env.json")
    30  
    31  type ENVTest struct {
    32  	json          vmJSON
    33  	statedb       *state.StateDB
    34  	createCost    float64
    35  	callCost      float64
    36  	compileCost   float64
    37  	nocompileCost float64
    38  	createRunCost float64
    39  	callRunCost   float64
    40  }
    41  
    42  func (t *ENVTest) UnmarshalJSON(data []byte) error {
    43  	err := json.Unmarshal(data, &t.json)
    44  	if err != nil {
    45  		return err
    46  	}
    47  	return nil
    48  }
    49  
    50  func parseInput(args []argument) []interface{} {
    51  	var input []interface{}
    52  	for _, v := range args {
    53  		input = append(input, parseData(v.Data, v.DataType))
    54  	}
    55  	return input
    56  }
    57  
    58  func parseData(data string, dataType string) interface{} {
    59  	var parse interface{}
    60  	var err error
    61  	switch dataType {
    62  	case "uint32":
    63  		var res uint64
    64  		res, err = strconv.ParseUint(data, 10, 32)
    65  		if err == nil {
    66  			parse = uint32(res)
    67  		} else {
    68  			fmt.Printf("err %s\n", err.Error())
    69  		}
    70  	case "int32":
    71  		var res int64
    72  		res, err = strconv.ParseInt(data, 10, 32)
    73  		if err == nil {
    74  			parse = int32(res)
    75  		}
    76  	case "uint64":
    77  		var res uint64
    78  		res, err = strconv.ParseUint(data, 10, 64)
    79  		if err == nil {
    80  			parse = uint64(res)
    81  		}
    82  	case "int64":
    83  		var res int64
    84  		res, err = strconv.ParseInt(data, 10, 64)
    85  		if err == nil {
    86  			parse = int64(res)
    87  		}
    88  	case "uint256":
    89  		bigint := new(big.Int)
    90  		_, flag := bigint.SetString(data, 10)
    91  		if flag == false {
    92  			panic("Illegal uint256 input " + data)
    93  		}
    94  		parse = bigint
    95  	case "string":
    96  		parse = data
    97  	case "address":
    98  		if data[0:2] != "0x" {
    99  			parse = common.BytesToAddress([]byte(data))
   100  		} else {
   101  			parse = common.HexToAddress(data)
   102  		}
   103  	case "bool":
   104  		if data == "true" {
   105  			parse = true
   106  		} else {
   107  			parse = false
   108  		}
   109  	default:
   110  		err = errors.New(fmt.Sprintf("unsupport data type %s", dataType))
   111  	}
   112  	if err != nil {
   113  		panic(err)
   114  	}
   115  	return parse
   116  }
   117  
   118  func packInput(abiobj abi.ABI, name string, args ...interface{}) []byte {
   119  	if name == "testfunctionnoexist" {
   120  		return []byte("")
   121  	}
   122  	abires := abiobj
   123  	var res []byte
   124  	var err error
   125  	if len(args) == 0 {
   126  		res, err = abires.Pack(name)
   127  	} else {
   128  		res, err = abires.Pack(name, args...)
   129  	}
   130  	if err != nil {
   131  		panic(err)
   132  	}
   133  	return res
   134  }
   135  
   136  func unpackOutput(abiobj abi.ABI, v interface{}, name string, output []byte) interface{} {
   137  	abires := abiobj
   138  	err := abires.Unpack(v, name, output)
   139  	if err != nil {
   140  		panic(err)
   141  	}
   142  	return v
   143  }
   144  
   145  func (t *ENVTest) newWAVM(statedb *state.StateDB, vmconfig vm.Config) vm.VM {
   146  	canTransfer := func(db inter.StateDB, address common.Address, amount *big.Int) bool {
   147  		return core.CanTransfer(db, address, amount)
   148  	}
   149  	transfer := func(db inter.StateDB, sender, recipient common.Address, amount *big.Int) {
   150  		core.Transfer(db, sender, recipient, amount)
   151  	}
   152  	context := vm.Context{
   153  		CanTransfer: canTransfer,
   154  		Transfer:    transfer,
   155  		GetHash:     vmTestBlockHash,
   156  		Origin:      t.json.Exec.Origin,
   157  		Coinbase:    t.json.Env.Coinbase,
   158  		BlockNumber: new(big.Int).SetUint64(t.json.Env.Number),
   159  		Time:        new(big.Int).SetUint64(t.json.Env.Timestamp),
   160  		GasLimit:    t.json.Env.GasLimit,
   161  		Difficulty:  t.json.Env.Difficulty,
   162  		GasPrice:    t.json.Exec.GasPrice,
   163  	}
   164  	return wavm.NewWAVM(context, statedb, params.AllCliqueProtocolChanges, vmconfig)
   165  }
   166  
   167  func (t *ENVTest) getStateDb() {
   168  	if t.statedb == nil {
   169  		db := vntdb.NewMemDatabase()
   170  		statedb := MakePreState(db, t.json.Pre)
   171  		t.statedb = statedb
   172  	}
   173  }
   174  
   175  func (t *ENVTest) Run(vmconfig vm.Config, data []byte, iscreate bool, needinit bool, test *testing.T) ([]byte, error) {
   176  	t.getStateDb()
   177  	// now := T.Now()
   178  	ret, _, err := t.exec(t.statedb, vmconfig, data, iscreate, needinit)
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  	// duration := T.Since(now)
   183  	// test.Logf("time duration %f", duration.Seconds())
   184  	// t.timeUsed += duration.Seconds()
   185  
   186  	// if t.json.GasRemaining == nil {
   187  	// 	if err == nil {
   188  	// 		return fmt.Errorf("gas unspecified (indicating an error), but VM returned no error")
   189  	// 	}
   190  	// 	if gasRemaining > 0 {
   191  	// 		return fmt.Errorf("gas unspecified (indicating an error), but VM returned gas remaining > 0")
   192  	// 	}
   193  	// 	return nil
   194  	// }
   195  	// Test declares gas, expecting outputs to match.
   196  	// if !bytes.Equal(ret, t.json.Out) {
   197  	// 	return fmt.Errorf("return data mismatch: got %x, want %x", ret, t.json.Out)
   198  	// }
   199  	// if gasRemaining != uint64(*t.json.GasRemaining) {
   200  	// 	return fmt.Errorf("remaining gas %v, want %v", gasRemaining, *t.json.GasRemaining)
   201  	// }
   202  	// for addr, account := range t.json.Post {
   203  	// 	for k, wantV := range account.Storage {
   204  	// 		if haveV := statedb.GetState(addr, k); haveV != wantV {
   205  	// 			return fmt.Errorf("wrong storage value at %x:\n  got  %x\n  want %x", k, haveV, wantV)
   206  	// 		}
   207  	// 	}
   208  	// }
   209  	// if root := statedb.IntermediateRoot(false); root != t.json.PostStateRoot {
   210  	// 	return fmt.Errorf("post state root mismatch, got %x, want %x", root, t.json.PostStateRoot)
   211  	// }
   212  	// if logs := rlpHash(statedb.Logs()); logs != common.Hash(t.json.Logs) {
   213  	// 	return fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, t.json.Logs)
   214  	// }
   215  	return ret, nil
   216  }
   217  
   218  func (t *ENVTest) exec(statedb *state.StateDB, vmconfig vm.Config, data []byte, isCreated bool, needinit bool) ([]byte, uint64, error) {
   219  	wavmobj := t.newWAVM(statedb, vmconfig)
   220  	e := t.json.Exec
   221  	if isCreated {
   222  		now := T.Now()
   223  		res, addr, gas, err := wavmobj.Create(vm.AccountRef(e.Caller), data, e.GasLimit, e.Value)
   224  		duration := T.Since(now)
   225  		t.createCost += duration.Seconds()
   226  		if needinit == true {
   227  			t.json.Exec.Address = addr
   228  		}
   229  		// t.compileCost += wavmobj.(*wavm.WAVM).Wavm.VM.CompileTimeCost
   230  		// t.createRunCost += wavmobj.(*wavm.WAVM).CreateTimeCost
   231  		fmt.Printf("create gas cost %d\n", gas)
   232  		return res, gas, err
   233  	} else {
   234  		now := T.Now()
   235  		res, gas, err := wavmobj.Call(vm.AccountRef(e.Caller), e.Address, data, e.GasLimit, e.Value)
   236  		duration := T.Since(now)
   237  		t.callCost += duration.Seconds()
   238  		// t.nocompileCost += wavmobj.(*wavm.WAVM).Wavm.VM.NoCompileTimeCost
   239  		// t.callRunCost += wavmobj.(*wavm.WAVM).CallTimeCost
   240  		return res, gas, err
   241  	}
   242  
   243  }
   244  
   245  func run(t *testing.T, jspath string) {
   246  	jsonfile, err := ioutil.ReadFile(jspath)
   247  	if err != nil {
   248  		t.Fatalf(err.Error())
   249  	}
   250  	vmconfig := vm.Config{Debug: true, Tracer: wavm.NewWasmLogger(&vm.LogConfig{Debug: true})}
   251  	envtest := new(ENVTest)
   252  	envtest.callCost = 0
   253  	envtest.createCost = 0
   254  	err = envtest.UnmarshalJSON(jsonfile)
   255  	if err != nil {
   256  		t.Fatalf(err.Error())
   257  	}
   258  
   259  	for i := 0; i < 1; i++ {
   260  		for _, v := range envtest.json.TestCase {
   261  
   262  			//init
   263  			code := wasmContract.WasmCode{}
   264  			code.Code = readFile(filepath.Join(v.Code))
   265  			code.Abi = readFile(filepath.Join(v.Abi))
   266  			parseinput := parseInput(v.InitCase.Input)
   267  			input := packInput(getABI(filepath.Join(v.Abi)), "", parseinput...)
   268  			c := append(code.Code, input...)
   269  			// fmt.Printf(hex.EncodeToString(c))
   270  			// pre := envtest.json.Pre[envtest.json.Exec.Address]
   271  			// pre.Code = c //[]byte(hexutil.Encode(c))
   272  			// envtest.json.Pre[envtest.json.Exec.Address] = pre
   273  			ret, err := envtest.Run(vmconfig, c, true, v.InitCase.NeedInit, t)
   274  			if err != nil {
   275  				t.Fatalf(err.Error())
   276  			}
   277  			if v.InitCase.NeedInit == false {
   278  				account := envtest.json.Pre[envtest.json.Exec.Address]
   279  				account.Code = ret
   280  				envtest.json.Pre[envtest.json.Exec.Address] = account
   281  				envtest.statedb = nil
   282  			}
   283  			// if v.InitCase.NeedInit == true {
   284  			// 	code := wasmContract.WasmCode{}
   285  			// 	code.Code = readFile(filepath.Join(v.Code))
   286  			// 	code.Abi = readFile(filepath.Join(v.Abi))
   287  			// 	parseinput := parseInput(v.InitCase.Input)
   288  			// 	input := packInput(getABI(filepath.Join(v.Abi)), "", parseinput...)
   289  			// 	c := append(code.Code, input...)
   290  			// 	// pre := envtest.json.Pre[envtest.json.Exec.Address]
   291  			// 	// pre.Code = c //[]byte(hexutil.Encode(c))
   292  			// 	// envtest.json.Pre[envtest.json.Exec.Address] = pre
   293  			// 	_, err = envtest.Run(vmconfig, c, true, t)
   294  			// 	if err != nil {
   295  			// 		t.Fatalf(err.Error())
   296  			// 	}
   297  			// } else {
   298  
   299  			// }
   300  
   301  			for _, testcase := range v.Tests {
   302  				var pack []byte
   303  				abiobj := getABI(filepath.Join(v.Abi))
   304  				if testcase.RawInput == nil {
   305  					input := parseInput(testcase.Input)
   306  					pack = packInput(abiobj, testcase.Function, input...)
   307  				} else {
   308  					pack = testcase.RawInput
   309  				}
   310  
   311  				ret, err := envtest.Run(vmconfig, pack, false, v.InitCase.NeedInit, t)
   312  				if err != nil {
   313  					if testcase.Error == err.Error() && testcase.Error != "" {
   314  						t.Logf("funcName %s\n", testcase.Function)
   315  						t.Logf("wavm err match, got %s, want %s", err, testcase.Error)
   316  						continue
   317  					} else {
   318  						if strings.HasPrefix(err.Error(), errorsmsg.ErrExecutionAssert.Error()) {
   319  							t.Logf("%s", err.Error())
   320  						} else if err.Error() == errorsmsg.ErrExecutionReverted.Error() {
   321  							t.Logf("%s", errorsmsg.ErrExecutionReverted)
   322  						} else {
   323  							t.Fatal(err)
   324  						}
   325  					}
   326  
   327  				}
   328  				verify(t, ret, testcase.Wanted, abiobj, testcase.Function)
   329  				// if testcase.Event != nil {
   330  				// 	fmt.Printf("logs %s\n", rlpHash(envtest.statedb.Logs()).Hex())
   331  				// 	fmt.Printf("logs %+v\n", envtest.statedb.Logs())
   332  				// 	res := envtest.statedb.Logs()[0].Data
   333  				// 	fmt.Printf("data %v\n", res)
   334  				// 	type testevent struct {
   335  				// 		Str  string
   336  				// 		Addr common.Address
   337  				// 		U64  uint64
   338  				// 		U32  uint32
   339  				// 		I64  int64
   340  				// 		I32  int32
   341  				// 		U256 *big.Int
   342  				// 		B    bool
   343  				// 	}
   344  				// 	var test1 testevent
   345  				// 	err := abiobj.Unpack(&test1, "TESTEVENT", res)
   346  				// 	if err != nil {
   347  				// 		panic(err)
   348  				// 	}
   349  				// 	fmt.Printf("test1 %+v\n", test1)
   350  				// 	// verifyEvent(t, ret, testcase.Event, abiobj, testcase.Function)
   351  				// }
   352  			}
   353  
   354  		}
   355  	}
   356  
   357  	t.Logf("create cost %f", envtest.createCost/10000.0)
   358  	t.Logf("call cost %f", envtest.callCost/10000.0)
   359  	t.Logf("compile cost %f", envtest.compileCost/10000.0)
   360  	t.Logf("nocompile cost %f", envtest.nocompileCost/10000.0)
   361  	t.Logf("create run cost %f", envtest.compileCost/10000.0)
   362  	t.Logf("call run cost %f", envtest.compileCost/10000.0)
   363  
   364  }
   365  
   366  func verify(t *testing.T, ret []byte, wanted argument, abiobj abi.ABI, funcName string) {
   367  	t.Logf("funcName %s\n", funcName)
   368  	data := wanted.Data
   369  	dataType := wanted.DataType
   370  	if wanted.DataType == "" {
   371  		return
   372  	}
   373  	parse := parseData(data, dataType)
   374  	switch dataType {
   375  	case "uint32":
   376  		want := parse.(uint32)
   377  		var got uint32
   378  		unpackOutput(abiobj, &got, funcName, ret)
   379  		if got != want {
   380  			t.Fatalf("wavm result mismatch, got %d, want %d", got, want)
   381  		} else {
   382  			t.Logf("wavm result match, got %d, want %d", got, want)
   383  		}
   384  	case "int32":
   385  		want := parse.(int32)
   386  		var got int32
   387  		unpackOutput(abiobj, &got, funcName, ret)
   388  		if got != want {
   389  			t.Fatalf("wavm result mismatch, got %d, want %d", got, want)
   390  		} else {
   391  			t.Logf("wavm result match, got %d, want %d", got, want)
   392  		}
   393  	case "uint64":
   394  		want := parse.(uint64)
   395  		var got uint64
   396  		log.Debug("111", "funcName", funcName, "ret", ret)
   397  		unpackOutput(abiobj, &got, funcName, ret)
   398  		if got != want {
   399  			t.Fatalf("wavm result mismatch, got %d, want %d", got, want)
   400  		} else {
   401  			t.Logf("wavm result match, got %d, want %d", got, want)
   402  		}
   403  	case "int64":
   404  		want := parse.(int64)
   405  		var got int64
   406  		unpackOutput(abiobj, &got, funcName, ret)
   407  		if got != want {
   408  			t.Fatalf("wavm result mismatch, got %d, want %d", got, want)
   409  		} else {
   410  			t.Logf("wavm result match, got %d, want %d", got, want)
   411  		}
   412  	case "uint256":
   413  		want := parse.(*big.Int)
   414  		var got *big.Int
   415  		unpackOutput(abiobj, &got, funcName, ret)
   416  		if got.Cmp(want) != 0 {
   417  			t.Fatalf("wavm result mismatch, got %d, want %d", got, want)
   418  		} else {
   419  			t.Logf("wavm result match, got %d, want %d", got, want)
   420  		}
   421  	case "string":
   422  		want := parse.(string)
   423  		var got string
   424  		unpackOutput(abiobj, &got, funcName, ret)
   425  		if got != want {
   426  			t.Fatalf("wavm result mismatch, got %s, want %s", got, want)
   427  		} else {
   428  			t.Logf("wavm result match, got %s, want %s", got, want)
   429  		}
   430  	case "address":
   431  		want := parse.(common.Address)
   432  		var got common.Address
   433  		unpackOutput(abiobj, &got, funcName, ret)
   434  		if got != want {
   435  			t.Fatalf("wavm result mismatch, got %s, want %s", got.Hex(), want.Hex())
   436  		} else {
   437  			t.Logf("wavm result match, got %s, want %s", got.Hex(), want.Hex())
   438  		}
   439  	case "bool":
   440  		want := parse.(bool)
   441  		var got bool
   442  		unpackOutput(abiobj, &got, funcName, ret)
   443  		if got != want {
   444  			t.Fatalf("wavm result mismatch, got %t, want %t", got, want)
   445  		} else {
   446  			t.Logf("wavm result match, got %t, want %t", got, want)
   447  		}
   448  	}
   449  
   450  }
   451  
   452  func verifyEvent(t *testing.T, ret []byte, wanted argument, abiobj abi.ABI, funcName string) {
   453  	t.Logf("funcName %s\n", funcName)
   454  	data := wanted.Data
   455  	dataType := wanted.DataType
   456  	if wanted.DataType == "" {
   457  		return
   458  	}
   459  	parse := parseData(data, dataType)
   460  	switch dataType {
   461  	case "uint32":
   462  		want := parse.(uint32)
   463  		var got uint32
   464  		unpackOutput(abiobj, &got, funcName, ret)
   465  		if got != want {
   466  			t.Fatalf("wavm result mismatch, got %d, want %d", got, want)
   467  		} else {
   468  			t.Logf("wavm result match, got %d, want %d", got, want)
   469  		}
   470  	case "int32":
   471  		want := parse.(int32)
   472  		var got int32
   473  		unpackOutput(abiobj, &got, funcName, ret)
   474  		if got != want {
   475  			t.Fatalf("wavm result mismatch, got %d, want %d", got, want)
   476  		} else {
   477  			t.Logf("wavm result match, got %d, want %d", got, want)
   478  		}
   479  	case "uint64":
   480  		want := parse.(uint64)
   481  		var got uint64
   482  		unpackOutput(abiobj, &got, funcName, ret)
   483  		if got != want {
   484  			t.Fatalf("wavm result mismatch, got %d, want %d", got, want)
   485  		} else {
   486  			t.Logf("wavm result match, got %d, want %d", got, want)
   487  		}
   488  	case "int64":
   489  		want := parse.(int64)
   490  		var got int64
   491  		unpackOutput(abiobj, &got, funcName, ret)
   492  		if got != want {
   493  			t.Fatalf("wavm result mismatch, got %d, want %d", got, want)
   494  		} else {
   495  			t.Logf("wavm result match, got %d, want %d", got, want)
   496  		}
   497  	case "uint256":
   498  		want := parse.(*big.Int)
   499  		var got *big.Int
   500  		unpackOutput(abiobj, &got, funcName, ret)
   501  		if got.Cmp(want) != 0 {
   502  			t.Fatalf("wavm result mismatch, got %d, want %d", got, want)
   503  		} else {
   504  			t.Logf("wavm result match, got %d, want %d", got, want)
   505  		}
   506  	case "string":
   507  		want := parse.(string)
   508  		var got string
   509  		unpackOutput(abiobj, &got, funcName, ret)
   510  		if got != want {
   511  			t.Fatalf("wavm result mismatch, got %s, want %s", got, want)
   512  		} else {
   513  			t.Logf("wavm result match, got %s, want %s", got, want)
   514  		}
   515  	case "address":
   516  		want := parse.(common.Address)
   517  		var got common.Address
   518  		unpackOutput(abiobj, &got, funcName, ret)
   519  		if got != want {
   520  			t.Fatalf("wavm result mismatch, got %s, want %s", got.Hex(), want.Hex())
   521  		} else {
   522  			t.Logf("wavm result match, got %s, want %s", got.Hex(), want.Hex())
   523  		}
   524  	case "bool":
   525  		want := parse.(bool)
   526  		var got bool
   527  		unpackOutput(abiobj, &got, funcName, ret)
   528  		if got != want {
   529  			t.Fatalf("wavm result mismatch, got %t, want %t", got, want)
   530  		} else {
   531  			t.Logf("wavm result match, got %t, want %t", got, want)
   532  		}
   533  	}
   534  
   535  }
   536  
   537  func TestEnv(t *testing.T) {
   538  	run(t, envJsonPath)
   539  }