github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/zarray/hashmap.go (about)

     1  //go:build go1.18
     2  // +build go1.18
     3  
     4  package zarray
     5  
     6  import (
     7  	"encoding/json"
     8  	"reflect"
     9  	"sort"
    10  	"strconv"
    11  	"sync/atomic"
    12  	"unsafe"
    13  
    14  	"github.com/sohaha/zlsgo/zutil"
    15  	"golang.org/x/exp/constraints"
    16  	"golang.org/x/sync/singleflight"
    17  )
    18  
    19  const (
    20  	// defaultSize is the default size for a zero allocated map
    21  	defaultSize = 8
    22  
    23  	// maxFillRate is the maximum fill rate for the slice before a resize will happen
    24  	maxFillRate = 50
    25  
    26  	// intSizeBytes is the size in byte of an int or uint value
    27  	intSizeBytes = strconv.IntSize >> 3
    28  )
    29  
    30  // indicates resizing operation status enums
    31  const (
    32  	notResizing uint32 = iota
    33  	resizingInProgress
    34  )
    35  
    36  type (
    37  	hashable interface {
    38  		constraints.Integer | constraints.Float | constraints.Complex | ~string | uintptr | unsafe.Pointer
    39  	}
    40  
    41  	metadata[K hashable, V any] struct {
    42  		count     *zutil.Uintptr
    43  		data      unsafe.Pointer
    44  		index     []*element[K, V]
    45  		keyshifts uintptr
    46  	}
    47  
    48  	// Maper implements the concurrent hashmap
    49  	Maper[K hashable, V any] struct {
    50  		gsf      singleflight.Group
    51  		listHead *element[K, V]
    52  		hasher   func(K) uintptr
    53  		metadata atomicPointer[metadata[K, V]]
    54  		resizing *zutil.Uint32
    55  		numItems *zutil.Uintptr
    56  	}
    57  
    58  	deletionRequest[K hashable] struct {
    59  		key     K
    60  		keyHash uintptr
    61  	}
    62  )
    63  
    64  func NewHashMap[K hashable, V any](size ...uintptr) *Maper[K, V] {
    65  	m := &Maper[K, V]{
    66  		listHead: newListHead[K, V](),
    67  		resizing: zutil.NewUint32(0),
    68  		numItems: zutil.NewUintptr(0),
    69  	}
    70  	m.numItems.Store(0)
    71  	if len(size) > 0 && size[0] != 0 {
    72  		m.allocate(size[0])
    73  	} else {
    74  		m.allocate(defaultSize)
    75  	}
    76  	m.setDefaultHasher()
    77  	return m
    78  }
    79  
    80  // Delete a map data structure.
    81  func (m *Maper[K, V]) Delete(keys ...K) {
    82  	size := len(keys)
    83  	switch {
    84  	case size == 0:
    85  		return
    86  	case size == 1:
    87  		var (
    88  			h        = m.hasher(keys[0])
    89  			existing = m.metadata.Load().indexElement(h)
    90  		)
    91  		if existing == nil || existing.keyHash > h {
    92  			existing = m.listHead.next()
    93  		}
    94  		for ; existing != nil && existing.keyHash <= h; existing = existing.next() {
    95  			if existing.key == keys[0] {
    96  				if existing.remove() {
    97  					m.removeItemFromIndex(existing)
    98  				}
    99  				return
   100  			}
   101  		}
   102  	default:
   103  		var (
   104  			delQueue = make([]deletionRequest[K], size)
   105  			iter     = 0
   106  		)
   107  		for idx := 0; idx < size; idx++ {
   108  			delQueue[idx].keyHash, delQueue[idx].key = m.hasher(keys[idx]), keys[idx]
   109  		}
   110  
   111  		sort.Slice(delQueue, func(i, j int) bool {
   112  			return delQueue[i].keyHash < delQueue[j].keyHash
   113  		})
   114  
   115  		elem := m.metadata.Load().indexElement(delQueue[0].keyHash)
   116  
   117  		if elem == nil || elem.keyHash > delQueue[0].keyHash {
   118  			elem = m.listHead.next()
   119  		}
   120  
   121  		for elem != nil && iter < size {
   122  			if elem.keyHash == delQueue[iter].keyHash && elem.key == delQueue[iter].key {
   123  				if elem.remove() {
   124  					m.removeItemFromIndex(elem)
   125  				}
   126  				iter++
   127  				elem = elem.next()
   128  			} else if elem.keyHash > delQueue[iter].keyHash {
   129  				iter++
   130  			} else {
   131  				elem = elem.next()
   132  			}
   133  		}
   134  	}
   135  }
   136  
   137  // Has a parameter "key" of type K and returns a boolean value "ok".
   138  func (m *Maper[K, V]) Has(key K) (ok bool) {
   139  	_, ok = m.get(m.hasher(key), key)
   140  	return
   141  }
   142  
   143  // Get `Maper` struct is used to retrieve the value associated with a given key
   144  // from the hashmap. It takes a key as input and returns the corresponding value and a boolean
   145  // indicating whether the key exists in the hashmap.
   146  func (m *Maper[K, V]) Get(key K) (value V, ok bool) {
   147  	return m.get(m.hasher(key), key)
   148  }
   149  
   150  func (m *Maper[K, V]) GetAndDelete(key K) (value V, ok bool) {
   151  	var (
   152  		h        = m.hasher(key)
   153  		existing = m.metadata.Load().indexElement(h)
   154  	)
   155  	if existing == nil || existing.keyHash > h {
   156  		existing = m.listHead.next()
   157  	}
   158  	for ; existing != nil && existing.keyHash <= h; existing = existing.next() {
   159  		if existing.key == key {
   160  			value, ok = *existing.value.Load(), !existing.isDeleted()
   161  			if existing.remove() {
   162  				m.removeItemFromIndex(existing)
   163  			}
   164  			return
   165  		}
   166  	}
   167  	return
   168  }
   169  
   170  func (m *Maper[K, V]) get(h uintptr, key K) (value V, ok bool) {
   171  	for elem := m.metadata.Load().indexElement(h); elem != nil && elem.keyHash <= h; elem = elem.nextPtr.Load() {
   172  		if elem.key == key {
   173  			ok = !elem.isDeleted()
   174  			if ok {
   175  				value = *elem.value.Load()
   176  			}
   177  			return
   178  		}
   179  	}
   180  	return
   181  }
   182  
   183  // ProvideGet `Maper` struct is used to retrieve the value associated with a
   184  // given key from the hashmap. If the key exists in the hashmap, the function returns the value and
   185  // sets the `loaded` flag to true. If the key does not exist, the function calls the `provide` function
   186  // to compute the value and sets the `computed` flag to true. The computed value is then added to the
   187  // hashmap and returned.
   188  func (m *Maper[K, V]) ProvideGet(key K, provide func() (V, bool)) (actual V, loaded, computed bool) {
   189  	var (
   190  		h        = m.hasher(key)
   191  		data     = m.metadata.Load()
   192  		existing = data.indexElement(h)
   193  	)
   194  
   195  	for elem := existing; elem != nil && elem.keyHash <= h; elem = elem.nextPtr.Load() {
   196  		if elem.key == key {
   197  			loaded = !elem.isDeleted()
   198  			if loaded {
   199  				actual = *elem.value.Load()
   200  				return
   201  			}
   202  		}
   203  	}
   204  
   205  	r := false
   206  	_, _, _ = m.gsf.Do(strconv.FormatInt(int64(h), 10), func() (interface{}, error) {
   207  		actual, loaded = provide()
   208  		if loaded {
   209  			m.Set(key, actual)
   210  			computed = true
   211  		}
   212  		r = true
   213  		return nil, nil
   214  	})
   215  
   216  	if !r {
   217  		actual, loaded = m.get(h, key)
   218  	}
   219  	return
   220  }
   221  
   222  func (m *Maper[K, V]) GetOrSet(key K, value V) (actual V, loaded bool) {
   223  	actual, loaded, _ = m.ProvideGet(key, func() (V, bool) {
   224  		return value, true
   225  	})
   226  	return
   227  }
   228  
   229  func (m *Maper[K, V]) Set(key K, value V) {
   230  	m.set(m.hasher(key), key, value)
   231  }
   232  
   233  func (m *Maper[K, V]) set(h uintptr, key K, value V) {
   234  	var (
   235  		created  bool
   236  		valPtr   = &value
   237  		alloc    *element[K, V]
   238  		data     = m.metadata.Load()
   239  		existing = data.indexElement(h)
   240  	)
   241  
   242  	if existing == nil || existing.keyHash > h {
   243  		existing = m.listHead
   244  	}
   245  
   246  	if alloc, created = existing.inject(h, key, valPtr); alloc != nil {
   247  		if created {
   248  			m.numItems.Add(1)
   249  		}
   250  	} else {
   251  		for existing = m.listHead; alloc == nil; alloc, created = existing.inject(h, key, valPtr) {
   252  		}
   253  		if created {
   254  			m.numItems.Add(1)
   255  		}
   256  	}
   257  
   258  	count := data.addItemToIndex(alloc)
   259  	if resizeNeeded(uintptr(len(data.index)), count) && m.resizing.CAS(notResizing, resizingInProgress) {
   260  		m.grow(0)
   261  	}
   262  }
   263  
   264  // Swap `Maper` struct is used to atomically swap the value associated with a
   265  // given key in the hashmap. It takes a key and a new value as input parameters and returns the old
   266  // value that was swapped out and a boolean indicating whether the swap was successful.
   267  func (m *Maper[K, V]) Swap(key K, newValue V) (oldValue V, swapped bool) {
   268  	var (
   269  		h        = m.hasher(key)
   270  		existing = m.metadata.Load().indexElement(h)
   271  	)
   272  
   273  	if existing == nil || existing.keyHash > h {
   274  		existing = m.listHead
   275  	}
   276  
   277  	if _, current, _ := existing.search(h, key); current != nil {
   278  		oldValue, swapped = *current.value.Swap(&newValue), true
   279  	} else {
   280  		swapped = false
   281  	}
   282  	return
   283  }
   284  
   285  // CAS `Maper` struct is used to perform a Compare-and-Swap operation on a
   286  // key-value pair in the hashmap.
   287  func (m *Maper[K, V]) CAS(key K, oldValue, newValue V) bool {
   288  	var (
   289  		h        = m.hasher(key)
   290  		existing = m.metadata.Load().indexElement(h)
   291  	)
   292  
   293  	if existing == nil || existing.keyHash > h {
   294  		existing = m.listHead
   295  	}
   296  
   297  	if _, current, _ := existing.search(h, key); current != nil {
   298  		if oldPtr := current.value.Load(); reflect.DeepEqual(*oldPtr, oldValue) {
   299  			return current.value.CompareAndSwap(oldPtr, &newValue)
   300  		}
   301  	}
   302  	return false
   303  }
   304  
   305  // ForEach `Maper` struct iterates over each key-value pair in the hashmap and
   306  // applies a lambda function to each pair. The lambda function takes a key and value as input
   307  // parameters and returns a boolean value. If the lambda function returns `true`, the iteration
   308  // continues to the next key-value pair. If the lambda function returns `false`, the iteration stops.
   309  func (m *Maper[K, V]) ForEach(lambda func(K, V) bool) {
   310  	for item := m.listHead.next(); item != nil && lambda(item.key, *item.value.Load()); item = item.next() {
   311  	}
   312  }
   313  
   314  func (m *Maper[K, V]) Grow(newSize uintptr) {
   315  	if m.resizing.CAS(notResizing, resizingInProgress) {
   316  		m.grow(newSize)
   317  	}
   318  }
   319  
   320  func (m *Maper[K, V]) SetHasher(hasher func(K) uintptr) {
   321  	m.hasher = hasher
   322  }
   323  
   324  func (m *Maper[K, V]) Len() uintptr {
   325  	return m.numItems.Load()
   326  }
   327  
   328  func (m *Maper[K, V]) Clear() {
   329  	index := make([]*element[K, V], defaultSize)
   330  	header := (*reflect.SliceHeader)(unsafe.Pointer(&index))
   331  
   332  	newdata := &metadata[K, V]{
   333  		keyshifts: strconv.IntSize - log2(defaultSize),
   334  		data:      unsafe.Pointer(header.Data),
   335  		index:     index,
   336  		count:     zutil.NewUintptr(0),
   337  	}
   338  
   339  	m.listHead.nextPtr.Store(nil)
   340  	m.metadata.Store(newdata)
   341  	m.numItems.Store(0)
   342  }
   343  
   344  // Keys returns the keys of the map.
   345  func (m *Maper[K, V]) Keys() (keys []K) {
   346  	keys = make([]K, m.Len())
   347  	var (
   348  		idx  = 0
   349  		item = m.listHead.next()
   350  	)
   351  	for item != nil {
   352  		keys[idx] = item.key
   353  		idx++
   354  		item = item.next()
   355  	}
   356  	return
   357  }
   358  
   359  // MarshalJSON convert the `Maper` object into a JSON-encoded byte slice.
   360  func (m *Maper[K, V]) MarshalJSON() ([]byte, error) {
   361  	gomap := make(map[K]V)
   362  	for i := m.listHead.next(); i != nil; i = i.next() {
   363  		gomap[i.key] = *i.value.Load()
   364  	}
   365  	return json.Marshal(gomap)
   366  }
   367  
   368  // UnmarshalJSON used to deserialize a JSON-encoded byte slice into a `Maper` object.
   369  func (m *Maper[K, V]) UnmarshalJSON(i []byte) error {
   370  	gomap := make(map[K]V)
   371  	err := json.Unmarshal(i, &gomap)
   372  	if err != nil {
   373  		return err
   374  	}
   375  	for k, v := range gomap {
   376  		m.Set(k, v)
   377  	}
   378  	return nil
   379  }
   380  
   381  // Fillrate calculates the fill rate of the hashmap.
   382  // It returns the percentage of slots in the hashmap that are currently occupied by elements.
   383  func (m *Maper[K, V]) Fillrate() uintptr {
   384  	data := m.metadata.Load()
   385  
   386  	return (data.count.Load() * 100) / uintptr(len(data.index))
   387  }
   388  
   389  func (m *Maper[K, V]) allocate(newSize uintptr) {
   390  	if m.resizing.CAS(notResizing, resizingInProgress) {
   391  		m.grow(newSize)
   392  	}
   393  }
   394  
   395  func (m *Maper[K, V]) fillIndexItems(mapData *metadata[K, V]) {
   396  	var (
   397  		first     = m.listHead.next()
   398  		item      = first
   399  		lastIndex = uintptr(0)
   400  	)
   401  
   402  	for item != nil {
   403  		index := item.keyHash >> mapData.keyshifts
   404  		if item == first || index != lastIndex {
   405  			mapData.addItemToIndex(item)
   406  			lastIndex = index
   407  		}
   408  		item = item.next()
   409  	}
   410  }
   411  
   412  func (m *Maper[K, V]) removeItemFromIndex(item *element[K, V]) {
   413  	for {
   414  		data := m.metadata.Load()
   415  		index := item.keyHash >> data.keyshifts
   416  		ptr := (*unsafe.Pointer)(unsafe.Pointer(uintptr(data.data) + index*intSizeBytes))
   417  
   418  		next := item.next()
   419  		if next != nil && next.keyHash>>data.keyshifts != index {
   420  			next = nil
   421  		}
   422  
   423  		swappedToNil := atomic.CompareAndSwapPointer(ptr, unsafe.Pointer(item), unsafe.Pointer(next)) && next == nil
   424  		if data == m.metadata.Load() {
   425  			m.numItems.Add(^uintptr(0))
   426  			if swappedToNil {
   427  				data.count.Add(^uintptr(0))
   428  			}
   429  			return
   430  		}
   431  	}
   432  }
   433  
   434  func (m *Maper[K, V]) grow(newSize uintptr) {
   435  	for {
   436  		currentStore := m.metadata.Load()
   437  		if newSize == 0 {
   438  			newSize = uintptr(len(currentStore.index)) << 1
   439  		} else {
   440  			newSize = roundUpPower2(newSize)
   441  		}
   442  
   443  		index := make([]*element[K, V], newSize)
   444  		header := (*reflect.SliceHeader)(unsafe.Pointer(&index))
   445  
   446  		newdata := &metadata[K, V]{
   447  			keyshifts: strconv.IntSize - log2(newSize),
   448  			data:      unsafe.Pointer(header.Data),
   449  			count:     zutil.NewUintptr(0),
   450  			index:     index,
   451  		}
   452  
   453  		m.fillIndexItems(newdata)
   454  		m.metadata.Store(newdata)
   455  
   456  		if !resizeNeeded(newSize, m.Len()) {
   457  			m.resizing.Store(notResizing)
   458  			return
   459  		}
   460  		newSize = 0
   461  	}
   462  }
   463  
   464  func (md *metadata[K, V]) indexElement(hashedKey uintptr) *element[K, V] {
   465  	index := hashedKey >> md.keyshifts
   466  	ptr := (*unsafe.Pointer)(unsafe.Pointer(uintptr(md.data) + index*intSizeBytes))
   467  	item := (*element[K, V])(atomic.LoadPointer(ptr))
   468  	for (item == nil || hashedKey < item.keyHash || item.isDeleted()) && index > 0 {
   469  		index--
   470  		ptr = (*unsafe.Pointer)(unsafe.Pointer(uintptr(md.data) + index*intSizeBytes))
   471  		item = (*element[K, V])(atomic.LoadPointer(ptr))
   472  	}
   473  	return item
   474  }
   475  
   476  func (md *metadata[K, V]) addItemToIndex(item *element[K, V]) uintptr {
   477  	index := item.keyHash >> md.keyshifts
   478  	ptr := (*unsafe.Pointer)(unsafe.Pointer(uintptr(md.data) + index*intSizeBytes))
   479  	for {
   480  		elem := (*element[K, V])(atomic.LoadPointer(ptr))
   481  		if elem == nil {
   482  			if atomic.CompareAndSwapPointer(ptr, nil, unsafe.Pointer(item)) {
   483  				return md.count.Add(1)
   484  			}
   485  			continue
   486  		}
   487  		if item.keyHash < elem.keyHash {
   488  			if !atomic.CompareAndSwapPointer(ptr, unsafe.Pointer(elem), unsafe.Pointer(item)) {
   489  				continue
   490  			}
   491  		}
   492  		return 0
   493  	}
   494  }
   495  
   496  func resizeNeeded(length, count uintptr) bool {
   497  	return (count*100)/length > maxFillRate
   498  }
   499  
   500  func roundUpPower2(i uintptr) uintptr {
   501  	i--
   502  	i |= i >> 1
   503  	i |= i >> 2
   504  	i |= i >> 4
   505  	i |= i >> 8
   506  	i |= i >> 16
   507  	i |= i >> 32
   508  	i++
   509  	return i
   510  }
   511  
   512  func log2(i uintptr) (n uintptr) {
   513  	for p := uintptr(1); p < i; p, n = p<<1, n+1 {
   514  	}
   515  	return
   516  }