github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/socket/unix/transport/connectioned.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transport
    16  
    17  import (
    18  	"github.com/SagerNet/gvisor/pkg/abi/linux"
    19  	"github.com/SagerNet/gvisor/pkg/context"
    20  	"github.com/SagerNet/gvisor/pkg/sync"
    21  	"github.com/SagerNet/gvisor/pkg/syserr"
    22  	"github.com/SagerNet/gvisor/pkg/tcpip"
    23  	"github.com/SagerNet/gvisor/pkg/waiter"
    24  )
    25  
    26  // UniqueIDProvider generates a sequence of unique identifiers useful for,
    27  // among other things, lock ordering.
    28  type UniqueIDProvider interface {
    29  	// UniqueID returns a new unique identifier.
    30  	UniqueID() uint64
    31  }
    32  
    33  // A ConnectingEndpoint is a connectioned unix endpoint that is attempting to
    34  // establish a bidirectional connection with a BoundEndpoint.
    35  type ConnectingEndpoint interface {
    36  	// ID returns the endpoint's globally unique identifier. This identifier
    37  	// must be used to determine locking order if more than one endpoint is
    38  	// to be locked in the same codepath. The endpoint with the smaller
    39  	// identifier must be locked before endpoints with larger identifiers.
    40  	ID() uint64
    41  
    42  	// Passcred implements socket.Credentialer.Passcred.
    43  	Passcred() bool
    44  
    45  	// Type returns the socket type, typically either SockStream or
    46  	// SockSeqpacket. The connection attempt must be aborted if this
    47  	// value doesn't match the ConnectableEndpoint's type.
    48  	Type() linux.SockType
    49  
    50  	// GetLocalAddress returns the bound path.
    51  	GetLocalAddress() (tcpip.FullAddress, tcpip.Error)
    52  
    53  	// Locker protects the following methods. While locked, only the holder of
    54  	// the lock can change the return value of the protected methods.
    55  	sync.Locker
    56  
    57  	// Connected returns true iff the ConnectingEndpoint is in the connected
    58  	// state. ConnectingEndpoints can only be connected to a single endpoint,
    59  	// so the connection attempt must be aborted if this returns true.
    60  	Connected() bool
    61  
    62  	// Listening returns true iff the ConnectingEndpoint is in the listening
    63  	// state. ConnectingEndpoints cannot make connections while listening, so
    64  	// the connection attempt must be aborted if this returns true.
    65  	Listening() bool
    66  
    67  	// WaiterQueue returns a pointer to the endpoint's waiter queue.
    68  	WaiterQueue() *waiter.Queue
    69  }
    70  
    71  // connectionedEndpoint is a Unix-domain connected or connectable endpoint and implements
    72  // ConnectingEndpoint, ConnectableEndpoint and tcpip.Endpoint.
    73  //
    74  // connectionedEndpoints must be in connected state in order to transfer data.
    75  //
    76  // This implementation includes STREAM and SEQPACKET Unix sockets created with
    77  // socket(2), accept(2) or socketpair(2) and dgram unix sockets created with
    78  // socketpair(2). See unix_connectionless.go for the implementation of DGRAM
    79  // Unix sockets created with socket(2).
    80  //
    81  // The state is much simpler than a TCP endpoint, so it is not encoded
    82  // explicitly. Instead we enforce the following invariants:
    83  //
    84  // receiver != nil, connected != nil => connected.
    85  // path != "" && acceptedChan == nil => bound, not listening.
    86  // path != "" && acceptedChan != nil => bound and listening.
    87  //
    88  // Only one of these will be true at any moment.
    89  //
    90  // +stateify savable
    91  type connectionedEndpoint struct {
    92  	baseEndpoint
    93  
    94  	// id is the unique endpoint identifier. This is used exclusively for
    95  	// lock ordering within connect.
    96  	id uint64
    97  
    98  	// idGenerator is used to generate new unique endpoint identifiers.
    99  	idGenerator UniqueIDProvider
   100  
   101  	// stype is used by connecting sockets to ensure that they are the
   102  	// same type. The value is typically either tcpip.SockSeqpacket or
   103  	// tcpip.SockStream.
   104  	stype linux.SockType
   105  
   106  	// acceptedChan is per the TCP endpoint implementation. Note that the
   107  	// sockets in this channel are _already in the connected state_, and
   108  	// have another associated connectionedEndpoint.
   109  	//
   110  	// If nil, then no listen call has been made.
   111  	acceptedChan chan *connectionedEndpoint `state:".([]*connectionedEndpoint)"`
   112  }
   113  
   114  var (
   115  	_ = BoundEndpoint((*connectionedEndpoint)(nil))
   116  	_ = Endpoint((*connectionedEndpoint)(nil))
   117  )
   118  
   119  // NewConnectioned creates a new unbound connectionedEndpoint.
   120  func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint {
   121  	return newConnectioned(ctx, stype, uid)
   122  }
   123  
   124  func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint {
   125  	ep := &connectionedEndpoint{
   126  		baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
   127  		id:           uid.UniqueID(),
   128  		idGenerator:  uid,
   129  		stype:        stype,
   130  	}
   131  
   132  	ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
   133  	ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
   134  	ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
   135  	return ep
   136  }
   137  
   138  // NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
   139  func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
   140  	a := newConnectioned(ctx, stype, uid)
   141  	b := newConnectioned(ctx, stype, uid)
   142  
   143  	q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: defaultBufferSize}
   144  	q1.InitRefs()
   145  	q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: defaultBufferSize}
   146  	q2.InitRefs()
   147  
   148  	if stype == linux.SOCK_STREAM {
   149  		a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
   150  		b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}}
   151  	} else {
   152  		a.receiver = &queueReceiver{q1}
   153  		b.receiver = &queueReceiver{q2}
   154  	}
   155  
   156  	q2.IncRef()
   157  	a.connected = &connectedEndpoint{
   158  		endpoint:   b,
   159  		writeQueue: q2,
   160  	}
   161  	q1.IncRef()
   162  	b.connected = &connectedEndpoint{
   163  		endpoint:   a,
   164  		writeQueue: q1,
   165  	}
   166  
   167  	return a, b
   168  }
   169  
   170  // NewExternal creates a new externally backed Endpoint. It behaves like a
   171  // socketpair.
   172  func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
   173  	ep := &connectionedEndpoint{
   174  		baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
   175  		id:           uid.UniqueID(),
   176  		idGenerator:  uid,
   177  		stype:        stype,
   178  	}
   179  	ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
   180  	ep.ops.SetSendBufferSize(connected.SendMaxQueueSize(), false /* notify */)
   181  	ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
   182  	return ep
   183  }
   184  
   185  // ID implements ConnectingEndpoint.ID.
   186  func (e *connectionedEndpoint) ID() uint64 {
   187  	return e.id
   188  }
   189  
   190  // Type implements ConnectingEndpoint.Type and Endpoint.Type.
   191  func (e *connectionedEndpoint) Type() linux.SockType {
   192  	return e.stype
   193  }
   194  
   195  // WaiterQueue implements ConnectingEndpoint.WaiterQueue.
   196  func (e *connectionedEndpoint) WaiterQueue() *waiter.Queue {
   197  	return e.Queue
   198  }
   199  
   200  // isBound returns true iff the connectionedEndpoint is bound (but not
   201  // listening).
   202  func (e *connectionedEndpoint) isBound() bool {
   203  	return e.path != "" && e.acceptedChan == nil
   204  }
   205  
   206  // Listening implements ConnectingEndpoint.Listening.
   207  func (e *connectionedEndpoint) Listening() bool {
   208  	return e.acceptedChan != nil
   209  }
   210  
   211  // Close puts the connectionedEndpoint in a closed state and frees all
   212  // resources associated with it.
   213  //
   214  // The socket will be a fresh state after a call to close and may be reused.
   215  // That is, close may be used to "unbind" or "disconnect" the socket in error
   216  // paths.
   217  func (e *connectionedEndpoint) Close(ctx context.Context) {
   218  	e.Lock()
   219  	var c ConnectedEndpoint
   220  	var r Receiver
   221  	switch {
   222  	case e.Connected():
   223  		e.connected.CloseSend()
   224  		e.receiver.CloseRecv()
   225  		// Still have unread data? If yes, we set this into the write
   226  		// end so that the peer can get ECONNRESET) when it does read.
   227  		if e.receiver.RecvQueuedSize() > 0 {
   228  			e.connected.CloseUnread()
   229  		}
   230  		c = e.connected
   231  		r = e.receiver
   232  		e.connected = nil
   233  		e.receiver = nil
   234  	case e.isBound():
   235  		e.path = ""
   236  	case e.Listening():
   237  		close(e.acceptedChan)
   238  		for n := range e.acceptedChan {
   239  			n.Close(ctx)
   240  		}
   241  		e.acceptedChan = nil
   242  		e.path = ""
   243  	}
   244  	e.Unlock()
   245  	if c != nil {
   246  		c.CloseNotify()
   247  		c.Release(ctx)
   248  	}
   249  	if r != nil {
   250  		r.CloseNotify()
   251  		r.Release(ctx)
   252  	}
   253  }
   254  
   255  // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
   256  func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
   257  	if ce.Type() != e.stype {
   258  		return syserr.ErrWrongProtocolForSocket
   259  	}
   260  
   261  	// Check if ce is e to avoid a deadlock.
   262  	if ce, ok := ce.(*connectionedEndpoint); ok && ce == e {
   263  		return syserr.ErrInvalidEndpointState
   264  	}
   265  
   266  	// Do a dance to safely acquire locks on both endpoints.
   267  	if e.id < ce.ID() {
   268  		e.Lock()
   269  		ce.Lock()
   270  	} else {
   271  		ce.Lock()
   272  		e.Lock()
   273  	}
   274  
   275  	// Check connecting state.
   276  	if ce.Connected() {
   277  		e.Unlock()
   278  		ce.Unlock()
   279  		return syserr.ErrAlreadyConnected
   280  	}
   281  	if ce.Listening() {
   282  		e.Unlock()
   283  		ce.Unlock()
   284  		return syserr.ErrInvalidEndpointState
   285  	}
   286  
   287  	// Check bound state.
   288  	if !e.Listening() {
   289  		e.Unlock()
   290  		ce.Unlock()
   291  		return syserr.ErrConnectionRefused
   292  	}
   293  
   294  	// Create a newly bound connectionedEndpoint.
   295  	ne := &connectionedEndpoint{
   296  		baseEndpoint: baseEndpoint{
   297  			path:  e.path,
   298  			Queue: &waiter.Queue{},
   299  		},
   300  		id:          e.idGenerator.UniqueID(),
   301  		idGenerator: e.idGenerator,
   302  		stype:       e.stype,
   303  	}
   304  	ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
   305  	ne.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
   306  	ne.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
   307  
   308  	readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: defaultBufferSize}
   309  	readQueue.InitRefs()
   310  	ne.connected = &connectedEndpoint{
   311  		endpoint:   ce,
   312  		writeQueue: readQueue,
   313  	}
   314  
   315  	// Make sure the accepted endpoint inherits this listening socket's SO_SNDBUF.
   316  	writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: e.ops.GetSendBufferSize()}
   317  	writeQueue.InitRefs()
   318  	if e.stype == linux.SOCK_STREAM {
   319  		ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
   320  	} else {
   321  		ne.receiver = &queueReceiver{readQueue: writeQueue}
   322  	}
   323  
   324  	select {
   325  	case e.acceptedChan <- ne:
   326  		// Commit state.
   327  		writeQueue.IncRef()
   328  		connected := &connectedEndpoint{
   329  			endpoint:   ne,
   330  			writeQueue: writeQueue,
   331  		}
   332  		readQueue.IncRef()
   333  		if e.stype == linux.SOCK_STREAM {
   334  			returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected)
   335  		} else {
   336  			returnConnect(&queueReceiver{readQueue: readQueue}, connected)
   337  		}
   338  
   339  		// Notify can deadlock if we are holding these locks.
   340  		e.Unlock()
   341  		ce.Unlock()
   342  
   343  		// Notify on both ends.
   344  		e.Notify(waiter.ReadableEvents)
   345  		ce.WaiterQueue().Notify(waiter.WritableEvents)
   346  
   347  		return nil
   348  	default:
   349  		// Busy; return EAGAIN per spec.
   350  		ne.Close(ctx)
   351  		e.Unlock()
   352  		ce.Unlock()
   353  		return syserr.ErrTryAgain
   354  	}
   355  }
   356  
   357  // UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
   358  func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) {
   359  	return nil, syserr.ErrConnectionRefused
   360  }
   361  
   362  // Connect attempts to directly connect to another Endpoint.
   363  // Implements Endpoint.Connect.
   364  func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error {
   365  	returnConnect := func(r Receiver, ce ConnectedEndpoint) {
   366  		e.receiver = r
   367  		e.connected = ce
   368  		// Make sure the newly created connected endpoint's write queue is updated
   369  		// to reflect this endpoint's send buffer size.
   370  		if bufSz := e.connected.SetSendBufferSize(e.ops.GetSendBufferSize()); bufSz != e.ops.GetSendBufferSize() {
   371  			e.ops.SetSendBufferSize(bufSz, false /* notify */)
   372  			e.ops.SetReceiveBufferSize(bufSz, false /* notify */)
   373  		}
   374  	}
   375  
   376  	return server.BidirectionalConnect(ctx, e, returnConnect)
   377  }
   378  
   379  // Listen starts listening on the connection.
   380  func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error {
   381  	e.Lock()
   382  	defer e.Unlock()
   383  	if e.Listening() {
   384  		// Adjust the size of the channel iff we can fix existing
   385  		// pending connections into the new one.
   386  		if len(e.acceptedChan) > backlog {
   387  			return syserr.ErrInvalidEndpointState
   388  		}
   389  		origChan := e.acceptedChan
   390  		e.acceptedChan = make(chan *connectionedEndpoint, backlog)
   391  		close(origChan)
   392  		for ep := range origChan {
   393  			e.acceptedChan <- ep
   394  		}
   395  		return nil
   396  	}
   397  	if !e.isBound() {
   398  		return syserr.ErrInvalidEndpointState
   399  	}
   400  
   401  	// Normal case.
   402  	e.acceptedChan = make(chan *connectionedEndpoint, backlog)
   403  	return nil
   404  }
   405  
   406  // Accept accepts a new connection.
   407  func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) {
   408  	e.Lock()
   409  	defer e.Unlock()
   410  
   411  	if !e.Listening() {
   412  		return nil, syserr.ErrInvalidEndpointState
   413  	}
   414  
   415  	select {
   416  	case ne := <-e.acceptedChan:
   417  		if peerAddr != nil {
   418  			ne.Lock()
   419  			c := ne.connected
   420  			ne.Unlock()
   421  			if c != nil {
   422  				addr, err := c.GetLocalAddress()
   423  				if err != nil {
   424  					return nil, syserr.TranslateNetstackError(err)
   425  				}
   426  				*peerAddr = addr
   427  			}
   428  		}
   429  		return ne, nil
   430  
   431  	default:
   432  		// Nothing left.
   433  		return nil, syserr.ErrWouldBlock
   434  	}
   435  }
   436  
   437  // Bind binds the connection.
   438  //
   439  // For Unix connectionedEndpoints, this _only sets the address associated with
   440  // the socket_. Work associated with sockets in the filesystem or finding those
   441  // sockets must be done by a higher level.
   442  //
   443  // Bind will fail only if the socket is connected, bound or the passed address
   444  // is invalid (the empty string).
   445  func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error {
   446  	e.Lock()
   447  	defer e.Unlock()
   448  	if e.isBound() || e.Listening() {
   449  		return syserr.ErrAlreadyBound
   450  	}
   451  	if addr.Addr == "" {
   452  		// The empty string is not permitted.
   453  		return syserr.ErrBadLocalAddress
   454  	}
   455  	if commit != nil {
   456  		if err := commit(); err != nil {
   457  			return err
   458  		}
   459  	}
   460  
   461  	// Save the bound address.
   462  	e.path = string(addr.Addr)
   463  	return nil
   464  }
   465  
   466  // SendMsg writes data and a control message to the endpoint's peer.
   467  // This method does not block if the data cannot be written.
   468  func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
   469  	// Stream sockets do not support specifying the endpoint. Seqpacket
   470  	// sockets ignore the passed endpoint.
   471  	if e.stype == linux.SOCK_STREAM && to != nil {
   472  		return 0, syserr.ErrNotSupported
   473  	}
   474  	return e.baseEndpoint.SendMsg(ctx, data, c, to)
   475  }
   476  
   477  // Readiness returns the current readiness of the connectionedEndpoint. For
   478  // example, if waiter.EventIn is set, the connectionedEndpoint is immediately
   479  // readable.
   480  func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
   481  	e.Lock()
   482  	defer e.Unlock()
   483  
   484  	ready := waiter.EventMask(0)
   485  	switch {
   486  	case e.Connected():
   487  		if mask&waiter.ReadableEvents != 0 && e.receiver.Readable() {
   488  			ready |= waiter.ReadableEvents
   489  		}
   490  		if mask&waiter.WritableEvents != 0 && e.connected.Writable() {
   491  			ready |= waiter.WritableEvents
   492  		}
   493  	case e.Listening():
   494  		if mask&waiter.ReadableEvents != 0 && len(e.acceptedChan) > 0 {
   495  			ready |= waiter.ReadableEvents
   496  		}
   497  	}
   498  
   499  	return ready
   500  }
   501  
   502  // State implements socket.Socket.State.
   503  func (e *connectionedEndpoint) State() uint32 {
   504  	e.Lock()
   505  	defer e.Unlock()
   506  
   507  	if e.Connected() {
   508  		return linux.SS_CONNECTED
   509  	}
   510  	return linux.SS_UNCONNECTED
   511  }
   512  
   513  // OnSetSendBufferSize implements tcpip.SocketOptionsHandler.OnSetSendBufferSize.
   514  func (e *connectionedEndpoint) OnSetSendBufferSize(v int64) (newSz int64) {
   515  	if e.Connected() {
   516  		return e.baseEndpoint.connected.SetSendBufferSize(v)
   517  	}
   518  	return v
   519  }