github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/rpc/rpc_test.go (about)

     1  // Copyright 2012, 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package rpc_test
     5  
     6  import (
     7  	"context"
     8  	"encoding/json"
     9  	"fmt"
    10  	"net"
    11  	"reflect"
    12  	"regexp"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/juju/errors"
    17  	"github.com/juju/loggo"
    18  	jc "github.com/juju/testing/checkers"
    19  	gc "gopkg.in/check.v1"
    20  
    21  	"github.com/juju/juju/apiserver/params"
    22  	"github.com/juju/juju/rpc"
    23  	"github.com/juju/juju/rpc/jsoncodec"
    24  	"github.com/juju/juju/rpc/rpcreflect"
    25  	"github.com/juju/juju/testing"
    26  )
    27  
    28  var logger = loggo.GetLogger("juju.rpc")
    29  
    30  type rpcSuite struct {
    31  	testing.BaseSuite
    32  }
    33  
    34  var _ = gc.Suite(&rpcSuite{})
    35  
    36  type callInfo struct {
    37  	rcvr   interface{}
    38  	method string
    39  	arg    interface{}
    40  }
    41  
    42  type callError callInfo
    43  
    44  func (e *callError) Error() string {
    45  	return fmt.Sprintf("error calling %s", e.method)
    46  }
    47  
    48  type stringVal struct {
    49  	Val string
    50  }
    51  
    52  type Root struct {
    53  	mu          sync.Mutex
    54  	conn        *rpc.Conn
    55  	calls       []*callInfo
    56  	returnErr   bool
    57  	simple      map[string]*SimpleMethods
    58  	delayed     map[string]*DelayedMethods
    59  	errorInst   *ErrorMethods
    60  	contextInst *ContextMethods
    61  }
    62  
    63  func (r *Root) callError(rcvr interface{}, name string, arg interface{}) error {
    64  	if r.returnErr {
    65  		return &callError{rcvr, name, arg}
    66  	}
    67  	return nil
    68  }
    69  
    70  func (r *Root) SimpleMethods(id string) (*SimpleMethods, error) {
    71  	r.mu.Lock()
    72  	defer r.mu.Unlock()
    73  	if a := r.simple[id]; a != nil {
    74  		return a, nil
    75  	}
    76  	return nil, fmt.Errorf("unknown SimpleMethods id")
    77  }
    78  
    79  func (r *Root) DelayedMethods(id string) (*DelayedMethods, error) {
    80  	r.mu.Lock()
    81  	defer r.mu.Unlock()
    82  	if a := r.delayed[id]; a != nil {
    83  		return a, nil
    84  	}
    85  	return nil, fmt.Errorf("unknown DelayedMethods id")
    86  }
    87  
    88  func (r *Root) ErrorMethods(id string) (*ErrorMethods, error) {
    89  	if r.errorInst == nil {
    90  		return nil, fmt.Errorf("no error methods")
    91  	}
    92  	return r.errorInst, nil
    93  }
    94  
    95  func (r *Root) ContextMethods(id string) (*ContextMethods, error) {
    96  	if r.contextInst == nil {
    97  		return nil, fmt.Errorf("no context methods")
    98  	}
    99  	return r.contextInst, nil
   100  }
   101  
   102  func (r *Root) Discard1() {}
   103  
   104  func (r *Root) Discard2(id string) error { return nil }
   105  
   106  func (r *Root) Discard3(id string) int { return 0 }
   107  
   108  func (r *Root) CallbackMethods(string) (*CallbackMethods, error) {
   109  	return &CallbackMethods{r}, nil
   110  }
   111  
   112  func (r *Root) InterfaceMethods(id string) (InterfaceMethods, error) {
   113  	logger.Infof("interface methods called")
   114  	m, err := r.SimpleMethods(id)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	return m, nil
   119  }
   120  
   121  type InterfaceMethods interface {
   122  	Call1r1e(s stringVal) (stringVal, error)
   123  }
   124  
   125  type ChangeAPIMethods struct {
   126  	r *Root
   127  }
   128  
   129  func (r *Root) ChangeAPIMethods(string) (*ChangeAPIMethods, error) {
   130  	return &ChangeAPIMethods{r}, nil
   131  }
   132  
   133  func (t *Root) called(rcvr interface{}, method string, arg interface{}) {
   134  	t.mu.Lock()
   135  	t.calls = append(t.calls, &callInfo{rcvr, method, arg})
   136  	t.mu.Unlock()
   137  }
   138  
   139  type SimpleMethods struct {
   140  	root *Root
   141  	id   string
   142  }
   143  
   144  // Each Call method is named in this standard form:
   145  //
   146  //     Call<narg>r<nret><e>
   147  //
   148  // where narg is the number of arguments, nret is the number of returned
   149  // values (not including the error) and e is the letter 'e' if the
   150  // method returns an error.
   151  
   152  func (a *SimpleMethods) Call0r0() {
   153  	a.root.called(a, "Call0r0", nil)
   154  }
   155  
   156  func (a *SimpleMethods) Call0r1() stringVal {
   157  	a.root.called(a, "Call0r1", nil)
   158  	return stringVal{"Call0r1 ret"}
   159  }
   160  
   161  func (a *SimpleMethods) Call0r1e() (stringVal, error) {
   162  	a.root.called(a, "Call0r1e", nil)
   163  	return stringVal{"Call0r1e ret"}, a.root.callError(a, "Call0r1e", nil)
   164  }
   165  
   166  func (a *SimpleMethods) Call0r0e() error {
   167  	a.root.called(a, "Call0r0e", nil)
   168  	return a.root.callError(a, "Call0r0e", nil)
   169  }
   170  
   171  func (a *SimpleMethods) Call1r0(s stringVal) {
   172  	a.root.called(a, "Call1r0", s)
   173  }
   174  
   175  func (a *SimpleMethods) Call1r1(s stringVal) stringVal {
   176  	a.root.called(a, "Call1r1", s)
   177  	return stringVal{"Call1r1 ret"}
   178  }
   179  
   180  func (a *SimpleMethods) Call1r1e(s stringVal) (stringVal, error) {
   181  	a.root.called(a, "Call1r1e", s)
   182  	return stringVal{"Call1r1e ret"}, a.root.callError(a, "Call1r1e", s)
   183  }
   184  
   185  func (a *SimpleMethods) Call1r0e(s stringVal) error {
   186  	a.root.called(a, "Call1r0e", s)
   187  	return a.root.callError(a, "Call1r0e", s)
   188  }
   189  
   190  func (a *SimpleMethods) SliceArg(struct{ X []string }) stringVal {
   191  	return stringVal{"SliceArg ret"}
   192  }
   193  
   194  func (a *SimpleMethods) Discard1(int) {}
   195  
   196  func (a *SimpleMethods) Discard2(struct{}, struct{}) {}
   197  
   198  func (a *SimpleMethods) Discard3() int { return 0 }
   199  
   200  func (a *SimpleMethods) Discard4() (_, _ struct{}) { return }
   201  
   202  type ContextMethods struct {
   203  	root        *Root
   204  	callContext context.Context
   205  	waiting     chan struct{}
   206  }
   207  
   208  func (c *ContextMethods) Call0(ctx context.Context) error {
   209  	c.root.called(c, "Call0", nil)
   210  	c.callContext = ctx
   211  	return c.checkContext(ctx)
   212  }
   213  
   214  func (c *ContextMethods) Call1(ctx context.Context, s stringVal) error {
   215  	c.root.called(c, "Call1", s)
   216  	c.callContext = ctx
   217  	return c.checkContext(ctx)
   218  }
   219  
   220  func (c *ContextMethods) Wait(ctx context.Context) error {
   221  	c.root.called(c, "Wait", nil)
   222  	close(c.waiting)
   223  	select {
   224  	case <-ctx.Done():
   225  		return ctx.Err()
   226  	case <-time.After(testing.LongWait):
   227  		return errors.New("expected context to be cancelled")
   228  	}
   229  }
   230  
   231  func (c *ContextMethods) checkContext(ctx context.Context) error {
   232  	select {
   233  	case <-ctx.Done():
   234  		return ctx.Err()
   235  	case <-time.After(testing.ShortWait):
   236  	}
   237  	return nil
   238  }
   239  
   240  type DelayedMethods struct {
   241  	ready     chan struct{}
   242  	done      chan string
   243  	doneError chan error
   244  }
   245  
   246  func (a *DelayedMethods) Delay() (stringVal, error) {
   247  	if a.ready != nil {
   248  		a.ready <- struct{}{}
   249  	}
   250  	select {
   251  	case s := <-a.done:
   252  		return stringVal{s}, nil
   253  	case err := <-a.doneError:
   254  		return stringVal{}, err
   255  	}
   256  }
   257  
   258  type ErrorMethods struct {
   259  	err error
   260  }
   261  
   262  func (e *ErrorMethods) Call() error {
   263  	return e.err
   264  }
   265  
   266  type CallbackMethods struct {
   267  	root *Root
   268  }
   269  
   270  type int64val struct {
   271  	I int64
   272  }
   273  
   274  func (a *CallbackMethods) Factorial(x int64val) (int64val, error) {
   275  	if x.I <= 1 {
   276  		return int64val{1}, nil
   277  	}
   278  	var r int64val
   279  	err := a.root.conn.Call(rpc.Request{"CallbackMethods", 0, "", "Factorial"}, int64val{x.I - 1}, &r)
   280  	if err != nil {
   281  		return int64val{}, err
   282  	}
   283  	return int64val{x.I * r.I}, nil
   284  }
   285  
   286  func (a *ChangeAPIMethods) ChangeAPI() {
   287  	a.r.conn.Serve(&changedAPIRoot{}, nil, nil)
   288  }
   289  
   290  func (a *ChangeAPIMethods) RemoveAPI() {
   291  	a.r.conn.Serve(nil, nil, nil)
   292  }
   293  
   294  type changedAPIRoot struct{}
   295  
   296  func (r *changedAPIRoot) NewlyAvailable(string) (newlyAvailableMethods, error) {
   297  	return newlyAvailableMethods{}, nil
   298  }
   299  
   300  type newlyAvailableMethods struct{}
   301  
   302  func (newlyAvailableMethods) NewMethod() stringVal {
   303  	return stringVal{"new method result"}
   304  }
   305  
   306  type VariableMethods1 struct {
   307  	sm *SimpleMethods
   308  }
   309  
   310  func (vm *VariableMethods1) Call0r1() stringVal {
   311  	return vm.sm.Call0r1()
   312  }
   313  
   314  type VariableMethods2 struct {
   315  	sm *SimpleMethods
   316  }
   317  
   318  func (vm *VariableMethods2) Call1r1(s stringVal) stringVal {
   319  	return vm.sm.Call1r1(s)
   320  }
   321  
   322  type RestrictedMethods struct {
   323  	InterfaceMethods
   324  }
   325  
   326  type CustomRoot struct {
   327  	root *Root
   328  }
   329  
   330  type wrapper func(*SimpleMethods) reflect.Value
   331  
   332  type customMethodCaller struct {
   333  	wrap         wrapper
   334  	root         *Root
   335  	objMethod    rpcreflect.ObjMethod
   336  	expectedType reflect.Type
   337  }
   338  
   339  func (c customMethodCaller) ParamsType() reflect.Type {
   340  	return c.objMethod.Params
   341  }
   342  
   343  func (c customMethodCaller) ResultType() reflect.Type {
   344  	return c.objMethod.Result
   345  }
   346  
   347  func (c customMethodCaller) Call(_ context.Context, objId string, arg reflect.Value) (reflect.Value, error) {
   348  	sm, err := c.root.SimpleMethods(objId)
   349  	if err != nil {
   350  		return reflect.Value{}, err
   351  	}
   352  	obj := c.wrap(sm)
   353  	if reflect.TypeOf(obj) != c.expectedType {
   354  		logger.Errorf("got the wrong type back, expected %s got %T", c.expectedType, obj)
   355  	}
   356  	logger.Debugf("calling: %T %v %#v", obj, obj, c.objMethod)
   357  	return c.objMethod.Call(context.TODO(), obj, arg)
   358  }
   359  
   360  func (cc *CustomRoot) Kill() {
   361  }
   362  
   363  func (cc *CustomRoot) FindMethod(
   364  	rootMethodName string, version int, objMethodName string,
   365  ) (
   366  	rpcreflect.MethodCaller, error,
   367  ) {
   368  	logger.Debugf("got to FindMethod: %q %d %q", rootMethodName, version, objMethodName)
   369  	if rootMethodName != "MultiVersion" {
   370  		return nil, &rpcreflect.CallNotImplementedError{
   371  			RootMethod: rootMethodName,
   372  		}
   373  	}
   374  	var goType reflect.Type
   375  	var wrap wrapper
   376  	switch version {
   377  	case 0:
   378  		goType = reflect.TypeOf((*VariableMethods1)(nil))
   379  		wrap = func(sm *SimpleMethods) reflect.Value {
   380  			return reflect.ValueOf(&VariableMethods1{sm})
   381  		}
   382  	case 1:
   383  		goType = reflect.TypeOf((*VariableMethods2)(nil))
   384  		wrap = func(sm *SimpleMethods) reflect.Value {
   385  			return reflect.ValueOf(&VariableMethods2{sm})
   386  		}
   387  	case 2:
   388  		goType = reflect.TypeOf((*RestrictedMethods)(nil))
   389  		wrap = func(sm *SimpleMethods) reflect.Value {
   390  			methods := &RestrictedMethods{InterfaceMethods: sm}
   391  			return reflect.ValueOf(methods)
   392  		}
   393  	default:
   394  		return nil, &rpcreflect.CallNotImplementedError{
   395  			RootMethod: rootMethodName,
   396  			Version:    version,
   397  		}
   398  	}
   399  	logger.Debugf("found type: %s", goType)
   400  	objType := rpcreflect.ObjTypeOf(goType)
   401  	objMethod, err := objType.Method(objMethodName)
   402  	if err != nil {
   403  		return nil, &rpcreflect.CallNotImplementedError{
   404  			RootMethod: rootMethodName,
   405  			Version:    version,
   406  			Method:     objMethodName,
   407  		}
   408  	}
   409  	return customMethodCaller{
   410  		objMethod:    objMethod,
   411  		root:         cc.root,
   412  		wrap:         wrap,
   413  		expectedType: goType,
   414  	}, nil
   415  }
   416  
   417  func SimpleRoot() *Root {
   418  	root := &Root{
   419  		simple: make(map[string]*SimpleMethods),
   420  	}
   421  	root.simple["a99"] = &SimpleMethods{root: root, id: "a99"}
   422  	return root
   423  }
   424  
   425  func (*rpcSuite) TestRPC(c *gc.C) {
   426  	root := SimpleRoot()
   427  	client, _, srvDone, serverNotifier := newRPCClientServer(c, root, nil, false)
   428  	defer closeClient(c, client, srvDone)
   429  	for narg := 0; narg < 2; narg++ {
   430  		for nret := 0; nret < 2; nret++ {
   431  			for nerr := 0; nerr < 2; nerr++ {
   432  				retErr := nerr != 0
   433  				p := testCallParams{
   434  					client:         client,
   435  					serverNotifier: serverNotifier,
   436  					entry:          "SimpleMethods",
   437  					narg:           narg,
   438  					nret:           nret,
   439  					retErr:         retErr,
   440  					testErr:        false,
   441  				}
   442  				root.testCall(c, p)
   443  				if retErr {
   444  					p.testErr = true
   445  					root.testCall(c, p)
   446  				}
   447  			}
   448  		}
   449  	}
   450  }
   451  
   452  func callName(narg, nret int, retErr bool) string {
   453  	e := ""
   454  	if retErr {
   455  		e = "e"
   456  	}
   457  	return fmt.Sprintf("Call%dr%d%s", narg, nret, e)
   458  }
   459  
   460  type testCallParams struct {
   461  	// client holds the client-side of the rpc connection that
   462  	// will be used to make the call.
   463  	client *rpc.Conn
   464  
   465  	// serverNotifier holds the notifier for the server side.
   466  	serverNotifier *notifier
   467  
   468  	// entry holds the top-level type that will be invoked
   469  	// (e.g. "SimpleMethods").
   470  	entry string
   471  
   472  	// narg holds the number of arguments accepted by the
   473  	// call (0 or 1).
   474  	narg int
   475  
   476  	// nret holds the number of values returned by the
   477  	// call (0 or 1).
   478  	nret int
   479  
   480  	// retErr specifies whether the call returns an error.
   481  	retErr bool
   482  
   483  	// testErr specifies whether the call should be made to return an error.
   484  	testErr bool
   485  
   486  	// version specifies what version of the interface to call, defaults to 0.
   487  	version int
   488  }
   489  
   490  // request returns the RPC request for the test call.
   491  func (p testCallParams) request() rpc.Request {
   492  	return rpc.Request{
   493  		Type:    p.entry,
   494  		Version: p.version,
   495  		Id:      "a99",
   496  		Action:  callName(p.narg, p.nret, p.retErr),
   497  	}
   498  }
   499  
   500  // error message returns the error message that the test call
   501  // should return if it returns an error.
   502  func (p testCallParams) errorMessage() string {
   503  	return fmt.Sprintf("error calling %s", p.request().Action)
   504  }
   505  
   506  func (root *Root) testCall(c *gc.C, args testCallParams) {
   507  	args.serverNotifier.reset()
   508  	root.calls = nil
   509  	root.returnErr = args.testErr
   510  	c.Logf("test call %s", args.request().Action)
   511  	var response stringVal
   512  	err := args.client.Call(args.request(), stringVal{"arg"}, &response)
   513  	switch {
   514  	case args.retErr && args.testErr:
   515  		c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   516  			Message: args.errorMessage(),
   517  		})
   518  		c.Assert(response, gc.Equals, stringVal{})
   519  	case args.nret > 0:
   520  		c.Check(response, gc.Equals, stringVal{args.request().Action + " ret"})
   521  	}
   522  	if !args.testErr {
   523  		c.Check(err, jc.ErrorIsNil)
   524  	}
   525  
   526  	// Check that the call was actually made, the right
   527  	// parameters were received and the right result returned.
   528  	root.mu.Lock()
   529  	defer root.mu.Unlock()
   530  
   531  	root.assertCallMade(c, args)
   532  	root.assertServerNotified(c, args, args.client.ClientRequestID())
   533  }
   534  
   535  func (root *Root) assertCallMade(c *gc.C, p testCallParams) {
   536  	expectCall := callInfo{
   537  		rcvr:   root.simple["a99"],
   538  		method: p.request().Action,
   539  	}
   540  	if p.narg > 0 {
   541  		expectCall.arg = stringVal{"arg"}
   542  	}
   543  	c.Assert(root.calls, gc.HasLen, 1)
   544  	c.Assert(*root.calls[0], gc.Equals, expectCall)
   545  }
   546  
   547  // assertServerNotified asserts that the right server notifications
   548  // were made for the given test call parameters. The id of the request
   549  // is held in requestId.
   550  func (root *Root) assertServerNotified(c *gc.C, p testCallParams, requestId uint64) {
   551  	// Test that there was a notification for the request.
   552  	c.Assert(p.serverNotifier.serverRequests, gc.HasLen, 1)
   553  	serverReq := p.serverNotifier.serverRequests[0]
   554  	c.Assert(serverReq.hdr, gc.DeepEquals, rpc.Header{
   555  		RequestId: requestId,
   556  		Request:   p.request(),
   557  		Version:   1,
   558  	})
   559  	if p.narg > 0 {
   560  		c.Assert(serverReq.body, gc.Equals, stringVal{"arg"})
   561  	} else {
   562  		c.Assert(serverReq.body, gc.Equals, struct{}{})
   563  	}
   564  
   565  	// Test that there was a notification for the reply.
   566  	c.Assert(p.serverNotifier.serverReplies, gc.HasLen, 1)
   567  	serverReply := p.serverNotifier.serverReplies[0]
   568  	c.Assert(serverReply.req, gc.Equals, p.request())
   569  	if p.retErr && p.testErr || p.nret == 0 {
   570  		c.Assert(serverReply.body, gc.Equals, struct{}{})
   571  	} else {
   572  		c.Assert(serverReply.body, gc.Equals, stringVal{p.request().Action + " ret"})
   573  	}
   574  	if p.retErr && p.testErr {
   575  		c.Assert(serverReply.hdr, gc.Equals, rpc.Header{
   576  			RequestId: requestId,
   577  			Error:     p.errorMessage(),
   578  			Version:   1,
   579  		})
   580  	} else {
   581  		c.Assert(serverReply.hdr, gc.Equals, rpc.Header{
   582  			RequestId: requestId,
   583  			Version:   1,
   584  		})
   585  	}
   586  }
   587  
   588  func (*rpcSuite) TestInterfaceMethods(c *gc.C) {
   589  	root := SimpleRoot()
   590  	client, _, srvDone, serverNotifier := newRPCClientServer(c, root, nil, false)
   591  	defer closeClient(c, client, srvDone)
   592  	p := testCallParams{
   593  		client:         client,
   594  		serverNotifier: serverNotifier,
   595  		entry:          "InterfaceMethods",
   596  		narg:           1,
   597  		nret:           1,
   598  		retErr:         true,
   599  		testErr:        false,
   600  	}
   601  
   602  	root.testCall(c, p)
   603  	p.testErr = true
   604  	root.testCall(c, p)
   605  	// Call0r0 is defined on the underlying SimpleMethods, but is not
   606  	// exposed at the InterfaceMethods level, so this call should fail with
   607  	// CodeNotImplemented.
   608  	var r stringVal
   609  	err := client.Call(rpc.Request{"InterfaceMethods", 0, "a99", "Call0r0"}, stringVal{"arg"}, &r)
   610  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   611  		Message: "no such request - method InterfaceMethods.Call0r0 is not implemented",
   612  		Code:    rpc.CodeNotImplemented,
   613  	})
   614  }
   615  
   616  func (*rpcSuite) TestCustomRootV0(c *gc.C) {
   617  	root := &CustomRoot{SimpleRoot()}
   618  	client, _, srvDone, serverNotifier := newRPCClientServer(c, root, nil, false)
   619  	defer closeClient(c, client, srvDone)
   620  	// V0 of MultiVersion implements only VariableMethods1.Call0r1.
   621  	p := testCallParams{
   622  		client:         client,
   623  		serverNotifier: serverNotifier,
   624  		entry:          "MultiVersion",
   625  		version:        0,
   626  		narg:           0,
   627  		nret:           1,
   628  		retErr:         false,
   629  		testErr:        false,
   630  	}
   631  
   632  	root.root.testCall(c, p)
   633  	// Call1r1 is exposed in version 1, but not in version 0.
   634  	var r stringVal
   635  	err := client.Call(rpc.Request{"MultiVersion", 0, "a99", "Call1r1"}, stringVal{"arg"}, &r)
   636  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   637  		Message: "no such request - method MultiVersion.Call1r1 is not implemented",
   638  		Code:    rpc.CodeNotImplemented,
   639  	})
   640  }
   641  
   642  func (*rpcSuite) TestCustomRootV1(c *gc.C) {
   643  	root := &CustomRoot{SimpleRoot()}
   644  	client, _, srvDone, serverNotifier := newRPCClientServer(c, root, nil, false)
   645  	defer closeClient(c, client, srvDone)
   646  	// V1 of MultiVersion implements only VariableMethods2.Call1r1.
   647  	p := testCallParams{
   648  		client:         client,
   649  		serverNotifier: serverNotifier,
   650  		entry:          "MultiVersion",
   651  		version:        1,
   652  		narg:           1,
   653  		nret:           1,
   654  		retErr:         false,
   655  		testErr:        false,
   656  	}
   657  
   658  	root.root.testCall(c, p)
   659  	// Call0r1 is exposed in version 0, but not in version 1.
   660  	var r stringVal
   661  	err := client.Call(rpc.Request{"MultiVersion", 1, "a99", "Call0r1"}, nil, &r)
   662  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   663  		Message: "no such request - method MultiVersion(1).Call0r1 is not implemented",
   664  		Code:    rpc.CodeNotImplemented,
   665  	})
   666  }
   667  
   668  func (*rpcSuite) TestCustomRootV2(c *gc.C) {
   669  	root := &CustomRoot{SimpleRoot()}
   670  	client, _, srvDone, serverNotifier := newRPCClientServer(c, root, nil, false)
   671  	defer closeClient(c, client, srvDone)
   672  	p := testCallParams{
   673  		client:         client,
   674  		serverNotifier: serverNotifier,
   675  		entry:          "MultiVersion",
   676  		version:        2,
   677  		narg:           1,
   678  		nret:           1,
   679  		retErr:         true,
   680  		testErr:        false,
   681  	}
   682  
   683  	root.root.testCall(c, p)
   684  	// By embedding the InterfaceMethods inside a concrete
   685  	// RestrictedMethods type, we actually only expose the methods defined
   686  	// in InterfaceMethods.
   687  	var r stringVal
   688  	err := client.Call(rpc.Request{"MultiVersion", 2, "a99", "Call0r1e"}, nil, &r)
   689  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   690  		Message: `no such request - method MultiVersion(2).Call0r1e is not implemented`,
   691  		Code:    rpc.CodeNotImplemented,
   692  	})
   693  }
   694  
   695  func (*rpcSuite) TestCustomRootUnknownVersion(c *gc.C) {
   696  	root := &CustomRoot{SimpleRoot()}
   697  	client, _, srvDone, _ := newRPCClientServer(c, root, nil, false)
   698  	defer closeClient(c, client, srvDone)
   699  	var r stringVal
   700  	// Unknown version 5
   701  	err := client.Call(rpc.Request{"MultiVersion", 5, "a99", "Call0r1"}, nil, &r)
   702  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   703  		Message: `unknown version (5) of interface "MultiVersion"`,
   704  		Code:    rpc.CodeNotImplemented,
   705  	})
   706  }
   707  
   708  func (*rpcSuite) TestConcurrentCalls(c *gc.C) {
   709  	start1 := make(chan string)
   710  	start2 := make(chan string)
   711  	ready1 := make(chan struct{})
   712  	ready2 := make(chan struct{})
   713  
   714  	root := &Root{
   715  		delayed: map[string]*DelayedMethods{
   716  			"1": {ready: ready1, done: start1},
   717  			"2": {ready: ready2, done: start2},
   718  		},
   719  	}
   720  
   721  	client, _, srvDone, _ := newRPCClientServer(c, root, nil, false)
   722  	defer closeClient(c, client, srvDone)
   723  	call := func(id string, done chan<- struct{}) {
   724  		var r stringVal
   725  		err := client.Call(rpc.Request{"DelayedMethods", 0, id, "Delay"}, nil, &r)
   726  		c.Check(err, jc.ErrorIsNil)
   727  		c.Check(r.Val, gc.Equals, "return "+id)
   728  		done <- struct{}{}
   729  	}
   730  	done1 := make(chan struct{})
   731  	done2 := make(chan struct{})
   732  	go call("1", done1)
   733  	go call("2", done2)
   734  
   735  	// Check that both calls are running concurrently.
   736  	chanRead(c, ready1, "method 1 ready")
   737  	chanRead(c, ready2, "method 2 ready")
   738  
   739  	// Let the requests complete.
   740  	start1 <- "return 1"
   741  	start2 <- "return 2"
   742  	chanRead(c, done1, "method 1 done")
   743  	chanRead(c, done2, "method 2 done")
   744  }
   745  
   746  type codedError struct {
   747  	m    string
   748  	code string
   749  }
   750  
   751  func (e *codedError) Error() string {
   752  	return e.m
   753  }
   754  
   755  func (e *codedError) ErrorCode() string {
   756  	return e.code
   757  }
   758  
   759  func (*rpcSuite) TestErrorCode(c *gc.C) {
   760  	root := &Root{
   761  		errorInst: &ErrorMethods{&codedError{"message", "code"}},
   762  	}
   763  	client, _, srvDone, _ := newRPCClientServer(c, root, nil, false)
   764  	defer closeClient(c, client, srvDone)
   765  	err := client.Call(rpc.Request{"ErrorMethods", 0, "", "Call"}, nil, nil)
   766  	c.Assert(err, gc.ErrorMatches, `message \(code\)`)
   767  	c.Assert(errors.Cause(err).(rpc.ErrorCoder).ErrorCode(), gc.Equals, "code")
   768  }
   769  
   770  func (*rpcSuite) TestTransformErrors(c *gc.C) {
   771  	root := &Root{
   772  		errorInst: &ErrorMethods{&codedError{"message", "code"}},
   773  	}
   774  	tfErr := func(err error) error {
   775  		c.Check(err, gc.NotNil)
   776  		if e, ok := err.(*codedError); ok {
   777  			return &codedError{
   778  				m:    "transformed: " + e.m,
   779  				code: "transformed: " + e.code,
   780  			}
   781  		}
   782  		return fmt.Errorf("transformed: %v", err)
   783  	}
   784  	client, _, srvDone, _ := newRPCClientServer(c, root, tfErr, false)
   785  	defer closeClient(c, client, srvDone)
   786  	// First, we don't transform methods we can't find.
   787  	err := client.Call(rpc.Request{"foo", 0, "", "bar"}, nil, nil)
   788  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   789  		Message: `unknown object type "foo"`,
   790  		Code:    rpc.CodeNotImplemented,
   791  	})
   792  
   793  	err = client.Call(rpc.Request{"ErrorMethods", 0, "", "NoMethod"}, nil, nil)
   794  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   795  		Message: "no such request - method ErrorMethods.NoMethod is not implemented",
   796  		Code:    rpc.CodeNotImplemented,
   797  	})
   798  
   799  	// We do transform any errors that happen from calling the RootMethod
   800  	// and beyond.
   801  	err = client.Call(rpc.Request{"ErrorMethods", 0, "", "Call"}, nil, nil)
   802  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   803  		Message: "transformed: message",
   804  		Code:    "transformed: code",
   805  	})
   806  
   807  	root.errorInst.err = nil
   808  	err = client.Call(rpc.Request{"ErrorMethods", 0, "", "Call"}, nil, nil)
   809  	c.Assert(err, jc.ErrorIsNil)
   810  
   811  	root.errorInst = nil
   812  	err = client.Call(rpc.Request{"ErrorMethods", 0, "", "Call"}, nil, nil)
   813  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   814  		Message: "transformed: no error methods",
   815  	})
   816  
   817  }
   818  
   819  func (*rpcSuite) TestServerWaitsForOutstandingCalls(c *gc.C) {
   820  	ready := make(chan struct{})
   821  	start := make(chan string)
   822  	root := &Root{
   823  		delayed: map[string]*DelayedMethods{
   824  			"1": {
   825  				ready: ready,
   826  				done:  start,
   827  			},
   828  		},
   829  	}
   830  	client, _, srvDone, _ := newRPCClientServer(c, root, nil, false)
   831  	defer closeClient(c, client, srvDone)
   832  	done := make(chan struct{})
   833  	go func() {
   834  		var r stringVal
   835  		err := client.Call(rpc.Request{"DelayedMethods", 0, "1", "Delay"}, nil, &r)
   836  		c.Check(errors.Cause(err), gc.Equals, rpc.ErrShutdown)
   837  		done <- struct{}{}
   838  	}()
   839  	chanRead(c, ready, "DelayedMethods.Delay ready")
   840  	client.Close()
   841  	select {
   842  	case err := <-srvDone:
   843  		c.Fatalf("server returned while outstanding operation in progress: %v", err)
   844  		<-done
   845  	case <-time.After(25 * time.Millisecond):
   846  	}
   847  	start <- "xxx"
   848  }
   849  
   850  func chanRead(c *gc.C, ch <-chan struct{}, what string) {
   851  	select {
   852  	case <-ch:
   853  		return
   854  	case <-time.After(3 * time.Second):
   855  		c.Fatalf("timeout on channel read %s", what)
   856  	}
   857  }
   858  
   859  func (*rpcSuite) TestCompatibility(c *gc.C) {
   860  	root := &Root{
   861  		simple: make(map[string]*SimpleMethods),
   862  	}
   863  	a0 := &SimpleMethods{root: root, id: "a0"}
   864  	root.simple["a0"] = a0
   865  
   866  	client, _, srvDone, _ := newRPCClientServer(c, root, nil, false)
   867  	defer closeClient(c, client, srvDone)
   868  	call := func(method string, arg, ret interface{}) (passedArg interface{}) {
   869  		root.calls = nil
   870  		err := client.Call(rpc.Request{"SimpleMethods", 0, "a0", method}, arg, ret)
   871  		c.Assert(err, jc.ErrorIsNil)
   872  		c.Assert(root.calls, gc.HasLen, 1)
   873  		info := root.calls[0]
   874  		c.Assert(info.rcvr, gc.Equals, a0)
   875  		c.Assert(info.method, gc.Equals, method)
   876  		return info.arg
   877  	}
   878  	type extra struct {
   879  		Val   string
   880  		Extra string
   881  	}
   882  	// Extra fields in request and response.
   883  	var r extra
   884  	arg := call("Call1r1", extra{"x", "y"}, &r)
   885  	c.Assert(arg, gc.Equals, stringVal{"x"})
   886  
   887  	// Nil argument as request.
   888  	r = extra{}
   889  	arg = call("Call1r1", nil, &r)
   890  	c.Assert(arg, gc.Equals, stringVal{})
   891  
   892  	// Nil argument as response.
   893  	arg = call("Call1r1", stringVal{"x"}, nil)
   894  	c.Assert(arg, gc.Equals, stringVal{"x"})
   895  
   896  	// Non-nil argument for no response.
   897  	r = extra{}
   898  	arg = call("Call1r0", stringVal{"x"}, &r)
   899  	c.Assert(arg, gc.Equals, stringVal{"x"})
   900  	c.Assert(r, gc.Equals, extra{})
   901  }
   902  
   903  func (*rpcSuite) TestBadCall(c *gc.C) {
   904  	loggo.GetLogger("juju.rpc").SetLogLevel(loggo.TRACE)
   905  	root := &Root{
   906  		simple: make(map[string]*SimpleMethods),
   907  	}
   908  	a0 := &SimpleMethods{root: root, id: "a0"}
   909  	root.simple["a0"] = a0
   910  	client, _, srvDone, serverNotifier := newRPCClientServer(c, root, nil, false)
   911  	defer closeClient(c, client, srvDone)
   912  
   913  	testBadCall(c, client, serverNotifier,
   914  		rpc.Request{"BadSomething", 0, "a0", "No"},
   915  		`unknown object type "BadSomething"`,
   916  		rpc.CodeNotImplemented,
   917  		false,
   918  	)
   919  	testBadCall(c, client, serverNotifier,
   920  		rpc.Request{"SimpleMethods", 0, "xx", "No"},
   921  		"no such request - method SimpleMethods.No is not implemented",
   922  		rpc.CodeNotImplemented,
   923  		false,
   924  	)
   925  	testBadCall(c, client, serverNotifier,
   926  		rpc.Request{"SimpleMethods", 0, "xx", "Call0r0"},
   927  		`unknown SimpleMethods id`,
   928  		"",
   929  		true,
   930  	)
   931  }
   932  
   933  func testBadCall(
   934  	c *gc.C,
   935  	client *rpc.Conn,
   936  	serverNotifier *notifier,
   937  	req rpc.Request,
   938  	expectedErr string,
   939  	expectedErrCode string,
   940  	requestKnown bool,
   941  ) {
   942  	serverNotifier.reset()
   943  	err := client.Call(req, nil, nil)
   944  	msg := expectedErr
   945  	if expectedErrCode != "" {
   946  		msg += " (" + expectedErrCode + ")"
   947  	}
   948  	c.Assert(err, gc.ErrorMatches, regexp.QuoteMeta(msg))
   949  
   950  	// From docs on ServerRequest:
   951  	// 	If the request was not recognized or there was
   952  	//	an error reading the body, body will be nil.
   953  	var expectBody interface{}
   954  	if requestKnown {
   955  		expectBody = struct{}{}
   956  	}
   957  	c.Assert(serverNotifier.serverRequests[0], gc.DeepEquals, requestEvent{
   958  		hdr: rpc.Header{
   959  			RequestId: client.ClientRequestID(),
   960  			Request:   req,
   961  			Version:   1,
   962  		},
   963  		body: expectBody,
   964  	})
   965  
   966  	// Test that there was a notification for the server reply.
   967  	c.Assert(serverNotifier.serverReplies, gc.HasLen, 1)
   968  	serverReply := serverNotifier.serverReplies[0]
   969  	c.Assert(serverReply, gc.DeepEquals, replyEvent{
   970  		hdr: rpc.Header{
   971  			RequestId: client.ClientRequestID(),
   972  			Error:     expectedErr,
   973  			ErrorCode: expectedErrCode,
   974  			Version:   1,
   975  		},
   976  		req:  req,
   977  		body: struct{}{},
   978  	})
   979  }
   980  
   981  func (*rpcSuite) TestContinueAfterReadBodyError(c *gc.C) {
   982  	root := &Root{
   983  		simple: make(map[string]*SimpleMethods),
   984  	}
   985  	a0 := &SimpleMethods{root: root, id: "a0"}
   986  	root.simple["a0"] = a0
   987  	client, _, srvDone, _ := newRPCClientServer(c, root, nil, false)
   988  	defer closeClient(c, client, srvDone)
   989  
   990  	var ret stringVal
   991  	arg0 := struct {
   992  		X map[string]int
   993  	}{
   994  		X: map[string]int{"hello": 65},
   995  	}
   996  	err := client.Call(rpc.Request{"SimpleMethods", 0, "a0", "SliceArg"}, arg0, &ret)
   997  	c.Assert(err, gc.ErrorMatches, `json: cannot unmarshal object into Go (?:value)|(?:struct field \.X) of type \[\]string`)
   998  
   999  	err = client.Call(rpc.Request{"SimpleMethods", 0, "a0", "SliceArg"}, arg0, &ret)
  1000  	c.Assert(err, gc.ErrorMatches, `json: cannot unmarshal object into Go (?:value)|(?:struct field \.X) of type \[\]string`)
  1001  
  1002  	arg1 := struct {
  1003  		X []string
  1004  	}{
  1005  		X: []string{"one"},
  1006  	}
  1007  	err = client.Call(rpc.Request{"SimpleMethods", 0, "a0", "SliceArg"}, arg1, &ret)
  1008  	c.Assert(err, jc.ErrorIsNil)
  1009  	c.Assert(ret.Val, gc.Equals, "SliceArg ret")
  1010  }
  1011  
  1012  func (*rpcSuite) TestErrorAfterClientClose(c *gc.C) {
  1013  	client, _, srvDone, _ := newRPCClientServer(c, &Root{}, nil, false)
  1014  	err := client.Close()
  1015  	c.Assert(err, jc.ErrorIsNil)
  1016  	err = client.Call(rpc.Request{"Foo", 0, "", "Bar"}, nil, nil)
  1017  	c.Assert(errors.Cause(err), gc.Equals, rpc.ErrShutdown)
  1018  	err = chanReadError(c, srvDone, "server done")
  1019  	c.Assert(err, jc.ErrorIsNil)
  1020  }
  1021  
  1022  func (*rpcSuite) TestClientCloseIdempotent(c *gc.C) {
  1023  	client, _, _, _ := newRPCClientServer(c, &Root{}, nil, false)
  1024  	err := client.Close()
  1025  	c.Assert(err, jc.ErrorIsNil)
  1026  	err = client.Close()
  1027  	c.Assert(err, jc.ErrorIsNil)
  1028  	err = client.Close()
  1029  	c.Assert(err, jc.ErrorIsNil)
  1030  }
  1031  
  1032  func (*rpcSuite) TestBidirectional(c *gc.C) {
  1033  	srvRoot := &Root{}
  1034  	client, _, srvDone, _ := newRPCClientServer(c, srvRoot, nil, true)
  1035  	defer closeClient(c, client, srvDone)
  1036  	clientRoot := &Root{conn: client}
  1037  	client.Serve(clientRoot, nil, nil)
  1038  	var r int64val
  1039  	err := client.Call(rpc.Request{"CallbackMethods", 0, "", "Factorial"}, int64val{12}, &r)
  1040  	c.Assert(err, jc.ErrorIsNil)
  1041  	c.Assert(r.I, gc.Equals, int64(479001600))
  1042  }
  1043  
  1044  func (*rpcSuite) TestServerRequestWhenNotServing(c *gc.C) {
  1045  	srvRoot := &Root{}
  1046  	client, _, srvDone, _ := newRPCClientServer(c, srvRoot, nil, true)
  1047  	defer closeClient(c, client, srvDone)
  1048  	var r int64val
  1049  	err := client.Call(rpc.Request{"CallbackMethods", 0, "", "Factorial"}, int64val{12}, &r)
  1050  	c.Assert(err, gc.ErrorMatches, "no service")
  1051  }
  1052  
  1053  func (*rpcSuite) TestChangeAPI(c *gc.C) {
  1054  	srvRoot := &Root{}
  1055  	client, _, srvDone, _ := newRPCClientServer(c, srvRoot, nil, true)
  1056  	defer closeClient(c, client, srvDone)
  1057  	var s stringVal
  1058  	err := client.Call(rpc.Request{"NewlyAvailable", 0, "", "NewMethod"}, nil, &s)
  1059  	c.Assert(err, gc.ErrorMatches, `unknown object type "NewlyAvailable" \(not implemented\)`)
  1060  	err = client.Call(rpc.Request{"ChangeAPIMethods", 0, "", "ChangeAPI"}, nil, nil)
  1061  	c.Assert(err, jc.ErrorIsNil)
  1062  	err = client.Call(rpc.Request{"ChangeAPIMethods", 0, "", "ChangeAPI"}, nil, nil)
  1063  	c.Assert(err, gc.ErrorMatches, `unknown object type "ChangeAPIMethods" \(not implemented\)`)
  1064  	err = client.Call(rpc.Request{"NewlyAvailable", 0, "", "NewMethod"}, nil, &s)
  1065  	c.Assert(err, jc.ErrorIsNil)
  1066  	c.Assert(s, gc.Equals, stringVal{"new method result"})
  1067  }
  1068  
  1069  func (*rpcSuite) TestChangeAPIToNil(c *gc.C) {
  1070  	srvRoot := &Root{}
  1071  	client, _, srvDone, _ := newRPCClientServer(c, srvRoot, nil, true)
  1072  	defer closeClient(c, client, srvDone)
  1073  
  1074  	err := client.Call(rpc.Request{"ChangeAPIMethods", 0, "", "RemoveAPI"}, nil, nil)
  1075  	c.Assert(err, jc.ErrorIsNil)
  1076  
  1077  	err = client.Call(rpc.Request{"ChangeAPIMethods", 0, "", "RemoveAPI"}, nil, nil)
  1078  	c.Assert(err, gc.ErrorMatches, "no service")
  1079  }
  1080  
  1081  func (*rpcSuite) TestChangeAPIWhileServingRequest(c *gc.C) {
  1082  	ready := make(chan struct{})
  1083  	done := make(chan error)
  1084  	srvRoot := &Root{
  1085  		delayed: map[string]*DelayedMethods{
  1086  			"1": {ready: ready, doneError: done},
  1087  		},
  1088  	}
  1089  	transform := func(err error) error {
  1090  		return fmt.Errorf("transformed: %v", err)
  1091  	}
  1092  	client, _, srvDone, _ := newRPCClientServer(c, srvRoot, transform, true)
  1093  	defer closeClient(c, client, srvDone)
  1094  
  1095  	result := make(chan error)
  1096  	go func() {
  1097  		result <- client.Call(rpc.Request{"DelayedMethods", 0, "1", "Delay"}, nil, nil)
  1098  	}()
  1099  	chanRead(c, ready, "method ready")
  1100  
  1101  	err := client.Call(rpc.Request{"ChangeAPIMethods", 0, "", "ChangeAPI"}, nil, nil)
  1102  	c.Assert(err, jc.ErrorIsNil)
  1103  
  1104  	// Ensure that not only does the request in progress complete,
  1105  	// but that the original transformErrors function is called.
  1106  	done <- fmt.Errorf("an error")
  1107  	select {
  1108  	case r := <-result:
  1109  		c.Assert(r, gc.ErrorMatches, "transformed: an error")
  1110  	case <-time.After(3 * time.Second):
  1111  		c.Fatalf("timeout on channel read")
  1112  	}
  1113  }
  1114  
  1115  func (*rpcSuite) TestCodeNotImplementedMatchesAPIserverParams(c *gc.C) {
  1116  	c.Assert(rpc.CodeNotImplemented, gc.Equals, params.CodeNotImplemented)
  1117  }
  1118  
  1119  func (*rpcSuite) TestRequestContext(c *gc.C) {
  1120  	root := &Root{}
  1121  	root.contextInst = &ContextMethods{root: root}
  1122  
  1123  	client, _, srvDone, _ := newRPCClientServer(c, root, nil, false)
  1124  	defer closeClient(c, client, srvDone)
  1125  
  1126  	call := func(method string, arg, ret interface{}) (passedArg interface{}) {
  1127  		root.calls = nil
  1128  		root.contextInst.callContext = nil
  1129  		err := client.Call(rpc.Request{"ContextMethods", 0, "", method}, arg, ret)
  1130  		c.Assert(err, jc.ErrorIsNil)
  1131  		c.Assert(root.calls, gc.HasLen, 1)
  1132  		info := root.calls[0]
  1133  		c.Assert(info.rcvr, gc.Equals, root.contextInst)
  1134  		c.Assert(info.method, gc.Equals, method)
  1135  		c.Assert(root.contextInst.callContext, gc.NotNil)
  1136  		// context is cancelled when the method returns
  1137  		c.Assert(root.contextInst.callContext.Err(), gc.Equals, context.Canceled)
  1138  		return info.arg
  1139  	}
  1140  
  1141  	arg := call("Call0", nil, nil)
  1142  	c.Assert(arg, gc.IsNil)
  1143  
  1144  	arg = call("Call1", stringVal{"foo"}, nil)
  1145  	c.Assert(arg, gc.Equals, stringVal{"foo"})
  1146  }
  1147  
  1148  func (*rpcSuite) TestConnectionContextCloseClient(c *gc.C) {
  1149  	root := &Root{}
  1150  	root.contextInst = &ContextMethods{
  1151  		root:    root,
  1152  		waiting: make(chan struct{}),
  1153  	}
  1154  
  1155  	client, _, srvDone, _ := newRPCClientServer(c, root, nil, false)
  1156  	defer closeClient(c, client, srvDone)
  1157  
  1158  	errch := make(chan error, 1)
  1159  	go func() {
  1160  		errch <- client.Call(rpc.Request{"ContextMethods", 0, "", "Wait"}, nil, nil)
  1161  	}()
  1162  
  1163  	<-root.contextInst.waiting
  1164  	err := client.Close()
  1165  	c.Assert(err, jc.ErrorIsNil)
  1166  
  1167  	err = <-errch
  1168  	c.Assert(err, jc.Satisfies, rpc.IsShutdownErr)
  1169  }
  1170  
  1171  func (*rpcSuite) TestConnectionContextCloseServer(c *gc.C) {
  1172  	root := &Root{}
  1173  	root.contextInst = &ContextMethods{
  1174  		root:    root,
  1175  		waiting: make(chan struct{}),
  1176  	}
  1177  
  1178  	client, server, srvDone, _ := newRPCClientServer(c, root, nil, false)
  1179  	defer closeClient(c, client, srvDone)
  1180  
  1181  	errch := make(chan error, 1)
  1182  	go func() {
  1183  		errch <- client.Call(rpc.Request{"ContextMethods", 0, "", "Wait"}, nil, nil)
  1184  	}()
  1185  
  1186  	<-root.contextInst.waiting
  1187  	err := server.Close()
  1188  	c.Assert(err, jc.ErrorIsNil)
  1189  
  1190  	err = <-errch
  1191  	c.Assert(err, gc.ErrorMatches, "context canceled")
  1192  }
  1193  
  1194  func (s *rpcSuite) TestRecorderErrorPreventsRequest(c *gc.C) {
  1195  	root := &Root{
  1196  		simple: make(map[string]*SimpleMethods),
  1197  	}
  1198  	root.simple["a0"] = &SimpleMethods{
  1199  		root: root,
  1200  		id:   "a0",
  1201  	}
  1202  	client, server, srvDone, notifier := newRPCClientServer(c, root, nil, false)
  1203  	defer closeClient(c, client, srvDone)
  1204  	notifier.errors = []error{errors.Errorf("explodo"), errors.Errorf("pyronica")}
  1205  
  1206  	err := client.Call(rpc.Request{"SimpleMethods", 0, "a0", "Call0r0"}, nil, nil)
  1207  	c.Assert(err, gc.ErrorMatches, "explodo")
  1208  
  1209  	err = server.Close()
  1210  	c.Assert(err, jc.ErrorIsNil)
  1211  
  1212  	c.Assert(root.calls, gc.HasLen, 0)
  1213  }
  1214  
  1215  func chanReadError(c *gc.C, ch <-chan error, what string) error {
  1216  	select {
  1217  	case e := <-ch:
  1218  		return e
  1219  	case <-time.After(3 * time.Second):
  1220  		c.Fatalf("timeout on channel read %s", what)
  1221  	}
  1222  	panic("unreachable")
  1223  }
  1224  
  1225  // newRPCClientServer starts an RPC server serving a connection from a
  1226  // single client.  When the server has finished serving the connection,
  1227  // it sends a value on the returned channel.
  1228  // If bidir is true, requests can flow in both directions.
  1229  func newRPCClientServer(
  1230  	c *gc.C,
  1231  	root interface{},
  1232  	tfErr func(error) error,
  1233  	bidir bool,
  1234  ) (client, server *rpc.Conn, srvDone chan error, serverNotifier *notifier) {
  1235  	l, err := net.Listen("tcp", "127.0.0.1:0")
  1236  	c.Assert(err, jc.ErrorIsNil)
  1237  
  1238  	srvDone = make(chan error, 1)
  1239  	serverNotifier = new(notifier)
  1240  	srvStarted := make(chan *rpc.Conn)
  1241  	go func() {
  1242  		defer close(srvDone)
  1243  		defer close(srvStarted)
  1244  		defer l.Close()
  1245  
  1246  		conn, err := l.Accept()
  1247  		if err != nil {
  1248  			srvDone <- err
  1249  			return
  1250  		}
  1251  
  1252  		role := roleServer
  1253  		if bidir {
  1254  			role = roleBoth
  1255  		}
  1256  		recorderFactory := func() rpc.Recorder { return serverNotifier }
  1257  		rpcConn := rpc.NewConn(NewJSONCodec(conn, role), recorderFactory)
  1258  		if custroot, ok := root.(*CustomRoot); ok {
  1259  			rpcConn.ServeRoot(custroot, recorderFactory, tfErr)
  1260  			custroot.root.conn = rpcConn
  1261  		} else {
  1262  			rpcConn.Serve(root, recorderFactory, tfErr)
  1263  		}
  1264  		if root, ok := root.(*Root); ok {
  1265  			root.conn = rpcConn
  1266  		}
  1267  		rpcConn.Start(context.Background())
  1268  		srvStarted <- rpcConn
  1269  		<-rpcConn.Dead()
  1270  		srvDone <- rpcConn.Close()
  1271  	}()
  1272  	conn, err := net.Dial("tcp", l.Addr().String())
  1273  	c.Assert(err, jc.ErrorIsNil)
  1274  	server = <-srvStarted
  1275  	if server == nil {
  1276  		conn.Close()
  1277  		c.Fatal(<-srvDone)
  1278  	}
  1279  	role := roleClient
  1280  	if bidir {
  1281  		role = roleBoth
  1282  	}
  1283  	client = rpc.NewConn(NewJSONCodec(conn, role), nil)
  1284  	client.Start(context.Background())
  1285  	return client, server, srvDone, serverNotifier
  1286  }
  1287  
  1288  func closeClient(c *gc.C, client *rpc.Conn, srvDone <-chan error) {
  1289  	err := client.Close()
  1290  	c.Assert(err, jc.ErrorIsNil)
  1291  	err = chanReadError(c, srvDone, "server done")
  1292  	c.Assert(err, jc.ErrorIsNil)
  1293  }
  1294  
  1295  type encoder interface {
  1296  	Encode(e interface{}) error
  1297  }
  1298  
  1299  type decoder interface {
  1300  	Decode(e interface{}) error
  1301  }
  1302  
  1303  // testCodec wraps an rpc.Codec with extra error checking code.
  1304  type testCodec struct {
  1305  	role connRole
  1306  	rpc.Codec
  1307  }
  1308  
  1309  func (c *testCodec) WriteMessage(hdr *rpc.Header, x interface{}) error {
  1310  	if reflect.ValueOf(x).Kind() != reflect.Struct {
  1311  		panic(fmt.Errorf("WriteRequest bad param; want struct got %T (%#v)", x, x))
  1312  	}
  1313  	if c.role != roleBoth && hdr.IsRequest() != (c.role == roleClient) {
  1314  		panic(fmt.Errorf("codec role %v; header wrong type %#v", c.role, hdr))
  1315  	}
  1316  	logger.Infof("send header: %#v; body: %#v", hdr, x)
  1317  	return c.Codec.WriteMessage(hdr, x)
  1318  }
  1319  
  1320  func (c *testCodec) ReadHeader(hdr *rpc.Header) error {
  1321  	err := c.Codec.ReadHeader(hdr)
  1322  	if err != nil {
  1323  		return err
  1324  	}
  1325  	logger.Infof("got header %#v", hdr)
  1326  	if c.role != roleBoth && hdr.IsRequest() == (c.role == roleClient) {
  1327  		panic(fmt.Errorf("codec role %v; read wrong type %#v", c.role, hdr))
  1328  	}
  1329  	return nil
  1330  }
  1331  
  1332  func (c *testCodec) ReadBody(r interface{}, isRequest bool) error {
  1333  	if v := reflect.ValueOf(r); v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
  1334  		panic(fmt.Errorf("ReadResponseBody bad destination; want *struct got %T", r))
  1335  	}
  1336  	if c.role != roleBoth && isRequest == (c.role == roleClient) {
  1337  		panic(fmt.Errorf("codec role %v; read wrong body type %#v", c.role, r))
  1338  	}
  1339  	// Note: this will need to change if we want to test a non-JSON codec.
  1340  	var m json.RawMessage
  1341  	err := c.Codec.ReadBody(&m, isRequest)
  1342  	if err != nil {
  1343  		return err
  1344  	}
  1345  	logger.Infof("got response body: %q", m)
  1346  	err = json.Unmarshal(m, r)
  1347  	logger.Infof("unmarshalled into %#v", r)
  1348  	return err
  1349  }
  1350  
  1351  type connRole string
  1352  
  1353  const (
  1354  	roleBoth   connRole = "both"
  1355  	roleClient connRole = "client"
  1356  	roleServer connRole = "server"
  1357  )
  1358  
  1359  func NewJSONCodec(c net.Conn, role connRole) rpc.Codec {
  1360  	return &testCodec{
  1361  		role:  role,
  1362  		Codec: jsoncodec.NewNet(c),
  1363  	}
  1364  }
  1365  
  1366  type requestEvent struct {
  1367  	hdr  rpc.Header
  1368  	body interface{}
  1369  }
  1370  
  1371  type replyEvent struct {
  1372  	req  rpc.Request
  1373  	hdr  rpc.Header
  1374  	body interface{}
  1375  }
  1376  
  1377  type notifier struct {
  1378  	mu             sync.Mutex
  1379  	serverRequests []requestEvent
  1380  	serverReplies  []replyEvent
  1381  	errors         []error
  1382  }
  1383  
  1384  func (n *notifier) reset() {
  1385  	n.mu.Lock()
  1386  	defer n.mu.Unlock()
  1387  	n.serverRequests = nil
  1388  	n.serverReplies = nil
  1389  	n.errors = nil
  1390  }
  1391  
  1392  func (n *notifier) nextErr() error {
  1393  	if len(n.errors) == 0 {
  1394  		return nil
  1395  	}
  1396  	err := n.errors[0]
  1397  	n.errors = n.errors[1:]
  1398  	return err
  1399  }
  1400  
  1401  func (n *notifier) HandleRequest(hdr *rpc.Header, body interface{}) error {
  1402  	n.mu.Lock()
  1403  	defer n.mu.Unlock()
  1404  	n.serverRequests = append(n.serverRequests, requestEvent{
  1405  		hdr:  *hdr,
  1406  		body: body,
  1407  	})
  1408  
  1409  	return n.nextErr()
  1410  }
  1411  
  1412  func (n *notifier) HandleReply(req rpc.Request, hdr *rpc.Header, body interface{}) error {
  1413  	n.mu.Lock()
  1414  	defer n.mu.Unlock()
  1415  	n.serverReplies = append(n.serverReplies, replyEvent{
  1416  		req:  req,
  1417  		hdr:  *hdr,
  1418  		body: body,
  1419  	})
  1420  	return n.nextErr()
  1421  }