github.com/MerlinKodo/gvisor@v0.0.0-20231110090155-957f62ecf90e/pkg/sentry/socket/netlink/socket.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package netlink provides core functionality for netlink sockets.
    16  package netlink
    17  
    18  import (
    19  	"io"
    20  	"math"
    21  	"time"
    22  
    23  	"github.com/MerlinKodo/gvisor/pkg/abi/linux"
    24  	"github.com/MerlinKodo/gvisor/pkg/abi/linux/errno"
    25  	"github.com/MerlinKodo/gvisor/pkg/context"
    26  	"github.com/MerlinKodo/gvisor/pkg/errors/linuxerr"
    27  	"github.com/MerlinKodo/gvisor/pkg/hostarch"
    28  	"github.com/MerlinKodo/gvisor/pkg/marshal"
    29  	"github.com/MerlinKodo/gvisor/pkg/marshal/primitive"
    30  	"github.com/MerlinKodo/gvisor/pkg/sentry/arch"
    31  	"github.com/MerlinKodo/gvisor/pkg/sentry/kernel"
    32  	"github.com/MerlinKodo/gvisor/pkg/sentry/kernel/auth"
    33  	ktime "github.com/MerlinKodo/gvisor/pkg/sentry/kernel/time"
    34  	"github.com/MerlinKodo/gvisor/pkg/sentry/socket"
    35  	"github.com/MerlinKodo/gvisor/pkg/sentry/socket/netlink/port"
    36  	"github.com/MerlinKodo/gvisor/pkg/sentry/socket/unix"
    37  	"github.com/MerlinKodo/gvisor/pkg/sentry/socket/unix/transport"
    38  	"github.com/MerlinKodo/gvisor/pkg/sentry/vfs"
    39  	"github.com/MerlinKodo/gvisor/pkg/sync"
    40  	"github.com/MerlinKodo/gvisor/pkg/syserr"
    41  	"github.com/MerlinKodo/gvisor/pkg/usermem"
    42  	"github.com/MerlinKodo/gvisor/pkg/waiter"
    43  )
    44  
    45  const sizeOfInt32 int = 4
    46  
    47  const (
    48  	// minBufferSize is the smallest size of a send buffer.
    49  	minSendBufferSize = 4 << 10 // 4096 bytes.
    50  
    51  	// defaultSendBufferSize is the default size for the send buffer.
    52  	defaultSendBufferSize = 16 * 1024
    53  
    54  	// maxBufferSize is the largest size a send buffer can grow to.
    55  	maxSendBufferSize = 4 << 20 // 4MB
    56  )
    57  
    58  var errNoFilter = syserr.New("no filter attached", errno.ENOENT)
    59  
    60  // Socket is the base socket type for netlink sockets.
    61  //
    62  // This implementation only supports userspace sending and receiving messages
    63  // to/from the kernel.
    64  //
    65  // Socket implements socket.Socket and transport.Credentialer.
    66  //
    67  // +stateify savable
    68  type Socket struct {
    69  	vfsfd vfs.FileDescription
    70  	vfs.FileDescriptionDefaultImpl
    71  	vfs.DentryMetadataFileDescriptionImpl
    72  	vfs.LockFD
    73  	socket.SendReceiveTimeout
    74  
    75  	// ports provides netlink port allocation.
    76  	ports *port.Manager
    77  
    78  	// protocol is the netlink protocol implementation.
    79  	protocol Protocol
    80  
    81  	// skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for
    82  	// netlink sockets.
    83  	skType linux.SockType
    84  
    85  	// ep is a datagram unix endpoint used to buffer messages sent from the
    86  	// kernel to userspace. RecvMsg reads messages from this endpoint.
    87  	ep transport.Endpoint
    88  
    89  	// connection is the kernel's connection to ep, used to write messages
    90  	// sent to userspace.
    91  	connection transport.ConnectedEndpoint
    92  
    93  	// mu protects the fields below.
    94  	mu sync.Mutex `state:"nosave"`
    95  
    96  	// bound indicates that portid is valid.
    97  	bound bool
    98  
    99  	// portID is the port ID allocated for this socket.
   100  	portID int32
   101  
   102  	// sendBufferSize is the send buffer "size". We don't actually have a
   103  	// fixed buffer but only consume this many bytes.
   104  	sendBufferSize uint32
   105  
   106  	// filter indicates that this socket has a BPF filter "installed".
   107  	//
   108  	// TODO(gvisor.dev/issue/1119): We don't actually support filtering,
   109  	// this is just bookkeeping for tracking add/remove.
   110  	filter bool
   111  }
   112  
   113  var _ socket.Socket = (*Socket)(nil)
   114  var _ transport.Credentialer = (*Socket)(nil)
   115  
   116  // New creates a new Socket.
   117  func New(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) {
   118  	// Datagram endpoint used to buffer kernel -> user messages.
   119  	ep := transport.NewConnectionless(t)
   120  
   121  	// Bind the endpoint for good measure so we can connect to it. The
   122  	// bound address will never be exposed.
   123  	if err := ep.Bind(transport.Address{Addr: "dummy"}); err != nil {
   124  		ep.Close(t)
   125  		return nil, err
   126  	}
   127  
   128  	// Create a connection from which the kernel can write messages.
   129  	connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
   130  	if err != nil {
   131  		ep.Close(t)
   132  		return nil, err
   133  	}
   134  
   135  	fd := &Socket{
   136  		ports:          t.Kernel().NetlinkPorts(),
   137  		protocol:       protocol,
   138  		skType:         skType,
   139  		ep:             ep,
   140  		connection:     connection,
   141  		sendBufferSize: defaultSendBufferSize,
   142  	}
   143  	fd.LockFD.Init(&vfs.FileLocks{})
   144  	return fd, nil
   145  }
   146  
   147  // Release implements vfs.FileDescriptionImpl.Release.
   148  func (s *Socket) Release(ctx context.Context) {
   149  	t := kernel.TaskFromContext(ctx)
   150  	t.Kernel().DeleteSocket(&s.vfsfd)
   151  	s.connection.Release(ctx)
   152  	s.ep.Close(ctx)
   153  
   154  	if s.bound {
   155  		s.ports.Release(s.protocol.Protocol(), s.portID)
   156  	}
   157  }
   158  
   159  // Epollable implements FileDescriptionImpl.Epollable.
   160  func (s *Socket) Epollable() bool {
   161  	return true
   162  }
   163  
   164  // Ioctl implements vfs.FileDescriptionImpl.
   165  func (*Socket) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
   166  	// TODO(b/68878065): no ioctls supported.
   167  	return 0, linuxerr.ENOTTY
   168  }
   169  
   170  // PRead implements vfs.FileDescriptionImpl.
   171  func (s *Socket) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
   172  	return 0, linuxerr.ESPIPE
   173  }
   174  
   175  // Read implements vfs.FileDescriptionImpl.
   176  func (s *Socket) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
   177  	// All flags other than RWF_NOWAIT should be ignored.
   178  	// TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
   179  	if opts.Flags != 0 {
   180  		return 0, linuxerr.EOPNOTSUPP
   181  	}
   182  
   183  	if dst.NumBytes() == 0 {
   184  		return 0, nil
   185  	}
   186  	r := unix.EndpointReader{
   187  		Endpoint: s.ep,
   188  	}
   189  	n, err := dst.CopyOutFrom(ctx, &r)
   190  	if r.Notify != nil {
   191  		r.Notify()
   192  	}
   193  	return n, err
   194  }
   195  
   196  // PWrite implements vfs.FileDescriptionImpl.
   197  func (s *Socket) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
   198  	return 0, linuxerr.ESPIPE
   199  }
   200  
   201  // Write implements vfs.FileDescriptionImpl.
   202  func (s *Socket) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
   203  	// All flags other than RWF_NOWAIT should be ignored.
   204  	// TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
   205  	if opts.Flags != 0 {
   206  		return 0, linuxerr.EOPNOTSUPP
   207  	}
   208  
   209  	n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
   210  	return int64(n), err.ToError()
   211  }
   212  
   213  // Readiness implements waiter.Waitable.Readiness.
   214  func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask {
   215  	// ep holds messages to be read and thus handles EventIn readiness.
   216  	ready := s.ep.Readiness(mask)
   217  
   218  	if mask&waiter.WritableEvents != 0 {
   219  		// sendMsg handles messages synchronously and is thus always
   220  		// ready for writing.
   221  		ready |= waiter.WritableEvents
   222  	}
   223  
   224  	return ready
   225  }
   226  
   227  // EventRegister implements waiter.Waitable.EventRegister.
   228  func (s *Socket) EventRegister(e *waiter.Entry) error {
   229  	return s.ep.EventRegister(e)
   230  	// Writable readiness never changes, so no registration is needed.
   231  }
   232  
   233  // EventUnregister implements waiter.Waitable.EventUnregister.
   234  func (s *Socket) EventUnregister(e *waiter.Entry) {
   235  	s.ep.EventUnregister(e)
   236  }
   237  
   238  // Passcred implements transport.Credentialer.Passcred.
   239  func (s *Socket) Passcred() bool {
   240  	return s.ep.SocketOptions().GetPassCred()
   241  }
   242  
   243  // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
   244  func (s *Socket) ConnectedPasscred() bool {
   245  	// This socket is connected to the kernel, which doesn't need creds.
   246  	//
   247  	// This is arbitrary, as ConnectedPasscred on this type has no callers.
   248  	return false
   249  }
   250  
   251  // ExtractSockAddr extracts the SockAddrNetlink from b.
   252  func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
   253  	if len(b) < linux.SockAddrNetlinkSize {
   254  		return nil, syserr.ErrBadAddress
   255  	}
   256  
   257  	var sa linux.SockAddrNetlink
   258  	sa.UnmarshalUnsafe(b)
   259  
   260  	if sa.Family != linux.AF_NETLINK {
   261  		return nil, syserr.ErrInvalidArgument
   262  	}
   263  
   264  	return &sa, nil
   265  }
   266  
   267  // bindPort binds this socket to a port, preferring 'port' if it is available.
   268  //
   269  // port of 0 defaults to the ThreadGroup ID.
   270  //
   271  // Preconditions: mu is held.
   272  func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error {
   273  	if s.bound {
   274  		// Re-binding is only allowed if the port doesn't change.
   275  		if port != s.portID {
   276  			return syserr.ErrInvalidArgument
   277  		}
   278  
   279  		return nil
   280  	}
   281  
   282  	if port == 0 {
   283  		port = int32(t.ThreadGroup().ID())
   284  	}
   285  	port, ok := s.ports.Allocate(s.protocol.Protocol(), port)
   286  	if !ok {
   287  		return syserr.ErrBusy
   288  	}
   289  
   290  	s.portID = port
   291  	s.bound = true
   292  	return nil
   293  }
   294  
   295  // Bind implements socket.Socket.Bind.
   296  func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
   297  	a, err := ExtractSockAddr(sockaddr)
   298  	if err != nil {
   299  		return err
   300  	}
   301  
   302  	// No support for multicast groups yet.
   303  	if a.Groups != 0 {
   304  		return syserr.ErrPermissionDenied
   305  	}
   306  
   307  	s.mu.Lock()
   308  	defer s.mu.Unlock()
   309  
   310  	return s.bindPort(t, int32(a.PortID))
   311  }
   312  
   313  // Connect implements socket.Socket.Connect.
   314  func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
   315  	a, err := ExtractSockAddr(sockaddr)
   316  	if err != nil {
   317  		return err
   318  	}
   319  
   320  	// No support for multicast groups yet.
   321  	if a.Groups != 0 {
   322  		return syserr.ErrPermissionDenied
   323  	}
   324  
   325  	s.mu.Lock()
   326  	defer s.mu.Unlock()
   327  
   328  	if a.PortID == 0 {
   329  		// Netlink sockets default to connected to the kernel, but
   330  		// connecting anyways automatically binds if not already bound.
   331  		if !s.bound {
   332  			// Pass port 0 to get an auto-selected port ID.
   333  			return s.bindPort(t, 0)
   334  		}
   335  		return nil
   336  	}
   337  
   338  	// We don't support non-kernel destination ports. Linux returns EPERM
   339  	// if applications attempt to do this without NL_CFG_F_NONROOT_SEND, so
   340  	// we emulate that.
   341  	return syserr.ErrPermissionDenied
   342  }
   343  
   344  // Accept implements socket.Socket.Accept.
   345  func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
   346  	// Netlink sockets never support accept.
   347  	return 0, nil, 0, syserr.ErrNotSupported
   348  }
   349  
   350  // Listen implements socket.Socket.Listen.
   351  func (s *Socket) Listen(t *kernel.Task, backlog int) *syserr.Error {
   352  	// Netlink sockets never support listen.
   353  	return syserr.ErrNotSupported
   354  }
   355  
   356  // Shutdown implements socket.Socket.Shutdown.
   357  func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error {
   358  	// Netlink sockets never support shutdown.
   359  	return syserr.ErrNotSupported
   360  }
   361  
   362  // GetSockOpt implements socket.Socket.GetSockOpt.
   363  func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
   364  	switch level {
   365  	case linux.SOL_SOCKET:
   366  		switch name {
   367  		case linux.SO_SNDBUF:
   368  			if outLen < sizeOfInt32 {
   369  				return nil, syserr.ErrInvalidArgument
   370  			}
   371  			s.mu.Lock()
   372  			defer s.mu.Unlock()
   373  			return primitive.AllocateInt32(int32(s.sendBufferSize)), nil
   374  
   375  		case linux.SO_RCVBUF:
   376  			if outLen < sizeOfInt32 {
   377  				return nil, syserr.ErrInvalidArgument
   378  			}
   379  			// We don't have limit on receiving size.
   380  			return primitive.AllocateInt32(math.MaxInt32), nil
   381  
   382  		case linux.SO_PASSCRED:
   383  			if outLen < sizeOfInt32 {
   384  				return nil, syserr.ErrInvalidArgument
   385  			}
   386  			var passcred primitive.Int32
   387  			if s.Passcred() {
   388  				passcred = 1
   389  			}
   390  			return &passcred, nil
   391  
   392  		case linux.SO_SNDTIMEO:
   393  			if outLen < linux.SizeOfTimeval {
   394  				return nil, syserr.ErrInvalidArgument
   395  			}
   396  			sendTimeout := linux.NsecToTimeval(s.SendTimeout())
   397  			return &sendTimeout, nil
   398  
   399  		case linux.SO_RCVTIMEO:
   400  			if outLen < linux.SizeOfTimeval {
   401  				return nil, syserr.ErrInvalidArgument
   402  			}
   403  			recvTimeout := linux.NsecToTimeval(s.RecvTimeout())
   404  			return &recvTimeout, nil
   405  		}
   406  	case linux.SOL_NETLINK:
   407  		switch name {
   408  		case linux.NETLINK_BROADCAST_ERROR,
   409  			linux.NETLINK_CAP_ACK,
   410  			linux.NETLINK_DUMP_STRICT_CHK,
   411  			linux.NETLINK_EXT_ACK,
   412  			linux.NETLINK_LIST_MEMBERSHIPS,
   413  			linux.NETLINK_NO_ENOBUFS,
   414  			linux.NETLINK_PKTINFO:
   415  			// Not supported.
   416  		}
   417  	}
   418  	// TODO(b/68878065): other sockopts are not supported.
   419  	return nil, syserr.ErrProtocolNotAvailable
   420  }
   421  
   422  // SetSockOpt implements socket.Socket.SetSockOpt.
   423  func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
   424  	switch level {
   425  	case linux.SOL_SOCKET:
   426  		switch name {
   427  		case linux.SO_SNDBUF:
   428  			if len(opt) < sizeOfInt32 {
   429  				return syserr.ErrInvalidArgument
   430  			}
   431  			size := hostarch.ByteOrder.Uint32(opt)
   432  			if size < minSendBufferSize {
   433  				size = minSendBufferSize
   434  			} else if size > maxSendBufferSize {
   435  				size = maxSendBufferSize
   436  			}
   437  			s.mu.Lock()
   438  			s.sendBufferSize = size
   439  			s.mu.Unlock()
   440  			return nil
   441  
   442  		case linux.SO_RCVBUF:
   443  			if len(opt) < sizeOfInt32 {
   444  				return syserr.ErrInvalidArgument
   445  			}
   446  			// We don't have limit on receiving size. So just accept anything as
   447  			// valid for compatibility.
   448  			return nil
   449  
   450  		case linux.SO_PASSCRED:
   451  			if len(opt) < sizeOfInt32 {
   452  				return syserr.ErrInvalidArgument
   453  			}
   454  			passcred := hostarch.ByteOrder.Uint32(opt)
   455  
   456  			s.ep.SocketOptions().SetPassCred(passcred != 0)
   457  			return nil
   458  
   459  		case linux.SO_ATTACH_FILTER:
   460  			// TODO(gvisor.dev/issue/1119): We don't actually
   461  			// support filtering. If this socket can't ever send
   462  			// messages, then there is nothing to filter and we can
   463  			// advertise support. Otherwise, be conservative and
   464  			// return an error.
   465  			if s.protocol.CanSend() {
   466  				return syserr.ErrProtocolNotAvailable
   467  			}
   468  
   469  			s.mu.Lock()
   470  			s.filter = true
   471  			s.mu.Unlock()
   472  			return nil
   473  
   474  		case linux.SO_DETACH_FILTER:
   475  			// TODO(gvisor.dev/issue/1119): See above.
   476  			if s.protocol.CanSend() {
   477  				return syserr.ErrProtocolNotAvailable
   478  			}
   479  
   480  			s.mu.Lock()
   481  			filter := s.filter
   482  			s.filter = false
   483  			s.mu.Unlock()
   484  
   485  			if !filter {
   486  				return errNoFilter
   487  			}
   488  
   489  			return nil
   490  
   491  		case linux.SO_SNDTIMEO:
   492  			if len(opt) < linux.SizeOfTimeval {
   493  				return syserr.ErrInvalidArgument
   494  			}
   495  
   496  			var v linux.Timeval
   497  			v.UnmarshalBytes(opt)
   498  			if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
   499  				return syserr.ErrDomain
   500  			}
   501  			s.SetSendTimeout(v.ToNsecCapped())
   502  			return nil
   503  
   504  		case linux.SO_RCVTIMEO:
   505  			if len(opt) < linux.SizeOfTimeval {
   506  				return syserr.ErrInvalidArgument
   507  			}
   508  
   509  			var v linux.Timeval
   510  			v.UnmarshalBytes(opt)
   511  			if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
   512  				return syserr.ErrDomain
   513  			}
   514  			s.SetRecvTimeout(v.ToNsecCapped())
   515  			return nil
   516  		}
   517  	case linux.SOL_NETLINK:
   518  		switch name {
   519  		case linux.NETLINK_ADD_MEMBERSHIP,
   520  			linux.NETLINK_BROADCAST_ERROR,
   521  			linux.NETLINK_CAP_ACK,
   522  			linux.NETLINK_DROP_MEMBERSHIP,
   523  			linux.NETLINK_DUMP_STRICT_CHK,
   524  			linux.NETLINK_EXT_ACK,
   525  			linux.NETLINK_LISTEN_ALL_NSID,
   526  			linux.NETLINK_NO_ENOBUFS,
   527  			linux.NETLINK_PKTINFO:
   528  			// Not supported.
   529  		}
   530  	}
   531  
   532  	// TODO(b/68878065): other sockopts are not supported.
   533  	return syserr.ErrProtocolNotAvailable
   534  }
   535  
   536  // GetSockName implements socket.Socket.GetSockName.
   537  func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
   538  	s.mu.Lock()
   539  	defer s.mu.Unlock()
   540  
   541  	sa := &linux.SockAddrNetlink{
   542  		Family: linux.AF_NETLINK,
   543  		PortID: uint32(s.portID),
   544  	}
   545  	return sa, uint32(sa.SizeBytes()), nil
   546  }
   547  
   548  // GetPeerName implements socket.Socket.GetPeerName.
   549  func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
   550  	sa := &linux.SockAddrNetlink{
   551  		Family: linux.AF_NETLINK,
   552  		// TODO(b/68878065): Support non-kernel peers. For now the peer
   553  		// must be the kernel.
   554  		PortID: 0,
   555  	}
   556  	return sa, uint32(sa.SizeBytes()), nil
   557  }
   558  
   559  // RecvMsg implements socket.Socket.RecvMsg.
   560  func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
   561  	from := &linux.SockAddrNetlink{
   562  		Family: linux.AF_NETLINK,
   563  		PortID: 0,
   564  	}
   565  	fromLen := uint32(from.SizeBytes())
   566  
   567  	trunc := flags&linux.MSG_TRUNC != 0
   568  
   569  	r := unix.EndpointReader{
   570  		Ctx:      t,
   571  		Endpoint: s.ep,
   572  		Peek:     flags&linux.MSG_PEEK != 0,
   573  	}
   574  
   575  	doRead := func() (int64, error) {
   576  		return dst.CopyOutFrom(t, &r)
   577  	}
   578  
   579  	// If MSG_TRUNC is set with a zero byte destination then we still need
   580  	// to read the message and discard it, or in the case where MSG_PEEK is
   581  	// set, leave it be. In both cases the full message length must be
   582  	// returned.
   583  	if trunc && dst.Addrs.NumBytes() == 0 {
   584  		doRead = func() (int64, error) {
   585  			err := r.Truncate()
   586  			// Always return zero for bytes read since the destination size is
   587  			// zero.
   588  			return 0, err
   589  		}
   590  	}
   591  
   592  	if n, err := doRead(); err != linuxerr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
   593  		var mflags int
   594  		if n < int64(r.MsgSize) {
   595  			mflags |= linux.MSG_TRUNC
   596  		}
   597  		if trunc {
   598  			n = int64(r.MsgSize)
   599  		}
   600  		return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
   601  	}
   602  
   603  	// We'll have to block. Register for notification and keep trying to
   604  	// receive all the data.
   605  	e, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
   606  	if err := s.EventRegister(&e); err != nil {
   607  		return 0, 0, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
   608  	}
   609  	defer s.EventUnregister(&e)
   610  
   611  	for {
   612  		if n, err := doRead(); err != linuxerr.ErrWouldBlock {
   613  			var mflags int
   614  			if n < int64(r.MsgSize) {
   615  				mflags |= linux.MSG_TRUNC
   616  			}
   617  			if trunc {
   618  				n = int64(r.MsgSize)
   619  			}
   620  			return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
   621  		}
   622  
   623  		if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
   624  			if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
   625  				return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
   626  			}
   627  			return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
   628  		}
   629  	}
   630  }
   631  
   632  // kernelSCM implements control.SCMCredentials with credentials that represent
   633  // the kernel itself rather than a Task.
   634  //
   635  // +stateify savable
   636  type kernelSCM struct{}
   637  
   638  // Equals implements transport.CredentialsControlMessage.Equals.
   639  func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool {
   640  	_, ok := oc.(kernelSCM)
   641  	return ok
   642  }
   643  
   644  // Credentials implements control.SCMCredentials.Credentials.
   645  func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
   646  	return 0, auth.RootUID, auth.RootGID
   647  }
   648  
   649  // kernelCreds is the concrete version of kernelSCM used in all creds.
   650  var kernelCreds = &kernelSCM{}
   651  
   652  // sendResponse sends the response messages in ms back to userspace.
   653  func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
   654  	// Linux combines multiple netlink messages into a single datagram.
   655  	bufs := make([][]byte, 0, len(ms.Messages))
   656  	for _, m := range ms.Messages {
   657  		bufs = append(bufs, m.Finalize())
   658  	}
   659  
   660  	// All messages are from the kernel.
   661  	cms := transport.ControlMessages{
   662  		Credentials: kernelCreds,
   663  	}
   664  
   665  	if len(bufs) > 0 {
   666  		// RecvMsg never receives the address, so we don't need to send
   667  		// one.
   668  		_, notify, err := s.connection.Send(ctx, bufs, cms, transport.Address{})
   669  		// If the buffer is full, we simply drop messages, just like
   670  		// Linux.
   671  		if err != nil && err != syserr.ErrWouldBlock {
   672  			return err
   673  		}
   674  		if notify {
   675  			s.connection.SendNotify()
   676  		}
   677  	}
   678  
   679  	// N.B. multi-part messages should still send NLMSG_DONE even if
   680  	// MessageSet contains no messages.
   681  	//
   682  	// N.B. NLMSG_DONE is always sent in a different datagram. See
   683  	// net/netlink/af_netlink.c:netlink_dump.
   684  	if ms.Multi {
   685  		m := NewMessage(linux.NetlinkMessageHeader{
   686  			Type:   linux.NLMSG_DONE,
   687  			Flags:  linux.NLM_F_MULTI,
   688  			Seq:    ms.Seq,
   689  			PortID: uint32(ms.PortID),
   690  		})
   691  
   692  		// Add the dump_done_errno payload.
   693  		m.Put(primitive.AllocateInt64(0))
   694  
   695  		_, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, transport.Address{})
   696  		if err != nil && err != syserr.ErrWouldBlock {
   697  			return err
   698  		}
   699  		if notify {
   700  			s.connection.SendNotify()
   701  		}
   702  	}
   703  
   704  	return nil
   705  }
   706  
   707  func dumpErrorMessage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) {
   708  	m := ms.AddMessage(linux.NetlinkMessageHeader{
   709  		Type: linux.NLMSG_ERROR,
   710  	})
   711  	m.Put(&linux.NetlinkErrorMessage{
   712  		Error:  int32(-err.ToLinux()),
   713  		Header: hdr,
   714  	})
   715  }
   716  
   717  func dumpAckMessage(hdr linux.NetlinkMessageHeader, ms *MessageSet) {
   718  	m := ms.AddMessage(linux.NetlinkMessageHeader{
   719  		Type: linux.NLMSG_ERROR,
   720  	})
   721  	m.Put(&linux.NetlinkErrorMessage{
   722  		Error:  0,
   723  		Header: hdr,
   724  	})
   725  }
   726  
   727  // processMessages handles each message in buf, passing it to the protocol
   728  // handler for final handling.
   729  func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error {
   730  	for len(buf) > 0 {
   731  		msg, rest, ok := ParseMessage(buf)
   732  		if !ok {
   733  			// Linux ignores messages that are too short. See
   734  			// net/netlink/af_netlink.c:netlink_rcv_skb.
   735  			break
   736  		}
   737  		buf = rest
   738  		hdr := msg.Header()
   739  
   740  		// Ignore control messages.
   741  		if hdr.Type < linux.NLMSG_MIN_TYPE {
   742  			continue
   743  		}
   744  
   745  		ms := NewMessageSet(s.portID, hdr.Seq)
   746  		if err := s.protocol.ProcessMessage(ctx, msg, ms); err != nil {
   747  			dumpErrorMessage(hdr, ms, err)
   748  		} else if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
   749  			dumpAckMessage(hdr, ms)
   750  		}
   751  
   752  		if err := s.sendResponse(ctx, ms); err != nil {
   753  			return err
   754  		}
   755  	}
   756  
   757  	return nil
   758  }
   759  
   760  // sendMsg is the core of message send, used for SendMsg and Write.
   761  func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
   762  	dstPort := int32(0)
   763  
   764  	if len(to) != 0 {
   765  		a, err := ExtractSockAddr(to)
   766  		if err != nil {
   767  			return 0, err
   768  		}
   769  
   770  		// No support for multicast groups yet.
   771  		if a.Groups != 0 {
   772  			return 0, syserr.ErrPermissionDenied
   773  		}
   774  
   775  		dstPort = int32(a.PortID)
   776  	}
   777  
   778  	if dstPort != 0 {
   779  		// Non-kernel destinations not supported yet. Treat as if
   780  		// NL_CFG_F_NONROOT_SEND is not set.
   781  		return 0, syserr.ErrPermissionDenied
   782  	}
   783  
   784  	s.mu.Lock()
   785  	defer s.mu.Unlock()
   786  
   787  	// For simplicity, and consistency with Linux, we copy in the entire
   788  	// message up front.
   789  	if src.NumBytes() > int64(s.sendBufferSize) {
   790  		return 0, syserr.ErrMessageTooLong
   791  	}
   792  
   793  	buf := make([]byte, src.NumBytes())
   794  	n, err := src.CopyIn(ctx, buf)
   795  	// io.EOF can be only returned if src is a file, this means that
   796  	// sendMsg is called from splice and the error has to be ignored in
   797  	// this case.
   798  	if err == io.EOF {
   799  		err = nil
   800  	}
   801  	if err != nil {
   802  		// Don't partially consume messages.
   803  		return 0, syserr.FromError(err)
   804  	}
   805  
   806  	if err := s.processMessages(ctx, buf); err != nil {
   807  		return 0, err
   808  	}
   809  
   810  	return n, nil
   811  }
   812  
   813  // SendMsg implements socket.Socket.SendMsg.
   814  func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
   815  	return s.sendMsg(t, src, to, flags, controlMessages)
   816  }
   817  
   818  // State implements socket.Socket.State.
   819  func (s *Socket) State() uint32 {
   820  	return s.ep.State()
   821  }
   822  
   823  // Type implements socket.Socket.Type.
   824  func (s *Socket) Type() (family int, skType linux.SockType, protocol int) {
   825  	return linux.AF_NETLINK, s.skType, s.protocol.Protocol()
   826  }