github.com/decred/politeia@v1.4.0/politeiawww/sessions/mysql/mysql.go (about)

     1  // Copyright (c) 2021 The Decred developers
     2  // Use of this source code is governed by an ISC
     3  // license that can be found in the LICENSE file.
     4  
     5  package mysql
     6  
     7  import (
     8  	"context"
     9  	"database/sql"
    10  	"encoding/json"
    11  	"fmt"
    12  	"time"
    13  
    14  	"github.com/decred/politeia/politeiawww/sessions"
    15  	"github.com/pkg/errors"
    16  )
    17  
    18  // sessionsTable is the table for the encoded session values.
    19  //
    20  // The id column is 128 bytes so that it can accomidate a 64 byte base64,
    21  // base32, or hex encoded key.
    22  //
    23  // The encoded_session column has a max length of 2^16 bytes, which is around
    24  // 64KB.
    25  //
    26  // The created_at column contains a Unix timestamp and is used to manually
    27  // clean up expired sessions. The gorilla/sessions Store does not do this
    28  // automatically.
    29  const sessionsTable = `
    30    id              CHAR(128) PRIMARY KEY,
    31    encoded_session BLOB NOT NULL,
    32    created_at      BIGINT NOT NULL
    33  `
    34  
    35  var (
    36  	_ sessions.DB = (*mysql)(nil)
    37  )
    38  
    39  // mysql implements the sessions.DB interface.
    40  type mysql struct {
    41  	// db is the mysql DB context.
    42  	db *sql.DB
    43  
    44  	// sessionMaxAge is the max age of a session in seconds. This is used to
    45  	// periodically clean up expired sessions from the database. The
    46  	// gorilla/sessions Store implemenation does not do this automatically. It
    47  	// must be done manually in the database layer.
    48  	sessionMaxAge int64
    49  
    50  	// opts contains the session database options.
    51  	opts *Opts
    52  }
    53  
    54  // Opts contains configurable options for the sessions database. These are
    55  // not required. Sane defaults are used when the options are not provided.
    56  type Opts struct {
    57  	// TableName is the table name for the sessions table.
    58  	TableName string
    59  
    60  	// OpTimeout is the timeout for a single database operation.
    61  	OpTimeout time.Duration
    62  }
    63  
    64  const (
    65  	// defaultTableName is the default table name for the sessions table.
    66  	defaultTableName = "sessions"
    67  
    68  	// defaultOpTimeout is the default timeout for a single database operation.
    69  	defaultOpTimeout = 1 * time.Minute
    70  )
    71  
    72  // New returns a new mysql context that implements the sessions DB interface.
    73  // The opts param can be used to override the default mysql context settings.
    74  //
    75  // The sessionMaxAge is the max age in seconds of a session. This function
    76  // cleans up any expired sessions from the database as part of the
    77  // initialization. A sessionMaxAge of <=0 will cause the sessions database
    78  // to be dropped and recreated.
    79  func New(db *sql.DB, sessionMaxAge int64, opts *Opts) (*mysql, error) {
    80  	// Setup the database options
    81  	if opts == nil {
    82  		opts = &Opts{}
    83  	}
    84  	if opts.TableName == "" {
    85  		opts.TableName = defaultTableName
    86  	}
    87  	if opts.OpTimeout == 0 {
    88  		opts.OpTimeout = defaultOpTimeout
    89  	}
    90  
    91  	// Setup the mysql context
    92  	m := mysql{
    93  		db:            db,
    94  		sessionMaxAge: sessionMaxAge,
    95  		opts:          opts,
    96  	}
    97  
    98  	// Perform database setup
    99  	if sessionMaxAge <= 0 {
   100  		err := m.dropTable()
   101  		if err != nil {
   102  			return nil, err
   103  		}
   104  	}
   105  	err := m.createTable()
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	err = m.cleanup()
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	return &m, nil
   115  }
   116  
   117  // Save saves a session to the database.
   118  //
   119  // Save satisfies the sessions.DB interface.
   120  func (m *mysql) Save(sessionID string, s sessions.EncodedSession) error {
   121  	log.Tracef("Save %v", sessionID)
   122  
   123  	es, err := json.Marshal(s)
   124  	if err != nil {
   125  		return err
   126  	}
   127  
   128  	ctx, cancel := m.ctxForOp()
   129  	defer cancel()
   130  
   131  	q := `INSERT INTO %v
   132      (id, encoded_session, created_at) VALUES (?, ?, ?)
   133      ON DUPLICATE KEY UPDATE
   134      encoded_session = VALUES(encoded_session)`
   135  
   136  	q = fmt.Sprintf(q, m.opts.TableName)
   137  	_, err = m.db.ExecContext(ctx, q, sessionID, es, time.Now().Unix())
   138  	if err != nil {
   139  		return errors.WithStack(err)
   140  	}
   141  
   142  	return nil
   143  }
   144  
   145  // Del deletes a session from the database. An error is not returned if the
   146  // session does not exist.
   147  //
   148  // Del satisfies the sessions.DB interface.
   149  func (m *mysql) Del(sessionID string) error {
   150  	log.Tracef("Del %v", sessionID)
   151  
   152  	ctx, cancel := m.ctxForOp()
   153  	defer cancel()
   154  
   155  	q := fmt.Sprintf("DELETE FROM %v WHERE id = ?", m.opts.TableName)
   156  	_, err := m.db.ExecContext(ctx, q, sessionID)
   157  	if err != nil {
   158  		return errors.WithStack(err)
   159  	}
   160  
   161  	return nil
   162  }
   163  
   164  // Get gets a session from the database. An ErrNotFound error is returned if
   165  // a session is not found for the session ID.
   166  //
   167  // Get statisfies the sessions.DB interface.
   168  func (m *mysql) Get(sessionID string) (*sessions.EncodedSession, error) {
   169  	log.Tracef("Get %v", sessionID)
   170  
   171  	ctx, cancel := m.ctxForOp()
   172  	defer cancel()
   173  
   174  	q := fmt.Sprintf("SELECT encoded_session FROM %v WHERE id = ?",
   175  		m.opts.TableName)
   176  
   177  	var encodedBlob []byte
   178  	err := m.db.QueryRowContext(ctx, q, sessionID).Scan(&encodedBlob)
   179  	switch {
   180  	case err == sql.ErrNoRows:
   181  		return nil, sessions.ErrNotFound
   182  	case err != nil:
   183  		return nil, errors.WithStack(err)
   184  	}
   185  
   186  	var es sessions.EncodedSession
   187  	err = json.Unmarshal(encodedBlob, &es)
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  
   192  	return &es, nil
   193  }
   194  
   195  // createTable creates the sessions table.
   196  func (m *mysql) createTable() error {
   197  	ctx, cancel := m.ctxForOp()
   198  	defer cancel()
   199  
   200  	q := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %v (%v)",
   201  		m.opts.TableName, sessionsTable)
   202  	_, err := m.db.ExecContext(ctx, q)
   203  	if err != nil {
   204  		return errors.WithStack(err)
   205  	}
   206  
   207  	log.Debugf("Created %v database table", m.opts.TableName)
   208  
   209  	return nil
   210  }
   211  
   212  // dropTable drops the sessions table.
   213  func (m *mysql) dropTable() error {
   214  	ctx, cancel := m.ctxForOp()
   215  	defer cancel()
   216  
   217  	q := fmt.Sprintf("DROP TABLE IF EXISTS %v", m.opts.TableName)
   218  	_, err := m.db.ExecContext(ctx, q)
   219  	if err != nil {
   220  		return errors.WithStack(err)
   221  	}
   222  
   223  	log.Debugf("Dropped %v database table", m.opts.TableName)
   224  
   225  	return nil
   226  }
   227  
   228  // cleanup performs database cleanup by deleting all sessions that have
   229  // expired.
   230  func (m *mysql) cleanup() error {
   231  	ctx, cancel := m.ctxForOp()
   232  	defer cancel()
   233  
   234  	q := "DELETE FROM %v WHERE created_at + ? <= ?"
   235  	q = fmt.Sprintf(q, m.opts.TableName)
   236  	r, err := m.db.ExecContext(ctx, q, m.sessionMaxAge, time.Now().Unix())
   237  	if err != nil {
   238  		return errors.WithStack(err)
   239  	}
   240  	rowsAffected, err := r.RowsAffected()
   241  	if err != nil {
   242  		return err
   243  	}
   244  
   245  	log.Debugf("Deleted %v expired sessions from the database", rowsAffected)
   246  
   247  	return nil
   248  }
   249  
   250  // ctxForOp returns a context and cancel function for a single database
   251  // operation.
   252  func (m *mysql) ctxForOp() (context.Context, func()) {
   253  	return context.WithTimeout(context.Background(), m.opts.OpTimeout)
   254  }