github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/chain/trie/trie.go (about) 1 package trie 2 3 import ( 4 "bytes" 5 "fmt" 6 7 "github.com/neatlab/neatio/chain/log" 8 "github.com/neatlab/neatio/utilities/common" 9 "github.com/neatlab/neatio/utilities/crypto" 10 ) 11 12 var ( 13 emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421") 14 15 emptyState = crypto.Keccak256Hash(nil) 16 ) 17 18 type LeafCallback func(leaf []byte, parent common.Hash) error 19 20 type Trie struct { 21 db *Database 22 root node 23 } 24 25 func (t *Trie) newFlag() nodeFlag { 26 return nodeFlag{dirty: true} 27 } 28 29 func New(root common.Hash, db *Database) (*Trie, error) { 30 if db == nil { 31 panic("trie.New called without a database") 32 } 33 trie := &Trie{ 34 db: db, 35 } 36 if root != (common.Hash{}) && root != emptyRoot { 37 rootnode, err := trie.resolveHash(root[:], nil) 38 if err != nil { 39 return nil, err 40 } 41 trie.root = rootnode 42 } 43 return trie, nil 44 } 45 46 func (t *Trie) NodeIterator(start []byte) NodeIterator { 47 return newNodeIterator(t, start) 48 } 49 50 func (t *Trie) Get(key []byte) []byte { 51 res, err := t.TryGet(key) 52 if err != nil { 53 log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) 54 } 55 return res 56 } 57 58 func (t *Trie) TryGet(key []byte) ([]byte, error) { 59 key = keybytesToHex(key) 60 value, newroot, didResolve, err := t.tryGet(t.root, key, 0) 61 if err == nil && didResolve { 62 t.root = newroot 63 } 64 return value, err 65 } 66 67 func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) { 68 switch n := (origNode).(type) { 69 case nil: 70 return nil, nil, false, nil 71 case valueNode: 72 return n, n, false, nil 73 case *shortNode: 74 if len(key)-pos < len(n.Key) || !bytes.Equal(n.Key, key[pos:pos+len(n.Key)]) { 75 76 return nil, n, false, nil 77 } 78 value, newnode, didResolve, err = t.tryGet(n.Val, key, pos+len(n.Key)) 79 if err == nil && didResolve { 80 n = n.copy() 81 n.Val = newnode 82 } 83 return value, n, didResolve, err 84 case *fullNode: 85 value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1) 86 if err == nil && didResolve { 87 n = n.copy() 88 n.Children[key[pos]] = newnode 89 } 90 return value, n, didResolve, err 91 case hashNode: 92 side, err := t.resolveHash(n, key[:pos]) 93 if err != nil { 94 return nil, n, true, err 95 } 96 value, newnode, _, err := t.tryGet(side, key, pos) 97 return value, newnode, true, err 98 default: 99 panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) 100 } 101 } 102 103 func (t *Trie) Update(key, value []byte) { 104 if err := t.TryUpdate(key, value); err != nil { 105 log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) 106 } 107 } 108 109 func (t *Trie) TryUpdate(key, value []byte) error { 110 k := keybytesToHex(key) 111 if len(value) != 0 { 112 _, n, err := t.insert(t.root, nil, k, valueNode(value)) 113 if err != nil { 114 return err 115 } 116 t.root = n 117 } else { 118 _, n, err := t.delete(t.root, nil, k) 119 if err != nil { 120 return err 121 } 122 t.root = n 123 } 124 return nil 125 } 126 127 func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error) { 128 if len(key) == 0 { 129 if v, ok := n.(valueNode); ok { 130 return !bytes.Equal(v, value.(valueNode)), value, nil 131 } 132 return true, value, nil 133 } 134 switch n := n.(type) { 135 case *shortNode: 136 matchlen := prefixLen(key, n.Key) 137 138 if matchlen == len(n.Key) { 139 dirty, nn, err := t.insert(n.Val, append(prefix, key[:matchlen]...), key[matchlen:], value) 140 if !dirty || err != nil { 141 return false, n, err 142 } 143 return true, &shortNode{n.Key, nn, t.newFlag()}, nil 144 } 145 146 branch := &fullNode{flags: t.newFlag()} 147 var err error 148 _, branch.Children[n.Key[matchlen]], err = t.insert(nil, append(prefix, n.Key[:matchlen+1]...), n.Key[matchlen+1:], n.Val) 149 if err != nil { 150 return false, nil, err 151 } 152 _, branch.Children[key[matchlen]], err = t.insert(nil, append(prefix, key[:matchlen+1]...), key[matchlen+1:], value) 153 if err != nil { 154 return false, nil, err 155 } 156 157 if matchlen == 0 { 158 return true, branch, nil 159 } 160 161 return true, &shortNode{key[:matchlen], branch, t.newFlag()}, nil 162 163 case *fullNode: 164 dirty, nn, err := t.insert(n.Children[key[0]], append(prefix, key[0]), key[1:], value) 165 if !dirty || err != nil { 166 return false, n, err 167 } 168 n = n.copy() 169 n.flags = t.newFlag() 170 n.Children[key[0]] = nn 171 return true, n, nil 172 173 case nil: 174 return true, &shortNode{key, value, t.newFlag()}, nil 175 176 case hashNode: 177 178 rn, err := t.resolveHash(n, prefix) 179 if err != nil { 180 return false, nil, err 181 } 182 dirty, nn, err := t.insert(rn, prefix, key, value) 183 if !dirty || err != nil { 184 return false, rn, err 185 } 186 return true, nn, nil 187 188 default: 189 panic(fmt.Sprintf("%T: invalid node: %v", n, n)) 190 } 191 } 192 193 func (t *Trie) Delete(key []byte) { 194 if err := t.TryDelete(key); err != nil { 195 log.Error(fmt.Sprintf("Unhandled trie error: %v", err)) 196 } 197 } 198 199 func (t *Trie) TryDelete(key []byte) error { 200 k := keybytesToHex(key) 201 _, n, err := t.delete(t.root, nil, k) 202 if err != nil { 203 return err 204 } 205 t.root = n 206 return nil 207 } 208 209 func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) { 210 switch n := n.(type) { 211 case *shortNode: 212 matchlen := prefixLen(key, n.Key) 213 if matchlen < len(n.Key) { 214 return false, n, nil 215 } 216 if matchlen == len(key) { 217 return true, nil, nil 218 } 219 220 dirty, side, err := t.delete(n.Val, append(prefix, key[:len(n.Key)]...), key[len(n.Key):]) 221 if !dirty || err != nil { 222 return false, n, err 223 } 224 switch side := side.(type) { 225 case *shortNode: 226 227 return true, &shortNode{concat(n.Key, side.Key...), side.Val, t.newFlag()}, nil 228 default: 229 return true, &shortNode{n.Key, side, t.newFlag()}, nil 230 } 231 232 case *fullNode: 233 dirty, nn, err := t.delete(n.Children[key[0]], append(prefix, key[0]), key[1:]) 234 if !dirty || err != nil { 235 return false, n, err 236 } 237 n = n.copy() 238 n.flags = t.newFlag() 239 n.Children[key[0]] = nn 240 241 pos := -1 242 for i, cld := range &n.Children { 243 if cld != nil { 244 if pos == -1 { 245 pos = i 246 } else { 247 pos = -2 248 break 249 } 250 } 251 } 252 if pos >= 0 { 253 if pos != 16 { 254 255 cnode, err := t.resolve(n.Children[pos], prefix) 256 if err != nil { 257 return false, nil, err 258 } 259 if cnode, ok := cnode.(*shortNode); ok { 260 k := append([]byte{byte(pos)}, cnode.Key...) 261 return true, &shortNode{k, cnode.Val, t.newFlag()}, nil 262 } 263 } 264 265 return true, &shortNode{[]byte{byte(pos)}, n.Children[pos], t.newFlag()}, nil 266 } 267 268 return true, n, nil 269 270 case valueNode: 271 return true, nil, nil 272 273 case nil: 274 return false, nil, nil 275 276 case hashNode: 277 278 rn, err := t.resolveHash(n, prefix) 279 if err != nil { 280 return false, nil, err 281 } 282 dirty, nn, err := t.delete(rn, prefix, key) 283 if !dirty || err != nil { 284 return false, rn, err 285 } 286 return true, nn, nil 287 288 default: 289 panic(fmt.Sprintf("%T: invalid node: %v (%v)", n, n, key)) 290 } 291 } 292 293 func concat(s1 []byte, s2 ...byte) []byte { 294 r := make([]byte, len(s1)+len(s2)) 295 copy(r, s1) 296 copy(r[len(s1):], s2) 297 return r 298 } 299 300 func (t *Trie) resolve(n node, prefix []byte) (node, error) { 301 if n, ok := n.(hashNode); ok { 302 return t.resolveHash(n, prefix) 303 } 304 return n, nil 305 } 306 307 func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) { 308 hash := common.BytesToHash(n) 309 if node := t.db.node(hash); node != nil { 310 return node, nil 311 } 312 return nil, &MissingNodeError{NodeHash: hash, Path: prefix} 313 } 314 315 func (t *Trie) Hash() common.Hash { 316 hash, cached, _ := t.hashRoot(nil, nil) 317 t.root = cached 318 return common.BytesToHash(hash.(hashNode)) 319 } 320 321 func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) { 322 if t.db == nil { 323 panic("commit called on trie with nil database") 324 } 325 hash, cached, err := t.hashRoot(t.db, onleaf) 326 if err != nil { 327 return common.Hash{}, err 328 } 329 t.root = cached 330 return common.BytesToHash(hash.(hashNode)), nil 331 } 332 333 func (t *Trie) hashRoot(db *Database, onleaf LeafCallback) (node, node, error) { 334 if t.root == nil { 335 return hashNode(emptyRoot.Bytes()), nil, nil 336 } 337 h := newHasher(onleaf) 338 defer returnHasherToPool(h) 339 return h.hash(t.root, db, true) 340 }