github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/p9/client.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package p9 16 17 import ( 18 "errors" 19 "fmt" 20 21 "golang.org/x/sys/unix" 22 "github.com/sagernet/gvisor/pkg/flipcall" 23 "github.com/sagernet/gvisor/pkg/log" 24 "github.com/sagernet/gvisor/pkg/pool" 25 "github.com/sagernet/gvisor/pkg/sync" 26 "github.com/sagernet/gvisor/pkg/unet" 27 ) 28 29 // ErrOutOfTags indicates no tags are available. 30 var ErrOutOfTags = errors.New("out of tags -- messages lost?") 31 32 // ErrOutOfFIDs indicates no more FIDs are available. 33 var ErrOutOfFIDs = errors.New("out of FIDs -- messages lost?") 34 35 // ErrUnexpectedTag indicates a response with an unexpected tag was received. 36 var ErrUnexpectedTag = errors.New("unexpected tag in response") 37 38 // ErrVersionsExhausted indicates that all versions to negotiate have been exhausted. 39 var ErrVersionsExhausted = errors.New("exhausted all versions to negotiate") 40 41 // ErrBadVersionString indicates that the version string is malformed or unsupported. 42 var ErrBadVersionString = errors.New("bad version string") 43 44 // ErrBadResponse indicates the response didn't match the request. 45 type ErrBadResponse struct { 46 Got MsgType 47 Want MsgType 48 } 49 50 // Error returns a highly descriptive error. 51 func (e *ErrBadResponse) Error() string { 52 return fmt.Sprintf("unexpected message type: got %v, want %v", e.Got, e.Want) 53 } 54 55 // response is the asynchronous return from recv. 56 // 57 // This is used in the pending map below. 58 type response struct { 59 r message 60 done chan error 61 } 62 63 var responsePool = sync.Pool{ 64 New: func() any { 65 return &response{ 66 done: make(chan error, 1), 67 } 68 }, 69 } 70 71 // Client is at least a 9P2000.L client. 72 type Client struct { 73 // socket is the connected socket. 74 socket *unet.Socket 75 76 // tagPool is the collection of available tags. 77 tagPool pool.Pool 78 79 // fidPool is the collection of available fids. 80 fidPool pool.Pool 81 82 // messageSize is the maximum total size of a message. 83 messageSize uint32 84 85 // payloadSize is the maximum payload size of a read or write. 86 // 87 // For large reads and writes this means that the read or write is 88 // broken up into buffer-size/payloadSize requests. 89 payloadSize uint32 90 91 // version is the agreed upon version X of 9P2000.L.Google.X. 92 // version 0 implies 9P2000.L. 93 version uint32 94 95 // closedWg is marked as done when the Client.watch() goroutine, which is 96 // responsible for closing channels and the socket fd, returns. 97 closedWg sync.WaitGroup 98 99 // sendRecv is the transport function. 100 // 101 // This is determined dynamically based on whether or not the server 102 // supports flipcall channels (preferred as it is faster and more 103 // efficient, and does not require tags). 104 sendRecv func(message, message) error 105 106 // -- below corresponds to sendRecvChannel -- 107 108 // channelsMu protects channels. 109 channelsMu sync.Mutex 110 111 // channelsWg counts the number of channels for which channel.active == 112 // true. 113 channelsWg sync.WaitGroup 114 115 // channels is the set of all initialized channels. 116 channels []*channel 117 118 // availableChannels is a LIFO of inactive channels. 119 availableChannels []*channel 120 121 // -- below corresponds to sendRecvLegacy -- 122 123 // pending is the set of pending messages. 124 pending map[Tag]*response 125 pendingMu sync.Mutex 126 127 // sendMu is the lock for sending a request. 128 sendMu sync.Mutex 129 130 // recvr is essentially a mutex for calling recv. 131 // 132 // Whoever writes to this channel is permitted to call recv. When 133 // finished calling recv, this channel should be emptied. 134 recvr chan bool 135 } 136 137 // NewClient creates a new client. It performs a Tversion exchange with 138 // the server to assert that messageSize is ok to use. 139 // 140 // If NewClient succeeds, ownership of socket is transferred to the new Client. 141 func NewClient(socket *unet.Socket, messageSize uint32, version string) (*Client, error) { 142 // Need at least one byte of payload. 143 if messageSize <= msgRegistry.largestFixedSize { 144 return nil, &ErrMessageTooLarge{ 145 size: messageSize, 146 msize: msgRegistry.largestFixedSize, 147 } 148 } 149 150 // Compute a payload size and round to 512 (normal block size) 151 // if it's larger than a single block. 152 payloadSize := messageSize - msgRegistry.largestFixedSize 153 if payloadSize > 512 && payloadSize%512 != 0 { 154 payloadSize -= (payloadSize % 512) 155 } 156 c := &Client{ 157 socket: socket, 158 tagPool: pool.Pool{Start: 1, Limit: uint64(NoTag)}, 159 fidPool: pool.Pool{Start: 1, Limit: uint64(NoFID)}, 160 pending: make(map[Tag]*response), 161 recvr: make(chan bool, 1), 162 messageSize: messageSize, 163 payloadSize: payloadSize, 164 } 165 // Agree upon a version. 166 requested, ok := parseVersion(version) 167 if !ok { 168 return nil, ErrBadVersionString 169 } 170 for { 171 // Always exchange the version using the legacy version of the 172 // protocol. If the protocol supports flipcall, then we switch 173 // our sendRecv function to use that functionality. Otherwise, 174 // we stick to sendRecvLegacy. 175 rversion := Rversion{} 176 _, err := c.sendRecvLegacy(&Tversion{ 177 Version: versionString(requested), 178 MSize: messageSize, 179 }, &rversion) 180 181 // The server told us to try again with a lower version. 182 if err == unix.EAGAIN { 183 if requested == lowestSupportedVersion { 184 return nil, ErrVersionsExhausted 185 } 186 requested-- 187 continue 188 } 189 190 // We requested an impossible version or our other parameters were bogus. 191 if err != nil { 192 return nil, err 193 } 194 195 // Parse the version. 196 version, ok := parseVersion(rversion.Version) 197 if !ok { 198 // The server gave us a bad version. We return a generically worrisome error. 199 log.Warningf("server returned bad version string %q", rversion.Version) 200 return nil, ErrBadVersionString 201 } 202 c.version = version 203 break 204 } 205 206 // Can we switch to use the more advanced channels and create 207 // independent channels for communication? Prefer it if possible. 208 if versionSupportsFlipcall(c.version) { 209 // Attempt to initialize IPC-based communication. 210 for i := 0; i < channelsPerClient; i++ { 211 if err := c.openChannel(i); err != nil { 212 log.Warningf("error opening flipcall channel: %v", err) 213 break // Stop. 214 } 215 } 216 if len(c.channels) >= 1 { 217 // At least one channel created. 218 c.sendRecv = c.sendRecvChannel 219 } else { 220 // Channel setup failed; fallback. 221 c.sendRecv = c.sendRecvLegacySyscallErr 222 } 223 } else { 224 // No channels available: use the legacy mechanism. 225 c.sendRecv = c.sendRecvLegacySyscallErr 226 } 227 228 // Ensure that the socket and channels are closed when the socket is shut 229 // down. 230 c.closedWg.Add(1) 231 go c.watch(socket) // S/R-SAFE: not relevant. 232 233 return c, nil 234 } 235 236 // watch watches the given socket and releases resources on hangup events. 237 // 238 // This is intended to be called as a goroutine. 239 func (c *Client) watch(socket *unet.Socket) { 240 defer c.closedWg.Done() 241 242 events := []unix.PollFd{ 243 { 244 Fd: int32(socket.FD()), 245 Events: unix.POLLHUP | unix.POLLRDHUP, 246 }, 247 } 248 249 // Wait for a shutdown event. 250 for { 251 n, err := unix.Ppoll(events, nil, nil) 252 if err == unix.EINTR || err == unix.EAGAIN { 253 continue 254 } 255 if err != nil { 256 log.Warningf("p9.Client.watch(): %v", err) 257 break 258 } 259 if n != 1 { 260 log.Warningf("p9.Client.watch(): got %d events, wanted 1", n) 261 } 262 break 263 } 264 265 // Set availableChannels to nil so that future calls to c.sendRecvChannel() 266 // don't attempt to activate a channel, and concurrent calls to 267 // c.sendRecvChannel() don't mark released channels as available. 268 c.channelsMu.Lock() 269 c.availableChannels = nil 270 271 // Shut down all active channels. 272 for _, ch := range c.channels { 273 if ch.active { 274 log.Debugf("shutting down active channel@%p...", ch) 275 ch.Shutdown() 276 } 277 } 278 c.channelsMu.Unlock() 279 280 // Wait for active channels to become inactive. 281 c.channelsWg.Wait() 282 283 // Close all channels. 284 c.channelsMu.Lock() 285 for _, ch := range c.channels { 286 ch.Close() 287 } 288 c.channelsMu.Unlock() 289 290 // Close the main socket. 291 c.socket.Close() 292 } 293 294 // openChannel attempts to open a client channel. 295 // 296 // Note that this function returns naked errors which should not be propagated 297 // directly to a caller. It is expected that the errors will be logged and a 298 // fallback path will be used instead. 299 func (c *Client) openChannel(id int) error { 300 var ( 301 rchannel0 Rchannel 302 rchannel1 Rchannel 303 res = new(channel) 304 ) 305 306 // Open the data channel. 307 if _, err := c.sendRecvLegacy(&Tchannel{ 308 ID: uint32(id), 309 Control: 0, 310 }, &rchannel0); err != nil { 311 return fmt.Errorf("error handling Tchannel message: %v", err) 312 } 313 if rchannel0.FilePayload() == nil { 314 return fmt.Errorf("missing file descriptor on primary channel") 315 } 316 317 // We don't need to hold this. 318 defer rchannel0.FilePayload().Close() 319 320 // Open the channel for file descriptors. 321 if _, err := c.sendRecvLegacy(&Tchannel{ 322 ID: uint32(id), 323 Control: 1, 324 }, &rchannel1); err != nil { 325 return err 326 } 327 if rchannel1.FilePayload() == nil { 328 return fmt.Errorf("missing file descriptor on file descriptor channel") 329 } 330 331 // Construct the endpoints. 332 res.desc = flipcall.PacketWindowDescriptor{ 333 FD: rchannel0.FilePayload().FD(), 334 Offset: int64(rchannel0.Offset), 335 Length: int(rchannel0.Length), 336 } 337 if err := res.data.Init(flipcall.ClientSide, res.desc); err != nil { 338 rchannel1.FilePayload().Close() 339 return err 340 } 341 342 // The fds channel owns the control payload, and it will be closed when 343 // the channel object is closed. 344 res.fds.Init(rchannel1.FilePayload().Release()) 345 346 // Save the channel. 347 c.channelsMu.Lock() 348 defer c.channelsMu.Unlock() 349 c.channels = append(c.channels, res) 350 c.availableChannels = append(c.availableChannels, res) 351 return nil 352 } 353 354 // handleOne handles a single incoming message. 355 // 356 // This should only be called with the token from recvr. Note that the received 357 // tag will automatically be cleared from pending. 358 func (c *Client) handleOne() { 359 tag, r, err := recv(c.socket, c.messageSize, func(tag Tag, t MsgType) (message, error) { 360 c.pendingMu.Lock() 361 resp := c.pending[tag] 362 c.pendingMu.Unlock() 363 364 // Not expecting this message? 365 if resp == nil { 366 log.Warningf("client received unexpected tag %v, ignoring", tag) 367 return nil, ErrUnexpectedTag 368 } 369 370 // Is it an error? We specifically allow this to 371 // go through, and then we deserialize below. 372 if t == MsgRlerror { 373 return &Rlerror{}, nil 374 } 375 376 // Does it match expectations? 377 if t != resp.r.Type() { 378 return nil, &ErrBadResponse{Got: t, Want: resp.r.Type()} 379 } 380 381 // Return the response. 382 return resp.r, nil 383 }) 384 385 if err != nil { 386 // No tag was extracted (probably a socket error). 387 // 388 // Likely catastrophic. Notify all waiters and clear pending. 389 c.pendingMu.Lock() 390 for _, resp := range c.pending { 391 resp.done <- err 392 } 393 clear(c.pending) 394 c.pendingMu.Unlock() 395 } else { 396 // Process the tag. 397 // 398 // We know that is is contained in the map because our lookup function 399 // above must have succeeded (found the tag) to return nil err. 400 c.pendingMu.Lock() 401 resp := c.pending[tag] 402 delete(c.pending, tag) 403 c.pendingMu.Unlock() 404 resp.r = r 405 resp.done <- err 406 } 407 } 408 409 // waitAndRecv coordinates with other receivers to handle responses. 410 func (c *Client) waitAndRecv(done chan error) error { 411 for { 412 select { 413 case err := <-done: 414 return err 415 case c.recvr <- true: 416 select { 417 case err := <-done: 418 // It's possible that we got the token, despite 419 // done also being available. Check for that. 420 <-c.recvr 421 return err 422 default: 423 // Handle receiving one tag. 424 c.handleOne() 425 426 // Return the token. 427 <-c.recvr 428 } 429 } 430 } 431 } 432 433 // sendRecvLegacySyscallErr is a wrapper for sendRecvLegacy that converts all 434 // non-syscall errors to EIO. 435 func (c *Client) sendRecvLegacySyscallErr(t message, r message) error { 436 received, err := c.sendRecvLegacy(t, r) 437 if !received { 438 log.Warningf("p9.Client.sendRecvChannel: %v", err) 439 return unix.EIO 440 } 441 return err 442 } 443 444 // sendRecvLegacy performs a roundtrip message exchange. 445 // 446 // sendRecvLegacy returns true if a message was received. This allows us to 447 // differentiate between failed receives and successful receives where the 448 // response was an error message. 449 // 450 // This is called by internal functions. 451 func (c *Client) sendRecvLegacy(t message, r message) (bool, error) { 452 tag, ok := c.tagPool.Get() 453 if !ok { 454 return false, ErrOutOfTags 455 } 456 defer c.tagPool.Put(tag) 457 458 // Indicate we're expecting a response. 459 // 460 // Note that the tag will be cleared from pending 461 // automatically (see handleOne for details). 462 resp := responsePool.Get().(*response) 463 defer responsePool.Put(resp) 464 resp.r = r 465 c.pendingMu.Lock() 466 c.pending[Tag(tag)] = resp 467 c.pendingMu.Unlock() 468 469 // Send the request over the wire. 470 c.sendMu.Lock() 471 err := send(c.socket, Tag(tag), t) 472 c.sendMu.Unlock() 473 if err != nil { 474 return false, err 475 } 476 477 // Coordinate with other receivers. 478 if err := c.waitAndRecv(resp.done); err != nil { 479 return false, err 480 } 481 482 // Is it an error message? 483 // 484 // For convenience, we transform these directly 485 // into errors. Handlers need not handle this case. 486 if rlerr, ok := resp.r.(*Rlerror); ok { 487 return true, unix.Errno(rlerr.Error) 488 } 489 490 // At this point, we know it matches. 491 // 492 // Per recv call above, we will only allow a type 493 // match (and give our r) or an instance of Rlerror. 494 return true, nil 495 } 496 497 // sendRecvChannel uses channels to send a message. 498 func (c *Client) sendRecvChannel(t message, r message) error { 499 // Acquire an available channel. 500 c.channelsMu.Lock() 501 if len(c.availableChannels) == 0 { 502 c.channelsMu.Unlock() 503 return c.sendRecvLegacySyscallErr(t, r) 504 } 505 idx := len(c.availableChannels) - 1 506 ch := c.availableChannels[idx] 507 c.availableChannels = c.availableChannels[:idx] 508 ch.active = true 509 c.channelsWg.Add(1) 510 c.channelsMu.Unlock() 511 512 // Ensure that it's connected. 513 if !ch.connected { 514 ch.connected = true 515 if err := ch.data.Connect(); err != nil { 516 // The channel is unusable, so don't return it to 517 // c.availableChannels. However, we still have to mark it as 518 // inactive so c.watch() doesn't wait for it. 519 c.channelsMu.Lock() 520 ch.active = false 521 c.channelsMu.Unlock() 522 c.channelsWg.Done() 523 // Map all transport errors to EIO, but ensure that the real error 524 // is logged. 525 log.Warningf("p9.Client.sendRecvChannel: flipcall.Endpoint.Connect: %v", err) 526 return unix.EIO 527 } 528 } 529 530 // Send the request and receive the server's response. 531 rsz, err := ch.send(t, false /* isServer */) 532 if err != nil { 533 // See above. 534 c.channelsMu.Lock() 535 ch.active = false 536 c.channelsMu.Unlock() 537 c.channelsWg.Done() 538 log.Warningf("p9.Client.sendRecvChannel: p9.channel.send: %v", err) 539 return unix.EIO 540 } 541 542 // Parse the server's response. 543 resp, retErr := ch.recv(r, rsz) 544 if resp == nil { 545 log.Warningf("p9.Client.sendRecvChannel: p9.channel.recv: %v", retErr) 546 retErr = unix.EIO 547 } 548 549 // Release the channel. 550 c.channelsMu.Lock() 551 ch.active = false 552 // If c.availableChannels is nil, c.watch() has fired and we should not 553 // mark this channel as available. 554 if c.availableChannels != nil { 555 c.availableChannels = append(c.availableChannels, ch) 556 } 557 c.channelsMu.Unlock() 558 c.channelsWg.Done() 559 560 return retErr 561 } 562 563 // Version returns the negotiated 9P2000.L.Google version number. 564 func (c *Client) Version() uint32 { 565 return c.version 566 } 567 568 // Close closes the underlying socket and channels. 569 func (c *Client) Close() { 570 // unet.Socket.Shutdown() has no effect if unet.Socket.Close() has already 571 // been called (by c.watch()). 572 if err := c.socket.Shutdown(); err != nil { 573 log.Warningf("Socket.Shutdown() failed (FD: %d): %v", c.socket.FD(), err) 574 } 575 c.closedWg.Wait() 576 }