trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/client_transport_tcp_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  //go:build linux || freebsd || dragonfly || darwin
    15  // +build linux freebsd dragonfly darwin
    16  
    17  package tnet_test
    18  
    19  import (
    20  	"context"
    21  	"net"
    22  	"os"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/stretchr/testify/assert"
    27  	"trpc.group/trpc-go/tnet"
    28  
    29  	trpc "trpc.group/trpc-go/trpc-go"
    30  	"trpc.group/trpc-go/trpc-go/codec"
    31  	"trpc.group/trpc-go/trpc-go/errs"
    32  	"trpc.group/trpc-go/trpc-go/pool/connpool"
    33  	"trpc.group/trpc-go/trpc-go/transport"
    34  	tnettrans "trpc.group/trpc-go/trpc-go/transport/tnet"
    35  )
    36  
    37  func TestClientTCP(t *testing.T) {
    38  	startClientTest(
    39  		t,
    40  		defaultServerHandle,
    41  		nil,
    42  		func(addr string) {
    43  			ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
    44  			defer cancel()
    45  			rsp, err := tnetRequest(ctx, helloWorld,
    46  				transport.WithDialAddress(addr),
    47  				transport.WithDialTimeout(500*time.Millisecond),
    48  			)
    49  			assert.Equal(t, helloWorld, rsp)
    50  			assert.Nil(t, err)
    51  		},
    52  	)
    53  }
    54  
    55  func TestClientTCP_NoFrameBuilder(t *testing.T) {
    56  	startClientTest(
    57  		t,
    58  		defaultServerHandle,
    59  		nil,
    60  		func(addr string) {
    61  			ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
    62  			defer cancel()
    63  			_, err := tnetRequest(ctx, helloWorld,
    64  				transport.WithDialAddress(addr),
    65  				transport.WithClientFramerBuilder(nil),
    66  			)
    67  			assert.Equal(t, errs.RetClientConnectFail, errs.Code(err))
    68  		},
    69  	)
    70  }
    71  
    72  func TestClientTCP_CtxErr(t *testing.T) {
    73  	startClientTest(
    74  		t,
    75  		defaultServerHandle,
    76  		nil,
    77  		func(addr string) {
    78  			// canceled context error
    79  			ctx, cancel := context.WithCancel(context.Background())
    80  			cancel()
    81  			_, err := tnetRequest(ctx, helloWorld,
    82  				transport.WithDialAddress(addr),
    83  			)
    84  			assert.Equal(t, errs.RetClientCanceled, errs.Code(err))
    85  
    86  			// timeout context error
    87  			ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Nanosecond))
    88  			defer cancel()
    89  			time.Sleep(time.Nanosecond)
    90  			_, err = tnetRequest(ctx, helloWorld,
    91  				transport.WithDialAddress(addr),
    92  			)
    93  			assert.Equal(t, errs.RetClientTimeout, errs.Code(err))
    94  		},
    95  	)
    96  }
    97  
    98  func TestClientTCP_DisableConnPool(t *testing.T) {
    99  	// success case
   100  	startClientTest(
   101  		t,
   102  		defaultServerHandle,
   103  		nil,
   104  		func(addr string) {
   105  			rsp, err := tnetRequest(
   106  				context.Background(),
   107  				helloWorld,
   108  				transport.WithDialAddress(addr),
   109  				transport.WithDisableConnectionPool(),
   110  			)
   111  			assert.Nil(t, err)
   112  			assert.Equal(t, helloWorld, rsp)
   113  		},
   114  	)
   115  	// dial wrong address
   116  	_, err := tnetRequest(
   117  		context.Background(),
   118  		helloWorld,
   119  		transport.WithDialAddress("0"),
   120  		transport.WithDisableConnectionPool(),
   121  	)
   122  	assert.Equal(t, errs.RetClientConnectFail, errs.Code(err))
   123  }
   124  
   125  func TestClientTCP_ReadTimeout(t *testing.T) {
   126  	startClientTest(
   127  		t,
   128  		func(ctx context.Context, req []byte) ([]byte, error) {
   129  			time.Sleep(time.Hour)
   130  			return nil, nil
   131  		},
   132  		nil,
   133  		func(addr string) {
   134  			ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   135  			defer cancel()
   136  			_, err := tnetRequest(
   137  				ctx,
   138  				helloWorld,
   139  				transport.WithDialAddress(addr),
   140  			)
   141  			assert.Equal(t, errs.RetClientTimeout, errs.Code(err))
   142  		},
   143  	)
   144  }
   145  
   146  func TestClientTCP_CustomPool(t *testing.T) {
   147  	startClientTest(
   148  		t,
   149  		defaultServerHandle,
   150  		nil,
   151  		func(addr string) {
   152  			ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   153  			defer cancel()
   154  			rsp, err := tnetRequest(
   155  				ctx,
   156  				helloWorld,
   157  				transport.WithDialAddress(addr),
   158  				transport.WithDialPool(&customPool{}),
   159  			)
   160  			assert.Equal(t, helloWorld, rsp)
   161  			assert.Nil(t, err)
   162  		},
   163  	)
   164  }
   165  
   166  func TestClientUDP(t *testing.T) {
   167  	// UDP is not supported, but it will switch to gonet default transport to roundtrip.
   168  	startClientTest(
   169  		t,
   170  		defaultServerHandle,
   171  		[]transport.ListenServeOption{transport.WithListenNetwork("udp")},
   172  		func(addr string) {
   173  			rsp, err := tnetRequest(
   174  				context.Background(),
   175  				helloWorld,
   176  				transport.WithDialAddress(addr),
   177  				transport.WithDialNetwork("udp"))
   178  			assert.Nil(t, err)
   179  			assert.Equal(t, helloWorld, rsp)
   180  		},
   181  	)
   182  }
   183  
   184  func TestClientUnix(t *testing.T) {
   185  	// Unix socket is not supported, but it will switch to gonet default transport to roundtrip.
   186  	unixAddr := "/tmp/server.sock"
   187  	os.Remove(unixAddr)
   188  	startClientTest(
   189  		t,
   190  		defaultServerHandle,
   191  		[]transport.ListenServeOption{
   192  			transport.WithListenAddress(unixAddr),
   193  			transport.WithListenNetwork("unix"),
   194  		},
   195  		func(addr string) {
   196  			rsp, err := tnetRequest(
   197  				context.Background(),
   198  				helloWorld,
   199  				transport.WithDialAddress(unixAddr),
   200  				transport.WithDialNetwork("unix"))
   201  			assert.Nil(t, err)
   202  			assert.Equal(t, helloWorld, rsp)
   203  		},
   204  	)
   205  
   206  }
   207  
   208  func TestClientTCP_Multiplex(t *testing.T) {
   209  	startClientTest(
   210  		t,
   211  		defaultServerHandle,
   212  		nil,
   213  		func(addr string) {
   214  			req := helloWorld
   215  			ctx, msg := codec.EnsureMessage(context.Background())
   216  			reqFrame, err := trpc.DefaultClientCodec.Encode(codec.Message(ctx), req)
   217  			assert.Nil(t, err)
   218  
   219  			cliOpts := getRoundTripOption(
   220  				transport.WithDialAddress(addr),
   221  				transport.WithMultiplexed(true),
   222  				transport.WithMsg(msg),
   223  			)
   224  			clientTrans := tnettrans.NewClientTransport()
   225  			rspFrame, err := clientTrans.RoundTrip(ctx, reqFrame, cliOpts...)
   226  			assert.Nil(t, err)
   227  			rsp, err := trpc.DefaultClientCodec.Decode(msg, rspFrame)
   228  			assert.Nil(t, err)
   229  			assert.Equal(t, helloWorld, rsp)
   230  		},
   231  	)
   232  }
   233  
   234  func TestClientTCP_TLS(t *testing.T) {
   235  	startClientTest(
   236  		t,
   237  		defaultServerHandle,
   238  		[]transport.ListenServeOption{transport.WithServeTLS("../../testdata/server.crt", "../../testdata/server.key", "../../testdata/ca.pem")},
   239  		func(addr string) {
   240  			rsp, err := tnetRequest(
   241  				context.Background(),
   242  				helloWorld,
   243  				transport.WithDialAddress(addr),
   244  				transport.WithDialTLS("../../testdata/client.crt", "../../testdata/client.key", "../../testdata/ca.pem", "localhost"),
   245  			)
   246  			assert.Nil(t, err)
   247  			assert.Equal(t, helloWorld, rsp)
   248  
   249  			rsp, err = tnetRequest(
   250  				context.Background(),
   251  				helloWorld,
   252  				transport.WithDialAddress(addr),
   253  				transport.WithDialTLS("../../testdata/client.crt", "../../testdata/client.key", "none", ""),
   254  			)
   255  			assert.Nil(t, err)
   256  			assert.Equal(t, helloWorld, rsp)
   257  		},
   258  	)
   259  }
   260  
   261  func TestClientTCP_HealthCheck(t *testing.T) {
   262  	addr := getAddr()
   263  	s := transport.NewServerTransport()
   264  	serveOpts := getListenServeOption(transport.WithListenAddress(addr))
   265  	err := s.ListenAndServe(context.Background(), serveOpts...)
   266  	assert.Nil(t, err)
   267  
   268  	c, err := net.Dial("tcp", addr)
   269  	assert.Nil(t, err)
   270  	assert.True(t, tnettrans.HealthChecker(&connpool.PoolConn{Conn: c}, true))
   271  
   272  	c, err = tnet.DialTCP("tcp", addr, 0)
   273  	assert.Nil(t, err)
   274  	assert.True(t, tnettrans.HealthChecker(&connpool.PoolConn{Conn: c}, true))
   275  
   276  	c.Close()
   277  	assert.False(t, tnettrans.HealthChecker(&connpool.PoolConn{Conn: c}, true))
   278  }
   279  
   280  func TestNewConnectionPool(t *testing.T) {
   281  	p := tnettrans.NewConnectionPool()
   282  	assert.NotNil(t, p)
   283  }
   284  
   285  func startClientTest(
   286  	t *testing.T,
   287  	serverHandle func(ctx context.Context, req []byte) ([]byte, error),
   288  	svrCustomOpts []transport.ListenServeOption,
   289  	clientHandle func(addr string),
   290  ) {
   291  	addr := getAddr()
   292  	s := transport.NewServerTransport()
   293  	handler := newUserDefineHandler(func(ctx context.Context, req []byte) ([]byte, error) {
   294  		return serverHandle(ctx, req)
   295  	})
   296  	serveOpts := getListenServeOption(
   297  		transport.WithListenAddress(addr),
   298  		transport.WithHandler(handler),
   299  	)
   300  	serveOpts = append(serveOpts, svrCustomOpts...)
   301  	err := s.ListenAndServe(context.Background(), serveOpts...)
   302  	assert.Nil(t, err)
   303  
   304  	clientHandle(addr)
   305  }
   306  
   307  type customPool struct{}
   308  
   309  type customConn struct {
   310  	tnet.Conn
   311  	framer codec.Framer
   312  }
   313  
   314  func (c *customConn) ReadFrame() ([]byte, error) {
   315  	return c.framer.ReadFrame()
   316  }
   317  
   318  func (p *customPool) Get(network string, address string, opts connpool.GetOptions) (net.Conn, error) {
   319  	c, err := tnet.DialTCP(network, address, opts.DialTimeout)
   320  	if err != nil {
   321  		return nil, err
   322  	}
   323  	return &customConn{Conn: c, framer: opts.FramerBuilder.New(c)}, nil
   324  }
   325  
   326  func tnetRequest(ctx context.Context, req []byte, opts ...transport.RoundTripOption) ([]byte, error) {
   327  	ctx, _ = codec.EnsureMessage(ctx)
   328  	reqbytes, err := trpc.DefaultClientCodec.Encode(
   329  		codec.Message(ctx),
   330  		req,
   331  	)
   332  	if err != nil {
   333  		return nil, err
   334  	}
   335  
   336  	cliOpts := getRoundTripOption(opts...)
   337  	clientTrans := tnettrans.NewClientTransport()
   338  	rspbytes, err := clientTrans.RoundTrip(
   339  		ctx,
   340  		reqbytes,
   341  		cliOpts...,
   342  	)
   343  	if err != nil {
   344  		return nil, err
   345  	}
   346  	rsp, err := trpc.DefaultClientCodec.Decode(
   347  		codec.Message(ctx),
   348  		rspbytes,
   349  	)
   350  	return rsp, err
   351  }