github.com/cryptotooltop/go-ethereum@v0.0.0-20231103184714-151d1922f3e5/trie/zk_trie_database.go (about)

     1  package trie
     2  
     3  import (
     4  	"math/big"
     5  
     6  	"github.com/syndtr/goleveldb/leveldb"
     7  
     8  	zktrie "github.com/scroll-tech/zktrie/trie"
     9  
    10  	"github.com/scroll-tech/go-ethereum/common"
    11  	"github.com/scroll-tech/go-ethereum/ethdb"
    12  )
    13  
    14  // ZktrieDatabase Database adaptor implements zktrie.ZktrieDatbase
    15  // It also reverses the bit order of the key being persisted.
    16  // This ensures that the adjacent leaf in zktrie maintains minimal
    17  // distance when persisted with dictionary order in LevelDB.
    18  // Consequently, this optimizes the snapshot operation, allowing it
    19  // to iterate through adjacent leaves at a reduced cost.
    20  
    21  type ZktrieDatabase struct {
    22  	db     *Database
    23  	prefix []byte
    24  }
    25  
    26  func NewZktrieDatabase(diskdb ethdb.KeyValueStore) *ZktrieDatabase {
    27  	return &ZktrieDatabase{db: NewDatabase(diskdb), prefix: []byte{}}
    28  }
    29  
    30  // adhoc wrapper...
    31  func NewZktrieDatabaseFromTriedb(db *Database) *ZktrieDatabase {
    32  	db.Zktrie = true
    33  	return &ZktrieDatabase{db: db, prefix: []byte{}}
    34  }
    35  
    36  // Put saves a key:value into the Storage
    37  func (l *ZktrieDatabase) Put(k, v []byte) error {
    38  	k = bitReverse(k)
    39  	l.db.lock.Lock()
    40  	l.db.rawDirties.Put(Concat(l.prefix, k[:]), v)
    41  	l.db.lock.Unlock()
    42  	return nil
    43  }
    44  
    45  // Get retrieves a value from a key in the Storage
    46  func (l *ZktrieDatabase) Get(key []byte) ([]byte, error) {
    47  	key = bitReverse(key)
    48  	concatKey := Concat(l.prefix, key[:])
    49  	l.db.lock.RLock()
    50  	value, ok := l.db.rawDirties.Get(concatKey)
    51  	l.db.lock.RUnlock()
    52  	if ok {
    53  		return value, nil
    54  	}
    55  
    56  	if l.db.cleans != nil {
    57  		if enc := l.db.cleans.Get(nil, concatKey); enc != nil {
    58  			memcacheCleanHitMeter.Mark(1)
    59  			memcacheCleanReadMeter.Mark(int64(len(enc)))
    60  			return enc, nil
    61  		}
    62  	}
    63  
    64  	v, err := l.db.diskdb.Get(concatKey)
    65  	if err == leveldb.ErrNotFound {
    66  		return nil, zktrie.ErrKeyNotFound
    67  	}
    68  	if l.db.cleans != nil {
    69  		l.db.cleans.Set(concatKey[:], v)
    70  		memcacheCleanMissMeter.Mark(1)
    71  		memcacheCleanWriteMeter.Mark(int64(len(v)))
    72  	}
    73  	return v, err
    74  }
    75  
    76  func (l *ZktrieDatabase) UpdatePreimage(preimage []byte, hashField *big.Int) {
    77  	db := l.db
    78  	if db.preimages != nil { // Ugly direct check but avoids the below write lock
    79  		// we must copy the input key
    80  		db.preimages.insertPreimage(map[common.Hash][]byte{common.BytesToHash(hashField.Bytes()): common.CopyBytes(preimage)})
    81  	}
    82  }
    83  
    84  // Iterate implements the method Iterate of the interface Storage
    85  func (l *ZktrieDatabase) Iterate(f func([]byte, []byte) (bool, error)) error {
    86  	iter := l.db.diskdb.NewIterator(l.prefix, nil)
    87  	defer iter.Release()
    88  	for iter.Next() {
    89  		localKey := bitReverse(iter.Key()[len(l.prefix):])
    90  		if cont, err := f(localKey, iter.Value()); err != nil {
    91  			return err
    92  		} else if !cont {
    93  			break
    94  		}
    95  	}
    96  	iter.Release()
    97  	return iter.Error()
    98  }
    99  
   100  // Close implements the method Close of the interface Storage
   101  func (l *ZktrieDatabase) Close() {
   102  	// FIXME: is this correct?
   103  	if err := l.db.diskdb.Close(); err != nil {
   104  		panic(err)
   105  	}
   106  }
   107  
   108  // List implements the method List of the interface Storage
   109  func (l *ZktrieDatabase) List(limit int) ([]KV, error) {
   110  	ret := []KV{}
   111  	err := l.Iterate(func(key []byte, value []byte) (bool, error) {
   112  		ret = append(ret, KV{K: Clone(key), V: Clone(value)})
   113  		if len(ret) == limit {
   114  			return false, nil
   115  		}
   116  		return true, nil
   117  	})
   118  	return ret, err
   119  }
   120  
   121  func bitReverseForNibble(b byte) byte {
   122  	switch b {
   123  	case 0:
   124  		return 0
   125  	case 1:
   126  		return 8
   127  	case 2:
   128  		return 4
   129  	case 3:
   130  		return 12
   131  	case 4:
   132  		return 2
   133  	case 5:
   134  		return 10
   135  	case 6:
   136  		return 6
   137  	case 7:
   138  		return 14
   139  	case 8:
   140  		return 1
   141  	case 9:
   142  		return 9
   143  	case 10:
   144  		return 5
   145  	case 11:
   146  		return 13
   147  	case 12:
   148  		return 3
   149  	case 13:
   150  		return 11
   151  	case 14:
   152  		return 7
   153  	case 15:
   154  		return 15
   155  	default:
   156  		panic("unexpected input")
   157  	}
   158  }
   159  
   160  func bitReverse(inp []byte) (out []byte) {
   161  
   162  	l := len(inp)
   163  	out = make([]byte, l)
   164  
   165  	for i, b := range inp {
   166  		out[l-i-1] = bitReverseForNibble(b&15)<<4 + bitReverseForNibble(b>>4)
   167  	}
   168  
   169  	return
   170  }