github.com/dim4egster/coreth@v0.10.2/sync/client/client_test.go (about)

     1  // (c) 2021-2022, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package statesyncclient
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"fmt"
    10  	"math/rand"
    11  	"strings"
    12  	"testing"
    13  
    14  	"github.com/stretchr/testify/assert"
    15  
    16  	"github.com/dim4egster/qmallgo/ids"
    17  
    18  	"github.com/dim4egster/coreth/consensus/dummy"
    19  	"github.com/dim4egster/coreth/core"
    20  	"github.com/dim4egster/coreth/core/types"
    21  	"github.com/dim4egster/coreth/ethdb/memorydb"
    22  	"github.com/dim4egster/coreth/params"
    23  	"github.com/dim4egster/coreth/plugin/evm/message"
    24  	clientstats "github.com/dim4egster/coreth/sync/client/stats"
    25  	"github.com/dim4egster/coreth/sync/handlers"
    26  	handlerstats "github.com/dim4egster/coreth/sync/handlers/stats"
    27  	"github.com/dim4egster/coreth/trie"
    28  	"github.com/ethereum/go-ethereum/common"
    29  	"github.com/ethereum/go-ethereum/crypto"
    30  )
    31  
    32  func TestGetCode(t *testing.T) {
    33  	mockNetClient := &mockNetwork{}
    34  
    35  	tests := map[string]struct {
    36  		setupRequest func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte)
    37  		expectedErr  error
    38  	}{
    39  		"normal": {
    40  			setupRequest: func() ([]common.Hash, message.CodeResponse, [][]byte) {
    41  				code := []byte("this is the code")
    42  				codeHash := crypto.Keccak256Hash(code)
    43  				codeSlices := [][]byte{code}
    44  				return []common.Hash{codeHash}, message.CodeResponse{
    45  					Data: codeSlices,
    46  				}, codeSlices
    47  			},
    48  			expectedErr: nil,
    49  		},
    50  		"unexpected code bytes": {
    51  			setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) {
    52  				return []common.Hash{{1}}, message.CodeResponse{
    53  					Data: [][]byte{{1}},
    54  				}, nil
    55  			},
    56  			expectedErr: errHashMismatch,
    57  		},
    58  		"too many code elements returned": {
    59  			setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) {
    60  				return []common.Hash{{1}}, message.CodeResponse{
    61  					Data: [][]byte{{1}, {2}},
    62  				}, nil
    63  			},
    64  			expectedErr: errInvalidCodeResponseLen,
    65  		},
    66  		"too few code elements returned": {
    67  			setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) {
    68  				return []common.Hash{{1}}, message.CodeResponse{
    69  					Data: [][]byte{},
    70  				}, nil
    71  			},
    72  			expectedErr: errInvalidCodeResponseLen,
    73  		},
    74  		"code size is too large": {
    75  			setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) {
    76  				oversizedCode := make([]byte, params.MaxCodeSize+1)
    77  				codeHash := crypto.Keccak256Hash(oversizedCode)
    78  				return []common.Hash{codeHash}, message.CodeResponse{
    79  					Data: [][]byte{oversizedCode},
    80  				}, nil
    81  			},
    82  			expectedErr: errMaxCodeSizeExceeded,
    83  		},
    84  	}
    85  
    86  	stateSyncClient := NewClient(&ClientConfig{
    87  		NetworkClient:    mockNetClient,
    88  		Codec:            message.Codec,
    89  		Stats:            clientstats.NewNoOpStats(),
    90  		StateSyncNodeIDs: nil,
    91  		BlockParser:      mockBlockParser,
    92  	})
    93  
    94  	for name, test := range tests {
    95  		t.Run(name, func(t *testing.T) {
    96  			ctx, cancel := context.WithCancel(context.Background())
    97  			defer cancel()
    98  			codeHashes, res, expectedCode := test.setupRequest()
    99  
   100  			responseBytes, err := message.Codec.Marshal(message.Version, res)
   101  			if err != nil {
   102  				t.Fatal(err)
   103  			}
   104  			// Dirty hack required because the client will re-request if it encounters
   105  			// an error.
   106  			attempted := false
   107  			if test.expectedErr == nil {
   108  				mockNetClient.mockResponse(1, nil, responseBytes)
   109  			} else {
   110  				mockNetClient.mockResponse(2, func() {
   111  					// Cancel before the second attempt is processed.
   112  					if attempted {
   113  						cancel()
   114  					}
   115  					attempted = true
   116  				}, responseBytes)
   117  			}
   118  
   119  			codeBytes, err := stateSyncClient.GetCode(ctx, codeHashes)
   120  			// If we expect an error, assert that one occurred and return
   121  			if test.expectedErr != nil {
   122  				assert.ErrorIs(t, err, test.expectedErr)
   123  				assert.EqualValues(t, 2, mockNetClient.numCalls)
   124  				return
   125  			}
   126  			// Otherwise, assert there was no error and that the result is as expected
   127  			assert.NoError(t, err)
   128  			assert.Equal(t, len(codeBytes), len(expectedCode))
   129  			for i, code := range codeBytes {
   130  				assert.Equal(t, expectedCode[i], code)
   131  			}
   132  			assert.Equal(t, uint(1), mockNetClient.numCalls)
   133  		})
   134  	}
   135  }
   136  
   137  func TestGetBlocks(t *testing.T) {
   138  	// set random seed for deterministic tests
   139  	rand.Seed(1)
   140  
   141  	var gspec = &core.Genesis{
   142  		Config: params.TestChainConfig,
   143  	}
   144  	memdb := memorydb.New()
   145  	genesis := gspec.MustCommit(memdb)
   146  	engine := dummy.NewETHFaker()
   147  	numBlocks := 110
   148  	blocks, _, err := core.GenerateChain(params.TestChainConfig, genesis, engine, memdb, numBlocks, 0, func(i int, b *core.BlockGen) {})
   149  	if err != nil {
   150  		t.Fatal("unexpected error when generating test blockchain", err)
   151  	}
   152  	assert.Equal(t, numBlocks, len(blocks))
   153  
   154  	// Construct client
   155  	mockNetClient := &mockNetwork{}
   156  	stateSyncClient := NewClient(&ClientConfig{
   157  		NetworkClient:    mockNetClient,
   158  		Codec:            message.Codec,
   159  		Stats:            clientstats.NewNoOpStats(),
   160  		StateSyncNodeIDs: nil,
   161  		BlockParser:      mockBlockParser,
   162  	})
   163  
   164  	blocksRequestHandler := handlers.NewBlockRequestHandler(buildGetter(blocks), message.Codec, handlerstats.NewNoopHandlerStats())
   165  
   166  	// encodeBlockSlice takes a slice of blocks that are ordered in increasing height order
   167  	// and returns a slice of byte slices with those blocks encoded in reverse order
   168  	encodeBlockSlice := func(blocks []*types.Block) [][]byte {
   169  		blockBytes := make([][]byte, 0, len(blocks))
   170  		for i := len(blocks) - 1; i >= 0; i-- {
   171  			buf := new(bytes.Buffer)
   172  			if err := blocks[i].EncodeRLP(buf); err != nil {
   173  				t.Fatalf("failed to generate expected response %s", err)
   174  			}
   175  			blockBytes = append(blockBytes, buf.Bytes())
   176  		}
   177  
   178  		return blockBytes
   179  	}
   180  	tests := map[string]struct {
   181  		request        message.BlockRequest
   182  		getResponse    func(t *testing.T, request message.BlockRequest) []byte
   183  		assertResponse func(t *testing.T, response []*types.Block)
   184  		expectedErr    string
   185  	}{
   186  		"normal resonse": {
   187  			request: message.BlockRequest{
   188  				Hash:    blocks[100].Hash(),
   189  				Height:  100,
   190  				Parents: 16,
   191  			},
   192  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   193  				response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   194  				if err != nil {
   195  					t.Fatal(err)
   196  				}
   197  
   198  				if len(response) == 0 {
   199  					t.Fatal("Failed to generate valid response")
   200  				}
   201  
   202  				return response
   203  			},
   204  			assertResponse: func(t *testing.T, response []*types.Block) {
   205  				assert.Equal(t, 16, len(response))
   206  			},
   207  		},
   208  		"fewer than requested blocks": {
   209  			request: message.BlockRequest{
   210  				Hash:    blocks[100].Hash(),
   211  				Height:  100,
   212  				Parents: 16,
   213  			},
   214  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   215  				request.Parents -= 5
   216  				response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   217  				if err != nil {
   218  					t.Fatal(err)
   219  				}
   220  
   221  				if len(response) == 0 {
   222  					t.Fatal("Failed to generate valid response")
   223  				}
   224  
   225  				return response
   226  			},
   227  			// If the server returns fewer than requested blocks, we should consider it valid
   228  			assertResponse: func(t *testing.T, response []*types.Block) {
   229  				assert.Equal(t, 11, len(response))
   230  			},
   231  		},
   232  		"gibberish response": {
   233  			request: message.BlockRequest{
   234  				Hash:    blocks[100].Hash(),
   235  				Height:  100,
   236  				Parents: 16,
   237  			},
   238  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   239  				return []byte("gibberish")
   240  			},
   241  			expectedErr: errUnmarshalResponse.Error(),
   242  		},
   243  		"invalid value replacing block": {
   244  			request: message.BlockRequest{
   245  				Hash:    blocks[100].Hash(),
   246  				Height:  100,
   247  				Parents: 16,
   248  			},
   249  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   250  				response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   251  				if err != nil {
   252  					t.Fatalf("failed to get block response: %s", err)
   253  				}
   254  				var blockResponse message.BlockResponse
   255  				if _, err = message.Codec.Unmarshal(response, &blockResponse); err != nil {
   256  					t.Fatalf("failed to marshal block response: %s", err)
   257  				}
   258  				// Replace middle value with garbage data
   259  				blockResponse.Blocks[10] = []byte("invalid value replacing block bytes")
   260  				responseBytes, err := message.Codec.Marshal(message.Version, blockResponse)
   261  				if err != nil {
   262  					t.Fatalf("failed to marshal block response: %s", err)
   263  				}
   264  
   265  				return responseBytes
   266  			},
   267  			expectedErr: "failed to unmarshal response: rlp: expected input list for types.extblock",
   268  		},
   269  		"incorrect starting point": {
   270  			request: message.BlockRequest{
   271  				Hash:    blocks[100].Hash(),
   272  				Height:  100,
   273  				Parents: 16,
   274  			},
   275  			getResponse: func(t *testing.T, _ message.BlockRequest) []byte {
   276  				response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, message.BlockRequest{
   277  					Hash:    blocks[99].Hash(),
   278  					Height:  99,
   279  					Parents: 16,
   280  				})
   281  				if err != nil {
   282  					t.Fatal(err)
   283  				}
   284  
   285  				if len(response) == 0 {
   286  					t.Fatal("Failed to generate valid response")
   287  				}
   288  
   289  				return response
   290  			},
   291  			expectedErr: errHashMismatch.Error(),
   292  		},
   293  		"missing link in between blocks": {
   294  			request: message.BlockRequest{
   295  				Hash:    blocks[100].Hash(),
   296  				Height:  100,
   297  				Parents: 16,
   298  			},
   299  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   300  				// Encode blocks with a missing link
   301  				blks := make([]*types.Block, 0)
   302  				blks = append(blks, blocks[84:89]...)
   303  				blks = append(blks, blocks[90:101]...)
   304  				blockBytes := encodeBlockSlice(blks)
   305  
   306  				blockResponse := message.BlockResponse{
   307  					Blocks: blockBytes,
   308  				}
   309  				responseBytes, err := message.Codec.Marshal(message.Version, blockResponse)
   310  				if err != nil {
   311  					t.Fatalf("failed to marshal block response: %s", err)
   312  				}
   313  
   314  				return responseBytes
   315  			},
   316  			expectedErr: errHashMismatch.Error(),
   317  		},
   318  		"no blocks": {
   319  			request: message.BlockRequest{
   320  				Hash:    blocks[100].Hash(),
   321  				Height:  100,
   322  				Parents: 16,
   323  			},
   324  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   325  				blockResponse := message.BlockResponse{
   326  					Blocks: nil,
   327  				}
   328  				responseBytes, err := message.Codec.Marshal(message.Version, blockResponse)
   329  				if err != nil {
   330  					t.Fatalf("failed to marshal block response: %s", err)
   331  				}
   332  
   333  				return responseBytes
   334  			},
   335  			expectedErr: errEmptyResponse.Error(),
   336  		},
   337  		"more than requested blocks": {
   338  			request: message.BlockRequest{
   339  				Hash:    blocks[100].Hash(),
   340  				Height:  100,
   341  				Parents: 16,
   342  			},
   343  			getResponse: func(t *testing.T, request message.BlockRequest) []byte {
   344  				blockBytes := encodeBlockSlice(blocks[80:100])
   345  
   346  				blockResponse := message.BlockResponse{
   347  					Blocks: blockBytes,
   348  				}
   349  				responseBytes, err := message.Codec.Marshal(message.Version, blockResponse)
   350  				if err != nil {
   351  					t.Fatalf("failed to marshal block response: %s", err)
   352  				}
   353  
   354  				return responseBytes
   355  			},
   356  			expectedErr: errTooManyBlocks.Error(),
   357  		},
   358  	}
   359  	for name, test := range tests {
   360  		t.Run(name, func(t *testing.T) {
   361  			ctx, cancel := context.WithCancel(context.Background())
   362  			defer cancel()
   363  
   364  			responseBytes := test.getResponse(t, test.request)
   365  			if len(test.expectedErr) == 0 {
   366  				mockNetClient.mockResponse(1, nil, responseBytes)
   367  			} else {
   368  				attempted := false
   369  				mockNetClient.mockResponse(2, func() {
   370  					if attempted {
   371  						cancel()
   372  					}
   373  					attempted = true
   374  				}, responseBytes)
   375  			}
   376  
   377  			blockResponse, err := stateSyncClient.GetBlocks(ctx, test.request.Hash, test.request.Height, test.request.Parents)
   378  			if len(test.expectedErr) != 0 {
   379  				if err == nil {
   380  					t.Fatalf("Expected error: %s, but found no error", test.expectedErr)
   381  				}
   382  				assert.True(t, strings.Contains(err.Error(), test.expectedErr), "expected error to contain [%s], but found [%s]", test.expectedErr, err)
   383  				return
   384  			}
   385  			if err != nil {
   386  				t.Fatal(err)
   387  			}
   388  
   389  			test.assertResponse(t, blockResponse)
   390  		})
   391  	}
   392  }
   393  
   394  func buildGetter(blocks []*types.Block) handlers.BlockProvider {
   395  	return &handlers.TestBlockProvider{
   396  		GetBlockFn: func(blockHash common.Hash, blockHeight uint64) *types.Block {
   397  			requestedBlock := blocks[blockHeight]
   398  			if requestedBlock.Hash() != blockHash {
   399  				fmt.Printf("ERROR height=%d, hash=%s, parentHash=%s, reqHash=%s\n", blockHeight, blockHash, requestedBlock.ParentHash(), requestedBlock.Hash())
   400  				return nil
   401  			}
   402  			return requestedBlock
   403  		},
   404  	}
   405  }
   406  
   407  func TestGetLeafs(t *testing.T) {
   408  	rand.Seed(1)
   409  
   410  	const leafsLimit = 1024
   411  
   412  	trieDB := trie.NewDatabase(memorydb.New())
   413  	largeTrieRoot, largeTrieKeys, _ := trie.GenerateTrie(t, trieDB, 100_000, common.HashLength)
   414  	smallTrieRoot, _, _ := trie.GenerateTrie(t, trieDB, leafsLimit, common.HashLength)
   415  
   416  	handler := handlers.NewLeafsRequestHandler(trieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats())
   417  	client := NewClient(&ClientConfig{
   418  		NetworkClient:    &mockNetwork{},
   419  		Codec:            message.Codec,
   420  		Stats:            clientstats.NewNoOpStats(),
   421  		StateSyncNodeIDs: nil,
   422  		BlockParser:      mockBlockParser,
   423  	})
   424  
   425  	tests := map[string]struct {
   426  		request        message.LeafsRequest
   427  		getResponse    func(t *testing.T, request message.LeafsRequest) []byte
   428  		assertResponse func(t *testing.T, response message.LeafsResponse)
   429  		expectedErr    error
   430  	}{
   431  		"full response for small (single request) trie": {
   432  			request: message.LeafsRequest{
   433  				Root:     smallTrieRoot,
   434  				Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   435  				End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   436  				Limit:    leafsLimit,
   437  				NodeType: message.StateTrieNode,
   438  			},
   439  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   440  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   441  				if err != nil {
   442  					t.Fatal("unexpected error in calling leafs request handler", err)
   443  				}
   444  				if len(response) == 0 {
   445  					t.Fatal("Failed to create valid response")
   446  				}
   447  
   448  				return response
   449  			},
   450  			assertResponse: func(t *testing.T, response message.LeafsResponse) {
   451  				assert.False(t, response.More)
   452  				assert.Equal(t, leafsLimit, len(response.Keys))
   453  				assert.Equal(t, leafsLimit, len(response.Vals))
   454  			},
   455  		},
   456  		"too many leaves in response": {
   457  			request: message.LeafsRequest{
   458  				Root:     smallTrieRoot,
   459  				Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   460  				End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   461  				Limit:    leafsLimit / 2,
   462  				NodeType: message.StateTrieNode,
   463  			},
   464  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   465  				modifiedRequest := request
   466  				modifiedRequest.Limit = leafsLimit
   467  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, modifiedRequest)
   468  				if err != nil {
   469  					t.Fatal("unexpected error in calling leafs request handler", err)
   470  				}
   471  				if len(response) == 0 {
   472  					t.Fatal("Failed to create valid response")
   473  				}
   474  
   475  				return response
   476  			},
   477  			expectedErr: errTooManyLeaves,
   478  		},
   479  		"partial response to request for entire trie (full leaf limit)": {
   480  			request: message.LeafsRequest{
   481  				Root:     largeTrieRoot,
   482  				Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   483  				End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   484  				Limit:    leafsLimit,
   485  				NodeType: message.StateTrieNode,
   486  			},
   487  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   488  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   489  				if err != nil {
   490  					t.Fatal("unexpected error in calling leafs request handler", err)
   491  				}
   492  				if len(response) == 0 {
   493  					t.Fatal("Failed to create valid response")
   494  				}
   495  
   496  				return response
   497  			},
   498  			assertResponse: func(t *testing.T, response message.LeafsResponse) {
   499  				assert.True(t, response.More)
   500  				assert.Equal(t, leafsLimit, len(response.Keys))
   501  				assert.Equal(t, leafsLimit, len(response.Vals))
   502  			},
   503  		},
   504  		"partial response to request for middle range of trie (full leaf limit)": {
   505  			request: message.LeafsRequest{
   506  				Root:     largeTrieRoot,
   507  				Start:    largeTrieKeys[1000],
   508  				End:      largeTrieKeys[99000],
   509  				Limit:    leafsLimit,
   510  				NodeType: message.StateTrieNode,
   511  			},
   512  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   513  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   514  				if err != nil {
   515  					t.Fatal("unexpected error in calling leafs request handler", err)
   516  				}
   517  
   518  				if len(response) == 0 {
   519  					t.Fatal("Failed to create valid response")
   520  				}
   521  				return response
   522  			},
   523  			assertResponse: func(t *testing.T, response message.LeafsResponse) {
   524  				assert.True(t, response.More)
   525  				assert.Equal(t, leafsLimit, len(response.Keys))
   526  				assert.Equal(t, leafsLimit, len(response.Vals))
   527  			},
   528  		},
   529  		"full response from near end of trie to end of trie (less than leaf limit)": {
   530  			request: message.LeafsRequest{
   531  				Root:     largeTrieRoot,
   532  				Start:    largeTrieKeys[len(largeTrieKeys)-30], // Set start 30 keys from the end of the large trie
   533  				End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   534  				Limit:    leafsLimit,
   535  				NodeType: message.StateTrieNode,
   536  			},
   537  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   538  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   539  				if err != nil {
   540  					t.Fatal("unexpected error in calling leafs request handler", err)
   541  				}
   542  				if len(response) == 0 {
   543  					t.Fatal("Failed to create valid response")
   544  				}
   545  				return response
   546  			},
   547  			assertResponse: func(t *testing.T, response message.LeafsResponse) {
   548  				assert.False(t, response.More)
   549  				assert.Equal(t, 30, len(response.Keys))
   550  				assert.Equal(t, 30, len(response.Vals))
   551  			},
   552  		},
   553  		"full response for intermediate range of trie (less than leaf limit)": {
   554  			request: message.LeafsRequest{
   555  				Root:     largeTrieRoot,
   556  				Start:    largeTrieKeys[1000], // Set the range for 1000 leafs in an intermediate range of the trie
   557  				End:      largeTrieKeys[1099], // (inclusive range)
   558  				Limit:    leafsLimit,
   559  				NodeType: message.StateTrieNode,
   560  			},
   561  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   562  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   563  				if err != nil {
   564  					t.Fatal("unexpected error in calling leafs request handler", err)
   565  				}
   566  				if len(response) == 0 {
   567  					t.Fatal("Failed to create valid response")
   568  				}
   569  
   570  				return response
   571  			},
   572  			assertResponse: func(t *testing.T, response message.LeafsResponse) {
   573  				assert.True(t, response.More)
   574  				assert.Equal(t, 100, len(response.Keys))
   575  				assert.Equal(t, 100, len(response.Vals))
   576  			},
   577  		},
   578  		"removed first key in response": {
   579  			request: message.LeafsRequest{
   580  				Root:     largeTrieRoot,
   581  				Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   582  				End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   583  				Limit:    leafsLimit,
   584  				NodeType: message.StateTrieNode,
   585  			},
   586  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   587  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   588  				if err != nil {
   589  					t.Fatal("unexpected error in calling leafs request handler", err)
   590  				}
   591  				if len(response) == 0 {
   592  					t.Fatal("Failed to create valid response")
   593  				}
   594  				var leafResponse message.LeafsResponse
   595  				if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil {
   596  					t.Fatal(err)
   597  				}
   598  				leafResponse.Keys = leafResponse.Keys[1:]
   599  				leafResponse.Vals = leafResponse.Vals[1:]
   600  
   601  				modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse)
   602  				if err != nil {
   603  					t.Fatal(err)
   604  				}
   605  				return modifiedResponse
   606  			},
   607  			expectedErr: errInvalidRangeProof,
   608  		},
   609  		"removed first key in response and replaced proof": {
   610  			request: message.LeafsRequest{
   611  				Root:     largeTrieRoot,
   612  				Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   613  				End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   614  				Limit:    leafsLimit,
   615  				NodeType: message.StateTrieNode,
   616  			},
   617  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   618  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   619  				if err != nil {
   620  					t.Fatal("unexpected error in calling leafs request handler", err)
   621  				}
   622  				if len(response) == 0 {
   623  					t.Fatal("Failed to create valid response")
   624  				}
   625  				var leafResponse message.LeafsResponse
   626  				if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil {
   627  					t.Fatal(err)
   628  				}
   629  				modifiedRequest := request
   630  				modifiedRequest.Start = leafResponse.Keys[1]
   631  				modifiedResponse, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 2, modifiedRequest)
   632  				if err != nil {
   633  					t.Fatal("unexpected error in calling leafs request handler", err)
   634  				}
   635  				return modifiedResponse
   636  			},
   637  			expectedErr: errInvalidRangeProof,
   638  		},
   639  		"removed last key in response": {
   640  			request: message.LeafsRequest{
   641  				Root:     largeTrieRoot,
   642  				Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   643  				End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   644  				Limit:    leafsLimit,
   645  				NodeType: message.StateTrieNode,
   646  			},
   647  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   648  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   649  				if err != nil {
   650  					t.Fatal("unexpected error in calling leafs request handler", err)
   651  				}
   652  				if len(response) == 0 {
   653  					t.Fatal("Failed to create valid response")
   654  				}
   655  				var leafResponse message.LeafsResponse
   656  				if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil {
   657  					t.Fatal(err)
   658  				}
   659  				leafResponse.Keys = leafResponse.Keys[:len(leafResponse.Keys)-2]
   660  				leafResponse.Vals = leafResponse.Vals[:len(leafResponse.Vals)-2]
   661  
   662  				modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse)
   663  				if err != nil {
   664  					t.Fatal(err)
   665  				}
   666  				return modifiedResponse
   667  			},
   668  			expectedErr: errInvalidRangeProof,
   669  		},
   670  		"removed key from middle of response": {
   671  			request: message.LeafsRequest{
   672  				Root:     largeTrieRoot,
   673  				Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   674  				End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   675  				Limit:    leafsLimit,
   676  				NodeType: message.StateTrieNode,
   677  			},
   678  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   679  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   680  				if err != nil {
   681  					t.Fatal("unexpected error in calling leafs request handler", err)
   682  				}
   683  				if len(response) == 0 {
   684  					t.Fatal("Failed to create valid response")
   685  				}
   686  				var leafResponse message.LeafsResponse
   687  				if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil {
   688  					t.Fatal(err)
   689  				}
   690  				// Remove middle key-value pair response
   691  				leafResponse.Keys = append(leafResponse.Keys[:100], leafResponse.Keys[101:]...)
   692  				leafResponse.Vals = append(leafResponse.Vals[:100], leafResponse.Vals[101:]...)
   693  
   694  				modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse)
   695  				if err != nil {
   696  					t.Fatal(err)
   697  				}
   698  				return modifiedResponse
   699  			},
   700  			expectedErr: errInvalidRangeProof,
   701  		},
   702  		"corrupted value in middle of response": {
   703  			request: message.LeafsRequest{
   704  				Root:     largeTrieRoot,
   705  				Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   706  				End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   707  				Limit:    leafsLimit,
   708  				NodeType: message.StateTrieNode,
   709  			},
   710  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   711  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   712  				if err != nil {
   713  					t.Fatal("unexpected error in calling leafs request handler", err)
   714  				}
   715  				if len(response) == 0 {
   716  					t.Fatal("Failed to create valid response")
   717  				}
   718  				var leafResponse message.LeafsResponse
   719  				if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil {
   720  					t.Fatal(err)
   721  				}
   722  				// Remove middle key-value pair response
   723  				leafResponse.Vals[100] = []byte("garbage value data")
   724  
   725  				modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse)
   726  				if err != nil {
   727  					t.Fatal(err)
   728  				}
   729  				return modifiedResponse
   730  			},
   731  			expectedErr: errInvalidRangeProof,
   732  		},
   733  		"all proof keys removed from response": {
   734  			request: message.LeafsRequest{
   735  				Root:     largeTrieRoot,
   736  				Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   737  				End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   738  				Limit:    leafsLimit,
   739  				NodeType: message.StateTrieNode,
   740  			},
   741  			getResponse: func(t *testing.T, request message.LeafsRequest) []byte {
   742  				response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request)
   743  				if err != nil {
   744  					t.Fatal("unexpected error in calling leafs request handler", err)
   745  				}
   746  				if len(response) == 0 {
   747  					t.Fatal("Failed to create valid response")
   748  				}
   749  
   750  				var leafResponse message.LeafsResponse
   751  				if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil {
   752  					t.Fatal(err)
   753  				}
   754  				// Remove the proof
   755  				leafResponse.ProofVals = nil
   756  
   757  				modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse)
   758  				if err != nil {
   759  					t.Fatal(err)
   760  				}
   761  				return modifiedResponse
   762  			},
   763  			expectedErr: errInvalidRangeProof,
   764  		},
   765  	}
   766  	for name, test := range tests {
   767  		t.Run(name, func(t *testing.T) {
   768  			responseBytes := test.getResponse(t, test.request)
   769  
   770  			response, _, err := parseLeafsResponse(client.codec, test.request, responseBytes)
   771  			if test.expectedErr != nil {
   772  				if err == nil {
   773  					t.Fatalf("Expected error: %s, but found no error", test.expectedErr)
   774  				}
   775  				assert.True(t, strings.Contains(err.Error(), test.expectedErr.Error()))
   776  				return
   777  			}
   778  
   779  			if err != nil {
   780  				t.Fatal(err)
   781  			}
   782  
   783  			leafsResponse, ok := response.(message.LeafsResponse)
   784  			if !ok {
   785  				t.Fatalf("parseLeafsResponse returned incorrect type %T", response)
   786  			}
   787  			test.assertResponse(t, leafsResponse)
   788  		})
   789  	}
   790  }
   791  
   792  func TestGetLeafsRetries(t *testing.T) {
   793  	rand.Seed(1)
   794  
   795  	trieDB := trie.NewDatabase(memorydb.New())
   796  	root, _, _ := trie.GenerateTrie(t, trieDB, 100_000, common.HashLength)
   797  
   798  	handler := handlers.NewLeafsRequestHandler(trieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats())
   799  	mockNetClient := &mockNetwork{}
   800  
   801  	const maxAttempts = 8
   802  	client := NewClient(&ClientConfig{
   803  		NetworkClient:    mockNetClient,
   804  		Codec:            message.Codec,
   805  		Stats:            clientstats.NewNoOpStats(),
   806  		StateSyncNodeIDs: nil,
   807  		BlockParser:      mockBlockParser,
   808  	})
   809  
   810  	request := message.LeafsRequest{
   811  		Root:     root,
   812  		Start:    bytes.Repeat([]byte{0x00}, common.HashLength),
   813  		End:      bytes.Repeat([]byte{0xff}, common.HashLength),
   814  		Limit:    defaultLeafRequestLimit,
   815  		NodeType: message.StateTrieNode,
   816  	}
   817  
   818  	ctx, cancel := context.WithCancel(context.Background())
   819  	defer cancel()
   820  	goodResponse, responseErr := handler.OnLeafsRequest(ctx, ids.GenerateTestNodeID(), 1, request)
   821  	assert.NoError(t, responseErr)
   822  	mockNetClient.mockResponse(1, nil, goodResponse)
   823  
   824  	res, err := client.GetLeafs(ctx, request)
   825  	if err != nil {
   826  		t.Fatal(err)
   827  	}
   828  	assert.Equal(t, 1024, len(res.Keys))
   829  	assert.Equal(t, 1024, len(res.Vals))
   830  
   831  	// Succeeds within the allotted number of attempts
   832  	invalidResponse := []byte("invalid response")
   833  	mockNetClient.mockResponses(nil, invalidResponse, invalidResponse, goodResponse)
   834  
   835  	res, err = client.GetLeafs(ctx, request)
   836  	if err != nil {
   837  		t.Fatal(err)
   838  	}
   839  	assert.Equal(t, 1024, len(res.Keys))
   840  	assert.Equal(t, 1024, len(res.Vals))
   841  
   842  	// Test that GetLeafs stops after the context is cancelled
   843  	numAttempts := 0
   844  	mockNetClient.mockResponse(maxAttempts, func() {
   845  		numAttempts++
   846  		if numAttempts >= maxAttempts {
   847  			cancel()
   848  		}
   849  	}, invalidResponse)
   850  	_, err = client.GetLeafs(ctx, request)
   851  	assert.Error(t, err)
   852  	assert.True(t, strings.Contains(err.Error(), context.Canceled.Error()))
   853  }
   854  
   855  func TestStateSyncNodes(t *testing.T) {
   856  	mockNetClient := &mockNetwork{}
   857  
   858  	stateSyncNodes := []ids.NodeID{
   859  		ids.GenerateTestNodeID(),
   860  		ids.GenerateTestNodeID(),
   861  		ids.GenerateTestNodeID(),
   862  		ids.GenerateTestNodeID(),
   863  	}
   864  	client := NewClient(&ClientConfig{
   865  		NetworkClient:    mockNetClient,
   866  		Codec:            message.Codec,
   867  		Stats:            clientstats.NewNoOpStats(),
   868  		StateSyncNodeIDs: stateSyncNodes,
   869  		BlockParser:      mockBlockParser,
   870  	})
   871  	ctx, cancel := context.WithCancel(context.Background())
   872  	defer cancel()
   873  	attempt := 0
   874  	responses := [][]byte{{1}, {2}, {3}, {4}}
   875  	mockNetClient.mockResponses(func() {
   876  		attempt++
   877  		if attempt >= 4 {
   878  			cancel()
   879  		}
   880  	}, responses...)
   881  
   882  	// send some request, doesn't matter what it is because we're testing the interaction with state sync nodes here
   883  	response, err := client.GetLeafs(ctx, message.LeafsRequest{})
   884  	assert.Error(t, err)
   885  	assert.Empty(t, response)
   886  
   887  	// assert all nodes were called
   888  	assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[0])
   889  	assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[1])
   890  	assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[2])
   891  	assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[3])
   892  }