github.com/dim4egster/coreth@v0.10.2/sync/handlers/leafs_request_test.go (about)

     1  // (c) 2021-2022, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package handlers
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/dim4egster/qmallgo/ids"
    13  	"github.com/dim4egster/coreth/core/rawdb"
    14  	"github.com/dim4egster/coreth/core/state/snapshot"
    15  	"github.com/dim4egster/coreth/core/types"
    16  	"github.com/dim4egster/coreth/ethdb"
    17  	"github.com/dim4egster/coreth/ethdb/memorydb"
    18  	"github.com/dim4egster/coreth/plugin/evm/message"
    19  	"github.com/dim4egster/coreth/sync/handlers/stats"
    20  	"github.com/dim4egster/coreth/trie"
    21  	"github.com/ethereum/go-ethereum/common"
    22  	"github.com/ethereum/go-ethereum/crypto"
    23  	"github.com/ethereum/go-ethereum/rlp"
    24  	"github.com/stretchr/testify/assert"
    25  )
    26  
    27  func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) {
    28  	rand.Seed(1)
    29  	mockHandlerStats := &stats.MockHandlerStats{}
    30  	memdb := memorydb.New()
    31  	trieDB := trie.NewDatabase(memdb)
    32  
    33  	corruptedTrieRoot, _, _ := trie.GenerateTrie(t, trieDB, 100, common.HashLength)
    34  	// Corrupt [corruptedTrieRoot]
    35  	trie.CorruptTrie(t, trieDB, corruptedTrieRoot, 5)
    36  
    37  	largeTrieRoot, largeTrieKeys, _ := trie.GenerateTrie(t, trieDB, 10_000, common.HashLength)
    38  	smallTrieRoot, _, _ := trie.GenerateTrie(t, trieDB, 500, common.HashLength)
    39  	accountTrieRoot, accounts := trie.FillAccounts(
    40  		t,
    41  		trieDB,
    42  		common.Hash{},
    43  		10_000,
    44  		func(t *testing.T, i int, acc types.StateAccount) types.StateAccount {
    45  			// set the storage trie root for two accounts
    46  			if i == 0 {
    47  				acc.Root = largeTrieRoot
    48  			} else if i == 1 {
    49  				acc.Root = smallTrieRoot
    50  			}
    51  
    52  			return acc
    53  		})
    54  
    55  	// find the hash of the account we set to have a storage
    56  	var (
    57  		largeStorageAccount common.Hash
    58  		smallStorageAccount common.Hash
    59  	)
    60  	for key, account := range accounts {
    61  		if account.Root == largeTrieRoot {
    62  			largeStorageAccount = crypto.Keccak256Hash(key.Address[:])
    63  		}
    64  		if account.Root == smallTrieRoot {
    65  			smallStorageAccount = crypto.Keccak256Hash(key.Address[:])
    66  		}
    67  		if (largeStorageAccount != common.Hash{}) && (smallStorageAccount != common.Hash{}) {
    68  			// we can break if we found both accounts of interest to the test
    69  			break
    70  		}
    71  	}
    72  	snapshotProvider := &TestSnapshotProvider{}
    73  	leafsHandler := NewLeafsRequestHandler(trieDB, snapshotProvider, message.Codec, mockHandlerStats)
    74  
    75  	tests := map[string]struct {
    76  		prepareTestFn    func() (context.Context, message.LeafsRequest)
    77  		assertResponseFn func(*testing.T, message.LeafsRequest, []byte, error)
    78  	}{
    79  		"zero limit dropped": {
    80  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
    81  				return context.Background(), message.LeafsRequest{
    82  					Root:     largeTrieRoot,
    83  					Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
    84  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
    85  					Limit:    0,
    86  					NodeType: message.StateTrieNode,
    87  				}
    88  			},
    89  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
    90  				assert.Nil(t, response)
    91  				assert.Nil(t, err)
    92  				assert.EqualValues(t, 1, mockHandlerStats.InvalidLeafsRequestCount)
    93  			},
    94  		},
    95  		"empty root dropped": {
    96  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
    97  				return context.Background(), message.LeafsRequest{
    98  					Root:     common.Hash{},
    99  					Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   100  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   101  					Limit:    maxLeavesLimit,
   102  					NodeType: message.StateTrieNode,
   103  				}
   104  			},
   105  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   106  				assert.Nil(t, response)
   107  				assert.Nil(t, err)
   108  				assert.EqualValues(t, 1, mockHandlerStats.InvalidLeafsRequestCount)
   109  			},
   110  		},
   111  		"bad start len dropped": {
   112  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   113  				return context.Background(), message.LeafsRequest{
   114  					Root:     common.Hash{},
   115  					Start:    bytes.Repeat([]byte{0x00}, common.HashLength+2),
   116  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   117  					Limit:    maxLeavesLimit,
   118  					NodeType: message.StateTrieNode,
   119  				}
   120  			},
   121  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   122  				assert.Nil(t, response)
   123  				assert.Nil(t, err)
   124  				assert.EqualValues(t, 1, mockHandlerStats.InvalidLeafsRequestCount)
   125  			},
   126  		},
   127  		"bad end len dropped": {
   128  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   129  				return context.Background(), message.LeafsRequest{
   130  					Root:     common.Hash{},
   131  					Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   132  					End:      bytes.Repeat([]byte{0xff}, common.HashLength-1),
   133  					Limit:    maxLeavesLimit,
   134  					NodeType: message.StateTrieNode,
   135  				}
   136  			},
   137  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   138  				assert.Nil(t, response)
   139  				assert.Nil(t, err)
   140  				assert.EqualValues(t, 1, mockHandlerStats.InvalidLeafsRequestCount)
   141  			},
   142  		},
   143  		"empty storage root dropped": {
   144  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   145  				return context.Background(), message.LeafsRequest{
   146  					Root:     types.EmptyRootHash,
   147  					Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   148  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   149  					Limit:    maxLeavesLimit,
   150  					NodeType: message.StateTrieNode,
   151  				}
   152  			},
   153  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   154  				assert.Nil(t, response)
   155  				assert.Nil(t, err)
   156  				assert.EqualValues(t, 1, mockHandlerStats.InvalidLeafsRequestCount)
   157  			},
   158  		},
   159  		"missing root dropped": {
   160  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   161  				return context.Background(), message.LeafsRequest{
   162  					Root:     common.BytesToHash([]byte("something is missing here...")),
   163  					Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   164  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   165  					Limit:    maxLeavesLimit,
   166  					NodeType: message.StateTrieNode,
   167  				}
   168  			},
   169  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   170  				assert.Nil(t, response)
   171  				assert.Nil(t, err)
   172  				assert.EqualValues(t, 1, mockHandlerStats.MissingRootCount)
   173  			},
   174  		},
   175  		"corrupted trie drops request": {
   176  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   177  				return context.Background(), message.LeafsRequest{
   178  					Root:     corruptedTrieRoot,
   179  					Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   180  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   181  					Limit:    maxLeavesLimit,
   182  					NodeType: message.StateTrieNode,
   183  				}
   184  			},
   185  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   186  				assert.Nil(t, response)
   187  				assert.Nil(t, err)
   188  				assert.EqualValues(t, uint32(1), mockHandlerStats.TrieErrorCount)
   189  			},
   190  		},
   191  		"cancelled context dropped": {
   192  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   193  				ctx, cancel := context.WithCancel(context.Background())
   194  				defer cancel()
   195  				return ctx, message.LeafsRequest{
   196  					Root:     largeTrieRoot,
   197  					Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   198  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   199  					Limit:    maxLeavesLimit,
   200  					NodeType: message.StateTrieNode,
   201  				}
   202  			},
   203  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   204  				assert.Nil(t, response)
   205  				assert.Nil(t, err)
   206  			},
   207  		},
   208  		"nil start and end range returns entire trie": {
   209  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   210  				return context.Background(), message.LeafsRequest{
   211  					Root:     smallTrieRoot,
   212  					Start:    nil,
   213  					End:      nil,
   214  					Limit:    maxLeavesLimit,
   215  					NodeType: message.StateTrieNode,
   216  				}
   217  			},
   218  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   219  				assert.NoError(t, err)
   220  				var leafsResponse message.LeafsResponse
   221  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   222  				assert.NoError(t, err)
   223  				assert.Len(t, leafsResponse.Keys, 500)
   224  				assert.Len(t, leafsResponse.Vals, 500)
   225  				assert.Len(t, leafsResponse.ProofVals, 0)
   226  			},
   227  		},
   228  		"nil end range treated like greatest possible value": {
   229  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   230  				return context.Background(), message.LeafsRequest{
   231  					Root:     smallTrieRoot,
   232  					Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   233  					End:      nil,
   234  					Limit:    maxLeavesLimit,
   235  					NodeType: message.StateTrieNode,
   236  				}
   237  			},
   238  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   239  				assert.NoError(t, err)
   240  				var leafsResponse message.LeafsResponse
   241  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   242  				assert.NoError(t, err)
   243  				assert.Len(t, leafsResponse.Keys, 500)
   244  				assert.Len(t, leafsResponse.Vals, 500)
   245  			},
   246  		},
   247  		"end greater than start dropped": {
   248  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   249  				ctx, cancel := context.WithCancel(context.Background())
   250  				defer cancel()
   251  				return ctx, message.LeafsRequest{
   252  					Root:     largeTrieRoot,
   253  					Start:    bytes.Repeat([]byte{0xbb}, common.HashLength),
   254  					End:      bytes.Repeat([]byte{0xaa}, common.HashLength),
   255  					Limit:    maxLeavesLimit,
   256  					NodeType: message.StateTrieNode,
   257  				}
   258  			},
   259  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   260  				assert.Nil(t, response)
   261  				assert.Nil(t, err)
   262  				assert.EqualValues(t, 1, mockHandlerStats.InvalidLeafsRequestCount)
   263  			},
   264  		},
   265  		"invalid node type dropped": {
   266  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   267  				ctx, cancel := context.WithCancel(context.Background())
   268  				defer cancel()
   269  				return ctx, message.LeafsRequest{
   270  					Root:     largeTrieRoot,
   271  					Start:    bytes.Repeat([]byte{0xbb}, common.HashLength),
   272  					End:      bytes.Repeat([]byte{0xaa}, common.HashLength),
   273  					Limit:    maxLeavesLimit,
   274  					NodeType: message.NodeType(11),
   275  				}
   276  			},
   277  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   278  				assert.Nil(t, response)
   279  				assert.Nil(t, err)
   280  			},
   281  		},
   282  		"max leaves overridden": {
   283  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   284  				return context.Background(), message.LeafsRequest{
   285  					Root:     largeTrieRoot,
   286  					Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   287  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   288  					Limit:    maxLeavesLimit * 10,
   289  					NodeType: message.StateTrieNode,
   290  				}
   291  			},
   292  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   293  				assert.NoError(t, err)
   294  				var leafsResponse message.LeafsResponse
   295  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   296  				assert.NoError(t, err)
   297  				assert.EqualValues(t, len(leafsResponse.Keys), maxLeavesLimit)
   298  				assert.EqualValues(t, len(leafsResponse.Vals), maxLeavesLimit)
   299  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   300  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   301  			},
   302  		},
   303  		"full range with nil start": {
   304  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   305  				return context.Background(), message.LeafsRequest{
   306  					Root:     largeTrieRoot,
   307  					Start:    nil,
   308  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   309  					Limit:    maxLeavesLimit,
   310  					NodeType: message.StateTrieNode,
   311  				}
   312  			},
   313  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   314  				assert.NoError(t, err)
   315  				var leafsResponse message.LeafsResponse
   316  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   317  				assert.NoError(t, err)
   318  				assert.EqualValues(t, len(leafsResponse.Keys), maxLeavesLimit)
   319  				assert.EqualValues(t, len(leafsResponse.Vals), maxLeavesLimit)
   320  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   321  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   322  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   323  			},
   324  		},
   325  		"full range with 0x00 start": {
   326  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   327  				return context.Background(), message.LeafsRequest{
   328  					Root:     largeTrieRoot,
   329  					Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   330  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   331  					Limit:    maxLeavesLimit,
   332  					NodeType: message.StateTrieNode,
   333  				}
   334  			},
   335  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   336  				assert.NoError(t, err)
   337  				var leafsResponse message.LeafsResponse
   338  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   339  				assert.NoError(t, err)
   340  				assert.EqualValues(t, len(leafsResponse.Keys), maxLeavesLimit)
   341  				assert.EqualValues(t, len(leafsResponse.Vals), maxLeavesLimit)
   342  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   343  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   344  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   345  			},
   346  		},
   347  		"partial mid range": {
   348  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   349  				startKey := largeTrieKeys[1_000]
   350  				startKey[31] = startKey[31] + 1 // exclude start key from response
   351  				endKey := largeTrieKeys[1_040]  // include end key in response
   352  				return context.Background(), message.LeafsRequest{
   353  					Root:     largeTrieRoot,
   354  					Start:    startKey,
   355  					End:      endKey,
   356  					Limit:    maxLeavesLimit,
   357  					NodeType: message.StateTrieNode,
   358  				}
   359  			},
   360  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   361  				assert.NoError(t, err)
   362  				var leafsResponse message.LeafsResponse
   363  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   364  				assert.NoError(t, err)
   365  				assert.EqualValues(t, 40, len(leafsResponse.Keys))
   366  				assert.EqualValues(t, 40, len(leafsResponse.Vals))
   367  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   368  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   369  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   370  			},
   371  		},
   372  		"partial end range": {
   373  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   374  				return context.Background(), message.LeafsRequest{
   375  					Root:     largeTrieRoot,
   376  					Start:    largeTrieKeys[9_400],
   377  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   378  					Limit:    maxLeavesLimit,
   379  					NodeType: message.StateTrieNode,
   380  				}
   381  			},
   382  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   383  				assert.NoError(t, err)
   384  				var leafsResponse message.LeafsResponse
   385  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   386  				assert.NoError(t, err)
   387  				assert.EqualValues(t, 600, len(leafsResponse.Keys))
   388  				assert.EqualValues(t, 600, len(leafsResponse.Vals))
   389  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   390  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   391  				assertRangeProofIsValid(t, &request, &leafsResponse, false)
   392  			},
   393  		},
   394  		"final end range": {
   395  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   396  				return context.Background(), message.LeafsRequest{
   397  					Root:     largeTrieRoot,
   398  					Start:    bytes.Repeat([]byte{0xff}, common.HashLength),
   399  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   400  					Limit:    maxLeavesLimit,
   401  					NodeType: message.StateTrieNode,
   402  				}
   403  			},
   404  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   405  				assert.NoError(t, err)
   406  				var leafsResponse message.LeafsResponse
   407  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   408  				assert.NoError(t, err)
   409  				assert.EqualValues(t, len(leafsResponse.Keys), 0)
   410  				assert.EqualValues(t, len(leafsResponse.Vals), 0)
   411  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   412  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   413  				assertRangeProofIsValid(t, &request, &leafsResponse, false)
   414  			},
   415  		},
   416  		"small trie root": {
   417  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   418  				return context.Background(), message.LeafsRequest{
   419  					Root:     smallTrieRoot,
   420  					Start:    nil,
   421  					End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   422  					Limit:    maxLeavesLimit,
   423  					NodeType: message.StateTrieNode,
   424  				}
   425  			},
   426  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   427  				assert.NotEmpty(t, response)
   428  
   429  				var leafsResponse message.LeafsResponse
   430  				if _, err = message.Codec.Unmarshal(response, &leafsResponse); err != nil {
   431  					t.Fatalf("unexpected error when unmarshalling LeafsResponse: %v", err)
   432  				}
   433  
   434  				assert.EqualValues(t, 500, len(leafsResponse.Keys))
   435  				assert.EqualValues(t, 500, len(leafsResponse.Vals))
   436  				assert.Empty(t, leafsResponse.ProofVals)
   437  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   438  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   439  				assertRangeProofIsValid(t, &request, &leafsResponse, false)
   440  			},
   441  		},
   442  		"account data served from snapshot": {
   443  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   444  				snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false)
   445  				if err != nil {
   446  					t.Fatal(err)
   447  				}
   448  				snapshotProvider.Snapshot = snap
   449  				return context.Background(), message.LeafsRequest{
   450  					Root:     accountTrieRoot,
   451  					Limit:    maxLeavesLimit,
   452  					NodeType: message.StateTrieNode,
   453  				}
   454  			},
   455  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   456  				assert.NoError(t, err)
   457  				var leafsResponse message.LeafsResponse
   458  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   459  				assert.NoError(t, err)
   460  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys))
   461  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals))
   462  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   463  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   464  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount)
   465  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadSuccessCount)
   466  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   467  			},
   468  		},
   469  		"partial account data served from snapshot": {
   470  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   471  				snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false)
   472  				if err != nil {
   473  					t.Fatal(err)
   474  				}
   475  				snapshotProvider.Snapshot = snap
   476  				it := snap.DiskAccountIterator(common.Hash{})
   477  				defer it.Release()
   478  				i := 0
   479  				for it.Next() {
   480  					if i > int(maxLeavesLimit) {
   481  						// no need to modify beyond the request limit
   482  						break
   483  					}
   484  					// modify one entry of 1 in 4 segments
   485  					if i%(segmentLen*4) == 0 {
   486  						var acc snapshot.Account
   487  						if err := rlp.DecodeBytes(it.Account(), &acc); err != nil {
   488  							t.Fatalf("could not parse snapshot account: %v", err)
   489  						}
   490  						acc.Nonce++
   491  						bytes, err := rlp.EncodeToBytes(acc)
   492  						if err != nil {
   493  							t.Fatalf("coult not encode snapshot account to bytes: %v", err)
   494  						}
   495  						rawdb.WriteAccountSnapshot(memdb, it.Hash(), bytes)
   496  					}
   497  					i++
   498  				}
   499  
   500  				return context.Background(), message.LeafsRequest{
   501  					Root:     accountTrieRoot,
   502  					Limit:    maxLeavesLimit,
   503  					NodeType: message.StateTrieNode,
   504  				}
   505  			},
   506  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   507  				assert.NoError(t, err)
   508  				var leafsResponse message.LeafsResponse
   509  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   510  				assert.NoError(t, err)
   511  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys))
   512  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals))
   513  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   514  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   515  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount)
   516  				assert.EqualValues(t, 0, mockHandlerStats.SnapshotReadSuccessCount)
   517  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   518  
   519  				// expect 1/4th of segments to be invalid
   520  				numSegments := maxLeavesLimit / segmentLen
   521  				assert.EqualValues(t, numSegments/4, mockHandlerStats.SnapshotSegmentInvalidCount)
   522  				assert.EqualValues(t, 3*numSegments/4, mockHandlerStats.SnapshotSegmentValidCount)
   523  			},
   524  		},
   525  		"storage data served from snapshot": {
   526  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   527  				snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false)
   528  				if err != nil {
   529  					t.Fatal(err)
   530  				}
   531  				snapshotProvider.Snapshot = snap
   532  				return context.Background(), message.LeafsRequest{
   533  					Root:     largeTrieRoot,
   534  					Account:  largeStorageAccount,
   535  					Limit:    maxLeavesLimit,
   536  					NodeType: message.StateTrieNode,
   537  				}
   538  			},
   539  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   540  				assert.NoError(t, err)
   541  				var leafsResponse message.LeafsResponse
   542  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   543  				assert.NoError(t, err)
   544  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys))
   545  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals))
   546  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   547  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   548  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount)
   549  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadSuccessCount)
   550  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   551  			},
   552  		},
   553  		"partial storage data served from snapshot": {
   554  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   555  				snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false)
   556  				if err != nil {
   557  					t.Fatal(err)
   558  				}
   559  				snapshotProvider.Snapshot = snap
   560  				it := snap.DiskStorageIterator(largeStorageAccount, common.Hash{})
   561  				defer it.Release()
   562  				i := 0
   563  				for it.Next() {
   564  					if i > int(maxLeavesLimit) {
   565  						// no need to modify beyond the request limit
   566  						break
   567  					}
   568  					// modify one entry of 1 in 4 segments
   569  					if i%(segmentLen*4) == 0 {
   570  						randomBytes := make([]byte, 5)
   571  						_, err := rand.Read(randomBytes)
   572  						assert.NoError(t, err)
   573  						rawdb.WriteStorageSnapshot(memdb, largeStorageAccount, it.Hash(), randomBytes)
   574  					}
   575  					i++
   576  				}
   577  
   578  				return context.Background(), message.LeafsRequest{
   579  					Root:     largeTrieRoot,
   580  					Account:  largeStorageAccount,
   581  					Limit:    maxLeavesLimit,
   582  					NodeType: message.StateTrieNode,
   583  				}
   584  			},
   585  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   586  				assert.NoError(t, err)
   587  				var leafsResponse message.LeafsResponse
   588  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   589  				assert.NoError(t, err)
   590  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys))
   591  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals))
   592  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   593  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   594  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount)
   595  				assert.EqualValues(t, 0, mockHandlerStats.SnapshotReadSuccessCount)
   596  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   597  
   598  				// expect 1/4th of segments to be invalid
   599  				numSegments := maxLeavesLimit / segmentLen
   600  				assert.EqualValues(t, numSegments/4, mockHandlerStats.SnapshotSegmentInvalidCount)
   601  				assert.EqualValues(t, 3*numSegments/4, mockHandlerStats.SnapshotSegmentValidCount)
   602  			},
   603  		},
   604  		"last snapshot key removed": {
   605  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   606  				snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false)
   607  				if err != nil {
   608  					t.Fatal(err)
   609  				}
   610  				snapshotProvider.Snapshot = snap
   611  				it := snap.DiskStorageIterator(smallStorageAccount, common.Hash{})
   612  				defer it.Release()
   613  				var lastKey common.Hash
   614  				for it.Next() {
   615  					lastKey = it.Hash()
   616  				}
   617  				rawdb.DeleteStorageSnapshot(memdb, smallStorageAccount, lastKey)
   618  
   619  				return context.Background(), message.LeafsRequest{
   620  					Root:     smallTrieRoot,
   621  					Account:  smallStorageAccount,
   622  					Limit:    maxLeavesLimit,
   623  					NodeType: message.StateTrieNode,
   624  				}
   625  			},
   626  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   627  				assert.NoError(t, err)
   628  				var leafsResponse message.LeafsResponse
   629  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   630  				assert.NoError(t, err)
   631  				assert.EqualValues(t, 500, len(leafsResponse.Keys))
   632  				assert.EqualValues(t, 500, len(leafsResponse.Vals))
   633  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   634  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   635  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount)
   636  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadSuccessCount)
   637  				assertRangeProofIsValid(t, &request, &leafsResponse, false)
   638  			},
   639  		},
   640  		"request last key when removed from snapshot": {
   641  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   642  				snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false)
   643  				if err != nil {
   644  					t.Fatal(err)
   645  				}
   646  				snapshotProvider.Snapshot = snap
   647  				it := snap.DiskStorageIterator(smallStorageAccount, common.Hash{})
   648  				defer it.Release()
   649  				var lastKey common.Hash
   650  				for it.Next() {
   651  					lastKey = it.Hash()
   652  				}
   653  				rawdb.DeleteStorageSnapshot(memdb, smallStorageAccount, lastKey)
   654  
   655  				return context.Background(), message.LeafsRequest{
   656  					Root:     smallTrieRoot,
   657  					Account:  smallStorageAccount,
   658  					Start:    lastKey[:],
   659  					Limit:    maxLeavesLimit,
   660  					NodeType: message.StateTrieNode,
   661  				}
   662  			},
   663  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   664  				assert.NoError(t, err)
   665  				var leafsResponse message.LeafsResponse
   666  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   667  				assert.NoError(t, err)
   668  				assert.EqualValues(t, 1, len(leafsResponse.Keys))
   669  				assert.EqualValues(t, 1, len(leafsResponse.Vals))
   670  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   671  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   672  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount)
   673  				assert.EqualValues(t, 0, mockHandlerStats.SnapshotReadSuccessCount)
   674  				assertRangeProofIsValid(t, &request, &leafsResponse, false)
   675  			},
   676  		},
   677  	}
   678  	for name, test := range tests {
   679  		t.Run(name, func(t *testing.T) {
   680  			ctx, request := test.prepareTestFn()
   681  			t.Cleanup(func() {
   682  				<-snapshot.WipeSnapshot(memdb, true)
   683  				mockHandlerStats.Reset()
   684  				snapshotProvider.Snapshot = nil // reset the snapshot to nil
   685  			})
   686  
   687  			response, err := leafsHandler.OnLeafsRequest(ctx, ids.GenerateTestNodeID(), 1, request)
   688  			test.assertResponseFn(t, request, response, err)
   689  		})
   690  	}
   691  }
   692  
   693  func assertRangeProofIsValid(t *testing.T, request *message.LeafsRequest, response *message.LeafsResponse, expectMore bool) {
   694  	t.Helper()
   695  
   696  	var start, end []byte
   697  	if len(request.Start) == 0 {
   698  		start = bytes.Repeat([]byte{0x00}, common.HashLength)
   699  	} else {
   700  		start = request.Start
   701  	}
   702  	if len(response.Keys) > 0 {
   703  		end = response.Keys[len(response.Vals)-1]
   704  	}
   705  
   706  	var proof ethdb.Database
   707  	if len(response.ProofVals) > 0 {
   708  		proof = memorydb.New()
   709  		defer proof.Close()
   710  		for _, proofVal := range response.ProofVals {
   711  			proofKey := crypto.Keccak256(proofVal)
   712  			if err := proof.Put(proofKey, proofVal); err != nil {
   713  				t.Fatal(err)
   714  			}
   715  		}
   716  	}
   717  
   718  	more, err := trie.VerifyRangeProof(request.Root, start, end, response.Keys, response.Vals, proof)
   719  	assert.NoError(t, err)
   720  	assert.Equal(t, expectMore, more)
   721  }