github.com/bytom/bytom@v1.1.2-0.20221014091027-bbcba3df6075/netsync/chainmgr/block_keeper_test.go (about)

     1  package chainmgr
     2  
     3  import (
     4  	"io/ioutil"
     5  	"os"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/bytom/bytom/consensus"
    10  	dbm "github.com/bytom/bytom/database/leveldb"
    11  	"github.com/bytom/bytom/errors"
    12  	msgs "github.com/bytom/bytom/netsync/messages"
    13  	"github.com/bytom/bytom/netsync/peers"
    14  	"github.com/bytom/bytom/protocol"
    15  	"github.com/bytom/bytom/protocol/bc"
    16  	"github.com/bytom/bytom/protocol/bc/types"
    17  	"github.com/bytom/bytom/test/mock"
    18  	"github.com/bytom/bytom/testcontrol"
    19  	"github.com/bytom/bytom/testutil"
    20  )
    21  
    22  func TestCheckSyncType(t *testing.T) {
    23  	tmp, err := ioutil.TempDir(".", "")
    24  	if err != nil {
    25  		t.Fatalf("failed to create temporary data folder: %v", err)
    26  	}
    27  	fastSyncDB := dbm.NewDB("testdb", "leveldb", tmp)
    28  	defer func() {
    29  		fastSyncDB.Close()
    30  		os.RemoveAll(tmp)
    31  	}()
    32  
    33  	blocks := mockBlocks(nil, 50)
    34  	chain := mock.NewChain()
    35  	chain.SetBestBlockHeader(&blocks[len(blocks)-1].BlockHeader)
    36  	for _, block := range blocks {
    37  		chain.SetBlockByHeight(block.Height, block)
    38  	}
    39  
    40  	type syncPeer struct {
    41  		peer               *P2PPeer
    42  		bestHeight         uint64
    43  		irreversibleHeight uint64
    44  	}
    45  
    46  	cases := []struct {
    47  		peers    []*syncPeer
    48  		syncType int
    49  	}{
    50  		{
    51  			peers:    []*syncPeer{},
    52  			syncType: noNeedSync,
    53  		},
    54  		{
    55  			peers: []*syncPeer{
    56  				{peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 500},
    57  				{peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 50, irreversibleHeight: 50},
    58  			},
    59  			syncType: fastSyncType,
    60  		},
    61  		{
    62  			peers: []*syncPeer{
    63  				{peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 100},
    64  				{peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 500, irreversibleHeight: 50},
    65  			},
    66  			syncType: regularSyncType,
    67  		},
    68  		{
    69  			peers: []*syncPeer{
    70  				{peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 51, irreversibleHeight: 50},
    71  			},
    72  			syncType: regularSyncType,
    73  		},
    74  		{
    75  			peers: []*syncPeer{
    76  				{peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 30, irreversibleHeight: 30},
    77  			},
    78  			syncType: noNeedSync,
    79  		},
    80  		{
    81  			peers: []*syncPeer{
    82  				{peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode}, bestHeight: 1000, irreversibleHeight: 1000},
    83  			},
    84  			syncType: regularSyncType,
    85  		},
    86  		{
    87  			peers: []*syncPeer{
    88  				{peer: &P2PPeer{id: "peer1", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 1000, irreversibleHeight: 50},
    89  				{peer: &P2PPeer{id: "peer2", flag: consensus.SFFullNode | consensus.SFFastSync}, bestHeight: 800, irreversibleHeight: 800},
    90  			},
    91  			syncType: fastSyncType,
    92  		},
    93  	}
    94  
    95  	for i, c := range cases {
    96  		peers := peers.NewPeerSet(NewPeerSet())
    97  		blockKeeper := newBlockKeeper(chain, peers, fastSyncDB)
    98  		for _, syncPeer := range c.peers {
    99  			blockKeeper.peers.AddPeer(syncPeer.peer)
   100  			blockKeeper.peers.SetStatus(syncPeer.peer.id, syncPeer.bestHeight, nil)
   101  			blockKeeper.peers.SetJustifiedStatus(syncPeer.peer.id, syncPeer.irreversibleHeight, nil)
   102  		}
   103  		gotType := blockKeeper.checkSyncType()
   104  		if c.syncType != gotType {
   105  			t.Errorf("case %d: got %d want %d", i, gotType, c.syncType)
   106  		}
   107  	}
   108  }
   109  
   110  func TestRegularBlockSync(t *testing.T) {
   111  	if testcontrol.IgnoreTestTemporary {
   112  		return
   113  	}
   114  
   115  	baseChain := mockBlocks(nil, 50)
   116  	chainX := append(baseChain, mockBlocks(baseChain[50], 60)...)
   117  	chainY := append(baseChain, mockBlocks(baseChain[50], 70)...)
   118  	chainZ := append(baseChain, mockBlocks(baseChain[50], 200)...)
   119  	chainE := append(baseChain, mockErrorBlocks(baseChain[50], 200, 60)...)
   120  
   121  	cases := []struct {
   122  		syncTimeout time.Duration
   123  		aBlocks     []*types.Block
   124  		bBlocks     []*types.Block
   125  		want        []*types.Block
   126  		err         error
   127  	}{
   128  		{
   129  			syncTimeout: 30 * time.Second,
   130  			aBlocks:     baseChain[:20],
   131  			bBlocks:     baseChain[:50],
   132  			want:        baseChain[:50],
   133  			err:         nil,
   134  		},
   135  		{
   136  			syncTimeout: 30 * time.Second,
   137  			aBlocks:     chainX,
   138  			bBlocks:     chainY,
   139  			want:        chainY,
   140  			err:         nil,
   141  		},
   142  		{
   143  			syncTimeout: 30 * time.Second,
   144  			aBlocks:     chainX[:52],
   145  			bBlocks:     chainY[:53],
   146  			want:        chainY[:53],
   147  			err:         nil,
   148  		},
   149  		{
   150  			syncTimeout: 30 * time.Second,
   151  			aBlocks:     chainX[:52],
   152  			bBlocks:     chainZ,
   153  			want:        chainZ[:180],
   154  			err:         nil,
   155  		},
   156  		{
   157  			syncTimeout: 0 * time.Second,
   158  			aBlocks:     chainX[:52],
   159  			bBlocks:     chainZ,
   160  			want:        chainX[:52],
   161  			err:         errRequestTimeout,
   162  		},
   163  		{
   164  			syncTimeout: 30 * time.Second,
   165  			aBlocks:     chainX[:52],
   166  			bBlocks:     chainE,
   167  			want:        chainE[:60],
   168  			err:         protocol.ErrBadStateRoot,
   169  		},
   170  	}
   171  	tmp, err := ioutil.TempDir(".", "")
   172  	if err != nil {
   173  		t.Fatalf("failed to create temporary data folder: %v", err)
   174  	}
   175  	testDBA := dbm.NewDB("testdba", "leveldb", tmp)
   176  	testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
   177  	defer func() {
   178  		testDBA.Close()
   179  		testDBB.Close()
   180  		os.RemoveAll(tmp)
   181  	}()
   182  
   183  	for i, c := range cases {
   184  		a := mockSync(c.aBlocks, nil, testDBA)
   185  		b := mockSync(c.bBlocks, nil, testDBB)
   186  		netWork := NewNetWork()
   187  		netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
   188  		netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
   189  		if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
   190  			t.Errorf("fail on peer hands shake %v", err)
   191  		} else {
   192  			go B2A.postMan()
   193  			go A2B.postMan()
   194  		}
   195  
   196  		requireBlockTimeout = c.syncTimeout
   197  		a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
   198  		if err := a.blockKeeper.regularBlockSync(); errors.Root(err) != c.err {
   199  			t.Errorf("case %d: got %v want %v", i, err, c.err)
   200  		}
   201  
   202  		got := []*types.Block{}
   203  		for i := uint64(0); i <= a.chain.BestBlockHeight(); i++ {
   204  			block, err := a.chain.GetBlockByHeight(i)
   205  			if err != nil {
   206  				t.Errorf("case %d got err %v", i, err)
   207  			}
   208  			got = append(got, block)
   209  		}
   210  
   211  		if !testutil.DeepEqual(got, c.want) {
   212  			t.Errorf("case %d: got %v want %v", i, got, c.want)
   213  		}
   214  	}
   215  }
   216  
   217  func TestRequireBlock(t *testing.T) {
   218  	if testcontrol.IgnoreTestTemporary {
   219  		return
   220  	}
   221  
   222  	tmp, err := ioutil.TempDir(".", "")
   223  	if err != nil {
   224  		t.Fatalf("failed to create temporary data folder: %v", err)
   225  	}
   226  	testDBA := dbm.NewDB("testdba", "leveldb", tmp)
   227  	testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
   228  	defer func() {
   229  		testDBB.Close()
   230  		testDBA.Close()
   231  		os.RemoveAll(tmp)
   232  	}()
   233  
   234  	blocks := mockBlocks(nil, 5)
   235  	a := mockSync(blocks[:1], nil, testDBA)
   236  	b := mockSync(blocks[:5], nil, testDBB)
   237  	netWork := NewNetWork()
   238  	netWork.Register(a, "192.168.0.1", "test node A", consensus.SFFullNode)
   239  	netWork.Register(b, "192.168.0.2", "test node B", consensus.SFFullNode)
   240  	if B2A, A2B, err := netWork.HandsShake(a, b); err != nil {
   241  		t.Errorf("fail on peer hands shake %v", err)
   242  	} else {
   243  		go B2A.postMan()
   244  		go A2B.postMan()
   245  	}
   246  
   247  	a.blockKeeper.syncPeer = a.peers.GetPeer("test node B")
   248  	b.blockKeeper.syncPeer = b.peers.GetPeer("test node A")
   249  	cases := []struct {
   250  		syncTimeout   time.Duration
   251  		testNode      *Manager
   252  		requireHeight uint64
   253  		want          *types.Block
   254  		err           error
   255  	}{
   256  		{
   257  			syncTimeout:   30 * time.Second,
   258  			testNode:      a,
   259  			requireHeight: 4,
   260  			want:          blocks[4],
   261  			err:           nil,
   262  		},
   263  		{
   264  			syncTimeout:   1 * time.Millisecond,
   265  			testNode:      b,
   266  			requireHeight: 4,
   267  			want:          nil,
   268  			err:           errRequestTimeout,
   269  		},
   270  	}
   271  
   272  	defer func() {
   273  		requireBlockTimeout = 20 * time.Second
   274  	}()
   275  
   276  	for i, c := range cases {
   277  		requireBlockTimeout = c.syncTimeout
   278  		got, err := c.testNode.blockKeeper.msgFetcher.requireBlock(c.testNode.blockKeeper.syncPeer.ID(), c.requireHeight)
   279  		if !testutil.DeepEqual(got, c.want) {
   280  			t.Errorf("case %d: got %v want %v", i, got, c.want)
   281  		}
   282  		if errors.Root(err) != c.err {
   283  			t.Errorf("case %d: got %v want %v", i, err, c.err)
   284  		}
   285  	}
   286  }
   287  
   288  func TestSendMerkleBlock(t *testing.T) {
   289  	if testcontrol.IgnoreTestTemporary {
   290  		return
   291  	}
   292  
   293  	tmp, err := ioutil.TempDir(".", "")
   294  	if err != nil {
   295  		t.Fatalf("failed to create temporary data folder: %v", err)
   296  	}
   297  
   298  	testDBA := dbm.NewDB("testdba", "leveldb", tmp)
   299  	testDBB := dbm.NewDB("testdbb", "leveldb", tmp)
   300  	defer func() {
   301  		testDBA.Close()
   302  		testDBB.Close()
   303  		os.RemoveAll(tmp)
   304  	}()
   305  
   306  	cases := []struct {
   307  		txCount        int
   308  		relatedTxIndex []int
   309  	}{
   310  		{
   311  			txCount:        10,
   312  			relatedTxIndex: []int{0, 2, 5},
   313  		},
   314  		{
   315  			txCount:        0,
   316  			relatedTxIndex: []int{},
   317  		},
   318  		{
   319  			txCount:        10,
   320  			relatedTxIndex: []int{},
   321  		},
   322  		{
   323  			txCount:        5,
   324  			relatedTxIndex: []int{0, 1, 2, 3, 4},
   325  		},
   326  		{
   327  			txCount:        20,
   328  			relatedTxIndex: []int{1, 6, 3, 9, 10, 19},
   329  		},
   330  	}
   331  
   332  	for _, c := range cases {
   333  		blocks := mockBlocks(nil, 2)
   334  		targetBlock := blocks[1]
   335  		txs, bcTxs := mockTxs(c.txCount)
   336  		var err error
   337  
   338  		targetBlock.Transactions = txs
   339  		if targetBlock.TransactionsMerkleRoot, err = types.TxMerkleRoot(bcTxs); err != nil {
   340  			t.Fatal(err)
   341  		}
   342  
   343  		spvNode := mockSync(blocks, nil, testDBA)
   344  		fullNode := mockSync(blocks, nil, testDBB)
   345  		netWork := NewNetWork()
   346  		netWork.Register(spvNode, "192.168.0.1", "spv_node", consensus.SFFastSync)
   347  		netWork.Register(fullNode, "192.168.0.2", "full_node", consensus.DefaultServices)
   348  
   349  		var F2S *P2PPeer
   350  		if F2S, _, err = netWork.HandsShake(spvNode, fullNode); err != nil {
   351  			t.Errorf("fail on peer hands shake %v", err)
   352  		}
   353  
   354  		completed := make(chan error)
   355  		go func() {
   356  			msgBytes := <-F2S.msgCh
   357  			_, msg, _ := decodeMessage(msgBytes)
   358  			switch m := msg.(type) {
   359  			case *msgs.MerkleBlockMessage:
   360  				var relatedTxIDs []*bc.Hash
   361  				for _, rawTx := range m.RawTxDatas {
   362  					tx := &types.Tx{}
   363  					if err := tx.UnmarshalText(rawTx); err != nil {
   364  						completed <- err
   365  					}
   366  
   367  					relatedTxIDs = append(relatedTxIDs, &tx.ID)
   368  				}
   369  				var txHashes []*bc.Hash
   370  				for _, hashByte := range m.TxHashes {
   371  					hash := bc.NewHash(hashByte)
   372  					txHashes = append(txHashes, &hash)
   373  				}
   374  				if ok := types.ValidateTxMerkleTreeProof(txHashes, m.Flags, relatedTxIDs, targetBlock.TransactionsMerkleRoot); !ok {
   375  					completed <- errors.New("validate tx fail")
   376  				}
   377  				completed <- nil
   378  			}
   379  		}()
   380  
   381  		spvPeer := fullNode.peers.GetPeer("spv_node")
   382  		for i := 0; i < len(c.relatedTxIndex); i++ {
   383  			spvPeer.AddFilterAddress(txs[c.relatedTxIndex[i]].Outputs[0].ControlProgram)
   384  		}
   385  		msg := &msgs.GetMerkleBlockMessage{RawHash: targetBlock.Hash().Byte32()}
   386  		fullNode.handleGetMerkleBlockMsg(spvPeer, msg)
   387  		if err := <-completed; err != nil {
   388  			t.Fatal(err)
   389  		}
   390  	}
   391  }
   392  
   393  func TestLocateBlocks(t *testing.T) {
   394  	if testcontrol.IgnoreTestTemporary {
   395  		return
   396  	}
   397  
   398  	maxNumOfBlocksPerMsg = 5
   399  	blocks := mockBlocks(nil, 100)
   400  	cases := []struct {
   401  		locator    []uint64
   402  		stopHash   bc.Hash
   403  		wantHeight []uint64
   404  		wantErr    error
   405  	}{
   406  		{
   407  			locator:    []uint64{20},
   408  			stopHash:   blocks[100].Hash(),
   409  			wantHeight: []uint64{20, 21, 22, 23, 24},
   410  			wantErr:    nil,
   411  		},
   412  		{
   413  			locator:    []uint64{20},
   414  			stopHash:   bc.NewHash([32]byte{0x01, 0x02}),
   415  			wantHeight: []uint64{},
   416  			wantErr:    mock.ErrFoundHeaderByHash,
   417  		},
   418  	}
   419  
   420  	mockChain := mock.NewChain()
   421  	bk := &blockKeeper{chain: mockChain}
   422  	for _, block := range blocks {
   423  		mockChain.SetBlockByHeight(block.Height, block)
   424  	}
   425  
   426  	for i, c := range cases {
   427  		locator := []*bc.Hash{}
   428  		for _, i := range c.locator {
   429  			hash := blocks[i].Hash()
   430  			locator = append(locator, &hash)
   431  		}
   432  
   433  		want := []*types.Block{}
   434  		for _, i := range c.wantHeight {
   435  			want = append(want, blocks[i])
   436  		}
   437  
   438  		mockTimeout := func() bool { return false }
   439  		got, err := bk.locateBlocks(locator, &c.stopHash, mockTimeout)
   440  		if err != c.wantErr {
   441  			t.Errorf("case %d: got %v want err = %v", i, err, c.wantErr)
   442  		}
   443  
   444  		if !testutil.DeepEqual(got, want) {
   445  			t.Errorf("case %d: got %v want %v", i, got, want)
   446  		}
   447  	}
   448  }
   449  
   450  func TestLocateHeaders(t *testing.T) {
   451  	if testcontrol.IgnoreTestTemporary {
   452  		return
   453  	}
   454  
   455  	defer func() {
   456  		maxNumOfHeadersPerMsg = 1000
   457  	}()
   458  	maxNumOfHeadersPerMsg = 10
   459  	blocks := mockBlocks(nil, 150)
   460  	blocksHash := []bc.Hash{}
   461  	for _, block := range blocks {
   462  		blocksHash = append(blocksHash, block.Hash())
   463  	}
   464  
   465  	cases := []struct {
   466  		chainHeight uint64
   467  		locator     []uint64
   468  		stopHash    *bc.Hash
   469  		skip        uint64
   470  		wantHeight  []uint64
   471  		err         error
   472  	}{
   473  		{
   474  			chainHeight: 100,
   475  			locator:     []uint64{90},
   476  			stopHash:    &blocksHash[100],
   477  			skip:        0,
   478  			wantHeight:  []uint64{90, 91, 92, 93, 94, 95, 96, 97, 98, 99},
   479  			err:         nil,
   480  		},
   481  		{
   482  			chainHeight: 100,
   483  			locator:     []uint64{20},
   484  			stopHash:    &blocksHash[24],
   485  			skip:        0,
   486  			wantHeight:  []uint64{20, 21, 22, 23, 24},
   487  			err:         nil,
   488  		},
   489  		{
   490  			chainHeight: 100,
   491  			locator:     []uint64{20},
   492  			stopHash:    &blocksHash[20],
   493  			wantHeight:  []uint64{20},
   494  			err:         nil,
   495  		},
   496  		{
   497  			chainHeight: 100,
   498  			locator:     []uint64{20},
   499  			stopHash:    &blocksHash[120],
   500  			wantHeight:  []uint64{},
   501  			err:         mock.ErrFoundHeaderByHash,
   502  		},
   503  		{
   504  			chainHeight: 100,
   505  			locator:     []uint64{120, 70},
   506  			stopHash:    &blocksHash[78],
   507  			wantHeight:  []uint64{70, 71, 72, 73, 74, 75, 76, 77, 78},
   508  			err:         nil,
   509  		},
   510  		{
   511  			chainHeight: 100,
   512  			locator:     []uint64{15},
   513  			stopHash:    &blocksHash[10],
   514  			skip:        10,
   515  			wantHeight:  []uint64{},
   516  			err:         nil,
   517  		},
   518  		{
   519  			chainHeight: 100,
   520  			locator:     []uint64{15},
   521  			stopHash:    &blocksHash[80],
   522  			skip:        10,
   523  			wantHeight:  []uint64{15, 26, 37, 48, 59, 70, 80},
   524  			err:         nil,
   525  		},
   526  		{
   527  			chainHeight: 100,
   528  			locator:     []uint64{0},
   529  			stopHash:    &blocksHash[100],
   530  			skip:        9,
   531  			wantHeight:  []uint64{0, 10, 20, 30, 40, 50, 60, 70, 80, 90},
   532  			err:         nil,
   533  		},
   534  	}
   535  
   536  	for i, c := range cases {
   537  		mockChain := mock.NewChain()
   538  		bk := &blockKeeper{chain: mockChain}
   539  		for i := uint64(0); i <= c.chainHeight; i++ {
   540  			mockChain.SetBlockByHeight(i, blocks[i])
   541  		}
   542  
   543  		locator := []*bc.Hash{}
   544  		for _, i := range c.locator {
   545  			hash := blocks[i].Hash()
   546  			locator = append(locator, &hash)
   547  		}
   548  
   549  		want := []*types.BlockHeader{}
   550  		for _, i := range c.wantHeight {
   551  			want = append(want, &blocks[i].BlockHeader)
   552  		}
   553  
   554  		got, err := bk.locateHeaders(locator, c.stopHash, c.skip, maxNumOfHeadersPerMsg)
   555  		if err != c.err {
   556  			t.Errorf("case %d: got %v want err = %v", i, err, c.err)
   557  		}
   558  		if !testutil.DeepEqual(got, want) {
   559  			t.Errorf("case %d: got %v want %v", i, got, want)
   560  		}
   561  	}
   562  }