github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/fsimpl/host/socket.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package host
    16  
    17  import (
    18  	"fmt"
    19  	"sync/atomic"
    20  
    21  	"golang.org/x/sys/unix"
    22  	"github.com/SagerNet/gvisor/pkg/abi/linux"
    23  	"github.com/SagerNet/gvisor/pkg/context"
    24  	"github.com/SagerNet/gvisor/pkg/errors/linuxerr"
    25  	"github.com/SagerNet/gvisor/pkg/fdnotifier"
    26  	"github.com/SagerNet/gvisor/pkg/log"
    27  	"github.com/SagerNet/gvisor/pkg/sentry/socket/control"
    28  	"github.com/SagerNet/gvisor/pkg/sentry/socket/unix/transport"
    29  	"github.com/SagerNet/gvisor/pkg/sentry/uniqueid"
    30  	"github.com/SagerNet/gvisor/pkg/sync"
    31  	"github.com/SagerNet/gvisor/pkg/syserr"
    32  	"github.com/SagerNet/gvisor/pkg/tcpip"
    33  	"github.com/SagerNet/gvisor/pkg/unet"
    34  	"github.com/SagerNet/gvisor/pkg/waiter"
    35  )
    36  
    37  // Create a new host-backed endpoint from the given fd and its corresponding
    38  // notification queue.
    39  func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transport.Endpoint, error) {
    40  	// Set up an external transport.Endpoint using the host fd.
    41  	addr := fmt.Sprintf("hostfd:[%d]", hostFD)
    42  	e, err := NewConnectedEndpoint(hostFD, addr)
    43  	if err != nil {
    44  		return nil, err.ToError()
    45  	}
    46  	ep := transport.NewExternal(ctx, e.stype, uniqueid.GlobalProviderFromContext(ctx), queue, e, e)
    47  	return ep, nil
    48  }
    49  
    50  // ConnectedEndpoint is an implementation of transport.ConnectedEndpoint and
    51  // transport.Receiver. It is backed by a host fd that was imported at sentry
    52  // startup. This fd is shared with a hostfs inode, which retains ownership of
    53  // it.
    54  //
    55  // ConnectedEndpoint is saveable, since we expect that the host will provide
    56  // the same fd upon restore.
    57  //
    58  // As of this writing, we only allow Unix sockets to be imported.
    59  //
    60  // +stateify savable
    61  type ConnectedEndpoint struct {
    62  	ConnectedEndpointRefs
    63  
    64  	// mu protects fd below.
    65  	mu sync.RWMutex `state:"nosave"`
    66  
    67  	// fd is the host fd backing this endpoint.
    68  	fd int
    69  
    70  	// addr is the address at which this endpoint is bound.
    71  	addr string
    72  
    73  	// sndbuf is the size of the send buffer.
    74  	//
    75  	// N.B. When this is smaller than the host size, we present it via
    76  	// GetSockOpt and message splitting/rejection in SendMsg, but do not
    77  	// prevent lots of small messages from filling the real send buffer
    78  	// size on the host.
    79  	sndbuf int64 `state:"nosave"`
    80  
    81  	// stype is the type of Unix socket.
    82  	stype linux.SockType
    83  }
    84  
    85  // init performs initialization required for creating new ConnectedEndpoints and
    86  // for restoring them.
    87  func (c *ConnectedEndpoint) init() *syserr.Error {
    88  	c.InitRefs()
    89  	return c.initFromOptions()
    90  }
    91  
    92  func (c *ConnectedEndpoint) initFromOptions() *syserr.Error {
    93  	family, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_DOMAIN)
    94  	if err != nil {
    95  		return syserr.FromError(err)
    96  	}
    97  
    98  	if family != unix.AF_UNIX {
    99  		// We only allow Unix sockets.
   100  		return syserr.ErrInvalidEndpointState
   101  	}
   102  
   103  	stype, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_TYPE)
   104  	if err != nil {
   105  		return syserr.FromError(err)
   106  	}
   107  
   108  	if err := unix.SetNonblock(c.fd, true); err != nil {
   109  		return syserr.FromError(err)
   110  	}
   111  
   112  	sndbuf, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_SNDBUF)
   113  	if err != nil {
   114  		return syserr.FromError(err)
   115  	}
   116  
   117  	c.stype = linux.SockType(stype)
   118  	atomic.StoreInt64(&c.sndbuf, int64(sndbuf))
   119  
   120  	return nil
   121  }
   122  
   123  // NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host fd
   124  // imported at sentry startup,
   125  //
   126  // The caller is responsible for calling Init(). Additionaly, Release needs to
   127  // be called twice because ConnectedEndpoint is both a transport.Receiver and
   128  // transport.ConnectedEndpoint.
   129  func NewConnectedEndpoint(hostFD int, addr string) (*ConnectedEndpoint, *syserr.Error) {
   130  	e := ConnectedEndpoint{
   131  		fd:   hostFD,
   132  		addr: addr,
   133  	}
   134  
   135  	if err := e.init(); err != nil {
   136  		return nil, err
   137  	}
   138  
   139  	// ConnectedEndpointRefs start off with a single reference. We need two.
   140  	e.IncRef()
   141  	return &e, nil
   142  }
   143  
   144  // Send implements transport.ConnectedEndpoint.Send.
   145  func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
   146  	c.mu.RLock()
   147  	defer c.mu.RUnlock()
   148  
   149  	if !controlMessages.Empty() {
   150  		return 0, false, syserr.ErrInvalidEndpointState
   151  	}
   152  
   153  	// Since stream sockets don't preserve message boundaries, we can write
   154  	// only as much of the message as fits in the send buffer.
   155  	truncate := c.stype == linux.SOCK_STREAM
   156  
   157  	n, totalLen, err := fdWriteVec(c.fd, data, c.SendMaxQueueSize(), truncate)
   158  	if n < totalLen && err == nil {
   159  		// The host only returns a short write if it would otherwise
   160  		// block (and only for stream sockets).
   161  		err = linuxerr.EAGAIN
   162  	}
   163  	if n > 0 && !linuxerr.Equals(linuxerr.EAGAIN, err) {
   164  		// The caller may need to block to send more data, but
   165  		// otherwise there isn't anything that can be done about an
   166  		// error with a partial write.
   167  		err = nil
   168  	}
   169  
   170  	// There is no need for the callee to call SendNotify because fdWriteVec
   171  	// uses the host's sendmsg(2) and the host kernel's queue.
   172  	return n, false, syserr.FromError(err)
   173  }
   174  
   175  // SendNotify implements transport.ConnectedEndpoint.SendNotify.
   176  func (c *ConnectedEndpoint) SendNotify() {}
   177  
   178  // CloseSend implements transport.ConnectedEndpoint.CloseSend.
   179  func (c *ConnectedEndpoint) CloseSend() {
   180  	c.mu.Lock()
   181  	defer c.mu.Unlock()
   182  
   183  	if err := unix.Shutdown(c.fd, unix.SHUT_WR); err != nil {
   184  		// A well-formed UDS shutdown can't fail. See
   185  		// net/unix/af_unix.c:unix_shutdown.
   186  		panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err))
   187  	}
   188  }
   189  
   190  // CloseNotify implements transport.ConnectedEndpoint.CloseNotify.
   191  func (c *ConnectedEndpoint) CloseNotify() {}
   192  
   193  // Writable implements transport.ConnectedEndpoint.Writable.
   194  func (c *ConnectedEndpoint) Writable() bool {
   195  	c.mu.RLock()
   196  	defer c.mu.RUnlock()
   197  
   198  	return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.WritableEvents)&waiter.WritableEvents != 0
   199  }
   200  
   201  // Passcred implements transport.ConnectedEndpoint.Passcred.
   202  func (c *ConnectedEndpoint) Passcred() bool {
   203  	// We don't support credential passing for host sockets.
   204  	return false
   205  }
   206  
   207  // GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress.
   208  func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
   209  	return tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, nil
   210  }
   211  
   212  // EventUpdate implements transport.ConnectedEndpoint.EventUpdate.
   213  func (c *ConnectedEndpoint) EventUpdate() {
   214  	c.mu.RLock()
   215  	defer c.mu.RUnlock()
   216  	if c.fd != -1 {
   217  		fdnotifier.UpdateFD(int32(c.fd))
   218  	}
   219  }
   220  
   221  // Recv implements transport.Receiver.Recv.
   222  func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
   223  	c.mu.RLock()
   224  	defer c.mu.RUnlock()
   225  
   226  	var cm unet.ControlMessage
   227  	if numRights > 0 {
   228  		cm.EnableFDs(int(numRights))
   229  	}
   230  
   231  	// N.B. Unix sockets don't have a receive buffer, the send buffer
   232  	// serves both purposes.
   233  	rl, ml, cl, cTrunc, err := fdReadVec(c.fd, data, []byte(cm), peek, c.RecvMaxQueueSize())
   234  	if rl > 0 && err != nil {
   235  		// We got some data, so all we need to do on error is return
   236  		// the data that we got. Short reads are fine, no need to
   237  		// block.
   238  		err = nil
   239  	}
   240  	if err != nil {
   241  		return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
   242  	}
   243  
   244  	// There is no need for the callee to call RecvNotify because fdReadVec uses
   245  	// the host's recvmsg(2) and the host kernel's queue.
   246  
   247  	// Trim the control data if we received less than the full amount.
   248  	if cl < uint64(len(cm)) {
   249  		cm = cm[:cl]
   250  	}
   251  
   252  	// Avoid extra allocations in the case where there isn't any control data.
   253  	if len(cm) == 0 {
   254  		return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
   255  	}
   256  
   257  	fds, err := cm.ExtractFDs()
   258  	if err != nil {
   259  		return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
   260  	}
   261  
   262  	if len(fds) == 0 {
   263  		return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
   264  	}
   265  	return rl, ml, control.NewVFS2(nil, nil, newSCMRights(fds)), cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.addr)}, false, nil
   266  }
   267  
   268  // RecvNotify implements transport.Receiver.RecvNotify.
   269  func (c *ConnectedEndpoint) RecvNotify() {}
   270  
   271  // CloseRecv implements transport.Receiver.CloseRecv.
   272  func (c *ConnectedEndpoint) CloseRecv() {
   273  	c.mu.Lock()
   274  	defer c.mu.Unlock()
   275  
   276  	if err := unix.Shutdown(c.fd, unix.SHUT_RD); err != nil {
   277  		// A well-formed UDS shutdown can't fail. See
   278  		// net/unix/af_unix.c:unix_shutdown.
   279  		panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err))
   280  	}
   281  }
   282  
   283  // Readable implements transport.Receiver.Readable.
   284  func (c *ConnectedEndpoint) Readable() bool {
   285  	c.mu.RLock()
   286  	defer c.mu.RUnlock()
   287  
   288  	return fdnotifier.NonBlockingPoll(int32(c.fd), waiter.ReadableEvents)&waiter.ReadableEvents != 0
   289  }
   290  
   291  // SendQueuedSize implements transport.Receiver.SendQueuedSize.
   292  func (c *ConnectedEndpoint) SendQueuedSize() int64 {
   293  	// TODO(github.com/SagerNet/issue/273): SendQueuedSize isn't supported for host
   294  	// sockets because we don't allow the sentry to call ioctl(2).
   295  	return -1
   296  }
   297  
   298  // RecvQueuedSize implements transport.Receiver.RecvQueuedSize.
   299  func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
   300  	// TODO(github.com/SagerNet/issue/273): RecvQueuedSize isn't supported for host
   301  	// sockets because we don't allow the sentry to call ioctl(2).
   302  	return -1
   303  }
   304  
   305  // SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize.
   306  func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
   307  	return atomic.LoadInt64(&c.sndbuf)
   308  }
   309  
   310  // RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize.
   311  func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
   312  	// N.B. Unix sockets don't use the receive buffer. We'll claim it is
   313  	// the same size as the send buffer.
   314  	return atomic.LoadInt64(&c.sndbuf)
   315  }
   316  
   317  func (c *ConnectedEndpoint) destroyLocked() {
   318  	c.fd = -1
   319  }
   320  
   321  // Release implements transport.ConnectedEndpoint.Release and
   322  // transport.Receiver.Release.
   323  func (c *ConnectedEndpoint) Release(ctx context.Context) {
   324  	c.DecRef(func() {
   325  		c.mu.Lock()
   326  		c.destroyLocked()
   327  		c.mu.Unlock()
   328  	})
   329  }
   330  
   331  // CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
   332  func (c *ConnectedEndpoint) CloseUnread() {}
   333  
   334  // SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize.
   335  func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
   336  	// gVisor does not permit setting of SO_SNDBUF for host backed unix
   337  	// domain sockets.
   338  	return atomic.LoadInt64(&c.sndbuf)
   339  }
   340  
   341  // SetReceiveBufferSize implements transport.ConnectedEndpoint.SetReceiveBufferSize.
   342  func (c *ConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) {
   343  	// gVisor does not permit setting of SO_RCVBUF for host backed unix
   344  	// domain sockets. Receive buffer does not have any effect for unix
   345  	// sockets and we claim to be the same as send buffer.
   346  	return atomic.LoadInt64(&c.sndbuf)
   347  }
   348  
   349  // SCMConnectedEndpoint represents an endpoint backed by a host fd that was
   350  // passed through a gofer Unix socket. It resembles ConnectedEndpoint, with the
   351  // following differences:
   352  // - SCMConnectedEndpoint is not saveable, because the host cannot guarantee
   353  // the same descriptor number across S/R.
   354  // - SCMConnectedEndpoint holds ownership of its fd and notification queue.
   355  type SCMConnectedEndpoint struct {
   356  	ConnectedEndpoint
   357  
   358  	queue *waiter.Queue
   359  }
   360  
   361  // Init will do the initialization required without holding other locks.
   362  func (e *SCMConnectedEndpoint) Init() error {
   363  	return fdnotifier.AddFD(int32(e.fd), e.queue)
   364  }
   365  
   366  // Release implements transport.ConnectedEndpoint.Release and
   367  // transport.Receiver.Release.
   368  func (e *SCMConnectedEndpoint) Release(ctx context.Context) {
   369  	e.DecRef(func() {
   370  		e.mu.Lock()
   371  		fdnotifier.RemoveFD(int32(e.fd))
   372  		if err := unix.Close(e.fd); err != nil {
   373  			log.Warningf("Failed to close host fd %d: %v", err)
   374  		}
   375  		e.destroyLocked()
   376  		e.mu.Unlock()
   377  	})
   378  }
   379  
   380  // NewSCMEndpoint creates a new SCMConnectedEndpoint backed by a host fd that
   381  // was passed through a Unix socket.
   382  //
   383  // The caller is responsible for calling Init(). Additionaly, Release needs to
   384  // be called twice because ConnectedEndpoint is both a transport.Receiver and
   385  // transport.ConnectedEndpoint.
   386  func NewSCMEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue, addr string) (*SCMConnectedEndpoint, *syserr.Error) {
   387  	e := SCMConnectedEndpoint{
   388  		ConnectedEndpoint: ConnectedEndpoint{
   389  			fd:   hostFD,
   390  			addr: addr,
   391  		},
   392  		queue: queue,
   393  	}
   394  
   395  	if err := e.init(); err != nil {
   396  		return nil, err
   397  	}
   398  
   399  	// e starts off with a single reference. We need two.
   400  	e.IncRef()
   401  	return &e, nil
   402  }