github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/vfs/mount_unsafe.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package vfs
    16  
    17  import (
    18  	"fmt"
    19  	"math/bits"
    20  	"sync/atomic"
    21  	"unsafe"
    22  
    23  	"github.com/SagerNet/gvisor/pkg/gohacks"
    24  	"github.com/SagerNet/gvisor/pkg/sync"
    25  )
    26  
    27  // mountKey represents the location at which a Mount is mounted. It is
    28  // structurally identical to VirtualDentry, but stores its fields as
    29  // unsafe.Pointer since mutators synchronize with VFS path traversal using
    30  // seqcounts.
    31  //
    32  // This is explicitly not savable.
    33  type mountKey struct {
    34  	parent unsafe.Pointer // *Mount
    35  	point  unsafe.Pointer // *Dentry
    36  }
    37  
    38  var (
    39  	mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil))
    40  	mountKeySeed   = sync.RandUintptr()
    41  )
    42  
    43  func (k *mountKey) hash() uintptr {
    44  	return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed)
    45  }
    46  
    47  func (mnt *Mount) parent() *Mount {
    48  	return (*Mount)(atomic.LoadPointer(&mnt.key.parent))
    49  }
    50  
    51  func (mnt *Mount) point() *Dentry {
    52  	return (*Dentry)(atomic.LoadPointer(&mnt.key.point))
    53  }
    54  
    55  func (mnt *Mount) getKey() VirtualDentry {
    56  	return VirtualDentry{
    57  		mount:  mnt.parent(),
    58  		dentry: mnt.point(),
    59  	}
    60  }
    61  
    62  // Invariant: mnt.key.parent == nil. vd.Ok().
    63  func (mnt *Mount) setKey(vd VirtualDentry) {
    64  	atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount))
    65  	atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry))
    66  }
    67  
    68  // mountTable maps (mount parent, mount point) pairs to mounts. It supports
    69  // efficient concurrent lookup, even in the presence of concurrent mutators
    70  // (provided mutation is sufficiently uncommon).
    71  //
    72  // mountTable.Init() must be called on new mountTables before use.
    73  type mountTable struct {
    74  	// mountTable is implemented as a seqcount-protected hash table that
    75  	// resolves collisions with linear probing, featuring Robin Hood insertion
    76  	// and backward shift deletion. These minimize probe length variance,
    77  	// significantly improving the performance of linear probing at high load
    78  	// factors. (mountTable doesn't use bucketing, which is the other major
    79  	// technique commonly used in high-performance hash tables; the efficiency
    80  	// of bucketing is largely due to SIMD lookup, and Go lacks both SIMD
    81  	// intrinsics and inline assembly, limiting the performance of this
    82  	// approach.)
    83  
    84  	seq sync.SeqCount `state:"nosave"`
    85  
    86  	// size holds both length (number of elements) and capacity (number of
    87  	// slots): capacity is stored as its base-2 log (referred to as order) in
    88  	// the least significant bits of size, and length is stored in the
    89  	// remaining bits. Go defines bit shifts >= width of shifted unsigned
    90  	// operand as shifting to 0, which differs from x86's SHL, so the Go
    91  	// compiler inserts a bounds check for each bit shift unless we mask order
    92  	// anyway (cf. runtime.bucketShift()), and length isn't used by lookup;
    93  	// thus this bit packing gets us more bits for the length (vs. storing
    94  	// length and cap in separate uint32s) for ~free.
    95  	size uint64
    96  
    97  	slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init
    98  }
    99  
   100  type mountSlot struct {
   101  	// We don't store keys in slots; instead, we just check Mount.parent and
   102  	// Mount.point directly. Any practical use of lookup will need to touch
   103  	// Mounts anyway, and comparing hashes means that false positives are
   104  	// extremely rare, so this isn't an extra cache line touch overall.
   105  	value unsafe.Pointer // *Mount
   106  	hash  uintptr
   107  }
   108  
   109  const (
   110  	mtSizeOrderBits = 6 // log2 of pointer size in bits
   111  	mtSizeOrderMask = (1 << mtSizeOrderBits) - 1
   112  	mtSizeOrderOne  = 1
   113  	mtSizeLenLSB    = mtSizeOrderBits
   114  	mtSizeLenOne    = 1 << mtSizeLenLSB
   115  	mtSizeLenNegOne = ^uint64(mtSizeOrderMask) // uint64(-1) << mtSizeLenLSB
   116  
   117  	mountSlotBytes = unsafe.Sizeof(mountSlot{})
   118  	mountKeyBytes  = unsafe.Sizeof(mountKey{})
   119  
   120  	// Tuning parameters.
   121  	//
   122  	// Essentially every mountTable will contain at least /proc, /sys, and
   123  	// /dev/shm, so there is ~no reason for mtInitCap to be < 4.
   124  	mtInitOrder  = 2
   125  	mtInitCap    = 1 << mtInitOrder
   126  	mtMaxLoadNum = 13
   127  	mtMaxLoadDen = 16
   128  )
   129  
   130  func init() {
   131  	// We can't just define mtSizeOrderBits as follows because Go doesn't have
   132  	// constexpr.
   133  	if ptrBits := uint(unsafe.Sizeof(uintptr(0)) * 8); mtSizeOrderBits != bits.TrailingZeros(ptrBits) {
   134  		panic(fmt.Sprintf("mtSizeOrderBits (%d) must be %d = log2 of pointer size in bits (%d)", mtSizeOrderBits, bits.TrailingZeros(ptrBits), ptrBits))
   135  	}
   136  	if bits.OnesCount(uint(mountSlotBytes)) != 1 {
   137  		panic(fmt.Sprintf("sizeof(mountSlotBytes) (%d) must be a power of 2 to use bit masking for wraparound", mountSlotBytes))
   138  	}
   139  	if mtInitCap <= 1 {
   140  		panic(fmt.Sprintf("mtInitCap (%d) must be at least 2 since mountTable methods assume that there will always be at least one empty slot", mtInitCap))
   141  	}
   142  	if mtMaxLoadNum >= mtMaxLoadDen {
   143  		panic(fmt.Sprintf("invalid mountTable maximum load factor (%d/%d)", mtMaxLoadNum, mtMaxLoadDen))
   144  	}
   145  }
   146  
   147  // Init must be called exactly once on each mountTable before use.
   148  func (mt *mountTable) Init() {
   149  	mt.size = mtInitOrder
   150  	mt.slots = newMountTableSlots(mtInitCap)
   151  }
   152  
   153  func newMountTableSlots(cap uintptr) unsafe.Pointer {
   154  	slice := make([]mountSlot, cap, cap)
   155  	hdr := (*gohacks.SliceHeader)(unsafe.Pointer(&slice))
   156  	return hdr.Data
   157  }
   158  
   159  // Lookup returns the Mount with the given parent, mounted at the given point.
   160  // If no such Mount exists, Lookup returns nil.
   161  //
   162  // Lookup may be called even if there are concurrent mutators of mt.
   163  func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
   164  	key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
   165  	hash := key.hash()
   166  
   167  loop:
   168  	for {
   169  		epoch := mt.seq.BeginRead()
   170  		size := atomic.LoadUint64(&mt.size)
   171  		slots := atomic.LoadPointer(&mt.slots)
   172  		if !mt.seq.ReadOk(epoch) {
   173  			continue
   174  		}
   175  		tcap := uintptr(1) << (size & mtSizeOrderMask)
   176  		mask := tcap - 1
   177  		off := (hash & mask) * mountSlotBytes
   178  		offmask := mask * mountSlotBytes
   179  		for {
   180  			// This avoids bounds checking.
   181  			slot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + off))
   182  			slotValue := atomic.LoadPointer(&slot.value)
   183  			slotHash := atomic.LoadUintptr(&slot.hash)
   184  			if !mt.seq.ReadOk(epoch) {
   185  				// The element we're looking for might have been moved into a
   186  				// slot we've previously checked, so restart entirely.
   187  				continue loop
   188  			}
   189  			if slotValue == nil {
   190  				return nil
   191  			}
   192  			if slotHash == hash {
   193  				mount := (*Mount)(slotValue)
   194  				var mountKey mountKey
   195  				mountKey.parent = atomic.LoadPointer(&mount.key.parent)
   196  				mountKey.point = atomic.LoadPointer(&mount.key.point)
   197  				if !mt.seq.ReadOk(epoch) {
   198  					continue loop
   199  				}
   200  				if key == mountKey {
   201  					return mount
   202  				}
   203  			}
   204  			off = (off + mountSlotBytes) & offmask
   205  		}
   206  	}
   207  }
   208  
   209  // Range calls f on each Mount in mt. If f returns false, Range stops iteration
   210  // and returns immediately.
   211  func (mt *mountTable) Range(f func(*Mount) bool) {
   212  	tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
   213  	slotPtr := mt.slots
   214  	last := unsafe.Pointer(uintptr(mt.slots) + ((tcap - 1) * mountSlotBytes))
   215  	for {
   216  		slot := (*mountSlot)(slotPtr)
   217  		if slot.value != nil {
   218  			if !f((*Mount)(slot.value)) {
   219  				return
   220  			}
   221  		}
   222  		if slotPtr == last {
   223  			return
   224  		}
   225  		slotPtr = unsafe.Pointer(uintptr(slotPtr) + mountSlotBytes)
   226  	}
   227  }
   228  
   229  // Insert inserts the given mount into mt.
   230  //
   231  // Preconditions: mt must not already contain a Mount with the same mount point
   232  // and parent.
   233  func (mt *mountTable) Insert(mount *Mount) {
   234  	mt.seq.BeginWrite()
   235  	mt.insertSeqed(mount)
   236  	mt.seq.EndWrite()
   237  }
   238  
   239  // insertSeqed inserts the given mount into mt.
   240  //
   241  // Preconditions:
   242  // * mt.seq must be in a writer critical section.
   243  // * mt must not already contain a Mount with the same mount point and parent.
   244  func (mt *mountTable) insertSeqed(mount *Mount) {
   245  	hash := mount.key.hash()
   246  
   247  	// We're under the maximum load factor if:
   248  	//
   249  	//          (len+1) / cap <= mtMaxLoadNum / mtMaxLoadDen
   250  	// (len+1) * mtMaxLoadDen <= mtMaxLoadNum * cap
   251  	tlen := mt.size >> mtSizeLenLSB
   252  	order := mt.size & mtSizeOrderMask
   253  	tcap := uintptr(1) << order
   254  	if ((tlen + 1) * mtMaxLoadDen) <= (uint64(mtMaxLoadNum) << order) {
   255  		// Atomically insert the new element into the table.
   256  		atomic.AddUint64(&mt.size, mtSizeLenOne)
   257  		mtInsertLocked(mt.slots, tcap, unsafe.Pointer(mount), hash)
   258  		return
   259  	}
   260  
   261  	// Otherwise, we have to expand. Double the number of slots in the new
   262  	// table.
   263  	newOrder := order + 1
   264  	if newOrder > mtSizeOrderMask {
   265  		panic("mount table size overflow")
   266  	}
   267  	newCap := uintptr(1) << newOrder
   268  	newSlots := newMountTableSlots(newCap)
   269  	// Copy existing elements to the new table.
   270  	oldCur := mt.slots
   271  	// Go does not permit pointers to the end of allocated objects, so we
   272  	// must use a pointer to the last element of the old table. The
   273  	// following expression is equivalent to
   274  	// `slots+(cap-1)*mountSlotBytes` but has a critical path length of 2
   275  	// arithmetic instructions instead of 3.
   276  	oldLast := unsafe.Pointer((uintptr(mt.slots) - mountSlotBytes) + (tcap * mountSlotBytes))
   277  	for {
   278  		oldSlot := (*mountSlot)(oldCur)
   279  		if oldSlot.value != nil {
   280  			mtInsertLocked(newSlots, newCap, oldSlot.value, oldSlot.hash)
   281  		}
   282  		if oldCur == oldLast {
   283  			break
   284  		}
   285  		oldCur = unsafe.Pointer(uintptr(oldCur) + mountSlotBytes)
   286  	}
   287  	// Insert the new element into the new table.
   288  	mtInsertLocked(newSlots, newCap, unsafe.Pointer(mount), hash)
   289  	// Switch to the new table.
   290  	atomic.AddUint64(&mt.size, mtSizeLenOne|mtSizeOrderOne)
   291  	atomic.StorePointer(&mt.slots, newSlots)
   292  }
   293  
   294  // Preconditions:
   295  // * There are no concurrent mutators of the table (slots, cap).
   296  // * If the table is visible to readers, then mt.seq must be in a writer
   297  //   critical section.
   298  // * cap must be a power of 2.
   299  func mtInsertLocked(slots unsafe.Pointer, cap uintptr, value unsafe.Pointer, hash uintptr) {
   300  	mask := cap - 1
   301  	off := (hash & mask) * mountSlotBytes
   302  	offmask := mask * mountSlotBytes
   303  	disp := uintptr(0)
   304  	for {
   305  		slot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + off))
   306  		slotValue := slot.value
   307  		if slotValue == nil {
   308  			atomic.StorePointer(&slot.value, value)
   309  			atomic.StoreUintptr(&slot.hash, hash)
   310  			return
   311  		}
   312  		// If we've been displaced farther from our first-probed slot than the
   313  		// element stored in this one, swap elements and switch to inserting
   314  		// the replaced one. (This is Robin Hood insertion.)
   315  		slotHash := slot.hash
   316  		slotDisp := ((off / mountSlotBytes) - slotHash) & mask
   317  		if disp > slotDisp {
   318  			atomic.StorePointer(&slot.value, value)
   319  			atomic.StoreUintptr(&slot.hash, hash)
   320  			value = slotValue
   321  			hash = slotHash
   322  			disp = slotDisp
   323  		}
   324  		off = (off + mountSlotBytes) & offmask
   325  		disp++
   326  	}
   327  }
   328  
   329  // Remove removes the given mount from mt.
   330  //
   331  // Preconditions: mt must contain mount.
   332  func (mt *mountTable) Remove(mount *Mount) {
   333  	mt.seq.BeginWrite()
   334  	mt.removeSeqed(mount)
   335  	mt.seq.EndWrite()
   336  }
   337  
   338  // removeSeqed removes the given mount from mt.
   339  //
   340  // Preconditions:
   341  // * mt.seq must be in a writer critical section.
   342  // * mt must contain mount.
   343  func (mt *mountTable) removeSeqed(mount *Mount) {
   344  	hash := mount.key.hash()
   345  	tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
   346  	mask := tcap - 1
   347  	slots := mt.slots
   348  	off := (hash & mask) * mountSlotBytes
   349  	offmask := mask * mountSlotBytes
   350  	for {
   351  		slot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + off))
   352  		slotValue := slot.value
   353  		if slotValue == unsafe.Pointer(mount) {
   354  			// Found the element to remove. Move all subsequent elements
   355  			// backward until we either find an empty slot, or an element that
   356  			// is already in its first-probed slot. (This is backward shift
   357  			// deletion.)
   358  			for {
   359  				nextOff := (off + mountSlotBytes) & offmask
   360  				nextSlot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + nextOff))
   361  				nextSlotValue := nextSlot.value
   362  				if nextSlotValue == nil {
   363  					break
   364  				}
   365  				nextSlotHash := nextSlot.hash
   366  				if (nextOff / mountSlotBytes) == (nextSlotHash & mask) {
   367  					break
   368  				}
   369  				atomic.StorePointer(&slot.value, nextSlotValue)
   370  				atomic.StoreUintptr(&slot.hash, nextSlotHash)
   371  				off = nextOff
   372  				slot = nextSlot
   373  			}
   374  			atomic.StorePointer(&slot.value, nil)
   375  			atomic.AddUint64(&mt.size, mtSizeLenNegOne)
   376  			return
   377  		}
   378  		if checkInvariants && slotValue == nil {
   379  			panic(fmt.Sprintf("mountTable.Remove() called on missing Mount %v", mount))
   380  		}
   381  		off = (off + mountSlotBytes) & offmask
   382  	}
   383  }