github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/chain/trie/proof_test.go (about)

     1  package trie
     2  
     3  import (
     4  	"bytes"
     5  	crand "crypto/rand"
     6  	mrand "math/rand"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/neatio-net/neatio/neatdb/memorydb"
    11  	"github.com/neatio-net/neatio/utilities/common"
    12  	"github.com/neatio-net/neatio/utilities/crypto"
    13  )
    14  
    15  func init() {
    16  	mrand.Seed(time.Now().Unix())
    17  }
    18  
    19  func makeProvers(trie *Trie) []func(key []byte) *memorydb.Database {
    20  	var provers []func(key []byte) *memorydb.Database
    21  
    22  	provers = append(provers, func(key []byte) *memorydb.Database {
    23  		proof := memorydb.New()
    24  		trie.Prove(key, 0, proof)
    25  		return proof
    26  	})
    27  
    28  	provers = append(provers, func(key []byte) *memorydb.Database {
    29  		proof := memorydb.New()
    30  		if it := NewIterator(trie.NodeIterator(key)); it.Next() && bytes.Equal(key, it.Key) {
    31  			for _, p := range it.Prove() {
    32  				proof.Put(crypto.Keccak256(p), p)
    33  			}
    34  		}
    35  		return proof
    36  	})
    37  	return provers
    38  }
    39  
    40  func TestProof(t *testing.T) {
    41  	trie, vals := randomTrie(500)
    42  	root := trie.Hash()
    43  	for i, prover := range makeProvers(trie) {
    44  		for _, kv := range vals {
    45  			proof := prover(kv.k)
    46  			if proof == nil {
    47  				t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k)
    48  			}
    49  			val, _, err := VerifyProof(root, kv.k, proof)
    50  			if err != nil {
    51  				t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof)
    52  			}
    53  			if !bytes.Equal(val, kv.v) {
    54  				t.Fatalf("prover %d: verified value mismatch for key %x: have %x, want %x", i, kv.k, val, kv.v)
    55  			}
    56  		}
    57  	}
    58  }
    59  
    60  func TestOneElementProof(t *testing.T) {
    61  	trie := new(Trie)
    62  	updateString(trie, "k", "v")
    63  	for i, prover := range makeProvers(trie) {
    64  		proof := prover([]byte("k"))
    65  		if proof == nil {
    66  			t.Fatalf("prover %d: nil proof", i)
    67  		}
    68  		if proof.Len() != 1 {
    69  			t.Errorf("prover %d: proof should have one element", i)
    70  		}
    71  		val, _, err := VerifyProof(trie.Hash(), []byte("k"), proof)
    72  		if err != nil {
    73  			t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
    74  		}
    75  		if !bytes.Equal(val, []byte("v")) {
    76  			t.Fatalf("prover %d: verified value mismatch: have %x, want 'k'", i, val)
    77  		}
    78  	}
    79  }
    80  
    81  func TestBadProof(t *testing.T) {
    82  	trie, vals := randomTrie(800)
    83  	root := trie.Hash()
    84  	for i, prover := range makeProvers(trie) {
    85  		for _, kv := range vals {
    86  			proof := prover(kv.k)
    87  			if proof == nil {
    88  				t.Fatalf("prover %d: nil proof", i)
    89  			}
    90  			it := proof.NewIterator()
    91  			for i, d := 0, mrand.Intn(proof.Len()); i <= d; i++ {
    92  				it.Next()
    93  			}
    94  			key := it.Key()
    95  			val, _ := proof.Get(key)
    96  			proof.Delete(key)
    97  			it.Release()
    98  
    99  			mutateByte(val)
   100  			proof.Put(crypto.Keccak256(val), val)
   101  
   102  			if _, _, err := VerifyProof(root, kv.k, proof); err == nil {
   103  				t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k)
   104  			}
   105  		}
   106  	}
   107  }
   108  
   109  func TestMissingKeyProof(t *testing.T) {
   110  	trie := new(Trie)
   111  	updateString(trie, "k", "v")
   112  
   113  	for i, key := range []string{"a", "j", "l", "z"} {
   114  		proof := memorydb.New()
   115  		trie.Prove([]byte(key), 0, proof)
   116  
   117  		if proof.Len() != 1 {
   118  			t.Errorf("test %d: proof should have one element", i)
   119  		}
   120  		val, _, err := VerifyProof(trie.Hash(), []byte(key), proof)
   121  		if err != nil {
   122  			t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
   123  		}
   124  		if val != nil {
   125  			t.Fatalf("test %d: verified value mismatch: have %x, want nil", i, val)
   126  		}
   127  	}
   128  }
   129  
   130  func mutateByte(b []byte) {
   131  	for r := mrand.Intn(len(b)); ; {
   132  		new := byte(mrand.Intn(255))
   133  		if new != b[r] {
   134  			b[r] = new
   135  			break
   136  		}
   137  	}
   138  }
   139  
   140  func BenchmarkProve(b *testing.B) {
   141  	trie, vals := randomTrie(100)
   142  	var keys []string
   143  	for k := range vals {
   144  		keys = append(keys, k)
   145  	}
   146  
   147  	b.ResetTimer()
   148  	for i := 0; i < b.N; i++ {
   149  		kv := vals[keys[i%len(keys)]]
   150  		proofs := memorydb.New()
   151  		if trie.Prove(kv.k, 0, proofs); proofs.Len() == 0 {
   152  			b.Fatalf("zero length proof for %x", kv.k)
   153  		}
   154  	}
   155  }
   156  
   157  func BenchmarkVerifyProof(b *testing.B) {
   158  	trie, vals := randomTrie(100)
   159  	root := trie.Hash()
   160  	var keys []string
   161  	var proofs []*memorydb.Database
   162  	for k := range vals {
   163  		keys = append(keys, k)
   164  		proof := memorydb.New()
   165  		trie.Prove([]byte(k), 0, proof)
   166  		proofs = append(proofs, proof)
   167  	}
   168  
   169  	b.ResetTimer()
   170  	for i := 0; i < b.N; i++ {
   171  		im := i % len(keys)
   172  		if _, _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil {
   173  			b.Fatalf("key %x: %v", keys[im], err)
   174  		}
   175  	}
   176  }
   177  
   178  func randomTrie(n int) (*Trie, map[string]*kv) {
   179  	trie := new(Trie)
   180  	vals := make(map[string]*kv)
   181  	for i := byte(0); i < 100; i++ {
   182  		value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
   183  		value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
   184  		trie.Update(value.k, value.v)
   185  		trie.Update(value2.k, value2.v)
   186  		vals[string(value.k)] = value
   187  		vals[string(value2.k)] = value2
   188  	}
   189  	for i := 0; i < n; i++ {
   190  		value := &kv{randBytes(32), randBytes(20), false}
   191  		trie.Update(value.k, value.v)
   192  		vals[string(value.k)] = value
   193  	}
   194  	return trie, vals
   195  }
   196  
   197  func randBytes(n int) []byte {
   198  	r := make([]byte, n)
   199  	crand.Read(r)
   200  	return r
   201  }