github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/p9/transport.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package p9
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"io"
    21  	"io/ioutil"
    22  
    23  	"golang.org/x/sys/unix"
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/fd"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/log"
    26  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    27  	"github.com/nicocha30/gvisor-ligolo/pkg/unet"
    28  )
    29  
    30  // ErrSocket is returned in cases of a socket issue.
    31  //
    32  // This may be treated differently than other errors.
    33  type ErrSocket struct {
    34  	// error is the socket error.
    35  	error
    36  }
    37  
    38  // ErrMessageTooLarge indicates the size was larger than reasonable.
    39  type ErrMessageTooLarge struct {
    40  	size  uint32
    41  	msize uint32
    42  }
    43  
    44  // Error returns a sensible error.
    45  func (e *ErrMessageTooLarge) Error() string {
    46  	return fmt.Sprintf("message too large for fixed buffer: size is %d, limit is %d", e.size, e.msize)
    47  }
    48  
    49  // ErrNoValidMessage indicates no valid message could be decoded.
    50  var ErrNoValidMessage = errors.New("buffer contained no valid message")
    51  
    52  const (
    53  	// headerLength is the number of bytes required for a header.
    54  	headerLength uint32 = 7
    55  
    56  	// maximumLength is the largest possible message.
    57  	maximumLength uint32 = 1 << 20
    58  
    59  	// DefaultMessageSize is a sensible default.
    60  	DefaultMessageSize uint32 = 64 << 10
    61  
    62  	// initialBufferLength is the initial data buffer we allocate.
    63  	initialBufferLength uint32 = 64
    64  )
    65  
    66  var dataPool = sync.Pool{
    67  	New: func() any {
    68  		// These buffers are used for decoding without a payload.
    69  		// We need to return a pointer to avoid unnecessary allocations
    70  		// (see https://staticcheck.io/docs/checks#SA6002).
    71  		b := make([]byte, initialBufferLength)
    72  		return &b
    73  	},
    74  }
    75  
    76  // send sends the given message over the socket.
    77  func send(s *unet.Socket, tag Tag, m message) error {
    78  	data := dataPool.Get().(*[]byte)
    79  	dataBuf := buffer{data: (*data)[:0]}
    80  
    81  	if log.IsLogging(log.Debug) {
    82  		log.Debugf("send [FD %d] [Tag %06d] %s", s.FD(), tag, m.String())
    83  	}
    84  
    85  	// Encode the message. The buffer will grow automatically.
    86  	m.encode(&dataBuf)
    87  
    88  	// Get our vectors to send.
    89  	var hdr [headerLength]byte
    90  	vecs := make([][]byte, 0, 3)
    91  	vecs = append(vecs, hdr[:])
    92  	if len(dataBuf.data) > 0 {
    93  		vecs = append(vecs, dataBuf.data)
    94  	}
    95  	totalLength := headerLength + uint32(len(dataBuf.data))
    96  
    97  	// Is there a payload?
    98  	if payloader, ok := m.(payloader); ok {
    99  		p := payloader.Payload()
   100  		if len(p) > 0 {
   101  			vecs = append(vecs, p)
   102  			totalLength += uint32(len(p))
   103  		}
   104  	}
   105  
   106  	// Construct the header.
   107  	headerBuf := buffer{data: hdr[:0]}
   108  	headerBuf.Write32(totalLength)
   109  	headerBuf.WriteMsgType(m.Type())
   110  	headerBuf.WriteTag(tag)
   111  
   112  	// Pack any files if necessary.
   113  	w := s.Writer(true)
   114  	if filer, ok := m.(filer); ok {
   115  		if f := filer.FilePayload(); f != nil {
   116  			defer f.Close()
   117  			// Pack the file into the message.
   118  			w.PackFDs(f.FD())
   119  		}
   120  	}
   121  
   122  	for n := 0; n < int(totalLength); {
   123  		cur, err := w.WriteVec(vecs)
   124  		if err != nil {
   125  			return ErrSocket{err}
   126  		}
   127  		n += cur
   128  
   129  		// Consume iovecs.
   130  		for consumed := 0; consumed < cur; {
   131  			if len(vecs[0]) <= cur-consumed {
   132  				consumed += len(vecs[0])
   133  				vecs = vecs[1:]
   134  			} else {
   135  				vecs[0] = vecs[0][cur-consumed:]
   136  				break
   137  			}
   138  		}
   139  
   140  		if n > 0 && n < int(totalLength) {
   141  			// Don't resend any control message.
   142  			w.UnpackFDs()
   143  		}
   144  	}
   145  
   146  	// All set.
   147  	dataPool.Put(&dataBuf.data)
   148  	return nil
   149  }
   150  
   151  // lookupTagAndType looks up an existing message or creates a new one.
   152  //
   153  // This is called by recv after decoding the header. Any error returned will be
   154  // propagating back to the caller. You may use messageByType directly as a
   155  // lookupTagAndType function (by design).
   156  type lookupTagAndType func(tag Tag, t MsgType) (message, error)
   157  
   158  // recv decodes a message from the socket.
   159  //
   160  // This is done in two parts, and is thus not safe for multiple callers.
   161  //
   162  // On a socket error, the special error type ErrSocket is returned.
   163  //
   164  // The tag value NoTag will always be returned if err is non-nil.
   165  func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, error) {
   166  	// Read a header.
   167  	//
   168  	// Since the send above is atomic, we must always receive control
   169  	// messages along with the header. This means we need to be careful
   170  	// about closing FDs during errors to prevent leaks.
   171  	var hdr [headerLength]byte
   172  	r := s.Reader(true)
   173  	r.EnableFDs(1)
   174  
   175  	n, err := r.ReadVec([][]byte{hdr[:]})
   176  	if err != nil && (n == 0 || err != io.EOF) {
   177  		r.CloseFDs()
   178  		return NoTag, nil, ErrSocket{err}
   179  	}
   180  
   181  	fds, err := r.ExtractFDs()
   182  	if err != nil {
   183  		return NoTag, nil, ErrSocket{err}
   184  	}
   185  	defer func() {
   186  		// Close anything left open. The case where
   187  		// fds are caught and used is handled below,
   188  		// and the fds variable will be set to nil.
   189  		for _, fd := range fds {
   190  			unix.Close(fd)
   191  		}
   192  	}()
   193  	r.EnableFDs(0)
   194  
   195  	// Continuing reading for a short header.
   196  	for n < int(headerLength) {
   197  		cur, err := r.ReadVec([][]byte{hdr[n:]})
   198  		if err != nil && (cur == 0 || err != io.EOF) {
   199  			return NoTag, nil, ErrSocket{err}
   200  		}
   201  		n += cur
   202  	}
   203  
   204  	// Decode the header.
   205  	headerBuf := buffer{data: hdr[:]}
   206  	size := headerBuf.Read32()
   207  	t := headerBuf.ReadMsgType()
   208  	tag := headerBuf.ReadTag()
   209  	if size < headerLength {
   210  		// The message is too small.
   211  		//
   212  		// See above: it's probably screwed.
   213  		return NoTag, nil, ErrSocket{ErrNoValidMessage}
   214  	}
   215  	if size > maximumLength || size > msize {
   216  		// The message is too big.
   217  		return NoTag, nil, ErrSocket{&ErrMessageTooLarge{size, msize}}
   218  	}
   219  	remaining := size - headerLength
   220  
   221  	// Find our message to decode.
   222  	m, err := lookup(tag, t)
   223  	if err != nil {
   224  		// Throw away the contents of this message.
   225  		if remaining > 0 {
   226  			io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)})
   227  		}
   228  		return tag, nil, err
   229  	}
   230  
   231  	// Not yet initialized.
   232  	var dataBuf buffer
   233  	var vecs [][]byte
   234  
   235  	appendBuffer := func(size int) *[]byte {
   236  		// Pull a data buffer from the pool.
   237  		datap := dataPool.Get().(*[]byte)
   238  		data := *datap
   239  		if size > len(data) {
   240  			// Create a larger data buffer.
   241  			data = make([]byte, size)
   242  			datap = &data
   243  		} else {
   244  			// Limit the data buffer.
   245  			data = data[:size]
   246  		}
   247  		dataBuf = buffer{data: data}
   248  		vecs = append(vecs, data)
   249  		return datap
   250  	}
   251  
   252  	// Read the rest of the payload.
   253  	//
   254  	// This requires some special care to ensure that the vectors all line
   255  	// up the way they should. We do this to minimize copying data around.
   256  	if payloader, ok := m.(payloader); ok {
   257  		fixedSize := payloader.FixedSize()
   258  
   259  		// Do we need more than there is?
   260  		if fixedSize > remaining {
   261  			// This is not a valid message.
   262  			if remaining > 0 {
   263  				io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)})
   264  			}
   265  			return NoTag, nil, ErrNoValidMessage
   266  		}
   267  
   268  		if fixedSize != 0 {
   269  			datap := appendBuffer(int(fixedSize))
   270  			defer dataPool.Put(datap)
   271  		}
   272  
   273  		// Include the payload.
   274  		p := payloader.Payload()
   275  		if p == nil || len(p) != int(remaining-fixedSize) {
   276  			p = make([]byte, remaining-fixedSize)
   277  			payloader.SetPayload(p)
   278  		}
   279  		if len(p) > 0 {
   280  			vecs = append(vecs, p)
   281  		}
   282  	} else if remaining != 0 {
   283  		datap := appendBuffer(int(remaining))
   284  		defer dataPool.Put(datap)
   285  	}
   286  
   287  	if len(vecs) > 0 {
   288  		// Read the rest of the message.
   289  		//
   290  		// No need to handle a control message.
   291  		r := s.Reader(true)
   292  		for n := 0; n < int(remaining); {
   293  			cur, err := r.ReadVec(vecs)
   294  			if err != nil && (cur == 0 || err != io.EOF) {
   295  				return NoTag, nil, ErrSocket{err}
   296  			}
   297  			n += cur
   298  
   299  			// Consume iovecs.
   300  			for consumed := 0; consumed < cur; {
   301  				if len(vecs[0]) <= cur-consumed {
   302  					consumed += len(vecs[0])
   303  					vecs = vecs[1:]
   304  				} else {
   305  					vecs[0] = vecs[0][cur-consumed:]
   306  					break
   307  				}
   308  			}
   309  		}
   310  	}
   311  
   312  	// Decode the message data.
   313  	m.decode(&dataBuf)
   314  	if dataBuf.isOverrun() {
   315  		// No need to drain the socket.
   316  		return NoTag, nil, ErrNoValidMessage
   317  	}
   318  
   319  	// Save the file, if any came out.
   320  	if filer, ok := m.(filer); ok && len(fds) > 0 {
   321  		// Set the file object.
   322  		filer.SetFilePayload(fd.New(fds[0]))
   323  
   324  		// Close the rest. We support only one.
   325  		for i := 1; i < len(fds); i++ {
   326  			unix.Close(fds[i])
   327  		}
   328  
   329  		// Don't close in the defer.
   330  		fds = nil
   331  	}
   332  
   333  	if log.IsLogging(log.Debug) {
   334  		log.Debugf("recv [FD %d] [Tag %06d] %s", s.FD(), tag, m.String())
   335  	}
   336  
   337  	// All set.
   338  	return tag, m, nil
   339  }