github.com/anycable/anycable-go@v1.5.1/hub/hub.go (about) 1 package hub 2 3 import ( 4 "context" 5 "hash/fnv" 6 "log/slog" 7 "sync" 8 9 "github.com/anycable/anycable-go/common" 10 "github.com/anycable/anycable-go/encoders" 11 "github.com/anycable/anycable-go/utils" 12 ) 13 14 type HubSession interface { 15 GetID() string 16 GetIdentifiers() string 17 Send(msg encoders.EncodedMessage) 18 DisconnectWithMessage(msg encoders.EncodedMessage, code string) 19 } 20 21 // HubRegistration represents registration event ("add" or "remove") 22 type HubRegistration struct { 23 event string 24 session HubSession 25 } 26 27 // HubSessionInfo is used to track registered sessions 28 type HubSessionInfo struct { 29 session HubSession 30 // List of stream-identifier pairs 31 streams [][]string 32 } 33 34 func NewHubSessionInfo(session HubSession) *HubSessionInfo { 35 return &HubSessionInfo{ 36 session: session, 37 streams: make([][]string, 0), 38 } 39 } 40 41 func (hs *HubSessionInfo) AddStream(stream string, identifier string) { 42 hs.streams = append(hs.streams, []string{stream, identifier}) 43 } 44 45 func (hs *HubSessionInfo) RemoveStream(stream string, identifier string) { 46 for i, s := range hs.streams { 47 if s[0] == stream && s[1] == identifier { 48 hs.streams = append(hs.streams[:i], hs.streams[i+1:]...) 49 break 50 } 51 } 52 } 53 54 // Hub stores all the sessions and the corresponding subscriptions info 55 type Hub struct { 56 // Gates (=shards) 57 gates []*Gate 58 gatesNum int 59 60 // Registered sessions 61 sessions map[string]*HubSessionInfo 62 63 // Identifiers to session 64 identifiers map[string]map[string]bool 65 66 // Messages for specified stream 67 broadcast chan *common.StreamMessage 68 69 // Remote disconnect commands 70 disconnect chan *common.RemoteDisconnectMessage 71 72 // Register requests from the sessions 73 register chan HubRegistration 74 75 // Control channel to shutdown hub 76 shutdown chan struct{} 77 78 // Synchronization group to wait for gracefully disconnect of all sessions 79 done sync.WaitGroup 80 81 doneFn context.CancelFunc 82 83 // Log context 84 log *slog.Logger 85 86 // go pool 87 pool *utils.GoPool 88 89 // mutex for sessions data tracking 90 mu sync.RWMutex 91 } 92 93 // NewHub builds new hub instance 94 func NewHub(poolSize int, l *slog.Logger) *Hub { 95 ctx, doneFn := context.WithCancel(context.Background()) 96 97 return &Hub{ 98 broadcast: make(chan *common.StreamMessage, 256), 99 disconnect: make(chan *common.RemoteDisconnectMessage, 128), 100 register: make(chan HubRegistration, 2048), 101 sessions: make(map[string]*HubSessionInfo), 102 identifiers: make(map[string]map[string]bool), 103 gates: buildGates(ctx, poolSize, l), 104 gatesNum: poolSize, 105 pool: utils.NewGoPool("remote commands", 256), 106 doneFn: doneFn, 107 shutdown: make(chan struct{}), 108 log: l.With("component", "hub"), 109 } 110 } 111 112 // Run makes hub active 113 func (h *Hub) Run() { 114 h.done.Add(1) 115 for { 116 select { 117 case r := <-h.register: 118 if r.event == "add" { 119 h.AddSession(r.session) 120 } else { 121 h.RemoveSession(r.session) 122 } 123 124 case message := <-h.broadcast: 125 h.broadcastToStream(message) 126 127 case command := <-h.disconnect: 128 h.disconnectSessions(command.Identifier, command.Reconnect) 129 130 case <-h.shutdown: 131 h.done.Done() 132 return 133 } 134 } 135 } 136 137 // RemoveSession enqueues session un-registration 138 func (h *Hub) RemoveSessionLater(s HubSession) { 139 h.register <- HubRegistration{event: "remove", session: s} 140 } 141 142 // Broadcast enqueues data broadcasting to a stream 143 func (h *Hub) Broadcast(stream string, data string) { 144 h.broadcast <- &common.StreamMessage{Stream: stream, Data: data} 145 } 146 147 // BroadcastMessage enqueues broadcasting a pre-built StreamMessage 148 func (h *Hub) BroadcastMessage(msg *common.StreamMessage) { 149 h.broadcast <- msg 150 } 151 152 // RemoteDisconnect enqueues remote disconnect command 153 func (h *Hub) RemoteDisconnect(msg *common.RemoteDisconnectMessage) { 154 h.disconnect <- msg 155 } 156 157 // Shutdown sends shutdown command to hub 158 func (h *Hub) Shutdown() { 159 h.shutdown <- struct{}{} 160 161 // Wait for stop listening channels 162 h.done.Wait() 163 } 164 165 // Size returns a number of active sessions 166 func (h *Hub) Size() int { 167 h.mu.RLock() 168 defer h.mu.RUnlock() 169 170 return len(h.sessions) 171 } 172 173 // UniqSize returns a number of uniq identifiers 174 func (h *Hub) UniqSize() int { 175 h.mu.RLock() 176 defer h.mu.RUnlock() 177 178 return len(h.identifiers) 179 } 180 181 // StreamsSize returns a number of uniq streams 182 func (h *Hub) StreamsSize() int { 183 size := 0 184 for _, gate := range h.gates { 185 size += gate.Size() 186 } 187 return size 188 } 189 190 func (h *Hub) AddSession(session HubSession) { 191 h.mu.Lock() 192 defer h.mu.Unlock() 193 194 uid := session.GetID() 195 identifiers := session.GetIdentifiers() 196 197 h.sessions[uid] = NewHubSessionInfo(session) 198 199 if _, ok := h.identifiers[identifiers]; !ok { 200 h.identifiers[identifiers] = make(map[string]bool) 201 } 202 203 h.identifiers[identifiers][uid] = true 204 205 h.log.With("sid", uid).Debug( 206 "registered", "ids", identifiers, 207 ) 208 } 209 210 func (h *Hub) RemoveSession(session HubSession) { 211 h.mu.RLock() 212 uid := session.GetID() 213 214 if _, ok := h.sessions[uid]; !ok { 215 h.mu.RUnlock() 216 h.log.With("sid", uid).Warn("session hasn't been registered") 217 return 218 } 219 h.mu.RUnlock() 220 221 identifiers := session.GetIdentifiers() 222 h.unsubscribeSessionFromAllChannels(session) 223 224 h.mu.Lock() 225 226 delete(h.sessions, uid) 227 delete(h.identifiers[identifiers], uid) 228 229 if len(h.identifiers[identifiers]) == 0 { 230 delete(h.identifiers, identifiers) 231 } 232 233 h.mu.Unlock() 234 235 h.log.With("sid", uid).Debug("unregistered") 236 } 237 238 func (h *Hub) unsubscribeSessionFromAllChannels(session HubSession) { 239 h.mu.Lock() 240 defer h.mu.Unlock() 241 242 sid := session.GetID() 243 244 if sessionInfo, ok := h.sessions[sid]; ok { 245 for _, streamInfo := range sessionInfo.streams { 246 stream, identifier := streamInfo[0], streamInfo[1] 247 248 h.gates[index(stream, h.gatesNum)].Unsubscribe(session, stream, identifier) 249 } 250 } 251 } 252 253 func (h *Hub) UnsubscribeSessionFromChannel(session HubSession, targetIdentifier string) { 254 h.mu.Lock() 255 defer h.mu.Unlock() 256 257 sid := session.GetID() 258 259 if sessionInfo, ok := h.sessions[sid]; ok { 260 for _, streamInfo := range sessionInfo.streams { 261 stream, identifier := streamInfo[0], streamInfo[1] 262 263 if targetIdentifier == identifier { 264 h.gates[index(stream, h.gatesNum)].Unsubscribe(session, stream, identifier) 265 sessionInfo.RemoveStream(stream, identifier) 266 } 267 } 268 } 269 270 h.log.With("sid", sid).Debug("unsubscribed", "identifier", targetIdentifier) 271 } 272 273 func (h *Hub) SubscribeSession(session HubSession, stream string, identifier string) { 274 h.gates[index(stream, h.gatesNum)].Subscribe(session, stream, identifier) 275 276 h.mu.Lock() 277 defer h.mu.Unlock() 278 279 sid := session.GetID() 280 281 if _, ok := h.sessions[sid]; !ok { 282 h.sessions[sid] = NewHubSessionInfo(session) 283 } 284 285 h.sessions[sid].AddStream(stream, identifier) 286 287 h.log.With("sid", sid).Debug("subscribed", "identifier", identifier, "stream", stream) 288 } 289 290 func (h *Hub) UnsubscribeSession(session HubSession, stream string, identifier string) { 291 h.gates[index(stream, h.gatesNum)].Unsubscribe(session, stream, identifier) 292 293 h.mu.Lock() 294 defer h.mu.Unlock() 295 296 sid := session.GetID() 297 298 if info, ok := h.sessions[sid]; ok { 299 info.RemoveStream(stream, identifier) 300 } 301 302 h.log.With("sid", sid).Debug("unsubscribed", "identifier", identifier, "stream", stream) 303 } 304 305 func (h *Hub) broadcastToStream(streamMsg *common.StreamMessage) { 306 h.gates[index(streamMsg.Stream, h.gatesNum)].Broadcast(streamMsg) 307 } 308 309 func (h *Hub) disconnectSessions(identifier string, reconnect bool) { 310 h.mu.RLock() 311 ids, ok := h.identifiers[identifier] 312 h.mu.RUnlock() 313 314 if !ok { 315 h.log.Debug("cannot disconnect session", "identifier", identifier, "reason", "not found") 316 return 317 } 318 319 msg := common.NewDisconnectMessage(common.REMOTE_DISCONNECT_REASON, reconnect) 320 321 h.pool.Schedule(func() { 322 h.mu.RLock() 323 defer h.mu.RUnlock() 324 325 for id := range ids { 326 if sinfo, ok := h.sessions[id]; ok { 327 sinfo.session.DisconnectWithMessage(msg, common.REMOTE_DISCONNECT_REASON) 328 } 329 } 330 }) 331 } 332 333 func (h *Hub) FindByIdentifier(id string) HubSession { 334 h.mu.RLock() 335 defer h.mu.RUnlock() 336 337 ids, ok := h.identifiers[id] 338 339 if !ok { 340 return nil 341 } 342 343 for id := range ids { 344 if info, ok := h.sessions[id]; ok { 345 return info.session 346 } 347 } 348 349 return nil 350 } 351 352 func (h *Hub) Sessions() []HubSession { 353 h.mu.RLock() 354 defer h.mu.RUnlock() 355 356 sessions := make([]HubSession, 0, len(h.sessions)) 357 358 for _, info := range h.sessions { 359 sessions = append(sessions, info.session) 360 } 361 362 return sessions 363 } 364 365 func buildGates(ctx context.Context, num int, l *slog.Logger) []*Gate { 366 gates := make([]*Gate, 0, num) 367 for i := 0; i < num; i++ { 368 gates = append(gates, NewGate(ctx, l.With("component", "hub", "gate", i))) 369 } 370 371 return gates 372 } 373 374 func index(stream string, size int) int { 375 if size == 1 { 376 return 0 377 } 378 379 hash := fnv.New64a() 380 hash.Write([]byte(stream)) 381 return int(hash.Sum64() % uint64(size)) 382 }