github.com/rogpeppe/juju@v0.0.0-20140613142852-6337964b789e/rpc/rpc_test.go (about) 1 // Copyright 2012, 2013 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package rpc_test 5 6 import ( 7 "encoding/json" 8 "fmt" 9 "net" 10 "reflect" 11 "regexp" 12 "sync" 13 stdtesting "testing" 14 "time" 15 16 "github.com/juju/loggo" 17 gc "launchpad.net/gocheck" 18 19 "github.com/juju/juju/rpc" 20 "github.com/juju/juju/rpc/jsoncodec" 21 "github.com/juju/juju/testing" 22 ) 23 24 var logger = loggo.GetLogger("juju.rpc") 25 26 type rpcSuite struct { 27 testing.BaseSuite 28 } 29 30 var _ = gc.Suite(&rpcSuite{}) 31 32 func TestAll(t *stdtesting.T) { 33 gc.TestingT(t) 34 } 35 36 type callInfo struct { 37 rcvr interface{} 38 method string 39 arg interface{} 40 } 41 42 type callError callInfo 43 44 func (e *callError) Error() string { 45 return fmt.Sprintf("error calling %s", e.method) 46 } 47 48 type stringVal struct { 49 Val string 50 } 51 52 type Root struct { 53 mu sync.Mutex 54 conn *rpc.Conn 55 calls []*callInfo 56 returnErr bool 57 simple map[string]*SimpleMethods 58 delayed map[string]*DelayedMethods 59 errorInst *ErrorMethods 60 } 61 62 func (r *Root) callError(rcvr interface{}, name string, arg interface{}) error { 63 if r.returnErr { 64 return &callError{rcvr, name, arg} 65 } 66 return nil 67 } 68 69 func (r *Root) SimpleMethods(id string) (*SimpleMethods, error) { 70 r.mu.Lock() 71 defer r.mu.Unlock() 72 if a := r.simple[id]; a != nil { 73 return a, nil 74 } 75 return nil, fmt.Errorf("unknown SimpleMethods id") 76 } 77 78 func (r *Root) DelayedMethods(id string) (*DelayedMethods, error) { 79 r.mu.Lock() 80 defer r.mu.Unlock() 81 if a := r.delayed[id]; a != nil { 82 return a, nil 83 } 84 return nil, fmt.Errorf("unknown DelayedMethods id") 85 } 86 87 func (r *Root) ErrorMethods(id string) (*ErrorMethods, error) { 88 if r.errorInst == nil { 89 return nil, fmt.Errorf("no error methods") 90 } 91 return r.errorInst, nil 92 } 93 94 func (r *Root) Discard1() {} 95 96 func (r *Root) Discard2(id string) error { return nil } 97 98 func (r *Root) Discard3(id string) int { return 0 } 99 100 func (r *Root) CallbackMethods(string) (*CallbackMethods, error) { 101 return &CallbackMethods{r}, nil 102 } 103 104 func (r *Root) InterfaceMethods(id string) (InterfaceMethods, error) { 105 logger.Infof("interface methods called") 106 m, err := r.SimpleMethods(id) 107 if err != nil { 108 return nil, err 109 } 110 return m, nil 111 } 112 113 type InterfaceMethods interface { 114 Call1r1e(s stringVal) (stringVal, error) 115 } 116 117 type ChangeAPIMethods struct { 118 r *Root 119 } 120 121 func (r *Root) ChangeAPIMethods(string) (*ChangeAPIMethods, error) { 122 return &ChangeAPIMethods{r}, nil 123 } 124 125 func (t *Root) called(rcvr interface{}, method string, arg interface{}) { 126 t.mu.Lock() 127 t.calls = append(t.calls, &callInfo{rcvr, method, arg}) 128 t.mu.Unlock() 129 } 130 131 type SimpleMethods struct { 132 root *Root 133 id string 134 } 135 136 // Each Call method is named in this standard form: 137 // 138 // Call<narg>r<nret><e> 139 // 140 // where narg is the number of arguments, nret is the number of returned 141 // values (not including the error) and e is the letter 'e' if the 142 // method returns an error. 143 144 func (a *SimpleMethods) Call0r0() { 145 a.root.called(a, "Call0r0", nil) 146 } 147 148 func (a *SimpleMethods) Call0r1() stringVal { 149 a.root.called(a, "Call0r1", nil) 150 return stringVal{"Call0r1 ret"} 151 } 152 153 func (a *SimpleMethods) Call0r1e() (stringVal, error) { 154 a.root.called(a, "Call0r1e", nil) 155 return stringVal{"Call0r1e ret"}, a.root.callError(a, "Call0r1e", nil) 156 } 157 158 func (a *SimpleMethods) Call0r0e() error { 159 a.root.called(a, "Call0r0e", nil) 160 return a.root.callError(a, "Call0r0e", nil) 161 } 162 163 func (a *SimpleMethods) Call1r0(s stringVal) { 164 a.root.called(a, "Call1r0", s) 165 } 166 167 func (a *SimpleMethods) Call1r1(s stringVal) stringVal { 168 a.root.called(a, "Call1r1", s) 169 return stringVal{"Call1r1 ret"} 170 } 171 172 func (a *SimpleMethods) Call1r1e(s stringVal) (stringVal, error) { 173 a.root.called(a, "Call1r1e", s) 174 return stringVal{"Call1r1e ret"}, a.root.callError(a, "Call1r1e", s) 175 } 176 177 func (a *SimpleMethods) Call1r0e(s stringVal) error { 178 a.root.called(a, "Call1r0e", s) 179 return a.root.callError(a, "Call1r0e", s) 180 } 181 182 func (a *SimpleMethods) SliceArg(struct{ X []string }) stringVal { 183 return stringVal{"SliceArg ret"} 184 } 185 186 func (a *SimpleMethods) Discard1(int) {} 187 188 func (a *SimpleMethods) Discard2(struct{}, struct{}) {} 189 190 func (a *SimpleMethods) Discard3() int { return 0 } 191 192 func (a *SimpleMethods) Discard4() (_, _ struct{}) { return } 193 194 type DelayedMethods struct { 195 ready chan struct{} 196 done chan string 197 doneError chan error 198 } 199 200 func (a *DelayedMethods) Delay() (stringVal, error) { 201 if a.ready != nil { 202 a.ready <- struct{}{} 203 } 204 select { 205 case s := <-a.done: 206 return stringVal{s}, nil 207 case err := <-a.doneError: 208 return stringVal{}, err 209 } 210 } 211 212 type ErrorMethods struct { 213 err error 214 } 215 216 func (e *ErrorMethods) Call() error { 217 return e.err 218 } 219 220 type CallbackMethods struct { 221 root *Root 222 } 223 224 type int64val struct { 225 I int64 226 } 227 228 func (a *CallbackMethods) Factorial(x int64val) (int64val, error) { 229 if x.I <= 1 { 230 return int64val{1}, nil 231 } 232 var r int64val 233 err := a.root.conn.Call(rpc.Request{"CallbackMethods", "", "Factorial"}, int64val{x.I - 1}, &r) 234 if err != nil { 235 return int64val{}, err 236 } 237 return int64val{x.I * r.I}, nil 238 } 239 240 func (a *ChangeAPIMethods) ChangeAPI() { 241 a.r.conn.Serve(&changedAPIRoot{}, nil) 242 } 243 244 func (a *ChangeAPIMethods) RemoveAPI() { 245 a.r.conn.Serve(nil, nil) 246 } 247 248 type changedAPIRoot struct{} 249 250 func (r *changedAPIRoot) NewlyAvailable(string) (newlyAvailableMethods, error) { 251 return newlyAvailableMethods{}, nil 252 } 253 254 type newlyAvailableMethods struct{} 255 256 func (newlyAvailableMethods) NewMethod() stringVal { 257 return stringVal{"new method result"} 258 } 259 260 func (*rpcSuite) TestRPC(c *gc.C) { 261 root := &Root{ 262 simple: make(map[string]*SimpleMethods), 263 } 264 root.simple["a99"] = &SimpleMethods{root: root, id: "a99"} 265 client, srvDone, clientNotifier, serverNotifier := newRPCClientServer(c, root, nil, false) 266 defer closeClient(c, client, srvDone) 267 for narg := 0; narg < 2; narg++ { 268 for nret := 0; nret < 2; nret++ { 269 for nerr := 0; nerr < 2; nerr++ { 270 retErr := nerr != 0 271 p := testCallParams{ 272 client: client, 273 clientNotifier: clientNotifier, 274 serverNotifier: serverNotifier, 275 entry: "SimpleMethods", 276 narg: narg, 277 nret: nret, 278 retErr: retErr, 279 testErr: false, 280 } 281 root.testCall(c, p) 282 if retErr { 283 p.testErr = true 284 root.testCall(c, p) 285 } 286 } 287 } 288 } 289 } 290 291 func callName(narg, nret int, retErr bool) string { 292 e := "" 293 if retErr { 294 e = "e" 295 } 296 return fmt.Sprintf("Call%dr%d%s", narg, nret, e) 297 } 298 299 type testCallParams struct { 300 // client holds the client-side of the rpc connection that 301 // will be used to make the call. 302 client *rpc.Conn 303 304 // clientNotifier holds the notifier for the client side. 305 clientNotifier *notifier 306 307 // serverNotifier holds the notifier for the server side. 308 serverNotifier *notifier 309 310 // entry holds the top-level type that will be invoked 311 // (e.g. "SimpleMethods") 312 entry string 313 314 // narg holds the number of arguments accepted by the 315 // call (0 or 1) 316 narg int 317 318 // nret holds the number of values returned by the 319 // call (0 or 1). 320 nret int 321 322 // retErr specifies whether the call returns an error. 323 retErr bool 324 325 // testErr specifies whether the call should be made to return an error. 326 testErr bool 327 } 328 329 // request returns the RPC request for the test call. 330 func (p testCallParams) request() rpc.Request { 331 return rpc.Request{ 332 Type: p.entry, 333 Id: "a99", 334 Action: callName(p.narg, p.nret, p.retErr), 335 } 336 } 337 338 // error message returns the error message that the test call 339 // should return if it returns an error. 340 func (p testCallParams) errorMessage() string { 341 return fmt.Sprintf("error calling %s", p.request().Action) 342 } 343 344 func (root *Root) testCall(c *gc.C, p testCallParams) { 345 p.clientNotifier.reset() 346 p.serverNotifier.reset() 347 root.calls = nil 348 root.returnErr = p.testErr 349 c.Logf("test call %s", p.request().Action) 350 var r stringVal 351 err := p.client.Call(p.request(), stringVal{"arg"}, &r) 352 switch { 353 case p.retErr && p.testErr: 354 c.Assert(err, gc.DeepEquals, &rpc.RequestError{ 355 Message: p.errorMessage(), 356 }) 357 c.Assert(r, gc.Equals, stringVal{}) 358 case p.nret > 0: 359 c.Assert(r, gc.Equals, stringVal{p.request().Action + " ret"}) 360 } 361 362 // Check that the call was actually made, the right 363 // parameters were received and the right result returned. 364 root.mu.Lock() 365 defer root.mu.Unlock() 366 367 root.assertCallMade(c, p) 368 369 requestId := root.assertClientNotified(c, p, &r) 370 371 root.assertServerNotified(c, p, requestId) 372 } 373 374 func (root *Root) assertCallMade(c *gc.C, p testCallParams) { 375 expectCall := callInfo{ 376 rcvr: root.simple["a99"], 377 method: p.request().Action, 378 } 379 if p.narg > 0 { 380 expectCall.arg = stringVal{"arg"} 381 } 382 c.Assert(root.calls, gc.HasLen, 1) 383 c.Assert(*root.calls[0], gc.Equals, expectCall) 384 } 385 386 // assertClientNotified asserts that the right client notifications 387 // were made for the given test call parameters. The value of r 388 // holds the result parameter passed to the call. 389 // It returns the request id. 390 func (root *Root) assertClientNotified(c *gc.C, p testCallParams, r interface{}) uint64 { 391 c.Assert(p.clientNotifier.serverRequests, gc.HasLen, 0) 392 c.Assert(p.clientNotifier.serverReplies, gc.HasLen, 0) 393 394 // Test that there was a notification for the request. 395 c.Assert(p.clientNotifier.clientRequests, gc.HasLen, 1) 396 clientReq := p.clientNotifier.clientRequests[0] 397 requestId := clientReq.hdr.RequestId 398 clientReq.hdr.RequestId = 0 // Ignore the exact value of the request id to start with. 399 c.Assert(clientReq.hdr, gc.DeepEquals, rpc.Header{ 400 Request: p.request(), 401 }) 402 c.Assert(clientReq.body, gc.Equals, stringVal{"arg"}) 403 404 // Test that there was a notification for the reply. 405 c.Assert(p.clientNotifier.clientReplies, gc.HasLen, 1) 406 clientReply := p.clientNotifier.clientReplies[0] 407 c.Assert(clientReply.req, gc.Equals, p.request()) 408 if p.retErr && p.testErr { 409 c.Assert(clientReply.body, gc.Equals, nil) 410 } else { 411 c.Assert(clientReply.body, gc.Equals, r) 412 } 413 if p.retErr && p.testErr { 414 c.Assert(clientReply.hdr, gc.DeepEquals, rpc.Header{ 415 RequestId: requestId, 416 Error: p.errorMessage(), 417 }) 418 } else { 419 c.Assert(clientReply.hdr, gc.DeepEquals, rpc.Header{ 420 RequestId: requestId, 421 }) 422 } 423 return requestId 424 } 425 426 // assertServerNotified asserts that the right server notifications 427 // were made for the given test call parameters. The id of the request 428 // is held in requestId. 429 func (root *Root) assertServerNotified(c *gc.C, p testCallParams, requestId uint64) { 430 // Check that the right server notifications were made. 431 c.Assert(p.serverNotifier.clientRequests, gc.HasLen, 0) 432 c.Assert(p.serverNotifier.clientReplies, gc.HasLen, 0) 433 434 // Test that there was a notification for the request. 435 c.Assert(p.serverNotifier.serverRequests, gc.HasLen, 1) 436 serverReq := p.serverNotifier.serverRequests[0] 437 c.Assert(serverReq.hdr, gc.DeepEquals, rpc.Header{ 438 RequestId: requestId, 439 Request: p.request(), 440 }) 441 if p.narg > 0 { 442 c.Assert(serverReq.body, gc.Equals, stringVal{"arg"}) 443 } else { 444 c.Assert(serverReq.body, gc.Equals, struct{}{}) 445 } 446 447 // Test that there was a notification for the reply. 448 c.Assert(p.serverNotifier.serverReplies, gc.HasLen, 1) 449 serverReply := p.serverNotifier.serverReplies[0] 450 c.Assert(serverReply.req, gc.Equals, p.request()) 451 if p.retErr && p.testErr || p.nret == 0 { 452 c.Assert(serverReply.body, gc.Equals, struct{}{}) 453 } else { 454 c.Assert(serverReply.body, gc.Equals, stringVal{p.request().Action + " ret"}) 455 } 456 if p.retErr && p.testErr { 457 c.Assert(serverReply.hdr, gc.Equals, rpc.Header{ 458 RequestId: requestId, 459 Error: p.errorMessage(), 460 }) 461 } else { 462 c.Assert(serverReply.hdr, gc.Equals, rpc.Header{ 463 RequestId: requestId, 464 }) 465 } 466 } 467 468 func (*rpcSuite) TestInterfaceMethods(c *gc.C) { 469 root := &Root{ 470 simple: make(map[string]*SimpleMethods), 471 } 472 root.simple["a99"] = &SimpleMethods{root: root, id: "a99"} 473 client, srvDone, clientNotifier, serverNotifier := newRPCClientServer(c, root, nil, false) 474 defer closeClient(c, client, srvDone) 475 p := testCallParams{ 476 client: client, 477 clientNotifier: clientNotifier, 478 serverNotifier: serverNotifier, 479 entry: "InterfaceMethods", 480 narg: 1, 481 nret: 1, 482 retErr: true, 483 testErr: false, 484 } 485 486 root.testCall(c, p) 487 p.testErr = true 488 root.testCall(c, p) 489 } 490 491 func (*rpcSuite) TestConcurrentCalls(c *gc.C) { 492 start1 := make(chan string) 493 start2 := make(chan string) 494 ready1 := make(chan struct{}) 495 ready2 := make(chan struct{}) 496 497 root := &Root{ 498 delayed: map[string]*DelayedMethods{ 499 "1": {ready: ready1, done: start1}, 500 "2": {ready: ready2, done: start2}, 501 }, 502 } 503 504 client, srvDone, _, _ := newRPCClientServer(c, root, nil, false) 505 defer closeClient(c, client, srvDone) 506 call := func(id string, done chan<- struct{}) { 507 var r stringVal 508 err := client.Call(rpc.Request{"DelayedMethods", id, "Delay"}, nil, &r) 509 c.Check(err, gc.IsNil) 510 c.Check(r.Val, gc.Equals, "return "+id) 511 done <- struct{}{} 512 } 513 done1 := make(chan struct{}) 514 done2 := make(chan struct{}) 515 go call("1", done1) 516 go call("2", done2) 517 518 // Check that both calls are running concurrently. 519 chanRead(c, ready1, "method 1 ready") 520 chanRead(c, ready2, "method 2 ready") 521 522 // Let the requests complete. 523 start1 <- "return 1" 524 start2 <- "return 2" 525 chanRead(c, done1, "method 1 done") 526 chanRead(c, done2, "method 2 done") 527 } 528 529 type codedError struct { 530 m string 531 code string 532 } 533 534 func (e *codedError) Error() string { 535 return e.m 536 } 537 538 func (e *codedError) ErrorCode() string { 539 return e.code 540 } 541 542 func (*rpcSuite) TestErrorCode(c *gc.C) { 543 root := &Root{ 544 errorInst: &ErrorMethods{&codedError{"message", "code"}}, 545 } 546 client, srvDone, _, _ := newRPCClientServer(c, root, nil, false) 547 defer closeClient(c, client, srvDone) 548 err := client.Call(rpc.Request{"ErrorMethods", "", "Call"}, nil, nil) 549 c.Assert(err, gc.ErrorMatches, `request error: message \(code\)`) 550 c.Assert(err.(rpc.ErrorCoder).ErrorCode(), gc.Equals, "code") 551 } 552 553 func (*rpcSuite) TestTransformErrors(c *gc.C) { 554 root := &Root{ 555 errorInst: &ErrorMethods{&codedError{"message", "code"}}, 556 } 557 tfErr := func(err error) error { 558 c.Check(err, gc.NotNil) 559 if e, ok := err.(*codedError); ok { 560 return &codedError{ 561 m: "transformed: " + e.m, 562 code: "transformed: " + e.code, 563 } 564 } 565 return fmt.Errorf("transformed: %v", err) 566 } 567 client, srvDone, _, _ := newRPCClientServer(c, root, tfErr, false) 568 defer closeClient(c, client, srvDone) 569 err := client.Call(rpc.Request{"ErrorMethods", "", "Call"}, nil, nil) 570 c.Assert(err, gc.DeepEquals, &rpc.RequestError{ 571 Message: "transformed: message", 572 Code: "transformed: code", 573 }) 574 575 root.errorInst.err = nil 576 err = client.Call(rpc.Request{"ErrorMethods", "", "Call"}, nil, nil) 577 c.Assert(err, gc.IsNil) 578 579 root.errorInst = nil 580 err = client.Call(rpc.Request{"ErrorMethods", "", "Call"}, nil, nil) 581 c.Assert(err, gc.DeepEquals, &rpc.RequestError{ 582 Message: "transformed: no error methods", 583 }) 584 585 } 586 587 func (*rpcSuite) TestServerWaitsForOutstandingCalls(c *gc.C) { 588 ready := make(chan struct{}) 589 start := make(chan string) 590 root := &Root{ 591 delayed: map[string]*DelayedMethods{ 592 "1": { 593 ready: ready, 594 done: start, 595 }, 596 }, 597 } 598 client, srvDone, _, _ := newRPCClientServer(c, root, nil, false) 599 defer closeClient(c, client, srvDone) 600 done := make(chan struct{}) 601 go func() { 602 var r stringVal 603 err := client.Call(rpc.Request{"DelayedMethods", "1", "Delay"}, nil, &r) 604 c.Check(err, gc.Equals, rpc.ErrShutdown) 605 done <- struct{}{} 606 }() 607 chanRead(c, ready, "DelayedMethods.Delay ready") 608 client.Close() 609 select { 610 case err := <-srvDone: 611 c.Fatalf("server returned while outstanding operation in progress: %v", err) 612 <-done 613 case <-time.After(25 * time.Millisecond): 614 } 615 start <- "xxx" 616 } 617 618 func chanRead(c *gc.C, ch <-chan struct{}, what string) { 619 select { 620 case <-ch: 621 return 622 case <-time.After(3 * time.Second): 623 c.Fatalf("timeout on channel read %s", what) 624 } 625 } 626 627 func (*rpcSuite) TestCompatibility(c *gc.C) { 628 root := &Root{ 629 simple: make(map[string]*SimpleMethods), 630 } 631 a0 := &SimpleMethods{root: root, id: "a0"} 632 root.simple["a0"] = a0 633 634 client, srvDone, _, _ := newRPCClientServer(c, root, nil, false) 635 defer closeClient(c, client, srvDone) 636 call := func(method string, arg, ret interface{}) (passedArg interface{}) { 637 root.calls = nil 638 err := client.Call(rpc.Request{"SimpleMethods", "a0", method}, arg, ret) 639 c.Assert(err, gc.IsNil) 640 c.Assert(root.calls, gc.HasLen, 1) 641 info := root.calls[0] 642 c.Assert(info.rcvr, gc.Equals, a0) 643 c.Assert(info.method, gc.Equals, method) 644 return info.arg 645 } 646 type extra struct { 647 Val string 648 Extra string 649 } 650 // Extra fields in request and response. 651 var r extra 652 arg := call("Call1r1", extra{"x", "y"}, &r) 653 c.Assert(arg, gc.Equals, stringVal{"x"}) 654 655 // Nil argument as request. 656 r = extra{} 657 arg = call("Call1r1", nil, &r) 658 c.Assert(arg, gc.Equals, stringVal{}) 659 660 // Nil argument as response. 661 arg = call("Call1r1", stringVal{"x"}, nil) 662 c.Assert(arg, gc.Equals, stringVal{"x"}) 663 664 // Non-nil argument for no response. 665 r = extra{} 666 arg = call("Call1r0", stringVal{"x"}, &r) 667 c.Assert(arg, gc.Equals, stringVal{"x"}) 668 c.Assert(r, gc.Equals, extra{}) 669 } 670 671 func (*rpcSuite) TestBadCall(c *gc.C) { 672 root := &Root{ 673 simple: make(map[string]*SimpleMethods), 674 } 675 a0 := &SimpleMethods{root: root, id: "a0"} 676 root.simple["a0"] = a0 677 client, srvDone, clientNotifier, serverNotifier := newRPCClientServer(c, root, nil, false) 678 defer closeClient(c, client, srvDone) 679 680 testBadCall(c, client, clientNotifier, serverNotifier, 681 rpc.Request{"BadSomething", "a0", "No"}, 682 `unknown object type "BadSomething"`, 683 rpc.CodeNotImplemented, 684 false, 685 ) 686 testBadCall(c, client, clientNotifier, serverNotifier, 687 rpc.Request{"SimpleMethods", "xx", "No"}, 688 "no such request - method SimpleMethods.No is not implemented", 689 rpc.CodeNotImplemented, 690 false, 691 ) 692 testBadCall(c, client, clientNotifier, serverNotifier, 693 rpc.Request{"SimpleMethods", "xx", "Call0r0"}, 694 `unknown SimpleMethods id`, 695 "", 696 true, 697 ) 698 } 699 700 func testBadCall( 701 c *gc.C, 702 client *rpc.Conn, 703 clientNotifier, serverNotifier *notifier, 704 req rpc.Request, 705 expectedErr string, 706 expectedErrCode string, 707 requestKnown bool, 708 ) { 709 clientNotifier.reset() 710 serverNotifier.reset() 711 err := client.Call(req, nil, nil) 712 msg := expectedErr 713 if expectedErrCode != "" { 714 msg += " (" + expectedErrCode + ")" 715 } 716 c.Assert(err, gc.ErrorMatches, regexp.QuoteMeta("request error: "+msg)) 717 718 // Test that there was a notification for the client request. 719 c.Assert(clientNotifier.clientRequests, gc.HasLen, 1) 720 clientReq := clientNotifier.clientRequests[0] 721 requestId := clientReq.hdr.RequestId 722 c.Assert(clientReq, gc.DeepEquals, requestEvent{ 723 hdr: rpc.Header{ 724 RequestId: requestId, 725 Request: req, 726 }, 727 body: struct{}{}, 728 }) 729 // Test that there was a notification for the client reply. 730 c.Assert(clientNotifier.clientReplies, gc.HasLen, 1) 731 clientReply := clientNotifier.clientReplies[0] 732 c.Assert(clientReply, gc.DeepEquals, replyEvent{ 733 req: req, 734 hdr: rpc.Header{ 735 RequestId: requestId, 736 Error: expectedErr, 737 ErrorCode: expectedErrCode, 738 }, 739 }) 740 741 // Test that there was a notification for the server request. 742 c.Assert(serverNotifier.serverRequests, gc.HasLen, 1) 743 serverReq := serverNotifier.serverRequests[0] 744 745 // From docs on ServerRequest: 746 // If the request was not recognized or there was 747 // an error reading the body, body will be nil. 748 var expectBody interface{} 749 if requestKnown { 750 expectBody = struct{}{} 751 } 752 c.Assert(serverReq, gc.DeepEquals, requestEvent{ 753 hdr: rpc.Header{ 754 RequestId: requestId, 755 Request: req, 756 }, 757 body: expectBody, 758 }) 759 760 // Test that there was a notification for the server reply. 761 c.Assert(serverNotifier.serverReplies, gc.HasLen, 1) 762 serverReply := serverNotifier.serverReplies[0] 763 c.Assert(serverReply, gc.DeepEquals, replyEvent{ 764 hdr: rpc.Header{ 765 RequestId: requestId, 766 Error: expectedErr, 767 ErrorCode: expectedErrCode, 768 }, 769 req: req, 770 body: struct{}{}, 771 }) 772 } 773 774 func (*rpcSuite) TestContinueAfterReadBodyError(c *gc.C) { 775 root := &Root{ 776 simple: make(map[string]*SimpleMethods), 777 } 778 a0 := &SimpleMethods{root: root, id: "a0"} 779 root.simple["a0"] = a0 780 client, srvDone, _, _ := newRPCClientServer(c, root, nil, false) 781 defer closeClient(c, client, srvDone) 782 783 var ret stringVal 784 arg0 := struct { 785 X map[string]int 786 }{ 787 X: map[string]int{"hello": 65}, 788 } 789 err := client.Call(rpc.Request{"SimpleMethods", "a0", "SliceArg"}, arg0, &ret) 790 c.Assert(err, gc.ErrorMatches, `request error: json: cannot unmarshal object into Go value of type \[\]string`) 791 792 err = client.Call(rpc.Request{"SimpleMethods", "a0", "SliceArg"}, arg0, &ret) 793 c.Assert(err, gc.ErrorMatches, `request error: json: cannot unmarshal object into Go value of type \[\]string`) 794 795 arg1 := struct { 796 X []string 797 }{ 798 X: []string{"one"}, 799 } 800 err = client.Call(rpc.Request{"SimpleMethods", "a0", "SliceArg"}, arg1, &ret) 801 c.Assert(err, gc.IsNil) 802 c.Assert(ret.Val, gc.Equals, "SliceArg ret") 803 } 804 805 func (*rpcSuite) TestErrorAfterClientClose(c *gc.C) { 806 client, srvDone, _, _ := newRPCClientServer(c, &Root{}, nil, false) 807 err := client.Close() 808 c.Assert(err, gc.IsNil) 809 err = client.Call(rpc.Request{"Foo", "", "Bar"}, nil, nil) 810 c.Assert(err, gc.Equals, rpc.ErrShutdown) 811 err = chanReadError(c, srvDone, "server done") 812 c.Assert(err, gc.IsNil) 813 } 814 815 func (*rpcSuite) TestClientCloseIdempotent(c *gc.C) { 816 client, _, _, _ := newRPCClientServer(c, &Root{}, nil, false) 817 err := client.Close() 818 c.Assert(err, gc.IsNil) 819 err = client.Close() 820 c.Assert(err, gc.IsNil) 821 err = client.Close() 822 c.Assert(err, gc.IsNil) 823 } 824 825 type KillerRoot struct { 826 killed bool 827 Root 828 } 829 830 func (r *KillerRoot) Kill() { 831 r.killed = true 832 } 833 834 func (*rpcSuite) TestRootIsKilled(c *gc.C) { 835 root := &KillerRoot{} 836 client, srvDone, _, _ := newRPCClientServer(c, root, nil, false) 837 err := client.Close() 838 c.Assert(err, gc.IsNil) 839 err = chanReadError(c, srvDone, "server done") 840 c.Assert(err, gc.IsNil) 841 c.Assert(root.killed, gc.Equals, true) 842 } 843 844 func (*rpcSuite) TestBidirectional(c *gc.C) { 845 srvRoot := &Root{} 846 client, srvDone, _, _ := newRPCClientServer(c, srvRoot, nil, true) 847 defer closeClient(c, client, srvDone) 848 clientRoot := &Root{conn: client} 849 client.Serve(clientRoot, nil) 850 var r int64val 851 err := client.Call(rpc.Request{"CallbackMethods", "", "Factorial"}, int64val{12}, &r) 852 c.Assert(err, gc.IsNil) 853 c.Assert(r.I, gc.Equals, int64(479001600)) 854 } 855 856 func (*rpcSuite) TestServerRequestWhenNotServing(c *gc.C) { 857 srvRoot := &Root{} 858 client, srvDone, _, _ := newRPCClientServer(c, srvRoot, nil, true) 859 defer closeClient(c, client, srvDone) 860 var r int64val 861 err := client.Call(rpc.Request{"CallbackMethods", "", "Factorial"}, int64val{12}, &r) 862 c.Assert(err, gc.ErrorMatches, "request error: request error: no service") 863 } 864 865 func (*rpcSuite) TestChangeAPI(c *gc.C) { 866 srvRoot := &Root{} 867 client, srvDone, _, _ := newRPCClientServer(c, srvRoot, nil, true) 868 defer closeClient(c, client, srvDone) 869 var s stringVal 870 err := client.Call(rpc.Request{"NewlyAvailable", "", "NewMethod"}, nil, &s) 871 c.Assert(err, gc.ErrorMatches, `request error: unknown object type "NewlyAvailable" \(not implemented\)`) 872 err = client.Call(rpc.Request{"ChangeAPIMethods", "", "ChangeAPI"}, nil, nil) 873 c.Assert(err, gc.IsNil) 874 err = client.Call(rpc.Request{"ChangeAPIMethods", "", "ChangeAPI"}, nil, nil) 875 c.Assert(err, gc.ErrorMatches, `request error: unknown object type "ChangeAPIMethods" \(not implemented\)`) 876 err = client.Call(rpc.Request{"NewlyAvailable", "", "NewMethod"}, nil, &s) 877 c.Assert(err, gc.IsNil) 878 c.Assert(s, gc.Equals, stringVal{"new method result"}) 879 } 880 881 func (*rpcSuite) TestChangeAPIToNil(c *gc.C) { 882 srvRoot := &Root{} 883 client, srvDone, _, _ := newRPCClientServer(c, srvRoot, nil, true) 884 defer closeClient(c, client, srvDone) 885 886 err := client.Call(rpc.Request{"ChangeAPIMethods", "", "RemoveAPI"}, nil, nil) 887 c.Assert(err, gc.IsNil) 888 889 err = client.Call(rpc.Request{"ChangeAPIMethods", "", "RemoveAPI"}, nil, nil) 890 c.Assert(err, gc.ErrorMatches, "request error: no service") 891 } 892 893 func (*rpcSuite) TestChangeAPIWhileServingRequest(c *gc.C) { 894 ready := make(chan struct{}) 895 done := make(chan error) 896 srvRoot := &Root{ 897 delayed: map[string]*DelayedMethods{ 898 "1": {ready: ready, doneError: done}, 899 }, 900 } 901 transform := func(err error) error { 902 return fmt.Errorf("transformed: %v", err) 903 } 904 client, srvDone, _, _ := newRPCClientServer(c, srvRoot, transform, true) 905 defer closeClient(c, client, srvDone) 906 907 result := make(chan error) 908 go func() { 909 result <- client.Call(rpc.Request{"DelayedMethods", "1", "Delay"}, nil, nil) 910 }() 911 chanRead(c, ready, "method ready") 912 913 err := client.Call(rpc.Request{"ChangeAPIMethods", "", "ChangeAPI"}, nil, nil) 914 c.Assert(err, gc.IsNil) 915 916 // Ensure that not only does the request in progress complete, 917 // but that the original transformErrors function is called. 918 done <- fmt.Errorf("an error") 919 select { 920 case r := <-result: 921 c.Assert(r, gc.ErrorMatches, "request error: transformed: an error") 922 case <-time.After(3 * time.Second): 923 c.Fatalf("timeout on channel read") 924 } 925 } 926 927 func chanReadError(c *gc.C, ch <-chan error, what string) error { 928 select { 929 case e := <-ch: 930 return e 931 case <-time.After(3 * time.Second): 932 c.Fatalf("timeout on channel read %s", what) 933 } 934 panic("unreachable") 935 } 936 937 // newRPCClientServer starts an RPC server serving a connection from a 938 // single client. When the server has finished serving the connection, 939 // it sends a value on the returned channel. 940 // If bidir is true, requests can flow in both directions. 941 func newRPCClientServer(c *gc.C, root interface{}, tfErr func(error) error, bidir bool) (client *rpc.Conn, srvDone chan error, clientNotifier, serverNotifier *notifier) { 942 l, err := net.Listen("tcp", "127.0.0.1:0") 943 c.Assert(err, gc.IsNil) 944 945 srvDone = make(chan error, 1) 946 clientNotifier = new(notifier) 947 serverNotifier = new(notifier) 948 go func() { 949 conn, err := l.Accept() 950 if err != nil { 951 srvDone <- nil 952 return 953 } 954 defer l.Close() 955 role := roleServer 956 if bidir { 957 role = roleBoth 958 } 959 rpcConn := rpc.NewConn(NewJSONCodec(conn, role), serverNotifier) 960 rpcConn.Serve(root, tfErr) 961 if root, ok := root.(*Root); ok { 962 root.conn = rpcConn 963 } 964 rpcConn.Start() 965 <-rpcConn.Dead() 966 srvDone <- rpcConn.Close() 967 }() 968 conn, err := net.Dial("tcp", l.Addr().String()) 969 c.Assert(err, gc.IsNil) 970 role := roleClient 971 if bidir { 972 role = roleBoth 973 } 974 client = rpc.NewConn(NewJSONCodec(conn, role), clientNotifier) 975 client.Start() 976 return client, srvDone, clientNotifier, serverNotifier 977 } 978 979 func closeClient(c *gc.C, client *rpc.Conn, srvDone <-chan error) { 980 err := client.Close() 981 c.Assert(err, gc.IsNil) 982 err = chanReadError(c, srvDone, "server done") 983 c.Assert(err, gc.IsNil) 984 } 985 986 type encoder interface { 987 Encode(e interface{}) error 988 } 989 990 type decoder interface { 991 Decode(e interface{}) error 992 } 993 994 // testCodec wraps an rpc.Codec with extra error checking code. 995 type testCodec struct { 996 role connRole 997 rpc.Codec 998 } 999 1000 func (c *testCodec) WriteMessage(hdr *rpc.Header, x interface{}) error { 1001 if reflect.ValueOf(x).Kind() != reflect.Struct { 1002 panic(fmt.Errorf("WriteRequest bad param; want struct got %T (%#v)", x, x)) 1003 } 1004 if c.role != roleBoth && hdr.IsRequest() != (c.role == roleClient) { 1005 panic(fmt.Errorf("codec role %v; header wrong type %#v", c.role, hdr)) 1006 } 1007 logger.Infof("send header: %#v; body: %#v", hdr, x) 1008 return c.Codec.WriteMessage(hdr, x) 1009 } 1010 1011 func (c *testCodec) ReadHeader(hdr *rpc.Header) error { 1012 err := c.Codec.ReadHeader(hdr) 1013 if err != nil { 1014 return err 1015 } 1016 logger.Infof("got header %#v", hdr) 1017 if c.role != roleBoth && hdr.IsRequest() == (c.role == roleClient) { 1018 panic(fmt.Errorf("codec role %v; read wrong type %#v", c.role, hdr)) 1019 } 1020 return nil 1021 } 1022 1023 func (c *testCodec) ReadBody(r interface{}, isRequest bool) error { 1024 if v := reflect.ValueOf(r); v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { 1025 panic(fmt.Errorf("ReadResponseBody bad destination; want *struct got %T", r)) 1026 } 1027 if c.role != roleBoth && isRequest == (c.role == roleClient) { 1028 panic(fmt.Errorf("codec role %v; read wrong body type %#v", c.role, r)) 1029 } 1030 // Note: this will need to change if we want to test a non-JSON codec. 1031 var m json.RawMessage 1032 err := c.Codec.ReadBody(&m, isRequest) 1033 if err != nil { 1034 return err 1035 } 1036 logger.Infof("got response body: %q", m) 1037 err = json.Unmarshal(m, r) 1038 logger.Infof("unmarshalled into %#v", r) 1039 return err 1040 } 1041 1042 type connRole string 1043 1044 const ( 1045 roleBoth connRole = "both" 1046 roleClient connRole = "client" 1047 roleServer connRole = "server" 1048 ) 1049 1050 func NewJSONCodec(c net.Conn, role connRole) rpc.Codec { 1051 return &testCodec{ 1052 role: role, 1053 Codec: jsoncodec.NewNet(c), 1054 } 1055 } 1056 1057 type requestEvent struct { 1058 hdr rpc.Header 1059 body interface{} 1060 } 1061 1062 type replyEvent struct { 1063 req rpc.Request 1064 hdr rpc.Header 1065 body interface{} 1066 } 1067 1068 type notifier struct { 1069 mu sync.Mutex 1070 serverRequests []requestEvent 1071 serverReplies []replyEvent 1072 clientRequests []requestEvent 1073 clientReplies []replyEvent 1074 } 1075 1076 func (n *notifier) reset() { 1077 n.mu.Lock() 1078 defer n.mu.Unlock() 1079 n.serverRequests = nil 1080 n.serverReplies = nil 1081 n.clientRequests = nil 1082 n.clientReplies = nil 1083 } 1084 1085 func (n *notifier) ServerRequest(hdr *rpc.Header, body interface{}) { 1086 n.mu.Lock() 1087 defer n.mu.Unlock() 1088 n.serverRequests = append(n.serverRequests, requestEvent{ 1089 hdr: *hdr, 1090 body: body, 1091 }) 1092 } 1093 1094 func (n *notifier) ServerReply(req rpc.Request, hdr *rpc.Header, body interface{}, timeSpent time.Duration) { 1095 n.mu.Lock() 1096 defer n.mu.Unlock() 1097 n.serverReplies = append(n.serverReplies, replyEvent{ 1098 req: req, 1099 hdr: *hdr, 1100 body: body, 1101 }) 1102 } 1103 1104 func (n *notifier) ClientRequest(hdr *rpc.Header, body interface{}) { 1105 n.mu.Lock() 1106 defer n.mu.Unlock() 1107 n.clientRequests = append(n.clientRequests, requestEvent{ 1108 hdr: *hdr, 1109 body: body, 1110 }) 1111 } 1112 1113 func (n *notifier) ClientReply(req rpc.Request, hdr *rpc.Header, body interface{}) { 1114 n.mu.Lock() 1115 defer n.mu.Unlock() 1116 n.clientReplies = append(n.clientReplies, replyEvent{ 1117 req: req, 1118 hdr: *hdr, 1119 body: body, 1120 }) 1121 }