github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/chain/trie/proof.go (about)

     1  package trie
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  
     7  	"github.com/neatlab/neatio/chain/log"
     8  	"github.com/neatlab/neatio/neatdb"
     9  	"github.com/neatlab/neatio/utilities/common"
    10  	"github.com/neatlab/neatio/utilities/rlp"
    11  )
    12  
    13  func (t *Trie) Prove(key []byte, fromLevel uint, proofDb neatdb.Writer) error {
    14  	key = keybytesToHex(key)
    15  	var nodes []node
    16  	tn := t.root
    17  	for len(key) > 0 && tn != nil {
    18  		switch n := tn.(type) {
    19  		case *shortNode:
    20  			if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
    21  				tn = nil
    22  			} else {
    23  				tn = n.Val
    24  				key = key[len(n.Key):]
    25  			}
    26  			nodes = append(nodes, n)
    27  		case *fullNode:
    28  			tn = n.Children[key[0]]
    29  			key = key[1:]
    30  			nodes = append(nodes, n)
    31  		case hashNode:
    32  			var err error
    33  			tn, err = t.resolveHash(n, nil)
    34  			if err != nil {
    35  				log.Error(fmt.Sprintf("Unhandled trie error: %v", err))
    36  				return err
    37  			}
    38  		default:
    39  			panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
    40  		}
    41  	}
    42  	hasher := newHasher(nil)
    43  	defer returnHasherToPool(hasher)
    44  
    45  	for i, n := range nodes {
    46  		n, _, _ = hasher.hashChildren(n, nil)
    47  		hn, _ := hasher.store(n, nil, false)
    48  		if hash, ok := hn.(hashNode); ok || i == 0 {
    49  			if fromLevel > 0 {
    50  				fromLevel--
    51  			} else {
    52  				enc, _ := rlp.EncodeToBytes(n)
    53  				if !ok {
    54  					hash = hasher.makeHashNode(enc)
    55  				}
    56  				proofDb.Put(hash, enc)
    57  			}
    58  		}
    59  	}
    60  	return nil
    61  }
    62  
    63  func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb neatdb.Writer) error {
    64  	return t.trie.Prove(key, fromLevel, proofDb)
    65  }
    66  
    67  func VerifyProof(rootHash common.Hash, key []byte, proofDb neatdb.Reader) (value []byte, nodes int, err error) {
    68  	key = keybytesToHex(key)
    69  	wantHash := rootHash
    70  	for i := 0; ; i++ {
    71  		buf, _ := proofDb.Get(wantHash[:])
    72  		if buf == nil {
    73  			return nil, i, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash)
    74  		}
    75  		n, err := decodeNode(wantHash[:], buf)
    76  		if err != nil {
    77  			return nil, i, fmt.Errorf("bad proof node %d: %v", i, err)
    78  		}
    79  		keyrest, cld := get(n, key)
    80  		switch cld := cld.(type) {
    81  		case nil:
    82  			return nil, i, nil
    83  		case hashNode:
    84  			key = keyrest
    85  			copy(wantHash[:], cld)
    86  		case valueNode:
    87  			return cld, i + 1, nil
    88  		}
    89  	}
    90  }
    91  
    92  func get(tn node, key []byte) ([]byte, node) {
    93  	for {
    94  		switch n := tn.(type) {
    95  		case *shortNode:
    96  			if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
    97  				return nil, nil
    98  			}
    99  			tn = n.Val
   100  			key = key[len(n.Key):]
   101  		case *fullNode:
   102  			tn = n.Children[key[0]]
   103  			key = key[1:]
   104  		case hashNode:
   105  			return key, n
   106  		case nil:
   107  			return key, nil
   108  		case valueNode:
   109  			return nil, n
   110  		default:
   111  			panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
   112  		}
   113  	}
   114  }