github.com/MetalBlockchain/subnet-evm@v0.4.9/sync/client/client_test.go (about) 1 // (c) 2021-2022, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package statesyncclient 5 6 import ( 7 "bytes" 8 "context" 9 "fmt" 10 "math/rand" 11 "strings" 12 "testing" 13 14 "github.com/stretchr/testify/assert" 15 16 "github.com/MetalBlockchain/metalgo/ids" 17 18 "github.com/MetalBlockchain/subnet-evm/consensus/dummy" 19 "github.com/MetalBlockchain/subnet-evm/core" 20 "github.com/MetalBlockchain/subnet-evm/core/types" 21 "github.com/MetalBlockchain/subnet-evm/ethdb/memorydb" 22 "github.com/MetalBlockchain/subnet-evm/params" 23 "github.com/MetalBlockchain/subnet-evm/plugin/evm/message" 24 clientstats "github.com/MetalBlockchain/subnet-evm/sync/client/stats" 25 "github.com/MetalBlockchain/subnet-evm/sync/handlers" 26 handlerstats "github.com/MetalBlockchain/subnet-evm/sync/handlers/stats" 27 "github.com/MetalBlockchain/subnet-evm/trie" 28 "github.com/ethereum/go-ethereum/common" 29 "github.com/ethereum/go-ethereum/crypto" 30 ) 31 32 func TestGetCode(t *testing.T) { 33 mockNetClient := &mockNetwork{} 34 35 tests := map[string]struct { 36 setupRequest func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) 37 expectedErr error 38 }{ 39 "normal": { 40 setupRequest: func() ([]common.Hash, message.CodeResponse, [][]byte) { 41 code := []byte("this is the code") 42 codeHash := crypto.Keccak256Hash(code) 43 codeSlices := [][]byte{code} 44 return []common.Hash{codeHash}, message.CodeResponse{ 45 Data: codeSlices, 46 }, codeSlices 47 }, 48 expectedErr: nil, 49 }, 50 "unexpected code bytes": { 51 setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) { 52 return []common.Hash{{1}}, message.CodeResponse{ 53 Data: [][]byte{{1}}, 54 }, nil 55 }, 56 expectedErr: errHashMismatch, 57 }, 58 "too many code elements returned": { 59 setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) { 60 return []common.Hash{{1}}, message.CodeResponse{ 61 Data: [][]byte{{1}, {2}}, 62 }, nil 63 }, 64 expectedErr: errInvalidCodeResponseLen, 65 }, 66 "too few code elements returned": { 67 setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) { 68 return []common.Hash{{1}}, message.CodeResponse{ 69 Data: [][]byte{}, 70 }, nil 71 }, 72 expectedErr: errInvalidCodeResponseLen, 73 }, 74 "code size is too large": { 75 setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) { 76 oversizedCode := make([]byte, params.MaxCodeSize+1) 77 codeHash := crypto.Keccak256Hash(oversizedCode) 78 return []common.Hash{codeHash}, message.CodeResponse{ 79 Data: [][]byte{oversizedCode}, 80 }, nil 81 }, 82 expectedErr: errMaxCodeSizeExceeded, 83 }, 84 } 85 86 stateSyncClient := NewClient(&ClientConfig{ 87 NetworkClient: mockNetClient, 88 Codec: message.Codec, 89 Stats: clientstats.NewNoOpStats(), 90 StateSyncNodeIDs: nil, 91 BlockParser: mockBlockParser, 92 }) 93 94 for name, test := range tests { 95 t.Run(name, func(t *testing.T) { 96 ctx, cancel := context.WithCancel(context.Background()) 97 defer cancel() 98 codeHashes, res, expectedCode := test.setupRequest() 99 100 responseBytes, err := message.Codec.Marshal(message.Version, res) 101 if err != nil { 102 t.Fatal(err) 103 } 104 // Dirty hack required because the client will re-request if it encounters 105 // an error. 106 attempted := false 107 if test.expectedErr == nil { 108 mockNetClient.mockResponse(1, nil, responseBytes) 109 } else { 110 mockNetClient.mockResponse(2, func() { 111 // Cancel before the second attempt is processed. 112 if attempted { 113 cancel() 114 } 115 attempted = true 116 }, responseBytes) 117 } 118 119 codeBytes, err := stateSyncClient.GetCode(ctx, codeHashes) 120 // If we expect an error, assert that one occurred and return 121 if test.expectedErr != nil { 122 assert.ErrorIs(t, err, test.expectedErr) 123 assert.EqualValues(t, 2, mockNetClient.numCalls) 124 return 125 } 126 // Otherwise, assert there was no error and that the result is as expected 127 assert.NoError(t, err) 128 assert.Equal(t, len(codeBytes), len(expectedCode)) 129 for i, code := range codeBytes { 130 assert.Equal(t, expectedCode[i], code) 131 } 132 assert.Equal(t, uint(1), mockNetClient.numCalls) 133 }) 134 } 135 } 136 137 func TestGetBlocks(t *testing.T) { 138 // set random seed for deterministic tests 139 rand.Seed(1) 140 141 var gspec = &core.Genesis{ 142 Config: params.TestChainConfig, 143 } 144 memdb := memorydb.New() 145 genesis := gspec.MustCommit(memdb) 146 engine := dummy.NewETHFaker() 147 numBlocks := 110 148 blocks, _, err := core.GenerateChain(params.TestChainConfig, genesis, engine, memdb, numBlocks, 0, func(i int, b *core.BlockGen) {}) 149 if err != nil { 150 t.Fatal("unexpected error when generating test blockchain", err) 151 } 152 assert.Equal(t, numBlocks, len(blocks)) 153 154 // Construct client 155 mockNetClient := &mockNetwork{} 156 stateSyncClient := NewClient(&ClientConfig{ 157 NetworkClient: mockNetClient, 158 Codec: message.Codec, 159 Stats: clientstats.NewNoOpStats(), 160 StateSyncNodeIDs: nil, 161 BlockParser: mockBlockParser, 162 }) 163 164 blocksRequestHandler := handlers.NewBlockRequestHandler(buildGetter(blocks), message.Codec, handlerstats.NewNoopHandlerStats()) 165 166 // encodeBlockSlice takes a slice of blocks that are ordered in increasing height order 167 // and returns a slice of byte slices with those blocks encoded in reverse order 168 encodeBlockSlice := func(blocks []*types.Block) [][]byte { 169 blockBytes := make([][]byte, 0, len(blocks)) 170 for i := len(blocks) - 1; i >= 0; i-- { 171 buf := new(bytes.Buffer) 172 if err := blocks[i].EncodeRLP(buf); err != nil { 173 t.Fatalf("failed to generate expected response %s", err) 174 } 175 blockBytes = append(blockBytes, buf.Bytes()) 176 } 177 178 return blockBytes 179 } 180 tests := map[string]struct { 181 request message.BlockRequest 182 getResponse func(t *testing.T, request message.BlockRequest) []byte 183 assertResponse func(t *testing.T, response []*types.Block) 184 expectedErr string 185 }{ 186 "normal resonse": { 187 request: message.BlockRequest{ 188 Hash: blocks[100].Hash(), 189 Height: 100, 190 Parents: 16, 191 }, 192 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 193 response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 194 if err != nil { 195 t.Fatal(err) 196 } 197 198 if len(response) == 0 { 199 t.Fatal("Failed to generate valid response") 200 } 201 202 return response 203 }, 204 assertResponse: func(t *testing.T, response []*types.Block) { 205 assert.Equal(t, 16, len(response)) 206 }, 207 }, 208 "fewer than requested blocks": { 209 request: message.BlockRequest{ 210 Hash: blocks[100].Hash(), 211 Height: 100, 212 Parents: 16, 213 }, 214 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 215 request.Parents -= 5 216 response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 217 if err != nil { 218 t.Fatal(err) 219 } 220 221 if len(response) == 0 { 222 t.Fatal("Failed to generate valid response") 223 } 224 225 return response 226 }, 227 // If the server returns fewer than requested blocks, we should consider it valid 228 assertResponse: func(t *testing.T, response []*types.Block) { 229 assert.Equal(t, 11, len(response)) 230 }, 231 }, 232 "gibberish response": { 233 request: message.BlockRequest{ 234 Hash: blocks[100].Hash(), 235 Height: 100, 236 Parents: 16, 237 }, 238 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 239 return []byte("gibberish") 240 }, 241 expectedErr: errUnmarshalResponse.Error(), 242 }, 243 "invalid value replacing block": { 244 request: message.BlockRequest{ 245 Hash: blocks[100].Hash(), 246 Height: 100, 247 Parents: 16, 248 }, 249 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 250 response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 251 if err != nil { 252 t.Fatalf("failed to get block response: %s", err) 253 } 254 var blockResponse message.BlockResponse 255 if _, err = message.Codec.Unmarshal(response, &blockResponse); err != nil { 256 t.Fatalf("failed to marshal block response: %s", err) 257 } 258 // Replace middle value with garbage data 259 blockResponse.Blocks[10] = []byte("invalid value replacing block bytes") 260 responseBytes, err := message.Codec.Marshal(message.Version, blockResponse) 261 if err != nil { 262 t.Fatalf("failed to marshal block response: %s", err) 263 } 264 265 return responseBytes 266 }, 267 expectedErr: "failed to unmarshal response: rlp: expected input list for types.extblock", 268 }, 269 "incorrect starting point": { 270 request: message.BlockRequest{ 271 Hash: blocks[100].Hash(), 272 Height: 100, 273 Parents: 16, 274 }, 275 getResponse: func(t *testing.T, _ message.BlockRequest) []byte { 276 response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, message.BlockRequest{ 277 Hash: blocks[99].Hash(), 278 Height: 99, 279 Parents: 16, 280 }) 281 if err != nil { 282 t.Fatal(err) 283 } 284 285 if len(response) == 0 { 286 t.Fatal("Failed to generate valid response") 287 } 288 289 return response 290 }, 291 expectedErr: errHashMismatch.Error(), 292 }, 293 "missing link in between blocks": { 294 request: message.BlockRequest{ 295 Hash: blocks[100].Hash(), 296 Height: 100, 297 Parents: 16, 298 }, 299 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 300 // Encode blocks with a missing link 301 blks := make([]*types.Block, 0) 302 blks = append(blks, blocks[84:89]...) 303 blks = append(blks, blocks[90:101]...) 304 blockBytes := encodeBlockSlice(blks) 305 306 blockResponse := message.BlockResponse{ 307 Blocks: blockBytes, 308 } 309 responseBytes, err := message.Codec.Marshal(message.Version, blockResponse) 310 if err != nil { 311 t.Fatalf("failed to marshal block response: %s", err) 312 } 313 314 return responseBytes 315 }, 316 expectedErr: errHashMismatch.Error(), 317 }, 318 "no blocks": { 319 request: message.BlockRequest{ 320 Hash: blocks[100].Hash(), 321 Height: 100, 322 Parents: 16, 323 }, 324 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 325 blockResponse := message.BlockResponse{ 326 Blocks: nil, 327 } 328 responseBytes, err := message.Codec.Marshal(message.Version, blockResponse) 329 if err != nil { 330 t.Fatalf("failed to marshal block response: %s", err) 331 } 332 333 return responseBytes 334 }, 335 expectedErr: errEmptyResponse.Error(), 336 }, 337 "more than requested blocks": { 338 request: message.BlockRequest{ 339 Hash: blocks[100].Hash(), 340 Height: 100, 341 Parents: 16, 342 }, 343 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 344 blockBytes := encodeBlockSlice(blocks[80:100]) 345 346 blockResponse := message.BlockResponse{ 347 Blocks: blockBytes, 348 } 349 responseBytes, err := message.Codec.Marshal(message.Version, blockResponse) 350 if err != nil { 351 t.Fatalf("failed to marshal block response: %s", err) 352 } 353 354 return responseBytes 355 }, 356 expectedErr: errTooManyBlocks.Error(), 357 }, 358 } 359 for name, test := range tests { 360 t.Run(name, func(t *testing.T) { 361 ctx, cancel := context.WithCancel(context.Background()) 362 defer cancel() 363 364 responseBytes := test.getResponse(t, test.request) 365 if len(test.expectedErr) == 0 { 366 mockNetClient.mockResponse(1, nil, responseBytes) 367 } else { 368 attempted := false 369 mockNetClient.mockResponse(2, func() { 370 if attempted { 371 cancel() 372 } 373 attempted = true 374 }, responseBytes) 375 } 376 377 blockResponse, err := stateSyncClient.GetBlocks(ctx, test.request.Hash, test.request.Height, test.request.Parents) 378 if len(test.expectedErr) != 0 { 379 if err == nil { 380 t.Fatalf("Expected error: %s, but found no error", test.expectedErr) 381 } 382 assert.True(t, strings.Contains(err.Error(), test.expectedErr), "expected error to contain [%s], but found [%s]", test.expectedErr, err) 383 return 384 } 385 if err != nil { 386 t.Fatal(err) 387 } 388 389 test.assertResponse(t, blockResponse) 390 }) 391 } 392 } 393 394 func buildGetter(blocks []*types.Block) handlers.BlockProvider { 395 return &handlers.TestBlockProvider{ 396 GetBlockFn: func(blockHash common.Hash, blockHeight uint64) *types.Block { 397 requestedBlock := blocks[blockHeight] 398 if requestedBlock.Hash() != blockHash { 399 fmt.Printf("ERROR height=%d, hash=%s, parentHash=%s, reqHash=%s\n", blockHeight, blockHash, requestedBlock.ParentHash(), requestedBlock.Hash()) 400 return nil 401 } 402 return requestedBlock 403 }, 404 } 405 } 406 407 func TestGetLeafs(t *testing.T) { 408 rand.Seed(1) 409 410 const leafsLimit = 1024 411 412 trieDB := trie.NewDatabase(memorydb.New()) 413 largeTrieRoot, largeTrieKeys, _ := trie.GenerateTrie(t, trieDB, 100_000, common.HashLength) 414 smallTrieRoot, _, _ := trie.GenerateTrie(t, trieDB, leafsLimit, common.HashLength) 415 416 handler := handlers.NewLeafsRequestHandler(trieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats()) 417 client := NewClient(&ClientConfig{ 418 NetworkClient: &mockNetwork{}, 419 Codec: message.Codec, 420 Stats: clientstats.NewNoOpStats(), 421 StateSyncNodeIDs: nil, 422 BlockParser: mockBlockParser, 423 }) 424 425 tests := map[string]struct { 426 request message.LeafsRequest 427 getResponse func(t *testing.T, request message.LeafsRequest) []byte 428 assertResponse func(t *testing.T, response message.LeafsResponse) 429 expectedErr error 430 }{ 431 "full response for small (single request) trie": { 432 request: message.LeafsRequest{ 433 Root: smallTrieRoot, 434 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 435 End: bytes.Repeat([]byte{0xff}, common.HashLength), 436 Limit: leafsLimit, 437 }, 438 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 439 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 440 if err != nil { 441 t.Fatal("unexpected error in calling leafs request handler", err) 442 } 443 if len(response) == 0 { 444 t.Fatal("Failed to create valid response") 445 } 446 447 return response 448 }, 449 assertResponse: func(t *testing.T, response message.LeafsResponse) { 450 assert.False(t, response.More) 451 assert.Equal(t, leafsLimit, len(response.Keys)) 452 assert.Equal(t, leafsLimit, len(response.Vals)) 453 }, 454 }, 455 "too many leaves in response": { 456 request: message.LeafsRequest{ 457 Root: smallTrieRoot, 458 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 459 End: bytes.Repeat([]byte{0xff}, common.HashLength), 460 Limit: leafsLimit / 2, 461 }, 462 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 463 modifiedRequest := request 464 modifiedRequest.Limit = leafsLimit 465 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, modifiedRequest) 466 if err != nil { 467 t.Fatal("unexpected error in calling leafs request handler", err) 468 } 469 if len(response) == 0 { 470 t.Fatal("Failed to create valid response") 471 } 472 473 return response 474 }, 475 expectedErr: errTooManyLeaves, 476 }, 477 "partial response to request for entire trie (full leaf limit)": { 478 request: message.LeafsRequest{ 479 Root: largeTrieRoot, 480 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 481 End: bytes.Repeat([]byte{0xff}, common.HashLength), 482 Limit: leafsLimit, 483 }, 484 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 485 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 486 if err != nil { 487 t.Fatal("unexpected error in calling leafs request handler", err) 488 } 489 if len(response) == 0 { 490 t.Fatal("Failed to create valid response") 491 } 492 493 return response 494 }, 495 assertResponse: func(t *testing.T, response message.LeafsResponse) { 496 assert.True(t, response.More) 497 assert.Equal(t, leafsLimit, len(response.Keys)) 498 assert.Equal(t, leafsLimit, len(response.Vals)) 499 }, 500 }, 501 "partial response to request for middle range of trie (full leaf limit)": { 502 request: message.LeafsRequest{ 503 Root: largeTrieRoot, 504 Start: largeTrieKeys[1000], 505 End: largeTrieKeys[99000], 506 Limit: leafsLimit, 507 }, 508 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 509 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 510 if err != nil { 511 t.Fatal("unexpected error in calling leafs request handler", err) 512 } 513 514 if len(response) == 0 { 515 t.Fatal("Failed to create valid response") 516 } 517 return response 518 }, 519 assertResponse: func(t *testing.T, response message.LeafsResponse) { 520 assert.True(t, response.More) 521 assert.Equal(t, leafsLimit, len(response.Keys)) 522 assert.Equal(t, leafsLimit, len(response.Vals)) 523 }, 524 }, 525 "full response from near end of trie to end of trie (less than leaf limit)": { 526 request: message.LeafsRequest{ 527 Root: largeTrieRoot, 528 Start: largeTrieKeys[len(largeTrieKeys)-30], // Set start 30 keys from the end of the large trie 529 End: bytes.Repeat([]byte{0xff}, common.HashLength), 530 Limit: leafsLimit, 531 }, 532 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 533 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 534 if err != nil { 535 t.Fatal("unexpected error in calling leafs request handler", err) 536 } 537 if len(response) == 0 { 538 t.Fatal("Failed to create valid response") 539 } 540 return response 541 }, 542 assertResponse: func(t *testing.T, response message.LeafsResponse) { 543 assert.False(t, response.More) 544 assert.Equal(t, 30, len(response.Keys)) 545 assert.Equal(t, 30, len(response.Vals)) 546 }, 547 }, 548 "full response for intermediate range of trie (less than leaf limit)": { 549 request: message.LeafsRequest{ 550 Root: largeTrieRoot, 551 Start: largeTrieKeys[1000], // Set the range for 1000 leafs in an intermediate range of the trie 552 End: largeTrieKeys[1099], // (inclusive range) 553 Limit: leafsLimit, 554 }, 555 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 556 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 557 if err != nil { 558 t.Fatal("unexpected error in calling leafs request handler", err) 559 } 560 if len(response) == 0 { 561 t.Fatal("Failed to create valid response") 562 } 563 564 return response 565 }, 566 assertResponse: func(t *testing.T, response message.LeafsResponse) { 567 assert.True(t, response.More) 568 assert.Equal(t, 100, len(response.Keys)) 569 assert.Equal(t, 100, len(response.Vals)) 570 }, 571 }, 572 "removed first key in response": { 573 request: message.LeafsRequest{ 574 Root: largeTrieRoot, 575 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 576 End: bytes.Repeat([]byte{0xff}, common.HashLength), 577 Limit: leafsLimit, 578 }, 579 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 580 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 581 if err != nil { 582 t.Fatal("unexpected error in calling leafs request handler", err) 583 } 584 if len(response) == 0 { 585 t.Fatal("Failed to create valid response") 586 } 587 var leafResponse message.LeafsResponse 588 if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { 589 t.Fatal(err) 590 } 591 leafResponse.Keys = leafResponse.Keys[1:] 592 leafResponse.Vals = leafResponse.Vals[1:] 593 594 modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) 595 if err != nil { 596 t.Fatal(err) 597 } 598 return modifiedResponse 599 }, 600 expectedErr: errInvalidRangeProof, 601 }, 602 "removed first key in response and replaced proof": { 603 request: message.LeafsRequest{ 604 Root: largeTrieRoot, 605 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 606 End: bytes.Repeat([]byte{0xff}, common.HashLength), 607 Limit: leafsLimit, 608 }, 609 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 610 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 611 if err != nil { 612 t.Fatal("unexpected error in calling leafs request handler", err) 613 } 614 if len(response) == 0 { 615 t.Fatal("Failed to create valid response") 616 } 617 var leafResponse message.LeafsResponse 618 if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { 619 t.Fatal(err) 620 } 621 modifiedRequest := request 622 modifiedRequest.Start = leafResponse.Keys[1] 623 modifiedResponse, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 2, modifiedRequest) 624 if err != nil { 625 t.Fatal("unexpected error in calling leafs request handler", err) 626 } 627 return modifiedResponse 628 }, 629 expectedErr: errInvalidRangeProof, 630 }, 631 "removed last key in response": { 632 request: message.LeafsRequest{ 633 Root: largeTrieRoot, 634 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 635 End: bytes.Repeat([]byte{0xff}, common.HashLength), 636 Limit: leafsLimit, 637 }, 638 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 639 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 640 if err != nil { 641 t.Fatal("unexpected error in calling leafs request handler", err) 642 } 643 if len(response) == 0 { 644 t.Fatal("Failed to create valid response") 645 } 646 var leafResponse message.LeafsResponse 647 if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { 648 t.Fatal(err) 649 } 650 leafResponse.Keys = leafResponse.Keys[:len(leafResponse.Keys)-2] 651 leafResponse.Vals = leafResponse.Vals[:len(leafResponse.Vals)-2] 652 653 modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) 654 if err != nil { 655 t.Fatal(err) 656 } 657 return modifiedResponse 658 }, 659 expectedErr: errInvalidRangeProof, 660 }, 661 "removed key from middle of response": { 662 request: message.LeafsRequest{ 663 Root: largeTrieRoot, 664 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 665 End: bytes.Repeat([]byte{0xff}, common.HashLength), 666 Limit: leafsLimit, 667 }, 668 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 669 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 670 if err != nil { 671 t.Fatal("unexpected error in calling leafs request handler", err) 672 } 673 if len(response) == 0 { 674 t.Fatal("Failed to create valid response") 675 } 676 var leafResponse message.LeafsResponse 677 if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { 678 t.Fatal(err) 679 } 680 // Remove middle key-value pair response 681 leafResponse.Keys = append(leafResponse.Keys[:100], leafResponse.Keys[101:]...) 682 leafResponse.Vals = append(leafResponse.Vals[:100], leafResponse.Vals[101:]...) 683 684 modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) 685 if err != nil { 686 t.Fatal(err) 687 } 688 return modifiedResponse 689 }, 690 expectedErr: errInvalidRangeProof, 691 }, 692 "corrupted value in middle of response": { 693 request: message.LeafsRequest{ 694 Root: largeTrieRoot, 695 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 696 End: bytes.Repeat([]byte{0xff}, common.HashLength), 697 Limit: leafsLimit, 698 }, 699 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 700 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 701 if err != nil { 702 t.Fatal("unexpected error in calling leafs request handler", err) 703 } 704 if len(response) == 0 { 705 t.Fatal("Failed to create valid response") 706 } 707 var leafResponse message.LeafsResponse 708 if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { 709 t.Fatal(err) 710 } 711 // Remove middle key-value pair response 712 leafResponse.Vals[100] = []byte("garbage value data") 713 714 modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) 715 if err != nil { 716 t.Fatal(err) 717 } 718 return modifiedResponse 719 }, 720 expectedErr: errInvalidRangeProof, 721 }, 722 "all proof keys removed from response": { 723 request: message.LeafsRequest{ 724 Root: largeTrieRoot, 725 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 726 End: bytes.Repeat([]byte{0xff}, common.HashLength), 727 Limit: leafsLimit, 728 }, 729 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 730 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 731 if err != nil { 732 t.Fatal("unexpected error in calling leafs request handler", err) 733 } 734 if len(response) == 0 { 735 t.Fatal("Failed to create valid response") 736 } 737 738 var leafResponse message.LeafsResponse 739 if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { 740 t.Fatal(err) 741 } 742 // Remove the proof 743 leafResponse.ProofVals = nil 744 745 modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) 746 if err != nil { 747 t.Fatal(err) 748 } 749 return modifiedResponse 750 }, 751 expectedErr: errInvalidRangeProof, 752 }, 753 } 754 for name, test := range tests { 755 t.Run(name, func(t *testing.T) { 756 responseBytes := test.getResponse(t, test.request) 757 758 response, _, err := parseLeafsResponse(client.codec, test.request, responseBytes) 759 if test.expectedErr != nil { 760 if err == nil { 761 t.Fatalf("Expected error: %s, but found no error", test.expectedErr) 762 } 763 assert.True(t, strings.Contains(err.Error(), test.expectedErr.Error())) 764 return 765 } 766 767 if err != nil { 768 t.Fatal(err) 769 } 770 771 leafsResponse, ok := response.(message.LeafsResponse) 772 if !ok { 773 t.Fatalf("parseLeafsResponse returned incorrect type %T", response) 774 } 775 test.assertResponse(t, leafsResponse) 776 }) 777 } 778 } 779 780 func TestGetLeafsRetries(t *testing.T) { 781 rand.Seed(1) 782 783 trieDB := trie.NewDatabase(memorydb.New()) 784 root, _, _ := trie.GenerateTrie(t, trieDB, 100_000, common.HashLength) 785 786 handler := handlers.NewLeafsRequestHandler(trieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats()) 787 mockNetClient := &mockNetwork{} 788 789 const maxAttempts = 8 790 client := NewClient(&ClientConfig{ 791 NetworkClient: mockNetClient, 792 Codec: message.Codec, 793 Stats: clientstats.NewNoOpStats(), 794 StateSyncNodeIDs: nil, 795 BlockParser: mockBlockParser, 796 }) 797 798 request := message.LeafsRequest{ 799 Root: root, 800 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 801 End: bytes.Repeat([]byte{0xff}, common.HashLength), 802 Limit: defaultLeafRequestLimit, 803 } 804 805 ctx, cancel := context.WithCancel(context.Background()) 806 defer cancel() 807 goodResponse, responseErr := handler.OnLeafsRequest(ctx, ids.GenerateTestNodeID(), 1, request) 808 assert.NoError(t, responseErr) 809 mockNetClient.mockResponse(1, nil, goodResponse) 810 811 res, err := client.GetLeafs(ctx, request) 812 if err != nil { 813 t.Fatal(err) 814 } 815 assert.Equal(t, 1024, len(res.Keys)) 816 assert.Equal(t, 1024, len(res.Vals)) 817 818 // Succeeds within the allotted number of attempts 819 invalidResponse := []byte("invalid response") 820 mockNetClient.mockResponses(nil, invalidResponse, invalidResponse, goodResponse) 821 822 res, err = client.GetLeafs(ctx, request) 823 if err != nil { 824 t.Fatal(err) 825 } 826 assert.Equal(t, 1024, len(res.Keys)) 827 assert.Equal(t, 1024, len(res.Vals)) 828 829 // Test that GetLeafs stops after the context is cancelled 830 numAttempts := 0 831 mockNetClient.mockResponse(maxAttempts, func() { 832 numAttempts++ 833 if numAttempts >= maxAttempts { 834 cancel() 835 } 836 }, invalidResponse) 837 _, err = client.GetLeafs(ctx, request) 838 assert.Error(t, err) 839 assert.True(t, strings.Contains(err.Error(), context.Canceled.Error())) 840 } 841 842 func TestStateSyncNodes(t *testing.T) { 843 mockNetClient := &mockNetwork{} 844 845 stateSyncNodes := []ids.NodeID{ 846 ids.GenerateTestNodeID(), 847 ids.GenerateTestNodeID(), 848 ids.GenerateTestNodeID(), 849 ids.GenerateTestNodeID(), 850 } 851 client := NewClient(&ClientConfig{ 852 NetworkClient: mockNetClient, 853 Codec: message.Codec, 854 Stats: clientstats.NewNoOpStats(), 855 StateSyncNodeIDs: stateSyncNodes, 856 BlockParser: mockBlockParser, 857 }) 858 ctx, cancel := context.WithCancel(context.Background()) 859 defer cancel() 860 attempt := 0 861 responses := [][]byte{{1}, {2}, {3}, {4}} 862 mockNetClient.mockResponses(func() { 863 attempt++ 864 if attempt >= 4 { 865 cancel() 866 } 867 }, responses...) 868 869 // send some request, doesn't matter what it is because we're testing the interaction with state sync nodes here 870 response, err := client.GetLeafs(ctx, message.LeafsRequest{}) 871 assert.Error(t, err) 872 assert.Empty(t, response) 873 874 // assert all nodes were called 875 assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[0]) 876 assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[1]) 877 assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[2]) 878 assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[3]) 879 }