github.com/Bytom/bytom@v1.1.2-0.20210127130405-ae40204c0b09/netsync/block_keeper_test.go (about) 1 package netsync 2 3 import ( 4 "container/list" 5 "encoding/hex" 6 "encoding/json" 7 "testing" 8 "time" 9 10 "github.com/bytom/bytom/consensus" 11 "github.com/bytom/bytom/errors" 12 "github.com/bytom/bytom/protocol/bc" 13 "github.com/bytom/bytom/protocol/bc/types" 14 "github.com/bytom/bytom/test/mock" 15 "github.com/bytom/bytom/testutil" 16 ) 17 18 func TestAppendHeaderList(t *testing.T) { 19 blocks := mockBlocks(nil, 7) 20 cases := []struct { 21 originalHeaders []*types.BlockHeader 22 inputHeaders []*types.BlockHeader 23 wantHeaders []*types.BlockHeader 24 err error 25 }{ 26 { 27 originalHeaders: []*types.BlockHeader{&blocks[0].BlockHeader}, 28 inputHeaders: []*types.BlockHeader{&blocks[1].BlockHeader, &blocks[2].BlockHeader}, 29 wantHeaders: []*types.BlockHeader{&blocks[0].BlockHeader, &blocks[1].BlockHeader, &blocks[2].BlockHeader}, 30 err: nil, 31 }, 32 { 33 originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader}, 34 inputHeaders: []*types.BlockHeader{&blocks[6].BlockHeader}, 35 wantHeaders: []*types.BlockHeader{&blocks[5].BlockHeader, &blocks[6].BlockHeader}, 36 err: nil, 37 }, 38 { 39 originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader}, 40 inputHeaders: []*types.BlockHeader{&blocks[7].BlockHeader}, 41 wantHeaders: []*types.BlockHeader{&blocks[5].BlockHeader}, 42 err: errAppendHeaders, 43 }, 44 { 45 originalHeaders: []*types.BlockHeader{&blocks[5].BlockHeader}, 46 inputHeaders: []*types.BlockHeader{&blocks[7].BlockHeader, &blocks[6].BlockHeader}, 47 wantHeaders: []*types.BlockHeader{&blocks[5].BlockHeader}, 48 err: errAppendHeaders, 49 }, 50 { 51 originalHeaders: []*types.BlockHeader{&blocks[2].BlockHeader}, 52 inputHeaders: []*types.BlockHeader{&blocks[3].BlockHeader, &blocks[4].BlockHeader, &blocks[6].BlockHeader}, 53 wantHeaders: []*types.BlockHeader{&blocks[2].BlockHeader, &blocks[3].BlockHeader, &blocks[4].BlockHeader}, 54 err: errAppendHeaders, 55 }, 56 } 57 58 for i, c := range cases { 59 bk := &blockKeeper{headerList: list.New()} 60 for _, header := range c.originalHeaders { 61 bk.headerList.PushBack(header) 62 } 63 64 if err := bk.appendHeaderList(c.inputHeaders); err != c.err { 65 t.Errorf("case %d: got error %v want error %v", i, err, c.err) 66 } 67 68 gotHeaders := []*types.BlockHeader{} 69 for e := bk.headerList.Front(); e != nil; e = e.Next() { 70 gotHeaders = append(gotHeaders, e.Value.(*types.BlockHeader)) 71 } 72 73 if !testutil.DeepEqual(gotHeaders, c.wantHeaders) { 74 t.Errorf("case %d: got %v want %v", i, gotHeaders, c.wantHeaders) 75 } 76 } 77 } 78 79 func TestBlockLocator(t *testing.T) { 80 blocks := mockBlocks(nil, 500) 81 cases := []struct { 82 bestHeight uint64 83 wantHeight []uint64 84 }{ 85 { 86 bestHeight: 0, 87 wantHeight: []uint64{0}, 88 }, 89 { 90 bestHeight: 1, 91 wantHeight: []uint64{1, 0}, 92 }, 93 { 94 bestHeight: 7, 95 wantHeight: []uint64{7, 6, 5, 4, 3, 2, 1, 0}, 96 }, 97 { 98 bestHeight: 10, 99 wantHeight: []uint64{10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, 100 }, 101 { 102 bestHeight: 100, 103 wantHeight: []uint64{100, 99, 98, 97, 96, 95, 94, 93, 92, 91, 89, 85, 77, 61, 29, 0}, 104 }, 105 { 106 bestHeight: 500, 107 wantHeight: []uint64{500, 499, 498, 497, 496, 495, 494, 493, 492, 491, 489, 485, 477, 461, 429, 365, 237, 0}, 108 }, 109 } 110 111 for i, c := range cases { 112 mockChain := mock.NewChain() 113 bk := &blockKeeper{chain: mockChain} 114 mockChain.SetBestBlockHeader(&blocks[c.bestHeight].BlockHeader) 115 for i := uint64(0); i <= c.bestHeight; i++ { 116 mockChain.SetBlockByHeight(i, blocks[i]) 117 } 118 119 want := []*bc.Hash{} 120 for _, i := range c.wantHeight { 121 hash := blocks[i].Hash() 122 want = append(want, &hash) 123 } 124 125 if got := bk.blockLocator(); !testutil.DeepEqual(got, want) { 126 t.Errorf("case %d: got %v want %v", i, got, want) 127 } 128 } 129 } 130 131 func TestFastBlockSync(t *testing.T) { 132 maxBlockPerMsg = 5 133 maxBlockHeadersPerMsg = 10 134 baseChain := mockBlocks(nil, 300) 135 136 cases := []struct { 137 syncTimeout time.Duration 138 aBlocks []*types.Block 139 bBlocks []*types.Block 140 checkPoint *consensus.Checkpoint 141 want []*types.Block 142 err error 143 }{ 144 { 145 syncTimeout: 30 * time.Second, 146 aBlocks: baseChain[:100], 147 bBlocks: baseChain[:301], 148 checkPoint: &consensus.Checkpoint{ 149 Height: baseChain[250].Height, 150 Hash: baseChain[250].Hash(), 151 }, 152 want: baseChain[:251], 153 err: nil, 154 }, 155 { 156 syncTimeout: 30 * time.Second, 157 aBlocks: baseChain[:100], 158 bBlocks: baseChain[:301], 159 checkPoint: &consensus.Checkpoint{ 160 Height: baseChain[100].Height, 161 Hash: baseChain[100].Hash(), 162 }, 163 want: baseChain[:101], 164 err: nil, 165 }, 166 { 167 syncTimeout: 1 * time.Millisecond, 168 aBlocks: baseChain[:100], 169 bBlocks: baseChain[:100], 170 checkPoint: &consensus.Checkpoint{ 171 Height: baseChain[200].Height, 172 Hash: baseChain[200].Hash(), 173 }, 174 want: baseChain[:100], 175 err: errRequestTimeout, 176 }, 177 } 178 179 for i, c := range cases { 180 syncTimeout = c.syncTimeout 181 a := mockSync(c.aBlocks) 182 b := mockSync(c.bBlocks) 183 netWork := NewNetWork() 184 netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode) 185 netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode) 186 if B2A, A2B, err := netWork.HandsShake(a, b); err != nil { 187 t.Errorf("fail on peer hands shake %v", err) 188 } else { 189 go B2A.postMan() 190 go A2B.postMan() 191 } 192 193 a.blockKeeper.syncPeer = a.peers.getPeer("test node B") 194 if err := a.blockKeeper.fastBlockSync(c.checkPoint); errors.Root(err) != c.err { 195 t.Errorf("case %d: got %v want %v", i, err, c.err) 196 } 197 198 got := []*types.Block{} 199 for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ { 200 block, err := a.chain.GetBlockByHeight(i) 201 if err != nil { 202 t.Errorf("case %d got err %v", i, err) 203 } 204 got = append(got, block) 205 } 206 207 if !testutil.DeepEqual(got, c.want) { 208 t.Errorf("case %d: got %v want %v", i, got, c.want) 209 } 210 } 211 } 212 213 func TestLocateBlocks(t *testing.T) { 214 maxBlockPerMsg = 5 215 blocks := mockBlocks(nil, 100) 216 cases := []struct { 217 locator []uint64 218 stopHash bc.Hash 219 wantHeight []uint64 220 }{ 221 { 222 locator: []uint64{20}, 223 stopHash: blocks[100].Hash(), 224 wantHeight: []uint64{21, 22, 23, 24, 25}, 225 }, 226 } 227 228 mockChain := mock.NewChain() 229 bk := &blockKeeper{chain: mockChain} 230 for _, block := range blocks { 231 mockChain.SetBlockByHeight(block.Height, block) 232 } 233 234 for i, c := range cases { 235 locator := []*bc.Hash{} 236 for _, i := range c.locator { 237 hash := blocks[i].Hash() 238 locator = append(locator, &hash) 239 } 240 241 want := []*types.Block{} 242 for _, i := range c.wantHeight { 243 want = append(want, blocks[i]) 244 } 245 246 got, _ := bk.locateBlocks(locator, &c.stopHash) 247 if !testutil.DeepEqual(got, want) { 248 t.Errorf("case %d: got %v want %v", i, got, want) 249 } 250 } 251 } 252 253 func TestLocateHeaders(t *testing.T) { 254 maxBlockHeadersPerMsg = 10 255 blocks := mockBlocks(nil, 150) 256 cases := []struct { 257 chainHeight uint64 258 locator []uint64 259 stopHash bc.Hash 260 wantHeight []uint64 261 err bool 262 }{ 263 { 264 chainHeight: 100, 265 locator: []uint64{}, 266 stopHash: blocks[100].Hash(), 267 wantHeight: []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, 268 err: false, 269 }, 270 { 271 chainHeight: 100, 272 locator: []uint64{20}, 273 stopHash: blocks[100].Hash(), 274 wantHeight: []uint64{21, 22, 23, 24, 25, 26, 27, 28, 29, 30}, 275 err: false, 276 }, 277 { 278 chainHeight: 100, 279 locator: []uint64{20}, 280 stopHash: blocks[24].Hash(), 281 wantHeight: []uint64{21, 22, 23, 24}, 282 err: false, 283 }, 284 { 285 chainHeight: 100, 286 locator: []uint64{20}, 287 stopHash: blocks[20].Hash(), 288 wantHeight: []uint64{}, 289 err: false, 290 }, 291 { 292 chainHeight: 100, 293 locator: []uint64{20}, 294 stopHash: bc.Hash{}, 295 wantHeight: []uint64{}, 296 err: true, 297 }, 298 { 299 chainHeight: 100, 300 locator: []uint64{120, 70}, 301 stopHash: blocks[78].Hash(), 302 wantHeight: []uint64{71, 72, 73, 74, 75, 76, 77, 78}, 303 err: false, 304 }, 305 } 306 307 for i, c := range cases { 308 mockChain := mock.NewChain() 309 bk := &blockKeeper{chain: mockChain} 310 for i := uint64(0); i <= c.chainHeight; i++ { 311 mockChain.SetBlockByHeight(i, blocks[i]) 312 } 313 314 locator := []*bc.Hash{} 315 for _, i := range c.locator { 316 hash := blocks[i].Hash() 317 locator = append(locator, &hash) 318 } 319 320 want := []*types.BlockHeader{} 321 for _, i := range c.wantHeight { 322 want = append(want, &blocks[i].BlockHeader) 323 } 324 325 got, err := bk.locateHeaders(locator, &c.stopHash) 326 if err != nil != c.err { 327 t.Errorf("case %d: got %v want err = %v", i, err, c.err) 328 } 329 if !testutil.DeepEqual(got, want) { 330 t.Errorf("case %d: got %v want %v", i, got, want) 331 } 332 } 333 } 334 335 func TestNextCheckpoint(t *testing.T) { 336 cases := []struct { 337 checkPoints []consensus.Checkpoint 338 bestHeight uint64 339 want *consensus.Checkpoint 340 }{ 341 { 342 checkPoints: []consensus.Checkpoint{}, 343 bestHeight: 5000, 344 want: nil, 345 }, 346 { 347 checkPoints: []consensus.Checkpoint{ 348 {10000, bc.Hash{V0: 1}}, 349 }, 350 bestHeight: 5000, 351 want: &consensus.Checkpoint{10000, bc.Hash{V0: 1}}, 352 }, 353 { 354 checkPoints: []consensus.Checkpoint{ 355 {10000, bc.Hash{V0: 1}}, 356 {20000, bc.Hash{V0: 2}}, 357 {30000, bc.Hash{V0: 3}}, 358 }, 359 bestHeight: 15000, 360 want: &consensus.Checkpoint{20000, bc.Hash{V0: 2}}, 361 }, 362 { 363 checkPoints: []consensus.Checkpoint{ 364 {10000, bc.Hash{V0: 1}}, 365 {20000, bc.Hash{V0: 2}}, 366 {30000, bc.Hash{V0: 3}}, 367 }, 368 bestHeight: 10000, 369 want: &consensus.Checkpoint{20000, bc.Hash{V0: 2}}, 370 }, 371 { 372 checkPoints: []consensus.Checkpoint{ 373 {10000, bc.Hash{V0: 1}}, 374 {20000, bc.Hash{V0: 2}}, 375 {30000, bc.Hash{V0: 3}}, 376 }, 377 bestHeight: 35000, 378 want: nil, 379 }, 380 } 381 382 mockChain := mock.NewChain() 383 for i, c := range cases { 384 consensus.ActiveNetParams.Checkpoints = c.checkPoints 385 mockChain.SetBestBlockHeader(&types.BlockHeader{Height: c.bestHeight}) 386 bk := &blockKeeper{chain: mockChain} 387 388 if got := bk.nextCheckpoint(); !testutil.DeepEqual(got, c.want) { 389 t.Errorf("case %d: got %v want %v", i, got, c.want) 390 } 391 } 392 } 393 394 func TestRegularBlockSync(t *testing.T) { 395 baseChain := mockBlocks(nil, 50) 396 chainX := append(baseChain, mockBlocks(baseChain[50], 60)...) 397 chainY := append(baseChain, mockBlocks(baseChain[50], 70)...) 398 cases := []struct { 399 syncTimeout time.Duration 400 aBlocks []*types.Block 401 bBlocks []*types.Block 402 syncHeight uint64 403 want []*types.Block 404 err error 405 }{ 406 { 407 syncTimeout: 30 * time.Second, 408 aBlocks: baseChain[:20], 409 bBlocks: baseChain[:50], 410 syncHeight: 45, 411 want: baseChain[:46], 412 err: nil, 413 }, 414 { 415 syncTimeout: 30 * time.Second, 416 aBlocks: chainX, 417 bBlocks: chainY, 418 syncHeight: 70, 419 want: chainY, 420 err: nil, 421 }, 422 { 423 syncTimeout: 30 * time.Second, 424 aBlocks: chainX[:52], 425 bBlocks: chainY[:53], 426 syncHeight: 52, 427 want: chainY[:53], 428 err: nil, 429 }, 430 { 431 syncTimeout: 1 * time.Millisecond, 432 aBlocks: baseChain, 433 bBlocks: baseChain, 434 syncHeight: 52, 435 want: baseChain, 436 err: errRequestTimeout, 437 }, 438 } 439 440 for i, c := range cases { 441 syncTimeout = c.syncTimeout 442 a := mockSync(c.aBlocks) 443 b := mockSync(c.bBlocks) 444 netWork := NewNetWork() 445 netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode) 446 netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode) 447 if B2A, A2B, err := netWork.HandsShake(a, b); err != nil { 448 t.Errorf("fail on peer hands shake %v", err) 449 } else { 450 go B2A.postMan() 451 go A2B.postMan() 452 } 453 454 a.blockKeeper.syncPeer = a.peers.getPeer("test node B") 455 if err := a.blockKeeper.regularBlockSync(c.syncHeight); errors.Root(err) != c.err { 456 t.Errorf("case %d: got %v want %v", i, err, c.err) 457 } 458 459 got := []*types.Block{} 460 for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ { 461 block, err := a.chain.GetBlockByHeight(i) 462 if err != nil { 463 t.Errorf("case %d got err %v", i, err) 464 } 465 got = append(got, block) 466 } 467 468 if !testutil.DeepEqual(got, c.want) { 469 t.Errorf("case %d: got %v want %v", i, got, c.want) 470 } 471 } 472 } 473 474 func TestRequireBlock(t *testing.T) { 475 blocks := mockBlocks(nil, 5) 476 a := mockSync(blocks[:1]) 477 b := mockSync(blocks[:5]) 478 netWork := NewNetWork() 479 netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode) 480 netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode) 481 if B2A, A2B, err := netWork.HandsShake(a, b); err != nil { 482 t.Errorf("fail on peer hands shake %v", err) 483 } else { 484 go B2A.postMan() 485 go A2B.postMan() 486 } 487 488 a.blockKeeper.syncPeer = a.peers.getPeer("test node B") 489 b.blockKeeper.syncPeer = b.peers.getPeer("test node A") 490 cases := []struct { 491 syncTimeout time.Duration 492 testNode *SyncManager 493 requireHeight uint64 494 want *types.Block 495 err error 496 }{ 497 { 498 syncTimeout: 30 * time.Second, 499 testNode: a, 500 requireHeight: 4, 501 want: blocks[4], 502 err: nil, 503 }, 504 { 505 syncTimeout: 1 * time.Millisecond, 506 testNode: b, 507 requireHeight: 4, 508 want: nil, 509 err: errRequestTimeout, 510 }, 511 } 512 513 for i, c := range cases { 514 syncTimeout = c.syncTimeout 515 got, err := c.testNode.blockKeeper.requireBlock(c.requireHeight) 516 if !testutil.DeepEqual(got, c.want) { 517 t.Errorf("case %d: got %v want %v", i, got, c.want) 518 } 519 if errors.Root(err) != c.err { 520 t.Errorf("case %d: got %v want %v", i, err, c.err) 521 } 522 } 523 } 524 525 func TestSendMerkleBlock(t *testing.T) { 526 cases := []struct { 527 txCount int 528 relatedTxIndex []int 529 }{ 530 { 531 txCount: 10, 532 relatedTxIndex: []int{0, 2, 5}, 533 }, 534 { 535 txCount: 0, 536 relatedTxIndex: []int{}, 537 }, 538 { 539 txCount: 10, 540 relatedTxIndex: []int{}, 541 }, 542 { 543 txCount: 5, 544 relatedTxIndex: []int{0, 1, 2, 3, 4}, 545 }, 546 { 547 txCount: 20, 548 relatedTxIndex: []int{1, 6, 3, 9, 10, 19}, 549 }, 550 } 551 552 for _, c := range cases { 553 blocks := mockBlocks(nil, 2) 554 targetBlock := blocks[1] 555 txs, bcTxs := mockTxs(c.txCount) 556 var err error 557 558 targetBlock.Transactions = txs 559 if targetBlock.TransactionsMerkleRoot, err = types.TxMerkleRoot(bcTxs); err != nil { 560 t.Fatal(err) 561 } 562 563 spvNode := mockSync(blocks) 564 blockHash := targetBlock.Hash() 565 var statusResult *bc.TransactionStatus 566 if statusResult, err = spvNode.chain.GetTransactionStatus(&blockHash); err != nil { 567 t.Fatal(err) 568 } 569 570 if targetBlock.TransactionStatusHash, err = types.TxStatusMerkleRoot(statusResult.VerifyStatus); err != nil { 571 t.Fatal(err) 572 } 573 574 fullNode := mockSync(blocks) 575 netWork := NewNetWork() 576 netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync) 577 netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices) 578 579 var F2S *P2PPeer 580 if F2S, _, err = netWork.HandsShake(spvNode, fullNode); err != nil { 581 t.Errorf("fail on peer hands shake %v", err) 582 } 583 584 completed := make(chan error) 585 go func() { 586 msgBytes := <-F2S.msgCh 587 _, msg, _ := DecodeMessage(msgBytes) 588 switch m := msg.(type) { 589 case *MerkleBlockMessage: 590 var relatedTxIDs []*bc.Hash 591 for _, rawTx := range m.RawTxDatas { 592 tx := &types.Tx{} 593 if err := tx.UnmarshalText(rawTx); err != nil { 594 completed <- err 595 } 596 597 relatedTxIDs = append(relatedTxIDs, &tx.ID) 598 } 599 var txHashes []*bc.Hash 600 for _, hashByte := range m.TxHashes { 601 hash := bc.NewHash(hashByte) 602 txHashes = append(txHashes, &hash) 603 } 604 if ok := types.ValidateTxMerkleTreeProof(txHashes, m.Flags, relatedTxIDs, targetBlock.TransactionsMerkleRoot); !ok { 605 completed <- errors.New("validate tx fail") 606 } 607 608 var statusHashes []*bc.Hash 609 for _, statusByte := range m.StatusHashes { 610 hash := bc.NewHash(statusByte) 611 statusHashes = append(statusHashes, &hash) 612 } 613 var relatedStatuses []*bc.TxVerifyResult 614 for _, statusByte := range m.RawTxStatuses { 615 status := &bc.TxVerifyResult{} 616 err := json.Unmarshal(statusByte, status) 617 if err != nil { 618 completed <- err 619 } 620 relatedStatuses = append(relatedStatuses, status) 621 } 622 if ok := types.ValidateStatusMerkleTreeProof(statusHashes, m.Flags, relatedStatuses, targetBlock.TransactionStatusHash); !ok { 623 completed <- errors.New("validate status fail") 624 } 625 626 completed <- nil 627 } 628 }() 629 630 spvPeer := fullNode.peers.getPeer("spv_node") 631 for i := 0; i < len(c.relatedTxIndex); i++ { 632 spvPeer.filterAdds.Add(hex.EncodeToString(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram)) 633 } 634 msg := &GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()} 635 fullNode.handleGetMerkleBlockMsg(spvPeer, msg) 636 if err := <-completed; err != nil { 637 t.Fatal(err) 638 } 639 } 640 }