github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/network/discovery_test.go (about) 1 package network 2 3 import ( 4 "errors" 5 "net" 6 "sort" 7 "sync/atomic" 8 "testing" 9 "time" 10 11 "github.com/nspcc-dev/neo-go/pkg/network/capability" 12 "github.com/nspcc-dev/neo-go/pkg/network/payload" 13 "github.com/stretchr/testify/assert" 14 "github.com/stretchr/testify/require" 15 ) 16 17 type fakeTransp struct { 18 retFalse atomic.Int32 19 started atomic.Bool 20 closed atomic.Bool 21 dialCh chan string 22 host string 23 port string 24 } 25 26 type fakeAPeer struct { 27 addr string 28 peer string 29 version *payload.Version 30 } 31 32 func (f *fakeAPeer) ConnectionAddr() string { 33 return f.addr 34 } 35 36 func (f *fakeAPeer) PeerAddr() net.Addr { 37 tcpAddr, err := net.ResolveTCPAddr("tcp", f.peer) 38 if err != nil { 39 panic(err) 40 } 41 return tcpAddr 42 } 43 44 func (f *fakeAPeer) Version() *payload.Version { 45 return f.version 46 } 47 48 func newFakeTransp(s *Server, addr string) Transporter { 49 tr := &fakeTransp{} 50 h, p, err := net.SplitHostPort(addr) 51 if err == nil { 52 tr.host = h 53 tr.port = p 54 } 55 return tr 56 } 57 58 func (ft *fakeTransp) Dial(addr string, timeout time.Duration) (AddressablePeer, error) { 59 var ret error 60 if ft.retFalse.Load() > 0 { 61 ret = errors.New("smth bad happened") 62 } 63 ft.dialCh <- addr 64 65 return &fakeAPeer{addr: addr, peer: addr}, ret 66 } 67 func (ft *fakeTransp) Accept() { 68 if ft.started.Load() { 69 panic("started twice") 70 } 71 ft.host = "0.0.0.0" 72 ft.port = "42" 73 ft.started.Store(true) 74 } 75 func (ft *fakeTransp) Proto() string { 76 return "" 77 } 78 func (ft *fakeTransp) HostPort() (string, string) { 79 return ft.host, ft.port 80 } 81 func (ft *fakeTransp) Close() { 82 if ft.closed.Load() { 83 panic("closed twice") 84 } 85 ft.closed.Store(true) 86 } 87 func TestDefaultDiscoverer(t *testing.T) { 88 ts := &fakeTransp{} 89 ts.dialCh = make(chan string) 90 d := NewDefaultDiscovery(nil, time.Second/16, ts) 91 92 tryMaxWait = 1 // Don't waste time. 93 var set1 = []string{"1.1.1.1:10333", "2.2.2.2:10333"} 94 sort.Strings(set1) 95 96 // Added addresses should end up in the pool and in the unconnected set. 97 // Done twice to check re-adding unconnected addresses, which should be 98 // a no-op. 99 for i := 0; i < 2; i++ { 100 d.BackFill(set1...) 101 assert.Equal(t, len(set1), d.PoolCount()) 102 set1D := d.UnconnectedPeers() 103 sort.Strings(set1D) 104 assert.Equal(t, 0, len(d.GoodPeers())) 105 assert.Equal(t, 0, len(d.BadPeers())) 106 require.Equal(t, set1, set1D) 107 } 108 require.Equal(t, 2, d.GetFanOut()) 109 110 // Request should make goroutines dial our addresses draining the pool. 111 d.RequestRemote(len(set1)) 112 dialled := make([]string, 0) 113 for i := 0; i < len(set1); i++ { 114 select { 115 case a := <-ts.dialCh: 116 dialled = append(dialled, a) 117 d.RegisterConnected(&fakeAPeer{addr: a, peer: a}) 118 case <-time.After(time.Second): 119 t.Fatalf("timeout expecting for transport dial") 120 } 121 } 122 require.Eventually(t, func() bool { return len(d.UnconnectedPeers()) == 0 }, 2*time.Second, 50*time.Millisecond) 123 sort.Strings(dialled) 124 assert.Equal(t, 0, d.PoolCount()) 125 assert.Equal(t, 0, len(d.BadPeers())) 126 assert.Equal(t, 0, len(d.GoodPeers())) 127 require.Equal(t, set1, dialled) 128 129 // Registered good addresses should end up in appropriate set. 130 for _, addr := range set1 { 131 d.RegisterGood(&fakeAPeer{ 132 addr: addr, 133 peer: addr, 134 version: &payload.Version{ 135 Capabilities: capability.Capabilities{{ 136 Type: capability.FullNode, 137 Data: &capability.Node{StartHeight: 123}, 138 }}, 139 }, 140 }) 141 } 142 gAddrWithCap := d.GoodPeers() 143 gAddrs := make([]string, len(gAddrWithCap)) 144 for i, addr := range gAddrWithCap { 145 require.Equal(t, capability.Capabilities{ 146 { 147 Type: capability.FullNode, 148 Data: &capability.Node{StartHeight: 123}, 149 }, 150 }, addr.Capabilities) 151 gAddrs[i] = addr.Address 152 } 153 sort.Strings(gAddrs) 154 assert.Equal(t, 0, d.PoolCount()) 155 assert.Equal(t, 0, len(d.UnconnectedPeers())) 156 assert.Equal(t, 0, len(d.BadPeers())) 157 require.Equal(t, set1, gAddrs) 158 159 // Re-adding connected addresses should be no-op. 160 d.BackFill(set1...) 161 assert.Equal(t, 0, len(d.UnconnectedPeers())) 162 assert.Equal(t, 0, len(d.BadPeers())) 163 assert.Equal(t, len(set1), len(d.GoodPeers())) 164 require.Equal(t, 0, d.PoolCount()) 165 166 // Unregistering connected should work. 167 for _, addr := range set1 { 168 d.UnregisterConnected(&fakeAPeer{addr: addr, peer: addr}, false) 169 } 170 assert.Equal(t, 2, len(d.UnconnectedPeers())) // They're re-added automatically. 171 assert.Equal(t, 0, len(d.BadPeers())) 172 assert.Equal(t, len(set1), len(d.GoodPeers())) 173 require.Equal(t, 2, d.PoolCount()) 174 175 // Now make Dial() fail and wait to see addresses in the bad list. 176 ts.retFalse.Store(1) 177 assert.Equal(t, len(set1), d.PoolCount()) 178 set1D := d.UnconnectedPeers() 179 sort.Strings(set1D) 180 assert.Equal(t, 0, len(d.BadPeers())) 181 require.Equal(t, set1, set1D) 182 183 dialledBad := make([]string, 0) 184 d.RequestRemote(len(set1)) 185 for i := 0; i < connRetries; i++ { 186 for j := 0; j < len(set1); j++ { 187 select { 188 case a := <-ts.dialCh: 189 dialledBad = append(dialledBad, a) 190 case <-time.After(time.Second): 191 t.Fatalf("timeout expecting for transport dial; i: %d, j: %d", i, j) 192 } 193 } 194 } 195 require.Eventually(t, func() bool { return d.PoolCount() == 0 }, 2*time.Second, 50*time.Millisecond) 196 sort.Strings(dialledBad) 197 for i := 0; i < len(set1); i++ { 198 for j := 0; j < connRetries; j++ { 199 assert.Equal(t, set1[i], dialledBad[i*connRetries+j]) 200 } 201 } 202 require.Eventually(t, func() bool { return len(d.BadPeers()) == len(set1) }, 2*time.Second, 50*time.Millisecond) 203 assert.Equal(t, 0, len(d.GoodPeers())) 204 assert.Equal(t, 0, len(d.UnconnectedPeers())) 205 206 // Re-adding bad addresses is a no-op. 207 d.BackFill(set1...) 208 assert.Equal(t, 0, len(d.UnconnectedPeers())) 209 assert.Equal(t, len(set1), len(d.BadPeers())) 210 assert.Equal(t, 0, len(d.GoodPeers())) 211 require.Equal(t, 0, d.PoolCount()) 212 } 213 214 func TestSeedDiscovery(t *testing.T) { 215 var seeds = []string{"1.1.1.1:10333", "2.2.2.2:10333"} 216 ts := &fakeTransp{} 217 ts.dialCh = make(chan string) 218 ts.retFalse.Store(1) // Fail all dial requests. 219 sort.Strings(seeds) 220 221 d := NewDefaultDiscovery(seeds, time.Second/10, ts) 222 tryMaxWait = 1 // Don't waste time. 223 224 d.RequestRemote(len(seeds)) 225 for i := 0; i < connRetries*2; i++ { 226 for range seeds { 227 select { 228 case <-ts.dialCh: 229 case <-time.After(time.Second): 230 t.Fatalf("timeout expecting for transport dial") 231 } 232 } 233 } 234 }