github.com/anycable/anycable-go@v1.5.1/node/node.go (about)

     1  package node
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"log/slog"
     8  	"runtime"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/anycable/anycable-go/broker"
    13  	"github.com/anycable/anycable-go/common"
    14  	"github.com/anycable/anycable-go/hub"
    15  	"github.com/anycable/anycable-go/logger"
    16  	"github.com/anycable/anycable-go/metrics"
    17  	"github.com/anycable/anycable-go/utils"
    18  	"github.com/anycable/anycable-go/ws"
    19  	"github.com/joomcode/errorx"
    20  )
    21  
    22  const (
    23  	metricsGoroutines      = "goroutines_num"
    24  	metricsMemSys          = "mem_sys_bytes"
    25  	metricsClientsNum      = "clients_num"
    26  	metricsUniqClientsNum  = "clients_uniq_num"
    27  	metricsStreamsNum      = "broadcast_streams_num"
    28  	metricsDisconnectQueue = "disconnect_queue_size"
    29  
    30  	metricsFailedAuths           = "failed_auths_total"
    31  	metricsReceivedMsg           = "client_msg_total"
    32  	metricsFailedCommandReceived = "failed_client_msg_total"
    33  	metricsBroadcastMsg          = "broadcast_msg_total"
    34  	metricsUnknownBroadcast      = "failed_broadcast_msg_total"
    35  
    36  	metricsSentMsg    = "server_msg_total"
    37  	metricsFailedSent = "failed_server_msg_total"
    38  
    39  	metricsDataSent     = "data_sent_total"
    40  	metricsDataReceived = "data_rcvd_total"
    41  )
    42  
    43  // AppNode describes a basic node interface
    44  //
    45  //go:generate mockery --name AppNode --output "../node_mocks" --outpkg node_mocks
    46  type AppNode interface {
    47  	HandlePubSub(msg []byte)
    48  	LookupSession(id string) *Session
    49  	Authenticate(s *Session, opts ...AuthOption) (*common.ConnectResult, error)
    50  	Authenticated(s *Session, identifiers string)
    51  	Subscribe(s *Session, msg *common.Message) (*common.CommandResult, error)
    52  	Unsubscribe(s *Session, msg *common.Message) (*common.CommandResult, error)
    53  	Perform(s *Session, msg *common.Message) (*common.CommandResult, error)
    54  	Disconnect(s *Session) error
    55  }
    56  
    57  // Connection represents underlying connection
    58  type Connection interface {
    59  	Write(msg []byte, deadline time.Time) error
    60  	WriteBinary(msg []byte, deadline time.Time) error
    61  	Read() ([]byte, error)
    62  	Close(code int, reason string)
    63  }
    64  
    65  // Node represents the whole application
    66  type Node struct {
    67  	id      string
    68  	metrics metrics.Instrumenter
    69  
    70  	config       *Config
    71  	hub          *hub.Hub
    72  	broker       broker.Broker
    73  	controller   Controller
    74  	disconnector Disconnector
    75  	shutdownCh   chan struct{}
    76  	shutdownMu   sync.Mutex
    77  	closed       bool
    78  	log          *slog.Logger
    79  }
    80  
    81  var _ AppNode = (*Node)(nil)
    82  
    83  type NodeOption = func(*Node)
    84  
    85  func WithController(c Controller) NodeOption {
    86  	return func(n *Node) {
    87  		n.controller = c
    88  	}
    89  }
    90  
    91  func WithInstrumenter(i metrics.Instrumenter) NodeOption {
    92  	return func(n *Node) {
    93  		n.metrics = i
    94  	}
    95  }
    96  
    97  func WithLogger(l *slog.Logger) NodeOption {
    98  	return func(n *Node) {
    99  		n.log = l.With("context", "node")
   100  	}
   101  }
   102  
   103  func WithID(id string) NodeOption {
   104  	return func(n *Node) {
   105  		n.id = id
   106  	}
   107  }
   108  
   109  // NewNode builds new node struct
   110  func NewNode(config *Config, opts ...NodeOption) *Node {
   111  	n := &Node{
   112  		config:     config,
   113  		shutdownCh: make(chan struct{}),
   114  	}
   115  
   116  	for _, opt := range opts {
   117  		opt(n)
   118  	}
   119  
   120  	// Setup default logger
   121  	if n.log == nil {
   122  		n.log = slog.With("context", "node")
   123  	}
   124  
   125  	n.hub = hub.NewHub(config.HubGopoolSize, n.log)
   126  
   127  	if n.metrics != nil {
   128  		n.registerMetrics()
   129  	}
   130  
   131  	return n
   132  }
   133  
   134  // Start runs all the required goroutines
   135  func (n *Node) Start() error {
   136  	go n.hub.Run()
   137  	go n.collectStats()
   138  
   139  	return nil
   140  }
   141  
   142  // ID returns node identifier
   143  func (n *Node) ID() string {
   144  	return n.id
   145  }
   146  
   147  // SetDisconnector set disconnector for the node
   148  func (n *Node) SetDisconnector(d Disconnector) {
   149  	n.disconnector = d
   150  }
   151  
   152  func (n *Node) SetBroker(b broker.Broker) {
   153  	n.broker = b
   154  }
   155  
   156  // Return current instrumenter for the node
   157  func (n *Node) Instrumenter() metrics.Instrumenter {
   158  	return n.metrics
   159  }
   160  
   161  // HandleCommand parses incoming message from client and
   162  // execute the command (if recognized)
   163  func (n *Node) HandleCommand(s *Session, msg *common.Message) (err error) {
   164  	s.Log.Debug("incoming message", "data", msg)
   165  	switch msg.Command {
   166  	case "pong":
   167  		s.handlePong(msg)
   168  	case "subscribe":
   169  		_, err = n.Subscribe(s, msg)
   170  	case "unsubscribe":
   171  		_, err = n.Unsubscribe(s, msg)
   172  	case "message":
   173  		_, err = n.Perform(s, msg)
   174  	case "history":
   175  		err = n.History(s, msg)
   176  	case "whisper":
   177  		err = n.Whisper(s, msg)
   178  	default:
   179  		err = fmt.Errorf("unknown command: %s", msg.Command)
   180  	}
   181  
   182  	return
   183  }
   184  
   185  // HandleBroadcast parses incoming broadcast message, record it and re-transmit to other nodes
   186  func (n *Node) HandleBroadcast(raw []byte) {
   187  	msg, err := common.PubSubMessageFromJSON(raw)
   188  
   189  	if err != nil {
   190  		n.metrics.CounterIncrement(metricsUnknownBroadcast)
   191  		n.log.Warn("failed to parse pubsub message", "data", logger.CompactValue(raw), "error", err)
   192  		return
   193  	}
   194  
   195  	switch v := msg.(type) {
   196  	case common.StreamMessage:
   197  		n.log.Debug("handle broadcast message", "payload", &v)
   198  		n.broker.HandleBroadcast(&v)
   199  	case []*common.StreamMessage:
   200  		n.log.Debug("handle batch-broadcast message", "payload", &v)
   201  		for _, el := range v {
   202  			n.broker.HandleBroadcast(el)
   203  		}
   204  	case common.RemoteCommandMessage:
   205  		n.log.Debug("handle remote command", "command", &v)
   206  		n.broker.HandleCommand(&v)
   207  	}
   208  }
   209  
   210  // HandlePubSub parses incoming pubsub message and broadcast it to all clients (w/o using a broker)
   211  func (n *Node) HandlePubSub(raw []byte) {
   212  	msg, err := common.PubSubMessageFromJSON(raw)
   213  
   214  	if err != nil {
   215  		n.metrics.CounterIncrement(metricsUnknownBroadcast)
   216  		n.log.Warn("failed to parse pubsub message", "data", logger.CompactValue(raw), "error", err)
   217  		return
   218  	}
   219  
   220  	switch v := msg.(type) {
   221  	case common.StreamMessage:
   222  		n.Broadcast(&v)
   223  	case []*common.StreamMessage:
   224  		for _, el := range v {
   225  			n.Broadcast(el)
   226  		}
   227  	case common.RemoteCommandMessage:
   228  		n.ExecuteRemoteCommand(&v)
   229  	}
   230  }
   231  
   232  func (n *Node) LookupSession(id string) *Session {
   233  	hubSession := n.hub.FindByIdentifier(id)
   234  	session, _ := hubSession.(*Session)
   235  	return session
   236  }
   237  
   238  // Shutdown stops all services (hub, controller)
   239  func (n *Node) Shutdown(ctx context.Context) (err error) {
   240  	n.shutdownMu.Lock()
   241  	if n.closed {
   242  		n.shutdownMu.Unlock()
   243  		return errors.New("already shut down")
   244  	}
   245  
   246  	close(n.shutdownCh)
   247  
   248  	n.closed = true
   249  	n.shutdownMu.Unlock()
   250  
   251  	if n.hub != nil {
   252  		active := n.hub.Size()
   253  
   254  		if active > 0 {
   255  			n.log.Info("closing active connections", "num", active)
   256  			n.disconnectAll(ctx)
   257  		}
   258  
   259  		n.hub.Shutdown()
   260  	}
   261  
   262  	if n.disconnector != nil {
   263  		err := n.disconnector.Shutdown(ctx)
   264  
   265  		if err != nil {
   266  			n.log.Warn("failed to shutdown disconnector gracefully", "error", err)
   267  		}
   268  	}
   269  
   270  	if n.controller != nil {
   271  		err := n.controller.Shutdown()
   272  
   273  		if err != nil {
   274  			n.log.Warn("failed to shutdown controller gracefully", "error", err)
   275  		}
   276  	}
   277  
   278  	return
   279  }
   280  
   281  func (n *Node) IsShuttingDown() bool {
   282  	n.shutdownMu.Lock()
   283  	defer n.shutdownMu.Unlock()
   284  
   285  	return n.closed
   286  }
   287  
   288  type AuthOptions struct {
   289  	DisconnectOnFailure bool
   290  }
   291  
   292  func newAuthOptions(modifiers []AuthOption) *AuthOptions {
   293  	base := &AuthOptions{
   294  		DisconnectOnFailure: true,
   295  	}
   296  
   297  	for _, modifier := range modifiers {
   298  		modifier(base)
   299  	}
   300  
   301  	return base
   302  }
   303  
   304  type AuthOption = func(*AuthOptions)
   305  
   306  func WithDisconnectOnFailure(disconnect bool) AuthOption {
   307  	return func(opts *AuthOptions) {
   308  		opts.DisconnectOnFailure = disconnect
   309  	}
   310  }
   311  
   312  // Authenticate calls controller to perform authentication.
   313  // If authentication is successful, session is registered with a hub.
   314  func (n *Node) Authenticate(s *Session, options ...AuthOption) (*common.ConnectResult, error) {
   315  	opts := newAuthOptions(options)
   316  
   317  	if s.IsResumeable() {
   318  		restored := n.TryRestoreSession(s)
   319  
   320  		if restored {
   321  			return &common.ConnectResult{Status: common.SUCCESS}, nil
   322  		}
   323  	}
   324  
   325  	res, err := n.controller.Authenticate(s.GetID(), s.env)
   326  
   327  	s.Log.Debug("controller authenticate", "response", res, "err", err)
   328  
   329  	if err != nil {
   330  		s.Disconnect("Auth Error", ws.CloseInternalServerErr)
   331  		return nil, errorx.Decorate(err, "failed to authenticate")
   332  	}
   333  
   334  	if res.Status == common.SUCCESS {
   335  		n.Authenticated(s, res.Identifier)
   336  	} else {
   337  		if res.Status == common.FAILURE {
   338  			n.metrics.CounterIncrement(metricsFailedAuths)
   339  		}
   340  
   341  		if opts.DisconnectOnFailure {
   342  			defer s.Disconnect("Auth Failed", ws.CloseNormalClosure)
   343  		}
   344  	}
   345  
   346  	n.handleCallReply(s, res.ToCallResult())
   347  	n.markDisconnectable(s, res.DisconnectInterest)
   348  
   349  	if s.IsResumeable() {
   350  		if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
   351  			s.Log.Error("failed to persist session in cache", "error", berr)
   352  		}
   353  	}
   354  
   355  	return res, nil
   356  }
   357  
   358  // Mark session as authenticated and register it with a hub.
   359  // Useful when you perform authentication manually, not using a controller.
   360  func (n *Node) Authenticated(s *Session, ids string) {
   361  	s.SetIdentifiers(ids)
   362  	s.Connected = true
   363  	n.hub.AddSession(s)
   364  }
   365  
   366  func (n *Node) TryRestoreSession(s *Session) (restored bool) {
   367  	sid := s.GetID()
   368  	prev_sid := s.PrevSid()
   369  
   370  	if prev_sid == "" {
   371  		return false
   372  	}
   373  
   374  	cached_session, err := n.broker.RestoreSession(prev_sid)
   375  
   376  	if err != nil {
   377  		s.Log.Error("failed to fetch session cache", "old_sid", prev_sid, "error", err)
   378  		return false
   379  	}
   380  
   381  	if cached_session == nil {
   382  		s.Log.Debug("session not found in cache", "old_sid", prev_sid)
   383  		return false
   384  	}
   385  
   386  	err = s.RestoreFromCache(cached_session)
   387  
   388  	if err != nil {
   389  		s.Log.Error("failed to restore session from cache", "old_sid", prev_sid, "error", err)
   390  		return false
   391  	}
   392  
   393  	s.Log.Debug("session restored", "old_sid", prev_sid)
   394  
   395  	s.Connected = true
   396  	n.hub.AddSession(s)
   397  
   398  	// Resubscribe to streams
   399  	for identifier, channel_streams := range s.subscriptions.channels {
   400  		for stream := range channel_streams {
   401  			streamId := n.broker.Subscribe(stream)
   402  			n.hub.SubscribeSession(s, streamId, identifier)
   403  		}
   404  	}
   405  
   406  	// Send welcome message
   407  	s.Send(&common.Reply{
   408  		Type:        common.WelcomeType,
   409  		Sid:         sid,
   410  		Restored:    true,
   411  		RestoredIDs: utils.Keys(s.subscriptions.channels),
   412  	})
   413  
   414  	if s.IsResumeable() {
   415  		if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
   416  			s.Log.Error("failed to persist session in cache", "error", berr)
   417  		}
   418  	}
   419  
   420  	return true
   421  }
   422  
   423  // Subscribe subscribes session to a channel
   424  func (n *Node) Subscribe(s *Session, msg *common.Message) (*common.CommandResult, error) {
   425  	s.smu.Lock()
   426  
   427  	if ok := s.subscriptions.HasChannel(msg.Identifier); ok {
   428  		s.smu.Unlock()
   429  		return nil, fmt.Errorf("already subscribed to %s", msg.Identifier)
   430  	}
   431  
   432  	res, err := n.controller.Subscribe(s.GetID(), s.env, s.GetIdentifiers(), msg.Identifier)
   433  
   434  	s.Log.Debug("controller subscribe", "response", res, "err", err)
   435  
   436  	var confirmed bool
   437  
   438  	if err != nil { // nolint: gocritic
   439  		if res == nil || res.Status == common.ERROR {
   440  			return nil, errorx.Decorate(err, "subscribe failed for %s", msg.Identifier)
   441  		}
   442  	} else if res.Status == common.SUCCESS {
   443  		confirmed = true
   444  		s.subscriptions.AddChannel(msg.Identifier)
   445  		s.Log.Debug("subscribed", "identifier", msg.Identifier)
   446  	} else {
   447  		s.Log.Debug("subscription rejected", "identifier", msg.Identifier)
   448  	}
   449  
   450  	s.smu.Unlock()
   451  
   452  	if res != nil {
   453  		n.handleCommandReply(s, msg, res)
   454  		n.markDisconnectable(s, res.DisconnectInterest)
   455  	}
   456  
   457  	if confirmed {
   458  		if s.IsResumeable() {
   459  			if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
   460  				s.Log.Error("failed to persist session in cache", "error", berr)
   461  			}
   462  		}
   463  
   464  		if msg.History.Since > 0 || msg.History.Streams != nil {
   465  			if err := n.History(s, msg); err != nil {
   466  				s.Log.Warn("couldn't retrieve history", "identifier", msg.Identifier, "error", err)
   467  			}
   468  
   469  			return res, nil
   470  		}
   471  	}
   472  
   473  	return res, nil
   474  }
   475  
   476  // Unsubscribe unsubscribes session from a channel
   477  func (n *Node) Unsubscribe(s *Session, msg *common.Message) (*common.CommandResult, error) {
   478  	s.smu.Lock()
   479  
   480  	if ok := s.subscriptions.HasChannel(msg.Identifier); !ok {
   481  		s.smu.Unlock()
   482  		return nil, fmt.Errorf("unknown subscription: %s", msg.Identifier)
   483  	}
   484  
   485  	res, err := n.controller.Unsubscribe(s.GetID(), s.env, s.GetIdentifiers(), msg.Identifier)
   486  
   487  	s.Log.Debug("controller unsubscribe", "response", res, "err", err)
   488  
   489  	if err != nil {
   490  		if res == nil || res.Status == common.ERROR {
   491  			return nil, errorx.Decorate(err, "failed to unsubscribe from %s", msg.Identifier)
   492  		}
   493  	} else {
   494  		// Make sure to remove all streams subscriptions
   495  		res.StopAllStreams = true
   496  
   497  		s.env.RemoveChannelState(msg.Identifier)
   498  		s.subscriptions.RemoveChannel(msg.Identifier)
   499  
   500  		s.Log.Debug("unsubscribed", "identifier", msg.Identifier)
   501  	}
   502  
   503  	s.smu.Unlock()
   504  
   505  	if res != nil {
   506  		n.handleCommandReply(s, msg, res)
   507  	}
   508  
   509  	if s.IsResumeable() {
   510  		if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
   511  			s.Log.Error("failed to persist session in cache", "error", berr)
   512  		}
   513  	}
   514  
   515  	return res, nil
   516  }
   517  
   518  // Perform executes client command
   519  func (n *Node) Perform(s *Session, msg *common.Message) (*common.CommandResult, error) {
   520  	s.smu.Lock()
   521  
   522  	if ok := s.subscriptions.HasChannel(msg.Identifier); !ok {
   523  		s.smu.Unlock()
   524  		return nil, fmt.Errorf("unknown subscription %s", msg.Identifier)
   525  	}
   526  
   527  	s.smu.Unlock()
   528  
   529  	data, ok := msg.Data.(string)
   530  
   531  	if !ok {
   532  		return nil, fmt.Errorf("perform data must be a string, got %v", msg.Data)
   533  	}
   534  
   535  	res, err := n.controller.Perform(s.GetID(), s.env, s.GetIdentifiers(), msg.Identifier, data)
   536  
   537  	s.Log.Debug("controller perform", "response", res, "err", err)
   538  
   539  	if err != nil {
   540  		if res == nil || res.Status == common.ERROR {
   541  			return nil, errorx.Decorate(err, "perform failed for %s", msg.Identifier)
   542  		}
   543  	}
   544  
   545  	if res != nil {
   546  		if n.handleCommandReply(s, msg, res) {
   547  			if s.IsResumeable() {
   548  				if berr := n.broker.CommitSession(s.GetID(), s); berr != nil {
   549  					s.Log.Error("failed to persist session in cache", "error", berr)
   550  				}
   551  			}
   552  		}
   553  	}
   554  
   555  	return res, nil
   556  }
   557  
   558  // History fetches the stream history for the specified identifier
   559  func (n *Node) History(s *Session, msg *common.Message) error {
   560  	s.smu.Lock()
   561  
   562  	if ok := s.subscriptions.HasChannel(msg.Identifier); !ok {
   563  		s.smu.Unlock()
   564  		return fmt.Errorf("unknown subscription %s", msg.Identifier)
   565  	}
   566  
   567  	subscriptionStreams := s.subscriptions.StreamsFor(msg.Identifier)
   568  
   569  	s.smu.Unlock()
   570  
   571  	history := msg.History
   572  
   573  	if history.Since == 0 && history.Streams == nil {
   574  		return fmt.Errorf("history request is missing, got %v", msg)
   575  	}
   576  
   577  	backlog, err := n.retreiveHistory(&history, subscriptionStreams)
   578  
   579  	if err != nil {
   580  		s.Send(&common.Reply{
   581  			Type:       common.HistoryRejectedType,
   582  			Identifier: msg.Identifier,
   583  		})
   584  
   585  		return err
   586  	}
   587  
   588  	for _, el := range backlog {
   589  		s.Send(el.ToReplyFor(msg.Identifier))
   590  	}
   591  
   592  	s.Send(&common.Reply{
   593  		Type:       common.HistoryConfirmedType,
   594  		Identifier: msg.Identifier,
   595  	})
   596  
   597  	return nil
   598  }
   599  
   600  func (n *Node) retreiveHistory(history *common.HistoryRequest, streams []string) (backlog []common.StreamMessage, err error) {
   601  	backlog = []common.StreamMessage{}
   602  
   603  	for _, stream := range streams {
   604  		if history.Streams != nil {
   605  			pos, ok := history.Streams[stream]
   606  
   607  			if ok {
   608  				streamBacklog, err := n.broker.HistoryFrom(stream, pos.Epoch, pos.Offset)
   609  
   610  				if err != nil {
   611  					return nil, err
   612  				}
   613  
   614  				backlog = append(backlog, streamBacklog...)
   615  
   616  				continue
   617  			}
   618  		}
   619  
   620  		if history.Since > 0 {
   621  			streamBacklog, err := n.broker.HistorySince(stream, history.Since)
   622  
   623  			if err != nil {
   624  				return nil, err
   625  			}
   626  
   627  			backlog = append(backlog, streamBacklog...)
   628  		}
   629  	}
   630  
   631  	return backlog, nil
   632  }
   633  
   634  // Whisper broadcasts the message to the specified whispering stream to
   635  // all clients except the sender
   636  func (n *Node) Whisper(s *Session, msg *common.Message) error {
   637  	// The session must have the whisper stream name defined in the state to be able to whisper
   638  	// If the stream is not defined, the whisper message is ignored
   639  	env := s.GetEnv()
   640  	if env == nil {
   641  		return errors.New("session environment is missing")
   642  	}
   643  
   644  	stream := env.GetChannelStateField(msg.Identifier, common.WHISPER_STREAM_STATE)
   645  
   646  	if stream == "" {
   647  		s.Log.Debug("whisper stream not found", "identifier", msg.Identifier)
   648  		return nil
   649  	}
   650  
   651  	broadcast := &common.StreamMessage{
   652  		Stream: stream,
   653  		Data:   string(utils.ToJSON(msg.Data)),
   654  		Meta: &common.StreamMessageMetadata{
   655  			ExcludeSocket: s.GetID(),
   656  			Transient:     true,
   657  		},
   658  	}
   659  
   660  	n.broker.HandleBroadcast(broadcast)
   661  
   662  	s.Log.Debug("whispered", "stream", stream)
   663  
   664  	return nil
   665  }
   666  
   667  // Broadcast message to stream (locally)
   668  func (n *Node) Broadcast(msg *common.StreamMessage) {
   669  	n.metrics.CounterIncrement(metricsBroadcastMsg)
   670  	n.log.Debug("incoming broadcast message", "payload", msg)
   671  	n.hub.BroadcastMessage(msg)
   672  }
   673  
   674  // Execute remote command (locally)
   675  func (n *Node) ExecuteRemoteCommand(msg *common.RemoteCommandMessage) {
   676  	// TODO: Add remote commands metrics
   677  	// n.metrics.CounterIncrement(metricsRemoteCommandsMsg)
   678  	switch msg.Command { // nolint:gocritic
   679  	case "disconnect":
   680  		dmsg, err := msg.ToRemoteDisconnectMessage()
   681  		if err != nil {
   682  			n.log.Warn("failed to parse remote disconnect command", "data", msg, "error", err)
   683  			return
   684  		}
   685  
   686  		n.log.Debug("incoming remote command", "command", dmsg)
   687  
   688  		n.RemoteDisconnect(dmsg)
   689  	}
   690  }
   691  
   692  // Disconnect adds session to disconnector queue and unregister session from hub
   693  func (n *Node) Disconnect(s *Session) error {
   694  	if s.IsResumeable() {
   695  		n.broker.FinishSession(s.GetID()) // nolint:errcheck
   696  	}
   697  
   698  	if n.IsShuttingDown() {
   699  		if s.IsDisconnectable() {
   700  			return n.DisconnectNow(s)
   701  		}
   702  	} else {
   703  		n.hub.RemoveSessionLater(s)
   704  
   705  		if s.IsDisconnectable() {
   706  			return n.disconnector.Enqueue(s)
   707  		}
   708  	}
   709  
   710  	return nil
   711  }
   712  
   713  // DisconnectNow execute disconnect on controller
   714  func (n *Node) DisconnectNow(s *Session) error {
   715  	sessionSubscriptions := s.subscriptions.Channels()
   716  
   717  	ids := s.GetIdentifiers()
   718  
   719  	s.Log.Debug("disconnect", "ids", ids, "url", s.env.URL, "headers", s.env.Headers, "subscriptions", sessionSubscriptions)
   720  
   721  	err := n.controller.Disconnect(
   722  		s.GetID(),
   723  		s.env,
   724  		ids,
   725  		sessionSubscriptions,
   726  	)
   727  
   728  	if err != nil {
   729  		s.Log.Error("controller disconnect failed", "error", err)
   730  	}
   731  
   732  	s.Log.Debug("controller disconnect succeeded")
   733  
   734  	return err
   735  }
   736  
   737  // RemoteDisconnect find a session by identifier and closes it
   738  func (n *Node) RemoteDisconnect(msg *common.RemoteDisconnectMessage) {
   739  	n.metrics.CounterIncrement(metricsBroadcastMsg)
   740  	n.log.Debug("incoming pubsub command", "data", msg)
   741  	n.hub.RemoteDisconnect(msg)
   742  }
   743  
   744  // Interest is represented as a int; -1 indicates no interest, 0 indicates lack of such information,
   745  // and 1 indicates interest.
   746  func (n *Node) markDisconnectable(s *Session, interest int) {
   747  	switch n.config.DisconnectMode {
   748  	case "always":
   749  		s.MarkDisconnectable(true)
   750  	case "never":
   751  		s.MarkDisconnectable(false)
   752  	case "auto":
   753  		s.MarkDisconnectable(interest >= 0)
   754  	}
   755  }
   756  
   757  func (n *Node) Size() int {
   758  	return n.hub.Size()
   759  }
   760  
   761  func transmit(s *Session, transmissions []string) {
   762  	for _, msg := range transmissions {
   763  		s.SendJSONTransmission(msg)
   764  	}
   765  }
   766  
   767  func (n *Node) handleCommandReply(s *Session, msg *common.Message, reply *common.CommandResult) bool {
   768  	// Returns true if any of the subscriptions/channel/connections state has changed
   769  	isDirty := false
   770  
   771  	if reply.Disconnect {
   772  		defer s.Disconnect("Command Failed", ws.CloseAbnormalClosure)
   773  	}
   774  
   775  	if reply.StopAllStreams {
   776  		n.hub.UnsubscribeSessionFromChannel(s, msg.Identifier)
   777  		removedStreams := s.subscriptions.RemoveChannelStreams(msg.Identifier)
   778  
   779  		for _, stream := range removedStreams {
   780  			isDirty = true
   781  			n.broker.Unsubscribe(stream)
   782  		}
   783  
   784  	} else if reply.StoppedStreams != nil {
   785  		isDirty = true
   786  
   787  		for _, stream := range reply.StoppedStreams {
   788  			streamId := n.broker.Unsubscribe(stream)
   789  			n.hub.UnsubscribeSession(s, streamId, msg.Identifier)
   790  			s.subscriptions.RemoveChannelStream(msg.Identifier, streamId)
   791  		}
   792  	}
   793  
   794  	if reply.Streams != nil {
   795  		isDirty = true
   796  
   797  		for _, stream := range reply.Streams {
   798  			streamId := n.broker.Subscribe(stream)
   799  			n.hub.SubscribeSession(s, streamId, msg.Identifier)
   800  			s.subscriptions.AddChannelStream(msg.Identifier, streamId)
   801  		}
   802  	}
   803  
   804  	if reply.IState != nil {
   805  		isDirty = true
   806  
   807  		s.smu.Lock()
   808  		s.env.MergeChannelState(msg.Identifier, &reply.IState)
   809  		s.smu.Unlock()
   810  	}
   811  
   812  	isConnectionDirty := n.handleCallReply(s, reply.ToCallResult())
   813  	return isDirty || isConnectionDirty
   814  }
   815  
   816  func (n *Node) handleCallReply(s *Session, reply *common.CallResult) bool {
   817  	isDirty := false
   818  
   819  	if reply.CState != nil {
   820  		isDirty = true
   821  
   822  		s.smu.Lock()
   823  		s.env.MergeConnectionState(&reply.CState)
   824  		s.smu.Unlock()
   825  	}
   826  
   827  	if reply.Broadcasts != nil {
   828  		for _, broadcast := range reply.Broadcasts {
   829  			n.broker.HandleBroadcast(broadcast)
   830  		}
   831  	}
   832  
   833  	if reply.Transmissions != nil {
   834  		transmit(s, reply.Transmissions)
   835  	}
   836  
   837  	return isDirty
   838  }
   839  
   840  // disconnectScheduler controls how quickly to disconnect sessions
   841  type disconnectScheduler interface {
   842  	// This method is called when a session is ready to be disconnected,
   843  	// so it can block the operation or cancel it (by returning false).
   844  	Continue() bool
   845  }
   846  
   847  type noopScheduler struct {
   848  	ctx context.Context
   849  }
   850  
   851  func (s *noopScheduler) Continue() bool {
   852  	return s.ctx.Err() == nil
   853  }
   854  
   855  func (n *Node) disconnectAll(ctx context.Context) {
   856  	disconnectMessage := common.NewDisconnectMessage(common.SERVER_RESTART_REASON, true)
   857  
   858  	// To speed up the process we disconnect all sessions in parallel using a pool of workers
   859  	pool := utils.NewGoPool("disconnect", n.config.ShutdownDisconnectPoolSize)
   860  
   861  	sessions := n.hub.Sessions()
   862  
   863  	var scheduler disconnectScheduler // nolint:gosimple
   864  
   865  	scheduler = &noopScheduler{ctx}
   866  
   867  	var wg sync.WaitGroup
   868  
   869  	wg.Add(len(sessions))
   870  
   871  	for _, s := range sessions {
   872  		s := s.(*Session)
   873  		pool.Schedule(func() {
   874  			if scheduler.Continue() {
   875  				if s.IsConnected() {
   876  					s.DisconnectWithMessage(disconnectMessage, common.SERVER_RESTART_REASON)
   877  				}
   878  				wg.Done()
   879  			}
   880  		})
   881  	}
   882  
   883  	done := make(chan struct{})
   884  
   885  	go func() {
   886  		wg.Wait()
   887  		close(done)
   888  	}()
   889  
   890  	select {
   891  	case <-ctx.Done():
   892  		n.log.Warn("terminated while disconnecting active sessions", "num", n.hub.Size())
   893  	case <-done:
   894  		n.log.Info("all active connections closed")
   895  	}
   896  }
   897  
   898  func (n *Node) collectStats() {
   899  	if n.config.StatsRefreshInterval == 0 {
   900  		return
   901  	}
   902  
   903  	statsCollectInterval := time.Duration(n.config.StatsRefreshInterval) * time.Second
   904  
   905  	for {
   906  		select {
   907  		case <-n.shutdownCh:
   908  			return
   909  		case <-time.After(statsCollectInterval):
   910  			n.collectStatsOnce()
   911  		}
   912  	}
   913  }
   914  
   915  func (n *Node) collectStatsOnce() {
   916  	n.metrics.GaugeSet(metricsGoroutines, uint64(runtime.NumGoroutine()))
   917  
   918  	var m runtime.MemStats
   919  	runtime.ReadMemStats(&m)
   920  	n.metrics.GaugeSet(metricsMemSys, m.Sys)
   921  
   922  	n.metrics.GaugeSet(metricsClientsNum, uint64(n.hub.Size()))
   923  	n.metrics.GaugeSet(metricsUniqClientsNum, uint64(n.hub.UniqSize()))
   924  	n.metrics.GaugeSet(metricsStreamsNum, uint64(n.hub.StreamsSize()))
   925  	n.metrics.GaugeSet(metricsDisconnectQueue, uint64(n.disconnector.Size()))
   926  }
   927  
   928  func (n *Node) registerMetrics() {
   929  	n.metrics.RegisterGauge(metricsGoroutines, "The number of Go routines")
   930  	n.metrics.RegisterGauge(metricsMemSys, "The total bytes of memory obtained from the OS")
   931  
   932  	n.metrics.RegisterGauge(metricsClientsNum, "The number of active clients")
   933  	n.metrics.RegisterGauge(metricsUniqClientsNum, "The number of unique clients (with respect to connection identifiers)")
   934  	n.metrics.RegisterGauge(metricsStreamsNum, "The number of active broadcasting streams")
   935  	n.metrics.RegisterGauge(metricsDisconnectQueue, "The size of delayed disconnect")
   936  
   937  	n.metrics.RegisterCounter(metricsFailedAuths, "The total number of failed authentication attempts")
   938  	n.metrics.RegisterCounter(metricsReceivedMsg, "The total number of received messages from clients")
   939  	n.metrics.RegisterCounter(metricsFailedCommandReceived, "The total number of unrecognized messages received from clients")
   940  	n.metrics.RegisterCounter(metricsBroadcastMsg, "The total number of messages received through PubSub (for broadcast)")
   941  	n.metrics.RegisterCounter(metricsUnknownBroadcast, "The total number of unrecognized messages received through PubSub")
   942  
   943  	n.metrics.RegisterCounter(metricsSentMsg, "The total number of messages sent to clients")
   944  	n.metrics.RegisterCounter(metricsFailedSent, "The total number of messages failed to send to clients")
   945  
   946  	n.metrics.RegisterCounter(metricsDataSent, "The total amount of bytes sent to clients")
   947  	n.metrics.RegisterCounter(metricsDataReceived, "The total amount of bytes received from clients")
   948  }