github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/p9/client.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package p9
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  
    21  	"golang.org/x/sys/unix"
    22  	"github.com/nicocha30/gvisor-ligolo/pkg/flipcall"
    23  	"github.com/nicocha30/gvisor-ligolo/pkg/log"
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/pool"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    26  	"github.com/nicocha30/gvisor-ligolo/pkg/unet"
    27  )
    28  
    29  // ErrOutOfTags indicates no tags are available.
    30  var ErrOutOfTags = errors.New("out of tags -- messages lost?")
    31  
    32  // ErrOutOfFIDs indicates no more FIDs are available.
    33  var ErrOutOfFIDs = errors.New("out of FIDs -- messages lost?")
    34  
    35  // ErrUnexpectedTag indicates a response with an unexpected tag was received.
    36  var ErrUnexpectedTag = errors.New("unexpected tag in response")
    37  
    38  // ErrVersionsExhausted indicates that all versions to negotiate have been exhausted.
    39  var ErrVersionsExhausted = errors.New("exhausted all versions to negotiate")
    40  
    41  // ErrBadVersionString indicates that the version string is malformed or unsupported.
    42  var ErrBadVersionString = errors.New("bad version string")
    43  
    44  // ErrBadResponse indicates the response didn't match the request.
    45  type ErrBadResponse struct {
    46  	Got  MsgType
    47  	Want MsgType
    48  }
    49  
    50  // Error returns a highly descriptive error.
    51  func (e *ErrBadResponse) Error() string {
    52  	return fmt.Sprintf("unexpected message type: got %v, want %v", e.Got, e.Want)
    53  }
    54  
    55  // response is the asynchronous return from recv.
    56  //
    57  // This is used in the pending map below.
    58  type response struct {
    59  	r    message
    60  	done chan error
    61  }
    62  
    63  var responsePool = sync.Pool{
    64  	New: func() any {
    65  		return &response{
    66  			done: make(chan error, 1),
    67  		}
    68  	},
    69  }
    70  
    71  // Client is at least a 9P2000.L client.
    72  type Client struct {
    73  	// socket is the connected socket.
    74  	socket *unet.Socket
    75  
    76  	// tagPool is the collection of available tags.
    77  	tagPool pool.Pool
    78  
    79  	// fidPool is the collection of available fids.
    80  	fidPool pool.Pool
    81  
    82  	// messageSize is the maximum total size of a message.
    83  	messageSize uint32
    84  
    85  	// payloadSize is the maximum payload size of a read or write.
    86  	//
    87  	// For large reads and writes this means that the read or write is
    88  	// broken up into buffer-size/payloadSize requests.
    89  	payloadSize uint32
    90  
    91  	// version is the agreed upon version X of 9P2000.L.Google.X.
    92  	// version 0 implies 9P2000.L.
    93  	version uint32
    94  
    95  	// closedWg is marked as done when the Client.watch() goroutine, which is
    96  	// responsible for closing channels and the socket fd, returns.
    97  	closedWg sync.WaitGroup
    98  
    99  	// sendRecv is the transport function.
   100  	//
   101  	// This is determined dynamically based on whether or not the server
   102  	// supports flipcall channels (preferred as it is faster and more
   103  	// efficient, and does not require tags).
   104  	sendRecv func(message, message) error
   105  
   106  	//	-- below corresponds to sendRecvChannel --
   107  
   108  	// channelsMu protects channels.
   109  	channelsMu sync.Mutex
   110  
   111  	// channelsWg counts the number of channels for which channel.active ==
   112  	// true.
   113  	channelsWg sync.WaitGroup
   114  
   115  	// channels is the set of all initialized channels.
   116  	channels []*channel
   117  
   118  	// availableChannels is a LIFO of inactive channels.
   119  	availableChannels []*channel
   120  
   121  	//	-- below corresponds to sendRecvLegacy --
   122  
   123  	// pending is the set of pending messages.
   124  	pending   map[Tag]*response
   125  	pendingMu sync.Mutex
   126  
   127  	// sendMu is the lock for sending a request.
   128  	sendMu sync.Mutex
   129  
   130  	// recvr is essentially a mutex for calling recv.
   131  	//
   132  	// Whoever writes to this channel is permitted to call recv. When
   133  	// finished calling recv, this channel should be emptied.
   134  	recvr chan bool
   135  }
   136  
   137  // NewClient creates a new client.  It performs a Tversion exchange with
   138  // the server to assert that messageSize is ok to use.
   139  //
   140  // If NewClient succeeds, ownership of socket is transferred to the new Client.
   141  func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) {
   142  	// Need at least one byte of payload.
   143  	if messageSize <= msgRegistry.largestFixedSize {
   144  		return nil, &ErrMessageTooLarge{
   145  			size:  messageSize,
   146  			msize: msgRegistry.largestFixedSize,
   147  		}
   148  	}
   149  
   150  	// Compute a payload size and round to 512 (normal block size)
   151  	// if it's larger than a single block.
   152  	payloadSize := messageSize - msgRegistry.largestFixedSize
   153  	if payloadSize > 512 && payloadSize%512 != 0 {
   154  		payloadSize -= (payloadSize % 512)
   155  	}
   156  	c := &Client{
   157  		socket:      socket,
   158  		tagPool:     pool.Pool{Start: 1, Limit: uint64(NoTag)},
   159  		fidPool:     pool.Pool{Start: 1, Limit: uint64(NoFID)},
   160  		pending:     make(map[Tag]*response),
   161  		recvr:       make(chan bool, 1),
   162  		messageSize: messageSize,
   163  		payloadSize: payloadSize,
   164  	}
   165  	// Agree upon a version.
   166  	requested, ok := parseVersion(version)
   167  	if !ok {
   168  		return nil, ErrBadVersionString
   169  	}
   170  	for {
   171  		// Always exchange the version using the legacy version of the
   172  		// protocol. If the protocol supports flipcall, then we switch
   173  		// our sendRecv function to use that functionality.  Otherwise,
   174  		// we stick to sendRecvLegacy.
   175  		rversion := Rversion{}
   176  		_, err := c.sendRecvLegacy(&Tversion{
   177  			Version: versionString(requested),
   178  			MSize:   messageSize,
   179  		}, &rversion)
   180  
   181  		// The server told us to try again with a lower version.
   182  		if err == unix.EAGAIN {
   183  			if requested == lowestSupportedVersion {
   184  				return nil, ErrVersionsExhausted
   185  			}
   186  			requested--
   187  			continue
   188  		}
   189  
   190  		// We requested an impossible version or our other parameters were bogus.
   191  		if err != nil {
   192  			return nil, err
   193  		}
   194  
   195  		// Parse the version.
   196  		version, ok := parseVersion(rversion.Version)
   197  		if !ok {
   198  			// The server gave us a bad version. We return a generically worrisome error.
   199  			log.Warningf("server returned bad version string %q", rversion.Version)
   200  			return nil, ErrBadVersionString
   201  		}
   202  		c.version = version
   203  		break
   204  	}
   205  
   206  	// Can we switch to use the more advanced channels and create
   207  	// independent channels for communication? Prefer it if possible.
   208  	if versionSupportsFlipcall(c.version) {
   209  		// Attempt to initialize IPC-based communication.
   210  		for i := 0; i < channelsPerClient; i++ {
   211  			if err := c.openChannel(i); err != nil {
   212  				log.Warningf("error opening flipcall channel: %v", err)
   213  				break // Stop.
   214  			}
   215  		}
   216  		if len(c.channels) >= 1 {
   217  			// At least one channel created.
   218  			c.sendRecv = c.sendRecvChannel
   219  		} else {
   220  			// Channel setup failed; fallback.
   221  			c.sendRecv = c.sendRecvLegacySyscallErr
   222  		}
   223  	} else {
   224  		// No channels available: use the legacy mechanism.
   225  		c.sendRecv = c.sendRecvLegacySyscallErr
   226  	}
   227  
   228  	// Ensure that the socket and channels are closed when the socket is shut
   229  	// down.
   230  	c.closedWg.Add(1)
   231  	go c.watch(socket) // S/R-SAFE: not relevant.
   232  
   233  	return c, nil
   234  }
   235  
   236  // watch watches the given socket and releases resources on hangup events.
   237  //
   238  // This is intended to be called as a goroutine.
   239  func (c *Client) watch(socket *unet.Socket) {
   240  	defer c.closedWg.Done()
   241  
   242  	events := []unix.PollFd{
   243  		{
   244  			Fd:     int32(socket.FD()),
   245  			Events: unix.POLLHUP | unix.POLLRDHUP,
   246  		},
   247  	}
   248  
   249  	// Wait for a shutdown event.
   250  	for {
   251  		n, err := unix.Ppoll(events, nil, nil)
   252  		if err == unix.EINTR || err == unix.EAGAIN {
   253  			continue
   254  		}
   255  		if err != nil {
   256  			log.Warningf("p9.Client.watch(): %v", err)
   257  			break
   258  		}
   259  		if n != 1 {
   260  			log.Warningf("p9.Client.watch(): got %d events, wanted 1", n)
   261  		}
   262  		break
   263  	}
   264  
   265  	// Set availableChannels to nil so that future calls to c.sendRecvChannel()
   266  	// don't attempt to activate a channel, and concurrent calls to
   267  	// c.sendRecvChannel() don't mark released channels as available.
   268  	c.channelsMu.Lock()
   269  	c.availableChannels = nil
   270  
   271  	// Shut down all active channels.
   272  	for _, ch := range c.channels {
   273  		if ch.active {
   274  			log.Debugf("shutting down active channel@%p...", ch)
   275  			ch.Shutdown()
   276  		}
   277  	}
   278  	c.channelsMu.Unlock()
   279  
   280  	// Wait for active channels to become inactive.
   281  	c.channelsWg.Wait()
   282  
   283  	// Close all channels.
   284  	c.channelsMu.Lock()
   285  	for _, ch := range c.channels {
   286  		ch.Close()
   287  	}
   288  	c.channelsMu.Unlock()
   289  
   290  	// Close the main socket.
   291  	c.socket.Close()
   292  }
   293  
   294  // openChannel attempts to open a client channel.
   295  //
   296  // Note that this function returns naked errors which should not be propagated
   297  // directly to a caller. It is expected that the errors will be logged and a
   298  // fallback path will be used instead.
   299  func (c *Client) openChannel(id int) error {
   300  	var (
   301  		rchannel0 Rchannel
   302  		rchannel1 Rchannel
   303  		res       = new(channel)
   304  	)
   305  
   306  	// Open the data channel.
   307  	if _, err := c.sendRecvLegacy(&Tchannel{
   308  		ID:      uint32(id),
   309  		Control: 0,
   310  	}, &rchannel0); err != nil {
   311  		return fmt.Errorf("error handling Tchannel message: %v", err)
   312  	}
   313  	if rchannel0.FilePayload() == nil {
   314  		return fmt.Errorf("missing file descriptor on primary channel")
   315  	}
   316  
   317  	// We don't need to hold this.
   318  	defer rchannel0.FilePayload().Close()
   319  
   320  	// Open the channel for file descriptors.
   321  	if _, err := c.sendRecvLegacy(&Tchannel{
   322  		ID:      uint32(id),
   323  		Control: 1,
   324  	}, &rchannel1); err != nil {
   325  		return err
   326  	}
   327  	if rchannel1.FilePayload() == nil {
   328  		return fmt.Errorf("missing file descriptor on file descriptor channel")
   329  	}
   330  
   331  	// Construct the endpoints.
   332  	res.desc = flipcall.PacketWindowDescriptor{
   333  		FD:     rchannel0.FilePayload().FD(),
   334  		Offset: int64(rchannel0.Offset),
   335  		Length: int(rchannel0.Length),
   336  	}
   337  	if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil {
   338  		rchannel1.FilePayload().Close()
   339  		return err
   340  	}
   341  
   342  	// The fds channel owns the control payload, and it will be closed when
   343  	// the channel object is closed.
   344  	res.fds.Init(rchannel1.FilePayload().Release())
   345  
   346  	// Save the channel.
   347  	c.channelsMu.Lock()
   348  	defer c.channelsMu.Unlock()
   349  	c.channels = append(c.channels, res)
   350  	c.availableChannels = append(c.availableChannels, res)
   351  	return nil
   352  }
   353  
   354  // handleOne handles a single incoming message.
   355  //
   356  // This should only be called with the token from recvr. Note that the received
   357  // tag will automatically be cleared from pending.
   358  func (c *Client) handleOne() {
   359  	tag, r, err := recv(c.socket, c.messageSize, func(tag Tag, t MsgType) (message, error) {
   360  		c.pendingMu.Lock()
   361  		resp := c.pending[tag]
   362  		c.pendingMu.Unlock()
   363  
   364  		// Not expecting this message?
   365  		if resp == nil {
   366  			log.Warningf("client received unexpected tag %v, ignoring", tag)
   367  			return nil, ErrUnexpectedTag
   368  		}
   369  
   370  		// Is it an error? We specifically allow this to
   371  		// go through, and then we deserialize below.
   372  		if t == MsgRlerror {
   373  			return &Rlerror{}, nil
   374  		}
   375  
   376  		// Does it match expectations?
   377  		if t != resp.r.Type() {
   378  			return nil, &ErrBadResponse{Got: t, Want: resp.r.Type()}
   379  		}
   380  
   381  		// Return the response.
   382  		return resp.r, nil
   383  	})
   384  
   385  	if err != nil {
   386  		// No tag was extracted (probably a socket error).
   387  		//
   388  		// Likely catastrophic. Notify all waiters and clear pending.
   389  		c.pendingMu.Lock()
   390  		for _, resp := range c.pending {
   391  			resp.done <- err
   392  		}
   393  		c.pending = make(map[Tag]*response)
   394  		c.pendingMu.Unlock()
   395  	} else {
   396  		// Process the tag.
   397  		//
   398  		// We know that is is contained in the map because our lookup function
   399  		// above must have succeeded (found the tag) to return nil err.
   400  		c.pendingMu.Lock()
   401  		resp := c.pending[tag]
   402  		delete(c.pending, tag)
   403  		c.pendingMu.Unlock()
   404  		resp.r = r
   405  		resp.done <- err
   406  	}
   407  }
   408  
   409  // waitAndRecv co-ordinates with other receivers to handle responses.
   410  func (c *Client) waitAndRecv(done chan error) error {
   411  	for {
   412  		select {
   413  		case err := <-done:
   414  			return err
   415  		case c.recvr <- true:
   416  			select {
   417  			case err := <-done:
   418  				// It's possible that we got the token, despite
   419  				// done also being available. Check for that.
   420  				<-c.recvr
   421  				return err
   422  			default:
   423  				// Handle receiving one tag.
   424  				c.handleOne()
   425  
   426  				// Return the token.
   427  				<-c.recvr
   428  			}
   429  		}
   430  	}
   431  }
   432  
   433  // sendRecvLegacySyscallErr is a wrapper for sendRecvLegacy that converts all
   434  // non-syscall errors to EIO.
   435  func (c *Client) sendRecvLegacySyscallErr(t message, r message) error {
   436  	received, err := c.sendRecvLegacy(t, r)
   437  	if !received {
   438  		log.Warningf("p9.Client.sendRecvChannel: %v", err)
   439  		return unix.EIO
   440  	}
   441  	return err
   442  }
   443  
   444  // sendRecvLegacy performs a roundtrip message exchange.
   445  //
   446  // sendRecvLegacy returns true if a message was received. This allows us to
   447  // differentiate between failed receives and successful receives where the
   448  // response was an error message.
   449  //
   450  // This is called by internal functions.
   451  func (c *Client) sendRecvLegacy(t message, r message) (bool, error) {
   452  	tag, ok := c.tagPool.Get()
   453  	if !ok {
   454  		return false, ErrOutOfTags
   455  	}
   456  	defer c.tagPool.Put(tag)
   457  
   458  	// Indicate we're expecting a response.
   459  	//
   460  	// Note that the tag will be cleared from pending
   461  	// automatically (see handleOne for details).
   462  	resp := responsePool.Get().(*response)
   463  	defer responsePool.Put(resp)
   464  	resp.r = r
   465  	c.pendingMu.Lock()
   466  	c.pending[Tag(tag)] = resp
   467  	c.pendingMu.Unlock()
   468  
   469  	// Send the request over the wire.
   470  	c.sendMu.Lock()
   471  	err := send(c.socket, Tag(tag), t)
   472  	c.sendMu.Unlock()
   473  	if err != nil {
   474  		return false, err
   475  	}
   476  
   477  	// Co-ordinate with other receivers.
   478  	if err := c.waitAndRecv(resp.done); err != nil {
   479  		return false, err
   480  	}
   481  
   482  	// Is it an error message?
   483  	//
   484  	// For convenience, we transform these directly
   485  	// into errors. Handlers need not handle this case.
   486  	if rlerr, ok := resp.r.(*Rlerror); ok {
   487  		return true, unix.Errno(rlerr.Error)
   488  	}
   489  
   490  	// At this point, we know it matches.
   491  	//
   492  	// Per recv call above, we will only allow a type
   493  	// match (and give our r) or an instance of Rlerror.
   494  	return true, nil
   495  }
   496  
   497  // sendRecvChannel uses channels to send a message.
   498  func (c *Client) sendRecvChannel(t message, r message) error {
   499  	// Acquire an available channel.
   500  	c.channelsMu.Lock()
   501  	if len(c.availableChannels) == 0 {
   502  		c.channelsMu.Unlock()
   503  		return c.sendRecvLegacySyscallErr(t, r)
   504  	}
   505  	idx := len(c.availableChannels) - 1
   506  	ch := c.availableChannels[idx]
   507  	c.availableChannels = c.availableChannels[:idx]
   508  	ch.active = true
   509  	c.channelsWg.Add(1)
   510  	c.channelsMu.Unlock()
   511  
   512  	// Ensure that it's connected.
   513  	if !ch.connected {
   514  		ch.connected = true
   515  		if err := ch.data.Connect(); err != nil {
   516  			// The channel is unusable, so don't return it to
   517  			// c.availableChannels. However, we still have to mark it as
   518  			// inactive so c.watch() doesn't wait for it.
   519  			c.channelsMu.Lock()
   520  			ch.active = false
   521  			c.channelsMu.Unlock()
   522  			c.channelsWg.Done()
   523  			// Map all transport errors to EIO, but ensure that the real error
   524  			// is logged.
   525  			log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err)
   526  			return unix.EIO
   527  		}
   528  	}
   529  
   530  	// Send the request and receive the server's response.
   531  	rsz, err := ch.send(t, false /* isServer */)
   532  	if err != nil {
   533  		// See above.
   534  		c.channelsMu.Lock()
   535  		ch.active = false
   536  		c.channelsMu.Unlock()
   537  		c.channelsWg.Done()
   538  		log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err)
   539  		return unix.EIO
   540  	}
   541  
   542  	// Parse the server's response.
   543  	resp, retErr := ch.recv(r, rsz)
   544  	if resp == nil {
   545  		log.Warningf("p9.Client.sendRecvChannel: p9.channel.recv: %v", retErr)
   546  		retErr = unix.EIO
   547  	}
   548  
   549  	// Release the channel.
   550  	c.channelsMu.Lock()
   551  	ch.active = false
   552  	// If c.availableChannels is nil, c.watch() has fired and we should not
   553  	// mark this channel as available.
   554  	if c.availableChannels != nil {
   555  		c.availableChannels = append(c.availableChannels, ch)
   556  	}
   557  	c.channelsMu.Unlock()
   558  	c.channelsWg.Done()
   559  
   560  	return retErr
   561  }
   562  
   563  // Version returns the negotiated 9P2000.L.Google version number.
   564  func (c *Client) Version() uint32 {
   565  	return c.version
   566  }
   567  
   568  // Close closes the underlying socket and channels.
   569  func (c *Client) Close() {
   570  	// unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already
   571  	// been called (by c.watch()).
   572  	if err := c.socket.Shutdown(); err != nil {
   573  		log.Warningf("Socket.Shutdown() failed (FD: %d): %v", c.socket.FD(), err)
   574  	}
   575  	c.closedWg.Wait()
   576  }