github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/core/mpt/trie.go (about)

     1  package mpt
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  
     9  	"github.com/nspcc-dev/neo-go/pkg/core/storage"
    10  	"github.com/nspcc-dev/neo-go/pkg/io"
    11  	"github.com/nspcc-dev/neo-go/pkg/util"
    12  )
    13  
    14  // TrieMode is the storage mode of a trie, it affects the DB scheme.
    15  type TrieMode byte
    16  
    17  // TrieMode is the storage mode of a trie.
    18  const (
    19  	// ModeAll is used to store everything.
    20  	ModeAll TrieMode = 0
    21  	// ModeLatest is used to only store the latest root.
    22  	ModeLatest TrieMode = 0x01
    23  	// ModeGCFlag is a flag for GC.
    24  	ModeGCFlag TrieMode = 0x02
    25  	// ModeGC is used to store a set of roots with GC possible, it combines
    26  	// GCFlag and Latest (because it needs RC, but it has GC enabled).
    27  	ModeGC TrieMode = 0x03
    28  )
    29  
    30  // Trie is an MPT trie storing all key-value pairs.
    31  type Trie struct {
    32  	Store *storage.MemCachedStore
    33  
    34  	root     Node
    35  	mode     TrieMode
    36  	refcount map[util.Uint256]*cachedNode
    37  }
    38  
    39  type cachedNode struct {
    40  	bytes    []byte
    41  	initial  int32
    42  	refcount int32
    43  }
    44  
    45  // ErrNotFound is returned when the requested trie item is missing.
    46  var ErrNotFound = errors.New("item not found")
    47  
    48  // RC returns true when reference counting is enabled.
    49  func (m TrieMode) RC() bool {
    50  	return m&ModeLatest != 0
    51  }
    52  
    53  // GC returns true when garbage collection is enabled.
    54  func (m TrieMode) GC() bool {
    55  	return m&ModeGCFlag != 0
    56  }
    57  
    58  // NewTrie returns a new MPT trie. It accepts a MemCachedStore to decouple storage errors from logic errors,
    59  // so that all storage errors are processed during `store.Persist()` at the caller.
    60  // Another benefit is that every `Put` can be considered an atomic operation.
    61  func NewTrie(root Node, mode TrieMode, store *storage.MemCachedStore) *Trie {
    62  	if root == nil {
    63  		root = EmptyNode{}
    64  	}
    65  
    66  	return &Trie{
    67  		Store: store,
    68  		root:  root,
    69  
    70  		mode:     mode,
    71  		refcount: make(map[util.Uint256]*cachedNode),
    72  	}
    73  }
    74  
    75  // Get returns the value for the provided key in t.
    76  func (t *Trie) Get(key []byte) ([]byte, error) {
    77  	if len(key) > MaxKeyLength {
    78  		return nil, errors.New("key is too big")
    79  	}
    80  	path := toNibbles(key)
    81  	r, leaf, _, err := t.getWithPath(t.root, path, true)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	t.root = r
    86  	return bytes.Clone(leaf.(*LeafNode).value), nil
    87  }
    88  
    89  // getWithPath returns the current node with all hash nodes along the path replaced
    90  // with their "unhashed" counterparts. It also returns node which the provided path in a
    91  // subtrie rooting in curr points to. In case of `strict` set to `false`, the
    92  // provided path can be incomplete, so it also returns the full path that points to
    93  // the node found at the specified incomplete path. In case of `strict` set to `true`,
    94  // the resulting path matches the provided one.
    95  func (t *Trie) getWithPath(curr Node, path []byte, strict bool) (Node, Node, []byte, error) {
    96  	switch n := curr.(type) {
    97  	case *LeafNode:
    98  		if len(path) == 0 {
    99  			return curr, n, []byte{}, nil
   100  		}
   101  	case *BranchNode:
   102  		i, path := splitPath(path)
   103  		if i == lastChild && !strict {
   104  			return curr, n, []byte{}, nil
   105  		}
   106  		r, res, prefix, err := t.getWithPath(n.Children[i], path, strict)
   107  		if err != nil {
   108  			return nil, nil, nil, err
   109  		}
   110  		n.Children[i] = r
   111  		return n, res, append([]byte{i}, prefix...), nil
   112  	case EmptyNode:
   113  	case *HashNode:
   114  		if r, err := t.getFromStore(n.hash); err == nil {
   115  			return t.getWithPath(r, path, strict)
   116  		}
   117  	case *ExtensionNode:
   118  		if len(path) == 0 && !strict {
   119  			return curr, n.next, n.key, nil
   120  		}
   121  		if bytes.HasPrefix(path, n.key) {
   122  			r, res, prefix, err := t.getWithPath(n.next, path[len(n.key):], strict)
   123  			if err != nil {
   124  				return nil, nil, nil, err
   125  			}
   126  			n.next = r
   127  			return curr, res, append(n.key, prefix...), err
   128  		}
   129  		if !strict && bytes.HasPrefix(n.key, path) {
   130  			// path is shorter than prefix, stop seeking
   131  			return curr, n.next, n.key, nil
   132  		}
   133  	default:
   134  		panic("invalid MPT node type")
   135  	}
   136  	return curr, nil, nil, ErrNotFound
   137  }
   138  
   139  // Put puts key-value pair in t.
   140  func (t *Trie) Put(key, value []byte) error {
   141  	if len(key) == 0 {
   142  		return errors.New("key is empty")
   143  	} else if len(key) > MaxKeyLength {
   144  		return errors.New("key is too big")
   145  	} else if len(value) > MaxValueLength {
   146  		return errors.New("value is too big")
   147  	} else if value == nil {
   148  		// (t *Trie).Delete should be used to remove value
   149  		return errors.New("value is nil")
   150  	}
   151  	path := toNibbles(key)
   152  	n := NewLeafNode(value)
   153  	r, err := t.putIntoNode(t.root, path, n)
   154  	if err != nil {
   155  		return err
   156  	}
   157  	t.root = r
   158  	return nil
   159  }
   160  
   161  // putIntoLeaf puts the val to the trie if the current node is a Leaf.
   162  // It returns a Node if curr needs to be replaced and an error has occurred, if any.
   163  func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) {
   164  	v := val.(*LeafNode)
   165  	if len(path) == 0 {
   166  		t.removeRef(curr.Hash(), curr.bytes)
   167  		t.addRef(val.Hash(), val.Bytes())
   168  		return v, nil
   169  	}
   170  
   171  	b := NewBranchNode()
   172  	b.Children[path[0]] = t.newSubTrie(path[1:], v, true)
   173  	b.Children[lastChild] = curr
   174  	t.addRef(b.Hash(), b.bytes)
   175  	return b, nil
   176  }
   177  
   178  // putIntoBranch puts the val to the trie if the current node is a Branch.
   179  // It returns the Node if curr needs to be replaced and an error has occurred, if any.
   180  func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) {
   181  	i, path := splitPath(path)
   182  	t.removeRef(curr.Hash(), curr.bytes)
   183  	r, err := t.putIntoNode(curr.Children[i], path, val)
   184  	if err != nil {
   185  		return nil, err
   186  	}
   187  	curr.Children[i] = r
   188  	curr.invalidateCache()
   189  	t.addRef(curr.Hash(), curr.bytes)
   190  	return curr, nil
   191  }
   192  
   193  // putIntoExtension puts the val to the trie if the current node is an Extension.
   194  // It returns the Node if curr needs to be replaced and an error has occurred, if any.
   195  func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) {
   196  	t.removeRef(curr.Hash(), curr.bytes)
   197  	if bytes.HasPrefix(path, curr.key) {
   198  		r, err := t.putIntoNode(curr.next, path[len(curr.key):], val)
   199  		if err != nil {
   200  			return nil, err
   201  		}
   202  		curr.next = r
   203  		curr.invalidateCache()
   204  		t.addRef(curr.Hash(), curr.bytes)
   205  		return curr, nil
   206  	}
   207  
   208  	pref := lcp(curr.key, path)
   209  	lp := len(pref)
   210  	keyTail := curr.key[lp:]
   211  	pathTail := path[lp:]
   212  
   213  	s1 := t.newSubTrie(keyTail[1:], curr.next, false)
   214  	b := NewBranchNode()
   215  	b.Children[keyTail[0]] = s1
   216  
   217  	i, pathTail := splitPath(pathTail)
   218  	s2 := t.newSubTrie(pathTail, val, true)
   219  	b.Children[i] = s2
   220  
   221  	t.addRef(b.Hash(), b.bytes)
   222  	if lp > 0 {
   223  		e := NewExtensionNode(bytes.Clone(pref), b)
   224  		t.addRef(e.Hash(), e.bytes)
   225  		return e, nil
   226  	}
   227  	return b, nil
   228  }
   229  
   230  func (t *Trie) putIntoEmpty(path []byte, val Node) (Node, error) {
   231  	return t.newSubTrie(path, val, true), nil
   232  }
   233  
   234  // putIntoHash puts the val to the trie if the current node is a HashNode.
   235  // It returns the Node if curr needs to be replaced and an error has occurred, if any.
   236  func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) {
   237  	result, err := t.getFromStore(curr.hash)
   238  	if err != nil {
   239  		return nil, err
   240  	}
   241  	return t.putIntoNode(result, path, val)
   242  }
   243  
   244  // newSubTrie creates a new trie containing the node at the provided path.
   245  func (t *Trie) newSubTrie(path []byte, val Node, newVal bool) Node {
   246  	if newVal {
   247  		t.addRef(val.Hash(), val.Bytes())
   248  	}
   249  	if len(path) == 0 {
   250  		return val
   251  	}
   252  	e := NewExtensionNode(path, val)
   253  	t.addRef(e.Hash(), e.bytes)
   254  	return e
   255  }
   256  
   257  // putIntoNode puts the val with the provided path inside curr and returns an updated node.
   258  // Reference counters are updated for both curr and returned value.
   259  func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) {
   260  	switch n := curr.(type) {
   261  	case *LeafNode:
   262  		return t.putIntoLeaf(n, path, val)
   263  	case *BranchNode:
   264  		return t.putIntoBranch(n, path, val)
   265  	case *ExtensionNode:
   266  		return t.putIntoExtension(n, path, val)
   267  	case *HashNode:
   268  		return t.putIntoHash(n, path, val)
   269  	case EmptyNode:
   270  		return t.putIntoEmpty(path, val)
   271  	default:
   272  		panic("invalid MPT node type")
   273  	}
   274  }
   275  
   276  // Delete removes the key from the trie.
   277  // It returns no error on a missing key.
   278  func (t *Trie) Delete(key []byte) error {
   279  	if len(key) > MaxKeyLength {
   280  		return errors.New("key is too big")
   281  	}
   282  	path := toNibbles(key)
   283  	r, err := t.deleteFromNode(t.root, path)
   284  	if err != nil {
   285  		return err
   286  	}
   287  	t.root = r
   288  	return nil
   289  }
   290  
   291  func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) {
   292  	i, path := splitPath(path)
   293  	h := b.Hash()
   294  	bs := b.bytes
   295  	r, err := t.deleteFromNode(b.Children[i], path)
   296  	if err != nil {
   297  		return nil, err
   298  	}
   299  	t.removeRef(h, bs)
   300  	b.Children[i] = r
   301  	b.invalidateCache()
   302  	var count, index int
   303  	for i := range b.Children {
   304  		if !isEmpty(b.Children[i]) {
   305  			index = i
   306  			count++
   307  		}
   308  	}
   309  	// count is >= 1 because branch node had at least 2 children before deletion.
   310  	if count > 1 {
   311  		t.addRef(b.Hash(), b.bytes)
   312  		return b, nil
   313  	}
   314  	c := b.Children[index]
   315  	if index == lastChild {
   316  		return c, nil
   317  	}
   318  	if h, ok := c.(*HashNode); ok {
   319  		c, err = t.getFromStore(h.Hash())
   320  		if err != nil {
   321  			return nil, err
   322  		}
   323  	}
   324  	if e, ok := c.(*ExtensionNode); ok {
   325  		t.removeRef(e.Hash(), e.bytes)
   326  		e.key = append([]byte{byte(index)}, e.key...)
   327  		e.invalidateCache()
   328  		t.addRef(e.Hash(), e.bytes)
   329  		return e, nil
   330  	}
   331  
   332  	e := NewExtensionNode([]byte{byte(index)}, c)
   333  	t.addRef(e.Hash(), e.bytes)
   334  	return e, nil
   335  }
   336  
   337  func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) {
   338  	if !bytes.HasPrefix(path, n.key) {
   339  		return n, nil
   340  	}
   341  	h := n.Hash()
   342  	bs := n.bytes
   343  	r, err := t.deleteFromNode(n.next, path[len(n.key):])
   344  	if err != nil {
   345  		return nil, err
   346  	}
   347  	t.removeRef(h, bs)
   348  	switch nxt := r.(type) {
   349  	case *ExtensionNode:
   350  		t.removeRef(nxt.Hash(), nxt.bytes)
   351  		n.key = append(n.key, nxt.key...)
   352  		n.next = nxt.next
   353  	case EmptyNode:
   354  		return nxt, nil
   355  	case *HashNode:
   356  		n.next = nxt
   357  	default:
   358  		n.next = r
   359  	}
   360  	n.invalidateCache()
   361  	t.addRef(n.Hash(), n.bytes)
   362  	return n, nil
   363  }
   364  
   365  // deleteFromNode removes the value with the provided path from curr and returns an updated node.
   366  // Reference counters are updated for both curr and returned value.
   367  func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) {
   368  	switch n := curr.(type) {
   369  	case *LeafNode:
   370  		if len(path) == 0 {
   371  			t.removeRef(curr.Hash(), curr.Bytes())
   372  			return EmptyNode{}, nil
   373  		}
   374  		return curr, nil
   375  	case *BranchNode:
   376  		return t.deleteFromBranch(n, path)
   377  	case *ExtensionNode:
   378  		return t.deleteFromExtension(n, path)
   379  	case EmptyNode:
   380  		return n, nil
   381  	case *HashNode:
   382  		newNode, err := t.getFromStore(n.Hash())
   383  		if err != nil {
   384  			return nil, err
   385  		}
   386  		return t.deleteFromNode(newNode, path)
   387  	default:
   388  		panic("invalid MPT node type")
   389  	}
   390  }
   391  
   392  // StateRoot returns root hash of t.
   393  func (t *Trie) StateRoot() util.Uint256 {
   394  	if isEmpty(t.root) {
   395  		return util.Uint256{}
   396  	}
   397  	return t.root.Hash()
   398  }
   399  
   400  func makeStorageKey(mptKey util.Uint256) []byte {
   401  	return append([]byte{byte(storage.DataMPT)}, mptKey[:]...)
   402  }
   403  
   404  // Flush puts every node (except Hash ones) in the trie to the storage.
   405  // Because we care about block-level changes only, there is no need to put every
   406  // new node to the storage. Normally, flush should be called with every StateRoot persist, i.e.
   407  // after every block.
   408  func (t *Trie) Flush(index uint32) {
   409  	key := makeStorageKey(util.Uint256{})
   410  	for h, node := range t.refcount {
   411  		if node.refcount != 0 {
   412  			copy(key[1:], h[:])
   413  			if node.bytes == nil {
   414  				panic("item not in trie")
   415  			}
   416  			if t.mode.RC() {
   417  				node.initial = t.updateRefCount(h, key, index)
   418  				if node.initial == 0 {
   419  					delete(t.refcount, h)
   420  				}
   421  			} else if node.refcount > 0 {
   422  				t.Store.Put(key, node.bytes)
   423  			}
   424  			node.refcount = 0
   425  		} else {
   426  			delete(t.refcount, h)
   427  		}
   428  	}
   429  }
   430  
   431  func IsActiveValue(v []byte) bool {
   432  	return len(v) > 4 && v[len(v)-5] == 1
   433  }
   434  
   435  func getFromStore(key []byte, mode TrieMode, store *storage.MemCachedStore) ([]byte, error) {
   436  	data, err := store.Get(key)
   437  	if err == nil && mode.GC() && !IsActiveValue(data) {
   438  		return nil, storage.ErrKeyNotFound
   439  	}
   440  	return data, err
   441  }
   442  
   443  // updateRefCount should be called only when refcounting is enabled.
   444  func (t *Trie) updateRefCount(h util.Uint256, key []byte, index uint32) int32 {
   445  	if !t.mode.RC() {
   446  		panic("`updateRefCount` is called, but GC is disabled")
   447  	}
   448  	var data []byte
   449  	node := t.refcount[h]
   450  	cnt := node.initial
   451  	if cnt == 0 {
   452  		// A newly created item which may be in store.
   453  		var err error
   454  		data, err = getFromStore(key, t.mode, t.Store)
   455  		if err == nil {
   456  			cnt = int32(binary.LittleEndian.Uint32(data[len(data)-4:]))
   457  		}
   458  	}
   459  	if len(data) == 0 {
   460  		data = append(node.bytes, 1, 0, 0, 0, 0)
   461  	}
   462  	cnt += node.refcount
   463  	switch {
   464  	case cnt < 0:
   465  		// BUG: negative reference count
   466  		panic(fmt.Sprintf("negative reference count: %s new %d, upd %d", h.StringBE(), cnt, t.refcount[h]))
   467  	case cnt == 0:
   468  		if !t.mode.GC() {
   469  			t.Store.Delete(key)
   470  		} else {
   471  			data[len(data)-5] = 0
   472  			binary.LittleEndian.PutUint32(data[len(data)-4:], index)
   473  			t.Store.Put(key, data)
   474  		}
   475  	default:
   476  		binary.LittleEndian.PutUint32(data[len(data)-4:], uint32(cnt))
   477  		t.Store.Put(key, data)
   478  	}
   479  	return cnt
   480  }
   481  
   482  func (t *Trie) addRef(h util.Uint256, bs []byte) {
   483  	node := t.refcount[h]
   484  	if node == nil {
   485  		t.refcount[h] = &cachedNode{
   486  			refcount: 1,
   487  			bytes:    bs,
   488  		}
   489  		return
   490  	}
   491  	node.refcount++
   492  	if node.bytes == nil {
   493  		node.bytes = bs
   494  	}
   495  }
   496  
   497  func (t *Trie) removeRef(h util.Uint256, bs []byte) {
   498  	node := t.refcount[h]
   499  	if node == nil {
   500  		t.refcount[h] = &cachedNode{
   501  			refcount: -1,
   502  			bytes:    bs,
   503  		}
   504  		return
   505  	}
   506  	node.refcount--
   507  	if node.bytes == nil {
   508  		node.bytes = bs
   509  	}
   510  }
   511  
   512  func (t *Trie) getFromStore(h util.Uint256) (Node, error) {
   513  	data, err := getFromStore(makeStorageKey(h), t.mode, t.Store)
   514  	if err != nil {
   515  		return nil, err
   516  	}
   517  
   518  	var n NodeObject
   519  	r := io.NewBinReaderFromBuf(data)
   520  	n.DecodeBinary(r)
   521  	if r.Err != nil {
   522  		return nil, r.Err
   523  	}
   524  
   525  	if t.mode.RC() {
   526  		data = data[:len(data)-5]
   527  		node := t.refcount[h]
   528  		if node != nil {
   529  			node.bytes = data
   530  			_ = r.ReadB()
   531  			node.initial = int32(r.ReadU32LE())
   532  		}
   533  	}
   534  	n.Node.(flushedNode).setCache(data, h)
   535  	return n.Node, nil
   536  }
   537  
   538  // Collapse compresses all nodes at depth n to the hash nodes.
   539  // Note: this function does not perform any kind of storage flushing so
   540  // `Flush()` should be called explicitly before invoking function.
   541  func (t *Trie) Collapse(depth int) {
   542  	if depth < 0 {
   543  		panic("negative depth")
   544  	}
   545  	t.root = collapse(depth, t.root)
   546  	t.refcount = make(map[util.Uint256]*cachedNode)
   547  }
   548  
   549  func collapse(depth int, node Node) Node {
   550  	switch node.(type) {
   551  	case *HashNode, EmptyNode:
   552  		return node
   553  	}
   554  	if depth == 0 {
   555  		return NewHashNode(node.Hash())
   556  	}
   557  
   558  	switch n := node.(type) {
   559  	case *BranchNode:
   560  		for i := range n.Children {
   561  			n.Children[i] = collapse(depth-1, n.Children[i])
   562  		}
   563  	case *ExtensionNode:
   564  		n.next = collapse(depth-1, n.next)
   565  	case *LeafNode:
   566  	case *HashNode:
   567  	default:
   568  		panic("invalid MPT node type")
   569  	}
   570  	return node
   571  }
   572  
   573  // Find returns a list of storage key-value pairs whose key is prefixed by the specified
   574  // prefix starting from the specified `prefix`+`from` path (not including the item at
   575  // the specified `prefix`+`from` path if so). The `max` number of elements is returned at max.
   576  func (t *Trie) Find(prefix, from []byte, max int) ([]storage.KeyValue, error) {
   577  	if len(prefix) > MaxKeyLength {
   578  		return nil, errors.New("invalid prefix length")
   579  	}
   580  	if len(from) > MaxKeyLength-len(prefix) {
   581  		return nil, errors.New("invalid from length")
   582  	}
   583  	prefixP := toNibbles(prefix)
   584  	fromP := []byte{}
   585  	if len(from) > 0 {
   586  		fromP = toNibbles(from)
   587  	}
   588  	_, start, path, err := t.getWithPath(t.root, prefixP, false)
   589  	if err != nil {
   590  		return nil, fmt.Errorf("failed to determine the start node: %w", err)
   591  	}
   592  	path = path[len(prefixP):]
   593  
   594  	if len(fromP) > 0 {
   595  		if len(path) <= len(fromP) && bytes.HasPrefix(fromP, path) {
   596  			fromP = fromP[len(path):]
   597  		} else if len(path) > len(fromP) && bytes.HasPrefix(path, fromP) {
   598  			fromP = []byte{}
   599  		} else {
   600  			cmp := bytes.Compare(path, fromP)
   601  			switch {
   602  			case cmp < 0:
   603  				return []storage.KeyValue{}, nil
   604  			case cmp > 0:
   605  				fromP = []byte{}
   606  			}
   607  		}
   608  	}
   609  
   610  	var (
   611  		res   []storage.KeyValue
   612  		count int
   613  	)
   614  	b := NewBillet(t.root.Hash(), t.mode, 0, t.Store)
   615  	process := func(pathToNode []byte, node Node, _ []byte) bool {
   616  		if leaf, ok := node.(*LeafNode); ok {
   617  			if from == nil || !bytes.Equal(pathToNode, from) { // (*Billet).traverse includes `from` path into result if so. Need to filter out manually.
   618  				res = append(res, storage.KeyValue{
   619  					Key:   append(bytes.Clone(prefix), pathToNode...),
   620  					Value: bytes.Clone(leaf.value),
   621  				})
   622  				count++
   623  			}
   624  		}
   625  		return count >= max
   626  	}
   627  	_, err = b.traverse(start, path, fromP, process, false, false)
   628  	if err != nil && !errors.Is(err, errStop) {
   629  		return nil, err
   630  	}
   631  	return res, nil
   632  }