github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/core/mpt/trie.go (about) 1 package mpt 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "fmt" 8 9 "github.com/nspcc-dev/neo-go/pkg/core/storage" 10 "github.com/nspcc-dev/neo-go/pkg/io" 11 "github.com/nspcc-dev/neo-go/pkg/util" 12 ) 13 14 // TrieMode is the storage mode of a trie, it affects the DB scheme. 15 type TrieMode byte 16 17 // TrieMode is the storage mode of a trie. 18 const ( 19 // ModeAll is used to store everything. 20 ModeAll TrieMode = 0 21 // ModeLatest is used to only store the latest root. 22 ModeLatest TrieMode = 0x01 23 // ModeGCFlag is a flag for GC. 24 ModeGCFlag TrieMode = 0x02 25 // ModeGC is used to store a set of roots with GC possible, it combines 26 // GCFlag and Latest (because it needs RC, but it has GC enabled). 27 ModeGC TrieMode = 0x03 28 ) 29 30 // Trie is an MPT trie storing all key-value pairs. 31 type Trie struct { 32 Store *storage.MemCachedStore 33 34 root Node 35 mode TrieMode 36 refcount map[util.Uint256]*cachedNode 37 } 38 39 type cachedNode struct { 40 bytes []byte 41 initial int32 42 refcount int32 43 } 44 45 // ErrNotFound is returned when the requested trie item is missing. 46 var ErrNotFound = errors.New("item not found") 47 48 // RC returns true when reference counting is enabled. 49 func (m TrieMode) RC() bool { 50 return m&ModeLatest != 0 51 } 52 53 // GC returns true when garbage collection is enabled. 54 func (m TrieMode) GC() bool { 55 return m&ModeGCFlag != 0 56 } 57 58 // NewTrie returns a new MPT trie. It accepts a MemCachedStore to decouple storage errors from logic errors, 59 // so that all storage errors are processed during `store.Persist()` at the caller. 60 // Another benefit is that every `Put` can be considered an atomic operation. 61 func NewTrie(root Node, mode TrieMode, store *storage.MemCachedStore) *Trie { 62 if root == nil { 63 root = EmptyNode{} 64 } 65 66 return &Trie{ 67 Store: store, 68 root: root, 69 70 mode: mode, 71 refcount: make(map[util.Uint256]*cachedNode), 72 } 73 } 74 75 // Get returns the value for the provided key in t. 76 func (t *Trie) Get(key []byte) ([]byte, error) { 77 if len(key) > MaxKeyLength { 78 return nil, errors.New("key is too big") 79 } 80 path := toNibbles(key) 81 r, leaf, _, err := t.getWithPath(t.root, path, true) 82 if err != nil { 83 return nil, err 84 } 85 t.root = r 86 return bytes.Clone(leaf.(*LeafNode).value), nil 87 } 88 89 // getWithPath returns the current node with all hash nodes along the path replaced 90 // with their "unhashed" counterparts. It also returns node which the provided path in a 91 // subtrie rooting in curr points to. In case of `strict` set to `false`, the 92 // provided path can be incomplete, so it also returns the full path that points to 93 // the node found at the specified incomplete path. In case of `strict` set to `true`, 94 // the resulting path matches the provided one. 95 func (t *Trie) getWithPath(curr Node, path []byte, strict bool) (Node, Node, []byte, error) { 96 switch n := curr.(type) { 97 case *LeafNode: 98 if len(path) == 0 { 99 return curr, n, []byte{}, nil 100 } 101 case *BranchNode: 102 i, path := splitPath(path) 103 if i == lastChild && !strict { 104 return curr, n, []byte{}, nil 105 } 106 r, res, prefix, err := t.getWithPath(n.Children[i], path, strict) 107 if err != nil { 108 return nil, nil, nil, err 109 } 110 n.Children[i] = r 111 return n, res, append([]byte{i}, prefix...), nil 112 case EmptyNode: 113 case *HashNode: 114 if r, err := t.getFromStore(n.hash); err == nil { 115 return t.getWithPath(r, path, strict) 116 } 117 case *ExtensionNode: 118 if len(path) == 0 && !strict { 119 return curr, n.next, n.key, nil 120 } 121 if bytes.HasPrefix(path, n.key) { 122 r, res, prefix, err := t.getWithPath(n.next, path[len(n.key):], strict) 123 if err != nil { 124 return nil, nil, nil, err 125 } 126 n.next = r 127 return curr, res, append(n.key, prefix...), err 128 } 129 if !strict && bytes.HasPrefix(n.key, path) { 130 // path is shorter than prefix, stop seeking 131 return curr, n.next, n.key, nil 132 } 133 default: 134 panic("invalid MPT node type") 135 } 136 return curr, nil, nil, ErrNotFound 137 } 138 139 // Put puts key-value pair in t. 140 func (t *Trie) Put(key, value []byte) error { 141 if len(key) == 0 { 142 return errors.New("key is empty") 143 } else if len(key) > MaxKeyLength { 144 return errors.New("key is too big") 145 } else if len(value) > MaxValueLength { 146 return errors.New("value is too big") 147 } else if value == nil { 148 // (t *Trie).Delete should be used to remove value 149 return errors.New("value is nil") 150 } 151 path := toNibbles(key) 152 n := NewLeafNode(value) 153 r, err := t.putIntoNode(t.root, path, n) 154 if err != nil { 155 return err 156 } 157 t.root = r 158 return nil 159 } 160 161 // putIntoLeaf puts the val to the trie if the current node is a Leaf. 162 // It returns a Node if curr needs to be replaced and an error has occurred, if any. 163 func (t *Trie) putIntoLeaf(curr *LeafNode, path []byte, val Node) (Node, error) { 164 v := val.(*LeafNode) 165 if len(path) == 0 { 166 t.removeRef(curr.Hash(), curr.bytes) 167 t.addRef(val.Hash(), val.Bytes()) 168 return v, nil 169 } 170 171 b := NewBranchNode() 172 b.Children[path[0]] = t.newSubTrie(path[1:], v, true) 173 b.Children[lastChild] = curr 174 t.addRef(b.Hash(), b.bytes) 175 return b, nil 176 } 177 178 // putIntoBranch puts the val to the trie if the current node is a Branch. 179 // It returns the Node if curr needs to be replaced and an error has occurred, if any. 180 func (t *Trie) putIntoBranch(curr *BranchNode, path []byte, val Node) (Node, error) { 181 i, path := splitPath(path) 182 t.removeRef(curr.Hash(), curr.bytes) 183 r, err := t.putIntoNode(curr.Children[i], path, val) 184 if err != nil { 185 return nil, err 186 } 187 curr.Children[i] = r 188 curr.invalidateCache() 189 t.addRef(curr.Hash(), curr.bytes) 190 return curr, nil 191 } 192 193 // putIntoExtension puts the val to the trie if the current node is an Extension. 194 // It returns the Node if curr needs to be replaced and an error has occurred, if any. 195 func (t *Trie) putIntoExtension(curr *ExtensionNode, path []byte, val Node) (Node, error) { 196 t.removeRef(curr.Hash(), curr.bytes) 197 if bytes.HasPrefix(path, curr.key) { 198 r, err := t.putIntoNode(curr.next, path[len(curr.key):], val) 199 if err != nil { 200 return nil, err 201 } 202 curr.next = r 203 curr.invalidateCache() 204 t.addRef(curr.Hash(), curr.bytes) 205 return curr, nil 206 } 207 208 pref := lcp(curr.key, path) 209 lp := len(pref) 210 keyTail := curr.key[lp:] 211 pathTail := path[lp:] 212 213 s1 := t.newSubTrie(keyTail[1:], curr.next, false) 214 b := NewBranchNode() 215 b.Children[keyTail[0]] = s1 216 217 i, pathTail := splitPath(pathTail) 218 s2 := t.newSubTrie(pathTail, val, true) 219 b.Children[i] = s2 220 221 t.addRef(b.Hash(), b.bytes) 222 if lp > 0 { 223 e := NewExtensionNode(bytes.Clone(pref), b) 224 t.addRef(e.Hash(), e.bytes) 225 return e, nil 226 } 227 return b, nil 228 } 229 230 func (t *Trie) putIntoEmpty(path []byte, val Node) (Node, error) { 231 return t.newSubTrie(path, val, true), nil 232 } 233 234 // putIntoHash puts the val to the trie if the current node is a HashNode. 235 // It returns the Node if curr needs to be replaced and an error has occurred, if any. 236 func (t *Trie) putIntoHash(curr *HashNode, path []byte, val Node) (Node, error) { 237 result, err := t.getFromStore(curr.hash) 238 if err != nil { 239 return nil, err 240 } 241 return t.putIntoNode(result, path, val) 242 } 243 244 // newSubTrie creates a new trie containing the node at the provided path. 245 func (t *Trie) newSubTrie(path []byte, val Node, newVal bool) Node { 246 if newVal { 247 t.addRef(val.Hash(), val.Bytes()) 248 } 249 if len(path) == 0 { 250 return val 251 } 252 e := NewExtensionNode(path, val) 253 t.addRef(e.Hash(), e.bytes) 254 return e 255 } 256 257 // putIntoNode puts the val with the provided path inside curr and returns an updated node. 258 // Reference counters are updated for both curr and returned value. 259 func (t *Trie) putIntoNode(curr Node, path []byte, val Node) (Node, error) { 260 switch n := curr.(type) { 261 case *LeafNode: 262 return t.putIntoLeaf(n, path, val) 263 case *BranchNode: 264 return t.putIntoBranch(n, path, val) 265 case *ExtensionNode: 266 return t.putIntoExtension(n, path, val) 267 case *HashNode: 268 return t.putIntoHash(n, path, val) 269 case EmptyNode: 270 return t.putIntoEmpty(path, val) 271 default: 272 panic("invalid MPT node type") 273 } 274 } 275 276 // Delete removes the key from the trie. 277 // It returns no error on a missing key. 278 func (t *Trie) Delete(key []byte) error { 279 if len(key) > MaxKeyLength { 280 return errors.New("key is too big") 281 } 282 path := toNibbles(key) 283 r, err := t.deleteFromNode(t.root, path) 284 if err != nil { 285 return err 286 } 287 t.root = r 288 return nil 289 } 290 291 func (t *Trie) deleteFromBranch(b *BranchNode, path []byte) (Node, error) { 292 i, path := splitPath(path) 293 h := b.Hash() 294 bs := b.bytes 295 r, err := t.deleteFromNode(b.Children[i], path) 296 if err != nil { 297 return nil, err 298 } 299 t.removeRef(h, bs) 300 b.Children[i] = r 301 b.invalidateCache() 302 var count, index int 303 for i := range b.Children { 304 if !isEmpty(b.Children[i]) { 305 index = i 306 count++ 307 } 308 } 309 // count is >= 1 because branch node had at least 2 children before deletion. 310 if count > 1 { 311 t.addRef(b.Hash(), b.bytes) 312 return b, nil 313 } 314 c := b.Children[index] 315 if index == lastChild { 316 return c, nil 317 } 318 if h, ok := c.(*HashNode); ok { 319 c, err = t.getFromStore(h.Hash()) 320 if err != nil { 321 return nil, err 322 } 323 } 324 if e, ok := c.(*ExtensionNode); ok { 325 t.removeRef(e.Hash(), e.bytes) 326 e.key = append([]byte{byte(index)}, e.key...) 327 e.invalidateCache() 328 t.addRef(e.Hash(), e.bytes) 329 return e, nil 330 } 331 332 e := NewExtensionNode([]byte{byte(index)}, c) 333 t.addRef(e.Hash(), e.bytes) 334 return e, nil 335 } 336 337 func (t *Trie) deleteFromExtension(n *ExtensionNode, path []byte) (Node, error) { 338 if !bytes.HasPrefix(path, n.key) { 339 return n, nil 340 } 341 h := n.Hash() 342 bs := n.bytes 343 r, err := t.deleteFromNode(n.next, path[len(n.key):]) 344 if err != nil { 345 return nil, err 346 } 347 t.removeRef(h, bs) 348 switch nxt := r.(type) { 349 case *ExtensionNode: 350 t.removeRef(nxt.Hash(), nxt.bytes) 351 n.key = append(n.key, nxt.key...) 352 n.next = nxt.next 353 case EmptyNode: 354 return nxt, nil 355 case *HashNode: 356 n.next = nxt 357 default: 358 n.next = r 359 } 360 n.invalidateCache() 361 t.addRef(n.Hash(), n.bytes) 362 return n, nil 363 } 364 365 // deleteFromNode removes the value with the provided path from curr and returns an updated node. 366 // Reference counters are updated for both curr and returned value. 367 func (t *Trie) deleteFromNode(curr Node, path []byte) (Node, error) { 368 switch n := curr.(type) { 369 case *LeafNode: 370 if len(path) == 0 { 371 t.removeRef(curr.Hash(), curr.Bytes()) 372 return EmptyNode{}, nil 373 } 374 return curr, nil 375 case *BranchNode: 376 return t.deleteFromBranch(n, path) 377 case *ExtensionNode: 378 return t.deleteFromExtension(n, path) 379 case EmptyNode: 380 return n, nil 381 case *HashNode: 382 newNode, err := t.getFromStore(n.Hash()) 383 if err != nil { 384 return nil, err 385 } 386 return t.deleteFromNode(newNode, path) 387 default: 388 panic("invalid MPT node type") 389 } 390 } 391 392 // StateRoot returns root hash of t. 393 func (t *Trie) StateRoot() util.Uint256 { 394 if isEmpty(t.root) { 395 return util.Uint256{} 396 } 397 return t.root.Hash() 398 } 399 400 func makeStorageKey(mptKey util.Uint256) []byte { 401 return append([]byte{byte(storage.DataMPT)}, mptKey[:]...) 402 } 403 404 // Flush puts every node (except Hash ones) in the trie to the storage. 405 // Because we care about block-level changes only, there is no need to put every 406 // new node to the storage. Normally, flush should be called with every StateRoot persist, i.e. 407 // after every block. 408 func (t *Trie) Flush(index uint32) { 409 key := makeStorageKey(util.Uint256{}) 410 for h, node := range t.refcount { 411 if node.refcount != 0 { 412 copy(key[1:], h[:]) 413 if node.bytes == nil { 414 panic("item not in trie") 415 } 416 if t.mode.RC() { 417 node.initial = t.updateRefCount(h, key, index) 418 if node.initial == 0 { 419 delete(t.refcount, h) 420 } 421 } else if node.refcount > 0 { 422 t.Store.Put(key, node.bytes) 423 } 424 node.refcount = 0 425 } else { 426 delete(t.refcount, h) 427 } 428 } 429 } 430 431 func IsActiveValue(v []byte) bool { 432 return len(v) > 4 && v[len(v)-5] == 1 433 } 434 435 func getFromStore(key []byte, mode TrieMode, store *storage.MemCachedStore) ([]byte, error) { 436 data, err := store.Get(key) 437 if err == nil && mode.GC() && !IsActiveValue(data) { 438 return nil, storage.ErrKeyNotFound 439 } 440 return data, err 441 } 442 443 // updateRefCount should be called only when refcounting is enabled. 444 func (t *Trie) updateRefCount(h util.Uint256, key []byte, index uint32) int32 { 445 if !t.mode.RC() { 446 panic("`updateRefCount` is called, but GC is disabled") 447 } 448 var data []byte 449 node := t.refcount[h] 450 cnt := node.initial 451 if cnt == 0 { 452 // A newly created item which may be in store. 453 var err error 454 data, err = getFromStore(key, t.mode, t.Store) 455 if err == nil { 456 cnt = int32(binary.LittleEndian.Uint32(data[len(data)-4:])) 457 } 458 } 459 if len(data) == 0 { 460 data = append(node.bytes, 1, 0, 0, 0, 0) 461 } 462 cnt += node.refcount 463 switch { 464 case cnt < 0: 465 // BUG: negative reference count 466 panic(fmt.Sprintf("negative reference count: %s new %d, upd %d", h.StringBE(), cnt, t.refcount[h])) 467 case cnt == 0: 468 if !t.mode.GC() { 469 t.Store.Delete(key) 470 } else { 471 data[len(data)-5] = 0 472 binary.LittleEndian.PutUint32(data[len(data)-4:], index) 473 t.Store.Put(key, data) 474 } 475 default: 476 binary.LittleEndian.PutUint32(data[len(data)-4:], uint32(cnt)) 477 t.Store.Put(key, data) 478 } 479 return cnt 480 } 481 482 func (t *Trie) addRef(h util.Uint256, bs []byte) { 483 node := t.refcount[h] 484 if node == nil { 485 t.refcount[h] = &cachedNode{ 486 refcount: 1, 487 bytes: bs, 488 } 489 return 490 } 491 node.refcount++ 492 if node.bytes == nil { 493 node.bytes = bs 494 } 495 } 496 497 func (t *Trie) removeRef(h util.Uint256, bs []byte) { 498 node := t.refcount[h] 499 if node == nil { 500 t.refcount[h] = &cachedNode{ 501 refcount: -1, 502 bytes: bs, 503 } 504 return 505 } 506 node.refcount-- 507 if node.bytes == nil { 508 node.bytes = bs 509 } 510 } 511 512 func (t *Trie) getFromStore(h util.Uint256) (Node, error) { 513 data, err := getFromStore(makeStorageKey(h), t.mode, t.Store) 514 if err != nil { 515 return nil, err 516 } 517 518 var n NodeObject 519 r := io.NewBinReaderFromBuf(data) 520 n.DecodeBinary(r) 521 if r.Err != nil { 522 return nil, r.Err 523 } 524 525 if t.mode.RC() { 526 data = data[:len(data)-5] 527 node := t.refcount[h] 528 if node != nil { 529 node.bytes = data 530 _ = r.ReadB() 531 node.initial = int32(r.ReadU32LE()) 532 } 533 } 534 n.Node.(flushedNode).setCache(data, h) 535 return n.Node, nil 536 } 537 538 // Collapse compresses all nodes at depth n to the hash nodes. 539 // Note: this function does not perform any kind of storage flushing so 540 // `Flush()` should be called explicitly before invoking function. 541 func (t *Trie) Collapse(depth int) { 542 if depth < 0 { 543 panic("negative depth") 544 } 545 t.root = collapse(depth, t.root) 546 t.refcount = make(map[util.Uint256]*cachedNode) 547 } 548 549 func collapse(depth int, node Node) Node { 550 switch node.(type) { 551 case *HashNode, EmptyNode: 552 return node 553 } 554 if depth == 0 { 555 return NewHashNode(node.Hash()) 556 } 557 558 switch n := node.(type) { 559 case *BranchNode: 560 for i := range n.Children { 561 n.Children[i] = collapse(depth-1, n.Children[i]) 562 } 563 case *ExtensionNode: 564 n.next = collapse(depth-1, n.next) 565 case *LeafNode: 566 case *HashNode: 567 default: 568 panic("invalid MPT node type") 569 } 570 return node 571 } 572 573 // Find returns a list of storage key-value pairs whose key is prefixed by the specified 574 // prefix starting from the specified `prefix`+`from` path (not including the item at 575 // the specified `prefix`+`from` path if so). The `max` number of elements is returned at max. 576 func (t *Trie) Find(prefix, from []byte, max int) ([]storage.KeyValue, error) { 577 if len(prefix) > MaxKeyLength { 578 return nil, errors.New("invalid prefix length") 579 } 580 if len(from) > MaxKeyLength-len(prefix) { 581 return nil, errors.New("invalid from length") 582 } 583 prefixP := toNibbles(prefix) 584 fromP := []byte{} 585 if len(from) > 0 { 586 fromP = toNibbles(from) 587 } 588 _, start, path, err := t.getWithPath(t.root, prefixP, false) 589 if err != nil { 590 return nil, fmt.Errorf("failed to determine the start node: %w", err) 591 } 592 path = path[len(prefixP):] 593 594 if len(fromP) > 0 { 595 if len(path) <= len(fromP) && bytes.HasPrefix(fromP, path) { 596 fromP = fromP[len(path):] 597 } else if len(path) > len(fromP) && bytes.HasPrefix(path, fromP) { 598 fromP = []byte{} 599 } else { 600 cmp := bytes.Compare(path, fromP) 601 switch { 602 case cmp < 0: 603 return []storage.KeyValue{}, nil 604 case cmp > 0: 605 fromP = []byte{} 606 } 607 } 608 } 609 610 var ( 611 res []storage.KeyValue 612 count int 613 ) 614 b := NewBillet(t.root.Hash(), t.mode, 0, t.Store) 615 process := func(pathToNode []byte, node Node, _ []byte) bool { 616 if leaf, ok := node.(*LeafNode); ok { 617 if from == nil || !bytes.Equal(pathToNode, from) { // (*Billet).traverse includes `from` path into result if so. Need to filter out manually. 618 res = append(res, storage.KeyValue{ 619 Key: append(bytes.Clone(prefix), pathToNode...), 620 Value: bytes.Clone(leaf.value), 621 }) 622 count++ 623 } 624 } 625 return count >= max 626 } 627 _, err = b.traverse(start, path, fromP, process, false, false) 628 if err != nil && !errors.Is(err, errStop) { 629 return nil, err 630 } 631 return res, nil 632 }