github.com/cryptotooltop/go-ethereum@v0.0.0-20231103184714-151d1922f3e5/trie/zk_trie_proof_test.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package trie 18 19 import ( 20 "bytes" 21 mrand "math/rand" 22 "testing" 23 "time" 24 25 "github.com/stretchr/testify/assert" 26 27 zkt "github.com/scroll-tech/zktrie/types" 28 29 "github.com/scroll-tech/go-ethereum/common" 30 "github.com/scroll-tech/go-ethereum/crypto" 31 "github.com/scroll-tech/go-ethereum/ethdb/memorydb" 32 ) 33 34 func init() { 35 mrand.Seed(time.Now().Unix()) 36 } 37 38 // makeProvers creates Merkle trie provers based on different implementations to 39 // test all variations. 40 func makeSMTProvers(mt *ZkTrie) []func(key []byte) *memorydb.Database { 41 var provers []func(key []byte) *memorydb.Database 42 43 // Create a direct trie based Merkle prover 44 provers = append(provers, func(key []byte) *memorydb.Database { 45 word := zkt.NewByte32FromBytesPaddingZero(key) 46 k, err := word.Hash() 47 if err != nil { 48 panic(err) 49 } 50 proof := memorydb.New() 51 err = mt.Prove(common.BytesToHash(k.Bytes()).Bytes(), 0, proof) 52 if err != nil { 53 panic(err) 54 } 55 56 return proof 57 }) 58 return provers 59 } 60 61 func verifyValue(proveVal []byte, vPreimage []byte) bool { 62 return bytes.Equal(proveVal, vPreimage) 63 } 64 65 func TestSMTOneElementProof(t *testing.T) { 66 tr, _ := NewZkTrie(common.Hash{}, NewZktrieDatabase((memorydb.New()))) 67 mt := &zkTrieImplTestWrapper{tr.Tree()} 68 err := mt.UpdateWord( 69 zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("k"), 32)), 70 zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("v"), 32)), 71 ) 72 assert.Nil(t, err) 73 for i, prover := range makeSMTProvers(tr) { 74 keyBytes := bytes.Repeat([]byte("k"), 32) 75 proof := prover(keyBytes) 76 if proof == nil { 77 t.Fatalf("prover %d: nil proof", i) 78 } 79 if proof.Len() != 2 { 80 t.Errorf("prover %d: proof should have 1+1 element (including the magic kv)", i) 81 } 82 val, err := VerifyProof(common.BytesToHash(mt.Root().Bytes()), keyBytes, proof) 83 if err != nil { 84 t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) 85 } 86 if !verifyValue(val, bytes.Repeat([]byte("v"), 32)) { 87 t.Fatalf("prover %d: verified value mismatch: want 'v' get %x", i, val) 88 } 89 } 90 } 91 92 func TestSMTProof(t *testing.T) { 93 mt, vals := randomZktrie(t, 500) 94 root := mt.Tree().Root() 95 for i, prover := range makeSMTProvers(mt) { 96 for _, kv := range vals { 97 proof := prover(kv.k) 98 if proof == nil { 99 t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k) 100 } 101 val, err := VerifyProof(common.BytesToHash(root.Bytes()), kv.k, proof) 102 if err != nil { 103 t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x\n", i, kv.k, err, proof) 104 } 105 if !verifyValue(val, zkt.NewByte32FromBytesPaddingZero(kv.v)[:]) { 106 t.Fatalf("prover %d: verified value mismatch for key %x, want %x, get %x", i, kv.k, kv.v, val) 107 } 108 } 109 } 110 } 111 112 func TestSMTBadProof(t *testing.T) { 113 mt, vals := randomZktrie(t, 500) 114 root := mt.Tree().Root() 115 for i, prover := range makeSMTProvers(mt) { 116 for _, kv := range vals { 117 proof := prover(kv.k) 118 if proof == nil { 119 t.Fatalf("prover %d: nil proof", i) 120 } 121 it := proof.NewIterator(nil, nil) 122 for i, d := 0, mrand.Intn(proof.Len()); i <= d; i++ { 123 it.Next() 124 } 125 key := it.Key() 126 val, _ := proof.Get(key) 127 proof.Delete(key) 128 it.Release() 129 130 mutateByte(val) 131 proof.Put(crypto.Keccak256(val), val) 132 133 if _, err := VerifyProof(common.BytesToHash(root.Bytes()), kv.k, proof); err == nil { 134 t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k) 135 } 136 } 137 } 138 } 139 140 // Tests that missing keys can also be proven. The test explicitly uses a single 141 // entry trie and checks for missing keys both before and after the single entry. 142 func TestSMTMissingKeyProof(t *testing.T) { 143 tr, _ := NewZkTrie(common.Hash{}, NewZktrieDatabase((memorydb.New()))) 144 mt := &zkTrieImplTestWrapper{tr.Tree()} 145 err := mt.UpdateWord( 146 zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("k"), 32)), 147 zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("v"), 32)), 148 ) 149 assert.Nil(t, err) 150 151 prover := makeSMTProvers(tr)[0] 152 153 for i, key := range []string{"a", "j", "l", "z"} { 154 keyBytes := bytes.Repeat([]byte(key), 32) 155 proof := prover(keyBytes) 156 157 if proof.Len() != 2 { 158 t.Errorf("test %d: proof should have 2 element (with magic kv)", i) 159 } 160 val, err := VerifyProof(common.BytesToHash(mt.Root().Bytes()), keyBytes, proof) 161 if err != nil { 162 t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof) 163 } 164 if val != nil { 165 t.Fatalf("test %d: verified value mismatch: have %x, want nil", i, val) 166 } 167 } 168 } 169 170 func randomZktrie(t *testing.T, n int) (*ZkTrie, map[string]*kv) { 171 tr, err := NewZkTrie(common.Hash{}, NewZktrieDatabase((memorydb.New()))) 172 if err != nil { 173 panic(err) 174 } 175 mt := &zkTrieImplTestWrapper{tr.Tree()} 176 vals := make(map[string]*kv) 177 for i := byte(0); i < 100; i++ { 178 179 value := &kv{common.LeftPadBytes([]byte{i}, 32), bytes.Repeat([]byte{i}, 32), false} 180 value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), bytes.Repeat([]byte{i}, 32), false} 181 182 err = mt.UpdateWord(zkt.NewByte32FromBytesPaddingZero(value.k), zkt.NewByte32FromBytesPaddingZero(value.v)) 183 assert.Nil(t, err) 184 err = mt.UpdateWord(zkt.NewByte32FromBytesPaddingZero(value2.k), zkt.NewByte32FromBytesPaddingZero(value2.v)) 185 assert.Nil(t, err) 186 vals[string(value.k)] = value 187 vals[string(value2.k)] = value2 188 } 189 for i := 0; i < n; i++ { 190 value := &kv{randBytes(32), randBytes(20), false} 191 err = mt.UpdateWord(zkt.NewByte32FromBytesPaddingZero(value.k), zkt.NewByte32FromBytesPaddingZero(value.v)) 192 assert.Nil(t, err) 193 vals[string(value.k)] = value 194 } 195 196 return tr, vals 197 } 198 199 // Tests that new "proof trace" feature 200 func TestProofWithDeletion(t *testing.T) { 201 tr, _ := NewZkTrie(common.Hash{}, NewZktrieDatabase((memorydb.New()))) 202 mt := &zkTrieImplTestWrapper{tr.Tree()} 203 key1 := bytes.Repeat([]byte("l"), 32) 204 key2 := bytes.Repeat([]byte("m"), 32) 205 err := mt.UpdateWord( 206 zkt.NewByte32FromBytesPaddingZero(key1), 207 zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("v"), 32)), 208 ) 209 assert.NoError(t, err) 210 err = mt.UpdateWord( 211 zkt.NewByte32FromBytesPaddingZero(key2), 212 zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("n"), 32)), 213 ) 214 assert.NoError(t, err) 215 216 proof := memorydb.New() 217 s_key1, err := zkt.ToSecureKeyBytes(key1) 218 assert.NoError(t, err) 219 220 proofTracer := tr.NewProofTracer() 221 222 err = proofTracer.Prove(s_key1.Bytes(), 0, proof) 223 assert.NoError(t, err) 224 nd, err := tr.TryGet(key2) 225 assert.NoError(t, err) 226 227 s_key2, err := zkt.ToSecureKeyBytes(bytes.Repeat([]byte("x"), 32)) 228 assert.NoError(t, err) 229 230 err = proofTracer.Prove(s_key2.Bytes(), 0, proof) 231 assert.NoError(t, err) 232 //assert.Equal(t, len(sibling1), len(delTracer.GetProofs())) 233 234 siblings, err := proofTracer.GetDeletionProofs() 235 assert.NoError(t, err) 236 assert.Equal(t, 0, len(siblings)) 237 238 proofTracer.MarkDeletion(s_key1.Bytes()) 239 siblings, err = proofTracer.GetDeletionProofs() 240 assert.NoError(t, err) 241 assert.Equal(t, 1, len(siblings)) 242 l := len(siblings[0]) 243 // a hacking to grep the value part directly from the encoded leaf node, 244 // notice the sibling of key `k*32`` is just the leaf of key `m*32` 245 assert.Equal(t, siblings[0][l-33:l-1], nd) 246 247 // Marking a key that is currently not hit (but terminated by an empty node) 248 // also causes it to be added to the deletion proof 249 proofTracer.MarkDeletion(s_key2.Bytes()) 250 siblings, err = proofTracer.GetDeletionProofs() 251 assert.NoError(t, err) 252 assert.Equal(t, 2, len(siblings)) 253 254 key3 := bytes.Repeat([]byte("x"), 32) 255 err = mt.UpdateWord( 256 zkt.NewByte32FromBytesPaddingZero(key3), 257 zkt.NewByte32FromBytesPaddingZero(bytes.Repeat([]byte("z"), 32)), 258 ) 259 assert.NoError(t, err) 260 261 proofTracer = tr.NewProofTracer() 262 err = proofTracer.Prove(s_key1.Bytes(), 0, proof) 263 assert.NoError(t, err) 264 err = proofTracer.Prove(s_key2.Bytes(), 0, proof) 265 assert.NoError(t, err) 266 267 proofTracer.MarkDeletion(s_key1.Bytes()) 268 siblings, err = proofTracer.GetDeletionProofs() 269 assert.NoError(t, err) 270 assert.Equal(t, 1, len(siblings)) 271 272 proofTracer.MarkDeletion(s_key2.Bytes()) 273 siblings, err = proofTracer.GetDeletionProofs() 274 assert.NoError(t, err) 275 assert.Equal(t, 2, len(siblings)) 276 277 // one of the siblings is just leaf for key2, while 278 // another one must be a middle node 279 match1 := bytes.Equal(siblings[0][l-33:l-1], nd) 280 match2 := bytes.Equal(siblings[1][l-33:l-1], nd) 281 assert.True(t, match1 || match2) 282 assert.False(t, match1 && match2) 283 }