github.com/carter-ya/go-ethereum@v0.0.0-20230628080049-d2309be3983b/trie/iterator_test.go (about)

     1  // Copyright 2014 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package trie
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/binary"
    22  	"fmt"
    23  	"math/rand"
    24  	"testing"
    25  
    26  	"github.com/ethereum/go-ethereum/common"
    27  	"github.com/ethereum/go-ethereum/core/rawdb"
    28  	"github.com/ethereum/go-ethereum/crypto"
    29  	"github.com/ethereum/go-ethereum/ethdb"
    30  	"github.com/ethereum/go-ethereum/ethdb/memorydb"
    31  )
    32  
    33  func TestEmptyIterator(t *testing.T) {
    34  	trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
    35  	iter := trie.NodeIterator(nil)
    36  
    37  	seen := make(map[string]struct{})
    38  	for iter.Next(true) {
    39  		seen[string(iter.Path())] = struct{}{}
    40  	}
    41  	if len(seen) != 0 {
    42  		t.Fatal("Unexpected trie node iterated")
    43  	}
    44  }
    45  
    46  func TestIterator(t *testing.T) {
    47  	db := NewDatabase(rawdb.NewMemoryDatabase())
    48  	trie := NewEmpty(db)
    49  	vals := []struct{ k, v string }{
    50  		{"do", "verb"},
    51  		{"ether", "wookiedoo"},
    52  		{"horse", "stallion"},
    53  		{"shaman", "horse"},
    54  		{"doge", "coin"},
    55  		{"dog", "puppy"},
    56  		{"somethingveryoddindeedthis is", "myothernodedata"},
    57  	}
    58  	all := make(map[string]string)
    59  	for _, val := range vals {
    60  		all[val.k] = val.v
    61  		trie.Update([]byte(val.k), []byte(val.v))
    62  	}
    63  	root, nodes, err := trie.Commit(false)
    64  	if err != nil {
    65  		t.Fatalf("Failed to commit trie %v", err)
    66  	}
    67  	db.Update(NewWithNodeSet(nodes))
    68  
    69  	trie, _ = New(TrieID(root), db)
    70  	found := make(map[string]string)
    71  	it := NewIterator(trie.NodeIterator(nil))
    72  	for it.Next() {
    73  		found[string(it.Key)] = string(it.Value)
    74  	}
    75  
    76  	for k, v := range all {
    77  		if found[k] != v {
    78  			t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v)
    79  		}
    80  	}
    81  }
    82  
    83  type kv struct {
    84  	k, v []byte
    85  	t    bool
    86  }
    87  
    88  func TestIteratorLargeData(t *testing.T) {
    89  	trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
    90  	vals := make(map[string]*kv)
    91  
    92  	for i := byte(0); i < 255; i++ {
    93  		value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
    94  		value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
    95  		trie.Update(value.k, value.v)
    96  		trie.Update(value2.k, value2.v)
    97  		vals[string(value.k)] = value
    98  		vals[string(value2.k)] = value2
    99  	}
   100  
   101  	it := NewIterator(trie.NodeIterator(nil))
   102  	for it.Next() {
   103  		vals[string(it.Key)].t = true
   104  	}
   105  
   106  	var untouched []*kv
   107  	for _, value := range vals {
   108  		if !value.t {
   109  			untouched = append(untouched, value)
   110  		}
   111  	}
   112  
   113  	if len(untouched) > 0 {
   114  		t.Errorf("Missed %d nodes", len(untouched))
   115  		for _, value := range untouched {
   116  			t.Error(value)
   117  		}
   118  	}
   119  }
   120  
   121  // Tests that the node iterator indeed walks over the entire database contents.
   122  func TestNodeIteratorCoverage(t *testing.T) {
   123  	// Create some arbitrary test trie to iterate
   124  	db, trie, _ := makeTestTrie()
   125  
   126  	// Gather all the node hashes found by the iterator
   127  	hashes := make(map[common.Hash]struct{})
   128  	for it := trie.NodeIterator(nil); it.Next(true); {
   129  		if it.Hash() != (common.Hash{}) {
   130  			hashes[it.Hash()] = struct{}{}
   131  		}
   132  	}
   133  	// Cross check the hashes and the database itself
   134  	for hash := range hashes {
   135  		if _, err := db.Node(hash); err != nil {
   136  			t.Errorf("failed to retrieve reported node %x: %v", hash, err)
   137  		}
   138  	}
   139  	for hash, obj := range db.dirties {
   140  		if obj != nil && hash != (common.Hash{}) {
   141  			if _, ok := hashes[hash]; !ok {
   142  				t.Errorf("state entry not reported %x", hash)
   143  			}
   144  		}
   145  	}
   146  	it := db.diskdb.NewIterator(nil, nil)
   147  	for it.Next() {
   148  		key := it.Key()
   149  		if _, ok := hashes[common.BytesToHash(key)]; !ok {
   150  			t.Errorf("state entry not reported %x", key)
   151  		}
   152  	}
   153  	it.Release()
   154  }
   155  
   156  type kvs struct{ k, v string }
   157  
   158  var testdata1 = []kvs{
   159  	{"barb", "ba"},
   160  	{"bard", "bc"},
   161  	{"bars", "bb"},
   162  	{"bar", "b"},
   163  	{"fab", "z"},
   164  	{"food", "ab"},
   165  	{"foos", "aa"},
   166  	{"foo", "a"},
   167  }
   168  
   169  var testdata2 = []kvs{
   170  	{"aardvark", "c"},
   171  	{"bar", "b"},
   172  	{"barb", "bd"},
   173  	{"bars", "be"},
   174  	{"fab", "z"},
   175  	{"foo", "a"},
   176  	{"foos", "aa"},
   177  	{"food", "ab"},
   178  	{"jars", "d"},
   179  }
   180  
   181  func TestIteratorSeek(t *testing.T) {
   182  	trie := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
   183  	for _, val := range testdata1 {
   184  		trie.Update([]byte(val.k), []byte(val.v))
   185  	}
   186  
   187  	// Seek to the middle.
   188  	it := NewIterator(trie.NodeIterator([]byte("fab")))
   189  	if err := checkIteratorOrder(testdata1[4:], it); err != nil {
   190  		t.Fatal(err)
   191  	}
   192  
   193  	// Seek to a non-existent key.
   194  	it = NewIterator(trie.NodeIterator([]byte("barc")))
   195  	if err := checkIteratorOrder(testdata1[1:], it); err != nil {
   196  		t.Fatal(err)
   197  	}
   198  
   199  	// Seek beyond the end.
   200  	it = NewIterator(trie.NodeIterator([]byte("z")))
   201  	if err := checkIteratorOrder(nil, it); err != nil {
   202  		t.Fatal(err)
   203  	}
   204  }
   205  
   206  func checkIteratorOrder(want []kvs, it *Iterator) error {
   207  	for it.Next() {
   208  		if len(want) == 0 {
   209  			return fmt.Errorf("didn't expect any more values, got key %q", it.Key)
   210  		}
   211  		if !bytes.Equal(it.Key, []byte(want[0].k)) {
   212  			return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k)
   213  		}
   214  		want = want[1:]
   215  	}
   216  	if len(want) > 0 {
   217  		return fmt.Errorf("iterator ended early, want key %q", want[0])
   218  	}
   219  	return nil
   220  }
   221  
   222  func TestDifferenceIterator(t *testing.T) {
   223  	dba := NewDatabase(rawdb.NewMemoryDatabase())
   224  	triea := NewEmpty(dba)
   225  	for _, val := range testdata1 {
   226  		triea.Update([]byte(val.k), []byte(val.v))
   227  	}
   228  	rootA, nodesA, _ := triea.Commit(false)
   229  	dba.Update(NewWithNodeSet(nodesA))
   230  	triea, _ = New(TrieID(rootA), dba)
   231  
   232  	dbb := NewDatabase(rawdb.NewMemoryDatabase())
   233  	trieb := NewEmpty(dbb)
   234  	for _, val := range testdata2 {
   235  		trieb.Update([]byte(val.k), []byte(val.v))
   236  	}
   237  	rootB, nodesB, _ := trieb.Commit(false)
   238  	dbb.Update(NewWithNodeSet(nodesB))
   239  	trieb, _ = New(TrieID(rootB), dbb)
   240  
   241  	found := make(map[string]string)
   242  	di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
   243  	it := NewIterator(di)
   244  	for it.Next() {
   245  		found[string(it.Key)] = string(it.Value)
   246  	}
   247  
   248  	all := []struct{ k, v string }{
   249  		{"aardvark", "c"},
   250  		{"barb", "bd"},
   251  		{"bars", "be"},
   252  		{"jars", "d"},
   253  	}
   254  	for _, item := range all {
   255  		if found[item.k] != item.v {
   256  			t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v)
   257  		}
   258  	}
   259  	if len(found) != len(all) {
   260  		t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all))
   261  	}
   262  }
   263  
   264  func TestUnionIterator(t *testing.T) {
   265  	dba := NewDatabase(rawdb.NewMemoryDatabase())
   266  	triea := NewEmpty(dba)
   267  	for _, val := range testdata1 {
   268  		triea.Update([]byte(val.k), []byte(val.v))
   269  	}
   270  	rootA, nodesA, _ := triea.Commit(false)
   271  	dba.Update(NewWithNodeSet(nodesA))
   272  	triea, _ = New(TrieID(rootA), dba)
   273  
   274  	dbb := NewDatabase(rawdb.NewMemoryDatabase())
   275  	trieb := NewEmpty(dbb)
   276  	for _, val := range testdata2 {
   277  		trieb.Update([]byte(val.k), []byte(val.v))
   278  	}
   279  	rootB, nodesB, _ := trieb.Commit(false)
   280  	dbb.Update(NewWithNodeSet(nodesB))
   281  	trieb, _ = New(TrieID(rootB), dbb)
   282  
   283  	di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
   284  	it := NewIterator(di)
   285  
   286  	all := []struct{ k, v string }{
   287  		{"aardvark", "c"},
   288  		{"barb", "ba"},
   289  		{"barb", "bd"},
   290  		{"bard", "bc"},
   291  		{"bars", "bb"},
   292  		{"bars", "be"},
   293  		{"bar", "b"},
   294  		{"fab", "z"},
   295  		{"food", "ab"},
   296  		{"foos", "aa"},
   297  		{"foo", "a"},
   298  		{"jars", "d"},
   299  	}
   300  
   301  	for i, kv := range all {
   302  		if !it.Next() {
   303  			t.Errorf("Iterator ends prematurely at element %d", i)
   304  		}
   305  		if kv.k != string(it.Key) {
   306  			t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k)
   307  		}
   308  		if kv.v != string(it.Value) {
   309  			t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v)
   310  		}
   311  	}
   312  	if it.Next() {
   313  		t.Errorf("Iterator returned extra values.")
   314  	}
   315  }
   316  
   317  func TestIteratorNoDups(t *testing.T) {
   318  	tr := NewEmpty(NewDatabase(rawdb.NewMemoryDatabase()))
   319  	for _, val := range testdata1 {
   320  		tr.Update([]byte(val.k), []byte(val.v))
   321  	}
   322  	checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   323  }
   324  
   325  // This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
   326  func TestIteratorContinueAfterErrorDisk(t *testing.T)    { testIteratorContinueAfterError(t, false) }
   327  func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
   328  
   329  func testIteratorContinueAfterError(t *testing.T, memonly bool) {
   330  	diskdb := memorydb.New()
   331  	triedb := NewDatabase(diskdb)
   332  
   333  	tr := NewEmpty(triedb)
   334  	for _, val := range testdata1 {
   335  		tr.Update([]byte(val.k), []byte(val.v))
   336  	}
   337  	_, nodes, _ := tr.Commit(false)
   338  	triedb.Update(NewWithNodeSet(nodes))
   339  	if !memonly {
   340  		triedb.Commit(tr.Hash(), true, nil)
   341  	}
   342  	wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   343  
   344  	var (
   345  		diskKeys [][]byte
   346  		memKeys  []common.Hash
   347  	)
   348  	if memonly {
   349  		memKeys = triedb.Nodes()
   350  	} else {
   351  		it := diskdb.NewIterator(nil, nil)
   352  		for it.Next() {
   353  			diskKeys = append(diskKeys, it.Key())
   354  		}
   355  		it.Release()
   356  	}
   357  	for i := 0; i < 20; i++ {
   358  		// Create trie that will load all nodes from DB.
   359  		tr, _ := New(TrieID(tr.Hash()), triedb)
   360  
   361  		// Remove a random node from the database. It can't be the root node
   362  		// because that one is already loaded.
   363  		var (
   364  			rkey common.Hash
   365  			rval []byte
   366  			robj *cachedNode
   367  		)
   368  		for {
   369  			if memonly {
   370  				rkey = memKeys[rand.Intn(len(memKeys))]
   371  			} else {
   372  				copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
   373  			}
   374  			if rkey != tr.Hash() {
   375  				break
   376  			}
   377  		}
   378  		if memonly {
   379  			robj = triedb.dirties[rkey]
   380  			delete(triedb.dirties, rkey)
   381  		} else {
   382  			rval, _ = diskdb.Get(rkey[:])
   383  			diskdb.Delete(rkey[:])
   384  		}
   385  		// Iterate until the error is hit.
   386  		seen := make(map[string]bool)
   387  		it := tr.NodeIterator(nil)
   388  		checkIteratorNoDups(t, it, seen)
   389  		missing, ok := it.Error().(*MissingNodeError)
   390  		if !ok || missing.NodeHash != rkey {
   391  			t.Fatal("didn't hit missing node, got", it.Error())
   392  		}
   393  
   394  		// Add the node back and continue iteration.
   395  		if memonly {
   396  			triedb.dirties[rkey] = robj
   397  		} else {
   398  			diskdb.Put(rkey[:], rval)
   399  		}
   400  		checkIteratorNoDups(t, it, seen)
   401  		if it.Error() != nil {
   402  			t.Fatal("unexpected error", it.Error())
   403  		}
   404  		if len(seen) != wantNodeCount {
   405  			t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount)
   406  		}
   407  	}
   408  }
   409  
   410  // Similar to the test above, this one checks that failure to create nodeIterator at a
   411  // certain key prefix behaves correctly when Next is called. The expectation is that Next
   412  // should retry seeking before returning true for the first time.
   413  func TestIteratorContinueAfterSeekErrorDisk(t *testing.T) {
   414  	testIteratorContinueAfterSeekError(t, false)
   415  }
   416  func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
   417  	testIteratorContinueAfterSeekError(t, true)
   418  }
   419  
   420  func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
   421  	// Commit test trie to db, then remove the node containing "bars".
   422  	diskdb := memorydb.New()
   423  	triedb := NewDatabase(diskdb)
   424  
   425  	ctr := NewEmpty(triedb)
   426  	for _, val := range testdata1 {
   427  		ctr.Update([]byte(val.k), []byte(val.v))
   428  	}
   429  	root, nodes, _ := ctr.Commit(false)
   430  	triedb.Update(NewWithNodeSet(nodes))
   431  	if !memonly {
   432  		triedb.Commit(root, true, nil)
   433  	}
   434  	barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
   435  	var (
   436  		barNodeBlob []byte
   437  		barNodeObj  *cachedNode
   438  	)
   439  	if memonly {
   440  		barNodeObj = triedb.dirties[barNodeHash]
   441  		delete(triedb.dirties, barNodeHash)
   442  	} else {
   443  		barNodeBlob, _ = diskdb.Get(barNodeHash[:])
   444  		diskdb.Delete(barNodeHash[:])
   445  	}
   446  	// Create a new iterator that seeks to "bars". Seeking can't proceed because
   447  	// the node is missing.
   448  	tr, _ := New(TrieID(root), triedb)
   449  	it := tr.NodeIterator([]byte("bars"))
   450  	missing, ok := it.Error().(*MissingNodeError)
   451  	if !ok {
   452  		t.Fatal("want MissingNodeError, got", it.Error())
   453  	} else if missing.NodeHash != barNodeHash {
   454  		t.Fatal("wrong node missing")
   455  	}
   456  	// Reinsert the missing node.
   457  	if memonly {
   458  		triedb.dirties[barNodeHash] = barNodeObj
   459  	} else {
   460  		diskdb.Put(barNodeHash[:], barNodeBlob)
   461  	}
   462  	// Check that iteration produces the right set of values.
   463  	if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
   464  		t.Fatal(err)
   465  	}
   466  }
   467  
   468  func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int {
   469  	if seen == nil {
   470  		seen = make(map[string]bool)
   471  	}
   472  	for it.Next(true) {
   473  		if seen[string(it.Path())] {
   474  			t.Fatalf("iterator visited node path %x twice", it.Path())
   475  		}
   476  		seen[string(it.Path())] = true
   477  	}
   478  	return len(seen)
   479  }
   480  
   481  type loggingDb struct {
   482  	getCount uint64
   483  	backend  ethdb.KeyValueStore
   484  }
   485  
   486  func (l *loggingDb) Has(key []byte) (bool, error) {
   487  	return l.backend.Has(key)
   488  }
   489  
   490  func (l *loggingDb) Get(key []byte) ([]byte, error) {
   491  	l.getCount++
   492  	return l.backend.Get(key)
   493  }
   494  
   495  func (l *loggingDb) Put(key []byte, value []byte) error {
   496  	return l.backend.Put(key, value)
   497  }
   498  
   499  func (l *loggingDb) Delete(key []byte) error {
   500  	return l.backend.Delete(key)
   501  }
   502  
   503  func (l *loggingDb) NewBatch() ethdb.Batch {
   504  	return l.backend.NewBatch()
   505  }
   506  
   507  func (l *loggingDb) NewBatchWithSize(size int) ethdb.Batch {
   508  	return l.backend.NewBatchWithSize(size)
   509  }
   510  
   511  func (l *loggingDb) NewIterator(prefix []byte, start []byte) ethdb.Iterator {
   512  	return l.backend.NewIterator(prefix, start)
   513  }
   514  
   515  func (l *loggingDb) NewSnapshot() (ethdb.Snapshot, error) {
   516  	return l.backend.NewSnapshot()
   517  }
   518  
   519  func (l *loggingDb) Stat(property string) (string, error) {
   520  	return l.backend.Stat(property)
   521  }
   522  
   523  func (l *loggingDb) Compact(start []byte, limit []byte) error {
   524  	return l.backend.Compact(start, limit)
   525  }
   526  
   527  func (l *loggingDb) Close() error {
   528  	return l.backend.Close()
   529  }
   530  
   531  // makeLargeTestTrie create a sample test trie
   532  func makeLargeTestTrie() (*Database, *StateTrie, *loggingDb) {
   533  	// Create an empty trie
   534  	logDb := &loggingDb{0, memorydb.New()}
   535  	triedb := NewDatabase(logDb)
   536  	trie, _ := NewStateTrie(TrieID(common.Hash{}), triedb)
   537  
   538  	// Fill it with some arbitrary data
   539  	for i := 0; i < 10000; i++ {
   540  		key := make([]byte, 32)
   541  		val := make([]byte, 32)
   542  		binary.BigEndian.PutUint64(key, uint64(i))
   543  		binary.BigEndian.PutUint64(val, uint64(i))
   544  		key = crypto.Keccak256(key)
   545  		val = crypto.Keccak256(val)
   546  		trie.Update(key, val)
   547  	}
   548  	_, nodes, _ := trie.Commit(false)
   549  	triedb.Update(NewWithNodeSet(nodes))
   550  	// Return the generated trie
   551  	return triedb, trie, logDb
   552  }
   553  
   554  // Tests that the node iterator indeed walks over the entire database contents.
   555  func TestNodeIteratorLargeTrie(t *testing.T) {
   556  	// Create some arbitrary test trie to iterate
   557  	db, trie, logDb := makeLargeTestTrie()
   558  	db.Cap(0) // flush everything
   559  	// Do a seek operation
   560  	trie.NodeIterator(common.FromHex("0x77667766776677766778855885885885"))
   561  	// master: 24 get operations
   562  	// this pr: 5 get operations
   563  	if have, want := logDb.getCount, uint64(5); have != want {
   564  		t.Fatalf("Too many lookups during seek, have %d want %d", have, want)
   565  	}
   566  }
   567  
   568  func TestIteratorNodeBlob(t *testing.T) {
   569  	var (
   570  		db     = memorydb.New()
   571  		triedb = NewDatabase(db)
   572  		trie   = NewEmpty(triedb)
   573  	)
   574  	vals := []struct{ k, v string }{
   575  		{"do", "verb"},
   576  		{"ether", "wookiedoo"},
   577  		{"horse", "stallion"},
   578  		{"shaman", "horse"},
   579  		{"doge", "coin"},
   580  		{"dog", "puppy"},
   581  		{"somethingveryoddindeedthis is", "myothernodedata"},
   582  	}
   583  	all := make(map[string]string)
   584  	for _, val := range vals {
   585  		all[val.k] = val.v
   586  		trie.Update([]byte(val.k), []byte(val.v))
   587  	}
   588  	_, nodes, _ := trie.Commit(false)
   589  	triedb.Update(NewWithNodeSet(nodes))
   590  	triedb.Cap(0)
   591  
   592  	found := make(map[common.Hash][]byte)
   593  	it := trie.NodeIterator(nil)
   594  	for it.Next(true) {
   595  		if it.Hash() == (common.Hash{}) {
   596  			continue
   597  		}
   598  		found[it.Hash()] = it.NodeBlob()
   599  	}
   600  
   601  	dbIter := db.NewIterator(nil, nil)
   602  	defer dbIter.Release()
   603  
   604  	var count int
   605  	for dbIter.Next() {
   606  		got, present := found[common.BytesToHash(dbIter.Key())]
   607  		if !present {
   608  			t.Fatalf("Miss trie node %v", dbIter.Key())
   609  		}
   610  		if !bytes.Equal(got, dbIter.Value()) {
   611  			t.Fatalf("Unexpected trie node want %v got %v", dbIter.Value(), got)
   612  		}
   613  		count += 1
   614  	}
   615  	if count != len(found) {
   616  		t.Fatal("Find extra trie node via iterator")
   617  	}
   618  }