github.com/bytom/bytom@v1.1.2-0.20221014091027-bbcba3df6075/p2p/switch.go (about)

     1  package p2p
     2  
     3  import (
     4  	"encoding/hex"
     5  	"fmt"
     6  	"net"
     7  	"sync"
     8  	"time"
     9  
    10  	log "github.com/sirupsen/logrus"
    11  	cmn "github.com/tendermint/tmlibs/common"
    12  
    13  	cfg "github.com/bytom/bytom/config"
    14  	"github.com/bytom/bytom/consensus"
    15  	"github.com/bytom/bytom/crypto/ed25519/chainkd"
    16  	"github.com/bytom/bytom/errors"
    17  	"github.com/bytom/bytom/event"
    18  	"github.com/bytom/bytom/p2p/connection"
    19  	"github.com/bytom/bytom/p2p/discover/dht"
    20  	"github.com/bytom/bytom/p2p/discover/mdns"
    21  	"github.com/bytom/bytom/p2p/netutil"
    22  	"github.com/bytom/bytom/p2p/security"
    23  	"github.com/bytom/bytom/version"
    24  )
    25  
    26  const (
    27  	logModule = "p2p"
    28  
    29  	minNumOutboundPeers = 4
    30  	maxNumLANPeers      = 15
    31  )
    32  
    33  //pre-define errors for connecting fail
    34  var (
    35  	ErrDuplicatePeer  = errors.New("Duplicate peer")
    36  	ErrConnectSelf    = errors.New("Connect self")
    37  	ErrConnectSpvPeer = errors.New("Outbound connect spv peer")
    38  )
    39  
    40  type discv interface {
    41  	ReadRandomNodes(buf []*dht.Node) (n int)
    42  }
    43  
    44  type lanDiscv interface {
    45  	Subscribe() (*event.Subscription, error)
    46  	Stop()
    47  }
    48  
    49  type Security interface {
    50  	DoFilter(ip string, pubKey string) error
    51  	IsBanned(ip string, level byte, reason string) bool
    52  	RegisterFilter(filter security.Filter)
    53  	Start() error
    54  }
    55  
    56  // Switch handles peer connections and exposes an API to receive incoming messages
    57  // on `Reactors`.  Each `Reactor` is responsible for handling incoming messages of one
    58  // or more `Channels`.  So while sending outgoing messages is typically performed on the peer,
    59  // incoming messages are received on the reactor.
    60  type Switch struct {
    61  	cmn.BaseService
    62  
    63  	Config       *cfg.Config
    64  	peerConfig   *PeerConfig
    65  	listeners    []Listener
    66  	reactors     map[string]Reactor
    67  	chDescs      []*connection.ChannelDescriptor
    68  	reactorsByCh map[byte]Reactor
    69  	peers        *PeerSet
    70  	dialing      *cmn.CMap
    71  	nodeInfo     *NodeInfo             // our node info
    72  	nodePrivKey  chainkd.XPrv // our node privkey
    73  	discv        discv
    74  	lanDiscv     lanDiscv
    75  	security     Security
    76  }
    77  
    78  // NewSwitch create a new Switch and set discover.
    79  func NewSwitch(config *cfg.Config) (*Switch, error) {
    80  	var err error
    81  	var l Listener
    82  	var listenAddr string
    83  	var discv *dht.Network
    84  	var lanDiscv *mdns.LANDiscover
    85  
    86  	xPrv := config.PrivateKey()
    87  	if !config.VaultMode {
    88  		// Create listener
    89  		l, listenAddr = GetListener(config.P2P)
    90  		discv, err = dht.NewDiscover(config, *xPrv, l.ExternalAddress().Port)
    91  		if err != nil {
    92  			return nil, err
    93  		}
    94  		if config.P2P.LANDiscover {
    95  			lanDiscv = mdns.NewLANDiscover(mdns.NewProtocol(config.ChainID), int(l.ExternalAddress().Port))
    96  		}
    97  	}
    98  
    99  	return newSwitch(config, discv, lanDiscv, l, *xPrv, listenAddr)
   100  }
   101  
   102  // newSwitch creates a new Switch with the given config.
   103  func newSwitch(config *cfg.Config, discv discv, lanDiscv lanDiscv, l Listener, priv chainkd.XPrv, listenAddr string) (*Switch, error) {
   104  	sw := &Switch{
   105  		Config:       config,
   106  		peerConfig:   DefaultPeerConfig(config.P2P),
   107  		reactors:     make(map[string]Reactor),
   108  		chDescs:      make([]*connection.ChannelDescriptor, 0),
   109  		reactorsByCh: make(map[byte]Reactor),
   110  		peers:        NewPeerSet(),
   111  		dialing:      cmn.NewCMap(),
   112  		nodePrivKey:  priv,
   113  		discv:        discv,
   114  		lanDiscv:     lanDiscv,
   115  		nodeInfo:     NewNodeInfo(config, priv.XPub().PublicKey(), listenAddr),
   116  		security:     security.NewSecurity(config),
   117  	}
   118  
   119  	sw.AddListener(l)
   120  	sw.BaseService = *cmn.NewBaseService(nil, "P2P Switch", sw)
   121  	return sw, nil
   122  }
   123  
   124  // OnStart implements BaseService. It starts all the reactors, peers, and listeners.
   125  func (sw *Switch) OnStart() error {
   126  	for _, reactor := range sw.reactors {
   127  		if err := reactor.Start(); err != nil {
   128  			return err
   129  		}
   130  	}
   131  
   132  	sw.security.RegisterFilter(sw.nodeInfo)
   133  	sw.security.RegisterFilter(sw.peers)
   134  	if err := sw.security.Start(); err != nil {
   135  		return err
   136  	}
   137  
   138  	for _, listener := range sw.listeners {
   139  		go sw.listenerRoutine(listener)
   140  	}
   141  	go sw.ensureOutboundPeersRoutine()
   142  	go sw.connectLANPeersRoutine()
   143  
   144  	return nil
   145  }
   146  
   147  // OnStop implements BaseService. It stops all listeners, peers, and reactors.
   148  func (sw *Switch) OnStop() {
   149  	if sw.Config.P2P.LANDiscover {
   150  		sw.lanDiscv.Stop()
   151  	}
   152  
   153  	for _, listener := range sw.listeners {
   154  		listener.Stop()
   155  	}
   156  	sw.listeners = nil
   157  
   158  	for _, peer := range sw.peers.List() {
   159  		peer.Stop()
   160  		sw.peers.Remove(peer)
   161  	}
   162  
   163  	for _, reactor := range sw.reactors {
   164  		reactor.Stop()
   165  	}
   166  }
   167  
   168  // AddPeer performs the P2P handshake with a peer
   169  // that already has a SecretConnection. If all goes well,
   170  // it starts the peer and adds it to the switch.
   171  // NOTE: This performs a blocking handshake before the peer is added.
   172  // CONTRACT: If error is returned, peer is nil, and conn is immediately closed.
   173  func (sw *Switch) AddPeer(pc *peerConn, isLAN bool) error {
   174  	peerNodeInfo, err := pc.HandshakeTimeout(sw.nodeInfo, sw.peerConfig.HandshakeTimeout)
   175  	if err != nil {
   176  		return err
   177  	}
   178  
   179  	if err := version.Status.CheckUpdate(sw.nodeInfo.Version, peerNodeInfo.Version, peerNodeInfo.RemoteAddr); err != nil {
   180  		return err
   181  	}
   182  	if err := sw.nodeInfo.CompatibleWith(peerNodeInfo); err != nil {
   183  		return err
   184  	}
   185  
   186  	peer := newPeer(pc, peerNodeInfo, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, isLAN)
   187  	if err := sw.security.DoFilter(peer.RemoteAddrHost(), hex.EncodeToString(peer.PubKey())); err != nil {
   188  		return err
   189  	}
   190  
   191  	if pc.outbound && !peer.ServiceFlag().IsEnable(consensus.SFFullNode) {
   192  		return ErrConnectSpvPeer
   193  	}
   194  
   195  	// Start peer
   196  	if sw.IsRunning() {
   197  		if err := sw.startInitPeer(peer); err != nil {
   198  			return err
   199  		}
   200  	}
   201  
   202  	return sw.peers.Add(peer)
   203  }
   204  
   205  // AddReactor adds the given reactor to the switch.
   206  // NOTE: Not goroutine safe.
   207  func (sw *Switch) AddReactor(name string, reactor Reactor) Reactor {
   208  	// Validate the reactor.
   209  	// No two reactors can share the same channel.
   210  	for _, chDesc := range reactor.GetChannels() {
   211  		chID := chDesc.ID
   212  		if sw.reactorsByCh[chID] != nil {
   213  			cmn.PanicSanity(fmt.Sprintf("Channel %X has multiple reactors %v & %v", chID, sw.reactorsByCh[chID], reactor))
   214  		}
   215  		sw.chDescs = append(sw.chDescs, chDesc)
   216  		sw.reactorsByCh[chID] = reactor
   217  	}
   218  	sw.reactors[name] = reactor
   219  	reactor.SetSwitch(sw)
   220  	return reactor
   221  }
   222  
   223  // AddListener adds the given listener to the switch for listening to incoming peer connections.
   224  // NOTE: Not goroutine safe.
   225  func (sw *Switch) AddListener(l Listener) {
   226  	sw.listeners = append(sw.listeners, l)
   227  }
   228  
   229  //DialPeerWithAddress dial node from net address
   230  func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error {
   231  	log.WithFields(log.Fields{"module": logModule, "address": addr}).Debug("Dialing peer")
   232  	sw.dialing.Set(addr.IP.String(), addr)
   233  	defer sw.dialing.Delete(addr.IP.String())
   234  	if err := sw.security.DoFilter(addr.IP.String(), ""); err != nil {
   235  		return err
   236  	}
   237  
   238  	pc, err := newOutboundPeerConn(addr, sw.nodePrivKey, sw.peerConfig)
   239  	if err != nil {
   240  		log.WithFields(log.Fields{"module": logModule, "address": addr, " err": err}).Warn("DialPeer fail on newOutboundPeerConn")
   241  		return err
   242  	}
   243  
   244  	if err = sw.AddPeer(pc, addr.isLAN); err != nil {
   245  		log.WithFields(log.Fields{"module": logModule, "address": addr, " err": err}).Warn("DialPeer fail on switch AddPeer")
   246  		pc.CloseConn()
   247  		return err
   248  	}
   249  	log.WithFields(log.Fields{"module": logModule, "address": addr, "peer num": sw.peers.Size()}).Debug("DialPeer added peer")
   250  	return nil
   251  }
   252  
   253  func (sw *Switch) IsBanned(ip string, level byte, reason string) bool {
   254  	return sw.security.IsBanned(ip, level, reason)
   255  }
   256  
   257  //IsDialing prevent duplicate dialing
   258  func (sw *Switch) IsDialing(addr *NetAddress) bool {
   259  	return sw.dialing.Has(addr.IP.String())
   260  }
   261  
   262  // IsListening returns true if the switch has at least one listener.
   263  // NOTE: Not goroutine safe.
   264  func (sw *Switch) IsListening() bool {
   265  	return len(sw.listeners) > 0
   266  }
   267  
   268  // Listeners returns the list of listeners the switch listens on.
   269  // NOTE: Not goroutine safe.
   270  func (sw *Switch) Listeners() []Listener {
   271  	return sw.listeners
   272  }
   273  
   274  // NumPeers Returns the count of outbound/inbound and outbound-dialing peers.
   275  func (sw *Switch) NumPeers() (lan, outbound, inbound, dialing int) {
   276  	peers := sw.peers.List()
   277  	for _, peer := range peers {
   278  		if peer.outbound && !peer.isLAN {
   279  			outbound++
   280  		} else {
   281  			inbound++
   282  		}
   283  		if peer.isLAN {
   284  			lan++
   285  		}
   286  	}
   287  	dialing = sw.dialing.Size()
   288  	return
   289  }
   290  
   291  // NodeInfo returns the switch's NodeInfo.
   292  // NOTE: Not goroutine safe.
   293  func (sw *Switch) NodeInfo() *NodeInfo {
   294  	return sw.nodeInfo
   295  }
   296  
   297  //Peers return switch peerset
   298  func (sw *Switch) Peers() *PeerSet {
   299  	return sw.peers
   300  }
   301  
   302  // StopPeerForError disconnects from a peer due to external error.
   303  func (sw *Switch) StopPeerForError(peer *Peer, reason interface{}) {
   304  	log.WithFields(log.Fields{"module": logModule, "peer": peer, " err": reason}).Debug("stopping peer for error")
   305  	sw.stopAndRemovePeer(peer, reason)
   306  }
   307  
   308  // StopPeerGracefully disconnect from a peer gracefully.
   309  func (sw *Switch) StopPeerGracefully(peerID string) {
   310  	if peer := sw.peers.Get(peerID); peer != nil {
   311  		sw.stopAndRemovePeer(peer, nil)
   312  	}
   313  }
   314  
   315  func (sw *Switch) addPeerWithConnection(conn net.Conn) error {
   316  	peerConn, err := newInboundPeerConn(conn, sw.nodePrivKey, sw.Config.P2P)
   317  	if err != nil {
   318  		if err := conn.Close(); err != nil {
   319  			log.WithFields(log.Fields{"module": logModule, "remote peer:": conn.RemoteAddr().String(), " err:": err}).Warn("closes connection err")
   320  		}
   321  		return err
   322  	}
   323  
   324  	if err = sw.AddPeer(peerConn, false); err != nil {
   325  		if err := conn.Close(); err != nil {
   326  			log.WithFields(log.Fields{"module": logModule, "remote peer:": conn.RemoteAddr().String(), " err:": err}).Warn("closes connection err")
   327  		}
   328  		return err
   329  	}
   330  
   331  	log.WithFields(log.Fields{"module": logModule, "address": conn.RemoteAddr().String(), "peer num": sw.peers.Size()}).Debug("add inbound peer")
   332  	return nil
   333  }
   334  
   335  func (sw *Switch) connectLANPeers(lanPeer mdns.LANPeerEvent) {
   336  	lanPeers, _, _, numDialing := sw.NumPeers()
   337  	numToDial := maxNumLANPeers - lanPeers
   338  	log.WithFields(log.Fields{"module": logModule, "numDialing": numDialing, "numToDial": numToDial}).Debug("connect LAN peers")
   339  	if numToDial <= 0 {
   340  		return
   341  	}
   342  	addresses := make([]*NetAddress, 0)
   343  	for i := 0; i < len(lanPeer.IP); i++ {
   344  		addresses = append(addresses, NewLANNetAddressIPPort(lanPeer.IP[i], uint16(lanPeer.Port)))
   345  	}
   346  	sw.dialPeers(addresses)
   347  }
   348  
   349  func (sw *Switch) connectLANPeersRoutine() {
   350  	if !sw.Config.P2P.LANDiscover {
   351  		return
   352  	}
   353  
   354  	lanPeerEventSub, err := sw.lanDiscv.Subscribe()
   355  	if err != nil {
   356  		log.WithFields(log.Fields{"module": logModule, "err": err}).Warning("subscribe LAN Peer Event error")
   357  		return
   358  	}
   359  
   360  	for {
   361  		select {
   362  		case obj, ok := <-lanPeerEventSub.Chan():
   363  			if !ok {
   364  				log.WithFields(log.Fields{"module": logModule}).Warning("LAN peer event subscription channel closed")
   365  				return
   366  			}
   367  			LANPeer, ok := obj.Data.(mdns.LANPeerEvent)
   368  			if !ok {
   369  				log.WithFields(log.Fields{"module": logModule}).Error("event type error")
   370  				continue
   371  			}
   372  			sw.connectLANPeers(LANPeer)
   373  		case <-sw.Quit():
   374  			return
   375  		}
   376  	}
   377  }
   378  
   379  func (sw *Switch) listenerRoutine(l Listener) {
   380  	for {
   381  		inConn, ok := <-l.Connections()
   382  		if !ok {
   383  			break
   384  		}
   385  
   386  		// disconnect if we alrady have MaxNumPeers
   387  		if sw.peers.Size() >= sw.Config.P2P.MaxNumPeers {
   388  			if err := inConn.Close(); err != nil {
   389  				log.WithFields(log.Fields{"module": logModule, "remote peer:": inConn.RemoteAddr().String(), " err:": err}).Warn("closes connection err")
   390  			}
   391  			log.Info("Ignoring inbound connection: already have enough peers.")
   392  			continue
   393  		}
   394  
   395  		// New inbound connection!
   396  		if err := sw.addPeerWithConnection(inConn); err != nil {
   397  			log.Info("Ignoring inbound connection: error while adding peer.", " address:", inConn.RemoteAddr().String(), " error:", err)
   398  			continue
   399  		}
   400  	}
   401  }
   402  
   403  func (sw *Switch) dialPeerWorker(a *NetAddress, wg *sync.WaitGroup) {
   404  	if err := sw.DialPeerWithAddress(a); err != nil {
   405  		log.WithFields(log.Fields{"module": logModule, "addr": a, "err": err}).Warn("dialPeerWorker fail on dial peer")
   406  	}
   407  	wg.Done()
   408  }
   409  
   410  func (sw *Switch) dialPeers(addresses []*NetAddress) {
   411  	connectedPeers := make(map[string]struct{})
   412  	for _, peer := range sw.Peers().List() {
   413  		connectedPeers[peer.RemoteAddrHost()] = struct{}{}
   414  	}
   415  
   416  	var wg sync.WaitGroup
   417  	for _, address := range addresses {
   418  		if sw.NodeInfo().ListenAddr == address.String() {
   419  			continue
   420  		}
   421  		if dialling := sw.IsDialing(address); dialling {
   422  			continue
   423  		}
   424  		if _, ok := connectedPeers[address.IP.String()]; ok {
   425  			continue
   426  		}
   427  
   428  		wg.Add(1)
   429  		go sw.dialPeerWorker(address, &wg)
   430  	}
   431  	wg.Wait()
   432  }
   433  
   434  func (sw *Switch) ensureKeepConnectPeers() {
   435  	keepDials := netutil.CheckAndSplitAddresses(sw.Config.P2P.KeepDial)
   436  	addresses := make([]*NetAddress, 0)
   437  	for _, keepDial := range keepDials {
   438  		address, err := NewNetAddressString(keepDial)
   439  		if err != nil {
   440  			log.WithFields(log.Fields{"module": logModule, "err": err, "address": keepDial}).Warn("parse address to NetAddress")
   441  			continue
   442  		}
   443  		addresses = append(addresses, address)
   444  	}
   445  
   446  	sw.dialPeers(addresses)
   447  }
   448  
   449  func (sw *Switch) ensureOutboundPeers() {
   450  	lanPeers, numOutPeers, _, numDialing := sw.NumPeers()
   451  	numToDial := minNumOutboundPeers - (numOutPeers + numDialing)
   452  	log.WithFields(log.Fields{"module": logModule, "numOutPeers": numOutPeers, "LANPeers": lanPeers, "numDialing": numDialing, "numToDial": numToDial}).Debug("ensure peers")
   453  	if numToDial <= 0 {
   454  		return
   455  	}
   456  
   457  	nodes := make([]*dht.Node, numToDial)
   458  	n := sw.discv.ReadRandomNodes(nodes)
   459  	addresses := make([]*NetAddress, 0)
   460  	for i := 0; i < n; i++ {
   461  		address := NewNetAddressIPPort(nodes[i].IP, nodes[i].TCP)
   462  		addresses = append(addresses, address)
   463  	}
   464  	sw.dialPeers(addresses)
   465  }
   466  
   467  func (sw *Switch) ensureOutboundPeersRoutine() {
   468  	sw.ensureKeepConnectPeers()
   469  	sw.ensureOutboundPeers()
   470  
   471  	ticker := time.NewTicker(10 * time.Second)
   472  	defer ticker.Stop()
   473  
   474  	for {
   475  		select {
   476  		case <-ticker.C:
   477  			sw.ensureKeepConnectPeers()
   478  			sw.ensureOutboundPeers()
   479  		case <-sw.Quit():
   480  			return
   481  		}
   482  	}
   483  }
   484  
   485  func (sw *Switch) startInitPeer(peer *Peer) error {
   486  	// spawn send/recv routines
   487  	if err := peer.Start(); err != nil {
   488  		log.WithFields(log.Fields{"module": logModule, "remote peer:": peer.RemoteAddr, " err:": err}).Error("init peer err")
   489  	}
   490  
   491  	for _, reactor := range sw.reactors {
   492  		if err := reactor.AddPeer(peer); err != nil {
   493  			return err
   494  		}
   495  	}
   496  	return nil
   497  }
   498  
   499  func (sw *Switch) stopAndRemovePeer(peer *Peer, reason interface{}) {
   500  	sw.peers.Remove(peer)
   501  	for _, reactor := range sw.reactors {
   502  		reactor.RemovePeer(peer, reason)
   503  	}
   504  	peer.Stop()
   505  
   506  	sentStatus, receivedStatus := peer.TrafficStatus()
   507  	log.WithFields(log.Fields{
   508  		"module":                logModule,
   509  		"address":               peer.Addr().String(),
   510  		"reason":                reason,
   511  		"duration":              sentStatus.Duration.String(),
   512  		"total_sent":            sentStatus.Bytes,
   513  		"total_received":        receivedStatus.Bytes,
   514  		"average_sent_rate":     sentStatus.AvgRate,
   515  		"average_received_rate": receivedStatus.AvgRate,
   516  		"peer num":              sw.peers.Size(),
   517  	}).Info("disconnect with peer")
   518  }