trpc.group/trpc-go/trpc-go@v1.0.3/transport/server_transport_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package transport_test
    15  
    16  import (
    17  	"context"
    18  	"encoding/binary"
    19  	"encoding/json"
    20  	"errors"
    21  	"fmt"
    22  	"net"
    23  	"runtime"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/stretchr/testify/assert"
    29  	"github.com/stretchr/testify/require"
    30  
    31  	_ "trpc.group/trpc-go/trpc-go"
    32  	"trpc.group/trpc-go/trpc-go/errs"
    33  	"trpc.group/trpc-go/trpc-go/transport"
    34  )
    35  
    36  func TestNewServerTransport(t *testing.T) {
    37  	st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute))
    38  	assert.NotNil(t, st)
    39  }
    40  
    41  func TestTCPListenAndServe(t *testing.T) {
    42  	var addr = getFreeAddr("tcp4")
    43  
    44  	// Wait until server transport is ready.
    45  	wg := sync.WaitGroup{}
    46  	wg.Add(1)
    47  	go func() {
    48  		defer wg.Done()
    49  		st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute))
    50  		err := st.ListenAndServe(context.Background(),
    51  			transport.WithListenNetwork("tcp4"),
    52  			transport.WithListenAddress(addr),
    53  			transport.WithHandler(&errorHandler{}),
    54  			transport.WithServerFramerBuilder(&framerBuilder{}),
    55  			transport.WithServiceName("test name"),
    56  		)
    57  
    58  		if err != nil {
    59  			t.Logf("ListenAndServe fail:%v", err)
    60  		}
    61  	}()
    62  	wg.Wait()
    63  
    64  	// Round trip.
    65  	req := &helloRequest{
    66  		Name: "trpc",
    67  		Msg:  "HelloWorld",
    68  	}
    69  
    70  	data, err := json.Marshal(req)
    71  	if err != nil {
    72  		t.Fatalf("json marshal fail:%v", err)
    73  	}
    74  	lenData := make([]byte, 4)
    75  	binary.BigEndian.PutUint32(lenData, uint32(len(data)))
    76  
    77  	reqData := append(lenData, data...)
    78  
    79  	ctx, f := context.WithTimeout(context.Background(), 20*time.Millisecond)
    80  	defer f()
    81  
    82  	_, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("tcp4"),
    83  		transport.WithDialAddress(addr),
    84  		transport.WithClientFramerBuilder(&framerBuilder{}))
    85  	assert.NotNil(t, err)
    86  }
    87  
    88  func TestTCPTLSListenAndServe(t *testing.T) {
    89  	addr := getFreeAddr("tcp")
    90  
    91  	// Wait until server transport ready.
    92  	wg := &sync.WaitGroup{}
    93  	wg.Add(1)
    94  	go func() {
    95  		defer wg.Done()
    96  		st := transport.NewServerTransport()
    97  		err := st.ListenAndServe(context.Background(),
    98  			transport.WithListenNetwork("tcp"),
    99  			transport.WithListenAddress(addr),
   100  			transport.WithHandler(&echoHandler{}),
   101  			transport.WithServerFramerBuilder(&framerBuilder{}),
   102  			transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.pem"),
   103  		)
   104  
   105  		if err != nil {
   106  			t.Logf("ListenAndServe fail:%v", err)
   107  		}
   108  	}()
   109  	wg.Wait()
   110  
   111  	// Round trip.
   112  	req := &helloRequest{
   113  		Name: "trpc",
   114  		Msg:  "HelloWorld",
   115  	}
   116  
   117  	data, err := json.Marshal(req)
   118  	if err != nil {
   119  		t.Fatalf("json marshal fail:%v", err)
   120  	}
   121  	lenData := make([]byte, 4)
   122  	binary.BigEndian.PutUint32(lenData, uint32(len(data)))
   123  
   124  	reqData := append(lenData, data...)
   125  
   126  	ctx, f := context.WithTimeout(context.Background(), 200*time.Millisecond)
   127  	defer f()
   128  
   129  	_, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("tcp"),
   130  		transport.WithDialAddress(addr),
   131  		transport.WithClientFramerBuilder(&framerBuilder{}),
   132  		transport.WithDialTLS("../testdata/client.crt", "../testdata/client.key", "../testdata/ca.pem", "localhost"))
   133  	assert.Nil(t, err)
   134  
   135  	_, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("tcp"),
   136  		transport.WithDialAddress(addr),
   137  		transport.WithClientFramerBuilder(&framerBuilder{}),
   138  		transport.WithDialTLS("../testdata/client.crt", "../testdata/client.key", "none", ""))
   139  	assert.Nil(t, err)
   140  }
   141  
   142  func TestHandleError(t *testing.T) {
   143  	var addr = getFreeAddr("udp4")
   144  
   145  	// Wait until server transport is ready.
   146  	wg := &sync.WaitGroup{}
   147  	wg.Add(1)
   148  	go func() {
   149  		defer wg.Done()
   150  		err := transport.ListenAndServe(
   151  			transport.WithListenNetwork("udp4"),
   152  			transport.WithListenAddress(addr),
   153  			transport.WithHandler(&errorHandler{}),
   154  			transport.WithServerFramerBuilder(&framerBuilder{}),
   155  		)
   156  
   157  		if err != nil {
   158  			t.Logf("test fail:%v", err)
   159  		}
   160  	}()
   161  	wg.Wait()
   162  
   163  	// Round trip.
   164  	req := &helloRequest{
   165  		Name: "trpc",
   166  		Msg:  "HelloWorld",
   167  	}
   168  
   169  	data, err := json.Marshal(req)
   170  	if err != nil {
   171  		t.Fatalf("test fail:%v", err)
   172  	}
   173  	lenData := make([]byte, 4)
   174  	binary.BigEndian.PutUint32(lenData, uint32(len(data)))
   175  
   176  	reqData := append(lenData, data...)
   177  
   178  	ctx, f := context.WithTimeout(context.Background(), 20*time.Millisecond)
   179  	defer f()
   180  	_, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("udp4"),
   181  		transport.WithDialAddress(addr),
   182  		transport.WithClientFramerBuilder(&framerBuilder{}))
   183  	assert.NotNil(t, err)
   184  }
   185  
   186  func TestNewServerTransport_NotSupport(t *testing.T) {
   187  	st := transport.NewServerTransport()
   188  	err := st.ListenAndServe(context.Background(), transport.WithListenNetwork("unix"))
   189  	assert.NotNil(t, err)
   190  
   191  	err = st.ListenAndServe(context.Background(), transport.WithListenNetwork("xxx"))
   192  	assert.NotNil(t, err)
   193  }
   194  
   195  func TestServerTransport_ListenAndServeUDP(t *testing.T) {
   196  	// NoReusePort
   197  	st := transport.NewServerTransport(transport.WithReusePort(false),
   198  		transport.WithKeepAlivePeriod(time.Minute))
   199  	err := st.ListenAndServe(
   200  		context.Background(),
   201  		transport.WithListenNetwork("udp"),
   202  		transport.WithServerFramerBuilder(&framerBuilder{}),
   203  	)
   204  	assert.Nil(t, err)
   205  
   206  	st = transport.NewServerTransport(transport.WithReusePort(true))
   207  	err = st.ListenAndServe(
   208  		context.Background(),
   209  		transport.WithListenNetwork("udp"),
   210  		transport.WithServerFramerBuilder(&framerBuilder{}),
   211  	)
   212  	assert.Nil(t, err)
   213  
   214  	st = transport.NewServerTransport(transport.WithReusePort(true))
   215  	err = st.ListenAndServe(
   216  		context.Background(),
   217  		transport.WithListenNetwork("ip"),
   218  		transport.WithServerFramerBuilder(&framerBuilder{}),
   219  	)
   220  	assert.NotNil(t, err)
   221  }
   222  
   223  func TestServerTransport_ListenAndServe(t *testing.T) {
   224  	// NoFramerBuilder
   225  	st := transport.NewServerTransport(transport.WithReusePort(false))
   226  	err := st.ListenAndServe(context.Background(), transport.WithListenNetwork("tcp"))
   227  	assert.NotNil(t, err)
   228  
   229  	fb := transport.GetFramerBuilder("trpc")
   230  	// NoReusePort
   231  	st = transport.NewServerTransport(transport.WithReusePort(false))
   232  	err = st.ListenAndServe(context.Background(),
   233  		transport.WithListenNetwork("tcp"),
   234  		transport.WithServerFramerBuilder(fb))
   235  	assert.Nil(t, err)
   236  
   237  	// ReusePort
   238  	st = transport.NewServerTransport(transport.WithReusePort(true))
   239  	err = st.ListenAndServe(context.Background(),
   240  		transport.WithListenNetwork("tcp"),
   241  		transport.WithServerFramerBuilder(fb))
   242  	assert.Nil(t, err)
   243  
   244  	// Listener
   245  	lis, err := net.Listen("tcp", getFreeAddr("tcp"))
   246  	assert.Nil(t, err)
   247  	st = transport.NewServerTransport()
   248  	err = st.ListenAndServe(context.Background(),
   249  		transport.WithListener(lis),
   250  		transport.WithServerFramerBuilder(fb))
   251  	assert.Nil(t, err)
   252  	lis.Close()
   253  
   254  	// ReusePort + Listen Error
   255  	st = transport.NewServerTransport(transport.WithReusePort(true))
   256  	err = st.ListenAndServe(context.Background(),
   257  		transport.WithListenNetwork("tcperror"),
   258  		transport.WithServerFramerBuilder(fb))
   259  	assert.NotNil(t, err)
   260  
   261  	// context cancel
   262  	ctx, cancel := context.WithCancel(context.Background())
   263  	cancel()
   264  	st = transport.NewServerTransport(transport.WithReusePort(true))
   265  	err = st.ListenAndServe(ctx, transport.WithListenNetwork("tcp"), transport.WithServerFramerBuilder(fb))
   266  	assert.Nil(t, err)
   267  }
   268  
   269  func TestServerTransport_ListenAndServeBothUDPAndTCP(t *testing.T) {
   270  	fb := transport.GetFramerBuilder("trpc")
   271  	// Empty network.
   272  	network := ""
   273  	st := transport.NewServerTransport()
   274  	err := st.ListenAndServe(context.Background(), transport.WithListenNetwork(network))
   275  	assert.EqualError(t, err, "server transport: not support network type "+network)
   276  
   277  	// Another unknown wrong input.
   278  	network = "wrong_type"
   279  	st = transport.NewServerTransport()
   280  	err = st.ListenAndServe(context.Background(), transport.WithListenNetwork(network))
   281  	assert.EqualError(t, err, "server transport: not support network type "+network)
   282  
   283  	// Right input.
   284  	network = "tcp,udp"
   285  	// No reuse.
   286  	st = transport.NewServerTransport(transport.WithReusePort(false))
   287  	err = st.ListenAndServe(context.Background(),
   288  		transport.WithListenNetwork(network),
   289  		transport.WithServerFramerBuilder(fb))
   290  	assert.Nil(t, err)
   291  }
   292  
   293  // TestTCPListenAndServeAsync tests asynchronous server process.
   294  func TestTCPListenAndServeAsync(t *testing.T) {
   295  	var addr = getFreeAddr("tcp4")
   296  
   297  	// Wait until server transport is ready.
   298  	wg := sync.WaitGroup{}
   299  	wg.Add(1)
   300  	go func() {
   301  		defer wg.Done()
   302  		st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute))
   303  		err := st.ListenAndServe(context.Background(),
   304  			transport.WithListenNetwork("tcp4"),
   305  			transport.WithListenAddress(addr),
   306  			transport.WithHandler(&errorHandler{}),
   307  			transport.WithServerFramerBuilder(&framerBuilder{}),
   308  			transport.WithServerAsync(true),
   309  			transport.WithWritev(true),
   310  		)
   311  
   312  		if err != nil {
   313  			t.Logf("ListenAndServe fail:%v", err)
   314  		}
   315  	}()
   316  	wg.Wait()
   317  
   318  	// round trip
   319  	req := &helloRequest{
   320  		Name: "trpc",
   321  		Msg:  "HelloWorld",
   322  	}
   323  
   324  	data, err := json.Marshal(req)
   325  	if err != nil {
   326  		t.Fatalf("json marshal fail:%v", err)
   327  	}
   328  	lenData := make([]byte, 4)
   329  	binary.BigEndian.PutUint32(lenData, uint32(len(data)))
   330  
   331  	reqData := append(lenData, data...)
   332  
   333  	ctx, f := context.WithTimeout(context.Background(), 20*time.Millisecond)
   334  	defer f()
   335  
   336  	_, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("tcp4"),
   337  		transport.WithDialAddress(addr),
   338  		transport.WithClientFramerBuilder(&framerBuilder{}))
   339  	assert.NotNil(t, err)
   340  }
   341  
   342  // TestTCPListenAndServerRoutinePool tests serving with goroutine pool.
   343  func TestTCPListenAndServerRoutinePool(t *testing.T) {
   344  	var addr = getFreeAddr("tcp4")
   345  
   346  	// Wait until server transport is ready.
   347  	wg := sync.WaitGroup{}
   348  	wg.Add(1)
   349  	go func() {
   350  		defer wg.Done()
   351  		st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute))
   352  		err := st.ListenAndServe(context.Background(),
   353  			transport.WithListenNetwork("tcp4"),
   354  			transport.WithListenAddress(addr),
   355  			transport.WithHandler(&errorHandler{}),
   356  			transport.WithServerFramerBuilder(&framerBuilder{}),
   357  			transport.WithServerAsync(true),
   358  			transport.WithMaxRoutines(100),
   359  		)
   360  
   361  		if err != nil {
   362  			t.Logf("ListenAndServe fail:%v", err)
   363  		}
   364  	}()
   365  	wg.Wait()
   366  
   367  	// round trip
   368  	req := &helloRequest{
   369  		Name: "trpc",
   370  		Msg:  "HelloWorld",
   371  	}
   372  
   373  	data, err := json.Marshal(req)
   374  	if err != nil {
   375  		t.Fatalf("json marshal fail:%v", err)
   376  	}
   377  	lenData := make([]byte, 4)
   378  	binary.BigEndian.PutUint32(lenData, uint32(len(data)))
   379  
   380  	reqData := append(lenData, data...)
   381  
   382  	ctx, f := context.WithTimeout(context.Background(), 20*time.Millisecond)
   383  	defer f()
   384  
   385  	_, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("tcp4"),
   386  		transport.WithDialAddress(addr),
   387  		transport.WithClientFramerBuilder(&framerBuilder{}))
   388  	assert.NotNil(t, err)
   389  }
   390  
   391  func TestWithReusePort(t *testing.T) {
   392  	opts := &transport.ServerTransportOptions{}
   393  	require.False(t, opts.ReusePort)
   394  
   395  	opt := transport.WithReusePort(true)
   396  	require.NotNil(t, opt)
   397  	opt(opts)
   398  	if runtime.GOOS != "windows" {
   399  		require.True(t, opts.ReusePort)
   400  	} else {
   401  		require.False(t, opts.ReusePort)
   402  	}
   403  
   404  	opt = transport.WithReusePort(false)
   405  	require.NotNil(t, opt)
   406  	opt(opts)
   407  	require.False(t, opts.ReusePort)
   408  }
   409  
   410  func TestWithRecvMsgChannelSize(t *testing.T) {
   411  	opt := transport.WithRecvMsgChannelSize(1000)
   412  	assert.NotNil(t, opt)
   413  	opts := &transport.ServerTransportOptions{}
   414  	opt(opts)
   415  	assert.Equal(t, 1000, opts.RecvMsgChannelSize)
   416  }
   417  
   418  func TestWithSendMsgChannelSize(t *testing.T) {
   419  	opt := transport.WithSendMsgChannelSize(1000)
   420  	assert.NotNil(t, opt)
   421  	opts := &transport.ServerTransportOptions{}
   422  	opt(opts)
   423  	assert.Equal(t, 1000, opts.SendMsgChannelSize)
   424  }
   425  
   426  func TestWithRecvUDPPacketBufferSize(t *testing.T) {
   427  	opt := transport.WithRecvUDPPacketBufferSize(1000)
   428  	assert.NotNil(t, opt)
   429  	opts := &transport.ServerTransportOptions{}
   430  	opt(opts)
   431  	assert.Equal(t, 1000, opts.RecvUDPPacketBufferSize)
   432  }
   433  
   434  func TestWithRecvUDPRawSocketBufSize(t *testing.T) {
   435  	opt := transport.WithRecvUDPRawSocketBufSize(1000)
   436  	assert.NotNil(t, opt)
   437  	opts := &transport.ServerTransportOptions{}
   438  	opt(opts)
   439  	assert.Equal(t, 1000, opts.RecvUDPRawSocketBufSize)
   440  }
   441  
   442  func TestWithIdleTimeout(t *testing.T) {
   443  	opt := transport.WithIdleTimeout(time.Second)
   444  	assert.NotNil(t, opt)
   445  	opts := &transport.ServerTransportOptions{}
   446  	opt(opts)
   447  	assert.Equal(t, time.Second, opts.IdleTimeout)
   448  }
   449  
   450  func TestWithKeepAlivePeriod(t *testing.T) {
   451  	opt := transport.WithKeepAlivePeriod(time.Minute)
   452  	assert.NotNil(t, opt)
   453  	opts := &transport.ServerTransportOptions{}
   454  	opt(opts)
   455  	assert.Equal(t, time.Minute, opts.KeepAlivePeriod)
   456  }
   457  
   458  func TestWithServeTLS(t *testing.T) {
   459  	opt := transport.WithServeTLS("certfile", "keyfile", "")
   460  	assert.NotNil(t, opt)
   461  	opts := &transport.ListenServeOptions{}
   462  	opt(opts)
   463  	assert.Equal(t, "certfile", opts.TLSCertFile)
   464  	assert.Equal(t, "keyfile", opts.TLSKeyFile)
   465  }
   466  
   467  // TestWithServeAsync tests setting server async.
   468  func TestWithServeAsync(t *testing.T) {
   469  	opt := transport.WithServerAsync(true)
   470  	assert.NotNil(t, opt)
   471  	opts := &transport.ListenServeOptions{}
   472  	opt(opts)
   473  	assert.Equal(t, true, opts.ServerAsync)
   474  }
   475  
   476  // TestWithWritev tests setting writev.
   477  func TestWithWritev(t *testing.T) {
   478  	opt := transport.WithWritev(true)
   479  	assert.NotNil(t, opt)
   480  	opts := &transport.ListenServeOptions{}
   481  	opt(opts)
   482  	assert.Equal(t, true, opts.Writev)
   483  }
   484  
   485  // TestWithMaxRoutine tests setting max number of goroutines.
   486  func TestWithMaxRoutine(t *testing.T) {
   487  	opt := transport.WithMaxRoutines(100)
   488  	assert.NotNil(t, opt)
   489  	opts := &transport.ListenServeOptions{}
   490  	opt(opts)
   491  	assert.Equal(t, 100, opts.Routines)
   492  }
   493  
   494  // TestTCPServerClosed tests if TCP listener can be closed immediately.
   495  func TestTCPListenerClosed(t *testing.T) {
   496  	err := tryCloseTCPListener(false)
   497  	if err != nil {
   498  		t.Errorf("close tcp listener err: %v", err)
   499  	}
   500  }
   501  
   502  // TestTCPListenerClosed_WithReuseport tests if TCP listener can be closed immediately.
   503  func TestTCPListenerClosed_WithReuseport(t *testing.T) {
   504  	err := tryCloseTCPListener(true)
   505  	if err != nil {
   506  		t.Errorf("close tcp listener (with reuseport) err: %v", err)
   507  	}
   508  }
   509  
   510  func tryCloseTCPListener(reuseport bool) error {
   511  	port, err := getFreePort("tcp")
   512  	if err != nil {
   513  		return fmt.Errorf("get freeport error: %v", err)
   514  	}
   515  
   516  	ctx := context.Background()
   517  	ctx, cancel := context.WithCancel(ctx)
   518  
   519  	var prepareErr error
   520  	wg := sync.WaitGroup{}
   521  	wg.Add(1)
   522  	go func() {
   523  		defer wg.Done()
   524  		st := transport.NewServerTransport(transport.WithReusePort(reuseport))
   525  		err := st.ListenAndServe(ctx,
   526  			transport.WithListenNetwork("tcp"),
   527  			transport.WithListenAddress(fmt.Sprintf(":%d", port)),
   528  			transport.WithHandler(&echoHandler{}),
   529  			transport.WithServerFramerBuilder(&framerBuilder{}),
   530  		)
   531  		if err != nil {
   532  			prepareErr = err
   533  		}
   534  	}()
   535  	wg.Wait()
   536  
   537  	if prepareErr != nil {
   538  		cancel()
   539  		return fmt.Errorf("prepare listener error: %v", prepareErr)
   540  	}
   541  
   542  	// First time dial, should work.
   543  	conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port))
   544  	if err != nil {
   545  		cancel()
   546  		return fmt.Errorf("tcp dial error: %v", err)
   547  	}
   548  	conn.Close()
   549  
   550  	// Notify and wait server close.
   551  	cancel()
   552  	time.Sleep(5 * time.Millisecond)
   553  
   554  	// Second time dial, must fail.
   555  	_, err = net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), 10*time.Millisecond)
   556  	if err == nil {
   557  		return fmt.Errorf("tcp dial (2nd time) want error")
   558  	}
   559  	return nil
   560  }
   561  
   562  func TestGetListenersFds(t *testing.T) {
   563  	ListenFds := transport.GetListenersFds()
   564  	assert.NotNil(t, ListenFds)
   565  }
   566  
   567  var savedListenerPort int
   568  
   569  func TestSaveListener(t *testing.T) {
   570  	port, err := getFreePort("tcp")
   571  	if err != nil {
   572  		t.Fatalf("get freeport error: %v", err)
   573  	}
   574  	err = transport.SaveListener(NewPacketConn{})
   575  	assert.NotNil(t, err)
   576  
   577  	newListener, _ := net.Listen("tcp", fmt.Sprintf(":%d", port))
   578  	err = transport.SaveListener(newListener)
   579  	assert.Nil(t, err)
   580  	savedListenerPort = port
   581  }
   582  
   583  func TestTCPSeverErr(t *testing.T) {
   584  	st := transport.NewServerTransport()
   585  	err := st.ListenAndServe(context.Background(),
   586  		transport.WithListenNetwork("tcp"),
   587  		transport.WithListenAddress(getFreeAddr("tcp")),
   588  		transport.WithHandler(&echoHandler{}),
   589  		transport.WithServerFramerBuilder(&framerBuilder{}))
   590  	assert.Nil(t, err)
   591  }
   592  
   593  func TestUDPServerErr(t *testing.T) {
   594  	st := transport.NewServerTransport()
   595  
   596  	err := st.ListenAndServe(context.Background(),
   597  		transport.WithListenNetwork("udp"),
   598  		transport.WithListenAddress(getFreeAddr("udp")),
   599  		transport.WithHandler(&echoHandler{}),
   600  		transport.WithServerFramerBuilder(&framerBuilder{}))
   601  	assert.Nil(t, err)
   602  }
   603  
   604  type fakeListen struct {
   605  }
   606  
   607  func (c *fakeListen) Accept() (net.Conn, error) {
   608  	return nil, &netError{errors.New("网络失败")}
   609  }
   610  func (c *fakeListen) Close() error {
   611  	return nil
   612  }
   613  
   614  func (c *fakeListen) Addr() net.Addr {
   615  	return nil
   616  }
   617  
   618  func TestTCPServerConErr(t *testing.T) {
   619  	go func() {
   620  		fb := transport.GetFramerBuilder("trpc")
   621  		st := transport.NewServerTransport()
   622  		err := st.ListenAndServe(context.Background(),
   623  			transport.WithListener(&fakeListen{}),
   624  			transport.WithServerFramerBuilder(fb))
   625  		if err != nil {
   626  			t.Logf("ListenAndServe fail:%v", err)
   627  		}
   628  	}()
   629  }
   630  
   631  func TestUDPServerConErr(t *testing.T) {
   632  	fb := transport.GetFramerBuilder("trpc")
   633  	st := transport.NewServerTransport()
   634  	err := st.ListenAndServe(context.Background(),
   635  		transport.WithListenNetwork("udp"),
   636  		transport.WithListenAddress(getFreeAddr("udp")),
   637  		transport.WithServerFramerBuilder(fb))
   638  	if err != nil {
   639  		t.Fatalf("ListenAndServe fail:%v", err)
   640  	}
   641  }
   642  
   643  func getFreePort(network string) (int, error) {
   644  	if network == "tcp" || network == "tcp4" || network == "tcp6" {
   645  		addr, err := net.ResolveTCPAddr(network, "localhost:0")
   646  		if err != nil {
   647  			return -1, err
   648  		}
   649  
   650  		l, err := net.ListenTCP(network, addr)
   651  		if err != nil {
   652  			return -1, err
   653  		}
   654  		defer l.Close()
   655  
   656  		return l.Addr().(*net.TCPAddr).Port, nil
   657  	}
   658  
   659  	if network == "udp" || network == "udp4" || network == "udp6" {
   660  		addr, err := net.ResolveUDPAddr(network, "localhost:0")
   661  		if err != nil {
   662  			return -1, err
   663  		}
   664  
   665  		l, err := net.ListenUDP(network, addr)
   666  		if err != nil {
   667  			return -1, err
   668  		}
   669  		defer l.Close()
   670  
   671  		return l.LocalAddr().(*net.UDPAddr).Port, nil
   672  	}
   673  
   674  	return -1, errors.New("invalid network")
   675  }
   676  
   677  func TestGetFreePort(t *testing.T) {
   678  	for i := 0; i < 10; i++ {
   679  		p, err := getFreePort("tcp")
   680  		assert.Nil(t, err)
   681  		assert.NotEqual(t, p, -1)
   682  		t.Logf("get freeport network:%s, port:%d", "tcp", p)
   683  	}
   684  
   685  	for i := 0; i < 10; i++ {
   686  		p, err := getFreePort("udp")
   687  		assert.Nil(t, err)
   688  		assert.NotEqual(t, p, -1)
   689  		t.Logf("get freeport network:%s, port:%d", "udp", p)
   690  	}
   691  
   692  	p1, err := getFreePort("tcp")
   693  	assert.Nil(t, err)
   694  
   695  	p2, err := getFreePort("tcp")
   696  	assert.Nil(t, err)
   697  	assert.NotEqual(t, p1, p2, "allocated 2 conflict ports")
   698  }
   699  
   700  func getFreeAddr(network string) string {
   701  	p, err := getFreePort(network)
   702  	if err != nil {
   703  		panic(err)
   704  	}
   705  
   706  	return fmt.Sprintf(":%d", p)
   707  }
   708  
   709  func TestTCPWriteToClosedConn(t *testing.T) {
   710  	l, err := net.Listen("tcp4", "localhost:0")
   711  	require.Nil(t, err)
   712  	defer l.Close()
   713  
   714  	var wg sync.WaitGroup
   715  	wg.Add(1)
   716  	go func() {
   717  		defer wg.Done()
   718  		st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute))
   719  		err := st.ListenAndServe(context.Background(),
   720  			transport.WithListener(l),
   721  			transport.WithHandler(&echoHandler{}),
   722  			transport.WithServerFramerBuilder(&framerBuilder{}),
   723  			transport.WithServerAsync(true),
   724  		)
   725  		assert.Nil(t, err)
   726  	}()
   727  	wg.Wait()
   728  	conn, err := net.Dial("tcp4", l.Addr().String())
   729  	require.Nil(t, err)
   730  	require.Nil(t, conn.Close())
   731  	_, err = conn.Write([]byte("data"))
   732  	require.Contains(t, errs.Msg(err), "use of closed network connection")
   733  }
   734  
   735  func TestTCPServerHandleErrAndClose(t *testing.T) {
   736  	var addr = getFreeAddr("tcp4")
   737  
   738  	wg := sync.WaitGroup{}
   739  	wg.Add(1)
   740  	go func() {
   741  		defer wg.Done()
   742  		st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute))
   743  		err := st.ListenAndServe(context.Background(),
   744  			transport.WithListenNetwork("tcp4"),
   745  			transport.WithListenAddress(addr),
   746  			transport.WithHandler(&errorHandler{}),
   747  			transport.WithServerFramerBuilder(&framerBuilder{}),
   748  			transport.WithServerAsync(true),
   749  		)
   750  		assert.Nil(t, err)
   751  	}()
   752  	wg.Wait()
   753  
   754  	// First time dial, should work.
   755  	conn, err := net.Dial("tcp", addr)
   756  	assert.Nil(t, err)
   757  	time.Sleep(time.Millisecond * 5)
   758  	data := []byte("hello world")
   759  	req := make([]byte, 4)
   760  	binary.BigEndian.PutUint32(req, uint32(len(data)))
   761  	req = append(req, data...)
   762  	_, err = conn.Write(req)
   763  	assert.Nil(t, err)
   764  
   765  	// Check the connection is closed by server.
   766  	time.Sleep(time.Millisecond * 5)
   767  	out := make([]byte, 8)
   768  	_, err = conn.Read(out)
   769  	assert.NotNil(t, err)
   770  }
   771  
   772  // TestTCPListenAndServeWithSafeFramer tests that we support safe framer without copying packages.
   773  func TestUDPListenAndServeWithSafeFramer(t *testing.T) {
   774  	var addr = getFreeAddr("udp")
   775  
   776  	// Wait until server transport is ready.
   777  	wg := sync.WaitGroup{}
   778  	wg.Add(1)
   779  	go func() {
   780  		defer wg.Done()
   781  		err := transport.ListenAndServe(
   782  			transport.WithListenNetwork("udp"),
   783  			transport.WithListenAddress(addr),
   784  			transport.WithHandler(&echoHandler{}),
   785  			transport.WithServerFramerBuilder(&framerBuilder{safe: true}),
   786  		)
   787  		assert.Nil(t, err)
   788  		time.Sleep(20 * time.Millisecond)
   789  	}()
   790  	wg.Wait()
   791  
   792  	req := &helloRequest{
   793  		Name: "trpc",
   794  		Msg:  "HelloWorld",
   795  	}
   796  	data, err := json.Marshal(req)
   797  	if err != nil {
   798  		t.Fatalf("json marshal fail:%v", err)
   799  	}
   800  	lenData := make([]byte, 4)
   801  	binary.BigEndian.PutUint32(lenData, uint32(len(data)))
   802  	reqData := append(lenData, data...)
   803  	ctx, f := context.WithTimeout(context.Background(), 20*time.Millisecond)
   804  	defer f()
   805  
   806  	rspData, err := transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("udp"),
   807  		transport.WithDialAddress(addr),
   808  		transport.WithClientFramerBuilder(&framerBuilder{safe: true}))
   809  	assert.Nil(t, err)
   810  
   811  	length := binary.BigEndian.Uint32(rspData[:4])
   812  	helloRsp := &helloResponse{}
   813  	err = json.Unmarshal(rspData[4:4+length], helloRsp)
   814  	assert.Nil(t, err)
   815  	assert.Equal(t, helloRsp.Msg, "HelloWorld")
   816  }
   817  
   818  // TestTCPListenAndServeWithSafeFramer tests that frame is not copied when Framer is already safe.
   819  func TestTCPListenAndServeWithSafeFramer(t *testing.T) {
   820  	var addr = getFreeAddr("tcp4")
   821  
   822  	wg := sync.WaitGroup{}
   823  	wg.Add(1)
   824  	go func() {
   825  		defer wg.Done()
   826  		st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute))
   827  		err := st.ListenAndServe(context.Background(),
   828  			transport.WithListenNetwork("tcp4"),
   829  			transport.WithListenAddress(addr),
   830  			transport.WithHandler(&echoHandler{}),
   831  			transport.WithServerFramerBuilder(&framerBuilder{safe: true}),
   832  			transport.WithServerAsync(true),
   833  		)
   834  		assert.Nil(t, err)
   835  		time.Sleep(20 * time.Millisecond)
   836  	}()
   837  	wg.Wait()
   838  
   839  	req := &helloRequest{
   840  		Name: "trpc",
   841  		Msg:  "HelloWorld",
   842  	}
   843  	data, err := json.Marshal(req)
   844  	if err != nil {
   845  		t.Fatalf("json marshal fail:%v", err)
   846  	}
   847  	lenData := make([]byte, 4)
   848  	binary.BigEndian.PutUint32(lenData, uint32(len(data)))
   849  	reqData := append(lenData, data...)
   850  	ctx, f := context.WithTimeout(context.Background(), 20*time.Millisecond)
   851  	defer f()
   852  
   853  	rspData, err := transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("tcp4"),
   854  		transport.WithDialAddress(addr),
   855  		transport.WithClientFramerBuilder(&framerBuilder{safe: true}))
   856  	assert.Nil(t, err)
   857  
   858  	length := binary.BigEndian.Uint32(rspData[:4])
   859  	helloRsp := &helloResponse{}
   860  	err = json.Unmarshal(rspData[4:4+length], helloRsp)
   861  	assert.Nil(t, err)
   862  	assert.Equal(t, helloRsp.Msg, "HelloWorld")
   863  }
   864  
   865  func TestWithDisableKeepAlives(t *testing.T) {
   866  	disable := true
   867  	o := transport.WithDisableKeepAlives(true)
   868  	opts := &transport.ListenServeOptions{}
   869  	o(opts)
   870  	assert.Equal(t, disable, opts.DisableKeepAlives)
   871  }
   872  
   873  func TestWithServerIdleTimeout(t *testing.T) {
   874  	idleTimeout := time.Second
   875  	o := transport.WithServerIdleTimeout(idleTimeout)
   876  	opts := &transport.ListenServeOptions{}
   877  	o(opts)
   878  	assert.Equal(t, opts.IdleTimeout, idleTimeout)
   879  }
   880  
   881  func TestUDPServeClose(t *testing.T) {
   882  	ts := transport.NewServerTransport()
   883  	ctx, cancel := context.WithCancel(context.Background())
   884  	cancel()
   885  	err := ts.ListenAndServe(
   886  		ctx,
   887  		transport.WithListenNetwork("udp"),
   888  		transport.WithListenAddress(getFreeAddr("udp")),
   889  		transport.WithHandler(&echoHandler{}),
   890  		transport.WithServerFramerBuilder(&framerBuilder{safe: true}),
   891  		transport.WithServerAsync(true),
   892  	)
   893  	assert.Nil(t, err)
   894  	time.Sleep(100 * time.Millisecond)
   895  }
   896  
   897  type MockUDPError struct{}
   898  
   899  func (e MockUDPError) Error() string   { return "mock udp error" }
   900  func (e MockUDPError) Timeout() bool   { return false }
   901  func (e MockUDPError) Temporary() bool { return true }
   902  
   903  func TestUDPReadError(t *testing.T) {
   904  	addr := getFreeAddr("udp")
   905  
   906  	err := transport.ListenAndServe(
   907  		transport.WithListenNetwork("udp"),
   908  		transport.WithListenAddress(addr),
   909  		transport.WithHandler(&echoHandler{}),
   910  		transport.WithServerFramerBuilder(&framerBuilder{safe: true}),
   911  		transport.WithServerAsync(false),
   912  	)
   913  	assert.Nil(t, err)
   914  	time.Sleep(60 * time.Millisecond)
   915  }
   916  
   917  func TestUDPWriteError(t *testing.T) {
   918  	addr := getFreeAddr("udp")
   919  
   920  	err := transport.ListenAndServe(
   921  		transport.WithListenNetwork("udp"),
   922  		transport.WithListenAddress(addr),
   923  		transport.WithHandler(&echoHandler{}),
   924  		transport.WithServerFramerBuilder(&framerBuilder{safe: true}),
   925  		transport.WithServerAsync(false),
   926  	)
   927  	assert.Nil(t, err)
   928  	time.Sleep(20 * time.Millisecond)
   929  
   930  	req := &helloRequest{
   931  		Name: "trpc",
   932  		Msg:  "HelloWorld",
   933  	}
   934  	data, err := json.Marshal(req)
   935  	if err != nil {
   936  		t.Fatalf("json marshal fail:%v", err)
   937  	}
   938  	lenData := make([]byte, 4)
   939  	binary.BigEndian.PutUint32(lenData, uint32(len(data)))
   940  	reqData := append(lenData, data...)
   941  	ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
   942  	defer cancel()
   943  	_, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("udp"),
   944  		transport.WithDialAddress(addr),
   945  		transport.WithClientFramerBuilder(&framerBuilder{safe: true}))
   946  	assert.Nil(t, err)
   947  }
   948  
   949  func TestPoolInvokeFail(t *testing.T) {
   950  
   951  	addr := getFreeAddr("udp")
   952  
   953  	err := transport.ListenAndServe(
   954  		transport.WithListenNetwork("udp"),
   955  		transport.WithListenAddress(addr),
   956  		transport.WithHandler(&echoHandler{}),
   957  		transport.WithServerFramerBuilder(&framerBuilder{safe: true}),
   958  		transport.WithServerAsync(true),
   959  	)
   960  	assert.Nil(t, err)
   961  	time.Sleep(20 * time.Millisecond)
   962  
   963  	req := &helloRequest{
   964  		Name: "trpc",
   965  		Msg:  "HelloWorld",
   966  	}
   967  	data, err := json.Marshal(req)
   968  	if err != nil {
   969  		t.Fatalf("json marshal fail:%v", err)
   970  	}
   971  	lenData := make([]byte, 4)
   972  	binary.BigEndian.PutUint32(lenData, uint32(len(data)))
   973  	reqData := append(lenData, data...)
   974  	ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
   975  	defer cancel()
   976  	_, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("udp"),
   977  		transport.WithDialAddress(addr),
   978  		transport.WithClientFramerBuilder(&framerBuilder{safe: true}))
   979  	assert.Nil(t, err)
   980  }
   981  
   982  func TestCreatePoolFail(t *testing.T) {
   983  	addr := getFreeAddr("udp")
   984  
   985  	err := transport.ListenAndServe(
   986  		transport.WithListenNetwork("udp"),
   987  		transport.WithListenAddress(addr),
   988  		transport.WithHandler(&echoHandler{}),
   989  		transport.WithServerFramerBuilder(&framerBuilder{safe: true}),
   990  		transport.WithServerAsync(true),
   991  	)
   992  	assert.Nil(t, err)
   993  
   994  	req := &helloRequest{
   995  		Name: "trpc",
   996  		Msg:  "HelloWorld",
   997  	}
   998  	data, err := json.Marshal(req)
   999  	if err != nil {
  1000  		t.Fatalf("json marshal fail:%v", err)
  1001  	}
  1002  	lenData := make([]byte, 4)
  1003  	binary.BigEndian.PutUint32(lenData, uint32(len(data)))
  1004  	reqData := append(lenData, data...)
  1005  	ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
  1006  	defer cancel()
  1007  	_, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("udp"),
  1008  		transport.WithDialAddress(addr),
  1009  		transport.WithClientFramerBuilder(&framerBuilder{safe: true}))
  1010  	assert.Nil(t, err)
  1011  }
  1012  
  1013  func TestListenAndServeTLSFail(t *testing.T) {
  1014  	s := transport.NewServerTransport()
  1015  	ctx, cancel := context.WithCancel(context.Background())
  1016  	defer cancel()
  1017  	ln, err := net.Listen("tcp", "127.0.0.1:0")
  1018  	require.Nil(t, err)
  1019  	defer ln.Close()
  1020  	require.NotNil(t, s.ListenAndServe(ctx,
  1021  		transport.WithListenNetwork("tcp"),
  1022  		transport.WithServeTLS("fakeCertFileName", "fakeKeyFileName", "fakeCAFileName"),
  1023  		transport.WithServerFramerBuilder(&framerBuilder{}),
  1024  		transport.WithListener(ln),
  1025  	))
  1026  }
  1027  
  1028  func TestListenAndServeWithStopListener(t *testing.T) {
  1029  	s := transport.NewServerTransport()
  1030  	ctx, cancel := context.WithCancel(context.Background())
  1031  	defer cancel()
  1032  	ln, err := net.Listen("tcp", "127.0.0.1:0")
  1033  	require.Nil(t, err)
  1034  	ch := make(chan struct{})
  1035  	require.Nil(t, s.ListenAndServe(ctx,
  1036  		transport.WithListenNetwork("tcp"),
  1037  		transport.WithServerFramerBuilder(&framerBuilder{}),
  1038  		transport.WithListener(ln),
  1039  		transport.WithStopListening(ch),
  1040  	))
  1041  	_, err = net.Dial("tcp", ln.Addr().String())
  1042  	require.Nil(t, err)
  1043  	close(ch)
  1044  	time.Sleep(time.Millisecond)
  1045  	_, err = net.Dial("tcp", ln.Addr().String())
  1046  	require.NotNil(t, err)
  1047  }