github.com/aergoio/aergo@v1.3.1/p2p/v030/v030handshake.go (about)

     1  /*
     2   * @file
     3   * @copyright defined in aergo/LICENSE.txt
     4   */
     5  
     6  package v030
     7  
     8  import (
     9  	"context"
    10  	"fmt"
    11  	"github.com/aergoio/aergo/internal/network"
    12  	"github.com/aergoio/aergo/p2p/p2pkey"
    13  	"io"
    14  	"time"
    15  
    16  	"github.com/aergoio/aergo-lib/log"
    17  	"github.com/aergoio/aergo/p2p/p2pcommon"
    18  	"github.com/aergoio/aergo/p2p/p2putil"
    19  	"github.com/aergoio/aergo/types"
    20  )
    21  
    22  // V030Handshaker exchange status data over protocol version .0.3.0
    23  type V030Handshaker struct {
    24  	pm      p2pcommon.PeerManager
    25  	actor   p2pcommon.ActorService
    26  	logger  *log.Logger
    27  	peerID  types.PeerID
    28  	chainID *types.ChainID
    29  
    30  	msgRW p2pcommon.MsgReadWriter
    31  }
    32  
    33  var _ p2pcommon.VersionedHandshaker = (*V030Handshaker)(nil)
    34  
    35  func (h *V030Handshaker) GetMsgRW() p2pcommon.MsgReadWriter {
    36  	return h.msgRW
    37  }
    38  
    39  func NewV030VersionedHS(pm p2pcommon.PeerManager, actor p2pcommon.ActorService, log *log.Logger, chainID *types.ChainID, peerID types.PeerID, rwc io.ReadWriteCloser) *V030Handshaker {
    40  	h := &V030Handshaker{pm: pm, actor: actor, logger: log, chainID: chainID, peerID: peerID}
    41  	h.msgRW = NewV030MsgPipe(rwc)
    42  	return h
    43  }
    44  
    45  // handshakeOutboundPeer start handshake with outbound peer
    46  func (h *V030Handshaker) DoForOutbound(ctx context.Context) (*types.Status, error) {
    47  	// TODO need to check auth at first...
    48  	h.logger.Debug().Str(p2putil.LogPeerID, p2putil.ShortForm(h.peerID)).Msg("Starting versioned handshake for outbound peer connection")
    49  
    50  	status, err := createStatus(h.pm, h.actor, h.chainID, nil)
    51  	if err != nil {
    52  		h.logger.Warn().Err(err).Msg("Failed to create status message.")
    53  		h.sendGoAway("internal error")
    54  		return nil, err
    55  	}
    56  
    57  	err = h.sendLocalStatus(ctx, status)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	remotePeerStatus, err := h.receiveRemoteStatus(ctx)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	if err = h.checkRemoteStatus(remotePeerStatus); err != nil {
    68  		return nil, err
    69  	} else {
    70  		return remotePeerStatus, nil
    71  	}
    72  }
    73  
    74  func (h *V030Handshaker) sendLocalStatus(ctx context.Context, hostStatus *types.Status) error {
    75  	var err error
    76  	container := createMessage(p2pcommon.StatusRequest, p2pcommon.NewMsgID(), hostStatus)
    77  	if container == nil {
    78  		h.logger.Warn().Str(p2putil.LogPeerID, p2putil.ShortForm(h.peerID)).Msg("failed to create p2p message")
    79  		h.sendGoAway("internal error")
    80  		// h.logger.Warn().Str(LogPeerID, ShortForm(peerID)).Err(err).Msg("failed to create p2p message")
    81  		return fmt.Errorf("failed to craete container message")
    82  	}
    83  	if err = h.msgRW.WriteMsg(container); err != nil {
    84  		h.logger.Info().Str(p2putil.LogPeerID, p2putil.ShortForm(h.peerID)).Err(err).Msg("failed to write local status ")
    85  		return err
    86  	}
    87  	select {
    88  	case <-ctx.Done():
    89  		return ctx.Err()
    90  	default:
    91  		// go on
    92  	}
    93  	return nil
    94  }
    95  
    96  func (h *V030Handshaker) receiveRemoteStatus(ctx context.Context) (*types.Status, error) {
    97  	// and wait to response status
    98  	data, err := h.msgRW.ReadMsg()
    99  	if err != nil {
   100  		h.sendGoAway("malformed message")
   101  		// h.logger.Info().Err(err).Msg("fail to decode")
   102  		return nil, err
   103  	}
   104  	select {
   105  	case <-ctx.Done():
   106  		return nil, ctx.Err()
   107  	default:
   108  		// go on
   109  	}
   110  	if data.Subprotocol() != p2pcommon.StatusRequest {
   111  		if data.Subprotocol() == p2pcommon.GoAway {
   112  			return h.handleGoAway(h.peerID, data)
   113  		} else {
   114  			h.logger.Info().Str(p2putil.LogPeerID, p2putil.ShortForm(h.peerID)).Str("expected", p2pcommon.StatusRequest.String()).Str("actual", data.Subprotocol().String()).Msg("unexpected message type")
   115  			h.sendGoAway("unexpected message type")
   116  			return nil, fmt.Errorf("unexpected message type")
   117  		}
   118  	}
   119  
   120  	remotePeerStatus := &types.Status{}
   121  	err = p2putil.UnmarshalMessageBody(data.Payload(), remotePeerStatus)
   122  	if err != nil {
   123  		h.sendGoAway("malformed status message")
   124  		return nil, err
   125  	}
   126  
   127  	return remotePeerStatus, nil
   128  }
   129  
   130  func (h *V030Handshaker) checkRemoteStatus(remotePeerStatus *types.Status) error {
   131  	// check if chainID is same or not
   132  	remoteChainID := types.NewChainID()
   133  	err := remoteChainID.Read(remotePeerStatus.ChainID)
   134  	if err != nil {
   135  		h.sendGoAway("wrong status")
   136  		return err
   137  	}
   138  	if !h.chainID.Equals(remoteChainID) {
   139  		h.sendGoAway("different chainID")
   140  		return fmt.Errorf("different chainID : %s", remoteChainID.ToJSON())
   141  	}
   142  
   143  	peerAddress := remotePeerStatus.Sender
   144  	if peerAddress == nil || network.CheckAddressType(peerAddress.Address) == network.AddressTypeError {
   145  		h.sendGoAway("invalid peer address")
   146  		return fmt.Errorf("invalid peer address : %s", peerAddress)
   147  	}
   148  
   149  	rMeta := p2pcommon.FromPeerAddress(peerAddress)
   150  	if rMeta.ID != h.peerID {
   151  		h.logger.Debug().Str("received_peer_id", rMeta.ID.Pretty()).Str(p2putil.LogPeerID, p2putil.ShortForm(h.peerID)).Msg("Inconsistent peerID")
   152  		h.sendGoAway("Inconsistent peerID")
   153  		return fmt.Errorf("Inconsistent peerID")
   154  	}
   155  
   156  	return nil
   157  }
   158  
   159  // DoForInbound is handle handshake from inbound peer
   160  func (h *V030Handshaker) DoForInbound(ctx context.Context) (*types.Status, error) {
   161  	// TODO need to check auth at first...
   162  	h.logger.Debug().Str(p2putil.LogPeerID, p2putil.ShortForm(h.peerID)).Msg("Starting versioned handshake for inbound peer connection")
   163  
   164  	// inbound: receive, check and send
   165  	remotePeerStatus, err := h.receiveRemoteStatus(ctx)
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  	if err = h.checkRemoteStatus(remotePeerStatus); err != nil {
   170  		return nil, err
   171  	}
   172  
   173  	// send my localStatus message as response
   174  	localStatus, err := createStatus(h.pm, h.actor, h.chainID, nil)
   175  	if err != nil {
   176  		h.logger.Warn().Err(err).Msg("Failed to create localStatus message.")
   177  		h.sendGoAway("internal error")
   178  		return nil, err
   179  	}
   180  	err = h.sendLocalStatus(ctx, localStatus)
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  	return remotePeerStatus, nil
   185  }
   186  
   187  func (h *V030Handshaker) handleGoAway(peerID types.PeerID, data p2pcommon.Message) (*types.Status, error) {
   188  	goAway := &types.GoAwayNotice{}
   189  	if err := p2putil.UnmarshalMessageBody(data.Payload(), goAway); err != nil {
   190  		h.logger.Warn().Str(p2putil.LogPeerID, p2putil.ShortForm(peerID)).Err(err).Msg("Remote peer sent goAway but failed to decode internal message")
   191  		return nil, err
   192  	}
   193  	return nil, fmt.Errorf("remote peer refuse handshake: %s", goAway.GetMessage())
   194  }
   195  
   196  func (h *V030Handshaker) sendGoAway(msg string) {
   197  	goMsg := createMessage(p2pcommon.GoAway, p2pcommon.NewMsgID(), &types.GoAwayNotice{Message: msg})
   198  	if goMsg != nil {
   199  		h.msgRW.WriteMsg(goMsg)
   200  	}
   201  }
   202  
   203  func createStatus(pm p2pcommon.PeerManager, actor p2pcommon.ActorService, chainID *types.ChainID, genesis []byte) (*types.Status, error) {
   204  	// find my best block
   205  	bestBlock, err := actor.GetChainAccessor().GetBestBlock()
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  	selfAddr := pm.SelfMeta().ToPeerAddress()
   210  	chainIDbytes, err := chainID.Bytes()
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  	// create message data
   215  	statusMsg := &types.Status{
   216  		Sender:        &selfAddr,
   217  		ChainID:       chainIDbytes,
   218  		BestBlockHash: bestBlock.BlockHash(),
   219  		BestHeight:    bestBlock.GetHeader().GetBlockNo(),
   220  		NoExpose:      pm.SelfMeta().Hidden,
   221  		Version:       p2pkey.NodeVersion(),
   222  		Genesis:       genesis,
   223  	}
   224  
   225  	return statusMsg, nil
   226  }
   227  
   228  func createMessage(protocolID p2pcommon.SubProtocol, msgID p2pcommon.MsgID, msgBody p2pcommon.MessageBody) p2pcommon.Message {
   229  	bytes, err := p2putil.MarshalMessageBody(msgBody)
   230  	if err != nil {
   231  		return nil
   232  	}
   233  
   234  	msg := p2pcommon.NewMessageValue(protocolID, msgID, p2pcommon.EmptyID, time.Now().UnixNano(), bytes)
   235  	return msg
   236  }