github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/network/p2p/dial.go (about)

     1  package p2p
     2  
     3  import (
     4  	"container/heap"
     5  	"crypto/rand"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"time"
    10  
    11  	"github.com/neatio-net/neatio/chain/log"
    12  	"github.com/neatio-net/neatio/network/p2p/discover"
    13  	"github.com/neatio-net/neatio/network/p2p/netutil"
    14  )
    15  
    16  const (
    17  	dialHistoryExpiration = 30 * time.Second
    18  
    19  	lookupInterval = 4 * time.Second
    20  
    21  	fallbackInterval = 20 * time.Second
    22  
    23  	initialResolveDelay = 60 * time.Second
    24  	maxResolveDelay     = time.Hour
    25  )
    26  
    27  type NodeDialer interface {
    28  	Dial(*discover.Node) (net.Conn, error)
    29  }
    30  
    31  type TCPDialer struct {
    32  	*net.Dialer
    33  }
    34  
    35  func (t TCPDialer) Dial(dest *discover.Node) (net.Conn, error) {
    36  	addr := &net.TCPAddr{IP: dest.IP, Port: int(dest.TCP)}
    37  	return t.Dialer.Dial("tcp", addr.String())
    38  }
    39  
    40  type dialstate struct {
    41  	maxDynDials int
    42  	ntab        discoverTable
    43  	netrestrict *netutil.Netlist
    44  
    45  	lookupRunning bool
    46  	dialing       map[discover.NodeID]connFlag
    47  	lookupBuf     []*discover.Node
    48  	randomNodes   []*discover.Node
    49  	static        map[discover.NodeID]*dialTask
    50  	hist          *dialHistory
    51  
    52  	start     time.Time
    53  	bootnodes []*discover.Node
    54  }
    55  
    56  type discoverTable interface {
    57  	Self() *discover.Node
    58  	Close()
    59  	Resolve(target discover.NodeID) *discover.Node
    60  	Lookup(target discover.NodeID) []*discover.Node
    61  	ReadRandomNodes([]*discover.Node) int
    62  }
    63  
    64  type dialHistory []pastDial
    65  
    66  type pastDial struct {
    67  	id  discover.NodeID
    68  	exp time.Time
    69  }
    70  
    71  type task interface {
    72  	Do(*Server)
    73  }
    74  
    75  type dialTask struct {
    76  	flags        connFlag
    77  	dest         *discover.Node
    78  	lastResolved time.Time
    79  	resolveDelay time.Duration
    80  }
    81  
    82  type discoverTask struct {
    83  	results []*discover.Node
    84  }
    85  
    86  type waitExpireTask struct {
    87  	time.Duration
    88  }
    89  
    90  func newDialState(static []*discover.Node, bootnodes []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate {
    91  	s := &dialstate{
    92  		maxDynDials: maxdyn,
    93  		ntab:        ntab,
    94  		netrestrict: netrestrict,
    95  		static:      make(map[discover.NodeID]*dialTask),
    96  		dialing:     make(map[discover.NodeID]connFlag),
    97  		bootnodes:   make([]*discover.Node, len(bootnodes)),
    98  		randomNodes: make([]*discover.Node, maxdyn/2),
    99  		hist:        new(dialHistory),
   100  	}
   101  	copy(s.bootnodes, bootnodes)
   102  	for _, n := range static {
   103  		s.addStatic(n)
   104  	}
   105  	return s
   106  }
   107  
   108  func (s *dialstate) addStatic(n *discover.Node) {
   109  
   110  	s.static[n.ID] = &dialTask{flags: staticDialedConn, dest: n}
   111  }
   112  
   113  func (s *dialstate) removeStatic(n *discover.Node) {
   114  
   115  	delete(s.static, n.ID)
   116  
   117  	s.hist.remove(n.ID)
   118  }
   119  
   120  func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task {
   121  	if s.start.IsZero() {
   122  		s.start = now
   123  	}
   124  
   125  	var newtasks []task
   126  	addDial := func(flag connFlag, n *discover.Node) bool {
   127  		if err := s.checkDial(n, peers); err != nil {
   128  			log.Trace("Skipping dial candidate", "id", n.ID, "addr", &net.TCPAddr{IP: n.IP, Port: int(n.TCP)}, "err", err)
   129  			return false
   130  		}
   131  		s.dialing[n.ID] = flag
   132  		newtasks = append(newtasks, &dialTask{flags: flag, dest: n})
   133  		return true
   134  	}
   135  
   136  	needDynDials := s.maxDynDials
   137  	for _, p := range peers {
   138  		if p.rw.is(dynDialedConn) {
   139  			needDynDials--
   140  		}
   141  	}
   142  	for _, flag := range s.dialing {
   143  		if flag&dynDialedConn != 0 {
   144  			needDynDials--
   145  		}
   146  	}
   147  
   148  	s.hist.expire(now)
   149  
   150  	for id, t := range s.static {
   151  		err := s.checkDial(t.dest, peers)
   152  		switch err {
   153  		case errNotWhitelisted, errSelf:
   154  			log.Warn("Removing static dial candidate", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)}, "err", err)
   155  			delete(s.static, t.dest.ID)
   156  		case nil:
   157  			s.dialing[id] = t.flags
   158  			newtasks = append(newtasks, t)
   159  		}
   160  	}
   161  
   162  	if len(peers) == 0 && len(s.bootnodes) > 0 && needDynDials > 0 && now.Sub(s.start) > fallbackInterval {
   163  		bootnode := s.bootnodes[0]
   164  		s.bootnodes = append(s.bootnodes[:0], s.bootnodes[1:]...)
   165  		s.bootnodes = append(s.bootnodes, bootnode)
   166  
   167  		if addDial(dynDialedConn, bootnode) {
   168  			needDynDials--
   169  		}
   170  	}
   171  
   172  	randomCandidates := needDynDials / 2
   173  	if randomCandidates > 0 {
   174  		n := s.ntab.ReadRandomNodes(s.randomNodes)
   175  		for i := 0; i < randomCandidates && i < n; i++ {
   176  			if addDial(dynDialedConn, s.randomNodes[i]) {
   177  				needDynDials--
   178  			}
   179  		}
   180  	}
   181  
   182  	i := 0
   183  	for ; i < len(s.lookupBuf) && needDynDials > 0; i++ {
   184  		if addDial(dynDialedConn, s.lookupBuf[i]) {
   185  			needDynDials--
   186  		}
   187  	}
   188  	s.lookupBuf = s.lookupBuf[:copy(s.lookupBuf, s.lookupBuf[i:])]
   189  
   190  	if len(s.lookupBuf) < needDynDials && !s.lookupRunning {
   191  		s.lookupRunning = true
   192  		newtasks = append(newtasks, &discoverTask{})
   193  	}
   194  
   195  	if nRunning == 0 && len(newtasks) == 0 && s.hist.Len() > 0 {
   196  		t := &waitExpireTask{s.hist.min().exp.Sub(now)}
   197  		newtasks = append(newtasks, t)
   198  	}
   199  	return newtasks
   200  }
   201  
   202  var (
   203  	errSelf             = errors.New("is self")
   204  	errAlreadyDialing   = errors.New("already dialing")
   205  	errAlreadyConnected = errors.New("already connected")
   206  	errRecentlyDialed   = errors.New("recently dialed")
   207  	errNotWhitelisted   = errors.New("not contained in netrestrict whitelist")
   208  )
   209  
   210  func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error {
   211  	_, dialing := s.dialing[n.ID]
   212  	switch {
   213  	case dialing:
   214  		return errAlreadyDialing
   215  	case peers[n.ID] != nil:
   216  		return errAlreadyConnected
   217  	case s.ntab != nil && n.ID == s.ntab.Self().ID:
   218  		return errSelf
   219  	case s.netrestrict != nil && !s.netrestrict.Contains(n.IP):
   220  		return errNotWhitelisted
   221  	case s.hist.contains(n.ID):
   222  		return errRecentlyDialed
   223  	}
   224  	return nil
   225  }
   226  
   227  func (s *dialstate) taskDone(t task, now time.Time) {
   228  	switch t := t.(type) {
   229  	case *dialTask:
   230  		s.hist.add(t.dest.ID, now.Add(dialHistoryExpiration))
   231  		delete(s.dialing, t.dest.ID)
   232  	case *discoverTask:
   233  		s.lookupRunning = false
   234  		s.lookupBuf = append(s.lookupBuf, t.results...)
   235  	}
   236  }
   237  
   238  func (t *dialTask) Do(srv *Server) {
   239  	if t.dest.Incomplete() {
   240  		if !t.resolve(srv) {
   241  			return
   242  		}
   243  	}
   244  	err := t.dial(srv, t.dest)
   245  	if err != nil {
   246  		log.Trace("Dial error", "task", t, "err", err)
   247  
   248  		if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 {
   249  			if t.resolve(srv) {
   250  				t.dial(srv, t.dest)
   251  			}
   252  		}
   253  	}
   254  }
   255  
   256  func (t *dialTask) resolve(srv *Server) bool {
   257  	if srv.ntab == nil {
   258  		log.Debug("Can't resolve node", "id", t.dest.ID, "err", "discovery is disabled")
   259  		return false
   260  	}
   261  	if t.resolveDelay == 0 {
   262  		t.resolveDelay = initialResolveDelay
   263  	}
   264  	if time.Since(t.lastResolved) < t.resolveDelay {
   265  		return false
   266  	}
   267  	resolved := srv.ntab.Resolve(t.dest.ID)
   268  	t.lastResolved = time.Now()
   269  	if resolved == nil {
   270  		t.resolveDelay *= 2
   271  		if t.resolveDelay > maxResolveDelay {
   272  			t.resolveDelay = maxResolveDelay
   273  		}
   274  		log.Debug("Resolving node failed", "id", t.dest.ID, "newdelay", t.resolveDelay)
   275  		return false
   276  	}
   277  
   278  	t.resolveDelay = initialResolveDelay
   279  	t.dest = resolved
   280  	log.Debug("Resolved node", "id", t.dest.ID, "addr", &net.TCPAddr{IP: t.dest.IP, Port: int(t.dest.TCP)})
   281  	return true
   282  }
   283  
   284  type dialError struct {
   285  	error
   286  }
   287  
   288  func (t *dialTask) dial(srv *Server, dest *discover.Node) error {
   289  	fd, err := srv.Dialer.Dial(dest)
   290  	if err != nil {
   291  		return &dialError{err}
   292  	}
   293  	mfd := newMeteredConn(fd, false)
   294  	return srv.SetupConn(mfd, t.flags, dest)
   295  }
   296  
   297  func (t *dialTask) String() string {
   298  	return fmt.Sprintf("%v %x %v:%d", t.flags, t.dest.ID[:8], t.dest.IP, t.dest.TCP)
   299  }
   300  
   301  func (t *discoverTask) Do(srv *Server) {
   302  
   303  	next := srv.lastLookup.Add(lookupInterval)
   304  	if now := time.Now(); now.Before(next) {
   305  		time.Sleep(next.Sub(now))
   306  	}
   307  	srv.lastLookup = time.Now()
   308  	var target discover.NodeID
   309  	rand.Read(target[:])
   310  	t.results = srv.ntab.Lookup(target)
   311  }
   312  
   313  func (t *discoverTask) String() string {
   314  	s := "discovery lookup"
   315  	if len(t.results) > 0 {
   316  		s += fmt.Sprintf(" (%d results)", len(t.results))
   317  	}
   318  	return s
   319  }
   320  
   321  func (t waitExpireTask) Do(*Server) {
   322  	time.Sleep(t.Duration)
   323  }
   324  func (t waitExpireTask) String() string {
   325  	return fmt.Sprintf("wait for dial hist expire (%v)", t.Duration)
   326  }
   327  
   328  func (h dialHistory) min() pastDial {
   329  	return h[0]
   330  }
   331  func (h *dialHistory) add(id discover.NodeID, exp time.Time) {
   332  	heap.Push(h, pastDial{id, exp})
   333  
   334  }
   335  func (h *dialHistory) remove(id discover.NodeID) bool {
   336  	for i, v := range *h {
   337  		if v.id == id {
   338  			heap.Remove(h, i)
   339  			return true
   340  		}
   341  	}
   342  	return false
   343  }
   344  func (h dialHistory) contains(id discover.NodeID) bool {
   345  	for _, v := range h {
   346  		if v.id == id {
   347  			return true
   348  		}
   349  	}
   350  	return false
   351  }
   352  func (h *dialHistory) expire(now time.Time) {
   353  	for h.Len() > 0 && h.min().exp.Before(now) {
   354  		heap.Pop(h)
   355  	}
   356  }
   357  
   358  func (h dialHistory) Len() int           { return len(h) }
   359  func (h dialHistory) Less(i, j int) bool { return h[i].exp.Before(h[j].exp) }
   360  func (h dialHistory) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }
   361  func (h *dialHistory) Push(x interface{}) {
   362  	*h = append(*h, x.(pastDial))
   363  }
   364  func (h *dialHistory) Pop() interface{} {
   365  	old := *h
   366  	n := len(old)
   367  	x := old[n-1]
   368  	*h = old[0 : n-1]
   369  	return x
   370  }