github.com/koko1123/flow-go-1@v0.29.6/network/test/unicast_authorization_test.go (about)

     1  package test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/rs/zerolog"
    10  	"github.com/stretchr/testify/mock"
    11  	mockery "github.com/stretchr/testify/mock"
    12  	"github.com/stretchr/testify/require"
    13  	"github.com/stretchr/testify/suite"
    14  
    15  	"github.com/koko1123/flow-go-1/model/flow"
    16  	"github.com/koko1123/flow-go-1/model/flow/filter"
    17  	libp2pmessage "github.com/koko1123/flow-go-1/model/libp2p/message"
    18  	"github.com/koko1123/flow-go-1/model/messages"
    19  	"github.com/koko1123/flow-go-1/module/irrecoverable"
    20  	"github.com/koko1123/flow-go-1/network"
    21  	"github.com/koko1123/flow-go-1/network/channels"
    22  	"github.com/koko1123/flow-go-1/network/codec"
    23  	"github.com/koko1123/flow-go-1/network/internal/testutils"
    24  	"github.com/koko1123/flow-go-1/network/message"
    25  	"github.com/koko1123/flow-go-1/network/mocknetwork"
    26  	"github.com/koko1123/flow-go-1/network/p2p"
    27  	"github.com/koko1123/flow-go-1/network/p2p/middleware"
    28  	"github.com/koko1123/flow-go-1/network/slashing"
    29  	"github.com/koko1123/flow-go-1/network/validator"
    30  	"github.com/koko1123/flow-go-1/utils/unittest"
    31  )
    32  
    33  // UnicastAuthorizationTestSuite tests that messages sent via unicast that are unauthenticated or unauthorized are correctly rejected. Each test on the test suite
    34  // uses 2 middlewares, a sender and receiver. A mock slashing violation's consumer is used to assert the messages were rejected. Middleware and the cancel func
    35  // are set during each test run inside the test and remove after each test run in the TearDownTest callback.
    36  type UnicastAuthorizationTestSuite struct {
    37  	suite.Suite
    38  	channelCloseDuration time.Duration
    39  	logger               zerolog.Logger
    40  
    41  	libP2PNodes []p2p.LibP2PNode
    42  	// senderMW is the mw that will be sending the message
    43  	senderMW network.Middleware
    44  	// senderID the identity on the mw sending the message
    45  	senderID *flow.Identity
    46  	// receiverMW is the mw that will be sending the message
    47  	receiverMW network.Middleware
    48  	// receiverID the identity on the mw sending the message
    49  	receiverID *flow.Identity
    50  	// providers id providers generated at beginning of a test run
    51  	providers []*testutils.UpdatableIDProvider
    52  	// cancel is the cancel func from the context that was used to start the middlewares in a test run
    53  	cancel context.CancelFunc
    54  	// waitCh is the channel used to wait for the middleware to perform authorization and invoke the slashing
    55  	//violation's consumer before making mock assertions and cleaning up resources
    56  	waitCh chan struct{}
    57  }
    58  
    59  // TestUnicastAuthorizationTestSuite runs all the test methods in this test suit
    60  func TestUnicastAuthorizationTestSuite(t *testing.T) {
    61  	t.Parallel()
    62  	suite.Run(t, new(UnicastAuthorizationTestSuite))
    63  }
    64  
    65  func (u *UnicastAuthorizationTestSuite) SetupTest() {
    66  	u.logger = unittest.Logger()
    67  	u.channelCloseDuration = 100 * time.Millisecond
    68  	// this ch will allow us to wait until the expected method call happens before shutting down middleware
    69  	u.waitCh = make(chan struct{})
    70  }
    71  
    72  func (u *UnicastAuthorizationTestSuite) TearDownTest() {
    73  	u.stopMiddlewares()
    74  }
    75  
    76  // setupMiddlewaresAndProviders will setup 2 middlewares that will be used as a sender and receiver in each suite test.
    77  func (u *UnicastAuthorizationTestSuite) setupMiddlewaresAndProviders(slashingViolationsConsumer slashing.ViolationsConsumer) {
    78  	ids, libP2PNodes, _ := testutils.GenerateIDs(u.T(), u.logger, 2)
    79  	mws, providers := testutils.GenerateMiddlewares(u.T(), u.logger, ids, libP2PNodes, unittest.NetworkCodec(), slashingViolationsConsumer)
    80  	require.Len(u.T(), ids, 2)
    81  	require.Len(u.T(), providers, 2)
    82  	require.Len(u.T(), mws, 2)
    83  
    84  	u.senderID = ids[0]
    85  	u.senderMW = mws[0]
    86  	u.receiverID = ids[1]
    87  	u.receiverMW = mws[1]
    88  	u.providers = providers
    89  	u.libP2PNodes = libP2PNodes
    90  }
    91  
    92  // startMiddlewares will start both sender and receiver middlewares with an irrecoverable signaler context and set the context cancel func.
    93  func (u *UnicastAuthorizationTestSuite) startMiddlewares(overlay *mocknetwork.Overlay) {
    94  	ctx, cancel := context.WithCancel(context.Background())
    95  	sigCtx, _ := irrecoverable.WithSignaler(ctx)
    96  
    97  	testutils.StartNodes(sigCtx, u.T(), u.libP2PNodes, 100*time.Millisecond)
    98  
    99  	u.senderMW.SetOverlay(overlay)
   100  	u.senderMW.Start(sigCtx)
   101  
   102  	u.receiverMW.SetOverlay(overlay)
   103  	u.receiverMW.Start(sigCtx)
   104  
   105  	unittest.RequireComponentsReadyBefore(u.T(), 100*time.Millisecond, u.senderMW, u.receiverMW)
   106  
   107  	u.cancel = cancel
   108  }
   109  
   110  // stopMiddlewares will stop all middlewares.
   111  func (u *UnicastAuthorizationTestSuite) stopMiddlewares() {
   112  	u.cancel()
   113  	unittest.RequireCloseBefore(u.T(), u.senderMW.Done(), u.channelCloseDuration, "could not stop middleware on time")
   114  	unittest.RequireCloseBefore(u.T(), u.receiverMW.Done(), u.channelCloseDuration, "could not stop middleware on time")
   115  }
   116  
   117  // TestUnicastAuthorization_UnstakedPeer tests that messages sent via unicast by an unstaked peer is correctly rejected.
   118  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnstakedPeer() {
   119  	// setup mock slashing violations consumer and middlewares
   120  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   121  	u.setupMiddlewaresAndProviders(slashingViolationsConsumer)
   122  
   123  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   124  	require.NoError(u.T(), err)
   125  
   126  	var nilID *flow.Identity
   127  	expectedViolation := &slashing.Violation{
   128  		Identity: nilID, // because the peer will be unverified this identity will be nil
   129  		PeerID:   expectedSenderPeerID.String(),
   130  		MsgType:  "",                          // message will not be decoded before OnSenderEjectedError is logged, we won't log message type
   131  		Channel:  channels.TestNetworkChannel, // message will not be decoded before OnSenderEjectedError is logged, we won't log peer ID
   132  		Protocol: message.ProtocolUnicast,
   133  		Err:      validator.ErrIdentityUnverified,
   134  	}
   135  	slashingViolationsConsumer.On(
   136  		"OnUnAuthorizedSenderError",
   137  		expectedViolation,
   138  	).Once().Run(func(args mockery.Arguments) {
   139  		close(u.waitCh)
   140  	})
   141  
   142  	overlay := mocknetwork.NewOverlay(u.T())
   143  	overlay.On("Identities").Maybe().Return(func() flow.IdentityList {
   144  		return u.providers[0].Identities(filter.Any)
   145  	})
   146  	overlay.On("Topology").Maybe().Return(func() flow.IdentityList {
   147  		return u.providers[0].Identities(filter.Any)
   148  	}, nil)
   149  
   150  	//NOTE: return (nil, false) simulating unstaked node
   151  	overlay.On("Identity", mock.AnythingOfType("peer.ID")).Return(nil, false)
   152  	// message will be rejected so assert overlay never receives it
   153  	defer overlay.AssertNotCalled(u.T(), "Receive", mockery.Anything)
   154  
   155  	u.startMiddlewares(overlay)
   156  
   157  	require.NoError(u.T(), u.receiverMW.Subscribe(testChannel))
   158  	require.NoError(u.T(), u.senderMW.Subscribe(testChannel))
   159  
   160  	msg, err := network.NewOutgoingScope(
   161  		flow.IdentifierList{u.receiverID.NodeID},
   162  		testChannel,
   163  		&libp2pmessage.TestMessage{
   164  			Text: string("hello"),
   165  		},
   166  		unittest.NetworkCodec().Encode,
   167  		network.ProtocolTypeUnicast)
   168  	require.NoError(u.T(), err)
   169  
   170  	// send message via unicast
   171  	err = u.senderMW.SendDirect(msg)
   172  	require.NoError(u.T(), err)
   173  
   174  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   175  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   176  }
   177  
   178  // TestUnicastAuthorization_EjectedPeer tests that messages sent via unicast by an ejected peer is correctly rejected.
   179  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_EjectedPeer() {
   180  	// setup mock slashing violations consumer and middlewares
   181  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   182  	u.setupMiddlewaresAndProviders(slashingViolationsConsumer)
   183  	//NOTE: setup ejected identity
   184  	u.senderID.Ejected = true
   185  
   186  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   187  	require.NoError(u.T(), err)
   188  
   189  	expectedViolation := &slashing.Violation{
   190  		Identity: u.senderID, // we expect this method to be called with the ejected identity
   191  		PeerID:   expectedSenderPeerID.String(),
   192  		MsgType:  "",                          // message will not be decoded before OnSenderEjectedError is logged, we won't log message type
   193  		Channel:  channels.TestNetworkChannel, // message will not be decoded before OnSenderEjectedError is logged, we won't log peer ID
   194  		Protocol: message.ProtocolUnicast,
   195  		Err:      validator.ErrSenderEjected,
   196  	}
   197  	slashingViolationsConsumer.On(
   198  		"OnSenderEjectedError",
   199  		expectedViolation,
   200  	).Once().Run(func(args mockery.Arguments) {
   201  		close(u.waitCh)
   202  	})
   203  
   204  	overlay := mocknetwork.NewOverlay(u.T())
   205  	overlay.On("Identities").Maybe().Return(func() flow.IdentityList {
   206  		return u.providers[0].Identities(filter.Any)
   207  	})
   208  	overlay.On("Topology").Maybe().Return(func() flow.IdentityList {
   209  		return u.providers[0].Identities(filter.Any)
   210  	}, nil)
   211  	//NOTE: return ejected identity causing validation to fail
   212  	overlay.On("Identity", mock.AnythingOfType("peer.ID")).Return(u.senderID, true)
   213  	// message will be rejected so assert overlay never receives it
   214  	defer overlay.AssertNotCalled(u.T(), "Receive", mockery.Anything)
   215  
   216  	u.startMiddlewares(overlay)
   217  
   218  	require.NoError(u.T(), u.receiverMW.Subscribe(testChannel))
   219  	require.NoError(u.T(), u.senderMW.Subscribe(testChannel))
   220  
   221  	msg, err := network.NewOutgoingScope(
   222  		flow.IdentifierList{u.receiverID.NodeID},
   223  		testChannel,
   224  		&libp2pmessage.TestMessage{
   225  			Text: string("hello"),
   226  		},
   227  		unittest.NetworkCodec().Encode,
   228  		network.ProtocolTypeUnicast)
   229  	require.NoError(u.T(), err)
   230  
   231  	// send message via unicast
   232  	err = u.senderMW.SendDirect(msg)
   233  	require.NoError(u.T(), err)
   234  
   235  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   236  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   237  }
   238  
   239  // TestUnicastAuthorization_UnauthorizedPeer tests that messages sent via unicast by an unauthorized peer is correctly rejected.
   240  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnauthorizedPeer() {
   241  	// setup mock slashing violations consumer and middlewares
   242  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   243  	u.setupMiddlewaresAndProviders(slashingViolationsConsumer)
   244  
   245  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   246  	require.NoError(u.T(), err)
   247  
   248  	expectedViolation := &slashing.Violation{
   249  		Identity: u.senderID,
   250  		PeerID:   expectedSenderPeerID.String(),
   251  		MsgType:  message.TestMessage,
   252  		Channel:  channels.ConsensusCommittee,
   253  		Protocol: message.ProtocolUnicast,
   254  		Err:      message.ErrUnauthorizedMessageOnChannel,
   255  	}
   256  
   257  	slashingViolationsConsumer.On(
   258  		"OnUnAuthorizedSenderError",
   259  		expectedViolation,
   260  	).Once().Run(func(args mockery.Arguments) {
   261  		close(u.waitCh)
   262  	})
   263  
   264  	overlay := mocknetwork.NewOverlay(u.T())
   265  	overlay.On("Identities").Maybe().Return(func() flow.IdentityList {
   266  		return u.providers[0].Identities(filter.Any)
   267  	})
   268  	overlay.On("Topology").Maybe().Return(func() flow.IdentityList {
   269  		return u.providers[0].Identities(filter.Any)
   270  	}, nil)
   271  	overlay.On("Identity", mock.AnythingOfType("peer.ID")).Return(u.senderID, true)
   272  	// message will be rejected so assert overlay never receives it
   273  	defer overlay.AssertNotCalled(u.T(), "Receive", mockery.Anything)
   274  
   275  	u.startMiddlewares(overlay)
   276  
   277  	channel := channels.ConsensusCommittee
   278  	require.NoError(u.T(), u.receiverMW.Subscribe(channel))
   279  	require.NoError(u.T(), u.senderMW.Subscribe(channel))
   280  
   281  	msg, err := network.NewOutgoingScope(
   282  		flow.IdentifierList{u.receiverID.NodeID},
   283  		channel,
   284  		&libp2pmessage.TestMessage{
   285  			Text: string("hello"),
   286  		},
   287  		unittest.NetworkCodec().Encode,
   288  		network.ProtocolTypeUnicast)
   289  	require.NoError(u.T(), err)
   290  
   291  	// send message via unicast
   292  	err = u.senderMW.SendDirect(msg)
   293  	require.NoError(u.T(), err)
   294  
   295  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   296  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   297  }
   298  
   299  // TestUnicastAuthorization_UnknownMsgCode tests that messages sent via unicast with an unknown message code is correctly rejected.
   300  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnknownMsgCode() {
   301  	// setup mock slashing violations consumer and middlewares
   302  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   303  	u.setupMiddlewaresAndProviders(slashingViolationsConsumer)
   304  
   305  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   306  	require.NoError(u.T(), err)
   307  
   308  	invalidMessageCode := byte('X')
   309  
   310  	var nilID *flow.Identity
   311  	expectedViolation := &slashing.Violation{
   312  		Identity: nilID,
   313  		PeerID:   expectedSenderPeerID.String(),
   314  		MsgType:  "",
   315  		Channel:  channels.TestNetworkChannel,
   316  		Protocol: message.ProtocolUnicast,
   317  		Err:      codec.NewUnknownMsgCodeErr(invalidMessageCode),
   318  	}
   319  
   320  	slashingViolationsConsumer.On(
   321  		"OnUnknownMsgTypeError",
   322  		expectedViolation,
   323  	).Once().Run(func(args mockery.Arguments) {
   324  		close(u.waitCh)
   325  	})
   326  
   327  	overlay := mocknetwork.NewOverlay(u.T())
   328  	overlay.On("Identities").Maybe().Return(func() flow.IdentityList {
   329  		return u.providers[0].Identities(filter.Any)
   330  	})
   331  	overlay.On("Topology").Maybe().Return(func() flow.IdentityList {
   332  		return u.providers[0].Identities(filter.Any)
   333  	}, nil)
   334  	overlay.On("Identity", mock.AnythingOfType("peer.ID")).Return(u.senderID, true)
   335  
   336  	// message will be rejected so assert overlay never receives it
   337  	defer overlay.AssertNotCalled(u.T(), "Receive", u.senderID.NodeID, mock.AnythingOfType("*message.Message"))
   338  
   339  	u.startMiddlewares(overlay)
   340  
   341  	require.NoError(u.T(), u.receiverMW.Subscribe(testChannel))
   342  	require.NoError(u.T(), u.senderMW.Subscribe(testChannel))
   343  
   344  	msg, err := network.NewOutgoingScope(
   345  		flow.IdentifierList{u.receiverID.NodeID},
   346  		testChannel,
   347  		&libp2pmessage.TestMessage{
   348  			Text: "hello",
   349  		},
   350  		// we use a custom encoder that encodes the message with an invalid message code.
   351  		func(msg interface{}) ([]byte, error) {
   352  			e, err := unittest.NetworkCodec().Encode(msg)
   353  			require.NoError(u.T(), err)
   354  			// manipulate message code byte
   355  			e[0] = invalidMessageCode
   356  			return e, nil
   357  		},
   358  		network.ProtocolTypeUnicast)
   359  	require.NoError(u.T(), err)
   360  
   361  	// send message via unicast
   362  	err = u.senderMW.SendDirect(msg)
   363  	require.NoError(u.T(), err)
   364  
   365  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   366  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   367  }
   368  
   369  // TestUnicastAuthorization_WrongMsgCode tests that messages sent via unicast with a message code that does not match the underlying message type are correctly rejected.
   370  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_WrongMsgCode() {
   371  	// setup mock slashing violations consumer and middlewares
   372  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   373  	u.setupMiddlewaresAndProviders(slashingViolationsConsumer)
   374  
   375  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   376  	require.NoError(u.T(), err)
   377  
   378  	modifiedMessageCode := codec.CodeDKGMessage
   379  
   380  	var nilID *flow.Identity
   381  	expectedViolation := &slashing.Violation{
   382  		Identity: nilID,
   383  		PeerID:   expectedSenderPeerID.String(),
   384  		MsgType:  "",
   385  		Channel:  channels.TestNetworkChannel,
   386  		Protocol: message.ProtocolUnicast,
   387  		//NOTE: in this test the message code does not match the underlying message type causing the codec to fail to unmarshal the message when decoding.
   388  		Err: codec.NewMsgUnmarshalErr(modifiedMessageCode, message.DKGMessage, fmt.Errorf("cbor: found unknown field at map element index 0")),
   389  	}
   390  
   391  	slashingViolationsConsumer.On(
   392  		"OnInvalidMsgError",
   393  		expectedViolation,
   394  	).Once().Run(func(args mockery.Arguments) {
   395  		close(u.waitCh)
   396  	})
   397  
   398  	overlay := mocknetwork.NewOverlay(u.T())
   399  	overlay.On("Identities").Maybe().Return(func() flow.IdentityList {
   400  		return u.providers[0].Identities(filter.Any)
   401  	})
   402  	overlay.On("Topology").Maybe().Return(func() flow.IdentityList {
   403  		return u.providers[0].Identities(filter.Any)
   404  	}, nil)
   405  	overlay.On("Identity", expectedSenderPeerID).Return(u.senderID, true)
   406  
   407  	// message will be rejected so assert overlay never receives it
   408  	defer overlay.AssertNotCalled(u.T(), "Receive", u.senderID.NodeID, mock.AnythingOfType("*message.Message"))
   409  
   410  	u.startMiddlewares(overlay)
   411  
   412  	require.NoError(u.T(), u.receiverMW.Subscribe(testChannel))
   413  	require.NoError(u.T(), u.senderMW.Subscribe(testChannel))
   414  
   415  	msg, err := network.NewOutgoingScope(
   416  		flow.IdentifierList{u.receiverID.NodeID},
   417  		testChannel,
   418  		&libp2pmessage.TestMessage{
   419  			Text: "hello",
   420  		},
   421  		// we use a custom encoder that encodes the message with an invalid message code.
   422  		func(msg interface{}) ([]byte, error) {
   423  			e, err := unittest.NetworkCodec().Encode(msg)
   424  			require.NoError(u.T(), err)
   425  			// manipulate message code byte
   426  			e[0] = modifiedMessageCode
   427  			return e, nil
   428  		},
   429  		network.ProtocolTypeUnicast)
   430  	require.NoError(u.T(), err)
   431  
   432  	// send message via unicast
   433  	err = u.senderMW.SendDirect(msg)
   434  	require.NoError(u.T(), err)
   435  
   436  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   437  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   438  }
   439  
   440  // TestUnicastAuthorization_PublicChannel tests that messages sent via unicast on a public channel are not rejected for any reason.
   441  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_PublicChannel() {
   442  	// setup mock slashing violations consumer and middlewares
   443  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   444  	u.setupMiddlewaresAndProviders(slashingViolationsConsumer)
   445  
   446  	expectedPayload := "hello"
   447  	msg, err := network.NewOutgoingScope(
   448  		flow.IdentifierList{u.receiverID.NodeID},
   449  		testChannel,
   450  		&libp2pmessage.TestMessage{
   451  			Text: expectedPayload,
   452  		},
   453  		unittest.NetworkCodec().Encode,
   454  		network.ProtocolTypeUnicast)
   455  	require.NoError(u.T(), err)
   456  
   457  	overlay := mocknetwork.NewOverlay(u.T())
   458  	overlay.On("Identities").Maybe().Return(func() flow.IdentityList {
   459  		return u.providers[0].Identities(filter.Any)
   460  	})
   461  	overlay.On("Topology").Maybe().Return(func() flow.IdentityList {
   462  		return u.providers[0].Identities(filter.Any)
   463  	}, nil)
   464  	overlay.On("Identity", mock.AnythingOfType("peer.ID")).Return(u.senderID, true)
   465  
   466  	// we should receive the message on our overlay, at this point close the waitCh
   467  	overlay.On("Receive", mockery.Anything).Return(nil).
   468  		Once().
   469  		Run(func(args mockery.Arguments) {
   470  			close(u.waitCh)
   471  
   472  			msg, ok := args[0].(*network.IncomingMessageScope)
   473  			require.True(u.T(), ok)
   474  
   475  			require.Equal(u.T(), testChannel, msg.Channel())                                              // channel
   476  			require.Equal(u.T(), u.senderID.NodeID, msg.OriginId())                                       // sender id
   477  			require.Equal(u.T(), u.receiverID.NodeID, msg.TargetIDs()[0])                                 // target id
   478  			require.Equal(u.T(), network.ProtocolTypeUnicast, msg.Protocol())                             // protocol
   479  			require.Equal(u.T(), expectedPayload, msg.DecodedPayload().(*libp2pmessage.TestMessage).Text) // payload
   480  		})
   481  
   482  	u.startMiddlewares(overlay)
   483  
   484  	require.NoError(u.T(), u.receiverMW.Subscribe(testChannel))
   485  	require.NoError(u.T(), u.senderMW.Subscribe(testChannel))
   486  
   487  	// send message via unicast
   488  	err = u.senderMW.SendDirect(msg)
   489  	require.NoError(u.T(), err)
   490  
   491  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   492  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   493  }
   494  
   495  // TestUnicastAuthorization_UnauthorizedUnicastOnChannel tests that messages sent via unicast that are not authorized for unicast are rejected.
   496  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_UnauthorizedUnicastOnChannel() {
   497  	// setup mock slashing violations consumer and middlewares
   498  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   499  	u.setupMiddlewaresAndProviders(slashingViolationsConsumer)
   500  
   501  	// set sender id role to RoleConsensus to avoid unauthorized sender validation error
   502  	u.senderID.Role = flow.RoleConsensus
   503  
   504  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   505  	require.NoError(u.T(), err)
   506  
   507  	expectedViolation := &slashing.Violation{
   508  		Identity: u.senderID,
   509  		PeerID:   expectedSenderPeerID.String(),
   510  		MsgType:  "BlockProposal",
   511  		Channel:  channels.ConsensusCommittee,
   512  		Protocol: message.ProtocolUnicast,
   513  		Err:      message.ErrUnauthorizedUnicastOnChannel,
   514  	}
   515  
   516  	slashingViolationsConsumer.On(
   517  		"OnUnauthorizedUnicastOnChannel",
   518  		expectedViolation,
   519  	).Once().Run(func(args mockery.Arguments) {
   520  		close(u.waitCh)
   521  	})
   522  
   523  	overlay := mocknetwork.NewOverlay(u.T())
   524  	overlay.On("Identities").Maybe().Return(func() flow.IdentityList {
   525  		return u.providers[0].Identities(filter.Any)
   526  	})
   527  	overlay.On("Topology").Maybe().Return(func() flow.IdentityList {
   528  		return u.providers[0].Identities(filter.Any)
   529  	}, nil)
   530  	overlay.On("Identity", expectedSenderPeerID).Return(u.senderID, true)
   531  
   532  	// message will be rejected so assert overlay never receives it
   533  	defer overlay.AssertNotCalled(u.T(), "Receive", u.senderID.NodeID, mock.AnythingOfType("*message.Message"))
   534  
   535  	u.startMiddlewares(overlay)
   536  
   537  	channel := channels.ConsensusCommittee
   538  	require.NoError(u.T(), u.receiverMW.Subscribe(channel))
   539  	require.NoError(u.T(), u.senderMW.Subscribe(channel))
   540  
   541  	// messages.BlockProposal is not authorized to be sent via unicast over the ConsensusCommittee channel
   542  	payload := unittest.ProposalFixture()
   543  
   544  	msg, err := network.NewOutgoingScope(
   545  		flow.IdentifierList{u.receiverID.NodeID},
   546  		channel,
   547  		payload,
   548  		unittest.NetworkCodec().Encode,
   549  		network.ProtocolTypeUnicast)
   550  	require.NoError(u.T(), err)
   551  
   552  	// send message via unicast
   553  	err = u.senderMW.SendDirect(msg)
   554  	require.NoError(u.T(), err)
   555  
   556  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   557  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   558  }
   559  
   560  // TestUnicastAuthorization_ReceiverHasNoSubscription tests that messages sent via unicast are rejected on the receiver end if the receiver does not have a subscription
   561  // to the channel of the message.
   562  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_ReceiverHasNoSubscription() {
   563  	// setup mock slashing violations consumer and middlewares
   564  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   565  	u.setupMiddlewaresAndProviders(slashingViolationsConsumer)
   566  
   567  	expectedSenderPeerID, err := unittest.PeerIDFromFlowID(u.senderID)
   568  	require.NoError(u.T(), err)
   569  
   570  	expectedViolation := &slashing.Violation{
   571  		Identity: nil,
   572  		PeerID:   expectedSenderPeerID.String(),
   573  		MsgType:  message.TestMessage,
   574  		Channel:  channels.TestNetworkChannel,
   575  		Protocol: message.ProtocolUnicast,
   576  		Err:      middleware.ErrUnicastMsgWithoutSub,
   577  	}
   578  
   579  	slashingViolationsConsumer.On(
   580  		"OnUnauthorizedUnicastOnChannel",
   581  		expectedViolation,
   582  	).Once().Run(func(args mockery.Arguments) {
   583  		close(u.waitCh)
   584  	})
   585  
   586  	overlay := mocknetwork.NewOverlay(u.T())
   587  	overlay.On("Identities").Maybe().Return(func() flow.IdentityList {
   588  		return u.providers[0].Identities(filter.Any)
   589  	})
   590  	overlay.On("Topology").Maybe().Return(func() flow.IdentityList {
   591  		return u.providers[0].Identities(filter.Any)
   592  	}, nil)
   593  
   594  	// message will be rejected so assert overlay never receives it
   595  	defer overlay.AssertNotCalled(u.T(), "Receive", u.senderID.NodeID, mock.AnythingOfType("*message.Message"))
   596  
   597  	u.startMiddlewares(overlay)
   598  
   599  	channel := channels.TestNetworkChannel
   600  
   601  	msg, err := network.NewOutgoingScope(
   602  		flow.IdentifierList{u.receiverID.NodeID},
   603  		channel,
   604  		&libp2pmessage.TestMessage{
   605  			Text: "TestUnicastAuthorization_ReceiverHasNoSubscription",
   606  		},
   607  		unittest.NetworkCodec().Encode,
   608  		network.ProtocolTypeUnicast)
   609  	require.NoError(u.T(), err)
   610  
   611  	// send message via unicast
   612  	err = u.senderMW.SendDirect(msg)
   613  	require.NoError(u.T(), err)
   614  
   615  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   616  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   617  }
   618  
   619  // TestUnicastAuthorization_ReceiverHasSubscription tests that messages sent via unicast are processed on the receiver end if the receiver does have a subscription
   620  // to the channel of the message.
   621  func (u *UnicastAuthorizationTestSuite) TestUnicastAuthorization_ReceiverHasSubscription() {
   622  	// setup mock slashing violations consumer and middlewares
   623  	slashingViolationsConsumer := mocknetwork.NewViolationsConsumer(u.T())
   624  	u.setupMiddlewaresAndProviders(slashingViolationsConsumer)
   625  	channel := channels.RequestReceiptsByBlockID
   626  
   627  	msg, err := network.NewOutgoingScope(
   628  		flow.IdentifierList{u.receiverID.NodeID},
   629  		channel,
   630  		&messages.EntityRequest{},
   631  		unittest.NetworkCodec().Encode,
   632  		network.ProtocolTypeUnicast)
   633  	require.NoError(u.T(), err)
   634  
   635  	u.senderID.Role = flow.RoleConsensus
   636  	u.receiverID.Role = flow.RoleExecution
   637  
   638  	overlay := mocknetwork.NewOverlay(u.T())
   639  	overlay.On("Identities").Maybe().Return(func() flow.IdentityList {
   640  		return u.providers[0].Identities(filter.Any)
   641  	})
   642  	overlay.On("Topology").Maybe().Return(func() flow.IdentityList {
   643  		return u.providers[0].Identities(filter.Any)
   644  	}, nil)
   645  	overlay.On("Identity", mock.AnythingOfType("peer.ID")).Return(u.senderID, true)
   646  
   647  	// we should receive the message on our overlay, at this point close the waitCh
   648  	overlay.On("Receive", mockery.Anything).Return(nil).
   649  		Once().
   650  		Run(func(args mockery.Arguments) {
   651  			close(u.waitCh)
   652  
   653  			msg, ok := args[0].(*network.IncomingMessageScope)
   654  			require.True(u.T(), ok)
   655  
   656  			require.Equal(u.T(), channel, msg.Channel())                      // channel
   657  			require.Equal(u.T(), u.senderID.NodeID, msg.OriginId())           // sender id
   658  			require.Equal(u.T(), u.receiverID.NodeID, msg.TargetIDs()[0])     // target id
   659  			require.Equal(u.T(), network.ProtocolTypeUnicast, msg.Protocol()) // protocol
   660  		})
   661  
   662  	u.startMiddlewares(overlay)
   663  
   664  	require.NoError(u.T(), u.receiverMW.Subscribe(channel))
   665  	require.NoError(u.T(), u.senderMW.Subscribe(channel))
   666  
   667  	// send message via unicast
   668  	err = u.senderMW.SendDirect(msg)
   669  	require.NoError(u.T(), err)
   670  
   671  	// wait for slashing violations consumer mock to invoke run func and close ch if expected method call happens
   672  	unittest.RequireCloseBefore(u.T(), u.waitCh, u.channelCloseDuration, "could close ch on time")
   673  }