github.com/masterhung0112/hk_server/v5@v5.0.0-20220302090640-ec71aef15e1c/store/sqlstore/thread_store.go (about)

     1  package sqlstore
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"time"
     7  
     8  	sq "github.com/Masterminds/squirrel"
     9  
    10  	"github.com/mattermost/gorp"
    11  	"github.com/pkg/errors"
    12  
    13  	"github.com/masterhung0112/hk_server/v5/model"
    14  	"github.com/masterhung0112/hk_server/v5/store"
    15  	"github.com/masterhung0112/hk_server/v5/utils"
    16  )
    17  
    18  type SqlThreadStore struct {
    19  	*SqlStore
    20  }
    21  
    22  func (s *SqlThreadStore) ClearCaches() {
    23  }
    24  
    25  func newSqlThreadStore(sqlStore *SqlStore) store.ThreadStore {
    26  	s := &SqlThreadStore{
    27  		SqlStore: sqlStore,
    28  	}
    29  
    30  	for _, db := range sqlStore.GetAllConns() {
    31  		tableThreads := db.AddTableWithName(model.Thread{}, "Threads").SetKeys(false, "PostId")
    32  		tableThreads.ColMap("PostId").SetMaxSize(26)
    33  		tableThreads.ColMap("ChannelId").SetMaxSize(26)
    34  		tableThreads.ColMap("Participants").SetMaxSize(0)
    35  		tableThreadMemberships := db.AddTableWithName(model.ThreadMembership{}, "ThreadMemberships").SetKeys(false, "PostId", "UserId")
    36  		tableThreadMemberships.ColMap("PostId").SetMaxSize(26)
    37  		tableThreadMemberships.ColMap("UserId").SetMaxSize(26)
    38  	}
    39  
    40  	return s
    41  }
    42  
    43  func threadSliceColumns() []string {
    44  	return []string{"PostId", "ChannelId", "LastReplyAt", "ReplyCount", "Participants"}
    45  }
    46  
    47  func threadToSlice(thread *model.Thread) []interface{} {
    48  	return []interface{}{
    49  		thread.PostId,
    50  		thread.ChannelId,
    51  		thread.LastReplyAt,
    52  		thread.ReplyCount,
    53  		thread.Participants,
    54  	}
    55  }
    56  
    57  func (s *SqlThreadStore) createIndexesIfNotExists() {
    58  	s.CreateIndexIfNotExists("idx_thread_memberships_last_update_at", "ThreadMemberships", "LastUpdated")
    59  	s.CreateIndexIfNotExists("idx_thread_memberships_last_view_at", "ThreadMemberships", "LastViewed")
    60  	s.CreateIndexIfNotExists("idx_thread_memberships_user_id", "ThreadMemberships", "UserId")
    61  	s.CreateIndexIfNotExists("idx_threads_channel_id", "Threads", "ChannelId")
    62  }
    63  
    64  func (s *SqlThreadStore) SaveMultiple(threads []*model.Thread) ([]*model.Thread, int, error) {
    65  	builder := s.getQueryBuilder().Insert("Threads").Columns(threadSliceColumns()...)
    66  	for _, thread := range threads {
    67  		builder = builder.Values(threadToSlice(thread)...)
    68  	}
    69  	query, args, err := builder.ToSql()
    70  	if err != nil {
    71  		return nil, -1, errors.Wrap(err, "thread_tosql")
    72  	}
    73  
    74  	if _, err := s.GetMaster().Exec(query, args...); err != nil {
    75  		return nil, -1, errors.Wrap(err, "failed to save Post")
    76  	}
    77  
    78  	return threads, -1, nil
    79  }
    80  
    81  func (s *SqlThreadStore) Save(thread *model.Thread) (*model.Thread, error) {
    82  	threads, _, err := s.SaveMultiple([]*model.Thread{thread})
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	return threads[0], nil
    87  }
    88  
    89  func (s *SqlThreadStore) Update(thread *model.Thread) (*model.Thread, error) {
    90  	return s.update(s.GetMaster(), thread)
    91  }
    92  
    93  func (s *SqlThreadStore) update(ex gorp.SqlExecutor, thread *model.Thread) (*model.Thread, error) {
    94  	if _, err := ex.Update(thread); err != nil {
    95  		return nil, errors.Wrapf(err, "failed to update thread with id=%s", thread.PostId)
    96  	}
    97  
    98  	return thread, nil
    99  }
   100  
   101  func (s *SqlThreadStore) Get(id string) (*model.Thread, error) {
   102  	return s.get(s.GetReplica(), id)
   103  }
   104  
   105  func (s *SqlThreadStore) get(ex gorp.SqlExecutor, id string) (*model.Thread, error) {
   106  	var thread model.Thread
   107  	query, args, _ := s.getQueryBuilder().Select("*").From("Threads").Where(sq.Eq{"PostId": id}).ToSql()
   108  	err := ex.SelectOne(&thread, query, args...)
   109  	if err != nil {
   110  		if err == sql.ErrNoRows {
   111  			return nil, nil
   112  		}
   113  
   114  		return nil, errors.Wrapf(err, "failed to get thread with id=%s", id)
   115  	}
   116  	return &thread, nil
   117  }
   118  
   119  func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) {
   120  	type JoinedThread struct {
   121  		PostId         string
   122  		ReplyCount     int64
   123  		LastReplyAt    int64
   124  		LastViewedAt   int64
   125  		UnreadReplies  int64
   126  		UnreadMentions int64
   127  		Participants   model.StringArray
   128  		model.Post
   129  	}
   130  
   131  	unreadRepliesQuery := "SELECT COUNT(Posts.Id) From Posts Where Posts.RootId=ThreadMemberships.PostId AND Posts.CreateAt >= ThreadMemberships.LastViewed"
   132  	fetchConditions := sq.And{
   133  		sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}},
   134  		sq.Eq{"ThreadMemberships.UserId": userId},
   135  		sq.Eq{"ThreadMemberships.Following": true},
   136  	}
   137  	if !opts.Deleted {
   138  		fetchConditions = sq.And{
   139  			fetchConditions,
   140  			sq.Eq{"COALESCE(Posts.DeleteAt, 0)": 0},
   141  		}
   142  	}
   143  
   144  	pageSize := uint64(30)
   145  	if opts.PageSize != 0 {
   146  		pageSize = opts.PageSize
   147  	}
   148  
   149  	totalUnreadThreadsChan := make(chan store.StoreResult, 1)
   150  	totalCountChan := make(chan store.StoreResult, 1)
   151  	totalUnreadMentionsChan := make(chan store.StoreResult, 1)
   152  	threadsChan := make(chan store.StoreResult, 1)
   153  	go func() {
   154  		repliesQuery, repliesQueryArgs, _ := s.getQueryBuilder().
   155  			Select("COUNT(DISTINCT(Posts.RootId))").
   156  			From("Posts").
   157  			LeftJoin("ThreadMemberships ON Posts.RootId = ThreadMemberships.PostId").
   158  			LeftJoin("Channels ON Posts.ChannelId = Channels.Id").
   159  			Where(fetchConditions).
   160  			Where("Posts.CreateAt >= ThreadMemberships.LastViewed").ToSql()
   161  
   162  		totalUnreadThreads, err := s.GetMaster().SelectInt(repliesQuery, repliesQueryArgs...)
   163  		totalUnreadThreadsChan <- store.StoreResult{Data: totalUnreadThreads, NErr: errors.Wrapf(err, "failed to get count unread on threads for user id=%s", userId)}
   164  		close(totalUnreadThreadsChan)
   165  	}()
   166  	go func() {
   167  		newFetchConditions := fetchConditions
   168  
   169  		if opts.Unread {
   170  			newFetchConditions = sq.And{newFetchConditions, sq.Expr("ThreadMemberships.LastViewed < Threads.LastReplyAt")}
   171  		}
   172  
   173  		threadsQuery, threadsQueryArgs, _ := s.getQueryBuilder().
   174  			Select("COUNT(ThreadMemberships.PostId)").
   175  			LeftJoin("Threads ON Threads.PostId = ThreadMemberships.PostId").
   176  			LeftJoin("Channels ON Threads.ChannelId = Channels.Id").
   177  			LeftJoin("Posts ON Posts.Id = ThreadMemberships.PostId").
   178  			From("ThreadMemberships").
   179  			Where(newFetchConditions).ToSql()
   180  
   181  		totalCount, err := s.GetMaster().SelectInt(threadsQuery, threadsQueryArgs...)
   182  		totalCountChan <- store.StoreResult{Data: totalCount, NErr: err}
   183  		close(totalCountChan)
   184  	}()
   185  	go func() {
   186  		mentionsQuery, mentionsQueryArgs, _ := s.getQueryBuilder().
   187  			Select("COALESCE(SUM(ThreadMemberships.UnreadMentions),0)").
   188  			From("ThreadMemberships").
   189  			LeftJoin("Threads ON Threads.PostId = ThreadMemberships.PostId").
   190  			LeftJoin("Posts ON Posts.Id = ThreadMemberships.PostId").
   191  			LeftJoin("Channels ON Threads.ChannelId = Channels.Id").
   192  			Where(fetchConditions).ToSql()
   193  		totalUnreadMentions, err := s.GetMaster().SelectInt(mentionsQuery, mentionsQueryArgs...)
   194  		totalUnreadMentionsChan <- store.StoreResult{Data: totalUnreadMentions, NErr: err}
   195  		close(totalUnreadMentionsChan)
   196  	}()
   197  	go func() {
   198  		newFetchConditions := fetchConditions
   199  		if opts.Since > 0 {
   200  			newFetchConditions = sq.And{newFetchConditions, sq.GtOrEq{"ThreadMemberships.LastUpdated": opts.Since}}
   201  		}
   202  		order := "DESC"
   203  		if opts.Before != "" {
   204  			newFetchConditions = sq.And{
   205  				newFetchConditions,
   206  				sq.Expr(`LastReplyAt < (SELECT LastReplyAt FROM Threads WHERE PostId = ?)`, opts.Before),
   207  			}
   208  		}
   209  		if opts.After != "" {
   210  			order = "ASC"
   211  			newFetchConditions = sq.And{
   212  				newFetchConditions,
   213  				sq.Expr(`LastReplyAt > (SELECT LastReplyAt FROM Threads WHERE PostId = ?)`, opts.After),
   214  			}
   215  		}
   216  		if opts.Unread {
   217  			newFetchConditions = sq.And{newFetchConditions, sq.Expr("ThreadMemberships.LastViewed < Threads.LastReplyAt")}
   218  		}
   219  		var threads []*JoinedThread
   220  		query, args, _ := s.getQueryBuilder().
   221  			Select(`Threads.*,
   222  				` + postSliceCoalesceQuery() + `,
   223  				ThreadMemberships.LastViewed as LastViewedAt,
   224  				ThreadMemberships.UnreadMentions as UnreadMentions`).
   225  			From("Threads").
   226  			Column(sq.Alias(sq.Expr(unreadRepliesQuery), "UnreadReplies")).
   227  			LeftJoin("Posts ON Posts.Id = Threads.PostId").
   228  			LeftJoin("Channels ON Posts.ChannelId = Channels.Id").
   229  			LeftJoin("ThreadMemberships ON ThreadMemberships.PostId = Threads.PostId").
   230  			Where(newFetchConditions).
   231  			OrderBy("Threads.LastReplyAt " + order).
   232  			Limit(pageSize).ToSql()
   233  
   234  		_, err := s.GetReplica().Select(&threads, query, args...)
   235  		threadsChan <- store.StoreResult{Data: threads, NErr: err}
   236  		close(threadsChan)
   237  	}()
   238  
   239  	threadsResult := <-threadsChan
   240  	if threadsResult.NErr != nil {
   241  		return nil, threadsResult.NErr
   242  	}
   243  	threads := threadsResult.Data.([]*JoinedThread)
   244  
   245  	totalUnreadMentionsResult := <-totalUnreadMentionsChan
   246  	if totalUnreadMentionsResult.NErr != nil {
   247  		return nil, totalUnreadMentionsResult.NErr
   248  	}
   249  	totalUnreadMentions := totalUnreadMentionsResult.Data.(int64)
   250  
   251  	totalCountResult := <-totalCountChan
   252  	if totalCountResult.NErr != nil {
   253  		return nil, totalCountResult.NErr
   254  	}
   255  	totalCount := totalCountResult.Data.(int64)
   256  
   257  	totalUnreadThreadsResult := <-totalUnreadThreadsChan
   258  	if totalUnreadThreadsResult.NErr != nil {
   259  		return nil, totalUnreadThreadsResult.NErr
   260  	}
   261  	totalUnreadThreads := totalUnreadThreadsResult.Data.(int64)
   262  
   263  	var userIds []string
   264  	userIdMap := map[string]bool{}
   265  	for _, thread := range threads {
   266  		for _, participantId := range thread.Participants {
   267  			if _, ok := userIdMap[participantId]; !ok {
   268  				userIdMap[participantId] = true
   269  				userIds = append(userIds, participantId)
   270  			}
   271  		}
   272  	}
   273  	var users []*model.User
   274  	if opts.Extended {
   275  		var err error
   276  		users, err = s.User().GetProfileByIds(context.Background(), userIds, &store.UserGetByIdsOpts{}, true)
   277  		if err != nil {
   278  			return nil, errors.Wrapf(err, "failed to get threads for user id=%s", userId)
   279  		}
   280  	} else {
   281  		for _, userId := range userIds {
   282  			users = append(users, &model.User{Id: userId})
   283  		}
   284  	}
   285  
   286  	result := &model.Threads{
   287  		Total:               totalCount,
   288  		Threads:             []*model.ThreadResponse{},
   289  		TotalUnreadMentions: totalUnreadMentions,
   290  		TotalUnreadThreads:  totalUnreadThreads,
   291  	}
   292  
   293  	for _, thread := range threads {
   294  		var participants []*model.User
   295  		for _, participantId := range thread.Participants {
   296  			var participant *model.User
   297  			for _, u := range users {
   298  				if u.Id == participantId {
   299  					participant = u
   300  					break
   301  				}
   302  			}
   303  			if participant == nil {
   304  				return nil, errors.New("cannot find thread participant with id=" + participantId)
   305  			}
   306  			participants = append(participants, participant)
   307  		}
   308  		result.Threads = append(result.Threads, &model.ThreadResponse{
   309  			PostId:         thread.PostId,
   310  			ReplyCount:     thread.ReplyCount,
   311  			LastReplyAt:    thread.LastReplyAt,
   312  			LastViewedAt:   thread.LastViewedAt,
   313  			UnreadReplies:  thread.UnreadReplies,
   314  			UnreadMentions: thread.UnreadMentions,
   315  			Participants:   participants,
   316  			Post:           thread.Post.ToNilIfInvalid(),
   317  		})
   318  	}
   319  
   320  	return result, nil
   321  }
   322  
   323  func (s *SqlThreadStore) GetThreadFollowers(threadID string) ([]string, error) {
   324  	var users []string
   325  	query, args, _ := s.getQueryBuilder().
   326  		Select("ThreadMemberships.UserId").
   327  		From("ThreadMemberships").
   328  		Where(sq.Eq{"PostId": threadID}).ToSql()
   329  	_, err := s.GetReplica().Select(&users, query, args...)
   330  
   331  	if err != nil {
   332  		return nil, err
   333  	}
   334  	return users, nil
   335  }
   336  
   337  func (s *SqlThreadStore) GetThreadForUser(teamId string, threadMembership *model.ThreadMembership, extended bool) (*model.ThreadResponse, error) {
   338  	if !threadMembership.Following {
   339  		return nil, nil // in case the thread is not followed anymore - return nil error to be interpreted as 404
   340  	}
   341  	type JoinedThread struct {
   342  		PostId         string
   343  		Following      bool
   344  		ReplyCount     int64
   345  		LastReplyAt    int64
   346  		LastViewedAt   int64
   347  		UnreadReplies  int64
   348  		UnreadMentions int64
   349  		Participants   model.StringArray
   350  		model.Post
   351  	}
   352  
   353  	unreadRepliesQuery, unreadRepliesArgs := sq.
   354  		Select("COUNT(Posts.Id)").
   355  		From("Posts").
   356  		Where(sq.And{
   357  			sq.Eq{"Posts.RootId": threadMembership.PostId},
   358  			sq.GtOrEq{"Posts.CreateAt": threadMembership.LastViewed},
   359  			sq.Eq{"Posts.DeleteAt": 0},
   360  		}).MustSql()
   361  	fetchConditions := sq.And{
   362  		sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}},
   363  		sq.Eq{"Threads.PostId": threadMembership.PostId},
   364  	}
   365  
   366  	var thread JoinedThread
   367  	query, threadArgs, _ := s.getQueryBuilder().
   368  		Select("Threads.*, Posts.*").
   369  		From("Threads").
   370  		Column(sq.Alias(sq.Expr(unreadRepliesQuery), "UnreadReplies")).
   371  		LeftJoin("Posts ON Posts.Id = Threads.PostId").
   372  		LeftJoin("Channels ON Posts.ChannelId = Channels.Id").
   373  		Where(fetchConditions).ToSql()
   374  	args := append(unreadRepliesArgs, threadArgs...)
   375  	err := s.GetReplica().SelectOne(&thread, query, args...)
   376  
   377  	if err != nil {
   378  		return nil, err
   379  	}
   380  
   381  	thread.LastViewedAt = threadMembership.LastViewed
   382  	thread.UnreadMentions = threadMembership.UnreadMentions
   383  
   384  	var users []*model.User
   385  	if extended {
   386  		var err error
   387  		users, err = s.User().GetProfileByIds(context.Background(), thread.Participants, &store.UserGetByIdsOpts{}, true)
   388  		if err != nil {
   389  			return nil, errors.Wrapf(err, "failed to get thread for user id=%s", threadMembership.UserId)
   390  		}
   391  	} else {
   392  		for _, userId := range thread.Participants {
   393  			users = append(users, &model.User{Id: userId})
   394  		}
   395  	}
   396  
   397  	result := &model.ThreadResponse{
   398  		PostId:         thread.PostId,
   399  		ReplyCount:     thread.ReplyCount,
   400  		LastReplyAt:    thread.LastReplyAt,
   401  		LastViewedAt:   thread.LastViewedAt,
   402  		UnreadReplies:  thread.UnreadReplies,
   403  		UnreadMentions: thread.UnreadMentions,
   404  		Participants:   users,
   405  		Post:           thread.Post.ToNilIfInvalid(),
   406  	}
   407  
   408  	return result, nil
   409  }
   410  
   411  func (s *SqlThreadStore) MarkAllAsReadInChannels(userID string, channelIDs []string) error {
   412  	var threadIDs []string
   413  
   414  	query, args, _ := s.getQueryBuilder().
   415  		Select("ThreadMemberships.PostId").
   416  		Join("Threads ON Threads.PostId = ThreadMemberships.PostId").
   417  		Join("Channels ON Threads.ChannelId = Channels.Id").
   418  		From("ThreadMemberships").
   419  		Where(sq.Eq{"Threads.ChannelId": channelIDs}).
   420  		Where(sq.Eq{"ThreadMemberships.UserId": userID}).
   421  		ToSql()
   422  
   423  	_, err := s.GetReplica().Select(&threadIDs, query, args...)
   424  	if err != nil {
   425  		return errors.Wrapf(err, "failed to get thread membership with userid=%s", userID)
   426  	}
   427  
   428  	timestamp := model.GetMillis()
   429  	query, args, _ = s.getQueryBuilder().
   430  		Update("ThreadMemberships").
   431  		Where(sq.Eq{"PostId": threadIDs}).
   432  		Where(sq.Eq{"UserId": userID}).
   433  		Set("LastViewed", timestamp).
   434  		Set("UnreadMentions", 0).
   435  		ToSql()
   436  	if _, err := s.GetMaster().Exec(query, args...); err != nil {
   437  		return errors.Wrapf(err, "failed to update thread read state for user id=%s", userID)
   438  	}
   439  	return nil
   440  
   441  }
   442  
   443  func (s *SqlThreadStore) MarkAllAsRead(userId, teamId string) error {
   444  	memberships, err := s.GetMembershipsForUser(userId, teamId)
   445  	if err != nil {
   446  		return err
   447  	}
   448  	var membershipIds []string
   449  	for _, m := range memberships {
   450  		membershipIds = append(membershipIds, m.PostId)
   451  	}
   452  	timestamp := model.GetMillis()
   453  	query, args, _ := s.getQueryBuilder().
   454  		Update("ThreadMemberships").
   455  		Where(sq.Eq{"PostId": membershipIds}).
   456  		Where(sq.Eq{"UserId": userId}).
   457  		Set("LastViewed", timestamp).
   458  		Set("UnreadMentions", 0).
   459  		ToSql()
   460  	if _, err := s.GetMaster().Exec(query, args...); err != nil {
   461  		return errors.Wrapf(err, "failed to update thread read state for user id=%s", userId)
   462  	}
   463  	return nil
   464  }
   465  
   466  func (s *SqlThreadStore) MarkAsRead(userId, threadId string, timestamp int64) error {
   467  	query, args, _ := s.getQueryBuilder().
   468  		Update("ThreadMemberships").
   469  		Where(sq.Eq{"UserId": userId}).
   470  		Where(sq.Eq{"PostId": threadId}).
   471  		Set("LastViewed", timestamp).
   472  		ToSql()
   473  	if _, err := s.GetMaster().Exec(query, args...); err != nil {
   474  		return errors.Wrapf(err, "failed to update thread read state for user id=%s thread_id=%v", userId, threadId)
   475  	}
   476  	return nil
   477  }
   478  
   479  func (s *SqlThreadStore) Delete(threadId string) error {
   480  	query, args, _ := s.getQueryBuilder().Delete("Threads").Where(sq.Eq{"PostId": threadId}).ToSql()
   481  	if _, err := s.GetMaster().Exec(query, args...); err != nil {
   482  		return errors.Wrap(err, "failed to update threads")
   483  	}
   484  
   485  	return nil
   486  }
   487  
   488  func (s *SqlThreadStore) SaveMembership(membership *model.ThreadMembership) (*model.ThreadMembership, error) {
   489  	return s.saveMembership(s.GetMaster(), membership)
   490  }
   491  
   492  func (s *SqlThreadStore) saveMembership(ex gorp.SqlExecutor, membership *model.ThreadMembership) (*model.ThreadMembership, error) {
   493  	if err := ex.Insert(membership); err != nil {
   494  		return nil, errors.Wrapf(err, "failed to save thread membership with postid=%s userid=%s", membership.PostId, membership.UserId)
   495  	}
   496  
   497  	return membership, nil
   498  }
   499  
   500  func (s *SqlThreadStore) UpdateMembership(membership *model.ThreadMembership) (*model.ThreadMembership, error) {
   501  	return s.updateMembership(s.GetMaster(), membership)
   502  }
   503  
   504  func (s *SqlThreadStore) updateMembership(ex gorp.SqlExecutor, membership *model.ThreadMembership) (*model.ThreadMembership, error) {
   505  	if _, err := ex.Update(membership); err != nil {
   506  		return nil, errors.Wrapf(err, "failed to update thread membership with postid=%s userid=%s", membership.PostId, membership.UserId)
   507  	}
   508  
   509  	return membership, nil
   510  }
   511  
   512  func (s *SqlThreadStore) GetMembershipsForUser(userId, teamId string) ([]*model.ThreadMembership, error) {
   513  	var memberships []*model.ThreadMembership
   514  
   515  	query, args, _ := s.getQueryBuilder().
   516  		Select("ThreadMemberships.*").
   517  		Join("Threads ON Threads.PostId = ThreadMemberships.PostId").
   518  		Join("Channels ON Threads.ChannelId = Channels.Id").
   519  		From("ThreadMemberships").
   520  		Where(sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}}).
   521  		Where(sq.Eq{"ThreadMemberships.UserId": userId}).
   522  		ToSql()
   523  
   524  	_, err := s.GetReplica().Select(&memberships, query, args...)
   525  	if err != nil {
   526  		return nil, errors.Wrapf(err, "failed to get thread membership with userid=%s", userId)
   527  	}
   528  	return memberships, nil
   529  }
   530  
   531  func (s *SqlThreadStore) GetMembershipForUser(userId, postId string) (*model.ThreadMembership, error) {
   532  	return s.getMembershipForUser(s.GetReplica(), userId, postId)
   533  }
   534  
   535  func (s *SqlThreadStore) getMembershipForUser(ex gorp.SqlExecutor, userId, postId string) (*model.ThreadMembership, error) {
   536  	var membership model.ThreadMembership
   537  	err := ex.SelectOne(&membership, "SELECT * from ThreadMemberships WHERE UserId = :UserId AND PostId = :PostId", map[string]interface{}{"UserId": userId, "PostId": postId})
   538  	if err != nil {
   539  		if err == sql.ErrNoRows {
   540  			return nil, store.NewErrNotFound("Thread", postId)
   541  		}
   542  		return nil, errors.Wrapf(err, "failed to get thread membership with userid=%s postid=%s", userId, postId)
   543  	}
   544  	return &membership, nil
   545  }
   546  
   547  func (s *SqlThreadStore) DeleteMembershipForUser(userId string, postId string) error {
   548  	if _, err := s.GetMaster().Exec("DELETE FROM ThreadMemberships Where PostId = :PostId AND UserId = :UserId", map[string]interface{}{"PostId": postId, "UserId": userId}); err != nil {
   549  		return errors.Wrap(err, "failed to update thread membership")
   550  	}
   551  
   552  	return nil
   553  }
   554  
   555  // MaintainMembership creates or updates a thread membership for the given user
   556  // and post. This method is used to update the state of a membership in response
   557  // to some events like:
   558  // - post creation (mentions handling)
   559  // - channel marked unread
   560  // - user explicitly following a thread
   561  func (s *SqlThreadStore) MaintainMembership(userId, postId string, opts store.ThreadMembershipOpts) (*model.ThreadMembership, error) {
   562  	trx, err := s.GetMaster().Begin()
   563  	if err != nil {
   564  		return nil, errors.Wrap(err, "begin_transaction")
   565  	}
   566  	defer finalizeTransaction(trx)
   567  
   568  	membership, err := s.getMembershipForUser(trx, userId, postId)
   569  	now := utils.MillisFromTime(time.Now())
   570  	// if memebership exists, update it if:
   571  	// a. user started/stopped following a thread
   572  	// b. mention count changed
   573  	// c. user viewed a thread
   574  	if err == nil {
   575  		followingNeedsUpdate := (opts.UpdateFollowing && !membership.Following || membership.Following != opts.Following)
   576  		if followingNeedsUpdate || opts.IncrementMentions || opts.UpdateViewedTimestamp {
   577  			if followingNeedsUpdate {
   578  				membership.Following = opts.Following
   579  			}
   580  			if opts.UpdateViewedTimestamp {
   581  				membership.LastViewed = now
   582  			}
   583  			membership.LastUpdated = now
   584  			if opts.IncrementMentions {
   585  				membership.UnreadMentions += 1
   586  			}
   587  			if _, err = s.updateMembership(trx, membership); err != nil {
   588  				return nil, err
   589  			}
   590  		}
   591  		if err = trx.Commit(); err != nil {
   592  			return nil, errors.Wrap(err, "commit_transaction")
   593  		}
   594  
   595  		return membership, err
   596  	}
   597  
   598  	var nfErr *store.ErrNotFound
   599  
   600  	if !errors.As(err, &nfErr) {
   601  		return nil, errors.Wrap(err, "failed to get thread membership")
   602  	}
   603  	membership = &model.ThreadMembership{
   604  		PostId:      postId,
   605  		UserId:      userId,
   606  		Following:   opts.Following,
   607  		LastUpdated: now,
   608  	}
   609  	if opts.IncrementMentions {
   610  		membership.UnreadMentions = 1
   611  	}
   612  	if opts.UpdateViewedTimestamp {
   613  		membership.LastViewed = now
   614  	}
   615  	membership, err = s.saveMembership(trx, membership)
   616  	if err != nil {
   617  		return nil, err
   618  	}
   619  
   620  	if opts.UpdateParticipants {
   621  		thread, getErr := s.get(trx, postId)
   622  		if getErr != nil {
   623  			return nil, getErr
   624  		}
   625  		if thread != nil && !thread.Participants.Contains(userId) {
   626  			thread.Participants = append(thread.Participants, userId)
   627  			if _, err = s.update(trx, thread); err != nil {
   628  				return nil, err
   629  			}
   630  		}
   631  	}
   632  	if err = trx.Commit(); err != nil {
   633  		return nil, errors.Wrap(err, "commit_transaction")
   634  	}
   635  	return membership, err
   636  }
   637  
   638  func (s *SqlThreadStore) CollectThreadsWithNewerReplies(userId string, channelIds []string, timestamp int64) ([]string, error) {
   639  	var changedThreads []string
   640  	query, args, _ := s.getQueryBuilder().
   641  		Select("Threads.PostId").
   642  		From("Threads").
   643  		LeftJoin("ChannelMembers ON ChannelMembers.ChannelId=Threads.ChannelId").
   644  		Where(sq.And{
   645  			sq.Eq{"Threads.ChannelId": channelIds},
   646  			sq.Eq{"ChannelMembers.UserId": userId},
   647  			sq.Or{
   648  				sq.Expr("Threads.LastReplyAt >= ChannelMembers.LastViewedAt"),
   649  				sq.GtOrEq{"Threads.LastReplyAt": timestamp},
   650  			},
   651  		}).
   652  		ToSql()
   653  	if _, err := s.GetReplica().Select(&changedThreads, query, args...); err != nil {
   654  		return nil, errors.Wrap(err, "failed to fetch threads")
   655  	}
   656  	return changedThreads, nil
   657  }
   658  
   659  func (s *SqlThreadStore) UpdateUnreadsByChannel(userId string, changedThreads []string, timestamp int64, updateViewedTimestamp bool) error {
   660  	if len(changedThreads) == 0 {
   661  		return nil
   662  	}
   663  
   664  	qb := s.getQueryBuilder().
   665  		Update("ThreadMemberships").
   666  		Where(sq.Eq{"UserId": userId, "PostId": changedThreads}).
   667  		Set("LastUpdated", timestamp)
   668  
   669  	if updateViewedTimestamp {
   670  		qb = qb.Set("LastViewed", timestamp)
   671  	}
   672  	updateQuery, updateArgs, _ := qb.ToSql()
   673  
   674  	if _, err := s.GetMaster().Exec(updateQuery, updateArgs...); err != nil {
   675  		return errors.Wrap(err, "failed to update thread membership")
   676  	}
   677  
   678  	return nil
   679  }
   680  
   681  func (s *SqlThreadStore) GetPosts(threadId string, since int64) ([]*model.Post, error) {
   682  	query, args, _ := s.getQueryBuilder().
   683  		Select("*").
   684  		From("Posts").
   685  		Where(sq.Eq{"RootId": threadId}).
   686  		Where(sq.Eq{"DeleteAt": 0}).
   687  		Where(sq.GtOrEq{"UpdateAt": since}).ToSql()
   688  	var result []*model.Post
   689  	if _, err := s.GetReplica().Select(&result, query, args...); err != nil {
   690  		return nil, errors.Wrap(err, "failed to fetch thread posts")
   691  	}
   692  	return result, nil
   693  }
   694  
   695  // PermanentDeleteBatchForRetentionPolicies deletes a batch of records which are affected by
   696  // the global or a granular retention policy.
   697  // See `genericPermanentDeleteBatchForRetentionPolicies` for details.
   698  func (s *SqlThreadStore) PermanentDeleteBatchForRetentionPolicies(now, globalPolicyEndTime, limit int64, cursor model.RetentionPolicyCursor) (int64, model.RetentionPolicyCursor, error) {
   699  	builder := s.getQueryBuilder().
   700  		Select("Threads.PostId").
   701  		From("Threads")
   702  	return genericPermanentDeleteBatchForRetentionPolicies(RetentionPolicyBatchDeletionInfo{
   703  		BaseBuilder:         builder,
   704  		Table:               "Threads",
   705  		TimeColumn:          "LastReplyAt",
   706  		PrimaryKeys:         []string{"PostId"},
   707  		ChannelIDTable:      "Threads",
   708  		NowMillis:           now,
   709  		GlobalPolicyEndTime: globalPolicyEndTime,
   710  		Limit:               limit,
   711  	}, s.SqlStore, cursor)
   712  }
   713  
   714  // PermanentDeleteBatchThreadMembershipsForRetentionPolicies deletes a batch of records
   715  // which are affected by the global or a granular retention policy.
   716  // See `genericPermanentDeleteBatchForRetentionPolicies` for details.
   717  func (s *SqlThreadStore) PermanentDeleteBatchThreadMembershipsForRetentionPolicies(now, globalPolicyEndTime, limit int64, cursor model.RetentionPolicyCursor) (int64, model.RetentionPolicyCursor, error) {
   718  	builder := s.getQueryBuilder().
   719  		Select("ThreadMemberships.PostId").
   720  		From("ThreadMemberships").
   721  		InnerJoin("Threads ON ThreadMemberships.PostId = Threads.PostId")
   722  	return genericPermanentDeleteBatchForRetentionPolicies(RetentionPolicyBatchDeletionInfo{
   723  		BaseBuilder:         builder,
   724  		Table:               "ThreadMemberships",
   725  		TimeColumn:          "LastUpdated",
   726  		PrimaryKeys:         []string{"PostId"},
   727  		ChannelIDTable:      "Threads",
   728  		NowMillis:           now,
   729  		GlobalPolicyEndTime: globalPolicyEndTime,
   730  		Limit:               limit,
   731  	}, s.SqlStore, cursor)
   732  }
   733  
   734  // DeleteOrphanedRows removes orphaned rows from Threads and ThreadMemberships
   735  func (s *SqlThreadStore) DeleteOrphanedRows(limit int) (deleted int64, err error) {
   736  	// We need the extra level of nesting to deal with MySQL's locking
   737  	const threadsQuery = `
   738  	DELETE FROM Threads WHERE PostId IN (
   739  		SELECT * FROM (
   740  			SELECT Threads.PostId FROM Threads
   741  			LEFT JOIN Channels ON Threads.ChannelId = Channels.Id
   742  			WHERE Channels.Id IS NULL
   743  			LIMIT :Limit
   744  		) AS A
   745  	)`
   746  	// We only delete a thread membership if the entire thread no longer exists,
   747  	// not if the root post has been deleted
   748  	const threadMembershipsQuery = `
   749  	DELETE FROM ThreadMemberships WHERE PostId IN (
   750  		SELECT * FROM (
   751  			SELECT ThreadMemberships.PostId FROM ThreadMemberships
   752  			LEFT JOIN Threads ON ThreadMemberships.PostId = Threads.PostId
   753  			WHERE Threads.PostId IS NULL
   754  			LIMIT :Limit
   755  		) AS A
   756  	)`
   757  	props := map[string]interface{}{"Limit": limit}
   758  	result, err := s.GetMaster().Exec(threadsQuery, props)
   759  	if err != nil {
   760  		return
   761  	}
   762  	rpcDeleted, err := result.RowsAffected()
   763  	if err != nil {
   764  		return
   765  	}
   766  	result, err = s.GetMaster().Exec(threadMembershipsQuery, props)
   767  	if err != nil {
   768  		return
   769  	}
   770  	rptDeleted, err := result.RowsAffected()
   771  	if err != nil {
   772  		return
   773  	}
   774  	deleted = rpcDeleted + rptDeleted
   775  	return
   776  }