github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/chain/core/vm/logger_test.go (about)

     1  package vm
     2  
     3  import (
     4  	"math/big"
     5  	"testing"
     6  
     7  	"github.com/neatlab/neatio/chain/core/state"
     8  	"github.com/neatlab/neatio/params"
     9  	"github.com/neatlab/neatio/utilities/common"
    10  )
    11  
    12  type dummyContractRef struct {
    13  	calledForEach bool
    14  }
    15  
    16  func (dummyContractRef) ReturnGas(*big.Int)          {}
    17  func (dummyContractRef) Address() common.Address     { return common.Address{} }
    18  func (dummyContractRef) Value() *big.Int             { return new(big.Int) }
    19  func (dummyContractRef) SetCode(common.Hash, []byte) {}
    20  func (d *dummyContractRef) ForEachStorage(callback func(key, value common.Hash) bool) {
    21  	d.calledForEach = true
    22  }
    23  func (d *dummyContractRef) SubBalance(amount *big.Int) {}
    24  func (d *dummyContractRef) AddBalance(amount *big.Int) {}
    25  func (d *dummyContractRef) SetBalance(*big.Int)        {}
    26  func (d *dummyContractRef) SetNonce(uint64)            {}
    27  func (d *dummyContractRef) Balance() *big.Int          { return new(big.Int) }
    28  
    29  type dummyStatedb struct {
    30  	state.StateDB
    31  }
    32  
    33  func (*dummyStatedb) GetRefund() uint64 { return 1337 }
    34  
    35  func TestStoreCapture(t *testing.T) {
    36  	var (
    37  		env      = NewEVM(Context{}, &dummyStatedb{}, params.TestChainConfig, Config{})
    38  		logger   = NewStructLogger(nil)
    39  		mem      = NewMemory()
    40  		stack    = newstack()
    41  		contract = NewContract(&dummyContractRef{}, &dummyContractRef{}, new(big.Int), 0)
    42  	)
    43  	stack.push(big.NewInt(1))
    44  	stack.push(big.NewInt(0))
    45  	var index common.Hash
    46  	logger.CaptureState(env, 0, SSTORE, 0, 0, mem, stack, contract, 0, nil)
    47  	if len(logger.changedValues[contract.Address()]) == 0 {
    48  		t.Fatalf("expected exactly 1 changed value on address %x, got %d", contract.Address(), len(logger.changedValues[contract.Address()]))
    49  	}
    50  	exp := common.BigToHash(big.NewInt(1))
    51  	if logger.changedValues[contract.Address()][index] != exp {
    52  		t.Errorf("expected %x, got %x", exp, logger.changedValues[contract.Address()][index])
    53  	}
    54  }