github.com/haalcala/mattermost-server-change-repo/v5@v5.33.2/store/storetest/thread_store.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package storetest
     5  
     6  import (
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/mattermost/mattermost-server/v5/model"
    14  	"github.com/mattermost/mattermost-server/v5/store"
    15  )
    16  
    17  func TestThreadStore(t *testing.T, ss store.Store, s SqlStore) {
    18  	t.Run("ThreadStorePopulation", func(t *testing.T) { testThreadStorePopulation(t, ss) })
    19  }
    20  
    21  func testThreadStorePopulation(t *testing.T, ss store.Store) {
    22  	makeSomePosts := func() []*model.Post {
    23  
    24  		u1 := model.User{
    25  			Email:    MakeEmail(),
    26  			Username: model.NewId(),
    27  		}
    28  
    29  		u, err := ss.User().Save(&u1)
    30  		require.NoError(t, err)
    31  
    32  		c, err2 := ss.Channel().Save(&model.Channel{
    33  			DisplayName: model.NewId(),
    34  			Type:        model.CHANNEL_OPEN,
    35  			Name:        model.NewId(),
    36  		}, 999)
    37  		require.NoError(t, err2)
    38  
    39  		_, err44 := ss.Channel().SaveMember(&model.ChannelMember{
    40  			ChannelId:   c.Id,
    41  			UserId:      u1.Id,
    42  			NotifyProps: model.GetDefaultChannelNotifyProps(),
    43  			MsgCount:    0,
    44  		})
    45  		require.NoError(t, err44)
    46  		o := model.Post{}
    47  		o.ChannelId = c.Id
    48  		o.UserId = u.Id
    49  		o.Message = "zz" + model.NewId() + "b"
    50  
    51  		otmp, err3 := ss.Post().Save(&o)
    52  		require.NoError(t, err3)
    53  		o2 := model.Post{}
    54  		o2.ChannelId = c.Id
    55  		o2.UserId = model.NewId()
    56  		o2.RootId = otmp.Id
    57  		o2.Message = "zz" + model.NewId() + "b"
    58  
    59  		o3 := model.Post{}
    60  		o3.ChannelId = c.Id
    61  		o3.UserId = u.Id
    62  		o3.RootId = otmp.Id
    63  		o3.Message = "zz" + model.NewId() + "b"
    64  
    65  		o4 := model.Post{}
    66  		o4.ChannelId = c.Id
    67  		o4.UserId = model.NewId()
    68  		o4.Message = "zz" + model.NewId() + "b"
    69  
    70  		newPosts, errIdx, err3 := ss.Post().SaveMultiple([]*model.Post{&o2, &o3, &o4})
    71  
    72  		olist, _ := ss.Post().Get(otmp.Id, true, false, false)
    73  		o1 := olist.Posts[olist.Order[0]]
    74  
    75  		newPosts = append([]*model.Post{o1}, newPosts...)
    76  		require.NoError(t, err3, "couldn't save item")
    77  		require.Equal(t, -1, errIdx)
    78  		require.Len(t, newPosts, 4)
    79  		require.Equal(t, int64(2), newPosts[0].ReplyCount)
    80  		require.Equal(t, int64(2), newPosts[1].ReplyCount)
    81  		require.Equal(t, int64(2), newPosts[2].ReplyCount)
    82  		require.Equal(t, int64(0), newPosts[3].ReplyCount)
    83  
    84  		return newPosts
    85  	}
    86  	t.Run("Save replies creates a thread", func(t *testing.T) {
    87  		newPosts := makeSomePosts()
    88  		thread, err := ss.Thread().Get(newPosts[0].Id)
    89  		require.NoError(t, err, "couldn't get thread")
    90  		require.NotNil(t, thread)
    91  		require.Equal(t, int64(2), thread.ReplyCount)
    92  		require.ElementsMatch(t, model.StringArray{newPosts[0].UserId, newPosts[1].UserId}, thread.Participants)
    93  
    94  		o5 := model.Post{}
    95  		o5.ChannelId = model.NewId()
    96  		o5.UserId = model.NewId()
    97  		o5.RootId = newPosts[0].Id
    98  		o5.Message = "zz" + model.NewId() + "b"
    99  
   100  		_, _, err = ss.Post().SaveMultiple([]*model.Post{&o5})
   101  		require.NoError(t, err, "couldn't save item")
   102  
   103  		thread, err = ss.Thread().Get(newPosts[0].Id)
   104  		require.NoError(t, err, "couldn't get thread")
   105  		require.NotNil(t, thread)
   106  		require.Equal(t, int64(3), thread.ReplyCount)
   107  		require.ElementsMatch(t, model.StringArray{newPosts[0].UserId, newPosts[1].UserId, o5.UserId}, thread.Participants)
   108  	})
   109  
   110  	t.Run("Delete a reply updates count on a thread", func(t *testing.T) {
   111  		newPosts := makeSomePosts()
   112  		thread, err := ss.Thread().Get(newPosts[0].Id)
   113  		require.NoError(t, err, "couldn't get thread")
   114  		require.NotNil(t, thread)
   115  		require.Equal(t, int64(2), thread.ReplyCount)
   116  		require.ElementsMatch(t, model.StringArray{newPosts[0].UserId, newPosts[1].UserId}, thread.Participants)
   117  
   118  		err = ss.Post().Delete(newPosts[1].Id, 1234, model.NewId())
   119  		require.NoError(t, err, "couldn't delete post")
   120  
   121  		thread, err = ss.Thread().Get(newPosts[0].Id)
   122  		require.NoError(t, err, "couldn't get thread")
   123  		require.NotNil(t, thread)
   124  		require.Equal(t, int64(1), thread.ReplyCount)
   125  		require.ElementsMatch(t, model.StringArray{newPosts[0].UserId, newPosts[1].UserId}, thread.Participants)
   126  	})
   127  
   128  	t.Run("Update reply should update the UpdateAt of the thread", func(t *testing.T) {
   129  		rootPost := model.Post{}
   130  		rootPost.RootId = model.NewId()
   131  		rootPost.ChannelId = model.NewId()
   132  		rootPost.UserId = model.NewId()
   133  		rootPost.Message = "zz" + model.NewId() + "b"
   134  
   135  		replyPost := model.Post{}
   136  		replyPost.ChannelId = rootPost.ChannelId
   137  		replyPost.UserId = model.NewId()
   138  		replyPost.Message = "zz" + model.NewId() + "b"
   139  		replyPost.RootId = rootPost.RootId
   140  
   141  		newPosts, _, err := ss.Post().SaveMultiple([]*model.Post{&rootPost, &replyPost})
   142  		require.NoError(t, err)
   143  
   144  		thread1, err := ss.Thread().Get(newPosts[0].RootId)
   145  		require.NoError(t, err)
   146  
   147  		rrootPost, err := ss.Post().GetSingle(rootPost.Id)
   148  		require.NoError(t, err)
   149  		require.Equal(t, rrootPost.UpdateAt, rootPost.UpdateAt)
   150  
   151  		replyPost2 := model.Post{}
   152  		replyPost2.ChannelId = rootPost.ChannelId
   153  		replyPost2.UserId = model.NewId()
   154  		replyPost2.Message = "zz" + model.NewId() + "b"
   155  		replyPost2.RootId = rootPost.Id
   156  
   157  		replyPost3 := model.Post{}
   158  		replyPost3.ChannelId = rootPost.ChannelId
   159  		replyPost3.UserId = model.NewId()
   160  		replyPost3.Message = "zz" + model.NewId() + "b"
   161  		replyPost3.RootId = rootPost.Id
   162  
   163  		_, _, err = ss.Post().SaveMultiple([]*model.Post{&replyPost2, &replyPost3})
   164  		require.NoError(t, err)
   165  
   166  		rrootPost2, err := ss.Post().GetSingle(rootPost.Id)
   167  		require.NoError(t, err)
   168  		require.Greater(t, rrootPost2.UpdateAt, rrootPost.UpdateAt)
   169  
   170  		thread2, err := ss.Thread().Get(rootPost.Id)
   171  		require.NoError(t, err)
   172  		require.Greater(t, thread2.LastReplyAt, thread1.LastReplyAt)
   173  	})
   174  
   175  	t.Run("Deleting reply should update the thread", func(t *testing.T) {
   176  		rootPost := model.Post{}
   177  		rootPost.RootId = model.NewId()
   178  		rootPost.ChannelId = model.NewId()
   179  		rootPost.UserId = model.NewId()
   180  		rootPost.Message = "zz" + model.NewId() + "b"
   181  
   182  		replyPost := model.Post{}
   183  		replyPost.ChannelId = rootPost.ChannelId
   184  		replyPost.UserId = model.NewId()
   185  		replyPost.Message = "zz" + model.NewId() + "b"
   186  		replyPost.RootId = rootPost.RootId
   187  
   188  		newPosts, _, err := ss.Post().SaveMultiple([]*model.Post{&rootPost, &replyPost})
   189  		require.NoError(t, err)
   190  
   191  		thread1, err := ss.Thread().Get(newPosts[0].RootId)
   192  		require.NoError(t, err)
   193  		require.EqualValues(t, thread1.ReplyCount, 2)
   194  		require.Len(t, thread1.Participants, 2)
   195  
   196  		err = ss.Post().Delete(replyPost.Id, 123, model.NewId())
   197  		require.NoError(t, err)
   198  
   199  		thread2, err := ss.Thread().Get(rootPost.RootId)
   200  		require.NoError(t, err)
   201  		require.EqualValues(t, thread2.ReplyCount, 1)
   202  		require.Len(t, thread2.Participants, 2)
   203  	})
   204  
   205  	t.Run("Deleting root post should delete the thread", func(t *testing.T) {
   206  		rootPost := model.Post{}
   207  		rootPost.ChannelId = model.NewId()
   208  		rootPost.UserId = model.NewId()
   209  		rootPost.Message = "zz" + model.NewId() + "b"
   210  
   211  		newPosts1, _, err := ss.Post().SaveMultiple([]*model.Post{&rootPost})
   212  		require.NoError(t, err)
   213  
   214  		replyPost := model.Post{}
   215  		replyPost.ChannelId = rootPost.ChannelId
   216  		replyPost.UserId = model.NewId()
   217  		replyPost.Message = "zz" + model.NewId() + "b"
   218  		replyPost.RootId = newPosts1[0].Id
   219  
   220  		_, _, err = ss.Post().SaveMultiple([]*model.Post{&replyPost})
   221  		require.NoError(t, err)
   222  
   223  		thread1, err := ss.Thread().Get(newPosts1[0].Id)
   224  		require.NoError(t, err)
   225  		require.EqualValues(t, thread1.ReplyCount, 1)
   226  		require.Len(t, thread1.Participants, 2)
   227  
   228  		err = ss.Post().PermanentDeleteByUser(rootPost.UserId)
   229  		require.NoError(t, err)
   230  
   231  		thread2, _ := ss.Thread().Get(rootPost.Id)
   232  		require.Nil(t, thread2)
   233  	})
   234  
   235  	t.Run("Thread last updated is changed when channel is updated after UpdateLastViewedAtPost", func(t *testing.T) {
   236  		newPosts := makeSomePosts()
   237  
   238  		require.NoError(t, ss.Thread().CreateMembershipIfNeeded(newPosts[0].UserId, newPosts[0].Id, true, false, true))
   239  		m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
   240  		require.NoError(t, err1)
   241  		m.LastUpdated -= 1000
   242  		_, err := ss.Thread().UpdateMembership(m)
   243  		require.NoError(t, err)
   244  
   245  		_, err = ss.Channel().UpdateLastViewedAtPost(newPosts[0], newPosts[0].UserId, 0, true)
   246  		require.NoError(t, err)
   247  
   248  		assert.Eventually(t, func() bool {
   249  			m2, err2 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
   250  			require.NoError(t, err2)
   251  			return m2.LastUpdated > m.LastUpdated
   252  		}, time.Second, 10*time.Millisecond)
   253  	})
   254  
   255  	t.Run("Thread last updated is changed when channel is updated after IncrementMentionCount", func(t *testing.T) {
   256  		newPosts := makeSomePosts()
   257  
   258  		require.NoError(t, ss.Thread().CreateMembershipIfNeeded(newPosts[0].UserId, newPosts[0].Id, true, false, true))
   259  		m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
   260  		require.NoError(t, err1)
   261  		m.LastUpdated -= 1000
   262  		_, err := ss.Thread().UpdateMembership(m)
   263  		require.NoError(t, err)
   264  
   265  		err = ss.Channel().IncrementMentionCount(newPosts[0].ChannelId, newPosts[0].UserId, true)
   266  		require.NoError(t, err)
   267  
   268  		assert.Eventually(t, func() bool {
   269  			m2, err2 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
   270  			require.NoError(t, err2)
   271  			return m2.LastUpdated > m.LastUpdated
   272  		}, time.Second, 10*time.Millisecond)
   273  	})
   274  
   275  	t.Run("Thread last updated is changed when channel is updated after UpdateLastViewedAt", func(t *testing.T) {
   276  		newPosts := makeSomePosts()
   277  
   278  		require.NoError(t, ss.Thread().CreateMembershipIfNeeded(newPosts[0].UserId, newPosts[0].Id, true, false, true))
   279  		m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
   280  		require.NoError(t, err1)
   281  		m.LastUpdated -= 1000
   282  		_, err := ss.Thread().UpdateMembership(m)
   283  		require.NoError(t, err)
   284  
   285  		_, err = ss.Channel().UpdateLastViewedAt([]string{newPosts[0].ChannelId}, newPosts[0].UserId, true)
   286  		require.NoError(t, err)
   287  
   288  		assert.Eventually(t, func() bool {
   289  			m2, err2 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
   290  			require.NoError(t, err2)
   291  			return m2.LastUpdated > m.LastUpdated
   292  		}, time.Second, 10*time.Millisecond)
   293  	})
   294  
   295  	t.Run("Thread last updated is changed when channel is updated after UpdateLastViewedAtPost for mark unread", func(t *testing.T) {
   296  		newPosts := makeSomePosts()
   297  
   298  		require.NoError(t, ss.Thread().CreateMembershipIfNeeded(newPosts[0].UserId, newPosts[0].Id, true, false, true))
   299  		m, err1 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
   300  		require.NoError(t, err1)
   301  		m.LastUpdated += 1000
   302  		_, err := ss.Thread().UpdateMembership(m)
   303  		require.NoError(t, err)
   304  
   305  		_, err = ss.Channel().UpdateLastViewedAtPost(newPosts[0], newPosts[0].UserId, 0, true)
   306  		require.NoError(t, err)
   307  
   308  		assert.Eventually(t, func() bool {
   309  			m2, err2 := ss.Thread().GetMembershipForUser(newPosts[0].UserId, newPosts[0].Id)
   310  			require.NoError(t, err2)
   311  			return m2.LastUpdated < m.LastUpdated
   312  		}, time.Second, 10*time.Millisecond)
   313  	})
   314  }