github.com/ttpreport/gvisor-ligolo@v0.0.0-20240123134145-a858404967ba/pkg/sentry/socket/netlink/socket.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package netlink provides core functionality for netlink sockets.
    16  package netlink
    17  
    18  import (
    19  	"io"
    20  	"math"
    21  
    22  	"github.com/ttpreport/gvisor-ligolo/pkg/abi/linux"
    23  	"github.com/ttpreport/gvisor-ligolo/pkg/abi/linux/errno"
    24  	"github.com/ttpreport/gvisor-ligolo/pkg/context"
    25  	"github.com/ttpreport/gvisor-ligolo/pkg/errors/linuxerr"
    26  	"github.com/ttpreport/gvisor-ligolo/pkg/hostarch"
    27  	"github.com/ttpreport/gvisor-ligolo/pkg/marshal"
    28  	"github.com/ttpreport/gvisor-ligolo/pkg/marshal/primitive"
    29  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/arch"
    30  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/kernel"
    31  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/kernel/auth"
    32  	ktime "github.com/ttpreport/gvisor-ligolo/pkg/sentry/kernel/time"
    33  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/socket"
    34  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/socket/netlink/port"
    35  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/socket/unix"
    36  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/socket/unix/transport"
    37  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/vfs"
    38  	"github.com/ttpreport/gvisor-ligolo/pkg/sync"
    39  	"github.com/ttpreport/gvisor-ligolo/pkg/syserr"
    40  	"github.com/ttpreport/gvisor-ligolo/pkg/usermem"
    41  	"github.com/ttpreport/gvisor-ligolo/pkg/waiter"
    42  )
    43  
    44  const sizeOfInt32 int = 4
    45  
    46  const (
    47  	// minBufferSize is the smallest size of a send buffer.
    48  	minSendBufferSize = 4 << 10 // 4096 bytes.
    49  
    50  	// defaultSendBufferSize is the default size for the send buffer.
    51  	defaultSendBufferSize = 16 * 1024
    52  
    53  	// maxBufferSize is the largest size a send buffer can grow to.
    54  	maxSendBufferSize = 4 << 20 // 4MB
    55  )
    56  
    57  var errNoFilter = syserr.New("no filter attached", errno.ENOENT)
    58  
    59  // Socket is the base socket type for netlink sockets.
    60  //
    61  // This implementation only supports userspace sending and receiving messages
    62  // to/from the kernel.
    63  //
    64  // Socket implements socket.Socket and transport.Credentialer.
    65  //
    66  // +stateify savable
    67  type Socket struct {
    68  	vfsfd vfs.FileDescription
    69  	vfs.FileDescriptionDefaultImpl
    70  	vfs.DentryMetadataFileDescriptionImpl
    71  	vfs.LockFD
    72  	socket.SendReceiveTimeout
    73  
    74  	// ports provides netlink port allocation.
    75  	ports *port.Manager
    76  
    77  	// protocol is the netlink protocol implementation.
    78  	protocol Protocol
    79  
    80  	// skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for
    81  	// netlink sockets.
    82  	skType linux.SockType
    83  
    84  	// ep is a datagram unix endpoint used to buffer messages sent from the
    85  	// kernel to userspace. RecvMsg reads messages from this endpoint.
    86  	ep transport.Endpoint
    87  
    88  	// connection is the kernel's connection to ep, used to write messages
    89  	// sent to userspace.
    90  	connection transport.ConnectedEndpoint
    91  
    92  	// mu protects the fields below.
    93  	mu sync.Mutex `state:"nosave"`
    94  
    95  	// bound indicates that portid is valid.
    96  	bound bool
    97  
    98  	// portID is the port ID allocated for this socket.
    99  	portID int32
   100  
   101  	// sendBufferSize is the send buffer "size". We don't actually have a
   102  	// fixed buffer but only consume this many bytes.
   103  	sendBufferSize uint32
   104  
   105  	// filter indicates that this socket has a BPF filter "installed".
   106  	//
   107  	// TODO(gvisor.dev/issue/1119): We don't actually support filtering,
   108  	// this is just bookkeeping for tracking add/remove.
   109  	filter bool
   110  }
   111  
   112  var _ socket.Socket = (*Socket)(nil)
   113  var _ transport.Credentialer = (*Socket)(nil)
   114  
   115  // New creates a new Socket.
   116  func New(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) {
   117  	// Datagram endpoint used to buffer kernel -> user messages.
   118  	ep := transport.NewConnectionless(t)
   119  
   120  	// Bind the endpoint for good measure so we can connect to it. The
   121  	// bound address will never be exposed.
   122  	if err := ep.Bind(transport.Address{Addr: "dummy"}); err != nil {
   123  		ep.Close(t)
   124  		return nil, err
   125  	}
   126  
   127  	// Create a connection from which the kernel can write messages.
   128  	connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
   129  	if err != nil {
   130  		ep.Close(t)
   131  		return nil, err
   132  	}
   133  
   134  	fd := &Socket{
   135  		ports:          t.Kernel().NetlinkPorts(),
   136  		protocol:       protocol,
   137  		skType:         skType,
   138  		ep:             ep,
   139  		connection:     connection,
   140  		sendBufferSize: defaultSendBufferSize,
   141  	}
   142  	fd.LockFD.Init(&vfs.FileLocks{})
   143  	return fd, nil
   144  }
   145  
   146  // Release implements vfs.FileDescriptionImpl.Release.
   147  func (s *Socket) Release(ctx context.Context) {
   148  	t := kernel.TaskFromContext(ctx)
   149  	t.Kernel().DeleteSocket(&s.vfsfd)
   150  	s.connection.Release(ctx)
   151  	s.ep.Close(ctx)
   152  
   153  	if s.bound {
   154  		s.ports.Release(s.protocol.Protocol(), s.portID)
   155  	}
   156  }
   157  
   158  // Epollable implements FileDescriptionImpl.Epollable.
   159  func (s *Socket) Epollable() bool {
   160  	return true
   161  }
   162  
   163  // Ioctl implements vfs.FileDescriptionImpl.
   164  func (*Socket) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
   165  	// TODO(b/68878065): no ioctls supported.
   166  	return 0, linuxerr.ENOTTY
   167  }
   168  
   169  // PRead implements vfs.FileDescriptionImpl.
   170  func (s *Socket) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
   171  	return 0, linuxerr.ESPIPE
   172  }
   173  
   174  // Read implements vfs.FileDescriptionImpl.
   175  func (s *Socket) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
   176  	// All flags other than RWF_NOWAIT should be ignored.
   177  	// TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
   178  	if opts.Flags != 0 {
   179  		return 0, linuxerr.EOPNOTSUPP
   180  	}
   181  
   182  	if dst.NumBytes() == 0 {
   183  		return 0, nil
   184  	}
   185  	r := unix.EndpointReader{
   186  		Endpoint: s.ep,
   187  	}
   188  	n, err := dst.CopyOutFrom(ctx, &r)
   189  	if r.Notify != nil {
   190  		r.Notify()
   191  	}
   192  	return n, err
   193  }
   194  
   195  // PWrite implements vfs.FileDescriptionImpl.
   196  func (s *Socket) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
   197  	return 0, linuxerr.ESPIPE
   198  }
   199  
   200  // Write implements vfs.FileDescriptionImpl.
   201  func (s *Socket) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
   202  	// All flags other than RWF_NOWAIT should be ignored.
   203  	// TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
   204  	if opts.Flags != 0 {
   205  		return 0, linuxerr.EOPNOTSUPP
   206  	}
   207  
   208  	n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
   209  	return int64(n), err.ToError()
   210  }
   211  
   212  // Readiness implements waiter.Waitable.Readiness.
   213  func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask {
   214  	// ep holds messages to be read and thus handles EventIn readiness.
   215  	ready := s.ep.Readiness(mask)
   216  
   217  	if mask&waiter.WritableEvents != 0 {
   218  		// sendMsg handles messages synchronously and is thus always
   219  		// ready for writing.
   220  		ready |= waiter.WritableEvents
   221  	}
   222  
   223  	return ready
   224  }
   225  
   226  // EventRegister implements waiter.Waitable.EventRegister.
   227  func (s *Socket) EventRegister(e *waiter.Entry) error {
   228  	return s.ep.EventRegister(e)
   229  	// Writable readiness never changes, so no registration is needed.
   230  }
   231  
   232  // EventUnregister implements waiter.Waitable.EventUnregister.
   233  func (s *Socket) EventUnregister(e *waiter.Entry) {
   234  	s.ep.EventUnregister(e)
   235  }
   236  
   237  // Passcred implements transport.Credentialer.Passcred.
   238  func (s *Socket) Passcred() bool {
   239  	return s.ep.SocketOptions().GetPassCred()
   240  }
   241  
   242  // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
   243  func (s *Socket) ConnectedPasscred() bool {
   244  	// This socket is connected to the kernel, which doesn't need creds.
   245  	//
   246  	// This is arbitrary, as ConnectedPasscred on this type has no callers.
   247  	return false
   248  }
   249  
   250  // ExtractSockAddr extracts the SockAddrNetlink from b.
   251  func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
   252  	if len(b) < linux.SockAddrNetlinkSize {
   253  		return nil, syserr.ErrBadAddress
   254  	}
   255  
   256  	var sa linux.SockAddrNetlink
   257  	sa.UnmarshalUnsafe(b)
   258  
   259  	if sa.Family != linux.AF_NETLINK {
   260  		return nil, syserr.ErrInvalidArgument
   261  	}
   262  
   263  	return &sa, nil
   264  }
   265  
   266  // bindPort binds this socket to a port, preferring 'port' if it is available.
   267  //
   268  // port of 0 defaults to the ThreadGroup ID.
   269  //
   270  // Preconditions: mu is held.
   271  func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error {
   272  	if s.bound {
   273  		// Re-binding is only allowed if the port doesn't change.
   274  		if port != s.portID {
   275  			return syserr.ErrInvalidArgument
   276  		}
   277  
   278  		return nil
   279  	}
   280  
   281  	if port == 0 {
   282  		port = int32(t.ThreadGroup().ID())
   283  	}
   284  	port, ok := s.ports.Allocate(s.protocol.Protocol(), port)
   285  	if !ok {
   286  		return syserr.ErrBusy
   287  	}
   288  
   289  	s.portID = port
   290  	s.bound = true
   291  	return nil
   292  }
   293  
   294  // Bind implements socket.Socket.Bind.
   295  func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
   296  	a, err := ExtractSockAddr(sockaddr)
   297  	if err != nil {
   298  		return err
   299  	}
   300  
   301  	// No support for multicast groups yet.
   302  	if a.Groups != 0 {
   303  		return syserr.ErrPermissionDenied
   304  	}
   305  
   306  	s.mu.Lock()
   307  	defer s.mu.Unlock()
   308  
   309  	return s.bindPort(t, int32(a.PortID))
   310  }
   311  
   312  // Connect implements socket.Socket.Connect.
   313  func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
   314  	a, err := ExtractSockAddr(sockaddr)
   315  	if err != nil {
   316  		return err
   317  	}
   318  
   319  	// No support for multicast groups yet.
   320  	if a.Groups != 0 {
   321  		return syserr.ErrPermissionDenied
   322  	}
   323  
   324  	s.mu.Lock()
   325  	defer s.mu.Unlock()
   326  
   327  	if a.PortID == 0 {
   328  		// Netlink sockets default to connected to the kernel, but
   329  		// connecting anyways automatically binds if not already bound.
   330  		if !s.bound {
   331  			// Pass port 0 to get an auto-selected port ID.
   332  			return s.bindPort(t, 0)
   333  		}
   334  		return nil
   335  	}
   336  
   337  	// We don't support non-kernel destination ports. Linux returns EPERM
   338  	// if applications attempt to do this without NL_CFG_F_NONROOT_SEND, so
   339  	// we emulate that.
   340  	return syserr.ErrPermissionDenied
   341  }
   342  
   343  // Accept implements socket.Socket.Accept.
   344  func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
   345  	// Netlink sockets never support accept.
   346  	return 0, nil, 0, syserr.ErrNotSupported
   347  }
   348  
   349  // Listen implements socket.Socket.Listen.
   350  func (s *Socket) Listen(t *kernel.Task, backlog int) *syserr.Error {
   351  	// Netlink sockets never support listen.
   352  	return syserr.ErrNotSupported
   353  }
   354  
   355  // Shutdown implements socket.Socket.Shutdown.
   356  func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error {
   357  	// Netlink sockets never support shutdown.
   358  	return syserr.ErrNotSupported
   359  }
   360  
   361  // GetSockOpt implements socket.Socket.GetSockOpt.
   362  func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
   363  	switch level {
   364  	case linux.SOL_SOCKET:
   365  		switch name {
   366  		case linux.SO_SNDBUF:
   367  			if outLen < sizeOfInt32 {
   368  				return nil, syserr.ErrInvalidArgument
   369  			}
   370  			s.mu.Lock()
   371  			defer s.mu.Unlock()
   372  			return primitive.AllocateInt32(int32(s.sendBufferSize)), nil
   373  
   374  		case linux.SO_RCVBUF:
   375  			if outLen < sizeOfInt32 {
   376  				return nil, syserr.ErrInvalidArgument
   377  			}
   378  			// We don't have limit on receiving size.
   379  			return primitive.AllocateInt32(math.MaxInt32), nil
   380  
   381  		case linux.SO_PASSCRED:
   382  			if outLen < sizeOfInt32 {
   383  				return nil, syserr.ErrInvalidArgument
   384  			}
   385  			var passcred primitive.Int32
   386  			if s.Passcred() {
   387  				passcred = 1
   388  			}
   389  			return &passcred, nil
   390  		}
   391  	case linux.SOL_NETLINK:
   392  		switch name {
   393  		case linux.NETLINK_BROADCAST_ERROR,
   394  			linux.NETLINK_CAP_ACK,
   395  			linux.NETLINK_DUMP_STRICT_CHK,
   396  			linux.NETLINK_EXT_ACK,
   397  			linux.NETLINK_LIST_MEMBERSHIPS,
   398  			linux.NETLINK_NO_ENOBUFS,
   399  			linux.NETLINK_PKTINFO:
   400  			// Not supported.
   401  		}
   402  	}
   403  	// TODO(b/68878065): other sockopts are not supported.
   404  	return nil, syserr.ErrProtocolNotAvailable
   405  }
   406  
   407  // SetSockOpt implements socket.Socket.SetSockOpt.
   408  func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
   409  	switch level {
   410  	case linux.SOL_SOCKET:
   411  		switch name {
   412  		case linux.SO_SNDBUF:
   413  			if len(opt) < sizeOfInt32 {
   414  				return syserr.ErrInvalidArgument
   415  			}
   416  			size := hostarch.ByteOrder.Uint32(opt)
   417  			if size < minSendBufferSize {
   418  				size = minSendBufferSize
   419  			} else if size > maxSendBufferSize {
   420  				size = maxSendBufferSize
   421  			}
   422  			s.mu.Lock()
   423  			s.sendBufferSize = size
   424  			s.mu.Unlock()
   425  			return nil
   426  
   427  		case linux.SO_RCVBUF:
   428  			if len(opt) < sizeOfInt32 {
   429  				return syserr.ErrInvalidArgument
   430  			}
   431  			// We don't have limit on receiving size. So just accept anything as
   432  			// valid for compatibility.
   433  			return nil
   434  
   435  		case linux.SO_PASSCRED:
   436  			if len(opt) < sizeOfInt32 {
   437  				return syserr.ErrInvalidArgument
   438  			}
   439  			passcred := hostarch.ByteOrder.Uint32(opt)
   440  
   441  			s.ep.SocketOptions().SetPassCred(passcred != 0)
   442  			return nil
   443  
   444  		case linux.SO_ATTACH_FILTER:
   445  			// TODO(gvisor.dev/issue/1119): We don't actually
   446  			// support filtering. If this socket can't ever send
   447  			// messages, then there is nothing to filter and we can
   448  			// advertise support. Otherwise, be conservative and
   449  			// return an error.
   450  			if s.protocol.CanSend() {
   451  				return syserr.ErrProtocolNotAvailable
   452  			}
   453  
   454  			s.mu.Lock()
   455  			s.filter = true
   456  			s.mu.Unlock()
   457  			return nil
   458  
   459  		case linux.SO_DETACH_FILTER:
   460  			// TODO(gvisor.dev/issue/1119): See above.
   461  			if s.protocol.CanSend() {
   462  				return syserr.ErrProtocolNotAvailable
   463  			}
   464  
   465  			s.mu.Lock()
   466  			filter := s.filter
   467  			s.filter = false
   468  			s.mu.Unlock()
   469  
   470  			if !filter {
   471  				return errNoFilter
   472  			}
   473  
   474  			return nil
   475  		}
   476  	case linux.SOL_NETLINK:
   477  		switch name {
   478  		case linux.NETLINK_ADD_MEMBERSHIP,
   479  			linux.NETLINK_BROADCAST_ERROR,
   480  			linux.NETLINK_CAP_ACK,
   481  			linux.NETLINK_DROP_MEMBERSHIP,
   482  			linux.NETLINK_DUMP_STRICT_CHK,
   483  			linux.NETLINK_EXT_ACK,
   484  			linux.NETLINK_LISTEN_ALL_NSID,
   485  			linux.NETLINK_NO_ENOBUFS,
   486  			linux.NETLINK_PKTINFO:
   487  			// Not supported.
   488  		}
   489  	}
   490  
   491  	// TODO(b/68878065): other sockopts are not supported.
   492  	return syserr.ErrProtocolNotAvailable
   493  }
   494  
   495  // GetSockName implements socket.Socket.GetSockName.
   496  func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
   497  	s.mu.Lock()
   498  	defer s.mu.Unlock()
   499  
   500  	sa := &linux.SockAddrNetlink{
   501  		Family: linux.AF_NETLINK,
   502  		PortID: uint32(s.portID),
   503  	}
   504  	return sa, uint32(sa.SizeBytes()), nil
   505  }
   506  
   507  // GetPeerName implements socket.Socket.GetPeerName.
   508  func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
   509  	sa := &linux.SockAddrNetlink{
   510  		Family: linux.AF_NETLINK,
   511  		// TODO(b/68878065): Support non-kernel peers. For now the peer
   512  		// must be the kernel.
   513  		PortID: 0,
   514  	}
   515  	return sa, uint32(sa.SizeBytes()), nil
   516  }
   517  
   518  // RecvMsg implements socket.Socket.RecvMsg.
   519  func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
   520  	from := &linux.SockAddrNetlink{
   521  		Family: linux.AF_NETLINK,
   522  		PortID: 0,
   523  	}
   524  	fromLen := uint32(from.SizeBytes())
   525  
   526  	trunc := flags&linux.MSG_TRUNC != 0
   527  
   528  	r := unix.EndpointReader{
   529  		Ctx:      t,
   530  		Endpoint: s.ep,
   531  		Peek:     flags&linux.MSG_PEEK != 0,
   532  	}
   533  
   534  	doRead := func() (int64, error) {
   535  		return dst.CopyOutFrom(t, &r)
   536  	}
   537  
   538  	// If MSG_TRUNC is set with a zero byte destination then we still need
   539  	// to read the message and discard it, or in the case where MSG_PEEK is
   540  	// set, leave it be. In both cases the full message length must be
   541  	// returned.
   542  	if trunc && dst.Addrs.NumBytes() == 0 {
   543  		doRead = func() (int64, error) {
   544  			err := r.Truncate()
   545  			// Always return zero for bytes read since the destination size is
   546  			// zero.
   547  			return 0, err
   548  		}
   549  	}
   550  
   551  	if n, err := doRead(); err != linuxerr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
   552  		var mflags int
   553  		if n < int64(r.MsgSize) {
   554  			mflags |= linux.MSG_TRUNC
   555  		}
   556  		if trunc {
   557  			n = int64(r.MsgSize)
   558  		}
   559  		return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
   560  	}
   561  
   562  	// We'll have to block. Register for notification and keep trying to
   563  	// receive all the data.
   564  	e, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
   565  	if err := s.EventRegister(&e); err != nil {
   566  		return 0, 0, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
   567  	}
   568  	defer s.EventUnregister(&e)
   569  
   570  	for {
   571  		if n, err := doRead(); err != linuxerr.ErrWouldBlock {
   572  			var mflags int
   573  			if n < int64(r.MsgSize) {
   574  				mflags |= linux.MSG_TRUNC
   575  			}
   576  			if trunc {
   577  				n = int64(r.MsgSize)
   578  			}
   579  			return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
   580  		}
   581  
   582  		if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
   583  			if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
   584  				return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
   585  			}
   586  			return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
   587  		}
   588  	}
   589  }
   590  
   591  // kernelSCM implements control.SCMCredentials with credentials that represent
   592  // the kernel itself rather than a Task.
   593  //
   594  // +stateify savable
   595  type kernelSCM struct{}
   596  
   597  // Equals implements transport.CredentialsControlMessage.Equals.
   598  func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool {
   599  	_, ok := oc.(kernelSCM)
   600  	return ok
   601  }
   602  
   603  // Credentials implements control.SCMCredentials.Credentials.
   604  func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
   605  	return 0, auth.RootUID, auth.RootGID
   606  }
   607  
   608  // kernelCreds is the concrete version of kernelSCM used in all creds.
   609  var kernelCreds = &kernelSCM{}
   610  
   611  // sendResponse sends the response messages in ms back to userspace.
   612  func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
   613  	// Linux combines multiple netlink messages into a single datagram.
   614  	bufs := make([][]byte, 0, len(ms.Messages))
   615  	for _, m := range ms.Messages {
   616  		bufs = append(bufs, m.Finalize())
   617  	}
   618  
   619  	// All messages are from the kernel.
   620  	cms := transport.ControlMessages{
   621  		Credentials: kernelCreds,
   622  	}
   623  
   624  	if len(bufs) > 0 {
   625  		// RecvMsg never receives the address, so we don't need to send
   626  		// one.
   627  		_, notify, err := s.connection.Send(ctx, bufs, cms, transport.Address{})
   628  		// If the buffer is full, we simply drop messages, just like
   629  		// Linux.
   630  		if err != nil && err != syserr.ErrWouldBlock {
   631  			return err
   632  		}
   633  		if notify {
   634  			s.connection.SendNotify()
   635  		}
   636  	}
   637  
   638  	// N.B. multi-part messages should still send NLMSG_DONE even if
   639  	// MessageSet contains no messages.
   640  	//
   641  	// N.B. NLMSG_DONE is always sent in a different datagram. See
   642  	// net/netlink/af_netlink.c:netlink_dump.
   643  	if ms.Multi {
   644  		m := NewMessage(linux.NetlinkMessageHeader{
   645  			Type:   linux.NLMSG_DONE,
   646  			Flags:  linux.NLM_F_MULTI,
   647  			Seq:    ms.Seq,
   648  			PortID: uint32(ms.PortID),
   649  		})
   650  
   651  		// Add the dump_done_errno payload.
   652  		m.Put(primitive.AllocateInt64(0))
   653  
   654  		_, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, transport.Address{})
   655  		if err != nil && err != syserr.ErrWouldBlock {
   656  			return err
   657  		}
   658  		if notify {
   659  			s.connection.SendNotify()
   660  		}
   661  	}
   662  
   663  	return nil
   664  }
   665  
   666  func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) {
   667  	m := ms.AddMessage(linux.NetlinkMessageHeader{
   668  		Type: linux.NLMSG_ERROR,
   669  	})
   670  	m.Put(&linux.NetlinkErrorMessage{
   671  		Error:  int32(-err.ToLinux()),
   672  		Header: hdr,
   673  	})
   674  }
   675  
   676  func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) {
   677  	m := ms.AddMessage(linux.NetlinkMessageHeader{
   678  		Type: linux.NLMSG_ERROR,
   679  	})
   680  	m.Put(&linux.NetlinkErrorMessage{
   681  		Error:  0,
   682  		Header: hdr,
   683  	})
   684  }
   685  
   686  // processMessages handles each message in buf, passing it to the protocol
   687  // handler for final handling.
   688  func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error {
   689  	for len(buf) > 0 {
   690  		msg, rest, ok := ParseMessage(buf)
   691  		if !ok {
   692  			// Linux ignores messages that are too short. See
   693  			// net/netlink/af_netlink.c:netlink_rcv_skb.
   694  			break
   695  		}
   696  		buf = rest
   697  		hdr := msg.Header()
   698  
   699  		// Ignore control messages.
   700  		if hdr.Type < linux.NLMSG_MIN_TYPE {
   701  			continue
   702  		}
   703  
   704  		ms := NewMessageSet(s.portID, hdr.Seq)
   705  		if err := s.protocol.ProcessMessage(ctx, msg, ms); err != nil {
   706  			dumpErrorMesage(hdr, ms, err)
   707  		} else if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
   708  			dumpAckMesage(hdr, ms)
   709  		}
   710  
   711  		if err := s.sendResponse(ctx, ms); err != nil {
   712  			return err
   713  		}
   714  	}
   715  
   716  	return nil
   717  }
   718  
   719  // sendMsg is the core of message send, used for SendMsg and Write.
   720  func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
   721  	dstPort := int32(0)
   722  
   723  	if len(to) != 0 {
   724  		a, err := ExtractSockAddr(to)
   725  		if err != nil {
   726  			return 0, err
   727  		}
   728  
   729  		// No support for multicast groups yet.
   730  		if a.Groups != 0 {
   731  			return 0, syserr.ErrPermissionDenied
   732  		}
   733  
   734  		dstPort = int32(a.PortID)
   735  	}
   736  
   737  	if dstPort != 0 {
   738  		// Non-kernel destinations not supported yet. Treat as if
   739  		// NL_CFG_F_NONROOT_SEND is not set.
   740  		return 0, syserr.ErrPermissionDenied
   741  	}
   742  
   743  	s.mu.Lock()
   744  	defer s.mu.Unlock()
   745  
   746  	// For simplicity, and consistency with Linux, we copy in the entire
   747  	// message up front.
   748  	if src.NumBytes() > int64(s.sendBufferSize) {
   749  		return 0, syserr.ErrMessageTooLong
   750  	}
   751  
   752  	buf := make([]byte, src.NumBytes())
   753  	n, err := src.CopyIn(ctx, buf)
   754  	// io.EOF can be only returned if src is a file, this means that
   755  	// sendMsg is called from splice and the error has to be ignored in
   756  	// this case.
   757  	if err == io.EOF {
   758  		err = nil
   759  	}
   760  	if err != nil {
   761  		// Don't partially consume messages.
   762  		return 0, syserr.FromError(err)
   763  	}
   764  
   765  	if err := s.processMessages(ctx, buf); err != nil {
   766  		return 0, err
   767  	}
   768  
   769  	return n, nil
   770  }
   771  
   772  // SendMsg implements socket.Socket.SendMsg.
   773  func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
   774  	return s.sendMsg(t, src, to, flags, controlMessages)
   775  }
   776  
   777  // State implements socket.Socket.State.
   778  func (s *Socket) State() uint32 {
   779  	return s.ep.State()
   780  }
   781  
   782  // Type implements socket.Socket.Type.
   783  func (s *Socket) Type() (family int, skType linux.SockType, protocol int) {
   784  	return linux.AF_NETLINK, s.skType, s.protocol.Protocol()
   785  }