github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/core/state/statedb_test.go (about)

     1  package state
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"math"
     8  	"math/big"
     9  	"math/rand"
    10  	"reflect"
    11  	"strings"
    12  	"testing"
    13  	"testing/quick"
    14  
    15  	check "gopkg.in/check.v1"
    16  
    17  	"github.com/quickchainproject/quickchain/common"
    18  	"github.com/quickchainproject/quickchain/core/types"
    19  	"github.com/quickchainproject/quickchain/qctdb"
    20  )
    21  
    22  // Tests that updating a state trie does not leak any database writes prior to
    23  // actually committing the state.
    24  func TestUpdateLeaks(t *testing.T) {
    25  	// Create an empty state database
    26  	db, _ := qctdb.NewMemDatabase()
    27  	state, _ := New(common.Hash{}, NewDatabase(db))
    28  
    29  	// Update it with some accounts
    30  	for i := byte(0); i < 255; i++ {
    31  		addr := common.BytesToAddress([]byte{i})
    32  		state.AddBalance(addr, big.NewInt(int64(11*i)))
    33  		state.SetNonce(addr, uint64(42*i))
    34  		if i%2 == 0 {
    35  			state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
    36  		}
    37  		if i%3 == 0 {
    38  			state.SetCode(addr, []byte{i, i, i, i, i})
    39  		}
    40  		state.IntermediateRoot(false)
    41  	}
    42  	// Ensure that no data was leaked into the database
    43  	for _, key := range db.Keys() {
    44  		value, _ := db.Get(key)
    45  		t.Errorf("State leaked into database: %x -> %x", key, value)
    46  	}
    47  }
    48  
    49  // Tests that no intermediate state of an object is stored into the database,
    50  // only the one right before the commit.
    51  func TestIntermediateLeaks(t *testing.T) {
    52  	// Create two state databases, one transitioning to the final state, the other final from the beginning
    53  	transDb, _ := qctdb.NewMemDatabase()
    54  	finalDb, _ := qctdb.NewMemDatabase()
    55  	transState, _ := New(common.Hash{}, NewDatabase(transDb))
    56  	finalState, _ := New(common.Hash{}, NewDatabase(finalDb))
    57  
    58  	modify := func(state *StateDB, addr common.Address, i, tweak byte) {
    59  		state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
    60  		state.SetNonce(addr, uint64(42*i+tweak))
    61  		if i%2 == 0 {
    62  			state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
    63  			state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
    64  		}
    65  		if i%3 == 0 {
    66  			state.SetCode(addr, []byte{i, i, i, i, i, tweak})
    67  		}
    68  	}
    69  
    70  	// Modify the transient state.
    71  	for i := byte(0); i < 255; i++ {
    72  		modify(transState, common.Address{byte(i)}, i, 0)
    73  	}
    74  	// Write modifications to trie.
    75  	transState.IntermediateRoot(false)
    76  
    77  	// Overwrite all the data with new values in the transient database.
    78  	for i := byte(0); i < 255; i++ {
    79  		modify(transState, common.Address{byte(i)}, i, 99)
    80  		modify(finalState, common.Address{byte(i)}, i, 99)
    81  	}
    82  
    83  	// Commit and cross check the databases.
    84  	if _, err := transState.Commit(false); err != nil {
    85  		t.Fatalf("failed to commit transition state: %v", err)
    86  	}
    87  	if _, err := finalState.Commit(false); err != nil {
    88  		t.Fatalf("failed to commit final state: %v", err)
    89  	}
    90  	for _, key := range finalDb.Keys() {
    91  		if _, err := transDb.Get(key); err != nil {
    92  			val, _ := finalDb.Get(key)
    93  			t.Errorf("entry missing from the transition database: %x -> %x", key, val)
    94  		}
    95  	}
    96  	for _, key := range transDb.Keys() {
    97  		if _, err := finalDb.Get(key); err != nil {
    98  			val, _ := transDb.Get(key)
    99  			t.Errorf("extra entry in the transition database: %x -> %x", key, val)
   100  		}
   101  	}
   102  }
   103  
   104  // TestCopy tests that copying a statedb object indeed makes the original and
   105  // the copy independent of each other. This test is a regression test against
   106  // https://github.com/quickchainproject/quickchain/pull/15549.
   107  func TestCopy(t *testing.T) {
   108  	// Create a random state test to copy and modify "independently"
   109  	db, _ := qctdb.NewMemDatabase()
   110  	orig, _ := New(common.Hash{}, NewDatabase(db))
   111  
   112  	for i := byte(0); i < 255; i++ {
   113  		obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   114  		obj.AddBalance(big.NewInt(int64(i)))
   115  		orig.updateStateObject(obj)
   116  	}
   117  	orig.Finalise(false)
   118  
   119  	// Copy the state, modify both in-memory
   120  	copy := orig.Copy()
   121  
   122  	for i := byte(0); i < 255; i++ {
   123  		origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   124  		copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   125  
   126  		origObj.AddBalance(big.NewInt(2 * int64(i)))
   127  		copyObj.AddBalance(big.NewInt(3 * int64(i)))
   128  
   129  		orig.updateStateObject(origObj)
   130  		copy.updateStateObject(copyObj)
   131  	}
   132  	// Finalise the changes on both concurrently
   133  	done := make(chan struct{})
   134  	go func() {
   135  		orig.Finalise(true)
   136  		close(done)
   137  	}()
   138  	copy.Finalise(true)
   139  	<-done
   140  
   141  	// Verify that the two states have been updated independently
   142  	for i := byte(0); i < 255; i++ {
   143  		origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   144  		copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   145  
   146  		if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 {
   147  			t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want)
   148  		}
   149  		if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 {
   150  			t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want)
   151  		}
   152  	}
   153  }
   154  
   155  func TestSnapshotRandom(t *testing.T) {
   156  	config := &quick.Config{MaxCount: 1000}
   157  	err := quick.Check((*snapshotTest).run, config)
   158  	if cerr, ok := err.(*quick.CheckError); ok {
   159  		test := cerr.In[0].(*snapshotTest)
   160  		t.Errorf("%v:\n%s", test.err, test)
   161  	} else if err != nil {
   162  		t.Error(err)
   163  	}
   164  }
   165  
   166  // A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
   167  // captured by the snapshot. Instances of this test with pseudorandom content are created
   168  // by Generate.
   169  //
   170  // The test works as follows:
   171  //
   172  // A new state is created and all actions are applied to it. Several snapshots are taken
   173  // in between actions. The test then reverts each snapshot. For each snapshot the actions
   174  // leading up to it are replayed on a fresh, empty state. The behaviour of all public
   175  // accessor methods on the reverted state must match the return value of the equivalent
   176  // methods on the replayed state.
   177  type snapshotTest struct {
   178  	addrs     []common.Address // all account addresses
   179  	actions   []testAction     // modifications to the state
   180  	snapshots []int            // actions indexes at which snapshot is taken
   181  	err       error            // failure details are reported through this field
   182  }
   183  
   184  type testAction struct {
   185  	name   string
   186  	fn     func(testAction, *StateDB)
   187  	args   []int64
   188  	noAddr bool
   189  }
   190  
   191  // newTestAction creates a random action that changes state.
   192  func newTestAction(addr common.Address, r *rand.Rand) testAction {
   193  	actions := []testAction{
   194  		{
   195  			name: "SetBalance",
   196  			fn: func(a testAction, s *StateDB) {
   197  				s.SetBalance(addr, big.NewInt(a.args[0]))
   198  			},
   199  			args: make([]int64, 1),
   200  		},
   201  		{
   202  			name: "AddBalance",
   203  			fn: func(a testAction, s *StateDB) {
   204  				s.AddBalance(addr, big.NewInt(a.args[0]))
   205  			},
   206  			args: make([]int64, 1),
   207  		},
   208  		{
   209  			name: "SetNonce",
   210  			fn: func(a testAction, s *StateDB) {
   211  				s.SetNonce(addr, uint64(a.args[0]))
   212  			},
   213  			args: make([]int64, 1),
   214  		},
   215  		{
   216  			name: "SetState",
   217  			fn: func(a testAction, s *StateDB) {
   218  				var key, val common.Hash
   219  				binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
   220  				binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
   221  				s.SetState(addr, key, val)
   222  			},
   223  			args: make([]int64, 2),
   224  		},
   225  		{
   226  			name: "SetCode",
   227  			fn: func(a testAction, s *StateDB) {
   228  				code := make([]byte, 16)
   229  				binary.BigEndian.PutUint64(code, uint64(a.args[0]))
   230  				binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
   231  				s.SetCode(addr, code)
   232  			},
   233  			args: make([]int64, 2),
   234  		},
   235  		{
   236  			name: "CreateAccount",
   237  			fn: func(a testAction, s *StateDB) {
   238  				s.CreateAccount(addr)
   239  			},
   240  		},
   241  		{
   242  			name: "Suicide",
   243  			fn: func(a testAction, s *StateDB) {
   244  				s.Suicide(addr)
   245  			},
   246  		},
   247  		{
   248  			name: "AddRefund",
   249  			fn: func(a testAction, s *StateDB) {
   250  				s.AddRefund(uint64(a.args[0]))
   251  			},
   252  			args:   make([]int64, 1),
   253  			noAddr: true,
   254  		},
   255  		{
   256  			name: "AddLog",
   257  			fn: func(a testAction, s *StateDB) {
   258  				data := make([]byte, 2)
   259  				binary.BigEndian.PutUint16(data, uint16(a.args[0]))
   260  				s.AddLog(&types.Log{Address: addr, Data: data})
   261  			},
   262  			args: make([]int64, 1),
   263  		},
   264  	}
   265  	action := actions[r.Intn(len(actions))]
   266  	var nameargs []string
   267  	if !action.noAddr {
   268  		nameargs = append(nameargs, addr.Hex())
   269  	}
   270  	for _, i := range action.args {
   271  		action.args[i] = rand.Int63n(100)
   272  		nameargs = append(nameargs, fmt.Sprint(action.args[i]))
   273  	}
   274  	action.name += strings.Join(nameargs, ", ")
   275  	return action
   276  }
   277  
   278  // Generate returns a new snapshot test of the given size. All randomness is
   279  // derived from r.
   280  func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
   281  	// Generate random actions.
   282  	addrs := make([]common.Address, 50)
   283  	for i := range addrs {
   284  		addrs[i][0] = byte(i)
   285  	}
   286  	actions := make([]testAction, size)
   287  	for i := range actions {
   288  		addr := addrs[r.Intn(len(addrs))]
   289  		actions[i] = newTestAction(addr, r)
   290  	}
   291  	// Generate snapshot indexes.
   292  	nsnapshots := int(math.Sqrt(float64(size)))
   293  	if size > 0 && nsnapshots == 0 {
   294  		nsnapshots = 1
   295  	}
   296  	snapshots := make([]int, nsnapshots)
   297  	snaplen := len(actions) / nsnapshots
   298  	for i := range snapshots {
   299  		// Try to place the snapshots some number of actions apart from each other.
   300  		snapshots[i] = (i * snaplen) + r.Intn(snaplen)
   301  	}
   302  	return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
   303  }
   304  
   305  func (test *snapshotTest) String() string {
   306  	out := new(bytes.Buffer)
   307  	sindex := 0
   308  	for i, action := range test.actions {
   309  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   310  			fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
   311  			sindex++
   312  		}
   313  		fmt.Fprintf(out, "%4d: %s\n", i, action.name)
   314  	}
   315  	return out.String()
   316  }
   317  
   318  func (test *snapshotTest) run() bool {
   319  	// Run all actions and create snapshots.
   320  	var (
   321  		db, _        = qctdb.NewMemDatabase()
   322  		state, _     = New(common.Hash{}, NewDatabase(db))
   323  		snapshotRevs = make([]int, len(test.snapshots))
   324  		sindex       = 0
   325  	)
   326  	for i, action := range test.actions {
   327  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   328  			snapshotRevs[sindex] = state.Snapshot()
   329  			sindex++
   330  		}
   331  		action.fn(action, state)
   332  	}
   333  	// Revert all snapshots in reverse order. Each revert must yield a state
   334  	// that is equivalent to fresh state with all actions up the snapshot applied.
   335  	for sindex--; sindex >= 0; sindex-- {
   336  		checkstate, _ := New(common.Hash{}, state.Database())
   337  		for _, action := range test.actions[:test.snapshots[sindex]] {
   338  			action.fn(action, checkstate)
   339  		}
   340  		state.RevertToSnapshot(snapshotRevs[sindex])
   341  		if err := test.checkEqual(state, checkstate); err != nil {
   342  			test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
   343  			return false
   344  		}
   345  	}
   346  	return true
   347  }
   348  
   349  // checkEqual checks that methods of state and checkstate return the same values.
   350  func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
   351  	for _, addr := range test.addrs {
   352  		var err error
   353  		checkeq := func(op string, a, b interface{}) bool {
   354  			if err == nil && !reflect.DeepEqual(a, b) {
   355  				err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
   356  				return false
   357  			}
   358  			return true
   359  		}
   360  		// Check basic accessor methods.
   361  		checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
   362  		checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
   363  		checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
   364  		checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
   365  		checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
   366  		checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
   367  		checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
   368  		// Check storage.
   369  		if obj := state.getStateObject(addr); obj != nil {
   370  			state.ForEachStorage(addr, func(key, val common.Hash) bool {
   371  				return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
   372  			})
   373  			checkstate.ForEachStorage(addr, func(key, checkval common.Hash) bool {
   374  				return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
   375  			})
   376  		}
   377  		if err != nil {
   378  			return err
   379  		}
   380  	}
   381  
   382  	if state.GetRefund() != checkstate.GetRefund() {
   383  		return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
   384  			state.GetRefund(), checkstate.GetRefund())
   385  	}
   386  	if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
   387  		return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
   388  			state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
   389  	}
   390  	return nil
   391  }
   392  
   393  func (s *StateSuite) TestTouchDelete(c *check.C) {
   394  	s.state.GetOrNewStateObject(common.Address{})
   395  	root, _ := s.state.Commit(false)
   396  	s.state.Reset(root)
   397  
   398  	snapshot := s.state.Snapshot()
   399  	s.state.AddBalance(common.Address{}, new(big.Int))
   400  
   401  	if len(s.state.journal.dirties) != 1 {
   402  		c.Fatal("expected one dirty state object")
   403  	}
   404  	s.state.RevertToSnapshot(snapshot)
   405  	if len(s.state.journal.dirties) != 0 {
   406  		c.Fatal("expected no dirty state object")
   407  	}
   408  }
   409  
   410  // TestCopyOfCopy tests that modified objects are carried over to the copy, and the copy of the copy.
   411  // See https://github.com/quickchainproject/quickchain/pull/15225#issuecomment-380191512
   412  func TestCopyOfCopy(t *testing.T) {
   413  	db, _ := qctdb.NewMemDatabase()
   414  	sdb, _ := New(common.Hash{}, NewDatabase(db))
   415  	addr := common.HexToAddress("aaaa")
   416  	sdb.SetBalance(addr, big.NewInt(42))
   417  
   418  	if got := sdb.Copy().GetBalance(addr).Uint64(); got != 42 {
   419  		t.Fatalf("1st copy fail, expected 42, got %v", got)
   420  	}
   421  	if got := sdb.Copy().Copy().GetBalance(addr).Uint64(); got != 42 {
   422  		t.Fatalf("2nd copy fail, expected 42, got %v", got)
   423  	}
   424  }