github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/network/p2p/cache/protocol_state_provider_test.go (about) 1 package cache_test 2 3 import ( 4 "testing" 5 6 "github.com/libp2p/go-libp2p/core/peer" 7 "github.com/rs/zerolog" 8 "github.com/stretchr/testify/assert" 9 "github.com/stretchr/testify/mock" 10 "github.com/stretchr/testify/require" 11 "github.com/stretchr/testify/suite" 12 13 "github.com/onflow/flow-go/network/p2p" 14 "github.com/onflow/flow-go/network/p2p/cache" 15 16 "github.com/onflow/flow-go/model/flow" 17 "github.com/onflow/flow-go/model/flow/filter" 18 "github.com/onflow/flow-go/network/p2p/keyutils" 19 "github.com/onflow/flow-go/state/protocol" 20 "github.com/onflow/flow-go/state/protocol/events" 21 mockprotocol "github.com/onflow/flow-go/state/protocol/mock" 22 "github.com/onflow/flow-go/utils/unittest" 23 ) 24 25 type ProtocolStateProviderTestSuite struct { 26 suite.Suite 27 provider *cache.ProtocolStateIDCache 28 distributor *events.Distributor 29 state protocol.State 30 snapshot protocol.Snapshot 31 head *flow.Header 32 participants flow.IdentityList 33 epochNum uint64 34 } 35 36 func (suite *ProtocolStateProviderTestSuite) SetupTest() { 37 suite.distributor = events.NewDistributor() 38 39 // set up protocol state mock 40 state := &mockprotocol.State{} 41 state.On("Final").Return( 42 func() protocol.Snapshot { 43 return suite.snapshot 44 }, 45 ) 46 state.On("AtBlockID", mock.Anything).Return( 47 func(blockID flow.Identifier) protocol.Snapshot { 48 if suite.head.ID() == blockID { 49 return suite.snapshot 50 } else { 51 return unittest.StateSnapshotForUnknownBlock() 52 } 53 }, 54 ) 55 suite.state = state 56 suite.epochNum = 0 57 58 suite.triggerUpdate() 59 60 provider, err := cache.NewProtocolStateIDCache(zerolog.Logger{}, state, suite.distributor) 61 require.NoError(suite.T(), err) 62 63 suite.provider = provider 64 } 65 66 // triggerUpdate simulates an epoch transition 67 func (suite *ProtocolStateProviderTestSuite) triggerUpdate() { 68 suite.participants = unittest.IdentityListFixture(5, unittest.WithAllRoles(), unittest.WithKeys) 69 70 block := unittest.BlockFixture() 71 suite.head = block.Header 72 73 // set up protocol snapshot mock 74 snapshot := &mockprotocol.Snapshot{} 75 snapshot.On("Identities", mock.Anything).Return( 76 func(filter flow.IdentityFilter[flow.Identity]) flow.IdentityList { 77 return suite.participants.Filter(filter) 78 }, 79 nil, 80 ) 81 snapshot.On("Identity", mock.Anything).Return(func(id flow.Identifier) *flow.Identity { 82 for _, n := range suite.participants { 83 if n.ID() == id { 84 return n 85 } 86 } 87 return nil 88 }, nil) 89 snapshot.On("Head").Return( 90 func() *flow.Header { 91 return suite.head 92 }, 93 nil, 94 ) 95 suite.snapshot = snapshot 96 suite.epochNum += 1 97 98 suite.distributor.EpochTransition(suite.epochNum, suite.head) 99 } 100 101 func TestProtocolStateProvider(t *testing.T) { 102 suite.Run(t, new(ProtocolStateProviderTestSuite)) 103 } 104 105 // checkStateTransition triggers an epoch transition and checks that the updated 106 // state is reflected by the provider being tested. 107 func (suite *ProtocolStateProviderTestSuite) checkStateTransition() { 108 oldParticipants := suite.participants 109 110 suite.triggerUpdate() 111 112 assert.ElementsMatch(suite.T(), suite.participants, suite.provider.Identities(filter.Any)) 113 for _, participant := range suite.participants { 114 pid, err := suite.provider.GetPeerID(participant.NodeID) 115 require.NoError(suite.T(), err) 116 fid, err := suite.provider.GetFlowID(pid) 117 require.NoError(suite.T(), err) 118 assert.Equal(suite.T(), fid, participant.NodeID) 119 } 120 for _, participant := range oldParticipants { 121 _, err := suite.provider.GetPeerID(participant.NodeID) 122 require.ErrorIs(suite.T(), err, p2p.ErrUnknownId) 123 } 124 } 125 126 func (suite *ProtocolStateProviderTestSuite) TestUpdateState() { 127 for i := 0; i < 10; i++ { 128 suite.checkStateTransition() 129 } 130 } 131 132 func (suite *ProtocolStateProviderTestSuite) TestIDTranslation() { 133 for _, participant := range suite.participants { 134 pid, err := suite.provider.GetPeerID(participant.NodeID) 135 require.NoError(suite.T(), err) 136 key, err := keyutils.LibP2PPublicKeyFromFlow(participant.NetworkPubKey) 137 require.NoError(suite.T(), err) 138 expectedPid, err := peer.IDFromPublicKey(key) 139 require.NoError(suite.T(), err) 140 assert.Equal(suite.T(), expectedPid, pid) 141 fid, err := suite.provider.GetFlowID(pid) 142 require.NoError(suite.T(), err) 143 assert.Equal(suite.T(), fid, participant.NodeID) 144 } 145 } 146 147 // TestUnknownIDs verifies that `ProtocolStateIDCache` complies with `p2p.IDTranslator` 148 // interface specification: we expect an `p2p.ErrUnknownId` when attempting to 149 // translate an unknown `peer.ID` or `flow.Identifier`. 150 func (suite *ProtocolStateProviderTestSuite) TestUnknownIDs() { 151 unknwonFlowID := unittest.IdentifierFixture() 152 _, err := suite.provider.GetPeerID(unknwonFlowID) 153 require.ErrorIs(suite.T(), err, p2p.ErrUnknownId) 154 155 unknownPeerID := peer.ID("unknownPeerID") 156 _, err = suite.provider.GetFlowID(unknownPeerID) 157 require.ErrorIs(suite.T(), err, p2p.ErrUnknownId) 158 }