github.com/ttpreport/gvisor-ligolo@v0.0.0-20240123134145-a858404967ba/pkg/urpc/urpc.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package urpc provides a minimal RPC package based on unet.
    16  //
    17  // RPC requests are _not_ concurrent and methods must be explicitly
    18  // registered. However, files may be send as part of the payload.
    19  package urpc
    20  
    21  import (
    22  	"bytes"
    23  	"encoding/json"
    24  	"errors"
    25  	"fmt"
    26  	"io"
    27  	"os"
    28  	"reflect"
    29  	"runtime"
    30  	"time"
    31  
    32  	"github.com/ttpreport/gvisor-ligolo/pkg/fd"
    33  	"github.com/ttpreport/gvisor-ligolo/pkg/log"
    34  	"github.com/ttpreport/gvisor-ligolo/pkg/sync"
    35  	"github.com/ttpreport/gvisor-ligolo/pkg/unet"
    36  )
    37  
    38  // maxFiles determines the maximum file payload. This limit is arbitrary. Linux
    39  // allows SCM_MAX_FD = 253 FDs to be donated in one sendmsg(2) call.
    40  const maxFiles = 128
    41  
    42  // ErrTooManyFiles is returned when too many file descriptors are mapped.
    43  var ErrTooManyFiles = errors.New("too many files")
    44  
    45  // ErrUnknownMethod is returned when a method is not known.
    46  var ErrUnknownMethod = errors.New("unknown method")
    47  
    48  // errStopped is an internal error indicating the server has been stopped.
    49  var errStopped = errors.New("stopped")
    50  
    51  // RemoteError is an error returned by the remote invocation.
    52  //
    53  // This indicates that the RPC transport was correct, but that the called
    54  // function itself returned an error.
    55  type RemoteError struct {
    56  	// Message is the result of calling Error() on the remote error.
    57  	Message string
    58  }
    59  
    60  // Error returns the remote error string.
    61  func (r RemoteError) Error() string {
    62  	return r.Message
    63  }
    64  
    65  // FilePayload may be _embedded_ in another type in order to send or receive a
    66  // file as a result of an RPC. These are not actually serialized, rather they
    67  // are sent via an accompanying SCM_RIGHTS message (plumbed through the unet
    68  // package).
    69  //
    70  // When embedding a FilePayload in an argument struct, the argument type _must_
    71  // be a pointer to the struct rather than the struct type itself. This is
    72  // because the urpc package defines pointer methods on FilePayload.
    73  type FilePayload struct {
    74  	Files []*os.File `json:"-"`
    75  }
    76  
    77  // ReleaseFD releases the FD at the specified index.
    78  func (f *FilePayload) ReleaseFD(index int) (*fd.FD, error) {
    79  	return fd.NewFromFile(f.Files[index])
    80  }
    81  
    82  // filePayload returns the file. It may be nil.
    83  func (f *FilePayload) filePayload() []*os.File {
    84  	return f.Files
    85  }
    86  
    87  // setFilePayload sets the payload.
    88  func (f *FilePayload) setFilePayload(fs []*os.File) {
    89  	f.Files = fs
    90  }
    91  
    92  // closeAll closes a slice of files.
    93  func closeAll(files []*os.File) {
    94  	for _, f := range files {
    95  		f.Close()
    96  	}
    97  }
    98  
    99  // filePayloader is implemented only by FilePayload and will be implicitly
   100  // implemented by types that have the FilePayload embedded. Note that there is
   101  // no way to implement these methods other than by embedding FilePayload, due
   102  // to the way unexported method names are mangled.
   103  type filePayloader interface {
   104  	filePayload() []*os.File
   105  	setFilePayload([]*os.File)
   106  }
   107  
   108  // clientCall is the client=>server method call on the client side.
   109  type clientCall struct {
   110  	Method string `json:"method"`
   111  	Arg    any    `json:"arg"`
   112  }
   113  
   114  // serverCall is the client=>server method call on the server side.
   115  type serverCall struct {
   116  	Method string          `json:"method"`
   117  	Arg    json.RawMessage `json:"arg"`
   118  }
   119  
   120  // callResult is the server=>client method call result.
   121  type callResult struct {
   122  	Success bool   `json:"success"`
   123  	Err     string `json:"err"`
   124  	Result  any    `json:"result"`
   125  }
   126  
   127  // registeredMethod is method registered with the server.
   128  type registeredMethod struct {
   129  	// fn is the underlying function.
   130  	fn reflect.Value
   131  
   132  	// rcvr is the receiver value.
   133  	rcvr reflect.Value
   134  
   135  	// argType is a typed argument.
   136  	argType reflect.Type
   137  
   138  	// resultType is also a type result.
   139  	resultType reflect.Type
   140  }
   141  
   142  // clientState is client metadata.
   143  //
   144  // The following are valid states:
   145  //
   146  // idle - not processing any requests, no close request.
   147  // processing - actively processing, no close request.
   148  // closeRequested - actively processing, pending close.
   149  // closed - client connection has been closed.
   150  //
   151  // The following transitions are possible:
   152  //
   153  // idle -> processing, closed
   154  // processing -> idle, closeRequested
   155  // closeRequested -> closed
   156  type clientState int
   157  
   158  // See clientState.
   159  const (
   160  	idle clientState = iota
   161  	processing
   162  	closeRequested
   163  	closed
   164  )
   165  
   166  // Server is an RPC server.
   167  type Server struct {
   168  	// mu protects all fields, except wg.
   169  	mu sync.Mutex
   170  
   171  	// methods is the set of server methods.
   172  	methods map[string]registeredMethod
   173  
   174  	// stoppers are all registered stoppers.
   175  	stoppers []Stopper
   176  
   177  	// clients is a map of clients.
   178  	clients map[*unet.Socket]clientState
   179  
   180  	// wg is a wait group for all outstanding clients.
   181  	wg sync.WaitGroup
   182  
   183  	// afterRPCCallback is called after each RPC is successfully completed.
   184  	afterRPCCallback func()
   185  }
   186  
   187  // NewServer returns a new server.
   188  func NewServer() *Server {
   189  	return NewServerWithCallback(nil)
   190  }
   191  
   192  // NewServerWithCallback returns a new server, who upon completion of each RPC
   193  // calls the given function.
   194  func NewServerWithCallback(afterRPCCallback func()) *Server {
   195  	return &Server{
   196  		methods:          make(map[string]registeredMethod),
   197  		clients:          make(map[*unet.Socket]clientState),
   198  		afterRPCCallback: afterRPCCallback,
   199  	}
   200  }
   201  
   202  // Stopper is an optional interface, that when implemented, allows an object
   203  // to have a callback executed when the server is shutting down.
   204  type Stopper interface {
   205  	Stop()
   206  }
   207  
   208  // Register registers the given object as an RPC receiver.
   209  //
   210  // This functions is the same way as the built-in RPC package, but it does not
   211  // tolerate any object with non-conforming methods. Any non-confirming methods
   212  // will lead to an immediate panic, instead of being skipped or an error.
   213  // Panics will also be generated by anonymous objects and duplicate entries.
   214  func (s *Server) Register(obj any) {
   215  	s.mu.Lock()
   216  	defer s.mu.Unlock()
   217  
   218  	typ := reflect.TypeOf(obj)
   219  	stopper, hasStop := obj.(Stopper)
   220  
   221  	// If we got a pointer, deref it to the underlying object. We need this to
   222  	// obtain the name of the underlying type.
   223  	typDeref := typ
   224  	if typ.Kind() == reflect.Ptr {
   225  		typDeref = typ.Elem()
   226  	}
   227  
   228  	for m := 0; m < typ.NumMethod(); m++ {
   229  		method := typ.Method(m)
   230  
   231  		if typDeref.Name() == "" {
   232  			// Can't be anonymous.
   233  			panic("type not named.")
   234  		}
   235  		if hasStop && method.Name == "Stop" {
   236  			s.stoppers = append(s.stoppers, stopper)
   237  			continue // Legal stop method.
   238  		}
   239  
   240  		prettyName := typDeref.Name() + "." + method.Name
   241  		if _, ok := s.methods[prettyName]; ok {
   242  			// Duplicate entry.
   243  			panic(fmt.Sprintf("method %s is duplicated.", prettyName))
   244  		}
   245  
   246  		if method.PkgPath != "" {
   247  			// Must be exported.
   248  			panic(fmt.Sprintf("method %s is not exported.", prettyName))
   249  		}
   250  		mtype := method.Type
   251  		if mtype.NumIn() != 3 {
   252  			// Need exactly two arguments (+ receiver).
   253  			panic(fmt.Sprintf("method %s has wrong number of arguments.", prettyName))
   254  		}
   255  		argType := mtype.In(1)
   256  		if argType.Kind() != reflect.Ptr {
   257  			// Need arg pointer.
   258  			panic(fmt.Sprintf("method %s has non-pointer first argument.", prettyName))
   259  		}
   260  		resultType := mtype.In(2)
   261  		if resultType.Kind() != reflect.Ptr {
   262  			// Need result pointer.
   263  			panic(fmt.Sprintf("method %s has non-pointer second argument.", prettyName))
   264  		}
   265  		if mtype.NumOut() != 1 {
   266  			// Need single return.
   267  			panic(fmt.Sprintf("method %s has wrong number of returns.", prettyName))
   268  		}
   269  		if returnType := mtype.Out(0); returnType != reflect.TypeOf((*error)(nil)).Elem() {
   270  			// Need error return.
   271  			panic(fmt.Sprintf("method %s has non-error return value.", prettyName))
   272  		}
   273  
   274  		// Register the method.
   275  		s.methods[prettyName] = registeredMethod{
   276  			fn:         method.Func,
   277  			rcvr:       reflect.ValueOf(obj),
   278  			argType:    argType,
   279  			resultType: resultType,
   280  		}
   281  	}
   282  }
   283  
   284  // lookup looks up the given method.
   285  func (s *Server) lookup(method string) (registeredMethod, bool) {
   286  	s.mu.Lock()
   287  	defer s.mu.Unlock()
   288  	rm, ok := s.methods[method]
   289  	return rm, ok
   290  }
   291  
   292  // handleOne handles a single call.
   293  func (s *Server) handleOne(client *unet.Socket) error {
   294  	// Unmarshal the call.
   295  	var c serverCall
   296  	newFs, err := unmarshal(client, &c)
   297  	if err != nil {
   298  		// Client is dead.
   299  		return err
   300  	}
   301  	if s.afterRPCCallback != nil {
   302  		defer s.afterRPCCallback()
   303  	}
   304  
   305  	// Explicitly close all these files after the call.
   306  	//
   307  	// This is also explicitly a reference to the files after the call,
   308  	// which means they are kept open for the duration of the call.
   309  	defer closeAll(newFs)
   310  
   311  	// Start the request.
   312  	if !s.clientBeginRequest(client) {
   313  		// Client is dead; don't process this call.
   314  		return errStopped
   315  	}
   316  	defer s.clientEndRequest(client)
   317  
   318  	// Lookup the method.
   319  	rm, ok := s.lookup(c.Method)
   320  	if !ok {
   321  		// Try to serialize the error.
   322  		return marshal(client, &callResult{Err: ErrUnknownMethod.Error()}, nil)
   323  	}
   324  
   325  	// Unmarshal the arguments now that we know the type.
   326  	na := reflect.New(rm.argType.Elem())
   327  	if err := json.Unmarshal(c.Arg, na.Interface()); err != nil {
   328  		return marshal(client, &callResult{Err: err.Error()}, nil)
   329  	}
   330  
   331  	// Set the file payload as an argument.
   332  	if fp, ok := na.Interface().(filePayloader); ok {
   333  		fp.setFilePayload(newFs)
   334  	}
   335  
   336  	// Call the method.
   337  	re := reflect.New(rm.resultType.Elem())
   338  	rValues := rm.fn.Call([]reflect.Value{rm.rcvr, na, re})
   339  	if errVal := rValues[0].Interface(); errVal != nil {
   340  		return marshal(client, &callResult{Err: errVal.(error).Error()}, nil)
   341  	}
   342  
   343  	// Set the resulting payload.
   344  	var fs []*os.File
   345  	if fp, ok := re.Interface().(filePayloader); ok {
   346  		fs = fp.filePayload()
   347  		if len(fs) > maxFiles {
   348  			// Ugh. Send an error to the client, despite success.
   349  			return marshal(client, &callResult{Err: ErrTooManyFiles.Error()}, nil)
   350  		}
   351  	}
   352  
   353  	// Marshal the result.
   354  	return marshal(client, &callResult{Success: true, Result: re.Interface()}, fs)
   355  }
   356  
   357  // clientBeginRequest begins a request.
   358  //
   359  // If true is returned, the request may be processed. If false is returned,
   360  // then the server has been stopped and the request should be skipped.
   361  func (s *Server) clientBeginRequest(client *unet.Socket) bool {
   362  	s.mu.Lock()
   363  	defer s.mu.Unlock()
   364  	switch state := s.clients[client]; state {
   365  	case idle:
   366  		// Mark as processing.
   367  		s.clients[client] = processing
   368  		return true
   369  	case closed:
   370  		// Whoops, how did this happen? Must have closed immediately
   371  		// following the deserialization. Don't let the RPC actually go
   372  		// through, since we won't be able to serialize a proper
   373  		// response.
   374  		return false
   375  	default:
   376  		// Should not happen.
   377  		panic(fmt.Sprintf("expected idle or closed, got %d", state))
   378  	}
   379  }
   380  
   381  // clientEndRequest ends a request.
   382  func (s *Server) clientEndRequest(client *unet.Socket) {
   383  	s.mu.Lock()
   384  	defer s.mu.Unlock()
   385  	switch state := s.clients[client]; state {
   386  	case processing:
   387  		// Return to idle.
   388  		s.clients[client] = idle
   389  	case closeRequested:
   390  		// Close the connection.
   391  		client.Close()
   392  		s.clients[client] = closed
   393  	default:
   394  		// Should not happen.
   395  		panic(fmt.Sprintf("expected processing or requestClose, got %d", state))
   396  	}
   397  }
   398  
   399  // clientRegister registers a connection.
   400  //
   401  // See Stop for more context.
   402  func (s *Server) clientRegister(client *unet.Socket) {
   403  	s.mu.Lock()
   404  	defer s.mu.Unlock()
   405  	s.clients[client] = idle
   406  	s.wg.Add(1)
   407  }
   408  
   409  // clientUnregister unregisters and closes a connection if necessary.
   410  //
   411  // See Stop for more context.
   412  func (s *Server) clientUnregister(client *unet.Socket) {
   413  	s.mu.Lock()
   414  	defer s.mu.Unlock()
   415  	switch state := s.clients[client]; state {
   416  	case idle:
   417  		// Close the connection.
   418  		client.Close()
   419  	case closed:
   420  		// Already done.
   421  	default:
   422  		// Should not happen.
   423  		panic(fmt.Sprintf("expected idle or closed, got %d", state))
   424  	}
   425  	delete(s.clients, client)
   426  	s.wg.Done()
   427  }
   428  
   429  // handleRegistered handles calls from a registered client.
   430  func (s *Server) handleRegistered(client *unet.Socket) error {
   431  	for {
   432  		// Handle one call.
   433  		if err := s.handleOne(client); err != nil {
   434  			// Client is dead.
   435  			return err
   436  		}
   437  	}
   438  }
   439  
   440  // Handle synchronously handles a single client over a connection.
   441  func (s *Server) Handle(client *unet.Socket) error {
   442  	s.clientRegister(client)
   443  	defer s.clientUnregister(client)
   444  	return s.handleRegistered(client)
   445  }
   446  
   447  // StartHandling creates a goroutine that handles a single client over a
   448  // connection.
   449  func (s *Server) StartHandling(client *unet.Socket) {
   450  	s.clientRegister(client)
   451  	go func() { // S/R-SAFE: out of scope
   452  		defer s.clientUnregister(client)
   453  		s.handleRegistered(client)
   454  	}()
   455  }
   456  
   457  // Stop safely terminates outstanding clients.
   458  //
   459  // No new requests should be initiated after calling Stop. Existing clients
   460  // will be closed after completing any pending RPCs. This method will block
   461  // until all clients have disconnected.
   462  //
   463  // timeout is the time for clients to complete ongoing RPCs.
   464  func (s *Server) Stop(timeout time.Duration) {
   465  	// Call any Stop callbacks.
   466  	for _, stopper := range s.stoppers {
   467  		stopper.Stop()
   468  	}
   469  
   470  	done := make(chan bool, 1)
   471  	go func() {
   472  		if timeout != 0 {
   473  			timer := time.NewTicker(timeout)
   474  			defer timer.Stop()
   475  			select {
   476  			case <-done:
   477  				return
   478  			case <-timer.C:
   479  			}
   480  		}
   481  
   482  		// Close all known clients.
   483  		s.mu.Lock()
   484  		defer s.mu.Unlock()
   485  		for client, state := range s.clients {
   486  			switch state {
   487  			case idle:
   488  				// Close connection now.
   489  				client.Close()
   490  				s.clients[client] = closed
   491  			case processing:
   492  				// Request close when done.
   493  				s.clients[client] = closeRequested
   494  			}
   495  		}
   496  	}()
   497  
   498  	// Wait for all outstanding requests.
   499  	s.wg.Wait()
   500  	done <- true
   501  }
   502  
   503  // Client is a urpc client.
   504  type Client struct {
   505  	// mu protects all members.
   506  	//
   507  	// It also enforces single-call semantics.
   508  	mu sync.Mutex
   509  
   510  	// Socket is the underlying socket for this client.
   511  	//
   512  	// This _must_ be provided and must be closed manually by calling
   513  	// Close.
   514  	Socket *unet.Socket
   515  }
   516  
   517  // NewClient returns a new client.
   518  func NewClient(socket *unet.Socket) *Client {
   519  	return &Client{
   520  		Socket: socket,
   521  	}
   522  }
   523  
   524  // marshal sends the given FD and json struct.
   525  func marshal(s *unet.Socket, v any, fs []*os.File) error {
   526  	// Marshal to a buffer.
   527  	data, err := json.Marshal(v)
   528  	if err != nil {
   529  		log.Warningf("urpc: error marshalling %s: %s", fmt.Sprintf("%v", v), err.Error())
   530  		return err
   531  	}
   532  
   533  	// Write to the socket.
   534  	w := s.Writer(true)
   535  	if fs != nil {
   536  		var fds []int
   537  		for _, f := range fs {
   538  			fds = append(fds, int(f.Fd()))
   539  		}
   540  		w.PackFDs(fds...)
   541  	}
   542  
   543  	// Send.
   544  	for n := 0; n < len(data); {
   545  		cur, err := w.WriteVec([][]byte{data[n:]})
   546  		if n == 0 && cur < len(data) {
   547  			// Don't send FDs anymore. This call is only made on
   548  			// the first successful call to WriteVec, assuming cur
   549  			// is not sufficient to fill the entire buffer.
   550  			w.PackFDs()
   551  		}
   552  		n += cur
   553  		if err != nil {
   554  			log.Warningf("urpc: error writing %v: %s", data[n:], err.Error())
   555  			return err
   556  		}
   557  	}
   558  
   559  	// We're done sending the fds to the client. Explicitly prevent fs from
   560  	// being GCed until here. Urpc rpcs often unlink the file to send, relying
   561  	// on the kernel to automatically delete it once the last reference is
   562  	// dropped. Until we successfully call sendmsg(2), fs may contain the last
   563  	// references to these files. Without this explicit reference to fs here,
   564  	// the go runtime is free to assume we're done with fs after the fd
   565  	// collection loop above, since it just sees us copying ints.
   566  	runtime.KeepAlive(fs)
   567  
   568  	log.Debugf("urpc: successfully marshalled %d bytes.", len(data))
   569  	return nil
   570  }
   571  
   572  // unmarhsal receives an FD (optional) and unmarshals the given struct.
   573  func unmarshal(s *unet.Socket, v any) ([]*os.File, error) {
   574  	// Receive a single byte.
   575  	r := s.Reader(true)
   576  	r.EnableFDs(maxFiles)
   577  	firstByte := make([]byte, 1)
   578  
   579  	// Extract any FDs that may be there.
   580  	if _, err := r.ReadVec([][]byte{firstByte}); err != nil {
   581  		return nil, err
   582  	}
   583  	fds, err := r.ExtractFDs()
   584  	if err != nil {
   585  		log.Warningf("urpc: error extracting fds: %s", err.Error())
   586  		return nil, err
   587  	}
   588  	var fs []*os.File
   589  	for _, fd := range fds {
   590  		fs = append(fs, os.NewFile(uintptr(fd), "urpc"))
   591  	}
   592  
   593  	// Read the rest.
   594  	d := json.NewDecoder(io.MultiReader(bytes.NewBuffer(firstByte), s))
   595  	// urpc internally decodes / re-encodes the data with any as the
   596  	// intermediate type. We have to unmarshal integers to json.Number type
   597  	// instead of the default float type for those intermediate values, such
   598  	// that when they get re-encoded, their values are not printed out in
   599  	// floating-point formats such as 1e9, which could not be decoded to
   600  	// explicitly typed intergers later.
   601  	d.UseNumber()
   602  	if err := d.Decode(v); err != nil {
   603  		log.Warningf("urpc: error decoding: %s", err.Error())
   604  		for _, f := range fs {
   605  			f.Close()
   606  		}
   607  		return nil, err
   608  	}
   609  
   610  	// All set.
   611  	log.Debugf("urpc: unmarshal success.")
   612  	return fs, nil
   613  }
   614  
   615  // Call calls a function.
   616  func (c *Client) Call(method string, arg any, result any) error {
   617  	c.mu.Lock()
   618  	defer c.mu.Unlock()
   619  
   620  	// If arg is a FilePayload, not a *FilePayload, files won't actually be
   621  	// sent, so error out.
   622  	if _, ok := arg.(FilePayload); ok {
   623  		return fmt.Errorf("argument is a FilePayload, but should be a *FilePayload")
   624  	}
   625  
   626  	// Are there files to send?
   627  	var fs []*os.File
   628  	if fp, ok := arg.(filePayloader); ok {
   629  		fs = fp.filePayload()
   630  		if len(fs) > maxFiles {
   631  			return ErrTooManyFiles
   632  		}
   633  	}
   634  
   635  	// Marshal the data.
   636  	if err := marshal(c.Socket, &clientCall{Method: method, Arg: arg}, fs); err != nil {
   637  		return err
   638  	}
   639  
   640  	// Wait for the response.
   641  	callR := callResult{Result: result}
   642  	newFs, err := unmarshal(c.Socket, &callR)
   643  	if err != nil {
   644  		return fmt.Errorf("urpc method %q failed: %v", method, err)
   645  	}
   646  
   647  	// Set the file payload.
   648  	if fp, ok := result.(filePayloader); ok {
   649  		fp.setFilePayload(newFs)
   650  	} else {
   651  		closeAll(newFs)
   652  	}
   653  
   654  	// Did an error occur?
   655  	if !callR.Success {
   656  		return RemoteError{Message: callR.Err}
   657  	}
   658  
   659  	// All set.
   660  	return nil
   661  }
   662  
   663  // Close closes the underlying socket.
   664  //
   665  // Further calls to the client may result in undefined behavior.
   666  func (c *Client) Close() error {
   667  	c.mu.Lock()
   668  	defer c.mu.Unlock()
   669  	return c.Socket.Close()
   670  }