github.com/amazechain/amc@v0.1.3/modules/state/intra_block_state_test.go (about)

     1  // Copyright 2023 The AmazeChain Authors
     2  // This file is part of the AmazeChain library.
     3  //
     4  // The AmazeChain library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The AmazeChain library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the AmazeChain library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  //go:build integration
    18  
    19  package state
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"encoding/binary"
    25  	"fmt"
    26  	"math"
    27  	"math/big"
    28  	"math/rand"
    29  	"reflect"
    30  	"strings"
    31  	"testing"
    32  	"testing/quick"
    33  
    34  	"github.com/amazechain/amc/common/types"
    35  	"github.com/holiman/uint256"
    36  	"github.com/ledgerwatch/erigon-lib/kv/memdb"
    37  )
    38  
    39  func TestSnapshotRandom(t *testing.T) {
    40  	config := &quick.Config{MaxCount: 1000}
    41  	err := quick.Check((*snapshotTest).run, config)
    42  	if cerr, ok := err.(*quick.CheckError); ok {
    43  		test := cerr.In[0].(*snapshotTest)
    44  		t.Errorf("%v:\n%s", test.err, test)
    45  	} else if err != nil {
    46  		t.Error(err)
    47  	}
    48  }
    49  
    50  // A snapshotTest checks that reverting IntraBlockState snapshots properly undoes all changes
    51  // captured by the snapshot. Instances of this test with pseudorandom content are created
    52  // by Generate.
    53  //
    54  // The test works as follows:
    55  //
    56  // A new state is created and all actions are applied to it. Several snapshots are taken
    57  // in between actions. The test then reverts each snapshot. For each snapshot the actions
    58  // leading up to it are replayed on a fresh, empty state. The behaviour of all public
    59  // accessor methods on the reverted state must match the return value of the equivalent
    60  // methods on the replayed state.
    61  type snapshotTest struct {
    62  	addrs     []types.Address // all account addresses
    63  	actions   []testAction    // modifications to the state
    64  	snapshots []int           // actions indexes at which snapshot is taken
    65  	err       error           // failure details are reported through this field
    66  }
    67  
    68  type testAction struct {
    69  	name   string
    70  	fn     func(testAction, *IntraBlockState)
    71  	args   []int64
    72  	noAddr bool
    73  }
    74  
    75  // newTestAction creates a random action that changes state.
    76  func newTestAction(addr types.Address, r *rand.Rand) testAction {
    77  	actions := []testAction{
    78  		{
    79  			name: "SetBalance",
    80  			fn: func(a testAction, s *IntraBlockState) {
    81  				s.SetBalance(addr, uint256.NewInt(uint64(a.args[0])))
    82  			},
    83  			args: make([]int64, 1),
    84  		},
    85  		{
    86  			name: "AddBalance",
    87  			fn: func(a testAction, s *IntraBlockState) {
    88  				s.AddBalance(addr, uint256.NewInt(uint64(a.args[0])))
    89  			},
    90  			args: make([]int64, 1),
    91  		},
    92  		{
    93  			name: "SetNonce",
    94  			fn: func(a testAction, s *IntraBlockState) {
    95  				s.SetNonce(addr, uint64(a.args[0]))
    96  			},
    97  			args: make([]int64, 1),
    98  		},
    99  		{
   100  			name: "SetState",
   101  			fn: func(a testAction, s *IntraBlockState) {
   102  				var key types.Hash
   103  				binary.BigEndian.PutUint16(key[:], uint16(a.args[0]))
   104  				val := uint256.NewInt(uint64(a.args[1]))
   105  				s.SetState(addr, &key, *val)
   106  			},
   107  			args: make([]int64, 2),
   108  		},
   109  		{
   110  			name: "SetCode",
   111  			fn: func(a testAction, s *IntraBlockState) {
   112  				code := make([]byte, 16)
   113  				binary.BigEndian.PutUint64(code, uint64(a.args[0]))
   114  				binary.BigEndian.PutUint64(code[8:], uint64(a.args[1]))
   115  				s.SetCode(addr, code)
   116  			},
   117  			args: make([]int64, 2),
   118  		},
   119  		{
   120  			name: "CreateAccount",
   121  			fn: func(a testAction, s *IntraBlockState) {
   122  				s.CreateAccount(addr, true)
   123  			},
   124  		},
   125  		{
   126  			name: "Suicide",
   127  			fn: func(a testAction, s *IntraBlockState) {
   128  				s.Suicide(addr)
   129  			},
   130  		},
   131  		{
   132  			name: "AddRefund",
   133  			fn: func(a testAction, s *IntraBlockState) {
   134  				s.AddRefund(uint64(a.args[0]))
   135  			},
   136  			args:   make([]int64, 1),
   137  			noAddr: true,
   138  		},
   139  		{
   140  			name: "AddLog",
   141  			fn: func(a testAction, s *IntraBlockState) {
   142  				data := make([]byte, 2)
   143  				binary.BigEndian.PutUint16(data, uint16(a.args[0]))
   144  				s.AddLog(&types.Log{Address: addr, Data: data})
   145  			},
   146  			args: make([]int64, 1),
   147  		},
   148  		{
   149  			name: "AddAddressToAccessList",
   150  			fn: func(a testAction, s *IntraBlockState) {
   151  				s.AddAddressToAccessList(addr)
   152  			},
   153  		},
   154  		{
   155  			name: "AddSlotToAccessList",
   156  			fn: func(a testAction, s *IntraBlockState) {
   157  				s.AddSlotToAccessList(addr,
   158  					types.Hash{byte(a.args[0])})
   159  			},
   160  			args: make([]int64, 1),
   161  		},
   162  	}
   163  	action := actions[r.Intn(len(actions))]
   164  	var nameargs []string
   165  	if !action.noAddr {
   166  		nameargs = append(nameargs, addr.Hex())
   167  	}
   168  	for i := range action.args {
   169  		action.args[i] = rand.Int63n(100)
   170  		nameargs = append(nameargs, fmt.Sprint(action.args[i]))
   171  	}
   172  	action.name += strings.Join(nameargs, ", ")
   173  	return action
   174  }
   175  
   176  // Generate returns a new snapshot test of the given size. All randomness is
   177  // derived from r.
   178  func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value {
   179  	// Generate random actions.
   180  	addrs := make([]types.Address, 50)
   181  	for i := range addrs {
   182  		addrs[i][0] = byte(i)
   183  	}
   184  	actions := make([]testAction, size)
   185  	for i := range actions {
   186  		addr := addrs[r.Intn(len(addrs))]
   187  		actions[i] = newTestAction(addr, r)
   188  	}
   189  	// Generate snapshot indexes.
   190  	nsnapshots := int(math.Sqrt(float64(size)))
   191  	if size > 0 && nsnapshots == 0 {
   192  		nsnapshots = 1
   193  	}
   194  	snapshots := make([]int, nsnapshots)
   195  	snaplen := len(actions) / nsnapshots
   196  	for i := range snapshots {
   197  		// Try to place the snapshots some number of actions apart from each other.
   198  		snapshots[i] = (i * snaplen) + r.Intn(snaplen)
   199  	}
   200  	return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil})
   201  }
   202  
   203  func (test *snapshotTest) String() string {
   204  	out := new(bytes.Buffer)
   205  	sindex := 0
   206  	for i, action := range test.actions {
   207  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   208  			fmt.Fprintf(out, "---- snapshot %d ----\n", sindex)
   209  			sindex++
   210  		}
   211  		fmt.Fprintf(out, "%4d: %s\n", i, action.name)
   212  	}
   213  	return out.String()
   214  }
   215  
   216  func (test *snapshotTest) run() bool {
   217  	// Run all actions and create snapshots.
   218  	db := memdb.New()
   219  	defer db.Close()
   220  	tx, err := db.BeginRw(context.Background())
   221  	if err != nil {
   222  		test.err = err
   223  		return false
   224  	}
   225  	defer tx.Rollback()
   226  	var (
   227  		ds           = NewPlainState(tx, 1)
   228  		state        = New(ds)
   229  		snapshotRevs = make([]int, len(test.snapshots))
   230  		sindex       = 0
   231  	)
   232  	for i, action := range test.actions {
   233  		if len(test.snapshots) > sindex && i == test.snapshots[sindex] {
   234  			snapshotRevs[sindex] = state.Snapshot()
   235  			sindex++
   236  		}
   237  		action.fn(action, state)
   238  	}
   239  	// Revert all snapshots in reverse order. Each revert must yield a state
   240  	// that is equivalent to fresh state with all actions up the snapshot applied.
   241  	for sindex--; sindex >= 0; sindex-- {
   242  		checkds := NewPlainState(tx, 1)
   243  		checkstate := New(checkds)
   244  		for _, action := range test.actions[:test.snapshots[sindex]] {
   245  			action.fn(action, checkstate)
   246  		}
   247  		state.RevertToSnapshot(snapshotRevs[sindex])
   248  		if err := test.checkEqual(state, checkstate); err != nil {
   249  			test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%w", sindex, err)
   250  			return false
   251  		}
   252  	}
   253  	return true
   254  }
   255  
   256  // checkEqual checks that methods of state and checkstate return the same values.
   257  func (test *snapshotTest) checkEqual(state, checkstate *IntraBlockState) error {
   258  	for _, addr := range test.addrs {
   259  		addr := addr // pin
   260  		var err error
   261  		checkeq := func(op string, a, b interface{}) bool {
   262  			if err == nil && !reflect.DeepEqual(a, b) {
   263  				err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b)
   264  				return false
   265  			}
   266  			return true
   267  		}
   268  		checkeqBigInt := func(op string, a, b *big.Int) bool {
   269  			if err == nil && a.Cmp(b) != 0 {
   270  				err = fmt.Errorf("got %s(%s) == %d, want %d", op, addr.Hex(), a, b)
   271  				return false
   272  			}
   273  			return true
   274  		}
   275  		// Check basic accessor methods.
   276  		if !checkeq("Exist", state.Exist(addr), checkstate.Exist(addr)) {
   277  			return err
   278  		}
   279  		checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr))
   280  		checkeqBigInt("GetBalance", state.GetBalance(addr).ToBig(), checkstate.GetBalance(addr).ToBig())
   281  		checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr))
   282  		checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr))
   283  		checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr))
   284  		checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr))
   285  		// Check storage.
   286  		if obj := state.getStateObject(addr); obj != nil {
   287  			for key, value := range obj.dirtyStorage {
   288  				var out uint256.Int
   289  				checkstate.GetState(addr, &key, &out)
   290  				if !checkeq("GetState("+key.Hex()+")", out, value) {
   291  					return err
   292  				}
   293  			}
   294  		}
   295  		if obj := checkstate.getStateObject(addr); obj != nil {
   296  			for key, value := range obj.dirtyStorage {
   297  				var out uint256.Int
   298  				state.GetState(addr, &key, &out)
   299  				if !checkeq("GetState("+key.Hex()+")", out, value) {
   300  					return err
   301  				}
   302  			}
   303  		}
   304  	}
   305  
   306  	if state.GetRefund() != checkstate.GetRefund() {
   307  		return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d",
   308  			state.GetRefund(), checkstate.GetRefund())
   309  	}
   310  	if !reflect.DeepEqual(state.GetLogs(types.Hash{}), checkstate.GetLogs(types.Hash{})) {
   311  		return fmt.Errorf("got GetLogs(types.Hash{}) == %v, want GetLogs(types.Hash{}) == %v",
   312  			state.GetLogs(types.Hash{}), checkstate.GetLogs(types.Hash{}))
   313  	}
   314  	return nil
   315  }