github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/fs/host/socket.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package host
    16  
    17  import (
    18  	"fmt"
    19  	"sync/atomic"
    20  
    21  	"golang.org/x/sys/unix"
    22  	"github.com/SagerNet/gvisor/pkg/abi/linux"
    23  	"github.com/SagerNet/gvisor/pkg/context"
    24  	"github.com/SagerNet/gvisor/pkg/errors/linuxerr"
    25  	"github.com/SagerNet/gvisor/pkg/fd"
    26  	"github.com/SagerNet/gvisor/pkg/fdnotifier"
    27  	"github.com/SagerNet/gvisor/pkg/refs"
    28  	"github.com/SagerNet/gvisor/pkg/sentry/fs"
    29  	"github.com/SagerNet/gvisor/pkg/sentry/socket/control"
    30  	unixsocket "github.com/SagerNet/gvisor/pkg/sentry/socket/unix"
    31  	"github.com/SagerNet/gvisor/pkg/sentry/socket/unix/transport"
    32  	"github.com/SagerNet/gvisor/pkg/sentry/uniqueid"
    33  	"github.com/SagerNet/gvisor/pkg/sync"
    34  	"github.com/SagerNet/gvisor/pkg/syserr"
    35  	"github.com/SagerNet/gvisor/pkg/tcpip"
    36  	"github.com/SagerNet/gvisor/pkg/unet"
    37  	"github.com/SagerNet/gvisor/pkg/waiter"
    38  )
    39  
    40  // LINT.IfChange
    41  
    42  // ConnectedEndpoint is a host FD backed implementation of
    43  // transport.ConnectedEndpoint and transport.Receiver.
    44  //
    45  // +stateify savable
    46  type ConnectedEndpoint struct {
    47  	// ref keeps track of references to a connectedEndpoint.
    48  	ref refs.AtomicRefCount
    49  
    50  	queue *waiter.Queue
    51  	path  string
    52  
    53  	// If srfd >= 0, it is the host FD that file was imported from.
    54  	srfd int `state:"wait"`
    55  
    56  	// stype is the type of Unix socket.
    57  	stype linux.SockType
    58  
    59  	// sndbuf is the size of the send buffer.
    60  	//
    61  	// N.B. When this is smaller than the host size, we present it via
    62  	// GetSockOpt and message splitting/rejection in SendMsg, but do not
    63  	// prevent lots of small messages from filling the real send buffer
    64  	// size on the host.
    65  	sndbuf int64 `state:"nosave"`
    66  
    67  	// mu protects the fields below.
    68  	mu sync.RWMutex `state:"nosave"`
    69  
    70  	// file is an *fd.FD containing the FD backing this endpoint. It must be
    71  	// set to nil if it has been closed.
    72  	file *fd.FD `state:"nosave"`
    73  }
    74  
    75  // init performs initialization required for creating new ConnectedEndpoints and
    76  // for restoring them.
    77  func (c *ConnectedEndpoint) init() *syserr.Error {
    78  	family, err := unix.GetsockoptInt(c.file.FD(), unix.SOL_SOCKET, unix.SO_DOMAIN)
    79  	if err != nil {
    80  		return syserr.FromError(err)
    81  	}
    82  
    83  	if family != unix.AF_UNIX {
    84  		// We only allow Unix sockets.
    85  		return syserr.ErrInvalidEndpointState
    86  	}
    87  
    88  	stype, err := unix.GetsockoptInt(c.file.FD(), unix.SOL_SOCKET, unix.SO_TYPE)
    89  	if err != nil {
    90  		return syserr.FromError(err)
    91  	}
    92  
    93  	if err := unix.SetNonblock(c.file.FD(), true); err != nil {
    94  		return syserr.FromError(err)
    95  	}
    96  
    97  	sndbuf, err := unix.GetsockoptInt(c.file.FD(), unix.SOL_SOCKET, unix.SO_SNDBUF)
    98  	if err != nil {
    99  		return syserr.FromError(err)
   100  	}
   101  
   102  	c.stype = linux.SockType(stype)
   103  	c.sndbuf = int64(sndbuf)
   104  
   105  	return nil
   106  }
   107  
   108  // NewConnectedEndpoint creates a new ConnectedEndpoint backed by a host FD
   109  // that will pretend to be bound at a given sentry path.
   110  //
   111  // The caller is responsible for calling Init(). Additionaly, Release needs to
   112  // be called twice because ConnectedEndpoint is both a transport.Receiver and
   113  // transport.ConnectedEndpoint.
   114  func NewConnectedEndpoint(ctx context.Context, file *fd.FD, queue *waiter.Queue, path string) (*ConnectedEndpoint, *syserr.Error) {
   115  	e := ConnectedEndpoint{
   116  		path:  path,
   117  		queue: queue,
   118  		file:  file,
   119  		srfd:  -1,
   120  	}
   121  
   122  	if err := e.init(); err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	// AtomicRefCounters start off with a single reference. We need two.
   127  	e.ref.IncRef()
   128  
   129  	e.ref.EnableLeakCheck("host.ConnectedEndpoint")
   130  
   131  	return &e, nil
   132  }
   133  
   134  // Init will do initialization required without holding other locks.
   135  func (c *ConnectedEndpoint) Init() {
   136  	if err := fdnotifier.AddFD(int32(c.file.FD()), c.queue); err != nil {
   137  		panic(err)
   138  	}
   139  }
   140  
   141  // NewSocketWithDirent allocates a new unix socket with host endpoint.
   142  //
   143  // This is currently only used by unsaveable Gofer nodes.
   144  //
   145  // NewSocketWithDirent takes ownership of f on success.
   146  func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.FileFlags) (*fs.File, error) {
   147  	f2 := fd.New(f.FD())
   148  	var q waiter.Queue
   149  	e, err := NewConnectedEndpoint(ctx, f2, &q, "" /* path */)
   150  	if err != nil {
   151  		f2.Release()
   152  		return nil, err.ToError()
   153  	}
   154  
   155  	// Take ownship of the FD.
   156  	f.Release()
   157  
   158  	e.Init()
   159  
   160  	ep := transport.NewExternal(ctx, e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
   161  
   162  	return unixsocket.NewWithDirent(ctx, d, ep, e.stype, flags), nil
   163  }
   164  
   165  // newSocket allocates a new unix socket with host endpoint.
   166  func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) {
   167  	ownedfd := orgfd
   168  	srfd := -1
   169  	if saveable {
   170  		var err error
   171  		ownedfd, err = unix.Dup(orgfd)
   172  		if err != nil {
   173  			return nil, err
   174  		}
   175  		srfd = orgfd
   176  	}
   177  	f := fd.New(ownedfd)
   178  	var q waiter.Queue
   179  	e, err := NewConnectedEndpoint(ctx, f, &q, "" /* path */)
   180  	if err != nil {
   181  		if saveable {
   182  			f.Close()
   183  		} else {
   184  			f.Release()
   185  		}
   186  		return nil, err.ToError()
   187  	}
   188  
   189  	e.srfd = srfd
   190  	e.Init()
   191  
   192  	ep := transport.NewExternal(ctx, e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
   193  
   194  	return unixsocket.New(ctx, ep, e.stype), nil
   195  }
   196  
   197  // Send implements transport.ConnectedEndpoint.Send.
   198  func (c *ConnectedEndpoint) Send(ctx context.Context, data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
   199  	c.mu.RLock()
   200  	defer c.mu.RUnlock()
   201  
   202  	if !controlMessages.Empty() {
   203  		return 0, false, syserr.ErrInvalidEndpointState
   204  	}
   205  
   206  	// Since stream sockets don't preserve message boundaries, we can write
   207  	// only as much of the message as fits in the send buffer.
   208  	truncate := c.stype == linux.SOCK_STREAM
   209  
   210  	n, totalLen, err := fdWriteVec(c.file.FD(), data, c.SendMaxQueueSize(), truncate)
   211  	if n < totalLen && err == nil {
   212  		// The host only returns a short write if it would otherwise
   213  		// block (and only for stream sockets).
   214  		err = linuxerr.EAGAIN
   215  	}
   216  	if n > 0 && !linuxerr.Equals(linuxerr.EAGAIN, err) {
   217  		// The caller may need to block to send more data, but
   218  		// otherwise there isn't anything that can be done about an
   219  		// error with a partial write.
   220  		err = nil
   221  	}
   222  
   223  	// There is no need for the callee to call SendNotify because fdWriteVec
   224  	// uses the host's sendmsg(2) and the host kernel's queue.
   225  	return n, false, syserr.FromError(err)
   226  }
   227  
   228  // SendNotify implements transport.ConnectedEndpoint.SendNotify.
   229  func (c *ConnectedEndpoint) SendNotify() {}
   230  
   231  // CloseSend implements transport.ConnectedEndpoint.CloseSend.
   232  func (c *ConnectedEndpoint) CloseSend() {
   233  	c.mu.Lock()
   234  	defer c.mu.Unlock()
   235  
   236  	if err := unix.Shutdown(c.file.FD(), unix.SHUT_WR); err != nil {
   237  		// A well-formed UDS shutdown can't fail. See
   238  		// net/unix/af_unix.c:unix_shutdown.
   239  		panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err))
   240  	}
   241  }
   242  
   243  // CloseNotify implements transport.ConnectedEndpoint.CloseNotify.
   244  func (c *ConnectedEndpoint) CloseNotify() {}
   245  
   246  // Writable implements transport.ConnectedEndpoint.Writable.
   247  func (c *ConnectedEndpoint) Writable() bool {
   248  	c.mu.RLock()
   249  	defer c.mu.RUnlock()
   250  
   251  	return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.WritableEvents)&waiter.WritableEvents != 0
   252  }
   253  
   254  // Passcred implements transport.ConnectedEndpoint.Passcred.
   255  func (c *ConnectedEndpoint) Passcred() bool {
   256  	// We don't support credential passing for host sockets.
   257  	return false
   258  }
   259  
   260  // GetLocalAddress implements transport.ConnectedEndpoint.GetLocalAddress.
   261  func (c *ConnectedEndpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
   262  	return tcpip.FullAddress{Addr: tcpip.Address(c.path)}, nil
   263  }
   264  
   265  // EventUpdate implements transport.ConnectedEndpoint.EventUpdate.
   266  func (c *ConnectedEndpoint) EventUpdate() {
   267  	c.mu.RLock()
   268  	defer c.mu.RUnlock()
   269  	if c.file.FD() != -1 {
   270  		fdnotifier.UpdateFD(int32(c.file.FD()))
   271  	}
   272  }
   273  
   274  // Recv implements transport.Receiver.Recv.
   275  func (c *ConnectedEndpoint) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
   276  	c.mu.RLock()
   277  	defer c.mu.RUnlock()
   278  
   279  	var cm unet.ControlMessage
   280  	if numRights > 0 {
   281  		cm.EnableFDs(int(numRights))
   282  	}
   283  
   284  	// N.B. Unix sockets don't have a receive buffer, the send buffer
   285  	// serves both purposes.
   286  	rl, ml, cl, cTrunc, err := fdReadVec(c.file.FD(), data, []byte(cm), peek, c.RecvMaxQueueSize())
   287  	if rl > 0 && err != nil {
   288  		// We got some data, so all we need to do on error is return
   289  		// the data that we got. Short reads are fine, no need to
   290  		// block.
   291  		err = nil
   292  	}
   293  	if err != nil {
   294  		return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
   295  	}
   296  
   297  	// There is no need for the callee to call RecvNotify because fdReadVec uses
   298  	// the host's recvmsg(2) and the host kernel's queue.
   299  
   300  	// Trim the control data if we received less than the full amount.
   301  	if cl < uint64(len(cm)) {
   302  		cm = cm[:cl]
   303  	}
   304  
   305  	// Avoid extra allocations in the case where there isn't any control data.
   306  	if len(cm) == 0 {
   307  		return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
   308  	}
   309  
   310  	fds, err := cm.ExtractFDs()
   311  	if err != nil {
   312  		return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.FromError(err)
   313  	}
   314  
   315  	if len(fds) == 0 {
   316  		return rl, ml, transport.ControlMessages{}, cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
   317  	}
   318  	return rl, ml, control.New(nil, nil, newSCMRights(fds)), cTrunc, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, nil
   319  }
   320  
   321  // close releases all resources related to the endpoint.
   322  func (c *ConnectedEndpoint) close(context.Context) {
   323  	fdnotifier.RemoveFD(int32(c.file.FD()))
   324  	c.file.Close()
   325  	c.file = nil
   326  }
   327  
   328  // RecvNotify implements transport.Receiver.RecvNotify.
   329  func (c *ConnectedEndpoint) RecvNotify() {}
   330  
   331  // CloseRecv implements transport.Receiver.CloseRecv.
   332  func (c *ConnectedEndpoint) CloseRecv() {
   333  	c.mu.Lock()
   334  	defer c.mu.Unlock()
   335  
   336  	if err := unix.Shutdown(c.file.FD(), unix.SHUT_RD); err != nil {
   337  		// A well-formed UDS shutdown can't fail. See
   338  		// net/unix/af_unix.c:unix_shutdown.
   339  		panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err))
   340  	}
   341  }
   342  
   343  // Readable implements transport.Receiver.Readable.
   344  func (c *ConnectedEndpoint) Readable() bool {
   345  	c.mu.RLock()
   346  	defer c.mu.RUnlock()
   347  
   348  	return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.ReadableEvents)&waiter.ReadableEvents != 0
   349  }
   350  
   351  // SendQueuedSize implements transport.Receiver.SendQueuedSize.
   352  func (c *ConnectedEndpoint) SendQueuedSize() int64 {
   353  	// TODO(github.com/SagerNet/issue/273): SendQueuedSize isn't supported for host
   354  	// sockets because we don't allow the sentry to call ioctl(2).
   355  	return -1
   356  }
   357  
   358  // RecvQueuedSize implements transport.Receiver.RecvQueuedSize.
   359  func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
   360  	// TODO(github.com/SagerNet/issue/273): RecvQueuedSize isn't supported for host
   361  	// sockets because we don't allow the sentry to call ioctl(2).
   362  	return -1
   363  }
   364  
   365  // SendMaxQueueSize implements transport.Receiver.SendMaxQueueSize.
   366  func (c *ConnectedEndpoint) SendMaxQueueSize() int64 {
   367  	return atomic.LoadInt64(&c.sndbuf)
   368  }
   369  
   370  // RecvMaxQueueSize implements transport.Receiver.RecvMaxQueueSize.
   371  func (c *ConnectedEndpoint) RecvMaxQueueSize() int64 {
   372  	// N.B. Unix sockets don't use the receive buffer. We'll claim it is
   373  	// the same size as the send buffer.
   374  	return atomic.LoadInt64(&c.sndbuf)
   375  }
   376  
   377  // Release implements transport.ConnectedEndpoint.Release and transport.Receiver.Release.
   378  func (c *ConnectedEndpoint) Release(ctx context.Context) {
   379  	c.ref.DecRefWithDestructor(ctx, c.close)
   380  }
   381  
   382  // CloseUnread implements transport.ConnectedEndpoint.CloseUnread.
   383  func (c *ConnectedEndpoint) CloseUnread() {}
   384  
   385  // SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize.
   386  func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) {
   387  	// gVisor does not permit setting of SO_SNDBUF for host backed unix
   388  	// domain sockets.
   389  	return atomic.LoadInt64(&c.sndbuf)
   390  }
   391  
   392  // SetReceiveBufferSize implements transport.ConnectedEndpoint.SetReceiveBufferSize.
   393  func (c *ConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) {
   394  	// gVisor does not permit setting of SO_RCVBUF for host backed unix
   395  	// domain sockets. Receive buffer does not have any effect for unix
   396  	// sockets and we claim to be the same as send buffer.
   397  	return atomic.LoadInt64(&c.sndbuf)
   398  }
   399  
   400  // LINT.ThenChange(../../fsimpl/host/socket.go)