github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/lisafs/client.go (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package lisafs 16 17 import ( 18 "fmt" 19 "math" 20 21 "golang.org/x/sys/unix" 22 "github.com/nicocha30/gvisor-ligolo/pkg/context" 23 "github.com/nicocha30/gvisor-ligolo/pkg/flipcall" 24 "github.com/nicocha30/gvisor-ligolo/pkg/log" 25 "github.com/nicocha30/gvisor-ligolo/pkg/sync" 26 "github.com/nicocha30/gvisor-ligolo/pkg/unet" 27 ) 28 29 const ( 30 // fdsToCloseBatchSize is the number of closed FDs batched before an Close 31 // RPC is made to close them all. fdsToCloseBatchSize is immutable. 32 fdsToCloseBatchSize = 100 33 ) 34 35 // Client helps manage a connection to the lisafs server and pass messages 36 // efficiently. There is a 1:1 mapping between a Connection and a Client. 37 type Client struct { 38 // sockComm is the main socket by which this connections is established. 39 // Communication over the socket is synchronized by sockMu. 40 sockMu sync.Mutex 41 sockComm *sockCommunicator 42 43 // channelsMu protects channels and availableChannels. 44 channelsMu sync.Mutex 45 // channels tracks all the channels. 46 channels []*channel 47 // availableChannels is a LIFO (stack) of channels available to be used. 48 availableChannels []*channel 49 // activeWg represents active channels. 50 activeWg sync.WaitGroup 51 52 // watchdogWg only holds the watchdog goroutine. 53 watchdogWg sync.WaitGroup 54 55 // supported caches information about which messages are supported. It is 56 // indexed by MID. An MID is supported if supported[MID] is true. 57 supported []bool 58 59 // maxMessageSize is the maximum payload length (in bytes) that can be sent. 60 // It is initialized on Mount and is immutable. 61 maxMessageSize uint32 62 63 // fdsToClose tracks the FDs to close. It caches the FDs no longer being used 64 // by the client and closes them in one shot. It is not preserved across 65 // checkpoint/restore as FDIDs are not preserved. 66 fdsMu sync.Mutex 67 fdsToClose []FDID 68 } 69 70 // NewClient creates a new client for communication with the server. It mounts 71 // the server and creates channels for fast IPC. NewClient takes ownership over 72 // the passed socket. On success, it returns the initialized client along with 73 // the root Inode. 74 func NewClient(sock *unet.Socket) (*Client, Inode, int, error) { 75 c := &Client{ 76 sockComm: newSockComm(sock), 77 maxMessageSize: 1 << 20, // 1 MB for now. 78 fdsToClose: make([]FDID, 0, fdsToCloseBatchSize), 79 } 80 81 // Start a goroutine to check socket health. This goroutine is also 82 // responsible for client cleanup. 83 c.watchdogWg.Add(1) 84 go c.watchdog() 85 86 // Mount the server first. Assume Mount is supported so that we can make the 87 // Mount RPC below. 88 c.supported = make([]bool, Mount+1) 89 c.supported[Mount] = true 90 var ( 91 mountReq MountReq 92 mountResp MountResp 93 mountHostFD = [1]int{-1} 94 ) 95 if err := c.SndRcvMessage(Mount, uint32(mountReq.SizeBytes()), mountReq.MarshalBytes, mountResp.CheckedUnmarshal, mountHostFD[:], mountReq.String, mountResp.String); err != nil { 96 c.Close() 97 return nil, Inode{}, -1, err 98 } 99 100 // Initialize client. 101 c.maxMessageSize = uint32(mountResp.MaxMessageSize) 102 var maxSuppMID MID 103 for _, suppMID := range mountResp.SupportedMs { 104 if suppMID > maxSuppMID { 105 maxSuppMID = suppMID 106 } 107 } 108 c.supported = make([]bool, maxSuppMID+1) 109 for _, suppMID := range mountResp.SupportedMs { 110 c.supported[suppMID] = true 111 } 112 return c, mountResp.Root, mountHostFD[0], nil 113 } 114 115 // StartChannels starts maxChannels() channel communicators. 116 func (c *Client) StartChannels() error { 117 maxChans := maxChannels() 118 c.channelsMu.Lock() 119 c.channels = make([]*channel, 0, maxChans) 120 c.availableChannels = make([]*channel, 0, maxChans) 121 c.channelsMu.Unlock() 122 123 // Create channels parallely so that channels can be used to create more 124 // channels and costly initialization like flipcall.Endpoint.Connect can 125 // proceed parallely. 126 var channelsWg sync.WaitGroup 127 for i := 0; i < maxChans; i++ { 128 channelsWg.Add(1) 129 go func() { 130 defer channelsWg.Done() 131 ch, err := c.createChannel() 132 if err != nil { 133 if err == unix.ENOMEM { 134 log.Debugf("channel creation failed because server hit max channels limit") 135 } else { 136 log.Warningf("channel creation failed: %v", err) 137 } 138 return 139 } 140 c.channelsMu.Lock() 141 c.channels = append(c.channels, ch) 142 c.availableChannels = append(c.availableChannels, ch) 143 c.channelsMu.Unlock() 144 }() 145 } 146 channelsWg.Wait() 147 148 // Check that atleast 1 channel is created. This is not required by lisafs 149 // protocol. It exists to flag server side issues in channel creation. 150 c.channelsMu.Lock() 151 numChannels := len(c.channels) 152 c.channelsMu.Unlock() 153 if maxChans > 0 && numChannels == 0 { 154 log.Warningf("all channel RPCs failed") 155 return unix.ENOMEM 156 } 157 return nil 158 } 159 160 func (c *Client) watchdog() { 161 defer c.watchdogWg.Done() 162 163 events := []unix.PollFd{ 164 { 165 Fd: int32(c.sockComm.FD()), 166 Events: unix.POLLHUP | unix.POLLRDHUP, 167 }, 168 } 169 170 // Wait for a shutdown event. 171 for { 172 n, err := unix.Ppoll(events, nil, nil) 173 if err == unix.EINTR || err == unix.EAGAIN { 174 continue 175 } 176 if err != nil { 177 log.Warningf("lisafs.Client.watch(): %v", err) 178 } else if n != 1 { 179 log.Warningf("lisafs.Client.watch(): got %d events, wanted 1", n) 180 } 181 break 182 } 183 184 // Shutdown all active channels and wait for them to complete. 185 c.shutdownActiveChans() 186 c.activeWg.Wait() 187 188 // Close all channels. 189 c.channelsMu.Lock() 190 for _, ch := range c.channels { 191 ch.destroy() 192 } 193 c.channelsMu.Unlock() 194 195 // Close main socket. 196 c.sockComm.destroy() 197 } 198 199 func (c *Client) shutdownActiveChans() { 200 c.channelsMu.Lock() 201 defer c.channelsMu.Unlock() 202 203 availableChans := make(map[*channel]bool) 204 for _, ch := range c.availableChannels { 205 availableChans[ch] = true 206 } 207 for _, ch := range c.channels { 208 // A channel that is not available is active. 209 if _, ok := availableChans[ch]; !ok { 210 log.Debugf("shutting down active channel@%p...", ch) 211 ch.shutdown() 212 } 213 } 214 215 // Prevent channels from becoming available and serving new requests. 216 c.availableChannels = nil 217 } 218 219 // Close shuts down the main socket and waits for the watchdog to clean up. 220 func (c *Client) Close() { 221 // This shutdown has no effect if the watchdog has already fired and closed 222 // the main socket. 223 c.sockComm.shutdown() 224 c.watchdogWg.Wait() 225 } 226 227 func (c *Client) createChannel() (*channel, error) { 228 var ( 229 chanReq ChannelReq 230 chanResp ChannelResp 231 ) 232 var fds [2]int 233 if err := c.SndRcvMessage(Channel, uint32(chanReq.SizeBytes()), chanReq.MarshalBytes, chanResp.CheckedUnmarshal, fds[:], chanReq.String, chanResp.String); err != nil { 234 return nil, err 235 } 236 if fds[0] < 0 || fds[1] < 0 { 237 closeFDs(fds[:]) 238 return nil, fmt.Errorf("insufficient FDs provided in Channel response: %v", fds) 239 } 240 241 // Lets create the channel. 242 defer closeFDs(fds[:1]) // The data FD is not needed after this. 243 desc := flipcall.PacketWindowDescriptor{ 244 FD: fds[0], 245 Offset: chanResp.dataOffset, 246 Length: int(chanResp.dataLength), 247 } 248 249 ch := &channel{} 250 if err := ch.data.Init(flipcall.ClientSide, desc); err != nil { 251 closeFDs(fds[1:]) 252 return nil, err 253 } 254 ch.fdChan.Init(fds[1]) // fdChan now owns this FD. 255 256 // Only a connected channel is usable. 257 if err := ch.data.Connect(); err != nil { 258 ch.destroy() 259 return nil, err 260 } 261 return ch, nil 262 } 263 264 // IsSupported returns true if this connection supports the passed message. 265 func (c *Client) IsSupported(m MID) bool { 266 return int(m) < len(c.supported) && c.supported[m] 267 } 268 269 // CloseFD either queues the passed FD to be closed or makes a batch 270 // RPC to close all the accumulated FDs-to-close. If flush is true, the RPC 271 // is made immediately. 272 func (c *Client) CloseFD(ctx context.Context, fd FDID, flush bool) { 273 c.fdsMu.Lock() 274 c.fdsToClose = append(c.fdsToClose, fd) 275 if !flush && len(c.fdsToClose) < fdsToCloseBatchSize { 276 // We can continue batching. 277 c.fdsMu.Unlock() 278 return 279 } 280 281 // Flush the cache. We should not hold fdsMu while making an RPC, so be sure 282 // to copy the fdsToClose to another buffer before unlocking fdsMu. 283 var toCloseArr [fdsToCloseBatchSize]FDID 284 toClose := toCloseArr[:len(c.fdsToClose)] 285 copy(toClose, c.fdsToClose) 286 287 // Clear fdsToClose so other FDIDs can be appended. 288 c.fdsToClose = c.fdsToClose[:0] 289 c.fdsMu.Unlock() 290 291 req := CloseReq{FDs: toClose} 292 var resp CloseResp 293 ctx.UninterruptibleSleepStart(false) 294 err := c.SndRcvMessage(Close, uint32(req.SizeBytes()), req.MarshalBytes, resp.CheckedUnmarshal, nil, req.String, resp.String) 295 ctx.UninterruptibleSleepFinish(false) 296 if err != nil { 297 log.Warningf("lisafs: batch closing FDs returned error: %v", err) 298 } 299 } 300 301 // SyncFDs makes a Fsync RPC to sync multiple FDs. 302 func (c *Client) SyncFDs(ctx context.Context, fds []FDID) error { 303 if len(fds) == 0 { 304 return nil 305 } 306 req := FsyncReq{FDs: fds} 307 var resp FsyncResp 308 ctx.UninterruptibleSleepStart(false) 309 err := c.SndRcvMessage(FSync, uint32(req.SizeBytes()), req.MarshalBytes, resp.CheckedUnmarshal, nil, req.String, resp.String) 310 ctx.UninterruptibleSleepFinish(false) 311 return err 312 } 313 314 // SndRcvMessage invokes reqMarshal to marshal the request onto the payload 315 // buffer, wakes up the server to process the request, waits for the response 316 // and invokes respUnmarshal with the response payload. respFDs is populated 317 // with the received FDs, extra fields are set to -1. 318 // 319 // See messages.go to understand why function arguments are used instead of 320 // combining these functions into an interface type. 321 // 322 // Precondition: function arguments must be non-nil. 323 func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal marshalFunc, respUnmarshal unmarshalFunc, respFDs []int, reqString debugStringer, respString debugStringer) error { 324 if !c.IsSupported(m) { 325 return unix.EOPNOTSUPP 326 } 327 if payloadLen > c.maxMessageSize { 328 log.Warningf("message %d has payload which is too large: %d bytes", m, payloadLen) 329 return unix.EIO 330 } 331 wantFDs := len(respFDs) 332 if wantFDs > math.MaxUint8 { 333 log.Warningf("want too many FDs: %d", wantFDs) 334 return unix.EINVAL 335 } 336 337 // Acquire a communicator. 338 comm := c.acquireCommunicator() 339 defer c.releaseCommunicator(comm) 340 341 debugf("send", comm, reqString) 342 343 // Marshal the request into comm's payload buffer and make the RPC. 344 reqMarshal(comm.PayloadBuf(payloadLen)) 345 respM, respPayloadLen, err := comm.SndRcvMessage(m, payloadLen, uint8(wantFDs)) 346 347 // Handle FD donation. 348 rcvFDs := comm.ReleaseFDs() 349 if numRcvFDs := len(rcvFDs); numRcvFDs+wantFDs > 0 { 350 // releasedFDs is memory owned by comm which can not be returned to caller. 351 // Copy it into the caller's buffer. 352 numFDCopied := copy(respFDs, rcvFDs) 353 if numFDCopied < numRcvFDs { 354 log.Warningf("%d unexpected FDs were donated by the server, wanted", numRcvFDs-numFDCopied, wantFDs) 355 closeFDs(rcvFDs[numFDCopied:]) 356 } 357 if numFDCopied < wantFDs { 358 for i := numFDCopied; i < wantFDs; i++ { 359 respFDs[i] = -1 360 } 361 } 362 } 363 364 // Error cases. 365 if err != nil { 366 closeFDs(respFDs) 367 return err 368 } 369 if respPayloadLen > c.maxMessageSize { 370 log.Warningf("server response for message %d is too large: %d bytes", respM, respPayloadLen) 371 closeFDs(respFDs) 372 return unix.EIO 373 } 374 if respM == Error { 375 closeFDs(respFDs) 376 var resp ErrorResp 377 resp.UnmarshalUnsafe(comm.PayloadBuf(respPayloadLen)) 378 debugf("recv", comm, resp.String) 379 return unix.Errno(resp.errno) 380 } 381 if respM != m { 382 closeFDs(respFDs) 383 log.Warningf("sent %d message but got %d in response", m, respM) 384 return unix.EINVAL 385 } 386 387 // Success. The payload must be unmarshalled *before* comm is released. 388 if _, ok := respUnmarshal(comm.PayloadBuf(respPayloadLen)); !ok { 389 log.Warningf("server response unmarshalling for %d message failed", respM) 390 return unix.EIO 391 } 392 debugf("recv", comm, respString) 393 return nil 394 } 395 396 func debugf(action string, comm Communicator, debugMsg debugStringer) { 397 // Replicate the log.IsLogging(log.Debug) check to avoid having to call 398 // debugMsg() on the hot path. 399 if log.IsLogging(log.Debug) { 400 log.Debugf("%s [%s] %s", action, comm, debugMsg()) 401 } 402 } 403 404 // Postcondition: releaseCommunicator() must be called on the returned value. 405 func (c *Client) acquireCommunicator() Communicator { 406 // Prefer using channel over socket because: 407 // - Channel uses a shared memory region for passing messages. IO from shared 408 // memory is faster and does not involve making a syscall. 409 // - No intermediate buffer allocation needed. With a channel, the message 410 // can be directly pasted into the shared memory region. 411 if ch := c.getChannel(); ch != nil { 412 return ch 413 } 414 415 c.sockMu.Lock() 416 return c.sockComm 417 } 418 419 // Precondition: comm must have been acquired via acquireCommunicator(). 420 func (c *Client) releaseCommunicator(comm Communicator) { 421 switch t := comm.(type) { 422 case *sockCommunicator: 423 c.sockMu.Unlock() // +checklocksforce: locked in acquireCommunicator(). 424 case *channel: 425 c.releaseChannel(t) 426 default: 427 panic(fmt.Sprintf("unknown communicator type %T", t)) 428 } 429 } 430 431 // getChannel pops a channel from the available channels stack. The caller must 432 // release the channel after use. 433 func (c *Client) getChannel() *channel { 434 c.channelsMu.Lock() 435 defer c.channelsMu.Unlock() 436 if len(c.availableChannels) == 0 { 437 return nil 438 } 439 440 idx := len(c.availableChannels) - 1 441 ch := c.availableChannels[idx] 442 c.availableChannels = c.availableChannels[:idx] 443 c.activeWg.Add(1) 444 return ch 445 } 446 447 // releaseChannel pushes the passed channel onto the available channel stack if 448 // reinsert is true. 449 func (c *Client) releaseChannel(ch *channel) { 450 c.channelsMu.Lock() 451 defer c.channelsMu.Unlock() 452 453 // If availableChannels is nil, then watchdog has fired and the client is 454 // shutting down. So don't make this channel available again. 455 if !ch.dead && c.availableChannels != nil { 456 c.availableChannels = append(c.availableChannels, ch) 457 } 458 c.activeWg.Done() 459 }