github.com/tacshi/go-ethereum@v0.0.0-20230616113857-84a434e20921/trie/iterator_test.go (about)

     1  // Copyright 2014 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package trie
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/binary"
    22  	"fmt"
    23  	"math/rand"
    24  	"testing"
    25  
    26  	"github.com/tacshi/go-ethereum/common"
    27  	"github.com/tacshi/go-ethereum/core/rawdb"
    28  	"github.com/tacshi/go-ethereum/crypto"
    29  	"github.com/tacshi/go-ethereum/ethdb"
    30  	"github.com/tacshi/go-ethereum/ethdb/memorydb"
    31  )
    32  
    33  func TestEmptyIterator(t *testing.T) {
    34  	trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
    35  	iter := trie.NodeIterator(nil)
    36  
    37  	seen := make(map[string]struct{})
    38  	for iter.Next(true) {
    39  		seen[string(iter.Path())] = struct{}{}
    40  	}
    41  	if len(seen) != 0 {
    42  		t.Fatal("Unexpected trie node iterated")
    43  	}
    44  }
    45  
    46  func TestIterator(t *testing.T) {
    47  	db := NewDatabase(rawdb.NewMemoryDatabase())
    48  	trie := NewEmpty(db)
    49  	vals := []struct{ k, v string }{
    50  		{"do", "verb"},
    51  		{"ether", "wookiedoo"},
    52  		{"horse", "stallion"},
    53  		{"shaman", "horse"},
    54  		{"doge", "coin"},
    55  		{"dog", "puppy"},
    56  		{"somethingveryoddindeedthis is", "myothernodedata"},
    57  	}
    58  	all := make(map[string]string)
    59  	for _, val := range vals {
    60  		all[val.k] = val.v
    61  		trie.Update([]byte(val.k), []byte(val.v))
    62  	}
    63  	root, nodes := trie.Commit(false)
    64  	db.Update(NewWithNodeSet(nodes))
    65  
    66  	trie, _ = New(TrieID(root), db)
    67  	found := make(map[string]string)
    68  	it := NewIterator(trie.NodeIterator(nil))
    69  	for it.Next() {
    70  		found[string(it.Key)] = string(it.Value)
    71  	}
    72  
    73  	for k, v := range all {
    74  		if found[k] != v {
    75  			t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v)
    76  		}
    77  	}
    78  }
    79  
    80  type kv struct {
    81  	k, v []byte
    82  	t    bool
    83  }
    84  
    85  func TestIteratorLargeData(t *testing.T) {
    86  	trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
    87  	vals := make(map[string]*kv)
    88  
    89  	for i := byte(0); i < 255; i++ {
    90  		value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
    91  		value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
    92  		trie.Update(value.k, value.v)
    93  		trie.Update(value2.k, value2.v)
    94  		vals[string(value.k)] = value
    95  		vals[string(value2.k)] = value2
    96  	}
    97  
    98  	it := NewIterator(trie.NodeIterator(nil))
    99  	for it.Next() {
   100  		vals[string(it.Key)].t = true
   101  	}
   102  
   103  	var untouched []*kv
   104  	for _, value := range vals {
   105  		if !value.t {
   106  			untouched = append(untouched, value)
   107  		}
   108  	}
   109  
   110  	if len(untouched) > 0 {
   111  		t.Errorf("Missed %d nodes", len(untouched))
   112  		for _, value := range untouched {
   113  			t.Error(value)
   114  		}
   115  	}
   116  }
   117  
   118  // Tests that the node iterator indeed walks over the entire database contents.
   119  func TestNodeIteratorCoverage(t *testing.T) {
   120  	// Create some arbitrary test trie to iterate
   121  	db, trie, _ := makeTestTrie()
   122  
   123  	// Gather all the node hashes found by the iterator
   124  	hashes := make(map[common.Hash]struct{})
   125  	for it := trie.NodeIterator(nil); it.Next(true); {
   126  		if it.Hash() != (common.Hash{}) {
   127  			hashes[it.Hash()] = struct{}{}
   128  		}
   129  	}
   130  	// Cross check the hashes and the database itself
   131  	for hash := range hashes {
   132  		if _, err := db.Node(hash); err != nil {
   133  			t.Errorf("failed to retrieve reported node %x: %v", hash, err)
   134  		}
   135  	}
   136  	for hash, obj := range db.dirties {
   137  		if obj != nil && hash != (common.Hash{}) {
   138  			if _, ok := hashes[hash]; !ok {
   139  				t.Errorf("state entry not reported %x", hash)
   140  			}
   141  		}
   142  	}
   143  	it := db.diskdb.NewIterator(nil, nil)
   144  	for it.Next() {
   145  		key := it.Key()
   146  		if _, ok := hashes[common.BytesToHash(key)]; !ok {
   147  			t.Errorf("state entry not reported %x", key)
   148  		}
   149  	}
   150  	it.Release()
   151  }
   152  
   153  type kvs struct{ k, v string }
   154  
   155  var testdata1 = []kvs{
   156  	{"barb", "ba"},
   157  	{"bard", "bc"},
   158  	{"bars", "bb"},
   159  	{"bar", "b"},
   160  	{"fab", "z"},
   161  	{"food", "ab"},
   162  	{"foos", "aa"},
   163  	{"foo", "a"},
   164  }
   165  
   166  var testdata2 = []kvs{
   167  	{"aardvark", "c"},
   168  	{"bar", "b"},
   169  	{"barb", "bd"},
   170  	{"bars", "be"},
   171  	{"fab", "z"},
   172  	{"foo", "a"},
   173  	{"foos", "aa"},
   174  	{"food", "ab"},
   175  	{"jars", "d"},
   176  }
   177  
   178  func TestIteratorSeek(t *testing.T) {
   179  	trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
   180  	for _, val := range testdata1 {
   181  		trie.Update([]byte(val.k), []byte(val.v))
   182  	}
   183  
   184  	// Seek to the middle.
   185  	it := NewIterator(trie.NodeIterator([]byte("fab")))
   186  	if err := checkIteratorOrder(testdata1[4:], it); err != nil {
   187  		t.Fatal(err)
   188  	}
   189  
   190  	// Seek to a non-existent key.
   191  	it = NewIterator(trie.NodeIterator([]byte("barc")))
   192  	if err := checkIteratorOrder(testdata1[1:], it); err != nil {
   193  		t.Fatal(err)
   194  	}
   195  
   196  	// Seek beyond the end.
   197  	it = NewIterator(trie.NodeIterator([]byte("z")))
   198  	if err := checkIteratorOrder(nil, it); err != nil {
   199  		t.Fatal(err)
   200  	}
   201  }
   202  
   203  func checkIteratorOrder(want []kvs, it *Iterator) error {
   204  	for it.Next() {
   205  		if len(want) == 0 {
   206  			return fmt.Errorf("didn't expect any more values, got key %q", it.Key)
   207  		}
   208  		if !bytes.Equal(it.Key, []byte(want[0].k)) {
   209  			return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k)
   210  		}
   211  		want = want[1:]
   212  	}
   213  	if len(want) > 0 {
   214  		return fmt.Errorf("iterator ended early, want key %q", want[0])
   215  	}
   216  	return nil
   217  }
   218  
   219  func TestDifferenceIterator(t *testing.T) {
   220  	dba := NewDatabase(rawdb.NewMemoryDatabase())
   221  	triea := NewEmpty(dba)
   222  	for _, val := range testdata1 {
   223  		triea.Update([]byte(val.k), []byte(val.v))
   224  	}
   225  	rootA, nodesA := triea.Commit(false)
   226  	dba.Update(NewWithNodeSet(nodesA))
   227  	triea, _ = New(TrieID(rootA), dba)
   228  
   229  	dbb := NewDatabase(rawdb.NewMemoryDatabase())
   230  	trieb := NewEmpty(dbb)
   231  	for _, val := range testdata2 {
   232  		trieb.Update([]byte(val.k), []byte(val.v))
   233  	}
   234  	rootB, nodesB := trieb.Commit(false)
   235  	dbb.Update(NewWithNodeSet(nodesB))
   236  	trieb, _ = New(TrieID(rootB), dbb)
   237  
   238  	found := make(map[string]string)
   239  	di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
   240  	it := NewIterator(di)
   241  	for it.Next() {
   242  		found[string(it.Key)] = string(it.Value)
   243  	}
   244  
   245  	all := []struct{ k, v string }{
   246  		{"aardvark", "c"},
   247  		{"barb", "bd"},
   248  		{"bars", "be"},
   249  		{"jars", "d"},
   250  	}
   251  	for _, item := range all {
   252  		if found[item.k] != item.v {
   253  			t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v)
   254  		}
   255  	}
   256  	if len(found) != len(all) {
   257  		t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all))
   258  	}
   259  }
   260  
   261  func TestUnionIterator(t *testing.T) {
   262  	dba := NewDatabase(rawdb.NewMemoryDatabase())
   263  	triea := NewEmpty(dba)
   264  	for _, val := range testdata1 {
   265  		triea.Update([]byte(val.k), []byte(val.v))
   266  	}
   267  	rootA, nodesA := triea.Commit(false)
   268  	dba.Update(NewWithNodeSet(nodesA))
   269  	triea, _ = New(TrieID(rootA), dba)
   270  
   271  	dbb := NewDatabase(rawdb.NewMemoryDatabase())
   272  	trieb := NewEmpty(dbb)
   273  	for _, val := range testdata2 {
   274  		trieb.Update([]byte(val.k), []byte(val.v))
   275  	}
   276  	rootB, nodesB := trieb.Commit(false)
   277  	dbb.Update(NewWithNodeSet(nodesB))
   278  	trieb, _ = New(TrieID(rootB), dbb)
   279  
   280  	di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
   281  	it := NewIterator(di)
   282  
   283  	all := []struct{ k, v string }{
   284  		{"aardvark", "c"},
   285  		{"barb", "ba"},
   286  		{"barb", "bd"},
   287  		{"bard", "bc"},
   288  		{"bars", "bb"},
   289  		{"bars", "be"},
   290  		{"bar", "b"},
   291  		{"fab", "z"},
   292  		{"food", "ab"},
   293  		{"foos", "aa"},
   294  		{"foo", "a"},
   295  		{"jars", "d"},
   296  	}
   297  
   298  	for i, kv := range all {
   299  		if !it.Next() {
   300  			t.Errorf("Iterator ends prematurely at element %d", i)
   301  		}
   302  		if kv.k != string(it.Key) {
   303  			t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k)
   304  		}
   305  		if kv.v != string(it.Value) {
   306  			t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v)
   307  		}
   308  	}
   309  	if it.Next() {
   310  		t.Errorf("Iterator returned extra values.")
   311  	}
   312  }
   313  
   314  func TestIteratorNoDups(t *testing.T) {
   315  	tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
   316  	for _, val := range testdata1 {
   317  		tr.Update([]byte(val.k), []byte(val.v))
   318  	}
   319  	checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   320  }
   321  
   322  // This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
   323  func TestIteratorContinueAfterErrorDisk(t *testing.T)    { testIteratorContinueAfterError(t, false) }
   324  func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
   325  
   326  func testIteratorContinueAfterError(t *testing.T, memonly bool) {
   327  	diskdb := rawdb.NewMemoryDatabase()
   328  	triedb := NewDatabase(diskdb)
   329  
   330  	tr := NewEmpty(triedb)
   331  	for _, val := range testdata1 {
   332  		tr.Update([]byte(val.k), []byte(val.v))
   333  	}
   334  	_, nodes := tr.Commit(false)
   335  	triedb.Update(NewWithNodeSet(nodes))
   336  	if !memonly {
   337  		triedb.Commit(tr.Hash(), false)
   338  	}
   339  	wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   340  
   341  	var (
   342  		diskKeys [][]byte
   343  		memKeys  []common.Hash
   344  	)
   345  	if memonly {
   346  		memKeys = triedb.Nodes()
   347  	} else {
   348  		it := diskdb.NewIterator(nil, nil)
   349  		for it.Next() {
   350  			diskKeys = append(diskKeys, it.Key())
   351  		}
   352  		it.Release()
   353  	}
   354  	for i := 0; i < 20; i++ {
   355  		// Create trie that will load all nodes from DB.
   356  		tr, _ := New(TrieID(tr.Hash()), triedb)
   357  
   358  		// Remove a random node from the database. It can't be the root node
   359  		// because that one is already loaded.
   360  		var (
   361  			rkey common.Hash
   362  			rval []byte
   363  			robj *cachedNode
   364  		)
   365  		for {
   366  			if memonly {
   367  				rkey = memKeys[rand.Intn(len(memKeys))]
   368  			} else {
   369  				copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
   370  			}
   371  			if rkey != tr.Hash() {
   372  				break
   373  			}
   374  		}
   375  		if memonly {
   376  			robj = triedb.dirties[rkey]
   377  			delete(triedb.dirties, rkey)
   378  		} else {
   379  			rval, _ = diskdb.Get(rkey[:])
   380  			diskdb.Delete(rkey[:])
   381  		}
   382  		// Iterate until the error is hit.
   383  		seen := make(map[string]bool)
   384  		it := tr.NodeIterator(nil)
   385  		checkIteratorNoDups(t, it, seen)
   386  		missing, ok := it.Error().(*MissingNodeError)
   387  		if !ok || missing.NodeHash != rkey {
   388  			t.Fatal("didn't hit missing node, got", it.Error())
   389  		}
   390  
   391  		// Add the node back and continue iteration.
   392  		if memonly {
   393  			triedb.dirties[rkey] = robj
   394  		} else {
   395  			diskdb.Put(rkey[:], rval)
   396  		}
   397  		checkIteratorNoDups(t, it, seen)
   398  		if it.Error() != nil {
   399  			t.Fatal("unexpected error", it.Error())
   400  		}
   401  		if len(seen) != wantNodeCount {
   402  			t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount)
   403  		}
   404  	}
   405  }
   406  
   407  // Similar to the test above, this one checks that failure to create nodeIterator at a
   408  // certain key prefix behaves correctly when Next is called. The expectation is that Next
   409  // should retry seeking before returning true for the first time.
   410  func TestIteratorContinueAfterSeekErrorDisk(t *testing.T) {
   411  	testIteratorContinueAfterSeekError(t, false)
   412  }
   413  func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
   414  	testIteratorContinueAfterSeekError(t, true)
   415  }
   416  
   417  func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
   418  	// Commit test trie to db, then remove the node containing "bars".
   419  	diskdb := rawdb.NewMemoryDatabase()
   420  	triedb := NewDatabase(diskdb)
   421  
   422  	ctr := NewEmpty(triedb)
   423  	for _, val := range testdata1 {
   424  		ctr.Update([]byte(val.k), []byte(val.v))
   425  	}
   426  	root, nodes := ctr.Commit(false)
   427  	triedb.Update(NewWithNodeSet(nodes))
   428  	if !memonly {
   429  		triedb.Commit(root, false)
   430  	}
   431  	barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
   432  	var (
   433  		barNodeBlob []byte
   434  		barNodeObj  *cachedNode
   435  	)
   436  	if memonly {
   437  		barNodeObj = triedb.dirties[barNodeHash]
   438  		delete(triedb.dirties, barNodeHash)
   439  	} else {
   440  		barNodeBlob, _ = diskdb.Get(barNodeHash[:])
   441  		diskdb.Delete(barNodeHash[:])
   442  	}
   443  	// Create a new iterator that seeks to "bars". Seeking can't proceed because
   444  	// the node is missing.
   445  	tr, _ := New(TrieID(root), triedb)
   446  	it := tr.NodeIterator([]byte("bars"))
   447  	missing, ok := it.Error().(*MissingNodeError)
   448  	if !ok {
   449  		t.Fatal("want MissingNodeError, got", it.Error())
   450  	} else if missing.NodeHash != barNodeHash {
   451  		t.Fatal("wrong node missing")
   452  	}
   453  	// Reinsert the missing node.
   454  	if memonly {
   455  		triedb.dirties[barNodeHash] = barNodeObj
   456  	} else {
   457  		diskdb.Put(barNodeHash[:], barNodeBlob)
   458  	}
   459  	// Check that iteration produces the right set of values.
   460  	if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
   461  		t.Fatal(err)
   462  	}
   463  }
   464  
   465  func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int {
   466  	if seen == nil {
   467  		seen = make(map[string]bool)
   468  	}
   469  	for it.Next(true) {
   470  		if seen[string(it.Path())] {
   471  			t.Fatalf("iterator visited node path %x twice", it.Path())
   472  		}
   473  		seen[string(it.Path())] = true
   474  	}
   475  	return len(seen)
   476  }
   477  
   478  type loggingDb struct {
   479  	getCount uint64
   480  	backend  ethdb.KeyValueStore
   481  }
   482  
   483  func (l *loggingDb) Has(key []byte) (bool, error) {
   484  	return l.backend.Has(key)
   485  }
   486  
   487  func (l *loggingDb) Get(key []byte) ([]byte, error) {
   488  	l.getCount++
   489  	return l.backend.Get(key)
   490  }
   491  
   492  func (l *loggingDb) Put(key []byte, value []byte) error {
   493  	return l.backend.Put(key, value)
   494  }
   495  
   496  func (l *loggingDb) Delete(key []byte) error {
   497  	return l.backend.Delete(key)
   498  }
   499  
   500  func (l *loggingDb) NewBatch() ethdb.Batch {
   501  	return l.backend.NewBatch()
   502  }
   503  
   504  func (l *loggingDb) NewBatchWithSize(size int) ethdb.Batch {
   505  	return l.backend.NewBatchWithSize(size)
   506  }
   507  
   508  func (l *loggingDb) NewIterator(prefix []byte, start []byte) ethdb.Iterator {
   509  	return l.backend.NewIterator(prefix, start)
   510  }
   511  
   512  func (l *loggingDb) NewSnapshot() (ethdb.Snapshot, error) {
   513  	return l.backend.NewSnapshot()
   514  }
   515  
   516  func (l *loggingDb) Stat(property string) (string, error) {
   517  	return l.backend.Stat(property)
   518  }
   519  
   520  func (l *loggingDb) Compact(start []byte, limit []byte) error {
   521  	return l.backend.Compact(start, limit)
   522  }
   523  
   524  func (l *loggingDb) Close() error {
   525  	return l.backend.Close()
   526  }
   527  
   528  // makeLargeTestTrie create a sample test trie
   529  func makeLargeTestTrie() (*Database, *StateTrie, *loggingDb) {
   530  	// Create an empty trie
   531  	logDb := &loggingDb{0, memorydb.New()}
   532  	triedb := NewDatabase(rawdb.NewDatabase(logDb))
   533  	trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb)
   534  
   535  	// Fill it with some arbitrary data
   536  	for i := 0; i < 10000; i++ {
   537  		key := make([]byte, 32)
   538  		val := make([]byte, 32)
   539  		binary.BigEndian.PutUint64(key, uint64(i))
   540  		binary.BigEndian.PutUint64(val, uint64(i))
   541  		key = crypto.Keccak256(key)
   542  		val = crypto.Keccak256(val)
   543  		trie.Update(key, val)
   544  	}
   545  	_, nodes := trie.Commit(false)
   546  	triedb.Update(NewWithNodeSet(nodes))
   547  	// Return the generated trie
   548  	return triedb, trie, logDb
   549  }
   550  
   551  // Tests that the node iterator indeed walks over the entire database contents.
   552  func TestNodeIteratorLargeTrie(t *testing.T) {
   553  	// Create some arbitrary test trie to iterate
   554  	db, trie, logDb := makeLargeTestTrie()
   555  	db.Cap(0) // flush everything
   556  	// Do a seek operation
   557  	trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885"))
   558  	// master: 24 get operations
   559  	// this pr: 5 get operations
   560  	if have, want := logDb.getCount, uint64(5); have != want {
   561  		t.Fatalf("Too many lookups during seek, have %d want %d", have, want)
   562  	}
   563  }
   564  
   565  func TestIteratorNodeBlob(t *testing.T) {
   566  	var (
   567  		db     = rawdb.NewMemoryDatabase()
   568  		triedb = NewDatabase(db)
   569  		trie   = NewEmpty(triedb)
   570  	)
   571  	vals := []struct{ k, v string }{
   572  		{"do", "verb"},
   573  		{"ether", "wookiedoo"},
   574  		{"horse", "stallion"},
   575  		{"shaman", "horse"},
   576  		{"doge", "coin"},
   577  		{"dog", "puppy"},
   578  		{"somethingveryoddindeedthis is", "myothernodedata"},
   579  	}
   580  	all := make(map[string]string)
   581  	for _, val := range vals {
   582  		all[val.k] = val.v
   583  		trie.Update([]byte(val.k), []byte(val.v))
   584  	}
   585  	_, nodes := trie.Commit(false)
   586  	triedb.Update(NewWithNodeSet(nodes))
   587  	triedb.Cap(0)
   588  
   589  	found := make(map[common.Hash][]byte)
   590  	it := trie.NodeIterator(nil)
   591  	for it.Next(true) {
   592  		if it.Hash() == (common.Hash{}) {
   593  			continue
   594  		}
   595  		found[it.Hash()] = it.NodeBlob()
   596  	}
   597  
   598  	dbIter := db.NewIterator(nil, nil)
   599  	defer dbIter.Release()
   600  
   601  	var count int
   602  	for dbIter.Next() {
   603  		got, present := found[common.BytesToHash(dbIter.Key())]
   604  		if !present {
   605  			t.Fatalf("Miss trie node %v", dbIter.Key())
   606  		}
   607  		if !bytes.Equal(got, dbIter.Value()) {
   608  			t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got)
   609  		}
   610  		count += 1
   611  	}
   612  	if count != len(found) {
   613  		t.Fatal("Find extra trie node via iterator")
   614  	}
   615  }