github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/p9/transport.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package p9 16 17 import ( 18 "errors" 19 "fmt" 20 "io" 21 "io/ioutil" 22 23 "golang.org/x/sys/unix" 24 "github.com/sagernet/gvisor/pkg/fd" 25 "github.com/sagernet/gvisor/pkg/log" 26 "github.com/sagernet/gvisor/pkg/sync" 27 "github.com/sagernet/gvisor/pkg/unet" 28 ) 29 30 // ErrSocket is returned in cases of a socket issue. 31 // 32 // This may be treated differently than other errors. 33 type ErrSocket struct { 34 // error is the socket error. 35 error 36 } 37 38 // ErrMessageTooLarge indicates the size was larger than reasonable. 39 type ErrMessageTooLarge struct { 40 size uint32 41 msize uint32 42 } 43 44 // Error returns a sensible error. 45 func (e *ErrMessageTooLarge) Error() string { 46 return fmt.Sprintf("message too large for fixed buffer: size is %d, limit is %d", e.size, e.msize) 47 } 48 49 // ErrNoValidMessage indicates no valid message could be decoded. 50 var ErrNoValidMessage = errors.New("buffer contained no valid message") 51 52 const ( 53 // headerLength is the number of bytes required for a header. 54 headerLength uint32 = 7 55 56 // maximumLength is the largest possible message. 57 maximumLength uint32 = 1 << 20 58 59 // DefaultMessageSize is a sensible default. 60 DefaultMessageSize uint32 = 64 << 10 61 62 // initialBufferLength is the initial data buffer we allocate. 63 initialBufferLength uint32 = 64 64 ) 65 66 var dataPool = sync.Pool{ 67 New: func() any { 68 // These buffers are used for decoding without a payload. 69 // We need to return a pointer to avoid unnecessary allocations 70 // (see https://staticcheck.io/docs/checks#SA6002). 71 b := make([]byte, initialBufferLength) 72 return &b 73 }, 74 } 75 76 // send sends the given message over the socket. 77 func send(s *unet.Socket, tag Tag, m message) error { 78 data := dataPool.Get().(*[]byte) 79 dataBuf := buffer{data: (*data)[:0]} 80 81 if log.IsLogging(log.Debug) { 82 log.Debugf("send [FD %d] [Tag %06d] %s", s.FD(), tag, m.String()) 83 } 84 85 // Encode the message. The buffer will grow automatically. 86 m.encode(&dataBuf) 87 88 // Get our vectors to send. 89 var hdr [headerLength]byte 90 vecs := make([][]byte, 0, 3) 91 vecs = append(vecs, hdr[:]) 92 if len(dataBuf.data) > 0 { 93 vecs = append(vecs, dataBuf.data) 94 } 95 totalLength := headerLength + uint32(len(dataBuf.data)) 96 97 // Is there a payload? 98 if payloader, ok := m.(payloader); ok { 99 p := payloader.Payload() 100 if len(p) > 0 { 101 vecs = append(vecs, p) 102 totalLength += uint32(len(p)) 103 } 104 } 105 106 // Construct the header. 107 headerBuf := buffer{data: hdr[:0]} 108 headerBuf.Write32(totalLength) 109 headerBuf.WriteMsgType(m.Type()) 110 headerBuf.WriteTag(tag) 111 112 // Pack any files if necessary. 113 w := s.Writer(true) 114 if filer, ok := m.(filer); ok { 115 if f := filer.FilePayload(); f != nil { 116 defer f.Close() 117 // Pack the file into the message. 118 w.PackFDs(f.FD()) 119 } 120 } 121 122 for n := 0; n < int(totalLength); { 123 cur, err := w.WriteVec(vecs) 124 if err != nil { 125 return ErrSocket{err} 126 } 127 n += cur 128 129 // Consume iovecs. 130 for consumed := 0; consumed < cur; { 131 if len(vecs[0]) <= cur-consumed { 132 consumed += len(vecs[0]) 133 vecs = vecs[1:] 134 } else { 135 vecs[0] = vecs[0][cur-consumed:] 136 break 137 } 138 } 139 140 if n > 0 && n < int(totalLength) { 141 // Don't resend any control message. 142 w.UnpackFDs() 143 } 144 } 145 146 // All set. 147 dataPool.Put(&dataBuf.data) 148 return nil 149 } 150 151 // lookupTagAndType looks up an existing message or creates a new one. 152 // 153 // This is called by recv after decoding the header. Any error returned will be 154 // propagating back to the caller. You may use messageByType directly as a 155 // lookupTagAndType function (by design). 156 type lookupTagAndType func(tag Tag, t MsgType) (message, error) 157 158 // recv decodes a message from the socket. 159 // 160 // This is done in two parts, and is thus not safe for multiple callers. 161 // 162 // On a socket error, the special error type ErrSocket is returned. 163 // 164 // The tag value NoTag will always be returned if err is non-nil. 165 func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, error) { 166 // Read a header. 167 // 168 // Since the send above is atomic, we must always receive control 169 // messages along with the header. This means we need to be careful 170 // about closing FDs during errors to prevent leaks. 171 var hdr [headerLength]byte 172 r := s.Reader(true) 173 r.EnableFDs(1) 174 175 n, err := r.ReadVec([][]byte{hdr[:]}) 176 if err != nil && (n == 0 || err != io.EOF) { 177 r.CloseFDs() 178 return NoTag, nil, ErrSocket{err} 179 } 180 181 fds, err := r.ExtractFDs() 182 if err != nil { 183 return NoTag, nil, ErrSocket{err} 184 } 185 defer func() { 186 // Close anything left open. The case where 187 // fds are caught and used is handled below, 188 // and the fds variable will be set to nil. 189 for _, fd := range fds { 190 unix.Close(fd) 191 } 192 }() 193 r.EnableFDs(0) 194 195 // Continuing reading for a short header. 196 for n < int(headerLength) { 197 cur, err := r.ReadVec([][]byte{hdr[n:]}) 198 if err != nil && (cur == 0 || err != io.EOF) { 199 return NoTag, nil, ErrSocket{err} 200 } 201 n += cur 202 } 203 204 // Decode the header. 205 headerBuf := buffer{data: hdr[:]} 206 size := headerBuf.Read32() 207 t := headerBuf.ReadMsgType() 208 tag := headerBuf.ReadTag() 209 if size < headerLength { 210 // The message is too small. 211 // 212 // See above: it's probably screwed. 213 return NoTag, nil, ErrSocket{ErrNoValidMessage} 214 } 215 if size > maximumLength || size > msize { 216 // The message is too big. 217 return NoTag, nil, ErrSocket{&ErrMessageTooLarge{size, msize}} 218 } 219 remaining := size - headerLength 220 221 // Find our message to decode. 222 m, err := lookup(tag, t) 223 if err != nil { 224 // Throw away the contents of this message. 225 if remaining > 0 { 226 io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)}) 227 } 228 return tag, nil, err 229 } 230 231 // Not yet initialized. 232 var dataBuf buffer 233 var vecs [][]byte 234 235 appendBuffer := func(size int) *[]byte { 236 // Pull a data buffer from the pool. 237 datap := dataPool.Get().(*[]byte) 238 data := *datap 239 if size > len(data) { 240 // Create a larger data buffer. 241 data = make([]byte, size) 242 datap = &data 243 } else { 244 // Limit the data buffer. 245 data = data[:size] 246 } 247 dataBuf = buffer{data: data} 248 vecs = append(vecs, data) 249 return datap 250 } 251 252 // Read the rest of the payload. 253 // 254 // This requires some special care to ensure that the vectors all line 255 // up the way they should. We do this to minimize copying data around. 256 if payloader, ok := m.(payloader); ok { 257 fixedSize := payloader.FixedSize() 258 259 // Do we need more than there is? 260 if fixedSize > remaining { 261 // This is not a valid message. 262 if remaining > 0 { 263 io.Copy(ioutil.Discard, &io.LimitedReader{R: s, N: int64(remaining)}) 264 } 265 return NoTag, nil, ErrNoValidMessage 266 } 267 268 if fixedSize != 0 { 269 datap := appendBuffer(int(fixedSize)) 270 defer dataPool.Put(datap) 271 } 272 273 // Include the payload. 274 p := payloader.Payload() 275 if p == nil || len(p) != int(remaining-fixedSize) { 276 p = make([]byte, remaining-fixedSize) 277 payloader.SetPayload(p) 278 } 279 if len(p) > 0 { 280 vecs = append(vecs, p) 281 } 282 } else if remaining != 0 { 283 datap := appendBuffer(int(remaining)) 284 defer dataPool.Put(datap) 285 } 286 287 if len(vecs) > 0 { 288 // Read the rest of the message. 289 // 290 // No need to handle a control message. 291 r := s.Reader(true) 292 for n := 0; n < int(remaining); { 293 cur, err := r.ReadVec(vecs) 294 if err != nil && (cur == 0 || err != io.EOF) { 295 return NoTag, nil, ErrSocket{err} 296 } 297 n += cur 298 299 // Consume iovecs. 300 for consumed := 0; consumed < cur; { 301 if len(vecs[0]) <= cur-consumed { 302 consumed += len(vecs[0]) 303 vecs = vecs[1:] 304 } else { 305 vecs[0] = vecs[0][cur-consumed:] 306 break 307 } 308 } 309 } 310 } 311 312 // Decode the message data. 313 m.decode(&dataBuf) 314 if dataBuf.isOverrun() { 315 // No need to drain the socket. 316 return NoTag, nil, ErrNoValidMessage 317 } 318 319 // Save the file, if any came out. 320 if filer, ok := m.(filer); ok && len(fds) > 0 { 321 // Set the file object. 322 filer.SetFilePayload(fd.New(fds[0])) 323 324 // Close the rest. We support only one. 325 for i := 1; i < len(fds); i++ { 326 unix.Close(fds[i]) 327 } 328 329 // Don't close in the defer. 330 fds = nil 331 } 332 333 if log.IsLogging(log.Debug) { 334 log.Debugf("recv [FD %d] [Tag %06d] %s", s.FD(), tag, m.String()) 335 } 336 337 // All set. 338 return tag, m, nil 339 }