github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/network/discovery_test.go (about)

     1  package network
     2  
     3  import (
     4  	"errors"
     5  	"net"
     6  	"sort"
     7  	"sync/atomic"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/nspcc-dev/neo-go/pkg/network/capability"
    12  	"github.com/nspcc-dev/neo-go/pkg/network/payload"
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  type fakeTransp struct {
    18  	retFalse atomic.Int32
    19  	started  atomic.Bool
    20  	closed   atomic.Bool
    21  	dialCh   chan string
    22  	host     string
    23  	port     string
    24  }
    25  
    26  type fakeAPeer struct {
    27  	addr    string
    28  	peer    string
    29  	version *payload.Version
    30  }
    31  
    32  func (f *fakeAPeer) ConnectionAddr() string {
    33  	return f.addr
    34  }
    35  
    36  func (f *fakeAPeer) PeerAddr() net.Addr {
    37  	tcpAddr, err := net.ResolveTCPAddr("tcp", f.peer)
    38  	if err != nil {
    39  		panic(err)
    40  	}
    41  	return tcpAddr
    42  }
    43  
    44  func (f *fakeAPeer) Version() *payload.Version {
    45  	return f.version
    46  }
    47  
    48  func newFakeTransp(s *Server, addr string) Transporter {
    49  	tr := &fakeTransp{}
    50  	h, p, err := net.SplitHostPort(addr)
    51  	if err == nil {
    52  		tr.host = h
    53  		tr.port = p
    54  	}
    55  	return tr
    56  }
    57  
    58  func (ft *fakeTransp) Dial(addr string, timeout time.Duration) (AddressablePeer, error) {
    59  	var ret error
    60  	if ft.retFalse.Load() > 0 {
    61  		ret = errors.New("smth bad happened")
    62  	}
    63  	ft.dialCh <- addr
    64  
    65  	return &fakeAPeer{addr: addr, peer: addr}, ret
    66  }
    67  func (ft *fakeTransp) Accept() {
    68  	if ft.started.Load() {
    69  		panic("started twice")
    70  	}
    71  	ft.host = "0.0.0.0"
    72  	ft.port = "42"
    73  	ft.started.Store(true)
    74  }
    75  func (ft *fakeTransp) Proto() string {
    76  	return ""
    77  }
    78  func (ft *fakeTransp) HostPort() (string, string) {
    79  	return ft.host, ft.port
    80  }
    81  func (ft *fakeTransp) Close() {
    82  	if ft.closed.Load() {
    83  		panic("closed twice")
    84  	}
    85  	ft.closed.Store(true)
    86  }
    87  func TestDefaultDiscoverer(t *testing.T) {
    88  	ts := &fakeTransp{}
    89  	ts.dialCh = make(chan string)
    90  	d := NewDefaultDiscovery(nil, time.Second/16, ts)
    91  
    92  	tryMaxWait = 1 // Don't waste time.
    93  	var set1 = []string{"1.1.1.1:10333", "2.2.2.2:10333"}
    94  	sort.Strings(set1)
    95  
    96  	// Added addresses should end up in the pool and in the unconnected set.
    97  	// Done twice to check re-adding unconnected addresses, which should be
    98  	// a no-op.
    99  	for i := 0; i < 2; i++ {
   100  		d.BackFill(set1...)
   101  		assert.Equal(t, len(set1), d.PoolCount())
   102  		set1D := d.UnconnectedPeers()
   103  		sort.Strings(set1D)
   104  		assert.Equal(t, 0, len(d.GoodPeers()))
   105  		assert.Equal(t, 0, len(d.BadPeers()))
   106  		require.Equal(t, set1, set1D)
   107  	}
   108  	require.Equal(t, 2, d.GetFanOut())
   109  
   110  	// Request should make goroutines dial our addresses draining the pool.
   111  	d.RequestRemote(len(set1))
   112  	dialled := make([]string, 0)
   113  	for i := 0; i < len(set1); i++ {
   114  		select {
   115  		case a := <-ts.dialCh:
   116  			dialled = append(dialled, a)
   117  			d.RegisterConnected(&fakeAPeer{addr: a, peer: a})
   118  		case <-time.After(time.Second):
   119  			t.Fatalf("timeout expecting for transport dial")
   120  		}
   121  	}
   122  	require.Eventually(t, func() bool { return len(d.UnconnectedPeers()) == 0 }, 2*time.Second, 50*time.Millisecond)
   123  	sort.Strings(dialled)
   124  	assert.Equal(t, 0, d.PoolCount())
   125  	assert.Equal(t, 0, len(d.BadPeers()))
   126  	assert.Equal(t, 0, len(d.GoodPeers()))
   127  	require.Equal(t, set1, dialled)
   128  
   129  	// Registered good addresses should end up in appropriate set.
   130  	for _, addr := range set1 {
   131  		d.RegisterGood(&fakeAPeer{
   132  			addr: addr,
   133  			peer: addr,
   134  			version: &payload.Version{
   135  				Capabilities: capability.Capabilities{{
   136  					Type: capability.FullNode,
   137  					Data: &capability.Node{StartHeight: 123},
   138  				}},
   139  			},
   140  		})
   141  	}
   142  	gAddrWithCap := d.GoodPeers()
   143  	gAddrs := make([]string, len(gAddrWithCap))
   144  	for i, addr := range gAddrWithCap {
   145  		require.Equal(t, capability.Capabilities{
   146  			{
   147  				Type: capability.FullNode,
   148  				Data: &capability.Node{StartHeight: 123},
   149  			},
   150  		}, addr.Capabilities)
   151  		gAddrs[i] = addr.Address
   152  	}
   153  	sort.Strings(gAddrs)
   154  	assert.Equal(t, 0, d.PoolCount())
   155  	assert.Equal(t, 0, len(d.UnconnectedPeers()))
   156  	assert.Equal(t, 0, len(d.BadPeers()))
   157  	require.Equal(t, set1, gAddrs)
   158  
   159  	// Re-adding connected addresses should be no-op.
   160  	d.BackFill(set1...)
   161  	assert.Equal(t, 0, len(d.UnconnectedPeers()))
   162  	assert.Equal(t, 0, len(d.BadPeers()))
   163  	assert.Equal(t, len(set1), len(d.GoodPeers()))
   164  	require.Equal(t, 0, d.PoolCount())
   165  
   166  	// Unregistering connected should work.
   167  	for _, addr := range set1 {
   168  		d.UnregisterConnected(&fakeAPeer{addr: addr, peer: addr}, false)
   169  	}
   170  	assert.Equal(t, 2, len(d.UnconnectedPeers())) // They're re-added automatically.
   171  	assert.Equal(t, 0, len(d.BadPeers()))
   172  	assert.Equal(t, len(set1), len(d.GoodPeers()))
   173  	require.Equal(t, 2, d.PoolCount())
   174  
   175  	// Now make Dial() fail and wait to see addresses in the bad list.
   176  	ts.retFalse.Store(1)
   177  	assert.Equal(t, len(set1), d.PoolCount())
   178  	set1D := d.UnconnectedPeers()
   179  	sort.Strings(set1D)
   180  	assert.Equal(t, 0, len(d.BadPeers()))
   181  	require.Equal(t, set1, set1D)
   182  
   183  	dialledBad := make([]string, 0)
   184  	d.RequestRemote(len(set1))
   185  	for i := 0; i < connRetries; i++ {
   186  		for j := 0; j < len(set1); j++ {
   187  			select {
   188  			case a := <-ts.dialCh:
   189  				dialledBad = append(dialledBad, a)
   190  			case <-time.After(time.Second):
   191  				t.Fatalf("timeout expecting for transport dial; i: %d, j: %d", i, j)
   192  			}
   193  		}
   194  	}
   195  	require.Eventually(t, func() bool { return d.PoolCount() == 0 }, 2*time.Second, 50*time.Millisecond)
   196  	sort.Strings(dialledBad)
   197  	for i := 0; i < len(set1); i++ {
   198  		for j := 0; j < connRetries; j++ {
   199  			assert.Equal(t, set1[i], dialledBad[i*connRetries+j])
   200  		}
   201  	}
   202  	require.Eventually(t, func() bool { return len(d.BadPeers()) == len(set1) }, 2*time.Second, 50*time.Millisecond)
   203  	assert.Equal(t, 0, len(d.GoodPeers()))
   204  	assert.Equal(t, 0, len(d.UnconnectedPeers()))
   205  
   206  	// Re-adding bad addresses is a no-op.
   207  	d.BackFill(set1...)
   208  	assert.Equal(t, 0, len(d.UnconnectedPeers()))
   209  	assert.Equal(t, len(set1), len(d.BadPeers()))
   210  	assert.Equal(t, 0, len(d.GoodPeers()))
   211  	require.Equal(t, 0, d.PoolCount())
   212  }
   213  
   214  func TestSeedDiscovery(t *testing.T) {
   215  	var seeds = []string{"1.1.1.1:10333", "2.2.2.2:10333"}
   216  	ts := &fakeTransp{}
   217  	ts.dialCh = make(chan string)
   218  	ts.retFalse.Store(1) // Fail all dial requests.
   219  	sort.Strings(seeds)
   220  
   221  	d := NewDefaultDiscovery(seeds, time.Second/10, ts)
   222  	tryMaxWait = 1 // Don't waste time.
   223  
   224  	d.RequestRemote(len(seeds))
   225  	for i := 0; i < connRetries*2; i++ {
   226  		for range seeds {
   227  			select {
   228  			case <-ts.dialCh:
   229  			case <-time.After(time.Second):
   230  				t.Fatalf("timeout expecting for transport dial")
   231  			}
   232  		}
   233  	}
   234  }