github.com/masterhung0112/hk_server/v5@v5.0.0-20220302090640-ec71aef15e1c/store/sqlstore/thread_store.go (about) 1 package sqlstore 2 3 import ( 4 "context" 5 "database/sql" 6 "time" 7 8 sq "github.com/Masterminds/squirrel" 9 10 "github.com/mattermost/gorp" 11 "github.com/pkg/errors" 12 13 "github.com/masterhung0112/hk_server/v5/model" 14 "github.com/masterhung0112/hk_server/v5/store" 15 "github.com/masterhung0112/hk_server/v5/utils" 16 ) 17 18 type SqlThreadStore struct { 19 *SqlStore 20 } 21 22 func (s *SqlThreadStore) ClearCaches() { 23 } 24 25 func newSqlThreadStore(sqlStore *SqlStore) store.ThreadStore { 26 s := &SqlThreadStore{ 27 SqlStore: sqlStore, 28 } 29 30 for _, db := range sqlStore.GetAllConns() { 31 tableThreads := db.AddTableWithName(model.Thread{}, "Threads").SetKeys(false, "PostId") 32 tableThreads.ColMap("PostId").SetMaxSize(26) 33 tableThreads.ColMap("ChannelId").SetMaxSize(26) 34 tableThreads.ColMap("Participants").SetMaxSize(0) 35 tableThreadMemberships := db.AddTableWithName(model.ThreadMembership{}, "ThreadMemberships").SetKeys(false, "PostId", "UserId") 36 tableThreadMemberships.ColMap("PostId").SetMaxSize(26) 37 tableThreadMemberships.ColMap("UserId").SetMaxSize(26) 38 } 39 40 return s 41 } 42 43 func threadSliceColumns() []string { 44 return []string{"PostId", "ChannelId", "LastReplyAt", "ReplyCount", "Participants"} 45 } 46 47 func threadToSlice(thread *model.Thread) []interface{} { 48 return []interface{}{ 49 thread.PostId, 50 thread.ChannelId, 51 thread.LastReplyAt, 52 thread.ReplyCount, 53 thread.Participants, 54 } 55 } 56 57 func (s *SqlThreadStore) createIndexesIfNotExists() { 58 s.CreateIndexIfNotExists("idx_thread_memberships_last_update_at", "ThreadMemberships", "LastUpdated") 59 s.CreateIndexIfNotExists("idx_thread_memberships_last_view_at", "ThreadMemberships", "LastViewed") 60 s.CreateIndexIfNotExists("idx_thread_memberships_user_id", "ThreadMemberships", "UserId") 61 s.CreateIndexIfNotExists("idx_threads_channel_id", "Threads", "ChannelId") 62 } 63 64 func (s *SqlThreadStore) SaveMultiple(threads []*model.Thread) ([]*model.Thread, int, error) { 65 builder := s.getQueryBuilder().Insert("Threads").Columns(threadSliceColumns()...) 66 for _, thread := range threads { 67 builder = builder.Values(threadToSlice(thread)...) 68 } 69 query, args, err := builder.ToSql() 70 if err != nil { 71 return nil, -1, errors.Wrap(err, "thread_tosql") 72 } 73 74 if _, err := s.GetMaster().Exec(query, args...); err != nil { 75 return nil, -1, errors.Wrap(err, "failed to save Post") 76 } 77 78 return threads, -1, nil 79 } 80 81 func (s *SqlThreadStore) Save(thread *model.Thread) (*model.Thread, error) { 82 threads, _, err := s.SaveMultiple([]*model.Thread{thread}) 83 if err != nil { 84 return nil, err 85 } 86 return threads[0], nil 87 } 88 89 func (s *SqlThreadStore) Update(thread *model.Thread) (*model.Thread, error) { 90 return s.update(s.GetMaster(), thread) 91 } 92 93 func (s *SqlThreadStore) update(ex gorp.SqlExecutor, thread *model.Thread) (*model.Thread, error) { 94 if _, err := ex.Update(thread); err != nil { 95 return nil, errors.Wrapf(err, "failed to update thread with id=%s", thread.PostId) 96 } 97 98 return thread, nil 99 } 100 101 func (s *SqlThreadStore) Get(id string) (*model.Thread, error) { 102 return s.get(s.GetReplica(), id) 103 } 104 105 func (s *SqlThreadStore) get(ex gorp.SqlExecutor, id string) (*model.Thread, error) { 106 var thread model.Thread 107 query, args, _ := s.getQueryBuilder().Select("*").From("Threads").Where(sq.Eq{"PostId": id}).ToSql() 108 err := ex.SelectOne(&thread, query, args...) 109 if err != nil { 110 if err == sql.ErrNoRows { 111 return nil, nil 112 } 113 114 return nil, errors.Wrapf(err, "failed to get thread with id=%s", id) 115 } 116 return &thread, nil 117 } 118 119 func (s *SqlThreadStore) GetThreadsForUser(userId, teamId string, opts model.GetUserThreadsOpts) (*model.Threads, error) { 120 type JoinedThread struct { 121 PostId string 122 ReplyCount int64 123 LastReplyAt int64 124 LastViewedAt int64 125 UnreadReplies int64 126 UnreadMentions int64 127 Participants model.StringArray 128 model.Post 129 } 130 131 unreadRepliesQuery := "SELECT COUNT(Posts.Id) From Posts Where Posts.RootId=ThreadMemberships.PostId AND Posts.CreateAt >= ThreadMemberships.LastViewed" 132 fetchConditions := sq.And{ 133 sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}}, 134 sq.Eq{"ThreadMemberships.UserId": userId}, 135 sq.Eq{"ThreadMemberships.Following": true}, 136 } 137 if !opts.Deleted { 138 fetchConditions = sq.And{ 139 fetchConditions, 140 sq.Eq{"COALESCE(Posts.DeleteAt, 0)": 0}, 141 } 142 } 143 144 pageSize := uint64(30) 145 if opts.PageSize != 0 { 146 pageSize = opts.PageSize 147 } 148 149 totalUnreadThreadsChan := make(chan store.StoreResult, 1) 150 totalCountChan := make(chan store.StoreResult, 1) 151 totalUnreadMentionsChan := make(chan store.StoreResult, 1) 152 threadsChan := make(chan store.StoreResult, 1) 153 go func() { 154 repliesQuery, repliesQueryArgs, _ := s.getQueryBuilder(). 155 Select("COUNT(DISTINCT(Posts.RootId))"). 156 From("Posts"). 157 LeftJoin("ThreadMemberships ON Posts.RootId = ThreadMemberships.PostId"). 158 LeftJoin("Channels ON Posts.ChannelId = Channels.Id"). 159 Where(fetchConditions). 160 Where("Posts.CreateAt >= ThreadMemberships.LastViewed").ToSql() 161 162 totalUnreadThreads, err := s.GetMaster().SelectInt(repliesQuery, repliesQueryArgs...) 163 totalUnreadThreadsChan <- store.StoreResult{Data: totalUnreadThreads, NErr: errors.Wrapf(err, "failed to get count unread on threads for user id=%s", userId)} 164 close(totalUnreadThreadsChan) 165 }() 166 go func() { 167 newFetchConditions := fetchConditions 168 169 if opts.Unread { 170 newFetchConditions = sq.And{newFetchConditions, sq.Expr("ThreadMemberships.LastViewed < Threads.LastReplyAt")} 171 } 172 173 threadsQuery, threadsQueryArgs, _ := s.getQueryBuilder(). 174 Select("COUNT(ThreadMemberships.PostId)"). 175 LeftJoin("Threads ON Threads.PostId = ThreadMemberships.PostId"). 176 LeftJoin("Channels ON Threads.ChannelId = Channels.Id"). 177 LeftJoin("Posts ON Posts.Id = ThreadMemberships.PostId"). 178 From("ThreadMemberships"). 179 Where(newFetchConditions).ToSql() 180 181 totalCount, err := s.GetMaster().SelectInt(threadsQuery, threadsQueryArgs...) 182 totalCountChan <- store.StoreResult{Data: totalCount, NErr: err} 183 close(totalCountChan) 184 }() 185 go func() { 186 mentionsQuery, mentionsQueryArgs, _ := s.getQueryBuilder(). 187 Select("COALESCE(SUM(ThreadMemberships.UnreadMentions),0)"). 188 From("ThreadMemberships"). 189 LeftJoin("Threads ON Threads.PostId = ThreadMemberships.PostId"). 190 LeftJoin("Posts ON Posts.Id = ThreadMemberships.PostId"). 191 LeftJoin("Channels ON Threads.ChannelId = Channels.Id"). 192 Where(fetchConditions).ToSql() 193 totalUnreadMentions, err := s.GetMaster().SelectInt(mentionsQuery, mentionsQueryArgs...) 194 totalUnreadMentionsChan <- store.StoreResult{Data: totalUnreadMentions, NErr: err} 195 close(totalUnreadMentionsChan) 196 }() 197 go func() { 198 newFetchConditions := fetchConditions 199 if opts.Since > 0 { 200 newFetchConditions = sq.And{newFetchConditions, sq.GtOrEq{"ThreadMemberships.LastUpdated": opts.Since}} 201 } 202 order := "DESC" 203 if opts.Before != "" { 204 newFetchConditions = sq.And{ 205 newFetchConditions, 206 sq.Expr(`LastReplyAt < (SELECT LastReplyAt FROM Threads WHERE PostId = ?)`, opts.Before), 207 } 208 } 209 if opts.After != "" { 210 order = "ASC" 211 newFetchConditions = sq.And{ 212 newFetchConditions, 213 sq.Expr(`LastReplyAt > (SELECT LastReplyAt FROM Threads WHERE PostId = ?)`, opts.After), 214 } 215 } 216 if opts.Unread { 217 newFetchConditions = sq.And{newFetchConditions, sq.Expr("ThreadMemberships.LastViewed < Threads.LastReplyAt")} 218 } 219 var threads []*JoinedThread 220 query, args, _ := s.getQueryBuilder(). 221 Select(`Threads.*, 222 ` + postSliceCoalesceQuery() + `, 223 ThreadMemberships.LastViewed as LastViewedAt, 224 ThreadMemberships.UnreadMentions as UnreadMentions`). 225 From("Threads"). 226 Column(sq.Alias(sq.Expr(unreadRepliesQuery), "UnreadReplies")). 227 LeftJoin("Posts ON Posts.Id = Threads.PostId"). 228 LeftJoin("Channels ON Posts.ChannelId = Channels.Id"). 229 LeftJoin("ThreadMemberships ON ThreadMemberships.PostId = Threads.PostId"). 230 Where(newFetchConditions). 231 OrderBy("Threads.LastReplyAt " + order). 232 Limit(pageSize).ToSql() 233 234 _, err := s.GetReplica().Select(&threads, query, args...) 235 threadsChan <- store.StoreResult{Data: threads, NErr: err} 236 close(threadsChan) 237 }() 238 239 threadsResult := <-threadsChan 240 if threadsResult.NErr != nil { 241 return nil, threadsResult.NErr 242 } 243 threads := threadsResult.Data.([]*JoinedThread) 244 245 totalUnreadMentionsResult := <-totalUnreadMentionsChan 246 if totalUnreadMentionsResult.NErr != nil { 247 return nil, totalUnreadMentionsResult.NErr 248 } 249 totalUnreadMentions := totalUnreadMentionsResult.Data.(int64) 250 251 totalCountResult := <-totalCountChan 252 if totalCountResult.NErr != nil { 253 return nil, totalCountResult.NErr 254 } 255 totalCount := totalCountResult.Data.(int64) 256 257 totalUnreadThreadsResult := <-totalUnreadThreadsChan 258 if totalUnreadThreadsResult.NErr != nil { 259 return nil, totalUnreadThreadsResult.NErr 260 } 261 totalUnreadThreads := totalUnreadThreadsResult.Data.(int64) 262 263 var userIds []string 264 userIdMap := map[string]bool{} 265 for _, thread := range threads { 266 for _, participantId := range thread.Participants { 267 if _, ok := userIdMap[participantId]; !ok { 268 userIdMap[participantId] = true 269 userIds = append(userIds, participantId) 270 } 271 } 272 } 273 var users []*model.User 274 if opts.Extended { 275 var err error 276 users, err = s.User().GetProfileByIds(context.Background(), userIds, &store.UserGetByIdsOpts{}, true) 277 if err != nil { 278 return nil, errors.Wrapf(err, "failed to get threads for user id=%s", userId) 279 } 280 } else { 281 for _, userId := range userIds { 282 users = append(users, &model.User{Id: userId}) 283 } 284 } 285 286 result := &model.Threads{ 287 Total: totalCount, 288 Threads: []*model.ThreadResponse{}, 289 TotalUnreadMentions: totalUnreadMentions, 290 TotalUnreadThreads: totalUnreadThreads, 291 } 292 293 for _, thread := range threads { 294 var participants []*model.User 295 for _, participantId := range thread.Participants { 296 var participant *model.User 297 for _, u := range users { 298 if u.Id == participantId { 299 participant = u 300 break 301 } 302 } 303 if participant == nil { 304 return nil, errors.New("cannot find thread participant with id=" + participantId) 305 } 306 participants = append(participants, participant) 307 } 308 result.Threads = append(result.Threads, &model.ThreadResponse{ 309 PostId: thread.PostId, 310 ReplyCount: thread.ReplyCount, 311 LastReplyAt: thread.LastReplyAt, 312 LastViewedAt: thread.LastViewedAt, 313 UnreadReplies: thread.UnreadReplies, 314 UnreadMentions: thread.UnreadMentions, 315 Participants: participants, 316 Post: thread.Post.ToNilIfInvalid(), 317 }) 318 } 319 320 return result, nil 321 } 322 323 func (s *SqlThreadStore) GetThreadFollowers(threadID string) ([]string, error) { 324 var users []string 325 query, args, _ := s.getQueryBuilder(). 326 Select("ThreadMemberships.UserId"). 327 From("ThreadMemberships"). 328 Where(sq.Eq{"PostId": threadID}).ToSql() 329 _, err := s.GetReplica().Select(&users, query, args...) 330 331 if err != nil { 332 return nil, err 333 } 334 return users, nil 335 } 336 337 func (s *SqlThreadStore) GetThreadForUser(teamId string, threadMembership *model.ThreadMembership, extended bool) (*model.ThreadResponse, error) { 338 if !threadMembership.Following { 339 return nil, nil // in case the thread is not followed anymore - return nil error to be interpreted as 404 340 } 341 type JoinedThread struct { 342 PostId string 343 Following bool 344 ReplyCount int64 345 LastReplyAt int64 346 LastViewedAt int64 347 UnreadReplies int64 348 UnreadMentions int64 349 Participants model.StringArray 350 model.Post 351 } 352 353 unreadRepliesQuery, unreadRepliesArgs := sq. 354 Select("COUNT(Posts.Id)"). 355 From("Posts"). 356 Where(sq.And{ 357 sq.Eq{"Posts.RootId": threadMembership.PostId}, 358 sq.GtOrEq{"Posts.CreateAt": threadMembership.LastViewed}, 359 sq.Eq{"Posts.DeleteAt": 0}, 360 }).MustSql() 361 fetchConditions := sq.And{ 362 sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}}, 363 sq.Eq{"Threads.PostId": threadMembership.PostId}, 364 } 365 366 var thread JoinedThread 367 query, threadArgs, _ := s.getQueryBuilder(). 368 Select("Threads.*, Posts.*"). 369 From("Threads"). 370 Column(sq.Alias(sq.Expr(unreadRepliesQuery), "UnreadReplies")). 371 LeftJoin("Posts ON Posts.Id = Threads.PostId"). 372 LeftJoin("Channels ON Posts.ChannelId = Channels.Id"). 373 Where(fetchConditions).ToSql() 374 args := append(unreadRepliesArgs, threadArgs...) 375 err := s.GetReplica().SelectOne(&thread, query, args...) 376 377 if err != nil { 378 return nil, err 379 } 380 381 thread.LastViewedAt = threadMembership.LastViewed 382 thread.UnreadMentions = threadMembership.UnreadMentions 383 384 var users []*model.User 385 if extended { 386 var err error 387 users, err = s.User().GetProfileByIds(context.Background(), thread.Participants, &store.UserGetByIdsOpts{}, true) 388 if err != nil { 389 return nil, errors.Wrapf(err, "failed to get thread for user id=%s", threadMembership.UserId) 390 } 391 } else { 392 for _, userId := range thread.Participants { 393 users = append(users, &model.User{Id: userId}) 394 } 395 } 396 397 result := &model.ThreadResponse{ 398 PostId: thread.PostId, 399 ReplyCount: thread.ReplyCount, 400 LastReplyAt: thread.LastReplyAt, 401 LastViewedAt: thread.LastViewedAt, 402 UnreadReplies: thread.UnreadReplies, 403 UnreadMentions: thread.UnreadMentions, 404 Participants: users, 405 Post: thread.Post.ToNilIfInvalid(), 406 } 407 408 return result, nil 409 } 410 411 func (s *SqlThreadStore) MarkAllAsReadInChannels(userID string, channelIDs []string) error { 412 var threadIDs []string 413 414 query, args, _ := s.getQueryBuilder(). 415 Select("ThreadMemberships.PostId"). 416 Join("Threads ON Threads.PostId = ThreadMemberships.PostId"). 417 Join("Channels ON Threads.ChannelId = Channels.Id"). 418 From("ThreadMemberships"). 419 Where(sq.Eq{"Threads.ChannelId": channelIDs}). 420 Where(sq.Eq{"ThreadMemberships.UserId": userID}). 421 ToSql() 422 423 _, err := s.GetReplica().Select(&threadIDs, query, args...) 424 if err != nil { 425 return errors.Wrapf(err, "failed to get thread membership with userid=%s", userID) 426 } 427 428 timestamp := model.GetMillis() 429 query, args, _ = s.getQueryBuilder(). 430 Update("ThreadMemberships"). 431 Where(sq.Eq{"PostId": threadIDs}). 432 Where(sq.Eq{"UserId": userID}). 433 Set("LastViewed", timestamp). 434 Set("UnreadMentions", 0). 435 ToSql() 436 if _, err := s.GetMaster().Exec(query, args...); err != nil { 437 return errors.Wrapf(err, "failed to update thread read state for user id=%s", userID) 438 } 439 return nil 440 441 } 442 443 func (s *SqlThreadStore) MarkAllAsRead(userId, teamId string) error { 444 memberships, err := s.GetMembershipsForUser(userId, teamId) 445 if err != nil { 446 return err 447 } 448 var membershipIds []string 449 for _, m := range memberships { 450 membershipIds = append(membershipIds, m.PostId) 451 } 452 timestamp := model.GetMillis() 453 query, args, _ := s.getQueryBuilder(). 454 Update("ThreadMemberships"). 455 Where(sq.Eq{"PostId": membershipIds}). 456 Where(sq.Eq{"UserId": userId}). 457 Set("LastViewed", timestamp). 458 Set("UnreadMentions", 0). 459 ToSql() 460 if _, err := s.GetMaster().Exec(query, args...); err != nil { 461 return errors.Wrapf(err, "failed to update thread read state for user id=%s", userId) 462 } 463 return nil 464 } 465 466 func (s *SqlThreadStore) MarkAsRead(userId, threadId string, timestamp int64) error { 467 query, args, _ := s.getQueryBuilder(). 468 Update("ThreadMemberships"). 469 Where(sq.Eq{"UserId": userId}). 470 Where(sq.Eq{"PostId": threadId}). 471 Set("LastViewed", timestamp). 472 ToSql() 473 if _, err := s.GetMaster().Exec(query, args...); err != nil { 474 return errors.Wrapf(err, "failed to update thread read state for user id=%s thread_id=%v", userId, threadId) 475 } 476 return nil 477 } 478 479 func (s *SqlThreadStore) Delete(threadId string) error { 480 query, args, _ := s.getQueryBuilder().Delete("Threads").Where(sq.Eq{"PostId": threadId}).ToSql() 481 if _, err := s.GetMaster().Exec(query, args...); err != nil { 482 return errors.Wrap(err, "failed to update threads") 483 } 484 485 return nil 486 } 487 488 func (s *SqlThreadStore) SaveMembership(membership *model.ThreadMembership) (*model.ThreadMembership, error) { 489 return s.saveMembership(s.GetMaster(), membership) 490 } 491 492 func (s *SqlThreadStore) saveMembership(ex gorp.SqlExecutor, membership *model.ThreadMembership) (*model.ThreadMembership, error) { 493 if err := ex.Insert(membership); err != nil { 494 return nil, errors.Wrapf(err, "failed to save thread membership with postid=%s userid=%s", membership.PostId, membership.UserId) 495 } 496 497 return membership, nil 498 } 499 500 func (s *SqlThreadStore) UpdateMembership(membership *model.ThreadMembership) (*model.ThreadMembership, error) { 501 return s.updateMembership(s.GetMaster(), membership) 502 } 503 504 func (s *SqlThreadStore) updateMembership(ex gorp.SqlExecutor, membership *model.ThreadMembership) (*model.ThreadMembership, error) { 505 if _, err := ex.Update(membership); err != nil { 506 return nil, errors.Wrapf(err, "failed to update thread membership with postid=%s userid=%s", membership.PostId, membership.UserId) 507 } 508 509 return membership, nil 510 } 511 512 func (s *SqlThreadStore) GetMembershipsForUser(userId, teamId string) ([]*model.ThreadMembership, error) { 513 var memberships []*model.ThreadMembership 514 515 query, args, _ := s.getQueryBuilder(). 516 Select("ThreadMemberships.*"). 517 Join("Threads ON Threads.PostId = ThreadMemberships.PostId"). 518 Join("Channels ON Threads.ChannelId = Channels.Id"). 519 From("ThreadMemberships"). 520 Where(sq.Or{sq.Eq{"Channels.TeamId": teamId}, sq.Eq{"Channels.TeamId": ""}}). 521 Where(sq.Eq{"ThreadMemberships.UserId": userId}). 522 ToSql() 523 524 _, err := s.GetReplica().Select(&memberships, query, args...) 525 if err != nil { 526 return nil, errors.Wrapf(err, "failed to get thread membership with userid=%s", userId) 527 } 528 return memberships, nil 529 } 530 531 func (s *SqlThreadStore) GetMembershipForUser(userId, postId string) (*model.ThreadMembership, error) { 532 return s.getMembershipForUser(s.GetReplica(), userId, postId) 533 } 534 535 func (s *SqlThreadStore) getMembershipForUser(ex gorp.SqlExecutor, userId, postId string) (*model.ThreadMembership, error) { 536 var membership model.ThreadMembership 537 err := ex.SelectOne(&membership, "SELECT * from ThreadMemberships WHERE UserId = :UserId AND PostId = :PostId", map[string]interface{}{"UserId": userId, "PostId": postId}) 538 if err != nil { 539 if err == sql.ErrNoRows { 540 return nil, store.NewErrNotFound("Thread", postId) 541 } 542 return nil, errors.Wrapf(err, "failed to get thread membership with userid=%s postid=%s", userId, postId) 543 } 544 return &membership, nil 545 } 546 547 func (s *SqlThreadStore) DeleteMembershipForUser(userId string, postId string) error { 548 if _, err := s.GetMaster().Exec("DELETE FROM ThreadMemberships Where PostId = :PostId AND UserId = :UserId", map[string]interface{}{"PostId": postId, "UserId": userId}); err != nil { 549 return errors.Wrap(err, "failed to update thread membership") 550 } 551 552 return nil 553 } 554 555 // MaintainMembership creates or updates a thread membership for the given user 556 // and post. This method is used to update the state of a membership in response 557 // to some events like: 558 // - post creation (mentions handling) 559 // - channel marked unread 560 // - user explicitly following a thread 561 func (s *SqlThreadStore) MaintainMembership(userId, postId string, opts store.ThreadMembershipOpts) (*model.ThreadMembership, error) { 562 trx, err := s.GetMaster().Begin() 563 if err != nil { 564 return nil, errors.Wrap(err, "begin_transaction") 565 } 566 defer finalizeTransaction(trx) 567 568 membership, err := s.getMembershipForUser(trx, userId, postId) 569 now := utils.MillisFromTime(time.Now()) 570 // if memebership exists, update it if: 571 // a. user started/stopped following a thread 572 // b. mention count changed 573 // c. user viewed a thread 574 if err == nil { 575 followingNeedsUpdate := (opts.UpdateFollowing && !membership.Following || membership.Following != opts.Following) 576 if followingNeedsUpdate || opts.IncrementMentions || opts.UpdateViewedTimestamp { 577 if followingNeedsUpdate { 578 membership.Following = opts.Following 579 } 580 if opts.UpdateViewedTimestamp { 581 membership.LastViewed = now 582 } 583 membership.LastUpdated = now 584 if opts.IncrementMentions { 585 membership.UnreadMentions += 1 586 } 587 if _, err = s.updateMembership(trx, membership); err != nil { 588 return nil, err 589 } 590 } 591 if err = trx.Commit(); err != nil { 592 return nil, errors.Wrap(err, "commit_transaction") 593 } 594 595 return membership, err 596 } 597 598 var nfErr *store.ErrNotFound 599 600 if !errors.As(err, &nfErr) { 601 return nil, errors.Wrap(err, "failed to get thread membership") 602 } 603 membership = &model.ThreadMembership{ 604 PostId: postId, 605 UserId: userId, 606 Following: opts.Following, 607 LastUpdated: now, 608 } 609 if opts.IncrementMentions { 610 membership.UnreadMentions = 1 611 } 612 if opts.UpdateViewedTimestamp { 613 membership.LastViewed = now 614 } 615 membership, err = s.saveMembership(trx, membership) 616 if err != nil { 617 return nil, err 618 } 619 620 if opts.UpdateParticipants { 621 thread, getErr := s.get(trx, postId) 622 if getErr != nil { 623 return nil, getErr 624 } 625 if thread != nil && !thread.Participants.Contains(userId) { 626 thread.Participants = append(thread.Participants, userId) 627 if _, err = s.update(trx, thread); err != nil { 628 return nil, err 629 } 630 } 631 } 632 if err = trx.Commit(); err != nil { 633 return nil, errors.Wrap(err, "commit_transaction") 634 } 635 return membership, err 636 } 637 638 func (s *SqlThreadStore) CollectThreadsWithNewerReplies(userId string, channelIds []string, timestamp int64) ([]string, error) { 639 var changedThreads []string 640 query, args, _ := s.getQueryBuilder(). 641 Select("Threads.PostId"). 642 From("Threads"). 643 LeftJoin("ChannelMembers ON ChannelMembers.ChannelId=Threads.ChannelId"). 644 Where(sq.And{ 645 sq.Eq{"Threads.ChannelId": channelIds}, 646 sq.Eq{"ChannelMembers.UserId": userId}, 647 sq.Or{ 648 sq.Expr("Threads.LastReplyAt >= ChannelMembers.LastViewedAt"), 649 sq.GtOrEq{"Threads.LastReplyAt": timestamp}, 650 }, 651 }). 652 ToSql() 653 if _, err := s.GetReplica().Select(&changedThreads, query, args...); err != nil { 654 return nil, errors.Wrap(err, "failed to fetch threads") 655 } 656 return changedThreads, nil 657 } 658 659 func (s *SqlThreadStore) UpdateUnreadsByChannel(userId string, changedThreads []string, timestamp int64, updateViewedTimestamp bool) error { 660 if len(changedThreads) == 0 { 661 return nil 662 } 663 664 qb := s.getQueryBuilder(). 665 Update("ThreadMemberships"). 666 Where(sq.Eq{"UserId": userId, "PostId": changedThreads}). 667 Set("LastUpdated", timestamp) 668 669 if updateViewedTimestamp { 670 qb = qb.Set("LastViewed", timestamp) 671 } 672 updateQuery, updateArgs, _ := qb.ToSql() 673 674 if _, err := s.GetMaster().Exec(updateQuery, updateArgs...); err != nil { 675 return errors.Wrap(err, "failed to update thread membership") 676 } 677 678 return nil 679 } 680 681 func (s *SqlThreadStore) GetPosts(threadId string, since int64) ([]*model.Post, error) { 682 query, args, _ := s.getQueryBuilder(). 683 Select("*"). 684 From("Posts"). 685 Where(sq.Eq{"RootId": threadId}). 686 Where(sq.Eq{"DeleteAt": 0}). 687 Where(sq.GtOrEq{"UpdateAt": since}).ToSql() 688 var result []*model.Post 689 if _, err := s.GetReplica().Select(&result, query, args...); err != nil { 690 return nil, errors.Wrap(err, "failed to fetch thread posts") 691 } 692 return result, nil 693 } 694 695 // PermanentDeleteBatchForRetentionPolicies deletes a batch of records which are affected by 696 // the global or a granular retention policy. 697 // See `genericPermanentDeleteBatchForRetentionPolicies` for details. 698 func (s *SqlThreadStore) PermanentDeleteBatchForRetentionPolicies(now, globalPolicyEndTime, limit int64, cursor model.RetentionPolicyCursor) (int64, model.RetentionPolicyCursor, error) { 699 builder := s.getQueryBuilder(). 700 Select("Threads.PostId"). 701 From("Threads") 702 return genericPermanentDeleteBatchForRetentionPolicies(RetentionPolicyBatchDeletionInfo{ 703 BaseBuilder: builder, 704 Table: "Threads", 705 TimeColumn: "LastReplyAt", 706 PrimaryKeys: []string{"PostId"}, 707 ChannelIDTable: "Threads", 708 NowMillis: now, 709 GlobalPolicyEndTime: globalPolicyEndTime, 710 Limit: limit, 711 }, s.SqlStore, cursor) 712 } 713 714 // PermanentDeleteBatchThreadMembershipsForRetentionPolicies deletes a batch of records 715 // which are affected by the global or a granular retention policy. 716 // See `genericPermanentDeleteBatchForRetentionPolicies` for details. 717 func (s *SqlThreadStore) PermanentDeleteBatchThreadMembershipsForRetentionPolicies(now, globalPolicyEndTime, limit int64, cursor model.RetentionPolicyCursor) (int64, model.RetentionPolicyCursor, error) { 718 builder := s.getQueryBuilder(). 719 Select("ThreadMemberships.PostId"). 720 From("ThreadMemberships"). 721 InnerJoin("Threads ON ThreadMemberships.PostId = Threads.PostId") 722 return genericPermanentDeleteBatchForRetentionPolicies(RetentionPolicyBatchDeletionInfo{ 723 BaseBuilder: builder, 724 Table: "ThreadMemberships", 725 TimeColumn: "LastUpdated", 726 PrimaryKeys: []string{"PostId"}, 727 ChannelIDTable: "Threads", 728 NowMillis: now, 729 GlobalPolicyEndTime: globalPolicyEndTime, 730 Limit: limit, 731 }, s.SqlStore, cursor) 732 } 733 734 // DeleteOrphanedRows removes orphaned rows from Threads and ThreadMemberships 735 func (s *SqlThreadStore) DeleteOrphanedRows(limit int) (deleted int64, err error) { 736 // We need the extra level of nesting to deal with MySQL's locking 737 const threadsQuery = ` 738 DELETE FROM Threads WHERE PostId IN ( 739 SELECT * FROM ( 740 SELECT Threads.PostId FROM Threads 741 LEFT JOIN Channels ON Threads.ChannelId = Channels.Id 742 WHERE Channels.Id IS NULL 743 LIMIT :Limit 744 ) AS A 745 )` 746 // We only delete a thread membership if the entire thread no longer exists, 747 // not if the root post has been deleted 748 const threadMembershipsQuery = ` 749 DELETE FROM ThreadMemberships WHERE PostId IN ( 750 SELECT * FROM ( 751 SELECT ThreadMemberships.PostId FROM ThreadMemberships 752 LEFT JOIN Threads ON ThreadMemberships.PostId = Threads.PostId 753 WHERE Threads.PostId IS NULL 754 LIMIT :Limit 755 ) AS A 756 )` 757 props := map[string]interface{}{"Limit": limit} 758 result, err := s.GetMaster().Exec(threadsQuery, props) 759 if err != nil { 760 return 761 } 762 rpcDeleted, err := result.RowsAffected() 763 if err != nil { 764 return 765 } 766 result, err = s.GetMaster().Exec(threadMembershipsQuery, props) 767 if err != nil { 768 return 769 } 770 rptDeleted, err := result.RowsAffected() 771 if err != nil { 772 return 773 } 774 deleted = rpcDeleted + rptDeleted 775 return 776 }