github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/syscalls/linux/sys_socket.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package linux
    16  
    17  import (
    18  	"time"
    19  
    20  	"github.com/SagerNet/gvisor/pkg/abi/linux"
    21  	"github.com/SagerNet/gvisor/pkg/errors/linuxerr"
    22  	"github.com/SagerNet/gvisor/pkg/hostarch"
    23  	"github.com/SagerNet/gvisor/pkg/marshal"
    24  	"github.com/SagerNet/gvisor/pkg/marshal/primitive"
    25  	"github.com/SagerNet/gvisor/pkg/sentry/arch"
    26  	"github.com/SagerNet/gvisor/pkg/sentry/fs"
    27  	"github.com/SagerNet/gvisor/pkg/sentry/kernel"
    28  	ktime "github.com/SagerNet/gvisor/pkg/sentry/kernel/time"
    29  	"github.com/SagerNet/gvisor/pkg/sentry/socket"
    30  	"github.com/SagerNet/gvisor/pkg/sentry/socket/control"
    31  	"github.com/SagerNet/gvisor/pkg/sentry/socket/unix/transport"
    32  	"github.com/SagerNet/gvisor/pkg/syserr"
    33  	"github.com/SagerNet/gvisor/pkg/syserror"
    34  	"github.com/SagerNet/gvisor/pkg/usermem"
    35  )
    36  
    37  // LINT.IfChange
    38  
    39  // maxAddrLen is the maximum socket address length we're willing to accept.
    40  const maxAddrLen = 200
    41  
    42  // maxOptLen is the maximum sockopt parameter length we're willing to accept.
    43  const maxOptLen = 1024 * 8
    44  
    45  // maxControlLen is the maximum length of the msghdr.msg_control buffer we're
    46  // willing to accept. Note that this limit is smaller than Linux, which allows
    47  // buffers upto INT_MAX.
    48  const maxControlLen = 10 * 1024 * 1024
    49  
    50  // maxListenBacklog is the maximum limit of listen backlog supported.
    51  const maxListenBacklog = 1024
    52  
    53  // nameLenOffset is the offset from the start of the MessageHeader64 struct to
    54  // the NameLen field.
    55  const nameLenOffset = 8
    56  
    57  // controlLenOffset is the offset form the start of the MessageHeader64 struct
    58  // to the ControlLen field.
    59  const controlLenOffset = 40
    60  
    61  // flagsOffset is the offset form the start of the MessageHeader64 struct
    62  // to the Flags field.
    63  const flagsOffset = 48
    64  
    65  const sizeOfInt32 = 4
    66  
    67  // messageHeader64Len is the length of a MessageHeader64 struct.
    68  var messageHeader64Len = uint64((*MessageHeader64)(nil).SizeBytes())
    69  
    70  // multipleMessageHeader64Len is the length of a multipeMessageHeader64 struct.
    71  var multipleMessageHeader64Len = uint64((*multipleMessageHeader64)(nil).SizeBytes())
    72  
    73  // baseRecvFlags are the flags that are accepted across recvmsg(2),
    74  // recvmmsg(2), and recvfrom(2).
    75  const baseRecvFlags = linux.MSG_OOB | linux.MSG_DONTROUTE | linux.MSG_DONTWAIT | linux.MSG_NOSIGNAL | linux.MSG_WAITALL | linux.MSG_TRUNC | linux.MSG_CTRUNC
    76  
    77  // MessageHeader64 is the 64-bit representation of the msghdr struct used in
    78  // the recvmsg and sendmsg syscalls.
    79  //
    80  // +marshal
    81  type MessageHeader64 struct {
    82  	// Name is the optional pointer to a network address buffer.
    83  	Name uint64
    84  
    85  	// NameLen is the length of the buffer pointed to by Name.
    86  	NameLen uint32
    87  	_       uint32
    88  
    89  	// Iov is a pointer to an array of io vectors that describe the memory
    90  	// locations involved in the io operation.
    91  	Iov uint64
    92  
    93  	// IovLen is the length of the array pointed to by Iov.
    94  	IovLen uint64
    95  
    96  	// Control is the optional pointer to ancillary control data.
    97  	Control uint64
    98  
    99  	// ControlLen is the length of the data pointed to by Control.
   100  	ControlLen uint64
   101  
   102  	// Flags on the sent/received message.
   103  	Flags int32
   104  	_     int32
   105  }
   106  
   107  // multipleMessageHeader64 is the 64-bit representation of the mmsghdr struct used in
   108  // the recvmmsg and sendmmsg syscalls.
   109  //
   110  // +marshal
   111  type multipleMessageHeader64 struct {
   112  	msgHdr MessageHeader64
   113  	msgLen uint32
   114  	_      int32
   115  }
   116  
   117  // CaptureAddress allocates memory for and copies a socket address structure
   118  // from the untrusted address space range.
   119  func CaptureAddress(t *kernel.Task, addr hostarch.Addr, addrlen uint32) ([]byte, error) {
   120  	if addrlen > maxAddrLen {
   121  		return nil, linuxerr.EINVAL
   122  	}
   123  
   124  	addrBuf := make([]byte, addrlen)
   125  	if _, err := t.CopyInBytes(addr, addrBuf); err != nil {
   126  		return nil, err
   127  	}
   128  
   129  	return addrBuf, nil
   130  }
   131  
   132  // writeAddress writes a sockaddr structure and its length to an output buffer
   133  // in the unstrusted address space range. If the address is bigger than the
   134  // buffer, it is truncated.
   135  func writeAddress(t *kernel.Task, addr linux.SockAddr, addrLen uint32, addrPtr hostarch.Addr, addrLenPtr hostarch.Addr) error {
   136  	// Get the buffer length.
   137  	var bufLen uint32
   138  	if _, err := primitive.CopyUint32In(t, addrLenPtr, &bufLen); err != nil {
   139  		return err
   140  	}
   141  
   142  	if int32(bufLen) < 0 {
   143  		return linuxerr.EINVAL
   144  	}
   145  
   146  	// Write the length unconditionally.
   147  	if _, err := primitive.CopyUint32Out(t, addrLenPtr, addrLen); err != nil {
   148  		return err
   149  	}
   150  
   151  	if addr == nil {
   152  		return nil
   153  	}
   154  
   155  	if bufLen > addrLen {
   156  		bufLen = addrLen
   157  	}
   158  
   159  	// Copy as much of the address as will fit in the buffer.
   160  	encodedAddr := t.CopyScratchBuffer(addr.SizeBytes())
   161  	addr.MarshalUnsafe(encodedAddr)
   162  	if bufLen > uint32(len(encodedAddr)) {
   163  		bufLen = uint32(len(encodedAddr))
   164  	}
   165  	_, err := t.CopyOutBytes(addrPtr, encodedAddr[:int(bufLen)])
   166  	return err
   167  }
   168  
   169  // Socket implements the linux syscall socket(2).
   170  func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   171  	domain := int(args[0].Int())
   172  	stype := args[1].Int()
   173  	protocol := int(args[2].Int())
   174  
   175  	// Check and initialize the flags.
   176  	if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
   177  		return 0, nil, linuxerr.EINVAL
   178  	}
   179  
   180  	// Create the new socket.
   181  	s, e := socket.New(t, domain, linux.SockType(stype&0xf), protocol)
   182  	if e != nil {
   183  		return 0, nil, e.ToError()
   184  	}
   185  	s.SetFlags(fs.SettableFileFlags{
   186  		NonBlocking: stype&linux.SOCK_NONBLOCK != 0,
   187  	})
   188  	defer s.DecRef(t)
   189  
   190  	fd, err := t.NewFDFrom(0, s, kernel.FDFlags{
   191  		CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
   192  	})
   193  	if err != nil {
   194  		return 0, nil, err
   195  	}
   196  
   197  	return uintptr(fd), nil, nil
   198  }
   199  
   200  // SocketPair implements the linux syscall socketpair(2).
   201  func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   202  	domain := int(args[0].Int())
   203  	stype := args[1].Int()
   204  	protocol := int(args[2].Int())
   205  	socks := args[3].Pointer()
   206  
   207  	// Check and initialize the flags.
   208  	if stype & ^(0xf|linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
   209  		return 0, nil, linuxerr.EINVAL
   210  	}
   211  
   212  	fileFlags := fs.SettableFileFlags{
   213  		NonBlocking: stype&linux.SOCK_NONBLOCK != 0,
   214  	}
   215  
   216  	// Create the socket pair.
   217  	s1, s2, e := socket.Pair(t, domain, linux.SockType(stype&0xf), protocol)
   218  	if e != nil {
   219  		return 0, nil, e.ToError()
   220  	}
   221  	s1.SetFlags(fileFlags)
   222  	s2.SetFlags(fileFlags)
   223  	defer s1.DecRef(t)
   224  	defer s2.DecRef(t)
   225  
   226  	// Create the FDs for the sockets.
   227  	fds, err := t.NewFDs(0, []*fs.File{s1, s2}, kernel.FDFlags{
   228  		CloseOnExec: stype&linux.SOCK_CLOEXEC != 0,
   229  	})
   230  	if err != nil {
   231  		return 0, nil, err
   232  	}
   233  
   234  	// Copy the file descriptors out.
   235  	if _, err := primitive.CopyInt32SliceOut(t, socks, fds); err != nil {
   236  		for _, fd := range fds {
   237  			if file, _ := t.FDTable().Remove(t, fd); file != nil {
   238  				file.DecRef(t)
   239  			}
   240  		}
   241  		return 0, nil, err
   242  	}
   243  
   244  	return 0, nil, nil
   245  }
   246  
   247  // Connect implements the linux syscall connect(2).
   248  func Connect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   249  	fd := args[0].Int()
   250  	addr := args[1].Pointer()
   251  	addrlen := args[2].Uint()
   252  
   253  	// Get socket from the file descriptor.
   254  	file := t.GetFile(fd)
   255  	if file == nil {
   256  		return 0, nil, linuxerr.EBADF
   257  	}
   258  	defer file.DecRef(t)
   259  
   260  	// Extract the socket.
   261  	s, ok := file.FileOperations.(socket.Socket)
   262  	if !ok {
   263  		return 0, nil, syserror.ENOTSOCK
   264  	}
   265  
   266  	// Capture address and call syscall implementation.
   267  	a, err := CaptureAddress(t, addr, addrlen)
   268  	if err != nil {
   269  		return 0, nil, err
   270  	}
   271  
   272  	blocking := !file.Flags().NonBlocking
   273  	return 0, nil, syserror.ConvertIntr(s.Connect(t, a, blocking).ToError(), syserror.ERESTARTSYS)
   274  }
   275  
   276  // accept is the implementation of the accept syscall. It is called by accept
   277  // and accept4 syscall handlers.
   278  func accept(t *kernel.Task, fd int32, addr hostarch.Addr, addrLen hostarch.Addr, flags int) (uintptr, error) {
   279  	// Check that no unsupported flags are passed in.
   280  	if flags & ^(linux.SOCK_NONBLOCK|linux.SOCK_CLOEXEC) != 0 {
   281  		return 0, linuxerr.EINVAL
   282  	}
   283  
   284  	// Get socket from the file descriptor.
   285  	file := t.GetFile(fd)
   286  	if file == nil {
   287  		return 0, linuxerr.EBADF
   288  	}
   289  	defer file.DecRef(t)
   290  
   291  	// Extract the socket.
   292  	s, ok := file.FileOperations.(socket.Socket)
   293  	if !ok {
   294  		return 0, syserror.ENOTSOCK
   295  	}
   296  
   297  	// Call the syscall implementation for this socket, then copy the
   298  	// output address if one is specified.
   299  	blocking := !file.Flags().NonBlocking
   300  
   301  	peerRequested := addrLen != 0
   302  	nfd, peer, peerLen, e := s.Accept(t, peerRequested, flags, blocking)
   303  	if e != nil {
   304  		return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS)
   305  	}
   306  	if peerRequested {
   307  		// NOTE(magi): Linux does not give you an error if it can't
   308  		// write the data back out so neither do we.
   309  		if err := writeAddress(t, peer, peerLen, addr, addrLen); linuxerr.Equals(linuxerr.EINVAL, err) {
   310  			return 0, err
   311  		}
   312  	}
   313  	return uintptr(nfd), nil
   314  }
   315  
   316  // Accept4 implements the linux syscall accept4(2).
   317  func Accept4(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   318  	fd := args[0].Int()
   319  	addr := args[1].Pointer()
   320  	addrlen := args[2].Pointer()
   321  	flags := int(args[3].Int())
   322  
   323  	n, err := accept(t, fd, addr, addrlen, flags)
   324  	return n, nil, err
   325  }
   326  
   327  // Accept implements the linux syscall accept(2).
   328  func Accept(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   329  	fd := args[0].Int()
   330  	addr := args[1].Pointer()
   331  	addrlen := args[2].Pointer()
   332  
   333  	n, err := accept(t, fd, addr, addrlen, 0)
   334  	return n, nil, err
   335  }
   336  
   337  // Bind implements the linux syscall bind(2).
   338  func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   339  	fd := args[0].Int()
   340  	addr := args[1].Pointer()
   341  	addrlen := args[2].Uint()
   342  
   343  	// Get socket from the file descriptor.
   344  	file := t.GetFile(fd)
   345  	if file == nil {
   346  		return 0, nil, linuxerr.EBADF
   347  	}
   348  	defer file.DecRef(t)
   349  
   350  	// Extract the socket.
   351  	s, ok := file.FileOperations.(socket.Socket)
   352  	if !ok {
   353  		return 0, nil, syserror.ENOTSOCK
   354  	}
   355  
   356  	// Capture address and call syscall implementation.
   357  	a, err := CaptureAddress(t, addr, addrlen)
   358  	if err != nil {
   359  		return 0, nil, err
   360  	}
   361  
   362  	return 0, nil, s.Bind(t, a).ToError()
   363  }
   364  
   365  // Listen implements the linux syscall listen(2).
   366  func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   367  	fd := args[0].Int()
   368  	backlog := args[1].Uint()
   369  
   370  	// Get socket from the file descriptor.
   371  	file := t.GetFile(fd)
   372  	if file == nil {
   373  		return 0, nil, linuxerr.EBADF
   374  	}
   375  	defer file.DecRef(t)
   376  
   377  	// Extract the socket.
   378  	s, ok := file.FileOperations.(socket.Socket)
   379  	if !ok {
   380  		return 0, nil, syserror.ENOTSOCK
   381  	}
   382  
   383  	if backlog > maxListenBacklog {
   384  		// Linux treats incoming backlog as uint with a limit defined by
   385  		// sysctl_somaxconn.
   386  		// https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666
   387  		backlog = maxListenBacklog
   388  	}
   389  
   390  	// Accept one more than the configured listen backlog to keep in parity with
   391  	// Linux. Ref, because of missing equality check here:
   392  	// https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937
   393  	//
   394  	// In case of unix domain sockets, the following check
   395  	// https://github.com/torvalds/linux/blob/7d6beb71da3/net/unix/af_unix.c#L1293
   396  	// will allow 1 connect through since it checks for a receive queue len >
   397  	// backlog and not >=.
   398  	backlog++
   399  
   400  	return 0, nil, s.Listen(t, int(backlog)).ToError()
   401  }
   402  
   403  // Shutdown implements the linux syscall shutdown(2).
   404  func Shutdown(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   405  	fd := args[0].Int()
   406  	how := args[1].Int()
   407  
   408  	// Get socket from the file descriptor.
   409  	file := t.GetFile(fd)
   410  	if file == nil {
   411  		return 0, nil, linuxerr.EBADF
   412  	}
   413  	defer file.DecRef(t)
   414  
   415  	// Extract the socket.
   416  	s, ok := file.FileOperations.(socket.Socket)
   417  	if !ok {
   418  		return 0, nil, syserror.ENOTSOCK
   419  	}
   420  
   421  	// Validate how, then call syscall implementation.
   422  	switch how {
   423  	case linux.SHUT_RD, linux.SHUT_WR, linux.SHUT_RDWR:
   424  	default:
   425  		return 0, nil, linuxerr.EINVAL
   426  	}
   427  
   428  	return 0, nil, s.Shutdown(t, int(how)).ToError()
   429  }
   430  
   431  // GetSockOpt implements the linux syscall getsockopt(2).
   432  func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   433  	fd := args[0].Int()
   434  	level := args[1].Int()
   435  	name := args[2].Int()
   436  	optValAddr := args[3].Pointer()
   437  	optLenAddr := args[4].Pointer()
   438  
   439  	// Get socket from the file descriptor.
   440  	file := t.GetFile(fd)
   441  	if file == nil {
   442  		return 0, nil, linuxerr.EBADF
   443  	}
   444  	defer file.DecRef(t)
   445  
   446  	// Extract the socket.
   447  	s, ok := file.FileOperations.(socket.Socket)
   448  	if !ok {
   449  		return 0, nil, syserror.ENOTSOCK
   450  	}
   451  
   452  	// Read the length. Reject negative values.
   453  	var optLen int32
   454  	if _, err := primitive.CopyInt32In(t, optLenAddr, &optLen); err != nil {
   455  		return 0, nil, err
   456  	}
   457  	if optLen < 0 {
   458  		return 0, nil, linuxerr.EINVAL
   459  	}
   460  
   461  	// Call syscall implementation then copy both value and value len out.
   462  	v, e := getSockOpt(t, s, int(level), int(name), optValAddr, int(optLen))
   463  	if e != nil {
   464  		return 0, nil, e.ToError()
   465  	}
   466  
   467  	if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil {
   468  		return 0, nil, err
   469  	}
   470  
   471  	if v != nil {
   472  		if _, err := v.CopyOut(t, optValAddr); err != nil {
   473  			return 0, nil, err
   474  		}
   475  	}
   476  
   477  	return 0, nil, nil
   478  }
   479  
   480  // getSockOpt tries to handle common socket options, or dispatches to a specific
   481  // socket implementation.
   482  func getSockOpt(t *kernel.Task, s socket.Socket, level, name int, optValAddr hostarch.Addr, len int) (marshal.Marshallable, *syserr.Error) {
   483  	if level == linux.SOL_SOCKET {
   484  		switch name {
   485  		case linux.SO_TYPE, linux.SO_DOMAIN, linux.SO_PROTOCOL:
   486  			if len < sizeOfInt32 {
   487  				return nil, syserr.ErrInvalidArgument
   488  			}
   489  		}
   490  
   491  		switch name {
   492  		case linux.SO_TYPE:
   493  			_, skType, _ := s.Type()
   494  			v := primitive.Int32(skType)
   495  			return &v, nil
   496  		case linux.SO_DOMAIN:
   497  			family, _, _ := s.Type()
   498  			v := primitive.Int32(family)
   499  			return &v, nil
   500  		case linux.SO_PROTOCOL:
   501  			_, _, protocol := s.Type()
   502  			v := primitive.Int32(protocol)
   503  			return &v, nil
   504  		}
   505  	}
   506  
   507  	return s.GetSockOpt(t, level, name, optValAddr, len)
   508  }
   509  
   510  // SetSockOpt implements the linux syscall setsockopt(2).
   511  //
   512  // Note that unlike Linux, enabling SO_PASSCRED does not autobind the socket.
   513  func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   514  	fd := args[0].Int()
   515  	level := args[1].Int()
   516  	name := args[2].Int()
   517  	optValAddr := args[3].Pointer()
   518  	optLen := args[4].Int()
   519  
   520  	// Get socket from the file descriptor.
   521  	file := t.GetFile(fd)
   522  	if file == nil {
   523  		return 0, nil, linuxerr.EBADF
   524  	}
   525  	defer file.DecRef(t)
   526  
   527  	// Extract the socket.
   528  	s, ok := file.FileOperations.(socket.Socket)
   529  	if !ok {
   530  		return 0, nil, syserror.ENOTSOCK
   531  	}
   532  
   533  	if optLen < 0 {
   534  		return 0, nil, linuxerr.EINVAL
   535  	}
   536  	if optLen > maxOptLen {
   537  		return 0, nil, linuxerr.EINVAL
   538  	}
   539  	buf := t.CopyScratchBuffer(int(optLen))
   540  	if _, err := t.CopyInBytes(optValAddr, buf); err != nil {
   541  		return 0, nil, err
   542  	}
   543  
   544  	// Call syscall implementation.
   545  	if err := s.SetSockOpt(t, int(level), int(name), buf); err != nil {
   546  		return 0, nil, err.ToError()
   547  	}
   548  
   549  	return 0, nil, nil
   550  }
   551  
   552  // GetSockName implements the linux syscall getsockname(2).
   553  func GetSockName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   554  	fd := args[0].Int()
   555  	addr := args[1].Pointer()
   556  	addrlen := args[2].Pointer()
   557  
   558  	// Get socket from the file descriptor.
   559  	file := t.GetFile(fd)
   560  	if file == nil {
   561  		return 0, nil, linuxerr.EBADF
   562  	}
   563  	defer file.DecRef(t)
   564  
   565  	// Extract the socket.
   566  	s, ok := file.FileOperations.(socket.Socket)
   567  	if !ok {
   568  		return 0, nil, syserror.ENOTSOCK
   569  	}
   570  
   571  	// Get the socket name and copy it to the caller.
   572  	v, vl, err := s.GetSockName(t)
   573  	if err != nil {
   574  		return 0, nil, err.ToError()
   575  	}
   576  
   577  	return 0, nil, writeAddress(t, v, vl, addr, addrlen)
   578  }
   579  
   580  // GetPeerName implements the linux syscall getpeername(2).
   581  func GetPeerName(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   582  	fd := args[0].Int()
   583  	addr := args[1].Pointer()
   584  	addrlen := args[2].Pointer()
   585  
   586  	// Get socket from the file descriptor.
   587  	file := t.GetFile(fd)
   588  	if file == nil {
   589  		return 0, nil, linuxerr.EBADF
   590  	}
   591  	defer file.DecRef(t)
   592  
   593  	// Extract the socket.
   594  	s, ok := file.FileOperations.(socket.Socket)
   595  	if !ok {
   596  		return 0, nil, syserror.ENOTSOCK
   597  	}
   598  
   599  	// Get the socket peer name and copy it to the caller.
   600  	v, vl, err := s.GetPeerName(t)
   601  	if err != nil {
   602  		return 0, nil, err.ToError()
   603  	}
   604  
   605  	return 0, nil, writeAddress(t, v, vl, addr, addrlen)
   606  }
   607  
   608  // RecvMsg implements the linux syscall recvmsg(2).
   609  func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   610  	fd := args[0].Int()
   611  	msgPtr := args[1].Pointer()
   612  	flags := args[2].Int()
   613  
   614  	if t.Arch().Width() != 8 {
   615  		// We only handle 64-bit for now.
   616  		return 0, nil, linuxerr.EINVAL
   617  	}
   618  
   619  	// Get socket from the file descriptor.
   620  	file := t.GetFile(fd)
   621  	if file == nil {
   622  		return 0, nil, linuxerr.EBADF
   623  	}
   624  	defer file.DecRef(t)
   625  
   626  	// Extract the socket.
   627  	s, ok := file.FileOperations.(socket.Socket)
   628  	if !ok {
   629  		return 0, nil, syserror.ENOTSOCK
   630  	}
   631  
   632  	// Reject flags that we don't handle yet.
   633  	if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
   634  		return 0, nil, linuxerr.EINVAL
   635  	}
   636  
   637  	if file.Flags().NonBlocking {
   638  		flags |= linux.MSG_DONTWAIT
   639  	}
   640  
   641  	var haveDeadline bool
   642  	var deadline ktime.Time
   643  	if dl := s.RecvTimeout(); dl > 0 {
   644  		deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
   645  		haveDeadline = true
   646  	} else if dl < 0 {
   647  		flags |= linux.MSG_DONTWAIT
   648  	}
   649  
   650  	n, err := recvSingleMsg(t, s, msgPtr, flags, haveDeadline, deadline)
   651  	return n, nil, err
   652  }
   653  
   654  // RecvMMsg implements the linux syscall recvmmsg(2).
   655  func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   656  	fd := args[0].Int()
   657  	msgPtr := args[1].Pointer()
   658  	vlen := args[2].Uint()
   659  	flags := args[3].Int()
   660  	toPtr := args[4].Pointer()
   661  
   662  	if t.Arch().Width() != 8 {
   663  		// We only handle 64-bit for now.
   664  		return 0, nil, linuxerr.EINVAL
   665  	}
   666  
   667  	if vlen > linux.UIO_MAXIOV {
   668  		vlen = linux.UIO_MAXIOV
   669  	}
   670  
   671  	// Reject flags that we don't handle yet.
   672  	if flags & ^(baseRecvFlags|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
   673  		return 0, nil, linuxerr.EINVAL
   674  	}
   675  
   676  	// Get socket from the file descriptor.
   677  	file := t.GetFile(fd)
   678  	if file == nil {
   679  		return 0, nil, linuxerr.EBADF
   680  	}
   681  	defer file.DecRef(t)
   682  
   683  	// Extract the socket.
   684  	s, ok := file.FileOperations.(socket.Socket)
   685  	if !ok {
   686  		return 0, nil, syserror.ENOTSOCK
   687  	}
   688  
   689  	if file.Flags().NonBlocking {
   690  		flags |= linux.MSG_DONTWAIT
   691  	}
   692  
   693  	var haveDeadline bool
   694  	var deadline ktime.Time
   695  	if toPtr != 0 {
   696  		ts, err := copyTimespecIn(t, toPtr)
   697  		if err != nil {
   698  			return 0, nil, err
   699  		}
   700  		if !ts.Valid() {
   701  			return 0, nil, linuxerr.EINVAL
   702  		}
   703  		deadline = t.Kernel().MonotonicClock().Now().Add(ts.ToDuration())
   704  		haveDeadline = true
   705  	}
   706  
   707  	if !haveDeadline {
   708  		if dl := s.RecvTimeout(); dl > 0 {
   709  			deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
   710  			haveDeadline = true
   711  		} else if dl < 0 {
   712  			flags |= linux.MSG_DONTWAIT
   713  		}
   714  	}
   715  
   716  	var count uint32
   717  	var err error
   718  	for i := uint64(0); i < uint64(vlen); i++ {
   719  		mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
   720  		if !ok {
   721  			return 0, nil, syserror.EFAULT
   722  		}
   723  		var n uintptr
   724  		if n, err = recvSingleMsg(t, s, mp, flags, haveDeadline, deadline); err != nil {
   725  			break
   726  		}
   727  
   728  		// Copy the received length to the caller.
   729  		lp, ok := mp.AddLength(messageHeader64Len)
   730  		if !ok {
   731  			return 0, nil, syserror.EFAULT
   732  		}
   733  		if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil {
   734  			break
   735  		}
   736  		count++
   737  	}
   738  
   739  	if count == 0 {
   740  		return 0, nil, err
   741  	}
   742  	return uintptr(count), nil, nil
   743  }
   744  
   745  func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr hostarch.Addr, flags int32, haveDeadline bool, deadline ktime.Time) (uintptr, error) {
   746  	// Capture the message header and io vectors.
   747  	var msg MessageHeader64
   748  	if _, err := msg.CopyIn(t, msgPtr); err != nil {
   749  		return 0, err
   750  	}
   751  
   752  	if msg.IovLen > linux.UIO_MAXIOV {
   753  		return 0, linuxerr.EMSGSIZE
   754  	}
   755  	dst, err := t.IovecsIOSequence(hostarch.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
   756  		AddressSpaceActive: true,
   757  	})
   758  	if err != nil {
   759  		return 0, err
   760  	}
   761  
   762  	// Fast path when no control message nor name buffers are provided.
   763  	if msg.ControlLen == 0 && msg.NameLen == 0 {
   764  		n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
   765  		if err != nil {
   766  			return 0, syserror.ConvertIntr(err.ToError(), syserror.ERESTARTSYS)
   767  		}
   768  		if !cms.Unix.Empty() {
   769  			mflags |= linux.MSG_CTRUNC
   770  			cms.Release(t)
   771  		}
   772  
   773  		if int(msg.Flags) != mflags {
   774  			// Copy out the flags to the caller.
   775  			if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil {
   776  				return 0, err
   777  			}
   778  		}
   779  
   780  		return uintptr(n), nil
   781  	}
   782  
   783  	if msg.ControlLen > maxControlLen {
   784  		return 0, linuxerr.ENOBUFS
   785  	}
   786  	n, mflags, sender, senderLen, cms, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, msg.NameLen != 0, msg.ControlLen)
   787  	if e != nil {
   788  		return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS)
   789  	}
   790  	defer cms.Release(t)
   791  
   792  	controlData := make([]byte, 0, msg.ControlLen)
   793  	controlData = control.PackControlMessages(t, cms, controlData)
   794  
   795  	if cr, ok := s.(transport.Credentialer); ok && cr.Passcred() {
   796  		creds, _ := cms.Unix.Credentials.(control.SCMCredentials)
   797  		controlData, mflags = control.PackCredentials(t, creds, controlData, mflags)
   798  	}
   799  
   800  	if cms.Unix.Rights != nil {
   801  		controlData, mflags = control.PackRights(t, cms.Unix.Rights.(control.SCMRights), flags&linux.MSG_CMSG_CLOEXEC != 0, controlData, mflags)
   802  	}
   803  
   804  	// Copy the address to the caller.
   805  	if msg.NameLen != 0 {
   806  		if err := writeAddress(t, sender, senderLen, hostarch.Addr(msg.Name), hostarch.Addr(msgPtr+nameLenOffset)); err != nil {
   807  			return 0, err
   808  		}
   809  	}
   810  
   811  	// Copy the control data to the caller.
   812  	if _, err := primitive.CopyUint64Out(t, msgPtr+controlLenOffset, uint64(len(controlData))); err != nil {
   813  		return 0, err
   814  	}
   815  	if len(controlData) > 0 {
   816  		if _, err := t.CopyOutBytes(hostarch.Addr(msg.Control), controlData); err != nil {
   817  			return 0, err
   818  		}
   819  	}
   820  
   821  	// Copy out the flags to the caller.
   822  	if _, err := primitive.CopyInt32Out(t, msgPtr+flagsOffset, int32(mflags)); err != nil {
   823  		return 0, err
   824  	}
   825  
   826  	return uintptr(n), nil
   827  }
   828  
   829  // recvFrom is the implementation of the recvfrom syscall. It is called by
   830  // recvfrom and recv syscall handlers.
   831  func recvFrom(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLenPtr hostarch.Addr) (uintptr, error) {
   832  	if int(bufLen) < 0 {
   833  		return 0, linuxerr.EINVAL
   834  	}
   835  
   836  	// Reject flags that we don't handle yet.
   837  	if flags & ^(baseRecvFlags|linux.MSG_PEEK|linux.MSG_CONFIRM) != 0 {
   838  		return 0, linuxerr.EINVAL
   839  	}
   840  
   841  	// Get socket from the file descriptor.
   842  	file := t.GetFile(fd)
   843  	if file == nil {
   844  		return 0, linuxerr.EBADF
   845  	}
   846  	defer file.DecRef(t)
   847  
   848  	// Extract the socket.
   849  	s, ok := file.FileOperations.(socket.Socket)
   850  	if !ok {
   851  		return 0, syserror.ENOTSOCK
   852  	}
   853  
   854  	if file.Flags().NonBlocking {
   855  		flags |= linux.MSG_DONTWAIT
   856  	}
   857  
   858  	dst, err := t.SingleIOSequence(bufPtr, int(bufLen), usermem.IOOpts{
   859  		AddressSpaceActive: true,
   860  	})
   861  	if err != nil {
   862  		return 0, err
   863  	}
   864  
   865  	var haveDeadline bool
   866  	var deadline ktime.Time
   867  	if dl := s.RecvTimeout(); dl > 0 {
   868  		deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
   869  		haveDeadline = true
   870  	} else if dl < 0 {
   871  		flags |= linux.MSG_DONTWAIT
   872  	}
   873  
   874  	n, _, sender, senderLen, cm, e := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, nameLenPtr != 0, 0)
   875  	cm.Release(t)
   876  	if e != nil {
   877  		return 0, syserror.ConvertIntr(e.ToError(), syserror.ERESTARTSYS)
   878  	}
   879  
   880  	// Copy the address to the caller.
   881  	if nameLenPtr != 0 {
   882  		if err := writeAddress(t, sender, senderLen, namePtr, nameLenPtr); err != nil {
   883  			return 0, err
   884  		}
   885  	}
   886  
   887  	return uintptr(n), nil
   888  }
   889  
   890  // RecvFrom implements the linux syscall recvfrom(2).
   891  func RecvFrom(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   892  	fd := args[0].Int()
   893  	bufPtr := args[1].Pointer()
   894  	bufLen := args[2].Uint64()
   895  	flags := args[3].Int()
   896  	namePtr := args[4].Pointer()
   897  	nameLenPtr := args[5].Pointer()
   898  
   899  	n, err := recvFrom(t, fd, bufPtr, bufLen, flags, namePtr, nameLenPtr)
   900  	return n, nil, err
   901  }
   902  
   903  // SendMsg implements the linux syscall sendmsg(2).
   904  func SendMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   905  	fd := args[0].Int()
   906  	msgPtr := args[1].Pointer()
   907  	flags := args[2].Int()
   908  
   909  	if t.Arch().Width() != 8 {
   910  		// We only handle 64-bit for now.
   911  		return 0, nil, linuxerr.EINVAL
   912  	}
   913  
   914  	// Get socket from the file descriptor.
   915  	file := t.GetFile(fd)
   916  	if file == nil {
   917  		return 0, nil, linuxerr.EBADF
   918  	}
   919  	defer file.DecRef(t)
   920  
   921  	// Extract the socket.
   922  	s, ok := file.FileOperations.(socket.Socket)
   923  	if !ok {
   924  		return 0, nil, syserror.ENOTSOCK
   925  	}
   926  
   927  	// Reject flags that we don't handle yet.
   928  	if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
   929  		return 0, nil, linuxerr.EINVAL
   930  	}
   931  
   932  	if file.Flags().NonBlocking {
   933  		flags |= linux.MSG_DONTWAIT
   934  	}
   935  
   936  	n, err := sendSingleMsg(t, s, file, msgPtr, flags)
   937  	return n, nil, err
   938  }
   939  
   940  // SendMMsg implements the linux syscall sendmmsg(2).
   941  func SendMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   942  	fd := args[0].Int()
   943  	msgPtr := args[1].Pointer()
   944  	vlen := args[2].Uint()
   945  	flags := args[3].Int()
   946  
   947  	if t.Arch().Width() != 8 {
   948  		// We only handle 64-bit for now.
   949  		return 0, nil, linuxerr.EINVAL
   950  	}
   951  
   952  	if vlen > linux.UIO_MAXIOV {
   953  		vlen = linux.UIO_MAXIOV
   954  	}
   955  
   956  	// Get socket from the file descriptor.
   957  	file := t.GetFile(fd)
   958  	if file == nil {
   959  		return 0, nil, linuxerr.EBADF
   960  	}
   961  	defer file.DecRef(t)
   962  
   963  	// Extract the socket.
   964  	s, ok := file.FileOperations.(socket.Socket)
   965  	if !ok {
   966  		return 0, nil, syserror.ENOTSOCK
   967  	}
   968  
   969  	// Reject flags that we don't handle yet.
   970  	if flags & ^(linux.MSG_DONTWAIT|linux.MSG_EOR|linux.MSG_MORE|linux.MSG_NOSIGNAL) != 0 {
   971  		return 0, nil, linuxerr.EINVAL
   972  	}
   973  
   974  	if file.Flags().NonBlocking {
   975  		flags |= linux.MSG_DONTWAIT
   976  	}
   977  
   978  	var count uint32
   979  	var err error
   980  	for i := uint64(0); i < uint64(vlen); i++ {
   981  		mp, ok := msgPtr.AddLength(i * multipleMessageHeader64Len)
   982  		if !ok {
   983  			return 0, nil, syserror.EFAULT
   984  		}
   985  		var n uintptr
   986  		if n, err = sendSingleMsg(t, s, file, mp, flags); err != nil {
   987  			break
   988  		}
   989  
   990  		// Copy the received length to the caller.
   991  		lp, ok := mp.AddLength(messageHeader64Len)
   992  		if !ok {
   993  			return 0, nil, syserror.EFAULT
   994  		}
   995  		if _, err = primitive.CopyUint32Out(t, lp, uint32(n)); err != nil {
   996  			break
   997  		}
   998  		count++
   999  	}
  1000  
  1001  	if count == 0 {
  1002  		return 0, nil, err
  1003  	}
  1004  	return uintptr(count), nil, nil
  1005  }
  1006  
  1007  func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr hostarch.Addr, flags int32) (uintptr, error) {
  1008  	// Capture the message header.
  1009  	var msg MessageHeader64
  1010  	if _, err := msg.CopyIn(t, msgPtr); err != nil {
  1011  		return 0, err
  1012  	}
  1013  
  1014  	var controlData []byte
  1015  	if msg.ControlLen > 0 {
  1016  		// Put an upper bound to prevent large allocations.
  1017  		if msg.ControlLen > maxControlLen {
  1018  			return 0, linuxerr.ENOBUFS
  1019  		}
  1020  		controlData = make([]byte, msg.ControlLen)
  1021  		if _, err := t.CopyInBytes(hostarch.Addr(msg.Control), controlData); err != nil {
  1022  			return 0, err
  1023  		}
  1024  	}
  1025  
  1026  	// Read the destination address if one is specified.
  1027  	var to []byte
  1028  	if msg.NameLen != 0 {
  1029  		var err error
  1030  		to, err = CaptureAddress(t, hostarch.Addr(msg.Name), msg.NameLen)
  1031  		if err != nil {
  1032  			return 0, err
  1033  		}
  1034  	}
  1035  
  1036  	// Read data then call the sendmsg implementation.
  1037  	if msg.IovLen > linux.UIO_MAXIOV {
  1038  		return 0, linuxerr.EMSGSIZE
  1039  	}
  1040  	src, err := t.IovecsIOSequence(hostarch.Addr(msg.Iov), int(msg.IovLen), usermem.IOOpts{
  1041  		AddressSpaceActive: true,
  1042  	})
  1043  	if err != nil {
  1044  		return 0, err
  1045  	}
  1046  
  1047  	controlMessages, err := control.Parse(t, s, controlData, t.Arch().Width())
  1048  	if err != nil {
  1049  		return 0, err
  1050  	}
  1051  
  1052  	var haveDeadline bool
  1053  	var deadline ktime.Time
  1054  	if dl := s.SendTimeout(); dl > 0 {
  1055  		deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
  1056  		haveDeadline = true
  1057  	} else if dl < 0 {
  1058  		flags |= linux.MSG_DONTWAIT
  1059  	}
  1060  
  1061  	// Call the syscall implementation.
  1062  	n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, controlMessages)
  1063  	err = handleIOError(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendmsg", file)
  1064  	// Control messages should be released on error as well as for zero-length
  1065  	// messages, which are discarded by the receiver.
  1066  	if n == 0 || err != nil {
  1067  		controlMessages.Release(t)
  1068  	}
  1069  	return uintptr(n), err
  1070  }
  1071  
  1072  // sendTo is the implementation of the sendto syscall. It is called by sendto
  1073  // and send syscall handlers.
  1074  func sendTo(t *kernel.Task, fd int32, bufPtr hostarch.Addr, bufLen uint64, flags int32, namePtr hostarch.Addr, nameLen uint32) (uintptr, error) {
  1075  	bl := int(bufLen)
  1076  	if bl < 0 {
  1077  		return 0, linuxerr.EINVAL
  1078  	}
  1079  
  1080  	// Get socket from the file descriptor.
  1081  	file := t.GetFile(fd)
  1082  	if file == nil {
  1083  		return 0, linuxerr.EBADF
  1084  	}
  1085  	defer file.DecRef(t)
  1086  
  1087  	// Extract the socket.
  1088  	s, ok := file.FileOperations.(socket.Socket)
  1089  	if !ok {
  1090  		return 0, syserror.ENOTSOCK
  1091  	}
  1092  
  1093  	if file.Flags().NonBlocking {
  1094  		flags |= linux.MSG_DONTWAIT
  1095  	}
  1096  
  1097  	// Read the destination address if one is specified.
  1098  	var to []byte
  1099  	var err error
  1100  	if namePtr != 0 {
  1101  		to, err = CaptureAddress(t, namePtr, nameLen)
  1102  		if err != nil {
  1103  			return 0, err
  1104  		}
  1105  	}
  1106  
  1107  	src, err := t.SingleIOSequence(bufPtr, bl, usermem.IOOpts{
  1108  		AddressSpaceActive: true,
  1109  	})
  1110  	if err != nil {
  1111  		return 0, err
  1112  	}
  1113  
  1114  	var haveDeadline bool
  1115  	var deadline ktime.Time
  1116  	if dl := s.SendTimeout(); dl > 0 {
  1117  		deadline = t.Kernel().MonotonicClock().Now().Add(time.Duration(dl) * time.Nanosecond)
  1118  		haveDeadline = true
  1119  	} else if dl < 0 {
  1120  		flags |= linux.MSG_DONTWAIT
  1121  	}
  1122  
  1123  	// Call the syscall implementation.
  1124  	n, e := s.SendMsg(t, src, to, int(flags), haveDeadline, deadline, socket.ControlMessages{Unix: control.New(t, s, nil)})
  1125  	return uintptr(n), handleIOError(t, n != 0, e.ToError(), syserror.ERESTARTSYS, "sendto", file)
  1126  }
  1127  
  1128  // SendTo implements the linux syscall sendto(2).
  1129  func SendTo(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
  1130  	fd := args[0].Int()
  1131  	bufPtr := args[1].Pointer()
  1132  	bufLen := args[2].Uint64()
  1133  	flags := args[3].Int()
  1134  	namePtr := args[4].Pointer()
  1135  	nameLen := args[5].Uint()
  1136  
  1137  	n, err := sendTo(t, fd, bufPtr, bufLen, flags, namePtr, nameLen)
  1138  	return n, nil, err
  1139  }
  1140  
  1141  // LINT.ThenChange(./vfs2/socket.go)