github.com/koko1123/flow-go-1@v0.29.6/storage/merkle/tree.go (about)

     1  // (c) 2019 Dapper Labs - ALL RIGHTS RESERVED
     2  
     3  package merkle
     4  
     5  import (
     6  	"errors"
     7  	"fmt"
     8  
     9  	"golang.org/x/crypto/blake2b"
    10  
    11  	"github.com/koko1123/flow-go-1/ledger/common/bitutils"
    12  )
    13  
    14  // maxKeyLength in bytes:
    15  // For any key, we need to ensure that the entire path can be stored in a short node.
    16  // A short node stores the _number of bits_ for the path segment it represents in 2 bytes.
    17  //
    18  // Hence, the theoretically possible value range is [0,65535]. However, a short node with
    19  // zero path length is not part of our storage model. Furthermore, we always represent
    20  // keys as _byte_ slices, i.e. their number of bits must be an integer-multiple of 8.
    21  // Therefore, the range of valid key length in bytes is [1, 8191] (the corresponding
    22  // range in bits is [8, 65528]) .
    23  const maxKeyLength = 8191
    24  const maxKeyLenBits = maxKeyLength * 8
    25  
    26  var EmptyTreeRootHash []byte
    27  
    28  func init() {
    29  	h, _ := blake2b.New256([]byte{})
    30  	EmptyTreeRootHash = h.Sum(nil)
    31  }
    32  
    33  // Tree represents a binary patricia merkle tree. The difference with a normal
    34  // merkle tree is that it compresses paths that lead to a single leaf into a
    35  // single intermediary node, which makes it significantly more space-efficient
    36  // and a lot harder to exploit for denial-of-service attacks. On the downside,
    37  // it makes insertions and deletions more complex, as we need to split nodes
    38  // and merge them, depending on whether there are leaves or not.
    39  //
    40  // CONVENTION:
    41  //   - If the tree contains _any_ elements, the tree is defined by its root vertex.
    42  //     This case follows completely the convention for nodes: "In any existing tree,
    43  //     all nodes are non-nil."
    44  //   - Without any stored elements, there exists no root vertex in this data model,
    45  //     and we set `root` to nil.
    46  type Tree struct {
    47  	keyLength int
    48  	root      node
    49  	// setting this flag would prevent more writes to the trie
    50  	// but makes it more efficient for proof generation
    51  	readOnlyEnabled bool
    52  }
    53  
    54  // NewTree creates a new empty patricia merkle tree, with keys of the given
    55  // `keyLength` (length measured in bytes).
    56  // The current implementation only works with 1 ≤ keyLength ≤ 8191. Otherwise,
    57  // the sentinel error `ErrorIncompatibleKeyLength` is returned.
    58  func NewTree(keyLength int) (*Tree, error) {
    59  	if keyLength < 1 || maxKeyLength < keyLength {
    60  		return nil, fmt.Errorf("key length %d is outside of supported interval [1, %d]: %w", keyLength, maxKeyLength, ErrorIncompatibleKeyLength)
    61  	}
    62  	return &Tree{
    63  		keyLength: keyLength,
    64  		root:      nil,
    65  	}, nil
    66  }
    67  
    68  // MakeItReadOnly makes the tree read only, this operation is not reversible.
    69  // when tree becomes readonly, while doing operations it starts caching hashValues
    70  // for faster operations.
    71  func (t *Tree) MakeItReadOnly() {
    72  	t.readOnlyEnabled = true
    73  }
    74  
    75  // ComputeMaxDepth returns the maximum depth of the tree by traversing all paths
    76  //
    77  // Warning: this could be a very expensive operation for large trees, as nodes
    78  // don't cache the depth of children and have to compute by traversing.
    79  func (t *Tree) ComputeMaxDepth() uint {
    80  	return t.root.MaxDepthOfDescendants()
    81  }
    82  
    83  // Put stores the given value in the trie under the given key. If the key
    84  // already exists, it will replace the value and return true. All inputs
    85  // are internally stored and copied where necessary, thereby allowing
    86  // external code to re-use the slices.
    87  // Returns:
    88  //   - (false, nil): key-value pair is stored; key did _not_ yet exist prior to update
    89  //   - (true, nil):  key-value pair is stored; key existed prior to update and the old
    90  //     value was overwritten
    91  //   - (false, error): with possible error returns
    92  //   - ErrorIncompatibleKeyLength if `key` has different length than the pre-configured value
    93  //     No other errors are returned.
    94  func (t *Tree) Put(key []byte, val []byte) (bool, error) {
    95  	if t.readOnlyEnabled {
    96  		return false, errors.New("tree is in readonly mode, no more put operation is accepted")
    97  	}
    98  	if len(key) != t.keyLength {
    99  		return false, fmt.Errorf("trie is configured for key length of %d bytes, but got key with length %d: %w", t.keyLength, len(key), ErrorIncompatibleKeyLength)
   100  	}
   101  	replaced := t.unsafePut(key, val)
   102  	return replaced, nil
   103  }
   104  
   105  // unsafePut stores the given value in the trie under the given key. If the
   106  // key already exists, it will replace the value and return true.
   107  // UNSAFE:
   108  //   - all keys must have identical lengths, which is not checked here.
   109  func (t *Tree) unsafePut(key []byte, val []byte) bool {
   110  	// the path through the tree is determined by the key; we decide whether to
   111  	// go left or right based on whether the next bit is set or not
   112  
   113  	// we use a pointer that points at the current node in the tree
   114  	cur := &t.root
   115  
   116  	// we use an index to keep track of the bit we are currently looking at
   117  	index := 0
   118  
   119  	// the for statement keeps running until we reach a leaf in the merkle tree
   120  	// if the leaf is nil, it was empty and we insert a new value
   121  	// if the leaf is a valid pointer, we overwrite the previous value
   122  PutLoop:
   123  	for {
   124  		switch n := (*cur).(type) {
   125  
   126  		// if we have a full node, we have a node on each side to go to, so we
   127  		// just pick the next node based on whether the bit is set or not
   128  		case *full:
   129  			// if the bit is 0, we go left; otherwise (bit value 1), we go right
   130  			if bitutils.ReadBit(key, index) == 0 {
   131  				cur = &n.left
   132  			} else {
   133  				cur = &n.right
   134  			}
   135  
   136  			// we forward the index by one to look at the next bit
   137  			index++
   138  
   139  			continue PutLoop
   140  
   141  		// if we have a short node, we have a path of several bits to the next
   142  		// node; in that case, we use as much of the shared path as possible
   143  		case *short:
   144  			// first, we find out how many bits we have in common
   145  			commonCount := 0
   146  			for i := 0; i < n.count; i++ {
   147  				if bitutils.ReadBit(key, i+index) != bitutils.ReadBit(n.path, i) {
   148  					break
   149  				}
   150  				commonCount++
   151  			}
   152  
   153  			// if the common and node count are equal, we share all of the path
   154  			// we can simply forward to the child of the short node and continue
   155  			if commonCount == n.count {
   156  				cur = &n.child
   157  				index += commonCount
   158  				continue PutLoop
   159  			}
   160  
   161  			// if the common count is non-zero, we share some of the path;
   162  			// first, we insert a common short node for the shared path
   163  			if commonCount > 0 {
   164  				commonPath := bitutils.MakeBitVector(commonCount)
   165  				for i := 0; i < commonCount; i++ {
   166  					bitutils.WriteBit(commonPath, i, bitutils.ReadBit(key, i+index))
   167  				}
   168  				commonNode := &short{count: commonCount, path: commonPath}
   169  				*cur = commonNode
   170  				cur = &commonNode.child
   171  				index += commonCount
   172  			}
   173  
   174  			// we then insert a full node that splits the tree after the shared
   175  			// path; we set our pointer to the side that lies on our path,
   176  			// and use a remaining pointer for the other side of the node
   177  			var remain *node
   178  			splitNode := &full{}
   179  			*cur = splitNode
   180  			if bitutils.ReadBit(n.path, commonCount) == 1 {
   181  				cur = &splitNode.left
   182  				remain = &splitNode.right
   183  			} else {
   184  				cur = &splitNode.right
   185  				remain = &splitNode.left
   186  			}
   187  			index++
   188  
   189  			// we can continue our insertion at this point, but we should first
   190  			// insert the correct node on the other side of the created full
   191  			// node; if we have remaining path, we create a short node and
   192  			// forward to its path; finally, we set the leaf to original leaf
   193  			remainCount := n.count - commonCount - 1
   194  			if remainCount > 0 {
   195  				remainPath := bitutils.MakeBitVector(remainCount)
   196  				for i := 0; i < remainCount; i++ {
   197  					bitutils.WriteBit(remainPath, i, bitutils.ReadBit(n.path, i+commonCount+1))
   198  				}
   199  				remainNode := &short{count: remainCount, path: remainPath}
   200  				*remain = remainNode
   201  				remain = &remainNode.child
   202  			}
   203  			*remain = n.child
   204  
   205  			continue PutLoop
   206  
   207  		// if we have a leaf node, we reached a non-empty leaf
   208  		case *leaf:
   209  			n.val = append(make([]byte, 0, len(val)), val...)
   210  			return true // return true to indicate that we overwrote
   211  
   212  		// if we have nil, we reached the end of any shared path
   213  		case nil:
   214  			// if we have reached the end of the key, insert the new value
   215  			totalCount := len(key) * 8
   216  			if index == totalCount {
   217  				// Instantiate a new leaf holding a _copy_ of the provided key-value pair,
   218  				// to protect the slices from external modification.
   219  				*cur = &leaf{
   220  					val: append(make([]byte, 0, len(val)), val...),
   221  				}
   222  				return false
   223  			}
   224  
   225  			// otherwise, insert a short node with the remainder of the path
   226  			finalCount := totalCount - index
   227  			finalPath := bitutils.MakeBitVector(finalCount)
   228  			for i := 0; i < finalCount; i++ {
   229  				bitutils.WriteBit(finalPath, i, bitutils.ReadBit(key, index+i))
   230  			}
   231  			finalNode := &short{count: finalCount, path: []byte(finalPath)}
   232  			*cur = finalNode
   233  			cur = &finalNode.child
   234  			index += finalCount
   235  
   236  			continue PutLoop
   237  		}
   238  	}
   239  }
   240  
   241  // Get will retrieve the value associated with the given key. It returns true
   242  // if the key was found and false otherwise.
   243  func (t *Tree) Get(key []byte) ([]byte, bool) {
   244  	if t.keyLength != len(key) {
   245  		return nil, false
   246  	}
   247  	return t.unsafeGet(key)
   248  }
   249  
   250  // unsafeGet retrieves the value associated with the given key. It returns true
   251  // if the key was found and false otherwise.
   252  // UNSAFE:
   253  //   - all keys must have identical lengths, which is not checked here.
   254  func (t *Tree) unsafeGet(key []byte) ([]byte, bool) {
   255  	cur := &t.root // start at the root
   256  	index := 0     // and we start at a zero index in the path
   257  
   258  GetLoop:
   259  	for {
   260  		switch n := (*cur).(type) {
   261  
   262  		// if we have a full node, we can follow the path for at least one more
   263  		// bit, so go left or right depending on whether it's set or not
   264  		case *full:
   265  			// forward pointer and index to the correct child
   266  			if bitutils.ReadBit(key, index) == 0 {
   267  				cur = &n.left
   268  			} else {
   269  				cur = &n.right
   270  			}
   271  
   272  			index++
   273  			continue GetLoop
   274  
   275  		// if we have a short path, we can only follow the short node if
   276  		// its paths has all bits in common with the key we are retrieving
   277  		case *short:
   278  			// if any part of the path doesn't match, key doesn't exist
   279  			for i := 0; i < n.count; i++ {
   280  				if bitutils.ReadBit(key, i+index) != bitutils.ReadBit(n.path, i) {
   281  					return nil, false
   282  				}
   283  			}
   284  
   285  			// forward pointer and index to child
   286  			cur = &n.child
   287  			index += n.count
   288  
   289  			continue GetLoop
   290  
   291  		// if we have a leaf, we found the key, return value and true
   292  		case *leaf:
   293  			return n.val, true
   294  
   295  		// if we have a nil node, key doesn't exist, return nil and false
   296  		case nil:
   297  			return nil, false
   298  		}
   299  	}
   300  }
   301  
   302  // Prove constructs an inclusion proof for the given key, provided the key exists in the trie.
   303  // It returns:
   304  // - (proof, true) if key is found
   305  // - (nil, false) if key is not found
   306  // Proof is constructed by traversing the trie from top to down and collects data for proof as follows:
   307  //   - if full node, append the sibling node hash value to sibling hash list
   308  //   - if short node, appends the node.shortCount to the short count list
   309  //   - if leaf, would capture the leaf value
   310  func (t *Tree) Prove(key []byte) (*Proof, bool) {
   311  
   312  	// check the len of key first
   313  	if t.keyLength != len(key) {
   314  		return nil, false
   315  	}
   316  
   317  	// we start at the root again
   318  	cur := &t.root
   319  
   320  	// and we start at a zero index in the path
   321  	index := 0
   322  
   323  	// init proof params
   324  	siblingHashes := make([][]byte, 0)
   325  	shortPathLengths := make([]uint16, 0)
   326  
   327  	steps := 0
   328  	shortNodeVisited := make([]bool, 0)
   329  
   330  ProveLoop:
   331  	for {
   332  		switch n := (*cur).(type) {
   333  
   334  		// if we have a full node, we can follow the path for at least one more
   335  		// bit, so go left or right depending on whether it's set or not
   336  		case *full:
   337  			var sibling node
   338  			// forward pointer and index to the correct child
   339  			if bitutils.ReadBit(key, index) == 0 {
   340  				sibling = n.right
   341  				cur = &n.left
   342  			} else {
   343  				sibling = n.left
   344  				cur = &n.right
   345  			}
   346  
   347  			index++
   348  			siblingHashes = append(siblingHashes, sibling.Hash(t.readOnlyEnabled))
   349  			shortNodeVisited = append(shortNodeVisited, false)
   350  			steps++
   351  
   352  			continue ProveLoop
   353  
   354  		// if we have a short node, we can only follow the path if the key's subsequent
   355  		// bits match the entire path segment of the short node.
   356  		case *short:
   357  
   358  			// if any part of the path doesn't match, key doesn't exist
   359  			for i := 0; i < n.count; i++ {
   360  				if bitutils.ReadBit(key, i+index) != bitutils.ReadBit(n.path, i) {
   361  					return nil, false
   362  				}
   363  			}
   364  
   365  			cur = &n.child
   366  			index += n.count
   367  			shortPathLengths = append(shortPathLengths, uint16(n.count))
   368  			shortNodeVisited = append(shortNodeVisited, true)
   369  			steps++
   370  
   371  			continue ProveLoop
   372  
   373  		// if we have a leaf, we found the key, return proof and true
   374  		case *leaf:
   375  			// compress interimNodeTypes
   376  			interimNodeTypes := bitutils.MakeBitVector(len(shortNodeVisited))
   377  			for i, b := range shortNodeVisited {
   378  				if b {
   379  					bitutils.SetBit(interimNodeTypes, i)
   380  				}
   381  			}
   382  
   383  			return &Proof{
   384  				Key:              key,
   385  				Value:            n.val,
   386  				InterimNodeTypes: interimNodeTypes,
   387  				ShortPathLengths: shortPathLengths,
   388  				SiblingHashes:    siblingHashes,
   389  			}, true
   390  
   391  		// the only possible nil node is the root node of an empty trie
   392  		case nil:
   393  			return nil, false
   394  		}
   395  	}
   396  }
   397  
   398  // Del removes the value associated with the given key from the patricia
   399  // merkle trie. It returns true if they key was found and false otherwise.
   400  // Internally, any parent nodes between the leaf up to the closest shared path
   401  // will be deleted or merged, which keeps the trie deterministic regardless of
   402  // insertion and deletion orders.
   403  func (t *Tree) Del(key []byte) (bool, error) {
   404  	if t.readOnlyEnabled {
   405  		return false, errors.New("tree is in readonly mode, no more delete operation is accepted")
   406  	}
   407  	if t.keyLength != len(key) {
   408  		return false, fmt.Errorf("trie is configured for key length of %d bytes, but got key with length %d: %w", t.keyLength, len(key), ErrorIncompatibleKeyLength)
   409  	}
   410  	return t.unsafeDel(key), nil
   411  }
   412  
   413  // unsafeDel removes the value associated with the given key from the patricia
   414  // merkle trie. It returns true if they key was found and false otherwise.
   415  // Internally, any parent nodes between the leaf up to the closest shared path
   416  // will be deleted or merged, which keeps the trie deterministic regardless of
   417  // insertion and deletion orders.
   418  // UNSAFE:
   419  //   - all keys must have identical lengths, which is not checked here.
   420  func (t *Tree) unsafeDel(key []byte) bool {
   421  	cur := &t.root // start at the root
   422  	index := 0     // the index points to the bit we are processing in the path
   423  
   424  	// we initialize three pointers pointing to a dummy empty node
   425  	// this is used to keep track of the node we last pointed to, as well as
   426  	// its parent and grand parent, which is needed in case we remove a full
   427  	// node and have to merge several other nodes into a short node; otherwise,
   428  	// we would not keep the tree as compact as possible, and it would no longer
   429  	// be deterministic after deletes
   430  	dummy := node(&dummy{})
   431  	last, parent, grand := &dummy, &dummy, &dummy
   432  
   433  DelLoop:
   434  	for {
   435  		switch n := (*cur).(type) {
   436  
   437  		// if we have a full node, we forward all of the pointers
   438  		case *full:
   439  			// keep track of grand-parent, parent and node for cleanup
   440  			grand = parent
   441  			parent = last
   442  			last = cur
   443  
   444  			// forward pointer and index to the correct child
   445  			if bitutils.ReadBit(key, index) == 0 {
   446  				cur = &n.left
   447  			} else {
   448  				cur = &n.right
   449  			}
   450  
   451  			index++
   452  			continue DelLoop
   453  
   454  		// if we have a short node, we forward by all of the common path if
   455  		// possible; otherwise the node wasn't found
   456  		case *short:
   457  			// keep track of grand-parent, parent and node for cleanup
   458  			grand = parent
   459  			parent = last
   460  			last = cur
   461  
   462  			// if the path doesn't match at any point, we can't find the node
   463  			for i := 0; i < n.count; i++ {
   464  				if bitutils.ReadBit(key, i+index) != bitutils.ReadBit(n.path, i) {
   465  					return false
   466  				}
   467  			}
   468  
   469  			// forward pointer and index to the node child
   470  			cur = &n.child
   471  			index += n.count
   472  
   473  			continue DelLoop
   474  
   475  		// if we have a leaf node, we remove it and continue with cleanup
   476  		case *leaf:
   477  			*cur = nil // replace the current pointer with nil to delete the node
   478  			break DelLoop
   479  
   480  		// if we reach nil, the node doesn't exist
   481  		case nil:
   482  			return false
   483  		}
   484  	}
   485  
   486  	// if the last node before reaching the leaf is a short node, we set it to
   487  	// nil to remove it from the tree and move the pointer to its parent
   488  	_, ok := (*last).(*short)
   489  	if ok {
   490  		*last = nil
   491  		last = parent
   492  		parent = grand
   493  	}
   494  
   495  	// if the last node here is not a full node, we are done; we never have two
   496  	// short nodes in a row, which means we have reached the root
   497  	f, ok := (*last).(*full)
   498  	if !ok {
   499  		return true
   500  	}
   501  
   502  	// if the last node is a full node, we need to convert it into a short node
   503  	// that holds the undeleted child and the corresponding bit as path
   504  	var n *short
   505  	newPath := bitutils.MakeBitVector(1)
   506  	if f.left != nil {
   507  		bitutils.ClearBit(newPath, 0)
   508  		n = &short{count: 1, path: newPath, child: f.left}
   509  	} else {
   510  		bitutils.SetBit(newPath, 0)
   511  		n = &short{count: 1, path: newPath, child: f.right}
   512  	}
   513  	*last = n
   514  
   515  	// if the child is also a short node, we have to merge them and use the
   516  	// child's child as the child of the merged short node
   517  	c, ok := n.child.(*short)
   518  	if ok {
   519  		merge(n, c)
   520  	}
   521  
   522  	// if the parent is also a short node, we have to merge them and use the
   523  	// current child as the child of the merged node
   524  	p, ok := (*parent).(*short)
   525  	if ok {
   526  		merge(p, n)
   527  	}
   528  
   529  	// NOTE: if neither the parent nor the child are short nodes, we simply
   530  	// bypass both conditional scopes and land here right away
   531  	return true
   532  }
   533  
   534  // Hash returns the root hash of this patricia merkle tree.
   535  // Per convention, an empty trie has an empty hash.
   536  func (t *Tree) Hash() []byte {
   537  	if t.root == nil {
   538  		return EmptyTreeRootHash
   539  	}
   540  	return t.root.Hash(t.readOnlyEnabled)
   541  }
   542  
   543  // merge will merge a child short node into a parent short node.
   544  func merge(p *short, c *short) {
   545  	totalCount := p.count + c.count
   546  	totalPath := bitutils.MakeBitVector(totalCount)
   547  	for i := 0; i < p.count; i++ {
   548  		bitutils.WriteBit(totalPath, i, bitutils.ReadBit(p.path, i))
   549  	}
   550  	for i := 0; i < c.count; i++ {
   551  		bitutils.WriteBit(totalPath, i+p.count, bitutils.ReadBit(c.path, i))
   552  	}
   553  	p.count = totalCount
   554  	p.path = totalPath
   555  	p.child = c.child
   556  }