github.com/koko1123/flow-go-1@v0.29.6/network/p2p/unicast/manager.go (about)

     1  package unicast
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math/rand"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/hashicorp/go-multierror"
    12  	libp2pnet "github.com/libp2p/go-libp2p/core/network"
    13  	"github.com/libp2p/go-libp2p/core/peer"
    14  	"github.com/libp2p/go-libp2p/core/protocol"
    15  	"github.com/libp2p/go-libp2p/p2p/net/swarm"
    16  	"github.com/multiformats/go-multiaddr"
    17  	"github.com/rs/zerolog"
    18  
    19  	"github.com/koko1123/flow-go-1/model/flow"
    20  )
    21  
    22  // MaxConnectAttemptSleepDuration is the maximum number of milliseconds to wait between attempts for a 1-1 direct connection
    23  const MaxConnectAttemptSleepDuration = 5
    24  
    25  // Manager manages libp2p stream negotiation and creation, which is utilized for unicast dispatches.
    26  type Manager struct {
    27  	logger         zerolog.Logger
    28  	streamFactory  StreamFactory
    29  	unicasts       []Protocol
    30  	defaultHandler libp2pnet.StreamHandler
    31  	sporkId        flow.Identifier
    32  }
    33  
    34  func NewUnicastManager(logger zerolog.Logger, streamFactory StreamFactory, sporkId flow.Identifier) *Manager {
    35  	return &Manager{
    36  		logger:        logger.With().Str("module", "unicast-manager").Logger(),
    37  		streamFactory: streamFactory,
    38  		sporkId:       sporkId,
    39  	}
    40  }
    41  
    42  // WithDefaultHandler sets the default stream handler for this unicast manager. The default handler is utilized
    43  // as the core handler for other unicast protocols, e.g., compressions.
    44  func (m *Manager) WithDefaultHandler(defaultHandler libp2pnet.StreamHandler) {
    45  	defaultProtocolID := FlowProtocolID(m.sporkId)
    46  	m.defaultHandler = defaultHandler
    47  
    48  	if len(m.unicasts) > 0 {
    49  		panic("default handler must be set only once before any unicast registration")
    50  	}
    51  
    52  	m.unicasts = []Protocol{
    53  		&PlainStream{
    54  			protocolId: defaultProtocolID,
    55  			handler:    defaultHandler,
    56  		},
    57  	}
    58  
    59  	m.streamFactory.SetStreamHandler(defaultProtocolID, defaultHandler)
    60  	m.logger.Info().Str("protocol_id", string(defaultProtocolID)).Msg("default unicast handler registered")
    61  }
    62  
    63  // Register registers given protocol name as preferred unicast. Each invocation of register prioritizes the current protocol
    64  // over previously registered ones.
    65  func (m *Manager) Register(unicast ProtocolName) error {
    66  	factory, err := ToProtocolFactory(unicast)
    67  	if err != nil {
    68  		return fmt.Errorf("could not translate protocol name into factory: %w", err)
    69  	}
    70  
    71  	u := factory(m.logger, m.sporkId, m.defaultHandler)
    72  
    73  	m.unicasts = append(m.unicasts, u)
    74  	m.streamFactory.SetStreamHandler(u.ProtocolId(), u.Handler)
    75  	m.logger.Info().Str("protocol_id", string(u.ProtocolId())).Msg("unicast handler registered")
    76  
    77  	return nil
    78  }
    79  
    80  // CreateStream tries establishing a libp2p stream to the remote peer id. It tries creating streams in the descending order of preference until
    81  // it either creates a successful stream or runs out of options. Creating stream on each protocol is tried at most `maxAttempt` one, and then falls
    82  // back to the less preferred one.
    83  func (m *Manager) CreateStream(ctx context.Context, peerID peer.ID, maxAttempts int) (libp2pnet.Stream, []multiaddr.Multiaddr, error) {
    84  	var errs error
    85  
    86  	for i := len(m.unicasts) - 1; i >= 0; i-- {
    87  		s, addrs, err := m.rawStreamWithProtocol(ctx, m.unicasts[i].ProtocolId(), peerID, maxAttempts)
    88  		if err != nil {
    89  			errs = multierror.Append(errs, err)
    90  			continue
    91  		}
    92  
    93  		s, err = m.unicasts[i].UpgradeRawStream(s)
    94  		if err != nil {
    95  			errs = multierror.Append(errs, fmt.Errorf("could not upgrade stream: %w", err))
    96  			continue
    97  		}
    98  
    99  		// return first successful stream
   100  		return s, addrs, nil
   101  	}
   102  
   103  	return nil, nil, fmt.Errorf("could not create stream on any available unicast protocol: %w", errs)
   104  }
   105  
   106  // rawStreamWithProtocol creates a stream raw libp2p stream on specified protocol.
   107  //
   108  // Note: a raw stream must be upgraded by the given unicast protocol id.
   109  //
   110  // It makes at most `maxAttempts` to create a stream with the peer.
   111  // This was put in as a fix for #2416. PubSub and 1-1 communication compete with each other when trying to connect to
   112  // remote nodes and once in a while NewStream returns an error 'both yamux endpoints are clients'.
   113  //
   114  // Note that in case an existing TCP connection underneath to `peerID` exists, that connection is utilized for creating a new stream.
   115  // The multiaddr.Multiaddr return value represents the addresses of `peerID` we dial while trying to create a stream to it.
   116  func (m *Manager) rawStreamWithProtocol(ctx context.Context,
   117  	protocolID protocol.ID,
   118  	peerID peer.ID,
   119  	maxAttempts int) (libp2pnet.Stream, []multiaddr.Multiaddr, error) {
   120  
   121  	var errs error
   122  	var s libp2pnet.Stream
   123  	var retries = 0
   124  	var dialAddr []multiaddr.Multiaddr // address on which we dial peerID
   125  	for ; retries < maxAttempts; retries++ {
   126  		select {
   127  		case <-ctx.Done():
   128  			return nil, nil, fmt.Errorf("context done before stream could be created (retry attempt: %d, errors: %w)", retries, errs)
   129  		default:
   130  		}
   131  
   132  		// libp2p internally uses swarm dial - https://github.com/libp2p/go-libp2p-swarm/blob/master/swarm_dial.go
   133  		// to connect to a peer. Swarm dial adds a back off each time it fails connecting to a peer. While this is
   134  		// the desired behaviour for pub-sub (1-k style of communication) for 1-1 style we want to retry the connection
   135  		// immediately without backing off and fail-fast.
   136  		// Hence, explicitly cancel the dial back off (if any) and try connecting again
   137  
   138  		// cancel the dial back off (if any), since we want to connect immediately
   139  		dialAddr = m.streamFactory.DialAddress(peerID)
   140  		m.streamFactory.ClearBackoff(peerID)
   141  
   142  		// if this is a retry attempt, wait for some time before retrying
   143  		if retries > 0 {
   144  			// choose a random interval between 0 to 5
   145  			// (to ensure that this node and the target node don't attempt to reconnect at the same time)
   146  			r := rand.Intn(MaxConnectAttemptSleepDuration)
   147  			time.Sleep(time.Duration(r) * time.Millisecond)
   148  		}
   149  
   150  		err := m.streamFactory.Connect(ctx, peer.AddrInfo{ID: peerID})
   151  		if err != nil {
   152  
   153  			// if the connection was rejected due to invalid node id, skip the re-attempt
   154  			if strings.Contains(err.Error(), "failed to negotiate security protocol") {
   155  				return s, dialAddr, fmt.Errorf("invalid node id: %w", err)
   156  			}
   157  
   158  			// if the connection was rejected due to allowlisting, skip the re-attempt
   159  			if errors.Is(err, swarm.ErrGaterDisallowedConnection) {
   160  				return s, dialAddr, fmt.Errorf("target node is not on the approved list of nodes: %w", err)
   161  			}
   162  
   163  			errs = multierror.Append(errs, err)
   164  			continue
   165  		}
   166  
   167  		// creates stream using stream factory
   168  		s, err = m.streamFactory.NewStream(ctx, peerID, protocolID)
   169  		if err != nil {
   170  			// if the stream creation failed due to invalid protocol id, skip the re-attempt
   171  			if strings.Contains(err.Error(), "protocol not supported") {
   172  				return nil, dialAddr, fmt.Errorf("remote node is running on a different spork: %w, protocol attempted: %s", err, protocolID)
   173  			}
   174  			errs = multierror.Append(errs, err)
   175  			continue
   176  		}
   177  
   178  		break
   179  	}
   180  
   181  	if retries == maxAttempts {
   182  		return s, dialAddr, errs
   183  	}
   184  
   185  	return s, dialAddr, nil
   186  }