github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/collection/epochmgr/engine_test.go (about)

     1  package epochmgr
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/rs/zerolog"
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/mock"
    11  	"github.com/stretchr/testify/require"
    12  	"github.com/stretchr/testify/suite"
    13  
    14  	"github.com/onflow/flow-go/consensus/hotstuff"
    15  	mockhotstuff "github.com/onflow/flow-go/consensus/hotstuff/mocks"
    16  	epochmgr "github.com/onflow/flow-go/engine/collection/epochmgr/mock"
    17  	mockcollection "github.com/onflow/flow-go/engine/collection/mock"
    18  	"github.com/onflow/flow-go/model/flow"
    19  	realmodule "github.com/onflow/flow-go/module"
    20  	"github.com/onflow/flow-go/module/component"
    21  	mockcomponent "github.com/onflow/flow-go/module/component/mock"
    22  	"github.com/onflow/flow-go/module/irrecoverable"
    23  	"github.com/onflow/flow-go/module/mempool"
    24  	"github.com/onflow/flow-go/module/mempool/epochs"
    25  	"github.com/onflow/flow-go/module/mempool/herocache"
    26  	"github.com/onflow/flow-go/module/metrics"
    27  	mockmodule "github.com/onflow/flow-go/module/mock"
    28  	realcluster "github.com/onflow/flow-go/state/cluster"
    29  	cluster "github.com/onflow/flow-go/state/cluster/mock"
    30  	realprotocol "github.com/onflow/flow-go/state/protocol"
    31  	events "github.com/onflow/flow-go/state/protocol/events/mock"
    32  	protocol "github.com/onflow/flow-go/state/protocol/mock"
    33  	"github.com/onflow/flow-go/utils/unittest"
    34  	"github.com/onflow/flow-go/utils/unittest/mocks"
    35  )
    36  
    37  // mockComponents is a container for the mocked version of epoch components.
    38  type mockComponents struct {
    39  	state             *cluster.State
    40  	prop              *mockcomponent.Component
    41  	sync              *mockmodule.ReadyDoneAware
    42  	hotstuff          *mockmodule.HotStuff
    43  	voteAggregator    *mockhotstuff.VoteAggregator
    44  	timeoutAggregator *mockhotstuff.TimeoutAggregator
    45  	messageHub        *mockcomponent.Component
    46  }
    47  
    48  func newMockComponents(t *testing.T) *mockComponents {
    49  	components := &mockComponents{
    50  		state:             cluster.NewState(t),
    51  		prop:              mockcomponent.NewComponent(t),
    52  		sync:              mockmodule.NewReadyDoneAware(t),
    53  		hotstuff:          mockmodule.NewHotStuff(t),
    54  		voteAggregator:    mockhotstuff.NewVoteAggregator(t),
    55  		timeoutAggregator: mockhotstuff.NewTimeoutAggregator(t),
    56  		messageHub:        mockcomponent.NewComponent(t),
    57  	}
    58  	unittest.ReadyDoneify(components.prop)
    59  	unittest.ReadyDoneify(components.sync)
    60  	unittest.ReadyDoneify(components.hotstuff)
    61  	unittest.ReadyDoneify(components.voteAggregator)
    62  	unittest.ReadyDoneify(components.timeoutAggregator)
    63  	unittest.ReadyDoneify(components.messageHub)
    64  
    65  	components.prop.On("Start", mock.Anything)
    66  	components.hotstuff.On("Start", mock.Anything)
    67  	components.voteAggregator.On("Start", mock.Anything)
    68  	components.timeoutAggregator.On("Start", mock.Anything)
    69  	components.messageHub.On("Start", mock.Anything)
    70  	params := cluster.NewParams(t)
    71  	params.On("ChainID").Return(flow.ChainID("chain-id"), nil).Maybe()
    72  	components.state.On("Params").Return(params).Maybe()
    73  	return components
    74  }
    75  
    76  type Suite struct {
    77  	suite.Suite
    78  
    79  	// engine dependencies
    80  	log   zerolog.Logger
    81  	me    *mockmodule.Local
    82  	state *protocol.State
    83  	snap  *protocol.Snapshot
    84  	pools *epochs.TransactionPools
    85  
    86  	// qc voter dependencies
    87  	signer  *mockhotstuff.Signer
    88  	client  *mockmodule.QCContractClient
    89  	voter   *mockmodule.ClusterRootQCVoter
    90  	factory *epochmgr.EpochComponentsFactory
    91  	heights *events.Heights
    92  
    93  	epochQuery *mocks.EpochQuery
    94  	counter    uint64                     // reflects the counter of the current epoch
    95  	phase      flow.EpochPhase            // phase at mocked snapshot
    96  	header     *flow.Header               // header at mocked snapshot
    97  	epochs     map[uint64]*protocol.Epoch // track all epochs
    98  	components map[uint64]*mockComponents // track all epoch components
    99  
   100  	ctx    irrecoverable.SignalerContext
   101  	cancel context.CancelFunc
   102  	errs   <-chan error
   103  
   104  	engine *Engine
   105  
   106  	engineEventsDistributor *mockcollection.EngineEvents
   107  }
   108  
   109  // MockFactoryCreate mocks the epoch factory to create epoch components for the given epoch.
   110  func (suite *Suite) MockFactoryCreate(arg any) {
   111  	suite.factory.On("Create", arg).
   112  		Run(func(args mock.Arguments) {
   113  			epoch, ok := args.Get(0).(realprotocol.Epoch)
   114  			suite.Require().Truef(ok, "invalid type %T", args.Get(0))
   115  			counter, err := epoch.Counter()
   116  			suite.Require().Nil(err)
   117  			suite.components[counter] = newMockComponents(suite.T())
   118  		}).
   119  		Return(
   120  			func(epoch realprotocol.Epoch) realcluster.State { return suite.ComponentsForEpoch(epoch).state },
   121  			func(epoch realprotocol.Epoch) component.Component { return suite.ComponentsForEpoch(epoch).prop },
   122  			func(epoch realprotocol.Epoch) realmodule.ReadyDoneAware { return suite.ComponentsForEpoch(epoch).sync },
   123  			func(epoch realprotocol.Epoch) realmodule.HotStuff { return suite.ComponentsForEpoch(epoch).hotstuff },
   124  			func(epoch realprotocol.Epoch) hotstuff.VoteAggregator {
   125  				return suite.ComponentsForEpoch(epoch).voteAggregator
   126  			},
   127  			func(epoch realprotocol.Epoch) hotstuff.TimeoutAggregator {
   128  				return suite.ComponentsForEpoch(epoch).timeoutAggregator
   129  			},
   130  			func(epoch realprotocol.Epoch) component.Component { return suite.ComponentsForEpoch(epoch).messageHub },
   131  			func(epoch realprotocol.Epoch) error { return nil },
   132  		).Maybe()
   133  }
   134  
   135  func (suite *Suite) SetupTest() {
   136  	suite.log = unittest.Logger()
   137  	suite.me = mockmodule.NewLocal(suite.T())
   138  	suite.state = protocol.NewState(suite.T())
   139  	suite.snap = protocol.NewSnapshot(suite.T())
   140  
   141  	suite.epochs = make(map[uint64]*protocol.Epoch)
   142  	suite.components = make(map[uint64]*mockComponents)
   143  
   144  	suite.signer = mockhotstuff.NewSigner(suite.T())
   145  	suite.client = mockmodule.NewQCContractClient(suite.T())
   146  	suite.voter = mockmodule.NewClusterRootQCVoter(suite.T())
   147  	suite.factory = epochmgr.NewEpochComponentsFactory(suite.T())
   148  	suite.heights = events.NewHeights(suite.T())
   149  
   150  	// mock out Create so that it instantiates the appropriate mocks
   151  	suite.MockFactoryCreate(mock.Anything)
   152  
   153  	suite.phase = flow.EpochPhaseSetup
   154  	suite.header = unittest.BlockHeaderFixture()
   155  	suite.epochQuery = mocks.NewEpochQuery(suite.T(), suite.counter)
   156  
   157  	suite.state.On("Final").Return(suite.snap)
   158  	suite.state.On("AtBlockID", suite.header.ID()).Return(suite.snap).Maybe()
   159  	suite.snap.On("Epochs").Return(suite.epochQuery)
   160  	suite.snap.On("Head").Return(
   161  		func() *flow.Header { return suite.header },
   162  		func() error { return nil })
   163  	suite.snap.On("Phase").Return(
   164  		func() flow.EpochPhase { return suite.phase },
   165  		func() error { return nil })
   166  
   167  	// add current and next epochs
   168  	suite.AddEpoch(suite.counter)
   169  	suite.AddEpoch(suite.counter + 1)
   170  
   171  	suite.pools = epochs.NewTransactionPools(func(_ uint64) mempool.Transactions {
   172  		return herocache.NewTransactions(1000, suite.log, metrics.NewNoopCollector())
   173  	})
   174  
   175  	suite.engineEventsDistributor = mockcollection.NewEngineEvents(suite.T())
   176  
   177  	var err error
   178  	suite.engine, err = New(suite.log, suite.me, suite.state, suite.pools, suite.voter, suite.factory, suite.heights, suite.engineEventsDistributor)
   179  	suite.Require().Nil(err)
   180  
   181  }
   182  
   183  // StartEngine starts the engine under test, and spawns a routine to check for irrecoverable errors.
   184  func (suite *Suite) StartEngine() {
   185  	suite.ctx, suite.cancel, suite.errs = irrecoverable.WithSignallerAndCancel(context.Background())
   186  	go unittest.FailOnIrrecoverableError(suite.T(), suite.ctx.Done(), suite.errs)
   187  	suite.engine.Start(suite.ctx)
   188  	unittest.AssertClosesBefore(suite.T(), suite.engine.Ready(), time.Second)
   189  }
   190  
   191  // TearDownTest stops the engine and checks for any irrecoverable errors.
   192  func (suite *Suite) TearDownTest() {
   193  	if suite.cancel == nil {
   194  		return
   195  	}
   196  	suite.cancel()
   197  	unittest.RequireCloseBefore(suite.T(), suite.engine.Done(), time.Second, "engine failed to stop")
   198  	select {
   199  	case err := <-suite.errs:
   200  		assert.NoError(suite.T(), err)
   201  	default:
   202  	}
   203  }
   204  
   205  func TestEpochManager(t *testing.T) {
   206  	suite.Run(t, new(Suite))
   207  }
   208  
   209  // TransitionEpoch triggers an epoch transition in the suite's mocks.
   210  func (suite *Suite) TransitionEpoch() {
   211  	suite.counter++
   212  	suite.epochQuery.Transition()
   213  }
   214  
   215  // AddEpoch adds an epoch with the given counter.
   216  func (suite *Suite) AddEpoch(counter uint64) *protocol.Epoch {
   217  	epoch := new(protocol.Epoch)
   218  	epoch.On("Counter").Return(counter, nil)
   219  	suite.epochs[counter] = epoch
   220  	suite.epochQuery.Add(epoch)
   221  	return epoch
   222  }
   223  
   224  // AssertEpochStarted asserts that the components for the given epoch have been started.
   225  func (suite *Suite) AssertEpochStarted(counter uint64) {
   226  	components, ok := suite.components[counter]
   227  	suite.Assert().True(ok, "asserting nonexistent epoch %d started", counter)
   228  	components.prop.AssertCalled(suite.T(), "Ready")
   229  	components.sync.AssertCalled(suite.T(), "Ready")
   230  	components.voteAggregator.AssertCalled(suite.T(), "Ready")
   231  	components.voteAggregator.AssertCalled(suite.T(), "Start", mock.Anything)
   232  }
   233  
   234  // AssertEpochStopped asserts that the components for the given epoch have been stopped.
   235  func (suite *Suite) AssertEpochStopped(counter uint64) {
   236  	components, ok := suite.components[counter]
   237  	suite.Assert().True(ok, "asserting nonexistent epoch stopped", counter)
   238  	components.prop.AssertCalled(suite.T(), "Done")
   239  	components.sync.AssertCalled(suite.T(), "Done")
   240  }
   241  
   242  func (suite *Suite) ComponentsForEpoch(epoch realprotocol.Epoch) *mockComponents {
   243  	counter, err := epoch.Counter()
   244  	suite.Require().Nil(err, "cannot get counter")
   245  	components, ok := suite.components[counter]
   246  	suite.Require().True(ok, "missing component for counter", counter)
   247  	return components
   248  }
   249  
   250  // MockAsUnauthorizedNode mocks the factory to return a sentinel indicating
   251  // we are not authorized in the epoch
   252  func (suite *Suite) MockAsUnauthorizedNode(forEpoch uint64) {
   253  
   254  	// mock as unauthorized for given epoch only
   255  	unauthorizedMatcher := func(epoch realprotocol.Epoch) bool {
   256  		counter, err := epoch.Counter()
   257  		require.NoError(suite.T(), err)
   258  		return counter == forEpoch
   259  	}
   260  	authorizedMatcher := func(epoch realprotocol.Epoch) bool { return !unauthorizedMatcher(epoch) }
   261  
   262  	suite.factory = epochmgr.NewEpochComponentsFactory(suite.T())
   263  	suite.factory.
   264  		On("Create", mock.MatchedBy(unauthorizedMatcher)).
   265  		Return(nil, nil, nil, nil, nil, nil, nil, ErrNotAuthorizedForEpoch)
   266  	suite.MockFactoryCreate(mock.MatchedBy(authorizedMatcher))
   267  
   268  	var err error
   269  	suite.engine, err = New(suite.log, suite.me, suite.state, suite.pools, suite.voter, suite.factory, suite.heights, suite.engineEventsDistributor)
   270  	suite.Require().Nil(err)
   271  }
   272  
   273  // TestRestartInSetupPhase tests that, if we start up during the setup phase,
   274  // we should kick off the root QC voter
   275  func (suite *Suite) TestRestartInSetupPhase() {
   276  	// we expect 1 ActiveClustersChanged events when the engine first starts and the first set of epoch components are started
   277  	suite.engineEventsDistributor.On("ActiveClustersChanged", mock.AnythingOfType("flow.ChainIDList")).Once()
   278  	defer suite.engineEventsDistributor.AssertExpectations(suite.T())
   279  	// we are in setup phase
   280  	suite.phase = flow.EpochPhaseSetup
   281  	// should call voter with next epoch
   282  	var called = make(chan struct{})
   283  	suite.voter.On("Vote", mock.Anything, suite.epochQuery.Next()).
   284  		Return(nil).
   285  		Run(func(args mock.Arguments) {
   286  			close(called)
   287  		}).Once()
   288  
   289  	// start up the engine
   290  	suite.StartEngine()
   291  
   292  	unittest.AssertClosesBefore(suite.T(), called, time.Second)
   293  }
   294  
   295  // TestStartAfterEpochBoundary_WithinTxExpiry tests starting the engine shortly after an epoch transition.
   296  // When the finalized height is within the first tx_expiry blocks of the new epoch
   297  // the engine should restart the previous epoch cluster consensus.
   298  func (suite *Suite) TestStartAfterEpochBoundary_WithinTxExpiry() {
   299  	// we expect 2 ActiveClustersChanged events once when the engine first starts and the first set of epoch components are started and on restart
   300  	suite.engineEventsDistributor.On("ActiveClustersChanged", mock.AnythingOfType("flow.ChainIDList")).Twice()
   301  	defer suite.engineEventsDistributor.AssertExpectations(suite.T())
   302  	suite.phase = flow.EpochPhaseStaking
   303  	// transition epochs, so that a Previous epoch is queryable
   304  	suite.TransitionEpoch()
   305  	prevEpoch := suite.epochs[suite.counter-1]
   306  	// the finalized height is within [1,tx_expiry] heights of previous epoch final height
   307  	prevEpochFinalHeight := uint64(100)
   308  	prevEpoch.On("FinalHeight").Return(prevEpochFinalHeight, nil)
   309  	suite.header.Height = prevEpochFinalHeight + 1
   310  	suite.heights.On("OnHeight", prevEpochFinalHeight+flow.DefaultTransactionExpiry+1, mock.Anything)
   311  
   312  	suite.StartEngine()
   313  	// previous epoch components should have been started
   314  	suite.AssertEpochStarted(suite.counter - 1)
   315  	suite.AssertEpochStarted(suite.counter)
   316  }
   317  
   318  // TestStartAfterEpochBoundary_BeyondTxExpiry tests starting the engine shortly after an epoch transition.
   319  // When the finalized height is beyond the first tx_expiry blocks of the new epoch
   320  // the engine should NOT restart the previous epoch cluster consensus.
   321  func (suite *Suite) TestStartAfterEpochBoundary_BeyondTxExpiry() {
   322  	// we expect 1 ActiveClustersChanged events when the engine first starts and the first set of epoch components are started
   323  	suite.engineEventsDistributor.On("ActiveClustersChanged", mock.AnythingOfType("flow.ChainIDList")).Once()
   324  	defer suite.engineEventsDistributor.AssertExpectations(suite.T())
   325  	suite.phase = flow.EpochPhaseStaking
   326  	// transition epochs, so that a Previous epoch is queryable
   327  	suite.TransitionEpoch()
   328  	prevEpoch := suite.epochs[suite.counter-1]
   329  	// the finalized height is more than tx_expiry above previous epoch final height
   330  	prevEpochFinalHeight := uint64(100)
   331  	prevEpoch.On("FinalHeight").Return(prevEpochFinalHeight, nil)
   332  	suite.header.Height = prevEpochFinalHeight + flow.DefaultTransactionExpiry + 100
   333  
   334  	suite.StartEngine()
   335  	// previous epoch components should not have been started
   336  	suite.AssertEpochStarted(suite.counter)
   337  	suite.Assert().Len(suite.components, 1)
   338  }
   339  
   340  // TestStartAfterEpochBoundary_NotApprovedForPreviousEpoch tests starting the engine
   341  // shortly after an epoch transition. The finalized boundary is near enough the epoch
   342  // boundary that we could start the previous epoch cluster consensus - however,
   343  // since we are not approved for the epoch, we should only start current epoch components.
   344  func (suite *Suite) TestStartAfterEpochBoundary_NotApprovedForPreviousEpoch() {
   345  	// we expect 1 ActiveClustersChanged events when the current epoch components are started
   346  	suite.engineEventsDistributor.On("ActiveClustersChanged", mock.AnythingOfType("flow.ChainIDList")).Once()
   347  	defer suite.engineEventsDistributor.AssertExpectations(suite.T())
   348  	suite.phase = flow.EpochPhaseStaking
   349  	// transition epochs, so that a Previous epoch is queryable
   350  	suite.TransitionEpoch()
   351  	prevEpoch := suite.epochs[suite.counter-1]
   352  	// the finalized height is within [1,tx_expiry] heights of previous epoch final height
   353  	prevEpochFinalHeight := uint64(100)
   354  	prevEpoch.On("FinalHeight").Return(prevEpochFinalHeight, nil)
   355  	suite.header.Height = 101
   356  	suite.MockAsUnauthorizedNode(suite.counter - 1)
   357  
   358  	suite.StartEngine()
   359  	// previous epoch components should not have been started
   360  	suite.AssertEpochStarted(suite.counter)
   361  	suite.Assert().Len(suite.components, 1)
   362  }
   363  
   364  // TestStartAfterEpochBoundary_NotApprovedForCurrentEpoch tests starting the engine
   365  // shortly after an epoch transition. The finalized boundary is near enough the epoch
   366  // boundary that we should start the previous epoch cluster consensus. However, we are
   367  // not approved for the current epoch -> we should only start *previous* epoch components.
   368  func (suite *Suite) TestStartAfterEpochBoundary_NotApprovedForCurrentEpoch() {
   369  	// we expect 1 ActiveClustersChanged events when the current epoch components are started
   370  	suite.engineEventsDistributor.On("ActiveClustersChanged", mock.AnythingOfType("flow.ChainIDList")).Once()
   371  	defer suite.engineEventsDistributor.AssertExpectations(suite.T())
   372  	suite.phase = flow.EpochPhaseStaking
   373  	// transition epochs, so that a Previous epoch is queryable
   374  	suite.TransitionEpoch()
   375  	prevEpoch := suite.epochs[suite.counter-1]
   376  	// the finalized height is within [1,tx_expiry] heights of previous epoch final height
   377  	prevEpochFinalHeight := uint64(100)
   378  	prevEpoch.On("FinalHeight").Return(prevEpochFinalHeight, nil)
   379  	suite.header.Height = 101
   380  	suite.heights.On("OnHeight", prevEpochFinalHeight+flow.DefaultTransactionExpiry+1, mock.Anything)
   381  	suite.MockAsUnauthorizedNode(suite.counter)
   382  
   383  	suite.StartEngine()
   384  	// only previous epoch components should have been started
   385  	suite.AssertEpochStarted(suite.counter - 1)
   386  	suite.Assert().Len(suite.components, 1)
   387  }
   388  
   389  // TestStartAfterEpochBoundary_PreviousEpochTransitionBeforeRoot tests starting the engine
   390  // with a root snapshot whose sealing segment excludes the last epoch boundary.
   391  // In this case we should only start up current-epoch components.
   392  func (suite *Suite) TestStartAfterEpochBoundary_PreviousEpochTransitionBeforeRoot() {
   393  	// we expect 1 ActiveClustersChanged events when the current epoch components are started
   394  	suite.engineEventsDistributor.On("ActiveClustersChanged", mock.AnythingOfType("flow.ChainIDList")).Once()
   395  	defer suite.engineEventsDistributor.AssertExpectations(suite.T())
   396  	suite.phase = flow.EpochPhaseStaking
   397  	// transition epochs, so that a Previous epoch is queryable
   398  	suite.TransitionEpoch()
   399  	prevEpoch := suite.epochs[suite.counter-1]
   400  	// Previous epoch end boundary is unknown because it is before our root snapshot
   401  	prevEpoch.On("FinalHeight").Return(uint64(0), realprotocol.ErrUnknownEpochBoundary)
   402  
   403  	suite.StartEngine()
   404  	// only current epoch components should have been started
   405  	suite.AssertEpochStarted(suite.counter)
   406  	suite.Assert().Len(suite.components, 1)
   407  }
   408  
   409  // TestStartAsUnauthorizedNode test that when a collection node joins the network
   410  // at an epoch boundary, they must start running during the EpochSetup phase in the
   411  // epoch before they become an authorized member so they submit their cluster QC vote.
   412  //
   413  // These nodes must kick off the root QC voter but should not attempt to participate
   414  // in cluster consensus in the current epoch.
   415  func (suite *Suite) TestStartAsUnauthorizedNode() {
   416  	suite.MockAsUnauthorizedNode(suite.counter)
   417  	// we are in setup phase
   418  	suite.phase = flow.EpochPhaseSetup
   419  	// should call voter with next epoch
   420  	var called = make(chan struct{})
   421  	suite.voter.On("Vote", mock.Anything, suite.epochQuery.Next()).
   422  		Return(nil).
   423  		Run(func(args mock.Arguments) {
   424  			close(called)
   425  		}).Once()
   426  
   427  	// start the engine
   428  	suite.StartEngine()
   429  
   430  	// should have submitted vote
   431  	unittest.AssertClosesBefore(suite.T(), called, time.Second)
   432  	// should have no epoch components
   433  	assert.Empty(suite.T(), suite.engine.epochs, "should have 0 epoch components")
   434  }
   435  
   436  // TestRespondToPhaseChange should kick off root QC voter when we receive an event
   437  // indicating the EpochSetup phase has started.
   438  func (suite *Suite) TestRespondToPhaseChange() {
   439  	// we expect 1 ActiveClustersChanged events when the engine first starts and the first set of epoch components are started
   440  	suite.engineEventsDistributor.On("ActiveClustersChanged", mock.AnythingOfType("flow.ChainIDList")).Once()
   441  	defer suite.engineEventsDistributor.AssertExpectations(suite.T())
   442  
   443  	// start in staking phase
   444  	suite.phase = flow.EpochPhaseStaking
   445  	// should call voter with next epoch
   446  	var called = make(chan struct{})
   447  	suite.voter.On("Vote", mock.Anything, suite.epochQuery.Next()).
   448  		Return(nil).
   449  		Run(func(args mock.Arguments) {
   450  			close(called)
   451  		}).Once()
   452  
   453  	firstBlockOfEpochSetupPhase := unittest.BlockHeaderFixture()
   454  	suite.state.On("AtBlockID", firstBlockOfEpochSetupPhase.ID()).Return(suite.snap)
   455  	suite.StartEngine()
   456  
   457  	// after receiving the protocol event, we should submit our root QC vote
   458  	suite.engine.EpochSetupPhaseStarted(0, firstBlockOfEpochSetupPhase)
   459  	unittest.AssertClosesBefore(suite.T(), called, time.Second)
   460  }
   461  
   462  // TestRespondToEpochTransition tests the engine's behaviour during epoch transition.
   463  // It should:
   464  //   - instantiate cluster consensus for the new epoch
   465  //   - register callback to stop the previous epoch's cluster consensus
   466  //   - stop the previous epoch's cluster consensus when the callback is invoked
   467  func (suite *Suite) TestRespondToEpochTransition() {
   468  	// we expect 3 ActiveClustersChanged events
   469  	// - once when the engine first starts and the first set of epoch components are started
   470  	// - once when the epoch transitions and the new set of epoch components are started
   471  	// - once when the epoch transitions and the old set of epoch components are stopped
   472  	expectedNumOfEvents := 3
   473  	suite.engineEventsDistributor.On("ActiveClustersChanged", mock.AnythingOfType("flow.ChainIDList")).Times(expectedNumOfEvents)
   474  	defer suite.engineEventsDistributor.AssertExpectations(suite.T())
   475  
   476  	// we are in committed phase
   477  	suite.phase = flow.EpochPhaseCommitted
   478  	suite.StartEngine()
   479  
   480  	firstBlockOfEpoch := unittest.BlockHeaderFixture()
   481  	suite.state.On("AtBlockID", firstBlockOfEpoch.ID()).Return(suite.snap)
   482  
   483  	// should set up callback for height at which previous epoch expires
   484  	var expiryCallback func()
   485  	heightRegistered := make(chan struct{})
   486  	suite.heights.On("OnHeight", firstBlockOfEpoch.Height+flow.DefaultTransactionExpiry, mock.Anything).
   487  		Run(func(args mock.Arguments) {
   488  			expiryCallback = args.Get(1).(func())
   489  			close(heightRegistered)
   490  		}).
   491  		Once()
   492  
   493  	// mock the epoch transition
   494  	suite.TransitionEpoch()
   495  	// notify the engine of the epoch transition
   496  	suite.engine.EpochTransition(suite.counter, firstBlockOfEpoch)
   497  	// ensure we registered a height callback
   498  	unittest.AssertClosesBefore(suite.T(), heightRegistered, time.Second)
   499  	suite.Assert().NotNil(expiryCallback)
   500  
   501  	// the engine should have two epochs under management, the just ended epoch
   502  	// and the newly started epoch
   503  	suite.Eventually(func() bool {
   504  		suite.engine.mu.Lock()
   505  		defer suite.engine.mu.Unlock()
   506  		return len(suite.engine.epochs) == 2
   507  	}, time.Second, 10*time.Millisecond)
   508  	_, exists := suite.engine.epochs[suite.counter-1]
   509  	suite.Assert().True(exists, "should have previous epoch components")
   510  	_, exists = suite.engine.epochs[suite.counter]
   511  	suite.Assert().True(exists, "should have current epoch components")
   512  
   513  	// the newly started (current) epoch should have been started
   514  	suite.AssertEpochStarted(suite.counter)
   515  
   516  	// when we invoke the callback registered to handle the previous epoch's
   517  	// expiry, the previous epoch components should be cleaned up
   518  	expiryCallback()
   519  
   520  	suite.Assert().Eventually(func() bool {
   521  		suite.engine.mu.Lock()
   522  		defer suite.engine.mu.Unlock()
   523  		return len(suite.engine.epochs) == 1
   524  	}, time.Second, 10*time.Millisecond)
   525  
   526  	// after the previous epoch expires, we should only have current epoch
   527  	_, exists = suite.engine.epochs[suite.counter]
   528  	suite.Assert().True(exists, "should have current epoch components")
   529  	_, exists = suite.engine.epochs[suite.counter-1]
   530  	suite.Assert().False(exists, "should not have previous epoch components")
   531  
   532  	// the expired epoch should have been stopped
   533  	suite.AssertEpochStopped(suite.counter - 1)
   534  }