github.com/koko1123/flow-go-1@v0.29.6/network/test/middleware_test.go (about) 1 package test 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "regexp" 8 "sync" 9 "testing" 10 "time" 11 12 "github.com/libp2p/go-libp2p/core/peer" 13 "github.com/libp2p/go-libp2p/p2p/net/swarm" 14 "github.com/rs/zerolog" 15 "github.com/stretchr/testify/assert" 16 mockery "github.com/stretchr/testify/mock" 17 "github.com/stretchr/testify/require" 18 "github.com/stretchr/testify/suite" 19 "go.uber.org/atomic" 20 "golang.org/x/time/rate" 21 22 "github.com/koko1123/flow-go-1/model/flow" 23 "github.com/koko1123/flow-go-1/model/flow/filter" 24 libp2pmessage "github.com/koko1123/flow-go-1/model/libp2p/message" 25 "github.com/koko1123/flow-go-1/module/irrecoverable" 26 "github.com/koko1123/flow-go-1/module/metrics" 27 "github.com/koko1123/flow-go-1/module/mock" 28 "github.com/koko1123/flow-go-1/module/observable" 29 "github.com/koko1123/flow-go-1/network" 30 "github.com/koko1123/flow-go-1/network/channels" 31 "github.com/koko1123/flow-go-1/network/internal/testutils" 32 "github.com/koko1123/flow-go-1/network/mocknetwork" 33 "github.com/koko1123/flow-go-1/network/p2p" 34 "github.com/koko1123/flow-go-1/network/p2p/middleware" 35 "github.com/koko1123/flow-go-1/network/p2p/p2pnode" 36 "github.com/koko1123/flow-go-1/network/p2p/unicast/ratelimit" 37 "github.com/koko1123/flow-go-1/network/slashing" 38 "github.com/koko1123/flow-go-1/utils/unittest" 39 ) 40 41 const testChannel = channels.TestNetworkChannel 42 43 // libp2p emits a call to `Protect` with a topic-specific tag upon establishing each peering connection in a GossipSUb mesh, see: 44 // https://github.com/libp2p/go-libp2p-pubsub/blob/master/tag_tracer.go 45 // One way to make sure such a mesh has formed, asynchronously, in unit tests, is to wait for libp2p.GossipSubD such calls, 46 // and that's what we do with tagsObserver. 47 type tagsObserver struct { 48 tags chan string 49 log zerolog.Logger 50 } 51 52 func (co *tagsObserver) OnNext(peertag interface{}) { 53 pt, ok := peertag.(testutils.PeerTag) 54 55 if ok { 56 co.tags <- fmt.Sprintf("peer: %v tag: %v", pt.Peer, pt.Tag) 57 } 58 59 } 60 func (co *tagsObserver) OnError(err error) { 61 co.log.Error().Err(err).Msg("Tags Observer closed on an error") 62 close(co.tags) 63 } 64 func (co *tagsObserver) OnComplete() { 65 close(co.tags) 66 } 67 68 type MiddlewareTestSuite struct { 69 suite.Suite 70 sync.RWMutex 71 size int // used to determine number of middlewares under test 72 nodes []p2p.LibP2PNode 73 mws []network.Middleware // used to keep track of middlewares under test 74 ov []*mocknetwork.Overlay 75 obs chan string // used to keep track of Protect events tagged by pubsub messages 76 ids []*flow.Identity 77 metrics *metrics.NoopCollector // no-op performance monitoring simulation 78 logger zerolog.Logger 79 providers []*testutils.UpdatableIDProvider 80 81 mwCancel context.CancelFunc 82 mwCtx irrecoverable.SignalerContext 83 84 slashingViolationsConsumer slashing.ViolationsConsumer 85 } 86 87 // TestMiddlewareTestSuit runs all the test methods in this test suit 88 func TestMiddlewareTestSuite(t *testing.T) { 89 t.Parallel() 90 suite.Run(t, new(MiddlewareTestSuite)) 91 } 92 93 // SetupTest initiates the test setups prior to each test 94 func (m *MiddlewareTestSuite) SetupTest() { 95 m.logger = unittest.Logger() 96 97 m.size = 2 // operates on two middlewares 98 m.metrics = metrics.NewNoopCollector() 99 100 // create and start the middlewares and inject a connection observer 101 var obs []observable.Observable 102 peerChannel := make(chan string) 103 ob := tagsObserver{ 104 tags: peerChannel, 105 log: m.logger, 106 } 107 108 m.slashingViolationsConsumer = mocknetwork.NewViolationsConsumer(m.T()) 109 110 m.ids, m.nodes, m.mws, obs, m.providers = testutils.GenerateIDsAndMiddlewares(m.T(), 111 m.size, 112 m.logger, 113 unittest.NetworkCodec(), 114 m.slashingViolationsConsumer) 115 116 for _, observableConnMgr := range obs { 117 observableConnMgr.Subscribe(&ob) 118 } 119 m.obs = peerChannel 120 121 require.Len(m.Suite.T(), obs, m.size) 122 require.Len(m.Suite.T(), m.ids, m.size) 123 require.Len(m.Suite.T(), m.mws, m.size) 124 125 // create the mock overlays 126 for i := 0; i < m.size; i++ { 127 m.ov = append(m.ov, m.createOverlay(m.providers[i])) 128 } 129 130 ctx, cancel := context.WithCancel(context.Background()) 131 m.mwCancel = cancel 132 133 m.mwCtx = irrecoverable.NewMockSignalerContext(m.T(), ctx) 134 135 testutils.StartNodes(m.mwCtx, m.T(), m.nodes, 100*time.Millisecond) 136 137 for i, mw := range m.mws { 138 mw.SetOverlay(m.ov[i]) 139 mw.Start(m.mwCtx) 140 unittest.RequireComponentsReadyBefore(m.T(), 100*time.Millisecond, mw) 141 require.NoError(m.T(), mw.Subscribe(testChannel)) 142 } 143 } 144 145 func (m *MiddlewareTestSuite) TearDownTest() { 146 m.mwCancel() 147 148 testutils.StopComponents(m.T(), m.mws, 100*time.Millisecond) 149 testutils.StopComponents(m.T(), m.nodes, 100*time.Millisecond) 150 151 m.mws = nil 152 m.nodes = nil 153 m.ov = nil 154 m.ids = nil 155 m.size = 0 156 } 157 158 // TestUpdateNodeAddresses tests that the UpdateNodeAddresses method correctly updates 159 // the addresses of the staked network participants. 160 func (m *MiddlewareTestSuite) TestUpdateNodeAddresses() { 161 ctx, cancel := context.WithCancel(m.mwCtx) 162 irrecoverableCtx := irrecoverable.NewMockSignalerContext(m.T(), ctx) 163 164 // create a new staked identity 165 ids, libP2PNodes, _ := testutils.GenerateIDs(m.T(), m.logger, 1) 166 167 mws, providers := testutils.GenerateMiddlewares(m.T(), m.logger, ids, libP2PNodes, unittest.NetworkCodec(), m.slashingViolationsConsumer) 168 require.Len(m.T(), ids, 1) 169 require.Len(m.T(), providers, 1) 170 require.Len(m.T(), mws, 1) 171 newId := ids[0] 172 newMw := mws[0] 173 174 overlay := m.createOverlay(providers[0]) 175 overlay.On("Receive", m.ids[0].NodeID, mockery.AnythingOfType("*message.Message")).Return(nil) 176 newMw.SetOverlay(overlay) 177 178 // start up nodes and peer managers 179 testutils.StartNodes(irrecoverableCtx, m.T(), libP2PNodes, 100*time.Millisecond) 180 defer testutils.StopComponents(m.T(), libP2PNodes, 100*time.Millisecond) 181 182 newMw.Start(irrecoverableCtx) 183 defer testutils.StopComponents(m.T(), mws, 100*time.Millisecond) 184 unittest.RequireComponentsReadyBefore(m.T(), 100*time.Millisecond, newMw) 185 186 idList := flow.IdentityList(append(m.ids, newId)) 187 188 // needed to enable ID translation 189 m.providers[0].SetIdentities(idList) 190 191 outMsg, err := network.NewOutgoingScope( 192 flow.IdentifierList{newId.NodeID}, 193 testChannel, 194 &libp2pmessage.TestMessage{ 195 Text: "TestUpdateNodeAddresses", 196 }, 197 unittest.NetworkCodec().Encode, 198 network.ProtocolTypeUnicast) 199 require.NoError(m.T(), err) 200 // message should fail to send because no address is known yet 201 // for the new identity 202 err = m.mws[0].SendDirect(outMsg) 203 require.ErrorIs(m.T(), err, swarm.ErrNoAddresses) 204 205 // update the addresses 206 m.mws[0].UpdateNodeAddresses() 207 208 // now the message should send successfully 209 err = m.mws[0].SendDirect(outMsg) 210 require.NoError(m.T(), err) 211 212 cancel() 213 unittest.RequireComponentsReadyBefore(m.T(), 100*time.Millisecond, newMw) 214 } 215 216 func (m *MiddlewareTestSuite) TestUnicastRateLimit_Messages() { 217 unittest.SkipUnless(m.T(), unittest.TEST_FLAKY, "disabling so that flaky metrics can be gathered before re-enabling") 218 219 // limiter limit will be set to 5 events/sec the 6th event per interval will be rate limited 220 limit := rate.Limit(5) 221 222 // burst per interval 223 burst := 5 224 225 messageRateLimiter := ratelimit.NewMessageRateLimiter(limit, burst, 1) 226 227 // the onUnicastRateLimitedPeerFunc call back we will use to keep track of how many times a rate limit happens 228 // after 5 rate limits we will close ch. O 229 ch := make(chan struct{}) 230 rateLimits := atomic.NewUint64(0) 231 onRateLimit := func(peerID peer.ID, role, msgType string, topic channels.Topic, reason ratelimit.RateLimitReason) { 232 require.Equal(m.T(), reason, ratelimit.ReasonMessageCount) 233 234 // we only expect messages from the first middleware on the test suite 235 expectedPID, err := unittest.PeerIDFromFlowID(m.ids[0]) 236 require.NoError(m.T(), err) 237 require.Equal(m.T(), expectedPID, peerID) 238 239 // update hook calls 240 rateLimits.Inc() 241 } 242 243 rateLimiters := ratelimit.NewRateLimiters(messageRateLimiter, 244 &ratelimit.NoopRateLimiter{}, 245 onRateLimit, 246 ratelimit.WithDisabledRateLimiting(false)) 247 248 // create a new staked identity 249 ids, libP2PNodes, _ := testutils.GenerateIDs(m.T(), m.logger, 1) 250 251 // create middleware 252 netmet := mock.NewNetworkMetrics(m.T()) 253 calls := 0 254 netmet.On("InboundMessageReceived", mockery.Anything, mockery.Anything, mockery.Anything).Times(5).Run(func(args mockery.Arguments) { 255 calls++ 256 if calls == 5 { 257 close(ch) 258 } 259 }) 260 // we expect 5 messages to be processed the rest will be rate limited 261 defer netmet.AssertNumberOfCalls(m.T(), "InboundMessageReceived", 5) 262 263 mws, providers := testutils.GenerateMiddlewares(m.T(), 264 m.logger, 265 ids, 266 libP2PNodes, 267 unittest.NetworkCodec(), 268 m.slashingViolationsConsumer, 269 testutils.WithUnicastRateLimiters(rateLimiters), 270 testutils.WithNetworkMetrics(netmet)) 271 272 require.Len(m.T(), ids, 1) 273 require.Len(m.T(), providers, 1) 274 require.Len(m.T(), mws, 1) 275 newId := ids[0] 276 newMw := mws[0] 277 278 overlay := m.createOverlay(providers[0]) 279 overlay.On("Receive", m.ids[0].NodeID, mockery.AnythingOfType("*message.Message")).Return(nil) 280 281 newMw.SetOverlay(overlay) 282 283 ctx, cancel := context.WithCancel(m.mwCtx) 284 irrecoverableCtx := irrecoverable.NewMockSignalerContext(m.T(), ctx) 285 286 testutils.StartNodes(irrecoverableCtx, m.T(), libP2PNodes, 100*time.Millisecond) 287 defer testutils.StopComponents(m.T(), libP2PNodes, 100*time.Millisecond) 288 289 newMw.Start(irrecoverableCtx) 290 unittest.RequireComponentsReadyBefore(m.T(), 100*time.Millisecond, newMw) 291 292 require.NoError(m.T(), newMw.Subscribe(testChannel)) 293 294 idList := flow.IdentityList(append(m.ids, newId)) 295 296 // needed to enable ID translation 297 m.providers[0].SetIdentities(idList) 298 299 // update the addresses 300 m.mws[0].UpdateNodeAddresses() 301 302 // send 6 unicast messages, 5 should be allowed and the 6th should be rate limited 303 for i := 0; i < 6; i++ { 304 msg, err := network.NewOutgoingScope( 305 flow.IdentifierList{newId.NodeID}, 306 testChannel, 307 &libp2pmessage.TestMessage{ 308 Text: fmt.Sprintf("hello-%d", i), 309 }, 310 unittest.NetworkCodec().Encode, 311 network.ProtocolTypeUnicast) 312 require.NoError(m.T(), err) 313 err = m.mws[0].SendDirect(msg) 314 315 require.NoError(m.T(), err) 316 } 317 318 // wait for all rate limits before shutting down middleware 319 unittest.RequireCloseBefore(m.T(), ch, 100*time.Millisecond, "could not stop on rate limit test ch on time") 320 321 // shutdown our middleware so that each message can be processed 322 cancel() 323 unittest.RequireCloseBefore(m.T(), libP2PNodes[0].Done(), 100*time.Millisecond, "could not stop libp2p node on time") 324 unittest.RequireCloseBefore(m.T(), newMw.Done(), 100*time.Millisecond, "could not stop middleware on time") 325 326 // expect our rate limited peer callback to be invoked once 327 require.Equal(m.T(), uint64(1), rateLimits.Load()) 328 } 329 330 func (m *MiddlewareTestSuite) TestUnicastRateLimit_Bandwidth() { 331 unittest.SkipUnless(m.T(), unittest.TEST_FLAKY, "disabling so that flaky metrics can be gathered before re-enabling") 332 333 //limiter limit will be set up to 1000 bytes/sec 334 limit := rate.Limit(1000) 335 336 //burst per interval 337 burst := 1000 338 339 // create test time 340 testtime := unittest.NewTestTime() 341 342 // setup bandwidth rate limiter 343 bandwidthRateLimiter := ratelimit.NewBandWidthRateLimiter(limit, burst, 1, p2p.WithGetTimeNowFunc(testtime.Now)) 344 345 // the onUnicastRateLimitedPeerFunc call back we will use to keep track of how many times a rate limit happens 346 // after 5 rate limits we will close ch. 347 ch := make(chan struct{}) 348 rateLimits := atomic.NewUint64(0) 349 onRateLimit := func(peerID peer.ID, role, msgType string, topic channels.Topic, reason ratelimit.RateLimitReason) { 350 require.Equal(m.T(), reason, ratelimit.ReasonBandwidth) 351 352 // we only expect messages from the first middleware on the test suite 353 expectedPID, err := unittest.PeerIDFromFlowID(m.ids[0]) 354 require.NoError(m.T(), err) 355 require.Equal(m.T(), expectedPID, peerID) 356 // update hook calls 357 rateLimits.Inc() 358 close(ch) 359 } 360 361 rateLimiters := ratelimit.NewRateLimiters(&ratelimit.NoopRateLimiter{}, 362 bandwidthRateLimiter, 363 onRateLimit, 364 ratelimit.WithDisabledRateLimiting(false)) 365 366 // create a new staked identity 367 ids, libP2PNodes, _ := testutils.GenerateIDs(m.T(), m.logger, 1) 368 369 // create middleware 370 opts := testutils.WithUnicastRateLimiters(rateLimiters) 371 mws, providers := testutils.GenerateMiddlewares(m.T(), 372 m.logger, 373 ids, 374 libP2PNodes, 375 unittest.NetworkCodec(), 376 m.slashingViolationsConsumer, opts) 377 require.Len(m.T(), ids, 1) 378 require.Len(m.T(), providers, 1) 379 require.Len(m.T(), mws, 1) 380 newId := ids[0] 381 newMw := mws[0] 382 383 overlay := m.createOverlay(providers[0]) 384 overlay.On("Receive", m.ids[0].NodeID, mockery.AnythingOfType("*message.Message")).Return(nil) 385 386 newMw.SetOverlay(overlay) 387 388 ctx, cancel := context.WithCancel(m.mwCtx) 389 irrecoverableCtx := irrecoverable.NewMockSignalerContext(m.T(), ctx) 390 391 testutils.StartNodes(irrecoverableCtx, m.T(), libP2PNodes, 100*time.Millisecond) 392 defer testutils.StopComponents(m.T(), libP2PNodes, 100*time.Millisecond) 393 394 newMw.Start(irrecoverableCtx) 395 unittest.RequireComponentsReadyBefore(m.T(), 100*time.Millisecond, newMw) 396 397 require.NoError(m.T(), newMw.Subscribe(testChannel)) 398 399 idList := flow.IdentityList(append(m.ids, newId)) 400 401 // needed to enable ID translation 402 m.providers[0].SetIdentities(idList) 403 404 // create message with about 400bytes (300 random bytes + 100bytes message info) 405 b := make([]byte, 300) 406 for i := range b { 407 b[i] = byte('X') 408 } 409 410 msg, err := network.NewOutgoingScope( 411 flow.IdentifierList{newId.NodeID}, 412 testChannel, 413 &libp2pmessage.TestMessage{ 414 Text: string(b), 415 }, 416 unittest.NetworkCodec().Encode, 417 network.ProtocolTypeUnicast) 418 require.NoError(m.T(), err) 419 420 // update the addresses 421 m.mws[0].UpdateNodeAddresses() 422 423 // for the duration of a simulated second we will send 3 messages. Each message is about 424 // 400 bytes, the 3rd message will put our limiter over the 1000 byte limit at 1200 bytes. Thus 425 // the 3rd message should be rate limited. 426 start := testtime.Now() 427 end := start.Add(time.Second) 428 for testtime.Now().Before(end) { 429 430 err := m.mws[0].SendDirect(msg) 431 require.NoError(m.T(), err) 432 433 // send 3 messages 434 testtime.Advance(334 * time.Millisecond) 435 } 436 437 // wait for all rate limits before shutting down middleware 438 unittest.RequireCloseBefore(m.T(), ch, 100*time.Millisecond, "could not stop on rate limit test ch on time") 439 440 // shutdown our middleware so that each message can be processed 441 cancel() 442 unittest.RequireComponentsDoneBefore(m.T(), 100*time.Millisecond, newMw) 443 444 // expect our rate limited peer callback to be invoked once 445 require.Equal(m.T(), uint64(1), rateLimits.Load()) 446 } 447 448 func (m *MiddlewareTestSuite) createOverlay(provider *testutils.UpdatableIDProvider) *mocknetwork.Overlay { 449 overlay := &mocknetwork.Overlay{} 450 overlay.On("Identities").Maybe().Return(func() flow.IdentityList { 451 return provider.Identities(filter.Any) 452 }) 453 overlay.On("Topology").Maybe().Return(func() flow.IdentityList { 454 return provider.Identities(filter.Any) 455 }, nil) 456 // this test is not testing the topic validator, especially in spoofing, 457 // so we always return a valid identity. We only care about the node role for the test TestMaxMessageSize_SendDirect 458 // where EN are the only node authorized to send chunk data response. 459 identityOpts := unittest.WithRole(flow.RoleExecution) 460 overlay.On("Identity", mockery.AnythingOfType("peer.ID")).Maybe().Return(unittest.IdentityFixture(identityOpts), true) 461 return overlay 462 } 463 464 // TestMultiPing tests the middleware against type of received payload 465 // of distinct messages that are sent concurrently from a node to another 466 func (m *MiddlewareTestSuite) TestMultiPing() { 467 // one distinct message 468 m.MultiPing(1) 469 470 // two distinct messages 471 m.MultiPing(2) 472 473 // 10 distinct messages 474 m.MultiPing(10) 475 } 476 477 // TestPing sends a message from the first middleware of the test suit to the last one and checks that the 478 // last middleware receives the message and that the message is correctly decoded. 479 func (m *MiddlewareTestSuite) TestPing() { 480 receiveWG := sync.WaitGroup{} 481 receiveWG.Add(1) 482 // extracts sender id based on the mock option 483 var err error 484 485 // mocks Overlay.Receive for middleware.Overlay.Receive(*nodeID, payload) 486 firstNodeIndex := 0 487 lastNodeIndex := m.size - 1 488 489 expectedPayload := "TestPingContentReception" 490 msg, err := network.NewOutgoingScope( 491 flow.IdentifierList{m.ids[lastNodeIndex].NodeID}, 492 testChannel, 493 &libp2pmessage.TestMessage{ 494 Text: expectedPayload, 495 }, 496 unittest.NetworkCodec().Encode, 497 network.ProtocolTypeUnicast) 498 require.NoError(m.T(), err) 499 500 m.ov[lastNodeIndex].On("Receive", mockery.Anything).Return(nil).Once(). 501 Run(func(args mockery.Arguments) { 502 receiveWG.Done() 503 504 msg, ok := args[0].(*network.IncomingMessageScope) 505 require.True(m.T(), ok) 506 507 require.Equal(m.T(), testChannel, msg.Channel()) // channel 508 require.Equal(m.T(), m.ids[firstNodeIndex].NodeID, msg.OriginId()) // sender id 509 require.Equal(m.T(), m.ids[lastNodeIndex].NodeID, msg.TargetIDs()[0]) // target id 510 require.Equal(m.T(), network.ProtocolTypeUnicast, msg.Protocol()) // protocol 511 require.Equal(m.T(), expectedPayload, msg.DecodedPayload().(*libp2pmessage.TestMessage).Text) // payload 512 }) 513 514 // sends a direct message from first node to the last node 515 err = m.mws[firstNodeIndex].SendDirect(msg) 516 require.NoError(m.Suite.T(), err) 517 518 unittest.RequireReturnsBefore(m.T(), receiveWG.Wait, 1000*time.Millisecond, "did not receive message") 519 520 // evaluates the mock calls 521 for i := 1; i < m.size; i++ { 522 m.ov[i].AssertExpectations(m.T()) 523 } 524 525 } 526 527 // MultiPing sends count-many distinct messages concurrently from the first middleware of the test suit to the last one. 528 // It evaluates the correctness of reception of the content of the messages. Each message must be received by the 529 // last middleware of the test suit exactly once. 530 func (m *MiddlewareTestSuite) MultiPing(count int) { 531 receiveWG := sync.WaitGroup{} 532 sendWG := sync.WaitGroup{} 533 // extracts sender id based on the mock option 534 // mocks Overlay.Receive for middleware.Overlay.Receive(*nodeID, payload) 535 firstNodeIndex := 0 536 lastNodeIndex := m.size - 1 537 538 receivedPayloads := unittest.NewProtectedMap[string, struct{}]() // keep track of unique payloads received. 539 540 // regex to extract the payload from the message 541 regex := regexp.MustCompile(`^hello from: \d`) 542 543 for i := 0; i < count; i++ { 544 receiveWG.Add(1) 545 sendWG.Add(1) 546 547 expectedPayloadText := fmt.Sprintf("hello from: %d", i) 548 msg, err := network.NewOutgoingScope( 549 flow.IdentifierList{m.ids[lastNodeIndex].NodeID}, 550 testChannel, 551 &libp2pmessage.TestMessage{ 552 Text: expectedPayloadText, 553 }, 554 unittest.NetworkCodec().Encode, 555 network.ProtocolTypeUnicast) 556 require.NoError(m.T(), err) 557 558 m.ov[lastNodeIndex].On("Receive", mockery.Anything).Return(nil).Once(). 559 Run(func(args mockery.Arguments) { 560 receiveWG.Done() 561 562 msg, ok := args[0].(*network.IncomingMessageScope) 563 require.True(m.T(), ok) 564 565 require.Equal(m.T(), testChannel, msg.Channel()) // channel 566 require.Equal(m.T(), m.ids[firstNodeIndex].NodeID, msg.OriginId()) // sender id 567 require.Equal(m.T(), m.ids[lastNodeIndex].NodeID, msg.TargetIDs()[0]) // target id 568 require.Equal(m.T(), network.ProtocolTypeUnicast, msg.Protocol()) // protocol 569 570 // payload 571 decodedPayload := msg.DecodedPayload().(*libp2pmessage.TestMessage).Text 572 require.True(m.T(), regex.MatchString(decodedPayload)) 573 require.False(m.T(), receivedPayloads.Has(decodedPayload)) // payload must be unique 574 receivedPayloads.Add(decodedPayload, struct{}{}) 575 }) 576 go func() { 577 // sends a direct message from first node to the last node 578 err := m.mws[firstNodeIndex].SendDirect(msg) 579 require.NoError(m.Suite.T(), err) 580 581 sendWG.Done() 582 }() 583 } 584 585 unittest.RequireReturnsBefore(m.T(), sendWG.Wait, 1*time.Second, "could not send unicasts on time") 586 unittest.RequireReturnsBefore(m.T(), receiveWG.Wait, 1*time.Second, "could not receive unicasts on time") 587 588 // evaluates the mock calls 589 for i := 1; i < m.size; i++ { 590 m.ov[i].AssertExpectations(m.T()) 591 } 592 } 593 594 // TestEcho sends an echo message from first middleware to the last middleware 595 // the last middleware echos back the message. The test evaluates the correctness 596 // of the message reception as well as its content 597 func (m *MiddlewareTestSuite) TestEcho() { 598 wg := sync.WaitGroup{} 599 // extracts sender id based on the mock option 600 var err error 601 602 wg.Add(2) 603 // mocks Overlay.Receive for middleware.Overlay.Receive(*nodeID, payload) 604 first := 0 605 last := m.size - 1 606 firstNode := m.ids[first].NodeID 607 lastNode := m.ids[last].NodeID 608 609 // message sent from first node to the last node. 610 expectedSendMsg := "TestEcho" 611 sendMsg, err := network.NewOutgoingScope( 612 flow.IdentifierList{lastNode}, 613 testChannel, 614 &libp2pmessage.TestMessage{ 615 Text: expectedSendMsg, 616 }, 617 unittest.NetworkCodec().Encode, 618 network.ProtocolTypeUnicast) 619 require.NoError(m.T(), err) 620 621 // reply from last node to the first node. 622 expectedReplyMsg := "TestEcho response" 623 replyMsg, err := network.NewOutgoingScope( 624 flow.IdentifierList{firstNode}, 625 testChannel, 626 &libp2pmessage.TestMessage{ 627 Text: expectedReplyMsg, 628 }, 629 unittest.NetworkCodec().Encode, 630 network.ProtocolTypeUnicast) 631 require.NoError(m.T(), err) 632 633 // last node 634 m.ov[last].On("Receive", mockery.Anything).Return(nil).Once(). 635 Run(func(args mockery.Arguments) { 636 wg.Done() 637 638 // sanity checks the message content. 639 msg, ok := args[0].(*network.IncomingMessageScope) 640 require.True(m.T(), ok) 641 642 require.Equal(m.T(), testChannel, msg.Channel()) // channel 643 require.Equal(m.T(), m.ids[first].NodeID, msg.OriginId()) // sender id 644 require.Equal(m.T(), lastNode, msg.TargetIDs()[0]) // target id 645 require.Equal(m.T(), network.ProtocolTypeUnicast, msg.Protocol()) // protocol 646 require.Equal(m.T(), expectedSendMsg, msg.DecodedPayload().(*libp2pmessage.TestMessage).Text) // payload 647 // event id 648 eventId, err := network.EventId(msg.Channel(), msg.Proto().Payload) 649 require.NoError(m.T(), err) 650 require.True(m.T(), bytes.Equal(eventId, msg.EventID())) 651 652 // echos back the same message back to the sender 653 err = m.mws[last].SendDirect(replyMsg) 654 assert.NoError(m.T(), err) 655 }) 656 657 // first node 658 m.ov[first].On("Receive", mockery.Anything).Return(nil).Once(). 659 Run(func(args mockery.Arguments) { 660 wg.Done() 661 // sanity checks the message content. 662 msg, ok := args[0].(*network.IncomingMessageScope) 663 require.True(m.T(), ok) 664 665 require.Equal(m.T(), testChannel, msg.Channel()) // channel 666 require.Equal(m.T(), m.ids[last].NodeID, msg.OriginId()) // sender id 667 require.Equal(m.T(), firstNode, msg.TargetIDs()[0]) // target id 668 require.Equal(m.T(), network.ProtocolTypeUnicast, msg.Protocol()) // protocol 669 require.Equal(m.T(), expectedReplyMsg, msg.DecodedPayload().(*libp2pmessage.TestMessage).Text) // payload 670 // event id 671 eventId, err := network.EventId(msg.Channel(), msg.Proto().Payload) 672 require.NoError(m.T(), err) 673 require.True(m.T(), bytes.Equal(eventId, msg.EventID())) 674 }) 675 676 // sends a direct message from first node to the last node 677 err = m.mws[first].SendDirect(sendMsg) 678 require.NoError(m.Suite.T(), err) 679 680 unittest.RequireReturnsBefore(m.T(), wg.Wait, 100*time.Second, "could not receive unicast on time") 681 682 // evaluates the mock calls 683 for i := 1; i < m.size; i++ { 684 m.ov[i].AssertExpectations(m.T()) 685 } 686 } 687 688 // TestMaxMessageSize_SendDirect evaluates that invoking SendDirect method of the middleware on a message 689 // size beyond the permissible unicast message size returns an error. 690 func (m *MiddlewareTestSuite) TestMaxMessageSize_SendDirect() { 691 first := 0 692 last := m.size - 1 693 lastNode := m.ids[last].NodeID 694 695 // creates a network payload beyond the maximum message size 696 // Note: networkPayloadFixture considers 1000 bytes as the overhead of the encoded message, 697 // so the generated payload is 1000 bytes below the maximum unicast message size. 698 // We hence add up 1000 bytes to the input of network payload fixture to make 699 // sure that payload is beyond the permissible size. 700 payload := testutils.NetworkPayloadFixture(m.T(), uint(middleware.DefaultMaxUnicastMsgSize)+1000) 701 event := &libp2pmessage.TestMessage{ 702 Text: string(payload), 703 } 704 705 msg, err := network.NewOutgoingScope( 706 flow.IdentifierList{lastNode}, 707 testChannel, 708 event, 709 unittest.NetworkCodec().Encode, 710 network.ProtocolTypeUnicast) 711 require.NoError(m.T(), err) 712 713 // sends a direct message from first node to the last node 714 err = m.mws[first].SendDirect(msg) 715 require.Error(m.Suite.T(), err) 716 } 717 718 // TestLargeMessageSize_SendDirect asserts that a ChunkDataResponse is treated as a large message and can be unicasted 719 // successfully even though it's size is greater than the default message size. 720 func (m *MiddlewareTestSuite) TestLargeMessageSize_SendDirect() { 721 sourceIndex := 0 722 targetIndex := m.size - 1 723 targetNode := m.ids[targetIndex].NodeID 724 targetMW := m.mws[targetIndex] 725 726 // subscribe to channels.ProvideChunks so that the message is not dropped 727 require.NoError(m.T(), targetMW.Subscribe(channels.ProvideChunks)) 728 729 // creates a network payload with a size greater than the default max size using a known large message type 730 targetSize := uint64(middleware.DefaultMaxUnicastMsgSize) + 1000 731 event := unittest.ChunkDataResponseMsgFixture(unittest.IdentifierFixture(), unittest.WithApproximateSize(targetSize)) 732 733 msg, err := network.NewOutgoingScope( 734 flow.IdentifierList{targetNode}, 735 channels.ProvideChunks, 736 event, 737 unittest.NetworkCodec().Encode, 738 network.ProtocolTypeUnicast) 739 require.NoError(m.T(), err) 740 741 // expect one message to be received by the target 742 ch := make(chan struct{}) 743 m.ov[targetIndex].On("Receive", mockery.Anything).Return(nil).Once(). 744 Run(func(args mockery.Arguments) { 745 msg, ok := args[0].(*network.IncomingMessageScope) 746 require.True(m.T(), ok) 747 748 require.Equal(m.T(), channels.ProvideChunks, msg.Channel()) 749 require.Equal(m.T(), m.ids[sourceIndex].NodeID, msg.OriginId()) 750 require.Equal(m.T(), targetNode, msg.TargetIDs()[0]) 751 require.Equal(m.T(), network.ProtocolTypeUnicast, msg.Protocol()) 752 753 eventId, err := network.EventId(msg.Channel(), msg.Proto().Payload) 754 require.NoError(m.T(), err) 755 require.True(m.T(), bytes.Equal(eventId, msg.EventID())) 756 close(ch) 757 }) 758 759 // sends a direct message from source node to the target node 760 err = m.mws[sourceIndex].SendDirect(msg) 761 // SendDirect should not error since this is a known large message 762 require.NoError(m.Suite.T(), err) 763 764 // check message reception on target 765 unittest.RequireCloseBefore(m.T(), ch, 60*time.Second, "source node failed to send large message to target") 766 767 m.ov[targetIndex].AssertExpectations(m.T()) 768 } 769 770 // TestMaxMessageSize_Publish evaluates that invoking Publish method of the middleware on a message 771 // size beyond the permissible publish message size returns an error. 772 func (m *MiddlewareTestSuite) TestMaxMessageSize_Publish() { 773 first := 0 774 last := m.size - 1 775 lastNode := m.ids[last].NodeID 776 777 // creates a network payload beyond the maximum message size 778 // Note: networkPayloadFixture considers 1000 bytes as the overhead of the encoded message, 779 // so the generated payload is 1000 bytes below the maximum publish message size. 780 // We hence add up 1000 bytes to the input of network payload fixture to make 781 // sure that payload is beyond the permissible size. 782 payload := testutils.NetworkPayloadFixture(m.T(), uint(p2pnode.DefaultMaxPubSubMsgSize)+1000) 783 event := &libp2pmessage.TestMessage{ 784 Text: string(payload), 785 } 786 msg, err := network.NewOutgoingScope( 787 flow.IdentifierList{lastNode}, 788 testChannel, 789 event, 790 unittest.NetworkCodec().Encode, 791 network.ProtocolTypePubSub) 792 require.NoError(m.T(), err) 793 794 // sends a direct message from first node to the last node 795 err = m.mws[first].Publish(msg) 796 require.Error(m.Suite.T(), err) 797 } 798 799 // TestUnsubscribe tests that an engine can unsubscribe from a topic it was earlier subscribed to and stop receiving 800 // messages. 801 func (m *MiddlewareTestSuite) TestUnsubscribe() { 802 first := 0 803 last := m.size - 1 804 firstNode := m.ids[first].NodeID 805 lastNode := m.ids[last].NodeID 806 807 // set up waiting for m.size pubsub tags indicating a mesh has formed 808 for i := 0; i < m.size; i++ { 809 select { 810 case <-m.obs: 811 case <-time.After(2 * time.Second): 812 assert.FailNow(m.T(), "could not receive pubsub tag indicating mesh formed") 813 } 814 } 815 816 msgRcvd := make(chan struct{}, 2) 817 msgRcvdFun := func() { 818 <-msgRcvd 819 } 820 821 message1, err := network.NewOutgoingScope( 822 flow.IdentifierList{lastNode}, 823 testChannel, 824 &libp2pmessage.TestMessage{ 825 Text: string("hello1"), 826 }, 827 unittest.NetworkCodec().Encode, 828 network.ProtocolTypeUnicast) 829 require.NoError(m.T(), err) 830 831 m.ov[last].On("Receive", mockery.Anything).Return(nil).Run(func(args mockery.Arguments) { 832 msg, ok := args[0].(*network.IncomingMessageScope) 833 require.True(m.T(), ok) 834 require.Equal(m.T(), firstNode, msg.OriginId()) 835 msgRcvd <- struct{}{} 836 }) 837 838 // first test that when both nodes are subscribed to the channel, the target node receives the message 839 err = m.mws[first].Publish(message1) 840 assert.NoError(m.T(), err) 841 842 unittest.RequireReturnsBefore(m.T(), msgRcvdFun, 2*time.Second, "message not received") 843 844 // now unsubscribe the target node from the channel 845 err = m.mws[last].Unsubscribe(testChannel) 846 assert.NoError(m.T(), err) 847 848 // create and send a new message on the channel from the origin node 849 message2, err := network.NewOutgoingScope( 850 flow.IdentifierList{lastNode}, 851 testChannel, 852 &libp2pmessage.TestMessage{ 853 Text: string("hello2"), 854 }, 855 unittest.NetworkCodec().Encode, 856 network.ProtocolTypeUnicast) 857 require.NoError(m.T(), err) 858 859 err = m.mws[first].Publish(message2) 860 assert.NoError(m.T(), err) 861 862 // assert that the new message is not received by the target node 863 unittest.RequireNeverReturnBefore(m.T(), msgRcvdFun, 2*time.Second, "message received unexpectedly") 864 }