github.com/hdt3213/godis@v1.2.9/datastruct/dict/concurrent.go (about)

     1  package dict
     2  
     3  import (
     4  	"math"
     5  	"math/rand"
     6  	"sort"
     7  	"sync"
     8  	"sync/atomic"
     9  	"time"
    10  )
    11  
    12  // ConcurrentDict is thread safe map using sharding lock
    13  type ConcurrentDict struct {
    14  	table      []*shard
    15  	count      int32
    16  	shardCount int
    17  }
    18  
    19  type shard struct {
    20  	m     map[string]interface{}
    21  	mutex sync.RWMutex
    22  }
    23  
    24  func computeCapacity(param int) (size int) {
    25  	if param <= 16 {
    26  		return 16
    27  	}
    28  	n := param - 1
    29  	n |= n >> 1
    30  	n |= n >> 2
    31  	n |= n >> 4
    32  	n |= n >> 8
    33  	n |= n >> 16
    34  	if n < 0 {
    35  		return math.MaxInt32
    36  	}
    37  	return n + 1
    38  }
    39  
    40  // MakeConcurrent creates ConcurrentDict with the given shard count
    41  func MakeConcurrent(shardCount int) *ConcurrentDict {
    42  	shardCount = computeCapacity(shardCount)
    43  	table := make([]*shard, shardCount)
    44  	for i := 0; i < shardCount; i++ {
    45  		table[i] = &shard{
    46  			m: make(map[string]interface{}),
    47  		}
    48  	}
    49  	d := &ConcurrentDict{
    50  		count:      0,
    51  		table:      table,
    52  		shardCount: shardCount,
    53  	}
    54  	return d
    55  }
    56  
    57  const prime32 = uint32(16777619)
    58  
    59  func fnv32(key string) uint32 {
    60  	hash := uint32(2166136261)
    61  	for i := 0; i < len(key); i++ {
    62  		hash *= prime32
    63  		hash ^= uint32(key[i])
    64  	}
    65  	return hash
    66  }
    67  
    68  func (dict *ConcurrentDict) spread(hashCode uint32) uint32 {
    69  	if dict == nil {
    70  		panic("dict is nil")
    71  	}
    72  	tableSize := uint32(len(dict.table))
    73  	return (tableSize - 1) & hashCode
    74  }
    75  
    76  func (dict *ConcurrentDict) getShard(index uint32) *shard {
    77  	if dict == nil {
    78  		panic("dict is nil")
    79  	}
    80  	return dict.table[index]
    81  }
    82  
    83  // Get returns the binding value and whether the key is exist
    84  func (dict *ConcurrentDict) Get(key string) (val interface{}, exists bool) {
    85  	if dict == nil {
    86  		panic("dict is nil")
    87  	}
    88  	hashCode := fnv32(key)
    89  	index := dict.spread(hashCode)
    90  	s := dict.getShard(index)
    91  	s.mutex.Lock()
    92  	defer s.mutex.Unlock()
    93  	val, exists = s.m[key]
    94  	return
    95  }
    96  
    97  func (dict *ConcurrentDict) GetWithLock(key string) (val interface{}, exists bool) {
    98  	if dict == nil {
    99  		panic("dict is nil")
   100  	}
   101  	hashCode := fnv32(key)
   102  	index := dict.spread(hashCode)
   103  	s := dict.getShard(index)
   104  	val, exists = s.m[key]
   105  	return
   106  }
   107  
   108  // Len returns the number of dict
   109  func (dict *ConcurrentDict) Len() int {
   110  	if dict == nil {
   111  		panic("dict is nil")
   112  	}
   113  	return int(atomic.LoadInt32(&dict.count))
   114  }
   115  
   116  // Put puts key value into dict and returns the number of new inserted key-value
   117  func (dict *ConcurrentDict) Put(key string, val interface{}) (result int) {
   118  	if dict == nil {
   119  		panic("dict is nil")
   120  	}
   121  	hashCode := fnv32(key)
   122  	index := dict.spread(hashCode)
   123  	s := dict.getShard(index)
   124  	s.mutex.Lock()
   125  	defer s.mutex.Unlock()
   126  
   127  	if _, ok := s.m[key]; ok {
   128  		s.m[key] = val
   129  		return 0
   130  	}
   131  	dict.addCount()
   132  	s.m[key] = val
   133  	return 1
   134  }
   135  
   136  func (dict *ConcurrentDict) PutWithLock(key string, val interface{}) (result int) {
   137  	if dict == nil {
   138  		panic("dict is nil")
   139  	}
   140  	hashCode := fnv32(key)
   141  	index := dict.spread(hashCode)
   142  	s := dict.getShard(index)
   143  
   144  	if _, ok := s.m[key]; ok {
   145  		s.m[key] = val
   146  		return 0
   147  	}
   148  	dict.addCount()
   149  	s.m[key] = val
   150  	return 1
   151  }
   152  
   153  // PutIfAbsent puts value if the key is not exists and returns the number of updated key-value
   154  func (dict *ConcurrentDict) PutIfAbsent(key string, val interface{}) (result int) {
   155  	if dict == nil {
   156  		panic("dict is nil")
   157  	}
   158  	hashCode := fnv32(key)
   159  	index := dict.spread(hashCode)
   160  	s := dict.getShard(index)
   161  	s.mutex.Lock()
   162  	defer s.mutex.Unlock()
   163  
   164  	if _, ok := s.m[key]; ok {
   165  		return 0
   166  	}
   167  	s.m[key] = val
   168  	dict.addCount()
   169  	return 1
   170  }
   171  
   172  func (dict *ConcurrentDict) PutIfAbsentWithLock(key string, val interface{}) (result int) {
   173  	if dict == nil {
   174  		panic("dict is nil")
   175  	}
   176  	hashCode := fnv32(key)
   177  	index := dict.spread(hashCode)
   178  	s := dict.getShard(index)
   179  
   180  	if _, ok := s.m[key]; ok {
   181  		return 0
   182  	}
   183  	s.m[key] = val
   184  	dict.addCount()
   185  	return 1
   186  }
   187  
   188  // PutIfExists puts value if the key is exist and returns the number of inserted key-value
   189  func (dict *ConcurrentDict) PutIfExists(key string, val interface{}) (result int) {
   190  	if dict == nil {
   191  		panic("dict is nil")
   192  	}
   193  	hashCode := fnv32(key)
   194  	index := dict.spread(hashCode)
   195  	s := dict.getShard(index)
   196  	s.mutex.Lock()
   197  	defer s.mutex.Unlock()
   198  
   199  	if _, ok := s.m[key]; ok {
   200  		s.m[key] = val
   201  		return 1
   202  	}
   203  	return 0
   204  }
   205  
   206  func (dict *ConcurrentDict) PutIfExistsWithLock(key string, val interface{}) (result int) {
   207  	if dict == nil {
   208  		panic("dict is nil")
   209  	}
   210  	hashCode := fnv32(key)
   211  	index := dict.spread(hashCode)
   212  	s := dict.getShard(index)
   213  
   214  	if _, ok := s.m[key]; ok {
   215  		s.m[key] = val
   216  		return 1
   217  	}
   218  	return 0
   219  }
   220  
   221  // Remove removes the key and return the number of deleted key-value
   222  func (dict *ConcurrentDict) Remove(key string) (result int) {
   223  	if dict == nil {
   224  		panic("dict is nil")
   225  	}
   226  	hashCode := fnv32(key)
   227  	index := dict.spread(hashCode)
   228  	s := dict.getShard(index)
   229  	s.mutex.Lock()
   230  	defer s.mutex.Unlock()
   231  
   232  	if _, ok := s.m[key]; ok {
   233  		delete(s.m, key)
   234  		dict.decreaseCount()
   235  		return 1
   236  	}
   237  	return 0
   238  }
   239  
   240  func (dict *ConcurrentDict) RemoveWithLock(key string) (result int) {
   241  	if dict == nil {
   242  		panic("dict is nil")
   243  	}
   244  	hashCode := fnv32(key)
   245  	index := dict.spread(hashCode)
   246  	s := dict.getShard(index)
   247  
   248  	if _, ok := s.m[key]; ok {
   249  		delete(s.m, key)
   250  		dict.decreaseCount()
   251  		return 1
   252  	}
   253  	return 0
   254  }
   255  
   256  func (dict *ConcurrentDict) addCount() int32 {
   257  	return atomic.AddInt32(&dict.count, 1)
   258  }
   259  
   260  func (dict *ConcurrentDict) decreaseCount() int32 {
   261  	return atomic.AddInt32(&dict.count, -1)
   262  }
   263  
   264  // ForEach traversal the dict
   265  // it may not visits new entry inserted during traversal
   266  func (dict *ConcurrentDict) ForEach(consumer Consumer) {
   267  	if dict == nil {
   268  		panic("dict is nil")
   269  	}
   270  
   271  	for _, s := range dict.table {
   272  		s.mutex.RLock()
   273  		f := func() bool {
   274  			defer s.mutex.RUnlock()
   275  			for key, value := range s.m {
   276  				continues := consumer(key, value)
   277  				if !continues {
   278  					return false
   279  				}
   280  			}
   281  			return true
   282  		}
   283  		if !f() {
   284  			break
   285  		}
   286  	}
   287  }
   288  
   289  // Keys returns all keys in dict
   290  func (dict *ConcurrentDict) Keys() []string {
   291  	keys := make([]string, dict.Len())
   292  	i := 0
   293  	dict.ForEach(func(key string, val interface{}) bool {
   294  		if i < len(keys) {
   295  			keys[i] = key
   296  			i++
   297  		} else {
   298  			keys = append(keys, key)
   299  		}
   300  		return true
   301  	})
   302  	return keys
   303  }
   304  
   305  // RandomKey returns a key randomly
   306  func (shard *shard) RandomKey() string {
   307  	if shard == nil {
   308  		panic("shard is nil")
   309  	}
   310  	shard.mutex.RLock()
   311  	defer shard.mutex.RUnlock()
   312  
   313  	for key := range shard.m {
   314  		return key
   315  	}
   316  	return ""
   317  }
   318  
   319  // RandomKeys randomly returns keys of the given number, may contain duplicated key
   320  func (dict *ConcurrentDict) RandomKeys(limit int) []string {
   321  	size := dict.Len()
   322  	if limit >= size {
   323  		return dict.Keys()
   324  	}
   325  	shardCount := len(dict.table)
   326  
   327  	result := make([]string, limit)
   328  	nR := rand.New(rand.NewSource(time.Now().UnixNano()))
   329  	for i := 0; i < limit; {
   330  		s := dict.getShard(uint32(nR.Intn(shardCount)))
   331  		if s == nil {
   332  			continue
   333  		}
   334  		key := s.RandomKey()
   335  		if key != "" {
   336  			result[i] = key
   337  			i++
   338  		}
   339  	}
   340  	return result
   341  }
   342  
   343  // RandomDistinctKeys randomly returns keys of the given number, won't contain duplicated key
   344  func (dict *ConcurrentDict) RandomDistinctKeys(limit int) []string {
   345  	size := dict.Len()
   346  	if limit >= size {
   347  		return dict.Keys()
   348  	}
   349  
   350  	shardCount := len(dict.table)
   351  	result := make(map[string]struct{})
   352  	nR := rand.New(rand.NewSource(time.Now().UnixNano()))
   353  	for len(result) < limit {
   354  		shardIndex := uint32(nR.Intn(shardCount))
   355  		s := dict.getShard(shardIndex)
   356  		if s == nil {
   357  			continue
   358  		}
   359  		key := s.RandomKey()
   360  		if key != "" {
   361  			if _, exists := result[key]; !exists {
   362  				result[key] = struct{}{}
   363  			}
   364  		}
   365  	}
   366  	arr := make([]string, limit)
   367  	i := 0
   368  	for k := range result {
   369  		arr[i] = k
   370  		i++
   371  	}
   372  	return arr
   373  }
   374  
   375  // Clear removes all keys in dict
   376  func (dict *ConcurrentDict) Clear() {
   377  	*dict = *MakeConcurrent(dict.shardCount)
   378  }
   379  
   380  func (dict *ConcurrentDict) toLockIndices(keys []string, reverse bool) []uint32 {
   381  	indexMap := make(map[uint32]struct{})
   382  	for _, key := range keys {
   383  		index := dict.spread(fnv32(key))
   384  		indexMap[index] = struct{}{}
   385  	}
   386  	indices := make([]uint32, 0, len(indexMap))
   387  	for index := range indexMap {
   388  		indices = append(indices, index)
   389  	}
   390  	sort.Slice(indices, func(i, j int) bool {
   391  		if !reverse {
   392  			return indices[i] < indices[j]
   393  		}
   394  		return indices[i] > indices[j]
   395  	})
   396  	return indices
   397  }
   398  
   399  // RWLocks locks write keys and read keys together. allow duplicate keys
   400  func (dict *ConcurrentDict) RWLocks(writeKeys []string, readKeys []string) {
   401  	keys := append(writeKeys, readKeys...)
   402  	indices := dict.toLockIndices(keys, false)
   403  	writeIndexSet := make(map[uint32]struct{})
   404  	for _, wKey := range writeKeys {
   405  		idx := dict.spread(fnv32(wKey))
   406  		writeIndexSet[idx] = struct{}{}
   407  	}
   408  	for _, index := range indices {
   409  		_, w := writeIndexSet[index]
   410  		mu := &dict.table[index].mutex
   411  		if w {
   412  			mu.Lock()
   413  		} else {
   414  			mu.RLock()
   415  		}
   416  	}
   417  }
   418  
   419  // RWUnLocks unlocks write keys and read keys together. allow duplicate keys
   420  func (dict *ConcurrentDict) RWUnLocks(writeKeys []string, readKeys []string) {
   421  	keys := append(writeKeys, readKeys...)
   422  	indices := dict.toLockIndices(keys, true)
   423  	writeIndexSet := make(map[uint32]struct{})
   424  	for _, wKey := range writeKeys {
   425  		idx := dict.spread(fnv32(wKey))
   426  		writeIndexSet[idx] = struct{}{}
   427  	}
   428  	for _, index := range indices {
   429  		_, w := writeIndexSet[index]
   430  		mu := &dict.table[index].mutex
   431  		if w {
   432  			mu.Unlock()
   433  		} else {
   434  			mu.RUnlock()
   435  		}
   436  	}
   437  }