github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/sentry/socket/hostinet/socket.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package hostinet
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"golang.org/x/sys/unix"
    21  	"github.com/nicocha30/gvisor-ligolo/pkg/abi/linux"
    22  	"github.com/nicocha30/gvisor-ligolo/pkg/atomicbitops"
    23  	"github.com/nicocha30/gvisor-ligolo/pkg/context"
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/errors/linuxerr"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/fdnotifier"
    26  	"github.com/nicocha30/gvisor-ligolo/pkg/log"
    27  	"github.com/nicocha30/gvisor-ligolo/pkg/marshal/primitive"
    28  	"github.com/nicocha30/gvisor-ligolo/pkg/safemem"
    29  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/arch"
    30  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/fsimpl/sockfs"
    31  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/hostfd"
    32  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/kernel"
    33  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/kernel/auth"
    34  	ktime "github.com/nicocha30/gvisor-ligolo/pkg/sentry/kernel/time"
    35  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/socket"
    36  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/socket/control"
    37  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/vfs"
    38  	"github.com/nicocha30/gvisor-ligolo/pkg/syserr"
    39  	"github.com/nicocha30/gvisor-ligolo/pkg/usermem"
    40  	"github.com/nicocha30/gvisor-ligolo/pkg/waiter"
    41  )
    42  
    43  const (
    44  	// sizeofSockaddr is the size in bytes of the largest sockaddr type
    45  	// supported by this package.
    46  	sizeofSockaddr = unix.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in)
    47  
    48  	// maxControlLen is the maximum size of a control message buffer used in a
    49  	// recvmsg or sendmsg unix.
    50  	maxControlLen = 1024
    51  )
    52  
    53  // AllowedSocketType is a tuple of socket family, type, and protocol.
    54  type AllowedSocketType struct {
    55  	Family int
    56  	Type   int
    57  
    58  	// Protocol of AllowAllProtocols indicates that all protocols are
    59  	// allowed.
    60  	Protocol int
    61  }
    62  
    63  // AllowAllProtocols indicates that all protocols are allowed by the stack and
    64  // in the syscall filters.
    65  var AllowAllProtocols = -1
    66  
    67  // AllowedSocketTypes are the socket types which are supported by hostinet.
    68  // These are used to validate the arguments to socket(), and also to generate
    69  // syscall filters.
    70  var AllowedSocketTypes = []AllowedSocketType{
    71  	// Family, Type, Protocol.
    72  	{unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_TCP},
    73  	{unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_UDP},
    74  	{unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_ICMP},
    75  
    76  	{unix.AF_INET6, unix.SOCK_STREAM, unix.IPPROTO_TCP},
    77  	{unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_UDP},
    78  	{unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_ICMPV6},
    79  }
    80  
    81  // AllowedRawSocketTypes are the socket types which are supported by hostinet
    82  // with raw sockets enabled.
    83  var AllowedRawSocketTypes = []AllowedSocketType{
    84  	{unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_RAW},
    85  	{unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_TCP},
    86  	{unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_UDP},
    87  	{unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_ICMP},
    88  
    89  	{unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_RAW},
    90  	{unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_TCP},
    91  	{unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP},
    92  	{unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_ICMPV6},
    93  
    94  	// AF_PACKET do not allow Write or SendMsg.
    95  	{unix.AF_PACKET, unix.SOCK_DGRAM, AllowAllProtocols},
    96  	{unix.AF_PACKET, unix.SOCK_RAW, AllowAllProtocols},
    97  }
    98  
    99  // Socket implements socket.Socket (and by extension, vfs.FileDescriptionImpl)
   100  // for host sockets.
   101  //
   102  // +stateify savable
   103  type Socket struct {
   104  	vfsfd vfs.FileDescription
   105  	vfs.FileDescriptionDefaultImpl
   106  	vfs.LockFD
   107  	// We store metadata for hostinet sockets internally. Technically, we should
   108  	// access metadata (e.g. through stat, chmod) on the host for correctness,
   109  	// but this is not very useful for inet socket fds, which do not belong to a
   110  	// concrete file anyway.
   111  	vfs.DentryMetadataFileDescriptionImpl
   112  	socket.SendReceiveTimeout
   113  
   114  	family   int            // Read-only.
   115  	stype    linux.SockType // Read-only.
   116  	protocol int            // Read-only.
   117  	queue    waiter.Queue
   118  
   119  	// fd is the host socket fd. It must have O_NONBLOCK, so that operations
   120  	// will return EWOULDBLOCK instead of blocking on the host. This allows us to
   121  	// handle blocking behavior independently in the sentry.
   122  	fd int
   123  
   124  	// recvClosed indicates that the socket has been shutdown for reading
   125  	// (SHUT_RD or SHUT_RDWR).
   126  	recvClosed atomicbitops.Bool
   127  }
   128  
   129  var _ = socket.Socket(&Socket{})
   130  
   131  func newSocket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) {
   132  	mnt := t.Kernel().SocketMount()
   133  	d := sockfs.NewDentry(t, mnt)
   134  	defer d.DecRef(t)
   135  
   136  	s := &Socket{
   137  		family:   family,
   138  		stype:    stype,
   139  		protocol: protocol,
   140  		fd:       fd,
   141  	}
   142  	s.LockFD.Init(&vfs.FileLocks{})
   143  	if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
   144  		return nil, syserr.FromError(err)
   145  	}
   146  	vfsfd := &s.vfsfd
   147  	if err := vfsfd.Init(s, linux.O_RDWR|(flags&linux.O_NONBLOCK), mnt, d, &vfs.FileDescriptionOptions{
   148  		DenyPRead:         true,
   149  		DenyPWrite:        true,
   150  		UseDentryMetadata: true,
   151  	}); err != nil {
   152  		fdnotifier.RemoveFD(int32(s.fd))
   153  		return nil, syserr.FromError(err)
   154  	}
   155  	return vfsfd, nil
   156  }
   157  
   158  // Release implements vfs.FileDescriptionImpl.Release.
   159  func (s *Socket) Release(ctx context.Context) {
   160  	kernel.KernelFromContext(ctx).DeleteSocket(&s.vfsfd)
   161  	fdnotifier.RemoveFD(int32(s.fd))
   162  	_ = unix.Close(s.fd)
   163  }
   164  
   165  // Epollable implements FileDescriptionImpl.Epollable.
   166  func (s *Socket) Epollable() bool {
   167  	return true
   168  }
   169  
   170  // Ioctl implements vfs.FileDescriptionImpl.
   171  func (s *Socket) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
   172  	return ioctl(ctx, s.fd, uio, sysno, args)
   173  }
   174  
   175  // PRead implements vfs.FileDescriptionImpl.PRead.
   176  func (s *Socket) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
   177  	return 0, linuxerr.ESPIPE
   178  }
   179  
   180  // Read implements vfs.FileDescriptionImpl.
   181  func (s *Socket) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
   182  	// All flags other than RWF_NOWAIT should be ignored.
   183  	// TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
   184  	if opts.Flags != 0 {
   185  		return 0, linuxerr.EOPNOTSUPP
   186  	}
   187  
   188  	reader := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags)
   189  	defer hostfd.PutReadWriterAt(reader)
   190  	n, err := dst.CopyOutFrom(ctx, reader)
   191  	return int64(n), err
   192  }
   193  
   194  // PWrite implements vfs.FileDescriptionImpl.
   195  func (s *Socket) PWrite(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
   196  	return 0, linuxerr.ESPIPE
   197  }
   198  
   199  // Write implements vfs.FileDescriptionImpl.
   200  func (s *Socket) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
   201  	if s.family == linux.AF_PACKET {
   202  		// Don't allow Write for AF_PACKET.
   203  		return 0, linuxerr.EACCES
   204  	}
   205  
   206  	// All flags other than RWF_NOWAIT should be ignored.
   207  	// TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
   208  	if opts.Flags != 0 {
   209  		return 0, linuxerr.EOPNOTSUPP
   210  	}
   211  
   212  	writer := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags)
   213  	defer hostfd.PutReadWriterAt(writer)
   214  	n, err := src.CopyInTo(ctx, writer)
   215  	return int64(n), err
   216  }
   217  
   218  type socketProvider struct {
   219  	family int
   220  }
   221  
   222  // Socket implements socket.Provider.Socket.
   223  func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
   224  	// Check that we are using the host network stack.
   225  	netCtx := t.NetworkContext()
   226  	if netCtx == nil {
   227  		return nil, nil
   228  	}
   229  	stack, ok := netCtx.(*Stack)
   230  	if !ok {
   231  		return nil, nil
   232  	}
   233  
   234  	stype := stypeflags & linux.SOCK_TYPE_MASK
   235  
   236  	// Raw and packet sockets require CAP_NET_RAW.
   237  	if stype == linux.SOCK_RAW || p.family == linux.AF_PACKET {
   238  		if creds := auth.CredentialsFromContext(t); !creds.HasCapability(linux.CAP_NET_RAW) {
   239  			return nil, syserr.ErrNotPermitted
   240  		}
   241  	}
   242  
   243  	// Convert generic IPPROTO_IP protocol to the actual protocol depending
   244  	// on family and type.
   245  	if protocol == linux.IPPROTO_IP && (p.family == linux.AF_INET || p.family == linux.AF_INET6) {
   246  		switch stype {
   247  		case linux.SOCK_STREAM:
   248  			protocol = linux.IPPROTO_TCP
   249  		case linux.SOCK_DGRAM:
   250  			protocol = linux.IPPROTO_UDP
   251  		}
   252  	}
   253  
   254  	// Validate the socket based on family, type, and protocol.
   255  	var supported bool
   256  	for _, allowed := range stack.allowedSocketTypes {
   257  		isAllowedFamily := p.family == allowed.Family
   258  		isAllowedType := int(stype) == allowed.Type
   259  		isAllowedProtocol := protocol == allowed.Protocol || allowed.Protocol == AllowAllProtocols
   260  		if isAllowedFamily && isAllowedType && isAllowedProtocol {
   261  			supported = true
   262  			break
   263  		}
   264  	}
   265  	if !supported {
   266  		// Return nil error here to give other socket providers a
   267  		// chance to create this socket.
   268  		return nil, nil
   269  	}
   270  
   271  	// Conservatively ignore all flags specified by the application and add
   272  	// SOCK_NONBLOCK since socketOperations requires it.
   273  	st := int(stype) | unix.SOCK_NONBLOCK | unix.SOCK_CLOEXEC
   274  	fd, err := unix.Socket(p.family, st, protocol)
   275  	if err != nil {
   276  		return nil, syserr.FromError(err)
   277  	}
   278  	return newSocket(t, p.family, stype, protocol, fd, uint32(stypeflags&unix.SOCK_NONBLOCK))
   279  }
   280  
   281  // Pair implements socket.Provider.Pair.
   282  func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
   283  	// Not supported by AF_INET/AF_INET6.
   284  	return nil, nil, nil
   285  }
   286  
   287  // Readiness implements waiter.Waitable.Readiness.
   288  func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask {
   289  	return fdnotifier.NonBlockingPoll(int32(s.fd), mask)
   290  }
   291  
   292  // EventRegister implements waiter.Waitable.EventRegister.
   293  func (s *Socket) EventRegister(e *waiter.Entry) error {
   294  	s.queue.EventRegister(e)
   295  	if err := fdnotifier.UpdateFD(int32(s.fd)); err != nil {
   296  		s.queue.EventUnregister(e)
   297  		return err
   298  	}
   299  	return nil
   300  }
   301  
   302  // EventUnregister implements waiter.Waitable.EventUnregister.
   303  func (s *Socket) EventUnregister(e *waiter.Entry) {
   304  	s.queue.EventUnregister(e)
   305  	if err := fdnotifier.UpdateFD(int32(s.fd)); err != nil {
   306  		panic(err)
   307  	}
   308  }
   309  
   310  // Connect implements socket.Socket.Connect.
   311  func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
   312  	if len(sockaddr) > sizeofSockaddr {
   313  		sockaddr = sockaddr[:sizeofSockaddr]
   314  	}
   315  
   316  	_, _, errno := unix.Syscall(unix.SYS_CONNECT, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))
   317  	if errno == 0 {
   318  		return nil
   319  	}
   320  	// The host socket is always non-blocking, so we expect connect to
   321  	// return EINPROGRESS. If we are emulating a blocking socket, we will
   322  	// wait for the connect to complete below.
   323  	// But if we are not emulating a blocking socket, or if we got some
   324  	// other error, then return it now.
   325  	if errno != unix.EINPROGRESS || !blocking {
   326  		return syserr.FromError(translateIOSyscallError(errno))
   327  	}
   328  
   329  	// "EINPROGRESS: The socket is nonblocking and the connection cannot be
   330  	// completed immediately. It is possible to select(2) or poll(2) for
   331  	// completion by selecting the socket for writing. After select(2)
   332  	// indicates writability, use getsockopt(2) to read the SO_ERROR option at
   333  	// level SOL-SOCKET to determine whether connect() completed successfully
   334  	// (SO_ERROR is zero) or unsuccessfully (SO_ERROR is one of the usual error
   335  	// codes listed here, explaining the reason for the failure)." - connect(2)
   336  	writableMask := waiter.WritableEvents
   337  	e, ch := waiter.NewChannelEntry(writableMask)
   338  	s.EventRegister(&e)
   339  	defer s.EventUnregister(&e)
   340  	if s.Readiness(writableMask)&writableMask == 0 {
   341  		if err := t.Block(ch); err != nil {
   342  			return syserr.FromError(err)
   343  		}
   344  	}
   345  
   346  	val, err := unix.GetsockoptInt(s.fd, unix.SOL_SOCKET, unix.SO_ERROR)
   347  	if err != nil {
   348  		return syserr.FromError(err)
   349  	}
   350  	if val != 0 {
   351  		return syserr.FromError(unix.Errno(uintptr(val)))
   352  	}
   353  
   354  	// It seems like we are all good now, but Linux has left the socket
   355  	// state as CONNECTING (not CONNECTED). This is a strange quirk of
   356  	// non-blocking sockets. See tcp_finish_connect() which sets tcp state
   357  	// but not socket state.
   358  	//
   359  	// Sockets in the CONNECTING state can call connect() a second time,
   360  	// whereas CONNECTED sockets will reject the second connect() call.
   361  	// Because we are emulating a blocking socket, we want a subsequent
   362  	// connect() call to fail. So we must kick Linux to update the socket
   363  	// to state CONNECTED, which we can do by calling connect() a second
   364  	// time ourselves.
   365  	_, _, errno = unix.Syscall(unix.SYS_CONNECT, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))
   366  	if errno != 0 && errno != unix.EALREADY {
   367  		return syserr.FromError(translateIOSyscallError(errno))
   368  	}
   369  	return nil
   370  }
   371  
   372  // Accept implements socket.Socket.Accept.
   373  func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
   374  	var peerAddr linux.SockAddr
   375  	var peerAddrBuf []byte
   376  	var peerAddrlen uint32
   377  	var peerAddrPtr *byte
   378  	var peerAddrlenPtr *uint32
   379  	if peerRequested {
   380  		peerAddrBuf = make([]byte, sizeofSockaddr)
   381  		peerAddrlen = uint32(len(peerAddrBuf))
   382  		peerAddrPtr = &peerAddrBuf[0]
   383  		peerAddrlenPtr = &peerAddrlen
   384  	}
   385  
   386  	// Conservatively ignore all flags specified by the application and add
   387  	// SOCK_NONBLOCK since socketOpsCommon requires it.
   388  	fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC)
   389  	if blocking {
   390  		var ch chan struct{}
   391  		for linuxerr.Equals(linuxerr.ErrWouldBlock, syscallErr) {
   392  			if ch != nil {
   393  				if syscallErr = t.Block(ch); syscallErr != nil {
   394  					break
   395  				}
   396  			} else {
   397  				var e waiter.Entry
   398  				e, ch = waiter.NewChannelEntry(waiter.ReadableEvents | waiter.EventHUp | waiter.EventErr)
   399  				s.EventRegister(&e)
   400  				defer s.EventUnregister(&e)
   401  			}
   402  			fd, syscallErr = accept4(s.fd, peerAddrPtr, peerAddrlenPtr, unix.SOCK_NONBLOCK|unix.SOCK_CLOEXEC)
   403  		}
   404  	}
   405  
   406  	if peerRequested {
   407  		peerAddr = socket.UnmarshalSockAddr(s.family, peerAddrBuf[:peerAddrlen])
   408  	}
   409  	if syscallErr != nil {
   410  		return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
   411  	}
   412  
   413  	var (
   414  		kfd  int32
   415  		kerr error
   416  	)
   417  	f, err := newSocket(t, s.family, s.stype, s.protocol, fd, uint32(flags&unix.SOCK_NONBLOCK))
   418  	if err != nil {
   419  		_ = unix.Close(fd)
   420  		return 0, nil, 0, err
   421  	}
   422  	defer f.DecRef(t)
   423  
   424  	kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{
   425  		CloseOnExec: flags&unix.SOCK_CLOEXEC != 0,
   426  	})
   427  	t.Kernel().RecordSocket(f)
   428  
   429  	return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr)
   430  }
   431  
   432  // Bind implements socket.Socket.Bind.
   433  func (s *Socket) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error {
   434  	if len(sockaddr) > sizeofSockaddr {
   435  		sockaddr = sockaddr[:sizeofSockaddr]
   436  	}
   437  
   438  	_, _, errno := unix.Syscall(unix.SYS_BIND, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))
   439  	if errno != 0 {
   440  		return syserr.FromError(errno)
   441  	}
   442  	return nil
   443  }
   444  
   445  // Listen implements socket.Socket.Listen.
   446  func (s *Socket) Listen(_ *kernel.Task, backlog int) *syserr.Error {
   447  	return syserr.FromError(unix.Listen(s.fd, backlog))
   448  }
   449  
   450  // Shutdown implements socket.Socket.Shutdown.
   451  func (s *Socket) Shutdown(_ *kernel.Task, how int) *syserr.Error {
   452  	switch how {
   453  	case unix.SHUT_RD, unix.SHUT_RDWR:
   454  		// Mark the socket as closed for reading.
   455  		s.recvClosed.Store(true)
   456  		fallthrough
   457  	case unix.SHUT_WR:
   458  		return syserr.FromError(unix.Shutdown(s.fd, how))
   459  	default:
   460  		return syserr.ErrInvalidArgument
   461  	}
   462  }
   463  
   464  func (s *Socket) recvMsgFromHost(iovs []unix.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) {
   465  	// We always do a non-blocking recv*().
   466  	sysflags := flags | unix.MSG_DONTWAIT
   467  
   468  	msg := unix.Msghdr{}
   469  	if len(iovs) > 0 {
   470  		msg.Iov = &iovs[0]
   471  		msg.Iovlen = uint64(len(iovs))
   472  	}
   473  	var senderAddrBuf []byte
   474  	if senderRequested {
   475  		senderAddrBuf = make([]byte, sizeofSockaddr)
   476  		msg.Name = &senderAddrBuf[0]
   477  		msg.Namelen = uint32(sizeofSockaddr)
   478  	}
   479  	var controlBuf []byte
   480  	if controlLen > 0 {
   481  		if controlLen > maxControlLen {
   482  			controlLen = maxControlLen
   483  		}
   484  		controlBuf = make([]byte, controlLen)
   485  		msg.Control = &controlBuf[0]
   486  		msg.Controllen = controlLen
   487  	}
   488  	n, err := recvmsg(s.fd, &msg, sysflags)
   489  	if err != nil {
   490  		return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err
   491  	}
   492  	return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err
   493  }
   494  
   495  const allowedRecvMsgFlags = unix.MSG_CTRUNC |
   496  	unix.MSG_DONTWAIT |
   497  	unix.MSG_ERRQUEUE |
   498  	unix.MSG_OOB |
   499  	unix.MSG_PEEK |
   500  	unix.MSG_TRUNC |
   501  	unix.MSG_WAITALL
   502  
   503  // RecvMsg implements socket.Socket.RecvMsg.
   504  func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
   505  	// Only allow known and safe flags.
   506  	if flags&^allowedRecvMsgFlags != 0 {
   507  		return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
   508  	}
   509  
   510  	var senderAddrBuf []byte
   511  	var controlBuf []byte
   512  	var msgFlags int
   513  	copyToDst := func() (int64, error) {
   514  		var n uint64
   515  		var err error
   516  		if dst.NumBytes() == 0 {
   517  			// We want to make the recvmsg(2) call to the host even if dst is empty
   518  			// to fetch control messages, sender address or errors if any occur.
   519  			n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen)
   520  			return int64(n), err
   521  		}
   522  
   523  		recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
   524  			// Refuse to do anything if any part of dst.Addrs was unusable.
   525  			if uint64(dst.NumBytes()) != dsts.NumBytes() {
   526  				return 0, nil
   527  			}
   528  			if dsts.IsEmpty() {
   529  				return 0, nil
   530  			}
   531  
   532  			n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen)
   533  			return n, err
   534  		})
   535  		return dst.CopyOutFrom(t, recvmsgToBlocks)
   536  	}
   537  
   538  	var ch chan struct{}
   539  	n, err := copyToDst()
   540  
   541  	// recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT.
   542  	if flags&(unix.MSG_DONTWAIT|unix.MSG_ERRQUEUE) == 0 {
   543  		for linuxerr.Equals(linuxerr.ErrWouldBlock, err) {
   544  			// We only expect blocking to come from the actual syscall, in which
   545  			// case it can't have returned any data.
   546  			if n != 0 {
   547  				panic(fmt.Sprintf("CopyOutFrom: got (%d, %v), wanted (0, %v)", n, err, err))
   548  			}
   549  			// Are we closed for reading? No sense in trying to read if so.
   550  			if s.recvClosed.Load() {
   551  				break
   552  			}
   553  			if ch != nil {
   554  				if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
   555  					if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
   556  						err = linuxerr.ErrWouldBlock
   557  					}
   558  					break
   559  				}
   560  			} else {
   561  				var e waiter.Entry
   562  				e, ch = waiter.NewChannelEntry(waiter.ReadableEvents | waiter.EventRdHUp | waiter.EventHUp | waiter.EventErr)
   563  				s.EventRegister(&e)
   564  				defer s.EventUnregister(&e)
   565  			}
   566  			n, err = copyToDst()
   567  		}
   568  	}
   569  	if err != nil {
   570  		return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
   571  	}
   572  
   573  	// In some circumstances (like MSG_PEEK specified), the sender address
   574  	// field is purposefully ignored. recvMsgFromHost will return an empty
   575  	// senderAddrBuf in those cases.
   576  	var senderAddr linux.SockAddr
   577  	if senderRequested && len(senderAddrBuf) > 0 {
   578  		senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
   579  	}
   580  
   581  	unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf)
   582  	if err != nil {
   583  		return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
   584  	}
   585  	return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil
   586  }
   587  
   588  func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages {
   589  	controlMessages := socket.ControlMessages{}
   590  	for _, unixCmsg := range unixControlMessages {
   591  		switch unixCmsg.Header.Level {
   592  		case linux.SOL_SOCKET:
   593  			switch unixCmsg.Header.Type {
   594  			case linux.SO_TIMESTAMP:
   595  				controlMessages.IP.HasTimestamp = true
   596  				ts := linux.Timeval{}
   597  				ts.UnmarshalUnsafe(unixCmsg.Data)
   598  				controlMessages.IP.Timestamp = ts.ToTime()
   599  			}
   600  
   601  		case linux.SOL_IP:
   602  			switch unixCmsg.Header.Type {
   603  			case linux.IP_TOS:
   604  				controlMessages.IP.HasTOS = true
   605  				var tos primitive.Uint8
   606  				tos.UnmarshalUnsafe(unixCmsg.Data)
   607  				controlMessages.IP.TOS = uint8(tos)
   608  
   609  			case linux.IP_TTL:
   610  				controlMessages.IP.HasTTL = true
   611  				var ttl primitive.Uint32
   612  				ttl.UnmarshalUnsafe(unixCmsg.Data)
   613  				controlMessages.IP.TTL = uint32(ttl)
   614  
   615  			case linux.IP_PKTINFO:
   616  				controlMessages.IP.HasIPPacketInfo = true
   617  				var packetInfo linux.ControlMessageIPPacketInfo
   618  				packetInfo.UnmarshalUnsafe(unixCmsg.Data)
   619  				controlMessages.IP.PacketInfo = packetInfo
   620  
   621  			case linux.IP_RECVORIGDSTADDR:
   622  				var addr linux.SockAddrInet
   623  				addr.UnmarshalUnsafe(unixCmsg.Data)
   624  				controlMessages.IP.OriginalDstAddress = &addr
   625  
   626  			case unix.IP_RECVERR:
   627  				var errCmsg linux.SockErrCMsgIPv4
   628  				errCmsg.UnmarshalBytes(unixCmsg.Data)
   629  				controlMessages.IP.SockErr = &errCmsg
   630  			}
   631  
   632  		case linux.SOL_IPV6:
   633  			switch unixCmsg.Header.Type {
   634  			case linux.IPV6_TCLASS:
   635  				controlMessages.IP.HasTClass = true
   636  				var tclass primitive.Uint32
   637  				tclass.UnmarshalUnsafe(unixCmsg.Data)
   638  				controlMessages.IP.TClass = uint32(tclass)
   639  
   640  			case linux.IPV6_PKTINFO:
   641  				controlMessages.IP.HasIPv6PacketInfo = true
   642  				var packetInfo linux.ControlMessageIPv6PacketInfo
   643  				packetInfo.UnmarshalUnsafe(unixCmsg.Data)
   644  				controlMessages.IP.IPv6PacketInfo = packetInfo
   645  
   646  			case linux.IPV6_HOPLIMIT:
   647  				controlMessages.IP.HasHopLimit = true
   648  				var hoplimit primitive.Uint32
   649  				hoplimit.UnmarshalUnsafe(unixCmsg.Data)
   650  				controlMessages.IP.HopLimit = uint32(hoplimit)
   651  
   652  			case linux.IPV6_RECVORIGDSTADDR:
   653  				var addr linux.SockAddrInet6
   654  				addr.UnmarshalUnsafe(unixCmsg.Data)
   655  				controlMessages.IP.OriginalDstAddress = &addr
   656  
   657  			case unix.IPV6_RECVERR:
   658  				var errCmsg linux.SockErrCMsgIPv6
   659  				errCmsg.UnmarshalBytes(unixCmsg.Data)
   660  				controlMessages.IP.SockErr = &errCmsg
   661  			}
   662  
   663  		case linux.SOL_TCP:
   664  			switch unixCmsg.Header.Type {
   665  			case linux.TCP_INQ:
   666  				controlMessages.IP.HasInq = true
   667  				var inq primitive.Int32
   668  				inq.UnmarshalUnsafe(unixCmsg.Data)
   669  				controlMessages.IP.Inq = int32(inq)
   670  			}
   671  		}
   672  	}
   673  	return controlMessages
   674  }
   675  
   676  const allowedSendMsgFlags = unix.MSG_DONTWAIT |
   677  	unix.MSG_EOR |
   678  	unix.MSG_FASTOPEN |
   679  	unix.MSG_MORE |
   680  	unix.MSG_NOSIGNAL |
   681  	unix.MSG_OOB
   682  
   683  // SendMsg implements socket.Socket.SendMsg.
   684  func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
   685  	if s.family == linux.AF_PACKET {
   686  		// Don't allow SendMesg for AF_PACKET.
   687  		return 0, syserr.ErrPermissionDenied
   688  	}
   689  
   690  	// Only allow known and safe flags.
   691  	if flags&^allowedSendMsgFlags != 0 {
   692  		return 0, syserr.ErrInvalidArgument
   693  	}
   694  
   695  	// If the src is zero-length, call SENDTO directly with a null buffer in
   696  	// order to generate poll/epoll notifications.
   697  	if src.NumBytes() == 0 {
   698  		sysflags := flags | unix.MSG_DONTWAIT
   699  		n, _, errno := unix.Syscall6(unix.SYS_SENDTO, uintptr(s.fd), 0, 0, uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to)))
   700  		if errno != 0 {
   701  			return 0, syserr.FromError(errno)
   702  		}
   703  		return int(n), nil
   704  	}
   705  
   706  	space := uint64(control.CmsgsSpace(t, controlMessages))
   707  	if space > maxControlLen {
   708  		space = maxControlLen
   709  	}
   710  	controlBuf := make([]byte, 0, space)
   711  	// PackControlMessages will append up to space bytes to controlBuf.
   712  	controlBuf = control.PackControlMessages(t, controlMessages, controlBuf)
   713  
   714  	sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
   715  		// Refuse to do anything if any part of src.Addrs was unusable.
   716  		if uint64(src.NumBytes()) != srcs.NumBytes() {
   717  			return 0, nil
   718  		}
   719  		if srcs.IsEmpty() && len(controlBuf) == 0 {
   720  			return 0, nil
   721  		}
   722  
   723  		// We always do a non-blocking send*().
   724  		sysflags := flags | unix.MSG_DONTWAIT
   725  
   726  		if srcs.NumBlocks() == 1 && len(controlBuf) == 0 {
   727  			// Skip allocating []unix.Iovec.
   728  			src := srcs.Head()
   729  			n, _, errno := unix.Syscall6(unix.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to)))
   730  			if errno != 0 {
   731  				return 0, translateIOSyscallError(errno)
   732  			}
   733  			return uint64(n), nil
   734  		}
   735  
   736  		iovs := safemem.IovecsFromBlockSeq(srcs)
   737  		msg := unix.Msghdr{
   738  			Iov:    &iovs[0],
   739  			Iovlen: uint64(len(iovs)),
   740  		}
   741  		if len(to) != 0 {
   742  			msg.Name = &to[0]
   743  			msg.Namelen = uint32(len(to))
   744  		}
   745  		if len(controlBuf) != 0 {
   746  			msg.Control = &controlBuf[0]
   747  			msg.Controllen = uint64(len(controlBuf))
   748  		}
   749  		return sendmsg(s.fd, &msg, sysflags)
   750  	})
   751  
   752  	var ch chan struct{}
   753  	n, err := src.CopyInTo(t, sendmsgFromBlocks)
   754  	if flags&unix.MSG_DONTWAIT == 0 {
   755  		for linuxerr.Equals(linuxerr.ErrWouldBlock, err) {
   756  			// We only expect blocking to come from the actual syscall, in which
   757  			// case it can't have returned any data.
   758  			if n != 0 {
   759  				panic(fmt.Sprintf("CopyInTo: got (%d, %v), wanted (0, %v)", n, err, err))
   760  			}
   761  			if ch != nil {
   762  				if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
   763  					if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
   764  						err = linuxerr.ErrWouldBlock
   765  					}
   766  					break
   767  				}
   768  			} else {
   769  				var e waiter.Entry
   770  				e, ch = waiter.NewChannelEntry(waiter.WritableEvents | waiter.EventHUp | waiter.EventErr)
   771  				s.EventRegister(&e)
   772  				defer s.EventUnregister(&e)
   773  			}
   774  			n, err = src.CopyInTo(t, sendmsgFromBlocks)
   775  		}
   776  	}
   777  
   778  	return int(n), syserr.FromError(err)
   779  }
   780  
   781  func translateIOSyscallError(err error) error {
   782  	if err == unix.EAGAIN || err == unix.EWOULDBLOCK {
   783  		return linuxerr.ErrWouldBlock
   784  	}
   785  	return err
   786  }
   787  
   788  // State implements socket.Socket.State.
   789  func (s *Socket) State() uint32 {
   790  	info := linux.TCPInfo{}
   791  	buf := make([]byte, linux.SizeOfTCPInfo)
   792  	var err error
   793  	buf, err = getsockopt(s.fd, unix.SOL_TCP, unix.TCP_INFO, buf)
   794  	if err != nil {
   795  		if err != unix.ENOPROTOOPT {
   796  			log.Warningf("Failed to get TCP socket info from %+v: %v", s, err)
   797  		}
   798  		// For non-TCP sockets, silently ignore the failure.
   799  		return 0
   800  	}
   801  	if len(buf) != linux.SizeOfTCPInfo {
   802  		// Unmarshal below will panic if getsockopt returns a buffer of
   803  		// unexpected size.
   804  		log.Warningf("Failed to get TCP socket info from %+v: getsockopt(2) returned %d bytes, expecting %d bytes.", s, len(buf), linux.SizeOfTCPInfo)
   805  		return 0
   806  	}
   807  
   808  	info.UnmarshalUnsafe(buf[:info.SizeBytes()])
   809  	return uint32(info.State)
   810  }
   811  
   812  // Type implements socket.Socket.Type.
   813  func (s *Socket) Type() (family int, skType linux.SockType, protocol int) {
   814  	return s.family, s.stype, s.protocol
   815  }
   816  
   817  func init() {
   818  	// Register all families in AllowedSocketTypes and AllowedRawSocket
   819  	// types. If we don't allow raw sockets, they will be rejected in the
   820  	// Socket call.
   821  	registered := make(map[int]struct{})
   822  	for _, sockType := range append(AllowedSocketTypes, AllowedRawSocketTypes...) {
   823  		fam := sockType.Family
   824  		if _, ok := registered[fam]; ok {
   825  			continue
   826  		}
   827  		socket.RegisterProvider(fam, &socketProvider{fam})
   828  		registered[fam] = struct{}{}
   829  	}
   830  }