github.com/decred/politeia@v1.4.0/politeiawww/sessions/mysql/mysql.go (about) 1 // Copyright (c) 2021 The Decred developers 2 // Use of this source code is governed by an ISC 3 // license that can be found in the LICENSE file. 4 5 package mysql 6 7 import ( 8 "context" 9 "database/sql" 10 "encoding/json" 11 "fmt" 12 "time" 13 14 "github.com/decred/politeia/politeiawww/sessions" 15 "github.com/pkg/errors" 16 ) 17 18 // sessionsTable is the table for the encoded session values. 19 // 20 // The id column is 128 bytes so that it can accomidate a 64 byte base64, 21 // base32, or hex encoded key. 22 // 23 // The encoded_session column has a max length of 2^16 bytes, which is around 24 // 64KB. 25 // 26 // The created_at column contains a Unix timestamp and is used to manually 27 // clean up expired sessions. The gorilla/sessions Store does not do this 28 // automatically. 29 const sessionsTable = ` 30 id CHAR(128) PRIMARY KEY, 31 encoded_session BLOB NOT NULL, 32 created_at BIGINT NOT NULL 33 ` 34 35 var ( 36 _ sessions.DB = (*mysql)(nil) 37 ) 38 39 // mysql implements the sessions.DB interface. 40 type mysql struct { 41 // db is the mysql DB context. 42 db *sql.DB 43 44 // sessionMaxAge is the max age of a session in seconds. This is used to 45 // periodically clean up expired sessions from the database. The 46 // gorilla/sessions Store implemenation does not do this automatically. It 47 // must be done manually in the database layer. 48 sessionMaxAge int64 49 50 // opts contains the session database options. 51 opts *Opts 52 } 53 54 // Opts contains configurable options for the sessions database. These are 55 // not required. Sane defaults are used when the options are not provided. 56 type Opts struct { 57 // TableName is the table name for the sessions table. 58 TableName string 59 60 // OpTimeout is the timeout for a single database operation. 61 OpTimeout time.Duration 62 } 63 64 const ( 65 // defaultTableName is the default table name for the sessions table. 66 defaultTableName = "sessions" 67 68 // defaultOpTimeout is the default timeout for a single database operation. 69 defaultOpTimeout = 1 * time.Minute 70 ) 71 72 // New returns a new mysql context that implements the sessions DB interface. 73 // The opts param can be used to override the default mysql context settings. 74 // 75 // The sessionMaxAge is the max age in seconds of a session. This function 76 // cleans up any expired sessions from the database as part of the 77 // initialization. A sessionMaxAge of <=0 will cause the sessions database 78 // to be dropped and recreated. 79 func New(db *sql.DB, sessionMaxAge int64, opts *Opts) (*mysql, error) { 80 // Setup the database options 81 if opts == nil { 82 opts = &Opts{} 83 } 84 if opts.TableName == "" { 85 opts.TableName = defaultTableName 86 } 87 if opts.OpTimeout == 0 { 88 opts.OpTimeout = defaultOpTimeout 89 } 90 91 // Setup the mysql context 92 m := mysql{ 93 db: db, 94 sessionMaxAge: sessionMaxAge, 95 opts: opts, 96 } 97 98 // Perform database setup 99 if sessionMaxAge <= 0 { 100 err := m.dropTable() 101 if err != nil { 102 return nil, err 103 } 104 } 105 err := m.createTable() 106 if err != nil { 107 return nil, err 108 } 109 err = m.cleanup() 110 if err != nil { 111 return nil, err 112 } 113 114 return &m, nil 115 } 116 117 // Save saves a session to the database. 118 // 119 // Save satisfies the sessions.DB interface. 120 func (m *mysql) Save(sessionID string, s sessions.EncodedSession) error { 121 log.Tracef("Save %v", sessionID) 122 123 es, err := json.Marshal(s) 124 if err != nil { 125 return err 126 } 127 128 ctx, cancel := m.ctxForOp() 129 defer cancel() 130 131 q := `INSERT INTO %v 132 (id, encoded_session, created_at) VALUES (?, ?, ?) 133 ON DUPLICATE KEY UPDATE 134 encoded_session = VALUES(encoded_session)` 135 136 q = fmt.Sprintf(q, m.opts.TableName) 137 _, err = m.db.ExecContext(ctx, q, sessionID, es, time.Now().Unix()) 138 if err != nil { 139 return errors.WithStack(err) 140 } 141 142 return nil 143 } 144 145 // Del deletes a session from the database. An error is not returned if the 146 // session does not exist. 147 // 148 // Del satisfies the sessions.DB interface. 149 func (m *mysql) Del(sessionID string) error { 150 log.Tracef("Del %v", sessionID) 151 152 ctx, cancel := m.ctxForOp() 153 defer cancel() 154 155 q := fmt.Sprintf("DELETE FROM %v WHERE id = ?", m.opts.TableName) 156 _, err := m.db.ExecContext(ctx, q, sessionID) 157 if err != nil { 158 return errors.WithStack(err) 159 } 160 161 return nil 162 } 163 164 // Get gets a session from the database. An ErrNotFound error is returned if 165 // a session is not found for the session ID. 166 // 167 // Get statisfies the sessions.DB interface. 168 func (m *mysql) Get(sessionID string) (*sessions.EncodedSession, error) { 169 log.Tracef("Get %v", sessionID) 170 171 ctx, cancel := m.ctxForOp() 172 defer cancel() 173 174 q := fmt.Sprintf("SELECT encoded_session FROM %v WHERE id = ?", 175 m.opts.TableName) 176 177 var encodedBlob []byte 178 err := m.db.QueryRowContext(ctx, q, sessionID).Scan(&encodedBlob) 179 switch { 180 case err == sql.ErrNoRows: 181 return nil, sessions.ErrNotFound 182 case err != nil: 183 return nil, errors.WithStack(err) 184 } 185 186 var es sessions.EncodedSession 187 err = json.Unmarshal(encodedBlob, &es) 188 if err != nil { 189 return nil, err 190 } 191 192 return &es, nil 193 } 194 195 // createTable creates the sessions table. 196 func (m *mysql) createTable() error { 197 ctx, cancel := m.ctxForOp() 198 defer cancel() 199 200 q := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %v (%v)", 201 m.opts.TableName, sessionsTable) 202 _, err := m.db.ExecContext(ctx, q) 203 if err != nil { 204 return errors.WithStack(err) 205 } 206 207 log.Debugf("Created %v database table", m.opts.TableName) 208 209 return nil 210 } 211 212 // dropTable drops the sessions table. 213 func (m *mysql) dropTable() error { 214 ctx, cancel := m.ctxForOp() 215 defer cancel() 216 217 q := fmt.Sprintf("DROP TABLE IF EXISTS %v", m.opts.TableName) 218 _, err := m.db.ExecContext(ctx, q) 219 if err != nil { 220 return errors.WithStack(err) 221 } 222 223 log.Debugf("Dropped %v database table", m.opts.TableName) 224 225 return nil 226 } 227 228 // cleanup performs database cleanup by deleting all sessions that have 229 // expired. 230 func (m *mysql) cleanup() error { 231 ctx, cancel := m.ctxForOp() 232 defer cancel() 233 234 q := "DELETE FROM %v WHERE created_at + ? <= ?" 235 q = fmt.Sprintf(q, m.opts.TableName) 236 r, err := m.db.ExecContext(ctx, q, m.sessionMaxAge, time.Now().Unix()) 237 if err != nil { 238 return errors.WithStack(err) 239 } 240 rowsAffected, err := r.RowsAffected() 241 if err != nil { 242 return err 243 } 244 245 log.Debugf("Deleted %v expired sessions from the database", rowsAffected) 246 247 return nil 248 } 249 250 // ctxForOp returns a context and cancel function for a single database 251 // operation. 252 func (m *mysql) ctxForOp() (context.Context, func()) { 253 return context.WithTimeout(context.Background(), m.opts.OpTimeout) 254 }