github.com/mattermost/mattermost-server/v5@v5.39.3/store/sqlstore/thread_store.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package sqlstore
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"time"
    10  
    11  	sq "github.com/Masterminds/squirrel"
    12  	"github.com/mattermost/gorp"
    13  	"github.com/pkg/errors"
    14  
    15  	"github.com/mattermost/mattermost-server/v5/model"
    16  	"github.com/mattermost/mattermost-server/v5/store"
    17  	"github.com/mattermost/mattermost-server/v5/utils"
    18  )
    19  
    20  type SqlThreadStore struct {
    21  	*SqlStore
    22  }
    23  
    24  func (s *SqlThreadStore) ClearCaches() {
    25  }
    26  
    27  func newSqlThreadStore(sqlStore *SqlStore) store.ThreadStore {
    28  	s := &SqlThreadStore{
    29  		SqlStore: sqlStore,
    30  	}
    31  
    32  	for _, db := range sqlStore.GetAllConns() {
    33  		tableThreads := db.AddTableWithName(model.Thread{}, "Threads").SetKeys(false, "PostId")
    34  		tableThreads.ColMap("PostId").SetMaxSize(26)
    35  		tableThreads.ColMap("ChannelId").SetMaxSize(26)
    36  		tableThreads.ColMap("Participants").SetMaxSize(0)
    37  		tableThreadMemberships := db.AddTableWithName(model.ThreadMembership{}, "ThreadMemberships").SetKeys(false, "PostId", "UserId")
    38  		tableThreadMemberships.ColMap("PostId").SetMaxSize(26)
    39  		tableThreadMemberships.ColMap("UserId").SetMaxSize(26)
    40  	}
    41  
    42  	return s
    43  }
    44  
    45  func threadSliceColumns() []string {
    46  	return []string{"PostId", "ChannelId", "LastReplyAt", "ReplyCount", "Participants"}
    47  }
    48  
    49  func threadToSlice(thread *model.Thread) []interface{} {
    50  	return []interface{}{
    51  		thread.PostId,
    52  		thread.ChannelId,
    53  		thread.LastReplyAt,
    54  		thread.ReplyCount,
    55  		model.ArrayToJson(thread.Participants),
    56  	}
    57  }
    58  
    59  func (s *SqlThreadStore) createIndexesIfNotExists() {
    60  	s.CreateIndexIfNotExists("idx_thread_memberships_last_update_at", "ThreadMemberships", "LastUpdated")
    61  	s.CreateIndexIfNotExists("idx_thread_memberships_last_view_at", "ThreadMemberships", "LastViewed")
    62  	s.CreateIndexIfNotExists("idx_thread_memberships_user_id", "ThreadMemberships", "UserId")
    63  	s.CreateIndexIfNotExists("idx_threads_channel_id", "Threads", "ChannelId")
    64  }
    65  
    66  func (s *SqlThreadStore) SaveMultiple(threads []*model.Thread) ([]*model.Thread, int, error) {
    67  	builder := s.getQueryBuilder().Insert("Threads").Columns(threadSliceColumns()...)
    68  	for _, thread := range threads {
    69  		builder = builder.Values(threadToSlice(thread)...)
    70  	}
    71  	query, args, err := builder.ToSql()
    72  	if err != nil {
    73  		return nil, -1, errors.Wrap(err, "thread_tosql")
    74  	}
    75  
    76  	if _, err := s.GetMaster().Exec(query, args...); err != nil {
    77  		return nil, -1, errors.Wrap(err, "failed to save Post")
    78  	}
    79  
    80  	return threads, -1, nil
    81  }
    82  
    83  func (s *SqlThreadStore) Save(thread *model.Thread) (*model.Thread, error) {
    84  	threads, _, err := s.SaveMultiple([]*model.Thread{thread})
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  	return threads[0], nil
    89  }
    90  
    91  func (s *SqlThreadStore) Update(thread *model.Thread) (*model.Thread, error) {
    92  	return s.update(s.GetMaster(), thread)
    93  }
    94  
    95  func (s *SqlThreadStore) update(ex gorp.SqlExecutor, thread *model.Thread) (*model.Thread, error) {
    96  	if _, err := ex.Update(thread); err != nil {
    97  		return nil, errors.Wrapf(err, "failed to update thread with id=%s", thread.PostId)
    98  	}
    99  
   100  	return thread, nil
   101  }
   102  
   103  func (s *SqlThreadStore) Get(id string) (*model.Thread, error) {
   104  	return s.get(s.GetReplica(), id)
   105  }
   106  
   107  func (s *SqlThreadStore) get(ex gorp.SqlExecutor, id string) (*model.Thread, error) {
   108  	var thread model.Thread
   109  	query, args, _ := s.getQueryBuilder().Select("*").From("Threads").Where(sq.Eq{"PostId": id}).ToSql()
   110  	err := ex.SelectOne(&thread, query, args...)
   111  	if err != nil {
   112  		if err == sql.ErrNoRows {
   113  			return nil, nil
   114  		}
   115  
   116  		return nil, errors.Wrapf(err, "failed to get thread with id=%s", id)
   117  	}
   118  	return &thread, nil
   119  }
   120  
   121  func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) {
   122  	type JoinedThread struct {
   123  		PostId         string
   124  		ReplyCount     int64
   125  		LastReplyAt    int64
   126  		LastViewedAt   int64
   127  		UnreadReplies  int64
   128  		UnreadMentions int64
   129  		Participants   model.StringArray
   130  		model.Post
   131  	}
   132  
   133  	fetchConditions := sq.And{
   134  		sq.Eq{"ThreadMemberships.UserId": userId},
   135  		sq.Eq{"ThreadMemberships.Following": true},
   136  	}
   137  	if opts.TeamOnly {
   138  		fetchConditions = sq.And{
   139  			sq.Eq{"Channels.TeamId": teamId},
   140  			fetchConditions,
   141  		}
   142  	} else {
   143  		fetchConditions = sq.And{
   144  			sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}},
   145  			fetchConditions,
   146  		}
   147  	}
   148  	if !opts.Deleted {
   149  		fetchConditions = sq.And{
   150  			fetchConditions,
   151  			sq.Eq{"COALESCE(Posts.DeleteAt, 0)": 0},
   152  		}
   153  	}
   154  
   155  	pageSize := uint64(30)
   156  	if opts.PageSize != 0 {
   157  		pageSize = opts.PageSize
   158  	}
   159  
   160  	totalUnreadThreadsChan := make(chan store.StoreResult, 1)
   161  	totalCountChan := make(chan store.StoreResult, 1)
   162  	totalUnreadMentionsChan := make(chan store.StoreResult, 1)
   163  	var threadsChan chan store.StoreResult
   164  	if !opts.TotalsOnly {
   165  		threadsChan = make(chan store.StoreResult, 1)
   166  	}
   167  
   168  	go func() {
   169  		repliesQuery, repliesQueryArgs, _ := s.getQueryBuilder().
   170  			Select("COUNT(DISTINCT(Posts.RootId))").
   171  			From("Posts").
   172  			LeftJoin("ThreadMemberships ON Posts.RootId = ThreadMemberships.PostId").
   173  			LeftJoin("Channels ON Posts.ChannelId = Channels.Id").
   174  			Where(fetchConditions).
   175  			Where("Posts.CreateAt > ThreadMemberships.LastViewed").ToSql()
   176  
   177  		totalUnreadThreads, err := s.GetMaster().SelectInt(repliesQuery, repliesQueryArgs...)
   178  		totalUnreadThreadsChan <- store.StoreResult{Data: totalUnreadThreads, NErr: errors.Wrapf(err, "failed to get count unread on threads for user id=%s", userId)}
   179  		close(totalUnreadThreadsChan)
   180  	}()
   181  	go func() {
   182  		newFetchConditions := fetchConditions
   183  
   184  		if opts.Unread {
   185  			newFetchConditions = sq.And{newFetchConditions, sq.Expr("ThreadMemberships.LastViewed < Threads.LastReplyAt")}
   186  		}
   187  
   188  		threadsQuery, threadsQueryArgs, _ := s.getQueryBuilder().
   189  			Select("COUNT(ThreadMemberships.PostId)").
   190  			LeftJoin("Threads ON Threads.PostId = ThreadMemberships.PostId").
   191  			LeftJoin("Channels ON Threads.ChannelId = Channels.Id").
   192  			LeftJoin("Posts ON Posts.Id = ThreadMemberships.PostId").
   193  			From("ThreadMemberships").
   194  			Where(newFetchConditions).ToSql()
   195  
   196  		totalCount, err := s.GetMaster().SelectInt(threadsQuery, threadsQueryArgs...)
   197  		totalCountChan <- store.StoreResult{Data: totalCount, NErr: err}
   198  		close(totalCountChan)
   199  	}()
   200  	go func() {
   201  		mentionsQuery, mentionsQueryArgs, _ := s.getQueryBuilder().
   202  			Select("COALESCE(SUM(ThreadMemberships.UnreadMentions),0)").
   203  			From("ThreadMemberships").
   204  			LeftJoin("Threads ON Threads.PostId = ThreadMemberships.PostId").
   205  			LeftJoin("Posts ON Posts.Id = ThreadMemberships.PostId").
   206  			LeftJoin("Channels ON Threads.ChannelId = Channels.Id").
   207  			Where(fetchConditions).ToSql()
   208  		totalUnreadMentions, err := s.GetMaster().SelectInt(mentionsQuery, mentionsQueryArgs...)
   209  		totalUnreadMentionsChan <- store.StoreResult{Data: totalUnreadMentions, NErr: err}
   210  		close(totalUnreadMentionsChan)
   211  	}()
   212  
   213  	if !opts.TotalsOnly {
   214  		go func() {
   215  			newFetchConditions := fetchConditions
   216  			if opts.Since > 0 {
   217  				newFetchConditions = sq.And{newFetchConditions, sq.GtOrEq{"ThreadMemberships.LastUpdated": opts.Since}}
   218  			}
   219  			order := "DESC"
   220  			if opts.Before != "" {
   221  				newFetchConditions = sq.And{
   222  					newFetchConditions,
   223  					sq.Expr(`LastReplyAt < (SELECT LastReplyAt FROM Threads WHERE PostId = ?)`, opts.Before),
   224  				}
   225  			}
   226  			if opts.After != "" {
   227  				order = "ASC"
   228  				newFetchConditions = sq.And{
   229  					newFetchConditions,
   230  					sq.Expr(`LastReplyAt > (SELECT LastReplyAt FROM Threads WHERE PostId = ?)`, opts.After),
   231  				}
   232  			}
   233  			if opts.Unread {
   234  				newFetchConditions = sq.And{newFetchConditions, sq.Expr("ThreadMemberships.LastViewed < Threads.LastReplyAt")}
   235  			}
   236  
   237  			unreadRepliesFetchConditions := sq.And{
   238  				sq.Expr("Posts.RootId = ThreadMemberships.PostId"),
   239  				sq.Expr("Posts.CreateAt > ThreadMemberships.LastViewed"),
   240  			}
   241  			if !opts.Deleted {
   242  				unreadRepliesFetchConditions = sq.And{
   243  					unreadRepliesFetchConditions,
   244  					sq.Expr("Posts.DeleteAt = 0"),
   245  				}
   246  			}
   247  
   248  			unreadRepliesQuery, _ := sq.
   249  				Select("COUNT(Posts.Id)").
   250  				From("Posts").
   251  				Where(unreadRepliesFetchConditions).
   252  				MustSql()
   253  
   254  			var threads []*JoinedThread
   255  			query, args, _ := s.getQueryBuilder().
   256  				Select(`Threads.*,
   257  				` + postSliceCoalesceQuery() + `,
   258  				ThreadMemberships.LastViewed as LastViewedAt,
   259  				ThreadMemberships.UnreadMentions as UnreadMentions`).
   260  				From("Threads").
   261  				Column(sq.Alias(sq.Expr(unreadRepliesQuery), "UnreadReplies")).
   262  				LeftJoin("Posts ON Posts.Id = Threads.PostId").
   263  				LeftJoin("Channels ON Posts.ChannelId = Channels.Id").
   264  				LeftJoin("ThreadMemberships ON ThreadMemberships.PostId = Threads.PostId").
   265  				Where(newFetchConditions).
   266  				OrderBy("Threads.LastReplyAt " + order).
   267  				Limit(pageSize).ToSql()
   268  
   269  			_, err := s.GetReplica().Select(&threads, query, args...)
   270  			threadsChan <- store.StoreResult{Data: threads, NErr: err}
   271  			close(threadsChan)
   272  		}()
   273  	}
   274  
   275  	totalUnreadMentionsResult := <-totalUnreadMentionsChan
   276  	if totalUnreadMentionsResult.NErr != nil {
   277  		return nil, totalUnreadMentionsResult.NErr
   278  	}
   279  	totalUnreadMentions := totalUnreadMentionsResult.Data.(int64)
   280  
   281  	totalCountResult := <-totalCountChan
   282  	if totalCountResult.NErr != nil {
   283  		return nil, totalCountResult.NErr
   284  	}
   285  	totalCount := totalCountResult.Data.(int64)
   286  
   287  	totalUnreadThreadsResult := <-totalUnreadThreadsChan
   288  	if totalUnreadThreadsResult.NErr != nil {
   289  		return nil, totalUnreadThreadsResult.NErr
   290  	}
   291  	totalUnreadThreads := totalUnreadThreadsResult.Data.(int64)
   292  
   293  	var userIds []string
   294  	userIdMap := map[string]bool{}
   295  
   296  	result := &model.Threads{
   297  		Total:               totalCount,
   298  		Threads:             []*model.ThreadResponse{},
   299  		TotalUnreadMentions: totalUnreadMentions,
   300  		TotalUnreadThreads:  totalUnreadThreads,
   301  	}
   302  
   303  	if !opts.TotalsOnly {
   304  		threadsResult := <-threadsChan
   305  		if threadsResult.NErr != nil {
   306  			return nil, threadsResult.NErr
   307  		}
   308  		threads := threadsResult.Data.([]*JoinedThread)
   309  		for _, thread := range threads {
   310  			for _, participantId := range thread.Participants {
   311  				if _, ok := userIdMap[participantId]; !ok {
   312  					userIdMap[participantId] = true
   313  					userIds = append(userIds, participantId)
   314  				}
   315  			}
   316  		}
   317  		var users []*model.User
   318  		if opts.Extended {
   319  			var err error
   320  			users, err = s.User().GetProfileByIds(context.Background(), userIds, &store.UserGetByIdsOpts{}, true)
   321  			if err != nil {
   322  				return nil, errors.Wrapf(err, "failed to get threads for user id=%s", userId)
   323  			}
   324  		} else {
   325  			for _, userId := range userIds {
   326  				users = append(users, &model.User{Id: userId})
   327  			}
   328  		}
   329  
   330  		for _, thread := range threads {
   331  			var participants []*model.User
   332  			for _, participantId := range thread.Participants {
   333  				var participant *model.User
   334  				for _, u := range users {
   335  					if u.Id == participantId {
   336  						participant = u
   337  						break
   338  					}
   339  				}
   340  				if participant == nil {
   341  					return nil, errors.New("cannot find thread participant with id=" + participantId)
   342  				}
   343  				participants = append(participants, participant)
   344  			}
   345  			result.Threads = append(result.Threads, &model.ThreadResponse{
   346  				PostId:         thread.PostId,
   347  				ReplyCount:     thread.ReplyCount,
   348  				LastReplyAt:    thread.LastReplyAt,
   349  				LastViewedAt:   thread.LastViewedAt,
   350  				UnreadReplies:  thread.UnreadReplies,
   351  				UnreadMentions: thread.UnreadMentions,
   352  				Participants:   participants,
   353  				Post:           thread.Post.ToNilIfInvalid(),
   354  			})
   355  		}
   356  	}
   357  
   358  	return result, nil
   359  }
   360  
   361  func (s *SqlThreadStore) GetThreadFollowers(threadID string) ([]string, error) {
   362  	var users []string
   363  	query, args, _ := s.getQueryBuilder().
   364  		Select("ThreadMemberships.UserId").
   365  		From("ThreadMemberships").
   366  		Where(sq.Eq{"PostId": threadID}).ToSql()
   367  	_, err := s.GetReplica().Select(&users, query, args...)
   368  
   369  	if err != nil {
   370  		return nil, err
   371  	}
   372  	return users, nil
   373  }
   374  
   375  func (s *SqlThreadStore) GetThreadForUser(teamId string, threadMembership *model.ThreadMembership, extended bool) (*model.ThreadResponse, error) {
   376  	if !threadMembership.Following {
   377  		return nil, nil // in case the thread is not followed anymore - return nil error to be interpreted as 404
   378  	}
   379  
   380  	type JoinedThread struct {
   381  		PostId         string
   382  		Following      bool
   383  		ReplyCount     int64
   384  		LastReplyAt    int64
   385  		LastViewedAt   int64
   386  		UnreadReplies  int64
   387  		UnreadMentions int64
   388  		Participants   model.StringArray
   389  		model.Post
   390  	}
   391  
   392  	unreadRepliesQuery, unreadRepliesArgs := sq.
   393  		Select("COUNT(Posts.Id)").
   394  		From("Posts").
   395  		Where(sq.And{
   396  			sq.Eq{"Posts.RootId": threadMembership.PostId},
   397  			sq.Gt{"Posts.CreateAt": threadMembership.LastViewed},
   398  			sq.Eq{"Posts.DeleteAt": 0},
   399  		}).MustSql()
   400  
   401  	fetchConditions := sq.And{
   402  		sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}},
   403  		sq.Eq{"Threads.PostId": threadMembership.PostId},
   404  	}
   405  
   406  	var thread JoinedThread
   407  	query, threadArgs, _ := s.getQueryBuilder().
   408  		Select("Threads.*, Posts.*").
   409  		From("Threads").
   410  		Column(sq.Alias(sq.Expr(unreadRepliesQuery), "UnreadReplies")).
   411  		LeftJoin("Posts ON Posts.Id = Threads.PostId").
   412  		LeftJoin("Channels ON Posts.ChannelId = Channels.Id").
   413  		Where(fetchConditions).ToSql()
   414  
   415  	args := append(unreadRepliesArgs, threadArgs...)
   416  
   417  	err := s.GetReplica().SelectOne(&thread, query, args...)
   418  	if err != nil {
   419  		return nil, err
   420  	}
   421  
   422  	thread.LastViewedAt = threadMembership.LastViewed
   423  	thread.UnreadMentions = threadMembership.UnreadMentions
   424  
   425  	var users []*model.User
   426  	if extended {
   427  		var err error
   428  		users, err = s.User().GetProfileByIds(context.Background(), thread.Participants, &store.UserGetByIdsOpts{}, true)
   429  		if err != nil {
   430  			return nil, errors.Wrapf(err, "failed to get thread for user id=%s", threadMembership.UserId)
   431  		}
   432  	} else {
   433  		for _, userId := range thread.Participants {
   434  			users = append(users, &model.User{Id: userId})
   435  		}
   436  	}
   437  
   438  	result := &model.ThreadResponse{
   439  		PostId:         thread.PostId,
   440  		ReplyCount:     thread.ReplyCount,
   441  		LastReplyAt:    thread.LastReplyAt,
   442  		LastViewedAt:   thread.LastViewedAt,
   443  		UnreadReplies:  thread.UnreadReplies,
   444  		UnreadMentions: thread.UnreadMentions,
   445  		Participants:   users,
   446  		Post:           thread.Post.ToNilIfInvalid(),
   447  	}
   448  
   449  	return result, nil
   450  }
   451  func (s *SqlThreadStore) MarkAllAsReadInChannels(userID string, channelIDs []string) error {
   452  	var threadIDs []string
   453  
   454  	query, args, _ := s.getQueryBuilder().
   455  		Select("ThreadMemberships.PostId").
   456  		Join("Threads ON Threads.PostId = ThreadMemberships.PostId").
   457  		Join("Channels ON Threads.ChannelId = Channels.Id").
   458  		From("ThreadMemberships").
   459  		Where(sq.Eq{"Threads.ChannelId": channelIDs}).
   460  		Where(sq.Eq{"ThreadMemberships.UserId": userID}).
   461  		ToSql()
   462  
   463  	_, err := s.GetReplica().Select(&threadIDs, query, args...)
   464  	if err != nil {
   465  		return errors.Wrapf(err, "failed to get thread membership with userid=%s", userID)
   466  	}
   467  
   468  	timestamp := model.GetMillis()
   469  	query, args, _ = s.getQueryBuilder().
   470  		Update("ThreadMemberships").
   471  		Where(sq.Eq{"PostId": threadIDs}).
   472  		Where(sq.Eq{"UserId": userID}).
   473  		Set("LastViewed", timestamp).
   474  		Set("UnreadMentions", 0).
   475  		ToSql()
   476  	if _, err := s.GetMaster().Exec(query, args...); err != nil {
   477  		return errors.Wrapf(err, "failed to update thread read state for user id=%s", userID)
   478  	}
   479  	return nil
   480  
   481  }
   482  func (s *SqlThreadStore) MarkAllAsRead(userId, teamId string) error {
   483  	memberships, err := s.GetMembershipsForUser(userId, teamId)
   484  	if err != nil {
   485  		return err
   486  	}
   487  	var membershipIds []string
   488  	for _, m := range memberships {
   489  		membershipIds = append(membershipIds, m.PostId)
   490  	}
   491  	timestamp := model.GetMillis()
   492  	query, args, _ := s.getQueryBuilder().
   493  		Update("ThreadMemberships").
   494  		Where(sq.Eq{"PostId": membershipIds}).
   495  		Where(sq.Eq{"UserId": userId}).
   496  		Set("LastViewed", timestamp).
   497  		Set("UnreadMentions", 0).
   498  		ToSql()
   499  	if _, err := s.GetMaster().Exec(query, args...); err != nil {
   500  		return errors.Wrapf(err, "failed to update thread read state for user id=%s", userId)
   501  	}
   502  	return nil
   503  }
   504  
   505  func (s *SqlThreadStore) MarkAsRead(userId, threadId string, timestamp int64) error {
   506  	query, args, _ := s.getQueryBuilder().
   507  		Update("ThreadMemberships").
   508  		Where(sq.Eq{"UserId": userId}).
   509  		Where(sq.Eq{"PostId": threadId}).
   510  		Set("LastViewed", timestamp).
   511  		ToSql()
   512  	if _, err := s.GetMaster().Exec(query, args...); err != nil {
   513  		return errors.Wrapf(err, "failed to update thread read state for user id=%s thread_id=%v", userId, threadId)
   514  	}
   515  	return nil
   516  }
   517  
   518  func (s *SqlThreadStore) Delete(threadId string) error {
   519  	query, args, _ := s.getQueryBuilder().Delete("Threads").Where(sq.Eq{"PostId": threadId}).ToSql()
   520  	if _, err := s.GetMaster().Exec(query, args...); err != nil {
   521  		return errors.Wrap(err, "failed to update threads")
   522  	}
   523  
   524  	return nil
   525  }
   526  
   527  func (s *SqlThreadStore) SaveMembership(membership *model.ThreadMembership) (*model.ThreadMembership, error) {
   528  	return s.saveMembership(s.GetMaster(), membership)
   529  }
   530  
   531  func (s *SqlThreadStore) saveMembership(ex gorp.SqlExecutor, membership *model.ThreadMembership) (*model.ThreadMembership, error) {
   532  	if err := ex.Insert(membership); err != nil {
   533  		return nil, errors.Wrapf(err, "failed to save thread membership with postid=%s userid=%s", membership.PostId, membership.UserId)
   534  	}
   535  
   536  	return membership, nil
   537  }
   538  
   539  func (s *SqlThreadStore) UpdateMembership(membership *model.ThreadMembership) (*model.ThreadMembership, error) {
   540  	return s.updateMembership(s.GetMaster(), membership)
   541  }
   542  
   543  func (s *SqlThreadStore) updateMembership(ex gorp.SqlExecutor, membership *model.ThreadMembership) (*model.ThreadMembership, error) {
   544  	if _, err := ex.Update(membership); err != nil {
   545  		return nil, errors.Wrapf(err, "failed to update thread membership with postid=%s userid=%s", membership.PostId, membership.UserId)
   546  	}
   547  
   548  	return membership, nil
   549  }
   550  
   551  func (s *SqlThreadStore) GetMembershipsForUser(userId, teamId string) ([]*model.ThreadMembership, error) {
   552  	var memberships []*model.ThreadMembership
   553  
   554  	query, args, _ := s.getQueryBuilder().
   555  		Select("ThreadMemberships.*").
   556  		Join("Threads ON Threads.PostId = ThreadMemberships.PostId").
   557  		Join("Channels ON Threads.ChannelId = Channels.Id").
   558  		From("ThreadMemberships").
   559  		Where(sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}}).
   560  		Where(sq.Eq{"ThreadMemberships.UserId": userId}).
   561  		ToSql()
   562  
   563  	_, err := s.GetReplica().Select(&memberships, query, args...)
   564  	if err != nil {
   565  		return nil, errors.Wrapf(err, "failed to get thread membership with userid=%s", userId)
   566  	}
   567  	return memberships, nil
   568  }
   569  
   570  func (s *SqlThreadStore) GetMembershipForUser(userId, postId string) (*model.ThreadMembership, error) {
   571  	return s.getMembershipForUser(s.GetReplica(), userId, postId)
   572  }
   573  
   574  func (s *SqlThreadStore) getMembershipForUser(ex gorp.SqlExecutor, userId, postId string) (*model.ThreadMembership, error) {
   575  	var membership model.ThreadMembership
   576  	err := ex.SelectOne(&membership, "SELECT * from ThreadMemberships WHERE UserId = :UserId AND PostId = :PostId", map[string]interface{}{"UserId": userId, "PostId": postId})
   577  	if err != nil {
   578  		if err == sql.ErrNoRows {
   579  			return nil, store.NewErrNotFound("Thread", postId)
   580  		}
   581  		return nil, errors.Wrapf(err, "failed to get thread membership with userid=%s postid=%s", userId, postId)
   582  	}
   583  	return &membership, nil
   584  }
   585  
   586  func (s *SqlThreadStore) DeleteMembershipForUser(userId string, postId string) error {
   587  	if _, err := s.GetMaster().Exec("DELETE FROM ThreadMemberships Where PostId = :PostId AND UserId = :UserId", map[string]interface{}{"PostId": postId, "UserId": userId}); err != nil {
   588  		return errors.Wrap(err, "failed to update thread membership")
   589  	}
   590  
   591  	return nil
   592  }
   593  
   594  // MaintainMembership creates or updates a thread membership for the given user
   595  // and post. This method is used to update the state of a membership in response
   596  // to some events like:
   597  // - post creation (mentions handling)
   598  // - channel marked unread
   599  // - user explicitly following a thread
   600  func (s *SqlThreadStore) MaintainMembership(userId, postId string, opts store.ThreadMembershipOpts) (*model.ThreadMembership, error) {
   601  	trx, err := s.GetMaster().Begin()
   602  	if err != nil {
   603  		return nil, errors.Wrap(err, "begin_transaction")
   604  	}
   605  	defer finalizeTransaction(trx)
   606  
   607  	membership, err := s.getMembershipForUser(trx, userId, postId)
   608  	now := utils.MillisFromTime(time.Now())
   609  	// if memebership exists, update it if:
   610  	// a. user started/stopped following a thread
   611  	// b. mention count changed
   612  	// c. user viewed a thread
   613  	if err == nil {
   614  		followingNeedsUpdate := (opts.UpdateFollowing && (membership.Following != opts.Following))
   615  		if followingNeedsUpdate || opts.IncrementMentions || opts.UpdateViewedTimestamp {
   616  			if followingNeedsUpdate {
   617  				membership.Following = opts.Following
   618  			}
   619  			if opts.UpdateViewedTimestamp {
   620  				membership.LastViewed = now
   621  			}
   622  			membership.LastUpdated = now
   623  			if opts.IncrementMentions {
   624  				membership.UnreadMentions += 1
   625  			}
   626  			if _, err = s.updateMembership(trx, membership); err != nil {
   627  				return nil, err
   628  			}
   629  		}
   630  
   631  		if err = trx.Commit(); err != nil {
   632  			return nil, errors.Wrap(err, "commit_transaction")
   633  		}
   634  
   635  		return membership, err
   636  	}
   637  
   638  	var nfErr *store.ErrNotFound
   639  	if !errors.As(err, &nfErr) {
   640  		return nil, errors.Wrap(err, "failed to get thread membership")
   641  	}
   642  
   643  	membership = &model.ThreadMembership{
   644  		PostId:      postId,
   645  		UserId:      userId,
   646  		Following:   opts.Following,
   647  		LastUpdated: now,
   648  	}
   649  	if opts.IncrementMentions {
   650  		membership.UnreadMentions = 1
   651  	}
   652  	if opts.UpdateViewedTimestamp {
   653  		membership.LastViewed = now
   654  	}
   655  	membership, err = s.saveMembership(trx, membership)
   656  	if err != nil {
   657  		return nil, err
   658  	}
   659  
   660  	if opts.UpdateParticipants {
   661  		thread, getErr := s.get(trx, postId)
   662  		if getErr != nil {
   663  			return nil, getErr
   664  		}
   665  		if thread != nil && !thread.Participants.Contains(userId) {
   666  			thread.Participants = append(thread.Participants, userId)
   667  			if _, err = s.update(trx, thread); err != nil {
   668  				return nil, err
   669  			}
   670  		}
   671  	}
   672  
   673  	if err = trx.Commit(); err != nil {
   674  		return nil, errors.Wrap(err, "commit_transaction")
   675  	}
   676  
   677  	return membership, err
   678  }
   679  
   680  func (s *SqlThreadStore) CollectThreadsWithNewerReplies(userId string, channelIds []string, timestamp int64) ([]string, error) {
   681  	var changedThreads []string
   682  	query, args, _ := s.getQueryBuilder().
   683  		Select("Threads.PostId").
   684  		From("Threads").
   685  		LeftJoin("ChannelMembers ON ChannelMembers.ChannelId=Threads.ChannelId").
   686  		Where(sq.And{
   687  			sq.Eq{"Threads.ChannelId": channelIds},
   688  			sq.Eq{"ChannelMembers.UserId": userId},
   689  			sq.Or{
   690  				sq.Expr("Threads.LastReplyAt > ChannelMembers.LastViewedAt"),
   691  				sq.Gt{"Threads.LastReplyAt": timestamp},
   692  			},
   693  		}).
   694  		ToSql()
   695  	if _, err := s.GetReplica().Select(&changedThreads, query, args...); err != nil {
   696  		return nil, errors.Wrap(err, "failed to fetch threads")
   697  	}
   698  	return changedThreads, nil
   699  }
   700  
   701  func (s *SqlThreadStore) UpdateUnreadsByChannel(userId string, changedThreads []string, timestamp int64, updateViewedTimestamp bool) error {
   702  	if len(changedThreads) == 0 {
   703  		return nil
   704  	}
   705  
   706  	qb := s.getQueryBuilder().
   707  		Update("ThreadMemberships").
   708  		Where(sq.Eq{"UserId": userId, "PostId": changedThreads}).
   709  		Set("LastUpdated", timestamp)
   710  
   711  	if updateViewedTimestamp {
   712  		qb = qb.Set("LastViewed", timestamp)
   713  	}
   714  	updateQuery, updateArgs, _ := qb.ToSql()
   715  
   716  	if _, err := s.GetMaster().Exec(updateQuery, updateArgs...); err != nil {
   717  		return errors.Wrap(err, "failed to update thread membership")
   718  	}
   719  
   720  	return nil
   721  }
   722  
   723  func (s *SqlThreadStore) GetPosts(threadId string, since int64) ([]*model.Post, error) {
   724  	query, args, _ := s.getQueryBuilder().
   725  		Select("*").
   726  		From("Posts").
   727  		Where(sq.Eq{"RootId": threadId}).
   728  		Where(sq.Eq{"DeleteAt": 0}).
   729  		Where(sq.GtOrEq{"UpdateAt": since}).ToSql()
   730  	var result []*model.Post
   731  	if _, err := s.GetReplica().Select(&result, query, args...); err != nil {
   732  		return nil, errors.Wrap(err, "failed to fetch thread posts")
   733  	}
   734  	return result, nil
   735  }
   736  
   737  // PermanentDeleteBatchForRetentionPolicies deletes a batch of records which are affected by
   738  // the global or a granular retention policy.
   739  // See `genericPermanentDeleteBatchForRetentionPolicies` for details.
   740  func (s *SqlThreadStore) PermanentDeleteBatchForRetentionPolicies(now, globalPolicyEndTime, limit int64, cursor model.RetentionPolicyCursor) (int64, model.RetentionPolicyCursor, error) {
   741  	builder := s.getQueryBuilder().
   742  		Select("Threads.PostId").
   743  		From("Threads")
   744  	return genericPermanentDeleteBatchForRetentionPolicies(RetentionPolicyBatchDeletionInfo{
   745  		BaseBuilder:         builder,
   746  		Table:               "Threads",
   747  		TimeColumn:          "LastReplyAt",
   748  		PrimaryKeys:         []string{"PostId"},
   749  		ChannelIDTable:      "Threads",
   750  		NowMillis:           now,
   751  		GlobalPolicyEndTime: globalPolicyEndTime,
   752  		Limit:               limit,
   753  	}, s.SqlStore, cursor)
   754  }
   755  
   756  // PermanentDeleteBatchThreadMembershipsForRetentionPolicies deletes a batch of records
   757  // which are affected by the global or a granular retention policy.
   758  // See `genericPermanentDeleteBatchForRetentionPolicies` for details.
   759  func (s *SqlThreadStore) PermanentDeleteBatchThreadMembershipsForRetentionPolicies(now, globalPolicyEndTime, limit int64, cursor model.RetentionPolicyCursor) (int64, model.RetentionPolicyCursor, error) {
   760  	builder := s.getQueryBuilder().
   761  		Select("ThreadMemberships.PostId").
   762  		From("ThreadMemberships").
   763  		InnerJoin("Threads ON ThreadMemberships.PostId = Threads.PostId")
   764  	return genericPermanentDeleteBatchForRetentionPolicies(RetentionPolicyBatchDeletionInfo{
   765  		BaseBuilder:         builder,
   766  		Table:               "ThreadMemberships",
   767  		TimeColumn:          "LastUpdated",
   768  		PrimaryKeys:         []string{"PostId"},
   769  		ChannelIDTable:      "Threads",
   770  		NowMillis:           now,
   771  		GlobalPolicyEndTime: globalPolicyEndTime,
   772  		Limit:               limit,
   773  	}, s.SqlStore, cursor)
   774  }
   775  
   776  // DeleteOrphanedRows removes orphaned rows from Threads and ThreadMemberships
   777  func (s *SqlThreadStore) DeleteOrphanedRows(limit int) (deleted int64, err error) {
   778  	// We need the extra level of nesting to deal with MySQL's locking
   779  	const threadsQuery = `
   780  	DELETE FROM Threads WHERE PostId IN (
   781  		SELECT * FROM (
   782  			SELECT Threads.PostId FROM Threads
   783  			LEFT JOIN Channels ON Threads.ChannelId = Channels.Id
   784  			WHERE Channels.Id IS NULL
   785  			LIMIT :Limit
   786  		) AS A
   787  	)`
   788  	// We only delete a thread membership if the entire thread no longer exists,
   789  	// not if the root post has been deleted
   790  	const threadMembershipsQuery = `
   791  	DELETE FROM ThreadMemberships WHERE PostId IN (
   792  		SELECT * FROM (
   793  			SELECT ThreadMemberships.PostId FROM ThreadMemberships
   794  			LEFT JOIN Threads ON ThreadMemberships.PostId = Threads.PostId
   795  			WHERE Threads.PostId IS NULL
   796  			LIMIT :Limit
   797  		) AS A
   798  	)`
   799  	props := map[string]interface{}{"Limit": limit}
   800  	result, err := s.GetMaster().Exec(threadsQuery, props)
   801  	if err != nil {
   802  		return
   803  	}
   804  	rpcDeleted, err := result.RowsAffected()
   805  	if err != nil {
   806  		return
   807  	}
   808  	result, err = s.GetMaster().Exec(threadMembershipsQuery, props)
   809  	if err != nil {
   810  		return
   811  	}
   812  	rptDeleted, err := result.RowsAffected()
   813  	if err != nil {
   814  		return
   815  	}
   816  	deleted = rpcDeleted + rptDeleted
   817  	return
   818  }