github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/consensus/message_hub/message_hub_test.go (about)

     1  package message_hub
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/mock"
    11  	"github.com/stretchr/testify/require"
    12  	"github.com/stretchr/testify/suite"
    13  
    14  	"github.com/onflow/flow-go/consensus/hotstuff/helper"
    15  	hotstuff "github.com/onflow/flow-go/consensus/hotstuff/mocks"
    16  	"github.com/onflow/flow-go/consensus/hotstuff/model"
    17  	mockconsensus "github.com/onflow/flow-go/engine/consensus/mock"
    18  	"github.com/onflow/flow-go/model/flow"
    19  	"github.com/onflow/flow-go/model/messages"
    20  	"github.com/onflow/flow-go/module/irrecoverable"
    21  	"github.com/onflow/flow-go/module/metrics"
    22  	module "github.com/onflow/flow-go/module/mock"
    23  	"github.com/onflow/flow-go/module/util"
    24  	netint "github.com/onflow/flow-go/network"
    25  	"github.com/onflow/flow-go/network/channels"
    26  	"github.com/onflow/flow-go/network/mocknetwork"
    27  	protint "github.com/onflow/flow-go/state/protocol"
    28  	protocol "github.com/onflow/flow-go/state/protocol/mock"
    29  	storerr "github.com/onflow/flow-go/storage"
    30  	storage "github.com/onflow/flow-go/storage/mock"
    31  	"github.com/onflow/flow-go/utils/unittest"
    32  )
    33  
    34  func TestMessageHub(t *testing.T) {
    35  	suite.Run(t, new(MessageHubSuite))
    36  }
    37  
    38  // MessageHubSuite tests the consensus message hub. Holds mocked dependencies that are used by different test scenarios.
    39  type MessageHubSuite struct {
    40  	suite.Suite
    41  
    42  	// parameters
    43  	participants flow.IdentityList
    44  	myID         flow.Identifier
    45  	head         *flow.Header
    46  
    47  	// mocked dependencies
    48  	payloads          *storage.Payloads
    49  	me                *module.Local
    50  	state             *protocol.State
    51  	net               *mocknetwork.Network
    52  	con               *mocknetwork.Conduit
    53  	pushBlocksCon     *mocknetwork.Conduit
    54  	hotstuff          *module.HotStuff
    55  	voteAggregator    *hotstuff.VoteAggregator
    56  	timeoutAggregator *hotstuff.TimeoutAggregator
    57  	compliance        *mockconsensus.Compliance
    58  	snapshot          *protocol.Snapshot
    59  
    60  	ctx    irrecoverable.SignalerContext
    61  	cancel context.CancelFunc
    62  	errs   <-chan error
    63  	hub    *MessageHub
    64  }
    65  
    66  func (s *MessageHubSuite) SetupTest() {
    67  	// initialize the paramaters
    68  	s.participants = unittest.IdentityListFixture(3,
    69  		unittest.WithRole(flow.RoleConsensus),
    70  		unittest.WithInitialWeight(1000),
    71  	)
    72  	s.myID = s.participants[0].NodeID
    73  	block := unittest.BlockFixture()
    74  	s.head = block.Header
    75  
    76  	s.payloads = storage.NewPayloads(s.T())
    77  	s.me = module.NewLocal(s.T())
    78  	s.state = protocol.NewState(s.T())
    79  	s.net = mocknetwork.NewNetwork(s.T())
    80  	s.con = mocknetwork.NewConduit(s.T())
    81  	s.pushBlocksCon = mocknetwork.NewConduit(s.T())
    82  	s.hotstuff = module.NewHotStuff(s.T())
    83  	s.voteAggregator = hotstuff.NewVoteAggregator(s.T())
    84  	s.timeoutAggregator = hotstuff.NewTimeoutAggregator(s.T())
    85  	s.compliance = mockconsensus.NewCompliance(s.T())
    86  
    87  	// set up protocol state mock
    88  	s.state.On("Final").Return(
    89  		func() protint.Snapshot {
    90  			return s.snapshot
    91  		},
    92  	).Maybe()
    93  	s.state.On("AtBlockID", mock.Anything).Return(
    94  		func(blockID flow.Identifier) protint.Snapshot {
    95  			return s.snapshot
    96  		},
    97  	).Maybe()
    98  
    99  	// set up local module mock
   100  	s.me.On("NodeID").Return(
   101  		func() flow.Identifier {
   102  			return s.myID
   103  		},
   104  	).Maybe()
   105  
   106  	// set up network module mock
   107  	s.net.On("Register", mock.Anything, mock.Anything).Return(
   108  		func(channel channels.Channel, engine netint.MessageProcessor) netint.Conduit {
   109  			if channel == channels.ConsensusCommittee {
   110  				return s.con
   111  			} else if channel == channels.PushBlocks {
   112  				return s.pushBlocksCon
   113  			} else {
   114  				s.T().Fail()
   115  			}
   116  			return nil
   117  		},
   118  		nil,
   119  	)
   120  
   121  	// set up protocol snapshot mock
   122  	s.snapshot = &protocol.Snapshot{}
   123  	s.snapshot.On("Identities", mock.Anything).Return(
   124  		func(filter flow.IdentityFilter[flow.Identity]) flow.IdentityList {
   125  			return s.participants.Filter(filter)
   126  		},
   127  		nil,
   128  	)
   129  	s.snapshot.On("Head").Return(
   130  		func() *flow.Header {
   131  			return s.head
   132  		},
   133  		nil,
   134  	)
   135  
   136  	engineMetrics := metrics.NewNoopCollector()
   137  	hub, err := NewMessageHub(
   138  		unittest.Logger(),
   139  		engineMetrics,
   140  		s.net,
   141  		s.me,
   142  		s.compliance,
   143  		s.hotstuff,
   144  		s.voteAggregator,
   145  		s.timeoutAggregator,
   146  		s.state,
   147  		s.payloads,
   148  	)
   149  	require.NoError(s.T(), err)
   150  	s.hub = hub
   151  
   152  	s.ctx, s.cancel, s.errs = irrecoverable.WithSignallerAndCancel(context.Background())
   153  	s.hub.Start(s.ctx)
   154  
   155  	unittest.AssertClosesBefore(s.T(), s.hub.Ready(), time.Second)
   156  }
   157  
   158  // TearDownTest stops the hub and checks there are no errors thrown to the SignallerContext.
   159  func (s *MessageHubSuite) TearDownTest() {
   160  	s.cancel()
   161  	unittest.RequireCloseBefore(s.T(), s.hub.Done(), time.Second, "hub failed to stop")
   162  	select {
   163  	case err := <-s.errs:
   164  		assert.NoError(s.T(), err)
   165  	default:
   166  	}
   167  }
   168  
   169  // TestProcessIncomingMessages tests processing of incoming messages, MessageHub matches messages by type
   170  // and sends them to other modules which execute business logic.
   171  func (s *MessageHubSuite) TestProcessIncomingMessages() {
   172  	var channel channels.Channel
   173  	originID := unittest.IdentifierFixture()
   174  	s.Run("to-compliance-engine", func() {
   175  		block := unittest.BlockFixture()
   176  
   177  		blockProposalMsg := messages.NewBlockProposal(&block)
   178  		expectedComplianceMsg := flow.Slashable[*messages.BlockProposal]{
   179  			OriginID: originID,
   180  			Message:  blockProposalMsg,
   181  		}
   182  		s.compliance.On("OnBlockProposal", expectedComplianceMsg).Return(nil).Once()
   183  		err := s.hub.Process(channel, originID, blockProposalMsg)
   184  		require.NoError(s.T(), err)
   185  	})
   186  	s.Run("to-vote-aggregator", func() {
   187  		expectedVote := unittest.VoteFixture(unittest.WithVoteSignerID(originID))
   188  		msg := &messages.BlockVote{
   189  			View:    expectedVote.View,
   190  			BlockID: expectedVote.BlockID,
   191  			SigData: expectedVote.SigData,
   192  		}
   193  		s.voteAggregator.On("AddVote", expectedVote)
   194  		err := s.hub.Process(channel, originID, msg)
   195  		require.NoError(s.T(), err)
   196  	})
   197  	s.Run("to-timeout-aggregator", func() {
   198  		expectedTimeout := helper.TimeoutObjectFixture(helper.WithTimeoutObjectSignerID(originID))
   199  		msg := &messages.TimeoutObject{
   200  			View:       expectedTimeout.View,
   201  			NewestQC:   expectedTimeout.NewestQC,
   202  			LastViewTC: expectedTimeout.LastViewTC,
   203  			SigData:    expectedTimeout.SigData,
   204  		}
   205  		s.timeoutAggregator.On("AddTimeout", expectedTimeout)
   206  		err := s.hub.Process(channel, originID, msg)
   207  		require.NoError(s.T(), err)
   208  	})
   209  	s.Run("unsupported-msg-type", func() {
   210  		err := s.hub.Process(channel, originID, struct{}{})
   211  		require.NoError(s.T(), err)
   212  	})
   213  }
   214  
   215  // TestOnOwnProposal tests broadcasting proposals with different inputs
   216  func (s *MessageHubSuite) TestOnOwnProposal() {
   217  	// add execution node to participants to make sure we exclude them from broadcast
   218  	s.participants = append(s.participants, unittest.IdentityFixture(unittest.WithRole(flow.RoleExecution)))
   219  
   220  	// generate a parent with height and chain ID set
   221  	parent := unittest.BlockHeaderFixture()
   222  	parent.ChainID = "test"
   223  	parent.Height = 10
   224  
   225  	// create a block with the parent and store the payload with correct ID
   226  	block := unittest.BlockWithParentFixture(parent)
   227  	block.Header.ProposerID = s.myID
   228  
   229  	s.payloads.On("ByBlockID", block.Header.ID()).Return(block.Payload, nil)
   230  	s.payloads.On("ByBlockID", mock.Anything).Return(nil, storerr.ErrNotFound)
   231  
   232  	s.Run("should fail with wrong proposer", func() {
   233  		header := *block.Header
   234  		header.ProposerID = unittest.IdentifierFixture()
   235  		err := s.hub.sendOwnProposal(&header)
   236  		require.Error(s.T(), err, "should fail with wrong proposer")
   237  		header.ProposerID = s.myID
   238  	})
   239  
   240  	// should fail with wrong block ID (payload unavailable)
   241  	s.Run("should fail with wrong block ID", func() {
   242  		header := *block.Header
   243  		header.View++
   244  		err := s.hub.sendOwnProposal(&header)
   245  		require.Error(s.T(), err, "should fail with missing payload")
   246  		header.View--
   247  	})
   248  
   249  	s.Run("should broadcast proposal and pass to HotStuff for valid proposals", func() {
   250  		expectedBroadcastMsg := messages.NewBlockProposal(block)
   251  
   252  		submitted := make(chan struct{}) // closed when proposal is submitted to hotstuff
   253  		hotstuffProposal := model.ProposalFromFlow(block.Header)
   254  		s.voteAggregator.On("AddBlock", hotstuffProposal).Once()
   255  		s.hotstuff.On("SubmitProposal", hotstuffProposal).
   256  			Run(func(args mock.Arguments) { close(submitted) }).
   257  			Once()
   258  
   259  		broadcast := make(chan struct{}) // closed when proposal is broadcast
   260  		s.con.On("Publish", expectedBroadcastMsg, s.participants[1].NodeID, s.participants[2].NodeID).
   261  			Run(func(_ mock.Arguments) { close(broadcast) }).
   262  			Return(nil).
   263  			Once()
   264  
   265  		s.pushBlocksCon.On("Publish", expectedBroadcastMsg, s.participants[3].NodeID).Return(nil)
   266  
   267  		// submit to broadcast proposal
   268  		s.hub.OnOwnProposal(block.Header, time.Now())
   269  
   270  		unittest.AssertClosesBefore(s.T(), util.AllClosed(broadcast, submitted), time.Second)
   271  	})
   272  }
   273  
   274  // TestProcessMultipleMessagesHappyPath tests submitting all types of messages through full processing pipeline and
   275  // asserting that expected message transmissions happened as expected.
   276  func (s *MessageHubSuite) TestProcessMultipleMessagesHappyPath() {
   277  	var wg sync.WaitGroup
   278  
   279  	// add execution node to participants to make sure we exclude them from broadcast
   280  	s.participants = append(s.participants, unittest.IdentityFixture(unittest.WithRole(flow.RoleExecution)))
   281  
   282  	s.Run("vote", func() {
   283  		wg.Add(1)
   284  		// prepare vote fixture
   285  		vote := unittest.VoteFixture()
   286  		recipientID := unittest.IdentifierFixture()
   287  		s.con.On("Unicast", mock.Anything, recipientID).Run(func(_ mock.Arguments) {
   288  			wg.Done()
   289  		}).Return(nil)
   290  
   291  		// submit vote
   292  		s.hub.OnOwnVote(vote.BlockID, vote.View, vote.SigData, recipientID)
   293  	})
   294  	s.Run("timeout", func() {
   295  		wg.Add(1)
   296  		// prepare timeout fixture
   297  		timeout := helper.TimeoutObjectFixture()
   298  		expectedBroadcastMsg := &messages.TimeoutObject{
   299  			View:       timeout.View,
   300  			NewestQC:   timeout.NewestQC,
   301  			LastViewTC: timeout.LastViewTC,
   302  			SigData:    timeout.SigData,
   303  		}
   304  		s.con.On("Publish", expectedBroadcastMsg, s.participants[1].NodeID, s.participants[2].NodeID).
   305  			Run(func(_ mock.Arguments) { wg.Done() }).
   306  			Return(nil)
   307  		s.timeoutAggregator.On("AddTimeout", timeout).Once()
   308  		// submit timeout
   309  		s.hub.OnOwnTimeout(timeout)
   310  	})
   311  	s.Run("proposal", func() {
   312  		wg.Add(1)
   313  		// prepare proposal fixture
   314  		proposal := unittest.BlockWithParentAndProposerFixture(s.T(), s.head, s.myID)
   315  		s.payloads.On("ByBlockID", proposal.Header.ID()).Return(proposal.Payload, nil)
   316  
   317  		// unset chain and height to make sure they are correctly reconstructed
   318  		hotstuffProposal := model.ProposalFromFlow(proposal.Header)
   319  		s.voteAggregator.On("AddBlock", hotstuffProposal).Once()
   320  		s.hotstuff.On("SubmitProposal", hotstuffProposal)
   321  		expectedBroadcastMsg := messages.NewBlockProposal(&proposal)
   322  		s.con.On("Publish", expectedBroadcastMsg, s.participants[1].NodeID, s.participants[2].NodeID).
   323  			Run(func(_ mock.Arguments) { wg.Done() }).
   324  			Return(nil)
   325  		s.pushBlocksCon.On("Publish", expectedBroadcastMsg, s.participants[3].NodeID).Return(nil)
   326  
   327  		// submit proposal
   328  		s.hub.OnOwnProposal(proposal.Header, time.Now())
   329  	})
   330  
   331  	unittest.RequireReturnsBefore(s.T(), func() {
   332  		wg.Wait()
   333  	}, time.Second, "expect to process messages before timeout")
   334  }