github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/storage/merkle/proof.go (about)

     1  package merkle
     2  
     3  import (
     4  	"bytes"
     5  	"math/bits"
     6  
     7  	"github.com/onflow/flow-go/ledger/common/bitutils"
     8  )
     9  
    10  // Proof captures all data needed for proving inclusion of a single key/value pair inside a merkle trie.
    11  // Verifying a proof requires knowledge of the trie path structure (node types), traversing
    12  // the trie from the leaf to the root, and computing hash values.
    13  type Proof struct {
    14  	// Key used to insert and look up the value
    15  	Key []byte
    16  	// Value stored in the trie for the given key
    17  	Value []byte
    18  	// InterimNodeTypes is designed to be consumed bit by bit to determine if the next node
    19  	// is a short node or full node while traversing the trie downward (0: fullnode, 1: shortnode).
    20  	// The very first bit corresponds to the root of the trie and last bit is the last
    21  	// interim node before reaching the leaf.
    22  	// The slice represents a bit vector where the lowest index byte represents the first 8 node types,
    23  	// while the most significant bit of the byte represents the first node type (big endianness).
    24  	// Note that we always allocate the minimal number of bytes needed to capture all
    25  	// the nodes in the path (padded with zero)
    26  	InterimNodeTypes []byte
    27  	// ShortPathLengths is read when we reach a short node, and the value represents non-zero number of common bits that were included
    28  	// in the short node (shortNode.count). Elements are ordered from root to leaf.
    29  	ShortPathLengths []uint16
    30  	// SiblingHashes is read when we reach a full node. The corresponding element represents
    31  	// the hash of the non-visited sibling node for each full node on the path. Elements are ordered from root to leaf.
    32  	SiblingHashes [][]byte
    33  }
    34  
    35  // validateFormat validates the format and size of elements of the proof (syntax check)
    36  //
    37  // A valid proof as to satisfy the following consistency conditions:
    38  //  1. A valid inclusion proof represents a full path through the merkle tree.
    39  //     We separate the path into a sequence of interim vertices and a tailing leaf.
    40  //     For interim vertex (with index i, counted from the root node) along the path,
    41  //     the proof as to contain the following information:
    42  //     (i) whether the vertex is a short node (InterimNodeTypes[i] == 1) or
    43  //     a full node (InterimNodeTypes[i] == 0)
    44  //     (ii) for each short node, we need the number of bits in the node's key segment
    45  //     (entry in ShortPathLengths)
    46  //     (iii) for a full node, we need the hash of the sibling that is _not_ on the path
    47  //     (entry in SiblingHashes)
    48  //     Hence, len(ShortPathLengths) + len(SiblingHashes) specifies how many _interim_
    49  //     vertices are on the merkle path. Consequently, we require the same number of _bits_
    50  //     in InterimNodeTypes. Therefore, we know that InterimNodeTypes should have a length
    51  //     of `(numberBits+7)>>3` _bytes_, and the remaining padding bits must be zeros.
    52  //  2. The key length (measured in bytes) has to be in the interval [1, maxKeyLength].
    53  //     Furthermore, each interim vertex on the merkle path represents:
    54  //     * either a single bit in case of a full node:
    55  //     we expect InterimNodeTypes[i] == 0
    56  //     * a positive number of bits in case of a short node:
    57  //     we expect InterimNodeTypes[i] == 1
    58  //     and the number of bits is non-zero and encoded in the respective element of ShortPathLengths
    59  //     Hence, the total key length _in bits_ should be: len(SiblingHashes) + sum(ShortPathLengths)
    60  func (p *Proof) validateFormat() error {
    61  
    62  	// step1 - validate the key as the very first step
    63  	keyLen := len(p.Key)
    64  	if keyLen == 0 || maxKeyLength < keyLen {
    65  		return NewMalformedProofErrorf("key length in bytes must be in interval [1, %d], but is %d", maxKeyLength, keyLen)
    66  	}
    67  
    68  	// step2 - check ShortPathLengths and SiblingHashes
    69  
    70  	// validate number of bits that is going to be checked matches the size of the given key
    71  	keyBitCount := len(p.SiblingHashes)
    72  	for _, sc := range p.ShortPathLengths {
    73  		if keyBitCount > maxKeyLenBits {
    74  			return NewMalformedProofErrorf("number of key bits (%d) exceed limit (%d)", keyBitCount, maxKeyLenBits)
    75  		}
    76  		// check the common bits are non-zero
    77  		if sc == 0 {
    78  			return NewMalformedProofErrorf("short path length cannot be zero")
    79  		}
    80  		keyBitCount += int(sc)
    81  	}
    82  	if keyLen*8 != keyBitCount {
    83  		return NewMalformedProofErrorf("key length in bits (%d) doesn't match the length of ShortPathLengths and SiblingHashes (%d)",
    84  			keyLen*8,
    85  			keyBitCount)
    86  	}
    87  
    88  	// step3 - check InterimNodeTypes
    89  
    90  	// size checks
    91  	if len(p.InterimNodeTypes) > maxKeyLength {
    92  		return NewMalformedProofErrorf("InterimNodeTypes is larger than max key length allowed (%d > %d)", len(p.InterimNodeTypes), maxKeyLength)
    93  	}
    94  	// InterimNodeTypes should only use the smallest number of bytes needed for steps
    95  	steps := len(p.ShortPathLengths) + len(p.SiblingHashes)
    96  	if len(p.InterimNodeTypes) != (steps+7)>>3 {
    97  		return NewMalformedProofErrorf("the length of InterimNodeTypes doesn't match the length of ShortPathLengths and SiblingHashes")
    98  	}
    99  
   100  	// semantic checks
   101  
   102  	// Verify that number of bits that are set to 1 equals to the number of short nodes, i.e. len(ShortPathLengths).
   103  	numberOfShortNodes := 0
   104  	for _, d := range p.InterimNodeTypes {
   105  		numberOfShortNodes += bits.OnesCount8(d)
   106  	}
   107  	if numberOfShortNodes != len(p.ShortPathLengths) {
   108  		return NewMalformedProofErrorf("len(ShortPathLengths) (%d) does not match number of set bits in InterimNodeTypes (%d)", len(p.ShortPathLengths), numberOfShortNodes)
   109  	}
   110  
   111  	// check that tailing auxiliary bits (to make a complete full byte) are all zero
   112  	for i := len(p.InterimNodeTypes)*8 - 1; i >= steps; i-- {
   113  		if bitutils.ReadBit(p.InterimNodeTypes, i) != 0 {
   114  			return NewMalformedProofErrorf("tailing auxiliary bits in InterimNodeTypes should all be zero")
   115  		}
   116  	}
   117  
   118  	return nil
   119  }
   120  
   121  // Verify verifies the proof by constructing the hash values bottom up and cross-check
   122  // the constructed root hash with the given one. For valid proofs, `nil` is returned.
   123  // During normal operations, the following error returns are expected:
   124  //   - MalformedProofError if the proof has a syntactically invalid structure
   125  //   - InvalidProofError if the proof is syntactically valid, but the reconstructed
   126  //     root hash does not match the expected value.
   127  func (p *Proof) Verify(expectedRootHash []byte) error {
   128  
   129  	// first validate the format of the proof
   130  	if err := p.validateFormat(); err != nil {
   131  		return err
   132  	}
   133  
   134  	// an index to consume SiblingHashes from the last element to the first element
   135  	siblingHashIndex := len(p.SiblingHashes) - 1
   136  
   137  	// an index to consume ShortPathLengths from the last element to the first element
   138  	shortPathLengthIndex := len(p.ShortPathLengths) - 1
   139  
   140  	// keyIndex keeps track of the largest index of the key that is unchecked.
   141  	// Note that we traverse bottom up, so we start with the largest key index
   142  	// build hashes until we reach to the root.
   143  	keyIndex := len(p.Key)*8 - 1
   144  
   145  	// compute the hash value of the leaf
   146  	currentHash := computeLeafHash(p.Value)
   147  
   148  	// number of steps
   149  	steps := len(p.ShortPathLengths) + len(p.SiblingHashes)
   150  
   151  	// for each step (level from bottom to top) check if it's a full node or a short node and compute the
   152  	// hash value accordingly; for full node having the sibling hash helps to compute the hash value
   153  	// of the next level; for short nodes compute the hash using the common path constructed based on
   154  	// the given short count
   155  	for interimNodeTypesIndex := steps - 1; interimNodeTypesIndex >= 0; interimNodeTypesIndex-- {
   156  
   157  		// Full node
   158  		if bitutils.ReadBit(p.InterimNodeTypes, interimNodeTypesIndex) == 0 {
   159  
   160  			// read and pop the sibling hash value from SiblingHashes
   161  			sibling := p.SiblingHashes[siblingHashIndex]
   162  			siblingHashIndex--
   163  
   164  			// based on the bit at keyIndex, compute the hash
   165  			if bitutils.ReadBit(p.Key, keyIndex) == 0 { // left branching
   166  				currentHash = computeFullHash(currentHash, sibling)
   167  			} else {
   168  				currentHash = computeFullHash(sibling, currentHash) // right branching
   169  			}
   170  
   171  			// move to the parent vertex along the path
   172  			keyIndex--
   173  
   174  			continue
   175  		}
   176  
   177  		// Short node
   178  
   179  		// read and pop from ShortPathLengths
   180  		shortPathLength := int(p.ShortPathLengths[shortPathLengthIndex])
   181  		shortPathLengthIndex--
   182  
   183  		// construct the common path
   184  		commonPath := bitutils.MakeBitVector(shortPathLength)
   185  		for c := shortPathLength - 1; c >= 0; c-- {
   186  			if bitutils.ReadBit(p.Key, keyIndex) == 1 {
   187  				bitutils.SetBit(commonPath, c)
   188  			}
   189  			keyIndex--
   190  		}
   191  		// compute the hash for the short node
   192  		currentHash = computeShortHash(shortPathLength, commonPath, currentHash)
   193  	}
   194  
   195  	// the final hash value should match whith what was expected
   196  	if !bytes.Equal(currentHash, expectedRootHash) {
   197  		return newInvalidProofErrorf("root hash doesn't match, expected %X, computed %X", expectedRootHash, currentHash)
   198  	}
   199  
   200  	return nil
   201  }