github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/rpc/server.go (about)

     1  // Copyright 2012, 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package rpc
     5  
     6  import (
     7  	"context"
     8  	"io"
     9  	"reflect"
    10  	"runtime/debug"
    11  	"strings"
    12  	"sync"
    13  
    14  	"github.com/juju/errors"
    15  	"github.com/juju/loggo"
    16  
    17  	"github.com/juju/juju/rpc/rpcreflect"
    18  )
    19  
    20  const codeNotImplemented = "not implemented"
    21  
    22  var logger = loggo.GetLogger("juju.rpc")
    23  
    24  // A Codec implements reading and writing of messages in an RPC
    25  // session.  The RPC code calls WriteMessage to write a message to the
    26  // connection and calls ReadHeader and ReadBody in pairs to read
    27  // messages.
    28  type Codec interface {
    29  	// ReadHeader reads a message header into hdr.
    30  	ReadHeader(hdr *Header) error
    31  
    32  	// ReadBody reads a message body into the given body value.  The
    33  	// isRequest parameter specifies whether the message being read
    34  	// is a request; if not, it's a response.  The body value will
    35  	// be a non-nil struct pointer, or nil to signify that the body
    36  	// should be read and discarded.
    37  	ReadBody(body interface{}, isRequest bool) error
    38  
    39  	// WriteMessage writes a message with the given header and body.
    40  	// The body will always be a struct. It may be called concurrently
    41  	// with ReadHeader and ReadBody, but will not be called
    42  	// concurrently with itself.
    43  	WriteMessage(hdr *Header, body interface{}) error
    44  
    45  	// Close closes the codec. It may be called concurrently
    46  	// and should cause the Read methods to unblock.
    47  	Close() error
    48  }
    49  
    50  // Header is a header written before every RPC call.  Since RPC requests
    51  // can be initiated from either side, the header may represent a request
    52  // from the other side or a response to an outstanding request.
    53  type Header struct {
    54  	// RequestId holds the sequence number of the request.
    55  	// For replies, it holds the sequence number of the request
    56  	// that is being replied to.
    57  	RequestId uint64
    58  
    59  	// Request holds the action to invoke.
    60  	Request Request
    61  
    62  	// Error holds the error, if any.
    63  	Error string
    64  
    65  	// ErrorCode holds the code of the error, if any.
    66  	ErrorCode string
    67  
    68  	// Version defines the wire format of the request and response structure.
    69  	Version int
    70  }
    71  
    72  // Request represents an RPC to be performed, absent its parameters.
    73  type Request struct {
    74  	// Type holds the type of object to act on.
    75  	Type string
    76  
    77  	// Version holds the version of Type we will be acting on
    78  	Version int
    79  
    80  	// Id holds the id of the object to act on.
    81  	Id string
    82  
    83  	// Action holds the action to perform on the object.
    84  	Action string
    85  }
    86  
    87  // IsRequest returns whether the header represents an RPC request.  If
    88  // it is not a request, it is a response.
    89  func (hdr *Header) IsRequest() bool {
    90  	return hdr.Request.Type != "" || hdr.Request.Action != ""
    91  }
    92  
    93  // RecorderFactory is a function that returns a recorder to record
    94  // details of a single request/response.
    95  type RecorderFactory func() Recorder
    96  
    97  // Recorder represents something the connection uses to record
    98  // requests and replies. Recording a message can fail (for example for
    99  // audit logging), and when it does the request should be failed as
   100  // well.
   101  type Recorder interface {
   102  	HandleRequest(hdr *Header, body interface{}) error
   103  	HandleReply(req Request, replyHdr *Header, body interface{}) error
   104  }
   105  
   106  // Note that we use "client request" and "server request" to name
   107  // requests initiated locally and remotely respectively.
   108  
   109  // Conn represents an RPC endpoint.  It can both initiate and receive
   110  // RPC requests.  There may be multiple outstanding Calls associated
   111  // with a single Client, and a Client may be used by multiple goroutines
   112  // simultaneously.
   113  type Conn struct {
   114  	// codec holds the underlying RPC connection.
   115  	codec Codec
   116  
   117  	// srvPending represents the current server requests.
   118  	srvPending sync.WaitGroup
   119  
   120  	// sending guards the write side of the codec - it ensures
   121  	// that codec.WriteMessage is not called concurrently.
   122  	// It also guards shutdown.
   123  	sending sync.Mutex
   124  
   125  	// mutex guards the following values.
   126  	mutex sync.Mutex
   127  
   128  	// root represents  the current root object that serves the RPC requests.
   129  	// It may be nil if nothing is being served.
   130  	root Root
   131  
   132  	// transformErrors is used to transform returned errors.
   133  	transformErrors func(error) error
   134  
   135  	// reqId holds the latest client request id.
   136  	reqId uint64
   137  
   138  	// clientPending holds all pending client requests.
   139  	clientPending map[uint64]*Call
   140  
   141  	// closing is set when the connection is shutting down via
   142  	// Close.  When this is set, no more client or server requests
   143  	// will be initiated.
   144  	closing bool
   145  
   146  	// shutdown is set when the input loop terminates. When this
   147  	// is set, no more client requests will be sent to the server.
   148  	shutdown bool
   149  
   150  	// dead is closed when the input loop terminates.
   151  	dead chan struct{}
   152  
   153  	// context is created when the connection is started, and is
   154  	// cancelled when the connection is closed.
   155  	context       context.Context
   156  	cancelContext context.CancelFunc
   157  
   158  	// inputLoopError holds the error that caused the input loop to
   159  	// terminate prematurely.  It is set before dead is closed.
   160  	inputLoopError error
   161  
   162  	recorderFactory RecorderFactory
   163  }
   164  
   165  // NewConn creates a new connection that uses the given codec for
   166  // transport, but it does not start it. Conn.Start must be called
   167  // before any requests are sent or received. If recorderFactory is
   168  // non-nil, it will be called to get a new recorder for every request.
   169  func NewConn(codec Codec, factory RecorderFactory) *Conn {
   170  	return &Conn{
   171  		codec:           codec,
   172  		clientPending:   make(map[uint64]*Call),
   173  		recorderFactory: ensureFactory(factory),
   174  	}
   175  }
   176  
   177  // Start starts the RPC connection running.  It must be called at
   178  // least once for any RPC connection (client or server side) It has no
   179  // effect if it has already been called.  By default, a connection
   180  // serves no methods.  See Conn.Serve for a description of how to
   181  // serve methods on a Conn.
   182  //
   183  // The context passed in will be propagated to requests served by
   184  // the connection.
   185  func (conn *Conn) Start(ctx context.Context) {
   186  	conn.mutex.Lock()
   187  	defer conn.mutex.Unlock()
   188  	if conn.dead == nil {
   189  		conn.context, conn.cancelContext = context.WithCancel(ctx)
   190  		conn.dead = make(chan struct{})
   191  		go conn.input()
   192  	}
   193  }
   194  
   195  // Serve serves RPC requests on the connection by invoking methods on
   196  // root. Note that it does not start the connection running,
   197  // though it may be called once the connection is already started.
   198  //
   199  // The server executes each client request by calling a method on root
   200  // to obtain an object to act on; then it invokes an method on that
   201  // object with the request parameters, possibly returning some result.
   202  //
   203  // Methods on the root value are of the form:
   204  //
   205  //      M(id string) (O, error)
   206  //
   207  // where M is an exported name, conventionally naming the object type,
   208  // id is some identifier for the object and O is the type of the
   209  // returned object.
   210  //
   211  // Methods defined on O may defined in one of the following forms, where
   212  // T and R must be struct types.
   213  //
   214  //	Method([context.Context])
   215  //	Method([context.Context]) R
   216  //	Method([context.Context]) (R, error)
   217  //	Method([context.Context]) error
   218  //	Method([context.Context,]T)
   219  //	Method([context.Context,]T) R
   220  //	Method([context.Context,]T) (R, error)
   221  //	Method([context.Context,]T) error
   222  //
   223  // If transformErrors is non-nil, it will be called on all returned
   224  // non-nil errors, for example to transform the errors into ServerErrors
   225  // with specified codes.  There will be a panic if transformErrors
   226  // returns nil.
   227  //
   228  // Serve may be called at any time on a connection to change the
   229  // set of methods being served by the connection. This will have
   230  // no effect on calls that are currently being services.
   231  // If root is nil, the connection will serve no methods.
   232  func (conn *Conn) Serve(root interface{}, factory RecorderFactory, transformErrors func(error) error) {
   233  	rootValue := rpcreflect.ValueOf(reflect.ValueOf(root))
   234  	if rootValue.IsValid() {
   235  		conn.serve(rootValue, factory, transformErrors)
   236  	} else {
   237  		conn.serve(nil, factory, transformErrors)
   238  	}
   239  }
   240  
   241  // ServeRoot is like Serve except that it gives the root object dynamic
   242  // control over what methods are available instead of using reflection
   243  // on the type.
   244  //
   245  // The server executes each client request by calling FindMethod to obtain a
   246  // method to invoke. It invokes that method with the request parameters,
   247  // possibly returning some result.
   248  //
   249  // The Kill method will be called when the connection is closed.
   250  func (conn *Conn) ServeRoot(root Root, factory RecorderFactory, transformErrors func(error) error) {
   251  	conn.serve(root, factory, transformErrors)
   252  }
   253  
   254  func (conn *Conn) serve(root Root, factory RecorderFactory, transformErrors func(error) error) {
   255  	if transformErrors == nil {
   256  		transformErrors = noopTransform
   257  	}
   258  	conn.mutex.Lock()
   259  	defer conn.mutex.Unlock()
   260  	conn.root = root
   261  	conn.recorderFactory = ensureFactory(factory)
   262  	conn.transformErrors = transformErrors
   263  }
   264  
   265  // noopTransform is used when transformErrors is not supplied to Serve.
   266  func noopTransform(err error) error {
   267  	return err
   268  }
   269  
   270  // Dead returns a channel that is closed when the connection
   271  // has been closed or the underlying transport has received
   272  // an error. There may still be outstanding requests.
   273  // Dead must be called after conn.Start has been called.
   274  func (conn *Conn) Dead() <-chan struct{} {
   275  	return conn.dead
   276  }
   277  
   278  // Close closes the connection and its underlying codec; it returns when
   279  // all requests have been terminated.
   280  //
   281  // If the connection is serving requests, and the root value implements
   282  // the Killer interface, its Kill method will be called.  The codec will
   283  // then be closed only when all its outstanding server calls have
   284  // completed.
   285  //
   286  // Calling Close multiple times is not an error.
   287  func (conn *Conn) Close() error {
   288  	conn.mutex.Lock()
   289  	if conn.closing {
   290  		conn.mutex.Unlock()
   291  		// Golang's net/rpc returns rpc.ErrShutdown if you ask to close
   292  		// a closing or shutdown connection. Our choice is that Close
   293  		// is an idempotent way to ask for resources to be released and
   294  		// isn't a failure if called multiple times.
   295  		return nil
   296  	}
   297  	conn.closing = true
   298  	if conn.root != nil {
   299  		// Kill calls down into the resources to stop all the resources which
   300  		// includes watchers. The watches need to be killed in order for their
   301  		// API methods to return, otherwise they are just waiting.
   302  		conn.root.Kill()
   303  	}
   304  	conn.mutex.Unlock()
   305  
   306  	// Wait for any outstanding server requests to complete
   307  	// and write their replies before closing the codec. We
   308  	// cancel the context so that any requests that would
   309  	// block will be notified that the server is shutting
   310  	// down.
   311  	conn.cancelContext()
   312  	conn.srvPending.Wait()
   313  
   314  	conn.mutex.Lock()
   315  	if conn.root != nil {
   316  		// It is possible that since we last Killed the root, other resources
   317  		// may have been added during some of the pending call resoulutions.
   318  		// So to release these resources, double tap the root.
   319  		conn.root.Kill()
   320  	}
   321  	conn.mutex.Unlock()
   322  
   323  	// Closing the codec should cause the input loop to terminate.
   324  	if err := conn.codec.Close(); err != nil {
   325  		logger.Debugf("error closing codec: %v", err)
   326  	}
   327  	<-conn.dead
   328  
   329  	return conn.inputLoopError
   330  }
   331  
   332  // ErrorCoder represents an any error that has an associated
   333  // error code. An error code is a short string that represents the
   334  // kind of an error.
   335  type ErrorCoder interface {
   336  	ErrorCode() string
   337  }
   338  
   339  // Root represents a type that can be used to lookup a Method and place
   340  // calls on that method.
   341  type Root interface {
   342  	FindMethod(rootName string, version int, methodName string) (rpcreflect.MethodCaller, error)
   343  	Killer
   344  }
   345  
   346  // Killer represents a type that can be asked to abort any outstanding
   347  // requests.  The Kill method should return immediately.
   348  type Killer interface {
   349  	Kill()
   350  }
   351  
   352  // input reads messages from the connection and handles them
   353  // appropriately.
   354  func (conn *Conn) input() {
   355  	err := conn.loop()
   356  	conn.sending.Lock()
   357  	defer conn.sending.Unlock()
   358  	conn.mutex.Lock()
   359  	defer conn.mutex.Unlock()
   360  
   361  	if conn.closing || errors.Cause(err) == io.EOF {
   362  		err = ErrShutdown
   363  	} else {
   364  		// Make the error available for Conn.Close to see.
   365  		conn.inputLoopError = err
   366  	}
   367  	// Terminate all client requests.
   368  	for _, call := range conn.clientPending {
   369  		call.Error = err
   370  		call.done()
   371  	}
   372  	conn.clientPending = nil
   373  	conn.shutdown = true
   374  	close(conn.dead)
   375  }
   376  
   377  // loop implements the looping part of Conn.input.
   378  func (conn *Conn) loop() error {
   379  	defer conn.cancelContext()
   380  	for {
   381  		var hdr Header
   382  		err := conn.codec.ReadHeader(&hdr)
   383  		switch {
   384  		case errors.Cause(err) == io.EOF:
   385  			// handle sentinel error specially
   386  			return err
   387  		case err != nil:
   388  			return errors.Annotate(err, "codec.ReadHeader error")
   389  		case hdr.IsRequest():
   390  			if err := conn.handleRequest(&hdr); err != nil {
   391  				return errors.Annotatef(err, "codec.handleRequest %#v error", hdr)
   392  			}
   393  		default:
   394  			if err := conn.handleResponse(&hdr); err != nil {
   395  				return errors.Annotatef(err, "codec.handleResponse %#v error", hdr)
   396  			}
   397  		}
   398  	}
   399  }
   400  
   401  func (conn *Conn) readBody(resp interface{}, isRequest bool) error {
   402  	if resp == nil {
   403  		resp = &struct{}{}
   404  	}
   405  	return conn.codec.ReadBody(resp, isRequest)
   406  }
   407  
   408  func (conn *Conn) getRecorder() Recorder {
   409  	conn.mutex.Lock()
   410  	defer conn.mutex.Unlock()
   411  	return conn.recorderFactory()
   412  }
   413  
   414  func (conn *Conn) handleRequest(hdr *Header) error {
   415  	recorder := conn.getRecorder()
   416  	req, err := conn.bindRequest(hdr)
   417  	if err != nil {
   418  		if err := recorder.HandleRequest(hdr, nil); err != nil {
   419  			return errors.Trace(err)
   420  		}
   421  		if err := conn.readBody(nil, true); err != nil {
   422  			return err
   423  		}
   424  		// We don't transform the error here. bindRequest will have
   425  		// already transformed it and returned a zero req.
   426  		return conn.writeErrorResponse(hdr, err, recorder)
   427  	}
   428  	var argp interface{}
   429  	var arg reflect.Value
   430  	if req.ParamsType() != nil {
   431  		v := reflect.New(req.ParamsType())
   432  		arg = v.Elem()
   433  		argp = v.Interface()
   434  	}
   435  	if err := conn.readBody(argp, true); err != nil {
   436  		if err := recorder.HandleRequest(hdr, nil); err != nil {
   437  			return errors.Trace(err)
   438  		}
   439  
   440  		// If we get EOF, we know the connection is a
   441  		// goner, so don't try to respond.
   442  		if err == io.EOF || err == io.ErrUnexpectedEOF {
   443  			return err
   444  		}
   445  		// An error reading the body often indicates bad
   446  		// request parameters rather than an issue with
   447  		// the connection itself, so we reply with an
   448  		// error rather than tearing down the connection
   449  		// unless it's obviously a connection issue.  If
   450  		// the error is actually a framing or syntax
   451  		// problem, then the next ReadHeader should pick
   452  		// up the problem and abort.
   453  		return conn.writeErrorResponse(hdr, req.transformErrors(err), recorder)
   454  	}
   455  	var body interface{} = struct{}{}
   456  	if req.ParamsType() != nil {
   457  		body = arg.Interface()
   458  	}
   459  	if err := recorder.HandleRequest(hdr, body); err != nil {
   460  		logger.Errorf("error recording request %+v with arg %+v: %T %+v", req, arg, err, err)
   461  		return conn.writeErrorResponse(hdr, req.transformErrors(err), recorder)
   462  	}
   463  	conn.mutex.Lock()
   464  	closing := conn.closing
   465  	if !closing {
   466  		conn.srvPending.Add(1)
   467  		go conn.runRequest(req, arg, hdr.Version, recorder)
   468  	}
   469  	conn.mutex.Unlock()
   470  	if closing {
   471  		// We're closing down - no new requests may be initiated.
   472  		return conn.writeErrorResponse(hdr, req.transformErrors(ErrShutdown), recorder)
   473  	}
   474  	return nil
   475  }
   476  
   477  func (conn *Conn) writeErrorResponse(reqHdr *Header, err error, recorder Recorder) error {
   478  	conn.sending.Lock()
   479  	defer conn.sending.Unlock()
   480  	hdr := &Header{
   481  		RequestId: reqHdr.RequestId,
   482  		Version:   reqHdr.Version,
   483  	}
   484  	if err, ok := err.(ErrorCoder); ok {
   485  		hdr.ErrorCode = err.ErrorCode()
   486  	} else {
   487  		hdr.ErrorCode = ""
   488  	}
   489  	hdr.Error = err.Error()
   490  	if err := recorder.HandleReply(reqHdr.Request, hdr, struct{}{}); err != nil {
   491  		logger.Errorf("error recording reply %+v: %T %+v", hdr, err, err)
   492  	}
   493  
   494  	return conn.codec.WriteMessage(hdr, struct{}{})
   495  }
   496  
   497  // boundRequest represents an RPC request that is
   498  // bound to an actual implementation.
   499  type boundRequest struct {
   500  	rpcreflect.MethodCaller
   501  	transformErrors func(error) error
   502  	hdr             Header
   503  }
   504  
   505  // bindRequest searches for methods implementing the
   506  // request held in the given header and returns
   507  // a boundRequest that can call those methods.
   508  func (conn *Conn) bindRequest(hdr *Header) (boundRequest, error) {
   509  	conn.mutex.Lock()
   510  	root := conn.root
   511  	transformErrors := conn.transformErrors
   512  	conn.mutex.Unlock()
   513  
   514  	if root == nil {
   515  		return boundRequest{}, errors.New("no service")
   516  	}
   517  	caller, err := root.FindMethod(
   518  		hdr.Request.Type, hdr.Request.Version, hdr.Request.Action)
   519  	if err != nil {
   520  		if _, ok := err.(*rpcreflect.CallNotImplementedError); ok {
   521  			err = &serverError{
   522  				error: err,
   523  			}
   524  		} else {
   525  			err = transformErrors(err)
   526  		}
   527  		return boundRequest{}, err
   528  	}
   529  	return boundRequest{
   530  		MethodCaller:    caller,
   531  		transformErrors: transformErrors,
   532  		hdr:             *hdr,
   533  	}, nil
   534  }
   535  
   536  // runRequest runs the given request and sends the reply.
   537  func (conn *Conn) runRequest(
   538  	req boundRequest,
   539  	arg reflect.Value,
   540  	version int,
   541  	recorder Recorder,
   542  ) {
   543  	// If the request causes a panic, ensure we log that before closing the connection.
   544  	defer func() {
   545  		if panicResult := recover(); panicResult != nil {
   546  			logger.Criticalf(
   547  				"panic running request %+v with arg %+v: %v\n%v", req, arg, panicResult, string(debug.Stack()))
   548  			conn.writeErrorResponse(&req.hdr, errors.Errorf("%v", panicResult), recorder)
   549  		}
   550  	}()
   551  	defer conn.srvPending.Done()
   552  
   553  	// Create a request-specific context, cancelled when the
   554  	// request returns.
   555  	//
   556  	// TODO(axw) provide a means for clients to cancel a request.
   557  	ctx, cancel := context.WithCancel(conn.context)
   558  	defer cancel()
   559  
   560  	rv, err := req.Call(ctx, req.hdr.Request.Id, arg)
   561  	if err != nil {
   562  		err = conn.writeErrorResponse(&req.hdr, req.transformErrors(err), recorder)
   563  	} else {
   564  		hdr := &Header{
   565  			RequestId: req.hdr.RequestId,
   566  			Version:   version,
   567  		}
   568  		var rvi interface{}
   569  		if rv.IsValid() {
   570  			rvi = rv.Interface()
   571  		} else {
   572  			rvi = struct{}{}
   573  		}
   574  		if err := recorder.HandleReply(req.hdr.Request, hdr, rvi); err != nil {
   575  			logger.Errorf("error recording reply %+v: %T %+v", hdr, err, err)
   576  		}
   577  		conn.sending.Lock()
   578  		err = conn.codec.WriteMessage(hdr, rvi)
   579  		conn.sending.Unlock()
   580  	}
   581  	if err != nil {
   582  		// If the message failed due to the other end closing the socket, that
   583  		// is expected when an agent restarts so no need to log an  error.
   584  		// The error type here is errors.errorString so all we can do is a match
   585  		// on the error string content.
   586  		msg := err.Error()
   587  		if !strings.Contains(msg, "websocket: close sent") &&
   588  			!strings.Contains(msg, "write: broken pipe") {
   589  			logger.Errorf("error writing response: %T %+v", err, err)
   590  		}
   591  	}
   592  }
   593  
   594  type serverError struct {
   595  	error
   596  }
   597  
   598  func (e *serverError) ErrorCode() string {
   599  	// serverError only knows one error code.
   600  	return codeNotImplemented
   601  }
   602  
   603  func ensureFactory(f RecorderFactory) RecorderFactory {
   604  	if f != nil {
   605  		return f
   606  	}
   607  	var nop nopRecorder
   608  	return func() Recorder {
   609  		return &nop
   610  	}
   611  }
   612  
   613  type nopRecorder struct{}
   614  
   615  func (nopRecorder) HandleRequest(hdr *Header, body interface{}) error { return nil }
   616  
   617  func (nopRecorder) HandleReply(req Request, hdr *Header, body interface{}) error { return nil }