github.com/anycable/anycable-go@v1.5.1/node/node.go (about) 1 package node 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "log/slog" 8 "runtime" 9 "sync" 10 "time" 11 12 "github.com/anycable/anycable-go/broker" 13 "github.com/anycable/anycable-go/common" 14 "github.com/anycable/anycable-go/hub" 15 "github.com/anycable/anycable-go/logger" 16 "github.com/anycable/anycable-go/metrics" 17 "github.com/anycable/anycable-go/utils" 18 "github.com/anycable/anycable-go/ws" 19 "github.com/joomcode/errorx" 20 ) 21 22 const ( 23 metricsGoroutines = "goroutines_num" 24 metricsMemSys = "mem_sys_bytes" 25 metricsClientsNum = "clients_num" 26 metricsUniqClientsNum = "clients_uniq_num" 27 metricsStreamsNum = "broadcast_streams_num" 28 metricsDisconnectQueue = "disconnect_queue_size" 29 30 metricsFailedAuths = "failed_auths_total" 31 metricsReceivedMsg = "client_msg_total" 32 metricsFailedCommandReceived = "failed_client_msg_total" 33 metricsBroadcastMsg = "broadcast_msg_total" 34 metricsUnknownBroadcast = "failed_broadcast_msg_total" 35 36 metricsSentMsg = "server_msg_total" 37 metricsFailedSent = "failed_server_msg_total" 38 39 metricsDataSent = "data_sent_total" 40 metricsDataReceived = "data_rcvd_total" 41 ) 42 43 // AppNode describes a basic node interface 44 // 45 //go:generate mockery --name AppNode --output "../node_mocks" --outpkg node_mocks 46 type AppNode interface { 47 HandlePubSub(msg []byte) 48 LookupSession(id string) *Session 49 Authenticate(s *Session, opts ...AuthOption) (*common.ConnectResult, error) 50 Authenticated(s *Session, identifiers string) 51 Subscribe(s *Session, msg *common.Message) (*common.CommandResult, error) 52 Unsubscribe(s *Session, msg *common.Message) (*common.CommandResult, error) 53 Perform(s *Session, msg *common.Message) (*common.CommandResult, error) 54 Disconnect(s *Session) error 55 } 56 57 // Connection represents underlying connection 58 type Connection interface { 59 Write(msg []byte, deadline time.Time) error 60 WriteBinary(msg []byte, deadline time.Time) error 61 Read() ([]byte, error) 62 Close(code int, reason string) 63 } 64 65 // Node represents the whole application 66 type Node struct { 67 id string 68 metrics metrics.Instrumenter 69 70 config *Config 71 hub *hub.Hub 72 broker broker.Broker 73 controller Controller 74 disconnector Disconnector 75 shutdownCh chan struct{} 76 shutdownMu sync.Mutex 77 closed bool 78 log *slog.Logger 79 } 80 81 var _ AppNode = (*Node)(nil) 82 83 type NodeOption = func(*Node) 84 85 func WithController(c Controller) NodeOption { 86 return func(n *Node) { 87 n.controller = c 88 } 89 } 90 91 func WithInstrumenter(i metrics.Instrumenter) NodeOption { 92 return func(n *Node) { 93 n.metrics = i 94 } 95 } 96 97 func WithLogger(l *slog.Logger) NodeOption { 98 return func(n *Node) { 99 n.log = l.With("context", "node") 100 } 101 } 102 103 func WithID(id string) NodeOption { 104 return func(n *Node) { 105 n.id = id 106 } 107 } 108 109 // NewNode builds new node struct 110 func NewNode(config *Config, opts ...NodeOption) *Node { 111 n := &Node{ 112 config: config, 113 shutdownCh: make(chan struct{}), 114 } 115 116 for _, opt := range opts { 117 opt(n) 118 } 119 120 // Setup default logger 121 if n.log == nil { 122 n.log = slog.With("context", "node") 123 } 124 125 n.hub = hub.NewHub(config.HubGopoolSize, n.log) 126 127 if n.metrics != nil { 128 n.registerMetrics() 129 } 130 131 return n 132 } 133 134 // Start runs all the required goroutines 135 func (n *Node) Start() error { 136 go n.hub.Run() 137 go n.collectStats() 138 139 return nil 140 } 141 142 // ID returns node identifier 143 func (n *Node) ID() string { 144 return n.id 145 } 146 147 // SetDisconnector set disconnector for the node 148 func (n *Node) SetDisconnector(d Disconnector) { 149 n.disconnector = d 150 } 151 152 func (n *Node) SetBroker(b broker.Broker) { 153 n.broker = b 154 } 155 156 // Return current instrumenter for the node 157 func (n *Node) Instrumenter() metrics.Instrumenter { 158 return n.metrics 159 } 160 161 // HandleCommand parses incoming message from client and 162 // execute the command (if recognized) 163 func (n *Node) HandleCommand(s *Session, msg *common.Message) (err error) { 164 s.Log.Debug("incoming message", "data", msg) 165 switch msg.Command { 166 case "pong": 167 s.handlePong(msg) 168 case "subscribe": 169 _, err = n.Subscribe(s, msg) 170 case "unsubscribe": 171 _, err = n.Unsubscribe(s, msg) 172 case "message": 173 _, err = n.Perform(s, msg) 174 case "history": 175 err = n.History(s, msg) 176 case "whisper": 177 err = n.Whisper(s, msg) 178 default: 179 err = fmt.Errorf("unknown command: %s", msg.Command) 180 } 181 182 return 183 } 184 185 // HandleBroadcast parses incoming broadcast message, record it and re-transmit to other nodes 186 func (n *Node) HandleBroadcast(raw []byte) { 187 msg, err := common.PubSubMessageFromJSON(raw) 188 189 if err != nil { 190 n.metrics.CounterIncrement(metricsUnknownBroadcast) 191 n.log.Warn("failed to parse pubsub message", "data", logger.CompactValue(raw), "error", err) 192 return 193 } 194 195 switch v := msg.(type) { 196 case common.StreamMessage: 197 n.log.Debug("handle broadcast message", "payload", &v) 198 n.broker.HandleBroadcast(&v) 199 case []*common.StreamMessage: 200 n.log.Debug("handle batch-broadcast message", "payload", &v) 201 for _, el := range v { 202 n.broker.HandleBroadcast(el) 203 } 204 case common.RemoteCommandMessage: 205 n.log.Debug("handle remote command", "command", &v) 206 n.broker.HandleCommand(&v) 207 } 208 } 209 210 // HandlePubSub parses incoming pubsub message and broadcast it to all clients (w/o using a broker) 211 func (n *Node) HandlePubSub(raw []byte) { 212 msg, err := common.PubSubMessageFromJSON(raw) 213 214 if err != nil { 215 n.metrics.CounterIncrement(metricsUnknownBroadcast) 216 n.log.Warn("failed to parse pubsub message", "data", logger.CompactValue(raw), "error", err) 217 return 218 } 219 220 switch v := msg.(type) { 221 case common.StreamMessage: 222 n.Broadcast(&v) 223 case []*common.StreamMessage: 224 for _, el := range v { 225 n.Broadcast(el) 226 } 227 case common.RemoteCommandMessage: 228 n.ExecuteRemoteCommand(&v) 229 } 230 } 231 232 func (n *Node) LookupSession(id string) *Session { 233 hubSession := n.hub.FindByIdentifier(id) 234 session, _ := hubSession.(*Session) 235 return session 236 } 237 238 // Shutdown stops all services (hub, controller) 239 func (n *Node) Shutdown(ctx context.Context) (err error) { 240 n.shutdownMu.Lock() 241 if n.closed { 242 n.shutdownMu.Unlock() 243 return errors.New("already shut down") 244 } 245 246 close(n.shutdownCh) 247 248 n.closed = true 249 n.shutdownMu.Unlock() 250 251 if n.hub != nil { 252 active := n.hub.Size() 253 254 if active > 0 { 255 n.log.Info("closing active connections", "num", active) 256 n.disconnectAll(ctx) 257 } 258 259 n.hub.Shutdown() 260 } 261 262 if n.disconnector != nil { 263 err := n.disconnector.Shutdown(ctx) 264 265 if err != nil { 266 n.log.Warn("failed to shutdown disconnector gracefully", "error", err) 267 } 268 } 269 270 if n.controller != nil { 271 err := n.controller.Shutdown() 272 273 if err != nil { 274 n.log.Warn("failed to shutdown controller gracefully", "error", err) 275 } 276 } 277 278 return 279 } 280 281 func (n *Node) IsShuttingDown() bool { 282 n.shutdownMu.Lock() 283 defer n.shutdownMu.Unlock() 284 285 return n.closed 286 } 287 288 type AuthOptions struct { 289 DisconnectOnFailure bool 290 } 291 292 func newAuthOptions(modifiers []AuthOption) *AuthOptions { 293 base := &AuthOptions{ 294 DisconnectOnFailure: true, 295 } 296 297 for _, modifier := range modifiers { 298 modifier(base) 299 } 300 301 return base 302 } 303 304 type AuthOption = func(*AuthOptions) 305 306 func WithDisconnectOnFailure(disconnect bool) AuthOption { 307 return func(opts *AuthOptions) { 308 opts.DisconnectOnFailure = disconnect 309 } 310 } 311 312 // Authenticate calls controller to perform authentication. 313 // If authentication is successful, session is registered with a hub. 314 func (n *Node) Authenticate(s *Session, options ...AuthOption) (*common.ConnectResult, error) { 315 opts := newAuthOptions(options) 316 317 if s.IsResumeable() { 318 restored := n.TryRestoreSession(s) 319 320 if restored { 321 return &common.ConnectResult{Status: common.SUCCESS}, nil 322 } 323 } 324 325 res, err := n.controller.Authenticate(s.GetID(), s.env) 326 327 s.Log.Debug("controller authenticate", "response", res, "err", err) 328 329 if err != nil { 330 s.Disconnect("Auth Error", ws.CloseInternalServerErr) 331 return nil, errorx.Decorate(err, "failed to authenticate") 332 } 333 334 if res.Status == common.SUCCESS { 335 n.Authenticated(s, res.Identifier) 336 } else { 337 if res.Status == common.FAILURE { 338 n.metrics.CounterIncrement(metricsFailedAuths) 339 } 340 341 if opts.DisconnectOnFailure { 342 defer s.Disconnect("Auth Failed", ws.CloseNormalClosure) 343 } 344 } 345 346 n.handleCallReply(s, res.ToCallResult()) 347 n.markDisconnectable(s, res.DisconnectInterest) 348 349 if s.IsResumeable() { 350 if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { 351 s.Log.Error("failed to persist session in cache", "error", berr) 352 } 353 } 354 355 return res, nil 356 } 357 358 // Mark session as authenticated and register it with a hub. 359 // Useful when you perform authentication manually, not using a controller. 360 func (n *Node) Authenticated(s *Session, ids string) { 361 s.SetIdentifiers(ids) 362 s.Connected = true 363 n.hub.AddSession(s) 364 } 365 366 func (n *Node) TryRestoreSession(s *Session) (restored bool) { 367 sid := s.GetID() 368 prev_sid := s.PrevSid() 369 370 if prev_sid == "" { 371 return false 372 } 373 374 cached_session, err := n.broker.RestoreSession(prev_sid) 375 376 if err != nil { 377 s.Log.Error("failed to fetch session cache", "old_sid", prev_sid, "error", err) 378 return false 379 } 380 381 if cached_session == nil { 382 s.Log.Debug("session not found in cache", "old_sid", prev_sid) 383 return false 384 } 385 386 err = s.RestoreFromCache(cached_session) 387 388 if err != nil { 389 s.Log.Error("failed to restore session from cache", "old_sid", prev_sid, "error", err) 390 return false 391 } 392 393 s.Log.Debug("session restored", "old_sid", prev_sid) 394 395 s.Connected = true 396 n.hub.AddSession(s) 397 398 // Resubscribe to streams 399 for identifier, channel_streams := range s.subscriptions.channels { 400 for stream := range channel_streams { 401 streamId := n.broker.Subscribe(stream) 402 n.hub.SubscribeSession(s, streamId, identifier) 403 } 404 } 405 406 // Send welcome message 407 s.Send(&common.Reply{ 408 Type: common.WelcomeType, 409 Sid: sid, 410 Restored: true, 411 RestoredIDs: utils.Keys(s.subscriptions.channels), 412 }) 413 414 if s.IsResumeable() { 415 if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { 416 s.Log.Error("failed to persist session in cache", "error", berr) 417 } 418 } 419 420 return true 421 } 422 423 // Subscribe subscribes session to a channel 424 func (n *Node) Subscribe(s *Session, msg *common.Message) (*common.CommandResult, error) { 425 s.smu.Lock() 426 427 if ok := s.subscriptions.HasChannel(msg.Identifier); ok { 428 s.smu.Unlock() 429 return nil, fmt.Errorf("already subscribed to %s", msg.Identifier) 430 } 431 432 res, err := n.controller.Subscribe(s.GetID(), s.env, s.GetIdentifiers(), msg.Identifier) 433 434 s.Log.Debug("controller subscribe", "response", res, "err", err) 435 436 var confirmed bool 437 438 if err != nil { // nolint: gocritic 439 if res == nil || res.Status == common.ERROR { 440 return nil, errorx.Decorate(err, "subscribe failed for %s", msg.Identifier) 441 } 442 } else if res.Status == common.SUCCESS { 443 confirmed = true 444 s.subscriptions.AddChannel(msg.Identifier) 445 s.Log.Debug("subscribed", "identifier", msg.Identifier) 446 } else { 447 s.Log.Debug("subscription rejected", "identifier", msg.Identifier) 448 } 449 450 s.smu.Unlock() 451 452 if res != nil { 453 n.handleCommandReply(s, msg, res) 454 n.markDisconnectable(s, res.DisconnectInterest) 455 } 456 457 if confirmed { 458 if s.IsResumeable() { 459 if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { 460 s.Log.Error("failed to persist session in cache", "error", berr) 461 } 462 } 463 464 if msg.History.Since > 0 || msg.History.Streams != nil { 465 if err := n.History(s, msg); err != nil { 466 s.Log.Warn("couldn't retrieve history", "identifier", msg.Identifier, "error", err) 467 } 468 469 return res, nil 470 } 471 } 472 473 return res, nil 474 } 475 476 // Unsubscribe unsubscribes session from a channel 477 func (n *Node) Unsubscribe(s *Session, msg *common.Message) (*common.CommandResult, error) { 478 s.smu.Lock() 479 480 if ok := s.subscriptions.HasChannel(msg.Identifier); !ok { 481 s.smu.Unlock() 482 return nil, fmt.Errorf("unknown subscription: %s", msg.Identifier) 483 } 484 485 res, err := n.controller.Unsubscribe(s.GetID(), s.env, s.GetIdentifiers(), msg.Identifier) 486 487 s.Log.Debug("controller unsubscribe", "response", res, "err", err) 488 489 if err != nil { 490 if res == nil || res.Status == common.ERROR { 491 return nil, errorx.Decorate(err, "failed to unsubscribe from %s", msg.Identifier) 492 } 493 } else { 494 // Make sure to remove all streams subscriptions 495 res.StopAllStreams = true 496 497 s.env.RemoveChannelState(msg.Identifier) 498 s.subscriptions.RemoveChannel(msg.Identifier) 499 500 s.Log.Debug("unsubscribed", "identifier", msg.Identifier) 501 } 502 503 s.smu.Unlock() 504 505 if res != nil { 506 n.handleCommandReply(s, msg, res) 507 } 508 509 if s.IsResumeable() { 510 if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { 511 s.Log.Error("failed to persist session in cache", "error", berr) 512 } 513 } 514 515 return res, nil 516 } 517 518 // Perform executes client command 519 func (n *Node) Perform(s *Session, msg *common.Message) (*common.CommandResult, error) { 520 s.smu.Lock() 521 522 if ok := s.subscriptions.HasChannel(msg.Identifier); !ok { 523 s.smu.Unlock() 524 return nil, fmt.Errorf("unknown subscription %s", msg.Identifier) 525 } 526 527 s.smu.Unlock() 528 529 data, ok := msg.Data.(string) 530 531 if !ok { 532 return nil, fmt.Errorf("perform data must be a string, got %v", msg.Data) 533 } 534 535 res, err := n.controller.Perform(s.GetID(), s.env, s.GetIdentifiers(), msg.Identifier, data) 536 537 s.Log.Debug("controller perform", "response", res, "err", err) 538 539 if err != nil { 540 if res == nil || res.Status == common.ERROR { 541 return nil, errorx.Decorate(err, "perform failed for %s", msg.Identifier) 542 } 543 } 544 545 if res != nil { 546 if n.handleCommandReply(s, msg, res) { 547 if s.IsResumeable() { 548 if berr := n.broker.CommitSession(s.GetID(), s); berr != nil { 549 s.Log.Error("failed to persist session in cache", "error", berr) 550 } 551 } 552 } 553 } 554 555 return res, nil 556 } 557 558 // History fetches the stream history for the specified identifier 559 func (n *Node) History(s *Session, msg *common.Message) error { 560 s.smu.Lock() 561 562 if ok := s.subscriptions.HasChannel(msg.Identifier); !ok { 563 s.smu.Unlock() 564 return fmt.Errorf("unknown subscription %s", msg.Identifier) 565 } 566 567 subscriptionStreams := s.subscriptions.StreamsFor(msg.Identifier) 568 569 s.smu.Unlock() 570 571 history := msg.History 572 573 if history.Since == 0 && history.Streams == nil { 574 return fmt.Errorf("history request is missing, got %v", msg) 575 } 576 577 backlog, err := n.retreiveHistory(&history, subscriptionStreams) 578 579 if err != nil { 580 s.Send(&common.Reply{ 581 Type: common.HistoryRejectedType, 582 Identifier: msg.Identifier, 583 }) 584 585 return err 586 } 587 588 for _, el := range backlog { 589 s.Send(el.ToReplyFor(msg.Identifier)) 590 } 591 592 s.Send(&common.Reply{ 593 Type: common.HistoryConfirmedType, 594 Identifier: msg.Identifier, 595 }) 596 597 return nil 598 } 599 600 func (n *Node) retreiveHistory(history *common.HistoryRequest, streams []string) (backlog []common.StreamMessage, err error) { 601 backlog = []common.StreamMessage{} 602 603 for _, stream := range streams { 604 if history.Streams != nil { 605 pos, ok := history.Streams[stream] 606 607 if ok { 608 streamBacklog, err := n.broker.HistoryFrom(stream, pos.Epoch, pos.Offset) 609 610 if err != nil { 611 return nil, err 612 } 613 614 backlog = append(backlog, streamBacklog...) 615 616 continue 617 } 618 } 619 620 if history.Since > 0 { 621 streamBacklog, err := n.broker.HistorySince(stream, history.Since) 622 623 if err != nil { 624 return nil, err 625 } 626 627 backlog = append(backlog, streamBacklog...) 628 } 629 } 630 631 return backlog, nil 632 } 633 634 // Whisper broadcasts the message to the specified whispering stream to 635 // all clients except the sender 636 func (n *Node) Whisper(s *Session, msg *common.Message) error { 637 // The session must have the whisper stream name defined in the state to be able to whisper 638 // If the stream is not defined, the whisper message is ignored 639 env := s.GetEnv() 640 if env == nil { 641 return errors.New("session environment is missing") 642 } 643 644 stream := env.GetChannelStateField(msg.Identifier, common.WHISPER_STREAM_STATE) 645 646 if stream == "" { 647 s.Log.Debug("whisper stream not found", "identifier", msg.Identifier) 648 return nil 649 } 650 651 broadcast := &common.StreamMessage{ 652 Stream: stream, 653 Data: string(utils.ToJSON(msg.Data)), 654 Meta: &common.StreamMessageMetadata{ 655 ExcludeSocket: s.GetID(), 656 Transient: true, 657 }, 658 } 659 660 n.broker.HandleBroadcast(broadcast) 661 662 s.Log.Debug("whispered", "stream", stream) 663 664 return nil 665 } 666 667 // Broadcast message to stream (locally) 668 func (n *Node) Broadcast(msg *common.StreamMessage) { 669 n.metrics.CounterIncrement(metricsBroadcastMsg) 670 n.log.Debug("incoming broadcast message", "payload", msg) 671 n.hub.BroadcastMessage(msg) 672 } 673 674 // Execute remote command (locally) 675 func (n *Node) ExecuteRemoteCommand(msg *common.RemoteCommandMessage) { 676 // TODO: Add remote commands metrics 677 // n.metrics.CounterIncrement(metricsRemoteCommandsMsg) 678 switch msg.Command { // nolint:gocritic 679 case "disconnect": 680 dmsg, err := msg.ToRemoteDisconnectMessage() 681 if err != nil { 682 n.log.Warn("failed to parse remote disconnect command", "data", msg, "error", err) 683 return 684 } 685 686 n.log.Debug("incoming remote command", "command", dmsg) 687 688 n.RemoteDisconnect(dmsg) 689 } 690 } 691 692 // Disconnect adds session to disconnector queue and unregister session from hub 693 func (n *Node) Disconnect(s *Session) error { 694 if s.IsResumeable() { 695 n.broker.FinishSession(s.GetID()) // nolint:errcheck 696 } 697 698 if n.IsShuttingDown() { 699 if s.IsDisconnectable() { 700 return n.DisconnectNow(s) 701 } 702 } else { 703 n.hub.RemoveSessionLater(s) 704 705 if s.IsDisconnectable() { 706 return n.disconnector.Enqueue(s) 707 } 708 } 709 710 return nil 711 } 712 713 // DisconnectNow execute disconnect on controller 714 func (n *Node) DisconnectNow(s *Session) error { 715 sessionSubscriptions := s.subscriptions.Channels() 716 717 ids := s.GetIdentifiers() 718 719 s.Log.Debug("disconnect", "ids", ids, "url", s.env.URL, "headers", s.env.Headers, "subscriptions", sessionSubscriptions) 720 721 err := n.controller.Disconnect( 722 s.GetID(), 723 s.env, 724 ids, 725 sessionSubscriptions, 726 ) 727 728 if err != nil { 729 s.Log.Error("controller disconnect failed", "error", err) 730 } 731 732 s.Log.Debug("controller disconnect succeeded") 733 734 return err 735 } 736 737 // RemoteDisconnect find a session by identifier and closes it 738 func (n *Node) RemoteDisconnect(msg *common.RemoteDisconnectMessage) { 739 n.metrics.CounterIncrement(metricsBroadcastMsg) 740 n.log.Debug("incoming pubsub command", "data", msg) 741 n.hub.RemoteDisconnect(msg) 742 } 743 744 // Interest is represented as a int; -1 indicates no interest, 0 indicates lack of such information, 745 // and 1 indicates interest. 746 func (n *Node) markDisconnectable(s *Session, interest int) { 747 switch n.config.DisconnectMode { 748 case "always": 749 s.MarkDisconnectable(true) 750 case "never": 751 s.MarkDisconnectable(false) 752 case "auto": 753 s.MarkDisconnectable(interest >= 0) 754 } 755 } 756 757 func (n *Node) Size() int { 758 return n.hub.Size() 759 } 760 761 func transmit(s *Session, transmissions []string) { 762 for _, msg := range transmissions { 763 s.SendJSONTransmission(msg) 764 } 765 } 766 767 func (n *Node) handleCommandReply(s *Session, msg *common.Message, reply *common.CommandResult) bool { 768 // Returns true if any of the subscriptions/channel/connections state has changed 769 isDirty := false 770 771 if reply.Disconnect { 772 defer s.Disconnect("Command Failed", ws.CloseAbnormalClosure) 773 } 774 775 if reply.StopAllStreams { 776 n.hub.UnsubscribeSessionFromChannel(s, msg.Identifier) 777 removedStreams := s.subscriptions.RemoveChannelStreams(msg.Identifier) 778 779 for _, stream := range removedStreams { 780 isDirty = true 781 n.broker.Unsubscribe(stream) 782 } 783 784 } else if reply.StoppedStreams != nil { 785 isDirty = true 786 787 for _, stream := range reply.StoppedStreams { 788 streamId := n.broker.Unsubscribe(stream) 789 n.hub.UnsubscribeSession(s, streamId, msg.Identifier) 790 s.subscriptions.RemoveChannelStream(msg.Identifier, streamId) 791 } 792 } 793 794 if reply.Streams != nil { 795 isDirty = true 796 797 for _, stream := range reply.Streams { 798 streamId := n.broker.Subscribe(stream) 799 n.hub.SubscribeSession(s, streamId, msg.Identifier) 800 s.subscriptions.AddChannelStream(msg.Identifier, streamId) 801 } 802 } 803 804 if reply.IState != nil { 805 isDirty = true 806 807 s.smu.Lock() 808 s.env.MergeChannelState(msg.Identifier, &reply.IState) 809 s.smu.Unlock() 810 } 811 812 isConnectionDirty := n.handleCallReply(s, reply.ToCallResult()) 813 return isDirty || isConnectionDirty 814 } 815 816 func (n *Node) handleCallReply(s *Session, reply *common.CallResult) bool { 817 isDirty := false 818 819 if reply.CState != nil { 820 isDirty = true 821 822 s.smu.Lock() 823 s.env.MergeConnectionState(&reply.CState) 824 s.smu.Unlock() 825 } 826 827 if reply.Broadcasts != nil { 828 for _, broadcast := range reply.Broadcasts { 829 n.broker.HandleBroadcast(broadcast) 830 } 831 } 832 833 if reply.Transmissions != nil { 834 transmit(s, reply.Transmissions) 835 } 836 837 return isDirty 838 } 839 840 // disconnectScheduler controls how quickly to disconnect sessions 841 type disconnectScheduler interface { 842 // This method is called when a session is ready to be disconnected, 843 // so it can block the operation or cancel it (by returning false). 844 Continue() bool 845 } 846 847 type noopScheduler struct { 848 ctx context.Context 849 } 850 851 func (s *noopScheduler) Continue() bool { 852 return s.ctx.Err() == nil 853 } 854 855 func (n *Node) disconnectAll(ctx context.Context) { 856 disconnectMessage := common.NewDisconnectMessage(common.SERVER_RESTART_REASON, true) 857 858 // To speed up the process we disconnect all sessions in parallel using a pool of workers 859 pool := utils.NewGoPool("disconnect", n.config.ShutdownDisconnectPoolSize) 860 861 sessions := n.hub.Sessions() 862 863 var scheduler disconnectScheduler // nolint:gosimple 864 865 scheduler = &noopScheduler{ctx} 866 867 var wg sync.WaitGroup 868 869 wg.Add(len(sessions)) 870 871 for _, s := range sessions { 872 s := s.(*Session) 873 pool.Schedule(func() { 874 if scheduler.Continue() { 875 if s.IsConnected() { 876 s.DisconnectWithMessage(disconnectMessage, common.SERVER_RESTART_REASON) 877 } 878 wg.Done() 879 } 880 }) 881 } 882 883 done := make(chan struct{}) 884 885 go func() { 886 wg.Wait() 887 close(done) 888 }() 889 890 select { 891 case <-ctx.Done(): 892 n.log.Warn("terminated while disconnecting active sessions", "num", n.hub.Size()) 893 case <-done: 894 n.log.Info("all active connections closed") 895 } 896 } 897 898 func (n *Node) collectStats() { 899 if n.config.StatsRefreshInterval == 0 { 900 return 901 } 902 903 statsCollectInterval := time.Duration(n.config.StatsRefreshInterval) * time.Second 904 905 for { 906 select { 907 case <-n.shutdownCh: 908 return 909 case <-time.After(statsCollectInterval): 910 n.collectStatsOnce() 911 } 912 } 913 } 914 915 func (n *Node) collectStatsOnce() { 916 n.metrics.GaugeSet(metricsGoroutines, uint64(runtime.NumGoroutine())) 917 918 var m runtime.MemStats 919 runtime.ReadMemStats(&m) 920 n.metrics.GaugeSet(metricsMemSys, m.Sys) 921 922 n.metrics.GaugeSet(metricsClientsNum, uint64(n.hub.Size())) 923 n.metrics.GaugeSet(metricsUniqClientsNum, uint64(n.hub.UniqSize())) 924 n.metrics.GaugeSet(metricsStreamsNum, uint64(n.hub.StreamsSize())) 925 n.metrics.GaugeSet(metricsDisconnectQueue, uint64(n.disconnector.Size())) 926 } 927 928 func (n *Node) registerMetrics() { 929 n.metrics.RegisterGauge(metricsGoroutines, "The number of Go routines") 930 n.metrics.RegisterGauge(metricsMemSys, "The total bytes of memory obtained from the OS") 931 932 n.metrics.RegisterGauge(metricsClientsNum, "The number of active clients") 933 n.metrics.RegisterGauge(metricsUniqClientsNum, "The number of unique clients (with respect to connection identifiers)") 934 n.metrics.RegisterGauge(metricsStreamsNum, "The number of active broadcasting streams") 935 n.metrics.RegisterGauge(metricsDisconnectQueue, "The size of delayed disconnect") 936 937 n.metrics.RegisterCounter(metricsFailedAuths, "The total number of failed authentication attempts") 938 n.metrics.RegisterCounter(metricsReceivedMsg, "The total number of received messages from clients") 939 n.metrics.RegisterCounter(metricsFailedCommandReceived, "The total number of unrecognized messages received from clients") 940 n.metrics.RegisterCounter(metricsBroadcastMsg, "The total number of messages received through PubSub (for broadcast)") 941 n.metrics.RegisterCounter(metricsUnknownBroadcast, "The total number of unrecognized messages received through PubSub") 942 943 n.metrics.RegisterCounter(metricsSentMsg, "The total number of messages sent to clients") 944 n.metrics.RegisterCounter(metricsFailedSent, "The total number of messages failed to send to clients") 945 946 n.metrics.RegisterCounter(metricsDataSent, "The total amount of bytes sent to clients") 947 n.metrics.RegisterCounter(metricsDataReceived, "The total amount of bytes received from clients") 948 }