github.com/anycable/anycable-go@v1.5.1/hub/gate.go (about) 1 package hub 2 3 import ( 4 "context" 5 "log/slog" 6 "sync" 7 8 "github.com/anycable/anycable-go/common" 9 "github.com/anycable/anycable-go/encoders" 10 ) 11 12 // Gate plays the role of a shard for the hub. 13 // It keeps subscriptions for some streams (a particular shard) and is used 14 // to broadcast messages to all subscribers of these streams. 15 type Gate struct { 16 // Maps streams to sessions with identifiers 17 // stream -> session -> identifier -> true 18 streams map[string]map[HubSession]map[string]bool 19 20 // Maps sessions to identifiers to streams 21 // session -> identifier -> [stream] 22 sessionsStreams map[HubSession]map[string][]string 23 24 // This channel is used as a broadcast queue 25 sender chan *common.StreamMessage 26 27 mu sync.RWMutex 28 log *slog.Logger 29 } 30 31 // NewGate creates a new gate. 32 func NewGate(ctx context.Context, l *slog.Logger) *Gate { 33 g := Gate{ 34 streams: make(map[string]map[HubSession]map[string]bool), 35 sessionsStreams: make(map[HubSession]map[string][]string), 36 // Use a buffered channel to avoid blocking 37 sender: make(chan *common.StreamMessage, 256), 38 log: l, 39 } 40 41 go g.broadcastLoop(ctx) 42 43 return &g 44 } 45 46 // Broadcast sends a message to all subscribers of the stream. 47 func (g *Gate) Broadcast(streamMsg *common.StreamMessage) { 48 stream := streamMsg.Stream 49 50 ctx := g.log.With("stream", stream) 51 52 ctx.Debug("schedule broadcast", "message", streamMsg) 53 54 g.mu.RLock() 55 if _, ok := g.streams[stream]; !ok { 56 ctx.Debug("no sessions") 57 g.mu.RUnlock() 58 return 59 } 60 g.mu.RUnlock() 61 62 g.sender <- streamMsg 63 } 64 65 // Subscribe adds a session to the stream. 66 func (g *Gate) Subscribe(session HubSession, stream string, identifier string) { 67 g.mu.Lock() 68 defer g.mu.Unlock() 69 70 if _, ok := g.streams[stream]; !ok { 71 g.streams[stream] = make(map[HubSession]map[string]bool) 72 } 73 74 if _, ok := g.streams[stream][session]; !ok { 75 g.streams[stream][session] = make(map[string]bool) 76 } 77 78 g.streams[stream][session][identifier] = true 79 } 80 81 // Unsubscribe removes a session from the stream. 82 func (g *Gate) Unsubscribe(session HubSession, stream string, identifier string) { 83 g.mu.RLock() 84 85 if _, ok := g.streams[stream]; !ok { 86 g.mu.RUnlock() 87 return 88 } 89 90 if _, ok := g.streams[stream][session]; !ok { 91 g.mu.RUnlock() 92 return 93 } 94 95 if _, ok := g.streams[stream][session][identifier]; !ok { 96 g.mu.RUnlock() 97 return 98 } 99 100 g.mu.RUnlock() 101 102 g.mu.Lock() 103 defer g.mu.Unlock() 104 105 delete(g.streams[stream][session], identifier) 106 107 if len(g.streams[stream][session]) == 0 { 108 delete(g.streams[stream], session) 109 110 if len(g.streams[stream]) == 0 { 111 delete(g.streams, stream) 112 } 113 } 114 } 115 116 // Size returns a number of uniq streams 117 func (g *Gate) Size() int { 118 g.mu.RLock() 119 defer g.mu.RUnlock() 120 121 return len(g.streams) 122 } 123 124 func (g *Gate) broadcastLoop(ctx context.Context) { 125 for { 126 select { 127 case <-ctx.Done(): 128 return 129 case msg := <-g.sender: 130 g.performBroadcast(msg) 131 } 132 } 133 } 134 135 func (g *Gate) performBroadcast(streamMsg *common.StreamMessage) { 136 stream := streamMsg.Stream 137 138 buf := make(map[string](encoders.EncodedMessage)) 139 140 var bdata encoders.EncodedMessage 141 142 g.mu.RLock() 143 streamSessions := streamSessionsSnapshot(g.streams[stream]) 144 g.mu.RUnlock() 145 146 for session, ids := range streamSessions { 147 if streamMsg.Meta != nil && streamMsg.Meta.ExcludeSocket == session.GetID() { 148 continue 149 } 150 151 for _, id := range ids { 152 if msg, ok := buf[id]; ok { 153 bdata = msg 154 } else { 155 bdata = buildMessage(streamMsg, id) 156 buf[id] = bdata 157 } 158 159 session.Send(bdata) 160 } 161 } 162 } 163 164 func buildMessage(msg *common.StreamMessage, identifier string) encoders.EncodedMessage { 165 reply := msg.ToReplyFor(identifier) 166 167 if msg.Meta != nil { 168 reply.Type = msg.Meta.BroadcastType 169 } 170 171 return encoders.NewCachedEncodedMessage(reply) 172 } 173 174 func streamSessionsSnapshot[T comparable](src map[T]map[string]bool) map[T][]string { 175 dest := make(map[T][]string) 176 177 for k, v := range src { 178 dest[k] = make([]string, len(v)) 179 180 i := 0 181 182 for id := range v { 183 dest[k][i] = id 184 i++ 185 } 186 } 187 188 return dest 189 }