github.com/slayercat/go@v0.0.0-20170428012452-c51559813f61/src/net/rpc/server_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rpc
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"net"
    13  	"net/http/httptest"
    14  	"reflect"
    15  	"runtime"
    16  	"strings"
    17  	"sync"
    18  	"sync/atomic"
    19  	"testing"
    20  	"time"
    21  )
    22  
    23  var (
    24  	newServer                 *Server
    25  	serverAddr, newServerAddr string
    26  	httpServerAddr            string
    27  	once, newOnce, httpOnce   sync.Once
    28  )
    29  
    30  const (
    31  	newHttpPath = "/foo"
    32  )
    33  
    34  type Args struct {
    35  	A, B int
    36  }
    37  
    38  type Reply struct {
    39  	C int
    40  }
    41  
    42  type Arith int
    43  
    44  // Some of Arith's methods have value args, some have pointer args. That's deliberate.
    45  
    46  func (t *Arith) Add(args Args, reply *Reply) error {
    47  	reply.C = args.A + args.B
    48  	return nil
    49  }
    50  
    51  func (t *Arith) Mul(args *Args, reply *Reply) error {
    52  	reply.C = args.A * args.B
    53  	return nil
    54  }
    55  
    56  func (t *Arith) Div(args Args, reply *Reply) error {
    57  	if args.B == 0 {
    58  		return errors.New("divide by zero")
    59  	}
    60  	reply.C = args.A / args.B
    61  	return nil
    62  }
    63  
    64  func (t *Arith) String(args *Args, reply *string) error {
    65  	*reply = fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
    66  	return nil
    67  }
    68  
    69  func (t *Arith) Scan(args string, reply *Reply) (err error) {
    70  	_, err = fmt.Sscan(args, &reply.C)
    71  	return
    72  }
    73  
    74  func (t *Arith) Error(args *Args, reply *Reply) error {
    75  	panic("ERROR")
    76  }
    77  
    78  type hidden int
    79  
    80  func (t *hidden) Exported(args Args, reply *Reply) error {
    81  	reply.C = args.A + args.B
    82  	return nil
    83  }
    84  
    85  type Embed struct {
    86  	hidden
    87  }
    88  
    89  type BuiltinTypes struct{}
    90  
    91  func (BuiltinTypes) Map(args *Args, reply *map[int]int) error {
    92  	(*reply)[args.A] = args.B
    93  	return nil
    94  }
    95  
    96  func (BuiltinTypes) Slice(args *Args, reply *[]int) error {
    97  	*reply = append(*reply, args.A, args.B)
    98  	return nil
    99  }
   100  
   101  func (BuiltinTypes) Array(args *Args, reply *[2]int) error {
   102  	(*reply)[0] = args.A
   103  	(*reply)[1] = args.B
   104  	return nil
   105  }
   106  
   107  func listenTCP() (net.Listener, string) {
   108  	l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
   109  	if e != nil {
   110  		log.Fatalf("net.Listen tcp :0: %v", e)
   111  	}
   112  	return l, l.Addr().String()
   113  }
   114  
   115  func startServer() {
   116  	Register(new(Arith))
   117  	Register(new(Embed))
   118  	RegisterName("net.rpc.Arith", new(Arith))
   119  	Register(BuiltinTypes{})
   120  
   121  	var l net.Listener
   122  	l, serverAddr = listenTCP()
   123  	log.Println("Test RPC server listening on", serverAddr)
   124  	go Accept(l)
   125  
   126  	HandleHTTP()
   127  	httpOnce.Do(startHttpServer)
   128  }
   129  
   130  func startNewServer() {
   131  	newServer = NewServer()
   132  	newServer.Register(new(Arith))
   133  	newServer.Register(new(Embed))
   134  	newServer.RegisterName("net.rpc.Arith", new(Arith))
   135  	newServer.RegisterName("newServer.Arith", new(Arith))
   136  
   137  	var l net.Listener
   138  	l, newServerAddr = listenTCP()
   139  	log.Println("NewServer test RPC server listening on", newServerAddr)
   140  	go newServer.Accept(l)
   141  
   142  	newServer.HandleHTTP(newHttpPath, "/bar")
   143  	httpOnce.Do(startHttpServer)
   144  }
   145  
   146  func startHttpServer() {
   147  	server := httptest.NewServer(nil)
   148  	httpServerAddr = server.Listener.Addr().String()
   149  	log.Println("Test HTTP RPC server listening on", httpServerAddr)
   150  }
   151  
   152  func TestRPC(t *testing.T) {
   153  	once.Do(startServer)
   154  	testRPC(t, serverAddr)
   155  	newOnce.Do(startNewServer)
   156  	testRPC(t, newServerAddr)
   157  	testNewServerRPC(t, newServerAddr)
   158  }
   159  
   160  func testRPC(t *testing.T, addr string) {
   161  	client, err := Dial("tcp", addr)
   162  	if err != nil {
   163  		t.Fatal("dialing", err)
   164  	}
   165  	defer client.Close()
   166  
   167  	// Synchronous calls
   168  	args := &Args{7, 8}
   169  	reply := new(Reply)
   170  	err = client.Call("Arith.Add", args, reply)
   171  	if err != nil {
   172  		t.Errorf("Add: expected no error but got string %q", err.Error())
   173  	}
   174  	if reply.C != args.A+args.B {
   175  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   176  	}
   177  
   178  	// Methods exported from unexported embedded structs
   179  	args = &Args{7, 0}
   180  	reply = new(Reply)
   181  	err = client.Call("Embed.Exported", args, reply)
   182  	if err != nil {
   183  		t.Errorf("Add: expected no error but got string %q", err.Error())
   184  	}
   185  	if reply.C != args.A+args.B {
   186  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   187  	}
   188  
   189  	// Nonexistent method
   190  	args = &Args{7, 0}
   191  	reply = new(Reply)
   192  	err = client.Call("Arith.BadOperation", args, reply)
   193  	// expect an error
   194  	if err == nil {
   195  		t.Error("BadOperation: expected error")
   196  	} else if !strings.HasPrefix(err.Error(), "rpc: can't find method ") {
   197  		t.Errorf("BadOperation: expected can't find method error; got %q", err)
   198  	}
   199  
   200  	// Unknown service
   201  	args = &Args{7, 8}
   202  	reply = new(Reply)
   203  	err = client.Call("Arith.Unknown", args, reply)
   204  	if err == nil {
   205  		t.Error("expected error calling unknown service")
   206  	} else if !strings.Contains(err.Error(), "method") {
   207  		t.Error("expected error about method; got", err)
   208  	}
   209  
   210  	// Out of order.
   211  	args = &Args{7, 8}
   212  	mulReply := new(Reply)
   213  	mulCall := client.Go("Arith.Mul", args, mulReply, nil)
   214  	addReply := new(Reply)
   215  	addCall := client.Go("Arith.Add", args, addReply, nil)
   216  
   217  	addCall = <-addCall.Done
   218  	if addCall.Error != nil {
   219  		t.Errorf("Add: expected no error but got string %q", addCall.Error.Error())
   220  	}
   221  	if addReply.C != args.A+args.B {
   222  		t.Errorf("Add: expected %d got %d", addReply.C, args.A+args.B)
   223  	}
   224  
   225  	mulCall = <-mulCall.Done
   226  	if mulCall.Error != nil {
   227  		t.Errorf("Mul: expected no error but got string %q", mulCall.Error.Error())
   228  	}
   229  	if mulReply.C != args.A*args.B {
   230  		t.Errorf("Mul: expected %d got %d", mulReply.C, args.A*args.B)
   231  	}
   232  
   233  	// Error test
   234  	args = &Args{7, 0}
   235  	reply = new(Reply)
   236  	err = client.Call("Arith.Div", args, reply)
   237  	// expect an error: zero divide
   238  	if err == nil {
   239  		t.Error("Div: expected error")
   240  	} else if err.Error() != "divide by zero" {
   241  		t.Error("Div: expected divide by zero error; got", err)
   242  	}
   243  
   244  	// Bad type.
   245  	reply = new(Reply)
   246  	err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
   247  	if err == nil {
   248  		t.Error("expected error calling Arith.Add with wrong arg type")
   249  	} else if !strings.Contains(err.Error(), "type") {
   250  		t.Error("expected error about type; got", err)
   251  	}
   252  
   253  	// Non-struct argument
   254  	const Val = 12345
   255  	str := fmt.Sprint(Val)
   256  	reply = new(Reply)
   257  	err = client.Call("Arith.Scan", &str, reply)
   258  	if err != nil {
   259  		t.Errorf("Scan: expected no error but got string %q", err.Error())
   260  	} else if reply.C != Val {
   261  		t.Errorf("Scan: expected %d got %d", Val, reply.C)
   262  	}
   263  
   264  	// Non-struct reply
   265  	args = &Args{27, 35}
   266  	str = ""
   267  	err = client.Call("Arith.String", args, &str)
   268  	if err != nil {
   269  		t.Errorf("String: expected no error but got string %q", err.Error())
   270  	}
   271  	expect := fmt.Sprintf("%d+%d=%d", args.A, args.B, args.A+args.B)
   272  	if str != expect {
   273  		t.Errorf("String: expected %s got %s", expect, str)
   274  	}
   275  
   276  	args = &Args{7, 8}
   277  	reply = new(Reply)
   278  	err = client.Call("Arith.Mul", args, reply)
   279  	if err != nil {
   280  		t.Errorf("Mul: expected no error but got string %q", err.Error())
   281  	}
   282  	if reply.C != args.A*args.B {
   283  		t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
   284  	}
   285  
   286  	// ServiceName contain "." character
   287  	args = &Args{7, 8}
   288  	reply = new(Reply)
   289  	err = client.Call("net.rpc.Arith.Add", args, reply)
   290  	if err != nil {
   291  		t.Errorf("Add: expected no error but got string %q", err.Error())
   292  	}
   293  	if reply.C != args.A+args.B {
   294  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   295  	}
   296  }
   297  
   298  func testNewServerRPC(t *testing.T, addr string) {
   299  	client, err := Dial("tcp", addr)
   300  	if err != nil {
   301  		t.Fatal("dialing", err)
   302  	}
   303  	defer client.Close()
   304  
   305  	// Synchronous calls
   306  	args := &Args{7, 8}
   307  	reply := new(Reply)
   308  	err = client.Call("newServer.Arith.Add", args, reply)
   309  	if err != nil {
   310  		t.Errorf("Add: expected no error but got string %q", err.Error())
   311  	}
   312  	if reply.C != args.A+args.B {
   313  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   314  	}
   315  }
   316  
   317  func TestHTTP(t *testing.T) {
   318  	once.Do(startServer)
   319  	testHTTPRPC(t, "")
   320  	newOnce.Do(startNewServer)
   321  	testHTTPRPC(t, newHttpPath)
   322  }
   323  
   324  func testHTTPRPC(t *testing.T, path string) {
   325  	var client *Client
   326  	var err error
   327  	if path == "" {
   328  		client, err = DialHTTP("tcp", httpServerAddr)
   329  	} else {
   330  		client, err = DialHTTPPath("tcp", httpServerAddr, path)
   331  	}
   332  	if err != nil {
   333  		t.Fatal("dialing", err)
   334  	}
   335  	defer client.Close()
   336  
   337  	// Synchronous calls
   338  	args := &Args{7, 8}
   339  	reply := new(Reply)
   340  	err = client.Call("Arith.Add", args, reply)
   341  	if err != nil {
   342  		t.Errorf("Add: expected no error but got string %q", err.Error())
   343  	}
   344  	if reply.C != args.A+args.B {
   345  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   346  	}
   347  }
   348  
   349  func TestBuiltinTypes(t *testing.T) {
   350  	once.Do(startServer)
   351  
   352  	client, err := DialHTTP("tcp", httpServerAddr)
   353  	if err != nil {
   354  		t.Fatal("dialing", err)
   355  	}
   356  	defer client.Close()
   357  
   358  	// Map
   359  	args := &Args{7, 8}
   360  	replyMap := map[int]int{}
   361  	err = client.Call("BuiltinTypes.Map", args, &replyMap)
   362  	if err != nil {
   363  		t.Errorf("Map: expected no error but got string %q", err.Error())
   364  	}
   365  	if replyMap[args.A] != args.B {
   366  		t.Errorf("Map: expected %d got %d", args.B, replyMap[args.A])
   367  	}
   368  
   369  	// Slice
   370  	args = &Args{7, 8}
   371  	replySlice := []int{}
   372  	err = client.Call("BuiltinTypes.Slice", args, &replySlice)
   373  	if err != nil {
   374  		t.Errorf("Slice: expected no error but got string %q", err.Error())
   375  	}
   376  	if e := []int{args.A, args.B}; !reflect.DeepEqual(replySlice, e) {
   377  		t.Errorf("Slice: expected %v got %v", e, replySlice)
   378  	}
   379  
   380  	// Array
   381  	args = &Args{7, 8}
   382  	replyArray := [2]int{}
   383  	err = client.Call("BuiltinTypes.Array", args, &replyArray)
   384  	if err != nil {
   385  		t.Errorf("Array: expected no error but got string %q", err.Error())
   386  	}
   387  	if e := [2]int{args.A, args.B}; !reflect.DeepEqual(replyArray, e) {
   388  		t.Errorf("Array: expected %v got %v", e, replyArray)
   389  	}
   390  }
   391  
   392  // CodecEmulator provides a client-like api and a ServerCodec interface.
   393  // Can be used to test ServeRequest.
   394  type CodecEmulator struct {
   395  	server        *Server
   396  	serviceMethod string
   397  	args          *Args
   398  	reply         *Reply
   399  	err           error
   400  }
   401  
   402  func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) error {
   403  	codec.serviceMethod = serviceMethod
   404  	codec.args = args
   405  	codec.reply = reply
   406  	codec.err = nil
   407  	var serverError error
   408  	if codec.server == nil {
   409  		serverError = ServeRequest(codec)
   410  	} else {
   411  		serverError = codec.server.ServeRequest(codec)
   412  	}
   413  	if codec.err == nil && serverError != nil {
   414  		codec.err = serverError
   415  	}
   416  	return codec.err
   417  }
   418  
   419  func (codec *CodecEmulator) ReadRequestHeader(req *Request) error {
   420  	req.ServiceMethod = codec.serviceMethod
   421  	req.Seq = 0
   422  	return nil
   423  }
   424  
   425  func (codec *CodecEmulator) ReadRequestBody(argv interface{}) error {
   426  	if codec.args == nil {
   427  		return io.ErrUnexpectedEOF
   428  	}
   429  	*(argv.(*Args)) = *codec.args
   430  	return nil
   431  }
   432  
   433  func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) error {
   434  	if resp.Error != "" {
   435  		codec.err = errors.New(resp.Error)
   436  	} else {
   437  		*codec.reply = *(reply.(*Reply))
   438  	}
   439  	return nil
   440  }
   441  
   442  func (codec *CodecEmulator) Close() error {
   443  	return nil
   444  }
   445  
   446  func TestServeRequest(t *testing.T) {
   447  	once.Do(startServer)
   448  	testServeRequest(t, nil)
   449  	newOnce.Do(startNewServer)
   450  	testServeRequest(t, newServer)
   451  }
   452  
   453  func testServeRequest(t *testing.T, server *Server) {
   454  	client := CodecEmulator{server: server}
   455  	defer client.Close()
   456  
   457  	args := &Args{7, 8}
   458  	reply := new(Reply)
   459  	err := client.Call("Arith.Add", args, reply)
   460  	if err != nil {
   461  		t.Errorf("Add: expected no error but got string %q", err.Error())
   462  	}
   463  	if reply.C != args.A+args.B {
   464  		t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   465  	}
   466  
   467  	err = client.Call("Arith.Add", nil, reply)
   468  	if err == nil {
   469  		t.Errorf("expected error calling Arith.Add with nil arg")
   470  	}
   471  }
   472  
   473  type ReplyNotPointer int
   474  type ArgNotPublic int
   475  type ReplyNotPublic int
   476  type NeedsPtrType int
   477  type local struct{}
   478  
   479  func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
   480  	return nil
   481  }
   482  
   483  func (t *ArgNotPublic) ArgNotPublic(args *local, reply *Reply) error {
   484  	return nil
   485  }
   486  
   487  func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
   488  	return nil
   489  }
   490  
   491  func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error {
   492  	return nil
   493  }
   494  
   495  // Check that registration handles lots of bad methods and a type with no suitable methods.
   496  func TestRegistrationError(t *testing.T) {
   497  	err := Register(new(ReplyNotPointer))
   498  	if err == nil {
   499  		t.Error("expected error registering ReplyNotPointer")
   500  	}
   501  	err = Register(new(ArgNotPublic))
   502  	if err == nil {
   503  		t.Error("expected error registering ArgNotPublic")
   504  	}
   505  	err = Register(new(ReplyNotPublic))
   506  	if err == nil {
   507  		t.Error("expected error registering ReplyNotPublic")
   508  	}
   509  	err = Register(NeedsPtrType(0))
   510  	if err == nil {
   511  		t.Error("expected error registering NeedsPtrType")
   512  	} else if !strings.Contains(err.Error(), "pointer") {
   513  		t.Error("expected hint when registering NeedsPtrType")
   514  	}
   515  }
   516  
   517  type WriteFailCodec int
   518  
   519  func (WriteFailCodec) WriteRequest(*Request, interface{}) error {
   520  	// the panic caused by this error used to not unlock a lock.
   521  	return errors.New("fail")
   522  }
   523  
   524  func (WriteFailCodec) ReadResponseHeader(*Response) error {
   525  	select {}
   526  }
   527  
   528  func (WriteFailCodec) ReadResponseBody(interface{}) error {
   529  	select {}
   530  }
   531  
   532  func (WriteFailCodec) Close() error {
   533  	return nil
   534  }
   535  
   536  func TestSendDeadlock(t *testing.T) {
   537  	client := NewClientWithCodec(WriteFailCodec(0))
   538  	defer client.Close()
   539  
   540  	done := make(chan bool)
   541  	go func() {
   542  		testSendDeadlock(client)
   543  		testSendDeadlock(client)
   544  		done <- true
   545  	}()
   546  	select {
   547  	case <-done:
   548  		return
   549  	case <-time.After(5 * time.Second):
   550  		t.Fatal("deadlock")
   551  	}
   552  }
   553  
   554  func testSendDeadlock(client *Client) {
   555  	defer func() {
   556  		recover()
   557  	}()
   558  	args := &Args{7, 8}
   559  	reply := new(Reply)
   560  	client.Call("Arith.Add", args, reply)
   561  }
   562  
   563  func dialDirect() (*Client, error) {
   564  	return Dial("tcp", serverAddr)
   565  }
   566  
   567  func dialHTTP() (*Client, error) {
   568  	return DialHTTP("tcp", httpServerAddr)
   569  }
   570  
   571  func countMallocs(dial func() (*Client, error), t *testing.T) float64 {
   572  	once.Do(startServer)
   573  	client, err := dial()
   574  	if err != nil {
   575  		t.Fatal("error dialing", err)
   576  	}
   577  	defer client.Close()
   578  
   579  	args := &Args{7, 8}
   580  	reply := new(Reply)
   581  	return testing.AllocsPerRun(100, func() {
   582  		err := client.Call("Arith.Add", args, reply)
   583  		if err != nil {
   584  			t.Errorf("Add: expected no error but got string %q", err.Error())
   585  		}
   586  		if reply.C != args.A+args.B {
   587  			t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
   588  		}
   589  	})
   590  }
   591  
   592  func TestCountMallocs(t *testing.T) {
   593  	if testing.Short() {
   594  		t.Skip("skipping malloc count in short mode")
   595  	}
   596  	if runtime.GOMAXPROCS(0) > 1 {
   597  		t.Skip("skipping; GOMAXPROCS>1")
   598  	}
   599  	fmt.Printf("mallocs per rpc round trip: %v\n", countMallocs(dialDirect, t))
   600  }
   601  
   602  func TestCountMallocsOverHTTP(t *testing.T) {
   603  	if testing.Short() {
   604  		t.Skip("skipping malloc count in short mode")
   605  	}
   606  	if runtime.GOMAXPROCS(0) > 1 {
   607  		t.Skip("skipping; GOMAXPROCS>1")
   608  	}
   609  	fmt.Printf("mallocs per HTTP rpc round trip: %v\n", countMallocs(dialHTTP, t))
   610  }
   611  
   612  type writeCrasher struct {
   613  	done chan bool
   614  }
   615  
   616  func (writeCrasher) Close() error {
   617  	return nil
   618  }
   619  
   620  func (w *writeCrasher) Read(p []byte) (int, error) {
   621  	<-w.done
   622  	return 0, io.EOF
   623  }
   624  
   625  func (writeCrasher) Write(p []byte) (int, error) {
   626  	return 0, errors.New("fake write failure")
   627  }
   628  
   629  func TestClientWriteError(t *testing.T) {
   630  	w := &writeCrasher{done: make(chan bool)}
   631  	c := NewClient(w)
   632  	defer c.Close()
   633  
   634  	res := false
   635  	err := c.Call("foo", 1, &res)
   636  	if err == nil {
   637  		t.Fatal("expected error")
   638  	}
   639  	if err.Error() != "fake write failure" {
   640  		t.Error("unexpected value of error:", err)
   641  	}
   642  	w.done <- true
   643  }
   644  
   645  func TestTCPClose(t *testing.T) {
   646  	once.Do(startServer)
   647  
   648  	client, err := dialHTTP()
   649  	if err != nil {
   650  		t.Fatalf("dialing: %v", err)
   651  	}
   652  	defer client.Close()
   653  
   654  	args := Args{17, 8}
   655  	var reply Reply
   656  	err = client.Call("Arith.Mul", args, &reply)
   657  	if err != nil {
   658  		t.Fatal("arith error:", err)
   659  	}
   660  	t.Logf("Arith: %d*%d=%d\n", args.A, args.B, reply)
   661  	if reply.C != args.A*args.B {
   662  		t.Errorf("Add: expected %d got %d", reply.C, args.A*args.B)
   663  	}
   664  }
   665  
   666  func TestErrorAfterClientClose(t *testing.T) {
   667  	once.Do(startServer)
   668  
   669  	client, err := dialHTTP()
   670  	if err != nil {
   671  		t.Fatalf("dialing: %v", err)
   672  	}
   673  	err = client.Close()
   674  	if err != nil {
   675  		t.Fatal("close error:", err)
   676  	}
   677  	err = client.Call("Arith.Add", &Args{7, 9}, new(Reply))
   678  	if err != ErrShutdown {
   679  		t.Errorf("Forever: expected ErrShutdown got %v", err)
   680  	}
   681  }
   682  
   683  // Tests the fix to issue 11221. Without the fix, this loops forever or crashes.
   684  func TestAcceptExitAfterListenerClose(t *testing.T) {
   685  	newServer := NewServer()
   686  	newServer.Register(new(Arith))
   687  	newServer.RegisterName("net.rpc.Arith", new(Arith))
   688  	newServer.RegisterName("newServer.Arith", new(Arith))
   689  
   690  	var l net.Listener
   691  	l, _ = listenTCP()
   692  	l.Close()
   693  	newServer.Accept(l)
   694  }
   695  
   696  func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
   697  	once.Do(startServer)
   698  	client, err := dial()
   699  	if err != nil {
   700  		b.Fatal("error dialing:", err)
   701  	}
   702  	defer client.Close()
   703  
   704  	// Synchronous calls
   705  	args := &Args{7, 8}
   706  	b.ResetTimer()
   707  
   708  	b.RunParallel(func(pb *testing.PB) {
   709  		reply := new(Reply)
   710  		for pb.Next() {
   711  			err := client.Call("Arith.Add", args, reply)
   712  			if err != nil {
   713  				b.Fatalf("rpc error: Add: expected no error but got string %q", err.Error())
   714  			}
   715  			if reply.C != args.A+args.B {
   716  				b.Fatalf("rpc error: Add: expected %d got %d", reply.C, args.A+args.B)
   717  			}
   718  		}
   719  	})
   720  }
   721  
   722  func benchmarkEndToEndAsync(dial func() (*Client, error), b *testing.B) {
   723  	if b.N == 0 {
   724  		return
   725  	}
   726  	const MaxConcurrentCalls = 100
   727  	once.Do(startServer)
   728  	client, err := dial()
   729  	if err != nil {
   730  		b.Fatal("error dialing:", err)
   731  	}
   732  	defer client.Close()
   733  
   734  	// Asynchronous calls
   735  	args := &Args{7, 8}
   736  	procs := 4 * runtime.GOMAXPROCS(-1)
   737  	send := int32(b.N)
   738  	recv := int32(b.N)
   739  	var wg sync.WaitGroup
   740  	wg.Add(procs)
   741  	gate := make(chan bool, MaxConcurrentCalls)
   742  	res := make(chan *Call, MaxConcurrentCalls)
   743  	b.ResetTimer()
   744  
   745  	for p := 0; p < procs; p++ {
   746  		go func() {
   747  			for atomic.AddInt32(&send, -1) >= 0 {
   748  				gate <- true
   749  				reply := new(Reply)
   750  				client.Go("Arith.Add", args, reply, res)
   751  			}
   752  		}()
   753  		go func() {
   754  			for call := range res {
   755  				A := call.Args.(*Args).A
   756  				B := call.Args.(*Args).B
   757  				C := call.Reply.(*Reply).C
   758  				if A+B != C {
   759  					b.Errorf("incorrect reply: Add: expected %d got %d", A+B, C)
   760  					return
   761  				}
   762  				<-gate
   763  				if atomic.AddInt32(&recv, -1) == 0 {
   764  					close(res)
   765  				}
   766  			}
   767  			wg.Done()
   768  		}()
   769  	}
   770  	wg.Wait()
   771  }
   772  
   773  func BenchmarkEndToEnd(b *testing.B) {
   774  	benchmarkEndToEnd(dialDirect, b)
   775  }
   776  
   777  func BenchmarkEndToEndHTTP(b *testing.B) {
   778  	benchmarkEndToEnd(dialHTTP, b)
   779  }
   780  
   781  func BenchmarkEndToEndAsync(b *testing.B) {
   782  	benchmarkEndToEndAsync(dialDirect, b)
   783  }
   784  
   785  func BenchmarkEndToEndAsyncHTTP(b *testing.B) {
   786  	benchmarkEndToEndAsync(dialHTTP, b)
   787  }