github.com/MetalBlockchain/metalgo@v1.11.9/database/encdb/db.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package encdb 5 6 import ( 7 "context" 8 "crypto/cipher" 9 "crypto/rand" 10 "slices" 11 "sync" 12 13 "golang.org/x/crypto/chacha20poly1305" 14 15 "github.com/MetalBlockchain/metalgo/database" 16 "github.com/MetalBlockchain/metalgo/utils/hashing" 17 ) 18 19 var ( 20 _ database.Database = (*Database)(nil) 21 _ database.Batch = (*batch)(nil) 22 _ database.Iterator = (*iterator)(nil) 23 ) 24 25 // Database encrypts all values that are provided 26 type Database struct { 27 lock sync.RWMutex 28 cipher cipher.AEAD 29 db database.Database 30 closed bool 31 } 32 33 // New returns a new encrypted database 34 func New(password []byte, db database.Database) (*Database, error) { 35 h := hashing.ComputeHash256(password) 36 aead, err := chacha20poly1305.NewX(h) 37 return &Database{ 38 cipher: aead, 39 db: db, 40 }, err 41 } 42 43 func (db *Database) Has(key []byte) (bool, error) { 44 db.lock.RLock() 45 defer db.lock.RUnlock() 46 47 if db.closed { 48 return false, database.ErrClosed 49 } 50 return db.db.Has(key) 51 } 52 53 func (db *Database) Get(key []byte) ([]byte, error) { 54 db.lock.RLock() 55 defer db.lock.RUnlock() 56 57 if db.closed { 58 return nil, database.ErrClosed 59 } 60 encVal, err := db.db.Get(key) 61 if err != nil { 62 return nil, err 63 } 64 return db.decrypt(encVal) 65 } 66 67 func (db *Database) Put(key, value []byte) error { 68 db.lock.Lock() 69 defer db.lock.Unlock() 70 71 if db.closed { 72 return database.ErrClosed 73 } 74 75 encValue, err := db.encrypt(value) 76 if err != nil { 77 return err 78 } 79 return db.db.Put(key, encValue) 80 } 81 82 func (db *Database) Delete(key []byte) error { 83 db.lock.Lock() 84 defer db.lock.Unlock() 85 86 if db.closed { 87 return database.ErrClosed 88 } 89 return db.db.Delete(key) 90 } 91 92 func (db *Database) NewBatch() database.Batch { 93 return &batch{ 94 Batch: db.db.NewBatch(), 95 db: db, 96 } 97 } 98 99 func (db *Database) NewIterator() database.Iterator { 100 return db.NewIteratorWithStartAndPrefix(nil, nil) 101 } 102 103 func (db *Database) NewIteratorWithStart(start []byte) database.Iterator { 104 return db.NewIteratorWithStartAndPrefix(start, nil) 105 } 106 107 func (db *Database) NewIteratorWithPrefix(prefix []byte) database.Iterator { 108 return db.NewIteratorWithStartAndPrefix(nil, prefix) 109 } 110 111 func (db *Database) NewIteratorWithStartAndPrefix(start, prefix []byte) database.Iterator { 112 db.lock.RLock() 113 defer db.lock.RUnlock() 114 115 if db.closed { 116 return &database.IteratorError{ 117 Err: database.ErrClosed, 118 } 119 } 120 return &iterator{ 121 Iterator: db.db.NewIteratorWithStartAndPrefix(start, prefix), 122 db: db, 123 } 124 } 125 126 func (db *Database) Compact(start, limit []byte) error { 127 db.lock.Lock() 128 defer db.lock.Unlock() 129 130 if db.closed { 131 return database.ErrClosed 132 } 133 return db.db.Compact(start, limit) 134 } 135 136 func (db *Database) Close() error { 137 db.lock.Lock() 138 defer db.lock.Unlock() 139 140 if db.closed { 141 return database.ErrClosed 142 } 143 db.closed = true 144 return nil 145 } 146 147 func (db *Database) isClosed() bool { 148 db.lock.RLock() 149 defer db.lock.RUnlock() 150 151 return db.closed 152 } 153 154 func (db *Database) HealthCheck(ctx context.Context) (interface{}, error) { 155 db.lock.RLock() 156 defer db.lock.RUnlock() 157 158 if db.closed { 159 return nil, database.ErrClosed 160 } 161 return db.db.HealthCheck(ctx) 162 } 163 164 type batch struct { 165 database.Batch 166 167 db *Database 168 ops []database.BatchOp 169 } 170 171 func (b *batch) Put(key, value []byte) error { 172 b.ops = append(b.ops, database.BatchOp{ 173 Key: slices.Clone(key), 174 Value: slices.Clone(value), 175 }) 176 encValue, err := b.db.encrypt(value) 177 if err != nil { 178 return err 179 } 180 return b.Batch.Put(key, encValue) 181 } 182 183 func (b *batch) Delete(key []byte) error { 184 b.ops = append(b.ops, database.BatchOp{ 185 Key: slices.Clone(key), 186 Delete: true, 187 }) 188 return b.Batch.Delete(key) 189 } 190 191 func (b *batch) Write() error { 192 b.db.lock.Lock() 193 defer b.db.lock.Unlock() 194 195 if b.db.closed { 196 return database.ErrClosed 197 } 198 199 return b.Batch.Write() 200 } 201 202 // Reset resets the batch for reuse. 203 func (b *batch) Reset() { 204 if cap(b.ops) > len(b.ops)*database.MaxExcessCapacityFactor { 205 b.ops = make([]database.BatchOp, 0, cap(b.ops)/database.CapacityReductionFactor) 206 } else { 207 b.ops = b.ops[:0] 208 } 209 b.Batch.Reset() 210 } 211 212 // Replay replays the batch contents. 213 func (b *batch) Replay(w database.KeyValueWriterDeleter) error { 214 for _, op := range b.ops { 215 if op.Delete { 216 if err := w.Delete(op.Key); err != nil { 217 return err 218 } 219 } else if err := w.Put(op.Key, op.Value); err != nil { 220 return err 221 } 222 } 223 return nil 224 } 225 226 type iterator struct { 227 database.Iterator 228 db *Database 229 230 val, key []byte 231 err error 232 } 233 234 func (it *iterator) Next() bool { 235 // Short-circuit and set an error if the underlying database has been closed. 236 if it.db.isClosed() { 237 it.val = nil 238 it.key = nil 239 it.err = database.ErrClosed 240 return false 241 } 242 243 next := it.Iterator.Next() 244 if next { 245 encVal := it.Iterator.Value() 246 val, err := it.db.decrypt(encVal) 247 if err != nil { 248 it.err = err 249 return false 250 } 251 it.val = val 252 it.key = it.Iterator.Key() 253 } else { 254 it.val = nil 255 it.key = nil 256 } 257 return next 258 } 259 260 func (it *iterator) Error() error { 261 if it.err != nil { 262 return it.err 263 } 264 return it.Iterator.Error() 265 } 266 267 func (it *iterator) Key() []byte { 268 return it.key 269 } 270 271 func (it *iterator) Value() []byte { 272 return it.val 273 } 274 275 type encryptedValue struct { 276 Ciphertext []byte `serialize:"true"` 277 Nonce []byte `serialize:"true"` 278 } 279 280 func (db *Database) encrypt(plaintext []byte) ([]byte, error) { 281 nonce := make([]byte, chacha20poly1305.NonceSizeX) 282 if _, err := rand.Read(nonce); err != nil { 283 return nil, err 284 } 285 ciphertext := db.cipher.Seal(nil, nonce, plaintext, nil) 286 return Codec.Marshal(CodecVersion, &encryptedValue{ 287 Ciphertext: ciphertext, 288 Nonce: nonce, 289 }) 290 } 291 292 func (db *Database) decrypt(ciphertext []byte) ([]byte, error) { 293 val := encryptedValue{} 294 if _, err := Codec.Unmarshal(ciphertext, &val); err != nil { 295 return nil, err 296 } 297 return db.cipher.Open(nil, val.Nonce, val.Ciphertext, nil) 298 }