github.com/aakash4dev/cometbft@v0.38.2/crypto/merkle/proof.go (about)

     1  package merkle
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  
     8  	"github.com/aakash4dev/cometbft/crypto/tmhash"
     9  	cmtcrypto "github.com/aakash4dev/cometbft/proto/tendermint/crypto"
    10  )
    11  
    12  const (
    13  	// MaxAunts is the maximum number of aunts that can be included in a Proof.
    14  	// This corresponds to a tree of size 2^100, which should be sufficient for all conceivable purposes.
    15  	// This maximum helps prevent Denial-of-Service attacks by limitting the size of the proofs.
    16  	MaxAunts = 100
    17  )
    18  
    19  // Proof represents a Merkle proof.
    20  // NOTE: The convention for proofs is to include leaf hashes but to
    21  // exclude the root hash.
    22  // This convention is implemented across IAVL range proofs as well.
    23  // Keep this consistent unless there's a very good reason to change
    24  // everything.  This also affects the generalized proof system as
    25  // well.
    26  type Proof struct {
    27  	Total    int64    `json:"total"`     // Total number of items.
    28  	Index    int64    `json:"index"`     // Index of item to prove.
    29  	LeafHash []byte   `json:"leaf_hash"` // Hash of item value.
    30  	Aunts    [][]byte `json:"aunts"`     // Hashes from leaf's sibling to a root's child.
    31  }
    32  
    33  // ProofsFromByteSlices computes inclusion proof for given items.
    34  // proofs[0] is the proof for items[0].
    35  func ProofsFromByteSlices(items [][]byte) (rootHash []byte, proofs []*Proof) {
    36  	trails, rootSPN := trailsFromByteSlices(items)
    37  	rootHash = rootSPN.Hash
    38  	proofs = make([]*Proof, len(items))
    39  	for i, trail := range trails {
    40  		proofs[i] = &Proof{
    41  			Total:    int64(len(items)),
    42  			Index:    int64(i),
    43  			LeafHash: trail.Hash,
    44  			Aunts:    trail.FlattenAunts(),
    45  		}
    46  	}
    47  	return
    48  }
    49  
    50  // Verify that the Proof proves the root hash.
    51  // Check sp.Index/sp.Total manually if needed
    52  func (sp *Proof) Verify(rootHash []byte, leaf []byte) error {
    53  	if rootHash == nil {
    54  		return fmt.Errorf("invalid root hash: cannot be nil")
    55  	}
    56  	if sp.Total < 0 {
    57  		return errors.New("proof total must be positive")
    58  	}
    59  	if sp.Index < 0 {
    60  		return errors.New("proof index cannot be negative")
    61  	}
    62  	leafHash := leafHash(leaf)
    63  	if !bytes.Equal(sp.LeafHash, leafHash) {
    64  		return fmt.Errorf("invalid leaf hash: wanted %X got %X", leafHash, sp.LeafHash)
    65  	}
    66  	computedHash, err := sp.computeRootHash()
    67  	if err != nil {
    68  		return fmt.Errorf("compute root hash: %w", err)
    69  	}
    70  	if !bytes.Equal(computedHash, rootHash) {
    71  		return fmt.Errorf("invalid root hash: wanted %X got %X", rootHash, computedHash)
    72  	}
    73  	return nil
    74  }
    75  
    76  // Compute the root hash given a leaf hash.
    77  func (sp *Proof) computeRootHash() ([]byte, error) {
    78  	return computeHashFromAunts(
    79  		sp.Index,
    80  		sp.Total,
    81  		sp.LeafHash,
    82  		sp.Aunts,
    83  	)
    84  }
    85  
    86  // String implements the stringer interface for Proof.
    87  // It is a wrapper around StringIndented.
    88  func (sp *Proof) String() string {
    89  	return sp.StringIndented("")
    90  }
    91  
    92  // StringIndented generates a canonical string representation of a Proof.
    93  func (sp *Proof) StringIndented(indent string) string {
    94  	return fmt.Sprintf(`Proof{
    95  %s  Aunts: %X
    96  %s}`,
    97  		indent, sp.Aunts,
    98  		indent)
    99  }
   100  
   101  // ValidateBasic performs basic validation.
   102  // NOTE: it expects the LeafHash and the elements of Aunts to be of size tmhash.Size,
   103  // and it expects at most MaxAunts elements in Aunts.
   104  func (sp *Proof) ValidateBasic() error {
   105  	if sp.Total < 0 {
   106  		return errors.New("negative Total")
   107  	}
   108  	if sp.Index < 0 {
   109  		return errors.New("negative Index")
   110  	}
   111  	if len(sp.LeafHash) != tmhash.Size {
   112  		return fmt.Errorf("expected LeafHash size to be %d, got %d", tmhash.Size, len(sp.LeafHash))
   113  	}
   114  	if len(sp.Aunts) > MaxAunts {
   115  		return fmt.Errorf("expected no more than %d aunts, got %d", MaxAunts, len(sp.Aunts))
   116  	}
   117  	for i, auntHash := range sp.Aunts {
   118  		if len(auntHash) != tmhash.Size {
   119  			return fmt.Errorf("expected Aunts#%d size to be %d, got %d", i, tmhash.Size, len(auntHash))
   120  		}
   121  	}
   122  	return nil
   123  }
   124  
   125  func (sp *Proof) ToProto() *cmtcrypto.Proof {
   126  	if sp == nil {
   127  		return nil
   128  	}
   129  	pb := new(cmtcrypto.Proof)
   130  
   131  	pb.Total = sp.Total
   132  	pb.Index = sp.Index
   133  	pb.LeafHash = sp.LeafHash
   134  	pb.Aunts = sp.Aunts
   135  
   136  	return pb
   137  }
   138  
   139  func ProofFromProto(pb *cmtcrypto.Proof) (*Proof, error) {
   140  	if pb == nil {
   141  		return nil, errors.New("nil proof")
   142  	}
   143  
   144  	sp := new(Proof)
   145  
   146  	sp.Total = pb.Total
   147  	sp.Index = pb.Index
   148  	sp.LeafHash = pb.LeafHash
   149  	sp.Aunts = pb.Aunts
   150  
   151  	return sp, sp.ValidateBasic()
   152  }
   153  
   154  // Use the leafHash and innerHashes to get the root merkle hash.
   155  // If the length of the innerHashes slice isn't exactly correct, the result is nil.
   156  // Recursive impl.
   157  func computeHashFromAunts(index, total int64, leafHash []byte, innerHashes [][]byte) ([]byte, error) {
   158  	if index >= total || index < 0 || total <= 0 {
   159  		return nil, fmt.Errorf("invalid index %d and/or total %d", index, total)
   160  	}
   161  	switch total {
   162  	case 0:
   163  		panic("Cannot call computeHashFromAunts() with 0 total")
   164  	case 1:
   165  		if len(innerHashes) != 0 {
   166  			return nil, fmt.Errorf("unexpected inner hashes")
   167  		}
   168  		return leafHash, nil
   169  	default:
   170  		if len(innerHashes) == 0 {
   171  			return nil, fmt.Errorf("expected at least one inner hash")
   172  		}
   173  		numLeft := getSplitPoint(total)
   174  		if index < numLeft {
   175  			leftHash, err := computeHashFromAunts(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1])
   176  			if err != nil {
   177  				return nil, err
   178  			}
   179  
   180  			return innerHash(leftHash, innerHashes[len(innerHashes)-1]), nil
   181  		}
   182  		rightHash, err := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1])
   183  		if err != nil {
   184  			return nil, err
   185  		}
   186  		return innerHash(innerHashes[len(innerHashes)-1], rightHash), nil
   187  	}
   188  }
   189  
   190  // ProofNode is a helper structure to construct merkle proof.
   191  // The node and the tree is thrown away afterwards.
   192  // Exactly one of node.Left and node.Right is nil, unless node is the root, in which case both are nil.
   193  // node.Parent.Hash = hash(node.Hash, node.Right.Hash) or
   194  // hash(node.Left.Hash, node.Hash), depending on whether node is a left/right child.
   195  type ProofNode struct {
   196  	Hash   []byte
   197  	Parent *ProofNode
   198  	Left   *ProofNode // Left sibling  (only one of Left,Right is set)
   199  	Right  *ProofNode // Right sibling (only one of Left,Right is set)
   200  }
   201  
   202  // FlattenAunts will return the inner hashes for the item corresponding to the leaf,
   203  // starting from a leaf ProofNode.
   204  func (spn *ProofNode) FlattenAunts() [][]byte {
   205  	// Nonrecursive impl.
   206  	innerHashes := [][]byte{}
   207  	for spn != nil {
   208  		switch {
   209  		case spn.Left != nil:
   210  			innerHashes = append(innerHashes, spn.Left.Hash)
   211  		case spn.Right != nil:
   212  			innerHashes = append(innerHashes, spn.Right.Hash)
   213  		default:
   214  			break
   215  		}
   216  		spn = spn.Parent
   217  	}
   218  	return innerHashes
   219  }
   220  
   221  // trails[0].Hash is the leaf hash for items[0].
   222  // trails[i].Parent.Parent....Parent == root for all i.
   223  func trailsFromByteSlices(items [][]byte) (trails []*ProofNode, root *ProofNode) {
   224  	// Recursive impl.
   225  	switch len(items) {
   226  	case 0:
   227  		return []*ProofNode{}, &ProofNode{emptyHash(), nil, nil, nil}
   228  	case 1:
   229  		trail := &ProofNode{leafHash(items[0]), nil, nil, nil}
   230  		return []*ProofNode{trail}, trail
   231  	default:
   232  		k := getSplitPoint(int64(len(items)))
   233  		lefts, leftRoot := trailsFromByteSlices(items[:k])
   234  		rights, rightRoot := trailsFromByteSlices(items[k:])
   235  		rootHash := innerHash(leftRoot.Hash, rightRoot.Hash)
   236  		root := &ProofNode{rootHash, nil, nil, nil}
   237  		leftRoot.Parent = root
   238  		leftRoot.Right = rightRoot
   239  		rightRoot.Parent = root
   240  		rightRoot.Left = leftRoot
   241  		return append(lefts, rights...), root
   242  	}
   243  }