trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/server_transport_tcp_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  //go:build linux || freebsd || dragonfly || darwin
    15  // +build linux freebsd dragonfly darwin
    16  
    17  package tnet_test
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"net"
    25  	"os"
    26  	"strconv"
    27  	"sync/atomic"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/stretchr/testify/assert"
    32  	"trpc.group/trpc-go/tnet"
    33  
    34  	trpc "trpc.group/trpc-go/trpc-go"
    35  	"trpc.group/trpc-go/trpc-go/codec"
    36  	"trpc.group/trpc-go/trpc-go/transport"
    37  	tnettrans "trpc.group/trpc-go/trpc-go/transport/tnet"
    38  )
    39  
    40  var (
    41  	port       uint64 = 9000
    42  	helloWorld        = []byte("helloworld")
    43  )
    44  
    45  func TestServerTCP_ListenAndServe(t *testing.T) {
    46  	startServerTest(
    47  		t,
    48  		defaultServerHandle,
    49  		nil,
    50  		func(addr string) {
    51  			rsp, err := gonetRequest(context.Background(), transport.WithDialAddress(addr))
    52  			assert.Nil(t, err)
    53  			assert.Equal(t, helloWorld, rsp)
    54  		},
    55  	)
    56  }
    57  
    58  func TestServerTCP_Asyn(t *testing.T) {
    59  	startServerTest(
    60  		t,
    61  		defaultServerHandle,
    62  		[]transport.ListenServeOption{transport.WithServerAsync(true)},
    63  		func(addr string) {
    64  			rsp, err := gonetRequest(context.Background(), transport.WithDialAddress(addr))
    65  			assert.Nil(t, err)
    66  			assert.Equal(t, helloWorld, rsp)
    67  		},
    68  	)
    69  }
    70  
    71  func TestServerTCP_CustomizedFramerCopyFrame(t *testing.T) {
    72  	startServerTest(
    73  		t,
    74  		func(ctx context.Context, req []byte) ([]byte, error) {
    75  			return req, nil
    76  		},
    77  		[]transport.ListenServeOption{
    78  			transport.WithServerFramerBuilder(&reuseBufferFramerBuilder{}),
    79  			transport.WithServerAsync(true),
    80  		},
    81  		func(addr string) {
    82  			req := helloWorld
    83  			ctx, _ := codec.EnsureMessage(context.Background())
    84  			reqbytes, err := (&emptyClientCodec{}).Encode(
    85  				codec.Message(ctx),
    86  				req,
    87  			)
    88  			assert.Nil(t, err)
    89  
    90  			cliOpts := []transport.RoundTripOption{
    91  				transport.WithDialAddress(addr),
    92  				transport.WithDialNetwork("tcp"),
    93  				transport.WithClientFramerBuilder(&reuseBufferFramerBuilder{}),
    94  				transport.WithDialTimeout(5 * time.Second),
    95  			}
    96  			clientTrans := transport.NewClientTransport()
    97  			rspbytes, err := clientTrans.RoundTrip(
    98  				ctx,
    99  				reqbytes,
   100  				cliOpts...,
   101  			)
   102  			assert.Nil(t, err)
   103  
   104  			rsp, err := (&emptyClientCodec{}).Decode(
   105  				codec.Message(ctx),
   106  				rspbytes,
   107  			)
   108  			assert.Nil(t, err)
   109  			assert.Equal(t, helloWorld, rsp)
   110  		},
   111  	)
   112  }
   113  
   114  func TestServerTCP_UserDefineListener(t *testing.T) {
   115  	serverAddr := getAddr()
   116  	ln, err := tnet.Listen("tcp", serverAddr)
   117  	assert.Nil(t, err)
   118  	startServerTest(
   119  		t,
   120  		defaultServerHandle,
   121  		[]transport.ListenServeOption{transport.WithListener(ln)},
   122  		func(_ string) {
   123  			rsp, err := gonetRequest(context.Background(), transport.WithDialAddress(serverAddr))
   124  			assert.Nil(t, err)
   125  			assert.Equal(t, helloWorld, rsp)
   126  		},
   127  	)
   128  }
   129  
   130  func TestServerTCP_ErrorCases(t *testing.T) {
   131  	s := tnettrans.NewServerTransport()
   132  
   133  	// Without framerBuilder
   134  	serveOpts := getListenServeOption(
   135  		transport.WithServerFramerBuilder(nil),
   136  	)
   137  	err := s.ListenAndServe(context.Background(), serveOpts...)
   138  	assert.NotNil(t, err)
   139  
   140  	// Unsupported network type
   141  	serveOpts = getListenServeOption(
   142  		transport.WithListenNetwork("ip"),
   143  	)
   144  	err = s.ListenAndServe(context.Background(), serveOpts...)
   145  	assert.NotNil(t, err)
   146  }
   147  
   148  func TestServerTCP_HandleErr(t *testing.T) {
   149  	startServerTest(
   150  		t,
   151  		errServerHandle,
   152  		nil,
   153  		func(addr string) {
   154  			_, err := gonetRequest(context.Background(), transport.WithDialAddress(addr))
   155  			fmt.Println(err)
   156  			assert.NotNil(t, err)
   157  		},
   158  	)
   159  }
   160  
   161  func TestServerTCP_IdleTimeout(t *testing.T) {
   162  	startServerTest(
   163  		t,
   164  		defaultServerHandle,
   165  		[]transport.ListenServeOption{transport.WithServerIdleTimeout(time.Second)},
   166  		func(addr string) {
   167  			cliconn, err := tnet.DialTCP("tcp", addr, 0)
   168  			assert.Nil(t, err)
   169  			_, err = cliconn.Write([]byte("0"))
   170  			assert.Nil(t, err)
   171  
   172  			// sleep to make sure ListenAndServe run into onRequest()
   173  			time.Sleep(2 * time.Second)
   174  			_, err = cliconn.Write([]byte("0"))
   175  			assert.NotNil(t, err)
   176  		},
   177  	)
   178  
   179  }
   180  
   181  func TestServerTCP_WriteFail(t *testing.T) {
   182  	ch := make(chan struct{}, 1)
   183  	var isHandled bool
   184  	startServerTest(
   185  		t,
   186  		func(ctx context.Context, req []byte) ([]byte, error) {
   187  			isHandled = true
   188  			<-ch
   189  			return nil, nil
   190  		},
   191  		[]transport.ListenServeOption{transport.WithServerAsync(true)},
   192  		func(addr string) {
   193  			ctx, _ := codec.EnsureMessage(context.Background())
   194  			req, err := trpc.DefaultClientCodec.Encode(codec.Message(ctx), helloWorld)
   195  			assert.Nil(t, err)
   196  
   197  			cliconn, err := tnet.DialTCP("tcp", addr, 0)
   198  			assert.Nil(t, err)
   199  			_, err = cliconn.Write(req)
   200  			assert.Nil(t, err)
   201  
   202  			// sleep to make sure server received data
   203  			time.Sleep(50 * time.Millisecond)
   204  			cliconn.Close()
   205  			// notify server write back data, but server will fail, because connection is closed
   206  			ch <- struct{}{}
   207  			_, err = cliconn.ReadN(1)
   208  			assert.NotNil(t, err)
   209  			// make sure server run into handle
   210  			assert.True(t, isHandled)
   211  		},
   212  	)
   213  }
   214  
   215  func TestServerTCP_PassedListener(t *testing.T) {
   216  	serverAddr := getAddr()
   217  	listener, err := net.Listen("tcp", serverAddr)
   218  	assert.Nil(t, err)
   219  
   220  	transport.SaveListener(listener)
   221  	fds := transport.GetListenersFds()
   222  	var fd int
   223  	for _, f := range fds {
   224  		if f.Address == serverAddr {
   225  			fd = int(f.Fd)
   226  		}
   227  	}
   228  
   229  	os.Setenv(transport.EnvGraceRestart, "1")
   230  	os.Setenv(transport.EnvGraceFirstFd, strconv.Itoa(fd))
   231  	os.Setenv(transport.EnvGraceRestartFdNum, "1")
   232  
   233  	defer func() {
   234  		os.Setenv(transport.EnvGraceRestart, "0")
   235  		os.Setenv(transport.EnvGraceFirstFd, "0")
   236  		os.Setenv(transport.EnvGraceRestartFdNum, "0")
   237  	}()
   238  
   239  	startServerTest(
   240  		t,
   241  		defaultServerHandle,
   242  		[]transport.ListenServeOption{transport.WithListenAddress(serverAddr)},
   243  		func(_ string) {
   244  			rsp, err := gonetRequest(context.Background(), transport.WithDialAddress(serverAddr))
   245  			assert.Nil(t, err)
   246  			assert.Equal(t, helloWorld, rsp)
   247  		},
   248  	)
   249  }
   250  
   251  func TestServerTCP_ClientWrongReq(t *testing.T) {
   252  	startServerTest(
   253  		t,
   254  		defaultServerHandle,
   255  		nil,
   256  		func(addr string) {
   257  			cliconn, err := tnet.DialTCP("tcp", addr, 0)
   258  			assert.Nil(t, err)
   259  			_, err = cliconn.Write([]byte("1234567890123456"))
   260  			assert.Nil(t, err)
   261  
   262  			// sleep to make sure ListenAndServe run into onRequest()
   263  			time.Sleep(50 * time.Millisecond)
   264  			err = cliconn.Close()
   265  			assert.Nil(t, err)
   266  		},
   267  	)
   268  }
   269  
   270  func TestServerTCP_SendAndClose(t *testing.T) {
   271  	addr := getAddr()
   272  	s := tnettrans.NewServerTransport()
   273  	serveOpts := getListenServeOption(
   274  		transport.WithListenAddress(addr),
   275  		transport.WithServerAsync(true),
   276  	)
   277  	err := s.ListenAndServe(context.Background(), serveOpts...)
   278  	assert.Nil(t, err)
   279  
   280  	cliconn, err := tnet.DialTCP("tcp", addr, 0)
   281  	assert.Nil(t, err)
   282  	cliAddr := cliconn.LocalAddr()
   283  
   284  	time.Sleep(50 * time.Millisecond)
   285  	streamTransport, ok := s.(transport.ServerStreamTransport)
   286  	assert.True(t, ok)
   287  	ctx, msg := codec.EnsureMessage(context.Background())
   288  	msg.WithRemoteAddr(cliAddr)
   289  	svrAddr, err := net.ResolveTCPAddr("tcp", addr)
   290  	assert.Nil(t, err)
   291  	msg.WithLocalAddr(svrAddr)
   292  	err = streamTransport.Send(ctx, helloWorld)
   293  	assert.Nil(t, err)
   294  
   295  	b := make([]byte, len(helloWorld))
   296  	cliconn.Read(b)
   297  	assert.Equal(t, b, helloWorld)
   298  
   299  	streamTransport.Close(ctx)
   300  	err = streamTransport.Send(ctx, helloWorld)
   301  	assert.NotNil(t, err)
   302  }
   303  
   304  func TestServerTCP_TLS(t *testing.T) {
   305  	startServerTest(
   306  		t,
   307  		defaultServerHandle,
   308  		[]transport.ListenServeOption{transport.WithServeTLS("../../testdata/server.crt", "../../testdata/server.key", "../../testdata/ca.pem")},
   309  		func(addr string) {
   310  			rsp, err := gonetRequest(
   311  				context.Background(),
   312  				transport.WithDialAddress(addr),
   313  				transport.WithDialTLS("../../testdata/client.crt", "../../testdata/client.key", "../../testdata/ca.pem", "localhost"),
   314  			)
   315  			assert.Nil(t, err)
   316  			assert.Equal(t, helloWorld, rsp)
   317  
   318  			rsp, err = gonetRequest(
   319  				context.Background(),
   320  				transport.WithDialAddress(addr),
   321  				transport.WithDialTLS("../../testdata/client.crt", "../../testdata/client.key", "none", ""),
   322  			)
   323  			assert.Nil(t, err)
   324  			assert.Equal(t, helloWorld, rsp)
   325  		},
   326  	)
   327  }
   328  
   329  func TestUDP(t *testing.T) {
   330  	// UDP is not supported, but it will switch to gonet default transport to serve.
   331  	startServerTest(
   332  		t,
   333  		defaultServerHandle,
   334  		[]transport.ListenServeOption{transport.WithListenNetwork("tcp,udp")},
   335  		func(addr string) {
   336  			rsp, err := gonetRequest(
   337  				context.Background(),
   338  				transport.WithDialAddress(addr),
   339  				transport.WithDialNetwork("udp"))
   340  			assert.Nil(t, err)
   341  			assert.Equal(t, helloWorld, rsp)
   342  
   343  			rsp, err = gonetRequest(
   344  				context.Background(),
   345  				transport.WithDialAddress(addr),
   346  				transport.WithDialNetwork("tcp"))
   347  			assert.Nil(t, err)
   348  			assert.Equal(t, helloWorld, rsp)
   349  		},
   350  	)
   351  }
   352  
   353  func TestUnix(t *testing.T) {
   354  	// Unix socket is not supported, but it will switch to gonet default transport to serve.
   355  	myAddr := "/tmp/server.sock"
   356  	os.Remove(myAddr)
   357  	startServerTest(
   358  		t,
   359  		defaultServerHandle,
   360  		[]transport.ListenServeOption{
   361  			transport.WithListenNetwork("unix"),
   362  			transport.WithListenAddress(myAddr),
   363  		},
   364  		func(_ string) {
   365  			rsp, err := gonetRequest(
   366  				context.Background(),
   367  				transport.WithDialAddress(myAddr),
   368  				transport.WithDialNetwork("unix"))
   369  			assert.Nil(t, err)
   370  			assert.Equal(t, helloWorld, rsp)
   371  		},
   372  	)
   373  }
   374  
   375  func getListenServeOption(opts ...transport.ListenServeOption) []transport.ListenServeOption {
   376  	lsopts := []transport.ListenServeOption{
   377  		transport.WithServerFramerBuilder(trpc.DefaultFramerBuilder),
   378  		transport.WithListenNetwork("tcp"),
   379  		transport.WithHandler(newUserDefineHandler(defaultServerHandle)),
   380  		transport.WithServerIdleTimeout(5 * time.Second),
   381  	}
   382  	lsopts = append(lsopts, opts...)
   383  	return lsopts
   384  }
   385  
   386  func defaultServerHandle(ctx context.Context, req []byte) (rsp []byte, err error) {
   387  	msg := codec.Message(ctx)
   388  	reqdata, err := trpc.DefaultServerCodec.Decode(msg, req)
   389  	if err != nil {
   390  		return nil, err
   391  	}
   392  	rspdata := make([]byte, len(reqdata))
   393  	copy(rspdata, reqdata)
   394  	rsp, err = trpc.DefaultServerCodec.Encode(msg, rspdata)
   395  	return rsp, err
   396  }
   397  
   398  func errServerHandle(ctx context.Context, req []byte) (rsp []byte, err error) {
   399  	return nil, errors.New("mock error")
   400  }
   401  
   402  type userDefineHandler struct {
   403  	handleFunc func(context.Context, []byte) ([]byte, error)
   404  }
   405  
   406  func newUserDefineHandler(f func(context.Context, []byte) ([]byte, error)) *userDefineHandler {
   407  	return &userDefineHandler{handleFunc: f}
   408  }
   409  
   410  func (uh *userDefineHandler) Handle(ctx context.Context, req []byte) (rsp []byte, err error) {
   411  	return uh.handleFunc(ctx, req)
   412  }
   413  
   414  func startServerTest(
   415  	t *testing.T,
   416  	serverHandle func(ctx context.Context, req []byte) ([]byte, error),
   417  	svrCustomOpts []transport.ListenServeOption,
   418  	clientHandle func(addr string),
   419  ) {
   420  	addr := getAddr()
   421  	s := tnettrans.NewServerTransport(
   422  		tnettrans.WithKeepAlivePeriod(15*time.Second),
   423  		tnettrans.WithReusePort(true),
   424  	)
   425  	handler := newUserDefineHandler(func(ctx context.Context, req []byte) ([]byte, error) {
   426  		return serverHandle(ctx, req)
   427  	})
   428  	serveOpts := getListenServeOption(
   429  		transport.WithListenAddress(addr),
   430  		transport.WithHandler(handler),
   431  	)
   432  	serveOpts = append(serveOpts, svrCustomOpts...)
   433  	err := s.ListenAndServe(context.Background(), serveOpts...)
   434  	assert.Nil(t, err)
   435  
   436  	clientHandle(addr)
   437  }
   438  
   439  func gonetRequest(ctx context.Context, opts ...transport.RoundTripOption) ([]byte, error) {
   440  	req := helloWorld
   441  	ctx, _ = codec.EnsureMessage(ctx)
   442  	reqbytes, err := trpc.DefaultClientCodec.Encode(
   443  		codec.Message(ctx),
   444  		req,
   445  	)
   446  	if err != nil {
   447  		return nil, err
   448  	}
   449  
   450  	cliOpts := getRoundTripOption(opts...)
   451  	clientTrans := transport.NewClientTransport()
   452  	rspbytes, err := clientTrans.RoundTrip(
   453  		ctx,
   454  		reqbytes,
   455  		cliOpts...,
   456  	)
   457  	if err != nil {
   458  		return nil, err
   459  	}
   460  	rsp, err := trpc.DefaultClientCodec.Decode(
   461  		codec.Message(ctx),
   462  		rspbytes,
   463  	)
   464  	return rsp, err
   465  }
   466  
   467  func getRoundTripOption(opts ...transport.RoundTripOption) []transport.RoundTripOption {
   468  	rtopts := []transport.RoundTripOption{
   469  		transport.WithDialNetwork("tcp"),
   470  		transport.WithClientFramerBuilder(trpc.DefaultFramerBuilder),
   471  		transport.WithDialTimeout(5 * time.Second),
   472  	}
   473  	rtopts = append(rtopts, opts...)
   474  	return rtopts
   475  }
   476  
   477  func getAddr() string {
   478  	atomic.AddUint64(&port, 1)
   479  	return "127.0.0.1:" + fmt.Sprint(port)
   480  }
   481  
   482  type reuseBufferFramerBuilder struct{}
   483  
   484  func (*reuseBufferFramerBuilder) New(r io.Reader) codec.Framer {
   485  	return &reuseBufferFramer{r: r, reuseBuffer: make([]byte, len(helloWorld))}
   486  }
   487  
   488  type reuseBufferFramer struct {
   489  	r           io.Reader
   490  	reuseBuffer []byte
   491  }
   492  
   493  func (f *reuseBufferFramer) ReadFrame() ([]byte, error) {
   494  	_, err := io.ReadFull(f.r, f.reuseBuffer)
   495  	if err != nil {
   496  		return nil, fmt.Errorf("io.ReadFull err: %w", err)
   497  	}
   498  	return f.reuseBuffer, nil
   499  }
   500  
   501  type emptyServerCodec struct{}
   502  
   503  func (s *emptyServerCodec) Decode(msg codec.Msg, reqBuf []byte) ([]byte, error) {
   504  	return reqBuf, nil
   505  }
   506  
   507  func (s *emptyServerCodec) Encode(msg codec.Msg, rspBody []byte) ([]byte, error) {
   508  	return rspBody, nil
   509  }
   510  
   511  type emptyClientCodec struct{}
   512  
   513  func (s *emptyClientCodec) Decode(msg codec.Msg, reqBuf []byte) ([]byte, error) {
   514  	return reqBuf, nil
   515  }
   516  
   517  func (s *emptyClientCodec) Encode(msg codec.Msg, rspBody []byte) ([]byte, error) {
   518  	return rspBody, nil
   519  }