github.com/cloudbase/juju-core@v0.0.0-20140504232958-a7271ac7912f/rpc/server.go (about)

     1  // Copyright 2012, 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package rpc
     5  
     6  import (
     7  	"fmt"
     8  	"io"
     9  	"reflect"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/juju/loggo"
    14  
    15  	"launchpad.net/juju-core/rpc/rpcreflect"
    16  )
    17  
    18  const CodeNotImplemented = "not implemented"
    19  
    20  var logger = loggo.GetLogger("juju.rpc")
    21  
    22  // A Codec implements reading and writing of messages in an RPC
    23  // session.  The RPC code calls WriteMessage to write a message to the
    24  // connection and calls ReadHeader and ReadBody in pairs to read
    25  // messages.
    26  type Codec interface {
    27  	// ReadHeader reads a message header into hdr.
    28  	ReadHeader(hdr *Header) error
    29  
    30  	// ReadBody reads a message body into the given body value.  The
    31  	// isRequest parameter specifies whether the message being read
    32  	// is a request; if not, it's a response.  The body value will
    33  	// be a non-nil struct pointer, or nil to signify that the body
    34  	// should be read and discarded.
    35  	ReadBody(body interface{}, isRequest bool) error
    36  
    37  	// WriteMessage writes a message with the given header and body.
    38  	// The body will always be a struct. It may be called concurrently
    39  	// with ReadHeader and ReadBody, but will not be called
    40  	// concurrently with itself.
    41  	WriteMessage(hdr *Header, body interface{}) error
    42  
    43  	// Close closes the codec. It may be called concurrently
    44  	// and should cause the Read methods to unblock.
    45  	Close() error
    46  }
    47  
    48  // Header is a header written before every RPC call.  Since RPC requests
    49  // can be initiated from either side, the header may represent a request
    50  // from the other side or a response to an outstanding request.
    51  type Header struct {
    52  	// RequestId holds the sequence number of the request.
    53  	// For replies, it holds the sequence number of the request
    54  	// that is being replied to.
    55  	RequestId uint64
    56  
    57  	// Request holds the action to invoke.
    58  	Request Request
    59  
    60  	// Error holds the error, if any.
    61  	Error string
    62  
    63  	// ErrorCode holds the code of the error, if any.
    64  	ErrorCode string
    65  }
    66  
    67  // Request represents an RPC to be performed, absent its parameters.
    68  type Request struct {
    69  	// Type holds the type of object to act on.
    70  	Type string
    71  
    72  	// Id holds the id of the object to act on.
    73  	Id string
    74  
    75  	// Action holds the action to perform on the object.
    76  	Action string
    77  }
    78  
    79  // IsRequest returns whether the header represents an RPC request.  If
    80  // it is not a request, it is a response.
    81  func (hdr *Header) IsRequest() bool {
    82  	return hdr.Request.Type != "" || hdr.Request.Action != ""
    83  }
    84  
    85  // Note that we use "client request" and "server request" to name
    86  // requests initiated locally and remotely respectively.
    87  
    88  // Conn represents an RPC endpoint.  It can both initiate and receive
    89  // RPC requests.  There may be multiple outstanding Calls associated
    90  // with a single Client, and a Client may be used by multiple goroutines
    91  // simultaneously.
    92  type Conn struct {
    93  	// codec holds the underlying RPC connection.
    94  	codec Codec
    95  
    96  	// notifier is informed about RPC requests. It may be nil.
    97  	notifier RequestNotifier
    98  
    99  	// srvPending represents the current server requests.
   100  	srvPending sync.WaitGroup
   101  
   102  	// sending guards the write side of the codec - it ensures
   103  	// that codec.WriteMessage is not called concurrently.
   104  	// It also guards shutdown.
   105  	sending sync.Mutex
   106  
   107  	// mutex guards the following values.
   108  	mutex sync.Mutex
   109  
   110  	// rootValue holds the value to use to serve RPC requests, if any.
   111  	rootValue rpcreflect.Value
   112  
   113  	// transformErrors is used to transform returned errors.
   114  	transformErrors func(error) error
   115  
   116  	// reqId holds the latest client request id.
   117  	reqId uint64
   118  
   119  	// clientPending holds all pending client requests.
   120  	clientPending map[uint64]*Call
   121  
   122  	// closing is set when the connection is shutting down via
   123  	// Close.  When this is set, no more client or server requests
   124  	// will be initiated.
   125  	closing bool
   126  
   127  	// shutdown is set when the input loop terminates. When this
   128  	// is set, no more client requests will be sent to the server.
   129  	shutdown bool
   130  
   131  	// dead is closed when the input loop terminates.
   132  	dead chan struct{}
   133  
   134  	// inputLoopError holds the error that caused the input loop to
   135  	// terminate prematurely.  It is set before dead is closed.
   136  	inputLoopError error
   137  }
   138  
   139  // RequestNotifier can be implemented to find out about requests
   140  // occurring in an RPC conn, for example to print requests for logging
   141  // purposes. The calls should not block or interact with the Conn object
   142  // as that can cause delays to the RPC server or deadlock.
   143  // Note that the methods on RequestNotifier may
   144  // be called concurrently.
   145  type RequestNotifier interface {
   146  	// ServerRequest informs the RequestNotifier of a request made
   147  	// to the Conn. If the request was not recognized or there was
   148  	// an error reading the body, body will be nil.
   149  	//
   150  	// ServerRequest is called just before the server method
   151  	// is invoked.
   152  	ServerRequest(hdr *Header, body interface{})
   153  
   154  	// ServerReply informs the RequestNotifier of a reply sent to a
   155  	// server request. The given Request gives details of the call
   156  	// that was made; the given Header and body are the header and
   157  	// body sent as reply.
   158  	//
   159  	// ServerReply is called just before the reply is written.
   160  	ServerReply(req Request, hdr *Header, body interface{}, timeSpent time.Duration)
   161  
   162  	// ClientRequest informs the RequestNotifier of a request
   163  	// made from the Conn. It is called just before the request is
   164  	// written.
   165  	ClientRequest(hdr *Header, body interface{})
   166  
   167  	// ClientReply informs the RequestNotifier of a reply received
   168  	// to a request. If the reply was to an unrecognised request,
   169  	// the Request will be zero-valued. If the reply contained an
   170  	// error, body will be nil; otherwise body will be the value that
   171  	// was passed to the Conn.Call method.
   172  	//
   173  	// ClientReply is called just before the reply is handed to
   174  	// back to the caller.
   175  	ClientReply(req Request, hdr *Header, body interface{})
   176  }
   177  
   178  // NewConn creates a new connection that uses the given codec for
   179  // transport, but it does not start it. Conn.Start must be called before
   180  // any requests are sent or received. If notifier is non-nil, the
   181  // appropriate method will be called for every RPC request.
   182  func NewConn(codec Codec, notifier RequestNotifier) *Conn {
   183  	return &Conn{
   184  		codec:         codec,
   185  		clientPending: make(map[uint64]*Call),
   186  		notifier:      notifier,
   187  	}
   188  }
   189  
   190  // Start starts the RPC connection running.  It must be called at least
   191  // once for any RPC connection (client or server side) It has no effect
   192  // if it has already been called.  By default, a connection serves no
   193  // methods.  See Conn.Serve for a description of how to serve methods on
   194  // a Conn.
   195  func (conn *Conn) Start() {
   196  	conn.mutex.Lock()
   197  	defer conn.mutex.Unlock()
   198  	if conn.dead == nil {
   199  		conn.dead = make(chan struct{})
   200  		go conn.input()
   201  	}
   202  }
   203  
   204  // Serve serves RPC requests on the connection by invoking methods on
   205  // root. Note that it does not start the connection running,
   206  // though it may be called once the connection is already started.
   207  //
   208  // The server executes each client request by calling a method on root
   209  // to obtain an object to act on; then it invokes an method on that
   210  // object with the request parameters, possibly returning some result.
   211  //
   212  // Methods on the root value are of the form:
   213  //
   214  //      M(id string) (O, error)
   215  //
   216  // where M is an exported name, conventionally naming the object type,
   217  // id is some identifier for the object and O is the type of the
   218  // returned object.
   219  //
   220  // Methods defined on O may defined in one of the following forms, where
   221  // T and R must be struct types.
   222  //
   223  //	Method()
   224  //	Method() R
   225  //	Method() (R, error)
   226  //	Method() error
   227  //	Method(T)
   228  //	Method(T) R
   229  //	Method(T) (R, error)
   230  //	Method(T) error
   231  //
   232  // If transformErrors is non-nil, it will be called on all returned
   233  // non-nil errors, for example to transform the errors into ServerErrors
   234  // with specified codes.  There will be a panic if transformErrors
   235  // returns nil.
   236  //
   237  // Serve may be called at any time on a connection to change the
   238  // set of methods being served by the connection. This will have
   239  // no effect on calls that are currently being services.
   240  // If root is nil, the connection will serve no methods.
   241  func (conn *Conn) Serve(root interface{}, transformErrors func(error) error) {
   242  	rootValue := rpcreflect.ValueOf(reflect.ValueOf(root))
   243  	if rootValue.IsValid() && transformErrors == nil {
   244  		transformErrors = func(err error) error { return err }
   245  	}
   246  	conn.mutex.Lock()
   247  	defer conn.mutex.Unlock()
   248  	conn.rootValue = rootValue
   249  	conn.transformErrors = transformErrors
   250  }
   251  
   252  // Dead returns a channel that is closed when the connection
   253  // has been closed or the underlying transport has received
   254  // an error. There may still be outstanding requests.
   255  // Dead must be called after conn.Start has been called.
   256  func (conn *Conn) Dead() <-chan struct{} {
   257  	return conn.dead
   258  }
   259  
   260  // Close closes the connection and its underlying codec; it returns when
   261  // all requests have been terminated.
   262  //
   263  // If the connection is serving requests, and the root value implements
   264  // the Killer interface, its Kill method will be called.  The codec will
   265  // then be closed only when all its outstanding server calls have
   266  // completed.
   267  //
   268  // Calling Close multiple times is not an error.
   269  func (conn *Conn) Close() error {
   270  	conn.mutex.Lock()
   271  	if conn.closing {
   272  		conn.mutex.Unlock()
   273  		// Golang's net/rpc returns rpc.ErrShutdown if you ask to close
   274  		// a closing or shutdown connection. Our choice is that Close
   275  		// is an idempotent way to ask for resources to be released and
   276  		// isn't a failure if called multiple times.
   277  		return nil
   278  	}
   279  	conn.closing = true
   280  	// Kill server requests if appropriate.  Client requests will be
   281  	// terminated when the input loop finishes.
   282  	if conn.rootValue.IsValid() {
   283  		if killer, ok := conn.rootValue.GoValue().Interface().(Killer); ok {
   284  			killer.Kill()
   285  		}
   286  	}
   287  	conn.mutex.Unlock()
   288  
   289  	// Wait for any outstanding server requests to complete
   290  	// and write their replies before closing the codec.
   291  	conn.srvPending.Wait()
   292  
   293  	// Closing the codec should cause the input loop to terminate.
   294  	if err := conn.codec.Close(); err != nil {
   295  		logger.Infof("error closing codec: %v", err)
   296  	}
   297  	<-conn.dead
   298  	return conn.inputLoopError
   299  }
   300  
   301  // ErrorCoder represents an any error that has an associated
   302  // error code. An error code is a short string that represents the
   303  // kind of an error.
   304  type ErrorCoder interface {
   305  	ErrorCode() string
   306  }
   307  
   308  // Killer represents a type that can be asked to abort any outstanding
   309  // requests.  The Kill method should return immediately.
   310  type Killer interface {
   311  	Kill()
   312  }
   313  
   314  // input reads messages from the connection and handles them
   315  // appropriately.
   316  func (conn *Conn) input() {
   317  	err := conn.loop()
   318  	conn.sending.Lock()
   319  	defer conn.sending.Unlock()
   320  	conn.mutex.Lock()
   321  	defer conn.mutex.Unlock()
   322  
   323  	if conn.closing || err == io.EOF {
   324  		err = ErrShutdown
   325  	} else {
   326  		// Make the error available for Conn.Close to see.
   327  		conn.inputLoopError = err
   328  	}
   329  	// Terminate all client requests.
   330  	for _, call := range conn.clientPending {
   331  		call.Error = err
   332  		call.done()
   333  	}
   334  	conn.clientPending = nil
   335  	conn.shutdown = true
   336  	close(conn.dead)
   337  }
   338  
   339  // loop implements the looping part of Conn.input.
   340  func (conn *Conn) loop() error {
   341  	var hdr Header
   342  	for {
   343  		hdr = Header{}
   344  		err := conn.codec.ReadHeader(&hdr)
   345  		if err != nil {
   346  			return err
   347  		}
   348  		if hdr.IsRequest() {
   349  			err = conn.handleRequest(&hdr)
   350  		} else {
   351  			err = conn.handleResponse(&hdr)
   352  		}
   353  		if err != nil {
   354  			return err
   355  		}
   356  	}
   357  }
   358  
   359  func (conn *Conn) readBody(resp interface{}, isRequest bool) error {
   360  	if resp == nil {
   361  		resp = &struct{}{}
   362  	}
   363  	return conn.codec.ReadBody(resp, isRequest)
   364  }
   365  
   366  func (conn *Conn) handleRequest(hdr *Header) error {
   367  	startTime := time.Now()
   368  	req, err := conn.bindRequest(hdr)
   369  	if err != nil {
   370  		if conn.notifier != nil {
   371  			conn.notifier.ServerRequest(hdr, nil)
   372  		}
   373  		if err := conn.readBody(nil, true); err != nil {
   374  			return err
   375  		}
   376  		// We don't transform the error because there
   377  		// may be no transformErrors function available.
   378  		return conn.writeErrorResponse(hdr, err, startTime)
   379  	}
   380  	var argp interface{}
   381  	var arg reflect.Value
   382  	if req.ParamsType != nil {
   383  		v := reflect.New(req.ParamsType)
   384  		arg = v.Elem()
   385  		argp = v.Interface()
   386  	}
   387  	if err := conn.readBody(argp, true); err != nil {
   388  		if conn.notifier != nil {
   389  			conn.notifier.ServerRequest(hdr, nil)
   390  		}
   391  		// If we get EOF, we know the connection is a
   392  		// goner, so don't try to respond.
   393  		if err == io.EOF || err == io.ErrUnexpectedEOF {
   394  			return err
   395  		}
   396  		// An error reading the body often indicates bad
   397  		// request parameters rather than an issue with
   398  		// the connection itself, so we reply with an
   399  		// error rather than tearing down the connection
   400  		// unless it's obviously a connection issue.  If
   401  		// the error is actually a framing or syntax
   402  		// problem, then the next ReadHeader should pick
   403  		// up the problem and abort.
   404  		return conn.writeErrorResponse(hdr, req.transformErrors(err), startTime)
   405  	}
   406  	if conn.notifier != nil {
   407  		if req.ParamsType != nil {
   408  			conn.notifier.ServerRequest(hdr, arg.Interface())
   409  		} else {
   410  			conn.notifier.ServerRequest(hdr, struct{}{})
   411  		}
   412  	}
   413  	conn.mutex.Lock()
   414  	closing := conn.closing
   415  	if !closing {
   416  		conn.srvPending.Add(1)
   417  		go conn.runRequest(req, arg, startTime)
   418  	}
   419  	conn.mutex.Unlock()
   420  	if closing {
   421  		// We're closing down - no new requests may be initiated.
   422  		return conn.writeErrorResponse(hdr, req.transformErrors(ErrShutdown), startTime)
   423  	}
   424  	return nil
   425  }
   426  
   427  func (conn *Conn) writeErrorResponse(reqHdr *Header, err error, startTime time.Time) error {
   428  	conn.sending.Lock()
   429  	defer conn.sending.Unlock()
   430  	hdr := &Header{
   431  		RequestId: reqHdr.RequestId,
   432  	}
   433  	if err, ok := err.(ErrorCoder); ok {
   434  		hdr.ErrorCode = err.ErrorCode()
   435  	} else {
   436  		hdr.ErrorCode = ""
   437  	}
   438  	hdr.Error = err.Error()
   439  	if conn.notifier != nil {
   440  		conn.notifier.ServerReply(reqHdr.Request, hdr, struct{}{}, time.Since(startTime))
   441  	}
   442  	return conn.codec.WriteMessage(hdr, struct{}{})
   443  }
   444  
   445  // boundRequest represents an RPC request that is
   446  // bound to an actual implementation.
   447  type boundRequest struct {
   448  	rpcreflect.MethodCaller
   449  	transformErrors func(error) error
   450  	hdr             Header
   451  }
   452  
   453  // bindRequest searches for methods implementing the
   454  // request held in the given header and returns
   455  // a boundRequest that can call those methods.
   456  func (conn *Conn) bindRequest(hdr *Header) (boundRequest, error) {
   457  	conn.mutex.Lock()
   458  	rootValue := conn.rootValue
   459  	transformErrors := conn.transformErrors
   460  	conn.mutex.Unlock()
   461  
   462  	if !rootValue.IsValid() {
   463  		return boundRequest{}, fmt.Errorf("no service")
   464  	}
   465  	caller, err := rootValue.MethodCaller(hdr.Request.Type, hdr.Request.Action)
   466  	if err != nil {
   467  		if _, ok := err.(*rpcreflect.CallNotImplementedError); ok {
   468  			err = &serverError{
   469  				Message: err.Error(),
   470  				Code:    CodeNotImplemented,
   471  			}
   472  		}
   473  		return boundRequest{}, err
   474  	}
   475  	return boundRequest{
   476  		MethodCaller:    caller,
   477  		transformErrors: transformErrors,
   478  		hdr:             *hdr,
   479  	}, nil
   480  }
   481  
   482  // runRequest runs the given request and sends the reply.
   483  func (conn *Conn) runRequest(req boundRequest, arg reflect.Value, startTime time.Time) {
   484  	defer conn.srvPending.Done()
   485  	rv, err := req.Call(req.hdr.Request.Id, arg)
   486  	if err != nil {
   487  		err = conn.writeErrorResponse(&req.hdr, req.transformErrors(err), startTime)
   488  	} else {
   489  		hdr := &Header{
   490  			RequestId: req.hdr.RequestId,
   491  		}
   492  		var rvi interface{}
   493  		if rv.IsValid() {
   494  			rvi = rv.Interface()
   495  		} else {
   496  			rvi = struct{}{}
   497  		}
   498  		if conn.notifier != nil {
   499  			conn.notifier.ServerReply(req.hdr.Request, hdr, rvi, time.Since(startTime))
   500  		}
   501  		conn.sending.Lock()
   502  		err = conn.codec.WriteMessage(hdr, rvi)
   503  		conn.sending.Unlock()
   504  	}
   505  	if err != nil {
   506  		logger.Errorf("error writing response: %v", err)
   507  	}
   508  }
   509  
   510  type serverError RequestError
   511  
   512  func (e *serverError) Error() string {
   513  	return e.Message
   514  }
   515  
   516  func (e *serverError) ErrorCode() string {
   517  	return e.Code
   518  }