github.com/ylsGit/go-ethereum@v1.6.5/core/state/statedb_test.go (about)

     1  // Copyright 2016 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package state
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/binary"
    22  	"fmt"
    23  	"math"
    24  	"math/big"
    25  	"math/rand"
    26  	"reflect"
    27  	"strings"
    28  	"testing"
    29  	"testing/quick"
    30  
    31  	"github.com/ethereum/go-ethereum/common"
    32  	"github.com/ethereum/go-ethereum/core/types"
    33  	"github.com/ethereum/go-ethereum/ethdb"
    34  )
    35  
    36  // Tests that updating a state trie does not leak any database writes prior to
    37  // actually committing the state.
    38  func TestUpdateLeaks(t *testing.T) {
    39  	// Create an empty state database
    40  	db, _ := ethdb.NewMemDatabase()
    41  	state, _ := New(common.Hash{}, db)
    42  
    43  	// Update it with some accounts
    44  	for i := byte(0); i < 255; i++ {
    45  		addr := common.BytesToAddress([]byte{i})
    46  		state.AddBalance(addr, big.NewInt(int64(11*i)))
    47  		state.SetNonce(addr, uint64(42*i))
    48  		if i%2 == 0 {
    49  			state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
    50  		}
    51  		if i%3 == 0 {
    52  			state.SetCode(addr, []byte{i, i, i, i, i})
    53  		}
    54  		state.IntermediateRoot(false)
    55  	}
    56  	// Ensure that no data was leaked into the database
    57  	for _, key := range db.Keys() {
    58  		value, _ := db.Get(key)
    59  		t.Errorf("State leaked into database: %x -> %x", key, value)
    60  	}
    61  }
    62  
    63  // Tests that no intermediate state of an object is stored into the database,
    64  // only the one right before the commit.
    65  func TestIntermediateLeaks(t *testing.T) {
    66  	// Create two state databases, one transitioning to the final state, the other final from the beginning
    67  	transDb, _ := ethdb.NewMemDatabase()
    68  	finalDb, _ := ethdb.NewMemDatabase()
    69  	transState, _ := New(common.Hash{}, transDb)
    70  	finalState, _ := New(common.Hash{}, finalDb)
    71  
    72  	modify := func(state *StateDB, addr common.Address, i, tweak byte) {
    73  		state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
    74  		state.SetNonce(addr, uint64(42*i+tweak))
    75  		if i%2 == 0 {
    76  			state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
    77  			state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
    78  		}
    79  		if i%3 == 0 {
    80  			state.SetCode(addr, []byte{i, i, i, i, i, tweak})
    81  		}
    82  	}
    83  
    84  	// Modify the transient state.
    85  	for i := byte(0); i < 255; i++ {
    86  		modify(transState, common.Address{byte(i)}, i, 0)
    87  	}
    88  	// Write modifications to trie.
    89  	transState.IntermediateRoot(false)
    90  
    91  	// Overwrite all the data with new values in the transient database.
    92  	for i := byte(0); i < 255; i++ {
    93  		modify(transState, common.Address{byte(i)}, i, 99)
    94  		modify(finalState, common.Address{byte(i)}, i, 99)
    95  	}
    96  
    97  	// Commit and cross check the databases.
    98  	if _, err := transState.Commit(false); err != nil {
    99  		t.Fatalf("failed to commit transition state: %v", err)
   100  	}
   101  	if _, err := finalState.Commit(false); err != nil {
   102  		t.Fatalf("failed to commit final state: %v", err)
   103  	}
   104  	for _, key := range finalDb.Keys() {
   105  		if _, err := transDb.Get(key); err != nil {
   106  			val, _ := finalDb.Get(key)
   107  			t.Errorf("entry missing from the transition database: %x -> %x", key, val)
   108  		}
   109  	}
   110  	for _, key := range transDb.Keys() {
   111  		if _, err := finalDb.Get(key); err != nil {
   112  			val, _ := transDb.Get(key)
   113  			t.Errorf("extra entry in the transition database: %x -> %x", key, val)
   114  		}
   115  	}
   116  }
   117  
   118  func TestSnapshotRandom(t *testing.T) {
   119  	config := &quick.Config{MaxCount: 1000}
   120  	err := quick.Check((*snapshotTest).run, config)
   121  	if cerr, ok := err.(*quick.CheckError); ok {
   122  		test := cerr.In[0].(*snapshotTest)
   123  		t.Errorf("%v:\n%s", test.err, test)
   124  	} else if err != nil {
   125  		t.Error(err)
   126  	}
   127  }
   128  
   129  // A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
   130  // captured by the snapshot. Instances of this test with pseudorandom content are created
   131  // by Generate.
   132  //
   133  // The test works as follows:
   134  //
   135  // A new state is created and all actions are applied to it. Several snapshots are taken
   136  // in between actions. The test then reverts each snapshot. For each snapshot the actions
   137  // leading up to it are replayed on a fresh, empty state. The behaviour of all public
   138  // accessor methods on the reverted state must match the return value of the equivalent
   139  // methods on the replayed state.
   140  type snapshotTest struct {
   141  	addrs     []common.Address // all account addresses
   142  	actions   []testAction     // modifications to the state
   143  	snapshots []int            // actions indexes at which snapshot is taken
   144  	err       error            // failure details are reported through this field
   145  }
   146  
   147  type testAction struct {
   148  	name   string
   149  	fn     func(testAction, *StateDB)
   150  	args   []int64
   151  	noAddr bool
   152  }
   153  
   154  // newTestAction creates a random action that changes state.
   155  func newTestAction(addr common.Address, r *rand.Rand) testAction {
   156  	actions := []testAction{
   157  		{
   158  			name: "SetBalance",
   159  			fn: func(a testAction, s *StateDB) {
   160  				s.SetBalance(addr, big.NewInt(a.args[0]))
   161  			},
   162  			args: make([]int64, 1),
   163  		},
   164  		{
   165  			name: "AddBalance",
   166  			fn: func(a testAction, s *StateDB) {
   167  				s.AddBalance(addr, big.NewInt(a.args[0]))
   168  			},
   169  			args: make([]int64, 1),
   170  		},
   171  		{
   172  			name: "SetNonce",
   173  			fn: func(a testAction, s *StateDB) {
   174  				s.SetNonce(addr, uint64(a.args[0]))
   175  			},
   176  			args: make([]int64, 1),
   177  		},
   178  		{
   179  			name: "SetState",
   180  			fn: func(a testAction, s *StateDB) {
   181  				var key, val common.Hash
   182  				binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
   183  				binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
   184  				s.SetState(addr, key, val)
   185  			},
   186  			args: make([]int64, 2),
   187  		},
   188  		{
   189  			name: "SetCode",
   190  			fn: func(a testAction, s *StateDB) {
   191  				code := make([]byte, 16)
   192  				binary.BigEndian.PutUint64(code, uint64(a.args[0]))
   193  				binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
   194  				s.SetCode(addr, code)
   195  			},
   196  			args: make([]int64, 2),
   197  		},
   198  		{
   199  			name: "CreateAccount",
   200  			fn: func(a testAction, s *StateDB) {
   201  				s.CreateAccount(addr)
   202  			},
   203  		},
   204  		{
   205  			name: "Suicide",
   206  			fn: func(a testAction, s *StateDB) {
   207  				s.Suicide(addr)
   208  			},
   209  		},
   210  		{
   211  			name: "AddRefund",
   212  			fn: func(a testAction, s *StateDB) {
   213  				s.AddRefund(big.NewInt(a.args[0]))
   214  			},
   215  			args:   make([]int64, 1),
   216  			noAddr: true,
   217  		},
   218  		{
   219  			name: "AddLog",
   220  			fn: func(a testAction, s *StateDB) {
   221  				data := make([]byte, 2)
   222  				binary.BigEndian.PutUint16(data, uint16(a.args[0]))
   223  				s.AddLog(&types.Log{Address: addr, Data: data})
   224  			},
   225  			args: make([]int64, 1),
   226  		},
   227  	}
   228  	action := actions[r.Intn(len(actions))]
   229  	var nameargs []string
   230  	if !action.noAddr {
   231  		nameargs = append(nameargs, addr.Hex())
   232  	}
   233  	for _, i := range action.args {
   234  		action.args[i] = rand.Int63n(100)
   235  		nameargs = append(nameargs, fmt.Sprint(action.args[i]))
   236  	}
   237  	action.name += strings.Join(nameargs, ", ")
   238  	return action
   239  }
   240  
   241  // Generate returns a new snapshot test of the given size. All randomness is
   242  // derived from r.
   243  func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
   244  	// Generate random actions.
   245  	addrs := make([]common.Address, 50)
   246  	for i := range addrs {
   247  		addrs[i][0] = byte(i)
   248  	}
   249  	actions := make([]testAction, size)
   250  	for i := range actions {
   251  		addr := addrs[r.Intn(len(addrs))]
   252  		actions[i] = newTestAction(addr, r)
   253  	}
   254  	// Generate snapshot indexes.
   255  	nsnapshots := int(math.Sqrt(float64(size)))
   256  	if size > 0 && nsnapshots == 0 {
   257  		nsnapshots = 1
   258  	}
   259  	snapshots := make([]int, nsnapshots)
   260  	snaplen := len(actions) / nsnapshots
   261  	for i := range snapshots {
   262  		// Try to place the snapshots some number of actions apart from each other.
   263  		snapshots[i] = (i * snaplen) + r.Intn(snaplen)
   264  	}
   265  	return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
   266  }
   267  
   268  func (test *snapshotTest) String() string {
   269  	out := new(bytes.Buffer)
   270  	sindex := 0
   271  	for i, action := range test.actions {
   272  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   273  			fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
   274  			sindex++
   275  		}
   276  		fmt.Fprintf(out, "%4d: %s\n", i, action.name)
   277  	}
   278  	return out.String()
   279  }
   280  
   281  func (test *snapshotTest) run() bool {
   282  	// Run all actions and create snapshots.
   283  	var (
   284  		db, _        = ethdb.NewMemDatabase()
   285  		state, _     = New(common.Hash{}, db)
   286  		snapshotRevs = make([]int, len(test.snapshots))
   287  		sindex       = 0
   288  	)
   289  	for i, action := range test.actions {
   290  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   291  			snapshotRevs[sindex] = state.Snapshot()
   292  			sindex++
   293  		}
   294  		action.fn(action, state)
   295  	}
   296  
   297  	// Revert all snapshots in reverse order. Each revert must yield a state
   298  	// that is equivalent to fresh state with all actions up the snapshot applied.
   299  	for sindex--; sindex >= 0; sindex-- {
   300  		checkstate, _ := New(common.Hash{}, db)
   301  		for _, action := range test.actions[:test.snapshots[sindex]] {
   302  			action.fn(action, checkstate)
   303  		}
   304  		state.RevertToSnapshot(snapshotRevs[sindex])
   305  		if err := test.checkEqual(state, checkstate); err != nil {
   306  			test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
   307  			return false
   308  		}
   309  	}
   310  	return true
   311  }
   312  
   313  // checkEqual checks that methods of state and checkstate return the same values.
   314  func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
   315  	for _, addr := range test.addrs {
   316  		var err error
   317  		checkeq := func(op string, a, b interface{}) bool {
   318  			if err == nil && !reflect.DeepEqual(a, b) {
   319  				err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
   320  				return false
   321  			}
   322  			return true
   323  		}
   324  		// Check basic accessor methods.
   325  		checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
   326  		checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
   327  		checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
   328  		checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
   329  		checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
   330  		checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
   331  		checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
   332  		// Check storage.
   333  		if obj := state.getStateObject(addr); obj != nil {
   334  			state.ForEachStorage(addr, func(key, val common.Hash) bool {
   335  				return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
   336  			})
   337  			checkstate.ForEachStorage(addr, func(key, checkval common.Hash) bool {
   338  				return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
   339  			})
   340  		}
   341  		if err != nil {
   342  			return err
   343  		}
   344  	}
   345  
   346  	if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 {
   347  		return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
   348  			state.GetRefund(), checkstate.GetRefund())
   349  	}
   350  	if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
   351  		return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
   352  			state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
   353  	}
   354  	return nil
   355  }
   356  
   357  func TestTouchDelete(t *testing.T) {
   358  	db, _ := ethdb.NewMemDatabase()
   359  	state, _ := New(common.Hash{}, db)
   360  	state.GetOrNewStateObject(common.Address{})
   361  	root, _ := state.Commit(false)
   362  	state.Reset(root)
   363  
   364  	snapshot := state.Snapshot()
   365  	state.AddBalance(common.Address{}, new(big.Int))
   366  	if len(state.stateObjectsDirty) != 1 {
   367  		t.Fatal("expected one dirty state object")
   368  	}
   369  
   370  	state.RevertToSnapshot(snapshot)
   371  	if len(state.stateObjectsDirty) != 0 {
   372  		t.Fatal("expected no dirty state object")
   373  	}
   374  }