github.com/myafeier/go-ethereum@v1.6.8-0.20170719123245-3e0dbe0eaa72/trie/trie_test.go (about)

     1  // Copyright 2014 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package trie
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/binary"
    22  	"errors"
    23  	"fmt"
    24  	"io/ioutil"
    25  	"math/rand"
    26  	"os"
    27  	"reflect"
    28  	"testing"
    29  	"testing/quick"
    30  
    31  	"github.com/davecgh/go-spew/spew"
    32  	"github.com/ethereum/go-ethereum/common"
    33  	"github.com/ethereum/go-ethereum/ethdb"
    34  )
    35  
    36  func init() {
    37  	spew.Config.Indent = "    "
    38  	spew.Config.DisableMethods = false
    39  }
    40  
    41  // Used for testing
    42  func newEmpty() *Trie {
    43  	db, _ := ethdb.NewMemDatabase()
    44  	trie, _ := New(common.Hash{}, db)
    45  	return trie
    46  }
    47  
    48  func TestEmptyTrie(t *testing.T) {
    49  	var trie Trie
    50  	res := trie.Hash()
    51  	exp := emptyRoot
    52  	if res != common.Hash(exp) {
    53  		t.Errorf("expected %x got %x", exp, res)
    54  	}
    55  }
    56  
    57  func TestNull(t *testing.T) {
    58  	var trie Trie
    59  	key := make([]byte, 32)
    60  	value := []byte("test")
    61  	trie.Update(key, value)
    62  	if !bytes.Equal(trie.Get(key), value) {
    63  		t.Fatal("wrong value")
    64  	}
    65  }
    66  
    67  func TestMissingRoot(t *testing.T) {
    68  	db, _ := ethdb.NewMemDatabase()
    69  	trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), db)
    70  	if trie != nil {
    71  		t.Error("New returned non-nil trie for invalid root")
    72  	}
    73  	if _, ok := err.(*MissingNodeError); !ok {
    74  		t.Errorf("New returned wrong error: %v", err)
    75  	}
    76  }
    77  
    78  func TestMissingNode(t *testing.T) {
    79  	db, _ := ethdb.NewMemDatabase()
    80  	trie, _ := New(common.Hash{}, db)
    81  	updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer")
    82  	updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf")
    83  	root, _ := trie.Commit()
    84  
    85  	trie, _ = New(root, db)
    86  	_, err := trie.TryGet([]byte("120000"))
    87  	if err != nil {
    88  		t.Errorf("Unexpected error: %v", err)
    89  	}
    90  
    91  	trie, _ = New(root, db)
    92  	_, err = trie.TryGet([]byte("120099"))
    93  	if err != nil {
    94  		t.Errorf("Unexpected error: %v", err)
    95  	}
    96  
    97  	trie, _ = New(root, db)
    98  	_, err = trie.TryGet([]byte("123456"))
    99  	if err != nil {
   100  		t.Errorf("Unexpected error: %v", err)
   101  	}
   102  
   103  	trie, _ = New(root, db)
   104  	err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
   105  	if err != nil {
   106  		t.Errorf("Unexpected error: %v", err)
   107  	}
   108  
   109  	trie, _ = New(root, db)
   110  	err = trie.TryDelete([]byte("123456"))
   111  	if err != nil {
   112  		t.Errorf("Unexpected error: %v", err)
   113  	}
   114  
   115  	db.Delete(common.FromHex("e1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9"))
   116  
   117  	trie, _ = New(root, db)
   118  	_, err = trie.TryGet([]byte("120000"))
   119  	if _, ok := err.(*MissingNodeError); !ok {
   120  		t.Errorf("Wrong error: %v", err)
   121  	}
   122  
   123  	trie, _ = New(root, db)
   124  	_, err = trie.TryGet([]byte("120099"))
   125  	if _, ok := err.(*MissingNodeError); !ok {
   126  		t.Errorf("Wrong error: %v", err)
   127  	}
   128  
   129  	trie, _ = New(root, db)
   130  	_, err = trie.TryGet([]byte("123456"))
   131  	if err != nil {
   132  		t.Errorf("Unexpected error: %v", err)
   133  	}
   134  
   135  	trie, _ = New(root, db)
   136  	err = trie.TryUpdate([]byte("120099"), []byte("zxcv"))
   137  	if _, ok := err.(*MissingNodeError); !ok {
   138  		t.Errorf("Wrong error: %v", err)
   139  	}
   140  
   141  	trie, _ = New(root, db)
   142  	err = trie.TryDelete([]byte("123456"))
   143  	if _, ok := err.(*MissingNodeError); !ok {
   144  		t.Errorf("Wrong error: %v", err)
   145  	}
   146  }
   147  
   148  func TestInsert(t *testing.T) {
   149  	trie := newEmpty()
   150  
   151  	updateString(trie, "doe", "reindeer")
   152  	updateString(trie, "dog", "puppy")
   153  	updateString(trie, "dogglesworth", "cat")
   154  
   155  	exp := common.HexToHash("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
   156  	root := trie.Hash()
   157  	if root != exp {
   158  		t.Errorf("exp %x got %x", exp, root)
   159  	}
   160  
   161  	trie = newEmpty()
   162  	updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
   163  
   164  	exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
   165  	root, err := trie.Commit()
   166  	if err != nil {
   167  		t.Fatalf("commit error: %v", err)
   168  	}
   169  	if root != exp {
   170  		t.Errorf("exp %x got %x", exp, root)
   171  	}
   172  }
   173  
   174  func TestGet(t *testing.T) {
   175  	trie := newEmpty()
   176  	updateString(trie, "doe", "reindeer")
   177  	updateString(trie, "dog", "puppy")
   178  	updateString(trie, "dogglesworth", "cat")
   179  
   180  	for i := 0; i < 2; i++ {
   181  		res := getString(trie, "dog")
   182  		if !bytes.Equal(res, []byte("puppy")) {
   183  			t.Errorf("expected puppy got %x", res)
   184  		}
   185  
   186  		unknown := getString(trie, "unknown")
   187  		if unknown != nil {
   188  			t.Errorf("expected nil got %x", unknown)
   189  		}
   190  
   191  		if i == 1 {
   192  			return
   193  		}
   194  		trie.Commit()
   195  	}
   196  }
   197  
   198  func TestDelete(t *testing.T) {
   199  	trie := newEmpty()
   200  	vals := []struct{ k, v string }{
   201  		{"do", "verb"},
   202  		{"ether", "wookiedoo"},
   203  		{"horse", "stallion"},
   204  		{"shaman", "horse"},
   205  		{"doge", "coin"},
   206  		{"ether", ""},
   207  		{"dog", "puppy"},
   208  		{"shaman", ""},
   209  	}
   210  	for _, val := range vals {
   211  		if val.v != "" {
   212  			updateString(trie, val.k, val.v)
   213  		} else {
   214  			deleteString(trie, val.k)
   215  		}
   216  	}
   217  
   218  	hash := trie.Hash()
   219  	exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
   220  	if hash != exp {
   221  		t.Errorf("expected %x got %x", exp, hash)
   222  	}
   223  }
   224  
   225  func TestEmptyValues(t *testing.T) {
   226  	trie := newEmpty()
   227  
   228  	vals := []struct{ k, v string }{
   229  		{"do", "verb"},
   230  		{"ether", "wookiedoo"},
   231  		{"horse", "stallion"},
   232  		{"shaman", "horse"},
   233  		{"doge", "coin"},
   234  		{"ether", ""},
   235  		{"dog", "puppy"},
   236  		{"shaman", ""},
   237  	}
   238  	for _, val := range vals {
   239  		updateString(trie, val.k, val.v)
   240  	}
   241  
   242  	hash := trie.Hash()
   243  	exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
   244  	if hash != exp {
   245  		t.Errorf("expected %x got %x", exp, hash)
   246  	}
   247  }
   248  
   249  func TestReplication(t *testing.T) {
   250  	trie := newEmpty()
   251  	vals := []struct{ k, v string }{
   252  		{"do", "verb"},
   253  		{"ether", "wookiedoo"},
   254  		{"horse", "stallion"},
   255  		{"shaman", "horse"},
   256  		{"doge", "coin"},
   257  		{"dog", "puppy"},
   258  		{"somethingveryoddindeedthis is", "myothernodedata"},
   259  	}
   260  	for _, val := range vals {
   261  		updateString(trie, val.k, val.v)
   262  	}
   263  	exp, err := trie.Commit()
   264  	if err != nil {
   265  		t.Fatalf("commit error: %v", err)
   266  	}
   267  
   268  	// create a new trie on top of the database and check that lookups work.
   269  	trie2, err := New(exp, trie.db)
   270  	if err != nil {
   271  		t.Fatalf("can't recreate trie at %x: %v", exp, err)
   272  	}
   273  	for _, kv := range vals {
   274  		if string(getString(trie2, kv.k)) != kv.v {
   275  			t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v)
   276  		}
   277  	}
   278  	hash, err := trie2.Commit()
   279  	if err != nil {
   280  		t.Fatalf("commit error: %v", err)
   281  	}
   282  	if hash != exp {
   283  		t.Errorf("root failure. expected %x got %x", exp, hash)
   284  	}
   285  
   286  	// perform some insertions on the new trie.
   287  	vals2 := []struct{ k, v string }{
   288  		{"do", "verb"},
   289  		{"ether", "wookiedoo"},
   290  		{"horse", "stallion"},
   291  		// {"shaman", "horse"},
   292  		// {"doge", "coin"},
   293  		// {"ether", ""},
   294  		// {"dog", "puppy"},
   295  		// {"somethingveryoddindeedthis is", "myothernodedata"},
   296  		// {"shaman", ""},
   297  	}
   298  	for _, val := range vals2 {
   299  		updateString(trie2, val.k, val.v)
   300  	}
   301  	if hash := trie2.Hash(); hash != exp {
   302  		t.Errorf("root failure. expected %x got %x", exp, hash)
   303  	}
   304  }
   305  
   306  func TestLargeValue(t *testing.T) {
   307  	trie := newEmpty()
   308  	trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
   309  	trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32))
   310  	trie.Hash()
   311  }
   312  
   313  type countingDB struct {
   314  	Database
   315  	gets map[string]int
   316  }
   317  
   318  func (db *countingDB) Get(key []byte) ([]byte, error) {
   319  	db.gets[string(key)]++
   320  	return db.Database.Get(key)
   321  }
   322  
   323  // TestCacheUnload checks that decoded nodes are unloaded after a
   324  // certain number of commit operations.
   325  func TestCacheUnload(t *testing.T) {
   326  	// Create test trie with two branches.
   327  	trie := newEmpty()
   328  	key1 := "---------------------------------"
   329  	key2 := "---some other branch"
   330  	updateString(trie, key1, "this is the branch of key1.")
   331  	updateString(trie, key2, "this is the branch of key2.")
   332  	root, _ := trie.Commit()
   333  
   334  	// Commit the trie repeatedly and access key1.
   335  	// The branch containing it is loaded from DB exactly two times:
   336  	// in the 0th and 6th iteration.
   337  	db := &countingDB{Database: trie.db, gets: make(map[string]int)}
   338  	trie, _ = New(root, db)
   339  	trie.SetCacheLimit(5)
   340  	for i := 0; i < 12; i++ {
   341  		getString(trie, key1)
   342  		trie.Commit()
   343  	}
   344  
   345  	// Check that it got loaded two times.
   346  	for dbkey, count := range db.gets {
   347  		if count != 2 {
   348  			t.Errorf("db key %x loaded %d times, want %d times", []byte(dbkey), count, 2)
   349  		}
   350  	}
   351  }
   352  
   353  // randTest performs random trie operations.
   354  // Instances of this test are created by Generate.
   355  type randTest []randTestStep
   356  
   357  type randTestStep struct {
   358  	op    int
   359  	key   []byte // for opUpdate, opDelete, opGet
   360  	value []byte // for opUpdate
   361  	err   error  // for debugging
   362  }
   363  
   364  const (
   365  	opUpdate = iota
   366  	opDelete
   367  	opGet
   368  	opCommit
   369  	opHash
   370  	opReset
   371  	opItercheckhash
   372  	opCheckCacheInvariant
   373  	opMax // boundary value, not an actual op
   374  )
   375  
   376  func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
   377  	var allKeys [][]byte
   378  	genKey := func() []byte {
   379  		if len(allKeys) < 2 || r.Intn(100) < 10 {
   380  			// new key
   381  			key := make([]byte, r.Intn(50))
   382  			r.Read(key)
   383  			allKeys = append(allKeys, key)
   384  			return key
   385  		}
   386  		// use existing key
   387  		return allKeys[r.Intn(len(allKeys))]
   388  	}
   389  
   390  	var steps randTest
   391  	for i := 0; i < size; i++ {
   392  		step := randTestStep{op: r.Intn(opMax)}
   393  		switch step.op {
   394  		case opUpdate:
   395  			step.key = genKey()
   396  			step.value = make([]byte, 8)
   397  			binary.BigEndian.PutUint64(step.value, uint64(i))
   398  		case opGet, opDelete:
   399  			step.key = genKey()
   400  		}
   401  		steps = append(steps, step)
   402  	}
   403  	return reflect.ValueOf(steps)
   404  }
   405  
   406  func runRandTest(rt randTest) bool {
   407  	db, _ := ethdb.NewMemDatabase()
   408  	tr, _ := New(common.Hash{}, db)
   409  	values := make(map[string]string) // tracks content of the trie
   410  
   411  	for i, step := range rt {
   412  		switch step.op {
   413  		case opUpdate:
   414  			tr.Update(step.key, step.value)
   415  			values[string(step.key)] = string(step.value)
   416  		case opDelete:
   417  			tr.Delete(step.key)
   418  			delete(values, string(step.key))
   419  		case opGet:
   420  			v := tr.Get(step.key)
   421  			want := values[string(step.key)]
   422  			if string(v) != want {
   423  				rt[i].err = fmt.Errorf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want)
   424  			}
   425  		case opCommit:
   426  			_, rt[i].err = tr.Commit()
   427  		case opHash:
   428  			tr.Hash()
   429  		case opReset:
   430  			hash, err := tr.Commit()
   431  			if err != nil {
   432  				rt[i].err = err
   433  				return false
   434  			}
   435  			newtr, err := New(hash, db)
   436  			if err != nil {
   437  				rt[i].err = err
   438  				return false
   439  			}
   440  			tr = newtr
   441  		case opItercheckhash:
   442  			checktr, _ := New(common.Hash{}, nil)
   443  			it := NewIterator(tr.NodeIterator(nil))
   444  			for it.Next() {
   445  				checktr.Update(it.Key, it.Value)
   446  			}
   447  			if tr.Hash() != checktr.Hash() {
   448  				rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash")
   449  			}
   450  		case opCheckCacheInvariant:
   451  			rt[i].err = checkCacheInvariant(tr.root, nil, tr.cachegen, false, 0)
   452  		}
   453  		// Abort the test on error.
   454  		if rt[i].err != nil {
   455  			return false
   456  		}
   457  	}
   458  	return true
   459  }
   460  
   461  func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool, depth int) error {
   462  	var children []node
   463  	var flag nodeFlag
   464  	switch n := n.(type) {
   465  	case *shortNode:
   466  		flag = n.flags
   467  		children = []node{n.Val}
   468  	case *fullNode:
   469  		flag = n.flags
   470  		children = n.Children[:]
   471  	default:
   472  		return nil
   473  	}
   474  
   475  	errorf := func(format string, args ...interface{}) error {
   476  		msg := fmt.Sprintf(format, args...)
   477  		msg += fmt.Sprintf("\nat depth %d node %s", depth, spew.Sdump(n))
   478  		msg += fmt.Sprintf("parent: %s", spew.Sdump(parent))
   479  		return errors.New(msg)
   480  	}
   481  	if flag.gen > parentCachegen {
   482  		return errorf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen)
   483  	}
   484  	if depth > 0 && !parentDirty && flag.dirty {
   485  		return errorf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen)
   486  	}
   487  	for _, child := range children {
   488  		if err := checkCacheInvariant(child, n, flag.gen, flag.dirty, depth+1); err != nil {
   489  			return err
   490  		}
   491  	}
   492  	return nil
   493  }
   494  
   495  func TestRandom(t *testing.T) {
   496  	if err := quick.Check(runRandTest, nil); err != nil {
   497  		if cerr, ok := err.(*quick.CheckError); ok {
   498  			t.Fatalf("random test iteration %d failed: %s", cerr.Count, spew.Sdump(cerr.In))
   499  		}
   500  		t.Fatal(err)
   501  	}
   502  }
   503  
   504  func BenchmarkGet(b *testing.B)      { benchGet(b, false) }
   505  func BenchmarkGetDB(b *testing.B)    { benchGet(b, true) }
   506  func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) }
   507  func BenchmarkUpdateLE(b *testing.B) { benchUpdate(b, binary.LittleEndian) }
   508  func BenchmarkHashBE(b *testing.B)   { benchHash(b, binary.BigEndian) }
   509  func BenchmarkHashLE(b *testing.B)   { benchHash(b, binary.LittleEndian) }
   510  
   511  const benchElemCount = 20000
   512  
   513  func benchGet(b *testing.B, commit bool) {
   514  	trie := new(Trie)
   515  	if commit {
   516  		_, tmpdb := tempDB()
   517  		trie, _ = New(common.Hash{}, tmpdb)
   518  	}
   519  	k := make([]byte, 32)
   520  	for i := 0; i < benchElemCount; i++ {
   521  		binary.LittleEndian.PutUint64(k, uint64(i))
   522  		trie.Update(k, k)
   523  	}
   524  	binary.LittleEndian.PutUint64(k, benchElemCount/2)
   525  	if commit {
   526  		trie.Commit()
   527  	}
   528  
   529  	b.ResetTimer()
   530  	for i := 0; i < b.N; i++ {
   531  		trie.Get(k)
   532  	}
   533  	b.StopTimer()
   534  
   535  	if commit {
   536  		ldb := trie.db.(*ethdb.LDBDatabase)
   537  		ldb.Close()
   538  		os.RemoveAll(ldb.Path())
   539  	}
   540  }
   541  
   542  func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie {
   543  	trie := newEmpty()
   544  	k := make([]byte, 32)
   545  	for i := 0; i < b.N; i++ {
   546  		e.PutUint64(k, uint64(i))
   547  		trie.Update(k, k)
   548  	}
   549  	return trie
   550  }
   551  
   552  func benchHash(b *testing.B, e binary.ByteOrder) {
   553  	trie := newEmpty()
   554  	k := make([]byte, 32)
   555  	for i := 0; i < benchElemCount; i++ {
   556  		e.PutUint64(k, uint64(i))
   557  		trie.Update(k, k)
   558  	}
   559  
   560  	b.ResetTimer()
   561  	for i := 0; i < b.N; i++ {
   562  		trie.Hash()
   563  	}
   564  }
   565  
   566  func tempDB() (string, Database) {
   567  	dir, err := ioutil.TempDir("", "trie-bench")
   568  	if err != nil {
   569  		panic(fmt.Sprintf("can't create temporary directory: %v", err))
   570  	}
   571  	db, err := ethdb.NewLDBDatabase(dir, 256, 0)
   572  	if err != nil {
   573  		panic(fmt.Sprintf("can't create temporary database: %v", err))
   574  	}
   575  	return dir, db
   576  }
   577  
   578  func getString(trie *Trie, k string) []byte {
   579  	return trie.Get([]byte(k))
   580  }
   581  
   582  func updateString(trie *Trie, k, v string) {
   583  	trie.Update([]byte(k), []byte(v))
   584  }
   585  
   586  func deleteString(trie *Trie, k string) {
   587  	trie.Delete([]byte(k))
   588  }