github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/network/validator/authorized_sender_validator_test.go (about) 1 package validator 2 3 import ( 4 "fmt" 5 "testing" 6 7 "github.com/libp2p/go-libp2p/core/peer" 8 "github.com/rs/zerolog" 9 "github.com/stretchr/testify/mock" 10 "github.com/stretchr/testify/require" 11 "github.com/stretchr/testify/suite" 12 13 "github.com/onflow/flow-go/model/flow" 14 libp2pmessage "github.com/onflow/flow-go/model/libp2p/message" 15 "github.com/onflow/flow-go/model/messages" 16 "github.com/onflow/flow-go/module/metrics" 17 "github.com/onflow/flow-go/network" 18 "github.com/onflow/flow-go/network/alsp" 19 "github.com/onflow/flow-go/network/channels" 20 "github.com/onflow/flow-go/network/codec" 21 "github.com/onflow/flow-go/network/message" 22 "github.com/onflow/flow-go/network/mocknetwork" 23 "github.com/onflow/flow-go/network/p2p" 24 "github.com/onflow/flow-go/network/slashing" 25 "github.com/onflow/flow-go/utils/unittest" 26 ) 27 28 type TestCase struct { 29 Identity *flow.Identity 30 GetIdentity func(pid peer.ID) (*flow.Identity, bool) 31 Channel channels.Channel 32 Message interface{} 33 MessageCode codec.MessageCode 34 MessageStr string 35 Protocols message.Protocols 36 } 37 38 func TestIsAuthorizedSender(t *testing.T) { 39 suite.Run(t, new(TestAuthorizedSenderValidatorSuite)) 40 } 41 42 type TestAuthorizedSenderValidatorSuite struct { 43 suite.Suite 44 authorizedSenderTestCases []TestCase 45 unauthorizedSenderTestCases []TestCase 46 unauthorizedMessageOnChannelTestCases []TestCase 47 unauthorizedUnicastOnChannel []TestCase 48 authorizedUnicastOnChannel []TestCase 49 log zerolog.Logger 50 slashingViolationsConsumer network.ViolationsConsumer 51 allMsgConfigs []message.MsgAuthConfig 52 codec network.Codec 53 } 54 55 func (s *TestAuthorizedSenderValidatorSuite) SetupTest() { 56 s.allMsgConfigs = message.GetAllMessageAuthConfigs() 57 s.initializeAuthorizationTestCases() 58 s.initializeInvalidMessageOnChannelTestCases() 59 s.initializeUnicastOnChannelTestCases() 60 s.log = unittest.Logger() 61 s.codec = unittest.NetworkCodec() 62 } 63 64 // TestValidatorCallback_AuthorizedSender checks that AuthorizedSenderValidator.Validate does not return false positive 65 // validation errors for all possible valid combinations (authorized sender role, message type). 66 func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_AuthorizedSender() { 67 for _, c := range s.authorizedSenderTestCases { 68 str := fmt.Sprintf("role (%s) should be authorized to send message type (%s) on channel (%s)", c.Identity.Role, c.MessageStr, c.Channel) 69 s.Run(str, func() { 70 misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) 71 defer misbehaviorReportConsumer.AssertNotCalled(s.T(), "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport")) 72 violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) 73 authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity) 74 validateUnicast := authorizedSenderValidator.Validate 75 validatePubsub := authorizedSenderValidator.PubSubMessageValidator(c.Channel) 76 pid, err := unittest.PeerIDFromFlowID(c.Identity) 77 require.NoError(s.T(), err) 78 switch { 79 // ensure according to the message auth config, if a message is authorized to be sent via unicast it 80 // is accepted. 81 case c.Protocols.Contains(message.ProtocolTypeUnicast): 82 msgType, err := validateUnicast(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast) 83 if c.Protocols.Contains(message.ProtocolTypeUnicast) { 84 require.NoError(s.T(), err) 85 require.Equal(s.T(), c.MessageStr, msgType) 86 } 87 // ensure according to the message auth config, if a message is authorized to be sent via pubsub it 88 // is accepted. 89 case c.Protocols.Contains(message.ProtocolTypePubSub): 90 payload, err := s.codec.Encode(c.Message) 91 require.NoError(s.T(), err) 92 m := &message.Message{ 93 ChannelID: c.Channel.String(), 94 Payload: payload, 95 } 96 pubsubResult := validatePubsub(pid, m) 97 require.Equal(s.T(), p2p.ValidationAccept, pubsubResult) 98 default: 99 s.T().Fatal("authconfig does not contain any protocols") 100 } 101 }) 102 } 103 104 s.Run("test messages should be allowed to be sent via both protocols unicast/pubsub on test channel", func() { 105 identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(flow.RoleCollection)) 106 misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) 107 defer misbehaviorReportConsumer.AssertNotCalled(s.T(), "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport")) 108 violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) 109 getIdentityFunc := s.getIdentity(identity) 110 pid, err := unittest.PeerIDFromFlowID(identity) 111 require.NoError(s.T(), err) 112 authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc) 113 114 msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeEcho.Uint8()}, channels.TestNetworkChannel, message.ProtocolTypeUnicast) 115 require.NoError(s.T(), err) 116 require.Equal(s.T(), "*message.TestMessage", msgType) 117 118 payload, err := s.codec.Encode(&libp2pmessage.TestMessage{}) 119 require.NoError(s.T(), err) 120 m := &message.Message{ 121 ChannelID: channels.TestNetworkChannel.String(), 122 Payload: payload, 123 } 124 validatePubsub := authorizedSenderValidator.PubSubMessageValidator(channels.TestNetworkChannel) 125 pubsubResult := validatePubsub(pid, m) 126 require.Equal(s.T(), p2p.ValidationAccept, pubsubResult) 127 }) 128 } 129 130 // TestValidatorCallback_UnAuthorizedSender checks that AuthorizedSenderValidator.Validate return's p2p.ValidationReject 131 // validation error for all possible invalid combinations (unauthorized sender role, message type). 132 func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_UnAuthorizedSender() { 133 for _, c := range s.unauthorizedSenderTestCases { 134 str := fmt.Sprintf("role (%s) should not be authorized to send message type (%s) on channel (%s)", c.Identity.Role, c.MessageStr, c.Channel) 135 s.Run(str, func() { 136 pid, err := unittest.PeerIDFromFlowID(c.Identity) 137 require.NoError(s.T(), err) 138 expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(c.Identity.NodeID, alsp.UnAuthorizedSender) 139 require.NoError(s.T(), err) 140 misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) 141 misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", c.Channel, expectedMisbehaviorReport).Once() 142 violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) 143 authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity) 144 145 payload, err := s.codec.Encode(c.Message) 146 require.NoError(s.T(), err) 147 m := &message.Message{ 148 ChannelID: c.Channel.String(), 149 Payload: payload, 150 } 151 validatePubsub := authorizedSenderValidator.PubSubMessageValidator(c.Channel) 152 pubsubResult := validatePubsub(pid, m) 153 require.Equal(s.T(), p2p.ValidationReject, pubsubResult) 154 }) 155 } 156 } 157 158 // TestValidatorCallback_AuthorizedUnicastOnChannel checks that AuthorizedSenderValidator.Validate does not return an error 159 // for messages sent via unicast that are authorized to be sent via unicast. 160 func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_AuthorizedUnicastOnChannel() { 161 for _, c := range s.authorizedUnicastOnChannel { 162 str := fmt.Sprintf("role (%s) should be authorized to send message type (%s) on channel (%s) via unicast", c.Identity.Role, c.MessageStr, c.Channel) 163 s.Run(str, func() { 164 pid, err := unittest.PeerIDFromFlowID(c.Identity) 165 require.NoError(s.T(), err) 166 misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) 167 defer misbehaviorReportConsumer.AssertNotCalled(s.T(), "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport")) 168 violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) 169 authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity) 170 171 msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast) 172 require.NoError(s.T(), err) 173 require.Equal(s.T(), c.MessageStr, msgType) 174 }) 175 } 176 } 177 178 // TestValidatorCallback_UnAuthorizedUnicastOnChannel checks that AuthorizedSenderValidator.Validate returns message.ErrUnauthorizedUnicastOnChannel 179 // when a message not authorized to be sent via unicast is sent via unicast. 180 func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_UnAuthorizedUnicastOnChannel() { 181 for _, c := range s.unauthorizedUnicastOnChannel { 182 str := fmt.Sprintf("role (%s) should not be authorized to send message type (%s) on channel (%s) via unicast", c.Identity.Role, c.MessageStr, c.Channel) 183 s.Run(str, func() { 184 pid, err := unittest.PeerIDFromFlowID(c.Identity) 185 require.NoError(s.T(), err) 186 expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(c.Identity.NodeID, alsp.UnauthorizedUnicastOnChannel) 187 require.NoError(s.T(), err) 188 misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) 189 misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", c.Channel, expectedMisbehaviorReport).Once() 190 violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) 191 authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity) 192 193 msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast) 194 require.ErrorIs(s.T(), err, message.ErrUnauthorizedUnicastOnChannel) 195 require.Equal(s.T(), c.MessageStr, msgType) 196 }) 197 } 198 } 199 200 // TestValidatorCallback_UnAuthorizedMessageOnChannel checks that for each invalid combination of message type and channel 201 // AuthorizedSenderValidator.Validate returns the appropriate error message.ErrUnauthorizedMessageOnChannel. 202 func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_UnAuthorizedMessageOnChannel() { 203 for _, c := range s.unauthorizedMessageOnChannelTestCases { 204 str := fmt.Sprintf("message type (%s) should not be authorized to be sent on channel (%s)", c.MessageStr, c.Channel) 205 s.Run(str, func() { 206 pid, err := unittest.PeerIDFromFlowID(c.Identity) 207 require.NoError(s.T(), err) 208 expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(c.Identity.NodeID, alsp.UnAuthorizedSender) 209 require.NoError(s.T(), err) 210 misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) 211 misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", c.Channel, expectedMisbehaviorReport).Twice() 212 violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) 213 authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity) 214 215 msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypeUnicast) 216 require.ErrorIs(s.T(), err, message.ErrUnauthorizedMessageOnChannel) 217 require.Equal(s.T(), c.MessageStr, msgType) 218 219 payload, err := s.codec.Encode(c.Message) 220 require.NoError(s.T(), err) 221 m := &message.Message{ 222 ChannelID: c.Channel.String(), 223 Payload: payload, 224 } 225 validatePubsub := authorizedSenderValidator.PubSubMessageValidator(c.Channel) 226 pubsubResult := validatePubsub(pid, m) 227 require.Equal(s.T(), p2p.ValidationReject, pubsubResult) 228 }) 229 } 230 } 231 232 // TestValidatorCallback_ClusterPrefixedChannels checks that AuthorizedSenderValidator.Validate correctly 233 // handles cluster prefixed channels during validation. 234 func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_ClusterPrefixedChannels() { 235 identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(flow.RoleCollection)) 236 clusterID := flow.Localnet 237 238 getIdentityFunc := s.getIdentity(identity) 239 pid, err := unittest.PeerIDFromFlowID(identity) 240 require.NoError(s.T(), err) 241 242 expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(identity.NodeID, alsp.UnauthorizedUnicastOnChannel) 243 require.NoError(s.T(), err) 244 misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) 245 misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channels.SyncCluster(clusterID), expectedMisbehaviorReport).Once() 246 misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channels.ConsensusCluster(clusterID), expectedMisbehaviorReport).Once() 247 248 violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) 249 authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc) 250 251 // validate collection sync cluster SyncRequest is not allowed to be sent on channel via unicast 252 msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeSyncRequest.Uint8()}, channels.SyncCluster(clusterID), message.ProtocolTypeUnicast) 253 require.ErrorIs(s.T(), err, message.ErrUnauthorizedUnicastOnChannel) 254 require.Equal(s.T(), "*messages.SyncRequest", msgType) 255 256 // ensure ClusterBlockProposal not allowed to be sent on channel via unicast 257 msgType, err = authorizedSenderValidator.Validate(pid, []byte{codec.CodeClusterBlockProposal.Uint8()}, channels.ConsensusCluster(clusterID), message.ProtocolTypeUnicast) 258 require.ErrorIs(s.T(), err, message.ErrUnauthorizedUnicastOnChannel) 259 require.Equal(s.T(), "*messages.ClusterBlockProposal", msgType) 260 261 // ensure ClusterBlockProposal is allowed to be sent via pubsub by authorized sender 262 payload, err := s.codec.Encode(&messages.ClusterBlockProposal{}) 263 require.NoError(s.T(), err) 264 m := &message.Message{ 265 ChannelID: channels.ConsensusCluster(clusterID).String(), 266 Payload: payload, 267 } 268 validateCollConsensusPubsub := authorizedSenderValidator.PubSubMessageValidator(channels.ConsensusCluster(clusterID)) 269 pubsubResult := validateCollConsensusPubsub(pid, m) 270 require.Equal(s.T(), p2p.ValidationAccept, pubsubResult) 271 272 // ensure SyncRequest is allowed to be sent via pubsub by authorized sender 273 payload, err = s.codec.Encode(&messages.SyncRequest{}) 274 require.NoError(s.T(), err) 275 m = &message.Message{ 276 ChannelID: channels.SyncCluster(clusterID).String(), 277 Payload: payload, 278 } 279 validateSyncClusterPubsub := authorizedSenderValidator.PubSubMessageValidator(channels.SyncCluster(clusterID)) 280 pubsubResult = validateSyncClusterPubsub(pid, m) 281 require.Equal(s.T(), p2p.ValidationAccept, pubsubResult) 282 } 283 284 // TestValidatorCallback_ValidationFailure checks that AuthorizedSenderValidator.Validate returns the expected validation error. 285 func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_ValidationFailure() { 286 s.Run("sender is ejected", func() { 287 identity, _ := unittest.IdentityWithNetworkingKeyFixture() 288 identity.EpochParticipationStatus = flow.EpochParticipationStatusEjected 289 getIdentityFunc := s.getIdentity(identity) 290 pid, err := unittest.PeerIDFromFlowID(identity) 291 require.NoError(s.T(), err) 292 293 expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(identity.NodeID, alsp.SenderEjected) 294 require.NoError(s.T(), err) 295 misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) 296 misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channels.SyncCommittee, expectedMisbehaviorReport).Twice() 297 violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) 298 authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc) 299 300 msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeSyncRequest.Uint8()}, channels.SyncCommittee, message.ProtocolTypeUnicast) 301 require.ErrorIs(s.T(), err, ErrSenderEjected) 302 require.Equal(s.T(), "", msgType) 303 304 payload, err := s.codec.Encode(&messages.SyncRequest{}) 305 require.NoError(s.T(), err) 306 m := &message.Message{ 307 ChannelID: channels.SyncCommittee.String(), 308 Payload: payload, 309 } 310 validatePubsub := authorizedSenderValidator.PubSubMessageValidator(channels.SyncCommittee) 311 pubsubResult := validatePubsub(pid, m) 312 require.Equal(s.T(), p2p.ValidationReject, pubsubResult) 313 }) 314 315 s.Run("unknown message code", func() { 316 identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(flow.RoleConsensus)) 317 318 getIdentityFunc := s.getIdentity(identity) 319 pid, err := unittest.PeerIDFromFlowID(identity) 320 require.NoError(s.T(), err) 321 322 expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(identity.NodeID, alsp.UnknownMsgType) 323 require.NoError(s.T(), err) 324 misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) 325 misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", channels.ConsensusCommittee, expectedMisbehaviorReport).Twice() 326 violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) 327 authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc) 328 validatePubsub := authorizedSenderValidator.PubSubMessageValidator(channels.ConsensusCommittee) 329 330 // unknown message types are rejected 331 msgType, err := authorizedSenderValidator.Validate(pid, []byte{'x'}, channels.ConsensusCommittee, message.ProtocolTypeUnicast) 332 require.True(s.T(), codec.IsErrUnknownMsgCode(err)) 333 require.Equal(s.T(), "", msgType) 334 335 payload, err := s.codec.Encode(&messages.BlockProposal{}) 336 require.NoError(s.T(), err) 337 payload[0] = byte('x') 338 netMsg := &message.Message{ 339 ChannelID: channels.ConsensusCommittee.String(), 340 Payload: payload, 341 } 342 pubsubResult := validatePubsub(pid, netMsg) 343 require.Equal(s.T(), p2p.ValidationReject, pubsubResult) 344 }) 345 346 s.Run("sender is not staked getIdentityFunc does not return identity ", func() { 347 identity, _ := unittest.IdentityWithNetworkingKeyFixture() 348 349 // getIdentityFunc simulates unstaked node not found in participant list 350 getIdentityFunc := func(id peer.ID) (*flow.Identity, bool) { return nil, false } 351 352 pid, err := unittest.PeerIDFromFlowID(identity) 353 require.NoError(s.T(), err) 354 355 misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) 356 // we cannot penalize a peer if identity is not known, in this case we don't expect any misbehavior reports to be reported 357 defer misbehaviorReportConsumer.AssertNotCalled(s.T(), "ReportMisbehaviorOnChannel", mock.AnythingOfType("channels.Channel"), mock.AnythingOfType("*alsp.MisbehaviorReport")) 358 violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) 359 authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, getIdentityFunc) 360 361 msgType, err := authorizedSenderValidator.Validate(pid, []byte{codec.CodeSyncRequest.Uint8()}, channels.SyncCommittee, message.ProtocolTypeUnicast) 362 require.ErrorIs(s.T(), err, ErrIdentityUnverified) 363 require.Equal(s.T(), "", msgType) 364 365 payload, err := s.codec.Encode(&messages.SyncRequest{}) 366 require.NoError(s.T(), err) 367 m := &message.Message{ 368 ChannelID: channels.SyncCommittee.String(), 369 Payload: payload, 370 } 371 validatePubsub := authorizedSenderValidator.PubSubMessageValidator(channels.SyncCommittee) 372 pubsubResult := validatePubsub(pid, m) 373 require.Equal(s.T(), p2p.ValidationReject, pubsubResult) 374 }) 375 } 376 377 // TestValidatorCallback_ValidationFailure checks that AuthorizedSenderValidator returns the expected validation error when a unicast-only message is published. 378 func (s *TestAuthorizedSenderValidatorSuite) TestValidatorCallback_UnauthorizedPublishOnChannel() { 379 for _, c := range s.authorizedUnicastOnChannel { 380 str := fmt.Sprintf("message type (%s) is not authorized to be sent via libp2p publish", c.MessageStr) 381 s.Run(str, func() { 382 // skip test message check 383 if c.MessageStr == "*message.TestMessage" { 384 return 385 } 386 pid, err := unittest.PeerIDFromFlowID(c.Identity) 387 require.NoError(s.T(), err) 388 expectedMisbehaviorReport, err := alsp.NewMisbehaviorReport(c.Identity.NodeID, alsp.UnauthorizedPublishOnChannel) 389 require.NoError(s.T(), err) 390 misbehaviorReportConsumer := mocknetwork.NewMisbehaviorReportConsumer(s.T()) 391 misbehaviorReportConsumer.On("ReportMisbehaviorOnChannel", c.Channel, expectedMisbehaviorReport).Once() 392 violationsConsumer := slashing.NewSlashingViolationsConsumer(s.log, metrics.NewNoopCollector(), misbehaviorReportConsumer) 393 authorizedSenderValidator := NewAuthorizedSenderValidator(s.log, violationsConsumer, c.GetIdentity) 394 msgType, err := authorizedSenderValidator.Validate(pid, []byte{c.MessageCode.Uint8()}, c.Channel, message.ProtocolTypePubSub) 395 require.ErrorIs(s.T(), err, message.ErrUnauthorizedPublishOnChannel) 396 require.Equal(s.T(), c.MessageStr, msgType) 397 }) 398 } 399 } 400 401 // initializeAuthorizationTestCases initializes happy and sad path test cases for checking authorized and unauthorized role message combinations. 402 func (s *TestAuthorizedSenderValidatorSuite) initializeAuthorizationTestCases() { 403 for _, c := range s.allMsgConfigs { 404 for channel, channelAuthConfig := range c.Config { 405 for _, role := range flow.Roles() { 406 identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(role)) 407 code, what, err := codec.MessageCodeFromInterface(c.Type()) 408 require.NoError(s.T(), err) 409 tc := TestCase{ 410 Identity: identity, 411 GetIdentity: s.getIdentity(identity), 412 Channel: channel, 413 Message: c.Type(), 414 MessageCode: code, 415 MessageStr: what, 416 Protocols: channelAuthConfig.AllowedProtocols, 417 } 418 if channelAuthConfig.AuthorizedRoles.Contains(role) { 419 // test cases for validation success happy path 420 s.authorizedSenderTestCases = append(s.authorizedSenderTestCases, tc) 421 } else { 422 // test cases for validation unsuccessful sad path 423 s.unauthorizedSenderTestCases = append(s.unauthorizedSenderTestCases, tc) 424 } 425 } 426 } 427 } 428 } 429 430 // initializeInvalidMessageOnChannelTestCases initializes test cases for all possible combinations of invalid message types on channel. 431 // NOTE: the role in the test case does not matter since ErrUnauthorizedMessageOnChannel will be returned before the role is checked. 432 func (s *TestAuthorizedSenderValidatorSuite) initializeInvalidMessageOnChannelTestCases() { 433 // iterate all channels 434 for _, c := range s.allMsgConfigs { 435 for channel, channelAuthConfig := range c.Config { 436 identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(channelAuthConfig.AuthorizedRoles[0])) 437 438 // iterate all message types 439 for _, config := range s.allMsgConfigs { 440 // include test if message type is not authorized on channel 441 _, ok := config.Config[channel] 442 code, what, err := codec.MessageCodeFromInterface(config.Type()) 443 require.NoError(s.T(), err) 444 if config.Name != c.Name && !ok { 445 tc := TestCase{ 446 Identity: identity, 447 GetIdentity: s.getIdentity(identity), 448 Channel: channel, 449 Message: config.Type(), 450 MessageCode: code, 451 MessageStr: what, 452 Protocols: channelAuthConfig.AllowedProtocols, 453 } 454 s.unauthorizedMessageOnChannelTestCases = append(s.unauthorizedMessageOnChannelTestCases, tc) 455 } 456 } 457 } 458 } 459 } 460 461 // initializeUnicastOnChannelTestCases initializes happy and sad path test cases for unicast on channel message combinations. 462 func (s *TestAuthorizedSenderValidatorSuite) initializeUnicastOnChannelTestCases() { 463 for _, c := range s.allMsgConfigs { 464 for channel, channelAuthConfig := range c.Config { 465 identity, _ := unittest.IdentityWithNetworkingKeyFixture(unittest.WithRole(channelAuthConfig.AuthorizedRoles[0])) 466 code, what, err := codec.MessageCodeFromInterface(c.Type()) 467 require.NoError(s.T(), err) 468 tc := TestCase{ 469 Identity: identity, 470 GetIdentity: s.getIdentity(identity), 471 Channel: channel, 472 Message: c.Type(), 473 MessageCode: code, 474 MessageStr: what, 475 Protocols: channelAuthConfig.AllowedProtocols, 476 } 477 if channelAuthConfig.AllowedProtocols.Contains(message.ProtocolTypeUnicast) { 478 s.authorizedUnicastOnChannel = append(s.authorizedUnicastOnChannel, tc) 479 } else { 480 s.unauthorizedUnicastOnChannel = append(s.unauthorizedUnicastOnChannel, tc) 481 } 482 } 483 } 484 } 485 486 // getIdentity returns a callback that simply returns the provided identity. 487 func (s *TestAuthorizedSenderValidatorSuite) getIdentity(id *flow.Identity) func(pid peer.ID) (*flow.Identity, bool) { 488 return func(pid peer.ID) (*flow.Identity, bool) { 489 return id, true 490 } 491 }