github.com/SmartMeshFoundation/Spectrum@v0.0.0-20220621030607-452a266fee1e/trie/proof.go (about) 1 // Copyright 2015 The Spectrum Authors 2 // This file is part of the Spectrum library. 3 // 4 // The Spectrum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The Spectrum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the Spectrum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package trie 18 19 import ( 20 "bytes" 21 "fmt" 22 23 "github.com/SmartMeshFoundation/Spectrum/common" 24 "github.com/SmartMeshFoundation/Spectrum/crypto" 25 "github.com/SmartMeshFoundation/Spectrum/log" 26 "github.com/SmartMeshFoundation/Spectrum/rlp" 27 ) 28 29 // Prove constructs a merkle proof for key. The result contains all 30 // encoded nodes on the path to the value at key. The value itself is 31 // also included in the last node and can be retrieved by verifying 32 // the proof. 33 // 34 // If the trie does not contain a value for key, the returned proof 35 // contains all nodes of the longest existing prefix of the key 36 // (at least the root node), ending with the node that proves the 37 // absence of the key. 38 func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error { 39 // Collect all nodes on the path to key. 40 key = keybytesToHex(key) 41 nodes := []node{} 42 tn := t.root 43 for len(key) > 0 && tn != nil { 44 switch n := tn.(type) { 45 case *shortNode: 46 if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { 47 // The trie doesn't contain the key. 48 tn = nil 49 } else { 50 tn = n.Val 51 key = key[len(n.Key):] 52 } 53 nodes = append(nodes, n) 54 case *fullNode: 55 tn = n.Children[key[0]] 56 key = key[1:] 57 nodes = append(nodes, n) 58 case hashNode: 59 var err error 60 tn, err = t.resolveHash(n, nil) 61 if err != nil { 62 log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) 63 return err 64 } 65 default: 66 panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) 67 } 68 } 69 hasher := newHasher(0, 0) 70 for i, n := range nodes { 71 // Don't bother checking for errors here since hasher panics 72 // if encoding doesn't work and we're not writing to any database. 73 n, _, _ = hasher.hashChildren(n, nil) 74 hn, _ := hasher.store(n, nil, false) 75 if hash, ok := hn.(hashNode); ok || i == 0 { 76 // If the node's database encoding is a hash (or is the 77 // root node), it becomes a proof element. 78 if fromLevel > 0 { 79 fromLevel-- 80 } else { 81 enc, _ := rlp.EncodeToBytes(n) 82 if !ok { 83 hash = crypto.Keccak256(enc) 84 } 85 proofDb.Put(hash, enc) 86 } 87 } 88 } 89 return nil 90 } 91 92 // VerifyProof checks merkle proofs. The given proof must contain the 93 // value for key in a trie with the given root hash. VerifyProof 94 // returns an error if the proof contains invalid trie nodes or the 95 // wrong value. 96 func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (value []byte, err error, nodes int) { 97 key = keybytesToHex(key) 98 wantHash := rootHash[:] 99 for i := 0; ; i++ { 100 buf, _ := proofDb.Get(wantHash) 101 if buf == nil { 102 return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash[:]), i 103 } 104 n, err := decodeNode(wantHash, buf, 0) 105 if err != nil { 106 return nil, fmt.Errorf("bad proof node %d: %v", i, err), i 107 } 108 keyrest, cld := get(n, key) 109 switch cld := cld.(type) { 110 case nil: 111 // The trie doesn't contain the key. 112 return nil, nil, i 113 case hashNode: 114 key = keyrest 115 wantHash = cld 116 case valueNode: 117 return cld, nil, i + 1 118 } 119 } 120 } 121 122 func get(tn node, key []byte) ([]byte, node) { 123 for { 124 switch n := tn.(type) { 125 case *shortNode: 126 if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { 127 return nil, nil 128 } 129 tn = n.Val 130 key = key[len(n.Key):] 131 case *fullNode: 132 tn = n.Children[key[0]] 133 key = key[1:] 134 case hashNode: 135 return key, n 136 case nil: 137 return key, nil 138 case valueNode: 139 return nil, n 140 default: 141 panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) 142 } 143 } 144 }