github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/sync/locking/atomicptrmap_goroutine_unsafe.go (about)

     1  package locking
     2  
     3  import (
     4  	"sync/atomic"
     5  	"unsafe"
     6  
     7  	"github.com/nicocha30/gvisor-ligolo/pkg/gohacks"
     8  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
     9  )
    10  
    11  const (
    12  	// ShardOrder is an optional parameter specifying the base-2 log of the
    13  	// number of shards per AtomicPtrMap. Higher values of ShardOrder reduce
    14  	// unnecessary synchronization between unrelated concurrent operations,
    15  	// improving performance for write-heavy workloads, but increase memory
    16  	// usage for small maps.
    17  	goroutineLocksShardOrder = 0
    18  )
    19  
    20  // Hasher is an optional type parameter. If Hasher is provided, it must define
    21  // the Init and Hash methods. One Hasher will be shared by all AtomicPtrMaps.
    22  type goroutineLocksHasher struct {
    23  	goroutineLocksdefaultHasher
    24  }
    25  
    26  // defaultHasher is the default Hasher. This indirection exists because
    27  // defaultHasher must exist even if a custom Hasher is provided, to prevent the
    28  // Go compiler from complaining about defaultHasher's unused imports.
    29  type goroutineLocksdefaultHasher struct {
    30  	fn   func(unsafe.Pointer, uintptr) uintptr
    31  	seed uintptr
    32  }
    33  
    34  // Init initializes the Hasher.
    35  func (h *goroutineLocksdefaultHasher) Init() {
    36  	h.fn = sync.MapKeyHasher(map[int64]*goroutineLocks(nil))
    37  	h.seed = sync.RandUintptr()
    38  }
    39  
    40  // Hash returns the hash value for the given Key.
    41  func (h *goroutineLocksdefaultHasher) Hash(key int64) uintptr {
    42  	return h.fn(gohacks.Noescape(unsafe.Pointer(&key)), h.seed)
    43  }
    44  
    45  var goroutineLockshasher goroutineLocksHasher
    46  
    47  func init() {
    48  	goroutineLockshasher.Init()
    49  }
    50  
    51  // An AtomicPtrMap maps Keys to non-nil pointers to Values. AtomicPtrMap are
    52  // safe for concurrent use from multiple goroutines without additional
    53  // synchronization.
    54  //
    55  // The zero value of AtomicPtrMap is empty (maps all Keys to nil) and ready for
    56  // use. AtomicPtrMaps must not be copied after first use.
    57  //
    58  // sync.Map may be faster than AtomicPtrMap if most operations on the map are
    59  // concurrent writes to a fixed set of keys. AtomicPtrMap is usually faster in
    60  // other circumstances.
    61  type goroutineLocksAtomicPtrMap struct {
    62  	shards [1 << goroutineLocksShardOrder]goroutineLocksapmShard
    63  }
    64  
    65  func (m *goroutineLocksAtomicPtrMap) shard(hash uintptr) *goroutineLocksapmShard {
    66  	// Go defines right shifts >= width of shifted unsigned operand as 0, so
    67  	// this is correct even if ShardOrder is 0 (although nogo complains because
    68  	// nogo is dumb).
    69  	const indexLSB = unsafe.Sizeof(uintptr(0))*8 - goroutineLocksShardOrder
    70  	index := hash >> indexLSB
    71  	return (*goroutineLocksapmShard)(unsafe.Pointer(uintptr(unsafe.Pointer(&m.shards)) + (index * unsafe.Sizeof(goroutineLocksapmShard{}))))
    72  }
    73  
    74  type goroutineLocksapmShard struct {
    75  	goroutineLocksapmShardMutationData
    76  	_ [goroutineLocksapmShardMutationDataPadding]byte
    77  	goroutineLocksapmShardLookupData
    78  	_ [goroutineLocksapmShardLookupDataPadding]byte
    79  }
    80  
    81  type goroutineLocksapmShardMutationData struct {
    82  	dirtyMu  sync.Mutex // serializes slot transitions out of empty
    83  	dirty    uintptr    // # slots with val != nil
    84  	count    uintptr    // # slots with val != nil and val != tombstone()
    85  	rehashMu sync.Mutex // serializes rehashing
    86  }
    87  
    88  type goroutineLocksapmShardLookupData struct {
    89  	seq   sync.SeqCount  // allows atomic reads of slots+mask
    90  	slots unsafe.Pointer // [mask+1]slot or nil; protected by rehashMu/seq
    91  	mask  uintptr        // always (a power of 2) - 1; protected by rehashMu/seq
    92  }
    93  
    94  const (
    95  	goroutineLockscacheLineBytes = 64
    96  	// Cache line padding is enabled if sharding is.
    97  	goroutineLocksapmEnablePadding = (goroutineLocksShardOrder + 63) >> 6 // 0 if ShardOrder == 0, 1 otherwise
    98  	// The -1 and +1 below are required to ensure that if unsafe.Sizeof(T) %
    99  	// cacheLineBytes == 0, then padding is 0 (rather than cacheLineBytes).
   100  	goroutineLocksapmShardMutationDataRequiredPadding = goroutineLockscacheLineBytes - (((unsafe.Sizeof(goroutineLocksapmShardMutationData{}) - 1) % goroutineLockscacheLineBytes) + 1)
   101  	goroutineLocksapmShardMutationDataPadding         = goroutineLocksapmEnablePadding * goroutineLocksapmShardMutationDataRequiredPadding
   102  	goroutineLocksapmShardLookupDataRequiredPadding   = goroutineLockscacheLineBytes - (((unsafe.Sizeof(goroutineLocksapmShardLookupData{}) - 1) % goroutineLockscacheLineBytes) + 1)
   103  	goroutineLocksapmShardLookupDataPadding           = goroutineLocksapmEnablePadding * goroutineLocksapmShardLookupDataRequiredPadding
   104  
   105  	// These define fractional thresholds for when apmShard.rehash() is called
   106  	// (i.e. the load factor) and when it rehases to a larger table
   107  	// respectively. They are chosen such that the rehash threshold = the
   108  	// expansion threshold + 1/2, so that when reuse of deleted slots is rare
   109  	// or non-existent, rehashing occurs after the insertion of at least 1/2
   110  	// the table's size in new entries, which is acceptably infrequent.
   111  	goroutineLocksapmRehashThresholdNum    = 2
   112  	goroutineLocksapmRehashThresholdDen    = 3
   113  	goroutineLocksapmExpansionThresholdNum = 1
   114  	goroutineLocksapmExpansionThresholdDen = 6
   115  )
   116  
   117  type goroutineLocksapmSlot struct {
   118  	// slot states are indicated by val:
   119  	//
   120  	//	* Empty: val == nil; key is meaningless. May transition to full or
   121  	//		evacuated with dirtyMu locked.
   122  	//
   123  	//	* Full: val != nil, tombstone(), or evacuated(); key is immutable. val
   124  	//		is the Value mapped to key. May transition to deleted or evacuated.
   125  	//
   126  	//	* Deleted: val == tombstone(); key is still immutable. key is mapped to
   127  	//		no Value. May transition to full or evacuated.
   128  	//
   129  	//	* Evacuated: val == evacuated(); key is immutable. Set by rehashing on
   130  	//		slots that have already been moved, requiring readers to wait for
   131  	//		rehashing to complete and use the new table. Terminal state.
   132  	//
   133  	// Note that once val is non-nil, it cannot become nil again. That is, the
   134  	// transition from empty to non-empty is irreversible for a given slot;
   135  	// the only way to create more empty slots is by rehashing.
   136  	val unsafe.Pointer
   137  	key int64
   138  }
   139  
   140  func goroutineLocksapmSlotAt(slots unsafe.Pointer, pos uintptr) *goroutineLocksapmSlot {
   141  	return (*goroutineLocksapmSlot)(unsafe.Pointer(uintptr(slots) + pos*unsafe.Sizeof(goroutineLocksapmSlot{})))
   142  }
   143  
   144  var goroutineLockstombstoneObj byte
   145  
   146  func goroutineLockstombstone() unsafe.Pointer {
   147  	return unsafe.Pointer(&goroutineLockstombstoneObj)
   148  }
   149  
   150  var goroutineLocksevacuatedObj byte
   151  
   152  func goroutineLocksevacuated() unsafe.Pointer {
   153  	return unsafe.Pointer(&goroutineLocksevacuatedObj)
   154  }
   155  
   156  // Load returns the Value stored in m for key.
   157  func (m *goroutineLocksAtomicPtrMap) Load(key int64) *goroutineLocks {
   158  	hash := goroutineLockshasher.Hash(key)
   159  	shard := m.shard(hash)
   160  
   161  retry:
   162  	epoch := shard.seq.BeginRead()
   163  	slots := atomic.LoadPointer(&shard.slots)
   164  	mask := atomic.LoadUintptr(&shard.mask)
   165  	if !shard.seq.ReadOk(epoch) {
   166  		goto retry
   167  	}
   168  	if slots == nil {
   169  		return nil
   170  	}
   171  
   172  	i := hash & mask
   173  	inc := uintptr(1)
   174  	for {
   175  		slot := goroutineLocksapmSlotAt(slots, i)
   176  		slotVal := atomic.LoadPointer(&slot.val)
   177  		if slotVal == nil {
   178  
   179  			return nil
   180  		}
   181  		if slotVal == goroutineLocksevacuated() {
   182  
   183  			goto retry
   184  		}
   185  		if slot.key == key {
   186  			if slotVal == goroutineLockstombstone() {
   187  				return nil
   188  			}
   189  			return (*goroutineLocks)(slotVal)
   190  		}
   191  		i = (i + inc) & mask
   192  		inc++
   193  	}
   194  }
   195  
   196  // Store stores the Value val for key.
   197  func (m *goroutineLocksAtomicPtrMap) Store(key int64, val *goroutineLocks) {
   198  	m.maybeCompareAndSwap(key, false, nil, val)
   199  }
   200  
   201  // Swap stores the Value val for key and returns the previously-mapped Value.
   202  func (m *goroutineLocksAtomicPtrMap) Swap(key int64, val *goroutineLocks) *goroutineLocks {
   203  	return m.maybeCompareAndSwap(key, false, nil, val)
   204  }
   205  
   206  // CompareAndSwap checks that the Value stored for key is oldVal; if it is, it
   207  // stores the Value newVal for key. CompareAndSwap returns the previous Value
   208  // stored for key, whether or not it stores newVal.
   209  func (m *goroutineLocksAtomicPtrMap) CompareAndSwap(key int64, oldVal, newVal *goroutineLocks) *goroutineLocks {
   210  	return m.maybeCompareAndSwap(key, true, oldVal, newVal)
   211  }
   212  
   213  func (m *goroutineLocksAtomicPtrMap) maybeCompareAndSwap(key int64, compare bool, typedOldVal, typedNewVal *goroutineLocks) *goroutineLocks {
   214  	hash := goroutineLockshasher.Hash(key)
   215  	shard := m.shard(hash)
   216  	oldVal := goroutineLockstombstone()
   217  	if typedOldVal != nil {
   218  		oldVal = unsafe.Pointer(typedOldVal)
   219  	}
   220  	newVal := goroutineLockstombstone()
   221  	if typedNewVal != nil {
   222  		newVal = unsafe.Pointer(typedNewVal)
   223  	}
   224  
   225  retry:
   226  	epoch := shard.seq.BeginRead()
   227  	slots := atomic.LoadPointer(&shard.slots)
   228  	mask := atomic.LoadUintptr(&shard.mask)
   229  	if !shard.seq.ReadOk(epoch) {
   230  		goto retry
   231  	}
   232  	if slots == nil {
   233  		if (compare && oldVal != goroutineLockstombstone()) || newVal == goroutineLockstombstone() {
   234  			return nil
   235  		}
   236  
   237  		shard.rehash(nil)
   238  		goto retry
   239  	}
   240  
   241  	i := hash & mask
   242  	inc := uintptr(1)
   243  	for {
   244  		slot := goroutineLocksapmSlotAt(slots, i)
   245  		slotVal := atomic.LoadPointer(&slot.val)
   246  		if slotVal == nil {
   247  			if (compare && oldVal != goroutineLockstombstone()) || newVal == goroutineLockstombstone() {
   248  				return nil
   249  			}
   250  
   251  			shard.dirtyMu.Lock()
   252  			slotVal = atomic.LoadPointer(&slot.val)
   253  			if slotVal == nil {
   254  
   255  				if dirty, capacity := shard.dirty+1, mask+1; dirty*goroutineLocksapmRehashThresholdDen >= capacity*goroutineLocksapmRehashThresholdNum {
   256  					shard.dirtyMu.Unlock()
   257  					shard.rehash(slots)
   258  					goto retry
   259  				}
   260  				slot.key = key
   261  				atomic.StorePointer(&slot.val, newVal)
   262  				shard.dirty++
   263  				atomic.AddUintptr(&shard.count, 1)
   264  				shard.dirtyMu.Unlock()
   265  				return nil
   266  			}
   267  
   268  			shard.dirtyMu.Unlock()
   269  		}
   270  		if slotVal == goroutineLocksevacuated() {
   271  
   272  			goto retry
   273  		}
   274  		if slot.key == key {
   275  
   276  			for {
   277  				if (compare && oldVal != slotVal) || newVal == slotVal {
   278  					if slotVal == goroutineLockstombstone() {
   279  						return nil
   280  					}
   281  					return (*goroutineLocks)(slotVal)
   282  				}
   283  				if atomic.CompareAndSwapPointer(&slot.val, slotVal, newVal) {
   284  					if slotVal == goroutineLockstombstone() {
   285  						atomic.AddUintptr(&shard.count, 1)
   286  						return nil
   287  					}
   288  					if newVal == goroutineLockstombstone() {
   289  						atomic.AddUintptr(&shard.count, ^uintptr(0))
   290  					}
   291  					return (*goroutineLocks)(slotVal)
   292  				}
   293  				slotVal = atomic.LoadPointer(&slot.val)
   294  				if slotVal == goroutineLocksevacuated() {
   295  					goto retry
   296  				}
   297  			}
   298  		}
   299  
   300  		i = (i + inc) & mask
   301  		inc++
   302  	}
   303  }
   304  
   305  // rehash is marked nosplit to avoid preemption during table copying.
   306  //
   307  //go:nosplit
   308  func (shard *goroutineLocksapmShard) rehash(oldSlots unsafe.Pointer) {
   309  	shard.rehashMu.Lock()
   310  	defer shard.rehashMu.Unlock()
   311  
   312  	if shard.slots != oldSlots {
   313  
   314  		return
   315  	}
   316  
   317  	newSize := uintptr(8)
   318  	if oldSlots != nil {
   319  		oldSize := shard.mask + 1
   320  		newSize = oldSize
   321  		if count := atomic.LoadUintptr(&shard.count) + 1; count*goroutineLocksapmExpansionThresholdDen > oldSize*goroutineLocksapmExpansionThresholdNum {
   322  			newSize *= 2
   323  		}
   324  	}
   325  
   326  	newSlotsSlice := make([]goroutineLocksapmSlot, newSize)
   327  	newSlots := unsafe.Pointer(&newSlotsSlice[0])
   328  	newMask := newSize - 1
   329  
   330  	shard.dirtyMu.Lock()
   331  	shard.seq.BeginWrite()
   332  
   333  	if oldSlots != nil {
   334  		realCount := uintptr(0)
   335  
   336  		oldMask := shard.mask
   337  		for i := uintptr(0); i <= oldMask; i++ {
   338  			oldSlot := goroutineLocksapmSlotAt(oldSlots, i)
   339  			val := atomic.SwapPointer(&oldSlot.val, goroutineLocksevacuated())
   340  			if val == nil || val == goroutineLockstombstone() {
   341  				continue
   342  			}
   343  			hash := goroutineLockshasher.Hash(oldSlot.key)
   344  			j := hash & newMask
   345  			inc := uintptr(1)
   346  			for {
   347  				newSlot := goroutineLocksapmSlotAt(newSlots, j)
   348  				if newSlot.val == nil {
   349  					newSlot.val = val
   350  					newSlot.key = oldSlot.key
   351  					break
   352  				}
   353  				j = (j + inc) & newMask
   354  				inc++
   355  			}
   356  			realCount++
   357  		}
   358  
   359  		shard.dirty = realCount
   360  	}
   361  
   362  	atomic.StorePointer(&shard.slots, newSlots)
   363  	atomic.StoreUintptr(&shard.mask, newMask)
   364  
   365  	shard.seq.EndWrite()
   366  	shard.dirtyMu.Unlock()
   367  }
   368  
   369  // Range invokes f on each Key-Value pair stored in m. If any call to f returns
   370  // false, Range stops iteration and returns.
   371  //
   372  // Range does not necessarily correspond to any consistent snapshot of the
   373  // Map's contents: no Key will be visited more than once, but if the Value for
   374  // any Key is stored or deleted concurrently, Range may reflect any mapping for
   375  // that Key from any point during the Range call.
   376  //
   377  // f must not call other methods on m.
   378  func (m *goroutineLocksAtomicPtrMap) Range(f func(key int64, val *goroutineLocks) bool) {
   379  	for si := 0; si < len(m.shards); si++ {
   380  		shard := &m.shards[si]
   381  		if !shard.doRange(f) {
   382  			return
   383  		}
   384  	}
   385  }
   386  
   387  func (shard *goroutineLocksapmShard) doRange(f func(key int64, val *goroutineLocks) bool) bool {
   388  
   389  	shard.rehashMu.Lock()
   390  	defer shard.rehashMu.Unlock()
   391  	slots := shard.slots
   392  	if slots == nil {
   393  		return true
   394  	}
   395  	mask := shard.mask
   396  	for i := uintptr(0); i <= mask; i++ {
   397  		slot := goroutineLocksapmSlotAt(slots, i)
   398  		slotVal := atomic.LoadPointer(&slot.val)
   399  		if slotVal == nil || slotVal == goroutineLockstombstone() {
   400  			continue
   401  		}
   402  		if !f(slot.key, (*goroutineLocks)(slotVal)) {
   403  			return false
   404  		}
   405  	}
   406  	return true
   407  }
   408  
   409  // RangeRepeatable is like Range, but:
   410  //
   411  //   - RangeRepeatable may visit the same Key multiple times in the presence of
   412  //     concurrent mutators, possibly passing different Values to f in different
   413  //     calls.
   414  //
   415  //   - It is safe for f to call other methods on m.
   416  func (m *goroutineLocksAtomicPtrMap) RangeRepeatable(f func(key int64, val *goroutineLocks) bool) {
   417  	for si := 0; si < len(m.shards); si++ {
   418  		shard := &m.shards[si]
   419  
   420  	retry:
   421  		epoch := shard.seq.BeginRead()
   422  		slots := atomic.LoadPointer(&shard.slots)
   423  		mask := atomic.LoadUintptr(&shard.mask)
   424  		if !shard.seq.ReadOk(epoch) {
   425  			goto retry
   426  		}
   427  		if slots == nil {
   428  			continue
   429  		}
   430  
   431  		for i := uintptr(0); i <= mask; i++ {
   432  			slot := goroutineLocksapmSlotAt(slots, i)
   433  			slotVal := atomic.LoadPointer(&slot.val)
   434  			if slotVal == goroutineLocksevacuated() {
   435  				goto retry
   436  			}
   437  			if slotVal == nil || slotVal == goroutineLockstombstone() {
   438  				continue
   439  			}
   440  			if !f(slot.key, (*goroutineLocks)(slotVal)) {
   441  				return
   442  			}
   443  		}
   444  	}
   445  }