github.com/rogpeppe/juju@v0.0.0-20140613142852-6337964b789e/rpc/rpc_test.go (about)

     1  // Copyright 2012, 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package rpc_test
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  	"net"
    10  	"reflect"
    11  	"regexp"
    12  	"sync"
    13  	stdtesting "testing"
    14  	"time"
    15  
    16  	"github.com/juju/loggo"
    17  	gc "launchpad.net/gocheck"
    18  
    19  	"github.com/juju/juju/rpc"
    20  	"github.com/juju/juju/rpc/jsoncodec"
    21  	"github.com/juju/juju/testing"
    22  )
    23  
    24  var logger = loggo.GetLogger("juju.rpc")
    25  
    26  type rpcSuite struct {
    27  	testing.BaseSuite
    28  }
    29  
    30  var _ = gc.Suite(&rpcSuite{})
    31  
    32  func TestAll(t *stdtesting.T) {
    33  	gc.TestingT(t)
    34  }
    35  
    36  type callInfo struct {
    37  	rcvr   interface{}
    38  	method string
    39  	arg    interface{}
    40  }
    41  
    42  type callError callInfo
    43  
    44  func (e *callError) Error() string {
    45  	return fmt.Sprintf("error calling %s", e.method)
    46  }
    47  
    48  type stringVal struct {
    49  	Val string
    50  }
    51  
    52  type Root struct {
    53  	mu        sync.Mutex
    54  	conn      *rpc.Conn
    55  	calls     []*callInfo
    56  	returnErr bool
    57  	simple    map[string]*SimpleMethods
    58  	delayed   map[string]*DelayedMethods
    59  	errorInst *ErrorMethods
    60  }
    61  
    62  func (r *Root) callError(rcvr interface{}, name string, arg interface{}) error {
    63  	if r.returnErr {
    64  		return &callError{rcvr, name, arg}
    65  	}
    66  	return nil
    67  }
    68  
    69  func (r *Root) SimpleMethods(id string) (*SimpleMethods, error) {
    70  	r.mu.Lock()
    71  	defer r.mu.Unlock()
    72  	if a := r.simple[id]; a != nil {
    73  		return a, nil
    74  	}
    75  	return nil, fmt.Errorf("unknown SimpleMethods id")
    76  }
    77  
    78  func (r *Root) DelayedMethods(id string) (*DelayedMethods, error) {
    79  	r.mu.Lock()
    80  	defer r.mu.Unlock()
    81  	if a := r.delayed[id]; a != nil {
    82  		return a, nil
    83  	}
    84  	return nil, fmt.Errorf("unknown DelayedMethods id")
    85  }
    86  
    87  func (r *Root) ErrorMethods(id string) (*ErrorMethods, error) {
    88  	if r.errorInst == nil {
    89  		return nil, fmt.Errorf("no error methods")
    90  	}
    91  	return r.errorInst, nil
    92  }
    93  
    94  func (r *Root) Discard1() {}
    95  
    96  func (r *Root) Discard2(id string) error { return nil }
    97  
    98  func (r *Root) Discard3(id string) int { return 0 }
    99  
   100  func (r *Root) CallbackMethods(string) (*CallbackMethods, error) {
   101  	return &CallbackMethods{r}, nil
   102  }
   103  
   104  func (r *Root) InterfaceMethods(id string) (InterfaceMethods, error) {
   105  	logger.Infof("interface methods called")
   106  	m, err := r.SimpleMethods(id)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	return m, nil
   111  }
   112  
   113  type InterfaceMethods interface {
   114  	Call1r1e(s stringVal) (stringVal, error)
   115  }
   116  
   117  type ChangeAPIMethods struct {
   118  	r *Root
   119  }
   120  
   121  func (r *Root) ChangeAPIMethods(string) (*ChangeAPIMethods, error) {
   122  	return &ChangeAPIMethods{r}, nil
   123  }
   124  
   125  func (t *Root) called(rcvr interface{}, method string, arg interface{}) {
   126  	t.mu.Lock()
   127  	t.calls = append(t.calls, &callInfo{rcvr, method, arg})
   128  	t.mu.Unlock()
   129  }
   130  
   131  type SimpleMethods struct {
   132  	root *Root
   133  	id   string
   134  }
   135  
   136  // Each Call method is named in this standard form:
   137  //
   138  //     Call<narg>r<nret><e>
   139  //
   140  // where narg is the number of arguments, nret is the number of returned
   141  // values (not including the error) and e is the letter 'e' if the
   142  // method returns an error.
   143  
   144  func (a *SimpleMethods) Call0r0() {
   145  	a.root.called(a, "Call0r0", nil)
   146  }
   147  
   148  func (a *SimpleMethods) Call0r1() stringVal {
   149  	a.root.called(a, "Call0r1", nil)
   150  	return stringVal{"Call0r1 ret"}
   151  }
   152  
   153  func (a *SimpleMethods) Call0r1e() (stringVal, error) {
   154  	a.root.called(a, "Call0r1e", nil)
   155  	return stringVal{"Call0r1e ret"}, a.root.callError(a, "Call0r1e", nil)
   156  }
   157  
   158  func (a *SimpleMethods) Call0r0e() error {
   159  	a.root.called(a, "Call0r0e", nil)
   160  	return a.root.callError(a, "Call0r0e", nil)
   161  }
   162  
   163  func (a *SimpleMethods) Call1r0(s stringVal) {
   164  	a.root.called(a, "Call1r0", s)
   165  }
   166  
   167  func (a *SimpleMethods) Call1r1(s stringVal) stringVal {
   168  	a.root.called(a, "Call1r1", s)
   169  	return stringVal{"Call1r1 ret"}
   170  }
   171  
   172  func (a *SimpleMethods) Call1r1e(s stringVal) (stringVal, error) {
   173  	a.root.called(a, "Call1r1e", s)
   174  	return stringVal{"Call1r1e ret"}, a.root.callError(a, "Call1r1e", s)
   175  }
   176  
   177  func (a *SimpleMethods) Call1r0e(s stringVal) error {
   178  	a.root.called(a, "Call1r0e", s)
   179  	return a.root.callError(a, "Call1r0e", s)
   180  }
   181  
   182  func (a *SimpleMethods) SliceArg(struct{ X []string }) stringVal {
   183  	return stringVal{"SliceArg ret"}
   184  }
   185  
   186  func (a *SimpleMethods) Discard1(int) {}
   187  
   188  func (a *SimpleMethods) Discard2(struct{}, struct{}) {}
   189  
   190  func (a *SimpleMethods) Discard3() int { return 0 }
   191  
   192  func (a *SimpleMethods) Discard4() (_, _ struct{}) { return }
   193  
   194  type DelayedMethods struct {
   195  	ready     chan struct{}
   196  	done      chan string
   197  	doneError chan error
   198  }
   199  
   200  func (a *DelayedMethods) Delay() (stringVal, error) {
   201  	if a.ready != nil {
   202  		a.ready <- struct{}{}
   203  	}
   204  	select {
   205  	case s := <-a.done:
   206  		return stringVal{s}, nil
   207  	case err := <-a.doneError:
   208  		return stringVal{}, err
   209  	}
   210  }
   211  
   212  type ErrorMethods struct {
   213  	err error
   214  }
   215  
   216  func (e *ErrorMethods) Call() error {
   217  	return e.err
   218  }
   219  
   220  type CallbackMethods struct {
   221  	root *Root
   222  }
   223  
   224  type int64val struct {
   225  	I int64
   226  }
   227  
   228  func (a *CallbackMethods) Factorial(x int64val) (int64val, error) {
   229  	if x.I <= 1 {
   230  		return int64val{1}, nil
   231  	}
   232  	var r int64val
   233  	err := a.root.conn.Call(rpc.Request{"CallbackMethods", "", "Factorial"}, int64val{x.I - 1}, &r)
   234  	if err != nil {
   235  		return int64val{}, err
   236  	}
   237  	return int64val{x.I * r.I}, nil
   238  }
   239  
   240  func (a *ChangeAPIMethods) ChangeAPI() {
   241  	a.r.conn.Serve(&changedAPIRoot{}, nil)
   242  }
   243  
   244  func (a *ChangeAPIMethods) RemoveAPI() {
   245  	a.r.conn.Serve(nil, nil)
   246  }
   247  
   248  type changedAPIRoot struct{}
   249  
   250  func (r *changedAPIRoot) NewlyAvailable(string) (newlyAvailableMethods, error) {
   251  	return newlyAvailableMethods{}, nil
   252  }
   253  
   254  type newlyAvailableMethods struct{}
   255  
   256  func (newlyAvailableMethods) NewMethod() stringVal {
   257  	return stringVal{"new method result"}
   258  }
   259  
   260  func (*rpcSuite) TestRPC(c *gc.C) {
   261  	root := &Root{
   262  		simple: make(map[string]*SimpleMethods),
   263  	}
   264  	root.simple["a99"] = &SimpleMethods{root: root, id: "a99"}
   265  	client, srvDone, clientNotifier, serverNotifier := newRPCClientServer(c, root, nil, false)
   266  	defer closeClient(c, client, srvDone)
   267  	for narg := 0; narg < 2; narg++ {
   268  		for nret := 0; nret < 2; nret++ {
   269  			for nerr := 0; nerr < 2; nerr++ {
   270  				retErr := nerr != 0
   271  				p := testCallParams{
   272  					client:         client,
   273  					clientNotifier: clientNotifier,
   274  					serverNotifier: serverNotifier,
   275  					entry:          "SimpleMethods",
   276  					narg:           narg,
   277  					nret:           nret,
   278  					retErr:         retErr,
   279  					testErr:        false,
   280  				}
   281  				root.testCall(c, p)
   282  				if retErr {
   283  					p.testErr = true
   284  					root.testCall(c, p)
   285  				}
   286  			}
   287  		}
   288  	}
   289  }
   290  
   291  func callName(narg, nret int, retErr bool) string {
   292  	e := ""
   293  	if retErr {
   294  		e = "e"
   295  	}
   296  	return fmt.Sprintf("Call%dr%d%s", narg, nret, e)
   297  }
   298  
   299  type testCallParams struct {
   300  	// client holds the client-side of the rpc connection that
   301  	// will be used to make the call.
   302  	client *rpc.Conn
   303  
   304  	// clientNotifier holds the notifier for the client side.
   305  	clientNotifier *notifier
   306  
   307  	// serverNotifier holds the notifier for the server side.
   308  	serverNotifier *notifier
   309  
   310  	// entry holds the top-level type that will be invoked
   311  	// (e.g. "SimpleMethods")
   312  	entry string
   313  
   314  	// narg holds the number of arguments accepted by the
   315  	// call (0 or 1)
   316  	narg int
   317  
   318  	// nret holds the number of values returned by the
   319  	// call (0 or 1).
   320  	nret int
   321  
   322  	// retErr specifies whether the call returns an error.
   323  	retErr bool
   324  
   325  	// testErr specifies whether the call should be made to return an error.
   326  	testErr bool
   327  }
   328  
   329  // request returns the RPC request for the test call.
   330  func (p testCallParams) request() rpc.Request {
   331  	return rpc.Request{
   332  		Type:   p.entry,
   333  		Id:     "a99",
   334  		Action: callName(p.narg, p.nret, p.retErr),
   335  	}
   336  }
   337  
   338  // error message returns the error message that the test call
   339  // should return if it returns an error.
   340  func (p testCallParams) errorMessage() string {
   341  	return fmt.Sprintf("error calling %s", p.request().Action)
   342  }
   343  
   344  func (root *Root) testCall(c *gc.C, p testCallParams) {
   345  	p.clientNotifier.reset()
   346  	p.serverNotifier.reset()
   347  	root.calls = nil
   348  	root.returnErr = p.testErr
   349  	c.Logf("test call %s", p.request().Action)
   350  	var r stringVal
   351  	err := p.client.Call(p.request(), stringVal{"arg"}, &r)
   352  	switch {
   353  	case p.retErr && p.testErr:
   354  		c.Assert(err, gc.DeepEquals, &rpc.RequestError{
   355  			Message: p.errorMessage(),
   356  		})
   357  		c.Assert(r, gc.Equals, stringVal{})
   358  	case p.nret > 0:
   359  		c.Assert(r, gc.Equals, stringVal{p.request().Action + " ret"})
   360  	}
   361  
   362  	// Check that the call was actually made, the right
   363  	// parameters were received and the right result returned.
   364  	root.mu.Lock()
   365  	defer root.mu.Unlock()
   366  
   367  	root.assertCallMade(c, p)
   368  
   369  	requestId := root.assertClientNotified(c, p, &r)
   370  
   371  	root.assertServerNotified(c, p, requestId)
   372  }
   373  
   374  func (root *Root) assertCallMade(c *gc.C, p testCallParams) {
   375  	expectCall := callInfo{
   376  		rcvr:   root.simple["a99"],
   377  		method: p.request().Action,
   378  	}
   379  	if p.narg > 0 {
   380  		expectCall.arg = stringVal{"arg"}
   381  	}
   382  	c.Assert(root.calls, gc.HasLen, 1)
   383  	c.Assert(*root.calls[0], gc.Equals, expectCall)
   384  }
   385  
   386  // assertClientNotified asserts that the right client notifications
   387  // were made for the given test call parameters. The value of r
   388  // holds the result parameter passed to the call.
   389  // It returns the request id.
   390  func (root *Root) assertClientNotified(c *gc.C, p testCallParams, r interface{}) uint64 {
   391  	c.Assert(p.clientNotifier.serverRequests, gc.HasLen, 0)
   392  	c.Assert(p.clientNotifier.serverReplies, gc.HasLen, 0)
   393  
   394  	// Test that there was a notification for the request.
   395  	c.Assert(p.clientNotifier.clientRequests, gc.HasLen, 1)
   396  	clientReq := p.clientNotifier.clientRequests[0]
   397  	requestId := clientReq.hdr.RequestId
   398  	clientReq.hdr.RequestId = 0 // Ignore the exact value of the request id to start with.
   399  	c.Assert(clientReq.hdr, gc.DeepEquals, rpc.Header{
   400  		Request: p.request(),
   401  	})
   402  	c.Assert(clientReq.body, gc.Equals, stringVal{"arg"})
   403  
   404  	// Test that there was a notification for the reply.
   405  	c.Assert(p.clientNotifier.clientReplies, gc.HasLen, 1)
   406  	clientReply := p.clientNotifier.clientReplies[0]
   407  	c.Assert(clientReply.req, gc.Equals, p.request())
   408  	if p.retErr && p.testErr {
   409  		c.Assert(clientReply.body, gc.Equals, nil)
   410  	} else {
   411  		c.Assert(clientReply.body, gc.Equals, r)
   412  	}
   413  	if p.retErr && p.testErr {
   414  		c.Assert(clientReply.hdr, gc.DeepEquals, rpc.Header{
   415  			RequestId: requestId,
   416  			Error:     p.errorMessage(),
   417  		})
   418  	} else {
   419  		c.Assert(clientReply.hdr, gc.DeepEquals, rpc.Header{
   420  			RequestId: requestId,
   421  		})
   422  	}
   423  	return requestId
   424  }
   425  
   426  // assertServerNotified asserts that the right server notifications
   427  // were made for the given test call parameters. The id of the request
   428  // is held in requestId.
   429  func (root *Root) assertServerNotified(c *gc.C, p testCallParams, requestId uint64) {
   430  	// Check that the right server notifications were made.
   431  	c.Assert(p.serverNotifier.clientRequests, gc.HasLen, 0)
   432  	c.Assert(p.serverNotifier.clientReplies, gc.HasLen, 0)
   433  
   434  	// Test that there was a notification for the request.
   435  	c.Assert(p.serverNotifier.serverRequests, gc.HasLen, 1)
   436  	serverReq := p.serverNotifier.serverRequests[0]
   437  	c.Assert(serverReq.hdr, gc.DeepEquals, rpc.Header{
   438  		RequestId: requestId,
   439  		Request:   p.request(),
   440  	})
   441  	if p.narg > 0 {
   442  		c.Assert(serverReq.body, gc.Equals, stringVal{"arg"})
   443  	} else {
   444  		c.Assert(serverReq.body, gc.Equals, struct{}{})
   445  	}
   446  
   447  	// Test that there was a notification for the reply.
   448  	c.Assert(p.serverNotifier.serverReplies, gc.HasLen, 1)
   449  	serverReply := p.serverNotifier.serverReplies[0]
   450  	c.Assert(serverReply.req, gc.Equals, p.request())
   451  	if p.retErr && p.testErr || p.nret == 0 {
   452  		c.Assert(serverReply.body, gc.Equals, struct{}{})
   453  	} else {
   454  		c.Assert(serverReply.body, gc.Equals, stringVal{p.request().Action + " ret"})
   455  	}
   456  	if p.retErr && p.testErr {
   457  		c.Assert(serverReply.hdr, gc.Equals, rpc.Header{
   458  			RequestId: requestId,
   459  			Error:     p.errorMessage(),
   460  		})
   461  	} else {
   462  		c.Assert(serverReply.hdr, gc.Equals, rpc.Header{
   463  			RequestId: requestId,
   464  		})
   465  	}
   466  }
   467  
   468  func (*rpcSuite) TestInterfaceMethods(c *gc.C) {
   469  	root := &Root{
   470  		simple: make(map[string]*SimpleMethods),
   471  	}
   472  	root.simple["a99"] = &SimpleMethods{root: root, id: "a99"}
   473  	client, srvDone, clientNotifier, serverNotifier := newRPCClientServer(c, root, nil, false)
   474  	defer closeClient(c, client, srvDone)
   475  	p := testCallParams{
   476  		client:         client,
   477  		clientNotifier: clientNotifier,
   478  		serverNotifier: serverNotifier,
   479  		entry:          "InterfaceMethods",
   480  		narg:           1,
   481  		nret:           1,
   482  		retErr:         true,
   483  		testErr:        false,
   484  	}
   485  
   486  	root.testCall(c, p)
   487  	p.testErr = true
   488  	root.testCall(c, p)
   489  }
   490  
   491  func (*rpcSuite) TestConcurrentCalls(c *gc.C) {
   492  	start1 := make(chan string)
   493  	start2 := make(chan string)
   494  	ready1 := make(chan struct{})
   495  	ready2 := make(chan struct{})
   496  
   497  	root := &Root{
   498  		delayed: map[string]*DelayedMethods{
   499  			"1": {ready: ready1, done: start1},
   500  			"2": {ready: ready2, done: start2},
   501  		},
   502  	}
   503  
   504  	client, srvDone, _, _ := newRPCClientServer(c, root, nil, false)
   505  	defer closeClient(c, client, srvDone)
   506  	call := func(id string, done chan<- struct{}) {
   507  		var r stringVal
   508  		err := client.Call(rpc.Request{"DelayedMethods", id, "Delay"}, nil, &r)
   509  		c.Check(err, gc.IsNil)
   510  		c.Check(r.Val, gc.Equals, "return "+id)
   511  		done <- struct{}{}
   512  	}
   513  	done1 := make(chan struct{})
   514  	done2 := make(chan struct{})
   515  	go call("1", done1)
   516  	go call("2", done2)
   517  
   518  	// Check that both calls are running concurrently.
   519  	chanRead(c, ready1, "method 1 ready")
   520  	chanRead(c, ready2, "method 2 ready")
   521  
   522  	// Let the requests complete.
   523  	start1 <- "return 1"
   524  	start2 <- "return 2"
   525  	chanRead(c, done1, "method 1 done")
   526  	chanRead(c, done2, "method 2 done")
   527  }
   528  
   529  type codedError struct {
   530  	m    string
   531  	code string
   532  }
   533  
   534  func (e *codedError) Error() string {
   535  	return e.m
   536  }
   537  
   538  func (e *codedError) ErrorCode() string {
   539  	return e.code
   540  }
   541  
   542  func (*rpcSuite) TestErrorCode(c *gc.C) {
   543  	root := &Root{
   544  		errorInst: &ErrorMethods{&codedError{"message", "code"}},
   545  	}
   546  	client, srvDone, _, _ := newRPCClientServer(c, root, nil, false)
   547  	defer closeClient(c, client, srvDone)
   548  	err := client.Call(rpc.Request{"ErrorMethods", "", "Call"}, nil, nil)
   549  	c.Assert(err, gc.ErrorMatches, `request error: message \(code\)`)
   550  	c.Assert(err.(rpc.ErrorCoder).ErrorCode(), gc.Equals, "code")
   551  }
   552  
   553  func (*rpcSuite) TestTransformErrors(c *gc.C) {
   554  	root := &Root{
   555  		errorInst: &ErrorMethods{&codedError{"message", "code"}},
   556  	}
   557  	tfErr := func(err error) error {
   558  		c.Check(err, gc.NotNil)
   559  		if e, ok := err.(*codedError); ok {
   560  			return &codedError{
   561  				m:    "transformed: " + e.m,
   562  				code: "transformed: " + e.code,
   563  			}
   564  		}
   565  		return fmt.Errorf("transformed: %v", err)
   566  	}
   567  	client, srvDone, _, _ := newRPCClientServer(c, root, tfErr, false)
   568  	defer closeClient(c, client, srvDone)
   569  	err := client.Call(rpc.Request{"ErrorMethods", "", "Call"}, nil, nil)
   570  	c.Assert(err, gc.DeepEquals, &rpc.RequestError{
   571  		Message: "transformed: message",
   572  		Code:    "transformed: code",
   573  	})
   574  
   575  	root.errorInst.err = nil
   576  	err = client.Call(rpc.Request{"ErrorMethods", "", "Call"}, nil, nil)
   577  	c.Assert(err, gc.IsNil)
   578  
   579  	root.errorInst = nil
   580  	err = client.Call(rpc.Request{"ErrorMethods", "", "Call"}, nil, nil)
   581  	c.Assert(err, gc.DeepEquals, &rpc.RequestError{
   582  		Message: "transformed: no error methods",
   583  	})
   584  
   585  }
   586  
   587  func (*rpcSuite) TestServerWaitsForOutstandingCalls(c *gc.C) {
   588  	ready := make(chan struct{})
   589  	start := make(chan string)
   590  	root := &Root{
   591  		delayed: map[string]*DelayedMethods{
   592  			"1": {
   593  				ready: ready,
   594  				done:  start,
   595  			},
   596  		},
   597  	}
   598  	client, srvDone, _, _ := newRPCClientServer(c, root, nil, false)
   599  	defer closeClient(c, client, srvDone)
   600  	done := make(chan struct{})
   601  	go func() {
   602  		var r stringVal
   603  		err := client.Call(rpc.Request{"DelayedMethods", "1", "Delay"}, nil, &r)
   604  		c.Check(err, gc.Equals, rpc.ErrShutdown)
   605  		done <- struct{}{}
   606  	}()
   607  	chanRead(c, ready, "DelayedMethods.Delay ready")
   608  	client.Close()
   609  	select {
   610  	case err := <-srvDone:
   611  		c.Fatalf("server returned while outstanding operation in progress: %v", err)
   612  		<-done
   613  	case <-time.After(25 * time.Millisecond):
   614  	}
   615  	start <- "xxx"
   616  }
   617  
   618  func chanRead(c *gc.C, ch <-chan struct{}, what string) {
   619  	select {
   620  	case <-ch:
   621  		return
   622  	case <-time.After(3 * time.Second):
   623  		c.Fatalf("timeout on channel read %s", what)
   624  	}
   625  }
   626  
   627  func (*rpcSuite) TestCompatibility(c *gc.C) {
   628  	root := &Root{
   629  		simple: make(map[string]*SimpleMethods),
   630  	}
   631  	a0 := &SimpleMethods{root: root, id: "a0"}
   632  	root.simple["a0"] = a0
   633  
   634  	client, srvDone, _, _ := newRPCClientServer(c, root, nil, false)
   635  	defer closeClient(c, client, srvDone)
   636  	call := func(method string, arg, ret interface{}) (passedArg interface{}) {
   637  		root.calls = nil
   638  		err := client.Call(rpc.Request{"SimpleMethods", "a0", method}, arg, ret)
   639  		c.Assert(err, gc.IsNil)
   640  		c.Assert(root.calls, gc.HasLen, 1)
   641  		info := root.calls[0]
   642  		c.Assert(info.rcvr, gc.Equals, a0)
   643  		c.Assert(info.method, gc.Equals, method)
   644  		return info.arg
   645  	}
   646  	type extra struct {
   647  		Val   string
   648  		Extra string
   649  	}
   650  	// Extra fields in request and response.
   651  	var r extra
   652  	arg := call("Call1r1", extra{"x", "y"}, &r)
   653  	c.Assert(arg, gc.Equals, stringVal{"x"})
   654  
   655  	// Nil argument as request.
   656  	r = extra{}
   657  	arg = call("Call1r1", nil, &r)
   658  	c.Assert(arg, gc.Equals, stringVal{})
   659  
   660  	// Nil argument as response.
   661  	arg = call("Call1r1", stringVal{"x"}, nil)
   662  	c.Assert(arg, gc.Equals, stringVal{"x"})
   663  
   664  	// Non-nil argument for no response.
   665  	r = extra{}
   666  	arg = call("Call1r0", stringVal{"x"}, &r)
   667  	c.Assert(arg, gc.Equals, stringVal{"x"})
   668  	c.Assert(r, gc.Equals, extra{})
   669  }
   670  
   671  func (*rpcSuite) TestBadCall(c *gc.C) {
   672  	root := &Root{
   673  		simple: make(map[string]*SimpleMethods),
   674  	}
   675  	a0 := &SimpleMethods{root: root, id: "a0"}
   676  	root.simple["a0"] = a0
   677  	client, srvDone, clientNotifier, serverNotifier := newRPCClientServer(c, root, nil, false)
   678  	defer closeClient(c, client, srvDone)
   679  
   680  	testBadCall(c, client, clientNotifier, serverNotifier,
   681  		rpc.Request{"BadSomething", "a0", "No"},
   682  		`unknown object type "BadSomething"`,
   683  		rpc.CodeNotImplemented,
   684  		false,
   685  	)
   686  	testBadCall(c, client, clientNotifier, serverNotifier,
   687  		rpc.Request{"SimpleMethods", "xx", "No"},
   688  		"no such request - method SimpleMethods.No is not implemented",
   689  		rpc.CodeNotImplemented,
   690  		false,
   691  	)
   692  	testBadCall(c, client, clientNotifier, serverNotifier,
   693  		rpc.Request{"SimpleMethods", "xx", "Call0r0"},
   694  		`unknown SimpleMethods id`,
   695  		"",
   696  		true,
   697  	)
   698  }
   699  
   700  func testBadCall(
   701  	c *gc.C,
   702  	client *rpc.Conn,
   703  	clientNotifier, serverNotifier *notifier,
   704  	req rpc.Request,
   705  	expectedErr string,
   706  	expectedErrCode string,
   707  	requestKnown bool,
   708  ) {
   709  	clientNotifier.reset()
   710  	serverNotifier.reset()
   711  	err := client.Call(req, nil, nil)
   712  	msg := expectedErr
   713  	if expectedErrCode != "" {
   714  		msg += " (" + expectedErrCode + ")"
   715  	}
   716  	c.Assert(err, gc.ErrorMatches, regexp.QuoteMeta("request error: "+msg))
   717  
   718  	// Test that there was a notification for the client request.
   719  	c.Assert(clientNotifier.clientRequests, gc.HasLen, 1)
   720  	clientReq := clientNotifier.clientRequests[0]
   721  	requestId := clientReq.hdr.RequestId
   722  	c.Assert(clientReq, gc.DeepEquals, requestEvent{
   723  		hdr: rpc.Header{
   724  			RequestId: requestId,
   725  			Request:   req,
   726  		},
   727  		body: struct{}{},
   728  	})
   729  	// Test that there was a notification for the client reply.
   730  	c.Assert(clientNotifier.clientReplies, gc.HasLen, 1)
   731  	clientReply := clientNotifier.clientReplies[0]
   732  	c.Assert(clientReply, gc.DeepEquals, replyEvent{
   733  		req: req,
   734  		hdr: rpc.Header{
   735  			RequestId: requestId,
   736  			Error:     expectedErr,
   737  			ErrorCode: expectedErrCode,
   738  		},
   739  	})
   740  
   741  	// Test that there was a notification for the server request.
   742  	c.Assert(serverNotifier.serverRequests, gc.HasLen, 1)
   743  	serverReq := serverNotifier.serverRequests[0]
   744  
   745  	// From docs on ServerRequest:
   746  	// 	If the request was not recognized or there was
   747  	//	an error reading the body, body will be nil.
   748  	var expectBody interface{}
   749  	if requestKnown {
   750  		expectBody = struct{}{}
   751  	}
   752  	c.Assert(serverReq, gc.DeepEquals, requestEvent{
   753  		hdr: rpc.Header{
   754  			RequestId: requestId,
   755  			Request:   req,
   756  		},
   757  		body: expectBody,
   758  	})
   759  
   760  	// Test that there was a notification for the server reply.
   761  	c.Assert(serverNotifier.serverReplies, gc.HasLen, 1)
   762  	serverReply := serverNotifier.serverReplies[0]
   763  	c.Assert(serverReply, gc.DeepEquals, replyEvent{
   764  		hdr: rpc.Header{
   765  			RequestId: requestId,
   766  			Error:     expectedErr,
   767  			ErrorCode: expectedErrCode,
   768  		},
   769  		req:  req,
   770  		body: struct{}{},
   771  	})
   772  }
   773  
   774  func (*rpcSuite) TestContinueAfterReadBodyError(c *gc.C) {
   775  	root := &Root{
   776  		simple: make(map[string]*SimpleMethods),
   777  	}
   778  	a0 := &SimpleMethods{root: root, id: "a0"}
   779  	root.simple["a0"] = a0
   780  	client, srvDone, _, _ := newRPCClientServer(c, root, nil, false)
   781  	defer closeClient(c, client, srvDone)
   782  
   783  	var ret stringVal
   784  	arg0 := struct {
   785  		X map[string]int
   786  	}{
   787  		X: map[string]int{"hello": 65},
   788  	}
   789  	err := client.Call(rpc.Request{"SimpleMethods", "a0", "SliceArg"}, arg0, &ret)
   790  	c.Assert(err, gc.ErrorMatches, `request error: json: cannot unmarshal object into Go value of type \[\]string`)
   791  
   792  	err = client.Call(rpc.Request{"SimpleMethods", "a0", "SliceArg"}, arg0, &ret)
   793  	c.Assert(err, gc.ErrorMatches, `request error: json: cannot unmarshal object into Go value of type \[\]string`)
   794  
   795  	arg1 := struct {
   796  		X []string
   797  	}{
   798  		X: []string{"one"},
   799  	}
   800  	err = client.Call(rpc.Request{"SimpleMethods", "a0", "SliceArg"}, arg1, &ret)
   801  	c.Assert(err, gc.IsNil)
   802  	c.Assert(ret.Val, gc.Equals, "SliceArg ret")
   803  }
   804  
   805  func (*rpcSuite) TestErrorAfterClientClose(c *gc.C) {
   806  	client, srvDone, _, _ := newRPCClientServer(c, &Root{}, nil, false)
   807  	err := client.Close()
   808  	c.Assert(err, gc.IsNil)
   809  	err = client.Call(rpc.Request{"Foo", "", "Bar"}, nil, nil)
   810  	c.Assert(err, gc.Equals, rpc.ErrShutdown)
   811  	err = chanReadError(c, srvDone, "server done")
   812  	c.Assert(err, gc.IsNil)
   813  }
   814  
   815  func (*rpcSuite) TestClientCloseIdempotent(c *gc.C) {
   816  	client, _, _, _ := newRPCClientServer(c, &Root{}, nil, false)
   817  	err := client.Close()
   818  	c.Assert(err, gc.IsNil)
   819  	err = client.Close()
   820  	c.Assert(err, gc.IsNil)
   821  	err = client.Close()
   822  	c.Assert(err, gc.IsNil)
   823  }
   824  
   825  type KillerRoot struct {
   826  	killed bool
   827  	Root
   828  }
   829  
   830  func (r *KillerRoot) Kill() {
   831  	r.killed = true
   832  }
   833  
   834  func (*rpcSuite) TestRootIsKilled(c *gc.C) {
   835  	root := &KillerRoot{}
   836  	client, srvDone, _, _ := newRPCClientServer(c, root, nil, false)
   837  	err := client.Close()
   838  	c.Assert(err, gc.IsNil)
   839  	err = chanReadError(c, srvDone, "server done")
   840  	c.Assert(err, gc.IsNil)
   841  	c.Assert(root.killed, gc.Equals, true)
   842  }
   843  
   844  func (*rpcSuite) TestBidirectional(c *gc.C) {
   845  	srvRoot := &Root{}
   846  	client, srvDone, _, _ := newRPCClientServer(c, srvRoot, nil, true)
   847  	defer closeClient(c, client, srvDone)
   848  	clientRoot := &Root{conn: client}
   849  	client.Serve(clientRoot, nil)
   850  	var r int64val
   851  	err := client.Call(rpc.Request{"CallbackMethods", "", "Factorial"}, int64val{12}, &r)
   852  	c.Assert(err, gc.IsNil)
   853  	c.Assert(r.I, gc.Equals, int64(479001600))
   854  }
   855  
   856  func (*rpcSuite) TestServerRequestWhenNotServing(c *gc.C) {
   857  	srvRoot := &Root{}
   858  	client, srvDone, _, _ := newRPCClientServer(c, srvRoot, nil, true)
   859  	defer closeClient(c, client, srvDone)
   860  	var r int64val
   861  	err := client.Call(rpc.Request{"CallbackMethods", "", "Factorial"}, int64val{12}, &r)
   862  	c.Assert(err, gc.ErrorMatches, "request error: request error: no service")
   863  }
   864  
   865  func (*rpcSuite) TestChangeAPI(c *gc.C) {
   866  	srvRoot := &Root{}
   867  	client, srvDone, _, _ := newRPCClientServer(c, srvRoot, nil, true)
   868  	defer closeClient(c, client, srvDone)
   869  	var s stringVal
   870  	err := client.Call(rpc.Request{"NewlyAvailable", "", "NewMethod"}, nil, &s)
   871  	c.Assert(err, gc.ErrorMatches, `request error: unknown object type "NewlyAvailable" \(not implemented\)`)
   872  	err = client.Call(rpc.Request{"ChangeAPIMethods", "", "ChangeAPI"}, nil, nil)
   873  	c.Assert(err, gc.IsNil)
   874  	err = client.Call(rpc.Request{"ChangeAPIMethods", "", "ChangeAPI"}, nil, nil)
   875  	c.Assert(err, gc.ErrorMatches, `request error: unknown object type "ChangeAPIMethods" \(not implemented\)`)
   876  	err = client.Call(rpc.Request{"NewlyAvailable", "", "NewMethod"}, nil, &s)
   877  	c.Assert(err, gc.IsNil)
   878  	c.Assert(s, gc.Equals, stringVal{"new method result"})
   879  }
   880  
   881  func (*rpcSuite) TestChangeAPIToNil(c *gc.C) {
   882  	srvRoot := &Root{}
   883  	client, srvDone, _, _ := newRPCClientServer(c, srvRoot, nil, true)
   884  	defer closeClient(c, client, srvDone)
   885  
   886  	err := client.Call(rpc.Request{"ChangeAPIMethods", "", "RemoveAPI"}, nil, nil)
   887  	c.Assert(err, gc.IsNil)
   888  
   889  	err = client.Call(rpc.Request{"ChangeAPIMethods", "", "RemoveAPI"}, nil, nil)
   890  	c.Assert(err, gc.ErrorMatches, "request error: no service")
   891  }
   892  
   893  func (*rpcSuite) TestChangeAPIWhileServingRequest(c *gc.C) {
   894  	ready := make(chan struct{})
   895  	done := make(chan error)
   896  	srvRoot := &Root{
   897  		delayed: map[string]*DelayedMethods{
   898  			"1": {ready: ready, doneError: done},
   899  		},
   900  	}
   901  	transform := func(err error) error {
   902  		return fmt.Errorf("transformed: %v", err)
   903  	}
   904  	client, srvDone, _, _ := newRPCClientServer(c, srvRoot, transform, true)
   905  	defer closeClient(c, client, srvDone)
   906  
   907  	result := make(chan error)
   908  	go func() {
   909  		result <- client.Call(rpc.Request{"DelayedMethods", "1", "Delay"}, nil, nil)
   910  	}()
   911  	chanRead(c, ready, "method ready")
   912  
   913  	err := client.Call(rpc.Request{"ChangeAPIMethods", "", "ChangeAPI"}, nil, nil)
   914  	c.Assert(err, gc.IsNil)
   915  
   916  	// Ensure that not only does the request in progress complete,
   917  	// but that the original transformErrors function is called.
   918  	done <- fmt.Errorf("an error")
   919  	select {
   920  	case r := <-result:
   921  		c.Assert(r, gc.ErrorMatches, "request error: transformed: an error")
   922  	case <-time.After(3 * time.Second):
   923  		c.Fatalf("timeout on channel read")
   924  	}
   925  }
   926  
   927  func chanReadError(c *gc.C, ch <-chan error, what string) error {
   928  	select {
   929  	case e := <-ch:
   930  		return e
   931  	case <-time.After(3 * time.Second):
   932  		c.Fatalf("timeout on channel read %s", what)
   933  	}
   934  	panic("unreachable")
   935  }
   936  
   937  // newRPCClientServer starts an RPC server serving a connection from a
   938  // single client.  When the server has finished serving the connection,
   939  // it sends a value on the returned channel.
   940  // If bidir is true, requests can flow in both directions.
   941  func newRPCClientServer(c *gc.C, root interface{}, tfErr func(error) error, bidir bool) (client *rpc.Conn, srvDone chan error, clientNotifier, serverNotifier *notifier) {
   942  	l, err := net.Listen("tcp", "127.0.0.1:0")
   943  	c.Assert(err, gc.IsNil)
   944  
   945  	srvDone = make(chan error, 1)
   946  	clientNotifier = new(notifier)
   947  	serverNotifier = new(notifier)
   948  	go func() {
   949  		conn, err := l.Accept()
   950  		if err != nil {
   951  			srvDone <- nil
   952  			return
   953  		}
   954  		defer l.Close()
   955  		role := roleServer
   956  		if bidir {
   957  			role = roleBoth
   958  		}
   959  		rpcConn := rpc.NewConn(NewJSONCodec(conn, role), serverNotifier)
   960  		rpcConn.Serve(root, tfErr)
   961  		if root, ok := root.(*Root); ok {
   962  			root.conn = rpcConn
   963  		}
   964  		rpcConn.Start()
   965  		<-rpcConn.Dead()
   966  		srvDone <- rpcConn.Close()
   967  	}()
   968  	conn, err := net.Dial("tcp", l.Addr().String())
   969  	c.Assert(err, gc.IsNil)
   970  	role := roleClient
   971  	if bidir {
   972  		role = roleBoth
   973  	}
   974  	client = rpc.NewConn(NewJSONCodec(conn, role), clientNotifier)
   975  	client.Start()
   976  	return client, srvDone, clientNotifier, serverNotifier
   977  }
   978  
   979  func closeClient(c *gc.C, client *rpc.Conn, srvDone <-chan error) {
   980  	err := client.Close()
   981  	c.Assert(err, gc.IsNil)
   982  	err = chanReadError(c, srvDone, "server done")
   983  	c.Assert(err, gc.IsNil)
   984  }
   985  
   986  type encoder interface {
   987  	Encode(e interface{}) error
   988  }
   989  
   990  type decoder interface {
   991  	Decode(e interface{}) error
   992  }
   993  
   994  // testCodec wraps an rpc.Codec with extra error checking code.
   995  type testCodec struct {
   996  	role connRole
   997  	rpc.Codec
   998  }
   999  
  1000  func (c *testCodec) WriteMessage(hdr *rpc.Header, x interface{}) error {
  1001  	if reflect.ValueOf(x).Kind() != reflect.Struct {
  1002  		panic(fmt.Errorf("WriteRequest bad param; want struct got %T (%#v)", x, x))
  1003  	}
  1004  	if c.role != roleBoth && hdr.IsRequest() != (c.role == roleClient) {
  1005  		panic(fmt.Errorf("codec role %v; header wrong type %#v", c.role, hdr))
  1006  	}
  1007  	logger.Infof("send header: %#v; body: %#v", hdr, x)
  1008  	return c.Codec.WriteMessage(hdr, x)
  1009  }
  1010  
  1011  func (c *testCodec) ReadHeader(hdr *rpc.Header) error {
  1012  	err := c.Codec.ReadHeader(hdr)
  1013  	if err != nil {
  1014  		return err
  1015  	}
  1016  	logger.Infof("got header %#v", hdr)
  1017  	if c.role != roleBoth && hdr.IsRequest() == (c.role == roleClient) {
  1018  		panic(fmt.Errorf("codec role %v; read wrong type %#v", c.role, hdr))
  1019  	}
  1020  	return nil
  1021  }
  1022  
  1023  func (c *testCodec) ReadBody(r interface{}, isRequest bool) error {
  1024  	if v := reflect.ValueOf(r); v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
  1025  		panic(fmt.Errorf("ReadResponseBody bad destination; want *struct got %T", r))
  1026  	}
  1027  	if c.role != roleBoth && isRequest == (c.role == roleClient) {
  1028  		panic(fmt.Errorf("codec role %v; read wrong body type %#v", c.role, r))
  1029  	}
  1030  	// Note: this will need to change if we want to test a non-JSON codec.
  1031  	var m json.RawMessage
  1032  	err := c.Codec.ReadBody(&m, isRequest)
  1033  	if err != nil {
  1034  		return err
  1035  	}
  1036  	logger.Infof("got response body: %q", m)
  1037  	err = json.Unmarshal(m, r)
  1038  	logger.Infof("unmarshalled into %#v", r)
  1039  	return err
  1040  }
  1041  
  1042  type connRole string
  1043  
  1044  const (
  1045  	roleBoth   connRole = "both"
  1046  	roleClient connRole = "client"
  1047  	roleServer connRole = "server"
  1048  )
  1049  
  1050  func NewJSONCodec(c net.Conn, role connRole) rpc.Codec {
  1051  	return &testCodec{
  1052  		role:  role,
  1053  		Codec: jsoncodec.NewNet(c),
  1054  	}
  1055  }
  1056  
  1057  type requestEvent struct {
  1058  	hdr  rpc.Header
  1059  	body interface{}
  1060  }
  1061  
  1062  type replyEvent struct {
  1063  	req  rpc.Request
  1064  	hdr  rpc.Header
  1065  	body interface{}
  1066  }
  1067  
  1068  type notifier struct {
  1069  	mu             sync.Mutex
  1070  	serverRequests []requestEvent
  1071  	serverReplies  []replyEvent
  1072  	clientRequests []requestEvent
  1073  	clientReplies  []replyEvent
  1074  }
  1075  
  1076  func (n *notifier) reset() {
  1077  	n.mu.Lock()
  1078  	defer n.mu.Unlock()
  1079  	n.serverRequests = nil
  1080  	n.serverReplies = nil
  1081  	n.clientRequests = nil
  1082  	n.clientReplies = nil
  1083  }
  1084  
  1085  func (n *notifier) ServerRequest(hdr *rpc.Header, body interface{}) {
  1086  	n.mu.Lock()
  1087  	defer n.mu.Unlock()
  1088  	n.serverRequests = append(n.serverRequests, requestEvent{
  1089  		hdr:  *hdr,
  1090  		body: body,
  1091  	})
  1092  }
  1093  
  1094  func (n *notifier) ServerReply(req rpc.Request, hdr *rpc.Header, body interface{}, timeSpent time.Duration) {
  1095  	n.mu.Lock()
  1096  	defer n.mu.Unlock()
  1097  	n.serverReplies = append(n.serverReplies, replyEvent{
  1098  		req:  req,
  1099  		hdr:  *hdr,
  1100  		body: body,
  1101  	})
  1102  }
  1103  
  1104  func (n *notifier) ClientRequest(hdr *rpc.Header, body interface{}) {
  1105  	n.mu.Lock()
  1106  	defer n.mu.Unlock()
  1107  	n.clientRequests = append(n.clientRequests, requestEvent{
  1108  		hdr:  *hdr,
  1109  		body: body,
  1110  	})
  1111  }
  1112  
  1113  func (n *notifier) ClientReply(req rpc.Request, hdr *rpc.Header, body interface{}) {
  1114  	n.mu.Lock()
  1115  	defer n.mu.Unlock()
  1116  	n.clientReplies = append(n.clientReplies, replyEvent{
  1117  		req:  req,
  1118  		hdr:  *hdr,
  1119  		body: body,
  1120  	})
  1121  }