github.com/gogf/gf/v2@v2.7.4/os/gsession/gsession_storage_file.go (about)

     1  // Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the MIT License.
     4  // If a copy of the MIT was not distributed with this file,
     5  // You can obtain one at https://github.com/gogf/gf.
     6  
     7  package gsession
     8  
     9  import (
    10  	"context"
    11  	"fmt"
    12  	"os"
    13  	"time"
    14  
    15  	"github.com/gogf/gf/v2/container/gmap"
    16  	"github.com/gogf/gf/v2/container/gset"
    17  	"github.com/gogf/gf/v2/crypto/gaes"
    18  	"github.com/gogf/gf/v2/encoding/gbinary"
    19  	"github.com/gogf/gf/v2/errors/gcode"
    20  	"github.com/gogf/gf/v2/errors/gerror"
    21  	"github.com/gogf/gf/v2/internal/intlog"
    22  	"github.com/gogf/gf/v2/internal/json"
    23  	"github.com/gogf/gf/v2/os/gfile"
    24  	"github.com/gogf/gf/v2/os/gtime"
    25  	"github.com/gogf/gf/v2/os/gtimer"
    26  )
    27  
    28  // StorageFile implements the Session Storage interface with file system.
    29  type StorageFile struct {
    30  	StorageBase
    31  	path          string        // Session file storage folder path.
    32  	ttl           time.Duration // Session TTL.
    33  	cryptoKey     []byte        // Used when enable crypto feature.
    34  	cryptoEnabled bool          // Used when enable crypto feature.
    35  	updatingIdSet *gset.StrSet  // To be batched updated session id set.
    36  }
    37  
    38  const (
    39  	DefaultStorageFileCryptoEnabled        = false
    40  	DefaultStorageFileUpdateTTLInterval    = 10 * time.Second
    41  	DefaultStorageFileClearExpiredInterval = time.Hour
    42  )
    43  
    44  var (
    45  	DefaultStorageFilePath      = gfile.Temp("gsessions")
    46  	DefaultStorageFileCryptoKey = []byte("Session storage file crypto key!")
    47  )
    48  
    49  // NewStorageFile creates and returns a file storage object for session.
    50  func NewStorageFile(path string, ttl time.Duration) *StorageFile {
    51  	var (
    52  		ctx         = context.TODO()
    53  		storagePath = DefaultStorageFilePath
    54  	)
    55  	if path != "" {
    56  		storagePath, _ = gfile.Search(path)
    57  		if storagePath == "" {
    58  			panic(gerror.NewCodef(gcode.CodeInvalidParameter, `"%s" does not exist`, path))
    59  		}
    60  		if !gfile.IsWritable(storagePath) {
    61  			panic(gerror.NewCodef(gcode.CodeInvalidParameter, `"%s" is not writable`, path))
    62  		}
    63  	}
    64  	if storagePath != "" {
    65  		if err := gfile.Mkdir(storagePath); err != nil {
    66  			panic(gerror.Wrapf(err, `Mkdir "%s" failed in PWD "%s"`, path, gfile.Pwd()))
    67  		}
    68  	}
    69  	s := &StorageFile{
    70  		path:          storagePath,
    71  		ttl:           ttl,
    72  		cryptoKey:     DefaultStorageFileCryptoKey,
    73  		cryptoEnabled: DefaultStorageFileCryptoEnabled,
    74  		updatingIdSet: gset.NewStrSet(true),
    75  	}
    76  
    77  	gtimer.AddSingleton(ctx, DefaultStorageFileUpdateTTLInterval, s.timelyUpdateSessionTTL)
    78  	gtimer.AddSingleton(ctx, DefaultStorageFileClearExpiredInterval, s.timelyClearExpiredSessionFile)
    79  	return s
    80  }
    81  
    82  // timelyUpdateSessionTTL batch updates the TTL for sessions timely.
    83  func (s *StorageFile) timelyUpdateSessionTTL(ctx context.Context) {
    84  	var (
    85  		sessionId string
    86  		err       error
    87  	)
    88  	// Batch updating sessions.
    89  	for {
    90  		if sessionId = s.updatingIdSet.Pop(); sessionId == "" {
    91  			break
    92  		}
    93  		if err = s.updateSessionTTl(context.TODO(), sessionId); err != nil {
    94  			intlog.Errorf(context.TODO(), `%+v`, err)
    95  		}
    96  	}
    97  }
    98  
    99  // timelyClearExpiredSessionFile deletes all expired files timely.
   100  func (s *StorageFile) timelyClearExpiredSessionFile(ctx context.Context) {
   101  	files, err := gfile.ScanDirFile(s.path, "*.session", false)
   102  	if err != nil {
   103  		intlog.Errorf(ctx, `%+v`, err)
   104  		return
   105  	}
   106  	for _, file := range files {
   107  		if err = s.checkAndClearSessionFile(ctx, file); err != nil {
   108  			intlog.Errorf(ctx, `%+v`, err)
   109  		}
   110  	}
   111  }
   112  
   113  // SetCryptoKey sets the crypto key for session storage.
   114  // The crypto key is used when crypto feature is enabled.
   115  func (s *StorageFile) SetCryptoKey(key []byte) {
   116  	s.cryptoKey = key
   117  }
   118  
   119  // SetCryptoEnabled enables/disables the crypto feature for session storage.
   120  func (s *StorageFile) SetCryptoEnabled(enabled bool) {
   121  	s.cryptoEnabled = enabled
   122  }
   123  
   124  // sessionFilePath returns the storage file path for given session id.
   125  func (s *StorageFile) sessionFilePath(sessionId string) string {
   126  	return gfile.Join(s.path, sessionId) + ".session"
   127  }
   128  
   129  // RemoveAll deletes all key-value pairs from storage.
   130  func (s *StorageFile) RemoveAll(ctx context.Context, sessionId string) error {
   131  	return gfile.Remove(s.sessionFilePath(sessionId))
   132  }
   133  
   134  // GetSession returns the session data as *gmap.StrAnyMap for given session id from storage.
   135  //
   136  // The parameter `ttl` specifies the TTL for this session, and it returns nil if the TTL is exceeded.
   137  // The parameter `data` is the current old session data stored in memory,
   138  // and for some storage it might be nil if memory storage is disabled.
   139  //
   140  // This function is called ever when session starts.
   141  func (s *StorageFile) GetSession(ctx context.Context, sessionId string, ttl time.Duration) (sessionData *gmap.StrAnyMap, err error) {
   142  	var (
   143  		path    = s.sessionFilePath(sessionId)
   144  		content = gfile.GetBytes(path)
   145  	)
   146  	// It updates the TTL only if the session file already exists.
   147  	if len(content) > 8 {
   148  		timestampMilli := gbinary.DecodeToInt64(content[:8])
   149  		if timestampMilli+ttl.Nanoseconds()/1e6 < gtime.TimestampMilli() {
   150  			return nil, nil
   151  		}
   152  		content = content[8:]
   153  		// Decrypt with AES.
   154  		if s.cryptoEnabled {
   155  			content, err = gaes.Decrypt(content, DefaultStorageFileCryptoKey)
   156  			if err != nil {
   157  				return nil, err
   158  			}
   159  		}
   160  		var m map[string]interface{}
   161  		if err = json.UnmarshalUseNumber(content, &m); err != nil {
   162  			return nil, err
   163  		}
   164  		if m == nil {
   165  			return nil, nil
   166  		}
   167  		return gmap.NewStrAnyMapFrom(m, true), nil
   168  	}
   169  	return nil, nil
   170  }
   171  
   172  // SetSession updates the data map for specified session id.
   173  // This function is called ever after session, which is changed dirty, is closed.
   174  // This copy all session data map from memory to storage.
   175  func (s *StorageFile) SetSession(ctx context.Context, sessionId string, sessionData *gmap.StrAnyMap, ttl time.Duration) error {
   176  	intlog.Printf(ctx, "StorageFile.SetSession: %s, %v, %v", sessionId, sessionData, ttl)
   177  	path := s.sessionFilePath(sessionId)
   178  	content, err := json.Marshal(sessionData)
   179  	if err != nil {
   180  		return err
   181  	}
   182  	// Encrypt with AES.
   183  	if s.cryptoEnabled {
   184  		content, err = gaes.Encrypt(content, DefaultStorageFileCryptoKey)
   185  		if err != nil {
   186  			return err
   187  		}
   188  	}
   189  	file, err := gfile.OpenWithFlagPerm(
   190  		path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm,
   191  	)
   192  	if err != nil {
   193  		return err
   194  	}
   195  	defer file.Close()
   196  	if _, err = file.Write(gbinary.EncodeInt64(gtime.TimestampMilli())); err != nil {
   197  		err = gerror.Wrapf(err, `write data failed to file "%s"`, path)
   198  		return err
   199  	}
   200  	if _, err = file.Write(content); err != nil {
   201  		err = gerror.Wrapf(err, `write data failed to file "%s"`, path)
   202  		return err
   203  	}
   204  	return nil
   205  }
   206  
   207  // UpdateTTL updates the TTL for specified session id.
   208  // This function is called ever after session, which is not dirty, is closed.
   209  // It just adds the session id to the async handling queue.
   210  func (s *StorageFile) UpdateTTL(ctx context.Context, sessionId string, ttl time.Duration) error {
   211  	intlog.Printf(ctx, "StorageFile.UpdateTTL: %s, %v", sessionId, ttl)
   212  	if ttl >= DefaultStorageFileUpdateTTLInterval {
   213  		s.updatingIdSet.Add(sessionId)
   214  	}
   215  	return nil
   216  }
   217  
   218  // updateSessionTTL updates the TTL for specified session id.
   219  func (s *StorageFile) updateSessionTTl(ctx context.Context, sessionId string) error {
   220  	intlog.Printf(ctx, "StorageFile.updateSession: %s", sessionId)
   221  	path := s.sessionFilePath(sessionId)
   222  	file, err := gfile.OpenWithFlag(path, os.O_WRONLY)
   223  	if err != nil {
   224  		return err
   225  	}
   226  	if _, err = file.WriteAt(gbinary.EncodeInt64(gtime.TimestampMilli()), 0); err != nil {
   227  		err = gerror.Wrapf(err, `write data failed to file "%s"`, path)
   228  		return err
   229  	}
   230  	return file.Close()
   231  }
   232  
   233  func (s *StorageFile) checkAndClearSessionFile(ctx context.Context, path string) (err error) {
   234  	var (
   235  		file                *os.File
   236  		readBytesCount      int
   237  		timestampMilliBytes = make([]byte, 8)
   238  	)
   239  	file, err = gfile.OpenWithFlag(path, os.O_RDONLY)
   240  	if err != nil {
   241  		return err
   242  	}
   243  	defer file.Close()
   244  	// Read the session file updated timestamp in milliseconds.
   245  	readBytesCount, err = file.Read(timestampMilliBytes)
   246  	if err != nil {
   247  		return
   248  	}
   249  	if readBytesCount != 8 {
   250  		return gerror.Newf(`invalid read bytes count "%d", expect "8"`, readBytesCount)
   251  	}
   252  	// Remove expired session file.
   253  	var (
   254  		ttlInMilliseconds     = s.ttl.Nanoseconds() / 1e6
   255  		fileTimestampMilli    = gbinary.DecodeToInt64(timestampMilliBytes)
   256  		currentTimestampMilli = gtime.TimestampMilli()
   257  	)
   258  	if fileTimestampMilli+ttlInMilliseconds < currentTimestampMilli {
   259  		intlog.PrintFunc(ctx, func() string {
   260  			return fmt.Sprintf(
   261  				`clear expired session file "%s": updated datetime "%s", ttl "%s"`,
   262  				path, gtime.NewFromTimeStamp(fileTimestampMilli), s.ttl,
   263  			)
   264  		})
   265  		return gfile.Remove(path)
   266  	}
   267  	return nil
   268  }