github.com/min1324/cmap@v1.0.3-0.20220418125848-74e72bbe3be4/cmap.go (about)

     1  package cmap
     2  
     3  import (
     4  	"runtime"
     5  	"sync"
     6  	"sync/atomic"
     7  	"unsafe"
     8  )
     9  
    10  const (
    11  	mInitBit  = 4
    12  	mInitSize = 1 << mInitBit
    13  )
    14  
    15  // CMap is a "thread" safe Cmap of type AnyComparableType:Any.
    16  // To avoid lock bottlenecks this Cmap is dived to several Cmap shards.
    17  type CMap struct {
    18  	// mu    sync.Mutex
    19  	count int64
    20  	node  unsafe.Pointer
    21  }
    22  
    23  type node struct {
    24  	mask    uintptr        // 1<<B - 1
    25  	B       uint8          // log_2 of # of buckets (can hold up to loadFactor * 2^B items)
    26  	resize  uint32         // 重新计算进程,0表示完成,1表示正在进行
    27  	oldNode unsafe.Pointer // *node
    28  	buckets []bucket
    29  }
    30  
    31  type bucket struct {
    32  	mu       sync.RWMutex
    33  	init     sync.Once
    34  	evacuted int32                       // 1 表示oldNode对应buckut已经迁移到新buckut
    35  	frozen   int32                       // true表示当前bucket已经冻结,进行resize
    36  	m        map[interface{}]interface{} //
    37  }
    38  
    39  // Load returns the value stored in the Cmap for a key, or nil if no
    40  // value is present.
    41  // The ok result indicates whether value was found in the Cmap.
    42  func (m *CMap) Load(key interface{}) (value interface{}, ok bool) {
    43  	hash := chash(key)
    44  	_, b := m.getNodeAndBucket(hash)
    45  	value, ok = b.tryLoad(key)
    46  	return
    47  }
    48  
    49  // Store sets the value for a key.
    50  func (m *CMap) Store(key, value interface{}) {
    51  	hash := chash(key)
    52  	for {
    53  		n, b := m.getNodeAndBucket(hash)
    54  		if b.tryStore(m, n, key, value) {
    55  			return
    56  		}
    57  	}
    58  }
    59  
    60  // LoadOrStore returns the existing value for the key if present.
    61  // Otherwise, it stores and returns the given value.
    62  // The loaded result is true if the value was loaded, false if stored.
    63  func (m *CMap) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) {
    64  	hash := chash(key)
    65  	var ok bool
    66  	for {
    67  		n, b := m.getNodeAndBucket(hash)
    68  		actual, loaded, ok = b.tryLoadOrStore(m, n, key, value)
    69  		if ok {
    70  			return
    71  		}
    72  		runtime.Gosched()
    73  	}
    74  }
    75  
    76  // Delete deletes the value for a key.
    77  func (m *CMap) Delete(key interface{}) {
    78  	m.LoadAndDelete(key)
    79  }
    80  
    81  // LoadAndDelete deletes the value for a key, returning the previous value if any.
    82  // The loaded result reports whether the key was present.
    83  func (m *CMap) LoadAndDelete(key interface{}) (value interface{}, loaded bool) {
    84  	hash := chash(key)
    85  	var ok bool
    86  	for {
    87  		n, b := m.getNodeAndBucket(hash)
    88  		value, loaded, ok = b.tryLoadAndDelete(m, n, key)
    89  		if ok {
    90  			return
    91  		}
    92  		runtime.Gosched()
    93  	}
    94  }
    95  
    96  // Range calls f sequentially for each key and value present in the Cmap.
    97  // If f returns false, range stops the iteration.
    98  //
    99  // Range does not necessarily correspond to any consistent snapshot of the CMap's
   100  // contents: no key will be visited more than once, but if the value for any key
   101  // is stored or deleted concurrently, Range may reflect any Cmapping for that key
   102  // from any point during the Range call.
   103  //
   104  // Range may be O(N) with the number of elements in the Cmap even if f returns
   105  // false after a constant number of calls.
   106  func (m *CMap) Range(f func(key, value interface{}) bool) {
   107  	n := m.getNode()
   108  	for i := range n.buckets {
   109  		b := n.getBucket(uintptr(i))
   110  		if !b.walk(f) {
   111  			return
   112  		}
   113  	}
   114  }
   115  
   116  // Count returns the number of elements within the Cmap.
   117  func (m *CMap) Count() int64 {
   118  	return atomic.LoadInt64(&m.count)
   119  }
   120  
   121  func (m *CMap) getNodeAndBucket(hash uintptr) (n *node, b *bucket) {
   122  	n = m.getNode()
   123  	b = n.getBucket(hash)
   124  	return n, b
   125  }
   126  
   127  func (m *CMap) getNode() *node {
   128  	for {
   129  		n := (*node)(atomic.LoadPointer(&m.node))
   130  		if n != nil {
   131  			return n
   132  		}
   133  		// node == nil, init node.
   134  		newNode := &node{
   135  			mask:    uintptr(mInitSize - 1),
   136  			B:       mInitBit,
   137  			buckets: make([]bucket, mInitSize),
   138  		}
   139  		if atomic.CompareAndSwapPointer(&m.node, nil, unsafe.Pointer(newNode)) {
   140  			return newNode
   141  		}
   142  	}
   143  	// n := (*node)(atomic.LoadPointer(&m.node))
   144  	// if n == nil {
   145  	// 	m.mu.Lock()
   146  	// 	n = (*node)(atomic.LoadPointer(&m.node))
   147  	// 	if n == nil {
   148  	// 		n = &node{
   149  	// 			mask:    uintptr(mInitSize - 1),
   150  	// 			B:       mInitBit,
   151  	// 			buckets: make([]bucket, mInitSize),
   152  	// 		}
   153  	// 		atomic.StorePointer(&m.node, unsafe.Pointer(n))
   154  	// 	}
   155  	// 	m.mu.Unlock()
   156  	// }
   157  	// return n
   158  }
   159  
   160  // give a hash key and return it's store bucket
   161  func (n *node) getBucket(i uintptr) *bucket {
   162  	i = i & n.mask
   163  	b := &(n.buckets[i])
   164  	b.onceInit()
   165  	oldNode := (*node)(atomic.LoadPointer(&n.oldNode))
   166  	if oldNode != nil && !b.hadEvacuted() {
   167  		evacute(n, oldNode, b, i)
   168  	}
   169  	return b
   170  }
   171  
   172  func (n *node) initBuckets() {
   173  	for i := range n.buckets {
   174  		n.getBucket(uintptr(i))
   175  	}
   176  	// empty oldNode
   177  	atomic.StorePointer(&n.oldNode, nil)
   178  	// finish all evacute
   179  	atomic.StoreUint32(&n.resize, 0)
   180  }
   181  
   182  // evacute oldNode -> newNode
   183  // i must be b==new.buckuts[i&n.mask]
   184  func evacute(new, old *node, b *bucket, i uintptr) {
   185  	b.mu.Lock()
   186  	defer b.mu.Unlock()
   187  	if b.hadEvacuted() || old == nil {
   188  		return
   189  	}
   190  	if new.mask > old.mask {
   191  		// grow
   192  		pb := old.getBucket(i)
   193  		pb.freezeInLock(func(k, v interface{}) bool {
   194  			h := chash(k)
   195  			if h&new.mask == i {
   196  				b.m[k] = v
   197  			}
   198  			return true
   199  		})
   200  	} else {
   201  		// shrink
   202  		pb0 := old.getBucket(i)
   203  		pb1 := old.getBucket(i + bucketShift(new.B))
   204  		pb0.freezeInLock(func(k, v interface{}) bool {
   205  			b.m[k] = v
   206  			return true
   207  		})
   208  		pb1.freezeInLock(func(k, v interface{}) bool {
   209  			b.m[k] = v
   210  			return true
   211  		})
   212  	}
   213  	atomic.StoreInt32(&b.evacuted, 1)
   214  }
   215  
   216  func (b *bucket) onceInit() {
   217  	b.init.Do(func() {
   218  		b.m = make(map[interface{}]interface{})
   219  	})
   220  }
   221  
   222  func (b *bucket) hadEvacuted() bool {
   223  	return atomic.LoadInt32(&b.evacuted) == 1
   224  }
   225  
   226  func (b *bucket) hadFrozen() bool {
   227  	return atomic.LoadInt32(&b.frozen) == 1
   228  }
   229  
   230  func (b *bucket) freezeInLock(f func(k, v interface{}) bool) (done bool) {
   231  	b.mu.Lock()
   232  	defer b.mu.Unlock()
   233  	atomic.StoreInt32(&b.frozen, 1)
   234  
   235  	// BUG issue001 b.m race with delete(b.m,key)
   236  	for k, v := range b.m {
   237  		if !f(k, v) {
   238  			return false
   239  		}
   240  	}
   241  	return true
   242  }
   243  
   244  func (b *bucket) walk(f func(k, v interface{}) bool) (done bool) {
   245  	// use in range
   246  	type entry struct {
   247  		key, value interface{}
   248  	}
   249  	b.mu.Lock()
   250  	entries := make([]entry, 0, len(b.m))
   251  	for k, v := range b.m {
   252  		entries = append(entries, entry{key: k, value: v})
   253  	}
   254  	b.mu.Unlock()
   255  
   256  	for _, e := range entries {
   257  		if !f(e.key, e.value) {
   258  			return false
   259  		}
   260  	}
   261  	return true
   262  }
   263  
   264  func (b *bucket) tryLoad(key interface{}) (value interface{}, ok bool) {
   265  	b.mu.RLock()
   266  	value, ok = b.m[key]
   267  	b.mu.RUnlock()
   268  	return
   269  }
   270  
   271  func (b *bucket) tryStore(m *CMap, n *node, key, value interface{}) bool {
   272  	b.mu.Lock()
   273  	defer b.mu.Unlock()
   274  	if b.hadFrozen() {
   275  		return false
   276  	}
   277  
   278  	l0 := len(b.m) // Using length check existence is faster than accessing.
   279  	b.m[key] = value
   280  	l1 := len(b.m)
   281  	if l0 == l1 {
   282  		return true
   283  	}
   284  	count := atomic.AddInt64(&m.count, 1)
   285  	// grow
   286  	if overLoadFactor(int64(l1), n.B) || overflowGrow(count, n.B) {
   287  		growWork(m, n, n.B+1)
   288  	}
   289  	return true
   290  }
   291  
   292  func (b *bucket) tryLoadOrStore(m *CMap, n *node, key, value interface{}) (actual interface{}, loaded, ok bool) {
   293  	b.mu.Lock()
   294  	defer b.mu.Unlock()
   295  	if b.hadFrozen() {
   296  		return nil, false, false
   297  	}
   298  	actual, loaded = b.m[key]
   299  	if loaded {
   300  		return actual, loaded, true
   301  	}
   302  	b.m[key] = value
   303  	count := atomic.AddInt64(&m.count, 1)
   304  
   305  	// grow
   306  	if overLoadFactor(int64(len(b.m)), n.B) || overflowGrow(count, n.B) {
   307  		growWork(m, n, n.B+1)
   308  	}
   309  	return value, false, true
   310  }
   311  
   312  func (b *bucket) tryLoadAndDelete(m *CMap, n *node, key interface{}) (actual interface{}, loaded, ok bool) {
   313  	if b.hadFrozen() {
   314  		return nil, false, false
   315  	}
   316  	b.mu.Lock()
   317  	defer b.mu.Unlock()
   318  	if b.hadFrozen() {
   319  		return nil, false, false
   320  	}
   321  	actual, loaded = b.m[key]
   322  	if !loaded {
   323  		return nil, false, true
   324  	}
   325  
   326  	// BUG issue001 b.m race with delete(b.m,key)
   327  	delete(b.m, key)
   328  	count := atomic.AddInt64(&m.count, -1)
   329  
   330  	// shrink
   331  	if belowShrink(count, n.B) {
   332  		growWork(m, n, n.B-1)
   333  	}
   334  	return actual, loaded, true
   335  }
   336  
   337  func growWork(m *CMap, n *node, B uint8) {
   338  	if !n.growing() && atomic.CompareAndSwapUint32(&n.resize, 0, 1) {
   339  		for {
   340  			nn := &node{
   341  				mask:    bucketMask(B),
   342  				B:       B,
   343  				resize:  1,
   344  				oldNode: unsafe.Pointer(n),
   345  				buckets: make([]bucket, bucketShift(B)),
   346  			}
   347  			if atomic.CompareAndSwapPointer(&m.node, unsafe.Pointer(n), unsafe.Pointer(nn)) {
   348  				go nn.initBuckets()
   349  				return
   350  			}
   351  		}
   352  	}
   353  }
   354  
   355  func (n *node) growing() bool {
   356  	return atomic.LoadPointer(&n.oldNode) != nil
   357  }
   358  
   359  // buckut len over loadfactor
   360  func overLoadFactor(blen int64, B uint8) bool {
   361  	// TODO adjust loadfactor
   362  	return blen*13/2 > int64(1<<(B)) && B < 31
   363  }
   364  
   365  // count overflow grow threshold
   366  func overflowGrow(count int64, B uint8) bool {
   367  	if B > 31 {
   368  		return false
   369  	}
   370  	return count >= int64(1<<(2*B))
   371  }
   372  
   373  // count below shrink threshold
   374  func belowShrink(count int64, B uint8) bool {
   375  	if B-1 <= mInitBit {
   376  		return false
   377  	}
   378  	return count < int64(1<<(B-1))
   379  }
   380  
   381  // bucketShift returns 1<<b, optimized for code generation.
   382  func bucketShift(b uint8) uintptr {
   383  	// Masking the shift amount allows overflow checks to be elided.
   384  	return uintptr(1) << (b)
   385  }
   386  
   387  // bucketMask returns 1<<b - 1, optimized for code generation.
   388  func bucketMask(b uint8) uintptr {
   389  	return bucketShift(b) - 1
   390  }