github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/chain/core/state/statedb_test.go (about)

     1  // Copyright 2016 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package state
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/binary"
    22  	"fmt"
    23  
    24  	"math"
    25  	"math/big"
    26  	"math/rand"
    27  	"reflect"
    28  	"strings"
    29  	"testing"
    30  	"testing/quick"
    31  
    32  	"gopkg.in/check.v1"
    33  
    34  	"github.com/neatlab/neatio/chain/core/rawdb"
    35  	"github.com/neatlab/neatio/chain/core/types"
    36  	"github.com/neatlab/neatio/utilities/common"
    37  )
    38  
    39  // Tests that updating a state trie does not leak any database writes prior to
    40  // actually committing the state.
    41  func TestUpdateLeaks(t *testing.T) {
    42  	// Create an empty state database
    43  	db := rawdb.NewMemoryDatabase()
    44  	state, _ := New(common.Hash{}, NewDatabase(db))
    45  
    46  	// Update it with some accounts
    47  	for i := byte(0); i < 255; i++ {
    48  		addr := common.BytesToAddress([]byte{i})
    49  		state.AddBalance(addr, big.NewInt(int64(11*i)))
    50  		state.SetNonce(addr, uint64(42*i))
    51  		if i%2 == 0 {
    52  			state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
    53  		}
    54  		if i%3 == 0 {
    55  			state.SetCode(addr, []byte{i, i, i, i, i})
    56  		}
    57  		state.IntermediateRoot(false)
    58  	}
    59  	// Ensure that no data was leaked into the database
    60  	it := db.NewIterator()
    61  	for it.Next() {
    62  		t.Errorf("State leaked into database: %x -> %x", it.Key(), it.Value())
    63  	}
    64  	it.Release()
    65  }
    66  
    67  // Tests that no intermediate state of an object is stored into the database,
    68  // only the one right before the commit.
    69  func TestIntermediateLeaks(t *testing.T) {
    70  	// Create two state databases, one transitioning to the final state, the other final from the beginning
    71  	transDb := rawdb.NewMemoryDatabase()
    72  	finalDb := rawdb.NewMemoryDatabase()
    73  	transState, _ := New(common.Hash{}, NewDatabase(transDb))
    74  	finalState, _ := New(common.Hash{}, NewDatabase(finalDb))
    75  
    76  	modify := func(state *StateDB, addr common.Address, i, tweak byte) {
    77  		state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
    78  		state.SetNonce(addr, uint64(42*i+tweak))
    79  		if i%2 == 0 {
    80  			state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
    81  			state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
    82  		}
    83  		if i%3 == 0 {
    84  			state.SetCode(addr, []byte{i, i, i, i, i, tweak})
    85  		}
    86  	}
    87  
    88  	// Modify the transient state.
    89  	for i := byte(0); i < 255; i++ {
    90  		modify(transState, common.Address{byte(i)}, i, 0)
    91  	}
    92  	// Write modifications to trie.
    93  	transState.IntermediateRoot(false)
    94  
    95  	// Overwrite all the data with new values in the transient database.
    96  	for i := byte(0); i < 255; i++ {
    97  		modify(transState, common.Address{byte(i)}, i, 99)
    98  		modify(finalState, common.Address{byte(i)}, i, 99)
    99  	}
   100  
   101  	// Commit and cross check the databases.
   102  	if _, err := transState.Commit(false); err != nil {
   103  		t.Fatalf("failed to commit transition state: %v", err)
   104  	}
   105  	if _, err := finalState.Commit(false); err != nil {
   106  		t.Fatalf("failed to commit final state: %v", err)
   107  	}
   108  	it := finalDb.NewIterator()
   109  	for it.Next() {
   110  		key := it.Key()
   111  		if _, err := transDb.Get(key); err != nil {
   112  			t.Errorf("entry missing from the transition database: %x -> %x", key, it.Value())
   113  		}
   114  	}
   115  	it.Release()
   116  
   117  	it = transDb.NewIterator()
   118  	for it.Next() {
   119  		key := it.Key()
   120  		if _, err := finalDb.Get(key); err != nil {
   121  			t.Errorf("extra entry in the transition database: %x -> %x", key, it.Value())
   122  		}
   123  	}
   124  }
   125  
   126  // TestCopy tests that copying a statedb object indeed makes the original and
   127  // the copy independent of each other. This test is a regression test against
   128  // https://github.com/neatlab/neatio/pull/15549.
   129  func TestCopy(t *testing.T) {
   130  	// Create a random state test to copy and modify "independently"
   131  	orig, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
   132  
   133  	for i := byte(0); i < 255; i++ {
   134  		obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   135  		obj.AddBalance(big.NewInt(int64(i)))
   136  		orig.updateStateObject(obj)
   137  	}
   138  	orig.Finalise(false)
   139  
   140  	// Copy the state, modify both in-memory
   141  	copy := orig.Copy()
   142  
   143  	for i := byte(0); i < 255; i++ {
   144  		origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   145  		copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   146  
   147  		origObj.AddBalance(big.NewInt(2 * int64(i)))
   148  		copyObj.AddBalance(big.NewInt(3 * int64(i)))
   149  
   150  		orig.updateStateObject(origObj)
   151  		copy.updateStateObject(copyObj)
   152  	}
   153  	// Finalise the changes on both concurrently
   154  	done := make(chan struct{})
   155  	go func() {
   156  		orig.Finalise(true)
   157  		close(done)
   158  	}()
   159  	copy.Finalise(true)
   160  	<-done
   161  
   162  	// Verify that the two states have been updated independently
   163  	for i := byte(0); i < 255; i++ {
   164  		origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   165  		copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   166  
   167  		if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 {
   168  			t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want)
   169  		}
   170  		if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 {
   171  			t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want)
   172  		}
   173  	}
   174  }
   175  
   176  func TestSnapshotRandom(t *testing.T) {
   177  	config := &quick.Config{MaxCount: 1000}
   178  	err := quick.Check((*snapshotTest).run, config)
   179  	if cerr, ok := err.(*quick.CheckError); ok {
   180  		test := cerr.In[0].(*snapshotTest)
   181  		t.Errorf("%v:\n%s", test.err, test)
   182  	} else if err != nil {
   183  		t.Error(err)
   184  	}
   185  }
   186  
   187  // A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
   188  // captured by the snapshot. Instances of this test with pseudorandom content are created
   189  // by Generate.
   190  //
   191  // The test works as follows:
   192  //
   193  // A new state is created and all actions are applied to it. Several snapshots are taken
   194  // in between actions. The test then reverts each snapshot. For each snapshot the actions
   195  // leading up to it are replayed on a fresh, empty state. The behaviour of all public
   196  // accessor methods on the reverted state must match the return value of the equivalent
   197  // methods on the replayed state.
   198  type snapshotTest struct {
   199  	addrs     []common.Address // all account addresses
   200  	actions   []testAction     // modifications to the state
   201  	snapshots []int            // actions indexes at which snapshot is taken
   202  	err       error            // failure details are reported through this field
   203  }
   204  
   205  type testAction struct {
   206  	name   string
   207  	fn     func(testAction, *StateDB)
   208  	args   []int64
   209  	noAddr bool
   210  }
   211  
   212  // newTestAction creates a random action that changes state.
   213  func newTestAction(addr common.Address, r *rand.Rand) testAction {
   214  	actions := []testAction{
   215  		{
   216  			name: "SetBalance",
   217  			fn: func(a testAction, s *StateDB) {
   218  				s.SetBalance(addr, big.NewInt(a.args[0]))
   219  			},
   220  			args: make([]int64, 1),
   221  		},
   222  		{
   223  			name: "AddBalance",
   224  			fn: func(a testAction, s *StateDB) {
   225  				s.AddBalance(addr, big.NewInt(a.args[0]))
   226  			},
   227  			args: make([]int64, 1),
   228  		},
   229  		{
   230  			name: "SetNonce",
   231  			fn: func(a testAction, s *StateDB) {
   232  				s.SetNonce(addr, uint64(a.args[0]))
   233  			},
   234  			args: make([]int64, 1),
   235  		},
   236  		{
   237  			name: "SetState",
   238  			fn: func(a testAction, s *StateDB) {
   239  				var key, val common.Hash
   240  				binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
   241  				binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
   242  				s.SetState(addr, key, val)
   243  			},
   244  			args: make([]int64, 2),
   245  		},
   246  		{
   247  			name: "SetCode",
   248  			fn: func(a testAction, s *StateDB) {
   249  				code := make([]byte, 16)
   250  				binary.BigEndian.PutUint64(code, uint64(a.args[0]))
   251  				binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
   252  				s.SetCode(addr, code)
   253  			},
   254  			args: make([]int64, 2),
   255  		},
   256  		{
   257  			name: "CreateAccount",
   258  			fn: func(a testAction, s *StateDB) {
   259  				s.CreateAccount(addr)
   260  			},
   261  		},
   262  		{
   263  			name: "Suicide",
   264  			fn: func(a testAction, s *StateDB) {
   265  				s.Suicide(addr)
   266  			},
   267  		},
   268  		{
   269  			name: "AddRefund",
   270  			fn: func(a testAction, s *StateDB) {
   271  				s.AddRefund(uint64(a.args[0]))
   272  			},
   273  			args:   make([]int64, 1),
   274  			noAddr: true,
   275  		},
   276  		{
   277  			name: "AddLog",
   278  			fn: func(a testAction, s *StateDB) {
   279  				data := make([]byte, 2)
   280  				binary.BigEndian.PutUint16(data, uint16(a.args[0]))
   281  				s.AddLog(&types.Log{Address: addr, Data: data})
   282  			},
   283  			args: make([]int64, 1),
   284  		},
   285  	}
   286  	action := actions[r.Intn(len(actions))]
   287  	var nameargs []string
   288  	if !action.noAddr {
   289  		//nameargs = append(nameargs, addr.Hex())
   290  		nameargs = append(nameargs, addr.String())
   291  	}
   292  	for _, i := range action.args {
   293  		action.args[i] = rand.Int63n(100)
   294  		nameargs = append(nameargs, fmt.Sprint(action.args[i]))
   295  	}
   296  	action.name += strings.Join(nameargs, ", ")
   297  	return action
   298  }
   299  
   300  // Generate returns a new snapshot test of the given size. All randomness is
   301  // derived from r.
   302  func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
   303  	// Generate random actions.
   304  	addrs := make([]common.Address, 50)
   305  	for i := range addrs {
   306  		addrs[i][0] = byte(i)
   307  	}
   308  	actions := make([]testAction, size)
   309  	for i := range actions {
   310  		addr := addrs[r.Intn(len(addrs))]
   311  		actions[i] = newTestAction(addr, r)
   312  	}
   313  	// Generate snapshot indexes.
   314  	nsnapshots := int(math.Sqrt(float64(size)))
   315  	if size > 0 && nsnapshots == 0 {
   316  		nsnapshots = 1
   317  	}
   318  	snapshots := make([]int, nsnapshots)
   319  	snaplen := len(actions) / nsnapshots
   320  	for i := range snapshots {
   321  		// Try to place the snapshots some number of actions apart from each other.
   322  		snapshots[i] = (i * snaplen) + r.Intn(snaplen)
   323  	}
   324  	return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
   325  }
   326  
   327  func (test *snapshotTest) String() string {
   328  	out := new(bytes.Buffer)
   329  	sindex := 0
   330  	for i, action := range test.actions {
   331  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   332  			fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
   333  			sindex++
   334  		}
   335  		fmt.Fprintf(out, "%4d: %s\n", i, action.name)
   336  	}
   337  	return out.String()
   338  }
   339  
   340  func (test *snapshotTest) run() bool {
   341  	// Run all actions and create snapshots.
   342  	var (
   343  		state, _     = New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
   344  		snapshotRevs = make([]int, len(test.snapshots))
   345  		sindex       = 0
   346  	)
   347  	for i, action := range test.actions {
   348  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   349  			snapshotRevs[sindex] = state.Snapshot()
   350  			sindex++
   351  		}
   352  		action.fn(action, state)
   353  	}
   354  	// Revert all snapshots in reverse order. Each revert must yield a state
   355  	// that is equivalent to fresh state with all actions up the snapshot applied.
   356  	for sindex--; sindex >= 0; sindex-- {
   357  		checkstate, _ := New(common.Hash{}, state.Database())
   358  		for _, action := range test.actions[:test.snapshots[sindex]] {
   359  			action.fn(action, checkstate)
   360  		}
   361  		state.RevertToSnapshot(snapshotRevs[sindex])
   362  		if err := test.checkEqual(state, checkstate); err != nil {
   363  			test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
   364  			return false
   365  		}
   366  	}
   367  	return true
   368  }
   369  
   370  // checkEqual checks that methods of state and checkstate return the same values.
   371  func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
   372  	for _, addr := range test.addrs {
   373  		var err error
   374  		checkeq := func(op string, a, b interface{}) bool {
   375  			if err == nil && !reflect.DeepEqual(a, b) {
   376  				//err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
   377  				err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.String(), a, b)
   378  				return false
   379  			}
   380  			return true
   381  		}
   382  		// Check basic accessor methods.
   383  		checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
   384  		checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
   385  		checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
   386  		checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
   387  		checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
   388  		checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
   389  		checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
   390  		// Check storage.
   391  		if obj := state.getStateObject(addr); obj != nil {
   392  			state.ForEachStorage(addr, func(key, val common.Hash) bool {
   393  				return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
   394  			})
   395  			checkstate.ForEachStorage(addr, func(key, checkval common.Hash) bool {
   396  				return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
   397  			})
   398  		}
   399  		if err != nil {
   400  			return err
   401  		}
   402  	}
   403  
   404  	if state.GetRefund() != checkstate.GetRefund() {
   405  		return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
   406  			state.GetRefund(), checkstate.GetRefund())
   407  	}
   408  	if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
   409  		return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
   410  			state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
   411  	}
   412  	return nil
   413  }
   414  
   415  func (s *StateSuite) TestTouchDelete(c *check.C) {
   416  	s.state.GetOrNewStateObject(common.Address{})
   417  	root, _ := s.state.Commit(false)
   418  	s.state.Reset(root)
   419  
   420  	snapshot := s.state.Snapshot()
   421  	s.state.AddBalance(common.Address{}, new(big.Int))
   422  	if len(s.state.stateObjectsDirty) != 1 {
   423  		c.Fatal("expected one dirty state object")
   424  	}
   425  	s.state.RevertToSnapshot(snapshot)
   426  	if len(s.state.stateObjectsDirty) != 0 {
   427  		c.Fatal("expected no dirty state object")
   428  	}
   429  }
   430  
   431  // TestCopyOfCopy tests that modified objects are carried over to the copy, and the copy of the copy.
   432  // See https://github.com/neatlab/neatio/pull/15225#issuecomment-380191512
   433  func TestCopyOfCopy(t *testing.T) {
   434  	sdb, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
   435  	addr := common.HexToAddress("aaaa")
   436  	sdb.SetBalance(addr, big.NewInt(42))
   437  
   438  	if got := sdb.Copy().GetBalance(addr).Uint64(); got != 42 {
   439  		t.Fatalf("1st copy fail, expected 42, got %v", got)
   440  	}
   441  	if got := sdb.Copy().Copy().GetBalance(addr).Uint64(); got != 42 {
   442  		t.Fatalf("2nd copy fail, expected 42, got %v", got)
   443  	}
   444  }