github.com/klaytn/klaytn@v1.12.1/storage/statedb/iterator_test.go (about)

     1  // Modifications Copyright 2018 The klaytn Authors
     2  // Copyright 2014 The go-ethereum Authors
     3  // This file is part of the go-ethereum library.
     4  //
     5  // The go-ethereum library is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Lesser General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // The go-ethereum library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    13  // GNU Lesser General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public License
    16  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    17  //
    18  // This file is derived from trie/iterator_test.go (2018/06/04).
    19  // Modified and improved for the klaytn development.
    20  
    21  package statedb
    22  
    23  import (
    24  	"bytes"
    25  	"fmt"
    26  	"math/rand"
    27  	"testing"
    28  
    29  	"github.com/klaytn/klaytn/common"
    30  	"github.com/klaytn/klaytn/storage/database"
    31  )
    32  
    33  func TestIterator(t *testing.T) {
    34  	trie := newEmptyTrie()
    35  	vals := []struct{ k, v string }{
    36  		{"do", "verb"},
    37  		{"klaytn", "wookiedoo"},
    38  		{"horse", "stallion"},
    39  		{"shaman", "horse"},
    40  		{"doge", "coin"},
    41  		{"dog", "puppy"},
    42  		{"somethingveryoddindeedthis is", "myothernodedata"},
    43  	}
    44  	all := make(map[string]string)
    45  	for _, val := range vals {
    46  		all[val.k] = val.v
    47  		trie.Update([]byte(val.k), []byte(val.v))
    48  	}
    49  	trie.Commit(nil)
    50  
    51  	found := make(map[string]string)
    52  	it := NewIterator(trie.NodeIterator(nil))
    53  	for it.Next() {
    54  		found[string(it.Key)] = string(it.Value)
    55  	}
    56  
    57  	for k, v := range all {
    58  		if found[k] != v {
    59  			t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v)
    60  		}
    61  	}
    62  }
    63  
    64  type kv struct {
    65  	k, v []byte
    66  	t    bool
    67  }
    68  
    69  func TestIteratorLargeData(t *testing.T) {
    70  	trie := newEmptyTrie()
    71  	vals := make(map[string]*kv)
    72  
    73  	for i := byte(0); i < 255; i++ {
    74  		value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
    75  		value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
    76  		trie.Update(value.k, value.v)
    77  		trie.Update(value2.k, value2.v)
    78  		vals[string(value.k)] = value
    79  		vals[string(value2.k)] = value2
    80  	}
    81  
    82  	it := NewIterator(trie.NodeIterator(nil))
    83  	for it.Next() {
    84  		vals[string(it.Key)].t = true
    85  	}
    86  
    87  	var untouched []*kv
    88  	for _, value := range vals {
    89  		if !value.t {
    90  			untouched = append(untouched, value)
    91  		}
    92  	}
    93  
    94  	if len(untouched) > 0 {
    95  		t.Errorf("Missed %d nodes", len(untouched))
    96  		for _, value := range untouched {
    97  			t.Error(value)
    98  		}
    99  	}
   100  }
   101  
   102  // Tests that the node iterator indeed walks over the entire database contents.
   103  func TestNodeIteratorCoverage(t *testing.T) {
   104  	// Create some arbitrary test trie to iterate
   105  	db, trie, _ := makeTestTrie()
   106  
   107  	// Gather all the node hashes found by the iterator
   108  	iterated := make(map[common.Hash]struct{})
   109  	for it := trie.NodeIterator(nil); it.Next(true); {
   110  		if it.Hash() != (common.Hash{}) {
   111  			iterated[it.Hash()] = struct{}{}
   112  		}
   113  	}
   114  
   115  	// Cross check the hashes and the database itself
   116  	// db.Node contains all iterated hashes
   117  	for itHash := range iterated {
   118  		if _, err := db.Node(itHash.ExtendZero()); err != nil {
   119  			t.Errorf("failed to retrieve reported node %x: %v", itHash, err)
   120  		}
   121  	}
   122  	// iterated hashes contains all from db.nodes
   123  	for exthash, obj := range db.nodes {
   124  		hash := exthash.Unextend()
   125  		if obj == nil || common.EmptyExtHash(exthash) {
   126  			continue // skip empty entry
   127  		}
   128  		if _, ok := iterated[hash]; !ok {
   129  			t.Errorf("state entry not reported %x", hash)
   130  		}
   131  	}
   132  	// iterated hashes contains all diskDB keys
   133  	db.Cap(0) // flush to diskDB
   134  	for _, key := range db.diskDB.GetMemDB().Keys() {
   135  		hash := common.BytesToExtHash(key).Unextend()
   136  		if _, ok := iterated[hash]; !ok {
   137  			t.Errorf("state entry not reported %x", hash)
   138  		}
   139  	}
   140  }
   141  
   142  // NodeIterator yields exact same result for Trie and StroageTrie
   143  func TestNodeIteratorStorageTrie(t *testing.T) {
   144  	dbm := database.NewMemoryDBManager()
   145  	dbm.WritePruningEnabled()
   146  	triedb := NewDatabase(dbm)
   147  
   148  	trie1, _ := NewTrie(common.Hash{}, triedb, nil)
   149  	hashes1 := make(map[common.Hash]struct{})
   150  	for it := trie1.NodeIterator(nil); it.Next(true); {
   151  		hashes1[it.Hash()] = struct{}{}
   152  	}
   153  
   154  	trie2, _ := NewStorageTrie(common.ExtHash{}, triedb, nil)
   155  	hashes2 := make(map[common.Hash]struct{})
   156  	for it := trie2.NodeIterator(nil); it.Next(true); {
   157  		hashes2[it.Hash()] = struct{}{}
   158  	}
   159  
   160  	for hash := range hashes1 {
   161  		if _, ok := hashes2[hash]; !ok {
   162  			t.Errorf("state entry not reported %x", hash)
   163  		}
   164  	}
   165  	for hash := range hashes2 {
   166  		if _, ok := hashes1[hash]; !ok {
   167  			t.Errorf("state entry not reported %x", hash)
   168  		}
   169  	}
   170  }
   171  
   172  type kvs struct{ k, v string }
   173  
   174  var testdata1 = []kvs{
   175  	{"barb", "ba"},
   176  	{"bard", "bc"},
   177  	{"bars", "bb"},
   178  	{"bar", "b"},
   179  	{"fab", "z"},
   180  	{"food", "ab"},
   181  	{"foos", "aa"},
   182  	{"foo", "a"},
   183  }
   184  
   185  var testdata2 = []kvs{
   186  	{"aardvark", "c"},
   187  	{"bar", "b"},
   188  	{"barb", "bd"},
   189  	{"bars", "be"},
   190  	{"fab", "z"},
   191  	{"foo", "a"},
   192  	{"foos", "aa"},
   193  	{"food", "ab"},
   194  	{"jars", "d"},
   195  }
   196  
   197  func TestIteratorSeek(t *testing.T) {
   198  	trie := newEmptyTrie()
   199  	for _, val := range testdata1 {
   200  		trie.Update([]byte(val.k), []byte(val.v))
   201  	}
   202  
   203  	// Seek to the middle.
   204  	it := NewIterator(trie.NodeIterator([]byte("fab")))
   205  	if err := checkIteratorOrder(testdata1[4:], it); err != nil {
   206  		t.Fatal(err)
   207  	}
   208  
   209  	// Seek to a non-existent key.
   210  	it = NewIterator(trie.NodeIterator([]byte("barc")))
   211  	if err := checkIteratorOrder(testdata1[1:], it); err != nil {
   212  		t.Fatal(err)
   213  	}
   214  
   215  	// Seek beyond the end.
   216  	it = NewIterator(trie.NodeIterator([]byte("z")))
   217  	if err := checkIteratorOrder(nil, it); err != nil {
   218  		t.Fatal(err)
   219  	}
   220  }
   221  
   222  func checkIteratorOrder(want []kvs, it *Iterator) error {
   223  	for it.Next() {
   224  		if len(want) == 0 {
   225  			return fmt.Errorf("didn't expect any more values, got key %q", it.Key)
   226  		}
   227  		if !bytes.Equal(it.Key, []byte(want[0].k)) {
   228  			return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k)
   229  		}
   230  		want = want[1:]
   231  	}
   232  	if len(want) > 0 {
   233  		return fmt.Errorf("iterator ended early, want key %q", want[0])
   234  	}
   235  	return nil
   236  }
   237  
   238  func TestDifferenceIterator(t *testing.T) {
   239  	triea := newEmptyTrie()
   240  	for _, val := range testdata1 {
   241  		triea.Update([]byte(val.k), []byte(val.v))
   242  	}
   243  	triea.Commit(nil)
   244  
   245  	trieb := newEmptyTrie()
   246  	for _, val := range testdata2 {
   247  		trieb.Update([]byte(val.k), []byte(val.v))
   248  	}
   249  	trieb.Commit(nil)
   250  
   251  	found := make(map[string]string)
   252  	di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
   253  	it := NewIterator(di)
   254  	for it.Next() {
   255  		found[string(it.Key)] = string(it.Value)
   256  	}
   257  
   258  	all := []struct{ k, v string }{
   259  		{"aardvark", "c"},
   260  		{"barb", "bd"},
   261  		{"bars", "be"},
   262  		{"jars", "d"},
   263  	}
   264  	for _, item := range all {
   265  		if found[item.k] != item.v {
   266  			t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v)
   267  		}
   268  	}
   269  	if len(found) != len(all) {
   270  		t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all))
   271  	}
   272  }
   273  
   274  func TestUnionIterator(t *testing.T) {
   275  	triea := newEmptyTrie()
   276  	for _, val := range testdata1 {
   277  		triea.Update([]byte(val.k), []byte(val.v))
   278  	}
   279  	triea.Commit(nil)
   280  
   281  	trieb := newEmptyTrie()
   282  	for _, val := range testdata2 {
   283  		trieb.Update([]byte(val.k), []byte(val.v))
   284  	}
   285  	trieb.Commit(nil)
   286  
   287  	di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
   288  	it := NewIterator(di)
   289  
   290  	all := []struct{ k, v string }{
   291  		{"aardvark", "c"},
   292  		{"barb", "ba"},
   293  		{"barb", "bd"},
   294  		{"bard", "bc"},
   295  		{"bars", "bb"},
   296  		{"bars", "be"},
   297  		{"bar", "b"},
   298  		{"fab", "z"},
   299  		{"food", "ab"},
   300  		{"foos", "aa"},
   301  		{"foo", "a"},
   302  		{"jars", "d"},
   303  	}
   304  
   305  	for i, kv := range all {
   306  		if !it.Next() {
   307  			t.Errorf("Iterator ends prematurely at element %d", i)
   308  		}
   309  		if kv.k != string(it.Key) {
   310  			t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k)
   311  		}
   312  		if kv.v != string(it.Value) {
   313  			t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v)
   314  		}
   315  	}
   316  	if it.Next() {
   317  		t.Errorf("Iterator returned extra values.")
   318  	}
   319  }
   320  
   321  func TestIteratorNoDups(t *testing.T) {
   322  	var tr Trie
   323  	for _, val := range testdata1 {
   324  		tr.Update([]byte(val.k), []byte(val.v))
   325  	}
   326  	checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   327  }
   328  
   329  // This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
   330  func TestIteratorContinueAfterErrorDisk(t *testing.T)    { testIteratorContinueAfterError(t, false) }
   331  func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
   332  
   333  func testIteratorContinueAfterError(t *testing.T, memonly bool) {
   334  	dbm := database.NewMemoryDBManager()
   335  	diskdb := dbm.GetMemDB()
   336  	triedb := NewDatabase(dbm)
   337  
   338  	tr, _ := NewTrie(common.Hash{}, triedb, nil)
   339  	for _, val := range testdata1 {
   340  		tr.Update([]byte(val.k), []byte(val.v))
   341  	}
   342  	tr.Commit(nil)
   343  	if !memonly {
   344  		triedb.Commit(tr.Hash(), true, 0)
   345  	}
   346  	wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   347  
   348  	var (
   349  		diskKeys [][]byte
   350  		memKeys  []common.ExtHash
   351  	)
   352  	if memonly {
   353  		memKeys = triedb.Nodes()
   354  	} else {
   355  		diskKeys = diskdb.Keys()
   356  	}
   357  	for i := 0; i < 20; i++ {
   358  		// Create trie that will load all nodes from DB.
   359  		tr, _ := NewTrie(tr.Hash(), triedb, nil)
   360  
   361  		// Remove a random node from the database. It can't be the root node
   362  		// because that one is already loaded.
   363  		var (
   364  			rkey     common.Hash
   365  			nodehash common.ExtHash
   366  			rval     []byte
   367  			robj     *cachedNode
   368  		)
   369  		for {
   370  			if memonly {
   371  				idx := rand.Intn(len(memKeys))
   372  				nodehash = memKeys[idx]
   373  			} else {
   374  				idx := rand.Intn(len(diskKeys))
   375  				nodehash = common.BytesToExtHash(diskKeys[idx])
   376  			}
   377  			rkey = nodehash.Unextend()
   378  			if rkey != tr.Hash() {
   379  				break
   380  			}
   381  		}
   382  
   383  		if memonly {
   384  			robj = triedb.nodes[nodehash]
   385  			delete(triedb.nodes, nodehash)
   386  		} else {
   387  			rval, _ = dbm.ReadTrieNode(nodehash)
   388  			dbm.DeleteTrieNode(nodehash)
   389  		}
   390  		// Iterate until the error is hit.
   391  		seen := make(map[string]bool)
   392  		it := tr.NodeIterator(nil)
   393  		checkIteratorNoDups(t, it, seen)
   394  		missing, ok := it.Error().(*MissingNodeError)
   395  		if !ok || missing.NodeHash != rkey {
   396  			t.Fatal("didn't hit missing node, got", it.Error())
   397  		}
   398  
   399  		// Add the node back and continue iteration.
   400  		if memonly {
   401  			triedb.nodes[nodehash] = robj
   402  		} else {
   403  			dbm.WriteTrieNode(nodehash, rval)
   404  		}
   405  		checkIteratorNoDups(t, it, seen)
   406  		if it.Error() != nil {
   407  			t.Fatal("unexpected error", it.Error())
   408  		}
   409  		if len(seen) != wantNodeCount {
   410  			t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount)
   411  		}
   412  	}
   413  }
   414  
   415  // Similar to the test above, this one checks that failure to create nodeIterator at a
   416  // certain key prefix behaves correctly when Next is called. The expectation is that Next
   417  // should retry seeking before returning true for the first time.
   418  func TestIteratorContinueAfterSeekErrorDisk(t *testing.T) {
   419  	testIteratorContinueAfterSeekError(t, false)
   420  }
   421  
   422  func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
   423  	testIteratorContinueAfterSeekError(t, true)
   424  }
   425  
   426  func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
   427  	// Commit test trie to db, then remove the node containing "bars".
   428  	dbm := database.NewMemoryDBManager()
   429  	triedb := NewDatabase(dbm)
   430  
   431  	ctr, _ := NewTrie(common.Hash{}, triedb, nil)
   432  	for _, val := range testdata1 {
   433  		ctr.Update([]byte(val.k), []byte(val.v))
   434  	}
   435  	root, _ := ctr.Commit(nil)
   436  	if !memonly {
   437  		triedb.Commit(root, true, 0)
   438  	}
   439  	barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
   440  	nodehash := barNodeHash.ExtendZero()
   441  	var (
   442  		barNodeBlob []byte
   443  		barNodeObj  *cachedNode
   444  	)
   445  	if memonly {
   446  		barNodeObj = triedb.nodes[nodehash]
   447  		delete(triedb.nodes, nodehash)
   448  	} else {
   449  		barNodeBlob, _ = dbm.ReadTrieNode(nodehash)
   450  		dbm.DeleteTrieNode(nodehash)
   451  	}
   452  	// Create a new iterator that seeks to "bars". Seeking can't proceed because
   453  	// the node is missing.
   454  	tr, _ := NewTrie(root, triedb, nil)
   455  	it := tr.NodeIterator([]byte("bars"))
   456  	missing, ok := it.Error().(*MissingNodeError)
   457  	if !ok {
   458  		t.Fatal("want MissingNodeError, got", it.Error())
   459  	} else if missing.NodeHash != barNodeHash {
   460  		t.Fatal("wrong node missing")
   461  	}
   462  	// Reinsert the missing node.
   463  	if memonly {
   464  		triedb.nodes[nodehash] = barNodeObj
   465  	} else {
   466  		dbm.WriteTrieNode(nodehash, barNodeBlob)
   467  	}
   468  	// Check that iteration produces the right set of values.
   469  	if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
   470  		t.Fatal(err)
   471  	}
   472  }
   473  
   474  func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int {
   475  	if seen == nil {
   476  		seen = make(map[string]bool)
   477  	}
   478  	for it.Next(true) {
   479  		if seen[string(it.Path())] {
   480  			t.Fatalf("iterator visited node path %x twice", it.Path())
   481  		}
   482  		seen[string(it.Path())] = true
   483  	}
   484  	return len(seen)
   485  }