github.com/mattermost/mattermost-server/v5@v5.39.3/store/sqlstore/thread_store.go (about) 1 // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. 2 // See LICENSE.txt for license information. 3 4 package sqlstore 5 6 import ( 7 "context" 8 "database/sql" 9 "time" 10 11 sq "github.com/Masterminds/squirrel" 12 "github.com/mattermost/gorp" 13 "github.com/pkg/errors" 14 15 "github.com/mattermost/mattermost-server/v5/model" 16 "github.com/mattermost/mattermost-server/v5/store" 17 "github.com/mattermost/mattermost-server/v5/utils" 18 ) 19 20 type SqlThreadStore struct { 21 *SqlStore 22 } 23 24 func (s *SqlThreadStore) ClearCaches() { 25 } 26 27 func newSqlThreadStore(sqlStore *SqlStore) store.ThreadStore { 28 s := &SqlThreadStore{ 29 SqlStore: sqlStore, 30 } 31 32 for _, db := range sqlStore.GetAllConns() { 33 tableThreads := db.AddTableWithName(model.Thread{}, "Threads").SetKeys(false, "PostId") 34 tableThreads.ColMap("PostId").SetMaxSize(26) 35 tableThreads.ColMap("ChannelId").SetMaxSize(26) 36 tableThreads.ColMap("Participants").SetMaxSize(0) 37 tableThreadMemberships := db.AddTableWithName(model.ThreadMembership{}, "ThreadMemberships").SetKeys(false, "PostId", "UserId") 38 tableThreadMemberships.ColMap("PostId").SetMaxSize(26) 39 tableThreadMemberships.ColMap("UserId").SetMaxSize(26) 40 } 41 42 return s 43 } 44 45 func threadSliceColumns() []string { 46 return []string{"PostId", "ChannelId", "LastReplyAt", "ReplyCount", "Participants"} 47 } 48 49 func threadToSlice(thread *model.Thread) []interface{} { 50 return []interface{}{ 51 thread.PostId, 52 thread.ChannelId, 53 thread.LastReplyAt, 54 thread.ReplyCount, 55 model.ArrayToJson(thread.Participants), 56 } 57 } 58 59 func (s *SqlThreadStore) createIndexesIfNotExists() { 60 s.CreateIndexIfNotExists("idx_thread_memberships_last_update_at", "ThreadMemberships", "LastUpdated") 61 s.CreateIndexIfNotExists("idx_thread_memberships_last_view_at", "ThreadMemberships", "LastViewed") 62 s.CreateIndexIfNotExists("idx_thread_memberships_user_id", "ThreadMemberships", "UserId") 63 s.CreateIndexIfNotExists("idx_threads_channel_id", "Threads", "ChannelId") 64 } 65 66 func (s *SqlThreadStore) SaveMultiple(threads []*model.Thread) ([]*model.Thread, int, error) { 67 builder := s.getQueryBuilder().Insert("Threads").Columns(threadSliceColumns()...) 68 for _, thread := range threads { 69 builder = builder.Values(threadToSlice(thread)...) 70 } 71 query, args, err := builder.ToSql() 72 if err != nil { 73 return nil, -1, errors.Wrap(err, "thread_tosql") 74 } 75 76 if _, err := s.GetMaster().Exec(query, args...); err != nil { 77 return nil, -1, errors.Wrap(err, "failed to save Post") 78 } 79 80 return threads, -1, nil 81 } 82 83 func (s *SqlThreadStore) Save(thread *model.Thread) (*model.Thread, error) { 84 threads, _, err := s.SaveMultiple([]*model.Thread{thread}) 85 if err != nil { 86 return nil, err 87 } 88 return threads[0], nil 89 } 90 91 func (s *SqlThreadStore) Update(thread *model.Thread) (*model.Thread, error) { 92 return s.update(s.GetMaster(), thread) 93 } 94 95 func (s *SqlThreadStore) update(ex gorp.SqlExecutor, thread *model.Thread) (*model.Thread, error) { 96 if _, err := ex.Update(thread); err != nil { 97 return nil, errors.Wrapf(err, "failed to update thread with id=%s", thread.PostId) 98 } 99 100 return thread, nil 101 } 102 103 func (s *SqlThreadStore) Get(id string) (*model.Thread, error) { 104 return s.get(s.GetReplica(), id) 105 } 106 107 func (s *SqlThreadStore) get(ex gorp.SqlExecutor, id string) (*model.Thread, error) { 108 var thread model.Thread 109 query, args, _ := s.getQueryBuilder().Select("*").From("Threads").Where(sq.Eq{"PostId": id}).ToSql() 110 err := ex.SelectOne(&thread, query, args...) 111 if err != nil { 112 if err == sql.ErrNoRows { 113 return nil, nil 114 } 115 116 return nil, errors.Wrapf(err, "failed to get thread with id=%s", id) 117 } 118 return &thread, nil 119 } 120 121 func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) { 122 type JoinedThread struct { 123 PostId string 124 ReplyCount int64 125 LastReplyAt int64 126 LastViewedAt int64 127 UnreadReplies int64 128 UnreadMentions int64 129 Participants model.StringArray 130 model.Post 131 } 132 133 fetchConditions := sq.And{ 134 sq.Eq{"ThreadMemberships.UserId": userId}, 135 sq.Eq{"ThreadMemberships.Following": true}, 136 } 137 if opts.TeamOnly { 138 fetchConditions = sq.And{ 139 sq.Eq{"Channels.TeamId": teamId}, 140 fetchConditions, 141 } 142 } else { 143 fetchConditions = sq.And{ 144 sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}}, 145 fetchConditions, 146 } 147 } 148 if !opts.Deleted { 149 fetchConditions = sq.And{ 150 fetchConditions, 151 sq.Eq{"COALESCE(Posts.DeleteAt, 0)": 0}, 152 } 153 } 154 155 pageSize := uint64(30) 156 if opts.PageSize != 0 { 157 pageSize = opts.PageSize 158 } 159 160 totalUnreadThreadsChan := make(chan store.StoreResult, 1) 161 totalCountChan := make(chan store.StoreResult, 1) 162 totalUnreadMentionsChan := make(chan store.StoreResult, 1) 163 var threadsChan chan store.StoreResult 164 if !opts.TotalsOnly { 165 threadsChan = make(chan store.StoreResult, 1) 166 } 167 168 go func() { 169 repliesQuery, repliesQueryArgs, _ := s.getQueryBuilder(). 170 Select("COUNT(DISTINCT(Posts.RootId))"). 171 From("Posts"). 172 LeftJoin("ThreadMemberships ON Posts.RootId = ThreadMemberships.PostId"). 173 LeftJoin("Channels ON Posts.ChannelId = Channels.Id"). 174 Where(fetchConditions). 175 Where("Posts.CreateAt > ThreadMemberships.LastViewed").ToSql() 176 177 totalUnreadThreads, err := s.GetMaster().SelectInt(repliesQuery, repliesQueryArgs...) 178 totalUnreadThreadsChan <- store.StoreResult{Data: totalUnreadThreads, NErr: errors.Wrapf(err, "failed to get count unread on threads for user id=%s", userId)} 179 close(totalUnreadThreadsChan) 180 }() 181 go func() { 182 newFetchConditions := fetchConditions 183 184 if opts.Unread { 185 newFetchConditions = sq.And{newFetchConditions, sq.Expr("ThreadMemberships.LastViewed < Threads.LastReplyAt")} 186 } 187 188 threadsQuery, threadsQueryArgs, _ := s.getQueryBuilder(). 189 Select("COUNT(ThreadMemberships.PostId)"). 190 LeftJoin("Threads ON Threads.PostId = ThreadMemberships.PostId"). 191 LeftJoin("Channels ON Threads.ChannelId = Channels.Id"). 192 LeftJoin("Posts ON Posts.Id = ThreadMemberships.PostId"). 193 From("ThreadMemberships"). 194 Where(newFetchConditions).ToSql() 195 196 totalCount, err := s.GetMaster().SelectInt(threadsQuery, threadsQueryArgs...) 197 totalCountChan <- store.StoreResult{Data: totalCount, NErr: err} 198 close(totalCountChan) 199 }() 200 go func() { 201 mentionsQuery, mentionsQueryArgs, _ := s.getQueryBuilder(). 202 Select("COALESCE(SUM(ThreadMemberships.UnreadMentions),0)"). 203 From("ThreadMemberships"). 204 LeftJoin("Threads ON Threads.PostId = ThreadMemberships.PostId"). 205 LeftJoin("Posts ON Posts.Id = ThreadMemberships.PostId"). 206 LeftJoin("Channels ON Threads.ChannelId = Channels.Id"). 207 Where(fetchConditions).ToSql() 208 totalUnreadMentions, err := s.GetMaster().SelectInt(mentionsQuery, mentionsQueryArgs...) 209 totalUnreadMentionsChan <- store.StoreResult{Data: totalUnreadMentions, NErr: err} 210 close(totalUnreadMentionsChan) 211 }() 212 213 if !opts.TotalsOnly { 214 go func() { 215 newFetchConditions := fetchConditions 216 if opts.Since > 0 { 217 newFetchConditions = sq.And{newFetchConditions, sq.GtOrEq{"ThreadMemberships.LastUpdated": opts.Since}} 218 } 219 order := "DESC" 220 if opts.Before != "" { 221 newFetchConditions = sq.And{ 222 newFetchConditions, 223 sq.Expr(`LastReplyAt < (SELECT LastReplyAt FROM Threads WHERE PostId = ?)`, opts.Before), 224 } 225 } 226 if opts.After != "" { 227 order = "ASC" 228 newFetchConditions = sq.And{ 229 newFetchConditions, 230 sq.Expr(`LastReplyAt > (SELECT LastReplyAt FROM Threads WHERE PostId = ?)`, opts.After), 231 } 232 } 233 if opts.Unread { 234 newFetchConditions = sq.And{newFetchConditions, sq.Expr("ThreadMemberships.LastViewed < Threads.LastReplyAt")} 235 } 236 237 unreadRepliesFetchConditions := sq.And{ 238 sq.Expr("Posts.RootId = ThreadMemberships.PostId"), 239 sq.Expr("Posts.CreateAt > ThreadMemberships.LastViewed"), 240 } 241 if !opts.Deleted { 242 unreadRepliesFetchConditions = sq.And{ 243 unreadRepliesFetchConditions, 244 sq.Expr("Posts.DeleteAt = 0"), 245 } 246 } 247 248 unreadRepliesQuery, _ := sq. 249 Select("COUNT(Posts.Id)"). 250 From("Posts"). 251 Where(unreadRepliesFetchConditions). 252 MustSql() 253 254 var threads []*JoinedThread 255 query, args, _ := s.getQueryBuilder(). 256 Select(`Threads.*, 257 ` + postSliceCoalesceQuery() + `, 258 ThreadMemberships.LastViewed as LastViewedAt, 259 ThreadMemberships.UnreadMentions as UnreadMentions`). 260 From("Threads"). 261 Column(sq.Alias(sq.Expr(unreadRepliesQuery), "UnreadReplies")). 262 LeftJoin("Posts ON Posts.Id = Threads.PostId"). 263 LeftJoin("Channels ON Posts.ChannelId = Channels.Id"). 264 LeftJoin("ThreadMemberships ON ThreadMemberships.PostId = Threads.PostId"). 265 Where(newFetchConditions). 266 OrderBy("Threads.LastReplyAt " + order). 267 Limit(pageSize).ToSql() 268 269 _, err := s.GetReplica().Select(&threads, query, args...) 270 threadsChan <- store.StoreResult{Data: threads, NErr: err} 271 close(threadsChan) 272 }() 273 } 274 275 totalUnreadMentionsResult := <-totalUnreadMentionsChan 276 if totalUnreadMentionsResult.NErr != nil { 277 return nil, totalUnreadMentionsResult.NErr 278 } 279 totalUnreadMentions := totalUnreadMentionsResult.Data.(int64) 280 281 totalCountResult := <-totalCountChan 282 if totalCountResult.NErr != nil { 283 return nil, totalCountResult.NErr 284 } 285 totalCount := totalCountResult.Data.(int64) 286 287 totalUnreadThreadsResult := <-totalUnreadThreadsChan 288 if totalUnreadThreadsResult.NErr != nil { 289 return nil, totalUnreadThreadsResult.NErr 290 } 291 totalUnreadThreads := totalUnreadThreadsResult.Data.(int64) 292 293 var userIds []string 294 userIdMap := map[string]bool{} 295 296 result := &model.Threads{ 297 Total: totalCount, 298 Threads: []*model.ThreadResponse{}, 299 TotalUnreadMentions: totalUnreadMentions, 300 TotalUnreadThreads: totalUnreadThreads, 301 } 302 303 if !opts.TotalsOnly { 304 threadsResult := <-threadsChan 305 if threadsResult.NErr != nil { 306 return nil, threadsResult.NErr 307 } 308 threads := threadsResult.Data.([]*JoinedThread) 309 for _, thread := range threads { 310 for _, participantId := range thread.Participants { 311 if _, ok := userIdMap[participantId]; !ok { 312 userIdMap[participantId] = true 313 userIds = append(userIds, participantId) 314 } 315 } 316 } 317 var users []*model.User 318 if opts.Extended { 319 var err error 320 users, err = s.User().GetProfileByIds(context.Background(), userIds, &store.UserGetByIdsOpts{}, true) 321 if err != nil { 322 return nil, errors.Wrapf(err, "failed to get threads for user id=%s", userId) 323 } 324 } else { 325 for _, userId := range userIds { 326 users = append(users, &model.User{Id: userId}) 327 } 328 } 329 330 for _, thread := range threads { 331 var participants []*model.User 332 for _, participantId := range thread.Participants { 333 var participant *model.User 334 for _, u := range users { 335 if u.Id == participantId { 336 participant = u 337 break 338 } 339 } 340 if participant == nil { 341 return nil, errors.New("cannot find thread participant with id=" + participantId) 342 } 343 participants = append(participants, participant) 344 } 345 result.Threads = append(result.Threads, &model.ThreadResponse{ 346 PostId: thread.PostId, 347 ReplyCount: thread.ReplyCount, 348 LastReplyAt: thread.LastReplyAt, 349 LastViewedAt: thread.LastViewedAt, 350 UnreadReplies: thread.UnreadReplies, 351 UnreadMentions: thread.UnreadMentions, 352 Participants: participants, 353 Post: thread.Post.ToNilIfInvalid(), 354 }) 355 } 356 } 357 358 return result, nil 359 } 360 361 func (s *SqlThreadStore) GetThreadFollowers(threadID string) ([]string, error) { 362 var users []string 363 query, args, _ := s.getQueryBuilder(). 364 Select("ThreadMemberships.UserId"). 365 From("ThreadMemberships"). 366 Where(sq.Eq{"PostId": threadID}).ToSql() 367 _, err := s.GetReplica().Select(&users, query, args...) 368 369 if err != nil { 370 return nil, err 371 } 372 return users, nil 373 } 374 375 func (s *SqlThreadStore) GetThreadForUser(teamId string, threadMembership *model.ThreadMembership, extended bool) (*model.ThreadResponse, error) { 376 if !threadMembership.Following { 377 return nil, nil // in case the thread is not followed anymore - return nil error to be interpreted as 404 378 } 379 380 type JoinedThread struct { 381 PostId string 382 Following bool 383 ReplyCount int64 384 LastReplyAt int64 385 LastViewedAt int64 386 UnreadReplies int64 387 UnreadMentions int64 388 Participants model.StringArray 389 model.Post 390 } 391 392 unreadRepliesQuery, unreadRepliesArgs := sq. 393 Select("COUNT(Posts.Id)"). 394 From("Posts"). 395 Where(sq.And{ 396 sq.Eq{"Posts.RootId": threadMembership.PostId}, 397 sq.Gt{"Posts.CreateAt": threadMembership.LastViewed}, 398 sq.Eq{"Posts.DeleteAt": 0}, 399 }).MustSql() 400 401 fetchConditions := sq.And{ 402 sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}}, 403 sq.Eq{"Threads.PostId": threadMembership.PostId}, 404 } 405 406 var thread JoinedThread 407 query, threadArgs, _ := s.getQueryBuilder(). 408 Select("Threads.*, Posts.*"). 409 From("Threads"). 410 Column(sq.Alias(sq.Expr(unreadRepliesQuery), "UnreadReplies")). 411 LeftJoin("Posts ON Posts.Id = Threads.PostId"). 412 LeftJoin("Channels ON Posts.ChannelId = Channels.Id"). 413 Where(fetchConditions).ToSql() 414 415 args := append(unreadRepliesArgs, threadArgs...) 416 417 err := s.GetReplica().SelectOne(&thread, query, args...) 418 if err != nil { 419 return nil, err 420 } 421 422 thread.LastViewedAt = threadMembership.LastViewed 423 thread.UnreadMentions = threadMembership.UnreadMentions 424 425 var users []*model.User 426 if extended { 427 var err error 428 users, err = s.User().GetProfileByIds(context.Background(), thread.Participants, &store.UserGetByIdsOpts{}, true) 429 if err != nil { 430 return nil, errors.Wrapf(err, "failed to get thread for user id=%s", threadMembership.UserId) 431 } 432 } else { 433 for _, userId := range thread.Participants { 434 users = append(users, &model.User{Id: userId}) 435 } 436 } 437 438 result := &model.ThreadResponse{ 439 PostId: thread.PostId, 440 ReplyCount: thread.ReplyCount, 441 LastReplyAt: thread.LastReplyAt, 442 LastViewedAt: thread.LastViewedAt, 443 UnreadReplies: thread.UnreadReplies, 444 UnreadMentions: thread.UnreadMentions, 445 Participants: users, 446 Post: thread.Post.ToNilIfInvalid(), 447 } 448 449 return result, nil 450 } 451 func (s *SqlThreadStore) MarkAllAsReadInChannels(userID string, channelIDs []string) error { 452 var threadIDs []string 453 454 query, args, _ := s.getQueryBuilder(). 455 Select("ThreadMemberships.PostId"). 456 Join("Threads ON Threads.PostId = ThreadMemberships.PostId"). 457 Join("Channels ON Threads.ChannelId = Channels.Id"). 458 From("ThreadMemberships"). 459 Where(sq.Eq{"Threads.ChannelId": channelIDs}). 460 Where(sq.Eq{"ThreadMemberships.UserId": userID}). 461 ToSql() 462 463 _, err := s.GetReplica().Select(&threadIDs, query, args...) 464 if err != nil { 465 return errors.Wrapf(err, "failed to get thread membership with userid=%s", userID) 466 } 467 468 timestamp := model.GetMillis() 469 query, args, _ = s.getQueryBuilder(). 470 Update("ThreadMemberships"). 471 Where(sq.Eq{"PostId": threadIDs}). 472 Where(sq.Eq{"UserId": userID}). 473 Set("LastViewed", timestamp). 474 Set("UnreadMentions", 0). 475 ToSql() 476 if _, err := s.GetMaster().Exec(query, args...); err != nil { 477 return errors.Wrapf(err, "failed to update thread read state for user id=%s", userID) 478 } 479 return nil 480 481 } 482 func (s *SqlThreadStore) MarkAllAsRead(userId, teamId string) error { 483 memberships, err := s.GetMembershipsForUser(userId, teamId) 484 if err != nil { 485 return err 486 } 487 var membershipIds []string 488 for _, m := range memberships { 489 membershipIds = append(membershipIds, m.PostId) 490 } 491 timestamp := model.GetMillis() 492 query, args, _ := s.getQueryBuilder(). 493 Update("ThreadMemberships"). 494 Where(sq.Eq{"PostId": membershipIds}). 495 Where(sq.Eq{"UserId": userId}). 496 Set("LastViewed", timestamp). 497 Set("UnreadMentions", 0). 498 ToSql() 499 if _, err := s.GetMaster().Exec(query, args...); err != nil { 500 return errors.Wrapf(err, "failed to update thread read state for user id=%s", userId) 501 } 502 return nil 503 } 504 505 func (s *SqlThreadStore) MarkAsRead(userId, threadId string, timestamp int64) error { 506 query, args, _ := s.getQueryBuilder(). 507 Update("ThreadMemberships"). 508 Where(sq.Eq{"UserId": userId}). 509 Where(sq.Eq{"PostId": threadId}). 510 Set("LastViewed", timestamp). 511 ToSql() 512 if _, err := s.GetMaster().Exec(query, args...); err != nil { 513 return errors.Wrapf(err, "failed to update thread read state for user id=%s thread_id=%v", userId, threadId) 514 } 515 return nil 516 } 517 518 func (s *SqlThreadStore) Delete(threadId string) error { 519 query, args, _ := s.getQueryBuilder().Delete("Threads").Where(sq.Eq{"PostId": threadId}).ToSql() 520 if _, err := s.GetMaster().Exec(query, args...); err != nil { 521 return errors.Wrap(err, "failed to update threads") 522 } 523 524 return nil 525 } 526 527 func (s *SqlThreadStore) SaveMembership(membership *model.ThreadMembership) (*model.ThreadMembership, error) { 528 return s.saveMembership(s.GetMaster(), membership) 529 } 530 531 func (s *SqlThreadStore) saveMembership(ex gorp.SqlExecutor, membership *model.ThreadMembership) (*model.ThreadMembership, error) { 532 if err := ex.Insert(membership); err != nil { 533 return nil, errors.Wrapf(err, "failed to save thread membership with postid=%s userid=%s", membership.PostId, membership.UserId) 534 } 535 536 return membership, nil 537 } 538 539 func (s *SqlThreadStore) UpdateMembership(membership *model.ThreadMembership) (*model.ThreadMembership, error) { 540 return s.updateMembership(s.GetMaster(), membership) 541 } 542 543 func (s *SqlThreadStore) updateMembership(ex gorp.SqlExecutor, membership *model.ThreadMembership) (*model.ThreadMembership, error) { 544 if _, err := ex.Update(membership); err != nil { 545 return nil, errors.Wrapf(err, "failed to update thread membership with postid=%s userid=%s", membership.PostId, membership.UserId) 546 } 547 548 return membership, nil 549 } 550 551 func (s *SqlThreadStore) GetMembershipsForUser(userId, teamId string) ([]*model.ThreadMembership, error) { 552 var memberships []*model.ThreadMembership 553 554 query, args, _ := s.getQueryBuilder(). 555 Select("ThreadMemberships.*"). 556 Join("Threads ON Threads.PostId = ThreadMemberships.PostId"). 557 Join("Channels ON Threads.ChannelId = Channels.Id"). 558 From("ThreadMemberships"). 559 Where(sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}}). 560 Where(sq.Eq{"ThreadMemberships.UserId": userId}). 561 ToSql() 562 563 _, err := s.GetReplica().Select(&memberships, query, args...) 564 if err != nil { 565 return nil, errors.Wrapf(err, "failed to get thread membership with userid=%s", userId) 566 } 567 return memberships, nil 568 } 569 570 func (s *SqlThreadStore) GetMembershipForUser(userId, postId string) (*model.ThreadMembership, error) { 571 return s.getMembershipForUser(s.GetReplica(), userId, postId) 572 } 573 574 func (s *SqlThreadStore) getMembershipForUser(ex gorp.SqlExecutor, userId, postId string) (*model.ThreadMembership, error) { 575 var membership model.ThreadMembership 576 err := ex.SelectOne(&membership, "SELECT * from ThreadMemberships WHERE UserId = :UserId AND PostId = :PostId", map[string]interface{}{"UserId": userId, "PostId": postId}) 577 if err != nil { 578 if err == sql.ErrNoRows { 579 return nil, store.NewErrNotFound("Thread", postId) 580 } 581 return nil, errors.Wrapf(err, "failed to get thread membership with userid=%s postid=%s", userId, postId) 582 } 583 return &membership, nil 584 } 585 586 func (s *SqlThreadStore) DeleteMembershipForUser(userId string, postId string) error { 587 if _, err := s.GetMaster().Exec("DELETE FROM ThreadMemberships Where PostId = :PostId AND UserId = :UserId", map[string]interface{}{"PostId": postId, "UserId": userId}); err != nil { 588 return errors.Wrap(err, "failed to update thread membership") 589 } 590 591 return nil 592 } 593 594 // MaintainMembership creates or updates a thread membership for the given user 595 // and post. This method is used to update the state of a membership in response 596 // to some events like: 597 // - post creation (mentions handling) 598 // - channel marked unread 599 // - user explicitly following a thread 600 func (s *SqlThreadStore) MaintainMembership(userId, postId string, opts store.ThreadMembershipOpts) (*model.ThreadMembership, error) { 601 trx, err := s.GetMaster().Begin() 602 if err != nil { 603 return nil, errors.Wrap(err, "begin_transaction") 604 } 605 defer finalizeTransaction(trx) 606 607 membership, err := s.getMembershipForUser(trx, userId, postId) 608 now := utils.MillisFromTime(time.Now()) 609 // if memebership exists, update it if: 610 // a. user started/stopped following a thread 611 // b. mention count changed 612 // c. user viewed a thread 613 if err == nil { 614 followingNeedsUpdate := (opts.UpdateFollowing && (membership.Following != opts.Following)) 615 if followingNeedsUpdate || opts.IncrementMentions || opts.UpdateViewedTimestamp { 616 if followingNeedsUpdate { 617 membership.Following = opts.Following 618 } 619 if opts.UpdateViewedTimestamp { 620 membership.LastViewed = now 621 } 622 membership.LastUpdated = now 623 if opts.IncrementMentions { 624 membership.UnreadMentions += 1 625 } 626 if _, err = s.updateMembership(trx, membership); err != nil { 627 return nil, err 628 } 629 } 630 631 if err = trx.Commit(); err != nil { 632 return nil, errors.Wrap(err, "commit_transaction") 633 } 634 635 return membership, err 636 } 637 638 var nfErr *store.ErrNotFound 639 if !errors.As(err, &nfErr) { 640 return nil, errors.Wrap(err, "failed to get thread membership") 641 } 642 643 membership = &model.ThreadMembership{ 644 PostId: postId, 645 UserId: userId, 646 Following: opts.Following, 647 LastUpdated: now, 648 } 649 if opts.IncrementMentions { 650 membership.UnreadMentions = 1 651 } 652 if opts.UpdateViewedTimestamp { 653 membership.LastViewed = now 654 } 655 membership, err = s.saveMembership(trx, membership) 656 if err != nil { 657 return nil, err 658 } 659 660 if opts.UpdateParticipants { 661 thread, getErr := s.get(trx, postId) 662 if getErr != nil { 663 return nil, getErr 664 } 665 if thread != nil && !thread.Participants.Contains(userId) { 666 thread.Participants = append(thread.Participants, userId) 667 if _, err = s.update(trx, thread); err != nil { 668 return nil, err 669 } 670 } 671 } 672 673 if err = trx.Commit(); err != nil { 674 return nil, errors.Wrap(err, "commit_transaction") 675 } 676 677 return membership, err 678 } 679 680 func (s *SqlThreadStore) CollectThreadsWithNewerReplies(userId string, channelIds []string, timestamp int64) ([]string, error) { 681 var changedThreads []string 682 query, args, _ := s.getQueryBuilder(). 683 Select("Threads.PostId"). 684 From("Threads"). 685 LeftJoin("ChannelMembers ON ChannelMembers.ChannelId=Threads.ChannelId"). 686 Where(sq.And{ 687 sq.Eq{"Threads.ChannelId": channelIds}, 688 sq.Eq{"ChannelMembers.UserId": userId}, 689 sq.Or{ 690 sq.Expr("Threads.LastReplyAt > ChannelMembers.LastViewedAt"), 691 sq.Gt{"Threads.LastReplyAt": timestamp}, 692 }, 693 }). 694 ToSql() 695 if _, err := s.GetReplica().Select(&changedThreads, query, args...); err != nil { 696 return nil, errors.Wrap(err, "failed to fetch threads") 697 } 698 return changedThreads, nil 699 } 700 701 func (s *SqlThreadStore) UpdateUnreadsByChannel(userId string, changedThreads []string, timestamp int64, updateViewedTimestamp bool) error { 702 if len(changedThreads) == 0 { 703 return nil 704 } 705 706 qb := s.getQueryBuilder(). 707 Update("ThreadMemberships"). 708 Where(sq.Eq{"UserId": userId, "PostId": changedThreads}). 709 Set("LastUpdated", timestamp) 710 711 if updateViewedTimestamp { 712 qb = qb.Set("LastViewed", timestamp) 713 } 714 updateQuery, updateArgs, _ := qb.ToSql() 715 716 if _, err := s.GetMaster().Exec(updateQuery, updateArgs...); err != nil { 717 return errors.Wrap(err, "failed to update thread membership") 718 } 719 720 return nil 721 } 722 723 func (s *SqlThreadStore) GetPosts(threadId string, since int64) ([]*model.Post, error) { 724 query, args, _ := s.getQueryBuilder(). 725 Select("*"). 726 From("Posts"). 727 Where(sq.Eq{"RootId": threadId}). 728 Where(sq.Eq{"DeleteAt": 0}). 729 Where(sq.GtOrEq{"UpdateAt": since}).ToSql() 730 var result []*model.Post 731 if _, err := s.GetReplica().Select(&result, query, args...); err != nil { 732 return nil, errors.Wrap(err, "failed to fetch thread posts") 733 } 734 return result, nil 735 } 736 737 // PermanentDeleteBatchForRetentionPolicies deletes a batch of records which are affected by 738 // the global or a granular retention policy. 739 // See `genericPermanentDeleteBatchForRetentionPolicies` for details. 740 func (s *SqlThreadStore) PermanentDeleteBatchForRetentionPolicies(now, globalPolicyEndTime, limit int64, cursor model.RetentionPolicyCursor) (int64, model.RetentionPolicyCursor, error) { 741 builder := s.getQueryBuilder(). 742 Select("Threads.PostId"). 743 From("Threads") 744 return genericPermanentDeleteBatchForRetentionPolicies(RetentionPolicyBatchDeletionInfo{ 745 BaseBuilder: builder, 746 Table: "Threads", 747 TimeColumn: "LastReplyAt", 748 PrimaryKeys: []string{"PostId"}, 749 ChannelIDTable: "Threads", 750 NowMillis: now, 751 GlobalPolicyEndTime: globalPolicyEndTime, 752 Limit: limit, 753 }, s.SqlStore, cursor) 754 } 755 756 // PermanentDeleteBatchThreadMembershipsForRetentionPolicies deletes a batch of records 757 // which are affected by the global or a granular retention policy. 758 // See `genericPermanentDeleteBatchForRetentionPolicies` for details. 759 func (s *SqlThreadStore) PermanentDeleteBatchThreadMembershipsForRetentionPolicies(now, globalPolicyEndTime, limit int64, cursor model.RetentionPolicyCursor) (int64, model.RetentionPolicyCursor, error) { 760 builder := s.getQueryBuilder(). 761 Select("ThreadMemberships.PostId"). 762 From("ThreadMemberships"). 763 InnerJoin("Threads ON ThreadMemberships.PostId = Threads.PostId") 764 return genericPermanentDeleteBatchForRetentionPolicies(RetentionPolicyBatchDeletionInfo{ 765 BaseBuilder: builder, 766 Table: "ThreadMemberships", 767 TimeColumn: "LastUpdated", 768 PrimaryKeys: []string{"PostId"}, 769 ChannelIDTable: "Threads", 770 NowMillis: now, 771 GlobalPolicyEndTime: globalPolicyEndTime, 772 Limit: limit, 773 }, s.SqlStore, cursor) 774 } 775 776 // DeleteOrphanedRows removes orphaned rows from Threads and ThreadMemberships 777 func (s *SqlThreadStore) DeleteOrphanedRows(limit int) (deleted int64, err error) { 778 // We need the extra level of nesting to deal with MySQL's locking 779 const threadsQuery = ` 780 DELETE FROM Threads WHERE PostId IN ( 781 SELECT * FROM ( 782 SELECT Threads.PostId FROM Threads 783 LEFT JOIN Channels ON Threads.ChannelId = Channels.Id 784 WHERE Channels.Id IS NULL 785 LIMIT :Limit 786 ) AS A 787 )` 788 // We only delete a thread membership if the entire thread no longer exists, 789 // not if the root post has been deleted 790 const threadMembershipsQuery = ` 791 DELETE FROM ThreadMemberships WHERE PostId IN ( 792 SELECT * FROM ( 793 SELECT ThreadMemberships.PostId FROM ThreadMemberships 794 LEFT JOIN Threads ON ThreadMemberships.PostId = Threads.PostId 795 WHERE Threads.PostId IS NULL 796 LIMIT :Limit 797 ) AS A 798 )` 799 props := map[string]interface{}{"Limit": limit} 800 result, err := s.GetMaster().Exec(threadsQuery, props) 801 if err != nil { 802 return 803 } 804 rpcDeleted, err := result.RowsAffected() 805 if err != nil { 806 return 807 } 808 result, err = s.GetMaster().Exec(threadMembershipsQuery, props) 809 if err != nil { 810 return 811 } 812 rptDeleted, err := result.RowsAffected() 813 if err != nil { 814 return 815 } 816 deleted = rpcDeleted + rptDeleted 817 return 818 }