github.com/weedge/lib@v0.0.0-20230424045628-a36dcc1d90e4/container/concurrent_map/concurrent_map.go (about)

     1  package concurrent_map
     2  
     3  import (
     4  	"sync"
     5  )
     6  
     7  // ConcurrentMap is a thread safe map collection with better performance.
     8  // The backend map entries are separated into the different partitions.
     9  // Threads can access the different partitions safely without lock.
    10  type ConcurrentMap struct {
    11  	partitions    []*innerMap
    12  	numOfBlockets int
    13  }
    14  
    15  // Partitionable is the interface which should be implemented by key type.
    16  // It is to define how to partition the entries.
    17  type Partitionable interface {
    18  	// Value is raw value of the key
    19  	Value() interface{}
    20  
    21  	// PartitionKey is used for getting the partition to store the entry with the key.
    22  	// E.g. the key's hash could be used as its PartitionKey
    23  	// The partition for the key is partitions[(PartitionKey % m.numOfBlockets)]
    24  	//
    25  	// 1 Why not provide the default hash function for partition?
    26  	// Ans: As you known, the partition solution would impact the performance significantly.
    27  	// The proper partition solution balances the access to the different partitions and
    28  	// avoid of the hot partition. The access mode highly relates to your business.
    29  	// So, the better partition solution would just be designed according to your business.
    30  	PartitionKey() int64
    31  }
    32  
    33  type innerMap struct {
    34  	m    map[interface{}]interface{}
    35  	lock sync.RWMutex
    36  }
    37  
    38  func createInnerMap() *innerMap {
    39  	return &innerMap{
    40  		m: make(map[interface{}]interface{}),
    41  	}
    42  }
    43  
    44  func (im *innerMap) get(key Partitionable) (interface{}, bool) {
    45  	keyVal := key.Value()
    46  	im.lock.RLock()
    47  	v, ok := im.m[keyVal]
    48  	im.lock.RUnlock()
    49  	return v, ok
    50  }
    51  
    52  func (im *innerMap) set(key Partitionable, v interface{}) {
    53  	keyVal := key.Value()
    54  	im.lock.Lock()
    55  	im.m[keyVal] = v
    56  	im.lock.Unlock()
    57  }
    58  
    59  func (im *innerMap) del(key Partitionable) {
    60  	keyVal := key.Value()
    61  	im.lock.Lock()
    62  	delete(im.m, keyVal)
    63  	im.lock.Unlock()
    64  }
    65  
    66  // CreateConcurrentMap is to create a ConcurrentMap with the setting number of the partitions
    67  func CreateConcurrentMap(numOfPartitions int) *ConcurrentMap {
    68  	var partitions []*innerMap
    69  	for i := 0; i < numOfPartitions; i++ {
    70  		partitions = append(partitions, createInnerMap())
    71  	}
    72  	return &ConcurrentMap{partitions, numOfPartitions}
    73  }
    74  
    75  func (m *ConcurrentMap) getPartition(key Partitionable) *innerMap {
    76  	partitionID := key.PartitionKey() % int64(m.numOfBlockets)
    77  	return m.partitions[partitionID]
    78  }
    79  
    80  // Get is to get the value by the key
    81  func (m *ConcurrentMap) Get(key Partitionable) (interface{}, bool) {
    82  	return m.getPartition(key).get(key)
    83  }
    84  
    85  // Set is to store the KV entry to the map
    86  func (m *ConcurrentMap) Set(key Partitionable, v interface{}) {
    87  	im := m.getPartition(key)
    88  	im.set(key, v)
    89  }
    90  
    91  // Del is to delete the entries by the key
    92  func (m *ConcurrentMap) Del(key Partitionable) {
    93  	im := m.getPartition(key)
    94  	im.del(key)
    95  }
    96  
    97  type Tuple struct {
    98  	Key interface{}
    99  	Val interface{}
   100  }
   101  
   102  // snapshot shard map fan out into channels;
   103  // then fan in out channels for read;
   104  func (m *ConcurrentMap) IterBuffFromSnapshot() <-chan Tuple {
   105  	snapshotChs := m.Snapshot()
   106  	outChCap := 0
   107  	for _, ch := range snapshotChs {
   108  		outChCap += cap(ch)
   109  	}
   110  	out := make(chan Tuple, outChCap)
   111  	go m.fanIn(snapshotChs, out)
   112  
   113  	return out
   114  }
   115  
   116  // clear all shard map
   117  func (m *ConcurrentMap) Clear() {
   118  	snapshotChs := m.Snapshot()
   119  	wg := &sync.WaitGroup{}
   120  	wg.Add(len(snapshotChs))
   121  	for index, ch := range snapshotChs {
   122  		go func(i int, ch chan Tuple) {
   123  			for item := range ch {
   124  				m.partitions[i].lock.Lock()
   125  				delete(m.partitions[i].m, item.Key)
   126  				m.partitions[i].lock.Unlock()
   127  			}
   128  			wg.Done()
   129  		}(index, ch)
   130  	}
   131  	wg.Wait()
   132  }
   133  
   134  // snapshot shard map fan out into channels
   135  func (m *ConcurrentMap) Snapshot() (snapshotChs []chan Tuple) {
   136  	snapshotChs = make([]chan Tuple, m.numOfBlockets)
   137  	wg := &sync.WaitGroup{}
   138  	wg.Add(m.numOfBlockets)
   139  	for i := 0; i < m.numOfBlockets; i++ {
   140  		go func(index int, imap *innerMap) {
   141  			imap.lock.RLock()
   142  			snapshotChs[index] = make(chan Tuple, len(imap.m))
   143  			for key, val := range imap.m {
   144  				snapshotChs[index] <- Tuple{Key: key, Val: val}
   145  			}
   146  			imap.lock.RUnlock()
   147  			close(snapshotChs[index]) //once write full, close then read from ch is ok
   148  			wg.Done()
   149  		}(i, m.partitions[i])
   150  	}
   151  	wg.Wait()
   152  
   153  	return
   154  }
   155  
   156  func (m *ConcurrentMap) fanIn(chs []chan Tuple, out chan Tuple) {
   157  	wg := &sync.WaitGroup{}
   158  	wg.Add(len(chs))
   159  	for _, ch := range chs {
   160  		go func(ch chan Tuple) {
   161  			for item := range ch {
   162  				out <- item
   163  			}
   164  			wg.Done()
   165  		}(ch)
   166  	}
   167  	wg.Wait()
   168  	close(out) //once write full, close then read from ch is ok
   169  }
   170  
   171  // Count returns the number of elements within the map.
   172  func (m ConcurrentMap) Count() int {
   173  	count := 0
   174  	for i := 0; i < m.numOfBlockets; i++ {
   175  		shard := m.partitions[i]
   176  		shard.lock.RLock()
   177  		count += len(shard.m)
   178  		shard.lock.RUnlock()
   179  	}
   180  	return count
   181  }