github.com/wallyworld/juju@v0.0.0-20161013125918-6cf1bc9d917a/rpc/rpc_test.go (about)

     1  // Copyright 2012, 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package rpc_test
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  	"net"
    10  	"reflect"
    11  	"regexp"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/juju/errors"
    16  	"github.com/juju/loggo"
    17  	jc "github.com/juju/testing/checkers"
    18  	gc "gopkg.in/check.v1"
    19  
    20  	"github.com/juju/juju/apiserver/params"
    21  	"github.com/juju/juju/rpc"
    22  	"github.com/juju/juju/rpc/jsoncodec"
    23  	"github.com/juju/juju/rpc/rpcreflect"
    24  	"github.com/juju/juju/testing"
    25  )
    26  
    27  var logger = loggo.GetLogger("juju.rpc")
    28  
    29  type rpcSuite struct {
    30  	testing.BaseSuite
    31  }
    32  
    33  var _ = gc.Suite(&rpcSuite{})
    34  
    35  type callInfo struct {
    36  	rcvr   interface{}
    37  	method string
    38  	arg    interface{}
    39  }
    40  
    41  type callError callInfo
    42  
    43  func (e *callError) Error() string {
    44  	return fmt.Sprintf("error calling %s", e.method)
    45  }
    46  
    47  type stringVal struct {
    48  	Val string
    49  }
    50  
    51  type Root struct {
    52  	mu        sync.Mutex
    53  	conn      *rpc.Conn
    54  	calls     []*callInfo
    55  	returnErr bool
    56  	simple    map[string]*SimpleMethods
    57  	delayed   map[string]*DelayedMethods
    58  	errorInst *ErrorMethods
    59  }
    60  
    61  func (r *Root) callError(rcvr interface{}, name string, arg interface{}) error {
    62  	if r.returnErr {
    63  		return &callError{rcvr, name, arg}
    64  	}
    65  	return nil
    66  }
    67  
    68  func (r *Root) SimpleMethods(id string) (*SimpleMethods, error) {
    69  	r.mu.Lock()
    70  	defer r.mu.Unlock()
    71  	if a := r.simple[id]; a != nil {
    72  		return a, nil
    73  	}
    74  	return nil, fmt.Errorf("unknown SimpleMethods id")
    75  }
    76  
    77  func (r *Root) DelayedMethods(id string) (*DelayedMethods, error) {
    78  	r.mu.Lock()
    79  	defer r.mu.Unlock()
    80  	if a := r.delayed[id]; a != nil {
    81  		return a, nil
    82  	}
    83  	return nil, fmt.Errorf("unknown DelayedMethods id")
    84  }
    85  
    86  func (r *Root) ErrorMethods(id string) (*ErrorMethods, error) {
    87  	if r.errorInst == nil {
    88  		return nil, fmt.Errorf("no error methods")
    89  	}
    90  	return r.errorInst, nil
    91  }
    92  
    93  func (r *Root) Discard1() {}
    94  
    95  func (r *Root) Discard2(id string) error { return nil }
    96  
    97  func (r *Root) Discard3(id string) int { return 0 }
    98  
    99  func (r *Root) CallbackMethods(string) (*CallbackMethods, error) {
   100  	return &CallbackMethods{r}, nil
   101  }
   102  
   103  func (r *Root) InterfaceMethods(id string) (InterfaceMethods, error) {
   104  	logger.Infof("interface methods called")
   105  	m, err := r.SimpleMethods(id)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	return m, nil
   110  }
   111  
   112  type InterfaceMethods interface {
   113  	Call1r1e(s stringVal) (stringVal, error)
   114  }
   115  
   116  type ChangeAPIMethods struct {
   117  	r *Root
   118  }
   119  
   120  func (r *Root) ChangeAPIMethods(string) (*ChangeAPIMethods, error) {
   121  	return &ChangeAPIMethods{r}, nil
   122  }
   123  
   124  func (t *Root) called(rcvr interface{}, method string, arg interface{}) {
   125  	t.mu.Lock()
   126  	t.calls = append(t.calls, &callInfo{rcvr, method, arg})
   127  	t.mu.Unlock()
   128  }
   129  
   130  type SimpleMethods struct {
   131  	root *Root
   132  	id   string
   133  }
   134  
   135  // Each Call method is named in this standard form:
   136  //
   137  //     Call<narg>r<nret><e>
   138  //
   139  // where narg is the number of arguments, nret is the number of returned
   140  // values (not including the error) and e is the letter 'e' if the
   141  // method returns an error.
   142  
   143  func (a *SimpleMethods) Call0r0() {
   144  	a.root.called(a, "Call0r0", nil)
   145  }
   146  
   147  func (a *SimpleMethods) Call0r1() stringVal {
   148  	a.root.called(a, "Call0r1", nil)
   149  	return stringVal{"Call0r1 ret"}
   150  }
   151  
   152  func (a *SimpleMethods) Call0r1e() (stringVal, error) {
   153  	a.root.called(a, "Call0r1e", nil)
   154  	return stringVal{"Call0r1e ret"}, a.root.callError(a, "Call0r1e", nil)
   155  }
   156  
   157  func (a *SimpleMethods) Call0r0e() error {
   158  	a.root.called(a, "Call0r0e", nil)
   159  	return a.root.callError(a, "Call0r0e", nil)
   160  }
   161  
   162  func (a *SimpleMethods) Call1r0(s stringVal) {
   163  	a.root.called(a, "Call1r0", s)
   164  }
   165  
   166  func (a *SimpleMethods) Call1r1(s stringVal) stringVal {
   167  	a.root.called(a, "Call1r1", s)
   168  	return stringVal{"Call1r1 ret"}
   169  }
   170  
   171  func (a *SimpleMethods) Call1r1e(s stringVal) (stringVal, error) {
   172  	a.root.called(a, "Call1r1e", s)
   173  	return stringVal{"Call1r1e ret"}, a.root.callError(a, "Call1r1e", s)
   174  }
   175  
   176  func (a *SimpleMethods) Call1r0e(s stringVal) error {
   177  	a.root.called(a, "Call1r0e", s)
   178  	return a.root.callError(a, "Call1r0e", s)
   179  }
   180  
   181  func (a *SimpleMethods) SliceArg(struct{ X []string }) stringVal {
   182  	return stringVal{"SliceArg ret"}
   183  }
   184  
   185  func (a *SimpleMethods) Discard1(int) {}
   186  
   187  func (a *SimpleMethods) Discard2(struct{}, struct{}) {}
   188  
   189  func (a *SimpleMethods) Discard3() int { return 0 }
   190  
   191  func (a *SimpleMethods) Discard4() (_, _ struct{}) { return }
   192  
   193  type DelayedMethods struct {
   194  	ready     chan struct{}
   195  	done      chan string
   196  	doneError chan error
   197  }
   198  
   199  func (a *DelayedMethods) Delay() (stringVal, error) {
   200  	if a.ready != nil {
   201  		a.ready <- struct{}{}
   202  	}
   203  	select {
   204  	case s := <-a.done:
   205  		return stringVal{s}, nil
   206  	case err := <-a.doneError:
   207  		return stringVal{}, err
   208  	}
   209  }
   210  
   211  type ErrorMethods struct {
   212  	err error
   213  }
   214  
   215  func (e *ErrorMethods) Call() error {
   216  	return e.err
   217  }
   218  
   219  type CallbackMethods struct {
   220  	root *Root
   221  }
   222  
   223  type int64val struct {
   224  	I int64
   225  }
   226  
   227  func (a *CallbackMethods) Factorial(x int64val) (int64val, error) {
   228  	if x.I <= 1 {
   229  		return int64val{1}, nil
   230  	}
   231  	var r int64val
   232  	err := a.root.conn.Call(rpc.Request{"CallbackMethods", 0, "", "Factorial"}, int64val{x.I - 1}, &r)
   233  	if err != nil {
   234  		return int64val{}, err
   235  	}
   236  	return int64val{x.I * r.I}, nil
   237  }
   238  
   239  func (a *ChangeAPIMethods) ChangeAPI() {
   240  	a.r.conn.Serve(&changedAPIRoot{}, nil)
   241  }
   242  
   243  func (a *ChangeAPIMethods) RemoveAPI() {
   244  	a.r.conn.Serve(nil, nil)
   245  }
   246  
   247  type changedAPIRoot struct{}
   248  
   249  func (r *changedAPIRoot) NewlyAvailable(string) (newlyAvailableMethods, error) {
   250  	return newlyAvailableMethods{}, nil
   251  }
   252  
   253  type newlyAvailableMethods struct{}
   254  
   255  func (newlyAvailableMethods) NewMethod() stringVal {
   256  	return stringVal{"new method result"}
   257  }
   258  
   259  type VariableMethods1 struct {
   260  	sm *SimpleMethods
   261  }
   262  
   263  func (vm *VariableMethods1) Call0r1() stringVal {
   264  	return vm.sm.Call0r1()
   265  }
   266  
   267  type VariableMethods2 struct {
   268  	sm *SimpleMethods
   269  }
   270  
   271  func (vm *VariableMethods2) Call1r1(s stringVal) stringVal {
   272  	return vm.sm.Call1r1(s)
   273  }
   274  
   275  type RestrictedMethods struct {
   276  	InterfaceMethods
   277  }
   278  
   279  type CustomRoot struct {
   280  	root *Root
   281  }
   282  
   283  type wrapper func(*SimpleMethods) reflect.Value
   284  
   285  type customMethodCaller struct {
   286  	wrap         wrapper
   287  	root         *Root
   288  	objMethod    rpcreflect.ObjMethod
   289  	expectedType reflect.Type
   290  }
   291  
   292  func (c customMethodCaller) ParamsType() reflect.Type {
   293  	return c.objMethod.Params
   294  }
   295  
   296  func (c customMethodCaller) ResultType() reflect.Type {
   297  	return c.objMethod.Result
   298  }
   299  
   300  func (c customMethodCaller) Call(objId string, arg reflect.Value) (reflect.Value, error) {
   301  	sm, err := c.root.SimpleMethods(objId)
   302  	if err != nil {
   303  		return reflect.Value{}, err
   304  	}
   305  	obj := c.wrap(sm)
   306  	if reflect.TypeOf(obj) != c.expectedType {
   307  		logger.Errorf("got the wrong type back, expected %s got %T", c.expectedType, obj)
   308  	}
   309  	logger.Debugf("calling: %T %v %#v", obj, obj, c.objMethod)
   310  	return c.objMethod.Call(obj, arg)
   311  }
   312  
   313  func (cc *CustomRoot) Kill() {
   314  }
   315  
   316  func (cc *CustomRoot) FindMethod(
   317  	rootMethodName string, version int, objMethodName string,
   318  ) (
   319  	rpcreflect.MethodCaller, error,
   320  ) {
   321  	logger.Debugf("got to FindMethod: %q %d %q", rootMethodName, version, objMethodName)
   322  	if rootMethodName != "MultiVersion" {
   323  		return nil, &rpcreflect.CallNotImplementedError{
   324  			RootMethod: rootMethodName,
   325  		}
   326  	}
   327  	var goType reflect.Type
   328  	var wrap wrapper
   329  	switch version {
   330  	case 0:
   331  		goType = reflect.TypeOf((*VariableMethods1)(nil))
   332  		wrap = func(sm *SimpleMethods) reflect.Value {
   333  			return reflect.ValueOf(&VariableMethods1{sm})
   334  		}
   335  	case 1:
   336  		goType = reflect.TypeOf((*VariableMethods2)(nil))
   337  		wrap = func(sm *SimpleMethods) reflect.Value {
   338  			return reflect.ValueOf(&VariableMethods2{sm})
   339  		}
   340  	case 2:
   341  		goType = reflect.TypeOf((*RestrictedMethods)(nil))
   342  		wrap = func(sm *SimpleMethods) reflect.Value {
   343  			methods := &RestrictedMethods{InterfaceMethods: sm}
   344  			return reflect.ValueOf(methods)
   345  		}
   346  	default:
   347  		return nil, &rpcreflect.CallNotImplementedError{
   348  			RootMethod: rootMethodName,
   349  			Version:    version,
   350  		}
   351  	}
   352  	logger.Debugf("found type: %s", goType)
   353  	objType := rpcreflect.ObjTypeOf(goType)
   354  	objMethod, err := objType.Method(objMethodName)
   355  	if err != nil {
   356  		return nil, &rpcreflect.CallNotImplementedError{
   357  			RootMethod: rootMethodName,
   358  			Version:    version,
   359  			Method:     objMethodName,
   360  		}
   361  	}
   362  	return customMethodCaller{
   363  		objMethod:    objMethod,
   364  		root:         cc.root,
   365  		wrap:         wrap,
   366  		expectedType: goType,
   367  	}, nil
   368  }
   369  
   370  func SimpleRoot() *Root {
   371  	root := &Root{
   372  		simple: make(map[string]*SimpleMethods),
   373  	}
   374  	root.simple["a99"] = &SimpleMethods{root: root, id: "a99"}
   375  	return root
   376  }
   377  
   378  func (*rpcSuite) TestRPC(c *gc.C) {
   379  	root := SimpleRoot()
   380  	client, srvDone, serverNotifier := newRPCClientServer(c, root, nil, false)
   381  	defer closeClient(c, client, srvDone)
   382  	for narg := 0; narg < 2; narg++ {
   383  		for nret := 0; nret < 2; nret++ {
   384  			for nerr := 0; nerr < 2; nerr++ {
   385  				retErr := nerr != 0
   386  				p := testCallParams{
   387  					client:         client,
   388  					serverNotifier: serverNotifier,
   389  					entry:          "SimpleMethods",
   390  					narg:           narg,
   391  					nret:           nret,
   392  					retErr:         retErr,
   393  					testErr:        false,
   394  				}
   395  				root.testCall(c, p)
   396  				if retErr {
   397  					p.testErr = true
   398  					root.testCall(c, p)
   399  				}
   400  			}
   401  		}
   402  	}
   403  }
   404  
   405  func callName(narg, nret int, retErr bool) string {
   406  	e := ""
   407  	if retErr {
   408  		e = "e"
   409  	}
   410  	return fmt.Sprintf("Call%dr%d%s", narg, nret, e)
   411  }
   412  
   413  type testCallParams struct {
   414  	// client holds the client-side of the rpc connection that
   415  	// will be used to make the call.
   416  	client *rpc.Conn
   417  
   418  	// serverNotifier holds the notifier for the server side.
   419  	serverNotifier *notifier
   420  
   421  	// entry holds the top-level type that will be invoked
   422  	// (e.g. "SimpleMethods").
   423  	entry string
   424  
   425  	// narg holds the number of arguments accepted by the
   426  	// call (0 or 1).
   427  	narg int
   428  
   429  	// nret holds the number of values returned by the
   430  	// call (0 or 1).
   431  	nret int
   432  
   433  	// retErr specifies whether the call returns an error.
   434  	retErr bool
   435  
   436  	// testErr specifies whether the call should be made to return an error.
   437  	testErr bool
   438  
   439  	// version specifies what version of the interface to call, defaults to 0.
   440  	version int
   441  }
   442  
   443  // request returns the RPC request for the test call.
   444  func (p testCallParams) request() rpc.Request {
   445  	return rpc.Request{
   446  		Type:    p.entry,
   447  		Version: p.version,
   448  		Id:      "a99",
   449  		Action:  callName(p.narg, p.nret, p.retErr),
   450  	}
   451  }
   452  
   453  // error message returns the error message that the test call
   454  // should return if it returns an error.
   455  func (p testCallParams) errorMessage() string {
   456  	return fmt.Sprintf("error calling %s", p.request().Action)
   457  }
   458  
   459  func (root *Root) testCall(c *gc.C, args testCallParams) {
   460  	args.serverNotifier.reset()
   461  	root.calls = nil
   462  	root.returnErr = args.testErr
   463  	c.Logf("test call %s", args.request().Action)
   464  	var response stringVal
   465  	err := args.client.Call(args.request(), stringVal{"arg"}, &response)
   466  	switch {
   467  	case args.retErr && args.testErr:
   468  		c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   469  			Message: args.errorMessage(),
   470  		})
   471  		c.Assert(response, gc.Equals, stringVal{})
   472  	case args.nret > 0:
   473  		c.Check(response, gc.Equals, stringVal{args.request().Action + " ret"})
   474  	}
   475  	if !args.testErr {
   476  		c.Check(err, jc.ErrorIsNil)
   477  	}
   478  
   479  	// Check that the call was actually made, the right
   480  	// parameters were received and the right result returned.
   481  	root.mu.Lock()
   482  	defer root.mu.Unlock()
   483  
   484  	root.assertCallMade(c, args)
   485  	root.assertServerNotified(c, args, args.client.ClientRequestID())
   486  }
   487  
   488  func (root *Root) assertCallMade(c *gc.C, p testCallParams) {
   489  	expectCall := callInfo{
   490  		rcvr:   root.simple["a99"],
   491  		method: p.request().Action,
   492  	}
   493  	if p.narg > 0 {
   494  		expectCall.arg = stringVal{"arg"}
   495  	}
   496  	c.Assert(root.calls, gc.HasLen, 1)
   497  	c.Assert(*root.calls[0], gc.Equals, expectCall)
   498  }
   499  
   500  // assertServerNotified asserts that the right server notifications
   501  // were made for the given test call parameters. The id of the request
   502  // is held in requestId.
   503  func (root *Root) assertServerNotified(c *gc.C, p testCallParams, requestId uint64) {
   504  	// Test that there was a notification for the request.
   505  	c.Assert(p.serverNotifier.serverRequests, gc.HasLen, 1)
   506  	serverReq := p.serverNotifier.serverRequests[0]
   507  	c.Assert(serverReq.hdr, gc.DeepEquals, rpc.Header{
   508  		RequestId: requestId,
   509  		Request:   p.request(),
   510  		Version:   1,
   511  	})
   512  	if p.narg > 0 {
   513  		c.Assert(serverReq.body, gc.Equals, stringVal{"arg"})
   514  	} else {
   515  		c.Assert(serverReq.body, gc.Equals, struct{}{})
   516  	}
   517  
   518  	// Test that there was a notification for the reply.
   519  	c.Assert(p.serverNotifier.serverReplies, gc.HasLen, 1)
   520  	serverReply := p.serverNotifier.serverReplies[0]
   521  	c.Assert(serverReply.req, gc.Equals, p.request())
   522  	if p.retErr && p.testErr || p.nret == 0 {
   523  		c.Assert(serverReply.body, gc.Equals, struct{}{})
   524  	} else {
   525  		c.Assert(serverReply.body, gc.Equals, stringVal{p.request().Action + " ret"})
   526  	}
   527  	if p.retErr && p.testErr {
   528  		c.Assert(serverReply.hdr, gc.Equals, rpc.Header{
   529  			RequestId: requestId,
   530  			Error:     p.errorMessage(),
   531  			Version:   1,
   532  		})
   533  	} else {
   534  		c.Assert(serverReply.hdr, gc.Equals, rpc.Header{
   535  			RequestId: requestId,
   536  			Version:   1,
   537  		})
   538  	}
   539  }
   540  
   541  func (*rpcSuite) TestInterfaceMethods(c *gc.C) {
   542  	root := SimpleRoot()
   543  	client, srvDone, serverNotifier := newRPCClientServer(c, root, nil, false)
   544  	defer closeClient(c, client, srvDone)
   545  	p := testCallParams{
   546  		client:         client,
   547  		serverNotifier: serverNotifier,
   548  		entry:          "InterfaceMethods",
   549  		narg:           1,
   550  		nret:           1,
   551  		retErr:         true,
   552  		testErr:        false,
   553  	}
   554  
   555  	root.testCall(c, p)
   556  	p.testErr = true
   557  	root.testCall(c, p)
   558  	// Call0r0 is defined on the underlying SimpleMethods, but is not
   559  	// exposed at the InterfaceMethods level, so this call should fail with
   560  	// CodeNotImplemented.
   561  	var r stringVal
   562  	err := client.Call(rpc.Request{"InterfaceMethods", 0, "a99", "Call0r0"}, stringVal{"arg"}, &r)
   563  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   564  		Message: "no such request - method InterfaceMethods.Call0r0 is not implemented",
   565  		Code:    rpc.CodeNotImplemented,
   566  	})
   567  }
   568  
   569  func (*rpcSuite) TestCustomRootV0(c *gc.C) {
   570  	root := &CustomRoot{SimpleRoot()}
   571  	client, srvDone, serverNotifier := newRPCClientServer(c, root, nil, false)
   572  	defer closeClient(c, client, srvDone)
   573  	// V0 of MultiVersion implements only VariableMethods1.Call0r1.
   574  	p := testCallParams{
   575  		client:         client,
   576  		serverNotifier: serverNotifier,
   577  		entry:          "MultiVersion",
   578  		version:        0,
   579  		narg:           0,
   580  		nret:           1,
   581  		retErr:         false,
   582  		testErr:        false,
   583  	}
   584  
   585  	root.root.testCall(c, p)
   586  	// Call1r1 is exposed in version 1, but not in version 0.
   587  	var r stringVal
   588  	err := client.Call(rpc.Request{"MultiVersion", 0, "a99", "Call1r1"}, stringVal{"arg"}, &r)
   589  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   590  		Message: "no such request - method MultiVersion.Call1r1 is not implemented",
   591  		Code:    rpc.CodeNotImplemented,
   592  	})
   593  }
   594  
   595  func (*rpcSuite) TestCustomRootV1(c *gc.C) {
   596  	root := &CustomRoot{SimpleRoot()}
   597  	client, srvDone, serverNotifier := newRPCClientServer(c, root, nil, false)
   598  	defer closeClient(c, client, srvDone)
   599  	// V1 of MultiVersion implements only VariableMethods2.Call1r1.
   600  	p := testCallParams{
   601  		client:         client,
   602  		serverNotifier: serverNotifier,
   603  		entry:          "MultiVersion",
   604  		version:        1,
   605  		narg:           1,
   606  		nret:           1,
   607  		retErr:         false,
   608  		testErr:        false,
   609  	}
   610  
   611  	root.root.testCall(c, p)
   612  	// Call0r1 is exposed in version 0, but not in version 1.
   613  	var r stringVal
   614  	err := client.Call(rpc.Request{"MultiVersion", 1, "a99", "Call0r1"}, nil, &r)
   615  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   616  		Message: "no such request - method MultiVersion(1).Call0r1 is not implemented",
   617  		Code:    rpc.CodeNotImplemented,
   618  	})
   619  }
   620  
   621  func (*rpcSuite) TestCustomRootV2(c *gc.C) {
   622  	root := &CustomRoot{SimpleRoot()}
   623  	client, srvDone, serverNotifier := newRPCClientServer(c, root, nil, false)
   624  	defer closeClient(c, client, srvDone)
   625  	p := testCallParams{
   626  		client:         client,
   627  		serverNotifier: serverNotifier,
   628  		entry:          "MultiVersion",
   629  		version:        2,
   630  		narg:           1,
   631  		nret:           1,
   632  		retErr:         true,
   633  		testErr:        false,
   634  	}
   635  
   636  	root.root.testCall(c, p)
   637  	// By embedding the InterfaceMethods inside a concrete
   638  	// RestrictedMethods type, we actually only expose the methods defined
   639  	// in InterfaceMethods.
   640  	var r stringVal
   641  	err := client.Call(rpc.Request{"MultiVersion", 2, "a99", "Call0r1e"}, nil, &r)
   642  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   643  		Message: `no such request - method MultiVersion(2).Call0r1e is not implemented`,
   644  		Code:    rpc.CodeNotImplemented,
   645  	})
   646  }
   647  
   648  func (*rpcSuite) TestCustomRootUnknownVersion(c *gc.C) {
   649  	root := &CustomRoot{SimpleRoot()}
   650  	client, srvDone, _ := newRPCClientServer(c, root, nil, false)
   651  	defer closeClient(c, client, srvDone)
   652  	var r stringVal
   653  	// Unknown version 5
   654  	err := client.Call(rpc.Request{"MultiVersion", 5, "a99", "Call0r1"}, nil, &r)
   655  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   656  		Message: `unknown version (5) of interface "MultiVersion"`,
   657  		Code:    rpc.CodeNotImplemented,
   658  	})
   659  }
   660  
   661  func (*rpcSuite) TestConcurrentCalls(c *gc.C) {
   662  	start1 := make(chan string)
   663  	start2 := make(chan string)
   664  	ready1 := make(chan struct{})
   665  	ready2 := make(chan struct{})
   666  
   667  	root := &Root{
   668  		delayed: map[string]*DelayedMethods{
   669  			"1": {ready: ready1, done: start1},
   670  			"2": {ready: ready2, done: start2},
   671  		},
   672  	}
   673  
   674  	client, srvDone, _ := newRPCClientServer(c, root, nil, false)
   675  	defer closeClient(c, client, srvDone)
   676  	call := func(id string, done chan<- struct{}) {
   677  		var r stringVal
   678  		err := client.Call(rpc.Request{"DelayedMethods", 0, id, "Delay"}, nil, &r)
   679  		c.Check(err, jc.ErrorIsNil)
   680  		c.Check(r.Val, gc.Equals, "return "+id)
   681  		done <- struct{}{}
   682  	}
   683  	done1 := make(chan struct{})
   684  	done2 := make(chan struct{})
   685  	go call("1", done1)
   686  	go call("2", done2)
   687  
   688  	// Check that both calls are running concurrently.
   689  	chanRead(c, ready1, "method 1 ready")
   690  	chanRead(c, ready2, "method 2 ready")
   691  
   692  	// Let the requests complete.
   693  	start1 <- "return 1"
   694  	start2 <- "return 2"
   695  	chanRead(c, done1, "method 1 done")
   696  	chanRead(c, done2, "method 2 done")
   697  }
   698  
   699  type codedError struct {
   700  	m    string
   701  	code string
   702  }
   703  
   704  func (e *codedError) Error() string {
   705  	return e.m
   706  }
   707  
   708  func (e *codedError) ErrorCode() string {
   709  	return e.code
   710  }
   711  
   712  func (*rpcSuite) TestErrorCode(c *gc.C) {
   713  	root := &Root{
   714  		errorInst: &ErrorMethods{&codedError{"message", "code"}},
   715  	}
   716  	client, srvDone, _ := newRPCClientServer(c, root, nil, false)
   717  	defer closeClient(c, client, srvDone)
   718  	err := client.Call(rpc.Request{"ErrorMethods", 0, "", "Call"}, nil, nil)
   719  	c.Assert(err, gc.ErrorMatches, `message \(code\)`)
   720  	c.Assert(errors.Cause(err).(rpc.ErrorCoder).ErrorCode(), gc.Equals, "code")
   721  }
   722  
   723  func (*rpcSuite) TestTransformErrors(c *gc.C) {
   724  	root := &Root{
   725  		errorInst: &ErrorMethods{&codedError{"message", "code"}},
   726  	}
   727  	tfErr := func(err error) error {
   728  		c.Check(err, gc.NotNil)
   729  		if e, ok := err.(*codedError); ok {
   730  			return &codedError{
   731  				m:    "transformed: " + e.m,
   732  				code: "transformed: " + e.code,
   733  			}
   734  		}
   735  		return fmt.Errorf("transformed: %v", err)
   736  	}
   737  	client, srvDone, _ := newRPCClientServer(c, root, tfErr, false)
   738  	defer closeClient(c, client, srvDone)
   739  	// First, we don't transform methods we can't find.
   740  	err := client.Call(rpc.Request{"foo", 0, "", "bar"}, nil, nil)
   741  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   742  		Message: `unknown object type "foo"`,
   743  		Code:    rpc.CodeNotImplemented,
   744  	})
   745  
   746  	err = client.Call(rpc.Request{"ErrorMethods", 0, "", "NoMethod"}, nil, nil)
   747  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   748  		Message: "no such request - method ErrorMethods.NoMethod is not implemented",
   749  		Code:    rpc.CodeNotImplemented,
   750  	})
   751  
   752  	// We do transform any errors that happen from calling the RootMethod
   753  	// and beyond.
   754  	err = client.Call(rpc.Request{"ErrorMethods", 0, "", "Call"}, nil, nil)
   755  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   756  		Message: "transformed: message",
   757  		Code:    "transformed: code",
   758  	})
   759  
   760  	root.errorInst.err = nil
   761  	err = client.Call(rpc.Request{"ErrorMethods", 0, "", "Call"}, nil, nil)
   762  	c.Assert(err, jc.ErrorIsNil)
   763  
   764  	root.errorInst = nil
   765  	err = client.Call(rpc.Request{"ErrorMethods", 0, "", "Call"}, nil, nil)
   766  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   767  		Message: "transformed: no error methods",
   768  	})
   769  
   770  }
   771  
   772  func (*rpcSuite) TestServerWaitsForOutstandingCalls(c *gc.C) {
   773  	ready := make(chan struct{})
   774  	start := make(chan string)
   775  	root := &Root{
   776  		delayed: map[string]*DelayedMethods{
   777  			"1": {
   778  				ready: ready,
   779  				done:  start,
   780  			},
   781  		},
   782  	}
   783  	client, srvDone, _ := newRPCClientServer(c, root, nil, false)
   784  	defer closeClient(c, client, srvDone)
   785  	done := make(chan struct{})
   786  	go func() {
   787  		var r stringVal
   788  		err := client.Call(rpc.Request{"DelayedMethods", 0, "1", "Delay"}, nil, &r)
   789  		c.Check(errors.Cause(err), gc.Equals, rpc.ErrShutdown)
   790  		done <- struct{}{}
   791  	}()
   792  	chanRead(c, ready, "DelayedMethods.Delay ready")
   793  	client.Close()
   794  	select {
   795  	case err := <-srvDone:
   796  		c.Fatalf("server returned while outstanding operation in progress: %v", err)
   797  		<-done
   798  	case <-time.After(25 * time.Millisecond):
   799  	}
   800  	start <- "xxx"
   801  }
   802  
   803  func chanRead(c *gc.C, ch <-chan struct{}, what string) {
   804  	select {
   805  	case <-ch:
   806  		return
   807  	case <-time.After(3 * time.Second):
   808  		c.Fatalf("timeout on channel read %s", what)
   809  	}
   810  }
   811  
   812  func (*rpcSuite) TestCompatibility(c *gc.C) {
   813  	root := &Root{
   814  		simple: make(map[string]*SimpleMethods),
   815  	}
   816  	a0 := &SimpleMethods{root: root, id: "a0"}
   817  	root.simple["a0"] = a0
   818  
   819  	client, srvDone, _ := newRPCClientServer(c, root, nil, false)
   820  	defer closeClient(c, client, srvDone)
   821  	call := func(method string, arg, ret interface{}) (passedArg interface{}) {
   822  		root.calls = nil
   823  		err := client.Call(rpc.Request{"SimpleMethods", 0, "a0", method}, arg, ret)
   824  		c.Assert(err, jc.ErrorIsNil)
   825  		c.Assert(root.calls, gc.HasLen, 1)
   826  		info := root.calls[0]
   827  		c.Assert(info.rcvr, gc.Equals, a0)
   828  		c.Assert(info.method, gc.Equals, method)
   829  		return info.arg
   830  	}
   831  	type extra struct {
   832  		Val   string
   833  		Extra string
   834  	}
   835  	// Extra fields in request and response.
   836  	var r extra
   837  	arg := call("Call1r1", extra{"x", "y"}, &r)
   838  	c.Assert(arg, gc.Equals, stringVal{"x"})
   839  
   840  	// Nil argument as request.
   841  	r = extra{}
   842  	arg = call("Call1r1", nil, &r)
   843  	c.Assert(arg, gc.Equals, stringVal{})
   844  
   845  	// Nil argument as response.
   846  	arg = call("Call1r1", stringVal{"x"}, nil)
   847  	c.Assert(arg, gc.Equals, stringVal{"x"})
   848  
   849  	// Non-nil argument for no response.
   850  	r = extra{}
   851  	arg = call("Call1r0", stringVal{"x"}, &r)
   852  	c.Assert(arg, gc.Equals, stringVal{"x"})
   853  	c.Assert(r, gc.Equals, extra{})
   854  }
   855  
   856  func (*rpcSuite) TestBadCall(c *gc.C) {
   857  	loggo.GetLogger("juju.rpc").SetLogLevel(loggo.TRACE)
   858  	root := &Root{
   859  		simple: make(map[string]*SimpleMethods),
   860  	}
   861  	a0 := &SimpleMethods{root: root, id: "a0"}
   862  	root.simple["a0"] = a0
   863  	client, srvDone, serverNotifier := newRPCClientServer(c, root, nil, false)
   864  	defer closeClient(c, client, srvDone)
   865  
   866  	testBadCall(c, client, serverNotifier,
   867  		rpc.Request{"BadSomething", 0, "a0", "No"},
   868  		`unknown object type "BadSomething"`,
   869  		rpc.CodeNotImplemented,
   870  		false,
   871  	)
   872  	testBadCall(c, client, serverNotifier,
   873  		rpc.Request{"SimpleMethods", 0, "xx", "No"},
   874  		"no such request - method SimpleMethods.No is not implemented",
   875  		rpc.CodeNotImplemented,
   876  		false,
   877  	)
   878  	testBadCall(c, client, serverNotifier,
   879  		rpc.Request{"SimpleMethods", 0, "xx", "Call0r0"},
   880  		`unknown SimpleMethods id`,
   881  		"",
   882  		true,
   883  	)
   884  }
   885  
   886  func testBadCall(
   887  	c *gc.C,
   888  	client *rpc.Conn,
   889  	serverNotifier *notifier,
   890  	req rpc.Request,
   891  	expectedErr string,
   892  	expectedErrCode string,
   893  	requestKnown bool,
   894  ) {
   895  	serverNotifier.reset()
   896  	err := client.Call(req, nil, nil)
   897  	msg := expectedErr
   898  	if expectedErrCode != "" {
   899  		msg += " (" + expectedErrCode + ")"
   900  	}
   901  	c.Assert(err, gc.ErrorMatches, regexp.QuoteMeta(msg))
   902  
   903  	// From docs on ServerRequest:
   904  	// 	If the request was not recognized or there was
   905  	//	an error reading the body, body will be nil.
   906  	var expectBody interface{}
   907  	if requestKnown {
   908  		expectBody = struct{}{}
   909  	}
   910  	c.Assert(serverNotifier.serverRequests[0], gc.DeepEquals, requestEvent{
   911  		hdr: rpc.Header{
   912  			RequestId: client.ClientRequestID(),
   913  			Request:   req,
   914  			Version:   1,
   915  		},
   916  		body: expectBody,
   917  	})
   918  
   919  	// Test that there was a notification for the server reply.
   920  	c.Assert(serverNotifier.serverReplies, gc.HasLen, 1)
   921  	serverReply := serverNotifier.serverReplies[0]
   922  	c.Assert(serverReply, gc.DeepEquals, replyEvent{
   923  		hdr: rpc.Header{
   924  			RequestId: client.ClientRequestID(),
   925  			Error:     expectedErr,
   926  			ErrorCode: expectedErrCode,
   927  			Version:   1,
   928  		},
   929  		req:  req,
   930  		body: struct{}{},
   931  	})
   932  }
   933  
   934  func (*rpcSuite) TestContinueAfterReadBodyError(c *gc.C) {
   935  	root := &Root{
   936  		simple: make(map[string]*SimpleMethods),
   937  	}
   938  	a0 := &SimpleMethods{root: root, id: "a0"}
   939  	root.simple["a0"] = a0
   940  	client, srvDone, _ := newRPCClientServer(c, root, nil, false)
   941  	defer closeClient(c, client, srvDone)
   942  
   943  	var ret stringVal
   944  	arg0 := struct {
   945  		X map[string]int
   946  	}{
   947  		X: map[string]int{"hello": 65},
   948  	}
   949  	err := client.Call(rpc.Request{"SimpleMethods", 0, "a0", "SliceArg"}, arg0, &ret)
   950  	c.Assert(err, gc.ErrorMatches, `json: cannot unmarshal object into Go value of type \[\]string`)
   951  
   952  	err = client.Call(rpc.Request{"SimpleMethods", 0, "a0", "SliceArg"}, arg0, &ret)
   953  	c.Assert(err, gc.ErrorMatches, `json: cannot unmarshal object into Go value of type \[\]string`)
   954  
   955  	arg1 := struct {
   956  		X []string
   957  	}{
   958  		X: []string{"one"},
   959  	}
   960  	err = client.Call(rpc.Request{"SimpleMethods", 0, "a0", "SliceArg"}, arg1, &ret)
   961  	c.Assert(err, jc.ErrorIsNil)
   962  	c.Assert(ret.Val, gc.Equals, "SliceArg ret")
   963  }
   964  
   965  func (*rpcSuite) TestErrorAfterClientClose(c *gc.C) {
   966  	client, srvDone, _ := newRPCClientServer(c, &Root{}, nil, false)
   967  	err := client.Close()
   968  	c.Assert(err, jc.ErrorIsNil)
   969  	err = client.Call(rpc.Request{"Foo", 0, "", "Bar"}, nil, nil)
   970  	c.Assert(errors.Cause(err), gc.Equals, rpc.ErrShutdown)
   971  	err = chanReadError(c, srvDone, "server done")
   972  	c.Assert(err, jc.ErrorIsNil)
   973  }
   974  
   975  func (*rpcSuite) TestClientCloseIdempotent(c *gc.C) {
   976  	client, _, _ := newRPCClientServer(c, &Root{}, nil, false)
   977  	err := client.Close()
   978  	c.Assert(err, jc.ErrorIsNil)
   979  	err = client.Close()
   980  	c.Assert(err, jc.ErrorIsNil)
   981  	err = client.Close()
   982  	c.Assert(err, jc.ErrorIsNil)
   983  }
   984  
   985  func (*rpcSuite) TestBidirectional(c *gc.C) {
   986  	srvRoot := &Root{}
   987  	client, srvDone, _ := newRPCClientServer(c, srvRoot, nil, true)
   988  	defer closeClient(c, client, srvDone)
   989  	clientRoot := &Root{conn: client}
   990  	client.Serve(clientRoot, nil)
   991  	var r int64val
   992  	err := client.Call(rpc.Request{"CallbackMethods", 0, "", "Factorial"}, int64val{12}, &r)
   993  	c.Assert(err, jc.ErrorIsNil)
   994  	c.Assert(r.I, gc.Equals, int64(479001600))
   995  }
   996  
   997  func (*rpcSuite) TestServerRequestWhenNotServing(c *gc.C) {
   998  	srvRoot := &Root{}
   999  	client, srvDone, _ := newRPCClientServer(c, srvRoot, nil, true)
  1000  	defer closeClient(c, client, srvDone)
  1001  	var r int64val
  1002  	err := client.Call(rpc.Request{"CallbackMethods", 0, "", "Factorial"}, int64val{12}, &r)
  1003  	c.Assert(err, gc.ErrorMatches, "no service")
  1004  }
  1005  
  1006  func (*rpcSuite) TestChangeAPI(c *gc.C) {
  1007  	srvRoot := &Root{}
  1008  	client, srvDone, _ := newRPCClientServer(c, srvRoot, nil, true)
  1009  	defer closeClient(c, client, srvDone)
  1010  	var s stringVal
  1011  	err := client.Call(rpc.Request{"NewlyAvailable", 0, "", "NewMethod"}, nil, &s)
  1012  	c.Assert(err, gc.ErrorMatches, `unknown object type "NewlyAvailable" \(not implemented\)`)
  1013  	err = client.Call(rpc.Request{"ChangeAPIMethods", 0, "", "ChangeAPI"}, nil, nil)
  1014  	c.Assert(err, jc.ErrorIsNil)
  1015  	err = client.Call(rpc.Request{"ChangeAPIMethods", 0, "", "ChangeAPI"}, nil, nil)
  1016  	c.Assert(err, gc.ErrorMatches, `unknown object type "ChangeAPIMethods" \(not implemented\)`)
  1017  	err = client.Call(rpc.Request{"NewlyAvailable", 0, "", "NewMethod"}, nil, &s)
  1018  	c.Assert(err, jc.ErrorIsNil)
  1019  	c.Assert(s, gc.Equals, stringVal{"new method result"})
  1020  }
  1021  
  1022  func (*rpcSuite) TestChangeAPIToNil(c *gc.C) {
  1023  	srvRoot := &Root{}
  1024  	client, srvDone, _ := newRPCClientServer(c, srvRoot, nil, true)
  1025  	defer closeClient(c, client, srvDone)
  1026  
  1027  	err := client.Call(rpc.Request{"ChangeAPIMethods", 0, "", "RemoveAPI"}, nil, nil)
  1028  	c.Assert(err, jc.ErrorIsNil)
  1029  
  1030  	err = client.Call(rpc.Request{"ChangeAPIMethods", 0, "", "RemoveAPI"}, nil, nil)
  1031  	c.Assert(err, gc.ErrorMatches, "no service")
  1032  }
  1033  
  1034  func (*rpcSuite) TestChangeAPIWhileServingRequest(c *gc.C) {
  1035  	ready := make(chan struct{})
  1036  	done := make(chan error)
  1037  	srvRoot := &Root{
  1038  		delayed: map[string]*DelayedMethods{
  1039  			"1": {ready: ready, doneError: done},
  1040  		},
  1041  	}
  1042  	transform := func(err error) error {
  1043  		return fmt.Errorf("transformed: %v", err)
  1044  	}
  1045  	client, srvDone, _ := newRPCClientServer(c, srvRoot, transform, true)
  1046  	defer closeClient(c, client, srvDone)
  1047  
  1048  	result := make(chan error)
  1049  	go func() {
  1050  		result <- client.Call(rpc.Request{"DelayedMethods", 0, "1", "Delay"}, nil, nil)
  1051  	}()
  1052  	chanRead(c, ready, "method ready")
  1053  
  1054  	err := client.Call(rpc.Request{"ChangeAPIMethods", 0, "", "ChangeAPI"}, nil, nil)
  1055  	c.Assert(err, jc.ErrorIsNil)
  1056  
  1057  	// Ensure that not only does the request in progress complete,
  1058  	// but that the original transformErrors function is called.
  1059  	done <- fmt.Errorf("an error")
  1060  	select {
  1061  	case r := <-result:
  1062  		c.Assert(r, gc.ErrorMatches, "transformed: an error")
  1063  	case <-time.After(3 * time.Second):
  1064  		c.Fatalf("timeout on channel read")
  1065  	}
  1066  }
  1067  
  1068  func (*rpcSuite) TestCodeNotImplementedMatchesAPIserverParams(c *gc.C) {
  1069  	c.Assert(rpc.CodeNotImplemented, gc.Equals, params.CodeNotImplemented)
  1070  }
  1071  
  1072  func chanReadError(c *gc.C, ch <-chan error, what string) error {
  1073  	select {
  1074  	case e := <-ch:
  1075  		return e
  1076  	case <-time.After(3 * time.Second):
  1077  		c.Fatalf("timeout on channel read %s", what)
  1078  	}
  1079  	panic("unreachable")
  1080  }
  1081  
  1082  // newRPCClientServer starts an RPC server serving a connection from a
  1083  // single client.  When the server has finished serving the connection,
  1084  // it sends a value on the returned channel.
  1085  // If bidir is true, requests can flow in both directions.
  1086  func newRPCClientServer(
  1087  	c *gc.C,
  1088  	root interface{},
  1089  	tfErr func(error) error,
  1090  	bidir bool,
  1091  ) (client *rpc.Conn, srvDone chan error, serverNotifier *notifier) {
  1092  	l, err := net.Listen("tcp", "127.0.0.1:0")
  1093  	c.Assert(err, jc.ErrorIsNil)
  1094  
  1095  	srvDone = make(chan error, 1)
  1096  	serverNotifier = new(notifier)
  1097  	go func() {
  1098  		conn, err := l.Accept()
  1099  		if err != nil {
  1100  			srvDone <- nil
  1101  			return
  1102  		}
  1103  		defer l.Close()
  1104  		role := roleServer
  1105  		if bidir {
  1106  			role = roleBoth
  1107  		}
  1108  		rpcConn := rpc.NewConn(NewJSONCodec(conn, role), serverNotifier)
  1109  		if custroot, ok := root.(*CustomRoot); ok {
  1110  			rpcConn.ServeRoot(custroot, tfErr)
  1111  			custroot.root.conn = rpcConn
  1112  		} else {
  1113  			rpcConn.Serve(root, tfErr)
  1114  		}
  1115  		if root, ok := root.(*Root); ok {
  1116  			root.conn = rpcConn
  1117  		}
  1118  		rpcConn.Start()
  1119  		<-rpcConn.Dead()
  1120  		srvDone <- rpcConn.Close()
  1121  	}()
  1122  	conn, err := net.Dial("tcp", l.Addr().String())
  1123  	c.Assert(err, jc.ErrorIsNil)
  1124  	role := roleClient
  1125  	if bidir {
  1126  		role = roleBoth
  1127  	}
  1128  	client = rpc.NewConn(NewJSONCodec(conn, role), &notifier{})
  1129  	client.Start()
  1130  	return client, srvDone, serverNotifier
  1131  }
  1132  
  1133  func closeClient(c *gc.C, client *rpc.Conn, srvDone <-chan error) {
  1134  	err := client.Close()
  1135  	c.Assert(err, jc.ErrorIsNil)
  1136  	err = chanReadError(c, srvDone, "server done")
  1137  	c.Assert(err, jc.ErrorIsNil)
  1138  }
  1139  
  1140  type encoder interface {
  1141  	Encode(e interface{}) error
  1142  }
  1143  
  1144  type decoder interface {
  1145  	Decode(e interface{}) error
  1146  }
  1147  
  1148  // testCodec wraps an rpc.Codec with extra error checking code.
  1149  type testCodec struct {
  1150  	role connRole
  1151  	rpc.Codec
  1152  }
  1153  
  1154  func (c *testCodec) WriteMessage(hdr *rpc.Header, x interface{}) error {
  1155  	if reflect.ValueOf(x).Kind() != reflect.Struct {
  1156  		panic(fmt.Errorf("WriteRequest bad param; want struct got %T (%#v)", x, x))
  1157  	}
  1158  	if c.role != roleBoth && hdr.IsRequest() != (c.role == roleClient) {
  1159  		panic(fmt.Errorf("codec role %v; header wrong type %#v", c.role, hdr))
  1160  	}
  1161  	logger.Infof("send header: %#v; body: %#v", hdr, x)
  1162  	return c.Codec.WriteMessage(hdr, x)
  1163  }
  1164  
  1165  func (c *testCodec) ReadHeader(hdr *rpc.Header) error {
  1166  	err := c.Codec.ReadHeader(hdr)
  1167  	if err != nil {
  1168  		return err
  1169  	}
  1170  	logger.Infof("got header %#v", hdr)
  1171  	if c.role != roleBoth && hdr.IsRequest() == (c.role == roleClient) {
  1172  		panic(fmt.Errorf("codec role %v; read wrong type %#v", c.role, hdr))
  1173  	}
  1174  	return nil
  1175  }
  1176  
  1177  func (c *testCodec) ReadBody(r interface{}, isRequest bool) error {
  1178  	if v := reflect.ValueOf(r); v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
  1179  		panic(fmt.Errorf("ReadResponseBody bad destination; want *struct got %T", r))
  1180  	}
  1181  	if c.role != roleBoth && isRequest == (c.role == roleClient) {
  1182  		panic(fmt.Errorf("codec role %v; read wrong body type %#v", c.role, r))
  1183  	}
  1184  	// Note: this will need to change if we want to test a non-JSON codec.
  1185  	var m json.RawMessage
  1186  	err := c.Codec.ReadBody(&m, isRequest)
  1187  	if err != nil {
  1188  		return err
  1189  	}
  1190  	logger.Infof("got response body: %q", m)
  1191  	err = json.Unmarshal(m, r)
  1192  	logger.Infof("unmarshalled into %#v", r)
  1193  	return err
  1194  }
  1195  
  1196  type connRole string
  1197  
  1198  const (
  1199  	roleBoth   connRole = "both"
  1200  	roleClient connRole = "client"
  1201  	roleServer connRole = "server"
  1202  )
  1203  
  1204  func NewJSONCodec(c net.Conn, role connRole) rpc.Codec {
  1205  	return &testCodec{
  1206  		role:  role,
  1207  		Codec: jsoncodec.NewNet(c),
  1208  	}
  1209  }
  1210  
  1211  type requestEvent struct {
  1212  	hdr  rpc.Header
  1213  	body interface{}
  1214  }
  1215  
  1216  type replyEvent struct {
  1217  	req  rpc.Request
  1218  	hdr  rpc.Header
  1219  	body interface{}
  1220  }
  1221  
  1222  type notifier struct {
  1223  	mu             sync.Mutex
  1224  	serverRequests []requestEvent
  1225  	serverReplies  []replyEvent
  1226  }
  1227  
  1228  func (n *notifier) RPCObserver() rpc.Observer {
  1229  	// For testing, we usually won't want an actual copy of the
  1230  	// stub. To avoid confusing test failures (e.g. wondering why your
  1231  	// calls aren't showing up on your stub because the underlying
  1232  	// code has called DeepCopy) and immense complexity, just return
  1233  	// the same value.
  1234  	return n
  1235  }
  1236  
  1237  func (n *notifier) reset() {
  1238  	n.mu.Lock()
  1239  	defer n.mu.Unlock()
  1240  	n.serverRequests = nil
  1241  	n.serverReplies = nil
  1242  }
  1243  
  1244  func (n *notifier) ServerRequest(hdr *rpc.Header, body interface{}) {
  1245  	n.mu.Lock()
  1246  	defer n.mu.Unlock()
  1247  	n.serverRequests = append(n.serverRequests, requestEvent{
  1248  		hdr:  *hdr,
  1249  		body: body,
  1250  	})
  1251  }
  1252  
  1253  func (n *notifier) ServerReply(req rpc.Request, hdr *rpc.Header, body interface{}) {
  1254  	n.mu.Lock()
  1255  	defer n.mu.Unlock()
  1256  	n.serverReplies = append(n.serverReplies, replyEvent{
  1257  		req:  req,
  1258  		hdr:  *hdr,
  1259  		body: body,
  1260  	})
  1261  }