github.com/cryptotooltop/go-ethereum@v0.0.0-20231103184714-151d1922f3e5/trie/zk_trie_impl_test.go (about) 1 package trie 2 3 import ( 4 "math/big" 5 "testing" 6 7 "github.com/iden3/go-iden3-crypto/constants" 8 cryptoUtils "github.com/iden3/go-iden3-crypto/utils" 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/require" 11 12 zktrie "github.com/scroll-tech/zktrie/trie" 13 zkt "github.com/scroll-tech/zktrie/types" 14 15 "github.com/scroll-tech/go-ethereum/common" 16 "github.com/scroll-tech/go-ethereum/core/types" 17 "github.com/scroll-tech/go-ethereum/ethdb/memorydb" 18 ) 19 20 // we do not need zktrie impl anymore, only made a wrapper for adapting testing 21 type zkTrieImplTestWrapper struct { 22 *zktrie.ZkTrieImpl 23 } 24 25 func newZkTrieImpl(storage *ZktrieDatabase, maxLevels int) (*zkTrieImplTestWrapper, error) { 26 return newZkTrieImplWithRoot(storage, &zkt.HashZero, maxLevels) 27 } 28 29 // NewZkTrieImplWithRoot loads a new ZkTrieImpl. If in the storage already exists one 30 // will open that one, if not, will create a new one. 31 func newZkTrieImplWithRoot(storage *ZktrieDatabase, root *zkt.Hash, maxLevels int) (*zkTrieImplTestWrapper, error) { 32 impl, err := zktrie.NewZkTrieImplWithRoot(storage, root, maxLevels) 33 if err != nil { 34 return nil, err 35 } 36 37 return &zkTrieImplTestWrapper{impl}, nil 38 } 39 40 // AddWord 41 // Deprecated: Add a Bytes32 kv to ZkTrieImpl, only for testing 42 func (mt *zkTrieImplTestWrapper) AddWord(kPreimage, vPreimage *zkt.Byte32) error { 43 44 k, err := kPreimage.Hash() 45 if err != nil { 46 return err 47 } 48 49 if v, _ := mt.TryGet(k.Bytes()); v != nil { 50 return zktrie.ErrEntryIndexAlreadyExists 51 } 52 53 return mt.ZkTrieImpl.TryUpdate(zkt.NewHashFromBigInt(k), 1, []zkt.Byte32{*vPreimage}) 54 } 55 56 // GetLeafNodeByWord 57 // Deprecated: Get a Bytes32 kv to ZkTrieImpl, only for testing 58 func (mt *zkTrieImplTestWrapper) GetLeafNodeByWord(kPreimage *zkt.Byte32) (*zktrie.Node, error) { 59 k, err := kPreimage.Hash() 60 if err != nil { 61 return nil, err 62 } 63 return mt.ZkTrieImpl.GetLeafNode(zkt.NewHashFromBigInt(k)) 64 } 65 66 // Deprecated: only for testing 67 func (mt *zkTrieImplTestWrapper) UpdateWord(kPreimage, vPreimage *zkt.Byte32) error { 68 k, err := kPreimage.Hash() 69 if err != nil { 70 return err 71 } 72 73 return mt.ZkTrieImpl.TryUpdate(zkt.NewHashFromBigInt(k), 1, []zkt.Byte32{*vPreimage}) 74 } 75 76 // Deprecated: only for testing 77 func (mt *zkTrieImplTestWrapper) DeleteWord(kPreimage *zkt.Byte32) error { 78 k, err := kPreimage.Hash() 79 if err != nil { 80 return err 81 } 82 return mt.ZkTrieImpl.TryDelete(zkt.NewHashFromBigInt(k)) 83 } 84 85 func (mt *zkTrieImplTestWrapper) TryGet(key []byte) ([]byte, error) { 86 return mt.ZkTrieImpl.TryGet(zkt.NewHashFromBytes(key)) 87 } 88 89 func (mt *zkTrieImplTestWrapper) TryDelete(key []byte) error { 90 return mt.ZkTrieImpl.TryDelete(zkt.NewHashFromBytes(key)) 91 } 92 93 // TryUpdateAccount will abstract the write of an account to the trie 94 func (mt *zkTrieImplTestWrapper) TryUpdateAccount(key []byte, acc *types.StateAccount) error { 95 value, flag := acc.MarshalFields() 96 return mt.ZkTrieImpl.TryUpdate(zkt.NewHashFromBytes(key), flag, value) 97 } 98 99 // NewHashFromHex returns a *Hash representation of the given hex string 100 func NewHashFromHex(h string) (*zkt.Hash, error) { 101 return zkt.NewHashFromCheckedBytes(common.FromHex(h)) 102 } 103 104 type Fatalable interface { 105 Fatal(args ...interface{}) 106 } 107 108 func newTestingMerkle(f Fatalable, numLevels int) *zkTrieImplTestWrapper { 109 mt, err := newZkTrieImpl(NewZktrieDatabase((memorydb.New())), numLevels) 110 if err != nil { 111 f.Fatal(err) 112 return nil 113 } 114 return mt 115 } 116 117 func TestHashParsers(t *testing.T) { 118 h0 := zkt.NewHashFromBigInt(big.NewInt(0)) 119 assert.Equal(t, "0", h0.String()) 120 h1 := zkt.NewHashFromBigInt(big.NewInt(1)) 121 assert.Equal(t, "1", h1.String()) 122 h10 := zkt.NewHashFromBigInt(big.NewInt(10)) 123 assert.Equal(t, "10", h10.String()) 124 125 h7l := zkt.NewHashFromBigInt(big.NewInt(1234567)) 126 assert.Equal(t, "1234567", h7l.String()) 127 h8l := zkt.NewHashFromBigInt(big.NewInt(12345678)) 128 assert.Equal(t, "12345678...", h8l.String()) 129 130 b, ok := new(big.Int).SetString("4932297968297298434239270129193057052722409868268166443802652458940273154854", 10) //nolint:lll 131 assert.True(t, ok) 132 h := zkt.NewHashFromBigInt(b) 133 assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", h.BigInt().String()) //nolint:lll 134 assert.Equal(t, "49322979...", h.String()) 135 assert.Equal(t, "0ae794eb9c3d8bbb9002e993fc2ed301dcbd2af5508ed072c375e861f1aa5b26", h.Hex()) 136 137 b1, err := zkt.NewBigIntFromHashBytes(b.Bytes()) 138 assert.Nil(t, err) 139 assert.Equal(t, new(big.Int).SetBytes(b.Bytes()).String(), b1.String()) 140 141 b2, err := zkt.NewHashFromCheckedBytes(b.Bytes()) 142 assert.Nil(t, err) 143 assert.Equal(t, b.String(), b2.BigInt().String()) 144 145 h2, err := NewHashFromHex(h.Hex()) 146 assert.Nil(t, err) 147 assert.Equal(t, h, h2) 148 _, err = NewHashFromHex("0x12") 149 assert.NotNil(t, err) 150 151 // check limits 152 a := new(big.Int).Sub(constants.Q, big.NewInt(1)) 153 testHashParsers(t, a) 154 a = big.NewInt(int64(1)) 155 testHashParsers(t, a) 156 } 157 158 func testHashParsers(t *testing.T, a *big.Int) { 159 require.True(t, cryptoUtils.CheckBigIntInField(a)) 160 h := zkt.NewHashFromBigInt(a) 161 assert.Equal(t, a, h.BigInt()) 162 hFromBytes, err := zkt.NewHashFromCheckedBytes(h.Bytes()) 163 assert.Nil(t, err) 164 assert.Equal(t, h, hFromBytes) 165 assert.Equal(t, a, hFromBytes.BigInt()) 166 assert.Equal(t, a.String(), hFromBytes.BigInt().String()) 167 hFromHex, err := NewHashFromHex(h.Hex()) 168 assert.Nil(t, err) 169 assert.Equal(t, h, hFromHex) 170 171 aBIFromHBytes, err := zkt.NewBigIntFromHashBytes(h.Bytes()) 172 assert.Nil(t, err) 173 assert.Equal(t, a, aBIFromHBytes) 174 assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String()) 175 } 176 177 func TestMerkleTree_AddUpdateGetWord(t *testing.T) { 178 mt := newTestingMerkle(t, 10) 179 err := mt.AddWord(&zkt.Byte32{1}, &zkt.Byte32{2}) 180 assert.Nil(t, err) 181 err = mt.AddWord(&zkt.Byte32{3}, &zkt.Byte32{4}) 182 assert.Nil(t, err) 183 err = mt.AddWord(&zkt.Byte32{5}, &zkt.Byte32{6}) 184 assert.Nil(t, err) 185 err = mt.AddWord(&zkt.Byte32{5}, &zkt.Byte32{7}) 186 assert.Equal(t, zktrie.ErrEntryIndexAlreadyExists, err) 187 188 node, err := mt.GetLeafNodeByWord(&zkt.Byte32{1}) 189 assert.Nil(t, err) 190 assert.Equal(t, len(node.ValuePreimage), 1) 191 assert.Equal(t, (&zkt.Byte32{2})[:], node.ValuePreimage[0][:]) 192 node, err = mt.GetLeafNodeByWord(&zkt.Byte32{3}) 193 assert.Nil(t, err) 194 assert.Equal(t, len(node.ValuePreimage), 1) 195 assert.Equal(t, (&zkt.Byte32{4})[:], node.ValuePreimage[0][:]) 196 node, err = mt.GetLeafNodeByWord(&zkt.Byte32{5}) 197 assert.Nil(t, err) 198 assert.Equal(t, len(node.ValuePreimage), 1) 199 assert.Equal(t, (&zkt.Byte32{6})[:], node.ValuePreimage[0][:]) 200 201 err = mt.UpdateWord(&zkt.Byte32{1}, &zkt.Byte32{7}) 202 assert.Nil(t, err) 203 err = mt.UpdateWord(&zkt.Byte32{3}, &zkt.Byte32{8}) 204 assert.Nil(t, err) 205 err = mt.UpdateWord(&zkt.Byte32{5}, &zkt.Byte32{9}) 206 assert.Nil(t, err) 207 208 node, err = mt.GetLeafNodeByWord(&zkt.Byte32{1}) 209 assert.Nil(t, err) 210 assert.Equal(t, len(node.ValuePreimage), 1) 211 assert.Equal(t, (&zkt.Byte32{7})[:], node.ValuePreimage[0][:]) 212 node, err = mt.GetLeafNodeByWord(&zkt.Byte32{3}) 213 assert.Nil(t, err) 214 assert.Equal(t, len(node.ValuePreimage), 1) 215 assert.Equal(t, (&zkt.Byte32{8})[:], node.ValuePreimage[0][:]) 216 node, err = mt.GetLeafNodeByWord(&zkt.Byte32{5}) 217 assert.Nil(t, err) 218 assert.Equal(t, len(node.ValuePreimage), 1) 219 assert.Equal(t, (&zkt.Byte32{9})[:], node.ValuePreimage[0][:]) 220 _, err = mt.GetLeafNodeByWord(&zkt.Byte32{100}) 221 assert.Equal(t, zktrie.ErrKeyNotFound, err) 222 } 223 224 func TestMerkleTree_UpdateAccount(t *testing.T) { 225 226 mt := newTestingMerkle(t, 10) 227 228 acc1 := &types.StateAccount{ 229 Nonce: 1, 230 Balance: big.NewInt(10000000), 231 Root: common.HexToHash("22fb59aa5410ed465267023713ab42554c250f394901455a3366e223d5f7d147"), 232 KeccakCodeHash: common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(), 233 PoseidonCodeHash: common.HexToHash("0c0a77f6e063b4b62eb7d9ed6f427cf687d8d0071d751850cfe5d136bc60d3ab").Bytes(), 234 CodeSize: 0, 235 } 236 237 err := mt.TryUpdateAccount(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes(), acc1) 238 assert.Nil(t, err) 239 240 acc2 := &types.StateAccount{ 241 Nonce: 5, 242 Balance: big.NewInt(50000000), 243 Root: common.HexToHash("0"), 244 KeccakCodeHash: common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(), 245 PoseidonCodeHash: common.HexToHash("05d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(), 246 CodeSize: 5, 247 } 248 err = mt.TryUpdateAccount(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes(), acc2) 249 assert.Nil(t, err) 250 251 bt, err := mt.TryGet(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes()) 252 assert.Nil(t, err) 253 254 acc, err := types.UnmarshalStateAccount(bt) 255 assert.Nil(t, err) 256 assert.Equal(t, acc1.Nonce, acc.Nonce) 257 assert.Equal(t, acc1.Balance.Uint64(), acc.Balance.Uint64()) 258 assert.Equal(t, acc1.Root.Bytes(), acc.Root.Bytes()) 259 assert.Equal(t, acc1.KeccakCodeHash, acc.KeccakCodeHash) 260 assert.Equal(t, acc1.PoseidonCodeHash, acc.PoseidonCodeHash) 261 assert.Equal(t, acc1.CodeSize, acc.CodeSize) 262 263 bt, err = mt.TryGet(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes()) 264 assert.Nil(t, err) 265 266 acc, err = types.UnmarshalStateAccount(bt) 267 assert.Nil(t, err) 268 assert.Equal(t, acc2.Nonce, acc.Nonce) 269 assert.Equal(t, acc2.Balance.Uint64(), acc.Balance.Uint64()) 270 assert.Equal(t, acc2.Root.Bytes(), acc.Root.Bytes()) 271 assert.Equal(t, acc2.KeccakCodeHash, acc.KeccakCodeHash) 272 assert.Equal(t, acc2.PoseidonCodeHash, acc.PoseidonCodeHash) 273 assert.Equal(t, acc2.CodeSize, acc.CodeSize) 274 275 bt, err = mt.TryGet(common.HexToAddress("0x8dE13967F19410A7991D63c2c0179feBFDA0c261").Bytes()) 276 assert.Nil(t, err) 277 assert.Nil(t, bt) 278 279 err = mt.TryDelete(common.HexToHash("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes()) 280 assert.Nil(t, err) 281 282 bt, err = mt.TryGet(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes()) 283 assert.Nil(t, err) 284 assert.Nil(t, bt) 285 286 err = mt.TryDelete(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes()) 287 assert.Nil(t, err) 288 289 bt, err = mt.TryGet(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes()) 290 assert.Nil(t, err) 291 assert.Nil(t, bt) 292 }