github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/sentry/vfs/mount_unsafe.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package vfs
    16  
    17  import (
    18  	"fmt"
    19  	"math/bits"
    20  	"sync/atomic"
    21  	"unsafe"
    22  
    23  	"github.com/nicocha30/gvisor-ligolo/pkg/atomicbitops"
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/gohacks"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    26  )
    27  
    28  // mountKey represents the location at which a Mount is mounted. It is
    29  // structurally identical to VirtualDentry, but stores its fields as
    30  // unsafe.Pointer since mutators synchronize with VFS path traversal using
    31  // seqcounts.
    32  //
    33  // This is explicitly not savable.
    34  type mountKey struct {
    35  	parent unsafe.Pointer // *Mount
    36  	point  unsafe.Pointer // *Dentry
    37  }
    38  
    39  var (
    40  	mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil))
    41  	mountKeySeed   = sync.RandUintptr()
    42  )
    43  
    44  func (k *mountKey) hash() uintptr {
    45  	return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed)
    46  }
    47  
    48  func (mnt *Mount) parent() *Mount {
    49  	return (*Mount)(atomic.LoadPointer(&mnt.key.parent))
    50  }
    51  
    52  func (mnt *Mount) point() *Dentry {
    53  	return (*Dentry)(atomic.LoadPointer(&mnt.key.point))
    54  }
    55  
    56  func (mnt *Mount) getKey() VirtualDentry {
    57  	return VirtualDentry{
    58  		mount:  mnt.parent(),
    59  		dentry: mnt.point(),
    60  	}
    61  }
    62  
    63  // Invariant: mnt.key.parent == nil. vd.Ok().
    64  func (mnt *Mount) setKey(vd VirtualDentry) {
    65  	atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount))
    66  	atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry))
    67  }
    68  
    69  // mountTable maps (mount parent, mount point) pairs to mounts. It supports
    70  // efficient concurrent lookup, even in the presence of concurrent mutators
    71  // (provided mutation is sufficiently uncommon).
    72  //
    73  // mountTable.Init() must be called on new mountTables before use.
    74  type mountTable struct {
    75  	// mountTable is implemented as a seqcount-protected hash table that
    76  	// resolves collisions with linear probing, featuring Robin Hood insertion
    77  	// and backward shift deletion. These minimize probe length variance,
    78  	// significantly improving the performance of linear probing at high load
    79  	// factors. (mountTable doesn't use bucketing, which is the other major
    80  	// technique commonly used in high-performance hash tables; the efficiency
    81  	// of bucketing is largely due to SIMD lookup, and Go lacks both SIMD
    82  	// intrinsics and inline assembly, limiting the performance of this
    83  	// approach.)
    84  
    85  	seq sync.SeqCount `state:"nosave"`
    86  
    87  	// size holds both length (number of elements) and capacity (number of
    88  	// slots): capacity is stored as its base-2 log (referred to as order) in
    89  	// the least significant bits of size, and length is stored in the
    90  	// remaining bits. Go defines bit shifts >= width of shifted unsigned
    91  	// operand as shifting to 0, which differs from x86's SHL, so the Go
    92  	// compiler inserts a bounds check for each bit shift unless we mask order
    93  	// anyway (cf. runtime.bucketShift()), and length isn't used by lookup;
    94  	// thus this bit packing gets us more bits for the length (vs. storing
    95  	// length and cap in separate uint32s) for ~free.
    96  	size atomicbitops.Uint64
    97  
    98  	slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init
    99  }
   100  
   101  type mountSlot struct {
   102  	// We don't store keys in slots; instead, we just check Mount.parent and
   103  	// Mount.point directly. Any practical use of lookup will need to touch
   104  	// Mounts anyway, and comparing hashes means that false positives are
   105  	// extremely rare, so this isn't an extra cache line touch overall.
   106  	value unsafe.Pointer // *Mount
   107  	hash  uintptr
   108  }
   109  
   110  const (
   111  	mtSizeOrderBits = 6 // log2 of pointer size in bits
   112  	mtSizeOrderMask = (1 << mtSizeOrderBits) - 1
   113  	mtSizeOrderOne  = 1
   114  	mtSizeLenLSB    = mtSizeOrderBits
   115  	mtSizeLenOne    = 1 << mtSizeLenLSB
   116  	mtSizeLenNegOne = ^uint64(mtSizeOrderMask) // uint64(-1) << mtSizeLenLSB
   117  
   118  	mountSlotBytes = unsafe.Sizeof(mountSlot{})
   119  	mountKeyBytes  = unsafe.Sizeof(mountKey{})
   120  
   121  	// Tuning parameters.
   122  	//
   123  	// Essentially every mountTable will contain at least /proc, /sys, and
   124  	// /dev/shm, so there is ~no reason for mtInitCap to be < 4.
   125  	mtInitOrder  = 2
   126  	mtInitCap    = 1 << mtInitOrder
   127  	mtMaxLoadNum = 13
   128  	mtMaxLoadDen = 16
   129  )
   130  
   131  func init() {
   132  	// We can't just define mtSizeOrderBits as follows because Go doesn't have
   133  	// constexpr.
   134  	if ptrBits := uint(unsafe.Sizeof(uintptr(0)) * 8); mtSizeOrderBits != bits.TrailingZeros(ptrBits) {
   135  		panic(fmt.Sprintf("mtSizeOrderBits (%d) must be %d = log2 of pointer size in bits (%d)", mtSizeOrderBits, bits.TrailingZeros(ptrBits), ptrBits))
   136  	}
   137  	if bits.OnesCount(uint(mountSlotBytes)) != 1 {
   138  		panic(fmt.Sprintf("sizeof(mountSlotBytes) (%d) must be a power of 2 to use bit masking for wraparound", mountSlotBytes))
   139  	}
   140  	if mtInitCap <= 1 {
   141  		panic(fmt.Sprintf("mtInitCap (%d) must be at least 2 since mountTable methods assume that there will always be at least one empty slot", mtInitCap))
   142  	}
   143  	if mtMaxLoadNum >= mtMaxLoadDen {
   144  		panic(fmt.Sprintf("invalid mountTable maximum load factor (%d/%d)", mtMaxLoadNum, mtMaxLoadDen))
   145  	}
   146  }
   147  
   148  // Init must be called exactly once on each mountTable before use.
   149  func (mt *mountTable) Init() {
   150  	mt.size = atomicbitops.FromUint64(mtInitOrder)
   151  	mt.slots = newMountTableSlots(mtInitCap)
   152  }
   153  
   154  func newMountTableSlots(cap uintptr) unsafe.Pointer {
   155  	slice := make([]mountSlot, cap, cap)
   156  	return unsafe.Pointer(&slice[0])
   157  }
   158  
   159  // Lookup returns the Mount with the given parent, mounted at the given point.
   160  // If no such Mount exists, Lookup returns nil.
   161  //
   162  // Lookup may be called even if there are concurrent mutators of mt.
   163  func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
   164  	key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
   165  	hash := key.hash()
   166  
   167  loop:
   168  	for {
   169  		epoch := mt.seq.BeginRead()
   170  		size := mt.size.Load()
   171  		slots := atomic.LoadPointer(&mt.slots)
   172  		if !mt.seq.ReadOk(epoch) {
   173  			continue
   174  		}
   175  		tcap := uintptr(1) << (size & mtSizeOrderMask)
   176  		mask := tcap - 1
   177  		off := (hash & mask) * mountSlotBytes
   178  		offmask := mask * mountSlotBytes
   179  		for {
   180  			// This avoids bounds checking.
   181  			slot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + off))
   182  			slotValue := atomic.LoadPointer(&slot.value)
   183  			slotHash := atomic.LoadUintptr(&slot.hash)
   184  			if !mt.seq.ReadOk(epoch) {
   185  				// The element we're looking for might have been moved into a
   186  				// slot we've previously checked, so restart entirely.
   187  				continue loop
   188  			}
   189  			if slotValue == nil {
   190  				return nil
   191  			}
   192  			if slotHash == hash {
   193  				mount := (*Mount)(slotValue)
   194  				var mountKey mountKey
   195  				mountKey.parent = atomic.LoadPointer(&mount.key.parent)
   196  				mountKey.point = atomic.LoadPointer(&mount.key.point)
   197  				if !mt.seq.ReadOk(epoch) {
   198  					continue loop
   199  				}
   200  				if key == mountKey {
   201  					return mount
   202  				}
   203  			}
   204  			off = (off + mountSlotBytes) & offmask
   205  		}
   206  	}
   207  }
   208  
   209  // Range calls f on each Mount in mt. If f returns false, Range stops iteration
   210  // and returns immediately.
   211  func (mt *mountTable) Range(f func(*Mount) bool) {
   212  	tcap := uintptr(1) << (mt.size.Load() & mtSizeOrderMask)
   213  	slotPtr := mt.slots
   214  	last := unsafe.Pointer(uintptr(mt.slots) + ((tcap - 1) * mountSlotBytes))
   215  	for {
   216  		slot := (*mountSlot)(slotPtr)
   217  		if slot.value != nil {
   218  			if !f((*Mount)(slot.value)) {
   219  				return
   220  			}
   221  		}
   222  		if slotPtr == last {
   223  			return
   224  		}
   225  		slotPtr = unsafe.Pointer(uintptr(slotPtr) + mountSlotBytes)
   226  	}
   227  }
   228  
   229  // Insert inserts the given mount into mt.
   230  //
   231  // Preconditions: mt must not already contain a Mount with the same mount point
   232  // and parent.
   233  func (mt *mountTable) Insert(mount *Mount) {
   234  	mt.seq.BeginWrite()
   235  	mt.insertSeqed(mount)
   236  	mt.seq.EndWrite()
   237  }
   238  
   239  // insertSeqed inserts the given mount into mt.
   240  //
   241  // Preconditions:
   242  //   - mt.seq must be in a writer critical section.
   243  //   - mt must not already contain a Mount with the same mount point and parent.
   244  func (mt *mountTable) insertSeqed(mount *Mount) {
   245  	hash := mount.key.hash()
   246  
   247  	// We're under the maximum load factor if:
   248  	//
   249  	//          (len+1) / cap <= mtMaxLoadNum / mtMaxLoadDen
   250  	// (len+1) * mtMaxLoadDen <= mtMaxLoadNum * cap
   251  	tlen := mt.size.RacyLoad() >> mtSizeLenLSB
   252  	order := mt.size.RacyLoad() & mtSizeOrderMask
   253  	tcap := uintptr(1) << order
   254  	if ((tlen + 1) * mtMaxLoadDen) <= (uint64(mtMaxLoadNum) << order) {
   255  		// Atomically insert the new element into the table.
   256  		mt.size.Add(mtSizeLenOne)
   257  		mtInsertLocked(mt.slots, tcap, unsafe.Pointer(mount), hash)
   258  		return
   259  	}
   260  
   261  	// Otherwise, we have to expand. Double the number of slots in the new
   262  	// table.
   263  	newOrder := order + 1
   264  	if newOrder > mtSizeOrderMask {
   265  		panic("mount table size overflow")
   266  	}
   267  	newCap := uintptr(1) << newOrder
   268  	newSlots := newMountTableSlots(newCap)
   269  	// Copy existing elements to the new table.
   270  	oldCur := mt.slots
   271  	// Go does not permit pointers to the end of allocated objects, so we
   272  	// must use a pointer to the last element of the old table. The
   273  	// following expression is equivalent to
   274  	// `slots+(cap-1)*mountSlotBytes` but has a critical path length of 2
   275  	// arithmetic instructions instead of 3.
   276  	oldLast := unsafe.Pointer((uintptr(mt.slots) - mountSlotBytes) + (tcap * mountSlotBytes))
   277  	for {
   278  		oldSlot := (*mountSlot)(oldCur)
   279  		if oldSlot.value != nil {
   280  			mtInsertLocked(newSlots, newCap, oldSlot.value, oldSlot.hash)
   281  		}
   282  		if oldCur == oldLast {
   283  			break
   284  		}
   285  		oldCur = unsafe.Pointer(uintptr(oldCur) + mountSlotBytes)
   286  	}
   287  	// Insert the new element into the new table.
   288  	mtInsertLocked(newSlots, newCap, unsafe.Pointer(mount), hash)
   289  	// Switch to the new table.
   290  	mt.size.Add(mtSizeLenOne | mtSizeOrderOne)
   291  	atomic.StorePointer(&mt.slots, newSlots)
   292  }
   293  
   294  // Preconditions:
   295  //   - There are no concurrent mutators of the table (slots, cap).
   296  //   - If the table is visible to readers, then mt.seq must be in a writer
   297  //     critical section.
   298  //   - cap must be a power of 2.
   299  func mtInsertLocked(slots unsafe.Pointer, cap uintptr, value unsafe.Pointer, hash uintptr) {
   300  	mask := cap - 1
   301  	off := (hash & mask) * mountSlotBytes
   302  	offmask := mask * mountSlotBytes
   303  	disp := uintptr(0)
   304  	for {
   305  		slot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + off))
   306  		slotValue := slot.value
   307  		if slotValue == nil {
   308  			atomic.StorePointer(&slot.value, value)
   309  			atomic.StoreUintptr(&slot.hash, hash)
   310  			return
   311  		}
   312  		// If we've been displaced farther from our first-probed slot than the
   313  		// element stored in this one, swap elements and switch to inserting
   314  		// the replaced one. (This is Robin Hood insertion.)
   315  		slotHash := slot.hash
   316  		slotDisp := ((off / mountSlotBytes) - slotHash) & mask
   317  		if disp > slotDisp {
   318  			atomic.StorePointer(&slot.value, value)
   319  			atomic.StoreUintptr(&slot.hash, hash)
   320  			value = slotValue
   321  			hash = slotHash
   322  			disp = slotDisp
   323  		}
   324  		off = (off + mountSlotBytes) & offmask
   325  		disp++
   326  	}
   327  }
   328  
   329  // Remove removes the given mount from mt.
   330  //
   331  // Preconditions:
   332  //   - mt must contain mount.
   333  //   - mount.key should be valid.
   334  func (mt *mountTable) Remove(mount *Mount) {
   335  	mt.seq.BeginWrite()
   336  	mt.removeSeqed(mount)
   337  	mt.seq.EndWrite()
   338  }
   339  
   340  // removeSeqed removes the given mount from mt.
   341  //
   342  // Preconditions same as Remove() plus:
   343  //   - mt.seq must be in a writer critical section.
   344  func (mt *mountTable) removeSeqed(mount *Mount) {
   345  	hash := mount.key.hash()
   346  	tcap := uintptr(1) << (mt.size.RacyLoad() & mtSizeOrderMask)
   347  	mask := tcap - 1
   348  	slots := mt.slots
   349  	off := (hash & mask) * mountSlotBytes
   350  	offmask := mask * mountSlotBytes
   351  	for {
   352  		slot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + off))
   353  		slotValue := slot.value
   354  		if slotValue == unsafe.Pointer(mount) {
   355  			// Found the element to remove. Move all subsequent elements
   356  			// backward until we either find an empty slot, or an element that
   357  			// is already in its first-probed slot. (This is backward shift
   358  			// deletion.)
   359  			for {
   360  				nextOff := (off + mountSlotBytes) & offmask
   361  				nextSlot := (*mountSlot)(unsafe.Pointer(uintptr(slots) + nextOff))
   362  				nextSlotValue := nextSlot.value
   363  				if nextSlotValue == nil {
   364  					break
   365  				}
   366  				nextSlotHash := nextSlot.hash
   367  				if (nextOff / mountSlotBytes) == (nextSlotHash & mask) {
   368  					break
   369  				}
   370  				atomic.StorePointer(&slot.value, nextSlotValue)
   371  				atomic.StoreUintptr(&slot.hash, nextSlotHash)
   372  				off = nextOff
   373  				slot = nextSlot
   374  			}
   375  			atomic.StorePointer(&slot.value, nil)
   376  			mt.size.Add(mtSizeLenNegOne)
   377  			return
   378  		}
   379  		if checkInvariants && slotValue == nil {
   380  			panic(fmt.Sprintf("mountTable.Remove() called on missing Mount %v", mount))
   381  		}
   382  		off = (off + mountSlotBytes) & offmask
   383  	}
   384  }