github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/core/statesync/mptpool.go (about)

     1  package statesync
     2  
     3  import (
     4  	"bytes"
     5  	"sort"
     6  	"sync"
     7  
     8  	"github.com/nspcc-dev/neo-go/pkg/util"
     9  )
    10  
    11  // Pool stores unknown MPT nodes along with the corresponding paths (single node is
    12  // allowed to have multiple MPT paths).
    13  type Pool struct {
    14  	lock   sync.RWMutex
    15  	hashes map[util.Uint256][][]byte
    16  }
    17  
    18  // NewPool returns new MPT node hashes pool.
    19  func NewPool() *Pool {
    20  	return &Pool{
    21  		hashes: make(map[util.Uint256][][]byte),
    22  	}
    23  }
    24  
    25  // ContainsKey checks if MPT node with the specified hash is in the Pool.
    26  func (mp *Pool) ContainsKey(hash util.Uint256) bool {
    27  	mp.lock.RLock()
    28  	defer mp.lock.RUnlock()
    29  
    30  	_, ok := mp.hashes[hash]
    31  	return ok
    32  }
    33  
    34  // TryGet returns a set of MPT paths for the specified HashNode.
    35  func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) {
    36  	mp.lock.RLock()
    37  	defer mp.lock.RUnlock()
    38  
    39  	paths, ok := mp.hashes[hash]
    40  	// need to copy here, because we can modify existing array of paths inside the pool.
    41  	res := make([][]byte, len(paths))
    42  	copy(res, paths)
    43  	return res, ok
    44  }
    45  
    46  // GetAll returns all MPT nodes with the corresponding paths from the pool.
    47  func (mp *Pool) GetAll() map[util.Uint256][][]byte {
    48  	mp.lock.RLock()
    49  	defer mp.lock.RUnlock()
    50  
    51  	return mp.hashes
    52  }
    53  
    54  // GetBatch returns set of unknown MPT nodes hashes (`limit` at max).
    55  func (mp *Pool) GetBatch(limit int) []util.Uint256 {
    56  	mp.lock.RLock()
    57  	defer mp.lock.RUnlock()
    58  
    59  	count := len(mp.hashes)
    60  	if count > limit {
    61  		count = limit
    62  	}
    63  	result := make([]util.Uint256, 0, limit)
    64  	for h := range mp.hashes {
    65  		if count == 0 {
    66  			break
    67  		}
    68  		result = append(result, h)
    69  		count--
    70  	}
    71  	return result
    72  }
    73  
    74  // Remove removes MPT node from the pool by the specified hash.
    75  func (mp *Pool) Remove(hash util.Uint256) {
    76  	mp.lock.Lock()
    77  	defer mp.lock.Unlock()
    78  
    79  	delete(mp.hashes, hash)
    80  }
    81  
    82  // Add adds path to the set of paths for the specified node.
    83  func (mp *Pool) Add(hash util.Uint256, path []byte) {
    84  	mp.lock.Lock()
    85  	defer mp.lock.Unlock()
    86  
    87  	mp.addPaths(hash, [][]byte{path})
    88  }
    89  
    90  // Update is an atomic operation and removes/adds specified nodes from/to the pool.
    91  func (mp *Pool) Update(remove map[util.Uint256][][]byte, add map[util.Uint256][][]byte) {
    92  	mp.lock.Lock()
    93  	defer mp.lock.Unlock()
    94  
    95  	for h, paths := range remove {
    96  		old := mp.hashes[h]
    97  		for _, path := range paths {
    98  			i := sort.Search(len(old), func(i int) bool {
    99  				return bytes.Compare(old[i], path) >= 0
   100  			})
   101  			if i < len(old) && bytes.Equal(old[i], path) {
   102  				old = append(old[:i], old[i+1:]...)
   103  			}
   104  		}
   105  		if len(old) == 0 {
   106  			delete(mp.hashes, h)
   107  		} else {
   108  			mp.hashes[h] = old
   109  		}
   110  	}
   111  	for h, paths := range add {
   112  		mp.addPaths(h, paths)
   113  	}
   114  }
   115  
   116  // addPaths adds set of the specified node paths to the pool.
   117  func (mp *Pool) addPaths(nodeHash util.Uint256, paths [][]byte) {
   118  	old := mp.hashes[nodeHash]
   119  	for _, path := range paths {
   120  		i := sort.Search(len(old), func(i int) bool {
   121  			return bytes.Compare(old[i], path) >= 0
   122  		})
   123  		if i < len(old) && bytes.Equal(old[i], path) {
   124  			// then path is already added
   125  			continue
   126  		}
   127  		old = append(old, path)
   128  		if i != len(old)-1 {
   129  			copy(old[i+1:], old[i:])
   130  			old[i] = path
   131  		}
   132  	}
   133  	mp.hashes[nodeHash] = old
   134  }
   135  
   136  // Count returns the number of nodes in the pool.
   137  func (mp *Pool) Count() int {
   138  	mp.lock.RLock()
   139  	defer mp.lock.RUnlock()
   140  
   141  	return len(mp.hashes)
   142  }