github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/neatptc/fetcher/fetcher_test.go (about) 1 package fetcher 2 3 import ( 4 "errors" 5 "math/big" 6 "sync" 7 "sync/atomic" 8 "testing" 9 "time" 10 11 "github.com/neatlab/neatio/chain/core" 12 "github.com/neatlab/neatio/chain/core/rawdb" 13 "github.com/neatlab/neatio/chain/core/types" 14 "github.com/neatlab/neatio/params" 15 "github.com/neatlab/neatio/utilities/common" 16 "github.com/neatlab/neatio/utilities/crypto" 17 ) 18 19 var ( 20 testdb = rawdb.NewMemoryDatabase() 21 testKey, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") 22 testAddress = crypto.PubkeyToAddress(testKey.PublicKey) 23 genesis = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000)) 24 unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit}, nil, nil, nil) 25 ) 26 27 func makeChain(n int, seed byte, parent *types.Block) ([]common.Hash, map[common.Hash]*types.Block) { 28 blocks, _ := core.GenerateChain(params.TestChainConfig, parent, nil, testdb, n, func(i int, block *core.BlockGen) { 29 block.SetCoinbase(common.Address{seed}) 30 31 if parent == genesis && i%3 == 0 { 32 signer := types.MakeSigner(params.TestChainConfig, block.Number()) 33 tx, err := types.SignTx(types.NewTransaction(block.TxNonce(testAddress), common.Address{seed}, big.NewInt(1000), params.TxGas, nil, nil), signer, testKey) 34 if err != nil { 35 panic(err) 36 } 37 block.AddTx(tx) 38 } 39 40 if i%5 == 0 { 41 block.AddUncle(&types.Header{ParentHash: block.PrevBlock(i - 1).Hash(), Number: big.NewInt(int64(i - 1))}) 42 } 43 }) 44 hashes := make([]common.Hash, n+1) 45 hashes[len(hashes)-1] = parent.Hash() 46 blockm := make(map[common.Hash]*types.Block, n+1) 47 blockm[parent.Hash()] = parent 48 for i, b := range blocks { 49 hashes[len(hashes)-i-2] = b.Hash() 50 blockm[b.Hash()] = b 51 } 52 return hashes, blockm 53 } 54 55 type fetcherTester struct { 56 fetcher *Fetcher 57 58 hashes []common.Hash 59 blocks map[common.Hash]*types.Block 60 drops map[string]bool 61 62 lock sync.RWMutex 63 } 64 65 func newTester() *fetcherTester { 66 tester := &fetcherTester{ 67 hashes: []common.Hash{genesis.Hash()}, 68 blocks: map[common.Hash]*types.Block{genesis.Hash(): genesis}, 69 drops: make(map[string]bool), 70 } 71 tester.fetcher = New(tester.getBlock, tester.verifyHeader, tester.broadcastBlock, tester.chainHeight, tester.insertChain, tester.dropPeer) 72 tester.fetcher.Start() 73 74 return tester 75 } 76 77 func (f *fetcherTester) getBlock(hash common.Hash) *types.Block { 78 f.lock.RLock() 79 defer f.lock.RUnlock() 80 81 return f.blocks[hash] 82 } 83 84 func (f *fetcherTester) verifyHeader(header *types.Header) error { 85 return nil 86 } 87 88 func (f *fetcherTester) broadcastBlock(block *types.Block, propagate bool) { 89 } 90 91 func (f *fetcherTester) chainHeight() uint64 { 92 f.lock.RLock() 93 defer f.lock.RUnlock() 94 95 return f.blocks[f.hashes[len(f.hashes)-1]].NumberU64() 96 } 97 98 func (f *fetcherTester) insertChain(blocks types.Blocks) (int, error) { 99 f.lock.Lock() 100 defer f.lock.Unlock() 101 102 for i, block := range blocks { 103 104 if _, ok := f.blocks[block.ParentHash()]; !ok { 105 return i, errors.New("unknown parent") 106 } 107 108 if block.NumberU64() <= f.blocks[f.hashes[len(f.hashes)-1]].NumberU64() { 109 return i, nil 110 } 111 112 f.hashes = append(f.hashes, block.Hash()) 113 f.blocks[block.Hash()] = block 114 } 115 return 0, nil 116 } 117 118 func (f *fetcherTester) dropPeer(peer string) { 119 f.lock.Lock() 120 defer f.lock.Unlock() 121 122 f.drops[peer] = true 123 } 124 125 func (f *fetcherTester) makeHeaderFetcher(peer string, blocks map[common.Hash]*types.Block, drift time.Duration) headerRequesterFn { 126 closure := make(map[common.Hash]*types.Block) 127 for hash, block := range blocks { 128 closure[hash] = block 129 } 130 131 return func(hash common.Hash) error { 132 133 headers := make([]*types.Header, 0, 1) 134 if block, ok := closure[hash]; ok { 135 headers = append(headers, block.Header()) 136 } 137 138 go f.fetcher.FilterHeaders(peer, headers, time.Now().Add(drift)) 139 140 return nil 141 } 142 } 143 144 func (f *fetcherTester) makeBodyFetcher(peer string, blocks map[common.Hash]*types.Block, drift time.Duration) bodyRequesterFn { 145 closure := make(map[common.Hash]*types.Block) 146 for hash, block := range blocks { 147 closure[hash] = block 148 } 149 150 return func(hashes []common.Hash) error { 151 152 transactions := make([][]*types.Transaction, 0, len(hashes)) 153 uncles := make([][]*types.Header, 0, len(hashes)) 154 155 for _, hash := range hashes { 156 if block, ok := closure[hash]; ok { 157 transactions = append(transactions, block.Transactions()) 158 uncles = append(uncles, block.Uncles()) 159 } 160 } 161 162 go f.fetcher.FilterBodies(peer, transactions, uncles, time.Now().Add(drift)) 163 164 return nil 165 } 166 } 167 168 func verifyFetchingEvent(t *testing.T, fetching chan []common.Hash, arrive bool) { 169 if arrive { 170 select { 171 case <-fetching: 172 case <-time.After(time.Second): 173 t.Fatalf("fetching timeout") 174 } 175 } else { 176 select { 177 case <-fetching: 178 t.Fatalf("fetching invoked") 179 case <-time.After(10 * time.Millisecond): 180 } 181 } 182 } 183 184 func verifyCompletingEvent(t *testing.T, completing chan []common.Hash, arrive bool) { 185 if arrive { 186 select { 187 case <-completing: 188 case <-time.After(time.Second): 189 t.Fatalf("completing timeout") 190 } 191 } else { 192 select { 193 case <-completing: 194 t.Fatalf("completing invoked") 195 case <-time.After(10 * time.Millisecond): 196 } 197 } 198 } 199 200 func verifyImportEvent(t *testing.T, imported chan *types.Block, arrive bool) { 201 if arrive { 202 select { 203 case <-imported: 204 case <-time.After(time.Second): 205 t.Fatalf("import timeout") 206 } 207 } else { 208 select { 209 case <-imported: 210 t.Fatalf("import invoked") 211 case <-time.After(10 * time.Millisecond): 212 } 213 } 214 } 215 216 func verifyImportCount(t *testing.T, imported chan *types.Block, count int) { 217 for i := 0; i < count; i++ { 218 select { 219 case <-imported: 220 case <-time.After(time.Second): 221 t.Fatalf("block %d: import timeout", i+1) 222 } 223 } 224 verifyImportDone(t, imported) 225 } 226 227 func verifyImportDone(t *testing.T, imported chan *types.Block) { 228 select { 229 case <-imported: 230 t.Fatalf("extra block imported") 231 case <-time.After(50 * time.Millisecond): 232 } 233 } 234 235 func TestSequentialAnnouncements62(t *testing.T) { testSequentialAnnouncements(t, 62) } 236 func TestSequentialAnnouncements63(t *testing.T) { testSequentialAnnouncements(t, 63) } 237 func TestSequentialAnnouncements64(t *testing.T) { testSequentialAnnouncements(t, 64) } 238 239 func testSequentialAnnouncements(t *testing.T, protocol int) { 240 241 targetBlocks := 4 * hashLimit 242 hashes, blocks := makeChain(targetBlocks, 0, genesis) 243 244 tester := newTester() 245 headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) 246 bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) 247 248 imported := make(chan *types.Block) 249 tester.fetcher.importedHook = func(block *types.Block) { imported <- block } 250 251 for i := len(hashes) - 2; i >= 0; i-- { 252 tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) 253 verifyImportEvent(t, imported, true) 254 } 255 verifyImportDone(t, imported) 256 } 257 258 func TestConcurrentAnnouncements62(t *testing.T) { testConcurrentAnnouncements(t, 62) } 259 func TestConcurrentAnnouncements63(t *testing.T) { testConcurrentAnnouncements(t, 63) } 260 func TestConcurrentAnnouncements64(t *testing.T) { testConcurrentAnnouncements(t, 64) } 261 262 func testConcurrentAnnouncements(t *testing.T, protocol int) { 263 264 targetBlocks := 4 * hashLimit 265 hashes, blocks := makeChain(targetBlocks, 0, genesis) 266 267 tester := newTester() 268 firstHeaderFetcher := tester.makeHeaderFetcher("first", blocks, -gatherSlack) 269 firstBodyFetcher := tester.makeBodyFetcher("first", blocks, 0) 270 secondHeaderFetcher := tester.makeHeaderFetcher("second", blocks, -gatherSlack) 271 secondBodyFetcher := tester.makeBodyFetcher("second", blocks, 0) 272 273 counter := uint32(0) 274 firstHeaderWrapper := func(hash common.Hash) error { 275 atomic.AddUint32(&counter, 1) 276 return firstHeaderFetcher(hash) 277 } 278 secondHeaderWrapper := func(hash common.Hash) error { 279 atomic.AddUint32(&counter, 1) 280 return secondHeaderFetcher(hash) 281 } 282 283 imported := make(chan *types.Block) 284 tester.fetcher.importedHook = func(block *types.Block) { imported <- block } 285 286 for i := len(hashes) - 2; i >= 0; i-- { 287 tester.fetcher.Notify("first", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), firstHeaderWrapper, firstBodyFetcher) 288 tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout+time.Millisecond), secondHeaderWrapper, secondBodyFetcher) 289 tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout-time.Millisecond), secondHeaderWrapper, secondBodyFetcher) 290 verifyImportEvent(t, imported, true) 291 } 292 verifyImportDone(t, imported) 293 294 if int(counter) != targetBlocks { 295 t.Fatalf("retrieval count mismatch: have %v, want %v", counter, targetBlocks) 296 } 297 } 298 299 func TestOverlappingAnnouncements62(t *testing.T) { testOverlappingAnnouncements(t, 62) } 300 func TestOverlappingAnnouncements63(t *testing.T) { testOverlappingAnnouncements(t, 63) } 301 func TestOverlappingAnnouncements64(t *testing.T) { testOverlappingAnnouncements(t, 64) } 302 303 func testOverlappingAnnouncements(t *testing.T, protocol int) { 304 305 targetBlocks := 4 * hashLimit 306 hashes, blocks := makeChain(targetBlocks, 0, genesis) 307 308 tester := newTester() 309 headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) 310 bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) 311 312 overlap := 16 313 imported := make(chan *types.Block, len(hashes)-1) 314 for i := 0; i < overlap; i++ { 315 imported <- nil 316 } 317 tester.fetcher.importedHook = func(block *types.Block) { imported <- block } 318 319 for i := len(hashes) - 2; i >= 0; i-- { 320 tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) 321 select { 322 case <-imported: 323 case <-time.After(time.Second): 324 t.Fatalf("block %d: import timeout", len(hashes)-i) 325 } 326 } 327 328 verifyImportCount(t, imported, overlap) 329 } 330 331 func TestPendingDeduplication62(t *testing.T) { testPendingDeduplication(t, 62) } 332 func TestPendingDeduplication63(t *testing.T) { testPendingDeduplication(t, 63) } 333 func TestPendingDeduplication64(t *testing.T) { testPendingDeduplication(t, 64) } 334 335 func testPendingDeduplication(t *testing.T, protocol int) { 336 337 hashes, blocks := makeChain(1, 0, genesis) 338 339 tester := newTester() 340 headerFetcher := tester.makeHeaderFetcher("repeater", blocks, -gatherSlack) 341 bodyFetcher := tester.makeBodyFetcher("repeater", blocks, 0) 342 343 delay := 50 * time.Millisecond 344 counter := uint32(0) 345 headerWrapper := func(hash common.Hash) error { 346 atomic.AddUint32(&counter, 1) 347 348 go func() { 349 time.Sleep(delay) 350 headerFetcher(hash) 351 }() 352 return nil 353 } 354 355 for tester.getBlock(hashes[0]) == nil { 356 tester.fetcher.Notify("repeater", hashes[0], 1, time.Now().Add(-arriveTimeout), headerWrapper, bodyFetcher) 357 time.Sleep(time.Millisecond) 358 } 359 time.Sleep(delay) 360 361 if imported := len(tester.blocks); imported != 2 { 362 t.Fatalf("synchronised block mismatch: have %v, want %v", imported, 2) 363 } 364 if int(counter) != 1 { 365 t.Fatalf("retrieval count mismatch: have %v, want %v", counter, 1) 366 } 367 } 368 369 func TestRandomArrivalImport62(t *testing.T) { testRandomArrivalImport(t, 62) } 370 func TestRandomArrivalImport63(t *testing.T) { testRandomArrivalImport(t, 63) } 371 func TestRandomArrivalImport64(t *testing.T) { testRandomArrivalImport(t, 64) } 372 373 func testRandomArrivalImport(t *testing.T, protocol int) { 374 375 targetBlocks := maxQueueDist 376 hashes, blocks := makeChain(targetBlocks, 0, genesis) 377 skip := targetBlocks / 2 378 379 tester := newTester() 380 headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) 381 bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) 382 383 imported := make(chan *types.Block, len(hashes)-1) 384 tester.fetcher.importedHook = func(block *types.Block) { imported <- block } 385 386 for i := len(hashes) - 1; i >= 0; i-- { 387 if i != skip { 388 tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) 389 time.Sleep(time.Millisecond) 390 } 391 } 392 393 tester.fetcher.Notify("valid", hashes[skip], uint64(len(hashes)-skip-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) 394 verifyImportCount(t, imported, len(hashes)-1) 395 } 396 397 func TestQueueGapFill62(t *testing.T) { testQueueGapFill(t, 62) } 398 func TestQueueGapFill63(t *testing.T) { testQueueGapFill(t, 63) } 399 func TestQueueGapFill64(t *testing.T) { testQueueGapFill(t, 64) } 400 401 func testQueueGapFill(t *testing.T, protocol int) { 402 403 targetBlocks := maxQueueDist 404 hashes, blocks := makeChain(targetBlocks, 0, genesis) 405 skip := targetBlocks / 2 406 407 tester := newTester() 408 headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) 409 bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) 410 411 imported := make(chan *types.Block, len(hashes)-1) 412 tester.fetcher.importedHook = func(block *types.Block) { imported <- block } 413 414 for i := len(hashes) - 1; i >= 0; i-- { 415 if i != skip { 416 tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) 417 time.Sleep(time.Millisecond) 418 } 419 } 420 421 tester.fetcher.Enqueue("valid", blocks[hashes[skip]]) 422 verifyImportCount(t, imported, len(hashes)-1) 423 } 424 425 func TestImportDeduplication62(t *testing.T) { testImportDeduplication(t, 62) } 426 func TestImportDeduplication63(t *testing.T) { testImportDeduplication(t, 63) } 427 func TestImportDeduplication64(t *testing.T) { testImportDeduplication(t, 64) } 428 429 func testImportDeduplication(t *testing.T, protocol int) { 430 431 hashes, blocks := makeChain(2, 0, genesis) 432 433 tester := newTester() 434 headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) 435 bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) 436 437 counter := uint32(0) 438 tester.fetcher.insertChain = func(blocks types.Blocks) (int, error) { 439 atomic.AddUint32(&counter, uint32(len(blocks))) 440 return tester.insertChain(blocks) 441 } 442 443 fetching := make(chan []common.Hash) 444 imported := make(chan *types.Block, len(hashes)-1) 445 tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- hashes } 446 tester.fetcher.importedHook = func(block *types.Block) { imported <- block } 447 448 tester.fetcher.Notify("valid", hashes[0], 1, time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) 449 <-fetching 450 451 tester.fetcher.Enqueue("valid", blocks[hashes[0]]) 452 tester.fetcher.Enqueue("valid", blocks[hashes[0]]) 453 tester.fetcher.Enqueue("valid", blocks[hashes[0]]) 454 455 tester.fetcher.Enqueue("valid", blocks[hashes[1]]) 456 verifyImportCount(t, imported, 2) 457 458 if counter != 2 { 459 t.Fatalf("import invocation count mismatch: have %v, want %v", counter, 2) 460 } 461 } 462 463 func TestDistantPropagationDiscarding(t *testing.T) { 464 465 hashes, blocks := makeChain(3*maxQueueDist, 0, genesis) 466 head := hashes[len(hashes)/2] 467 468 low, high := len(hashes)/2+maxUncleDist+1, len(hashes)/2-maxQueueDist-1 469 470 tester := newTester() 471 472 tester.lock.Lock() 473 tester.hashes = []common.Hash{head} 474 tester.blocks = map[common.Hash]*types.Block{head: blocks[head]} 475 tester.lock.Unlock() 476 477 tester.fetcher.Enqueue("lower", blocks[hashes[low]]) 478 time.Sleep(10 * time.Millisecond) 479 if !tester.fetcher.queue.Empty() { 480 t.Fatalf("fetcher queued stale block") 481 } 482 483 tester.fetcher.Enqueue("higher", blocks[hashes[high]]) 484 time.Sleep(10 * time.Millisecond) 485 if !tester.fetcher.queue.Empty() { 486 t.Fatalf("fetcher queued future block") 487 } 488 } 489 490 func TestDistantAnnouncementDiscarding62(t *testing.T) { testDistantAnnouncementDiscarding(t, 62) } 491 func TestDistantAnnouncementDiscarding63(t *testing.T) { testDistantAnnouncementDiscarding(t, 63) } 492 func TestDistantAnnouncementDiscarding64(t *testing.T) { testDistantAnnouncementDiscarding(t, 64) } 493 494 func testDistantAnnouncementDiscarding(t *testing.T, protocol int) { 495 496 hashes, blocks := makeChain(3*maxQueueDist, 0, genesis) 497 head := hashes[len(hashes)/2] 498 499 low, high := len(hashes)/2+maxUncleDist+1, len(hashes)/2-maxQueueDist-1 500 501 tester := newTester() 502 503 tester.lock.Lock() 504 tester.hashes = []common.Hash{head} 505 tester.blocks = map[common.Hash]*types.Block{head: blocks[head]} 506 tester.lock.Unlock() 507 508 headerFetcher := tester.makeHeaderFetcher("lower", blocks, -gatherSlack) 509 bodyFetcher := tester.makeBodyFetcher("lower", blocks, 0) 510 511 fetching := make(chan struct{}, 2) 512 tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- struct{}{} } 513 514 tester.fetcher.Notify("lower", hashes[low], blocks[hashes[low]].NumberU64(), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) 515 select { 516 case <-time.After(50 * time.Millisecond): 517 case <-fetching: 518 t.Fatalf("fetcher requested stale header") 519 } 520 521 tester.fetcher.Notify("higher", hashes[high], blocks[hashes[high]].NumberU64(), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) 522 select { 523 case <-time.After(50 * time.Millisecond): 524 case <-fetching: 525 t.Fatalf("fetcher requested future header") 526 } 527 } 528 529 func TestInvalidNumberAnnouncement62(t *testing.T) { testInvalidNumberAnnouncement(t, 62) } 530 func TestInvalidNumberAnnouncement63(t *testing.T) { testInvalidNumberAnnouncement(t, 63) } 531 func TestInvalidNumberAnnouncement64(t *testing.T) { testInvalidNumberAnnouncement(t, 64) } 532 533 func testInvalidNumberAnnouncement(t *testing.T, protocol int) { 534 535 hashes, blocks := makeChain(1, 0, genesis) 536 537 tester := newTester() 538 badHeaderFetcher := tester.makeHeaderFetcher("bad", blocks, -gatherSlack) 539 badBodyFetcher := tester.makeBodyFetcher("bad", blocks, 0) 540 541 imported := make(chan *types.Block) 542 tester.fetcher.importedHook = func(block *types.Block) { imported <- block } 543 544 tester.fetcher.Notify("bad", hashes[0], 2, time.Now().Add(-arriveTimeout), badHeaderFetcher, badBodyFetcher) 545 verifyImportEvent(t, imported, false) 546 547 tester.lock.RLock() 548 dropped := tester.drops["bad"] 549 tester.lock.RUnlock() 550 551 if !dropped { 552 t.Fatalf("peer with invalid numbered announcement not dropped") 553 } 554 555 goodHeaderFetcher := tester.makeHeaderFetcher("good", blocks, -gatherSlack) 556 goodBodyFetcher := tester.makeBodyFetcher("good", blocks, 0) 557 558 tester.fetcher.Notify("good", hashes[0], 1, time.Now().Add(-arriveTimeout), goodHeaderFetcher, goodBodyFetcher) 559 verifyImportEvent(t, imported, true) 560 561 tester.lock.RLock() 562 dropped = tester.drops["good"] 563 tester.lock.RUnlock() 564 565 if dropped { 566 t.Fatalf("peer with valid numbered announcement dropped") 567 } 568 verifyImportDone(t, imported) 569 } 570 571 func TestEmptyBlockShortCircuit62(t *testing.T) { testEmptyBlockShortCircuit(t, 62) } 572 func TestEmptyBlockShortCircuit63(t *testing.T) { testEmptyBlockShortCircuit(t, 63) } 573 func TestEmptyBlockShortCircuit64(t *testing.T) { testEmptyBlockShortCircuit(t, 64) } 574 575 func testEmptyBlockShortCircuit(t *testing.T, protocol int) { 576 577 hashes, blocks := makeChain(32, 0, genesis) 578 579 tester := newTester() 580 headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) 581 bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) 582 583 fetching := make(chan []common.Hash) 584 tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- hashes } 585 586 completing := make(chan []common.Hash) 587 tester.fetcher.completingHook = func(hashes []common.Hash) { completing <- hashes } 588 589 imported := make(chan *types.Block) 590 tester.fetcher.importedHook = func(block *types.Block) { imported <- block } 591 592 for i := len(hashes) - 2; i >= 0; i-- { 593 tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher) 594 595 verifyFetchingEvent(t, fetching, true) 596 597 verifyCompletingEvent(t, completing, len(blocks[hashes[i]].Transactions()) > 0 || len(blocks[hashes[i]].Uncles()) > 0) 598 599 verifyImportEvent(t, imported, true) 600 } 601 verifyImportDone(t, imported) 602 } 603 604 func TestHashMemoryExhaustionAttack62(t *testing.T) { testHashMemoryExhaustionAttack(t, 62) } 605 func TestHashMemoryExhaustionAttack63(t *testing.T) { testHashMemoryExhaustionAttack(t, 63) } 606 func TestHashMemoryExhaustionAttack64(t *testing.T) { testHashMemoryExhaustionAttack(t, 64) } 607 608 func testHashMemoryExhaustionAttack(t *testing.T, protocol int) { 609 610 tester := newTester() 611 612 imported, announces := make(chan *types.Block), int32(0) 613 tester.fetcher.importedHook = func(block *types.Block) { imported <- block } 614 tester.fetcher.announceChangeHook = func(hash common.Hash, added bool) { 615 if added { 616 atomic.AddInt32(&announces, 1) 617 } else { 618 atomic.AddInt32(&announces, -1) 619 } 620 } 621 622 targetBlocks := hashLimit + 2*maxQueueDist 623 hashes, blocks := makeChain(targetBlocks, 0, genesis) 624 validHeaderFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack) 625 validBodyFetcher := tester.makeBodyFetcher("valid", blocks, 0) 626 627 attack, _ := makeChain(targetBlocks, 0, unknownBlock) 628 attackerHeaderFetcher := tester.makeHeaderFetcher("attacker", nil, -gatherSlack) 629 attackerBodyFetcher := tester.makeBodyFetcher("attacker", nil, 0) 630 631 for i := 0; i < len(attack); i++ { 632 if i < maxQueueDist { 633 tester.fetcher.Notify("valid", hashes[len(hashes)-2-i], uint64(i+1), time.Now(), validHeaderFetcher, validBodyFetcher) 634 } 635 tester.fetcher.Notify("attacker", attack[i], 1, time.Now(), attackerHeaderFetcher, attackerBodyFetcher) 636 } 637 if count := atomic.LoadInt32(&announces); count != hashLimit+maxQueueDist { 638 t.Fatalf("queued announce count mismatch: have %d, want %d", count, hashLimit+maxQueueDist) 639 } 640 641 verifyImportCount(t, imported, maxQueueDist) 642 643 for i := len(hashes) - maxQueueDist - 2; i >= 0; i-- { 644 tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), validHeaderFetcher, validBodyFetcher) 645 verifyImportEvent(t, imported, true) 646 } 647 verifyImportDone(t, imported) 648 } 649 650 func TestBlockMemoryExhaustionAttack(t *testing.T) { 651 652 tester := newTester() 653 654 imported, enqueued := make(chan *types.Block), int32(0) 655 tester.fetcher.importedHook = func(block *types.Block) { imported <- block } 656 tester.fetcher.queueChangeHook = func(hash common.Hash, added bool) { 657 if added { 658 atomic.AddInt32(&enqueued, 1) 659 } else { 660 atomic.AddInt32(&enqueued, -1) 661 } 662 } 663 664 targetBlocks := hashLimit + 2*maxQueueDist 665 hashes, blocks := makeChain(targetBlocks, 0, genesis) 666 attack := make(map[common.Hash]*types.Block) 667 for i := byte(0); len(attack) < blockLimit+2*maxQueueDist; i++ { 668 hashes, blocks := makeChain(maxQueueDist-1, i, unknownBlock) 669 for _, hash := range hashes[:maxQueueDist-2] { 670 attack[hash] = blocks[hash] 671 } 672 } 673 674 for _, block := range attack { 675 tester.fetcher.Enqueue("attacker", block) 676 } 677 time.Sleep(200 * time.Millisecond) 678 if queued := atomic.LoadInt32(&enqueued); queued != blockLimit { 679 t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit) 680 } 681 682 for i := 0; i < maxQueueDist-1; i++ { 683 tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-3-i]]) 684 } 685 time.Sleep(100 * time.Millisecond) 686 if queued := atomic.LoadInt32(&enqueued); queued != blockLimit+maxQueueDist-1 { 687 t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit+maxQueueDist-1) 688 } 689 690 tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-2]]) 691 verifyImportCount(t, imported, maxQueueDist) 692 693 for i := maxQueueDist; i < len(hashes)-1; i++ { 694 tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-2-i]]) 695 verifyImportEvent(t, imported, true) 696 } 697 verifyImportDone(t, imported) 698 }