git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/concurrentmap/concurrent_map.go (about)

     1  package concurrentmap
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"sync"
     7  )
     8  
     9  var SHARD_COUNT = 32
    10  
    11  type Stringer interface {
    12  	fmt.Stringer
    13  	comparable
    14  }
    15  
    16  // A "thread" safe map of type string:Anything.
    17  // To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards.
    18  type ConcurrentMap[K comparable, V any] struct {
    19  	shards   []*ConcurrentMapShared[K, V]
    20  	sharding func(key K) uint32
    21  }
    22  
    23  // A "thread" safe string to anything map.
    24  type ConcurrentMapShared[K comparable, V any] struct {
    25  	items        map[K]V
    26  	sync.RWMutex // Read Write mutex, guards access to internal map.
    27  }
    28  
    29  func create[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] {
    30  	m := ConcurrentMap[K, V]{
    31  		sharding: sharding,
    32  		shards:   make([]*ConcurrentMapShared[K, V], SHARD_COUNT),
    33  	}
    34  	for i := 0; i < SHARD_COUNT; i++ {
    35  		m.shards[i] = &ConcurrentMapShared[K, V]{items: make(map[K]V)}
    36  	}
    37  	return m
    38  }
    39  
    40  // Creates a new concurrent map.
    41  func New[V any]() ConcurrentMap[string, V] {
    42  	return create[string, V](fnv32)
    43  }
    44  
    45  // Creates a new concurrent map.
    46  func NewStringer[K Stringer, V any]() ConcurrentMap[K, V] {
    47  	return create[K, V](strfnv32[K])
    48  }
    49  
    50  // Creates a new concurrent map.
    51  func NewWithCustomShardingFunction[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] {
    52  	return create[K, V](sharding)
    53  }
    54  
    55  // GetShard returns shard under given key
    56  func (m ConcurrentMap[K, V]) GetShard(key K) *ConcurrentMapShared[K, V] {
    57  	return m.shards[uint(m.sharding(key))%uint(SHARD_COUNT)]
    58  }
    59  
    60  func (m ConcurrentMap[K, V]) MSet(data map[K]V) {
    61  	for key, value := range data {
    62  		shard := m.GetShard(key)
    63  		shard.Lock()
    64  		shard.items[key] = value
    65  		shard.Unlock()
    66  	}
    67  }
    68  
    69  // Sets the given value under the specified key.
    70  func (m ConcurrentMap[K, V]) Set(key K, value V) {
    71  	// Get map shard.
    72  	shard := m.GetShard(key)
    73  	shard.Lock()
    74  	shard.items[key] = value
    75  	shard.Unlock()
    76  }
    77  
    78  // Callback to return new element to be inserted into the map
    79  // It is called while lock is held, therefore it MUST NOT
    80  // try to access other keys in same map, as it can lead to deadlock since
    81  // Go sync.RWLock is not reentrant
    82  type UpsertCb[V any] func(exist bool, valueInMap V, newValue V) V
    83  
    84  // Insert or Update - updates existing element or inserts a new one using UpsertCb
    85  func (m ConcurrentMap[K, V]) Upsert(key K, value V, cb UpsertCb[V]) (res V) {
    86  	shard := m.GetShard(key)
    87  	shard.Lock()
    88  	v, ok := shard.items[key]
    89  	res = cb(ok, v, value)
    90  	shard.items[key] = res
    91  	shard.Unlock()
    92  	return res
    93  }
    94  
    95  // Sets the given value under the specified key if no value was associated with it.
    96  func (m ConcurrentMap[K, V]) SetIfAbsent(key K, value V) bool {
    97  	// Get map shard.
    98  	shard := m.GetShard(key)
    99  	shard.Lock()
   100  	_, ok := shard.items[key]
   101  	if !ok {
   102  		shard.items[key] = value
   103  	}
   104  	shard.Unlock()
   105  	return !ok
   106  }
   107  
   108  // Get retrieves an element from map under given key.
   109  func (m ConcurrentMap[K, V]) Get(key K) (V, bool) {
   110  	// Get shard
   111  	shard := m.GetShard(key)
   112  	shard.RLock()
   113  	// Get item from shard.
   114  	val, ok := shard.items[key]
   115  	shard.RUnlock()
   116  	return val, ok
   117  }
   118  
   119  // Count returns the number of elements within the map.
   120  func (m ConcurrentMap[K, V]) Count() int {
   121  	count := 0
   122  	for i := 0; i < SHARD_COUNT; i++ {
   123  		shard := m.shards[i]
   124  		shard.RLock()
   125  		count += len(shard.items)
   126  		shard.RUnlock()
   127  	}
   128  	return count
   129  }
   130  
   131  // Looks up an item under specified key
   132  func (m ConcurrentMap[K, V]) Has(key K) bool {
   133  	// Get shard
   134  	shard := m.GetShard(key)
   135  	shard.RLock()
   136  	// See if element is within shard.
   137  	_, ok := shard.items[key]
   138  	shard.RUnlock()
   139  	return ok
   140  }
   141  
   142  // Remove removes an element from the map.
   143  func (m ConcurrentMap[K, V]) Remove(key K) {
   144  	// Try to get shard.
   145  	shard := m.GetShard(key)
   146  	shard.Lock()
   147  	delete(shard.items, key)
   148  	shard.Unlock()
   149  }
   150  
   151  // RemoveCb is a callback executed in a map.RemoveCb() call, while Lock is held
   152  // If returns true, the element will be removed from the map
   153  type RemoveCb[K any, V any] func(key K, v V, exists bool) bool
   154  
   155  // RemoveCb locks the shard containing the key, retrieves its current value and calls the callback with those params
   156  // If callback returns true and element exists, it will remove it from the map
   157  // Returns the value returned by the callback (even if element was not present in the map)
   158  func (m ConcurrentMap[K, V]) RemoveCb(key K, cb RemoveCb[K, V]) bool {
   159  	// Try to get shard.
   160  	shard := m.GetShard(key)
   161  	shard.Lock()
   162  	v, ok := shard.items[key]
   163  	remove := cb(key, v, ok)
   164  	if remove && ok {
   165  		delete(shard.items, key)
   166  	}
   167  	shard.Unlock()
   168  	return remove
   169  }
   170  
   171  // Pop removes an element from the map and returns it
   172  func (m ConcurrentMap[K, V]) Pop(key K) (v V, exists bool) {
   173  	// Try to get shard.
   174  	shard := m.GetShard(key)
   175  	shard.Lock()
   176  	v, exists = shard.items[key]
   177  	delete(shard.items, key)
   178  	shard.Unlock()
   179  	return v, exists
   180  }
   181  
   182  // IsEmpty checks if map is empty.
   183  func (m ConcurrentMap[K, V]) IsEmpty() bool {
   184  	return m.Count() == 0
   185  }
   186  
   187  // Used by the Iter & IterBuffered functions to wrap two variables together over a channel,
   188  type Tuple[K comparable, V any] struct {
   189  	Key K
   190  	Val V
   191  }
   192  
   193  // Iter returns an iterator which could be used in a for range loop.
   194  //
   195  // Deprecated: using IterBuffered() will get a better performence
   196  func (m ConcurrentMap[K, V]) Iter() <-chan Tuple[K, V] {
   197  	chans := snapshot(m)
   198  	ch := make(chan Tuple[K, V])
   199  	go fanIn(chans, ch)
   200  	return ch
   201  }
   202  
   203  // IterBuffered returns a buffered iterator which could be used in a for range loop.
   204  func (m ConcurrentMap[K, V]) IterBuffered() <-chan Tuple[K, V] {
   205  	chans := snapshot(m)
   206  	total := 0
   207  	for _, c := range chans {
   208  		total += cap(c)
   209  	}
   210  	ch := make(chan Tuple[K, V], total)
   211  	go fanIn(chans, ch)
   212  	return ch
   213  }
   214  
   215  // Clear removes all items from map.
   216  func (m ConcurrentMap[K, V]) Clear() {
   217  	for item := range m.IterBuffered() {
   218  		m.Remove(item.Key)
   219  	}
   220  }
   221  
   222  // Returns a array of channels that contains elements in each shard,
   223  // which likely takes a snapshot of `m`.
   224  // It returns once the size of each buffered channel is determined,
   225  // before all the channels are populated using goroutines.
   226  func snapshot[K comparable, V any](m ConcurrentMap[K, V]) (chans []chan Tuple[K, V]) {
   227  	//When you access map items before initializing.
   228  	if len(m.shards) == 0 {
   229  		panic(`cmap.ConcurrentMap is not initialized. Should run New() before usage.`)
   230  	}
   231  	chans = make([]chan Tuple[K, V], SHARD_COUNT)
   232  	wg := sync.WaitGroup{}
   233  	wg.Add(SHARD_COUNT)
   234  	// Foreach shard.
   235  	for index, shard := range m.shards {
   236  		go func(index int, shard *ConcurrentMapShared[K, V]) {
   237  			// Foreach key, value pair.
   238  			shard.RLock()
   239  			chans[index] = make(chan Tuple[K, V], len(shard.items))
   240  			wg.Done()
   241  			for key, val := range shard.items {
   242  				chans[index] <- Tuple[K, V]{key, val}
   243  			}
   244  			shard.RUnlock()
   245  			close(chans[index])
   246  		}(index, shard)
   247  	}
   248  	wg.Wait()
   249  	return chans
   250  }
   251  
   252  // fanIn reads elements from channels `chans` into channel `out`
   253  func fanIn[K comparable, V any](chans []chan Tuple[K, V], out chan Tuple[K, V]) {
   254  	wg := sync.WaitGroup{}
   255  	wg.Add(len(chans))
   256  	for _, ch := range chans {
   257  		go func(ch chan Tuple[K, V]) {
   258  			for t := range ch {
   259  				out <- t
   260  			}
   261  			wg.Done()
   262  		}(ch)
   263  	}
   264  	wg.Wait()
   265  	close(out)
   266  }
   267  
   268  // Items returns all items as map[string]V
   269  func (m ConcurrentMap[K, V]) Items() map[K]V {
   270  	tmp := make(map[K]V)
   271  
   272  	// Insert items to temporary map.
   273  	for item := range m.IterBuffered() {
   274  		tmp[item.Key] = item.Val
   275  	}
   276  
   277  	return tmp
   278  }
   279  
   280  // Iterator callbacalled for every key,value found in
   281  // maps. RLock is held for all calls for a given shard
   282  // therefore callback sess consistent view of a shard,
   283  // but not across the shards
   284  type IterCb[K comparable, V any] func(key K, v V)
   285  
   286  // Callback based iterator, cheapest way to read
   287  // all elements in a map.
   288  func (m ConcurrentMap[K, V]) IterCb(fn IterCb[K, V]) {
   289  	for idx := range m.shards {
   290  		shard := (m.shards)[idx]
   291  		shard.RLock()
   292  		for key, value := range shard.items {
   293  			fn(key, value)
   294  		}
   295  		shard.RUnlock()
   296  	}
   297  }
   298  
   299  // Keys returns all keys as []string
   300  func (m ConcurrentMap[K, V]) Keys() []K {
   301  	count := m.Count()
   302  	ch := make(chan K, count)
   303  	go func() {
   304  		// Foreach shard.
   305  		wg := sync.WaitGroup{}
   306  		wg.Add(SHARD_COUNT)
   307  		for _, shard := range m.shards {
   308  			go func(shard *ConcurrentMapShared[K, V]) {
   309  				// Foreach key, value pair.
   310  				shard.RLock()
   311  				for key := range shard.items {
   312  					ch <- key
   313  				}
   314  				shard.RUnlock()
   315  				wg.Done()
   316  			}(shard)
   317  		}
   318  		wg.Wait()
   319  		close(ch)
   320  	}()
   321  
   322  	// Generate keys
   323  	keys := make([]K, 0, count)
   324  	for k := range ch {
   325  		keys = append(keys, k)
   326  	}
   327  	return keys
   328  }
   329  
   330  // Reviles ConcurrentMap "private" variables to json marshal.
   331  func (m ConcurrentMap[K, V]) MarshalJSON() ([]byte, error) {
   332  	// Create a temporary map, which will hold all item spread across shards.
   333  	tmp := make(map[K]V)
   334  
   335  	// Insert items to temporary map.
   336  	for item := range m.IterBuffered() {
   337  		tmp[item.Key] = item.Val
   338  	}
   339  	return json.Marshal(tmp)
   340  }
   341  func strfnv32[K fmt.Stringer](key K) uint32 {
   342  	return fnv32(key.String())
   343  }
   344  
   345  func fnv32(key string) uint32 {
   346  	hash := uint32(2166136261)
   347  	const prime32 = uint32(16777619)
   348  	keyLength := len(key)
   349  	for i := 0; i < keyLength; i++ {
   350  		hash *= prime32
   351  		hash ^= uint32(key[i])
   352  	}
   353  	return hash
   354  }
   355  
   356  // Reverse process of Marshal.
   357  func (m *ConcurrentMap[K, V]) UnmarshalJSON(b []byte) (err error) {
   358  	tmp := make(map[K]V)
   359  
   360  	// Unmarshal into a single map.
   361  	if err := json.Unmarshal(b, &tmp); err != nil {
   362  		return err
   363  	}
   364  
   365  	// foreach key,value pair in temporary map insert into our concurrent map.
   366  	for key, val := range tmp {
   367  		m.Set(key, val)
   368  	}
   369  	return nil
   370  }