github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package atomicptrmap doesn't exist. This file must be instantiated using the
    16  // go_template_instance rule in tools/go_generics/defs.bzl.
    17  package atomicptrmap
    18  
    19  import (
    20  	"sync/atomic"
    21  	"unsafe"
    22  
    23  	"github.com/SagerNet/gvisor/pkg/gohacks"
    24  	"github.com/SagerNet/gvisor/pkg/sync"
    25  )
    26  
    27  // Key is a required type parameter.
    28  type Key struct{}
    29  
    30  // Value is a required type parameter.
    31  type Value struct{}
    32  
    33  const (
    34  	// ShardOrder is an optional parameter specifying the base-2 log of the
    35  	// number of shards per AtomicPtrMap. Higher values of ShardOrder reduce
    36  	// unnecessary synchronization between unrelated concurrent operations,
    37  	// improving performance for write-heavy workloads, but increase memory
    38  	// usage for small maps.
    39  	ShardOrder = 0
    40  )
    41  
    42  // Hasher is an optional type parameter. If Hasher is provided, it must define
    43  // the Init and Hash methods. One Hasher will be shared by all AtomicPtrMaps.
    44  type Hasher struct {
    45  	defaultHasher
    46  }
    47  
    48  // defaultHasher is the default Hasher. This indirection exists because
    49  // defaultHasher must exist even if a custom Hasher is provided, to prevent the
    50  // Go compiler from complaining about defaultHasher's unused imports.
    51  type defaultHasher struct {
    52  	fn   func(unsafe.Pointer, uintptr) uintptr
    53  	seed uintptr
    54  }
    55  
    56  // Init initializes the Hasher.
    57  func (h *defaultHasher) Init() {
    58  	h.fn = sync.MapKeyHasher(map[Key]*Value(nil))
    59  	h.seed = sync.RandUintptr()
    60  }
    61  
    62  // Hash returns the hash value for the given Key.
    63  func (h *defaultHasher) Hash(key Key) uintptr {
    64  	return h.fn(gohacks.Noescape(unsafe.Pointer(&key)), h.seed)
    65  }
    66  
    67  var hasher Hasher
    68  
    69  func init() {
    70  	hasher.Init()
    71  }
    72  
    73  // An AtomicPtrMap maps Keys to non-nil pointers to Values. AtomicPtrMap are
    74  // safe for concurrent use from multiple goroutines without additional
    75  // synchronization.
    76  //
    77  // The zero value of AtomicPtrMap is empty (maps all Keys to nil) and ready for
    78  // use. AtomicPtrMaps must not be copied after first use.
    79  //
    80  // sync.Map may be faster than AtomicPtrMap if most operations on the map are
    81  // concurrent writes to a fixed set of keys. AtomicPtrMap is usually faster in
    82  // other circumstances.
    83  type AtomicPtrMap struct {
    84  	// AtomicPtrMap is implemented as a hash table with the following
    85  	// properties:
    86  	//
    87  	// * Collisions are resolved with quadratic probing. Of the two major
    88  	// alternatives, Robin Hood linear probing makes it difficult for writers
    89  	// to execute in parallel, and bucketing is less effective in Go due to
    90  	// lack of SIMD.
    91  	//
    92  	// * The table is optionally divided into shards indexed by hash to further
    93  	// reduce unnecessary synchronization.
    94  
    95  	shards [1 << ShardOrder]apmShard
    96  }
    97  
    98  func (m *AtomicPtrMap) shard(hash uintptr) *apmShard {
    99  	// Go defines right shifts >= width of shifted unsigned operand as 0, so
   100  	// this is correct even if ShardOrder is 0 (although nogo complains because
   101  	// nogo is dumb).
   102  	const indexLSB = unsafe.Sizeof(uintptr(0))*8 - ShardOrder
   103  	index := hash >> indexLSB
   104  	return (*apmShard)(unsafe.Pointer(uintptr(unsafe.Pointer(&m.shards)) + (index * unsafe.Sizeof(apmShard{}))))
   105  }
   106  
   107  type apmShard struct {
   108  	apmShardMutationData
   109  	_ [apmShardMutationDataPadding]byte
   110  	apmShardLookupData
   111  	_ [apmShardLookupDataPadding]byte
   112  }
   113  
   114  type apmShardMutationData struct {
   115  	dirtyMu  sync.Mutex // serializes slot transitions out of empty
   116  	dirty    uintptr    // # slots with val != nil
   117  	count    uintptr    // # slots with val != nil and val != tombstone()
   118  	rehashMu sync.Mutex // serializes rehashing
   119  }
   120  
   121  type apmShardLookupData struct {
   122  	seq   sync.SeqCount  // allows atomic reads of slots+mask
   123  	slots unsafe.Pointer // [mask+1]slot or nil; protected by rehashMu/seq
   124  	mask  uintptr        // always (a power of 2) - 1; protected by rehashMu/seq
   125  }
   126  
   127  const (
   128  	cacheLineBytes = 64
   129  	// Cache line padding is enabled if sharding is.
   130  	apmEnablePadding = (ShardOrder + 63) >> 6 // 0 if ShardOrder == 0, 1 otherwise
   131  	// The -1 and +1 below are required to ensure that if unsafe.Sizeof(T) %
   132  	// cacheLineBytes == 0, then padding is 0 (rather than cacheLineBytes).
   133  	apmShardMutationDataRequiredPadding = cacheLineBytes - (((unsafe.Sizeof(apmShardMutationData{}) - 1) % cacheLineBytes) + 1)
   134  	apmShardMutationDataPadding         = apmEnablePadding * apmShardMutationDataRequiredPadding
   135  	apmShardLookupDataRequiredPadding   = cacheLineBytes - (((unsafe.Sizeof(apmShardLookupData{}) - 1) % cacheLineBytes) + 1)
   136  	apmShardLookupDataPadding           = apmEnablePadding * apmShardLookupDataRequiredPadding
   137  
   138  	// These define fractional thresholds for when apmShard.rehash() is called
   139  	// (i.e. the load factor) and when it rehases to a larger table
   140  	// respectively. They are chosen such that the rehash threshold = the
   141  	// expansion threshold + 1/2, so that when reuse of deleted slots is rare
   142  	// or non-existent, rehashing occurs after the insertion of at least 1/2
   143  	// the table's size in new entries, which is acceptably infrequent.
   144  	apmRehashThresholdNum    = 2
   145  	apmRehashThresholdDen    = 3
   146  	apmExpansionThresholdNum = 1
   147  	apmExpansionThresholdDen = 6
   148  )
   149  
   150  type apmSlot struct {
   151  	// slot states are indicated by val:
   152  	//
   153  	// * Empty: val == nil; key is meaningless. May transition to full or
   154  	// evacuated with dirtyMu locked.
   155  	//
   156  	// * Full: val != nil, tombstone(), or evacuated(); key is immutable. val
   157  	// is the Value mapped to key. May transition to deleted or evacuated.
   158  	//
   159  	// * Deleted: val == tombstone(); key is still immutable. key is mapped to
   160  	// no Value. May transition to full or evacuated.
   161  	//
   162  	// * Evacuated: val == evacuated(); key is immutable. Set by rehashing on
   163  	// slots that have already been moved, requiring readers to wait for
   164  	// rehashing to complete and use the new table. Terminal state.
   165  	//
   166  	// Note that once val is non-nil, it cannot become nil again. That is, the
   167  	// transition from empty to non-empty is irreversible for a given slot;
   168  	// the only way to create more empty slots is by rehashing.
   169  	val unsafe.Pointer
   170  	key Key
   171  }
   172  
   173  func apmSlotAt(slots unsafe.Pointer, pos uintptr) *apmSlot {
   174  	return (*apmSlot)(unsafe.Pointer(uintptr(slots) + pos*unsafe.Sizeof(apmSlot{})))
   175  }
   176  
   177  var tombstoneObj byte
   178  
   179  func tombstone() unsafe.Pointer {
   180  	return unsafe.Pointer(&tombstoneObj)
   181  }
   182  
   183  var evacuatedObj byte
   184  
   185  func evacuated() unsafe.Pointer {
   186  	return unsafe.Pointer(&evacuatedObj)
   187  }
   188  
   189  // Load returns the Value stored in m for key.
   190  func (m *AtomicPtrMap) Load(key Key) *Value {
   191  	hash := hasher.Hash(key)
   192  	shard := m.shard(hash)
   193  
   194  retry:
   195  	epoch := shard.seq.BeginRead()
   196  	slots := atomic.LoadPointer(&shard.slots)
   197  	mask := atomic.LoadUintptr(&shard.mask)
   198  	if !shard.seq.ReadOk(epoch) {
   199  		goto retry
   200  	}
   201  	if slots == nil {
   202  		return nil
   203  	}
   204  
   205  	i := hash & mask
   206  	inc := uintptr(1)
   207  	for {
   208  		slot := apmSlotAt(slots, i)
   209  		slotVal := atomic.LoadPointer(&slot.val)
   210  		if slotVal == nil {
   211  			// Empty slot; end of probe sequence.
   212  			return nil
   213  		}
   214  		if slotVal == evacuated() {
   215  			// Racing with rehashing.
   216  			goto retry
   217  		}
   218  		if slot.key == key {
   219  			if slotVal == tombstone() {
   220  				return nil
   221  			}
   222  			return (*Value)(slotVal)
   223  		}
   224  		i = (i + inc) & mask
   225  		inc++
   226  	}
   227  }
   228  
   229  // Store stores the Value val for key.
   230  func (m *AtomicPtrMap) Store(key Key, val *Value) {
   231  	m.maybeCompareAndSwap(key, false, nil, val)
   232  }
   233  
   234  // Swap stores the Value val for key and returns the previously-mapped Value.
   235  func (m *AtomicPtrMap) Swap(key Key, val *Value) *Value {
   236  	return m.maybeCompareAndSwap(key, false, nil, val)
   237  }
   238  
   239  // CompareAndSwap checks that the Value stored for key is oldVal; if it is, it
   240  // stores the Value newVal for key. CompareAndSwap returns the previous Value
   241  // stored for key, whether or not it stores newVal.
   242  func (m *AtomicPtrMap) CompareAndSwap(key Key, oldVal, newVal *Value) *Value {
   243  	return m.maybeCompareAndSwap(key, true, oldVal, newVal)
   244  }
   245  
   246  func (m *AtomicPtrMap) maybeCompareAndSwap(key Key, compare bool, typedOldVal, typedNewVal *Value) *Value {
   247  	hash := hasher.Hash(key)
   248  	shard := m.shard(hash)
   249  	oldVal := tombstone()
   250  	if typedOldVal != nil {
   251  		oldVal = unsafe.Pointer(typedOldVal)
   252  	}
   253  	newVal := tombstone()
   254  	if typedNewVal != nil {
   255  		newVal = unsafe.Pointer(typedNewVal)
   256  	}
   257  
   258  retry:
   259  	epoch := shard.seq.BeginRead()
   260  	slots := atomic.LoadPointer(&shard.slots)
   261  	mask := atomic.LoadUintptr(&shard.mask)
   262  	if !shard.seq.ReadOk(epoch) {
   263  		goto retry
   264  	}
   265  	if slots == nil {
   266  		if (compare && oldVal != tombstone()) || newVal == tombstone() {
   267  			return nil
   268  		}
   269  		// Need to allocate a table before insertion.
   270  		shard.rehash(nil)
   271  		goto retry
   272  	}
   273  
   274  	i := hash & mask
   275  	inc := uintptr(1)
   276  	for {
   277  		slot := apmSlotAt(slots, i)
   278  		slotVal := atomic.LoadPointer(&slot.val)
   279  		if slotVal == nil {
   280  			if (compare && oldVal != tombstone()) || newVal == tombstone() {
   281  				return nil
   282  			}
   283  			// Try to grab this slot for ourselves.
   284  			shard.dirtyMu.Lock()
   285  			slotVal = atomic.LoadPointer(&slot.val)
   286  			if slotVal == nil {
   287  				// Check if we need to rehash before dirtying a slot.
   288  				if dirty, capacity := shard.dirty+1, mask+1; dirty*apmRehashThresholdDen >= capacity*apmRehashThresholdNum {
   289  					shard.dirtyMu.Unlock()
   290  					shard.rehash(slots)
   291  					goto retry
   292  				}
   293  				slot.key = key
   294  				atomic.StorePointer(&slot.val, newVal) // transitions slot to full
   295  				shard.dirty++
   296  				atomic.AddUintptr(&shard.count, 1)
   297  				shard.dirtyMu.Unlock()
   298  				return nil
   299  			}
   300  			// Raced with another store; the slot is no longer empty. Continue
   301  			// with the new value of slotVal since we may have raced with
   302  			// another store of key.
   303  			shard.dirtyMu.Unlock()
   304  		}
   305  		if slotVal == evacuated() {
   306  			// Racing with rehashing.
   307  			goto retry
   308  		}
   309  		if slot.key == key {
   310  			// We're reusing an existing slot, so rehashing isn't necessary.
   311  			for {
   312  				if (compare && oldVal != slotVal) || newVal == slotVal {
   313  					if slotVal == tombstone() {
   314  						return nil
   315  					}
   316  					return (*Value)(slotVal)
   317  				}
   318  				if atomic.CompareAndSwapPointer(&slot.val, slotVal, newVal) {
   319  					if slotVal == tombstone() {
   320  						atomic.AddUintptr(&shard.count, 1)
   321  						return nil
   322  					}
   323  					if newVal == tombstone() {
   324  						atomic.AddUintptr(&shard.count, ^uintptr(0) /* -1 */)
   325  					}
   326  					return (*Value)(slotVal)
   327  				}
   328  				slotVal = atomic.LoadPointer(&slot.val)
   329  				if slotVal == evacuated() {
   330  					goto retry
   331  				}
   332  			}
   333  		}
   334  		// This produces a triangular number sequence of offsets from the
   335  		// initially-probed position.
   336  		i = (i + inc) & mask
   337  		inc++
   338  	}
   339  }
   340  
   341  // rehash is marked nosplit to avoid preemption during table copying.
   342  //go:nosplit
   343  func (shard *apmShard) rehash(oldSlots unsafe.Pointer) {
   344  	shard.rehashMu.Lock()
   345  	defer shard.rehashMu.Unlock()
   346  
   347  	if shard.slots != oldSlots {
   348  		// Raced with another call to rehash().
   349  		return
   350  	}
   351  
   352  	// Determine the size of the new table. Constraints:
   353  	//
   354  	// * The size of the table must be a power of two to ensure that every slot
   355  	// is visitable by every probe sequence under quadratic probing with
   356  	// triangular numbers.
   357  	//
   358  	// * The size of the table cannot decrease because even if shard.count is
   359  	// currently smaller than shard.dirty, concurrent stores that reuse
   360  	// existing slots can drive shard.count back up to a maximum of
   361  	// shard.dirty.
   362  	newSize := uintptr(8) // arbitrary initial size
   363  	if oldSlots != nil {
   364  		oldSize := shard.mask + 1
   365  		newSize = oldSize
   366  		if count := atomic.LoadUintptr(&shard.count) + 1; count*apmExpansionThresholdDen > oldSize*apmExpansionThresholdNum {
   367  			newSize *= 2
   368  		}
   369  	}
   370  
   371  	// Allocate the new table.
   372  	newSlotsSlice := make([]apmSlot, newSize)
   373  	newSlotsHeader := (*gohacks.SliceHeader)(unsafe.Pointer(&newSlotsSlice))
   374  	newSlots := newSlotsHeader.Data
   375  	newMask := newSize - 1
   376  
   377  	// Start a writer critical section now so that racing users of the old
   378  	// table that observe evacuated() wait for the new table. (But lock dirtyMu
   379  	// first since doing so may block, which we don't want to do during the
   380  	// writer critical section.)
   381  	shard.dirtyMu.Lock()
   382  	shard.seq.BeginWrite()
   383  
   384  	if oldSlots != nil {
   385  		realCount := uintptr(0)
   386  		// Copy old entries to the new table.
   387  		oldMask := shard.mask
   388  		for i := uintptr(0); i <= oldMask; i++ {
   389  			oldSlot := apmSlotAt(oldSlots, i)
   390  			val := atomic.SwapPointer(&oldSlot.val, evacuated())
   391  			if val == nil || val == tombstone() {
   392  				continue
   393  			}
   394  			hash := hasher.Hash(oldSlot.key)
   395  			j := hash & newMask
   396  			inc := uintptr(1)
   397  			for {
   398  				newSlot := apmSlotAt(newSlots, j)
   399  				if newSlot.val == nil {
   400  					newSlot.val = val
   401  					newSlot.key = oldSlot.key
   402  					break
   403  				}
   404  				j = (j + inc) & newMask
   405  				inc++
   406  			}
   407  			realCount++
   408  		}
   409  		// Update dirty to reflect that tombstones were not copied to the new
   410  		// table. Use realCount since a concurrent mutator may not have updated
   411  		// shard.count yet.
   412  		shard.dirty = realCount
   413  	}
   414  
   415  	// Switch to the new table.
   416  	atomic.StorePointer(&shard.slots, newSlots)
   417  	atomic.StoreUintptr(&shard.mask, newMask)
   418  
   419  	shard.seq.EndWrite()
   420  	shard.dirtyMu.Unlock()
   421  }
   422  
   423  // Range invokes f on each Key-Value pair stored in m. If any call to f returns
   424  // false, Range stops iteration and returns.
   425  //
   426  // Range does not necessarily correspond to any consistent snapshot of the
   427  // Map's contents: no Key will be visited more than once, but if the Value for
   428  // any Key is stored or deleted concurrently, Range may reflect any mapping for
   429  // that Key from any point during the Range call.
   430  //
   431  // f must not call other methods on m.
   432  func (m *AtomicPtrMap) Range(f func(key Key, val *Value) bool) {
   433  	for si := 0; si < len(m.shards); si++ {
   434  		shard := &m.shards[si]
   435  		if !shard.doRange(f) {
   436  			return
   437  		}
   438  	}
   439  }
   440  
   441  func (shard *apmShard) doRange(f func(key Key, val *Value) bool) bool {
   442  	// We have to lock rehashMu because if we handled races with rehashing by
   443  	// retrying, f could see the same key twice.
   444  	shard.rehashMu.Lock()
   445  	defer shard.rehashMu.Unlock()
   446  	slots := shard.slots
   447  	if slots == nil {
   448  		return true
   449  	}
   450  	mask := shard.mask
   451  	for i := uintptr(0); i <= mask; i++ {
   452  		slot := apmSlotAt(slots, i)
   453  		slotVal := atomic.LoadPointer(&slot.val)
   454  		if slotVal == nil || slotVal == tombstone() {
   455  			continue
   456  		}
   457  		if !f(slot.key, (*Value)(slotVal)) {
   458  			return false
   459  		}
   460  	}
   461  	return true
   462  }
   463  
   464  // RangeRepeatable is like Range, but:
   465  //
   466  // * RangeRepeatable may visit the same Key multiple times in the presence of
   467  // concurrent mutators, possibly passing different Values to f in different
   468  // calls.
   469  //
   470  // * It is safe for f to call other methods on m.
   471  func (m *AtomicPtrMap) RangeRepeatable(f func(key Key, val *Value) bool) {
   472  	for si := 0; si < len(m.shards); si++ {
   473  		shard := &m.shards[si]
   474  
   475  	retry:
   476  		epoch := shard.seq.BeginRead()
   477  		slots := atomic.LoadPointer(&shard.slots)
   478  		mask := atomic.LoadUintptr(&shard.mask)
   479  		if !shard.seq.ReadOk(epoch) {
   480  			goto retry
   481  		}
   482  		if slots == nil {
   483  			continue
   484  		}
   485  
   486  		for i := uintptr(0); i <= mask; i++ {
   487  			slot := apmSlotAt(slots, i)
   488  			slotVal := atomic.LoadPointer(&slot.val)
   489  			if slotVal == evacuated() {
   490  				goto retry
   491  			}
   492  			if slotVal == nil || slotVal == tombstone() {
   493  				continue
   494  			}
   495  			if !f(slot.key, (*Value)(slotVal)) {
   496  				return
   497  			}
   498  		}
   499  	}
   500  }