github.com/sagernet/gvisor@v0.0.0-20240428053021-e691de28565f/pkg/urpc/urpc.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package urpc provides a minimal RPC package based on unet. 16 // 17 // RPC requests are _not_ concurrent and methods must be explicitly 18 // registered. However, files may be send as part of the payload. 19 package urpc 20 21 import ( 22 "bytes" 23 "encoding/json" 24 "errors" 25 "fmt" 26 "io" 27 "os" 28 "reflect" 29 "runtime" 30 "time" 31 32 "github.com/sagernet/gvisor/pkg/fd" 33 "github.com/sagernet/gvisor/pkg/log" 34 "github.com/sagernet/gvisor/pkg/sync" 35 "github.com/sagernet/gvisor/pkg/unet" 36 ) 37 38 // maxFiles determines the maximum file payload. This limit is arbitrary. Linux 39 // allows SCM_MAX_FD = 253 FDs to be donated in one sendmsg(2) call. 40 const maxFiles = 128 41 42 // ErrTooManyFiles is returned when too many file descriptors are mapped. 43 var ErrTooManyFiles = errors.New("too many files") 44 45 // ErrUnknownMethod is returned when a method is not known. 46 var ErrUnknownMethod = errors.New("unknown method") 47 48 // errStopped is an internal error indicating the server has been stopped. 49 var errStopped = errors.New("stopped") 50 51 // RemoteError is an error returned by the remote invocation. 52 // 53 // This indicates that the RPC transport was correct, but that the called 54 // function itself returned an error. 55 type RemoteError struct { 56 // Message is the result of calling Error() on the remote error. 57 Message string 58 } 59 60 // Error returns the remote error string. 61 func (r RemoteError) Error() string { 62 return r.Message 63 } 64 65 // FilePayload may be _embedded_ in another type in order to send or receive a 66 // file as a result of an RPC. These are not actually serialized, rather they 67 // are sent via an accompanying SCM_RIGHTS message (plumbed through the unet 68 // package). 69 // 70 // When embedding a FilePayload in an argument struct, the argument type _must_ 71 // be a pointer to the struct rather than the struct type itself. This is 72 // because the urpc package defines pointer methods on FilePayload. 73 type FilePayload struct { 74 Files []*os.File `json:"-"` 75 } 76 77 // ReleaseFD releases the FD at the specified index. 78 func (f *FilePayload) ReleaseFD(index int) (*fd.FD, error) { 79 return fd.NewFromFile(f.Files[index]) 80 } 81 82 // filePayload returns the file. It may be nil. 83 func (f *FilePayload) filePayload() []*os.File { 84 return f.Files 85 } 86 87 // setFilePayload sets the payload. 88 func (f *FilePayload) setFilePayload(fs []*os.File) { 89 f.Files = fs 90 } 91 92 // closeAll closes a slice of files. 93 func closeAll(files []*os.File) { 94 for _, f := range files { 95 f.Close() 96 } 97 } 98 99 // filePayloader is implemented only by FilePayload and will be implicitly 100 // implemented by types that have the FilePayload embedded. Note that there is 101 // no way to implement these methods other than by embedding FilePayload, due 102 // to the way unexported method names are mangled. 103 type filePayloader interface { 104 filePayload() []*os.File 105 setFilePayload([]*os.File) 106 } 107 108 // clientCall is the client=>server method call on the client side. 109 type clientCall struct { 110 Method string `json:"method"` 111 Arg any `json:"arg"` 112 } 113 114 // serverCall is the client=>server method call on the server side. 115 type serverCall struct { 116 Method string `json:"method"` 117 Arg json.RawMessage `json:"arg"` 118 } 119 120 // callResult is the server=>client method call result. 121 type callResult struct { 122 Success bool `json:"success"` 123 Err string `json:"err"` 124 Result any `json:"result"` 125 } 126 127 // registeredMethod is method registered with the server. 128 type registeredMethod struct { 129 // fn is the underlying function. 130 fn reflect.Value 131 132 // rcvr is the receiver value. 133 rcvr reflect.Value 134 135 // argType is a typed argument. 136 argType reflect.Type 137 138 // resultType is also a type result. 139 resultType reflect.Type 140 } 141 142 // clientState is client metadata. 143 // 144 // The following are valid states: 145 // 146 // idle - not processing any requests, no close request. 147 // processing - actively processing, no close request. 148 // closeRequested - actively processing, pending close. 149 // closed - client connection has been closed. 150 // 151 // The following transitions are possible: 152 // 153 // idle -> processing, closed 154 // processing -> idle, closeRequested 155 // closeRequested -> closed 156 type clientState int 157 158 // See clientState. 159 const ( 160 idle clientState = iota 161 processing 162 closeRequested 163 closed 164 ) 165 166 // Server is an RPC server. 167 type Server struct { 168 // mu protects all fields, except wg. 169 mu sync.Mutex 170 171 // methods is the set of server methods. 172 methods map[string]registeredMethod 173 174 // stoppers are all registered stoppers. 175 stoppers []Stopper 176 177 // clients is a map of clients. 178 clients map[*unet.Socket]clientState 179 180 // wg is a wait group for all outstanding clients. 181 wg sync.WaitGroup 182 183 // afterRPCCallback is called after each RPC is successfully completed. 184 afterRPCCallback func() 185 } 186 187 // NewServer returns a new server. 188 func NewServer() *Server { 189 return NewServerWithCallback(nil) 190 } 191 192 // NewServerWithCallback returns a new server, who upon completion of each RPC 193 // calls the given function. 194 func NewServerWithCallback(afterRPCCallback func()) *Server { 195 return &Server{ 196 methods: make(map[string]registeredMethod), 197 clients: make(map[*unet.Socket]clientState), 198 afterRPCCallback: afterRPCCallback, 199 } 200 } 201 202 // Stopper is an optional interface, that when implemented, allows an object 203 // to have a callback executed when the server is shutting down. 204 type Stopper interface { 205 Stop() 206 } 207 208 // Register registers the given object as an RPC receiver. 209 // 210 // This functions is the same way as the built-in RPC package, but it does not 211 // tolerate any object with non-conforming methods. Any non-confirming methods 212 // will lead to an immediate panic, instead of being skipped or an error. 213 // Panics will also be generated by anonymous objects and duplicate entries. 214 func (s *Server) Register(obj any) { 215 s.mu.Lock() 216 defer s.mu.Unlock() 217 218 typ := reflect.TypeOf(obj) 219 stopper, hasStop := obj.(Stopper) 220 221 // If we got a pointer, deref it to the underlying object. We need this to 222 // obtain the name of the underlying type. 223 typDeref := typ 224 if typ.Kind() == reflect.Ptr { 225 typDeref = typ.Elem() 226 } 227 228 for m := 0; m < typ.NumMethod(); m++ { 229 method := typ.Method(m) 230 231 if typDeref.Name() == "" { 232 // Can't be anonymous. 233 panic("type not named.") 234 } 235 if hasStop && method.Name == "Stop" { 236 s.stoppers = append(s.stoppers, stopper) 237 continue // Legal stop method. 238 } 239 240 prettyName := typDeref.Name() + "." + method.Name 241 if _, ok := s.methods[prettyName]; ok { 242 // Duplicate entry. 243 panic(fmt.Sprintf("method %s is duplicated.", prettyName)) 244 } 245 246 if method.PkgPath != "" { 247 // Must be exported. 248 panic(fmt.Sprintf("method %s is not exported.", prettyName)) 249 } 250 mtype := method.Type 251 if mtype.NumIn() != 3 { 252 // Need exactly two arguments (+ receiver). 253 panic(fmt.Sprintf("method %s has wrong number of arguments.", prettyName)) 254 } 255 argType := mtype.In(1) 256 if argType.Kind() != reflect.Ptr { 257 // Need arg pointer. 258 panic(fmt.Sprintf("method %s has non-pointer first argument.", prettyName)) 259 } 260 resultType := mtype.In(2) 261 if resultType.Kind() != reflect.Ptr { 262 // Need result pointer. 263 panic(fmt.Sprintf("method %s has non-pointer second argument.", prettyName)) 264 } 265 if mtype.NumOut() != 1 { 266 // Need single return. 267 panic(fmt.Sprintf("method %s has wrong number of returns.", prettyName)) 268 } 269 if returnType := mtype.Out(0); returnType != reflect.TypeOf((*error)(nil)).Elem() { 270 // Need error return. 271 panic(fmt.Sprintf("method %s has non-error return value.", prettyName)) 272 } 273 274 // Register the method. 275 s.methods[prettyName] = registeredMethod{ 276 fn: method.Func, 277 rcvr: reflect.ValueOf(obj), 278 argType: argType, 279 resultType: resultType, 280 } 281 } 282 } 283 284 // lookup looks up the given method. 285 func (s *Server) lookup(method string) (registeredMethod, bool) { 286 s.mu.Lock() 287 defer s.mu.Unlock() 288 rm, ok := s.methods[method] 289 return rm, ok 290 } 291 292 // handleOne handles a single call. 293 func (s *Server) handleOne(client *unet.Socket) error { 294 // Unmarshal the call. 295 var c serverCall 296 newFs, err := unmarshal(client, &c) 297 if err != nil { 298 // Client is dead. 299 return err 300 } 301 if s.afterRPCCallback != nil { 302 defer s.afterRPCCallback() 303 } 304 305 // Explicitly close all these files after the call. 306 // 307 // This is also explicitly a reference to the files after the call, 308 // which means they are kept open for the duration of the call. 309 defer closeAll(newFs) 310 311 // Start the request. 312 if !s.clientBeginRequest(client) { 313 // Client is dead; don't process this call. 314 return errStopped 315 } 316 defer s.clientEndRequest(client) 317 318 // Lookup the method. 319 rm, ok := s.lookup(c.Method) 320 if !ok { 321 // Try to serialize the error. 322 return marshal(client, &callResult{Err: ErrUnknownMethod.Error()}, nil) 323 } 324 325 // Unmarshal the arguments now that we know the type. 326 na := reflect.New(rm.argType.Elem()) 327 if err := json.Unmarshal(c.Arg, na.Interface()); err != nil { 328 return marshal(client, &callResult{Err: err.Error()}, nil) 329 } 330 331 // Set the file payload as an argument. 332 if fp, ok := na.Interface().(filePayloader); ok { 333 fp.setFilePayload(newFs) 334 } 335 336 // Call the method. 337 re := reflect.New(rm.resultType.Elem()) 338 rValues := rm.fn.Call([]reflect.Value{rm.rcvr, na, re}) 339 if errVal := rValues[0].Interface(); errVal != nil { 340 return marshal(client, &callResult{Err: errVal.(error).Error()}, nil) 341 } 342 343 // Set the resulting payload. 344 var fs []*os.File 345 if fp, ok := re.Interface().(filePayloader); ok { 346 fs = fp.filePayload() 347 if len(fs) > maxFiles { 348 // Ugh. Send an error to the client, despite success. 349 return marshal(client, &callResult{Err: ErrTooManyFiles.Error()}, nil) 350 } 351 } 352 353 // Marshal the result. 354 return marshal(client, &callResult{Success: true, Result: re.Interface()}, fs) 355 } 356 357 // clientBeginRequest begins a request. 358 // 359 // If true is returned, the request may be processed. If false is returned, 360 // then the server has been stopped and the request should be skipped. 361 func (s *Server) clientBeginRequest(client *unet.Socket) bool { 362 s.mu.Lock() 363 defer s.mu.Unlock() 364 switch state := s.clients[client]; state { 365 case idle: 366 // Mark as processing. 367 s.clients[client] = processing 368 return true 369 case closed: 370 // Whoops, how did this happen? Must have closed immediately 371 // following the deserialization. Don't let the RPC actually go 372 // through, since we won't be able to serialize a proper 373 // response. 374 return false 375 default: 376 // Should not happen. 377 panic(fmt.Sprintf("expected idle or closed, got %d", state)) 378 } 379 } 380 381 // clientEndRequest ends a request. 382 func (s *Server) clientEndRequest(client *unet.Socket) { 383 s.mu.Lock() 384 defer s.mu.Unlock() 385 switch state := s.clients[client]; state { 386 case processing: 387 // Return to idle. 388 s.clients[client] = idle 389 case closeRequested: 390 // Close the connection. 391 client.Close() 392 s.clients[client] = closed 393 default: 394 // Should not happen. 395 panic(fmt.Sprintf("expected processing or requestClose, got %d", state)) 396 } 397 } 398 399 // clientRegister registers a connection. 400 // 401 // See Stop for more context. 402 func (s *Server) clientRegister(client *unet.Socket) { 403 s.mu.Lock() 404 defer s.mu.Unlock() 405 s.clients[client] = idle 406 s.wg.Add(1) 407 } 408 409 // clientUnregister unregisters and closes a connection if necessary. 410 // 411 // See Stop for more context. 412 func (s *Server) clientUnregister(client *unet.Socket) { 413 s.mu.Lock() 414 defer s.mu.Unlock() 415 switch state := s.clients[client]; state { 416 case idle: 417 // Close the connection. 418 client.Close() 419 case closed: 420 // Already done. 421 default: 422 // Should not happen. 423 panic(fmt.Sprintf("expected idle or closed, got %d", state)) 424 } 425 delete(s.clients, client) 426 s.wg.Done() 427 } 428 429 // handleRegistered handles calls from a registered client. 430 func (s *Server) handleRegistered(client *unet.Socket) error { 431 for { 432 // Handle one call. 433 if err := s.handleOne(client); err != nil { 434 // Client is dead. 435 return err 436 } 437 } 438 } 439 440 // Handle synchronously handles a single client over a connection. 441 func (s *Server) Handle(client *unet.Socket) error { 442 s.clientRegister(client) 443 defer s.clientUnregister(client) 444 return s.handleRegistered(client) 445 } 446 447 // StartHandling creates a goroutine that handles a single client over a 448 // connection. 449 func (s *Server) StartHandling(client *unet.Socket) { 450 s.clientRegister(client) 451 go func() { // S/R-SAFE: out of scope 452 defer s.clientUnregister(client) 453 s.handleRegistered(client) 454 }() 455 } 456 457 // Stop safely terminates outstanding clients. 458 // 459 // No new requests should be initiated after calling Stop. Existing clients 460 // will be closed after completing any pending RPCs. This method will block 461 // until all clients have disconnected. 462 // 463 // timeout is the time for clients to complete ongoing RPCs. 464 func (s *Server) Stop(timeout time.Duration) { 465 // Call any Stop callbacks. 466 for _, stopper := range s.stoppers { 467 stopper.Stop() 468 } 469 470 done := make(chan bool, 1) 471 go func() { 472 if timeout != 0 { 473 timer := time.NewTicker(timeout) 474 defer timer.Stop() 475 select { 476 case <-done: 477 return 478 case <-timer.C: 479 } 480 } 481 482 // Close all known clients. 483 s.mu.Lock() 484 defer s.mu.Unlock() 485 for client, state := range s.clients { 486 switch state { 487 case idle: 488 // Close connection now. 489 client.Close() 490 s.clients[client] = closed 491 case processing: 492 // Request close when done. 493 s.clients[client] = closeRequested 494 } 495 } 496 }() 497 498 // Wait for all outstanding requests. 499 s.wg.Wait() 500 done <- true 501 } 502 503 // Client is a urpc client. 504 type Client struct { 505 // mu protects all members. 506 // 507 // It also enforces single-call semantics. 508 mu sync.Mutex 509 510 // Socket is the underlying socket for this client. 511 // 512 // This _must_ be provided and must be closed manually by calling 513 // Close. 514 Socket *unet.Socket 515 } 516 517 // NewClient returns a new client. 518 func NewClient(socket *unet.Socket) *Client { 519 return &Client{ 520 Socket: socket, 521 } 522 } 523 524 // marshal sends the given FD and json struct. 525 func marshal(s *unet.Socket, v any, fs []*os.File) error { 526 // Marshal to a buffer. 527 data, err := json.Marshal(v) 528 if err != nil { 529 log.Warningf("urpc: error marshalling %s: %s", fmt.Sprintf("%v", v), err.Error()) 530 return err 531 } 532 533 // Write to the socket. 534 w := s.Writer(true) 535 if fs != nil { 536 var fds []int 537 for _, f := range fs { 538 fds = append(fds, int(f.Fd())) 539 } 540 w.PackFDs(fds...) 541 } 542 543 // Send. 544 for n := 0; n < len(data); { 545 cur, err := w.WriteVec([][]byte{data[n:]}) 546 if n == 0 && cur < len(data) { 547 // Don't send FDs anymore. This call is only made on 548 // the first successful call to WriteVec, assuming cur 549 // is not sufficient to fill the entire buffer. 550 w.PackFDs() 551 } 552 n += cur 553 if err != nil { 554 log.Warningf("urpc: error writing %v: %s", data[n:], err.Error()) 555 return err 556 } 557 } 558 559 // We're done sending the fds to the client. Explicitly prevent fs from 560 // being GCed until here. Urpc rpcs often unlink the file to send, relying 561 // on the kernel to automatically delete it once the last reference is 562 // dropped. Until we successfully call sendmsg(2), fs may contain the last 563 // references to these files. Without this explicit reference to fs here, 564 // the go runtime is free to assume we're done with fs after the fd 565 // collection loop above, since it just sees us copying ints. 566 runtime.KeepAlive(fs) 567 568 log.Debugf("urpc: successfully marshalled %d bytes.", len(data)) 569 return nil 570 } 571 572 // unmarhsal receives an FD (optional) and unmarshals the given struct. 573 func unmarshal(s *unet.Socket, v any) ([]*os.File, error) { 574 // Receive a single byte. 575 r := s.Reader(true) 576 r.EnableFDs(maxFiles) 577 firstByte := make([]byte, 1) 578 579 // Extract any FDs that may be there. 580 if _, err := r.ReadVec([][]byte{firstByte}); err != nil { 581 return nil, err 582 } 583 fds, err := r.ExtractFDs() 584 if err != nil { 585 log.Warningf("urpc: error extracting fds: %s", err.Error()) 586 return nil, err 587 } 588 var fs []*os.File 589 for _, fd := range fds { 590 fs = append(fs, os.NewFile(uintptr(fd), "urpc")) 591 } 592 593 // Read the rest. 594 d := json.NewDecoder(io.MultiReader(bytes.NewBuffer(firstByte), s)) 595 // urpc internally decodes / re-encodes the data with any as the 596 // intermediate type. We have to unmarshal integers to json.Number type 597 // instead of the default float type for those intermediate values, such 598 // that when they get re-encoded, their values are not printed out in 599 // floating-point formats such as 1e9, which could not be decoded to 600 // explicitly typed integers later. 601 d.UseNumber() 602 if err := d.Decode(v); err != nil { 603 log.Warningf("urpc: error decoding: %s", err.Error()) 604 for _, f := range fs { 605 f.Close() 606 } 607 return nil, err 608 } 609 610 // All set. 611 log.Debugf("urpc: unmarshal success.") 612 return fs, nil 613 } 614 615 // Call calls a function. 616 func (c *Client) Call(method string, arg any, result any) error { 617 c.mu.Lock() 618 defer c.mu.Unlock() 619 620 // If arg is a FilePayload, not a *FilePayload, files won't actually be 621 // sent, so error out. 622 if _, ok := arg.(FilePayload); ok { 623 return fmt.Errorf("argument is a FilePayload, but should be a *FilePayload") 624 } 625 626 // Are there files to send? 627 var fs []*os.File 628 if fp, ok := arg.(filePayloader); ok { 629 fs = fp.filePayload() 630 if len(fs) > maxFiles { 631 return ErrTooManyFiles 632 } 633 } 634 635 // Marshal the data. 636 if err := marshal(c.Socket, &clientCall{Method: method, Arg: arg}, fs); err != nil { 637 return err 638 } 639 640 // Wait for the response. 641 callR := callResult{Result: result} 642 newFs, err := unmarshal(c.Socket, &callR) 643 if err != nil { 644 return fmt.Errorf("urpc method %q failed: %v", method, err) 645 } 646 647 // Set the file payload. 648 if fp, ok := result.(filePayloader); ok { 649 fp.setFilePayload(newFs) 650 } else { 651 closeAll(newFs) 652 } 653 654 // Did an error occur? 655 if !callR.Success { 656 return RemoteError{Message: callR.Err} 657 } 658 659 // All set. 660 return nil 661 } 662 663 // Close closes the underlying socket. 664 // 665 // Further calls to the client may result in undefined behavior. 666 func (c *Client) Close() error { 667 c.mu.Lock() 668 defer c.mu.Unlock() 669 return c.Socket.Close() 670 }