github.com/MetalBlockchain/metalgo@v1.11.9/network/p2p/network_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package p2p
     5  
     6  import (
     7  	"context"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/prometheus/client_golang/prometheus"
    12  	"github.com/stretchr/testify/require"
    13  
    14  	"github.com/MetalBlockchain/metalgo/ids"
    15  	"github.com/MetalBlockchain/metalgo/snow/engine/common"
    16  	"github.com/MetalBlockchain/metalgo/snow/validators"
    17  	"github.com/MetalBlockchain/metalgo/utils/logging"
    18  	"github.com/MetalBlockchain/metalgo/utils/set"
    19  	"github.com/MetalBlockchain/metalgo/version"
    20  )
    21  
    22  const (
    23  	handlerID     = 123
    24  	handlerPrefix = byte(handlerID)
    25  )
    26  
    27  var errFoo = &common.AppError{
    28  	Code:    123,
    29  	Message: "foo",
    30  }
    31  
    32  func TestMessageRouting(t *testing.T) {
    33  	require := require.New(t)
    34  	ctx := context.Background()
    35  	wantNodeID := ids.GenerateTestNodeID()
    36  	wantChainID := ids.GenerateTestID()
    37  	wantMsg := []byte("message")
    38  
    39  	var appGossipCalled, appRequestCalled, crossChainAppRequestCalled bool
    40  	testHandler := &TestHandler{
    41  		AppGossipF: func(_ context.Context, nodeID ids.NodeID, msg []byte) {
    42  			appGossipCalled = true
    43  			require.Equal(wantNodeID, nodeID)
    44  			require.Equal(wantMsg, msg)
    45  		},
    46  		AppRequestF: func(_ context.Context, nodeID ids.NodeID, _ time.Time, msg []byte) ([]byte, error) {
    47  			appRequestCalled = true
    48  			require.Equal(wantNodeID, nodeID)
    49  			require.Equal(wantMsg, msg)
    50  			return nil, nil
    51  		},
    52  		CrossChainAppRequestF: func(_ context.Context, chainID ids.ID, _ time.Time, msg []byte) ([]byte, error) {
    53  			crossChainAppRequestCalled = true
    54  			require.Equal(wantChainID, chainID)
    55  			require.Equal(wantMsg, msg)
    56  			return nil, nil
    57  		},
    58  	}
    59  
    60  	sender := &common.FakeSender{
    61  		SentAppGossip:            make(chan []byte, 1),
    62  		SentAppRequest:           make(chan []byte, 1),
    63  		SentCrossChainAppRequest: make(chan []byte, 1),
    64  	}
    65  
    66  	network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "")
    67  	require.NoError(err)
    68  	require.NoError(network.AddHandler(1, testHandler))
    69  	client := network.NewClient(1)
    70  
    71  	require.NoError(client.AppGossip(
    72  		ctx,
    73  		common.SendConfig{
    74  			Peers: 1,
    75  		},
    76  		wantMsg,
    77  	))
    78  	require.NoError(network.AppGossip(ctx, wantNodeID, <-sender.SentAppGossip))
    79  	require.True(appGossipCalled)
    80  
    81  	require.NoError(client.AppRequest(ctx, set.Of(ids.EmptyNodeID), wantMsg, func(context.Context, ids.NodeID, []byte, error) {}))
    82  	require.NoError(network.AppRequest(ctx, wantNodeID, 1, time.Time{}, <-sender.SentAppRequest))
    83  	require.True(appRequestCalled)
    84  
    85  	require.NoError(client.CrossChainAppRequest(ctx, ids.Empty, wantMsg, func(context.Context, ids.ID, []byte, error) {}))
    86  	require.NoError(network.CrossChainAppRequest(ctx, wantChainID, 1, time.Time{}, <-sender.SentCrossChainAppRequest))
    87  	require.True(crossChainAppRequestCalled)
    88  }
    89  
    90  // Tests that the Client prefixes messages with the handler prefix
    91  func TestClientPrefixesMessages(t *testing.T) {
    92  	require := require.New(t)
    93  	ctx := context.Background()
    94  
    95  	sender := common.FakeSender{
    96  		SentAppRequest:           make(chan []byte, 1),
    97  		SentAppGossip:            make(chan []byte, 1),
    98  		SentCrossChainAppRequest: make(chan []byte, 1),
    99  	}
   100  
   101  	network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "")
   102  	require.NoError(err)
   103  	require.NoError(network.Connected(ctx, ids.EmptyNodeID, nil))
   104  	client := network.NewClient(handlerID)
   105  
   106  	want := []byte("message")
   107  
   108  	require.NoError(client.AppRequest(
   109  		ctx,
   110  		set.Of(ids.EmptyNodeID),
   111  		want,
   112  		func(context.Context, ids.NodeID, []byte, error) {},
   113  	))
   114  	gotAppRequest := <-sender.SentAppRequest
   115  	require.Equal(handlerPrefix, gotAppRequest[0])
   116  	require.Equal(want, gotAppRequest[1:])
   117  
   118  	require.NoError(client.AppRequestAny(
   119  		ctx,
   120  		want,
   121  		func(context.Context, ids.NodeID, []byte, error) {},
   122  	))
   123  	gotAppRequest = <-sender.SentAppRequest
   124  	require.Equal(handlerPrefix, gotAppRequest[0])
   125  	require.Equal(want, gotAppRequest[1:])
   126  
   127  	require.NoError(client.CrossChainAppRequest(
   128  		ctx,
   129  		ids.Empty,
   130  		want,
   131  		func(context.Context, ids.ID, []byte, error) {},
   132  	))
   133  	gotCrossChainAppRequest := <-sender.SentCrossChainAppRequest
   134  	require.Equal(handlerPrefix, gotCrossChainAppRequest[0])
   135  	require.Equal(want, gotCrossChainAppRequest[1:])
   136  
   137  	require.NoError(client.AppGossip(
   138  		ctx,
   139  		common.SendConfig{
   140  			Peers: 1,
   141  		},
   142  		want,
   143  	))
   144  	gotAppGossip := <-sender.SentAppGossip
   145  	require.Equal(handlerPrefix, gotAppGossip[0])
   146  	require.Equal(want, gotAppGossip[1:])
   147  }
   148  
   149  // Tests that the Client callback is called on a successful response
   150  func TestAppRequestResponse(t *testing.T) {
   151  	require := require.New(t)
   152  	ctx := context.Background()
   153  
   154  	sender := common.FakeSender{
   155  		SentAppRequest: make(chan []byte, 1),
   156  	}
   157  	network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "")
   158  	require.NoError(err)
   159  	client := network.NewClient(handlerID)
   160  
   161  	wantResponse := []byte("response")
   162  	wantNodeID := ids.GenerateTestNodeID()
   163  	done := make(chan struct{})
   164  
   165  	callback := func(_ context.Context, gotNodeID ids.NodeID, gotResponse []byte, err error) {
   166  		require.Equal(wantNodeID, gotNodeID)
   167  		require.NoError(err)
   168  		require.Equal(wantResponse, gotResponse)
   169  
   170  		close(done)
   171  	}
   172  
   173  	want := []byte("request")
   174  	require.NoError(client.AppRequest(ctx, set.Of(wantNodeID), want, callback))
   175  	got := <-sender.SentAppRequest
   176  	require.Equal(handlerPrefix, got[0])
   177  	require.Equal(want, got[1:])
   178  
   179  	require.NoError(network.AppResponse(ctx, wantNodeID, 1, wantResponse))
   180  	<-done
   181  }
   182  
   183  // Tests that the Client does not provide a cancelled context to the AppSender.
   184  func TestAppRequestCancelledContext(t *testing.T) {
   185  	require := require.New(t)
   186  	ctx := context.Background()
   187  
   188  	sentMessages := make(chan []byte, 1)
   189  	sender := &common.SenderTest{
   190  		SendAppRequestF: func(ctx context.Context, _ set.Set[ids.NodeID], _ uint32, msgBytes []byte) error {
   191  			require.NoError(ctx.Err())
   192  			sentMessages <- msgBytes
   193  			return nil
   194  		},
   195  	}
   196  	network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "")
   197  	require.NoError(err)
   198  	client := network.NewClient(handlerID)
   199  
   200  	wantResponse := []byte("response")
   201  	wantNodeID := ids.GenerateTestNodeID()
   202  	done := make(chan struct{})
   203  
   204  	callback := func(_ context.Context, gotNodeID ids.NodeID, gotResponse []byte, err error) {
   205  		require.Equal(wantNodeID, gotNodeID)
   206  		require.NoError(err)
   207  		require.Equal(wantResponse, gotResponse)
   208  
   209  		close(done)
   210  	}
   211  
   212  	cancelledCtx, cancel := context.WithCancel(ctx)
   213  	cancel()
   214  
   215  	want := []byte("request")
   216  	require.NoError(client.AppRequest(cancelledCtx, set.Of(wantNodeID), want, callback))
   217  	got := <-sentMessages
   218  	require.Equal(handlerPrefix, got[0])
   219  	require.Equal(want, got[1:])
   220  
   221  	require.NoError(network.AppResponse(ctx, wantNodeID, 1, wantResponse))
   222  	<-done
   223  }
   224  
   225  // Tests that the Client callback is given an error if the request fails
   226  func TestAppRequestFailed(t *testing.T) {
   227  	require := require.New(t)
   228  	ctx := context.Background()
   229  
   230  	sender := common.FakeSender{
   231  		SentAppRequest: make(chan []byte, 1),
   232  	}
   233  	network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "")
   234  	require.NoError(err)
   235  	client := network.NewClient(handlerID)
   236  
   237  	wantNodeID := ids.GenerateTestNodeID()
   238  	done := make(chan struct{})
   239  
   240  	callback := func(_ context.Context, gotNodeID ids.NodeID, gotResponse []byte, err error) {
   241  		require.Equal(wantNodeID, gotNodeID)
   242  		require.ErrorIs(err, errFoo)
   243  		require.Nil(gotResponse)
   244  
   245  		close(done)
   246  	}
   247  
   248  	require.NoError(client.AppRequest(ctx, set.Of(wantNodeID), []byte("request"), callback))
   249  	<-sender.SentAppRequest
   250  
   251  	require.NoError(network.AppRequestFailed(ctx, wantNodeID, 1, errFoo))
   252  	<-done
   253  }
   254  
   255  // Tests that the Client callback is called on a successful response
   256  func TestCrossChainAppRequestResponse(t *testing.T) {
   257  	require := require.New(t)
   258  	ctx := context.Background()
   259  
   260  	sender := common.FakeSender{
   261  		SentCrossChainAppRequest: make(chan []byte, 1),
   262  	}
   263  	network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "")
   264  	require.NoError(err)
   265  	client := network.NewClient(handlerID)
   266  
   267  	wantChainID := ids.GenerateTestID()
   268  	wantResponse := []byte("response")
   269  	done := make(chan struct{})
   270  
   271  	callback := func(_ context.Context, gotChainID ids.ID, gotResponse []byte, err error) {
   272  		require.Equal(wantChainID, gotChainID)
   273  		require.NoError(err)
   274  		require.Equal(wantResponse, gotResponse)
   275  
   276  		close(done)
   277  	}
   278  
   279  	require.NoError(client.CrossChainAppRequest(ctx, wantChainID, []byte("request"), callback))
   280  	<-sender.SentCrossChainAppRequest
   281  
   282  	require.NoError(network.CrossChainAppResponse(ctx, wantChainID, 1, wantResponse))
   283  	<-done
   284  }
   285  
   286  // Tests that the Client does not provide a cancelled context to the AppSender.
   287  func TestCrossChainAppRequestCancelledContext(t *testing.T) {
   288  	require := require.New(t)
   289  	ctx := context.Background()
   290  
   291  	sentMessages := make(chan []byte, 1)
   292  	sender := &common.SenderTest{
   293  		SendCrossChainAppRequestF: func(ctx context.Context, _ ids.ID, _ uint32, msgBytes []byte) {
   294  			require.NoError(ctx.Err())
   295  			sentMessages <- msgBytes
   296  		},
   297  	}
   298  	network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "")
   299  	require.NoError(err)
   300  	client := network.NewClient(handlerID)
   301  
   302  	cancelledCtx, cancel := context.WithCancel(ctx)
   303  	cancel()
   304  
   305  	wantChainID := ids.GenerateTestID()
   306  	wantResponse := []byte("response")
   307  	done := make(chan struct{})
   308  
   309  	callback := func(_ context.Context, gotChainID ids.ID, gotResponse []byte, err error) {
   310  		require.Equal(wantChainID, gotChainID)
   311  		require.NoError(err)
   312  		require.Equal(wantResponse, gotResponse)
   313  
   314  		close(done)
   315  	}
   316  
   317  	require.NoError(client.CrossChainAppRequest(cancelledCtx, wantChainID, []byte("request"), callback))
   318  	<-sentMessages
   319  
   320  	require.NoError(network.CrossChainAppResponse(ctx, wantChainID, 1, wantResponse))
   321  	<-done
   322  }
   323  
   324  // Tests that the Client callback is given an error if the request fails
   325  func TestCrossChainAppRequestFailed(t *testing.T) {
   326  	require := require.New(t)
   327  	ctx := context.Background()
   328  
   329  	sender := common.FakeSender{
   330  		SentCrossChainAppRequest: make(chan []byte, 1),
   331  	}
   332  	network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "")
   333  	require.NoError(err)
   334  	client := network.NewClient(handlerID)
   335  
   336  	wantChainID := ids.GenerateTestID()
   337  	done := make(chan struct{})
   338  
   339  	callback := func(_ context.Context, gotChainID ids.ID, gotResponse []byte, err error) {
   340  		require.Equal(wantChainID, gotChainID)
   341  		require.ErrorIs(err, errFoo)
   342  		require.Nil(gotResponse)
   343  
   344  		close(done)
   345  	}
   346  
   347  	require.NoError(client.CrossChainAppRequest(ctx, wantChainID, []byte("request"), callback))
   348  	<-sender.SentCrossChainAppRequest
   349  
   350  	require.NoError(network.CrossChainAppRequestFailed(ctx, wantChainID, 1, errFoo))
   351  	<-done
   352  }
   353  
   354  // Messages for unregistered handlers should be dropped gracefully
   355  func TestMessageForUnregisteredHandler(t *testing.T) {
   356  	tests := []struct {
   357  		name string
   358  		msg  []byte
   359  	}{
   360  		{
   361  			name: "nil",
   362  			msg:  nil,
   363  		},
   364  		{
   365  			name: "empty",
   366  			msg:  []byte{},
   367  		},
   368  		{
   369  			name: "non-empty",
   370  			msg:  []byte("foobar"),
   371  		},
   372  	}
   373  
   374  	for _, tt := range tests {
   375  		t.Run(tt.name, func(t *testing.T) {
   376  			require := require.New(t)
   377  			ctx := context.Background()
   378  			handler := &TestHandler{
   379  				AppGossipF: func(context.Context, ids.NodeID, []byte) {
   380  					require.Fail("should not be called")
   381  				},
   382  				AppRequestF: func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) {
   383  					require.Fail("should not be called")
   384  					return nil, nil
   385  				},
   386  				CrossChainAppRequestF: func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) {
   387  					require.Fail("should not be called")
   388  					return nil, nil
   389  				},
   390  			}
   391  			network, err := NewNetwork(logging.NoLog{}, nil, prometheus.NewRegistry(), "")
   392  			require.NoError(err)
   393  			require.NoError(network.AddHandler(handlerID, handler))
   394  
   395  			require.NoError(network.AppRequest(ctx, ids.EmptyNodeID, 0, time.Time{}, tt.msg))
   396  			require.NoError(network.AppGossip(ctx, ids.EmptyNodeID, tt.msg))
   397  			require.NoError(network.CrossChainAppRequest(ctx, ids.Empty, 0, time.Time{}, tt.msg))
   398  		})
   399  	}
   400  }
   401  
   402  // A response or timeout for a request we never made should return an error
   403  func TestResponseForUnrequestedRequest(t *testing.T) {
   404  	tests := []struct {
   405  		name string
   406  		msg  []byte
   407  	}{
   408  		{
   409  			name: "nil",
   410  			msg:  nil,
   411  		},
   412  		{
   413  			name: "empty",
   414  			msg:  []byte{},
   415  		},
   416  		{
   417  			name: "non-empty",
   418  			msg:  []byte("foobar"),
   419  		},
   420  	}
   421  
   422  	for _, tt := range tests {
   423  		t.Run(tt.name, func(t *testing.T) {
   424  			require := require.New(t)
   425  			ctx := context.Background()
   426  			handler := &TestHandler{
   427  				AppGossipF: func(context.Context, ids.NodeID, []byte) {
   428  					require.Fail("should not be called")
   429  				},
   430  				AppRequestF: func(context.Context, ids.NodeID, time.Time, []byte) ([]byte, error) {
   431  					require.Fail("should not be called")
   432  					return nil, nil
   433  				},
   434  				CrossChainAppRequestF: func(context.Context, ids.ID, time.Time, []byte) ([]byte, error) {
   435  					require.Fail("should not be called")
   436  					return nil, nil
   437  				},
   438  			}
   439  			network, err := NewNetwork(logging.NoLog{}, nil, prometheus.NewRegistry(), "")
   440  			require.NoError(err)
   441  			require.NoError(network.AddHandler(handlerID, handler))
   442  
   443  			err = network.AppResponse(ctx, ids.EmptyNodeID, 0, []byte("foobar"))
   444  			require.ErrorIs(err, ErrUnrequestedResponse)
   445  			err = network.AppRequestFailed(ctx, ids.EmptyNodeID, 0, common.ErrTimeout)
   446  			require.ErrorIs(err, ErrUnrequestedResponse)
   447  			err = network.CrossChainAppResponse(ctx, ids.Empty, 0, []byte("foobar"))
   448  			require.ErrorIs(err, ErrUnrequestedResponse)
   449  			err = network.CrossChainAppRequestFailed(ctx, ids.Empty, 0, common.ErrTimeout)
   450  
   451  			require.ErrorIs(err, ErrUnrequestedResponse)
   452  		})
   453  	}
   454  }
   455  
   456  // It's possible for the request id to overflow and wrap around.
   457  // If there are still pending requests with the same request id, we should
   458  // not attempt to issue another request until the previous one has cleared.
   459  func TestAppRequestDuplicateRequestIDs(t *testing.T) {
   460  	require := require.New(t)
   461  	ctx := context.Background()
   462  
   463  	sender := &common.FakeSender{
   464  		SentAppRequest: make(chan []byte, 1),
   465  	}
   466  
   467  	network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "")
   468  	require.NoError(err)
   469  	client := network.NewClient(0x1)
   470  
   471  	noOpCallback := func(context.Context, ids.NodeID, []byte, error) {}
   472  	// create a request that never gets a response
   473  	network.router.requestID = 1
   474  	require.NoError(client.AppRequest(ctx, set.Of(ids.EmptyNodeID), []byte{}, noOpCallback))
   475  	<-sender.SentAppRequest
   476  
   477  	// force the network to use the same requestID
   478  	network.router.requestID = 1
   479  	err = client.AppRequest(context.Background(), set.Of(ids.EmptyNodeID), []byte{}, noOpCallback)
   480  	require.ErrorIs(err, ErrRequestPending)
   481  }
   482  
   483  // Sample should always return up to [limit] peers, and less if fewer than
   484  // [limit] peers are available.
   485  func TestPeersSample(t *testing.T) {
   486  	nodeID1 := ids.GenerateTestNodeID()
   487  	nodeID2 := ids.GenerateTestNodeID()
   488  	nodeID3 := ids.GenerateTestNodeID()
   489  
   490  	tests := []struct {
   491  		name         string
   492  		connected    set.Set[ids.NodeID]
   493  		disconnected set.Set[ids.NodeID]
   494  		limit        int
   495  	}{
   496  		{
   497  			name:  "no peers",
   498  			limit: 1,
   499  		},
   500  		{
   501  			name:      "one peer connected",
   502  			connected: set.Of(nodeID1),
   503  			limit:     1,
   504  		},
   505  		{
   506  			name:      "multiple peers connected",
   507  			connected: set.Of(nodeID1, nodeID2, nodeID3),
   508  			limit:     1,
   509  		},
   510  		{
   511  			name:         "peer connects and disconnects - 1",
   512  			connected:    set.Of(nodeID1),
   513  			disconnected: set.Of(nodeID1),
   514  			limit:        1,
   515  		},
   516  		{
   517  			name:         "peer connects and disconnects - 2",
   518  			connected:    set.Of(nodeID1, nodeID2),
   519  			disconnected: set.Of(nodeID2),
   520  			limit:        1,
   521  		},
   522  		{
   523  			name:         "peer connects and disconnects - 2",
   524  			connected:    set.Of(nodeID1, nodeID2, nodeID3),
   525  			disconnected: set.Of(nodeID1, nodeID2),
   526  			limit:        1,
   527  		},
   528  		{
   529  			name:      "less than limit peers",
   530  			connected: set.Of(nodeID1, nodeID2, nodeID3),
   531  			limit:     4,
   532  		},
   533  		{
   534  			name:      "limit peers",
   535  			connected: set.Of(nodeID1, nodeID2, nodeID3),
   536  			limit:     3,
   537  		},
   538  		{
   539  			name:      "more than limit peers",
   540  			connected: set.Of(nodeID1, nodeID2, nodeID3),
   541  			limit:     2,
   542  		},
   543  	}
   544  
   545  	for _, tt := range tests {
   546  		t.Run(tt.name, func(t *testing.T) {
   547  			require := require.New(t)
   548  
   549  			network, err := NewNetwork(logging.NoLog{}, &common.FakeSender{}, prometheus.NewRegistry(), "")
   550  			require.NoError(err)
   551  
   552  			for connected := range tt.connected {
   553  				require.NoError(network.Connected(context.Background(), connected, nil))
   554  			}
   555  
   556  			for disconnected := range tt.disconnected {
   557  				require.NoError(network.Disconnected(context.Background(), disconnected))
   558  			}
   559  
   560  			sampleable := set.Set[ids.NodeID]{}
   561  			sampleable.Union(tt.connected)
   562  			sampleable.Difference(tt.disconnected)
   563  
   564  			sampled := network.Peers.Sample(tt.limit)
   565  			require.Len(sampled, min(tt.limit, len(sampleable)))
   566  			require.Subset(sampleable, sampled)
   567  		})
   568  	}
   569  }
   570  
   571  func TestAppRequestAnyNodeSelection(t *testing.T) {
   572  	tests := []struct {
   573  		name     string
   574  		peers    []ids.NodeID
   575  		expected error
   576  	}{
   577  		{
   578  			name:     "no peers",
   579  			expected: ErrNoPeers,
   580  		},
   581  		{
   582  			name:  "has peers",
   583  			peers: []ids.NodeID{ids.GenerateTestNodeID()},
   584  		},
   585  	}
   586  
   587  	for _, tt := range tests {
   588  		t.Run(tt.name, func(t *testing.T) {
   589  			require := require.New(t)
   590  
   591  			sent := set.Set[ids.NodeID]{}
   592  			sender := &common.SenderTest{
   593  				SendAppRequestF: func(_ context.Context, nodeIDs set.Set[ids.NodeID], _ uint32, _ []byte) error {
   594  					sent = nodeIDs
   595  					return nil
   596  				},
   597  			}
   598  
   599  			n, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "")
   600  			require.NoError(err)
   601  			for _, peer := range tt.peers {
   602  				require.NoError(n.Connected(context.Background(), peer, &version.Application{}))
   603  			}
   604  
   605  			client := n.NewClient(1)
   606  
   607  			err = client.AppRequestAny(context.Background(), []byte("foobar"), nil)
   608  			require.ErrorIs(err, tt.expected)
   609  			require.Subset(tt.peers, sent.List())
   610  		})
   611  	}
   612  }
   613  
   614  func TestNodeSamplerClientOption(t *testing.T) {
   615  	nodeID0 := ids.GenerateTestNodeID()
   616  	nodeID1 := ids.GenerateTestNodeID()
   617  	nodeID2 := ids.GenerateTestNodeID()
   618  
   619  	tests := []struct {
   620  		name        string
   621  		peers       []ids.NodeID
   622  		option      func(t *testing.T, n *Network) ClientOption
   623  		expected    []ids.NodeID
   624  		expectedErr error
   625  	}{
   626  		{
   627  			name:  "default",
   628  			peers: []ids.NodeID{nodeID0, nodeID1, nodeID2},
   629  			option: func(*testing.T, *Network) ClientOption {
   630  				return clientOptionFunc(func(*clientOptions) {})
   631  			},
   632  			expected: []ids.NodeID{nodeID0, nodeID1, nodeID2},
   633  		},
   634  		{
   635  			name:  "validator connected",
   636  			peers: []ids.NodeID{nodeID0, nodeID1},
   637  			option: func(_ *testing.T, n *Network) ClientOption {
   638  				state := &validators.TestState{
   639  					GetCurrentHeightF: func(context.Context) (uint64, error) {
   640  						return 0, nil
   641  					},
   642  					GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) {
   643  						return map[ids.NodeID]*validators.GetValidatorOutput{
   644  							nodeID1: {
   645  								NodeID: nodeID1,
   646  								Weight: 1,
   647  							},
   648  						}, nil
   649  					},
   650  				}
   651  
   652  				validators := NewValidators(n.Peers, n.log, ids.Empty, state, 0)
   653  				return WithValidatorSampling(validators)
   654  			},
   655  			expected: []ids.NodeID{nodeID1},
   656  		},
   657  		{
   658  			name:  "validator disconnected",
   659  			peers: []ids.NodeID{nodeID0},
   660  			option: func(_ *testing.T, n *Network) ClientOption {
   661  				state := &validators.TestState{
   662  					GetCurrentHeightF: func(context.Context) (uint64, error) {
   663  						return 0, nil
   664  					},
   665  					GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) {
   666  						return map[ids.NodeID]*validators.GetValidatorOutput{
   667  							nodeID1: {
   668  								NodeID: nodeID1,
   669  								Weight: 1,
   670  							},
   671  						}, nil
   672  					},
   673  				}
   674  
   675  				validators := NewValidators(n.Peers, n.log, ids.Empty, state, 0)
   676  				return WithValidatorSampling(validators)
   677  			},
   678  			expectedErr: ErrNoPeers,
   679  		},
   680  	}
   681  
   682  	for _, tt := range tests {
   683  		t.Run(tt.name, func(t *testing.T) {
   684  			require := require.New(t)
   685  
   686  			done := make(chan struct{})
   687  			sender := &common.SenderTest{
   688  				SendAppRequestF: func(_ context.Context, nodeIDs set.Set[ids.NodeID], _ uint32, _ []byte) error {
   689  					require.Subset(tt.expected, nodeIDs.List())
   690  					close(done)
   691  					return nil
   692  				},
   693  			}
   694  			network, err := NewNetwork(logging.NoLog{}, sender, prometheus.NewRegistry(), "")
   695  			require.NoError(err)
   696  			ctx := context.Background()
   697  			for _, peer := range tt.peers {
   698  				require.NoError(network.Connected(ctx, peer, nil))
   699  			}
   700  
   701  			client := network.NewClient(0, tt.option(t, network))
   702  
   703  			if err = client.AppRequestAny(ctx, []byte("request"), nil); err != nil {
   704  				close(done)
   705  			}
   706  
   707  			require.ErrorIs(err, tt.expectedErr)
   708  			<-done
   709  		})
   710  	}
   711  }
   712  
   713  // Tests that a given protocol can have more than one client
   714  func TestMultipleClients(t *testing.T) {
   715  	require := require.New(t)
   716  
   717  	n, err := NewNetwork(logging.NoLog{}, &common.SenderTest{}, prometheus.NewRegistry(), "")
   718  	require.NoError(err)
   719  	_ = n.NewClient(0)
   720  	_ = n.NewClient(0)
   721  }