github.com/kotovmak/go-admin@v1.1.1/modules/auth/session.go (about) 1 // Copyright 2019 GoAdmin Core Team. All rights reserved. 2 // Use of this source code is governed by a Apache-2.0 style 3 // license that can be found in the LICENSE file. 4 5 package auth 6 7 import ( 8 "encoding/json" 9 "net/http" 10 "strconv" 11 "time" 12 13 "github.com/kotovmak/go-admin/context" 14 "github.com/kotovmak/go-admin/modules/config" 15 "github.com/kotovmak/go-admin/modules/db" 16 "github.com/kotovmak/go-admin/modules/db/dialect" 17 "github.com/kotovmak/go-admin/modules/logger" 18 "github.com/kotovmak/go-admin/plugins/admin/modules" 19 ) 20 21 const DefaultCookieKey = "go_admin_session" 22 23 // NewDBDriver return the default PersistenceDriver. 24 func newDBDriver(conn db.Connection) *DBDriver { 25 return &DBDriver{ 26 conn: conn, 27 tableName: "goadmin_session", 28 } 29 } 30 31 // PersistenceDriver is a driver of storing and getting the session info. 32 type PersistenceDriver interface { 33 Load(string) (map[string]interface{}, error) 34 Update(sid string, values map[string]interface{}) error 35 } 36 37 // GetSessionByKey get the session value by key. 38 func GetSessionByKey(sesKey, key string, conn db.Connection) (interface{}, error) { 39 m, err := newDBDriver(conn).Load(sesKey) 40 return m[key], err 41 } 42 43 // Session contains info of session. 44 type Session struct { 45 Expires time.Duration 46 Cookie string 47 Values map[string]interface{} 48 Driver PersistenceDriver 49 Sid string 50 Context *context.Context 51 } 52 53 // Config wraps the Session info. 54 type Config struct { 55 Expires time.Duration 56 Cookie string 57 } 58 59 // UpdateConfig update the Expires and Cookie of Session. 60 func (ses *Session) UpdateConfig(config Config) { 61 ses.Expires = config.Expires 62 ses.Cookie = config.Cookie 63 } 64 65 // Get get the session value. 66 func (ses *Session) Get(key string) interface{} { 67 return ses.Values[key] 68 } 69 70 // Add add the session value of key. 71 func (ses *Session) Add(key string, value interface{}) error { 72 ses.Values[key] = value 73 if err := ses.Driver.Update(ses.Sid, ses.Values); err != nil { 74 return err 75 } 76 cookie := http.Cookie{ 77 Name: ses.Cookie, 78 Value: ses.Sid, 79 MaxAge: config.GetSessionLifeTime(), 80 Expires: time.Now().Add(ses.Expires), 81 HttpOnly: true, 82 Path: "/", 83 } 84 if config.GetDomain() != "" { 85 cookie.Domain = config.GetDomain() 86 } 87 ses.Context.SetCookie(&cookie) 88 return nil 89 } 90 91 // Clear clear a Session. 92 func (ses *Session) Clear() error { 93 ses.Values = map[string]interface{}{} 94 return ses.Driver.Update(ses.Sid, ses.Values) 95 } 96 97 // UseDriver set the driver of the Session. 98 func (ses *Session) UseDriver(driver PersistenceDriver) { 99 ses.Driver = driver 100 } 101 102 // StartCtx return a Session from the given Context. 103 func (ses *Session) StartCtx(ctx *context.Context) (*Session, error) { 104 if cookie, err := ctx.Request.Cookie(ses.Cookie); err == nil && cookie.Value != "" { 105 ses.Sid = cookie.Value 106 valueFromDriver, err := ses.Driver.Load(cookie.Value) 107 if err != nil { 108 return nil, err 109 } 110 if len(valueFromDriver) > 0 { 111 ses.Values = valueFromDriver 112 } 113 } else { 114 ses.Sid = modules.Uuid() 115 } 116 ses.Context = ctx 117 return ses, nil 118 } 119 120 // InitSession return the default Session. 121 func InitSession(ctx *context.Context, conn db.Connection) (*Session, error) { 122 123 sessions := new(Session) 124 sessions.UpdateConfig(Config{ 125 Expires: time.Second * time.Duration(config.GetSessionLifeTime()), 126 Cookie: DefaultCookieKey, 127 }) 128 129 sessions.UseDriver(newDBDriver(conn)) 130 sessions.Values = make(map[string]interface{}) 131 132 return sessions.StartCtx(ctx) 133 } 134 135 // DBDriver is a driver which uses database as a persistence tool. 136 type DBDriver struct { 137 conn db.Connection 138 tableName string 139 } 140 141 // Load implements the PersistenceDriver.Load. 142 func (driver *DBDriver) Load(sid string) (map[string]interface{}, error) { 143 sesModel, err := driver.table().Where("sid", "=", sid).First() 144 145 if db.CheckError(err, db.QUERY) { 146 return nil, err 147 } 148 149 if sesModel == nil { 150 return map[string]interface{}{}, nil 151 } 152 153 var values map[string]interface{} 154 err = json.Unmarshal([]byte(sesModel["values"].(string)), &values) 155 return values, err 156 } 157 158 func (driver *DBDriver) deleteOverdueSession() { 159 160 defer func() { 161 if err := recover(); err != nil { 162 logger.Error(err) 163 panic(err) 164 } 165 }() 166 167 var ( 168 duration = strconv.Itoa(config.GetSessionLifeTime() + 1000) 169 driverName = config.GetDatabases().GetDefault().Driver 170 raw = `` 171 ) 172 173 if db.DriverPostgresql == driverName { 174 raw = `extract(epoch from now()) - ` + duration + ` > extract(epoch from created_at)` 175 } else if db.DriverMysql == driverName { 176 raw = `unix_timestamp(created_at) < unix_timestamp() - ` + duration 177 } else if db.DriverSqlite == driverName { 178 raw = `strftime('%s', created_at) < strftime('%s', 'now') - ` + duration 179 } else if db.DriverMssql == driverName { 180 raw = `DATEDIFF(second, [created_at], GETDATE()) > ` + duration 181 } else if db.DriverOceanBase == driverName { 182 raw = `unix_timestamp(created_at) < unix_timestamp() - ` + duration 183 } 184 185 if raw != "" { 186 _ = driver.table().WhereRaw(raw).Delete() 187 } 188 } 189 190 // Update implements the PersistenceDriver.Update. 191 func (driver *DBDriver) Update(sid string, values map[string]interface{}) error { 192 193 go driver.deleteOverdueSession() 194 195 if sid != "" { 196 if len(values) == 0 { 197 err := driver.table().Where("sid", "=", sid).Delete() 198 if db.CheckError(err, db.DELETE) { 199 return err 200 } 201 } 202 valuesByte, err := json.Marshal(values) 203 if err != nil { 204 return err 205 } 206 sesValue := string(valuesByte) 207 sesModel, _ := driver.table().Where("sid", "=", sid).First() 208 if sesModel == nil { 209 if !config.GetNoLimitLoginIP() { 210 err = driver.table().Where("values", "=", sesValue).Delete() 211 if db.CheckError(err, db.DELETE) { 212 return err 213 } 214 } 215 _, err := driver.table().Insert(dialect.H{ 216 "values": sesValue, 217 "sid": sid, 218 }) 219 if db.CheckError(err, db.INSERT) { 220 return err 221 } 222 } else { 223 _, err := driver.table(). 224 Where("sid", "=", sid). 225 Update(dialect.H{ 226 "values": sesValue, 227 }) 228 if db.CheckError(err, db.UPDATE) { 229 return err 230 } 231 } 232 } 233 return nil 234 } 235 236 func (driver *DBDriver) table() *db.SQL { 237 return db.Table(driver.tableName).WithDriver(driver.conn) 238 }