github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/network/test/cohort2/unicast_authorization_test.go (about)

     1  package cohort2
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"reflect"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/rs/zerolog"
    11  	mockery "github.com/stretchr/testify/mock"
    12  	"github.com/stretchr/testify/require"
    13  	"github.com/stretchr/testify/suite"
    14  
    15  	"github.com/onflow/flow-go/model/flow"
    16  	libp2pmessage "github.com/onflow/flow-go/model/libp2p/message"
    17  	"github.com/onflow/flow-go/model/messages"
    18  	"github.com/onflow/flow-go/module/irrecoverable"
    19  	"github.com/onflow/flow-go/network"
    20  	"github.com/onflow/flow-go/network/channels"
    21  	"github.com/onflow/flow-go/network/codec"
    22  	"github.com/onflow/flow-go/network/internal/testutils"
    23  	"github.com/onflow/flow-go/network/message"
    24  	"github.com/onflow/flow-go/network/mocknetwork"
    25  	"github.com/onflow/flow-go/network/p2p"
    26  	p2plogging "github.com/onflow/flow-go/network/p2p/logging"
    27  	"github.com/onflow/flow-go/network/underlay"
    28  	"github.com/onflow/flow-go/network/validator"
    29  	"github.com/onflow/flow-go/utils/unittest"
    30  )
    31  
    32  // UnicastAuthorizationTestSuite tests that messages sent via unicast that are unauthenticated or unauthorized are correctly rejected. Each test on the test suite
    33  // uses 2 networks, a sender and receiver. A mock slashing violation's consumer is used to assert the messages were rejected. Networks and the cancel func
    34  // are set during each test run inside the test and remove after each test run in the TearDownTest callback.
    35  type UnicastAuthorizationTestSuite struct {
    36  	suite.Suite
    37  	channelCloseDuration time.Duration
    38  	logger               zerolog.Logger
    39  
    40  	codec *overridableMessageEncoder
    41  
    42  	libP2PNodes []p2p.LibP2PNode
    43  	// senderNetwork is the networking layer instance that will be used to send the message.
    44  	senderNetwork network.EngineRegistry
    45  	// senderID the identity on the mw sending the message
    46  	senderID *flow.Identity
    47  	// receiverNetwork is the networking layer instance that will be used to receive the message.
    48  	receiverNetwork network.EngineRegistry
    49  	// receiverID the identity on the mw sending the message
    50  	receiverID *flow.Identity
    51  	// providers id providers generated at beginning of a test run
    52  	providers []*unittest.UpdatableIDProvider
    53  	// cancel is the cancel func from the context that was used to start the networks in a test run
    54  	cancel  context.CancelFunc
    55  	sporkId flow.Identifier
    56  	// waitCh is the channel used to wait for the networks to perform authorization and invoke the slashing
    57  	// violation's consumer before making mock assertions and cleaning up resources
    58  	waitCh chan struct{}
    59  }
    60  
    61  // TestUnicastAuthorizationTestSuite runs all the test methods in this test suit
    62  func TestUnicastAuthorizationTestSuite(t *testing.T) {
    63  	suite.Run(t, new(UnicastAuthorizationTestSuite))
    64  }
    65  
    66  func (u *UnicastAuthorizationTestSuite) SetupTest() {
    67  	u.logger = unittest.Logger()
    68  	u.channelCloseDuration = 100 * time.Millisecond
    69  	// this ch will allow us to wait until the expected method call happens before shutting down networks.
    70  	u.waitCh = make(chan struct{})
    71  }
    72  
    73  func (u *UnicastAuthorizationTestSuite) TearDownTest() {
    74  	u.stopNetworksAndLibp2pNodes()
    75  }
    76  
    77  // setupNetworks will setup the sender and receiver networks with the given slashing violations consumer.
    78  func (u *UnicastAuthorizationTestSuite) setupNetworks(slashingViolationsConsumer network.ViolationsConsumer) {
    79  	u.sporkId = unittest.IdentifierFixture()
    80  	ids, libP2PNodes := testutils.LibP2PNodeForNetworkFixture(u.T(), u.sporkId, 2)
    81  	u.codec = newOverridableMessageEncoder(unittest.NetworkCodec())
    82  	nets, providers := testutils.NetworksFixture(
    83  		u.T(),
    84  		u.sporkId,
    85  		ids,
    86  		libP2PNodes,
    87  		underlay.WithCodec(u.codec),
    88  		underlay.WithSlashingViolationConsumerFactory(func(_ network.ConduitAdapter) network.ViolationsConsumer {
    89  			return slashingViolationsConsumer
    90  		}))
    91  	require.Len(u.T(), ids, 2)
    92  	require.Len(u.T(), providers, 2)
    93  	require.Len(u.T(), nets, 2)
    94  
    95  	u.senderNetwork = nets[0]
    96  	u.receiverNetwork = nets[1]
    97  	u.senderID = ids[0]
    98  	u.receiverID = ids[1]
    99  	u.providers = providers
   100  	u.libP2PNodes = libP2PNodes
   101  }
   102  
   103  // startNetworksAndLibp2pNodes will start both sender and receiver networks with an irrecoverable signaler context and set the context cancel func.
   104  func (u *UnicastAuthorizationTestSuite) startNetworksAndLibp2pNodes() {
   105  	ctx, cancel := context.WithCancel(context.Background())
   106  	sigCtx, _ := irrecoverable.WithSignaler(ctx)
   107  
   108  	testutils.StartNodes(sigCtx, u.T(), u.libP2PNodes)
   109  	testutils.StartNetworks(sigCtx, u.T(), []network.EngineRegistry{u.senderNetwork, u.receiverNetwork})
   110  	unittest.RequireComponentsReadyBefore(u.T(), 1*time.Second, u.senderNetwork, u.receiverNetwork)
   111  
   112  	u.cancel = cancel
   113  }
   114  
   115  // stopNetworksAndLibp2pNodes will stop all networks and libp2p nodes and wait for them to stop.
   116  func (u *UnicastAuthorizationTestSuite) stopNetworksAndLibp2pNodes() {
   117  	u.cancel() // cancel context to stop libp2p nodes.
   118  
   119  	testutils.StopComponents(u.T(), []network.EngineRegistry{u.senderNetwork, u.receiverNetwork}, 1*time.Second)
   120  	unittest.RequireComponentsDoneBefore(u.T(), 1*time.Second, u.senderNetwork, u.receiverNetwork)
   121  }
   122  
   123  // TestUnicastAuthorization_UnstakedPeer tests that messages sent via unicast by an unstaked peer is correctly rejected.
   124  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnstakedPeer() {
   125  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   126  	u.setupNetworks(slashingViolationsConsumer)
   127  
   128  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   129  	require.NoError(u.T(), err)
   130  
   131  	var nilID *flow.Identity
   132  	expectedViolation := &network.Violation{
   133  		Identity: nilID, // because the peer will be unverified this identity will be nil
   134  		PeerID:   p2plogging.PeerId(expectedSenderPeerID),
   135  		MsgType:  "",                          // message will not be decoded before OnSenderEjectedError is logged, we won't log message type
   136  		Channel:  channels.TestNetworkChannel, // message will not be decoded before OnSenderEjectedError is logged, we won't log peer ID
   137  		Protocol: message.ProtocolTypeUnicast,
   138  		Err:      validator.ErrIdentityUnverified,
   139  	}
   140  	slashingViolationsConsumer.On("OnUnAuthorizedSenderError", expectedViolation).Return(nil).Once().Run(func(args mockery.Arguments) {
   141  		close(u.waitCh)
   142  	})
   143  
   144  	u.startNetworksAndLibp2pNodes()
   145  
   146  	// overriding the identity provide of the receiver node to return an empty identity list so that the
   147  	// sender node looks unstaked to its networking layer and hence it sends an UnAuthorizedSenderError upon receiving a message
   148  	// from the sender node
   149  	u.providers[1].SetIdentities(nil)
   150  
   151  	_, err = u.receiverNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{})
   152  	require.NoError(u.T(), err)
   153  
   154  	senderCon, err := u.senderNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{})
   155  	require.NoError(u.T(), err)
   156  
   157  	// send message via unicast
   158  	err = senderCon.Unicast(&libp2pmessage.TestMessage{
   159  		Text: string("hello"),
   160  	}, u.receiverID.NodeID)
   161  	require.NoError(u.T(), err)
   162  
   163  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   164  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   165  }
   166  
   167  // TestUnicastAuthorization_EjectedPeer tests that messages sent via unicast by an ejected peer is correctly rejected.
   168  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_EjectedPeer() {
   169  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   170  	u.setupNetworks(slashingViolationsConsumer)
   171  	//NOTE: setup ejected identity
   172  	u.senderID.EpochParticipationStatus = flow.EpochParticipationStatusEjected
   173  
   174  	// overriding the identity provide of the receiver node to return the ejected identity so that the
   175  	// sender node looks ejected to its networking layer and hence it sends a SenderEjectedError upon receiving a message
   176  	// from the sender node
   177  	u.providers[1].SetIdentities(flow.IdentityList{u.senderID})
   178  
   179  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   180  	require.NoError(u.T(), err)
   181  
   182  	expectedViolation := &network.Violation{
   183  		Identity: u.senderID, // we expect this method to be called with the ejected identity
   184  		OriginID: u.senderID.NodeID,
   185  		PeerID:   p2plogging.PeerId(expectedSenderPeerID),
   186  		MsgType:  "",                          // message will not be decoded before OnSenderEjectedError is logged, we won't log message type
   187  		Channel:  channels.TestNetworkChannel, // message will not be decoded before OnSenderEjectedError is logged, we won't log peer ID
   188  		Protocol: message.ProtocolTypeUnicast,
   189  		Err:      validator.ErrSenderEjected,
   190  	}
   191  	slashingViolationsConsumer.On("OnSenderEjectedError", expectedViolation).
   192  		Return(nil).Once().Run(func(args mockery.Arguments) {
   193  		close(u.waitCh)
   194  	})
   195  
   196  	u.startNetworksAndLibp2pNodes()
   197  
   198  	_, err = u.receiverNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{})
   199  	require.NoError(u.T(), err)
   200  
   201  	senderCon, err := u.senderNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{})
   202  	require.NoError(u.T(), err)
   203  
   204  	// send message via unicast
   205  	err = senderCon.Unicast(&libp2pmessage.TestMessage{
   206  		Text: string("hello"),
   207  	}, u.receiverID.NodeID)
   208  	require.NoError(u.T(), err)
   209  
   210  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   211  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   212  }
   213  
   214  // TestUnicastAuthorization_UnauthorizedPeer tests that messages sent via unicast by an unauthorized peer is correctly rejected.
   215  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnauthorizedPeer() {
   216  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   217  	u.setupNetworks(slashingViolationsConsumer)
   218  
   219  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   220  	require.NoError(u.T(), err)
   221  
   222  	expectedViolation := &network.Violation{
   223  		Identity: u.senderID,
   224  		OriginID: u.senderID.NodeID,
   225  		PeerID:   p2plogging.PeerId(expectedSenderPeerID),
   226  		MsgType:  "*message.TestMessage",
   227  		Channel:  channels.ConsensusCommittee,
   228  		Protocol: message.ProtocolTypeUnicast,
   229  		Err:      message.ErrUnauthorizedMessageOnChannel,
   230  	}
   231  
   232  	slashingViolationsConsumer.On("OnUnAuthorizedSenderError", expectedViolation).
   233  		Return(nil).Once().Run(func(args mockery.Arguments) {
   234  		close(u.waitCh)
   235  	})
   236  
   237  	u.startNetworksAndLibp2pNodes()
   238  
   239  	_, err = u.receiverNetwork.Register(channels.ConsensusCommittee, &mocknetwork.MessageProcessor{})
   240  	require.NoError(u.T(), err)
   241  
   242  	senderCon, err := u.senderNetwork.Register(channels.ConsensusCommittee, &mocknetwork.MessageProcessor{})
   243  	require.NoError(u.T(), err)
   244  
   245  	// send message via unicast; a test message must only be unicasted on the TestNetworkChannel, not on the ConsensusCommittee channel
   246  	// so we expect an unauthorized sender error
   247  	err = senderCon.Unicast(&libp2pmessage.TestMessage{
   248  		Text: string("hello"),
   249  	}, u.receiverID.NodeID)
   250  	require.NoError(u.T(), err)
   251  
   252  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   253  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   254  }
   255  
   256  // TestUnicastAuthorization_UnknownMsgCode tests that messages sent via unicast with an unknown message code is correctly rejected.
   257  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnknownMsgCode() {
   258  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   259  	u.setupNetworks(slashingViolationsConsumer)
   260  
   261  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   262  	require.NoError(u.T(), err)
   263  
   264  	invalidMessageCode := codec.MessageCode(byte('X'))
   265  	// register a custom encoder that encodes the message with an invalid message code when encoding a string.
   266  	u.codec.RegisterEncoder(reflect.TypeOf(""), func(v interface{}) ([]byte, error) {
   267  		e, err := unittest.NetworkCodec().Encode(&libp2pmessage.TestMessage{
   268  			Text: v.(string),
   269  		})
   270  		require.NoError(u.T(), err)
   271  		// manipulate message code byte
   272  		invalidMessageCode := codec.MessageCode(byte('X'))
   273  		e[0] = invalidMessageCode.Uint8()
   274  		return e, nil
   275  	})
   276  
   277  	var nilID *flow.Identity
   278  	expectedViolation := &network.Violation{
   279  		Identity: nilID,
   280  		PeerID:   p2plogging.PeerId(expectedSenderPeerID),
   281  		MsgType:  "",
   282  		Channel:  channels.TestNetworkChannel,
   283  		Protocol: message.ProtocolTypeUnicast,
   284  		Err:      codec.NewUnknownMsgCodeErr(invalidMessageCode),
   285  	}
   286  
   287  	slashingViolationsConsumer.On("OnUnknownMsgTypeError", expectedViolation).
   288  		Return(nil).Once().Run(func(args mockery.Arguments) {
   289  		close(u.waitCh)
   290  	})
   291  
   292  	u.startNetworksAndLibp2pNodes()
   293  
   294  	_, err = u.receiverNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{})
   295  	require.NoError(u.T(), err)
   296  
   297  	senderCon, err := u.senderNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{})
   298  	require.NoError(u.T(), err)
   299  
   300  	// send message via unicast
   301  	err = senderCon.Unicast("hello!", u.receiverID.NodeID)
   302  	require.NoError(u.T(), err)
   303  
   304  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   305  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   306  }
   307  
   308  // TestUnicastAuthorization_WrongMsgCode tests that messages sent via unicast with a message code that does not match the underlying message type are correctly rejected.
   309  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_WrongMsgCode() {
   310  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   311  	u.setupNetworks(slashingViolationsConsumer)
   312  
   313  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   314  	require.NoError(u.T(), err)
   315  
   316  	modifiedMessageCode := codec.CodeDKGMessage
   317  	// register a custom encoder that overrides the message code when encoding a TestMessage.
   318  	u.codec.RegisterEncoder(reflect.TypeOf(&libp2pmessage.TestMessage{}), func(v interface{}) ([]byte, error) {
   319  		e, err := unittest.NetworkCodec().Encode(v)
   320  		require.NoError(u.T(), err)
   321  		e[0] = modifiedMessageCode.Uint8()
   322  		return e, nil
   323  	})
   324  
   325  	expectedViolation := &network.Violation{
   326  		Identity: u.senderID,
   327  		OriginID: u.senderID.NodeID,
   328  		PeerID:   p2plogging.PeerId(expectedSenderPeerID),
   329  		MsgType:  "*messages.DKGMessage",
   330  		Channel:  channels.TestNetworkChannel,
   331  		Protocol: message.ProtocolTypeUnicast,
   332  		Err:      message.ErrUnauthorizedMessageOnChannel,
   333  	}
   334  
   335  	slashingViolationsConsumer.On("OnUnAuthorizedSenderError", expectedViolation).
   336  		Return(nil).Once().Run(func(args mockery.Arguments) {
   337  		close(u.waitCh)
   338  	})
   339  
   340  	u.startNetworksAndLibp2pNodes()
   341  
   342  	_, err = u.receiverNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{})
   343  	require.NoError(u.T(), err)
   344  
   345  	senderCon, err := u.senderNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{})
   346  	require.NoError(u.T(), err)
   347  
   348  	// send message via unicast
   349  	err = senderCon.Unicast(&libp2pmessage.TestMessage{
   350  		Text: string("hello"),
   351  	}, u.receiverID.NodeID)
   352  	require.NoError(u.T(), err)
   353  
   354  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   355  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   356  }
   357  
   358  // TestUnicastAuthorization_PublicChannel tests that messages sent via unicast on a public channel are not rejected for any reason.
   359  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_PublicChannel() {
   360  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   361  	u.setupNetworks(slashingViolationsConsumer)
   362  	u.startNetworksAndLibp2pNodes()
   363  
   364  	msg := &libp2pmessage.TestMessage{
   365  		Text: string("hello"),
   366  	}
   367  
   368  	// mock a message processor that will receive the message.
   369  	receiverEngine := &mocknetwork.MessageProcessor{}
   370  	receiverEngine.On("Process", channels.PublicPushBlocks, u.senderID.NodeID, msg).Run(
   371  		func(args mockery.Arguments) {
   372  			close(u.waitCh)
   373  		}).Return(nil).Once()
   374  	_, err := u.receiverNetwork.Register(channels.PublicPushBlocks, receiverEngine)
   375  	require.NoError(u.T(), err)
   376  
   377  	senderCon, err := u.senderNetwork.Register(channels.PublicPushBlocks, &mocknetwork.MessageProcessor{})
   378  	require.NoError(u.T(), err)
   379  
   380  	// send message via unicast
   381  	err = senderCon.Unicast(&libp2pmessage.TestMessage{
   382  		Text: string("hello"),
   383  	}, u.receiverID.NodeID)
   384  	require.NoError(u.T(), err)
   385  
   386  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   387  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   388  }
   389  
   390  // TestUnicastAuthorization_UnauthorizedUnicastOnChannel tests that messages sent via unicast that are not authorized for unicast are rejected.
   391  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnauthorizedUnicastOnChannel() {
   392  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   393  	u.setupNetworks(slashingViolationsConsumer)
   394  
   395  	// set sender id role to RoleConsensus to avoid unauthorized sender validation error
   396  	u.senderID.Role = flow.RoleConsensus
   397  
   398  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   399  	require.NoError(u.T(), err)
   400  
   401  	expectedViolation := &network.Violation{
   402  		Identity: u.senderID,
   403  		OriginID: u.senderID.NodeID,
   404  		PeerID:   p2plogging.PeerId(expectedSenderPeerID),
   405  		MsgType:  "*messages.BlockProposal",
   406  		Channel:  channels.ConsensusCommittee,
   407  		Protocol: message.ProtocolTypeUnicast,
   408  		Err:      message.ErrUnauthorizedUnicastOnChannel,
   409  	}
   410  
   411  	slashingViolationsConsumer.On("OnUnauthorizedUnicastOnChannel", expectedViolation).
   412  		Return(nil).Once().Run(func(args mockery.Arguments) {
   413  		close(u.waitCh)
   414  	})
   415  
   416  	u.startNetworksAndLibp2pNodes()
   417  
   418  	_, err = u.receiverNetwork.Register(channels.ConsensusCommittee, &mocknetwork.MessageProcessor{})
   419  	require.NoError(u.T(), err)
   420  
   421  	senderCon, err := u.senderNetwork.Register(channels.ConsensusCommittee, &mocknetwork.MessageProcessor{})
   422  	require.NoError(u.T(), err)
   423  
   424  	// messages.BlockProposal is not authorized to be sent via unicast over the ConsensusCommittee channel
   425  	payload := unittest.ProposalFixture()
   426  	// send message via unicast
   427  	err = senderCon.Unicast(payload, u.receiverID.NodeID)
   428  	require.NoError(u.T(), err)
   429  
   430  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   431  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   432  }
   433  
   434  // TestUnicastAuthorization_ReceiverHasNoSubscription tests that messages sent via unicast are rejected on the receiver end if the receiver does not have a subscription
   435  // to the channel of the message.
   436  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_ReceiverHasNoSubscription() {
   437  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   438  	u.setupNetworks(slashingViolationsConsumer)
   439  
   440  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   441  	require.NoError(u.T(), err)
   442  
   443  	expectedViolation := &network.Violation{
   444  		Identity: nil,
   445  		PeerID:   p2plogging.PeerId(expectedSenderPeerID),
   446  		MsgType:  "*message.TestMessage",
   447  		Channel:  channels.TestNetworkChannel,
   448  		Protocol: message.ProtocolTypeUnicast,
   449  		Err:      underlay.ErrUnicastMsgWithoutSub,
   450  	}
   451  
   452  	slashingViolationsConsumer.On("OnUnauthorizedUnicastOnChannel", expectedViolation).
   453  		Return(nil).Once().Run(func(args mockery.Arguments) {
   454  		close(u.waitCh)
   455  	})
   456  
   457  	u.startNetworksAndLibp2pNodes()
   458  
   459  	senderCon, err := u.senderNetwork.Register(channels.TestNetworkChannel, &mocknetwork.MessageProcessor{})
   460  	require.NoError(u.T(), err)
   461  
   462  	// send message via unicast
   463  	err = senderCon.Unicast(&libp2pmessage.TestMessage{
   464  		Text: string("hello"),
   465  	}, u.receiverID.NodeID)
   466  	require.NoError(u.T(), err)
   467  
   468  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   469  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   470  }
   471  
   472  // TestUnicastAuthorization_ReceiverHasSubscription tests that messages sent via unicast are processed on the receiver end if the receiver does have a subscription
   473  // to the channel of the message.
   474  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_ReceiverHasSubscription() {
   475  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   476  	u.setupNetworks(slashingViolationsConsumer)
   477  	u.startNetworksAndLibp2pNodes()
   478  
   479  	msg := &messages.EntityRequest{
   480  		EntityIDs: unittest.IdentifierListFixture(10),
   481  	}
   482  
   483  	// both sender and receiver must have an authorized role to send and receive messages on the ConsensusCommittee channel.
   484  	u.senderID.Role = flow.RoleConsensus
   485  	u.receiverID.Role = flow.RoleExecution
   486  
   487  	receiverEngine := &mocknetwork.MessageProcessor{}
   488  	receiverEngine.On("Process", channels.RequestReceiptsByBlockID, u.senderID.NodeID, msg).Run(
   489  		func(args mockery.Arguments) {
   490  			close(u.waitCh)
   491  		}).Return(nil).Once()
   492  	_, err := u.receiverNetwork.Register(channels.RequestReceiptsByBlockID, receiverEngine)
   493  	require.NoError(u.T(), err)
   494  
   495  	senderCon, err := u.senderNetwork.Register(channels.RequestReceiptsByBlockID, &mocknetwork.MessageProcessor{})
   496  	require.NoError(u.T(), err)
   497  
   498  	// send message via unicast
   499  	err = senderCon.Unicast(msg, u.receiverID.NodeID)
   500  	require.NoError(u.T(), err)
   501  
   502  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   503  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   504  }
   505  
   506  // overridableMessageEncoder is a codec that allows to override the encoder for a specific type only for sake of testing.
   507  // We specifically use this to override the encoder for the TestMessage type to encode it with an invalid message code.
   508  type overridableMessageEncoder struct {
   509  	codec           network.Codec
   510  	specificEncoder map[reflect.Type]func(interface{}) ([]byte, error)
   511  }
   512  
   513  var _ network.Codec = (*overridableMessageEncoder)(nil)
   514  
   515  func newOverridableMessageEncoder(codec network.Codec) *overridableMessageEncoder {
   516  	return &overridableMessageEncoder{
   517  		codec:           codec,
   518  		specificEncoder: make(map[reflect.Type]func(interface{}) ([]byte, error)),
   519  	}
   520  }
   521  
   522  // RegisterEncoder registers an encoder for a specific type, overriding the default encoder for that type.
   523  func (u *overridableMessageEncoder) RegisterEncoder(t reflect.Type, encoder func(interface{}) ([]byte, error)) {
   524  	u.specificEncoder[t] = encoder
   525  }
   526  
   527  // NewEncoder creates a new encoder.
   528  func (u *overridableMessageEncoder) NewEncoder(w io.Writer) network.Encoder {
   529  	return u.codec.NewEncoder(w)
   530  }
   531  
   532  // NewDecoder creates a new decoder.
   533  func (u *overridableMessageEncoder) NewDecoder(r io.Reader) network.Decoder {
   534  	return u.codec.NewDecoder(r)
   535  }
   536  
   537  // Encode encodes a value into a byte slice. If a specific encoder is registered for the type of the value, it will be used.
   538  // Otherwise, the default encoder will be used.
   539  func (u *overridableMessageEncoder) Encode(v interface{}) ([]byte, error) {
   540  	if encoder, ok := u.specificEncoder[reflect.TypeOf(v)]; ok {
   541  		return encoder(v)
   542  	}
   543  	return u.codec.Encode(v)
   544  }
   545  
   546  // Decode decodes a byte slice into a value. It uses the default decoder.
   547  func (u *overridableMessageEncoder) Decode(data []byte) (interface{}, error) {
   548  	return u.codec.Decode(data)
   549  }