github.com/koko1123/flow-go-1@v0.29.6/engine/collection/compliance/engine_test.go (about)

     1  package compliance
     2  
     3  import (
     4  	"math/rand"
     5  	"sync"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/mock"
    10  	"github.com/stretchr/testify/require"
    11  	"github.com/stretchr/testify/suite"
    12  
    13  	"github.com/koko1123/flow-go-1/consensus/hotstuff/model"
    14  	"github.com/koko1123/flow-go-1/engine"
    15  	"github.com/koko1123/flow-go-1/model/cluster"
    16  	"github.com/koko1123/flow-go-1/model/flow"
    17  	"github.com/koko1123/flow-go-1/model/messages"
    18  	module "github.com/koko1123/flow-go-1/module/mock"
    19  	netint "github.com/koko1123/flow-go-1/network"
    20  	"github.com/koko1123/flow-go-1/network/channels"
    21  	"github.com/koko1123/flow-go-1/network/mocknetwork"
    22  	protocol "github.com/koko1123/flow-go-1/state/protocol/mock"
    23  	storerr "github.com/koko1123/flow-go-1/storage"
    24  	storage "github.com/koko1123/flow-go-1/storage/mock"
    25  	"github.com/koko1123/flow-go-1/utils/unittest"
    26  )
    27  
    28  func TestComplianceEngine(t *testing.T) {
    29  	suite.Run(t, new(ComplianceSuite))
    30  }
    31  
    32  type ComplianceSuite struct {
    33  	ComplianceCoreSuite
    34  
    35  	clusterID  flow.ChainID
    36  	myID       flow.Identifier
    37  	cluster    flow.IdentityList
    38  	me         *module.Local
    39  	net        *mocknetwork.Network
    40  	payloads   *storage.ClusterPayloads
    41  	protoState *protocol.MutableState
    42  	con        *mocknetwork.Conduit
    43  
    44  	payloadDB map[flow.Identifier]*cluster.Payload
    45  
    46  	engine *Engine
    47  }
    48  
    49  func (cs *ComplianceSuite) SetupTest() {
    50  	cs.ComplianceCoreSuite.SetupTest()
    51  
    52  	// initialize the paramaters
    53  	cs.cluster = unittest.IdentityListFixture(3,
    54  		unittest.WithRole(flow.RoleCollection),
    55  		unittest.WithWeight(1000),
    56  	)
    57  	cs.myID = cs.cluster[0].NodeID
    58  
    59  	protoEpoch := &protocol.Epoch{}
    60  	clusters := flow.ClusterList{cs.cluster}
    61  	protoEpoch.On("Clustering").Return(clusters, nil)
    62  
    63  	protoQuery := &protocol.EpochQuery{}
    64  	protoQuery.On("Current").Return(protoEpoch)
    65  
    66  	protoSnapshot := &protocol.Snapshot{}
    67  	protoSnapshot.On("Epochs").Return(protoQuery)
    68  	protoSnapshot.On("Identities", mock.Anything).Return(
    69  		func(selector flow.IdentityFilter) flow.IdentityList {
    70  			return cs.cluster.Filter(selector)
    71  		},
    72  		nil,
    73  	)
    74  
    75  	cs.protoState = &protocol.MutableState{}
    76  	cs.protoState.On("Final").Return(protoSnapshot)
    77  
    78  	cs.clusterID = "cluster-id"
    79  	clusterParams := &protocol.Params{}
    80  	clusterParams.On("ChainID").Return(cs.clusterID, nil)
    81  
    82  	cs.state.On("Params").Return(clusterParams)
    83  
    84  	// set up local module mock
    85  	cs.me = &module.Local{}
    86  	cs.me.On("NodeID").Return(
    87  		func() flow.Identifier {
    88  			return cs.myID
    89  		},
    90  	)
    91  
    92  	cs.payloadDB = make(map[flow.Identifier]*cluster.Payload)
    93  
    94  	// set up payload storage mock
    95  	cs.payloads = &storage.ClusterPayloads{}
    96  	cs.payloads.On("Store", mock.Anything, mock.Anything).Return(
    97  		func(blockID flow.Identifier, payload *cluster.Payload) error {
    98  			cs.payloadDB[blockID] = payload
    99  			return nil
   100  		},
   101  	)
   102  	cs.payloads.On("ByBlockID", mock.Anything).Return(
   103  		func(blockID flow.Identifier) *cluster.Payload {
   104  			return cs.payloadDB[blockID]
   105  		},
   106  		func(blockID flow.Identifier) error {
   107  			_, exists := cs.payloadDB[blockID]
   108  			if !exists {
   109  				return storerr.ErrNotFound
   110  			}
   111  			return nil
   112  		},
   113  	)
   114  
   115  	// set up network conduit mock
   116  	cs.con = &mocknetwork.Conduit{}
   117  	cs.con.On("Publish", mock.Anything, mock.Anything).Return(nil)
   118  	cs.con.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil)
   119  	cs.con.On("Publish", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)
   120  	cs.con.On("Unicast", mock.Anything, mock.Anything).Return(nil)
   121  
   122  	// set up network module mock
   123  	cs.net = &mocknetwork.Network{}
   124  	cs.net.On("Register", mock.Anything, mock.Anything).Return(
   125  		func(channel channels.Channel, engine netint.MessageProcessor) netint.Conduit {
   126  			return cs.con
   127  		},
   128  		nil,
   129  	)
   130  
   131  	e, err := NewEngine(unittest.Logger(), cs.net, cs.me, cs.protoState, cs.payloads, cs.core)
   132  	require.NoError(cs.T(), err)
   133  	cs.engine = e
   134  
   135  	ready := func() <-chan struct{} {
   136  		channel := make(chan struct{})
   137  		close(channel)
   138  		return channel
   139  	}()
   140  
   141  	cs.hotstuff.On("Start", mock.Anything)
   142  	cs.hotstuff.On("Ready", mock.Anything).Return(ready)
   143  	<-cs.engine.Ready()
   144  }
   145  
   146  // TestSendVote tests that single vote can be send and properly processed
   147  func (cs *ComplianceSuite) TestSendVote() {
   148  	// create parameters to send a vote
   149  	blockID := unittest.IdentifierFixture()
   150  	view := rand.Uint64()
   151  	sig := unittest.SignatureFixture()
   152  	recipientID := unittest.IdentifierFixture()
   153  
   154  	// submit the vote
   155  	err := cs.engine.SendVote(blockID, view, sig, recipientID)
   156  	require.NoError(cs.T(), err, "should pass send vote")
   157  
   158  	done := func() <-chan struct{} {
   159  		channel := make(chan struct{})
   160  		close(channel)
   161  		return channel
   162  	}()
   163  
   164  	cs.hotstuff.On("Done", mock.Anything).Return(done)
   165  
   166  	// The vote is transmitted asynchronously. We allow 10ms for the vote to be received:
   167  	<-time.After(10 * time.Millisecond)
   168  	<-cs.engine.Done()
   169  
   170  	// check it was called with right params
   171  	vote := messages.ClusterBlockVote{
   172  		BlockID: blockID,
   173  		View:    view,
   174  		SigData: sig,
   175  	}
   176  	cs.con.AssertCalled(cs.T(), "Unicast", &vote, recipientID)
   177  }
   178  
   179  // TestBroadcastProposalWithDelay tests broadcasting proposals with different
   180  // inputs
   181  func (cs *ComplianceSuite) TestBroadcastProposalWithDelay() {
   182  
   183  	// generate a parent with height and chain ID set
   184  	parent := unittest.ClusterBlockFixture()
   185  	parent.Header.ChainID = "test"
   186  	parent.Header.Height = 10
   187  	cs.headerDB[parent.ID()] = &parent
   188  
   189  	// create a block with the parent and store the payload with correct ID
   190  	block := unittest.ClusterBlockWithParent(&parent)
   191  	block.Header.ProposerID = cs.myID
   192  	cs.payloadDB[block.ID()] = block.Payload
   193  
   194  	// keep a duplicate of the correct header to check against leader
   195  	header := block.Header
   196  
   197  	// unset chain and height to make sure they are correctly reconstructed
   198  	block.Header.ChainID = ""
   199  	block.Header.Height = 0
   200  
   201  	cs.hotstuff.On("SubmitProposal", block.Header, parent.Header.View).Return(doneChan()).Once()
   202  
   203  	// submit to broadcast proposal
   204  	err := cs.engine.BroadcastProposalWithDelay(block.Header, 0)
   205  	require.NoError(cs.T(), err, "header broadcast should pass")
   206  
   207  	// make sure chain ID and height were reconstructed and
   208  	// we broadcast to correct nodes
   209  	header.ChainID = "test"
   210  	header.Height = 11
   211  	msg := messages.NewClusterBlockProposal(&block)
   212  
   213  	done := func() <-chan struct{} {
   214  		channel := make(chan struct{})
   215  		close(channel)
   216  		return channel
   217  	}()
   218  
   219  	cs.hotstuff.On("Done", mock.Anything).Return(done)
   220  
   221  	<-time.After(10 * time.Millisecond)
   222  	<-cs.engine.Done()
   223  	cs.con.AssertCalled(cs.T(), "Publish", msg, cs.cluster[1].NodeID, cs.cluster[2].NodeID)
   224  
   225  	// should fail with wrong proposer
   226  	header.ProposerID = unittest.IdentifierFixture()
   227  	err = cs.engine.BroadcastProposalWithDelay(header, 0)
   228  	require.Error(cs.T(), err, "should fail with wrong proposer")
   229  	header.ProposerID = cs.myID
   230  
   231  	// should fail with changed (missing) parent
   232  	header.ParentID[0]++
   233  	err = cs.engine.BroadcastProposalWithDelay(header, 0)
   234  	require.Error(cs.T(), err, "should fail with missing parent")
   235  	header.ParentID[0]--
   236  
   237  	// should fail with wrong block ID (payload unavailable)
   238  	header.View++
   239  	err = cs.engine.BroadcastProposalWithDelay(header, 0)
   240  	require.Error(cs.T(), err, "should fail with missing payload")
   241  	header.View--
   242  }
   243  
   244  // TestSubmittingMultipleVotes tests that we can send multiple votes and they
   245  // are queued and processed in expected way
   246  func (cs *ComplianceSuite) TestSubmittingMultipleEntries() {
   247  	// create a vote
   248  	originID := unittest.IdentifierFixture()
   249  	voteCount := 15
   250  
   251  	channel := channels.ConsensusCluster(cs.clusterID)
   252  
   253  	var wg sync.WaitGroup
   254  	wg.Add(1)
   255  	go func() {
   256  		for i := 0; i < voteCount; i++ {
   257  			vote := messages.ClusterBlockVote{
   258  				BlockID: unittest.IdentifierFixture(),
   259  				View:    rand.Uint64(),
   260  				SigData: unittest.SignatureFixture(),
   261  			}
   262  			cs.voteAggregator.On("AddVote", &model.Vote{
   263  				View:     vote.View,
   264  				BlockID:  vote.BlockID,
   265  				SignerID: originID,
   266  				SigData:  vote.SigData,
   267  			}).Return().Once()
   268  			// execute the vote submission
   269  			_ = cs.engine.Process(channel, originID, &vote)
   270  		}
   271  		wg.Done()
   272  	}()
   273  	wg.Add(1)
   274  	go func() {
   275  		// create a proposal that directly descends from the latest finalized header
   276  		originID := cs.cluster[1].NodeID
   277  		block := unittest.ClusterBlockWithParent(cs.head)
   278  		proposal := messages.NewClusterBlockProposal(&block)
   279  
   280  		// store the data for retrieval
   281  		cs.headerDB[block.Header.ParentID] = cs.head
   282  		cs.hotstuff.On("SubmitProposal", block.Header, cs.head.Header.View).Return(doneChan())
   283  		_ = cs.engine.Process(channel, originID, proposal)
   284  		wg.Done()
   285  	}()
   286  
   287  	wg.Wait()
   288  
   289  	time.Sleep(time.Second)
   290  
   291  	// check that submit vote was called with correct parameters
   292  	cs.hotstuff.AssertExpectations(cs.T())
   293  	cs.voteAggregator.AssertExpectations(cs.T())
   294  }
   295  
   296  // TestProcessUnsupportedMessageType tests that Process and ProcessLocal correctly handle a case where invalid message type
   297  // was submitted from network layer.
   298  func (cs *ComplianceSuite) TestProcessUnsupportedMessageType() {
   299  	invalidEvent := uint64(42)
   300  	err := cs.engine.Process("ch", unittest.IdentifierFixture(), invalidEvent)
   301  	// shouldn't result in error since byzantine inputs are expected
   302  	require.NoError(cs.T(), err)
   303  	// in case of local processing error cannot be consumed since all inputs are trusted
   304  	err = cs.engine.ProcessLocal(invalidEvent)
   305  	require.Error(cs.T(), err)
   306  	require.True(cs.T(), engine.IsIncompatibleInputTypeError(err))
   307  }
   308  
   309  // TestOnFinalizedBlock tests if finalized block gets processed when send through `Engine`.
   310  // Tests the whole processing pipeline.
   311  func (cs *ComplianceSuite) TestOnFinalizedBlock() {
   312  	finalizedBlock := unittest.ClusterBlockFixture()
   313  	cs.head = &finalizedBlock
   314  
   315  	*cs.pending = module.PendingClusterBlockBuffer{}
   316  	cs.pending.On("PruneByView", finalizedBlock.Header.View).Return(nil).Once()
   317  	cs.pending.On("Size").Return(uint(0)).Once()
   318  	cs.engine.OnFinalizedBlock(model.BlockFromFlow(finalizedBlock.Header, finalizedBlock.Header.View-1))
   319  
   320  	require.Eventually(cs.T(),
   321  		func() bool {
   322  			return cs.pending.AssertCalled(cs.T(), "PruneByView", finalizedBlock.Header.View)
   323  		}, time.Second, time.Millisecond*20)
   324  }