github.com/aergoio/aergo@v1.3.1/contract/vm_dummy.go (about)

     1  package contract
     2  
     3  // helper functions
     4  import (
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"math/big"
    10  	"os"
    11  	"path"
    12  	"regexp"
    13  	"strconv"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/aergoio/aergo-lib/db"
    18  	luac_util "github.com/aergoio/aergo/cmd/aergoluac/util"
    19  	"github.com/aergoio/aergo/contract/system"
    20  	"github.com/aergoio/aergo/state"
    21  	"github.com/aergoio/aergo/types"
    22  	"github.com/minio/sha256-simd"
    23  )
    24  
    25  type DummyChain struct {
    26  	sdb           *state.ChainStateDB
    27  	bestBlock     *types.Block
    28  	cBlock        *types.Block
    29  	bestBlockNo   types.BlockNo
    30  	bestBlockId   types.BlockID
    31  	blockIds      []types.BlockID
    32  	blocks        []*types.Block
    33  	testReceiptDB db.DB
    34  	tmpDir        string
    35  }
    36  
    37  var addressRegexp *regexp.Regexp
    38  var traceState bool
    39  
    40  func init() {
    41  	addressRegexp, _ = regexp.Compile("^[a-zA-Z0-9]+$")
    42  	//	traceState = true
    43  }
    44  
    45  func LoadDummyChain() (*DummyChain, error) {
    46  	dataPath, err := ioutil.TempDir("", "data")
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  	bc := &DummyChain{
    51  		sdb:    state.NewChainStateDB(),
    52  		tmpDir: dataPath,
    53  	}
    54  	defer func() {
    55  		if err != nil {
    56  			bc.Release()
    57  		}
    58  	}()
    59  
    60  	err = bc.sdb.Init(string(db.BadgerImpl), dataPath, nil, false)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	genesis := types.GetTestGenesis()
    65  	bc.sdb.SetGenesis(genesis, nil)
    66  	bc.bestBlockNo = genesis.Block().BlockNo()
    67  	bc.bestBlockId = genesis.Block().BlockID()
    68  	bc.blockIds = append(bc.blockIds, bc.bestBlockId)
    69  	bc.blocks = append(bc.blocks, genesis.Block())
    70  	bc.testReceiptDB = db.NewDB(db.BadgerImpl, path.Join(dataPath, "receiptDB"))
    71  	LoadTestDatabase(dataPath) // sql database
    72  	SetStateSQLMaxDBSize(1024)
    73  	StartLStateFactory()
    74  
    75  	// To pass the governance tests.
    76  	types.InitGovernance("dpos", true)
    77  	system.InitGovernance("dpos")
    78  
    79  	return bc, nil
    80  }
    81  
    82  func (bc *DummyChain) Release() {
    83  	bc.testReceiptDB.Close()
    84  	_ = os.RemoveAll(bc.tmpDir)
    85  }
    86  
    87  func (bc *DummyChain) BestBlockNo() uint64 {
    88  	return bc.bestBlockNo
    89  }
    90  
    91  func (bc *DummyChain) newBState() *state.BlockState {
    92  	b := types.Block{
    93  		Header: &types.BlockHeader{
    94  			PrevBlockHash: bc.bestBlockId[:],
    95  			BlockNo:       bc.bestBlockNo + 1,
    96  			Timestamp:     time.Now().UnixNano(),
    97  		},
    98  	}
    99  	bc.cBlock = &b
   100  	// blockInfo := types.NewBlockInfo(b.BlockNo(), b.BlockID(), bc.bestBlockId)
   101  	return state.NewBlockState(bc.sdb.OpenNewStateDB(bc.sdb.GetRoot()))
   102  }
   103  
   104  func (bc *DummyChain) BeginReceiptTx() db.Transaction {
   105  	return bc.testReceiptDB.NewTx()
   106  }
   107  
   108  func (bc *DummyChain) GetABI(contract string) (*types.ABI, error) {
   109  	cState, err := bc.sdb.GetStateDB().OpenContractStateAccount(types.ToAccountID(strHash(contract)))
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	return GetABI(cState)
   114  }
   115  
   116  func (bc *DummyChain) GetEvents(tx *luaTxCall) []*types.Event {
   117  	h := sha256.New()
   118  	h.Write([]byte(strconv.FormatUint(tx.id, 10)))
   119  	b := h.Sum(nil)
   120  
   121  	receipt := bc.getReceipt(b)
   122  	if receipt != nil {
   123  		return receipt.Events
   124  	}
   125  
   126  	return nil
   127  }
   128  
   129  func (bc *DummyChain) getReceipt(txHash []byte) *types.Receipt {
   130  	r := new(types.Receipt)
   131  	r.UnmarshalBinary(bc.testReceiptDB.Get(txHash))
   132  	return r
   133  }
   134  
   135  func (bc *DummyChain) GetAccountState(name string) (*types.State, error) {
   136  	return bc.sdb.GetStateDB().GetAccountState(types.ToAccountID(strHash(name)))
   137  }
   138  
   139  func (bc *DummyChain) GetStaking(name string) (*types.Staking, error) {
   140  	scs, err := bc.sdb.GetStateDB().OpenContractStateAccount(types.ToAccountID([]byte(types.AergoSystem)))
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  	return system.GetStaking(scs, strHash(name))
   145  }
   146  
   147  func (bc *DummyChain) GetBlockByNo(blockNo types.BlockNo) (*types.Block, error) {
   148  	return bc.blocks[blockNo], nil
   149  }
   150  
   151  func (bc *DummyChain) GetBestBlock() (*types.Block, error) {
   152  	return bc.bestBlock, nil
   153  }
   154  
   155  type luaTx interface {
   156  	run(bs *state.BlockState, bc *DummyChain, blockNo uint64, ts int64, prevBlockHash []byte, receiptTx db.Transaction) error
   157  }
   158  
   159  type luaTxAccount struct {
   160  	name    []byte
   161  	balance *big.Int
   162  }
   163  
   164  func NewLuaTxAccount(name string, balance uint64) *luaTxAccount {
   165  	return &luaTxAccount{
   166  		name:    strHash(name),
   167  		balance: new(big.Int).SetUint64(balance),
   168  	}
   169  }
   170  
   171  func NewLuaTxAccountBig(name string, balance *big.Int) *luaTxAccount {
   172  	return &luaTxAccount{
   173  		name:    strHash(name),
   174  		balance: balance,
   175  	}
   176  }
   177  
   178  func (l *luaTxAccount) run(bs *state.BlockState, bc *DummyChain, blockNo uint64, ts int64, prevBlockHash []byte,
   179  	receiptTx db.Transaction) error {
   180  
   181  	id := types.ToAccountID(l.name)
   182  	accountState, err := bs.GetAccountState(id)
   183  	if err != nil {
   184  		return err
   185  	}
   186  	updatedAccountState := types.State(*accountState)
   187  	updatedAccountState.Balance = l.balance.Bytes()
   188  	bs.PutState(id, &updatedAccountState)
   189  	return nil
   190  }
   191  
   192  type luaTxSend struct {
   193  	sender   []byte
   194  	receiver []byte
   195  	balance  *big.Int
   196  }
   197  
   198  func NewLuaTxSendBig(sender, receiver string, balance *big.Int) *luaTxSend {
   199  	return &luaTxSend{
   200  		sender:   strHash(sender),
   201  		receiver: strHash(receiver),
   202  		balance:  balance,
   203  	}
   204  }
   205  
   206  func (l *luaTxSend) run(bs *state.BlockState, bc *DummyChain, blockNo uint64, ts int64, prevBlockHash []byte,
   207  	receiptTx db.Transaction) error {
   208  
   209  	senderID := types.ToAccountID(l.sender)
   210  	receiverID := types.ToAccountID(l.receiver)
   211  
   212  	if senderID == receiverID {
   213  		return fmt.Errorf("sender and receiever cannot be same")
   214  	}
   215  
   216  	senderState, err := bs.GetAccountState(senderID)
   217  	if err != nil {
   218  		return err
   219  	} else if senderState.GetBalanceBigInt().Cmp(l.balance) < 0 {
   220  		return fmt.Errorf("insufficient balance to sender")
   221  	}
   222  	receiverState, err := bs.GetAccountState(receiverID)
   223  	if err != nil {
   224  		return err
   225  	}
   226  
   227  	updatedSenderState := types.State(*senderState)
   228  	updatedSenderState.Balance = new(big.Int).Sub(updatedSenderState.GetBalanceBigInt(), l.balance).Bytes()
   229  	bs.PutState(senderID, &updatedSenderState)
   230  
   231  	updatedReceiverState := types.State(*receiverState)
   232  	updatedReceiverState.Balance = new(big.Int).Add(updatedReceiverState.GetBalanceBigInt(), l.balance).Bytes()
   233  	bs.PutState(receiverID, &updatedReceiverState)
   234  
   235  	return nil
   236  }
   237  
   238  type luaTxCommon struct {
   239  	sender   []byte
   240  	contract []byte
   241  	amount   *big.Int
   242  	code     []byte
   243  	id       uint64
   244  }
   245  
   246  type luaTxDef struct {
   247  	luaTxCommon
   248  	cErr error
   249  }
   250  
   251  func NewLuaTxDef(sender, contract string, amount uint64, code string) *luaTxDef {
   252  	L := luac_util.NewLState()
   253  	if L == nil {
   254  		return &luaTxDef{cErr: newVmStartError()}
   255  	}
   256  	defer luac_util.CloseLState(L)
   257  	b, err := luac_util.Compile(L, code)
   258  	if err != nil {
   259  		return &luaTxDef{cErr: err}
   260  	}
   261  	codeWithInit := make([]byte, 4+len(b))
   262  	binary.LittleEndian.PutUint32(codeWithInit, uint32(4+len(b)))
   263  	copy(codeWithInit[4:], b)
   264  	return &luaTxDef{
   265  		luaTxCommon: luaTxCommon{
   266  			sender:   strHash(sender),
   267  			contract: strHash(contract),
   268  			code:     codeWithInit,
   269  			amount:   new(big.Int).SetUint64(amount),
   270  			id:       newTxId(),
   271  		},
   272  		cErr: nil,
   273  	}
   274  }
   275  
   276  func getCompiledABI(code string) ([]byte, error) {
   277  
   278  	L := luac_util.NewLState()
   279  	if L == nil {
   280  		return nil, newVmStartError()
   281  	}
   282  	defer luac_util.CloseLState(L)
   283  	b, err := luac_util.Compile(L, code)
   284  	if err != nil {
   285  		return nil, err
   286  	}
   287  
   288  	codeLen := binary.LittleEndian.Uint32(b[:4])
   289  
   290  	return b[4+codeLen:], nil
   291  }
   292  
   293  func NewRawLuaTxDefBig(sender, contract string, amount *big.Int, code string) *luaTxDef {
   294  
   295  	byteAbi, err := getCompiledABI(code)
   296  	if err != nil {
   297  		return &luaTxDef{cErr: err}
   298  	}
   299  
   300  	byteCode := []byte(code)
   301  	payload := make([]byte, 8+len(byteCode)+len(byteAbi))
   302  	binary.LittleEndian.PutUint32(payload[0:], uint32(len(byteCode)+len(byteAbi)+8))
   303  	binary.LittleEndian.PutUint32(payload[4:], uint32(len(byteCode)))
   304  	codeLen := copy(payload[8:], byteCode)
   305  	copy(payload[8+codeLen:], byteAbi)
   306  
   307  	return &luaTxDef{
   308  		luaTxCommon: luaTxCommon{
   309  			sender:   strHash(sender),
   310  			contract: strHash(contract),
   311  			code:     payload,
   312  			amount:   amount,
   313  			id:       newTxId(),
   314  		},
   315  		cErr: nil,
   316  	}
   317  }
   318  
   319  func strHash(d string) []byte {
   320  	// using real address
   321  	if len(d) == types.EncodedAddressLength && addressRegexp.MatchString(d) {
   322  		return types.ToAddress(d)
   323  	} else {
   324  		// using alias
   325  		h := sha256.New()
   326  		h.Write([]byte(d))
   327  		b := h.Sum(nil)
   328  		b = append([]byte{0x0C}, b...)
   329  		return b
   330  	}
   331  }
   332  
   333  var luaTxId uint64 = 0
   334  
   335  func newTxId() uint64 {
   336  	luaTxId++
   337  	return luaTxId
   338  }
   339  
   340  func (l *luaTxDef) hash() []byte {
   341  	h := sha256.New()
   342  	h.Write([]byte(strconv.FormatUint(l.id, 10)))
   343  	b := h.Sum(nil)
   344  	return b
   345  }
   346  
   347  func (l *luaTxDef) Constructor(args string) *luaTxDef {
   348  	argsLen := len([]byte(args))
   349  	if argsLen == 0 || l.cErr != nil {
   350  		return l
   351  	}
   352  
   353  	code := make([]byte, len(l.code)+argsLen)
   354  	codeLen := copy(code[0:], l.code)
   355  	binary.LittleEndian.PutUint32(code[0:], uint32(codeLen))
   356  	copy(code[codeLen:], []byte(args))
   357  
   358  	l.code = code
   359  
   360  	return l
   361  }
   362  
   363  func contractFrame(l *luaTxCommon, bs *state.BlockState,
   364  	run func(s, c *state.V, id types.AccountID, cs *state.ContractState) error) error {
   365  
   366  	creatorId := types.ToAccountID(l.sender)
   367  	creatorState, err := bs.GetAccountStateV(l.sender)
   368  	if err != nil {
   369  		return err
   370  	}
   371  
   372  	contractId := types.ToAccountID(l.contract)
   373  	contractState, err := bs.GetAccountStateV(l.contract)
   374  	if err != nil {
   375  		return err
   376  	}
   377  
   378  	eContractState, err := bs.OpenContractState(contractId, contractState.State())
   379  	if err != nil {
   380  		return err
   381  	}
   382  
   383  	creatorState.SubBalance(l.amount)
   384  	contractState.AddBalance(l.amount)
   385  	err = run(creatorState, contractState, contractId, eContractState)
   386  	if err != nil {
   387  		return err
   388  	}
   389  
   390  	bs.PutState(creatorId, creatorState.State())
   391  	bs.PutState(contractId, contractState.State())
   392  	return nil
   393  
   394  }
   395  
   396  func (l *luaTxDef) run(bs *state.BlockState, bc *DummyChain, blockNo uint64, ts int64, prevBlockHash []byte,
   397  	receiptTx db.Transaction) error {
   398  
   399  	if l.cErr != nil {
   400  		return l.cErr
   401  	}
   402  
   403  	return contractFrame(&l.luaTxCommon, bs,
   404  		func(sender, contract *state.V, contractId types.AccountID, eContractState *state.ContractState) error {
   405  			contract.State().SqlRecoveryPoint = 1
   406  
   407  			stateSet := NewContext(bs, nil, sender, contract, eContractState, sender.ID(),
   408  				l.hash(), blockNo, ts, prevBlockHash, "", true,
   409  				false, contract.State().SqlRecoveryPoint, ChainService, l.luaTxCommon.amount)
   410  
   411  			if traceState {
   412  				stateSet.traceFile, _ =
   413  					os.OpenFile("test.trace", os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644)
   414  				defer stateSet.traceFile.Close()
   415  			}
   416  
   417  			_, _, _, err := Create(eContractState, l.code, l.contract, stateSet)
   418  			if err != nil {
   419  				return err
   420  			}
   421  			err = bs.StageContractState(eContractState)
   422  			if err != nil {
   423  				return err
   424  			}
   425  			return nil
   426  		},
   427  	)
   428  }
   429  
   430  type luaTxCall struct {
   431  	luaTxCommon
   432  	expectedErr string
   433  }
   434  
   435  func NewLuaTxCall(sender, contract string, amount uint64, code string) *luaTxCall {
   436  	return &luaTxCall{
   437  		luaTxCommon: luaTxCommon{
   438  			sender:   strHash(sender),
   439  			contract: strHash(contract),
   440  			amount:   new(big.Int).SetUint64(amount),
   441  			code:     []byte(code),
   442  			id:       newTxId(),
   443  		},
   444  	}
   445  }
   446  
   447  func NewLuaTxCallBig(sender, contract string, amount *big.Int, code string) *luaTxCall {
   448  	return &luaTxCall{
   449  		luaTxCommon: luaTxCommon{
   450  			sender:   strHash(sender),
   451  			contract: strHash(contract),
   452  			amount:   amount,
   453  			code:     []byte(code),
   454  			id:       newTxId(),
   455  		},
   456  	}
   457  }
   458  
   459  func (l *luaTxCall) hash() []byte {
   460  	h := sha256.New()
   461  	h.Write([]byte(strconv.FormatUint(l.id, 10)))
   462  	b := h.Sum(nil)
   463  	return b
   464  }
   465  
   466  func (l *luaTxCall) Fail(expectedErr string) *luaTxCall {
   467  	l.expectedErr = expectedErr
   468  	return l
   469  }
   470  
   471  func (l *luaTxCall) run(bs *state.BlockState, bc *DummyChain, blockNo uint64, ts int64, prevBlockHash []byte,
   472  	receiptTx db.Transaction) error {
   473  	err := contractFrame(&l.luaTxCommon, bs,
   474  		func(sender, contract *state.V, contractId types.AccountID, eContractState *state.ContractState) error {
   475  			stateSet := NewContext(bs, bc, sender, contract, eContractState, sender.ID(),
   476  				l.hash(), blockNo, ts, prevBlockHash, "", true,
   477  				false, contract.State().SqlRecoveryPoint, ChainService, l.luaTxCommon.amount)
   478  			if traceState {
   479  				stateSet.traceFile, _ =
   480  					os.OpenFile("test.trace", os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644)
   481  				defer stateSet.traceFile.Close()
   482  			}
   483  			rv, evs, _, err := Call(eContractState, l.code, l.contract, stateSet)
   484  			if err != nil {
   485  				r := types.NewReceipt(l.contract, err.Error(), "")
   486  				r.TxHash = l.hash()
   487  				b, _ := r.MarshalBinary()
   488  				receiptTx.Set(l.hash(), b)
   489  				return err
   490  			}
   491  			_ = bs.StageContractState(eContractState)
   492  			r := types.NewReceipt(l.contract, "SUCCESS", rv)
   493  			r.Events = evs
   494  			r.TxHash = l.hash()
   495  			blockHash := make([]byte, 32)
   496  			for _, ev := range evs {
   497  				ev.TxHash = r.TxHash
   498  				ev.BlockHash = blockHash
   499  			}
   500  			b, _ := r.MarshalBinary()
   501  			receiptTx.Set(l.hash(), b)
   502  			return nil
   503  		},
   504  	)
   505  	if l.expectedErr != "" {
   506  		if err == nil {
   507  			return fmt.Errorf("no error, expected: %s", l.expectedErr)
   508  		}
   509  		if !strings.Contains(err.Error(), l.expectedErr) {
   510  			return err
   511  		}
   512  		return nil
   513  	}
   514  	return err
   515  }
   516  
   517  func (bc *DummyChain) ConnectBlock(txs ...luaTx) error {
   518  	blockState := bc.newBState()
   519  	tx := bc.BeginReceiptTx()
   520  	defer tx.Commit()
   521  	defer CloseDatabase()
   522  
   523  	for _, x := range txs {
   524  		if err := x.run(blockState, bc, bc.cBlock.Header.BlockNo, bc.cBlock.Header.Timestamp,
   525  			bc.cBlock.Header.PrevBlockHash, tx); err != nil {
   526  			return err
   527  		}
   528  	}
   529  	err := SaveRecoveryPoint(blockState)
   530  	if err != nil {
   531  		return err
   532  	}
   533  	err = bc.sdb.Apply(blockState)
   534  	if err != nil {
   535  		return err
   536  	}
   537  	//FIXME newblock must be created after sdb.apply()
   538  	bc.cBlock.SetBlocksRootHash(bc.sdb.GetRoot())
   539  	bc.bestBlockNo = bc.bestBlockNo + 1
   540  	bc.bestBlock = bc.cBlock
   541  	bc.bestBlockId = types.ToBlockID(bc.cBlock.BlockHash())
   542  	bc.blockIds = append(bc.blockIds, bc.bestBlockId)
   543  	bc.blocks = append(bc.blocks, bc.cBlock)
   544  
   545  	return nil
   546  }
   547  
   548  func (bc *DummyChain) DisConnectBlock() error {
   549  	if len(bc.blockIds) == 1 {
   550  		return errors.New("genesis block")
   551  	}
   552  	bc.bestBlockNo--
   553  	bc.blockIds = bc.blockIds[0 : len(bc.blockIds)-1]
   554  	bc.blocks = bc.blocks[0 : len(bc.blocks)-1]
   555  	bc.bestBlockId = bc.blockIds[len(bc.blockIds)-1]
   556  
   557  	bestBlock := bc.blocks[len(bc.blocks)-1]
   558  
   559  	var sroot []byte
   560  	if bestBlock != nil {
   561  		sroot = bestBlock.GetHeader().GetBlocksRootHash()
   562  	}
   563  	return bc.sdb.SetRoot(sroot)
   564  }
   565  
   566  func (bc *DummyChain) Query(contract, queryInfo, expectedErr string, expectedRvs ...string) error {
   567  	cState, err := bc.sdb.GetStateDB().OpenContractStateAccount(types.ToAccountID(strHash(contract)))
   568  	if err != nil {
   569  		return err
   570  	}
   571  	rv, err := Query(strHash(contract), bc.newBState(), bc, cState, []byte(queryInfo))
   572  	if expectedErr != "" {
   573  		if err == nil {
   574  			return fmt.Errorf("no error, expected: %s", expectedErr)
   575  		}
   576  		if !strings.Contains(err.Error(), expectedErr) {
   577  			return err
   578  		}
   579  		return nil
   580  	}
   581  	if err != nil {
   582  		return err
   583  	}
   584  
   585  	for _, ev := range expectedRvs {
   586  		if ev != string(rv) {
   587  			err = fmt.Errorf("expected: %s, but got: %s", ev, string(rv))
   588  		} else {
   589  			return nil
   590  		}
   591  	}
   592  	return err
   593  }
   594  
   595  func (bc *DummyChain) QueryOnly(contract, queryInfo string, expectedErr string) (bool, string, error) {
   596  	cState, err := bc.sdb.GetStateDB().OpenContractStateAccount(types.ToAccountID(strHash(contract)))
   597  	if err != nil {
   598  		return false, "", err
   599  	}
   600  	rv, err := Query(strHash(contract), bc.newBState(), nil, cState, []byte(queryInfo))
   601  
   602  	if expectedErr != "" {
   603  		if err == nil {
   604  			return false, "", fmt.Errorf("no error, expected: %s", expectedErr)
   605  		}
   606  		if !strings.Contains(err.Error(), expectedErr) {
   607  			return false, "", err
   608  		}
   609  		return true, "", nil
   610  	}
   611  
   612  	if err != nil {
   613  		return false, "", err
   614  	}
   615  
   616  	return false, string(rv), nil
   617  }
   618  
   619  func StrToAddress(name string) string {
   620  	return types.EncodeAddress(strHash(name))
   621  }