github.com/linapex/ethereum-go-chinese@v0.0.0-20190316121929-f8b7a73c3fa1/trie/iterator_test.go (about)

     1  
     2  //<developer>
     3  //    <name>linapex 曹一峰</name>
     4  //    <email>linapex@163.com</email>
     5  //    <wx>superexc</wx>
     6  //    <qqgroup>128148617</qqgroup>
     7  //    <url>https://jsq.ink</url>
     8  //    <role>pku engineer</role>
     9  //    <date>2019-03-16 19:16:45</date>
    10  //</624450122981314560>
    11  
    12  
    13  package trie
    14  
    15  import (
    16  	"bytes"
    17  	"fmt"
    18  	"math/rand"
    19  	"testing"
    20  
    21  	"github.com/ethereum/go-ethereum/common"
    22  	"github.com/ethereum/go-ethereum/ethdb"
    23  )
    24  
    25  func TestIterator(t *testing.T) {
    26  	trie := newEmpty()
    27  	vals := []struct{ k, v string }{
    28  		{"do", "verb"},
    29  		{"ether", "wookiedoo"},
    30  		{"horse", "stallion"},
    31  		{"shaman", "horse"},
    32  		{"doge", "coin"},
    33  		{"dog", "puppy"},
    34  		{"somethingveryoddindeedthis is", "myothernodedata"},
    35  	}
    36  	all := make(map[string]string)
    37  	for _, val := range vals {
    38  		all[val.k] = val.v
    39  		trie.Update([]byte(val.k), []byte(val.v))
    40  	}
    41  	trie.Commit(nil)
    42  
    43  	found := make(map[string]string)
    44  	it := NewIterator(trie.NodeIterator(nil))
    45  	for it.Next() {
    46  		found[string(it.Key)] = string(it.Value)
    47  	}
    48  
    49  	for k, v := range all {
    50  		if found[k] != v {
    51  			t.Errorf("iterator value mismatch for %s: got %q want %q", k, found[k], v)
    52  		}
    53  	}
    54  }
    55  
    56  type kv struct {
    57  	k, v []byte
    58  	t    bool
    59  }
    60  
    61  func TestIteratorLargeData(t *testing.T) {
    62  	trie := newEmpty()
    63  	vals := make(map[string]*kv)
    64  
    65  	for i := byte(0); i < 255; i++ {
    66  		value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
    67  		value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false}
    68  		trie.Update(value.k, value.v)
    69  		trie.Update(value2.k, value2.v)
    70  		vals[string(value.k)] = value
    71  		vals[string(value2.k)] = value2
    72  	}
    73  
    74  	it := NewIterator(trie.NodeIterator(nil))
    75  	for it.Next() {
    76  		vals[string(it.Key)].t = true
    77  	}
    78  
    79  	var untouched []*kv
    80  	for _, value := range vals {
    81  		if !value.t {
    82  			untouched = append(untouched, value)
    83  		}
    84  	}
    85  
    86  	if len(untouched) > 0 {
    87  		t.Errorf("Missed %d nodes", len(untouched))
    88  		for _, value := range untouched {
    89  			t.Error(value)
    90  		}
    91  	}
    92  }
    93  
    94  //测试节点迭代器是否确实遍历整个数据库内容。
    95  func TestNodeIteratorCoverage(t *testing.T) {
    96  //创建一些要迭代的任意测试trie
    97  	db, trie, _ := makeTestTrie()
    98  
    99  //收集迭代器找到的所有节点散列
   100  	hashes := make(map[common.Hash]struct{})
   101  	for it := trie.NodeIterator(nil); it.Next(true); {
   102  		if it.Hash() != (common.Hash{}) {
   103  			hashes[it.Hash()] = struct{}{}
   104  		}
   105  	}
   106  //交叉检查哈希和数据库本身
   107  	for hash := range hashes {
   108  		if _, err := db.Node(hash); err != nil {
   109  			t.Errorf("failed to retrieve reported node %x: %v", hash, err)
   110  		}
   111  	}
   112  	for hash, obj := range db.dirties {
   113  		if obj != nil && hash != (common.Hash{}) {
   114  			if _, ok := hashes[hash]; !ok {
   115  				t.Errorf("state entry not reported %x", hash)
   116  			}
   117  		}
   118  	}
   119  	for _, key := range db.diskdb.(*ethdb.MemDatabase).Keys() {
   120  		if _, ok := hashes[common.BytesToHash(key)]; !ok {
   121  			t.Errorf("state entry not reported %x", key)
   122  		}
   123  	}
   124  }
   125  
   126  type kvs struct{ k, v string }
   127  
   128  var testdata1 = []kvs{
   129  	{"barb", "ba"},
   130  	{"bard", "bc"},
   131  	{"bars", "bb"},
   132  	{"bar", "b"},
   133  	{"fab", "z"},
   134  	{"food", "ab"},
   135  	{"foos", "aa"},
   136  	{"foo", "a"},
   137  }
   138  
   139  var testdata2 = []kvs{
   140  	{"aardvark", "c"},
   141  	{"bar", "b"},
   142  	{"barb", "bd"},
   143  	{"bars", "be"},
   144  	{"fab", "z"},
   145  	{"foo", "a"},
   146  	{"foos", "aa"},
   147  	{"food", "ab"},
   148  	{"jars", "d"},
   149  }
   150  
   151  func TestIteratorSeek(t *testing.T) {
   152  	trie := newEmpty()
   153  	for _, val := range testdata1 {
   154  		trie.Update([]byte(val.k), []byte(val.v))
   155  	}
   156  
   157  //寻求中间。
   158  	it := NewIterator(trie.NodeIterator([]byte("fab")))
   159  	if err := checkIteratorOrder(testdata1[4:], it); err != nil {
   160  		t.Fatal(err)
   161  	}
   162  
   163  //查找不存在的密钥。
   164  	it = NewIterator(trie.NodeIterator([]byte("barc")))
   165  	if err := checkIteratorOrder(testdata1[1:], it); err != nil {
   166  		t.Fatal(err)
   167  	}
   168  
   169  //超越终点。
   170  	it = NewIterator(trie.NodeIterator([]byte("z")))
   171  	if err := checkIteratorOrder(nil, it); err != nil {
   172  		t.Fatal(err)
   173  	}
   174  }
   175  
   176  func checkIteratorOrder(want []kvs, it *Iterator) error {
   177  	for it.Next() {
   178  		if len(want) == 0 {
   179  			return fmt.Errorf("didn't expect any more values, got key %q", it.Key)
   180  		}
   181  		if !bytes.Equal(it.Key, []byte(want[0].k)) {
   182  			return fmt.Errorf("wrong key: got %q, want %q", it.Key, want[0].k)
   183  		}
   184  		want = want[1:]
   185  	}
   186  	if len(want) > 0 {
   187  		return fmt.Errorf("iterator ended early, want key %q", want[0])
   188  	}
   189  	return nil
   190  }
   191  
   192  func TestDifferenceIterator(t *testing.T) {
   193  	triea := newEmpty()
   194  	for _, val := range testdata1 {
   195  		triea.Update([]byte(val.k), []byte(val.v))
   196  	}
   197  	triea.Commit(nil)
   198  
   199  	trieb := newEmpty()
   200  	for _, val := range testdata2 {
   201  		trieb.Update([]byte(val.k), []byte(val.v))
   202  	}
   203  	trieb.Commit(nil)
   204  
   205  	found := make(map[string]string)
   206  	di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
   207  	it := NewIterator(di)
   208  	for it.Next() {
   209  		found[string(it.Key)] = string(it.Value)
   210  	}
   211  
   212  	all := []struct{ k, v string }{
   213  		{"aardvark", "c"},
   214  		{"barb", "bd"},
   215  		{"bars", "be"},
   216  		{"jars", "d"},
   217  	}
   218  	for _, item := range all {
   219  		if found[item.k] != item.v {
   220  			t.Errorf("iterator value mismatch for %s: got %v want %v", item.k, found[item.k], item.v)
   221  		}
   222  	}
   223  	if len(found) != len(all) {
   224  		t.Errorf("iterator count mismatch: got %d values, want %d", len(found), len(all))
   225  	}
   226  }
   227  
   228  func TestUnionIterator(t *testing.T) {
   229  	triea := newEmpty()
   230  	for _, val := range testdata1 {
   231  		triea.Update([]byte(val.k), []byte(val.v))
   232  	}
   233  	triea.Commit(nil)
   234  
   235  	trieb := newEmpty()
   236  	for _, val := range testdata2 {
   237  		trieb.Update([]byte(val.k), []byte(val.v))
   238  	}
   239  	trieb.Commit(nil)
   240  
   241  	di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
   242  	it := NewIterator(di)
   243  
   244  	all := []struct{ k, v string }{
   245  		{"aardvark", "c"},
   246  		{"barb", "ba"},
   247  		{"barb", "bd"},
   248  		{"bard", "bc"},
   249  		{"bars", "bb"},
   250  		{"bars", "be"},
   251  		{"bar", "b"},
   252  		{"fab", "z"},
   253  		{"food", "ab"},
   254  		{"foos", "aa"},
   255  		{"foo", "a"},
   256  		{"jars", "d"},
   257  	}
   258  
   259  	for i, kv := range all {
   260  		if !it.Next() {
   261  			t.Errorf("Iterator ends prematurely at element %d", i)
   262  		}
   263  		if kv.k != string(it.Key) {
   264  			t.Errorf("iterator value mismatch for element %d: got key %s want %s", i, it.Key, kv.k)
   265  		}
   266  		if kv.v != string(it.Value) {
   267  			t.Errorf("iterator value mismatch for element %d: got value %s want %s", i, it.Value, kv.v)
   268  		}
   269  	}
   270  	if it.Next() {
   271  		t.Errorf("Iterator returned extra values.")
   272  	}
   273  }
   274  
   275  func TestIteratorNoDups(t *testing.T) {
   276  	var tr Trie
   277  	for _, val := range testdata1 {
   278  		tr.Update([]byte(val.k), []byte(val.v))
   279  	}
   280  	checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   281  }
   282  
   283  //此测试检查nodeiterator。插入缺少的trie节点后,可以重试next。
   284  func TestIteratorContinueAfterErrorDisk(t *testing.T)    { testIteratorContinueAfterError(t, false) }
   285  func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
   286  
   287  func testIteratorContinueAfterError(t *testing.T, memonly bool) {
   288  	diskdb := ethdb.NewMemDatabase()
   289  	triedb := NewDatabase(diskdb)
   290  
   291  	tr, _ := New(common.Hash{}, triedb)
   292  	for _, val := range testdata1 {
   293  		tr.Update([]byte(val.k), []byte(val.v))
   294  	}
   295  	tr.Commit(nil)
   296  	if !memonly {
   297  		triedb.Commit(tr.Hash(), true)
   298  	}
   299  	wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
   300  
   301  	var (
   302  		diskKeys [][]byte
   303  		memKeys  []common.Hash
   304  	)
   305  	if memonly {
   306  		memKeys = triedb.Nodes()
   307  	} else {
   308  		diskKeys = diskdb.Keys()
   309  	}
   310  	for i := 0; i < 20; i++ {
   311  //创建将从数据库加载所有节点的trie。
   312  		tr, _ := New(tr.Hash(), triedb)
   313  
   314  //从数据库中删除随机节点。它不能是根节点
   315  //因为那个已经加载了。
   316  		var (
   317  			rkey common.Hash
   318  			rval []byte
   319  			robj *cachedNode
   320  		)
   321  		for {
   322  			if memonly {
   323  				rkey = memKeys[rand.Intn(len(memKeys))]
   324  			} else {
   325  				copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
   326  			}
   327  			if rkey != tr.Hash() {
   328  				break
   329  			}
   330  		}
   331  		if memonly {
   332  			robj = triedb.dirties[rkey]
   333  			delete(triedb.dirties, rkey)
   334  		} else {
   335  			rval, _ = diskdb.Get(rkey[:])
   336  			diskdb.Delete(rkey[:])
   337  		}
   338  //迭代直到命中错误。
   339  		seen := make(map[string]bool)
   340  		it := tr.NodeIterator(nil)
   341  		checkIteratorNoDups(t, it, seen)
   342  		missing, ok := it.Error().(*MissingNodeError)
   343  		if !ok || missing.NodeHash != rkey {
   344  			t.Fatal("didn't hit missing node, got", it.Error())
   345  		}
   346  
   347  //重新添加节点并继续迭代。
   348  		if memonly {
   349  			triedb.dirties[rkey] = robj
   350  		} else {
   351  			diskdb.Put(rkey[:], rval)
   352  		}
   353  		checkIteratorNoDups(t, it, seen)
   354  		if it.Error() != nil {
   355  			t.Fatal("unexpected error", it.Error())
   356  		}
   357  		if len(seen) != wantNodeCount {
   358  			t.Fatal("wrong node iteration count, got", len(seen), "want", wantNodeCount)
   359  		}
   360  	}
   361  }
   362  
   363  //与上面的测试类似,这个测试检查在
   364  //调用next时,某些键前缀的行为正确。接下来的期望是
   365  //应在第一次返回true之前重试查找。
   366  func TestIteratorContinueAfterSeekErrorDisk(t *testing.T) {
   367  	testIteratorContinueAfterSeekError(t, false)
   368  }
   369  func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
   370  	testIteratorContinueAfterSeekError(t, true)
   371  }
   372  
   373  func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
   374  //将测试trie提交到db,然后删除包含“bars”的节点。
   375  	diskdb := ethdb.NewMemDatabase()
   376  	triedb := NewDatabase(diskdb)
   377  
   378  	ctr, _ := New(common.Hash{}, triedb)
   379  	for _, val := range testdata1 {
   380  		ctr.Update([]byte(val.k), []byte(val.v))
   381  	}
   382  	root, _ := ctr.Commit(nil)
   383  	if !memonly {
   384  		triedb.Commit(root, true)
   385  	}
   386  	barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
   387  	var (
   388  		barNodeBlob []byte
   389  		barNodeObj  *cachedNode
   390  	)
   391  	if memonly {
   392  		barNodeObj = triedb.dirties[barNodeHash]
   393  		delete(triedb.dirties, barNodeHash)
   394  	} else {
   395  		barNodeBlob, _ = diskdb.Get(barNodeHash[:])
   396  		diskdb.Delete(barNodeHash[:])
   397  	}
   398  //创建一个寻找“条”的新迭代器。搜索无法继续,因为
   399  //缺少节点。
   400  	tr, _ := New(root, triedb)
   401  	it := tr.NodeIterator([]byte("bars"))
   402  	missing, ok := it.Error().(*MissingNodeError)
   403  	if !ok {
   404  		t.Fatal("want MissingNodeError, got", it.Error())
   405  	} else if missing.NodeHash != barNodeHash {
   406  		t.Fatal("wrong node missing")
   407  	}
   408  //重新插入丢失的节点。
   409  	if memonly {
   410  		triedb.dirties[barNodeHash] = barNodeObj
   411  	} else {
   412  		diskdb.Put(barNodeHash[:], barNodeBlob)
   413  	}
   414  //检查迭代是否生成正确的值集。
   415  	if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
   416  		t.Fatal(err)
   417  	}
   418  }
   419  
   420  func checkIteratorNoDups(t *testing.T, it NodeIterator, seen map[string]bool) int {
   421  	if seen == nil {
   422  		seen = make(map[string]bool)
   423  	}
   424  	for it.Next(true) {
   425  		if seen[string(it.Path())] {
   426  			t.Fatalf("iterator visited node path %x twice", it.Path())
   427  		}
   428  		seen[string(it.Path())] = true
   429  	}
   430  	return len(seen)
   431  }
   432