github.com/daethereum/go-dae@v2.2.3+incompatible/trie/iterator_test.go (about)

     1  // Copyright 2014 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package trie
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/binary"
    22  	"fmt"
    23  	"math/rand"
    24  	"testing"
    25  
    26  	"github.com/daethereum/go-dae/common"
    27  	"github.com/daethereum/go-dae/core/rawdb"
    28  	"github.com/daethereum/go-dae/crypto"
    29  	"github.com/daethereum/go-dae/ethdb"
    30  	"github.com/daethereum/go-dae/ethdb/memorydb"
    31  )
    32  
    33  func TestEmptyIterator(t *testing.T) {
    34  	trie := newEmpty()
    35  	iter := trie.NodeIterator(nil)
    36  
    37  	seen := make(map[string]struct{})
    38  	for iter.Next(true) {
    39  		seen[string(iter.Path())] = struct{}{}
    40  	}
    41  	if len(seen) != 0 {
    42  		t.Fatal("Unexpected trie node iterated")
    43  	}
    44  }
    45  
    46  func TestIterator(t *testing.T) {
    47  	trie := newEmpty()
    48  	vals := []struct{ k, v string }{
    49  		{"do", "verb"},
    50  		{"ether", "wookiedoo"},
    51  		{"horse", "stallion"},
    52  		{"shaman", "horse"},
    53  		{"doge", "coin"},
    54  		{"dog", "puppy"},
    55  		{"somethingveryoddindeedthis is", "myothernodedata"},
    56  	}
    57  	all := make(map[string]string)
    58  	for _, val := range vals {
    59  		all[val.k] = val.v
    60  		trie.Update([]byte(val.k), []byte(val.v))
    61  	}
    62  	trie.Commit(nil)
    63  
    64  	found := make(map[string]string)
    65  	it := NewIterator(trie.NodeIterator(nil))
    66  	for it.Next() {
    67  		found[string(it.Key)] = string(it.Value)
    68  	}
    69  
    70  	for k, v := range all {
    71  		if found[k] != v {
    72  			t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v)
    73  		}
    74  	}
    75  }
    76  
    77  type kv struct {
    78  	k, v []byte
    79  	t    bool
    80  }
    81  
    82  func TestIteratorLargeData(t *testing.T) {
    83  	trie := newEmpty()
    84  	vals := make(map[string]*kv)
    85  
    86  	for i := byte(0); i < 255; i++ {
    87  		value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
    88  		value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
    89  		trie.Update(value.k, value.v)
    90  		trie.Update(value2.k, value2.v)
    91  		vals[string(value.k)] = value
    92  		vals[string(value2.k)] = value2
    93  	}
    94  
    95  	it := NewIterator(trie.NodeIterator(nil))
    96  	for it.Next() {
    97  		vals[string(it.Key)].t = true
    98  	}
    99  
   100  	var untouched []*kv
   101  	for _, value := range vals {
   102  		if !value.t {
   103  			untouched = append(untouched, value)
   104  		}
   105  	}
   106  
   107  	if len(untouched) > 0 {
   108  		t.Errorf("Missed %d nodes", len(untouched))
   109  		for _, value := range untouched {
   110  			t.Error(value)
   111  		}
   112  	}
   113  }
   114  
   115  // Tests that the node iterator indeed walks over the entire database contents.
   116  func TestNodeIteratorCoverage(t *testing.T) {
   117  	// Create some arbitrary test trie to iterate
   118  	db, trie, _ := makeTestTrie()
   119  
   120  	// Gather all the node hashes found by the iterator
   121  	hashes := make(map[common.Hash]struct{})
   122  	for it := trie.NodeIterator(nil); it.Next(true); {
   123  		if it.Hash() != (common.Hash{}) {
   124  			hashes[it.Hash()] = struct{}{}
   125  		}
   126  	}
   127  	// Cross check the hashes and the database itself
   128  	for hash := range hashes {
   129  		if _, err := db.Node(hash); err != nil {
   130  			t.Errorf("failed to retrieve reported node %x: %v", hash, err)
   131  		}
   132  	}
   133  	for hash, obj := range db.dirties {
   134  		if obj != nil && hash != (common.Hash{}) {
   135  			if _, ok := hashes[hash]; !ok {
   136  				t.Errorf("state entry not reported %x", hash)
   137  			}
   138  		}
   139  	}
   140  	it := db.diskdb.NewIterator(nil, nil)
   141  	for it.Next() {
   142  		key := it.Key()
   143  		if _, ok := hashes[common.BytesToHash(key)]; !ok {
   144  			t.Errorf("state entry not reported %x", key)
   145  		}
   146  	}
   147  	it.Release()
   148  }
   149  
   150  type kvs struct{ k, v string }
   151  
   152  var testdata1 = []kvs{
   153  	{"barb", "ba"},
   154  	{"bard", "bc"},
   155  	{"bars", "bb"},
   156  	{"bar", "b"},
   157  	{"fab", "z"},
   158  	{"food", "ab"},
   159  	{"foos", "aa"},
   160  	{"foo", "a"},
   161  }
   162  
   163  var testdata2 = []kvs{
   164  	{"aardvark", "c"},
   165  	{"bar", "b"},
   166  	{"barb", "bd"},
   167  	{"bars", "be"},
   168  	{"fab", "z"},
   169  	{"foo", "a"},
   170  	{"foos", "aa"},
   171  	{"food", "ab"},
   172  	{"jars", "d"},
   173  }
   174  
   175  func TestIteratorSeek(t *testing.T) {
   176  	trie := newEmpty()
   177  	for _, val := range testdata1 {
   178  		trie.Update([]byte(val.k), []byte(val.v))
   179  	}
   180  
   181  	// Seek to the middle.
   182  	it := NewIterator(trie.NodeIterator([]byte("fab")))
   183  	if err := checkIteratorOrder(testdata1[4:], it); err != nil {
   184  		t.Fatal(err)
   185  	}
   186  
   187  	// Seek to a non-existent key.
   188  	it = NewIterator(trie.NodeIterator([]byte("barc")))
   189  	if err := checkIteratorOrder(testdata1[1:], it); err != nil {
   190  		t.Fatal(err)
   191  	}
   192  
   193  	// Seek beyond the end.
   194  	it = NewIterator(trie.NodeIterator([]byte("z")))
   195  	if err := checkIteratorOrder(nil, it); err != nil {
   196  		t.Fatal(err)
   197  	}
   198  }
   199  
   200  func checkIteratorOrder(want []kvs, it *Iterator) error {
   201  	for it.Next() {
   202  		if len(want) == 0 {
   203  			return fmt.Errorf("didn't expect any more values, got key %q", it.Key)
   204  		}
   205  		if !bytes.Equal(it.Key, []byte(want[0].k)) {
   206  			return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k)
   207  		}
   208  		want = want[1:]
   209  	}
   210  	if len(want) > 0 {
   211  		return fmt.Errorf("iterator ended early, want key %q", want[0])
   212  	}
   213  	return nil
   214  }
   215  
   216  func TestDifferenceIterator(t *testing.T) {
   217  	triea := newEmpty()
   218  	for _, val := range testdata1 {
   219  		triea.Update([]byte(val.k), []byte(val.v))
   220  	}
   221  	triea.Commit(nil)
   222  
   223  	trieb := newEmpty()
   224  	for _, val := range testdata2 {
   225  		trieb.Update([]byte(val.k), []byte(val.v))
   226  	}
   227  	trieb.Commit(nil)
   228  
   229  	found := make(map[string]string)
   230  	di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
   231  	it := NewIterator(di)
   232  	for it.Next() {
   233  		found[string(it.Key)] = string(it.Value)
   234  	}
   235  
   236  	all := []struct{ k, v string }{
   237  		{"aardvark", "c"},
   238  		{"barb", "bd"},
   239  		{"bars", "be"},
   240  		{"jars", "d"},
   241  	}
   242  	for _, item := range all {
   243  		if found[item.k] != item.v {
   244  			t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v)
   245  		}
   246  	}
   247  	if len(found) != len(all) {
   248  		t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all))
   249  	}
   250  }
   251  
   252  func TestUnionIterator(t *testing.T) {
   253  	triea := newEmpty()
   254  	for _, val := range testdata1 {
   255  		triea.Update([]byte(val.k), []byte(val.v))
   256  	}
   257  	triea.Commit(nil)
   258  
   259  	trieb := newEmpty()
   260  	for _, val := range testdata2 {
   261  		trieb.Update([]byte(val.k), []byte(val.v))
   262  	}
   263  	trieb.Commit(nil)
   264  
   265  	di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
   266  	it := NewIterator(di)
   267  
   268  	all := []struct{ k, v string }{
   269  		{"aardvark", "c"},
   270  		{"barb", "ba"},
   271  		{"barb", "bd"},
   272  		{"bard", "bc"},
   273  		{"bars", "bb"},
   274  		{"bars", "be"},
   275  		{"bar", "b"},
   276  		{"fab", "z"},
   277  		{"food", "ab"},
   278  		{"foos", "aa"},
   279  		{"foo", "a"},
   280  		{"jars", "d"},
   281  	}
   282  
   283  	for i, kv := range all {
   284  		if !it.Next() {
   285  			t.Errorf("Iterator ends prematurely at element %d", i)
   286  		}
   287  		if kv.k != string(it.Key) {
   288  			t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k)
   289  		}
   290  		if kv.v != string(it.Value) {
   291  			t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v)
   292  		}
   293  	}
   294  	if it.Next() {
   295  		t.Errorf("Iterator returned extra values.")
   296  	}
   297  }
   298  
   299  func TestIteratorNoDups(t *testing.T) {
   300  	tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
   301  	for _, val := range testdata1 {
   302  		tr.Update([]byte(val.k), []byte(val.v))
   303  	}
   304  	checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   305  }
   306  
   307  // This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
   308  func TestIteratorContinueAfterErrorDisk(t *testing.T)    { testIteratorContinueAfterError(t, false) }
   309  func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
   310  
   311  func testIteratorContinueAfterError(t *testing.T, memonly bool) {
   312  	diskdb := memorydb.New()
   313  	triedb := NewDatabase(diskdb)
   314  
   315  	tr := NewEmpty(triedb)
   316  	for _, val := range testdata1 {
   317  		tr.Update([]byte(val.k), []byte(val.v))
   318  	}
   319  	tr.Commit(nil)
   320  	if !memonly {
   321  		triedb.Commit(tr.Hash(), true, nil)
   322  	}
   323  	wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   324  
   325  	var (
   326  		diskKeys [][]byte
   327  		memKeys  []common.Hash
   328  	)
   329  	if memonly {
   330  		memKeys = triedb.Nodes()
   331  	} else {
   332  		it := diskdb.NewIterator(nil, nil)
   333  		for it.Next() {
   334  			diskKeys = append(diskKeys, it.Key())
   335  		}
   336  		it.Release()
   337  	}
   338  	for i := 0; i < 20; i++ {
   339  		// Create trie that will load all nodes from DB.
   340  		tr, _ := New(common.Hash{}, tr.Hash(), triedb)
   341  
   342  		// Remove a random node from the database. It can't be the root node
   343  		// because that one is already loaded.
   344  		var (
   345  			rkey common.Hash
   346  			rval []byte
   347  			robj *cachedNode
   348  		)
   349  		for {
   350  			if memonly {
   351  				rkey = memKeys[rand.Intn(len(memKeys))]
   352  			} else {
   353  				copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
   354  			}
   355  			if rkey != tr.Hash() {
   356  				break
   357  			}
   358  		}
   359  		if memonly {
   360  			robj = triedb.dirties[rkey]
   361  			delete(triedb.dirties, rkey)
   362  		} else {
   363  			rval, _ = diskdb.Get(rkey[:])
   364  			diskdb.Delete(rkey[:])
   365  		}
   366  		// Iterate until the error is hit.
   367  		seen := make(map[string]bool)
   368  		it := tr.NodeIterator(nil)
   369  		checkIteratorNoDups(t, it, seen)
   370  		missing, ok := it.Error().(*MissingNodeError)
   371  		if !ok || missing.NodeHash != rkey {
   372  			t.Fatal("didn't hit missing node, got", it.Error())
   373  		}
   374  
   375  		// Add the node back and continue iteration.
   376  		if memonly {
   377  			triedb.dirties[rkey] = robj
   378  		} else {
   379  			diskdb.Put(rkey[:], rval)
   380  		}
   381  		checkIteratorNoDups(t, it, seen)
   382  		if it.Error() != nil {
   383  			t.Fatal("unexpected error", it.Error())
   384  		}
   385  		if len(seen) != wantNodeCount {
   386  			t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount)
   387  		}
   388  	}
   389  }
   390  
   391  // Similar to the test above, this one checks that failure to create nodeIterator at a
   392  // certain key prefix behaves correctly when Next is called. The expectation is that Next
   393  // should retry seeking before returning true for the first time.
   394  func TestIteratorContinueAfterSeekErrorDisk(t *testing.T) {
   395  	testIteratorContinueAfterSeekError(t, false)
   396  }
   397  func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
   398  	testIteratorContinueAfterSeekError(t, true)
   399  }
   400  
   401  func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
   402  	// Commit test trie to db, then remove the node containing "bars".
   403  	diskdb := memorydb.New()
   404  	triedb := NewDatabase(diskdb)
   405  
   406  	ctr := NewEmpty(triedb)
   407  	for _, val := range testdata1 {
   408  		ctr.Update([]byte(val.k), []byte(val.v))
   409  	}
   410  	root, _, _ := ctr.Commit(nil)
   411  	if !memonly {
   412  		triedb.Commit(root, true, nil)
   413  	}
   414  	barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
   415  	var (
   416  		barNodeBlob []byte
   417  		barNodeObj  *cachedNode
   418  	)
   419  	if memonly {
   420  		barNodeObj = triedb.dirties[barNodeHash]
   421  		delete(triedb.dirties, barNodeHash)
   422  	} else {
   423  		barNodeBlob, _ = diskdb.Get(barNodeHash[:])
   424  		diskdb.Delete(barNodeHash[:])
   425  	}
   426  	// Create a new iterator that seeks to "bars". Seeking can't proceed because
   427  	// the node is missing.
   428  	tr, _ := New(common.Hash{}, root, triedb)
   429  	it := tr.NodeIterator([]byte("bars"))
   430  	missing, ok := it.Error().(*MissingNodeError)
   431  	if !ok {
   432  		t.Fatal("want MissingNodeError, got", it.Error())
   433  	} else if missing.NodeHash != barNodeHash {
   434  		t.Fatal("wrong node missing")
   435  	}
   436  	// Reinsert the missing node.
   437  	if memonly {
   438  		triedb.dirties[barNodeHash] = barNodeObj
   439  	} else {
   440  		diskdb.Put(barNodeHash[:], barNodeBlob)
   441  	}
   442  	// Check that iteration produces the right set of values.
   443  	if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
   444  		t.Fatal(err)
   445  	}
   446  }
   447  
   448  func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int {
   449  	if seen == nil {
   450  		seen = make(map[string]bool)
   451  	}
   452  	for it.Next(true) {
   453  		if seen[string(it.Path())] {
   454  			t.Fatalf("iterator visited node path %x twice", it.Path())
   455  		}
   456  		seen[string(it.Path())] = true
   457  	}
   458  	return len(seen)
   459  }
   460  
   461  type loggingDb struct {
   462  	getCount uint64
   463  	backend  ethdb.KeyValueStore
   464  }
   465  
   466  func (l *loggingDb) Has(key []byte) (bool, error) {
   467  	return l.backend.Has(key)
   468  }
   469  
   470  func (l *loggingDb) Get(key []byte) ([]byte, error) {
   471  	l.getCount++
   472  	return l.backend.Get(key)
   473  }
   474  
   475  func (l *loggingDb) Put(key []byte, value []byte) error {
   476  	return l.backend.Put(key, value)
   477  }
   478  
   479  func (l *loggingDb) Delete(key []byte) error {
   480  	return l.backend.Delete(key)
   481  }
   482  
   483  func (l *loggingDb) NewBatch() ethdb.Batch {
   484  	return l.backend.NewBatch()
   485  }
   486  
   487  func (l *loggingDb) NewBatchWithSize(size int) ethdb.Batch {
   488  	return l.backend.NewBatchWithSize(size)
   489  }
   490  
   491  func (l *loggingDb) NewIterator(prefix []byte, start []byte) ethdb.Iterator {
   492  	return l.backend.NewIterator(prefix, start)
   493  }
   494  
   495  func (l *loggingDb) NewSnapshot() (ethdb.Snapshot, error) {
   496  	return l.backend.NewSnapshot()
   497  }
   498  
   499  func (l *loggingDb) Stat(property string) (string, error) {
   500  	return l.backend.Stat(property)
   501  }
   502  
   503  func (l *loggingDb) Compact(start []byte, limit []byte) error {
   504  	return l.backend.Compact(start, limit)
   505  }
   506  
   507  func (l *loggingDb) Close() error {
   508  	return l.backend.Close()
   509  }
   510  
   511  // makeLargeTestTrie create a sample test trie
   512  func makeLargeTestTrie() (*Database, *SecureTrie, *loggingDb) {
   513  	// Create an empty trie
   514  	logDb := &loggingDb{0, memorydb.New()}
   515  	triedb := NewDatabase(logDb)
   516  	trie, _ := NewSecure(common.Hash{}, common.Hash{}, triedb)
   517  
   518  	// Fill it with some arbitrary data
   519  	for i := 0; i < 10000; i++ {
   520  		key := make([]byte, 32)
   521  		val := make([]byte, 32)
   522  		binary.BigEndian.PutUint64(key, uint64(i))
   523  		binary.BigEndian.PutUint64(val, uint64(i))
   524  		key = crypto.Keccak256(key)
   525  		val = crypto.Keccak256(val)
   526  		trie.Update(key, val)
   527  	}
   528  	trie.Commit(nil)
   529  	// Return the generated trie
   530  	return triedb, trie, logDb
   531  }
   532  
   533  // Tests that the node iterator indeed walks over the entire database contents.
   534  func TestNodeIteratorLargeTrie(t *testing.T) {
   535  	// Create some arbitrary test trie to iterate
   536  	db, trie, logDb := makeLargeTestTrie()
   537  	db.Cap(0) // flush everything
   538  	// Do a seek operation
   539  	trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885"))
   540  	// master: 24 get operations
   541  	// this pr: 5 get operations
   542  	if have, want := logDb.getCount, uint64(5); have != want {
   543  		t.Fatalf("Too many lookups during seek, have %d want %d", have, want)
   544  	}
   545  }
   546  
   547  func TestIteratorNodeBlob(t *testing.T) {
   548  	var (
   549  		db     = memorydb.New()
   550  		triedb = NewDatabase(db)
   551  		trie   = NewEmpty(triedb)
   552  	)
   553  	vals := []struct{ k, v string }{
   554  		{"do", "verb"},
   555  		{"ether", "wookiedoo"},
   556  		{"horse", "stallion"},
   557  		{"shaman", "horse"},
   558  		{"doge", "coin"},
   559  		{"dog", "puppy"},
   560  		{"somethingveryoddindeedthis is", "myothernodedata"},
   561  	}
   562  	all := make(map[string]string)
   563  	for _, val := range vals {
   564  		all[val.k] = val.v
   565  		trie.Update([]byte(val.k), []byte(val.v))
   566  	}
   567  	trie.Commit(nil)
   568  	triedb.Cap(0)
   569  
   570  	found := make(map[common.Hash][]byte)
   571  	it := trie.NodeIterator(nil)
   572  	for it.Next(true) {
   573  		if it.Hash() == (common.Hash{}) {
   574  			continue
   575  		}
   576  		found[it.Hash()] = it.NodeBlob()
   577  	}
   578  
   579  	dbIter := db.NewIterator(nil, nil)
   580  	defer dbIter.Release()
   581  
   582  	var count int
   583  	for dbIter.Next() {
   584  		got, present := found[common.BytesToHash(dbIter.Key())]
   585  		if !present {
   586  			t.Fatalf("Miss trie node %v", dbIter.Key())
   587  		}
   588  		if !bytes.Equal(got, dbIter.Value()) {
   589  			t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got)
   590  		}
   591  		count += 1
   592  	}
   593  	if count != len(found) {
   594  		t.Fatal("Find extra trie node via iterator")
   595  	}
   596  }