github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/common/ws_hub.go (about)

     1  package common
     2  
     3  import (
     4  	"encoding/json"
     5  	"log"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/gorilla/websocket"
    10  )
    11  
    12  // TODO: Rename this to WebSockets?
    13  var WsHub WsHubImpl
    14  
    15  // TODO: Make this an interface?
    16  // TODO: Write tests for this
    17  type WsHubImpl struct {
    18  	// TODO: Implement some form of generics so we don't write as much odd-even sharding code
    19  	evenOnlineUsers map[int]*WSUser
    20  	oddOnlineUsers  map[int]*WSUser
    21  	evenUserLock    sync.RWMutex
    22  	oddUserLock     sync.RWMutex
    23  
    24  	// TODO: Add sharding for this too?
    25  	OnlineGuests map[*WSUser]bool
    26  	GuestLock    sync.RWMutex
    27  
    28  	lastTick      time.Time
    29  	lastTopicList []*TopicsRow
    30  }
    31  
    32  func init() {
    33  	// TODO: Do we really want to initialise this here instead of in main.go / general_test.go like the other things?
    34  	WsHub = WsHubImpl{
    35  		evenOnlineUsers: make(map[int]*WSUser),
    36  		oddOnlineUsers:  make(map[int]*WSUser),
    37  		OnlineGuests:    make(map[*WSUser]bool),
    38  	}
    39  }
    40  
    41  func (h *WsHubImpl) Start() {
    42  	log.Print("Setting up the WebSocket ticks")
    43  	ticker := time.NewTicker(time.Minute * 5)
    44  	defer func() {
    45  		ticker.Stop()
    46  	}()
    47  
    48  	go func() {
    49  		defer EatPanics()
    50  		for {
    51  			item := func(l *sync.RWMutex, userMap map[int]*WSUser) {
    52  				l.RLock()
    53  				defer l.RUnlock()
    54  				// TODO: Copy to temporary slice for less contention?
    55  				for _, u := range userMap {
    56  					u.Ping()
    57  				}
    58  			}
    59  			select {
    60  			case <-ticker.C:
    61  				item(&h.evenUserLock, h.evenOnlineUsers)
    62  				item(&h.oddUserLock, h.oddOnlineUsers)
    63  			}
    64  		}
    65  	}()
    66  	if Config.DisableLiveTopicList {
    67  		return
    68  	}
    69  	h.lastTick = time.Now()
    70  	Tasks.Sec.Add(h.Tick)
    71  }
    72  
    73  // This Tick is separate from the admin one, as we want to process that in parallel with this due to the blocking calls to gopsutil
    74  func (h *WsHubImpl) Tick() error {
    75  	return wsTopicListTick(h)
    76  }
    77  
    78  func wsTopicListTick(h *WsHubImpl) error {
    79  	// Avoid hitting GetList when the topic list hasn't changed
    80  	if !TopicListThaw.Thawed() && h.lastTopicList != nil {
    81  		return nil
    82  	}
    83  	tickStart := time.Now()
    84  
    85  	// Don't waste CPU time if nothing has happened
    86  	// TODO: Get a topic list method which strips stickies?
    87  	tList, _, _, err := TopicList.GetList(1, 0, nil)
    88  	if err != nil {
    89  		h.lastTick = tickStart
    90  		return err // TODO: Do we get ErrNoRows here?
    91  	}
    92  	defer func() {
    93  		h.lastTick = tickStart
    94  		h.lastTopicList = tList
    95  	}()
    96  	if len(tList) == 0 {
    97  		return nil
    98  	}
    99  
   100  	// TODO: Optimise this by only sniffing the top non-sticky
   101  	// TODO: Optimise this by getting back an unsorted list so we don't have to hop around the stickies
   102  	// TODO: Add support for new stickies / replies to them
   103  	if len(tList) == len(h.lastTopicList) {
   104  		hasItem := false
   105  		for j, tItem := range tList {
   106  			if !tItem.Sticky {
   107  				if tItem.ID != h.lastTopicList[j].ID || !tItem.LastReplyAt.Equal(h.lastTopicList[j].LastReplyAt) {
   108  					hasItem = true
   109  					break
   110  				}
   111  			}
   112  		}
   113  		if !hasItem {
   114  			return nil
   115  		}
   116  	}
   117  
   118  	// TODO: Implement this for guests too? Should be able to optimise it far better there due to them sharing the same permission set
   119  	// TODO: Be less aggressive with the locking, maybe use an array of sorts instead of hitting the main map every-time
   120  	topicListMutex.RLock()
   121  	if len(topicListWatchers) == 0 {
   122  		topicListMutex.RUnlock()
   123  		return nil
   124  	}
   125  
   126  	// Copy these over so we close this loop as fast as possible so we can release the read lock, especially if the group gets are backed by calls to the database
   127  	groupIDs := make(map[int]bool)
   128  	currentWatchers := make([]*WSUser, len(topicListWatchers))
   129  	i := 0
   130  	for wsUser, _ := range topicListWatchers {
   131  		currentWatchers[i] = wsUser
   132  		groupIDs[wsUser.User.Group] = true
   133  		i++
   134  	}
   135  	topicListMutex.RUnlock()
   136  
   137  	groups := make(map[int]*Group)
   138  	canSeeMap := make(map[string][]int)
   139  	for gid, _ := range groupIDs {
   140  		g, err := Groups.Get(gid)
   141  		if err != nil {
   142  			// TODO: Do we really want to halt all pushes for what is possibly just one user?
   143  			return err
   144  		}
   145  		groups[g.ID] = g
   146  
   147  		canSee := make([]byte, len(g.CanSee))
   148  		for i, item := range g.CanSee {
   149  			canSee[i] = byte(item)
   150  		}
   151  		canSeeMap[string(canSee)] = g.CanSee
   152  	}
   153  
   154  	canSeeRenders := make(map[string][]byte)
   155  	canSeeLists := make(map[string][]*WsTopicsRow)
   156  	for name, canSee := range canSeeMap {
   157  		topicList, forumList, _, err := TopicList.GetListByCanSee(canSee, 1, 0, nil)
   158  		if err != nil {
   159  			return err // TODO: Do we get ErrNoRows here?
   160  		}
   161  		if len(topicList) == 0 {
   162  			continue
   163  		}
   164  		_ = forumList // Might use this later after we get the base feature working
   165  
   166  		if topicList[0].Sticky {
   167  			lastSticky := 0
   168  			for i, row := range topicList {
   169  				if !row.Sticky {
   170  					lastSticky = i
   171  					break
   172  				}
   173  			}
   174  			if lastSticky == 0 {
   175  				continue
   176  			}
   177  			topicList = topicList[lastSticky:]
   178  		}
   179  
   180  		// TODO: Compare to previous tick to eliminate unnecessary work and data
   181  		wsTopicList := make([]*WsTopicsRow, len(topicList))
   182  		for i, topicRow := range topicList {
   183  			wsTopicList[i] = topicRow.WebSockets()
   184  		}
   185  		canSeeLists[name] = wsTopicList
   186  
   187  		outBytes, err := json.Marshal(&WsTopicList{wsTopicList, 0, tickStart.Unix()})
   188  		if err != nil {
   189  			return err
   190  		}
   191  		canSeeRenders[name] = outBytes
   192  	}
   193  
   194  	// TODO: Use MessagePack for additional speed?
   195  	//fmt.Println("writing to the clients")
   196  	for _, wsUser := range currentWatchers {
   197  		u := wsUser.User
   198  		group := groups[u.Group]
   199  		canSee := make([]byte, len(group.CanSee))
   200  		for i, item := range group.CanSee {
   201  			canSee[i] = byte(item)
   202  		}
   203  		sCanSee := string(canSee)
   204  		l := canSeeLists[sCanSee]
   205  
   206  		// TODO: Optimise this away for guests?
   207  		anyMod, anyLock, anyMove, allMod := false, false, false, true
   208  		var modSet map[int]int
   209  		if u.IsSuperAdmin {
   210  			anyMod = true
   211  			anyLock = true
   212  			anyMove = true
   213  		} else {
   214  			modSet = make(map[int]int, len(l))
   215  			for i, t := range l {
   216  				// TODO: Abstract this?
   217  				fp, e := FPStore.Get(t.ParentID, u.Group)
   218  				if e == ErrNoRows {
   219  					fp = BlankForumPerms()
   220  				} else if e != nil {
   221  					return e
   222  				}
   223  				var ccanMod, ccanLock, ccanMove bool
   224  				if fp.Overrides {
   225  					ccanLock = fp.CloseTopic
   226  					ccanMove = fp.MoveTopic
   227  					ccanMod = t.CreatedBy == u.ID || fp.DeleteTopic || ccanLock || ccanMove
   228  				} else {
   229  					ccanLock = u.Perms.CloseTopic
   230  					ccanMove = u.Perms.MoveTopic
   231  					ccanMod = t.CreatedBy == u.ID || u.Perms.DeleteTopic || ccanLock || ccanMove
   232  				}
   233  				if ccanLock {
   234  					anyLock = true
   235  				}
   236  				if ccanMove {
   237  					anyMove = true
   238  				}
   239  				if ccanMod {
   240  					anyMod = true
   241  				} else {
   242  					allMod = false
   243  				}
   244  				var v int
   245  				if ccanMod {
   246  					v = 1
   247  				}
   248  				modSet[i] = v
   249  			}
   250  		}
   251  
   252  		//fmt.Println("writing to user #", wsUser.User.ID)
   253  		outBytes := canSeeRenders[sCanSee]
   254  		//fmt.Println("outBytes: ", string(outBytes))
   255  		//fmt.Println("outBytes[:len(outBytes)-1]: ", string(outBytes[:len(outBytes)-1]))
   256  		//e := wsUser.WriteToPageBytes(outBytes, "/topics/")
   257  		//e := wsUser.WriteToPageBytesMulti([][]byte{outBytes[:len(outBytes)-1], []byte(`,"mod":1}`)}, "/topics/")
   258  		var e error
   259  		if !anyMod {
   260  			e = wsUser.WriteToPageBytes(outBytes, "/topics/")
   261  		} else {
   262  			var lm []byte
   263  			if anyLock && anyMove {
   264  				lm = []byte(`,"lock":1,"move":1}`)
   265  			} else if anyLock {
   266  				lm = []byte(`,"lock":1}`)
   267  			} else if anyMove {
   268  				lm = []byte(`,"move":1}`)
   269  			} else {
   270  				lm = []byte("}")
   271  			}
   272  			if allMod {
   273  				e = wsUser.WriteToPageBytesMulti([][]byte{outBytes[:len(outBytes)-1], []byte(`,"mod":1`), lm}, "/topics/")
   274  			} else {
   275  				// TODO: Temporary and inefficient
   276  				mBytes, err := json.Marshal(modSet)
   277  				if err != nil {
   278  					return err
   279  				}
   280  				e = wsUser.WriteToPageBytesMulti([][]byte{outBytes[:len(outBytes)-1], []byte(`,"mod":`), mBytes, lm}, "/topics/")
   281  			}
   282  		}
   283  
   284  		if e == ErrNoneOnPage {
   285  			//fmt.Printf("werr for #%d: %s\n", wsUser.User.ID, err)
   286  			wsUser.FinalizePage("/topics/", func() {
   287  				topicListMutex.Lock()
   288  				delete(topicListWatchers, wsUser)
   289  				topicListMutex.Unlock()
   290  			})
   291  			continue
   292  		}
   293  	}
   294  	return nil
   295  }
   296  
   297  func (h *WsHubImpl) GuestCount() int {
   298  	h.GuestLock.RLock()
   299  	defer h.GuestLock.RUnlock()
   300  	return len(h.OnlineGuests)
   301  }
   302  
   303  func (h *WsHubImpl) UserCount() (count int) {
   304  	h.evenUserLock.RLock()
   305  	count += len(h.evenOnlineUsers)
   306  	h.evenUserLock.RUnlock()
   307  
   308  	h.oddUserLock.RLock()
   309  	count += len(h.oddOnlineUsers)
   310  	h.oddUserLock.RUnlock()
   311  	return count
   312  }
   313  
   314  func (h *WsHubImpl) HasUser(uid int) (exists bool) {
   315  	h.evenUserLock.RLock()
   316  	_, exists = h.evenOnlineUsers[uid]
   317  	h.evenUserLock.RUnlock()
   318  	if exists {
   319  		return exists
   320  	}
   321  
   322  	h.oddUserLock.RLock()
   323  	_, exists = h.oddOnlineUsers[uid]
   324  	h.oddUserLock.RUnlock()
   325  	return exists
   326  }
   327  
   328  func (h *WsHubImpl) broadcastMessage(msg string) error {
   329  	userLoop := func(users map[int]*WSUser, m *sync.RWMutex) error {
   330  		m.RLock()
   331  		defer m.RUnlock()
   332  		for _, wsUser := range users {
   333  			e := wsUser.WriteAll(msg)
   334  			if e != nil {
   335  				return e
   336  			}
   337  		}
   338  		return nil
   339  	}
   340  	// TODO: Can we move this RLock inside the closure safely?
   341  	e := userLoop(h.evenOnlineUsers, &h.evenUserLock)
   342  	if e != nil {
   343  		return e
   344  	}
   345  	return userLoop(h.oddOnlineUsers, &h.oddUserLock)
   346  }
   347  
   348  func (h *WsHubImpl) getUser(uid int) (wsUser *WSUser, err error) {
   349  	var ok bool
   350  	if uid%2 == 0 {
   351  		h.evenUserLock.RLock()
   352  		wsUser, ok = h.evenOnlineUsers[uid]
   353  		h.evenUserLock.RUnlock()
   354  	} else {
   355  		h.oddUserLock.RLock()
   356  		wsUser, ok = h.oddOnlineUsers[uid]
   357  		h.oddUserLock.RUnlock()
   358  	}
   359  	if !ok {
   360  		return nil, errWsNouser
   361  	}
   362  	return wsUser, nil
   363  }
   364  
   365  // Warning: For efficiency, some of the *WSUsers may be nil pointers, DO NOT EXPORT
   366  // TODO: Write tests for this
   367  func (h *WsHubImpl) getUsers(uids []int) (wsUsers []*WSUser, err error) {
   368  	if len(uids) == 0 {
   369  		return nil, errWsNouser
   370  	}
   371  	//wsUsers = make([]*WSUser, len(uids))
   372  	//i := 0
   373  	appender := func(l *sync.RWMutex, users map[int]*WSUser) {
   374  		l.RLock()
   375  		defer l.RUnlock()
   376  		// We don't want to keep a lock on this for too long, so we'll accept some nil pointers
   377  		for _, uid := range uids {
   378  			//wsUsers[i] = users[uid]
   379  			wsUsers = append(wsUsers, users[uid])
   380  			//i++
   381  		}
   382  	}
   383  	appender(&h.evenUserLock, h.evenOnlineUsers)
   384  	appender(&h.oddUserLock, h.oddOnlineUsers)
   385  	if len(wsUsers) == 0 {
   386  		return nil, errWsNouser
   387  	}
   388  	return wsUsers, nil
   389  }
   390  
   391  // For Widget WOL, please avoid using this as it might wind up being really long and slow without the right safeguards
   392  func (h *WsHubImpl) AllUsers() (users []*User) {
   393  	appender := func(l *sync.RWMutex, userMap map[int]*WSUser) {
   394  		l.RLock()
   395  		defer l.RUnlock()
   396  		for _, u := range userMap {
   397  			users = append(users, u.User)
   398  		}
   399  	}
   400  	appender(&h.evenUserLock, h.evenOnlineUsers)
   401  	appender(&h.oddUserLock, h.oddOnlineUsers)
   402  	return users
   403  }
   404  
   405  func (h *WsHubImpl) removeUser(uid int) {
   406  	if uid%2 == 0 {
   407  		h.evenUserLock.Lock()
   408  		delete(h.evenOnlineUsers, uid)
   409  		h.evenUserLock.Unlock()
   410  	} else {
   411  		h.oddUserLock.Lock()
   412  		delete(h.oddOnlineUsers, uid)
   413  		h.oddUserLock.Unlock()
   414  	}
   415  }
   416  
   417  func (h *WsHubImpl) AddConn(user *User, conn *websocket.Conn) (*WSUser, error) {
   418  	if user.ID == 0 {
   419  		wsUser := new(WSUser)
   420  		wsUser.User = new(User)
   421  		*wsUser.User = *user
   422  		wsUser.AddSocket(conn, "")
   423  		WsHub.GuestLock.Lock()
   424  		WsHub.OnlineGuests[wsUser] = true
   425  		WsHub.GuestLock.Unlock()
   426  		return wsUser, nil
   427  	}
   428  
   429  	// TODO: How should we handle user state changes if we're holding a pointer which never changes?
   430  	userptr, err := Users.Get(user.ID)
   431  	if err != nil && err != ErrStoreCapacityOverflow {
   432  		return nil, err
   433  	}
   434  
   435  	var mutex *sync.RWMutex
   436  	var theMap map[int]*WSUser
   437  	if user.ID%2 == 0 {
   438  		mutex = &h.evenUserLock
   439  		theMap = h.evenOnlineUsers
   440  	} else {
   441  		mutex = &h.oddUserLock
   442  		theMap = h.oddOnlineUsers
   443  	}
   444  
   445  	mutex.Lock()
   446  	wsUser, ok := theMap[user.ID]
   447  	if !ok {
   448  		wsUser = new(WSUser)
   449  		wsUser.User = userptr
   450  		wsUser.Sockets = []*WSUserSocket{{conn, ""}}
   451  		theMap[user.ID] = wsUser
   452  		mutex.Unlock()
   453  		return wsUser, nil
   454  	}
   455  	mutex.Unlock()
   456  	wsUser.AddSocket(conn, "")
   457  	return wsUser, nil
   458  }
   459  
   460  func (h *WsHubImpl) RemoveConn(wsUser *WSUser, conn *websocket.Conn) {
   461  	wsUser.RemoveSocket(conn)
   462  	wsUser.Lock()
   463  	if len(wsUser.Sockets) == 0 {
   464  		h.removeUser(wsUser.User.ID)
   465  	}
   466  	wsUser.Unlock()
   467  }
   468  
   469  func (h *WsHubImpl) PushMessage(targetUser int, msg string) error {
   470  	wsUser, e := h.getUser(targetUser)
   471  	if e != nil {
   472  		return e
   473  	}
   474  	return wsUser.WriteAll(msg)
   475  }
   476  
   477  func (h *WsHubImpl) pushAlert(targetUser int, a Alert) error {
   478  	wsUser, e := h.getUser(targetUser)
   479  	if e != nil {
   480  		return e
   481  	}
   482  	astr, e := BuildAlert(a, *wsUser.User)
   483  	if e != nil {
   484  		return e
   485  	}
   486  	return wsUser.WriteAll(astr)
   487  }
   488  
   489  func (h *WsHubImpl) pushAlerts(users []int, a Alert) error {
   490  	wsUsers, err := h.getUsers(users)
   491  	if err != nil {
   492  		return err
   493  	}
   494  
   495  	var errs []error
   496  	for _, wsUser := range wsUsers {
   497  		if wsUser == nil {
   498  			continue
   499  		}
   500  		alert, err := BuildAlert(a, *wsUser.User)
   501  		if err != nil {
   502  			errs = append(errs, err)
   503  		}
   504  		err = wsUser.WriteAll(alert)
   505  		if err != nil {
   506  			errs = append(errs, err)
   507  		}
   508  	}
   509  
   510  	// Return the first error
   511  	if len(errs) != 0 {
   512  		for _, e := range errs {
   513  			return e
   514  		}
   515  	}
   516  	return nil
   517  }