github.com/masterhung0112/hk_server/v5@v5.0.0-20220302090640-ec71aef15e1c/services/users/session.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package users
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"time"
    10  
    11  	"github.com/masterhung0112/hk_server/v5/model"
    12  	"github.com/masterhung0112/hk_server/v5/shared/mlog"
    13  	"github.com/masterhung0112/hk_server/v5/store/sqlstore"
    14  	"github.com/pkg/errors"
    15  )
    16  
    17  func (us *UserService) ReturnSessionToPool(session *model.Session) {
    18  	if session != nil {
    19  		session.Id = ""
    20  		us.sessionPool.Put(session)
    21  	}
    22  }
    23  
    24  func (us *UserService) CreateSession(session *model.Session) (*model.Session, error) {
    25  	session.Token = ""
    26  
    27  	session, err := us.sessionStore.Save(session)
    28  	if err != nil {
    29  		return nil, err
    30  	}
    31  
    32  	us.AddSessionToCache(session)
    33  
    34  	return session, nil
    35  }
    36  
    37  func (us *UserService) GetSession(token string) (*model.Session, error) {
    38  	var session = us.sessionPool.Get().(*model.Session)
    39  	if err := us.sessionCache.Get(token, session); err == nil {
    40  		if us.metrics != nil {
    41  			us.metrics.IncrementMemCacheHitCounterSession()
    42  		}
    43  	} else {
    44  		if us.metrics != nil {
    45  			us.metrics.IncrementMemCacheMissCounterSession()
    46  		}
    47  	}
    48  
    49  	if session.Id != "" {
    50  		return session, nil
    51  	}
    52  
    53  	return us.GetSessionContext(sqlstore.WithMaster(context.Background()), token)
    54  }
    55  
    56  func (us *UserService) GetSessionContext(ctx context.Context, token string) (*model.Session, error) {
    57  	return us.sessionStore.Get(ctx, token)
    58  }
    59  
    60  func (us *UserService) GetSessions(userID string) ([]*model.Session, error) {
    61  	return us.sessionStore.GetSessions(userID)
    62  }
    63  
    64  func (us *UserService) AddSessionToCache(session *model.Session) {
    65  	us.sessionCache.SetWithExpiry(session.Token, session, time.Duration(int64(*us.config().ServiceSettings.SessionCacheInMinutes))*time.Minute)
    66  }
    67  
    68  func (us *UserService) SessionCacheLength() int {
    69  	if l, err := us.sessionCache.Len(); err == nil {
    70  		return l
    71  	}
    72  	return 0
    73  }
    74  
    75  func (us *UserService) ClearUserSessionCacheLocal(userID string) {
    76  	if keys, err := us.sessionCache.Keys(); err == nil {
    77  		var session *model.Session
    78  		for _, key := range keys {
    79  			if err := us.sessionCache.Get(key, &session); err == nil {
    80  				if session.UserId == userID {
    81  					us.sessionCache.Remove(key)
    82  					if us.metrics != nil {
    83  						us.metrics.IncrementMemCacheInvalidationCounterSession()
    84  					}
    85  				}
    86  			}
    87  		}
    88  	}
    89  }
    90  
    91  func (us *UserService) ClearAllUsersSessionCacheLocal() {
    92  	us.sessionCache.Purge()
    93  }
    94  
    95  func (us *UserService) ClearUserSessionCache(userID string) {
    96  	us.ClearUserSessionCacheLocal(userID)
    97  
    98  	if us.cluster != nil {
    99  		msg := &model.ClusterMessage{
   100  			Event:    model.CLUSTER_EVENT_CLEAR_SESSION_CACHE_FOR_USER,
   101  			SendType: model.CLUSTER_SEND_RELIABLE,
   102  			Data:     userID,
   103  		}
   104  		us.cluster.SendClusterMessage(msg)
   105  	}
   106  }
   107  
   108  func (us *UserService) ClearAllUsersSessionCache() {
   109  	us.ClearAllUsersSessionCacheLocal()
   110  
   111  	if us.cluster != nil {
   112  		msg := &model.ClusterMessage{
   113  			Event:    model.CLUSTER_EVENT_CLEAR_SESSION_CACHE_FOR_ALL_USERS,
   114  			SendType: model.CLUSTER_SEND_RELIABLE,
   115  		}
   116  		us.cluster.SendClusterMessage(msg)
   117  	}
   118  }
   119  
   120  func (us *UserService) GetSessionByID(sessionID string) (*model.Session, error) {
   121  	return us.sessionStore.Get(context.Background(), sessionID)
   122  }
   123  
   124  func (us *UserService) RevokeSessionsFromAllUsers() error {
   125  	// revoke tokens before sessions so they can't be used to relogin
   126  	nErr := us.oAuthStore.RemoveAllAccessData()
   127  	if nErr != nil {
   128  		return errors.Wrap(DeleteAllAccessDataError, nErr.Error())
   129  	}
   130  	err := us.sessionStore.RemoveAllSessions()
   131  	if err != nil {
   132  		return err
   133  	}
   134  
   135  	us.ClearAllUsersSessionCache()
   136  	return nil
   137  }
   138  
   139  func (us *UserService) RevokeSessionsForDeviceId(userID string, deviceID string, currentSessionId string) error {
   140  	sessions, err := us.sessionStore.GetSessions(userID)
   141  	if err != nil {
   142  		return err
   143  	}
   144  	for _, session := range sessions {
   145  		if session.DeviceId == deviceID && session.Id != currentSessionId {
   146  			mlog.Debug("Revoking sessionId for userId. Re-login with the same device Id", mlog.String("session_id", session.Id), mlog.String("user_id", userID))
   147  			if err := us.RevokeSession(session); err != nil {
   148  				mlog.Warn("Could not revoke session for device", mlog.String("device_id", deviceID), mlog.Err(err))
   149  			}
   150  		}
   151  	}
   152  
   153  	return nil
   154  }
   155  
   156  func (us *UserService) RevokeSession(session *model.Session) error {
   157  	if session.IsOAuth {
   158  		if err := us.RevokeAccessToken(session.Token); err != nil {
   159  			return err
   160  		}
   161  	} else {
   162  		if err := us.sessionStore.Remove(session.Id); err != nil {
   163  			return errors.Wrap(DeleteSessionError, err.Error())
   164  		}
   165  	}
   166  
   167  	us.ClearUserSessionCache(session.UserId)
   168  
   169  	return nil
   170  }
   171  
   172  func (us *UserService) RevokeAccessToken(token string) error {
   173  	session, _ := us.GetSession(token)
   174  
   175  	defer us.ReturnSessionToPool(session)
   176  
   177  	schan := make(chan error, 1)
   178  	go func() {
   179  		schan <- us.sessionStore.Remove(token)
   180  		close(schan)
   181  	}()
   182  
   183  	if _, err := us.oAuthStore.GetAccessData(token); err != nil {
   184  		return errors.Wrap(GetTokenError, err.Error())
   185  	}
   186  
   187  	if err := us.oAuthStore.RemoveAccessData(token); err != nil {
   188  		return errors.Wrap(DeleteTokenError, err.Error())
   189  	}
   190  
   191  	if err := <-schan; err != nil {
   192  		return errors.Wrap(DeleteSessionError, err.Error())
   193  	}
   194  
   195  	if session != nil {
   196  		us.ClearUserSessionCache(session.UserId)
   197  	}
   198  
   199  	return nil
   200  }
   201  
   202  // SetSessionExpireInDays sets the session's expiry the specified number of days
   203  // relative to either the session creation date or the current time, depending
   204  // on the `ExtendSessionOnActivity` config setting.
   205  func (us *UserService) SetSessionExpireInDays(session *model.Session, days int) {
   206  	if session.CreateAt == 0 || *us.config().ServiceSettings.ExtendSessionLengthWithActivity {
   207  		session.ExpiresAt = model.GetMillis() + (1000 * 60 * 60 * 24 * int64(days))
   208  	} else {
   209  		session.ExpiresAt = session.CreateAt + (1000 * 60 * 60 * 24 * int64(days))
   210  	}
   211  }
   212  
   213  func (us *UserService) UpdateSessionsIsGuest(userID string, isGuest bool) error {
   214  	sessions, err := us.GetSessions(userID)
   215  	if err != nil {
   216  		return err
   217  	}
   218  
   219  	for _, session := range sessions {
   220  		session.AddProp(model.SESSION_PROP_IS_GUEST, fmt.Sprintf("%t", isGuest))
   221  		err := us.sessionStore.UpdateProps(session)
   222  		if err != nil {
   223  			mlog.Warn("Unable to update isGuest session", mlog.Err(err))
   224  			continue
   225  		}
   226  		us.AddSessionToCache(session)
   227  	}
   228  	return nil
   229  }
   230  
   231  func (us *UserService) RevokeAllSessions(userID string) error {
   232  	sessions, err := us.sessionStore.GetSessions(userID)
   233  	if err != nil {
   234  		return errors.Wrap(GetSessionError, err.Error())
   235  	}
   236  	for _, session := range sessions {
   237  		if session.IsOAuth {
   238  			us.RevokeAccessToken(session.Token)
   239  		} else {
   240  			if err := us.sessionStore.Remove(session.Id); err != nil {
   241  				return errors.Wrap(DeleteSessionError, err.Error())
   242  			}
   243  		}
   244  	}
   245  
   246  	us.ClearUserSessionCache(userID)
   247  
   248  	return nil
   249  }