github.com/letsencrypt/trillian@v1.1.2-0.20180615153820-ae375a99d36a/merkle/memory_merkle_tree.go (about)

     1  // Copyright 2016 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package merkle
    16  
    17  // This is a fairly direct port of the C++ Merkle Tree. It has the same API and should have
    18  // similar performance. It keeps all its data in RAM and is not part of the Trillian API.
    19  //
    20  // Note: this implementation evaluates the root lazily in the same way as the C++ code so
    21  // some methods that appear to be accessors can cause mutations to update the structure
    22  // to the necessary point required to obtain the result.
    23  //
    24  // -------------------------------------------------------------------------------------------
    25  // IMPORTANT NOTE: This code uses 1-based leaf indexing as this is how the original C++
    26  // works. There is scope for confusion if it is mixed with the Trillian specific trees in
    27  // this package, which index leaves starting from zero. This code is primarily meant for use in
    28  // cross checks of the new implementation and it is advantageous to be able to compare it
    29  // directly with the C++ code.
    30  // -------------------------------------------------------------------------------------------
    31  
    32  import (
    33  	"errors"
    34  	"fmt"
    35  
    36  	"github.com/google/trillian/merkle/hashers"
    37  )
    38  
    39  // TreeEntry is used for nodes in the tree for better readability. Just holds a hash but could be extended
    40  type TreeEntry struct {
    41  	hash []byte
    42  }
    43  
    44  // Hash returns the current hash in a newly created byte slice that the caller owns and may modify.
    45  func (t TreeEntry) Hash() []byte {
    46  	var newSlice []byte
    47  
    48  	return t.HashInto(newSlice)
    49  }
    50  
    51  // HashInto returns the current hash in a provided byte slice that the caller
    52  // may use to make multiple calls to obtain hashes without reallocating memory.
    53  func (t TreeEntry) HashInto(dest []byte) []byte {
    54  	dest = dest[:0] // reuse the existing space
    55  
    56  	dest = append(dest, t.hash...)
    57  	return dest
    58  }
    59  
    60  // TreeEntryDescriptor wraps a node and is used to describe tree paths, which are useful to have
    61  // access to when testing the code and examining how it works
    62  type TreeEntryDescriptor struct {
    63  	Value  TreeEntry
    64  	XCoord int64 // The horizontal node coordinate
    65  	YCoord int64 // The vertical node coordinate
    66  }
    67  
    68  // InMemoryMerkleTree holds a Merkle Tree in memory as a 2D node array
    69  type InMemoryMerkleTree struct {
    70  	// A container for nodes, organized according to levels and sorted
    71  	// left-to-right in each level. tree_[0] is the leaf level, etc.
    72  	// The hash of nodes tree_[i][j] and tree_[i][j+1] (j even) is stored
    73  	// at tree_[i+1][j/2]. When tree_[i][j] is the last node of the level with
    74  	// no right sibling, we store its dummy copy: tree_[i+1][j/2] = tree_[i][j].
    75  	//
    76  	// For example, a tree with 5 leaf hashes a0, a1, a2, a3, a4
    77  	//
    78  	//        __ hash__
    79  	//       |         |
    80  	//    __ h20__     a4
    81  	//   |        |
    82  	//  h10     h11
    83  	//  | |     | |
    84  	// a0 a1   a2 a3
    85  	//
    86  	// is internally represented, top-down
    87  	//
    88  	// --------
    89  	// | hash |                        tree_[3]
    90  	// --------------
    91  	// | h20  | a4  |                  tree_[2]
    92  	// -------------------
    93  	// | h10  | h11 | a4 |             tree_[1]
    94  	// -----------------------------
    95  	// | a0   | a1  | a2 | a3 | a4 |   tree_[0]
    96  	// -----------------------------
    97  	//
    98  	// Since the tree is append-only from the right, at any given point in time,
    99  	// at each level, all nodes computed so far, except possibly the last node,
   100  	// are fixed and will no longer change.
   101  	tree            [][]TreeEntry
   102  	leavesProcessed int64
   103  	levelCount      int64
   104  	hasher          hashers.LogHasher
   105  }
   106  
   107  // isPowerOfTwoPlusOne tests whether a number is (2^x)-1 for some x. From MerkleTreeMath in C++
   108  func isPowerOfTwoPlusOne(leafCount int64) bool {
   109  	if leafCount == 0 {
   110  		return false
   111  	}
   112  
   113  	if leafCount == 1 {
   114  		return true
   115  	}
   116  	// leaf_count is a power of two plus one if and only if
   117  	// ((leaf_count - 1) & (leaf_count - 2)) has no bits set.
   118  	return (((leafCount - 1) & (leafCount - 2)) == 0)
   119  }
   120  
   121  // sibling returns the index of the node's (left or right) sibling in the same level.
   122  func sibling(leaf int64) int64 {
   123  	if isRightChild(leaf) {
   124  		return leaf - 1
   125  	}
   126  	return leaf + 1
   127  }
   128  
   129  // NewInMemoryMerkleTree creates a new empty Merkle Tree using the specified Hasher
   130  func NewInMemoryMerkleTree(hasher hashers.LogHasher) *InMemoryMerkleTree {
   131  	mt := InMemoryMerkleTree{}
   132  
   133  	mt.hasher = hasher
   134  	mt.levelCount = 0
   135  	mt.leavesProcessed = 0
   136  
   137  	return &mt
   138  }
   139  
   140  // LeafHash returns the hash of the requested leaf.
   141  func (mt *InMemoryMerkleTree) LeafHash(leaf int64) []byte {
   142  	if leaf == 0 || leaf > mt.LeafCount() {
   143  		return nil
   144  	}
   145  
   146  	return mt.tree[0][leaf-1].hash
   147  }
   148  
   149  // NodeCount gets the current node count (of the lazily evaluated tree).
   150  // Caller is responsible for keeping track of the lazy evaluation status. This will not
   151  // update the tree.
   152  func (mt *InMemoryMerkleTree) NodeCount(level int64) int64 {
   153  	if mt.lazyLevelCount() <= level {
   154  		panic(fmt.Errorf("lazyLevelCount <= level in nodeCount: %d", mt.lazyLevelCount()))
   155  	}
   156  
   157  	return int64(len(mt.tree[level]))
   158  }
   159  
   160  // LevelCount returns the number of levels in the current Merkle tree
   161  func (mt *InMemoryMerkleTree) LevelCount() int64 {
   162  	return mt.levelCount
   163  }
   164  
   165  // lazyLevelCount is the current level count of the lazily evaluated tree.
   166  func (mt *InMemoryMerkleTree) lazyLevelCount() int64 {
   167  	return int64(len(mt.tree))
   168  }
   169  
   170  // LeafCount returns the number of leaves in the tree.
   171  func (mt *InMemoryMerkleTree) LeafCount() int64 {
   172  	if len(mt.tree) == 0 {
   173  		return 0
   174  	}
   175  	return mt.NodeCount(0)
   176  }
   177  
   178  // root gets the current root (of the lazily evaluated tree).
   179  // Caller is responsible for keeping track of the lazy evaluation status.
   180  func (mt *InMemoryMerkleTree) root() TreeEntry {
   181  	lastLevel := len(mt.tree) - 1
   182  
   183  	if len(mt.tree[lastLevel]) > 1 {
   184  		panic(fmt.Errorf("found multiple nodes in root: %d", len(mt.tree[lastLevel])))
   185  	}
   186  
   187  	return mt.tree[lastLevel][0]
   188  }
   189  
   190  // lastNode returns the last node of the given level in the tree.
   191  func (mt *InMemoryMerkleTree) lastNode(level int64) TreeEntry {
   192  	levelNodes := mt.NodeCount(level)
   193  
   194  	if levelNodes < 1 {
   195  		panic(fmt.Errorf("no nodes at level %d in lastNode", level))
   196  	}
   197  
   198  	return mt.tree[level][levelNodes-1]
   199  }
   200  
   201  // addLevel start a new tree level.
   202  func (mt *InMemoryMerkleTree) addLevel() {
   203  	mt.tree = append(mt.tree, []TreeEntry{})
   204  }
   205  
   206  // pushBack appends a node to the level.
   207  func (mt *InMemoryMerkleTree) pushBack(level int64, treeEntry TreeEntry) {
   208  	if mt.lazyLevelCount() <= level {
   209  		panic(fmt.Errorf("lazyLevelCount <= level in pushBack: %d", mt.lazyLevelCount()))
   210  	}
   211  
   212  	mt.tree[level] = append(mt.tree[level], treeEntry)
   213  }
   214  
   215  // popBack pops (removes and returns) the last node of the level.
   216  func (mt *InMemoryMerkleTree) popBack(level int64) {
   217  	if len(mt.tree[level]) < 1 {
   218  		panic(errors.New("no nodes to pop in popBack"))
   219  	}
   220  
   221  	mt.tree[level] = mt.tree[level][:len(mt.tree[level])-1]
   222  }
   223  
   224  // AddLeaf adds a new leaf to the hash tree. Stores the hash of the leaf data in the
   225  // tree structure, does not store the data itself.
   226  //
   227  // (We will evaluate the tree lazily, and not update the root here.)
   228  //
   229  // Returns the position of the leaf in the tree. Indexing starts at 1,
   230  // so position = number of leaves in the tree after this update.
   231  func (mt *InMemoryMerkleTree) AddLeaf(leafData []byte) (int64, TreeEntry, error) {
   232  	leafHash, err := mt.hasher.HashLeaf(leafData)
   233  	if err != nil {
   234  		return 0, TreeEntry{}, err
   235  	}
   236  	leafCount, treeEntry := mt.addLeafHash(leafHash)
   237  	return leafCount, treeEntry, nil
   238  }
   239  
   240  func (mt *InMemoryMerkleTree) addLeafHash(leafData []byte) (int64, TreeEntry) {
   241  	treeEntry := TreeEntry{}
   242  	treeEntry.hash = leafData
   243  
   244  	if mt.lazyLevelCount() == 0 {
   245  		// The first leaf hash is also the first root.
   246  		mt.addLevel()
   247  		mt.leavesProcessed = 1
   248  	}
   249  
   250  	mt.pushBack(0, treeEntry)
   251  	leafCount := mt.LeafCount()
   252  
   253  	// Update level count: a k-level tree can hold 2^{k-1} leaves,
   254  	// so increment level count every time we overflow a power of two.
   255  	// Do not update the root; we evaluate the tree lazily.
   256  	if isPowerOfTwoPlusOne(leafCount) {
   257  		mt.levelCount++
   258  	}
   259  
   260  	return leafCount, treeEntry
   261  }
   262  
   263  // CurrentRoot set the current root of the tree.
   264  // Updates the root to reflect the current shape of the tree and returns the tree digest.
   265  //
   266  // Returns the hash of an empty string if the tree has no leaves
   267  // (and hence, no root).
   268  func (mt *InMemoryMerkleTree) CurrentRoot() TreeEntry {
   269  	return mt.RootAtSnapshot(mt.LeafCount())
   270  }
   271  
   272  // RootAtSnapshot gets the root of the tree for a previous snapshot,
   273  // where snapshot 0 is an empty tree, snapshot 1 is the tree with
   274  // 1 leaf, etc.
   275  //
   276  // Returns an empty string if the snapshot requested is in the future
   277  // (i.e., the tree is not large enough).
   278  func (mt *InMemoryMerkleTree) RootAtSnapshot(snapshot int64) TreeEntry {
   279  	if snapshot == 0 {
   280  		return TreeEntry{mt.hasher.EmptyRoot()}
   281  	}
   282  
   283  	// Snapshot index bigger than tree, this is not the TreeEntry you're looking for
   284  	if snapshot > mt.LeafCount() {
   285  		return TreeEntry{nil}
   286  	}
   287  
   288  	if snapshot >= mt.leavesProcessed {
   289  		return mt.updateToSnapshot(snapshot)
   290  	}
   291  
   292  	// snapshot < leaves_processed_: recompute the snapshot root.
   293  	return mt.recomputePastSnapshot(snapshot, 0, nil)
   294  }
   295  
   296  // updateToSnapshot updates the tree to a given snapshot (if necessary), returns the root.
   297  func (mt *InMemoryMerkleTree) updateToSnapshot(snapshot int64) TreeEntry {
   298  	if snapshot == 0 {
   299  		return TreeEntry{mt.hasher.EmptyRoot()}
   300  	}
   301  
   302  	if snapshot == 1 {
   303  		return mt.tree[0][0]
   304  	}
   305  
   306  	if snapshot == mt.leavesProcessed {
   307  		return mt.root()
   308  	}
   309  
   310  	if snapshot > mt.LeafCount() {
   311  		panic(errors.New("snapshot size > leaf count in updateToSnapshot"))
   312  	}
   313  
   314  	if snapshot <= mt.leavesProcessed {
   315  		panic(errors.New("snapshot size <= leavesProcessed in updateToSnapshot"))
   316  	}
   317  
   318  	// Update tree, moving up level-by-level.
   319  	level := int64(0)
   320  	// Index of the first node to be processed at the current level.
   321  	firstNode := mt.leavesProcessed
   322  	// Index of the last node.
   323  	lastNode := snapshot - 1
   324  
   325  	// Process level-by-level until we converge to a single node.
   326  	// (first_node, last_node) = (0, 0) means we have reached the root level.
   327  	for lastNode != 0 {
   328  		if mt.lazyLevelCount() <= level+1 {
   329  			mt.addLevel()
   330  		} else if mt.NodeCount(level+1) == parent(firstNode)+1 {
   331  			// The leftmost parent at level 'level+1' may already exist,
   332  			// so we need to update it. Nuke the old parent.
   333  			mt.popBack(level + 1)
   334  		}
   335  
   336  		// Compute the parents of new nodes at the current level.
   337  		// Start with a left sibling and parse an even number of nodes.
   338  		for j := firstNode &^ 1; j < lastNode; j += 2 {
   339  			mt.pushBack(level+1, TreeEntry{mt.hasher.HashChildren(mt.tree[level][j].hash, mt.tree[level][j+1].hash)})
   340  		}
   341  
   342  		// If the last node at the current level is a left sibling,
   343  		// dummy-propagate it one level up.
   344  		if !isRightChild(lastNode) {
   345  			mt.pushBack(level+1, mt.tree[level][lastNode])
   346  		}
   347  
   348  		firstNode = parent(firstNode)
   349  		lastNode = parent(lastNode)
   350  		level++
   351  	}
   352  
   353  	mt.leavesProcessed = snapshot
   354  
   355  	return mt.root()
   356  }
   357  
   358  // recomputePastSnapshot returns the root of the tree as it was for a past snapshot.
   359  // If node is not nil, additionally records the rightmost node for the given snapshot and node_level.
   360  func (mt *InMemoryMerkleTree) recomputePastSnapshot(snapshot int64, nodeLevel int64, node *TreeEntry) TreeEntry {
   361  	level := int64(0)
   362  	// Index of the rightmost node at the current level for this snapshot.
   363  	lastNode := snapshot - 1
   364  
   365  	if snapshot == mt.leavesProcessed {
   366  		// Nothing to recompute.
   367  		if node != nil && mt.lazyLevelCount() > nodeLevel {
   368  			if nodeLevel > 0 {
   369  				*node = mt.lastNode(nodeLevel)
   370  			} else {
   371  				// Leaf level: grab the last processed leaf.
   372  				*node = mt.tree[nodeLevel][lastNode]
   373  			}
   374  		}
   375  
   376  		return mt.root()
   377  	}
   378  
   379  	if snapshot >= mt.leavesProcessed {
   380  		panic(errors.New("snapshot size >= leavesProcessed in recomputePastSnapshot"))
   381  	}
   382  
   383  	// Recompute nodes on the path of the last leaf.
   384  	for isRightChild(lastNode) {
   385  		if node != nil && nodeLevel == level {
   386  			*node = mt.tree[level][lastNode]
   387  		}
   388  
   389  		// Left sibling and parent exist in the snapshot, and are equal to
   390  		// those in the tree; no need to rehash, move one level up.
   391  		lastNode = parent(lastNode)
   392  		level++
   393  	}
   394  
   395  	// Now last_node is the index of a left sibling with no right sibling.
   396  	// Record the node.
   397  	subtreeRoot := mt.tree[level][lastNode]
   398  
   399  	if node != nil && nodeLevel == level {
   400  		*node = subtreeRoot
   401  	}
   402  
   403  	for lastNode != 0 {
   404  		if isRightChild(lastNode) {
   405  			// Recompute the parent of tree_[level][last_node].
   406  			subtreeRoot = TreeEntry{mt.hasher.HashChildren(mt.tree[level][lastNode-1].hash, subtreeRoot.hash)}
   407  		}
   408  		// Else the parent is a dummy copy of the current node; do nothing.
   409  
   410  		lastNode = parent(lastNode)
   411  		level++
   412  		if node != nil && nodeLevel == level {
   413  			*node = subtreeRoot
   414  		}
   415  	}
   416  
   417  	return subtreeRoot
   418  }
   419  
   420  // PathToCurrentRoot get the Merkle path from leaf to root for a given leaf.
   421  //
   422  // Returns a slice of node hashes, ordered by levels from leaf to root.
   423  // The first element is the sibling of the leaf hash, and the last element
   424  // is one below the root.
   425  // Returns an empty slice if the tree is not large enough
   426  // or the leaf index is 0.
   427  func (mt *InMemoryMerkleTree) PathToCurrentRoot(leaf int64) []TreeEntryDescriptor {
   428  	return mt.PathToRootAtSnapshot(leaf, mt.LeafCount())
   429  }
   430  
   431  // PathToRootAtSnapshot gets the Merkle path from a leaf to the root for a previous snapshot.
   432  //
   433  // Returns a slice of node hashes, ordered by levels from leaf to
   434  // root.  The first element is the sibling of the leaf hash, and the
   435  // last element is one below the root.  Returns an empty slice if
   436  // the leaf index is 0, the snapshot requested is in the future or
   437  // the snapshot tree is not large enough.
   438  func (mt *InMemoryMerkleTree) PathToRootAtSnapshot(leaf int64, snapshot int64) []TreeEntryDescriptor {
   439  	if leaf > snapshot || snapshot > mt.LeafCount() || leaf == 0 {
   440  		return []TreeEntryDescriptor{}
   441  	}
   442  
   443  	return mt.pathFromNodeToRootAtSnapshot(leaf-1, 0, snapshot)
   444  }
   445  
   446  // pathFromNodeToRootAtSnapshot returns the path from a node at a given level
   447  // (both indexed starting with 0) to the root at a given snapshot.
   448  func (mt *InMemoryMerkleTree) pathFromNodeToRootAtSnapshot(node int64, level int64, snapshot int64) []TreeEntryDescriptor {
   449  	var path []TreeEntryDescriptor
   450  
   451  	if snapshot == 0 {
   452  		return path
   453  	}
   454  
   455  	// Index of the last node.
   456  	lastNode := (snapshot - 1) >> uint64(level)
   457  
   458  	if level >= mt.levelCount || node > lastNode || snapshot > mt.LeafCount() {
   459  		return path
   460  	}
   461  
   462  	if snapshot > mt.leavesProcessed {
   463  		// Bring the tree sufficiently up to date.
   464  		mt.updateToSnapshot(snapshot)
   465  	}
   466  
   467  	// Move up, recording the sibling of the current node at each level.
   468  	for lastNode != 0 {
   469  		sibling := sibling(node)
   470  
   471  		if sibling < lastNode {
   472  			// The sibling is not the last node of the level in the snapshot
   473  			// tree, so its value is correct in the tree.
   474  			path = append(path, TreeEntryDescriptor{mt.tree[level][sibling], level, sibling})
   475  		} else if sibling == lastNode {
   476  			// The sibling is the last node of the level in the snapshot tree,
   477  			// so we get its value for the snapshot. Get the root in the same pass.
   478  			var recomputeNode TreeEntry
   479  
   480  			mt.recomputePastSnapshot(snapshot, level, &recomputeNode)
   481  			path = append(path, TreeEntryDescriptor{recomputeNode, -level, -sibling})
   482  		}
   483  		// Else sibling > last_node so the sibling does not exist. Do nothing.
   484  		// Continue moving up in the tree, ignoring dummy copies.
   485  		node = parent(node)
   486  		lastNode = parent(lastNode)
   487  		level++
   488  	}
   489  
   490  	return path
   491  }
   492  
   493  // SnapshotConsistency gets the Merkle consistency proof between two snapshots.
   494  // Returns a slice of node hashes, ordered according to levels.
   495  // Returns an empty slice if snapshot1 is 0, snapshot 1 >= snapshot2,
   496  // or one of the snapshots requested is in the future.
   497  func (mt *InMemoryMerkleTree) SnapshotConsistency(snapshot1 int64, snapshot2 int64) []TreeEntryDescriptor {
   498  	var proof []TreeEntryDescriptor
   499  
   500  	if snapshot1 == 0 || snapshot1 >= snapshot2 || snapshot2 > mt.LeafCount() {
   501  		return proof
   502  	}
   503  
   504  	level := int64(0)
   505  	// Rightmost node in snapshot1.
   506  	node := snapshot1 - 1
   507  
   508  	// Compute the (compressed) path to the root of snapshot2.
   509  	// Everything left of 'node' is equal in both trees; no need to record.
   510  	for isRightChild(node) {
   511  		node = parent(node)
   512  		level++
   513  	}
   514  
   515  	if snapshot2 > mt.leavesProcessed {
   516  		// Bring the tree sufficiently up to date.
   517  		mt.updateToSnapshot(snapshot2)
   518  	}
   519  
   520  	// Record the node, unless we already reached the root of snapshot1.
   521  	if node != 0 {
   522  		proof = append(proof, TreeEntryDescriptor{mt.tree[level][node], level, node})
   523  	}
   524  
   525  	// Now record the path from this node to the root of snapshot2.
   526  	path := mt.pathFromNodeToRootAtSnapshot(node, level, snapshot2)
   527  
   528  	return append(proof, path...)
   529  }