github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/core/statesync/mptpool.go (about) 1 package statesync 2 3 import ( 4 "bytes" 5 "sort" 6 "sync" 7 8 "github.com/nspcc-dev/neo-go/pkg/util" 9 ) 10 11 // Pool stores unknown MPT nodes along with the corresponding paths (single node is 12 // allowed to have multiple MPT paths). 13 type Pool struct { 14 lock sync.RWMutex 15 hashes map[util.Uint256][][]byte 16 } 17 18 // NewPool returns new MPT node hashes pool. 19 func NewPool() *Pool { 20 return &Pool{ 21 hashes: make(map[util.Uint256][][]byte), 22 } 23 } 24 25 // ContainsKey checks if MPT node with the specified hash is in the Pool. 26 func (mp *Pool) ContainsKey(hash util.Uint256) bool { 27 mp.lock.RLock() 28 defer mp.lock.RUnlock() 29 30 _, ok := mp.hashes[hash] 31 return ok 32 } 33 34 // TryGet returns a set of MPT paths for the specified HashNode. 35 func (mp *Pool) TryGet(hash util.Uint256) ([][]byte, bool) { 36 mp.lock.RLock() 37 defer mp.lock.RUnlock() 38 39 paths, ok := mp.hashes[hash] 40 // need to copy here, because we can modify existing array of paths inside the pool. 41 res := make([][]byte, len(paths)) 42 copy(res, paths) 43 return res, ok 44 } 45 46 // GetAll returns all MPT nodes with the corresponding paths from the pool. 47 func (mp *Pool) GetAll() map[util.Uint256][][]byte { 48 mp.lock.RLock() 49 defer mp.lock.RUnlock() 50 51 return mp.hashes 52 } 53 54 // GetBatch returns set of unknown MPT nodes hashes (`limit` at max). 55 func (mp *Pool) GetBatch(limit int) []util.Uint256 { 56 mp.lock.RLock() 57 defer mp.lock.RUnlock() 58 59 count := len(mp.hashes) 60 if count > limit { 61 count = limit 62 } 63 result := make([]util.Uint256, 0, limit) 64 for h := range mp.hashes { 65 if count == 0 { 66 break 67 } 68 result = append(result, h) 69 count-- 70 } 71 return result 72 } 73 74 // Remove removes MPT node from the pool by the specified hash. 75 func (mp *Pool) Remove(hash util.Uint256) { 76 mp.lock.Lock() 77 defer mp.lock.Unlock() 78 79 delete(mp.hashes, hash) 80 } 81 82 // Add adds path to the set of paths for the specified node. 83 func (mp *Pool) Add(hash util.Uint256, path []byte) { 84 mp.lock.Lock() 85 defer mp.lock.Unlock() 86 87 mp.addPaths(hash, [][]byte{path}) 88 } 89 90 // Update is an atomic operation and removes/adds specified nodes from/to the pool. 91 func (mp *Pool) Update(remove map[util.Uint256][][]byte, add map[util.Uint256][][]byte) { 92 mp.lock.Lock() 93 defer mp.lock.Unlock() 94 95 for h, paths := range remove { 96 old := mp.hashes[h] 97 for _, path := range paths { 98 i := sort.Search(len(old), func(i int) bool { 99 return bytes.Compare(old[i], path) >= 0 100 }) 101 if i < len(old) && bytes.Equal(old[i], path) { 102 old = append(old[:i], old[i+1:]...) 103 } 104 } 105 if len(old) == 0 { 106 delete(mp.hashes, h) 107 } else { 108 mp.hashes[h] = old 109 } 110 } 111 for h, paths := range add { 112 mp.addPaths(h, paths) 113 } 114 } 115 116 // addPaths adds set of the specified node paths to the pool. 117 func (mp *Pool) addPaths(nodeHash util.Uint256, paths [][]byte) { 118 old := mp.hashes[nodeHash] 119 for _, path := range paths { 120 i := sort.Search(len(old), func(i int) bool { 121 return bytes.Compare(old[i], path) >= 0 122 }) 123 if i < len(old) && bytes.Equal(old[i], path) { 124 // then path is already added 125 continue 126 } 127 old = append(old, path) 128 if i != len(old)-1 { 129 copy(old[i+1:], old[i:]) 130 old[i] = path 131 } 132 } 133 mp.hashes[nodeHash] = old 134 } 135 136 // Count returns the number of nodes in the pool. 137 func (mp *Pool) Count() int { 138 mp.lock.RLock() 139 defer mp.lock.RUnlock() 140 141 return len(mp.hashes) 142 }