github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/p9/transport_flipcall.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package p9
    16  
    17  import (
    18  	"runtime"
    19  
    20  	"golang.org/x/sys/unix"
    21  	"github.com/sagernet/gvisor/pkg/fd"
    22  	"github.com/sagernet/gvisor/pkg/fdchannel"
    23  	"github.com/sagernet/gvisor/pkg/flipcall"
    24  	"github.com/sagernet/gvisor/pkg/log"
    25  )
    26  
    27  // channelsPerClient is the number of channels to create per client.
    28  //
    29  // While the client and server will generally agree on this number, in reality
    30  // it's completely up to the server. We simply define a minimum of 2, and a
    31  // maximum of 4, and select the number of available processes as a tie-breaker.
    32  // Note that we don't want the number of channels to be too large, because each
    33  // will account for channelSize memory used, which can be large.
    34  var channelsPerClient = func() int {
    35  	n := runtime.NumCPU()
    36  	if n < 2 {
    37  		return 2
    38  	}
    39  	if n > 4 {
    40  		return 4
    41  	}
    42  	return n
    43  }()
    44  
    45  // channelSize is the channel size to create.
    46  //
    47  // We simply ensure that this is larger than the largest possible message size,
    48  // plus the flipcall packet header, plus the two bytes we write below.
    49  const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength)
    50  
    51  // channel is a fast IPC channel.
    52  //
    53  // The same object is used by both the server and client implementations. In
    54  // general, the client will use only the send and recv methods.
    55  type channel struct {
    56  	desc flipcall.PacketWindowDescriptor
    57  	data flipcall.Endpoint
    58  	fds  fdchannel.Endpoint
    59  	buf  buffer
    60  
    61  	//	-- client only --
    62  	connected bool
    63  	active    bool
    64  
    65  	//	-- server only --
    66  	client *fd.FD
    67  	done   chan struct{}
    68  }
    69  
    70  // reset resets the channel buffer.
    71  func (ch *channel) reset(sz uint32) {
    72  	ch.buf.data = ch.data.Data()[:sz]
    73  }
    74  
    75  // service services the channel.
    76  func (ch *channel) service(cs *connState) error {
    77  	rsz, err := ch.data.RecvFirst()
    78  	if err != nil {
    79  		return err
    80  	}
    81  	for rsz > 0 {
    82  		m, err := ch.recv(nil, rsz)
    83  		if err != nil {
    84  			return err
    85  		}
    86  		r := cs.handle(m)
    87  		msgRegistry.put(m)
    88  		rsz, err = ch.send(r, true /* isServer */)
    89  		if err != nil {
    90  			return err
    91  		}
    92  	}
    93  	return nil // Done.
    94  }
    95  
    96  // Shutdown shuts down the channel.
    97  //
    98  // This must be called before Close.
    99  func (ch *channel) Shutdown() {
   100  	ch.data.Shutdown()
   101  }
   102  
   103  // Close closes the channel.
   104  //
   105  // This must only be called once, and cannot return an error. Note that
   106  // synchronization for this method is provided at a high-level, depending on
   107  // whether it is the client or server. This cannot be called while there are
   108  // active callers in either service or sendRecv.
   109  //
   110  // Precondition: the channel should be shutdown.
   111  func (ch *channel) Close() error {
   112  	// Close all backing transports.
   113  	ch.fds.Destroy()
   114  	ch.data.Destroy()
   115  	if ch.client != nil {
   116  		ch.client.Close()
   117  	}
   118  	return nil
   119  }
   120  
   121  // send sends the given message.
   122  //
   123  // The return value is the size of the received response. Not that in the
   124  // server case, this is the size of the next request.
   125  func (ch *channel) send(m message, isServer bool) (uint32, error) {
   126  	if log.IsLogging(log.Debug) {
   127  		log.Debugf("send [channel @%p] %s", ch, m.String())
   128  	}
   129  
   130  	// Send any file payload.
   131  	sentFD := false
   132  	if filer, ok := m.(filer); ok {
   133  		if f := filer.FilePayload(); f != nil {
   134  			if err := ch.fds.SendFD(f.FD()); err != nil {
   135  				return 0, err
   136  			}
   137  			f.Close()     // Per sendRecvLegacy.
   138  			sentFD = true // To mark below.
   139  		}
   140  	}
   141  
   142  	// Encode the message.
   143  	//
   144  	// Note that IPC itself encodes the length of messages, so we don't
   145  	// need to encode a standard 9P header. We write only the message type.
   146  	ch.reset(0)
   147  
   148  	ch.buf.WriteMsgType(m.Type())
   149  	if sentFD {
   150  		ch.buf.Write8(1) // Incoming FD.
   151  	} else {
   152  		ch.buf.Write8(0) // No incoming FD.
   153  	}
   154  	m.encode(&ch.buf)
   155  	ssz := uint32(len(ch.buf.data)) // Updated below.
   156  
   157  	// Is there a payload?
   158  	if payloader, ok := m.(payloader); ok {
   159  		p := payloader.Payload()
   160  		copy(ch.data.Data()[ssz:], p)
   161  		ssz += uint32(len(p))
   162  	}
   163  
   164  	// Perform the one-shot communication.
   165  	if isServer {
   166  		return ch.data.SendRecv(ssz)
   167  	}
   168  	// RPCs are expected to return quickly rather than block.
   169  	return ch.data.SendRecvFast(ssz)
   170  }
   171  
   172  // recv decodes a message that exists on the channel.
   173  //
   174  // If the passed r is non-nil, then the type must match or an error will be
   175  // generated. If the passed r is nil, then a new message will be created and
   176  // returned.
   177  func (ch *channel) recv(r message, rsz uint32) (message, error) {
   178  	// Decode the response from the inline buffer.
   179  	ch.reset(rsz)
   180  	t := ch.buf.ReadMsgType()
   181  	hasFD := ch.buf.Read8() != 0
   182  	if t == MsgRlerror {
   183  		// Change the message type. We check for this special case
   184  		// after decoding below, and transform into an error.
   185  		r = &Rlerror{}
   186  	} else if r == nil {
   187  		nr, err := msgRegistry.get(0, t)
   188  		if err != nil {
   189  			return nil, err
   190  		}
   191  		r = nr // New message.
   192  	} else if t != r.Type() {
   193  		// Not an error and not the expected response; propagate.
   194  		return nil, &ErrBadResponse{Got: t, Want: r.Type()}
   195  	}
   196  
   197  	// Is there a payload? Copy from the latter portion.
   198  	if payloader, ok := r.(payloader); ok {
   199  		fs := payloader.FixedSize()
   200  		p := payloader.Payload()
   201  		payloadData := ch.buf.data[fs:]
   202  		if len(p) < len(payloadData) {
   203  			p = make([]byte, len(payloadData))
   204  			copy(p, payloadData)
   205  			payloader.SetPayload(p)
   206  		} else if n := copy(p, payloadData); n < len(p) {
   207  			payloader.SetPayload(p[:n])
   208  		}
   209  		ch.buf.data = ch.buf.data[:fs]
   210  	}
   211  
   212  	r.decode(&ch.buf)
   213  	if ch.buf.isOverrun() {
   214  		// Nothing valid was available.
   215  		log.Debugf("recv [got %d bytes, needed more]", rsz)
   216  		return nil, ErrNoValidMessage
   217  	}
   218  
   219  	// Read any FD result.
   220  	if hasFD {
   221  		if rfd, err := ch.fds.RecvFDNonblock(); err == nil {
   222  			f := fd.New(rfd)
   223  			if filer, ok := r.(filer); ok {
   224  				// Set the payload.
   225  				filer.SetFilePayload(f)
   226  			} else {
   227  				// Don't want the FD.
   228  				f.Close()
   229  			}
   230  		} else {
   231  			// The header bit was set but nothing came in.
   232  			log.Warningf("expected FD, got err: %v", err)
   233  		}
   234  	}
   235  
   236  	// Log a message.
   237  	if log.IsLogging(log.Debug) {
   238  		log.Debugf("recv [channel @%p] %s", ch, r.String())
   239  	}
   240  
   241  	// Convert errors appropriately; see above.
   242  	if rlerr, ok := r.(*Rlerror); ok {
   243  		return r, unix.Errno(rlerr.Error)
   244  	}
   245  
   246  	return r, nil
   247  }