github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/tcpip/link/rawfile/rawfile_unsafe.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build linux
    16  // +build linux
    17  
    18  // Package rawfile contains utilities for using the netstack with raw host
    19  // files on Linux hosts.
    20  package rawfile
    21  
    22  import (
    23  	"reflect"
    24  	"unsafe"
    25  
    26  	"golang.org/x/sys/unix"
    27  	"github.com/sagernet/gvisor/pkg/tcpip"
    28  )
    29  
    30  // SizeofIovec is the size of a unix.Iovec in bytes.
    31  const SizeofIovec = unsafe.Sizeof(unix.Iovec{})
    32  
    33  // MaxIovs is UIO_MAXIOV, the maximum number of iovecs that may be passed to a
    34  // host system call in a single array.
    35  const MaxIovs = 1024
    36  
    37  // IovecFromBytes returns a unix.Iovec representing bs.
    38  //
    39  // Preconditions: len(bs) > 0.
    40  func IovecFromBytes(bs []byte) unix.Iovec {
    41  	iov := unix.Iovec{
    42  		Base: &bs[0],
    43  	}
    44  	iov.SetLen(len(bs))
    45  	return iov
    46  }
    47  
    48  func bytesFromIovec(iov unix.Iovec) (bs []byte) {
    49  	sh := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
    50  	sh.Data = uintptr(unsafe.Pointer(iov.Base))
    51  	sh.Len = int(iov.Len)
    52  	sh.Cap = int(iov.Len)
    53  	return
    54  }
    55  
    56  // AppendIovecFromBytes returns append(iovs, IovecFromBytes(bs)). If len(bs) ==
    57  // 0, AppendIovecFromBytes returns iovs without modification. If len(iovs) >=
    58  // max, AppendIovecFromBytes replaces the final iovec in iovs with one that
    59  // also includes the contents of bs. Note that this implies that
    60  // AppendIovecFromBytes is only usable when the returned iovec slice is used as
    61  // the source of a write.
    62  func AppendIovecFromBytes(iovs []unix.Iovec, bs []byte, max int) []unix.Iovec {
    63  	if len(bs) == 0 {
    64  		return iovs
    65  	}
    66  	if len(iovs) < max {
    67  		return append(iovs, IovecFromBytes(bs))
    68  	}
    69  	iovs[len(iovs)-1] = IovecFromBytes(append(bytesFromIovec(iovs[len(iovs)-1]), bs...))
    70  	return iovs
    71  }
    72  
    73  // MMsgHdr represents the mmsg_hdr structure required by recvmmsg() on linux.
    74  type MMsgHdr struct {
    75  	Msg unix.Msghdr
    76  	Len uint32
    77  	_   [4]byte
    78  }
    79  
    80  // SizeofMMsgHdr is the size of a MMsgHdr in bytes.
    81  const SizeofMMsgHdr = unsafe.Sizeof(MMsgHdr{})
    82  
    83  // GetMTU determines the MTU of a network interface device.
    84  func GetMTU(name string) (uint32, error) {
    85  	fd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_DGRAM, 0)
    86  	if err != nil {
    87  		return 0, err
    88  	}
    89  
    90  	defer unix.Close(fd)
    91  
    92  	var ifreq struct {
    93  		name [16]byte
    94  		mtu  int32
    95  		_    [20]byte
    96  	}
    97  
    98  	copy(ifreq.name[:], name)
    99  	_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCGIFMTU, uintptr(unsafe.Pointer(&ifreq)))
   100  	if errno != 0 {
   101  		return 0, errno
   102  	}
   103  
   104  	return uint32(ifreq.mtu), nil
   105  }
   106  
   107  // NonBlockingWrite writes the given buffer to a file descriptor. It fails if
   108  // partial data is written.
   109  func NonBlockingWrite(fd int, buf []byte) tcpip.Error {
   110  	var ptr unsafe.Pointer
   111  	if len(buf) > 0 {
   112  		ptr = unsafe.Pointer(&buf[0])
   113  	}
   114  
   115  	_, _, e := unix.RawSyscall(unix.SYS_WRITE, uintptr(fd), uintptr(ptr), uintptr(len(buf)))
   116  	if e != 0 {
   117  		return TranslateErrno(e)
   118  	}
   119  
   120  	return nil
   121  }
   122  
   123  // NonBlockingWriteIovec writes iovec to a file descriptor in a single unix.
   124  // It fails if partial data is written.
   125  func NonBlockingWriteIovec(fd int, iovec []unix.Iovec) tcpip.Error {
   126  	iovecLen := uintptr(len(iovec))
   127  	_, _, e := unix.RawSyscall(unix.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen)
   128  	if e != 0 {
   129  		return TranslateErrno(e)
   130  	}
   131  	return nil
   132  }
   133  
   134  // NonBlockingSendMMsg sends multiple messages on a socket.
   135  func NonBlockingSendMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) {
   136  	n, _, e := unix.RawSyscall6(unix.SYS_SENDMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0)
   137  	if e != 0 {
   138  		return 0, TranslateErrno(e)
   139  	}
   140  
   141  	return int(n), nil
   142  }
   143  
   144  // PollEvent represents the pollfd structure passed to a poll() system call.
   145  type PollEvent struct {
   146  	FD      int32
   147  	Events  int16
   148  	Revents int16
   149  }
   150  
   151  // BlockingRead reads from a file descriptor that is set up as non-blocking. If
   152  // no data is available, it will block in a poll() syscall until the file
   153  // descriptor becomes readable.
   154  func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
   155  	n, err := BlockingReadUntranslated(fd, b)
   156  	if err != 0 {
   157  		return n, TranslateErrno(err)
   158  	}
   159  	return n, nil
   160  }
   161  
   162  // BlockingReadUntranslated reads from a file descriptor that is set up as
   163  // non-blocking. If no data is available, it will block in a poll() syscall
   164  // until the file descriptor becomes readable. It returns the raw unix.Errno
   165  // value returned by the underlying syscalls.
   166  func BlockingReadUntranslated(fd int, b []byte) (int, unix.Errno) {
   167  	for {
   168  		n, _, e := unix.RawSyscall(unix.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
   169  		if e == 0 {
   170  			return int(n), 0
   171  		}
   172  
   173  		event := PollEvent{
   174  			FD:     int32(fd),
   175  			Events: 1, // POLLIN
   176  		}
   177  
   178  		_, e = BlockingPoll(&event, 1, nil)
   179  		if e != 0 && e != unix.EINTR {
   180  			return 0, e
   181  		}
   182  	}
   183  }
   184  
   185  // BlockingReadvUntilStopped reads from a file descriptor that is set up as
   186  // non-blocking and stores the data in a list of iovecs buffers. If no data is
   187  // available, it will block in a poll() syscall until the file descriptor
   188  // becomes readable or stop is signalled (efd becomes readable). Returns -1 in
   189  // the latter case.
   190  func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, tcpip.Error) {
   191  	for {
   192  		n, _, e := unix.RawSyscall(unix.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
   193  		if e == 0 {
   194  			return int(n), nil
   195  		}
   196  		if e != 0 && e != unix.EWOULDBLOCK {
   197  			return 0, TranslateErrno(e)
   198  		}
   199  		stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
   200  		if stopped {
   201  			return -1, nil
   202  		}
   203  		if e != 0 && e != unix.EINTR {
   204  			return 0, TranslateErrno(e)
   205  		}
   206  	}
   207  }
   208  
   209  // BlockingRecvMMsgUntilStopped reads from a file descriptor that is set up as
   210  // non-blocking and stores the received messages in a slice of MMsgHdr
   211  // structures. If no data is available, it will block in a poll() syscall until
   212  // the file descriptor becomes readable or stop is signalled (efd becomes
   213  // readable). Returns -1 in the latter case.
   214  func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) {
   215  	for {
   216  		n, _, e := unix.RawSyscall6(unix.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0)
   217  		if e == 0 {
   218  			return int(n), nil
   219  		}
   220  
   221  		if e != 0 && e != unix.EWOULDBLOCK {
   222  			return 0, TranslateErrno(e)
   223  		}
   224  
   225  		stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
   226  		if stopped {
   227  			return -1, nil
   228  		}
   229  		if e != 0 && e != unix.EINTR {
   230  			return 0, TranslateErrno(e)
   231  		}
   232  	}
   233  }
   234  
   235  // BlockingPollUntilStopped polls for events on fd or until a stop is signalled
   236  // on the event fd efd. Returns true if stopped, i.e., efd has event POLLIN.
   237  func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno) {
   238  	pevents := [...]PollEvent{
   239  		{
   240  			FD:     int32(efd),
   241  			Events: unix.POLLIN,
   242  		},
   243  		{
   244  			FD:     int32(fd),
   245  			Events: events,
   246  		},
   247  	}
   248  	_, _, errno := unix.Syscall6(unix.SYS_PPOLL, uintptr(unsafe.Pointer(&pevents[0])), uintptr(len(pevents)), 0, 0, 0, 0)
   249  	if errno != 0 {
   250  		return pevents[0].Revents&unix.POLLIN != 0, errno
   251  	}
   252  
   253  	if pevents[1].Revents&unix.POLLHUP != 0 || pevents[1].Revents&unix.POLLERR != 0 {
   254  		errno = unix.ECONNRESET
   255  	}
   256  
   257  	return pevents[0].Revents&unix.POLLIN != 0, errno
   258  }