github.com/mhmtszr/concurrent-swiss-map@v1.0.8/concurrent_swiss_map.go (about)

     1  package csmap
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"sync"
     7  
     8  	"github.com/mhmtszr/concurrent-swiss-map/maphash"
     9  
    10  	"github.com/mhmtszr/concurrent-swiss-map/swiss"
    11  )
    12  
    13  type CsMap[K comparable, V any] struct {
    14  	hasher     func(key K) uint64
    15  	shards     []shard[K, V]
    16  	shardCount uint64
    17  	size       uint64
    18  }
    19  
    20  type shard[K comparable, V any] struct {
    21  	items *swiss.Map[K, V]
    22  	*sync.RWMutex
    23  }
    24  
    25  func Create[K comparable, V any](options ...func(options *CsMap[K, V])) *CsMap[K, V] {
    26  	m := CsMap[K, V]{
    27  		hasher:     maphash.NewHasher[K]().Hash,
    28  		shardCount: 32,
    29  	}
    30  	for _, option := range options {
    31  		option(&m)
    32  	}
    33  
    34  	m.shards = make([]shard[K, V], m.shardCount)
    35  
    36  	for i := 0; i < int(m.shardCount); i++ {
    37  		m.shards[i] = shard[K, V]{items: swiss.NewMap[K, V](uint32((m.size / m.shardCount) + 1)), RWMutex: &sync.RWMutex{}}
    38  	}
    39  	return &m
    40  }
    41  
    42  func WithShardCount[K comparable, V any](count uint64) func(csMap *CsMap[K, V]) {
    43  	return func(csMap *CsMap[K, V]) {
    44  		csMap.shardCount = count
    45  	}
    46  }
    47  
    48  func WithCustomHasher[K comparable, V any](h func(key K) uint64) func(csMap *CsMap[K, V]) {
    49  	return func(csMap *CsMap[K, V]) {
    50  		csMap.hasher = h
    51  	}
    52  }
    53  
    54  func WithSize[K comparable, V any](size uint64) func(csMap *CsMap[K, V]) {
    55  	return func(csMap *CsMap[K, V]) {
    56  		csMap.size = size
    57  	}
    58  }
    59  
    60  func (m *CsMap[K, V]) getShard(key K) HashShardPair[K, V] {
    61  	u := m.hasher(key)
    62  	return HashShardPair[K, V]{
    63  		hash:  u,
    64  		shard: m.shards[u%m.shardCount],
    65  	}
    66  }
    67  
    68  func (m *CsMap[K, V]) Store(key K, value V) {
    69  	hashShardPair := m.getShard(key)
    70  	shard := hashShardPair.shard
    71  	shard.Lock()
    72  	shard.items.PutWithHash(key, value, hashShardPair.hash)
    73  	shard.Unlock()
    74  }
    75  
    76  func (m *CsMap[K, V]) Delete(key K) bool {
    77  	hashShardPair := m.getShard(key)
    78  	shard := hashShardPair.shard
    79  	shard.Lock()
    80  	defer shard.Unlock()
    81  	return shard.items.DeleteWithHash(key, hashShardPair.hash)
    82  }
    83  
    84  func (m *CsMap[K, V]) DeleteIf(key K, condition func(value V) bool) bool {
    85  	hashShardPair := m.getShard(key)
    86  	shard := hashShardPair.shard
    87  	shard.Lock()
    88  	defer shard.Unlock()
    89  	value, ok := shard.items.GetWithHash(key, hashShardPair.hash)
    90  	if ok && condition(value) {
    91  		return shard.items.DeleteWithHash(key, hashShardPair.hash)
    92  	}
    93  	return false
    94  }
    95  
    96  func (m *CsMap[K, V]) Load(key K) (V, bool) {
    97  	hashShardPair := m.getShard(key)
    98  	shard := hashShardPair.shard
    99  	shard.RLock()
   100  	defer shard.RUnlock()
   101  	return shard.items.GetWithHash(key, hashShardPair.hash)
   102  }
   103  
   104  func (m *CsMap[K, V]) Has(key K) bool {
   105  	hashShardPair := m.getShard(key)
   106  	shard := hashShardPair.shard
   107  	shard.RLock()
   108  	defer shard.RUnlock()
   109  	return shard.items.HasWithHash(key, hashShardPair.hash)
   110  }
   111  
   112  func (m *CsMap[K, V]) Clear() {
   113  	for i := range m.shards {
   114  		shard := m.shards[i]
   115  
   116  		shard.Lock()
   117  		shard.items.Clear()
   118  		shard.Unlock()
   119  	}
   120  }
   121  
   122  func (m *CsMap[K, V]) Count() int {
   123  	count := 0
   124  	for i := range m.shards {
   125  		shard := m.shards[i]
   126  		shard.RLock()
   127  		count += shard.items.Count()
   128  		shard.RUnlock()
   129  	}
   130  	return count
   131  }
   132  
   133  func (m *CsMap[K, V]) SetIfAbsent(key K, value V) {
   134  	hashShardPair := m.getShard(key)
   135  	shard := hashShardPair.shard
   136  	shard.Lock()
   137  	_, ok := shard.items.GetWithHash(key, hashShardPair.hash)
   138  	if !ok {
   139  		shard.items.PutWithHash(key, value, hashShardPair.hash)
   140  	}
   141  	shard.Unlock()
   142  }
   143  
   144  func (m *CsMap[K, V]) SetIf(key K, conditionFn func(previousVale V, previousFound bool) (value V, set bool)) {
   145  	hashShardPair := m.getShard(key)
   146  	shard := hashShardPair.shard
   147  	shard.Lock()
   148  	value, found := shard.items.GetWithHash(key, hashShardPair.hash)
   149  	value, ok := conditionFn(value, found)
   150  	if ok {
   151  		shard.items.PutWithHash(key, value, hashShardPair.hash)
   152  	}
   153  	shard.Unlock()
   154  }
   155  
   156  func (m *CsMap[K, V]) SetIfPresent(key K, value V) {
   157  	hashShardPair := m.getShard(key)
   158  	shard := hashShardPair.shard
   159  	shard.Lock()
   160  	_, ok := shard.items.GetWithHash(key, hashShardPair.hash)
   161  	if ok {
   162  		shard.items.PutWithHash(key, value, hashShardPair.hash)
   163  	}
   164  	shard.Unlock()
   165  }
   166  
   167  func (m *CsMap[K, V]) IsEmpty() bool {
   168  	return m.Count() == 0
   169  }
   170  
   171  type Tuple[K comparable, V any] struct {
   172  	Key K
   173  	Val V
   174  }
   175  
   176  // Range If the callback function returns true iteration will stop.
   177  func (m *CsMap[K, V]) Range(f func(key K, value V) (stop bool)) {
   178  	ch := make(chan Tuple[K, V], m.Count())
   179  
   180  	ctx, cancel := context.WithCancel(context.Background())
   181  	defer cancel()
   182  
   183  	listenCompleted := m.listen(f, ch)
   184  	m.produce(ctx, ch)
   185  	listenCompleted.Wait()
   186  }
   187  
   188  func (m *CsMap[K, V]) MarshalJSON() ([]byte, error) {
   189  	tmp := make(map[K]V, m.Count())
   190  	m.Range(func(key K, value V) (stop bool) {
   191  		tmp[key] = value
   192  		return false
   193  	})
   194  	return json.Marshal(tmp)
   195  }
   196  
   197  func (m *CsMap[K, V]) UnmarshalJSON(b []byte) error {
   198  	tmp := make(map[K]V, m.Count())
   199  
   200  	if err := json.Unmarshal(b, &tmp); err != nil {
   201  		return err
   202  	}
   203  
   204  	for key, val := range tmp {
   205  		m.Store(key, val)
   206  	}
   207  	return nil
   208  }
   209  
   210  func (m *CsMap[K, V]) produce(ctx context.Context, ch chan Tuple[K, V]) {
   211  	var wg sync.WaitGroup
   212  	wg.Add(len(m.shards))
   213  	for i := range m.shards {
   214  		go func(i int) {
   215  			defer wg.Done()
   216  
   217  			shard := m.shards[i]
   218  			shard.RLock()
   219  			shard.items.Iter(func(k K, v V) (stop bool) {
   220  				select {
   221  				case <-ctx.Done():
   222  					return true
   223  				default:
   224  					ch <- Tuple[K, V]{Key: k, Val: v}
   225  				}
   226  				return false
   227  			})
   228  			shard.RUnlock()
   229  		}(i)
   230  	}
   231  	go func() {
   232  		wg.Wait()
   233  		close(ch)
   234  	}()
   235  }
   236  
   237  func (m *CsMap[K, V]) listen(f func(key K, value V) (stop bool), ch chan Tuple[K, V]) *sync.WaitGroup {
   238  	var wg sync.WaitGroup
   239  	wg.Add(1)
   240  	go func() {
   241  		defer wg.Done()
   242  		for t := range ch {
   243  			if stop := f(t.Key, t.Val); stop {
   244  				return
   245  			}
   246  		}
   247  	}()
   248  	return &wg
   249  }
   250  
   251  type HashShardPair[K comparable, V any] struct {
   252  	shard shard[K, V]
   253  	hash  uint64
   254  }