github.com/amazechain/amc@v0.1.3/internal/network/node.go (about)

     1  // Copyright 2022 The AmazeChain Authors
     2  // This file is part of the AmazeChain library.
     3  //
     4  // The AmazeChain library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The AmazeChain library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the AmazeChain library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package network
    18  
    19  import (
    20  	"bufio"
    21  	"bytes"
    22  	"context"
    23  	"fmt"
    24  	"github.com/amazechain/amc/common/hexutil"
    25  	"github.com/amazechain/amc/log"
    26  	"github.com/amazechain/amc/utils"
    27  	"github.com/holiman/uint256"
    28  	"google.golang.org/protobuf/proto"
    29  	"io"
    30  	"sync"
    31  	"time"
    32  
    33  	"github.com/amazechain/amc/api/protocol/msg_proto"
    34  	"github.com/amazechain/amc/common"
    35  	"github.com/amazechain/amc/common/message"
    36  	"github.com/amazechain/amc/common/types"
    37  	"github.com/libp2p/go-libp2p/core/crypto"
    38  	"github.com/libp2p/go-libp2p/core/host"
    39  	"github.com/libp2p/go-libp2p/core/network"
    40  	libpeer "github.com/libp2p/go-libp2p/core/peer"
    41  	"github.com/libp2p/go-libp2p/core/protocol"
    42  )
    43  
    44  var (
    45  	pingPongTimeOut  = time.Duration(1) * time.Second
    46  	handshakeTimeOut = time.Duration(2) * time.Second
    47  )
    48  
    49  type NodeConfig struct {
    50  	Stream     network.Stream
    51  	Protocol   protocol.ID
    52  	CacheCount int
    53  }
    54  
    55  type NodeOption func(*NodeConfig)
    56  
    57  func WithStream(stream network.Stream) NodeOption {
    58  	return func(config *NodeConfig) {
    59  		config.Stream = stream
    60  	}
    61  }
    62  
    63  func WithNodeProtocol(protocol protocol.ID) NodeOption {
    64  	return func(config *NodeConfig) {
    65  		config.Protocol = protocol
    66  	}
    67  }
    68  
    69  func WithNodeCacheCount(count int) NodeOption {
    70  	return func(config *NodeConfig) {
    71  		config.CacheCount = count
    72  	}
    73  }
    74  
    75  type Node struct {
    76  	host.Host
    77  
    78  	peer    *libpeer.AddrInfo
    79  	streams map[string]network.Stream // send and read msg stream protocolID -> stream
    80  
    81  	sync.RWMutex
    82  
    83  	ctx    context.Context
    84  	cancel context.CancelFunc
    85  
    86  	msgCh chan message.IMessage
    87  
    88  	service *Service
    89  
    90  	config      *NodeConfig
    91  	msgCallback map[message.MessageType]common.ConnHandler
    92  	msgLock     sync.RWMutex
    93  
    94  	isOK bool
    95  }
    96  
    97  func NewNode(ctx context.Context, h host.Host, s *Service, peer libpeer.AddrInfo, callback map[message.MessageType]common.ConnHandler, opts ...NodeOption) (*Node, error) {
    98  	c, cancel := context.WithCancel(ctx)
    99  	node := &Node{
   100  		Host:        h,
   101  		peer:        &peer,
   102  		streams:     make(map[string]network.Stream),
   103  		ctx:         c,
   104  		cancel:      cancel,
   105  		isOK:        false,
   106  		service:     s,
   107  		msgCallback: callback,
   108  	}
   109  
   110  	config := NodeConfig{
   111  		Protocol:   MSGProtocol,
   112  		CacheCount: 100,
   113  	}
   114  	for _, opt := range opts {
   115  		opt(&config)
   116  	}
   117  
   118  	var stream network.Stream
   119  	var err error
   120  
   121  	node.msgCh = make(chan message.IMessage, config.CacheCount)
   122  
   123  	if config.Stream != nil {
   124  		stream = config.Stream
   125  	} else {
   126  		stream, err = node.NewStream(node.ctx, peer.ID, config.Protocol)
   127  		if err != nil {
   128  			log.Error("failed to created stream to peer", "PeerID", peer.ID, "PeerAddress", peer.Addrs, "err", err)
   129  			return nil, err
   130  		}
   131  	}
   132  
   133  	node.Lock()
   134  	defer node.Unlock()
   135  	node.streams[string(config.Protocol)] = stream
   136  	node.config = &config
   137  
   138  	return node, nil
   139  }
   140  
   141  func (n *Node) Start() {
   142  	n.isOK = true
   143  	for _, s := range n.streams {
   144  		go n.readData(s)
   145  		go n.writeData(s)
   146  	}
   147  }
   148  
   149  // ProcessHandshake read peer's genesisHash and currentHeight
   150  func (n *Node) ProcessHandshake(h *msg_proto.ProtocolHandshakeMessage) error {
   151  	stream, ok := n.streams[string(n.config.Protocol)]
   152  	if !ok {
   153  		return fmt.Errorf("invalid protocol %stream", n.config.Protocol)
   154  	}
   155  
   156  	errCh := make(chan error)
   157  	defer close(errCh)
   158  
   159  	go func() {
   160  		var header Header
   161  		reader := bufio.NewReader(stream)
   162  		msgType, payloadLen, err := header.Decode(reader)
   163  		if err != nil {
   164  			log.Error("failed to decode protocol handshake msg", "ID", n.peer.ID, "data", stream, "err", err)
   165  			errCh <- err
   166  			return
   167  		}
   168  
   169  		if msgType != message.MsgAppHandshake {
   170  			log.Errorf(badMsgTypeError.Error())
   171  			errCh <- badMsgTypeError
   172  			return
   173  		}
   174  
   175  		payload := make([]byte, payloadLen)
   176  		_, err = io.ReadFull(reader, payload)
   177  		if err != nil {
   178  			log.Error("failed read payload", fmt.Errorf("payload len %d", payloadLen), err)
   179  			errCh <- err
   180  			return
   181  		}
   182  
   183  		var m msg_proto.MessageData
   184  		if err = proto.Unmarshal(payload, &m); err != nil {
   185  			log.Errorf("failed unmarshal to msg, from peer:%s", stream.Conn().ID())
   186  			errCh <- err
   187  			return
   188  		}
   189  		if err = proto.Unmarshal(m.Payload, h); err != nil {
   190  			log.Errorf("failed unmarshal to protocol handshake msessage err: %v", err)
   191  			errCh <- err
   192  			return
   193  		}
   194  		errCh <- nil
   195  	}()
   196  
   197  	select {
   198  	case <-n.ctx.Done():
   199  		return n.ctx.Err()
   200  	case <-time.After(handshakeTimeOut):
   201  		log.Warn("Peer handshake timeout", "PeerID", n.peer.ID, "StreamId", stream.ID(), "ProtocolId", n.config.Protocol)
   202  		_ = stream.Close()
   203  	case err, ok := <-errCh:
   204  		if ok {
   205  			if err != nil {
   206  				return err
   207  			}
   208  			return err
   209  		}
   210  	}
   211  
   212  	return fmt.Errorf("unknown cause failure")
   213  }
   214  
   215  func (n *Node) AcceptHandshake(h *msg_proto.ProtocolHandshakeMessage, version string, genesisHash types.Hash, currentHeight *uint256.Int) error {
   216  	if err := n.ProcessHandshake(h); err != nil {
   217  		return err
   218  	}
   219  
   220  	if err := n.ProtocolHandshake(h, version, genesisHash, currentHeight, false); err != nil {
   221  		return err
   222  	}
   223  
   224  	return nil
   225  }
   226  
   227  // ProtocolHandshake send current peer's genesisHash and height
   228  func (n *Node) ProtocolHandshake(h *msg_proto.ProtocolHandshakeMessage, version string, genesisHash types.Hash, currentHeight *uint256.Int, process bool) error {
   229  	phm := msg_proto.ProtocolHandshakeMessage{
   230  		Version:       version,
   231  		GenesisHash:   utils.ConvertHashToH256(genesisHash),
   232  		CurrentHeight: utils.ConvertUint256IntToH256(currentHeight),
   233  	}
   234  
   235  	b, err := proto.Marshal(&phm)
   236  	if err != nil {
   237  		return err
   238  	}
   239  
   240  	msg := P2PMessage{
   241  		MsgType: message.MsgAppHandshake,
   242  		Payload: b,
   243  		id:      "",
   244  	}
   245  
   246  	s, ok := n.streams[string(n.config.Protocol)]
   247  	if !ok {
   248  		return fmt.Errorf("invalid protocol %s", n.config.Protocol)
   249  	}
   250  
   251  	if err := n.writeMsg(s, &msg); err != nil {
   252  		return err
   253  	}
   254  
   255  	if process {
   256  		return n.ProcessHandshake(h)
   257  	}
   258  
   259  	return nil
   260  }
   261  
   262  func (n *Node) SetHandler(msgType message.MessageType, handler common.ConnHandler) error {
   263  	n.Lock()
   264  	defer n.Unlock()
   265  
   266  	if _, ok := n.msgCallback[msgType]; ok {
   267  		return nil
   268  	}
   269  
   270  	n.msgCallback[msgType] = handler
   271  
   272  	return nil
   273  }
   274  
   275  func (n *Node) readData(stream network.Stream) error {
   276  	defer func() {
   277  		log.Debugf("quit")
   278  		if n.isOK != false {
   279  			n.isOK = false
   280  			stream.Close()
   281  			n.cancel()
   282  			n.service.removeCh <- stream.Conn().RemotePeer()
   283  		}
   284  	}()
   285  	reader := bufio.NewReader(stream)
   286  	for {
   287  		select {
   288  		case <-n.ctx.Done():
   289  			return stream.Close()
   290  		default:
   291  			//log.Debug("start receive peer msg")
   292  			var header Header
   293  			msgType, payloadLen, err := header.Decode(reader)
   294  			if err != nil {
   295  				log.Error("failed read msg", "StreamID", stream.ID(), "PeerID", stream.Conn().RemotePeer().String(), "err", err)
   296  				return err
   297  			}
   298  			//
   299  			ingressTrafficMeter.Mark(int64(payloadLen))
   300  
   301  			payload := make([]byte, payloadLen)
   302  			_, err = io.ReadFull(reader, payload)
   303  			if err != nil {
   304  				log.Error("failed read payload", fmt.Errorf("payload len %d", payloadLen), err)
   305  				return err
   306  			}
   307  
   308  			//log.Debugf("read %d data, payload len %d", c, payloadLen)
   309  			var msg msg_proto.MessageData
   310  			if err = proto.Unmarshal(payload, &msg); err != nil {
   311  				log.Errorf("failed unmarshal to msg, from peer:%s", stream.Conn().ID())
   312  				return err
   313  			}
   314  
   315  			switch msgType {
   316  			case message.MsgPingReq:
   317  				log.Tracef("receive ping msg %s", string(msg.Payload))
   318  				msg := P2PMessage{
   319  					MsgType: message.MsgPingResp,
   320  					Payload: []byte("Hi boy!"),
   321  				}
   322  				n.Write(&msg)
   323  			case message.MsgPingResp:
   324  				log.Tracef("receive pong msg %s", string(msg.Payload))
   325  			default:
   326  				n.msgLock.RLock()
   327  
   328  				if h, ok := n.msgCallback[msgType]; ok {
   329  					log.Debug("receive a p2p msg ", "msgType", msgType, "PeerID", n.ID(), "Content", hexutil.Encode(msg.Payload))
   330  					if err := h(msg.Payload, n.ID()); err != nil {
   331  						n.msgLock.RUnlock()
   332  						log.Errorf("failed dispense data err: %v", err)
   333  						return err
   334  					}
   335  				} else {
   336  					log.Warnf("receive invalid msg, err: %v", badMsgTypeError)
   337  				}
   338  				n.msgLock.RUnlock()
   339  			}
   340  		}
   341  	}
   342  }
   343  
   344  func (n *Node) makeMsg(payload []byte) proto.Message {
   345  	key := n.Peerstore().PrivKey(n.Host.ID())
   346  	//log.Errorf("id: %v", n.Host.ID())
   347  	sign, err := key.Sign(payload)
   348  	if err != nil {
   349  		log.Error("failed to sign ping msg", "err", err)
   350  		return nil
   351  	}
   352  
   353  	nodePubKey, err := crypto.MarshalPublicKey(n.Peerstore().PubKey(n.Host.ID()))
   354  	if err != nil {
   355  		log.Errorf("failed to get public key for sender from local peer store id=%s", n.ID())
   356  		return nil
   357  	}
   358  
   359  	msg := msg_proto.MessageData{
   360  		ClientVersion: "",
   361  		Timestamp:     time.Now().Unix(),
   362  		Id:            "",
   363  		NodeID:        n.ID().String(),
   364  		NodePubKey:    nodePubKey,
   365  		Sign:          sign,
   366  		Payload:       payload,
   367  		Gossip:        false,
   368  	}
   369  
   370  	return &msg
   371  }
   372  
   373  func (n *Node) writeData(stream network.Stream) error {
   374  	defer func() {
   375  		if n.isOK != false {
   376  			n.isOK = false
   377  			n.cancel()
   378  			n.service.removeCh <- stream.Conn().RemotePeer()
   379  			stream.Close()
   380  		}
   381  		close(n.msgCh)
   382  	}()
   383  
   384  	ticker := time.NewTicker(pingPongTimeOut)
   385  	defer ticker.Stop()
   386  
   387  	writeF := func(msg message.IMessage) error {
   388  		var header Header
   389  		//log.Debugf("send msg to peer[%s] type %d", stream.Conn().ID(), msg.Type())
   390  		payload, err := msg.Encode()
   391  		if err != nil {
   392  			log.Errorf("failed to encode msg")
   393  			return err
   394  		}
   395  		if msgData := n.makeMsg(payload); msgData != nil {
   396  			data, err := proto.Marshal(msgData)
   397  			if err != nil {
   398  				log.Errorf("failed marshal message data to byts")
   399  			} else {
   400  				if err := header.Encode(stream, msg.Type(), int32(len(data))); err != nil {
   401  					log.Error("failed to send header", msg.Type(), len(data), err)
   402  					return err
   403  				} else {
   404  					if _, err := stream.Write(data); err != nil {
   405  						log.Errorf("failed to send payload node:%s", stream.Conn().ID())
   406  					} else {
   407  						egressTrafficMeter.Mark(int64(len(data)))
   408  						//log.Debugf("send %d size to node:%s", c, stream.Conn().ID())
   409  					}
   410  				}
   411  			}
   412  		}
   413  		return nil
   414  	}
   415  
   416  	for {
   417  		select {
   418  		case <-n.ctx.Done():
   419  			return stream.Close()
   420  		case m, ok := <-n.msgCh:
   421  			if ok {
   422  				if err := writeF(m); err != nil {
   423  					return err
   424  				}
   425  			} else {
   426  				log.Error("chan was closed")
   427  				return fmt.Errorf("chan ws closed")
   428  			}
   429  		case <-ticker.C:
   430  			msg := P2PMessage{
   431  				MsgType: message.MsgPingReq,
   432  				Payload: []byte("Hello girl"),
   433  			}
   434  			if err := writeF(&msg); err != nil {
   435  				return err
   436  			}
   437  			ticker.Reset(pingPongTimeOut)
   438  		}
   439  	}
   440  }
   441  
   442  func (n *Node) writeMsg(stream network.Stream, msg message.IMessage) error {
   443  	var header Header
   444  	payload, err := msg.Encode()
   445  	if err != nil {
   446  		log.Errorf("failed to encode msg")
   447  		return err
   448  	}
   449  	if msgData := n.makeMsg(payload); msgData != nil {
   450  		data, err := proto.Marshal(msgData)
   451  		if err != nil {
   452  			log.Errorf("failed marshal message data to byts")
   453  		} else {
   454  			var buf = new(bytes.Buffer)
   455  			if err := header.Encode(buf, msg.Type(), int32(len(data))); err != nil {
   456  				log.Error("failed to send header", msg.Type(), len(data))
   457  				return err
   458  			} else {
   459  				buf.Write(data)
   460  				if _, err := stream.Write(buf.Bytes()); err != nil {
   461  					log.Errorf("failed to send payload node:%s", stream.Conn().ID())
   462  				} else {
   463  					//Trace
   464  					log.Debug("send data to peer", "data", hexutil.Encode(buf.Bytes()), "PeerID", stream.Conn().RemotePeer(), "ProtocolID", stream.Protocol(), "StreamID", stream.Conn().ID(), "StreamDirection", stream.Stat().Direction)
   465  				}
   466  			}
   467  		}
   468  	}
   469  	return nil
   470  }
   471  
   472  func (n *Node) Write(msg message.IMessage) error {
   473  	if !n.isOK {
   474  		return fmt.Errorf("node already closed")
   475  	}
   476  
   477  	n.msgCh <- msg
   478  	return nil
   479  }
   480  
   481  func (n *Node) WriteMsg(messageType message.MessageType, payload []byte) error {
   482  	if !n.isOK {
   483  		return fmt.Errorf("node already closed")
   484  	}
   485  	//n.msgLock.Lock()
   486  	//defer n.msgLock.Unlock()
   487  	msg := P2PMessage{
   488  		MsgType: messageType,
   489  		Payload: payload,
   490  	}
   491  
   492  	//n.msgCh <- &msg
   493  
   494  	return n.Write(&msg)
   495  }
   496  
   497  func (n *Node) Close() error {
   498  	n.Lock()
   499  	defer n.Unlock()
   500  	if n.isOK {
   501  		n.isOK = false
   502  		n.cancel()
   503  	}
   504  	return nil
   505  }
   506  
   507  func (n *Node) ID() libpeer.ID {
   508  	return n.peer.ID
   509  }
   510  
   511  func (n *Node) ClearHandler(msgType message.MessageType) error {
   512  	n.msgLock.Lock()
   513  	defer n.msgLock.Unlock()
   514  	if _, ok := n.msgCallback[msgType]; ok {
   515  		delete(n.msgCallback, msgType)
   516  	}
   517  	return fmt.Errorf("failed to remove connhanlder, msg type %v", msgType)
   518  }