github.com/xushiwei/go@v0.0.0-20130601165731-2b9d83f45bc9/src/pkg/net/rpc/server_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package rpc 6 7 import ( 8 "errors" 9 "fmt" 10 "io" 11 "log" 12 "net" 13 "net/http/httptest" 14 "runtime" 15 "strings" 16 "sync" 17 "sync/atomic" 18 "testing" 19 "time" 20 ) 21 22 var ( 23 newServer *Server 24 serverAddr, newServerAddr string 25 httpServerAddr string 26 once, newOnce, httpOnce sync.Once 27 ) 28 29 const ( 30 newHttpPath = "/foo" 31 ) 32 33 type Args struct { 34 A, B int 35 } 36 37 type Reply struct { 38 C int 39 } 40 41 type Arith int 42 43 // Some of Arith's methods have value args, some have pointer args. That's deliberate. 44 45 func (t *Arith) Add(args Args, reply *Reply) error { 46 reply.C = args.A + args.B 47 return nil 48 } 49 50 func (t *Arith) Mul(args *Args, reply *Reply) error { 51 reply.C = args.A * args.B 52 return nil 53 } 54 55 func (t *Arith) Div(args Args, reply *Reply) error { 56 if args.B == 0 { 57 return errors.New("divide by zero") 58 } 59 reply.C = args.A / args.B 60 return nil 61 } 62 63 func (t *Arith) String(args *Args, reply *string) error { 64 *reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) 65 return nil 66 } 67 68 func (t *Arith) Scan(args string, reply *Reply) (err error) { 69 _, err = fmt.Sscan(args, &reply.C) 70 return 71 } 72 73 func (t *Arith) Error(args *Args, reply *Reply) error { 74 panic("ERROR") 75 } 76 77 func listenTCP() (net.Listener, string) { 78 l, e := net.Listen("tcp", "127.0.0.1:0") // any available address 79 if e != nil { 80 log.Fatalf("net.Listen tcp :0: %v", e) 81 } 82 return l, l.Addr().String() 83 } 84 85 func startServer() { 86 Register(new(Arith)) 87 88 var l net.Listener 89 l, serverAddr = listenTCP() 90 log.Println("Test RPC server listening on", serverAddr) 91 go Accept(l) 92 93 HandleHTTP() 94 httpOnce.Do(startHttpServer) 95 } 96 97 func startNewServer() { 98 newServer = NewServer() 99 newServer.Register(new(Arith)) 100 101 var l net.Listener 102 l, newServerAddr = listenTCP() 103 log.Println("NewServer test RPC server listening on", newServerAddr) 104 go Accept(l) 105 106 newServer.HandleHTTP(newHttpPath, "/bar") 107 httpOnce.Do(startHttpServer) 108 } 109 110 func startHttpServer() { 111 server := httptest.NewServer(nil) 112 httpServerAddr = server.Listener.Addr().String() 113 log.Println("Test HTTP RPC server listening on", httpServerAddr) 114 } 115 116 func TestRPC(t *testing.T) { 117 once.Do(startServer) 118 testRPC(t, serverAddr) 119 newOnce.Do(startNewServer) 120 testRPC(t, newServerAddr) 121 } 122 123 func testRPC(t *testing.T, addr string) { 124 client, err := Dial("tcp", addr) 125 if err != nil { 126 t.Fatal("dialing", err) 127 } 128 129 // Synchronous calls 130 args := &Args{7, 8} 131 reply := new(Reply) 132 err = client.Call("Arith.Add", args, reply) 133 if err != nil { 134 t.Errorf("Add: expected no error but got string %q", err.Error()) 135 } 136 if reply.C != args.A+args.B { 137 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 138 } 139 140 // Nonexistent method 141 args = &Args{7, 0} 142 reply = new(Reply) 143 err = client.Call("Arith.BadOperation", args, reply) 144 // expect an error 145 if err == nil { 146 t.Error("BadOperation: expected error") 147 } else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") { 148 t.Errorf("BadOperation: expected can't find method error; got %q", err) 149 } 150 151 // Unknown service 152 args = &Args{7, 8} 153 reply = new(Reply) 154 err = client.Call("Arith.Unknown", args, reply) 155 if err == nil { 156 t.Error("expected error calling unknown service") 157 } else if strings.Index(err.Error(), "method") < 0 { 158 t.Error("expected error about method; got", err) 159 } 160 161 // Out of order. 162 args = &Args{7, 8} 163 mulReply := new(Reply) 164 mulCall := client.Go("Arith.Mul", args, mulReply, nil) 165 addReply := new(Reply) 166 addCall := client.Go("Arith.Add", args, addReply, nil) 167 168 addCall = <-addCall.Done 169 if addCall.Error != nil { 170 t.Errorf("Add: expected no error but got string %q", addCall.Error.Error()) 171 } 172 if addReply.C != args.A+args.B { 173 t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B) 174 } 175 176 mulCall = <-mulCall.Done 177 if mulCall.Error != nil { 178 t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error()) 179 } 180 if mulReply.C != args.A*args.B { 181 t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B) 182 } 183 184 // Error test 185 args = &Args{7, 0} 186 reply = new(Reply) 187 err = client.Call("Arith.Div", args, reply) 188 // expect an error: zero divide 189 if err == nil { 190 t.Error("Div: expected error") 191 } else if err.Error() != "divide by zero" { 192 t.Error("Div: expected divide by zero error; got", err) 193 } 194 195 // Bad type. 196 reply = new(Reply) 197 err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use 198 if err == nil { 199 t.Error("expected error calling Arith.Add with wrong arg type") 200 } else if strings.Index(err.Error(), "type") < 0 { 201 t.Error("expected error about type; got", err) 202 } 203 204 // Non-struct argument 205 const Val = 12345 206 str := fmt.Sprint(Val) 207 reply = new(Reply) 208 err = client.Call("Arith.Scan", &str, reply) 209 if err != nil { 210 t.Errorf("Scan: expected no error but got string %q", err.Error()) 211 } else if reply.C != Val { 212 t.Errorf("Scan: expected %d got %d", Val, reply.C) 213 } 214 215 // Non-struct reply 216 args = &Args{27, 35} 217 str = "" 218 err = client.Call("Arith.String", args, &str) 219 if err != nil { 220 t.Errorf("String: expected no error but got string %q", err.Error()) 221 } 222 expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B) 223 if str != expect { 224 t.Errorf("String: expected %s got %s", expect, str) 225 } 226 227 args = &Args{7, 8} 228 reply = new(Reply) 229 err = client.Call("Arith.Mul", args, reply) 230 if err != nil { 231 t.Errorf("Mul: expected no error but got string %q", err.Error()) 232 } 233 if reply.C != args.A*args.B { 234 t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B) 235 } 236 } 237 238 func TestHTTP(t *testing.T) { 239 once.Do(startServer) 240 testHTTPRPC(t, "") 241 newOnce.Do(startNewServer) 242 testHTTPRPC(t, newHttpPath) 243 } 244 245 func testHTTPRPC(t *testing.T, path string) { 246 var client *Client 247 var err error 248 if path == "" { 249 client, err = DialHTTP("tcp", httpServerAddr) 250 } else { 251 client, err = DialHTTPPath("tcp", httpServerAddr, path) 252 } 253 if err != nil { 254 t.Fatal("dialing", err) 255 } 256 257 // Synchronous calls 258 args := &Args{7, 8} 259 reply := new(Reply) 260 err = client.Call("Arith.Add", args, reply) 261 if err != nil { 262 t.Errorf("Add: expected no error but got string %q", err.Error()) 263 } 264 if reply.C != args.A+args.B { 265 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 266 } 267 } 268 269 // CodecEmulator provides a client-like api and a ServerCodec interface. 270 // Can be used to test ServeRequest. 271 type CodecEmulator struct { 272 server *Server 273 serviceMethod string 274 args *Args 275 reply *Reply 276 err error 277 } 278 279 func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error { 280 codec.serviceMethod = serviceMethod 281 codec.args = args 282 codec.reply = reply 283 codec.err = nil 284 var serverError error 285 if codec.server == nil { 286 serverError = ServeRequest(codec) 287 } else { 288 serverError = codec.server.ServeRequest(codec) 289 } 290 if codec.err == nil && serverError != nil { 291 codec.err = serverError 292 } 293 return codec.err 294 } 295 296 func (codec *CodecEmulator) ReadRequestHeader(req *Request) error { 297 req.ServiceMethod = codec.serviceMethod 298 req.Seq = 0 299 return nil 300 } 301 302 func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error { 303 if codec.args == nil { 304 return io.ErrUnexpectedEOF 305 } 306 *(argv.(*Args)) = *codec.args 307 return nil 308 } 309 310 func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error { 311 if resp.Error != "" { 312 codec.err = errors.New(resp.Error) 313 } else { 314 *codec.reply = *(reply.(*Reply)) 315 } 316 return nil 317 } 318 319 func (codec *CodecEmulator) Close() error { 320 return nil 321 } 322 323 func TestServeRequest(t *testing.T) { 324 once.Do(startServer) 325 testServeRequest(t, nil) 326 newOnce.Do(startNewServer) 327 testServeRequest(t, newServer) 328 } 329 330 func testServeRequest(t *testing.T, server *Server) { 331 client := CodecEmulator{server: server} 332 333 args := &Args{7, 8} 334 reply := new(Reply) 335 err := client.Call("Arith.Add", args, reply) 336 if err != nil { 337 t.Errorf("Add: expected no error but got string %q", err.Error()) 338 } 339 if reply.C != args.A+args.B { 340 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 341 } 342 343 err = client.Call("Arith.Add", nil, reply) 344 if err == nil { 345 t.Errorf("expected error calling Arith.Add with nil arg") 346 } 347 } 348 349 type ReplyNotPointer int 350 type ArgNotPublic int 351 type ReplyNotPublic int 352 type NeedsPtrType int 353 type local struct{} 354 355 func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error { 356 return nil 357 } 358 359 func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error { 360 return nil 361 } 362 363 func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error { 364 return nil 365 } 366 367 func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error { 368 return nil 369 } 370 371 // Check that registration handles lots of bad methods and a type with no suitable methods. 372 func TestRegistrationError(t *testing.T) { 373 err := Register(new(ReplyNotPointer)) 374 if err == nil { 375 t.Error("expected error registering ReplyNotPointer") 376 } 377 err = Register(new(ArgNotPublic)) 378 if err == nil { 379 t.Error("expected error registering ArgNotPublic") 380 } 381 err = Register(new(ReplyNotPublic)) 382 if err == nil { 383 t.Error("expected error registering ReplyNotPublic") 384 } 385 err = Register(NeedsPtrType(0)) 386 if err == nil { 387 t.Error("expected error registering NeedsPtrType") 388 } else if !strings.Contains(err.Error(), "pointer") { 389 t.Error("expected hint when registering NeedsPtrType") 390 } 391 } 392 393 type WriteFailCodec int 394 395 func (WriteFailCodec) WriteRequest(*Request, interface{}) error { 396 // the panic caused by this error used to not unlock a lock. 397 return errors.New("fail") 398 } 399 400 func (WriteFailCodec) ReadResponseHeader(*Response) error { 401 select {} 402 } 403 404 func (WriteFailCodec) ReadResponseBody(interface{}) error { 405 select {} 406 } 407 408 func (WriteFailCodec) Close() error { 409 return nil 410 } 411 412 func TestSendDeadlock(t *testing.T) { 413 client := NewClientWithCodec(WriteFailCodec(0)) 414 415 done := make(chan bool) 416 go func() { 417 testSendDeadlock(client) 418 testSendDeadlock(client) 419 done <- true 420 }() 421 select { 422 case <-done: 423 return 424 case <-time.After(5 * time.Second): 425 t.Fatal("deadlock") 426 } 427 } 428 429 func testSendDeadlock(client *Client) { 430 defer func() { 431 recover() 432 }() 433 args := &Args{7, 8} 434 reply := new(Reply) 435 client.Call("Arith.Add", args, reply) 436 } 437 438 func dialDirect() (*Client, error) { 439 return Dial("tcp", serverAddr) 440 } 441 442 func dialHTTP() (*Client, error) { 443 return DialHTTP("tcp", httpServerAddr) 444 } 445 446 func countMallocs(dial func() (*Client, error), t *testing.T) float64 { 447 once.Do(startServer) 448 client, err := dial() 449 if err != nil { 450 t.Fatal("error dialing", err) 451 } 452 args := &Args{7, 8} 453 reply := new(Reply) 454 return testing.AllocsPerRun(100, func() { 455 err := client.Call("Arith.Add", args, reply) 456 if err != nil { 457 t.Errorf("Add: expected no error but got string %q", err.Error()) 458 } 459 if reply.C != args.A+args.B { 460 t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B) 461 } 462 }) 463 } 464 465 func TestCountMallocs(t *testing.T) { 466 if runtime.GOMAXPROCS(0) > 1 { 467 t.Skip("skipping; GOMAXPROCS>1") 468 } 469 fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t)) 470 } 471 472 func TestCountMallocsOverHTTP(t *testing.T) { 473 if runtime.GOMAXPROCS(0) > 1 { 474 t.Skip("skipping; GOMAXPROCS>1") 475 } 476 fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t)) 477 } 478 479 type writeCrasher struct { 480 done chan bool 481 } 482 483 func (writeCrasher) Close() error { 484 return nil 485 } 486 487 func (w *writeCrasher) Read(p []byte) (int, error) { 488 <-w.done 489 return 0, io.EOF 490 } 491 492 func (writeCrasher) Write(p []byte) (int, error) { 493 return 0, errors.New("fake write failure") 494 } 495 496 func TestClientWriteError(t *testing.T) { 497 w := &writeCrasher{done: make(chan bool)} 498 c := NewClient(w) 499 res := false 500 err := c.Call("foo", 1, &res) 501 if err == nil { 502 t.Fatal("expected error") 503 } 504 if err.Error() != "fake write failure" { 505 t.Error("unexpected value of error:", err) 506 } 507 w.done <- true 508 } 509 510 func TestTCPClose(t *testing.T) { 511 once.Do(startServer) 512 513 client, err := dialHTTP() 514 if err != nil { 515 t.Fatalf("dialing: %v", err) 516 } 517 defer client.Close() 518 519 args := Args{17, 8} 520 var reply Reply 521 err = client.Call("Arith.Mul", args, &reply) 522 if err != nil { 523 t.Fatal("arith error:", err) 524 } 525 t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply) 526 if reply.C != args.A*args.B { 527 t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B) 528 } 529 } 530 531 func TestErrorAfterClientClose(t *testing.T) { 532 once.Do(startServer) 533 534 client, err := dialHTTP() 535 if err != nil { 536 t.Fatalf("dialing: %v", err) 537 } 538 err = client.Close() 539 if err != nil { 540 t.Fatal("close error:", err) 541 } 542 err = client.Call("Arith.Add", &Args{7, 9}, new(Reply)) 543 if err != ErrShutdown { 544 t.Errorf("Forever: expected ErrShutdown got %v", err) 545 } 546 } 547 548 func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { 549 b.StopTimer() 550 once.Do(startServer) 551 client, err := dial() 552 if err != nil { 553 b.Fatal("error dialing:", err) 554 } 555 556 // Synchronous calls 557 args := &Args{7, 8} 558 procs := runtime.GOMAXPROCS(-1) 559 N := int32(b.N) 560 var wg sync.WaitGroup 561 wg.Add(procs) 562 b.StartTimer() 563 564 for p := 0; p < procs; p++ { 565 go func() { 566 reply := new(Reply) 567 for atomic.AddInt32(&N, -1) >= 0 { 568 err := client.Call("Arith.Add", args, reply) 569 if err != nil { 570 b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error()) 571 } 572 if reply.C != args.A+args.B { 573 b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B) 574 } 575 } 576 wg.Done() 577 }() 578 } 579 wg.Wait() 580 } 581 582 func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) { 583 const MaxConcurrentCalls = 100 584 b.StopTimer() 585 once.Do(startServer) 586 client, err := dial() 587 if err != nil { 588 b.Fatal("error dialing:", err) 589 } 590 591 // Asynchronous calls 592 args := &Args{7, 8} 593 procs := 4 * runtime.GOMAXPROCS(-1) 594 send := int32(b.N) 595 recv := int32(b.N) 596 var wg sync.WaitGroup 597 wg.Add(procs) 598 gate := make(chan bool, MaxConcurrentCalls) 599 res := make(chan *Call, MaxConcurrentCalls) 600 b.StartTimer() 601 602 for p := 0; p < procs; p++ { 603 go func() { 604 for atomic.AddInt32(&send, -1) >= 0 { 605 gate <- true 606 reply := new(Reply) 607 client.Go("Arith.Add", args, reply, res) 608 } 609 }() 610 go func() { 611 for call := range res { 612 A := call.Args.(*Args).A 613 B := call.Args.(*Args).B 614 C := call.Reply.(*Reply).C 615 if A+B != C { 616 b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C) 617 } 618 <-gate 619 if atomic.AddInt32(&recv, -1) == 0 { 620 close(res) 621 } 622 } 623 wg.Done() 624 }() 625 } 626 wg.Wait() 627 } 628 629 func BenchmarkEndToEnd(b *testing.B) { 630 benchmarkEndToEnd(dialDirect, b) 631 } 632 633 func BenchmarkEndToEndHTTP(b *testing.B) { 634 benchmarkEndToEnd(dialHTTP, b) 635 } 636 637 func BenchmarkEndToEndAsync(b *testing.B) { 638 benchmarkEndToEndAsync(dialDirect, b) 639 } 640 641 func BenchmarkEndToEndAsyncHTTP(b *testing.B) { 642 benchmarkEndToEndAsync(dialHTTP, b) 643 }