github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/chain/trie/proof.go (about) 1 package trie 2 3 import ( 4 "bytes" 5 "fmt" 6 7 "github.com/neatlab/neatio/chain/log" 8 "github.com/neatlab/neatio/neatdb" 9 "github.com/neatlab/neatio/utilities/common" 10 "github.com/neatlab/neatio/utilities/rlp" 11 ) 12 13 func (t *Trie) Prove(key []byte, fromLevel uint, proofDb neatdb.Writer) error { 14 key = keybytesToHex(key) 15 var nodes []node 16 tn := t.root 17 for len(key) > 0 && tn != nil { 18 switch n := tn.(type) { 19 case *shortNode: 20 if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { 21 tn = nil 22 } else { 23 tn = n.Val 24 key = key[len(n.Key):] 25 } 26 nodes = append(nodes, n) 27 case *fullNode: 28 tn = n.Children[key[0]] 29 key = key[1:] 30 nodes = append(nodes, n) 31 case hashNode: 32 var err error 33 tn, err = t.resolveHash(n, nil) 34 if err != nil { 35 log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) 36 return err 37 } 38 default: 39 panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) 40 } 41 } 42 hasher := newHasher(nil) 43 defer returnHasherToPool(hasher) 44 45 for i, n := range nodes { 46 n, _, _ = hasher.hashChildren(n, nil) 47 hn, _ := hasher.store(n, nil, false) 48 if hash, ok := hn.(hashNode); ok || i == 0 { 49 if fromLevel > 0 { 50 fromLevel-- 51 } else { 52 enc, _ := rlp.EncodeToBytes(n) 53 if !ok { 54 hash = hasher.makeHashNode(enc) 55 } 56 proofDb.Put(hash, enc) 57 } 58 } 59 } 60 return nil 61 } 62 63 func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb neatdb.Writer) error { 64 return t.trie.Prove(key, fromLevel, proofDb) 65 } 66 67 func VerifyProof(rootHash common.Hash, key []byte, proofDb neatdb.Reader) (value []byte, nodes int, err error) { 68 key = keybytesToHex(key) 69 wantHash := rootHash 70 for i := 0; ; i++ { 71 buf, _ := proofDb.Get(wantHash[:]) 72 if buf == nil { 73 return nil, i, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash) 74 } 75 n, err := decodeNode(wantHash[:], buf) 76 if err != nil { 77 return nil, i, fmt.Errorf("bad proof node %d: %v", i, err) 78 } 79 keyrest, cld := get(n, key) 80 switch cld := cld.(type) { 81 case nil: 82 return nil, i, nil 83 case hashNode: 84 key = keyrest 85 copy(wantHash[:], cld) 86 case valueNode: 87 return cld, i + 1, nil 88 } 89 } 90 } 91 92 func get(tn node, key []byte) ([]byte, node) { 93 for { 94 switch n := tn.(type) { 95 case *shortNode: 96 if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) { 97 return nil, nil 98 } 99 tn = n.Val 100 key = key[len(n.Key):] 101 case *fullNode: 102 tn = n.Children[key[0]] 103 key = key[1:] 104 case hashNode: 105 return key, n 106 case nil: 107 return key, nil 108 case valueNode: 109 return nil, n 110 default: 111 panic(fmt.Sprintf("%T: invalid node: %v", tn, tn)) 112 } 113 } 114 }