github.com/slayercat/go@v0.0.0-20170428012452-c51559813f61/src/net/rpc/server_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rpc 6 7 import ( 8 "errors" 9 "fmt" 10 "io" 11 "log" 12 "net" 13 "net/http/httptest" 14 "reflect" 15 "runtime" 16 "strings" 17 "sync" 18 "sync/atomic" 19 "testing" 20 "time" 21 ) 22 23 var ( 24 newServer *Server 25 serverAddr, newServerAddr string 26 httpServerAddr string 27 once, newOnce, httpOnce sync.Once 28 ) 29 30 const ( 31 newHttpPath = "/foo" 32 ) 33 34 type Args struct { 35 A, B int 36 } 37 38 type Reply struct { 39 C int 40 } 41 42 type Arith int 43 44 // Some of Arith's methods have value args, some have pointer args. That's deliberate. 45 46 func (t *Arith) Add(args Args, reply *Reply) error { 47 reply.C = args.A + args.B 48 return nil 49 } 50 51 func (t *Arith) Mul(args *Args, reply *Reply) error { 52 reply.C = args.A * args.B 53 return nil 54 } 55 56 func (t *Arith) Div(args Args, reply *Reply) error { 57 if args.B == 0 { 58 return errors.New("divide by zero") 59 } 60 reply.C = args.A / args.B 61 return nil 62 } 63 64 func (t *Arith) String(args *Args, reply *string) error { 65 *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) 66 return nil 67 } 68 69 func (t *Arith) Scan(args string, reply *Reply) (err error) { 70 _, err = fmt.Sscan(args, &reply.C) 71 return 72 } 73 74 func (t *Arith) Error(args *Args, reply *Reply) error { 75 panic("ERROR") 76 } 77 78 type hidden int 79 80 func (t *hidden) Exported(args Args, reply *Reply) error { 81 reply.C = args.A + args.B 82 return nil 83 } 84 85 type Embed struct { 86 hidden 87 } 88 89 type BuiltinTypes struct{} 90 91 func (BuiltinTypes) Map(args *Args, reply *map[int]int) error { 92 (*reply)[args.A] = args.B 93 return nil 94 } 95 96 func (BuiltinTypes) Slice(args *Args, reply *[]int) error { 97 *reply = append(*reply, args.A, args.B) 98 return nil 99 } 100 101 func (BuiltinTypes) Array(args *Args, reply *[2]int) error { 102 (*reply)[0] = args.A 103 (*reply)[1] = args.B 104 return nil 105 } 106 107 func listenTCP() (net.Listener, string) { 108 l, e := net.Listen("tcp", "127.0.0.1:0") // any available address 109 if e != nil { 110 log.Fatalf("net.Listen tcp :0: %v", e) 111 } 112 return l, l.Addr().String() 113 } 114 115 func startServer() { 116 Register(new(Arith)) 117 Register(new(Embed)) 118 RegisterName("net.rpc.Arith", new(Arith)) 119 Register(BuiltinTypes{}) 120 121 var l net.Listener 122 l, serverAddr = listenTCP() 123 log.Println("Test RPC server listening on", serverAddr) 124 go Accept(l) 125 126 HandleHTTP() 127 httpOnce.Do(startHttpServer) 128 } 129 130 func startNewServer() { 131 newServer = NewServer() 132 newServer.Register(new(Arith)) 133 newServer.Register(new(Embed)) 134 newServer.RegisterName("net.rpc.Arith", new(Arith)) 135 newServer.RegisterName("newServer.Arith", new(Arith)) 136 137 var l net.Listener 138 l, newServerAddr = listenTCP() 139 log.Println("NewServer test RPC server listening on", newServerAddr) 140 go newServer.Accept(l) 141 142 newServer.HandleHTTP(newHttpPath, "/bar") 143 httpOnce.Do(startHttpServer) 144 } 145 146 func startHttpServer() { 147 server := httptest.NewServer(nil) 148 httpServerAddr = server.Listener.Addr().String() 149 log.Println("Test HTTP RPC server listening on", httpServerAddr) 150 } 151 152 func TestRPC(t *testing.T) { 153 once.Do(startServer) 154 testRPC(t, serverAddr) 155 newOnce.Do(startNewServer) 156 testRPC(t, newServerAddr) 157 testNewServerRPC(t, newServerAddr) 158 } 159 160 func testRPC(t *testing.T, addr string) { 161 client, err := Dial("tcp", addr) 162 if err != nil { 163 t.Fatal("dialing", err) 164 } 165 defer client.Close() 166 167 // Synchronous calls 168 args := &Args{7, 8} 169 reply := new(Reply) 170 err = client.Call("Arith.Add", args, reply) 171 if err != nil { 172 t.Errorf("Add: expected no error but got string %q", err.Error()) 173 } 174 if reply.C != args.A+args.B { 175 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 176 } 177 178 // Methods exported from unexported embedded structs 179 args = &Args{7, 0} 180 reply = new(Reply) 181 err = client.Call("Embed.Exported", args, reply) 182 if err != nil { 183 t.Errorf("Add: expected no error but got string %q", err.Error()) 184 } 185 if reply.C != args.A+args.B { 186 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 187 } 188 189 // Nonexistent method 190 args = &Args{7, 0} 191 reply = new(Reply) 192 err = client.Call("Arith.BadOperation", args, reply) 193 // expect an error 194 if err == nil { 195 t.Error("BadOperation: expected error") 196 } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") { 197 t.Errorf("BadOperation: expected can't find method error; got %q", err) 198 } 199 200 // Unknown service 201 args = &Args{7, 8} 202 reply = new(Reply) 203 err = client.Call("Arith.Unknown", args, reply) 204 if err == nil { 205 t.Error("expected error calling unknown service") 206 } else if !strings.Contains(err.Error(), "method") { 207 t.Error("expected error about method; got", err) 208 } 209 210 // Out of order. 211 args = &Args{7, 8} 212 mulReply := new(Reply) 213 mulCall := client.Go("Arith.Mul", args, mulReply, nil) 214 addReply := new(Reply) 215 addCall := client.Go("Arith.Add", args, addReply, nil) 216 217 addCall = <-addCall.Done 218 if addCall.Error != nil { 219 t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) 220 } 221 if addReply.C != args.A+args.B { 222 t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) 223 } 224 225 mulCall = <-mulCall.Done 226 if mulCall.Error != nil { 227 t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) 228 } 229 if mulReply.C != args.A*args.B { 230 t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) 231 } 232 233 // Error test 234 args = &Args{7, 0} 235 reply = new(Reply) 236 err = client.Call("Arith.Div", args, reply) 237 // expect an error: zero divide 238 if err == nil { 239 t.Error("Div: expected error") 240 } else if err.Error() != "divide by zero" { 241 t.Error("Div: expected divide by zero error; got", err) 242 } 243 244 // Bad type. 245 reply = new(Reply) 246 err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use 247 if err == nil { 248 t.Error("expected error calling Arith.Add with wrong arg type") 249 } else if !strings.Contains(err.Error(), "type") { 250 t.Error("expected error about type; got", err) 251 } 252 253 // Non-struct argument 254 const Val = 12345 255 str := fmt.Sprint(Val) 256 reply = new(Reply) 257 err = client.Call("Arith.Scan", &str, reply) 258 if err != nil { 259 t.Errorf("Scan: expected no error but got string %q", err.Error()) 260 } else if reply.C != Val { 261 t.Errorf("Scan: expected %d got %d", Val, reply.C) 262 } 263 264 // Non-struct reply 265 args = &Args{27, 35} 266 str = "" 267 err = client.Call("Arith.String", args, &str) 268 if err != nil { 269 t.Errorf("String: expected no error but got string %q", err.Error()) 270 } 271 expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) 272 if str != expect { 273 t.Errorf("String: expected %s got %s", expect, str) 274 } 275 276 args = &Args{7, 8} 277 reply = new(Reply) 278 err = client.Call("Arith.Mul", args, reply) 279 if err != nil { 280 t.Errorf("Mul: expected no error but got string %q", err.Error()) 281 } 282 if reply.C != args.A*args.B { 283 t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) 284 } 285 286 // ServiceName contain "." character 287 args = &Args{7, 8} 288 reply = new(Reply) 289 err = client.Call("net.rpc.Arith.Add", args, reply) 290 if err != nil { 291 t.Errorf("Add: expected no error but got string %q", err.Error()) 292 } 293 if reply.C != args.A+args.B { 294 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 295 } 296 } 297 298 func testNewServerRPC(t *testing.T, addr string) { 299 client, err := Dial("tcp", addr) 300 if err != nil { 301 t.Fatal("dialing", err) 302 } 303 defer client.Close() 304 305 // Synchronous calls 306 args := &Args{7, 8} 307 reply := new(Reply) 308 err = client.Call("newServer.Arith.Add", args, reply) 309 if err != nil { 310 t.Errorf("Add: expected no error but got string %q", err.Error()) 311 } 312 if reply.C != args.A+args.B { 313 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 314 } 315 } 316 317 func TestHTTP(t *testing.T) { 318 once.Do(startServer) 319 testHTTPRPC(t, "") 320 newOnce.Do(startNewServer) 321 testHTTPRPC(t, newHttpPath) 322 } 323 324 func testHTTPRPC(t *testing.T, path string) { 325 var client *Client 326 var err error 327 if path == "" { 328 client, err = DialHTTP("tcp", httpServerAddr) 329 } else { 330 client, err = DialHTTPPath("tcp", httpServerAddr, path) 331 } 332 if err != nil { 333 t.Fatal("dialing", err) 334 } 335 defer client.Close() 336 337 // Synchronous calls 338 args := &Args{7, 8} 339 reply := new(Reply) 340 err = client.Call("Arith.Add", args, reply) 341 if err != nil { 342 t.Errorf("Add: expected no error but got string %q", err.Error()) 343 } 344 if reply.C != args.A+args.B { 345 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 346 } 347 } 348 349 func TestBuiltinTypes(t *testing.T) { 350 once.Do(startServer) 351 352 client, err := DialHTTP("tcp", httpServerAddr) 353 if err != nil { 354 t.Fatal("dialing", err) 355 } 356 defer client.Close() 357 358 // Map 359 args := &Args{7, 8} 360 replyMap := map[int]int{} 361 err = client.Call("BuiltinTypes.Map", args, &replyMap) 362 if err != nil { 363 t.Errorf("Map: expected no error but got string %q", err.Error()) 364 } 365 if replyMap[args.A] != args.B { 366 t.Errorf("Map: expected %d got %d", args.B, replyMap[args.A]) 367 } 368 369 // Slice 370 args = &Args{7, 8} 371 replySlice := []int{} 372 err = client.Call("BuiltinTypes.Slice", args, &replySlice) 373 if err != nil { 374 t.Errorf("Slice: expected no error but got string %q", err.Error()) 375 } 376 if e := []int{args.A, args.B}; !reflect.DeepEqual(replySlice, e) { 377 t.Errorf("Slice: expected %v got %v", e, replySlice) 378 } 379 380 // Array 381 args = &Args{7, 8} 382 replyArray := [2]int{} 383 err = client.Call("BuiltinTypes.Array", args, &replyArray) 384 if err != nil { 385 t.Errorf("Array: expected no error but got string %q", err.Error()) 386 } 387 if e := [2]int{args.A, args.B}; !reflect.DeepEqual(replyArray, e) { 388 t.Errorf("Array: expected %v got %v", e, replyArray) 389 } 390 } 391 392 // CodecEmulator provides a client-like api and a ServerCodec interface. 393 // Can be used to test ServeRequest. 394 type CodecEmulator struct { 395 server *Server 396 serviceMethod string 397 args *Args 398 reply *Reply 399 err error 400 } 401 402 func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error { 403 codec.serviceMethod = serviceMethod 404 codec.args = args 405 codec.reply = reply 406 codec.err = nil 407 var serverError error 408 if codec.server == nil { 409 serverError = ServeRequest(codec) 410 } else { 411 serverError = codec.server.ServeRequest(codec) 412 } 413 if codec.err == nil && serverError != nil { 414 codec.err = serverError 415 } 416 return codec.err 417 } 418 419 func (codec *CodecEmulator) ReadRequestHeader(req *Request) error { 420 req.ServiceMethod = codec.serviceMethod 421 req.Seq = 0 422 return nil 423 } 424 425 func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error { 426 if codec.args == nil { 427 return io.ErrUnexpectedEOF 428 } 429 *(argv.(*Args)) = *codec.args 430 return nil 431 } 432 433 func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error { 434 if resp.Error != "" { 435 codec.err = errors.New(resp.Error) 436 } else { 437 *codec.reply = *(reply.(*Reply)) 438 } 439 return nil 440 } 441 442 func (codec *CodecEmulator) Close() error { 443 return nil 444 } 445 446 func TestServeRequest(t *testing.T) { 447 once.Do(startServer) 448 testServeRequest(t, nil) 449 newOnce.Do(startNewServer) 450 testServeRequest(t, newServer) 451 } 452 453 func testServeRequest(t *testing.T, server *Server) { 454 client := CodecEmulator{server: server} 455 defer client.Close() 456 457 args := &Args{7, 8} 458 reply := new(Reply) 459 err := client.Call("Arith.Add", args, reply) 460 if err != nil { 461 t.Errorf("Add: expected no error but got string %q", err.Error()) 462 } 463 if reply.C != args.A+args.B { 464 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 465 } 466 467 err = client.Call("Arith.Add", nil, reply) 468 if err == nil { 469 t.Errorf("expected error calling Arith.Add with nil arg") 470 } 471 } 472 473 type ReplyNotPointer int 474 type ArgNotPublic int 475 type ReplyNotPublic int 476 type NeedsPtrType int 477 type local struct{} 478 479 func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error { 480 return nil 481 } 482 483 func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error { 484 return nil 485 } 486 487 func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error { 488 return nil 489 } 490 491 func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error { 492 return nil 493 } 494 495 // Check that registration handles lots of bad methods and a type with no suitable methods. 496 func TestRegistrationError(t *testing.T) { 497 err := Register(new(ReplyNotPointer)) 498 if err == nil { 499 t.Error("expected error registering ReplyNotPointer") 500 } 501 err = Register(new(ArgNotPublic)) 502 if err == nil { 503 t.Error("expected error registering ArgNotPublic") 504 } 505 err = Register(new(ReplyNotPublic)) 506 if err == nil { 507 t.Error("expected error registering ReplyNotPublic") 508 } 509 err = Register(NeedsPtrType(0)) 510 if err == nil { 511 t.Error("expected error registering NeedsPtrType") 512 } else if !strings.Contains(err.Error(), "pointer") { 513 t.Error("expected hint when registering NeedsPtrType") 514 } 515 } 516 517 type WriteFailCodec int 518 519 func (WriteFailCodec) WriteRequest(*Request, interface{}) error { 520 // the panic caused by this error used to not unlock a lock. 521 return errors.New("fail") 522 } 523 524 func (WriteFailCodec) ReadResponseHeader(*Response) error { 525 select {} 526 } 527 528 func (WriteFailCodec) ReadResponseBody(interface{}) error { 529 select {} 530 } 531 532 func (WriteFailCodec) Close() error { 533 return nil 534 } 535 536 func TestSendDeadlock(t *testing.T) { 537 client := NewClientWithCodec(WriteFailCodec(0)) 538 defer client.Close() 539 540 done := make(chan bool) 541 go func() { 542 testSendDeadlock(client) 543 testSendDeadlock(client) 544 done <- true 545 }() 546 select { 547 case <-done: 548 return 549 case <-time.After(5 * time.Second): 550 t.Fatal("deadlock") 551 } 552 } 553 554 func testSendDeadlock(client *Client) { 555 defer func() { 556 recover() 557 }() 558 args := &Args{7, 8} 559 reply := new(Reply) 560 client.Call("Arith.Add", args, reply) 561 } 562 563 func dialDirect() (*Client, error) { 564 return Dial("tcp", serverAddr) 565 } 566 567 func dialHTTP() (*Client, error) { 568 return DialHTTP("tcp", httpServerAddr) 569 } 570 571 func countMallocs(dial func() (*Client, error), t *testing.T) float64 { 572 once.Do(startServer) 573 client, err := dial() 574 if err != nil { 575 t.Fatal("error dialing", err) 576 } 577 defer client.Close() 578 579 args := &Args{7, 8} 580 reply := new(Reply) 581 return testing.AllocsPerRun(100, func() { 582 err := client.Call("Arith.Add", args, reply) 583 if err != nil { 584 t.Errorf("Add: expected no error but got string %q", err.Error()) 585 } 586 if reply.C != args.A+args.B { 587 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 588 } 589 }) 590 } 591 592 func TestCountMallocs(t *testing.T) { 593 if testing.Short() { 594 t.Skip("skipping malloc count in short mode") 595 } 596 if runtime.GOMAXPROCS(0) > 1 { 597 t.Skip("skipping; GOMAXPROCS>1") 598 } 599 fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t)) 600 } 601 602 func TestCountMallocsOverHTTP(t *testing.T) { 603 if testing.Short() { 604 t.Skip("skipping malloc count in short mode") 605 } 606 if runtime.GOMAXPROCS(0) > 1 { 607 t.Skip("skipping; GOMAXPROCS>1") 608 } 609 fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t)) 610 } 611 612 type writeCrasher struct { 613 done chan bool 614 } 615 616 func (writeCrasher) Close() error { 617 return nil 618 } 619 620 func (w *writeCrasher) Read(p []byte) (int, error) { 621 <-w.done 622 return 0, io.EOF 623 } 624 625 func (writeCrasher) Write(p []byte) (int, error) { 626 return 0, errors.New("fake write failure") 627 } 628 629 func TestClientWriteError(t *testing.T) { 630 w := &writeCrasher{done: make(chan bool)} 631 c := NewClient(w) 632 defer c.Close() 633 634 res := false 635 err := c.Call("foo", 1, &res) 636 if err == nil { 637 t.Fatal("expected error") 638 } 639 if err.Error() != "fake write failure" { 640 t.Error("unexpected value of error:", err) 641 } 642 w.done <- true 643 } 644 645 func TestTCPClose(t *testing.T) { 646 once.Do(startServer) 647 648 client, err := dialHTTP() 649 if err != nil { 650 t.Fatalf("dialing: %v", err) 651 } 652 defer client.Close() 653 654 args := Args{17, 8} 655 var reply Reply 656 err = client.Call("Arith.Mul", args, &reply) 657 if err != nil { 658 t.Fatal("arith error:", err) 659 } 660 t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply) 661 if reply.C != args.A*args.B { 662 t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B) 663 } 664 } 665 666 func TestErrorAfterClientClose(t *testing.T) { 667 once.Do(startServer) 668 669 client, err := dialHTTP() 670 if err != nil { 671 t.Fatalf("dialing: %v", err) 672 } 673 err = client.Close() 674 if err != nil { 675 t.Fatal("close error:", err) 676 } 677 err = client.Call("Arith.Add", &Args{7, 9}, new(Reply)) 678 if err != ErrShutdown { 679 t.Errorf("Forever: expected ErrShutdown got %v", err) 680 } 681 } 682 683 // Tests the fix to issue 11221. Without the fix, this loops forever or crashes. 684 func TestAcceptExitAfterListenerClose(t *testing.T) { 685 newServer := NewServer() 686 newServer.Register(new(Arith)) 687 newServer.RegisterName("net.rpc.Arith", new(Arith)) 688 newServer.RegisterName("newServer.Arith", new(Arith)) 689 690 var l net.Listener 691 l, _ = listenTCP() 692 l.Close() 693 newServer.Accept(l) 694 } 695 696 func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { 697 once.Do(startServer) 698 client, err := dial() 699 if err != nil { 700 b.Fatal("error dialing:", err) 701 } 702 defer client.Close() 703 704 // Synchronous calls 705 args := &Args{7, 8} 706 b.ResetTimer() 707 708 b.RunParallel(func(pb *testing.PB) { 709 reply := new(Reply) 710 for pb.Next() { 711 err := client.Call("Arith.Add", args, reply) 712 if err != nil { 713 b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error()) 714 } 715 if reply.C != args.A+args.B { 716 b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B) 717 } 718 } 719 }) 720 } 721 722 func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) { 723 if b.N == 0 { 724 return 725 } 726 const MaxConcurrentCalls = 100 727 once.Do(startServer) 728 client, err := dial() 729 if err != nil { 730 b.Fatal("error dialing:", err) 731 } 732 defer client.Close() 733 734 // Asynchronous calls 735 args := &Args{7, 8} 736 procs := 4 * runtime.GOMAXPROCS(-1) 737 send := int32(b.N) 738 recv := int32(b.N) 739 var wg sync.WaitGroup 740 wg.Add(procs) 741 gate := make(chan bool, MaxConcurrentCalls) 742 res := make(chan *Call, MaxConcurrentCalls) 743 b.ResetTimer() 744 745 for p := 0; p < procs; p++ { 746 go func() { 747 for atomic.AddInt32(&send, -1) >= 0 { 748 gate <- true 749 reply := new(Reply) 750 client.Go("Arith.Add", args, reply, res) 751 } 752 }() 753 go func() { 754 for call := range res { 755 A := call.Args.(*Args).A 756 B := call.Args.(*Args).B 757 C := call.Reply.(*Reply).C 758 if A+B != C { 759 b.Errorf("incorrect reply: Add: expected %d got %d", A+B, C) 760 return 761 } 762 <-gate 763 if atomic.AddInt32(&recv, -1) == 0 { 764 close(res) 765 } 766 } 767 wg.Done() 768 }() 769 } 770 wg.Wait() 771 } 772 773 func BenchmarkEndToEnd(b *testing.B) { 774 benchmarkEndToEnd(dialDirect, b) 775 } 776 777 func BenchmarkEndToEndHTTP(b *testing.B) { 778 benchmarkEndToEnd(dialHTTP, b) 779 } 780 781 func BenchmarkEndToEndAsync(b *testing.B) { 782 benchmarkEndToEndAsync(dialDirect, b) 783 } 784 785 func BenchmarkEndToEndAsyncHTTP(b *testing.B) { 786 benchmarkEndToEndAsync(dialHTTP, b) 787 }