github.com/mattermost/mattermost-server/v5@v5.39.3/store/storetest/session_store.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package storetest
     5  
     6  import (
     7  	"context"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/mattermost/mattermost-server/v5/model"
    14  	"github.com/mattermost/mattermost-server/v5/store"
    15  )
    16  
    17  const (
    18  	TenMinutes = 600000
    19  )
    20  
    21  func TestSessionStore(t *testing.T, ss store.Store) {
    22  	// Run serially to prevent interfering with other tests
    23  	testSessionCleanup(t, ss)
    24  
    25  	t.Run("Save", func(t *testing.T) { testSessionStoreSave(t, ss) })
    26  	t.Run("SessionGet", func(t *testing.T) { testSessionGet(t, ss) })
    27  	t.Run("SessionGetWithDeviceId", func(t *testing.T) { testSessionGetWithDeviceId(t, ss) })
    28  	t.Run("SessionRemove", func(t *testing.T) { testSessionRemove(t, ss) })
    29  	t.Run("SessionRemoveAll", func(t *testing.T) { testSessionRemoveAll(t, ss) })
    30  	t.Run("SessionRemoveByUser", func(t *testing.T) { testSessionRemoveByUser(t, ss) })
    31  	t.Run("SessionRemoveToken", func(t *testing.T) { testSessionRemoveToken(t, ss) })
    32  	t.Run("SessionUpdateDeviceId", func(t *testing.T) { testSessionUpdateDeviceId(t, ss) })
    33  	t.Run("SessionUpdateDeviceId2", func(t *testing.T) { testSessionUpdateDeviceId2(t, ss) })
    34  	t.Run("UpdateExpiresAt", func(t *testing.T) { testSessionStoreUpdateExpiresAt(t, ss) })
    35  	t.Run("UpdateLastActivityAt", func(t *testing.T) { testSessionStoreUpdateLastActivityAt(t, ss) })
    36  	t.Run("SessionCount", func(t *testing.T) { testSessionCount(t, ss) })
    37  	t.Run("GetSessionsExpired", func(t *testing.T) { testGetSessionsExpired(t, ss) })
    38  	t.Run("UpdateExpiredNotify", func(t *testing.T) { testUpdateExpiredNotify(t, ss) })
    39  }
    40  
    41  func testSessionStoreSave(t *testing.T, ss store.Store) {
    42  	s1 := &model.Session{}
    43  	s1.UserId = model.NewId()
    44  
    45  	_, err := ss.Session().Save(s1)
    46  	require.NoError(t, err)
    47  }
    48  
    49  func testSessionGet(t *testing.T, ss store.Store) {
    50  	s1 := &model.Session{}
    51  	s1.UserId = model.NewId()
    52  
    53  	s1, err := ss.Session().Save(s1)
    54  	require.NoError(t, err)
    55  
    56  	s2 := &model.Session{}
    57  	s2.UserId = s1.UserId
    58  
    59  	s2, err = ss.Session().Save(s2)
    60  	require.NoError(t, err)
    61  
    62  	s3 := &model.Session{}
    63  	s3.UserId = s1.UserId
    64  	s3.ExpiresAt = 1
    65  
    66  	s3, err = ss.Session().Save(s3)
    67  	require.NoError(t, err)
    68  
    69  	session, err := ss.Session().Get(context.Background(), s1.Id)
    70  	require.NoError(t, err)
    71  	require.Equal(t, session.Id, s1.Id, "should match")
    72  
    73  	data, err := ss.Session().GetSessions(s1.UserId)
    74  	require.NoError(t, err)
    75  	require.Len(t, data, 3, "should match len")
    76  }
    77  
    78  func testSessionGetWithDeviceId(t *testing.T, ss store.Store) {
    79  	s1 := &model.Session{}
    80  	s1.UserId = model.NewId()
    81  	s1.ExpiresAt = model.GetMillis() + 10000
    82  
    83  	s1, err := ss.Session().Save(s1)
    84  	require.NoError(t, err)
    85  
    86  	s2 := &model.Session{}
    87  	s2.UserId = s1.UserId
    88  	s2.DeviceId = model.NewId()
    89  	s2.ExpiresAt = model.GetMillis() + 10000
    90  
    91  	s2, err = ss.Session().Save(s2)
    92  	require.NoError(t, err)
    93  
    94  	s3 := &model.Session{}
    95  	s3.UserId = s1.UserId
    96  	s3.ExpiresAt = 1
    97  	s3.DeviceId = model.NewId()
    98  
    99  	s3, err = ss.Session().Save(s3)
   100  	require.NoError(t, err)
   101  
   102  	data, err := ss.Session().GetSessionsWithActiveDeviceIds(s1.UserId)
   103  	require.NoError(t, err)
   104  	require.Len(t, data, 1, "should match len")
   105  }
   106  
   107  func testSessionRemove(t *testing.T, ss store.Store) {
   108  	s1 := &model.Session{}
   109  	s1.UserId = model.NewId()
   110  
   111  	s1, err := ss.Session().Save(s1)
   112  	require.NoError(t, err)
   113  
   114  	session, err := ss.Session().Get(context.Background(), s1.Id)
   115  	require.NoError(t, err)
   116  	require.Equal(t, session.Id, s1.Id, "should match")
   117  
   118  	removeErr := ss.Session().Remove(s1.Id)
   119  	require.NoError(t, removeErr)
   120  
   121  	_, err = ss.Session().Get(context.Background(), s1.Id)
   122  	require.Error(t, err, "should have been removed")
   123  }
   124  
   125  func testSessionRemoveAll(t *testing.T, ss store.Store) {
   126  	s1 := &model.Session{}
   127  	s1.UserId = model.NewId()
   128  
   129  	s1, err := ss.Session().Save(s1)
   130  	require.NoError(t, err)
   131  
   132  	session, err := ss.Session().Get(context.Background(), s1.Id)
   133  	require.NoError(t, err)
   134  	require.Equal(t, session.Id, s1.Id, "should match")
   135  
   136  	removeErr := ss.Session().RemoveAllSessions()
   137  	require.NoError(t, removeErr)
   138  
   139  	_, err = ss.Session().Get(context.Background(), s1.Id)
   140  	require.Error(t, err, "should have been removed")
   141  }
   142  
   143  func testSessionRemoveByUser(t *testing.T, ss store.Store) {
   144  	s1 := &model.Session{}
   145  	s1.UserId = model.NewId()
   146  
   147  	s1, err := ss.Session().Save(s1)
   148  	require.NoError(t, err)
   149  
   150  	session, err := ss.Session().Get(context.Background(), s1.Id)
   151  	require.NoError(t, err)
   152  	require.Equal(t, session.Id, s1.Id, "should match")
   153  
   154  	deleteErr := ss.Session().PermanentDeleteSessionsByUser(s1.UserId)
   155  	require.NoError(t, deleteErr)
   156  
   157  	_, err = ss.Session().Get(context.Background(), s1.Id)
   158  	require.Error(t, err, "should have been removed")
   159  }
   160  
   161  func testSessionRemoveToken(t *testing.T, ss store.Store) {
   162  	s1 := &model.Session{}
   163  	s1.UserId = model.NewId()
   164  
   165  	s1, err := ss.Session().Save(s1)
   166  	require.NoError(t, err)
   167  
   168  	session, err := ss.Session().Get(context.Background(), s1.Id)
   169  	require.NoError(t, err)
   170  	require.Equal(t, session.Id, s1.Id, "should match")
   171  
   172  	removeErr := ss.Session().Remove(s1.Token)
   173  	require.NoError(t, removeErr)
   174  
   175  	_, err = ss.Session().Get(context.Background(), s1.Id)
   176  	require.Error(t, err, "should have been removed")
   177  
   178  	data, err := ss.Session().GetSessions(s1.UserId)
   179  	require.NoError(t, err)
   180  	require.Empty(t, data, "should match len")
   181  }
   182  
   183  func testSessionUpdateDeviceId(t *testing.T, ss store.Store) {
   184  	s1 := &model.Session{}
   185  	s1.UserId = model.NewId()
   186  
   187  	s1, err := ss.Session().Save(s1)
   188  	require.NoError(t, err)
   189  
   190  	_, err = ss.Session().UpdateDeviceId(s1.Id, model.PUSH_NOTIFY_APPLE+":1234567890", s1.ExpiresAt)
   191  	require.NoError(t, err)
   192  
   193  	s2 := &model.Session{}
   194  	s2.UserId = model.NewId()
   195  
   196  	s2, err = ss.Session().Save(s2)
   197  	require.NoError(t, err)
   198  
   199  	_, err = ss.Session().UpdateDeviceId(s2.Id, model.PUSH_NOTIFY_APPLE+":1234567890", s1.ExpiresAt)
   200  	require.NoError(t, err)
   201  }
   202  
   203  func testSessionUpdateDeviceId2(t *testing.T, ss store.Store) {
   204  	s1 := &model.Session{}
   205  	s1.UserId = model.NewId()
   206  
   207  	s1, err := ss.Session().Save(s1)
   208  	require.NoError(t, err)
   209  
   210  	_, err = ss.Session().UpdateDeviceId(s1.Id, model.PUSH_NOTIFY_APPLE_REACT_NATIVE+":1234567890", s1.ExpiresAt)
   211  	require.NoError(t, err)
   212  
   213  	s2 := &model.Session{}
   214  	s2.UserId = model.NewId()
   215  
   216  	s2, err = ss.Session().Save(s2)
   217  	require.NoError(t, err)
   218  
   219  	_, err = ss.Session().UpdateDeviceId(s2.Id, model.PUSH_NOTIFY_APPLE_REACT_NATIVE+":1234567890", s1.ExpiresAt)
   220  	require.NoError(t, err)
   221  }
   222  
   223  func testSessionStoreUpdateExpiresAt(t *testing.T, ss store.Store) {
   224  	s1 := &model.Session{}
   225  	s1.UserId = model.NewId()
   226  
   227  	s1, err := ss.Session().Save(s1)
   228  	require.NoError(t, err)
   229  
   230  	err = ss.Session().UpdateExpiresAt(s1.Id, 1234567890)
   231  	require.NoError(t, err)
   232  
   233  	session, err := ss.Session().Get(context.Background(), s1.Id)
   234  	require.NoError(t, err)
   235  	require.EqualValues(t, session.ExpiresAt, 1234567890, "ExpiresAt not updated correctly")
   236  }
   237  
   238  func testSessionStoreUpdateLastActivityAt(t *testing.T, ss store.Store) {
   239  	s1 := &model.Session{}
   240  	s1.UserId = model.NewId()
   241  
   242  	s1, err := ss.Session().Save(s1)
   243  	require.NoError(t, err)
   244  
   245  	err = ss.Session().UpdateLastActivityAt(s1.Id, 1234567890)
   246  	require.NoError(t, err)
   247  
   248  	session, err := ss.Session().Get(context.Background(), s1.Id)
   249  	require.NoError(t, err)
   250  	require.EqualValues(t, session.LastActivityAt, 1234567890, "LastActivityAt not updated correctly")
   251  }
   252  
   253  func testSessionCount(t *testing.T, ss store.Store) {
   254  	s1 := &model.Session{}
   255  	s1.UserId = model.NewId()
   256  	s1.ExpiresAt = model.GetMillis() + 100000
   257  
   258  	s1, err := ss.Session().Save(s1)
   259  	require.NoError(t, err)
   260  
   261  	count, err := ss.Session().AnalyticsSessionCount()
   262  	require.NoError(t, err)
   263  	require.NotZero(t, count, "should have at least 1 session")
   264  }
   265  
   266  func testSessionCleanup(t *testing.T, ss store.Store) {
   267  	now := model.GetMillis()
   268  
   269  	s1 := &model.Session{}
   270  	s1.UserId = model.NewId()
   271  	s1.ExpiresAt = 0 // never expires
   272  
   273  	s1, err := ss.Session().Save(s1)
   274  	require.NoError(t, err)
   275  
   276  	s2 := &model.Session{}
   277  	s2.UserId = s1.UserId
   278  	s2.ExpiresAt = now + 1000000 // expires in the future
   279  
   280  	s2, err = ss.Session().Save(s2)
   281  	require.NoError(t, err)
   282  
   283  	s3 := &model.Session{}
   284  	s3.UserId = model.NewId()
   285  	s3.ExpiresAt = 1 // expired
   286  
   287  	s3, err = ss.Session().Save(s3)
   288  	require.NoError(t, err)
   289  
   290  	s4 := &model.Session{}
   291  	s4.UserId = model.NewId()
   292  	s4.ExpiresAt = 2 // expired
   293  
   294  	s4, err = ss.Session().Save(s4)
   295  	require.NoError(t, err)
   296  
   297  	ss.Session().Cleanup(now, 1)
   298  
   299  	_, err = ss.Session().Get(context.Background(), s1.Id)
   300  	assert.NoError(t, err)
   301  
   302  	_, err = ss.Session().Get(context.Background(), s2.Id)
   303  	assert.NoError(t, err)
   304  
   305  	_, err = ss.Session().Get(context.Background(), s3.Id)
   306  	assert.Error(t, err)
   307  
   308  	_, err = ss.Session().Get(context.Background(), s4.Id)
   309  	assert.Error(t, err)
   310  
   311  	removeErr := ss.Session().Remove(s1.Id)
   312  	require.NoError(t, removeErr)
   313  
   314  	removeErr = ss.Session().Remove(s2.Id)
   315  	require.NoError(t, removeErr)
   316  }
   317  
   318  func testGetSessionsExpired(t *testing.T, ss store.Store) {
   319  	now := model.GetMillis()
   320  
   321  	// Clear existing sessions.
   322  	err := ss.Session().RemoveAllSessions()
   323  	require.NoError(t, err)
   324  
   325  	s1 := &model.Session{}
   326  	s1.UserId = model.NewId()
   327  	s1.DeviceId = model.NewId()
   328  	s1.ExpiresAt = 0 // never expires
   329  	s1, err = ss.Session().Save(s1)
   330  	require.NoError(t, err)
   331  
   332  	s2 := &model.Session{}
   333  	s2.UserId = model.NewId()
   334  	s2.DeviceId = model.NewId()
   335  	s2.ExpiresAt = now - TenMinutes // expired within threshold
   336  	s2, err = ss.Session().Save(s2)
   337  	require.NoError(t, err)
   338  
   339  	s3 := &model.Session{}
   340  	s3.UserId = model.NewId()
   341  	s3.DeviceId = model.NewId()
   342  	s3.ExpiresAt = now - (TenMinutes * 100) // expired outside threshold
   343  	s3, err = ss.Session().Save(s3)
   344  	require.NoError(t, err)
   345  
   346  	s4 := &model.Session{}
   347  	s4.UserId = model.NewId()
   348  	s4.ExpiresAt = now - TenMinutes // expired within threshold, but not mobile
   349  	s4, err = ss.Session().Save(s4)
   350  	require.NoError(t, err)
   351  
   352  	s5 := &model.Session{}
   353  	s5.UserId = model.NewId()
   354  	s5.DeviceId = model.NewId()
   355  	s5.ExpiresAt = now + (TenMinutes * 100000) // not expired
   356  	s5, err = ss.Session().Save(s5)
   357  	require.NoError(t, err)
   358  
   359  	sessions, err := ss.Session().GetSessionsExpired(TenMinutes*2, true, true) // mobile only
   360  	require.NoError(t, err)
   361  	require.Len(t, sessions, 1)
   362  	require.Equal(t, s2.Id, sessions[0].Id)
   363  
   364  	sessions, err = ss.Session().GetSessionsExpired(TenMinutes*2, false, true) // all client types
   365  	require.NoError(t, err)
   366  	require.Len(t, sessions, 2)
   367  	expected := []string{s2.Id, s4.Id}
   368  	for _, sess := range sessions {
   369  		require.Contains(t, expected, sess.Id)
   370  	}
   371  }
   372  
   373  func testUpdateExpiredNotify(t *testing.T, ss store.Store) {
   374  	s1 := &model.Session{}
   375  	s1.UserId = model.NewId()
   376  	s1.DeviceId = model.NewId()
   377  	s1.ExpiresAt = model.GetMillis() + TenMinutes
   378  	s1, err := ss.Session().Save(s1)
   379  	require.NoError(t, err)
   380  
   381  	session, err := ss.Session().Get(context.Background(), s1.Id)
   382  	require.NoError(t, err)
   383  	require.False(t, session.ExpiredNotify)
   384  
   385  	err = ss.Session().UpdateExpiredNotify(session.Id, true)
   386  	require.NoError(t, err)
   387  	session, err = ss.Session().Get(context.Background(), s1.Id)
   388  	require.NoError(t, err)
   389  	require.True(t, session.ExpiredNotify)
   390  
   391  	err = ss.Session().UpdateExpiredNotify(session.Id, false)
   392  	require.NoError(t, err)
   393  	session, err = ss.Session().Get(context.Background(), s1.Id)
   394  	require.NoError(t, err)
   395  	require.False(t, session.ExpiredNotify)
   396  }