github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/network/test/cohort2/unicast_authorization_test.go (about) 1 package cohort2 2 3 import ( 4 "context" 5 "io" 6 "reflect" 7 "testing" 8 "time" 9 10 "github.com/rs/zerolog" 11 mockery "github.com/stretchr/testify/mock" 12 "github.com/stretchr/testify/require" 13 "github.com/stretchr/testify/suite" 14 15 "github.com/onflow/flow-go/model/flow" 16 libp2pmessage "github.com/onflow/flow-go/model/libp2p/message" 17 "github.com/onflow/flow-go/model/messages" 18 "github.com/onflow/flow-go/module/irrecoverable" 19 "github.com/onflow/flow-go/network" 20 "github.com/onflow/flow-go/network/channels" 21 "github.com/onflow/flow-go/network/codec" 22 "github.com/onflow/flow-go/network/internal/testutils" 23 "github.com/onflow/flow-go/network/message" 24 "github.com/onflow/flow-go/network/mocknetwork" 25 "github.com/onflow/flow-go/network/p2p" 26 p2plogging "github.com/onflow/flow-go/network/p2p/logging" 27 "github.com/onflow/flow-go/network/underlay" 28 "github.com/onflow/flow-go/network/validator" 29 "github.com/onflow/flow-go/utils/unittest" 30 ) 31 32 // UnicastAuthorizationTestSuite tests that messages sent via unicast that are unauthenticated or unauthorized are correctly rejected. Each test on the test suite 33 // uses 2 networks, a sender and receiver. A mock slashing violation's consumer is used to assert the messages were rejected. Networks and the cancel func 34 // are set during each test run inside the test and remove after each test run in the TearDownTest callback. 35 type UnicastAuthorizationTestSuite struct { 36 suite.Suite 37 channelCloseDuration time.Duration 38 logger zerolog.Logger 39 40 codec *overridableMessageEncoder 41 42 libP2PNodes []p2p.LibP2PNode 43 // senderNetwork is the networking layer instance that will be used to send the message. 44 senderNetwork network.EngineRegistry 45 // senderID the identity on the mw sending the message 46 senderID *flow.Identity 47 // receiverNetwork is the networking layer instance that will be used to receive the message. 48 receiverNetwork network.EngineRegistry 49 // receiverID the identity on the mw sending the message 50 receiverID *flow.Identity 51 // providers id providers generated at beginning of a test run 52 providers []*unittest.UpdatableIDProvider 53 // cancel is the cancel func from the context that was used to start the networks in a test run 54 cancel context.CancelFunc 55 sporkId flow.Identifier 56 // waitCh is the channel used to wait for the networks to perform authorization and invoke the slashing 57 // violation's consumer before making mock assertions and cleaning up resources 58 waitCh chan struct{} 59 } 60 61 // TestUnicastAuthorizationTestSuite runs all the test methods in this test suit 62 func TestUnicastAuthorizationTestSuite(t *testing.T) { 63 suite.Run(t, new(UnicastAuthorizationTestSuite)) 64 } 65 66 func (u *UnicastAuthorizationTestSuite) SetupTest() { 67 u.logger = unittest.Logger() 68 u.channelCloseDuration = 100 * time.Millisecond 69 // this ch will allow us to wait until the expected method call happens before shutting down networks. 70 u.waitCh = make(chan struct{}) 71 } 72 73 func (u *UnicastAuthorizationTestSuite) TearDownTest() { 74 u.stopNetworksAndLibp2pNodes() 75 } 76 77 // setupNetworks will setup the sender and receiver networks with the given slashing violations consumer. 78 func (u *UnicastAuthorizationTestSuite) setupNetworks(slashingViolationsConsumer network.ViolationsConsumer) { 79 u.sporkId = unittest.IdentifierFixture() 80 ids, libP2PNodes := testutils.LibP2PNodeForNetworkFixture(u.T(), u.sporkId, 2) 81 u.codec = newOverridableMessageEncoder(unittest.NetworkCodec()) 82 nets, providers := testutils.NetworksFixture( 83 u.T(), 84 u.sporkId, 85 ids, 86 libP2PNodes, 87 underlay.WithCodec(u.codec), 88 underlay.WithSlashingViolationConsumerFactory(func(_ network.ConduitAdapter) network.ViolationsConsumer { 89 return slashingViolationsConsumer 90 })) 91 require.Len(u.T(), ids, 2) 92 require.Len(u.T(), providers, 2) 93 require.Len(u.T(), nets, 2) 94 95 u.senderNetwork = nets[0] 96 u.receiverNetwork = nets[1] 97 u.senderID = ids[0] 98 u.receiverID = ids[1] 99 u.providers = providers 100 u.libP2PNodes = libP2PNodes 101 } 102 103 // startNetworksAndLibp2pNodes will start both sender and receiver networks with an irrecoverable signaler context and set the context cancel func. 104 func (u *UnicastAuthorizationTestSuite) startNetworksAndLibp2pNodes() { 105 ctx, cancel := context.WithCancel(context.Background()) 106 sigCtx, _ := irrecoverable.WithSignaler(ctx) 107 108 testutils.StartNodes(sigCtx, u.T(), u.libP2PNodes) 109 testutils.StartNetworks(sigCtx, u.T(), []network.EngineRegistry{u.senderNetwork, u.receiverNetwork}) 110 unittest.RequireComponentsReadyBefore(u.T(), 1*time.Second, u.senderNetwork, u.receiverNetwork) 111 112 u.cancel = cancel 113 } 114 115 // stopNetworksAndLibp2pNodes will stop all networks and libp2p nodes and wait for them to stop. 116 func (u *UnicastAuthorizationTestSuite) stopNetworksAndLibp2pNodes() { 117 u.cancel() // cancel context to stop libp2p nodes. 118 119 testutils.StopComponents(u.T(), []network.EngineRegistry{u.senderNetwork, u.receiverNetwork}, 1*time.Second) 120 unittest.RequireComponentsDoneBefore(u.T(), 1*time.Second, u.senderNetwork, u.receiverNetwork) 121 } 122 123 // TestUnicastAuthorization_UnstakedPeer tests that messages sent via unicast by an unstaked peer is correctly rejected. 124 func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnstakedPeer() { 125 slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T()) 126 u.setupNetworks(slashingViolationsConsumer) 127 128 expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID) 129 require.NoError(u.T(), err) 130 131 var nilID *flow.Identity 132 expectedViolation := &network.Violation{ 133 Identity: nilID, // because the peer will be unverified this identity will be nil 134 PeerID: p2plogging.PeerId(expectedSenderPeerID), 135 MsgType: "", // message will not be decoded before OnSenderEjectedError is logged, we won't log message type 136 Channel: channels.TestNetworkChannel, // message will not be decoded before OnSenderEjectedError is logged, we won't log peer ID 137 Protocol: message.ProtocolTypeUnicast, 138 Err: validator.ErrIdentityUnverified, 139 } 140 slashingViolationsConsumer.On("OnUnAuthorizedSenderError", expectedViolation).Return(nil).Once().Run(func(args mockery.Arguments) { 141 close(u.waitCh) 142 }) 143 144 u.startNetworksAndLibp2pNodes() 145 146 // overriding the identity provide of the receiver node to return an empty identity list so that the 147 // sender node looks unstaked to its networking layer and hence it sends an UnAuthorizedSenderError upon receiving a message 148 // from the sender node 149 u.providers[1].SetIdentities(nil) 150 151 _, err = u.receiverNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{}) 152 require.NoError(u.T(), err) 153 154 senderCon, err := u.senderNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{}) 155 require.NoError(u.T(), err) 156 157 // send message via unicast 158 err = senderCon.Unicast(&libp2pmessage.TestMessage{ 159 Text: string("hello"), 160 }, u.receiverID.NodeID) 161 require.NoError(u.T(), err) 162 163 // wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens 164 unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time") 165 } 166 167 // TestUnicastAuthorization_EjectedPeer tests that messages sent via unicast by an ejected peer is correctly rejected. 168 func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_EjectedPeer() { 169 slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T()) 170 u.setupNetworks(slashingViolationsConsumer) 171 //NOTE: setup ejected identity 172 u.senderID.EpochParticipationStatus = flow.EpochParticipationStatusEjected 173 174 // overriding the identity provide of the receiver node to return the ejected identity so that the 175 // sender node looks ejected to its networking layer and hence it sends a SenderEjectedError upon receiving a message 176 // from the sender node 177 u.providers[1].SetIdentities(flow.IdentityList{u.senderID}) 178 179 expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID) 180 require.NoError(u.T(), err) 181 182 expectedViolation := &network.Violation{ 183 Identity: u.senderID, // we expect this method to be called with the ejected identity 184 OriginID: u.senderID.NodeID, 185 PeerID: p2plogging.PeerId(expectedSenderPeerID), 186 MsgType: "", // message will not be decoded before OnSenderEjectedError is logged, we won't log message type 187 Channel: channels.TestNetworkChannel, // message will not be decoded before OnSenderEjectedError is logged, we won't log peer ID 188 Protocol: message.ProtocolTypeUnicast, 189 Err: validator.ErrSenderEjected, 190 } 191 slashingViolationsConsumer.On("OnSenderEjectedError", expectedViolation). 192 Return(nil).Once().Run(func(args mockery.Arguments) { 193 close(u.waitCh) 194 }) 195 196 u.startNetworksAndLibp2pNodes() 197 198 _, err = u.receiverNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{}) 199 require.NoError(u.T(), err) 200 201 senderCon, err := u.senderNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{}) 202 require.NoError(u.T(), err) 203 204 // send message via unicast 205 err = senderCon.Unicast(&libp2pmessage.TestMessage{ 206 Text: string("hello"), 207 }, u.receiverID.NodeID) 208 require.NoError(u.T(), err) 209 210 // wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens 211 unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time") 212 } 213 214 // TestUnicastAuthorization_UnauthorizedPeer tests that messages sent via unicast by an unauthorized peer is correctly rejected. 215 func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnauthorizedPeer() { 216 slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T()) 217 u.setupNetworks(slashingViolationsConsumer) 218 219 expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID) 220 require.NoError(u.T(), err) 221 222 expectedViolation := &network.Violation{ 223 Identity: u.senderID, 224 OriginID: u.senderID.NodeID, 225 PeerID: p2plogging.PeerId(expectedSenderPeerID), 226 MsgType: "*message.TestMessage", 227 Channel: channels.ConsensusCommittee, 228 Protocol: message.ProtocolTypeUnicast, 229 Err: message.ErrUnauthorizedMessageOnChannel, 230 } 231 232 slashingViolationsConsumer.On("OnUnAuthorizedSenderError", expectedViolation). 233 Return(nil).Once().Run(func(args mockery.Arguments) { 234 close(u.waitCh) 235 }) 236 237 u.startNetworksAndLibp2pNodes() 238 239 _, err = u.receiverNetwork.Register(channels.ConsensusCommittee, &mocknetwork.MessageProcessor{}) 240 require.NoError(u.T(), err) 241 242 senderCon, err := u.senderNetwork.Register(channels.ConsensusCommittee, &mocknetwork.MessageProcessor{}) 243 require.NoError(u.T(), err) 244 245 // send message via unicast; a test message must only be unicasted on the TestNetworkChannel, not on the ConsensusCommittee channel 246 // so we expect an unauthorized sender error 247 err = senderCon.Unicast(&libp2pmessage.TestMessage{ 248 Text: string("hello"), 249 }, u.receiverID.NodeID) 250 require.NoError(u.T(), err) 251 252 // wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens 253 unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time") 254 } 255 256 // TestUnicastAuthorization_UnknownMsgCode tests that messages sent via unicast with an unknown message code is correctly rejected. 257 func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnknownMsgCode() { 258 slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T()) 259 u.setupNetworks(slashingViolationsConsumer) 260 261 expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID) 262 require.NoError(u.T(), err) 263 264 invalidMessageCode := codec.MessageCode(byte('X')) 265 // register a custom encoder that encodes the message with an invalid message code when encoding a string. 266 u.codec.RegisterEncoder(reflect.TypeOf(""), func(v interface{}) ([]byte, error) { 267 e, err := unittest.NetworkCodec().Encode(&libp2pmessage.TestMessage{ 268 Text: v.(string), 269 }) 270 require.NoError(u.T(), err) 271 // manipulate message code byte 272 invalidMessageCode := codec.MessageCode(byte('X')) 273 e[0] = invalidMessageCode.Uint8() 274 return e, nil 275 }) 276 277 var nilID *flow.Identity 278 expectedViolation := &network.Violation{ 279 Identity: nilID, 280 PeerID: p2plogging.PeerId(expectedSenderPeerID), 281 MsgType: "", 282 Channel: channels.TestNetworkChannel, 283 Protocol: message.ProtocolTypeUnicast, 284 Err: codec.NewUnknownMsgCodeErr(invalidMessageCode), 285 } 286 287 slashingViolationsConsumer.On("OnUnknownMsgTypeError", expectedViolation). 288 Return(nil).Once().Run(func(args mockery.Arguments) { 289 close(u.waitCh) 290 }) 291 292 u.startNetworksAndLibp2pNodes() 293 294 _, err = u.receiverNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{}) 295 require.NoError(u.T(), err) 296 297 senderCon, err := u.senderNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{}) 298 require.NoError(u.T(), err) 299 300 // send message via unicast 301 err = senderCon.Unicast("hello!", u.receiverID.NodeID) 302 require.NoError(u.T(), err) 303 304 // wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens 305 unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time") 306 } 307 308 // TestUnicastAuthorization_WrongMsgCode tests that messages sent via unicast with a message code that does not match the underlying message type are correctly rejected. 309 func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_WrongMsgCode() { 310 slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T()) 311 u.setupNetworks(slashingViolationsConsumer) 312 313 expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID) 314 require.NoError(u.T(), err) 315 316 modifiedMessageCode := codec.CodeDKGMessage 317 // register a custom encoder that overrides the message code when encoding a TestMessage. 318 u.codec.RegisterEncoder(reflect.TypeOf(&libp2pmessage.TestMessage{}), func(v interface{}) ([]byte, error) { 319 e, err := unittest.NetworkCodec().Encode(v) 320 require.NoError(u.T(), err) 321 e[0] = modifiedMessageCode.Uint8() 322 return e, nil 323 }) 324 325 expectedViolation := &network.Violation{ 326 Identity: u.senderID, 327 OriginID: u.senderID.NodeID, 328 PeerID: p2plogging.PeerId(expectedSenderPeerID), 329 MsgType: "*messages.DKGMessage", 330 Channel: channels.TestNetworkChannel, 331 Protocol: message.ProtocolTypeUnicast, 332 Err: message.ErrUnauthorizedMessageOnChannel, 333 } 334 335 slashingViolationsConsumer.On("OnUnAuthorizedSenderError", expectedViolation). 336 Return(nil).Once().Run(func(args mockery.Arguments) { 337 close(u.waitCh) 338 }) 339 340 u.startNetworksAndLibp2pNodes() 341 342 _, err = u.receiverNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{}) 343 require.NoError(u.T(), err) 344 345 senderCon, err := u.senderNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{}) 346 require.NoError(u.T(), err) 347 348 // send message via unicast 349 err = senderCon.Unicast(&libp2pmessage.TestMessage{ 350 Text: string("hello"), 351 }, u.receiverID.NodeID) 352 require.NoError(u.T(), err) 353 354 // wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens 355 unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time") 356 } 357 358 // TestUnicastAuthorization_PublicChannel tests that messages sent via unicast on a public channel are not rejected for any reason. 359 func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_PublicChannel() { 360 slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T()) 361 u.setupNetworks(slashingViolationsConsumer) 362 u.startNetworksAndLibp2pNodes() 363 364 msg := &libp2pmessage.TestMessage{ 365 Text: string("hello"), 366 } 367 368 // mock a message processor that will receive the message. 369 receiverEngine := &mocknetwork.MessageProcessor{} 370 receiverEngine.On("Process", channels.PublicPushBlocks, u.senderID.NodeID, msg).Run( 371 func(args mockery.Arguments) { 372 close(u.waitCh) 373 }).Return(nil).Once() 374 _, err := u.receiverNetwork.Register(channels.PublicPushBlocks, receiverEngine) 375 require.NoError(u.T(), err) 376 377 senderCon, err := u.senderNetwork.Register(channels.PublicPushBlocks, &mocknetwork.MessageProcessor{}) 378 require.NoError(u.T(), err) 379 380 // send message via unicast 381 err = senderCon.Unicast(&libp2pmessage.TestMessage{ 382 Text: string("hello"), 383 }, u.receiverID.NodeID) 384 require.NoError(u.T(), err) 385 386 // wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens 387 unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time") 388 } 389 390 // TestUnicastAuthorization_UnauthorizedUnicastOnChannel tests that messages sent via unicast that are not authorized for unicast are rejected. 391 func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnauthorizedUnicastOnChannel() { 392 slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T()) 393 u.setupNetworks(slashingViolationsConsumer) 394 395 // set sender id role to RoleConsensus to avoid unauthorized sender validation error 396 u.senderID.Role = flow.RoleConsensus 397 398 expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID) 399 require.NoError(u.T(), err) 400 401 expectedViolation := &network.Violation{ 402 Identity: u.senderID, 403 OriginID: u.senderID.NodeID, 404 PeerID: p2plogging.PeerId(expectedSenderPeerID), 405 MsgType: "*messages.BlockProposal", 406 Channel: channels.ConsensusCommittee, 407 Protocol: message.ProtocolTypeUnicast, 408 Err: message.ErrUnauthorizedUnicastOnChannel, 409 } 410 411 slashingViolationsConsumer.On("OnUnauthorizedUnicastOnChannel", expectedViolation). 412 Return(nil).Once().Run(func(args mockery.Arguments) { 413 close(u.waitCh) 414 }) 415 416 u.startNetworksAndLibp2pNodes() 417 418 _, err = u.receiverNetwork.Register(channels.ConsensusCommittee, &mocknetwork.MessageProcessor{}) 419 require.NoError(u.T(), err) 420 421 senderCon, err := u.senderNetwork.Register(channels.ConsensusCommittee, &mocknetwork.MessageProcessor{}) 422 require.NoError(u.T(), err) 423 424 // messages.BlockProposal is not authorized to be sent via unicast over the ConsensusCommittee channel 425 payload := unittest.ProposalFixture() 426 // send message via unicast 427 err = senderCon.Unicast(payload, u.receiverID.NodeID) 428 require.NoError(u.T(), err) 429 430 // wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens 431 unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time") 432 } 433 434 // TestUnicastAuthorization_ReceiverHasNoSubscription tests that messages sent via unicast are rejected on the receiver end if the receiver does not have a subscription 435 // to the channel of the message. 436 func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_ReceiverHasNoSubscription() { 437 slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T()) 438 u.setupNetworks(slashingViolationsConsumer) 439 440 expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID) 441 require.NoError(u.T(), err) 442 443 expectedViolation := &network.Violation{ 444 Identity: nil, 445 PeerID: p2plogging.PeerId(expectedSenderPeerID), 446 MsgType: "*message.TestMessage", 447 Channel: channels.TestNetworkChannel, 448 Protocol: message.ProtocolTypeUnicast, 449 Err: underlay.ErrUnicastMsgWithoutSub, 450 } 451 452 slashingViolationsConsumer.On("OnUnauthorizedUnicastOnChannel", expectedViolation). 453 Return(nil).Once().Run(func(args mockery.Arguments) { 454 close(u.waitCh) 455 }) 456 457 u.startNetworksAndLibp2pNodes() 458 459 senderCon, err := u.senderNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{}) 460 require.NoError(u.T(), err) 461 462 // send message via unicast 463 err = senderCon.Unicast(&libp2pmessage.TestMessage{ 464 Text: string("hello"), 465 }, u.receiverID.NodeID) 466 require.NoError(u.T(), err) 467 468 // wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens 469 unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time") 470 } 471 472 // TestUnicastAuthorization_ReceiverHasSubscription tests that messages sent via unicast are processed on the receiver end if the receiver does have a subscription 473 // to the channel of the message. 474 func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_ReceiverHasSubscription() { 475 slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T()) 476 u.setupNetworks(slashingViolationsConsumer) 477 u.startNetworksAndLibp2pNodes() 478 479 msg := &messages.EntityRequest{ 480 EntityIDs: unittest.IdentifierListFixture(10), 481 } 482 483 // both sender and receiver must have an authorized role to send and receive messages on the ConsensusCommittee channel. 484 u.senderID.Role = flow.RoleConsensus 485 u.receiverID.Role = flow.RoleExecution 486 487 receiverEngine := &mocknetwork.MessageProcessor{} 488 receiverEngine.On("Process", channels.RequestReceiptsByBlockID, u.senderID.NodeID, msg).Run( 489 func(args mockery.Arguments) { 490 close(u.waitCh) 491 }).Return(nil).Once() 492 _, err := u.receiverNetwork.Register(channels.RequestReceiptsByBlockID, receiverEngine) 493 require.NoError(u.T(), err) 494 495 senderCon, err := u.senderNetwork.Register(channels.RequestReceiptsByBlockID, &mocknetwork.MessageProcessor{}) 496 require.NoError(u.T(), err) 497 498 // send message via unicast 499 err = senderCon.Unicast(msg, u.receiverID.NodeID) 500 require.NoError(u.T(), err) 501 502 // wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens 503 unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time") 504 } 505 506 // overridableMessageEncoder is a codec that allows to override the encoder for a specific type only for sake of testing. 507 // We specifically use this to override the encoder for the TestMessage type to encode it with an invalid message code. 508 type overridableMessageEncoder struct { 509 codec network.Codec 510 specificEncoder map[reflect.Type]func(interface{}) ([]byte, error) 511 } 512 513 var _ network.Codec = (*overridableMessageEncoder)(nil) 514 515 func newOverridableMessageEncoder(codec network.Codec) *overridableMessageEncoder { 516 return &overridableMessageEncoder{ 517 codec: codec, 518 specificEncoder: make(map[reflect.Type]func(interface{}) ([]byte, error)), 519 } 520 } 521 522 // RegisterEncoder registers an encoder for a specific type, overriding the default encoder for that type. 523 func (u *overridableMessageEncoder) RegisterEncoder(t reflect.Type, encoder func(interface{}) ([]byte, error)) { 524 u.specificEncoder[t] = encoder 525 } 526 527 // NewEncoder creates a new encoder. 528 func (u *overridableMessageEncoder) NewEncoder(w io.Writer) network.Encoder { 529 return u.codec.NewEncoder(w) 530 } 531 532 // NewDecoder creates a new decoder. 533 func (u *overridableMessageEncoder) NewDecoder(r io.Reader) network.Decoder { 534 return u.codec.NewDecoder(r) 535 } 536 537 // Encode encodes a value into a byte slice. If a specific encoder is registered for the type of the value, it will be used. 538 // Otherwise, the default encoder will be used. 539 func (u *overridableMessageEncoder) Encode(v interface{}) ([]byte, error) { 540 if encoder, ok := u.specificEncoder[reflect.TypeOf(v)]; ok { 541 return encoder(v) 542 } 543 return u.codec.Encode(v) 544 } 545 546 // Decode decodes a byte slice into a value. It uses the default decoder. 547 func (u *overridableMessageEncoder) Decode(data []byte) (interface{}, error) { 548 return u.codec.Decode(data) 549 }