github.com/cryptotooltop/go-ethereum@v0.0.0-20231103184714-151d1922f3e5/trie/zk_trie.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package trie 18 19 import ( 20 "fmt" 21 22 zktrie "github.com/scroll-tech/zktrie/trie" 23 zkt "github.com/scroll-tech/zktrie/types" 24 25 "github.com/scroll-tech/go-ethereum/common" 26 "github.com/scroll-tech/go-ethereum/core/types" 27 "github.com/scroll-tech/go-ethereum/crypto/poseidon" 28 "github.com/scroll-tech/go-ethereum/ethdb" 29 "github.com/scroll-tech/go-ethereum/log" 30 ) 31 32 var magicHash []byte = []byte("THIS IS THE MAGIC INDEX FOR ZKTRIE") 33 34 // wrap zktrie for trie interface 35 type ZkTrie struct { 36 *zktrie.ZkTrie 37 db *ZktrieDatabase 38 } 39 40 func init() { 41 zkt.InitHashScheme(poseidon.HashFixedWithDomain) 42 } 43 44 func sanityCheckByte32Key(b []byte) { 45 if len(b) != 32 && len(b) != 20 { 46 panic(fmt.Errorf("do not support length except for 120bit and 256bit now. data: %v len: %v", b, len(b))) 47 } 48 } 49 50 // NewZkTrie creates a trie 51 // NewZkTrie bypasses all the buffer mechanism in *Database, it directly uses the 52 // underlying diskdb 53 func NewZkTrie(root common.Hash, db *ZktrieDatabase) (*ZkTrie, error) { 54 tr, err := zktrie.NewZkTrie(*zkt.NewByte32FromBytes(root.Bytes()), db) 55 if err != nil { 56 return nil, err 57 } 58 return &ZkTrie{tr, db}, nil 59 } 60 61 // Get returns the value for key stored in the trie. 62 // The value bytes must not be modified by the caller. 63 func (t *ZkTrie) Get(key []byte) []byte { 64 sanityCheckByte32Key(key) 65 res, err := t.TryGet(key) 66 if err != nil { 67 log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) 68 } 69 return res 70 } 71 72 // TryUpdateAccount will abstract the write of an account to the 73 // secure trie. 74 func (t *ZkTrie) TryUpdateAccount(key []byte, acc *types.StateAccount) error { 75 sanityCheckByte32Key(key) 76 value, flag := acc.MarshalFields() 77 return t.ZkTrie.TryUpdate(key, flag, value) 78 } 79 80 // Update associates key with value in the trie. Subsequent calls to 81 // Get will return value. If value has length zero, any existing value 82 // is deleted from the trie and calls to Get will return nil. 83 // 84 // The value bytes must not be modified by the caller while they are 85 // stored in the trie. 86 func (t *ZkTrie) Update(key, value []byte) { 87 if err := t.TryUpdate(key, value); err != nil { 88 log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) 89 } 90 } 91 92 // NOTE: value is restricted to length of bytes32. 93 // we override the underlying zktrie's TryUpdate method 94 func (t *ZkTrie) TryUpdate(key, value []byte) error { 95 sanityCheckByte32Key(key) 96 return t.ZkTrie.TryUpdate(key, 1, []zkt.Byte32{*zkt.NewByte32FromBytes(value)}) 97 } 98 99 // Delete removes any existing value for key from the trie. 100 func (t *ZkTrie) Delete(key []byte) { 101 sanityCheckByte32Key(key) 102 if err := t.TryDelete(key); err != nil { 103 log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) 104 } 105 } 106 107 // GetKey returns the preimage of a hashed key that was 108 // previously used to store a value. 109 func (t *ZkTrie) GetKey(kHashBytes []byte) []byte { 110 // TODO: use a kv cache in memory 111 k, err := zkt.NewBigIntFromHashBytes(kHashBytes) 112 if err != nil { 113 log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) 114 } 115 if t.db.db.preimages != nil { 116 return t.db.db.preimages.preimage(common.BytesToHash(k.Bytes())) 117 } 118 return nil 119 } 120 121 // Commit writes all nodes and the secure hash pre-images to the trie's database. 122 // Nodes are stored with their sha3 hash as the key. 123 // 124 // Committing flushes nodes from memory. Subsequent Get calls will load nodes 125 // from the database. 126 func (t *ZkTrie) Commit(LeafCallback) (common.Hash, int, error) { 127 // in current implmentation, every update of trie already writes into database 128 // so Commmit does nothing 129 return t.Hash(), 0, nil 130 } 131 132 // Hash returns the root hash of SecureBinaryTrie. It does not write to the 133 // database and can be used even if the trie doesn't have one. 134 func (t *ZkTrie) Hash() common.Hash { 135 var hash common.Hash 136 hash.SetBytes(t.ZkTrie.Hash()) 137 return hash 138 } 139 140 // Copy returns a copy of SecureBinaryTrie. 141 func (t *ZkTrie) Copy() *ZkTrie { 142 return &ZkTrie{t.ZkTrie.Copy(), t.db} 143 } 144 145 // NodeIterator returns an iterator that returns nodes of the underlying trie. Iteration 146 // starts at the key after the given start key. 147 func (t *ZkTrie) NodeIterator(start []byte) NodeIterator { 148 /// FIXME 149 panic("not implemented") 150 } 151 152 // hashKey returns the hash of key as an ephemeral buffer. 153 // The caller must not hold onto the return value because it will become 154 // invalid on the next call to hashKey or secKey. 155 /*func (t *ZkTrie) hashKey(key []byte) []byte { 156 if len(key) != 32 { 157 panic("non byte32 input to hashKey") 158 } 159 low16 := new(big.Int).SetBytes(key[:16]) 160 high16 := new(big.Int).SetBytes(key[16:]) 161 hash, err := poseidon.Hash([]*big.Int{low16, high16}) 162 if err != nil { 163 panic(err) 164 } 165 return hash.Bytes() 166 } 167 */ 168 169 // Prove constructs a merkle proof for key. The result contains all encoded nodes 170 // on the path to the value at key. The value itself is also included in the last 171 // node and can be retrieved by verifying the proof. 172 // 173 // If the trie does not contain a value for key, the returned proof contains all 174 // nodes of the longest existing prefix of the key (at least the root node), ending 175 // with the node that proves the absence of the key. 176 func (t *ZkTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error { 177 err := t.ZkTrie.Prove(key, fromLevel, func(n *zktrie.Node) error { 178 nodeHash, err := n.NodeHash() 179 if err != nil { 180 return err 181 } 182 183 if n.Type == zktrie.NodeTypeLeaf_New { 184 preImage := t.GetKey(n.NodeKey.Bytes()) 185 if len(preImage) > 0 { 186 n.KeyPreimage = &zkt.Byte32{} 187 copy(n.KeyPreimage[:], preImage) 188 //return fmt.Errorf("key preimage not found for [%x] ref %x", n.NodeKey.Bytes(), k.Bytes()) 189 } 190 } 191 return proofDb.Put(nodeHash[:], n.Value()) 192 }) 193 if err != nil { 194 return err 195 } 196 197 // we put this special kv pair in db so we can distinguish the type and 198 // make suitable Proof 199 return proofDb.Put(magicHash, zktrie.ProofMagicBytes()) 200 } 201 202 // VerifyProof checks merkle proofs. The given proof must contain the value for 203 // key in a trie with the given root hash. VerifyProof returns an error if the 204 // proof contains invalid trie nodes or the wrong value. 205 func VerifyProofSMT(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) { 206 207 h := zkt.NewHashFromBytes(rootHash.Bytes()) 208 k, err := zkt.ToSecureKey(key) 209 if err != nil { 210 return nil, err 211 } 212 213 proof, n, err := zktrie.BuildZkTrieProof(h, k, len(key)*8, func(key *zkt.Hash) (*zktrie.Node, error) { 214 buf, _ := proofDb.Get(key[:]) 215 if buf == nil { 216 return nil, zktrie.ErrKeyNotFound 217 } 218 n, err := zktrie.NewNodeFromBytes(buf) 219 return n, err 220 }) 221 222 if err != nil { 223 // do not contain the key 224 return nil, err 225 } else if !proof.Existence { 226 return nil, nil 227 } 228 229 if zktrie.VerifyProofZkTrie(h, proof, n) { 230 return n.Data(), nil 231 } else { 232 return nil, fmt.Errorf("bad proof node %v", proof) 233 } 234 }