github.com/klaytn/klaytn@v1.10.2/storage/statedb/iterator_test.go (about)

     1  // Modifications Copyright 2018 The klaytn Authors
     2  // Copyright 2014 The go-ethereum Authors
     3  // This file is part of the go-ethereum library.
     4  //
     5  // The go-ethereum library is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Lesser General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // The go-ethereum library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    13  // GNU Lesser General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public License
    16  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    17  //
    18  // This file is derived from trie/iterator_test.go (2018/06/04).
    19  // Modified and improved for the klaytn development.
    20  
    21  package statedb
    22  
    23  import (
    24  	"bytes"
    25  	"fmt"
    26  	"math/rand"
    27  	"testing"
    28  
    29  	"github.com/klaytn/klaytn/common"
    30  	"github.com/klaytn/klaytn/storage/database"
    31  )
    32  
    33  func TestIterator(t *testing.T) {
    34  	trie := newEmptyTrie()
    35  	vals := []struct{ k, v string }{
    36  		{"do", "verb"},
    37  		{"klaytn", "wookiedoo"},
    38  		{"horse", "stallion"},
    39  		{"shaman", "horse"},
    40  		{"doge", "coin"},
    41  		{"dog", "puppy"},
    42  		{"somethingveryoddindeedthis is", "myothernodedata"},
    43  	}
    44  	all := make(map[string]string)
    45  	for _, val := range vals {
    46  		all[val.k] = val.v
    47  		trie.Update([]byte(val.k), []byte(val.v))
    48  	}
    49  	trie.Commit(nil)
    50  
    51  	found := make(map[string]string)
    52  	it := NewIterator(trie.NodeIterator(nil))
    53  	for it.Next() {
    54  		found[string(it.Key)] = string(it.Value)
    55  	}
    56  
    57  	for k, v := range all {
    58  		if found[k] != v {
    59  			t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v)
    60  		}
    61  	}
    62  }
    63  
    64  type kv struct {
    65  	k, v []byte
    66  	t    bool
    67  }
    68  
    69  func TestIteratorLargeData(t *testing.T) {
    70  	trie := newEmptyTrie()
    71  	vals := make(map[string]*kv)
    72  
    73  	for i := byte(0); i < 255; i++ {
    74  		value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
    75  		value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
    76  		trie.Update(value.k, value.v)
    77  		trie.Update(value2.k, value2.v)
    78  		vals[string(value.k)] = value
    79  		vals[string(value2.k)] = value2
    80  	}
    81  
    82  	it := NewIterator(trie.NodeIterator(nil))
    83  	for it.Next() {
    84  		vals[string(it.Key)].t = true
    85  	}
    86  
    87  	var untouched []*kv
    88  	for _, value := range vals {
    89  		if !value.t {
    90  			untouched = append(untouched, value)
    91  		}
    92  	}
    93  
    94  	if len(untouched) > 0 {
    95  		t.Errorf("Missed %d nodes", len(untouched))
    96  		for _, value := range untouched {
    97  			t.Error(value)
    98  		}
    99  	}
   100  }
   101  
   102  // Tests that the node iterator indeed walks over the entire database contents.
   103  func TestNodeIteratorCoverage(t *testing.T) {
   104  	// Create some arbitrary test trie to iterate
   105  	db, trie, _ := makeTestTrie()
   106  
   107  	// Gather all the node hashes found by the iterator
   108  	hashes := make(map[common.Hash]struct{})
   109  	for it := trie.NodeIterator(nil); it.Next(true); {
   110  		if it.Hash() != (common.Hash{}) {
   111  			hashes[it.Hash()] = struct{}{}
   112  		}
   113  	}
   114  	// Cross check the hashes and the database itself
   115  	for hash := range hashes {
   116  		if _, err := db.Node(hash); err != nil {
   117  			t.Errorf("failed to retrieve reported node %x: %v", hash, err)
   118  		}
   119  	}
   120  	for hash, obj := range db.nodes {
   121  		if obj != nil && hash != (common.Hash{}) {
   122  			if _, ok := hashes[hash]; !ok {
   123  				t.Errorf("state entry not reported %x", hash)
   124  			}
   125  		}
   126  	}
   127  	for _, key := range db.diskDB.GetMemDB().Keys() {
   128  		if _, ok := hashes[common.BytesToHash(key)]; !ok {
   129  			t.Errorf("state entry not reported %x", key)
   130  		}
   131  	}
   132  }
   133  
   134  type kvs struct{ k, v string }
   135  
   136  var testdata1 = []kvs{
   137  	{"barb", "ba"},
   138  	{"bard", "bc"},
   139  	{"bars", "bb"},
   140  	{"bar", "b"},
   141  	{"fab", "z"},
   142  	{"food", "ab"},
   143  	{"foos", "aa"},
   144  	{"foo", "a"},
   145  }
   146  
   147  var testdata2 = []kvs{
   148  	{"aardvark", "c"},
   149  	{"bar", "b"},
   150  	{"barb", "bd"},
   151  	{"bars", "be"},
   152  	{"fab", "z"},
   153  	{"foo", "a"},
   154  	{"foos", "aa"},
   155  	{"food", "ab"},
   156  	{"jars", "d"},
   157  }
   158  
   159  func TestIteratorSeek(t *testing.T) {
   160  	trie := newEmptyTrie()
   161  	for _, val := range testdata1 {
   162  		trie.Update([]byte(val.k), []byte(val.v))
   163  	}
   164  
   165  	// Seek to the middle.
   166  	it := NewIterator(trie.NodeIterator([]byte("fab")))
   167  	if err := checkIteratorOrder(testdata1[4:], it); err != nil {
   168  		t.Fatal(err)
   169  	}
   170  
   171  	// Seek to a non-existent key.
   172  	it = NewIterator(trie.NodeIterator([]byte("barc")))
   173  	if err := checkIteratorOrder(testdata1[1:], it); err != nil {
   174  		t.Fatal(err)
   175  	}
   176  
   177  	// Seek beyond the end.
   178  	it = NewIterator(trie.NodeIterator([]byte("z")))
   179  	if err := checkIteratorOrder(nil, it); err != nil {
   180  		t.Fatal(err)
   181  	}
   182  }
   183  
   184  func checkIteratorOrder(want []kvs, it *Iterator) error {
   185  	for it.Next() {
   186  		if len(want) == 0 {
   187  			return fmt.Errorf("didn't expect any more values, got key %q", it.Key)
   188  		}
   189  		if !bytes.Equal(it.Key, []byte(want[0].k)) {
   190  			return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k)
   191  		}
   192  		want = want[1:]
   193  	}
   194  	if len(want) > 0 {
   195  		return fmt.Errorf("iterator ended early, want key %q", want[0])
   196  	}
   197  	return nil
   198  }
   199  
   200  func TestDifferenceIterator(t *testing.T) {
   201  	triea := newEmptyTrie()
   202  	for _, val := range testdata1 {
   203  		triea.Update([]byte(val.k), []byte(val.v))
   204  	}
   205  	triea.Commit(nil)
   206  
   207  	trieb := newEmptyTrie()
   208  	for _, val := range testdata2 {
   209  		trieb.Update([]byte(val.k), []byte(val.v))
   210  	}
   211  	trieb.Commit(nil)
   212  
   213  	found := make(map[string]string)
   214  	di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
   215  	it := NewIterator(di)
   216  	for it.Next() {
   217  		found[string(it.Key)] = string(it.Value)
   218  	}
   219  
   220  	all := []struct{ k, v string }{
   221  		{"aardvark", "c"},
   222  		{"barb", "bd"},
   223  		{"bars", "be"},
   224  		{"jars", "d"},
   225  	}
   226  	for _, item := range all {
   227  		if found[item.k] != item.v {
   228  			t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v)
   229  		}
   230  	}
   231  	if len(found) != len(all) {
   232  		t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all))
   233  	}
   234  }
   235  
   236  func TestUnionIterator(t *testing.T) {
   237  	triea := newEmptyTrie()
   238  	for _, val := range testdata1 {
   239  		triea.Update([]byte(val.k), []byte(val.v))
   240  	}
   241  	triea.Commit(nil)
   242  
   243  	trieb := newEmptyTrie()
   244  	for _, val := range testdata2 {
   245  		trieb.Update([]byte(val.k), []byte(val.v))
   246  	}
   247  	trieb.Commit(nil)
   248  
   249  	di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
   250  	it := NewIterator(di)
   251  
   252  	all := []struct{ k, v string }{
   253  		{"aardvark", "c"},
   254  		{"barb", "ba"},
   255  		{"barb", "bd"},
   256  		{"bard", "bc"},
   257  		{"bars", "bb"},
   258  		{"bars", "be"},
   259  		{"bar", "b"},
   260  		{"fab", "z"},
   261  		{"food", "ab"},
   262  		{"foos", "aa"},
   263  		{"foo", "a"},
   264  		{"jars", "d"},
   265  	}
   266  
   267  	for i, kv := range all {
   268  		if !it.Next() {
   269  			t.Errorf("Iterator ends prematurely at element %d", i)
   270  		}
   271  		if kv.k != string(it.Key) {
   272  			t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k)
   273  		}
   274  		if kv.v != string(it.Value) {
   275  			t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v)
   276  		}
   277  	}
   278  	if it.Next() {
   279  		t.Errorf("Iterator returned extra values.")
   280  	}
   281  }
   282  
   283  func TestIteratorNoDups(t *testing.T) {
   284  	var tr Trie
   285  	for _, val := range testdata1 {
   286  		tr.Update([]byte(val.k), []byte(val.v))
   287  	}
   288  	checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   289  }
   290  
   291  // This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
   292  func TestIteratorContinueAfterErrorDisk(t *testing.T)    { testIteratorContinueAfterError(t, false) }
   293  func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
   294  
   295  func testIteratorContinueAfterError(t *testing.T, memonly bool) {
   296  	memDBManager := database.NewMemoryDBManager()
   297  	diskdb := memDBManager.GetMemDB()
   298  	triedb := NewDatabase(memDBManager)
   299  
   300  	tr, _ := NewTrie(common.Hash{}, triedb)
   301  	for _, val := range testdata1 {
   302  		tr.Update([]byte(val.k), []byte(val.v))
   303  	}
   304  	tr.Commit(nil)
   305  	if !memonly {
   306  		triedb.Commit(tr.Hash(), true, 0)
   307  	}
   308  	wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   309  
   310  	var (
   311  		diskKeys [][]byte
   312  		memKeys  []common.Hash
   313  	)
   314  	if memonly {
   315  		memKeys = triedb.Nodes()
   316  	} else {
   317  		diskKeys = diskdb.Keys()
   318  	}
   319  	for i := 0; i < 20; i++ {
   320  		// Create trie that will load all nodes from DB.
   321  		tr, _ := NewTrie(tr.Hash(), triedb)
   322  
   323  		// Remove a random node from the database. It can't be the root node
   324  		// because that one is already loaded.
   325  		var (
   326  			rkey common.Hash
   327  			rval []byte
   328  			robj *cachedNode
   329  		)
   330  		for {
   331  			if memonly {
   332  				rkey = memKeys[rand.Intn(len(memKeys))]
   333  			} else {
   334  				copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
   335  			}
   336  			if rkey != tr.Hash() {
   337  				break
   338  			}
   339  		}
   340  		if memonly {
   341  			robj = triedb.nodes[rkey]
   342  			delete(triedb.nodes, rkey)
   343  		} else {
   344  			rval, _ = diskdb.Get(rkey[:])
   345  			diskdb.Delete(rkey[:])
   346  		}
   347  		// Iterate until the error is hit.
   348  		seen := make(map[string]bool)
   349  		it := tr.NodeIterator(nil)
   350  		checkIteratorNoDups(t, it, seen)
   351  		missing, ok := it.Error().(*MissingNodeError)
   352  		if !ok || missing.NodeHash != rkey {
   353  			t.Fatal("didn't hit missing node, got", it.Error())
   354  		}
   355  
   356  		// Add the node back and continue iteration.
   357  		if memonly {
   358  			triedb.nodes[rkey] = robj
   359  		} else {
   360  			diskdb.Put(rkey[:], rval)
   361  		}
   362  		checkIteratorNoDups(t, it, seen)
   363  		if it.Error() != nil {
   364  			t.Fatal("unexpected error", it.Error())
   365  		}
   366  		if len(seen) != wantNodeCount {
   367  			t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount)
   368  		}
   369  	}
   370  }
   371  
   372  // Similar to the test above, this one checks that failure to create nodeIterator at a
   373  // certain key prefix behaves correctly when Next is called. The expectation is that Next
   374  // should retry seeking before returning true for the first time.
   375  func TestIteratorContinueAfterSeekErrorDisk(t *testing.T) {
   376  	testIteratorContinueAfterSeekError(t, false)
   377  }
   378  
   379  func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
   380  	testIteratorContinueAfterSeekError(t, true)
   381  }
   382  
   383  func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
   384  	// Commit test trie to db, then remove the node containing "bars".
   385  	memDBManager := database.NewMemoryDBManager()
   386  	diskdb := memDBManager.GetMemDB()
   387  	triedb := NewDatabase(memDBManager)
   388  
   389  	ctr, _ := NewTrie(common.Hash{}, triedb)
   390  	for _, val := range testdata1 {
   391  		ctr.Update([]byte(val.k), []byte(val.v))
   392  	}
   393  	root, _ := ctr.Commit(nil)
   394  	if !memonly {
   395  		triedb.Commit(root, true, 0)
   396  	}
   397  	barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
   398  	var (
   399  		barNodeBlob []byte
   400  		barNodeObj  *cachedNode
   401  	)
   402  	if memonly {
   403  		barNodeObj = triedb.nodes[barNodeHash]
   404  		delete(triedb.nodes, barNodeHash)
   405  	} else {
   406  		barNodeBlob, _ = diskdb.Get(barNodeHash[:])
   407  		diskdb.Delete(barNodeHash[:])
   408  	}
   409  	// Create a new iterator that seeks to "bars". Seeking can't proceed because
   410  	// the node is missing.
   411  	tr, _ := NewTrie(root, triedb)
   412  	it := tr.NodeIterator([]byte("bars"))
   413  	missing, ok := it.Error().(*MissingNodeError)
   414  	if !ok {
   415  		t.Fatal("want MissingNodeError, got", it.Error())
   416  	} else if missing.NodeHash != barNodeHash {
   417  		t.Fatal("wrong node missing")
   418  	}
   419  	// Reinsert the missing node.
   420  	if memonly {
   421  		triedb.nodes[barNodeHash] = barNodeObj
   422  	} else {
   423  		diskdb.Put(barNodeHash[:], barNodeBlob)
   424  	}
   425  	// Check that iteration produces the right set of values.
   426  	if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
   427  		t.Fatal(err)
   428  	}
   429  }
   430  
   431  func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int {
   432  	if seen == nil {
   433  		seen = make(map[string]bool)
   434  	}
   435  	for it.Next(true) {
   436  		if seen[string(it.Path())] {
   437  			t.Fatalf("iterator visited node path %x twice", it.Path())
   438  		}
   439  		seen[string(it.Path())] = true
   440  	}
   441  	return len(seen)
   442  }