github.com/MetalBlockchain/subnet-evm@v0.4.9/sync/client/client_test.go (about)

     1  // (c) 2021-2022, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package statesyncclient
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"fmt"
    10  	"math/rand"
    11  	"strings"
    12  	"testing"
    13  
    14  	"github.com/stretchr/testify/assert"
    15  
    16  	"github.com/MetalBlockchain/metalgo/ids"
    17  
    18  	"github.com/MetalBlockchain/subnet-evm/consensus/dummy"
    19  	"github.com/MetalBlockchain/subnet-evm/core"
    20  	"github.com/MetalBlockchain/subnet-evm/core/types"
    21  	"github.com/MetalBlockchain/subnet-evm/ethdb/memorydb"
    22  	"github.com/MetalBlockchain/subnet-evm/params"
    23  	"github.com/MetalBlockchain/subnet-evm/plugin/evm/message"
    24  	clientstats "github.com/MetalBlockchain/subnet-evm/sync/client/stats"
    25  	"github.com/MetalBlockchain/subnet-evm/sync/handlers"
    26  	handlerstats "github.com/MetalBlockchain/subnet-evm/sync/handlers/stats"
    27  	"github.com/MetalBlockchain/subnet-evm/trie"
    28  	"github.com/ethereum/go-ethereum/common"
    29  	"github.com/ethereum/go-ethereum/crypto"
    30  )
    31  
    32  func TestGetCode(t *testing.T) {
    33  	mockNetClient := &mockNetwork{}
    34  
    35  	tests := map[string]struct {
    36  		setupRequest func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte)
    37  		expectedErr  error
    38  	}{
    39  		"normal": {
    40  			setupRequest: func() ([]common.Hash, message.CodeResponse, [][]byte) {
    41  				code := []byte("this is the code")
    42  				codeHash := crypto.Keccak256Hash(code)
    43  				codeSlices := [][]byte{code}
    44  				return []common.Hash{codeHash}, message.CodeResponse{
    45  					Data: codeSlices,
    46  				}, codeSlices
    47  			},
    48  			expectedErr: nil,
    49  		},
    50  		"unexpected code bytes": {
    51  			setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) {
    52  				return []common.Hash{{1}}, message.CodeResponse{
    53  					Data: [][]byte{{1}},
    54  				}, nil
    55  			},
    56  			expectedErr: errHashMismatch,
    57  		},
    58  		"too many code elements returned": {
    59  			setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) {
    60  				return []common.Hash{{1}}, message.CodeResponse{
    61  					Data: [][]byte{{1}, {2}},
    62  				}, nil
    63  			},
    64  			expectedErr: errInvalidCodeResponseLen,
    65  		},
    66  		"too few code elements returned": {
    67  			setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) {
    68  				return []common.Hash{{1}}, message.CodeResponse{
    69  					Data: [][]byte{},
    70  				}, nil
    71  			},
    72  			expectedErr: errInvalidCodeResponseLen,
    73  		},
    74  		"code size is too large": {
    75  			setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) {
    76  				oversizedCode := make([]byte, params.MaxCodeSize+1)
    77  				codeHash := crypto.Keccak256Hash(oversizedCode)
    78  				return []common.Hash{codeHash}, message.CodeResponse{
    79  					Data: [][]byte{oversizedCode},
    80  				}, nil
    81  			},
    82  			expectedErr: errMaxCodeSizeExceeded,
    83  		},
    84  	}
    85  
    86  	stateSyncClient := NewClient(&ClientConfig{
    87  		NetworkClient:    mockNetClient,
    88  		Codec:            message.Codec,
    89  		Stats:            clientstats.NewNoOpStats(),
    90  		StateSyncNodeIDs: nil,
    91  		BlockParser:      mockBlockParser,
    92  	})
    93  
    94  	for name, test := range tests {
    95  		t.Run(name, func(t *testing.T) {
    96  			ctx, cancel := context.WithCancel(context.Background())
    97  			defer cancel()
    98  			codeHashes, res, expectedCode := test.setupRequest()
    99  
   100  			responseBytes, err := message.Codec.Marshal(message.Version, res)
   101  			if err != nil {
   102  				t.Fatal(err)
   103  			}
   104  			// Dirty hack required because the client will re-request if it encounters
   105  			// an error.
   106  			attempted := false
   107  			if test.expectedErr == nil {
   108  				mockNetClient.mockResponse(1, nil, responseBytes)
   109  			} else {
   110  				mockNetClient.mockResponse(2, func() {
   111  					// Cancel before the second attempt is processed.
   112  					if attempted {
   113  						cancel()
   114  					}
   115  					attempted = true
   116  				}, responseBytes)
   117  			}
   118  
   119  			codeBytes, err := stateSyncClient.GetCode(ctx, codeHashes)
   120  			// If we expect an error, assert that one occurred and return
   121  			if test.expectedErr != nil {
   122  				assert.ErrorIs(t, err, test.expectedErr)
   123  				assert.EqualValues(t, 2, mockNetClient.numCalls)
   124  				return
   125  			}
   126  			// Otherwise, assert there was no error and that the result is as expected
   127  			assert.NoError(t, err)
   128  			assert.Equal(t, len(codeBytes), len(expectedCode))
   129  			for i, code := range codeBytes {
   130  				assert.Equal(t, expectedCode[i], code)
   131  			}
   132  			assert.Equal(t, uint(1), mockNetClient.numCalls)
   133  		})
   134  	}
   135  }
   136  
   137  func TestGetBlocks(t *testing.T) {
   138  	// set random seed for deterministic tests
   139  	rand.Seed(1)
   140  
   141  	var gspec = &core.Genesis{
   142  		Config: params.TestChainConfig,
   143  	}
   144  	memdb := memorydb.New()
   145  	genesis := gspec.MustCommit(memdb)
   146  	engine := dummy.NewETHFaker()
   147  	numBlocks := 110
   148  	blocks, _, err := core.GenerateChain(params.TestChainConfig, genesis, engine, memdb, numBlocks, 0, func(i int, b *core.BlockGen) {})
   149  	if err != nil {
   150  		t.Fatal("unexpected error when generating test blockchain", err)
   151  	}
   152  	assert.Equal(t, numBlocks, len(blocks))
   153  
   154  	// Construct client
   155  	mockNetClient := &mockNetwork{}
   156  	stateSyncClient := NewClient(&ClientConfig{
   157  		NetworkClient:    mockNetClient,
   158  		Codec:            message.Codec,
   159  		Stats:            clientstats.NewNoOpStats(),
   160  		StateSyncNodeIDs: nil,
   161  		BlockParser:      mockBlockParser,
   162  	})
   163  
   164  	blocksRequestHandler := handlers.NewBlockRequestHandler(buildGetter(blocks), message.Codec, handlerstats.NewNoopHandlerStats())
   165  
   166  	// encodeBlockSlice takes a slice of blocks that are ordered in increasing height order
   167  	// and returns a slice of byte slices with those blocks encoded in reverse order
   168  	encodeBlockSlice := func(blocks []*types.Block) [][]byte {
   169  		blockBytes := make([][]byte, 0, len(blocks))
   170  		for i := len(blocks) - 1; i >= 0; i-- {
   171  			buf := new(bytes.Buffer)
   172  			if err := blocks[i].EncodeRLP(buf); err != nil {
   173  				t.Fatalf("failed to generate expected response %s", err)
   174  			}
   175  			blockBytes = append(blockBytes, buf.Bytes())
   176  		}
   177  
   178  		return blockBytes
   179  	}
   180  	tests := map[string]struct {
   181  		request        message.BlockRequest
   182  		getResponse    func(t *testing.T, request message.BlockRequest) []byte
   183  		assertResponse func(t *testing.T, response []*types.Block)
   184  		expectedErr    string
   185  	}{
   186  		"normal resonse": {
   187  			request: message.BlockRequest{
   188  				Hash:    blocks[100].Hash(),
   189  				Height:  100,
   190  				Parents: 16,
   191  			},
   192  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   193  				response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   194  				if err != nil {
   195  					t.Fatal(err)
   196  				}
   197  
   198  				if len(response) == 0 {
   199  					t.Fatal("Failed to generate valid response")
   200  				}
   201  
   202  				return response
   203  			},
   204  			assertResponse: func(t *testing.T, response []*types.Block) {
   205  				assert.Equal(t, 16, len(response))
   206  			},
   207  		},
   208  		"fewer than requested blocks": {
   209  			request: message.BlockRequest{
   210  				Hash:    blocks[100].Hash(),
   211  				Height:  100,
   212  				Parents: 16,
   213  			},
   214  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   215  				request.Parents -= 5
   216  				response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   217  				if err != nil {
   218  					t.Fatal(err)
   219  				}
   220  
   221  				if len(response) == 0 {
   222  					t.Fatal("Failed to generate valid response")
   223  				}
   224  
   225  				return response
   226  			},
   227  			// If the server returns fewer than requested blocks, we should consider it valid
   228  			assertResponse: func(t *testing.T, response []*types.Block) {
   229  				assert.Equal(t, 11, len(response))
   230  			},
   231  		},
   232  		"gibberish response": {
   233  			request: message.BlockRequest{
   234  				Hash:    blocks[100].Hash(),
   235  				Height:  100,
   236  				Parents: 16,
   237  			},
   238  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   239  				return []byte("gibberish")
   240  			},
   241  			expectedErr: errUnmarshalResponse.Error(),
   242  		},
   243  		"invalid value replacing block": {
   244  			request: message.BlockRequest{
   245  				Hash:    blocks[100].Hash(),
   246  				Height:  100,
   247  				Parents: 16,
   248  			},
   249  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   250  				response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   251  				if err != nil {
   252  					t.Fatalf("failed to get block response: %s", err)
   253  				}
   254  				var blockResponse message.BlockResponse
   255  				if _, err = message.Codec.Unmarshal(response, &blockResponse); err != nil {
   256  					t.Fatalf("failed to marshal block response: %s", err)
   257  				}
   258  				// Replace middle value with garbage data
   259  				blockResponse.Blocks[10] = []byte("invalid value replacing block bytes")
   260  				responseBytes, err := message.Codec.Marshal(message.Version, blockResponse)
   261  				if err != nil {
   262  					t.Fatalf("failed to marshal block response: %s", err)
   263  				}
   264  
   265  				return responseBytes
   266  			},
   267  			expectedErr: "failed to unmarshal response: rlp: expected input list for types.extblock",
   268  		},
   269  		"incorrect starting point": {
   270  			request: message.BlockRequest{
   271  				Hash:    blocks[100].Hash(),
   272  				Height:  100,
   273  				Parents: 16,
   274  			},
   275  			getResponse: func(t *testing.T, _ message.BlockRequest) []byte {
   276  				response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, message.BlockRequest{
   277  					Hash:    blocks[99].Hash(),
   278  					Height:  99,
   279  					Parents: 16,
   280  				})
   281  				if err != nil {
   282  					t.Fatal(err)
   283  				}
   284  
   285  				if len(response) == 0 {
   286  					t.Fatal("Failed to generate valid response")
   287  				}
   288  
   289  				return response
   290  			},
   291  			expectedErr: errHashMismatch.Error(),
   292  		},
   293  		"missing link in between blocks": {
   294  			request: message.BlockRequest{
   295  				Hash:    blocks[100].Hash(),
   296  				Height:  100,
   297  				Parents: 16,
   298  			},
   299  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   300  				// Encode blocks with a missing link
   301  				blks := make([]*types.Block, 0)
   302  				blks = append(blks, blocks[84:89]...)
   303  				blks = append(blks, blocks[90:101]...)
   304  				blockBytes := encodeBlockSlice(blks)
   305  
   306  				blockResponse := message.BlockResponse{
   307  					Blocks: blockBytes,
   308  				}
   309  				responseBytes, err := message.Codec.Marshal(message.Version, blockResponse)
   310  				if err != nil {
   311  					t.Fatalf("failed to marshal block response: %s", err)
   312  				}
   313  
   314  				return responseBytes
   315  			},
   316  			expectedErr: errHashMismatch.Error(),
   317  		},
   318  		"no blocks": {
   319  			request: message.BlockRequest{
   320  				Hash:    blocks[100].Hash(),
   321  				Height:  100,
   322  				Parents: 16,
   323  			},
   324  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   325  				blockResponse := message.BlockResponse{
   326  					Blocks: nil,
   327  				}
   328  				responseBytes, err := message.Codec.Marshal(message.Version, blockResponse)
   329  				if err != nil {
   330  					t.Fatalf("failed to marshal block response: %s", err)
   331  				}
   332  
   333  				return responseBytes
   334  			},
   335  			expectedErr: errEmptyResponse.Error(),
   336  		},
   337  		"more than requested blocks": {
   338  			request: message.BlockRequest{
   339  				Hash:    blocks[100].Hash(),
   340  				Height:  100,
   341  				Parents: 16,
   342  			},
   343  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   344  				blockBytes := encodeBlockSlice(blocks[80:100])
   345  
   346  				blockResponse := message.BlockResponse{
   347  					Blocks: blockBytes,
   348  				}
   349  				responseBytes, err := message.Codec.Marshal(message.Version, blockResponse)
   350  				if err != nil {
   351  					t.Fatalf("failed to marshal block response: %s", err)
   352  				}
   353  
   354  				return responseBytes
   355  			},
   356  			expectedErr: errTooManyBlocks.Error(),
   357  		},
   358  	}
   359  	for name, test := range tests {
   360  		t.Run(name, func(t *testing.T) {
   361  			ctx, cancel := context.WithCancel(context.Background())
   362  			defer cancel()
   363  
   364  			responseBytes := test.getResponse(t, test.request)
   365  			if len(test.expectedErr) == 0 {
   366  				mockNetClient.mockResponse(1, nil, responseBytes)
   367  			} else {
   368  				attempted := false
   369  				mockNetClient.mockResponse(2, func() {
   370  					if attempted {
   371  						cancel()
   372  					}
   373  					attempted = true
   374  				}, responseBytes)
   375  			}
   376  
   377  			blockResponse, err := stateSyncClient.GetBlocks(ctx, test.request.Hash, test.request.Height, test.request.Parents)
   378  			if len(test.expectedErr) != 0 {
   379  				if err == nil {
   380  					t.Fatalf("Expected error: %s, but found no error", test.expectedErr)
   381  				}
   382  				assert.True(t, strings.Contains(err.Error(), test.expectedErr), "expected error to contain [%s], but found [%s]", test.expectedErr, err)
   383  				return
   384  			}
   385  			if err != nil {
   386  				t.Fatal(err)
   387  			}
   388  
   389  			test.assertResponse(t, blockResponse)
   390  		})
   391  	}
   392  }
   393  
   394  func buildGetter(blocks []*types.Block) handlers.BlockProvider {
   395  	return &handlers.TestBlockProvider{
   396  		GetBlockFn: func(blockHash common.Hash, blockHeight uint64) *types.Block {
   397  			requestedBlock := blocks[blockHeight]
   398  			if requestedBlock.Hash() != blockHash {
   399  				fmt.Printf("ERROR height=%d, hash=%s, parentHash=%s, reqHash=%s\n", blockHeight, blockHash, requestedBlock.ParentHash(), requestedBlock.Hash())
   400  				return nil
   401  			}
   402  			return requestedBlock
   403  		},
   404  	}
   405  }
   406  
   407  func TestGetLeafs(t *testing.T) {
   408  	rand.Seed(1)
   409  
   410  	const leafsLimit = 1024
   411  
   412  	trieDB := trie.NewDatabase(memorydb.New())
   413  	largeTrieRoot, largeTrieKeys, _ := trie.GenerateTrie(t, trieDB, 100_000, common.HashLength)
   414  	smallTrieRoot, _, _ := trie.GenerateTrie(t, trieDB, leafsLimit, common.HashLength)
   415  
   416  	handler := handlers.NewLeafsRequestHandler(trieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats())
   417  	client := NewClient(&ClientConfig{
   418  		NetworkClient:    &mockNetwork{},
   419  		Codec:            message.Codec,
   420  		Stats:            clientstats.NewNoOpStats(),
   421  		StateSyncNodeIDs: nil,
   422  		BlockParser:      mockBlockParser,
   423  	})
   424  
   425  	tests := map[string]struct {
   426  		request        message.LeafsRequest
   427  		getResponse    func(t *testing.T, request message.LeafsRequest) []byte
   428  		assertResponse func(t *testing.T, response message.LeafsResponse)
   429  		expectedErr    error
   430  	}{
   431  		"full response for small (single request) trie": {
   432  			request: message.LeafsRequest{
   433  				Root:  smallTrieRoot,
   434  				Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   435  				End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   436  				Limit: leafsLimit,
   437  			},
   438  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   439  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   440  				if err != nil {
   441  					t.Fatal("unexpected error in calling leafs request handler", err)
   442  				}
   443  				if len(response) == 0 {
   444  					t.Fatal("Failed to create valid response")
   445  				}
   446  
   447  				return response
   448  			},
   449  			assertResponse: func(t *testing.T, response message.LeafsResponse) {
   450  				assert.False(t, response.More)
   451  				assert.Equal(t, leafsLimit, len(response.Keys))
   452  				assert.Equal(t, leafsLimit, len(response.Vals))
   453  			},
   454  		},
   455  		"too many leaves in response": {
   456  			request: message.LeafsRequest{
   457  				Root:  smallTrieRoot,
   458  				Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   459  				End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   460  				Limit: leafsLimit / 2,
   461  			},
   462  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   463  				modifiedRequest := request
   464  				modifiedRequest.Limit = leafsLimit
   465  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, modifiedRequest)
   466  				if err != nil {
   467  					t.Fatal("unexpected error in calling leafs request handler", err)
   468  				}
   469  				if len(response) == 0 {
   470  					t.Fatal("Failed to create valid response")
   471  				}
   472  
   473  				return response
   474  			},
   475  			expectedErr: errTooManyLeaves,
   476  		},
   477  		"partial response to request for entire trie (full leaf limit)": {
   478  			request: message.LeafsRequest{
   479  				Root:  largeTrieRoot,
   480  				Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   481  				End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   482  				Limit: leafsLimit,
   483  			},
   484  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   485  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   486  				if err != nil {
   487  					t.Fatal("unexpected error in calling leafs request handler", err)
   488  				}
   489  				if len(response) == 0 {
   490  					t.Fatal("Failed to create valid response")
   491  				}
   492  
   493  				return response
   494  			},
   495  			assertResponse: func(t *testing.T, response message.LeafsResponse) {
   496  				assert.True(t, response.More)
   497  				assert.Equal(t, leafsLimit, len(response.Keys))
   498  				assert.Equal(t, leafsLimit, len(response.Vals))
   499  			},
   500  		},
   501  		"partial response to request for middle range of trie (full leaf limit)": {
   502  			request: message.LeafsRequest{
   503  				Root:  largeTrieRoot,
   504  				Start: largeTrieKeys[1000],
   505  				End:   largeTrieKeys[99000],
   506  				Limit: leafsLimit,
   507  			},
   508  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   509  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   510  				if err != nil {
   511  					t.Fatal("unexpected error in calling leafs request handler", err)
   512  				}
   513  
   514  				if len(response) == 0 {
   515  					t.Fatal("Failed to create valid response")
   516  				}
   517  				return response
   518  			},
   519  			assertResponse: func(t *testing.T, response message.LeafsResponse) {
   520  				assert.True(t, response.More)
   521  				assert.Equal(t, leafsLimit, len(response.Keys))
   522  				assert.Equal(t, leafsLimit, len(response.Vals))
   523  			},
   524  		},
   525  		"full response from near end of trie to end of trie (less than leaf limit)": {
   526  			request: message.LeafsRequest{
   527  				Root:  largeTrieRoot,
   528  				Start: largeTrieKeys[len(largeTrieKeys)-30], // Set start 30 keys from the end of the large trie
   529  				End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   530  				Limit: leafsLimit,
   531  			},
   532  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   533  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   534  				if err != nil {
   535  					t.Fatal("unexpected error in calling leafs request handler", err)
   536  				}
   537  				if len(response) == 0 {
   538  					t.Fatal("Failed to create valid response")
   539  				}
   540  				return response
   541  			},
   542  			assertResponse: func(t *testing.T, response message.LeafsResponse) {
   543  				assert.False(t, response.More)
   544  				assert.Equal(t, 30, len(response.Keys))
   545  				assert.Equal(t, 30, len(response.Vals))
   546  			},
   547  		},
   548  		"full response for intermediate range of trie (less than leaf limit)": {
   549  			request: message.LeafsRequest{
   550  				Root:  largeTrieRoot,
   551  				Start: largeTrieKeys[1000], // Set the range for 1000 leafs in an intermediate range of the trie
   552  				End:   largeTrieKeys[1099], // (inclusive range)
   553  				Limit: leafsLimit,
   554  			},
   555  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   556  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   557  				if err != nil {
   558  					t.Fatal("unexpected error in calling leafs request handler", err)
   559  				}
   560  				if len(response) == 0 {
   561  					t.Fatal("Failed to create valid response")
   562  				}
   563  
   564  				return response
   565  			},
   566  			assertResponse: func(t *testing.T, response message.LeafsResponse) {
   567  				assert.True(t, response.More)
   568  				assert.Equal(t, 100, len(response.Keys))
   569  				assert.Equal(t, 100, len(response.Vals))
   570  			},
   571  		},
   572  		"removed first key in response": {
   573  			request: message.LeafsRequest{
   574  				Root:  largeTrieRoot,
   575  				Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   576  				End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   577  				Limit: leafsLimit,
   578  			},
   579  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   580  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   581  				if err != nil {
   582  					t.Fatal("unexpected error in calling leafs request handler", err)
   583  				}
   584  				if len(response) == 0 {
   585  					t.Fatal("Failed to create valid response")
   586  				}
   587  				var leafResponse message.LeafsResponse
   588  				if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil {
   589  					t.Fatal(err)
   590  				}
   591  				leafResponse.Keys = leafResponse.Keys[1:]
   592  				leafResponse.Vals = leafResponse.Vals[1:]
   593  
   594  				modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse)
   595  				if err != nil {
   596  					t.Fatal(err)
   597  				}
   598  				return modifiedResponse
   599  			},
   600  			expectedErr: errInvalidRangeProof,
   601  		},
   602  		"removed first key in response and replaced proof": {
   603  			request: message.LeafsRequest{
   604  				Root:  largeTrieRoot,
   605  				Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   606  				End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   607  				Limit: leafsLimit,
   608  			},
   609  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   610  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   611  				if err != nil {
   612  					t.Fatal("unexpected error in calling leafs request handler", err)
   613  				}
   614  				if len(response) == 0 {
   615  					t.Fatal("Failed to create valid response")
   616  				}
   617  				var leafResponse message.LeafsResponse
   618  				if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil {
   619  					t.Fatal(err)
   620  				}
   621  				modifiedRequest := request
   622  				modifiedRequest.Start = leafResponse.Keys[1]
   623  				modifiedResponse, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 2, modifiedRequest)
   624  				if err != nil {
   625  					t.Fatal("unexpected error in calling leafs request handler", err)
   626  				}
   627  				return modifiedResponse
   628  			},
   629  			expectedErr: errInvalidRangeProof,
   630  		},
   631  		"removed last key in response": {
   632  			request: message.LeafsRequest{
   633  				Root:  largeTrieRoot,
   634  				Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   635  				End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   636  				Limit: leafsLimit,
   637  			},
   638  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   639  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   640  				if err != nil {
   641  					t.Fatal("unexpected error in calling leafs request handler", err)
   642  				}
   643  				if len(response) == 0 {
   644  					t.Fatal("Failed to create valid response")
   645  				}
   646  				var leafResponse message.LeafsResponse
   647  				if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil {
   648  					t.Fatal(err)
   649  				}
   650  				leafResponse.Keys = leafResponse.Keys[:len(leafResponse.Keys)-2]
   651  				leafResponse.Vals = leafResponse.Vals[:len(leafResponse.Vals)-2]
   652  
   653  				modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse)
   654  				if err != nil {
   655  					t.Fatal(err)
   656  				}
   657  				return modifiedResponse
   658  			},
   659  			expectedErr: errInvalidRangeProof,
   660  		},
   661  		"removed key from middle of response": {
   662  			request: message.LeafsRequest{
   663  				Root:  largeTrieRoot,
   664  				Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   665  				End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   666  				Limit: leafsLimit,
   667  			},
   668  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   669  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   670  				if err != nil {
   671  					t.Fatal("unexpected error in calling leafs request handler", err)
   672  				}
   673  				if len(response) == 0 {
   674  					t.Fatal("Failed to create valid response")
   675  				}
   676  				var leafResponse message.LeafsResponse
   677  				if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil {
   678  					t.Fatal(err)
   679  				}
   680  				// Remove middle key-value pair response
   681  				leafResponse.Keys = append(leafResponse.Keys[:100], leafResponse.Keys[101:]...)
   682  				leafResponse.Vals = append(leafResponse.Vals[:100], leafResponse.Vals[101:]...)
   683  
   684  				modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse)
   685  				if err != nil {
   686  					t.Fatal(err)
   687  				}
   688  				return modifiedResponse
   689  			},
   690  			expectedErr: errInvalidRangeProof,
   691  		},
   692  		"corrupted value in middle of response": {
   693  			request: message.LeafsRequest{
   694  				Root:  largeTrieRoot,
   695  				Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   696  				End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   697  				Limit: leafsLimit,
   698  			},
   699  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   700  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   701  				if err != nil {
   702  					t.Fatal("unexpected error in calling leafs request handler", err)
   703  				}
   704  				if len(response) == 0 {
   705  					t.Fatal("Failed to create valid response")
   706  				}
   707  				var leafResponse message.LeafsResponse
   708  				if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil {
   709  					t.Fatal(err)
   710  				}
   711  				// Remove middle key-value pair response
   712  				leafResponse.Vals[100] = []byte("garbage value data")
   713  
   714  				modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse)
   715  				if err != nil {
   716  					t.Fatal(err)
   717  				}
   718  				return modifiedResponse
   719  			},
   720  			expectedErr: errInvalidRangeProof,
   721  		},
   722  		"all proof keys removed from response": {
   723  			request: message.LeafsRequest{
   724  				Root:  largeTrieRoot,
   725  				Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   726  				End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   727  				Limit: leafsLimit,
   728  			},
   729  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   730  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   731  				if err != nil {
   732  					t.Fatal("unexpected error in calling leafs request handler", err)
   733  				}
   734  				if len(response) == 0 {
   735  					t.Fatal("Failed to create valid response")
   736  				}
   737  
   738  				var leafResponse message.LeafsResponse
   739  				if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil {
   740  					t.Fatal(err)
   741  				}
   742  				// Remove the proof
   743  				leafResponse.ProofVals = nil
   744  
   745  				modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse)
   746  				if err != nil {
   747  					t.Fatal(err)
   748  				}
   749  				return modifiedResponse
   750  			},
   751  			expectedErr: errInvalidRangeProof,
   752  		},
   753  	}
   754  	for name, test := range tests {
   755  		t.Run(name, func(t *testing.T) {
   756  			responseBytes := test.getResponse(t, test.request)
   757  
   758  			response, _, err := parseLeafsResponse(client.codec, test.request, responseBytes)
   759  			if test.expectedErr != nil {
   760  				if err == nil {
   761  					t.Fatalf("Expected error: %s, but found no error", test.expectedErr)
   762  				}
   763  				assert.True(t, strings.Contains(err.Error(), test.expectedErr.Error()))
   764  				return
   765  			}
   766  
   767  			if err != nil {
   768  				t.Fatal(err)
   769  			}
   770  
   771  			leafsResponse, ok := response.(message.LeafsResponse)
   772  			if !ok {
   773  				t.Fatalf("parseLeafsResponse returned incorrect type %T", response)
   774  			}
   775  			test.assertResponse(t, leafsResponse)
   776  		})
   777  	}
   778  }
   779  
   780  func TestGetLeafsRetries(t *testing.T) {
   781  	rand.Seed(1)
   782  
   783  	trieDB := trie.NewDatabase(memorydb.New())
   784  	root, _, _ := trie.GenerateTrie(t, trieDB, 100_000, common.HashLength)
   785  
   786  	handler := handlers.NewLeafsRequestHandler(trieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats())
   787  	mockNetClient := &mockNetwork{}
   788  
   789  	const maxAttempts = 8
   790  	client := NewClient(&ClientConfig{
   791  		NetworkClient:    mockNetClient,
   792  		Codec:            message.Codec,
   793  		Stats:            clientstats.NewNoOpStats(),
   794  		StateSyncNodeIDs: nil,
   795  		BlockParser:      mockBlockParser,
   796  	})
   797  
   798  	request := message.LeafsRequest{
   799  		Root:  root,
   800  		Start: bytes.Repeat([]byte{0x00}, common.HashLength),
   801  		End:   bytes.Repeat([]byte{0xff}, common.HashLength),
   802  		Limit: defaultLeafRequestLimit,
   803  	}
   804  
   805  	ctx, cancel := context.WithCancel(context.Background())
   806  	defer cancel()
   807  	goodResponse, responseErr := handler.OnLeafsRequest(ctx, ids.GenerateTestNodeID(), 1, request)
   808  	assert.NoError(t, responseErr)
   809  	mockNetClient.mockResponse(1, nil, goodResponse)
   810  
   811  	res, err := client.GetLeafs(ctx, request)
   812  	if err != nil {
   813  		t.Fatal(err)
   814  	}
   815  	assert.Equal(t, 1024, len(res.Keys))
   816  	assert.Equal(t, 1024, len(res.Vals))
   817  
   818  	// Succeeds within the allotted number of attempts
   819  	invalidResponse := []byte("invalid response")
   820  	mockNetClient.mockResponses(nil, invalidResponse, invalidResponse, goodResponse)
   821  
   822  	res, err = client.GetLeafs(ctx, request)
   823  	if err != nil {
   824  		t.Fatal(err)
   825  	}
   826  	assert.Equal(t, 1024, len(res.Keys))
   827  	assert.Equal(t, 1024, len(res.Vals))
   828  
   829  	// Test that GetLeafs stops after the context is cancelled
   830  	numAttempts := 0
   831  	mockNetClient.mockResponse(maxAttempts, func() {
   832  		numAttempts++
   833  		if numAttempts >= maxAttempts {
   834  			cancel()
   835  		}
   836  	}, invalidResponse)
   837  	_, err = client.GetLeafs(ctx, request)
   838  	assert.Error(t, err)
   839  	assert.True(t, strings.Contains(err.Error(), context.Canceled.Error()))
   840  }
   841  
   842  func TestStateSyncNodes(t *testing.T) {
   843  	mockNetClient := &mockNetwork{}
   844  
   845  	stateSyncNodes := []ids.NodeID{
   846  		ids.GenerateTestNodeID(),
   847  		ids.GenerateTestNodeID(),
   848  		ids.GenerateTestNodeID(),
   849  		ids.GenerateTestNodeID(),
   850  	}
   851  	client := NewClient(&ClientConfig{
   852  		NetworkClient:    mockNetClient,
   853  		Codec:            message.Codec,
   854  		Stats:            clientstats.NewNoOpStats(),
   855  		StateSyncNodeIDs: stateSyncNodes,
   856  		BlockParser:      mockBlockParser,
   857  	})
   858  	ctx, cancel := context.WithCancel(context.Background())
   859  	defer cancel()
   860  	attempt := 0
   861  	responses := [][]byte{{1}, {2}, {3}, {4}}
   862  	mockNetClient.mockResponses(func() {
   863  		attempt++
   864  		if attempt >= 4 {
   865  			cancel()
   866  		}
   867  	}, responses...)
   868  
   869  	// send some request, doesn't matter what it is because we're testing the interaction with state sync nodes here
   870  	response, err := client.GetLeafs(ctx, message.LeafsRequest{})
   871  	assert.Error(t, err)
   872  	assert.Empty(t, response)
   873  
   874  	// assert all nodes were called
   875  	assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[0])
   876  	assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[1])
   877  	assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[2])
   878  	assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[3])
   879  }