github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/network/tcp_peer.go (about) 1 package network 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "strconv" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "github.com/nspcc-dev/neo-go/pkg/io" 14 "github.com/nspcc-dev/neo-go/pkg/network/capability" 15 "github.com/nspcc-dev/neo-go/pkg/network/payload" 16 ) 17 18 type handShakeStage uint8 19 20 const ( 21 versionSent handShakeStage = 1 << iota 22 versionReceived 23 verAckSent 24 verAckReceived 25 26 requestQueueSize = 32 27 p2pMsgQueueSize = 16 28 hpRequestQueueSize = 4 29 incomingQueueSize = 1 // Each message can be up to 32MB in size. 30 ) 31 32 var ( 33 errGone = errors.New("the peer is gone already") 34 errStateMismatch = errors.New("tried to send protocol message before handshake completed") 35 errPingPong = errors.New("ping/pong timeout") 36 errUnexpectedPong = errors.New("pong message wasn't expected") 37 ) 38 39 // TCPPeer represents a connected remote node in the 40 // network over TCP. 41 type TCPPeer struct { 42 // underlying TCP connection. 43 conn net.Conn 44 // The server this peer belongs to. 45 server *Server 46 // The version of the peer. 47 version *payload.Version 48 // Index of the last block. 49 lastBlockIndex uint32 50 // pre-handshake non-canonical connection address. 51 addr string 52 53 lock sync.RWMutex 54 finale sync.Once 55 handShake handShakeStage 56 isFullNode bool 57 58 done chan struct{} 59 sendQ chan []byte 60 p2pSendQ chan []byte 61 hpSendQ chan []byte 62 incoming chan *Message 63 64 // track outstanding getaddr requests. 65 getAddrSent atomic.Int32 66 67 // number of sent pings. 68 pingSent int 69 pingTimer *time.Timer 70 } 71 72 // NewTCPPeer returns a TCPPeer structure based on the given connection. 73 func NewTCPPeer(conn net.Conn, addr string, s *Server) *TCPPeer { 74 return &TCPPeer{ 75 conn: conn, 76 server: s, 77 addr: addr, 78 done: make(chan struct{}), 79 sendQ: make(chan []byte, requestQueueSize), 80 p2pSendQ: make(chan []byte, p2pMsgQueueSize), 81 hpSendQ: make(chan []byte, hpRequestQueueSize), 82 incoming: make(chan *Message, incomingQueueSize), 83 } 84 } 85 86 // putPacketIntoQueue puts the given message into the given queue if 87 // the peer has done handshaking using the given context. 88 func (p *TCPPeer) putPacketIntoQueue(ctx context.Context, queue chan<- []byte, msg []byte) error { 89 if !p.Handshaked() { 90 return errStateMismatch 91 } 92 select { 93 case queue <- msg: 94 case <-p.done: 95 return errGone 96 case <-ctx.Done(): 97 return ctx.Err() 98 } 99 return nil 100 } 101 102 // BroadcastPacket implements the Peer interface. 103 func (p *TCPPeer) BroadcastPacket(ctx context.Context, msg []byte) error { 104 return p.putPacketIntoQueue(ctx, p.sendQ, msg) 105 } 106 107 // BroadcastHPPacket implements the Peer interface. It the peer is not yet 108 // handshaked it's a noop. 109 func (p *TCPPeer) BroadcastHPPacket(ctx context.Context, msg []byte) error { 110 return p.putPacketIntoQueue(ctx, p.hpSendQ, msg) 111 } 112 113 // putMessageIntoQueue serializes the given Message and puts it into given queue if 114 // the peer has done handshaking. 115 func (p *TCPPeer) putMsgIntoQueue(queue chan<- []byte, msg *Message) error { 116 b, err := msg.Bytes() 117 if err != nil { 118 return err 119 } 120 return p.putPacketIntoQueue(context.Background(), queue, b) 121 } 122 123 // EnqueueP2PMessage implements the Peer interface. 124 func (p *TCPPeer) EnqueueP2PMessage(msg *Message) error { 125 return p.putMsgIntoQueue(p.p2pSendQ, msg) 126 } 127 128 // EnqueueHPMessage implements the Peer interface. 129 func (p *TCPPeer) EnqueueHPMessage(msg *Message) error { 130 return p.putMsgIntoQueue(p.hpSendQ, msg) 131 } 132 133 // EnqueueP2PPacket implements the Peer interface. 134 func (p *TCPPeer) EnqueueP2PPacket(b []byte) error { 135 return p.putPacketIntoQueue(context.Background(), p.p2pSendQ, b) 136 } 137 138 // EnqueueHPPacket implements the Peer interface. 139 func (p *TCPPeer) EnqueueHPPacket(b []byte) error { 140 return p.putPacketIntoQueue(context.Background(), p.hpSendQ, b) 141 } 142 143 func (p *TCPPeer) writeMsg(msg *Message) error { 144 b, err := msg.Bytes() 145 if err != nil { 146 return err 147 } 148 149 _, err = p.conn.Write(b) 150 151 return err 152 } 153 154 // handleConn handles the read side of the connection, it should be started as 155 // a goroutine right after a new peer setup. 156 func (p *TCPPeer) handleConn() { 157 var err error 158 159 p.server.register <- p 160 161 go p.handleQueues() 162 go p.handleIncoming() 163 // When a new peer is connected, we send out our version immediately. 164 err = p.SendVersion() 165 if err == nil { 166 r := io.NewBinReaderFromIO(p.conn) 167 loop: 168 for { 169 msg := &Message{StateRootInHeader: p.server.config.StateRootInHeader} 170 err = msg.Decode(r) 171 172 if errors.Is(err, payload.ErrTooManyHeaders) { 173 p.server.log.Warn("not all headers were processed") 174 r.Err = nil 175 } else if err != nil { 176 break 177 } 178 select { 179 case p.incoming <- msg: 180 case <-p.done: 181 break loop 182 } 183 } 184 } 185 p.Disconnect(err) 186 close(p.incoming) 187 } 188 189 func (p *TCPPeer) handleIncoming() { 190 var err error 191 for msg := range p.incoming { 192 err = p.server.handleMessage(p, msg) 193 if err != nil { 194 if p.Handshaked() { 195 err = fmt.Errorf("handling %s message: %w", msg.Command.String(), err) 196 } 197 break 198 } 199 } 200 p.Disconnect(err) 201 } 202 203 // handleQueues is a goroutine that is started automatically to handle 204 // send queues. 205 func (p *TCPPeer) handleQueues() { 206 var err error 207 // p2psend queue shares its time with send queue in around 208 // ((p2pSkipDivisor - 1) * 2 + 1)/1 ratio, ratio because the third 209 // select can still choose p2psend over send. 210 var p2pSkipCounter uint32 211 const p2pSkipDivisor = 4 212 213 var writeTimeout = p.server.TimePerBlock 214 for { 215 var msg []byte 216 217 // This one is to give priority to the hp queue 218 select { 219 case <-p.done: 220 return 221 case msg = <-p.hpSendQ: 222 default: 223 } 224 225 // Skip this select every p2pSkipDivisor iteration. 226 if msg == nil && p2pSkipCounter%p2pSkipDivisor != 0 { 227 // Then look at the p2p queue. 228 select { 229 case <-p.done: 230 return 231 case msg = <-p.hpSendQ: 232 case msg = <-p.p2pSendQ: 233 default: 234 } 235 } 236 // If there is no message in HP or P2P queues, block until one 237 // appears in any of the queues. 238 if msg == nil { 239 select { 240 case <-p.done: 241 return 242 case msg = <-p.hpSendQ: 243 case msg = <-p.p2pSendQ: 244 case msg = <-p.sendQ: 245 } 246 } 247 err = p.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) 248 if err != nil { 249 break 250 } 251 _, err = p.conn.Write(msg) 252 if err != nil { 253 break 254 } 255 p2pSkipCounter++ 256 } 257 p.Disconnect(err) 258 drainloop: 259 for { 260 select { 261 case <-p.hpSendQ: 262 case <-p.p2pSendQ: 263 case <-p.sendQ: 264 default: 265 break drainloop 266 } 267 } 268 } 269 270 // StartProtocol starts a long running background loop that interacts 271 // every ProtoTickInterval with the peer. It's only good to run after the 272 // handshake. 273 func (p *TCPPeer) StartProtocol() { 274 var err error 275 276 p.server.handshake <- p 277 278 err = p.server.requestBlocksOrHeaders(p) 279 if err != nil { 280 p.Disconnect(err) 281 return 282 } 283 284 timer := time.NewTimer(p.server.ProtoTickInterval) 285 for { 286 select { 287 case <-p.done: 288 return 289 case <-timer.C: 290 // Try to sync in headers and block with the peer if his block height is higher than ours. 291 err = p.server.requestBlocksOrHeaders(p) 292 if err == nil { 293 timer.Reset(p.server.ProtoTickInterval) 294 } 295 } 296 if err != nil { 297 timer.Stop() 298 p.Disconnect(err) 299 return 300 } 301 } 302 } 303 304 // Handshaked returns status of the handshake, whether it's completed or not. 305 func (p *TCPPeer) Handshaked() bool { 306 p.lock.RLock() 307 defer p.lock.RUnlock() 308 return p.handshaked() 309 } 310 311 // handshaked is internal unlocked version of Handshaked(). 312 func (p *TCPPeer) handshaked() bool { 313 return p.handShake == (verAckReceived | verAckSent | versionReceived | versionSent) 314 } 315 316 // IsFullNode returns whether the node has full capability or TCP/WS only. 317 func (p *TCPPeer) IsFullNode() bool { 318 p.lock.RLock() 319 defer p.lock.RUnlock() 320 return p.handshaked() && p.isFullNode 321 } 322 323 // SendVersion checks for the handshake state and sends a message to the peer. 324 func (p *TCPPeer) SendVersion() error { 325 msg, err := p.server.getVersionMsg(p.conn.LocalAddr()) 326 if err != nil { 327 return err 328 } 329 p.lock.Lock() 330 defer p.lock.Unlock() 331 if p.handShake&versionSent != 0 { 332 return errors.New("invalid handshake: already sent Version") 333 } 334 err = p.writeMsg(msg) 335 if err == nil { 336 p.handShake |= versionSent 337 } 338 return err 339 } 340 341 // HandleVersion checks for the handshake state and version message contents. 342 func (p *TCPPeer) HandleVersion(version *payload.Version) error { 343 p.lock.Lock() 344 defer p.lock.Unlock() 345 if p.handShake&versionReceived != 0 { 346 return errors.New("invalid handshake: already received Version") 347 } 348 p.version = version 349 for _, cap := range version.Capabilities { 350 if cap.Type == capability.FullNode { 351 p.isFullNode = true 352 p.lastBlockIndex = cap.Data.(*capability.Node).StartHeight 353 break 354 } 355 } 356 357 p.handShake |= versionReceived 358 return nil 359 } 360 361 // SendVersionAck checks for the handshake state and sends a message to the peer. 362 func (p *TCPPeer) SendVersionAck(msg *Message) error { 363 p.lock.Lock() 364 defer p.lock.Unlock() 365 if p.handShake&versionReceived == 0 { 366 return errors.New("invalid handshake: tried to send VersionAck, but no version received yet") 367 } 368 if p.handShake&versionSent == 0 { 369 return errors.New("invalid handshake: tried to send VersionAck, but didn't send Version yet") 370 } 371 if p.handShake&verAckSent != 0 { 372 return errors.New("invalid handshake: already sent VersionAck") 373 } 374 err := p.writeMsg(msg) 375 if err == nil { 376 p.handShake |= verAckSent 377 } 378 return err 379 } 380 381 // HandleVersionAck checks handshake sequence correctness when VerAck message 382 // is received. 383 func (p *TCPPeer) HandleVersionAck() error { 384 p.lock.Lock() 385 defer p.lock.Unlock() 386 if p.handShake&versionSent == 0 { 387 return errors.New("invalid handshake: received VersionAck, but no version sent yet") 388 } 389 if p.handShake&versionReceived == 0 { 390 return errors.New("invalid handshake: received VersionAck, but no version received yet") 391 } 392 if p.handShake&verAckReceived != 0 { 393 return errors.New("invalid handshake: already received VersionAck") 394 } 395 p.handShake |= verAckReceived 396 return nil 397 } 398 399 // ConnectionAddr implements the Peer interface. 400 func (p *TCPPeer) ConnectionAddr() string { 401 if p.addr != "" { 402 return p.addr 403 } 404 return p.conn.RemoteAddr().String() 405 } 406 407 // RemoteAddr implements the Peer interface. 408 func (p *TCPPeer) RemoteAddr() net.Addr { 409 return p.conn.RemoteAddr() 410 } 411 412 // PeerAddr implements the Peer interface. 413 func (p *TCPPeer) PeerAddr() net.Addr { 414 remote := p.conn.RemoteAddr() 415 // The network can be non-tcp in unit tests. 416 if p.version == nil || remote.Network() != "tcp" { 417 return p.RemoteAddr() 418 } 419 host, _, err := net.SplitHostPort(remote.String()) 420 if err != nil { 421 return p.RemoteAddr() 422 } 423 var port uint16 424 for _, cap := range p.version.Capabilities { 425 if cap.Type == capability.TCPServer { 426 port = cap.Data.(*capability.Server).Port 427 } 428 } 429 if port == 0 { 430 return p.RemoteAddr() 431 } 432 addrString := net.JoinHostPort(host, strconv.Itoa(int(port))) 433 tcpAddr, err := net.ResolveTCPAddr("tcp", addrString) 434 if err != nil { 435 return p.RemoteAddr() 436 } 437 return tcpAddr 438 } 439 440 // Disconnect will fill the peer's done channel with the given error. 441 func (p *TCPPeer) Disconnect(err error) { 442 p.finale.Do(func() { 443 close(p.done) 444 p.conn.Close() 445 p.server.unregister <- peerDrop{p, err} 446 }) 447 } 448 449 // Version implements the Peer interface. 450 func (p *TCPPeer) Version() *payload.Version { 451 return p.version 452 } 453 454 // LastBlockIndex returns the last block index. 455 func (p *TCPPeer) LastBlockIndex() uint32 { 456 p.lock.RLock() 457 defer p.lock.RUnlock() 458 return p.lastBlockIndex 459 } 460 461 // SetPingTimer adds an outgoing ping to the counter and sets a PingTimeout timer 462 // that will shut the connection down in case of no response. 463 func (p *TCPPeer) SetPingTimer() { 464 p.lock.Lock() 465 p.pingSent++ 466 if p.pingTimer == nil { 467 p.pingTimer = time.AfterFunc(p.server.PingTimeout, func() { 468 p.Disconnect(errPingPong) 469 }) 470 } 471 p.lock.Unlock() 472 } 473 474 // HandlePing handles a ping message received from the peer. 475 func (p *TCPPeer) HandlePing(ping *payload.Ping) error { 476 p.lock.Lock() 477 defer p.lock.Unlock() 478 p.lastBlockIndex = ping.LastBlockIndex 479 return nil 480 } 481 482 // HandlePong handles a pong message received from the peer and does an appropriate 483 // accounting of outstanding pings and timeouts. 484 func (p *TCPPeer) HandlePong(pong *payload.Ping) error { 485 p.lock.Lock() 486 defer p.lock.Unlock() 487 if p.pingTimer != nil && !p.pingTimer.Stop() { 488 return errPingPong 489 } 490 p.pingTimer = nil 491 p.pingSent-- 492 if p.pingSent < 0 { 493 return errUnexpectedPong 494 } 495 p.lastBlockIndex = pong.LastBlockIndex 496 return nil 497 } 498 499 // AddGetAddrSent increments internal outstanding getaddr requests counter. Then, 500 // the peer can only send one addr reply per getaddr request. 501 func (p *TCPPeer) AddGetAddrSent() { 502 p.getAddrSent.Add(1) 503 } 504 505 // CanProcessAddr decrements internal outstanding getaddr requests counter and 506 // answers whether the addr command from the peer can be safely processed. 507 func (p *TCPPeer) CanProcessAddr() bool { 508 v := p.getAddrSent.Add(-1) 509 return v >= 0 510 }