github.com/lrita/cmap@v0.0.0-20231108122212-cb084a67f554/cmap.go (about)

     1  package cmap
     2  
     3  import (
     4  	"sync"
     5  	"sync/atomic"
     6  	"unsafe"
     7  )
     8  
     9  const (
    10  	mInitialSize           = 1 << 4
    11  	mOverflowThreshold     = 1 << 6
    12  	mOverflowGrowThreshold = 1 << 7
    13  )
    14  
    15  // Cmap is a "thread" safe map of type AnyComparableType:Any.
    16  // To avoid lock bottlenecks this map is dived to several map shards.
    17  // We can store different type key and value into the same map.
    18  type Cmap struct {
    19  	lock  sync.Mutex
    20  	inode unsafe.Pointer // *inode
    21  	count int64
    22  }
    23  
    24  type inode struct {
    25  	mask             uintptr
    26  	overflow         int64
    27  	growThreshold    int64
    28  	shrinkThreshold  int64
    29  	resizeInProgress int64
    30  	pred             unsafe.Pointer // *inode
    31  	buckets          []bucket
    32  }
    33  
    34  type entry struct {
    35  	key, value interface{}
    36  }
    37  
    38  type bucket struct {
    39  	lock   sync.RWMutex
    40  	init   int64
    41  	m      map[interface{}]interface{}
    42  	frozen bool
    43  }
    44  
    45  // Store sets the value for a key.
    46  func (m *Cmap) Store(key, value interface{}) {
    47  	hash := ehash(key)
    48  	for {
    49  		inode, b := m.getInodeAndBucket(hash)
    50  		if b.tryStore(m, inode, false, key, value) {
    51  			return
    52  		}
    53  	}
    54  }
    55  
    56  // Load returns the value stored in the map for a key, or nil if no
    57  // value is present.
    58  // The ok result indicates whether value was found in the map.
    59  func (m *Cmap) Load(key interface{}) (value interface{}, ok bool) {
    60  	hash := ehash(key)
    61  	_, b := m.getInodeAndBucket(hash)
    62  	return b.tryLoad(key)
    63  }
    64  
    65  // LoadOrStore returns the existing value for the key if present.
    66  // Otherwise, it stores and returns the given value.
    67  // The loaded result is true if the value was loaded, false if stored.
    68  func (m *Cmap) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) {
    69  	hash := ehash(key)
    70  	for {
    71  		inode, b := m.getInodeAndBucket(hash)
    72  		actual, loaded = b.tryLoad(key)
    73  		if loaded {
    74  			return
    75  		}
    76  		if b.tryStore(m, inode, true, key, value) {
    77  			return value, false
    78  		}
    79  	}
    80  }
    81  
    82  // Delete deletes the value for a key.
    83  func (m *Cmap) Delete(key interface{}) {
    84  	hash := ehash(key)
    85  	for {
    86  		inode, b := m.getInodeAndBucket(hash)
    87  		if b.tryDelete(m, inode, key) {
    88  			return
    89  		}
    90  	}
    91  }
    92  
    93  // Range calls f sequentially for each key and value present in the map.
    94  // If f returns false, range stops the iteration.
    95  //
    96  // Range does not necessarily correspond to any consistent snapshot of the Map's
    97  // contents: no key will be visited more than once, but if the value for any key
    98  // is stored or deleted concurrently, Range may reflect any mapping for that key
    99  // from any point during the Range call.
   100  //
   101  // Range may be O(N) with the number of elements in the map even if f returns
   102  // false after a constant number of calls.
   103  func (m *Cmap) Range(f func(key, value interface{}) bool) {
   104  	n := m.getInode()
   105  	for i := 0; i < len(n.buckets); i++ {
   106  		b := &(n.buckets[i])
   107  		if !b.inited() {
   108  			n.initBucket(uintptr(i))
   109  		}
   110  		for _, e := range b.clone() {
   111  			if !f(e.key, e.value) {
   112  				return
   113  			}
   114  		}
   115  	}
   116  }
   117  
   118  // Count returns the number of elements within the map.
   119  func (m *Cmap) Count() int {
   120  	return int(atomic.LoadInt64(&m.count))
   121  }
   122  
   123  // IsEmpty checks if map is empty.
   124  func (m *Cmap) IsEmpty() bool {
   125  	return m.Count() == 0
   126  }
   127  
   128  func (m *Cmap) getInode() *inode {
   129  	n := (*inode)(atomic.LoadPointer(&m.inode))
   130  	if n == nil {
   131  		m.lock.Lock()
   132  		n = (*inode)(atomic.LoadPointer(&m.inode))
   133  		if n == nil {
   134  			n = &inode{
   135  				mask:            uintptr(mInitialSize - 1),
   136  				growThreshold:   int64(mInitialSize * mOverflowThreshold),
   137  				shrinkThreshold: 0,
   138  				buckets:         make([]bucket, mInitialSize),
   139  			}
   140  			atomic.StorePointer(&m.inode, unsafe.Pointer(n))
   141  		}
   142  		m.lock.Unlock()
   143  	}
   144  	return n
   145  }
   146  
   147  func (m *Cmap) getInodeAndBucket(hash uintptr) (*inode, *bucket) {
   148  	n := m.getInode()
   149  	i := hash & n.mask
   150  	b := &(n.buckets[i])
   151  	if !b.inited() {
   152  		n.initBucket(i)
   153  	}
   154  	return n, b
   155  }
   156  
   157  func (n *inode) initBuckets() {
   158  	for i := range n.buckets {
   159  		n.initBucket(uintptr(i))
   160  	}
   161  	atomic.StorePointer(&n.pred, nil)
   162  }
   163  
   164  func (n *inode) initBucket(i uintptr) {
   165  	b := &(n.buckets[i])
   166  	b.lock.Lock()
   167  	if b.inited() {
   168  		b.lock.Unlock()
   169  		return
   170  	}
   171  
   172  	b.m = make(map[interface{}]interface{})
   173  	p := (*inode)(atomic.LoadPointer(&n.pred)) // predecessor
   174  	if p != nil {
   175  		if n.mask > p.mask {
   176  			// Grow
   177  			pb := &(p.buckets[i&p.mask])
   178  			if !pb.inited() {
   179  				p.initBucket(i & p.mask)
   180  			}
   181  			for k, v := range pb.freeze() {
   182  				hash := ehash(k)
   183  				if hash&n.mask == i {
   184  					b.m[k] = v
   185  				}
   186  			}
   187  		} else {
   188  			// Shrink
   189  			pb0 := &(p.buckets[i])
   190  			if !pb0.inited() {
   191  				p.initBucket(i)
   192  			}
   193  			pb1 := &(p.buckets[i+uintptr(len(n.buckets))])
   194  			if !pb1.inited() {
   195  				p.initBucket(i + uintptr(len(n.buckets)))
   196  			}
   197  			for k, v := range pb0.freeze() {
   198  				b.m[k] = v
   199  			}
   200  			for k, v := range pb1.freeze() {
   201  				b.m[k] = v
   202  			}
   203  		}
   204  		if len(b.m) > mOverflowThreshold {
   205  			atomic.AddInt64(&n.overflow, int64(len(b.m)-mOverflowThreshold))
   206  		}
   207  	}
   208  
   209  	atomic.StoreInt64(&b.init, 1)
   210  	b.lock.Unlock()
   211  }
   212  
   213  func (b *bucket) inited() bool {
   214  	return atomic.LoadInt64(&b.init) == 1
   215  }
   216  
   217  func (b *bucket) freeze() map[interface{}]interface{} {
   218  	b.lock.Lock()
   219  	b.frozen = true
   220  	m := b.m
   221  	b.lock.Unlock()
   222  	return m
   223  }
   224  
   225  func (b *bucket) clone() []entry {
   226  	b.lock.RLock()
   227  	entries := make([]entry, 0, len(b.m))
   228  	for k, v := range b.m {
   229  		entries = append(entries, entry{key: k, value: v})
   230  	}
   231  	b.lock.RUnlock()
   232  	return entries
   233  }
   234  
   235  func (b *bucket) tryLoad(key interface{}) (value interface{}, ok bool) {
   236  	b.lock.RLock()
   237  	value, ok = b.m[key]
   238  	b.lock.RUnlock()
   239  	return
   240  }
   241  
   242  func (b *bucket) tryStore(m *Cmap, n *inode, check bool, key, value interface{}) (done bool) {
   243  	b.lock.Lock()
   244  	if b.frozen {
   245  		b.lock.Unlock()
   246  		return
   247  	}
   248  
   249  	if check {
   250  		if _, ok := b.m[key]; ok {
   251  			b.lock.Unlock()
   252  			return
   253  		}
   254  	}
   255  
   256  	l0 := len(b.m) // Using length check existence is faster than accessing.
   257  	b.m[key] = value
   258  	length := len(b.m)
   259  	b.lock.Unlock()
   260  
   261  	if l0 == length {
   262  		return true
   263  	}
   264  
   265  	// Update counter
   266  	grow := atomic.AddInt64(&m.count, 1) >= n.growThreshold
   267  	if length > mOverflowThreshold {
   268  		grow = grow || atomic.AddInt64(&n.overflow, 1) >= mOverflowGrowThreshold
   269  	}
   270  
   271  	// Grow
   272  	if grow && atomic.CompareAndSwapInt64(&n.resizeInProgress, 0, 1) {
   273  		nlen := len(n.buckets) << 1
   274  		node := &inode{
   275  			mask:            uintptr(nlen) - 1,
   276  			pred:            unsafe.Pointer(n),
   277  			growThreshold:   int64(nlen) * mOverflowThreshold,
   278  			shrinkThreshold: int64(nlen) >> 1,
   279  			buckets:         make([]bucket, nlen),
   280  		}
   281  		ok := atomic.CompareAndSwapPointer(&m.inode, unsafe.Pointer(n), unsafe.Pointer(node))
   282  		if !ok {
   283  			panic("BUG: failed swapping head")
   284  		}
   285  		go node.initBuckets()
   286  	}
   287  
   288  	return true
   289  }
   290  
   291  func (b *bucket) tryDelete(m *Cmap, n *inode, key interface{}) (done bool) {
   292  	b.lock.Lock()
   293  	if b.frozen {
   294  		b.lock.Unlock()
   295  		return
   296  	}
   297  
   298  	l0 := len(b.m)
   299  	delete(b.m, key)
   300  	length := len(b.m)
   301  	b.lock.Unlock()
   302  
   303  	if l0 == length {
   304  		return true
   305  	}
   306  
   307  	// Update counter
   308  	shrink := atomic.AddInt64(&m.count, -1) < n.shrinkThreshold
   309  	if length >= mOverflowThreshold {
   310  		atomic.AddInt64(&n.overflow, -1)
   311  	}
   312  	// Shrink
   313  	if shrink && len(n.buckets) > mInitialSize && atomic.CompareAndSwapInt64(&n.resizeInProgress, 0, 1) {
   314  		nlen := len(n.buckets) >> 1
   315  		node := &inode{
   316  			mask:            uintptr(nlen) - 1,
   317  			pred:            unsafe.Pointer(n),
   318  			growThreshold:   int64(nlen) * mOverflowThreshold,
   319  			shrinkThreshold: int64(nlen) >> 1,
   320  			buckets:         make([]bucket, nlen),
   321  		}
   322  		ok := atomic.CompareAndSwapPointer(&m.inode, unsafe.Pointer(n), unsafe.Pointer(node))
   323  		if !ok {
   324  			panic("BUG: failed swapping head")
   325  		}
   326  		go node.initBuckets()
   327  	}
   328  	return true
   329  }
   330  
   331  func ehash(i interface{}) uintptr {
   332  	return nilinterhash(noescape(unsafe.Pointer(&i)), 0xdeadbeef)
   333  }
   334  
   335  //go:linkname nilinterhash runtime.nilinterhash
   336  func nilinterhash(p unsafe.Pointer, h uintptr) uintptr
   337  
   338  //go:nocheckptr
   339  //go:nosplit
   340  func noescape(p unsafe.Pointer) unsafe.Pointer {
   341  	x := uintptr(p)
   342  	return unsafe.Pointer(x ^ 0)
   343  }