trpc.group/trpc-go/trpc-go@v1.0.3/transport/client_transport_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package transport_test
    15  
    16  import (
    17  	"context"
    18  	"errors"
    19  	"fmt"
    20  	"io"
    21  	"math"
    22  	"net"
    23  	"strings"
    24  	"testing"
    25  	"time"
    26  
    27  	"trpc.group/trpc-go/trpc-go/codec"
    28  	"trpc.group/trpc-go/trpc-go/errs"
    29  	"trpc.group/trpc-go/trpc-go/pool/connpool"
    30  	"trpc.group/trpc-go/trpc-go/pool/multiplexed"
    31  	"trpc.group/trpc-go/trpc-go/transport"
    32  
    33  	"github.com/stretchr/testify/assert"
    34  	"github.com/stretchr/testify/require"
    35  
    36  	trpc "trpc.group/trpc-go/trpc-go"
    37  )
    38  
    39  func TestTcpRoundTripPoolNIl(t *testing.T) {
    40  	st := transport.NewClientTransport()
    41  	optNetwork := transport.WithDialNetwork("tcp")
    42  	optPool := transport.WithDialPool(nil)
    43  	_, err := st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool)
    44  	assert.NotNil(t, err)
    45  }
    46  
    47  func TestTcpRoundTripTCPErr(t *testing.T) {
    48  	st := transport.NewClientTransport()
    49  	optNetwork := transport.WithDialNetwork("tcp")
    50  	pool := connpool.NewConnectionPool()
    51  	optPool := transport.WithDialPool(pool)
    52  	fb := &trpc.FramerBuilder{}
    53  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
    54  	optDisabled := transport.WithDisableConnectionPool()
    55  	newCtx := context.Background()
    56  	newCtx.Done()
    57  	newCtx.Deadline()
    58  	_, err := st.RoundTrip(newCtx, []byte("hello"), optNetwork, optPool, optFramerBuilder, optDisabled)
    59  	assert.NotNil(t, err)
    60  }
    61  
    62  func TestTcpRoundTripCTXErr(t *testing.T) {
    63  	st := transport.NewClientTransport()
    64  	optNetwork := transport.WithDialNetwork("tcp")
    65  	pool := connpool.NewConnectionPool()
    66  	optPool := transport.WithDialPool(pool)
    67  	fb := &trpc.FramerBuilder{}
    68  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
    69  	_, err := st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool, optFramerBuilder)
    70  	assert.NotNil(t, err)
    71  }
    72  
    73  type fakePool struct {
    74  }
    75  
    76  func (p *fakePool) Get(network string, address string, opts connpool.GetOptions) (net.Conn, error) {
    77  	return &fakeConn{}, nil
    78  }
    79  
    80  type fakeConn struct {
    81  }
    82  
    83  func (c *fakeConn) Close() error {
    84  	return nil
    85  }
    86  
    87  func (c *fakeConn) Read(b []byte) (n int, err error) {
    88  	return 0, nil
    89  }
    90  
    91  type netError struct {
    92  	error
    93  }
    94  
    95  // Timeout() bool
    96  // Temporary() bool
    97  func (c *netError) Timeout() bool {
    98  	return true
    99  }
   100  func (c *netError) Temporary() bool {
   101  	return true
   102  }
   103  
   104  func (c *fakeConn) Write(b []byte) (n int, err error) {
   105  	if Count == 1 {
   106  		return 0, errors.New("write failure")
   107  	}
   108  	if Count == 2 {
   109  		return 0, netError{errors.New("net failure")}
   110  	}
   111  	return len(b), nil
   112  }
   113  
   114  func (c *fakeConn) LocalAddr() net.Addr {
   115  	return nil
   116  }
   117  
   118  func (c *fakeConn) RemoteAddr() net.Addr {
   119  	return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8888}
   120  }
   121  
   122  func (c *fakeConn) SetDeadline(t time.Time) error {
   123  	return nil
   124  }
   125  
   126  func (c *fakeConn) SetReadDeadline(t time.Time) error {
   127  	return nil
   128  }
   129  
   130  func (c *fakeConn) SetWriteDeadline(t time.Time) error {
   131  	return nil
   132  }
   133  
   134  func TestTcpRoundTripReadFrameNil(t *testing.T) {
   135  	st := transport.NewClientTransport()
   136  	optNetwork := transport.WithDialNetwork("tcp")
   137  	optPool := transport.WithDialPool(&fakePool{})
   138  	fb := &trpc.FramerBuilder{}
   139  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
   140  	optReqType := transport.WithReqType(transport.SendOnly)
   141  	optAddress := transport.WithDialAddress(":8888")
   142  	_, err := st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool, optFramerBuilder,
   143  		optReqType, optAddress)
   144  	assert.NotNil(t, err)
   145  }
   146  
   147  func TestTCPRoundTripSetRemoteAddr(t *testing.T) {
   148  	st := transport.NewClientTransport()
   149  	optNetwork := transport.WithDialNetwork("tcp")
   150  	optPool := transport.WithDialPool(&fakePool{})
   151  	fb := &trpc.FramerBuilder{}
   152  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
   153  	optAddress := transport.WithDialAddress("127.0.0.1:8888")
   154  	ctx, msg := codec.WithNewMessage(context.Background())
   155  	_, _ = st.RoundTrip(ctx, []byte("hello"), optNetwork, optPool, optFramerBuilder, optAddress)
   156  	assert.NotNil(t, msg.RemoteAddr())
   157  	assert.Equal(t, "127.0.0.1:8888", msg.RemoteAddr().String())
   158  }
   159  
   160  type newCtx struct {
   161  }
   162  
   163  var Count int64
   164  
   165  func (c *newCtx) Deadline() (deadline time.Time, ok bool) {
   166  	deadline = time.Now()
   167  	return deadline, true
   168  }
   169  func (c *newCtx) Done() <-chan struct{} {
   170  	return nil
   171  }
   172  func (c *newCtx) Err() error {
   173  	if Count == 1 {
   174  		return context.DeadlineExceeded
   175  	}
   176  	return context.Canceled
   177  }
   178  func (c *newCtx) Value(key interface{}) interface{} {
   179  	return nil
   180  }
   181  
   182  func TestTcpRoundTripCanceled(t *testing.T) {
   183  	st := transport.NewClientTransport()
   184  	optNetwork := transport.WithDialNetwork("tcp")
   185  	optPool := transport.WithDialPool(&fakePool{})
   186  	fb := &trpc.FramerBuilder{}
   187  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
   188  	optAddress := transport.WithDialAddress(":8888")
   189  	_, err := st.RoundTrip(&newCtx{}, []byte("hello"), optNetwork, optPool, optFramerBuilder,
   190  		optAddress)
   191  	assert.NotNil(t, err)
   192  }
   193  
   194  func TestTcpRoundTripTimeout(t *testing.T) {
   195  	st := transport.NewClientTransport()
   196  	optNetwork := transport.WithDialNetwork("tcp")
   197  	optPool := transport.WithDialPool(&fakePool{})
   198  	fb := &trpc.FramerBuilder{}
   199  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
   200  	optAddress := transport.WithDialAddress(":8888")
   201  	Count = 1
   202  	_, err := st.RoundTrip(&newCtx{}, []byte("hello"), optNetwork, optPool, optFramerBuilder,
   203  		optAddress)
   204  	assert.NotNil(t, err)
   205  }
   206  
   207  func TestTcpRoundTripConnWriteErr(t *testing.T) {
   208  	st := transport.NewClientTransport()
   209  	optNetwork := transport.WithDialNetwork("tcp")
   210  	optPool := transport.WithDialPool(&fakePool{})
   211  	fb := &trpc.FramerBuilder{}
   212  	optFramerBuilder := transport.WithClientFramerBuilder(fb)
   213  	optAddress := transport.WithDialAddress(":8888")
   214  	Count = 1
   215  	_, err := st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool, optFramerBuilder,
   216  		optAddress)
   217  	assert.NotNil(t, err)
   218  	Count = 2
   219  	_, err = st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool, optFramerBuilder,
   220  		optAddress)
   221  	assert.NotNil(t, err)
   222  }
   223  
   224  type NewPacketConn struct {
   225  }
   226  
   227  func (c *NewPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   228  	return 0, nil, nil
   229  }
   230  func (c *NewPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   231  	if Count == 1 {
   232  		return len(p), errors.New("write failure")
   233  	}
   234  	return len(p), netError{errors.New("net failure")}
   235  }
   236  func (c *NewPacketConn) Close() error {
   237  	return nil
   238  }
   239  func (c *NewPacketConn) LocalAddr() net.Addr {
   240  	return nil
   241  }
   242  func (c *NewPacketConn) SetDeadline(t time.Time) error {
   243  	return nil
   244  }
   245  func (c *NewPacketConn) SetReadDeadline(t time.Time) error {
   246  	return nil
   247  }
   248  func (c *NewPacketConn) SetWriteDeadline(t time.Time) error {
   249  	return nil
   250  }
   251  func (c *NewPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
   252  	return len(b), nil, netError{errors.New("net failure")}
   253  }
   254  
   255  func TestNewClientTransport(t *testing.T) {
   256  	st := transport.NewClientTransport()
   257  	assert.NotNil(t, st)
   258  }
   259  
   260  func TestWithDialPool(t *testing.T) {
   261  	opt := transport.WithDialPool(nil)
   262  	opts := &transport.RoundTripOptions{}
   263  	opt(opts)
   264  	assert.Equal(t, nil, opts.Pool)
   265  }
   266  
   267  func TestWithReqType(t *testing.T) {
   268  	opt := transport.WithReqType(transport.SendOnly)
   269  	opts := &transport.RoundTripOptions{}
   270  	opt(opts)
   271  	assert.Equal(t, transport.SendOnly, opts.ReqType)
   272  }
   273  
   274  type emptyPool struct {
   275  }
   276  
   277  func (p *emptyPool) Get(network string, address string, opts connpool.GetOptions) (net.Conn, error) {
   278  	return nil, errors.New("empty")
   279  }
   280  
   281  var testReqByte = []byte{'a', 'b'}
   282  
   283  func TestWithDialPoolError(t *testing.T) {
   284  	ctx, f := context.WithTimeout(context.Background(), 3*time.Second)
   285  	defer f()
   286  	_, err := transport.RoundTrip(ctx, testReqByte,
   287  		transport.WithDialPool(&emptyPool{}),
   288  		transport.WithDialNetwork("tcp"))
   289  	// fmt.Printf("err: %v", err)
   290  	assert.NotNil(t, err)
   291  }
   292  
   293  func TestContextTimeout(t *testing.T) {
   294  	ctx, f := context.WithTimeout(context.Background(), time.Millisecond)
   295  	defer f()
   296  	<-ctx.Done()
   297  	fb := &trpc.FramerBuilder{}
   298  	_, err := transport.RoundTrip(ctx, testReqByte,
   299  		transport.WithDialNetwork("tcp"),
   300  		transport.WithDialAddress(":8888"),
   301  		transport.WithClientFramerBuilder(fb))
   302  	assert.NotNil(t, err)
   303  }
   304  
   305  func TestContextTimeout_Multiplexed(t *testing.T) {
   306  	ctx, f := context.WithTimeout(context.Background(), time.Millisecond)
   307  	defer f()
   308  	<-ctx.Done()
   309  	fb := &trpc.FramerBuilder{}
   310  	_, err := transport.RoundTrip(ctx, testReqByte,
   311  		transport.WithDialNetwork("tcp"),
   312  		transport.WithDialAddress(":8888"),
   313  		transport.WithMultiplexed(true),
   314  		transport.WithMsg(codec.Message(ctx)),
   315  		transport.WithClientFramerBuilder(fb))
   316  	assert.NotNil(t, err)
   317  }
   318  
   319  func TestContextCancel(t *testing.T) {
   320  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   321  	cancel()
   322  	fb := &trpc.FramerBuilder{}
   323  	_, err := transport.RoundTrip(ctx, testReqByte,
   324  		transport.WithDialNetwork("tcp"),
   325  		transport.WithDialAddress(":8888"),
   326  		transport.WithClientFramerBuilder(fb))
   327  	assert.NotNil(t, err)
   328  }
   329  
   330  func TestWithReqTypeSendOnly(t *testing.T) {
   331  	ctx, f := context.WithTimeout(context.Background(), 3*time.Second)
   332  	defer f()
   333  	_, err := transport.RoundTrip(ctx, []byte{},
   334  		transport.WithReqType(transport.SendOnly),
   335  		transport.WithDialNetwork("tcp"))
   336  	// fmt.Printf("err: %v", err)
   337  	assert.NotNil(t, err)
   338  }
   339  
   340  func TestClientTransport_RoundTrip(t *testing.T) {
   341  	fb := &lengthDelimitedBuilder{}
   342  	go func() {
   343  		err := transport.ListenAndServe(
   344  			transport.WithListenNetwork("udp"),
   345  			transport.WithListenAddress("localhost:9998"),
   346  			transport.WithHandler(&lengthDelimitedHandler{}),
   347  			transport.WithServerFramerBuilder(fb),
   348  		)
   349  		assert.Nil(t, err)
   350  	}()
   351  	time.Sleep(20 * time.Millisecond)
   352  
   353  	t.Run("write: message too long", func(t *testing.T) {
   354  		c := mustListenUDP(t)
   355  		t.Cleanup(func() {
   356  			if err := c.Close(); err != nil {
   357  				t.Log(err)
   358  			}
   359  		})
   360  		largeRequest := encodeLengthDelimited(strings.Repeat("1", math.MaxInt32/4))
   361  		_, err := transport.RoundTrip(context.Background(), largeRequest,
   362  			transport.WithClientFramerBuilder(fb),
   363  			transport.WithDialNetwork("udp"),
   364  			transport.WithDialAddress(c.LocalAddr().String()),
   365  			transport.WithReqType(transport.SendAndRecv),
   366  		)
   367  		require.Equal(t, errs.RetClientNetErr, errs.Code(err))
   368  		require.Contains(t, errs.Msg(err), "udp client transport WriteTo")
   369  	})
   370  
   371  	var err error
   372  	_, err = transport.RoundTrip(context.Background(), encodeLengthDelimited("helloworld"))
   373  	assert.NotNil(t, err)
   374  
   375  	tc := transport.NewClientTransport()
   376  	_, err = tc.RoundTrip(context.Background(), encodeLengthDelimited("helloworld"))
   377  	assert.NotNil(t, err)
   378  
   379  	// Test address invalid.
   380  	_, err = tc.RoundTrip(context.Background(), encodeLengthDelimited("helloworld"),
   381  		transport.WithDialNetwork("udp"),
   382  		transport.WithDialAddress("invalidaddress"),
   383  		transport.WithReqType(transport.SendOnly))
   384  	assert.NotNil(t, err)
   385  
   386  	// Test send only.
   387  	rsp, err := tc.RoundTrip(context.Background(), encodeLengthDelimited("helloworld"),
   388  		transport.WithDialNetwork("udp"),
   389  		transport.WithDialAddress("localhost:9998"),
   390  		transport.WithClientFramerBuilder(fb),
   391  		transport.WithReqType(transport.SendOnly),
   392  		transport.WithConnectionMode(transport.NotConnected))
   393  	assert.NotNil(t, err)
   394  	assert.Equal(t, errs.ErrClientNoResponse, err)
   395  	assert.Nil(t, rsp)
   396  
   397  	// Test multiplexed send only.
   398  	ctx, msg := codec.WithNewMessage(context.Background())
   399  	rsp, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   400  		transport.WithDialNetwork("udp"),
   401  		transport.WithMultiplexed(true),
   402  		transport.WithDialAddress("localhost:9998"),
   403  		transport.WithReqType(transport.SendOnly),
   404  		transport.WithClientFramerBuilder(fb),
   405  		transport.WithMsg(msg),
   406  	)
   407  	assert.NotNil(t, err)
   408  	assert.Equal(t, errs.ErrClientNoResponse, err)
   409  	assert.Nil(t, rsp)
   410  
   411  	// Test context canceled.
   412  	ctx, cancel := context.WithCancel(context.Background())
   413  	cancel()
   414  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   415  		transport.WithDialNetwork("udp"),
   416  		transport.WithClientFramerBuilder(fb),
   417  		transport.WithDialAddress("localhost:9998"))
   418  	assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientCanceled))
   419  
   420  	// Test context timeout.
   421  	ctx, timeout := context.WithTimeout(context.Background(), time.Millisecond)
   422  	defer timeout()
   423  	<-ctx.Done()
   424  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   425  		transport.WithDialNetwork("udp"),
   426  		transport.WithClientFramerBuilder(fb),
   427  		transport.WithDialAddress("localhost:9998"))
   428  	assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientTimeout))
   429  
   430  	// Test roundtrip.
   431  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   432  	defer cancel()
   433  	rsp, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   434  		transport.WithDialNetwork("udp"),
   435  		transport.WithDialAddress("localhost:9998"),
   436  		transport.WithConnectionMode(transport.NotConnected),
   437  		transport.WithClientFramerBuilder(fb),
   438  	)
   439  	assert.NotNil(t, rsp)
   440  	assert.Nil(t, err)
   441  
   442  	// Test setting RemoteAddr of UDP RoundTrip.
   443  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   444  	defer cancel()
   445  	ctx, msg = codec.WithNewMessage(ctx)
   446  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   447  		transport.WithDialNetwork("udp"),
   448  		transport.WithDialAddress("127.0.0.1:9998"),
   449  		transport.WithConnectionMode(transport.Connected),
   450  		transport.WithClientFramerBuilder(fb),
   451  	)
   452  	assert.Nil(t, err)
   453  	assert.Equal(t, "127.0.0.1:9998", msg.RemoteAddr().String())
   454  
   455  	// Test local addr.
   456  	localAddr := "127.0.0.1:"
   457  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   458  	defer cancel()
   459  	ctx, msg = codec.WithNewMessage(ctx)
   460  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   461  		transport.WithDialNetwork("udp"),
   462  		transport.WithDialAddress("127.0.0.1:9998"),
   463  		transport.WithConnectionMode(transport.Connected),
   464  		transport.WithClientFramerBuilder(fb),
   465  		transport.WithLocalAddr(localAddr),
   466  	)
   467  	assert.Nil(t, err)
   468  	assert.Equal(t, "127.0.0.1", msg.LocalAddr().(*net.UDPAddr).IP.String())
   469  
   470  	// Test local addr error.
   471  	localAddr = "invalid address"
   472  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   473  	defer cancel()
   474  	ctx, msg = codec.WithNewMessage(ctx)
   475  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   476  		transport.WithDialNetwork("udp"),
   477  		transport.WithDialAddress("127.0.0.1:9998"),
   478  		transport.WithConnectionMode(transport.Connected),
   479  		transport.WithClientFramerBuilder(fb),
   480  		transport.WithLocalAddr(localAddr),
   481  	)
   482  	assert.NotNil(t, err)
   483  	assert.Nil(t, msg.LocalAddr())
   484  
   485  	// Test readframer error.
   486  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   487  	defer cancel()
   488  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   489  		transport.WithDialNetwork("udp"),
   490  		transport.WithDialAddress("127.0.0.1:9998"),
   491  		transport.WithConnectionMode(transport.Connected),
   492  		transport.WithClientFramerBuilder(&lengthDelimitedBuilder{
   493  			readError: true,
   494  		}),
   495  	)
   496  	assert.Contains(t, err.Error(), readFrameError.Error())
   497  
   498  	// Test readframe bytes remaining error.
   499  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   500  	defer cancel()
   501  	_, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"),
   502  		transport.WithDialNetwork("udp"),
   503  		transport.WithDialAddress("127.0.0.1:9998"),
   504  		transport.WithConnectionMode(transport.Connected),
   505  		transport.WithClientFramerBuilder(&lengthDelimitedBuilder{
   506  			remainingBytes: true,
   507  		}),
   508  	)
   509  	assert.Contains(t, err.Error(), remainingBytesError.Error())
   510  }
   511  
   512  func mustListenUDP(t *testing.T) net.PacketConn {
   513  	c, err := net.ListenPacket("udp", "127.0.0.1:0")
   514  	if err != nil {
   515  		t.Fatal(err)
   516  	}
   517  	return c
   518  }
   519  
   520  // Frame a stream of bytes based on a length prefix
   521  // +------------+--------------------------------+
   522  // | len: uint8 |          frame payload         |
   523  // +------------+--------------------------------+
   524  type lengthDelimitedBuilder struct {
   525  	remainingBytes bool
   526  	readError      bool
   527  }
   528  
   529  func (fb *lengthDelimitedBuilder) New(reader io.Reader) codec.Framer {
   530  	return &lengthDelimited{
   531  		readError:      fb.readError,
   532  		remainingBytes: fb.remainingBytes,
   533  		reader:         reader,
   534  	}
   535  }
   536  
   537  func (fb *lengthDelimitedBuilder) Parse(rc io.Reader) (vid uint32, buf []byte, err error) {
   538  	buf, err = fb.New(rc).ReadFrame()
   539  	if err != nil {
   540  		return 0, nil, err
   541  	}
   542  	return 0, buf, nil
   543  }
   544  
   545  type lengthDelimited struct {
   546  	reader         io.Reader
   547  	readError      bool
   548  	remainingBytes bool
   549  }
   550  
   551  func encodeLengthDelimited(data string) []byte {
   552  	result := []byte{byte(len(data))}
   553  	result = append(result, []byte(data)...)
   554  	return result
   555  }
   556  
   557  var (
   558  	readFrameError      = errors.New("read framer error")
   559  	remainingBytesError = fmt.Errorf(
   560  		"packet data is not drained, the remaining %d will be dropped",
   561  		remainingBytes,
   562  	)
   563  	remainingBytes = 1
   564  )
   565  
   566  func (f *lengthDelimited) ReadFrame() ([]byte, error) {
   567  	if f.readError {
   568  		return nil, readFrameError
   569  	}
   570  	head := make([]byte, 1)
   571  	if _, err := io.ReadFull(f.reader, head); err != nil {
   572  		return nil, err
   573  	}
   574  	bodyLen := int(head[0])
   575  	if f.remainingBytes {
   576  		bodyLen = bodyLen - remainingBytes
   577  	}
   578  	body := make([]byte, bodyLen)
   579  	if _, err := io.ReadFull(f.reader, body); err != nil {
   580  		return nil, err
   581  	}
   582  	return body, nil
   583  }
   584  
   585  type lengthDelimitedHandler struct{}
   586  
   587  func (h *lengthDelimitedHandler) Handle(ctx context.Context, req []byte) ([]byte, error) {
   588  	rsp := make([]byte, len(req)+1)
   589  	rsp[0] = byte(len(req))
   590  	copy(rsp[1:], req)
   591  	return rsp, nil
   592  }
   593  
   594  func TestClientTransport_MultiplexedErr(t *testing.T) {
   595  	listener, err := net.Listen("tcp", ":")
   596  	require.Nil(t, err)
   597  	defer listener.Close()
   598  	go func() {
   599  		transport.ListenAndServe(
   600  			transport.WithListener(listener),
   601  			transport.WithHandler(&echoHandler{}),
   602  			transport.WithServerFramerBuilder(transport.GetFramerBuilder("trpc")),
   603  		)
   604  	}()
   605  	time.Sleep(20 * time.Millisecond)
   606  
   607  	tc := transport.NewClientTransport()
   608  	fb := &trpc.FramerBuilder{}
   609  
   610  	// Test multiplexed context timeout.
   611  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   612  	defer cancel()
   613  	ctx, msg := codec.WithNewMessage(ctx)
   614  	_, err = tc.RoundTrip(ctx, []byte("helloworld"),
   615  		transport.WithDialNetwork(listener.Addr().Network()),
   616  		transport.WithDialAddress(listener.Addr().String()),
   617  		transport.WithMultiplexed(true),
   618  		transport.WithClientFramerBuilder(fb),
   619  		transport.WithMsg(msg),
   620  	)
   621  	assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientTimeout))
   622  
   623  	// Test multiplexed context canceled.
   624  	ctx, cancel = context.WithTimeout(context.Background(), time.Second)
   625  	go func() {
   626  		time.Sleep(time.Millisecond * 200)
   627  		cancel()
   628  	}()
   629  	_, err = tc.RoundTrip(ctx, []byte("helloworld"),
   630  		transport.WithDialNetwork(listener.Addr().Network()),
   631  		transport.WithDialAddress(listener.Addr().String()),
   632  		transport.WithMultiplexed(true),
   633  		transport.WithClientFramerBuilder(fb),
   634  		transport.WithMsg(msg),
   635  	)
   636  	assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientCanceled))
   637  }
   638  
   639  func TestClientTransport_RoundTrip_PreConnected(t *testing.T) {
   640  	go func() {
   641  		err := transport.ListenAndServe(
   642  			transport.WithListenNetwork("udp"),
   643  			transport.WithListenAddress("localhost:9999"),
   644  			transport.WithHandler(&echoHandler{}),
   645  			transport.WithServerFramerBuilder(transport.GetFramerBuilder("trpc")),
   646  		)
   647  		assert.Nil(t, err)
   648  	}()
   649  	time.Sleep(20 * time.Millisecond)
   650  
   651  	var err error
   652  	_, err = transport.RoundTrip(context.Background(), []byte("helloworld"))
   653  	assert.NotNil(t, err)
   654  
   655  	tc := transport.NewClientTransport()
   656  
   657  	// Test connected UDPConn.
   658  	rsp, err := tc.RoundTrip(context.Background(), []byte("helloworld"),
   659  		transport.WithDialNetwork("udp"),
   660  		transport.WithDialAddress("localhost:9999"),
   661  		transport.WithDialPassword("passwd"),
   662  		transport.WithClientFramerBuilder(&trpc.FramerBuilder{}),
   663  		transport.WithReqType(transport.SendOnly),
   664  		transport.WithConnectionMode(transport.Connected))
   665  	assert.NotNil(t, err)
   666  	assert.Equal(t, errs.ErrClientNoResponse, err)
   667  	assert.Nil(t, rsp)
   668  
   669  	// Test context done.
   670  	ctx, cancel := context.WithCancel(context.Background())
   671  	cancel()
   672  	_, err = tc.RoundTrip(ctx, []byte("helloworld"),
   673  		transport.WithDialNetwork("udp"),
   674  		transport.WithDialAddress("localhost:9999"),
   675  		transport.WithConnectionMode(transport.Connected))
   676  	assert.NotNil(t, err)
   677  
   678  	// Test RoundTrip.
   679  	ctx, cancel = context.WithTimeout(ctx, time.Second)
   680  	defer cancel()
   681  	rsp, err = tc.RoundTrip(ctx, []byte("helloworld"),
   682  		transport.WithDialNetwork("udp"),
   683  		transport.WithDialAddress("localhost:9999"),
   684  		transport.WithConnectionMode(transport.Connected))
   685  	assert.NotNil(t, err)
   686  	assert.Nil(t, rsp)
   687  }
   688  
   689  func TestOptions(t *testing.T) {
   690  
   691  	opts := &transport.RoundTripOptions{}
   692  
   693  	o := transport.WithDialTLS("client.cert", "client.key", "ca.pem", "servername")
   694  	o(opts)
   695  	assert.Equal(t, "client.cert", opts.TLSCertFile)
   696  	assert.Equal(t, "client.key", opts.TLSKeyFile)
   697  	assert.Equal(t, "ca.pem", opts.CACertFile)
   698  	assert.Equal(t, "servername", opts.TLSServerName)
   699  
   700  	o = transport.WithDisableConnectionPool()
   701  	o(opts)
   702  
   703  	assert.True(t, opts.DisableConnectionPool)
   704  }
   705  
   706  // TestWithMultiplexedPool tests connection pool multiplexing.
   707  func TestWithMultiplexedPool(t *testing.T) {
   708  	opts := &transport.RoundTripOptions{}
   709  	m := multiplexed.New(multiplexed.WithConnectNumber(10))
   710  	o := transport.WithMultiplexedPool(m)
   711  	o(opts)
   712  	assert.True(t, opts.EnableMultiplexed)
   713  	assert.Equal(t, opts.Multiplexed, m)
   714  }
   715  
   716  // TestUDPTransportFramerBuilderErr tests nil FramerBuilder error.
   717  func TestUDPTransportFramerBuilderErr(t *testing.T) {
   718  	opts := []transport.RoundTripOption{
   719  		transport.WithDialNetwork("udp"),
   720  	}
   721  	ts := transport.NewClientTransport()
   722  	_, err := ts.RoundTrip(context.Background(), nil, opts...)
   723  	assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientConnectFail))
   724  }
   725  
   726  // TestWithLocalAddr tests local addr.
   727  func TestWithLocalAddr(t *testing.T) {
   728  	opts := &transport.RoundTripOptions{}
   729  	localAddr := "127.0.0.1:8080"
   730  	o := transport.WithLocalAddr(localAddr)
   731  	o(opts)
   732  	assert.Equal(t, opts.LocalAddr, localAddr)
   733  }
   734  
   735  func TestWithDialTimeout(t *testing.T) {
   736  	opts := &transport.RoundTripOptions{}
   737  	timeout := time.Second
   738  	o := transport.WithDialTimeout(timeout)
   739  	o(opts)
   740  	assert.Equal(t, opts.DialTimeout, timeout)
   741  }
   742  
   743  func TestWithProtocol(t *testing.T) {
   744  	opts := &transport.RoundTripOptions{}
   745  	protocol := "xxx-protocol"
   746  	o := transport.WithProtocol(protocol)
   747  	o(opts)
   748  	assert.Equal(t, protocol, opts.Protocol)
   749  }
   750  
   751  func TestWithDisableEncodeTransInfoBase64(t *testing.T) {
   752  	opts := &transport.ClientTransportOptions{}
   753  	transport.WithDisableEncodeTransInfoBase64()(opts)
   754  	assert.Equal(t, true, opts.DisableHTTPEncodeTransInfoBase64)
   755  }