github.com/anycable/anycable-go@v1.5.1/node/session.go (about) 1 package node 2 3 import ( 4 "encoding/json" 5 "errors" 6 "fmt" 7 "log/slog" 8 "math/rand" 9 "sync" 10 "time" 11 12 "github.com/anycable/anycable-go/common" 13 "github.com/anycable/anycable-go/encoders" 14 "github.com/anycable/anycable-go/logger" 15 "github.com/anycable/anycable-go/metrics" 16 "github.com/anycable/anycable-go/ws" 17 ) 18 19 const ( 20 writeWait = 10 * time.Second 21 ) 22 23 // Executor handles incoming commands (messages) 24 type Executor interface { 25 HandleCommand(*Session, *common.Message) error 26 Disconnect(*Session) error 27 } 28 29 // Session represents active client 30 type Session struct { 31 conn Connection 32 uid string 33 encoder encoders.Encoder 34 executor Executor 35 metrics metrics.Instrumenter 36 env *common.SessionEnv 37 subscriptions *SubscriptionState 38 closed bool 39 40 // Defines if we should perform Disconnect RPC for this session 41 disconnectInterest bool 42 43 // Main mutex (for read/write and important session updates) 44 mu sync.Mutex 45 // Mutex for protocol-related state (env, subscriptions) 46 smu sync.Mutex 47 48 sendCh chan *ws.SentFrame 49 50 pingTimer *time.Timer 51 pingInterval time.Duration 52 53 pingTimestampPrecision string 54 55 handshakeDeadline time.Time 56 57 pongTimeout time.Duration 58 pongTimer *time.Timer 59 60 resumable bool 61 prevSid string 62 63 Connected bool 64 // Could be used to store arbitrary data within a session 65 InternalState map[string]interface{} 66 Log *slog.Logger 67 } 68 69 type SessionOption = func(*Session) 70 71 // WithPingInterval allows to set a custom ping interval for a session 72 // or disable pings at all (by passing 0) 73 func WithPingInterval(interval time.Duration) SessionOption { 74 return func(s *Session) { 75 s.pingInterval = interval 76 } 77 } 78 79 // WithPingPrecision allows to configure precision for timestamps attached to pings 80 func WithPingPrecision(val string) SessionOption { 81 return func(s *Session) { 82 s.pingTimestampPrecision = val 83 } 84 } 85 86 // WithEncoder allows to set a custom encoder for a session 87 func WithEncoder(enc encoders.Encoder) SessionOption { 88 return func(s *Session) { 89 s.encoder = enc 90 } 91 } 92 93 // WithExecutor allows to set a custom executor for a session 94 func WithExecutor(ex Executor) SessionOption { 95 return func(s *Session) { 96 s.executor = ex 97 } 98 } 99 100 // WithHandshakeMessageDeadline allows to set a custom deadline for handshake messages. 101 // This option also indicates that we MUST NOT perform Authenticate on connect. 102 func WithHandshakeMessageDeadline(deadline time.Time) SessionOption { 103 return func(s *Session) { 104 s.handshakeDeadline = deadline 105 } 106 } 107 108 // WithMetrics allows to set a custom metrics instrumenter for a session 109 func WithMetrics(m metrics.Instrumenter) SessionOption { 110 return func(s *Session) { 111 s.metrics = m 112 } 113 } 114 115 // WithResumable allows marking session as resumable (so we store its state in cache) 116 func WithResumable(val bool) SessionOption { 117 return func(s *Session) { 118 s.resumable = val 119 } 120 } 121 122 // WithPrevSID allows providing the previous session ID to restore from 123 func WithPrevSID(sid string) SessionOption { 124 return func(s *Session) { 125 s.prevSid = sid 126 } 127 } 128 129 // WithPongTimeout allows to set a custom pong timeout for a session 130 func WithPongTimeout(timeout time.Duration) SessionOption { 131 return func(s *Session) { 132 s.pongTimeout = timeout 133 } 134 } 135 136 // NewSession build a new Session struct from ws connetion and http request 137 func NewSession(node *Node, conn Connection, url string, headers *map[string]string, uid string, opts ...SessionOption) *Session { 138 session := &Session{ 139 conn: conn, 140 metrics: node.metrics, 141 env: common.NewSessionEnv(url, headers), 142 subscriptions: NewSubscriptionState(), 143 sendCh: make(chan *ws.SentFrame, 256), 144 closed: false, 145 Connected: false, 146 pingInterval: time.Duration(node.config.PingInterval) * time.Second, 147 pingTimestampPrecision: node.config.PingTimestampPrecision, 148 // Use JSON by default 149 encoder: encoders.JSON{}, 150 // Use Action Cable executor by default (implemented by node) 151 executor: node, 152 } 153 154 session.uid = uid 155 156 ctx := node.log.With("sid", session.uid) 157 158 session.Log = ctx 159 160 for _, opt := range opts { 161 opt(session) 162 } 163 164 if session.pingInterval > 0 { 165 session.startPing() 166 } 167 168 if !session.handshakeDeadline.IsZero() { 169 val := time.Until(session.handshakeDeadline) 170 time.AfterFunc(val, session.maybeDisconnectIdle) 171 } 172 173 go session.SendMessages() 174 175 return session 176 } 177 178 func (s *Session) GetEnv() *common.SessionEnv { 179 return s.env 180 } 181 182 func (s *Session) SetEnv(env *common.SessionEnv) { 183 s.env = env 184 } 185 186 func (s *Session) UnderlyingConn() Connection { 187 return s.conn 188 } 189 190 func (s *Session) AuthenticateOnConnect() bool { 191 return s.handshakeDeadline.IsZero() 192 } 193 194 func (s *Session) IsConnected() bool { 195 s.mu.Lock() 196 defer s.mu.Unlock() 197 198 return s.Connected 199 } 200 201 func (s *Session) IsResumeable() bool { 202 return s.resumable 203 } 204 205 func (s *Session) maybeDisconnectIdle() { 206 s.mu.Lock() 207 208 if s.Connected { 209 s.mu.Unlock() 210 return 211 } 212 213 s.mu.Unlock() 214 215 s.Log.Warn("disconnecting idle session") 216 217 s.Send(common.NewDisconnectMessage(common.IDLE_TIMEOUT_REASON, false)) 218 s.Disconnect("Idle Timeout", ws.CloseNormalClosure) 219 } 220 221 func (s *Session) GetID() string { 222 return s.uid 223 } 224 225 func (s *Session) SetID(id string) { 226 s.uid = id 227 } 228 229 func (s *Session) GetIdentifiers() string { 230 return s.env.Identifiers 231 } 232 233 func (s *Session) SetIdentifiers(ids string) { 234 s.env.Identifiers = ids 235 } 236 237 // Merge connection and channel states into current env. 238 // This method locks the state for writing (so, goroutine-safe) 239 func (s *Session) MergeEnv(env *common.SessionEnv) { 240 s.smu.Lock() 241 defer s.smu.Unlock() 242 243 if env.ConnectionState != nil { 244 s.env.MergeConnectionState(env.ConnectionState) 245 } 246 247 if env.ChannelStates != nil { 248 states := *env.ChannelStates 249 for id, state := range states { // #nosec 250 s.env.MergeChannelState(id, &state) 251 } 252 } 253 } 254 255 // WriteInternalState 256 func (s *Session) WriteInternalState(key string, val interface{}) { 257 s.mu.Lock() 258 defer s.mu.Unlock() 259 260 if s.InternalState == nil { 261 s.InternalState = make(map[string]interface{}) 262 } 263 264 s.InternalState[key] = val 265 } 266 267 // ReadInternalState reads internal state value by key 268 func (s *Session) ReadInternalState(key string) (interface{}, bool) { 269 s.mu.Lock() 270 defer s.mu.Unlock() 271 272 if s.InternalState == nil { 273 return nil, false 274 } 275 276 val, ok := s.InternalState[key] 277 278 return val, ok 279 } 280 281 func (s *Session) IsDisconnectable() bool { 282 s.mu.Lock() 283 defer s.mu.Unlock() 284 285 return s.disconnectInterest 286 } 287 288 func (s *Session) MarkDisconnectable(val bool) { 289 s.mu.Lock() 290 defer s.mu.Unlock() 291 292 s.disconnectInterest = s.disconnectInterest || val 293 } 294 295 // Serve enters a loop to read incoming data 296 func (s *Session) Serve(callback func()) error { 297 go func() { 298 defer callback() 299 300 for { 301 if s.IsClosed() { 302 return 303 } 304 305 message, err := s.conn.Read() 306 307 if err != nil { 308 if ws.IsCloseError(err) { 309 s.Log.Debug("WebSocket closed", "error", err) 310 s.disconnectNow("Read closed", ws.CloseNormalClosure) 311 } else { 312 s.Log.Debug("WebSocket close error", "error", err) 313 s.disconnectNow("Read failed", ws.CloseAbnormalClosure) 314 } 315 return 316 } 317 318 err = s.ReadMessage(message) 319 320 if err != nil { 321 s.Log.Debug("WebSocket read failed", "error", err) 322 return 323 } 324 } 325 }() 326 327 return nil 328 } 329 330 // SendMessages waits for incoming messages and send them to the client connection 331 func (s *Session) SendMessages() { 332 for message := range s.sendCh { 333 err := s.writeFrame(message) 334 335 if message.FrameType == ws.CloseFrame { 336 s.disconnectNow("Close frame sent", ws.CloseNormalClosure) 337 return 338 } 339 340 if err != nil { 341 s.metrics.CounterIncrement(metricsFailedSent) 342 s.disconnectNow("Write Failed", ws.CloseAbnormalClosure) 343 return 344 } 345 346 s.metrics.CounterIncrement(metricsSentMsg) 347 } 348 } 349 350 // ReadMessage reads messages from ws connection and send them to node 351 func (s *Session) ReadMessage(message []byte) error { 352 s.metrics.CounterAdd(metricsDataReceived, uint64(len(message))) 353 354 command, err := s.decodeMessage(message) 355 356 if err != nil { 357 s.metrics.CounterIncrement(metricsFailedCommandReceived) 358 return err 359 } 360 361 if command == nil { 362 return nil 363 } 364 365 s.metrics.CounterIncrement(metricsReceivedMsg) 366 367 if err := s.executor.HandleCommand(s, command); err != nil { 368 s.metrics.CounterIncrement(metricsFailedCommandReceived) 369 s.Log.Warn("failed to handle incoming message", "data", logger.CompactValue(message), "error", err) 370 } 371 372 return nil 373 } 374 375 // Send schedules a data transmission 376 func (s *Session) Send(msg encoders.EncodedMessage) { 377 if b, err := s.encodeMessage(msg); err == nil { 378 if b != nil { 379 s.sendFrame(b) 380 } 381 } else { 382 s.Log.Warn("failed to encode message", "data", msg, "error", err) 383 } 384 } 385 386 // SendJSONTransmission is used to propagate the direct transmission to the client 387 // (from RPC call result) 388 func (s *Session) SendJSONTransmission(msg string) { 389 if b, err := s.encodeTransmission(msg); err == nil { 390 if b != nil { 391 s.sendFrame(b) 392 } 393 } else { 394 s.Log.Warn("failed to encode transmission", "data", logger.CompactValue(msg), "error", err) 395 } 396 } 397 398 // Disconnect schedules connection disconnect 399 func (s *Session) Disconnect(reason string, code int) { 400 s.sendClose(reason, code) 401 s.close() 402 s.disconnectFromNode() 403 } 404 405 func (s *Session) DisconnectWithMessage(msg encoders.EncodedMessage, code string) { 406 s.Send(msg) 407 408 reason := "" 409 wsCode := ws.CloseNormalClosure 410 411 switch code { 412 case common.SERVER_RESTART_REASON: 413 reason = "Server restart" 414 wsCode = ws.CloseGoingAway 415 case common.REMOTE_DISCONNECT_REASON: 416 reason = "Closed remotely" 417 } 418 419 s.Disconnect(reason, wsCode) 420 } 421 422 // String returns session string representation (for %v in Printf-like functions) 423 func (s *Session) String() string { 424 return fmt.Sprintf("Session(%s)", s.uid) 425 } 426 427 type cacheEntry struct { 428 Identifiers string `json:"ids"` 429 Subscriptions map[string][]string `json:"subs"` 430 ConnectionState map[string]string `json:"cstate"` 431 ChannelsState map[string]map[string]string `json:"istate"` 432 Disconnectable bool 433 } 434 435 func (s *Session) ToCacheEntry() ([]byte, error) { 436 s.smu.Lock() 437 defer s.smu.Unlock() 438 439 entry := cacheEntry{ 440 Identifiers: s.GetIdentifiers(), 441 Subscriptions: s.subscriptions.ToMap(), 442 ConnectionState: *s.env.ConnectionState, 443 ChannelsState: *s.env.ChannelStates, 444 Disconnectable: s.disconnectInterest, 445 } 446 447 return json.Marshal(&entry) 448 } 449 450 func (s *Session) RestoreFromCache(cached []byte) error { 451 var entry cacheEntry 452 453 err := json.Unmarshal(cached, &entry) 454 455 if err != nil { 456 return err 457 } 458 459 s.smu.Lock() 460 defer s.smu.Unlock() 461 462 s.MarkDisconnectable(entry.Disconnectable) 463 s.SetIdentifiers(entry.Identifiers) 464 s.env.MergeConnectionState(&entry.ConnectionState) 465 466 for k := range entry.ChannelsState { 467 v := entry.ChannelsState[k] 468 s.env.MergeChannelState(k, &v) 469 } 470 471 for k, v := range entry.Subscriptions { 472 s.subscriptions.AddChannel(k) 473 474 for _, stream := range v { 475 s.subscriptions.AddChannelStream(k, stream) 476 } 477 } 478 479 return nil 480 } 481 482 func (s *Session) PrevSid() string { 483 return s.prevSid 484 } 485 486 func (s *Session) disconnectFromNode() { 487 s.mu.Lock() 488 if s.Connected { 489 defer s.executor.Disconnect(s) // nolint:errcheck 490 } 491 s.Connected = false 492 s.mu.Unlock() 493 } 494 495 func (s *Session) DisconnectNow(reason string, code int) { 496 s.disconnectNow(reason, code) 497 } 498 499 func (s *Session) disconnectNow(reason string, code int) { 500 s.mu.Lock() 501 if !s.Connected { 502 s.mu.Unlock() 503 return 504 } 505 s.mu.Unlock() 506 507 s.disconnectFromNode() 508 s.writeFrame(&ws.SentFrame{ // nolint:errcheck 509 FrameType: ws.CloseFrame, 510 CloseReason: reason, 511 CloseCode: code, 512 }) 513 514 s.mu.Lock() 515 if s.sendCh != nil { 516 close(s.sendCh) 517 s.sendCh = nil 518 } 519 s.mu.Unlock() 520 521 s.close() 522 } 523 524 func (s *Session) close() { 525 s.mu.Lock() 526 527 if s.closed { 528 s.mu.Unlock() 529 return 530 } 531 532 s.closed = true 533 defer s.mu.Unlock() 534 535 if s.pingTimer != nil { 536 s.pingTimer.Stop() 537 } 538 539 if s.pongTimer != nil { 540 s.pongTimer.Stop() 541 } 542 } 543 544 func (s *Session) IsClosed() bool { 545 s.mu.Lock() 546 defer s.mu.Unlock() 547 548 return s.closed 549 } 550 551 func (s *Session) sendClose(reason string, code int) { 552 s.sendFrame(&ws.SentFrame{ 553 FrameType: ws.CloseFrame, 554 CloseReason: reason, 555 CloseCode: code, 556 }) 557 } 558 559 func (s *Session) sendFrame(message *ws.SentFrame) { 560 s.mu.Lock() 561 562 if s.sendCh == nil { 563 s.mu.Unlock() 564 return 565 } 566 567 select { 568 case s.sendCh <- message: 569 default: 570 if s.sendCh != nil { 571 close(s.sendCh) 572 defer s.Disconnect("Write failed", ws.CloseAbnormalClosure) 573 } 574 575 s.sendCh = nil 576 } 577 578 s.mu.Unlock() 579 } 580 581 func (s *Session) writeFrame(message *ws.SentFrame) error { 582 return s.writeFrameWithDeadline(message, time.Now().Add(writeWait)) 583 } 584 585 func (s *Session) writeFrameWithDeadline(message *ws.SentFrame, deadline time.Time) error { 586 s.metrics.CounterAdd(metricsDataSent, uint64(len(message.Payload))) 587 588 switch message.FrameType { 589 case ws.TextFrame: 590 s.mu.Lock() 591 defer s.mu.Unlock() 592 593 err := s.conn.Write(message.Payload, deadline) 594 return err 595 case ws.BinaryFrame: 596 s.mu.Lock() 597 defer s.mu.Unlock() 598 599 err := s.conn.WriteBinary(message.Payload, deadline) 600 601 return err 602 case ws.CloseFrame: 603 s.conn.Close(message.CloseCode, message.CloseReason) 604 return errors.New("closed") 605 default: 606 s.Log.Error("unknown frame type", "msg", message) 607 return errors.New("unknown frame type") 608 } 609 } 610 611 func (s *Session) sendPing() { 612 s.mu.Lock() 613 if s.closed { 614 s.mu.Unlock() 615 return 616 } 617 s.mu.Unlock() 618 619 deadline := time.Now().Add(s.pingInterval / 2) 620 621 b, err := s.encodeMessage(newPingMessage(s.pingTimestampPrecision)) 622 623 if err != nil { 624 s.Log.Error("failed to encode ping message", "error", err) 625 } else if b != nil { 626 err = s.writeFrameWithDeadline(b, deadline) 627 } 628 629 if err != nil { 630 s.Disconnect("Ping failed", ws.CloseAbnormalClosure) 631 return 632 } 633 634 s.addPing() 635 } 636 637 func (s *Session) startPing() { 638 s.mu.Lock() 639 defer s.mu.Unlock() 640 641 // Calculate the minimum and maximum durations 642 minDuration := s.pingInterval / 2 643 maxDuration := s.pingInterval * 3 / 2 644 645 initialInterval := time.Duration(rand.Int63n(int64(maxDuration-minDuration))) + minDuration // nolint:gosec 646 647 s.pingTimer = time.AfterFunc(initialInterval, s.sendPing) 648 649 if s.pongTimeout > 0 { 650 s.pongTimer = time.AfterFunc(s.pongTimeout+initialInterval, s.handleNoPong) 651 } 652 } 653 654 func (s *Session) addPing() { 655 s.mu.Lock() 656 defer s.mu.Unlock() 657 658 s.pingTimer = time.AfterFunc(s.pingInterval, s.sendPing) 659 } 660 661 func newPingMessage(format string) *common.PingMessage { 662 var ts int64 663 664 switch format { 665 case "ns": 666 ts = time.Now().UnixNano() 667 case "ms": 668 ts = time.Now().UnixNano() / int64(time.Millisecond) 669 default: 670 ts = time.Now().Unix() 671 } 672 673 return (&common.PingMessage{Type: "ping", Message: ts}) 674 } 675 676 func (s *Session) handlePong(msg *common.Message) { 677 s.mu.Lock() 678 defer s.mu.Unlock() 679 680 if s.pongTimer == nil { 681 s.Log.Debug("unexpected pong received") 682 return 683 } 684 685 s.pongTimer.Reset(s.pongTimeout) 686 } 687 688 func (s *Session) handleNoPong() { 689 s.mu.Lock() 690 691 if !s.Connected { 692 s.mu.Unlock() 693 return 694 } 695 696 s.mu.Unlock() 697 698 s.Log.Warn("disconnecting session due to no pongs") 699 700 s.Send(common.NewDisconnectMessage(common.NO_PONG_REASON, true)) // nolint:errcheck 701 s.Disconnect("No Pong", ws.CloseNormalClosure) 702 } 703 704 func (s *Session) encodeMessage(msg encoders.EncodedMessage) (*ws.SentFrame, error) { 705 if cm, ok := msg.(*encoders.CachedEncodedMessage); ok { 706 return cm.Fetch( 707 s.encoder.ID(), 708 func(m encoders.EncodedMessage) (*ws.SentFrame, error) { 709 return s.encoder.Encode(m) 710 }) 711 } 712 713 return s.encoder.Encode(msg) 714 } 715 716 func (s *Session) encodeTransmission(msg string) (*ws.SentFrame, error) { 717 return s.encoder.EncodeTransmission(msg) 718 } 719 720 func (s *Session) decodeMessage(raw []byte) (*common.Message, error) { 721 return s.encoder.Decode(raw) 722 }