github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/network/validator/authorized_sender_validator_test.go (about)

     1  package validator
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/libp2p/go-libp2p/core/peer"
     8  	"github.com/rs/zerolog"
     9  	"github.com/stretchr/testify/mock"
    10  	"github.com/stretchr/testify/require"
    11  	"github.com/stretchr/testify/suite"
    12  
    13  	"github.com/onflow/flow-go/model/flow"
    14  	libp2pmessage "github.com/onflow/flow-go/model/libp2p/message"
    15  	"github.com/onflow/flow-go/model/messages"
    16  	"github.com/onflow/flow-go/module/metrics"
    17  	"github.com/onflow/flow-go/network"
    18  	"github.com/onflow/flow-go/network/alsp"
    19  	"github.com/onflow/flow-go/network/channels"
    20  	"github.com/onflow/flow-go/network/codec"
    21  	"github.com/onflow/flow-go/network/message"
    22  	"github.com/onflow/flow-go/network/mocknetwork"
    23  	"github.com/onflow/flow-go/network/p2p"
    24  	"github.com/onflow/flow-go/network/slashing"
    25  	"github.com/onflow/flow-go/utils/unittest"
    26  )
    27  
    28  type TestCase struct {
    29  	Identity    *flow.Identity
    30  	GetIdentity func(pid peer.ID) (*flow.Identity, bool)
    31  	Channel     channels.Channel
    32  	Message     interface{}
    33  	MessageCode codec.MessageCode
    34  	MessageStr  string
    35  	Protocols   message.Protocols
    36  }
    37  
    38  func TestIsAuthorizedSender(t *testing.T) {
    39  	suite.Run(t, new(TestAuthorizedSenderValidatorSuite))
    40  }
    41  
    42  type TestAuthorizedSenderValidatorSuite struct {
    43  	suite.Suite
    44  	authorizedSenderTestCases             []TestCase
    45  	unauthorizedSenderTestCases           []TestCase
    46  	unauthorizedMessageOnChannelTestCases []TestCase
    47  	unauthorizedUnicastOnChannel          []TestCase
    48  	authorizedUnicastOnChannel            []TestCase
    49  	log                                   zerolog.Logger
    50  	slashingViolationsConsumer            network.ViolationsConsumer
    51  	allMsgConfigs                         []message.MsgAuthConfig
    52  	codec                                 network.Codec
    53  }
    54  
    55  func (s *TestAuthorizedSenderValidatorSuite) SetupTest() {
    56  	s.allMsgConfigs = message.GetAllMessageAuthConfigs()
    57  	s.initializeAuthorizationTestCases()
    58  	s.initializeInvalidMessageOnChannelTestCases()
    59  	s.initializeUnicastOnChannelTestCases()
    60  	s.log = unittest.Logger()
    61  	s.codec = unittest.NetworkCodec()
    62  }
    63  
    64  // TestValidatorCallback_AuthorizedSender checks that AuthorizedSenderValidator.Validate does not return false positive
    65  // validation errors for all possible valid combinations (authorized sender role, message type).
    66  func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_AuthorizedSender() {
    67  	for _, c := range s.authorizedSenderTestCases {
    68  		str := fmt.Sprintf("role (%s) should be authorized to send message type (%s) on channel (%s)", c.Identity.Role, c.MessageStr, c.Channel)
    69  		s.Run(str, func() {
    70  			misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T())
    71  			defer misbehaviorReportConsumer.AssertNotCalled(s.T(), "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport"))
    72  			violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer)
    73  			authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity)
    74  			validateUnicast := authorizedSenderValidator.Validate
    75  			validatePubsub := authorizedSenderValidator.PubSubMessageValidator(c.Channel)
    76  			pid, err := unittest.PeerIDFromFlowID(c.Identity)
    77  			require.NoError(s.T(), err)
    78  			switch {
    79  			// ensure according to the message auth config, if a message is authorized to be sent via unicast it
    80  			// is accepted.
    81  			case c.Protocols.Contains(message.ProtocolTypeUnicast):
    82  				msgType, err := validateUnicast(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast)
    83  				if c.Protocols.Contains(message.ProtocolTypeUnicast) {
    84  					require.NoError(s.T(), err)
    85  					require.Equal(s.T(), c.MessageStr, msgType)
    86  				}
    87  			// ensure according to the message auth config, if a message is authorized to be sent via pubsub it
    88  			// is accepted.
    89  			case c.Protocols.Contains(message.ProtocolTypePubSub):
    90  				payload, err := s.codec.Encode(c.Message)
    91  				require.NoError(s.T(), err)
    92  				m := &message.Message{
    93  					ChannelID: c.Channel.String(),
    94  					Payload:   payload,
    95  				}
    96  				pubsubResult := validatePubsub(pid, m)
    97  				require.Equal(s.T(), p2p.ValidationAccept, pubsubResult)
    98  			default:
    99  				s.T().Fatal("authconfig does not contain any protocols")
   100  			}
   101  		})
   102  	}
   103  
   104  	s.Run("test messages should be allowed to be sent via both protocols unicast/pubsub on test channel", func() {
   105  		identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(flow.RoleCollection))
   106  		misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T())
   107  		defer misbehaviorReportConsumer.AssertNotCalled(s.T(), "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport"))
   108  		violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer)
   109  		getIdentityFunc := s.getIdentity(identity)
   110  		pid, err := unittest.PeerIDFromFlowID(identity)
   111  		require.NoError(s.T(), err)
   112  		authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc)
   113  
   114  		msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeEcho.Uint8()}, channels.TestNetworkChannel, message.ProtocolTypeUnicast)
   115  		require.NoError(s.T(), err)
   116  		require.Equal(s.T(), "*message.TestMessage", msgType)
   117  
   118  		payload, err := s.codec.Encode(&libp2pmessage.TestMessage{})
   119  		require.NoError(s.T(), err)
   120  		m := &message.Message{
   121  			ChannelID: channels.TestNetworkChannel.String(),
   122  			Payload:   payload,
   123  		}
   124  		validatePubsub := authorizedSenderValidator.PubSubMessageValidator(channels.TestNetworkChannel)
   125  		pubsubResult := validatePubsub(pid, m)
   126  		require.Equal(s.T(), p2p.ValidationAccept, pubsubResult)
   127  	})
   128  }
   129  
   130  // TestValidatorCallback_UnAuthorizedSender checks that AuthorizedSenderValidator.Validate return's p2p.ValidationReject
   131  // validation error for all possible invalid combinations (unauthorized sender role, message type).
   132  func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_UnAuthorizedSender() {
   133  	for _, c := range s.unauthorizedSenderTestCases {
   134  		str := fmt.Sprintf("role (%s) should not be authorized to send message type (%s) on channel (%s)", c.Identity.Role, c.MessageStr, c.Channel)
   135  		s.Run(str, func() {
   136  			pid, err := unittest.PeerIDFromFlowID(c.Identity)
   137  			require.NoError(s.T(), err)
   138  			expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(c.Identity.NodeID, alsp.UnAuthorizedSender)
   139  			require.NoError(s.T(), err)
   140  			misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T())
   141  			misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", c.Channel, expectedMisbehaviorReport).Once()
   142  			violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer)
   143  			authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity)
   144  
   145  			payload, err := s.codec.Encode(c.Message)
   146  			require.NoError(s.T(), err)
   147  			m := &message.Message{
   148  				ChannelID: c.Channel.String(),
   149  				Payload:   payload,
   150  			}
   151  			validatePubsub := authorizedSenderValidator.PubSubMessageValidator(c.Channel)
   152  			pubsubResult := validatePubsub(pid, m)
   153  			require.Equal(s.T(), p2p.ValidationReject, pubsubResult)
   154  		})
   155  	}
   156  }
   157  
   158  // TestValidatorCallback_AuthorizedUnicastOnChannel checks that AuthorizedSenderValidator.Validate does not return an error
   159  // for messages sent via unicast that are authorized to be sent via unicast.
   160  func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_AuthorizedUnicastOnChannel() {
   161  	for _, c := range s.authorizedUnicastOnChannel {
   162  		str := fmt.Sprintf("role (%s) should be authorized to send message type (%s) on channel (%s) via unicast", c.Identity.Role, c.MessageStr, c.Channel)
   163  		s.Run(str, func() {
   164  			pid, err := unittest.PeerIDFromFlowID(c.Identity)
   165  			require.NoError(s.T(), err)
   166  			misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T())
   167  			defer misbehaviorReportConsumer.AssertNotCalled(s.T(), "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport"))
   168  			violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer)
   169  			authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity)
   170  
   171  			msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast)
   172  			require.NoError(s.T(), err)
   173  			require.Equal(s.T(), c.MessageStr, msgType)
   174  		})
   175  	}
   176  }
   177  
   178  // TestValidatorCallback_UnAuthorizedUnicastOnChannel checks that AuthorizedSenderValidator.Validate returns message.ErrUnauthorizedUnicastOnChannel
   179  // when a message not authorized to be sent via unicast is sent via unicast.
   180  func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_UnAuthorizedUnicastOnChannel() {
   181  	for _, c := range s.unauthorizedUnicastOnChannel {
   182  		str := fmt.Sprintf("role (%s) should not be authorized to send message type (%s) on channel (%s) via unicast", c.Identity.Role, c.MessageStr, c.Channel)
   183  		s.Run(str, func() {
   184  			pid, err := unittest.PeerIDFromFlowID(c.Identity)
   185  			require.NoError(s.T(), err)
   186  			expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(c.Identity.NodeID, alsp.UnauthorizedUnicastOnChannel)
   187  			require.NoError(s.T(), err)
   188  			misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T())
   189  			misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", c.Channel, expectedMisbehaviorReport).Once()
   190  			violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer)
   191  			authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity)
   192  
   193  			msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast)
   194  			require.ErrorIs(s.T(), err, message.ErrUnauthorizedUnicastOnChannel)
   195  			require.Equal(s.T(), c.MessageStr, msgType)
   196  		})
   197  	}
   198  }
   199  
   200  // TestValidatorCallback_UnAuthorizedMessageOnChannel checks that for each invalid combination of message type and channel
   201  // AuthorizedSenderValidator.Validate returns the appropriate error message.ErrUnauthorizedMessageOnChannel.
   202  func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_UnAuthorizedMessageOnChannel() {
   203  	for _, c := range s.unauthorizedMessageOnChannelTestCases {
   204  		str := fmt.Sprintf("message type (%s) should not be authorized to be sent on channel (%s)", c.MessageStr, c.Channel)
   205  		s.Run(str, func() {
   206  			pid, err := unittest.PeerIDFromFlowID(c.Identity)
   207  			require.NoError(s.T(), err)
   208  			expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(c.Identity.NodeID, alsp.UnAuthorizedSender)
   209  			require.NoError(s.T(), err)
   210  			misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T())
   211  			misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", c.Channel, expectedMisbehaviorReport).Twice()
   212  			violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer)
   213  			authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity)
   214  
   215  			msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast)
   216  			require.ErrorIs(s.T(), err, message.ErrUnauthorizedMessageOnChannel)
   217  			require.Equal(s.T(), c.MessageStr, msgType)
   218  
   219  			payload, err := s.codec.Encode(c.Message)
   220  			require.NoError(s.T(), err)
   221  			m := &message.Message{
   222  				ChannelID: c.Channel.String(),
   223  				Payload:   payload,
   224  			}
   225  			validatePubsub := authorizedSenderValidator.PubSubMessageValidator(c.Channel)
   226  			pubsubResult := validatePubsub(pid, m)
   227  			require.Equal(s.T(), p2p.ValidationReject, pubsubResult)
   228  		})
   229  	}
   230  }
   231  
   232  // TestValidatorCallback_ClusterPrefixedChannels checks that AuthorizedSenderValidator.Validate correctly
   233  // handles cluster prefixed channels during validation.
   234  func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_ClusterPrefixedChannels() {
   235  	identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(flow.RoleCollection))
   236  	clusterID := flow.Localnet
   237  
   238  	getIdentityFunc := s.getIdentity(identity)
   239  	pid, err := unittest.PeerIDFromFlowID(identity)
   240  	require.NoError(s.T(), err)
   241  
   242  	expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(identity.NodeID, alsp.UnauthorizedUnicastOnChannel)
   243  	require.NoError(s.T(), err)
   244  	misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T())
   245  	misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channels.SyncCluster(clusterID), expectedMisbehaviorReport).Once()
   246  	misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channels.ConsensusCluster(clusterID), expectedMisbehaviorReport).Once()
   247  
   248  	violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer)
   249  	authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc)
   250  
   251  	// validate collection sync cluster SyncRequest is not allowed to be sent on channel via unicast
   252  	msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeSyncRequest.Uint8()}, channels.SyncCluster(clusterID), message.ProtocolTypeUnicast)
   253  	require.ErrorIs(s.T(), err, message.ErrUnauthorizedUnicastOnChannel)
   254  	require.Equal(s.T(), "*messages.SyncRequest", msgType)
   255  
   256  	// ensure ClusterBlockProposal not allowed to be sent on channel via unicast
   257  	msgType, err = authorizedSenderValidator.Validate(pid, []byte{codec.CodeClusterBlockProposal.Uint8()}, channels.ConsensusCluster(clusterID), message.ProtocolTypeUnicast)
   258  	require.ErrorIs(s.T(), err, message.ErrUnauthorizedUnicastOnChannel)
   259  	require.Equal(s.T(), "*messages.ClusterBlockProposal", msgType)
   260  
   261  	// ensure ClusterBlockProposal is allowed to be sent via pubsub by authorized sender
   262  	payload, err := s.codec.Encode(&messages.ClusterBlockProposal{})
   263  	require.NoError(s.T(), err)
   264  	m := &message.Message{
   265  		ChannelID: channels.ConsensusCluster(clusterID).String(),
   266  		Payload:   payload,
   267  	}
   268  	validateCollConsensusPubsub := authorizedSenderValidator.PubSubMessageValidator(channels.ConsensusCluster(clusterID))
   269  	pubsubResult := validateCollConsensusPubsub(pid, m)
   270  	require.Equal(s.T(), p2p.ValidationAccept, pubsubResult)
   271  
   272  	// ensure SyncRequest is allowed to be sent via pubsub by authorized sender
   273  	payload, err = s.codec.Encode(&messages.SyncRequest{})
   274  	require.NoError(s.T(), err)
   275  	m = &message.Message{
   276  		ChannelID: channels.SyncCluster(clusterID).String(),
   277  		Payload:   payload,
   278  	}
   279  	validateSyncClusterPubsub := authorizedSenderValidator.PubSubMessageValidator(channels.SyncCluster(clusterID))
   280  	pubsubResult = validateSyncClusterPubsub(pid, m)
   281  	require.Equal(s.T(), p2p.ValidationAccept, pubsubResult)
   282  }
   283  
   284  // TestValidatorCallback_ValidationFailure checks that AuthorizedSenderValidator.Validate returns the expected validation error.
   285  func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_ValidationFailure() {
   286  	s.Run("sender is ejected", func() {
   287  		identity, _ := unittest.IdentityWithNetworkingKeyFixture()
   288  		identity.EpochParticipationStatus = flow.EpochParticipationStatusEjected
   289  		getIdentityFunc := s.getIdentity(identity)
   290  		pid, err := unittest.PeerIDFromFlowID(identity)
   291  		require.NoError(s.T(), err)
   292  
   293  		expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(identity.NodeID, alsp.SenderEjected)
   294  		require.NoError(s.T(), err)
   295  		misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T())
   296  		misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channels.SyncCommittee, expectedMisbehaviorReport).Twice()
   297  		violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer)
   298  		authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc)
   299  
   300  		msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeSyncRequest.Uint8()}, channels.SyncCommittee, message.ProtocolTypeUnicast)
   301  		require.ErrorIs(s.T(), err, ErrSenderEjected)
   302  		require.Equal(s.T(), "", msgType)
   303  
   304  		payload, err := s.codec.Encode(&messages.SyncRequest{})
   305  		require.NoError(s.T(), err)
   306  		m := &message.Message{
   307  			ChannelID: channels.SyncCommittee.String(),
   308  			Payload:   payload,
   309  		}
   310  		validatePubsub := authorizedSenderValidator.PubSubMessageValidator(channels.SyncCommittee)
   311  		pubsubResult := validatePubsub(pid, m)
   312  		require.Equal(s.T(), p2p.ValidationReject, pubsubResult)
   313  	})
   314  
   315  	s.Run("unknown message code", func() {
   316  		identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(flow.RoleConsensus))
   317  
   318  		getIdentityFunc := s.getIdentity(identity)
   319  		pid, err := unittest.PeerIDFromFlowID(identity)
   320  		require.NoError(s.T(), err)
   321  
   322  		expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(identity.NodeID, alsp.UnknownMsgType)
   323  		require.NoError(s.T(), err)
   324  		misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T())
   325  		misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channels.ConsensusCommittee, expectedMisbehaviorReport).Twice()
   326  		violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer)
   327  		authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc)
   328  		validatePubsub := authorizedSenderValidator.PubSubMessageValidator(channels.ConsensusCommittee)
   329  
   330  		// unknown message types are rejected
   331  		msgType, err := authorizedSenderValidator.Validate(pid, []byte{'x'}, channels.ConsensusCommittee, message.ProtocolTypeUnicast)
   332  		require.True(s.T(), codec.IsErrUnknownMsgCode(err))
   333  		require.Equal(s.T(), "", msgType)
   334  
   335  		payload, err := s.codec.Encode(&messages.BlockProposal{})
   336  		require.NoError(s.T(), err)
   337  		payload[0] = byte('x')
   338  		netMsg := &message.Message{
   339  			ChannelID: channels.ConsensusCommittee.String(),
   340  			Payload:   payload,
   341  		}
   342  		pubsubResult := validatePubsub(pid, netMsg)
   343  		require.Equal(s.T(), p2p.ValidationReject, pubsubResult)
   344  	})
   345  
   346  	s.Run("sender is not staked getIdentityFunc does not return identity ", func() {
   347  		identity, _ := unittest.IdentityWithNetworkingKeyFixture()
   348  
   349  		// getIdentityFunc simulates unstaked node not found in participant list
   350  		getIdentityFunc := func(id peer.ID) (*flow.Identity, bool) { return nil, false }
   351  
   352  		pid, err := unittest.PeerIDFromFlowID(identity)
   353  		require.NoError(s.T(), err)
   354  
   355  		misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T())
   356  		// we cannot penalize a peer if identity is not known, in this case we don't expect any misbehavior reports to be reported
   357  		defer misbehaviorReportConsumer.AssertNotCalled(s.T(), "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport"))
   358  		violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer)
   359  		authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc)
   360  
   361  		msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeSyncRequest.Uint8()}, channels.SyncCommittee, message.ProtocolTypeUnicast)
   362  		require.ErrorIs(s.T(), err, ErrIdentityUnverified)
   363  		require.Equal(s.T(), "", msgType)
   364  
   365  		payload, err := s.codec.Encode(&messages.SyncRequest{})
   366  		require.NoError(s.T(), err)
   367  		m := &message.Message{
   368  			ChannelID: channels.SyncCommittee.String(),
   369  			Payload:   payload,
   370  		}
   371  		validatePubsub := authorizedSenderValidator.PubSubMessageValidator(channels.SyncCommittee)
   372  		pubsubResult := validatePubsub(pid, m)
   373  		require.Equal(s.T(), p2p.ValidationReject, pubsubResult)
   374  	})
   375  }
   376  
   377  // TestValidatorCallback_ValidationFailure checks that AuthorizedSenderValidator returns the expected validation error when a unicast-only message is published.
   378  func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_UnauthorizedPublishOnChannel() {
   379  	for _, c := range s.authorizedUnicastOnChannel {
   380  		str := fmt.Sprintf("message type (%s) is not authorized to be sent via libp2p publish", c.MessageStr)
   381  		s.Run(str, func() {
   382  			// skip test message check
   383  			if c.MessageStr == "*message.TestMessage" {
   384  				return
   385  			}
   386  			pid, err := unittest.PeerIDFromFlowID(c.Identity)
   387  			require.NoError(s.T(), err)
   388  			expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(c.Identity.NodeID, alsp.UnauthorizedPublishOnChannel)
   389  			require.NoError(s.T(), err)
   390  			misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T())
   391  			misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", c.Channel, expectedMisbehaviorReport).Once()
   392  			violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer)
   393  			authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity)
   394  			msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypePubSub)
   395  			require.ErrorIs(s.T(), err, message.ErrUnauthorizedPublishOnChannel)
   396  			require.Equal(s.T(), c.MessageStr, msgType)
   397  		})
   398  	}
   399  }
   400  
   401  // initializeAuthorizationTestCases initializes happy and sad path test cases for checking authorized and unauthorized role message combinations.
   402  func (s *TestAuthorizedSenderValidatorSuite) initializeAuthorizationTestCases() {
   403  	for _, c := range s.allMsgConfigs {
   404  		for channel, channelAuthConfig := range c.Config {
   405  			for _, role := range flow.Roles() {
   406  				identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(role))
   407  				code, what, err := codec.MessageCodeFromInterface(c.Type())
   408  				require.NoError(s.T(), err)
   409  				tc := TestCase{
   410  					Identity:    identity,
   411  					GetIdentity: s.getIdentity(identity),
   412  					Channel:     channel,
   413  					Message:     c.Type(),
   414  					MessageCode: code,
   415  					MessageStr:  what,
   416  					Protocols:   channelAuthConfig.AllowedProtocols,
   417  				}
   418  				if channelAuthConfig.AuthorizedRoles.Contains(role) {
   419  					// test cases for validation success happy path
   420  					s.authorizedSenderTestCases = append(s.authorizedSenderTestCases, tc)
   421  				} else {
   422  					// test cases for validation unsuccessful sad path
   423  					s.unauthorizedSenderTestCases = append(s.unauthorizedSenderTestCases, tc)
   424  				}
   425  			}
   426  		}
   427  	}
   428  }
   429  
   430  // initializeInvalidMessageOnChannelTestCases initializes test cases for all possible combinations of invalid message types on channel.
   431  // NOTE: the role in the test case does not matter since ErrUnauthorizedMessageOnChannel will be returned before the role is checked.
   432  func (s *TestAuthorizedSenderValidatorSuite) initializeInvalidMessageOnChannelTestCases() {
   433  	// iterate all channels
   434  	for _, c := range s.allMsgConfigs {
   435  		for channel, channelAuthConfig := range c.Config {
   436  			identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(channelAuthConfig.AuthorizedRoles[0]))
   437  
   438  			// iterate all message types
   439  			for _, config := range s.allMsgConfigs {
   440  				// include test if message type is not authorized on channel
   441  				_, ok := config.Config[channel]
   442  				code, what, err := codec.MessageCodeFromInterface(config.Type())
   443  				require.NoError(s.T(), err)
   444  				if config.Name != c.Name && !ok {
   445  					tc := TestCase{
   446  						Identity:    identity,
   447  						GetIdentity: s.getIdentity(identity),
   448  						Channel:     channel,
   449  						Message:     config.Type(),
   450  						MessageCode: code,
   451  						MessageStr:  what,
   452  						Protocols:   channelAuthConfig.AllowedProtocols,
   453  					}
   454  					s.unauthorizedMessageOnChannelTestCases = append(s.unauthorizedMessageOnChannelTestCases, tc)
   455  				}
   456  			}
   457  		}
   458  	}
   459  }
   460  
   461  // initializeUnicastOnChannelTestCases initializes happy and sad path test cases for unicast on channel message combinations.
   462  func (s *TestAuthorizedSenderValidatorSuite) initializeUnicastOnChannelTestCases() {
   463  	for _, c := range s.allMsgConfigs {
   464  		for channel, channelAuthConfig := range c.Config {
   465  			identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(channelAuthConfig.AuthorizedRoles[0]))
   466  			code, what, err := codec.MessageCodeFromInterface(c.Type())
   467  			require.NoError(s.T(), err)
   468  			tc := TestCase{
   469  				Identity:    identity,
   470  				GetIdentity: s.getIdentity(identity),
   471  				Channel:     channel,
   472  				Message:     c.Type(),
   473  				MessageCode: code,
   474  				MessageStr:  what,
   475  				Protocols:   channelAuthConfig.AllowedProtocols,
   476  			}
   477  			if channelAuthConfig.AllowedProtocols.Contains(message.ProtocolTypeUnicast) {
   478  				s.authorizedUnicastOnChannel = append(s.authorizedUnicastOnChannel, tc)
   479  			} else {
   480  				s.unauthorizedUnicastOnChannel = append(s.unauthorizedUnicastOnChannel, tc)
   481  			}
   482  		}
   483  	}
   484  }
   485  
   486  // getIdentity returns a callback that simply returns the provided identity.
   487  func (s *TestAuthorizedSenderValidatorSuite) getIdentity(id *flow.Identity) func(pid peer.ID) (*flow.Identity, bool) {
   488  	return func(pid peer.ID) (*flow.Identity, bool) {
   489  		return id, true
   490  	}
   491  }