github.com/4ad/go@v0.0.0-20161219182952-69a12818b605/src/net/rpc/server_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rpc 6 7 import ( 8 "errors" 9 "fmt" 10 "io" 11 "log" 12 "net" 13 "net/http/httptest" 14 "runtime" 15 "strings" 16 "sync" 17 "sync/atomic" 18 "testing" 19 "time" 20 ) 21 22 var ( 23 newServer *Server 24 serverAddr, newServerAddr string 25 httpServerAddr string 26 once, newOnce, httpOnce sync.Once 27 ) 28 29 const ( 30 newHttpPath = "/foo" 31 ) 32 33 type Args struct { 34 A, B int 35 } 36 37 type Reply struct { 38 C int 39 } 40 41 type Arith int 42 43 // Some of Arith's methods have value args, some have pointer args. That's deliberate. 44 45 func (t *Arith) Add(args Args, reply *Reply) error { 46 reply.C = args.A + args.B 47 return nil 48 } 49 50 func (t *Arith) Mul(args *Args, reply *Reply) error { 51 reply.C = args.A * args.B 52 return nil 53 } 54 55 func (t *Arith) Div(args Args, reply *Reply) error { 56 if args.B == 0 { 57 return errors.New("divide by zero") 58 } 59 reply.C = args.A / args.B 60 return nil 61 } 62 63 func (t *Arith) String(args *Args, reply *string) error { 64 *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) 65 return nil 66 } 67 68 func (t *Arith) Scan(args string, reply *Reply) (err error) { 69 _, err = fmt.Sscan(args, &reply.C) 70 return 71 } 72 73 func (t *Arith) Error(args *Args, reply *Reply) error { 74 panic("ERROR") 75 } 76 77 type hidden int 78 79 func (t *hidden) Exported(args Args, reply *Reply) error { 80 reply.C = args.A + args.B 81 return nil 82 } 83 84 type Embed struct { 85 hidden 86 } 87 88 func listenTCP() (net.Listener, string) { 89 l, e := net.Listen("tcp", "127.0.0.1:0") // any available address 90 if e != nil { 91 log.Fatalf("net.Listen tcp :0: %v", e) 92 } 93 return l, l.Addr().String() 94 } 95 96 func startServer() { 97 Register(new(Arith)) 98 Register(new(Embed)) 99 RegisterName("net.rpc.Arith", new(Arith)) 100 101 var l net.Listener 102 l, serverAddr = listenTCP() 103 log.Println("Test RPC server listening on", serverAddr) 104 go Accept(l) 105 106 HandleHTTP() 107 httpOnce.Do(startHttpServer) 108 } 109 110 func startNewServer() { 111 newServer = NewServer() 112 newServer.Register(new(Arith)) 113 newServer.Register(new(Embed)) 114 newServer.RegisterName("net.rpc.Arith", new(Arith)) 115 newServer.RegisterName("newServer.Arith", new(Arith)) 116 117 var l net.Listener 118 l, newServerAddr = listenTCP() 119 log.Println("NewServer test RPC server listening on", newServerAddr) 120 go newServer.Accept(l) 121 122 newServer.HandleHTTP(newHttpPath, "/bar") 123 httpOnce.Do(startHttpServer) 124 } 125 126 func startHttpServer() { 127 server := httptest.NewServer(nil) 128 httpServerAddr = server.Listener.Addr().String() 129 log.Println("Test HTTP RPC server listening on", httpServerAddr) 130 } 131 132 func TestRPC(t *testing.T) { 133 once.Do(startServer) 134 testRPC(t, serverAddr) 135 newOnce.Do(startNewServer) 136 testRPC(t, newServerAddr) 137 testNewServerRPC(t, newServerAddr) 138 } 139 140 func testRPC(t *testing.T, addr string) { 141 client, err := Dial("tcp", addr) 142 if err != nil { 143 t.Fatal("dialing", err) 144 } 145 defer client.Close() 146 147 // Synchronous calls 148 args := &Args{7, 8} 149 reply := new(Reply) 150 err = client.Call("Arith.Add", args, reply) 151 if err != nil { 152 t.Errorf("Add: expected no error but got string %q", err.Error()) 153 } 154 if reply.C != args.A+args.B { 155 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 156 } 157 158 // Methods exported from unexported embedded structs 159 args = &Args{7, 0} 160 reply = new(Reply) 161 err = client.Call("Embed.Exported", args, reply) 162 if err != nil { 163 t.Errorf("Add: expected no error but got string %q", err.Error()) 164 } 165 if reply.C != args.A+args.B { 166 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 167 } 168 169 // Nonexistent method 170 args = &Args{7, 0} 171 reply = new(Reply) 172 err = client.Call("Arith.BadOperation", args, reply) 173 // expect an error 174 if err == nil { 175 t.Error("BadOperation: expected error") 176 } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") { 177 t.Errorf("BadOperation: expected can't find method error; got %q", err) 178 } 179 180 // Unknown service 181 args = &Args{7, 8} 182 reply = new(Reply) 183 err = client.Call("Arith.Unknown", args, reply) 184 if err == nil { 185 t.Error("expected error calling unknown service") 186 } else if !strings.Contains(err.Error(), "method") { 187 t.Error("expected error about method; got", err) 188 } 189 190 // Out of order. 191 args = &Args{7, 8} 192 mulReply := new(Reply) 193 mulCall := client.Go("Arith.Mul", args, mulReply, nil) 194 addReply := new(Reply) 195 addCall := client.Go("Arith.Add", args, addReply, nil) 196 197 addCall = <-addCall.Done 198 if addCall.Error != nil { 199 t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) 200 } 201 if addReply.C != args.A+args.B { 202 t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) 203 } 204 205 mulCall = <-mulCall.Done 206 if mulCall.Error != nil { 207 t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) 208 } 209 if mulReply.C != args.A*args.B { 210 t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) 211 } 212 213 // Error test 214 args = &Args{7, 0} 215 reply = new(Reply) 216 err = client.Call("Arith.Div", args, reply) 217 // expect an error: zero divide 218 if err == nil { 219 t.Error("Div: expected error") 220 } else if err.Error() != "divide by zero" { 221 t.Error("Div: expected divide by zero error; got", err) 222 } 223 224 // Bad type. 225 reply = new(Reply) 226 err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use 227 if err == nil { 228 t.Error("expected error calling Arith.Add with wrong arg type") 229 } else if !strings.Contains(err.Error(), "type") { 230 t.Error("expected error about type; got", err) 231 } 232 233 // Non-struct argument 234 const Val = 12345 235 str := fmt.Sprint(Val) 236 reply = new(Reply) 237 err = client.Call("Arith.Scan", &str, reply) 238 if err != nil { 239 t.Errorf("Scan: expected no error but got string %q", err.Error()) 240 } else if reply.C != Val { 241 t.Errorf("Scan: expected %d got %d", Val, reply.C) 242 } 243 244 // Non-struct reply 245 args = &Args{27, 35} 246 str = "" 247 err = client.Call("Arith.String", args, &str) 248 if err != nil { 249 t.Errorf("String: expected no error but got string %q", err.Error()) 250 } 251 expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) 252 if str != expect { 253 t.Errorf("String: expected %s got %s", expect, str) 254 } 255 256 args = &Args{7, 8} 257 reply = new(Reply) 258 err = client.Call("Arith.Mul", args, reply) 259 if err != nil { 260 t.Errorf("Mul: expected no error but got string %q", err.Error()) 261 } 262 if reply.C != args.A*args.B { 263 t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) 264 } 265 266 // ServiceName contain "." character 267 args = &Args{7, 8} 268 reply = new(Reply) 269 err = client.Call("net.rpc.Arith.Add", args, reply) 270 if err != nil { 271 t.Errorf("Add: expected no error but got string %q", err.Error()) 272 } 273 if reply.C != args.A+args.B { 274 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 275 } 276 } 277 278 func testNewServerRPC(t *testing.T, addr string) { 279 client, err := Dial("tcp", addr) 280 if err != nil { 281 t.Fatal("dialing", err) 282 } 283 defer client.Close() 284 285 // Synchronous calls 286 args := &Args{7, 8} 287 reply := new(Reply) 288 err = client.Call("newServer.Arith.Add", args, reply) 289 if err != nil { 290 t.Errorf("Add: expected no error but got string %q", err.Error()) 291 } 292 if reply.C != args.A+args.B { 293 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 294 } 295 } 296 297 func TestHTTP(t *testing.T) { 298 once.Do(startServer) 299 testHTTPRPC(t, "") 300 newOnce.Do(startNewServer) 301 testHTTPRPC(t, newHttpPath) 302 } 303 304 func testHTTPRPC(t *testing.T, path string) { 305 var client *Client 306 var err error 307 if path == "" { 308 client, err = DialHTTP("tcp", httpServerAddr) 309 } else { 310 client, err = DialHTTPPath("tcp", httpServerAddr, path) 311 } 312 if err != nil { 313 t.Fatal("dialing", err) 314 } 315 defer client.Close() 316 317 // Synchronous calls 318 args := &Args{7, 8} 319 reply := new(Reply) 320 err = client.Call("Arith.Add", args, reply) 321 if err != nil { 322 t.Errorf("Add: expected no error but got string %q", err.Error()) 323 } 324 if reply.C != args.A+args.B { 325 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 326 } 327 } 328 329 // CodecEmulator provides a client-like api and a ServerCodec interface. 330 // Can be used to test ServeRequest. 331 type CodecEmulator struct { 332 server *Server 333 serviceMethod string 334 args *Args 335 reply *Reply 336 err error 337 } 338 339 func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error { 340 codec.serviceMethod = serviceMethod 341 codec.args = args 342 codec.reply = reply 343 codec.err = nil 344 var serverError error 345 if codec.server == nil { 346 serverError = ServeRequest(codec) 347 } else { 348 serverError = codec.server.ServeRequest(codec) 349 } 350 if codec.err == nil && serverError != nil { 351 codec.err = serverError 352 } 353 return codec.err 354 } 355 356 func (codec *CodecEmulator) ReadRequestHeader(req *Request) error { 357 req.ServiceMethod = codec.serviceMethod 358 req.Seq = 0 359 return nil 360 } 361 362 func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error { 363 if codec.args == nil { 364 return io.ErrUnexpectedEOF 365 } 366 *(argv.(*Args)) = *codec.args 367 return nil 368 } 369 370 func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error { 371 if resp.Error != "" { 372 codec.err = errors.New(resp.Error) 373 } else { 374 *codec.reply = *(reply.(*Reply)) 375 } 376 return nil 377 } 378 379 func (codec *CodecEmulator) Close() error { 380 return nil 381 } 382 383 func TestServeRequest(t *testing.T) { 384 once.Do(startServer) 385 testServeRequest(t, nil) 386 newOnce.Do(startNewServer) 387 testServeRequest(t, newServer) 388 } 389 390 func testServeRequest(t *testing.T, server *Server) { 391 client := CodecEmulator{server: server} 392 defer client.Close() 393 394 args := &Args{7, 8} 395 reply := new(Reply) 396 err := client.Call("Arith.Add", args, reply) 397 if err != nil { 398 t.Errorf("Add: expected no error but got string %q", err.Error()) 399 } 400 if reply.C != args.A+args.B { 401 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 402 } 403 404 err = client.Call("Arith.Add", nil, reply) 405 if err == nil { 406 t.Errorf("expected error calling Arith.Add with nil arg") 407 } 408 } 409 410 type ReplyNotPointer int 411 type ArgNotPublic int 412 type ReplyNotPublic int 413 type NeedsPtrType int 414 type local struct{} 415 416 func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error { 417 return nil 418 } 419 420 func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error { 421 return nil 422 } 423 424 func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error { 425 return nil 426 } 427 428 func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error { 429 return nil 430 } 431 432 // Check that registration handles lots of bad methods and a type with no suitable methods. 433 func TestRegistrationError(t *testing.T) { 434 err := Register(new(ReplyNotPointer)) 435 if err == nil { 436 t.Error("expected error registering ReplyNotPointer") 437 } 438 err = Register(new(ArgNotPublic)) 439 if err == nil { 440 t.Error("expected error registering ArgNotPublic") 441 } 442 err = Register(new(ReplyNotPublic)) 443 if err == nil { 444 t.Error("expected error registering ReplyNotPublic") 445 } 446 err = Register(NeedsPtrType(0)) 447 if err == nil { 448 t.Error("expected error registering NeedsPtrType") 449 } else if !strings.Contains(err.Error(), "pointer") { 450 t.Error("expected hint when registering NeedsPtrType") 451 } 452 } 453 454 type WriteFailCodec int 455 456 func (WriteFailCodec) WriteRequest(*Request, interface{}) error { 457 // the panic caused by this error used to not unlock a lock. 458 return errors.New("fail") 459 } 460 461 func (WriteFailCodec) ReadResponseHeader(*Response) error { 462 select {} 463 } 464 465 func (WriteFailCodec) ReadResponseBody(interface{}) error { 466 select {} 467 } 468 469 func (WriteFailCodec) Close() error { 470 return nil 471 } 472 473 func TestSendDeadlock(t *testing.T) { 474 client := NewClientWithCodec(WriteFailCodec(0)) 475 defer client.Close() 476 477 done := make(chan bool) 478 go func() { 479 testSendDeadlock(client) 480 testSendDeadlock(client) 481 done <- true 482 }() 483 select { 484 case <-done: 485 return 486 case <-time.After(5 * time.Second): 487 t.Fatal("deadlock") 488 } 489 } 490 491 func testSendDeadlock(client *Client) { 492 defer func() { 493 recover() 494 }() 495 args := &Args{7, 8} 496 reply := new(Reply) 497 client.Call("Arith.Add", args, reply) 498 } 499 500 func dialDirect() (*Client, error) { 501 return Dial("tcp", serverAddr) 502 } 503 504 func dialHTTP() (*Client, error) { 505 return DialHTTP("tcp", httpServerAddr) 506 } 507 508 func countMallocs(dial func() (*Client, error), t *testing.T) float64 { 509 once.Do(startServer) 510 client, err := dial() 511 if err != nil { 512 t.Fatal("error dialing", err) 513 } 514 defer client.Close() 515 516 args := &Args{7, 8} 517 reply := new(Reply) 518 return testing.AllocsPerRun(100, func() { 519 err := client.Call("Arith.Add", args, reply) 520 if err != nil { 521 t.Errorf("Add: expected no error but got string %q", err.Error()) 522 } 523 if reply.C != args.A+args.B { 524 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 525 } 526 }) 527 } 528 529 func TestCountMallocs(t *testing.T) { 530 if testing.Short() { 531 t.Skip("skipping malloc count in short mode") 532 } 533 if runtime.GOMAXPROCS(0) > 1 { 534 t.Skip("skipping; GOMAXPROCS>1") 535 } 536 fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t)) 537 } 538 539 func TestCountMallocsOverHTTP(t *testing.T) { 540 if testing.Short() { 541 t.Skip("skipping malloc count in short mode") 542 } 543 if runtime.GOMAXPROCS(0) > 1 { 544 t.Skip("skipping; GOMAXPROCS>1") 545 } 546 fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t)) 547 } 548 549 type writeCrasher struct { 550 done chan bool 551 } 552 553 func (writeCrasher) Close() error { 554 return nil 555 } 556 557 func (w *writeCrasher) Read(p []byte) (int, error) { 558 <-w.done 559 return 0, io.EOF 560 } 561 562 func (writeCrasher) Write(p []byte) (int, error) { 563 return 0, errors.New("fake write failure") 564 } 565 566 func TestClientWriteError(t *testing.T) { 567 w := &writeCrasher{done: make(chan bool)} 568 c := NewClient(w) 569 defer c.Close() 570 571 res := false 572 err := c.Call("foo", 1, &res) 573 if err == nil { 574 t.Fatal("expected error") 575 } 576 if err.Error() != "fake write failure" { 577 t.Error("unexpected value of error:", err) 578 } 579 w.done <- true 580 } 581 582 func TestTCPClose(t *testing.T) { 583 once.Do(startServer) 584 585 client, err := dialHTTP() 586 if err != nil { 587 t.Fatalf("dialing: %v", err) 588 } 589 defer client.Close() 590 591 args := Args{17, 8} 592 var reply Reply 593 err = client.Call("Arith.Mul", args, &reply) 594 if err != nil { 595 t.Fatal("arith error:", err) 596 } 597 t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply) 598 if reply.C != args.A*args.B { 599 t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B) 600 } 601 } 602 603 func TestErrorAfterClientClose(t *testing.T) { 604 once.Do(startServer) 605 606 client, err := dialHTTP() 607 if err != nil { 608 t.Fatalf("dialing: %v", err) 609 } 610 err = client.Close() 611 if err != nil { 612 t.Fatal("close error:", err) 613 } 614 err = client.Call("Arith.Add", &Args{7, 9}, new(Reply)) 615 if err != ErrShutdown { 616 t.Errorf("Forever: expected ErrShutdown got %v", err) 617 } 618 } 619 620 // Tests the fix to issue 11221. Without the fix, this loops forever or crashes. 621 func TestAcceptExitAfterListenerClose(t *testing.T) { 622 newServer = NewServer() 623 newServer.Register(new(Arith)) 624 newServer.RegisterName("net.rpc.Arith", new(Arith)) 625 newServer.RegisterName("newServer.Arith", new(Arith)) 626 627 var l net.Listener 628 l, newServerAddr = listenTCP() 629 l.Close() 630 newServer.Accept(l) 631 } 632 633 func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { 634 once.Do(startServer) 635 client, err := dial() 636 if err != nil { 637 b.Fatal("error dialing:", err) 638 } 639 defer client.Close() 640 641 // Synchronous calls 642 args := &Args{7, 8} 643 b.ResetTimer() 644 645 b.RunParallel(func(pb *testing.PB) { 646 reply := new(Reply) 647 for pb.Next() { 648 err := client.Call("Arith.Add", args, reply) 649 if err != nil { 650 b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error()) 651 } 652 if reply.C != args.A+args.B { 653 b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B) 654 } 655 } 656 }) 657 } 658 659 func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) { 660 if b.N == 0 { 661 return 662 } 663 const MaxConcurrentCalls = 100 664 once.Do(startServer) 665 client, err := dial() 666 if err != nil { 667 b.Fatal("error dialing:", err) 668 } 669 defer client.Close() 670 671 // Asynchronous calls 672 args := &Args{7, 8} 673 procs := 4 * runtime.GOMAXPROCS(-1) 674 send := int32(b.N) 675 recv := int32(b.N) 676 var wg sync.WaitGroup 677 wg.Add(procs) 678 gate := make(chan bool, MaxConcurrentCalls) 679 res := make(chan *Call, MaxConcurrentCalls) 680 b.ResetTimer() 681 682 for p := 0; p < procs; p++ { 683 go func() { 684 for atomic.AddInt32(&send, -1) >= 0 { 685 gate <- true 686 reply := new(Reply) 687 client.Go("Arith.Add", args, reply, res) 688 } 689 }() 690 go func() { 691 for call := range res { 692 A := call.Args.(*Args).A 693 B := call.Args.(*Args).B 694 C := call.Reply.(*Reply).C 695 if A+B != C { 696 b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C) 697 } 698 <-gate 699 if atomic.AddInt32(&recv, -1) == 0 { 700 close(res) 701 } 702 } 703 wg.Done() 704 }() 705 } 706 wg.Wait() 707 } 708 709 func BenchmarkEndToEnd(b *testing.B) { 710 benchmarkEndToEnd(dialDirect, b) 711 } 712 713 func BenchmarkEndToEndHTTP(b *testing.B) { 714 benchmarkEndToEnd(dialHTTP, b) 715 } 716 717 func BenchmarkEndToEndAsync(b *testing.B) { 718 benchmarkEndToEndAsync(dialDirect, b) 719 } 720 721 func BenchmarkEndToEndAsyncHTTP(b *testing.B) { 722 benchmarkEndToEndAsync(dialHTTP, b) 723 }