github.com/maypok86/otter@v1.2.1/internal/hashtable/map.go (about)

     1  // Copyright (c) 2023 Alexey Mayshev. All rights reserved.
     2  // Copyright (c) 2021 Andrey Pechkurov
     3  //
     4  // Copyright notice. This code is a fork of xsync.MapOf from this file with some changes:
     5  // https://github.com/puzpuzpuz/xsync/blob/main/mapof.go
     6  //
     7  // Use of this source code is governed by a MIT license that can be found
     8  // at https://github.com/puzpuzpuz/xsync/blob/main/LICENSE
     9  
    10  package hashtable
    11  
    12  import (
    13  	"fmt"
    14  	"sync"
    15  	"sync/atomic"
    16  	"unsafe"
    17  
    18  	"github.com/dolthub/maphash"
    19  
    20  	"github.com/maypok86/otter/internal/generated/node"
    21  	"github.com/maypok86/otter/internal/xmath"
    22  	"github.com/maypok86/otter/internal/xruntime"
    23  )
    24  
    25  type resizeHint int
    26  
    27  const (
    28  	growHint   resizeHint = 0
    29  	shrinkHint resizeHint = 1
    30  	clearHint  resizeHint = 2
    31  )
    32  
    33  const (
    34  	// number of entries per bucket
    35  	// 3 because we need to fit them into 1 cache line (64 bytes).
    36  	bucketSize = 3
    37  	// percentage at which the map will be expanded.
    38  	loadFactor = 0.75
    39  	// threshold fraction of table occupation to start a table shrinking
    40  	// when deleting the last entry in a bucket chain.
    41  	shrinkFraction   = 128
    42  	minBucketCount   = 32
    43  	minNodeCount     = bucketSize * minBucketCount
    44  	minCounterLength = 8
    45  	maxCounterLength = 32
    46  )
    47  
    48  // Map is like a Go map[K]V but is safe for concurrent
    49  // use by multiple goroutines without additional locking or
    50  // coordination.
    51  //
    52  // A Map must not be copied after first use.
    53  //
    54  // Map uses a modified version of Cache-Line Hash Table (CLHT)
    55  // data structure: https://github.com/LPD-EPFL/CLHT
    56  //
    57  // CLHT is built around idea to organize the hash table in
    58  // cache-line-sized buckets, so that on all modern CPUs update
    59  // operations complete with at most one cache-line transfer.
    60  // Also, Get operations involve no write to memory, as well as no
    61  // mutexes or any other sort of locks. Due to this design, in all
    62  // considered scenarios Map outperforms sync.Map.
    63  type Map[K comparable, V any] struct {
    64  	table unsafe.Pointer
    65  
    66  	nodeManager *node.Manager[K, V]
    67  	// only used along with resizeCond
    68  	resizeMutex sync.Mutex
    69  	// used to wake up resize waiters (concurrent modifications)
    70  	resizeCond sync.Cond
    71  	// resize in progress flag; updated atomically
    72  	resizing atomic.Int64
    73  }
    74  
    75  type table[K comparable] struct {
    76  	buckets []paddedBucket
    77  	// sharded counter for number of table entries;
    78  	// used to determine if a table shrinking is needed
    79  	// occupies min(buckets_memory/1024, 64KB) of memory
    80  	size   []paddedCounter
    81  	mask   uint64
    82  	hasher maphash.Hasher[K]
    83  }
    84  
    85  func (t *table[K]) addSize(bucketIdx uint64, delta int) {
    86  	counterIdx := uint64(len(t.size)-1) & bucketIdx
    87  	atomic.AddInt64(&t.size[counterIdx].c, int64(delta))
    88  }
    89  
    90  func (t *table[K]) addSizePlain(bucketIdx uint64, delta int) {
    91  	counterIdx := uint64(len(t.size)-1) & bucketIdx
    92  	t.size[counterIdx].c += int64(delta)
    93  }
    94  
    95  func (t *table[K]) sumSize() int64 {
    96  	sum := int64(0)
    97  	for i := range t.size {
    98  		sum += atomic.LoadInt64(&t.size[i].c)
    99  	}
   100  	return sum
   101  }
   102  
   103  func (t *table[K]) calcShiftHash(key K) uint64 {
   104  	// uint64(0) is a reserved value which stands for an empty slot.
   105  	h := t.hasher.Hash(key)
   106  	if h == uint64(0) {
   107  		return 1
   108  	}
   109  
   110  	return h
   111  }
   112  
   113  type counter struct {
   114  	c int64
   115  }
   116  
   117  type paddedCounter struct {
   118  	// padding prevents false sharing.
   119  	padding [xruntime.CacheLineSize - unsafe.Sizeof(counter{})]byte
   120  
   121  	counter
   122  }
   123  
   124  // NewWithSize creates a new Map instance with capacity enough
   125  // to hold size nodes. If size is zero or negative, the value
   126  // is ignored.
   127  func NewWithSize[K comparable, V any](nodeManager *node.Manager[K, V], size int) *Map[K, V] {
   128  	return newMap[K, V](nodeManager, size)
   129  }
   130  
   131  // New creates a new Map instance.
   132  func New[K comparable, V any](nodeManager *node.Manager[K, V]) *Map[K, V] {
   133  	return newMap[K, V](nodeManager, minNodeCount)
   134  }
   135  
   136  func newMap[K comparable, V any](nodeManager *node.Manager[K, V], size int) *Map[K, V] {
   137  	m := &Map[K, V]{
   138  		nodeManager: nodeManager,
   139  	}
   140  	m.resizeCond = *sync.NewCond(&m.resizeMutex)
   141  	var t *table[K]
   142  	if size <= minNodeCount {
   143  		t = newTable(minBucketCount, maphash.NewHasher[K]())
   144  	} else {
   145  		bucketCount := xmath.RoundUpPowerOf2(uint32(size / bucketSize))
   146  		t = newTable(int(bucketCount), maphash.NewHasher[K]())
   147  	}
   148  	atomic.StorePointer(&m.table, unsafe.Pointer(t))
   149  	return m
   150  }
   151  
   152  func newTable[K comparable](bucketCount int, prevHasher maphash.Hasher[K]) *table[K] {
   153  	buckets := make([]paddedBucket, bucketCount)
   154  	counterLength := bucketCount >> 10
   155  	if counterLength < minCounterLength {
   156  		counterLength = minCounterLength
   157  	} else if counterLength > maxCounterLength {
   158  		counterLength = maxCounterLength
   159  	}
   160  	counter := make([]paddedCounter, counterLength)
   161  	mask := uint64(len(buckets) - 1)
   162  	t := &table[K]{
   163  		buckets: buckets,
   164  		size:    counter,
   165  		mask:    mask,
   166  		hasher:  maphash.NewSeed[K](prevHasher),
   167  	}
   168  	return t
   169  }
   170  
   171  // Get returns the node.Node stored in the map for a key, or nil if no node is present.
   172  //
   173  // The ok result indicates whether node was found in the map.
   174  func (m *Map[K, V]) Get(key K) (got node.Node[K, V], ok bool) {
   175  	t := (*table[K])(atomic.LoadPointer(&m.table))
   176  	hash := t.calcShiftHash(key)
   177  	bucketIdx := hash & t.mask
   178  	b := &t.buckets[bucketIdx]
   179  	for {
   180  		for i := 0; i < bucketSize; i++ {
   181  			// we treat the hash code only as a hint, so there is no
   182  			// need to get an atomic snapshot.
   183  			h := atomic.LoadUint64(&b.hashes[i])
   184  			if h == uint64(0) || h != hash {
   185  				continue
   186  			}
   187  			// we found a matching hash code
   188  			nodePtr := atomic.LoadPointer(&b.nodes[i])
   189  			if nodePtr == nil {
   190  				// concurrent write in this node
   191  				continue
   192  			}
   193  			n := m.nodeManager.FromPointer(nodePtr)
   194  			if key != n.Key() {
   195  				continue
   196  			}
   197  
   198  			return n, true
   199  		}
   200  		bucketPtr := atomic.LoadPointer(&b.next)
   201  		if bucketPtr == nil {
   202  			return nil, false
   203  		}
   204  		b = (*paddedBucket)(bucketPtr)
   205  	}
   206  }
   207  
   208  // Set sets the node.Node for the key.
   209  //
   210  // Returns the evicted node or nil if the node was inserted.
   211  func (m *Map[K, V]) Set(n node.Node[K, V]) node.Node[K, V] {
   212  	return m.set(n, false)
   213  }
   214  
   215  // SetIfAbsent sets the node.Node if the specified key is not already associated with a value (or is mapped to null)
   216  // associates it with the given value and returns null, else returns the current node.
   217  func (m *Map[K, V]) SetIfAbsent(n node.Node[K, V]) node.Node[K, V] {
   218  	return m.set(n, true)
   219  }
   220  
   221  func (m *Map[K, V]) set(n node.Node[K, V], onlyIfAbsent bool) node.Node[K, V] {
   222  	for {
   223  	RETRY:
   224  		var (
   225  			emptyBucket *paddedBucket
   226  			emptyIdx    int
   227  		)
   228  		t := (*table[K])(atomic.LoadPointer(&m.table))
   229  		tableLen := len(t.buckets)
   230  		hash := t.calcShiftHash(n.Key())
   231  		bucketIdx := hash & t.mask
   232  		rootBucket := &t.buckets[bucketIdx]
   233  		rootBucket.mutex.Lock()
   234  		// the following two checks must go in reverse to what's
   235  		// in the resize method.
   236  		if m.resizeInProgress() {
   237  			// resize is in progress. wait, then go for another attempt.
   238  			rootBucket.mutex.Unlock()
   239  			m.waitForResize()
   240  			goto RETRY
   241  		}
   242  		if m.newerTableExists(t) {
   243  			// someone resized the table, go for another attempt.
   244  			rootBucket.mutex.Unlock()
   245  			goto RETRY
   246  		}
   247  		b := rootBucket
   248  		for {
   249  			for i := 0; i < bucketSize; i++ {
   250  				h := b.hashes[i]
   251  				if h == uint64(0) {
   252  					if emptyBucket == nil {
   253  						emptyBucket = b
   254  						emptyIdx = i
   255  					}
   256  					continue
   257  				}
   258  				if h != hash {
   259  					continue
   260  				}
   261  				prev := m.nodeManager.FromPointer(b.nodes[i])
   262  				if n.Key() != prev.Key() {
   263  					continue
   264  				}
   265  				if onlyIfAbsent {
   266  					// found node, drop set
   267  					rootBucket.mutex.Unlock()
   268  					return n
   269  				}
   270  				// in-place update.
   271  				// We get a copy of the value via an interface{} on each call,
   272  				// thus the live value pointers are unique. Otherwise atomic
   273  				// snapshot won't be correct in case of multiple Store calls
   274  				// using the same value.
   275  				atomic.StorePointer(&b.nodes[i], n.AsPointer())
   276  				rootBucket.mutex.Unlock()
   277  				return prev
   278  			}
   279  			if b.next == nil {
   280  				if emptyBucket != nil {
   281  					// insertion into an existing bucket.
   282  					// first we update the hash, then the entry.
   283  					atomic.StoreUint64(&emptyBucket.hashes[emptyIdx], hash)
   284  					atomic.StorePointer(&emptyBucket.nodes[emptyIdx], n.AsPointer())
   285  					rootBucket.mutex.Unlock()
   286  					t.addSize(bucketIdx, 1)
   287  					return nil
   288  				}
   289  				growThreshold := float64(tableLen) * bucketSize * loadFactor
   290  				if t.sumSize() > int64(growThreshold) {
   291  					// need to grow the table then go for another attempt.
   292  					rootBucket.mutex.Unlock()
   293  					m.resize(t, growHint)
   294  					goto RETRY
   295  				}
   296  				// insertion into a new bucket.
   297  				// create and append the bucket.
   298  				newBucket := &paddedBucket{}
   299  				newBucket.hashes[0] = hash
   300  				newBucket.nodes[0] = n.AsPointer()
   301  				atomic.StorePointer(&b.next, unsafe.Pointer(newBucket))
   302  				rootBucket.mutex.Unlock()
   303  				t.addSize(bucketIdx, 1)
   304  				return nil
   305  			}
   306  			b = (*paddedBucket)(b.next)
   307  		}
   308  	}
   309  }
   310  
   311  // Delete deletes the value for a key.
   312  //
   313  // Returns the deleted node or nil if the node wasn't deleted.
   314  func (m *Map[K, V]) Delete(key K) node.Node[K, V] {
   315  	return m.delete(key, func(n node.Node[K, V]) bool {
   316  		return key == n.Key()
   317  	})
   318  }
   319  
   320  // DeleteNode evicts the node for a key.
   321  //
   322  // Returns the evicted node or nil if the node wasn't evicted.
   323  func (m *Map[K, V]) DeleteNode(n node.Node[K, V]) node.Node[K, V] {
   324  	return m.delete(n.Key(), func(current node.Node[K, V]) bool {
   325  		return node.Equals(n, current)
   326  	})
   327  }
   328  
   329  func (m *Map[K, V]) delete(key K, cmp func(node.Node[K, V]) bool) node.Node[K, V] {
   330  	for {
   331  	RETRY:
   332  		hintNonEmpty := 0
   333  		t := (*table[K])(atomic.LoadPointer(&m.table))
   334  		hash := t.calcShiftHash(key)
   335  		bucketIdx := hash & t.mask
   336  		rootBucket := &t.buckets[bucketIdx]
   337  		rootBucket.mutex.Lock()
   338  		// the following two checks must go in reverse to what's
   339  		// in the resize method.
   340  		if m.resizeInProgress() {
   341  			// resize is in progress. Wait, then go for another attempt.
   342  			rootBucket.mutex.Unlock()
   343  			m.waitForResize()
   344  			goto RETRY
   345  		}
   346  		if m.newerTableExists(t) {
   347  			// someone resized the table. Go for another attempt.
   348  			rootBucket.mutex.Unlock()
   349  			goto RETRY
   350  		}
   351  		b := rootBucket
   352  		for {
   353  			for i := 0; i < bucketSize; i++ {
   354  				h := b.hashes[i]
   355  				if h == uint64(0) {
   356  					continue
   357  				}
   358  				if h != hash {
   359  					hintNonEmpty++
   360  					continue
   361  				}
   362  				current := m.nodeManager.FromPointer(b.nodes[i])
   363  				if !cmp(current) {
   364  					hintNonEmpty++
   365  					continue
   366  				}
   367  				// Deletion.
   368  				// First we update the hash, then the node.
   369  				atomic.StoreUint64(&b.hashes[i], uint64(0))
   370  				atomic.StorePointer(&b.nodes[i], nil)
   371  				leftEmpty := false
   372  				if hintNonEmpty == 0 {
   373  					leftEmpty = b.isEmpty()
   374  				}
   375  				rootBucket.mutex.Unlock()
   376  				t.addSize(bucketIdx, -1)
   377  				// Might need to shrink the table.
   378  				if leftEmpty {
   379  					m.resize(t, shrinkHint)
   380  				}
   381  				return current
   382  			}
   383  			if b.next == nil {
   384  				// not found
   385  				rootBucket.mutex.Unlock()
   386  				return nil
   387  			}
   388  			b = (*paddedBucket)(b.next)
   389  		}
   390  	}
   391  }
   392  
   393  func (m *Map[K, V]) resize(known *table[K], hint resizeHint) {
   394  	knownTableLen := len(known.buckets)
   395  	// fast path for shrink attempts.
   396  	if hint == shrinkHint {
   397  		shrinkThreshold := int64((knownTableLen * bucketSize) / shrinkFraction)
   398  		if knownTableLen == minBucketCount || known.sumSize() > shrinkThreshold {
   399  			return
   400  		}
   401  	}
   402  	// slow path.
   403  	if !m.resizing.CompareAndSwap(0, 1) {
   404  		// someone else started resize. Wait for it to finish.
   405  		m.waitForResize()
   406  		return
   407  	}
   408  	var nt *table[K]
   409  	t := (*table[K])(atomic.LoadPointer(&m.table))
   410  	tableLen := len(t.buckets)
   411  	switch hint {
   412  	case growHint:
   413  		// grow the table with factor of 2.
   414  		nt = newTable(tableLen<<1, t.hasher)
   415  	case shrinkHint:
   416  		shrinkThreshold := int64((tableLen * bucketSize) / shrinkFraction)
   417  		if tableLen > minBucketCount && t.sumSize() <= shrinkThreshold {
   418  			// shrink the table with factor of 2.
   419  			nt = newTable(tableLen>>1, t.hasher)
   420  		} else {
   421  			// no need to shrink, wake up all waiters and give up.
   422  			m.resizeMutex.Lock()
   423  			m.resizing.Store(0)
   424  			m.resizeCond.Broadcast()
   425  			m.resizeMutex.Unlock()
   426  			return
   427  		}
   428  	case clearHint:
   429  		nt = newTable(minBucketCount, t.hasher)
   430  	default:
   431  		panic(fmt.Sprintf("unexpected resize hint: %d", hint))
   432  	}
   433  	// copy the data only if we're not clearing the hashtable.
   434  	if hint != clearHint {
   435  		for i := 0; i < tableLen; i++ {
   436  			copied := m.copyBuckets(&t.buckets[i], nt)
   437  			nt.addSizePlain(uint64(i), copied)
   438  		}
   439  	}
   440  	// publish the new table and wake up all waiters.
   441  	atomic.StorePointer(&m.table, unsafe.Pointer(nt))
   442  	m.resizeMutex.Lock()
   443  	m.resizing.Store(0)
   444  	m.resizeCond.Broadcast()
   445  	m.resizeMutex.Unlock()
   446  }
   447  
   448  func (m *Map[K, V]) copyBuckets(b *paddedBucket, dest *table[K]) (copied int) {
   449  	rootBucket := b
   450  	rootBucket.mutex.Lock()
   451  	for {
   452  		for i := 0; i < bucketSize; i++ {
   453  			if b.nodes[i] == nil {
   454  				continue
   455  			}
   456  			n := m.nodeManager.FromPointer(b.nodes[i])
   457  			hash := dest.calcShiftHash(n.Key())
   458  			bucketIdx := hash & dest.mask
   459  			dest.buckets[bucketIdx].add(hash, b.nodes[i])
   460  			copied++
   461  		}
   462  		if b.next == nil {
   463  			rootBucket.mutex.Unlock()
   464  			return copied
   465  		}
   466  		b = (*paddedBucket)(b.next)
   467  	}
   468  }
   469  
   470  func (m *Map[K, V]) newerTableExists(table *table[K]) bool {
   471  	currentTable := atomic.LoadPointer(&m.table)
   472  	return uintptr(currentTable) != uintptr(unsafe.Pointer(table))
   473  }
   474  
   475  func (m *Map[K, V]) resizeInProgress() bool {
   476  	return m.resizing.Load() == 1
   477  }
   478  
   479  func (m *Map[K, V]) waitForResize() {
   480  	m.resizeMutex.Lock()
   481  	for m.resizeInProgress() {
   482  		m.resizeCond.Wait()
   483  	}
   484  	m.resizeMutex.Unlock()
   485  }
   486  
   487  // Range calls f sequentially for each node present in the
   488  // map. If f returns false, range stops the iteration.
   489  //
   490  // Range does not necessarily correspond to any consistent snapshot
   491  // of the Map's contents: no key will be visited more than once, but
   492  // if the value for any key is stored or deleted concurrently, Range
   493  // may reflect any mapping for that key from any point during the
   494  // Range call.
   495  //
   496  // It is safe to modify the map while iterating it. However, the
   497  // concurrent modification rule apply, i.e. the changes may be not
   498  // reflected in the subsequently iterated nodes.
   499  func (m *Map[K, V]) Range(f func(node.Node[K, V]) bool) {
   500  	var zeroPtr unsafe.Pointer
   501  	// Pre-allocate array big enough to fit nodes for most hash tables.
   502  	buffer := make([]unsafe.Pointer, 0, 16*bucketSize)
   503  	tp := atomic.LoadPointer(&m.table)
   504  	t := *(*table[K])(tp)
   505  	for i := range t.buckets {
   506  		rootBucket := &t.buckets[i]
   507  		b := rootBucket
   508  		// Prevent concurrent modifications and copy all nodes into
   509  		// the intermediate slice.
   510  		rootBucket.mutex.Lock()
   511  		for {
   512  			for i := 0; i < bucketSize; i++ {
   513  				if b.nodes[i] != nil {
   514  					buffer = append(buffer, b.nodes[i])
   515  				}
   516  			}
   517  			if b.next == nil {
   518  				rootBucket.mutex.Unlock()
   519  				break
   520  			}
   521  			b = (*paddedBucket)(b.next)
   522  		}
   523  		// Call the function for all copied nodes.
   524  		for j := range buffer {
   525  			n := m.nodeManager.FromPointer(buffer[j])
   526  			if !f(n) {
   527  				return
   528  			}
   529  			// Remove the reference to allow the copied nodes to be GCed before this method finishes.
   530  			buffer[j] = zeroPtr
   531  		}
   532  		buffer = buffer[:0]
   533  	}
   534  }
   535  
   536  // Clear deletes all keys and values currently stored in the map.
   537  func (m *Map[K, V]) Clear() {
   538  	table := (*table[K])(atomic.LoadPointer(&m.table))
   539  	m.resize(table, clearHint)
   540  }
   541  
   542  // Size returns current size of the map.
   543  func (m *Map[K, V]) Size() int {
   544  	table := (*table[K])(atomic.LoadPointer(&m.table))
   545  	return int(table.sumSize())
   546  }