github.com/haalcala/mattermost-server-change-repo/v5@v5.33.2/app/web_hub.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package app
     5  
     6  import (
     7  	"hash/maphash"
     8  	"runtime"
     9  	"runtime/debug"
    10  	"strconv"
    11  	"sync/atomic"
    12  
    13  	"github.com/mattermost/mattermost-server/v5/mlog"
    14  	"github.com/mattermost/mattermost-server/v5/model"
    15  )
    16  
    17  const (
    18  	broadcastQueueSize = 4096
    19  )
    20  
    21  type webConnActivityMessage struct {
    22  	userID       string
    23  	sessionToken string
    24  	activityAt   int64
    25  }
    26  
    27  type webConnDirectMessage struct {
    28  	conn *WebConn
    29  	msg  model.WebSocketMessage
    30  }
    31  
    32  type webConnSessionMessage struct {
    33  	userID       string
    34  	sessionToken string
    35  	isRegistered chan bool
    36  }
    37  
    38  // Hub is the central place to manage all websocket connections in the server.
    39  // It handles different websocket events and sending messages to individual
    40  // user connections.
    41  type Hub struct {
    42  	// connectionCount should be kept first.
    43  	// See https://github.com/mattermost/mattermost-server/pull/7281
    44  	connectionCount int64
    45  	app             *App
    46  	connectionIndex int
    47  	register        chan *WebConn
    48  	unregister      chan *WebConn
    49  	broadcast       chan *model.WebSocketEvent
    50  	stop            chan struct{}
    51  	didStop         chan struct{}
    52  	invalidateUser  chan string
    53  	activity        chan *webConnActivityMessage
    54  	directMsg       chan *webConnDirectMessage
    55  	explicitStop    bool
    56  	checkRegistered chan *webConnSessionMessage
    57  }
    58  
    59  // NewWebHub creates a new Hub.
    60  func (a *App) NewWebHub() *Hub {
    61  	return &Hub{
    62  		app:             a,
    63  		register:        make(chan *WebConn),
    64  		unregister:      make(chan *WebConn),
    65  		broadcast:       make(chan *model.WebSocketEvent, broadcastQueueSize),
    66  		stop:            make(chan struct{}),
    67  		didStop:         make(chan struct{}),
    68  		invalidateUser:  make(chan string),
    69  		activity:        make(chan *webConnActivityMessage),
    70  		directMsg:       make(chan *webConnDirectMessage),
    71  		checkRegistered: make(chan *webConnSessionMessage),
    72  	}
    73  }
    74  
    75  func (a *App) TotalWebsocketConnections() int {
    76  	return a.Srv().TotalWebsocketConnections()
    77  }
    78  
    79  // HubStart starts all the hubs.
    80  func (a *App) HubStart() {
    81  	// Total number of hubs is twice the number of CPUs.
    82  	numberOfHubs := runtime.NumCPU() * 2
    83  	mlog.Info("Starting websocket hubs", mlog.Int("number_of_hubs", numberOfHubs))
    84  
    85  	hubs := make([]*Hub, numberOfHubs)
    86  
    87  	for i := 0; i < numberOfHubs; i++ {
    88  		hubs[i] = a.NewWebHub()
    89  		hubs[i].connectionIndex = i
    90  		hubs[i].Start()
    91  	}
    92  	// Assigning to the hubs slice without any mutex is fine because it is only assigned once
    93  	// during the start of the program and always read from after that.
    94  	a.srv.hubs = hubs
    95  }
    96  
    97  func (a *App) invalidateCacheForUserSkipClusterSend(userID string) {
    98  	a.Srv().Store.Channel().InvalidateAllChannelMembersForUser(userID)
    99  	a.InvalidateWebConnSessionCacheForUser(userID)
   100  }
   101  
   102  func (a *App) invalidateCacheForWebhook(webhookID string) {
   103  	a.Srv().Store.Webhook().InvalidateWebhookCache(webhookID)
   104  }
   105  
   106  func (a *App) InvalidateWebConnSessionCacheForUser(userID string) {
   107  	hub := a.GetHubForUserId(userID)
   108  	if hub != nil {
   109  		hub.InvalidateUser(userID)
   110  	}
   111  }
   112  
   113  // HubStop stops all the hubs.
   114  func (s *Server) HubStop() {
   115  	mlog.Info("stopping websocket hub connections")
   116  
   117  	for _, hub := range s.hubs {
   118  		hub.Stop()
   119  	}
   120  }
   121  
   122  func (a *App) HubStop() {
   123  	a.Srv().HubStop()
   124  }
   125  
   126  // GetHubForUserId returns the hub for a given user id.
   127  func (s *Server) GetHubForUserId(userID string) *Hub {
   128  	// TODO: check if caching the userID -> hub mapping
   129  	// is worth the memory tradeoff.
   130  	// https://mattermost.atlassian.net/browse/MM-26629.
   131  	var hash maphash.Hash
   132  	hash.SetSeed(s.hashSeed)
   133  	hash.Write([]byte(userID))
   134  	index := hash.Sum64() % uint64(len(s.hubs))
   135  
   136  	return s.hubs[int(index)]
   137  }
   138  
   139  func (a *App) GetHubForUserId(userID string) *Hub {
   140  	return a.Srv().GetHubForUserId(userID)
   141  }
   142  
   143  // HubRegister registers a connection to a hub.
   144  func (a *App) HubRegister(webConn *WebConn) {
   145  	hub := a.GetHubForUserId(webConn.UserId)
   146  	if hub != nil {
   147  		if metrics := a.Metrics(); metrics != nil {
   148  			metrics.IncrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1)
   149  		}
   150  		hub.Register(webConn)
   151  	}
   152  }
   153  
   154  // HubUnregister unregisters a connection from a hub.
   155  func (a *App) HubUnregister(webConn *WebConn) {
   156  	hub := a.GetHubForUserId(webConn.UserId)
   157  	if hub != nil {
   158  		if metrics := a.Metrics(); metrics != nil {
   159  			metrics.DecrementWebSocketBroadcastUsersRegistered(strconv.Itoa(hub.connectionIndex), 1)
   160  		}
   161  		hub.Unregister(webConn)
   162  	}
   163  }
   164  
   165  func (s *Server) Publish(message *model.WebSocketEvent) {
   166  	if s.Metrics != nil {
   167  		s.Metrics.IncrementWebsocketEvent(message.EventType())
   168  	}
   169  
   170  	s.PublishSkipClusterSend(message)
   171  
   172  	if s.Cluster != nil {
   173  		cm := &model.ClusterMessage{
   174  			Event:    model.CLUSTER_EVENT_PUBLISH,
   175  			SendType: model.CLUSTER_SEND_BEST_EFFORT,
   176  			Data:     message.ToJson(),
   177  		}
   178  
   179  		if message.EventType() == model.WEBSOCKET_EVENT_POSTED ||
   180  			message.EventType() == model.WEBSOCKET_EVENT_POST_EDITED ||
   181  			message.EventType() == model.WEBSOCKET_EVENT_DIRECT_ADDED ||
   182  			message.EventType() == model.WEBSOCKET_EVENT_GROUP_ADDED ||
   183  			message.EventType() == model.WEBSOCKET_EVENT_ADDED_TO_TEAM {
   184  			cm.SendType = model.CLUSTER_SEND_RELIABLE
   185  		}
   186  
   187  		s.Cluster.SendClusterMessage(cm)
   188  	}
   189  }
   190  
   191  func (a *App) Publish(message *model.WebSocketEvent) {
   192  	a.Srv().Publish(message)
   193  }
   194  
   195  func (s *Server) PublishSkipClusterSend(message *model.WebSocketEvent) {
   196  	if message.GetBroadcast().UserId != "" {
   197  		hub := s.GetHubForUserId(message.GetBroadcast().UserId)
   198  		if hub != nil {
   199  			hub.Broadcast(message)
   200  		}
   201  	} else {
   202  		for _, hub := range s.hubs {
   203  			hub.Broadcast(message)
   204  		}
   205  	}
   206  }
   207  
   208  func (a *App) PublishSkipClusterSend(message *model.WebSocketEvent) {
   209  	a.Srv().PublishSkipClusterSend(message)
   210  }
   211  
   212  func (a *App) invalidateCacheForChannel(channel *model.Channel) {
   213  	a.Srv().Store.Channel().InvalidateChannel(channel.Id)
   214  	a.invalidateCacheForChannelByNameSkipClusterSend(channel.TeamId, channel.Name)
   215  
   216  	if a.Cluster() != nil {
   217  		nameMsg := &model.ClusterMessage{
   218  			Event:    model.CLUSTER_EVENT_INVALIDATE_CACHE_FOR_CHANNEL_BY_NAME,
   219  			SendType: model.CLUSTER_SEND_BEST_EFFORT,
   220  			Props:    make(map[string]string),
   221  		}
   222  
   223  		nameMsg.Props["name"] = channel.Name
   224  		if channel.TeamId == "" {
   225  			nameMsg.Props["id"] = "dm"
   226  		} else {
   227  			nameMsg.Props["id"] = channel.TeamId
   228  		}
   229  
   230  		a.Cluster().SendClusterMessage(nameMsg)
   231  	}
   232  }
   233  
   234  func (a *App) invalidateCacheForChannelMembers(channelId string) {
   235  	a.Srv().Store.User().InvalidateProfilesInChannelCache(channelId)
   236  	a.Srv().Store.Channel().InvalidateMemberCount(channelId)
   237  	a.Srv().Store.Channel().InvalidateGuestCount(channelId)
   238  }
   239  
   240  func (a *App) invalidateCacheForChannelMembersNotifyProps(channelId string) {
   241  	a.invalidateCacheForChannelMembersNotifyPropsSkipClusterSend(channelId)
   242  
   243  	if a.Cluster() != nil {
   244  		msg := &model.ClusterMessage{
   245  			Event:    model.CLUSTER_EVENT_INVALIDATE_CACHE_FOR_CHANNEL_MEMBERS_NOTIFY_PROPS,
   246  			SendType: model.CLUSTER_SEND_BEST_EFFORT,
   247  			Data:     channelId,
   248  		}
   249  		a.Cluster().SendClusterMessage(msg)
   250  	}
   251  }
   252  
   253  func (a *App) invalidateCacheForChannelMembersNotifyPropsSkipClusterSend(channelId string) {
   254  	a.Srv().Store.Channel().InvalidateCacheForChannelMembersNotifyProps(channelId)
   255  }
   256  
   257  func (a *App) invalidateCacheForChannelByNameSkipClusterSend(teamID, name string) {
   258  	if teamID == "" {
   259  		teamID = "dm"
   260  	}
   261  
   262  	a.Srv().Store.Channel().InvalidateChannelByName(teamID, name)
   263  }
   264  
   265  func (a *App) invalidateCacheForChannelPosts(channelId string) {
   266  	a.Srv().Store.Channel().InvalidatePinnedPostCount(channelId)
   267  	a.Srv().Store.Post().InvalidateLastPostTimeCache(channelId)
   268  }
   269  
   270  func (a *App) InvalidateCacheForUser(userID string) {
   271  	a.invalidateCacheForUserSkipClusterSend(userID)
   272  
   273  	a.Srv().Store.User().InvalidateProfilesInChannelCacheByUser(userID)
   274  	a.Srv().Store.User().InvalidateProfileCacheForUser(userID)
   275  
   276  	if a.Cluster() != nil {
   277  		msg := &model.ClusterMessage{
   278  			Event:    model.CLUSTER_EVENT_INVALIDATE_CACHE_FOR_USER,
   279  			SendType: model.CLUSTER_SEND_BEST_EFFORT,
   280  			Data:     userID,
   281  		}
   282  		a.Cluster().SendClusterMessage(msg)
   283  	}
   284  }
   285  
   286  func (a *App) invalidateCacheForUserTeams(userID string) {
   287  	a.InvalidateWebConnSessionCacheForUser(userID)
   288  	a.Srv().Store.Team().InvalidateAllTeamIdsForUser(userID)
   289  
   290  	if a.Cluster() != nil {
   291  		msg := &model.ClusterMessage{
   292  			Event:    model.CLUSTER_EVENT_INVALIDATE_CACHE_FOR_USER_TEAMS,
   293  			SendType: model.CLUSTER_SEND_BEST_EFFORT,
   294  			Data:     userID,
   295  		}
   296  		a.Cluster().SendClusterMessage(msg)
   297  	}
   298  }
   299  
   300  // UpdateWebConnUserActivity sets the LastUserActivityAt of the hub for the given session.
   301  func (a *App) UpdateWebConnUserActivity(session model.Session, activityAt int64) {
   302  	hub := a.GetHubForUserId(session.UserId)
   303  	if hub != nil {
   304  		hub.UpdateActivity(session.UserId, session.Token, activityAt)
   305  	}
   306  }
   307  
   308  // SessionIsRegistered determines if a specific session has been registered
   309  func (a *App) SessionIsRegistered(session model.Session) bool {
   310  	hub := a.GetHubForUserId(session.UserId)
   311  	if hub != nil {
   312  		return hub.IsRegistered(session.UserId, session.Token)
   313  	}
   314  	return false
   315  }
   316  
   317  // Register registers a connection to the hub.
   318  func (h *Hub) Register(webConn *WebConn) {
   319  	select {
   320  	case h.register <- webConn:
   321  	case <-h.stop:
   322  	}
   323  }
   324  
   325  // Unregister unregisters a connection from the hub.
   326  func (h *Hub) Unregister(webConn *WebConn) {
   327  	select {
   328  	case h.unregister <- webConn:
   329  	case <-h.stop:
   330  	}
   331  }
   332  
   333  // Determines if a user's session is registered a connection from the hub.
   334  func (h *Hub) IsRegistered(userID, sessionToken string) bool {
   335  	ws := &webConnSessionMessage{
   336  		userID:       userID,
   337  		sessionToken: sessionToken,
   338  		isRegistered: make(chan bool),
   339  	}
   340  	select {
   341  	case h.checkRegistered <- ws:
   342  		return <-ws.isRegistered
   343  	case <-h.stop:
   344  	}
   345  	return false
   346  }
   347  
   348  // Broadcast broadcasts the message to all connections in the hub.
   349  func (h *Hub) Broadcast(message *model.WebSocketEvent) {
   350  	// XXX: The hub nil check is because of the way we setup our tests. We call
   351  	// `app.NewServer()` which returns a server, but only after that, we call
   352  	// `wsapi.Init()` to initialize the hub.  But in the `NewServer` call
   353  	// itself proceeds to broadcast some messages happily.  This needs to be
   354  	// fixed once the wsapi cyclic dependency with server/app goes away.
   355  	// And possibly, we can look into doing the hub initialization inside
   356  	// NewServer itself.
   357  	if h != nil && message != nil {
   358  		if metrics := h.app.Metrics(); metrics != nil {
   359  			metrics.IncrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
   360  		}
   361  		select {
   362  		case h.broadcast <- message:
   363  		case <-h.stop:
   364  		}
   365  	}
   366  }
   367  
   368  // InvalidateUser invalidates the cache for the given user.
   369  func (h *Hub) InvalidateUser(userID string) {
   370  	select {
   371  	case h.invalidateUser <- userID:
   372  	case <-h.stop:
   373  	}
   374  }
   375  
   376  // UpdateActivity sets the LastUserActivityAt field for the connection
   377  // of the user.
   378  func (h *Hub) UpdateActivity(userID, sessionToken string, activityAt int64) {
   379  	select {
   380  	case h.activity <- &webConnActivityMessage{
   381  		userID:       userID,
   382  		sessionToken: sessionToken,
   383  		activityAt:   activityAt,
   384  	}:
   385  	case <-h.stop:
   386  	}
   387  }
   388  
   389  // SendMessage sends the given message to the given connection.
   390  func (h *Hub) SendMessage(conn *WebConn, msg model.WebSocketMessage) {
   391  	select {
   392  	case h.directMsg <- &webConnDirectMessage{
   393  		conn: conn,
   394  		msg:  msg,
   395  	}:
   396  	case <-h.stop:
   397  	}
   398  }
   399  
   400  // Stop stops the hub.
   401  func (h *Hub) Stop() {
   402  	close(h.stop)
   403  	<-h.didStop
   404  }
   405  
   406  // Start starts the hub.
   407  func (h *Hub) Start() {
   408  	var doStart func()
   409  	var doRecoverableStart func()
   410  	var doRecover func()
   411  
   412  	doStart = func() {
   413  		mlog.Debug("Hub is starting", mlog.Int("index", h.connectionIndex))
   414  
   415  		connIndex := newHubConnectionIndex()
   416  
   417  		for {
   418  			select {
   419  			case webSessionMessage := <-h.checkRegistered:
   420  				conns := connIndex.ForUser(webSessionMessage.userID)
   421  				var isRegistered bool
   422  				for _, conn := range conns {
   423  					if conn.GetSessionToken() == webSessionMessage.sessionToken {
   424  						isRegistered = true
   425  					}
   426  				}
   427  				webSessionMessage.isRegistered <- isRegistered
   428  			case webConn := <-h.register:
   429  				connIndex.Add(webConn)
   430  				atomic.StoreInt64(&h.connectionCount, int64(len(connIndex.All())))
   431  				if webConn.IsAuthenticated() {
   432  					webConn.send <- webConn.createHelloMessage()
   433  				}
   434  			case webConn := <-h.unregister:
   435  				connIndex.Remove(webConn)
   436  				atomic.StoreInt64(&h.connectionCount, int64(len(connIndex.All())))
   437  
   438  				if webConn.UserId == "" {
   439  					continue
   440  				}
   441  
   442  				conns := connIndex.ForUser(webConn.UserId)
   443  				if len(conns) == 0 {
   444  					h.app.Srv().Go(func() {
   445  						h.app.SetStatusOffline(webConn.UserId, false)
   446  					})
   447  					continue
   448  				}
   449  				var latestActivity int64 = 0
   450  				for _, conn := range conns {
   451  					if conn.lastUserActivityAt > latestActivity {
   452  						latestActivity = conn.lastUserActivityAt
   453  					}
   454  				}
   455  
   456  				if h.app.IsUserAway(latestActivity) {
   457  					h.app.Srv().Go(func() {
   458  						h.app.SetStatusLastActivityAt(webConn.UserId, latestActivity)
   459  					})
   460  				}
   461  			case userID := <-h.invalidateUser:
   462  				for _, webConn := range connIndex.ForUser(userID) {
   463  					webConn.InvalidateCache()
   464  				}
   465  			case activity := <-h.activity:
   466  				for _, webConn := range connIndex.ForUser(activity.userID) {
   467  					if webConn.GetSessionToken() == activity.sessionToken {
   468  						webConn.lastUserActivityAt = activity.activityAt
   469  					}
   470  				}
   471  			case directMsg := <-h.directMsg:
   472  				if !connIndex.Has(directMsg.conn) {
   473  					continue
   474  				}
   475  				select {
   476  				case directMsg.conn.send <- directMsg.msg:
   477  				default:
   478  					mlog.Error("webhub.broadcast: cannot send, closing websocket for user", mlog.String("user_id", directMsg.conn.UserId))
   479  					close(directMsg.conn.send)
   480  					connIndex.Remove(directMsg.conn)
   481  				}
   482  			case msg := <-h.broadcast:
   483  				if metrics := h.app.Metrics(); metrics != nil {
   484  					metrics.DecrementWebSocketBroadcastBufferSize(strconv.Itoa(h.connectionIndex), 1)
   485  				}
   486  				msg = msg.PrecomputeJSON()
   487  				broadcast := func(webConn *WebConn) {
   488  					if !connIndex.Has(webConn) {
   489  						return
   490  					}
   491  					if webConn.shouldSendEvent(msg) {
   492  						select {
   493  						case webConn.send <- msg:
   494  						default:
   495  							mlog.Error("webhub.broadcast: cannot send, closing websocket for user", mlog.String("user_id", webConn.UserId))
   496  							close(webConn.send)
   497  							connIndex.Remove(webConn)
   498  						}
   499  					}
   500  				}
   501  				if msg.GetBroadcast().UserId != "" {
   502  					candidates := connIndex.ForUser(msg.GetBroadcast().UserId)
   503  					for _, webConn := range candidates {
   504  						broadcast(webConn)
   505  					}
   506  					continue
   507  				}
   508  				candidates := connIndex.All()
   509  				for webConn := range candidates {
   510  					broadcast(webConn)
   511  				}
   512  			case <-h.stop:
   513  				for webConn := range connIndex.All() {
   514  					webConn.Close()
   515  					h.app.SetStatusOffline(webConn.UserId, false)
   516  				}
   517  
   518  				h.explicitStop = true
   519  				close(h.didStop)
   520  
   521  				return
   522  			}
   523  		}
   524  	}
   525  
   526  	doRecoverableStart = func() {
   527  		defer doRecover()
   528  		doStart()
   529  	}
   530  
   531  	doRecover = func() {
   532  		if !h.explicitStop {
   533  			if r := recover(); r != nil {
   534  				mlog.Error("Recovering from Hub panic.", mlog.Any("panic", r))
   535  			} else {
   536  				mlog.Error("Webhub stopped unexpectedly. Recovering.")
   537  			}
   538  
   539  			mlog.Error(string(debug.Stack()))
   540  
   541  			go doRecoverableStart()
   542  		}
   543  	}
   544  
   545  	go doRecoverableStart()
   546  }
   547  
   548  // hubConnectionIndex provides fast addition, removal, and iteration of web connections.
   549  // It requires 3 functionalities which need to be very fast:
   550  // - check if a connection exists or not.
   551  // - get all connections for a given userID.
   552  // - get all connections.
   553  type hubConnectionIndex struct {
   554  	// byUserId stores the list of connections for a given userID
   555  	byUserId map[string][]*WebConn
   556  	// byConnection serves the dual purpose of storing the index of the webconn
   557  	// in the value of byUserId map, and also to get all connections.
   558  	byConnection map[*WebConn]int
   559  }
   560  
   561  func newHubConnectionIndex() *hubConnectionIndex {
   562  	return &hubConnectionIndex{
   563  		byUserId:     make(map[string][]*WebConn),
   564  		byConnection: make(map[*WebConn]int),
   565  	}
   566  }
   567  
   568  func (i *hubConnectionIndex) Add(wc *WebConn) {
   569  	i.byUserId[wc.UserId] = append(i.byUserId[wc.UserId], wc)
   570  	i.byConnection[wc] = len(i.byUserId[wc.UserId]) - 1
   571  }
   572  
   573  func (i *hubConnectionIndex) Remove(wc *WebConn) {
   574  	userConnIndex, ok := i.byConnection[wc]
   575  	if !ok {
   576  		return
   577  	}
   578  
   579  	// get the conn slice.
   580  	userConnections := i.byUserId[wc.UserId]
   581  	// get the last connection.
   582  	last := userConnections[len(userConnections)-1]
   583  	// set the slot that we are trying to remove to be the last connection.
   584  	userConnections[userConnIndex] = last
   585  	// remove the last connection from the slice.
   586  	i.byUserId[wc.UserId] = userConnections[:len(userConnections)-1]
   587  	// set the index of the connection that was moved to the new index.
   588  	i.byConnection[last] = userConnIndex
   589  
   590  	delete(i.byConnection, wc)
   591  }
   592  
   593  func (i *hubConnectionIndex) Has(wc *WebConn) bool {
   594  	_, ok := i.byConnection[wc]
   595  	return ok
   596  }
   597  
   598  func (i *hubConnectionIndex) ForUser(id string) []*WebConn {
   599  	return i.byUserId[id]
   600  }
   601  
   602  func (i *hubConnectionIndex) All() map[*WebConn]int {
   603  	return i.byConnection
   604  }