github.com/cryptotooltop/go-ethereum@v0.0.0-20231103184714-151d1922f3e5/trie/zk_trie_impl_test.go (about)

     1  package trie
     2  
     3  import (
     4  	"math/big"
     5  	"testing"
     6  
     7  	"github.com/iden3/go-iden3-crypto/constants"
     8  	cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/require"
    11  
    12  	zktrie "github.com/scroll-tech/zktrie/trie"
    13  	zkt "github.com/scroll-tech/zktrie/types"
    14  
    15  	"github.com/scroll-tech/go-ethereum/common"
    16  	"github.com/scroll-tech/go-ethereum/core/types"
    17  	"github.com/scroll-tech/go-ethereum/ethdb/memorydb"
    18  )
    19  
    20  // we do not need zktrie impl anymore, only made a wrapper for adapting testing
    21  type zkTrieImplTestWrapper struct {
    22  	*zktrie.ZkTrieImpl
    23  }
    24  
    25  func newZkTrieImpl(storage *ZktrieDatabase, maxLevels int) (*zkTrieImplTestWrapper, error) {
    26  	return newZkTrieImplWithRoot(storage, &zkt.HashZero, maxLevels)
    27  }
    28  
    29  // NewZkTrieImplWithRoot loads a new ZkTrieImpl. If in the storage already exists one
    30  // will open that one, if not, will create a new one.
    31  func newZkTrieImplWithRoot(storage *ZktrieDatabase, root *zkt.Hash, maxLevels int) (*zkTrieImplTestWrapper, error) {
    32  	impl, err := zktrie.NewZkTrieImplWithRoot(storage, root, maxLevels)
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  
    37  	return &zkTrieImplTestWrapper{impl}, nil
    38  }
    39  
    40  // AddWord
    41  // Deprecated: Add a Bytes32 kv to ZkTrieImpl, only for testing
    42  func (mt *zkTrieImplTestWrapper) AddWord(kPreimage, vPreimage *zkt.Byte32) error {
    43  
    44  	k, err := kPreimage.Hash()
    45  	if err != nil {
    46  		return err
    47  	}
    48  
    49  	if v, _ := mt.TryGet(k.Bytes()); v != nil {
    50  		return zktrie.ErrEntryIndexAlreadyExists
    51  	}
    52  
    53  	return mt.ZkTrieImpl.TryUpdate(zkt.NewHashFromBigInt(k), 1, []zkt.Byte32{*vPreimage})
    54  }
    55  
    56  // GetLeafNodeByWord
    57  // Deprecated: Get a Bytes32 kv to ZkTrieImpl, only for testing
    58  func (mt *zkTrieImplTestWrapper) GetLeafNodeByWord(kPreimage *zkt.Byte32) (*zktrie.Node, error) {
    59  	k, err := kPreimage.Hash()
    60  	if err != nil {
    61  		return nil, err
    62  	}
    63  	return mt.ZkTrieImpl.GetLeafNode(zkt.NewHashFromBigInt(k))
    64  }
    65  
    66  // Deprecated: only for testing
    67  func (mt *zkTrieImplTestWrapper) UpdateWord(kPreimage, vPreimage *zkt.Byte32) error {
    68  	k, err := kPreimage.Hash()
    69  	if err != nil {
    70  		return err
    71  	}
    72  
    73  	return mt.ZkTrieImpl.TryUpdate(zkt.NewHashFromBigInt(k), 1, []zkt.Byte32{*vPreimage})
    74  }
    75  
    76  // Deprecated: only for testing
    77  func (mt *zkTrieImplTestWrapper) DeleteWord(kPreimage *zkt.Byte32) error {
    78  	k, err := kPreimage.Hash()
    79  	if err != nil {
    80  		return err
    81  	}
    82  	return mt.ZkTrieImpl.TryDelete(zkt.NewHashFromBigInt(k))
    83  }
    84  
    85  func (mt *zkTrieImplTestWrapper) TryGet(key []byte) ([]byte, error) {
    86  	return mt.ZkTrieImpl.TryGet(zkt.NewHashFromBytes(key))
    87  }
    88  
    89  func (mt *zkTrieImplTestWrapper) TryDelete(key []byte) error {
    90  	return mt.ZkTrieImpl.TryDelete(zkt.NewHashFromBytes(key))
    91  }
    92  
    93  // TryUpdateAccount will abstract the write of an account to the trie
    94  func (mt *zkTrieImplTestWrapper) TryUpdateAccount(key []byte, acc *types.StateAccount) error {
    95  	value, flag := acc.MarshalFields()
    96  	return mt.ZkTrieImpl.TryUpdate(zkt.NewHashFromBytes(key), flag, value)
    97  }
    98  
    99  // NewHashFromHex returns a *Hash representation of the given hex string
   100  func NewHashFromHex(h string) (*zkt.Hash, error) {
   101  	return zkt.NewHashFromCheckedBytes(common.FromHex(h))
   102  }
   103  
   104  type Fatalable interface {
   105  	Fatal(args ...interface{})
   106  }
   107  
   108  func newTestingMerkle(f Fatalable, numLevels int) *zkTrieImplTestWrapper {
   109  	mt, err := newZkTrieImpl(NewZktrieDatabase((memorydb.New())), numLevels)
   110  	if err != nil {
   111  		f.Fatal(err)
   112  		return nil
   113  	}
   114  	return mt
   115  }
   116  
   117  func TestHashParsers(t *testing.T) {
   118  	h0 := zkt.NewHashFromBigInt(big.NewInt(0))
   119  	assert.Equal(t, "0", h0.String())
   120  	h1 := zkt.NewHashFromBigInt(big.NewInt(1))
   121  	assert.Equal(t, "1", h1.String())
   122  	h10 := zkt.NewHashFromBigInt(big.NewInt(10))
   123  	assert.Equal(t, "10", h10.String())
   124  
   125  	h7l := zkt.NewHashFromBigInt(big.NewInt(1234567))
   126  	assert.Equal(t, "1234567", h7l.String())
   127  	h8l := zkt.NewHashFromBigInt(big.NewInt(12345678))
   128  	assert.Equal(t, "12345678...", h8l.String())
   129  
   130  	b, ok := new(big.Int).SetString("4932297968297298434239270129193057052722409868268166443802652458940273154854", 10) //nolint:lll
   131  	assert.True(t, ok)
   132  	h := zkt.NewHashFromBigInt(b)
   133  	assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", h.BigInt().String()) //nolint:lll
   134  	assert.Equal(t, "49322979...", h.String())
   135  	assert.Equal(t, "0ae794eb9c3d8bbb9002e993fc2ed301dcbd2af5508ed072c375e861f1aa5b26", h.Hex())
   136  
   137  	b1, err := zkt.NewBigIntFromHashBytes(b.Bytes())
   138  	assert.Nil(t, err)
   139  	assert.Equal(t, new(big.Int).SetBytes(b.Bytes()).String(), b1.String())
   140  
   141  	b2, err := zkt.NewHashFromCheckedBytes(b.Bytes())
   142  	assert.Nil(t, err)
   143  	assert.Equal(t, b.String(), b2.BigInt().String())
   144  
   145  	h2, err := NewHashFromHex(h.Hex())
   146  	assert.Nil(t, err)
   147  	assert.Equal(t, h, h2)
   148  	_, err = NewHashFromHex("0x12")
   149  	assert.NotNil(t, err)
   150  
   151  	// check limits
   152  	a := new(big.Int).Sub(constants.Q, big.NewInt(1))
   153  	testHashParsers(t, a)
   154  	a = big.NewInt(int64(1))
   155  	testHashParsers(t, a)
   156  }
   157  
   158  func testHashParsers(t *testing.T, a *big.Int) {
   159  	require.True(t, cryptoUtils.CheckBigIntInField(a))
   160  	h := zkt.NewHashFromBigInt(a)
   161  	assert.Equal(t, a, h.BigInt())
   162  	hFromBytes, err := zkt.NewHashFromCheckedBytes(h.Bytes())
   163  	assert.Nil(t, err)
   164  	assert.Equal(t, h, hFromBytes)
   165  	assert.Equal(t, a, hFromBytes.BigInt())
   166  	assert.Equal(t, a.String(), hFromBytes.BigInt().String())
   167  	hFromHex, err := NewHashFromHex(h.Hex())
   168  	assert.Nil(t, err)
   169  	assert.Equal(t, h, hFromHex)
   170  
   171  	aBIFromHBytes, err := zkt.NewBigIntFromHashBytes(h.Bytes())
   172  	assert.Nil(t, err)
   173  	assert.Equal(t, a, aBIFromHBytes)
   174  	assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String())
   175  }
   176  
   177  func TestMerkleTree_AddUpdateGetWord(t *testing.T) {
   178  	mt := newTestingMerkle(t, 10)
   179  	err := mt.AddWord(&zkt.Byte32{1}, &zkt.Byte32{2})
   180  	assert.Nil(t, err)
   181  	err = mt.AddWord(&zkt.Byte32{3}, &zkt.Byte32{4})
   182  	assert.Nil(t, err)
   183  	err = mt.AddWord(&zkt.Byte32{5}, &zkt.Byte32{6})
   184  	assert.Nil(t, err)
   185  	err = mt.AddWord(&zkt.Byte32{5}, &zkt.Byte32{7})
   186  	assert.Equal(t, zktrie.ErrEntryIndexAlreadyExists, err)
   187  
   188  	node, err := mt.GetLeafNodeByWord(&zkt.Byte32{1})
   189  	assert.Nil(t, err)
   190  	assert.Equal(t, len(node.ValuePreimage), 1)
   191  	assert.Equal(t, (&zkt.Byte32{2})[:], node.ValuePreimage[0][:])
   192  	node, err = mt.GetLeafNodeByWord(&zkt.Byte32{3})
   193  	assert.Nil(t, err)
   194  	assert.Equal(t, len(node.ValuePreimage), 1)
   195  	assert.Equal(t, (&zkt.Byte32{4})[:], node.ValuePreimage[0][:])
   196  	node, err = mt.GetLeafNodeByWord(&zkt.Byte32{5})
   197  	assert.Nil(t, err)
   198  	assert.Equal(t, len(node.ValuePreimage), 1)
   199  	assert.Equal(t, (&zkt.Byte32{6})[:], node.ValuePreimage[0][:])
   200  
   201  	err = mt.UpdateWord(&zkt.Byte32{1}, &zkt.Byte32{7})
   202  	assert.Nil(t, err)
   203  	err = mt.UpdateWord(&zkt.Byte32{3}, &zkt.Byte32{8})
   204  	assert.Nil(t, err)
   205  	err = mt.UpdateWord(&zkt.Byte32{5}, &zkt.Byte32{9})
   206  	assert.Nil(t, err)
   207  
   208  	node, err = mt.GetLeafNodeByWord(&zkt.Byte32{1})
   209  	assert.Nil(t, err)
   210  	assert.Equal(t, len(node.ValuePreimage), 1)
   211  	assert.Equal(t, (&zkt.Byte32{7})[:], node.ValuePreimage[0][:])
   212  	node, err = mt.GetLeafNodeByWord(&zkt.Byte32{3})
   213  	assert.Nil(t, err)
   214  	assert.Equal(t, len(node.ValuePreimage), 1)
   215  	assert.Equal(t, (&zkt.Byte32{8})[:], node.ValuePreimage[0][:])
   216  	node, err = mt.GetLeafNodeByWord(&zkt.Byte32{5})
   217  	assert.Nil(t, err)
   218  	assert.Equal(t, len(node.ValuePreimage), 1)
   219  	assert.Equal(t, (&zkt.Byte32{9})[:], node.ValuePreimage[0][:])
   220  	_, err = mt.GetLeafNodeByWord(&zkt.Byte32{100})
   221  	assert.Equal(t, zktrie.ErrKeyNotFound, err)
   222  }
   223  
   224  func TestMerkleTree_UpdateAccount(t *testing.T) {
   225  
   226  	mt := newTestingMerkle(t, 10)
   227  
   228  	acc1 := &types.StateAccount{
   229  		Nonce:            1,
   230  		Balance:          big.NewInt(10000000),
   231  		Root:             common.HexToHash("22fb59aa5410ed465267023713ab42554c250f394901455a3366e223d5f7d147"),
   232  		KeccakCodeHash:   common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(),
   233  		PoseidonCodeHash: common.HexToHash("0c0a77f6e063b4b62eb7d9ed6f427cf687d8d0071d751850cfe5d136bc60d3ab").Bytes(),
   234  		CodeSize:         0,
   235  	}
   236  
   237  	err := mt.TryUpdateAccount(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes(), acc1)
   238  	assert.Nil(t, err)
   239  
   240  	acc2 := &types.StateAccount{
   241  		Nonce:            5,
   242  		Balance:          big.NewInt(50000000),
   243  		Root:             common.HexToHash("0"),
   244  		KeccakCodeHash:   common.HexToHash("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(),
   245  		PoseidonCodeHash: common.HexToHash("05d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").Bytes(),
   246  		CodeSize:         5,
   247  	}
   248  	err = mt.TryUpdateAccount(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes(), acc2)
   249  	assert.Nil(t, err)
   250  
   251  	bt, err := mt.TryGet(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes())
   252  	assert.Nil(t, err)
   253  
   254  	acc, err := types.UnmarshalStateAccount(bt)
   255  	assert.Nil(t, err)
   256  	assert.Equal(t, acc1.Nonce, acc.Nonce)
   257  	assert.Equal(t, acc1.Balance.Uint64(), acc.Balance.Uint64())
   258  	assert.Equal(t, acc1.Root.Bytes(), acc.Root.Bytes())
   259  	assert.Equal(t, acc1.KeccakCodeHash, acc.KeccakCodeHash)
   260  	assert.Equal(t, acc1.PoseidonCodeHash, acc.PoseidonCodeHash)
   261  	assert.Equal(t, acc1.CodeSize, acc.CodeSize)
   262  
   263  	bt, err = mt.TryGet(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes())
   264  	assert.Nil(t, err)
   265  
   266  	acc, err = types.UnmarshalStateAccount(bt)
   267  	assert.Nil(t, err)
   268  	assert.Equal(t, acc2.Nonce, acc.Nonce)
   269  	assert.Equal(t, acc2.Balance.Uint64(), acc.Balance.Uint64())
   270  	assert.Equal(t, acc2.Root.Bytes(), acc.Root.Bytes())
   271  	assert.Equal(t, acc2.KeccakCodeHash, acc.KeccakCodeHash)
   272  	assert.Equal(t, acc2.PoseidonCodeHash, acc.PoseidonCodeHash)
   273  	assert.Equal(t, acc2.CodeSize, acc.CodeSize)
   274  
   275  	bt, err = mt.TryGet(common.HexToAddress("0x8dE13967F19410A7991D63c2c0179feBFDA0c261").Bytes())
   276  	assert.Nil(t, err)
   277  	assert.Nil(t, bt)
   278  
   279  	err = mt.TryDelete(common.HexToHash("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes())
   280  	assert.Nil(t, err)
   281  
   282  	bt, err = mt.TryGet(common.HexToAddress("0x05fDbDfaE180345C6Cff5316c286727CF1a43327").Bytes())
   283  	assert.Nil(t, err)
   284  	assert.Nil(t, bt)
   285  
   286  	err = mt.TryDelete(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes())
   287  	assert.Nil(t, err)
   288  
   289  	bt, err = mt.TryGet(common.HexToAddress("0x4cb1aB63aF5D8931Ce09673EbD8ae2ce16fD6571").Bytes())
   290  	assert.Nil(t, err)
   291  	assert.Nil(t, bt)
   292  }