github.com/ava-labs/avalanchego@v1.11.11/network/network_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package network
     5  
     6  import (
     7  	"context"
     8  	"crypto"
     9  	"net/netip"
    10  	"sync"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/prometheus/client_golang/prometheus"
    15  	"github.com/stretchr/testify/require"
    16  
    17  	"github.com/ava-labs/avalanchego/ids"
    18  	"github.com/ava-labs/avalanchego/message"
    19  	"github.com/ava-labs/avalanchego/network/dialer"
    20  	"github.com/ava-labs/avalanchego/network/peer"
    21  	"github.com/ava-labs/avalanchego/network/throttling"
    22  	"github.com/ava-labs/avalanchego/snow/engine/common"
    23  	"github.com/ava-labs/avalanchego/snow/networking/router"
    24  	"github.com/ava-labs/avalanchego/snow/networking/tracker"
    25  	"github.com/ava-labs/avalanchego/snow/uptime"
    26  	"github.com/ava-labs/avalanchego/snow/validators"
    27  	"github.com/ava-labs/avalanchego/staking"
    28  	"github.com/ava-labs/avalanchego/subnets"
    29  	"github.com/ava-labs/avalanchego/upgrade"
    30  	"github.com/ava-labs/avalanchego/utils"
    31  	"github.com/ava-labs/avalanchego/utils/constants"
    32  	"github.com/ava-labs/avalanchego/utils/crypto/bls"
    33  	"github.com/ava-labs/avalanchego/utils/ips"
    34  	"github.com/ava-labs/avalanchego/utils/logging"
    35  	"github.com/ava-labs/avalanchego/utils/math/meter"
    36  	"github.com/ava-labs/avalanchego/utils/resource"
    37  	"github.com/ava-labs/avalanchego/utils/set"
    38  	"github.com/ava-labs/avalanchego/utils/timer/mockable"
    39  	"github.com/ava-labs/avalanchego/utils/units"
    40  	"github.com/ava-labs/avalanchego/version"
    41  )
    42  
    43  var (
    44  	defaultHealthConfig = HealthConfig{
    45  		MinConnectedPeers:            1,
    46  		MaxTimeSinceMsgReceived:      time.Minute,
    47  		MaxTimeSinceMsgSent:          time.Minute,
    48  		MaxPortionSendQueueBytesFull: .9,
    49  		MaxSendFailRate:              .1,
    50  		SendFailRateHalflife:         time.Second,
    51  	}
    52  	defaultPeerListGossipConfig = PeerListGossipConfig{
    53  		PeerListNumValidatorIPs: 100,
    54  		PeerListPullGossipFreq:  time.Second,
    55  		PeerListBloomResetFreq:  constants.DefaultNetworkPeerListBloomResetFreq,
    56  	}
    57  	defaultTimeoutConfig = TimeoutConfig{
    58  		PingPongTimeout:      30 * time.Second,
    59  		ReadHandshakeTimeout: 15 * time.Second,
    60  	}
    61  	defaultDelayConfig = DelayConfig{
    62  		MaxReconnectDelay:     time.Hour,
    63  		InitialReconnectDelay: time.Second,
    64  	}
    65  	defaultThrottlerConfig = ThrottlerConfig{
    66  		InboundConnUpgradeThrottlerConfig: throttling.InboundConnUpgradeThrottlerConfig{
    67  			UpgradeCooldown:        time.Second,
    68  			MaxRecentConnsUpgraded: 100,
    69  		},
    70  		InboundMsgThrottlerConfig: throttling.InboundMsgThrottlerConfig{
    71  			MsgByteThrottlerConfig: throttling.MsgByteThrottlerConfig{
    72  				VdrAllocSize:        1 * units.GiB,
    73  				AtLargeAllocSize:    1 * units.GiB,
    74  				NodeMaxAtLargeBytes: constants.DefaultMaxMessageSize,
    75  			},
    76  			BandwidthThrottlerConfig: throttling.BandwidthThrottlerConfig{
    77  				RefillRate:   units.MiB,
    78  				MaxBurstSize: constants.DefaultMaxMessageSize,
    79  			},
    80  			CPUThrottlerConfig: throttling.SystemThrottlerConfig{
    81  				MaxRecheckDelay: 50 * time.Millisecond,
    82  			},
    83  			MaxProcessingMsgsPerNode: 100,
    84  			DiskThrottlerConfig: throttling.SystemThrottlerConfig{
    85  				MaxRecheckDelay: 50 * time.Millisecond,
    86  			},
    87  		},
    88  		OutboundMsgThrottlerConfig: throttling.MsgByteThrottlerConfig{
    89  			VdrAllocSize:        1 * units.GiB,
    90  			AtLargeAllocSize:    1 * units.GiB,
    91  			NodeMaxAtLargeBytes: constants.DefaultMaxMessageSize,
    92  		},
    93  		MaxInboundConnsPerSec: 100,
    94  	}
    95  	defaultDialerConfig = dialer.Config{
    96  		ThrottleRps:       100,
    97  		ConnectionTimeout: time.Second,
    98  	}
    99  
   100  	defaultConfig = Config{
   101  		HealthConfig:         defaultHealthConfig,
   102  		PeerListGossipConfig: defaultPeerListGossipConfig,
   103  		TimeoutConfig:        defaultTimeoutConfig,
   104  		DelayConfig:          defaultDelayConfig,
   105  		ThrottlerConfig:      defaultThrottlerConfig,
   106  
   107  		DialerConfig: defaultDialerConfig,
   108  
   109  		NetworkID:          49463,
   110  		MaxClockDifference: time.Minute,
   111  		PingFrequency:      constants.DefaultPingFrequency,
   112  		AllowPrivateIPs:    true,
   113  
   114  		CompressionType: constants.DefaultNetworkCompressionType,
   115  
   116  		UptimeCalculator:  uptime.NewManager(uptime.NewTestState(), &mockable.Clock{}),
   117  		UptimeMetricFreq:  30 * time.Second,
   118  		UptimeRequirement: .8,
   119  
   120  		RequireValidatorToConnect: false,
   121  
   122  		MaximumInboundMessageTimeout: 30 * time.Second,
   123  		ResourceTracker:              newDefaultResourceTracker(),
   124  		CPUTargeter:                  nil, // Set in init
   125  		DiskTargeter:                 nil, // Set in init
   126  	}
   127  )
   128  
   129  func init() {
   130  	defaultConfig.CPUTargeter = newDefaultTargeter(defaultConfig.ResourceTracker.CPUTracker())
   131  	defaultConfig.DiskTargeter = newDefaultTargeter(defaultConfig.ResourceTracker.DiskTracker())
   132  }
   133  
   134  func newDefaultTargeter(t tracker.Tracker) tracker.Targeter {
   135  	return tracker.NewTargeter(
   136  		logging.NoLog{},
   137  		&tracker.TargeterConfig{
   138  			VdrAlloc:           10,
   139  			MaxNonVdrUsage:     10,
   140  			MaxNonVdrNodeUsage: 10,
   141  		},
   142  		validators.NewManager(),
   143  		t,
   144  	)
   145  }
   146  
   147  func newDefaultResourceTracker() tracker.ResourceTracker {
   148  	tracker, err := tracker.NewResourceTracker(
   149  		prometheus.NewRegistry(),
   150  		resource.NoUsage,
   151  		meter.ContinuousFactory{},
   152  		10*time.Second,
   153  	)
   154  	if err != nil {
   155  		panic(err)
   156  	}
   157  	return tracker
   158  }
   159  
   160  func newTestNetwork(t *testing.T, count int) (*testDialer, []*testListener, []ids.NodeID, []*Config) {
   161  	var (
   162  		dialer    = newTestDialer()
   163  		listeners = make([]*testListener, count)
   164  		nodeIDs   = make([]ids.NodeID, count)
   165  		configs   = make([]*Config, count)
   166  	)
   167  	for i := 0; i < count; i++ {
   168  		ip, listener := dialer.NewListener()
   169  
   170  		tlsCert, err := staking.NewTLSCert()
   171  		require.NoError(t, err)
   172  
   173  		cert, err := staking.ParseCertificate(tlsCert.Leaf.Raw)
   174  		require.NoError(t, err)
   175  		nodeID := ids.NodeIDFromCert(cert)
   176  
   177  		blsKey, err := bls.NewSecretKey()
   178  		require.NoError(t, err)
   179  
   180  		config := defaultConfig
   181  		config.TLSConfig = peer.TLSConfig(*tlsCert, nil)
   182  		config.MyNodeID = nodeID
   183  		config.MyIPPort = utils.NewAtomic(ip)
   184  		config.TLSKey = tlsCert.PrivateKey.(crypto.Signer)
   185  		config.BLSKey = blsKey
   186  
   187  		listeners[i] = listener
   188  		nodeIDs[i] = nodeID
   189  		configs[i] = &config
   190  	}
   191  	return dialer, listeners, nodeIDs, configs
   192  }
   193  
   194  func newMessageCreator(t *testing.T) message.Creator {
   195  	t.Helper()
   196  
   197  	mc, err := message.NewCreator(
   198  		logging.NoLog{},
   199  		prometheus.NewRegistry(),
   200  		constants.DefaultNetworkCompressionType,
   201  		10*time.Second,
   202  	)
   203  	require.NoError(t, err)
   204  
   205  	return mc
   206  }
   207  
   208  func newFullyConnectedTestNetwork(t *testing.T, handlers []router.InboundHandler) ([]ids.NodeID, []*network, *sync.WaitGroup) {
   209  	require := require.New(t)
   210  
   211  	dialer, listeners, nodeIDs, configs := newTestNetwork(t, len(handlers))
   212  
   213  	var (
   214  		networks = make([]*network, len(configs))
   215  
   216  		globalLock     sync.Mutex
   217  		numConnected   int
   218  		allConnected   bool
   219  		onAllConnected = make(chan struct{})
   220  	)
   221  	for i, config := range configs {
   222  		msgCreator := newMessageCreator(t)
   223  		registry := prometheus.NewRegistry()
   224  
   225  		beacons := validators.NewManager()
   226  		require.NoError(beacons.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1))
   227  
   228  		vdrs := validators.NewManager()
   229  		for _, nodeID := range nodeIDs {
   230  			require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.GenerateTestID(), 1))
   231  		}
   232  
   233  		config := config
   234  
   235  		config.Beacons = beacons
   236  		config.Validators = vdrs
   237  
   238  		var connected set.Set[ids.NodeID]
   239  		net, err := NewNetwork(
   240  			config,
   241  			upgrade.InitiallyActiveTime,
   242  			msgCreator,
   243  			registry,
   244  			logging.NoLog{},
   245  			listeners[i],
   246  			dialer,
   247  			&testHandler{
   248  				InboundHandler: handlers[i],
   249  				ConnectedF: func(nodeID ids.NodeID, _ *version.Application, _ ids.ID) {
   250  					t.Logf("%s connected to %s", config.MyNodeID, nodeID)
   251  
   252  					globalLock.Lock()
   253  					defer globalLock.Unlock()
   254  
   255  					require.False(connected.Contains(nodeID))
   256  					connected.Add(nodeID)
   257  					numConnected++
   258  
   259  					if !allConnected && numConnected == len(nodeIDs)*(len(nodeIDs)-1) {
   260  						allConnected = true
   261  						close(onAllConnected)
   262  					}
   263  				},
   264  				DisconnectedF: func(nodeID ids.NodeID) {
   265  					t.Logf("%s disconnected from %s", config.MyNodeID, nodeID)
   266  
   267  					globalLock.Lock()
   268  					defer globalLock.Unlock()
   269  
   270  					require.True(connected.Contains(nodeID))
   271  					connected.Remove(nodeID)
   272  					numConnected--
   273  				},
   274  			},
   275  		)
   276  		require.NoError(err)
   277  		networks[i] = net.(*network)
   278  	}
   279  
   280  	wg := sync.WaitGroup{}
   281  	wg.Add(len(networks))
   282  	for i, net := range networks {
   283  		if i != 0 {
   284  			config := configs[0]
   285  			net.ManuallyTrack(config.MyNodeID, config.MyIPPort.Get())
   286  		}
   287  
   288  		go func(net Network) {
   289  			defer wg.Done()
   290  
   291  			require.NoError(net.Dispatch())
   292  		}(net)
   293  	}
   294  
   295  	if len(networks) > 1 {
   296  		<-onAllConnected
   297  	}
   298  
   299  	return nodeIDs, networks, &wg
   300  }
   301  
   302  func TestNewNetwork(t *testing.T) {
   303  	_, networks, wg := newFullyConnectedTestNetwork(t, []router.InboundHandler{nil, nil, nil})
   304  	for _, net := range networks {
   305  		net.StartClose()
   306  	}
   307  	wg.Wait()
   308  }
   309  
   310  func TestSend(t *testing.T) {
   311  	require := require.New(t)
   312  
   313  	received := make(chan message.InboundMessage)
   314  	nodeIDs, networks, wg := newFullyConnectedTestNetwork(
   315  		t,
   316  		[]router.InboundHandler{
   317  			router.InboundHandlerFunc(func(context.Context, message.InboundMessage) {
   318  				require.FailNow("unexpected message received")
   319  			}),
   320  			router.InboundHandlerFunc(func(_ context.Context, msg message.InboundMessage) {
   321  				received <- msg
   322  			}),
   323  			router.InboundHandlerFunc(func(context.Context, message.InboundMessage) {
   324  				require.FailNow("unexpected message received")
   325  			}),
   326  		},
   327  	)
   328  
   329  	net0 := networks[0]
   330  
   331  	mc := newMessageCreator(t)
   332  	outboundGetMsg, err := mc.Get(ids.Empty, 1, time.Second, ids.Empty)
   333  	require.NoError(err)
   334  
   335  	toSend := set.Of(nodeIDs[1])
   336  	sentTo := net0.Send(
   337  		outboundGetMsg,
   338  		common.SendConfig{
   339  			NodeIDs: toSend,
   340  		},
   341  		constants.PrimaryNetworkID,
   342  		subnets.NoOpAllower,
   343  	)
   344  	require.Equal(toSend, sentTo)
   345  
   346  	inboundGetMsg := <-received
   347  	require.Equal(message.GetOp, inboundGetMsg.Op())
   348  
   349  	for _, net := range networks {
   350  		net.StartClose()
   351  	}
   352  	wg.Wait()
   353  }
   354  
   355  func TestSendWithFilter(t *testing.T) {
   356  	require := require.New(t)
   357  
   358  	received := make(chan message.InboundMessage)
   359  	nodeIDs, networks, wg := newFullyConnectedTestNetwork(
   360  		t,
   361  		[]router.InboundHandler{
   362  			router.InboundHandlerFunc(func(context.Context, message.InboundMessage) {
   363  				require.FailNow("unexpected message received")
   364  			}),
   365  			router.InboundHandlerFunc(func(_ context.Context, msg message.InboundMessage) {
   366  				received <- msg
   367  			}),
   368  			router.InboundHandlerFunc(func(context.Context, message.InboundMessage) {
   369  				require.FailNow("unexpected message received")
   370  			}),
   371  		},
   372  	)
   373  
   374  	net0 := networks[0]
   375  
   376  	mc := newMessageCreator(t)
   377  	outboundGetMsg, err := mc.Get(ids.Empty, 1, time.Second, ids.Empty)
   378  	require.NoError(err)
   379  
   380  	toSend := set.Of(nodeIDs...)
   381  	validNodeID := nodeIDs[1]
   382  	sentTo := net0.Send(
   383  		outboundGetMsg,
   384  		common.SendConfig{
   385  			NodeIDs: toSend,
   386  		},
   387  		constants.PrimaryNetworkID,
   388  		newNodeIDConnector(validNodeID),
   389  	)
   390  	require.Len(sentTo, 1)
   391  	require.Contains(sentTo, validNodeID)
   392  
   393  	inboundGetMsg := <-received
   394  	require.Equal(message.GetOp, inboundGetMsg.Op())
   395  
   396  	for _, net := range networks {
   397  		net.StartClose()
   398  	}
   399  	wg.Wait()
   400  }
   401  
   402  func TestTrackVerifiesSignatures(t *testing.T) {
   403  	require := require.New(t)
   404  
   405  	_, networks, wg := newFullyConnectedTestNetwork(t, []router.InboundHandler{nil})
   406  
   407  	network := networks[0]
   408  
   409  	tlsCert, err := staking.NewTLSCert()
   410  	require.NoError(err)
   411  
   412  	cert, err := staking.ParseCertificate(tlsCert.Leaf.Raw)
   413  	require.NoError(err)
   414  	nodeID := ids.NodeIDFromCert(cert)
   415  
   416  	require.NoError(network.config.Validators.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.Empty, 1))
   417  
   418  	stakingCert, err := staking.ParseCertificate(tlsCert.Leaf.Raw)
   419  	require.NoError(err)
   420  
   421  	err = network.Track([]*ips.ClaimedIPPort{
   422  		ips.NewClaimedIPPort(
   423  			stakingCert,
   424  			netip.AddrPortFrom(
   425  				netip.AddrFrom4([4]byte{123, 132, 123, 123}),
   426  				10000,
   427  			),
   428  			1000, // timestamp
   429  			nil,  // signature
   430  		),
   431  	})
   432  	// The signature is wrong so this peer tracking info isn't useful.
   433  	require.ErrorIs(err, staking.ErrECDSAVerificationFailure)
   434  
   435  	network.peersLock.RLock()
   436  	require.Empty(network.trackedIPs)
   437  	network.peersLock.RUnlock()
   438  
   439  	for _, net := range networks {
   440  		net.StartClose()
   441  	}
   442  	wg.Wait()
   443  }
   444  
   445  func TestTrackDoesNotDialPrivateIPs(t *testing.T) {
   446  	require := require.New(t)
   447  
   448  	dialer, listeners, nodeIDs, configs := newTestNetwork(t, 2)
   449  
   450  	networks := make([]Network, len(configs))
   451  	for i, config := range configs {
   452  		msgCreator := newMessageCreator(t)
   453  		registry := prometheus.NewRegistry()
   454  
   455  		beacons := validators.NewManager()
   456  		require.NoError(beacons.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1))
   457  
   458  		vdrs := validators.NewManager()
   459  		for _, nodeID := range nodeIDs {
   460  			require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.GenerateTestID(), 1))
   461  		}
   462  
   463  		config := config
   464  
   465  		config.Beacons = beacons
   466  		config.Validators = vdrs
   467  		config.AllowPrivateIPs = false
   468  
   469  		net, err := NewNetwork(
   470  			config,
   471  			upgrade.InitiallyActiveTime,
   472  			msgCreator,
   473  			registry,
   474  			logging.NoLog{},
   475  			listeners[i],
   476  			dialer,
   477  			&testHandler{
   478  				InboundHandler: nil,
   479  				ConnectedF: func(ids.NodeID, *version.Application, ids.ID) {
   480  					require.FailNow("unexpectedly connected to a peer")
   481  				},
   482  				DisconnectedF: nil,
   483  			},
   484  		)
   485  		require.NoError(err)
   486  		networks[i] = net
   487  	}
   488  
   489  	wg := sync.WaitGroup{}
   490  	wg.Add(len(networks))
   491  	for i, net := range networks {
   492  		if i != 0 {
   493  			config := configs[0]
   494  			net.ManuallyTrack(config.MyNodeID, config.MyIPPort.Get())
   495  		}
   496  
   497  		go func(net Network) {
   498  			defer wg.Done()
   499  
   500  			require.NoError(net.Dispatch())
   501  		}(net)
   502  	}
   503  
   504  	network := networks[1].(*network)
   505  	require.Eventually(
   506  		func() bool {
   507  			network.peersLock.RLock()
   508  			defer network.peersLock.RUnlock()
   509  
   510  			nodeID := nodeIDs[0]
   511  			require.Contains(network.trackedIPs, nodeID)
   512  			ip := network.trackedIPs[nodeID]
   513  			return ip.getDelay() != 0
   514  		},
   515  		10*time.Second,
   516  		50*time.Millisecond,
   517  	)
   518  
   519  	for _, net := range networks {
   520  		net.StartClose()
   521  	}
   522  	wg.Wait()
   523  }
   524  
   525  func TestDialDeletesNonValidators(t *testing.T) {
   526  	require := require.New(t)
   527  
   528  	dialer, listeners, nodeIDs, configs := newTestNetwork(t, 2)
   529  
   530  	vdrs := validators.NewManager()
   531  	for _, nodeID := range nodeIDs {
   532  		require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, nodeID, nil, ids.GenerateTestID(), 1))
   533  	}
   534  
   535  	networks := make([]Network, len(configs))
   536  	for i, config := range configs {
   537  		msgCreator := newMessageCreator(t)
   538  		registry := prometheus.NewRegistry()
   539  
   540  		beacons := validators.NewManager()
   541  		require.NoError(beacons.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1))
   542  
   543  		config := config
   544  
   545  		config.Beacons = beacons
   546  		config.Validators = vdrs
   547  		config.AllowPrivateIPs = false
   548  
   549  		net, err := NewNetwork(
   550  			config,
   551  			upgrade.InitiallyActiveTime,
   552  			msgCreator,
   553  			registry,
   554  			logging.NoLog{},
   555  			listeners[i],
   556  			dialer,
   557  			&testHandler{
   558  				InboundHandler: nil,
   559  				ConnectedF: func(ids.NodeID, *version.Application, ids.ID) {
   560  					require.FailNow("unexpectedly connected to a peer")
   561  				},
   562  				DisconnectedF: nil,
   563  			},
   564  		)
   565  		require.NoError(err)
   566  		networks[i] = net
   567  	}
   568  
   569  	config := configs[0]
   570  	signer := peer.NewIPSigner(config.MyIPPort, config.TLSKey, config.BLSKey)
   571  	ip, err := signer.GetSignedIP()
   572  	require.NoError(err)
   573  
   574  	wg := sync.WaitGroup{}
   575  	wg.Add(len(networks))
   576  	for i, net := range networks {
   577  		if i != 0 {
   578  			stakingCert, err := staking.ParseCertificate(config.TLSConfig.Certificates[0].Leaf.Raw)
   579  			require.NoError(err)
   580  
   581  			require.NoError(net.Track([]*ips.ClaimedIPPort{
   582  				ips.NewClaimedIPPort(
   583  					stakingCert,
   584  					ip.AddrPort,
   585  					ip.Timestamp,
   586  					ip.TLSSignature,
   587  				),
   588  			}))
   589  		}
   590  
   591  		go func(net Network) {
   592  			defer wg.Done()
   593  
   594  			require.NoError(net.Dispatch())
   595  		}(net)
   596  	}
   597  
   598  	// Give the dialer time to run one iteration. This is racy, but should ony
   599  	// be possible to flake as a false negative (test passes when it shouldn't).
   600  	time.Sleep(50 * time.Millisecond)
   601  
   602  	network := networks[1].(*network)
   603  	require.NoError(vdrs.RemoveWeight(constants.PrimaryNetworkID, nodeIDs[0], 1))
   604  	require.Eventually(
   605  		func() bool {
   606  			network.peersLock.RLock()
   607  			defer network.peersLock.RUnlock()
   608  
   609  			nodeID := nodeIDs[0]
   610  			_, ok := network.trackedIPs[nodeID]
   611  			return !ok
   612  		},
   613  		10*time.Second,
   614  		50*time.Millisecond,
   615  	)
   616  
   617  	for _, net := range networks {
   618  		net.StartClose()
   619  	}
   620  	wg.Wait()
   621  }
   622  
   623  // Test that cancelling the context passed into dial
   624  // causes dial to return immediately.
   625  func TestDialContext(t *testing.T) {
   626  	_, networks, wg := newFullyConnectedTestNetwork(t, []router.InboundHandler{nil})
   627  
   628  	dialer := newTestDialer()
   629  	network := networks[0]
   630  	network.dialer = dialer
   631  
   632  	var (
   633  		neverDialedNodeID = ids.GenerateTestNodeID()
   634  		dialedNodeID      = ids.GenerateTestNodeID()
   635  
   636  		neverDialedIP, neverDialedListener = dialer.NewListener()
   637  		dialedIP, dialedListener           = dialer.NewListener()
   638  
   639  		neverDialedTrackedIP = &trackedIP{
   640  			ip: neverDialedIP,
   641  		}
   642  		dialedTrackedIP = &trackedIP{
   643  			ip: dialedIP,
   644  		}
   645  	)
   646  
   647  	network.ManuallyTrack(neverDialedNodeID, neverDialedIP)
   648  	network.ManuallyTrack(dialedNodeID, dialedIP)
   649  
   650  	// Sanity check that when a non-cancelled context is given,
   651  	// we actually dial the peer.
   652  	network.dial(dialedNodeID, dialedTrackedIP)
   653  
   654  	gotDialedIPConn := make(chan struct{})
   655  	go func() {
   656  		_, _ = dialedListener.Accept()
   657  		close(gotDialedIPConn)
   658  	}()
   659  	<-gotDialedIPConn
   660  
   661  	// Asset that when [n.onCloseCtx] is cancelled, dial returns immediately.
   662  	// That is, [neverDialedListener] doesn't accept a connection.
   663  	network.onCloseCtxCancel()
   664  	network.dial(neverDialedNodeID, neverDialedTrackedIP)
   665  
   666  	gotNeverDialedIPConn := make(chan struct{})
   667  	go func() {
   668  		_, _ = neverDialedListener.Accept()
   669  		close(gotNeverDialedIPConn)
   670  	}()
   671  
   672  	select {
   673  	case <-gotNeverDialedIPConn:
   674  		require.FailNow(t, "unexpectedly connected to peer")
   675  	default:
   676  	}
   677  
   678  	network.StartClose()
   679  	wg.Wait()
   680  }
   681  
   682  func TestAllowConnectionAsAValidator(t *testing.T) {
   683  	require := require.New(t)
   684  
   685  	dialer, listeners, nodeIDs, configs := newTestNetwork(t, 2)
   686  
   687  	networks := make([]Network, len(configs))
   688  	for i, config := range configs {
   689  		msgCreator := newMessageCreator(t)
   690  		registry := prometheus.NewRegistry()
   691  
   692  		beacons := validators.NewManager()
   693  		require.NoError(beacons.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1))
   694  
   695  		vdrs := validators.NewManager()
   696  		require.NoError(vdrs.AddStaker(constants.PrimaryNetworkID, nodeIDs[0], nil, ids.GenerateTestID(), 1))
   697  
   698  		config := config
   699  
   700  		config.Beacons = beacons
   701  		config.Validators = vdrs
   702  		config.RequireValidatorToConnect = true
   703  
   704  		net, err := NewNetwork(
   705  			config,
   706  			upgrade.InitiallyActiveTime,
   707  			msgCreator,
   708  			registry,
   709  			logging.NoLog{},
   710  			listeners[i],
   711  			dialer,
   712  			&testHandler{
   713  				InboundHandler: nil,
   714  				ConnectedF:     nil,
   715  				DisconnectedF:  nil,
   716  			},
   717  		)
   718  		require.NoError(err)
   719  		networks[i] = net
   720  	}
   721  
   722  	wg := sync.WaitGroup{}
   723  	wg.Add(len(networks))
   724  	for i, net := range networks {
   725  		if i != 0 {
   726  			config := configs[0]
   727  			net.ManuallyTrack(config.MyNodeID, config.MyIPPort.Get())
   728  		}
   729  
   730  		go func(net Network) {
   731  			defer wg.Done()
   732  
   733  			require.NoError(net.Dispatch())
   734  		}(net)
   735  	}
   736  
   737  	network := networks[1].(*network)
   738  	require.Eventually(
   739  		func() bool {
   740  			network.peersLock.RLock()
   741  			defer network.peersLock.RUnlock()
   742  
   743  			nodeID := nodeIDs[0]
   744  			_, contains := network.connectedPeers.GetByID(nodeID)
   745  			return contains
   746  		},
   747  		10*time.Second,
   748  		50*time.Millisecond,
   749  	)
   750  
   751  	for _, net := range networks {
   752  		net.StartClose()
   753  	}
   754  	wg.Wait()
   755  }