github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/fvm/evm/emulator/state/stateDB_test.go (about)

     1  package state_test
     2  
     3  import (
     4  	"fmt"
     5  	"math/big"
     6  	"testing"
     7  
     8  	"github.com/onflow/atree"
     9  	gethCommon "github.com/onflow/go-ethereum/common"
    10  	gethTypes "github.com/onflow/go-ethereum/core/types"
    11  	gethParams "github.com/onflow/go-ethereum/params"
    12  	"github.com/stretchr/testify/require"
    13  
    14  	"github.com/onflow/flow-go/fvm/evm/emulator/state"
    15  	"github.com/onflow/flow-go/fvm/evm/testutils"
    16  	"github.com/onflow/flow-go/fvm/evm/types"
    17  	"github.com/onflow/flow-go/model/flow"
    18  )
    19  
    20  var rootAddr = flow.Address{1, 2, 3, 4, 5, 6, 7, 8}
    21  
    22  func TestStateDB(t *testing.T) {
    23  	t.Parallel()
    24  
    25  	t.Run("test Empty method", func(t *testing.T) {
    26  		ledger := testutils.GetSimpleValueStore()
    27  		db, err := state.NewStateDB(ledger, rootAddr)
    28  		require.NoError(t, err)
    29  
    30  		addr1 := testutils.RandomCommonAddress(t)
    31  		// non-existent account
    32  		require.True(t, db.Empty(addr1))
    33  		require.NoError(t, db.Error())
    34  
    35  		db.CreateAccount(addr1)
    36  		require.NoError(t, db.Error())
    37  
    38  		require.True(t, db.Empty(addr1))
    39  		require.NoError(t, db.Error())
    40  
    41  		db.AddBalance(addr1, big.NewInt(10))
    42  		require.NoError(t, db.Error())
    43  
    44  		require.False(t, db.Empty(addr1))
    45  	})
    46  
    47  	t.Run("test commit functionality", func(t *testing.T) {
    48  		ledger := testutils.GetSimpleValueStore()
    49  		db, err := state.NewStateDB(ledger, rootAddr)
    50  		require.NoError(t, err)
    51  
    52  		addr1 := testutils.RandomCommonAddress(t)
    53  		key1 := testutils.RandomCommonHash(t)
    54  		value1 := testutils.RandomCommonHash(t)
    55  
    56  		db.CreateAccount(addr1)
    57  		require.NoError(t, db.Error())
    58  
    59  		db.AddBalance(addr1, big.NewInt(5))
    60  		require.NoError(t, db.Error())
    61  
    62  		// should have code to be able to set state
    63  		db.SetCode(addr1, []byte{1, 2, 3})
    64  		require.NoError(t, db.Error())
    65  
    66  		db.SetState(addr1, key1, value1)
    67  
    68  		ret := db.GetState(addr1, key1)
    69  		require.Equal(t, value1, ret)
    70  
    71  		ret = db.GetCommittedState(addr1, key1)
    72  		require.Equal(t, gethCommon.Hash{}, ret)
    73  
    74  		err = db.Commit(true)
    75  		require.NoError(t, err)
    76  
    77  		ret = db.GetCommittedState(addr1, key1)
    78  		require.Equal(t, value1, ret)
    79  
    80  		// create a new db
    81  		db, err = state.NewStateDB(ledger, rootAddr)
    82  		require.NoError(t, err)
    83  
    84  		bal := db.GetBalance(addr1)
    85  		require.NoError(t, db.Error())
    86  		require.Equal(t, big.NewInt(5), bal)
    87  
    88  		val := db.GetState(addr1, key1)
    89  		require.NoError(t, db.Error())
    90  		require.Equal(t, value1, val)
    91  	})
    92  
    93  	t.Run("test snapshot and revert functionality", func(t *testing.T) {
    94  		ledger := testutils.GetSimpleValueStore()
    95  		db, err := state.NewStateDB(ledger, rootAddr)
    96  		require.NoError(t, err)
    97  
    98  		addr1 := testutils.RandomCommonAddress(t)
    99  		require.False(t, db.Exist(addr1))
   100  		require.NoError(t, db.Error())
   101  
   102  		snapshot1 := db.Snapshot()
   103  		require.Equal(t, 1, snapshot1)
   104  
   105  		db.CreateAccount(addr1)
   106  		require.NoError(t, db.Error())
   107  
   108  		require.True(t, db.Exist(addr1))
   109  		require.NoError(t, db.Error())
   110  
   111  		db.AddBalance(addr1, big.NewInt(5))
   112  		require.NoError(t, db.Error())
   113  
   114  		bal := db.GetBalance(addr1)
   115  		require.NoError(t, db.Error())
   116  		require.Equal(t, big.NewInt(5), bal)
   117  
   118  		snapshot2 := db.Snapshot()
   119  		require.Equal(t, 2, snapshot2)
   120  
   121  		db.AddBalance(addr1, big.NewInt(5))
   122  		require.NoError(t, db.Error())
   123  
   124  		bal = db.GetBalance(addr1)
   125  		require.NoError(t, db.Error())
   126  		require.Equal(t, big.NewInt(10), bal)
   127  
   128  		// revert to snapshot 2
   129  		db.RevertToSnapshot(snapshot2)
   130  		require.NoError(t, db.Error())
   131  
   132  		bal = db.GetBalance(addr1)
   133  		require.NoError(t, db.Error())
   134  		require.Equal(t, big.NewInt(5), bal)
   135  
   136  		// revert to snapshot 1
   137  		db.RevertToSnapshot(snapshot1)
   138  		require.NoError(t, db.Error())
   139  
   140  		bal = db.GetBalance(addr1)
   141  		require.NoError(t, db.Error())
   142  		require.Equal(t, big.NewInt(0), bal)
   143  
   144  		// revert to an invalid snapshot
   145  		db.RevertToSnapshot(10)
   146  		require.Error(t, db.Error())
   147  	})
   148  
   149  	t.Run("test log functionality", func(t *testing.T) {
   150  		ledger := testutils.GetSimpleValueStore()
   151  		db, err := state.NewStateDB(ledger, rootAddr)
   152  		require.NoError(t, err)
   153  
   154  		logs := []*gethTypes.Log{
   155  			testutils.GetRandomLogFixture(t),
   156  			testutils.GetRandomLogFixture(t),
   157  			testutils.GetRandomLogFixture(t),
   158  			testutils.GetRandomLogFixture(t),
   159  		}
   160  
   161  		db.AddLog(logs[0])
   162  		db.AddLog(logs[1])
   163  
   164  		_ = db.Snapshot()
   165  
   166  		db.AddLog(logs[2])
   167  		db.AddLog(logs[3])
   168  
   169  		snapshot := db.Snapshot()
   170  		db.AddLog(testutils.GetRandomLogFixture(t))
   171  		db.RevertToSnapshot(snapshot)
   172  
   173  		ret := db.Logs(1, gethCommon.Hash{}, 1)
   174  		require.Equal(t, ret, logs)
   175  	})
   176  
   177  	t.Run("test refund functionality", func(t *testing.T) {
   178  		ledger := testutils.GetSimpleValueStore()
   179  		db, err := state.NewStateDB(ledger, rootAddr)
   180  		require.NoError(t, err)
   181  
   182  		require.Equal(t, uint64(0), db.GetRefund())
   183  		db.AddRefund(10)
   184  		require.Equal(t, uint64(10), db.GetRefund())
   185  		db.SubRefund(3)
   186  		require.Equal(t, uint64(7), db.GetRefund())
   187  
   188  		snap1 := db.Snapshot()
   189  		db.AddRefund(10)
   190  		require.Equal(t, uint64(17), db.GetRefund())
   191  
   192  		db.RevertToSnapshot(snap1)
   193  		require.Equal(t, uint64(7), db.GetRefund())
   194  	})
   195  
   196  	t.Run("test Prepare functionality", func(t *testing.T) {
   197  		ledger := testutils.GetSimpleValueStore()
   198  		db, err := state.NewStateDB(ledger, rootAddr)
   199  
   200  		sender := testutils.RandomCommonAddress(t)
   201  		coinbase := testutils.RandomCommonAddress(t)
   202  		dest := testutils.RandomCommonAddress(t)
   203  		precompiles := []gethCommon.Address{
   204  			testutils.RandomCommonAddress(t),
   205  			testutils.RandomCommonAddress(t),
   206  		}
   207  
   208  		txAccesses := gethTypes.AccessList([]gethTypes.AccessTuple{
   209  			{Address: testutils.RandomCommonAddress(t),
   210  				StorageKeys: []gethCommon.Hash{
   211  					testutils.RandomCommonHash(t),
   212  					testutils.RandomCommonHash(t),
   213  				},
   214  			},
   215  		})
   216  
   217  		rules := gethParams.Rules{
   218  			IsBerlin:   true,
   219  			IsShanghai: true,
   220  		}
   221  
   222  		require.NoError(t, err)
   223  		db.Prepare(rules, sender, coinbase, &dest, precompiles, txAccesses)
   224  
   225  		require.True(t, db.AddressInAccessList(sender))
   226  		require.True(t, db.AddressInAccessList(coinbase))
   227  		require.True(t, db.AddressInAccessList(dest))
   228  
   229  		for _, add := range precompiles {
   230  			require.True(t, db.AddressInAccessList(add))
   231  		}
   232  
   233  		for _, el := range txAccesses {
   234  			for _, key := range el.StorageKeys {
   235  				addrFound, slotFound := db.SlotInAccessList(el.Address, key)
   236  				require.True(t, addrFound)
   237  				require.True(t, slotFound)
   238  			}
   239  		}
   240  	})
   241  
   242  	t.Run("test non-fatal error handling", func(t *testing.T) {
   243  		ledger := &testutils.TestValueStore{
   244  			GetValueFunc: func(owner, key []byte) ([]byte, error) {
   245  				return nil, nil
   246  			},
   247  			SetValueFunc: func(owner, key, value []byte) error {
   248  				return atree.NewUserError(fmt.Errorf("key not found"))
   249  			},
   250  			AllocateStorageIndexFunc: func(owner []byte) (atree.SlabIndex, error) {
   251  				return atree.SlabIndex{}, nil
   252  			},
   253  		}
   254  		db, err := state.NewStateDB(ledger, rootAddr)
   255  		require.NoError(t, err)
   256  
   257  		db.CreateAccount(testutils.RandomCommonAddress(t))
   258  
   259  		err = db.Commit(true)
   260  		// ret := db.Error()
   261  		require.Error(t, err)
   262  		// check wrapping
   263  		require.True(t, types.IsAStateError(err))
   264  	})
   265  
   266  	t.Run("test fatal error handling", func(t *testing.T) {
   267  		ledger := &testutils.TestValueStore{
   268  			GetValueFunc: func(owner, key []byte) ([]byte, error) {
   269  				return nil, nil
   270  			},
   271  			SetValueFunc: func(owner, key, value []byte) error {
   272  				return atree.NewFatalError(fmt.Errorf("key not found"))
   273  			},
   274  			AllocateStorageIndexFunc: func(owner []byte) (atree.SlabIndex, error) {
   275  				return atree.SlabIndex{}, nil
   276  			},
   277  		}
   278  		db, err := state.NewStateDB(ledger, rootAddr)
   279  		require.NoError(t, err)
   280  
   281  		db.CreateAccount(testutils.RandomCommonAddress(t))
   282  
   283  		err = db.Commit(true)
   284  		// ret := db.Error()
   285  		require.Error(t, err)
   286  		// check wrapping
   287  		require.True(t, types.IsAFatalError(err))
   288  	})
   289  
   290  }