github.com/ttpreport/gvisor-ligolo@v0.0.0-20240123134145-a858404967ba/pkg/lisafs/sock.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package lisafs
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  
    21  	"github.com/ttpreport/gvisor-ligolo/pkg/log"
    22  	"github.com/ttpreport/gvisor-ligolo/pkg/unet"
    23  	"golang.org/x/sys/unix"
    24  )
    25  
    26  var (
    27  	sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes())
    28  )
    29  
    30  // sockHeader is the header present in front of each message received on a UDS.
    31  //
    32  // +marshal
    33  type sockHeader struct {
    34  	payloadLen uint32
    35  	message    MID
    36  	_          uint16 // Need to make struct packed.
    37  }
    38  
    39  // sockCommunicator implements Communicator. This is not thread safe.
    40  type sockCommunicator struct {
    41  	fdTracker
    42  	sock *unet.Socket
    43  	buf  []byte
    44  }
    45  
    46  var _ Communicator = (*sockCommunicator)(nil)
    47  
    48  func newSockComm(sock *unet.Socket) *sockCommunicator {
    49  	return &sockCommunicator{
    50  		sock: sock,
    51  		buf:  make([]byte, sockHeaderLen),
    52  	}
    53  }
    54  
    55  func (s *sockCommunicator) FD() int {
    56  	return s.sock.FD()
    57  }
    58  
    59  func (s *sockCommunicator) destroy() {
    60  	s.sock.Close()
    61  }
    62  
    63  func (s *sockCommunicator) shutdown() {
    64  	if err := s.sock.Shutdown(); err != nil {
    65  		log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err)
    66  	}
    67  }
    68  
    69  func (s *sockCommunicator) resizeBuf(size uint32) {
    70  	if cap(s.buf) < int(size) {
    71  		s.buf = s.buf[:cap(s.buf)]
    72  		s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...)
    73  	} else {
    74  		s.buf = s.buf[:size]
    75  	}
    76  }
    77  
    78  // PayloadBuf implements Communicator.PayloadBuf.
    79  func (s *sockCommunicator) PayloadBuf(size uint32) []byte {
    80  	s.resizeBuf(sockHeaderLen + size)
    81  	return s.buf[sockHeaderLen : sockHeaderLen+size]
    82  }
    83  
    84  // SndRcvMessage implements Communicator.SndRcvMessage.
    85  func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) {
    86  	// Map the transport errors to EIO, but also log the real error.
    87  	if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil {
    88  		log.Warningf("socketCommunicator.SndRcvMessage: sndPrepopulatedMsg failed: %v", err)
    89  		return 0, 0, unix.EIO
    90  	}
    91  
    92  	respM, respPayloadLen, err := s.rcvMsg(wantFDs)
    93  	if err != nil {
    94  		log.Warningf("socketCommunicator.SndRcvMessage: rcvMsg failed: %v", err)
    95  		return 0, 0, unix.EIO
    96  	}
    97  	return respM, respPayloadLen, nil
    98  }
    99  
   100  // String implements fmt.Stringer.String.
   101  func (s *sockCommunicator) String() string {
   102  	return fmt.Sprintf("sockComm %d", s.sock.FD())
   103  }
   104  
   105  // sndPrepopulatedMsg assumes that s.buf has already been populated with
   106  // `payloadLen` bytes of data.
   107  func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error {
   108  	header := sockHeader{payloadLen: payloadLen, message: m}
   109  	header.MarshalUnsafe(s.buf)
   110  	dataLen := sockHeaderLen + payloadLen
   111  	return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds)
   112  }
   113  
   114  // writeTo writes the passed iovec to the UDS and donates any passed FDs.
   115  func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error {
   116  	w := sock.Writer(true)
   117  	if len(fds) > 0 {
   118  		w.PackFDs(fds...)
   119  	}
   120  
   121  	fdsUnpacked := false
   122  	for n := 0; n < dataLen; {
   123  		cur, err := w.WriteVec(iovec)
   124  		if err != nil {
   125  			return err
   126  		}
   127  		n += cur
   128  
   129  		// Fast common path.
   130  		if n >= dataLen {
   131  			break
   132  		}
   133  
   134  		// Consume iovecs.
   135  		for consumed := 0; consumed < cur; {
   136  			if len(iovec[0]) <= cur-consumed {
   137  				consumed += len(iovec[0])
   138  				iovec = iovec[1:]
   139  			} else {
   140  				iovec[0] = iovec[0][cur-consumed:]
   141  				break
   142  			}
   143  		}
   144  
   145  		if n > 0 && !fdsUnpacked {
   146  			// Don't resend any control message.
   147  			fdsUnpacked = true
   148  			w.UnpackFDs()
   149  		}
   150  	}
   151  	return nil
   152  }
   153  
   154  // rcvMsg reads the message header and payload from the UDS. It also populates
   155  // fds with any donated FDs.
   156  func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) {
   157  	fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs)
   158  	if err != nil {
   159  		return 0, 0, err
   160  	}
   161  	for _, fd := range fds {
   162  		s.TrackFD(fd)
   163  	}
   164  
   165  	var header sockHeader
   166  	header.UnmarshalUnsafe(s.buf)
   167  
   168  	// No payload? We are done.
   169  	if header.payloadLen == 0 {
   170  		return header.message, 0, nil
   171  	}
   172  
   173  	if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil {
   174  		return 0, 0, err
   175  	}
   176  
   177  	return header.message, header.payloadLen, nil
   178  }
   179  
   180  // readFrom fills the passed buffer with data from the socket. It also returns
   181  // any donated FDs.
   182  func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) {
   183  	r := sock.Reader(true)
   184  	r.EnableFDs(int(wantFDs))
   185  
   186  	var (
   187  		fds    []int
   188  		fdInit bool
   189  	)
   190  	n := len(buf)
   191  	for got := 0; got < n; {
   192  		cur, err := r.ReadVec([][]byte{buf[got:]})
   193  
   194  		// Ignore EOF if cur > 0.
   195  		if err != nil && (err != io.EOF || cur == 0) {
   196  			r.CloseFDs()
   197  			return nil, err
   198  		}
   199  
   200  		if !fdInit && cur > 0 {
   201  			fds, err = r.ExtractFDs()
   202  			if err != nil {
   203  				return nil, err
   204  			}
   205  
   206  			fdInit = true
   207  			r.EnableFDs(0)
   208  		}
   209  
   210  		got += cur
   211  	}
   212  	return fds, nil
   213  }
   214  
   215  func closeFDs(fds []int) {
   216  	for _, fd := range fds {
   217  		if fd >= 0 {
   218  			unix.Close(fd)
   219  		}
   220  	}
   221  }