github.com/igggame/nebulas-go@v2.1.0+incompatible/net/stream.go (about) 1 // Copyright (C) 2018 go-nebulas authors 2 // 3 // This file is part of the go-nebulas library. 4 // 5 // the go-nebulas library is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // the go-nebulas library is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU General Public License for more details. 14 // 15 // You should have received a copy of the GNU General Public License 16 // along with the go-nebulas library. If not, see <http://www.gnu.org/licenses/>. 17 // 18 19 package net 20 21 import ( 22 "errors" 23 "fmt" 24 "strings" 25 "sync" 26 "time" 27 28 "github.com/gogo/protobuf/proto" 29 libnet "github.com/libp2p/go-libp2p-net" 30 peer "github.com/libp2p/go-libp2p-peer" 31 ma "github.com/multiformats/go-multiaddr" 32 netpb "github.com/nebulasio/go-nebulas/net/pb" 33 "github.com/nebulasio/go-nebulas/util/logging" 34 "github.com/sirupsen/logrus" 35 ) 36 37 // Stream Message Type 38 const ( 39 ClientVersion = "0.3.0" 40 NebProtocolID = "/neb/1.0.0" 41 HELLO = "hello" 42 OK = "ok" 43 BYE = "bye" 44 SYNCROUTE = "syncroute" 45 ROUTETABLE = "routetable" 46 RECVEDMSG = "recvedmsg" 47 CurrentVersion = 0x0 48 ) 49 50 // Stream Status 51 const ( 52 streamStatusInit = iota 53 streamStatusHandshakeSucceed 54 streamStatusClosed 55 ) 56 57 // Stream Errors 58 var ( 59 ErrShouldCloseConnectionAndExitLoop = errors.New("should close connection and exit loop") 60 ErrStreamIsNotConnected = errors.New("stream is not connected") 61 ) 62 63 // Stream define the structure of a stream in p2p network 64 type Stream struct { 65 syncMutex sync.Mutex 66 pid peer.ID 67 addr ma.Multiaddr 68 stream libnet.Stream 69 node *Node 70 handshakeSucceedCh chan bool 71 messageNotifChan chan int 72 highPriorityMessageChan chan *NebMessage 73 normalPriorityMessageChan chan *NebMessage 74 lowPriorityMessageChan chan *NebMessage 75 quitWriteCh chan bool 76 status int 77 connectedAt int64 78 latestReadAt int64 79 latestWriteAt int64 80 msgCount map[string]int 81 reservedFlag []byte 82 } 83 84 // NewStream return a new Stream 85 func NewStream(stream libnet.Stream, node *Node) *Stream { 86 return newStreamInstance(stream.Conn().RemotePeer(), stream.Conn().RemoteMultiaddr(), stream, node) 87 } 88 89 // NewStreamFromPID return a new Stream based on the pid 90 func NewStreamFromPID(pid peer.ID, node *Node) *Stream { 91 return newStreamInstance(pid, nil, nil, node) 92 } 93 94 func newStreamInstance(pid peer.ID, addr ma.Multiaddr, stream libnet.Stream, node *Node) *Stream { 95 return &Stream{ 96 pid: pid, 97 addr: addr, 98 stream: stream, 99 node: node, 100 handshakeSucceedCh: make(chan bool, 1), 101 messageNotifChan: make(chan int, 6*1024), 102 highPriorityMessageChan: make(chan *NebMessage, 2*1024), 103 normalPriorityMessageChan: make(chan *NebMessage, 2*1024), 104 lowPriorityMessageChan: make(chan *NebMessage, 2*1024), 105 quitWriteCh: make(chan bool, 1), 106 status: streamStatusInit, 107 connectedAt: time.Now().Unix(), 108 latestReadAt: 0, 109 latestWriteAt: 0, 110 msgCount: make(map[string]int), 111 reservedFlag: DefaultReserved, 112 } 113 } 114 115 // Connect to the stream 116 func (s *Stream) Connect() error { 117 logging.VLog().WithFields(logrus.Fields{ 118 "stream": s.String(), 119 }).Debug("Connecting to peer.") 120 121 // connect to host. 122 stream, err := s.node.host.NewStream( 123 s.node.context, 124 s.pid, 125 NebProtocolID, 126 ) 127 if err != nil { 128 logging.VLog().WithFields(logrus.Fields{ 129 "stream": s.String(), 130 "err": err, 131 }).Debug("Failed to connect to host.") 132 return err 133 } 134 s.stream = stream 135 s.addr = stream.Conn().RemoteMultiaddr() 136 137 return nil 138 } 139 140 // IsConnected return if the stream is connected 141 func (s *Stream) IsConnected() bool { 142 return s.stream != nil 143 } 144 145 // IsHandshakeSucceed return if the handshake in the stream succeed 146 func (s *Stream) IsHandshakeSucceed() bool { 147 return s.status == streamStatusHandshakeSucceed 148 } 149 150 func (s *Stream) String() string { 151 addrStr := "" 152 if s.addr != nil { 153 addrStr = s.addr.String() 154 } 155 156 return fmt.Sprintf("Peer Stream: %s,%s", s.pid.Pretty(), addrStr) 157 } 158 159 // SendProtoMessage send proto msg to buffer 160 func (s *Stream) SendProtoMessage(messageName string, pb proto.Message, priority int) error { 161 data, err := proto.Marshal(pb) 162 if err != nil { 163 logging.VLog().WithFields(logrus.Fields{ 164 "err": err, 165 "messageName": messageName, 166 "stream": s.String(), 167 }).Debug("Failed to marshal proto message.") 168 return err 169 } 170 171 return s.SendMessage(messageName, data, priority) 172 } 173 174 // SendMessage send msg to buffer 175 func (s *Stream) SendMessage(messageName string, data []byte, priority int) error { 176 message, err := NewNebMessage(s.node.config.ChainID, s.reservedFlag, CurrentVersion, messageName, data) 177 if err != nil { 178 return err 179 } 180 181 // metrics. 182 metricsPacketsOutByMessageName(messageName, message.Length()) 183 184 // send to pool. 185 message.FlagSendMessageAt() 186 187 // use a non-blocking channel to avoid blocking when the channel is full. 188 switch priority { 189 case MessagePriorityHigh: 190 s.highPriorityMessageChan <- message 191 case MessagePriorityNormal: 192 select { 193 case s.normalPriorityMessageChan <- message: 194 default: 195 logging.VLog().WithFields(logrus.Fields{ 196 "normalPriorityMessageChan.len": len(s.normalPriorityMessageChan), 197 "stream": s.String(), 198 }).Debug("Received too many normal priority message.") 199 return nil 200 } 201 default: 202 select { 203 case s.lowPriorityMessageChan <- message: 204 default: 205 logging.VLog().WithFields(logrus.Fields{ 206 "lowPriorityMessageChan.len": len(s.lowPriorityMessageChan), 207 "stream": s.String(), 208 }).Debug("Received too many low priority message.") 209 return nil 210 } 211 } 212 select { 213 case s.messageNotifChan <- 1: 214 default: 215 logging.VLog().WithFields(logrus.Fields{ 216 "messageNotifChan.len": len(s.messageNotifChan), 217 "stream": s.String(), 218 }).Debug("Received too many message notifChan.") 219 return nil 220 } 221 return nil 222 } 223 224 func (s *Stream) Write(data []byte) error { 225 if s.stream == nil { 226 s.close(ErrStreamIsNotConnected) 227 return ErrStreamIsNotConnected 228 } 229 230 // at least 5kb/s to write message 231 deadline := time.Now().Add(time.Duration(len(data)/1024/5+1) * time.Second) 232 if err := s.stream.SetWriteDeadline(deadline); err != nil { 233 return err 234 } 235 n, err := s.stream.Write(data) 236 if err != nil { 237 logging.VLog().WithFields(logrus.Fields{ 238 "err": err, 239 "stream": s.String(), 240 }).Warn("Failed to send message to peer.") 241 s.close(err) 242 return err 243 } 244 s.latestWriteAt = time.Now().Unix() 245 246 // metrics. 247 metricsPacketsOut.Mark(1) 248 metricsBytesOut.Mark(int64(n)) 249 250 return nil 251 } 252 253 // WriteNebMessage write neb msg in the stream 254 func (s *Stream) WriteNebMessage(message *NebMessage) error { 255 // metrics. 256 metricsPacketsOutByMessageName(message.MessageName(), message.Length()) 257 258 err := s.Write(message.Content()) 259 message.FlagWriteMessageAt() 260 261 return err 262 } 263 264 // WriteProtoMessage write proto msg in the stream 265 func (s *Stream) WriteProtoMessage(messageName string, pb proto.Message, reservedClientFlag byte) error { 266 data, err := proto.Marshal(pb) 267 if err != nil { 268 logging.VLog().WithFields(logrus.Fields{ 269 "err": err, 270 "messageName": messageName, 271 "stream": s.String(), 272 }).Debug("Failed to marshal proto message.") 273 return err 274 } 275 276 return s.WriteMessage(messageName, data, reservedClientFlag) 277 } 278 279 // WriteMessage write raw msg in the stream 280 func (s *Stream) WriteMessage(messageName string, data []byte, reservedClientFlag byte) error { 281 // hello and ok messages come with the client flag bit. 282 var reserved = make([]byte, len(s.reservedFlag)) 283 copy(reserved, s.reservedFlag) 284 285 if reservedClientFlag == ReservedCompressionClientFlag { 286 reserved[2] = s.reservedFlag[2] | reservedClientFlag 287 } 288 289 message, err := NewNebMessage(s.node.config.ChainID, reserved, CurrentVersion, messageName, data) 290 if err != nil { 291 return err 292 } 293 294 return s.WriteNebMessage(message) 295 } 296 297 // StartLoop start stream handling loop. 298 func (s *Stream) StartLoop() { 299 go s.writeLoop() 300 go s.readLoop() 301 } 302 303 func (s *Stream) readLoop() { 304 // send Hello to host if stream is not connected. 305 if !s.IsConnected() { 306 if err := s.Connect(); err != nil { 307 s.close(err) 308 return 309 } 310 if err := s.Hello(); err != nil { 311 s.close(err) 312 return 313 } 314 } 315 316 // loop. 317 buf := make([]byte, 1024*4) 318 messageBuffer := make([]byte, 0) 319 320 var message *NebMessage 321 322 for { 323 n, err := s.stream.Read(buf) 324 if err != nil { 325 logging.VLog().WithFields(logrus.Fields{ 326 "err": err, 327 "stream": s.String(), 328 }).Debug("Error occurred when reading data from network connection.") 329 s.close(err) 330 return 331 } 332 333 messageBuffer = append(messageBuffer, buf[:n]...) 334 s.latestReadAt = time.Now().Unix() 335 336 for { 337 if message == nil { 338 var err error 339 340 // waiting for header data. 341 if len(messageBuffer) < NebMessageHeaderLength { 342 // continue reading. 343 break 344 } 345 346 message, err = ParseNebMessage(messageBuffer) 347 if err != nil { 348 s.Bye() 349 return 350 } 351 352 // check ChainID. 353 if s.node.config.ChainID != message.ChainID() { 354 logging.VLog().WithFields(logrus.Fields{ 355 "err": err, 356 "stream": s.String(), 357 "conf.chainID": s.node.config.ChainID, 358 "message.chainID": message.ChainID(), 359 }).Warn("Invalid chainID, disconnect the connection.") 360 s.Bye() 361 return 362 } 363 364 // remove header from buffer. 365 messageBuffer = messageBuffer[NebMessageHeaderLength:] 366 } 367 368 // waiting for data. 369 if len(messageBuffer) < int(message.DataLength()) { 370 // continue reading. 371 break 372 } 373 374 if err := message.ParseMessageData(messageBuffer); err != nil { 375 s.Bye() 376 return 377 } 378 379 // remove data from buffer. 380 messageBuffer = messageBuffer[message.DataLength():] 381 382 // metrics. 383 metricsPacketsIn.Mark(1) 384 metricsBytesIn.Mark(int64(message.Length())) 385 metricsPacketsInByMessageName(message.MessageName(), message.Length()) 386 387 // handle message. 388 if err := s.handleMessage(message); err == ErrShouldCloseConnectionAndExitLoop { 389 s.Bye() 390 return 391 } 392 393 // reset message. 394 message = nil 395 } 396 } 397 } 398 399 func (s *Stream) writeLoop() { 400 // waiting for handshake succeed. 401 handshakeTimeoutTicker := time.NewTicker(30 * time.Second) 402 select { 403 case <-s.handshakeSucceedCh: 404 // handshake succeed. 405 case <-s.quitWriteCh: 406 logging.VLog().WithFields(logrus.Fields{ 407 "stream": s.String(), 408 }).Debug("Quiting Stream Write Loop.") 409 return 410 case <-handshakeTimeoutTicker.C: 411 logging.VLog().WithFields(logrus.Fields{ 412 "stream": s.String(), 413 }).Debug("Handshaking Stream timeout, quiting.") 414 s.close(errors.New("Handshake timeout")) 415 return 416 } 417 418 for { 419 select { 420 case <-s.quitWriteCh: 421 logging.VLog().WithFields(logrus.Fields{ 422 "stream": s.String(), 423 }).Debug("Quiting Stream Write Loop.") 424 return 425 case <-s.messageNotifChan: 426 select { 427 case message := <-s.highPriorityMessageChan: 428 s.WriteNebMessage(message) 429 continue 430 default: 431 } 432 433 select { 434 case message := <-s.normalPriorityMessageChan: 435 s.WriteNebMessage(message) 436 continue 437 default: 438 } 439 440 select { 441 case message := <-s.lowPriorityMessageChan: 442 s.WriteNebMessage(message) 443 continue 444 default: 445 } 446 } 447 } 448 } 449 450 func (s *Stream) handleMessage(message *NebMessage) error { 451 messageName := message.MessageName() 452 s.msgCount[messageName]++ 453 454 switch messageName { 455 case HELLO: 456 return s.onHello(message) 457 case OK: 458 return s.onOk(message) 459 case BYE: 460 return s.onBye(message) 461 } 462 463 // check handshake status. 464 if s.status != streamStatusHandshakeSucceed { 465 return ErrShouldCloseConnectionAndExitLoop 466 } 467 468 switch messageName { 469 case SYNCROUTE: 470 return s.onSyncRoute(message) 471 case ROUTETABLE: 472 return s.onRouteTable(message) 473 default: 474 data, err := s.getData(message) 475 if err != nil { 476 logging.VLog().WithFields(logrus.Fields{ 477 "err": err, 478 "messageName": message.MessageName(), 479 }).Info("Handle message data occurs error.") 480 return err 481 } 482 s.node.netService.PutMessage(NewBaseMessage(message.MessageName(), s.pid.Pretty(), data)) 483 // record recv message. 484 RecordRecvMessage(s, message.DataCheckSum()) 485 } 486 487 return nil 488 } 489 490 // Close close the stream 491 func (s *Stream) close(reason error) { 492 // Add lock & close flag to prevent multi call. 493 s.syncMutex.Lock() 494 defer s.syncMutex.Unlock() 495 496 if s.status == streamStatusClosed { 497 return 498 } 499 s.status = streamStatusClosed 500 501 logging.VLog().WithFields(logrus.Fields{ 502 "stream": s.String(), 503 "reason": reason, 504 }).Debug("Closing stream.") 505 506 // cleanup. 507 s.node.streamManager.RemoveStream(s) 508 s.node.routeTable.RemovePeerStream(s) 509 510 // quit. 511 s.quitWriteCh <- true 512 513 // close stream. 514 if s.stream != nil { 515 s.stream.Close() 516 } 517 } 518 519 // Bye say bye in the stream 520 func (s *Stream) Bye() { 521 s.WriteMessage(BYE, []byte{}, DefaultReservedFlag) 522 s.close(errors.New("bye: force close")) 523 } 524 525 func (s *Stream) onBye(message *NebMessage) error { 526 logging.VLog().WithFields(logrus.Fields{ 527 "stream": s.String(), 528 }).Debug("Received Bye message, close the connection.") 529 return ErrShouldCloseConnectionAndExitLoop 530 } 531 532 // Hello say hello in the stream 533 func (s *Stream) Hello() error { 534 msg := &netpb.Hello{ 535 NodeId: s.node.id.String(), 536 ClientVersion: ClientVersion, 537 } 538 return s.WriteProtoMessage(HELLO, msg, ReservedCompressionClientFlag) 539 } 540 541 func (s *Stream) onHello(message *NebMessage) error { 542 msg, err := netpb.HelloMessageFromProto(message.OriginalData()) 543 if err != nil { 544 return ErrShouldCloseConnectionAndExitLoop 545 } 546 547 if msg.NodeId != s.pid.String() || !CheckClientVersionCompatibility(ClientVersion, msg.ClientVersion) { 548 // invalid client, bye(). 549 logging.VLog().WithFields(logrus.Fields{ 550 "pid": s.pid.Pretty(), 551 "address": s.addr, 552 "ok.node_id": msg.NodeId, 553 "ok.client_version": msg.ClientVersion, 554 }).Warn("Invalid NodeId or incompatible client version.") 555 return ErrShouldCloseConnectionAndExitLoop 556 } 557 558 if (message.Reserved()[2] & ReservedCompressionClientFlag) > 0 { 559 s.reservedFlag = CurrentReserved 560 } 561 562 // add to route table. 563 s.node.routeTable.AddPeerStream(s) 564 565 // handshake finished. 566 s.finishHandshake() 567 568 return s.Ok() 569 } 570 571 // Ok say ok in the stream 572 func (s *Stream) Ok() error { 573 // send OK. 574 resp := &netpb.OK{ 575 NodeId: s.node.id.String(), 576 ClientVersion: ClientVersion, 577 } 578 579 return s.WriteProtoMessage(OK, resp, ReservedCompressionClientFlag) 580 } 581 582 func (s *Stream) onOk(message *NebMessage) error { 583 msg, err := netpb.OKMessageFromProto(message.OriginalData()) 584 if err != nil { 585 return ErrShouldCloseConnectionAndExitLoop 586 } 587 588 if msg.NodeId != s.pid.String() || !CheckClientVersionCompatibility(ClientVersion, msg.ClientVersion) { 589 // invalid client, bye(). 590 logging.VLog().WithFields(logrus.Fields{ 591 "pid": s.pid.Pretty(), 592 "address": s.addr, 593 "ok.node_id": msg.NodeId, 594 "ok.client_version": msg.ClientVersion, 595 }).Warn("Invalid NodeId or incompatible client version.") 596 return ErrShouldCloseConnectionAndExitLoop 597 } 598 599 if (message.Reserved()[2] & ReservedCompressionClientFlag) > 0 { 600 s.reservedFlag = CurrentReserved 601 } 602 603 // add to route table. 604 s.node.routeTable.AddPeerStream(s) 605 606 // handshake finished. 607 s.finishHandshake() 608 609 return nil 610 } 611 612 // SyncRoute send sync route request 613 func (s *Stream) SyncRoute() error { 614 return s.SendMessage(SYNCROUTE, []byte{}, MessagePriorityHigh) 615 } 616 617 func (s *Stream) onSyncRoute(message *NebMessage) error { 618 return s.RouteTable() 619 } 620 621 // RouteTable send sync table request 622 func (s *Stream) RouteTable() error { 623 // get random peers from routeTable 624 peers := s.node.routeTable.GetRandomPeers(s.pid) 625 626 // prepare the protobuf message. 627 msg := &netpb.Peers{ 628 Peers: make([]*netpb.PeerInfo, len(peers)), 629 } 630 631 for i, v := range peers { 632 pi := &netpb.PeerInfo{ 633 Id: v.ID.Pretty(), 634 Addrs: make([]string, len(v.Addrs)), 635 } 636 for j, addr := range v.Addrs { 637 pi.Addrs[j] = addr.String() 638 } 639 msg.Peers[i] = pi 640 } 641 642 logging.VLog().WithFields(logrus.Fields{ 643 "stream": s.String(), 644 "routetableCount": len(peers), 645 }).Debug("Replied sync route message.") 646 647 return s.SendProtoMessage(ROUTETABLE, msg, MessagePriorityHigh) 648 } 649 650 func (s *Stream) onRouteTable(message *NebMessage) error { 651 data, err := s.getData(message) 652 if err != nil { 653 return err 654 } 655 656 peers := new(netpb.Peers) 657 if err := proto.Unmarshal(data, peers); err != nil { 658 logging.VLog().WithFields(logrus.Fields{ 659 "err": err, 660 }).Debug("Invalid Peers proto message.") 661 return ErrShouldCloseConnectionAndExitLoop 662 } 663 664 s.node.routeTable.AddPeers(s.node.ID(), peers) 665 666 return nil 667 } 668 669 func (s *Stream) finishHandshake() { 670 logging.VLog().WithFields(logrus.Fields{ 671 "stream": s.String(), 672 }).Debug("Finished handshake.") 673 674 s.status = streamStatusHandshakeSucceed 675 s.handshakeSucceedCh <- true 676 } 677 678 func (s *Stream) getData(message *NebMessage) ([]byte, error) { 679 var data []byte 680 if ByteSliceEqualBCE(s.reservedFlag, CurrentReserved) { 681 var err error 682 data, err = message.Data() 683 if err != nil { 684 return nil, err 685 } 686 } else { 687 data = message.OriginalData() 688 } 689 return data, nil 690 } 691 692 // CheckClientVersionCompatibility if two clients are compatible 693 // If the clientVersion of node A is X.Y.Z, then node B must be X.Y.{} to be compatible with A. 694 func CheckClientVersionCompatibility(v1, v2 string) bool { 695 s1 := strings.Split(v1, ".") 696 s2 := strings.Split(v1, ".") 697 698 if len(s1) != 3 || len(s2) != 3 { 699 return false 700 } 701 702 if s1[0] != s2[0] || s1[1] != s2[1] { 703 return false 704 } 705 return true 706 } 707 708 // ByteSliceEqualBCE determines whether two byte arrays are equal. 709 func ByteSliceEqualBCE(a, b []byte) bool { 710 if len(a) != len(b) { 711 return false 712 } 713 714 if (a == nil) != (b == nil) { 715 return false 716 } 717 718 b = b[:len(a)] 719 for i, v := range a { 720 if v != b[i] { 721 return false 722 } 723 } 724 725 return true 726 }