github.com/dim4egster/coreth@v0.10.2/sync/client/client_test.go (about) 1 // (c) 2021-2022, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package statesyncclient 5 6 import ( 7 "bytes" 8 "context" 9 "fmt" 10 "math/rand" 11 "strings" 12 "testing" 13 14 "github.com/stretchr/testify/assert" 15 16 "github.com/dim4egster/qmallgo/ids" 17 18 "github.com/dim4egster/coreth/consensus/dummy" 19 "github.com/dim4egster/coreth/core" 20 "github.com/dim4egster/coreth/core/types" 21 "github.com/dim4egster/coreth/ethdb/memorydb" 22 "github.com/dim4egster/coreth/params" 23 "github.com/dim4egster/coreth/plugin/evm/message" 24 clientstats "github.com/dim4egster/coreth/sync/client/stats" 25 "github.com/dim4egster/coreth/sync/handlers" 26 handlerstats "github.com/dim4egster/coreth/sync/handlers/stats" 27 "github.com/dim4egster/coreth/trie" 28 "github.com/ethereum/go-ethereum/common" 29 "github.com/ethereum/go-ethereum/crypto" 30 ) 31 32 func TestGetCode(t *testing.T) { 33 mockNetClient := &mockNetwork{} 34 35 tests := map[string]struct { 36 setupRequest func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) 37 expectedErr error 38 }{ 39 "normal": { 40 setupRequest: func() ([]common.Hash, message.CodeResponse, [][]byte) { 41 code := []byte("this is the code") 42 codeHash := crypto.Keccak256Hash(code) 43 codeSlices := [][]byte{code} 44 return []common.Hash{codeHash}, message.CodeResponse{ 45 Data: codeSlices, 46 }, codeSlices 47 }, 48 expectedErr: nil, 49 }, 50 "unexpected code bytes": { 51 setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) { 52 return []common.Hash{{1}}, message.CodeResponse{ 53 Data: [][]byte{{1}}, 54 }, nil 55 }, 56 expectedErr: errHashMismatch, 57 }, 58 "too many code elements returned": { 59 setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) { 60 return []common.Hash{{1}}, message.CodeResponse{ 61 Data: [][]byte{{1}, {2}}, 62 }, nil 63 }, 64 expectedErr: errInvalidCodeResponseLen, 65 }, 66 "too few code elements returned": { 67 setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) { 68 return []common.Hash{{1}}, message.CodeResponse{ 69 Data: [][]byte{}, 70 }, nil 71 }, 72 expectedErr: errInvalidCodeResponseLen, 73 }, 74 "code size is too large": { 75 setupRequest: func() (requestHashes []common.Hash, mockResponse message.CodeResponse, expectedCode [][]byte) { 76 oversizedCode := make([]byte, params.MaxCodeSize+1) 77 codeHash := crypto.Keccak256Hash(oversizedCode) 78 return []common.Hash{codeHash}, message.CodeResponse{ 79 Data: [][]byte{oversizedCode}, 80 }, nil 81 }, 82 expectedErr: errMaxCodeSizeExceeded, 83 }, 84 } 85 86 stateSyncClient := NewClient(&ClientConfig{ 87 NetworkClient: mockNetClient, 88 Codec: message.Codec, 89 Stats: clientstats.NewNoOpStats(), 90 StateSyncNodeIDs: nil, 91 BlockParser: mockBlockParser, 92 }) 93 94 for name, test := range tests { 95 t.Run(name, func(t *testing.T) { 96 ctx, cancel := context.WithCancel(context.Background()) 97 defer cancel() 98 codeHashes, res, expectedCode := test.setupRequest() 99 100 responseBytes, err := message.Codec.Marshal(message.Version, res) 101 if err != nil { 102 t.Fatal(err) 103 } 104 // Dirty hack required because the client will re-request if it encounters 105 // an error. 106 attempted := false 107 if test.expectedErr == nil { 108 mockNetClient.mockResponse(1, nil, responseBytes) 109 } else { 110 mockNetClient.mockResponse(2, func() { 111 // Cancel before the second attempt is processed. 112 if attempted { 113 cancel() 114 } 115 attempted = true 116 }, responseBytes) 117 } 118 119 codeBytes, err := stateSyncClient.GetCode(ctx, codeHashes) 120 // If we expect an error, assert that one occurred and return 121 if test.expectedErr != nil { 122 assert.ErrorIs(t, err, test.expectedErr) 123 assert.EqualValues(t, 2, mockNetClient.numCalls) 124 return 125 } 126 // Otherwise, assert there was no error and that the result is as expected 127 assert.NoError(t, err) 128 assert.Equal(t, len(codeBytes), len(expectedCode)) 129 for i, code := range codeBytes { 130 assert.Equal(t, expectedCode[i], code) 131 } 132 assert.Equal(t, uint(1), mockNetClient.numCalls) 133 }) 134 } 135 } 136 137 func TestGetBlocks(t *testing.T) { 138 // set random seed for deterministic tests 139 rand.Seed(1) 140 141 var gspec = &core.Genesis{ 142 Config: params.TestChainConfig, 143 } 144 memdb := memorydb.New() 145 genesis := gspec.MustCommit(memdb) 146 engine := dummy.NewETHFaker() 147 numBlocks := 110 148 blocks, _, err := core.GenerateChain(params.TestChainConfig, genesis, engine, memdb, numBlocks, 0, func(i int, b *core.BlockGen) {}) 149 if err != nil { 150 t.Fatal("unexpected error when generating test blockchain", err) 151 } 152 assert.Equal(t, numBlocks, len(blocks)) 153 154 // Construct client 155 mockNetClient := &mockNetwork{} 156 stateSyncClient := NewClient(&ClientConfig{ 157 NetworkClient: mockNetClient, 158 Codec: message.Codec, 159 Stats: clientstats.NewNoOpStats(), 160 StateSyncNodeIDs: nil, 161 BlockParser: mockBlockParser, 162 }) 163 164 blocksRequestHandler := handlers.NewBlockRequestHandler(buildGetter(blocks), message.Codec, handlerstats.NewNoopHandlerStats()) 165 166 // encodeBlockSlice takes a slice of blocks that are ordered in increasing height order 167 // and returns a slice of byte slices with those blocks encoded in reverse order 168 encodeBlockSlice := func(blocks []*types.Block) [][]byte { 169 blockBytes := make([][]byte, 0, len(blocks)) 170 for i := len(blocks) - 1; i >= 0; i-- { 171 buf := new(bytes.Buffer) 172 if err := blocks[i].EncodeRLP(buf); err != nil { 173 t.Fatalf("failed to generate expected response %s", err) 174 } 175 blockBytes = append(blockBytes, buf.Bytes()) 176 } 177 178 return blockBytes 179 } 180 tests := map[string]struct { 181 request message.BlockRequest 182 getResponse func(t *testing.T, request message.BlockRequest) []byte 183 assertResponse func(t *testing.T, response []*types.Block) 184 expectedErr string 185 }{ 186 "normal resonse": { 187 request: message.BlockRequest{ 188 Hash: blocks[100].Hash(), 189 Height: 100, 190 Parents: 16, 191 }, 192 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 193 response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 194 if err != nil { 195 t.Fatal(err) 196 } 197 198 if len(response) == 0 { 199 t.Fatal("Failed to generate valid response") 200 } 201 202 return response 203 }, 204 assertResponse: func(t *testing.T, response []*types.Block) { 205 assert.Equal(t, 16, len(response)) 206 }, 207 }, 208 "fewer than requested blocks": { 209 request: message.BlockRequest{ 210 Hash: blocks[100].Hash(), 211 Height: 100, 212 Parents: 16, 213 }, 214 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 215 request.Parents -= 5 216 response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 217 if err != nil { 218 t.Fatal(err) 219 } 220 221 if len(response) == 0 { 222 t.Fatal("Failed to generate valid response") 223 } 224 225 return response 226 }, 227 // If the server returns fewer than requested blocks, we should consider it valid 228 assertResponse: func(t *testing.T, response []*types.Block) { 229 assert.Equal(t, 11, len(response)) 230 }, 231 }, 232 "gibberish response": { 233 request: message.BlockRequest{ 234 Hash: blocks[100].Hash(), 235 Height: 100, 236 Parents: 16, 237 }, 238 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 239 return []byte("gibberish") 240 }, 241 expectedErr: errUnmarshalResponse.Error(), 242 }, 243 "invalid value replacing block": { 244 request: message.BlockRequest{ 245 Hash: blocks[100].Hash(), 246 Height: 100, 247 Parents: 16, 248 }, 249 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 250 response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 251 if err != nil { 252 t.Fatalf("failed to get block response: %s", err) 253 } 254 var blockResponse message.BlockResponse 255 if _, err = message.Codec.Unmarshal(response, &blockResponse); err != nil { 256 t.Fatalf("failed to marshal block response: %s", err) 257 } 258 // Replace middle value with garbage data 259 blockResponse.Blocks[10] = []byte("invalid value replacing block bytes") 260 responseBytes, err := message.Codec.Marshal(message.Version, blockResponse) 261 if err != nil { 262 t.Fatalf("failed to marshal block response: %s", err) 263 } 264 265 return responseBytes 266 }, 267 expectedErr: "failed to unmarshal response: rlp: expected input list for types.extblock", 268 }, 269 "incorrect starting point": { 270 request: message.BlockRequest{ 271 Hash: blocks[100].Hash(), 272 Height: 100, 273 Parents: 16, 274 }, 275 getResponse: func(t *testing.T, _ message.BlockRequest) []byte { 276 response, err := blocksRequestHandler.OnBlockRequest(context.Background(), ids.GenerateTestNodeID(), 1, message.BlockRequest{ 277 Hash: blocks[99].Hash(), 278 Height: 99, 279 Parents: 16, 280 }) 281 if err != nil { 282 t.Fatal(err) 283 } 284 285 if len(response) == 0 { 286 t.Fatal("Failed to generate valid response") 287 } 288 289 return response 290 }, 291 expectedErr: errHashMismatch.Error(), 292 }, 293 "missing link in between blocks": { 294 request: message.BlockRequest{ 295 Hash: blocks[100].Hash(), 296 Height: 100, 297 Parents: 16, 298 }, 299 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 300 // Encode blocks with a missing link 301 blks := make([]*types.Block, 0) 302 blks = append(blks, blocks[84:89]...) 303 blks = append(blks, blocks[90:101]...) 304 blockBytes := encodeBlockSlice(blks) 305 306 blockResponse := message.BlockResponse{ 307 Blocks: blockBytes, 308 } 309 responseBytes, err := message.Codec.Marshal(message.Version, blockResponse) 310 if err != nil { 311 t.Fatalf("failed to marshal block response: %s", err) 312 } 313 314 return responseBytes 315 }, 316 expectedErr: errHashMismatch.Error(), 317 }, 318 "no blocks": { 319 request: message.BlockRequest{ 320 Hash: blocks[100].Hash(), 321 Height: 100, 322 Parents: 16, 323 }, 324 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 325 blockResponse := message.BlockResponse{ 326 Blocks: nil, 327 } 328 responseBytes, err := message.Codec.Marshal(message.Version, blockResponse) 329 if err != nil { 330 t.Fatalf("failed to marshal block response: %s", err) 331 } 332 333 return responseBytes 334 }, 335 expectedErr: errEmptyResponse.Error(), 336 }, 337 "more than requested blocks": { 338 request: message.BlockRequest{ 339 Hash: blocks[100].Hash(), 340 Height: 100, 341 Parents: 16, 342 }, 343 getResponse: func(t *testing.T, request message.BlockRequest) []byte { 344 blockBytes := encodeBlockSlice(blocks[80:100]) 345 346 blockResponse := message.BlockResponse{ 347 Blocks: blockBytes, 348 } 349 responseBytes, err := message.Codec.Marshal(message.Version, blockResponse) 350 if err != nil { 351 t.Fatalf("failed to marshal block response: %s", err) 352 } 353 354 return responseBytes 355 }, 356 expectedErr: errTooManyBlocks.Error(), 357 }, 358 } 359 for name, test := range tests { 360 t.Run(name, func(t *testing.T) { 361 ctx, cancel := context.WithCancel(context.Background()) 362 defer cancel() 363 364 responseBytes := test.getResponse(t, test.request) 365 if len(test.expectedErr) == 0 { 366 mockNetClient.mockResponse(1, nil, responseBytes) 367 } else { 368 attempted := false 369 mockNetClient.mockResponse(2, func() { 370 if attempted { 371 cancel() 372 } 373 attempted = true 374 }, responseBytes) 375 } 376 377 blockResponse, err := stateSyncClient.GetBlocks(ctx, test.request.Hash, test.request.Height, test.request.Parents) 378 if len(test.expectedErr) != 0 { 379 if err == nil { 380 t.Fatalf("Expected error: %s, but found no error", test.expectedErr) 381 } 382 assert.True(t, strings.Contains(err.Error(), test.expectedErr), "expected error to contain [%s], but found [%s]", test.expectedErr, err) 383 return 384 } 385 if err != nil { 386 t.Fatal(err) 387 } 388 389 test.assertResponse(t, blockResponse) 390 }) 391 } 392 } 393 394 func buildGetter(blocks []*types.Block) handlers.BlockProvider { 395 return &handlers.TestBlockProvider{ 396 GetBlockFn: func(blockHash common.Hash, blockHeight uint64) *types.Block { 397 requestedBlock := blocks[blockHeight] 398 if requestedBlock.Hash() != blockHash { 399 fmt.Printf("ERROR height=%d, hash=%s, parentHash=%s, reqHash=%s\n", blockHeight, blockHash, requestedBlock.ParentHash(), requestedBlock.Hash()) 400 return nil 401 } 402 return requestedBlock 403 }, 404 } 405 } 406 407 func TestGetLeafs(t *testing.T) { 408 rand.Seed(1) 409 410 const leafsLimit = 1024 411 412 trieDB := trie.NewDatabase(memorydb.New()) 413 largeTrieRoot, largeTrieKeys, _ := trie.GenerateTrie(t, trieDB, 100_000, common.HashLength) 414 smallTrieRoot, _, _ := trie.GenerateTrie(t, trieDB, leafsLimit, common.HashLength) 415 416 handler := handlers.NewLeafsRequestHandler(trieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats()) 417 client := NewClient(&ClientConfig{ 418 NetworkClient: &mockNetwork{}, 419 Codec: message.Codec, 420 Stats: clientstats.NewNoOpStats(), 421 StateSyncNodeIDs: nil, 422 BlockParser: mockBlockParser, 423 }) 424 425 tests := map[string]struct { 426 request message.LeafsRequest 427 getResponse func(t *testing.T, request message.LeafsRequest) []byte 428 assertResponse func(t *testing.T, response message.LeafsResponse) 429 expectedErr error 430 }{ 431 "full response for small (single request) trie": { 432 request: message.LeafsRequest{ 433 Root: smallTrieRoot, 434 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 435 End: bytes.Repeat([]byte{0xff}, common.HashLength), 436 Limit: leafsLimit, 437 NodeType: message.StateTrieNode, 438 }, 439 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 440 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 441 if err != nil { 442 t.Fatal("unexpected error in calling leafs request handler", err) 443 } 444 if len(response) == 0 { 445 t.Fatal("Failed to create valid response") 446 } 447 448 return response 449 }, 450 assertResponse: func(t *testing.T, response message.LeafsResponse) { 451 assert.False(t, response.More) 452 assert.Equal(t, leafsLimit, len(response.Keys)) 453 assert.Equal(t, leafsLimit, len(response.Vals)) 454 }, 455 }, 456 "too many leaves in response": { 457 request: message.LeafsRequest{ 458 Root: smallTrieRoot, 459 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 460 End: bytes.Repeat([]byte{0xff}, common.HashLength), 461 Limit: leafsLimit / 2, 462 NodeType: message.StateTrieNode, 463 }, 464 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 465 modifiedRequest := request 466 modifiedRequest.Limit = leafsLimit 467 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, modifiedRequest) 468 if err != nil { 469 t.Fatal("unexpected error in calling leafs request handler", err) 470 } 471 if len(response) == 0 { 472 t.Fatal("Failed to create valid response") 473 } 474 475 return response 476 }, 477 expectedErr: errTooManyLeaves, 478 }, 479 "partial response to request for entire trie (full leaf limit)": { 480 request: message.LeafsRequest{ 481 Root: largeTrieRoot, 482 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 483 End: bytes.Repeat([]byte{0xff}, common.HashLength), 484 Limit: leafsLimit, 485 NodeType: message.StateTrieNode, 486 }, 487 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 488 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 489 if err != nil { 490 t.Fatal("unexpected error in calling leafs request handler", err) 491 } 492 if len(response) == 0 { 493 t.Fatal("Failed to create valid response") 494 } 495 496 return response 497 }, 498 assertResponse: func(t *testing.T, response message.LeafsResponse) { 499 assert.True(t, response.More) 500 assert.Equal(t, leafsLimit, len(response.Keys)) 501 assert.Equal(t, leafsLimit, len(response.Vals)) 502 }, 503 }, 504 "partial response to request for middle range of trie (full leaf limit)": { 505 request: message.LeafsRequest{ 506 Root: largeTrieRoot, 507 Start: largeTrieKeys[1000], 508 End: largeTrieKeys[99000], 509 Limit: leafsLimit, 510 NodeType: message.StateTrieNode, 511 }, 512 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 513 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 514 if err != nil { 515 t.Fatal("unexpected error in calling leafs request handler", err) 516 } 517 518 if len(response) == 0 { 519 t.Fatal("Failed to create valid response") 520 } 521 return response 522 }, 523 assertResponse: func(t *testing.T, response message.LeafsResponse) { 524 assert.True(t, response.More) 525 assert.Equal(t, leafsLimit, len(response.Keys)) 526 assert.Equal(t, leafsLimit, len(response.Vals)) 527 }, 528 }, 529 "full response from near end of trie to end of trie (less than leaf limit)": { 530 request: message.LeafsRequest{ 531 Root: largeTrieRoot, 532 Start: largeTrieKeys[len(largeTrieKeys)-30], // Set start 30 keys from the end of the large trie 533 End: bytes.Repeat([]byte{0xff}, common.HashLength), 534 Limit: leafsLimit, 535 NodeType: message.StateTrieNode, 536 }, 537 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 538 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 539 if err != nil { 540 t.Fatal("unexpected error in calling leafs request handler", err) 541 } 542 if len(response) == 0 { 543 t.Fatal("Failed to create valid response") 544 } 545 return response 546 }, 547 assertResponse: func(t *testing.T, response message.LeafsResponse) { 548 assert.False(t, response.More) 549 assert.Equal(t, 30, len(response.Keys)) 550 assert.Equal(t, 30, len(response.Vals)) 551 }, 552 }, 553 "full response for intermediate range of trie (less than leaf limit)": { 554 request: message.LeafsRequest{ 555 Root: largeTrieRoot, 556 Start: largeTrieKeys[1000], // Set the range for 1000 leafs in an intermediate range of the trie 557 End: largeTrieKeys[1099], // (inclusive range) 558 Limit: leafsLimit, 559 NodeType: message.StateTrieNode, 560 }, 561 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 562 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 563 if err != nil { 564 t.Fatal("unexpected error in calling leafs request handler", err) 565 } 566 if len(response) == 0 { 567 t.Fatal("Failed to create valid response") 568 } 569 570 return response 571 }, 572 assertResponse: func(t *testing.T, response message.LeafsResponse) { 573 assert.True(t, response.More) 574 assert.Equal(t, 100, len(response.Keys)) 575 assert.Equal(t, 100, len(response.Vals)) 576 }, 577 }, 578 "removed first key in response": { 579 request: message.LeafsRequest{ 580 Root: largeTrieRoot, 581 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 582 End: bytes.Repeat([]byte{0xff}, common.HashLength), 583 Limit: leafsLimit, 584 NodeType: message.StateTrieNode, 585 }, 586 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 587 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 588 if err != nil { 589 t.Fatal("unexpected error in calling leafs request handler", err) 590 } 591 if len(response) == 0 { 592 t.Fatal("Failed to create valid response") 593 } 594 var leafResponse message.LeafsResponse 595 if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { 596 t.Fatal(err) 597 } 598 leafResponse.Keys = leafResponse.Keys[1:] 599 leafResponse.Vals = leafResponse.Vals[1:] 600 601 modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) 602 if err != nil { 603 t.Fatal(err) 604 } 605 return modifiedResponse 606 }, 607 expectedErr: errInvalidRangeProof, 608 }, 609 "removed first key in response and replaced proof": { 610 request: message.LeafsRequest{ 611 Root: largeTrieRoot, 612 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 613 End: bytes.Repeat([]byte{0xff}, common.HashLength), 614 Limit: leafsLimit, 615 NodeType: message.StateTrieNode, 616 }, 617 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 618 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 619 if err != nil { 620 t.Fatal("unexpected error in calling leafs request handler", err) 621 } 622 if len(response) == 0 { 623 t.Fatal("Failed to create valid response") 624 } 625 var leafResponse message.LeafsResponse 626 if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { 627 t.Fatal(err) 628 } 629 modifiedRequest := request 630 modifiedRequest.Start = leafResponse.Keys[1] 631 modifiedResponse, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 2, modifiedRequest) 632 if err != nil { 633 t.Fatal("unexpected error in calling leafs request handler", err) 634 } 635 return modifiedResponse 636 }, 637 expectedErr: errInvalidRangeProof, 638 }, 639 "removed last key in response": { 640 request: message.LeafsRequest{ 641 Root: largeTrieRoot, 642 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 643 End: bytes.Repeat([]byte{0xff}, common.HashLength), 644 Limit: leafsLimit, 645 NodeType: message.StateTrieNode, 646 }, 647 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 648 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 649 if err != nil { 650 t.Fatal("unexpected error in calling leafs request handler", err) 651 } 652 if len(response) == 0 { 653 t.Fatal("Failed to create valid response") 654 } 655 var leafResponse message.LeafsResponse 656 if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { 657 t.Fatal(err) 658 } 659 leafResponse.Keys = leafResponse.Keys[:len(leafResponse.Keys)-2] 660 leafResponse.Vals = leafResponse.Vals[:len(leafResponse.Vals)-2] 661 662 modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) 663 if err != nil { 664 t.Fatal(err) 665 } 666 return modifiedResponse 667 }, 668 expectedErr: errInvalidRangeProof, 669 }, 670 "removed key from middle of response": { 671 request: message.LeafsRequest{ 672 Root: largeTrieRoot, 673 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 674 End: bytes.Repeat([]byte{0xff}, common.HashLength), 675 Limit: leafsLimit, 676 NodeType: message.StateTrieNode, 677 }, 678 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 679 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 680 if err != nil { 681 t.Fatal("unexpected error in calling leafs request handler", err) 682 } 683 if len(response) == 0 { 684 t.Fatal("Failed to create valid response") 685 } 686 var leafResponse message.LeafsResponse 687 if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { 688 t.Fatal(err) 689 } 690 // Remove middle key-value pair response 691 leafResponse.Keys = append(leafResponse.Keys[:100], leafResponse.Keys[101:]...) 692 leafResponse.Vals = append(leafResponse.Vals[:100], leafResponse.Vals[101:]...) 693 694 modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) 695 if err != nil { 696 t.Fatal(err) 697 } 698 return modifiedResponse 699 }, 700 expectedErr: errInvalidRangeProof, 701 }, 702 "corrupted value in middle of response": { 703 request: message.LeafsRequest{ 704 Root: largeTrieRoot, 705 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 706 End: bytes.Repeat([]byte{0xff}, common.HashLength), 707 Limit: leafsLimit, 708 NodeType: message.StateTrieNode, 709 }, 710 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 711 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 712 if err != nil { 713 t.Fatal("unexpected error in calling leafs request handler", err) 714 } 715 if len(response) == 0 { 716 t.Fatal("Failed to create valid response") 717 } 718 var leafResponse message.LeafsResponse 719 if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { 720 t.Fatal(err) 721 } 722 // Remove middle key-value pair response 723 leafResponse.Vals[100] = []byte("garbage value data") 724 725 modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) 726 if err != nil { 727 t.Fatal(err) 728 } 729 return modifiedResponse 730 }, 731 expectedErr: errInvalidRangeProof, 732 }, 733 "all proof keys removed from response": { 734 request: message.LeafsRequest{ 735 Root: largeTrieRoot, 736 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 737 End: bytes.Repeat([]byte{0xff}, common.HashLength), 738 Limit: leafsLimit, 739 NodeType: message.StateTrieNode, 740 }, 741 getResponse: func(t *testing.T, request message.LeafsRequest) []byte { 742 response, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 1, request) 743 if err != nil { 744 t.Fatal("unexpected error in calling leafs request handler", err) 745 } 746 if len(response) == 0 { 747 t.Fatal("Failed to create valid response") 748 } 749 750 var leafResponse message.LeafsResponse 751 if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { 752 t.Fatal(err) 753 } 754 // Remove the proof 755 leafResponse.ProofVals = nil 756 757 modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) 758 if err != nil { 759 t.Fatal(err) 760 } 761 return modifiedResponse 762 }, 763 expectedErr: errInvalidRangeProof, 764 }, 765 } 766 for name, test := range tests { 767 t.Run(name, func(t *testing.T) { 768 responseBytes := test.getResponse(t, test.request) 769 770 response, _, err := parseLeafsResponse(client.codec, test.request, responseBytes) 771 if test.expectedErr != nil { 772 if err == nil { 773 t.Fatalf("Expected error: %s, but found no error", test.expectedErr) 774 } 775 assert.True(t, strings.Contains(err.Error(), test.expectedErr.Error())) 776 return 777 } 778 779 if err != nil { 780 t.Fatal(err) 781 } 782 783 leafsResponse, ok := response.(message.LeafsResponse) 784 if !ok { 785 t.Fatalf("parseLeafsResponse returned incorrect type %T", response) 786 } 787 test.assertResponse(t, leafsResponse) 788 }) 789 } 790 } 791 792 func TestGetLeafsRetries(t *testing.T) { 793 rand.Seed(1) 794 795 trieDB := trie.NewDatabase(memorydb.New()) 796 root, _, _ := trie.GenerateTrie(t, trieDB, 100_000, common.HashLength) 797 798 handler := handlers.NewLeafsRequestHandler(trieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats()) 799 mockNetClient := &mockNetwork{} 800 801 const maxAttempts = 8 802 client := NewClient(&ClientConfig{ 803 NetworkClient: mockNetClient, 804 Codec: message.Codec, 805 Stats: clientstats.NewNoOpStats(), 806 StateSyncNodeIDs: nil, 807 BlockParser: mockBlockParser, 808 }) 809 810 request := message.LeafsRequest{ 811 Root: root, 812 Start: bytes.Repeat([]byte{0x00}, common.HashLength), 813 End: bytes.Repeat([]byte{0xff}, common.HashLength), 814 Limit: defaultLeafRequestLimit, 815 NodeType: message.StateTrieNode, 816 } 817 818 ctx, cancel := context.WithCancel(context.Background()) 819 defer cancel() 820 goodResponse, responseErr := handler.OnLeafsRequest(ctx, ids.GenerateTestNodeID(), 1, request) 821 assert.NoError(t, responseErr) 822 mockNetClient.mockResponse(1, nil, goodResponse) 823 824 res, err := client.GetLeafs(ctx, request) 825 if err != nil { 826 t.Fatal(err) 827 } 828 assert.Equal(t, 1024, len(res.Keys)) 829 assert.Equal(t, 1024, len(res.Vals)) 830 831 // Succeeds within the allotted number of attempts 832 invalidResponse := []byte("invalid response") 833 mockNetClient.mockResponses(nil, invalidResponse, invalidResponse, goodResponse) 834 835 res, err = client.GetLeafs(ctx, request) 836 if err != nil { 837 t.Fatal(err) 838 } 839 assert.Equal(t, 1024, len(res.Keys)) 840 assert.Equal(t, 1024, len(res.Vals)) 841 842 // Test that GetLeafs stops after the context is cancelled 843 numAttempts := 0 844 mockNetClient.mockResponse(maxAttempts, func() { 845 numAttempts++ 846 if numAttempts >= maxAttempts { 847 cancel() 848 } 849 }, invalidResponse) 850 _, err = client.GetLeafs(ctx, request) 851 assert.Error(t, err) 852 assert.True(t, strings.Contains(err.Error(), context.Canceled.Error())) 853 } 854 855 func TestStateSyncNodes(t *testing.T) { 856 mockNetClient := &mockNetwork{} 857 858 stateSyncNodes := []ids.NodeID{ 859 ids.GenerateTestNodeID(), 860 ids.GenerateTestNodeID(), 861 ids.GenerateTestNodeID(), 862 ids.GenerateTestNodeID(), 863 } 864 client := NewClient(&ClientConfig{ 865 NetworkClient: mockNetClient, 866 Codec: message.Codec, 867 Stats: clientstats.NewNoOpStats(), 868 StateSyncNodeIDs: stateSyncNodes, 869 BlockParser: mockBlockParser, 870 }) 871 ctx, cancel := context.WithCancel(context.Background()) 872 defer cancel() 873 attempt := 0 874 responses := [][]byte{{1}, {2}, {3}, {4}} 875 mockNetClient.mockResponses(func() { 876 attempt++ 877 if attempt >= 4 { 878 cancel() 879 } 880 }, responses...) 881 882 // send some request, doesn't matter what it is because we're testing the interaction with state sync nodes here 883 response, err := client.GetLeafs(ctx, message.LeafsRequest{}) 884 assert.Error(t, err) 885 assert.Empty(t, response) 886 887 // assert all nodes were called 888 assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[0]) 889 assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[1]) 890 assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[2]) 891 assert.Contains(t, mockNetClient.nodesRequested, stateSyncNodes[3]) 892 }