github.com/cryptotooltop/go-ethereum@v0.0.0-20231103184714-151d1922f3e5/trie/zk_trie_proof_test.go (about)

     1  // Copyright 2015 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package trie
    18  
    19  import (
    20  	"bytes"
    21  	mrand "math/rand"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/stretchr/testify/assert"
    26  
    27  	zkt "github.com/scroll-tech/zktrie/types"
    28  
    29  	"github.com/scroll-tech/go-ethereum/common"
    30  	"github.com/scroll-tech/go-ethereum/crypto"
    31  	"github.com/scroll-tech/go-ethereum/ethdb/memorydb"
    32  )
    33  
    34  func init() {
    35  	mrand.Seed(time.Now().Unix())
    36  }
    37  
    38  // makeProvers creates Merkle trie provers based on different implementations to
    39  // test all variations.
    40  func makeSMTProvers(mt *ZkTrie) []func(key []byte) *memorydb.Database {
    41  	var provers []func(key []byte) *memorydb.Database
    42  
    43  	// Create a direct trie based Merkle prover
    44  	provers = append(provers, func(key []byte) *memorydb.Database {
    45  		word := zkt.NewByte32FromBytesPaddingZero(key)
    46  		k, err := word.Hash()
    47  		if err != nil {
    48  			panic(err)
    49  		}
    50  		proof := memorydb.New()
    51  		err = mt.Prove(common.BytesToHash(k.Bytes()).Bytes(), 0, proof)
    52  		if err != nil {
    53  			panic(err)
    54  		}
    55  
    56  		return proof
    57  	})
    58  	return provers
    59  }
    60  
    61  func verifyValue(proveVal []byte, vPreimage []byte) bool {
    62  	return bytes.Equal(proveVal, vPreimage)
    63  }
    64  
    65  func TestSMTOneElementProof(t *testing.T) {
    66  	tr, _ := NewZkTrie(common.Hash{}, NewZktrieDatabase((memorydb.New())))
    67  	mt := &zkTrieImplTestWrapper{tr.Tree()}
    68  	err := mt.UpdateWord(
    69  		zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("k"), 32)),
    70  		zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("v"), 32)),
    71  	)
    72  	assert.Nil(t, err)
    73  	for i, prover := range makeSMTProvers(tr) {
    74  		keyBytes := bytes.Repeat([]byte("k"), 32)
    75  		proof := prover(keyBytes)
    76  		if proof == nil {
    77  			t.Fatalf("prover %d: nil proof", i)
    78  		}
    79  		if proof.Len() != 2 {
    80  			t.Errorf("prover %d: proof should have 1+1 element (including the magic kv)", i)
    81  		}
    82  		val, err := VerifyProof(common.BytesToHash(mt.Root().Bytes()), keyBytes, proof)
    83  		if err != nil {
    84  			t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
    85  		}
    86  		if !verifyValue(val, bytes.Repeat([]byte("v"), 32)) {
    87  			t.Fatalf("prover %d: verified value mismatch: want 'v' get %x", i, val)
    88  		}
    89  	}
    90  }
    91  
    92  func TestSMTProof(t *testing.T) {
    93  	mt, vals := randomZktrie(t, 500)
    94  	root := mt.Tree().Root()
    95  	for i, prover := range makeSMTProvers(mt) {
    96  		for _, kv := range vals {
    97  			proof := prover(kv.k)
    98  			if proof == nil {
    99  				t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k)
   100  			}
   101  			val, err := VerifyProof(common.BytesToHash(root.Bytes()), kv.k, proof)
   102  			if err != nil {
   103  				t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x\n", i, kv.k, err, proof)
   104  			}
   105  			if !verifyValue(val, zkt.NewByte32FromBytesPaddingZero(kv.v)[:]) {
   106  				t.Fatalf("prover %d: verified value mismatch for key %x, want %x, get %x", i, kv.k, kv.v, val)
   107  			}
   108  		}
   109  	}
   110  }
   111  
   112  func TestSMTBadProof(t *testing.T) {
   113  	mt, vals := randomZktrie(t, 500)
   114  	root := mt.Tree().Root()
   115  	for i, prover := range makeSMTProvers(mt) {
   116  		for _, kv := range vals {
   117  			proof := prover(kv.k)
   118  			if proof == nil {
   119  				t.Fatalf("prover %d: nil proof", i)
   120  			}
   121  			it := proof.NewIterator(nil, nil)
   122  			for i, d := 0, mrand.Intn(proof.Len()); i <= d; i++ {
   123  				it.Next()
   124  			}
   125  			key := it.Key()
   126  			val, _ := proof.Get(key)
   127  			proof.Delete(key)
   128  			it.Release()
   129  
   130  			mutateByte(val)
   131  			proof.Put(crypto.Keccak256(val), val)
   132  
   133  			if _, err := VerifyProof(common.BytesToHash(root.Bytes()), kv.k, proof); err == nil {
   134  				t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k)
   135  			}
   136  		}
   137  	}
   138  }
   139  
   140  // Tests that missing keys can also be proven. The test explicitly uses a single
   141  // entry trie and checks for missing keys both before and after the single entry.
   142  func TestSMTMissingKeyProof(t *testing.T) {
   143  	tr, _ := NewZkTrie(common.Hash{}, NewZktrieDatabase((memorydb.New())))
   144  	mt := &zkTrieImplTestWrapper{tr.Tree()}
   145  	err := mt.UpdateWord(
   146  		zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("k"), 32)),
   147  		zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("v"), 32)),
   148  	)
   149  	assert.Nil(t, err)
   150  
   151  	prover := makeSMTProvers(tr)[0]
   152  
   153  	for i, key := range []string{"a", "j", "l", "z"} {
   154  		keyBytes := bytes.Repeat([]byte(key), 32)
   155  		proof := prover(keyBytes)
   156  
   157  		if proof.Len() != 2 {
   158  			t.Errorf("test %d: proof should have 2 element (with magic kv)", i)
   159  		}
   160  		val, err := VerifyProof(common.BytesToHash(mt.Root().Bytes()), keyBytes, proof)
   161  		if err != nil {
   162  			t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
   163  		}
   164  		if val != nil {
   165  			t.Fatalf("test %d: verified value mismatch: have %x, want nil", i, val)
   166  		}
   167  	}
   168  }
   169  
   170  func randomZktrie(t *testing.T, n int) (*ZkTrie, map[string]*kv) {
   171  	tr, err := NewZkTrie(common.Hash{}, NewZktrieDatabase((memorydb.New())))
   172  	if err != nil {
   173  		panic(err)
   174  	}
   175  	mt := &zkTrieImplTestWrapper{tr.Tree()}
   176  	vals := make(map[string]*kv)
   177  	for i := byte(0); i < 100; i++ {
   178  
   179  		value := &kv{common.LeftPadBytes([]byte{i}, 32), bytes.Repeat([]byte{i}, 32), false}
   180  		value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), bytes.Repeat([]byte{i}, 32), false}
   181  
   182  		err = mt.UpdateWord(zkt.NewByte32FromBytesPaddingZero(value.k), zkt.NewByte32FromBytesPaddingZero(value.v))
   183  		assert.Nil(t, err)
   184  		err = mt.UpdateWord(zkt.NewByte32FromBytesPaddingZero(value2.k), zkt.NewByte32FromBytesPaddingZero(value2.v))
   185  		assert.Nil(t, err)
   186  		vals[string(value.k)] = value
   187  		vals[string(value2.k)] = value2
   188  	}
   189  	for i := 0; i < n; i++ {
   190  		value := &kv{randBytes(32), randBytes(20), false}
   191  		err = mt.UpdateWord(zkt.NewByte32FromBytesPaddingZero(value.k), zkt.NewByte32FromBytesPaddingZero(value.v))
   192  		assert.Nil(t, err)
   193  		vals[string(value.k)] = value
   194  	}
   195  
   196  	return tr, vals
   197  }
   198  
   199  // Tests that new "proof trace" feature
   200  func TestProofWithDeletion(t *testing.T) {
   201  	tr, _ := NewZkTrie(common.Hash{}, NewZktrieDatabase((memorydb.New())))
   202  	mt := &zkTrieImplTestWrapper{tr.Tree()}
   203  	key1 := bytes.Repeat([]byte("l"), 32)
   204  	key2 := bytes.Repeat([]byte("m"), 32)
   205  	err := mt.UpdateWord(
   206  		zkt.NewByte32FromBytesPaddingZero(key1),
   207  		zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("v"), 32)),
   208  	)
   209  	assert.NoError(t, err)
   210  	err = mt.UpdateWord(
   211  		zkt.NewByte32FromBytesPaddingZero(key2),
   212  		zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("n"), 32)),
   213  	)
   214  	assert.NoError(t, err)
   215  
   216  	proof := memorydb.New()
   217  	s_key1, err := zkt.ToSecureKeyBytes(key1)
   218  	assert.NoError(t, err)
   219  
   220  	proofTracer := tr.NewProofTracer()
   221  
   222  	err = proofTracer.Prove(s_key1.Bytes(), 0, proof)
   223  	assert.NoError(t, err)
   224  	nd, err := tr.TryGet(key2)
   225  	assert.NoError(t, err)
   226  
   227  	s_key2, err := zkt.ToSecureKeyBytes(bytes.Repeat([]byte("x"), 32))
   228  	assert.NoError(t, err)
   229  
   230  	err = proofTracer.Prove(s_key2.Bytes(), 0, proof)
   231  	assert.NoError(t, err)
   232  	//assert.Equal(t, len(sibling1), len(delTracer.GetProofs()))
   233  
   234  	siblings, err := proofTracer.GetDeletionProofs()
   235  	assert.NoError(t, err)
   236  	assert.Equal(t, 0, len(siblings))
   237  
   238  	proofTracer.MarkDeletion(s_key1.Bytes())
   239  	siblings, err = proofTracer.GetDeletionProofs()
   240  	assert.NoError(t, err)
   241  	assert.Equal(t, 1, len(siblings))
   242  	l := len(siblings[0])
   243  	// a hacking to grep the value part directly from the encoded leaf node,
   244  	// notice the sibling of key `k*32`` is just the leaf of key `m*32`
   245  	assert.Equal(t, siblings[0][l-33:l-1], nd)
   246  
   247  	// Marking a key that is currently not hit (but terminated by an empty node)
   248  	// also causes it to be added to the deletion proof
   249  	proofTracer.MarkDeletion(s_key2.Bytes())
   250  	siblings, err = proofTracer.GetDeletionProofs()
   251  	assert.NoError(t, err)
   252  	assert.Equal(t, 2, len(siblings))
   253  
   254  	key3 := bytes.Repeat([]byte("x"), 32)
   255  	err = mt.UpdateWord(
   256  		zkt.NewByte32FromBytesPaddingZero(key3),
   257  		zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("z"), 32)),
   258  	)
   259  	assert.NoError(t, err)
   260  
   261  	proofTracer = tr.NewProofTracer()
   262  	err = proofTracer.Prove(s_key1.Bytes(), 0, proof)
   263  	assert.NoError(t, err)
   264  	err = proofTracer.Prove(s_key2.Bytes(), 0, proof)
   265  	assert.NoError(t, err)
   266  
   267  	proofTracer.MarkDeletion(s_key1.Bytes())
   268  	siblings, err = proofTracer.GetDeletionProofs()
   269  	assert.NoError(t, err)
   270  	assert.Equal(t, 1, len(siblings))
   271  
   272  	proofTracer.MarkDeletion(s_key2.Bytes())
   273  	siblings, err = proofTracer.GetDeletionProofs()
   274  	assert.NoError(t, err)
   275  	assert.Equal(t, 2, len(siblings))
   276  
   277  	// one of the siblings is just leaf for key2, while
   278  	// another one must be a middle node
   279  	match1 := bytes.Equal(siblings[0][l-33:l-1], nd)
   280  	match2 := bytes.Equal(siblings[1][l-33:l-1], nd)
   281  	assert.True(t, match1 || match2)
   282  	assert.False(t, match1 && match2)
   283  }