github.com/klaytn/klaytn@v1.12.1/node/cn/handler_msg_test.go (about)

     1  // Copyright 2019 The klaytn Authors
     2  // This file is part of the klaytn library.
     3  //
     4  // The klaytn library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The klaytn library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the klaytn library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package cn
    18  
    19  import (
    20  	"errors"
    21  	"math/big"
    22  	"strings"
    23  	"sync/atomic"
    24  	"testing"
    25  
    26  	"github.com/golang/mock/gomock"
    27  	"github.com/klaytn/klaytn/blockchain/types"
    28  	"github.com/klaytn/klaytn/common"
    29  	"github.com/klaytn/klaytn/consensus/istanbul"
    30  	"github.com/klaytn/klaytn/networks/p2p"
    31  	mocks2 "github.com/klaytn/klaytn/node/cn/mocks"
    32  	"github.com/klaytn/klaytn/params"
    33  	"github.com/klaytn/klaytn/reward"
    34  	"github.com/klaytn/klaytn/rlp"
    35  	"github.com/klaytn/klaytn/work/mocks"
    36  	"github.com/stretchr/testify/assert"
    37  )
    38  
    39  var expectedErr = errors.New("some error")
    40  
    41  // generateMsg creates a message struct for message handling tests.
    42  func generateMsg(t *testing.T, msgCode uint64, data interface{}) p2p.Msg {
    43  	size, r, err := rlp.EncodeToReader(data)
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  	return p2p.Msg{
    48  		Code:    msgCode,
    49  		Size:    uint32(size),
    50  		Payload: r,
    51  	}
    52  }
    53  
    54  // prepareTestHandleNewBlockMsg creates structs for TestHandleNewBlockMsg_ tests.
    55  func prepareTestHandleNewBlockMsg(t *testing.T, mockCtrl *gomock.Controller, blockNum int) (*types.Block, p2p.Msg, *MockPeer, *mocks2.MockProtocolManagerFetcher) {
    56  	mockPeer := NewMockPeer(mockCtrl)
    57  
    58  	newBlock := newBlock(blockNum)
    59  	newBlock.ReceivedFrom = mockPeer
    60  	msg := generateMsg(t, NewBlockMsg, newBlockData{Block: newBlock, TD: big.NewInt(int64(blockNum))})
    61  
    62  	mockPeer.EXPECT().AddToKnownBlocks(newBlock.Hash()).Times(1)
    63  	mockPeer.EXPECT().GetID().Return(nodeids[0].String()).AnyTimes()
    64  
    65  	mockFetcher := mocks2.NewMockProtocolManagerFetcher(mockCtrl)
    66  	mockFetcher.EXPECT().Enqueue(nodeids[0].String(), newBlock).Times(1)
    67  
    68  	return newBlock, msg, mockPeer, mockFetcher
    69  }
    70  
    71  func prepareDownloader(t *testing.T) (*gomock.Controller, *mocks2.MockProtocolManagerDownloader, *MockPeer, *ProtocolManager) {
    72  	mockCtrl := gomock.NewController(t)
    73  	mockDownloader := mocks2.NewMockProtocolManagerDownloader(mockCtrl)
    74  
    75  	mockPeer := NewMockPeer(mockCtrl)
    76  	mockPeer.EXPECT().GetID().Return(nodeids[0].String()).AnyTimes()
    77  
    78  	pm := &ProtocolManager{downloader: mockDownloader}
    79  
    80  	return mockCtrl, mockDownloader, mockPeer, pm
    81  }
    82  
    83  func TestHandleBlockHeadersMsg(t *testing.T) {
    84  	headers := []*types.Header{blocks[0].Header(), blocks[1].Header()}
    85  	{
    86  		mockCtrl, _, mockPeer, pm := prepareDownloader(t)
    87  		msg := generateMsg(t, BlockHeadersMsg, blocks[0].Header())
    88  
    89  		assert.Error(t, handleBlockHeadersMsg(pm, mockPeer, msg))
    90  		mockCtrl.Finish()
    91  	}
    92  	{
    93  		mockCtrl, mockDownloader, mockPeer, pm := prepareDownloader(t)
    94  		msg := generateMsg(t, BlockHeadersMsg, headers)
    95  		mockDownloader.EXPECT().DeliverHeaders(nodeids[0].String(), gomock.Eq(headers)).Return(expectedErr).Times(1)
    96  
    97  		assert.NoError(t, handleBlockHeadersMsg(pm, mockPeer, msg))
    98  		mockCtrl.Finish()
    99  	}
   100  	{
   101  		mockCtrl, mockDownloader, mockPeer, pm := prepareDownloader(t)
   102  		msg := generateMsg(t, BlockHeadersMsg, headers)
   103  		mockDownloader.EXPECT().DeliverHeaders(nodeids[0].String(), gomock.Eq(headers)).Return(nil).Times(1)
   104  
   105  		assert.NoError(t, handleBlockHeadersMsg(pm, mockPeer, msg))
   106  		mockCtrl.Finish()
   107  	}
   108  }
   109  
   110  func prepareBlockChain(t *testing.T) (*gomock.Controller, *mocks.MockBlockChain, *MockPeer, *ProtocolManager) {
   111  	mockCtrl := gomock.NewController(t)
   112  	mockBlockChain := mocks.NewMockBlockChain(mockCtrl)
   113  
   114  	mockPeer := NewMockPeer(mockCtrl)
   115  	mockPeer.EXPECT().GetID().Return(nodeids[0].String()).AnyTimes()
   116  
   117  	pm := &ProtocolManager{blockchain: mockBlockChain}
   118  
   119  	return mockCtrl, mockBlockChain, mockPeer, pm
   120  }
   121  
   122  func TestHandleBlockBodiesRequestMsg(t *testing.T) {
   123  	{
   124  		mockCtrl, _, mockPeer, pm := prepareBlockChain(t)
   125  		msg := generateMsg(t, BlockBodiesRequestMsg, uint64(123)) // Non-list value to invoke an error
   126  
   127  		bodies, err := handleBlockBodiesRequest(pm, mockPeer, msg)
   128  		assert.Nil(t, bodies)
   129  		assert.Error(t, err)
   130  		mockCtrl.Finish()
   131  	}
   132  	{
   133  		requestedHashes := []common.Hash{hashes[0], hashes[1]}
   134  		returnedData := []rlp.RawValue{hashes[1][:], hashes[0][:]}
   135  
   136  		mockCtrl, mockBlockChain, mockPeer, pm := prepareBlockChain(t)
   137  		msg := generateMsg(t, BlockBodiesRequestMsg, requestedHashes)
   138  
   139  		mockBlockChain.EXPECT().GetBodyRLP(gomock.Eq(hashes[0])).Return(returnedData[0]).Times(1)
   140  		mockBlockChain.EXPECT().GetBodyRLP(gomock.Eq(hashes[1])).Return(returnedData[1]).Times(1)
   141  
   142  		bodies, err := handleBlockBodiesRequest(pm, mockPeer, msg)
   143  		assert.Equal(t, returnedData, bodies)
   144  		assert.NoError(t, err)
   145  		mockCtrl.Finish()
   146  	}
   147  }
   148  
   149  func TestHandleBlockBodiesMsg(t *testing.T) {
   150  	{
   151  		mockCtrl, _, mockPeer, pm := prepareDownloader(t)
   152  		msg := generateMsg(t, BlockBodiesMsg, blocks[0].Header())
   153  
   154  		assert.Error(t, handleBlockBodiesMsg(pm, mockPeer, msg))
   155  		mockCtrl.Finish()
   156  	}
   157  }
   158  
   159  func TestNodeDataRequestMsg(t *testing.T) {
   160  	{
   161  		mockCtrl, _, mockPeer, pm := prepareBlockChain(t)
   162  		msg := generateMsg(t, NodeDataRequestMsg, uint64(123)) // Non-list value to invoke an error
   163  
   164  		mockPeer.EXPECT().GetVersion().Return(klay63).AnyTimes()
   165  		assert.Error(t, pm.handleMsg(mockPeer, addrs[0], msg))
   166  		mockCtrl.Finish()
   167  	}
   168  	{
   169  		requestedHashes := []common.Hash{hashes[0], hashes[1]}
   170  		returnedData := [][]byte{hashes[1][:], hashes[0][:]}
   171  
   172  		mockCtrl, mockBlockChain, mockPeer, pm := prepareBlockChain(t)
   173  		msg := generateMsg(t, NodeDataRequestMsg, requestedHashes)
   174  
   175  		mockBlockChain.EXPECT().TrieNode(gomock.Eq(hashes[0])).Return(returnedData[0], nil).Times(1)
   176  		mockBlockChain.EXPECT().TrieNode(gomock.Eq(hashes[1])).Return(returnedData[1], nil).Times(1)
   177  
   178  		mockPeer.EXPECT().SendNodeData(returnedData).Return(nil).Times(1)
   179  
   180  		mockPeer.EXPECT().GetVersion().Return(klay63).AnyTimes()
   181  		assert.NoError(t, pm.handleMsg(mockPeer, addrs[0], msg))
   182  		mockCtrl.Finish()
   183  	}
   184  }
   185  
   186  func TestHandleReceiptsRequestMsg(t *testing.T) {
   187  	{
   188  		mockCtrl, _, mockPeer, pm := prepareBlockChain(t)
   189  		msg := generateMsg(t, ReceiptsRequestMsg, uint64(123)) // Non-list value to invoke an error
   190  
   191  		mockPeer.EXPECT().GetVersion().Return(klay63).AnyTimes()
   192  		assert.Error(t, pm.handleMsg(mockPeer, addrs[0], msg))
   193  		mockCtrl.Finish()
   194  	}
   195  	{
   196  		requestedHashes := []common.Hash{hashes[0], hashes[1]}
   197  
   198  		rct1 := newReceipt(123)
   199  
   200  		mockCtrl, mockBlockChain, mockPeer, pm := prepareBlockChain(t)
   201  		msg := generateMsg(t, ReceiptsRequestMsg, requestedHashes)
   202  
   203  		mockBlockChain.EXPECT().GetReceiptsByBlockHash(gomock.Eq(hashes[0])).Return(types.Receipts{rct1}).Times(1)
   204  		mockBlockChain.EXPECT().GetReceiptsByBlockHash(gomock.Eq(hashes[1])).Return(nil).Times(1)
   205  		mockBlockChain.EXPECT().GetHeaderByHash(gomock.Eq(hashes[1])).Return(nil).Times(1)
   206  
   207  		mockPeer.EXPECT().SendReceiptsRLP(gomock.Any()).Return(nil).Times(1)
   208  
   209  		mockPeer.EXPECT().GetVersion().Return(klay63).AnyTimes()
   210  		assert.NoError(t, pm.handleMsg(mockPeer, addrs[0], msg))
   211  		mockCtrl.Finish()
   212  	}
   213  }
   214  
   215  func TestHandleNewBlockMsg_LargeLocalPeerBlockScore(t *testing.T) {
   216  	mockCtrl := gomock.NewController(t)
   217  	defer mockCtrl.Finish()
   218  
   219  	_, msg, mockPeer, mockFetcher := prepareTestHandleNewBlockMsg(t, mockCtrl, blockNum1)
   220  
   221  	pm := &ProtocolManager{}
   222  	pm.fetcher = mockFetcher
   223  
   224  	mockPeer.EXPECT().Head().Return(hash1, big.NewInt(blockNum1+1)).AnyTimes()
   225  
   226  	assert.NoError(t, handleNewBlockMsg(pm, mockPeer, msg))
   227  }
   228  
   229  func TestHandleNewBlockMsg_SmallLocalPeerBlockScore_NoSynchronise(t *testing.T) {
   230  	mockCtrl := gomock.NewController(t)
   231  	defer mockCtrl.Finish()
   232  
   233  	block, msg, mockPeer, mockFetcher := prepareTestHandleNewBlockMsg(t, mockCtrl, blockNum1)
   234  
   235  	pm := &ProtocolManager{}
   236  	pm.fetcher = mockFetcher
   237  
   238  	mockPeer.EXPECT().Head().Return(hash1, big.NewInt(blockNum1-2)).AnyTimes()
   239  	mockPeer.EXPECT().SetHead(block.ParentHash(), big.NewInt(blockNum1-1)).Times(1)
   240  
   241  	currBlock := newBlock(blockNum1 - 1)
   242  	mockBlockChain := mocks.NewMockBlockChain(mockCtrl)
   243  	mockBlockChain.EXPECT().CurrentBlock().Return(currBlock).Times(1)
   244  	mockBlockChain.EXPECT().GetTd(currBlock.Hash(), currBlock.NumberU64()).Return(big.NewInt(blockNum1)).Times(1)
   245  
   246  	pm.blockchain = mockBlockChain
   247  
   248  	assert.NoError(t, handleNewBlockMsg(pm, mockPeer, msg))
   249  }
   250  
   251  func TestHandleTxMsg(t *testing.T) {
   252  	pm := &ProtocolManager{}
   253  	mockCtrl := gomock.NewController(t)
   254  	defer mockCtrl.Finish()
   255  	mockPeer := NewMockPeer(mockCtrl)
   256  	mockPeer.EXPECT().GetVersion().Return(klay63).AnyTimes()
   257  
   258  	txs := types.Transactions{tx1}
   259  	msg := generateMsg(t, TxMsg, txs)
   260  
   261  	// If pm.acceptTxs == 0, nothing happens.
   262  	{
   263  		assert.NoError(t, pm.handleMsg(mockPeer, addrs[0], msg))
   264  	}
   265  	// If pm.acceptTxs == 1, TxPool.HandleTxMsg is called.
   266  	{
   267  		atomic.StoreUint32(&pm.acceptTxs, 1)
   268  		mockTxPool := mocks.NewMockTxPool(mockCtrl)
   269  
   270  		// The time field in received transaction through pm.handleMsg() has different value from generated transaction(`tx1`).
   271  		// It can check whether the transaction created `HandleTxMsg()` is the same as `tx1` through `AddToKnownTxs(txs[0].Hash())`.
   272  		mockTxPool.EXPECT().HandleTxMsg(gomock.Any()).AnyTimes()
   273  		pm.txpool = mockTxPool
   274  
   275  		mockPeer.EXPECT().AddToKnownTxs(txs[0].Hash()).Times(1)
   276  		assert.NoError(t, pm.handleMsg(mockPeer, addrs[0], msg))
   277  	}
   278  }
   279  
   280  func prepareTestHandleBlockHeaderFetchRequestMsg(t *testing.T) (*gomock.Controller, *MockPeer, *mocks.MockBlockChain, *ProtocolManager) {
   281  	mockCtrl := gomock.NewController(t)
   282  	mockPeer := NewMockPeer(mockCtrl)
   283  	mockBlockChain := mocks.NewMockBlockChain(mockCtrl)
   284  	mockPeer.EXPECT().GetVersion().Return(klay63).AnyTimes()
   285  
   286  	return mockCtrl, mockPeer, mockBlockChain, &ProtocolManager{blockchain: mockBlockChain}
   287  }
   288  
   289  func TestHandleBlockHeaderFetchRequestMsg(t *testing.T) {
   290  	// Decoding the message failed, an error is returned.
   291  	{
   292  		mockCtrl, mockPeer, _, pm := prepareTestHandleBlockHeaderFetchRequestMsg(t)
   293  
   294  		msg := generateMsg(t, BlockHeaderFetchRequestMsg, newBlock(blockNum1)) // use message data as a block, not a hash
   295  
   296  		assert.Error(t, pm.handleMsg(mockPeer, addrs[0], msg))
   297  		mockCtrl.Finish()
   298  	}
   299  	// GetHeaderByHash returns nil, an error is returned.
   300  	{
   301  		mockCtrl, mockPeer, mockBlockChain, pm := prepareTestHandleBlockHeaderFetchRequestMsg(t)
   302  		mockBlockChain.EXPECT().GetHeaderByHash(hash1).Return(nil).AnyTimes()
   303  		mockPeer.EXPECT().GetID().Return(nodeids[0].String()).AnyTimes()
   304  
   305  		msg := generateMsg(t, BlockHeaderFetchRequestMsg, hash1)
   306  
   307  		assert.Error(t, pm.handleMsg(mockPeer, addrs[0], msg))
   308  		mockCtrl.Finish()
   309  	}
   310  	// GetHeaderByHash returns a header, p.SendFetchedBlockHeader(header) should be called.
   311  	{
   312  		mockCtrl, mockPeer, mockBlockChain, pm := prepareTestHandleBlockHeaderFetchRequestMsg(t)
   313  
   314  		header := newBlock(blockNum1).Header()
   315  
   316  		mockBlockChain.EXPECT().GetHeaderByHash(hash1).Return(header).AnyTimes()
   317  		mockPeer.EXPECT().SendFetchedBlockHeader(header).AnyTimes()
   318  
   319  		msg := generateMsg(t, BlockHeaderFetchRequestMsg, hash1)
   320  		assert.NoError(t, pm.handleMsg(mockPeer, addrs[0], msg))
   321  		mockCtrl.Finish()
   322  	}
   323  }
   324  
   325  func prepareTestHandleBlockHeaderFetchResponseMsg(t *testing.T) (*gomock.Controller, *MockPeer, *mocks2.MockProtocolManagerFetcher, *ProtocolManager) {
   326  	mockCtrl := gomock.NewController(t)
   327  	mockPeer := NewMockPeer(mockCtrl)
   328  	mockPeer.EXPECT().GetVersion().Return(klay63).AnyTimes()
   329  
   330  	mockFetcher := mocks2.NewMockProtocolManagerFetcher(mockCtrl)
   331  	pm := &ProtocolManager{fetcher: mockFetcher}
   332  
   333  	return mockCtrl, mockPeer, mockFetcher, pm
   334  }
   335  
   336  func TestHandleBlockHeaderFetchResponseMsg(t *testing.T) {
   337  	header := newBlock(blockNum1).Header()
   338  	// Decoding the message failed, an error is returned.
   339  	{
   340  		mockCtrl := gomock.NewController(t)
   341  		mockPeer := NewMockPeer(mockCtrl)
   342  		mockPeer.EXPECT().GetVersion().Return(klay63).AnyTimes()
   343  		pm := &ProtocolManager{}
   344  		msg := generateMsg(t, BlockHeaderFetchResponseMsg, newBlock(blockNum1)) // use message data as a block, not a header
   345  		assert.Error(t, pm.handleMsg(mockPeer, addrs[0], msg))
   346  		mockCtrl.Finish()
   347  	}
   348  	// FilterHeaders returns nil, error is not returned.
   349  	{
   350  		mockCtrl, mockPeer, mockFetcher, pm := prepareTestHandleBlockHeaderFetchResponseMsg(t)
   351  		mockPeer.EXPECT().GetID().Return(nodeids[0].String()).AnyTimes()
   352  		mockFetcher.EXPECT().FilterHeaders(nodeids[0].String(), gomock.Eq([]*types.Header{header}), gomock.Any()).Return(nil).AnyTimes()
   353  
   354  		msg := generateMsg(t, BlockHeaderFetchResponseMsg, header)
   355  		assert.NoError(t, pm.handleMsg(mockPeer, addrs[0], msg))
   356  		mockCtrl.Finish()
   357  	}
   358  	// FilterHeaders returns not-nil, peer.GetID() is called twice to leave a log.
   359  	{
   360  		mockCtrl, mockPeer, mockFetcher, pm := prepareTestHandleBlockHeaderFetchResponseMsg(t)
   361  		mockPeer.EXPECT().GetID().Return(nodeids[0].String()).AnyTimes()
   362  		mockFetcher.EXPECT().FilterHeaders(nodeids[0].String(), gomock.Eq([]*types.Header{header}), gomock.Any()).Return([]*types.Header{header}).AnyTimes()
   363  
   364  		msg := generateMsg(t, BlockHeaderFetchResponseMsg, header)
   365  		assert.NoError(t, pm.handleMsg(mockPeer, addrs[0], msg))
   366  		mockCtrl.Finish()
   367  	}
   368  }
   369  
   370  func preparePeerAndDownloader(t *testing.T) (*gomock.Controller, *MockPeer, *mocks2.MockProtocolManagerDownloader, *ProtocolManager) {
   371  	mockCtrl := gomock.NewController(t)
   372  	mockPeer := NewMockPeer(mockCtrl)
   373  	mockPeer.EXPECT().GetID().Return(nodeids[0].String()).AnyTimes()
   374  	mockPeer.EXPECT().GetVersion().Return(klay63).AnyTimes()
   375  
   376  	mockDownloader := mocks2.NewMockProtocolManagerDownloader(mockCtrl)
   377  	pm := &ProtocolManager{downloader: mockDownloader}
   378  
   379  	return mockCtrl, mockPeer, mockDownloader, pm
   380  }
   381  
   382  func TestHandleReceiptMsg(t *testing.T) {
   383  	// Decoding the message failed, an error is returned.
   384  	{
   385  		mockCtrl := gomock.NewController(t)
   386  		mockPeer := NewMockPeer(mockCtrl)
   387  		mockPeer.EXPECT().GetVersion().Return(klay63).AnyTimes()
   388  
   389  		pm := &ProtocolManager{}
   390  		msg := generateMsg(t, ReceiptsMsg, newBlock(blockNum1)) // use message data as a block, not a header
   391  		assert.Error(t, pm.handleMsg(mockPeer, addrs[0], msg))
   392  		mockCtrl.Finish()
   393  	}
   394  	// DeliverReceipts returns nil, error is not returned.
   395  	{
   396  		receipts := make([][]*types.Receipt, 1)
   397  		receipts[0] = []*types.Receipt{newReceipt(123)}
   398  
   399  		mockCtrl, mockPeer, mockDownloader, pm := preparePeerAndDownloader(t)
   400  		mockDownloader.EXPECT().DeliverReceipts(nodeids[0].String(), gomock.Eq(receipts)).Times(1).Return(nil)
   401  
   402  		msg := generateMsg(t, ReceiptsMsg, receipts)
   403  		assert.NoError(t, pm.handleMsg(mockPeer, addrs[0], msg))
   404  		mockCtrl.Finish()
   405  	}
   406  	// DeliverReceipts returns an error, but the error is not returned.
   407  	{
   408  		receipts := make([][]*types.Receipt, 1)
   409  		receipts[0] = []*types.Receipt{newReceipt(123)}
   410  
   411  		mockCtrl, mockPeer, mockDownloader, pm := preparePeerAndDownloader(t)
   412  		mockDownloader.EXPECT().DeliverReceipts(nodeids[0].String(), gomock.Eq(receipts)).Times(1).Return(expectedErr)
   413  
   414  		msg := generateMsg(t, ReceiptsMsg, receipts)
   415  		assert.NoError(t, pm.handleMsg(mockPeer, addrs[0], msg))
   416  		mockCtrl.Finish()
   417  	}
   418  }
   419  
   420  func TestHandleNodeDataMsg(t *testing.T) {
   421  	// Decoding the message failed, an error is returned.
   422  	{
   423  		mockCtrl := gomock.NewController(t)
   424  		mockPeer := NewMockPeer(mockCtrl)
   425  		mockPeer.EXPECT().GetVersion().Return(klay63).AnyTimes()
   426  		pm := &ProtocolManager{}
   427  		msg := generateMsg(t, NodeDataMsg, newBlock(blockNum1)) // use message data as a block, not a node data
   428  		assert.Error(t, pm.handleMsg(mockPeer, addrs[0], msg))
   429  		mockCtrl.Finish()
   430  	}
   431  	// DeliverNodeData returns nil, error is not returned.
   432  	{
   433  		nodeData := make([][]byte, 1)
   434  		nodeData[0] = hash1[:]
   435  
   436  		mockCtrl, mockPeer, mockDownloader, pm := preparePeerAndDownloader(t)
   437  		mockDownloader.EXPECT().DeliverNodeData(nodeids[0].String(), gomock.Eq(nodeData)).Times(1).Return(nil)
   438  
   439  		msg := generateMsg(t, NodeDataMsg, nodeData)
   440  		assert.NoError(t, pm.handleMsg(mockPeer, addrs[0], msg))
   441  		mockCtrl.Finish()
   442  	}
   443  	// DeliverNodeData returns an error, but the error is not returned.
   444  	{
   445  		nodeData := make([][]byte, 1)
   446  		nodeData[0] = hash1[:]
   447  
   448  		mockCtrl, mockPeer, mockDownloader, pm := preparePeerAndDownloader(t)
   449  		mockDownloader.EXPECT().DeliverNodeData(nodeids[0].String(), gomock.Eq(nodeData)).Times(1).Return(expectedErr)
   450  
   451  		msg := generateMsg(t, NodeDataMsg, nodeData)
   452  		assert.NoError(t, pm.handleMsg(mockPeer, addrs[0], msg))
   453  		mockCtrl.Finish()
   454  	}
   455  }
   456  
   457  func TestHandleStakingInfoRequestMsg(t *testing.T) {
   458  	testChainConfig := params.TestChainConfig
   459  
   460  	{
   461  		// test if chain config istanbul is nil
   462  		mockCtrl, _, mockPeer, pm := prepareBlockChain(t)
   463  		testChainConfig.Istanbul = nil
   464  		pm.chainconfig = testChainConfig
   465  
   466  		err := handleStakingInfoRequestMsg(pm, mockPeer, p2p.Msg{})
   467  		assert.Error(t, err)
   468  		assert.Equal(t, err, errResp(ErrUnsupportedEnginePolicy, "the engine is not istanbul or the policy is not weighted random"))
   469  		mockCtrl.Finish()
   470  	}
   471  	{
   472  		// test if chain config istanbul is not nil, but proposer policy is not weighted random
   473  		mockCtrl, _, mockPeer, pm := prepareBlockChain(t)
   474  		testChainConfig.Istanbul = params.GetDefaultIstanbulConfig()
   475  		testChainConfig.Istanbul.ProposerPolicy = uint64(istanbul.RoundRobin)
   476  		pm.chainconfig = testChainConfig
   477  
   478  		err := handleStakingInfoRequestMsg(pm, mockPeer, p2p.Msg{})
   479  		assert.Error(t, err)
   480  		assert.Equal(t, err, errResp(ErrUnsupportedEnginePolicy, "the engine is not istanbul or the policy is not weighted random"))
   481  		mockCtrl.Finish()
   482  	}
   483  	{
   484  		// test if message does not contain expected data
   485  		mockCtrl, _, mockPeer, pm := prepareBlockChain(t)
   486  		testChainConfig.Istanbul = params.GetDefaultIstanbulConfig()
   487  		testChainConfig.Istanbul.ProposerPolicy = uint64(istanbul.WeightedRandom)
   488  		pm.chainconfig = testChainConfig
   489  		msg := generateMsg(t, StakingInfoRequestMsg, uint64(123)) // Non-list value to invoke an error
   490  
   491  		err := handleStakingInfoRequestMsg(pm, mockPeer, msg)
   492  		assert.Error(t, err)
   493  		assert.Equal(t, err, rlp.ErrExpectedList)
   494  		mockCtrl.Finish()
   495  	}
   496  
   497  	// Setup governance items for testing
   498  	orig := reward.GetStakingManager()
   499  	defer reward.SetTestStakingManager(orig)
   500  
   501  	testBlock := uint64(4)
   502  	testStakingInfo := newStakingInfo(testBlock)
   503  	reward.SetTestStakingManagerWithStakingInfoCache(testStakingInfo)
   504  	params.SetStakingUpdateInterval(testBlock)
   505  
   506  	{
   507  		requestedHashes := []common.Hash{hashes[0], hashes[1]}
   508  
   509  		mockCtrl, mockBlockChain, mockPeer, pm := prepareBlockChain(t)
   510  		testChainConfig.Istanbul = &params.IstanbulConfig{ProposerPolicy: uint64(istanbul.WeightedRandom)}
   511  		pm.chainconfig = testChainConfig
   512  
   513  		msg := generateMsg(t, StakingInfoRequestMsg, requestedHashes)
   514  
   515  		mockBlockChain.EXPECT().GetHeaderByHash(gomock.Eq(hashes[0])).Return(&types.Header{Number: big.NewInt(int64(testBlock))}).Times(1)
   516  		mockBlockChain.EXPECT().GetHeaderByHash(gomock.Eq(hashes[1])).Return(&types.Header{Number: big.NewInt(int64(5))}).Times(1) // not on staking block
   517  		data, _ := rlp.EncodeToBytes(testStakingInfo)
   518  		mockPeer.EXPECT().SendStakingInfoRLP(gomock.Eq([]rlp.RawValue{data})).Return(nil).Times(1)
   519  
   520  		err := handleStakingInfoRequestMsg(pm, mockPeer, msg)
   521  		assert.NoError(t, err)
   522  		mockCtrl.Finish()
   523  	}
   524  }
   525  
   526  func TestHandleStakingInfoMsg(t *testing.T) {
   527  	testChainConfig := params.TestChainConfig
   528  	{
   529  		// test if chain config istanbul is nil
   530  		mockCtrl, _, mockPeer, pm := prepareBlockChain(t)
   531  		testChainConfig.Istanbul = nil
   532  		pm.chainconfig = testChainConfig
   533  
   534  		err := handleStakingInfoMsg(pm, mockPeer, p2p.Msg{})
   535  		assert.Error(t, err)
   536  		assert.Equal(t, err, errResp(ErrUnsupportedEnginePolicy, "the engine is not istanbul or the policy is not weighted random"))
   537  		mockCtrl.Finish()
   538  	}
   539  	{
   540  		// test if chain config istanbul is not nil, but proposer policy is not weighted random
   541  		mockCtrl, _, mockPeer, pm := prepareBlockChain(t)
   542  		testChainConfig.Istanbul = params.GetDefaultIstanbulConfig()
   543  		testChainConfig.Istanbul.ProposerPolicy = uint64(istanbul.RoundRobin)
   544  		pm.chainconfig = testChainConfig
   545  
   546  		err := handleStakingInfoMsg(pm, mockPeer, p2p.Msg{})
   547  		assert.Error(t, err)
   548  		assert.Equal(t, err, errResp(ErrUnsupportedEnginePolicy, "the engine is not istanbul or the policy is not weighted random"))
   549  		mockCtrl.Finish()
   550  	}
   551  	{
   552  		// test if message does not contain expected data
   553  		mockCtrl, _, mockPeer, pm := prepareBlockChain(t)
   554  		testChainConfig.Istanbul = params.GetDefaultIstanbulConfig()
   555  		testChainConfig.Istanbul.ProposerPolicy = uint64(istanbul.WeightedRandom)
   556  		pm.chainconfig = testChainConfig
   557  		msg := generateMsg(t, StakingInfoRequestMsg, uint64(123)) // Non-list value to invoke an error
   558  
   559  		err := handleStakingInfoMsg(pm, mockPeer, msg)
   560  		assert.Error(t, err)
   561  		assert.True(t, strings.Contains(err.Error(), errCode(ErrDecode).String()))
   562  		mockCtrl.Finish()
   563  	}
   564  
   565  	// Setup governance items for testing
   566  	orig := reward.GetStakingManager()
   567  	defer reward.SetTestStakingManager(orig)
   568  
   569  	testBlock := uint64(4)
   570  	testStakingInfo := newStakingInfo(testBlock)
   571  	reward.SetTestStakingManagerWithStakingInfoCache(testStakingInfo)
   572  	params.SetStakingUpdateInterval(testBlock)
   573  
   574  	{
   575  		stakingInfos := []*reward.StakingInfo{testStakingInfo}
   576  
   577  		mockCtrl, mockPeer, mockDownloader, pm := preparePeerAndDownloader(t)
   578  		testChainConfig.Istanbul = params.GetDefaultIstanbulConfig()
   579  		testChainConfig.Istanbul.ProposerPolicy = uint64(istanbul.WeightedRandom)
   580  		pm.chainconfig = testChainConfig
   581  
   582  		mockDownloader.EXPECT().DeliverStakingInfos(gomock.Eq(nodeids[0].String()), gomock.Eq(stakingInfos)).Times(1).Return(expectedErr)
   583  
   584  		msg := generateMsg(t, StakingInfoMsg, stakingInfos)
   585  		err := handleStakingInfoMsg(pm, mockPeer, msg)
   586  		assert.NoError(t, err)
   587  		mockCtrl.Finish()
   588  	}
   589  }