github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/safecopy/safecopy_unsafe.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safecopy
    16  
    17  import (
    18  	"fmt"
    19  	"runtime"
    20  	"unsafe"
    21  
    22  	"golang.org/x/sys/unix"
    23  )
    24  
    25  // maxRegisterSize is the maximum register size used in memcpy and memclr. It
    26  // is used to decide by how much to rewind the copy (for memcpy) or zeroing
    27  // (for memclr) before proceeding.
    28  const maxRegisterSize = 16
    29  
    30  // memcpy copies data from src to dst. If a SIGSEGV or SIGBUS signal is received
    31  // during the copy, it returns the address that caused the fault and the number
    32  // of the signal that was received. Otherwise, it returns an unspecified address
    33  // and a signal number of 0.
    34  //
    35  // Data is copied in order, such that if a fault happens at address p, it is
    36  // safe to assume that all data before p-maxRegisterSize has already been
    37  // successfully copied.
    38  //
    39  //go:noescape
    40  func memcpy(dst, src uintptr, n uintptr) (fault uintptr, sig int32)
    41  
    42  // memclr sets the n bytes following ptr to zeroes. If a SIGSEGV or SIGBUS
    43  // signal is received during the write, it returns the address that caused the
    44  // fault and the number of the signal that was received. Otherwise, it returns
    45  // an unspecified address and a signal number of 0.
    46  //
    47  // Data is written in order, such that if a fault happens at address p, it is
    48  // safe to assume that all data before p-maxRegisterSize has already been
    49  // successfully written.
    50  //
    51  //go:noescape
    52  func memclr(ptr uintptr, n uintptr) (fault uintptr, sig int32)
    53  
    54  // swapUint32 atomically stores new into *ptr and returns (the previous *ptr
    55  // value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
    56  // value of old is unspecified, and sig is the number of the signal that was
    57  // received.
    58  //
    59  // Preconditions: ptr must be aligned to a 4-byte boundary.
    60  //
    61  //go:noescape
    62  func swapUint32(ptr unsafe.Pointer, new uint32) (old uint32, sig int32)
    63  
    64  // swapUint64 atomically stores new into *ptr and returns (the previous *ptr
    65  // value, 0). If a SIGSEGV or SIGBUS signal is received during the swap, the
    66  // value of old is unspecified, and sig is the number of the signal that was
    67  // received.
    68  //
    69  // Preconditions: ptr must be aligned to a 8-byte boundary.
    70  //
    71  //go:noescape
    72  func swapUint64(ptr unsafe.Pointer, new uint64) (old uint64, sig int32)
    73  
    74  // compareAndSwapUint32 is like sync/atomic.CompareAndSwapUint32, but returns
    75  // (the value previously stored at ptr, 0). If a SIGSEGV or SIGBUS signal is
    76  // received during the operation, the value of prev is unspecified, and sig is
    77  // the number of the signal that was received.
    78  //
    79  // Preconditions: ptr must be aligned to a 4-byte boundary.
    80  //
    81  //go:noescape
    82  func compareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (prev uint32, sig int32)
    83  
    84  // LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It
    85  // may fail with SIGSEGV or SIGBUS if it is received while reading from ptr.
    86  //
    87  // Preconditions: ptr must be aligned to a 4-byte boundary.
    88  //
    89  //go:noescape
    90  func loadUint32(ptr unsafe.Pointer) (val uint32, sig int32)
    91  
    92  // Return the start address of the functions above.
    93  //
    94  // In Go 1.17+, Go references to assembly functions resolve to an ABIInternal
    95  // wrapper function rather than the function itself. We must reference from
    96  // assembly to get the ABI0 (i.e., primary) address.
    97  func addrOfMemcpy() uintptr
    98  func addrOfMemclr() uintptr
    99  func addrOfSwapUint32() uintptr
   100  func addrOfSwapUint64() uintptr
   101  func addrOfCompareAndSwapUint32() uintptr
   102  func addrOfLoadUint32() uintptr
   103  
   104  // CopyIn copies len(dst) bytes from src to dst. It returns the number of bytes
   105  // copied and an error if SIGSEGV or SIGBUS is received while reading from src.
   106  func CopyIn(dst []byte, src unsafe.Pointer) (int, error) {
   107  	n, err := copyIn(dst, uintptr(src))
   108  	runtime.KeepAlive(src)
   109  	return n, err
   110  }
   111  
   112  // copyIn is the underlying definition for CopyIn.
   113  func copyIn(dst []byte, src uintptr) (int, error) {
   114  	toCopy := uintptr(len(dst))
   115  	if len(dst) == 0 {
   116  		return 0, nil
   117  	}
   118  
   119  	fault, sig := memcpy(uintptr(unsafe.Pointer(&dst[0])), src, toCopy)
   120  	if sig == 0 {
   121  		return len(dst), nil
   122  	}
   123  
   124  	if fault < src || fault >= src+toCopy {
   125  		panic(fmt.Sprintf("CopyIn raised signal %d at %#x, which is outside source [%#x, %#x)", sig, fault, src, src+toCopy))
   126  	}
   127  
   128  	// memcpy might have ended the copy up to maxRegisterSize bytes before
   129  	// fault, if an instruction caused a memory access that straddled two
   130  	// pages, and the second one faulted. Try to copy up to the fault.
   131  	var done int
   132  	if fault-src > maxRegisterSize {
   133  		done = int(fault - src - maxRegisterSize)
   134  	}
   135  	n, err := copyIn(dst[done:int(fault-src)], src+uintptr(done))
   136  	done += n
   137  	if err != nil {
   138  		return done, err
   139  	}
   140  	return done, errorFromFaultSignal(fault, sig)
   141  }
   142  
   143  // CopyOut copies len(src) bytes from src to dst. If returns the number of
   144  // bytes done and an error if SIGSEGV or SIGBUS is received while writing to
   145  // dst.
   146  func CopyOut(dst unsafe.Pointer, src []byte) (int, error) {
   147  	n, err := copyOut(uintptr(dst), src)
   148  	runtime.KeepAlive(dst)
   149  	return n, err
   150  }
   151  
   152  // copyOut is the underlying definition for CopyOut.
   153  func copyOut(dst uintptr, src []byte) (int, error) {
   154  	toCopy := uintptr(len(src))
   155  	if toCopy == 0 {
   156  		return 0, nil
   157  	}
   158  
   159  	fault, sig := memcpy(dst, uintptr(unsafe.Pointer(&src[0])), toCopy)
   160  	if sig == 0 {
   161  		return len(src), nil
   162  	}
   163  
   164  	if fault < dst || fault >= dst+toCopy {
   165  		panic(fmt.Sprintf("CopyOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, fault, dst, dst+toCopy))
   166  	}
   167  
   168  	// memcpy might have ended the copy up to maxRegisterSize bytes before
   169  	// fault, if an instruction caused a memory access that straddled two
   170  	// pages, and the second one faulted. Try to copy up to the fault.
   171  	var done int
   172  	if fault-dst > maxRegisterSize {
   173  		done = int(fault - dst - maxRegisterSize)
   174  	}
   175  	n, err := copyOut(dst+uintptr(done), src[done:int(fault-dst)])
   176  	done += n
   177  	if err != nil {
   178  		return done, err
   179  	}
   180  	return done, errorFromFaultSignal(fault, sig)
   181  }
   182  
   183  // Copy copies toCopy bytes from src to dst. It returns the number of bytes
   184  // copied and an error if SIGSEGV or SIGBUS is received while reading from src
   185  // or writing to dst.
   186  //
   187  // Data is copied in order; if [src, src+toCopy) and [dst, dst+toCopy) overlap,
   188  // the resulting contents of dst are unspecified.
   189  func Copy(dst, src unsafe.Pointer, toCopy uintptr) (uintptr, error) {
   190  	n, err := copyN(uintptr(dst), uintptr(src), toCopy)
   191  	runtime.KeepAlive(dst)
   192  	runtime.KeepAlive(src)
   193  	return n, err
   194  }
   195  
   196  // copyN is the underlying definition for Copy.
   197  func copyN(dst, src uintptr, toCopy uintptr) (uintptr, error) {
   198  	if toCopy == 0 {
   199  		return 0, nil
   200  	}
   201  
   202  	fault, sig := memcpy(dst, src, toCopy)
   203  	if sig == 0 {
   204  		return toCopy, nil
   205  	}
   206  
   207  	// Did the fault occur while reading from src or writing to dst?
   208  	faultAfterSrc := ^uintptr(0)
   209  	if fault >= src {
   210  		faultAfterSrc = fault - src
   211  	}
   212  	faultAfterDst := ^uintptr(0)
   213  	if fault >= dst {
   214  		faultAfterDst = fault - dst
   215  	}
   216  	if faultAfterSrc >= toCopy && faultAfterDst >= toCopy {
   217  		panic(fmt.Sprintf("Copy raised signal %d at %#x, which is outside source [%#x, %#x) and destination [%#x, %#x)", sig, fault, src, src+toCopy, dst, dst+toCopy))
   218  	}
   219  	faultedAfter := faultAfterSrc
   220  	if faultedAfter > faultAfterDst {
   221  		faultedAfter = faultAfterDst
   222  	}
   223  
   224  	// memcpy might have ended the copy up to maxRegisterSize bytes before
   225  	// fault, if an instruction caused a memory access that straddled two
   226  	// pages, and the second one faulted. Try to copy up to the fault.
   227  	var done uintptr
   228  	if faultedAfter > maxRegisterSize {
   229  		done = faultedAfter - maxRegisterSize
   230  	}
   231  	n, err := copyN(dst+done, src+done, faultedAfter-done)
   232  	done += n
   233  	if err != nil {
   234  		return done, err
   235  	}
   236  	return done, errorFromFaultSignal(fault, sig)
   237  }
   238  
   239  // ZeroOut writes toZero zero bytes to dst. It returns the number of bytes
   240  // written and an error if SIGSEGV or SIGBUS is received while writing to dst.
   241  func ZeroOut(dst unsafe.Pointer, toZero uintptr) (uintptr, error) {
   242  	n, err := zeroOut(uintptr(dst), toZero)
   243  	runtime.KeepAlive(dst)
   244  	return n, err
   245  }
   246  
   247  // zeroOut is the underlying definition for ZeroOut.
   248  func zeroOut(dst uintptr, toZero uintptr) (uintptr, error) {
   249  	if toZero == 0 {
   250  		return 0, nil
   251  	}
   252  
   253  	fault, sig := memclr(dst, toZero)
   254  	if sig == 0 {
   255  		return toZero, nil
   256  	}
   257  
   258  	if fault < dst || fault >= dst+toZero {
   259  		panic(fmt.Sprintf("ZeroOut raised signal %d at %#x, which is outside destination [%#x, %#x)", sig, fault, dst, dst+toZero))
   260  	}
   261  
   262  	// memclr might have ended the write up to maxRegisterSize bytes before
   263  	// fault, if an instruction caused a memory access that straddled two
   264  	// pages, and the second one faulted. Try to write up to the fault.
   265  	var done uintptr
   266  	if fault-dst > maxRegisterSize {
   267  		done = fault - dst - maxRegisterSize
   268  	}
   269  	n, err := zeroOut(dst+done, fault-dst-done)
   270  	done += n
   271  	if err != nil {
   272  		return done, err
   273  	}
   274  	return done, errorFromFaultSignal(fault, sig)
   275  }
   276  
   277  // SwapUint32 is equivalent to sync/atomic.SwapUint32, except that it returns
   278  // an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is
   279  // not aligned to a 4-byte boundary.
   280  func SwapUint32(ptr unsafe.Pointer, new uint32) (uint32, error) {
   281  	if addr := uintptr(ptr); addr&3 != 0 {
   282  		return 0, AlignmentError{addr, 4}
   283  	}
   284  	old, sig := swapUint32(ptr, new)
   285  	return old, errorFromFaultSignal(uintptr(ptr), sig)
   286  }
   287  
   288  // SwapUint64 is equivalent to sync/atomic.SwapUint64, except that it returns
   289  // an error if SIGSEGV or SIGBUS is received while accessing ptr, or if ptr is
   290  // not aligned to an 8-byte boundary.
   291  func SwapUint64(ptr unsafe.Pointer, new uint64) (uint64, error) {
   292  	if addr := uintptr(ptr); addr&7 != 0 {
   293  		return 0, AlignmentError{addr, 8}
   294  	}
   295  	old, sig := swapUint64(ptr, new)
   296  	return old, errorFromFaultSignal(uintptr(ptr), sig)
   297  }
   298  
   299  // CompareAndSwapUint32 is equivalent to atomicbitops.CompareAndSwapUint32,
   300  // except that it returns an error if SIGSEGV or SIGBUS is received while
   301  // accessing ptr, or if ptr is not aligned to a 4-byte boundary.
   302  func CompareAndSwapUint32(ptr unsafe.Pointer, old, new uint32) (uint32, error) {
   303  	if addr := uintptr(ptr); addr&3 != 0 {
   304  		return 0, AlignmentError{addr, 4}
   305  	}
   306  	prev, sig := compareAndSwapUint32(ptr, old, new)
   307  	return prev, errorFromFaultSignal(uintptr(ptr), sig)
   308  }
   309  
   310  // LoadUint32 is like sync/atomic.LoadUint32, but operates with user memory. It
   311  // may fail with SIGSEGV or SIGBUS if it is received while reading from ptr.
   312  //
   313  // Preconditions: ptr must be aligned to a 4-byte boundary.
   314  func LoadUint32(ptr unsafe.Pointer) (uint32, error) {
   315  	if addr := uintptr(ptr); addr&3 != 0 {
   316  		return 0, AlignmentError{addr, 4}
   317  	}
   318  	val, sig := loadUint32(ptr)
   319  	return val, errorFromFaultSignal(uintptr(ptr), sig)
   320  }
   321  
   322  func errorFromFaultSignal(addr uintptr, sig int32) error {
   323  	switch sig {
   324  	case 0:
   325  		return nil
   326  	case int32(unix.SIGSEGV):
   327  		return SegvError{addr}
   328  	case int32(unix.SIGBUS):
   329  		return BusError{addr}
   330  	default:
   331  		panic(fmt.Sprintf("safecopy got unexpected signal %d at address %#x", sig, addr))
   332  	}
   333  }