github.com/koko1123/flow-go-1@v0.29.6/engine/collection/compliance/engine_test.go (about) 1 package compliance 2 3 import ( 4 "math/rand" 5 "sync" 6 "testing" 7 "time" 8 9 "github.com/stretchr/testify/mock" 10 "github.com/stretchr/testify/require" 11 "github.com/stretchr/testify/suite" 12 13 "github.com/koko1123/flow-go-1/consensus/hotstuff/model" 14 "github.com/koko1123/flow-go-1/engine" 15 "github.com/koko1123/flow-go-1/model/cluster" 16 "github.com/koko1123/flow-go-1/model/flow" 17 "github.com/koko1123/flow-go-1/model/messages" 18 module "github.com/koko1123/flow-go-1/module/mock" 19 netint "github.com/koko1123/flow-go-1/network" 20 "github.com/koko1123/flow-go-1/network/channels" 21 "github.com/koko1123/flow-go-1/network/mocknetwork" 22 protocol "github.com/koko1123/flow-go-1/state/protocol/mock" 23 storerr "github.com/koko1123/flow-go-1/storage" 24 storage "github.com/koko1123/flow-go-1/storage/mock" 25 "github.com/koko1123/flow-go-1/utils/unittest" 26 ) 27 28 func TestComplianceEngine(t *testing.T) { 29 suite.Run(t, new(ComplianceSuite)) 30 } 31 32 type ComplianceSuite struct { 33 ComplianceCoreSuite 34 35 clusterID flow.ChainID 36 myID flow.Identifier 37 cluster flow.IdentityList 38 me *module.Local 39 net *mocknetwork.Network 40 payloads *storage.ClusterPayloads 41 protoState *protocol.MutableState 42 con *mocknetwork.Conduit 43 44 payloadDB map[flow.Identifier]*cluster.Payload 45 46 engine *Engine 47 } 48 49 func (cs *ComplianceSuite) SetupTest() { 50 cs.ComplianceCoreSuite.SetupTest() 51 52 // initialize the paramaters 53 cs.cluster = unittest.IdentityListFixture(3, 54 unittest.WithRole(flow.RoleCollection), 55 unittest.WithWeight(1000), 56 ) 57 cs.myID = cs.cluster[0].NodeID 58 59 protoEpoch := &protocol.Epoch{} 60 clusters := flow.ClusterList{cs.cluster} 61 protoEpoch.On("Clustering").Return(clusters, nil) 62 63 protoQuery := &protocol.EpochQuery{} 64 protoQuery.On("Current").Return(protoEpoch) 65 66 protoSnapshot := &protocol.Snapshot{} 67 protoSnapshot.On("Epochs").Return(protoQuery) 68 protoSnapshot.On("Identities", mock.Anything).Return( 69 func(selector flow.IdentityFilter) flow.IdentityList { 70 return cs.cluster.Filter(selector) 71 }, 72 nil, 73 ) 74 75 cs.protoState = &protocol.MutableState{} 76 cs.protoState.On("Final").Return(protoSnapshot) 77 78 cs.clusterID = "cluster-id" 79 clusterParams := &protocol.Params{} 80 clusterParams.On("ChainID").Return(cs.clusterID, nil) 81 82 cs.state.On("Params").Return(clusterParams) 83 84 // set up local module mock 85 cs.me = &module.Local{} 86 cs.me.On("NodeID").Return( 87 func() flow.Identifier { 88 return cs.myID 89 }, 90 ) 91 92 cs.payloadDB = make(map[flow.Identifier]*cluster.Payload) 93 94 // set up payload storage mock 95 cs.payloads = &storage.ClusterPayloads{} 96 cs.payloads.On("Store", mock.Anything, mock.Anything).Return( 97 func(blockID flow.Identifier, payload *cluster.Payload) error { 98 cs.payloadDB[blockID] = payload 99 return nil 100 }, 101 ) 102 cs.payloads.On("ByBlockID", mock.Anything).Return( 103 func(blockID flow.Identifier) *cluster.Payload { 104 return cs.payloadDB[blockID] 105 }, 106 func(blockID flow.Identifier) error { 107 _, exists := cs.payloadDB[blockID] 108 if !exists { 109 return storerr.ErrNotFound 110 } 111 return nil 112 }, 113 ) 114 115 // set up network conduit mock 116 cs.con = &mocknetwork.Conduit{} 117 cs.con.On("Publish", mock.Anything, mock.Anything).Return(nil) 118 cs.con.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil) 119 cs.con.On("Publish", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) 120 cs.con.On("Unicast", mock.Anything, mock.Anything).Return(nil) 121 122 // set up network module mock 123 cs.net = &mocknetwork.Network{} 124 cs.net.On("Register", mock.Anything, mock.Anything).Return( 125 func(channel channels.Channel, engine netint.MessageProcessor) netint.Conduit { 126 return cs.con 127 }, 128 nil, 129 ) 130 131 e, err := NewEngine(unittest.Logger(), cs.net, cs.me, cs.protoState, cs.payloads, cs.core) 132 require.NoError(cs.T(), err) 133 cs.engine = e 134 135 ready := func() <-chan struct{} { 136 channel := make(chan struct{}) 137 close(channel) 138 return channel 139 }() 140 141 cs.hotstuff.On("Start", mock.Anything) 142 cs.hotstuff.On("Ready", mock.Anything).Return(ready) 143 <-cs.engine.Ready() 144 } 145 146 // TestSendVote tests that single vote can be send and properly processed 147 func (cs *ComplianceSuite) TestSendVote() { 148 // create parameters to send a vote 149 blockID := unittest.IdentifierFixture() 150 view := rand.Uint64() 151 sig := unittest.SignatureFixture() 152 recipientID := unittest.IdentifierFixture() 153 154 // submit the vote 155 err := cs.engine.SendVote(blockID, view, sig, recipientID) 156 require.NoError(cs.T(), err, "should pass send vote") 157 158 done := func() <-chan struct{} { 159 channel := make(chan struct{}) 160 close(channel) 161 return channel 162 }() 163 164 cs.hotstuff.On("Done", mock.Anything).Return(done) 165 166 // The vote is transmitted asynchronously. We allow 10ms for the vote to be received: 167 <-time.After(10 * time.Millisecond) 168 <-cs.engine.Done() 169 170 // check it was called with right params 171 vote := messages.ClusterBlockVote{ 172 BlockID: blockID, 173 View: view, 174 SigData: sig, 175 } 176 cs.con.AssertCalled(cs.T(), "Unicast", &vote, recipientID) 177 } 178 179 // TestBroadcastProposalWithDelay tests broadcasting proposals with different 180 // inputs 181 func (cs *ComplianceSuite) TestBroadcastProposalWithDelay() { 182 183 // generate a parent with height and chain ID set 184 parent := unittest.ClusterBlockFixture() 185 parent.Header.ChainID = "test" 186 parent.Header.Height = 10 187 cs.headerDB[parent.ID()] = &parent 188 189 // create a block with the parent and store the payload with correct ID 190 block := unittest.ClusterBlockWithParent(&parent) 191 block.Header.ProposerID = cs.myID 192 cs.payloadDB[block.ID()] = block.Payload 193 194 // keep a duplicate of the correct header to check against leader 195 header := block.Header 196 197 // unset chain and height to make sure they are correctly reconstructed 198 block.Header.ChainID = "" 199 block.Header.Height = 0 200 201 cs.hotstuff.On("SubmitProposal", block.Header, parent.Header.View).Return(doneChan()).Once() 202 203 // submit to broadcast proposal 204 err := cs.engine.BroadcastProposalWithDelay(block.Header, 0) 205 require.NoError(cs.T(), err, "header broadcast should pass") 206 207 // make sure chain ID and height were reconstructed and 208 // we broadcast to correct nodes 209 header.ChainID = "test" 210 header.Height = 11 211 msg := messages.NewClusterBlockProposal(&block) 212 213 done := func() <-chan struct{} { 214 channel := make(chan struct{}) 215 close(channel) 216 return channel 217 }() 218 219 cs.hotstuff.On("Done", mock.Anything).Return(done) 220 221 <-time.After(10 * time.Millisecond) 222 <-cs.engine.Done() 223 cs.con.AssertCalled(cs.T(), "Publish", msg, cs.cluster[1].NodeID, cs.cluster[2].NodeID) 224 225 // should fail with wrong proposer 226 header.ProposerID = unittest.IdentifierFixture() 227 err = cs.engine.BroadcastProposalWithDelay(header, 0) 228 require.Error(cs.T(), err, "should fail with wrong proposer") 229 header.ProposerID = cs.myID 230 231 // should fail with changed (missing) parent 232 header.ParentID[0]++ 233 err = cs.engine.BroadcastProposalWithDelay(header, 0) 234 require.Error(cs.T(), err, "should fail with missing parent") 235 header.ParentID[0]-- 236 237 // should fail with wrong block ID (payload unavailable) 238 header.View++ 239 err = cs.engine.BroadcastProposalWithDelay(header, 0) 240 require.Error(cs.T(), err, "should fail with missing payload") 241 header.View-- 242 } 243 244 // TestSubmittingMultipleVotes tests that we can send multiple votes and they 245 // are queued and processed in expected way 246 func (cs *ComplianceSuite) TestSubmittingMultipleEntries() { 247 // create a vote 248 originID := unittest.IdentifierFixture() 249 voteCount := 15 250 251 channel := channels.ConsensusCluster(cs.clusterID) 252 253 var wg sync.WaitGroup 254 wg.Add(1) 255 go func() { 256 for i := 0; i < voteCount; i++ { 257 vote := messages.ClusterBlockVote{ 258 BlockID: unittest.IdentifierFixture(), 259 View: rand.Uint64(), 260 SigData: unittest.SignatureFixture(), 261 } 262 cs.voteAggregator.On("AddVote", &model.Vote{ 263 View: vote.View, 264 BlockID: vote.BlockID, 265 SignerID: originID, 266 SigData: vote.SigData, 267 }).Return().Once() 268 // execute the vote submission 269 _ = cs.engine.Process(channel, originID, &vote) 270 } 271 wg.Done() 272 }() 273 wg.Add(1) 274 go func() { 275 // create a proposal that directly descends from the latest finalized header 276 originID := cs.cluster[1].NodeID 277 block := unittest.ClusterBlockWithParent(cs.head) 278 proposal := messages.NewClusterBlockProposal(&block) 279 280 // store the data for retrieval 281 cs.headerDB[block.Header.ParentID] = cs.head 282 cs.hotstuff.On("SubmitProposal", block.Header, cs.head.Header.View).Return(doneChan()) 283 _ = cs.engine.Process(channel, originID, proposal) 284 wg.Done() 285 }() 286 287 wg.Wait() 288 289 time.Sleep(time.Second) 290 291 // check that submit vote was called with correct parameters 292 cs.hotstuff.AssertExpectations(cs.T()) 293 cs.voteAggregator.AssertExpectations(cs.T()) 294 } 295 296 // TestProcessUnsupportedMessageType tests that Process and ProcessLocal correctly handle a case where invalid message type 297 // was submitted from network layer. 298 func (cs *ComplianceSuite) TestProcessUnsupportedMessageType() { 299 invalidEvent := uint64(42) 300 err := cs.engine.Process("ch", unittest.IdentifierFixture(), invalidEvent) 301 // shouldn't result in error since byzantine inputs are expected 302 require.NoError(cs.T(), err) 303 // in case of local processing error cannot be consumed since all inputs are trusted 304 err = cs.engine.ProcessLocal(invalidEvent) 305 require.Error(cs.T(), err) 306 require.True(cs.T(), engine.IsIncompatibleInputTypeError(err)) 307 } 308 309 // TestOnFinalizedBlock tests if finalized block gets processed when send through `Engine`. 310 // Tests the whole processing pipeline. 311 func (cs *ComplianceSuite) TestOnFinalizedBlock() { 312 finalizedBlock := unittest.ClusterBlockFixture() 313 cs.head = &finalizedBlock 314 315 *cs.pending = module.PendingClusterBlockBuffer{} 316 cs.pending.On("PruneByView", finalizedBlock.Header.View).Return(nil).Once() 317 cs.pending.On("Size").Return(uint(0)).Once() 318 cs.engine.OnFinalizedBlock(model.BlockFromFlow(finalizedBlock.Header, finalizedBlock.Header.View-1)) 319 320 require.Eventually(cs.T(), 321 func() bool { 322 return cs.pending.AssertCalled(cs.T(), "PruneByView", finalizedBlock.Header.View) 323 }, time.Second, time.Millisecond*20) 324 }