github.com/lrita/cmap@v0.0.0-20231108122212-cb084a67f554/cmap.go (about) 1 package cmap 2 3 import ( 4 "sync" 5 "sync/atomic" 6 "unsafe" 7 ) 8 9 const ( 10 mInitialSize = 1 << 4 11 mOverflowThreshold = 1 << 6 12 mOverflowGrowThreshold = 1 << 7 13 ) 14 15 // Cmap is a "thread" safe map of type AnyComparableType:Any. 16 // To avoid lock bottlenecks this map is dived to several map shards. 17 // We can store different type key and value into the same map. 18 type Cmap struct { 19 lock sync.Mutex 20 inode unsafe.Pointer // *inode 21 count int64 22 } 23 24 type inode struct { 25 mask uintptr 26 overflow int64 27 growThreshold int64 28 shrinkThreshold int64 29 resizeInProgress int64 30 pred unsafe.Pointer // *inode 31 buckets []bucket 32 } 33 34 type entry struct { 35 key, value interface{} 36 } 37 38 type bucket struct { 39 lock sync.RWMutex 40 init int64 41 m map[interface{}]interface{} 42 frozen bool 43 } 44 45 // Store sets the value for a key. 46 func (m *Cmap) Store(key, value interface{}) { 47 hash := ehash(key) 48 for { 49 inode, b := m.getInodeAndBucket(hash) 50 if b.tryStore(m, inode, false, key, value) { 51 return 52 } 53 } 54 } 55 56 // Load returns the value stored in the map for a key, or nil if no 57 // value is present. 58 // The ok result indicates whether value was found in the map. 59 func (m *Cmap) Load(key interface{}) (value interface{}, ok bool) { 60 hash := ehash(key) 61 _, b := m.getInodeAndBucket(hash) 62 return b.tryLoad(key) 63 } 64 65 // LoadOrStore returns the existing value for the key if present. 66 // Otherwise, it stores and returns the given value. 67 // The loaded result is true if the value was loaded, false if stored. 68 func (m *Cmap) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) { 69 hash := ehash(key) 70 for { 71 inode, b := m.getInodeAndBucket(hash) 72 actual, loaded = b.tryLoad(key) 73 if loaded { 74 return 75 } 76 if b.tryStore(m, inode, true, key, value) { 77 return value, false 78 } 79 } 80 } 81 82 // Delete deletes the value for a key. 83 func (m *Cmap) Delete(key interface{}) { 84 hash := ehash(key) 85 for { 86 inode, b := m.getInodeAndBucket(hash) 87 if b.tryDelete(m, inode, key) { 88 return 89 } 90 } 91 } 92 93 // Range calls f sequentially for each key and value present in the map. 94 // If f returns false, range stops the iteration. 95 // 96 // Range does not necessarily correspond to any consistent snapshot of the Map's 97 // contents: no key will be visited more than once, but if the value for any key 98 // is stored or deleted concurrently, Range may reflect any mapping for that key 99 // from any point during the Range call. 100 // 101 // Range may be O(N) with the number of elements in the map even if f returns 102 // false after a constant number of calls. 103 func (m *Cmap) Range(f func(key, value interface{}) bool) { 104 n := m.getInode() 105 for i := 0; i < len(n.buckets); i++ { 106 b := &(n.buckets[i]) 107 if !b.inited() { 108 n.initBucket(uintptr(i)) 109 } 110 for _, e := range b.clone() { 111 if !f(e.key, e.value) { 112 return 113 } 114 } 115 } 116 } 117 118 // Count returns the number of elements within the map. 119 func (m *Cmap) Count() int { 120 return int(atomic.LoadInt64(&m.count)) 121 } 122 123 // IsEmpty checks if map is empty. 124 func (m *Cmap) IsEmpty() bool { 125 return m.Count() == 0 126 } 127 128 func (m *Cmap) getInode() *inode { 129 n := (*inode)(atomic.LoadPointer(&m.inode)) 130 if n == nil { 131 m.lock.Lock() 132 n = (*inode)(atomic.LoadPointer(&m.inode)) 133 if n == nil { 134 n = &inode{ 135 mask: uintptr(mInitialSize - 1), 136 growThreshold: int64(mInitialSize * mOverflowThreshold), 137 shrinkThreshold: 0, 138 buckets: make([]bucket, mInitialSize), 139 } 140 atomic.StorePointer(&m.inode, unsafe.Pointer(n)) 141 } 142 m.lock.Unlock() 143 } 144 return n 145 } 146 147 func (m *Cmap) getInodeAndBucket(hash uintptr) (*inode, *bucket) { 148 n := m.getInode() 149 i := hash & n.mask 150 b := &(n.buckets[i]) 151 if !b.inited() { 152 n.initBucket(i) 153 } 154 return n, b 155 } 156 157 func (n *inode) initBuckets() { 158 for i := range n.buckets { 159 n.initBucket(uintptr(i)) 160 } 161 atomic.StorePointer(&n.pred, nil) 162 } 163 164 func (n *inode) initBucket(i uintptr) { 165 b := &(n.buckets[i]) 166 b.lock.Lock() 167 if b.inited() { 168 b.lock.Unlock() 169 return 170 } 171 172 b.m = make(map[interface{}]interface{}) 173 p := (*inode)(atomic.LoadPointer(&n.pred)) // predecessor 174 if p != nil { 175 if n.mask > p.mask { 176 // Grow 177 pb := &(p.buckets[i&p.mask]) 178 if !pb.inited() { 179 p.initBucket(i & p.mask) 180 } 181 for k, v := range pb.freeze() { 182 hash := ehash(k) 183 if hash&n.mask == i { 184 b.m[k] = v 185 } 186 } 187 } else { 188 // Shrink 189 pb0 := &(p.buckets[i]) 190 if !pb0.inited() { 191 p.initBucket(i) 192 } 193 pb1 := &(p.buckets[i+uintptr(len(n.buckets))]) 194 if !pb1.inited() { 195 p.initBucket(i + uintptr(len(n.buckets))) 196 } 197 for k, v := range pb0.freeze() { 198 b.m[k] = v 199 } 200 for k, v := range pb1.freeze() { 201 b.m[k] = v 202 } 203 } 204 if len(b.m) > mOverflowThreshold { 205 atomic.AddInt64(&n.overflow, int64(len(b.m)-mOverflowThreshold)) 206 } 207 } 208 209 atomic.StoreInt64(&b.init, 1) 210 b.lock.Unlock() 211 } 212 213 func (b *bucket) inited() bool { 214 return atomic.LoadInt64(&b.init) == 1 215 } 216 217 func (b *bucket) freeze() map[interface{}]interface{} { 218 b.lock.Lock() 219 b.frozen = true 220 m := b.m 221 b.lock.Unlock() 222 return m 223 } 224 225 func (b *bucket) clone() []entry { 226 b.lock.RLock() 227 entries := make([]entry, 0, len(b.m)) 228 for k, v := range b.m { 229 entries = append(entries, entry{key: k, value: v}) 230 } 231 b.lock.RUnlock() 232 return entries 233 } 234 235 func (b *bucket) tryLoad(key interface{}) (value interface{}, ok bool) { 236 b.lock.RLock() 237 value, ok = b.m[key] 238 b.lock.RUnlock() 239 return 240 } 241 242 func (b *bucket) tryStore(m *Cmap, n *inode, check bool, key, value interface{}) (done bool) { 243 b.lock.Lock() 244 if b.frozen { 245 b.lock.Unlock() 246 return 247 } 248 249 if check { 250 if _, ok := b.m[key]; ok { 251 b.lock.Unlock() 252 return 253 } 254 } 255 256 l0 := len(b.m) // Using length check existence is faster than accessing. 257 b.m[key] = value 258 length := len(b.m) 259 b.lock.Unlock() 260 261 if l0 == length { 262 return true 263 } 264 265 // Update counter 266 grow := atomic.AddInt64(&m.count, 1) >= n.growThreshold 267 if length > mOverflowThreshold { 268 grow = grow || atomic.AddInt64(&n.overflow, 1) >= mOverflowGrowThreshold 269 } 270 271 // Grow 272 if grow && atomic.CompareAndSwapInt64(&n.resizeInProgress, 0, 1) { 273 nlen := len(n.buckets) << 1 274 node := &inode{ 275 mask: uintptr(nlen) - 1, 276 pred: unsafe.Pointer(n), 277 growThreshold: int64(nlen) * mOverflowThreshold, 278 shrinkThreshold: int64(nlen) >> 1, 279 buckets: make([]bucket, nlen), 280 } 281 ok := atomic.CompareAndSwapPointer(&m.inode, unsafe.Pointer(n), unsafe.Pointer(node)) 282 if !ok { 283 panic("BUG: failed swapping head") 284 } 285 go node.initBuckets() 286 } 287 288 return true 289 } 290 291 func (b *bucket) tryDelete(m *Cmap, n *inode, key interface{}) (done bool) { 292 b.lock.Lock() 293 if b.frozen { 294 b.lock.Unlock() 295 return 296 } 297 298 l0 := len(b.m) 299 delete(b.m, key) 300 length := len(b.m) 301 b.lock.Unlock() 302 303 if l0 == length { 304 return true 305 } 306 307 // Update counter 308 shrink := atomic.AddInt64(&m.count, -1) < n.shrinkThreshold 309 if length >= mOverflowThreshold { 310 atomic.AddInt64(&n.overflow, -1) 311 } 312 // Shrink 313 if shrink && len(n.buckets) > mInitialSize && atomic.CompareAndSwapInt64(&n.resizeInProgress, 0, 1) { 314 nlen := len(n.buckets) >> 1 315 node := &inode{ 316 mask: uintptr(nlen) - 1, 317 pred: unsafe.Pointer(n), 318 growThreshold: int64(nlen) * mOverflowThreshold, 319 shrinkThreshold: int64(nlen) >> 1, 320 buckets: make([]bucket, nlen), 321 } 322 ok := atomic.CompareAndSwapPointer(&m.inode, unsafe.Pointer(n), unsafe.Pointer(node)) 323 if !ok { 324 panic("BUG: failed swapping head") 325 } 326 go node.initBuckets() 327 } 328 return true 329 } 330 331 func ehash(i interface{}) uintptr { 332 return nilinterhash(noescape(unsafe.Pointer(&i)), 0xdeadbeef) 333 } 334 335 //go:linkname nilinterhash runtime.nilinterhash 336 func nilinterhash(p unsafe.Pointer, h uintptr) uintptr 337 338 //go:nocheckptr 339 //go:nosplit 340 func noescape(p unsafe.Pointer) unsafe.Pointer { 341 x := uintptr(p) 342 return unsafe.Pointer(x ^ 0) 343 }