github.com/cryptotooltop/go-ethereum@v0.0.0-20231103184714-151d1922f3e5/trie/zk_trie_database.go (about) 1 package trie 2 3 import ( 4 "math/big" 5 6 "github.com/syndtr/goleveldb/leveldb" 7 8 zktrie "github.com/scroll-tech/zktrie/trie" 9 10 "github.com/scroll-tech/go-ethereum/common" 11 "github.com/scroll-tech/go-ethereum/ethdb" 12 ) 13 14 // ZktrieDatabase Database adaptor implements zktrie.ZktrieDatbase 15 // It also reverses the bit order of the key being persisted. 16 // This ensures that the adjacent leaf in zktrie maintains minimal 17 // distance when persisted with dictionary order in LevelDB. 18 // Consequently, this optimizes the snapshot operation, allowing it 19 // to iterate through adjacent leaves at a reduced cost. 20 21 type ZktrieDatabase struct { 22 db *Database 23 prefix []byte 24 } 25 26 func NewZktrieDatabase(diskdb ethdb.KeyValueStore) *ZktrieDatabase { 27 return &ZktrieDatabase{db: NewDatabase(diskdb), prefix: []byte{}} 28 } 29 30 // adhoc wrapper... 31 func NewZktrieDatabaseFromTriedb(db *Database) *ZktrieDatabase { 32 db.Zktrie = true 33 return &ZktrieDatabase{db: db, prefix: []byte{}} 34 } 35 36 // Put saves a key:value into the Storage 37 func (l *ZktrieDatabase) Put(k, v []byte) error { 38 k = bitReverse(k) 39 l.db.lock.Lock() 40 l.db.rawDirties.Put(Concat(l.prefix, k[:]), v) 41 l.db.lock.Unlock() 42 return nil 43 } 44 45 // Get retrieves a value from a key in the Storage 46 func (l *ZktrieDatabase) Get(key []byte) ([]byte, error) { 47 key = bitReverse(key) 48 concatKey := Concat(l.prefix, key[:]) 49 l.db.lock.RLock() 50 value, ok := l.db.rawDirties.Get(concatKey) 51 l.db.lock.RUnlock() 52 if ok { 53 return value, nil 54 } 55 56 if l.db.cleans != nil { 57 if enc := l.db.cleans.Get(nil, concatKey); enc != nil { 58 memcacheCleanHitMeter.Mark(1) 59 memcacheCleanReadMeter.Mark(int64(len(enc))) 60 return enc, nil 61 } 62 } 63 64 v, err := l.db.diskdb.Get(concatKey) 65 if err == leveldb.ErrNotFound { 66 return nil, zktrie.ErrKeyNotFound 67 } 68 if l.db.cleans != nil { 69 l.db.cleans.Set(concatKey[:], v) 70 memcacheCleanMissMeter.Mark(1) 71 memcacheCleanWriteMeter.Mark(int64(len(v))) 72 } 73 return v, err 74 } 75 76 func (l *ZktrieDatabase) UpdatePreimage(preimage []byte, hashField *big.Int) { 77 db := l.db 78 if db.preimages != nil { // Ugly direct check but avoids the below write lock 79 // we must copy the input key 80 db.preimages.insertPreimage(map[common.Hash][]byte{common.BytesToHash(hashField.Bytes()): common.CopyBytes(preimage)}) 81 } 82 } 83 84 // Iterate implements the method Iterate of the interface Storage 85 func (l *ZktrieDatabase) Iterate(f func([]byte, []byte) (bool, error)) error { 86 iter := l.db.diskdb.NewIterator(l.prefix, nil) 87 defer iter.Release() 88 for iter.Next() { 89 localKey := bitReverse(iter.Key()[len(l.prefix):]) 90 if cont, err := f(localKey, iter.Value()); err != nil { 91 return err 92 } else if !cont { 93 break 94 } 95 } 96 iter.Release() 97 return iter.Error() 98 } 99 100 // Close implements the method Close of the interface Storage 101 func (l *ZktrieDatabase) Close() { 102 // FIXME: is this correct? 103 if err := l.db.diskdb.Close(); err != nil { 104 panic(err) 105 } 106 } 107 108 // List implements the method List of the interface Storage 109 func (l *ZktrieDatabase) List(limit int) ([]KV, error) { 110 ret := []KV{} 111 err := l.Iterate(func(key []byte, value []byte) (bool, error) { 112 ret = append(ret, KV{K: Clone(key), V: Clone(value)}) 113 if len(ret) == limit { 114 return false, nil 115 } 116 return true, nil 117 }) 118 return ret, err 119 } 120 121 func bitReverseForNibble(b byte) byte { 122 switch b { 123 case 0: 124 return 0 125 case 1: 126 return 8 127 case 2: 128 return 4 129 case 3: 130 return 12 131 case 4: 132 return 2 133 case 5: 134 return 10 135 case 6: 136 return 6 137 case 7: 138 return 14 139 case 8: 140 return 1 141 case 9: 142 return 9 143 case 10: 144 return 5 145 case 11: 146 return 13 147 case 12: 148 return 3 149 case 13: 150 return 11 151 case 14: 152 return 7 153 case 15: 154 return 15 155 default: 156 panic("unexpected input") 157 } 158 } 159 160 func bitReverse(inp []byte) (out []byte) { 161 162 l := len(inp) 163 out = make([]byte, l) 164 165 for i, b := range inp { 166 out[l-i-1] = bitReverseForNibble(b&15)<<4 + bitReverseForNibble(b>>4) 167 } 168 169 return 170 }