github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/collection/message_hub/message_hub_test.go (about)

     1  package message_hub
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/mock"
    11  	"github.com/stretchr/testify/require"
    12  	"github.com/stretchr/testify/suite"
    13  
    14  	"github.com/onflow/flow-go/consensus/hotstuff/helper"
    15  	hotstuff "github.com/onflow/flow-go/consensus/hotstuff/mocks"
    16  	"github.com/onflow/flow-go/consensus/hotstuff/model"
    17  	mockcollection "github.com/onflow/flow-go/engine/collection/mock"
    18  	"github.com/onflow/flow-go/model/cluster"
    19  	"github.com/onflow/flow-go/model/flow"
    20  	"github.com/onflow/flow-go/model/messages"
    21  	"github.com/onflow/flow-go/module/irrecoverable"
    22  	"github.com/onflow/flow-go/module/metrics"
    23  	module "github.com/onflow/flow-go/module/mock"
    24  	"github.com/onflow/flow-go/module/util"
    25  	netint "github.com/onflow/flow-go/network"
    26  	"github.com/onflow/flow-go/network/channels"
    27  	"github.com/onflow/flow-go/network/mocknetwork"
    28  	clusterint "github.com/onflow/flow-go/state/cluster"
    29  	clusterstate "github.com/onflow/flow-go/state/cluster/mock"
    30  	protocol "github.com/onflow/flow-go/state/protocol/mock"
    31  	storerr "github.com/onflow/flow-go/storage"
    32  	storage "github.com/onflow/flow-go/storage/mock"
    33  	"github.com/onflow/flow-go/utils/unittest"
    34  )
    35  
    36  func TestMessageHub(t *testing.T) {
    37  	suite.Run(t, new(MessageHubSuite))
    38  }
    39  
    40  // MessageHubSuite tests the cluster consensus message hub. Holds mocked dependencies that are used by different test scenarios.
    41  type MessageHubSuite struct {
    42  	suite.Suite
    43  
    44  	// parameters
    45  	cluster   flow.IdentityList
    46  	clusterID flow.ChainID
    47  	myID      flow.Identifier
    48  	head      *cluster.Block
    49  
    50  	// mocked dependencies
    51  	payloads          *storage.ClusterPayloads
    52  	me                *module.Local
    53  	state             *clusterstate.MutableState
    54  	protoState        *protocol.State
    55  	net               *mocknetwork.Network
    56  	con               *mocknetwork.Conduit
    57  	hotstuff          *module.HotStuff
    58  	voteAggregator    *hotstuff.VoteAggregator
    59  	timeoutAggregator *hotstuff.TimeoutAggregator
    60  	compliance        *mockcollection.Compliance
    61  	snapshot          *clusterstate.Snapshot
    62  
    63  	ctx    irrecoverable.SignalerContext
    64  	cancel context.CancelFunc
    65  	errs   <-chan error
    66  	hub    *MessageHub
    67  }
    68  
    69  func (s *MessageHubSuite) SetupTest() {
    70  	// initialize the paramaters
    71  	s.cluster = unittest.IdentityListFixture(3,
    72  		unittest.WithRole(flow.RoleCollection),
    73  		unittest.WithInitialWeight(1000),
    74  	)
    75  	s.myID = s.cluster[0].NodeID
    76  	s.clusterID = "cluster-id"
    77  	block := unittest.ClusterBlockFixture()
    78  	s.head = &block
    79  
    80  	s.payloads = storage.NewClusterPayloads(s.T())
    81  	s.me = module.NewLocal(s.T())
    82  	s.protoState = protocol.NewState(s.T())
    83  	s.net = mocknetwork.NewNetwork(s.T())
    84  	s.con = mocknetwork.NewConduit(s.T())
    85  	s.hotstuff = module.NewHotStuff(s.T())
    86  	s.voteAggregator = hotstuff.NewVoteAggregator(s.T())
    87  	s.timeoutAggregator = hotstuff.NewTimeoutAggregator(s.T())
    88  	s.compliance = mockcollection.NewCompliance(s.T())
    89  
    90  	// set up proto state mock
    91  	protoEpoch := &protocol.Epoch{}
    92  	clusters := flow.ClusterList{s.cluster.ToSkeleton()}
    93  	protoEpoch.On("Clustering").Return(clusters, nil)
    94  
    95  	protoQuery := &protocol.EpochQuery{}
    96  	protoQuery.On("Current").Return(protoEpoch)
    97  
    98  	protoSnapshot := &protocol.Snapshot{}
    99  	protoSnapshot.On("Epochs").Return(protoQuery)
   100  	protoSnapshot.On("Identities", mock.Anything).Return(
   101  		func(selector flow.IdentityFilter[flow.Identity]) flow.IdentityList {
   102  			return s.cluster.Filter(selector)
   103  		},
   104  		nil,
   105  	)
   106  	s.protoState.On("Final").Return(protoSnapshot)
   107  
   108  	// set up cluster state mock
   109  	s.state = &clusterstate.MutableState{}
   110  	s.state.On("Final").Return(
   111  		func() clusterint.Snapshot {
   112  			return s.snapshot
   113  		},
   114  	)
   115  	s.state.On("AtBlockID", mock.Anything).Return(
   116  		func(blockID flow.Identifier) clusterint.Snapshot {
   117  			return s.snapshot
   118  		},
   119  	)
   120  	clusterParams := &protocol.Params{}
   121  	clusterParams.On("ChainID").Return(s.clusterID, nil)
   122  
   123  	s.state.On("Params").Return(clusterParams)
   124  
   125  	// set up local module mock
   126  	s.me.On("NodeID").Return(
   127  		func() flow.Identifier {
   128  			return s.myID
   129  		},
   130  	).Maybe()
   131  
   132  	// set up network module mock
   133  	s.net.On("Register", mock.Anything, mock.Anything).Return(
   134  		func(channel channels.Channel, engine netint.MessageProcessor) netint.Conduit {
   135  			return s.con
   136  		},
   137  		nil,
   138  	)
   139  
   140  	// set up protocol snapshot mock
   141  	s.snapshot = &clusterstate.Snapshot{}
   142  	s.snapshot.On("Head").Return(
   143  		func() *flow.Header {
   144  			return s.head.Header
   145  		},
   146  		nil,
   147  	)
   148  
   149  	engineMetrics := metrics.NewNoopCollector()
   150  	hub, err := NewMessageHub(
   151  		unittest.Logger(),
   152  		engineMetrics,
   153  		s.net,
   154  		s.me,
   155  		s.compliance,
   156  		s.hotstuff,
   157  		s.voteAggregator,
   158  		s.timeoutAggregator,
   159  		s.protoState,
   160  		s.state,
   161  		s.payloads,
   162  	)
   163  	require.NoError(s.T(), err)
   164  	s.hub = hub
   165  
   166  	s.ctx, s.cancel, s.errs = irrecoverable.WithSignallerAndCancel(context.Background())
   167  	s.hub.Start(s.ctx)
   168  
   169  	unittest.AssertClosesBefore(s.T(), s.hub.Ready(), time.Second)
   170  }
   171  
   172  // TearDownTest stops the hub and checks there are no errors thrown to the SignallerContext.
   173  func (s *MessageHubSuite) TearDownTest() {
   174  	s.cancel()
   175  	unittest.RequireCloseBefore(s.T(), s.hub.Done(), time.Second, "hub failed to stop")
   176  	select {
   177  	case err := <-s.errs:
   178  		assert.NoError(s.T(), err)
   179  	default:
   180  	}
   181  }
   182  
   183  // TestProcessIncomingMessages tests processing of incoming messages, MessageHub matches messages by type
   184  // and sends them to other modules which execute business logic.
   185  func (s *MessageHubSuite) TestProcessIncomingMessages() {
   186  	var channel channels.Channel
   187  	originID := unittest.IdentifierFixture()
   188  	s.Run("to-compliance-engine", func() {
   189  		block := unittest.ClusterBlockFixture()
   190  
   191  		blockProposalMsg := messages.NewClusterBlockProposal(&block)
   192  		expectedComplianceMsg := flow.Slashable[*messages.ClusterBlockProposal]{
   193  			OriginID: originID,
   194  			Message:  blockProposalMsg,
   195  		}
   196  		s.compliance.On("OnClusterBlockProposal", expectedComplianceMsg).Return(nil).Once()
   197  		err := s.hub.Process(channel, originID, blockProposalMsg)
   198  		require.NoError(s.T(), err)
   199  	})
   200  	s.Run("to-vote-aggregator", func() {
   201  		expectedVote := unittest.VoteFixture(unittest.WithVoteSignerID(originID))
   202  		msg := &messages.ClusterBlockVote{
   203  			View:    expectedVote.View,
   204  			BlockID: expectedVote.BlockID,
   205  			SigData: expectedVote.SigData,
   206  		}
   207  		s.voteAggregator.On("AddVote", expectedVote)
   208  		err := s.hub.Process(channel, originID, msg)
   209  		require.NoError(s.T(), err)
   210  	})
   211  	s.Run("to-timeout-aggregator", func() {
   212  		expectedTimeout := helper.TimeoutObjectFixture(helper.WithTimeoutObjectSignerID(originID))
   213  		msg := &messages.ClusterTimeoutObject{
   214  			View:       expectedTimeout.View,
   215  			NewestQC:   expectedTimeout.NewestQC,
   216  			LastViewTC: expectedTimeout.LastViewTC,
   217  			SigData:    expectedTimeout.SigData,
   218  		}
   219  		s.timeoutAggregator.On("AddTimeout", expectedTimeout)
   220  		err := s.hub.Process(channel, originID, msg)
   221  		require.NoError(s.T(), err)
   222  	})
   223  	s.Run("unsupported-msg-type", func() {
   224  		err := s.hub.Process(channel, originID, struct{}{})
   225  		require.NoError(s.T(), err)
   226  	})
   227  }
   228  
   229  // TestOnOwnProposal tests broadcasting proposals with different inputs
   230  func (s *MessageHubSuite) TestOnOwnProposal() {
   231  	// add execution node to cluster to make sure we exclude them from broadcast
   232  	s.cluster = append(s.cluster, unittest.IdentityFixture(unittest.WithRole(flow.RoleExecution)))
   233  
   234  	// generate a parent with height and chain ID set
   235  	parent := unittest.ClusterBlockFixture()
   236  	parent.Header.ChainID = "test"
   237  	parent.Header.Height = 10
   238  
   239  	// create a block with the parent and store the payload with correct ID
   240  	block := unittest.ClusterBlockWithParent(&parent)
   241  	block.Header.ProposerID = s.myID
   242  
   243  	s.payloads.On("ByBlockID", block.Header.ID()).Return(block.Payload, nil)
   244  	s.payloads.On("ByBlockID", mock.Anything).Return(nil, storerr.ErrNotFound)
   245  
   246  	s.Run("should fail with wrong proposer", func() {
   247  		header := *block.Header
   248  		header.ProposerID = unittest.IdentifierFixture()
   249  		err := s.hub.sendOwnProposal(&header)
   250  		require.Error(s.T(), err, "should fail with wrong proposer")
   251  		header.ProposerID = s.myID
   252  	})
   253  
   254  	// should fail since we can't query payload
   255  	s.Run("should fail with changed/missing parent", func() {
   256  		header := *block.Header
   257  		header.ParentID[0]++
   258  		err := s.hub.sendOwnProposal(&header)
   259  		require.Error(s.T(), err, "should fail with missing parent")
   260  		header.ParentID[0]--
   261  	})
   262  
   263  	// should fail with wrong block ID (payload unavailable)
   264  	s.Run("should fail with wrong block ID", func() {
   265  		header := *block.Header
   266  		header.View++
   267  		err := s.hub.sendOwnProposal(&header)
   268  		require.Error(s.T(), err, "should fail with missing payload")
   269  		header.View--
   270  	})
   271  
   272  	s.Run("should broadcast proposal and pass to HotStuff for valid proposals", func() {
   273  		expectedBroadcastMsg := messages.NewClusterBlockProposal(&block)
   274  
   275  		submitted := make(chan struct{}) // closed when proposal is submitted to hotstuff
   276  		hotstuffProposal := model.ProposalFromFlow(block.Header)
   277  		s.voteAggregator.On("AddBlock", hotstuffProposal).Once()
   278  		s.hotstuff.On("SubmitProposal", hotstuffProposal).
   279  			Run(func(args mock.Arguments) { close(submitted) }).
   280  			Once()
   281  
   282  		broadcast := make(chan struct{}) // closed when proposal is broadcast
   283  		s.con.On("Publish", expectedBroadcastMsg, s.cluster[1].NodeID, s.cluster[2].NodeID).
   284  			Run(func(_ mock.Arguments) { close(broadcast) }).
   285  			Return(nil).
   286  			Once()
   287  
   288  		// submit to broadcast proposal
   289  		s.hub.OnOwnProposal(block.Header, time.Now())
   290  
   291  		unittest.AssertClosesBefore(s.T(), util.AllClosed(broadcast, submitted), time.Second)
   292  	})
   293  }
   294  
   295  // TestProcessMultipleMessagesHappyPath tests submitting all types of messages through full processing pipeline and
   296  // asserting that expected message transmissions happened as expected.
   297  func (s *MessageHubSuite) TestProcessMultipleMessagesHappyPath() {
   298  	var wg sync.WaitGroup
   299  
   300  	s.Run("vote", func() {
   301  		wg.Add(1)
   302  		// prepare vote fixture
   303  		vote := unittest.VoteFixture()
   304  		recipientID := unittest.IdentifierFixture()
   305  		s.con.On("Unicast", mock.Anything, recipientID).Run(func(_ mock.Arguments) {
   306  			wg.Done()
   307  		}).Return(nil)
   308  
   309  		// submit vote
   310  		s.hub.OnOwnVote(vote.BlockID, vote.View, vote.SigData, recipientID)
   311  	})
   312  	s.Run("timeout", func() {
   313  		wg.Add(1)
   314  		// prepare timeout fixture
   315  		timeout := helper.TimeoutObjectFixture()
   316  		expectedBroadcastMsg := &messages.ClusterTimeoutObject{
   317  			View:       timeout.View,
   318  			NewestQC:   timeout.NewestQC,
   319  			LastViewTC: timeout.LastViewTC,
   320  			SigData:    timeout.SigData,
   321  		}
   322  		s.con.On("Publish", expectedBroadcastMsg, s.cluster[1].NodeID, s.cluster[2].NodeID).
   323  			Run(func(_ mock.Arguments) { wg.Done() }).
   324  			Return(nil)
   325  		s.timeoutAggregator.On("AddTimeout", timeout).Once()
   326  		// submit timeout
   327  		s.hub.OnOwnTimeout(timeout)
   328  	})
   329  	s.Run("proposal", func() {
   330  		wg.Add(1)
   331  		// prepare proposal fixture
   332  		proposal := unittest.ClusterBlockWithParent(s.head)
   333  		proposal.Header.ProposerID = s.myID
   334  		s.payloads.On("ByBlockID", proposal.Header.ID()).Return(proposal.Payload, nil)
   335  
   336  		// unset chain and height to make sure they are correctly reconstructed
   337  		hotstuffProposal := model.ProposalFromFlow(proposal.Header)
   338  		s.voteAggregator.On("AddBlock", hotstuffProposal)
   339  		s.hotstuff.On("SubmitProposal", hotstuffProposal)
   340  		expectedBroadcastMsg := messages.NewClusterBlockProposal(&proposal)
   341  		s.con.On("Publish", expectedBroadcastMsg, s.cluster[1].NodeID, s.cluster[2].NodeID).
   342  			Run(func(_ mock.Arguments) { wg.Done() }).
   343  			Return(nil)
   344  
   345  		// submit proposal
   346  		s.hub.OnOwnProposal(proposal.Header, time.Now())
   347  	})
   348  
   349  	unittest.RequireReturnsBefore(s.T(), func() {
   350  		wg.Wait()
   351  	}, time.Second, "expect to process messages before timeout")
   352  }