github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/lisafs/sock.go (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package lisafs 16 17 import ( 18 "fmt" 19 "io" 20 21 "golang.org/x/sys/unix" 22 "github.com/sagernet/gvisor/pkg/log" 23 "github.com/sagernet/gvisor/pkg/unet" 24 ) 25 26 var ( 27 sockHeaderLen = uint32((*sockHeader)(nil).SizeBytes()) 28 ) 29 30 // sockHeader is the header present in front of each message received on a UDS. 31 // 32 // +marshal 33 type sockHeader struct { 34 payloadLen uint32 35 message MID 36 _ uint16 // Need to make struct packed. 37 } 38 39 // sockCommunicator implements Communicator. This is not thread safe. 40 type sockCommunicator struct { 41 fdTracker 42 sock *unet.Socket 43 buf []byte 44 } 45 46 var _ Communicator = (*sockCommunicator)(nil) 47 48 func newSockComm(sock *unet.Socket) *sockCommunicator { 49 return &sockCommunicator{ 50 sock: sock, 51 buf: make([]byte, sockHeaderLen), 52 } 53 } 54 55 func (s *sockCommunicator) FD() int { 56 return s.sock.FD() 57 } 58 59 func (s *sockCommunicator) destroy() { 60 s.sock.Close() 61 } 62 63 func (s *sockCommunicator) shutdown() { 64 if err := s.sock.Shutdown(); err != nil { 65 log.Warningf("Socket.Shutdown() failed (FD: %d): %v", s.sock.FD(), err) 66 } 67 } 68 69 func (s *sockCommunicator) resizeBuf(size uint32) { 70 if cap(s.buf) < int(size) { 71 s.buf = s.buf[:cap(s.buf)] 72 s.buf = append(s.buf, make([]byte, int(size)-cap(s.buf))...) 73 } else { 74 s.buf = s.buf[:size] 75 } 76 } 77 78 // PayloadBuf implements Communicator.PayloadBuf. 79 func (s *sockCommunicator) PayloadBuf(size uint32) []byte { 80 s.resizeBuf(sockHeaderLen + size) 81 return s.buf[sockHeaderLen : sockHeaderLen+size] 82 } 83 84 // SndRcvMessage implements Communicator.SndRcvMessage. 85 func (s *sockCommunicator) SndRcvMessage(m MID, payloadLen uint32, wantFDs uint8) (MID, uint32, error) { 86 // Map the transport errors to EIO, but also log the real error. 87 if err := s.sndPrepopulatedMsg(m, payloadLen, nil); err != nil { 88 log.Warningf("socketCommunicator.SndRcvMessage: sndPrepopulatedMsg failed: %v", err) 89 return 0, 0, unix.EIO 90 } 91 92 respM, respPayloadLen, err := s.rcvMsg(wantFDs) 93 if err != nil { 94 log.Warningf("socketCommunicator.SndRcvMessage: rcvMsg failed: %v", err) 95 return 0, 0, unix.EIO 96 } 97 return respM, respPayloadLen, nil 98 } 99 100 // String implements fmt.Stringer.String. 101 func (s *sockCommunicator) String() string { 102 return fmt.Sprintf("sockComm %d", s.sock.FD()) 103 } 104 105 // sndPrepopulatedMsg assumes that s.buf has already been populated with 106 // `payloadLen` bytes of data. 107 func (s *sockCommunicator) sndPrepopulatedMsg(m MID, payloadLen uint32, fds []int) error { 108 header := sockHeader{payloadLen: payloadLen, message: m} 109 header.MarshalUnsafe(s.buf) 110 dataLen := sockHeaderLen + payloadLen 111 return writeTo(s.sock, [][]byte{s.buf[:dataLen]}, int(dataLen), fds) 112 } 113 114 // writeTo writes the passed iovec to the UDS and donates any passed FDs. 115 func writeTo(sock *unet.Socket, iovec [][]byte, dataLen int, fds []int) error { 116 w := sock.Writer(true) 117 if len(fds) > 0 { 118 w.PackFDs(fds...) 119 } 120 121 fdsUnpacked := false 122 for n := 0; n < dataLen; { 123 cur, err := w.WriteVec(iovec) 124 if err != nil { 125 return err 126 } 127 n += cur 128 129 // Fast common path. 130 if n >= dataLen { 131 break 132 } 133 134 // Consume iovecs. 135 for consumed := 0; consumed < cur; { 136 if len(iovec[0]) <= cur-consumed { 137 consumed += len(iovec[0]) 138 iovec = iovec[1:] 139 } else { 140 iovec[0] = iovec[0][cur-consumed:] 141 break 142 } 143 } 144 145 if n > 0 && !fdsUnpacked { 146 // Don't resend any control message. 147 fdsUnpacked = true 148 w.UnpackFDs() 149 } 150 } 151 return nil 152 } 153 154 // rcvMsg reads the message header and payload from the UDS. It also populates 155 // fds with any donated FDs. 156 func (s *sockCommunicator) rcvMsg(wantFDs uint8) (MID, uint32, error) { 157 fds, err := readFrom(s.sock, s.buf[:sockHeaderLen], wantFDs) 158 if err != nil { 159 return 0, 0, err 160 } 161 for _, fd := range fds { 162 s.TrackFD(fd) 163 } 164 165 var header sockHeader 166 header.UnmarshalUnsafe(s.buf) 167 168 // No payload? We are done. 169 if header.payloadLen == 0 { 170 return header.message, 0, nil 171 } 172 173 if _, err := readFrom(s.sock, s.PayloadBuf(header.payloadLen), 0); err != nil { 174 return 0, 0, err 175 } 176 177 return header.message, header.payloadLen, nil 178 } 179 180 // readFrom fills the passed buffer with data from the socket. It also returns 181 // any donated FDs. 182 func readFrom(sock *unet.Socket, buf []byte, wantFDs uint8) ([]int, error) { 183 r := sock.Reader(true) 184 r.EnableFDs(int(wantFDs)) 185 186 var ( 187 fds []int 188 fdInit bool 189 ) 190 n := len(buf) 191 for got := 0; got < n; { 192 cur, err := r.ReadVec([][]byte{buf[got:]}) 193 194 // Ignore EOF if cur > 0. 195 if err != nil && (err != io.EOF || cur == 0) { 196 r.CloseFDs() 197 return nil, err 198 } 199 200 if !fdInit && cur > 0 { 201 fds, err = r.ExtractFDs() 202 if err != nil { 203 return nil, err 204 } 205 206 fdInit = true 207 r.EnableFDs(0) 208 } 209 210 got += cur 211 } 212 return fds, nil 213 } 214 215 func closeFDs(fds []int) { 216 for _, fd := range fds { 217 if fd >= 0 { 218 unix.Close(fd) 219 } 220 } 221 }