trpc.group/trpc-go/trpc-go@v1.0.2/transport/client_transport_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package transport_test
    15  
    16  import (
    17  	"context"
    18  	"errors"
    19  	"fmt"
    20  	"io"
    21  	"net"
    22  	"testing"
    23  	"time"
    24  
    25  	"trpc.group/trpc-go/trpc-go/codec"
    26  	"trpc.group/trpc-go/trpc-go/errs"
    27  	"trpc.group/trpc-go/trpc-go/pool/connpool"
    28  	"trpc.group/trpc-go/trpc-go/pool/multiplexed"
    29  	"trpc.group/trpc-go/trpc-go/transport"
    30  
    31  	"github.com/stretchr/testify/assert"
    32  	"github.com/stretchr/testify/require"
    33  
    34  	trpc "trpc.group/trpc-go/trpc-go"
    35  )
    36  
    37  func TestTcpRoundTripPoolNIl(t *testing.T) {
    38  	st := transport.NewClientTransport()
    39  	optNetwork := transport.WithDialNetwork("tcp")
    40  	optPool := transport.WithDialPool(nil)
    41  	_, err := st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool)
    42  	assert.NotNil(t, err)
    43  }
    44  
    45  func TestTcpRoundTripTCPErr(t *testing.T) {
    46  	st := transport.NewClientTransport()
    47  	optNetwork := transport.WithDialNetwork("tcp")
    48  	pool := connpool.NewConnectionPool()
    49  	optPool := transport.WithDialPool(pool)
    50  	fb := &trpc.FramerBuilder{}
    51  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
    52  	optDisabled := transport.WithDisableConnectionPool()
    53  	newCtx := context.Background()
    54  	newCtx.Done()
    55  	newCtx.Deadline()
    56  	_, err := st.RoundTrip(newCtx, []byte("hello"), optNetwork, optPool, optFramerBuilder, optDisabled)
    57  	assert.NotNil(t, err)
    58  }
    59  
    60  func TestTcpRoundTripCTXErr(t *testing.T) {
    61  	st := transport.NewClientTransport()
    62  	optNetwork := transport.WithDialNetwork("tcp")
    63  	pool := connpool.NewConnectionPool()
    64  	optPool := transport.WithDialPool(pool)
    65  	fb := &trpc.FramerBuilder{}
    66  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
    67  	_, err := st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool, optFramerBuilder)
    68  	assert.NotNil(t, err)
    69  }
    70  
    71  type fakePool struct {
    72  }
    73  
    74  func (p *fakePool) Get(network string, address string, opts connpool.GetOptions) (net.Conn, error) {
    75  	return &fakeConn{}, nil
    76  }
    77  
    78  type fakeConn struct {
    79  }
    80  
    81  func (c *fakeConn) Close() error {
    82  	return nil
    83  }
    84  
    85  func (c *fakeConn) Read(b []byte) (n int, err error) {
    86  	return 0, nil
    87  }
    88  
    89  type netError struct {
    90  	error
    91  }
    92  
    93  // Timeout() bool
    94  // Temporary() bool
    95  func (c *netError) Timeout() bool {
    96  	return true
    97  }
    98  func (c *netError) Temporary() bool {
    99  	return true
   100  }
   101  
   102  func (c *fakeConn) Write(b []byte) (n int, err error) {
   103  	if Count == 1 {
   104  		return 0, errors.New("write failure")
   105  	}
   106  	if Count == 2 {
   107  		return 0, netError{errors.New("net failure")}
   108  	}
   109  	return len(b), nil
   110  }
   111  
   112  func (c *fakeConn) LocalAddr() net.Addr {
   113  	return nil
   114  }
   115  
   116  func (c *fakeConn) RemoteAddr() net.Addr {
   117  	return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8888}
   118  }
   119  
   120  func (c *fakeConn) SetDeadline(t time.Time) error {
   121  	return nil
   122  }
   123  
   124  func (c *fakeConn) SetReadDeadline(t time.Time) error {
   125  	return nil
   126  }
   127  
   128  func (c *fakeConn) SetWriteDeadline(t time.Time) error {
   129  	return nil
   130  }
   131  
   132  func TestTcpRoundTripReadFrameNil(t *testing.T) {
   133  	st := transport.NewClientTransport()
   134  	optNetwork := transport.WithDialNetwork("tcp")
   135  	optPool := transport.WithDialPool(&fakePool{})
   136  	fb := &trpc.FramerBuilder{}
   137  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
   138  	optReqType := transport.WithReqType(transport.SendOnly)
   139  	optAddress := transport.WithDialAddress(":8888")
   140  	_, err := st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool, optFramerBuilder,
   141  		optReqType, optAddress)
   142  	assert.NotNil(t, err)
   143  }
   144  
   145  func TestTCPRoundTripSetRemoteAddr(t *testing.T) {
   146  	st := transport.NewClientTransport()
   147  	optNetwork := transport.WithDialNetwork("tcp")
   148  	optPool := transport.WithDialPool(&fakePool{})
   149  	fb := &trpc.FramerBuilder{}
   150  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
   151  	optAddress := transport.WithDialAddress("127.0.0.1:8888")
   152  	ctx, msg := codec.WithNewMessage(context.Background())
   153  	_, _ = st.RoundTrip(ctx, []byte("hello"), optNetwork, optPool, optFramerBuilder, optAddress)
   154  	assert.NotNil(t, msg.RemoteAddr())
   155  	assert.Equal(t, "127.0.0.1:8888", msg.RemoteAddr().String())
   156  }
   157  
   158  type newCtx struct {
   159  }
   160  
   161  var Count int64
   162  
   163  func (c *newCtx) Deadline() (deadline time.Time, ok bool) {
   164  	deadline = time.Now()
   165  	return deadline, true
   166  }
   167  func (c *newCtx) Done() <-chan struct{} {
   168  	return nil
   169  }
   170  func (c *newCtx) Err() error {
   171  	if Count == 1 {
   172  		return context.DeadlineExceeded
   173  	}
   174  	return context.Canceled
   175  }
   176  func (c *newCtx) Value(key interface{}) interface{} {
   177  	return nil
   178  }
   179  
   180  func TestTcpRoundTripCanceled(t *testing.T) {
   181  	st := transport.NewClientTransport()
   182  	optNetwork := transport.WithDialNetwork("tcp")
   183  	optPool := transport.WithDialPool(&fakePool{})
   184  	fb := &trpc.FramerBuilder{}
   185  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
   186  	optAddress := transport.WithDialAddress(":8888")
   187  	_, err := st.RoundTrip(&newCtx{}, []byte("hello"), optNetwork, optPool, optFramerBuilder,
   188  		optAddress)
   189  	assert.NotNil(t, err)
   190  }
   191  
   192  func TestTcpRoundTripTimeout(t *testing.T) {
   193  	st := transport.NewClientTransport()
   194  	optNetwork := transport.WithDialNetwork("tcp")
   195  	optPool := transport.WithDialPool(&fakePool{})
   196  	fb := &trpc.FramerBuilder{}
   197  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
   198  	optAddress := transport.WithDialAddress(":8888")
   199  	Count = 1
   200  	_, err := st.RoundTrip(&newCtx{}, []byte("hello"), optNetwork, optPool, optFramerBuilder,
   201  		optAddress)
   202  	assert.NotNil(t, err)
   203  }
   204  
   205  func TestTcpRoundTripConnWriteErr(t *testing.T) {
   206  	st := transport.NewClientTransport()
   207  	optNetwork := transport.WithDialNetwork("tcp")
   208  	optPool := transport.WithDialPool(&fakePool{})
   209  	fb := &trpc.FramerBuilder{}
   210  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
   211  	optAddress := transport.WithDialAddress(":8888")
   212  	Count = 1
   213  	_, err := st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool, optFramerBuilder,
   214  		optAddress)
   215  	assert.NotNil(t, err)
   216  	Count = 2
   217  	_, err = st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool, optFramerBuilder,
   218  		optAddress)
   219  	assert.NotNil(t, err)
   220  }
   221  
   222  type NewPacketConn struct {
   223  }
   224  
   225  func (c *NewPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   226  	return 0, nil, nil
   227  }
   228  func (c *NewPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   229  	if Count == 1 {
   230  		return len(p), errors.New("write failure")
   231  	}
   232  	return len(p), netError{errors.New("net failure")}
   233  }
   234  func (c *NewPacketConn) Close() error {
   235  	return nil
   236  }
   237  func (c *NewPacketConn) LocalAddr() net.Addr {
   238  	return nil
   239  }
   240  func (c *NewPacketConn) SetDeadline(t time.Time) error {
   241  	return nil
   242  }
   243  func (c *NewPacketConn) SetReadDeadline(t time.Time) error {
   244  	return nil
   245  }
   246  func (c *NewPacketConn) SetWriteDeadline(t time.Time) error {
   247  	return nil
   248  }
   249  func (c *NewPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
   250  	return len(b), nil, netError{errors.New("net failure")}
   251  }
   252  
   253  func TestNewClientTransport(t *testing.T) {
   254  	st := transport.NewClientTransport()
   255  	assert.NotNil(t, st)
   256  }
   257  
   258  func TestWithDialPool(t *testing.T) {
   259  	opt := transport.WithDialPool(nil)
   260  	opts := &transport.RoundTripOptions{}
   261  	opt(opts)
   262  	assert.Equal(t, nil, opts.Pool)
   263  }
   264  
   265  func TestWithReqType(t *testing.T) {
   266  	opt := transport.WithReqType(transport.SendOnly)
   267  	opts := &transport.RoundTripOptions{}
   268  	opt(opts)
   269  	assert.Equal(t, transport.SendOnly, opts.ReqType)
   270  }
   271  
   272  type emptyPool struct {
   273  }
   274  
   275  func (p *emptyPool) Get(network string, address string, opts connpool.GetOptions) (net.Conn, error) {
   276  	return nil, errors.New("empty")
   277  }
   278  
   279  var testReqByte = []byte{'a', 'b'}
   280  
   281  func TestWithDialPoolError(t *testing.T) {
   282  	ctx, f := context.WithTimeout(context.Background(), 3*time.Second)
   283  	defer f()
   284  	_, err := transport.RoundTrip(ctx, testReqByte,
   285  		transport.WithDialPool(&emptyPool{}),
   286  		transport.WithDialNetwork("tcp"))
   287  	// fmt.Printf("err: %v", err)
   288  	assert.NotNil(t, err)
   289  }
   290  
   291  func TestContextTimeout(t *testing.T) {
   292  	ctx, f := context.WithTimeout(context.Background(), time.Millisecond)
   293  	defer f()
   294  	<-ctx.Done()
   295  	fb := &trpc.FramerBuilder{}
   296  	_, err := transport.RoundTrip(ctx, testReqByte,
   297  		transport.WithDialNetwork("tcp"),
   298  		transport.WithDialAddress(":8888"),
   299  		transport.WithClientFramerBuilder(fb))
   300  	assert.NotNil(t, err)
   301  }
   302  
   303  func TestContextTimeout_Multiplexed(t *testing.T) {
   304  	ctx, f := context.WithTimeout(context.Background(), time.Millisecond)
   305  	defer f()
   306  	<-ctx.Done()
   307  	fb := &trpc.FramerBuilder{}
   308  	_, err := transport.RoundTrip(ctx, testReqByte,
   309  		transport.WithDialNetwork("tcp"),
   310  		transport.WithDialAddress(":8888"),
   311  		transport.WithMultiplexed(true),
   312  		transport.WithMsg(codec.Message(ctx)),
   313  		transport.WithClientFramerBuilder(fb))
   314  	assert.NotNil(t, err)
   315  }
   316  
   317  func TestContextCancel(t *testing.T) {
   318  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   319  	cancel()
   320  	fb := &trpc.FramerBuilder{}
   321  	_, err := transport.RoundTrip(ctx, testReqByte,
   322  		transport.WithDialNetwork("tcp"),
   323  		transport.WithDialAddress(":8888"),
   324  		transport.WithClientFramerBuilder(fb))
   325  	assert.NotNil(t, err)
   326  }
   327  
   328  func TestWithReqTypeSendOnly(t *testing.T) {
   329  	ctx, f := context.WithTimeout(context.Background(), 3*time.Second)
   330  	defer f()
   331  	_, err := transport.RoundTrip(ctx, []byte{},
   332  		transport.WithReqType(transport.SendOnly),
   333  		transport.WithDialNetwork("tcp"))
   334  	// fmt.Printf("err: %v", err)
   335  	assert.NotNil(t, err)
   336  }
   337  
   338  func TestClientTransport_RoundTrip(t *testing.T) {
   339  	fb := &lengthDelimitedBuilder{}
   340  	go func() {
   341  		err := transport.ListenAndServe(
   342  			transport.WithListenNetwork("udp"),
   343  			transport.WithListenAddress("localhost:9998"),
   344  			transport.WithHandler(&lengthDelimitedHandler{}),
   345  			transport.WithServerFramerBuilder(fb),
   346  		)
   347  		assert.Nil(t, err)
   348  	}()
   349  	time.Sleep(20 * time.Millisecond)
   350  
   351  	var err error
   352  	_, err = transport.RoundTrip(context.Background(), encodeLengthDelimited("helloworld"))
   353  	assert.NotNil(t, err)
   354  
   355  	tc := transport.NewClientTransport()
   356  	_, err = tc.RoundTrip(context.Background(), encodeLengthDelimited("helloworld"))
   357  	assert.NotNil(t, err)
   358  
   359  	// Test address invalid.
   360  	_, err = tc.RoundTrip(context.Background(), encodeLengthDelimited("helloworld"),
   361  		transport.WithDialNetwork("udp"),
   362  		transport.WithDialAddress("invalidaddress"),
   363  		transport.WithReqType(transport.SendOnly))
   364  	assert.NotNil(t, err)
   365  
   366  	// Test send only.
   367  	rsp, err := tc.RoundTrip(context.Background(), encodeLengthDelimited("helloworld"),
   368  		transport.WithDialNetwork("udp"),
   369  		transport.WithDialAddress("localhost:9998"),
   370  		transport.WithClientFramerBuilder(fb),
   371  		transport.WithReqType(transport.SendOnly),
   372  		transport.WithConnectionMode(transport.NotConnected))
   373  	assert.NotNil(t, err)
   374  	assert.Equal(t, errs.ErrClientNoResponse, err)
   375  	assert.Nil(t, rsp)
   376  
   377  	// Test multiplexed send only.
   378  	ctx, msg := codec.WithNewMessage(context.Background())
   379  	rsp, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   380  		transport.WithDialNetwork("udp"),
   381  		transport.WithMultiplexed(true),
   382  		transport.WithDialAddress("localhost:9998"),
   383  		transport.WithReqType(transport.SendOnly),
   384  		transport.WithClientFramerBuilder(fb),
   385  		transport.WithMsg(msg),
   386  	)
   387  	assert.NotNil(t, err)
   388  	assert.Equal(t, errs.ErrClientNoResponse, err)
   389  	assert.Nil(t, rsp)
   390  
   391  	// Test context canceled.
   392  	ctx, cancel := context.WithCancel(context.Background())
   393  	cancel()
   394  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   395  		transport.WithDialNetwork("udp"),
   396  		transport.WithClientFramerBuilder(fb),
   397  		transport.WithDialAddress("localhost:9998"))
   398  	assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientCanceled))
   399  
   400  	// Test context timeout.
   401  	ctx, timeout := context.WithTimeout(context.Background(), time.Millisecond)
   402  	defer timeout()
   403  	<-ctx.Done()
   404  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   405  		transport.WithDialNetwork("udp"),
   406  		transport.WithClientFramerBuilder(fb),
   407  		transport.WithDialAddress("localhost:9998"))
   408  	assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientTimeout))
   409  
   410  	// Test roundtrip.
   411  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   412  	defer cancel()
   413  	rsp, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   414  		transport.WithDialNetwork("udp"),
   415  		transport.WithDialAddress("localhost:9998"),
   416  		transport.WithConnectionMode(transport.NotConnected),
   417  		transport.WithClientFramerBuilder(fb),
   418  	)
   419  	assert.NotNil(t, rsp)
   420  	assert.Nil(t, err)
   421  
   422  	// Test setting RemoteAddr of UDP RoundTrip.
   423  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   424  	defer cancel()
   425  	ctx, msg = codec.WithNewMessage(ctx)
   426  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   427  		transport.WithDialNetwork("udp"),
   428  		transport.WithDialAddress("127.0.0.1:9998"),
   429  		transport.WithConnectionMode(transport.Connected),
   430  		transport.WithClientFramerBuilder(fb),
   431  	)
   432  	assert.Nil(t, err)
   433  	assert.Equal(t, "127.0.0.1:9998", msg.RemoteAddr().String())
   434  
   435  	// Test local addr.
   436  	localAddr := "127.0.0.1:"
   437  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   438  	defer cancel()
   439  	ctx, msg = codec.WithNewMessage(ctx)
   440  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   441  		transport.WithDialNetwork("udp"),
   442  		transport.WithDialAddress("127.0.0.1:9998"),
   443  		transport.WithConnectionMode(transport.Connected),
   444  		transport.WithClientFramerBuilder(fb),
   445  		transport.WithLocalAddr(localAddr),
   446  	)
   447  	assert.Nil(t, err)
   448  	assert.Equal(t, "127.0.0.1", msg.LocalAddr().(*net.UDPAddr).IP.String())
   449  
   450  	// Test local addr error.
   451  	localAddr = "invalid address"
   452  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   453  	defer cancel()
   454  	ctx, msg = codec.WithNewMessage(ctx)
   455  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   456  		transport.WithDialNetwork("udp"),
   457  		transport.WithDialAddress("127.0.0.1:9998"),
   458  		transport.WithConnectionMode(transport.Connected),
   459  		transport.WithClientFramerBuilder(fb),
   460  		transport.WithLocalAddr(localAddr),
   461  	)
   462  	assert.NotNil(t, err)
   463  	assert.Nil(t, msg.LocalAddr())
   464  
   465  	// Test readframer error.
   466  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   467  	defer cancel()
   468  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   469  		transport.WithDialNetwork("udp"),
   470  		transport.WithDialAddress("127.0.0.1:9998"),
   471  		transport.WithConnectionMode(transport.Connected),
   472  		transport.WithClientFramerBuilder(&lengthDelimitedBuilder{
   473  			readError: true,
   474  		}),
   475  	)
   476  	assert.Contains(t, err.Error(), readFrameError.Error())
   477  
   478  	// Test readframe bytes remaining error.
   479  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   480  	defer cancel()
   481  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   482  		transport.WithDialNetwork("udp"),
   483  		transport.WithDialAddress("127.0.0.1:9998"),
   484  		transport.WithConnectionMode(transport.Connected),
   485  		transport.WithClientFramerBuilder(&lengthDelimitedBuilder{
   486  			remainingBytes: true,
   487  		}),
   488  	)
   489  	assert.Contains(t, err.Error(), remainingBytesError.Error())
   490  }
   491  
   492  // Frame a stream of bytes based on a length prefix
   493  // +------------+--------------------------------+
   494  // | len: uint8 |          frame payload         |
   495  // +------------+--------------------------------+
   496  type lengthDelimitedBuilder struct {
   497  	remainingBytes bool
   498  	readError      bool
   499  }
   500  
   501  func (fb *lengthDelimitedBuilder) New(reader io.Reader) codec.Framer {
   502  	return &lengthDelimited{
   503  		readError:      fb.readError,
   504  		remainingBytes: fb.remainingBytes,
   505  		reader:         reader,
   506  	}
   507  }
   508  
   509  func (fb *lengthDelimitedBuilder) Parse(rc io.Reader) (vid uint32, buf []byte, err error) {
   510  	buf, err = fb.New(rc).ReadFrame()
   511  	if err != nil {
   512  		return 0, nil, err
   513  	}
   514  	return 0, buf, nil
   515  }
   516  
   517  type lengthDelimited struct {
   518  	reader         io.Reader
   519  	readError      bool
   520  	remainingBytes bool
   521  }
   522  
   523  func encodeLengthDelimited(data string) []byte {
   524  	result := []byte{byte(len(data))}
   525  	result = append(result, []byte(data)...)
   526  	return result
   527  }
   528  
   529  var (
   530  	readFrameError      = errors.New("read framer error")
   531  	remainingBytesError = fmt.Errorf(
   532  		"packet data is not drained, the remaining %d will be dropped",
   533  		remainingBytes,
   534  	)
   535  	remainingBytes = 1
   536  )
   537  
   538  func (f *lengthDelimited) ReadFrame() ([]byte, error) {
   539  	if f.readError {
   540  		return nil, readFrameError
   541  	}
   542  	head := make([]byte, 1)
   543  	if _, err := io.ReadFull(f.reader, head); err != nil {
   544  		return nil, err
   545  	}
   546  	bodyLen := int(head[0])
   547  	if f.remainingBytes {
   548  		bodyLen = bodyLen - remainingBytes
   549  	}
   550  	body := make([]byte, bodyLen)
   551  	if _, err := io.ReadFull(f.reader, body); err != nil {
   552  		return nil, err
   553  	}
   554  	return body, nil
   555  }
   556  
   557  type lengthDelimitedHandler struct{}
   558  
   559  func (h *lengthDelimitedHandler) Handle(ctx context.Context, req []byte) ([]byte, error) {
   560  	rsp := make([]byte, len(req)+1)
   561  	rsp[0] = byte(len(req))
   562  	copy(rsp[1:], req)
   563  	return rsp, nil
   564  }
   565  
   566  func TestClientTransport_MultiplexedErr(t *testing.T) {
   567  	listener, err := net.Listen("tcp", ":")
   568  	require.Nil(t, err)
   569  	defer listener.Close()
   570  	go func() {
   571  		transport.ListenAndServe(
   572  			transport.WithListener(listener),
   573  			transport.WithHandler(&echoHandler{}),
   574  			transport.WithServerFramerBuilder(transport.GetFramerBuilder("trpc")),
   575  		)
   576  	}()
   577  	time.Sleep(20 * time.Millisecond)
   578  
   579  	tc := transport.NewClientTransport()
   580  	fb := &trpc.FramerBuilder{}
   581  
   582  	// Test multiplexed context timeout.
   583  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   584  	defer cancel()
   585  	ctx, msg := codec.WithNewMessage(ctx)
   586  	_, err = tc.RoundTrip(ctx, []byte("helloworld"),
   587  		transport.WithDialNetwork(listener.Addr().Network()),
   588  		transport.WithDialAddress(listener.Addr().String()),
   589  		transport.WithMultiplexed(true),
   590  		transport.WithClientFramerBuilder(fb),
   591  		transport.WithMsg(msg),
   592  	)
   593  	assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientTimeout))
   594  
   595  	// Test multiplexed context canceled.
   596  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   597  	go func() {
   598  		time.Sleep(time.Millisecond * 200)
   599  		cancel()
   600  	}()
   601  	_, err = tc.RoundTrip(ctx, []byte("helloworld"),
   602  		transport.WithDialNetwork(listener.Addr().Network()),
   603  		transport.WithDialAddress(listener.Addr().String()),
   604  		transport.WithMultiplexed(true),
   605  		transport.WithClientFramerBuilder(fb),
   606  		transport.WithMsg(msg),
   607  	)
   608  	assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientCanceled))
   609  }
   610  
   611  func TestClientTransport_RoundTrip_PreConnected(t *testing.T) {
   612  
   613  	go func() {
   614  		err := transport.ListenAndServe(
   615  			transport.WithListenNetwork("udp"),
   616  			transport.WithListenAddress("localhost:9999"),
   617  			transport.WithHandler(&echoHandler{}),
   618  			transport.WithServerFramerBuilder(transport.GetFramerBuilder("trpc")),
   619  		)
   620  		assert.Nil(t, err)
   621  	}()
   622  	time.Sleep(20 * time.Millisecond)
   623  
   624  	var err error
   625  	_, err = transport.RoundTrip(context.Background(), []byte("helloworld"))
   626  	assert.NotNil(t, err)
   627  
   628  	tc := transport.NewClientTransport()
   629  
   630  	// Test connected UDPConn.
   631  	rsp, err := tc.RoundTrip(context.Background(), []byte("helloworld"),
   632  		transport.WithDialNetwork("udp"),
   633  		transport.WithDialAddress("localhost:9999"),
   634  		transport.WithDialPassword("passwd"),
   635  		transport.WithClientFramerBuilder(&trpc.FramerBuilder{}),
   636  		transport.WithReqType(transport.SendOnly),
   637  		transport.WithConnectionMode(transport.Connected))
   638  	assert.NotNil(t, err)
   639  	assert.Equal(t, errs.ErrClientNoResponse, err)
   640  	assert.Nil(t, rsp)
   641  
   642  	// Test context done.
   643  	ctx, cancel := context.WithCancel(context.Background())
   644  	cancel()
   645  	_, err = tc.RoundTrip(ctx, []byte("helloworld"),
   646  		transport.WithDialNetwork("udp"),
   647  		transport.WithDialAddress("localhost:9999"),
   648  		transport.WithConnectionMode(transport.Connected))
   649  	assert.NotNil(t, err)
   650  
   651  	// Test RoundTrip.
   652  	ctx, cancel = context.WithTimeout(ctx, time.Second)
   653  	defer cancel()
   654  	rsp, err = tc.RoundTrip(ctx, []byte("helloworld"),
   655  		transport.WithDialNetwork("udp"),
   656  		transport.WithDialAddress("localhost:9999"),
   657  		transport.WithConnectionMode(transport.Connected))
   658  	assert.NotNil(t, err)
   659  	assert.Nil(t, rsp)
   660  }
   661  
   662  func TestOptions(t *testing.T) {
   663  
   664  	opts := &transport.RoundTripOptions{}
   665  
   666  	o := transport.WithDialTLS("client.cert", "client.key", "ca.pem", "servername")
   667  	o(opts)
   668  	assert.Equal(t, "client.cert", opts.TLSCertFile)
   669  	assert.Equal(t, "client.key", opts.TLSKeyFile)
   670  	assert.Equal(t, "ca.pem", opts.CACertFile)
   671  	assert.Equal(t, "servername", opts.TLSServerName)
   672  
   673  	o = transport.WithDisableConnectionPool()
   674  	o(opts)
   675  
   676  	assert.True(t, opts.DisableConnectionPool)
   677  }
   678  
   679  // TestWithMultiplexedPool tests connection pool multiplexing.
   680  func TestWithMultiplexedPool(t *testing.T) {
   681  	opts := &transport.RoundTripOptions{}
   682  	m := multiplexed.New(multiplexed.WithConnectNumber(10))
   683  	o := transport.WithMultiplexedPool(m)
   684  	o(opts)
   685  	assert.True(t, opts.EnableMultiplexed)
   686  	assert.Equal(t, opts.Multiplexed, m)
   687  }
   688  
   689  // TestUDPTransportFramerBuilderErr tests nil FramerBuilder error.
   690  func TestUDPTransportFramerBuilderErr(t *testing.T) {
   691  	opts := []transport.RoundTripOption{
   692  		transport.WithDialNetwork("udp"),
   693  	}
   694  	ts := transport.NewClientTransport()
   695  	_, err := ts.RoundTrip(context.Background(), nil, opts...)
   696  	assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientConnectFail))
   697  }
   698  
   699  // TestWithLocalAddr tests local addr.
   700  func TestWithLocalAddr(t *testing.T) {
   701  	opts := &transport.RoundTripOptions{}
   702  	localAddr := "127.0.0.1:8080"
   703  	o := transport.WithLocalAddr(localAddr)
   704  	o(opts)
   705  	assert.Equal(t, opts.LocalAddr, localAddr)
   706  }
   707  
   708  func TestWithDialTimeout(t *testing.T) {
   709  	opts := &transport.RoundTripOptions{}
   710  	timeout := time.Second
   711  	o := transport.WithDialTimeout(timeout)
   712  	o(opts)
   713  	assert.Equal(t, opts.DialTimeout, timeout)
   714  }
   715  
   716  func TestWithProtocol(t *testing.T) {
   717  	opts := &transport.RoundTripOptions{}
   718  	protocol := "xxx-protocol"
   719  	o := transport.WithProtocol(protocol)
   720  	o(opts)
   721  	assert.Equal(t, protocol, opts.Protocol)
   722  }
   723  
   724  func TestWithDisableEncodeTransInfoBase64(t *testing.T) {
   725  	opts := &transport.ClientTransportOptions{}
   726  	transport.WithDisableEncodeTransInfoBase64()(opts)
   727  	assert.Equal(t, true, opts.DisableHTTPEncodeTransInfoBase64)
   728  }