github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/internal/p2p/transport_mconn.go (about)

     1  package p2p
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"math"
     9  	"net"
    10  	"strconv"
    11  	"sync"
    12  
    13  	"golang.org/x/net/netutil"
    14  
    15  	"github.com/ari-anchor/sei-tendermint/crypto"
    16  	"github.com/ari-anchor/sei-tendermint/internal/libs/protoio"
    17  	"github.com/ari-anchor/sei-tendermint/internal/p2p/conn"
    18  	"github.com/ari-anchor/sei-tendermint/libs/log"
    19  	p2pproto "github.com/ari-anchor/sei-tendermint/proto/tendermint/p2p"
    20  	"github.com/ari-anchor/sei-tendermint/types"
    21  )
    22  
    23  const (
    24  	MConnProtocol Protocol = "mconn"
    25  	TCPProtocol   Protocol = "tcp"
    26  )
    27  
    28  // MConnTransportOptions sets options for MConnTransport.
    29  type MConnTransportOptions struct {
    30  	// MaxAcceptedConnections is the maximum number of simultaneous accepted
    31  	// (incoming) connections. Beyond this, new connections will block until
    32  	// a slot is free. 0 means unlimited.
    33  	//
    34  	// FIXME: We may want to replace this with connection accounting in the
    35  	// Router, since it will need to do e.g. rate limiting and such as well.
    36  	// But it might also make sense to have per-transport limits.
    37  	MaxAcceptedConnections uint32
    38  }
    39  
    40  // MConnTransport is a Transport implementation using the current multiplexed
    41  // Tendermint protocol ("MConn").
    42  type MConnTransport struct {
    43  	logger       log.Logger
    44  	options      MConnTransportOptions
    45  	mConnConfig  conn.MConnConfig
    46  	channelDescs []*ChannelDescriptor
    47  
    48  	closeOnce sync.Once
    49  	doneCh    chan struct{}
    50  	listener  net.Listener
    51  }
    52  
    53  // NewMConnTransport sets up a new MConnection transport. This uses the
    54  // proprietary Tendermint MConnection protocol, which is implemented as
    55  // conn.MConnection.
    56  func NewMConnTransport(
    57  	logger log.Logger,
    58  	mConnConfig conn.MConnConfig,
    59  	channelDescs []*ChannelDescriptor,
    60  	options MConnTransportOptions,
    61  ) *MConnTransport {
    62  	return &MConnTransport{
    63  		logger:       logger,
    64  		options:      options,
    65  		mConnConfig:  mConnConfig,
    66  		doneCh:       make(chan struct{}),
    67  		channelDescs: channelDescs,
    68  	}
    69  }
    70  
    71  // String implements Transport.
    72  func (m *MConnTransport) String() string {
    73  	return string(MConnProtocol)
    74  }
    75  
    76  // Protocols implements Transport. We support tcp for backwards-compatibility.
    77  func (m *MConnTransport) Protocols() []Protocol {
    78  	return []Protocol{MConnProtocol, TCPProtocol}
    79  }
    80  
    81  // Endpoint implements Transport.
    82  func (m *MConnTransport) Endpoint() (*Endpoint, error) {
    83  	if m.listener == nil {
    84  		return nil, errors.New("listenter not defined")
    85  	}
    86  	select {
    87  	case <-m.doneCh:
    88  		return nil, errors.New("transport closed")
    89  	default:
    90  	}
    91  
    92  	endpoint := &Endpoint{
    93  		Protocol: MConnProtocol,
    94  	}
    95  	if addr, ok := m.listener.Addr().(*net.TCPAddr); ok {
    96  		endpoint.IP = addr.IP
    97  		endpoint.Port = uint16(addr.Port)
    98  	}
    99  	return endpoint, nil
   100  }
   101  
   102  // Listen asynchronously listens for inbound connections on the given endpoint.
   103  // It must be called exactly once before calling Accept(), and the caller must
   104  // call Close() to shut down the listener.
   105  //
   106  // FIXME: Listen currently only supports listening on a single endpoint, it
   107  // might be useful to support listening on multiple addresses (e.g. IPv4 and
   108  // IPv6, or a private and public address) via multiple Listen() calls.
   109  func (m *MConnTransport) Listen(endpoint *Endpoint) error {
   110  	if m.listener != nil {
   111  		return errors.New("transport is already listening")
   112  	}
   113  	if err := m.validateEndpoint(endpoint); err != nil {
   114  		return err
   115  	}
   116  
   117  	listener, err := net.Listen("tcp", net.JoinHostPort(
   118  		endpoint.IP.String(), strconv.Itoa(int(endpoint.Port))))
   119  	if err != nil {
   120  		return err
   121  	}
   122  	if m.options.MaxAcceptedConnections > 0 {
   123  		// FIXME: This will establish the inbound connection but simply hang it
   124  		// until another connection is released. It would probably be better to
   125  		// return an error to the remote peer or close the connection. This is
   126  		// also a DoS vector since the connection will take up kernel resources.
   127  		// This was just carried over from the legacy P2P stack.
   128  		listener = netutil.LimitListener(listener, int(m.options.MaxAcceptedConnections))
   129  	}
   130  	m.listener = listener
   131  
   132  	return nil
   133  }
   134  
   135  // Accept implements Transport.
   136  func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) {
   137  	if m.listener == nil {
   138  		return nil, errors.New("transport is not listening")
   139  	}
   140  
   141  	conCh := make(chan net.Conn)
   142  	errCh := make(chan error)
   143  	go func() {
   144  		tcpConn, err := m.listener.Accept()
   145  		if err != nil {
   146  			select {
   147  			case errCh <- err:
   148  			case <-ctx.Done():
   149  			}
   150  		}
   151  		select {
   152  		case conCh <- tcpConn:
   153  		case <-ctx.Done():
   154  		}
   155  	}()
   156  
   157  	select {
   158  	case <-ctx.Done():
   159  		m.listener.Close()
   160  		return nil, io.EOF
   161  	case <-m.doneCh:
   162  		m.listener.Close()
   163  		return nil, io.EOF
   164  	case err := <-errCh:
   165  		return nil, err
   166  	case tcpConn := <-conCh:
   167  		return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil
   168  	}
   169  
   170  }
   171  
   172  // Dial implements Transport.
   173  func (m *MConnTransport) Dial(ctx context.Context, endpoint *Endpoint) (Connection, error) {
   174  	if err := m.validateEndpoint(endpoint); err != nil {
   175  		return nil, err
   176  	}
   177  	if endpoint.Port == 0 {
   178  		endpoint.Port = 26657
   179  	}
   180  
   181  	dialer := net.Dialer{}
   182  	tcpConn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort(
   183  		endpoint.IP.String(), strconv.Itoa(int(endpoint.Port))))
   184  	if err != nil {
   185  		select {
   186  		case <-ctx.Done():
   187  			return nil, ctx.Err()
   188  		default:
   189  			return nil, err
   190  		}
   191  	}
   192  
   193  	return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil
   194  }
   195  
   196  // Close implements Transport.
   197  func (m *MConnTransport) Close() error {
   198  	var err error
   199  	m.closeOnce.Do(func() {
   200  		close(m.doneCh)
   201  		if m.listener != nil {
   202  			err = m.listener.Close()
   203  		}
   204  	})
   205  	return err
   206  }
   207  
   208  // SetChannels sets the channel descriptors to be used when
   209  // establishing a connection.
   210  //
   211  // FIXME: To be removed when the legacy p2p stack is removed. Channel
   212  // descriptors should be managed by the router. The underlying transport and
   213  // connections should be agnostic to everything but the channel ID's which are
   214  // initialized in the handshake.
   215  func (m *MConnTransport) AddChannelDescriptors(channelDesc []*ChannelDescriptor) {
   216  	m.channelDescs = append(m.channelDescs, channelDesc...)
   217  }
   218  
   219  // validateEndpoint validates an endpoint.
   220  func (m *MConnTransport) validateEndpoint(endpoint *Endpoint) error {
   221  	if err := endpoint.Validate(); err != nil {
   222  		return err
   223  	}
   224  	if endpoint.Protocol != MConnProtocol && endpoint.Protocol != TCPProtocol {
   225  		return fmt.Errorf("unsupported protocol %q", endpoint.Protocol)
   226  	}
   227  	if len(endpoint.IP) == 0 {
   228  		return errors.New("endpoint has no IP address")
   229  	}
   230  	if endpoint.Path != "" {
   231  		return fmt.Errorf("endpoints with path not supported (got %q)", endpoint.Path)
   232  	}
   233  	return nil
   234  }
   235  
   236  // mConnConnection implements Connection for MConnTransport.
   237  type mConnConnection struct {
   238  	logger       log.Logger
   239  	conn         net.Conn
   240  	mConnConfig  conn.MConnConfig
   241  	channelDescs []*ChannelDescriptor
   242  	receiveCh    chan mConnMessage
   243  	errorCh      chan error
   244  	doneCh       chan struct{}
   245  	closeOnce    sync.Once
   246  
   247  	mconn *conn.MConnection // set during Handshake()
   248  }
   249  
   250  // mConnMessage passes MConnection messages through internal channels.
   251  type mConnMessage struct {
   252  	channelID ChannelID
   253  	payload   []byte
   254  }
   255  
   256  // newMConnConnection creates a new mConnConnection.
   257  func newMConnConnection(
   258  	logger log.Logger,
   259  	conn net.Conn,
   260  	mConnConfig conn.MConnConfig,
   261  	channelDescs []*ChannelDescriptor,
   262  ) *mConnConnection {
   263  	return &mConnConnection{
   264  		logger:       logger,
   265  		conn:         conn,
   266  		mConnConfig:  mConnConfig,
   267  		channelDescs: channelDescs,
   268  		receiveCh:    make(chan mConnMessage),
   269  		errorCh:      make(chan error, 1), // buffered to avoid onError leak
   270  		doneCh:       make(chan struct{}),
   271  	}
   272  }
   273  
   274  // Handshake implements Connection.
   275  func (c *mConnConnection) Handshake(
   276  	ctx context.Context,
   277  	nodeInfo types.NodeInfo,
   278  	privKey crypto.PrivKey,
   279  ) (types.NodeInfo, crypto.PubKey, error) {
   280  	var (
   281  		mconn    *conn.MConnection
   282  		peerInfo types.NodeInfo
   283  		peerKey  crypto.PubKey
   284  		errCh    = make(chan error, 1)
   285  	)
   286  	// To handle context cancellation, we need to do the handshake in a
   287  	// goroutine and abort the blocking network calls by closing the connection
   288  	// when the context is canceled.
   289  	go func() {
   290  		// FIXME: Since the MConnection code panics, we need to recover it and turn it
   291  		// into an error. We should remove panics instead.
   292  		defer func() {
   293  			if r := recover(); r != nil {
   294  				errCh <- fmt.Errorf("recovered from panic: %v", r)
   295  			}
   296  		}()
   297  		var err error
   298  		mconn, peerInfo, peerKey, err = c.handshake(ctx, nodeInfo, privKey)
   299  
   300  		select {
   301  		case errCh <- err:
   302  		case <-ctx.Done():
   303  		}
   304  
   305  	}()
   306  
   307  	select {
   308  	case <-ctx.Done():
   309  		_ = c.Close()
   310  		return types.NodeInfo{}, nil, ctx.Err()
   311  
   312  	case err := <-errCh:
   313  		if err != nil {
   314  			return types.NodeInfo{}, nil, err
   315  		}
   316  		c.mconn = mconn
   317  		if err = c.mconn.Start(ctx); err != nil {
   318  			return types.NodeInfo{}, nil, err
   319  		}
   320  		return peerInfo, peerKey, nil
   321  	}
   322  }
   323  
   324  // handshake is a helper for Handshake, simplifying error handling so we can
   325  // keep context handling and panic recovery in Handshake. It returns an
   326  // unstarted but handshaked MConnection, to avoid concurrent field writes.
   327  func (c *mConnConnection) handshake(
   328  	ctx context.Context,
   329  	nodeInfo types.NodeInfo,
   330  	privKey crypto.PrivKey,
   331  ) (*conn.MConnection, types.NodeInfo, crypto.PubKey, error) {
   332  	if c.mconn != nil {
   333  		return nil, types.NodeInfo{}, nil, errors.New("connection is already handshaked")
   334  	}
   335  
   336  	secretConn, err := conn.MakeSecretConnection(c.conn, privKey)
   337  	if err != nil {
   338  		return nil, types.NodeInfo{}, nil, err
   339  	}
   340  
   341  	wg := &sync.WaitGroup{}
   342  	var pbPeerInfo p2pproto.NodeInfo
   343  	errCh := make(chan error, 2)
   344  	wg.Add(1)
   345  	go func() {
   346  		defer wg.Done()
   347  		_, err := protoio.NewDelimitedWriter(secretConn).WriteMsg(nodeInfo.ToProto())
   348  		select {
   349  		case errCh <- err:
   350  		case <-ctx.Done():
   351  		}
   352  
   353  	}()
   354  	wg.Add(1)
   355  	go func() {
   356  		defer wg.Done()
   357  		_, err := protoio.NewDelimitedReader(secretConn, types.MaxNodeInfoSize()).ReadMsg(&pbPeerInfo)
   358  		select {
   359  		case errCh <- err:
   360  		case <-ctx.Done():
   361  		}
   362  	}()
   363  
   364  	wg.Wait()
   365  
   366  	if err, ok := <-errCh; ok && err != nil {
   367  		return nil, types.NodeInfo{}, nil, err
   368  	}
   369  
   370  	if err := ctx.Err(); err != nil {
   371  		return nil, types.NodeInfo{}, nil, err
   372  	}
   373  
   374  	peerInfo, err := types.NodeInfoFromProto(&pbPeerInfo)
   375  	if err != nil {
   376  		return nil, types.NodeInfo{}, nil, err
   377  	}
   378  
   379  	c.logger.Debug(fmt.Sprintf("Creating a new MConnection with peerId %s, moniker %s, listenAddr %s", peerInfo.NodeID, peerInfo.Moniker, peerInfo.ListenAddr))
   380  
   381  	mconn := conn.NewMConnection(
   382  		c.logger.With("peer", c.RemoteEndpoint().NodeAddress(peerInfo.NodeID)),
   383  		secretConn,
   384  		c.channelDescs,
   385  		c.onReceive,
   386  		c.onError,
   387  		c.mConnConfig,
   388  	)
   389  
   390  	return mconn, peerInfo, secretConn.RemotePubKey(), nil
   391  }
   392  
   393  // onReceive is a callback for MConnection received messages.
   394  func (c *mConnConnection) onReceive(ctx context.Context, chID ChannelID, payload []byte) {
   395  	select {
   396  	case c.receiveCh <- mConnMessage{channelID: chID, payload: payload}:
   397  	case <-ctx.Done():
   398  	}
   399  }
   400  
   401  // onError is a callback for MConnection errors. The error is passed via errorCh
   402  // to ReceiveMessage (but not SendMessage, for legacy P2P stack behavior).
   403  func (c *mConnConnection) onError(ctx context.Context, e interface{}) {
   404  	err, ok := e.(error)
   405  	if !ok {
   406  		err = fmt.Errorf("%v", err)
   407  	}
   408  	// We have to close the connection here, since MConnection will have stopped
   409  	// the service on any errors.
   410  	_ = c.Close()
   411  	select {
   412  	case c.errorCh <- err:
   413  		c.logger.Error(fmt.Sprintf("mConnection Error %s", err))
   414  	case <-ctx.Done():
   415  	}
   416  }
   417  
   418  // String displays connection information.
   419  func (c *mConnConnection) String() string {
   420  	return c.RemoteEndpoint().String()
   421  }
   422  
   423  // SendMessage implements Connection.
   424  func (c *mConnConnection) SendMessage(ctx context.Context, chID ChannelID, msg []byte) error {
   425  	if chID > math.MaxUint8 {
   426  		return fmt.Errorf("MConnection only supports 1-byte channel IDs (got %v)", chID)
   427  	}
   428  	select {
   429  	case err := <-c.errorCh:
   430  		return err
   431  	case <-ctx.Done():
   432  		return io.EOF
   433  	default:
   434  		if ok := c.mconn.Send(chID, msg); !ok {
   435  			return errors.New("sending message timed out")
   436  		}
   437  
   438  		return nil
   439  	}
   440  }
   441  
   442  // ReceiveMessage implements Connection.
   443  func (c *mConnConnection) ReceiveMessage(ctx context.Context) (ChannelID, []byte, error) {
   444  	select {
   445  	case err := <-c.errorCh:
   446  		return 0, nil, err
   447  	case <-c.doneCh:
   448  		return 0, nil, io.EOF
   449  	case <-ctx.Done():
   450  		return 0, nil, io.EOF
   451  	case msg := <-c.receiveCh:
   452  		return msg.channelID, msg.payload, nil
   453  	}
   454  }
   455  
   456  // LocalEndpoint implements Connection.
   457  func (c *mConnConnection) LocalEndpoint() Endpoint {
   458  	endpoint := Endpoint{
   459  		Protocol: MConnProtocol,
   460  	}
   461  	if addr, ok := c.conn.LocalAddr().(*net.TCPAddr); ok {
   462  		endpoint.IP = addr.IP
   463  		endpoint.Port = uint16(addr.Port)
   464  	}
   465  	return endpoint
   466  }
   467  
   468  // RemoteEndpoint implements Connection.
   469  func (c *mConnConnection) RemoteEndpoint() Endpoint {
   470  	endpoint := Endpoint{
   471  		Protocol: MConnProtocol,
   472  	}
   473  	if addr, ok := c.conn.RemoteAddr().(*net.TCPAddr); ok {
   474  		endpoint.IP = addr.IP
   475  		endpoint.Port = uint16(addr.Port)
   476  	}
   477  	return endpoint
   478  }
   479  
   480  // Close implements Connection.
   481  func (c *mConnConnection) Close() error {
   482  	var err error
   483  	c.closeOnce.Do(func() {
   484  		defer close(c.doneCh)
   485  
   486  		if c.mconn != nil && c.mconn.IsRunning() {
   487  			c.mconn.Stop()
   488  		} else {
   489  			err = c.conn.Close()
   490  		}
   491  	})
   492  	return err
   493  }