github.com/lrita/cmap@v0.0.0-20231108122212-cb084a67f554/map.go (about)

     1  //go:build go1.18
     2  // +build go1.18
     3  
     4  package cmap
     5  
     6  import (
     7  	"reflect"
     8  	"sync"
     9  	"sync/atomic"
    10  	"unsafe"
    11  )
    12  
    13  // Map is a "thread" generics safe map of type AnyComparableType:Any
    14  // (AnyComparableType exclude interface type).
    15  // To avoid lock bottlenecks this map is dived to several map shards.
    16  type Map[K comparable, V any] struct {
    17  	lock  sync.Mutex
    18  	inode unsafe.Pointer // *inode2
    19  	typ   *rtype
    20  	count int64
    21  }
    22  
    23  type bucket2[K comparable, V any] struct {
    24  	lock   sync.RWMutex
    25  	init   int64
    26  	m      map[K]V
    27  	frozen bool
    28  }
    29  
    30  type entry2[K any, V any] struct {
    31  	key   K
    32  	value V
    33  }
    34  
    35  type inode2[K comparable, V any] struct {
    36  	mask             uintptr
    37  	overflow         int64
    38  	growThreshold    int64
    39  	shrinkThreshold  int64
    40  	resizeInProgress int64
    41  	pred             unsafe.Pointer // *inode
    42  	buckets          []bucket2[K, V]
    43  }
    44  
    45  // Store sets the value for a key.
    46  func (m *Map[K, V]) Store(key K, value V) {
    47  	hash := m.ehash(key)
    48  	for {
    49  		inode, b := m.getInodeAndBucket(hash)
    50  		if b.tryStore(m, inode, false, key, value) {
    51  			return
    52  		}
    53  	}
    54  }
    55  
    56  // Load returns the value stored in the map for a key, or nil if no
    57  // value is present.
    58  // The ok result indicates whether value was found in the map.
    59  func (m *Map[K, V]) Load(key K) (value V, ok bool) {
    60  	hash := m.ehash(key)
    61  	_, b := m.getInodeAndBucket(hash)
    62  	return b.tryLoad(key)
    63  }
    64  
    65  // LoadOrStore returns the existing value for the key if present.
    66  // Otherwise, it stores and returns the given value.
    67  // The loaded result is true if the value was loaded, false if stored.
    68  func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
    69  	hash := m.ehash(key)
    70  	for {
    71  		inode, b := m.getInodeAndBucket(hash)
    72  		actual, loaded = b.tryLoad(key)
    73  		if loaded {
    74  			return
    75  		}
    76  		if b.tryStore(m, inode, true, key, value) {
    77  			return value, false
    78  		}
    79  	}
    80  }
    81  
    82  // Delete deletes the value for a key.
    83  func (m *Map[K, V]) Delete(key K) {
    84  	hash := m.ehash(key)
    85  	for {
    86  		inode, b := m.getInodeAndBucket(hash)
    87  		if b.tryDelete(m, inode, key) {
    88  			return
    89  		}
    90  	}
    91  }
    92  
    93  // Range calls f sequentially for each key and value present in the map.
    94  // If f returns false, range stops the iteration.
    95  //
    96  // Range does not necessarily correspond to any consistent snapshot of the Map's
    97  // contents: no key will be visited more than once, but if the value for any key
    98  // is stored or deleted concurrently, Range may reflect any mapping for that key
    99  // from any point during the Range call.
   100  //
   101  // Range may be O(N) with the number of elements in the map even if f returns
   102  // false after a constant number of calls.
   103  func (m *Map[K, V]) Range(f func(key K, value V) bool) {
   104  	n := m.getInode()
   105  	for i := 0; i < len(n.buckets); i++ {
   106  		b := &(n.buckets[i])
   107  		if !b.inited() {
   108  			n.initBucket(m, uintptr(i))
   109  		}
   110  		for _, e := range b.clone() {
   111  			if !f(e.key, e.value) {
   112  				return
   113  			}
   114  		}
   115  	}
   116  }
   117  
   118  // Count returns the number of elements within the map.
   119  func (m *Map[K, V]) Count() int {
   120  	return int(atomic.LoadInt64(&m.count))
   121  }
   122  
   123  // IsEmpty checks if map is empty.
   124  func (m *Map[K, V]) IsEmpty() bool {
   125  	return m.Count() == 0
   126  }
   127  
   128  func (m *Map[K, V]) getInode() *inode2[K, V] {
   129  	n := (*inode2[K, V])(atomic.LoadPointer(&m.inode))
   130  	if n == nil {
   131  		m.lock.Lock()
   132  		n = (*inode2[K, V])(atomic.LoadPointer(&m.inode))
   133  		if n == nil {
   134  			n = &inode2[K, V]{
   135  				mask:            uintptr(mInitialSize - 1),
   136  				growThreshold:   int64(mInitialSize * mOverflowThreshold),
   137  				shrinkThreshold: 0,
   138  				buckets:         make([]bucket2[K, V], mInitialSize),
   139  			}
   140  			atomic.StorePointer(&m.inode, unsafe.Pointer(n))
   141  		}
   142  		m.lock.Unlock()
   143  	}
   144  	return n
   145  }
   146  
   147  func (m *Map[K, V]) getInodeAndBucket(hash uintptr) (*inode2[K, V], *bucket2[K, V]) {
   148  	n := m.getInode()
   149  	i := hash & n.mask
   150  	b := &(n.buckets[i])
   151  	if !b.inited() {
   152  		n.initBucket(m, i)
   153  	}
   154  	return n, b
   155  }
   156  
   157  func (n *inode2[K, V]) initBuckets(m *Map[K, V]) {
   158  	for i := range n.buckets {
   159  		n.initBucket(m, uintptr(i))
   160  	}
   161  	atomic.StorePointer(&n.pred, nil)
   162  }
   163  
   164  func (n *inode2[K, V]) initBucket(m *Map[K, V], i uintptr) {
   165  	b := &(n.buckets[i])
   166  	b.lock.Lock()
   167  	if b.inited() {
   168  		b.lock.Unlock()
   169  		return
   170  	}
   171  
   172  	b.m = make(map[K]V)
   173  	p := (*inode2[K, V])(atomic.LoadPointer(&n.pred)) // predecessor
   174  	if p != nil {
   175  		if n.mask > p.mask {
   176  			// Grow
   177  			pb := &(p.buckets[i&p.mask])
   178  			if !pb.inited() {
   179  				p.initBucket(m, i&p.mask)
   180  			}
   181  			for k, v := range pb.freeze() {
   182  				hash := m.ehash(k)
   183  				if hash&n.mask == i {
   184  					b.m[k] = v
   185  				}
   186  			}
   187  		} else {
   188  			// Shrink
   189  			pb0 := &(p.buckets[i])
   190  			if !pb0.inited() {
   191  				p.initBucket(m, i)
   192  			}
   193  			pb1 := &(p.buckets[i+uintptr(len(n.buckets))])
   194  			if !pb1.inited() {
   195  				p.initBucket(m, i+uintptr(len(n.buckets)))
   196  			}
   197  			for k, v := range pb0.freeze() {
   198  				b.m[k] = v
   199  			}
   200  			for k, v := range pb1.freeze() {
   201  				b.m[k] = v
   202  			}
   203  		}
   204  		if len(b.m) > mOverflowThreshold {
   205  			atomic.AddInt64(&n.overflow, int64(len(b.m)-mOverflowThreshold))
   206  		}
   207  	}
   208  
   209  	atomic.StoreInt64(&b.init, 1)
   210  	b.lock.Unlock()
   211  }
   212  
   213  func (b *bucket2[K, V]) inited() bool {
   214  	return atomic.LoadInt64(&b.init) == 1
   215  }
   216  
   217  func (b *bucket2[K, V]) freeze() map[K]V {
   218  	b.lock.Lock()
   219  	b.frozen = true
   220  	m := b.m
   221  	b.lock.Unlock()
   222  	return m
   223  }
   224  
   225  func (b *bucket2[K, V]) clone() []entry2[K, V] {
   226  	b.lock.RLock()
   227  	entries := make([]entry2[K, V], 0, len(b.m))
   228  	for k, v := range b.m {
   229  		entries = append(entries, entry2[K, V]{key: k, value: v})
   230  	}
   231  	b.lock.RUnlock()
   232  	return entries
   233  }
   234  
   235  func (b *bucket2[K, V]) tryLoad(key K) (value V, ok bool) {
   236  	b.lock.RLock()
   237  	value, ok = b.m[key]
   238  	b.lock.RUnlock()
   239  	return
   240  }
   241  
   242  func (b *bucket2[K, V]) tryStore(m *Map[K, V], n *inode2[K, V], check bool, key K, value V) (done bool) {
   243  	b.lock.Lock()
   244  	if b.frozen {
   245  		b.lock.Unlock()
   246  		return
   247  	}
   248  
   249  	if check {
   250  		if _, ok := b.m[key]; ok {
   251  			b.lock.Unlock()
   252  			return
   253  		}
   254  	}
   255  
   256  	l0 := len(b.m) // Using length check existence is faster than accessing.
   257  	b.m[key] = value
   258  	length := len(b.m)
   259  	b.lock.Unlock()
   260  
   261  	if l0 == length {
   262  		return true
   263  	}
   264  
   265  	// Update counter
   266  	grow := atomic.AddInt64(&m.count, 1) >= n.growThreshold
   267  	if length > mOverflowThreshold {
   268  		grow = grow || atomic.AddInt64(&n.overflow, 1) >= mOverflowGrowThreshold
   269  	}
   270  
   271  	// Grow
   272  	if grow && atomic.CompareAndSwapInt64(&n.resizeInProgress, 0, 1) {
   273  		nlen := len(n.buckets) << 1
   274  		node := &inode2[K, V]{
   275  			mask:            uintptr(nlen) - 1,
   276  			pred:            unsafe.Pointer(n),
   277  			growThreshold:   int64(nlen) * mOverflowThreshold,
   278  			shrinkThreshold: int64(nlen) >> 1,
   279  			buckets:         make([]bucket2[K, V], nlen),
   280  		}
   281  		ok := atomic.CompareAndSwapPointer(&m.inode, unsafe.Pointer(n), unsafe.Pointer(node))
   282  		if !ok {
   283  			panic("BUG: failed swapping head")
   284  		}
   285  		go node.initBuckets(m)
   286  	}
   287  
   288  	return true
   289  }
   290  
   291  func (b *bucket2[K, V]) tryDelete(m *Map[K, V], n *inode2[K, V], key K) (done bool) {
   292  	b.lock.Lock()
   293  	if b.frozen {
   294  		b.lock.Unlock()
   295  		return
   296  	}
   297  
   298  	l0 := len(b.m)
   299  	delete(b.m, key)
   300  	length := len(b.m)
   301  	b.lock.Unlock()
   302  
   303  	if l0 == length {
   304  		return true
   305  	}
   306  
   307  	// Update counter
   308  	shrink := atomic.AddInt64(&m.count, -1) < n.shrinkThreshold
   309  	if length >= mOverflowThreshold {
   310  		atomic.AddInt64(&n.overflow, -1)
   311  	}
   312  	// Shrink
   313  	if shrink && len(n.buckets) > mInitialSize && atomic.CompareAndSwapInt64(&n.resizeInProgress, 0, 1) {
   314  		nlen := len(n.buckets) >> 1
   315  		node := &inode2[K, V]{
   316  			mask:            uintptr(nlen) - 1,
   317  			pred:            unsafe.Pointer(n),
   318  			growThreshold:   int64(nlen) * mOverflowThreshold,
   319  			shrinkThreshold: int64(nlen) >> 1,
   320  			buckets:         make([]bucket2[K, V], nlen),
   321  		}
   322  		ok := atomic.CompareAndSwapPointer(&m.inode, unsafe.Pointer(n), unsafe.Pointer(node))
   323  		if !ok {
   324  			panic("BUG: failed swapping head")
   325  		}
   326  		go node.initBuckets(m)
   327  	}
   328  	return true
   329  }
   330  
   331  // tflag is used by an rtype to signal what extra type information is
   332  // available in the memory directly following the rtype value.
   333  //
   334  // tflag values must be kept in sync with copies in:
   335  //	cmd/compile/internal/reflectdata/reflect.go
   336  //	cmd/link/internal/ld/decodesym.go
   337  //	runtime/type.go
   338  type tflag uint8
   339  
   340  const (
   341  	// tflagUncommon means that there is a pointer, *uncommonType,
   342  	// just beyond the outer type structure.
   343  	//
   344  	// For example, if t.Kind() == Struct and t.tflag&tflagUncommon != 0,
   345  	// then t has uncommonType data and it can be accessed as:
   346  	//
   347  	//	type tUncommon struct {
   348  	//		structType
   349  	//		u uncommonType
   350  	//	}
   351  	//	u := &(*tUncommon)(unsafe.Pointer(t)).u
   352  	tflagUncommon tflag = 1 << 0
   353  
   354  	// tflagExtraStar means the name in the str field has an
   355  	// extraneous '*' prefix. This is because for most types T in
   356  	// a program, the type *T also exists and reusing the str data
   357  	// saves binary size.
   358  	tflagExtraStar tflag = 1 << 1
   359  
   360  	// tflagNamed means the type has a name.
   361  	tflagNamed tflag = 1 << 2
   362  
   363  	// tflagRegularMemory means that equal and hash functions can treat
   364  	// this type as a single region of t.size bytes.
   365  	tflagRegularMemory tflag = 1 << 3
   366  )
   367  
   368  // rtype is the common implementation of most values.
   369  // It is embedded in other struct types.
   370  //
   371  // rtype must be kept in sync with ../runtime/type.go:/^type._type.
   372  type rtype struct {
   373  	size       uintptr
   374  	ptrdata    uintptr // number of bytes in the type that can contain pointers
   375  	hash       uint32  // hash of type; avoids computation in hash tables
   376  	tflag      tflag   // extra type information flags
   377  	align      uint8   // alignment of variable with this type
   378  	fieldAlign uint8   // alignment of struct field with this type
   379  	kind       uint8   // enumeration for C
   380  }
   381  
   382  //func (t *rtype) IsRegularMemory() bool {
   383  //	return t.tflag&tflagRegularMemory != 0
   384  //}
   385  
   386  func (t *rtype) IsDirectIface() bool {
   387  	const kindDirectIface = 1 << 5
   388  	return t.kind&kindDirectIface != 0
   389  }
   390  
   391  // eface must be kept in sync with ../src/runtime/runtime2.go:/^eface.
   392  type eface struct {
   393  	typ  *rtype
   394  	data unsafe.Pointer
   395  }
   396  
   397  func efaceOf(ep *any) *eface {
   398  	return (*eface)(unsafe.Pointer(ep))
   399  }
   400  
   401  func (m *Map[K, V]) ehash(i K) uintptr {
   402  	if m.typ == nil {
   403  		func() {
   404  			m.lock.Lock()
   405  			defer m.lock.Unlock()
   406  			if m.typ == nil {
   407  				// if K is interface type, then the direct reflect.TypeOf(K).Kind return reflect.Ptr
   408  				if typ := reflect.TypeOf(&i); typ.Elem().Kind() == reflect.Interface {
   409  					panic("not support interface type")
   410  				}
   411  				var e any = i
   412  				m.typ = efaceOf(&e).typ
   413  			}
   414  		}()
   415  	}
   416  
   417  	var f eface
   418  	f.typ = m.typ
   419  	if f.typ.IsDirectIface() {
   420  		f.data = *(*unsafe.Pointer)(unsafe.Pointer(&i))
   421  	} else {
   422  		f.data = noescape(unsafe.Pointer(&i))
   423  	}
   424  
   425  	return nilinterhash(noescape(unsafe.Pointer(&f)), 0xdeadbeef)
   426  }