github.com/SmartMeshFoundation/Spectrum@v0.0.0-20220621030607-452a266fee1e/trie/proof.go (about)

     1  // Copyright 2015 The Spectrum Authors
     2  // This file is part of the Spectrum library.
     3  //
     4  // The Spectrum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The Spectrum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the Spectrum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package trie
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  
    23  	"github.com/SmartMeshFoundation/Spectrum/common"
    24  	"github.com/SmartMeshFoundation/Spectrum/crypto"
    25  	"github.com/SmartMeshFoundation/Spectrum/log"
    26  	"github.com/SmartMeshFoundation/Spectrum/rlp"
    27  )
    28  
    29  // Prove constructs a merkle proof for key. The result contains all
    30  // encoded nodes on the path to the value at key. The value itself is
    31  // also included in the last node and can be retrieved by verifying
    32  // the proof.
    33  //
    34  // If the trie does not contain a value for key, the returned proof
    35  // contains all nodes of the longest existing prefix of the key
    36  // (at least the root node), ending with the node that proves the
    37  // absence of the key.
    38  func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error {
    39  	// Collect all nodes on the path to key.
    40  	key = keybytesToHex(key)
    41  	nodes := []node{}
    42  	tn := t.root
    43  	for len(key) > 0 && tn != nil {
    44  		switch n := tn.(type) {
    45  		case *shortNode:
    46  			if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
    47  				// The trie doesn't contain the key.
    48  				tn = nil
    49  			} else {
    50  				tn = n.Val
    51  				key = key[len(n.Key):]
    52  			}
    53  			nodes = append(nodes, n)
    54  		case *fullNode:
    55  			tn = n.Children[key[0]]
    56  			key = key[1:]
    57  			nodes = append(nodes, n)
    58  		case hashNode:
    59  			var err error
    60  			tn, err = t.resolveHash(n, nil)
    61  			if err != nil {
    62  				log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
    63  				return err
    64  			}
    65  		default:
    66  			panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
    67  		}
    68  	}
    69  	hasher := newHasher(0, 0)
    70  	for i, n := range nodes {
    71  		// Don't bother checking for errors here since hasher panics
    72  		// if encoding doesn't work and we're not writing to any database.
    73  		n, _, _ = hasher.hashChildren(n, nil)
    74  		hn, _ := hasher.store(n, nil, false)
    75  		if hash, ok := hn.(hashNode); ok || i == 0 {
    76  			// If the node's database encoding is a hash (or is the
    77  			// root node), it becomes a proof element.
    78  			if fromLevel > 0 {
    79  				fromLevel--
    80  			} else {
    81  				enc, _ := rlp.EncodeToBytes(n)
    82  				if !ok {
    83  					hash = crypto.Keccak256(enc)
    84  				}
    85  				proofDb.Put(hash, enc)
    86  			}
    87  		}
    88  	}
    89  	return nil
    90  }
    91  
    92  // VerifyProof checks merkle proofs. The given proof must contain the
    93  // value for key in a trie with the given root hash. VerifyProof
    94  // returns an error if the proof contains invalid trie nodes or the
    95  // wrong value.
    96  func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (value []byte, err error, nodes int) {
    97  	key = keybytesToHex(key)
    98  	wantHash := rootHash[:]
    99  	for i := 0; ; i++ {
   100  		buf, _ := proofDb.Get(wantHash)
   101  		if buf == nil {
   102  			return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash[:]), i
   103  		}
   104  		n, err := decodeNode(wantHash, buf, 0)
   105  		if err != nil {
   106  			return nil, fmt.Errorf("bad proof node %d: %v", i, err), i
   107  		}
   108  		keyrest, cld := get(n, key)
   109  		switch cld := cld.(type) {
   110  		case nil:
   111  			// The trie doesn't contain the key.
   112  			return nil, nil, i
   113  		case hashNode:
   114  			key = keyrest
   115  			wantHash = cld
   116  		case valueNode:
   117  			return cld, nil, i + 1
   118  		}
   119  	}
   120  }
   121  
   122  func get(tn node, key []byte) ([]byte, node) {
   123  	for {
   124  		switch n := tn.(type) {
   125  		case *shortNode:
   126  			if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
   127  				return nil, nil
   128  			}
   129  			tn = n.Val
   130  			key = key[len(n.Key):]
   131  		case *fullNode:
   132  			tn = n.Children[key[0]]
   133  			key = key[1:]
   134  		case hashNode:
   135  			return key, n
   136  		case nil:
   137  			return key, nil
   138  		case valueNode:
   139  			return nil, n
   140  		default:
   141  			panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
   142  		}
   143  	}
   144  }