github.com/waltonchain/waltonchain_gwtc_src@v1.1.4-0.20201225072101-8a298c95a819/core/state/statedb_test.go (about)

     1  // Copyright 2016 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-wtc library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-wtc library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package state
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/binary"
    22  	"fmt"
    23  	"math"
    24  	"math/big"
    25  	"math/rand"
    26  	"reflect"
    27  	"strings"
    28  	"testing"
    29  	"testing/quick"
    30  
    31  	check "gopkg.in/check.v1"
    32  
    33  	"github.com/wtc/go-wtc/common"
    34  	"github.com/wtc/go-wtc/core/types"
    35  	"github.com/wtc/go-wtc/wtcdb"
    36  )
    37  
    38  // Tests that updating a state trie does not leak any database writes prior to
    39  // actually committing the state.
    40  func TestUpdateLeaks(t *testing.T) {
    41  	// Create an empty state database
    42  	db, _ := wtcdb.NewMemDatabase()
    43  	state, _ := New(common.Hash{}, NewDatabase(db))
    44  
    45  	// Update it with some accounts
    46  	for i := byte(0); i < 255; i++ {
    47  		addr := common.BytesToAddress([]byte{i})
    48  		state.AddBalance(addr, big.NewInt(int64(11*i)))
    49  		state.SetNonce(addr, uint64(42*i))
    50  		if i%2 == 0 {
    51  			state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
    52  		}
    53  		if i%3 == 0 {
    54  			state.SetCode(addr, []byte{i, i, i, i, i})
    55  		}
    56  		state.IntermediateRoot(false)
    57  	}
    58  	// Ensure that no data was leaked into the database
    59  	for _, key := range db.Keys() {
    60  		value, _ := db.Get(key)
    61  		t.Errorf("State leaked into database: %x -> %x", key, value)
    62  	}
    63  }
    64  
    65  // Tests that no intermediate state of an object is stored into the database,
    66  // only the one right before the commit.
    67  func TestIntermediateLeaks(t *testing.T) {
    68  	// Create two state databases, one transitioning to the final state, the other final from the beginning
    69  	transDb, _ := wtcdb.NewMemDatabase()
    70  	finalDb, _ := wtcdb.NewMemDatabase()
    71  	transState, _ := New(common.Hash{}, NewDatabase(transDb))
    72  	finalState, _ := New(common.Hash{}, NewDatabase(finalDb))
    73  
    74  	modify := func(state *StateDB, addr common.Address, i, tweak byte) {
    75  		state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
    76  		state.SetNonce(addr, uint64(42*i+tweak))
    77  		if i%2 == 0 {
    78  			state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
    79  			state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
    80  		}
    81  		if i%3 == 0 {
    82  			state.SetCode(addr, []byte{i, i, i, i, i, tweak})
    83  		}
    84  	}
    85  
    86  	// Modify the transient state.
    87  	for i := byte(0); i < 255; i++ {
    88  		modify(transState, common.Address{byte(i)}, i, 0)
    89  	}
    90  	// Write modifications to trie.
    91  	transState.IntermediateRoot(false)
    92  
    93  	// Overwrite all the data with new values in the transient database.
    94  	for i := byte(0); i < 255; i++ {
    95  		modify(transState, common.Address{byte(i)}, i, 99)
    96  		modify(finalState, common.Address{byte(i)}, i, 99)
    97  	}
    98  
    99  	// Commit and cross check the databases.
   100  	if _, err := transState.CommitTo(transDb, false); err != nil {
   101  		t.Fatalf("failed to commit transition state: %v", err)
   102  	}
   103  	if _, err := finalState.CommitTo(finalDb, false); err != nil {
   104  		t.Fatalf("failed to commit final state: %v", err)
   105  	}
   106  	for _, key := range finalDb.Keys() {
   107  		if _, err := transDb.Get(key); err != nil {
   108  			val, _ := finalDb.Get(key)
   109  			t.Errorf("entry missing from the transition database: %x -> %x", key, val)
   110  		}
   111  	}
   112  	for _, key := range transDb.Keys() {
   113  		if _, err := finalDb.Get(key); err != nil {
   114  			val, _ := transDb.Get(key)
   115  			t.Errorf("extra entry in the transition database: %x -> %x", key, val)
   116  		}
   117  	}
   118  }
   119  
   120  func TestSnapshotRandom(t *testing.T) {
   121  	config := &quick.Config{MaxCount: 1000}
   122  	err := quick.Check((*snapshotTest).run, config)
   123  	if cerr, ok := err.(*quick.CheckError); ok {
   124  		test := cerr.In[0].(*snapshotTest)
   125  		t.Errorf("%v:\n%s", test.err, test)
   126  	} else if err != nil {
   127  		t.Error(err)
   128  	}
   129  }
   130  
   131  // A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
   132  // captured by the snapshot. Instances of this test with pseudorandom content are created
   133  // by Generate.
   134  //
   135  // The test works as follows:
   136  //
   137  // A new state is created and all actions are applied to it. Several snapshots are taken
   138  // in between actions. The test then reverts each snapshot. For each snapshot the actions
   139  // leading up to it are replayed on a fresh, empty state. The behaviour of all public
   140  // accessor methods on the reverted state must match the return value of the equivalent
   141  // methods on the replayed state.
   142  type snapshotTest struct {
   143  	addrs     []common.Address // all account addresses
   144  	actions   []testAction     // modifications to the state
   145  	snapshots []int            // actions indexes at which snapshot is taken
   146  	err       error            // failure details are reported through this field
   147  }
   148  
   149  type testAction struct {
   150  	name   string
   151  	fn     func(testAction, *StateDB)
   152  	args   []int64
   153  	noAddr bool
   154  }
   155  
   156  // newTestAction creates a random action that changes state.
   157  func newTestAction(addr common.Address, r *rand.Rand) testAction {
   158  	actions := []testAction{
   159  		{
   160  			name: "SetBalance",
   161  			fn: func(a testAction, s *StateDB) {
   162  				s.SetBalance(addr, big.NewInt(a.args[0]))
   163  			},
   164  			args: make([]int64, 1),
   165  		},
   166  		{
   167  			name: "AddBalance",
   168  			fn: func(a testAction, s *StateDB) {
   169  				s.AddBalance(addr, big.NewInt(a.args[0]))
   170  			},
   171  			args: make([]int64, 1),
   172  		},
   173  		{
   174  			name: "SetNonce",
   175  			fn: func(a testAction, s *StateDB) {
   176  				s.SetNonce(addr, uint64(a.args[0]))
   177  			},
   178  			args: make([]int64, 1),
   179  		},
   180  		{
   181  			name: "SetState",
   182  			fn: func(a testAction, s *StateDB) {
   183  				var key, val common.Hash
   184  				binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
   185  				binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
   186  				s.SetState(addr, key, val)
   187  			},
   188  			args: make([]int64, 2),
   189  		},
   190  		{
   191  			name: "SetCode",
   192  			fn: func(a testAction, s *StateDB) {
   193  				code := make([]byte, 16)
   194  				binary.BigEndian.PutUint64(code, uint64(a.args[0]))
   195  				binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
   196  				s.SetCode(addr, code)
   197  			},
   198  			args: make([]int64, 2),
   199  		},
   200  		{
   201  			name: "CreateAccount",
   202  			fn: func(a testAction, s *StateDB) {
   203  				s.CreateAccount(addr)
   204  			},
   205  		},
   206  		{
   207  			name: "Suicide",
   208  			fn: func(a testAction, s *StateDB) {
   209  				s.Suicide(addr)
   210  			},
   211  		},
   212  		{
   213  			name: "AddRefund",
   214  			fn: func(a testAction, s *StateDB) {
   215  				s.AddRefund(big.NewInt(a.args[0]))
   216  			},
   217  			args:   make([]int64, 1),
   218  			noAddr: true,
   219  		},
   220  		{
   221  			name: "AddLog",
   222  			fn: func(a testAction, s *StateDB) {
   223  				data := make([]byte, 2)
   224  				binary.BigEndian.PutUint16(data, uint16(a.args[0]))
   225  				s.AddLog(&types.Log{Address: addr, Data: data})
   226  			},
   227  			args: make([]int64, 1),
   228  		},
   229  	}
   230  	action := actions[r.Intn(len(actions))]
   231  	var nameargs []string
   232  	if !action.noAddr {
   233  		nameargs = append(nameargs, addr.Hex())
   234  	}
   235  	for _, i := range action.args {
   236  		action.args[i] = rand.Int63n(100)
   237  		nameargs = append(nameargs, fmt.Sprint(action.args[i]))
   238  	}
   239  	action.name += strings.Join(nameargs, ", ")
   240  	return action
   241  }
   242  
   243  // Generate returns a new snapshot test of the given size. All randomness is
   244  // derived from r.
   245  func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
   246  	// Generate random actions.
   247  	addrs := make([]common.Address, 50)
   248  	for i := range addrs {
   249  		addrs[i][0] = byte(i)
   250  	}
   251  	actions := make([]testAction, size)
   252  	for i := range actions {
   253  		addr := addrs[r.Intn(len(addrs))]
   254  		actions[i] = newTestAction(addr, r)
   255  	}
   256  	// Generate snapshot indexes.
   257  	nsnapshots := int(math.Sqrt(float64(size)))
   258  	if size > 0 && nsnapshots == 0 {
   259  		nsnapshots = 1
   260  	}
   261  	snapshots := make([]int, nsnapshots)
   262  	snaplen := len(actions) / nsnapshots
   263  	for i := range snapshots {
   264  		// Try to place the snapshots some number of actions apart from each other.
   265  		snapshots[i] = (i * snaplen) + r.Intn(snaplen)
   266  	}
   267  	return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
   268  }
   269  
   270  func (test *snapshotTest) String() string {
   271  	out := new(bytes.Buffer)
   272  	sindex := 0
   273  	for i, action := range test.actions {
   274  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   275  			fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
   276  			sindex++
   277  		}
   278  		fmt.Fprintf(out, "%4d: %s\n", i, action.name)
   279  	}
   280  	return out.String()
   281  }
   282  
   283  func (test *snapshotTest) run() bool {
   284  	// Run all actions and create snapshots.
   285  	var (
   286  		db, _        = wtcdb.NewMemDatabase()
   287  		state, _     = New(common.Hash{}, NewDatabase(db))
   288  		snapshotRevs = make([]int, len(test.snapshots))
   289  		sindex       = 0
   290  	)
   291  	for i, action := range test.actions {
   292  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   293  			snapshotRevs[sindex] = state.Snapshot()
   294  			sindex++
   295  		}
   296  		action.fn(action, state)
   297  	}
   298  
   299  	// Revert all snapshots in reverse order. Each revert must yield a state
   300  	// that is equivalent to fresh state with all actions up the snapshot applied.
   301  	for sindex--; sindex >= 0; sindex-- {
   302  		checkstate, _ := New(common.Hash{}, NewDatabase(db))
   303  		for _, action := range test.actions[:test.snapshots[sindex]] {
   304  			action.fn(action, checkstate)
   305  		}
   306  		state.RevertToSnapshot(snapshotRevs[sindex])
   307  		if err := test.checkEqual(state, checkstate); err != nil {
   308  			test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
   309  			return false
   310  		}
   311  	}
   312  	return true
   313  }
   314  
   315  // checkEqual checks that methods of state and checkstate return the same values.
   316  func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
   317  	for _, addr := range test.addrs {
   318  		var err error
   319  		checkeq := func(op string, a, b interface{}) bool {
   320  			if err == nil && !reflect.DeepEqual(a, b) {
   321  				err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
   322  				return false
   323  			}
   324  			return true
   325  		}
   326  		// Check basic accessor methods.
   327  		checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
   328  		checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
   329  		checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
   330  		checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
   331  		checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
   332  		checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
   333  		checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
   334  		// Check storage.
   335  		if obj := state.getStateObject(addr); obj != nil {
   336  			state.ForEachStorage(addr, func(key, val common.Hash) bool {
   337  				return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
   338  			})
   339  			checkstate.ForEachStorage(addr, func(key, checkval common.Hash) bool {
   340  				return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
   341  			})
   342  		}
   343  		if err != nil {
   344  			return err
   345  		}
   346  	}
   347  
   348  	if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 {
   349  		return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
   350  			state.GetRefund(), checkstate.GetRefund())
   351  	}
   352  	if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
   353  		return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
   354  			state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
   355  	}
   356  	return nil
   357  }
   358  
   359  func (s *StateSuite) TestTouchDelete(c *check.C) {
   360  	s.state.GetOrNewStateObject(common.Address{})
   361  	root, _ := s.state.CommitTo(s.db, false)
   362  	s.state.Reset(root)
   363  
   364  	snapshot := s.state.Snapshot()
   365  	s.state.AddBalance(common.Address{}, new(big.Int))
   366  	if len(s.state.stateObjectsDirty) != 1 {
   367  		c.Fatal("expected one dirty state object")
   368  	}
   369  
   370  	s.state.RevertToSnapshot(snapshot)
   371  	if len(s.state.stateObjectsDirty) != 0 {
   372  		c.Fatal("expected no dirty state object")
   373  	}
   374  }