github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/ledger/complete/mtrie/trie/trie.go (about) 1 package trie 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "io" 7 "sync" 8 9 "github.com/onflow/flow-go/ledger" 10 "github.com/onflow/flow-go/ledger/common/bitutils" 11 "github.com/onflow/flow-go/ledger/complete/mtrie/node" 12 ) 13 14 // MTrie represents a perfect in-memory full binary Merkle tree with uniform height. 15 // For a detailed description of the storage model, please consult `mtrie/README.md` 16 // 17 // A MTrie is a thin wrapper around a the trie's root Node. An MTrie implements the 18 // logic for forming MTrie-graphs from the elementary nodes. Specifically: 19 // - how Nodes (graph vertices) form a Trie, 20 // - how register values are read from the trie, 21 // - how Merkle proofs are generated from a trie, and 22 // - how a new Trie with updated values is generated. 23 // 24 // `MTrie`s are _immutable_ data structures. Updating register values is implemented through 25 // copy-on-write, which creates a new `MTrie`. For minimal memory consumption, all sub-tries 26 // that where not affected by the write operation are shared between the original MTrie 27 // (before the register updates) and the updated MTrie (after the register writes). 28 // 29 // MTrie expects that for a specific path, the register's key never changes. 30 // 31 // DEFINITIONS and CONVENTIONS: 32 // - HEIGHT of a node v in a tree is the number of edges on the longest downward path 33 // between v and a tree leaf. The height of a tree is the height of its root. 34 // The height of a Trie is always the height of the fully-expanded tree. 35 type MTrie struct { 36 root *node.Node 37 regCount uint64 // number of registers allocated in the trie 38 regSize uint64 // size of registers allocated in the trie 39 } 40 41 // NewEmptyMTrie returns an empty Mtrie (root is nil) 42 func NewEmptyMTrie() *MTrie { 43 return &MTrie{root: nil} 44 } 45 46 // IsEmpty checks if a trie is empty. 47 // 48 // An empty try doesn't mean a trie with no allocated registers. 49 func (mt *MTrie) IsEmpty() bool { 50 return mt.root == nil 51 } 52 53 // NewMTrie returns a Mtrie given the root 54 func NewMTrie(root *node.Node, regCount uint64, regSize uint64) (*MTrie, error) { 55 if root != nil && root.Height() != ledger.NodeMaxHeight { 56 return nil, fmt.Errorf("height of root node must be %d but is %d, hash: %s", ledger.NodeMaxHeight, root.Height(), root.Hash().String()) 57 } 58 return &MTrie{ 59 root: root, 60 regCount: regCount, 61 regSize: regSize, 62 }, nil 63 } 64 65 // RootHash returns the trie's root hash. 66 // Concurrency safe (as Tries are immutable structures by convention) 67 func (mt *MTrie) RootHash() ledger.RootHash { 68 if mt.IsEmpty() { 69 // case of an empty trie 70 return EmptyTrieRootHash() 71 } 72 return ledger.RootHash(mt.root.Hash()) 73 } 74 75 // AllocatedRegCount returns the number of allocated registers in the trie. 76 // Concurrency safe (as Tries are immutable structures by convention) 77 func (mt *MTrie) AllocatedRegCount() uint64 { 78 return mt.regCount 79 } 80 81 // AllocatedRegSize returns the size (number of bytes) of allocated registers in the trie. 82 // Concurrency safe (as Tries are immutable structures by convention) 83 func (mt *MTrie) AllocatedRegSize() uint64 { 84 return mt.regSize 85 } 86 87 // RootNode returns the Trie's root Node 88 // Concurrency safe (as Tries are immutable structures by convention) 89 func (mt *MTrie) RootNode() *node.Node { 90 return mt.root 91 } 92 93 // String returns the trie's string representation. 94 // Concurrency safe (as Tries are immutable structures by convention) 95 func (mt *MTrie) String() string { 96 if mt.IsEmpty() { 97 return fmt.Sprintf("Empty Trie with default root hash: %v\n", mt.RootHash()) 98 } 99 trieStr := fmt.Sprintf("Trie root hash: %v\n", mt.RootHash()) 100 return trieStr + mt.root.FmtStr("", "") 101 } 102 103 // UnsafeValueSizes returns payload value sizes for the given paths. 104 // UNSAFE: requires _all_ paths to have a length of mt.Height bits. 105 // CAUTION: while getting payload value sizes, `paths` is permuted IN-PLACE for optimized processing. 106 // Return: 107 // - `sizes` []int 108 // For each path, the corresponding payload value size is written into sizes. AFTER 109 // the size operation completes, the order of `path` and `sizes` are such that 110 // for `path[i]` the corresponding register value size is referenced by `sizes[i]`. 111 // 112 // TODO move consistency checks from Forest into Trie to obtain a safe, self-contained API 113 func (mt *MTrie) UnsafeValueSizes(paths []ledger.Path) []int { 114 sizes := make([]int, len(paths)) // pre-allocate slice for the result 115 valueSizes(sizes, paths, mt.root) 116 return sizes 117 } 118 119 // valueSizes returns value sizes of all the registers in `paths“ in subtree with `head` as root node. 120 // For each `path[i]`, the corresponding value size is written into `sizes[i]` for the same index `i`. 121 // CAUTION: 122 // - while reading the payloads, `paths` is permuted IN-PLACE for optimized processing. 123 // - unchecked requirement: all paths must go through the `head` node 124 func valueSizes(sizes []int, paths []ledger.Path, head *node.Node) { 125 // check for empty paths 126 if len(paths) == 0 { 127 return 128 } 129 130 // path not found 131 if head == nil { 132 return 133 } 134 135 // reached a leaf node 136 if head.IsLeaf() { 137 for i, p := range paths { 138 if *head.Path() == p { 139 payload := head.Payload() 140 if payload != nil { 141 sizes[i] = payload.Value().Size() 142 } 143 // NOTE: break isn't used here because precondition 144 // doesn't require paths being deduplicated. 145 } 146 } 147 return 148 } 149 150 // reached an interim node with only one path 151 if len(paths) == 1 { 152 path := paths[0][:] 153 154 // traverse nodes following the path until a leaf node or nil node is reached. 155 // "for" loop helps to skip partition and recursive call when there's only one path to follow. 156 for { 157 depth := ledger.NodeMaxHeight - head.Height() // distance to the tree root 158 bit := bitutils.ReadBit(path, depth) 159 if bit == 0 { 160 head = head.LeftChild() 161 } else { 162 head = head.RightChild() 163 } 164 if head.IsLeaf() { 165 break 166 } 167 } 168 169 valueSizes(sizes, paths, head) 170 return 171 } 172 173 // reached an interim node with more than one paths 174 175 // partition step to quick sort the paths: 176 // lpaths contains all paths that have `0` at the partitionIndex 177 // rpaths contains all paths that have `1` at the partitionIndex 178 depth := ledger.NodeMaxHeight - head.Height() // distance to the tree root 179 partitionIndex := SplitPaths(paths, depth) 180 lpaths, rpaths := paths[:partitionIndex], paths[partitionIndex:] 181 lsizes, rsizes := sizes[:partitionIndex], sizes[partitionIndex:] 182 183 // read values from left and right subtrees in parallel 184 parallelRecursionThreshold := 32 // threshold to avoid the parallelization going too deep in the recursion 185 if len(lpaths) < parallelRecursionThreshold || len(rpaths) < parallelRecursionThreshold { 186 valueSizes(lsizes, lpaths, head.LeftChild()) 187 valueSizes(rsizes, rpaths, head.RightChild()) 188 } else { 189 // concurrent read of left and right subtree 190 wg := sync.WaitGroup{} 191 wg.Add(1) 192 go func() { 193 valueSizes(lsizes, lpaths, head.LeftChild()) 194 wg.Done() 195 }() 196 valueSizes(rsizes, rpaths, head.RightChild()) 197 wg.Wait() // wait for all threads 198 } 199 } 200 201 // ReadSinglePayload reads and returns a payload for a single path. 202 func (mt *MTrie) ReadSinglePayload(path ledger.Path) *ledger.Payload { 203 return readSinglePayload(path, mt.root) 204 } 205 206 // readSinglePayload reads and returns a payload for a single path in subtree with `head` as root node. 207 func readSinglePayload(path ledger.Path, head *node.Node) *ledger.Payload { 208 pathBytes := path[:] 209 210 if head == nil { 211 return ledger.EmptyPayload() 212 } 213 214 depth := ledger.NodeMaxHeight - head.Height() // distance to the tree root 215 216 // Traverse nodes following the path until a leaf node or nil node is reached. 217 for !head.IsLeaf() { 218 bit := bitutils.ReadBit(pathBytes, depth) 219 if bit == 0 { 220 head = head.LeftChild() 221 } else { 222 head = head.RightChild() 223 } 224 depth++ 225 } 226 227 if head != nil && *head.Path() == path { 228 return head.Payload() 229 } 230 231 return ledger.EmptyPayload() 232 } 233 234 // UnsafeRead reads payloads for the given paths. 235 // UNSAFE: requires _all_ paths to have a length of mt.Height bits. 236 // CAUTION: while reading the payloads, `paths` is permuted IN-PLACE for optimized processing. 237 // Return: 238 // - `payloads` []*ledger.Payload 239 // For each path, the corresponding payload is written into payloads. AFTER 240 // the read operation completes, the order of `path` and `payloads` are such that 241 // for `path[i]` the corresponding register value is referenced by 0`payloads[i]`. 242 // 243 // TODO move consistency checks from Forest into Trie to obtain a safe, self-contained API 244 func (mt *MTrie) UnsafeRead(paths []ledger.Path) []*ledger.Payload { 245 payloads := make([]*ledger.Payload, len(paths)) // pre-allocate slice for the result 246 read(payloads, paths, mt.root) 247 return payloads 248 } 249 250 // read reads all the registers in subtree with `head` as root node. For each 251 // `path[i]`, the corresponding payload is written into `payloads[i]` for the same index `i`. 252 // CAUTION: 253 // - while reading the payloads, `paths` is permuted IN-PLACE for optimized processing. 254 // - unchecked requirement: all paths must go through the `head` node 255 func read(payloads []*ledger.Payload, paths []ledger.Path, head *node.Node) { 256 // check for empty paths 257 if len(paths) == 0 { 258 return 259 } 260 261 // path not found 262 if head == nil { 263 for i := range paths { 264 payloads[i] = ledger.EmptyPayload() 265 } 266 return 267 } 268 269 // reached a leaf node 270 if head.IsLeaf() { 271 for i, p := range paths { 272 if *head.Path() == p { 273 payloads[i] = head.Payload() 274 } else { 275 payloads[i] = ledger.EmptyPayload() 276 } 277 } 278 return 279 } 280 281 // reached an interim node 282 if len(paths) == 1 { 283 // call readSinglePayload to skip partition and recursive calls when there is only one path 284 payloads[0] = readSinglePayload(paths[0], head) 285 return 286 } 287 288 // partition step to quick sort the paths: 289 // lpaths contains all paths that have `0` at the partitionIndex 290 // rpaths contains all paths that have `1` at the partitionIndex 291 depth := ledger.NodeMaxHeight - head.Height() // distance to the tree root 292 partitionIndex := SplitPaths(paths, depth) 293 lpaths, rpaths := paths[:partitionIndex], paths[partitionIndex:] 294 lpayloads, rpayloads := payloads[:partitionIndex], payloads[partitionIndex:] 295 296 // read values from left and right subtrees in parallel 297 parallelRecursionThreshold := 32 // threshold to avoid the parallelization going too deep in the recursion 298 if len(lpaths) < parallelRecursionThreshold || len(rpaths) < parallelRecursionThreshold { 299 read(lpayloads, lpaths, head.LeftChild()) 300 read(rpayloads, rpaths, head.RightChild()) 301 } else { 302 // concurrent read of left and right subtree 303 wg := sync.WaitGroup{} 304 wg.Add(1) 305 go func() { 306 read(lpayloads, lpaths, head.LeftChild()) 307 wg.Done() 308 }() 309 read(rpayloads, rpaths, head.RightChild()) 310 wg.Wait() // wait for all threads 311 } 312 } 313 314 // NewTrieWithUpdatedRegisters constructs a new trie containing all registers from the parent trie, 315 // and returns: 316 // - updated trie 317 // - max depth touched during update (this isn't affected by prune flag) 318 // - error 319 // 320 // The key-value pairs specify the registers whose values are supposed to hold updated values 321 // compared to the parent trie. Constructing the new trie is done in a COPY-ON-WRITE manner: 322 // - The original trie remains unchanged. 323 // - subtries that remain unchanged are from the parent trie instead of copied. 324 // 325 // UNSAFE: method requires the following conditions to be satisfied: 326 // - keys are NOT duplicated 327 // - requires _all_ paths to have a length of mt.Height bits. 328 // 329 // CAUTION: `updatedPaths` and `updatedPayloads` are permuted IN-PLACE for optimized processing. 330 // CAUTION: MTrie expects that for a specific path, the payload's key never changes. 331 // TODO: move consistency checks from MForest to here, to make API safe and self-contained 332 func NewTrieWithUpdatedRegisters( 333 parentTrie *MTrie, 334 updatedPaths []ledger.Path, 335 updatedPayloads []ledger.Payload, 336 prune bool, 337 ) (*MTrie, uint16, error) { 338 updatedRoot, regCountDelta, regSizeDelta, lowestHeightTouched := update( 339 ledger.NodeMaxHeight, 340 parentTrie.root, 341 updatedPaths, 342 updatedPayloads, 343 nil, 344 prune, 345 ) 346 347 updatedTrieRegCount := int64(parentTrie.AllocatedRegCount()) + regCountDelta 348 updatedTrieRegSize := int64(parentTrie.AllocatedRegSize()) + regSizeDelta 349 maxDepthTouched := uint16(ledger.NodeMaxHeight - lowestHeightTouched) 350 351 updatedTrie, err := NewMTrie(updatedRoot, uint64(updatedTrieRegCount), uint64(updatedTrieRegSize)) 352 if err != nil { 353 return nil, 0, fmt.Errorf("constructing updated trie failed: %w", err) 354 } 355 return updatedTrie, maxDepthTouched, nil 356 } 357 358 // updateResult is a wrapper of return values from update(). 359 // It's used to communicate values from goroutine. 360 type updateResult struct { 361 child *node.Node 362 allocatedRegCountDelta int64 363 allocatedRegSizeDelta int64 364 lowestHeightTouched int 365 } 366 367 // update traverses the subtree recursively and create new nodes with 368 // the updated payloads on the given paths 369 // 370 // it returns: 371 // - new updated node or original node if nothing was updated 372 // - allocated register count delta in subtrie (allocatedRegCountDelta) 373 // - allocated register size delta in subtrie (allocatedRegSizeDelta) 374 // - lowest height reached during recursive update in subtrie (lowestHeightTouched) 375 // 376 // update also compact a subtree into a single compact leaf node in the case where 377 // there is only 1 payload stored in the subtree. 378 // 379 // allocatedRegCountDelta and allocatedRegSizeDelta are used to compute updated 380 // trie's allocated register count and size. lowestHeightTouched is used to 381 // compute max depth touched during update. 382 // CAUTION: while updating, `paths` and `payloads` are permuted IN-PLACE for optimized processing. 383 // UNSAFE: method requires the following conditions to be satisfied: 384 // - paths all share the same common prefix [0 : mt.maxHeight-1 - nodeHeight) 385 // (excluding the bit at index headHeight) 386 // - paths are NOT duplicated 387 func update( 388 nodeHeight int, // the height of the node during traversing the subtree 389 currentNode *node.Node, // the current node on the travesing path, if it's nil it means the trie has no node on this path 390 paths []ledger.Path, // the paths to update the payloads 391 payloads []ledger.Payload, // the payloads to be updated at the given paths 392 compactLeaf *node.Node, // a compact leaf node from its ancester, it could be nil 393 prune bool, // prune is a flag for whether pruning nodes with empty payload. not pruning is useful for generating proof, expecially non-inclusion proof 394 ) (n *node.Node, allocatedRegCountDelta int64, allocatedRegSizeDelta int64, lowestHeightTouched int) { 395 // No new path to update 396 if len(paths) == 0 { 397 if compactLeaf != nil { 398 // if a compactLeaf from a higher height is still left, 399 // then expand the compact leaf node to the current height by creating a new compact leaf 400 // node with the same path and payload. 401 // The old node shouldn't be recycled as it is still used by the tree copy before the update. 402 n = node.NewLeaf(*compactLeaf.Path(), compactLeaf.Payload(), nodeHeight) 403 return n, 0, 0, nodeHeight 404 } 405 // if no path to update and there is no compact leaf node on this path, we return 406 // the current node regardless it exists or not. 407 return currentNode, 0, 0, nodeHeight 408 } 409 410 if len(paths) == 1 && currentNode == nil && compactLeaf == nil { 411 // if there is only 1 path to update, and the existing tree has no node on this path, also 412 // no compact leaf node from its ancester, it means we are storing a payload on a new path, 413 n = node.NewLeaf(paths[0], payloads[0].DeepCopy(), nodeHeight) 414 if payloads[0].IsEmpty() { 415 // if we are storing an empty node, then no register is allocated 416 // allocatedRegCountDelta and allocatedRegSizeDelta should both be 0 417 return n, 0, 0, nodeHeight 418 } 419 // if we are storing a non-empty node, we are allocating a new register 420 return n, 1, int64(payloads[0].Size()), nodeHeight 421 } 422 423 if currentNode != nil && currentNode.IsLeaf() { // if we're here then compactLeaf == nil 424 // check if the current node path is among the updated paths 425 found := false 426 currentPath := *currentNode.Path() 427 for i, p := range paths { 428 if p == currentPath { 429 // the case where the recursion stops: only one path to update 430 if len(paths) == 1 { 431 // check if the only path to update has the same payload. 432 // if payload is the same, we could skip the update to avoid creating duplicated node 433 if !currentNode.Payload().ValueEquals(&payloads[i]) { 434 n = node.NewLeaf(paths[i], payloads[i].DeepCopy(), nodeHeight) 435 436 allocatedRegCountDelta, allocatedRegSizeDelta = 437 computeAllocatedRegDeltas(currentNode.Payload(), &payloads[i]) 438 439 return n, allocatedRegCountDelta, allocatedRegSizeDelta, nodeHeight 440 } 441 // avoid creating a new node when the same payload is written 442 return currentNode, 0, 0, nodeHeight 443 } 444 // the case where the recursion carries on: len(paths)>1 445 found = true 446 447 allocatedRegCountDelta, allocatedRegSizeDelta = 448 computeAllocatedRegDeltasFromHigherHeight(currentNode.Payload()) 449 450 break 451 } 452 } 453 if !found { 454 // if the current node carries a path not included in the input path, then the current node 455 // represents a compact leaf that needs to be carried down the recursion. 456 compactLeaf = currentNode 457 } 458 } 459 460 // in the remaining code: 461 // - either len(paths) > 1 462 // - or len(paths) == 1 and compactLeaf!= nil 463 // - or len(paths) == 1 and currentNode != nil && !currentNode.IsLeaf() 464 465 // Split paths and payloads to recurse: 466 // lpaths contains all paths that have `0` at the partitionIndex 467 // rpaths contains all paths that have `1` at the partitionIndex 468 depth := ledger.NodeMaxHeight - nodeHeight // distance to the tree root 469 partitionIndex := splitByPath(paths, payloads, depth) 470 lpaths, rpaths := paths[:partitionIndex], paths[partitionIndex:] 471 lpayloads, rpayloads := payloads[:partitionIndex], payloads[partitionIndex:] 472 473 // check if there is a compact leaf that needs to get deep to height 0 474 var lcompactLeaf, rcompactLeaf *node.Node 475 if compactLeaf != nil { 476 // if yes, check which branch it will go to. 477 path := *compactLeaf.Path() 478 if bitutils.ReadBit(path[:], depth) == 0 { 479 lcompactLeaf = compactLeaf 480 } else { 481 rcompactLeaf = compactLeaf 482 } 483 } 484 485 // set the node children 486 var oldLeftChild, oldRightChild *node.Node 487 if currentNode != nil { 488 oldLeftChild = currentNode.LeftChild() 489 oldRightChild = currentNode.RightChild() 490 } 491 492 // recurse over each branch 493 var newLeftChild, newRightChild *node.Node 494 var lRegCountDelta, rRegCountDelta int64 495 var lRegSizeDelta, rRegSizeDelta int64 496 var lLowestHeightTouched, rLowestHeightTouched int 497 parallelRecursionThreshold := 16 498 if len(lpaths) < parallelRecursionThreshold || len(rpaths) < parallelRecursionThreshold { 499 // runtime optimization: if there are _no_ updates for either left or right sub-tree, proceed single-threaded 500 newLeftChild, lRegCountDelta, lRegSizeDelta, lLowestHeightTouched = update(nodeHeight-1, oldLeftChild, lpaths, lpayloads, lcompactLeaf, prune) 501 newRightChild, rRegCountDelta, rRegSizeDelta, rLowestHeightTouched = update(nodeHeight-1, oldRightChild, rpaths, rpayloads, rcompactLeaf, prune) 502 } else { 503 // runtime optimization: process the left child in a separate thread 504 505 // Since we're receiving 4 values from goroutine, use a 506 // struct and channel to reduce allocs/op. 507 // Although WaitGroup approach can be faster than channel (esp. with 2+ goroutines), 508 // we only use 1 goroutine here and need to communicate results from it. So using 509 // channel is faster and uses fewer allocs/op in this case. 510 results := make(chan updateResult, 1) 511 go func(retChan chan<- updateResult) { 512 child, regCountDelta, regSizeDelta, lowestHeightTouched := update(nodeHeight-1, oldLeftChild, lpaths, lpayloads, lcompactLeaf, prune) 513 retChan <- updateResult{child, regCountDelta, regSizeDelta, lowestHeightTouched} 514 }(results) 515 516 newRightChild, rRegCountDelta, rRegSizeDelta, rLowestHeightTouched = update(nodeHeight-1, oldRightChild, rpaths, rpayloads, rcompactLeaf, prune) 517 518 // Wait for results from goroutine. 519 ret := <-results 520 newLeftChild, lRegCountDelta, lRegSizeDelta, lLowestHeightTouched = ret.child, ret.allocatedRegCountDelta, ret.allocatedRegSizeDelta, ret.lowestHeightTouched 521 } 522 523 allocatedRegCountDelta += lRegCountDelta + rRegCountDelta 524 allocatedRegSizeDelta += lRegSizeDelta + rRegSizeDelta 525 lowestHeightTouched = minInt(lLowestHeightTouched, rLowestHeightTouched) 526 527 // mitigate storage exhaustion attack: avoids creating a new node when the exact same 528 // payload is re-written at a register. CAUTION: we only check that the children are 529 // unchanged. This is only sufficient for interim nodes (for leaf nodes, the children 530 // might be unchanged, i.e. both nil, but the payload could have changed). 531 // In case the current node was a leaf, we _cannot reuse_ it, because we potentially 532 // updated registers in the sub-trie 533 if !currentNode.IsLeaf() && newLeftChild == oldLeftChild && newRightChild == oldRightChild { 534 return currentNode, 0, 0, lowestHeightTouched 535 } 536 537 // if prune is on, then will check and create a compact leaf node if one child is nil, and the 538 // other child is a leaf node 539 if prune { 540 n = node.NewInterimCompactifiedNode(nodeHeight, newLeftChild, newRightChild) 541 return n, allocatedRegCountDelta, allocatedRegSizeDelta, lowestHeightTouched 542 } 543 544 n = node.NewInterimNode(nodeHeight, newLeftChild, newRightChild) 545 return n, allocatedRegCountDelta, allocatedRegSizeDelta, lowestHeightTouched 546 } 547 548 // computeAllocatedRegDeltasFromHigherHeight returns the deltas 549 // needed to compute the allocated reg count and reg size when 550 // a payload is updated or unallocated at a lower height. 551 func computeAllocatedRegDeltasFromHigherHeight(oldPayload *ledger.Payload) (allocatedRegCountDelta, allocatedRegSizeDelta int64) { 552 if !oldPayload.IsEmpty() { 553 // Allocated register will be updated or unallocated at lower height. 554 allocatedRegCountDelta-- 555 } 556 oldPayloadSize := oldPayload.Size() 557 allocatedRegSizeDelta -= int64(oldPayloadSize) 558 return 559 } 560 561 // computeAllocatedRegDeltas returns the allocated reg count 562 // and reg size deltas computed from old payload and new payload. 563 // PRECONDITION: !oldPayload.Equals(newPayload) 564 func computeAllocatedRegDeltas(oldPayload, newPayload *ledger.Payload) (allocatedRegCountDelta, allocatedRegSizeDelta int64) { 565 allocatedRegCountDelta = 0 566 if newPayload.IsEmpty() { 567 // Old payload is not empty while new payload is empty. 568 // Allocated register will be unallocated. 569 allocatedRegCountDelta = -1 570 } else if oldPayload.IsEmpty() { 571 // Old payload is empty while new payload is not empty. 572 // Unallocated register will be allocated. 573 allocatedRegCountDelta = 1 574 } 575 576 oldPayloadSize := oldPayload.Size() 577 newPayloadSize := newPayload.Size() 578 allocatedRegSizeDelta = int64(newPayloadSize - oldPayloadSize) 579 return 580 } 581 582 // UnsafeProofs provides proofs for the given paths. 583 // 584 // CAUTION: while updating, `paths` and `proofs` are permuted IN-PLACE for optimized processing. 585 // UNSAFE: requires _all_ paths to have a length of mt.Height bits. 586 // Paths in the input query don't have to be deduplicated, though deduplication would 587 // result in allocating less dynamic memory to store the proofs. 588 func (mt *MTrie) UnsafeProofs(paths []ledger.Path) *ledger.TrieBatchProof { 589 batchProofs := ledger.NewTrieBatchProofWithEmptyProofs(len(paths)) 590 prove(mt.root, paths, batchProofs.Proofs) 591 return batchProofs 592 } 593 594 // prove traverses the subtree and stores proofs for the given register paths in 595 // the provided `proofs` slice 596 // CAUTION: while updating, `paths` and `proofs` are permuted IN-PLACE for optimized processing. 597 // UNSAFE: method requires the following conditions to be satisfied: 598 // - paths all share the same common prefix [0 : mt.maxHeight-1 - nodeHeight) 599 // (excluding the bit at index headHeight) 600 func prove(head *node.Node, paths []ledger.Path, proofs []*ledger.TrieProof) { 601 // check for empty paths 602 if len(paths) == 0 { 603 return 604 } 605 606 // we've reached the end of a trie 607 // and path is not found (noninclusion proof) 608 if head == nil { 609 // by default, proofs are non-inclusion proofs 610 return 611 } 612 613 // we've reached a leaf 614 if head.IsLeaf() { 615 for i, path := range paths { 616 // value matches (inclusion proof) 617 if *head.Path() == path { 618 proofs[i].Path = *head.Path() 619 proofs[i].Payload = head.Payload() 620 proofs[i].Inclusion = true 621 } 622 } 623 // by default, proofs are non-inclusion proofs 624 return 625 } 626 627 // increment steps for all the proofs 628 for _, p := range proofs { 629 p.Steps++ 630 } 631 632 // partition step to quick sort the paths: 633 // lpaths contains all paths that have `0` at the partitionIndex 634 // rpaths contains all paths that have `1` at the partitionIndex 635 depth := ledger.NodeMaxHeight - head.Height() // distance to the tree root 636 partitionIndex := splitTrieProofsByPath(paths, proofs, depth) 637 lpaths, rpaths := paths[:partitionIndex], paths[partitionIndex:] 638 lproofs, rproofs := proofs[:partitionIndex], proofs[partitionIndex:] 639 640 parallelRecursionThreshold := 64 // threshold to avoid the parallelization going too deep in the recursion 641 if len(lpaths) < parallelRecursionThreshold || len(rpaths) < parallelRecursionThreshold { 642 // runtime optimization: below the parallelRecursionThreshold, we proceed single-threaded 643 addSiblingTrieHashToProofs(head.RightChild(), depth, lproofs) 644 prove(head.LeftChild(), lpaths, lproofs) 645 646 addSiblingTrieHashToProofs(head.LeftChild(), depth, rproofs) 647 prove(head.RightChild(), rpaths, rproofs) 648 } else { 649 wg := sync.WaitGroup{} 650 wg.Add(1) 651 go func() { 652 addSiblingTrieHashToProofs(head.RightChild(), depth, lproofs) 653 prove(head.LeftChild(), lpaths, lproofs) 654 wg.Done() 655 }() 656 657 addSiblingTrieHashToProofs(head.LeftChild(), depth, rproofs) 658 prove(head.RightChild(), rpaths, rproofs) 659 wg.Wait() 660 } 661 } 662 663 // addSiblingTrieHashToProofs inspects the sibling Trie and adds its root hash 664 // to the proofs, if the trie contains non-empty registers (i.e. the 665 // siblingTrie has a non-default hash). 666 func addSiblingTrieHashToProofs(siblingTrie *node.Node, depth int, proofs []*ledger.TrieProof) { 667 if siblingTrie == nil || len(proofs) == 0 { 668 return 669 } 670 671 // This code is necessary, because we do not remove nodes from the trie 672 // when a register is deleted. Instead, we just set the respective leaf's 673 // payload to empty. While this will cause the lead's hash to become the 674 // default hash, the node itself remains as part of the trie. 675 // However, a proof has the convention that the hash of the sibling trie 676 // should only be included, if it is _non-default_. Therefore, we can 677 // neither use `siblingTrie == nil` nor `siblingTrie.RegisterCount == 0`, 678 // as the sibling trie might contain leaves with default value (which are 679 // still counted as occupied registers) 680 // TODO: On update, prune subtries which only contain empty registers. 681 // Then, a child is nil if and only if the subtrie is empty. 682 683 nodeHash := siblingTrie.Hash() 684 isDef := nodeHash == ledger.GetDefaultHashForHeight(siblingTrie.Height()) 685 if !isDef { // in proofs, we only provide non-default value hashes 686 for _, p := range proofs { 687 bitutils.SetBit(p.Flags, depth) 688 p.Interims = append(p.Interims, nodeHash) 689 } 690 } 691 } 692 693 // Equals compares two tries for equality. 694 // Tries are equal iff they store the same data (i.e. root hash matches) 695 // and their number and height are identical 696 func (mt *MTrie) Equals(o *MTrie) bool { 697 if o == nil { 698 return false 699 } 700 return o.RootHash() == mt.RootHash() 701 } 702 703 // DumpAsJSON dumps the trie key value pairs to a file having each key value pair as a json row 704 func (mt *MTrie) DumpAsJSON(w io.Writer) error { 705 706 // Use encoder to prevent building entire trie in memory 707 enc := json.NewEncoder(w) 708 709 err := dumpAsJSON(mt.root, enc) 710 if err != nil { 711 return err 712 } 713 714 return nil 715 } 716 717 // dumpAsJSON serializes the sub-trie with root n to json and feeds it into encoder 718 func dumpAsJSON(n *node.Node, encoder *json.Encoder) error { 719 if n.IsLeaf() { 720 if n != nil { 721 err := encoder.Encode(n.Payload()) 722 if err != nil { 723 return err 724 } 725 } 726 return nil 727 } 728 729 if lChild := n.LeftChild(); lChild != nil { 730 err := dumpAsJSON(lChild, encoder) 731 if err != nil { 732 return err 733 } 734 } 735 736 if rChild := n.RightChild(); rChild != nil { 737 err := dumpAsJSON(rChild, encoder) 738 if err != nil { 739 return err 740 } 741 } 742 return nil 743 } 744 745 // EmptyTrieRootHash returns the rootHash of an empty Trie for the specified path size [bytes] 746 func EmptyTrieRootHash() ledger.RootHash { 747 return ledger.RootHash(ledger.GetDefaultHashForHeight(ledger.NodeMaxHeight)) 748 } 749 750 // AllPayloads returns all payloads 751 func (mt *MTrie) AllPayloads() []*ledger.Payload { 752 return mt.root.AllPayloads() 753 } 754 755 // IsAValidTrie verifies the content of the trie for potential issues 756 func (mt *MTrie) IsAValidTrie() bool { 757 // TODO add checks on the health of node max height ... 758 return mt.root.VerifyCachedHash() 759 } 760 761 // splitByPath permutes the input paths to be partitioned into 2 parts. The first part contains paths with a zero bit 762 // at the input bitIndex, the second part contains paths with a one at the bitIndex. The index of partition 763 // is returned. The same permutation is applied to the payloads slice. 764 // 765 // This would be the partition step of an ascending quick sort of paths (lexicographic order) 766 // with the pivot being the path with all zeros and 1 at bitIndex. 767 // The comparison of paths is only based on the bit at bitIndex, the function therefore assumes all paths have 768 // equal bits from 0 to bitIndex-1 769 // 770 // For instance, if `paths` contains the following 3 paths, and bitIndex is `1`: 771 // [[0,0,1,1], [0,1,0,1], [0,0,0,1]] 772 // then `splitByPath` returns 2 and updates `paths` into: 773 // [[0,0,1,1], [0,0,0,1], [0,1,0,1]] 774 func splitByPath(paths []ledger.Path, payloads []ledger.Payload, bitIndex int) int { 775 i := 0 776 for j, path := range paths { 777 bit := bitutils.ReadBit(path[:], bitIndex) 778 if bit == 0 { 779 paths[i], paths[j] = paths[j], paths[i] 780 payloads[i], payloads[j] = payloads[j], payloads[i] 781 i++ 782 } 783 } 784 return i 785 } 786 787 // SplitPaths permutes the input paths to be partitioned into 2 parts. The first part contains paths with a zero bit 788 // at the input bitIndex, the second part contains paths with a one at the bitIndex. The index of partition 789 // is returned. 790 // 791 // This would be the partition step of an ascending quick sort of paths (lexicographic order) 792 // with the pivot being the path with all zeros and 1 at bitIndex. 793 // The comparison of paths is only based on the bit at bitIndex, the function therefore assumes all paths have 794 // equal bits from 0 to bitIndex-1 795 func SplitPaths(paths []ledger.Path, bitIndex int) int { 796 i := 0 797 for j, path := range paths { 798 bit := bitutils.ReadBit(path[:], bitIndex) 799 if bit == 0 { 800 paths[i], paths[j] = paths[j], paths[i] 801 i++ 802 } 803 } 804 return i 805 } 806 807 // splitTrieProofsByPath permutes the input paths to be partitioned into 2 parts. The first part contains paths 808 // with a zero bit at the input bitIndex, the second part contains paths with a one at the bitIndex. The index 809 // of partition is returned. The same permutation is applied to the proofs slice. 810 // 811 // This would be the partition step of an ascending quick sort of paths (lexicographic order) 812 // with the pivot being the path with all zeros and 1 at bitIndex. 813 // The comparison of paths is only based on the bit at bitIndex, the function therefore assumes all paths have 814 // equal bits from 0 to bitIndex-1 815 func splitTrieProofsByPath(paths []ledger.Path, proofs []*ledger.TrieProof, bitIndex int) int { 816 i := 0 817 for j, path := range paths { 818 bit := bitutils.ReadBit(path[:], bitIndex) 819 if bit == 0 { 820 paths[i], paths[j] = paths[j], paths[i] 821 proofs[i], proofs[j] = proofs[j], proofs[i] 822 i++ 823 } 824 } 825 return i 826 } 827 828 func minInt(a, b int) int { 829 if a < b { 830 return a 831 } 832 return b 833 } 834 835 // TraverseNodes traverses all nodes of the trie in DFS order 836 func TraverseNodes(trie *MTrie, processNode func(*node.Node) error) error { 837 return traverseRecursive(trie.root, processNode) 838 } 839 840 func traverseRecursive(n *node.Node, processNode func(*node.Node) error) error { 841 if n == nil { 842 return nil 843 } 844 845 err := processNode(n) 846 if err != nil { 847 return err 848 } 849 850 err = traverseRecursive(n.LeftChild(), processNode) 851 if err != nil { 852 return err 853 } 854 855 err = traverseRecursive(n.RightChild(), processNode) 856 if err != nil { 857 return err 858 } 859 860 return nil 861 }