github.com/xushiwei/go@v0.0.0-20130601165731-2b9d83f45bc9/src/pkg/net/rpc/server_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rpc
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"net"
    13  	"net/http/httptest"
    14  	"runtime"
    15  	"strings"
    16  	"sync"
    17  	"sync/atomic"
    18  	"testing"
    19  	"time"
    20  )
    21  
    22  var (
    23  	newServer                 *Server
    24  	serverAddr, newServerAddr string
    25  	httpServerAddr            string
    26  	once, newOnce, httpOnce   sync.Once
    27  )
    28  
    29  const (
    30  	newHttpPath = "/foo"
    31  )
    32  
    33  type Args struct {
    34  	A, B int
    35  }
    36  
    37  type Reply struct {
    38  	C int
    39  }
    40  
    41  type Arith int
    42  
    43  // Some of Arith's methods have value args, some have pointer args. That's deliberate.
    44  
    45  func (t *Arith) Add(args Args, reply *Reply) error {
    46  	reply.C = args.A + args.B
    47  	return nil
    48  }
    49  
    50  func (t *Arith) Mul(args *Args, reply *Reply) error {
    51  	reply.C = args.A * args.B
    52  	return nil
    53  }
    54  
    55  func (t *Arith) Div(args Args, reply *Reply) error {
    56  	if args.B == 0 {
    57  		return errors.New("divide by zero")
    58  	}
    59  	reply.C = args.A / args.B
    60  	return nil
    61  }
    62  
    63  func (t *Arith) String(args *Args, reply *string) error {
    64  	*reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
    65  	return nil
    66  }
    67  
    68  func (t *Arith) Scan(args string, reply *Reply) (err error) {
    69  	_, err = fmt.Sscan(args, &reply.C)
    70  	return
    71  }
    72  
    73  func (t *Arith) Error(args *Args, reply *Reply) error {
    74  	panic("ERROR")
    75  }
    76  
    77  func listenTCP() (net.Listener, string) {
    78  	l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
    79  	if e != nil {
    80  		log.Fatalf("net.Listen tcp :0: %v", e)
    81  	}
    82  	return l, l.Addr().String()
    83  }
    84  
    85  func startServer() {
    86  	Register(new(Arith))
    87  
    88  	var l net.Listener
    89  	l, serverAddr = listenTCP()
    90  	log.Println("Test RPC server listening on", serverAddr)
    91  	go Accept(l)
    92  
    93  	HandleHTTP()
    94  	httpOnce.Do(startHttpServer)
    95  }
    96  
    97  func startNewServer() {
    98  	newServer = NewServer()
    99  	newServer.Register(new(Arith))
   100  
   101  	var l net.Listener
   102  	l, newServerAddr = listenTCP()
   103  	log.Println("NewServer test RPC server listening on", newServerAddr)
   104  	go Accept(l)
   105  
   106  	newServer.HandleHTTP(newHttpPath, "/bar")
   107  	httpOnce.Do(startHttpServer)
   108  }
   109  
   110  func startHttpServer() {
   111  	server := httptest.NewServer(nil)
   112  	httpServerAddr = server.Listener.Addr().String()
   113  	log.Println("Test HTTP RPC server listening on", httpServerAddr)
   114  }
   115  
   116  func TestRPC(t *testing.T) {
   117  	once.Do(startServer)
   118  	testRPC(t, serverAddr)
   119  	newOnce.Do(startNewServer)
   120  	testRPC(t, newServerAddr)
   121  }
   122  
   123  func testRPC(t *testing.T, addr string) {
   124  	client, err := Dial("tcp", addr)
   125  	if err != nil {
   126  		t.Fatal("dialing", err)
   127  	}
   128  
   129  	// Synchronous calls
   130  	args := &Args{7, 8}
   131  	reply := new(Reply)
   132  	err = client.Call("Arith.Add", args, reply)
   133  	if err != nil {
   134  		t.Errorf("Add: expected no error but got string %q", err.Error())
   135  	}
   136  	if reply.C != args.A+args.B {
   137  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   138  	}
   139  
   140  	// Nonexistent method
   141  	args = &Args{7, 0}
   142  	reply = new(Reply)
   143  	err = client.Call("Arith.BadOperation", args, reply)
   144  	// expect an error
   145  	if err == nil {
   146  		t.Error("BadOperation: expected error")
   147  	} else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
   148  		t.Errorf("BadOperation: expected can't find method error; got %q", err)
   149  	}
   150  
   151  	// Unknown service
   152  	args = &Args{7, 8}
   153  	reply = new(Reply)
   154  	err = client.Call("Arith.Unknown", args, reply)
   155  	if err == nil {
   156  		t.Error("expected error calling unknown service")
   157  	} else if strings.Index(err.Error(), "method") < 0 {
   158  		t.Error("expected error about method; got", err)
   159  	}
   160  
   161  	// Out of order.
   162  	args = &Args{7, 8}
   163  	mulReply := new(Reply)
   164  	mulCall := client.Go("Arith.Mul", args, mulReply, nil)
   165  	addReply := new(Reply)
   166  	addCall := client.Go("Arith.Add", args, addReply, nil)
   167  
   168  	addCall = <-addCall.Done
   169  	if addCall.Error != nil {
   170  		t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
   171  	}
   172  	if addReply.C != args.A+args.B {
   173  		t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
   174  	}
   175  
   176  	mulCall = <-mulCall.Done
   177  	if mulCall.Error != nil {
   178  		t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
   179  	}
   180  	if mulReply.C != args.A*args.B {
   181  		t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
   182  	}
   183  
   184  	// Error test
   185  	args = &Args{7, 0}
   186  	reply = new(Reply)
   187  	err = client.Call("Arith.Div", args, reply)
   188  	// expect an error: zero divide
   189  	if err == nil {
   190  		t.Error("Div: expected error")
   191  	} else if err.Error() != "divide by zero" {
   192  		t.Error("Div: expected divide by zero error; got", err)
   193  	}
   194  
   195  	// Bad type.
   196  	reply = new(Reply)
   197  	err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
   198  	if err == nil {
   199  		t.Error("expected error calling Arith.Add with wrong arg type")
   200  	} else if strings.Index(err.Error(), "type") < 0 {
   201  		t.Error("expected error about type; got", err)
   202  	}
   203  
   204  	// Non-struct argument
   205  	const Val = 12345
   206  	str := fmt.Sprint(Val)
   207  	reply = new(Reply)
   208  	err = client.Call("Arith.Scan", &str, reply)
   209  	if err != nil {
   210  		t.Errorf("Scan: expected no error but got string %q", err.Error())
   211  	} else if reply.C != Val {
   212  		t.Errorf("Scan: expected %d got %d", Val, reply.C)
   213  	}
   214  
   215  	// Non-struct reply
   216  	args = &Args{27, 35}
   217  	str = ""
   218  	err = client.Call("Arith.String", args, &str)
   219  	if err != nil {
   220  		t.Errorf("String: expected no error but got string %q", err.Error())
   221  	}
   222  	expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
   223  	if str != expect {
   224  		t.Errorf("String: expected %s got %s", expect, str)
   225  	}
   226  
   227  	args = &Args{7, 8}
   228  	reply = new(Reply)
   229  	err = client.Call("Arith.Mul", args, reply)
   230  	if err != nil {
   231  		t.Errorf("Mul: expected no error but got string %q", err.Error())
   232  	}
   233  	if reply.C != args.A*args.B {
   234  		t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
   235  	}
   236  }
   237  
   238  func TestHTTP(t *testing.T) {
   239  	once.Do(startServer)
   240  	testHTTPRPC(t, "")
   241  	newOnce.Do(startNewServer)
   242  	testHTTPRPC(t, newHttpPath)
   243  }
   244  
   245  func testHTTPRPC(t *testing.T, path string) {
   246  	var client *Client
   247  	var err error
   248  	if path == "" {
   249  		client, err = DialHTTP("tcp", httpServerAddr)
   250  	} else {
   251  		client, err = DialHTTPPath("tcp", httpServerAddr, path)
   252  	}
   253  	if err != nil {
   254  		t.Fatal("dialing", err)
   255  	}
   256  
   257  	// Synchronous calls
   258  	args := &Args{7, 8}
   259  	reply := new(Reply)
   260  	err = client.Call("Arith.Add", args, reply)
   261  	if err != nil {
   262  		t.Errorf("Add: expected no error but got string %q", err.Error())
   263  	}
   264  	if reply.C != args.A+args.B {
   265  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   266  	}
   267  }
   268  
   269  // CodecEmulator provides a client-like api and a ServerCodec interface.
   270  // Can be used to test ServeRequest.
   271  type CodecEmulator struct {
   272  	server        *Server
   273  	serviceMethod string
   274  	args          *Args
   275  	reply         *Reply
   276  	err           error
   277  }
   278  
   279  func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error {
   280  	codec.serviceMethod = serviceMethod
   281  	codec.args = args
   282  	codec.reply = reply
   283  	codec.err = nil
   284  	var serverError error
   285  	if codec.server == nil {
   286  		serverError = ServeRequest(codec)
   287  	} else {
   288  		serverError = codec.server.ServeRequest(codec)
   289  	}
   290  	if codec.err == nil && serverError != nil {
   291  		codec.err = serverError
   292  	}
   293  	return codec.err
   294  }
   295  
   296  func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
   297  	req.ServiceMethod = codec.serviceMethod
   298  	req.Seq = 0
   299  	return nil
   300  }
   301  
   302  func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error {
   303  	if codec.args == nil {
   304  		return io.ErrUnexpectedEOF
   305  	}
   306  	*(argv.(*Args)) = *codec.args
   307  	return nil
   308  }
   309  
   310  func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error {
   311  	if resp.Error != "" {
   312  		codec.err = errors.New(resp.Error)
   313  	} else {
   314  		*codec.reply = *(reply.(*Reply))
   315  	}
   316  	return nil
   317  }
   318  
   319  func (codec *CodecEmulator) Close() error {
   320  	return nil
   321  }
   322  
   323  func TestServeRequest(t *testing.T) {
   324  	once.Do(startServer)
   325  	testServeRequest(t, nil)
   326  	newOnce.Do(startNewServer)
   327  	testServeRequest(t, newServer)
   328  }
   329  
   330  func testServeRequest(t *testing.T, server *Server) {
   331  	client := CodecEmulator{server: server}
   332  
   333  	args := &Args{7, 8}
   334  	reply := new(Reply)
   335  	err := client.Call("Arith.Add", args, reply)
   336  	if err != nil {
   337  		t.Errorf("Add: expected no error but got string %q", err.Error())
   338  	}
   339  	if reply.C != args.A+args.B {
   340  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   341  	}
   342  
   343  	err = client.Call("Arith.Add", nil, reply)
   344  	if err == nil {
   345  		t.Errorf("expected error calling Arith.Add with nil arg")
   346  	}
   347  }
   348  
   349  type ReplyNotPointer int
   350  type ArgNotPublic int
   351  type ReplyNotPublic int
   352  type NeedsPtrType int
   353  type local struct{}
   354  
   355  func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
   356  	return nil
   357  }
   358  
   359  func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error {
   360  	return nil
   361  }
   362  
   363  func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
   364  	return nil
   365  }
   366  
   367  func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error {
   368  	return nil
   369  }
   370  
   371  // Check that registration handles lots of bad methods and a type with no suitable methods.
   372  func TestRegistrationError(t *testing.T) {
   373  	err := Register(new(ReplyNotPointer))
   374  	if err == nil {
   375  		t.Error("expected error registering ReplyNotPointer")
   376  	}
   377  	err = Register(new(ArgNotPublic))
   378  	if err == nil {
   379  		t.Error("expected error registering ArgNotPublic")
   380  	}
   381  	err = Register(new(ReplyNotPublic))
   382  	if err == nil {
   383  		t.Error("expected error registering ReplyNotPublic")
   384  	}
   385  	err = Register(NeedsPtrType(0))
   386  	if err == nil {
   387  		t.Error("expected error registering NeedsPtrType")
   388  	} else if !strings.Contains(err.Error(), "pointer") {
   389  		t.Error("expected hint when registering NeedsPtrType")
   390  	}
   391  }
   392  
   393  type WriteFailCodec int
   394  
   395  func (WriteFailCodec) WriteRequest(*Request, interface{}) error {
   396  	// the panic caused by this error used to not unlock a lock.
   397  	return errors.New("fail")
   398  }
   399  
   400  func (WriteFailCodec) ReadResponseHeader(*Response) error {
   401  	select {}
   402  }
   403  
   404  func (WriteFailCodec) ReadResponseBody(interface{}) error {
   405  	select {}
   406  }
   407  
   408  func (WriteFailCodec) Close() error {
   409  	return nil
   410  }
   411  
   412  func TestSendDeadlock(t *testing.T) {
   413  	client := NewClientWithCodec(WriteFailCodec(0))
   414  
   415  	done := make(chan bool)
   416  	go func() {
   417  		testSendDeadlock(client)
   418  		testSendDeadlock(client)
   419  		done <- true
   420  	}()
   421  	select {
   422  	case <-done:
   423  		return
   424  	case <-time.After(5 * time.Second):
   425  		t.Fatal("deadlock")
   426  	}
   427  }
   428  
   429  func testSendDeadlock(client *Client) {
   430  	defer func() {
   431  		recover()
   432  	}()
   433  	args := &Args{7, 8}
   434  	reply := new(Reply)
   435  	client.Call("Arith.Add", args, reply)
   436  }
   437  
   438  func dialDirect() (*Client, error) {
   439  	return Dial("tcp", serverAddr)
   440  }
   441  
   442  func dialHTTP() (*Client, error) {
   443  	return DialHTTP("tcp", httpServerAddr)
   444  }
   445  
   446  func countMallocs(dial func() (*Client, error), t *testing.T) float64 {
   447  	once.Do(startServer)
   448  	client, err := dial()
   449  	if err != nil {
   450  		t.Fatal("error dialing", err)
   451  	}
   452  	args := &Args{7, 8}
   453  	reply := new(Reply)
   454  	return testing.AllocsPerRun(100, func() {
   455  		err := client.Call("Arith.Add", args, reply)
   456  		if err != nil {
   457  			t.Errorf("Add: expected no error but got string %q", err.Error())
   458  		}
   459  		if reply.C != args.A+args.B {
   460  			t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   461  		}
   462  	})
   463  }
   464  
   465  func TestCountMallocs(t *testing.T) {
   466  	if runtime.GOMAXPROCS(0) > 1 {
   467  		t.Skip("skipping; GOMAXPROCS>1")
   468  	}
   469  	fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
   470  }
   471  
   472  func TestCountMallocsOverHTTP(t *testing.T) {
   473  	if runtime.GOMAXPROCS(0) > 1 {
   474  		t.Skip("skipping; GOMAXPROCS>1")
   475  	}
   476  	fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t))
   477  }
   478  
   479  type writeCrasher struct {
   480  	done chan bool
   481  }
   482  
   483  func (writeCrasher) Close() error {
   484  	return nil
   485  }
   486  
   487  func (w *writeCrasher) Read(p []byte) (int, error) {
   488  	<-w.done
   489  	return 0, io.EOF
   490  }
   491  
   492  func (writeCrasher) Write(p []byte) (int, error) {
   493  	return 0, errors.New("fake write failure")
   494  }
   495  
   496  func TestClientWriteError(t *testing.T) {
   497  	w := &writeCrasher{done: make(chan bool)}
   498  	c := NewClient(w)
   499  	res := false
   500  	err := c.Call("foo", 1, &res)
   501  	if err == nil {
   502  		t.Fatal("expected error")
   503  	}
   504  	if err.Error() != "fake write failure" {
   505  		t.Error("unexpected value of error:", err)
   506  	}
   507  	w.done <- true
   508  }
   509  
   510  func TestTCPClose(t *testing.T) {
   511  	once.Do(startServer)
   512  
   513  	client, err := dialHTTP()
   514  	if err != nil {
   515  		t.Fatalf("dialing: %v", err)
   516  	}
   517  	defer client.Close()
   518  
   519  	args := Args{17, 8}
   520  	var reply Reply
   521  	err = client.Call("Arith.Mul", args, &reply)
   522  	if err != nil {
   523  		t.Fatal("arith error:", err)
   524  	}
   525  	t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
   526  	if reply.C != args.A*args.B {
   527  		t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
   528  	}
   529  }
   530  
   531  func TestErrorAfterClientClose(t *testing.T) {
   532  	once.Do(startServer)
   533  
   534  	client, err := dialHTTP()
   535  	if err != nil {
   536  		t.Fatalf("dialing: %v", err)
   537  	}
   538  	err = client.Close()
   539  	if err != nil {
   540  		t.Fatal("close error:", err)
   541  	}
   542  	err = client.Call("Arith.Add", &Args{7, 9}, new(Reply))
   543  	if err != ErrShutdown {
   544  		t.Errorf("Forever: expected ErrShutdown got %v", err)
   545  	}
   546  }
   547  
   548  func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
   549  	b.StopTimer()
   550  	once.Do(startServer)
   551  	client, err := dial()
   552  	if err != nil {
   553  		b.Fatal("error dialing:", err)
   554  	}
   555  
   556  	// Synchronous calls
   557  	args := &Args{7, 8}
   558  	procs := runtime.GOMAXPROCS(-1)
   559  	N := int32(b.N)
   560  	var wg sync.WaitGroup
   561  	wg.Add(procs)
   562  	b.StartTimer()
   563  
   564  	for p := 0; p < procs; p++ {
   565  		go func() {
   566  			reply := new(Reply)
   567  			for atomic.AddInt32(&N, -1) >= 0 {
   568  				err := client.Call("Arith.Add", args, reply)
   569  				if err != nil {
   570  					b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
   571  				}
   572  				if reply.C != args.A+args.B {
   573  					b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
   574  				}
   575  			}
   576  			wg.Done()
   577  		}()
   578  	}
   579  	wg.Wait()
   580  }
   581  
   582  func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
   583  	const MaxConcurrentCalls = 100
   584  	b.StopTimer()
   585  	once.Do(startServer)
   586  	client, err := dial()
   587  	if err != nil {
   588  		b.Fatal("error dialing:", err)
   589  	}
   590  
   591  	// Asynchronous calls
   592  	args := &Args{7, 8}
   593  	procs := 4 * runtime.GOMAXPROCS(-1)
   594  	send := int32(b.N)
   595  	recv := int32(b.N)
   596  	var wg sync.WaitGroup
   597  	wg.Add(procs)
   598  	gate := make(chan bool, MaxConcurrentCalls)
   599  	res := make(chan *Call, MaxConcurrentCalls)
   600  	b.StartTimer()
   601  
   602  	for p := 0; p < procs; p++ {
   603  		go func() {
   604  			for atomic.AddInt32(&send, -1) >= 0 {
   605  				gate <- true
   606  				reply := new(Reply)
   607  				client.Go("Arith.Add", args, reply, res)
   608  			}
   609  		}()
   610  		go func() {
   611  			for call := range res {
   612  				A := call.Args.(*Args).A
   613  				B := call.Args.(*Args).B
   614  				C := call.Reply.(*Reply).C
   615  				if A+B != C {
   616  					b.Fatalf("incorrect reply: Add: expected %d got %d", A+B, C)
   617  				}
   618  				<-gate
   619  				if atomic.AddInt32(&recv, -1) == 0 {
   620  					close(res)
   621  				}
   622  			}
   623  			wg.Done()
   624  		}()
   625  	}
   626  	wg.Wait()
   627  }
   628  
   629  func BenchmarkEndToEnd(b *testing.B) {
   630  	benchmarkEndToEnd(dialDirect, b)
   631  }
   632  
   633  func BenchmarkEndToEndHTTP(b *testing.B) {
   634  	benchmarkEndToEnd(dialHTTP, b)
   635  }
   636  
   637  func BenchmarkEndToEndAsync(b *testing.B) {
   638  	benchmarkEndToEndAsync(dialDirect, b)
   639  }
   640  
   641  func BenchmarkEndToEndAsyncHTTP(b *testing.B) {
   642  	benchmarkEndToEndAsync(dialHTTP, b)
   643  }