github.com/decred/dcrlnd@v0.7.6/kvdb/postgres/db.go (about)

     1  //go:build kvdb_postgres
     2  // +build kvdb_postgres
     3  
     4  package postgres
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/btcsuite/btcwallet/walletdb"
    16  )
    17  
    18  const (
    19  	// kvTableName is the name of the table that will contain all the kv
    20  	// pairs.
    21  	kvTableName = "kv"
    22  )
    23  
    24  // KV stores a key/value pair.
    25  type KV struct {
    26  	key string
    27  	val string
    28  }
    29  
    30  // db holds a reference to the postgres connection connection.
    31  type db struct {
    32  	// cfg is the postgres connection config.
    33  	cfg *Config
    34  
    35  	// prefix is the table name prefix that is used to simulate namespaces.
    36  	// We don't use schemas because at least sqlite does not support that.
    37  	prefix string
    38  
    39  	// ctx is the overall context for the database driver.
    40  	//
    41  	// TODO: This is an anti-pattern that is in place until the kvdb
    42  	// interface supports a context.
    43  	ctx context.Context
    44  
    45  	// db is the underlying database connection instance.
    46  	db *sql.DB
    47  
    48  	// lock is the global write lock that ensures single writer.
    49  	lock sync.RWMutex
    50  
    51  	// table is the name of the table that contains the data for all
    52  	// top-level buckets that have keys that cannot be mapped to a distinct
    53  	// sql table.
    54  	table string
    55  }
    56  
    57  // Enforce db implements the walletdb.DB interface.
    58  var _ walletdb.DB = (*db)(nil)
    59  
    60  // Global set of database connections.
    61  var dbConns *dbConnSet
    62  
    63  // Init initializes the global set of database connections.
    64  func Init(maxConnections int) {
    65  	dbConns = newDbConnSet(maxConnections)
    66  }
    67  
    68  // newPostgresBackend returns a db object initialized with the passed backend
    69  // config. If postgres connection cannot be estabished, then returns error.
    70  func newPostgresBackend(ctx context.Context, config *Config, prefix string) (
    71  	*db, error) {
    72  
    73  	if prefix == "" {
    74  		return nil, errors.New("empty postgres prefix")
    75  	}
    76  
    77  	if dbConns == nil {
    78  		return nil, errors.New("db connection set not initialized")
    79  	}
    80  
    81  	dbConn, err := dbConns.Open(config.Dsn)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	// Compose system table names.
    87  	table := fmt.Sprintf(
    88  		"%s_%s", prefix, kvTableName,
    89  	)
    90  
    91  	// Execute the create statements to set up a kv table in postgres. Every
    92  	// row points to the bucket that it is one via its parent_id field. A
    93  	// NULL parent_id means that the key belongs to the upper-most bucket in
    94  	// this table. A constraint on parent_id is enforcing referential
    95  	// integrity.
    96  	//
    97  	// Furthermore there is a <table>_p index on parent_id that is required
    98  	// for the foreign key constraint.
    99  	//
   100  	// Finally there are unique indices on (parent_id, key) to prevent the
   101  	// same key being present in a bucket more than once (<table>_up and
   102  	// <table>_unp). In postgres, a single index wouldn't enforce the unique
   103  	// constraint on rows with a NULL parent_id. Therefore two indices are
   104  	// defined.
   105  	_, err = dbConn.ExecContext(ctx, `
   106  CREATE SCHEMA IF NOT EXISTS public;
   107  CREATE TABLE IF NOT EXISTS public.`+table+`
   108  (
   109      key bytea NOT NULL,
   110      value bytea,
   111      parent_id bigint,
   112      id bigserial PRIMARY KEY,
   113      sequence bigint,
   114      CONSTRAINT `+table+`_parent FOREIGN KEY (parent_id)
   115          REFERENCES public.`+table+` (id)
   116          ON UPDATE NO ACTION
   117          ON DELETE CASCADE
   118  );
   119  
   120  CREATE INDEX IF NOT EXISTS `+table+`_p
   121      ON public.`+table+` (parent_id);
   122  
   123  CREATE UNIQUE INDEX IF NOT EXISTS `+table+`_up
   124      ON public.`+table+`
   125      (parent_id, key) WHERE parent_id IS NOT NULL;
   126  
   127  CREATE UNIQUE INDEX IF NOT EXISTS `+table+`_unp 
   128      ON public.`+table+` (key) WHERE parent_id IS NULL;
   129  `)
   130  	if err != nil {
   131  		_ = dbConn.Close()
   132  
   133  		return nil, err
   134  	}
   135  
   136  	backend := &db{
   137  		cfg:    config,
   138  		prefix: prefix,
   139  		ctx:    ctx,
   140  		db:     dbConn,
   141  		table:  table,
   142  	}
   143  
   144  	return backend, nil
   145  }
   146  
   147  // getTimeoutCtx gets a timeout context for database requests.
   148  func (db *db) getTimeoutCtx() (context.Context, func()) {
   149  	if db.cfg.Timeout == time.Duration(0) {
   150  		return db.ctx, func() {}
   151  	}
   152  
   153  	return context.WithTimeout(db.ctx, db.cfg.Timeout)
   154  }
   155  
   156  // getPrefixedTableName returns a table name for this prefix (namespace).
   157  func (db *db) getPrefixedTableName(table string) string {
   158  	return fmt.Sprintf("%s_%s", db.prefix, table)
   159  }
   160  
   161  // catchPanic executes the specified function. If a panic occurs, it is returned
   162  // as an error value.
   163  func catchPanic(f func() error) (err error) {
   164  	defer func() {
   165  		if r := recover(); r != nil {
   166  			log.Criticalf("Caught unhandled error: %v", r)
   167  
   168  			switch data := r.(type) {
   169  			case error:
   170  				err = data
   171  
   172  			default:
   173  				err = errors.New(fmt.Sprintf("%v", data))
   174  			}
   175  		}
   176  	}()
   177  
   178  	err = f()
   179  
   180  	return
   181  }
   182  
   183  // View opens a database read transaction and executes the function f with the
   184  // transaction passed as a parameter. After f exits, the transaction is rolled
   185  // back. If f errors, its error is returned, not a rollback error (if any
   186  // occur). The passed reset function is called before the start of the
   187  // transaction and can be used to reset intermediate state. As callers may
   188  // expect retries of the f closure (depending on the database backend used), the
   189  // reset function will be called before each retry respectively.
   190  func (db *db) View(f func(tx walletdb.ReadTx) error, reset func()) error {
   191  	return db.executeTransaction(
   192  		func(tx walletdb.ReadWriteTx) error {
   193  			return f(tx.(walletdb.ReadTx))
   194  		},
   195  		reset, true,
   196  	)
   197  }
   198  
   199  // Update opens a database read/write transaction and executes the function f
   200  // with the transaction passed as a parameter. After f exits, if f did not
   201  // error, the transaction is committed. Otherwise, if f did error, the
   202  // transaction is rolled back. If the rollback fails, the original error
   203  // returned by f is still returned. If the commit fails, the commit error is
   204  // returned. As callers may expect retries of the f closure, the reset function
   205  // will be called before each retry respectively.
   206  func (db *db) Update(f func(tx walletdb.ReadWriteTx) error, reset func()) (err error) {
   207  	return db.executeTransaction(f, reset, false)
   208  }
   209  
   210  // executeTransaction creates a new read-only or read-write transaction and
   211  // executes the given function within it.
   212  func (db *db) executeTransaction(f func(tx walletdb.ReadWriteTx) error,
   213  	reset func(), readOnly bool) error {
   214  
   215  	reset()
   216  
   217  	tx, err := newReadWriteTx(db, readOnly)
   218  	if err != nil {
   219  		return err
   220  	}
   221  
   222  	err = catchPanic(func() error { return f(tx) })
   223  	if err != nil {
   224  		if rollbackErr := tx.Rollback(); rollbackErr != nil {
   225  			log.Errorf("Error rolling back tx: %v", rollbackErr)
   226  		}
   227  
   228  		return err
   229  	}
   230  
   231  	return tx.Commit()
   232  }
   233  
   234  // PrintStats returns all collected stats pretty printed into a string.
   235  func (db *db) PrintStats() string {
   236  	return "stats not supported by Postgres driver"
   237  }
   238  
   239  // BeginReadWriteTx opens a database read+write transaction.
   240  func (db *db) BeginReadWriteTx() (walletdb.ReadWriteTx, error) {
   241  	return newReadWriteTx(db, false)
   242  }
   243  
   244  // BeginReadTx opens a database read transaction.
   245  func (db *db) BeginReadTx() (walletdb.ReadTx, error) {
   246  	return newReadWriteTx(db, true)
   247  }
   248  
   249  // Copy writes a copy of the database to the provided writer. This call will
   250  // start a read-only transaction to perform all operations.
   251  // This function is part of the walletdb.Db interface implementation.
   252  func (db *db) Copy(w io.Writer) error {
   253  	return errors.New("not implemented")
   254  }
   255  
   256  // Close cleanly shuts down the database and syncs all data.
   257  // This function is part of the walletdb.Db interface implementation.
   258  func (db *db) Close() error {
   259  	log.Infof("Closing database %v", db.prefix)
   260  
   261  	return dbConns.Close(db.cfg.Dsn)
   262  }