github.com/MetalBlockchain/metalgo@v1.11.9/network/network_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package network
     5  
     6  import (
     7  	"context"
     8  	"crypto"
     9  	"net/netip"
    10  	"sync"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/prometheus/client_golang/prometheus"
    15  	"github.com/stretchr/testify/require"
    16  
    17  	"github.com/MetalBlockchain/metalgo/ids"
    18  	"github.com/MetalBlockchain/metalgo/message"
    19  	"github.com/MetalBlockchain/metalgo/network/dialer"
    20  	"github.com/MetalBlockchain/metalgo/network/peer"
    21  	"github.com/MetalBlockchain/metalgo/network/throttling"
    22  	"github.com/MetalBlockchain/metalgo/snow/engine/common"
    23  	"github.com/MetalBlockchain/metalgo/snow/networking/router"
    24  	"github.com/MetalBlockchain/metalgo/snow/networking/tracker"
    25  	"github.com/MetalBlockchain/metalgo/snow/uptime"
    26  	"github.com/MetalBlockchain/metalgo/snow/validators"
    27  	"github.com/MetalBlockchain/metalgo/staking"
    28  	"github.com/MetalBlockchain/metalgo/subnets"
    29  	"github.com/MetalBlockchain/metalgo/utils"
    30  	"github.com/MetalBlockchain/metalgo/utils/constants"
    31  	"github.com/MetalBlockchain/metalgo/utils/crypto/bls"
    32  	"github.com/MetalBlockchain/metalgo/utils/ips"
    33  	"github.com/MetalBlockchain/metalgo/utils/logging"
    34  	"github.com/MetalBlockchain/metalgo/utils/math/meter"
    35  	"github.com/MetalBlockchain/metalgo/utils/resource"
    36  	"github.com/MetalBlockchain/metalgo/utils/set"
    37  	"github.com/MetalBlockchain/metalgo/utils/timer/mockable"
    38  	"github.com/MetalBlockchain/metalgo/utils/units"
    39  	"github.com/MetalBlockchain/metalgo/version"
    40  )
    41  
    42  var (
    43  	defaultHealthConfig = HealthConfig{
    44  		MinConnectedPeers:            1,
    45  		MaxTimeSinceMsgReceived:      time.Minute,
    46  		MaxTimeSinceMsgSent:          time.Minute,
    47  		MaxPortionSendQueueBytesFull: .9,
    48  		MaxSendFailRate:              .1,
    49  		SendFailRateHalflife:         time.Second,
    50  	}
    51  	defaultPeerListGossipConfig = PeerListGossipConfig{
    52  		PeerListNumValidatorIPs: 100,
    53  		PeerListPullGossipFreq:  time.Second,
    54  		PeerListBloomResetFreq:  constants.DefaultNetworkPeerListBloomResetFreq,
    55  	}
    56  	defaultTimeoutConfig = TimeoutConfig{
    57  		PingPongTimeout:      30 * time.Second,
    58  		ReadHandshakeTimeout: 15 * time.Second,
    59  	}
    60  	defaultDelayConfig = DelayConfig{
    61  		MaxReconnectDelay:     time.Hour,
    62  		InitialReconnectDelay: time.Second,
    63  	}
    64  	defaultThrottlerConfig = ThrottlerConfig{
    65  		InboundConnUpgradeThrottlerConfig: throttling.InboundConnUpgradeThrottlerConfig{
    66  			UpgradeCooldown:        time.Second,
    67  			MaxRecentConnsUpgraded: 100,
    68  		},
    69  		InboundMsgThrottlerConfig: throttling.InboundMsgThrottlerConfig{
    70  			MsgByteThrottlerConfig: throttling.MsgByteThrottlerConfig{
    71  				VdrAllocSize:        1 * units.GiB,
    72  				AtLargeAllocSize:    1 * units.GiB,
    73  				NodeMaxAtLargeBytes: constants.DefaultMaxMessageSize,
    74  			},
    75  			BandwidthThrottlerConfig: throttling.BandwidthThrottlerConfig{
    76  				RefillRate:   units.MiB,
    77  				MaxBurstSize: constants.DefaultMaxMessageSize,
    78  			},
    79  			CPUThrottlerConfig: throttling.SystemThrottlerConfig{
    80  				MaxRecheckDelay: 50 * time.Millisecond,
    81  			},
    82  			MaxProcessingMsgsPerNode: 100,
    83  			DiskThrottlerConfig: throttling.SystemThrottlerConfig{
    84  				MaxRecheckDelay: 50 * time.Millisecond,
    85  			},
    86  		},
    87  		OutboundMsgThrottlerConfig: throttling.MsgByteThrottlerConfig{
    88  			VdrAllocSize:        1 * units.GiB,
    89  			AtLargeAllocSize:    1 * units.GiB,
    90  			NodeMaxAtLargeBytes: constants.DefaultMaxMessageSize,
    91  		},
    92  		MaxInboundConnsPerSec: 100,
    93  	}
    94  	defaultDialerConfig = dialer.Config{
    95  		ThrottleRps:       100,
    96  		ConnectionTimeout: time.Second,
    97  	}
    98  
    99  	defaultConfig = Config{
   100  		HealthConfig:         defaultHealthConfig,
   101  		PeerListGossipConfig: defaultPeerListGossipConfig,
   102  		TimeoutConfig:        defaultTimeoutConfig,
   103  		DelayConfig:          defaultDelayConfig,
   104  		ThrottlerConfig:      defaultThrottlerConfig,
   105  
   106  		DialerConfig: defaultDialerConfig,
   107  
   108  		NetworkID:          49463,
   109  		MaxClockDifference: time.Minute,
   110  		PingFrequency:      constants.DefaultPingFrequency,
   111  		AllowPrivateIPs:    true,
   112  
   113  		CompressionType: constants.DefaultNetworkCompressionType,
   114  
   115  		UptimeCalculator:  uptime.NewManager(uptime.NewTestState(), &mockable.Clock{}),
   116  		UptimeMetricFreq:  30 * time.Second,
   117  		UptimeRequirement: .8,
   118  
   119  		RequireValidatorToConnect: false,
   120  
   121  		MaximumInboundMessageTimeout: 30 * time.Second,
   122  		ResourceTracker:              newDefaultResourceTracker(),
   123  		CPUTargeter:                  nil, // Set in init
   124  		DiskTargeter:                 nil, // Set in init
   125  	}
   126  )
   127  
   128  func init() {
   129  	defaultConfig.CPUTargeter = newDefaultTargeter(defaultConfig.ResourceTracker.CPUTracker())
   130  	defaultConfig.DiskTargeter = newDefaultTargeter(defaultConfig.ResourceTracker.DiskTracker())
   131  }
   132  
   133  func newDefaultTargeter(t tracker.Tracker) tracker.Targeter {
   134  	return tracker.NewTargeter(
   135  		logging.NoLog{},
   136  		&tracker.TargeterConfig{
   137  			VdrAlloc:           10,
   138  			MaxNonVdrUsage:     10,
   139  			MaxNonVdrNodeUsage: 10,
   140  		},
   141  		validators.NewManager(),
   142  		t,
   143  	)
   144  }
   145  
   146  func newDefaultResourceTracker() tracker.ResourceTracker {
   147  	tracker, err := tracker.NewResourceTracker(
   148  		prometheus.NewRegistry(),
   149  		resource.NoUsage,
   150  		meter.ContinuousFactory{},
   151  		10*time.Second,
   152  	)
   153  	if err != nil {
   154  		panic(err)
   155  	}
   156  	return tracker
   157  }
   158  
   159  func newTestNetwork(t *testing.T, count int) (*testDialer, []*testListener, []ids.NodeID, []*Config) {
   160  	var (
   161  		dialer    = newTestDialer()
   162  		listeners = make([]*testListener, count)
   163  		nodeIDs   = make([]ids.NodeID, count)
   164  		configs   = make([]*Config, count)
   165  	)
   166  	for i := 0; i < count; i++ {
   167  		ip, listener := dialer.NewListener()
   168  
   169  		tlsCert, err := staking.NewTLSCert()
   170  		require.NoError(t, err)
   171  
   172  		cert, err := staking.ParseCertificate(tlsCert.Leaf.Raw)
   173  		require.NoError(t, err)
   174  		nodeID := ids.NodeIDFromCert(cert)
   175  
   176  		blsKey, err := bls.NewSecretKey()
   177  		require.NoError(t, err)
   178  
   179  		config := defaultConfig
   180  		config.TLSConfig = peer.TLSConfig(*tlsCert, nil)
   181  		config.MyNodeID = nodeID
   182  		config.MyIPPort = utils.NewAtomic(ip)
   183  		config.TLSKey = tlsCert.PrivateKey.(crypto.Signer)
   184  		config.BLSKey = blsKey
   185  
   186  		listeners[i] = listener
   187  		nodeIDs[i] = nodeID
   188  		configs[i] = &config
   189  	}
   190  	return dialer, listeners, nodeIDs, configs
   191  }
   192  
   193  func newMessageCreator(t *testing.T) message.Creator {
   194  	t.Helper()
   195  
   196  	mc, err := message.NewCreator(
   197  		logging.NoLog{},
   198  		prometheus.NewRegistry(),
   199  		constants.DefaultNetworkCompressionType,
   200  		10*time.Second,
   201  	)
   202  	require.NoError(t, err)
   203  
   204  	return mc
   205  }
   206  
   207  func newFullyConnectedTestNetwork(t *testing.T, handlers []router.InboundHandler) ([]ids.NodeID, []*network, *sync.WaitGroup) {
   208  	require := require.New(t)
   209  
   210  	dialer, listeners, nodeIDs, configs := newTestNetwork(t, len(handlers))
   211  
   212  	var (
   213  		networks = make([]*network, len(configs))
   214  
   215  		globalLock     sync.Mutex
   216  		numConnected   int
   217  		allConnected   bool
   218  		onAllConnected = make(chan struct{})
   219  	)
   220  	for i, config := range configs {
   221  		msgCreator := newMessageCreator(t)
   222  		registry := prometheus.NewRegistry()
   223  
   224  		beacons := validators.NewManager()
   225  		require.NoError(beacons.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1))
   226  
   227  		vdrs := validators.NewManager()
   228  		for _, nodeID := range nodeIDs {
   229  			require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.GenerateTestID(), 1))
   230  		}
   231  
   232  		config := config
   233  
   234  		config.Beacons = beacons
   235  		config.Validators = vdrs
   236  
   237  		var connected set.Set[ids.NodeID]
   238  		net, err := NewNetwork(
   239  			config,
   240  			msgCreator,
   241  			registry,
   242  			logging.NoLog{},
   243  			listeners[i],
   244  			dialer,
   245  			&testHandler{
   246  				InboundHandler: handlers[i],
   247  				ConnectedF: func(nodeID ids.NodeID, _ *version.Application, _ ids.ID) {
   248  					t.Logf("%s connected to %s", config.MyNodeID, nodeID)
   249  
   250  					globalLock.Lock()
   251  					defer globalLock.Unlock()
   252  
   253  					require.False(connected.Contains(nodeID))
   254  					connected.Add(nodeID)
   255  					numConnected++
   256  
   257  					if !allConnected && numConnected == len(nodeIDs)*(len(nodeIDs)-1) {
   258  						allConnected = true
   259  						close(onAllConnected)
   260  					}
   261  				},
   262  				DisconnectedF: func(nodeID ids.NodeID) {
   263  					t.Logf("%s disconnected from %s", config.MyNodeID, nodeID)
   264  
   265  					globalLock.Lock()
   266  					defer globalLock.Unlock()
   267  
   268  					require.True(connected.Contains(nodeID))
   269  					connected.Remove(nodeID)
   270  					numConnected--
   271  				},
   272  			},
   273  		)
   274  		require.NoError(err)
   275  		networks[i] = net.(*network)
   276  	}
   277  
   278  	wg := sync.WaitGroup{}
   279  	wg.Add(len(networks))
   280  	for i, net := range networks {
   281  		if i != 0 {
   282  			config := configs[0]
   283  			net.ManuallyTrack(config.MyNodeID, config.MyIPPort.Get())
   284  		}
   285  
   286  		go func(net Network) {
   287  			defer wg.Done()
   288  
   289  			require.NoError(net.Dispatch())
   290  		}(net)
   291  	}
   292  
   293  	if len(networks) > 1 {
   294  		<-onAllConnected
   295  	}
   296  
   297  	return nodeIDs, networks, &wg
   298  }
   299  
   300  func TestNewNetwork(t *testing.T) {
   301  	_, networks, wg := newFullyConnectedTestNetwork(t, []router.InboundHandler{nil, nil, nil})
   302  	for _, net := range networks {
   303  		net.StartClose()
   304  	}
   305  	wg.Wait()
   306  }
   307  
   308  func TestSend(t *testing.T) {
   309  	require := require.New(t)
   310  
   311  	received := make(chan message.InboundMessage)
   312  	nodeIDs, networks, wg := newFullyConnectedTestNetwork(
   313  		t,
   314  		[]router.InboundHandler{
   315  			router.InboundHandlerFunc(func(context.Context, message.InboundMessage) {
   316  				require.FailNow("unexpected message received")
   317  			}),
   318  			router.InboundHandlerFunc(func(_ context.Context, msg message.InboundMessage) {
   319  				received <- msg
   320  			}),
   321  			router.InboundHandlerFunc(func(context.Context, message.InboundMessage) {
   322  				require.FailNow("unexpected message received")
   323  			}),
   324  		},
   325  	)
   326  
   327  	net0 := networks[0]
   328  
   329  	mc := newMessageCreator(t)
   330  	outboundGetMsg, err := mc.Get(ids.Empty, 1, time.Second, ids.Empty)
   331  	require.NoError(err)
   332  
   333  	toSend := set.Of(nodeIDs[1])
   334  	sentTo := net0.Send(
   335  		outboundGetMsg,
   336  		common.SendConfig{
   337  			NodeIDs: toSend,
   338  		},
   339  		constants.PrimaryNetworkID,
   340  		subnets.NoOpAllower,
   341  	)
   342  	require.Equal(toSend, sentTo)
   343  
   344  	inboundGetMsg := <-received
   345  	require.Equal(message.GetOp, inboundGetMsg.Op())
   346  
   347  	for _, net := range networks {
   348  		net.StartClose()
   349  	}
   350  	wg.Wait()
   351  }
   352  
   353  func TestSendWithFilter(t *testing.T) {
   354  	require := require.New(t)
   355  
   356  	received := make(chan message.InboundMessage)
   357  	nodeIDs, networks, wg := newFullyConnectedTestNetwork(
   358  		t,
   359  		[]router.InboundHandler{
   360  			router.InboundHandlerFunc(func(context.Context, message.InboundMessage) {
   361  				require.FailNow("unexpected message received")
   362  			}),
   363  			router.InboundHandlerFunc(func(_ context.Context, msg message.InboundMessage) {
   364  				received <- msg
   365  			}),
   366  			router.InboundHandlerFunc(func(context.Context, message.InboundMessage) {
   367  				require.FailNow("unexpected message received")
   368  			}),
   369  		},
   370  	)
   371  
   372  	net0 := networks[0]
   373  
   374  	mc := newMessageCreator(t)
   375  	outboundGetMsg, err := mc.Get(ids.Empty, 1, time.Second, ids.Empty)
   376  	require.NoError(err)
   377  
   378  	toSend := set.Of(nodeIDs...)
   379  	validNodeID := nodeIDs[1]
   380  	sentTo := net0.Send(
   381  		outboundGetMsg,
   382  		common.SendConfig{
   383  			NodeIDs: toSend,
   384  		},
   385  		constants.PrimaryNetworkID,
   386  		newNodeIDConnector(validNodeID),
   387  	)
   388  	require.Len(sentTo, 1)
   389  	require.Contains(sentTo, validNodeID)
   390  
   391  	inboundGetMsg := <-received
   392  	require.Equal(message.GetOp, inboundGetMsg.Op())
   393  
   394  	for _, net := range networks {
   395  		net.StartClose()
   396  	}
   397  	wg.Wait()
   398  }
   399  
   400  func TestTrackVerifiesSignatures(t *testing.T) {
   401  	require := require.New(t)
   402  
   403  	_, networks, wg := newFullyConnectedTestNetwork(t, []router.InboundHandler{nil})
   404  
   405  	network := networks[0]
   406  
   407  	tlsCert, err := staking.NewTLSCert()
   408  	require.NoError(err)
   409  
   410  	cert, err := staking.ParseCertificate(tlsCert.Leaf.Raw)
   411  	require.NoError(err)
   412  	nodeID := ids.NodeIDFromCert(cert)
   413  
   414  	require.NoError(network.config.Validators.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.Empty, 1))
   415  
   416  	stakingCert, err := staking.ParseCertificate(tlsCert.Leaf.Raw)
   417  	require.NoError(err)
   418  
   419  	err = network.Track([]*ips.ClaimedIPPort{
   420  		ips.NewClaimedIPPort(
   421  			stakingCert,
   422  			netip.AddrPortFrom(
   423  				netip.AddrFrom4([4]byte{123, 132, 123, 123}),
   424  				10000,
   425  			),
   426  			1000, // timestamp
   427  			nil,  // signature
   428  		),
   429  	})
   430  	// The signature is wrong so this peer tracking info isn't useful.
   431  	require.ErrorIs(err, staking.ErrECDSAVerificationFailure)
   432  
   433  	network.peersLock.RLock()
   434  	require.Empty(network.trackedIPs)
   435  	network.peersLock.RUnlock()
   436  
   437  	for _, net := range networks {
   438  		net.StartClose()
   439  	}
   440  	wg.Wait()
   441  }
   442  
   443  func TestTrackDoesNotDialPrivateIPs(t *testing.T) {
   444  	require := require.New(t)
   445  
   446  	dialer, listeners, nodeIDs, configs := newTestNetwork(t, 2)
   447  
   448  	networks := make([]Network, len(configs))
   449  	for i, config := range configs {
   450  		msgCreator := newMessageCreator(t)
   451  		registry := prometheus.NewRegistry()
   452  
   453  		beacons := validators.NewManager()
   454  		require.NoError(beacons.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1))
   455  
   456  		vdrs := validators.NewManager()
   457  		for _, nodeID := range nodeIDs {
   458  			require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.GenerateTestID(), 1))
   459  		}
   460  
   461  		config := config
   462  
   463  		config.Beacons = beacons
   464  		config.Validators = vdrs
   465  		config.AllowPrivateIPs = false
   466  
   467  		net, err := NewNetwork(
   468  			config,
   469  			msgCreator,
   470  			registry,
   471  			logging.NoLog{},
   472  			listeners[i],
   473  			dialer,
   474  			&testHandler{
   475  				InboundHandler: nil,
   476  				ConnectedF: func(ids.NodeID, *version.Application, ids.ID) {
   477  					require.FailNow("unexpectedly connected to a peer")
   478  				},
   479  				DisconnectedF: nil,
   480  			},
   481  		)
   482  		require.NoError(err)
   483  		networks[i] = net
   484  	}
   485  
   486  	wg := sync.WaitGroup{}
   487  	wg.Add(len(networks))
   488  	for i, net := range networks {
   489  		if i != 0 {
   490  			config := configs[0]
   491  			net.ManuallyTrack(config.MyNodeID, config.MyIPPort.Get())
   492  		}
   493  
   494  		go func(net Network) {
   495  			defer wg.Done()
   496  
   497  			require.NoError(net.Dispatch())
   498  		}(net)
   499  	}
   500  
   501  	network := networks[1].(*network)
   502  	require.Eventually(
   503  		func() bool {
   504  			network.peersLock.RLock()
   505  			defer network.peersLock.RUnlock()
   506  
   507  			nodeID := nodeIDs[0]
   508  			require.Contains(network.trackedIPs, nodeID)
   509  			ip := network.trackedIPs[nodeID]
   510  			return ip.getDelay() != 0
   511  		},
   512  		10*time.Second,
   513  		50*time.Millisecond,
   514  	)
   515  
   516  	for _, net := range networks {
   517  		net.StartClose()
   518  	}
   519  	wg.Wait()
   520  }
   521  
   522  func TestDialDeletesNonValidators(t *testing.T) {
   523  	require := require.New(t)
   524  
   525  	dialer, listeners, nodeIDs, configs := newTestNetwork(t, 2)
   526  
   527  	vdrs := validators.NewManager()
   528  	for _, nodeID := range nodeIDs {
   529  		require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.GenerateTestID(), 1))
   530  	}
   531  
   532  	networks := make([]Network, len(configs))
   533  	for i, config := range configs {
   534  		msgCreator := newMessageCreator(t)
   535  		registry := prometheus.NewRegistry()
   536  
   537  		beacons := validators.NewManager()
   538  		require.NoError(beacons.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1))
   539  
   540  		config := config
   541  
   542  		config.Beacons = beacons
   543  		config.Validators = vdrs
   544  		config.AllowPrivateIPs = false
   545  
   546  		net, err := NewNetwork(
   547  			config,
   548  			msgCreator,
   549  			registry,
   550  			logging.NoLog{},
   551  			listeners[i],
   552  			dialer,
   553  			&testHandler{
   554  				InboundHandler: nil,
   555  				ConnectedF: func(ids.NodeID, *version.Application, ids.ID) {
   556  					require.FailNow("unexpectedly connected to a peer")
   557  				},
   558  				DisconnectedF: nil,
   559  			},
   560  		)
   561  		require.NoError(err)
   562  		networks[i] = net
   563  	}
   564  
   565  	config := configs[0]
   566  	signer := peer.NewIPSigner(config.MyIPPort, config.TLSKey, config.BLSKey)
   567  	ip, err := signer.GetSignedIP()
   568  	require.NoError(err)
   569  
   570  	wg := sync.WaitGroup{}
   571  	wg.Add(len(networks))
   572  	for i, net := range networks {
   573  		if i != 0 {
   574  			stakingCert, err := staking.ParseCertificate(config.TLSConfig.Certificates[0].Leaf.Raw)
   575  			require.NoError(err)
   576  
   577  			require.NoError(net.Track([]*ips.ClaimedIPPort{
   578  				ips.NewClaimedIPPort(
   579  					stakingCert,
   580  					ip.AddrPort,
   581  					ip.Timestamp,
   582  					ip.TLSSignature,
   583  				),
   584  			}))
   585  		}
   586  
   587  		go func(net Network) {
   588  			defer wg.Done()
   589  
   590  			require.NoError(net.Dispatch())
   591  		}(net)
   592  	}
   593  
   594  	// Give the dialer time to run one iteration. This is racy, but should ony
   595  	// be possible to flake as a false negative (test passes when it shouldn't).
   596  	time.Sleep(50 * time.Millisecond)
   597  
   598  	network := networks[1].(*network)
   599  	require.NoError(vdrs.RemoveWeight(constants.PrimaryNetworkID, nodeIDs[0], 1))
   600  	require.Eventually(
   601  		func() bool {
   602  			network.peersLock.RLock()
   603  			defer network.peersLock.RUnlock()
   604  
   605  			nodeID := nodeIDs[0]
   606  			_, ok := network.trackedIPs[nodeID]
   607  			return !ok
   608  		},
   609  		10*time.Second,
   610  		50*time.Millisecond,
   611  	)
   612  
   613  	for _, net := range networks {
   614  		net.StartClose()
   615  	}
   616  	wg.Wait()
   617  }
   618  
   619  // Test that cancelling the context passed into dial
   620  // causes dial to return immediately.
   621  func TestDialContext(t *testing.T) {
   622  	_, networks, wg := newFullyConnectedTestNetwork(t, []router.InboundHandler{nil})
   623  
   624  	dialer := newTestDialer()
   625  	network := networks[0]
   626  	network.dialer = dialer
   627  
   628  	var (
   629  		neverDialedNodeID = ids.GenerateTestNodeID()
   630  		dialedNodeID      = ids.GenerateTestNodeID()
   631  
   632  		neverDialedIP, neverDialedListener = dialer.NewListener()
   633  		dialedIP, dialedListener           = dialer.NewListener()
   634  
   635  		neverDialedTrackedIP = &trackedIP{
   636  			ip: neverDialedIP,
   637  		}
   638  		dialedTrackedIP = &trackedIP{
   639  			ip: dialedIP,
   640  		}
   641  	)
   642  
   643  	network.ManuallyTrack(neverDialedNodeID, neverDialedIP)
   644  	network.ManuallyTrack(dialedNodeID, dialedIP)
   645  
   646  	// Sanity check that when a non-cancelled context is given,
   647  	// we actually dial the peer.
   648  	network.dial(dialedNodeID, dialedTrackedIP)
   649  
   650  	gotDialedIPConn := make(chan struct{})
   651  	go func() {
   652  		_, _ = dialedListener.Accept()
   653  		close(gotDialedIPConn)
   654  	}()
   655  	<-gotDialedIPConn
   656  
   657  	// Asset that when [n.onCloseCtx] is cancelled, dial returns immediately.
   658  	// That is, [neverDialedListener] doesn't accept a connection.
   659  	network.onCloseCtxCancel()
   660  	network.dial(neverDialedNodeID, neverDialedTrackedIP)
   661  
   662  	gotNeverDialedIPConn := make(chan struct{})
   663  	go func() {
   664  		_, _ = neverDialedListener.Accept()
   665  		close(gotNeverDialedIPConn)
   666  	}()
   667  
   668  	select {
   669  	case <-gotNeverDialedIPConn:
   670  		require.FailNow(t, "unexpectedly connected to peer")
   671  	default:
   672  	}
   673  
   674  	network.StartClose()
   675  	wg.Wait()
   676  }
   677  
   678  func TestAllowConnectionAsAValidator(t *testing.T) {
   679  	require := require.New(t)
   680  
   681  	dialer, listeners, nodeIDs, configs := newTestNetwork(t, 2)
   682  
   683  	networks := make([]Network, len(configs))
   684  	for i, config := range configs {
   685  		msgCreator := newMessageCreator(t)
   686  		registry := prometheus.NewRegistry()
   687  
   688  		beacons := validators.NewManager()
   689  		require.NoError(beacons.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1))
   690  
   691  		vdrs := validators.NewManager()
   692  		require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1))
   693  
   694  		config := config
   695  
   696  		config.Beacons = beacons
   697  		config.Validators = vdrs
   698  		config.RequireValidatorToConnect = true
   699  
   700  		net, err := NewNetwork(
   701  			config,
   702  			msgCreator,
   703  			registry,
   704  			logging.NoLog{},
   705  			listeners[i],
   706  			dialer,
   707  			&testHandler{
   708  				InboundHandler: nil,
   709  				ConnectedF:     nil,
   710  				DisconnectedF:  nil,
   711  			},
   712  		)
   713  		require.NoError(err)
   714  		networks[i] = net
   715  	}
   716  
   717  	wg := sync.WaitGroup{}
   718  	wg.Add(len(networks))
   719  	for i, net := range networks {
   720  		if i != 0 {
   721  			config := configs[0]
   722  			net.ManuallyTrack(config.MyNodeID, config.MyIPPort.Get())
   723  		}
   724  
   725  		go func(net Network) {
   726  			defer wg.Done()
   727  
   728  			require.NoError(net.Dispatch())
   729  		}(net)
   730  	}
   731  
   732  	network := networks[1].(*network)
   733  	require.Eventually(
   734  		func() bool {
   735  			network.peersLock.RLock()
   736  			defer network.peersLock.RUnlock()
   737  
   738  			nodeID := nodeIDs[0]
   739  			_, contains := network.connectedPeers.GetByID(nodeID)
   740  			return contains
   741  		},
   742  		10*time.Second,
   743  		50*time.Millisecond,
   744  	)
   745  
   746  	for _, net := range networks {
   747  		net.StartClose()
   748  	}
   749  	wg.Wait()
   750  }