github.com/line/ostracon@v1.0.10-0.20230328032236-7f20145f065d/crypto/merkle/proof.go (about)

     1  package merkle
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  
     8  	tmcrypto "github.com/tendermint/tendermint/proto/tendermint/crypto"
     9  
    10  	"github.com/line/ostracon/crypto/tmhash"
    11  )
    12  
    13  const (
    14  	// MaxAunts is the maximum number of aunts that can be included in a Proof.
    15  	// This corresponds to a tree of size 2^100, which should be sufficient for all conceivable purposes.
    16  	// This maximum helps prevent Denial-of-Service attacks by limitting the size of the proofs.
    17  	MaxAunts = 100
    18  )
    19  
    20  // Proof represents a Merkle proof.
    21  // NOTE: The convention for proofs is to include leaf hashes but to
    22  // exclude the root hash.
    23  // This convention is implemented across IAVL range proofs as well.
    24  // Keep this consistent unless there's a very good reason to change
    25  // everything.  This also affects the generalized proof system as
    26  // well.
    27  type Proof struct {
    28  	Total    int64    `json:"total"`     // Total number of items.
    29  	Index    int64    `json:"index"`     // Index of item to prove.
    30  	LeafHash []byte   `json:"leaf_hash"` // Hash of item value.
    31  	Aunts    [][]byte `json:"aunts"`     // Hashes from leaf's sibling to a root's child.
    32  }
    33  
    34  // ProofsFromByteSlices computes inclusion proof for given items.
    35  // proofs[0] is the proof for items[0].
    36  func ProofsFromByteSlices(items [][]byte) (rootHash []byte, proofs []*Proof) {
    37  	trails, rootSPN := trailsFromByteSlices(items)
    38  	rootHash = rootSPN.Hash
    39  	proofs = make([]*Proof, len(items))
    40  	for i, trail := range trails {
    41  		proofs[i] = &Proof{
    42  			Total:    int64(len(items)),
    43  			Index:    int64(i),
    44  			LeafHash: trail.Hash,
    45  			Aunts:    trail.FlattenAunts(),
    46  		}
    47  	}
    48  	return
    49  }
    50  
    51  // Verify that the Proof proves the root hash.
    52  // Check sp.Index/sp.Total manually if needed
    53  func (sp *Proof) Verify(rootHash []byte, leaf []byte) error {
    54  	leafHash := leafHash(leaf)
    55  	if sp.Total < 0 {
    56  		return errors.New("proof total must be positive")
    57  	}
    58  	if sp.Index < 0 {
    59  		return errors.New("proof index cannot be negative")
    60  	}
    61  	if !bytes.Equal(sp.LeafHash, leafHash) {
    62  		return fmt.Errorf("invalid leaf hash: wanted %X got %X", leafHash, sp.LeafHash)
    63  	}
    64  	computedHash := sp.ComputeRootHash()
    65  	if !bytes.Equal(computedHash, rootHash) {
    66  		return fmt.Errorf("invalid root hash: wanted %X got %X", rootHash, computedHash)
    67  	}
    68  	return nil
    69  }
    70  
    71  // Compute the root hash given a leaf hash.  Does not verify the result.
    72  func (sp *Proof) ComputeRootHash() []byte {
    73  	return computeHashFromAunts(
    74  		sp.Index,
    75  		sp.Total,
    76  		sp.LeafHash,
    77  		sp.Aunts,
    78  	)
    79  }
    80  
    81  // String implements the stringer interface for Proof.
    82  // It is a wrapper around StringIndented.
    83  func (sp *Proof) String() string {
    84  	return sp.StringIndented("")
    85  }
    86  
    87  // StringIndented generates a canonical string representation of a Proof.
    88  func (sp *Proof) StringIndented(indent string) string {
    89  	return fmt.Sprintf(`Proof{
    90  %s  Aunts: %X
    91  %s}`,
    92  		indent, sp.Aunts,
    93  		indent)
    94  }
    95  
    96  // ValidateBasic performs basic validation.
    97  // NOTE: it expects the LeafHash and the elements of Aunts to be of size tmhash.Size,
    98  // and it expects at most MaxAunts elements in Aunts.
    99  func (sp *Proof) ValidateBasic() error {
   100  	if sp.Total < 0 {
   101  		return errors.New("negative Total")
   102  	}
   103  	if sp.Index < 0 {
   104  		return errors.New("negative Index")
   105  	}
   106  	if len(sp.LeafHash) != tmhash.Size {
   107  		return fmt.Errorf("expected LeafHash size to be %d, got %d", tmhash.Size, len(sp.LeafHash))
   108  	}
   109  	if len(sp.Aunts) > MaxAunts {
   110  		return fmt.Errorf("expected no more than %d aunts, got %d", MaxAunts, len(sp.Aunts))
   111  	}
   112  	for i, auntHash := range sp.Aunts {
   113  		if len(auntHash) != tmhash.Size {
   114  			return fmt.Errorf("expected Aunts#%d size to be %d, got %d", i, tmhash.Size, len(auntHash))
   115  		}
   116  	}
   117  	return nil
   118  }
   119  
   120  func (sp *Proof) ToProto() *tmcrypto.Proof {
   121  	if sp == nil {
   122  		return nil
   123  	}
   124  	pb := new(tmcrypto.Proof)
   125  
   126  	pb.Total = sp.Total
   127  	pb.Index = sp.Index
   128  	pb.LeafHash = sp.LeafHash
   129  	pb.Aunts = sp.Aunts
   130  
   131  	return pb
   132  }
   133  
   134  func ProofFromProto(pb *tmcrypto.Proof) (*Proof, error) {
   135  	if pb == nil {
   136  		return nil, errors.New("nil proof")
   137  	}
   138  
   139  	sp := new(Proof)
   140  
   141  	sp.Total = pb.Total
   142  	sp.Index = pb.Index
   143  	sp.LeafHash = pb.LeafHash
   144  	sp.Aunts = pb.Aunts
   145  
   146  	return sp, sp.ValidateBasic()
   147  }
   148  
   149  // Use the leafHash and innerHashes to get the root merkle hash.
   150  // If the length of the innerHashes slice isn't exactly correct, the result is nil.
   151  // Recursive impl.
   152  func computeHashFromAunts(index, total int64, leafHash []byte, innerHashes [][]byte) []byte {
   153  	if index >= total || index < 0 || total <= 0 {
   154  		return nil
   155  	}
   156  	switch total {
   157  	case 0:
   158  		panic("Cannot call computeHashFromAunts() with 0 total")
   159  	case 1:
   160  		if len(innerHashes) != 0 {
   161  			return nil
   162  		}
   163  		return leafHash
   164  	default:
   165  		if len(innerHashes) == 0 {
   166  			return nil
   167  		}
   168  		numLeft := getSplitPoint(total)
   169  		if index < numLeft {
   170  			leftHash := computeHashFromAunts(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1])
   171  			if leftHash == nil {
   172  				return nil
   173  			}
   174  			return innerHash(leftHash, innerHashes[len(innerHashes)-1])
   175  		}
   176  		rightHash := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1])
   177  		if rightHash == nil {
   178  			return nil
   179  		}
   180  		return innerHash(innerHashes[len(innerHashes)-1], rightHash)
   181  	}
   182  }
   183  
   184  // ProofNode is a helper structure to construct merkle proof.
   185  // The node and the tree is thrown away afterwards.
   186  // Exactly one of node.Left and node.Right is nil, unless node is the root, in which case both are nil.
   187  // node.Parent.Hash = hash(node.Hash, node.Right.Hash) or
   188  // hash(node.Left.Hash, node.Hash), depending on whether node is a left/right child.
   189  type ProofNode struct {
   190  	Hash   []byte
   191  	Parent *ProofNode
   192  	Left   *ProofNode // Left sibling  (only one of Left,Right is set)
   193  	Right  *ProofNode // Right sibling (only one of Left,Right is set)
   194  }
   195  
   196  // FlattenAunts will return the inner hashes for the item corresponding to the leaf,
   197  // starting from a leaf ProofNode.
   198  func (spn *ProofNode) FlattenAunts() [][]byte {
   199  	// Nonrecursive impl.
   200  	innerHashes := [][]byte{}
   201  	for spn != nil {
   202  		switch {
   203  		case spn.Left != nil:
   204  			innerHashes = append(innerHashes, spn.Left.Hash)
   205  		case spn.Right != nil:
   206  			innerHashes = append(innerHashes, spn.Right.Hash)
   207  		default:
   208  			break
   209  		}
   210  		spn = spn.Parent
   211  	}
   212  	return innerHashes
   213  }
   214  
   215  // trails[0].Hash is the leaf hash for items[0].
   216  // trails[i].Parent.Parent....Parent == root for all i.
   217  func trailsFromByteSlices(items [][]byte) (trails []*ProofNode, root *ProofNode) {
   218  	// Recursive impl.
   219  	switch len(items) {
   220  	case 0:
   221  		return []*ProofNode{}, &ProofNode{emptyHash(), nil, nil, nil}
   222  	case 1:
   223  		trail := &ProofNode{leafHash(items[0]), nil, nil, nil}
   224  		return []*ProofNode{trail}, trail
   225  	default:
   226  		k := getSplitPoint(int64(len(items)))
   227  		lefts, leftRoot := trailsFromByteSlices(items[:k])
   228  		rights, rightRoot := trailsFromByteSlices(items[k:])
   229  		rootHash := innerHash(leftRoot.Hash, rightRoot.Hash)
   230  		root := &ProofNode{rootHash, nil, nil, nil}
   231  		leftRoot.Parent = root
   232  		leftRoot.Right = rightRoot
   233  		rightRoot.Parent = root
   234  		rightRoot.Left = leftRoot
   235  		return append(lefts, rights...), root
   236  	}
   237  }