github.com/anycable/anycable-go@v1.5.1/node/session.go (about)

     1  package node
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"log/slog"
     8  	"math/rand"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/anycable/anycable-go/common"
    13  	"github.com/anycable/anycable-go/encoders"
    14  	"github.com/anycable/anycable-go/logger"
    15  	"github.com/anycable/anycable-go/metrics"
    16  	"github.com/anycable/anycable-go/ws"
    17  )
    18  
    19  const (
    20  	writeWait = 10 * time.Second
    21  )
    22  
    23  // Executor handles incoming commands (messages)
    24  type Executor interface {
    25  	HandleCommand(*Session, *common.Message) error
    26  	Disconnect(*Session) error
    27  }
    28  
    29  // Session represents active client
    30  type Session struct {
    31  	conn          Connection
    32  	uid           string
    33  	encoder       encoders.Encoder
    34  	executor      Executor
    35  	metrics       metrics.Instrumenter
    36  	env           *common.SessionEnv
    37  	subscriptions *SubscriptionState
    38  	closed        bool
    39  
    40  	// Defines if we should perform Disconnect RPC for this session
    41  	disconnectInterest bool
    42  
    43  	// Main mutex (for read/write and important session updates)
    44  	mu sync.Mutex
    45  	// Mutex for protocol-related state (env, subscriptions)
    46  	smu sync.Mutex
    47  
    48  	sendCh chan *ws.SentFrame
    49  
    50  	pingTimer    *time.Timer
    51  	pingInterval time.Duration
    52  
    53  	pingTimestampPrecision string
    54  
    55  	handshakeDeadline time.Time
    56  
    57  	pongTimeout time.Duration
    58  	pongTimer   *time.Timer
    59  
    60  	resumable bool
    61  	prevSid   string
    62  
    63  	Connected bool
    64  	// Could be used to store arbitrary data within a session
    65  	InternalState map[string]interface{}
    66  	Log           *slog.Logger
    67  }
    68  
    69  type SessionOption = func(*Session)
    70  
    71  // WithPingInterval allows to set a custom ping interval for a session
    72  // or disable pings at all (by passing 0)
    73  func WithPingInterval(interval time.Duration) SessionOption {
    74  	return func(s *Session) {
    75  		s.pingInterval = interval
    76  	}
    77  }
    78  
    79  // WithPingPrecision allows to configure precision for timestamps attached to pings
    80  func WithPingPrecision(val string) SessionOption {
    81  	return func(s *Session) {
    82  		s.pingTimestampPrecision = val
    83  	}
    84  }
    85  
    86  // WithEncoder allows to set a custom encoder for a session
    87  func WithEncoder(enc encoders.Encoder) SessionOption {
    88  	return func(s *Session) {
    89  		s.encoder = enc
    90  	}
    91  }
    92  
    93  // WithExecutor allows to set a custom executor for a session
    94  func WithExecutor(ex Executor) SessionOption {
    95  	return func(s *Session) {
    96  		s.executor = ex
    97  	}
    98  }
    99  
   100  // WithHandshakeMessageDeadline allows to set a custom deadline for handshake messages.
   101  // This option also indicates that we MUST NOT perform Authenticate on connect.
   102  func WithHandshakeMessageDeadline(deadline time.Time) SessionOption {
   103  	return func(s *Session) {
   104  		s.handshakeDeadline = deadline
   105  	}
   106  }
   107  
   108  // WithMetrics allows to set a custom metrics instrumenter for a session
   109  func WithMetrics(m metrics.Instrumenter) SessionOption {
   110  	return func(s *Session) {
   111  		s.metrics = m
   112  	}
   113  }
   114  
   115  // WithResumable allows marking session as resumable (so we store its state in cache)
   116  func WithResumable(val bool) SessionOption {
   117  	return func(s *Session) {
   118  		s.resumable = val
   119  	}
   120  }
   121  
   122  // WithPrevSID allows providing the previous session ID to restore from
   123  func WithPrevSID(sid string) SessionOption {
   124  	return func(s *Session) {
   125  		s.prevSid = sid
   126  	}
   127  }
   128  
   129  // WithPongTimeout allows to set a custom pong timeout for a session
   130  func WithPongTimeout(timeout time.Duration) SessionOption {
   131  	return func(s *Session) {
   132  		s.pongTimeout = timeout
   133  	}
   134  }
   135  
   136  // NewSession build a new Session struct from ws connetion and http request
   137  func NewSession(node *Node, conn Connection, url string, headers *map[string]string, uid string, opts ...SessionOption) *Session {
   138  	session := &Session{
   139  		conn:                   conn,
   140  		metrics:                node.metrics,
   141  		env:                    common.NewSessionEnv(url, headers),
   142  		subscriptions:          NewSubscriptionState(),
   143  		sendCh:                 make(chan *ws.SentFrame, 256),
   144  		closed:                 false,
   145  		Connected:              false,
   146  		pingInterval:           time.Duration(node.config.PingInterval) * time.Second,
   147  		pingTimestampPrecision: node.config.PingTimestampPrecision,
   148  		// Use JSON by default
   149  		encoder: encoders.JSON{},
   150  		// Use Action Cable executor by default (implemented by node)
   151  		executor: node,
   152  	}
   153  
   154  	session.uid = uid
   155  
   156  	ctx := node.log.With("sid", session.uid)
   157  
   158  	session.Log = ctx
   159  
   160  	for _, opt := range opts {
   161  		opt(session)
   162  	}
   163  
   164  	if session.pingInterval > 0 {
   165  		session.startPing()
   166  	}
   167  
   168  	if !session.handshakeDeadline.IsZero() {
   169  		val := time.Until(session.handshakeDeadline)
   170  		time.AfterFunc(val, session.maybeDisconnectIdle)
   171  	}
   172  
   173  	go session.SendMessages()
   174  
   175  	return session
   176  }
   177  
   178  func (s *Session) GetEnv() *common.SessionEnv {
   179  	return s.env
   180  }
   181  
   182  func (s *Session) SetEnv(env *common.SessionEnv) {
   183  	s.env = env
   184  }
   185  
   186  func (s *Session) UnderlyingConn() Connection {
   187  	return s.conn
   188  }
   189  
   190  func (s *Session) AuthenticateOnConnect() bool {
   191  	return s.handshakeDeadline.IsZero()
   192  }
   193  
   194  func (s *Session) IsConnected() bool {
   195  	s.mu.Lock()
   196  	defer s.mu.Unlock()
   197  
   198  	return s.Connected
   199  }
   200  
   201  func (s *Session) IsResumeable() bool {
   202  	return s.resumable
   203  }
   204  
   205  func (s *Session) maybeDisconnectIdle() {
   206  	s.mu.Lock()
   207  
   208  	if s.Connected {
   209  		s.mu.Unlock()
   210  		return
   211  	}
   212  
   213  	s.mu.Unlock()
   214  
   215  	s.Log.Warn("disconnecting idle session")
   216  
   217  	s.Send(common.NewDisconnectMessage(common.IDLE_TIMEOUT_REASON, false))
   218  	s.Disconnect("Idle Timeout", ws.CloseNormalClosure)
   219  }
   220  
   221  func (s *Session) GetID() string {
   222  	return s.uid
   223  }
   224  
   225  func (s *Session) SetID(id string) {
   226  	s.uid = id
   227  }
   228  
   229  func (s *Session) GetIdentifiers() string {
   230  	return s.env.Identifiers
   231  }
   232  
   233  func (s *Session) SetIdentifiers(ids string) {
   234  	s.env.Identifiers = ids
   235  }
   236  
   237  // Merge connection and channel states into current env.
   238  // This method locks the state for writing (so, goroutine-safe)
   239  func (s *Session) MergeEnv(env *common.SessionEnv) {
   240  	s.smu.Lock()
   241  	defer s.smu.Unlock()
   242  
   243  	if env.ConnectionState != nil {
   244  		s.env.MergeConnectionState(env.ConnectionState)
   245  	}
   246  
   247  	if env.ChannelStates != nil {
   248  		states := *env.ChannelStates
   249  		for id, state := range states { // #nosec
   250  			s.env.MergeChannelState(id, &state)
   251  		}
   252  	}
   253  }
   254  
   255  // WriteInternalState
   256  func (s *Session) WriteInternalState(key string, val interface{}) {
   257  	s.mu.Lock()
   258  	defer s.mu.Unlock()
   259  
   260  	if s.InternalState == nil {
   261  		s.InternalState = make(map[string]interface{})
   262  	}
   263  
   264  	s.InternalState[key] = val
   265  }
   266  
   267  // ReadInternalState reads internal state value by key
   268  func (s *Session) ReadInternalState(key string) (interface{}, bool) {
   269  	s.mu.Lock()
   270  	defer s.mu.Unlock()
   271  
   272  	if s.InternalState == nil {
   273  		return nil, false
   274  	}
   275  
   276  	val, ok := s.InternalState[key]
   277  
   278  	return val, ok
   279  }
   280  
   281  func (s *Session) IsDisconnectable() bool {
   282  	s.mu.Lock()
   283  	defer s.mu.Unlock()
   284  
   285  	return s.disconnectInterest
   286  }
   287  
   288  func (s *Session) MarkDisconnectable(val bool) {
   289  	s.mu.Lock()
   290  	defer s.mu.Unlock()
   291  
   292  	s.disconnectInterest = s.disconnectInterest || val
   293  }
   294  
   295  // Serve enters a loop to read incoming data
   296  func (s *Session) Serve(callback func()) error {
   297  	go func() {
   298  		defer callback()
   299  
   300  		for {
   301  			if s.IsClosed() {
   302  				return
   303  			}
   304  
   305  			message, err := s.conn.Read()
   306  
   307  			if err != nil {
   308  				if ws.IsCloseError(err) {
   309  					s.Log.Debug("WebSocket closed", "error", err)
   310  					s.disconnectNow("Read closed", ws.CloseNormalClosure)
   311  				} else {
   312  					s.Log.Debug("WebSocket close error", "error", err)
   313  					s.disconnectNow("Read failed", ws.CloseAbnormalClosure)
   314  				}
   315  				return
   316  			}
   317  
   318  			err = s.ReadMessage(message)
   319  
   320  			if err != nil {
   321  				s.Log.Debug("WebSocket read failed", "error", err)
   322  				return
   323  			}
   324  		}
   325  	}()
   326  
   327  	return nil
   328  }
   329  
   330  // SendMessages waits for incoming messages and send them to the client connection
   331  func (s *Session) SendMessages() {
   332  	for message := range s.sendCh {
   333  		err := s.writeFrame(message)
   334  
   335  		if message.FrameType == ws.CloseFrame {
   336  			s.disconnectNow("Close frame sent", ws.CloseNormalClosure)
   337  			return
   338  		}
   339  
   340  		if err != nil {
   341  			s.metrics.CounterIncrement(metricsFailedSent)
   342  			s.disconnectNow("Write Failed", ws.CloseAbnormalClosure)
   343  			return
   344  		}
   345  
   346  		s.metrics.CounterIncrement(metricsSentMsg)
   347  	}
   348  }
   349  
   350  // ReadMessage reads messages from ws connection and send them to node
   351  func (s *Session) ReadMessage(message []byte) error {
   352  	s.metrics.CounterAdd(metricsDataReceived, uint64(len(message)))
   353  
   354  	command, err := s.decodeMessage(message)
   355  
   356  	if err != nil {
   357  		s.metrics.CounterIncrement(metricsFailedCommandReceived)
   358  		return err
   359  	}
   360  
   361  	if command == nil {
   362  		return nil
   363  	}
   364  
   365  	s.metrics.CounterIncrement(metricsReceivedMsg)
   366  
   367  	if err := s.executor.HandleCommand(s, command); err != nil {
   368  		s.metrics.CounterIncrement(metricsFailedCommandReceived)
   369  		s.Log.Warn("failed to handle incoming message", "data", logger.CompactValue(message), "error", err)
   370  	}
   371  
   372  	return nil
   373  }
   374  
   375  // Send schedules a data transmission
   376  func (s *Session) Send(msg encoders.EncodedMessage) {
   377  	if b, err := s.encodeMessage(msg); err == nil {
   378  		if b != nil {
   379  			s.sendFrame(b)
   380  		}
   381  	} else {
   382  		s.Log.Warn("failed to encode message", "data", msg, "error", err)
   383  	}
   384  }
   385  
   386  // SendJSONTransmission is used to propagate the direct transmission to the client
   387  // (from RPC call result)
   388  func (s *Session) SendJSONTransmission(msg string) {
   389  	if b, err := s.encodeTransmission(msg); err == nil {
   390  		if b != nil {
   391  			s.sendFrame(b)
   392  		}
   393  	} else {
   394  		s.Log.Warn("failed to encode transmission", "data", logger.CompactValue(msg), "error", err)
   395  	}
   396  }
   397  
   398  // Disconnect schedules connection disconnect
   399  func (s *Session) Disconnect(reason string, code int) {
   400  	s.sendClose(reason, code)
   401  	s.close()
   402  	s.disconnectFromNode()
   403  }
   404  
   405  func (s *Session) DisconnectWithMessage(msg encoders.EncodedMessage, code string) {
   406  	s.Send(msg)
   407  
   408  	reason := ""
   409  	wsCode := ws.CloseNormalClosure
   410  
   411  	switch code {
   412  	case common.SERVER_RESTART_REASON:
   413  		reason = "Server restart"
   414  		wsCode = ws.CloseGoingAway
   415  	case common.REMOTE_DISCONNECT_REASON:
   416  		reason = "Closed remotely"
   417  	}
   418  
   419  	s.Disconnect(reason, wsCode)
   420  }
   421  
   422  // String returns session string representation (for %v in Printf-like functions)
   423  func (s *Session) String() string {
   424  	return fmt.Sprintf("Session(%s)", s.uid)
   425  }
   426  
   427  type cacheEntry struct {
   428  	Identifiers     string                       `json:"ids"`
   429  	Subscriptions   map[string][]string          `json:"subs"`
   430  	ConnectionState map[string]string            `json:"cstate"`
   431  	ChannelsState   map[string]map[string]string `json:"istate"`
   432  	Disconnectable  bool
   433  }
   434  
   435  func (s *Session) ToCacheEntry() ([]byte, error) {
   436  	s.smu.Lock()
   437  	defer s.smu.Unlock()
   438  
   439  	entry := cacheEntry{
   440  		Identifiers:     s.GetIdentifiers(),
   441  		Subscriptions:   s.subscriptions.ToMap(),
   442  		ConnectionState: *s.env.ConnectionState,
   443  		ChannelsState:   *s.env.ChannelStates,
   444  		Disconnectable:  s.disconnectInterest,
   445  	}
   446  
   447  	return json.Marshal(&entry)
   448  }
   449  
   450  func (s *Session) RestoreFromCache(cached []byte) error {
   451  	var entry cacheEntry
   452  
   453  	err := json.Unmarshal(cached, &entry)
   454  
   455  	if err != nil {
   456  		return err
   457  	}
   458  
   459  	s.smu.Lock()
   460  	defer s.smu.Unlock()
   461  
   462  	s.MarkDisconnectable(entry.Disconnectable)
   463  	s.SetIdentifiers(entry.Identifiers)
   464  	s.env.MergeConnectionState(&entry.ConnectionState)
   465  
   466  	for k := range entry.ChannelsState {
   467  		v := entry.ChannelsState[k]
   468  		s.env.MergeChannelState(k, &v)
   469  	}
   470  
   471  	for k, v := range entry.Subscriptions {
   472  		s.subscriptions.AddChannel(k)
   473  
   474  		for _, stream := range v {
   475  			s.subscriptions.AddChannelStream(k, stream)
   476  		}
   477  	}
   478  
   479  	return nil
   480  }
   481  
   482  func (s *Session) PrevSid() string {
   483  	return s.prevSid
   484  }
   485  
   486  func (s *Session) disconnectFromNode() {
   487  	s.mu.Lock()
   488  	if s.Connected {
   489  		defer s.executor.Disconnect(s) // nolint:errcheck
   490  	}
   491  	s.Connected = false
   492  	s.mu.Unlock()
   493  }
   494  
   495  func (s *Session) DisconnectNow(reason string, code int) {
   496  	s.disconnectNow(reason, code)
   497  }
   498  
   499  func (s *Session) disconnectNow(reason string, code int) {
   500  	s.mu.Lock()
   501  	if !s.Connected {
   502  		s.mu.Unlock()
   503  		return
   504  	}
   505  	s.mu.Unlock()
   506  
   507  	s.disconnectFromNode()
   508  	s.writeFrame(&ws.SentFrame{ // nolint:errcheck
   509  		FrameType:   ws.CloseFrame,
   510  		CloseReason: reason,
   511  		CloseCode:   code,
   512  	})
   513  
   514  	s.mu.Lock()
   515  	if s.sendCh != nil {
   516  		close(s.sendCh)
   517  		s.sendCh = nil
   518  	}
   519  	s.mu.Unlock()
   520  
   521  	s.close()
   522  }
   523  
   524  func (s *Session) close() {
   525  	s.mu.Lock()
   526  
   527  	if s.closed {
   528  		s.mu.Unlock()
   529  		return
   530  	}
   531  
   532  	s.closed = true
   533  	defer s.mu.Unlock()
   534  
   535  	if s.pingTimer != nil {
   536  		s.pingTimer.Stop()
   537  	}
   538  
   539  	if s.pongTimer != nil {
   540  		s.pongTimer.Stop()
   541  	}
   542  }
   543  
   544  func (s *Session) IsClosed() bool {
   545  	s.mu.Lock()
   546  	defer s.mu.Unlock()
   547  
   548  	return s.closed
   549  }
   550  
   551  func (s *Session) sendClose(reason string, code int) {
   552  	s.sendFrame(&ws.SentFrame{
   553  		FrameType:   ws.CloseFrame,
   554  		CloseReason: reason,
   555  		CloseCode:   code,
   556  	})
   557  }
   558  
   559  func (s *Session) sendFrame(message *ws.SentFrame) {
   560  	s.mu.Lock()
   561  
   562  	if s.sendCh == nil {
   563  		s.mu.Unlock()
   564  		return
   565  	}
   566  
   567  	select {
   568  	case s.sendCh <- message:
   569  	default:
   570  		if s.sendCh != nil {
   571  			close(s.sendCh)
   572  			defer s.Disconnect("Write failed", ws.CloseAbnormalClosure)
   573  		}
   574  
   575  		s.sendCh = nil
   576  	}
   577  
   578  	s.mu.Unlock()
   579  }
   580  
   581  func (s *Session) writeFrame(message *ws.SentFrame) error {
   582  	return s.writeFrameWithDeadline(message, time.Now().Add(writeWait))
   583  }
   584  
   585  func (s *Session) writeFrameWithDeadline(message *ws.SentFrame, deadline time.Time) error {
   586  	s.metrics.CounterAdd(metricsDataSent, uint64(len(message.Payload)))
   587  
   588  	switch message.FrameType {
   589  	case ws.TextFrame:
   590  		s.mu.Lock()
   591  		defer s.mu.Unlock()
   592  
   593  		err := s.conn.Write(message.Payload, deadline)
   594  		return err
   595  	case ws.BinaryFrame:
   596  		s.mu.Lock()
   597  		defer s.mu.Unlock()
   598  
   599  		err := s.conn.WriteBinary(message.Payload, deadline)
   600  
   601  		return err
   602  	case ws.CloseFrame:
   603  		s.conn.Close(message.CloseCode, message.CloseReason)
   604  		return errors.New("closed")
   605  	default:
   606  		s.Log.Error("unknown frame type", "msg", message)
   607  		return errors.New("unknown frame type")
   608  	}
   609  }
   610  
   611  func (s *Session) sendPing() {
   612  	s.mu.Lock()
   613  	if s.closed {
   614  		s.mu.Unlock()
   615  		return
   616  	}
   617  	s.mu.Unlock()
   618  
   619  	deadline := time.Now().Add(s.pingInterval / 2)
   620  
   621  	b, err := s.encodeMessage(newPingMessage(s.pingTimestampPrecision))
   622  
   623  	if err != nil {
   624  		s.Log.Error("failed to encode ping message", "error", err)
   625  	} else if b != nil {
   626  		err = s.writeFrameWithDeadline(b, deadline)
   627  	}
   628  
   629  	if err != nil {
   630  		s.Disconnect("Ping failed", ws.CloseAbnormalClosure)
   631  		return
   632  	}
   633  
   634  	s.addPing()
   635  }
   636  
   637  func (s *Session) startPing() {
   638  	s.mu.Lock()
   639  	defer s.mu.Unlock()
   640  
   641  	// Calculate the minimum and maximum durations
   642  	minDuration := s.pingInterval / 2
   643  	maxDuration := s.pingInterval * 3 / 2
   644  
   645  	initialInterval := time.Duration(rand.Int63n(int64(maxDuration-minDuration))) + minDuration // nolint:gosec
   646  
   647  	s.pingTimer = time.AfterFunc(initialInterval, s.sendPing)
   648  
   649  	if s.pongTimeout > 0 {
   650  		s.pongTimer = time.AfterFunc(s.pongTimeout+initialInterval, s.handleNoPong)
   651  	}
   652  }
   653  
   654  func (s *Session) addPing() {
   655  	s.mu.Lock()
   656  	defer s.mu.Unlock()
   657  
   658  	s.pingTimer = time.AfterFunc(s.pingInterval, s.sendPing)
   659  }
   660  
   661  func newPingMessage(format string) *common.PingMessage {
   662  	var ts int64
   663  
   664  	switch format {
   665  	case "ns":
   666  		ts = time.Now().UnixNano()
   667  	case "ms":
   668  		ts = time.Now().UnixNano() / int64(time.Millisecond)
   669  	default:
   670  		ts = time.Now().Unix()
   671  	}
   672  
   673  	return (&common.PingMessage{Type: "ping", Message: ts})
   674  }
   675  
   676  func (s *Session) handlePong(msg *common.Message) {
   677  	s.mu.Lock()
   678  	defer s.mu.Unlock()
   679  
   680  	if s.pongTimer == nil {
   681  		s.Log.Debug("unexpected pong received")
   682  		return
   683  	}
   684  
   685  	s.pongTimer.Reset(s.pongTimeout)
   686  }
   687  
   688  func (s *Session) handleNoPong() {
   689  	s.mu.Lock()
   690  
   691  	if !s.Connected {
   692  		s.mu.Unlock()
   693  		return
   694  	}
   695  
   696  	s.mu.Unlock()
   697  
   698  	s.Log.Warn("disconnecting session due to no pongs")
   699  
   700  	s.Send(common.NewDisconnectMessage(common.NO_PONG_REASON, true)) // nolint:errcheck
   701  	s.Disconnect("No Pong", ws.CloseNormalClosure)
   702  }
   703  
   704  func (s *Session) encodeMessage(msg encoders.EncodedMessage) (*ws.SentFrame, error) {
   705  	if cm, ok := msg.(*encoders.CachedEncodedMessage); ok {
   706  		return cm.Fetch(
   707  			s.encoder.ID(),
   708  			func(m encoders.EncodedMessage) (*ws.SentFrame, error) {
   709  				return s.encoder.Encode(m)
   710  			})
   711  	}
   712  
   713  	return s.encoder.Encode(msg)
   714  }
   715  
   716  func (s *Session) encodeTransmission(msg string) (*ws.SentFrame, error) {
   717  	return s.encoder.EncodeTransmission(msg)
   718  }
   719  
   720  func (s *Session) decodeMessage(raw []byte) (*common.Message, error) {
   721  	return s.encoder.Decode(raw)
   722  }