github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/socket/netstack/netstack_vfs2.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package netstack
    16  
    17  import (
    18  	"github.com/SagerNet/gvisor/pkg/abi/linux"
    19  	"github.com/SagerNet/gvisor/pkg/context"
    20  	"github.com/SagerNet/gvisor/pkg/hostarch"
    21  	"github.com/SagerNet/gvisor/pkg/marshal"
    22  	"github.com/SagerNet/gvisor/pkg/marshal/primitive"
    23  	"github.com/SagerNet/gvisor/pkg/sentry/arch"
    24  	"github.com/SagerNet/gvisor/pkg/sentry/fsimpl/sockfs"
    25  	"github.com/SagerNet/gvisor/pkg/sentry/kernel"
    26  	"github.com/SagerNet/gvisor/pkg/sentry/socket"
    27  	"github.com/SagerNet/gvisor/pkg/sentry/vfs"
    28  	"github.com/SagerNet/gvisor/pkg/syserr"
    29  	"github.com/SagerNet/gvisor/pkg/syserror"
    30  	"github.com/SagerNet/gvisor/pkg/tcpip"
    31  	"github.com/SagerNet/gvisor/pkg/usermem"
    32  	"github.com/SagerNet/gvisor/pkg/waiter"
    33  )
    34  
    35  // SocketVFS2 encapsulates all the state needed to represent a network stack
    36  // endpoint in the kernel context.
    37  //
    38  // +stateify savable
    39  type SocketVFS2 struct {
    40  	vfsfd vfs.FileDescription
    41  	vfs.FileDescriptionDefaultImpl
    42  	vfs.DentryMetadataFileDescriptionImpl
    43  	vfs.LockFD
    44  
    45  	socketOpsCommon
    46  }
    47  
    48  var _ = socket.SocketVFS2(&SocketVFS2{})
    49  
    50  // NewVFS2 creates a new endpoint socket.
    51  func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) {
    52  	if skType == linux.SOCK_STREAM {
    53  		endpoint.SocketOptions().SetDelayOption(true)
    54  	}
    55  
    56  	mnt := t.Kernel().SocketMount()
    57  	d := sockfs.NewDentry(t, mnt)
    58  	defer d.DecRef(t)
    59  
    60  	s := &SocketVFS2{
    61  		socketOpsCommon: socketOpsCommon{
    62  			Queue:    queue,
    63  			family:   family,
    64  			Endpoint: endpoint,
    65  			skType:   skType,
    66  			protocol: protocol,
    67  		},
    68  	}
    69  	s.LockFD.Init(&vfs.FileLocks{})
    70  	vfsfd := &s.vfsfd
    71  	if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
    72  		DenyPRead:         true,
    73  		DenyPWrite:        true,
    74  		UseDentryMetadata: true,
    75  	}); err != nil {
    76  		return nil, syserr.FromError(err)
    77  	}
    78  	return vfsfd, nil
    79  }
    80  
    81  // Release implements vfs.FileDescriptionImpl.Release.
    82  func (s *SocketVFS2) Release(ctx context.Context) {
    83  	kernel.KernelFromContext(ctx).DeleteSocketVFS2(&s.vfsfd)
    84  	s.socketOpsCommon.Release(ctx)
    85  }
    86  
    87  // Readiness implements waiter.Waitable.Readiness.
    88  func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
    89  	return s.socketOpsCommon.Readiness(mask)
    90  }
    91  
    92  // EventRegister implements waiter.Waitable.EventRegister.
    93  func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
    94  	s.socketOpsCommon.EventRegister(e, mask)
    95  }
    96  
    97  // EventUnregister implements waiter.Waitable.EventUnregister.
    98  func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
    99  	s.socketOpsCommon.EventUnregister(e)
   100  }
   101  
   102  // Read implements vfs.FileDescriptionImpl.
   103  func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
   104  	// All flags other than RWF_NOWAIT should be ignored.
   105  	// TODO(github.com/SagerNet/issue/2601): Support RWF_NOWAIT.
   106  	if opts.Flags != 0 {
   107  		return 0, syserror.EOPNOTSUPP
   108  	}
   109  
   110  	if dst.NumBytes() == 0 {
   111  		return 0, nil
   112  	}
   113  	n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
   114  	if err == syserr.ErrWouldBlock {
   115  		return int64(n), syserror.ErrWouldBlock
   116  	}
   117  	if err != nil {
   118  		return 0, err.ToError()
   119  	}
   120  	return int64(n), nil
   121  }
   122  
   123  // Write implements vfs.FileDescriptionImpl.
   124  func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
   125  	// All flags other than RWF_NOWAIT should be ignored.
   126  	// TODO(github.com/SagerNet/issue/2601): Support RWF_NOWAIT.
   127  	if opts.Flags != 0 {
   128  		return 0, syserror.EOPNOTSUPP
   129  	}
   130  
   131  	r := src.Reader(ctx)
   132  	n, err := s.Endpoint.Write(r, tcpip.WriteOptions{})
   133  	if _, ok := err.(*tcpip.ErrWouldBlock); ok {
   134  		return 0, syserror.ErrWouldBlock
   135  	}
   136  	if err != nil {
   137  		return 0, syserr.TranslateNetstackError(err).ToError()
   138  	}
   139  
   140  	if n < src.NumBytes() {
   141  		return n, syserror.ErrWouldBlock
   142  	}
   143  
   144  	return n, nil
   145  }
   146  
   147  // Accept implements the linux syscall accept(2) for sockets backed by
   148  // tcpip.Endpoint.
   149  func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
   150  	// Issue the accept request to get the new endpoint.
   151  	var peerAddr *tcpip.FullAddress
   152  	if peerRequested {
   153  		peerAddr = &tcpip.FullAddress{}
   154  	}
   155  	ep, wq, terr := s.Endpoint.Accept(peerAddr)
   156  	if terr != nil {
   157  		if _, ok := terr.(*tcpip.ErrWouldBlock); !ok || !blocking {
   158  			return 0, nil, 0, syserr.TranslateNetstackError(terr)
   159  		}
   160  
   161  		var err *syserr.Error
   162  		ep, wq, err = s.blockingAccept(t, peerAddr)
   163  		if err != nil {
   164  			return 0, nil, 0, err
   165  		}
   166  	}
   167  
   168  	ns, err := NewVFS2(t, s.family, s.skType, s.protocol, wq, ep)
   169  	if err != nil {
   170  		return 0, nil, 0, err
   171  	}
   172  	defer ns.DecRef(t)
   173  
   174  	if err := ns.SetStatusFlags(t, t.Credentials(), uint32(flags&linux.SOCK_NONBLOCK)); err != nil {
   175  		return 0, nil, 0, syserr.FromError(err)
   176  	}
   177  
   178  	var addr linux.SockAddr
   179  	var addrLen uint32
   180  	if peerAddr != nil {
   181  		// Get address of the peer and write it to peer slice.
   182  		addr, addrLen = socket.ConvertAddress(s.family, *peerAddr)
   183  	}
   184  
   185  	fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
   186  		CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
   187  	})
   188  
   189  	t.Kernel().RecordSocketVFS2(ns)
   190  
   191  	return fd, addr, addrLen, syserr.FromError(e)
   192  }
   193  
   194  // Ioctl implements vfs.FileDescriptionImpl.
   195  func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
   196  	return s.socketOpsCommon.ioctl(ctx, uio, args)
   197  }
   198  
   199  // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
   200  // tcpip.Endpoint.
   201  func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
   202  	// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
   203  	// implemented specifically for netstack.SocketVFS2 rather than
   204  	// commonEndpoint. commonEndpoint should be extended to support socket
   205  	// options where the implementation is not shared, as unix sockets need
   206  	// their own support for SO_TIMESTAMP.
   207  	if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
   208  		if outLen < sizeOfInt32 {
   209  			return nil, syserr.ErrInvalidArgument
   210  		}
   211  		val := primitive.Int32(0)
   212  		s.readMu.Lock()
   213  		defer s.readMu.Unlock()
   214  		if s.sockOptTimestamp {
   215  			val = 1
   216  		}
   217  		return &val, nil
   218  	}
   219  	if level == linux.SOL_TCP && name == linux.TCP_INQ {
   220  		if outLen < sizeOfInt32 {
   221  			return nil, syserr.ErrInvalidArgument
   222  		}
   223  		val := primitive.Int32(0)
   224  		s.readMu.Lock()
   225  		defer s.readMu.Unlock()
   226  		if s.sockOptInq {
   227  			val = 1
   228  		}
   229  		return &val, nil
   230  	}
   231  
   232  	return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outPtr, outLen)
   233  }
   234  
   235  // SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
   236  // tcpip.Endpoint.
   237  func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
   238  	// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
   239  	// implemented specifically for netstack.SocketVFS2 rather than
   240  	// commonEndpoint. commonEndpoint should be extended to support socket
   241  	// options where the implementation is not shared, as unix sockets need
   242  	// their own support for SO_TIMESTAMP.
   243  	if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
   244  		if len(optVal) < sizeOfInt32 {
   245  			return syserr.ErrInvalidArgument
   246  		}
   247  		s.readMu.Lock()
   248  		defer s.readMu.Unlock()
   249  		s.sockOptTimestamp = hostarch.ByteOrder.Uint32(optVal) != 0
   250  		return nil
   251  	}
   252  	if level == linux.SOL_TCP && name == linux.TCP_INQ {
   253  		if len(optVal) < sizeOfInt32 {
   254  			return syserr.ErrInvalidArgument
   255  		}
   256  		s.readMu.Lock()
   257  		defer s.readMu.Unlock()
   258  		s.sockOptInq = hostarch.ByteOrder.Uint32(optVal) != 0
   259  		return nil
   260  	}
   261  
   262  	return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
   263  }