github.com/aquanetwork/aquachain@v1.7.8/core/state/statedb_test.go (about)

     1  // Copyright 2016 The aquachain Authors
     2  // This file is part of the aquachain library.
     3  //
     4  // The aquachain library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The aquachain library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package state
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/binary"
    22  	"fmt"
    23  	"math"
    24  	"math/big"
    25  	"math/rand"
    26  	"reflect"
    27  	"strings"
    28  	"testing"
    29  	"testing/quick"
    30  
    31  	check "gopkg.in/check.v1"
    32  
    33  	"gitlab.com/aquachain/aquachain/aquadb"
    34  	"gitlab.com/aquachain/aquachain/common"
    35  	"gitlab.com/aquachain/aquachain/core/types"
    36  )
    37  
    38  // Tests that updating a state trie does not leak any database writes prior to
    39  // actually committing the state.
    40  func TestUpdateLeaks(t *testing.T) {
    41  	// Create an empty state database
    42  	db := aquadb.NewMemDatabase()
    43  	state, _ := New(common.Hash{}, NewDatabase(db))
    44  
    45  	// Update it with some accounts
    46  	for i := byte(0); i < 255; i++ {
    47  		addr := common.BytesToAddress([]byte{i})
    48  		state.AddBalance(addr, big.NewInt(int64(11*i)))
    49  		state.SetNonce(addr, uint64(42*i))
    50  		if i%2 == 0 {
    51  			state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
    52  		}
    53  		if i%3 == 0 {
    54  			state.SetCode(addr, []byte{i, i, i, i, i})
    55  		}
    56  		state.IntermediateRoot(false)
    57  	}
    58  	// Ensure that no data was leaked into the database
    59  	for _, key := range db.Keys() {
    60  		value, _ := db.Get(key)
    61  		t.Errorf("State leaked into database: %x -> %x", key, value)
    62  	}
    63  }
    64  
    65  // Tests that no intermediate state of an object is stored into the database,
    66  // only the one right before the commit.
    67  func TestIntermediateLeaks(t *testing.T) {
    68  	// Create two state databases, one transitioning to the final state, the other final from the beginning
    69  	transDb := aquadb.NewMemDatabase()
    70  	finalDb := aquadb.NewMemDatabase()
    71  	transState, _ := New(common.Hash{}, NewDatabase(transDb))
    72  	finalState, _ := New(common.Hash{}, NewDatabase(finalDb))
    73  
    74  	modify := func(state *StateDB, addr common.Address, i, tweak byte) {
    75  		state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
    76  		state.SetNonce(addr, uint64(42*i+tweak))
    77  		if i%2 == 0 {
    78  			state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
    79  			state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
    80  		}
    81  		if i%3 == 0 {
    82  			state.SetCode(addr, []byte{i, i, i, i, i, tweak})
    83  		}
    84  	}
    85  
    86  	// Modify the transient state.
    87  	for i := byte(0); i < 255; i++ {
    88  		modify(transState, common.Address{byte(i)}, i, 0)
    89  	}
    90  	// Write modifications to trie.
    91  	transState.IntermediateRoot(false)
    92  
    93  	// Overwrite all the data with new values in the transient database.
    94  	for i := byte(0); i < 255; i++ {
    95  		modify(transState, common.Address{byte(i)}, i, 99)
    96  		modify(finalState, common.Address{byte(i)}, i, 99)
    97  	}
    98  
    99  	// Commit and cross check the databases.
   100  	if _, err := transState.Commit(false); err != nil {
   101  		t.Fatalf("failed to commit transition state: %v", err)
   102  	}
   103  	if _, err := finalState.Commit(false); err != nil {
   104  		t.Fatalf("failed to commit final state: %v", err)
   105  	}
   106  	for _, key := range finalDb.Keys() {
   107  		if _, err := transDb.Get(key); err != nil {
   108  			val, _ := finalDb.Get(key)
   109  			t.Errorf("entry missing from the transition database: %x -> %x", key, val)
   110  		}
   111  	}
   112  	for _, key := range transDb.Keys() {
   113  		if _, err := finalDb.Get(key); err != nil {
   114  			val, _ := transDb.Get(key)
   115  			t.Errorf("extra entry in the transition database: %x -> %x", key, val)
   116  		}
   117  	}
   118  }
   119  
   120  // TestCopy tests that copying a statedb object indeed makes the original and
   121  // the copy independent of each other. This test is a regression test against
   122  // https://gitlab.com/aquachain/aquachain/pull/15549.
   123  func TestCopy(t *testing.T) {
   124  	// Create a random state test to copy and modify "independently"
   125  	db := aquadb.NewMemDatabase()
   126  	orig, _ := New(common.Hash{}, NewDatabase(db))
   127  
   128  	for i := byte(0); i < 255; i++ {
   129  		obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   130  		obj.AddBalance(big.NewInt(int64(i)))
   131  		orig.updateStateObject(obj)
   132  	}
   133  	orig.Finalise(false)
   134  
   135  	// Copy the state, modify both in-memory
   136  	copy := orig.Copy()
   137  
   138  	for i := byte(0); i < 255; i++ {
   139  		origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   140  		copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   141  
   142  		origObj.AddBalance(big.NewInt(2 * int64(i)))
   143  		copyObj.AddBalance(big.NewInt(3 * int64(i)))
   144  
   145  		orig.updateStateObject(origObj)
   146  		copy.updateStateObject(copyObj)
   147  	}
   148  	// Finalise the changes on both concurrently
   149  	done := make(chan struct{})
   150  	go func() {
   151  		orig.Finalise(true)
   152  		close(done)
   153  	}()
   154  	copy.Finalise(true)
   155  	<-done
   156  
   157  	// Verify that the two states have been updated independently
   158  	for i := byte(0); i < 255; i++ {
   159  		origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   160  		copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   161  
   162  		if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 {
   163  			t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want)
   164  		}
   165  		if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 {
   166  			t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want)
   167  		}
   168  	}
   169  }
   170  
   171  func TestSnapshotRandom(t *testing.T) {
   172  	config := &quick.Config{MaxCount: 1000}
   173  	err := quick.Check((*snapshotTest).run, config)
   174  	if cerr, ok := err.(*quick.CheckError); ok {
   175  		test := cerr.In[0].(*snapshotTest)
   176  		t.Errorf("%v:\n%s", test.err, test)
   177  	} else if err != nil {
   178  		t.Error(err)
   179  	}
   180  }
   181  
   182  // A snapshotTest checks that reverting StateDB snapshots properly undoes all changes
   183  // captured by the snapshot. Instances of this test with pseudorandom content are created
   184  // by Generate.
   185  //
   186  // The test works as follows:
   187  //
   188  // A new state is created and all actions are applied to it. Several snapshots are taken
   189  // in between actions. The test then reverts each snapshot. For each snapshot the actions
   190  // leading up to it are replayed on a fresh, empty state. The behaviour of all public
   191  // accessor methods on the reverted state must match the return value of the equivalent
   192  // methods on the replayed state.
   193  type snapshotTest struct {
   194  	addrs     []common.Address // all account addresses
   195  	actions   []testAction     // modifications to the state
   196  	snapshots []int            // actions indexes at which snapshot is taken
   197  	err       error            // failure details are reported through this field
   198  }
   199  
   200  type testAction struct {
   201  	name   string
   202  	fn     func(testAction, *StateDB)
   203  	args   []int64
   204  	noAddr bool
   205  }
   206  
   207  // newTestAction creates a random action that changes state.
   208  func newTestAction(addr common.Address, r *rand.Rand) testAction {
   209  	actions := []testAction{
   210  		{
   211  			name: "SetBalance",
   212  			fn: func(a testAction, s *StateDB) {
   213  				s.SetBalance(addr, big.NewInt(a.args[0]))
   214  			},
   215  			args: make([]int64, 1),
   216  		},
   217  		{
   218  			name: "AddBalance",
   219  			fn: func(a testAction, s *StateDB) {
   220  				s.AddBalance(addr, big.NewInt(a.args[0]))
   221  			},
   222  			args: make([]int64, 1),
   223  		},
   224  		{
   225  			name: "SetNonce",
   226  			fn: func(a testAction, s *StateDB) {
   227  				s.SetNonce(addr, uint64(a.args[0]))
   228  			},
   229  			args: make([]int64, 1),
   230  		},
   231  		{
   232  			name: "SetState",
   233  			fn: func(a testAction, s *StateDB) {
   234  				var key, val common.Hash
   235  				binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
   236  				binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
   237  				s.SetState(addr, key, val)
   238  			},
   239  			args: make([]int64, 2),
   240  		},
   241  		{
   242  			name: "SetCode",
   243  			fn: func(a testAction, s *StateDB) {
   244  				code := make([]byte, 16)
   245  				binary.BigEndian.PutUint64(code, uint64(a.args[0]))
   246  				binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
   247  				s.SetCode(addr, code)
   248  			},
   249  			args: make([]int64, 2),
   250  		},
   251  		{
   252  			name: "CreateAccount",
   253  			fn: func(a testAction, s *StateDB) {
   254  				s.CreateAccount(addr)
   255  			},
   256  		},
   257  		{
   258  			name: "Suicide",
   259  			fn: func(a testAction, s *StateDB) {
   260  				s.Suicide(addr)
   261  			},
   262  		},
   263  		{
   264  			name: "AddRefund",
   265  			fn: func(a testAction, s *StateDB) {
   266  				s.AddRefund(uint64(a.args[0]))
   267  			},
   268  			args:   make([]int64, 1),
   269  			noAddr: true,
   270  		},
   271  		{
   272  			name: "AddLog",
   273  			fn: func(a testAction, s *StateDB) {
   274  				data := make([]byte, 2)
   275  				binary.BigEndian.PutUint16(data, uint16(a.args[0]))
   276  				s.AddLog(&types.Log{Address: addr, Data: data})
   277  			},
   278  			args: make([]int64, 1),
   279  		},
   280  	}
   281  	action := actions[r.Intn(len(actions))]
   282  	var nameargs []string
   283  	if !action.noAddr {
   284  		nameargs = append(nameargs, addr.Hex())
   285  	}
   286  	for _, i := range action.args {
   287  		action.args[i] = rand.Int63n(100)
   288  		nameargs = append(nameargs, fmt.Sprint(action.args[i]))
   289  	}
   290  	action.name += strings.Join(nameargs, ", ")
   291  	return action
   292  }
   293  
   294  // Generate returns a new snapshot test of the given size. All randomness is
   295  // derived from r.
   296  func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
   297  	// Generate random actions.
   298  	addrs := make([]common.Address, 50)
   299  	for i := range addrs {
   300  		addrs[i][0] = byte(i)
   301  	}
   302  	actions := make([]testAction, size)
   303  	for i := range actions {
   304  		addr := addrs[r.Intn(len(addrs))]
   305  		actions[i] = newTestAction(addr, r)
   306  	}
   307  	// Generate snapshot indexes.
   308  	nsnapshots := int(math.Sqrt(float64(size)))
   309  	if size > 0 && nsnapshots == 0 {
   310  		nsnapshots = 1
   311  	}
   312  	snapshots := make([]int, nsnapshots)
   313  	snaplen := len(actions) / nsnapshots
   314  	for i := range snapshots {
   315  		// Try to place the snapshots some number of actions apart from each other.
   316  		snapshots[i] = (i * snaplen) + r.Intn(snaplen)
   317  	}
   318  	return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
   319  }
   320  
   321  func (test *snapshotTest) String() string {
   322  	out := new(bytes.Buffer)
   323  	sindex := 0
   324  	for i, action := range test.actions {
   325  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   326  			fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
   327  			sindex++
   328  		}
   329  		fmt.Fprintf(out, "%4d: %s\n", i, action.name)
   330  	}
   331  	return out.String()
   332  }
   333  
   334  func (test *snapshotTest) run() bool {
   335  	// Run all actions and create snapshots.
   336  	var (
   337  		db           = aquadb.NewMemDatabase()
   338  		state, _     = New(common.Hash{}, NewDatabase(db))
   339  		snapshotRevs = make([]int, len(test.snapshots))
   340  		sindex       = 0
   341  	)
   342  	for i, action := range test.actions {
   343  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   344  			snapshotRevs[sindex] = state.Snapshot()
   345  			sindex++
   346  		}
   347  		action.fn(action, state)
   348  	}
   349  	// Revert all snapshots in reverse order. Each revert must yield a state
   350  	// that is equivalent to fresh state with all actions up the snapshot applied.
   351  	for sindex--; sindex >= 0; sindex-- {
   352  		checkstate, _ := New(common.Hash{}, state.Database())
   353  		for _, action := range test.actions[:test.snapshots[sindex]] {
   354  			action.fn(action, checkstate)
   355  		}
   356  		state.RevertToSnapshot(snapshotRevs[sindex])
   357  		if err := test.checkEqual(state, checkstate); err != nil {
   358  			test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
   359  			return false
   360  		}
   361  	}
   362  	return true
   363  }
   364  
   365  // checkEqual checks that methods of state and checkstate return the same values.
   366  func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
   367  	for _, addr := range test.addrs {
   368  		var err error
   369  		checkeq := func(op string, a, b interface{}) bool {
   370  			if err == nil && !reflect.DeepEqual(a, b) {
   371  				err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
   372  				return false
   373  			}
   374  			return true
   375  		}
   376  		// Check basic accessor methods.
   377  		checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
   378  		checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
   379  		checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
   380  		checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
   381  		checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
   382  		checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
   383  		checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
   384  		// Check storage.
   385  		if obj := state.getStateObject(addr); obj != nil {
   386  			state.ForEachStorage(addr, func(key, val common.Hash) bool {
   387  				return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
   388  			})
   389  			checkstate.ForEachStorage(addr, func(key, checkval common.Hash) bool {
   390  				return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
   391  			})
   392  		}
   393  		if err != nil {
   394  			return err
   395  		}
   396  	}
   397  
   398  	if state.GetRefund() != checkstate.GetRefund() {
   399  		return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
   400  			state.GetRefund(), checkstate.GetRefund())
   401  	}
   402  	if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
   403  		return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
   404  			state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
   405  	}
   406  	return nil
   407  }
   408  
   409  func (s *StateSuite) TestTouchDelete(c *check.C) {
   410  	s.state.GetOrNewStateObject(common.Address{})
   411  	root, _ := s.state.Commit(false)
   412  	s.state.Reset(root)
   413  
   414  	snapshot := s.state.Snapshot()
   415  	s.state.AddBalance(common.Address{}, new(big.Int))
   416  	if len(s.state.stateObjectsDirty) != 1 {
   417  		c.Fatal("expected one dirty state object")
   418  	}
   419  	s.state.RevertToSnapshot(snapshot)
   420  	if len(s.state.stateObjectsDirty) != 0 {
   421  		c.Fatal("expected no dirty state object")
   422  	}
   423  }