github.com/sberex/go-sberex@v1.8.2-0.20181113200658-ed96ac38f7d7/trie/trie_test.go (about)

     1  // This file is part of the go-sberex library. The go-sberex library is 
     2  // free software: you can redistribute it and/or modify it under the terms 
     3  // of the GNU Lesser General Public License as published by the Free 
     4  // Software Foundation, either version 3 of the License, or (at your option)
     5  // any later version.
     6  //
     7  // The go-sberex library is distributed in the hope that it will be useful, 
     8  // but WITHOUT ANY WARRANTY; without even the implied warranty of
     9  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser 
    10  // General Public License <http://www.gnu.org/licenses/> for more details.
    11  
    12  package trie
    13  
    14  import (
    15  	"bytes"
    16  	"encoding/binary"
    17  	"errors"
    18  	"fmt"
    19  	"io/ioutil"
    20  	"math/big"
    21  	"math/rand"
    22  	"os"
    23  	"reflect"
    24  	"testing"
    25  	"testing/quick"
    26  
    27  	"github.com/davecgh/go-spew/spew"
    28  	"github.com/Sberex/go-sberex/common"
    29  	"github.com/Sberex/go-sberex/crypto"
    30  	"github.com/Sberex/go-sberex/ethdb"
    31  	"github.com/Sberex/go-sberex/rlp"
    32  )
    33  
    34  func init() {
    35  	spew.Config.Indent = "    "
    36  	spew.Config.DisableMethods = false
    37  }
    38  
    39  // Used for testing
    40  func newEmpty() *Trie {
    41  	diskdb, _ := ethdb.NewMemDatabase()
    42  	trie, _ := New(common.Hash{}, NewDatabase(diskdb))
    43  	return trie
    44  }
    45  
    46  func TestEmptyTrie(t *testing.T) {
    47  	var trie Trie
    48  	res := trie.Hash()
    49  	exp := emptyRoot
    50  	if res != common.Hash(exp) {
    51  		t.Errorf("expected %x got %x", exp, res)
    52  	}
    53  }
    54  
    55  func TestNull(t *testing.T) {
    56  	var trie Trie
    57  	key := make([]byte, 32)
    58  	value := []byte("test")
    59  	trie.Update(key, value)
    60  	if !bytes.Equal(trie.Get(key), value) {
    61  		t.Fatal("wrong value")
    62  	}
    63  }
    64  
    65  func TestMissingRoot(t *testing.T) {
    66  	diskdb, _ := ethdb.NewMemDatabase()
    67  	trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), NewDatabase(diskdb))
    68  	if trie != nil {
    69  		t.Error("New returned non-nil trie for invalid root")
    70  	}
    71  	if _, ok := err.(*MissingNodeError); !ok {
    72  		t.Errorf("New returned wrong error: %v", err)
    73  	}
    74  }
    75  
    76  func TestMissingNodeDisk(t *testing.T)    { testMissingNode(t, false) }
    77  func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) }
    78  
    79  func testMissingNode(t *testing.T, memonly bool) {
    80  	diskdb, _ := ethdb.NewMemDatabase()
    81  	triedb := NewDatabase(diskdb)
    82  
    83  	trie, _ := New(common.Hash{}, triedb)
    84  	updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer")
    85  	updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf")
    86  	root, _ := trie.Commit(nil)
    87  	if !memonly {
    88  		triedb.Commit(root, true)
    89  	}
    90  
    91  	trie, _ = New(root, triedb)
    92  	_, err := trie.TryGet([]byte("120000"))
    93  	if err != nil {
    94  		t.Errorf("Unexpected error: %v", err)
    95  	}
    96  	trie, _ = New(root, triedb)
    97  	_, err = trie.TryGet([]byte("120099"))
    98  	if err != nil {
    99  		t.Errorf("Unexpected error: %v", err)
   100  	}
   101  	trie, _ = New(root, triedb)
   102  	_, err = trie.TryGet([]byte("123456"))
   103  	if err != nil {
   104  		t.Errorf("Unexpected error: %v", err)
   105  	}
   106  	trie, _ = New(root, triedb)
   107  	err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
   108  	if err != nil {
   109  		t.Errorf("Unexpected error: %v", err)
   110  	}
   111  	trie, _ = New(root, triedb)
   112  	err = trie.TryDelete([]byte("123456"))
   113  	if err != nil {
   114  		t.Errorf("Unexpected error: %v", err)
   115  	}
   116  
   117  	hash := common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")
   118  	if memonly {
   119  		delete(triedb.nodes, hash)
   120  	} else {
   121  		diskdb.Delete(hash[:])
   122  	}
   123  
   124  	trie, _ = New(root, triedb)
   125  	_, err = trie.TryGet([]byte("120000"))
   126  	if _, ok := err.(*MissingNodeError); !ok {
   127  		t.Errorf("Wrong error: %v", err)
   128  	}
   129  	trie, _ = New(root, triedb)
   130  	_, err = trie.TryGet([]byte("120099"))
   131  	if _, ok := err.(*MissingNodeError); !ok {
   132  		t.Errorf("Wrong error: %v", err)
   133  	}
   134  	trie, _ = New(root, triedb)
   135  	_, err = trie.TryGet([]byte("123456"))
   136  	if err != nil {
   137  		t.Errorf("Unexpected error: %v", err)
   138  	}
   139  	trie, _ = New(root, triedb)
   140  	err = trie.TryUpdate([]byte("120099"), []byte("zxcv"))
   141  	if _, ok := err.(*MissingNodeError); !ok {
   142  		t.Errorf("Wrong error: %v", err)
   143  	}
   144  	trie, _ = New(root, triedb)
   145  	err = trie.TryDelete([]byte("123456"))
   146  	if _, ok := err.(*MissingNodeError); !ok {
   147  		t.Errorf("Wrong error: %v", err)
   148  	}
   149  }
   150  
   151  func TestInsert(t *testing.T) {
   152  	trie := newEmpty()
   153  
   154  	updateString(trie, "doe", "reindeer")
   155  	updateString(trie, "dog", "puppy")
   156  	updateString(trie, "dogglesworth", "cat")
   157  
   158  	exp := common.HexToHash("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
   159  	root := trie.Hash()
   160  	if root != exp {
   161  		t.Errorf("exp %x got %x", exp, root)
   162  	}
   163  
   164  	trie = newEmpty()
   165  	updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
   166  
   167  	exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
   168  	root, err := trie.Commit(nil)
   169  	if err != nil {
   170  		t.Fatalf("commit error: %v", err)
   171  	}
   172  	if root != exp {
   173  		t.Errorf("exp %x got %x", exp, root)
   174  	}
   175  }
   176  
   177  func TestGet(t *testing.T) {
   178  	trie := newEmpty()
   179  	updateString(trie, "doe", "reindeer")
   180  	updateString(trie, "dog", "puppy")
   181  	updateString(trie, "dogglesworth", "cat")
   182  
   183  	for i := 0; i < 2; i++ {
   184  		res := getString(trie, "dog")
   185  		if !bytes.Equal(res, []byte("puppy")) {
   186  			t.Errorf("expected puppy got %x", res)
   187  		}
   188  
   189  		unknown := getString(trie, "unknown")
   190  		if unknown != nil {
   191  			t.Errorf("expected nil got %x", unknown)
   192  		}
   193  
   194  		if i == 1 {
   195  			return
   196  		}
   197  		trie.Commit(nil)
   198  	}
   199  }
   200  
   201  func TestDelete(t *testing.T) {
   202  	trie := newEmpty()
   203  	vals := []struct{ k, v string }{
   204  		{"do", "verb"},
   205  		{"sbr", "wookiedoo"},
   206  		{"horse", "stallion"},
   207  		{"shaman", "horse"},
   208  		{"doge", "coin"},
   209  		{"sbr", ""},
   210  		{"dog", "puppy"},
   211  		{"shaman", ""},
   212  	}
   213  	for _, val := range vals {
   214  		if val.v != "" {
   215  			updateString(trie, val.k, val.v)
   216  		} else {
   217  			deleteString(trie, val.k)
   218  		}
   219  	}
   220  
   221  	hash := trie.Hash()
   222  	exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
   223  	if hash != exp {
   224  		t.Errorf("expected %x got %x", exp, hash)
   225  	}
   226  }
   227  
   228  func TestEmptyValues(t *testing.T) {
   229  	trie := newEmpty()
   230  
   231  	vals := []struct{ k, v string }{
   232  		{"do", "verb"},
   233  		{"sbr", "wookiedoo"},
   234  		{"horse", "stallion"},
   235  		{"shaman", "horse"},
   236  		{"doge", "coin"},
   237  		{"sbr", ""},
   238  		{"dog", "puppy"},
   239  		{"shaman", ""},
   240  	}
   241  	for _, val := range vals {
   242  		updateString(trie, val.k, val.v)
   243  	}
   244  
   245  	hash := trie.Hash()
   246  	exp := common.HexToHash("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
   247  	if hash != exp {
   248  		t.Errorf("expected %x got %x", exp, hash)
   249  	}
   250  }
   251  
   252  func TestReplication(t *testing.T) {
   253  	trie := newEmpty()
   254  	vals := []struct{ k, v string }{
   255  		{"do", "verb"},
   256  		{"sbr", "wookiedoo"},
   257  		{"horse", "stallion"},
   258  		{"shaman", "horse"},
   259  		{"doge", "coin"},
   260  		{"dog", "puppy"},
   261  		{"somethingveryoddindeedthis is", "myothernodedata"},
   262  	}
   263  	for _, val := range vals {
   264  		updateString(trie, val.k, val.v)
   265  	}
   266  	exp, err := trie.Commit(nil)
   267  	if err != nil {
   268  		t.Fatalf("commit error: %v", err)
   269  	}
   270  
   271  	// create a new trie on top of the database and check that lookups work.
   272  	trie2, err := New(exp, trie.db)
   273  	if err != nil {
   274  		t.Fatalf("can't recreate trie at %x: %v", exp, err)
   275  	}
   276  	for _, kv := range vals {
   277  		if string(getString(trie2, kv.k)) != kv.v {
   278  			t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v)
   279  		}
   280  	}
   281  	hash, err := trie2.Commit(nil)
   282  	if err != nil {
   283  		t.Fatalf("commit error: %v", err)
   284  	}
   285  	if hash != exp {
   286  		t.Errorf("root failure. expected %x got %x", exp, hash)
   287  	}
   288  
   289  	// perform some insertions on the new trie.
   290  	vals2 := []struct{ k, v string }{
   291  		{"do", "verb"},
   292  		{"sbr", "wookiedoo"},
   293  		{"horse", "stallion"},
   294  		// {"shaman", "horse"},
   295  		// {"doge", "coin"},
   296  		// {"sbr", ""},
   297  		// {"dog", "puppy"},
   298  		// {"somethingveryoddindeedthis is", "myothernodedata"},
   299  		// {"shaman", ""},
   300  	}
   301  	for _, val := range vals2 {
   302  		updateString(trie2, val.k, val.v)
   303  	}
   304  	if hash := trie2.Hash(); hash != exp {
   305  		t.Errorf("root failure. expected %x got %x", exp, hash)
   306  	}
   307  }
   308  
   309  func TestLargeValue(t *testing.T) {
   310  	trie := newEmpty()
   311  	trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
   312  	trie.Update([]byte("key2"), bytes.Repeat([]byte{1}, 32))
   313  	trie.Hash()
   314  }
   315  
   316  type countingDB struct {
   317  	ethdb.Database
   318  	gets map[string]int
   319  }
   320  
   321  func (db *countingDB) Get(key []byte) ([]byte, error) {
   322  	db.gets[string(key)]++
   323  	return db.Database.Get(key)
   324  }
   325  
   326  // TestCacheUnload checks that decoded nodes are unloaded after a
   327  // certain number of commit operations.
   328  func TestCacheUnload(t *testing.T) {
   329  	// Create test trie with two branches.
   330  	trie := newEmpty()
   331  	key1 := "---------------------------------"
   332  	key2 := "---some other branch"
   333  	updateString(trie, key1, "this is the branch of key1.")
   334  	updateString(trie, key2, "this is the branch of key2.")
   335  
   336  	root, _ := trie.Commit(nil)
   337  	trie.db.Commit(root, true)
   338  
   339  	// Commit the trie repeatedly and access key1.
   340  	// The branch containing it is loaded from DB exactly two times:
   341  	// in the 0th and 6th iteration.
   342  	db := &countingDB{Database: trie.db.diskdb, gets: make(map[string]int)}
   343  	trie, _ = New(root, NewDatabase(db))
   344  	trie.SetCacheLimit(5)
   345  	for i := 0; i < 12; i++ {
   346  		getString(trie, key1)
   347  		trie.Commit(nil)
   348  	}
   349  	// Check that it got loaded two times.
   350  	for dbkey, count := range db.gets {
   351  		if count != 2 {
   352  			t.Errorf("db key %x loaded %d times, want %d times", []byte(dbkey), count, 2)
   353  		}
   354  	}
   355  }
   356  
   357  // randTest performs random trie operations.
   358  // Instances of this test are created by Generate.
   359  type randTest []randTestStep
   360  
   361  type randTestStep struct {
   362  	op    int
   363  	key   []byte // for opUpdate, opDelete, opGet
   364  	value []byte // for opUpdate
   365  	err   error  // for debugging
   366  }
   367  
   368  const (
   369  	opUpdate = iota
   370  	opDelete
   371  	opGet
   372  	opCommit
   373  	opHash
   374  	opReset
   375  	opItercheckhash
   376  	opCheckCacheInvariant
   377  	opMax // boundary value, not an actual op
   378  )
   379  
   380  func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
   381  	var allKeys [][]byte
   382  	genKey := func() []byte {
   383  		if len(allKeys) < 2 || r.Intn(100) < 10 {
   384  			// new key
   385  			key := make([]byte, r.Intn(50))
   386  			r.Read(key)
   387  			allKeys = append(allKeys, key)
   388  			return key
   389  		}
   390  		// use existing key
   391  		return allKeys[r.Intn(len(allKeys))]
   392  	}
   393  
   394  	var steps randTest
   395  	for i := 0; i < size; i++ {
   396  		step := randTestStep{op: r.Intn(opMax)}
   397  		switch step.op {
   398  		case opUpdate:
   399  			step.key = genKey()
   400  			step.value = make([]byte, 8)
   401  			binary.BigEndian.PutUint64(step.value, uint64(i))
   402  		case opGet, opDelete:
   403  			step.key = genKey()
   404  		}
   405  		steps = append(steps, step)
   406  	}
   407  	return reflect.ValueOf(steps)
   408  }
   409  
   410  func runRandTest(rt randTest) bool {
   411  	diskdb, _ := ethdb.NewMemDatabase()
   412  	triedb := NewDatabase(diskdb)
   413  
   414  	tr, _ := New(common.Hash{}, triedb)
   415  	values := make(map[string]string) // tracks content of the trie
   416  
   417  	for i, step := range rt {
   418  		switch step.op {
   419  		case opUpdate:
   420  			tr.Update(step.key, step.value)
   421  			values[string(step.key)] = string(step.value)
   422  		case opDelete:
   423  			tr.Delete(step.key)
   424  			delete(values, string(step.key))
   425  		case opGet:
   426  			v := tr.Get(step.key)
   427  			want := values[string(step.key)]
   428  			if string(v) != want {
   429  				rt[i].err = fmt.Errorf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want)
   430  			}
   431  		case opCommit:
   432  			_, rt[i].err = tr.Commit(nil)
   433  		case opHash:
   434  			tr.Hash()
   435  		case opReset:
   436  			hash, err := tr.Commit(nil)
   437  			if err != nil {
   438  				rt[i].err = err
   439  				return false
   440  			}
   441  			newtr, err := New(hash, triedb)
   442  			if err != nil {
   443  				rt[i].err = err
   444  				return false
   445  			}
   446  			tr = newtr
   447  		case opItercheckhash:
   448  			checktr, _ := New(common.Hash{}, triedb)
   449  			it := NewIterator(tr.NodeIterator(nil))
   450  			for it.Next() {
   451  				checktr.Update(it.Key, it.Value)
   452  			}
   453  			if tr.Hash() != checktr.Hash() {
   454  				rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash")
   455  			}
   456  		case opCheckCacheInvariant:
   457  			rt[i].err = checkCacheInvariant(tr.root, nil, tr.cachegen, false, 0)
   458  		}
   459  		// Abort the test on error.
   460  		if rt[i].err != nil {
   461  			return false
   462  		}
   463  	}
   464  	return true
   465  }
   466  
   467  func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool, depth int) error {
   468  	var children []node
   469  	var flag nodeFlag
   470  	switch n := n.(type) {
   471  	case *shortNode:
   472  		flag = n.flags
   473  		children = []node{n.Val}
   474  	case *fullNode:
   475  		flag = n.flags
   476  		children = n.Children[:]
   477  	default:
   478  		return nil
   479  	}
   480  
   481  	errorf := func(format string, args ...interface{}) error {
   482  		msg := fmt.Sprintf(format, args...)
   483  		msg += fmt.Sprintf("\nat depth %d node %s", depth, spew.Sdump(n))
   484  		msg += fmt.Sprintf("parent: %s", spew.Sdump(parent))
   485  		return errors.New(msg)
   486  	}
   487  	if flag.gen > parentCachegen {
   488  		return errorf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen)
   489  	}
   490  	if depth > 0 && !parentDirty && flag.dirty {
   491  		return errorf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen)
   492  	}
   493  	for _, child := range children {
   494  		if err := checkCacheInvariant(child, n, flag.gen, flag.dirty, depth+1); err != nil {
   495  			return err
   496  		}
   497  	}
   498  	return nil
   499  }
   500  
   501  func TestRandom(t *testing.T) {
   502  	if err := quick.Check(runRandTest, nil); err != nil {
   503  		if cerr, ok := err.(*quick.CheckError); ok {
   504  			t.Fatalf("random test iteration %d failed: %s", cerr.Count, spew.Sdump(cerr.In))
   505  		}
   506  		t.Fatal(err)
   507  	}
   508  }
   509  
   510  func BenchmarkGet(b *testing.B)      { benchGet(b, false) }
   511  func BenchmarkGetDB(b *testing.B)    { benchGet(b, true) }
   512  func BenchmarkUpdateBE(b *testing.B) { benchUpdate(b, binary.BigEndian) }
   513  func BenchmarkUpdateLE(b *testing.B) { benchUpdate(b, binary.LittleEndian) }
   514  
   515  const benchElemCount = 20000
   516  
   517  func benchGet(b *testing.B, commit bool) {
   518  	trie := new(Trie)
   519  	if commit {
   520  		_, tmpdb := tempDB()
   521  		trie, _ = New(common.Hash{}, tmpdb)
   522  	}
   523  	k := make([]byte, 32)
   524  	for i := 0; i < benchElemCount; i++ {
   525  		binary.LittleEndian.PutUint64(k, uint64(i))
   526  		trie.Update(k, k)
   527  	}
   528  	binary.LittleEndian.PutUint64(k, benchElemCount/2)
   529  	if commit {
   530  		trie.Commit(nil)
   531  	}
   532  
   533  	b.ResetTimer()
   534  	for i := 0; i < b.N; i++ {
   535  		trie.Get(k)
   536  	}
   537  	b.StopTimer()
   538  
   539  	if commit {
   540  		ldb := trie.db.diskdb.(*ethdb.LDBDatabase)
   541  		ldb.Close()
   542  		os.RemoveAll(ldb.Path())
   543  	}
   544  }
   545  
   546  func benchUpdate(b *testing.B, e binary.ByteOrder) *Trie {
   547  	trie := newEmpty()
   548  	k := make([]byte, 32)
   549  	for i := 0; i < b.N; i++ {
   550  		e.PutUint64(k, uint64(i))
   551  		trie.Update(k, k)
   552  	}
   553  	return trie
   554  }
   555  
   556  // Benchmarks the trie hashing. Since the trie caches the result of any operation,
   557  // we cannot use b.N as the number of hashing rouns, since all rounds apart from
   558  // the first one will be NOOP. As such, we'll use b.N as the number of account to
   559  // insert into the trie before measuring the hashing.
   560  func BenchmarkHash(b *testing.B) {
   561  	// Make the random benchmark deterministic
   562  	random := rand.New(rand.NewSource(0))
   563  
   564  	// Create a realistic account trie to hash
   565  	addresses := make([][20]byte, b.N)
   566  	for i := 0; i < len(addresses); i++ {
   567  		for j := 0; j < len(addresses[i]); j++ {
   568  			addresses[i][j] = byte(random.Intn(256))
   569  		}
   570  	}
   571  	accounts := make([][]byte, len(addresses))
   572  	for i := 0; i < len(accounts); i++ {
   573  		var (
   574  			nonce   = uint64(random.Int63())
   575  			balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil))
   576  			root    = emptyRoot
   577  			code    = crypto.Keccak256(nil)
   578  		)
   579  		accounts[i], _ = rlp.EncodeToBytes([]interface{}{nonce, balance, root, code})
   580  	}
   581  	// Insert the accounts into the trie and hash it
   582  	trie := newEmpty()
   583  	for i := 0; i < len(addresses); i++ {
   584  		trie.Update(crypto.Keccak256(addresses[i][:]), accounts[i])
   585  	}
   586  	b.ResetTimer()
   587  	b.ReportAllocs()
   588  	trie.Hash()
   589  }
   590  
   591  func tempDB() (string, *Database) {
   592  	dir, err := ioutil.TempDir("", "trie-bench")
   593  	if err != nil {
   594  		panic(fmt.Sprintf("can't create temporary directory: %v", err))
   595  	}
   596  	diskdb, err := ethdb.NewLDBDatabase(dir, 256, 0)
   597  	if err != nil {
   598  		panic(fmt.Sprintf("can't create temporary database: %v", err))
   599  	}
   600  	return dir, NewDatabase(diskdb)
   601  }
   602  
   603  func getString(trie *Trie, k string) []byte {
   604  	return trie.Get([]byte(k))
   605  }
   606  
   607  func updateString(trie *Trie, k, v string) {
   608  	trie.Update([]byte(k), []byte(v))
   609  }
   610  
   611  func deleteString(trie *Trie, k string) {
   612  	trie.Delete([]byte(k))
   613  }