github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/chain/trie/iterator.go (about)

     1  package trie
     2  
     3  import (
     4  	"bytes"
     5  	"container/heap"
     6  	"errors"
     7  
     8  	"github.com/neatlab/neatio/utilities/common"
     9  	"github.com/neatlab/neatio/utilities/rlp"
    10  )
    11  
    12  type Iterator struct {
    13  	nodeIt NodeIterator
    14  
    15  	Key   []byte
    16  	Value []byte
    17  	Err   error
    18  }
    19  
    20  func NewIterator(it NodeIterator) *Iterator {
    21  	return &Iterator{
    22  		nodeIt: it,
    23  	}
    24  }
    25  
    26  func (it *Iterator) Next() bool {
    27  	for it.nodeIt.Next(true) {
    28  		if it.nodeIt.Leaf() {
    29  			it.Key = it.nodeIt.LeafKey()
    30  			it.Value = it.nodeIt.LeafBlob()
    31  			return true
    32  		}
    33  	}
    34  	it.Key = nil
    35  	it.Value = nil
    36  	it.Err = it.nodeIt.Error()
    37  	return false
    38  }
    39  
    40  func (it *Iterator) Prove() [][]byte {
    41  	return it.nodeIt.LeafProof()
    42  }
    43  
    44  type NodeIterator interface {
    45  	Next(bool) bool
    46  
    47  	Error() error
    48  
    49  	Hash() common.Hash
    50  
    51  	Parent() common.Hash
    52  
    53  	Path() []byte
    54  
    55  	Leaf() bool
    56  
    57  	LeafKey() []byte
    58  
    59  	LeafBlob() []byte
    60  
    61  	LeafProof() [][]byte
    62  }
    63  
    64  type nodeIteratorState struct {
    65  	hash    common.Hash
    66  	node    node
    67  	parent  common.Hash
    68  	index   int
    69  	pathlen int
    70  }
    71  
    72  type nodeIterator struct {
    73  	trie  *Trie
    74  	stack []*nodeIteratorState
    75  	path  []byte
    76  	err   error
    77  }
    78  
    79  var errIteratorEnd = errors.New("end of iteration")
    80  
    81  type seekError struct {
    82  	key []byte
    83  	err error
    84  }
    85  
    86  func (e seekError) Error() string {
    87  	return "seek error: " + e.err.Error()
    88  }
    89  
    90  func newNodeIterator(trie *Trie, start []byte) NodeIterator {
    91  	if trie.Hash() == emptyState {
    92  		return new(nodeIterator)
    93  	}
    94  	it := &nodeIterator{trie: trie}
    95  	it.err = it.seek(start)
    96  	return it
    97  }
    98  
    99  func (it *nodeIterator) Hash() common.Hash {
   100  	if len(it.stack) == 0 {
   101  		return common.Hash{}
   102  	}
   103  	return it.stack[len(it.stack)-1].hash
   104  }
   105  
   106  func (it *nodeIterator) Parent() common.Hash {
   107  	if len(it.stack) == 0 {
   108  		return common.Hash{}
   109  	}
   110  	return it.stack[len(it.stack)-1].parent
   111  }
   112  
   113  func (it *nodeIterator) Leaf() bool {
   114  	return hasTerm(it.path)
   115  }
   116  
   117  func (it *nodeIterator) LeafKey() []byte {
   118  	if len(it.stack) > 0 {
   119  		if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
   120  			return hexToKeybytes(it.path)
   121  		}
   122  	}
   123  	panic("not at leaf")
   124  }
   125  
   126  func (it *nodeIterator) LeafBlob() []byte {
   127  	if len(it.stack) > 0 {
   128  		if node, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
   129  			return []byte(node)
   130  		}
   131  	}
   132  	panic("not at leaf")
   133  }
   134  
   135  func (it *nodeIterator) LeafProof() [][]byte {
   136  	if len(it.stack) > 0 {
   137  		if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
   138  			hasher := newHasher(nil)
   139  			defer returnHasherToPool(hasher)
   140  
   141  			proofs := make([][]byte, 0, len(it.stack))
   142  
   143  			for i, item := range it.stack[:len(it.stack)-1] {
   144  
   145  				node, _, _ := hasher.hashChildren(item.node, nil)
   146  				hashed, _ := hasher.store(node, nil, false)
   147  				if _, ok := hashed.(hashNode); ok || i == 0 {
   148  					enc, _ := rlp.EncodeToBytes(node)
   149  					proofs = append(proofs, enc)
   150  				}
   151  			}
   152  			return proofs
   153  		}
   154  	}
   155  	panic("not at leaf")
   156  }
   157  
   158  func (it *nodeIterator) Path() []byte {
   159  	return it.path
   160  }
   161  
   162  func (it *nodeIterator) Error() error {
   163  	if it.err == errIteratorEnd {
   164  		return nil
   165  	}
   166  	if seek, ok := it.err.(seekError); ok {
   167  		return seek.err
   168  	}
   169  	return it.err
   170  }
   171  
   172  func (it *nodeIterator) Next(descend bool) bool {
   173  	if it.err == errIteratorEnd {
   174  		return false
   175  	}
   176  	if seek, ok := it.err.(seekError); ok {
   177  		if it.err = it.seek(seek.key); it.err != nil {
   178  			return false
   179  		}
   180  	}
   181  
   182  	state, parentIndex, path, err := it.peek(descend)
   183  	it.err = err
   184  	if it.err != nil {
   185  		return false
   186  	}
   187  	it.push(state, parentIndex, path)
   188  	return true
   189  }
   190  
   191  func (it *nodeIterator) seek(prefix []byte) error {
   192  
   193  	key := keybytesToHex(prefix)
   194  	key = key[:len(key)-1]
   195  
   196  	for {
   197  		state, parentIndex, path, err := it.peek(bytes.HasPrefix(key, it.path))
   198  		if err == errIteratorEnd {
   199  			return errIteratorEnd
   200  		} else if err != nil {
   201  			return seekError{prefix, err}
   202  		} else if bytes.Compare(path, key) >= 0 {
   203  			return nil
   204  		}
   205  		it.push(state, parentIndex, path)
   206  	}
   207  }
   208  
   209  func (it *nodeIterator) peek(descend bool) (*nodeIteratorState, *int, []byte, error) {
   210  	if len(it.stack) == 0 {
   211  
   212  		root := it.trie.Hash()
   213  		state := &nodeIteratorState{node: it.trie.root, index: -1}
   214  		if root != emptyRoot {
   215  			state.hash = root
   216  		}
   217  		err := state.resolve(it.trie, nil)
   218  		return state, nil, nil, err
   219  	}
   220  	if !descend {
   221  
   222  		it.pop()
   223  	}
   224  
   225  	for len(it.stack) > 0 {
   226  		parent := it.stack[len(it.stack)-1]
   227  		ancestor := parent.hash
   228  		if (ancestor == common.Hash{}) {
   229  			ancestor = parent.parent
   230  		}
   231  		state, path, ok := it.nextChild(parent, ancestor)
   232  		if ok {
   233  			if err := state.resolve(it.trie, path); err != nil {
   234  				return parent, &parent.index, path, err
   235  			}
   236  			return state, &parent.index, path, nil
   237  		}
   238  
   239  		it.pop()
   240  	}
   241  	return nil, nil, nil, errIteratorEnd
   242  }
   243  
   244  func (st *nodeIteratorState) resolve(tr *Trie, path []byte) error {
   245  	if hash, ok := st.node.(hashNode); ok {
   246  		resolved, err := tr.resolveHash(hash, path)
   247  		if err != nil {
   248  			return err
   249  		}
   250  		st.node = resolved
   251  		st.hash = common.BytesToHash(hash)
   252  	}
   253  	return nil
   254  }
   255  
   256  func (it *nodeIterator) nextChild(parent *nodeIteratorState, ancestor common.Hash) (*nodeIteratorState, []byte, bool) {
   257  	switch node := parent.node.(type) {
   258  	case *fullNode:
   259  
   260  		for i := parent.index + 1; i < len(node.Children); i++ {
   261  			side := node.Children[i]
   262  			if side != nil {
   263  				hash, _ := side.cache()
   264  				state := &nodeIteratorState{
   265  					hash:    common.BytesToHash(hash),
   266  					node:    side,
   267  					parent:  ancestor,
   268  					index:   -1,
   269  					pathlen: len(it.path),
   270  				}
   271  				path := append(it.path, byte(i))
   272  				parent.index = i - 1
   273  				return state, path, true
   274  			}
   275  		}
   276  	case *shortNode:
   277  
   278  		if parent.index < 0 {
   279  			hash, _ := node.Val.cache()
   280  			state := &nodeIteratorState{
   281  				hash:    common.BytesToHash(hash),
   282  				node:    node.Val,
   283  				parent:  ancestor,
   284  				index:   -1,
   285  				pathlen: len(it.path),
   286  			}
   287  			path := append(it.path, node.Key...)
   288  			return state, path, true
   289  		}
   290  	}
   291  	return parent, it.path, false
   292  }
   293  
   294  func (it *nodeIterator) push(state *nodeIteratorState, parentIndex *int, path []byte) {
   295  	it.path = path
   296  	it.stack = append(it.stack, state)
   297  	if parentIndex != nil {
   298  		*parentIndex++
   299  	}
   300  }
   301  
   302  func (it *nodeIterator) pop() {
   303  	parent := it.stack[len(it.stack)-1]
   304  	it.path = it.path[:parent.pathlen]
   305  	it.stack = it.stack[:len(it.stack)-1]
   306  }
   307  
   308  func compareNodes(a, b NodeIterator) int {
   309  	if cmp := bytes.Compare(a.Path(), b.Path()); cmp != 0 {
   310  		return cmp
   311  	}
   312  	if a.Leaf() && !b.Leaf() {
   313  		return -1
   314  	} else if b.Leaf() && !a.Leaf() {
   315  		return 1
   316  	}
   317  	if cmp := bytes.Compare(a.Hash().Bytes(), b.Hash().Bytes()); cmp != 0 {
   318  		return cmp
   319  	}
   320  	if a.Leaf() && b.Leaf() {
   321  		return bytes.Compare(a.LeafBlob(), b.LeafBlob())
   322  	}
   323  	return 0
   324  }
   325  
   326  type differenceIterator struct {
   327  	a, b  NodeIterator
   328  	eof   bool
   329  	count int
   330  }
   331  
   332  func NewDifferenceIterator(a, b NodeIterator) (NodeIterator, *int) {
   333  	a.Next(true)
   334  	it := &differenceIterator{
   335  		a: a,
   336  		b: b,
   337  	}
   338  	return it, &it.count
   339  }
   340  
   341  func (it *differenceIterator) Hash() common.Hash {
   342  	return it.b.Hash()
   343  }
   344  
   345  func (it *differenceIterator) Parent() common.Hash {
   346  	return it.b.Parent()
   347  }
   348  
   349  func (it *differenceIterator) Leaf() bool {
   350  	return it.b.Leaf()
   351  }
   352  
   353  func (it *differenceIterator) LeafKey() []byte {
   354  	return it.b.LeafKey()
   355  }
   356  
   357  func (it *differenceIterator) LeafBlob() []byte {
   358  	return it.b.LeafBlob()
   359  }
   360  
   361  func (it *differenceIterator) LeafProof() [][]byte {
   362  	return it.b.LeafProof()
   363  }
   364  
   365  func (it *differenceIterator) Path() []byte {
   366  	return it.b.Path()
   367  }
   368  
   369  func (it *differenceIterator) Next(bool) bool {
   370  
   371  	if !it.b.Next(true) {
   372  		return false
   373  	}
   374  	it.count++
   375  
   376  	if it.eof {
   377  
   378  		return true
   379  	}
   380  
   381  	for {
   382  		switch compareNodes(it.a, it.b) {
   383  		case -1:
   384  
   385  			if !it.a.Next(true) {
   386  				it.eof = true
   387  				return true
   388  			}
   389  			it.count++
   390  		case 1:
   391  
   392  			return true
   393  		case 0:
   394  
   395  			hasHash := it.a.Hash() == common.Hash{}
   396  			if !it.b.Next(hasHash) {
   397  				return false
   398  			}
   399  			it.count++
   400  			if !it.a.Next(hasHash) {
   401  				it.eof = true
   402  				return true
   403  			}
   404  			it.count++
   405  		}
   406  	}
   407  }
   408  
   409  func (it *differenceIterator) Error() error {
   410  	if err := it.a.Error(); err != nil {
   411  		return err
   412  	}
   413  	return it.b.Error()
   414  }
   415  
   416  type nodeIteratorHeap []NodeIterator
   417  
   418  func (h nodeIteratorHeap) Len() int            { return len(h) }
   419  func (h nodeIteratorHeap) Less(i, j int) bool  { return compareNodes(h[i], h[j]) < 0 }
   420  func (h nodeIteratorHeap) Swap(i, j int)       { h[i], h[j] = h[j], h[i] }
   421  func (h *nodeIteratorHeap) Push(x interface{}) { *h = append(*h, x.(NodeIterator)) }
   422  func (h *nodeIteratorHeap) Pop() interface{} {
   423  	n := len(*h)
   424  	x := (*h)[n-1]
   425  	*h = (*h)[0 : n-1]
   426  	return x
   427  }
   428  
   429  type unionIterator struct {
   430  	items *nodeIteratorHeap
   431  	count int
   432  }
   433  
   434  func NewUnionIterator(iters []NodeIterator) (NodeIterator, *int) {
   435  	h := make(nodeIteratorHeap, len(iters))
   436  	copy(h, iters)
   437  	heap.Init(&h)
   438  
   439  	ui := &unionIterator{items: &h}
   440  	return ui, &ui.count
   441  }
   442  
   443  func (it *unionIterator) Hash() common.Hash {
   444  	return (*it.items)[0].Hash()
   445  }
   446  
   447  func (it *unionIterator) Parent() common.Hash {
   448  	return (*it.items)[0].Parent()
   449  }
   450  
   451  func (it *unionIterator) Leaf() bool {
   452  	return (*it.items)[0].Leaf()
   453  }
   454  
   455  func (it *unionIterator) LeafKey() []byte {
   456  	return (*it.items)[0].LeafKey()
   457  }
   458  
   459  func (it *unionIterator) LeafBlob() []byte {
   460  	return (*it.items)[0].LeafBlob()
   461  }
   462  
   463  func (it *unionIterator) LeafProof() [][]byte {
   464  	return (*it.items)[0].LeafProof()
   465  }
   466  
   467  func (it *unionIterator) Path() []byte {
   468  	return (*it.items)[0].Path()
   469  }
   470  
   471  func (it *unionIterator) Next(descend bool) bool {
   472  	if len(*it.items) == 0 {
   473  		return false
   474  	}
   475  
   476  	least := heap.Pop(it.items).(NodeIterator)
   477  
   478  	for len(*it.items) > 0 && ((!descend && bytes.HasPrefix((*it.items)[0].Path(), least.Path())) || compareNodes(least, (*it.items)[0]) == 0) {
   479  		skipped := heap.Pop(it.items).(NodeIterator)
   480  
   481  		if skipped.Next(skipped.Hash() == common.Hash{}) {
   482  			it.count++
   483  
   484  			heap.Push(it.items, skipped)
   485  		}
   486  	}
   487  	if least.Next(descend) {
   488  		it.count++
   489  		heap.Push(it.items, least)
   490  	}
   491  	return len(*it.items) > 0
   492  }
   493  
   494  func (it *unionIterator) Error() error {
   495  	for i := 0; i < len(*it.items); i++ {
   496  		if err := (*it.items)[i].Error(); err != nil {
   497  			return err
   498  		}
   499  	}
   500  	return nil
   501  }