github.com/MetalBlockchain/metalgo@v1.11.9/database/encdb/db.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package encdb
     5  
     6  import (
     7  	"context"
     8  	"crypto/cipher"
     9  	"crypto/rand"
    10  	"slices"
    11  	"sync"
    12  
    13  	"golang.org/x/crypto/chacha20poly1305"
    14  
    15  	"github.com/MetalBlockchain/metalgo/database"
    16  	"github.com/MetalBlockchain/metalgo/utils/hashing"
    17  )
    18  
    19  var (
    20  	_ database.Database = (*Database)(nil)
    21  	_ database.Batch    = (*batch)(nil)
    22  	_ database.Iterator = (*iterator)(nil)
    23  )
    24  
    25  // Database encrypts all values that are provided
    26  type Database struct {
    27  	lock   sync.RWMutex
    28  	cipher cipher.AEAD
    29  	db     database.Database
    30  	closed bool
    31  }
    32  
    33  // New returns a new encrypted database
    34  func New(password []byte, db database.Database) (*Database, error) {
    35  	h := hashing.ComputeHash256(password)
    36  	aead, err := chacha20poly1305.NewX(h)
    37  	return &Database{
    38  		cipher: aead,
    39  		db:     db,
    40  	}, err
    41  }
    42  
    43  func (db *Database) Has(key []byte) (bool, error) {
    44  	db.lock.RLock()
    45  	defer db.lock.RUnlock()
    46  
    47  	if db.closed {
    48  		return false, database.ErrClosed
    49  	}
    50  	return db.db.Has(key)
    51  }
    52  
    53  func (db *Database) Get(key []byte) ([]byte, error) {
    54  	db.lock.RLock()
    55  	defer db.lock.RUnlock()
    56  
    57  	if db.closed {
    58  		return nil, database.ErrClosed
    59  	}
    60  	encVal, err := db.db.Get(key)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	return db.decrypt(encVal)
    65  }
    66  
    67  func (db *Database) Put(key, value []byte) error {
    68  	db.lock.Lock()
    69  	defer db.lock.Unlock()
    70  
    71  	if db.closed {
    72  		return database.ErrClosed
    73  	}
    74  
    75  	encValue, err := db.encrypt(value)
    76  	if err != nil {
    77  		return err
    78  	}
    79  	return db.db.Put(key, encValue)
    80  }
    81  
    82  func (db *Database) Delete(key []byte) error {
    83  	db.lock.Lock()
    84  	defer db.lock.Unlock()
    85  
    86  	if db.closed {
    87  		return database.ErrClosed
    88  	}
    89  	return db.db.Delete(key)
    90  }
    91  
    92  func (db *Database) NewBatch() database.Batch {
    93  	return &batch{
    94  		Batch: db.db.NewBatch(),
    95  		db:    db,
    96  	}
    97  }
    98  
    99  func (db *Database) NewIterator() database.Iterator {
   100  	return db.NewIteratorWithStartAndPrefix(nil, nil)
   101  }
   102  
   103  func (db *Database) NewIteratorWithStart(start []byte) database.Iterator {
   104  	return db.NewIteratorWithStartAndPrefix(start, nil)
   105  }
   106  
   107  func (db *Database) NewIteratorWithPrefix(prefix []byte) database.Iterator {
   108  	return db.NewIteratorWithStartAndPrefix(nil, prefix)
   109  }
   110  
   111  func (db *Database) NewIteratorWithStartAndPrefix(start, prefix []byte) database.Iterator {
   112  	db.lock.RLock()
   113  	defer db.lock.RUnlock()
   114  
   115  	if db.closed {
   116  		return &database.IteratorError{
   117  			Err: database.ErrClosed,
   118  		}
   119  	}
   120  	return &iterator{
   121  		Iterator: db.db.NewIteratorWithStartAndPrefix(start, prefix),
   122  		db:       db,
   123  	}
   124  }
   125  
   126  func (db *Database) Compact(start, limit []byte) error {
   127  	db.lock.Lock()
   128  	defer db.lock.Unlock()
   129  
   130  	if db.closed {
   131  		return database.ErrClosed
   132  	}
   133  	return db.db.Compact(start, limit)
   134  }
   135  
   136  func (db *Database) Close() error {
   137  	db.lock.Lock()
   138  	defer db.lock.Unlock()
   139  
   140  	if db.closed {
   141  		return database.ErrClosed
   142  	}
   143  	db.closed = true
   144  	return nil
   145  }
   146  
   147  func (db *Database) isClosed() bool {
   148  	db.lock.RLock()
   149  	defer db.lock.RUnlock()
   150  
   151  	return db.closed
   152  }
   153  
   154  func (db *Database) HealthCheck(ctx context.Context) (interface{}, error) {
   155  	db.lock.RLock()
   156  	defer db.lock.RUnlock()
   157  
   158  	if db.closed {
   159  		return nil, database.ErrClosed
   160  	}
   161  	return db.db.HealthCheck(ctx)
   162  }
   163  
   164  type batch struct {
   165  	database.Batch
   166  
   167  	db  *Database
   168  	ops []database.BatchOp
   169  }
   170  
   171  func (b *batch) Put(key, value []byte) error {
   172  	b.ops = append(b.ops, database.BatchOp{
   173  		Key:   slices.Clone(key),
   174  		Value: slices.Clone(value),
   175  	})
   176  	encValue, err := b.db.encrypt(value)
   177  	if err != nil {
   178  		return err
   179  	}
   180  	return b.Batch.Put(key, encValue)
   181  }
   182  
   183  func (b *batch) Delete(key []byte) error {
   184  	b.ops = append(b.ops, database.BatchOp{
   185  		Key:    slices.Clone(key),
   186  		Delete: true,
   187  	})
   188  	return b.Batch.Delete(key)
   189  }
   190  
   191  func (b *batch) Write() error {
   192  	b.db.lock.Lock()
   193  	defer b.db.lock.Unlock()
   194  
   195  	if b.db.closed {
   196  		return database.ErrClosed
   197  	}
   198  
   199  	return b.Batch.Write()
   200  }
   201  
   202  // Reset resets the batch for reuse.
   203  func (b *batch) Reset() {
   204  	if cap(b.ops) > len(b.ops)*database.MaxExcessCapacityFactor {
   205  		b.ops = make([]database.BatchOp, 0, cap(b.ops)/database.CapacityReductionFactor)
   206  	} else {
   207  		b.ops = b.ops[:0]
   208  	}
   209  	b.Batch.Reset()
   210  }
   211  
   212  // Replay replays the batch contents.
   213  func (b *batch) Replay(w database.KeyValueWriterDeleter) error {
   214  	for _, op := range b.ops {
   215  		if op.Delete {
   216  			if err := w.Delete(op.Key); err != nil {
   217  				return err
   218  			}
   219  		} else if err := w.Put(op.Key, op.Value); err != nil {
   220  			return err
   221  		}
   222  	}
   223  	return nil
   224  }
   225  
   226  type iterator struct {
   227  	database.Iterator
   228  	db *Database
   229  
   230  	val, key []byte
   231  	err      error
   232  }
   233  
   234  func (it *iterator) Next() bool {
   235  	// Short-circuit and set an error if the underlying database has been closed.
   236  	if it.db.isClosed() {
   237  		it.val = nil
   238  		it.key = nil
   239  		it.err = database.ErrClosed
   240  		return false
   241  	}
   242  
   243  	next := it.Iterator.Next()
   244  	if next {
   245  		encVal := it.Iterator.Value()
   246  		val, err := it.db.decrypt(encVal)
   247  		if err != nil {
   248  			it.err = err
   249  			return false
   250  		}
   251  		it.val = val
   252  		it.key = it.Iterator.Key()
   253  	} else {
   254  		it.val = nil
   255  		it.key = nil
   256  	}
   257  	return next
   258  }
   259  
   260  func (it *iterator) Error() error {
   261  	if it.err != nil {
   262  		return it.err
   263  	}
   264  	return it.Iterator.Error()
   265  }
   266  
   267  func (it *iterator) Key() []byte {
   268  	return it.key
   269  }
   270  
   271  func (it *iterator) Value() []byte {
   272  	return it.val
   273  }
   274  
   275  type encryptedValue struct {
   276  	Ciphertext []byte `serialize:"true"`
   277  	Nonce      []byte `serialize:"true"`
   278  }
   279  
   280  func (db *Database) encrypt(plaintext []byte) ([]byte, error) {
   281  	nonce := make([]byte, chacha20poly1305.NonceSizeX)
   282  	if _, err := rand.Read(nonce); err != nil {
   283  		return nil, err
   284  	}
   285  	ciphertext := db.cipher.Seal(nil, nonce, plaintext, nil)
   286  	return Codec.Marshal(CodecVersion, &encryptedValue{
   287  		Ciphertext: ciphertext,
   288  		Nonce:      nonce,
   289  	})
   290  }
   291  
   292  func (db *Database) decrypt(ciphertext []byte) ([]byte, error) {
   293  	val := encryptedValue{}
   294  	if _, err := Codec.Unmarshal(ciphertext, &val); err != nil {
   295  		return nil, err
   296  	}
   297  	return db.cipher.Open(nil, val.Nonce, val.Ciphertext, nil)
   298  }