github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/socket/netlink/socket.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package netlink provides core functionality for netlink sockets.
    16  package netlink
    17  
    18  import (
    19  	"io"
    20  	"math"
    21  
    22  	"github.com/SagerNet/gvisor/pkg/abi/linux"
    23  	"github.com/SagerNet/gvisor/pkg/abi/linux/errno"
    24  	"github.com/SagerNet/gvisor/pkg/context"
    25  	"github.com/SagerNet/gvisor/pkg/errors/linuxerr"
    26  	"github.com/SagerNet/gvisor/pkg/hostarch"
    27  	"github.com/SagerNet/gvisor/pkg/marshal"
    28  	"github.com/SagerNet/gvisor/pkg/marshal/primitive"
    29  	"github.com/SagerNet/gvisor/pkg/sentry/arch"
    30  	"github.com/SagerNet/gvisor/pkg/sentry/device"
    31  	"github.com/SagerNet/gvisor/pkg/sentry/fs"
    32  	"github.com/SagerNet/gvisor/pkg/sentry/fs/fsutil"
    33  	"github.com/SagerNet/gvisor/pkg/sentry/kernel"
    34  	"github.com/SagerNet/gvisor/pkg/sentry/kernel/auth"
    35  	ktime "github.com/SagerNet/gvisor/pkg/sentry/kernel/time"
    36  	"github.com/SagerNet/gvisor/pkg/sentry/socket"
    37  	"github.com/SagerNet/gvisor/pkg/sentry/socket/netlink/port"
    38  	"github.com/SagerNet/gvisor/pkg/sentry/socket/unix"
    39  	"github.com/SagerNet/gvisor/pkg/sentry/socket/unix/transport"
    40  	"github.com/SagerNet/gvisor/pkg/sync"
    41  	"github.com/SagerNet/gvisor/pkg/syserr"
    42  	"github.com/SagerNet/gvisor/pkg/syserror"
    43  	"github.com/SagerNet/gvisor/pkg/tcpip"
    44  	"github.com/SagerNet/gvisor/pkg/usermem"
    45  	"github.com/SagerNet/gvisor/pkg/waiter"
    46  )
    47  
    48  const sizeOfInt32 int = 4
    49  
    50  const (
    51  	// minBufferSize is the smallest size of a send buffer.
    52  	minSendBufferSize = 4 << 10 // 4096 bytes.
    53  
    54  	// defaultSendBufferSize is the default size for the send buffer.
    55  	defaultSendBufferSize = 16 * 1024
    56  
    57  	// maxBufferSize is the largest size a send buffer can grow to.
    58  	maxSendBufferSize = 4 << 20 // 4MB
    59  )
    60  
    61  var errNoFilter = syserr.New("no filter attached", errno.ENOENT)
    62  
    63  // netlinkSocketDevice is the netlink socket virtual device.
    64  var netlinkSocketDevice = device.NewAnonDevice()
    65  
    66  // LINT.IfChange
    67  
    68  // Socket is the base socket type for netlink sockets.
    69  //
    70  // This implementation only supports userspace sending and receiving messages
    71  // to/from the kernel.
    72  //
    73  // Socket implements socket.Socket and transport.Credentialer.
    74  //
    75  // +stateify savable
    76  type Socket struct {
    77  	fsutil.FilePipeSeek             `state:"nosave"`
    78  	fsutil.FileNotDirReaddir        `state:"nosave"`
    79  	fsutil.FileNoFsync              `state:"nosave"`
    80  	fsutil.FileNoMMap               `state:"nosave"`
    81  	fsutil.FileNoSplice             `state:"nosave"`
    82  	fsutil.FileNoopFlush            `state:"nosave"`
    83  	fsutil.FileUseInodeUnstableAttr `state:"nosave"`
    84  
    85  	socketOpsCommon
    86  }
    87  
    88  // socketOpsCommon contains the socket operations common to VFS1 and VFS2.
    89  //
    90  // +stateify savable
    91  type socketOpsCommon struct {
    92  	socket.SendReceiveTimeout
    93  
    94  	// ports provides netlink port allocation.
    95  	ports *port.Manager
    96  
    97  	// protocol is the netlink protocol implementation.
    98  	protocol Protocol
    99  
   100  	// skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for
   101  	// netlink sockets.
   102  	skType linux.SockType
   103  
   104  	// ep is a datagram unix endpoint used to buffer messages sent from the
   105  	// kernel to userspace. RecvMsg reads messages from this endpoint.
   106  	ep transport.Endpoint
   107  
   108  	// connection is the kernel's connection to ep, used to write messages
   109  	// sent to userspace.
   110  	connection transport.ConnectedEndpoint
   111  
   112  	// mu protects the fields below.
   113  	mu sync.Mutex `state:"nosave"`
   114  
   115  	// bound indicates that portid is valid.
   116  	bound bool
   117  
   118  	// portID is the port ID allocated for this socket.
   119  	portID int32
   120  
   121  	// sendBufferSize is the send buffer "size". We don't actually have a
   122  	// fixed buffer but only consume this many bytes.
   123  	sendBufferSize uint32
   124  
   125  	// filter indicates that this socket has a BPF filter "installed".
   126  	//
   127  	// TODO(github.com/SagerNet/issue/1119): We don't actually support filtering,
   128  	// this is just bookkeeping for tracking add/remove.
   129  	filter bool
   130  }
   131  
   132  var _ socket.Socket = (*Socket)(nil)
   133  var _ transport.Credentialer = (*Socket)(nil)
   134  
   135  // NewSocket creates a new Socket.
   136  func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) {
   137  	// Datagram endpoint used to buffer kernel -> user messages.
   138  	ep := transport.NewConnectionless(t)
   139  
   140  	// Bind the endpoint for good measure so we can connect to it. The
   141  	// bound address will never be exposed.
   142  	if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil {
   143  		ep.Close(t)
   144  		return nil, err
   145  	}
   146  
   147  	// Create a connection from which the kernel can write messages.
   148  	connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
   149  	if err != nil {
   150  		ep.Close(t)
   151  		return nil, err
   152  	}
   153  
   154  	return &Socket{
   155  		socketOpsCommon: socketOpsCommon{
   156  			ports:          t.Kernel().NetlinkPorts(),
   157  			protocol:       protocol,
   158  			skType:         skType,
   159  			ep:             ep,
   160  			connection:     connection,
   161  			sendBufferSize: defaultSendBufferSize,
   162  		},
   163  	}, nil
   164  }
   165  
   166  // Release implements fs.FileOperations.Release.
   167  func (s *socketOpsCommon) Release(ctx context.Context) {
   168  	s.connection.Release(ctx)
   169  	s.ep.Close(ctx)
   170  
   171  	if s.bound {
   172  		s.ports.Release(s.protocol.Protocol(), s.portID)
   173  	}
   174  }
   175  
   176  // Readiness implements waiter.Waitable.Readiness.
   177  func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
   178  	// ep holds messages to be read and thus handles EventIn readiness.
   179  	ready := s.ep.Readiness(mask)
   180  
   181  	if mask&waiter.WritableEvents != 0 {
   182  		// sendMsg handles messages synchronously and is thus always
   183  		// ready for writing.
   184  		ready |= waiter.WritableEvents
   185  	}
   186  
   187  	return ready
   188  }
   189  
   190  // EventRegister implements waiter.Waitable.EventRegister.
   191  func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
   192  	s.ep.EventRegister(e, mask)
   193  	// Writable readiness never changes, so no registration is needed.
   194  }
   195  
   196  // EventUnregister implements waiter.Waitable.EventUnregister.
   197  func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
   198  	s.ep.EventUnregister(e)
   199  }
   200  
   201  // Passcred implements transport.Credentialer.Passcred.
   202  func (s *socketOpsCommon) Passcred() bool {
   203  	return s.ep.SocketOptions().GetPassCred()
   204  }
   205  
   206  // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
   207  func (s *socketOpsCommon) ConnectedPasscred() bool {
   208  	// This socket is connected to the kernel, which doesn't need creds.
   209  	//
   210  	// This is arbitrary, as ConnectedPasscred on this type has no callers.
   211  	return false
   212  }
   213  
   214  // Ioctl implements fs.FileOperations.Ioctl.
   215  func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) {
   216  	// TODO(b/68878065): no ioctls supported.
   217  	return 0, syserror.ENOTTY
   218  }
   219  
   220  // ExtractSockAddr extracts the SockAddrNetlink from b.
   221  func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
   222  	if len(b) < linux.SockAddrNetlinkSize {
   223  		return nil, syserr.ErrBadAddress
   224  	}
   225  
   226  	var sa linux.SockAddrNetlink
   227  	sa.UnmarshalUnsafe(b[:sa.SizeBytes()])
   228  
   229  	if sa.Family != linux.AF_NETLINK {
   230  		return nil, syserr.ErrInvalidArgument
   231  	}
   232  
   233  	return &sa, nil
   234  }
   235  
   236  // bindPort binds this socket to a port, preferring 'port' if it is available.
   237  //
   238  // port of 0 defaults to the ThreadGroup ID.
   239  //
   240  // Preconditions: mu is held.
   241  func (s *socketOpsCommon) bindPort(t *kernel.Task, port int32) *syserr.Error {
   242  	if s.bound {
   243  		// Re-binding is only allowed if the port doesn't change.
   244  		if port != s.portID {
   245  			return syserr.ErrInvalidArgument
   246  		}
   247  
   248  		return nil
   249  	}
   250  
   251  	if port == 0 {
   252  		port = int32(t.ThreadGroup().ID())
   253  	}
   254  	port, ok := s.ports.Allocate(s.protocol.Protocol(), port)
   255  	if !ok {
   256  		return syserr.ErrBusy
   257  	}
   258  
   259  	s.portID = port
   260  	s.bound = true
   261  	return nil
   262  }
   263  
   264  // Bind implements socket.Socket.Bind.
   265  func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
   266  	a, err := ExtractSockAddr(sockaddr)
   267  	if err != nil {
   268  		return err
   269  	}
   270  
   271  	// No support for multicast groups yet.
   272  	if a.Groups != 0 {
   273  		return syserr.ErrPermissionDenied
   274  	}
   275  
   276  	s.mu.Lock()
   277  	defer s.mu.Unlock()
   278  
   279  	return s.bindPort(t, int32(a.PortID))
   280  }
   281  
   282  // Connect implements socket.Socket.Connect.
   283  func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
   284  	a, err := ExtractSockAddr(sockaddr)
   285  	if err != nil {
   286  		return err
   287  	}
   288  
   289  	// No support for multicast groups yet.
   290  	if a.Groups != 0 {
   291  		return syserr.ErrPermissionDenied
   292  	}
   293  
   294  	s.mu.Lock()
   295  	defer s.mu.Unlock()
   296  
   297  	if a.PortID == 0 {
   298  		// Netlink sockets default to connected to the kernel, but
   299  		// connecting anyways automatically binds if not already bound.
   300  		if !s.bound {
   301  			// Pass port 0 to get an auto-selected port ID.
   302  			return s.bindPort(t, 0)
   303  		}
   304  		return nil
   305  	}
   306  
   307  	// We don't support non-kernel destination ports. Linux returns EPERM
   308  	// if applications attempt to do this without NL_CFG_F_NONROOT_SEND, so
   309  	// we emulate that.
   310  	return syserr.ErrPermissionDenied
   311  }
   312  
   313  // Accept implements socket.Socket.Accept.
   314  func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
   315  	// Netlink sockets never support accept.
   316  	return 0, nil, 0, syserr.ErrNotSupported
   317  }
   318  
   319  // Listen implements socket.Socket.Listen.
   320  func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
   321  	// Netlink sockets never support listen.
   322  	return syserr.ErrNotSupported
   323  }
   324  
   325  // Shutdown implements socket.Socket.Shutdown.
   326  func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
   327  	// Netlink sockets never support shutdown.
   328  	return syserr.ErrNotSupported
   329  }
   330  
   331  // GetSockOpt implements socket.Socket.GetSockOpt.
   332  func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
   333  	switch level {
   334  	case linux.SOL_SOCKET:
   335  		switch name {
   336  		case linux.SO_SNDBUF:
   337  			if outLen < sizeOfInt32 {
   338  				return nil, syserr.ErrInvalidArgument
   339  			}
   340  			s.mu.Lock()
   341  			defer s.mu.Unlock()
   342  			return primitive.AllocateInt32(int32(s.sendBufferSize)), nil
   343  
   344  		case linux.SO_RCVBUF:
   345  			if outLen < sizeOfInt32 {
   346  				return nil, syserr.ErrInvalidArgument
   347  			}
   348  			// We don't have limit on receiving size.
   349  			return primitive.AllocateInt32(math.MaxInt32), nil
   350  
   351  		case linux.SO_PASSCRED:
   352  			if outLen < sizeOfInt32 {
   353  				return nil, syserr.ErrInvalidArgument
   354  			}
   355  			var passcred primitive.Int32
   356  			if s.Passcred() {
   357  				passcred = 1
   358  			}
   359  			return &passcred, nil
   360  
   361  		default:
   362  			socket.GetSockOptEmitUnimplementedEvent(t, name)
   363  		}
   364  
   365  	case linux.SOL_NETLINK:
   366  		switch name {
   367  		case linux.NETLINK_BROADCAST_ERROR,
   368  			linux.NETLINK_CAP_ACK,
   369  			linux.NETLINK_DUMP_STRICT_CHK,
   370  			linux.NETLINK_EXT_ACK,
   371  			linux.NETLINK_LIST_MEMBERSHIPS,
   372  			linux.NETLINK_NO_ENOBUFS,
   373  			linux.NETLINK_PKTINFO:
   374  
   375  			t.Kernel().EmitUnimplementedEvent(t)
   376  		}
   377  	}
   378  	// TODO(b/68878065): other sockopts are not supported.
   379  	return nil, syserr.ErrProtocolNotAvailable
   380  }
   381  
   382  // SetSockOpt implements socket.Socket.SetSockOpt.
   383  func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
   384  	switch level {
   385  	case linux.SOL_SOCKET:
   386  		switch name {
   387  		case linux.SO_SNDBUF:
   388  			if len(opt) < sizeOfInt32 {
   389  				return syserr.ErrInvalidArgument
   390  			}
   391  			size := hostarch.ByteOrder.Uint32(opt)
   392  			if size < minSendBufferSize {
   393  				size = minSendBufferSize
   394  			} else if size > maxSendBufferSize {
   395  				size = maxSendBufferSize
   396  			}
   397  			s.mu.Lock()
   398  			s.sendBufferSize = size
   399  			s.mu.Unlock()
   400  			return nil
   401  
   402  		case linux.SO_RCVBUF:
   403  			if len(opt) < sizeOfInt32 {
   404  				return syserr.ErrInvalidArgument
   405  			}
   406  			// We don't have limit on receiving size. So just accept anything as
   407  			// valid for compatibility.
   408  			return nil
   409  
   410  		case linux.SO_PASSCRED:
   411  			if len(opt) < sizeOfInt32 {
   412  				return syserr.ErrInvalidArgument
   413  			}
   414  			passcred := hostarch.ByteOrder.Uint32(opt)
   415  
   416  			s.ep.SocketOptions().SetPassCred(passcred != 0)
   417  			return nil
   418  
   419  		case linux.SO_ATTACH_FILTER:
   420  			// TODO(github.com/SagerNet/issue/1119): We don't actually
   421  			// support filtering. If this socket can't ever send
   422  			// messages, then there is nothing to filter and we can
   423  			// advertise support. Otherwise, be conservative and
   424  			// return an error.
   425  			if s.protocol.CanSend() {
   426  				socket.SetSockOptEmitUnimplementedEvent(t, name)
   427  				return syserr.ErrProtocolNotAvailable
   428  			}
   429  
   430  			s.mu.Lock()
   431  			s.filter = true
   432  			s.mu.Unlock()
   433  			return nil
   434  
   435  		case linux.SO_DETACH_FILTER:
   436  			// TODO(github.com/SagerNet/issue/1119): See above.
   437  			if s.protocol.CanSend() {
   438  				socket.SetSockOptEmitUnimplementedEvent(t, name)
   439  				return syserr.ErrProtocolNotAvailable
   440  			}
   441  
   442  			s.mu.Lock()
   443  			filter := s.filter
   444  			s.filter = false
   445  			s.mu.Unlock()
   446  
   447  			if !filter {
   448  				return errNoFilter
   449  			}
   450  
   451  			return nil
   452  
   453  		default:
   454  			socket.SetSockOptEmitUnimplementedEvent(t, name)
   455  		}
   456  
   457  	case linux.SOL_NETLINK:
   458  		switch name {
   459  		case linux.NETLINK_ADD_MEMBERSHIP,
   460  			linux.NETLINK_BROADCAST_ERROR,
   461  			linux.NETLINK_CAP_ACK,
   462  			linux.NETLINK_DROP_MEMBERSHIP,
   463  			linux.NETLINK_DUMP_STRICT_CHK,
   464  			linux.NETLINK_EXT_ACK,
   465  			linux.NETLINK_LISTEN_ALL_NSID,
   466  			linux.NETLINK_NO_ENOBUFS,
   467  			linux.NETLINK_PKTINFO:
   468  
   469  			t.Kernel().EmitUnimplementedEvent(t)
   470  		}
   471  
   472  	}
   473  	// TODO(b/68878065): other sockopts are not supported.
   474  	return syserr.ErrProtocolNotAvailable
   475  }
   476  
   477  // GetSockName implements socket.Socket.GetSockName.
   478  func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
   479  	s.mu.Lock()
   480  	defer s.mu.Unlock()
   481  
   482  	sa := &linux.SockAddrNetlink{
   483  		Family: linux.AF_NETLINK,
   484  		PortID: uint32(s.portID),
   485  	}
   486  	return sa, uint32(sa.SizeBytes()), nil
   487  }
   488  
   489  // GetPeerName implements socket.Socket.GetPeerName.
   490  func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
   491  	sa := &linux.SockAddrNetlink{
   492  		Family: linux.AF_NETLINK,
   493  		// TODO(b/68878065): Support non-kernel peers. For now the peer
   494  		// must be the kernel.
   495  		PortID: 0,
   496  	}
   497  	return sa, uint32(sa.SizeBytes()), nil
   498  }
   499  
   500  // RecvMsg implements socket.Socket.RecvMsg.
   501  func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
   502  	from := &linux.SockAddrNetlink{
   503  		Family: linux.AF_NETLINK,
   504  		PortID: 0,
   505  	}
   506  	fromLen := uint32(from.SizeBytes())
   507  
   508  	trunc := flags&linux.MSG_TRUNC != 0
   509  
   510  	r := unix.EndpointReader{
   511  		Ctx:      t,
   512  		Endpoint: s.ep,
   513  		Peek:     flags&linux.MSG_PEEK != 0,
   514  	}
   515  
   516  	doRead := func() (int64, error) {
   517  		return dst.CopyOutFrom(t, &r)
   518  	}
   519  
   520  	// If MSG_TRUNC is set with a zero byte destination then we still need
   521  	// to read the message and discard it, or in the case where MSG_PEEK is
   522  	// set, leave it be. In both cases the full message length must be
   523  	// returned.
   524  	if trunc && dst.Addrs.NumBytes() == 0 {
   525  		doRead = func() (int64, error) {
   526  			err := r.Truncate()
   527  			// Always return zero for bytes read since the destination size is
   528  			// zero.
   529  			return 0, err
   530  		}
   531  	}
   532  
   533  	if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
   534  		var mflags int
   535  		if n < int64(r.MsgSize) {
   536  			mflags |= linux.MSG_TRUNC
   537  		}
   538  		if trunc {
   539  			n = int64(r.MsgSize)
   540  		}
   541  		return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
   542  	}
   543  
   544  	// We'll have to block. Register for notification and keep trying to
   545  	// receive all the data.
   546  	e, ch := waiter.NewChannelEntry(nil)
   547  	s.EventRegister(&e, waiter.ReadableEvents)
   548  	defer s.EventUnregister(&e)
   549  
   550  	for {
   551  		if n, err := doRead(); err != syserror.ErrWouldBlock {
   552  			var mflags int
   553  			if n < int64(r.MsgSize) {
   554  				mflags |= linux.MSG_TRUNC
   555  			}
   556  			if trunc {
   557  				n = int64(r.MsgSize)
   558  			}
   559  			return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
   560  		}
   561  
   562  		if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
   563  			if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
   564  				return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
   565  			}
   566  			return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
   567  		}
   568  	}
   569  }
   570  
   571  // Read implements fs.FileOperations.Read.
   572  func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
   573  	if dst.NumBytes() == 0 {
   574  		return 0, nil
   575  	}
   576  	return dst.CopyOutFrom(ctx, &unix.EndpointReader{
   577  		Endpoint: s.ep,
   578  	})
   579  }
   580  
   581  // kernelSCM implements control.SCMCredentials with credentials that represent
   582  // the kernel itself rather than a Task.
   583  //
   584  // +stateify savable
   585  type kernelSCM struct{}
   586  
   587  // Equals implements transport.CredentialsControlMessage.Equals.
   588  func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool {
   589  	_, ok := oc.(kernelSCM)
   590  	return ok
   591  }
   592  
   593  // Credentials implements control.SCMCredentials.Credentials.
   594  func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
   595  	return 0, auth.RootUID, auth.RootGID
   596  }
   597  
   598  // kernelCreds is the concrete version of kernelSCM used in all creds.
   599  var kernelCreds = &kernelSCM{}
   600  
   601  // sendResponse sends the response messages in ms back to userspace.
   602  func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
   603  	// Linux combines multiple netlink messages into a single datagram.
   604  	bufs := make([][]byte, 0, len(ms.Messages))
   605  	for _, m := range ms.Messages {
   606  		bufs = append(bufs, m.Finalize())
   607  	}
   608  
   609  	// All messages are from the kernel.
   610  	cms := transport.ControlMessages{
   611  		Credentials: kernelCreds,
   612  	}
   613  
   614  	if len(bufs) > 0 {
   615  		// RecvMsg never receives the address, so we don't need to send
   616  		// one.
   617  		_, notify, err := s.connection.Send(ctx, bufs, cms, tcpip.FullAddress{})
   618  		// If the buffer is full, we simply drop messages, just like
   619  		// Linux.
   620  		if err != nil && err != syserr.ErrWouldBlock {
   621  			return err
   622  		}
   623  		if notify {
   624  			s.connection.SendNotify()
   625  		}
   626  	}
   627  
   628  	// N.B. multi-part messages should still send NLMSG_DONE even if
   629  	// MessageSet contains no messages.
   630  	//
   631  	// N.B. NLMSG_DONE is always sent in a different datagram. See
   632  	// net/netlink/af_netlink.c:netlink_dump.
   633  	if ms.Multi {
   634  		m := NewMessage(linux.NetlinkMessageHeader{
   635  			Type:   linux.NLMSG_DONE,
   636  			Flags:  linux.NLM_F_MULTI,
   637  			Seq:    ms.Seq,
   638  			PortID: uint32(ms.PortID),
   639  		})
   640  
   641  		// Add the dump_done_errno payload.
   642  		m.Put(primitive.AllocateInt64(0))
   643  
   644  		_, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{})
   645  		if err != nil && err != syserr.ErrWouldBlock {
   646  			return err
   647  		}
   648  		if notify {
   649  			s.connection.SendNotify()
   650  		}
   651  	}
   652  
   653  	return nil
   654  }
   655  
   656  func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) {
   657  	m := ms.AddMessage(linux.NetlinkMessageHeader{
   658  		Type: linux.NLMSG_ERROR,
   659  	})
   660  	m.Put(&linux.NetlinkErrorMessage{
   661  		Error:  int32(-err.ToLinux()),
   662  		Header: hdr,
   663  	})
   664  }
   665  
   666  func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) {
   667  	m := ms.AddMessage(linux.NetlinkMessageHeader{
   668  		Type: linux.NLMSG_ERROR,
   669  	})
   670  	m.Put(&linux.NetlinkErrorMessage{
   671  		Error:  0,
   672  		Header: hdr,
   673  	})
   674  }
   675  
   676  // processMessages handles each message in buf, passing it to the protocol
   677  // handler for final handling.
   678  func (s *socketOpsCommon) processMessages(ctx context.Context, buf []byte) *syserr.Error {
   679  	for len(buf) > 0 {
   680  		msg, rest, ok := ParseMessage(buf)
   681  		if !ok {
   682  			// Linux ignores messages that are too short. See
   683  			// net/netlink/af_netlink.c:netlink_rcv_skb.
   684  			break
   685  		}
   686  		buf = rest
   687  		hdr := msg.Header()
   688  
   689  		// Ignore control messages.
   690  		if hdr.Type < linux.NLMSG_MIN_TYPE {
   691  			continue
   692  		}
   693  
   694  		ms := NewMessageSet(s.portID, hdr.Seq)
   695  		if err := s.protocol.ProcessMessage(ctx, msg, ms); err != nil {
   696  			dumpErrorMesage(hdr, ms, err)
   697  		} else if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
   698  			dumpAckMesage(hdr, ms)
   699  		}
   700  
   701  		if err := s.sendResponse(ctx, ms); err != nil {
   702  			return err
   703  		}
   704  	}
   705  
   706  	return nil
   707  }
   708  
   709  // sendMsg is the core of message send, used for SendMsg and Write.
   710  func (s *socketOpsCommon) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
   711  	dstPort := int32(0)
   712  
   713  	if len(to) != 0 {
   714  		a, err := ExtractSockAddr(to)
   715  		if err != nil {
   716  			return 0, err
   717  		}
   718  
   719  		// No support for multicast groups yet.
   720  		if a.Groups != 0 {
   721  			return 0, syserr.ErrPermissionDenied
   722  		}
   723  
   724  		dstPort = int32(a.PortID)
   725  	}
   726  
   727  	if dstPort != 0 {
   728  		// Non-kernel destinations not supported yet. Treat as if
   729  		// NL_CFG_F_NONROOT_SEND is not set.
   730  		return 0, syserr.ErrPermissionDenied
   731  	}
   732  
   733  	s.mu.Lock()
   734  	defer s.mu.Unlock()
   735  
   736  	// For simplicity, and consistency with Linux, we copy in the entire
   737  	// message up front.
   738  	if src.NumBytes() > int64(s.sendBufferSize) {
   739  		return 0, syserr.ErrMessageTooLong
   740  	}
   741  
   742  	buf := make([]byte, src.NumBytes())
   743  	n, err := src.CopyIn(ctx, buf)
   744  	// io.EOF can be only returned if src is a file, this means that
   745  	// sendMsg is called from splice and the error has to be ignored in
   746  	// this case.
   747  	if err == io.EOF {
   748  		err = nil
   749  	}
   750  	if err != nil {
   751  		// Don't partially consume messages.
   752  		return 0, syserr.FromError(err)
   753  	}
   754  
   755  	if err := s.processMessages(ctx, buf); err != nil {
   756  		return 0, err
   757  	}
   758  
   759  	return n, nil
   760  }
   761  
   762  // SendMsg implements socket.Socket.SendMsg.
   763  func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
   764  	return s.sendMsg(t, src, to, flags, controlMessages)
   765  }
   766  
   767  // Write implements fs.FileOperations.Write.
   768  func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
   769  	n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
   770  	return int64(n), err.ToError()
   771  }
   772  
   773  // State implements socket.Socket.State.
   774  func (s *socketOpsCommon) State() uint32 {
   775  	return s.ep.State()
   776  }
   777  
   778  // Type implements socket.Socket.Type.
   779  func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
   780  	return linux.AF_NETLINK, s.skType, s.protocol.Protocol()
   781  }
   782  
   783  // LINT.ThenChange(./socket_vfs2.go)