github.com/etherite/go-etherite@v0.0.0-20171015192807-5f4dd87b2f6e/trie/iterator_test.go (about)

     1  // Copyright 2014 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package trie
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"math/rand"
    23  	"testing"
    24  
    25  	"github.com/etherite/go-etherite/common"
    26  	"github.com/etherite/go-etherite/ethdb"
    27  )
    28  
    29  func TestIterator(t *testing.T) {
    30  	trie := newEmpty()
    31  	vals := []struct{ k, v string }{
    32  		{"do", "verb"},
    33  		{"ether", "wookiedoo"},
    34  		{"horse", "stallion"},
    35  		{"shaman", "horse"},
    36  		{"doge", "coin"},
    37  		{"dog", "puppy"},
    38  		{"somethingveryoddindeedthis is", "myothernodedata"},
    39  	}
    40  	all := make(map[string]string)
    41  	for _, val := range vals {
    42  		all[val.k] = val.v
    43  		trie.Update([]byte(val.k), []byte(val.v))
    44  	}
    45  	trie.Commit()
    46  
    47  	found := make(map[string]string)
    48  	it := NewIterator(trie.NodeIterator(nil))
    49  	for it.Next() {
    50  		found[string(it.Key)] = string(it.Value)
    51  	}
    52  
    53  	for k, v := range all {
    54  		if found[k] != v {
    55  			t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v)
    56  		}
    57  	}
    58  }
    59  
    60  type kv struct {
    61  	k, v []byte
    62  	t    bool
    63  }
    64  
    65  func TestIteratorLargeData(t *testing.T) {
    66  	trie := newEmpty()
    67  	vals := make(map[string]*kv)
    68  
    69  	for i := byte(0); i < 255; i++ {
    70  		value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
    71  		value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
    72  		trie.Update(value.k, value.v)
    73  		trie.Update(value2.k, value2.v)
    74  		vals[string(value.k)] = value
    75  		vals[string(value2.k)] = value2
    76  	}
    77  
    78  	it := NewIterator(trie.NodeIterator(nil))
    79  	for it.Next() {
    80  		vals[string(it.Key)].t = true
    81  	}
    82  
    83  	var untouched []*kv
    84  	for _, value := range vals {
    85  		if !value.t {
    86  			untouched = append(untouched, value)
    87  		}
    88  	}
    89  
    90  	if len(untouched) > 0 {
    91  		t.Errorf("Missed %d nodes", len(untouched))
    92  		for _, value := range untouched {
    93  			t.Error(value)
    94  		}
    95  	}
    96  }
    97  
    98  // Tests that the node iterator indeed walks over the entire database contents.
    99  func TestNodeIteratorCoverage(t *testing.T) {
   100  	// Create some arbitrary test trie to iterate
   101  	db, trie, _ := makeTestTrie()
   102  
   103  	// Gather all the node hashes found by the iterator
   104  	hashes := make(map[common.Hash]struct{})
   105  	for it := trie.NodeIterator(nil); it.Next(true); {
   106  		if it.Hash() != (common.Hash{}) {
   107  			hashes[it.Hash()] = struct{}{}
   108  		}
   109  	}
   110  	// Cross check the hashes and the database itself
   111  	for hash := range hashes {
   112  		if _, err := db.Get(hash.Bytes()); err != nil {
   113  			t.Errorf("failed to retrieve reported node %x: %v", hash, err)
   114  		}
   115  	}
   116  	for _, key := range db.(*ethdb.MemDatabase).Keys() {
   117  		if _, ok := hashes[common.BytesToHash(key)]; !ok {
   118  			t.Errorf("state entry not reported %x", key)
   119  		}
   120  	}
   121  }
   122  
   123  type kvs struct{ k, v string }
   124  
   125  var testdata1 = []kvs{
   126  	{"barb", "ba"},
   127  	{"bard", "bc"},
   128  	{"bars", "bb"},
   129  	{"bar", "b"},
   130  	{"fab", "z"},
   131  	{"food", "ab"},
   132  	{"foos", "aa"},
   133  	{"foo", "a"},
   134  }
   135  
   136  var testdata2 = []kvs{
   137  	{"aardvark", "c"},
   138  	{"bar", "b"},
   139  	{"barb", "bd"},
   140  	{"bars", "be"},
   141  	{"fab", "z"},
   142  	{"foo", "a"},
   143  	{"foos", "aa"},
   144  	{"food", "ab"},
   145  	{"jars", "d"},
   146  }
   147  
   148  func TestIteratorSeek(t *testing.T) {
   149  	trie := newEmpty()
   150  	for _, val := range testdata1 {
   151  		trie.Update([]byte(val.k), []byte(val.v))
   152  	}
   153  
   154  	// Seek to the middle.
   155  	it := NewIterator(trie.NodeIterator([]byte("fab")))
   156  	if err := checkIteratorOrder(testdata1[4:], it); err != nil {
   157  		t.Fatal(err)
   158  	}
   159  
   160  	// Seek to a non-existent key.
   161  	it = NewIterator(trie.NodeIterator([]byte("barc")))
   162  	if err := checkIteratorOrder(testdata1[1:], it); err != nil {
   163  		t.Fatal(err)
   164  	}
   165  
   166  	// Seek beyond the end.
   167  	it = NewIterator(trie.NodeIterator([]byte("z")))
   168  	if err := checkIteratorOrder(nil, it); err != nil {
   169  		t.Fatal(err)
   170  	}
   171  }
   172  
   173  func checkIteratorOrder(want []kvs, it *Iterator) error {
   174  	for it.Next() {
   175  		if len(want) == 0 {
   176  			return fmt.Errorf("didn't expect any more values, got key %q", it.Key)
   177  		}
   178  		if !bytes.Equal(it.Key, []byte(want[0].k)) {
   179  			return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k)
   180  		}
   181  		want = want[1:]
   182  	}
   183  	if len(want) > 0 {
   184  		return fmt.Errorf("iterator ended early, want key %q", want[0])
   185  	}
   186  	return nil
   187  }
   188  
   189  func TestDifferenceIterator(t *testing.T) {
   190  	triea := newEmpty()
   191  	for _, val := range testdata1 {
   192  		triea.Update([]byte(val.k), []byte(val.v))
   193  	}
   194  	triea.Commit()
   195  
   196  	trieb := newEmpty()
   197  	for _, val := range testdata2 {
   198  		trieb.Update([]byte(val.k), []byte(val.v))
   199  	}
   200  	trieb.Commit()
   201  
   202  	found := make(map[string]string)
   203  	di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
   204  	it := NewIterator(di)
   205  	for it.Next() {
   206  		found[string(it.Key)] = string(it.Value)
   207  	}
   208  
   209  	all := []struct{ k, v string }{
   210  		{"aardvark", "c"},
   211  		{"barb", "bd"},
   212  		{"bars", "be"},
   213  		{"jars", "d"},
   214  	}
   215  	for _, item := range all {
   216  		if found[item.k] != item.v {
   217  			t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v)
   218  		}
   219  	}
   220  	if len(found) != len(all) {
   221  		t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all))
   222  	}
   223  }
   224  
   225  func TestUnionIterator(t *testing.T) {
   226  	triea := newEmpty()
   227  	for _, val := range testdata1 {
   228  		triea.Update([]byte(val.k), []byte(val.v))
   229  	}
   230  	triea.Commit()
   231  
   232  	trieb := newEmpty()
   233  	for _, val := range testdata2 {
   234  		trieb.Update([]byte(val.k), []byte(val.v))
   235  	}
   236  	trieb.Commit()
   237  
   238  	di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
   239  	it := NewIterator(di)
   240  
   241  	all := []struct{ k, v string }{
   242  		{"aardvark", "c"},
   243  		{"barb", "ba"},
   244  		{"barb", "bd"},
   245  		{"bard", "bc"},
   246  		{"bars", "bb"},
   247  		{"bars", "be"},
   248  		{"bar", "b"},
   249  		{"fab", "z"},
   250  		{"food", "ab"},
   251  		{"foos", "aa"},
   252  		{"foo", "a"},
   253  		{"jars", "d"},
   254  	}
   255  
   256  	for i, kv := range all {
   257  		if !it.Next() {
   258  			t.Errorf("Iterator ends prematurely at element %d", i)
   259  		}
   260  		if kv.k != string(it.Key) {
   261  			t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k)
   262  		}
   263  		if kv.v != string(it.Value) {
   264  			t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v)
   265  		}
   266  	}
   267  	if it.Next() {
   268  		t.Errorf("Iterator returned extra values.")
   269  	}
   270  }
   271  
   272  func TestIteratorNoDups(t *testing.T) {
   273  	var tr Trie
   274  	for _, val := range testdata1 {
   275  		tr.Update([]byte(val.k), []byte(val.v))
   276  	}
   277  	checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   278  }
   279  
   280  // This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
   281  func TestIteratorContinueAfterError(t *testing.T) {
   282  	db, _ := ethdb.NewMemDatabase()
   283  	tr, _ := New(common.Hash{}, db)
   284  	for _, val := range testdata1 {
   285  		tr.Update([]byte(val.k), []byte(val.v))
   286  	}
   287  	tr.Commit()
   288  	wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   289  	keys := db.Keys()
   290  	t.Log("node count", wantNodeCount)
   291  
   292  	for i := 0; i < 20; i++ {
   293  		// Create trie that will load all nodes from DB.
   294  		tr, _ := New(tr.Hash(), db)
   295  
   296  		// Remove a random node from the database. It can't be the root node
   297  		// because that one is already loaded.
   298  		var rkey []byte
   299  		for {
   300  			if rkey = keys[rand.Intn(len(keys))]; !bytes.Equal(rkey, tr.Hash().Bytes()) {
   301  				break
   302  			}
   303  		}
   304  		rval, _ := db.Get(rkey)
   305  		db.Delete(rkey)
   306  
   307  		// Iterate until the error is hit.
   308  		seen := make(map[string]bool)
   309  		it := tr.NodeIterator(nil)
   310  		checkIteratorNoDups(t, it, seen)
   311  		missing, ok := it.Error().(*MissingNodeError)
   312  		if !ok || !bytes.Equal(missing.NodeHash[:], rkey) {
   313  			t.Fatal("didn't hit missing node, got", it.Error())
   314  		}
   315  
   316  		// Add the node back and continue iteration.
   317  		db.Put(rkey, rval)
   318  		checkIteratorNoDups(t, it, seen)
   319  		if it.Error() != nil {
   320  			t.Fatal("unexpected error", it.Error())
   321  		}
   322  		if len(seen) != wantNodeCount {
   323  			t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount)
   324  		}
   325  	}
   326  }
   327  
   328  // Similar to the test above, this one checks that failure to create nodeIterator at a
   329  // certain key prefix behaves correctly when Next is called. The expectation is that Next
   330  // should retry seeking before returning true for the first time.
   331  func TestIteratorContinueAfterSeekError(t *testing.T) {
   332  	// Commit test trie to db, then remove the node containing "bars".
   333  	db, _ := ethdb.NewMemDatabase()
   334  	ctr, _ := New(common.Hash{}, db)
   335  	for _, val := range testdata1 {
   336  		ctr.Update([]byte(val.k), []byte(val.v))
   337  	}
   338  	root, _ := ctr.Commit()
   339  	barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
   340  	barNode, _ := db.Get(barNodeHash[:])
   341  	db.Delete(barNodeHash[:])
   342  
   343  	// Create a new iterator that seeks to "bars". Seeking can't proceed because
   344  	// the node is missing.
   345  	tr, _ := New(root, db)
   346  	it := tr.NodeIterator([]byte("bars"))
   347  	missing, ok := it.Error().(*MissingNodeError)
   348  	if !ok {
   349  		t.Fatal("want MissingNodeError, got", it.Error())
   350  	} else if missing.NodeHash != barNodeHash {
   351  		t.Fatal("wrong node missing")
   352  	}
   353  
   354  	// Reinsert the missing node.
   355  	db.Put(barNodeHash[:], barNode[:])
   356  
   357  	// Check that iteration produces the right set of values.
   358  	if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
   359  		t.Fatal(err)
   360  	}
   361  }
   362  
   363  func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int {
   364  	if seen == nil {
   365  		seen = make(map[string]bool)
   366  	}
   367  	for it.Next(true) {
   368  		if seen[string(it.Path())] {
   369  			t.Fatalf("iterator visited node path %x twice", it.Path())
   370  		}
   371  		seen[string(it.Path())] = true
   372  	}
   373  	return len(seen)
   374  }