github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/lisafs/connection.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package lisafs
    16  
    17  import (
    18  	"path"
    19  	"path/filepath"
    20  	"runtime/debug"
    21  
    22  	"golang.org/x/sys/unix"
    23  	"github.com/nicocha30/gvisor-ligolo/pkg/flipcall"
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/log"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/p9"
    26  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    27  	"github.com/nicocha30/gvisor-ligolo/pkg/unet"
    28  )
    29  
    30  // Connection represents a connection between a mount point in the client and a
    31  // mount point in the server. It is owned by the server on which it was started
    32  // and facilitates communication with the client mount.
    33  //
    34  // Each connection is set up using a unix domain socket. One end is owned by
    35  // the server and the other end is owned by the client. The connection may
    36  // spawn additional comunicational channels for the same mount for increased
    37  // RPC concurrency.
    38  //
    39  // Reference model:
    40  //   - When any FD is created, the connection takes a ref on it which represents
    41  //     the client's ref on the FD.
    42  //   - The client can drop its ref via the Close RPC which will in turn make the
    43  //     connection drop its ref.
    44  type Connection struct {
    45  	// server is the server on which this connection was created. It is immutably
    46  	// associated with it for its entire lifetime.
    47  	server *Server
    48  
    49  	// mountPath is the path to a file inside the server that is served to this
    50  	// connection as its root FD. IOW, this connection is mounted at this path.
    51  	// mountPath is trusted because it is configured by the server (trusted) as
    52  	// per the user's sandbox configuration. mountPath is immutable.
    53  	mountPath string
    54  
    55  	// maxMessageSize is the cached value of server.impl.MaxMessageSize().
    56  	maxMessageSize uint32
    57  
    58  	// readonly indicates if this connection is readonly. All write operations
    59  	// will fail with EROFS.
    60  	readonly bool
    61  
    62  	// sockComm is the main socket by which this connections is established.
    63  	sockComm *sockCommunicator
    64  
    65  	// channelsMu protects channels.
    66  	channelsMu sync.Mutex
    67  	// channels keeps track of all open channels.
    68  	channels []*channel
    69  
    70  	// activeWg represents active channels.
    71  	activeWg sync.WaitGroup
    72  
    73  	// reqGate counts requests that are still being handled.
    74  	reqGate sync.Gate
    75  
    76  	// channelAlloc is used to allocate memory for channels.
    77  	channelAlloc *flipcall.PacketWindowAllocator
    78  
    79  	fdsMu sync.RWMutex
    80  	// fds keeps tracks of open FDs on this server. It is protected by fdsMu.
    81  	fds map[FDID]genericFD
    82  	// nextFDID is the next available FDID. It is protected by fdsMu.
    83  	nextFDID FDID
    84  }
    85  
    86  // CreateConnection initializes a new connection which will be mounted at
    87  // mountPath. The connection must be started separately.
    88  func (s *Server) CreateConnection(sock *unet.Socket, mountPath string, readonly bool) (*Connection, error) {
    89  	mountPath = path.Clean(mountPath)
    90  	if !filepath.IsAbs(mountPath) {
    91  		log.Warningf("mountPath %q is not absolute", mountPath)
    92  		return nil, unix.EINVAL
    93  	}
    94  
    95  	c := &Connection{
    96  		sockComm:       newSockComm(sock),
    97  		server:         s,
    98  		maxMessageSize: s.impl.MaxMessageSize(),
    99  		mountPath:      mountPath,
   100  		readonly:       readonly,
   101  		channels:       make([]*channel, 0, maxChannels()),
   102  		fds:            make(map[FDID]genericFD),
   103  		nextFDID:       InvalidFDID + 1,
   104  	}
   105  
   106  	alloc, err := flipcall.NewPacketWindowAllocator()
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	c.channelAlloc = alloc
   111  	return c, nil
   112  }
   113  
   114  // ServerImpl returns the associated server implementation.
   115  func (c *Connection) ServerImpl() ServerImpl {
   116  	return c.server.impl
   117  }
   118  
   119  // Run defines the lifecycle of a connection.
   120  func (c *Connection) Run() {
   121  	defer c.close()
   122  
   123  	// Start handling requests on this connection.
   124  	for {
   125  		m, payloadLen, err := c.sockComm.rcvMsg(0 /* wantFDs */)
   126  		if err != nil {
   127  			log.Debugf("sock read failed, closing connection: %v", err)
   128  			return
   129  		}
   130  
   131  		respM, respPayloadLen, respFDs := c.handleMsg(c.sockComm, m, payloadLen)
   132  		err = c.sockComm.sndPrepopulatedMsg(respM, respPayloadLen, respFDs)
   133  		closeFDs(respFDs)
   134  		if err != nil {
   135  			log.Debugf("sock write failed, closing connection: %v", err)
   136  			return
   137  		}
   138  	}
   139  }
   140  
   141  // service starts servicing the passed channel until the channel is shutdown.
   142  // This is a blocking method and hence must be called in a separate goroutine.
   143  func (c *Connection) service(ch *channel) error {
   144  	rcvDataLen, err := ch.data.RecvFirst()
   145  	if err != nil {
   146  		return err
   147  	}
   148  	for rcvDataLen > 0 {
   149  		m, payloadLen, err := ch.rcvMsg(rcvDataLen)
   150  		if err != nil {
   151  			return err
   152  		}
   153  		respM, respPayloadLen, respFDs := c.handleMsg(ch, m, payloadLen)
   154  		numFDs := ch.sendFDs(respFDs)
   155  		closeFDs(respFDs)
   156  
   157  		ch.marshalHdr(respM, numFDs)
   158  		rcvDataLen, err = ch.data.SendRecv(respPayloadLen + chanHeaderLen)
   159  		if err != nil {
   160  			return err
   161  		}
   162  	}
   163  	return nil
   164  }
   165  
   166  func (c *Connection) respondError(comm Communicator, err unix.Errno) (MID, uint32, []int) {
   167  	resp := &ErrorResp{errno: uint32(err)}
   168  	respLen := uint32(resp.SizeBytes())
   169  	resp.MarshalUnsafe(comm.PayloadBuf(respLen))
   170  	return Error, respLen, nil
   171  }
   172  
   173  func (c *Connection) handleMsg(comm Communicator, m MID, payloadLen uint32) (retM MID, retPayloadLen uint32, retFDs []int) {
   174  	if payloadLen > c.maxMessageSize {
   175  		log.Warningf("received payload is too large: %d bytes", payloadLen)
   176  		return c.respondError(comm, unix.EIO)
   177  	}
   178  	if !c.reqGate.Enter() {
   179  		// c.close() has been called; the connection is shutting down.
   180  		return c.respondError(comm, unix.ECONNRESET)
   181  	}
   182  	defer func() {
   183  		c.reqGate.Leave()
   184  
   185  		// Don't allow a panic to propagate.
   186  		if err := recover(); err != nil {
   187  			// Include a useful log message.
   188  			log.Warningf("panic in handler: %v\n%s", err, debug.Stack())
   189  
   190  			// Wrap in an EREMOTEIO error; we don't really have a better way to
   191  			// describe this kind of error. EREMOTEIO is appropriate for a generic
   192  			// failed RPC message.
   193  			retM, retPayloadLen, retFDs = c.respondError(comm, unix.EREMOTEIO)
   194  		}
   195  	}()
   196  
   197  	// Check if the message is supported for forward compatibility.
   198  	if int(m) >= len(c.server.handlers) || c.server.handlers[m] == nil {
   199  		log.Warningf("received request which is not supported by the server, MID = %d", m)
   200  		return c.respondError(comm, unix.EOPNOTSUPP)
   201  	}
   202  
   203  	// Try handling the request.
   204  	respPayloadLen, err := c.server.handlers[m](c, comm, payloadLen)
   205  	fds := comm.ReleaseFDs()
   206  	if err != nil {
   207  		closeFDs(fds)
   208  		return c.respondError(comm, p9.ExtractErrno(err))
   209  	}
   210  	if respPayloadLen > c.maxMessageSize {
   211  		log.Warningf("handler for message %d responded with payload which is too large: %d bytes", m, respPayloadLen)
   212  		closeFDs(fds)
   213  		return c.respondError(comm, unix.EIO)
   214  	}
   215  
   216  	return m, respPayloadLen, fds
   217  }
   218  
   219  func (c *Connection) close() {
   220  	// Wait for completion of all inflight requests. This is mostly so that if
   221  	// a request is stuck, the sandbox supervisor has the opportunity to kill
   222  	// us with SIGABRT to get a stack dump of the offending handler.
   223  	c.reqGate.Close()
   224  
   225  	// Shutdown and clean up channels.
   226  	c.channelsMu.Lock()
   227  	for _, ch := range c.channels {
   228  		ch.shutdown()
   229  	}
   230  	c.activeWg.Wait()
   231  	for _, ch := range c.channels {
   232  		ch.destroy()
   233  	}
   234  	// This is to prevent additional channels from being created.
   235  	c.channels = nil
   236  	c.channelsMu.Unlock()
   237  
   238  	// Free the channel memory.
   239  	if c.channelAlloc != nil {
   240  		c.channelAlloc.Destroy()
   241  	}
   242  
   243  	// Ensure the connection is closed.
   244  	c.sockComm.destroy()
   245  
   246  	// Cleanup all FDs.
   247  	c.fdsMu.Lock()
   248  	defer c.fdsMu.Unlock()
   249  	for fdid := range c.fds {
   250  		fd := c.stopTrackingFD(fdid)
   251  		fd.DecRef(nil) // Drop the ref held by c.
   252  	}
   253  }
   254  
   255  // Postcondition: The caller gains a ref on the FD on success.
   256  func (c *Connection) lookupFD(id FDID) (genericFD, error) {
   257  	c.fdsMu.RLock()
   258  	defer c.fdsMu.RUnlock()
   259  
   260  	fd, ok := c.fds[id]
   261  	if !ok {
   262  		return nil, unix.EBADF
   263  	}
   264  	fd.IncRef()
   265  	return fd, nil
   266  }
   267  
   268  // lookupControlFD retrieves the control FD identified by id on this
   269  // connection. On success, the caller gains a ref on the FD.
   270  func (c *Connection) lookupControlFD(id FDID) (*ControlFD, error) {
   271  	fd, err := c.lookupFD(id)
   272  	if err != nil {
   273  		return nil, err
   274  	}
   275  
   276  	cfd, ok := fd.(*ControlFD)
   277  	if !ok {
   278  		fd.DecRef(nil)
   279  		return nil, unix.EINVAL
   280  	}
   281  	return cfd, nil
   282  }
   283  
   284  // lookupOpenFD retrieves the open FD identified by id on this
   285  // connection. On success, the caller gains a ref on the FD.
   286  func (c *Connection) lookupOpenFD(id FDID) (*OpenFD, error) {
   287  	fd, err := c.lookupFD(id)
   288  	if err != nil {
   289  		return nil, err
   290  	}
   291  
   292  	ofd, ok := fd.(*OpenFD)
   293  	if !ok {
   294  		fd.DecRef(nil)
   295  		return nil, unix.EINVAL
   296  	}
   297  	return ofd, nil
   298  }
   299  
   300  // lookupBoundSocketFD retrieves the boundSockedFD identified by id on this
   301  // connection. On success, the caller gains a ref on the FD.
   302  func (c *Connection) lookupBoundSocketFD(id FDID) (*BoundSocketFD, error) {
   303  	fd, err := c.lookupFD(id)
   304  	if err != nil {
   305  		return nil, err
   306  	}
   307  
   308  	bsfd, ok := fd.(*BoundSocketFD)
   309  	if !ok {
   310  		fd.DecRef(nil)
   311  		return nil, unix.EINVAL
   312  	}
   313  	return bsfd, nil
   314  }
   315  
   316  // insertFD inserts the passed fd into the internal datastructure to track FDs.
   317  // The caller must hold a ref on fd which is transferred to the connection.
   318  func (c *Connection) insertFD(fd genericFD) FDID {
   319  	c.fdsMu.Lock()
   320  	defer c.fdsMu.Unlock()
   321  
   322  	res := c.nextFDID
   323  	c.nextFDID++
   324  	if c.nextFDID < res {
   325  		panic("ran out of FDIDs")
   326  	}
   327  	c.fds[res] = fd
   328  	return res
   329  }
   330  
   331  // removeFD makes c stop tracking the passed FDID and drops its ref on it.
   332  func (c *Connection) removeFD(id FDID) {
   333  	c.fdsMu.Lock()
   334  	fd := c.stopTrackingFD(id)
   335  	c.fdsMu.Unlock()
   336  	if fd != nil {
   337  		// Drop the ref held by c. This can take arbitrarily long. So do not hold
   338  		// c.fdsMu while calling it.
   339  		fd.DecRef(nil)
   340  	}
   341  }
   342  
   343  // removeControlFDLocked is the same as removeFD with added preconditions.
   344  //
   345  // Preconditions:
   346  //   - server's rename mutex must at least be read locked.
   347  //   - id must be pointing to a control FD.
   348  func (c *Connection) removeControlFDLocked(id FDID) {
   349  	c.fdsMu.Lock()
   350  	fd := c.stopTrackingFD(id)
   351  	c.fdsMu.Unlock()
   352  	if fd != nil {
   353  		// Drop the ref held by c. This can take arbitrarily long. So do not hold
   354  		// c.fdsMu while calling it.
   355  		fd.(*ControlFD).decRefLocked()
   356  	}
   357  }
   358  
   359  // stopTrackingFD makes c stop tracking the passed FDID. Note that the caller
   360  // must drop ref on the returned fd (preferably without holding c.fdsMu).
   361  //
   362  // Precondition: c.fdsMu is locked.
   363  func (c *Connection) stopTrackingFD(id FDID) genericFD {
   364  	fd := c.fds[id]
   365  	if fd == nil {
   366  		log.Warningf("removeFDLocked called on non-existent FDID %d", id)
   367  		return nil
   368  	}
   369  	delete(c.fds, id)
   370  	return fd
   371  }