github.com/koko1123/flow-go-1@v0.29.6/network/test/middleware_test.go (about)

     1  package test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"regexp"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/libp2p/go-libp2p/core/peer"
    13  	"github.com/libp2p/go-libp2p/p2p/net/swarm"
    14  	"github.com/rs/zerolog"
    15  	"github.com/stretchr/testify/assert"
    16  	mockery "github.com/stretchr/testify/mock"
    17  	"github.com/stretchr/testify/require"
    18  	"github.com/stretchr/testify/suite"
    19  	"go.uber.org/atomic"
    20  	"golang.org/x/time/rate"
    21  
    22  	"github.com/koko1123/flow-go-1/model/flow"
    23  	"github.com/koko1123/flow-go-1/model/flow/filter"
    24  	libp2pmessage "github.com/koko1123/flow-go-1/model/libp2p/message"
    25  	"github.com/koko1123/flow-go-1/module/irrecoverable"
    26  	"github.com/koko1123/flow-go-1/module/metrics"
    27  	"github.com/koko1123/flow-go-1/module/mock"
    28  	"github.com/koko1123/flow-go-1/module/observable"
    29  	"github.com/koko1123/flow-go-1/network"
    30  	"github.com/koko1123/flow-go-1/network/channels"
    31  	"github.com/koko1123/flow-go-1/network/internal/testutils"
    32  	"github.com/koko1123/flow-go-1/network/mocknetwork"
    33  	"github.com/koko1123/flow-go-1/network/p2p"
    34  	"github.com/koko1123/flow-go-1/network/p2p/middleware"
    35  	"github.com/koko1123/flow-go-1/network/p2p/p2pnode"
    36  	"github.com/koko1123/flow-go-1/network/p2p/unicast/ratelimit"
    37  	"github.com/koko1123/flow-go-1/network/slashing"
    38  	"github.com/koko1123/flow-go-1/utils/unittest"
    39  )
    40  
    41  const testChannel = channels.TestNetworkChannel
    42  
    43  // libp2p emits a call to `Protect` with a topic-specific tag upon establishing each peering connection in a GossipSUb mesh, see:
    44  // https://github.com/libp2p/go-libp2p-pubsub/blob/master/tag_tracer.go
    45  // One way to make sure such a mesh has formed, asynchronously, in unit tests, is to wait for libp2p.GossipSubD such calls,
    46  // and that's what we do with tagsObserver.
    47  type tagsObserver struct {
    48  	tags chan string
    49  	log  zerolog.Logger
    50  }
    51  
    52  func (co *tagsObserver) OnNext(peertag interface{}) {
    53  	pt, ok := peertag.(testutils.PeerTag)
    54  
    55  	if ok {
    56  		co.tags <- fmt.Sprintf("peer: %v tag: %v", pt.Peer, pt.Tag)
    57  	}
    58  
    59  }
    60  func (co *tagsObserver) OnError(err error) {
    61  	co.log.Error().Err(err).Msg("Tags Observer closed on an error")
    62  	close(co.tags)
    63  }
    64  func (co *tagsObserver) OnComplete() {
    65  	close(co.tags)
    66  }
    67  
    68  type MiddlewareTestSuite struct {
    69  	suite.Suite
    70  	sync.RWMutex
    71  	size      int // used to determine number of middlewares under test
    72  	nodes     []p2p.LibP2PNode
    73  	mws       []network.Middleware // used to keep track of middlewares under test
    74  	ov        []*mocknetwork.Overlay
    75  	obs       chan string // used to keep track of Protect events tagged by pubsub messages
    76  	ids       []*flow.Identity
    77  	metrics   *metrics.NoopCollector // no-op performance monitoring simulation
    78  	logger    zerolog.Logger
    79  	providers []*testutils.UpdatableIDProvider
    80  
    81  	mwCancel context.CancelFunc
    82  	mwCtx    irrecoverable.SignalerContext
    83  
    84  	slashingViolationsConsumer slashing.ViolationsConsumer
    85  }
    86  
    87  // TestMiddlewareTestSuit runs all the test methods in this test suit
    88  func TestMiddlewareTestSuite(t *testing.T) {
    89  	t.Parallel()
    90  	suite.Run(t, new(MiddlewareTestSuite))
    91  }
    92  
    93  // SetupTest initiates the test setups prior to each test
    94  func (m *MiddlewareTestSuite) SetupTest() {
    95  	m.logger = unittest.Logger()
    96  
    97  	m.size = 2 // operates on two middlewares
    98  	m.metrics = metrics.NewNoopCollector()
    99  
   100  	// create and start the middlewares and inject a connection observer
   101  	var obs []observable.Observable
   102  	peerChannel := make(chan string)
   103  	ob := tagsObserver{
   104  		tags: peerChannel,
   105  		log:  m.logger,
   106  	}
   107  
   108  	m.slashingViolationsConsumer = mocknetwork.NewViolationsConsumer(m.T())
   109  
   110  	m.ids, m.nodes, m.mws, obs, m.providers = testutils.GenerateIDsAndMiddlewares(m.T(),
   111  		m.size,
   112  		m.logger,
   113  		unittest.NetworkCodec(),
   114  		m.slashingViolationsConsumer)
   115  
   116  	for _, observableConnMgr := range obs {
   117  		observableConnMgr.Subscribe(&ob)
   118  	}
   119  	m.obs = peerChannel
   120  
   121  	require.Len(m.Suite.T(), obs, m.size)
   122  	require.Len(m.Suite.T(), m.ids, m.size)
   123  	require.Len(m.Suite.T(), m.mws, m.size)
   124  
   125  	// create the mock overlays
   126  	for i := 0; i < m.size; i++ {
   127  		m.ov = append(m.ov, m.createOverlay(m.providers[i]))
   128  	}
   129  
   130  	ctx, cancel := context.WithCancel(context.Background())
   131  	m.mwCancel = cancel
   132  
   133  	m.mwCtx = irrecoverable.NewMockSignalerContext(m.T(), ctx)
   134  
   135  	testutils.StartNodes(m.mwCtx, m.T(), m.nodes, 100*time.Millisecond)
   136  
   137  	for i, mw := range m.mws {
   138  		mw.SetOverlay(m.ov[i])
   139  		mw.Start(m.mwCtx)
   140  		unittest.RequireComponentsReadyBefore(m.T(), 100*time.Millisecond, mw)
   141  		require.NoError(m.T(), mw.Subscribe(testChannel))
   142  	}
   143  }
   144  
   145  func (m *MiddlewareTestSuite) TearDownTest() {
   146  	m.mwCancel()
   147  
   148  	testutils.StopComponents(m.T(), m.mws, 100*time.Millisecond)
   149  	testutils.StopComponents(m.T(), m.nodes, 100*time.Millisecond)
   150  
   151  	m.mws = nil
   152  	m.nodes = nil
   153  	m.ov = nil
   154  	m.ids = nil
   155  	m.size = 0
   156  }
   157  
   158  // TestUpdateNodeAddresses tests that the UpdateNodeAddresses method correctly updates
   159  // the addresses of the staked network participants.
   160  func (m *MiddlewareTestSuite) TestUpdateNodeAddresses() {
   161  	ctx, cancel := context.WithCancel(m.mwCtx)
   162  	irrecoverableCtx := irrecoverable.NewMockSignalerContext(m.T(), ctx)
   163  
   164  	// create a new staked identity
   165  	ids, libP2PNodes, _ := testutils.GenerateIDs(m.T(), m.logger, 1)
   166  
   167  	mws, providers := testutils.GenerateMiddlewares(m.T(), m.logger, ids, libP2PNodes, unittest.NetworkCodec(), m.slashingViolationsConsumer)
   168  	require.Len(m.T(), ids, 1)
   169  	require.Len(m.T(), providers, 1)
   170  	require.Len(m.T(), mws, 1)
   171  	newId := ids[0]
   172  	newMw := mws[0]
   173  
   174  	overlay := m.createOverlay(providers[0])
   175  	overlay.On("Receive", m.ids[0].NodeID, mockery.AnythingOfType("*message.Message")).Return(nil)
   176  	newMw.SetOverlay(overlay)
   177  
   178  	// start up nodes and peer managers
   179  	testutils.StartNodes(irrecoverableCtx, m.T(), libP2PNodes, 100*time.Millisecond)
   180  	defer testutils.StopComponents(m.T(), libP2PNodes, 100*time.Millisecond)
   181  
   182  	newMw.Start(irrecoverableCtx)
   183  	defer testutils.StopComponents(m.T(), mws, 100*time.Millisecond)
   184  	unittest.RequireComponentsReadyBefore(m.T(), 100*time.Millisecond, newMw)
   185  
   186  	idList := flow.IdentityList(append(m.ids, newId))
   187  
   188  	// needed to enable ID translation
   189  	m.providers[0].SetIdentities(idList)
   190  
   191  	outMsg, err := network.NewOutgoingScope(
   192  		flow.IdentifierList{newId.NodeID},
   193  		testChannel,
   194  		&libp2pmessage.TestMessage{
   195  			Text: "TestUpdateNodeAddresses",
   196  		},
   197  		unittest.NetworkCodec().Encode,
   198  		network.ProtocolTypeUnicast)
   199  	require.NoError(m.T(), err)
   200  	// message should fail to send because no address is known yet
   201  	// for the new identity
   202  	err = m.mws[0].SendDirect(outMsg)
   203  	require.ErrorIs(m.T(), err, swarm.ErrNoAddresses)
   204  
   205  	// update the addresses
   206  	m.mws[0].UpdateNodeAddresses()
   207  
   208  	// now the message should send successfully
   209  	err = m.mws[0].SendDirect(outMsg)
   210  	require.NoError(m.T(), err)
   211  
   212  	cancel()
   213  	unittest.RequireComponentsReadyBefore(m.T(), 100*time.Millisecond, newMw)
   214  }
   215  
   216  func (m *MiddlewareTestSuite) TestUnicastRateLimit_Messages() {
   217  	unittest.SkipUnless(m.T(), unittest.TEST_FLAKY, "disabling so that flaky metrics can be gathered before re-enabling")
   218  
   219  	// limiter limit will be set to 5 events/sec the 6th event per interval will be rate limited
   220  	limit := rate.Limit(5)
   221  
   222  	// burst per interval
   223  	burst := 5
   224  
   225  	messageRateLimiter := ratelimit.NewMessageRateLimiter(limit, burst, 1)
   226  
   227  	// the onUnicastRateLimitedPeerFunc call back we will use to keep track of how many times a rate limit happens
   228  	// after 5 rate limits we will close ch. O
   229  	ch := make(chan struct{})
   230  	rateLimits := atomic.NewUint64(0)
   231  	onRateLimit := func(peerID peer.ID, role, msgType string, topic channels.Topic, reason ratelimit.RateLimitReason) {
   232  		require.Equal(m.T(), reason, ratelimit.ReasonMessageCount)
   233  
   234  		// we only expect messages from the first middleware on the test suite
   235  		expectedPID, err := unittest.PeerIDFromFlowID(m.ids[0])
   236  		require.NoError(m.T(), err)
   237  		require.Equal(m.T(), expectedPID, peerID)
   238  
   239  		// update hook calls
   240  		rateLimits.Inc()
   241  	}
   242  
   243  	rateLimiters := ratelimit.NewRateLimiters(messageRateLimiter,
   244  		&ratelimit.NoopRateLimiter{},
   245  		onRateLimit,
   246  		ratelimit.WithDisabledRateLimiting(false))
   247  
   248  	// create a new staked identity
   249  	ids, libP2PNodes, _ := testutils.GenerateIDs(m.T(), m.logger, 1)
   250  
   251  	// create middleware
   252  	netmet := mock.NewNetworkMetrics(m.T())
   253  	calls := 0
   254  	netmet.On("InboundMessageReceived", mockery.Anything, mockery.Anything, mockery.Anything).Times(5).Run(func(args mockery.Arguments) {
   255  		calls++
   256  		if calls == 5 {
   257  			close(ch)
   258  		}
   259  	})
   260  	// we expect 5 messages to be processed the rest will be rate limited
   261  	defer netmet.AssertNumberOfCalls(m.T(), "InboundMessageReceived", 5)
   262  
   263  	mws, providers := testutils.GenerateMiddlewares(m.T(),
   264  		m.logger,
   265  		ids,
   266  		libP2PNodes,
   267  		unittest.NetworkCodec(),
   268  		m.slashingViolationsConsumer,
   269  		testutils.WithUnicastRateLimiters(rateLimiters),
   270  		testutils.WithNetworkMetrics(netmet))
   271  
   272  	require.Len(m.T(), ids, 1)
   273  	require.Len(m.T(), providers, 1)
   274  	require.Len(m.T(), mws, 1)
   275  	newId := ids[0]
   276  	newMw := mws[0]
   277  
   278  	overlay := m.createOverlay(providers[0])
   279  	overlay.On("Receive", m.ids[0].NodeID, mockery.AnythingOfType("*message.Message")).Return(nil)
   280  
   281  	newMw.SetOverlay(overlay)
   282  
   283  	ctx, cancel := context.WithCancel(m.mwCtx)
   284  	irrecoverableCtx := irrecoverable.NewMockSignalerContext(m.T(), ctx)
   285  
   286  	testutils.StartNodes(irrecoverableCtx, m.T(), libP2PNodes, 100*time.Millisecond)
   287  	defer testutils.StopComponents(m.T(), libP2PNodes, 100*time.Millisecond)
   288  
   289  	newMw.Start(irrecoverableCtx)
   290  	unittest.RequireComponentsReadyBefore(m.T(), 100*time.Millisecond, newMw)
   291  
   292  	require.NoError(m.T(), newMw.Subscribe(testChannel))
   293  
   294  	idList := flow.IdentityList(append(m.ids, newId))
   295  
   296  	// needed to enable ID translation
   297  	m.providers[0].SetIdentities(idList)
   298  
   299  	// update the addresses
   300  	m.mws[0].UpdateNodeAddresses()
   301  
   302  	// send 6 unicast messages, 5 should be allowed and the 6th should be rate limited
   303  	for i := 0; i < 6; i++ {
   304  		msg, err := network.NewOutgoingScope(
   305  			flow.IdentifierList{newId.NodeID},
   306  			testChannel,
   307  			&libp2pmessage.TestMessage{
   308  				Text: fmt.Sprintf("hello-%d", i),
   309  			},
   310  			unittest.NetworkCodec().Encode,
   311  			network.ProtocolTypeUnicast)
   312  		require.NoError(m.T(), err)
   313  		err = m.mws[0].SendDirect(msg)
   314  
   315  		require.NoError(m.T(), err)
   316  	}
   317  
   318  	// wait for all rate limits before shutting down middleware
   319  	unittest.RequireCloseBefore(m.T(), ch, 100*time.Millisecond, "could not stop on rate limit test ch on time")
   320  
   321  	// shutdown our middleware so that each message can be processed
   322  	cancel()
   323  	unittest.RequireCloseBefore(m.T(), libP2PNodes[0].Done(), 100*time.Millisecond, "could not stop libp2p node on time")
   324  	unittest.RequireCloseBefore(m.T(), newMw.Done(), 100*time.Millisecond, "could not stop middleware on time")
   325  
   326  	// expect our rate limited peer callback to be invoked once
   327  	require.Equal(m.T(), uint64(1), rateLimits.Load())
   328  }
   329  
   330  func (m *MiddlewareTestSuite) TestUnicastRateLimit_Bandwidth() {
   331  	unittest.SkipUnless(m.T(), unittest.TEST_FLAKY, "disabling so that flaky metrics can be gathered before re-enabling")
   332  
   333  	//limiter limit will be set up to 1000 bytes/sec
   334  	limit := rate.Limit(1000)
   335  
   336  	//burst per interval
   337  	burst := 1000
   338  
   339  	// create test time
   340  	testtime := unittest.NewTestTime()
   341  
   342  	// setup bandwidth rate limiter
   343  	bandwidthRateLimiter := ratelimit.NewBandWidthRateLimiter(limit, burst, 1, p2p.WithGetTimeNowFunc(testtime.Now))
   344  
   345  	// the onUnicastRateLimitedPeerFunc call back we will use to keep track of how many times a rate limit happens
   346  	// after 5 rate limits we will close ch.
   347  	ch := make(chan struct{})
   348  	rateLimits := atomic.NewUint64(0)
   349  	onRateLimit := func(peerID peer.ID, role, msgType string, topic channels.Topic, reason ratelimit.RateLimitReason) {
   350  		require.Equal(m.T(), reason, ratelimit.ReasonBandwidth)
   351  
   352  		// we only expect messages from the first middleware on the test suite
   353  		expectedPID, err := unittest.PeerIDFromFlowID(m.ids[0])
   354  		require.NoError(m.T(), err)
   355  		require.Equal(m.T(), expectedPID, peerID)
   356  		// update hook calls
   357  		rateLimits.Inc()
   358  		close(ch)
   359  	}
   360  
   361  	rateLimiters := ratelimit.NewRateLimiters(&ratelimit.NoopRateLimiter{},
   362  		bandwidthRateLimiter,
   363  		onRateLimit,
   364  		ratelimit.WithDisabledRateLimiting(false))
   365  
   366  	// create a new staked identity
   367  	ids, libP2PNodes, _ := testutils.GenerateIDs(m.T(), m.logger, 1)
   368  
   369  	// create middleware
   370  	opts := testutils.WithUnicastRateLimiters(rateLimiters)
   371  	mws, providers := testutils.GenerateMiddlewares(m.T(),
   372  		m.logger,
   373  		ids,
   374  		libP2PNodes,
   375  		unittest.NetworkCodec(),
   376  		m.slashingViolationsConsumer, opts)
   377  	require.Len(m.T(), ids, 1)
   378  	require.Len(m.T(), providers, 1)
   379  	require.Len(m.T(), mws, 1)
   380  	newId := ids[0]
   381  	newMw := mws[0]
   382  
   383  	overlay := m.createOverlay(providers[0])
   384  	overlay.On("Receive", m.ids[0].NodeID, mockery.AnythingOfType("*message.Message")).Return(nil)
   385  
   386  	newMw.SetOverlay(overlay)
   387  
   388  	ctx, cancel := context.WithCancel(m.mwCtx)
   389  	irrecoverableCtx := irrecoverable.NewMockSignalerContext(m.T(), ctx)
   390  
   391  	testutils.StartNodes(irrecoverableCtx, m.T(), libP2PNodes, 100*time.Millisecond)
   392  	defer testutils.StopComponents(m.T(), libP2PNodes, 100*time.Millisecond)
   393  
   394  	newMw.Start(irrecoverableCtx)
   395  	unittest.RequireComponentsReadyBefore(m.T(), 100*time.Millisecond, newMw)
   396  
   397  	require.NoError(m.T(), newMw.Subscribe(testChannel))
   398  
   399  	idList := flow.IdentityList(append(m.ids, newId))
   400  
   401  	// needed to enable ID translation
   402  	m.providers[0].SetIdentities(idList)
   403  
   404  	// create message with about 400bytes (300 random bytes + 100bytes message info)
   405  	b := make([]byte, 300)
   406  	for i := range b {
   407  		b[i] = byte('X')
   408  	}
   409  
   410  	msg, err := network.NewOutgoingScope(
   411  		flow.IdentifierList{newId.NodeID},
   412  		testChannel,
   413  		&libp2pmessage.TestMessage{
   414  			Text: string(b),
   415  		},
   416  		unittest.NetworkCodec().Encode,
   417  		network.ProtocolTypeUnicast)
   418  	require.NoError(m.T(), err)
   419  
   420  	// update the addresses
   421  	m.mws[0].UpdateNodeAddresses()
   422  
   423  	// for the duration of a simulated second we will send 3 messages. Each message is about
   424  	// 400 bytes, the 3rd message will put our limiter over the 1000 byte limit at 1200 bytes. Thus
   425  	// the 3rd message should be rate limited.
   426  	start := testtime.Now()
   427  	end := start.Add(time.Second)
   428  	for testtime.Now().Before(end) {
   429  
   430  		err := m.mws[0].SendDirect(msg)
   431  		require.NoError(m.T(), err)
   432  
   433  		// send 3 messages
   434  		testtime.Advance(334 * time.Millisecond)
   435  	}
   436  
   437  	// wait for all rate limits before shutting down middleware
   438  	unittest.RequireCloseBefore(m.T(), ch, 100*time.Millisecond, "could not stop on rate limit test ch on time")
   439  
   440  	// shutdown our middleware so that each message can be processed
   441  	cancel()
   442  	unittest.RequireComponentsDoneBefore(m.T(), 100*time.Millisecond, newMw)
   443  
   444  	// expect our rate limited peer callback to be invoked once
   445  	require.Equal(m.T(), uint64(1), rateLimits.Load())
   446  }
   447  
   448  func (m *MiddlewareTestSuite) createOverlay(provider *testutils.UpdatableIDProvider) *mocknetwork.Overlay {
   449  	overlay := &mocknetwork.Overlay{}
   450  	overlay.On("Identities").Maybe().Return(func() flow.IdentityList {
   451  		return provider.Identities(filter.Any)
   452  	})
   453  	overlay.On("Topology").Maybe().Return(func() flow.IdentityList {
   454  		return provider.Identities(filter.Any)
   455  	}, nil)
   456  	// this test is not testing the topic validator, especially in spoofing,
   457  	// so we always return a valid identity. We only care about the node role for the test TestMaxMessageSize_SendDirect
   458  	// where EN are the only node authorized to send chunk data response.
   459  	identityOpts := unittest.WithRole(flow.RoleExecution)
   460  	overlay.On("Identity", mockery.AnythingOfType("peer.ID")).Maybe().Return(unittest.IdentityFixture(identityOpts), true)
   461  	return overlay
   462  }
   463  
   464  // TestMultiPing tests the middleware against type of received payload
   465  // of distinct messages that are sent concurrently from a node to another
   466  func (m *MiddlewareTestSuite) TestMultiPing() {
   467  	// one distinct message
   468  	m.MultiPing(1)
   469  
   470  	// two distinct messages
   471  	m.MultiPing(2)
   472  
   473  	// 10 distinct messages
   474  	m.MultiPing(10)
   475  }
   476  
   477  // TestPing sends a message from the first middleware of the test suit to the last one and checks that the
   478  // last middleware receives the message and that the message is correctly decoded.
   479  func (m *MiddlewareTestSuite) TestPing() {
   480  	receiveWG := sync.WaitGroup{}
   481  	receiveWG.Add(1)
   482  	// extracts sender id based on the mock option
   483  	var err error
   484  
   485  	// mocks Overlay.Receive for middleware.Overlay.Receive(*nodeID, payload)
   486  	firstNodeIndex := 0
   487  	lastNodeIndex := m.size - 1
   488  
   489  	expectedPayload := "TestPingContentReception"
   490  	msg, err := network.NewOutgoingScope(
   491  		flow.IdentifierList{m.ids[lastNodeIndex].NodeID},
   492  		testChannel,
   493  		&libp2pmessage.TestMessage{
   494  			Text: expectedPayload,
   495  		},
   496  		unittest.NetworkCodec().Encode,
   497  		network.ProtocolTypeUnicast)
   498  	require.NoError(m.T(), err)
   499  
   500  	m.ov[lastNodeIndex].On("Receive", mockery.Anything).Return(nil).Once().
   501  		Run(func(args mockery.Arguments) {
   502  			receiveWG.Done()
   503  
   504  			msg, ok := args[0].(*network.IncomingMessageScope)
   505  			require.True(m.T(), ok)
   506  
   507  			require.Equal(m.T(), testChannel, msg.Channel())                                              // channel
   508  			require.Equal(m.T(), m.ids[firstNodeIndex].NodeID, msg.OriginId())                            // sender id
   509  			require.Equal(m.T(), m.ids[lastNodeIndex].NodeID, msg.TargetIDs()[0])                         // target id
   510  			require.Equal(m.T(), network.ProtocolTypeUnicast, msg.Protocol())                             // protocol
   511  			require.Equal(m.T(), expectedPayload, msg.DecodedPayload().(*libp2pmessage.TestMessage).Text) // payload
   512  		})
   513  
   514  	// sends a direct message from first node to the last node
   515  	err = m.mws[firstNodeIndex].SendDirect(msg)
   516  	require.NoError(m.Suite.T(), err)
   517  
   518  	unittest.RequireReturnsBefore(m.T(), receiveWG.Wait, 1000*time.Millisecond, "did not receive message")
   519  
   520  	// evaluates the mock calls
   521  	for i := 1; i < m.size; i++ {
   522  		m.ov[i].AssertExpectations(m.T())
   523  	}
   524  
   525  }
   526  
   527  // MultiPing sends count-many distinct messages concurrently from the first middleware of the test suit to the last one.
   528  // It evaluates the correctness of reception of the content of the messages. Each message must be received by the
   529  // last middleware of the test suit exactly once.
   530  func (m *MiddlewareTestSuite) MultiPing(count int) {
   531  	receiveWG := sync.WaitGroup{}
   532  	sendWG := sync.WaitGroup{}
   533  	// extracts sender id based on the mock option
   534  	// mocks Overlay.Receive for  middleware.Overlay.Receive(*nodeID, payload)
   535  	firstNodeIndex := 0
   536  	lastNodeIndex := m.size - 1
   537  
   538  	receivedPayloads := unittest.NewProtectedMap[string, struct{}]() // keep track of unique payloads received.
   539  
   540  	// regex to extract the payload from the message
   541  	regex := regexp.MustCompile(`^hello from: \d`)
   542  
   543  	for i := 0; i < count; i++ {
   544  		receiveWG.Add(1)
   545  		sendWG.Add(1)
   546  
   547  		expectedPayloadText := fmt.Sprintf("hello from: %d", i)
   548  		msg, err := network.NewOutgoingScope(
   549  			flow.IdentifierList{m.ids[lastNodeIndex].NodeID},
   550  			testChannel,
   551  			&libp2pmessage.TestMessage{
   552  				Text: expectedPayloadText,
   553  			},
   554  			unittest.NetworkCodec().Encode,
   555  			network.ProtocolTypeUnicast)
   556  		require.NoError(m.T(), err)
   557  
   558  		m.ov[lastNodeIndex].On("Receive", mockery.Anything).Return(nil).Once().
   559  			Run(func(args mockery.Arguments) {
   560  				receiveWG.Done()
   561  
   562  				msg, ok := args[0].(*network.IncomingMessageScope)
   563  				require.True(m.T(), ok)
   564  
   565  				require.Equal(m.T(), testChannel, msg.Channel())                      // channel
   566  				require.Equal(m.T(), m.ids[firstNodeIndex].NodeID, msg.OriginId())    // sender id
   567  				require.Equal(m.T(), m.ids[lastNodeIndex].NodeID, msg.TargetIDs()[0]) // target id
   568  				require.Equal(m.T(), network.ProtocolTypeUnicast, msg.Protocol())     // protocol
   569  
   570  				// payload
   571  				decodedPayload := msg.DecodedPayload().(*libp2pmessage.TestMessage).Text
   572  				require.True(m.T(), regex.MatchString(decodedPayload))
   573  				require.False(m.T(), receivedPayloads.Has(decodedPayload)) // payload must be unique
   574  				receivedPayloads.Add(decodedPayload, struct{}{})
   575  			})
   576  		go func() {
   577  			// sends a direct message from first node to the last node
   578  			err := m.mws[firstNodeIndex].SendDirect(msg)
   579  			require.NoError(m.Suite.T(), err)
   580  
   581  			sendWG.Done()
   582  		}()
   583  	}
   584  
   585  	unittest.RequireReturnsBefore(m.T(), sendWG.Wait, 1*time.Second, "could not send unicasts on time")
   586  	unittest.RequireReturnsBefore(m.T(), receiveWG.Wait, 1*time.Second, "could not receive unicasts on time")
   587  
   588  	// evaluates the mock calls
   589  	for i := 1; i < m.size; i++ {
   590  		m.ov[i].AssertExpectations(m.T())
   591  	}
   592  }
   593  
   594  // TestEcho sends an echo message from first middleware to the last middleware
   595  // the last middleware echos back the message. The test evaluates the correctness
   596  // of the message reception as well as its content
   597  func (m *MiddlewareTestSuite) TestEcho() {
   598  	wg := sync.WaitGroup{}
   599  	// extracts sender id based on the mock option
   600  	var err error
   601  
   602  	wg.Add(2)
   603  	// mocks Overlay.Receive for middleware.Overlay.Receive(*nodeID, payload)
   604  	first := 0
   605  	last := m.size - 1
   606  	firstNode := m.ids[first].NodeID
   607  	lastNode := m.ids[last].NodeID
   608  
   609  	// message sent from first node to the last node.
   610  	expectedSendMsg := "TestEcho"
   611  	sendMsg, err := network.NewOutgoingScope(
   612  		flow.IdentifierList{lastNode},
   613  		testChannel,
   614  		&libp2pmessage.TestMessage{
   615  			Text: expectedSendMsg,
   616  		},
   617  		unittest.NetworkCodec().Encode,
   618  		network.ProtocolTypeUnicast)
   619  	require.NoError(m.T(), err)
   620  
   621  	// reply from last node to the first node.
   622  	expectedReplyMsg := "TestEcho response"
   623  	replyMsg, err := network.NewOutgoingScope(
   624  		flow.IdentifierList{firstNode},
   625  		testChannel,
   626  		&libp2pmessage.TestMessage{
   627  			Text: expectedReplyMsg,
   628  		},
   629  		unittest.NetworkCodec().Encode,
   630  		network.ProtocolTypeUnicast)
   631  	require.NoError(m.T(), err)
   632  
   633  	// last node
   634  	m.ov[last].On("Receive", mockery.Anything).Return(nil).Once().
   635  		Run(func(args mockery.Arguments) {
   636  			wg.Done()
   637  
   638  			// sanity checks the message content.
   639  			msg, ok := args[0].(*network.IncomingMessageScope)
   640  			require.True(m.T(), ok)
   641  
   642  			require.Equal(m.T(), testChannel, msg.Channel())                                              // channel
   643  			require.Equal(m.T(), m.ids[first].NodeID, msg.OriginId())                                     // sender id
   644  			require.Equal(m.T(), lastNode, msg.TargetIDs()[0])                                            // target id
   645  			require.Equal(m.T(), network.ProtocolTypeUnicast, msg.Protocol())                             // protocol
   646  			require.Equal(m.T(), expectedSendMsg, msg.DecodedPayload().(*libp2pmessage.TestMessage).Text) // payload
   647  			// event id
   648  			eventId, err := network.EventId(msg.Channel(), msg.Proto().Payload)
   649  			require.NoError(m.T(), err)
   650  			require.True(m.T(), bytes.Equal(eventId, msg.EventID()))
   651  
   652  			// echos back the same message back to the sender
   653  			err = m.mws[last].SendDirect(replyMsg)
   654  			assert.NoError(m.T(), err)
   655  		})
   656  
   657  	// first node
   658  	m.ov[first].On("Receive", mockery.Anything).Return(nil).Once().
   659  		Run(func(args mockery.Arguments) {
   660  			wg.Done()
   661  			// sanity checks the message content.
   662  			msg, ok := args[0].(*network.IncomingMessageScope)
   663  			require.True(m.T(), ok)
   664  
   665  			require.Equal(m.T(), testChannel, msg.Channel())                                               // channel
   666  			require.Equal(m.T(), m.ids[last].NodeID, msg.OriginId())                                       // sender id
   667  			require.Equal(m.T(), firstNode, msg.TargetIDs()[0])                                            // target id
   668  			require.Equal(m.T(), network.ProtocolTypeUnicast, msg.Protocol())                              // protocol
   669  			require.Equal(m.T(), expectedReplyMsg, msg.DecodedPayload().(*libp2pmessage.TestMessage).Text) // payload
   670  			// event id
   671  			eventId, err := network.EventId(msg.Channel(), msg.Proto().Payload)
   672  			require.NoError(m.T(), err)
   673  			require.True(m.T(), bytes.Equal(eventId, msg.EventID()))
   674  		})
   675  
   676  	// sends a direct message from first node to the last node
   677  	err = m.mws[first].SendDirect(sendMsg)
   678  	require.NoError(m.Suite.T(), err)
   679  
   680  	unittest.RequireReturnsBefore(m.T(), wg.Wait, 100*time.Second, "could not receive unicast on time")
   681  
   682  	// evaluates the mock calls
   683  	for i := 1; i < m.size; i++ {
   684  		m.ov[i].AssertExpectations(m.T())
   685  	}
   686  }
   687  
   688  // TestMaxMessageSize_SendDirect evaluates that invoking SendDirect method of the middleware on a message
   689  // size beyond the permissible unicast message size returns an error.
   690  func (m *MiddlewareTestSuite) TestMaxMessageSize_SendDirect() {
   691  	first := 0
   692  	last := m.size - 1
   693  	lastNode := m.ids[last].NodeID
   694  
   695  	// creates a network payload beyond the maximum message size
   696  	// Note: networkPayloadFixture considers 1000 bytes as the overhead of the encoded message,
   697  	// so the generated payload is 1000 bytes below the maximum unicast message size.
   698  	// We hence add up 1000 bytes to the input of network payload fixture to make
   699  	// sure that payload is beyond the permissible size.
   700  	payload := testutils.NetworkPayloadFixture(m.T(), uint(middleware.DefaultMaxUnicastMsgSize)+1000)
   701  	event := &libp2pmessage.TestMessage{
   702  		Text: string(payload),
   703  	}
   704  
   705  	msg, err := network.NewOutgoingScope(
   706  		flow.IdentifierList{lastNode},
   707  		testChannel,
   708  		event,
   709  		unittest.NetworkCodec().Encode,
   710  		network.ProtocolTypeUnicast)
   711  	require.NoError(m.T(), err)
   712  
   713  	// sends a direct message from first node to the last node
   714  	err = m.mws[first].SendDirect(msg)
   715  	require.Error(m.Suite.T(), err)
   716  }
   717  
   718  // TestLargeMessageSize_SendDirect asserts that a ChunkDataResponse is treated as a large message and can be unicasted
   719  // successfully even though it's size is greater than the default message size.
   720  func (m *MiddlewareTestSuite) TestLargeMessageSize_SendDirect() {
   721  	sourceIndex := 0
   722  	targetIndex := m.size - 1
   723  	targetNode := m.ids[targetIndex].NodeID
   724  	targetMW := m.mws[targetIndex]
   725  
   726  	// subscribe to channels.ProvideChunks so that the message is not dropped
   727  	require.NoError(m.T(), targetMW.Subscribe(channels.ProvideChunks))
   728  
   729  	// creates a network payload with a size greater than the default max size using a known large message type
   730  	targetSize := uint64(middleware.DefaultMaxUnicastMsgSize) + 1000
   731  	event := unittest.ChunkDataResponseMsgFixture(unittest.IdentifierFixture(), unittest.WithApproximateSize(targetSize))
   732  
   733  	msg, err := network.NewOutgoingScope(
   734  		flow.IdentifierList{targetNode},
   735  		channels.ProvideChunks,
   736  		event,
   737  		unittest.NetworkCodec().Encode,
   738  		network.ProtocolTypeUnicast)
   739  	require.NoError(m.T(), err)
   740  
   741  	// expect one message to be received by the target
   742  	ch := make(chan struct{})
   743  	m.ov[targetIndex].On("Receive", mockery.Anything).Return(nil).Once().
   744  		Run(func(args mockery.Arguments) {
   745  			msg, ok := args[0].(*network.IncomingMessageScope)
   746  			require.True(m.T(), ok)
   747  
   748  			require.Equal(m.T(), channels.ProvideChunks, msg.Channel())
   749  			require.Equal(m.T(), m.ids[sourceIndex].NodeID, msg.OriginId())
   750  			require.Equal(m.T(), targetNode, msg.TargetIDs()[0])
   751  			require.Equal(m.T(), network.ProtocolTypeUnicast, msg.Protocol())
   752  
   753  			eventId, err := network.EventId(msg.Channel(), msg.Proto().Payload)
   754  			require.NoError(m.T(), err)
   755  			require.True(m.T(), bytes.Equal(eventId, msg.EventID()))
   756  			close(ch)
   757  		})
   758  
   759  	// sends a direct message from source node to the target node
   760  	err = m.mws[sourceIndex].SendDirect(msg)
   761  	// SendDirect should not error since this is a known large message
   762  	require.NoError(m.Suite.T(), err)
   763  
   764  	// check message reception on target
   765  	unittest.RequireCloseBefore(m.T(), ch, 60*time.Second, "source node failed to send large message to target")
   766  
   767  	m.ov[targetIndex].AssertExpectations(m.T())
   768  }
   769  
   770  // TestMaxMessageSize_Publish evaluates that invoking Publish method of the middleware on a message
   771  // size beyond the permissible publish message size returns an error.
   772  func (m *MiddlewareTestSuite) TestMaxMessageSize_Publish() {
   773  	first := 0
   774  	last := m.size - 1
   775  	lastNode := m.ids[last].NodeID
   776  
   777  	// creates a network payload beyond the maximum message size
   778  	// Note: networkPayloadFixture considers 1000 bytes as the overhead of the encoded message,
   779  	// so the generated payload is 1000 bytes below the maximum publish message size.
   780  	// We hence add up 1000 bytes to the input of network payload fixture to make
   781  	// sure that payload is beyond the permissible size.
   782  	payload := testutils.NetworkPayloadFixture(m.T(), uint(p2pnode.DefaultMaxPubSubMsgSize)+1000)
   783  	event := &libp2pmessage.TestMessage{
   784  		Text: string(payload),
   785  	}
   786  	msg, err := network.NewOutgoingScope(
   787  		flow.IdentifierList{lastNode},
   788  		testChannel,
   789  		event,
   790  		unittest.NetworkCodec().Encode,
   791  		network.ProtocolTypePubSub)
   792  	require.NoError(m.T(), err)
   793  
   794  	// sends a direct message from first node to the last node
   795  	err = m.mws[first].Publish(msg)
   796  	require.Error(m.Suite.T(), err)
   797  }
   798  
   799  // TestUnsubscribe tests that an engine can unsubscribe from a topic it was earlier subscribed to and stop receiving
   800  // messages.
   801  func (m *MiddlewareTestSuite) TestUnsubscribe() {
   802  	first := 0
   803  	last := m.size - 1
   804  	firstNode := m.ids[first].NodeID
   805  	lastNode := m.ids[last].NodeID
   806  
   807  	// set up waiting for m.size pubsub tags indicating a mesh has formed
   808  	for i := 0; i < m.size; i++ {
   809  		select {
   810  		case <-m.obs:
   811  		case <-time.After(2 * time.Second):
   812  			assert.FailNow(m.T(), "could not receive pubsub tag indicating mesh formed")
   813  		}
   814  	}
   815  
   816  	msgRcvd := make(chan struct{}, 2)
   817  	msgRcvdFun := func() {
   818  		<-msgRcvd
   819  	}
   820  
   821  	message1, err := network.NewOutgoingScope(
   822  		flow.IdentifierList{lastNode},
   823  		testChannel,
   824  		&libp2pmessage.TestMessage{
   825  			Text: string("hello1"),
   826  		},
   827  		unittest.NetworkCodec().Encode,
   828  		network.ProtocolTypeUnicast)
   829  	require.NoError(m.T(), err)
   830  
   831  	m.ov[last].On("Receive", mockery.Anything).Return(nil).Run(func(args mockery.Arguments) {
   832  		msg, ok := args[0].(*network.IncomingMessageScope)
   833  		require.True(m.T(), ok)
   834  		require.Equal(m.T(), firstNode, msg.OriginId())
   835  		msgRcvd <- struct{}{}
   836  	})
   837  
   838  	// first test that when both nodes are subscribed to the channel, the target node receives the message
   839  	err = m.mws[first].Publish(message1)
   840  	assert.NoError(m.T(), err)
   841  
   842  	unittest.RequireReturnsBefore(m.T(), msgRcvdFun, 2*time.Second, "message not received")
   843  
   844  	// now unsubscribe the target node from the channel
   845  	err = m.mws[last].Unsubscribe(testChannel)
   846  	assert.NoError(m.T(), err)
   847  
   848  	// create and send a new message on the channel from the origin node
   849  	message2, err := network.NewOutgoingScope(
   850  		flow.IdentifierList{lastNode},
   851  		testChannel,
   852  		&libp2pmessage.TestMessage{
   853  			Text: string("hello2"),
   854  		},
   855  		unittest.NetworkCodec().Encode,
   856  		network.ProtocolTypeUnicast)
   857  	require.NoError(m.T(), err)
   858  
   859  	err = m.mws[first].Publish(message2)
   860  	assert.NoError(m.T(), err)
   861  
   862  	// assert that the new message is not received by the target node
   863  	unittest.RequireNeverReturnBefore(m.T(), msgRcvdFun, 2*time.Second, "message received unexpectedly")
   864  }