github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/unet/unet.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package unet provides a minimal net package based on Unix Domain Sockets.
    16  //
    17  // This does no pooling, and should only be used for a limited number of
    18  // connections in a Go process. Don't use this package for arbitrary servers.
    19  package unet
    20  
    21  import (
    22  	"errors"
    23  
    24  	"golang.org/x/sys/unix"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/atomicbitops"
    26  	"github.com/nicocha30/gvisor-ligolo/pkg/eventfd"
    27  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    28  )
    29  
    30  // backlog is used for the listen request.
    31  const backlog = 16
    32  
    33  // errClosing is returned by wait if the Socket is in the process of closing.
    34  var errClosing = errors.New("Socket is closing")
    35  
    36  // errMessageTruncated indicates that data was lost because the provided buffer
    37  // was too small.
    38  var errMessageTruncated = errors.New("message truncated")
    39  
    40  // socketType returns the appropriate type.
    41  func socketType(packet bool) int {
    42  	if packet {
    43  		return unix.SOCK_SEQPACKET
    44  	}
    45  	return unix.SOCK_STREAM
    46  }
    47  
    48  // socket creates a new host socket.
    49  func socket(packet bool) (int, error) {
    50  	// Make a new socket.
    51  	fd, err := unix.Socket(unix.AF_UNIX, socketType(packet), 0)
    52  	if err != nil {
    53  		return 0, err
    54  	}
    55  
    56  	return fd, nil
    57  }
    58  
    59  // Socket is a connected unix domain socket.
    60  type Socket struct {
    61  	// gate protects use of fd.
    62  	gate sync.Gate
    63  
    64  	// fd is the bound socket.
    65  	//
    66  	// fd only remains valid if read while within gate.
    67  	fd atomicbitops.Int32
    68  
    69  	// efd is an event FD that is signaled when the socket is closing.
    70  	//
    71  	// efd is immutable and remains valid until Close/Release.
    72  	efd eventfd.Eventfd
    73  
    74  	// race is an atomic variable used to avoid triggering the race
    75  	// detector. See comment in SocketPair below.
    76  	race *atomicbitops.Int32
    77  }
    78  
    79  // NewSocket returns a socket from an existing FD.
    80  //
    81  // NewSocket takes ownership of fd.
    82  func NewSocket(fd int) (*Socket, error) {
    83  	// fd must be non-blocking for non-blocking unix.Accept in
    84  	// ServerSocket.Accept.
    85  	if err := unix.SetNonblock(fd, true); err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	efd, err := eventfd.Create()
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	return &Socket{
    95  		fd:  atomicbitops.FromInt32(int32(fd)),
    96  		efd: efd,
    97  	}, nil
    98  }
    99  
   100  // finish completes use of s.fd by evicting any waiters, closing the gate, and
   101  // closing the event FD.
   102  func (s *Socket) finish() error {
   103  	// Signal any blocked or future polls.
   104  	if err := s.efd.Notify(); err != nil {
   105  		return err
   106  	}
   107  
   108  	// Close the gate, blocking until all FD users leave.
   109  	s.gate.Close()
   110  
   111  	return s.efd.Close()
   112  }
   113  
   114  // Close closes the socket.
   115  func (s *Socket) Close() error {
   116  	// Set the FD in the socket to -1, to ensure that all future calls to
   117  	// FD/Release get nothing and Close calls return immediately.
   118  	fd := int(s.fd.Swap(-1))
   119  	if fd < 0 {
   120  		// Already closed or closing.
   121  		return unix.EBADF
   122  	}
   123  
   124  	// Shutdown the socket to cancel any pending accepts.
   125  	s.shutdown(fd)
   126  
   127  	if err := s.finish(); err != nil {
   128  		return err
   129  	}
   130  
   131  	return unix.Close(fd)
   132  }
   133  
   134  // Release releases ownership of the socket FD.
   135  //
   136  // The returned FD is non-blocking.
   137  //
   138  // Any concurrent or future callers of Socket methods will receive EBADF.
   139  func (s *Socket) Release() (int, error) {
   140  	// Set the FD in the socket to -1, to ensure that all future calls to
   141  	// FD/Release get nothing and Close calls return immediately.
   142  	fd := int(s.fd.Swap(-1))
   143  	if fd < 0 {
   144  		// Already closed or closing.
   145  		return -1, unix.EBADF
   146  	}
   147  
   148  	if err := s.finish(); err != nil {
   149  		return -1, err
   150  	}
   151  
   152  	return fd, nil
   153  }
   154  
   155  // FD returns the FD for this Socket.
   156  //
   157  // The FD is non-blocking and must not be made blocking.
   158  //
   159  // N.B. os.File.Fd makes the FD blocking. Use of Release instead of FD is
   160  // strongly preferred.
   161  //
   162  // The returned FD cannot be used safely if there may be concurrent callers to
   163  // Close or Release.
   164  //
   165  // Use Release to take ownership of the FD.
   166  func (s *Socket) FD() int {
   167  	return int(s.fd.Load())
   168  }
   169  
   170  // enterFD enters the FD gate and returns the FD value.
   171  //
   172  // If enterFD returns ok, s.gate.Leave must be called when done with the FD.
   173  // Callers may only block while within the gate using s.wait.
   174  //
   175  // The returned FD is guaranteed to remain valid until s.gate.Leave.
   176  func (s *Socket) enterFD() (int, bool) {
   177  	if !s.gate.Enter() {
   178  		return -1, false
   179  	}
   180  
   181  	fd := int(s.fd.Load())
   182  	if fd < 0 {
   183  		s.gate.Leave()
   184  		return -1, false
   185  	}
   186  
   187  	return fd, true
   188  }
   189  
   190  // SocketPair creates a pair of connected sockets.
   191  func SocketPair(packet bool) (*Socket, *Socket, error) {
   192  	// Make a new pair.
   193  	fds, err := unix.Socketpair(unix.AF_UNIX, socketType(packet)|unix.SOCK_CLOEXEC, 0)
   194  	if err != nil {
   195  		return nil, nil, err
   196  	}
   197  
   198  	// race is an atomic variable used to avoid triggering the race
   199  	// detector. We have to fool TSAN into thinking there is a race
   200  	// variable between our two sockets. We only use SocketPair in tests
   201  	// anyway.
   202  	//
   203  	// NOTE(b/27107811): This is purely due to the fact that the raw
   204  	// syscall does not serve as a boundary for the sanitizer.
   205  	a, err := NewSocket(fds[0])
   206  	if err != nil {
   207  		unix.Close(fds[0])
   208  		unix.Close(fds[1])
   209  		return nil, nil, err
   210  	}
   211  	var race atomicbitops.Int32
   212  	a.race = &race
   213  	b, err := NewSocket(fds[1])
   214  	if err != nil {
   215  		a.Close()
   216  		unix.Close(fds[1])
   217  		return nil, nil, err
   218  	}
   219  	b.race = &race
   220  	return a, b, nil
   221  }
   222  
   223  // Connect connects to a server.
   224  func Connect(addr string, packet bool) (*Socket, error) {
   225  	fd, err := socket(packet)
   226  	if err != nil {
   227  		return nil, err
   228  	}
   229  
   230  	// Connect the socket.
   231  	usa := &unix.SockaddrUnix{Name: addr}
   232  	if err := unix.Connect(fd, usa); err != nil {
   233  		unix.Close(fd)
   234  		return nil, err
   235  	}
   236  
   237  	return NewSocket(fd)
   238  }
   239  
   240  // ControlMessage wraps around a byte array and provides functions for parsing
   241  // as a Unix Domain Socket control message.
   242  type ControlMessage []byte
   243  
   244  // EnableFDs enables receiving FDs via control message.
   245  //
   246  // This guarantees only a MINIMUM number of FDs received. You may receive MORE
   247  // than this due to the way FDs are packed. To be specific, the number of
   248  // receivable buffers will be rounded up to the nearest even number.
   249  //
   250  // This must be called prior to ReadVec if you want to receive FDs.
   251  func (c *ControlMessage) EnableFDs(count int) {
   252  	*c = make([]byte, unix.CmsgSpace(count*4))
   253  }
   254  
   255  // ExtractFDs returns the list of FDs in the control message.
   256  //
   257  // Either this or CloseFDs should be used after EnableFDs.
   258  func (c *ControlMessage) ExtractFDs() ([]int, error) {
   259  	msgs, err := unix.ParseSocketControlMessage(*c)
   260  	if err != nil {
   261  		return nil, err
   262  	}
   263  	var fds []int
   264  	for _, msg := range msgs {
   265  		thisFds, err := unix.ParseUnixRights(&msg)
   266  		if err != nil {
   267  			// Different control message.
   268  			return nil, err
   269  		}
   270  		for _, fd := range thisFds {
   271  			if fd >= 0 {
   272  				fds = append(fds, fd)
   273  			}
   274  		}
   275  	}
   276  	return fds, nil
   277  }
   278  
   279  // CloseFDs closes the list of FDs in the control message.
   280  //
   281  // Either this or ExtractFDs should be used after EnableFDs.
   282  func (c *ControlMessage) CloseFDs() {
   283  	fds, _ := c.ExtractFDs()
   284  	for _, fd := range fds {
   285  		if fd >= 0 {
   286  			unix.Close(fd)
   287  		}
   288  	}
   289  }
   290  
   291  // PackFDs packs the given list of FDs in the control message.
   292  //
   293  // This must be used prior to WriteVec.
   294  func (c *ControlMessage) PackFDs(fds ...int) {
   295  	*c = ControlMessage(unix.UnixRights(fds...))
   296  }
   297  
   298  // UnpackFDs clears the control message.
   299  func (c *ControlMessage) UnpackFDs() {
   300  	*c = nil
   301  }
   302  
   303  // SocketWriter wraps an individual send operation.
   304  //
   305  // The normal entrypoint is WriteVec.
   306  type SocketWriter struct {
   307  	socket   *Socket
   308  	to       []byte
   309  	blocking bool
   310  	race     *atomicbitops.Int32
   311  
   312  	ControlMessage
   313  }
   314  
   315  // Writer returns a writer for this socket.
   316  func (s *Socket) Writer(blocking bool) SocketWriter {
   317  	return SocketWriter{socket: s, blocking: blocking, race: s.race}
   318  }
   319  
   320  // Write implements io.Writer.Write.
   321  func (s *Socket) Write(p []byte) (int, error) {
   322  	r := s.Writer(true)
   323  	return r.WriteVec([][]byte{p})
   324  }
   325  
   326  // GetSockOpt gets the given socket option.
   327  func (s *Socket) GetSockOpt(level int, name int, b []byte) (uint32, error) {
   328  	fd, ok := s.enterFD()
   329  	if !ok {
   330  		return 0, unix.EBADF
   331  	}
   332  	defer s.gate.Leave()
   333  
   334  	return getsockopt(fd, level, name, b)
   335  }
   336  
   337  // SetSockOpt sets the given socket option.
   338  func (s *Socket) SetSockOpt(level, name int, b []byte) error {
   339  	fd, ok := s.enterFD()
   340  	if !ok {
   341  		return unix.EBADF
   342  	}
   343  	defer s.gate.Leave()
   344  
   345  	return setsockopt(fd, level, name, b)
   346  }
   347  
   348  // GetSockName returns the socket name.
   349  func (s *Socket) GetSockName() ([]byte, error) {
   350  	fd, ok := s.enterFD()
   351  	if !ok {
   352  		return nil, unix.EBADF
   353  	}
   354  	defer s.gate.Leave()
   355  
   356  	var buf []byte
   357  	l := unix.SizeofSockaddrAny
   358  
   359  	for {
   360  		// If the buffer is not large enough, allocate a new one with the hint.
   361  		buf = make([]byte, l)
   362  		l, err := getsockname(fd, buf)
   363  		if err != nil {
   364  			return nil, err
   365  		}
   366  
   367  		if l <= uint32(len(buf)) {
   368  			return buf[:l], nil
   369  		}
   370  	}
   371  }
   372  
   373  // GetPeerName returns the peer name.
   374  func (s *Socket) GetPeerName() ([]byte, error) {
   375  	fd, ok := s.enterFD()
   376  	if !ok {
   377  		return nil, unix.EBADF
   378  	}
   379  	defer s.gate.Leave()
   380  
   381  	var buf []byte
   382  	l := unix.SizeofSockaddrAny
   383  
   384  	for {
   385  		// See above.
   386  		buf = make([]byte, l)
   387  		l, err := getpeername(fd, buf)
   388  		if err != nil {
   389  			return nil, err
   390  		}
   391  
   392  		if l <= uint32(len(buf)) {
   393  			return buf[:l], nil
   394  		}
   395  	}
   396  }
   397  
   398  // GetPeerCred returns the peer's unix credentials.
   399  func (s *Socket) GetPeerCred() (*unix.Ucred, error) {
   400  	fd, ok := s.enterFD()
   401  	if !ok {
   402  		return nil, unix.EBADF
   403  	}
   404  	defer s.gate.Leave()
   405  
   406  	return unix.GetsockoptUcred(fd, unix.SOL_SOCKET, unix.SO_PEERCRED)
   407  }
   408  
   409  // SocketReader wraps an individual receive operation.
   410  //
   411  // This may be used for doing vectorized reads and/or sending additional
   412  // control messages (e.g. FDs). The normal entrypoint is ReadVec.
   413  //
   414  // One of ExtractFDs or DisposeFDs must be called if EnableFDs is used.
   415  type SocketReader struct {
   416  	socket   *Socket
   417  	source   []byte
   418  	blocking bool
   419  	race     *atomicbitops.Int32
   420  
   421  	ControlMessage
   422  }
   423  
   424  // Reader returns a reader for this socket.
   425  func (s *Socket) Reader(blocking bool) SocketReader {
   426  	return SocketReader{socket: s, blocking: blocking, race: s.race}
   427  }
   428  
   429  // Read implements io.Reader.Read.
   430  func (s *Socket) Read(p []byte) (int, error) {
   431  	r := s.Reader(true)
   432  	return r.ReadVec([][]byte{p})
   433  }
   434  
   435  func (s *Socket) shutdown(fd int) error {
   436  	// Shutdown the socket to cancel any pending accepts.
   437  	return unix.Shutdown(fd, unix.SHUT_RDWR)
   438  }
   439  
   440  // Shutdown closes the socket for read and write.
   441  func (s *Socket) Shutdown() error {
   442  	fd, ok := s.enterFD()
   443  	if !ok {
   444  		return unix.EBADF
   445  	}
   446  	defer s.gate.Leave()
   447  
   448  	return s.shutdown(fd)
   449  }
   450  
   451  // ServerSocket is a bound unix domain socket.
   452  type ServerSocket struct {
   453  	socket *Socket
   454  }
   455  
   456  // NewServerSocket returns a socket from an existing FD.
   457  func NewServerSocket(fd int) (*ServerSocket, error) {
   458  	s, err := NewSocket(fd)
   459  	if err != nil {
   460  		return nil, err
   461  	}
   462  	return &ServerSocket{socket: s}, nil
   463  }
   464  
   465  // Bind creates and binds a new socket.
   466  func Bind(addr string, packet bool) (*ServerSocket, error) {
   467  	fd, err := socket(packet)
   468  	if err != nil {
   469  		return nil, err
   470  	}
   471  
   472  	// Do the bind.
   473  	usa := &unix.SockaddrUnix{Name: addr}
   474  	if err := unix.Bind(fd, usa); err != nil {
   475  		unix.Close(fd)
   476  		return nil, err
   477  	}
   478  
   479  	return NewServerSocket(fd)
   480  }
   481  
   482  // BindAndListen creates, binds and listens on a new socket.
   483  func BindAndListen(addr string, packet bool) (*ServerSocket, error) {
   484  	s, err := Bind(addr, packet)
   485  	if err != nil {
   486  		return nil, err
   487  	}
   488  
   489  	// Start listening.
   490  	if err := s.Listen(); err != nil {
   491  		s.Close()
   492  		return nil, err
   493  	}
   494  
   495  	return s, nil
   496  }
   497  
   498  // Listen starts listening on the socket.
   499  func (s *ServerSocket) Listen() error {
   500  	fd, ok := s.socket.enterFD()
   501  	if !ok {
   502  		return unix.EBADF
   503  	}
   504  	defer s.socket.gate.Leave()
   505  
   506  	return unix.Listen(fd, backlog)
   507  }
   508  
   509  // Accept accepts a new connection.
   510  //
   511  // This is always blocking.
   512  //
   513  // Preconditions:
   514  //   - ServerSocket is listening (Listen called).
   515  func (s *ServerSocket) Accept() (*Socket, error) {
   516  	fd, ok := s.socket.enterFD()
   517  	if !ok {
   518  		return nil, unix.EBADF
   519  	}
   520  	defer s.socket.gate.Leave()
   521  
   522  	for {
   523  		nfd, _, err := unix.Accept(fd)
   524  		switch err {
   525  		case nil:
   526  			return NewSocket(nfd)
   527  		case unix.EAGAIN:
   528  			err = s.socket.wait(false)
   529  			if err == errClosing {
   530  				err = unix.EBADF
   531  			}
   532  		}
   533  		if err != nil {
   534  			return nil, err
   535  		}
   536  	}
   537  }
   538  
   539  // Close closes the server socket.
   540  //
   541  // This must only be called once.
   542  func (s *ServerSocket) Close() error {
   543  	return s.socket.Close()
   544  }
   545  
   546  // FD returns the socket's file descriptor.
   547  //
   548  // See Socket.FD.
   549  func (s *ServerSocket) FD() int {
   550  	return s.socket.FD()
   551  }
   552  
   553  // Release releases ownership of the socket's file descriptor.
   554  //
   555  // See Socket.Release.
   556  func (s *ServerSocket) Release() (int, error) {
   557  	return s.socket.Release()
   558  }