github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/neatptc/fetcher/fetcher_test.go (about)

     1  package fetcher
     2  
     3  import (
     4  	"errors"
     5  	"math/big"
     6  	"sync"
     7  	"sync/atomic"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/neatlab/neatio/chain/core"
    12  	"github.com/neatlab/neatio/chain/core/rawdb"
    13  	"github.com/neatlab/neatio/chain/core/types"
    14  	"github.com/neatlab/neatio/params"
    15  	"github.com/neatlab/neatio/utilities/common"
    16  	"github.com/neatlab/neatio/utilities/crypto"
    17  )
    18  
    19  var (
    20  	testdb       = rawdb.NewMemoryDatabase()
    21  	testKey, _   = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
    22  	testAddress  = crypto.PubkeyToAddress(testKey.PublicKey)
    23  	genesis      = core.GenesisBlockForTesting(testdb, testAddress, big.NewInt(1000000000))
    24  	unknownBlock = types.NewBlock(&types.Header{GasLimit: params.GenesisGasLimit}, nil, nil, nil)
    25  )
    26  
    27  func makeChain(n int, seed byte, parent *types.Block) ([]common.Hash, map[common.Hash]*types.Block) {
    28  	blocks, _ := core.GenerateChain(params.TestChainConfig, parent, nil, testdb, n, func(i int, block *core.BlockGen) {
    29  		block.SetCoinbase(common.Address{seed})
    30  
    31  		if parent == genesis && i%3 == 0 {
    32  			signer := types.MakeSigner(params.TestChainConfig, block.Number())
    33  			tx, err := types.SignTx(types.NewTransaction(block.TxNonce(testAddress), common.Address{seed}, big.NewInt(1000), params.TxGas, nil, nil), signer, testKey)
    34  			if err != nil {
    35  				panic(err)
    36  			}
    37  			block.AddTx(tx)
    38  		}
    39  
    40  		if i%5 == 0 {
    41  			block.AddUncle(&types.Header{ParentHash: block.PrevBlock(i - 1).Hash(), Number: big.NewInt(int64(i - 1))})
    42  		}
    43  	})
    44  	hashes := make([]common.Hash, n+1)
    45  	hashes[len(hashes)-1] = parent.Hash()
    46  	blockm := make(map[common.Hash]*types.Block, n+1)
    47  	blockm[parent.Hash()] = parent
    48  	for i, b := range blocks {
    49  		hashes[len(hashes)-i-2] = b.Hash()
    50  		blockm[b.Hash()] = b
    51  	}
    52  	return hashes, blockm
    53  }
    54  
    55  type fetcherTester struct {
    56  	fetcher *Fetcher
    57  
    58  	hashes []common.Hash
    59  	blocks map[common.Hash]*types.Block
    60  	drops  map[string]bool
    61  
    62  	lock sync.RWMutex
    63  }
    64  
    65  func newTester() *fetcherTester {
    66  	tester := &fetcherTester{
    67  		hashes: []common.Hash{genesis.Hash()},
    68  		blocks: map[common.Hash]*types.Block{genesis.Hash(): genesis},
    69  		drops:  make(map[string]bool),
    70  	}
    71  	tester.fetcher = New(tester.getBlock, tester.verifyHeader, tester.broadcastBlock, tester.chainHeight, tester.insertChain, tester.dropPeer)
    72  	tester.fetcher.Start()
    73  
    74  	return tester
    75  }
    76  
    77  func (f *fetcherTester) getBlock(hash common.Hash) *types.Block {
    78  	f.lock.RLock()
    79  	defer f.lock.RUnlock()
    80  
    81  	return f.blocks[hash]
    82  }
    83  
    84  func (f *fetcherTester) verifyHeader(header *types.Header) error {
    85  	return nil
    86  }
    87  
    88  func (f *fetcherTester) broadcastBlock(block *types.Block, propagate bool) {
    89  }
    90  
    91  func (f *fetcherTester) chainHeight() uint64 {
    92  	f.lock.RLock()
    93  	defer f.lock.RUnlock()
    94  
    95  	return f.blocks[f.hashes[len(f.hashes)-1]].NumberU64()
    96  }
    97  
    98  func (f *fetcherTester) insertChain(blocks types.Blocks) (int, error) {
    99  	f.lock.Lock()
   100  	defer f.lock.Unlock()
   101  
   102  	for i, block := range blocks {
   103  
   104  		if _, ok := f.blocks[block.ParentHash()]; !ok {
   105  			return i, errors.New("unknown parent")
   106  		}
   107  
   108  		if block.NumberU64() <= f.blocks[f.hashes[len(f.hashes)-1]].NumberU64() {
   109  			return i, nil
   110  		}
   111  
   112  		f.hashes = append(f.hashes, block.Hash())
   113  		f.blocks[block.Hash()] = block
   114  	}
   115  	return 0, nil
   116  }
   117  
   118  func (f *fetcherTester) dropPeer(peer string) {
   119  	f.lock.Lock()
   120  	defer f.lock.Unlock()
   121  
   122  	f.drops[peer] = true
   123  }
   124  
   125  func (f *fetcherTester) makeHeaderFetcher(peer string, blocks map[common.Hash]*types.Block, drift time.Duration) headerRequesterFn {
   126  	closure := make(map[common.Hash]*types.Block)
   127  	for hash, block := range blocks {
   128  		closure[hash] = block
   129  	}
   130  
   131  	return func(hash common.Hash) error {
   132  
   133  		headers := make([]*types.Header, 0, 1)
   134  		if block, ok := closure[hash]; ok {
   135  			headers = append(headers, block.Header())
   136  		}
   137  
   138  		go f.fetcher.FilterHeaders(peer, headers, time.Now().Add(drift))
   139  
   140  		return nil
   141  	}
   142  }
   143  
   144  func (f *fetcherTester) makeBodyFetcher(peer string, blocks map[common.Hash]*types.Block, drift time.Duration) bodyRequesterFn {
   145  	closure := make(map[common.Hash]*types.Block)
   146  	for hash, block := range blocks {
   147  		closure[hash] = block
   148  	}
   149  
   150  	return func(hashes []common.Hash) error {
   151  
   152  		transactions := make([][]*types.Transaction, 0, len(hashes))
   153  		uncles := make([][]*types.Header, 0, len(hashes))
   154  
   155  		for _, hash := range hashes {
   156  			if block, ok := closure[hash]; ok {
   157  				transactions = append(transactions, block.Transactions())
   158  				uncles = append(uncles, block.Uncles())
   159  			}
   160  		}
   161  
   162  		go f.fetcher.FilterBodies(peer, transactions, uncles, time.Now().Add(drift))
   163  
   164  		return nil
   165  	}
   166  }
   167  
   168  func verifyFetchingEvent(t *testing.T, fetching chan []common.Hash, arrive bool) {
   169  	if arrive {
   170  		select {
   171  		case <-fetching:
   172  		case <-time.After(time.Second):
   173  			t.Fatalf("fetching timeout")
   174  		}
   175  	} else {
   176  		select {
   177  		case <-fetching:
   178  			t.Fatalf("fetching invoked")
   179  		case <-time.After(10 * time.Millisecond):
   180  		}
   181  	}
   182  }
   183  
   184  func verifyCompletingEvent(t *testing.T, completing chan []common.Hash, arrive bool) {
   185  	if arrive {
   186  		select {
   187  		case <-completing:
   188  		case <-time.After(time.Second):
   189  			t.Fatalf("completing timeout")
   190  		}
   191  	} else {
   192  		select {
   193  		case <-completing:
   194  			t.Fatalf("completing invoked")
   195  		case <-time.After(10 * time.Millisecond):
   196  		}
   197  	}
   198  }
   199  
   200  func verifyImportEvent(t *testing.T, imported chan *types.Block, arrive bool) {
   201  	if arrive {
   202  		select {
   203  		case <-imported:
   204  		case <-time.After(time.Second):
   205  			t.Fatalf("import timeout")
   206  		}
   207  	} else {
   208  		select {
   209  		case <-imported:
   210  			t.Fatalf("import invoked")
   211  		case <-time.After(10 * time.Millisecond):
   212  		}
   213  	}
   214  }
   215  
   216  func verifyImportCount(t *testing.T, imported chan *types.Block, count int) {
   217  	for i := 0; i < count; i++ {
   218  		select {
   219  		case <-imported:
   220  		case <-time.After(time.Second):
   221  			t.Fatalf("block %d: import timeout", i+1)
   222  		}
   223  	}
   224  	verifyImportDone(t, imported)
   225  }
   226  
   227  func verifyImportDone(t *testing.T, imported chan *types.Block) {
   228  	select {
   229  	case <-imported:
   230  		t.Fatalf("extra block imported")
   231  	case <-time.After(50 * time.Millisecond):
   232  	}
   233  }
   234  
   235  func TestSequentialAnnouncements62(t *testing.T) { testSequentialAnnouncements(t, 62) }
   236  func TestSequentialAnnouncements63(t *testing.T) { testSequentialAnnouncements(t, 63) }
   237  func TestSequentialAnnouncements64(t *testing.T) { testSequentialAnnouncements(t, 64) }
   238  
   239  func testSequentialAnnouncements(t *testing.T, protocol int) {
   240  
   241  	targetBlocks := 4 * hashLimit
   242  	hashes, blocks := makeChain(targetBlocks, 0, genesis)
   243  
   244  	tester := newTester()
   245  	headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
   246  	bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
   247  
   248  	imported := make(chan *types.Block)
   249  	tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
   250  
   251  	for i := len(hashes) - 2; i >= 0; i-- {
   252  		tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
   253  		verifyImportEvent(t, imported, true)
   254  	}
   255  	verifyImportDone(t, imported)
   256  }
   257  
   258  func TestConcurrentAnnouncements62(t *testing.T) { testConcurrentAnnouncements(t, 62) }
   259  func TestConcurrentAnnouncements63(t *testing.T) { testConcurrentAnnouncements(t, 63) }
   260  func TestConcurrentAnnouncements64(t *testing.T) { testConcurrentAnnouncements(t, 64) }
   261  
   262  func testConcurrentAnnouncements(t *testing.T, protocol int) {
   263  
   264  	targetBlocks := 4 * hashLimit
   265  	hashes, blocks := makeChain(targetBlocks, 0, genesis)
   266  
   267  	tester := newTester()
   268  	firstHeaderFetcher := tester.makeHeaderFetcher("first", blocks, -gatherSlack)
   269  	firstBodyFetcher := tester.makeBodyFetcher("first", blocks, 0)
   270  	secondHeaderFetcher := tester.makeHeaderFetcher("second", blocks, -gatherSlack)
   271  	secondBodyFetcher := tester.makeBodyFetcher("second", blocks, 0)
   272  
   273  	counter := uint32(0)
   274  	firstHeaderWrapper := func(hash common.Hash) error {
   275  		atomic.AddUint32(&counter, 1)
   276  		return firstHeaderFetcher(hash)
   277  	}
   278  	secondHeaderWrapper := func(hash common.Hash) error {
   279  		atomic.AddUint32(&counter, 1)
   280  		return secondHeaderFetcher(hash)
   281  	}
   282  
   283  	imported := make(chan *types.Block)
   284  	tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
   285  
   286  	for i := len(hashes) - 2; i >= 0; i-- {
   287  		tester.fetcher.Notify("first", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), firstHeaderWrapper, firstBodyFetcher)
   288  		tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout+time.Millisecond), secondHeaderWrapper, secondBodyFetcher)
   289  		tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout-time.Millisecond), secondHeaderWrapper, secondBodyFetcher)
   290  		verifyImportEvent(t, imported, true)
   291  	}
   292  	verifyImportDone(t, imported)
   293  
   294  	if int(counter) != targetBlocks {
   295  		t.Fatalf("retrieval count mismatch: have %v, want %v", counter, targetBlocks)
   296  	}
   297  }
   298  
   299  func TestOverlappingAnnouncements62(t *testing.T) { testOverlappingAnnouncements(t, 62) }
   300  func TestOverlappingAnnouncements63(t *testing.T) { testOverlappingAnnouncements(t, 63) }
   301  func TestOverlappingAnnouncements64(t *testing.T) { testOverlappingAnnouncements(t, 64) }
   302  
   303  func testOverlappingAnnouncements(t *testing.T, protocol int) {
   304  
   305  	targetBlocks := 4 * hashLimit
   306  	hashes, blocks := makeChain(targetBlocks, 0, genesis)
   307  
   308  	tester := newTester()
   309  	headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
   310  	bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
   311  
   312  	overlap := 16
   313  	imported := make(chan *types.Block, len(hashes)-1)
   314  	for i := 0; i < overlap; i++ {
   315  		imported <- nil
   316  	}
   317  	tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
   318  
   319  	for i := len(hashes) - 2; i >= 0; i-- {
   320  		tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
   321  		select {
   322  		case <-imported:
   323  		case <-time.After(time.Second):
   324  			t.Fatalf("block %d: import timeout", len(hashes)-i)
   325  		}
   326  	}
   327  
   328  	verifyImportCount(t, imported, overlap)
   329  }
   330  
   331  func TestPendingDeduplication62(t *testing.T) { testPendingDeduplication(t, 62) }
   332  func TestPendingDeduplication63(t *testing.T) { testPendingDeduplication(t, 63) }
   333  func TestPendingDeduplication64(t *testing.T) { testPendingDeduplication(t, 64) }
   334  
   335  func testPendingDeduplication(t *testing.T, protocol int) {
   336  
   337  	hashes, blocks := makeChain(1, 0, genesis)
   338  
   339  	tester := newTester()
   340  	headerFetcher := tester.makeHeaderFetcher("repeater", blocks, -gatherSlack)
   341  	bodyFetcher := tester.makeBodyFetcher("repeater", blocks, 0)
   342  
   343  	delay := 50 * time.Millisecond
   344  	counter := uint32(0)
   345  	headerWrapper := func(hash common.Hash) error {
   346  		atomic.AddUint32(&counter, 1)
   347  
   348  		go func() {
   349  			time.Sleep(delay)
   350  			headerFetcher(hash)
   351  		}()
   352  		return nil
   353  	}
   354  
   355  	for tester.getBlock(hashes[0]) == nil {
   356  		tester.fetcher.Notify("repeater", hashes[0], 1, time.Now().Add(-arriveTimeout), headerWrapper, bodyFetcher)
   357  		time.Sleep(time.Millisecond)
   358  	}
   359  	time.Sleep(delay)
   360  
   361  	if imported := len(tester.blocks); imported != 2 {
   362  		t.Fatalf("synchronised block mismatch: have %v, want %v", imported, 2)
   363  	}
   364  	if int(counter) != 1 {
   365  		t.Fatalf("retrieval count mismatch: have %v, want %v", counter, 1)
   366  	}
   367  }
   368  
   369  func TestRandomArrivalImport62(t *testing.T) { testRandomArrivalImport(t, 62) }
   370  func TestRandomArrivalImport63(t *testing.T) { testRandomArrivalImport(t, 63) }
   371  func TestRandomArrivalImport64(t *testing.T) { testRandomArrivalImport(t, 64) }
   372  
   373  func testRandomArrivalImport(t *testing.T, protocol int) {
   374  
   375  	targetBlocks := maxQueueDist
   376  	hashes, blocks := makeChain(targetBlocks, 0, genesis)
   377  	skip := targetBlocks / 2
   378  
   379  	tester := newTester()
   380  	headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
   381  	bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
   382  
   383  	imported := make(chan *types.Block, len(hashes)-1)
   384  	tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
   385  
   386  	for i := len(hashes) - 1; i >= 0; i-- {
   387  		if i != skip {
   388  			tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
   389  			time.Sleep(time.Millisecond)
   390  		}
   391  	}
   392  
   393  	tester.fetcher.Notify("valid", hashes[skip], uint64(len(hashes)-skip-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
   394  	verifyImportCount(t, imported, len(hashes)-1)
   395  }
   396  
   397  func TestQueueGapFill62(t *testing.T) { testQueueGapFill(t, 62) }
   398  func TestQueueGapFill63(t *testing.T) { testQueueGapFill(t, 63) }
   399  func TestQueueGapFill64(t *testing.T) { testQueueGapFill(t, 64) }
   400  
   401  func testQueueGapFill(t *testing.T, protocol int) {
   402  
   403  	targetBlocks := maxQueueDist
   404  	hashes, blocks := makeChain(targetBlocks, 0, genesis)
   405  	skip := targetBlocks / 2
   406  
   407  	tester := newTester()
   408  	headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
   409  	bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
   410  
   411  	imported := make(chan *types.Block, len(hashes)-1)
   412  	tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
   413  
   414  	for i := len(hashes) - 1; i >= 0; i-- {
   415  		if i != skip {
   416  			tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
   417  			time.Sleep(time.Millisecond)
   418  		}
   419  	}
   420  
   421  	tester.fetcher.Enqueue("valid", blocks[hashes[skip]])
   422  	verifyImportCount(t, imported, len(hashes)-1)
   423  }
   424  
   425  func TestImportDeduplication62(t *testing.T) { testImportDeduplication(t, 62) }
   426  func TestImportDeduplication63(t *testing.T) { testImportDeduplication(t, 63) }
   427  func TestImportDeduplication64(t *testing.T) { testImportDeduplication(t, 64) }
   428  
   429  func testImportDeduplication(t *testing.T, protocol int) {
   430  
   431  	hashes, blocks := makeChain(2, 0, genesis)
   432  
   433  	tester := newTester()
   434  	headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
   435  	bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
   436  
   437  	counter := uint32(0)
   438  	tester.fetcher.insertChain = func(blocks types.Blocks) (int, error) {
   439  		atomic.AddUint32(&counter, uint32(len(blocks)))
   440  		return tester.insertChain(blocks)
   441  	}
   442  
   443  	fetching := make(chan []common.Hash)
   444  	imported := make(chan *types.Block, len(hashes)-1)
   445  	tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- hashes }
   446  	tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
   447  
   448  	tester.fetcher.Notify("valid", hashes[0], 1, time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
   449  	<-fetching
   450  
   451  	tester.fetcher.Enqueue("valid", blocks[hashes[0]])
   452  	tester.fetcher.Enqueue("valid", blocks[hashes[0]])
   453  	tester.fetcher.Enqueue("valid", blocks[hashes[0]])
   454  
   455  	tester.fetcher.Enqueue("valid", blocks[hashes[1]])
   456  	verifyImportCount(t, imported, 2)
   457  
   458  	if counter != 2 {
   459  		t.Fatalf("import invocation count mismatch: have %v, want %v", counter, 2)
   460  	}
   461  }
   462  
   463  func TestDistantPropagationDiscarding(t *testing.T) {
   464  
   465  	hashes, blocks := makeChain(3*maxQueueDist, 0, genesis)
   466  	head := hashes[len(hashes)/2]
   467  
   468  	low, high := len(hashes)/2+maxUncleDist+1, len(hashes)/2-maxQueueDist-1
   469  
   470  	tester := newTester()
   471  
   472  	tester.lock.Lock()
   473  	tester.hashes = []common.Hash{head}
   474  	tester.blocks = map[common.Hash]*types.Block{head: blocks[head]}
   475  	tester.lock.Unlock()
   476  
   477  	tester.fetcher.Enqueue("lower", blocks[hashes[low]])
   478  	time.Sleep(10 * time.Millisecond)
   479  	if !tester.fetcher.queue.Empty() {
   480  		t.Fatalf("fetcher queued stale block")
   481  	}
   482  
   483  	tester.fetcher.Enqueue("higher", blocks[hashes[high]])
   484  	time.Sleep(10 * time.Millisecond)
   485  	if !tester.fetcher.queue.Empty() {
   486  		t.Fatalf("fetcher queued future block")
   487  	}
   488  }
   489  
   490  func TestDistantAnnouncementDiscarding62(t *testing.T) { testDistantAnnouncementDiscarding(t, 62) }
   491  func TestDistantAnnouncementDiscarding63(t *testing.T) { testDistantAnnouncementDiscarding(t, 63) }
   492  func TestDistantAnnouncementDiscarding64(t *testing.T) { testDistantAnnouncementDiscarding(t, 64) }
   493  
   494  func testDistantAnnouncementDiscarding(t *testing.T, protocol int) {
   495  
   496  	hashes, blocks := makeChain(3*maxQueueDist, 0, genesis)
   497  	head := hashes[len(hashes)/2]
   498  
   499  	low, high := len(hashes)/2+maxUncleDist+1, len(hashes)/2-maxQueueDist-1
   500  
   501  	tester := newTester()
   502  
   503  	tester.lock.Lock()
   504  	tester.hashes = []common.Hash{head}
   505  	tester.blocks = map[common.Hash]*types.Block{head: blocks[head]}
   506  	tester.lock.Unlock()
   507  
   508  	headerFetcher := tester.makeHeaderFetcher("lower", blocks, -gatherSlack)
   509  	bodyFetcher := tester.makeBodyFetcher("lower", blocks, 0)
   510  
   511  	fetching := make(chan struct{}, 2)
   512  	tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- struct{}{} }
   513  
   514  	tester.fetcher.Notify("lower", hashes[low], blocks[hashes[low]].NumberU64(), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
   515  	select {
   516  	case <-time.After(50 * time.Millisecond):
   517  	case <-fetching:
   518  		t.Fatalf("fetcher requested stale header")
   519  	}
   520  
   521  	tester.fetcher.Notify("higher", hashes[high], blocks[hashes[high]].NumberU64(), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
   522  	select {
   523  	case <-time.After(50 * time.Millisecond):
   524  	case <-fetching:
   525  		t.Fatalf("fetcher requested future header")
   526  	}
   527  }
   528  
   529  func TestInvalidNumberAnnouncement62(t *testing.T) { testInvalidNumberAnnouncement(t, 62) }
   530  func TestInvalidNumberAnnouncement63(t *testing.T) { testInvalidNumberAnnouncement(t, 63) }
   531  func TestInvalidNumberAnnouncement64(t *testing.T) { testInvalidNumberAnnouncement(t, 64) }
   532  
   533  func testInvalidNumberAnnouncement(t *testing.T, protocol int) {
   534  
   535  	hashes, blocks := makeChain(1, 0, genesis)
   536  
   537  	tester := newTester()
   538  	badHeaderFetcher := tester.makeHeaderFetcher("bad", blocks, -gatherSlack)
   539  	badBodyFetcher := tester.makeBodyFetcher("bad", blocks, 0)
   540  
   541  	imported := make(chan *types.Block)
   542  	tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
   543  
   544  	tester.fetcher.Notify("bad", hashes[0], 2, time.Now().Add(-arriveTimeout), badHeaderFetcher, badBodyFetcher)
   545  	verifyImportEvent(t, imported, false)
   546  
   547  	tester.lock.RLock()
   548  	dropped := tester.drops["bad"]
   549  	tester.lock.RUnlock()
   550  
   551  	if !dropped {
   552  		t.Fatalf("peer with invalid numbered announcement not dropped")
   553  	}
   554  
   555  	goodHeaderFetcher := tester.makeHeaderFetcher("good", blocks, -gatherSlack)
   556  	goodBodyFetcher := tester.makeBodyFetcher("good", blocks, 0)
   557  
   558  	tester.fetcher.Notify("good", hashes[0], 1, time.Now().Add(-arriveTimeout), goodHeaderFetcher, goodBodyFetcher)
   559  	verifyImportEvent(t, imported, true)
   560  
   561  	tester.lock.RLock()
   562  	dropped = tester.drops["good"]
   563  	tester.lock.RUnlock()
   564  
   565  	if dropped {
   566  		t.Fatalf("peer with valid numbered announcement dropped")
   567  	}
   568  	verifyImportDone(t, imported)
   569  }
   570  
   571  func TestEmptyBlockShortCircuit62(t *testing.T) { testEmptyBlockShortCircuit(t, 62) }
   572  func TestEmptyBlockShortCircuit63(t *testing.T) { testEmptyBlockShortCircuit(t, 63) }
   573  func TestEmptyBlockShortCircuit64(t *testing.T) { testEmptyBlockShortCircuit(t, 64) }
   574  
   575  func testEmptyBlockShortCircuit(t *testing.T, protocol int) {
   576  
   577  	hashes, blocks := makeChain(32, 0, genesis)
   578  
   579  	tester := newTester()
   580  	headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
   581  	bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
   582  
   583  	fetching := make(chan []common.Hash)
   584  	tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- hashes }
   585  
   586  	completing := make(chan []common.Hash)
   587  	tester.fetcher.completingHook = func(hashes []common.Hash) { completing <- hashes }
   588  
   589  	imported := make(chan *types.Block)
   590  	tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
   591  
   592  	for i := len(hashes) - 2; i >= 0; i-- {
   593  		tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
   594  
   595  		verifyFetchingEvent(t, fetching, true)
   596  
   597  		verifyCompletingEvent(t, completing, len(blocks[hashes[i]].Transactions()) > 0 || len(blocks[hashes[i]].Uncles()) > 0)
   598  
   599  		verifyImportEvent(t, imported, true)
   600  	}
   601  	verifyImportDone(t, imported)
   602  }
   603  
   604  func TestHashMemoryExhaustionAttack62(t *testing.T) { testHashMemoryExhaustionAttack(t, 62) }
   605  func TestHashMemoryExhaustionAttack63(t *testing.T) { testHashMemoryExhaustionAttack(t, 63) }
   606  func TestHashMemoryExhaustionAttack64(t *testing.T) { testHashMemoryExhaustionAttack(t, 64) }
   607  
   608  func testHashMemoryExhaustionAttack(t *testing.T, protocol int) {
   609  
   610  	tester := newTester()
   611  
   612  	imported, announces := make(chan *types.Block), int32(0)
   613  	tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
   614  	tester.fetcher.announceChangeHook = func(hash common.Hash, added bool) {
   615  		if added {
   616  			atomic.AddInt32(&announces, 1)
   617  		} else {
   618  			atomic.AddInt32(&announces, -1)
   619  		}
   620  	}
   621  
   622  	targetBlocks := hashLimit + 2*maxQueueDist
   623  	hashes, blocks := makeChain(targetBlocks, 0, genesis)
   624  	validHeaderFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
   625  	validBodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
   626  
   627  	attack, _ := makeChain(targetBlocks, 0, unknownBlock)
   628  	attackerHeaderFetcher := tester.makeHeaderFetcher("attacker", nil, -gatherSlack)
   629  	attackerBodyFetcher := tester.makeBodyFetcher("attacker", nil, 0)
   630  
   631  	for i := 0; i < len(attack); i++ {
   632  		if i < maxQueueDist {
   633  			tester.fetcher.Notify("valid", hashes[len(hashes)-2-i], uint64(i+1), time.Now(), validHeaderFetcher, validBodyFetcher)
   634  		}
   635  		tester.fetcher.Notify("attacker", attack[i], 1, time.Now(), attackerHeaderFetcher, attackerBodyFetcher)
   636  	}
   637  	if count := atomic.LoadInt32(&announces); count != hashLimit+maxQueueDist {
   638  		t.Fatalf("queued announce count mismatch: have %d, want %d", count, hashLimit+maxQueueDist)
   639  	}
   640  
   641  	verifyImportCount(t, imported, maxQueueDist)
   642  
   643  	for i := len(hashes) - maxQueueDist - 2; i >= 0; i-- {
   644  		tester.fetcher.Notify("valid", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), validHeaderFetcher, validBodyFetcher)
   645  		verifyImportEvent(t, imported, true)
   646  	}
   647  	verifyImportDone(t, imported)
   648  }
   649  
   650  func TestBlockMemoryExhaustionAttack(t *testing.T) {
   651  
   652  	tester := newTester()
   653  
   654  	imported, enqueued := make(chan *types.Block), int32(0)
   655  	tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
   656  	tester.fetcher.queueChangeHook = func(hash common.Hash, added bool) {
   657  		if added {
   658  			atomic.AddInt32(&enqueued, 1)
   659  		} else {
   660  			atomic.AddInt32(&enqueued, -1)
   661  		}
   662  	}
   663  
   664  	targetBlocks := hashLimit + 2*maxQueueDist
   665  	hashes, blocks := makeChain(targetBlocks, 0, genesis)
   666  	attack := make(map[common.Hash]*types.Block)
   667  	for i := byte(0); len(attack) < blockLimit+2*maxQueueDist; i++ {
   668  		hashes, blocks := makeChain(maxQueueDist-1, i, unknownBlock)
   669  		for _, hash := range hashes[:maxQueueDist-2] {
   670  			attack[hash] = blocks[hash]
   671  		}
   672  	}
   673  
   674  	for _, block := range attack {
   675  		tester.fetcher.Enqueue("attacker", block)
   676  	}
   677  	time.Sleep(200 * time.Millisecond)
   678  	if queued := atomic.LoadInt32(&enqueued); queued != blockLimit {
   679  		t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit)
   680  	}
   681  
   682  	for i := 0; i < maxQueueDist-1; i++ {
   683  		tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-3-i]])
   684  	}
   685  	time.Sleep(100 * time.Millisecond)
   686  	if queued := atomic.LoadInt32(&enqueued); queued != blockLimit+maxQueueDist-1 {
   687  		t.Fatalf("queued block count mismatch: have %d, want %d", queued, blockLimit+maxQueueDist-1)
   688  	}
   689  
   690  	tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-2]])
   691  	verifyImportCount(t, imported, maxQueueDist)
   692  
   693  	for i := maxQueueDist; i < len(hashes)-1; i++ {
   694  		tester.fetcher.Enqueue("valid", blocks[hashes[len(hashes)-2-i]])
   695  		verifyImportEvent(t, imported, true)
   696  	}
   697  	verifyImportDone(t, imported)
   698  }