github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/network/p2p/peer.go (about)

     1  package p2p
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"net"
     7  	"sort"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/neatlab/neatio/chain/log"
    12  	"github.com/neatlab/neatio/network/p2p/discover"
    13  	"github.com/neatlab/neatio/utilities/common/mclock"
    14  	"github.com/neatlab/neatio/utilities/event"
    15  	"github.com/neatlab/neatio/utilities/rlp"
    16  )
    17  
    18  const (
    19  	baseProtocolVersion    = 5
    20  	baseProtocolLength     = uint64(16)
    21  	baseProtocolMaxMsgSize = 2 * 1024
    22  
    23  	snappyProtocolVersion = 5
    24  
    25  	pingInterval = 15 * time.Second
    26  )
    27  
    28  const (
    29  	handshakeMsg = 0x00
    30  	discMsg      = 0x01
    31  	pingMsg      = 0x02
    32  	pongMsg      = 0x03
    33  
    34  	BroadcastNewSideChainMsg = 0x04
    35  	ConfirmNewSideChainMsg   = 0x05
    36  
    37  	RefreshValidatorNodeInfoMsg = 0x06
    38  	RemoveValidatorNodeInfoMsg  = 0x07
    39  )
    40  
    41  type protoHandshake struct {
    42  	Version    uint64
    43  	Name       string
    44  	Caps       []Cap
    45  	ListenPort uint64
    46  	ID         discover.NodeID
    47  
    48  	Rest []rlp.RawValue `rlp:"tail"`
    49  }
    50  
    51  type PeerEventType string
    52  
    53  const (
    54  	PeerEventTypeAdd PeerEventType = "add"
    55  
    56  	PeerEventTypeDrop PeerEventType = "drop"
    57  
    58  	PeerEventTypeMsgSend PeerEventType = "msgsend"
    59  
    60  	PeerEventTypeMsgRecv PeerEventType = "msgrecv"
    61  
    62  	PeerEventTypeRefreshValidator PeerEventType = "refreshvalidator"
    63  
    64  	PeerEventTypeRemoveValidator PeerEventType = "removevalidator"
    65  )
    66  
    67  type PeerEvent struct {
    68  	Type     PeerEventType   `json:"type"`
    69  	Peer     discover.NodeID `json:"peer"`
    70  	Error    string          `json:"error,omitempty"`
    71  	Protocol string          `json:"protocol,omitempty"`
    72  	MsgCode  *uint64         `json:"msg_code,omitempty"`
    73  	MsgSize  *uint32         `json:"msg_size,omitempty"`
    74  }
    75  
    76  type Peer struct {
    77  	rw      *conn
    78  	running map[string]*protoRW
    79  	log     log.Logger
    80  	created mclock.AbsTime
    81  
    82  	wg       sync.WaitGroup
    83  	protoErr chan error
    84  	closed   chan struct{}
    85  	disc     chan DiscReason
    86  
    87  	events *event.Feed
    88  
    89  	srvProtocols *[]Protocol
    90  }
    91  
    92  func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
    93  	pipe, _ := net.Pipe()
    94  	conn := &conn{fd: pipe, transport: nil, id: id, caps: caps, name: name}
    95  	peer := newPeer(conn, nil)
    96  	close(peer.closed)
    97  	return peer
    98  }
    99  
   100  func (p *Peer) ID() discover.NodeID {
   101  	return p.rw.id
   102  }
   103  
   104  func (p *Peer) Name() string {
   105  	return p.rw.name
   106  }
   107  
   108  func (p *Peer) Caps() []Cap {
   109  
   110  	return p.rw.caps
   111  }
   112  
   113  func (p *Peer) RemoteAddr() net.Addr {
   114  	return p.rw.fd.RemoteAddr()
   115  }
   116  
   117  func (p *Peer) LocalAddr() net.Addr {
   118  	return p.rw.fd.LocalAddr()
   119  }
   120  
   121  func (p *Peer) Disconnect(reason DiscReason) {
   122  	select {
   123  	case p.disc <- reason:
   124  	case <-p.closed:
   125  	}
   126  }
   127  
   128  func (p *Peer) String() string {
   129  	return fmt.Sprintf("Peer %x %v", p.rw.id[:8], p.RemoteAddr())
   130  }
   131  
   132  func (p *Peer) Inbound() bool {
   133  	return p.rw.flags&inboundConn != 0
   134  }
   135  
   136  func newPeer(conn *conn, protocols []Protocol) *Peer {
   137  	protomap := matchProtocols(protocols, conn.caps, conn)
   138  	p := &Peer{
   139  		rw:       conn,
   140  		running:  protomap,
   141  		created:  mclock.Now(),
   142  		disc:     make(chan DiscReason),
   143  		protoErr: make(chan error, len(protomap)+1),
   144  		closed:   make(chan struct{}),
   145  		log:      log.New("id", conn.id, "conn", conn.flags),
   146  	}
   147  	return p
   148  }
   149  
   150  func (p *Peer) Log() log.Logger {
   151  	return p.log
   152  }
   153  
   154  func (p *Peer) run() (remoteRequested bool, err error) {
   155  	var (
   156  		writeStart = make(chan struct{}, 1)
   157  		writeErr   = make(chan error, 1)
   158  		readErr    = make(chan error, 1)
   159  		reason     DiscReason
   160  	)
   161  	p.wg.Add(2)
   162  	go p.readLoop(readErr)
   163  	go p.pingLoop()
   164  
   165  	writeStart <- struct{}{}
   166  	p.startProtocols(writeStart, writeErr)
   167  
   168  loop:
   169  	for {
   170  		select {
   171  		case err = <-writeErr:
   172  
   173  			if err != nil {
   174  				reason = DiscNetworkError
   175  				break loop
   176  			}
   177  			writeStart <- struct{}{}
   178  		case err = <-readErr:
   179  			if r, ok := err.(DiscReason); ok {
   180  				remoteRequested = true
   181  				reason = r
   182  			} else {
   183  				reason = DiscNetworkError
   184  			}
   185  			break loop
   186  		case err = <-p.protoErr:
   187  			reason = discReasonForError(err)
   188  			break loop
   189  		case err = <-p.disc:
   190  			break loop
   191  		}
   192  	}
   193  
   194  	close(p.closed)
   195  	p.rw.close(reason)
   196  	p.wg.Wait()
   197  	return remoteRequested, err
   198  }
   199  
   200  func (p *Peer) pingLoop() {
   201  	ping := time.NewTimer(pingInterval)
   202  	defer p.wg.Done()
   203  	defer ping.Stop()
   204  	for {
   205  		select {
   206  		case <-ping.C:
   207  			if err := SendItems(p.rw, pingMsg); err != nil {
   208  				p.protoErr <- err
   209  				return
   210  			}
   211  			ping.Reset(pingInterval)
   212  		case <-p.closed:
   213  			return
   214  		}
   215  	}
   216  }
   217  
   218  func (p *Peer) readLoop(errc chan<- error) {
   219  	defer p.wg.Done()
   220  	for {
   221  		msg, err := p.rw.ReadMsg()
   222  		if err != nil {
   223  			errc <- err
   224  			return
   225  		}
   226  		msg.ReceivedAt = time.Now()
   227  		if err = p.handle(msg); err != nil {
   228  			errc <- err
   229  			return
   230  		}
   231  	}
   232  }
   233  
   234  func (p *Peer) handle(msg Msg) error {
   235  	switch {
   236  	case msg.Code == pingMsg:
   237  		msg.Discard()
   238  		go SendItems(p.rw, pongMsg)
   239  	case msg.Code == discMsg:
   240  		var reason [1]DiscReason
   241  
   242  		rlp.Decode(msg.Payload, &reason)
   243  		return reason[0]
   244  	case msg.Code == BroadcastNewSideChainMsg:
   245  
   246  		var chainId string
   247  		if err := msg.Decode(&chainId); err != nil {
   248  			return err
   249  		}
   250  
   251  		p.log.Infof("Got new side chain msg from Peer %v, Before add protocol. Caps %v, Running Proto %+v", p.String(), p.Caps(), p.Info().Protocols)
   252  
   253  		newRunning := p.checkAndUpdateProtocol(chainId)
   254  		if newRunning {
   255  
   256  			go Send(p.rw, ConfirmNewSideChainMsg, chainId)
   257  		}
   258  
   259  		p.log.Infof("Got new side chain msg After add protocol. Caps %v, Running Proto %+v", p.Caps(), p.Info().Protocols)
   260  
   261  	case msg.Code == ConfirmNewSideChainMsg:
   262  
   263  		var chainId string
   264  		if err := msg.Decode(&chainId); err != nil {
   265  			return err
   266  		}
   267  		p.log.Infof("Got confirm msg from Peer %v, Before add protocol. Caps %v, Running Proto %+v", p.String(), p.Caps(), p.Info().Protocols)
   268  		p.checkAndUpdateProtocol(chainId)
   269  		p.log.Infof("Got confirm msg After add protocol. Caps %v, Running Proto %+v", p.Caps(), p.Info().Protocols)
   270  
   271  	case msg.Code == RefreshValidatorNodeInfoMsg:
   272  		p.log.Debug("Got refresh validation node infomation")
   273  		var valNodeInfo P2PValidatorNodeInfo
   274  		if err := msg.Decode(&valNodeInfo); err != nil {
   275  			p.log.Debugf("decode error: %v", err)
   276  			return err
   277  		}
   278  		p.log.Debugf("validation node address: %x", valNodeInfo.Validator.Address)
   279  
   280  		if valNodeInfo.Original && p.Info().ID == valNodeInfo.Node.ID.String() {
   281  			valNodeInfo.Node.IP = p.RemoteAddr().(*net.TCPAddr).IP
   282  		}
   283  		valNodeInfo.Original = false
   284  
   285  		p.log.Debugf("validator node info: %v", valNodeInfo)
   286  
   287  		data, err := rlp.EncodeToBytes(valNodeInfo)
   288  		if err != nil {
   289  			p.log.Debugf("encode error: %v", err)
   290  			return err
   291  		}
   292  		p.events.Send(&PeerEvent{
   293  			Type:     PeerEventTypeRefreshValidator,
   294  			Peer:     p.ID(),
   295  			Protocol: string(data),
   296  		})
   297  
   298  		p.log.Debugf("RefreshValidatorNodeInfoMsg handled")
   299  
   300  	case msg.Code == RemoveValidatorNodeInfoMsg:
   301  		p.log.Debug("Got remove validation node infomation")
   302  		var valNodeInfo P2PValidatorNodeInfo
   303  		if err := msg.Decode(&valNodeInfo); err != nil {
   304  			p.log.Debugf("decode error: %v", err)
   305  			return err
   306  		}
   307  		p.log.Debugf("validation node address: %x", valNodeInfo.Validator.Address)
   308  
   309  		if valNodeInfo.Original {
   310  			valNodeInfo.Node.IP = p.RemoteAddr().(*net.TCPAddr).IP
   311  			valNodeInfo.Original = false
   312  		}
   313  		p.log.Debugf("validator node info: %v", valNodeInfo)
   314  
   315  		data, err := rlp.EncodeToBytes(valNodeInfo)
   316  		if err != nil {
   317  			p.log.Debugf("encode error: %v", err)
   318  			return err
   319  		}
   320  		p.events.Send(&PeerEvent{
   321  			Type:     PeerEventTypeRemoveValidator,
   322  			Peer:     p.ID(),
   323  			Protocol: string(data),
   324  		})
   325  
   326  		p.log.Debug("RemoveValidatorNodeInfoMsg handled")
   327  
   328  	case msg.Code < baseProtocolLength:
   329  
   330  		return msg.Discard()
   331  	default:
   332  
   333  		proto, err := p.getProto(msg.Code)
   334  		if err != nil {
   335  			return fmt.Errorf("msg code out of range: %v", msg.Code)
   336  		}
   337  		select {
   338  		case proto.in <- msg:
   339  			return nil
   340  		case <-p.closed:
   341  			return io.EOF
   342  		}
   343  	}
   344  	return nil
   345  }
   346  
   347  func (p *Peer) checkAndUpdateProtocol(chainId string) bool {
   348  
   349  	sideProtocolName := "neatchain_" + chainId
   350  
   351  	if _, exist := p.running[sideProtocolName]; exist {
   352  		p.log.Infof("Side Chain %v is already running on peer", sideProtocolName)
   353  		return false
   354  	}
   355  
   356  	sideProtocolOffset := getLargestOffset(p.running)
   357  	if match, protoRW := matchServerProtocol(*p.srvProtocols, sideProtocolName, sideProtocolOffset, p.rw); match {
   358  
   359  		p.startSideChainProtocol(protoRW)
   360  
   361  		p.running[sideProtocolName] = protoRW
   362  
   363  		protoCap := protoRW.cap()
   364  		capExist := false
   365  		for _, cap := range p.rw.caps {
   366  			if cap.Name == protoCap.Name && cap.Version == protoCap.Version {
   367  				capExist = true
   368  			}
   369  		}
   370  		if !capExist {
   371  			p.rw.caps = append(p.rw.caps, protoCap)
   372  		}
   373  		return true
   374  	}
   375  
   376  	p.log.Infof("No Local Server Protocol matched, perhaps local server has not start the side chain %v yet.", sideProtocolName)
   377  	return false
   378  }
   379  
   380  func countMatchingProtocols(protocols []Protocol, caps []Cap) int {
   381  	n := 0
   382  	for _, cap := range caps {
   383  		for _, proto := range protocols {
   384  			if proto.Name == cap.Name && proto.Version == cap.Version {
   385  				n++
   386  			}
   387  		}
   388  	}
   389  	return n
   390  }
   391  
   392  func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
   393  	sort.Sort(capsByNameAndVersion(caps))
   394  	offset := baseProtocolLength
   395  	result := make(map[string]*protoRW)
   396  
   397  outer:
   398  	for _, cap := range caps {
   399  		for _, proto := range protocols {
   400  			if proto.Name == cap.Name && proto.Version == cap.Version {
   401  
   402  				if old := result[cap.Name]; old != nil {
   403  					offset -= old.Length
   404  				}
   405  
   406  				result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
   407  				offset += proto.Length
   408  
   409  				continue outer
   410  			}
   411  		}
   412  	}
   413  	return result
   414  }
   415  
   416  func matchServerProtocol(protocols []Protocol, name string, offset uint64, rw MsgReadWriter) (bool, *protoRW) {
   417  	for _, proto := range protocols {
   418  		if proto.Name == name {
   419  
   420  			return true, &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
   421  		}
   422  	}
   423  	return false, nil
   424  }
   425  
   426  func getLargestOffset(running map[string]*protoRW) uint64 {
   427  	var largestOffset uint64 = 0
   428  	for _, proto := range running {
   429  		offsetEnd := proto.offset + proto.Length
   430  		if offsetEnd > largestOffset {
   431  			largestOffset = offsetEnd
   432  		}
   433  	}
   434  	return largestOffset
   435  }
   436  
   437  func (p *Peer) startProtocols(writeStart <-chan struct{}, writeErr chan<- error) {
   438  	p.wg.Add(len(p.running))
   439  	for _, proto := range p.running {
   440  		proto := proto
   441  		proto.closed = p.closed
   442  		proto.wstart = writeStart
   443  		proto.werr = writeErr
   444  		var rw MsgReadWriter = proto
   445  		if p.events != nil {
   446  			rw = newMsgEventer(rw, p.events, p.ID(), proto.Name)
   447  		}
   448  		p.log.Trace(fmt.Sprintf("Starting protocol %s/%d", proto.Name, proto.Version))
   449  		go func() {
   450  			err := proto.Run(p, rw)
   451  			if err == nil {
   452  				p.log.Trace(fmt.Sprintf("Protocol %s/%d returned", proto.Name, proto.Version))
   453  				err = errProtocolReturned
   454  			} else if err != io.EOF {
   455  				p.log.Trace(fmt.Sprintf("Protocol %s/%d failed", proto.Name, proto.Version), "err", err)
   456  			}
   457  			p.protoErr <- err
   458  			p.wg.Done()
   459  		}()
   460  	}
   461  }
   462  
   463  func (p *Peer) startSideChainProtocol(proto *protoRW) {
   464  	p.wg.Add(1)
   465  
   466  	proto.closed = p.closed
   467  	proto.wstart = p.running["neatio"].wstart
   468  	proto.werr = p.running["neatio"].werr
   469  
   470  	var rw MsgReadWriter = proto
   471  	if p.events != nil {
   472  		rw = newMsgEventer(rw, p.events, p.ID(), proto.Name)
   473  	}
   474  	p.log.Trace(fmt.Sprintf("Starting protocol %s/%d", proto.Name, proto.Version))
   475  	go func() {
   476  		err := proto.Run(p, rw)
   477  		if err == nil {
   478  			p.log.Trace(fmt.Sprintf("Protocol %s/%d returned", proto.Name, proto.Version))
   479  			err = errProtocolReturned
   480  		} else if err != io.EOF {
   481  			p.log.Trace(fmt.Sprintf("Protocol %s/%d failed", proto.Name, proto.Version), "err", err)
   482  		}
   483  		p.protoErr <- err
   484  		p.wg.Done()
   485  	}()
   486  }
   487  
   488  func (p *Peer) getProto(code uint64) (*protoRW, error) {
   489  	for _, proto := range p.running {
   490  		if code >= proto.offset && code < proto.offset+proto.Length {
   491  			return proto, nil
   492  		}
   493  	}
   494  	return nil, newPeerError(errInvalidMsgCode, "%d", code)
   495  }
   496  
   497  type protoRW struct {
   498  	Protocol
   499  	in     chan Msg
   500  	closed <-chan struct{}
   501  	wstart <-chan struct{}
   502  	werr   chan<- error
   503  	offset uint64
   504  	w      MsgWriter
   505  }
   506  
   507  func (rw *protoRW) WriteMsg(msg Msg) (err error) {
   508  	if msg.Code >= rw.Length {
   509  		return newPeerError(errInvalidMsgCode, "not handled")
   510  	}
   511  	msg.Code += rw.offset
   512  	select {
   513  	case <-rw.wstart:
   514  		err = rw.w.WriteMsg(msg)
   515  
   516  		rw.werr <- err
   517  	case <-rw.closed:
   518  		err = fmt.Errorf("shutting down")
   519  	}
   520  	return err
   521  }
   522  
   523  func (rw *protoRW) ReadMsg() (Msg, error) {
   524  	select {
   525  	case msg := <-rw.in:
   526  		msg.Code -= rw.offset
   527  		return msg, nil
   528  	case <-rw.closed:
   529  		return Msg{}, io.EOF
   530  	}
   531  }
   532  
   533  type PeerInfo struct {
   534  	ID      string   `json:"id"`
   535  	Name    string   `json:"name"`
   536  	Caps    []string `json:"caps"`
   537  	Network struct {
   538  		LocalAddress  string `json:"localAddress"`
   539  		RemoteAddress string `json:"remoteAddress"`
   540  		Inbound       bool   `json:"inbound"`
   541  		Trusted       bool   `json:"trusted"`
   542  		Static        bool   `json:"static"`
   543  	} `json:"network"`
   544  	Protocols map[string]interface{} `json:"protocols"`
   545  }
   546  
   547  func (p *Peer) Info() *PeerInfo {
   548  
   549  	var caps []string
   550  	for _, cap := range p.Caps() {
   551  		caps = append(caps, cap.String())
   552  	}
   553  
   554  	info := &PeerInfo{
   555  		ID:        p.ID().String(),
   556  		Name:      p.Name(),
   557  		Caps:      caps,
   558  		Protocols: make(map[string]interface{}),
   559  	}
   560  	info.Network.LocalAddress = p.LocalAddr().String()
   561  	info.Network.RemoteAddress = p.RemoteAddr().String()
   562  	info.Network.Inbound = p.rw.is(inboundConn)
   563  	info.Network.Trusted = p.rw.is(trustedConn)
   564  	info.Network.Static = p.rw.is(staticDialedConn)
   565  
   566  	for _, proto := range p.running {
   567  		protoInfo := interface{}("unknown")
   568  		if query := proto.Protocol.PeerInfo; query != nil {
   569  			if metadata := query(p.ID()); metadata != nil {
   570  				protoInfo = metadata
   571  			} else {
   572  				protoInfo = "handshake"
   573  			}
   574  		}
   575  		info.Protocols[proto.Name] = protoInfo
   576  	}
   577  	return info
   578  }