github.com/peggyl/go@v0.0.0-20151008231540-ae315999c2d5/src/net/rpc/server_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rpc
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"net"
    13  	"net/http/httptest"
    14  	"runtime"
    15  	"strings"
    16  	"sync"
    17  	"sync/atomic"
    18  	"testing"
    19  	"time"
    20  )
    21  
    22  var (
    23  	newServer                 *Server
    24  	serverAddr, newServerAddr string
    25  	httpServerAddr            string
    26  	once, newOnce, httpOnce   sync.Once
    27  )
    28  
    29  const (
    30  	newHttpPath = "/foo"
    31  )
    32  
    33  type Args struct {
    34  	A, B int
    35  }
    36  
    37  type Reply struct {
    38  	C int
    39  }
    40  
    41  type Arith int
    42  
    43  // Some of Arith's methods have value args, some have pointer args. That's deliberate.
    44  
    45  func (t *Arith) Add(args Args, reply *Reply) error {
    46  	reply.C = args.A + args.B
    47  	return nil
    48  }
    49  
    50  func (t *Arith) Mul(args *Args, reply *Reply) error {
    51  	reply.C = args.A * args.B
    52  	return nil
    53  }
    54  
    55  func (t *Arith) Div(args Args, reply *Reply) error {
    56  	if args.B == 0 {
    57  		return errors.New("divide by zero")
    58  	}
    59  	reply.C = args.A / args.B
    60  	return nil
    61  }
    62  
    63  func (t *Arith) String(args *Args, reply *string) error {
    64  	*reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
    65  	return nil
    66  }
    67  
    68  func (t *Arith) Scan(args string, reply *Reply) (err error) {
    69  	_, err = fmt.Sscan(args, &reply.C)
    70  	return
    71  }
    72  
    73  func (t *Arith) Error(args *Args, reply *Reply) error {
    74  	panic("ERROR")
    75  }
    76  
    77  func listenTCP() (net.Listener, string) {
    78  	l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
    79  	if e != nil {
    80  		log.Fatalf("net.Listen tcp :0: %v", e)
    81  	}
    82  	return l, l.Addr().String()
    83  }
    84  
    85  func startServer() {
    86  	Register(new(Arith))
    87  	RegisterName("net.rpc.Arith", new(Arith))
    88  
    89  	var l net.Listener
    90  	l, serverAddr = listenTCP()
    91  	log.Println("Test RPC server listening on", serverAddr)
    92  	go Accept(l)
    93  
    94  	HandleHTTP()
    95  	httpOnce.Do(startHttpServer)
    96  }
    97  
    98  func startNewServer() {
    99  	newServer = NewServer()
   100  	newServer.Register(new(Arith))
   101  	newServer.RegisterName("net.rpc.Arith", new(Arith))
   102  	newServer.RegisterName("newServer.Arith", new(Arith))
   103  
   104  	var l net.Listener
   105  	l, newServerAddr = listenTCP()
   106  	log.Println("NewServer test RPC server listening on", newServerAddr)
   107  	go newServer.Accept(l)
   108  
   109  	newServer.HandleHTTP(newHttpPath, "/bar")
   110  	httpOnce.Do(startHttpServer)
   111  }
   112  
   113  func startHttpServer() {
   114  	server := httptest.NewServer(nil)
   115  	httpServerAddr = server.Listener.Addr().String()
   116  	log.Println("Test HTTP RPC server listening on", httpServerAddr)
   117  }
   118  
   119  func TestRPC(t *testing.T) {
   120  	once.Do(startServer)
   121  	testRPC(t, serverAddr)
   122  	newOnce.Do(startNewServer)
   123  	testRPC(t, newServerAddr)
   124  	testNewServerRPC(t, newServerAddr)
   125  }
   126  
   127  func testRPC(t *testing.T, addr string) {
   128  	client, err := Dial("tcp", addr)
   129  	if err != nil {
   130  		t.Fatal("dialing", err)
   131  	}
   132  	defer client.Close()
   133  
   134  	// Synchronous calls
   135  	args := &Args{7, 8}
   136  	reply := new(Reply)
   137  	err = client.Call("Arith.Add", args, reply)
   138  	if err != nil {
   139  		t.Errorf("Add: expected no error but got string %q", err.Error())
   140  	}
   141  	if reply.C != args.A+args.B {
   142  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   143  	}
   144  
   145  	// Nonexistent method
   146  	args = &Args{7, 0}
   147  	reply = new(Reply)
   148  	err = client.Call("Arith.BadOperation", args, reply)
   149  	// expect an error
   150  	if err == nil {
   151  		t.Error("BadOperation: expected error")
   152  	} else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
   153  		t.Errorf("BadOperation: expected can't find method error; got %q", err)
   154  	}
   155  
   156  	// Unknown service
   157  	args = &Args{7, 8}
   158  	reply = new(Reply)
   159  	err = client.Call("Arith.Unknown", args, reply)
   160  	if err == nil {
   161  		t.Error("expected error calling unknown service")
   162  	} else if strings.Index(err.Error(), "method") < 0 {
   163  		t.Error("expected error about method; got", err)
   164  	}
   165  
   166  	// Out of order.
   167  	args = &Args{7, 8}
   168  	mulReply := new(Reply)
   169  	mulCall := client.Go("Arith.Mul", args, mulReply, nil)
   170  	addReply := new(Reply)
   171  	addCall := client.Go("Arith.Add", args, addReply, nil)
   172  
   173  	addCall = <-addCall.Done
   174  	if addCall.Error != nil {
   175  		t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
   176  	}
   177  	if addReply.C != args.A+args.B {
   178  		t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
   179  	}
   180  
   181  	mulCall = <-mulCall.Done
   182  	if mulCall.Error != nil {
   183  		t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
   184  	}
   185  	if mulReply.C != args.A*args.B {
   186  		t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
   187  	}
   188  
   189  	// Error test
   190  	args = &Args{7, 0}
   191  	reply = new(Reply)
   192  	err = client.Call("Arith.Div", args, reply)
   193  	// expect an error: zero divide
   194  	if err == nil {
   195  		t.Error("Div: expected error")
   196  	} else if err.Error() != "divide by zero" {
   197  		t.Error("Div: expected divide by zero error; got", err)
   198  	}
   199  
   200  	// Bad type.
   201  	reply = new(Reply)
   202  	err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
   203  	if err == nil {
   204  		t.Error("expected error calling Arith.Add with wrong arg type")
   205  	} else if strings.Index(err.Error(), "type") < 0 {
   206  		t.Error("expected error about type; got", err)
   207  	}
   208  
   209  	// Non-struct argument
   210  	const Val = 12345
   211  	str := fmt.Sprint(Val)
   212  	reply = new(Reply)
   213  	err = client.Call("Arith.Scan", &str, reply)
   214  	if err != nil {
   215  		t.Errorf("Scan: expected no error but got string %q", err.Error())
   216  	} else if reply.C != Val {
   217  		t.Errorf("Scan: expected %d got %d", Val, reply.C)
   218  	}
   219  
   220  	// Non-struct reply
   221  	args = &Args{27, 35}
   222  	str = ""
   223  	err = client.Call("Arith.String", args, &str)
   224  	if err != nil {
   225  		t.Errorf("String: expected no error but got string %q", err.Error())
   226  	}
   227  	expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
   228  	if str != expect {
   229  		t.Errorf("String: expected %s got %s", expect, str)
   230  	}
   231  
   232  	args = &Args{7, 8}
   233  	reply = new(Reply)
   234  	err = client.Call("Arith.Mul", args, reply)
   235  	if err != nil {
   236  		t.Errorf("Mul: expected no error but got string %q", err.Error())
   237  	}
   238  	if reply.C != args.A*args.B {
   239  		t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
   240  	}
   241  
   242  	// ServiceName contain "." character
   243  	args = &Args{7, 8}
   244  	reply = new(Reply)
   245  	err = client.Call("net.rpc.Arith.Add", args, reply)
   246  	if err != nil {
   247  		t.Errorf("Add: expected no error but got string %q", err.Error())
   248  	}
   249  	if reply.C != args.A+args.B {
   250  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   251  	}
   252  }
   253  
   254  func testNewServerRPC(t *testing.T, addr string) {
   255  	client, err := Dial("tcp", addr)
   256  	if err != nil {
   257  		t.Fatal("dialing", err)
   258  	}
   259  	defer client.Close()
   260  
   261  	// Synchronous calls
   262  	args := &Args{7, 8}
   263  	reply := new(Reply)
   264  	err = client.Call("newServer.Arith.Add", args, reply)
   265  	if err != nil {
   266  		t.Errorf("Add: expected no error but got string %q", err.Error())
   267  	}
   268  	if reply.C != args.A+args.B {
   269  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   270  	}
   271  }
   272  
   273  func TestHTTP(t *testing.T) {
   274  	once.Do(startServer)
   275  	testHTTPRPC(t, "")
   276  	newOnce.Do(startNewServer)
   277  	testHTTPRPC(t, newHttpPath)
   278  }
   279  
   280  func testHTTPRPC(t *testing.T, path string) {
   281  	var client *Client
   282  	var err error
   283  	if path == "" {
   284  		client, err = DialHTTP("tcp", httpServerAddr)
   285  	} else {
   286  		client, err = DialHTTPPath("tcp", httpServerAddr, path)
   287  	}
   288  	if err != nil {
   289  		t.Fatal("dialing", err)
   290  	}
   291  	defer client.Close()
   292  
   293  	// Synchronous calls
   294  	args := &Args{7, 8}
   295  	reply := new(Reply)
   296  	err = client.Call("Arith.Add", args, reply)
   297  	if err != nil {
   298  		t.Errorf("Add: expected no error but got string %q", err.Error())
   299  	}
   300  	if reply.C != args.A+args.B {
   301  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   302  	}
   303  }
   304  
   305  // CodecEmulator provides a client-like api and a ServerCodec interface.
   306  // Can be used to test ServeRequest.
   307  type CodecEmulator struct {
   308  	server        *Server
   309  	serviceMethod string
   310  	args          *Args
   311  	reply         *Reply
   312  	err           error
   313  }
   314  
   315  func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error {
   316  	codec.serviceMethod = serviceMethod
   317  	codec.args = args
   318  	codec.reply = reply
   319  	codec.err = nil
   320  	var serverError error
   321  	if codec.server == nil {
   322  		serverError = ServeRequest(codec)
   323  	} else {
   324  		serverError = codec.server.ServeRequest(codec)
   325  	}
   326  	if codec.err == nil && serverError != nil {
   327  		codec.err = serverError
   328  	}
   329  	return codec.err
   330  }
   331  
   332  func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
   333  	req.ServiceMethod = codec.serviceMethod
   334  	req.Seq = 0
   335  	return nil
   336  }
   337  
   338  func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error {
   339  	if codec.args == nil {
   340  		return io.ErrUnexpectedEOF
   341  	}
   342  	*(argv.(*Args)) = *codec.args
   343  	return nil
   344  }
   345  
   346  func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error {
   347  	if resp.Error != "" {
   348  		codec.err = errors.New(resp.Error)
   349  	} else {
   350  		*codec.reply = *(reply.(*Reply))
   351  	}
   352  	return nil
   353  }
   354  
   355  func (codec *CodecEmulator) Close() error {
   356  	return nil
   357  }
   358  
   359  func TestServeRequest(t *testing.T) {
   360  	once.Do(startServer)
   361  	testServeRequest(t, nil)
   362  	newOnce.Do(startNewServer)
   363  	testServeRequest(t, newServer)
   364  }
   365  
   366  func testServeRequest(t *testing.T, server *Server) {
   367  	client := CodecEmulator{server: server}
   368  	defer client.Close()
   369  
   370  	args := &Args{7, 8}
   371  	reply := new(Reply)
   372  	err := client.Call("Arith.Add", args, reply)
   373  	if err != nil {
   374  		t.Errorf("Add: expected no error but got string %q", err.Error())
   375  	}
   376  	if reply.C != args.A+args.B {
   377  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   378  	}
   379  
   380  	err = client.Call("Arith.Add", nil, reply)
   381  	if err == nil {
   382  		t.Errorf("expected error calling Arith.Add with nil arg")
   383  	}
   384  }
   385  
   386  type ReplyNotPointer int
   387  type ArgNotPublic int
   388  type ReplyNotPublic int
   389  type NeedsPtrType int
   390  type local struct{}
   391  
   392  func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
   393  	return nil
   394  }
   395  
   396  func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error {
   397  	return nil
   398  }
   399  
   400  func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
   401  	return nil
   402  }
   403  
   404  func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error {
   405  	return nil
   406  }
   407  
   408  // Check that registration handles lots of bad methods and a type with no suitable methods.
   409  func TestRegistrationError(t *testing.T) {
   410  	err := Register(new(ReplyNotPointer))
   411  	if err == nil {
   412  		t.Error("expected error registering ReplyNotPointer")
   413  	}
   414  	err = Register(new(ArgNotPublic))
   415  	if err == nil {
   416  		t.Error("expected error registering ArgNotPublic")
   417  	}
   418  	err = Register(new(ReplyNotPublic))
   419  	if err == nil {
   420  		t.Error("expected error registering ReplyNotPublic")
   421  	}
   422  	err = Register(NeedsPtrType(0))
   423  	if err == nil {
   424  		t.Error("expected error registering NeedsPtrType")
   425  	} else if !strings.Contains(err.Error(), "pointer") {
   426  		t.Error("expected hint when registering NeedsPtrType")
   427  	}
   428  }
   429  
   430  type WriteFailCodec int
   431  
   432  func (WriteFailCodec) WriteRequest(*Request, interface{}) error {
   433  	// the panic caused by this error used to not unlock a lock.
   434  	return errors.New("fail")
   435  }
   436  
   437  func (WriteFailCodec) ReadResponseHeader(*Response) error {
   438  	select {}
   439  }
   440  
   441  func (WriteFailCodec) ReadResponseBody(interface{}) error {
   442  	select {}
   443  }
   444  
   445  func (WriteFailCodec) Close() error {
   446  	return nil
   447  }
   448  
   449  func TestSendDeadlock(t *testing.T) {
   450  	client := NewClientWithCodec(WriteFailCodec(0))
   451  	defer client.Close()
   452  
   453  	done := make(chan bool)
   454  	go func() {
   455  		testSendDeadlock(client)
   456  		testSendDeadlock(client)
   457  		done <- true
   458  	}()
   459  	select {
   460  	case <-done:
   461  		return
   462  	case <-time.After(5 * time.Second):
   463  		t.Fatal("deadlock")
   464  	}
   465  }
   466  
   467  func testSendDeadlock(client *Client) {
   468  	defer func() {
   469  		recover()
   470  	}()
   471  	args := &Args{7, 8}
   472  	reply := new(Reply)
   473  	client.Call("Arith.Add", args, reply)
   474  }
   475  
   476  func dialDirect() (*Client, error) {
   477  	return Dial("tcp", serverAddr)
   478  }
   479  
   480  func dialHTTP() (*Client, error) {
   481  	return DialHTTP("tcp", httpServerAddr)
   482  }
   483  
   484  func countMallocs(dial func() (*Client, error), t *testing.T) float64 {
   485  	once.Do(startServer)
   486  	client, err := dial()
   487  	if err != nil {
   488  		t.Fatal("error dialing", err)
   489  	}
   490  	defer client.Close()
   491  
   492  	args := &Args{7, 8}
   493  	reply := new(Reply)
   494  	return testing.AllocsPerRun(100, func() {
   495  		err := client.Call("Arith.Add", args, reply)
   496  		if err != nil {
   497  			t.Errorf("Add: expected no error but got string %q", err.Error())
   498  		}
   499  		if reply.C != args.A+args.B {
   500  			t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   501  		}
   502  	})
   503  }
   504  
   505  func TestCountMallocs(t *testing.T) {
   506  	if testing.Short() {
   507  		t.Skip("skipping malloc count in short mode")
   508  	}
   509  	if runtime.GOMAXPROCS(0) > 1 {
   510  		t.Skip("skipping; GOMAXPROCS>1")
   511  	}
   512  	fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
   513  }
   514  
   515  func TestCountMallocsOverHTTP(t *testing.T) {
   516  	if testing.Short() {
   517  		t.Skip("skipping malloc count in short mode")
   518  	}
   519  	if runtime.GOMAXPROCS(0) > 1 {
   520  		t.Skip("skipping; GOMAXPROCS>1")
   521  	}
   522  	fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t))
   523  }
   524  
   525  type writeCrasher struct {
   526  	done chan bool
   527  }
   528  
   529  func (writeCrasher) Close() error {
   530  	return nil
   531  }
   532  
   533  func (w *writeCrasher) Read(p []byte) (int, error) {
   534  	<-w.done
   535  	return 0, io.EOF
   536  }
   537  
   538  func (writeCrasher) Write(p []byte) (int, error) {
   539  	return 0, errors.New("fake write failure")
   540  }
   541  
   542  func TestClientWriteError(t *testing.T) {
   543  	w := &writeCrasher{done: make(chan bool)}
   544  	c := NewClient(w)
   545  	defer c.Close()
   546  
   547  	res := false
   548  	err := c.Call("foo", 1, &res)
   549  	if err == nil {
   550  		t.Fatal("expected error")
   551  	}
   552  	if err.Error() != "fake write failure" {
   553  		t.Error("unexpected value of error:", err)
   554  	}
   555  	w.done <- true
   556  }
   557  
   558  func TestTCPClose(t *testing.T) {
   559  	once.Do(startServer)
   560  
   561  	client, err := dialHTTP()
   562  	if err != nil {
   563  		t.Fatalf("dialing: %v", err)
   564  	}
   565  	defer client.Close()
   566  
   567  	args := Args{17, 8}
   568  	var reply Reply
   569  	err = client.Call("Arith.Mul", args, &reply)
   570  	if err != nil {
   571  		t.Fatal("arith error:", err)
   572  	}
   573  	t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
   574  	if reply.C != args.A*args.B {
   575  		t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
   576  	}
   577  }
   578  
   579  func TestErrorAfterClientClose(t *testing.T) {
   580  	once.Do(startServer)
   581  
   582  	client, err := dialHTTP()
   583  	if err != nil {
   584  		t.Fatalf("dialing: %v", err)
   585  	}
   586  	err = client.Close()
   587  	if err != nil {
   588  		t.Fatal("close error:", err)
   589  	}
   590  	err = client.Call("Arith.Add", &Args{7, 9}, new(Reply))
   591  	if err != ErrShutdown {
   592  		t.Errorf("Forever: expected ErrShutdown got %v", err)
   593  	}
   594  }
   595  
   596  // Tests the fix to issue 11221. Without the fix, this loops forever or crashes.
   597  func TestAcceptExitAfterListenerClose(t *testing.T) {
   598  	newServer = NewServer()
   599  	newServer.Register(new(Arith))
   600  	newServer.RegisterName("net.rpc.Arith", new(Arith))
   601  	newServer.RegisterName("newServer.Arith", new(Arith))
   602  
   603  	var l net.Listener
   604  	l, newServerAddr = listenTCP()
   605  	l.Close()
   606  	newServer.Accept(l)
   607  }
   608  
   609  func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
   610  	once.Do(startServer)
   611  	client, err := dial()
   612  	if err != nil {
   613  		b.Fatal("error dialing:", err)
   614  	}
   615  	defer client.Close()
   616  
   617  	// Synchronous calls
   618  	args := &Args{7, 8}
   619  	b.ResetTimer()
   620  
   621  	b.RunParallel(func(pb *testing.PB) {
   622  		reply := new(Reply)
   623  		for pb.Next() {
   624  			err := client.Call("Arith.Add", args, reply)
   625  			if err != nil {
   626  				b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
   627  			}
   628  			if reply.C != args.A+args.B {
   629  				b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
   630  			}
   631  		}
   632  	})
   633  }
   634  
   635  func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
   636  	const MaxConcurrentCalls = 100
   637  	once.Do(startServer)
   638  	client, err := dial()
   639  	if err != nil {
   640  		b.Fatal("error dialing:", err)
   641  	}
   642  	defer client.Close()
   643  
   644  	// Asynchronous calls
   645  	args := &Args{7, 8}
   646  	procs := 4 * runtime.GOMAXPROCS(-1)
   647  	send := int32(b.N)
   648  	recv := int32(b.N)
   649  	var wg sync.WaitGroup
   650  	wg.Add(procs)
   651  	gate := make(chan bool, MaxConcurrentCalls)
   652  	res := make(chan *Call, MaxConcurrentCalls)
   653  	b.ResetTimer()
   654  
   655  	for p := 0; p < procs; p++ {
   656  		go func() {
   657  			for atomic.AddInt32(&send, -1) >= 0 {
   658  				gate <- true
   659  				reply := new(Reply)
   660  				client.Go("Arith.Add", args, reply, res)
   661  			}
   662  		}()
   663  		go func() {
   664  			for call := range res {
   665  				A := call.Args.(*Args).A
   666  				B := call.Args.(*Args).B
   667  				C := call.Reply.(*Reply).C
   668  				if A+B != C {
   669  					b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C)
   670  				}
   671  				<-gate
   672  				if atomic.AddInt32(&recv, -1) == 0 {
   673  					close(res)
   674  				}
   675  			}
   676  			wg.Done()
   677  		}()
   678  	}
   679  	wg.Wait()
   680  }
   681  
   682  func BenchmarkEndToEnd(b *testing.B) {
   683  	benchmarkEndToEnd(dialDirect, b)
   684  }
   685  
   686  func BenchmarkEndToEndHTTP(b *testing.B) {
   687  	benchmarkEndToEnd(dialHTTP, b)
   688  }
   689  
   690  func BenchmarkEndToEndAsync(b *testing.B) {
   691  	benchmarkEndToEndAsync(dialDirect, b)
   692  }
   693  
   694  func BenchmarkEndToEndAsyncHTTP(b *testing.B) {
   695  	benchmarkEndToEndAsync(dialHTTP, b)
   696  }