git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/concurrentmap/concurrent_map.go (about) 1 package concurrentmap 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "sync" 7 ) 8 9 var SHARD_COUNT = 32 10 11 type Stringer interface { 12 fmt.Stringer 13 comparable 14 } 15 16 // A "thread" safe map of type string:Anything. 17 // To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards. 18 type ConcurrentMap[K comparable, V any] struct { 19 shards []*ConcurrentMapShared[K, V] 20 sharding func(key K) uint32 21 } 22 23 // A "thread" safe string to anything map. 24 type ConcurrentMapShared[K comparable, V any] struct { 25 items map[K]V 26 sync.RWMutex // Read Write mutex, guards access to internal map. 27 } 28 29 func create[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] { 30 m := ConcurrentMap[K, V]{ 31 sharding: sharding, 32 shards: make([]*ConcurrentMapShared[K, V], SHARD_COUNT), 33 } 34 for i := 0; i < SHARD_COUNT; i++ { 35 m.shards[i] = &ConcurrentMapShared[K, V]{items: make(map[K]V)} 36 } 37 return m 38 } 39 40 // Creates a new concurrent map. 41 func New[V any]() ConcurrentMap[string, V] { 42 return create[string, V](fnv32) 43 } 44 45 // Creates a new concurrent map. 46 func NewStringer[K Stringer, V any]() ConcurrentMap[K, V] { 47 return create[K, V](strfnv32[K]) 48 } 49 50 // Creates a new concurrent map. 51 func NewWithCustomShardingFunction[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] { 52 return create[K, V](sharding) 53 } 54 55 // GetShard returns shard under given key 56 func (m ConcurrentMap[K, V]) GetShard(key K) *ConcurrentMapShared[K, V] { 57 return m.shards[uint(m.sharding(key))%uint(SHARD_COUNT)] 58 } 59 60 func (m ConcurrentMap[K, V]) MSet(data map[K]V) { 61 for key, value := range data { 62 shard := m.GetShard(key) 63 shard.Lock() 64 shard.items[key] = value 65 shard.Unlock() 66 } 67 } 68 69 // Sets the given value under the specified key. 70 func (m ConcurrentMap[K, V]) Set(key K, value V) { 71 // Get map shard. 72 shard := m.GetShard(key) 73 shard.Lock() 74 shard.items[key] = value 75 shard.Unlock() 76 } 77 78 // Callback to return new element to be inserted into the map 79 // It is called while lock is held, therefore it MUST NOT 80 // try to access other keys in same map, as it can lead to deadlock since 81 // Go sync.RWLock is not reentrant 82 type UpsertCb[V any] func(exist bool, valueInMap V, newValue V) V 83 84 // Insert or Update - updates existing element or inserts a new one using UpsertCb 85 func (m ConcurrentMap[K, V]) Upsert(key K, value V, cb UpsertCb[V]) (res V) { 86 shard := m.GetShard(key) 87 shard.Lock() 88 v, ok := shard.items[key] 89 res = cb(ok, v, value) 90 shard.items[key] = res 91 shard.Unlock() 92 return res 93 } 94 95 // Sets the given value under the specified key if no value was associated with it. 96 func (m ConcurrentMap[K, V]) SetIfAbsent(key K, value V) bool { 97 // Get map shard. 98 shard := m.GetShard(key) 99 shard.Lock() 100 _, ok := shard.items[key] 101 if !ok { 102 shard.items[key] = value 103 } 104 shard.Unlock() 105 return !ok 106 } 107 108 // Get retrieves an element from map under given key. 109 func (m ConcurrentMap[K, V]) Get(key K) (V, bool) { 110 // Get shard 111 shard := m.GetShard(key) 112 shard.RLock() 113 // Get item from shard. 114 val, ok := shard.items[key] 115 shard.RUnlock() 116 return val, ok 117 } 118 119 // Count returns the number of elements within the map. 120 func (m ConcurrentMap[K, V]) Count() int { 121 count := 0 122 for i := 0; i < SHARD_COUNT; i++ { 123 shard := m.shards[i] 124 shard.RLock() 125 count += len(shard.items) 126 shard.RUnlock() 127 } 128 return count 129 } 130 131 // Looks up an item under specified key 132 func (m ConcurrentMap[K, V]) Has(key K) bool { 133 // Get shard 134 shard := m.GetShard(key) 135 shard.RLock() 136 // See if element is within shard. 137 _, ok := shard.items[key] 138 shard.RUnlock() 139 return ok 140 } 141 142 // Remove removes an element from the map. 143 func (m ConcurrentMap[K, V]) Remove(key K) { 144 // Try to get shard. 145 shard := m.GetShard(key) 146 shard.Lock() 147 delete(shard.items, key) 148 shard.Unlock() 149 } 150 151 // RemoveCb is a callback executed in a map.RemoveCb() call, while Lock is held 152 // If returns true, the element will be removed from the map 153 type RemoveCb[K any, V any] func(key K, v V, exists bool) bool 154 155 // RemoveCb locks the shard containing the key, retrieves its current value and calls the callback with those params 156 // If callback returns true and element exists, it will remove it from the map 157 // Returns the value returned by the callback (even if element was not present in the map) 158 func (m ConcurrentMap[K, V]) RemoveCb(key K, cb RemoveCb[K, V]) bool { 159 // Try to get shard. 160 shard := m.GetShard(key) 161 shard.Lock() 162 v, ok := shard.items[key] 163 remove := cb(key, v, ok) 164 if remove && ok { 165 delete(shard.items, key) 166 } 167 shard.Unlock() 168 return remove 169 } 170 171 // Pop removes an element from the map and returns it 172 func (m ConcurrentMap[K, V]) Pop(key K) (v V, exists bool) { 173 // Try to get shard. 174 shard := m.GetShard(key) 175 shard.Lock() 176 v, exists = shard.items[key] 177 delete(shard.items, key) 178 shard.Unlock() 179 return v, exists 180 } 181 182 // IsEmpty checks if map is empty. 183 func (m ConcurrentMap[K, V]) IsEmpty() bool { 184 return m.Count() == 0 185 } 186 187 // Used by the Iter & IterBuffered functions to wrap two variables together over a channel, 188 type Tuple[K comparable, V any] struct { 189 Key K 190 Val V 191 } 192 193 // Iter returns an iterator which could be used in a for range loop. 194 // 195 // Deprecated: using IterBuffered() will get a better performence 196 func (m ConcurrentMap[K, V]) Iter() <-chan Tuple[K, V] { 197 chans := snapshot(m) 198 ch := make(chan Tuple[K, V]) 199 go fanIn(chans, ch) 200 return ch 201 } 202 203 // IterBuffered returns a buffered iterator which could be used in a for range loop. 204 func (m ConcurrentMap[K, V]) IterBuffered() <-chan Tuple[K, V] { 205 chans := snapshot(m) 206 total := 0 207 for _, c := range chans { 208 total += cap(c) 209 } 210 ch := make(chan Tuple[K, V], total) 211 go fanIn(chans, ch) 212 return ch 213 } 214 215 // Clear removes all items from map. 216 func (m ConcurrentMap[K, V]) Clear() { 217 for item := range m.IterBuffered() { 218 m.Remove(item.Key) 219 } 220 } 221 222 // Returns a array of channels that contains elements in each shard, 223 // which likely takes a snapshot of `m`. 224 // It returns once the size of each buffered channel is determined, 225 // before all the channels are populated using goroutines. 226 func snapshot[K comparable, V any](m ConcurrentMap[K, V]) (chans []chan Tuple[K, V]) { 227 //When you access map items before initializing. 228 if len(m.shards) == 0 { 229 panic(`cmap.ConcurrentMap is not initialized. Should run New() before usage.`) 230 } 231 chans = make([]chan Tuple[K, V], SHARD_COUNT) 232 wg := sync.WaitGroup{} 233 wg.Add(SHARD_COUNT) 234 // Foreach shard. 235 for index, shard := range m.shards { 236 go func(index int, shard *ConcurrentMapShared[K, V]) { 237 // Foreach key, value pair. 238 shard.RLock() 239 chans[index] = make(chan Tuple[K, V], len(shard.items)) 240 wg.Done() 241 for key, val := range shard.items { 242 chans[index] <- Tuple[K, V]{key, val} 243 } 244 shard.RUnlock() 245 close(chans[index]) 246 }(index, shard) 247 } 248 wg.Wait() 249 return chans 250 } 251 252 // fanIn reads elements from channels `chans` into channel `out` 253 func fanIn[K comparable, V any](chans []chan Tuple[K, V], out chan Tuple[K, V]) { 254 wg := sync.WaitGroup{} 255 wg.Add(len(chans)) 256 for _, ch := range chans { 257 go func(ch chan Tuple[K, V]) { 258 for t := range ch { 259 out <- t 260 } 261 wg.Done() 262 }(ch) 263 } 264 wg.Wait() 265 close(out) 266 } 267 268 // Items returns all items as map[string]V 269 func (m ConcurrentMap[K, V]) Items() map[K]V { 270 tmp := make(map[K]V) 271 272 // Insert items to temporary map. 273 for item := range m.IterBuffered() { 274 tmp[item.Key] = item.Val 275 } 276 277 return tmp 278 } 279 280 // Iterator callbacalled for every key,value found in 281 // maps. RLock is held for all calls for a given shard 282 // therefore callback sess consistent view of a shard, 283 // but not across the shards 284 type IterCb[K comparable, V any] func(key K, v V) 285 286 // Callback based iterator, cheapest way to read 287 // all elements in a map. 288 func (m ConcurrentMap[K, V]) IterCb(fn IterCb[K, V]) { 289 for idx := range m.shards { 290 shard := (m.shards)[idx] 291 shard.RLock() 292 for key, value := range shard.items { 293 fn(key, value) 294 } 295 shard.RUnlock() 296 } 297 } 298 299 // Keys returns all keys as []string 300 func (m ConcurrentMap[K, V]) Keys() []K { 301 count := m.Count() 302 ch := make(chan K, count) 303 go func() { 304 // Foreach shard. 305 wg := sync.WaitGroup{} 306 wg.Add(SHARD_COUNT) 307 for _, shard := range m.shards { 308 go func(shard *ConcurrentMapShared[K, V]) { 309 // Foreach key, value pair. 310 shard.RLock() 311 for key := range shard.items { 312 ch <- key 313 } 314 shard.RUnlock() 315 wg.Done() 316 }(shard) 317 } 318 wg.Wait() 319 close(ch) 320 }() 321 322 // Generate keys 323 keys := make([]K, 0, count) 324 for k := range ch { 325 keys = append(keys, k) 326 } 327 return keys 328 } 329 330 // Reviles ConcurrentMap "private" variables to json marshal. 331 func (m ConcurrentMap[K, V]) MarshalJSON() ([]byte, error) { 332 // Create a temporary map, which will hold all item spread across shards. 333 tmp := make(map[K]V) 334 335 // Insert items to temporary map. 336 for item := range m.IterBuffered() { 337 tmp[item.Key] = item.Val 338 } 339 return json.Marshal(tmp) 340 } 341 func strfnv32[K fmt.Stringer](key K) uint32 { 342 return fnv32(key.String()) 343 } 344 345 func fnv32(key string) uint32 { 346 hash := uint32(2166136261) 347 const prime32 = uint32(16777619) 348 keyLength := len(key) 349 for i := 0; i < keyLength; i++ { 350 hash *= prime32 351 hash ^= uint32(key[i]) 352 } 353 return hash 354 } 355 356 // Reverse process of Marshal. 357 func (m *ConcurrentMap[K, V]) UnmarshalJSON(b []byte) (err error) { 358 tmp := make(map[K]V) 359 360 // Unmarshal into a single map. 361 if err := json.Unmarshal(b, &tmp); err != nil { 362 return err 363 } 364 365 // foreach key,value pair in temporary map insert into our concurrent map. 366 for key, val := range tmp { 367 m.Set(key, val) 368 } 369 return nil 370 }