github.com/avence12/go-ethereum@v1.5.10-0.20170320123548-1dfd65f6d047/trie/proof.go (about)

     1  // Copyright 2015 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package trie
    18  
    19  import (
    20  	"bytes"
    21  	"errors"
    22  	"fmt"
    23  
    24  	"github.com/ethereum/go-ethereum/common"
    25  	"github.com/ethereum/go-ethereum/crypto/sha3"
    26  	"github.com/ethereum/go-ethereum/log"
    27  	"github.com/ethereum/go-ethereum/rlp"
    28  )
    29  
    30  // Prove constructs a merkle proof for key. The result contains all
    31  // encoded nodes on the path to the value at key. The value itself is
    32  // also included in the last node and can be retrieved by verifying
    33  // the proof.
    34  //
    35  // If the trie does not contain a value for key, the returned proof
    36  // contains all nodes of the longest existing prefix of the key
    37  // (at least the root node), ending with the node that proves the
    38  // absence of the key.
    39  func (t *Trie) Prove(key []byte) []rlp.RawValue {
    40  	// Collect all nodes on the path to key.
    41  	key = compactHexDecode(key)
    42  	nodes := []node{}
    43  	tn := t.root
    44  	for len(key) > 0 && tn != nil {
    45  		switch n := tn.(type) {
    46  		case *shortNode:
    47  			if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
    48  				// The trie doesn't contain the key.
    49  				tn = nil
    50  			} else {
    51  				tn = n.Val
    52  				key = key[len(n.Key):]
    53  			}
    54  			nodes = append(nodes, n)
    55  		case *fullNode:
    56  			tn = n.Children[key[0]]
    57  			key = key[1:]
    58  			nodes = append(nodes, n)
    59  		case hashNode:
    60  			var err error
    61  			tn, err = t.resolveHash(n, nil, nil)
    62  			if err != nil {
    63  				log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
    64  				return nil
    65  			}
    66  		default:
    67  			panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
    68  		}
    69  	}
    70  	hasher := newHasher(0, 0)
    71  	proof := make([]rlp.RawValue, 0, len(nodes))
    72  	for i, n := range nodes {
    73  		// Don't bother checking for errors here since hasher panics
    74  		// if encoding doesn't work and we're not writing to any database.
    75  		n, _, _ = hasher.hashChildren(n, nil)
    76  		hn, _ := hasher.store(n, nil, false)
    77  		if _, ok := hn.(hashNode); ok || i == 0 {
    78  			// If the node's database encoding is a hash (or is the
    79  			// root node), it becomes a proof element.
    80  			enc, _ := rlp.EncodeToBytes(n)
    81  			proof = append(proof, enc)
    82  		}
    83  	}
    84  	return proof
    85  }
    86  
    87  // VerifyProof checks merkle proofs. The given proof must contain the
    88  // value for key in a trie with the given root hash. VerifyProof
    89  // returns an error if the proof contains invalid trie nodes or the
    90  // wrong value.
    91  func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) {
    92  	key = compactHexDecode(key)
    93  	sha := sha3.NewKeccak256()
    94  	wantHash := rootHash.Bytes()
    95  	for i, buf := range proof {
    96  		sha.Reset()
    97  		sha.Write(buf)
    98  		if !bytes.Equal(sha.Sum(nil), wantHash) {
    99  			return nil, fmt.Errorf("bad proof node %d: hash mismatch", i)
   100  		}
   101  		n, err := decodeNode(wantHash, buf, 0)
   102  		if err != nil {
   103  			return nil, fmt.Errorf("bad proof node %d: %v", i, err)
   104  		}
   105  		keyrest, cld := get(n, key)
   106  		switch cld := cld.(type) {
   107  		case nil:
   108  			if i != len(proof)-1 {
   109  				return nil, fmt.Errorf("key mismatch at proof node %d", i)
   110  			} else {
   111  				// The trie doesn't contain the key.
   112  				return nil, nil
   113  			}
   114  		case hashNode:
   115  			key = keyrest
   116  			wantHash = cld
   117  		case valueNode:
   118  			if i != len(proof)-1 {
   119  				return nil, errors.New("additional nodes at end of proof")
   120  			}
   121  			return cld, nil
   122  		}
   123  	}
   124  	return nil, errors.New("unexpected end of proof")
   125  }
   126  
   127  func get(tn node, key []byte) ([]byte, node) {
   128  	for len(key) > 0 {
   129  		switch n := tn.(type) {
   130  		case *shortNode:
   131  			if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
   132  				return nil, nil
   133  			}
   134  			tn = n.Val
   135  			key = key[len(n.Key):]
   136  		case *fullNode:
   137  			tn = n.Children[key[0]]
   138  			key = key[1:]
   139  		case hashNode:
   140  			return key, n
   141  		case nil:
   142  			return key, nil
   143  		default:
   144  			panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
   145  		}
   146  	}
   147  	return nil, tn.(valueNode)
   148  }