github.com/kotovmak/go-admin@v1.1.1/modules/auth/session.go (about)

     1  // Copyright 2019 GoAdmin Core Team. All rights reserved.
     2  // Use of this source code is governed by a Apache-2.0 style
     3  // license that can be found in the LICENSE file.
     4  
     5  package auth
     6  
     7  import (
     8  	"encoding/json"
     9  	"net/http"
    10  	"strconv"
    11  	"time"
    12  
    13  	"github.com/kotovmak/go-admin/context"
    14  	"github.com/kotovmak/go-admin/modules/config"
    15  	"github.com/kotovmak/go-admin/modules/db"
    16  	"github.com/kotovmak/go-admin/modules/db/dialect"
    17  	"github.com/kotovmak/go-admin/modules/logger"
    18  	"github.com/kotovmak/go-admin/plugins/admin/modules"
    19  )
    20  
    21  const DefaultCookieKey = "go_admin_session"
    22  
    23  // NewDBDriver return the default PersistenceDriver.
    24  func newDBDriver(conn db.Connection) *DBDriver {
    25  	return &DBDriver{
    26  		conn:      conn,
    27  		tableName: "goadmin_session",
    28  	}
    29  }
    30  
    31  // PersistenceDriver is a driver of storing and getting the session info.
    32  type PersistenceDriver interface {
    33  	Load(string) (map[string]interface{}, error)
    34  	Update(sid string, values map[string]interface{}) error
    35  }
    36  
    37  // GetSessionByKey get the session value by key.
    38  func GetSessionByKey(sesKey, key string, conn db.Connection) (interface{}, error) {
    39  	m, err := newDBDriver(conn).Load(sesKey)
    40  	return m[key], err
    41  }
    42  
    43  // Session contains info of session.
    44  type Session struct {
    45  	Expires time.Duration
    46  	Cookie  string
    47  	Values  map[string]interface{}
    48  	Driver  PersistenceDriver
    49  	Sid     string
    50  	Context *context.Context
    51  }
    52  
    53  // Config wraps the Session info.
    54  type Config struct {
    55  	Expires time.Duration
    56  	Cookie  string
    57  }
    58  
    59  // UpdateConfig update the Expires and Cookie of Session.
    60  func (ses *Session) UpdateConfig(config Config) {
    61  	ses.Expires = config.Expires
    62  	ses.Cookie = config.Cookie
    63  }
    64  
    65  // Get get the session value.
    66  func (ses *Session) Get(key string) interface{} {
    67  	return ses.Values[key]
    68  }
    69  
    70  // Add add the session value of key.
    71  func (ses *Session) Add(key string, value interface{}) error {
    72  	ses.Values[key] = value
    73  	if err := ses.Driver.Update(ses.Sid, ses.Values); err != nil {
    74  		return err
    75  	}
    76  	cookie := http.Cookie{
    77  		Name:     ses.Cookie,
    78  		Value:    ses.Sid,
    79  		MaxAge:   config.GetSessionLifeTime(),
    80  		Expires:  time.Now().Add(ses.Expires),
    81  		HttpOnly: true,
    82  		Path:     "/",
    83  	}
    84  	if config.GetDomain() != "" {
    85  		cookie.Domain = config.GetDomain()
    86  	}
    87  	ses.Context.SetCookie(&cookie)
    88  	return nil
    89  }
    90  
    91  // Clear clear a Session.
    92  func (ses *Session) Clear() error {
    93  	ses.Values = map[string]interface{}{}
    94  	return ses.Driver.Update(ses.Sid, ses.Values)
    95  }
    96  
    97  // UseDriver set the driver of the Session.
    98  func (ses *Session) UseDriver(driver PersistenceDriver) {
    99  	ses.Driver = driver
   100  }
   101  
   102  // StartCtx return a Session from the given Context.
   103  func (ses *Session) StartCtx(ctx *context.Context) (*Session, error) {
   104  	if cookie, err := ctx.Request.Cookie(ses.Cookie); err == nil && cookie.Value != "" {
   105  		ses.Sid = cookie.Value
   106  		valueFromDriver, err := ses.Driver.Load(cookie.Value)
   107  		if err != nil {
   108  			return nil, err
   109  		}
   110  		if len(valueFromDriver) > 0 {
   111  			ses.Values = valueFromDriver
   112  		}
   113  	} else {
   114  		ses.Sid = modules.Uuid()
   115  	}
   116  	ses.Context = ctx
   117  	return ses, nil
   118  }
   119  
   120  // InitSession return the default Session.
   121  func InitSession(ctx *context.Context, conn db.Connection) (*Session, error) {
   122  
   123  	sessions := new(Session)
   124  	sessions.UpdateConfig(Config{
   125  		Expires: time.Second * time.Duration(config.GetSessionLifeTime()),
   126  		Cookie:  DefaultCookieKey,
   127  	})
   128  
   129  	sessions.UseDriver(newDBDriver(conn))
   130  	sessions.Values = make(map[string]interface{})
   131  
   132  	return sessions.StartCtx(ctx)
   133  }
   134  
   135  // DBDriver is a driver which uses database as a persistence tool.
   136  type DBDriver struct {
   137  	conn      db.Connection
   138  	tableName string
   139  }
   140  
   141  // Load implements the PersistenceDriver.Load.
   142  func (driver *DBDriver) Load(sid string) (map[string]interface{}, error) {
   143  	sesModel, err := driver.table().Where("sid", "=", sid).First()
   144  
   145  	if db.CheckError(err, db.QUERY) {
   146  		return nil, err
   147  	}
   148  
   149  	if sesModel == nil {
   150  		return map[string]interface{}{}, nil
   151  	}
   152  
   153  	var values map[string]interface{}
   154  	err = json.Unmarshal([]byte(sesModel["values"].(string)), &values)
   155  	return values, err
   156  }
   157  
   158  func (driver *DBDriver) deleteOverdueSession() {
   159  
   160  	defer func() {
   161  		if err := recover(); err != nil {
   162  			logger.Error(err)
   163  			panic(err)
   164  		}
   165  	}()
   166  
   167  	var (
   168  		duration   = strconv.Itoa(config.GetSessionLifeTime() + 1000)
   169  		driverName = config.GetDatabases().GetDefault().Driver
   170  		raw        = ``
   171  	)
   172  
   173  	if db.DriverPostgresql == driverName {
   174  		raw = `extract(epoch from now()) - ` + duration + ` > extract(epoch from created_at)`
   175  	} else if db.DriverMysql == driverName {
   176  		raw = `unix_timestamp(created_at) < unix_timestamp() - ` + duration
   177  	} else if db.DriverSqlite == driverName {
   178  		raw = `strftime('%s', created_at) < strftime('%s', 'now') - ` + duration
   179  	} else if db.DriverMssql == driverName {
   180  		raw = `DATEDIFF(second, [created_at], GETDATE()) > ` + duration
   181  	} else if db.DriverOceanBase == driverName {
   182  		raw = `unix_timestamp(created_at) < unix_timestamp() - ` + duration
   183  	}
   184  
   185  	if raw != "" {
   186  		_ = driver.table().WhereRaw(raw).Delete()
   187  	}
   188  }
   189  
   190  // Update implements the PersistenceDriver.Update.
   191  func (driver *DBDriver) Update(sid string, values map[string]interface{}) error {
   192  
   193  	go driver.deleteOverdueSession()
   194  
   195  	if sid != "" {
   196  		if len(values) == 0 {
   197  			err := driver.table().Where("sid", "=", sid).Delete()
   198  			if db.CheckError(err, db.DELETE) {
   199  				return err
   200  			}
   201  		}
   202  		valuesByte, err := json.Marshal(values)
   203  		if err != nil {
   204  			return err
   205  		}
   206  		sesValue := string(valuesByte)
   207  		sesModel, _ := driver.table().Where("sid", "=", sid).First()
   208  		if sesModel == nil {
   209  			if !config.GetNoLimitLoginIP() {
   210  				err = driver.table().Where("values", "=", sesValue).Delete()
   211  				if db.CheckError(err, db.DELETE) {
   212  					return err
   213  				}
   214  			}
   215  			_, err := driver.table().Insert(dialect.H{
   216  				"values": sesValue,
   217  				"sid":    sid,
   218  			})
   219  			if db.CheckError(err, db.INSERT) {
   220  				return err
   221  			}
   222  		} else {
   223  			_, err := driver.table().
   224  				Where("sid", "=", sid).
   225  				Update(dialect.H{
   226  					"values": sesValue,
   227  				})
   228  			if db.CheckError(err, db.UPDATE) {
   229  				return err
   230  			}
   231  		}
   232  	}
   233  	return nil
   234  }
   235  
   236  func (driver *DBDriver) table() *db.SQL {
   237  	return db.Table(driver.tableName).WithDriver(driver.conn)
   238  }