github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/sentry/socket/unix/transport/host.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transport
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"golang.org/x/sys/unix"
    21  	"github.com/metacubex/gvisor/pkg/abi/linux"
    22  	"github.com/metacubex/gvisor/pkg/atomicbitops"
    23  	"github.com/metacubex/gvisor/pkg/context"
    24  	"github.com/metacubex/gvisor/pkg/errors/linuxerr"
    25  	"github.com/metacubex/gvisor/pkg/fdnotifier"
    26  	"github.com/metacubex/gvisor/pkg/log"
    27  	"github.com/metacubex/gvisor/pkg/sync"
    28  	"github.com/metacubex/gvisor/pkg/syserr"
    29  	"github.com/metacubex/gvisor/pkg/tcpip"
    30  	"github.com/metacubex/gvisor/pkg/unet"
    31  	"github.com/metacubex/gvisor/pkg/waiter"
    32  )
    33  
    34  // SCMRights implements RightsControlMessage with host FDs.
    35  type SCMRights struct {
    36  	FDs []int
    37  }
    38  
    39  // Clone implements RightsControlMessage.Clone.
    40  func (c *SCMRights) Clone() RightsControlMessage {
    41  	// Host rights never need to be cloned.
    42  	return nil
    43  }
    44  
    45  // Release implements RightsControlMessage.Release.
    46  func (c *SCMRights) Release(ctx context.Context) {
    47  	for _, fd := range c.FDs {
    48  		unix.Close(fd)
    49  	}
    50  	c.FDs = nil
    51  }
    52  
    53  // HostConnectedEndpoint is an implementation of ConnectedEndpoint and
    54  // Receiver. It is backed by a host fd that was imported at sentry startup.
    55  // This fd is shared with a hostfs inode, which retains ownership of it.
    56  //
    57  // HostConnectedEndpoint is saveable, since we expect that the host will
    58  // provide the same fd upon restore.
    59  //
    60  // As of this writing, we only allow Unix sockets to be imported.
    61  //
    62  // +stateify savable
    63  type HostConnectedEndpoint struct {
    64  	HostConnectedEndpointRefs
    65  
    66  	// mu protects fd below.
    67  	mu sync.RWMutex `state:"nosave"`
    68  
    69  	// fd is the host fd backing this endpoint.
    70  	fd int
    71  
    72  	// addr is the address at which this endpoint is bound.
    73  	addr string
    74  
    75  	// sndbuf is the size of the send buffer.
    76  	//
    77  	// N.B. When this is smaller than the host size, we present it via
    78  	// GetSockOpt and message splitting/rejection in SendMsg, but do not
    79  	// prevent lots of small messages from filling the real send buffer
    80  	// size on the host.
    81  	sndbuf atomicbitops.Int64 `state:"nosave"`
    82  
    83  	// stype is the type of Unix socket.
    84  	stype linux.SockType
    85  
    86  	// rdShutdown is true if receptions have been shutdown with SHUT_RD.
    87  	rdShutdown atomicbitops.Bool
    88  
    89  	// wrShutdown is true if transmissions have been shutdown with SHUT_WR.
    90  	wrShutdown atomicbitops.Bool
    91  }
    92  
    93  // init performs initialization required for creating new
    94  // HostConnectedEndpoints and for restoring them.
    95  func (c *HostConnectedEndpoint) init() *syserr.Error {
    96  	c.InitRefs()
    97  	return c.initFromOptions()
    98  }
    99  
   100  func (c *HostConnectedEndpoint) initFromOptions() *syserr.Error {
   101  	family, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_DOMAIN)
   102  	if err != nil {
   103  		return syserr.FromError(err)
   104  	}
   105  
   106  	if family != unix.AF_UNIX {
   107  		// We only allow Unix sockets.
   108  		return syserr.ErrInvalidEndpointState
   109  	}
   110  
   111  	stype, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_TYPE)
   112  	if err != nil {
   113  		return syserr.FromError(err)
   114  	}
   115  
   116  	if err := unix.SetNonblock(c.fd, true); err != nil {
   117  		return syserr.FromError(err)
   118  	}
   119  
   120  	sndbuf, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_SNDBUF)
   121  	if err != nil {
   122  		return syserr.FromError(err)
   123  	}
   124  
   125  	c.stype = linux.SockType(stype)
   126  	c.sndbuf.Store(int64(sndbuf))
   127  
   128  	return nil
   129  }
   130  
   131  // NewHostConnectedEndpoint creates a new HostConnectedEndpoint backed by a
   132  // host fd imported at sentry startup.
   133  //
   134  // The caller is responsible for calling Init(). Additionally, Release needs to
   135  // be called twice because HostConnectedEndpoint is both a Receiver and
   136  // HostConnectedEndpoint.
   137  func NewHostConnectedEndpoint(hostFD int, addr string) (*HostConnectedEndpoint, *syserr.Error) {
   138  	e := HostConnectedEndpoint{
   139  		fd:   hostFD,
   140  		addr: addr,
   141  	}
   142  
   143  	if err := e.init(); err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	// HostConnectedEndpointRefs start off with a single reference. We need two.
   148  	e.IncRef()
   149  	return &e, nil
   150  }
   151  
   152  // SockType returns the underlying socket type.
   153  func (c *HostConnectedEndpoint) SockType() linux.SockType {
   154  	return c.stype
   155  }
   156  
   157  // Send implements ConnectedEndpoint.Send.
   158  func (c *HostConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMessages ControlMessages, from Address) (int64, bool, *syserr.Error) {
   159  	c.mu.RLock()
   160  	defer c.mu.RUnlock()
   161  
   162  	if !controlMessages.Empty() {
   163  		return 0, false, syserr.ErrInvalidEndpointState
   164  	}
   165  
   166  	// Since stream sockets don't preserve message boundaries, we can write
   167  	// only as much of the message as fits in the send buffer.
   168  	truncate := c.stype == linux.SOCK_STREAM
   169  
   170  	n, totalLen, err := fdWriteVec(c.fd, data, c.SendMaxQueueSize(), truncate)
   171  	if n < totalLen && err == nil {
   172  		// The host only returns a short write if it would otherwise
   173  		// block (and only for stream sockets).
   174  		err = linuxerr.EAGAIN
   175  	}
   176  	if n > 0 && !linuxerr.Equals(linuxerr.EAGAIN, err) {
   177  		// The caller may need to block to send more data, but
   178  		// otherwise there isn't anything that can be done about an
   179  		// error with a partial write.
   180  		err = nil
   181  	}
   182  
   183  	// There is no need for the callee to call SendNotify because fdWriteVec
   184  	// uses the host's sendmsg(2) and the host kernel's queue.
   185  	return n, false, syserr.FromError(err)
   186  }
   187  
   188  // SendNotify implements ConnectedEndpoint.SendNotify.
   189  func (c *HostConnectedEndpoint) SendNotify() {}
   190  
   191  // CloseSend implements ConnectedEndpoint.CloseSend.
   192  func (c *HostConnectedEndpoint) CloseSend() {
   193  	c.mu.Lock()
   194  	defer c.mu.Unlock()
   195  
   196  	if err := unix.Shutdown(c.fd, unix.SHUT_WR); err != nil {
   197  		// A well-formed UDS shutdown can't fail. See
   198  		// net/unix/af_unix.c:unix_shutdown.
   199  		panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err))
   200  	}
   201  	c.wrShutdown.Store(true)
   202  }
   203  
   204  // CloseNotify implements ConnectedEndpoint.CloseNotify.
   205  func (c *HostConnectedEndpoint) CloseNotify() {}
   206  
   207  // IsSendClosed implements ConnectedEndpoint.IsSendClosed.
   208  func (c *HostConnectedEndpoint) IsSendClosed() bool {
   209  	return c.wrShutdown.Load()
   210  }
   211  
   212  // Writable implements ConnectedEndpoint.Writable.
   213  func (c *HostConnectedEndpoint) Writable() bool {
   214  	c.mu.RLock()
   215  	defer c.mu.RUnlock()
   216  
   217  	return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.WritableEvents)&waiter.WritableEvents != 0
   218  }
   219  
   220  // Passcred implements ConnectedEndpoint.Passcred.
   221  func (c *HostConnectedEndpoint) Passcred() bool {
   222  	// We don't support credential passing for host sockets.
   223  	return false
   224  }
   225  
   226  // GetLocalAddress implements ConnectedEndpoint.GetLocalAddress.
   227  func (c *HostConnectedEndpoint) GetLocalAddress() (Address, tcpip.Error) {
   228  	return Address{Addr: c.addr}, nil
   229  }
   230  
   231  // EventUpdate implements ConnectedEndpoint.EventUpdate.
   232  func (c *HostConnectedEndpoint) EventUpdate() error {
   233  	c.mu.RLock()
   234  	defer c.mu.RUnlock()
   235  	if c.fd != -1 {
   236  		if err := fdnotifier.UpdateFD(int32(c.fd)); err != nil {
   237  			return err
   238  		}
   239  	}
   240  	return nil
   241  }
   242  
   243  // Recv implements Receiver.Recv.
   244  func (c *HostConnectedEndpoint) Recv(ctx context.Context, data [][]byte, args RecvArgs) (RecvOutput, bool, *syserr.Error) {
   245  	c.mu.RLock()
   246  	defer c.mu.RUnlock()
   247  
   248  	var cm unet.ControlMessage
   249  	if args.NumRights > 0 {
   250  		cm.EnableFDs(int(args.NumRights))
   251  	}
   252  
   253  	// N.B. Unix sockets don't have a receive buffer, the send buffer
   254  	// serves both purposes.
   255  	out := RecvOutput{Source: Address{Addr: c.addr}}
   256  	var err error
   257  	var controlLen uint64
   258  	out.RecvLen, out.MsgLen, controlLen, out.ControlTrunc, err = fdReadVec(c.fd, data, []byte(cm), args.Peek, c.RecvMaxQueueSize())
   259  	if out.RecvLen > 0 && err != nil {
   260  		// We got some data, so all we need to do on error is return
   261  		// the data that we got. Short reads are fine, no need to
   262  		// block.
   263  		err = nil
   264  	}
   265  	if err != nil {
   266  		return RecvOutput{}, false, syserr.FromError(err)
   267  	}
   268  
   269  	// There is no need for the callee to call RecvNotify because fdReadVec uses
   270  	// the host's recvmsg(2) and the host kernel's queue.
   271  
   272  	// Trim the control data if we received less than the full amount.
   273  	if controlLen < uint64(len(cm)) {
   274  		cm = cm[:controlLen]
   275  	}
   276  
   277  	// Avoid extra allocations in the case where there isn't any control data.
   278  	if len(cm) == 0 {
   279  		return out, false, nil
   280  	}
   281  
   282  	fds, err := cm.ExtractFDs()
   283  	if err != nil {
   284  		return RecvOutput{}, false, syserr.FromError(err)
   285  	}
   286  
   287  	if len(fds) == 0 {
   288  		return out, false, nil
   289  	}
   290  	out.Control = ControlMessages{
   291  		Rights: &SCMRights{fds},
   292  	}
   293  	return out, false, nil
   294  }
   295  
   296  // RecvNotify implements Receiver.RecvNotify.
   297  func (c *HostConnectedEndpoint) RecvNotify() {}
   298  
   299  // CloseRecv implements Receiver.CloseRecv.
   300  func (c *HostConnectedEndpoint) CloseRecv() {
   301  	c.mu.Lock()
   302  	defer c.mu.Unlock()
   303  
   304  	if err := unix.Shutdown(c.fd, unix.SHUT_RD); err != nil {
   305  		// A well-formed UDS shutdown can't fail. See
   306  		// net/unix/af_unix.c:unix_shutdown.
   307  		panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err))
   308  	}
   309  	c.rdShutdown.Store(true)
   310  }
   311  
   312  // IsRecvClosed implements Receiver.IsRecvClosed.
   313  func (c *HostConnectedEndpoint) IsRecvClosed() bool {
   314  	return c.rdShutdown.Load()
   315  }
   316  
   317  // Readable implements Receiver.Readable.
   318  func (c *HostConnectedEndpoint) Readable() bool {
   319  	c.mu.RLock()
   320  	defer c.mu.RUnlock()
   321  
   322  	return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.ReadableEvents)&waiter.ReadableEvents != 0
   323  }
   324  
   325  // SendQueuedSize implements Receiver.SendQueuedSize.
   326  func (c *HostConnectedEndpoint) SendQueuedSize() int64 {
   327  	// TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host
   328  	// sockets because we don't allow the sentry to call ioctl(2).
   329  	return -1
   330  }
   331  
   332  // RecvQueuedSize implements Receiver.RecvQueuedSize.
   333  func (c *HostConnectedEndpoint) RecvQueuedSize() int64 {
   334  	// TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host
   335  	// sockets because we don't allow the sentry to call ioctl(2).
   336  	return -1
   337  }
   338  
   339  // SendMaxQueueSize implements Receiver.SendMaxQueueSize.
   340  func (c *HostConnectedEndpoint) SendMaxQueueSize() int64 {
   341  	return c.sndbuf.Load()
   342  }
   343  
   344  // RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
   345  func (c *HostConnectedEndpoint) RecvMaxQueueSize() int64 {
   346  	// N.B. Unix sockets don't use the receive buffer. We'll claim it is
   347  	// the same size as the send buffer.
   348  	return c.sndbuf.Load()
   349  }
   350  
   351  func (c *HostConnectedEndpoint) destroyLocked() {
   352  	c.fd = -1
   353  }
   354  
   355  // Release implements ConnectedEndpoint.Release and Receiver.Release.
   356  func (c *HostConnectedEndpoint) Release(ctx context.Context) {
   357  	c.DecRef(func() {
   358  		c.mu.Lock()
   359  		c.destroyLocked()
   360  		c.mu.Unlock()
   361  	})
   362  }
   363  
   364  // CloseUnread implements ConnectedEndpoint.CloseUnread.
   365  func (c *HostConnectedEndpoint) CloseUnread() {}
   366  
   367  // SetSendBufferSize implements ConnectedEndpoint.SetSendBufferSize.
   368  func (c *HostConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
   369  	// gVisor does not permit setting of SO_SNDBUF for host backed unix
   370  	// domain sockets.
   371  	return c.sndbuf.Load()
   372  }
   373  
   374  // SetReceiveBufferSize implements ConnectedEndpoint.SetReceiveBufferSize.
   375  func (c *HostConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) {
   376  	// gVisor does not permit setting of SO_RCVBUF for host backed unix
   377  	// domain sockets. Receive buffer does not have any effect for unix
   378  	// sockets and we claim to be the same as send buffer.
   379  	return c.sndbuf.Load()
   380  }
   381  
   382  // SCMConnectedEndpoint represents an endpoint backed by a host fd that was
   383  // passed through a gofer Unix socket. It resembles HostConnectedEndpoint, with the
   384  // following differences:
   385  //   - SCMConnectedEndpoint is not saveable, because the host cannot guarantee
   386  //     the same descriptor number across S/R.
   387  //   - SCMConnectedEndpoint holds ownership of its fd and notification queue.
   388  type SCMConnectedEndpoint struct {
   389  	HostConnectedEndpoint
   390  
   391  	queue *waiter.Queue
   392  }
   393  
   394  // Init will do the initialization required without holding other locks.
   395  func (e *SCMConnectedEndpoint) Init() error {
   396  	return fdnotifier.AddFD(int32(e.fd), e.queue)
   397  }
   398  
   399  // Release implements ConnectedEndpoint.Release and Receiver.Release.
   400  func (e *SCMConnectedEndpoint) Release(ctx context.Context) {
   401  	e.DecRef(func() {
   402  		e.mu.Lock()
   403  		fdnotifier.RemoveFD(int32(e.fd))
   404  		if err := unix.Close(e.fd); err != nil {
   405  			log.Warningf("Failed to close host fd %d: %v", err)
   406  		}
   407  		e.destroyLocked()
   408  		e.mu.Unlock()
   409  	})
   410  }
   411  
   412  // NewSCMEndpoint creates a new SCMConnectedEndpoint backed by a host fd that
   413  // was passed through a Unix socket.
   414  //
   415  // The caller is responsible for calling Init(). Additionally, Release needs to
   416  // be called twice because ConnectedEndpoint is both a Receiver and
   417  // ConnectedEndpoint.
   418  func NewSCMEndpoint(hostFD int, queue *waiter.Queue, addr string) (*SCMConnectedEndpoint, *syserr.Error) {
   419  	e := SCMConnectedEndpoint{
   420  		HostConnectedEndpoint: HostConnectedEndpoint{
   421  			fd:   hostFD,
   422  			addr: addr,
   423  		},
   424  		queue: queue,
   425  	}
   426  
   427  	if err := e.init(); err != nil {
   428  		return nil, err
   429  	}
   430  
   431  	// e starts off with a single reference. We need two.
   432  	e.IncRef()
   433  	return &e, nil
   434  }