github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/trie/proof_test.go (about)

     1  package trie
     2  
     3  import (
     4  	"bytes"
     5  	crand "crypto/rand"
     6  	mrand "math/rand"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/quickchainproject/quickchain/common"
    11  	"github.com/quickchainproject/quickchain/crypto"
    12  	"github.com/quickchainproject/quickchain/qctdb"
    13  )
    14  
    15  func init() {
    16  	mrand.Seed(time.Now().Unix())
    17  }
    18  
    19  func TestProof(t *testing.T) {
    20  	trie, vals := randomTrie(500)
    21  	root := trie.Hash()
    22  	for _, kv := range vals {
    23  		proofs, _ := qctdb.NewMemDatabase()
    24  		if trie.Prove(kv.k, 0, proofs) != nil {
    25  			t.Fatalf("missing key %x while constructing proof", kv.k)
    26  		}
    27  		val, err, _ := VerifyProof(root, kv.k, proofs)
    28  		if err != nil {
    29  			t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %v", kv.k, err, proofs)
    30  		}
    31  		if !bytes.Equal(val, kv.v) {
    32  			t.Fatalf("VerifyProof returned wrong value for key %x: got %x, want %x", kv.k, val, kv.v)
    33  		}
    34  	}
    35  }
    36  
    37  func TestOneElementProof(t *testing.T) {
    38  	trie := new(Trie)
    39  	updateString(trie, "k", "v")
    40  	proofs, _ := qctdb.NewMemDatabase()
    41  	trie.Prove([]byte("k"), 0, proofs)
    42  	if len(proofs.Keys()) != 1 {
    43  		t.Error("proof should have one element")
    44  	}
    45  	val, err, _ := VerifyProof(trie.Hash(), []byte("k"), proofs)
    46  	if err != nil {
    47  		t.Fatalf("VerifyProof error: %v\nproof hashes: %v", err, proofs.Keys())
    48  	}
    49  	if !bytes.Equal(val, []byte("v")) {
    50  		t.Fatalf("VerifyProof returned wrong value: got %x, want 'k'", val)
    51  	}
    52  }
    53  
    54  func TestVerifyBadProof(t *testing.T) {
    55  	trie, vals := randomTrie(800)
    56  	root := trie.Hash()
    57  	for _, kv := range vals {
    58  		proofs, _ := qctdb.NewMemDatabase()
    59  		trie.Prove(kv.k, 0, proofs)
    60  		if len(proofs.Keys()) == 0 {
    61  			t.Fatal("zero length proof")
    62  		}
    63  		keys := proofs.Keys()
    64  		key := keys[mrand.Intn(len(keys))]
    65  		node, _ := proofs.Get(key)
    66  		proofs.Delete(key)
    67  		mutateByte(node)
    68  		proofs.Put(crypto.Keccak256(node), node)
    69  		if _, err, _ := VerifyProof(root, kv.k, proofs); err == nil {
    70  			t.Fatalf("expected proof to fail for key %x", kv.k)
    71  		}
    72  	}
    73  }
    74  
    75  // mutateByte changes one byte in b.
    76  func mutateByte(b []byte) {
    77  	for r := mrand.Intn(len(b)); ; {
    78  		new := byte(mrand.Intn(255))
    79  		if new != b[r] {
    80  			b[r] = new
    81  			break
    82  		}
    83  	}
    84  }
    85  
    86  func BenchmarkProve(b *testing.B) {
    87  	trie, vals := randomTrie(100)
    88  	var keys []string
    89  	for k := range vals {
    90  		keys = append(keys, k)
    91  	}
    92  
    93  	b.ResetTimer()
    94  	for i := 0; i < b.N; i++ {
    95  		kv := vals[keys[i%len(keys)]]
    96  		proofs, _ := qctdb.NewMemDatabase()
    97  		if trie.Prove(kv.k, 0, proofs); len(proofs.Keys()) == 0 {
    98  			b.Fatalf("zero length proof for %x", kv.k)
    99  		}
   100  	}
   101  }
   102  
   103  func BenchmarkVerifyProof(b *testing.B) {
   104  	trie, vals := randomTrie(100)
   105  	root := trie.Hash()
   106  	var keys []string
   107  	var proofs []*qctdb.MemDatabase
   108  	for k := range vals {
   109  		keys = append(keys, k)
   110  		proof, _ := qctdb.NewMemDatabase()
   111  		trie.Prove([]byte(k), 0, proof)
   112  		proofs = append(proofs, proof)
   113  	}
   114  
   115  	b.ResetTimer()
   116  	for i := 0; i < b.N; i++ {
   117  		im := i % len(keys)
   118  		if _, err, _ := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil {
   119  			b.Fatalf("key %x: %v", keys[im], err)
   120  		}
   121  	}
   122  }
   123  
   124  func randomTrie(n int) (*Trie, map[string]*kv) {
   125  	trie := new(Trie)
   126  	vals := make(map[string]*kv)
   127  	for i := byte(0); i < 100; i++ {
   128  		value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
   129  		value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
   130  		trie.Update(value.k, value.v)
   131  		trie.Update(value2.k, value2.v)
   132  		vals[string(value.k)] = value
   133  		vals[string(value2.k)] = value2
   134  	}
   135  	for i := 0; i < n; i++ {
   136  		value := &kv{randBytes(32), randBytes(20), false}
   137  		trie.Update(value.k, value.v)
   138  		vals[string(value.k)] = value
   139  	}
   140  	return trie, vals
   141  }
   142  
   143  func randBytes(n int) []byte {
   144  	r := make([]byte, n)
   145  	crand.Read(r)
   146  	return r
   147  }