github.1485827954.workers.dev/ethereum/go-ethereum@v1.14.3/eth/protocols/snap/gentrie_test.go (about)

     1  // Copyright 2024 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package snap
    18  
    19  import (
    20  	"bytes"
    21  	"math/rand"
    22  	"slices"
    23  	"testing"
    24  
    25  	"github.com/ethereum/go-ethereum/common"
    26  	"github.com/ethereum/go-ethereum/core/rawdb"
    27  	"github.com/ethereum/go-ethereum/crypto"
    28  	"github.com/ethereum/go-ethereum/ethdb"
    29  	"github.com/ethereum/go-ethereum/internal/testrand"
    30  	"github.com/ethereum/go-ethereum/trie"
    31  )
    32  
    33  type replayer struct {
    34  	paths    []string      // sort in fifo order
    35  	hashes   []common.Hash // empty for deletion
    36  	unknowns int           // counter for unknown write
    37  }
    38  
    39  func newBatchReplay() *replayer {
    40  	return &replayer{}
    41  }
    42  
    43  func (r *replayer) decode(key []byte, value []byte) {
    44  	account := rawdb.IsAccountTrieNode(key)
    45  	storage := rawdb.IsStorageTrieNode(key)
    46  	if !account && !storage {
    47  		r.unknowns += 1
    48  		return
    49  	}
    50  	var path []byte
    51  	if account {
    52  		_, path = rawdb.ResolveAccountTrieNodeKey(key)
    53  	} else {
    54  		_, owner, inner := rawdb.ResolveStorageTrieNode(key)
    55  		path = append(owner.Bytes(), inner...)
    56  	}
    57  	r.paths = append(r.paths, string(path))
    58  
    59  	if len(value) == 0 {
    60  		r.hashes = append(r.hashes, common.Hash{})
    61  	} else {
    62  		r.hashes = append(r.hashes, crypto.Keccak256Hash(value))
    63  	}
    64  }
    65  
    66  // updates returns a set of effective mutations. Multiple mutations targeting
    67  // the same node path will be merged in FIFO order.
    68  func (r *replayer) modifies() map[string]common.Hash {
    69  	set := make(map[string]common.Hash)
    70  	for i, path := range r.paths {
    71  		set[path] = r.hashes[i]
    72  	}
    73  	return set
    74  }
    75  
    76  // updates returns the number of updates.
    77  func (r *replayer) updates() int {
    78  	var count int
    79  	for _, hash := range r.modifies() {
    80  		if hash == (common.Hash{}) {
    81  			continue
    82  		}
    83  		count++
    84  	}
    85  	return count
    86  }
    87  
    88  // Put inserts the given value into the key-value data store.
    89  func (r *replayer) Put(key []byte, value []byte) error {
    90  	r.decode(key, value)
    91  	return nil
    92  }
    93  
    94  // Delete removes the key from the key-value data store.
    95  func (r *replayer) Delete(key []byte) error {
    96  	r.decode(key, nil)
    97  	return nil
    98  }
    99  
   100  func byteToHex(str []byte) []byte {
   101  	l := len(str) * 2
   102  	var nibbles = make([]byte, l)
   103  	for i, b := range str {
   104  		nibbles[i*2] = b / 16
   105  		nibbles[i*2+1] = b % 16
   106  	}
   107  	return nibbles
   108  }
   109  
   110  // innerNodes returns the internal nodes narrowed by two boundaries along with
   111  // the leftmost and rightmost sub-trie roots.
   112  func innerNodes(first, last []byte, includeLeft, includeRight bool, nodes map[string]common.Hash, t *testing.T) (map[string]common.Hash, []byte, []byte) {
   113  	var (
   114  		leftRoot  []byte
   115  		rightRoot []byte
   116  		firstHex  = byteToHex(first)
   117  		lastHex   = byteToHex(last)
   118  		inner     = make(map[string]common.Hash)
   119  	)
   120  	for path, hash := range nodes {
   121  		if hash == (common.Hash{}) {
   122  			t.Fatalf("Unexpected deletion, %v", []byte(path))
   123  		}
   124  		// Filter out the siblings on the left side or the left boundary nodes.
   125  		if !includeLeft && (bytes.Compare(firstHex, []byte(path)) > 0 || bytes.HasPrefix(firstHex, []byte(path))) {
   126  			continue
   127  		}
   128  		// Filter out the siblings on the right side or the right boundary nodes.
   129  		if !includeRight && (bytes.Compare(lastHex, []byte(path)) < 0 || bytes.HasPrefix(lastHex, []byte(path))) {
   130  			continue
   131  		}
   132  		inner[path] = hash
   133  
   134  		// Track the path of the leftmost sub trie root
   135  		if leftRoot == nil || bytes.Compare(leftRoot, []byte(path)) > 0 {
   136  			leftRoot = []byte(path)
   137  		}
   138  		// Track the path of the rightmost sub trie root
   139  		if rightRoot == nil ||
   140  			(bytes.Compare(rightRoot, []byte(path)) < 0) ||
   141  			(bytes.Compare(rightRoot, []byte(path)) > 0 && bytes.HasPrefix(rightRoot, []byte(path))) {
   142  			rightRoot = []byte(path)
   143  		}
   144  	}
   145  	return inner, leftRoot, rightRoot
   146  }
   147  
   148  func buildPartial(owner common.Hash, db ethdb.KeyValueReader, batch ethdb.Batch, entries []*kv, first, last int) *replayer {
   149  	tr := newPathTrie(owner, first != 0, db, batch)
   150  	for i := first; i <= last; i++ {
   151  		tr.update(entries[i].k, entries[i].v)
   152  	}
   153  	tr.commit(last == len(entries)-1)
   154  
   155  	replay := newBatchReplay()
   156  	batch.Replay(replay)
   157  
   158  	return replay
   159  }
   160  
   161  // TestPartialGentree verifies if the trie constructed with partial states can
   162  // generate consistent trie nodes that match those of the full trie.
   163  func TestPartialGentree(t *testing.T) {
   164  	for round := 0; round < 100; round++ {
   165  		var (
   166  			n       = rand.Intn(1024) + 10
   167  			entries []*kv
   168  		)
   169  		for i := 0; i < n; i++ {
   170  			var val []byte
   171  			if rand.Intn(3) == 0 {
   172  				val = testrand.Bytes(3)
   173  			} else {
   174  				val = testrand.Bytes(32)
   175  			}
   176  			entries = append(entries, &kv{
   177  				k: testrand.Bytes(32),
   178  				v: val,
   179  			})
   180  		}
   181  		slices.SortFunc(entries, (*kv).cmp)
   182  
   183  		nodes := make(map[string]common.Hash)
   184  		tr := trie.NewStackTrie(func(path []byte, hash common.Hash, blob []byte) {
   185  			nodes[string(path)] = hash
   186  		})
   187  		for i := 0; i < len(entries); i++ {
   188  			tr.Update(entries[i].k, entries[i].v)
   189  		}
   190  		tr.Hash()
   191  
   192  		check := func(first, last int) {
   193  			var (
   194  				db    = rawdb.NewMemoryDatabase()
   195  				batch = db.NewBatch()
   196  			)
   197  			// Build the partial tree with specific boundaries
   198  			r := buildPartial(common.Hash{}, db, batch, entries, first, last)
   199  			if r.unknowns > 0 {
   200  				t.Fatalf("Unknown database write: %d", r.unknowns)
   201  			}
   202  
   203  			// Ensure all the internal nodes are produced
   204  			var (
   205  				set         = r.modifies()
   206  				inner, _, _ = innerNodes(entries[first].k, entries[last].k, first == 0, last == len(entries)-1, nodes, t)
   207  			)
   208  			for path, hash := range inner {
   209  				if _, ok := set[path]; !ok {
   210  					t.Fatalf("Missing nodes %v", []byte(path))
   211  				}
   212  				if hash != set[path] {
   213  					t.Fatalf("Inconsistent node, want %x, got: %x", hash, set[path])
   214  				}
   215  			}
   216  			if r.updates() != len(inner) {
   217  				t.Fatalf("Unexpected node write detected, want: %d, got: %d", len(inner), r.updates())
   218  			}
   219  		}
   220  		for j := 0; j < 100; j++ {
   221  			var (
   222  				first int
   223  				last  int
   224  			)
   225  			for {
   226  				first = rand.Intn(len(entries))
   227  				last = rand.Intn(len(entries))
   228  				if first <= last {
   229  					break
   230  				}
   231  			}
   232  			check(first, last)
   233  		}
   234  		var cases = []struct {
   235  			first int
   236  			last  int
   237  		}{
   238  			{0, len(entries) - 1},                // full
   239  			{1, len(entries) - 1},                // no left
   240  			{2, len(entries) - 1},                // no left
   241  			{2, len(entries) - 2},                // no left and right
   242  			{2, len(entries) - 2},                // no left and right
   243  			{len(entries) / 2, len(entries) / 2}, // single
   244  			{0, 0},                               // single first
   245  			{len(entries) - 1, len(entries) - 1}, // single last
   246  		}
   247  		for _, c := range cases {
   248  			check(c.first, c.last)
   249  		}
   250  	}
   251  }
   252  
   253  // TestGentreeDanglingClearing tests if the dangling nodes falling within the
   254  // path space of constructed tree can be correctly removed.
   255  func TestGentreeDanglingClearing(t *testing.T) {
   256  	for round := 0; round < 100; round++ {
   257  		var (
   258  			n       = rand.Intn(1024) + 10
   259  			entries []*kv
   260  		)
   261  		for i := 0; i < n; i++ {
   262  			var val []byte
   263  			if rand.Intn(3) == 0 {
   264  				val = testrand.Bytes(3)
   265  			} else {
   266  				val = testrand.Bytes(32)
   267  			}
   268  			entries = append(entries, &kv{
   269  				k: testrand.Bytes(32),
   270  				v: val,
   271  			})
   272  		}
   273  		slices.SortFunc(entries, (*kv).cmp)
   274  
   275  		nodes := make(map[string]common.Hash)
   276  		tr := trie.NewStackTrie(func(path []byte, hash common.Hash, blob []byte) {
   277  			nodes[string(path)] = hash
   278  		})
   279  		for i := 0; i < len(entries); i++ {
   280  			tr.Update(entries[i].k, entries[i].v)
   281  		}
   282  		tr.Hash()
   283  
   284  		check := func(first, last int) {
   285  			var (
   286  				db    = rawdb.NewMemoryDatabase()
   287  				batch = db.NewBatch()
   288  			)
   289  			// Write the junk nodes as the dangling
   290  			var injects []string
   291  			for path := range nodes {
   292  				for i := 0; i < len(path); i++ {
   293  					_, ok := nodes[path[:i]]
   294  					if ok {
   295  						continue
   296  					}
   297  					injects = append(injects, path[:i])
   298  				}
   299  			}
   300  			if len(injects) == 0 {
   301  				return
   302  			}
   303  			for _, path := range injects {
   304  				rawdb.WriteAccountTrieNode(db, []byte(path), testrand.Bytes(32))
   305  			}
   306  
   307  			// Build the partial tree with specific range
   308  			replay := buildPartial(common.Hash{}, db, batch, entries, first, last)
   309  			if replay.unknowns > 0 {
   310  				t.Fatalf("Unknown database write: %d", replay.unknowns)
   311  			}
   312  			set := replay.modifies()
   313  
   314  			// Make sure the injected junks falling within the path space of
   315  			// committed trie nodes are correctly deleted.
   316  			_, leftRoot, rightRoot := innerNodes(entries[first].k, entries[last].k, first == 0, last == len(entries)-1, nodes, t)
   317  			for _, path := range injects {
   318  				if bytes.Compare([]byte(path), leftRoot) < 0 && !bytes.HasPrefix(leftRoot, []byte(path)) {
   319  					continue
   320  				}
   321  				if bytes.Compare([]byte(path), rightRoot) > 0 {
   322  					continue
   323  				}
   324  				if hash, ok := set[path]; !ok || hash != (common.Hash{}) {
   325  					t.Fatalf("Missing delete, %v", []byte(path))
   326  				}
   327  			}
   328  		}
   329  		for j := 0; j < 100; j++ {
   330  			var (
   331  				first int
   332  				last  int
   333  			)
   334  			for {
   335  				first = rand.Intn(len(entries))
   336  				last = rand.Intn(len(entries))
   337  				if first <= last {
   338  					break
   339  				}
   340  			}
   341  			check(first, last)
   342  		}
   343  		var cases = []struct {
   344  			first int
   345  			last  int
   346  		}{
   347  			{0, len(entries) - 1},                // full
   348  			{1, len(entries) - 1},                // no left
   349  			{2, len(entries) - 1},                // no left
   350  			{2, len(entries) - 2},                // no left and right
   351  			{2, len(entries) - 2},                // no left and right
   352  			{len(entries) / 2, len(entries) / 2}, // single
   353  			{0, 0},                               // single first
   354  			{len(entries) - 1, len(entries) - 1}, // single last
   355  		}
   356  		for _, c := range cases {
   357  			check(c.first, c.last)
   358  		}
   359  	}
   360  }
   361  
   362  // TestFlushPartialTree tests the gentrie can produce complete inner trie nodes
   363  // even with lots of batch flushes.
   364  func TestFlushPartialTree(t *testing.T) {
   365  	var entries []*kv
   366  	for i := 0; i < 1024; i++ {
   367  		var val []byte
   368  		if rand.Intn(3) == 0 {
   369  			val = testrand.Bytes(3)
   370  		} else {
   371  			val = testrand.Bytes(32)
   372  		}
   373  		entries = append(entries, &kv{
   374  			k: testrand.Bytes(32),
   375  			v: val,
   376  		})
   377  	}
   378  	slices.SortFunc(entries, (*kv).cmp)
   379  
   380  	nodes := make(map[string]common.Hash)
   381  	tr := trie.NewStackTrie(func(path []byte, hash common.Hash, blob []byte) {
   382  		nodes[string(path)] = hash
   383  	})
   384  	for i := 0; i < len(entries); i++ {
   385  		tr.Update(entries[i].k, entries[i].v)
   386  	}
   387  	tr.Hash()
   388  
   389  	var cases = []struct {
   390  		first int
   391  		last  int
   392  	}{
   393  		{0, len(entries) - 1},                // full
   394  		{1, len(entries) - 1},                // no left
   395  		{10, len(entries) - 1},               // no left
   396  		{10, len(entries) - 2},               // no left and right
   397  		{10, len(entries) - 10},              // no left and right
   398  		{11, 11},                             // single
   399  		{0, 0},                               // single first
   400  		{len(entries) - 1, len(entries) - 1}, // single last
   401  	}
   402  	for _, c := range cases {
   403  		var (
   404  			db       = rawdb.NewMemoryDatabase()
   405  			batch    = db.NewBatch()
   406  			combined = db.NewBatch()
   407  		)
   408  		inner, _, _ := innerNodes(entries[c.first].k, entries[c.last].k, c.first == 0, c.last == len(entries)-1, nodes, t)
   409  
   410  		tr := newPathTrie(common.Hash{}, c.first != 0, db, batch)
   411  		for i := c.first; i <= c.last; i++ {
   412  			tr.update(entries[i].k, entries[i].v)
   413  			if rand.Intn(2) == 0 {
   414  				tr.commit(false)
   415  
   416  				batch.Replay(combined)
   417  				batch.Write()
   418  				batch.Reset()
   419  			}
   420  		}
   421  		tr.commit(c.last == len(entries)-1)
   422  
   423  		batch.Replay(combined)
   424  		batch.Write()
   425  		batch.Reset()
   426  
   427  		r := newBatchReplay()
   428  		combined.Replay(r)
   429  
   430  		// Ensure all the internal nodes are produced
   431  		set := r.modifies()
   432  		for path, hash := range inner {
   433  			if _, ok := set[path]; !ok {
   434  				t.Fatalf("Missing nodes %v", []byte(path))
   435  			}
   436  			if hash != set[path] {
   437  				t.Fatalf("Inconsistent node, want %x, got: %x", hash, set[path])
   438  			}
   439  		}
   440  		if r.updates() != len(inner) {
   441  			t.Fatalf("Unexpected node write detected, want: %d, got: %d", len(inner), r.updates())
   442  		}
   443  	}
   444  }
   445  
   446  // TestBoundSplit ensures two consecutive trie chunks are not overlapped with
   447  // each other.
   448  func TestBoundSplit(t *testing.T) {
   449  	var entries []*kv
   450  	for i := 0; i < 1024; i++ {
   451  		var val []byte
   452  		if rand.Intn(3) == 0 {
   453  			val = testrand.Bytes(3)
   454  		} else {
   455  			val = testrand.Bytes(32)
   456  		}
   457  		entries = append(entries, &kv{
   458  			k: testrand.Bytes(32),
   459  			v: val,
   460  		})
   461  	}
   462  	slices.SortFunc(entries, (*kv).cmp)
   463  
   464  	for j := 0; j < 100; j++ {
   465  		var (
   466  			next int
   467  			last int
   468  			db   = rawdb.NewMemoryDatabase()
   469  
   470  			lastRightRoot []byte
   471  		)
   472  		for {
   473  			if next == len(entries) {
   474  				break
   475  			}
   476  			last = rand.Intn(len(entries)-next) + next
   477  
   478  			r := buildPartial(common.Hash{}, db, db.NewBatch(), entries, next, last)
   479  			set := r.modifies()
   480  
   481  			// Skip if the chunk is zero-size
   482  			if r.updates() == 0 {
   483  				next = last + 1
   484  				continue
   485  			}
   486  
   487  			// Ensure the updates in two consecutive chunks are not overlapped.
   488  			// The only overlapping part should be deletion.
   489  			if lastRightRoot != nil && len(set) > 0 {
   490  				// Derive the path of left-most node in this chunk
   491  				var leftRoot []byte
   492  				for path, hash := range r.modifies() {
   493  					if hash == (common.Hash{}) {
   494  						t.Fatalf("Unexpected deletion %v", []byte(path))
   495  					}
   496  					if leftRoot == nil || bytes.Compare(leftRoot, []byte(path)) > 0 {
   497  						leftRoot = []byte(path)
   498  					}
   499  				}
   500  				if bytes.HasPrefix(lastRightRoot, leftRoot) || bytes.HasPrefix(leftRoot, lastRightRoot) {
   501  					t.Fatalf("Two chunks are not correctly separated, lastRight: %v, left: %v", lastRightRoot, leftRoot)
   502  				}
   503  			}
   504  
   505  			// Track the updates as the last chunk
   506  			var rightRoot []byte
   507  			for path := range set {
   508  				if rightRoot == nil ||
   509  					(bytes.Compare(rightRoot, []byte(path)) < 0) ||
   510  					(bytes.Compare(rightRoot, []byte(path)) > 0 && bytes.HasPrefix(rightRoot, []byte(path))) {
   511  					rightRoot = []byte(path)
   512  				}
   513  			}
   514  			lastRightRoot = rightRoot
   515  			next = last + 1
   516  		}
   517  	}
   518  }
   519  
   520  // TestTinyPartialTree tests if the partial tree is too tiny(has less than two
   521  // states), then nothing should be committed.
   522  func TestTinyPartialTree(t *testing.T) {
   523  	var entries []*kv
   524  	for i := 0; i < 1024; i++ {
   525  		var val []byte
   526  		if rand.Intn(3) == 0 {
   527  			val = testrand.Bytes(3)
   528  		} else {
   529  			val = testrand.Bytes(32)
   530  		}
   531  		entries = append(entries, &kv{
   532  			k: testrand.Bytes(32),
   533  			v: val,
   534  		})
   535  	}
   536  	slices.SortFunc(entries, (*kv).cmp)
   537  
   538  	for i := 0; i < len(entries); i++ {
   539  		next := i
   540  		last := i + 1
   541  		if last >= len(entries) {
   542  			last = len(entries) - 1
   543  		}
   544  		db := rawdb.NewMemoryDatabase()
   545  		r := buildPartial(common.Hash{}, db, db.NewBatch(), entries, next, last)
   546  
   547  		if next != 0 && last != len(entries)-1 {
   548  			if r.updates() != 0 {
   549  				t.Fatalf("Unexpected data writes, got: %d", r.updates())
   550  			}
   551  		}
   552  	}
   553  }