github.com/decred/dcrlnd@v0.7.6/kvdb/postgres/readwrite_bucket.go (about)

     1  //go:build kvdb_postgres
     2  // +build kvdb_postgres
     3  
     4  package postgres
     5  
     6  import (
     7  	"database/sql"
     8  	"errors"
     9  	"fmt"
    10  
    11  	"github.com/btcsuite/btcwallet/walletdb"
    12  )
    13  
    14  // readWriteBucket stores the bucket id and the buckets transaction.
    15  type readWriteBucket struct {
    16  	// id is used to identify the bucket. If id is null, it refers to the
    17  	// root bucket.
    18  	id *int64
    19  
    20  	// tx holds the parent transaction.
    21  	tx *readWriteTx
    22  
    23  	table string
    24  }
    25  
    26  // newReadWriteBucket creates a new rw bucket with the passed transaction
    27  // and bucket id.
    28  func newReadWriteBucket(tx *readWriteTx, id *int64) *readWriteBucket {
    29  	return &readWriteBucket{
    30  		id:    id,
    31  		tx:    tx,
    32  		table: tx.db.table,
    33  	}
    34  }
    35  
    36  // NestedReadBucket retrieves a nested read bucket with the given key.
    37  // Returns nil if the bucket does not exist.
    38  func (b *readWriteBucket) NestedReadBucket(key []byte) walletdb.ReadBucket {
    39  	return b.NestedReadWriteBucket(key)
    40  }
    41  
    42  func parentSelector(id *int64) string {
    43  	if id == nil {
    44  		return "parent_id IS NULL"
    45  	}
    46  	return fmt.Sprintf("parent_id=%v", *id)
    47  }
    48  
    49  // ForEach invokes the passed function with every key/value pair in
    50  // the bucket. This includes nested buckets, in which case the value
    51  // is nil, but it does not include the key/value pairs within those
    52  // nested buckets.
    53  func (b *readWriteBucket) ForEach(cb func(k, v []byte) error) error {
    54  	cursor := b.ReadWriteCursor()
    55  
    56  	k, v := cursor.First()
    57  	for k != nil {
    58  		err := cb(k, v)
    59  		if err != nil {
    60  			return err
    61  		}
    62  
    63  		k, v = cursor.Next()
    64  	}
    65  
    66  	return nil
    67  }
    68  
    69  // Get returns the value for the given key. Returns nil if the key does
    70  // not exist in this bucket.
    71  func (b *readWriteBucket) Get(key []byte) []byte {
    72  	// Return nil if the key is empty.
    73  	if len(key) == 0 {
    74  		return nil
    75  	}
    76  
    77  	var value *[]byte
    78  	row, cancel := b.tx.QueryRow(
    79  		"SELECT value FROM "+b.table+" WHERE "+parentSelector(b.id)+
    80  			" AND key=$1", key,
    81  	)
    82  	defer cancel()
    83  	err := row.Scan(&value)
    84  
    85  	switch {
    86  	case err == sql.ErrNoRows:
    87  		return nil
    88  
    89  	case err != nil:
    90  		panic(err)
    91  	}
    92  
    93  	return *value
    94  }
    95  
    96  // ReadCursor returns a new read-only cursor for this bucket.
    97  func (b *readWriteBucket) ReadCursor() walletdb.ReadCursor {
    98  	return newReadWriteCursor(b)
    99  }
   100  
   101  // NestedReadWriteBucket retrieves a nested bucket with the given key.
   102  // Returns nil if the bucket does not exist.
   103  func (b *readWriteBucket) NestedReadWriteBucket(
   104  	key []byte) walletdb.ReadWriteBucket {
   105  
   106  	if len(key) == 0 {
   107  		return nil
   108  	}
   109  
   110  	var id int64
   111  	row, cancel := b.tx.QueryRow(
   112  		"SELECT id FROM "+b.table+" WHERE "+parentSelector(b.id)+
   113  			" AND key=$1 AND value IS NULL", key,
   114  	)
   115  	defer cancel()
   116  	err := row.Scan(&id)
   117  
   118  	switch {
   119  	case err == sql.ErrNoRows:
   120  		return nil
   121  
   122  	case err != nil:
   123  		panic(err)
   124  	}
   125  
   126  	return newReadWriteBucket(b.tx, &id)
   127  }
   128  
   129  // CreateBucket creates and returns a new nested bucket with the given key.
   130  // Returns ErrBucketExists if the bucket already exists, ErrBucketNameRequired
   131  // if the key is empty, or ErrIncompatibleValue if the key value is otherwise
   132  // invalid for the particular database implementation.  Other errors are
   133  // possible depending on the implementation.
   134  func (b *readWriteBucket) CreateBucket(key []byte) (
   135  	walletdb.ReadWriteBucket, error) {
   136  
   137  	if len(key) == 0 {
   138  		return nil, walletdb.ErrBucketNameRequired
   139  	}
   140  
   141  	// Check to see if the bucket already exists.
   142  	var (
   143  		value *[]byte
   144  		id    int64
   145  	)
   146  	row, cancel := b.tx.QueryRow(
   147  		"SELECT id,value FROM "+b.table+" WHERE "+parentSelector(b.id)+
   148  			" AND key=$1", key,
   149  	)
   150  	defer cancel()
   151  	err := row.Scan(&id, &value)
   152  
   153  	switch {
   154  	case err == sql.ErrNoRows:
   155  
   156  	case err == nil && value == nil:
   157  		return nil, walletdb.ErrBucketExists
   158  
   159  	case err == nil && value != nil:
   160  		return nil, walletdb.ErrIncompatibleValue
   161  
   162  	case err != nil:
   163  		return nil, err
   164  	}
   165  
   166  	// Bucket does not yet exist, so create it. Postgres will generate a
   167  	// bucket id for the new bucket.
   168  	row, cancel = b.tx.QueryRow(
   169  		"INSERT INTO "+b.table+" (parent_id, key) "+
   170  			"VALUES($1, $2) RETURNING id", b.id, key,
   171  	)
   172  	defer cancel()
   173  	err = row.Scan(&id)
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	return newReadWriteBucket(b.tx, &id), nil
   179  }
   180  
   181  // CreateBucketIfNotExists creates and returns a new nested bucket with
   182  // the given key if it does not already exist.  Returns
   183  // ErrBucketNameRequired if the key is empty or ErrIncompatibleValue
   184  // if the key value is otherwise invalid for the particular database
   185  // backend.  Other errors are possible depending on the implementation.
   186  func (b *readWriteBucket) CreateBucketIfNotExists(key []byte) (
   187  	walletdb.ReadWriteBucket, error) {
   188  
   189  	if len(key) == 0 {
   190  		return nil, walletdb.ErrBucketNameRequired
   191  	}
   192  
   193  	// Check to see if the bucket already exists.
   194  	var (
   195  		value *[]byte
   196  		id    int64
   197  	)
   198  	row, cancel := b.tx.QueryRow(
   199  		"SELECT id,value FROM "+b.table+" WHERE "+parentSelector(b.id)+
   200  			" AND key=$1", key,
   201  	)
   202  	defer cancel()
   203  	err := row.Scan(&id, &value)
   204  
   205  	switch {
   206  	// Bucket does not yet exist, so create it now. Postgres will generate a
   207  	// bucket id for the new bucket.
   208  	case err == sql.ErrNoRows:
   209  		row, cancel := b.tx.QueryRow(
   210  			"INSERT INTO "+b.table+" (parent_id, key) "+
   211  				"VALUES($1, $2) RETURNING id", b.id, key,
   212  		)
   213  		defer cancel()
   214  		err := row.Scan(&id)
   215  		if err != nil {
   216  			return nil, err
   217  		}
   218  
   219  	case err == nil && value != nil:
   220  		return nil, walletdb.ErrIncompatibleValue
   221  
   222  	case err != nil:
   223  		return nil, err
   224  	}
   225  
   226  	return newReadWriteBucket(b.tx, &id), nil
   227  }
   228  
   229  // DeleteNestedBucket deletes the nested bucket and its sub-buckets
   230  // pointed to by the passed key. All values in the bucket and sub-buckets
   231  // will be deleted as well.
   232  func (b *readWriteBucket) DeleteNestedBucket(key []byte) error {
   233  	if len(key) == 0 {
   234  		return walletdb.ErrIncompatibleValue
   235  	}
   236  
   237  	result, err := b.tx.Exec(
   238  		"DELETE FROM "+b.table+" WHERE "+parentSelector(b.id)+
   239  			" AND key=$1 AND value IS NULL",
   240  		key,
   241  	)
   242  	if err != nil {
   243  		return err
   244  	}
   245  
   246  	rows, err := result.RowsAffected()
   247  	if err != nil {
   248  		return err
   249  	}
   250  	if rows == 0 {
   251  		return walletdb.ErrBucketNotFound
   252  	}
   253  
   254  	return nil
   255  }
   256  
   257  // Put updates the value for the passed key.
   258  // Returns ErrKeyRequired if te passed key is empty.
   259  func (b *readWriteBucket) Put(key, value []byte) error {
   260  	if len(key) == 0 {
   261  		return walletdb.ErrKeyRequired
   262  	}
   263  
   264  	// Prevent NULL being written for an empty value slice.
   265  	if value == nil {
   266  		value = []byte{}
   267  	}
   268  
   269  	var (
   270  		result sql.Result
   271  		err    error
   272  	)
   273  
   274  	// We are putting a value in a bucket in this table. Try to insert the
   275  	// key first. If the key already exists (ON CONFLICT), update the key.
   276  	// Do not update a NULL value, because this indicates that the key
   277  	// contains a sub-bucket. This case will be caught via RowsAffected
   278  	// below.
   279  	if b.id == nil {
   280  		// ON CONFLICT requires the WHERE parent_id IS NULL hint to let
   281  		// Postgres find the NULL-parent_id unique index (<table>_unp).
   282  		result, err = b.tx.Exec(
   283  			"INSERT INTO "+b.table+" (key, value) VALUES($1, $2) "+
   284  				"ON CONFLICT (key) WHERE parent_id IS NULL "+
   285  				"DO UPDATE SET value=$2 "+
   286  				"WHERE "+b.table+".value IS NOT NULL",
   287  			key, value,
   288  		)
   289  	} else {
   290  		// ON CONFLICT requires the WHERE parent_id NOT IS NULL hint to
   291  		// let Postgres find the non-NULL-parent_id unique index
   292  		// (<table>_up).
   293  		result, err = b.tx.Exec(
   294  			"INSERT INTO "+b.table+" (key, value, parent_id) "+
   295  				"VALUES($1, $2, $3) "+
   296  				"ON CONFLICT (key, parent_id) "+
   297  				"WHERE parent_id IS NOT NULL "+
   298  				"DO UPDATE SET value=$2 "+
   299  				"WHERE "+b.table+".value IS NOT NULL",
   300  			key, value, b.id,
   301  		)
   302  	}
   303  	if err != nil {
   304  		return err
   305  	}
   306  
   307  	rows, err := result.RowsAffected()
   308  	if err != nil {
   309  		return err
   310  	}
   311  	if rows != 1 {
   312  		return walletdb.ErrIncompatibleValue
   313  	}
   314  
   315  	return nil
   316  }
   317  
   318  // Delete deletes the key/value pointed to by the passed key.
   319  // Returns ErrKeyRequired if the passed key is empty.
   320  func (b *readWriteBucket) Delete(key []byte) error {
   321  	if key == nil {
   322  		return nil
   323  	}
   324  	if len(key) == 0 {
   325  		return walletdb.ErrKeyRequired
   326  	}
   327  
   328  	// Check to see if a bucket with this key exists.
   329  	var dummy int
   330  	row, cancel := b.tx.QueryRow(
   331  		"SELECT 1 FROM "+b.table+" WHERE "+parentSelector(b.id)+
   332  			" AND key=$1 AND value IS NULL", key,
   333  	)
   334  	defer cancel()
   335  	err := row.Scan(&dummy)
   336  	switch {
   337  	// No bucket exists, proceed to deletion of the key.
   338  	case err == sql.ErrNoRows:
   339  
   340  	case err != nil:
   341  		return err
   342  
   343  	// Bucket exists.
   344  	default:
   345  		return walletdb.ErrIncompatibleValue
   346  	}
   347  
   348  	_, err = b.tx.Exec(
   349  		"DELETE FROM "+b.table+" WHERE key=$1 AND "+
   350  			parentSelector(b.id)+" AND value IS NOT NULL",
   351  		key,
   352  	)
   353  	if err != nil {
   354  		return err
   355  	}
   356  
   357  	return nil
   358  }
   359  
   360  // ReadWriteCursor returns a new read-write cursor for this bucket.
   361  func (b *readWriteBucket) ReadWriteCursor() walletdb.ReadWriteCursor {
   362  	return newReadWriteCursor(b)
   363  }
   364  
   365  // Tx returns the buckets transaction.
   366  func (b *readWriteBucket) Tx() walletdb.ReadWriteTx {
   367  	return b.tx
   368  }
   369  
   370  // NextSequence returns an autoincrementing sequence number for this bucket.
   371  // Note that this is not a thread safe function and as such it must not be used
   372  // for synchronization.
   373  func (b *readWriteBucket) NextSequence() (uint64, error) {
   374  	seq := b.Sequence() + 1
   375  
   376  	return seq, b.SetSequence(seq)
   377  }
   378  
   379  // SetSequence updates the sequence number for the bucket.
   380  func (b *readWriteBucket) SetSequence(v uint64) error {
   381  	if b.id == nil {
   382  		panic("sequence not supported on top level bucket")
   383  	}
   384  
   385  	result, err := b.tx.Exec(
   386  		"UPDATE "+b.table+" SET sequence=$2 WHERE id=$1",
   387  		b.id, int64(v),
   388  	)
   389  	if err != nil {
   390  		return err
   391  	}
   392  
   393  	rows, err := result.RowsAffected()
   394  	if err != nil {
   395  		return err
   396  	}
   397  	if rows != 1 {
   398  		return errors.New("cannot set sequence")
   399  	}
   400  
   401  	return nil
   402  }
   403  
   404  // Sequence returns the current sequence number for this bucket without
   405  // incrementing it.
   406  func (b *readWriteBucket) Sequence() uint64 {
   407  	if b.id == nil {
   408  		panic("sequence not supported on top level bucket")
   409  	}
   410  
   411  	var seq int64
   412  	row, cancel := b.tx.QueryRow(
   413  		"SELECT sequence FROM "+b.table+" WHERE id=$1 "+
   414  			"AND sequence IS NOT NULL",
   415  		b.id,
   416  	)
   417  	defer cancel()
   418  	err := row.Scan(&seq)
   419  
   420  	switch {
   421  	case err == sql.ErrNoRows:
   422  		return 0
   423  
   424  	case err != nil:
   425  		panic(err)
   426  	}
   427  
   428  	return uint64(seq)
   429  }
   430  
   431  // Prefetch will attempt to prefetch all values under a path from the passed
   432  // bucket.
   433  func (b *readWriteBucket) Prefetch(paths ...[]string) {}
   434  
   435  // ForAll is an optimized version of ForEach with the limitation that no
   436  // additional queries can be executed within the callback.
   437  func (b *readWriteBucket) ForAll(cb func(k, v []byte) error) error {
   438  	rows, cancel, err := b.tx.Query(
   439  		"SELECT key, value FROM " + b.table + " WHERE " +
   440  			parentSelector(b.id) + " ORDER BY key",
   441  	)
   442  	if err != nil {
   443  		return err
   444  	}
   445  	defer cancel()
   446  
   447  	for rows.Next() {
   448  		var key, value []byte
   449  
   450  		err := rows.Scan(&key, &value)
   451  		if err != nil {
   452  			return err
   453  		}
   454  
   455  		err = cb(key, value)
   456  		if err != nil {
   457  			return err
   458  		}
   459  	}
   460  
   461  	return nil
   462  }