github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/chain/core/state/statedb_test.go (about)

     1  package state
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  
     8  	"math"
     9  	"math/big"
    10  	"math/rand"
    11  	"reflect"
    12  	"strings"
    13  	"testing"
    14  	"testing/quick"
    15  
    16  	"gopkg.in/check.v1"
    17  
    18  	"github.com/neatio-net/neatio/chain/core/rawdb"
    19  	"github.com/neatio-net/neatio/chain/core/types"
    20  	"github.com/neatio-net/neatio/utilities/common"
    21  )
    22  
    23  func TestUpdateLeaks(t *testing.T) {
    24  
    25  	db := rawdb.NewMemoryDatabase()
    26  	state, _ := New(common.Hash{}, NewDatabase(db))
    27  
    28  	for i := byte(0); i < 255; i++ {
    29  		addr := common.BytesToAddress([]byte{i})
    30  		state.AddBalance(addr, big.NewInt(int64(11*i)))
    31  		state.SetNonce(addr, uint64(42*i))
    32  		if i%2 == 0 {
    33  			state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i}))
    34  		}
    35  		if i%3 == 0 {
    36  			state.SetCode(addr, []byte{i, i, i, i, i})
    37  		}
    38  		state.IntermediateRoot(false)
    39  	}
    40  
    41  	it := db.NewIterator()
    42  	for it.Next() {
    43  		t.Errorf("State leaked into database: %x -> %x", it.Key(), it.Value())
    44  	}
    45  	it.Release()
    46  }
    47  
    48  func TestIntermediateLeaks(t *testing.T) {
    49  
    50  	transDb := rawdb.NewMemoryDatabase()
    51  	finalDb := rawdb.NewMemoryDatabase()
    52  	transState, _ := New(common.Hash{}, NewDatabase(transDb))
    53  	finalState, _ := New(common.Hash{}, NewDatabase(finalDb))
    54  
    55  	modify := func(state *StateDB, addr common.Address, i, tweak byte) {
    56  		state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak)))
    57  		state.SetNonce(addr, uint64(42*i+tweak))
    58  		if i%2 == 0 {
    59  			state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{})
    60  			state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak})
    61  		}
    62  		if i%3 == 0 {
    63  			state.SetCode(addr, []byte{i, i, i, i, i, tweak})
    64  		}
    65  	}
    66  
    67  	for i := byte(0); i < 255; i++ {
    68  		modify(transState, common.Address{byte(i)}, i, 0)
    69  	}
    70  
    71  	transState.IntermediateRoot(false)
    72  
    73  	for i := byte(0); i < 255; i++ {
    74  		modify(transState, common.Address{byte(i)}, i, 99)
    75  		modify(finalState, common.Address{byte(i)}, i, 99)
    76  	}
    77  
    78  	if _, err := transState.Commit(false); err != nil {
    79  		t.Fatalf("failed to commit transition state: %v", err)
    80  	}
    81  	if _, err := finalState.Commit(false); err != nil {
    82  		t.Fatalf("failed to commit final state: %v", err)
    83  	}
    84  	it := finalDb.NewIterator()
    85  	for it.Next() {
    86  		key := it.Key()
    87  		if _, err := transDb.Get(key); err != nil {
    88  			t.Errorf("entry missing from the transition database: %x -> %x", key, it.Value())
    89  		}
    90  	}
    91  	it.Release()
    92  
    93  	it = transDb.NewIterator()
    94  	for it.Next() {
    95  		key := it.Key()
    96  		if _, err := finalDb.Get(key); err != nil {
    97  			t.Errorf("extra entry in the transition database: %x -> %x", key, it.Value())
    98  		}
    99  	}
   100  }
   101  
   102  func TestCopy(t *testing.T) {
   103  
   104  	orig, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
   105  
   106  	for i := byte(0); i < 255; i++ {
   107  		obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   108  		obj.AddBalance(big.NewInt(int64(i)))
   109  		orig.updateStateObject(obj)
   110  	}
   111  	orig.Finalise(false)
   112  
   113  	copy := orig.Copy()
   114  
   115  	for i := byte(0); i < 255; i++ {
   116  		origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   117  		copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   118  
   119  		origObj.AddBalance(big.NewInt(2 * int64(i)))
   120  		copyObj.AddBalance(big.NewInt(3 * int64(i)))
   121  
   122  		orig.updateStateObject(origObj)
   123  		copy.updateStateObject(copyObj)
   124  	}
   125  
   126  	done := make(chan struct{})
   127  	go func() {
   128  		orig.Finalise(true)
   129  		close(done)
   130  	}()
   131  	copy.Finalise(true)
   132  	<-done
   133  
   134  	for i := byte(0); i < 255; i++ {
   135  		origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   136  		copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
   137  
   138  		if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 {
   139  			t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want)
   140  		}
   141  		if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 {
   142  			t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want)
   143  		}
   144  	}
   145  }
   146  
   147  func TestSnapshotRandom(t *testing.T) {
   148  	config := &quick.Config{MaxCount: 1000}
   149  	err := quick.Check((*snapshotTest).run, config)
   150  	if cerr, ok := err.(*quick.CheckError); ok {
   151  		test := cerr.In[0].(*snapshotTest)
   152  		t.Errorf("%v:\n%s", test.err, test)
   153  	} else if err != nil {
   154  		t.Error(err)
   155  	}
   156  }
   157  
   158  type snapshotTest struct {
   159  	addrs     []common.Address
   160  	actions   []testAction
   161  	snapshots []int
   162  	err       error
   163  }
   164  
   165  type testAction struct {
   166  	name   string
   167  	fn     func(testAction, *StateDB)
   168  	args   []int64
   169  	noAddr bool
   170  }
   171  
   172  func newTestAction(addr common.Address, r *rand.Rand) testAction {
   173  	actions := []testAction{
   174  		{
   175  			name: "SetBalance",
   176  			fn: func(a testAction, s *StateDB) {
   177  				s.SetBalance(addr, big.NewInt(a.args[0]))
   178  			},
   179  			args: make([]int64, 1),
   180  		},
   181  		{
   182  			name: "AddBalance",
   183  			fn: func(a testAction, s *StateDB) {
   184  				s.AddBalance(addr, big.NewInt(a.args[0]))
   185  			},
   186  			args: make([]int64, 1),
   187  		},
   188  		{
   189  			name: "SetNonce",
   190  			fn: func(a testAction, s *StateDB) {
   191  				s.SetNonce(addr, uint64(a.args[0]))
   192  			},
   193  			args: make([]int64, 1),
   194  		},
   195  		{
   196  			name: "SetState",
   197  			fn: func(a testAction, s *StateDB) {
   198  				var key, val common.Hash
   199  				binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
   200  				binary.BigEndian.PutUint16(val[:], uint16(a.args[1]))
   201  				s.SetState(addr, key, val)
   202  			},
   203  			args: make([]int64, 2),
   204  		},
   205  		{
   206  			name: "SetCode",
   207  			fn: func(a testAction, s *StateDB) {
   208  				code := make([]byte, 16)
   209  				binary.BigEndian.PutUint64(code, uint64(a.args[0]))
   210  				binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
   211  				s.SetCode(addr, code)
   212  			},
   213  			args: make([]int64, 2),
   214  		},
   215  		{
   216  			name: "CreateAccount",
   217  			fn: func(a testAction, s *StateDB) {
   218  				s.CreateAccount(addr)
   219  			},
   220  		},
   221  		{
   222  			name: "Suicide",
   223  			fn: func(a testAction, s *StateDB) {
   224  				s.Suicide(addr)
   225  			},
   226  		},
   227  		{
   228  			name: "AddRefund",
   229  			fn: func(a testAction, s *StateDB) {
   230  				s.AddRefund(uint64(a.args[0]))
   231  			},
   232  			args:   make([]int64, 1),
   233  			noAddr: true,
   234  		},
   235  		{
   236  			name: "AddLog",
   237  			fn: func(a testAction, s *StateDB) {
   238  				data := make([]byte, 2)
   239  				binary.BigEndian.PutUint16(data, uint16(a.args[0]))
   240  				s.AddLog(&types.Log{Address: addr, Data: data})
   241  			},
   242  			args: make([]int64, 1),
   243  		},
   244  	}
   245  	action := actions[r.Intn(len(actions))]
   246  	var nameargs []string
   247  	if !action.noAddr {
   248  
   249  		nameargs = append(nameargs, addr.String())
   250  	}
   251  	for _, i := range action.args {
   252  		action.args[i] = rand.Int63n(100)
   253  		nameargs = append(nameargs, fmt.Sprint(action.args[i]))
   254  	}
   255  	action.name += strings.Join(nameargs, ", ")
   256  	return action
   257  }
   258  
   259  func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
   260  
   261  	addrs := make([]common.Address, 50)
   262  	for i := range addrs {
   263  		addrs[i][0] = byte(i)
   264  	}
   265  	actions := make([]testAction, size)
   266  	for i := range actions {
   267  		addr := addrs[r.Intn(len(addrs))]
   268  		actions[i] = newTestAction(addr, r)
   269  	}
   270  
   271  	nsnapshots := int(math.Sqrt(float64(size)))
   272  	if size > 0 && nsnapshots == 0 {
   273  		nsnapshots = 1
   274  	}
   275  	snapshots := make([]int, nsnapshots)
   276  	snaplen := len(actions) / nsnapshots
   277  	for i := range snapshots {
   278  
   279  		snapshots[i] = (i * snaplen) + r.Intn(snaplen)
   280  	}
   281  	return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
   282  }
   283  
   284  func (test *snapshotTest) String() string {
   285  	out := new(bytes.Buffer)
   286  	sindex := 0
   287  	for i, action := range test.actions {
   288  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   289  			fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
   290  			sindex++
   291  		}
   292  		fmt.Fprintf(out, "%4d: %s\n", i, action.name)
   293  	}
   294  	return out.String()
   295  }
   296  
   297  func (test *snapshotTest) run() bool {
   298  
   299  	var (
   300  		state, _     = New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
   301  		snapshotRevs = make([]int, len(test.snapshots))
   302  		sindex       = 0
   303  	)
   304  	for i, action := range test.actions {
   305  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   306  			snapshotRevs[sindex] = state.Snapshot()
   307  			sindex++
   308  		}
   309  		action.fn(action, state)
   310  	}
   311  
   312  	for sindex--; sindex >= 0; sindex-- {
   313  		checkstate, _ := New(common.Hash{}, state.Database())
   314  		for _, action := range test.actions[:test.snapshots[sindex]] {
   315  			action.fn(action, checkstate)
   316  		}
   317  		state.RevertToSnapshot(snapshotRevs[sindex])
   318  		if err := test.checkEqual(state, checkstate); err != nil {
   319  			test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err)
   320  			return false
   321  		}
   322  	}
   323  	return true
   324  }
   325  
   326  func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
   327  	for _, addr := range test.addrs {
   328  		var err error
   329  		checkeq := func(op string, a, b interface{}) bool {
   330  			if err == nil && !reflect.DeepEqual(a, b) {
   331  
   332  				err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.String(), a, b)
   333  				return false
   334  			}
   335  			return true
   336  		}
   337  
   338  		checkeq("Exist", state.Exist(addr), checkstate.Exist(addr))
   339  		checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
   340  		checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr))
   341  		checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
   342  		checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
   343  		checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
   344  		checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
   345  
   346  		if obj := state.getStateObject(addr); obj != nil {
   347  			state.ForEachStorage(addr, func(key, val common.Hash) bool {
   348  				return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key))
   349  			})
   350  			checkstate.ForEachStorage(addr, func(key, checkval common.Hash) bool {
   351  				return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval)
   352  			})
   353  		}
   354  		if err != nil {
   355  			return err
   356  		}
   357  	}
   358  
   359  	if state.GetRefund() != checkstate.GetRefund() {
   360  		return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
   361  			state.GetRefund(), checkstate.GetRefund())
   362  	}
   363  	if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) {
   364  		return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v",
   365  			state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{}))
   366  	}
   367  	return nil
   368  }
   369  
   370  func (s *StateSuite) TestTouchDelete(c *check.C) {
   371  	s.state.GetOrNewStateObject(common.Address{})
   372  	root, _ := s.state.Commit(false)
   373  	s.state.Reset(root)
   374  
   375  	snapshot := s.state.Snapshot()
   376  	s.state.AddBalance(common.Address{}, new(big.Int))
   377  	if len(s.state.stateObjectsDirty) != 1 {
   378  		c.Fatal("expected one dirty state object")
   379  	}
   380  	s.state.RevertToSnapshot(snapshot)
   381  	if len(s.state.stateObjectsDirty) != 0 {
   382  		c.Fatal("expected no dirty state object")
   383  	}
   384  }
   385  
   386  func TestCopyOfCopy(t *testing.T) {
   387  	sdb, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()))
   388  	addr := common.HexToAddress("aaaa")
   389  	sdb.SetBalance(addr, big.NewInt(42))
   390  
   391  	if got := sdb.Copy().GetBalance(addr).Uint64(); got != 42 {
   392  		t.Fatalf("1st copy fail, expected 42, got %v", got)
   393  	}
   394  	if got := sdb.Copy().Copy().GetBalance(addr).Uint64(); got != 42 {
   395  		t.Fatalf("2nd copy fail, expected 42, got %v", got)
   396  	}
   397  }