github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/chain/trie/iterator.go (about) 1 package trie 2 3 import ( 4 "bytes" 5 "container/heap" 6 "errors" 7 8 "github.com/neatlab/neatio/utilities/common" 9 "github.com/neatlab/neatio/utilities/rlp" 10 ) 11 12 type Iterator struct { 13 nodeIt NodeIterator 14 15 Key []byte 16 Value []byte 17 Err error 18 } 19 20 func NewIterator(it NodeIterator) *Iterator { 21 return &Iterator{ 22 nodeIt: it, 23 } 24 } 25 26 func (it *Iterator) Next() bool { 27 for it.nodeIt.Next(true) { 28 if it.nodeIt.Leaf() { 29 it.Key = it.nodeIt.LeafKey() 30 it.Value = it.nodeIt.LeafBlob() 31 return true 32 } 33 } 34 it.Key = nil 35 it.Value = nil 36 it.Err = it.nodeIt.Error() 37 return false 38 } 39 40 func (it *Iterator) Prove() [][]byte { 41 return it.nodeIt.LeafProof() 42 } 43 44 type NodeIterator interface { 45 Next(bool) bool 46 47 Error() error 48 49 Hash() common.Hash 50 51 Parent() common.Hash 52 53 Path() []byte 54 55 Leaf() bool 56 57 LeafKey() []byte 58 59 LeafBlob() []byte 60 61 LeafProof() [][]byte 62 } 63 64 type nodeIteratorState struct { 65 hash common.Hash 66 node node 67 parent common.Hash 68 index int 69 pathlen int 70 } 71 72 type nodeIterator struct { 73 trie *Trie 74 stack []*nodeIteratorState 75 path []byte 76 err error 77 } 78 79 var errIteratorEnd = errors.New("end of iteration") 80 81 type seekError struct { 82 key []byte 83 err error 84 } 85 86 func (e seekError) Error() string { 87 return "seek error: " + e.err.Error() 88 } 89 90 func newNodeIterator(trie *Trie, start []byte) NodeIterator { 91 if trie.Hash() == emptyState { 92 return new(nodeIterator) 93 } 94 it := &nodeIterator{trie: trie} 95 it.err = it.seek(start) 96 return it 97 } 98 99 func (it *nodeIterator) Hash() common.Hash { 100 if len(it.stack) == 0 { 101 return common.Hash{} 102 } 103 return it.stack[len(it.stack)-1].hash 104 } 105 106 func (it *nodeIterator) Parent() common.Hash { 107 if len(it.stack) == 0 { 108 return common.Hash{} 109 } 110 return it.stack[len(it.stack)-1].parent 111 } 112 113 func (it *nodeIterator) Leaf() bool { 114 return hasTerm(it.path) 115 } 116 117 func (it *nodeIterator) LeafKey() []byte { 118 if len(it.stack) > 0 { 119 if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { 120 return hexToKeybytes(it.path) 121 } 122 } 123 panic("not at leaf") 124 } 125 126 func (it *nodeIterator) LeafBlob() []byte { 127 if len(it.stack) > 0 { 128 if node, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { 129 return []byte(node) 130 } 131 } 132 panic("not at leaf") 133 } 134 135 func (it *nodeIterator) LeafProof() [][]byte { 136 if len(it.stack) > 0 { 137 if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok { 138 hasher := newHasher(nil) 139 defer returnHasherToPool(hasher) 140 141 proofs := make([][]byte, 0, len(it.stack)) 142 143 for i, item := range it.stack[:len(it.stack)-1] { 144 145 node, _, _ := hasher.hashChildren(item.node, nil) 146 hashed, _ := hasher.store(node, nil, false) 147 if _, ok := hashed.(hashNode); ok || i == 0 { 148 enc, _ := rlp.EncodeToBytes(node) 149 proofs = append(proofs, enc) 150 } 151 } 152 return proofs 153 } 154 } 155 panic("not at leaf") 156 } 157 158 func (it *nodeIterator) Path() []byte { 159 return it.path 160 } 161 162 func (it *nodeIterator) Error() error { 163 if it.err == errIteratorEnd { 164 return nil 165 } 166 if seek, ok := it.err.(seekError); ok { 167 return seek.err 168 } 169 return it.err 170 } 171 172 func (it *nodeIterator) Next(descend bool) bool { 173 if it.err == errIteratorEnd { 174 return false 175 } 176 if seek, ok := it.err.(seekError); ok { 177 if it.err = it.seek(seek.key); it.err != nil { 178 return false 179 } 180 } 181 182 state, parentIndex, path, err := it.peek(descend) 183 it.err = err 184 if it.err != nil { 185 return false 186 } 187 it.push(state, parentIndex, path) 188 return true 189 } 190 191 func (it *nodeIterator) seek(prefix []byte) error { 192 193 key := keybytesToHex(prefix) 194 key = key[:len(key)-1] 195 196 for { 197 state, parentIndex, path, err := it.peek(bytes.HasPrefix(key, it.path)) 198 if err == errIteratorEnd { 199 return errIteratorEnd 200 } else if err != nil { 201 return seekError{prefix, err} 202 } else if bytes.Compare(path, key) >= 0 { 203 return nil 204 } 205 it.push(state, parentIndex, path) 206 } 207 } 208 209 func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, error) { 210 if len(it.stack) == 0 { 211 212 root := it.trie.Hash() 213 state := &nodeIteratorState{node: it.trie.root, index: -1} 214 if root != emptyRoot { 215 state.hash = root 216 } 217 err := state.resolve(it.trie, nil) 218 return state, nil, nil, err 219 } 220 if !descend { 221 222 it.pop() 223 } 224 225 for len(it.stack) > 0 { 226 parent := it.stack[len(it.stack)-1] 227 ancestor := parent.hash 228 if (ancestor == common.Hash{}) { 229 ancestor = parent.parent 230 } 231 state, path, ok := it.nextChild(parent, ancestor) 232 if ok { 233 if err := state.resolve(it.trie, path); err != nil { 234 return parent, &parent.index, path, err 235 } 236 return state, &parent.index, path, nil 237 } 238 239 it.pop() 240 } 241 return nil, nil, nil, errIteratorEnd 242 } 243 244 func (st *nodeIteratorState) resolve(tr *Trie, path []byte) error { 245 if hash, ok := st.node.(hashNode); ok { 246 resolved, err := tr.resolveHash(hash, path) 247 if err != nil { 248 return err 249 } 250 st.node = resolved 251 st.hash = common.BytesToHash(hash) 252 } 253 return nil 254 } 255 256 func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Hash) (*nodeIteratorState, []byte, bool) { 257 switch node := parent.node.(type) { 258 case *fullNode: 259 260 for i := parent.index + 1; i < len(node.Children); i++ { 261 side := node.Children[i] 262 if side != nil { 263 hash, _ := side.cache() 264 state := &nodeIteratorState{ 265 hash: common.BytesToHash(hash), 266 node: side, 267 parent: ancestor, 268 index: -1, 269 pathlen: len(it.path), 270 } 271 path := append(it.path, byte(i)) 272 parent.index = i - 1 273 return state, path, true 274 } 275 } 276 case *shortNode: 277 278 if parent.index < 0 { 279 hash, _ := node.Val.cache() 280 state := &nodeIteratorState{ 281 hash: common.BytesToHash(hash), 282 node: node.Val, 283 parent: ancestor, 284 index: -1, 285 pathlen: len(it.path), 286 } 287 path := append(it.path, node.Key...) 288 return state, path, true 289 } 290 } 291 return parent, it.path, false 292 } 293 294 func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []byte) { 295 it.path = path 296 it.stack = append(it.stack, state) 297 if parentIndex != nil { 298 *parentIndex++ 299 } 300 } 301 302 func (it *nodeIterator) pop() { 303 parent := it.stack[len(it.stack)-1] 304 it.path = it.path[:parent.pathlen] 305 it.stack = it.stack[:len(it.stack)-1] 306 } 307 308 func compareNodes(a, b NodeIterator) int { 309 if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 { 310 return cmp 311 } 312 if a.Leaf() && !b.Leaf() { 313 return -1 314 } else if b.Leaf() && !a.Leaf() { 315 return 1 316 } 317 if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 { 318 return cmp 319 } 320 if a.Leaf() && b.Leaf() { 321 return bytes.Compare(a.LeafBlob(), b.LeafBlob()) 322 } 323 return 0 324 } 325 326 type differenceIterator struct { 327 a, b NodeIterator 328 eof bool 329 count int 330 } 331 332 func NewDifferenceIterator(a, b NodeIterator) (NodeIterator, *int) { 333 a.Next(true) 334 it := &differenceIterator{ 335 a: a, 336 b: b, 337 } 338 return it, &it.count 339 } 340 341 func (it *differenceIterator) Hash() common.Hash { 342 return it.b.Hash() 343 } 344 345 func (it *differenceIterator) Parent() common.Hash { 346 return it.b.Parent() 347 } 348 349 func (it *differenceIterator) Leaf() bool { 350 return it.b.Leaf() 351 } 352 353 func (it *differenceIterator) LeafKey() []byte { 354 return it.b.LeafKey() 355 } 356 357 func (it *differenceIterator) LeafBlob() []byte { 358 return it.b.LeafBlob() 359 } 360 361 func (it *differenceIterator) LeafProof() [][]byte { 362 return it.b.LeafProof() 363 } 364 365 func (it *differenceIterator) Path() []byte { 366 return it.b.Path() 367 } 368 369 func (it *differenceIterator) Next(bool) bool { 370 371 if !it.b.Next(true) { 372 return false 373 } 374 it.count++ 375 376 if it.eof { 377 378 return true 379 } 380 381 for { 382 switch compareNodes(it.a, it.b) { 383 case -1: 384 385 if !it.a.Next(true) { 386 it.eof = true 387 return true 388 } 389 it.count++ 390 case 1: 391 392 return true 393 case 0: 394 395 hasHash := it.a.Hash() == common.Hash{} 396 if !it.b.Next(hasHash) { 397 return false 398 } 399 it.count++ 400 if !it.a.Next(hasHash) { 401 it.eof = true 402 return true 403 } 404 it.count++ 405 } 406 } 407 } 408 409 func (it *differenceIterator) Error() error { 410 if err := it.a.Error(); err != nil { 411 return err 412 } 413 return it.b.Error() 414 } 415 416 type nodeIteratorHeap []NodeIterator 417 418 func (h nodeIteratorHeap) Len() int { return len(h) } 419 func (h nodeIteratorHeap) Less(i, j int) bool { return compareNodes(h[i], h[j]) < 0 } 420 func (h nodeIteratorHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } 421 func (h *nodeIteratorHeap) Push(x interface{}) { *h = append(*h, x.(NodeIterator)) } 422 func (h *nodeIteratorHeap) Pop() interface{} { 423 n := len(*h) 424 x := (*h)[n-1] 425 *h = (*h)[0 : n-1] 426 return x 427 } 428 429 type unionIterator struct { 430 items *nodeIteratorHeap 431 count int 432 } 433 434 func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) { 435 h := make(nodeIteratorHeap, len(iters)) 436 copy(h, iters) 437 heap.Init(&h) 438 439 ui := &unionIterator{items: &h} 440 return ui, &ui.count 441 } 442 443 func (it *unionIterator) Hash() common.Hash { 444 return (*it.items)[0].Hash() 445 } 446 447 func (it *unionIterator) Parent() common.Hash { 448 return (*it.items)[0].Parent() 449 } 450 451 func (it *unionIterator) Leaf() bool { 452 return (*it.items)[0].Leaf() 453 } 454 455 func (it *unionIterator) LeafKey() []byte { 456 return (*it.items)[0].LeafKey() 457 } 458 459 func (it *unionIterator) LeafBlob() []byte { 460 return (*it.items)[0].LeafBlob() 461 } 462 463 func (it *unionIterator) LeafProof() [][]byte { 464 return (*it.items)[0].LeafProof() 465 } 466 467 func (it *unionIterator) Path() []byte { 468 return (*it.items)[0].Path() 469 } 470 471 func (it *unionIterator) Next(descend bool) bool { 472 if len(*it.items) == 0 { 473 return false 474 } 475 476 least := heap.Pop(it.items).(NodeIterator) 477 478 for len(*it.items) > 0 && ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) || compareNodes(least, (*it.items)[0]) == 0) { 479 skipped := heap.Pop(it.items).(NodeIterator) 480 481 if skipped.Next(skipped.Hash() == common.Hash{}) { 482 it.count++ 483 484 heap.Push(it.items, skipped) 485 } 486 } 487 if least.Next(descend) { 488 it.count++ 489 heap.Push(it.items, least) 490 } 491 return len(*it.items) > 0 492 } 493 494 func (it *unionIterator) Error() error { 495 for i := 0; i < len(*it.items); i++ { 496 if err := (*it.items)[i].Error(); err != nil { 497 return err 498 } 499 } 500 return nil 501 }