github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/network/p2p/cache/protocol_state_provider_test.go (about)

     1  package cache_test
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/libp2p/go-libp2p/core/peer"
     7  	"github.com/rs/zerolog"
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/stretchr/testify/mock"
    10  	"github.com/stretchr/testify/require"
    11  	"github.com/stretchr/testify/suite"
    12  
    13  	"github.com/onflow/flow-go/network/p2p"
    14  	"github.com/onflow/flow-go/network/p2p/cache"
    15  
    16  	"github.com/onflow/flow-go/model/flow"
    17  	"github.com/onflow/flow-go/model/flow/filter"
    18  	"github.com/onflow/flow-go/network/p2p/keyutils"
    19  	"github.com/onflow/flow-go/state/protocol"
    20  	"github.com/onflow/flow-go/state/protocol/events"
    21  	mockprotocol "github.com/onflow/flow-go/state/protocol/mock"
    22  	"github.com/onflow/flow-go/utils/unittest"
    23  )
    24  
    25  type ProtocolStateProviderTestSuite struct {
    26  	suite.Suite
    27  	provider     *cache.ProtocolStateIDCache
    28  	distributor  *events.Distributor
    29  	state        protocol.State
    30  	snapshot     protocol.Snapshot
    31  	head         *flow.Header
    32  	participants flow.IdentityList
    33  	epochNum     uint64
    34  }
    35  
    36  func (suite *ProtocolStateProviderTestSuite) SetupTest() {
    37  	suite.distributor = events.NewDistributor()
    38  
    39  	// set up protocol state mock
    40  	state := &mockprotocol.State{}
    41  	state.On("Final").Return(
    42  		func() protocol.Snapshot {
    43  			return suite.snapshot
    44  		},
    45  	)
    46  	state.On("AtBlockID", mock.Anything).Return(
    47  		func(blockID flow.Identifier) protocol.Snapshot {
    48  			if suite.head.ID() == blockID {
    49  				return suite.snapshot
    50  			} else {
    51  				return unittest.StateSnapshotForUnknownBlock()
    52  			}
    53  		},
    54  	)
    55  	suite.state = state
    56  	suite.epochNum = 0
    57  
    58  	suite.triggerUpdate()
    59  
    60  	provider, err := cache.NewProtocolStateIDCache(zerolog.Logger{}, state, suite.distributor)
    61  	require.NoError(suite.T(), err)
    62  
    63  	suite.provider = provider
    64  }
    65  
    66  // triggerUpdate simulates an epoch transition
    67  func (suite *ProtocolStateProviderTestSuite) triggerUpdate() {
    68  	suite.participants = unittest.IdentityListFixture(5, unittest.WithAllRoles(), unittest.WithKeys)
    69  
    70  	block := unittest.BlockFixture()
    71  	suite.head = block.Header
    72  
    73  	// set up protocol snapshot mock
    74  	snapshot := &mockprotocol.Snapshot{}
    75  	snapshot.On("Identities", mock.Anything).Return(
    76  		func(filter flow.IdentityFilter[flow.Identity]) flow.IdentityList {
    77  			return suite.participants.Filter(filter)
    78  		},
    79  		nil,
    80  	)
    81  	snapshot.On("Identity", mock.Anything).Return(func(id flow.Identifier) *flow.Identity {
    82  		for _, n := range suite.participants {
    83  			if n.ID() == id {
    84  				return n
    85  			}
    86  		}
    87  		return nil
    88  	}, nil)
    89  	snapshot.On("Head").Return(
    90  		func() *flow.Header {
    91  			return suite.head
    92  		},
    93  		nil,
    94  	)
    95  	suite.snapshot = snapshot
    96  	suite.epochNum += 1
    97  
    98  	suite.distributor.EpochTransition(suite.epochNum, suite.head)
    99  }
   100  
   101  func TestProtocolStateProvider(t *testing.T) {
   102  	suite.Run(t, new(ProtocolStateProviderTestSuite))
   103  }
   104  
   105  // checkStateTransition triggers an epoch transition and checks that the updated
   106  // state is reflected by the provider being tested.
   107  func (suite *ProtocolStateProviderTestSuite) checkStateTransition() {
   108  	oldParticipants := suite.participants
   109  
   110  	suite.triggerUpdate()
   111  
   112  	assert.ElementsMatch(suite.T(), suite.participants, suite.provider.Identities(filter.Any))
   113  	for _, participant := range suite.participants {
   114  		pid, err := suite.provider.GetPeerID(participant.NodeID)
   115  		require.NoError(suite.T(), err)
   116  		fid, err := suite.provider.GetFlowID(pid)
   117  		require.NoError(suite.T(), err)
   118  		assert.Equal(suite.T(), fid, participant.NodeID)
   119  	}
   120  	for _, participant := range oldParticipants {
   121  		_, err := suite.provider.GetPeerID(participant.NodeID)
   122  		require.ErrorIs(suite.T(), err, p2p.ErrUnknownId)
   123  	}
   124  }
   125  
   126  func (suite *ProtocolStateProviderTestSuite) TestUpdateState() {
   127  	for i := 0; i < 10; i++ {
   128  		suite.checkStateTransition()
   129  	}
   130  }
   131  
   132  func (suite *ProtocolStateProviderTestSuite) TestIDTranslation() {
   133  	for _, participant := range suite.participants {
   134  		pid, err := suite.provider.GetPeerID(participant.NodeID)
   135  		require.NoError(suite.T(), err)
   136  		key, err := keyutils.LibP2PPublicKeyFromFlow(participant.NetworkPubKey)
   137  		require.NoError(suite.T(), err)
   138  		expectedPid, err := peer.IDFromPublicKey(key)
   139  		require.NoError(suite.T(), err)
   140  		assert.Equal(suite.T(), expectedPid, pid)
   141  		fid, err := suite.provider.GetFlowID(pid)
   142  		require.NoError(suite.T(), err)
   143  		assert.Equal(suite.T(), fid, participant.NodeID)
   144  	}
   145  }
   146  
   147  // TestUnknownIDs verifies that `ProtocolStateIDCache` complies with `p2p.IDTranslator`
   148  // interface specification: we expect an `p2p.ErrUnknownId` when attempting to
   149  // translate an unknown `peer.ID` or `flow.Identifier`.
   150  func (suite *ProtocolStateProviderTestSuite) TestUnknownIDs() {
   151  	unknwonFlowID := unittest.IdentifierFixture()
   152  	_, err := suite.provider.GetPeerID(unknwonFlowID)
   153  	require.ErrorIs(suite.T(), err, p2p.ErrUnknownId)
   154  
   155  	unknownPeerID := peer.ID("unknownPeerID")
   156  	_, err = suite.provider.GetFlowID(unknownPeerID)
   157  	require.ErrorIs(suite.T(), err, p2p.ErrUnknownId)
   158  }