github.com/aergoio/aergo@v1.3.1/p2p/handshake.go (about)

     1  /**
     2   *  @file
     3   *  @copyright defined in aergo/LICENSE.txt
     4   */
     5  
     6  package p2p
     7  
     8  import (
     9  	"context"
    10  	"fmt"
    11  	"github.com/aergoio/aergo/p2p/v030"
    12  	"io"
    13  	"time"
    14  
    15  	"github.com/aergoio/aergo-lib/log"
    16  	"github.com/aergoio/aergo/p2p/p2pcommon"
    17  	"github.com/aergoio/aergo/types"
    18  )
    19  
    20  // LegacyInboundHSHandler handshake handler for legacy version
    21  type LegacyInboundHSHandler struct {
    22  	*LegacyWireHandshaker
    23  }
    24  
    25  func (ih *LegacyInboundHSHandler) Handle(s io.ReadWriteCloser, ttl time.Duration) (p2pcommon.MsgReadWriter, *types.Status, error) {
    26  	ctx, cancel := context.WithTimeout(context.Background(), ttl)
    27  	defer cancel()
    28  	return ih.handshakeInboundPeer(ctx, s)
    29  }
    30  
    31  // LegacyOutboundHSHandler handshake handler for legacy version
    32  type LegacyOutboundHSHandler struct {
    33  	*LegacyWireHandshaker
    34  }
    35  
    36  func (oh *LegacyOutboundHSHandler) Handle(s io.ReadWriteCloser, ttl time.Duration) (p2pcommon.MsgReadWriter, *types.Status, error) {
    37  	ctx, cancel := context.WithTimeout(context.Background(), ttl)
    38  	defer cancel()
    39  	return oh.handshakeOutboundPeer(ctx, s)
    40  }
    41  
    42  // LegacyWireHandshaker works to handshake to just connected peer, it detect chain networks
    43  // and protocol versions, and then select InnerHandshaker for that protocol version.
    44  type LegacyWireHandshaker struct {
    45  	pm     p2pcommon.PeerManager
    46  	actor  p2pcommon.ActorService
    47  	logger *log.Logger
    48  	peerID types.PeerID
    49  	// check if is it ad-hoc
    50  	localChainID *types.ChainID
    51  
    52  	remoteStatus *types.Status
    53  }
    54  
    55  func newHandshaker(pm p2pcommon.PeerManager, actor p2pcommon.ActorService, log *log.Logger, chainID *types.ChainID, peerID types.PeerID) *LegacyWireHandshaker {
    56  	return &LegacyWireHandshaker{pm: pm, actor: actor, logger: log, localChainID: chainID, peerID: peerID}
    57  }
    58  
    59  func (h *LegacyWireHandshaker) handshakeOutboundPeer(ctx context.Context, rwc io.ReadWriteCloser) (p2pcommon.MsgReadWriter, *types.Status, error) {
    60  	// send initial hsmessage
    61  	hsHeader := p2pcommon.HSHeader{Magic: p2pcommon.MAGICTest, Version: p2pcommon.P2PVersion030}
    62  	sent, err := rwc.Write(hsHeader.Marshal())
    63  	if err != nil {
    64  		return nil, nil, err
    65  	}
    66  	select {
    67  	case <-ctx.Done():
    68  		return nil, nil, ctx.Err()
    69  	default:
    70  		// go on
    71  	}
    72  	if sent != len(hsHeader.Marshal()) {
    73  		return nil, nil, fmt.Errorf("transport error")
    74  	}
    75  	// continue to handshake with VersionedHandshaker
    76  	innerHS, err := h.selectProtocolVersion(hsHeader.Version, rwc)
    77  	if err != nil {
    78  		return nil, nil, err
    79  	}
    80  	status, err := innerHS.DoForOutbound(ctx)
    81  	h.remoteStatus = status
    82  	return innerHS.GetMsgRW(), status, err
    83  }
    84  
    85  func (h *LegacyWireHandshaker) handshakeInboundPeer(ctx context.Context, rwc io.ReadWriteCloser) (p2pcommon.MsgReadWriter, *types.Status, error) {
    86  	var hsHeader p2pcommon.HSHeader
    87  	// wait initial hsmessage
    88  	headBuf := make([]byte, p2pcommon.V030HSHeaderLength)
    89  	read, err := h.readToLen(rwc, headBuf, 8)
    90  	if err != nil {
    91  		return nil, nil, err
    92  	}
    93  	select {
    94  	case <-ctx.Done():
    95  		return nil, nil, ctx.Err()
    96  	default:
    97  		// go on
    98  	}
    99  	if read != p2pcommon.V030HSHeaderLength {
   100  		return nil, nil, fmt.Errorf("transport error")
   101  	}
   102  	hsHeader.Unmarshal(headBuf)
   103  
   104  	// continue to handshake with VersionedHandshaker
   105  	innerHS, err := h.selectProtocolVersion(hsHeader.Version, rwc)
   106  	if err != nil {
   107  		return nil, nil, err
   108  	}
   109  	status, err := innerHS.DoForInbound(ctx)
   110  	// send hsresponse
   111  	h.remoteStatus = status
   112  	return innerHS.GetMsgRW(), status, err
   113  }
   114  
   115  func (h *LegacyWireHandshaker) readToLen(rd io.Reader, bf []byte, max int) (int, error) {
   116  	remain := max
   117  	offset := 0
   118  	for remain > 0 {
   119  		read, err := rd.Read(bf[offset:])
   120  		if err != nil {
   121  			return offset, err
   122  		}
   123  		remain -= read
   124  		offset += read
   125  	}
   126  	return offset, nil
   127  }
   128  
   129  func (h *LegacyWireHandshaker) selectProtocolVersion(version p2pcommon.P2PVersion, rwc io.ReadWriteCloser) (p2pcommon.VersionedHandshaker, error) {
   130  	switch version {
   131  	case p2pcommon.P2PVersion030:
   132  		v030hs := v030.NewV030VersionedHS(h.pm, h.actor, h.logger, h.localChainID, h.peerID, rwc)
   133  		return v030hs, nil
   134  	default:
   135  		return nil, fmt.Errorf("not supported version")
   136  	}
   137  }
   138