github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/sentry/socket/unix/unix.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package unix provides an implementation of the socket.Socket interface for
    16  // the AF_UNIX protocol family.
    17  package unix
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  
    23  	"golang.org/x/sys/unix"
    24  	"github.com/metacubex/gvisor/pkg/abi/linux"
    25  	"github.com/metacubex/gvisor/pkg/context"
    26  	"github.com/metacubex/gvisor/pkg/errors/linuxerr"
    27  	"github.com/metacubex/gvisor/pkg/fspath"
    28  	"github.com/metacubex/gvisor/pkg/hostarch"
    29  	"github.com/metacubex/gvisor/pkg/log"
    30  	"github.com/metacubex/gvisor/pkg/marshal"
    31  	"github.com/metacubex/gvisor/pkg/sentry/arch"
    32  	"github.com/metacubex/gvisor/pkg/sentry/fsimpl/sockfs"
    33  	"github.com/metacubex/gvisor/pkg/sentry/inet"
    34  	"github.com/metacubex/gvisor/pkg/sentry/kernel"
    35  	ktime "github.com/metacubex/gvisor/pkg/sentry/kernel/time"
    36  	"github.com/metacubex/gvisor/pkg/sentry/socket"
    37  	"github.com/metacubex/gvisor/pkg/sentry/socket/control"
    38  	"github.com/metacubex/gvisor/pkg/sentry/socket/netstack"
    39  	"github.com/metacubex/gvisor/pkg/sentry/socket/unix/transport"
    40  	"github.com/metacubex/gvisor/pkg/sentry/vfs"
    41  	"github.com/metacubex/gvisor/pkg/syserr"
    42  	"github.com/metacubex/gvisor/pkg/usermem"
    43  	"github.com/metacubex/gvisor/pkg/waiter"
    44  )
    45  
    46  // Socket implements socket.Socket (and by extension,
    47  // vfs.FileDescriptionImpl) for Unix sockets.
    48  //
    49  // +stateify savable
    50  type Socket struct {
    51  	vfsfd vfs.FileDescription
    52  	vfs.FileDescriptionDefaultImpl
    53  	vfs.DentryMetadataFileDescriptionImpl
    54  	vfs.LockFD
    55  	socket.SendReceiveTimeout
    56  	socketRefs
    57  
    58  	namespace *inet.Namespace
    59  	ep        transport.Endpoint
    60  	stype     linux.SockType
    61  
    62  	// abstractName and abstractNamespace indicate the name and namespace of the
    63  	// socket if it is bound to an abstract socket namespace. Once the socket is
    64  	// bound, they cannot be modified.
    65  	abstractName  string
    66  	abstractBound bool
    67  }
    68  
    69  var _ = socket.Socket(&Socket{})
    70  
    71  // NewSockfsFile creates a new socket file in the global sockfs mount and
    72  // returns a corresponding file description.
    73  func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
    74  	mnt := t.Kernel().SocketMount()
    75  	d := sockfs.NewDentry(t, mnt)
    76  	defer d.DecRef(t)
    77  
    78  	ns := t.GetNetworkNamespace()
    79  	fd, err := NewFileDescription(ep, stype, linux.O_RDWR, ns, mnt, d, &vfs.FileLocks{})
    80  	if err != nil {
    81  		ns.DecRef(t)
    82  		return nil, syserr.FromError(err)
    83  	}
    84  	return fd, nil
    85  }
    86  
    87  // NewFileDescription creates and returns a socket file description
    88  // corresponding to the given mount and dentry.
    89  func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, ns *inet.Namespace, mnt *vfs.Mount, d *vfs.Dentry, locks *vfs.FileLocks) (*vfs.FileDescription, error) {
    90  	// You can create AF_UNIX, SOCK_RAW sockets. They're the same as
    91  	// SOCK_DGRAM and don't require CAP_NET_RAW.
    92  	if stype == linux.SOCK_RAW {
    93  		stype = linux.SOCK_DGRAM
    94  	}
    95  
    96  	sock := &Socket{
    97  		ep:        ep,
    98  		stype:     stype,
    99  		namespace: ns,
   100  	}
   101  	sock.InitRefs()
   102  	sock.LockFD.Init(locks)
   103  	vfsfd := &sock.vfsfd
   104  	if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
   105  		DenyPRead:         true,
   106  		DenyPWrite:        true,
   107  		UseDentryMetadata: true,
   108  	}); err != nil {
   109  		return nil, err
   110  	}
   111  	return vfsfd, nil
   112  }
   113  
   114  // DecRef implements RefCounter.DecRef.
   115  func (s *Socket) DecRef(ctx context.Context) {
   116  	s.socketRefs.DecRef(func() {
   117  		kernel.KernelFromContext(ctx).DeleteSocket(&s.vfsfd)
   118  		s.ep.Close(ctx)
   119  		if s.abstractBound {
   120  			s.namespace.AbstractSockets().Remove(s.abstractName, s)
   121  		}
   122  		if s.namespace != nil {
   123  			s.namespace.DecRef(ctx)
   124  		}
   125  	})
   126  }
   127  
   128  // Release implements vfs.FileDescriptionImpl.Release.
   129  func (s *Socket) Release(ctx context.Context) {
   130  	// Release only decrements a reference on s because s may be referenced in
   131  	// the abstract socket namespace.
   132  	s.DecRef(ctx)
   133  }
   134  
   135  // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
   136  // a transport.Endpoint.
   137  func (s *Socket) GetSockOpt(t *kernel.Task, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
   138  	return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outPtr, outLen)
   139  }
   140  
   141  // blockingAccept implements a blocking version of accept(2), that is, if no
   142  // connections are ready to be accept, it will block until one becomes ready.
   143  func (s *Socket) blockingAccept(t *kernel.Task, peerAddr *transport.Address) (transport.Endpoint, *syserr.Error) {
   144  	// Register for notifications.
   145  	e, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
   146  	s.EventRegister(&e)
   147  	defer s.EventUnregister(&e)
   148  
   149  	// Try to accept the connection; if it fails, then wait until we get a
   150  	// notification.
   151  	for {
   152  		if ep, err := s.ep.Accept(t, peerAddr); err != syserr.ErrWouldBlock {
   153  			return ep, err
   154  		}
   155  
   156  		if err := t.Block(ch); err != nil {
   157  			return nil, syserr.FromError(err)
   158  		}
   159  	}
   160  }
   161  
   162  // Accept implements the linux syscall accept(2) for sockets backed by
   163  // a transport.Endpoint.
   164  func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
   165  	var peerAddr *transport.Address
   166  	if peerRequested {
   167  		peerAddr = &transport.Address{}
   168  	}
   169  	ep, err := s.ep.Accept(t, peerAddr)
   170  	if err != nil {
   171  		if err != syserr.ErrWouldBlock || !blocking {
   172  			return 0, nil, 0, err
   173  		}
   174  
   175  		var err *syserr.Error
   176  		ep, err = s.blockingAccept(t, peerAddr)
   177  		if err != nil {
   178  			return 0, nil, 0, err
   179  		}
   180  	}
   181  
   182  	ns, err := NewSockfsFile(t, ep, s.stype)
   183  	if err != nil {
   184  		return 0, nil, 0, err
   185  	}
   186  	defer ns.DecRef(t)
   187  
   188  	if flags&linux.SOCK_NONBLOCK != 0 {
   189  		ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK)
   190  	}
   191  
   192  	var addr linux.SockAddr
   193  	var addrLen uint32
   194  	if peerAddr != nil {
   195  		addr, addrLen = convertAddress(*peerAddr)
   196  	}
   197  
   198  	fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
   199  		CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
   200  	})
   201  	if e != nil {
   202  		return 0, nil, 0, syserr.FromError(e)
   203  	}
   204  
   205  	t.Kernel().RecordSocket(ns)
   206  	return fd, addr, addrLen, nil
   207  }
   208  
   209  // Bind implements the linux syscall bind(2) for unix sockets.
   210  func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
   211  	p, e := extractPath(sockaddr)
   212  	if e != nil {
   213  		return e
   214  	}
   215  
   216  	bep, ok := s.ep.(transport.BoundEndpoint)
   217  	if !ok {
   218  		// This socket can't be bound.
   219  		return syserr.ErrInvalidArgument
   220  	}
   221  
   222  	// If path is empty, the socket is autobound to an abstract address.
   223  	if len(p) == 0 || p[0] == 0 {
   224  		// Abstract socket. See net/unix/af_unix.c:unix_bind_abstract().
   225  		asn := s.namespace.AbstractSockets()
   226  		p, err := asn.Bind(t, p, bep, s)
   227  		if err != nil {
   228  			return err
   229  		}
   230  		name := p[1:]
   231  		if err := s.ep.Bind(transport.Address{Addr: p}); err != nil {
   232  			asn.Remove(name, s)
   233  			return err
   234  		}
   235  		// The socket has been successfully bound. We can update the following.
   236  		s.abstractName = name
   237  		s.abstractBound = true
   238  		return nil
   239  	}
   240  
   241  	// See net/unix/af_unix.c:unix_bind_bsd().
   242  	path := fspath.Parse(p)
   243  	root := t.FSContext().RootDirectory()
   244  	defer root.DecRef(t)
   245  	start := root
   246  	relPath := !path.Absolute
   247  	if relPath {
   248  		start = t.FSContext().WorkingDirectory()
   249  		defer start.DecRef(t)
   250  	}
   251  	pop := vfs.PathOperation{
   252  		Root:  root,
   253  		Start: start,
   254  		Path:  path,
   255  	}
   256  	stat, err := s.vfsfd.Stat(t, vfs.StatOptions{Mask: linux.STATX_MODE})
   257  	if err != nil {
   258  		return syserr.FromError(err)
   259  	}
   260  	err = t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{
   261  		Mode:     linux.FileMode(linux.S_IFSOCK | uint(stat.Mode)&^t.FSContext().Umask()),
   262  		Endpoint: bep,
   263  	})
   264  	if linuxerr.Equals(linuxerr.EEXIST, err) {
   265  		return syserr.ErrAddressInUse
   266  	}
   267  	if err != nil {
   268  		return syserr.FromError(err)
   269  	}
   270  	if err := s.ep.Bind(transport.Address{Addr: p}); err != nil {
   271  		if unlinkErr := t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &pop); unlinkErr != nil {
   272  			log.Warningf("failed to unlink socket file created for bind(%q): %v", p, unlinkErr)
   273  		}
   274  		return err
   275  	}
   276  	return nil
   277  }
   278  
   279  // Ioctl implements vfs.FileDescriptionImpl.
   280  func (s *Socket) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
   281  	return netstack.Ioctl(ctx, s.ep, uio, sysno, args)
   282  }
   283  
   284  // PRead implements vfs.FileDescriptionImpl.
   285  func (s *Socket) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
   286  	return 0, linuxerr.ESPIPE
   287  }
   288  
   289  // Read implements vfs.FileDescriptionImpl.
   290  func (s *Socket) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
   291  	// All flags other than RWF_NOWAIT should be ignored.
   292  	// TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
   293  	if opts.Flags != 0 {
   294  		return 0, linuxerr.EOPNOTSUPP
   295  	}
   296  
   297  	if dst.NumBytes() == 0 {
   298  		return 0, nil
   299  	}
   300  	r := &EndpointReader{
   301  		Ctx:       ctx,
   302  		Endpoint:  s.ep,
   303  		NumRights: 0,
   304  		Peek:      false,
   305  	}
   306  	n, err := dst.CopyOutFrom(ctx, r)
   307  	if r.Notify != nil {
   308  		r.Notify()
   309  	}
   310  	// Drop any unused rights messages.
   311  	for _, rm := range r.UnusedRights {
   312  		rm.Release(ctx)
   313  	}
   314  	// Drop control messages.
   315  	r.Control.Release(ctx)
   316  	return n, err
   317  }
   318  
   319  // PWrite implements vfs.FileDescriptionImpl.
   320  func (s *Socket) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
   321  	return 0, linuxerr.ESPIPE
   322  }
   323  
   324  // Write implements vfs.FileDescriptionImpl.
   325  func (s *Socket) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
   326  	// All flags other than RWF_NOWAIT should be ignored.
   327  	// TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
   328  	if opts.Flags != 0 {
   329  		return 0, linuxerr.EOPNOTSUPP
   330  	}
   331  
   332  	t := kernel.TaskFromContext(ctx)
   333  	ctrl := control.New(t, s.ep)
   334  
   335  	if src.NumBytes() == 0 {
   336  		nInt, notify, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil)
   337  		if notify != nil {
   338  			notify()
   339  		}
   340  		return int64(nInt), err.ToError()
   341  	}
   342  
   343  	w := &EndpointWriter{
   344  		Ctx:      ctx,
   345  		Endpoint: s.ep,
   346  		Control:  ctrl,
   347  		To:       nil,
   348  	}
   349  
   350  	n, err := src.CopyInTo(ctx, w)
   351  	if w.Notify != nil {
   352  		w.Notify()
   353  	}
   354  	return n, err
   355  
   356  }
   357  
   358  // Epollable implements FileDescriptionImpl.Epollable.
   359  func (s *Socket) Epollable() bool {
   360  	return true
   361  }
   362  
   363  // SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
   364  // a transport.Endpoint.
   365  func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
   366  	return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
   367  }
   368  
   369  // provider is a unix domain socket provider.
   370  type provider struct{}
   371  
   372  func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
   373  	// Check arguments.
   374  	if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
   375  		return nil, syserr.ErrProtocolNotSupported
   376  	}
   377  
   378  	// Create the endpoint and socket.
   379  	var ep transport.Endpoint
   380  	switch stype {
   381  	case linux.SOCK_DGRAM, linux.SOCK_RAW:
   382  		ep = transport.NewConnectionless(t)
   383  	case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
   384  		ep = transport.NewConnectioned(t, stype, t.Kernel())
   385  	default:
   386  		return nil, syserr.ErrInvalidArgument
   387  	}
   388  
   389  	f, err := NewSockfsFile(t, ep, stype)
   390  	if err != nil {
   391  		ep.Close(t)
   392  		return nil, err
   393  	}
   394  	return f, nil
   395  }
   396  
   397  // Pair creates a new pair of AF_UNIX connected sockets.
   398  func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
   399  	// Check arguments.
   400  	if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
   401  		return nil, nil, syserr.ErrProtocolNotSupported
   402  	}
   403  
   404  	switch stype {
   405  	case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW:
   406  		// Ok
   407  	default:
   408  		return nil, nil, syserr.ErrInvalidArgument
   409  	}
   410  
   411  	// Create the endpoints and sockets.
   412  	ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
   413  	s1, err := NewSockfsFile(t, ep1, stype)
   414  	if err != nil {
   415  		ep1.Close(t)
   416  		ep2.Close(t)
   417  		return nil, nil, err
   418  	}
   419  	s2, err := NewSockfsFile(t, ep2, stype)
   420  	if err != nil {
   421  		s1.DecRef(t)
   422  		ep2.Close(t)
   423  		return nil, nil, err
   424  	}
   425  
   426  	return s1, s2, nil
   427  }
   428  
   429  func (s *Socket) isPacket() bool {
   430  	switch s.stype {
   431  	case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
   432  		return true
   433  	case linux.SOCK_STREAM:
   434  		return false
   435  	default:
   436  		// We shouldn't have allowed any other socket types during creation.
   437  		panic(fmt.Sprintf("Invalid socket type %d", s.stype))
   438  	}
   439  }
   440  
   441  // Endpoint extracts the transport.Endpoint.
   442  func (s *Socket) Endpoint() transport.Endpoint {
   443  	return s.ep
   444  }
   445  
   446  // extractPath extracts and validates the address.
   447  func extractPath(sockaddr []byte) (string, *syserr.Error) {
   448  	addr, family, err := addressAndFamily(sockaddr)
   449  	if err != nil {
   450  		if err == syserr.ErrAddressFamilyNotSupported {
   451  			err = syserr.ErrInvalidArgument
   452  		}
   453  		return "", err
   454  	}
   455  	if family != linux.AF_UNIX {
   456  		return "", syserr.ErrInvalidArgument
   457  	}
   458  
   459  	// The address is trimmed by GetAddress.
   460  	p := addr.Addr
   461  	if len(p) > 0 && p[len(p)-1] == '/' {
   462  		// Weird, they tried to bind '/a/b/c/'?
   463  		return "", syserr.ErrIsDir
   464  	}
   465  
   466  	return p, nil
   467  }
   468  
   469  func addressAndFamily(addr []byte) (transport.Address, uint16, *syserr.Error) {
   470  	// Make sure we have at least 2 bytes for the address family.
   471  	if len(addr) < 2 {
   472  		return transport.Address{}, 0, syserr.ErrInvalidArgument
   473  	}
   474  
   475  	// Get the rest of the fields based on the address family.
   476  	switch family := hostarch.ByteOrder.Uint16(addr); family {
   477  	case linux.AF_UNIX:
   478  		path := addr[2:]
   479  		if len(path) > linux.UnixPathMax {
   480  			return transport.Address{}, family, syserr.ErrInvalidArgument
   481  		}
   482  		// Drop the terminating NUL (if one exists) and everything after
   483  		// it for filesystem (non-abstract) addresses.
   484  		if len(path) > 0 && path[0] != 0 {
   485  			if n := bytes.IndexByte(path[1:], 0); n >= 0 {
   486  				path = path[:n+1]
   487  			}
   488  		}
   489  		return transport.Address{
   490  			Addr: string(path),
   491  		}, family, nil
   492  	}
   493  	return transport.Address{}, 0, syserr.ErrAddressFamilyNotSupported
   494  }
   495  
   496  // GetPeerName implements the linux syscall getpeername(2) for sockets backed by
   497  // a transport.Endpoint.
   498  func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
   499  	addr, err := s.ep.GetRemoteAddress()
   500  	if err != nil {
   501  		return nil, 0, syserr.TranslateNetstackError(err)
   502  	}
   503  
   504  	a, l := convertAddress(addr)
   505  	return a, l, nil
   506  }
   507  
   508  // GetSockName implements the linux syscall getsockname(2) for sockets backed by
   509  // a transport.Endpoint.
   510  func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
   511  	addr, err := s.ep.GetLocalAddress()
   512  	if err != nil {
   513  		return nil, 0, syserr.TranslateNetstackError(err)
   514  	}
   515  
   516  	a, l := convertAddress(addr)
   517  	return a, l, nil
   518  }
   519  
   520  // Listen implements the linux syscall listen(2) for sockets backed by
   521  // a transport.Endpoint.
   522  func (s *Socket) Listen(t *kernel.Task, backlog int) *syserr.Error {
   523  	return s.ep.Listen(t, backlog)
   524  }
   525  
   526  // extractEndpoint retrieves the transport.BoundEndpoint associated with a Unix
   527  // socket path. The Release must be called on the transport.BoundEndpoint when
   528  // the caller is done with it.
   529  func (s *Socket) extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, *syserr.Error) {
   530  	path, err := extractPath(sockaddr)
   531  	if err != nil {
   532  		return nil, err
   533  	}
   534  	if path == "" {
   535  		// Not allowed.
   536  		return nil, syserr.ErrInvalidArgument
   537  	}
   538  
   539  	// Is it abstract?
   540  	if path[0] == 0 {
   541  		ep := s.namespace.AbstractSockets().BoundEndpoint(path[1:])
   542  		if ep == nil {
   543  			// No socket found.
   544  			return nil, syserr.ErrConnectionRefused
   545  		}
   546  
   547  		return ep, nil
   548  	}
   549  
   550  	p := fspath.Parse(path)
   551  	root := t.FSContext().RootDirectory()
   552  	start := root
   553  	relPath := !p.Absolute
   554  	if relPath {
   555  		start = t.FSContext().WorkingDirectory()
   556  	}
   557  	pop := vfs.PathOperation{
   558  		Root:               root,
   559  		Start:              start,
   560  		Path:               p,
   561  		FollowFinalSymlink: true,
   562  	}
   563  	ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path})
   564  	root.DecRef(t)
   565  	if relPath {
   566  		start.DecRef(t)
   567  	}
   568  	if e != nil {
   569  		return nil, syserr.FromError(e)
   570  	}
   571  	return ep, nil
   572  }
   573  
   574  // Connect implements the linux syscall connect(2) for unix sockets.
   575  func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
   576  	ep, err := s.extractEndpoint(t, sockaddr)
   577  	if err != nil {
   578  		return err
   579  	}
   580  	defer ep.Release(t)
   581  
   582  	// Connect the server endpoint.
   583  	err = s.ep.Connect(t, ep)
   584  
   585  	if err == syserr.ErrWrongProtocolForSocket {
   586  		// Linux for abstract sockets returns ErrConnectionRefused
   587  		// instead of ErrWrongProtocolForSocket.
   588  		path, _ := extractPath(sockaddr)
   589  		if len(path) > 0 && path[0] == 0 {
   590  			err = syserr.ErrConnectionRefused
   591  		}
   592  	}
   593  
   594  	return err
   595  }
   596  
   597  // SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
   598  // a transport.Endpoint.
   599  func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
   600  	w := EndpointWriter{
   601  		Ctx:      t,
   602  		Endpoint: s.ep,
   603  		Control:  controlMessages.Unix,
   604  		To:       nil,
   605  	}
   606  	if len(to) > 0 {
   607  		switch s.stype {
   608  		case linux.SOCK_SEQPACKET:
   609  			// to is ignored.
   610  		case linux.SOCK_STREAM:
   611  			if s.State() == linux.SS_CONNECTED {
   612  				return 0, syserr.ErrAlreadyConnected
   613  			}
   614  			return 0, syserr.ErrNotSupported
   615  		default:
   616  			ep, err := s.extractEndpoint(t, to)
   617  			if err != nil {
   618  				return 0, err
   619  			}
   620  			defer ep.Release(t)
   621  			w.To = ep
   622  
   623  			if ep.Passcred() && w.Control.Credentials == nil {
   624  				w.Control.Credentials = control.MakeCreds(t)
   625  			}
   626  		}
   627  	}
   628  
   629  	n, err := src.CopyInTo(t, &w)
   630  	if w.Notify != nil {
   631  		w.Notify()
   632  	}
   633  	if err != linuxerr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
   634  		return int(n), syserr.FromError(err)
   635  	}
   636  
   637  	// Only send SCM Rights once (see net/unix/af_unix.c:unix_stream_sendmsg).
   638  	w.Control.Rights = nil
   639  
   640  	// We'll have to block. Register for notification and keep trying to
   641  	// send all the data.
   642  	e, ch := waiter.NewChannelEntry(waiter.WritableEvents)
   643  	s.EventRegister(&e)
   644  	defer s.EventUnregister(&e)
   645  
   646  	total := n
   647  	for {
   648  		// Shorten src to reflect bytes previously written.
   649  		src = src.DropFirst64(n)
   650  
   651  		n, err = src.CopyInTo(t, &w)
   652  		if w.Notify != nil {
   653  			w.Notify()
   654  		}
   655  		total += n
   656  		if err != linuxerr.ErrWouldBlock {
   657  			break
   658  		}
   659  
   660  		if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
   661  			if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
   662  				err = linuxerr.ErrWouldBlock
   663  			}
   664  			break
   665  		}
   666  	}
   667  
   668  	return int(total), syserr.FromError(err)
   669  }
   670  
   671  // Passcred implements transport.Credentialer.Passcred.
   672  func (s *Socket) Passcred() bool {
   673  	return s.ep.Passcred()
   674  }
   675  
   676  // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
   677  func (s *Socket) ConnectedPasscred() bool {
   678  	return s.ep.ConnectedPasscred()
   679  }
   680  
   681  // Readiness implements waiter.Waitable.Readiness.
   682  func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask {
   683  	return s.ep.Readiness(mask)
   684  }
   685  
   686  // EventRegister implements waiter.Waitable.EventRegister.
   687  func (s *Socket) EventRegister(e *waiter.Entry) error {
   688  	return s.ep.EventRegister(e)
   689  }
   690  
   691  // EventUnregister implements waiter.Waitable.EventUnregister.
   692  func (s *Socket) EventUnregister(e *waiter.Entry) {
   693  	s.ep.EventUnregister(e)
   694  }
   695  
   696  // Shutdown implements the linux syscall shutdown(2) for sockets backed by
   697  // a transport.Endpoint.
   698  func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error {
   699  	f, err := netstack.ConvertShutdown(how)
   700  	if err != nil {
   701  		return err
   702  	}
   703  
   704  	// Issue shutdown request.
   705  	return s.ep.Shutdown(f)
   706  }
   707  
   708  // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
   709  // a transport.Endpoint.
   710  func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
   711  	trunc := flags&linux.MSG_TRUNC != 0
   712  	peek := flags&linux.MSG_PEEK != 0
   713  	dontWait := flags&linux.MSG_DONTWAIT != 0
   714  	waitAll := flags&linux.MSG_WAITALL != 0
   715  	isPacket := s.isPacket()
   716  
   717  	// Calculate the number of FDs for which we have space and if we are
   718  	// requesting credentials.
   719  	var wantCreds bool
   720  	rightsLen := int(controlDataLen) - unix.SizeofCmsghdr
   721  	if s.Passcred() {
   722  		// Credentials take priority if they are enabled and there is space.
   723  		wantCreds = rightsLen > 0
   724  		if !wantCreds {
   725  			msgFlags |= linux.MSG_CTRUNC
   726  		}
   727  		credLen := unix.CmsgSpace(unix.SizeofUcred)
   728  		rightsLen -= credLen
   729  	}
   730  	// FDs are 32 bit (4 byte) ints.
   731  	numRights := rightsLen / 4
   732  	if numRights < 0 {
   733  		numRights = 0
   734  	}
   735  
   736  	r := EndpointReader{
   737  		Ctx:       t,
   738  		Endpoint:  s.ep,
   739  		Creds:     wantCreds,
   740  		NumRights: numRights,
   741  		Peek:      peek,
   742  	}
   743  
   744  	doRead := func() (int64, error) {
   745  		n, err := dst.CopyOutFrom(t, &r)
   746  		if r.Notify != nil {
   747  			r.Notify()
   748  		}
   749  		return n, err
   750  	}
   751  
   752  	// Drop any unused rights messages after reading.
   753  	defer func() {
   754  		for _, rm := range r.UnusedRights {
   755  			rm.Release(t)
   756  		}
   757  	}()
   758  
   759  	// If MSG_TRUNC is set with a zero byte destination then we still need
   760  	// to read the message and discard it, or in the case where MSG_PEEK is
   761  	// set, leave it be. In both cases the full message length must be
   762  	// returned.
   763  	if trunc && dst.Addrs.NumBytes() == 0 {
   764  		doRead = func() (int64, error) {
   765  			err := r.Truncate()
   766  			// Always return zero for bytes read since the destination size is
   767  			// zero.
   768  			return 0, err
   769  		}
   770  
   771  	}
   772  
   773  	var total int64
   774  	if n, err := doRead(); err != linuxerr.ErrWouldBlock || dontWait {
   775  		var from linux.SockAddr
   776  		var fromLen uint32
   777  		if senderRequested && len([]byte(r.From.Addr)) != 0 {
   778  			from, fromLen = convertAddress(r.From)
   779  		}
   780  
   781  		if r.ControlTrunc {
   782  			msgFlags |= linux.MSG_CTRUNC
   783  		}
   784  
   785  		if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() {
   786  			if isPacket && n < int64(r.MsgSize) {
   787  				msgFlags |= linux.MSG_TRUNC
   788  			}
   789  
   790  			if trunc {
   791  				n = int64(r.MsgSize)
   792  			}
   793  
   794  			return int(n), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
   795  		}
   796  
   797  		// Don't overwrite any data we received.
   798  		dst = dst.DropFirst64(n)
   799  		total += n
   800  	}
   801  
   802  	// We'll have to block. Register for notification and keep trying to
   803  	// send all the data.
   804  	e, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
   805  	s.EventRegister(&e)
   806  	defer s.EventUnregister(&e)
   807  
   808  	for {
   809  		if n, err := doRead(); err != linuxerr.ErrWouldBlock {
   810  			var from linux.SockAddr
   811  			var fromLen uint32
   812  			if senderRequested {
   813  				from, fromLen = convertAddress(r.From)
   814  			}
   815  
   816  			if r.ControlTrunc {
   817  				msgFlags |= linux.MSG_CTRUNC
   818  			}
   819  
   820  			if trunc {
   821  				// n and r.MsgSize are the same for streams.
   822  				total += int64(r.MsgSize)
   823  			} else {
   824  				total += n
   825  			}
   826  
   827  			streamPeerClosed := s.stype == linux.SOCK_STREAM && n == 0 && err == nil
   828  			if err != nil || !waitAll || isPacket || n >= dst.NumBytes() || streamPeerClosed {
   829  				if total > 0 {
   830  					err = nil
   831  				}
   832  				if isPacket && n < int64(r.MsgSize) {
   833  					msgFlags |= linux.MSG_TRUNC
   834  				}
   835  				return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
   836  			}
   837  
   838  			// Don't overwrite any data we received.
   839  			dst = dst.DropFirst64(n)
   840  		}
   841  
   842  		if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
   843  			if total > 0 {
   844  				err = nil
   845  			}
   846  			if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
   847  				return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
   848  			}
   849  			return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
   850  		}
   851  	}
   852  }
   853  
   854  // State implements socket.Socket.State.
   855  func (s *Socket) State() uint32 {
   856  	return s.ep.State()
   857  }
   858  
   859  // Type implements socket.Socket.Type.
   860  func (s *Socket) Type() (family int, skType linux.SockType, protocol int) {
   861  	// Unix domain sockets always have a protocol of 0.
   862  	return linux.AF_UNIX, s.stype, 0
   863  }
   864  
   865  func convertAddress(addr transport.Address) (linux.SockAddr, uint32) {
   866  	var out linux.SockAddrUnix
   867  	out.Family = linux.AF_UNIX
   868  	l := len([]byte(addr.Addr))
   869  	for i := 0; i < l; i++ {
   870  		out.Path[i] = int8(addr.Addr[i])
   871  	}
   872  
   873  	// Linux returns the used length of the address struct (including the
   874  	// null terminator) for filesystem paths. The Family field is 2 bytes.
   875  	// It is sometimes allowed to exclude the null terminator if the
   876  	// address length is the max. Abstract and empty paths always return
   877  	// the full exact length.
   878  	if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
   879  		return &out, uint32(2 + l)
   880  	}
   881  	return &out, uint32(3 + l)
   882  
   883  }
   884  
   885  func init() {
   886  	socket.RegisterProvider(linux.AF_UNIX, &provider{})
   887  }