github.com/MetalBlockchain/subnet-evm@v0.4.9/sync/handlers/leafs_request_test.go (about)

     1  // (c) 2021-2022, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package handlers
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"math/rand"
    10  	"testing"
    11  
    12  	"github.com/MetalBlockchain/metalgo/ids"
    13  	"github.com/MetalBlockchain/subnet-evm/core/rawdb"
    14  	"github.com/MetalBlockchain/subnet-evm/core/state/snapshot"
    15  	"github.com/MetalBlockchain/subnet-evm/core/types"
    16  	"github.com/MetalBlockchain/subnet-evm/ethdb"
    17  	"github.com/MetalBlockchain/subnet-evm/ethdb/memorydb"
    18  	"github.com/MetalBlockchain/subnet-evm/plugin/evm/message"
    19  	"github.com/MetalBlockchain/subnet-evm/sync/handlers/stats"
    20  	"github.com/MetalBlockchain/subnet-evm/trie"
    21  	"github.com/ethereum/go-ethereum/common"
    22  	"github.com/ethereum/go-ethereum/crypto"
    23  	"github.com/ethereum/go-ethereum/rlp"
    24  	"github.com/stretchr/testify/assert"
    25  )
    26  
    27  func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) {
    28  	rand.Seed(1)
    29  	mockHandlerStats := &stats.MockHandlerStats{}
    30  	memdb := memorydb.New()
    31  	trieDB := trie.NewDatabase(memdb)
    32  
    33  	corruptedTrieRoot, _, _ := trie.GenerateTrie(t, trieDB, 100, common.HashLength)
    34  	// Corrupt [corruptedTrieRoot]
    35  	trie.CorruptTrie(t, trieDB, corruptedTrieRoot, 5)
    36  
    37  	largeTrieRoot, largeTrieKeys, _ := trie.GenerateTrie(t, trieDB, 10_000, common.HashLength)
    38  	smallTrieRoot, _, _ := trie.GenerateTrie(t, trieDB, 500, common.HashLength)
    39  	accountTrieRoot, accounts := trie.FillAccounts(
    40  		t,
    41  		trieDB,
    42  		common.Hash{},
    43  		10_000,
    44  		func(t *testing.T, i int, acc types.StateAccount) types.StateAccount {
    45  			// set the storage trie root for two accounts
    46  			if i == 0 {
    47  				acc.Root = largeTrieRoot
    48  			} else if i == 1 {
    49  				acc.Root = smallTrieRoot
    50  			}
    51  
    52  			return acc
    53  		})
    54  
    55  	// find the hash of the account we set to have a storage
    56  	var (
    57  		largeStorageAccount common.Hash
    58  		smallStorageAccount common.Hash
    59  	)
    60  	for key, account := range accounts {
    61  		if account.Root == largeTrieRoot {
    62  			largeStorageAccount = crypto.Keccak256Hash(key.Address[:])
    63  		}
    64  		if account.Root == smallTrieRoot {
    65  			smallStorageAccount = crypto.Keccak256Hash(key.Address[:])
    66  		}
    67  		if (largeStorageAccount != common.Hash{}) && (smallStorageAccount != common.Hash{}) {
    68  			// we can break if we found both accounts of interest to the test
    69  			break
    70  		}
    71  	}
    72  	snapshotProvider := &TestSnapshotProvider{}
    73  	leafsHandler := NewLeafsRequestHandler(trieDB, snapshotProvider, message.Codec, mockHandlerStats)
    74  
    75  	tests := map[string]struct {
    76  		prepareTestFn    func() (context.Context, message.LeafsRequest)
    77  		assertResponseFn func(*testing.T, message.LeafsRequest, []byte, error)
    78  	}{
    79  		"zero limit dropped": {
    80  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
    81  				return context.Background(), message.LeafsRequest{
    82  					Root:  largeTrieRoot,
    83  					Start: bytes.Repeat([]byte{0x00}, common.HashLength),
    84  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
    85  					Limit: 0,
    86  				}
    87  			},
    88  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
    89  				assert.Nil(t, response)
    90  				assert.Nil(t, err)
    91  				assert.EqualValues(t, 1, mockHandlerStats.InvalidLeafsRequestCount)
    92  			},
    93  		},
    94  		"empty root dropped": {
    95  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
    96  				return context.Background(), message.LeafsRequest{
    97  					Root:  common.Hash{},
    98  					Start: bytes.Repeat([]byte{0x00}, common.HashLength),
    99  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   100  					Limit: maxLeavesLimit,
   101  				}
   102  			},
   103  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   104  				assert.Nil(t, response)
   105  				assert.Nil(t, err)
   106  				assert.EqualValues(t, 1, mockHandlerStats.InvalidLeafsRequestCount)
   107  			},
   108  		},
   109  		"bad start len dropped": {
   110  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   111  				return context.Background(), message.LeafsRequest{
   112  					Root:  common.Hash{},
   113  					Start: bytes.Repeat([]byte{0x00}, common.HashLength+2),
   114  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   115  					Limit: maxLeavesLimit,
   116  				}
   117  			},
   118  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   119  				assert.Nil(t, response)
   120  				assert.Nil(t, err)
   121  				assert.EqualValues(t, 1, mockHandlerStats.InvalidLeafsRequestCount)
   122  			},
   123  		},
   124  		"bad end len dropped": {
   125  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   126  				return context.Background(), message.LeafsRequest{
   127  					Root:  common.Hash{},
   128  					Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   129  					End:   bytes.Repeat([]byte{0xff}, common.HashLength-1),
   130  					Limit: maxLeavesLimit,
   131  				}
   132  			},
   133  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   134  				assert.Nil(t, response)
   135  				assert.Nil(t, err)
   136  				assert.EqualValues(t, 1, mockHandlerStats.InvalidLeafsRequestCount)
   137  			},
   138  		},
   139  		"empty storage root dropped": {
   140  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   141  				return context.Background(), message.LeafsRequest{
   142  					Root:  types.EmptyRootHash,
   143  					Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   144  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   145  					Limit: maxLeavesLimit,
   146  				}
   147  			},
   148  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   149  				assert.Nil(t, response)
   150  				assert.Nil(t, err)
   151  				assert.EqualValues(t, 1, mockHandlerStats.InvalidLeafsRequestCount)
   152  			},
   153  		},
   154  		"missing root dropped": {
   155  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   156  				return context.Background(), message.LeafsRequest{
   157  					Root:  common.BytesToHash([]byte("something is missing here...")),
   158  					Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   159  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   160  					Limit: maxLeavesLimit,
   161  				}
   162  			},
   163  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   164  				assert.Nil(t, response)
   165  				assert.Nil(t, err)
   166  				assert.EqualValues(t, 1, mockHandlerStats.MissingRootCount)
   167  			},
   168  		},
   169  		"corrupted trie drops request": {
   170  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   171  				return context.Background(), message.LeafsRequest{
   172  					Root:  corruptedTrieRoot,
   173  					Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   174  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   175  					Limit: maxLeavesLimit,
   176  				}
   177  			},
   178  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   179  				assert.Nil(t, response)
   180  				assert.Nil(t, err)
   181  				assert.EqualValues(t, uint32(1), mockHandlerStats.TrieErrorCount)
   182  			},
   183  		},
   184  		"cancelled context dropped": {
   185  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   186  				ctx, cancel := context.WithCancel(context.Background())
   187  				defer cancel()
   188  				return ctx, message.LeafsRequest{
   189  					Root:  largeTrieRoot,
   190  					Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   191  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   192  					Limit: maxLeavesLimit,
   193  				}
   194  			},
   195  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   196  				assert.Nil(t, response)
   197  				assert.Nil(t, err)
   198  			},
   199  		},
   200  		"nil start and end range returns entire trie": {
   201  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   202  				return context.Background(), message.LeafsRequest{
   203  					Root:  smallTrieRoot,
   204  					Start: nil,
   205  					End:   nil,
   206  					Limit: maxLeavesLimit,
   207  				}
   208  			},
   209  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   210  				assert.NoError(t, err)
   211  				var leafsResponse message.LeafsResponse
   212  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   213  				assert.NoError(t, err)
   214  				assert.Len(t, leafsResponse.Keys, 500)
   215  				assert.Len(t, leafsResponse.Vals, 500)
   216  				assert.Len(t, leafsResponse.ProofVals, 0)
   217  			},
   218  		},
   219  		"nil end range treated like greatest possible value": {
   220  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   221  				return context.Background(), message.LeafsRequest{
   222  					Root:  smallTrieRoot,
   223  					Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   224  					End:   nil,
   225  					Limit: maxLeavesLimit,
   226  				}
   227  			},
   228  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   229  				assert.NoError(t, err)
   230  				var leafsResponse message.LeafsResponse
   231  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   232  				assert.NoError(t, err)
   233  				assert.Len(t, leafsResponse.Keys, 500)
   234  				assert.Len(t, leafsResponse.Vals, 500)
   235  			},
   236  		},
   237  		"end greater than start dropped": {
   238  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   239  				ctx, cancel := context.WithCancel(context.Background())
   240  				defer cancel()
   241  				return ctx, message.LeafsRequest{
   242  					Root:  largeTrieRoot,
   243  					Start: bytes.Repeat([]byte{0xbb}, common.HashLength),
   244  					End:   bytes.Repeat([]byte{0xaa}, common.HashLength),
   245  					Limit: maxLeavesLimit,
   246  				}
   247  			},
   248  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   249  				assert.Nil(t, response)
   250  				assert.Nil(t, err)
   251  				assert.EqualValues(t, 1, mockHandlerStats.InvalidLeafsRequestCount)
   252  			},
   253  		},
   254  		"invalid node type dropped": {
   255  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   256  				ctx, cancel := context.WithCancel(context.Background())
   257  				defer cancel()
   258  				return ctx, message.LeafsRequest{
   259  					Root:  largeTrieRoot,
   260  					Start: bytes.Repeat([]byte{0xbb}, common.HashLength),
   261  					End:   bytes.Repeat([]byte{0xaa}, common.HashLength),
   262  					Limit: maxLeavesLimit,
   263  				}
   264  			},
   265  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   266  				assert.Nil(t, response)
   267  				assert.Nil(t, err)
   268  			},
   269  		},
   270  		"max leaves overridden": {
   271  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   272  				return context.Background(), message.LeafsRequest{
   273  					Root:  largeTrieRoot,
   274  					Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   275  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   276  					Limit: maxLeavesLimit * 10,
   277  				}
   278  			},
   279  			assertResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) {
   280  				assert.NoError(t, err)
   281  				var leafsResponse message.LeafsResponse
   282  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   283  				assert.NoError(t, err)
   284  				assert.EqualValues(t, len(leafsResponse.Keys), maxLeavesLimit)
   285  				assert.EqualValues(t, len(leafsResponse.Vals), maxLeavesLimit)
   286  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   287  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   288  			},
   289  		},
   290  		"full range with nil start": {
   291  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   292  				return context.Background(), message.LeafsRequest{
   293  					Root:  largeTrieRoot,
   294  					Start: nil,
   295  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   296  					Limit: maxLeavesLimit,
   297  				}
   298  			},
   299  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   300  				assert.NoError(t, err)
   301  				var leafsResponse message.LeafsResponse
   302  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   303  				assert.NoError(t, err)
   304  				assert.EqualValues(t, len(leafsResponse.Keys), maxLeavesLimit)
   305  				assert.EqualValues(t, len(leafsResponse.Vals), maxLeavesLimit)
   306  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   307  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   308  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   309  			},
   310  		},
   311  		"full range with 0x00 start": {
   312  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   313  				return context.Background(), message.LeafsRequest{
   314  					Root:  largeTrieRoot,
   315  					Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   316  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   317  					Limit: maxLeavesLimit,
   318  				}
   319  			},
   320  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   321  				assert.NoError(t, err)
   322  				var leafsResponse message.LeafsResponse
   323  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   324  				assert.NoError(t, err)
   325  				assert.EqualValues(t, len(leafsResponse.Keys), maxLeavesLimit)
   326  				assert.EqualValues(t, len(leafsResponse.Vals), maxLeavesLimit)
   327  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   328  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   329  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   330  			},
   331  		},
   332  		"partial mid range": {
   333  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   334  				startKey := largeTrieKeys[1_000]
   335  				startKey[31] = startKey[31] + 1 // exclude start key from response
   336  				endKey := largeTrieKeys[1_040]  // include end key in response
   337  				return context.Background(), message.LeafsRequest{
   338  					Root:  largeTrieRoot,
   339  					Start: startKey,
   340  					End:   endKey,
   341  					Limit: maxLeavesLimit,
   342  				}
   343  			},
   344  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   345  				assert.NoError(t, err)
   346  				var leafsResponse message.LeafsResponse
   347  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   348  				assert.NoError(t, err)
   349  				assert.EqualValues(t, 40, len(leafsResponse.Keys))
   350  				assert.EqualValues(t, 40, len(leafsResponse.Vals))
   351  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   352  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   353  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   354  			},
   355  		},
   356  		"partial end range": {
   357  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   358  				return context.Background(), message.LeafsRequest{
   359  					Root:  largeTrieRoot,
   360  					Start: largeTrieKeys[9_400],
   361  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   362  					Limit: maxLeavesLimit,
   363  				}
   364  			},
   365  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   366  				assert.NoError(t, err)
   367  				var leafsResponse message.LeafsResponse
   368  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   369  				assert.NoError(t, err)
   370  				assert.EqualValues(t, 600, len(leafsResponse.Keys))
   371  				assert.EqualValues(t, 600, len(leafsResponse.Vals))
   372  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   373  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   374  				assertRangeProofIsValid(t, &request, &leafsResponse, false)
   375  			},
   376  		},
   377  		"final end range": {
   378  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   379  				return context.Background(), message.LeafsRequest{
   380  					Root:  largeTrieRoot,
   381  					Start: bytes.Repeat([]byte{0xff}, common.HashLength),
   382  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   383  					Limit: maxLeavesLimit,
   384  				}
   385  			},
   386  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   387  				assert.NoError(t, err)
   388  				var leafsResponse message.LeafsResponse
   389  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   390  				assert.NoError(t, err)
   391  				assert.EqualValues(t, len(leafsResponse.Keys), 0)
   392  				assert.EqualValues(t, len(leafsResponse.Vals), 0)
   393  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   394  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   395  				assertRangeProofIsValid(t, &request, &leafsResponse, false)
   396  			},
   397  		},
   398  		"small trie root": {
   399  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   400  				return context.Background(), message.LeafsRequest{
   401  					Root:  smallTrieRoot,
   402  					Start: nil,
   403  					End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   404  					Limit: maxLeavesLimit,
   405  				}
   406  			},
   407  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   408  				assert.NotEmpty(t, response)
   409  
   410  				var leafsResponse message.LeafsResponse
   411  				if _, err = message.Codec.Unmarshal(response, &leafsResponse); err != nil {
   412  					t.Fatalf("unexpected error when unmarshalling LeafsResponse: %v", err)
   413  				}
   414  
   415  				assert.EqualValues(t, 500, len(leafsResponse.Keys))
   416  				assert.EqualValues(t, 500, len(leafsResponse.Vals))
   417  				assert.Empty(t, leafsResponse.ProofVals)
   418  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   419  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   420  				assertRangeProofIsValid(t, &request, &leafsResponse, false)
   421  			},
   422  		},
   423  		"account data served from snapshot": {
   424  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   425  				snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false)
   426  				if err != nil {
   427  					t.Fatal(err)
   428  				}
   429  				snapshotProvider.Snapshot = snap
   430  				return context.Background(), message.LeafsRequest{
   431  					Root:  accountTrieRoot,
   432  					Limit: maxLeavesLimit,
   433  				}
   434  			},
   435  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   436  				assert.NoError(t, err)
   437  				var leafsResponse message.LeafsResponse
   438  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   439  				assert.NoError(t, err)
   440  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys))
   441  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals))
   442  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   443  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   444  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount)
   445  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadSuccessCount)
   446  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   447  			},
   448  		},
   449  		"partial account data served from snapshot": {
   450  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   451  				snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false)
   452  				if err != nil {
   453  					t.Fatal(err)
   454  				}
   455  				snapshotProvider.Snapshot = snap
   456  				it := snap.DiskAccountIterator(common.Hash{})
   457  				defer it.Release()
   458  				i := 0
   459  				for it.Next() {
   460  					if i > int(maxLeavesLimit) {
   461  						// no need to modify beyond the request limit
   462  						break
   463  					}
   464  					// modify one entry of 1 in 4 segments
   465  					if i%(segmentLen*4) == 0 {
   466  						var acc snapshot.Account
   467  						if err := rlp.DecodeBytes(it.Account(), &acc); err != nil {
   468  							t.Fatalf("could not parse snapshot account: %v", err)
   469  						}
   470  						acc.Nonce++
   471  						bytes, err := rlp.EncodeToBytes(acc)
   472  						if err != nil {
   473  							t.Fatalf("coult not encode snapshot account to bytes: %v", err)
   474  						}
   475  						rawdb.WriteAccountSnapshot(memdb, it.Hash(), bytes)
   476  					}
   477  					i++
   478  				}
   479  
   480  				return context.Background(), message.LeafsRequest{
   481  					Root:  accountTrieRoot,
   482  					Limit: maxLeavesLimit,
   483  				}
   484  			},
   485  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   486  				assert.NoError(t, err)
   487  				var leafsResponse message.LeafsResponse
   488  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   489  				assert.NoError(t, err)
   490  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys))
   491  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals))
   492  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   493  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   494  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount)
   495  				assert.EqualValues(t, 0, mockHandlerStats.SnapshotReadSuccessCount)
   496  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   497  
   498  				// expect 1/4th of segments to be invalid
   499  				numSegments := maxLeavesLimit / segmentLen
   500  				assert.EqualValues(t, numSegments/4, mockHandlerStats.SnapshotSegmentInvalidCount)
   501  				assert.EqualValues(t, 3*numSegments/4, mockHandlerStats.SnapshotSegmentValidCount)
   502  			},
   503  		},
   504  		"storage data served from snapshot": {
   505  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   506  				snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false)
   507  				if err != nil {
   508  					t.Fatal(err)
   509  				}
   510  				snapshotProvider.Snapshot = snap
   511  				return context.Background(), message.LeafsRequest{
   512  					Root:    largeTrieRoot,
   513  					Account: largeStorageAccount,
   514  					Limit:   maxLeavesLimit,
   515  				}
   516  			},
   517  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   518  				assert.NoError(t, err)
   519  				var leafsResponse message.LeafsResponse
   520  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   521  				assert.NoError(t, err)
   522  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys))
   523  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals))
   524  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   525  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   526  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount)
   527  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadSuccessCount)
   528  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   529  			},
   530  		},
   531  		"partial storage data served from snapshot": {
   532  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   533  				snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false)
   534  				if err != nil {
   535  					t.Fatal(err)
   536  				}
   537  				snapshotProvider.Snapshot = snap
   538  				it := snap.DiskStorageIterator(largeStorageAccount, common.Hash{})
   539  				defer it.Release()
   540  				i := 0
   541  				for it.Next() {
   542  					if i > int(maxLeavesLimit) {
   543  						// no need to modify beyond the request limit
   544  						break
   545  					}
   546  					// modify one entry of 1 in 4 segments
   547  					if i%(segmentLen*4) == 0 {
   548  						randomBytes := make([]byte, 5)
   549  						_, err := rand.Read(randomBytes)
   550  						assert.NoError(t, err)
   551  						rawdb.WriteStorageSnapshot(memdb, largeStorageAccount, it.Hash(), randomBytes)
   552  					}
   553  					i++
   554  				}
   555  
   556  				return context.Background(), message.LeafsRequest{
   557  					Root:    largeTrieRoot,
   558  					Account: largeStorageAccount,
   559  					Limit:   maxLeavesLimit,
   560  				}
   561  			},
   562  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   563  				assert.NoError(t, err)
   564  				var leafsResponse message.LeafsResponse
   565  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   566  				assert.NoError(t, err)
   567  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys))
   568  				assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals))
   569  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   570  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   571  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount)
   572  				assert.EqualValues(t, 0, mockHandlerStats.SnapshotReadSuccessCount)
   573  				assertRangeProofIsValid(t, &request, &leafsResponse, true)
   574  
   575  				// expect 1/4th of segments to be invalid
   576  				numSegments := maxLeavesLimit / segmentLen
   577  				assert.EqualValues(t, numSegments/4, mockHandlerStats.SnapshotSegmentInvalidCount)
   578  				assert.EqualValues(t, 3*numSegments/4, mockHandlerStats.SnapshotSegmentValidCount)
   579  			},
   580  		},
   581  		"last snapshot key removed": {
   582  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   583  				snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false)
   584  				if err != nil {
   585  					t.Fatal(err)
   586  				}
   587  				snapshotProvider.Snapshot = snap
   588  				it := snap.DiskStorageIterator(smallStorageAccount, common.Hash{})
   589  				defer it.Release()
   590  				var lastKey common.Hash
   591  				for it.Next() {
   592  					lastKey = it.Hash()
   593  				}
   594  				rawdb.DeleteStorageSnapshot(memdb, smallStorageAccount, lastKey)
   595  
   596  				return context.Background(), message.LeafsRequest{
   597  					Root:    smallTrieRoot,
   598  					Account: smallStorageAccount,
   599  					Limit:   maxLeavesLimit,
   600  				}
   601  			},
   602  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   603  				assert.NoError(t, err)
   604  				var leafsResponse message.LeafsResponse
   605  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   606  				assert.NoError(t, err)
   607  				assert.EqualValues(t, 500, len(leafsResponse.Keys))
   608  				assert.EqualValues(t, 500, len(leafsResponse.Vals))
   609  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   610  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   611  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount)
   612  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadSuccessCount)
   613  				assertRangeProofIsValid(t, &request, &leafsResponse, false)
   614  			},
   615  		},
   616  		"request last key when removed from snapshot": {
   617  			prepareTestFn: func() (context.Context, message.LeafsRequest) {
   618  				snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false)
   619  				if err != nil {
   620  					t.Fatal(err)
   621  				}
   622  				snapshotProvider.Snapshot = snap
   623  				it := snap.DiskStorageIterator(smallStorageAccount, common.Hash{})
   624  				defer it.Release()
   625  				var lastKey common.Hash
   626  				for it.Next() {
   627  					lastKey = it.Hash()
   628  				}
   629  				rawdb.DeleteStorageSnapshot(memdb, smallStorageAccount, lastKey)
   630  
   631  				return context.Background(), message.LeafsRequest{
   632  					Root:    smallTrieRoot,
   633  					Account: smallStorageAccount,
   634  					Start:   lastKey[:],
   635  					Limit:   maxLeavesLimit,
   636  				}
   637  			},
   638  			assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) {
   639  				assert.NoError(t, err)
   640  				var leafsResponse message.LeafsResponse
   641  				_, err = message.Codec.Unmarshal(response, &leafsResponse)
   642  				assert.NoError(t, err)
   643  				assert.EqualValues(t, 1, len(leafsResponse.Keys))
   644  				assert.EqualValues(t, 1, len(leafsResponse.Vals))
   645  				assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount)
   646  				assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum)
   647  				assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount)
   648  				assert.EqualValues(t, 0, mockHandlerStats.SnapshotReadSuccessCount)
   649  				assertRangeProofIsValid(t, &request, &leafsResponse, false)
   650  			},
   651  		},
   652  	}
   653  	for name, test := range tests {
   654  		t.Run(name, func(t *testing.T) {
   655  			ctx, request := test.prepareTestFn()
   656  			t.Cleanup(func() {
   657  				<-snapshot.WipeSnapshot(memdb, true)
   658  				mockHandlerStats.Reset()
   659  				snapshotProvider.Snapshot = nil // reset the snapshot to nil
   660  			})
   661  
   662  			response, err := leafsHandler.OnLeafsRequest(ctx, ids.GenerateTestNodeID(), 1, request)
   663  			test.assertResponseFn(t, request, response, err)
   664  		})
   665  	}
   666  }
   667  
   668  func assertRangeProofIsValid(t *testing.T, request *message.LeafsRequest, response *message.LeafsResponse, expectMore bool) {
   669  	t.Helper()
   670  
   671  	var start, end []byte
   672  	if len(request.Start) == 0 {
   673  		start = bytes.Repeat([]byte{0x00}, common.HashLength)
   674  	} else {
   675  		start = request.Start
   676  	}
   677  	if len(response.Keys) > 0 {
   678  		end = response.Keys[len(response.Vals)-1]
   679  	}
   680  
   681  	var proof ethdb.Database
   682  	if len(response.ProofVals) > 0 {
   683  		proof = memorydb.New()
   684  		defer proof.Close()
   685  		for _, proofVal := range response.ProofVals {
   686  			proofKey := crypto.Keccak256(proofVal)
   687  			if err := proof.Put(proofKey, proofVal); err != nil {
   688  				t.Fatal(err)
   689  			}
   690  		}
   691  	}
   692  
   693  	more, err := trie.VerifyRangeProof(request.Root, start, end, response.Keys, response.Vals, proof)
   694  	assert.NoError(t, err)
   695  	assert.Equal(t, expectMore, more)
   696  }