github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/network/p2p/dial.go (about) 1 package p2p 2 3 import ( 4 "container/heap" 5 "crypto/rand" 6 "errors" 7 "fmt" 8 "net" 9 "time" 10 11 "github.com/neatio-net/neatio/chain/log" 12 "github.com/neatio-net/neatio/network/p2p/discover" 13 "github.com/neatio-net/neatio/network/p2p/netutil" 14 ) 15 16 const ( 17 dialHistoryExpiration = 30 * time.Second 18 19 lookupInterval = 4 * time.Second 20 21 fallbackInterval = 20 * time.Second 22 23 initialResolveDelay = 60 * time.Second 24 maxResolveDelay = time.Hour 25 ) 26 27 type NodeDialer interface { 28 Dial(*discover.Node) (net.Conn, error) 29 } 30 31 type TCPDialer struct { 32 *net.Dialer 33 } 34 35 func (t TCPDialer) Dial(dest *discover.Node) (net.Conn, error) { 36 addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)} 37 return t.Dialer.Dial("tcp", addr.String()) 38 } 39 40 type dialstate struct { 41 maxDynDials int 42 ntab discoverTable 43 netrestrict *netutil.Netlist 44 45 lookupRunning bool 46 dialing map[discover.NodeID]connFlag 47 lookupBuf []*discover.Node 48 randomNodes []*discover.Node 49 static map[discover.NodeID]*dialTask 50 hist *dialHistory 51 52 start time.Time 53 bootnodes []*discover.Node 54 } 55 56 type discoverTable interface { 57 Self() *discover.Node 58 Close() 59 Resolve(target discover.NodeID) *discover.Node 60 Lookup(target discover.NodeID) []*discover.Node 61 ReadRandomNodes([]*discover.Node) int 62 } 63 64 type dialHistory []pastDial 65 66 type pastDial struct { 67 id discover.NodeID 68 exp time.Time 69 } 70 71 type task interface { 72 Do(*Server) 73 } 74 75 type dialTask struct { 76 flags connFlag 77 dest *discover.Node 78 lastResolved time.Time 79 resolveDelay time.Duration 80 } 81 82 type discoverTask struct { 83 results []*discover.Node 84 } 85 86 type waitExpireTask struct { 87 time.Duration 88 } 89 90 func newDialState(static []*discover.Node, bootnodes []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { 91 s := &dialstate{ 92 maxDynDials: maxdyn, 93 ntab: ntab, 94 netrestrict: netrestrict, 95 static: make(map[discover.NodeID]*dialTask), 96 dialing: make(map[discover.NodeID]connFlag), 97 bootnodes: make([]*discover.Node, len(bootnodes)), 98 randomNodes: make([]*discover.Node, maxdyn/2), 99 hist: new(dialHistory), 100 } 101 copy(s.bootnodes, bootnodes) 102 for _, n := range static { 103 s.addStatic(n) 104 } 105 return s 106 } 107 108 func (s *dialstate) addStatic(n *discover.Node) { 109 110 s.static[n.ID] = &dialTask{flags: staticDialedConn, dest: n} 111 } 112 113 func (s *dialstate) removeStatic(n *discover.Node) { 114 115 delete(s.static, n.ID) 116 117 s.hist.remove(n.ID) 118 } 119 120 func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { 121 if s.start.IsZero() { 122 s.start = now 123 } 124 125 var newtasks []task 126 addDial := func(flag connFlag, n *discover.Node) bool { 127 if err := s.checkDial(n, peers); err != nil { 128 log.Trace("Skipping dial candidate", "id", n.ID, "addr", &net.TCPAddr{IP: n.IP, Port: int(n.TCP)}, "err", err) 129 return false 130 } 131 s.dialing[n.ID] = flag 132 newtasks = append(newtasks, &dialTask{flags: flag, dest: n}) 133 return true 134 } 135 136 needDynDials := s.maxDynDials 137 for _, p := range peers { 138 if p.rw.is(dynDialedConn) { 139 needDynDials-- 140 } 141 } 142 for _, flag := range s.dialing { 143 if flag&dynDialedConn != 0 { 144 needDynDials-- 145 } 146 } 147 148 s.hist.expire(now) 149 150 for id, t := range s.static { 151 err := s.checkDial(t.dest, peers) 152 switch err { 153 case errNotWhitelisted, errSelf: 154 log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)}, "err", err) 155 delete(s.static, t.dest.ID) 156 case nil: 157 s.dialing[id] = t.flags 158 newtasks = append(newtasks, t) 159 } 160 } 161 162 if len(peers) == 0 && len(s.bootnodes) > 0 && needDynDials > 0 && now.Sub(s.start) > fallbackInterval { 163 bootnode := s.bootnodes[0] 164 s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...) 165 s.bootnodes = append(s.bootnodes, bootnode) 166 167 if addDial(dynDialedConn, bootnode) { 168 needDynDials-- 169 } 170 } 171 172 randomCandidates := needDynDials / 2 173 if randomCandidates > 0 { 174 n := s.ntab.ReadRandomNodes(s.randomNodes) 175 for i := 0; i < randomCandidates && i < n; i++ { 176 if addDial(dynDialedConn, s.randomNodes[i]) { 177 needDynDials-- 178 } 179 } 180 } 181 182 i := 0 183 for ; i < len(s.lookupBuf) && needDynDials > 0; i++ { 184 if addDial(dynDialedConn, s.lookupBuf[i]) { 185 needDynDials-- 186 } 187 } 188 s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])] 189 190 if len(s.lookupBuf) < needDynDials && !s.lookupRunning { 191 s.lookupRunning = true 192 newtasks = append(newtasks, &discoverTask{}) 193 } 194 195 if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 { 196 t := &waitExpireTask{s.hist.min().exp.Sub(now)} 197 newtasks = append(newtasks, t) 198 } 199 return newtasks 200 } 201 202 var ( 203 errSelf = errors.New("is self") 204 errAlreadyDialing = errors.New("already dialing") 205 errAlreadyConnected = errors.New("already connected") 206 errRecentlyDialed = errors.New("recently dialed") 207 errNotWhitelisted = errors.New("not contained in netrestrict whitelist") 208 ) 209 210 func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error { 211 _, dialing := s.dialing[n.ID] 212 switch { 213 case dialing: 214 return errAlreadyDialing 215 case peers[n.ID] != nil: 216 return errAlreadyConnected 217 case s.ntab != nil && n.ID == s.ntab.Self().ID: 218 return errSelf 219 case s.netrestrict != nil && !s.netrestrict.Contains(n.IP): 220 return errNotWhitelisted 221 case s.hist.contains(n.ID): 222 return errRecentlyDialed 223 } 224 return nil 225 } 226 227 func (s *dialstate) taskDone(t task, now time.Time) { 228 switch t := t.(type) { 229 case *dialTask: 230 s.hist.add(t.dest.ID, now.Add(dialHistoryExpiration)) 231 delete(s.dialing, t.dest.ID) 232 case *discoverTask: 233 s.lookupRunning = false 234 s.lookupBuf = append(s.lookupBuf, t.results...) 235 } 236 } 237 238 func (t *dialTask) Do(srv *Server) { 239 if t.dest.Incomplete() { 240 if !t.resolve(srv) { 241 return 242 } 243 } 244 err := t.dial(srv, t.dest) 245 if err != nil { 246 log.Trace("Dial error", "task", t, "err", err) 247 248 if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { 249 if t.resolve(srv) { 250 t.dial(srv, t.dest) 251 } 252 } 253 } 254 } 255 256 func (t *dialTask) resolve(srv *Server) bool { 257 if srv.ntab == nil { 258 log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled") 259 return false 260 } 261 if t.resolveDelay == 0 { 262 t.resolveDelay = initialResolveDelay 263 } 264 if time.Since(t.lastResolved) < t.resolveDelay { 265 return false 266 } 267 resolved := srv.ntab.Resolve(t.dest.ID) 268 t.lastResolved = time.Now() 269 if resolved == nil { 270 t.resolveDelay *= 2 271 if t.resolveDelay > maxResolveDelay { 272 t.resolveDelay = maxResolveDelay 273 } 274 log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay) 275 return false 276 } 277 278 t.resolveDelay = initialResolveDelay 279 t.dest = resolved 280 log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)}) 281 return true 282 } 283 284 type dialError struct { 285 error 286 } 287 288 func (t *dialTask) dial(srv *Server, dest *discover.Node) error { 289 fd, err := srv.Dialer.Dial(dest) 290 if err != nil { 291 return &dialError{err} 292 } 293 mfd := newMeteredConn(fd, false) 294 return srv.SetupConn(mfd, t.flags, dest) 295 } 296 297 func (t *dialTask) String() string { 298 return fmt.Sprintf("%v %x %v:%d", t.flags, t.dest.ID[:8], t.dest.IP, t.dest.TCP) 299 } 300 301 func (t *discoverTask) Do(srv *Server) { 302 303 next := srv.lastLookup.Add(lookupInterval) 304 if now := time.Now(); now.Before(next) { 305 time.Sleep(next.Sub(now)) 306 } 307 srv.lastLookup = time.Now() 308 var target discover.NodeID 309 rand.Read(target[:]) 310 t.results = srv.ntab.Lookup(target) 311 } 312 313 func (t *discoverTask) String() string { 314 s := "discovery lookup" 315 if len(t.results) > 0 { 316 s += fmt.Sprintf(" (%d results)", len(t.results)) 317 } 318 return s 319 } 320 321 func (t waitExpireTask) Do(*Server) { 322 time.Sleep(t.Duration) 323 } 324 func (t waitExpireTask) String() string { 325 return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration) 326 } 327 328 func (h dialHistory) min() pastDial { 329 return h[0] 330 } 331 func (h *dialHistory) add(id discover.NodeID, exp time.Time) { 332 heap.Push(h, pastDial{id, exp}) 333 334 } 335 func (h *dialHistory) remove(id discover.NodeID) bool { 336 for i, v := range *h { 337 if v.id == id { 338 heap.Remove(h, i) 339 return true 340 } 341 } 342 return false 343 } 344 func (h dialHistory) contains(id discover.NodeID) bool { 345 for _, v := range h { 346 if v.id == id { 347 return true 348 } 349 } 350 return false 351 } 352 func (h *dialHistory) expire(now time.Time) { 353 for h.Len() > 0 && h.min().exp.Before(now) { 354 heap.Pop(h) 355 } 356 } 357 358 func (h dialHistory) Len() int { return len(h) } 359 func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) } 360 func (h dialHistory) Swap(i, j int) { h[i], h[j] = h[j], h[i] } 361 func (h *dialHistory) Push(x interface{}) { 362 *h = append(*h, x.(pastDial)) 363 } 364 func (h *dialHistory) Pop() interface{} { 365 old := *h 366 n := len(old) 367 x := old[n-1] 368 *h = old[0 : n-1] 369 return x 370 }