github.com/celestiaorg/celestia-node@v0.15.0-beta.1/share/p2p/peers/manager.go (about)

     1  package peers
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sync"
     8  	"sync/atomic"
     9  	"time"
    10  
    11  	logging "github.com/ipfs/go-log/v2"
    12  	pubsub "github.com/libp2p/go-libp2p-pubsub"
    13  	"github.com/libp2p/go-libp2p/core/event"
    14  	"github.com/libp2p/go-libp2p/core/host"
    15  	"github.com/libp2p/go-libp2p/core/network"
    16  	"github.com/libp2p/go-libp2p/core/peer"
    17  	"github.com/libp2p/go-libp2p/p2p/host/eventbus"
    18  	"github.com/libp2p/go-libp2p/p2p/net/conngater"
    19  
    20  	libhead "github.com/celestiaorg/go-header"
    21  
    22  	"github.com/celestiaorg/celestia-node/header"
    23  	"github.com/celestiaorg/celestia-node/share"
    24  	"github.com/celestiaorg/celestia-node/share/p2p/shrexsub"
    25  )
    26  
    27  const (
    28  	// ResultNoop indicates operation was successful and no extra action is required
    29  	ResultNoop result = "result_noop"
    30  	// ResultCooldownPeer will put returned peer on cooldown, meaning it won't be available by Peer
    31  	// method for some time
    32  	ResultCooldownPeer = "result_cooldown_peer"
    33  	// ResultBlacklistPeer will blacklist peer. Blacklisted peers will be disconnected and blocked from
    34  	// any p2p communication in future by libp2p Gater
    35  	ResultBlacklistPeer = "result_blacklist_peer"
    36  
    37  	// eventbusBufSize is the size of the buffered channel to handle
    38  	// events in libp2p
    39  	eventbusBufSize = 32
    40  
    41  	// storedPoolsAmount is the amount of pools for recent headers that will be stored in the peer
    42  	// manager
    43  	storedPoolsAmount = 10
    44  )
    45  
    46  type result string
    47  
    48  var log = logging.Logger("shrex/peer-manager")
    49  
    50  // Manager keeps track of peers coming from shrex.Sub and from discovery
    51  type Manager struct {
    52  	lock   sync.Mutex
    53  	params Parameters
    54  
    55  	// header subscription is necessary in order to Validate the inbound eds hash
    56  	headerSub libhead.Subscriber[*header.ExtendedHeader]
    57  	shrexSub  *shrexsub.PubSub
    58  	host      host.Host
    59  	connGater *conngater.BasicConnectionGater
    60  
    61  	// pools collecting peers from shrexSub and stores them by datahash
    62  	pools map[string]*syncPool
    63  
    64  	// initialHeight is the height of the first header received from headersub
    65  	initialHeight atomic.Uint64
    66  	// messages from shrex.Sub with height below storeFrom will be ignored, since we don't need to
    67  	// track peers for those headers
    68  	storeFrom atomic.Uint64
    69  
    70  	// fullNodes collects full nodes peer.ID found via discovery
    71  	fullNodes *pool
    72  
    73  	// hashes that are not in the chain
    74  	blacklistedHashes map[string]bool
    75  
    76  	metrics *metrics
    77  
    78  	headerSubDone         chan struct{}
    79  	disconnectedPeersDone chan struct{}
    80  	cancel                context.CancelFunc
    81  }
    82  
    83  // DoneFunc updates internal state depending on call results. Should be called once per returned
    84  // peer from Peer method
    85  type DoneFunc func(result)
    86  
    87  type syncPool struct {
    88  	*pool
    89  
    90  	// isValidatedDataHash indicates if datahash was validated by receiving corresponding extended
    91  	// header from headerSub
    92  	isValidatedDataHash atomic.Bool
    93  	// height is the height of the header that corresponds to datahash
    94  	height uint64
    95  	// createdAt is the syncPool creation time
    96  	createdAt time.Time
    97  }
    98  
    99  func NewManager(
   100  	params Parameters,
   101  	host host.Host,
   102  	connGater *conngater.BasicConnectionGater,
   103  	options ...Option,
   104  ) (*Manager, error) {
   105  	if err := params.Validate(); err != nil {
   106  		return nil, err
   107  	}
   108  
   109  	s := &Manager{
   110  		params:                params,
   111  		connGater:             connGater,
   112  		host:                  host,
   113  		pools:                 make(map[string]*syncPool),
   114  		blacklistedHashes:     make(map[string]bool),
   115  		headerSubDone:         make(chan struct{}),
   116  		disconnectedPeersDone: make(chan struct{}),
   117  	}
   118  
   119  	for _, opt := range options {
   120  		err := opt(s)
   121  		if err != nil {
   122  			return nil, err
   123  		}
   124  	}
   125  
   126  	s.fullNodes = newPool(s.params.PeerCooldown)
   127  	return s, nil
   128  }
   129  
   130  func (m *Manager) Start(startCtx context.Context) error {
   131  	ctx, cancel := context.WithCancel(context.Background())
   132  	m.cancel = cancel
   133  
   134  	// pools will only be populated with senders of shrexsub notifications if the WithShrexSubPools
   135  	// option is used.
   136  	if m.shrexSub == nil && m.headerSub == nil {
   137  		return nil
   138  	}
   139  
   140  	validatorFn := m.metrics.validationObserver(m.Validate)
   141  	err := m.shrexSub.AddValidator(validatorFn)
   142  	if err != nil {
   143  		return fmt.Errorf("registering validator: %w", err)
   144  	}
   145  	err = m.shrexSub.Start(startCtx)
   146  	if err != nil {
   147  		return fmt.Errorf("starting shrexsub: %w", err)
   148  	}
   149  
   150  	headerSub, err := m.headerSub.Subscribe()
   151  	if err != nil {
   152  		return fmt.Errorf("subscribing to headersub: %w", err)
   153  	}
   154  
   155  	sub, err := m.host.EventBus().Subscribe(&event.EvtPeerConnectednessChanged{}, eventbus.BufSize(eventbusBufSize))
   156  	if err != nil {
   157  		return fmt.Errorf("subscribing to libp2p events: %w", err)
   158  	}
   159  
   160  	go m.subscribeHeader(ctx, headerSub)
   161  	go m.subscribeDisconnectedPeers(ctx, sub)
   162  	go m.GC(ctx)
   163  	return nil
   164  }
   165  
   166  func (m *Manager) Stop(ctx context.Context) error {
   167  	m.cancel()
   168  
   169  	// we do not need to wait for headersub and disconnected peers to finish
   170  	// here, since they were never started
   171  	if m.headerSub == nil && m.shrexSub == nil {
   172  		return nil
   173  	}
   174  
   175  	select {
   176  	case <-m.headerSubDone:
   177  	case <-ctx.Done():
   178  		return ctx.Err()
   179  	}
   180  
   181  	select {
   182  	case <-m.disconnectedPeersDone:
   183  	case <-ctx.Done():
   184  		return ctx.Err()
   185  	}
   186  
   187  	return nil
   188  }
   189  
   190  // Peer returns peer collected from shrex.Sub for given datahash if any available.
   191  // If there is none, it will look for full nodes collected from discovery. If there is no discovered
   192  // full nodes, it will wait until any peer appear in either source or timeout happen.
   193  // After fetching data using given peer, caller is required to call returned DoneFunc using
   194  // appropriate result value
   195  func (m *Manager) Peer(ctx context.Context, datahash share.DataHash, height uint64,
   196  ) (peer.ID, DoneFunc, error) {
   197  	p := m.validatedPool(datahash.String(), height)
   198  
   199  	// first, check if a peer is available for the given datahash
   200  	peerID, ok := p.tryGet()
   201  	if ok {
   202  		if m.removeIfUnreachable(p, peerID) {
   203  			return m.Peer(ctx, datahash, height)
   204  		}
   205  		return m.newPeer(ctx, datahash, peerID, sourceShrexSub, p.len(), 0)
   206  	}
   207  
   208  	// if no peer for datahash is currently available, try to use full node
   209  	// obtained from discovery
   210  	peerID, ok = m.fullNodes.tryGet()
   211  	if ok {
   212  		return m.newPeer(ctx, datahash, peerID, sourceFullNodes, m.fullNodes.len(), 0)
   213  	}
   214  
   215  	// no peers are available right now, wait for the first one
   216  	start := time.Now()
   217  	select {
   218  	case peerID = <-p.next(ctx):
   219  		if m.removeIfUnreachable(p, peerID) {
   220  			return m.Peer(ctx, datahash, height)
   221  		}
   222  		return m.newPeer(ctx, datahash, peerID, sourceShrexSub, p.len(), time.Since(start))
   223  	case peerID = <-m.fullNodes.next(ctx):
   224  		return m.newPeer(ctx, datahash, peerID, sourceFullNodes, m.fullNodes.len(), time.Since(start))
   225  	case <-ctx.Done():
   226  		return "", nil, ctx.Err()
   227  	}
   228  }
   229  
   230  // UpdateFullNodePool is called by discovery when new full node is discovered or removed
   231  func (m *Manager) UpdateFullNodePool(peerID peer.ID, isAdded bool) {
   232  	if isAdded {
   233  		if m.isBlacklistedPeer(peerID) {
   234  			log.Debugw("got blacklisted peer from discovery", "peer", peerID.String())
   235  			return
   236  		}
   237  		m.fullNodes.add(peerID)
   238  		log.Debugw("added to full nodes", "peer", peerID)
   239  		return
   240  	}
   241  
   242  	log.Debugw("removing peer from discovered full nodes", "peer", peerID.String())
   243  	m.fullNodes.remove(peerID)
   244  }
   245  
   246  func (m *Manager) newPeer(
   247  	ctx context.Context,
   248  	datahash share.DataHash,
   249  	peerID peer.ID,
   250  	source peerSource,
   251  	poolSize int,
   252  	waitTime time.Duration,
   253  ) (peer.ID, DoneFunc, error) {
   254  	log.Debugw("got peer",
   255  		"hash", datahash.String(),
   256  		"peer", peerID.String(),
   257  		"source", source,
   258  		"pool_size", poolSize,
   259  		"wait (s)", waitTime)
   260  	m.metrics.observeGetPeer(ctx, source, poolSize, waitTime)
   261  	return peerID, m.doneFunc(datahash, peerID, source), nil
   262  }
   263  
   264  func (m *Manager) doneFunc(datahash share.DataHash, peerID peer.ID, source peerSource) DoneFunc {
   265  	return func(result result) {
   266  		log.Debugw("set peer result",
   267  			"hash", datahash.String(),
   268  			"peer", peerID.String(),
   269  			"source", source,
   270  			"result", result)
   271  		m.metrics.observeDoneResult(source, result)
   272  		switch result {
   273  		case ResultNoop:
   274  		case ResultCooldownPeer:
   275  			if source == sourceFullNodes {
   276  				m.fullNodes.putOnCooldown(peerID)
   277  				return
   278  			}
   279  			m.getPool(datahash.String()).putOnCooldown(peerID)
   280  		case ResultBlacklistPeer:
   281  			m.blacklistPeers(reasonMisbehave, peerID)
   282  		}
   283  	}
   284  }
   285  
   286  // subscribeHeader takes datahash from received header and validates corresponding peer pool.
   287  func (m *Manager) subscribeHeader(ctx context.Context, headerSub libhead.Subscription[*header.ExtendedHeader]) {
   288  	defer close(m.headerSubDone)
   289  	defer headerSub.Cancel()
   290  
   291  	for {
   292  		h, err := headerSub.NextHeader(ctx)
   293  		if err != nil {
   294  			if errors.Is(err, context.Canceled) {
   295  				return
   296  			}
   297  			log.Errorw("get next header from sub", "err", err)
   298  			continue
   299  		}
   300  		m.validatedPool(h.DataHash.String(), h.Height())
   301  
   302  		// store first header for validation purposes
   303  		if m.initialHeight.CompareAndSwap(0, h.Height()) {
   304  			log.Debugw("stored initial height", "height", h.Height())
   305  		}
   306  
   307  		// update storeFrom if header height
   308  		m.storeFrom.Store(uint64(max(0, int(h.Height())-storedPoolsAmount)))
   309  		log.Debugw("updated lowest stored height", "height", h.Height())
   310  	}
   311  }
   312  
   313  // subscribeDisconnectedPeers subscribes to libp2p connectivity events and removes disconnected
   314  // peers from full nodes pool
   315  func (m *Manager) subscribeDisconnectedPeers(ctx context.Context, sub event.Subscription) {
   316  	defer close(m.disconnectedPeersDone)
   317  	defer sub.Close()
   318  	for {
   319  		select {
   320  		case <-ctx.Done():
   321  			return
   322  		case e, ok := <-sub.Out():
   323  			if !ok {
   324  				log.Fatal("Subscription for connectedness events is closed.") //nolint:gocritic
   325  				return
   326  			}
   327  			// listen to disconnect event to remove peer from full nodes pool
   328  			connStatus := e.(event.EvtPeerConnectednessChanged)
   329  			if connStatus.Connectedness == network.NotConnected {
   330  				peer := connStatus.Peer
   331  				if m.fullNodes.has(peer) {
   332  					log.Debugw("peer disconnected, removing from full nodes", "peer", peer.String())
   333  					m.fullNodes.remove(peer)
   334  				}
   335  			}
   336  		}
   337  	}
   338  }
   339  
   340  // Validate will collect peer.ID into corresponding peer pool
   341  func (m *Manager) Validate(_ context.Context, peerID peer.ID, msg shrexsub.Notification) pubsub.ValidationResult {
   342  	logger := log.With("peer", peerID.String(), "hash", msg.DataHash.String())
   343  
   344  	// messages broadcast from self should bypass the validation with Accept
   345  	if peerID == m.host.ID() {
   346  		logger.Debug("received datahash from self")
   347  		return pubsub.ValidationAccept
   348  	}
   349  
   350  	// punish peer for sending invalid hash if it has misbehaved in the past
   351  	if m.isBlacklistedHash(msg.DataHash) {
   352  		logger.Debug("received blacklisted hash, reject validation")
   353  		return pubsub.ValidationReject
   354  	}
   355  
   356  	if m.isBlacklistedPeer(peerID) {
   357  		logger.Debug("received message from blacklisted peer, reject validation")
   358  		return pubsub.ValidationReject
   359  	}
   360  
   361  	if msg.Height < m.storeFrom.Load() {
   362  		logger.Debug("received message for past header")
   363  		return pubsub.ValidationIgnore
   364  	}
   365  
   366  	p := m.getOrCreatePool(msg.DataHash.String(), msg.Height)
   367  	logger.Debugw("got hash from shrex-sub")
   368  
   369  	p.add(peerID)
   370  	if p.isValidatedDataHash.Load() {
   371  		// add peer to full nodes pool only if datahash has been already validated
   372  		m.fullNodes.add(peerID)
   373  	}
   374  	return pubsub.ValidationIgnore
   375  }
   376  
   377  func (m *Manager) getPool(datahash string) *syncPool {
   378  	m.lock.Lock()
   379  	defer m.lock.Unlock()
   380  	return m.pools[datahash]
   381  }
   382  
   383  func (m *Manager) getOrCreatePool(datahash string, height uint64) *syncPool {
   384  	m.lock.Lock()
   385  	defer m.lock.Unlock()
   386  
   387  	p, ok := m.pools[datahash]
   388  	if !ok {
   389  		p = &syncPool{
   390  			height:    height,
   391  			pool:      newPool(m.params.PeerCooldown),
   392  			createdAt: time.Now(),
   393  		}
   394  		m.pools[datahash] = p
   395  	}
   396  
   397  	return p
   398  }
   399  
   400  func (m *Manager) blacklistPeers(reason blacklistPeerReason, peerIDs ...peer.ID) {
   401  	m.metrics.observeBlacklistPeers(reason, len(peerIDs))
   402  
   403  	for _, peerID := range peerIDs {
   404  		// blacklisted peers will be logged regardless of EnableBlackListing whether option being is
   405  		// enabled, until blacklisting is not properly tested and enabled by default.
   406  		log.Debugw("blacklisting peer", "peer", peerID.String(), "reason", reason)
   407  		if !m.params.EnableBlackListing {
   408  			continue
   409  		}
   410  
   411  		m.fullNodes.remove(peerID)
   412  		// add peer to the blacklist, so we can't connect to it in the future.
   413  		err := m.connGater.BlockPeer(peerID)
   414  		if err != nil {
   415  			log.Warnw("failed to block peer", "peer", peerID, "err", err)
   416  		}
   417  		// close connections to peer.
   418  		err = m.host.Network().ClosePeer(peerID)
   419  		if err != nil {
   420  			log.Warnw("failed to close connection with peer", "peer", peerID, "err", err)
   421  		}
   422  	}
   423  }
   424  
   425  func (m *Manager) isBlacklistedPeer(peerID peer.ID) bool {
   426  	return !m.connGater.InterceptPeerDial(peerID)
   427  }
   428  
   429  func (m *Manager) isBlacklistedHash(hash share.DataHash) bool {
   430  	m.lock.Lock()
   431  	defer m.lock.Unlock()
   432  	return m.blacklistedHashes[hash.String()]
   433  }
   434  
   435  func (m *Manager) validatedPool(hashStr string, height uint64) *syncPool {
   436  	p := m.getOrCreatePool(hashStr, height)
   437  	if p.isValidatedDataHash.CompareAndSwap(false, true) {
   438  		log.Debugw("pool marked validated", "datahash", hashStr)
   439  		// if pool is proven to be valid, add all collected peers to full nodes
   440  		m.fullNodes.add(p.peers()...)
   441  	}
   442  	return p
   443  }
   444  
   445  // removeIfUnreachable removes peer from some pool if it is blacklisted or disconnected
   446  func (m *Manager) removeIfUnreachable(pool *syncPool, peerID peer.ID) bool {
   447  	if m.isBlacklistedPeer(peerID) || !m.fullNodes.has(peerID) {
   448  		log.Debugw("removing outdated peer from pool", "peer", peerID.String())
   449  		pool.remove(peerID)
   450  		return true
   451  	}
   452  	return false
   453  }
   454  
   455  func (m *Manager) GC(ctx context.Context) {
   456  	ticker := time.NewTicker(m.params.GcInterval)
   457  	defer ticker.Stop()
   458  
   459  	var blacklist []peer.ID
   460  	for {
   461  		select {
   462  		case <-ticker.C:
   463  		case <-ctx.Done():
   464  			return
   465  		}
   466  
   467  		blacklist = m.cleanUp()
   468  		if len(blacklist) > 0 {
   469  			m.blacklistPeers(reasonInvalidHash, blacklist...)
   470  		}
   471  	}
   472  }
   473  
   474  func (m *Manager) cleanUp() []peer.ID {
   475  	if m.initialHeight.Load() == 0 {
   476  		// can't blacklist peers until initialHeight is set
   477  		return nil
   478  	}
   479  
   480  	m.lock.Lock()
   481  	defer m.lock.Unlock()
   482  
   483  	addToBlackList := make(map[peer.ID]struct{})
   484  	for h, p := range m.pools {
   485  		if p.isValidatedDataHash.Load() {
   486  			// remove pools that are outdated
   487  			if p.height < m.storeFrom.Load() {
   488  				delete(m.pools, h)
   489  			}
   490  			continue
   491  		}
   492  
   493  		// can't validate datahashes below initial height
   494  		if p.height < m.initialHeight.Load() {
   495  			delete(m.pools, h)
   496  			continue
   497  		}
   498  
   499  		// find pools that are not validated in time
   500  		if time.Since(p.createdAt) > m.params.PoolValidationTimeout {
   501  			delete(m.pools, h)
   502  
   503  			log.Debug("blacklisting datahash with all corresponding peers",
   504  				"hash", h,
   505  				"peer_list", p.peersList)
   506  			// blacklist hash
   507  			m.blacklistedHashes[h] = true
   508  
   509  			// blacklist peers
   510  			for _, peer := range p.peersList {
   511  				addToBlackList[peer] = struct{}{}
   512  			}
   513  		}
   514  	}
   515  
   516  	blacklist := make([]peer.ID, 0, len(addToBlackList))
   517  	for peerID := range addToBlackList {
   518  		blacklist = append(blacklist, peerID)
   519  	}
   520  	return blacklist
   521  }