
     1  // Copyright 2022 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <>.
    17  package pathdb
    19  import (
    20  	"bytes"
    21  	"errors"
    22  	"fmt"
    23  	"math/big"
    24  	"math/rand"
    25  	"testing"
    27  	""
    28  	""
    29  	""
    30  	""
    31  	""
    32  	""
    33  	""
    34  	""
    35  )
    37  func updateTrie(addrHash common.Hash, root common.Hash, dirties, cleans map[common.Hash][]byte) (common.Hash, *trienode.NodeSet) {
    38  	h, err := newTestHasher(addrHash, root, cleans)
    39  	if err != nil {
    40  		panic(fmt.Errorf("failed to create hasher, err: %w", err))
    41  	}
    42  	for key, val := range dirties {
    43  		if len(val) == 0 {
    44  			h.Delete(key.Bytes())
    45  		} else {
    46  			h.Update(key.Bytes(), val)
    47  		}
    48  	}
    49  	root, nodes, _ := h.Commit(false)
    50  	return root, nodes
    51  }
    53  func generateAccount(storageRoot common.Hash) types.StateAccount {
    54  	return types.StateAccount{
    55  		Nonce:    uint64(rand.Intn(100)),
    56  		Balance:  big.NewInt(rand.Int63()),
    57  		CodeHash: testutil.RandBytes(32),
    58  		Root:     storageRoot,
    59  	}
    60  }
    62  const (
    63  	createAccountOp int = iota
    64  	modifyAccountOp
    65  	deleteAccountOp
    66  	opLen
    67  )
    69  type genctx struct {
    70  	accounts      map[common.Hash][]byte
    71  	storages      map[common.Hash]map[common.Hash][]byte
    72  	accountOrigin map[common.Address][]byte
    73  	storageOrigin map[common.Address]map[common.Hash][]byte
    74  	nodes         *trienode.MergedNodeSet
    75  }
    77  func newCtx() *genctx {
    78  	return &genctx{
    79  		accounts:      make(map[common.Hash][]byte),
    80  		storages:      make(map[common.Hash]map[common.Hash][]byte),
    81  		accountOrigin: make(map[common.Address][]byte),
    82  		storageOrigin: make(map[common.Address]map[common.Hash][]byte),
    83  		nodes:         trienode.NewMergedNodeSet(),
    84  	}
    85  }
    87  type tester struct {
    88  	db        *Database
    89  	roots     []common.Hash
    90  	preimages map[common.Hash]common.Address
    91  	accounts  map[common.Hash][]byte
    92  	storages  map[common.Hash]map[common.Hash][]byte
    94  	// state snapshots
    95  	snapAccounts map[common.Hash]map[common.Hash][]byte
    96  	snapStorages map[common.Hash]map[common.Hash]map[common.Hash][]byte
    97  }
    99  func newTester(t *testing.T) *tester {
   100  	var (
   101  		disk, _ = rawdb.NewDatabaseWithFreezer(rawdb.NewMemoryDatabase(), t.TempDir(), "", false)
   102  		db      = New(disk, &Config{CleanCacheSize: 256 * 1024, DirtyCacheSize: 256 * 1024})
   103  		obj     = &tester{
   104  			db:           db,
   105  			preimages:    make(map[common.Hash]common.Address),
   106  			accounts:     make(map[common.Hash][]byte),
   107  			storages:     make(map[common.Hash]map[common.Hash][]byte),
   108  			snapAccounts: make(map[common.Hash]map[common.Hash][]byte),
   109  			snapStorages: make(map[common.Hash]map[common.Hash]map[common.Hash][]byte),
   110  		}
   111  	)
   112  	for i := 0; i < 2*128; i++ {
   113  		var parent = types.EmptyRootHash
   114  		if len(obj.roots) != 0 {
   115  			parent = obj.roots[len(obj.roots)-1]
   116  		}
   117  		root, nodes, states := obj.generate(parent)
   118  		if err := db.Update(root, parent, uint64(i), nodes, states); err != nil {
   119  			panic(fmt.Errorf("failed to update state changes, err: %w", err))
   120  		}
   121  		obj.roots = append(obj.roots, root)
   122  	}
   123  	return obj
   124  }
   126  func (t *tester) release() {
   127  	t.db.Close()
   128  	t.db.diskdb.Close()
   129  }
   131  func (t *tester) randAccount() (common.Address, []byte) {
   132  	for addrHash, account := range t.accounts {
   133  		return t.preimages[addrHash], account
   134  	}
   135  	return common.Address{}, nil
   136  }
   138  func (t *tester) generateStorage(ctx *genctx, addr common.Address) common.Hash {
   139  	var (
   140  		addrHash = crypto.Keccak256Hash(addr.Bytes())
   141  		storage  = make(map[common.Hash][]byte)
   142  		origin   = make(map[common.Hash][]byte)
   143  	)
   144  	for i := 0; i < 10; i++ {
   145  		v, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(testutil.RandBytes(32)))
   146  		hash := testutil.RandomHash()
   148  		storage[hash] = v
   149  		origin[hash] = nil
   150  	}
   151  	root, set := updateTrie(addrHash, types.EmptyRootHash, storage, nil)
   153  	ctx.storages[addrHash] = storage
   154  	ctx.storageOrigin[addr] = origin
   155  	ctx.nodes.Merge(set)
   156  	return root
   157  }
   159  func (t *tester) mutateStorage(ctx *genctx, addr common.Address, root common.Hash) common.Hash {
   160  	var (
   161  		addrHash = crypto.Keccak256Hash(addr.Bytes())
   162  		storage  = make(map[common.Hash][]byte)
   163  		origin   = make(map[common.Hash][]byte)
   164  	)
   165  	for hash, val := range t.storages[addrHash] {
   166  		origin[hash] = val
   167  		storage[hash] = nil
   169  		if len(origin) == 3 {
   170  			break
   171  		}
   172  	}
   173  	for i := 0; i < 3; i++ {
   174  		v, _ := rlp.EncodeToBytes(common.TrimLeftZeroes(testutil.RandBytes(32)))
   175  		hash := testutil.RandomHash()
   177  		storage[hash] = v
   178  		origin[hash] = nil
   179  	}
   180  	root, set := updateTrie(crypto.Keccak256Hash(addr.Bytes()), root, storage, t.storages[addrHash])
   182  	ctx.storages[addrHash] = storage
   183  	ctx.storageOrigin[addr] = origin
   184  	ctx.nodes.Merge(set)
   185  	return root
   186  }
   188  func (t *tester) clearStorage(ctx *genctx, addr common.Address, root common.Hash) common.Hash {
   189  	var (
   190  		addrHash = crypto.Keccak256Hash(addr.Bytes())
   191  		storage  = make(map[common.Hash][]byte)
   192  		origin   = make(map[common.Hash][]byte)
   193  	)
   194  	for hash, val := range t.storages[addrHash] {
   195  		origin[hash] = val
   196  		storage[hash] = nil
   197  	}
   198  	root, set := updateTrie(addrHash, root, storage, t.storages[addrHash])
   199  	if root != types.EmptyRootHash {
   200  		panic("failed to clear storage trie")
   201  	}
   202  	ctx.storages[addrHash] = storage
   203  	ctx.storageOrigin[addr] = origin
   204  	ctx.nodes.Merge(set)
   205  	return root
   206  }
   208  func (t *tester) generate(parent common.Hash) (common.Hash, *trienode.MergedNodeSet, *triestate.Set) {
   209  	var (
   210  		ctx     = newCtx()
   211  		dirties = make(map[common.Hash]struct{})
   212  	)
   213  	for i := 0; i < 20; i++ {
   214  		switch rand.Intn(opLen) {
   215  		case createAccountOp:
   216  			// account creation
   217  			addr := testutil.RandomAddress()
   218  			addrHash := crypto.Keccak256Hash(addr.Bytes())
   219  			if _, ok := t.accounts[addrHash]; ok {
   220  				continue
   221  			}
   222  			if _, ok := dirties[addrHash]; ok {
   223  				continue
   224  			}
   225  			dirties[addrHash] = struct{}{}
   227  			root := t.generateStorage(ctx, addr)
   228  			ctx.accounts[addrHash] = types.SlimAccountRLP(generateAccount(root))
   229  			ctx.accountOrigin[addr] = nil
   230  			t.preimages[addrHash] = addr
   232  		case modifyAccountOp:
   233  			// account mutation
   234  			addr, account := t.randAccount()
   235  			if addr == (common.Address{}) {
   236  				continue
   237  			}
   238  			addrHash := crypto.Keccak256Hash(addr.Bytes())
   239  			if _, ok := dirties[addrHash]; ok {
   240  				continue
   241  			}
   242  			dirties[addrHash] = struct{}{}
   244  			acct, _ := types.FullAccount(account)
   245  			stRoot := t.mutateStorage(ctx, addr, acct.Root)
   246  			newAccount := types.SlimAccountRLP(generateAccount(stRoot))
   248  			ctx.accounts[addrHash] = newAccount
   249  			ctx.accountOrigin[addr] = account
   251  		case deleteAccountOp:
   252  			// account deletion
   253  			addr, account := t.randAccount()
   254  			if addr == (common.Address{}) {
   255  				continue
   256  			}
   257  			addrHash := crypto.Keccak256Hash(addr.Bytes())
   258  			if _, ok := dirties[addrHash]; ok {
   259  				continue
   260  			}
   261  			dirties[addrHash] = struct{}{}
   263  			acct, _ := types.FullAccount(account)
   264  			if acct.Root != types.EmptyRootHash {
   265  				t.clearStorage(ctx, addr, acct.Root)
   266  			}
   267  			ctx.accounts[addrHash] = nil
   268  			ctx.accountOrigin[addr] = account
   269  		}
   270  	}
   271  	root, set := updateTrie(common.Hash{}, parent, ctx.accounts, t.accounts)
   272  	ctx.nodes.Merge(set)
   274  	// Save state snapshot before commit
   275  	t.snapAccounts[parent] = copyAccounts(t.accounts)
   276  	t.snapStorages[parent] = copyStorages(t.storages)
   278  	// Commit all changes to live state set
   279  	for addrHash, account := range ctx.accounts {
   280  		if len(account) == 0 {
   281  			delete(t.accounts, addrHash)
   282  		} else {
   283  			t.accounts[addrHash] = account
   284  		}
   285  	}
   286  	for addrHash, slots := range ctx.storages {
   287  		if _, ok := t.storages[addrHash]; !ok {
   288  			t.storages[addrHash] = make(map[common.Hash][]byte)
   289  		}
   290  		for sHash, slot := range slots {
   291  			if len(slot) == 0 {
   292  				delete(t.storages[addrHash], sHash)
   293  			} else {
   294  				t.storages[addrHash][sHash] = slot
   295  			}
   296  		}
   297  	}
   298  	return root, ctx.nodes, triestate.New(ctx.accountOrigin, ctx.storageOrigin, nil)
   299  }
   301  // lastRoot returns the latest root hash, or empty if nothing is cached.
   302  func (t *tester) lastHash() common.Hash {
   303  	if len(t.roots) == 0 {
   304  		return common.Hash{}
   305  	}
   306  	return t.roots[len(t.roots)-1]
   307  }
   309  func (t *tester) verifyState(root common.Hash) error {
   310  	reader, err := t.db.Reader(root)
   311  	if err != nil {
   312  		return err
   313  	}
   314  	_, err = reader.Node(common.Hash{}, nil, root)
   315  	if err != nil {
   316  		return errors.New("root node is not available")
   317  	}
   318  	for addrHash, account := range t.snapAccounts[root] {
   319  		blob, err := reader.Node(common.Hash{}, addrHash.Bytes(), crypto.Keccak256Hash(account))
   320  		if err != nil || !bytes.Equal(blob, account) {
   321  			return fmt.Errorf("account is mismatched: %w", err)
   322  		}
   323  	}
   324  	for addrHash, slots := range t.snapStorages[root] {
   325  		for hash, slot := range slots {
   326  			blob, err := reader.Node(addrHash, hash.Bytes(), crypto.Keccak256Hash(slot))
   327  			if err != nil || !bytes.Equal(blob, slot) {
   328  				return fmt.Errorf("slot is mismatched: %w", err)
   329  			}
   330  		}
   331  	}
   332  	return nil
   333  }
   335  func (t *tester) verifyHistory() error {
   336  	bottom := t.bottomIndex()
   337  	for i, root := range t.roots {
   338  		// The state history related to the state above disk layer should not exist.
   339  		if i > bottom {
   340  			_, err := readHistory(t.db.freezer, uint64(i+1))
   341  			if err == nil {
   342  				return errors.New("unexpected state history")
   343  			}
   344  			continue
   345  		}
   346  		// The state history related to the state below or equal to the disk layer
   347  		// should exist.
   348  		obj, err := readHistory(t.db.freezer, uint64(i+1))
   349  		if err != nil {
   350  			return err
   351  		}
   352  		parent := types.EmptyRootHash
   353  		if i != 0 {
   354  			parent = t.roots[i-1]
   355  		}
   356  		if obj.meta.parent != parent {
   357  			return fmt.Errorf("unexpected parent, want: %x, got: %x", parent, obj.meta.parent)
   358  		}
   359  		if obj.meta.root != root {
   360  			return fmt.Errorf("unexpected root, want: %x, got: %x", root, obj.meta.root)
   361  		}
   362  	}
   363  	return nil
   364  }
   366  // bottomIndex returns the index of current disk layer.
   367  func (t *tester) bottomIndex() int {
   368  	bottom := t.db.tree.bottom()
   369  	for i := 0; i < len(t.roots); i++ {
   370  		if t.roots[i] == bottom.rootHash() {
   371  			return i
   372  		}
   373  	}
   374  	return -1
   375  }
   377  func TestDatabaseRollback(t *testing.T) {
   378  	// Verify state histories
   379  	tester := newTester(t)
   380  	defer tester.release()
   382  	if err := tester.verifyHistory(); err != nil {
   383  		t.Fatalf("Invalid state history, err: %v", err)
   384  	}
   385  	// Revert database from top to bottom
   386  	for i := tester.bottomIndex(); i >= 0; i-- {
   387  		root := tester.roots[i]
   388  		parent := types.EmptyRootHash
   389  		if i > 0 {
   390  			parent = tester.roots[i-1]
   391  		}
   392  		loader := newHashLoader(tester.snapAccounts[root], tester.snapStorages[root])
   393  		if err := tester.db.Recover(parent, loader); err != nil {
   394  			t.Fatalf("Failed to revert db, err: %v", err)
   395  		}
   396  		tester.verifyState(parent)
   397  	}
   398  	if tester.db.tree.len() != 1 {
   399  		t.Fatal("Only disk layer is expected")
   400  	}
   401  }
   403  func TestDatabaseRecoverable(t *testing.T) {
   404  	var (
   405  		tester = newTester(t)
   406  		index  = tester.bottomIndex()
   407  	)
   408  	defer tester.release()
   410  	var cases = []struct {
   411  		root   common.Hash
   412  		expect bool
   413  	}{
   414  		// Unknown state should be unrecoverable
   415  		{common.Hash{0x1}, false},
   417  		// Initial state should be recoverable
   418  		{types.EmptyRootHash, true},
   420  		// Initial state should be recoverable
   421  		{common.Hash{}, true},
   423  		// Layers below current disk layer are recoverable
   424  		{tester.roots[index-1], true},
   426  		// Disklayer itself is not recoverable, since it's
   427  		// available for accessing.
   428  		{tester.roots[index], false},
   430  		// Layers above current disk layer are not recoverable
   431  		// since they are available for accessing.
   432  		{tester.roots[index+1], false},
   433  	}
   434  	for i, c := range cases {
   435  		result := tester.db.Recoverable(c.root)
   436  		if result != c.expect {
   437  			t.Fatalf("case: %d, unexpected result, want %t, got %t", i, c.expect, result)
   438  		}
   439  	}
   440  }
   442  func TestReset(t *testing.T) {
   443  	var (
   444  		tester = newTester(t)
   445  		index  = tester.bottomIndex()
   446  	)
   447  	defer tester.release()
   449  	// Reset database to unknown target, should reject it
   450  	if err := tester.db.Reset(testutil.RandomHash()); err == nil {
   451  		t.Fatal("Failed to reject invalid reset")
   452  	}
   453  	// Reset database to state persisted in the disk
   454  	if err := tester.db.Reset(types.EmptyRootHash); err != nil {
   455  		t.Fatalf("Failed to reset database %v", err)
   456  	}
   457  	// Ensure journal is deleted from disk
   458  	if blob := rawdb.ReadTrieJournal(tester.db.diskdb); len(blob) != 0 {
   459  		t.Fatal("Failed to clean journal")
   460  	}
   461  	// Ensure all trie histories are removed
   462  	for i := 0; i <= index; i++ {
   463  		_, err := readHistory(tester.db.freezer, uint64(i+1))
   464  		if err == nil {
   465  			t.Fatalf("Failed to clean state history, index %d", i+1)
   466  		}
   467  	}
   468  	// Verify layer tree structure, single disk layer is expected
   469  	if tester.db.tree.len() != 1 {
   470  		t.Fatalf("Extra layer kept %d", tester.db.tree.len())
   471  	}
   472  	if tester.db.tree.bottom().rootHash() != types.EmptyRootHash {
   473  		t.Fatalf("Root hash is not matched exp %x got %x", types.EmptyRootHash, tester.db.tree.bottom().rootHash())
   474  	}
   475  }
   477  func TestCommit(t *testing.T) {
   478  	tester := newTester(t)
   479  	defer tester.release()
   481  	if err := tester.db.Commit(tester.lastHash(), false); err != nil {
   482  		t.Fatalf("Failed to cap database, err: %v", err)
   483  	}
   484  	// Verify layer tree structure, single disk layer is expected
   485  	if tester.db.tree.len() != 1 {
   486  		t.Fatal("Layer tree structure is invalid")
   487  	}
   488  	if tester.db.tree.bottom().rootHash() != tester.lastHash() {
   489  		t.Fatal("Layer tree structure is invalid")
   490  	}
   491  	// Verify states
   492  	if err := tester.verifyState(tester.lastHash()); err != nil {
   493  		t.Fatalf("State is invalid, err: %v", err)
   494  	}
   495  	// Verify state histories
   496  	if err := tester.verifyHistory(); err != nil {
   497  		t.Fatalf("State history is invalid, err: %v", err)
   498  	}
   499  }
   501  func TestJournal(t *testing.T) {
   502  	tester := newTester(t)
   503  	defer tester.release()
   505  	if err := tester.db.Journal(tester.lastHash()); err != nil {
   506  		t.Errorf("Failed to journal, err: %v", err)
   507  	}
   508  	tester.db.Close()
   509  	tester.db = New(tester.db.diskdb, nil)
   511  	// Verify states including disk layer and all diff on top.
   512  	for i := 0; i < len(tester.roots); i++ {
   513  		if i >= tester.bottomIndex() {
   514  			if err := tester.verifyState(tester.roots[i]); err != nil {
   515  				t.Fatalf("Invalid state, err: %v", err)
   516  			}
   517  			continue
   518  		}
   519  		if err := tester.verifyState(tester.roots[i]); err == nil {
   520  			t.Fatal("Unexpected state")
   521  		}
   522  	}
   523  }
   525  func TestCorruptedJournal(t *testing.T) {
   526  	tester := newTester(t)
   527  	defer tester.release()
   529  	if err := tester.db.Journal(tester.lastHash()); err != nil {
   530  		t.Errorf("Failed to journal, err: %v", err)
   531  	}
   532  	tester.db.Close()
   533  	_, root := rawdb.ReadAccountTrieNode(tester.db.diskdb, nil)
   535  	// Mutate the journal in disk, it should be regarded as invalid
   536  	blob := rawdb.ReadTrieJournal(tester.db.diskdb)
   537  	blob[0] = 1
   538  	rawdb.WriteTrieJournal(tester.db.diskdb, blob)
   540  	// Verify states, all not-yet-written states should be discarded
   541  	tester.db = New(tester.db.diskdb, nil)
   542  	for i := 0; i < len(tester.roots); i++ {
   543  		if tester.roots[i] == root {
   544  			if err := tester.verifyState(root); err != nil {
   545  				t.Fatalf("Disk state is corrupted, err: %v", err)
   546  			}
   547  			continue
   548  		}
   549  		if err := tester.verifyState(tester.roots[i]); err == nil {
   550  			t.Fatal("Unexpected state")
   551  		}
   552  	}
   553  }
   555  // copyAccounts returns a deep-copied account set of the provided one.
   556  func copyAccounts(set map[common.Hash][]byte) map[common.Hash][]byte {
   557  	copied := make(map[common.Hash][]byte, len(set))
   558  	for key, val := range set {
   559  		copied[key] = common.CopyBytes(val)
   560  	}
   561  	return copied
   562  }
   564  // copyStorages returns a deep-copied storage set of the provided one.
   565  func copyStorages(set map[common.Hash]map[common.Hash][]byte) map[common.Hash]map[common.Hash][]byte {
   566  	copied := make(map[common.Hash]map[common.Hash][]byte, len(set))
   567  	for addrHash, subset := range set {
   568  		copied[addrHash] = make(map[common.Hash][]byte, len(subset))
   569  		for key, val := range subset {
   570  			copied[addrHash][key] = common.CopyBytes(val)
   571  		}
   572  	}
   573  	return copied
   574  }