
     1  package trie
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io"
     7  	"sync"
     9  	""
    10  	""
    11  	""
    12  )
    14  // MTrie represents a perfect in-memory full binary Merkle tree with uniform height.
    15  // For a detailed description of the storage model, please consult `mtrie/`
    16  //
    17  // A MTrie is a thin wrapper around a the trie's root Node. An MTrie implements the
    18  // logic for forming MTrie-graphs from the elementary nodes. Specifically:
    19  //   - how Nodes (graph vertices) form a Trie,
    20  //   - how register values are read from the trie,
    21  //   - how Merkle proofs are generated from a trie, and
    22  //   - how a new Trie with updated values is generated.
    23  //
    24  // `MTrie`s are _immutable_ data structures. Updating register values is implemented through
    25  // copy-on-write, which creates a new `MTrie`. For minimal memory consumption, all sub-tries
    26  // that where not affected by the write operation are shared between the original MTrie
    27  // (before the register updates) and the updated MTrie (after the register writes).
    28  //
    29  // MTrie expects that for a specific path, the register's key never changes.
    30  //
    32  //   - HEIGHT of a node v in a tree is the number of edges on the longest downward path
    33  //     between v and a tree leaf. The height of a tree is the height of its root.
    34  //     The height of a Trie is always the height of the fully-expanded tree.
    35  type MTrie struct {
    36  	root     *node.Node
    37  	regCount uint64 // number of registers allocated in the trie
    38  	regSize  uint64 // size of registers allocated in the trie
    39  }
    41  // NewEmptyMTrie returns an empty Mtrie (root is nil)
    42  func NewEmptyMTrie() *MTrie {
    43  	return &MTrie{root: nil}
    44  }
    46  // IsEmpty checks if a trie is empty.
    47  //
    48  // An empty try doesn't mean a trie with no allocated registers.
    49  func (mt *MTrie) IsEmpty() bool {
    50  	return mt.root == nil
    51  }
    53  // NewMTrie returns a Mtrie given the root
    54  func NewMTrie(root *node.Node, regCount uint64, regSize uint64) (*MTrie, error) {
    55  	if root != nil && root.Height() != ledger.NodeMaxHeight {
    56  		return nil, fmt.Errorf("height of root node must be %d but is %d, hash: %s", ledger.NodeMaxHeight, root.Height(), root.Hash().String())
    57  	}
    58  	return &MTrie{
    59  		root:     root,
    60  		regCount: regCount,
    61  		regSize:  regSize,
    62  	}, nil
    63  }
    65  // RootHash returns the trie's root hash.
    66  // Concurrency safe (as Tries are immutable structures by convention)
    67  func (mt *MTrie) RootHash() ledger.RootHash {
    68  	if mt.IsEmpty() {
    69  		// case of an empty trie
    70  		return EmptyTrieRootHash()
    71  	}
    72  	return ledger.RootHash(mt.root.Hash())
    73  }
    75  // AllocatedRegCount returns the number of allocated registers in the trie.
    76  // Concurrency safe (as Tries are immutable structures by convention)
    77  func (mt *MTrie) AllocatedRegCount() uint64 {
    78  	return mt.regCount
    79  }
    81  // AllocatedRegSize returns the size (number of bytes) of allocated registers in the trie.
    82  // Concurrency safe (as Tries are immutable structures by convention)
    83  func (mt *MTrie) AllocatedRegSize() uint64 {
    84  	return mt.regSize
    85  }
    87  // RootNode returns the Trie's root Node
    88  // Concurrency safe (as Tries are immutable structures by convention)
    89  func (mt *MTrie) RootNode() *node.Node {
    90  	return mt.root
    91  }
    93  // String returns the trie's string representation.
    94  // Concurrency safe (as Tries are immutable structures by convention)
    95  func (mt *MTrie) String() string {
    96  	if mt.IsEmpty() {
    97  		return fmt.Sprintf("Empty Trie with default root hash: %v\n", mt.RootHash())
    98  	}
    99  	trieStr := fmt.Sprintf("Trie root hash: %v\n", mt.RootHash())
   100  	return trieStr + mt.root.FmtStr("", "")
   101  }
   103  // UnsafeValueSizes returns payload value sizes for the given paths.
   104  // UNSAFE: requires _all_ paths to have a length of mt.Height bits.
   105  // CAUTION: while getting payload value sizes, `paths` is permuted IN-PLACE for optimized processing.
   106  // Return:
   107  //   - `sizes` []int
   108  //     For each path, the corresponding payload value size is written into sizes. AFTER
   109  //     the size operation completes, the order of `path` and `sizes` are such that
   110  //     for `path[i]` the corresponding register value size is referenced by `sizes[i]`.
   111  //
   112  // TODO move consistency checks from Forest into Trie to obtain a safe, self-contained API
   113  func (mt *MTrie) UnsafeValueSizes(paths []ledger.Path) []int {
   114  	sizes := make([]int, len(paths)) // pre-allocate slice for the result
   115  	valueSizes(sizes, paths, mt.root)
   116  	return sizes
   117  }
   119  // valueSizes returns value sizes of all the registers in `paths“ in subtree with `head` as root node.
   120  // For each `path[i]`, the corresponding value size is written into `sizes[i]` for the same index `i`.
   121  // CAUTION:
   122  //   - while reading the payloads, `paths` is permuted IN-PLACE for optimized processing.
   123  //   - unchecked requirement: all paths must go through the `head` node
   124  func valueSizes(sizes []int, paths []ledger.Path, head *node.Node) {
   125  	// check for empty paths
   126  	if len(paths) == 0 {
   127  		return
   128  	}
   130  	// path not found
   131  	if head == nil {
   132  		return
   133  	}
   135  	// reached a leaf node
   136  	if head.IsLeaf() {
   137  		for i, p := range paths {
   138  			if *head.Path() == p {
   139  				payload := head.Payload()
   140  				if payload != nil {
   141  					sizes[i] = payload.Value().Size()
   142  				}
   143  				// NOTE: break isn't used here because precondition
   144  				// doesn't require paths being deduplicated.
   145  			}
   146  		}
   147  		return
   148  	}
   150  	// reached an interim node with only one path
   151  	if len(paths) == 1 {
   152  		path := paths[0][:]
   154  		// traverse nodes following the path until a leaf node or nil node is reached.
   155  		// "for" loop helps to skip partition and recursive call when there's only one path to follow.
   156  		for {
   157  			depth := ledger.NodeMaxHeight - head.Height() // distance to the tree root
   158  			bit := bitutils.ReadBit(path, depth)
   159  			if bit == 0 {
   160  				head = head.LeftChild()
   161  			} else {
   162  				head = head.RightChild()
   163  			}
   164  			if head.IsLeaf() {
   165  				break
   166  			}
   167  		}
   169  		valueSizes(sizes, paths, head)
   170  		return
   171  	}
   173  	// reached an interim node with more than one paths
   175  	// partition step to quick sort the paths:
   176  	// lpaths contains all paths that have `0` at the partitionIndex
   177  	// rpaths contains all paths that have `1` at the partitionIndex
   178  	depth := ledger.NodeMaxHeight - head.Height() // distance to the tree root
   179  	partitionIndex := SplitPaths(paths, depth)
   180  	lpaths, rpaths := paths[:partitionIndex], paths[partitionIndex:]
   181  	lsizes, rsizes := sizes[:partitionIndex], sizes[partitionIndex:]
   183  	// read values from left and right subtrees in parallel
   184  	parallelRecursionThreshold := 32 // threshold to avoid the parallelization going too deep in the recursion
   185  	if len(lpaths) < parallelRecursionThreshold || len(rpaths) < parallelRecursionThreshold {
   186  		valueSizes(lsizes, lpaths, head.LeftChild())
   187  		valueSizes(rsizes, rpaths, head.RightChild())
   188  	} else {
   189  		// concurrent read of left and right subtree
   190  		wg := sync.WaitGroup{}
   191  		wg.Add(1)
   192  		go func() {
   193  			valueSizes(lsizes, lpaths, head.LeftChild())
   194  			wg.Done()
   195  		}()
   196  		valueSizes(rsizes, rpaths, head.RightChild())
   197  		wg.Wait() // wait for all threads
   198  	}
   199  }
   201  // ReadSinglePayload reads and returns a payload for a single path.
   202  func (mt *MTrie) ReadSinglePayload(path ledger.Path) *ledger.Payload {
   203  	return readSinglePayload(path, mt.root)
   204  }
   206  // readSinglePayload reads and returns a payload for a single path in subtree with `head` as root node.
   207  func readSinglePayload(path ledger.Path, head *node.Node) *ledger.Payload {
   208  	pathBytes := path[:]
   210  	if head == nil {
   211  		return ledger.EmptyPayload()
   212  	}
   214  	depth := ledger.NodeMaxHeight - head.Height() // distance to the tree root
   216  	// Traverse nodes following the path until a leaf node or nil node is reached.
   217  	for !head.IsLeaf() {
   218  		bit := bitutils.ReadBit(pathBytes, depth)
   219  		if bit == 0 {
   220  			head = head.LeftChild()
   221  		} else {
   222  			head = head.RightChild()
   223  		}
   224  		depth++
   225  	}
   227  	if head != nil && *head.Path() == path {
   228  		return head.Payload()
   229  	}
   231  	return ledger.EmptyPayload()
   232  }
   234  // UnsafeRead reads payloads for the given paths.
   235  // UNSAFE: requires _all_ paths to have a length of mt.Height bits.
   236  // CAUTION: while reading the payloads, `paths` is permuted IN-PLACE for optimized processing.
   237  // Return:
   238  //   - `payloads` []*ledger.Payload
   239  //     For each path, the corresponding payload is written into payloads. AFTER
   240  //     the read operation completes, the order of `path` and `payloads` are such that
   241  //     for `path[i]` the corresponding register value is referenced by 0`payloads[i]`.
   242  //
   243  // TODO move consistency checks from Forest into Trie to obtain a safe, self-contained API
   244  func (mt *MTrie) UnsafeRead(paths []ledger.Path) []*ledger.Payload {
   245  	payloads := make([]*ledger.Payload, len(paths)) // pre-allocate slice for the result
   246  	read(payloads, paths, mt.root)
   247  	return payloads
   248  }
   250  // read reads all the registers in subtree with `head` as root node. For each
   251  // `path[i]`, the corresponding payload is written into `payloads[i]` for the same index `i`.
   252  // CAUTION:
   253  //   - while reading the payloads, `paths` is permuted IN-PLACE for optimized processing.
   254  //   - unchecked requirement: all paths must go through the `head` node
   255  func read(payloads []*ledger.Payload, paths []ledger.Path, head *node.Node) {
   256  	// check for empty paths
   257  	if len(paths) == 0 {
   258  		return
   259  	}
   261  	// path not found
   262  	if head == nil {
   263  		for i := range paths {
   264  			payloads[i] = ledger.EmptyPayload()
   265  		}
   266  		return
   267  	}
   269  	// reached a leaf node
   270  	if head.IsLeaf() {
   271  		for i, p := range paths {
   272  			if *head.Path() == p {
   273  				payloads[i] = head.Payload()
   274  			} else {
   275  				payloads[i] = ledger.EmptyPayload()
   276  			}
   277  		}
   278  		return
   279  	}
   281  	// reached an interim node
   282  	if len(paths) == 1 {
   283  		// call readSinglePayload to skip partition and recursive calls when there is only one path
   284  		payloads[0] = readSinglePayload(paths[0], head)
   285  		return
   286  	}
   288  	// partition step to quick sort the paths:
   289  	// lpaths contains all paths that have `0` at the partitionIndex
   290  	// rpaths contains all paths that have `1` at the partitionIndex
   291  	depth := ledger.NodeMaxHeight - head.Height() // distance to the tree root
   292  	partitionIndex := SplitPaths(paths, depth)
   293  	lpaths, rpaths := paths[:partitionIndex], paths[partitionIndex:]
   294  	lpayloads, rpayloads := payloads[:partitionIndex], payloads[partitionIndex:]
   296  	// read values from left and right subtrees in parallel
   297  	parallelRecursionThreshold := 32 // threshold to avoid the parallelization going too deep in the recursion
   298  	if len(lpaths) < parallelRecursionThreshold || len(rpaths) < parallelRecursionThreshold {
   299  		read(lpayloads, lpaths, head.LeftChild())
   300  		read(rpayloads, rpaths, head.RightChild())
   301  	} else {
   302  		// concurrent read of left and right subtree
   303  		wg := sync.WaitGroup{}
   304  		wg.Add(1)
   305  		go func() {
   306  			read(lpayloads, lpaths, head.LeftChild())
   307  			wg.Done()
   308  		}()
   309  		read(rpayloads, rpaths, head.RightChild())
   310  		wg.Wait() // wait for all threads
   311  	}
   312  }
   314  // NewTrieWithUpdatedRegisters constructs a new trie containing all registers from the parent trie,
   315  // and returns:
   316  //   - updated trie
   317  //   - max depth touched during update (this isn't affected by prune flag)
   318  //   - error
   319  //
   320  // The key-value pairs specify the registers whose values are supposed to hold updated values
   321  // compared to the parent trie. Constructing the new trie is done in a COPY-ON-WRITE manner:
   322  //   - The original trie remains unchanged.
   323  //   - subtries that remain unchanged are from the parent trie instead of copied.
   324  //
   325  // UNSAFE: method requires the following conditions to be satisfied:
   326  //   - keys are NOT duplicated
   327  //   - requires _all_ paths to have a length of mt.Height bits.
   328  //
   329  // CAUTION: `updatedPaths` and `updatedPayloads` are permuted IN-PLACE for optimized processing.
   330  // CAUTION: MTrie expects that for a specific path, the payload's key never changes.
   331  // TODO: move consistency checks from MForest to here, to make API safe and self-contained
   332  func NewTrieWithUpdatedRegisters(
   333  	parentTrie *MTrie,
   334  	updatedPaths []ledger.Path,
   335  	updatedPayloads []ledger.Payload,
   336  	prune bool,
   337  ) (*MTrie, uint16, error) {
   338  	updatedRoot, regCountDelta, regSizeDelta, lowestHeightTouched := update(
   339  		ledger.NodeMaxHeight,
   340  		parentTrie.root,
   341  		updatedPaths,
   342  		updatedPayloads,
   343  		nil,
   344  		prune,
   345  	)
   347  	updatedTrieRegCount := int64(parentTrie.AllocatedRegCount()) + regCountDelta
   348  	updatedTrieRegSize := int64(parentTrie.AllocatedRegSize()) + regSizeDelta
   349  	maxDepthTouched := uint16(ledger.NodeMaxHeight - lowestHeightTouched)
   351  	updatedTrie, err := NewMTrie(updatedRoot, uint64(updatedTrieRegCount), uint64(updatedTrieRegSize))
   352  	if err != nil {
   353  		return nil, 0, fmt.Errorf("constructing updated trie failed: %w", err)
   354  	}
   355  	return updatedTrie, maxDepthTouched, nil
   356  }
   358  // updateResult is a wrapper of return values from update().
   359  // It's used to communicate values from goroutine.
   360  type updateResult struct {
   361  	child                  *node.Node
   362  	allocatedRegCountDelta int64
   363  	allocatedRegSizeDelta  int64
   364  	lowestHeightTouched    int
   365  }
   367  // update traverses the subtree recursively and create new nodes with
   368  // the updated payloads on the given paths
   369  //
   370  // it returns:
   371  //   - new updated node or original node if nothing was updated
   372  //   - allocated register count delta in subtrie (allocatedRegCountDelta)
   373  //   - allocated register size delta in subtrie (allocatedRegSizeDelta)
   374  //   - lowest height reached during recursive update in subtrie (lowestHeightTouched)
   375  //
   376  // update also compact a subtree into a single compact leaf node in the case where
   377  // there is only 1 payload stored in the subtree.
   378  //
   379  // allocatedRegCountDelta and allocatedRegSizeDelta are used to compute updated
   380  // trie's allocated register count and size.  lowestHeightTouched is used to
   381  // compute max depth touched during update.
   382  // CAUTION: while updating, `paths` and `payloads` are permuted IN-PLACE for optimized processing.
   383  // UNSAFE: method requires the following conditions to be satisfied:
   384  //   - paths all share the same common prefix [0 : mt.maxHeight-1 - nodeHeight)
   385  //     (excluding the bit at index headHeight)
   386  //   - paths are NOT duplicated
   387  func update(
   388  	nodeHeight int, // the height of the node during traversing the subtree
   389  	currentNode *node.Node, // the current node on the travesing path, if it's nil it means the trie has no node on this path
   390  	paths []ledger.Path, // the paths to update the payloads
   391  	payloads []ledger.Payload, // the payloads to be updated at the given paths
   392  	compactLeaf *node.Node, // a compact leaf node from its ancester, it could be nil
   393  	prune bool, // prune is a flag for whether pruning nodes with empty payload. not pruning is useful for generating proof, expecially non-inclusion proof
   394  ) (n *node.Node, allocatedRegCountDelta int64, allocatedRegSizeDelta int64, lowestHeightTouched int) {
   395  	// No new path to update
   396  	if len(paths) == 0 {
   397  		if compactLeaf != nil {
   398  			// if a compactLeaf from a higher height is still left,
   399  			// then expand the compact leaf node to the current height by creating a new compact leaf
   400  			// node with the same path and payload.
   401  			// The old node shouldn't be recycled as it is still used by the tree copy before the update.
   402  			n = node.NewLeaf(*compactLeaf.Path(), compactLeaf.Payload(), nodeHeight)
   403  			return n, 0, 0, nodeHeight
   404  		}
   405  		// if no path to update and there is no compact leaf node on this path, we return
   406  		// the current node regardless it exists or not.
   407  		return currentNode, 0, 0, nodeHeight
   408  	}
   410  	if len(paths) == 1 && currentNode == nil && compactLeaf == nil {
   411  		// if there is only 1 path to update, and the existing tree has no node on this path, also
   412  		// no compact leaf node from its ancester, it means we are storing a payload on a new path,
   413  		n = node.NewLeaf(paths[0], payloads[0].DeepCopy(), nodeHeight)
   414  		if payloads[0].IsEmpty() {
   415  			// if we are storing an empty node, then no register is allocated
   416  			// allocatedRegCountDelta and allocatedRegSizeDelta should both be 0
   417  			return n, 0, 0, nodeHeight
   418  		}
   419  		// if we are storing a non-empty node, we are allocating a new register
   420  		return n, 1, int64(payloads[0].Size()), nodeHeight
   421  	}
   423  	if currentNode != nil && currentNode.IsLeaf() { // if we're here then compactLeaf == nil
   424  		// check if the current node path is among the updated paths
   425  		found := false
   426  		currentPath := *currentNode.Path()
   427  		for i, p := range paths {
   428  			if p == currentPath {
   429  				// the case where the recursion stops: only one path to update
   430  				if len(paths) == 1 {
   431  					// check if the only path to update has the same payload.
   432  					// if payload is the same, we could skip the update to avoid creating duplicated node
   433  					if !currentNode.Payload().ValueEquals(&payloads[i]) {
   434  						n = node.NewLeaf(paths[i], payloads[i].DeepCopy(), nodeHeight)
   436  						allocatedRegCountDelta, allocatedRegSizeDelta =
   437  							computeAllocatedRegDeltas(currentNode.Payload(), &payloads[i])
   439  						return n, allocatedRegCountDelta, allocatedRegSizeDelta, nodeHeight
   440  					}
   441  					// avoid creating a new node when the same payload is written
   442  					return currentNode, 0, 0, nodeHeight
   443  				}
   444  				// the case where the recursion carries on: len(paths)>1
   445  				found = true
   447  				allocatedRegCountDelta, allocatedRegSizeDelta =
   448  					computeAllocatedRegDeltasFromHigherHeight(currentNode.Payload())
   450  				break
   451  			}
   452  		}
   453  		if !found {
   454  			// if the current node carries a path not included in the input path, then the current node
   455  			// represents a compact leaf that needs to be carried down the recursion.
   456  			compactLeaf = currentNode
   457  		}
   458  	}
   460  	// in the remaining code:
   461  	//   - either len(paths) > 1
   462  	//   - or len(paths) == 1 and compactLeaf!= nil
   463  	//   - or len(paths) == 1 and currentNode != nil && !currentNode.IsLeaf()
   465  	// Split paths and payloads to recurse:
   466  	// lpaths contains all paths that have `0` at the partitionIndex
   467  	// rpaths contains all paths that have `1` at the partitionIndex
   468  	depth := ledger.NodeMaxHeight - nodeHeight // distance to the tree root
   469  	partitionIndex := splitByPath(paths, payloads, depth)
   470  	lpaths, rpaths := paths[:partitionIndex], paths[partitionIndex:]
   471  	lpayloads, rpayloads := payloads[:partitionIndex], payloads[partitionIndex:]
   473  	// check if there is a compact leaf that needs to get deep to height 0
   474  	var lcompactLeaf, rcompactLeaf *node.Node
   475  	if compactLeaf != nil {
   476  		// if yes, check which branch it will go to.
   477  		path := *compactLeaf.Path()
   478  		if bitutils.ReadBit(path[:], depth) == 0 {
   479  			lcompactLeaf = compactLeaf
   480  		} else {
   481  			rcompactLeaf = compactLeaf
   482  		}
   483  	}
   485  	// set the node children
   486  	var oldLeftChild, oldRightChild *node.Node
   487  	if currentNode != nil {
   488  		oldLeftChild = currentNode.LeftChild()
   489  		oldRightChild = currentNode.RightChild()
   490  	}
   492  	// recurse over each branch
   493  	var newLeftChild, newRightChild *node.Node
   494  	var lRegCountDelta, rRegCountDelta int64
   495  	var lRegSizeDelta, rRegSizeDelta int64
   496  	var lLowestHeightTouched, rLowestHeightTouched int
   497  	parallelRecursionThreshold := 16
   498  	if len(lpaths) < parallelRecursionThreshold || len(rpaths) < parallelRecursionThreshold {
   499  		// runtime optimization: if there are _no_ updates for either left or right sub-tree, proceed single-threaded
   500  		newLeftChild, lRegCountDelta, lRegSizeDelta, lLowestHeightTouched = update(nodeHeight-1, oldLeftChild, lpaths, lpayloads, lcompactLeaf, prune)
   501  		newRightChild, rRegCountDelta, rRegSizeDelta, rLowestHeightTouched = update(nodeHeight-1, oldRightChild, rpaths, rpayloads, rcompactLeaf, prune)
   502  	} else {
   503  		// runtime optimization: process the left child in a separate thread
   505  		// Since we're receiving 4 values from goroutine, use a
   506  		// struct and channel to reduce allocs/op.
   507  		// Although WaitGroup approach can be faster than channel (esp. with 2+ goroutines),
   508  		// we only use 1 goroutine here and need to communicate results from it. So using
   509  		// channel is faster and uses fewer allocs/op in this case.
   510  		results := make(chan updateResult, 1)
   511  		go func(retChan chan<- updateResult) {
   512  			child, regCountDelta, regSizeDelta, lowestHeightTouched := update(nodeHeight-1, oldLeftChild, lpaths, lpayloads, lcompactLeaf, prune)
   513  			retChan <- updateResult{child, regCountDelta, regSizeDelta, lowestHeightTouched}
   514  		}(results)
   516  		newRightChild, rRegCountDelta, rRegSizeDelta, rLowestHeightTouched = update(nodeHeight-1, oldRightChild, rpaths, rpayloads, rcompactLeaf, prune)
   518  		// Wait for results from goroutine.
   519  		ret := <-results
   520  		newLeftChild, lRegCountDelta, lRegSizeDelta, lLowestHeightTouched = ret.child, ret.allocatedRegCountDelta, ret.allocatedRegSizeDelta, ret.lowestHeightTouched
   521  	}
   523  	allocatedRegCountDelta += lRegCountDelta + rRegCountDelta
   524  	allocatedRegSizeDelta += lRegSizeDelta + rRegSizeDelta
   525  	lowestHeightTouched = minInt(lLowestHeightTouched, rLowestHeightTouched)
   527  	// mitigate storage exhaustion attack: avoids creating a new node when the exact same
   528  	// payload is re-written at a register. CAUTION: we only check that the children are
   529  	// unchanged. This is only sufficient for interim nodes (for leaf nodes, the children
   530  	// might be unchanged, i.e. both nil, but the payload could have changed).
   531  	// In case the current node was a leaf, we _cannot reuse_ it, because we potentially
   532  	// updated registers in the sub-trie
   533  	if !currentNode.IsLeaf() && newLeftChild == oldLeftChild && newRightChild == oldRightChild {
   534  		return currentNode, 0, 0, lowestHeightTouched
   535  	}
   537  	// if prune is on, then will check and create a compact leaf node if one child is nil, and the
   538  	// other child is a leaf node
   539  	if prune {
   540  		n = node.NewInterimCompactifiedNode(nodeHeight, newLeftChild, newRightChild)
   541  		return n, allocatedRegCountDelta, allocatedRegSizeDelta, lowestHeightTouched
   542  	}
   544  	n = node.NewInterimNode(nodeHeight, newLeftChild, newRightChild)
   545  	return n, allocatedRegCountDelta, allocatedRegSizeDelta, lowestHeightTouched
   546  }
   548  // computeAllocatedRegDeltasFromHigherHeight returns the deltas
   549  // needed to compute the allocated reg count and reg size when
   550  // a payload is updated or unallocated at a lower height.
   551  func computeAllocatedRegDeltasFromHigherHeight(oldPayload *ledger.Payload) (allocatedRegCountDelta, allocatedRegSizeDelta int64) {
   552  	if !oldPayload.IsEmpty() {
   553  		// Allocated register will be updated or unallocated at lower height.
   554  		allocatedRegCountDelta--
   555  	}
   556  	oldPayloadSize := oldPayload.Size()
   557  	allocatedRegSizeDelta -= int64(oldPayloadSize)
   558  	return
   559  }
   561  // computeAllocatedRegDeltas returns the allocated reg count
   562  // and reg size deltas computed from old payload and new payload.
   563  // PRECONDITION: !oldPayload.Equals(newPayload)
   564  func computeAllocatedRegDeltas(oldPayload, newPayload *ledger.Payload) (allocatedRegCountDelta, allocatedRegSizeDelta int64) {
   565  	allocatedRegCountDelta = 0
   566  	if newPayload.IsEmpty() {
   567  		// Old payload is not empty while new payload is empty.
   568  		// Allocated register will be unallocated.
   569  		allocatedRegCountDelta = -1
   570  	} else if oldPayload.IsEmpty() {
   571  		// Old payload is empty while new payload is not empty.
   572  		// Unallocated register will be allocated.
   573  		allocatedRegCountDelta = 1
   574  	}
   576  	oldPayloadSize := oldPayload.Size()
   577  	newPayloadSize := newPayload.Size()
   578  	allocatedRegSizeDelta = int64(newPayloadSize - oldPayloadSize)
   579  	return
   580  }
   582  // UnsafeProofs provides proofs for the given paths.
   583  //
   584  // CAUTION: while updating, `paths` and `proofs` are permuted IN-PLACE for optimized processing.
   585  // UNSAFE: requires _all_ paths to have a length of mt.Height bits.
   586  // Paths in the input query don't have to be deduplicated, though deduplication would
   587  // result in allocating less dynamic memory to store the proofs.
   588  func (mt *MTrie) UnsafeProofs(paths []ledger.Path) *ledger.TrieBatchProof {
   589  	batchProofs := ledger.NewTrieBatchProofWithEmptyProofs(len(paths))
   590  	prove(mt.root, paths, batchProofs.Proofs)
   591  	return batchProofs
   592  }
   594  // prove traverses the subtree and stores proofs for the given register paths in
   595  // the provided `proofs` slice
   596  // CAUTION: while updating, `paths` and `proofs` are permuted IN-PLACE for optimized processing.
   597  // UNSAFE: method requires the following conditions to be satisfied:
   598  //   - paths all share the same common prefix [0 : mt.maxHeight-1 - nodeHeight)
   599  //     (excluding the bit at index headHeight)
   600  func prove(head *node.Node, paths []ledger.Path, proofs []*ledger.TrieProof) {
   601  	// check for empty paths
   602  	if len(paths) == 0 {
   603  		return
   604  	}
   606  	// we've reached the end of a trie
   607  	// and path is not found (noninclusion proof)
   608  	if head == nil {
   609  		// by default, proofs are non-inclusion proofs
   610  		return
   611  	}
   613  	// we've reached a leaf
   614  	if head.IsLeaf() {
   615  		for i, path := range paths {
   616  			// value matches (inclusion proof)
   617  			if *head.Path() == path {
   618  				proofs[i].Path = *head.Path()
   619  				proofs[i].Payload = head.Payload()
   620  				proofs[i].Inclusion = true
   621  			}
   622  		}
   623  		// by default, proofs are non-inclusion proofs
   624  		return
   625  	}
   627  	// increment steps for all the proofs
   628  	for _, p := range proofs {
   629  		p.Steps++
   630  	}
   632  	// partition step to quick sort the paths:
   633  	// lpaths contains all paths that have `0` at the partitionIndex
   634  	// rpaths contains all paths that have `1` at the partitionIndex
   635  	depth := ledger.NodeMaxHeight - head.Height() // distance to the tree root
   636  	partitionIndex := splitTrieProofsByPath(paths, proofs, depth)
   637  	lpaths, rpaths := paths[:partitionIndex], paths[partitionIndex:]
   638  	lproofs, rproofs := proofs[:partitionIndex], proofs[partitionIndex:]
   640  	parallelRecursionThreshold := 64 // threshold to avoid the parallelization going too deep in the recursion
   641  	if len(lpaths) < parallelRecursionThreshold || len(rpaths) < parallelRecursionThreshold {
   642  		// runtime optimization: below the parallelRecursionThreshold, we proceed single-threaded
   643  		addSiblingTrieHashToProofs(head.RightChild(), depth, lproofs)
   644  		prove(head.LeftChild(), lpaths, lproofs)
   646  		addSiblingTrieHashToProofs(head.LeftChild(), depth, rproofs)
   647  		prove(head.RightChild(), rpaths, rproofs)
   648  	} else {
   649  		wg := sync.WaitGroup{}
   650  		wg.Add(1)
   651  		go func() {
   652  			addSiblingTrieHashToProofs(head.RightChild(), depth, lproofs)
   653  			prove(head.LeftChild(), lpaths, lproofs)
   654  			wg.Done()
   655  		}()
   657  		addSiblingTrieHashToProofs(head.LeftChild(), depth, rproofs)
   658  		prove(head.RightChild(), rpaths, rproofs)
   659  		wg.Wait()
   660  	}
   661  }
   663  // addSiblingTrieHashToProofs inspects the sibling Trie and adds its root hash
   664  // to the proofs, if the trie contains non-empty registers (i.e. the
   665  // siblingTrie has a non-default hash).
   666  func addSiblingTrieHashToProofs(siblingTrie *node.Node, depth int, proofs []*ledger.TrieProof) {
   667  	if siblingTrie == nil || len(proofs) == 0 {
   668  		return
   669  	}
   671  	// This code is necessary, because we do not remove nodes from the trie
   672  	// when a register is deleted. Instead, we just set the respective leaf's
   673  	// payload to empty. While this will cause the lead's hash to become the
   674  	// default hash, the node itself remains as part of the trie.
   675  	// However, a proof has the convention that the hash of the sibling trie
   676  	// should only be included, if it is _non-default_. Therefore, we can
   677  	// neither use `siblingTrie == nil` nor `siblingTrie.RegisterCount == 0`,
   678  	// as the sibling trie might contain leaves with default value (which are
   679  	// still counted as occupied registers)
   680  	// TODO: On update, prune subtries which only contain empty registers.
   681  	//       Then, a child is nil if and only if the subtrie is empty.
   683  	nodeHash := siblingTrie.Hash()
   684  	isDef := nodeHash == ledger.GetDefaultHashForHeight(siblingTrie.Height())
   685  	if !isDef { // in proofs, we only provide non-default value hashes
   686  		for _, p := range proofs {
   687  			bitutils.SetBit(p.Flags, depth)
   688  			p.Interims = append(p.Interims, nodeHash)
   689  		}
   690  	}
   691  }
   693  // Equals compares two tries for equality.
   694  // Tries are equal iff they store the same data (i.e. root hash matches)
   695  // and their number and height are identical
   696  func (mt *MTrie) Equals(o *MTrie) bool {
   697  	if o == nil {
   698  		return false
   699  	}
   700  	return o.RootHash() == mt.RootHash()
   701  }
   703  // DumpAsJSON dumps the trie key value pairs to a file having each key value pair as a json row
   704  func (mt *MTrie) DumpAsJSON(w io.Writer) error {
   706  	// Use encoder to prevent building entire trie in memory
   707  	enc := json.NewEncoder(w)
   709  	err := dumpAsJSON(mt.root, enc)
   710  	if err != nil {
   711  		return err
   712  	}
   714  	return nil
   715  }
   717  // dumpAsJSON serializes the sub-trie with root n to json and feeds it into encoder
   718  func dumpAsJSON(n *node.Node, encoder *json.Encoder) error {
   719  	if n.IsLeaf() {
   720  		if n != nil {
   721  			err := encoder.Encode(n.Payload())
   722  			if err != nil {
   723  				return err
   724  			}
   725  		}
   726  		return nil
   727  	}
   729  	if lChild := n.LeftChild(); lChild != nil {
   730  		err := dumpAsJSON(lChild, encoder)
   731  		if err != nil {
   732  			return err
   733  		}
   734  	}
   736  	if rChild := n.RightChild(); rChild != nil {
   737  		err := dumpAsJSON(rChild, encoder)
   738  		if err != nil {
   739  			return err
   740  		}
   741  	}
   742  	return nil
   743  }
   745  // EmptyTrieRootHash returns the rootHash of an empty Trie for the specified path size [bytes]
   746  func EmptyTrieRootHash() ledger.RootHash {
   747  	return ledger.RootHash(ledger.GetDefaultHashForHeight(ledger.NodeMaxHeight))
   748  }
   750  // AllPayloads returns all payloads
   751  func (mt *MTrie) AllPayloads() []*ledger.Payload {
   752  	return mt.root.AllPayloads()
   753  }
   755  // IsAValidTrie verifies the content of the trie for potential issues
   756  func (mt *MTrie) IsAValidTrie() bool {
   757  	// TODO add checks on the health of node max height ...
   758  	return mt.root.VerifyCachedHash()
   759  }
   761  // splitByPath permutes the input paths to be partitioned into 2 parts. The first part contains paths with a zero bit
   762  // at the input bitIndex, the second part contains paths with a one at the bitIndex. The index of partition
   763  // is returned. The same permutation is applied to the payloads slice.
   764  //
   765  // This would be the partition step of an ascending quick sort of paths (lexicographic order)
   766  // with the pivot being the path with all zeros and 1 at bitIndex.
   767  // The comparison of paths is only based on the bit at bitIndex, the function therefore assumes all paths have
   768  // equal bits from 0 to bitIndex-1
   769  //
   770  //	For instance, if `paths` contains the following 3 paths, and bitIndex is `1`:
   771  //	[[0,0,1,1], [0,1,0,1], [0,0,0,1]]
   772  //	then `splitByPath` returns 2 and updates `paths` into:
   773  //	[[0,0,1,1], [0,0,0,1], [0,1,0,1]]
   774  func splitByPath(paths []ledger.Path, payloads []ledger.Payload, bitIndex int) int {
   775  	i := 0
   776  	for j, path := range paths {
   777  		bit := bitutils.ReadBit(path[:], bitIndex)
   778  		if bit == 0 {
   779  			paths[i], paths[j] = paths[j], paths[i]
   780  			payloads[i], payloads[j] = payloads[j], payloads[i]
   781  			i++
   782  		}
   783  	}
   784  	return i
   785  }
   787  // SplitPaths permutes the input paths to be partitioned into 2 parts. The first part contains paths with a zero bit
   788  // at the input bitIndex, the second part contains paths with a one at the bitIndex. The index of partition
   789  // is returned.
   790  //
   791  // This would be the partition step of an ascending quick sort of paths (lexicographic order)
   792  // with the pivot being the path with all zeros and 1 at bitIndex.
   793  // The comparison of paths is only based on the bit at bitIndex, the function therefore assumes all paths have
   794  // equal bits from 0 to bitIndex-1
   795  func SplitPaths(paths []ledger.Path, bitIndex int) int {
   796  	i := 0
   797  	for j, path := range paths {
   798  		bit := bitutils.ReadBit(path[:], bitIndex)
   799  		if bit == 0 {
   800  			paths[i], paths[j] = paths[j], paths[i]
   801  			i++
   802  		}
   803  	}
   804  	return i
   805  }
   807  // splitTrieProofsByPath permutes the input paths to be partitioned into 2 parts. The first part contains paths
   808  // with a zero bit at the input bitIndex, the second part contains paths with a one at the bitIndex. The index
   809  // of partition is returned. The same permutation is applied to the proofs slice.
   810  //
   811  // This would be the partition step of an ascending quick sort of paths (lexicographic order)
   812  // with the pivot being the path with all zeros and 1 at bitIndex.
   813  // The comparison of paths is only based on the bit at bitIndex, the function therefore assumes all paths have
   814  // equal bits from 0 to bitIndex-1
   815  func splitTrieProofsByPath(paths []ledger.Path, proofs []*ledger.TrieProof, bitIndex int) int {
   816  	i := 0
   817  	for j, path := range paths {
   818  		bit := bitutils.ReadBit(path[:], bitIndex)
   819  		if bit == 0 {
   820  			paths[i], paths[j] = paths[j], paths[i]
   821  			proofs[i], proofs[j] = proofs[j], proofs[i]
   822  			i++
   823  		}
   824  	}
   825  	return i
   826  }
   828  func minInt(a, b int) int {
   829  	if a < b {
   830  		return a
   831  	}
   832  	return b
   833  }
   835  // TraverseNodes traverses all nodes of the trie in DFS order
   836  func TraverseNodes(trie *MTrie, processNode func(*node.Node) error) error {
   837  	return traverseRecursive(trie.root, processNode)
   838  }
   840  func traverseRecursive(n *node.Node, processNode func(*node.Node) error) error {
   841  	if n == nil {
   842  		return nil
   843  	}
   845  	err := processNode(n)
   846  	if err != nil {
   847  		return err
   848  	}
   850  	err = traverseRecursive(n.LeftChild(), processNode)
   851  	if err != nil {
   852  		return err
   853  	}
   855  	err = traverseRecursive(n.RightChild(), processNode)
   856  	if err != nil {
   857  		return err
   858  	}
   860  	return nil
   861  }