github.com/prysmaticlabs/prysm@v1.4.4/shared/trieutil/sparse_merkle.go (about)

     1  // Package trieutil defines utilities for sparse merkle tries for Ethereum consensus.
     2  package trieutil
     3  
     4  import (
     5  	"bytes"
     6  	"encoding/binary"
     7  	"errors"
     8  	"fmt"
     9  
    10  	protodb "github.com/prysmaticlabs/prysm/proto/beacon/db"
    11  	"github.com/prysmaticlabs/prysm/shared/bytesutil"
    12  	"github.com/prysmaticlabs/prysm/shared/hashutil"
    13  	"github.com/prysmaticlabs/prysm/shared/mathutil"
    14  )
    15  
    16  // SparseMerkleTrie implements a sparse, general purpose Merkle trie to be used
    17  // across Ethereum consensus functionality.
    18  type SparseMerkleTrie struct {
    19  	depth         uint
    20  	branches      [][][]byte
    21  	originalItems [][]byte // list of provided items before hashing them into leaves.
    22  }
    23  
    24  // NewTrie returns a new merkle trie filled with zerohashes to use.
    25  func NewTrie(depth uint64) (*SparseMerkleTrie, error) {
    26  	var zeroBytes [32]byte
    27  	items := [][]byte{zeroBytes[:]}
    28  	return GenerateTrieFromItems(items, depth)
    29  }
    30  
    31  // CreateTrieFromProto creates a Sparse Merkle Trie from its corresponding merkle trie.
    32  func CreateTrieFromProto(trieObj *protodb.SparseMerkleTrie) *SparseMerkleTrie {
    33  	trie := &SparseMerkleTrie{
    34  		depth:         uint(trieObj.Depth),
    35  		originalItems: trieObj.OriginalItems,
    36  	}
    37  	branches := make([][][]byte, len(trieObj.Layers))
    38  	for i, layer := range trieObj.Layers {
    39  		branches[i] = layer.Layer
    40  	}
    41  	trie.branches = branches
    42  	return trie
    43  }
    44  
    45  // GenerateTrieFromItems constructs a Merkle trie from a sequence of byte slices.
    46  func GenerateTrieFromItems(items [][]byte, depth uint64) (*SparseMerkleTrie, error) {
    47  	if len(items) == 0 {
    48  		return nil, errors.New("no items provided to generate Merkle trie")
    49  	}
    50  	leaves := items
    51  	layers := make([][][]byte, depth+1)
    52  	transformedLeaves := make([][]byte, len(leaves))
    53  	for i := range leaves {
    54  		arr := bytesutil.ToBytes32(leaves[i])
    55  		transformedLeaves[i] = arr[:]
    56  	}
    57  	layers[0] = transformedLeaves
    58  	for i := uint64(0); i < depth; i++ {
    59  		if len(layers[i])%2 == 1 {
    60  			layers[i] = append(layers[i], ZeroHashes[i][:])
    61  		}
    62  		updatedValues := make([][]byte, 0)
    63  		for j := 0; j < len(layers[i]); j += 2 {
    64  			concat := hashutil.Hash(append(layers[i][j], layers[i][j+1]...))
    65  			updatedValues = append(updatedValues, concat[:])
    66  		}
    67  		layers[i+1] = updatedValues
    68  	}
    69  	return &SparseMerkleTrie{
    70  		branches:      layers,
    71  		originalItems: items,
    72  		depth:         uint(depth),
    73  	}, nil
    74  }
    75  
    76  // Items returns the original items passed in when creating the Merkle trie.
    77  func (m *SparseMerkleTrie) Items() [][]byte {
    78  	return m.originalItems
    79  }
    80  
    81  // Root returns the top-most, Merkle root of the trie.
    82  func (m *SparseMerkleTrie) Root() [32]byte {
    83  	enc := [32]byte{}
    84  	binary.LittleEndian.PutUint64(enc[:], uint64(len(m.originalItems)))
    85  	return hashutil.Hash(append(m.branches[len(m.branches)-1][0], enc[:]...))
    86  }
    87  
    88  // Insert an item into the trie.
    89  func (m *SparseMerkleTrie) Insert(item []byte, index int) {
    90  	for index >= len(m.branches[0]) {
    91  		m.branches[0] = append(m.branches[0], ZeroHashes[0][:])
    92  	}
    93  	someItem := bytesutil.ToBytes32(item)
    94  	m.branches[0][index] = someItem[:]
    95  	if index >= len(m.originalItems) {
    96  		m.originalItems = append(m.originalItems, someItem[:])
    97  	} else {
    98  		m.originalItems[index] = someItem[:]
    99  	}
   100  	currentIndex := index
   101  	root := bytesutil.ToBytes32(item)
   102  	for i := 0; i < int(m.depth); i++ {
   103  		isLeft := currentIndex%2 == 0
   104  		neighborIdx := currentIndex ^ 1
   105  		var neighbor []byte
   106  		if neighborIdx >= len(m.branches[i]) {
   107  			neighbor = ZeroHashes[i][:]
   108  		} else {
   109  			neighbor = m.branches[i][neighborIdx]
   110  		}
   111  		if isLeft {
   112  			parentHash := hashutil.Hash(append(root[:], neighbor...))
   113  			root = parentHash
   114  		} else {
   115  			parentHash := hashutil.Hash(append(neighbor, root[:]...))
   116  			root = parentHash
   117  		}
   118  		parentIdx := currentIndex / 2
   119  		if len(m.branches[i+1]) == 0 || parentIdx >= len(m.branches[i+1]) {
   120  			newItem := root
   121  			m.branches[i+1] = append(m.branches[i+1], newItem[:])
   122  		} else {
   123  			newItem := root
   124  			m.branches[i+1][parentIdx] = newItem[:]
   125  		}
   126  		currentIndex = parentIdx
   127  	}
   128  }
   129  
   130  // MerkleProof computes a proof from a trie's branches using a Merkle index.
   131  func (m *SparseMerkleTrie) MerkleProof(index int) ([][]byte, error) {
   132  	merkleIndex := uint(index)
   133  	leaves := m.branches[0]
   134  	if index >= len(leaves) {
   135  		return nil, fmt.Errorf("merkle index out of range in trie, max range: %d, received: %d", len(leaves), index)
   136  	}
   137  	proof := make([][]byte, m.depth+1)
   138  	for i := uint(0); i < m.depth; i++ {
   139  		subIndex := (merkleIndex / (1 << i)) ^ 1
   140  		if subIndex < uint(len(m.branches[i])) {
   141  			item := bytesutil.ToBytes32(m.branches[i][subIndex])
   142  			proof[i] = item[:]
   143  		} else {
   144  			proof[i] = ZeroHashes[i][:]
   145  		}
   146  	}
   147  	enc := [32]byte{}
   148  	binary.LittleEndian.PutUint64(enc[:], uint64(len(m.originalItems)))
   149  	proof[len(proof)-1] = enc[:]
   150  	return proof, nil
   151  }
   152  
   153  // HashTreeRoot of the Merkle trie as defined in the deposit contract.
   154  //  Spec Definition:
   155  //   sha256(concat(node, self.to_little_endian_64(self.deposit_count), slice(zero_bytes32, start=0, len=24)))
   156  func (m *SparseMerkleTrie) HashTreeRoot() [32]byte {
   157  	var zeroBytes [32]byte
   158  	depositCount := uint64(len(m.originalItems))
   159  	if len(m.originalItems) == 1 && bytes.Equal(m.originalItems[0], zeroBytes[:]) {
   160  		// Accounting for empty tries
   161  		depositCount = 0
   162  	}
   163  	newNode := append(m.branches[len(m.branches)-1][0], bytesutil.Bytes8(depositCount)...)
   164  	newNode = append(newNode, zeroBytes[:24]...)
   165  	return hashutil.Hash(newNode)
   166  }
   167  
   168  // ToProto converts the underlying trie into its corresponding
   169  // proto object
   170  func (m *SparseMerkleTrie) ToProto() *protodb.SparseMerkleTrie {
   171  	trie := &protodb.SparseMerkleTrie{
   172  		Depth:         uint64(m.depth),
   173  		Layers:        make([]*protodb.TrieLayer, len(m.branches)),
   174  		OriginalItems: m.originalItems,
   175  	}
   176  	for i, l := range m.branches {
   177  		trie.Layers[i] = &protodb.TrieLayer{
   178  			Layer: l,
   179  		}
   180  	}
   181  	return trie
   182  }
   183  
   184  // VerifyMerkleBranch verifies a Merkle branch against a root of a trie.
   185  func VerifyMerkleBranch(root, item []byte, merkleIndex int, proof [][]byte, depth uint64) bool {
   186  	if len(proof) != int(depth)+1 {
   187  		return false
   188  	}
   189  	node := bytesutil.ToBytes32(item)
   190  	for i := 0; i <= int(depth); i++ {
   191  		if (uint64(merkleIndex) / mathutil.PowerOf2(uint64(i)) % 2) != 0 {
   192  			node = hashutil.Hash(append(proof[i], node[:]...))
   193  		} else {
   194  			node = hashutil.Hash(append(node[:], proof[i]...))
   195  		}
   196  	}
   197  
   198  	return bytes.Equal(root, node[:])
   199  }
   200  
   201  // Copy performs a deep copy of the trie.
   202  func (m *SparseMerkleTrie) Copy() *SparseMerkleTrie {
   203  	dstBranches := make([][][]byte, len(m.branches))
   204  	for i1, srcB1 := range m.branches {
   205  		dstBranches[i1] = bytesutil.Copy2dBytes(srcB1)
   206  	}
   207  
   208  	return &SparseMerkleTrie{
   209  		depth:         m.depth,
   210  		branches:      dstBranches,
   211  		originalItems: bytesutil.Copy2dBytes(m.originalItems),
   212  	}
   213  }
   214  
   215  // NumOfItems returns the num of items stored in
   216  // the sparse merkle trie. We handle a special case
   217  // where if there is only one item stored and it is a
   218  // empty 32-byte root.
   219  func (m *SparseMerkleTrie) NumOfItems() int {
   220  	var zeroBytes [32]byte
   221  	if len(m.originalItems) == 1 && bytes.Equal(m.originalItems[0], zeroBytes[:]) {
   222  		return 0
   223  	}
   224  	return len(m.originalItems)
   225  }