github.com/hyperion-hyn/go-ethereum@v2.4.0+incompatible/core/state/statedb_test.go (about)

     1  // Copyright 2016 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package state
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/binary"
    22  	"fmt"
    23  	"math"
    24  	"math/big"
    25  	"math/rand"
    26  	"reflect"
    27  	"strings"
    28  	"testing"
    29  	"testing/quick"
    30  
    31  	check "gopkg.in/check.v1"
    32  
    33  	"github.com/ethereum/go-ethereum/common"
    34  	"github.com/ethereum/go-ethereum/core/types"
    35  	"github.com/ethereum/go-ethereum/ethdb"
    36  )
    37  
    38  // Tests that updating a state trie does not leak any database writes prior to
    39  // actually committing the state.
    40  func TestUpdateLeaks(t *testing.T) {
    41  	// Create an empty state database
    42  	db := ethdb.NewMemDatabase()
    43  	state, _ := New(common.Hash{}, NewDatabase(db))
    44  
    45  	// Update it with some accounts
    46  	for i := byte(0); i < 255; i++ {
    47  		addr := common.BytesToAddress([]byte{i})
    48  		state.AddBalance(addr, big.NewInt(int64(11*i)))
    49  		state.SetNonce(addr, uint64(42*i))
    50  		if i%2 == 0 {
    51  			state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
    52  		}
    53  		if i%3 == 0 {
    54  			state.SetCode(addr, []byte{i, i, i, i, i})
    55  		}
    56  		state.IntermediateRoot(false)
    57  	}
    58  	// Ensure that no data was leaked into the database
    59  	for _, key := range db.Keys() {
    60  		value, _ := db.Get(key)
    61  		t.Errorf("State leaked into database: %x -> %x", key, value)
    62  	}
    63  }
    64  
    65  // Tests that no intermediate state of an object is stored into the database,
    66  // only the one right before the commit.
    67  func TestIntermediateLeaks(t *testing.T) {
    68  	// Create two state databases, one transitioning to the final state, the other final from the beginning
    69  	transDb := ethdb.NewMemDatabase()
    70  	finalDb := ethdb.NewMemDatabase()
    71  	transState, _ := New(common.Hash{}, NewDatabase(transDb))
    72  	finalState, _ := New(common.Hash{}, NewDatabase(finalDb))
    73  
    74  	modify := func(state *StateDB, addr common.Address, i, tweak byte) {
    75  		state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
    76  		state.SetNonce(addr, uint64(42*i+tweak))
    77  		if i%2 == 0 {
    78  			state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
    79  			state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
    80  		}
    81  		if i%3 == 0 {
    82  			state.SetCode(addr, []byte{i, i, i, i, i, tweak})
    83  		}
    84  	}
    85  
    86  	// Modify the transient state.
    87  	for i := byte(0); i < 255; i++ {
    88  		modify(transState, common.Address{byte(i)}, i, 0)
    89  	}
    90  	// Write modifications to trie.
    91  	transState.IntermediateRoot(false)
    92  
    93  	// Overwrite all the data with new values in the transient database.
    94  	for i := byte(0); i < 255; i++ {
    95  		modify(transState, common.Address{byte(i)}, i, 99)
    96  		modify(finalState, common.Address{byte(i)}, i, 99)
    97  	}
    98  
    99  	// Commit and cross check the databases.
   100  	if _, err := transState.Commit(false); err != nil {
   101  		t.Fatalf("failed to commit transition state: %v", err)
   102  	}
   103  	if _, err := finalState.Commit(false); err != nil {
   104  		t.Fatalf("failed to commit final state: %v", err)
   105  	}
   106  	for _, key := range finalDb.Keys() {
   107  		if _, err := transDb.Get(key); err != nil {
   108  			val, _ := finalDb.Get(key)
   109  			t.Errorf("entry missing from the transition database: %x -> %x", key, val)
   110  		}
   111  	}
   112  	for _, key := range transDb.Keys() {
   113  		if _, err := finalDb.Get(key); err != nil {
   114  			val, _ := transDb.Get(key)
   115  			t.Errorf("extra entry in the transition database: %x -> %x", key, val)
   116  		}
   117  	}
   118  }
   119  
   120  func TestStorageRoot(t *testing.T) {
   121  	var (
   122  		mem      = ethdb.NewMemDatabase()
   123  		db       = NewDatabase(mem)
   124  		state, _ = New(common.Hash{}, db)
   125  		addr     = common.Address{1}
   126  		key      = common.Hash{1}
   127  		value    = common.Hash{42}
   128  
   129  		empty = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
   130  	)
   131  
   132  	so := state.GetOrNewStateObject(addr)
   133  
   134  	emptyRoot := so.storageRoot(db)
   135  	if emptyRoot != empty {
   136  		t.Errorf("Invalid empty storage root, expected %x, got %x", empty, emptyRoot)
   137  	}
   138  
   139  	// add a bit of state
   140  	so.SetState(db, key, value)
   141  	state.Commit(false)
   142  
   143  	root := so.storageRoot(db)
   144  	expected := common.HexToHash("63511abd258fa907afa30cb118b53744a4f49055bb3f531da512c6b866fc2ffb")
   145  
   146  	if expected != root {
   147  		t.Errorf("Invalid storage root, expected %x, got %x", expected, root)
   148  	}
   149  }
   150  
   151  // TestCopy tests that copying a statedb object indeed makes the original and
   152  // the copy independent of each other. This test is a regression test against
   153  // https://github.com/ethereum/go-ethereum/pull/15549.
   154  func TestCopy(t *testing.T) {
   155  	// Create a random state test to copy and modify "independently"
   156  	orig, _ := New(common.Hash{}, NewDatabase(ethdb.NewMemDatabase()))
   157  
   158  	for i := byte(0); i < 255; i++ {
   159  		obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   160  		obj.AddBalance(big.NewInt(int64(i)))
   161  		orig.updateStateObject(obj)
   162  	}
   163  	orig.Finalise(false)
   164  
   165  	// Copy the state, modify both in-memory
   166  	copy := orig.Copy()
   167  
   168  	for i := byte(0); i < 255; i++ {
   169  		origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   170  		copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   171  
   172  		origObj.AddBalance(big.NewInt(2 * int64(i)))
   173  		copyObj.AddBalance(big.NewInt(3 * int64(i)))
   174  
   175  		orig.updateStateObject(origObj)
   176  		copy.updateStateObject(copyObj)
   177  	}
   178  	// Finalise the changes on both concurrently
   179  	done := make(chan struct{})
   180  	go func() {
   181  		orig.Finalise(true)
   182  		close(done)
   183  	}()
   184  	copy.Finalise(true)
   185  	<-done
   186  
   187  	// Verify that the two states have been updated independently
   188  	for i := byte(0); i < 255; i++ {
   189  		origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   190  		copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   191  
   192  		if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 {
   193  			t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want)
   194  		}
   195  		if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 {
   196  			t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want)
   197  		}
   198  	}
   199  }
   200  
   201  func TestSnapshotRandom(t *testing.T) {
   202  	config := &quick.Config{MaxCount: 1000}
   203  	err := quick.Check((*snapshotTest).run, config)
   204  	if cerr, ok := err.(*quick.CheckError); ok {
   205  		test := cerr.In[0].(*snapshotTest)
   206  		t.Errorf("%v:\n%s", test.err, test)
   207  	} else if err != nil {
   208  		t.Error(err)
   209  	}
   210  }
   211  
   212  // A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
   213  // captured by the snapshot. Instances of this test with pseudorandom content are created
   214  // by Generate.
   215  //
   216  // The test works as follows:
   217  //
   218  // A new state is created and all actions are applied to it. Several snapshots are taken
   219  // in between actions. The test then reverts each snapshot. For each snapshot the actions
   220  // leading up to it are replayed on a fresh, empty state. The behaviour of all public
   221  // accessor methods on the reverted state must match the return value of the equivalent
   222  // methods on the replayed state.
   223  type snapshotTest struct {
   224  	addrs     []common.Address // all account addresses
   225  	actions   []testAction     // modifications to the state
   226  	snapshots []int            // actions indexes at which snapshot is taken
   227  	err       error            // failure details are reported through this field
   228  }
   229  
   230  type testAction struct {
   231  	name   string
   232  	fn     func(testAction, *StateDB)
   233  	args   []int64
   234  	noAddr bool
   235  }
   236  
   237  // newTestAction creates a random action that changes state.
   238  func newTestAction(addr common.Address, r *rand.Rand) testAction {
   239  	actions := []testAction{
   240  		{
   241  			name: "SetBalance",
   242  			fn: func(a testAction, s *StateDB) {
   243  				s.SetBalance(addr, big.NewInt(a.args[0]))
   244  			},
   245  			args: make([]int64, 1),
   246  		},
   247  		{
   248  			name: "AddBalance",
   249  			fn: func(a testAction, s *StateDB) {
   250  				s.AddBalance(addr, big.NewInt(a.args[0]))
   251  			},
   252  			args: make([]int64, 1),
   253  		},
   254  		{
   255  			name: "SetNonce",
   256  			fn: func(a testAction, s *StateDB) {
   257  				s.SetNonce(addr, uint64(a.args[0]))
   258  			},
   259  			args: make([]int64, 1),
   260  		},
   261  		{
   262  			name: "SetState",
   263  			fn: func(a testAction, s *StateDB) {
   264  				var key, val common.Hash
   265  				binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
   266  				binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
   267  				s.SetState(addr, key, val)
   268  			},
   269  			args: make([]int64, 2),
   270  		},
   271  		{
   272  			name: "SetCode",
   273  			fn: func(a testAction, s *StateDB) {
   274  				code := make([]byte, 16)
   275  				binary.BigEndian.PutUint64(code, uint64(a.args[0]))
   276  				binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
   277  				s.SetCode(addr, code)
   278  			},
   279  			args: make([]int64, 2),
   280  		},
   281  		{
   282  			name: "CreateAccount",
   283  			fn: func(a testAction, s *StateDB) {
   284  				s.CreateAccount(addr)
   285  			},
   286  		},
   287  		{
   288  			name: "Suicide",
   289  			fn: func(a testAction, s *StateDB) {
   290  				s.Suicide(addr)
   291  			},
   292  		},
   293  		{
   294  			name: "AddRefund",
   295  			fn: func(a testAction, s *StateDB) {
   296  				s.AddRefund(uint64(a.args[0]))
   297  			},
   298  			args:   make([]int64, 1),
   299  			noAddr: true,
   300  		},
   301  		{
   302  			name: "AddLog",
   303  			fn: func(a testAction, s *StateDB) {
   304  				data := make([]byte, 2)
   305  				binary.BigEndian.PutUint16(data, uint16(a.args[0]))
   306  				s.AddLog(&types.Log{Address: addr, Data: data})
   307  			},
   308  			args: make([]int64, 1),
   309  		},
   310  	}
   311  	action := actions[r.Intn(len(actions))]
   312  	var nameargs []string
   313  	if !action.noAddr {
   314  		nameargs = append(nameargs, addr.Hex())
   315  	}
   316  	for _, i := range action.args {
   317  		action.args[i] = rand.Int63n(100)
   318  		nameargs = append(nameargs, fmt.Sprint(action.args[i]))
   319  	}
   320  	action.name += strings.Join(nameargs, ", ")
   321  	return action
   322  }
   323  
   324  // Generate returns a new snapshot test of the given size. All randomness is
   325  // derived from r.
   326  func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
   327  	// Generate random actions.
   328  	addrs := make([]common.Address, 50)
   329  	for i := range addrs {
   330  		addrs[i][0] = byte(i)
   331  	}
   332  	actions := make([]testAction, size)
   333  	for i := range actions {
   334  		addr := addrs[r.Intn(len(addrs))]
   335  		actions[i] = newTestAction(addr, r)
   336  	}
   337  	// Generate snapshot indexes.
   338  	nsnapshots := int(math.Sqrt(float64(size)))
   339  	if size > 0 && nsnapshots == 0 {
   340  		nsnapshots = 1
   341  	}
   342  	snapshots := make([]int, nsnapshots)
   343  	snaplen := len(actions) / nsnapshots
   344  	for i := range snapshots {
   345  		// Try to place the snapshots some number of actions apart from each other.
   346  		snapshots[i] = (i * snaplen) + r.Intn(snaplen)
   347  	}
   348  	return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
   349  }
   350  
   351  func (test *snapshotTest) String() string {
   352  	out := new(bytes.Buffer)
   353  	sindex := 0
   354  	for i, action := range test.actions {
   355  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   356  			fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
   357  			sindex++
   358  		}
   359  		fmt.Fprintf(out, "%4d: %s\n", i, action.name)
   360  	}
   361  	return out.String()
   362  }
   363  
   364  func (test *snapshotTest) run() bool {
   365  	// Run all actions and create snapshots.
   366  	var (
   367  		state, _     = New(common.Hash{}, NewDatabase(ethdb.NewMemDatabase()))
   368  		snapshotRevs = make([]int, len(test.snapshots))
   369  		sindex       = 0
   370  	)
   371  	for i, action := range test.actions {
   372  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   373  			snapshotRevs[sindex] = state.Snapshot()
   374  			sindex++
   375  		}
   376  		action.fn(action, state)
   377  	}
   378  	// Revert all snapshots in reverse order. Each revert must yield a state
   379  	// that is equivalent to fresh state with all actions up the snapshot applied.
   380  	for sindex--; sindex >= 0; sindex-- {
   381  		checkstate, _ := New(common.Hash{}, state.Database())
   382  		for _, action := range test.actions[:test.snapshots[sindex]] {
   383  			action.fn(action, checkstate)
   384  		}
   385  		state.RevertToSnapshot(snapshotRevs[sindex])
   386  		if err := test.checkEqual(state, checkstate); err != nil {
   387  			test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
   388  			return false
   389  		}
   390  	}
   391  	return true
   392  }
   393  
   394  // checkEqual checks that methods of state and checkstate return the same values.
   395  func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
   396  	for _, addr := range test.addrs {
   397  		var err error
   398  		checkeq := func(op string, a, b interface{}) bool {
   399  			if err == nil && !reflect.DeepEqual(a, b) {
   400  				err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
   401  				return false
   402  			}
   403  			return true
   404  		}
   405  		// Check basic accessor methods.
   406  		checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
   407  		checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
   408  		checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
   409  		checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
   410  		checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
   411  		checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
   412  		checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
   413  		// Check storage.
   414  		if obj := state.getStateObject(addr); obj != nil {
   415  			state.ForEachStorage(addr, func(key, value common.Hash) bool {
   416  				return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
   417  			})
   418  			checkstate.ForEachStorage(addr, func(key, value common.Hash) bool {
   419  				return checkeq("GetState("+key.Hex()+")", checkstate.GetState(addr, key), value)
   420  			})
   421  		}
   422  		if err != nil {
   423  			return err
   424  		}
   425  	}
   426  
   427  	if state.GetRefund() != checkstate.GetRefund() {
   428  		return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
   429  			state.GetRefund(), checkstate.GetRefund())
   430  	}
   431  	if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
   432  		return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
   433  			state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
   434  	}
   435  	return nil
   436  }
   437  
   438  func (s *StateSuite) TestTouchDelete(c *check.C) {
   439  	s.state.GetOrNewStateObject(common.Address{})
   440  	root, _ := s.state.Commit(false)
   441  	s.state.Reset(root)
   442  
   443  	snapshot := s.state.Snapshot()
   444  	s.state.AddBalance(common.Address{}, new(big.Int))
   445  
   446  	if len(s.state.journal.dirties) != 1 {
   447  		c.Fatal("expected one dirty state object")
   448  	}
   449  	s.state.RevertToSnapshot(snapshot)
   450  	if len(s.state.journal.dirties) != 0 {
   451  		c.Fatal("expected no dirty state object")
   452  	}
   453  }
   454  
   455  // TestCopyOfCopy tests that modified objects are carried over to the copy, and the copy of the copy.
   456  // See https://github.com/ethereum/go-ethereum/pull/15225#issuecomment-380191512
   457  func TestCopyOfCopy(t *testing.T) {
   458  	sdb, _ := New(common.Hash{}, NewDatabase(ethdb.NewMemDatabase()))
   459  	addr := common.HexToAddress("aaaa")
   460  	sdb.SetBalance(addr, big.NewInt(42))
   461  
   462  	if got := sdb.Copy().GetBalance(addr).Uint64(); got != 42 {
   463  		t.Fatalf("1st copy fail, expected 42, got %v", got)
   464  	}
   465  	if got := sdb.Copy().Copy().GetBalance(addr).Uint64(); got != 42 {
   466  		t.Fatalf("2nd copy fail, expected 42, got %v", got)
   467  	}
   468  }