github.com/consensys/gnark-crypto@v0.14.0/accumulator/merkletree/tree.go (about)

     1  // Original Copyright (c) 2015 Nebulous
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  // The above copyright notice and this permission notice shall be included in all
    10  // copies or substantial portions of the Software.
    11  //
    12  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    13  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    14  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    15  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    16  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    17  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    18  // SOFTWARE.
    19  
    20  // Package merkletree provides Merkle tree and proof following RFC 6962.
    21  //
    22  // From https://gitlab.com/NebulousLabs/merkletree
    23  package merkletree
    24  
    25  import (
    26  	"errors"
    27  	"fmt"
    28  	"hash"
    29  )
    30  
    31  // A Tree takes data as leaves and returns the Merkle root. Each call to 'Push'
    32  // adds one leaf to the Merkle tree. Calling 'Root' returns the Merkle root.
    33  // The Tree also constructs proof that a single leaf is a part of the tree. The
    34  // leaf can be chosen with 'SetIndex'. The memory footprint of Tree grows in
    35  // O(log(n)) in the number of leaves.
    36  type Tree struct {
    37  	// The Tree is stored as a stack of subtrees. Each subtree has a height,
    38  	// and is the Merkle root of 2^height leaves. A Tree with 11 nodes is
    39  	// represented as a subtree of height 3 (8 nodes), a subtree of height 1 (2
    40  	// nodes), and a subtree of height 0 (1 node). Head points to the smallest
    41  	// tree. When a new leaf is inserted, it is inserted as a subtree of height
    42  	// 0. If there is another subtree of the same height, both can be removed,
    43  	// combined, and then inserted as a subtree of height n + 1.
    44  	head *subTree
    45  	hash hash.Hash
    46  
    47  	// Helper variables used to construct proofs that the data at 'proofIndex'
    48  	// is in the Merkle tree. The proofSet is constructed as elements are being
    49  	// added to the tree. The first element of the proof set is the original
    50  	// data used to create the leaf at index 'proofIndex'. proofTree indicates
    51  	// if the tree will be used to create a merkle proof.
    52  	currentIndex uint64
    53  	proofIndex   uint64
    54  	proofSet     [][]byte
    55  	proofTree    bool
    56  
    57  	// The cachedTree flag indicates that the tree is cached, meaning that
    58  	// different code is used in 'Push' for creating a new head subtree. Adding
    59  	// this flag is somewhat gross, but eliminates needing to duplicate the
    60  	// entire 'Push' function when writing the cached tree.
    61  	cachedTree bool
    62  }
    63  
    64  // A subTree contains the Merkle root of a complete (2^height leaves) subTree
    65  // of the Tree. 'sum' is the Merkle root of the subTree. If 'next' is not nil,
    66  // it will be a tree with a higher height.
    67  type subTree struct {
    68  	next   *subTree
    69  	height int // Int is okay because a height over 300 is physically unachievable.
    70  	sum    []byte
    71  }
    72  
    73  // sum returns the hash of the input data using the specified algorithm.
    74  func sum(h hash.Hash, data ...[]byte) []byte {
    75  
    76  	h.Reset()
    77  
    78  	for _, d := range data {
    79  		// the Hash interface specifies that Write never returns an error
    80  		_, err := h.Write(d)
    81  		if err != nil {
    82  			panic(err)
    83  		}
    84  	}
    85  	return h.Sum(nil)
    86  }
    87  
    88  // leafSum returns the hash created from data inserted to form a leaf. Leaf
    89  // sums are calculated using:
    90  //
    91  //	Hash(0x00 || data)
    92  func leafSum(h hash.Hash, data []byte) []byte {
    93  
    94  	//return sum(h, leafHashPrefix, data)
    95  	return sum(h, data)
    96  }
    97  
    98  // nodeSum returns the hash created from two sibling nodes being combined into
    99  // a parent node. Node sums are calculated using:
   100  //
   101  //	Hash(0x01 || left sibling sum || right sibling sum)
   102  func nodeSum(h hash.Hash, a, b []byte) []byte {
   103  	//return sum(h, nodeHashPrefix, a, b)
   104  	return sum(h, a, b)
   105  }
   106  
   107  // joinSubTrees combines two equal sized subTrees into a larger subTree.
   108  func joinSubTrees(h hash.Hash, a, b *subTree) *subTree {
   109  	// if DEBUG {
   110  	// 	if b.next != a {
   111  	// 		panic("invalid subtree join - 'a' is not paired with 'b'")
   112  	// 	}
   113  	// 	if a.height < b.height {
   114  	// 		panic("invalid subtree presented - height mismatch")
   115  	// 	}
   116  	// }
   117  
   118  	return &subTree{
   119  		next:   a.next,
   120  		height: a.height + 1,
   121  		sum:    nodeSum(h, a.sum, b.sum),
   122  	}
   123  }
   124  
   125  // New creates a new Tree. The provided hash will be used for all hashing
   126  // operations within the Tree.
   127  func New(h hash.Hash) *Tree {
   128  	return &Tree{
   129  		hash: h,
   130  	}
   131  }
   132  
   133  // Prove creates a proof that the leaf at the established index (established by
   134  // SetIndex) is an element of the Merkle tree. Prove will return a nil proof
   135  // set if used incorrectly. Prove does not modify the Tree. Prove can only be
   136  // called if SetIndex has been called previously.
   137  func (t *Tree) Prove() (merkleRoot []byte, proofSet [][]byte, proofIndex uint64, numLeaves uint64) {
   138  	if !t.proofTree {
   139  		panic("wrong usage: can't call prove on a tree if SetIndex wasn't called")
   140  	}
   141  
   142  	// Return nil if the Tree is empty, or if the proofIndex hasn't yet been
   143  	// reached.
   144  	if t.head == nil || len(t.proofSet) == 0 {
   145  		return t.Root(), nil, t.proofIndex, t.currentIndex
   146  	}
   147  	proofSet = t.proofSet
   148  
   149  	// The set of subtrees must now be collapsed into a single root. The proof
   150  	// set already contains all of the elements that are members of a complete
   151  	// subtree. Of what remains, there will be at most 1 element provided from
   152  	// a sibling on the right, and all of the other proofs will be provided
   153  	// from a sibling on the left. This results from the way orphans are
   154  	// treated. All subtrees smaller than the subtree containing the proofIndex
   155  	// will be combined into a single subtree that gets combined with the
   156  	// proofIndex subtree as a single right sibling. All subtrees larger than
   157  	// the subtree containing the proofIndex will be combined with the subtree
   158  	// containing the proof index as left siblings.
   159  
   160  	// Start at the smallest subtree and combine it with larger subtrees until
   161  	// it would be combining with the subtree that contains the proof index. We
   162  	// can recognize the subtree containing the proof index because the height
   163  	// of that subtree will be one less than the current length of the proof
   164  	// set.
   165  	current := t.head
   166  	for current.next != nil && current.next.height < len(proofSet)-1 {
   167  		current = joinSubTrees(t.hash, current.next, current)
   168  	}
   169  
   170  	// Sanity check - check that either 'current' or 'current.next' is the
   171  	// subtree containing the proof index.
   172  	// if DEBUG {
   173  	// 	if current.height != len(t.proofSet)-1 && (current.next != nil && current.next.height != len(t.proofSet)-1) {
   174  	// 		panic("could not find the subtree containing the proof index")
   175  	// 	}
   176  	// }
   177  
   178  	// If the current subtree is not the subtree containing the proof index,
   179  	// then it must be an aggregate subtree that is to the right of the subtree
   180  	// containing the proof index, and the next subtree is the subtree
   181  	// containing the proof index.
   182  	if current.next != nil && current.next.height == len(proofSet)-1 {
   183  		proofSet = append(proofSet, current.sum)
   184  		current = current.next
   185  	}
   186  
   187  	// The current subtree must be the subtree containing the proof index. This
   188  	// subtree does not need an entry, as the entry was created during the
   189  	// construction of the Tree. Instead, skip to the next subtree.
   190  	current = current.next
   191  
   192  	// All remaining subtrees will be added to the proof set as a left sibling,
   193  	// completing the proof set.
   194  	for current != nil {
   195  		proofSet = append(proofSet, current.sum)
   196  		current = current.next
   197  	}
   198  	return t.Root(), proofSet, t.proofIndex, t.currentIndex
   199  }
   200  
   201  // Push will add data to the set, building out the Merkle tree and Root. The
   202  // tree does not remember all elements that are added, instead only keeping the
   203  // log(n) elements that are necessary to build the Merkle root and keeping the
   204  // log(n) elements necessary to build a proof that a piece of data is in the
   205  // Merkle tree.
   206  func (t *Tree) Push(data []byte) {
   207  	// The first element of a proof is the data at the proof index. If this
   208  	// data is being inserted at the proof index, it is added to the proof set.
   209  	if t.currentIndex == t.proofIndex {
   210  		t.proofSet = append(t.proofSet, data)
   211  	}
   212  
   213  	// Hash the data to create a subtree of height 0. The sum of the new node
   214  	// is going to be the data for cached trees, and is going to be the result
   215  	// of calling leafSum() on the data for standard trees. Doing a check here
   216  	// prevents needing to duplicate the entire 'Push' function for the trees.
   217  	t.head = &subTree{
   218  		next:   t.head,
   219  		height: 0,
   220  	}
   221  	if t.cachedTree {
   222  		t.head.sum = data
   223  	} else {
   224  		t.head.sum = leafSum(t.hash, data)
   225  	}
   226  
   227  	// Join subTrees if possible.
   228  	t.joinAllSubTrees()
   229  
   230  	// Update the index.
   231  	t.currentIndex++
   232  
   233  	// Sanity check - From head to tail of the stack, the height should be
   234  	// strictly increasing.
   235  	// if DEBUG {
   236  	// 	current := t.head
   237  	// 	height := current.height
   238  	// 	for current.next != nil {
   239  	// 		current = current.next
   240  	// 		if current.height <= height {
   241  	// 			panic("subtrees are out of order")
   242  	// 		}
   243  	// 		height = current.height
   244  	// 	}
   245  	// }
   246  }
   247  
   248  // PushSubTree pushes a cached subtree into the merkle tree. The subtree has to
   249  // be smaller than the smallest subtree in the merkle tree, it has to be
   250  // balanced and it can't contain the element that needs to be proven.  Since we
   251  // can't tell if a subTree is balanced, we can't sanity check for unbalanced
   252  // trees. Therefore an unbalanced tree will cause silent errors, pain and
   253  // misery for the person who wants to debug the resulting error.
   254  func (t *Tree) PushSubTree(height int, sum []byte) error {
   255  	// Check if the cached tree that is pushed contains the element at
   256  	// proofIndex. This is not allowed.
   257  	newIndex := t.currentIndex + 1<<uint64(height)
   258  	if t.proofTree && (t.currentIndex == t.proofIndex ||
   259  		(t.currentIndex < t.proofIndex && t.proofIndex < newIndex)) {
   260  		return errors.New("the cached tree shouldn't contain the element to prove")
   261  	}
   262  
   263  	// We can only add the cached tree if its depth is <= the depth of the
   264  	// current subtree.
   265  	if t.head != nil && height > t.head.height {
   266  		return fmt.Errorf("can't add a subtree that is larger than the smallest subtree %v > %v", height, t.head.height)
   267  	}
   268  
   269  	// Insert the cached tree as the new head.
   270  	t.head = &subTree{
   271  		height: height,
   272  		next:   t.head,
   273  		sum:    sum,
   274  	}
   275  
   276  	// Join subTrees if possible.
   277  	t.joinAllSubTrees()
   278  
   279  	// Update the index.
   280  	t.currentIndex = newIndex
   281  
   282  	// Sanity check - From head to tail of the stack, the height should be
   283  	// strictly increasing.
   284  	// if DEBUG {
   285  	// 	current := t.head
   286  	// 	height := current.height
   287  	// 	for current.next != nil {
   288  	// 		current = current.next
   289  	// 		if current.height <= height {
   290  	// 			panic("subtrees are out of order")
   291  	// 		}
   292  	// 		height = current.height
   293  	// 	}
   294  	// }
   295  	return nil
   296  }
   297  
   298  // Root returns the Merkle root of the data that has been pushed.
   299  func (t *Tree) Root() []byte {
   300  	// If the Tree is empty, return nil.
   301  	if t.head == nil {
   302  		return nil
   303  	}
   304  
   305  	// The root is formed by hashing together subTrees in order from least in
   306  	// height to greatest in height. The taller subtree is the first subtree in
   307  	// the join.
   308  	current := t.head
   309  	for current.next != nil {
   310  		current = joinSubTrees(t.hash, current.next, current)
   311  	}
   312  	// Return a copy to prevent leaking a pointer to internal data.
   313  	return append(current.sum[:0:0], current.sum...)
   314  }
   315  
   316  // SetIndex will tell the Tree to create a storage proof for the leaf at the
   317  // input index. SetIndex must be called on an empty tree.
   318  func (t *Tree) SetIndex(i uint64) error {
   319  	if t.head != nil {
   320  		return errors.New("cannot call SetIndex on Tree if Tree has not been reset")
   321  	}
   322  	t.proofTree = true
   323  	t.proofIndex = i
   324  	return nil
   325  }
   326  
   327  // joinAllSubTrees inserts the subTree at t.head into the Tree. As long as the
   328  // height of the next subTree is the same as the height of the current subTree,
   329  // the two will be combined into a single subTree of height n+1.
   330  func (t *Tree) joinAllSubTrees() {
   331  	for t.head.next != nil && t.head.height == t.head.next.height {
   332  		// Before combining subtrees, check whether one of the subtree hashes
   333  		// needs to be added to the proof set. This is going to be true IFF the
   334  		// subtrees being combined are one height higher than the previous
   335  		// subtree added to the proof set. The height of the previous subtree
   336  		// added to the proof set is equal to len(t.proofSet) - 1.
   337  		if t.head.height == len(t.proofSet)-1 {
   338  			// One of the subtrees needs to be added to the proof set. The
   339  			// subtree that needs to be added is the subtree that does not
   340  			// contain the proofIndex. Because the subtrees being compared are
   341  			// the smallest and rightmost trees in the Tree, this can be
   342  			// determined by rounding the currentIndex down to the number of
   343  			// nodes in the subtree and comparing that index to the proofIndex.
   344  			leaves := uint64(1 << uint(t.head.height))
   345  			mid := (t.currentIndex / leaves) * leaves
   346  			if t.proofIndex < mid {
   347  				t.proofSet = append(t.proofSet, t.head.sum)
   348  			} else {
   349  				t.proofSet = append(t.proofSet, t.head.next.sum)
   350  			}
   351  
   352  			// Sanity check - the proofIndex should never be less than the
   353  			// midpoint minus the number of leaves in each subtree.
   354  			// if DEBUG {
   355  			// 	if t.proofIndex < mid-leaves {
   356  			// 		panic("proof being added with weird values")
   357  			// 	}
   358  			// }
   359  		}
   360  
   361  		// Join the two subTrees into one subTree with a greater height. Then
   362  		// compare the new subTree to the next subTree.
   363  		t.head = joinSubTrees(t.hash, t.head.next, t.head)
   364  	}
   365  }