github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/network/tcp_peer.go (about)

     1  package network
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"strconv"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/nspcc-dev/neo-go/pkg/io"
    14  	"github.com/nspcc-dev/neo-go/pkg/network/capability"
    15  	"github.com/nspcc-dev/neo-go/pkg/network/payload"
    16  )
    17  
    18  type handShakeStage uint8
    19  
    20  const (
    21  	versionSent handShakeStage = 1 << iota
    22  	versionReceived
    23  	verAckSent
    24  	verAckReceived
    25  
    26  	requestQueueSize   = 32
    27  	p2pMsgQueueSize    = 16
    28  	hpRequestQueueSize = 4
    29  	incomingQueueSize  = 1 // Each message can be up to 32MB in size.
    30  )
    31  
    32  var (
    33  	errGone           = errors.New("the peer is gone already")
    34  	errStateMismatch  = errors.New("tried to send protocol message before handshake completed")
    35  	errPingPong       = errors.New("ping/pong timeout")
    36  	errUnexpectedPong = errors.New("pong message wasn't expected")
    37  )
    38  
    39  // TCPPeer represents a connected remote node in the
    40  // network over TCP.
    41  type TCPPeer struct {
    42  	// underlying TCP connection.
    43  	conn net.Conn
    44  	// The server this peer belongs to.
    45  	server *Server
    46  	// The version of the peer.
    47  	version *payload.Version
    48  	// Index of the last block.
    49  	lastBlockIndex uint32
    50  	// pre-handshake non-canonical connection address.
    51  	addr string
    52  
    53  	lock       sync.RWMutex
    54  	finale     sync.Once
    55  	handShake  handShakeStage
    56  	isFullNode bool
    57  
    58  	done     chan struct{}
    59  	sendQ    chan []byte
    60  	p2pSendQ chan []byte
    61  	hpSendQ  chan []byte
    62  	incoming chan *Message
    63  
    64  	// track outstanding getaddr requests.
    65  	getAddrSent atomic.Int32
    66  
    67  	// number of sent pings.
    68  	pingSent  int
    69  	pingTimer *time.Timer
    70  }
    71  
    72  // NewTCPPeer returns a TCPPeer structure based on the given connection.
    73  func NewTCPPeer(conn net.Conn, addr string, s *Server) *TCPPeer {
    74  	return &TCPPeer{
    75  		conn:     conn,
    76  		server:   s,
    77  		addr:     addr,
    78  		done:     make(chan struct{}),
    79  		sendQ:    make(chan []byte, requestQueueSize),
    80  		p2pSendQ: make(chan []byte, p2pMsgQueueSize),
    81  		hpSendQ:  make(chan []byte, hpRequestQueueSize),
    82  		incoming: make(chan *Message, incomingQueueSize),
    83  	}
    84  }
    85  
    86  // putPacketIntoQueue puts the given message into the given queue if
    87  // the peer has done handshaking using the given context.
    88  func (p *TCPPeer) putPacketIntoQueue(ctx context.Context, queue chan<- []byte, msg []byte) error {
    89  	if !p.Handshaked() {
    90  		return errStateMismatch
    91  	}
    92  	select {
    93  	case queue <- msg:
    94  	case <-p.done:
    95  		return errGone
    96  	case <-ctx.Done():
    97  		return ctx.Err()
    98  	}
    99  	return nil
   100  }
   101  
   102  // BroadcastPacket implements the Peer interface.
   103  func (p *TCPPeer) BroadcastPacket(ctx context.Context, msg []byte) error {
   104  	return p.putPacketIntoQueue(ctx, p.sendQ, msg)
   105  }
   106  
   107  // BroadcastHPPacket implements the Peer interface. It the peer is not yet
   108  // handshaked it's a noop.
   109  func (p *TCPPeer) BroadcastHPPacket(ctx context.Context, msg []byte) error {
   110  	return p.putPacketIntoQueue(ctx, p.hpSendQ, msg)
   111  }
   112  
   113  // putMessageIntoQueue serializes the given Message and puts it into given queue if
   114  // the peer has done handshaking.
   115  func (p *TCPPeer) putMsgIntoQueue(queue chan<- []byte, msg *Message) error {
   116  	b, err := msg.Bytes()
   117  	if err != nil {
   118  		return err
   119  	}
   120  	return p.putPacketIntoQueue(context.Background(), queue, b)
   121  }
   122  
   123  // EnqueueP2PMessage implements the Peer interface.
   124  func (p *TCPPeer) EnqueueP2PMessage(msg *Message) error {
   125  	return p.putMsgIntoQueue(p.p2pSendQ, msg)
   126  }
   127  
   128  // EnqueueHPMessage implements the Peer interface.
   129  func (p *TCPPeer) EnqueueHPMessage(msg *Message) error {
   130  	return p.putMsgIntoQueue(p.hpSendQ, msg)
   131  }
   132  
   133  // EnqueueP2PPacket implements the Peer interface.
   134  func (p *TCPPeer) EnqueueP2PPacket(b []byte) error {
   135  	return p.putPacketIntoQueue(context.Background(), p.p2pSendQ, b)
   136  }
   137  
   138  // EnqueueHPPacket implements the Peer interface.
   139  func (p *TCPPeer) EnqueueHPPacket(b []byte) error {
   140  	return p.putPacketIntoQueue(context.Background(), p.hpSendQ, b)
   141  }
   142  
   143  func (p *TCPPeer) writeMsg(msg *Message) error {
   144  	b, err := msg.Bytes()
   145  	if err != nil {
   146  		return err
   147  	}
   148  
   149  	_, err = p.conn.Write(b)
   150  
   151  	return err
   152  }
   153  
   154  // handleConn handles the read side of the connection, it should be started as
   155  // a goroutine right after a new peer setup.
   156  func (p *TCPPeer) handleConn() {
   157  	var err error
   158  
   159  	p.server.register <- p
   160  
   161  	go p.handleQueues()
   162  	go p.handleIncoming()
   163  	// When a new peer is connected, we send out our version immediately.
   164  	err = p.SendVersion()
   165  	if err == nil {
   166  		r := io.NewBinReaderFromIO(p.conn)
   167  	loop:
   168  		for {
   169  			msg := &Message{StateRootInHeader: p.server.config.StateRootInHeader}
   170  			err = msg.Decode(r)
   171  
   172  			if errors.Is(err, payload.ErrTooManyHeaders) {
   173  				p.server.log.Warn("not all headers were processed")
   174  				r.Err = nil
   175  			} else if err != nil {
   176  				break
   177  			}
   178  			select {
   179  			case p.incoming <- msg:
   180  			case <-p.done:
   181  				break loop
   182  			}
   183  		}
   184  	}
   185  	p.Disconnect(err)
   186  	close(p.incoming)
   187  }
   188  
   189  func (p *TCPPeer) handleIncoming() {
   190  	var err error
   191  	for msg := range p.incoming {
   192  		err = p.server.handleMessage(p, msg)
   193  		if err != nil {
   194  			if p.Handshaked() {
   195  				err = fmt.Errorf("handling %s message: %w", msg.Command.String(), err)
   196  			}
   197  			break
   198  		}
   199  	}
   200  	p.Disconnect(err)
   201  }
   202  
   203  // handleQueues is a goroutine that is started automatically to handle
   204  // send queues.
   205  func (p *TCPPeer) handleQueues() {
   206  	var err error
   207  	// p2psend queue shares its time with send queue in around
   208  	// ((p2pSkipDivisor - 1) * 2 + 1)/1 ratio, ratio because the third
   209  	// select can still choose p2psend over send.
   210  	var p2pSkipCounter uint32
   211  	const p2pSkipDivisor = 4
   212  
   213  	var writeTimeout = p.server.TimePerBlock
   214  	for {
   215  		var msg []byte
   216  
   217  		// This one is to give priority to the hp queue
   218  		select {
   219  		case <-p.done:
   220  			return
   221  		case msg = <-p.hpSendQ:
   222  		default:
   223  		}
   224  
   225  		// Skip this select every p2pSkipDivisor iteration.
   226  		if msg == nil && p2pSkipCounter%p2pSkipDivisor != 0 {
   227  			// Then look at the p2p queue.
   228  			select {
   229  			case <-p.done:
   230  				return
   231  			case msg = <-p.hpSendQ:
   232  			case msg = <-p.p2pSendQ:
   233  			default:
   234  			}
   235  		}
   236  		// If there is no message in HP or P2P queues, block until one
   237  		// appears in any of the queues.
   238  		if msg == nil {
   239  			select {
   240  			case <-p.done:
   241  				return
   242  			case msg = <-p.hpSendQ:
   243  			case msg = <-p.p2pSendQ:
   244  			case msg = <-p.sendQ:
   245  			}
   246  		}
   247  		err = p.conn.SetWriteDeadline(time.Now().Add(writeTimeout))
   248  		if err != nil {
   249  			break
   250  		}
   251  		_, err = p.conn.Write(msg)
   252  		if err != nil {
   253  			break
   254  		}
   255  		p2pSkipCounter++
   256  	}
   257  	p.Disconnect(err)
   258  drainloop:
   259  	for {
   260  		select {
   261  		case <-p.hpSendQ:
   262  		case <-p.p2pSendQ:
   263  		case <-p.sendQ:
   264  		default:
   265  			break drainloop
   266  		}
   267  	}
   268  }
   269  
   270  // StartProtocol starts a long running background loop that interacts
   271  // every ProtoTickInterval with the peer. It's only good to run after the
   272  // handshake.
   273  func (p *TCPPeer) StartProtocol() {
   274  	var err error
   275  
   276  	p.server.handshake <- p
   277  
   278  	err = p.server.requestBlocksOrHeaders(p)
   279  	if err != nil {
   280  		p.Disconnect(err)
   281  		return
   282  	}
   283  
   284  	timer := time.NewTimer(p.server.ProtoTickInterval)
   285  	for {
   286  		select {
   287  		case <-p.done:
   288  			return
   289  		case <-timer.C:
   290  			// Try to sync in headers and block with the peer if his block height is higher than ours.
   291  			err = p.server.requestBlocksOrHeaders(p)
   292  			if err == nil {
   293  				timer.Reset(p.server.ProtoTickInterval)
   294  			}
   295  		}
   296  		if err != nil {
   297  			timer.Stop()
   298  			p.Disconnect(err)
   299  			return
   300  		}
   301  	}
   302  }
   303  
   304  // Handshaked returns status of the handshake, whether it's completed or not.
   305  func (p *TCPPeer) Handshaked() bool {
   306  	p.lock.RLock()
   307  	defer p.lock.RUnlock()
   308  	return p.handshaked()
   309  }
   310  
   311  // handshaked is internal unlocked version of Handshaked().
   312  func (p *TCPPeer) handshaked() bool {
   313  	return p.handShake == (verAckReceived | verAckSent | versionReceived | versionSent)
   314  }
   315  
   316  // IsFullNode returns whether the node has full capability or TCP/WS only.
   317  func (p *TCPPeer) IsFullNode() bool {
   318  	p.lock.RLock()
   319  	defer p.lock.RUnlock()
   320  	return p.handshaked() && p.isFullNode
   321  }
   322  
   323  // SendVersion checks for the handshake state and sends a message to the peer.
   324  func (p *TCPPeer) SendVersion() error {
   325  	msg, err := p.server.getVersionMsg(p.conn.LocalAddr())
   326  	if err != nil {
   327  		return err
   328  	}
   329  	p.lock.Lock()
   330  	defer p.lock.Unlock()
   331  	if p.handShake&versionSent != 0 {
   332  		return errors.New("invalid handshake: already sent Version")
   333  	}
   334  	err = p.writeMsg(msg)
   335  	if err == nil {
   336  		p.handShake |= versionSent
   337  	}
   338  	return err
   339  }
   340  
   341  // HandleVersion checks for the handshake state and version message contents.
   342  func (p *TCPPeer) HandleVersion(version *payload.Version) error {
   343  	p.lock.Lock()
   344  	defer p.lock.Unlock()
   345  	if p.handShake&versionReceived != 0 {
   346  		return errors.New("invalid handshake: already received Version")
   347  	}
   348  	p.version = version
   349  	for _, cap := range version.Capabilities {
   350  		if cap.Type == capability.FullNode {
   351  			p.isFullNode = true
   352  			p.lastBlockIndex = cap.Data.(*capability.Node).StartHeight
   353  			break
   354  		}
   355  	}
   356  
   357  	p.handShake |= versionReceived
   358  	return nil
   359  }
   360  
   361  // SendVersionAck checks for the handshake state and sends a message to the peer.
   362  func (p *TCPPeer) SendVersionAck(msg *Message) error {
   363  	p.lock.Lock()
   364  	defer p.lock.Unlock()
   365  	if p.handShake&versionReceived == 0 {
   366  		return errors.New("invalid handshake: tried to send VersionAck, but no version received yet")
   367  	}
   368  	if p.handShake&versionSent == 0 {
   369  		return errors.New("invalid handshake: tried to send VersionAck, but didn't send Version yet")
   370  	}
   371  	if p.handShake&verAckSent != 0 {
   372  		return errors.New("invalid handshake: already sent VersionAck")
   373  	}
   374  	err := p.writeMsg(msg)
   375  	if err == nil {
   376  		p.handShake |= verAckSent
   377  	}
   378  	return err
   379  }
   380  
   381  // HandleVersionAck checks handshake sequence correctness when VerAck message
   382  // is received.
   383  func (p *TCPPeer) HandleVersionAck() error {
   384  	p.lock.Lock()
   385  	defer p.lock.Unlock()
   386  	if p.handShake&versionSent == 0 {
   387  		return errors.New("invalid handshake: received VersionAck, but no version sent yet")
   388  	}
   389  	if p.handShake&versionReceived == 0 {
   390  		return errors.New("invalid handshake: received VersionAck, but no version received yet")
   391  	}
   392  	if p.handShake&verAckReceived != 0 {
   393  		return errors.New("invalid handshake: already received VersionAck")
   394  	}
   395  	p.handShake |= verAckReceived
   396  	return nil
   397  }
   398  
   399  // ConnectionAddr implements the Peer interface.
   400  func (p *TCPPeer) ConnectionAddr() string {
   401  	if p.addr != "" {
   402  		return p.addr
   403  	}
   404  	return p.conn.RemoteAddr().String()
   405  }
   406  
   407  // RemoteAddr implements the Peer interface.
   408  func (p *TCPPeer) RemoteAddr() net.Addr {
   409  	return p.conn.RemoteAddr()
   410  }
   411  
   412  // PeerAddr implements the Peer interface.
   413  func (p *TCPPeer) PeerAddr() net.Addr {
   414  	remote := p.conn.RemoteAddr()
   415  	// The network can be non-tcp in unit tests.
   416  	if p.version == nil || remote.Network() != "tcp" {
   417  		return p.RemoteAddr()
   418  	}
   419  	host, _, err := net.SplitHostPort(remote.String())
   420  	if err != nil {
   421  		return p.RemoteAddr()
   422  	}
   423  	var port uint16
   424  	for _, cap := range p.version.Capabilities {
   425  		if cap.Type == capability.TCPServer {
   426  			port = cap.Data.(*capability.Server).Port
   427  		}
   428  	}
   429  	if port == 0 {
   430  		return p.RemoteAddr()
   431  	}
   432  	addrString := net.JoinHostPort(host, strconv.Itoa(int(port)))
   433  	tcpAddr, err := net.ResolveTCPAddr("tcp", addrString)
   434  	if err != nil {
   435  		return p.RemoteAddr()
   436  	}
   437  	return tcpAddr
   438  }
   439  
   440  // Disconnect will fill the peer's done channel with the given error.
   441  func (p *TCPPeer) Disconnect(err error) {
   442  	p.finale.Do(func() {
   443  		close(p.done)
   444  		p.conn.Close()
   445  		p.server.unregister <- peerDrop{p, err}
   446  	})
   447  }
   448  
   449  // Version implements the Peer interface.
   450  func (p *TCPPeer) Version() *payload.Version {
   451  	return p.version
   452  }
   453  
   454  // LastBlockIndex returns the last block index.
   455  func (p *TCPPeer) LastBlockIndex() uint32 {
   456  	p.lock.RLock()
   457  	defer p.lock.RUnlock()
   458  	return p.lastBlockIndex
   459  }
   460  
   461  // SetPingTimer adds an outgoing ping to the counter and sets a PingTimeout timer
   462  // that will shut the connection down in case of no response.
   463  func (p *TCPPeer) SetPingTimer() {
   464  	p.lock.Lock()
   465  	p.pingSent++
   466  	if p.pingTimer == nil {
   467  		p.pingTimer = time.AfterFunc(p.server.PingTimeout, func() {
   468  			p.Disconnect(errPingPong)
   469  		})
   470  	}
   471  	p.lock.Unlock()
   472  }
   473  
   474  // HandlePing handles a ping message received from the peer.
   475  func (p *TCPPeer) HandlePing(ping *payload.Ping) error {
   476  	p.lock.Lock()
   477  	defer p.lock.Unlock()
   478  	p.lastBlockIndex = ping.LastBlockIndex
   479  	return nil
   480  }
   481  
   482  // HandlePong handles a pong message received from the peer and does an appropriate
   483  // accounting of outstanding pings and timeouts.
   484  func (p *TCPPeer) HandlePong(pong *payload.Ping) error {
   485  	p.lock.Lock()
   486  	defer p.lock.Unlock()
   487  	if p.pingTimer != nil && !p.pingTimer.Stop() {
   488  		return errPingPong
   489  	}
   490  	p.pingTimer = nil
   491  	p.pingSent--
   492  	if p.pingSent < 0 {
   493  		return errUnexpectedPong
   494  	}
   495  	p.lastBlockIndex = pong.LastBlockIndex
   496  	return nil
   497  }
   498  
   499  // AddGetAddrSent increments internal outstanding getaddr requests counter. Then,
   500  // the peer can only send one addr reply per getaddr request.
   501  func (p *TCPPeer) AddGetAddrSent() {
   502  	p.getAddrSent.Add(1)
   503  }
   504  
   505  // CanProcessAddr decrements internal outstanding getaddr requests counter and
   506  // answers whether the addr command from the peer can be safely processed.
   507  func (p *TCPPeer) CanProcessAddr() bool {
   508  	v := p.getAddrSent.Add(-1)
   509  	return v >= 0
   510  }