github.com/koko1123/flow-go-1@v0.29.6/network/p2p/p2pnode/protocolPeerCache.go (about)

     1  package p2pnode
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  
     7  	"github.com/libp2p/go-libp2p/core/event"
     8  	"github.com/libp2p/go-libp2p/core/host"
     9  	libp2pnet "github.com/libp2p/go-libp2p/core/network"
    10  	"github.com/libp2p/go-libp2p/core/peer"
    11  	"github.com/libp2p/go-libp2p/core/protocol"
    12  	"github.com/rs/zerolog"
    13  )
    14  
    15  // ProtocolPeerCache store a mapping from protocol ID to peers who support that protocol
    16  type ProtocolPeerCache struct {
    17  	protocolPeers map[protocol.ID]map[peer.ID]struct{}
    18  	sync.RWMutex
    19  }
    20  
    21  func NewProtocolPeerCache(logger zerolog.Logger, h host.Host) (*ProtocolPeerCache, error) {
    22  	sub, err := h.EventBus().
    23  		Subscribe([]interface{}{new(event.EvtPeerIdentificationCompleted), new(event.EvtPeerProtocolsUpdated)})
    24  	if err != nil {
    25  		return nil, fmt.Errorf("could not subscribe to peer protocol update events: %w", err)
    26  	}
    27  	p := &ProtocolPeerCache{protocolPeers: make(map[protocol.ID]map[peer.ID]struct{})}
    28  	h.Network().Notify(&libp2pnet.NotifyBundle{
    29  		DisconnectedF: func(n libp2pnet.Network, c libp2pnet.Conn) {
    30  			peer := c.RemotePeer()
    31  			if len(n.ConnsToPeer(peer)) == 0 {
    32  				p.RemovePeer(peer)
    33  			}
    34  		},
    35  	})
    36  	go p.consumeSubscription(logger, h, sub)
    37  
    38  	return p, nil
    39  }
    40  
    41  func (p *ProtocolPeerCache) RemovePeer(peerID peer.ID) {
    42  	p.Lock()
    43  	defer p.Unlock()
    44  	for pid, peers := range p.protocolPeers {
    45  		delete(peers, peerID)
    46  		if len(peers) == 0 {
    47  			delete(p.protocolPeers, pid)
    48  		}
    49  	}
    50  }
    51  
    52  func (p *ProtocolPeerCache) AddProtocols(peerID peer.ID, protocols []protocol.ID) {
    53  	p.Lock()
    54  	defer p.Unlock()
    55  	for _, pid := range protocols {
    56  		peers, ok := p.protocolPeers[pid]
    57  		if !ok {
    58  			peers = make(map[peer.ID]struct{})
    59  			p.protocolPeers[pid] = peers
    60  		}
    61  		peers[peerID] = struct{}{}
    62  	}
    63  }
    64  
    65  func (p *ProtocolPeerCache) RemoveProtocols(peerID peer.ID, protocols []protocol.ID) {
    66  	p.Lock()
    67  	defer p.Unlock()
    68  	for _, pid := range protocols {
    69  		peers := p.protocolPeers[pid]
    70  		delete(peers, peerID)
    71  		if len(peers) == 0 {
    72  			delete(p.protocolPeers, pid)
    73  		}
    74  	}
    75  }
    76  
    77  func (p *ProtocolPeerCache) GetPeers(pid protocol.ID) map[peer.ID]struct{} {
    78  	p.RLock()
    79  	defer p.RUnlock()
    80  
    81  	// it is not safe to return a reference to the map, so we make a copy
    82  	peersCopy := make(map[peer.ID]struct{}, len(p.protocolPeers[pid]))
    83  	for peerID := range p.protocolPeers[pid] {
    84  		peersCopy[peerID] = struct{}{}
    85  	}
    86  	return peersCopy
    87  }
    88  
    89  func (p *ProtocolPeerCache) consumeSubscription(logger zerolog.Logger, h host.Host, sub event.Subscription) {
    90  	defer sub.Close()
    91  	logger.Debug().Msg("starting peer protocol event subscription loop")
    92  	for e := range sub.Out() {
    93  		logger.Debug().Interface("event", e).Msg("received new peer protocol event")
    94  		switch evt := e.(type) {
    95  		case event.EvtPeerIdentificationCompleted:
    96  			protocols, err := h.Peerstore().GetProtocols(evt.Peer)
    97  			if err != nil {
    98  				logger.Err(err).Str("peer", evt.Peer.String()).Msg("failed to get protocols for peer")
    99  				continue
   100  			}
   101  			p.AddProtocols(evt.Peer, protocol.ConvertFromStrings(protocols))
   102  		case event.EvtPeerProtocolsUpdated:
   103  			p.AddProtocols(evt.Peer, evt.Added)
   104  			p.RemoveProtocols(evt.Peer, evt.Removed)
   105  		}
   106  	}
   107  	logger.Debug().Msg("exiting peer protocol event subscription loop")
   108  }