github.com/decred/politeia@v1.4.0/politeiad/backendv2/tstorebe/store/mysql/mysql.go (about)

     1  // Copyright (c) 2020-2022 The Decred developers
     2  // Use of this source code is governed by an ISC
     3  // license that can be found in the LICENSE file.
     4  
     5  package mysql
     6  
     7  import (
     8  	"context"
     9  	"database/sql"
    10  	"fmt"
    11  	"strings"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/DATA-DOG/go-sqlmock"
    16  	"github.com/decred/politeia/politeiad/backendv2/tstorebe/store"
    17  	"github.com/decred/politeia/util"
    18  	"github.com/pkg/errors"
    19  
    20  	_ "github.com/go-sql-driver/mysql"
    21  )
    22  
    23  const (
    24  	// Database options
    25  	connTimeout     = 1 * time.Minute
    26  	connMaxLifetime = 1 * time.Minute
    27  	maxOpenConns    = 0 // 0 is unlimited
    28  	maxIdleConns    = 100
    29  
    30  	// Database table names
    31  	tableNameKeyValue = "kv"
    32  	tableNameNonce    = "nonce"
    33  
    34  	// maxPlaceholders is the maximum number of placeholders, "(?, ?, ?)", that
    35  	// can be used in a prepared statement. MySQL uses an uint16 for this, so
    36  	// the limit is the the maximum value of an uint16.
    37  	maxPlaceholders = 65535
    38  )
    39  
    40  // tableKeyValue defines the key-value table.
    41  const tableKeyValue = `
    42    k VARCHAR(255) NOT NULL PRIMARY KEY,
    43    v LONGBLOB NOT NULL
    44  `
    45  
    46  // tableNonce defines the table used to track the encryption nonce.
    47  const tableNonce = `
    48    n BIGINT PRIMARY KEY AUTO_INCREMENT
    49  `
    50  
    51  var (
    52  	_ store.BlobKV = (*mysqlCtx)(nil)
    53  )
    54  
    55  // mysqlCtx implements the store BlobKV interface using a mysql driver.
    56  type mysqlCtx struct {
    57  	shutdown uint64
    58  	db       *sql.DB
    59  	key      [32]byte
    60  
    61  	// The following fields are only used during unit tests.
    62  	testing bool
    63  	mock    sqlmock.Sqlmock
    64  }
    65  
    66  func ctxWithTimeout() (context.Context, func()) {
    67  	return context.WithTimeout(context.Background(), connTimeout)
    68  }
    69  
    70  func (s *mysqlCtx) isShutdown() bool {
    71  	return atomic.LoadUint64(&s.shutdown) != 0
    72  }
    73  
    74  // put saves the provided key-value pairs to the database using a transaction.
    75  // New entries are inserted. Existing entries are updated.
    76  func (s *mysqlCtx) put(blobs map[string][]byte, encrypt bool, ctx context.Context, tx *sql.Tx) error {
    77  	// Encrypt blobs
    78  	if encrypt {
    79  		encrypted := make(map[string][]byte, len(blobs))
    80  		for k, v := range blobs {
    81  			e, err := s.encrypt(ctx, tx, v)
    82  			if err != nil {
    83  				return err
    84  			}
    85  			encrypted[k] = e
    86  		}
    87  
    88  		// Sanity check
    89  		if len(encrypted) != len(blobs) {
    90  			return errors.Errorf("unexpected number of encrypted blobs")
    91  		}
    92  
    93  		blobs = encrypted
    94  	}
    95  
    96  	// Save blobs
    97  	for k, v := range blobs {
    98  		_, err := tx.ExecContext(ctx,
    99  			"REPLACE INTO kv (k, v) VALUES (?, ?);", k, v)
   100  		if err != nil {
   101  			return errors.WithStack(err)
   102  		}
   103  	}
   104  
   105  	return nil
   106  }
   107  
   108  // Put saves the provided key-value entries to the database. New entries are
   109  // inserted. Existing entries are updated.
   110  //
   111  // This operation is atomic.
   112  //
   113  // This function satisfies the store BlobKV interface.
   114  func (s *mysqlCtx) Put(blobs map[string][]byte, encrypt bool) error {
   115  	log.Tracef("Put: %v blobs", len(blobs))
   116  
   117  	if s.isShutdown() {
   118  		return store.ErrShutdown
   119  	}
   120  
   121  	ctx, cancel := ctxWithTimeout()
   122  	defer cancel()
   123  
   124  	// Start transaction
   125  	opts := &sql.TxOptions{
   126  		Isolation: sql.LevelDefault,
   127  	}
   128  	tx, err := s.db.BeginTx(ctx, opts)
   129  	if err != nil {
   130  		return err
   131  	}
   132  
   133  	// Save blobs
   134  	err = s.put(blobs, encrypt, ctx, tx)
   135  	if err != nil {
   136  		// Attempt to roll back the transaction
   137  		if err2 := tx.Rollback(); err2 != nil {
   138  			// We're in trouble!
   139  			e := fmt.Sprintf("put: %v, unable to rollback: %v", err, err2)
   140  			panic(e)
   141  		}
   142  		return err
   143  	}
   144  
   145  	// Commit transaction
   146  	err = tx.Commit()
   147  	if err != nil {
   148  		return err
   149  	}
   150  
   151  	log.Debugf("Saved blobs (%v) to store", len(blobs))
   152  
   153  	return nil
   154  }
   155  
   156  // Del deletes the key-value entries from the database for the provided keys.
   157  //
   158  // This operation is atomic.
   159  //
   160  // This function satisfies the store BlobKV interface.
   161  func (s *mysqlCtx) Del(keys []string) error {
   162  	log.Tracef("Del: %v", keys)
   163  
   164  	if s.isShutdown() {
   165  		return store.ErrShutdown
   166  	}
   167  
   168  	ctx, cancel := ctxWithTimeout()
   169  	defer cancel()
   170  
   171  	// Start transaction
   172  	opts := &sql.TxOptions{
   173  		Isolation: sql.LevelDefault,
   174  	}
   175  	tx, err := s.db.BeginTx(ctx, opts)
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	// Delete blobs
   181  	for _, v := range keys {
   182  		_, err = tx.ExecContext(ctx, "DELETE FROM kv WHERE k IN (?);", v)
   183  		if err != nil {
   184  			// Attempt to roll back the transaction
   185  			if err2 := tx.Rollback(); err2 != nil {
   186  				// We're in trouble!
   187  				e := fmt.Sprintf("del: %v, unable to rollback: %v", err, err2)
   188  				panic(e)
   189  			}
   190  			return err
   191  		}
   192  	}
   193  
   194  	// Commit transaction
   195  	err = tx.Commit()
   196  	if err != nil {
   197  		return err
   198  	}
   199  
   200  	log.Debugf("Deleted blobs (%v) from store", len(keys))
   201  
   202  	return nil
   203  }
   204  
   205  // Get retrieves the key-value entries from the database for the provided
   206  // keys.
   207  //
   208  // An entry will not exist in the returned map for any blobs that are not
   209  // found. It is the responsibility of the caller to ensure a blob was returned
   210  // for all provided keys.
   211  //
   212  // This function satisfies the store BlobKV interface.
   213  func (s *mysqlCtx) Get(keys []string) (map[string][]byte, error) {
   214  	log.Tracef("Get: %v", keys)
   215  
   216  	if s.isShutdown() {
   217  		return nil, store.ErrShutdown
   218  	}
   219  
   220  	// Build the select statements
   221  	statements := buildSelectStatements(keys, maxPlaceholders)
   222  
   223  	log.Debugf("Get %v blobs using %v prepared statements",
   224  		len(keys), len(statements))
   225  
   226  	// Execute the statements
   227  	reply := make(map[string][]byte, len(keys))
   228  	for i, e := range statements {
   229  		log.Debugf("Executing select statement %v/%v", i+1, len(statements))
   230  
   231  		ctx, cancel := ctxWithTimeout()
   232  		defer cancel()
   233  
   234  		rows, err := s.db.QueryContext(ctx, e.Query, e.Args...)
   235  		if err != nil {
   236  			return nil, errors.WithStack(err)
   237  		}
   238  		defer rows.Close()
   239  
   240  		// Unpack the reply
   241  		for rows.Next() {
   242  			var k string
   243  			var v []byte
   244  			err = rows.Scan(&k, &v)
   245  			if err != nil {
   246  				return nil, errors.WithStack(err)
   247  			}
   248  
   249  			// Decrypt the blob if required
   250  			if isEncrypted(v) {
   251  				log.Tracef("Encrypted blob: %v", k)
   252  				v, _, err = s.decrypt(v)
   253  				if err != nil {
   254  					return nil, err
   255  				}
   256  			}
   257  
   258  			// Save the blob
   259  			reply[k] = v
   260  		}
   261  		err = rows.Err()
   262  		if err != nil {
   263  			return nil, errors.WithStack(err)
   264  		}
   265  	}
   266  
   267  	return reply, nil
   268  }
   269  
   270  // Close closes the database connection.
   271  func (s *mysqlCtx) Close() {
   272  	log.Tracef("Close")
   273  
   274  	atomic.AddUint64(&s.shutdown, 1)
   275  
   276  	// Zero the encryption key
   277  	util.Zero(s.key[:])
   278  
   279  	// Close mysql connection
   280  	s.db.Close()
   281  }
   282  
   283  // selectStatement contains the query string and arguments for a SELECT
   284  // statement.
   285  type selectStatement struct {
   286  	Query string
   287  	Args  []interface{}
   288  }
   289  
   290  // buildSelectStatements builds the SELECT statements that can be executed
   291  // against the MySQL key-value store. The maximum number of records that will
   292  // be retrieved in any individual SELECT statement is determined by the size
   293  // argument. The keys are split up into multiple statements if they exceed this
   294  // limit.
   295  func buildSelectStatements(keys []string, size int) []selectStatement {
   296  	statements := make([]selectStatement, 0, (len(keys)/size)+1)
   297  	var startIdx int
   298  	for startIdx < len(keys) {
   299  		// Find the end index
   300  		endIdx := startIdx + size
   301  		if endIdx > len(keys) {
   302  			// We've reached the end of the slice
   303  			endIdx = len(keys)
   304  		}
   305  
   306  		// startIdx is included. endIdx is excluded.
   307  		statementKeys := keys[startIdx:endIdx]
   308  
   309  		// Build the query
   310  		q := buildSelectQuery(len(statementKeys))
   311  		log.Tracef("%v", q)
   312  
   313  		// Convert the keys to interfaces. The sql query
   314  		// methods require arguments be interfaces.
   315  		args := make([]interface{}, len(statementKeys))
   316  		for i, v := range statementKeys {
   317  			args[i] = v
   318  		}
   319  
   320  		// Save the statement
   321  		statements = append(statements, selectStatement{
   322  			Query: q,
   323  			Args:  args,
   324  		})
   325  
   326  		// Update the start index
   327  		startIdx = endIdx
   328  	}
   329  
   330  	return statements
   331  }
   332  
   333  // buildSelectQuery returns a query string for the MySQL key-value store.
   334  //
   335  // Example: "SELECT k, v FROM kv WHERE k IN (?,?);"
   336  func buildSelectQuery(placeholders int) string {
   337  	return fmt.Sprintf("SELECT k, v FROM kv WHERE k IN %v;",
   338  		buildPlaceholders(placeholders))
   339  }
   340  
   341  // buildPlaceholders builds and returns a parameter placeholder string with the
   342  // specified number of placeholders.
   343  //
   344  // Input: 1  Output: "(?)"
   345  // Input: 3  Output: "(?,?,?)"
   346  func buildPlaceholders(placeholders int) string {
   347  	var b strings.Builder
   348  
   349  	b.WriteString("(")
   350  	for i := 0; i < placeholders; i++ {
   351  		b.WriteString("?")
   352  		// Don't add a comma on the last one
   353  		if i < placeholders-1 {
   354  			b.WriteString(",")
   355  		}
   356  	}
   357  	b.WriteString(")")
   358  
   359  	return b.String()
   360  }
   361  
   362  // New connects to a mysql instance using the given connection params,
   363  // and returns pointer to the created mysql struct.
   364  func New(host, user, password, dbname string) (*mysqlCtx, error) {
   365  	// The password is required to derive the encryption key
   366  	if password == "" {
   367  		return nil, errors.Errorf("password not provided")
   368  	}
   369  
   370  	// Connect to database
   371  	log.Infof("MySQL host: %v:[password]@tcp(%v)/%v", user, host, dbname)
   372  
   373  	h := fmt.Sprintf("%v:%v@tcp(%v)/%v", user, password, host, dbname)
   374  	db, err := sql.Open("mysql", h)
   375  	if err != nil {
   376  		return nil, err
   377  	}
   378  
   379  	// Setup database options
   380  	db.SetConnMaxLifetime(connMaxLifetime)
   381  	db.SetMaxOpenConns(maxOpenConns)
   382  	db.SetMaxIdleConns(maxIdleConns)
   383  
   384  	// Verify database connection
   385  	err = db.Ping()
   386  	if err != nil {
   387  		return nil, err
   388  	}
   389  
   390  	// Setup key-value table
   391  	q := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %v (%v)`,
   392  		tableNameKeyValue, tableKeyValue)
   393  	_, err = db.Exec(q)
   394  	if err != nil {
   395  		return nil, errors.WithStack(err)
   396  	}
   397  
   398  	// Setup nonce table
   399  	q = fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %v (%v)`,
   400  		tableNameNonce, tableNonce)
   401  	_, err = db.Exec(q)
   402  	if err != nil {
   403  		return nil, errors.WithStack(err)
   404  	}
   405  
   406  	// Setup mysql context
   407  	s := &mysqlCtx{
   408  		db: db,
   409  	}
   410  
   411  	// Derive encryption key from password. Key is set in argon2idKey
   412  	err = s.deriveEncryptionKey(password)
   413  	if err != nil {
   414  		return nil, err
   415  	}
   416  
   417  	return s, nil
   418  }