github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/p9/transport_flipcall.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package p9
    16  
    17  import (
    18  	"runtime"
    19  
    20  	"golang.org/x/sys/unix"
    21  	"github.com/SagerNet/gvisor/pkg/fd"
    22  	"github.com/SagerNet/gvisor/pkg/fdchannel"
    23  	"github.com/SagerNet/gvisor/pkg/flipcall"
    24  	"github.com/SagerNet/gvisor/pkg/log"
    25  )
    26  
    27  // channelsPerClient is the number of channels to create per client.
    28  //
    29  // While the client and server will generally agree on this number, in reality
    30  // it's completely up to the server. We simply define a minimum of 2, and a
    31  // maximum of 4, and select the number of available processes as a tie-breaker.
    32  // Note that we don't want the number of channels to be too large, because each
    33  // will account for channelSize memory used, which can be large.
    34  var channelsPerClient = func() int {
    35  	n := runtime.NumCPU()
    36  	if n < 2 {
    37  		return 2
    38  	}
    39  	if n > 4 {
    40  		return 4
    41  	}
    42  	return n
    43  }()
    44  
    45  // channelSize is the channel size to create.
    46  //
    47  // We simply ensure that this is larger than the largest possible message size,
    48  // plus the flipcall packet header, plus the two bytes we write below.
    49  const channelSize = int(2 + flipcall.PacketHeaderBytes + 2 + maximumLength)
    50  
    51  // channel is a fast IPC channel.
    52  //
    53  // The same object is used by both the server and client implementations. In
    54  // general, the client will use only the send and recv methods.
    55  type channel struct {
    56  	desc flipcall.PacketWindowDescriptor
    57  	data flipcall.Endpoint
    58  	fds  fdchannel.Endpoint
    59  	buf  buffer
    60  
    61  	// -- client only --
    62  	connected bool
    63  	active    bool
    64  
    65  	// -- server only --
    66  	client *fd.FD
    67  	done   chan struct{}
    68  }
    69  
    70  // reset resets the channel buffer.
    71  func (ch *channel) reset(sz uint32) {
    72  	ch.buf.data = ch.data.Data()[:sz]
    73  }
    74  
    75  // service services the channel.
    76  func (ch *channel) service(cs *connState) error {
    77  	rsz, err := ch.data.RecvFirst()
    78  	if err != nil {
    79  		return err
    80  	}
    81  	for rsz > 0 {
    82  		m, err := ch.recv(nil, rsz)
    83  		if err != nil {
    84  			return err
    85  		}
    86  		r := cs.handle(m)
    87  		msgRegistry.put(m)
    88  		rsz, err = ch.send(r)
    89  		if err != nil {
    90  			return err
    91  		}
    92  	}
    93  	return nil // Done.
    94  }
    95  
    96  // Shutdown shuts down the channel.
    97  //
    98  // This must be called before Close.
    99  func (ch *channel) Shutdown() {
   100  	ch.data.Shutdown()
   101  }
   102  
   103  // Close closes the channel.
   104  //
   105  // This must only be called once, and cannot return an error. Note that
   106  // synchronization for this method is provided at a high-level, depending on
   107  // whether it is the client or server. This cannot be called while there are
   108  // active callers in either service or sendRecv.
   109  //
   110  // Precondition: the channel should be shutdown.
   111  func (ch *channel) Close() error {
   112  	// Close all backing transports.
   113  	ch.fds.Destroy()
   114  	ch.data.Destroy()
   115  	if ch.client != nil {
   116  		ch.client.Close()
   117  	}
   118  	return nil
   119  }
   120  
   121  // send sends the given message.
   122  //
   123  // The return value is the size of the received response. Not that in the
   124  // server case, this is the size of the next request.
   125  func (ch *channel) send(m message) (uint32, error) {
   126  	if log.IsLogging(log.Debug) {
   127  		log.Debugf("send [channel @%p] %s", ch, m.String())
   128  	}
   129  
   130  	// Send any file payload.
   131  	sentFD := false
   132  	if filer, ok := m.(filer); ok {
   133  		if f := filer.FilePayload(); f != nil {
   134  			if err := ch.fds.SendFD(f.FD()); err != nil {
   135  				return 0, err
   136  			}
   137  			f.Close()     // Per sendRecvLegacy.
   138  			sentFD = true // To mark below.
   139  		}
   140  	}
   141  
   142  	// Encode the message.
   143  	//
   144  	// Note that IPC itself encodes the length of messages, so we don't
   145  	// need to encode a standard 9P header. We write only the message type.
   146  	ch.reset(0)
   147  
   148  	ch.buf.WriteMsgType(m.Type())
   149  	if sentFD {
   150  		ch.buf.Write8(1) // Incoming FD.
   151  	} else {
   152  		ch.buf.Write8(0) // No incoming FD.
   153  	}
   154  	m.encode(&ch.buf)
   155  	ssz := uint32(len(ch.buf.data)) // Updated below.
   156  
   157  	// Is there a payload?
   158  	if payloader, ok := m.(payloader); ok {
   159  		p := payloader.Payload()
   160  		copy(ch.data.Data()[ssz:], p)
   161  		ssz += uint32(len(p))
   162  	}
   163  
   164  	// Perform the one-shot communication.
   165  	return ch.data.SendRecv(ssz)
   166  }
   167  
   168  // recv decodes a message that exists on the channel.
   169  //
   170  // If the passed r is non-nil, then the type must match or an error will be
   171  // generated. If the passed r is nil, then a new message will be created and
   172  // returned.
   173  func (ch *channel) recv(r message, rsz uint32) (message, error) {
   174  	// Decode the response from the inline buffer.
   175  	ch.reset(rsz)
   176  	t := ch.buf.ReadMsgType()
   177  	hasFD := ch.buf.Read8() != 0
   178  	if t == MsgRlerror {
   179  		// Change the message type. We check for this special case
   180  		// after decoding below, and transform into an error.
   181  		r = &Rlerror{}
   182  	} else if r == nil {
   183  		nr, err := msgRegistry.get(0, t)
   184  		if err != nil {
   185  			return nil, err
   186  		}
   187  		r = nr // New message.
   188  	} else if t != r.Type() {
   189  		// Not an error and not the expected response; propagate.
   190  		return nil, &ErrBadResponse{Got: t, Want: r.Type()}
   191  	}
   192  
   193  	// Is there a payload? Copy from the latter portion.
   194  	if payloader, ok := r.(payloader); ok {
   195  		fs := payloader.FixedSize()
   196  		p := payloader.Payload()
   197  		payloadData := ch.buf.data[fs:]
   198  		if len(p) < len(payloadData) {
   199  			p = make([]byte, len(payloadData))
   200  			copy(p, payloadData)
   201  			payloader.SetPayload(p)
   202  		} else if n := copy(p, payloadData); n < len(p) {
   203  			payloader.SetPayload(p[:n])
   204  		}
   205  		ch.buf.data = ch.buf.data[:fs]
   206  	}
   207  
   208  	r.decode(&ch.buf)
   209  	if ch.buf.isOverrun() {
   210  		// Nothing valid was available.
   211  		log.Debugf("recv [got %d bytes, needed more]", rsz)
   212  		return nil, ErrNoValidMessage
   213  	}
   214  
   215  	// Read any FD result.
   216  	if hasFD {
   217  		if rfd, err := ch.fds.RecvFDNonblock(); err == nil {
   218  			f := fd.New(rfd)
   219  			if filer, ok := r.(filer); ok {
   220  				// Set the payload.
   221  				filer.SetFilePayload(f)
   222  			} else {
   223  				// Don't want the FD.
   224  				f.Close()
   225  			}
   226  		} else {
   227  			// The header bit was set but nothing came in.
   228  			log.Warningf("expected FD, got err: %v", err)
   229  		}
   230  	}
   231  
   232  	// Log a message.
   233  	if log.IsLogging(log.Debug) {
   234  		log.Debugf("recv [channel @%p] %s", ch, r.String())
   235  	}
   236  
   237  	// Convert errors appropriately; see above.
   238  	if rlerr, ok := r.(*Rlerror); ok {
   239  		return r, unix.Errno(rlerr.Error)
   240  	}
   241  
   242  	return r, nil
   243  }