github.com/anycable/anycable-go@v1.5.1/hub/gate.go (about)

     1  package hub
     2  
     3  import (
     4  	"context"
     5  	"log/slog"
     6  	"sync"
     7  
     8  	"github.com/anycable/anycable-go/common"
     9  	"github.com/anycable/anycable-go/encoders"
    10  )
    11  
    12  // Gate plays the role of a shard for the hub.
    13  // It keeps subscriptions for some streams (a particular shard) and is used
    14  // to broadcast messages to all subscribers of these streams.
    15  type Gate struct {
    16  	// Maps streams to sessions with identifiers
    17  	// stream -> session -> identifier -> true
    18  	streams map[string]map[HubSession]map[string]bool
    19  
    20  	// Maps sessions to identifiers to streams
    21  	// session -> identifier -> [stream]
    22  	sessionsStreams map[HubSession]map[string][]string
    23  
    24  	// This channel is used as a broadcast queue
    25  	sender chan *common.StreamMessage
    26  
    27  	mu  sync.RWMutex
    28  	log *slog.Logger
    29  }
    30  
    31  // NewGate creates a new gate.
    32  func NewGate(ctx context.Context, l *slog.Logger) *Gate {
    33  	g := Gate{
    34  		streams:         make(map[string]map[HubSession]map[string]bool),
    35  		sessionsStreams: make(map[HubSession]map[string][]string),
    36  		// Use a buffered channel to avoid blocking
    37  		sender: make(chan *common.StreamMessage, 256),
    38  		log:    l,
    39  	}
    40  
    41  	go g.broadcastLoop(ctx)
    42  
    43  	return &g
    44  }
    45  
    46  // Broadcast sends a message to all subscribers of the stream.
    47  func (g *Gate) Broadcast(streamMsg *common.StreamMessage) {
    48  	stream := streamMsg.Stream
    49  
    50  	ctx := g.log.With("stream", stream)
    51  
    52  	ctx.Debug("schedule broadcast", "message", streamMsg)
    53  
    54  	g.mu.RLock()
    55  	if _, ok := g.streams[stream]; !ok {
    56  		ctx.Debug("no sessions")
    57  		g.mu.RUnlock()
    58  		return
    59  	}
    60  	g.mu.RUnlock()
    61  
    62  	g.sender <- streamMsg
    63  }
    64  
    65  // Subscribe adds a session to the stream.
    66  func (g *Gate) Subscribe(session HubSession, stream string, identifier string) {
    67  	g.mu.Lock()
    68  	defer g.mu.Unlock()
    69  
    70  	if _, ok := g.streams[stream]; !ok {
    71  		g.streams[stream] = make(map[HubSession]map[string]bool)
    72  	}
    73  
    74  	if _, ok := g.streams[stream][session]; !ok {
    75  		g.streams[stream][session] = make(map[string]bool)
    76  	}
    77  
    78  	g.streams[stream][session][identifier] = true
    79  }
    80  
    81  // Unsubscribe removes a session from the stream.
    82  func (g *Gate) Unsubscribe(session HubSession, stream string, identifier string) {
    83  	g.mu.RLock()
    84  
    85  	if _, ok := g.streams[stream]; !ok {
    86  		g.mu.RUnlock()
    87  		return
    88  	}
    89  
    90  	if _, ok := g.streams[stream][session]; !ok {
    91  		g.mu.RUnlock()
    92  		return
    93  	}
    94  
    95  	if _, ok := g.streams[stream][session][identifier]; !ok {
    96  		g.mu.RUnlock()
    97  		return
    98  	}
    99  
   100  	g.mu.RUnlock()
   101  
   102  	g.mu.Lock()
   103  	defer g.mu.Unlock()
   104  
   105  	delete(g.streams[stream][session], identifier)
   106  
   107  	if len(g.streams[stream][session]) == 0 {
   108  		delete(g.streams[stream], session)
   109  
   110  		if len(g.streams[stream]) == 0 {
   111  			delete(g.streams, stream)
   112  		}
   113  	}
   114  }
   115  
   116  // Size returns a number of uniq streams
   117  func (g *Gate) Size() int {
   118  	g.mu.RLock()
   119  	defer g.mu.RUnlock()
   120  
   121  	return len(g.streams)
   122  }
   123  
   124  func (g *Gate) broadcastLoop(ctx context.Context) {
   125  	for {
   126  		select {
   127  		case <-ctx.Done():
   128  			return
   129  		case msg := <-g.sender:
   130  			g.performBroadcast(msg)
   131  		}
   132  	}
   133  }
   134  
   135  func (g *Gate) performBroadcast(streamMsg *common.StreamMessage) {
   136  	stream := streamMsg.Stream
   137  
   138  	buf := make(map[string](encoders.EncodedMessage))
   139  
   140  	var bdata encoders.EncodedMessage
   141  
   142  	g.mu.RLock()
   143  	streamSessions := streamSessionsSnapshot(g.streams[stream])
   144  	g.mu.RUnlock()
   145  
   146  	for session, ids := range streamSessions {
   147  		if streamMsg.Meta != nil && streamMsg.Meta.ExcludeSocket == session.GetID() {
   148  			continue
   149  		}
   150  
   151  		for _, id := range ids {
   152  			if msg, ok := buf[id]; ok {
   153  				bdata = msg
   154  			} else {
   155  				bdata = buildMessage(streamMsg, id)
   156  				buf[id] = bdata
   157  			}
   158  
   159  			session.Send(bdata)
   160  		}
   161  	}
   162  }
   163  
   164  func buildMessage(msg *common.StreamMessage, identifier string) encoders.EncodedMessage {
   165  	reply := msg.ToReplyFor(identifier)
   166  
   167  	if msg.Meta != nil {
   168  		reply.Type = msg.Meta.BroadcastType
   169  	}
   170  
   171  	return encoders.NewCachedEncodedMessage(reply)
   172  }
   173  
   174  func streamSessionsSnapshot[T comparable](src map[T]map[string]bool) map[T][]string {
   175  	dest := make(map[T][]string)
   176  
   177  	for k, v := range src {
   178  		dest[k] = make([]string, len(v))
   179  
   180  		i := 0
   181  
   182  		for id := range v {
   183  			dest[k][i] = id
   184  			i++
   185  		}
   186  	}
   187  
   188  	return dest
   189  }