github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/storage/merkle/tree.go (about)

     1  package merkle
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  
     7  	"golang.org/x/crypto/blake2b"
     8  
     9  	"github.com/onflow/flow-go/ledger/common/bitutils"
    10  )
    11  
    12  // maxKeyLength in bytes:
    13  // For any key, we need to ensure that the entire path can be stored in a short node.
    14  // A short node stores the _number of bits_ for the path segment it represents in 2 bytes.
    15  //
    16  // Hence, the theoretically possible value range is [0,65535]. However, a short node with
    17  // zero path length is not part of our storage model. Furthermore, we always represent
    18  // keys as _byte_ slices, i.e. their number of bits must be an integer-multiple of 8.
    19  // Therefore, the range of valid key length in bytes is [1, 8191] (the corresponding
    20  // range in bits is [8, 65528]) .
    21  const maxKeyLength = 8191
    22  
    23  const maxKeyLenBits = maxKeyLength * 8
    24  
    25  var EmptyTreeRootHash []byte
    26  
    27  func init() {
    28  	h, _ := blake2b.New256([]byte{})
    29  	EmptyTreeRootHash = h.Sum(nil)
    30  }
    31  
    32  // Tree represents a binary patricia merkle tree. The difference with a normal
    33  // merkle tree is that it compresses paths that lead to a single leaf into a
    34  // single intermediary node, which makes it significantly more space-efficient
    35  // and a lot harder to exploit for denial-of-service attacks. On the downside,
    36  // it makes insertions and deletions more complex, as we need to split nodes
    37  // and merge them, depending on whether there are leaves or not.
    38  //
    39  // CONVENTION:
    40  //   - If the tree contains _any_ elements, the tree is defined by its root vertex.
    41  //     This case follows completely the convention for nodes: "In any existing tree,
    42  //     all nodes are non-nil."
    43  //   - Without any stored elements, there exists no root vertex in this data model,
    44  //     and we set `root` to nil.
    45  type Tree struct {
    46  	keyLength int
    47  	root      node
    48  	// setting this flag would prevent more writes to the trie
    49  	// but makes it more efficient for proof generation
    50  	readOnlyEnabled bool
    51  }
    52  
    53  // NewTree creates a new empty patricia merkle tree, with keys of the given
    54  // `keyLength` (length measured in bytes).
    55  // The current implementation only works with 1 ≤ keyLength ≤ 8191. Otherwise,
    56  // the sentinel error `ErrorIncompatibleKeyLength` is returned.
    57  func NewTree(keyLength int) (*Tree, error) {
    58  	if keyLength < 1 || maxKeyLength < keyLength {
    59  		return nil, fmt.Errorf("key length %d is outside of supported interval [1, %d]: %w", keyLength, maxKeyLength, ErrorIncompatibleKeyLength)
    60  	}
    61  	return &Tree{
    62  		keyLength: keyLength,
    63  		root:      nil,
    64  	}, nil
    65  }
    66  
    67  // MakeItReadOnly makes the tree read only, this operation is not reversible.
    68  // when tree becomes readonly, while doing operations it starts caching hashValues
    69  // for faster operations.
    70  func (t *Tree) MakeItReadOnly() {
    71  	t.readOnlyEnabled = true
    72  }
    73  
    74  // ComputeMaxDepth returns the maximum depth of the tree by traversing all paths
    75  //
    76  // Warning: this could be a very expensive operation for large trees, as nodes
    77  // don't cache the depth of children and have to compute by traversing.
    78  func (t *Tree) ComputeMaxDepth() uint {
    79  	return t.root.MaxDepthOfDescendants()
    80  }
    81  
    82  // Put stores the given value in the trie under the given key. If the key
    83  // already exists, it will replace the value and return true. All inputs
    84  // are internally stored and copied where necessary, thereby allowing
    85  // external code to re-use the slices.
    86  // Returns:
    87  //   - (false, nil): key-value pair is stored; key did _not_ yet exist prior to update
    88  //   - (true, nil):  key-value pair is stored; key existed prior to update and the old
    89  //     value was overwritten
    90  //   - (false, error): with possible error returns
    91  //   - ErrorIncompatibleKeyLength if `key` has different length than the pre-configured value
    92  //     No other errors are returned.
    93  func (t *Tree) Put(key []byte, val []byte) (bool, error) {
    94  	if t.readOnlyEnabled {
    95  		return false, errors.New("tree is in readonly mode, no more put operation is accepted")
    96  	}
    97  	if len(key) != t.keyLength {
    98  		return false, fmt.Errorf("trie is configured for key length of %d bytes, but got key with length %d: %w", t.keyLength, len(key), ErrorIncompatibleKeyLength)
    99  	}
   100  	replaced := t.unsafePut(key, val)
   101  	return replaced, nil
   102  }
   103  
   104  // unsafePut stores the given value in the trie under the given key. If the
   105  // key already exists, it will replace the value and return true.
   106  // UNSAFE:
   107  //   - all keys must have identical lengths, which is not checked here.
   108  func (t *Tree) unsafePut(key []byte, val []byte) bool {
   109  	// the path through the tree is determined by the key; we decide whether to
   110  	// go left or right based on whether the next bit is set or not
   111  
   112  	// we use a pointer that points at the current node in the tree
   113  	cur := &t.root
   114  
   115  	// we use an index to keep track of the bit we are currently looking at
   116  	index := 0
   117  
   118  	// the for statement keeps running until we reach a leaf in the merkle tree
   119  	// if the leaf is nil, it was empty and we insert a new value
   120  	// if the leaf is a valid pointer, we overwrite the previous value
   121  PutLoop:
   122  	for {
   123  		switch n := (*cur).(type) {
   124  
   125  		// if we have a full node, we have a node on each side to go to, so we
   126  		// just pick the next node based on whether the bit is set or not
   127  		case *full:
   128  			// if the bit is 0, we go left; otherwise (bit value 1), we go right
   129  			if bitutils.ReadBit(key, index) == 0 {
   130  				cur = &n.left
   131  			} else {
   132  				cur = &n.right
   133  			}
   134  
   135  			// we forward the index by one to look at the next bit
   136  			index++
   137  
   138  			continue PutLoop
   139  
   140  		// if we have a short node, we have a path of several bits to the next
   141  		// node; in that case, we use as much of the shared path as possible
   142  		case *short:
   143  			// first, we find out how many bits we have in common
   144  			commonCount := 0
   145  			for i := 0; i < n.count; i++ {
   146  				if bitutils.ReadBit(key, i+index) != bitutils.ReadBit(n.path, i) {
   147  					break
   148  				}
   149  				commonCount++
   150  			}
   151  
   152  			// if the common and node count are equal, we share all of the path
   153  			// we can simply forward to the child of the short node and continue
   154  			if commonCount == n.count {
   155  				cur = &n.child
   156  				index += commonCount
   157  				continue PutLoop
   158  			}
   159  
   160  			// if the common count is non-zero, we share some of the path;
   161  			// first, we insert a common short node for the shared path
   162  			if commonCount > 0 {
   163  				commonPath := bitutils.MakeBitVector(commonCount)
   164  				for i := 0; i < commonCount; i++ {
   165  					bitutils.WriteBit(commonPath, i, bitutils.ReadBit(key, i+index))
   166  				}
   167  				commonNode := &short{count: commonCount, path: commonPath}
   168  				*cur = commonNode
   169  				cur = &commonNode.child
   170  				index += commonCount
   171  			}
   172  
   173  			// we then insert a full node that splits the tree after the shared
   174  			// path; we set our pointer to the side that lies on our path,
   175  			// and use a remaining pointer for the other side of the node
   176  			var remain *node
   177  			splitNode := &full{}
   178  			*cur = splitNode
   179  			if bitutils.ReadBit(n.path, commonCount) == 1 {
   180  				cur = &splitNode.left
   181  				remain = &splitNode.right
   182  			} else {
   183  				cur = &splitNode.right
   184  				remain = &splitNode.left
   185  			}
   186  			index++
   187  
   188  			// we can continue our insertion at this point, but we should first
   189  			// insert the correct node on the other side of the created full
   190  			// node; if we have remaining path, we create a short node and
   191  			// forward to its path; finally, we set the leaf to original leaf
   192  			remainCount := n.count - commonCount - 1
   193  			if remainCount > 0 {
   194  				remainPath := bitutils.MakeBitVector(remainCount)
   195  				for i := 0; i < remainCount; i++ {
   196  					bitutils.WriteBit(remainPath, i, bitutils.ReadBit(n.path, i+commonCount+1))
   197  				}
   198  				remainNode := &short{count: remainCount, path: remainPath}
   199  				*remain = remainNode
   200  				remain = &remainNode.child
   201  			}
   202  			*remain = n.child
   203  
   204  			continue PutLoop
   205  
   206  		// if we have a leaf node, we reached a non-empty leaf
   207  		case *leaf:
   208  			n.val = append(make([]byte, 0, len(val)), val...)
   209  			return true // return true to indicate that we overwrote
   210  
   211  		// if we have nil, we reached the end of any shared path
   212  		case nil:
   213  			// if we have reached the end of the key, insert the new value
   214  			totalCount := len(key) * 8
   215  			if index == totalCount {
   216  				// Instantiate a new leaf holding a _copy_ of the provided key-value pair,
   217  				// to protect the slices from external modification.
   218  				*cur = &leaf{
   219  					val: append(make([]byte, 0, len(val)), val...),
   220  				}
   221  				return false
   222  			}
   223  
   224  			// otherwise, insert a short node with the remainder of the path
   225  			finalCount := totalCount - index
   226  			finalPath := bitutils.MakeBitVector(finalCount)
   227  			for i := 0; i < finalCount; i++ {
   228  				bitutils.WriteBit(finalPath, i, bitutils.ReadBit(key, index+i))
   229  			}
   230  			finalNode := &short{count: finalCount, path: []byte(finalPath)}
   231  			*cur = finalNode
   232  			cur = &finalNode.child
   233  			index += finalCount
   234  
   235  			continue PutLoop
   236  		}
   237  	}
   238  }
   239  
   240  // Get will retrieve the value associated with the given key. It returns true
   241  // if the key was found and false otherwise.
   242  func (t *Tree) Get(key []byte) ([]byte, bool) {
   243  	if t.keyLength != len(key) {
   244  		return nil, false
   245  	}
   246  	return t.unsafeGet(key)
   247  }
   248  
   249  // unsafeGet retrieves the value associated with the given key. It returns true
   250  // if the key was found and false otherwise.
   251  // UNSAFE:
   252  //   - all keys must have identical lengths, which is not checked here.
   253  func (t *Tree) unsafeGet(key []byte) ([]byte, bool) {
   254  	cur := &t.root // start at the root
   255  	index := 0     // and we start at a zero index in the path
   256  
   257  GetLoop:
   258  	for {
   259  		switch n := (*cur).(type) {
   260  
   261  		// if we have a full node, we can follow the path for at least one more
   262  		// bit, so go left or right depending on whether it's set or not
   263  		case *full:
   264  			// forward pointer and index to the correct child
   265  			if bitutils.ReadBit(key, index) == 0 {
   266  				cur = &n.left
   267  			} else {
   268  				cur = &n.right
   269  			}
   270  
   271  			index++
   272  			continue GetLoop
   273  
   274  		// if we have a short path, we can only follow the short node if
   275  		// its paths has all bits in common with the key we are retrieving
   276  		case *short:
   277  			// if any part of the path doesn't match, key doesn't exist
   278  			for i := 0; i < n.count; i++ {
   279  				if bitutils.ReadBit(key, i+index) != bitutils.ReadBit(n.path, i) {
   280  					return nil, false
   281  				}
   282  			}
   283  
   284  			// forward pointer and index to child
   285  			cur = &n.child
   286  			index += n.count
   287  
   288  			continue GetLoop
   289  
   290  		// if we have a leaf, we found the key, return value and true
   291  		case *leaf:
   292  			return n.val, true
   293  
   294  		// if we have a nil node, key doesn't exist, return nil and false
   295  		case nil:
   296  			return nil, false
   297  		}
   298  	}
   299  }
   300  
   301  // Prove constructs an inclusion proof for the given key, provided the key exists in the trie.
   302  // It returns:
   303  // - (proof, true) if key is found
   304  // - (nil, false) if key is not found
   305  // Proof is constructed by traversing the trie from top to down and collects data for proof as follows:
   306  //   - if full node, append the sibling node hash value to sibling hash list
   307  //   - if short node, appends the node.shortCount to the short count list
   308  //   - if leaf, would capture the leaf value
   309  func (t *Tree) Prove(key []byte) (*Proof, bool) {
   310  
   311  	// check the len of key first
   312  	if t.keyLength != len(key) {
   313  		return nil, false
   314  	}
   315  
   316  	// we start at the root again
   317  	cur := &t.root
   318  
   319  	// and we start at a zero index in the path
   320  	index := 0
   321  
   322  	// init proof params
   323  	siblingHashes := make([][]byte, 0)
   324  	shortPathLengths := make([]uint16, 0)
   325  
   326  	steps := 0
   327  	shortNodeVisited := make([]bool, 0)
   328  
   329  ProveLoop:
   330  	for {
   331  		switch n := (*cur).(type) {
   332  
   333  		// if we have a full node, we can follow the path for at least one more
   334  		// bit, so go left or right depending on whether it's set or not
   335  		case *full:
   336  			var sibling node
   337  			// forward pointer and index to the correct child
   338  			if bitutils.ReadBit(key, index) == 0 {
   339  				sibling = n.right
   340  				cur = &n.left
   341  			} else {
   342  				sibling = n.left
   343  				cur = &n.right
   344  			}
   345  
   346  			index++
   347  			siblingHashes = append(siblingHashes, sibling.Hash(t.readOnlyEnabled))
   348  			shortNodeVisited = append(shortNodeVisited, false)
   349  			steps++
   350  
   351  			continue ProveLoop
   352  
   353  		// if we have a short node, we can only follow the path if the key's subsequent
   354  		// bits match the entire path segment of the short node.
   355  		case *short:
   356  
   357  			// if any part of the path doesn't match, key doesn't exist
   358  			for i := 0; i < n.count; i++ {
   359  				if bitutils.ReadBit(key, i+index) != bitutils.ReadBit(n.path, i) {
   360  					return nil, false
   361  				}
   362  			}
   363  
   364  			cur = &n.child
   365  			index += n.count
   366  			shortPathLengths = append(shortPathLengths, uint16(n.count))
   367  			shortNodeVisited = append(shortNodeVisited, true)
   368  			steps++
   369  
   370  			continue ProveLoop
   371  
   372  		// if we have a leaf, we found the key, return proof and true
   373  		case *leaf:
   374  			// compress interimNodeTypes
   375  			interimNodeTypes := bitutils.MakeBitVector(len(shortNodeVisited))
   376  			for i, b := range shortNodeVisited {
   377  				if b {
   378  					bitutils.SetBit(interimNodeTypes, i)
   379  				}
   380  			}
   381  
   382  			return &Proof{
   383  				Key:              key,
   384  				Value:            n.val,
   385  				InterimNodeTypes: interimNodeTypes,
   386  				ShortPathLengths: shortPathLengths,
   387  				SiblingHashes:    siblingHashes,
   388  			}, true
   389  
   390  		// the only possible nil node is the root node of an empty trie
   391  		case nil:
   392  			return nil, false
   393  		}
   394  	}
   395  }
   396  
   397  // Del removes the value associated with the given key from the patricia
   398  // merkle trie. It returns true if they key was found and false otherwise.
   399  // Internally, any parent nodes between the leaf up to the closest shared path
   400  // will be deleted or merged, which keeps the trie deterministic regardless of
   401  // insertion and deletion orders.
   402  func (t *Tree) Del(key []byte) (bool, error) {
   403  	if t.readOnlyEnabled {
   404  		return false, errors.New("tree is in readonly mode, no more delete operation is accepted")
   405  	}
   406  	if t.keyLength != len(key) {
   407  		return false, fmt.Errorf("trie is configured for key length of %d bytes, but got key with length %d: %w", t.keyLength, len(key), ErrorIncompatibleKeyLength)
   408  	}
   409  	return t.unsafeDel(key), nil
   410  }
   411  
   412  // unsafeDel removes the value associated with the given key from the patricia
   413  // merkle trie. It returns true if they key was found and false otherwise.
   414  // Internally, any parent nodes between the leaf up to the closest shared path
   415  // will be deleted or merged, which keeps the trie deterministic regardless of
   416  // insertion and deletion orders.
   417  // UNSAFE:
   418  //   - all keys must have identical lengths, which is not checked here.
   419  func (t *Tree) unsafeDel(key []byte) bool {
   420  	cur := &t.root // start at the root
   421  	index := 0     // the index points to the bit we are processing in the path
   422  
   423  	// we initialize three pointers pointing to a dummy empty node
   424  	// this is used to keep track of the node we last pointed to, as well as
   425  	// its parent and grand parent, which is needed in case we remove a full
   426  	// node and have to merge several other nodes into a short node; otherwise,
   427  	// we would not keep the tree as compact as possible, and it would no longer
   428  	// be deterministic after deletes
   429  	dummy := node(&dummy{})
   430  	last, parent, grand := &dummy, &dummy, &dummy
   431  
   432  DelLoop:
   433  	for {
   434  		switch n := (*cur).(type) {
   435  
   436  		// if we have a full node, we forward all of the pointers
   437  		case *full:
   438  			// keep track of grand-parent, parent and node for cleanup
   439  			grand = parent
   440  			parent = last
   441  			last = cur
   442  
   443  			// forward pointer and index to the correct child
   444  			if bitutils.ReadBit(key, index) == 0 {
   445  				cur = &n.left
   446  			} else {
   447  				cur = &n.right
   448  			}
   449  
   450  			index++
   451  			continue DelLoop
   452  
   453  		// if we have a short node, we forward by all of the common path if
   454  		// possible; otherwise the node wasn't found
   455  		case *short:
   456  			// keep track of grand-parent, parent and node for cleanup
   457  			grand = parent
   458  			parent = last
   459  			last = cur
   460  
   461  			// if the path doesn't match at any point, we can't find the node
   462  			for i := 0; i < n.count; i++ {
   463  				if bitutils.ReadBit(key, i+index) != bitutils.ReadBit(n.path, i) {
   464  					return false
   465  				}
   466  			}
   467  
   468  			// forward pointer and index to the node child
   469  			cur = &n.child
   470  			index += n.count
   471  
   472  			continue DelLoop
   473  
   474  		// if we have a leaf node, we remove it and continue with cleanup
   475  		case *leaf:
   476  			*cur = nil // replace the current pointer with nil to delete the node
   477  			break DelLoop
   478  
   479  		// if we reach nil, the node doesn't exist
   480  		case nil:
   481  			return false
   482  		}
   483  	}
   484  
   485  	// if the last node before reaching the leaf is a short node, we set it to
   486  	// nil to remove it from the tree and move the pointer to its parent
   487  	_, ok := (*last).(*short)
   488  	if ok {
   489  		*last = nil
   490  		last = parent
   491  		parent = grand
   492  	}
   493  
   494  	// if the last node here is not a full node, we are done; we never have two
   495  	// short nodes in a row, which means we have reached the root
   496  	f, ok := (*last).(*full)
   497  	if !ok {
   498  		return true
   499  	}
   500  
   501  	// if the last node is a full node, we need to convert it into a short node
   502  	// that holds the undeleted child and the corresponding bit as path
   503  	var n *short
   504  	newPath := bitutils.MakeBitVector(1)
   505  	if f.left != nil {
   506  		bitutils.ClearBit(newPath, 0)
   507  		n = &short{count: 1, path: newPath, child: f.left}
   508  	} else {
   509  		bitutils.SetBit(newPath, 0)
   510  		n = &short{count: 1, path: newPath, child: f.right}
   511  	}
   512  	*last = n
   513  
   514  	// if the child is also a short node, we have to merge them and use the
   515  	// child's child as the child of the merged short node
   516  	c, ok := n.child.(*short)
   517  	if ok {
   518  		merge(n, c)
   519  	}
   520  
   521  	// if the parent is also a short node, we have to merge them and use the
   522  	// current child as the child of the merged node
   523  	p, ok := (*parent).(*short)
   524  	if ok {
   525  		merge(p, n)
   526  	}
   527  
   528  	// NOTE: if neither the parent nor the child are short nodes, we simply
   529  	// bypass both conditional scopes and land here right away
   530  	return true
   531  }
   532  
   533  // Hash returns the root hash of this patricia merkle tree.
   534  // Per convention, an empty trie has an empty hash.
   535  func (t *Tree) Hash() []byte {
   536  	if t.root == nil {
   537  		return EmptyTreeRootHash
   538  	}
   539  	return t.root.Hash(t.readOnlyEnabled)
   540  }
   541  
   542  // merge will merge a child short node into a parent short node.
   543  func merge(p *short, c *short) {
   544  	totalCount := p.count + c.count
   545  	totalPath := bitutils.MakeBitVector(totalCount)
   546  	for i := 0; i < p.count; i++ {
   547  		bitutils.WriteBit(totalPath, i, bitutils.ReadBit(p.path, i))
   548  	}
   549  	for i := 0; i < c.count; i++ {
   550  		bitutils.WriteBit(totalPath, i+p.count, bitutils.ReadBit(c.path, i))
   551  	}
   552  	p.count = totalCount
   553  	p.path = totalPath
   554  	p.child = c.child
   555  }