github.com/decred/dcrlnd@v0.7.6/kvdb/postgres/readwrite_tx.go (about)

     1  //go:build kvdb_postgres
     2  // +build kvdb_postgres
     3  
     4  package postgres
     5  
     6  import (
     7  	"context"
     8  	"database/sql"
     9  	"sync"
    10  
    11  	"github.com/btcsuite/btcwallet/walletdb"
    12  )
    13  
    14  // readWriteTx holds a reference to an open postgres transaction.
    15  type readWriteTx struct {
    16  	db *db
    17  	tx *sql.Tx
    18  
    19  	// onCommit gets called upon commit.
    20  	onCommit func()
    21  
    22  	// active is true if the transaction hasn't been committed yet.
    23  	active bool
    24  
    25  	// locker is a pointer to the global db lock.
    26  	locker sync.Locker
    27  }
    28  
    29  // newReadWriteTx creates an rw transaction using a connection from the
    30  // specified pool.
    31  func newReadWriteTx(db *db, readOnly bool) (*readWriteTx, error) {
    32  	// Obtain the global lock instance. An alternative here is to obtain a
    33  	// database lock from Postgres. Unfortunately there is no database-level
    34  	// lock in Postgres, meaning that each table would need to be locked
    35  	// individually. Perhaps an advisory lock could perform this function
    36  	// too.
    37  	var locker sync.Locker = &db.lock
    38  	if readOnly {
    39  		locker = db.lock.RLocker()
    40  	}
    41  	locker.Lock()
    42  
    43  	// Start the transaction. Don't use the timeout context because it would
    44  	// be applied to the transaction as a whole. If possible, mark the
    45  	// transaction as read-only to make sure that potential programming
    46  	// errors cannot cause changes to the database.
    47  	tx, err := db.db.BeginTx(
    48  		context.Background(),
    49  		&sql.TxOptions{
    50  			ReadOnly: readOnly,
    51  		},
    52  	)
    53  	if err != nil {
    54  		locker.Unlock()
    55  		return nil, err
    56  	}
    57  
    58  	return &readWriteTx{
    59  		db:     db,
    60  		tx:     tx,
    61  		active: true,
    62  		locker: locker,
    63  	}, nil
    64  }
    65  
    66  // ReadBucket opens the root bucket for read only access.  If the bucket
    67  // described by the key does not exist, nil is returned.
    68  func (tx *readWriteTx) ReadBucket(key []byte) walletdb.ReadBucket {
    69  	return tx.ReadWriteBucket(key)
    70  }
    71  
    72  // ForEachBucket iterates through all top level buckets.
    73  func (tx *readWriteTx) ForEachBucket(fn func(key []byte) error) error {
    74  	// Fetch binary top level buckets.
    75  	bucket := newReadWriteBucket(tx, nil)
    76  	err := bucket.ForEach(func(k, _ []byte) error {
    77  		return fn(k)
    78  	})
    79  	return err
    80  }
    81  
    82  // Rollback closes the transaction, discarding changes (if any) if the
    83  // database was modified by a write transaction.
    84  func (tx *readWriteTx) Rollback() error {
    85  	// If the transaction has been closed roolback will fail.
    86  	if !tx.active {
    87  		return walletdb.ErrTxClosed
    88  	}
    89  
    90  	err := tx.tx.Rollback()
    91  
    92  	// Unlock the transaction regardless of the error result.
    93  	tx.active = false
    94  	tx.locker.Unlock()
    95  	return err
    96  }
    97  
    98  // ReadWriteBucket opens the root bucket for read/write access.  If the
    99  // bucket described by the key does not exist, nil is returned.
   100  func (tx *readWriteTx) ReadWriteBucket(key []byte) walletdb.ReadWriteBucket {
   101  	if len(key) == 0 {
   102  		return nil
   103  	}
   104  
   105  	bucket := newReadWriteBucket(tx, nil)
   106  	return bucket.NestedReadWriteBucket(key)
   107  }
   108  
   109  // CreateTopLevelBucket creates the top level bucket for a key if it
   110  // does not exist.  The newly-created bucket it returned.
   111  func (tx *readWriteTx) CreateTopLevelBucket(key []byte) (walletdb.ReadWriteBucket, error) {
   112  	if len(key) == 0 {
   113  		return nil, walletdb.ErrBucketNameRequired
   114  	}
   115  
   116  	bucket := newReadWriteBucket(tx, nil)
   117  	return bucket.CreateBucketIfNotExists(key)
   118  }
   119  
   120  // DeleteTopLevelBucket deletes the top level bucket for a key.  This
   121  // errors if the bucket can not be found or the key keys a single value
   122  // instead of a bucket.
   123  func (tx *readWriteTx) DeleteTopLevelBucket(key []byte) error {
   124  	// Execute a cascading delete on the key.
   125  	result, err := tx.Exec(
   126  		"DELETE FROM "+tx.db.table+" WHERE key=$1 "+
   127  			"AND parent_id IS NULL",
   128  		key,
   129  	)
   130  	if err != nil {
   131  		return err
   132  	}
   133  
   134  	rows, err := result.RowsAffected()
   135  	if err != nil {
   136  		return err
   137  	}
   138  	if rows == 0 {
   139  		return walletdb.ErrBucketNotFound
   140  	}
   141  
   142  	return nil
   143  }
   144  
   145  // Commit commits the transaction if not already committed.
   146  func (tx *readWriteTx) Commit() error {
   147  	// Commit will fail if the transaction is already committed.
   148  	if !tx.active {
   149  		return walletdb.ErrTxClosed
   150  	}
   151  
   152  	// Try committing the transaction.
   153  	err := tx.tx.Commit()
   154  	if err == nil && tx.onCommit != nil {
   155  		tx.onCommit()
   156  	}
   157  
   158  	// Unlock the transaction regardless of the error result.
   159  	tx.active = false
   160  	tx.locker.Unlock()
   161  
   162  	return err
   163  }
   164  
   165  // OnCommit sets the commit callback (overriding if already set).
   166  func (tx *readWriteTx) OnCommit(cb func()) {
   167  	tx.onCommit = cb
   168  }
   169  
   170  // QueryRow executes a QueryRow call with a timeout context.
   171  func (tx *readWriteTx) QueryRow(query string, args ...interface{}) (*sql.Row,
   172  	func()) {
   173  
   174  	ctx, cancel := tx.db.getTimeoutCtx()
   175  	return tx.tx.QueryRowContext(ctx, query, args...), cancel
   176  }
   177  
   178  // Query executes a multi-row query call with a timeout context.
   179  func (tx *readWriteTx) Query(query string, args ...interface{}) (*sql.Rows,
   180  	func(), error) {
   181  
   182  	ctx, cancel := tx.db.getTimeoutCtx()
   183  	rows, err := tx.tx.QueryContext(ctx, query, args...)
   184  	if err != nil {
   185  		cancel()
   186  
   187  		return nil, func() {}, err
   188  	}
   189  
   190  	return rows, cancel, nil
   191  }
   192  
   193  // Exec executes a Exec call with a timeout context.
   194  func (tx *readWriteTx) Exec(query string, args ...interface{}) (sql.Result,
   195  	error) {
   196  
   197  	ctx, cancel := tx.db.getTimeoutCtx()
   198  	defer cancel()
   199  
   200  	return tx.tx.ExecContext(ctx, query, args...)
   201  }