github.com/ttpreport/gvisor-ligolo@v0.0.0-20240123134145-a858404967ba/pkg/sentry/socket/unix/unix.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package unix provides an implementation of the socket.Socket interface for
    16  // the AF_UNIX protocol family.
    17  package unix
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  
    23  	"github.com/ttpreport/gvisor-ligolo/pkg/abi/linux"
    24  	"github.com/ttpreport/gvisor-ligolo/pkg/context"
    25  	"github.com/ttpreport/gvisor-ligolo/pkg/errors/linuxerr"
    26  	"github.com/ttpreport/gvisor-ligolo/pkg/fspath"
    27  	"github.com/ttpreport/gvisor-ligolo/pkg/hostarch"
    28  	"github.com/ttpreport/gvisor-ligolo/pkg/log"
    29  	"github.com/ttpreport/gvisor-ligolo/pkg/marshal"
    30  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/arch"
    31  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/fsimpl/sockfs"
    32  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/kernel"
    33  	ktime "github.com/ttpreport/gvisor-ligolo/pkg/sentry/kernel/time"
    34  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/socket"
    35  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/socket/control"
    36  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/socket/netstack"
    37  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/socket/unix/transport"
    38  	"github.com/ttpreport/gvisor-ligolo/pkg/sentry/vfs"
    39  	"github.com/ttpreport/gvisor-ligolo/pkg/syserr"
    40  	"github.com/ttpreport/gvisor-ligolo/pkg/usermem"
    41  	"github.com/ttpreport/gvisor-ligolo/pkg/waiter"
    42  	"golang.org/x/sys/unix"
    43  )
    44  
    45  // Socket implements socket.Socket (and by extension,
    46  // vfs.FileDescriptionImpl) for Unix sockets.
    47  //
    48  // +stateify savable
    49  type Socket struct {
    50  	vfsfd vfs.FileDescription
    51  	vfs.FileDescriptionDefaultImpl
    52  	vfs.DentryMetadataFileDescriptionImpl
    53  	vfs.LockFD
    54  	socket.SendReceiveTimeout
    55  	socketRefs
    56  
    57  	ep    transport.Endpoint
    58  	stype linux.SockType
    59  
    60  	// abstractName and abstractNamespace indicate the name and namespace of the
    61  	// socket if it is bound to an abstract socket namespace. Once the socket is
    62  	// bound, they cannot be modified.
    63  	abstractName      string
    64  	abstractNamespace *kernel.AbstractSocketNamespace
    65  }
    66  
    67  var _ = socket.Socket(&Socket{})
    68  
    69  // NewSockfsFile creates a new socket file in the global sockfs mount and
    70  // returns a corresponding file description.
    71  func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
    72  	mnt := t.Kernel().SocketMount()
    73  	d := sockfs.NewDentry(t, mnt)
    74  	defer d.DecRef(t)
    75  
    76  	fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{})
    77  	if err != nil {
    78  		return nil, syserr.FromError(err)
    79  	}
    80  	return fd, nil
    81  }
    82  
    83  // NewFileDescription creates and returns a socket file description
    84  // corresponding to the given mount and dentry.
    85  func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *vfs.FileLocks) (*vfs.FileDescription, error) {
    86  	// You can create AF_UNIX, SOCK_RAW sockets. They're the same as
    87  	// SOCK_DGRAM and don't require CAP_NET_RAW.
    88  	if stype == linux.SOCK_RAW {
    89  		stype = linux.SOCK_DGRAM
    90  	}
    91  
    92  	sock := &Socket{
    93  		ep:    ep,
    94  		stype: stype,
    95  	}
    96  	sock.InitRefs()
    97  	sock.LockFD.Init(locks)
    98  	vfsfd := &sock.vfsfd
    99  	if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
   100  		DenyPRead:         true,
   101  		DenyPWrite:        true,
   102  		UseDentryMetadata: true,
   103  	}); err != nil {
   104  		return nil, err
   105  	}
   106  	return vfsfd, nil
   107  }
   108  
   109  // DecRef implements RefCounter.DecRef.
   110  func (s *Socket) DecRef(ctx context.Context) {
   111  	s.socketRefs.DecRef(func() {
   112  		kernel.KernelFromContext(ctx).DeleteSocket(&s.vfsfd)
   113  		s.ep.Close(ctx)
   114  		if s.abstractNamespace != nil {
   115  			s.abstractNamespace.Remove(s.abstractName, s)
   116  		}
   117  	})
   118  }
   119  
   120  // Release implements vfs.FileDescriptionImpl.Release.
   121  func (s *Socket) Release(ctx context.Context) {
   122  	// Release only decrements a reference on s because s may be referenced in
   123  	// the abstract socket namespace.
   124  	s.DecRef(ctx)
   125  }
   126  
   127  // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
   128  // a transport.Endpoint.
   129  func (s *Socket) GetSockOpt(t *kernel.Task, level, name int, outPtr hostarch.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
   130  	return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outPtr, outLen)
   131  }
   132  
   133  // blockingAccept implements a blocking version of accept(2), that is, if no
   134  // connections are ready to be accept, it will block until one becomes ready.
   135  func (s *Socket) blockingAccept(t *kernel.Task, peerAddr *transport.Address) (transport.Endpoint, *syserr.Error) {
   136  	// Register for notifications.
   137  	e, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
   138  	s.EventRegister(&e)
   139  	defer s.EventUnregister(&e)
   140  
   141  	// Try to accept the connection; if it fails, then wait until we get a
   142  	// notification.
   143  	for {
   144  		if ep, err := s.ep.Accept(t, peerAddr); err != syserr.ErrWouldBlock {
   145  			return ep, err
   146  		}
   147  
   148  		if err := t.Block(ch); err != nil {
   149  			return nil, syserr.FromError(err)
   150  		}
   151  	}
   152  }
   153  
   154  // Accept implements the linux syscall accept(2) for sockets backed by
   155  // a transport.Endpoint.
   156  func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
   157  	var peerAddr *transport.Address
   158  	if peerRequested {
   159  		peerAddr = &transport.Address{}
   160  	}
   161  	ep, err := s.ep.Accept(t, peerAddr)
   162  	if err != nil {
   163  		if err != syserr.ErrWouldBlock || !blocking {
   164  			return 0, nil, 0, err
   165  		}
   166  
   167  		var err *syserr.Error
   168  		ep, err = s.blockingAccept(t, peerAddr)
   169  		if err != nil {
   170  			return 0, nil, 0, err
   171  		}
   172  	}
   173  
   174  	ns, err := NewSockfsFile(t, ep, s.stype)
   175  	if err != nil {
   176  		return 0, nil, 0, err
   177  	}
   178  	defer ns.DecRef(t)
   179  
   180  	if flags&linux.SOCK_NONBLOCK != 0 {
   181  		ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK)
   182  	}
   183  
   184  	var addr linux.SockAddr
   185  	var addrLen uint32
   186  	if peerAddr != nil {
   187  		addr, addrLen = convertAddress(*peerAddr)
   188  	}
   189  
   190  	fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
   191  		CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
   192  	})
   193  	if e != nil {
   194  		return 0, nil, 0, syserr.FromError(e)
   195  	}
   196  
   197  	t.Kernel().RecordSocket(ns)
   198  	return fd, addr, addrLen, nil
   199  }
   200  
   201  // Bind implements the linux syscall bind(2) for unix sockets.
   202  func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
   203  	p, e := extractPath(sockaddr)
   204  	if e != nil {
   205  		return e
   206  	}
   207  
   208  	bep, ok := s.ep.(transport.BoundEndpoint)
   209  	if !ok {
   210  		// This socket can't be bound.
   211  		return syserr.ErrInvalidArgument
   212  	}
   213  
   214  	if p[0] == 0 {
   215  		// Abstract socket. See net/unix/af_unix.c:unix_bind_abstract().
   216  		if t.IsNetworkNamespaced() {
   217  			return syserr.ErrInvalidEndpointState
   218  		}
   219  		asn := t.AbstractSockets()
   220  		name := p[1:]
   221  		if err := asn.Bind(t, name, bep, s); err != nil {
   222  			// syserr.ErrPortInUse corresponds to EADDRINUSE.
   223  			return syserr.ErrPortInUse
   224  		}
   225  		if err := s.ep.Bind(transport.Address{Addr: p}); err != nil {
   226  			asn.Remove(name, s)
   227  			return err
   228  		}
   229  		// The socket has been successfully bound. We can update the following.
   230  		s.abstractName = name
   231  		s.abstractNamespace = asn
   232  		return nil
   233  	}
   234  
   235  	// See net/unix/af_unix.c:unix_bind_bsd().
   236  	path := fspath.Parse(p)
   237  	root := t.FSContext().RootDirectory()
   238  	defer root.DecRef(t)
   239  	start := root
   240  	relPath := !path.Absolute
   241  	if relPath {
   242  		start = t.FSContext().WorkingDirectory()
   243  		defer start.DecRef(t)
   244  	}
   245  	pop := vfs.PathOperation{
   246  		Root:  root,
   247  		Start: start,
   248  		Path:  path,
   249  	}
   250  	stat, err := s.vfsfd.Stat(t, vfs.StatOptions{Mask: linux.STATX_MODE})
   251  	if err != nil {
   252  		return syserr.FromError(err)
   253  	}
   254  	err = t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{
   255  		Mode:     linux.FileMode(linux.S_IFSOCK | uint(stat.Mode)&^t.FSContext().Umask()),
   256  		Endpoint: bep,
   257  	})
   258  	if linuxerr.Equals(linuxerr.EEXIST, err) {
   259  		return syserr.ErrAddressInUse
   260  	}
   261  	if err != nil {
   262  		return syserr.FromError(err)
   263  	}
   264  	if err := s.ep.Bind(transport.Address{Addr: p}); err != nil {
   265  		if unlinkErr := t.Kernel().VFS().UnlinkAt(t, t.Credentials(), &pop); unlinkErr != nil {
   266  			log.Warningf("failed to unlink socket file created for bind(%q): %v", p, unlinkErr)
   267  		}
   268  		return err
   269  	}
   270  	return nil
   271  }
   272  
   273  // Ioctl implements vfs.FileDescriptionImpl.
   274  func (s *Socket) Ioctl(ctx context.Context, uio usermem.IO, sysno uintptr, args arch.SyscallArguments) (uintptr, error) {
   275  	return netstack.Ioctl(ctx, s.ep, uio, sysno, args)
   276  }
   277  
   278  // PRead implements vfs.FileDescriptionImpl.
   279  func (s *Socket) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
   280  	return 0, linuxerr.ESPIPE
   281  }
   282  
   283  // Read implements vfs.FileDescriptionImpl.
   284  func (s *Socket) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
   285  	// All flags other than RWF_NOWAIT should be ignored.
   286  	// TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
   287  	if opts.Flags != 0 {
   288  		return 0, linuxerr.EOPNOTSUPP
   289  	}
   290  
   291  	if dst.NumBytes() == 0 {
   292  		return 0, nil
   293  	}
   294  	r := &EndpointReader{
   295  		Ctx:       ctx,
   296  		Endpoint:  s.ep,
   297  		NumRights: 0,
   298  		Peek:      false,
   299  		From:      nil,
   300  	}
   301  	n, err := dst.CopyOutFrom(ctx, r)
   302  	if r.Notify != nil {
   303  		r.Notify()
   304  	}
   305  	// Drop control messages.
   306  	r.Control.Release(ctx)
   307  	return n, err
   308  }
   309  
   310  // PWrite implements vfs.FileDescriptionImpl.
   311  func (s *Socket) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
   312  	return 0, linuxerr.ESPIPE
   313  }
   314  
   315  // Write implements vfs.FileDescriptionImpl.
   316  func (s *Socket) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
   317  	// All flags other than RWF_NOWAIT should be ignored.
   318  	// TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
   319  	if opts.Flags != 0 {
   320  		return 0, linuxerr.EOPNOTSUPP
   321  	}
   322  
   323  	t := kernel.TaskFromContext(ctx)
   324  	ctrl := control.New(t, s.ep)
   325  
   326  	if src.NumBytes() == 0 {
   327  		nInt, notify, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil)
   328  		if notify != nil {
   329  			notify()
   330  		}
   331  		return int64(nInt), err.ToError()
   332  	}
   333  
   334  	w := &EndpointWriter{
   335  		Ctx:      ctx,
   336  		Endpoint: s.ep,
   337  		Control:  ctrl,
   338  		To:       nil,
   339  	}
   340  
   341  	n, err := src.CopyInTo(ctx, w)
   342  	if w.Notify != nil {
   343  		w.Notify()
   344  	}
   345  	return n, err
   346  
   347  }
   348  
   349  // Epollable implements FileDescriptionImpl.Epollable.
   350  func (s *Socket) Epollable() bool {
   351  	return true
   352  }
   353  
   354  // SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
   355  // a transport.Endpoint.
   356  func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
   357  	return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
   358  }
   359  
   360  // provider is a unix domain socket provider.
   361  type provider struct{}
   362  
   363  func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
   364  	// Check arguments.
   365  	if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
   366  		return nil, syserr.ErrProtocolNotSupported
   367  	}
   368  
   369  	// Create the endpoint and socket.
   370  	var ep transport.Endpoint
   371  	switch stype {
   372  	case linux.SOCK_DGRAM, linux.SOCK_RAW:
   373  		ep = transport.NewConnectionless(t)
   374  	case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
   375  		ep = transport.NewConnectioned(t, stype, t.Kernel())
   376  	default:
   377  		return nil, syserr.ErrInvalidArgument
   378  	}
   379  
   380  	f, err := NewSockfsFile(t, ep, stype)
   381  	if err != nil {
   382  		ep.Close(t)
   383  		return nil, err
   384  	}
   385  	return f, nil
   386  }
   387  
   388  // Pair creates a new pair of AF_UNIX connected sockets.
   389  func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
   390  	// Check arguments.
   391  	if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
   392  		return nil, nil, syserr.ErrProtocolNotSupported
   393  	}
   394  
   395  	switch stype {
   396  	case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW:
   397  		// Ok
   398  	default:
   399  		return nil, nil, syserr.ErrInvalidArgument
   400  	}
   401  
   402  	// Create the endpoints and sockets.
   403  	ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
   404  	s1, err := NewSockfsFile(t, ep1, stype)
   405  	if err != nil {
   406  		ep1.Close(t)
   407  		ep2.Close(t)
   408  		return nil, nil, err
   409  	}
   410  	s2, err := NewSockfsFile(t, ep2, stype)
   411  	if err != nil {
   412  		s1.DecRef(t)
   413  		ep2.Close(t)
   414  		return nil, nil, err
   415  	}
   416  
   417  	return s1, s2, nil
   418  }
   419  
   420  func (s *Socket) isPacket() bool {
   421  	switch s.stype {
   422  	case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
   423  		return true
   424  	case linux.SOCK_STREAM:
   425  		return false
   426  	default:
   427  		// We shouldn't have allowed any other socket types during creation.
   428  		panic(fmt.Sprintf("Invalid socket type %d", s.stype))
   429  	}
   430  }
   431  
   432  // Endpoint extracts the transport.Endpoint.
   433  func (s *Socket) Endpoint() transport.Endpoint {
   434  	return s.ep
   435  }
   436  
   437  // extractPath extracts and validates the address.
   438  func extractPath(sockaddr []byte) (string, *syserr.Error) {
   439  	addr, family, err := addressAndFamily(sockaddr)
   440  	if err != nil {
   441  		if err == syserr.ErrAddressFamilyNotSupported {
   442  			err = syserr.ErrInvalidArgument
   443  		}
   444  		return "", err
   445  	}
   446  	if family != linux.AF_UNIX {
   447  		return "", syserr.ErrInvalidArgument
   448  	}
   449  
   450  	// The address is trimmed by GetAddress.
   451  	p := addr.Addr
   452  	if p == "" {
   453  		// Not allowed.
   454  		return "", syserr.ErrInvalidArgument
   455  	}
   456  	if p[len(p)-1] == '/' {
   457  		// Weird, they tried to bind '/a/b/c/'?
   458  		return "", syserr.ErrIsDir
   459  	}
   460  
   461  	return p, nil
   462  }
   463  
   464  func addressAndFamily(addr []byte) (transport.Address, uint16, *syserr.Error) {
   465  	// Make sure we have at least 2 bytes for the address family.
   466  	if len(addr) < 2 {
   467  		return transport.Address{}, 0, syserr.ErrInvalidArgument
   468  	}
   469  
   470  	// Get the rest of the fields based on the address family.
   471  	switch family := hostarch.ByteOrder.Uint16(addr); family {
   472  	case linux.AF_UNIX:
   473  		path := addr[2:]
   474  		if len(path) > linux.UnixPathMax {
   475  			return transport.Address{}, family, syserr.ErrInvalidArgument
   476  		}
   477  		// Drop the terminating NUL (if one exists) and everything after
   478  		// it for filesystem (non-abstract) addresses.
   479  		if len(path) > 0 && path[0] != 0 {
   480  			if n := bytes.IndexByte(path[1:], 0); n >= 0 {
   481  				path = path[:n+1]
   482  			}
   483  		}
   484  		return transport.Address{
   485  			Addr: string(path),
   486  		}, family, nil
   487  	}
   488  	return transport.Address{}, 0, syserr.ErrAddressFamilyNotSupported
   489  }
   490  
   491  // GetPeerName implements the linux syscall getpeername(2) for sockets backed by
   492  // a transport.Endpoint.
   493  func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
   494  	addr, err := s.ep.GetRemoteAddress()
   495  	if err != nil {
   496  		return nil, 0, syserr.TranslateNetstackError(err)
   497  	}
   498  
   499  	a, l := convertAddress(addr)
   500  	return a, l, nil
   501  }
   502  
   503  // GetSockName implements the linux syscall getsockname(2) for sockets backed by
   504  // a transport.Endpoint.
   505  func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
   506  	addr, err := s.ep.GetLocalAddress()
   507  	if err != nil {
   508  		return nil, 0, syserr.TranslateNetstackError(err)
   509  	}
   510  
   511  	a, l := convertAddress(addr)
   512  	return a, l, nil
   513  }
   514  
   515  // Listen implements the linux syscall listen(2) for sockets backed by
   516  // a transport.Endpoint.
   517  func (s *Socket) Listen(t *kernel.Task, backlog int) *syserr.Error {
   518  	return s.ep.Listen(t, backlog)
   519  }
   520  
   521  // extractEndpoint retrieves the transport.BoundEndpoint associated with a Unix
   522  // socket path. The Release must be called on the transport.BoundEndpoint when
   523  // the caller is done with it.
   524  func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, *syserr.Error) {
   525  	path, err := extractPath(sockaddr)
   526  	if err != nil {
   527  		return nil, err
   528  	}
   529  
   530  	// Is it abstract?
   531  	if path[0] == 0 {
   532  		if t.IsNetworkNamespaced() {
   533  			return nil, syserr.ErrInvalidArgument
   534  		}
   535  
   536  		ep := t.AbstractSockets().BoundEndpoint(path[1:])
   537  		if ep == nil {
   538  			// No socket found.
   539  			return nil, syserr.ErrConnectionRefused
   540  		}
   541  
   542  		return ep, nil
   543  	}
   544  
   545  	p := fspath.Parse(path)
   546  	root := t.FSContext().RootDirectory()
   547  	start := root
   548  	relPath := !p.Absolute
   549  	if relPath {
   550  		start = t.FSContext().WorkingDirectory()
   551  	}
   552  	pop := vfs.PathOperation{
   553  		Root:               root,
   554  		Start:              start,
   555  		Path:               p,
   556  		FollowFinalSymlink: true,
   557  	}
   558  	ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path})
   559  	root.DecRef(t)
   560  	if relPath {
   561  		start.DecRef(t)
   562  	}
   563  	if e != nil {
   564  		return nil, syserr.FromError(e)
   565  	}
   566  	return ep, nil
   567  }
   568  
   569  // Connect implements the linux syscall connect(2) for unix sockets.
   570  func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
   571  	ep, err := extractEndpoint(t, sockaddr)
   572  	if err != nil {
   573  		return err
   574  	}
   575  	defer ep.Release(t)
   576  
   577  	// Connect the server endpoint.
   578  	err = s.ep.Connect(t, ep)
   579  
   580  	if err == syserr.ErrWrongProtocolForSocket {
   581  		// Linux for abstract sockets returns ErrConnectionRefused
   582  		// instead of ErrWrongProtocolForSocket.
   583  		path, _ := extractPath(sockaddr)
   584  		if len(path) > 0 && path[0] == 0 {
   585  			err = syserr.ErrConnectionRefused
   586  		}
   587  	}
   588  
   589  	return err
   590  }
   591  
   592  // SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
   593  // a transport.Endpoint.
   594  func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
   595  	w := EndpointWriter{
   596  		Ctx:      t,
   597  		Endpoint: s.ep,
   598  		Control:  controlMessages.Unix,
   599  		To:       nil,
   600  	}
   601  	if len(to) > 0 {
   602  		switch s.stype {
   603  		case linux.SOCK_SEQPACKET:
   604  			// to is ignored.
   605  		case linux.SOCK_STREAM:
   606  			if s.State() == linux.SS_CONNECTED {
   607  				return 0, syserr.ErrAlreadyConnected
   608  			}
   609  			return 0, syserr.ErrNotSupported
   610  		default:
   611  			ep, err := extractEndpoint(t, to)
   612  			if err != nil {
   613  				return 0, err
   614  			}
   615  			defer ep.Release(t)
   616  			w.To = ep
   617  
   618  			if ep.Passcred() && w.Control.Credentials == nil {
   619  				w.Control.Credentials = control.MakeCreds(t)
   620  			}
   621  		}
   622  	}
   623  
   624  	n, err := src.CopyInTo(t, &w)
   625  	if w.Notify != nil {
   626  		w.Notify()
   627  	}
   628  	if err != linuxerr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
   629  		return int(n), syserr.FromError(err)
   630  	}
   631  
   632  	// Only send SCM Rights once (see net/unix/af_unix.c:unix_stream_sendmsg).
   633  	w.Control.Rights = nil
   634  
   635  	// We'll have to block. Register for notification and keep trying to
   636  	// send all the data.
   637  	e, ch := waiter.NewChannelEntry(waiter.WritableEvents)
   638  	s.EventRegister(&e)
   639  	defer s.EventUnregister(&e)
   640  
   641  	total := n
   642  	for {
   643  		// Shorten src to reflect bytes previously written.
   644  		src = src.DropFirst64(n)
   645  
   646  		n, err = src.CopyInTo(t, &w)
   647  		if w.Notify != nil {
   648  			w.Notify()
   649  		}
   650  		total += n
   651  		if err != linuxerr.ErrWouldBlock {
   652  			break
   653  		}
   654  
   655  		if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
   656  			if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
   657  				err = linuxerr.ErrWouldBlock
   658  			}
   659  			break
   660  		}
   661  	}
   662  
   663  	return int(total), syserr.FromError(err)
   664  }
   665  
   666  // Passcred implements transport.Credentialer.Passcred.
   667  func (s *Socket) Passcred() bool {
   668  	return s.ep.Passcred()
   669  }
   670  
   671  // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
   672  func (s *Socket) ConnectedPasscred() bool {
   673  	return s.ep.ConnectedPasscred()
   674  }
   675  
   676  // Readiness implements waiter.Waitable.Readiness.
   677  func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask {
   678  	return s.ep.Readiness(mask)
   679  }
   680  
   681  // EventRegister implements waiter.Waitable.EventRegister.
   682  func (s *Socket) EventRegister(e *waiter.Entry) error {
   683  	return s.ep.EventRegister(e)
   684  }
   685  
   686  // EventUnregister implements waiter.Waitable.EventUnregister.
   687  func (s *Socket) EventUnregister(e *waiter.Entry) {
   688  	s.ep.EventUnregister(e)
   689  }
   690  
   691  // Shutdown implements the linux syscall shutdown(2) for sockets backed by
   692  // a transport.Endpoint.
   693  func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error {
   694  	f, err := netstack.ConvertShutdown(how)
   695  	if err != nil {
   696  		return err
   697  	}
   698  
   699  	// Issue shutdown request.
   700  	return s.ep.Shutdown(f)
   701  }
   702  
   703  // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
   704  // a transport.Endpoint.
   705  func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
   706  	trunc := flags&linux.MSG_TRUNC != 0
   707  	peek := flags&linux.MSG_PEEK != 0
   708  	dontWait := flags&linux.MSG_DONTWAIT != 0
   709  	waitAll := flags&linux.MSG_WAITALL != 0
   710  	isPacket := s.isPacket()
   711  
   712  	// Calculate the number of FDs for which we have space and if we are
   713  	// requesting credentials.
   714  	var wantCreds bool
   715  	rightsLen := int(controlDataLen) - unix.SizeofCmsghdr
   716  	if s.Passcred() {
   717  		// Credentials take priority if they are enabled and there is space.
   718  		wantCreds = rightsLen > 0
   719  		if !wantCreds {
   720  			msgFlags |= linux.MSG_CTRUNC
   721  		}
   722  		credLen := unix.CmsgSpace(unix.SizeofUcred)
   723  		rightsLen -= credLen
   724  	}
   725  	// FDs are 32 bit (4 byte) ints.
   726  	numRights := rightsLen / 4
   727  	if numRights < 0 {
   728  		numRights = 0
   729  	}
   730  
   731  	r := EndpointReader{
   732  		Ctx:       t,
   733  		Endpoint:  s.ep,
   734  		Creds:     wantCreds,
   735  		NumRights: numRights,
   736  		Peek:      peek,
   737  	}
   738  	if senderRequested {
   739  		r.From = &transport.Address{}
   740  	}
   741  
   742  	doRead := func() (int64, error) {
   743  		n, err := dst.CopyOutFrom(t, &r)
   744  		if r.Notify != nil {
   745  			r.Notify()
   746  		}
   747  		return n, err
   748  	}
   749  
   750  	// If MSG_TRUNC is set with a zero byte destination then we still need
   751  	// to read the message and discard it, or in the case where MSG_PEEK is
   752  	// set, leave it be. In both cases the full message length must be
   753  	// returned.
   754  	if trunc && dst.Addrs.NumBytes() == 0 {
   755  		doRead = func() (int64, error) {
   756  			err := r.Truncate()
   757  			// Always return zero for bytes read since the destination size is
   758  			// zero.
   759  			return 0, err
   760  		}
   761  
   762  	}
   763  
   764  	var total int64
   765  	if n, err := doRead(); err != linuxerr.ErrWouldBlock || dontWait {
   766  		var from linux.SockAddr
   767  		var fromLen uint32
   768  		if r.From != nil && len([]byte(r.From.Addr)) != 0 {
   769  			from, fromLen = convertAddress(*r.From)
   770  		}
   771  
   772  		if r.ControlTrunc {
   773  			msgFlags |= linux.MSG_CTRUNC
   774  		}
   775  
   776  		if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() {
   777  			if isPacket && n < int64(r.MsgSize) {
   778  				msgFlags |= linux.MSG_TRUNC
   779  			}
   780  
   781  			if trunc {
   782  				n = int64(r.MsgSize)
   783  			}
   784  
   785  			return int(n), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
   786  		}
   787  
   788  		// Don't overwrite any data we received.
   789  		dst = dst.DropFirst64(n)
   790  		total += n
   791  	}
   792  
   793  	// We'll have to block. Register for notification and keep trying to
   794  	// send all the data.
   795  	e, ch := waiter.NewChannelEntry(waiter.ReadableEvents)
   796  	s.EventRegister(&e)
   797  	defer s.EventUnregister(&e)
   798  
   799  	for {
   800  		if n, err := doRead(); err != linuxerr.ErrWouldBlock {
   801  			var from linux.SockAddr
   802  			var fromLen uint32
   803  			if r.From != nil {
   804  				from, fromLen = convertAddress(*r.From)
   805  			}
   806  
   807  			if r.ControlTrunc {
   808  				msgFlags |= linux.MSG_CTRUNC
   809  			}
   810  
   811  			if trunc {
   812  				// n and r.MsgSize are the same for streams.
   813  				total += int64(r.MsgSize)
   814  			} else {
   815  				total += n
   816  			}
   817  
   818  			streamPeerClosed := s.stype == linux.SOCK_STREAM && n == 0 && err == nil
   819  			if err != nil || !waitAll || isPacket || n >= dst.NumBytes() || streamPeerClosed {
   820  				if total > 0 {
   821  					err = nil
   822  				}
   823  				if isPacket && n < int64(r.MsgSize) {
   824  					msgFlags |= linux.MSG_TRUNC
   825  				}
   826  				return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
   827  			}
   828  
   829  			// Don't overwrite any data we received.
   830  			dst = dst.DropFirst64(n)
   831  		}
   832  
   833  		if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
   834  			if total > 0 {
   835  				err = nil
   836  			}
   837  			if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
   838  				return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
   839  			}
   840  			return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
   841  		}
   842  	}
   843  }
   844  
   845  // State implements socket.Socket.State.
   846  func (s *Socket) State() uint32 {
   847  	return s.ep.State()
   848  }
   849  
   850  // Type implements socket.Socket.Type.
   851  func (s *Socket) Type() (family int, skType linux.SockType, protocol int) {
   852  	// Unix domain sockets always have a protocol of 0.
   853  	return linux.AF_UNIX, s.stype, 0
   854  }
   855  
   856  func convertAddress(addr transport.Address) (linux.SockAddr, uint32) {
   857  	var out linux.SockAddrUnix
   858  	out.Family = linux.AF_UNIX
   859  	l := len([]byte(addr.Addr))
   860  	for i := 0; i < l; i++ {
   861  		out.Path[i] = int8(addr.Addr[i])
   862  	}
   863  
   864  	// Linux returns the used length of the address struct (including the
   865  	// null terminator) for filesystem paths. The Family field is 2 bytes.
   866  	// It is sometimes allowed to exclude the null terminator if the
   867  	// address length is the max. Abstract and empty paths always return
   868  	// the full exact length.
   869  	if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
   870  		return &out, uint32(2 + l)
   871  	}
   872  	return &out, uint32(3 + l)
   873  
   874  }
   875  
   876  func init() {
   877  	socket.RegisterProvider(linux.AF_UNIX, &provider{})
   878  }