github.com/anycable/anycable-go@v1.5.1/hub/hub.go (about)

     1  package hub
     2  
     3  import (
     4  	"context"
     5  	"hash/fnv"
     6  	"log/slog"
     7  	"sync"
     8  
     9  	"github.com/anycable/anycable-go/common"
    10  	"github.com/anycable/anycable-go/encoders"
    11  	"github.com/anycable/anycable-go/utils"
    12  )
    13  
    14  type HubSession interface {
    15  	GetID() string
    16  	GetIdentifiers() string
    17  	Send(msg encoders.EncodedMessage)
    18  	DisconnectWithMessage(msg encoders.EncodedMessage, code string)
    19  }
    20  
    21  // HubRegistration represents registration event ("add" or "remove")
    22  type HubRegistration struct {
    23  	event   string
    24  	session HubSession
    25  }
    26  
    27  // HubSessionInfo is used to track registered sessions
    28  type HubSessionInfo struct {
    29  	session HubSession
    30  	// List of stream-identifier pairs
    31  	streams [][]string
    32  }
    33  
    34  func NewHubSessionInfo(session HubSession) *HubSessionInfo {
    35  	return &HubSessionInfo{
    36  		session: session,
    37  		streams: make([][]string, 0),
    38  	}
    39  }
    40  
    41  func (hs *HubSessionInfo) AddStream(stream string, identifier string) {
    42  	hs.streams = append(hs.streams, []string{stream, identifier})
    43  }
    44  
    45  func (hs *HubSessionInfo) RemoveStream(stream string, identifier string) {
    46  	for i, s := range hs.streams {
    47  		if s[0] == stream && s[1] == identifier {
    48  			hs.streams = append(hs.streams[:i], hs.streams[i+1:]...)
    49  			break
    50  		}
    51  	}
    52  }
    53  
    54  // Hub stores all the sessions and the corresponding subscriptions info
    55  type Hub struct {
    56  	// Gates (=shards)
    57  	gates    []*Gate
    58  	gatesNum int
    59  
    60  	// Registered sessions
    61  	sessions map[string]*HubSessionInfo
    62  
    63  	// Identifiers to session
    64  	identifiers map[string]map[string]bool
    65  
    66  	// Messages for specified stream
    67  	broadcast chan *common.StreamMessage
    68  
    69  	// Remote disconnect commands
    70  	disconnect chan *common.RemoteDisconnectMessage
    71  
    72  	// Register requests from the sessions
    73  	register chan HubRegistration
    74  
    75  	// Control channel to shutdown hub
    76  	shutdown chan struct{}
    77  
    78  	// Synchronization group to wait for gracefully disconnect of all sessions
    79  	done sync.WaitGroup
    80  
    81  	doneFn context.CancelFunc
    82  
    83  	// Log context
    84  	log *slog.Logger
    85  
    86  	// go pool
    87  	pool *utils.GoPool
    88  
    89  	// mutex for sessions data tracking
    90  	mu sync.RWMutex
    91  }
    92  
    93  // NewHub builds new hub instance
    94  func NewHub(poolSize int, l *slog.Logger) *Hub {
    95  	ctx, doneFn := context.WithCancel(context.Background())
    96  
    97  	return &Hub{
    98  		broadcast:   make(chan *common.StreamMessage, 256),
    99  		disconnect:  make(chan *common.RemoteDisconnectMessage, 128),
   100  		register:    make(chan HubRegistration, 2048),
   101  		sessions:    make(map[string]*HubSessionInfo),
   102  		identifiers: make(map[string]map[string]bool),
   103  		gates:       buildGates(ctx, poolSize, l),
   104  		gatesNum:    poolSize,
   105  		pool:        utils.NewGoPool("remote commands", 256),
   106  		doneFn:      doneFn,
   107  		shutdown:    make(chan struct{}),
   108  		log:         l.With("component", "hub"),
   109  	}
   110  }
   111  
   112  // Run makes hub active
   113  func (h *Hub) Run() {
   114  	h.done.Add(1)
   115  	for {
   116  		select {
   117  		case r := <-h.register:
   118  			if r.event == "add" {
   119  				h.AddSession(r.session)
   120  			} else {
   121  				h.RemoveSession(r.session)
   122  			}
   123  
   124  		case message := <-h.broadcast:
   125  			h.broadcastToStream(message)
   126  
   127  		case command := <-h.disconnect:
   128  			h.disconnectSessions(command.Identifier, command.Reconnect)
   129  
   130  		case <-h.shutdown:
   131  			h.done.Done()
   132  			return
   133  		}
   134  	}
   135  }
   136  
   137  // RemoveSession enqueues session un-registration
   138  func (h *Hub) RemoveSessionLater(s HubSession) {
   139  	h.register <- HubRegistration{event: "remove", session: s}
   140  }
   141  
   142  // Broadcast enqueues data broadcasting to a stream
   143  func (h *Hub) Broadcast(stream string, data string) {
   144  	h.broadcast <- &common.StreamMessage{Stream: stream, Data: data}
   145  }
   146  
   147  // BroadcastMessage enqueues broadcasting a pre-built StreamMessage
   148  func (h *Hub) BroadcastMessage(msg *common.StreamMessage) {
   149  	h.broadcast <- msg
   150  }
   151  
   152  // RemoteDisconnect enqueues remote disconnect command
   153  func (h *Hub) RemoteDisconnect(msg *common.RemoteDisconnectMessage) {
   154  	h.disconnect <- msg
   155  }
   156  
   157  // Shutdown sends shutdown command to hub
   158  func (h *Hub) Shutdown() {
   159  	h.shutdown <- struct{}{}
   160  
   161  	// Wait for stop listening channels
   162  	h.done.Wait()
   163  }
   164  
   165  // Size returns a number of active sessions
   166  func (h *Hub) Size() int {
   167  	h.mu.RLock()
   168  	defer h.mu.RUnlock()
   169  
   170  	return len(h.sessions)
   171  }
   172  
   173  // UniqSize returns a number of uniq identifiers
   174  func (h *Hub) UniqSize() int {
   175  	h.mu.RLock()
   176  	defer h.mu.RUnlock()
   177  
   178  	return len(h.identifiers)
   179  }
   180  
   181  // StreamsSize returns a number of uniq streams
   182  func (h *Hub) StreamsSize() int {
   183  	size := 0
   184  	for _, gate := range h.gates {
   185  		size += gate.Size()
   186  	}
   187  	return size
   188  }
   189  
   190  func (h *Hub) AddSession(session HubSession) {
   191  	h.mu.Lock()
   192  	defer h.mu.Unlock()
   193  
   194  	uid := session.GetID()
   195  	identifiers := session.GetIdentifiers()
   196  
   197  	h.sessions[uid] = NewHubSessionInfo(session)
   198  
   199  	if _, ok := h.identifiers[identifiers]; !ok {
   200  		h.identifiers[identifiers] = make(map[string]bool)
   201  	}
   202  
   203  	h.identifiers[identifiers][uid] = true
   204  
   205  	h.log.With("sid", uid).Debug(
   206  		"registered", "ids", identifiers,
   207  	)
   208  }
   209  
   210  func (h *Hub) RemoveSession(session HubSession) {
   211  	h.mu.RLock()
   212  	uid := session.GetID()
   213  
   214  	if _, ok := h.sessions[uid]; !ok {
   215  		h.mu.RUnlock()
   216  		h.log.With("sid", uid).Warn("session hasn't been registered")
   217  		return
   218  	}
   219  	h.mu.RUnlock()
   220  
   221  	identifiers := session.GetIdentifiers()
   222  	h.unsubscribeSessionFromAllChannels(session)
   223  
   224  	h.mu.Lock()
   225  
   226  	delete(h.sessions, uid)
   227  	delete(h.identifiers[identifiers], uid)
   228  
   229  	if len(h.identifiers[identifiers]) == 0 {
   230  		delete(h.identifiers, identifiers)
   231  	}
   232  
   233  	h.mu.Unlock()
   234  
   235  	h.log.With("sid", uid).Debug("unregistered")
   236  }
   237  
   238  func (h *Hub) unsubscribeSessionFromAllChannels(session HubSession) {
   239  	h.mu.Lock()
   240  	defer h.mu.Unlock()
   241  
   242  	sid := session.GetID()
   243  
   244  	if sessionInfo, ok := h.sessions[sid]; ok {
   245  		for _, streamInfo := range sessionInfo.streams {
   246  			stream, identifier := streamInfo[0], streamInfo[1]
   247  
   248  			h.gates[index(stream, h.gatesNum)].Unsubscribe(session, stream, identifier)
   249  		}
   250  	}
   251  }
   252  
   253  func (h *Hub) UnsubscribeSessionFromChannel(session HubSession, targetIdentifier string) {
   254  	h.mu.Lock()
   255  	defer h.mu.Unlock()
   256  
   257  	sid := session.GetID()
   258  
   259  	if sessionInfo, ok := h.sessions[sid]; ok {
   260  		for _, streamInfo := range sessionInfo.streams {
   261  			stream, identifier := streamInfo[0], streamInfo[1]
   262  
   263  			if targetIdentifier == identifier {
   264  				h.gates[index(stream, h.gatesNum)].Unsubscribe(session, stream, identifier)
   265  				sessionInfo.RemoveStream(stream, identifier)
   266  			}
   267  		}
   268  	}
   269  
   270  	h.log.With("sid", sid).Debug("unsubscribed", "identifier", targetIdentifier)
   271  }
   272  
   273  func (h *Hub) SubscribeSession(session HubSession, stream string, identifier string) {
   274  	h.gates[index(stream, h.gatesNum)].Subscribe(session, stream, identifier)
   275  
   276  	h.mu.Lock()
   277  	defer h.mu.Unlock()
   278  
   279  	sid := session.GetID()
   280  
   281  	if _, ok := h.sessions[sid]; !ok {
   282  		h.sessions[sid] = NewHubSessionInfo(session)
   283  	}
   284  
   285  	h.sessions[sid].AddStream(stream, identifier)
   286  
   287  	h.log.With("sid", sid).Debug("subscribed", "identifier", identifier, "stream", stream)
   288  }
   289  
   290  func (h *Hub) UnsubscribeSession(session HubSession, stream string, identifier string) {
   291  	h.gates[index(stream, h.gatesNum)].Unsubscribe(session, stream, identifier)
   292  
   293  	h.mu.Lock()
   294  	defer h.mu.Unlock()
   295  
   296  	sid := session.GetID()
   297  
   298  	if info, ok := h.sessions[sid]; ok {
   299  		info.RemoveStream(stream, identifier)
   300  	}
   301  
   302  	h.log.With("sid", sid).Debug("unsubscribed", "identifier", identifier, "stream", stream)
   303  }
   304  
   305  func (h *Hub) broadcastToStream(streamMsg *common.StreamMessage) {
   306  	h.gates[index(streamMsg.Stream, h.gatesNum)].Broadcast(streamMsg)
   307  }
   308  
   309  func (h *Hub) disconnectSessions(identifier string, reconnect bool) {
   310  	h.mu.RLock()
   311  	ids, ok := h.identifiers[identifier]
   312  	h.mu.RUnlock()
   313  
   314  	if !ok {
   315  		h.log.Debug("cannot disconnect session", "identifier", identifier, "reason", "not found")
   316  		return
   317  	}
   318  
   319  	msg := common.NewDisconnectMessage(common.REMOTE_DISCONNECT_REASON, reconnect)
   320  
   321  	h.pool.Schedule(func() {
   322  		h.mu.RLock()
   323  		defer h.mu.RUnlock()
   324  
   325  		for id := range ids {
   326  			if sinfo, ok := h.sessions[id]; ok {
   327  				sinfo.session.DisconnectWithMessage(msg, common.REMOTE_DISCONNECT_REASON)
   328  			}
   329  		}
   330  	})
   331  }
   332  
   333  func (h *Hub) FindByIdentifier(id string) HubSession {
   334  	h.mu.RLock()
   335  	defer h.mu.RUnlock()
   336  
   337  	ids, ok := h.identifiers[id]
   338  
   339  	if !ok {
   340  		return nil
   341  	}
   342  
   343  	for id := range ids {
   344  		if info, ok := h.sessions[id]; ok {
   345  			return info.session
   346  		}
   347  	}
   348  
   349  	return nil
   350  }
   351  
   352  func (h *Hub) Sessions() []HubSession {
   353  	h.mu.RLock()
   354  	defer h.mu.RUnlock()
   355  
   356  	sessions := make([]HubSession, 0, len(h.sessions))
   357  
   358  	for _, info := range h.sessions {
   359  		sessions = append(sessions, info.session)
   360  	}
   361  
   362  	return sessions
   363  }
   364  
   365  func buildGates(ctx context.Context, num int, l *slog.Logger) []*Gate {
   366  	gates := make([]*Gate, 0, num)
   367  	for i := 0; i < num; i++ {
   368  		gates = append(gates, NewGate(ctx, l.With("component", "hub", "gate", i)))
   369  	}
   370  
   371  	return gates
   372  }
   373  
   374  func index(stream string, size int) int {
   375  	if size == 1 {
   376  		return 0
   377  	}
   378  
   379  	hash := fnv.New64a()
   380  	hash.Write([]byte(stream))
   381  	return int(hash.Sum64() % uint64(size))
   382  }