trpc.group/trpc-go/trpc-go@v1.0.3/transport/client_transport_stream_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package transport_test
    15  
    16  import (
    17  	"context"
    18  	"encoding/binary"
    19  	"encoding/json"
    20  	"fmt"
    21  	"io"
    22  	"sync"
    23  	"sync/atomic"
    24  	"testing"
    25  	"time"
    26  
    27  	"trpc.group/trpc-go/trpc-go/codec"
    28  
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  
    32  	"trpc.group/trpc-go/trpc-go/transport"
    33  
    34  	_ "trpc.group/trpc-go/trpc-go"
    35  )
    36  
    37  // TestClientStreamNetworkError test client decode error.
    38  func TestClientStreamNetworkError(t *testing.T) {
    39  	st := transport.NewServerStreamTransport()
    40  	svrCtx, close := context.WithTimeout(context.Background(), 100*time.Millisecond)
    41  	defer close()
    42  	go func() {
    43  		err := st.ListenAndServe(svrCtx,
    44  			transport.WithListenNetwork("tcp"),
    45  			transport.WithListenAddress(":12017"),
    46  			transport.WithHandler(&echoHandler{}),
    47  			transport.WithServerFramerBuilder(&multiplexedFramerBuilder{}),
    48  		)
    49  		require.Nil(t, err)
    50  	}()
    51  
    52  	roundTripOpts := []transport.RoundTripOption{
    53  		transport.WithDialNetwork("tcp"),
    54  		transport.WithDialAddress(":12017"),
    55  	}
    56  
    57  	time.Sleep(20 * time.Millisecond)
    58  	req := &helloRequest{
    59  		Name: "trpc",
    60  		Msg:  "HelloWorld",
    61  	}
    62  
    63  	data, err := json.Marshal(req)
    64  	require.Nil(t, err)
    65  
    66  	lenData := make([]byte, 4)
    67  	binary.BigEndian.PutUint32(lenData, uint32(len(data)))
    68  
    69  	headData := make([]byte, 8)
    70  	binary.BigEndian.PutUint32(headData[:4], defaultStreamID)
    71  	binary.BigEndian.PutUint32(headData[4:8], uint32(len(data)))
    72  
    73  	ctx := context.Background()
    74  	ctx, msg := codec.WithNewMessage(ctx)
    75  	msg.WithStreamID(100)
    76  
    77  	// test IO EOF.
    78  	ct := transport.NewClientStreamTransport()
    79  	fb := &multiplexedFramerBuilder{}
    80  	fb.SetError(io.EOF)
    81  	roundTripOpts = append(roundTripOpts, transport.WithClientFramerBuilder(fb), transport.WithMsg(msg))
    82  	err = ct.Init(ctx, roundTripOpts...)
    83  	assert.Nil(t, err)
    84  	rsp, err := ct.Recv(ctx)
    85  	assert.Equal(t, io.EOF, err, fmt.Sprintf("current err: %+v", err))
    86  	assert.Nil(t, rsp)
    87  
    88  	// test ctx canceled.
    89  	msg.WithStreamID(101)
    90  	fb = &multiplexedFramerBuilder{}
    91  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
    92  	cancel()
    93  	roundTripOpts = append(roundTripOpts, transport.WithClientFramerBuilder(fb), transport.WithMsg(msg))
    94  	err = ct.Init(ctx, roundTripOpts...)
    95  	assert.NotNil(t, err)
    96  
    97  	// test ctx timeout.
    98  	msg.WithStreamID(102)
    99  	fb = &multiplexedFramerBuilder{}
   100  	ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond)
   101  	defer cancel()
   102  	<-ctx.Done()
   103  	roundTripOpts = append(roundTripOpts, transport.WithClientFramerBuilder(fb), transport.WithMsg(msg))
   104  	err = ct.Init(ctx, roundTripOpts...)
   105  	assert.NotNil(t, err)
   106  }
   107  
   108  func TestConcurrent(t *testing.T) {
   109  	st := transport.NewServerStreamTransport()
   110  	serverFinish := make(chan int)
   111  	go func() {
   112  		err := st.ListenAndServe(context.Background(),
   113  			transport.WithListenNetwork("tcp,udp"),
   114  			transport.WithListenAddress(":12015"),
   115  			transport.WithHandler(&echoHandler{}),
   116  			transport.WithServerFramerBuilder(&multiplexedFramerBuilder{}),
   117  		)
   118  		require.Nil(t, err)
   119  		serverFinish <- 1
   120  	}()
   121  	<-serverFinish
   122  
   123  	req := &helloRequest{
   124  		Name: "trpc",
   125  		Msg:  "HelloWorld",
   126  	}
   127  	data, err := json.Marshal(req)
   128  	require.Nil(t, err)
   129  	headData := make([]byte, 8) // head = streamID + data length
   130  	binary.BigEndian.PutUint32(headData[4:8], uint32(len(data)))
   131  	reqData := append(headData, data...)
   132  
   133  	ct := transport.NewClientStreamTransport(transport.WithMaxConcurrentStreams(20), transport.WithMaxIdleConnsPerHost(2))
   134  
   135  	// close stream send and receive.
   136  	var wg sync.WaitGroup
   137  	var index uint32
   138  	for i := 0; i < 200; i++ {
   139  		wg.Add(1)
   140  		go func() {
   141  			ctx := context.Background()
   142  			ctx, msg := codec.WithNewMessage(ctx)
   143  			newIndex := atomic.AddUint32(&index, 1)
   144  			streamID := defaultStreamID + newIndex
   145  			msg.WithStreamID(streamID)
   146  			msg.WithRequestID(streamID)
   147  
   148  			err = ct.Init(ctx, transport.WithDialNetwork("tcp"), transport.WithDialAddress(":12015"),
   149  				transport.WithClientFramerBuilder(&multiplexedFramerBuilder{}),
   150  				transport.WithMsg(msg))
   151  			assert.Nil(t, err)
   152  
   153  			copyData := make([]byte, len(reqData))
   154  			copy(copyData, reqData)
   155  			binary.BigEndian.PutUint32(copyData, streamID)
   156  
   157  			err = ct.Send(ctx, copyData)
   158  			assert.Nil(t, err)
   159  			rspData, err := ct.Recv(ctx)
   160  			assert.Nil(t, err)
   161  			assert.Equal(t, copyData, rspData)
   162  			ct.Close(ctx)
   163  			wg.Done()
   164  		}()
   165  		if i%50 == 0 {
   166  			time.Sleep(50 * time.Millisecond)
   167  		}
   168  	}
   169  	wg.Wait()
   170  }
   171  
   172  /* --------------------------------------------------- mock multiplexed framer ---------------------------------------------- */
   173  
   174  type multiplexedFramerBuilder struct {
   175  	errSet bool
   176  	err    error
   177  	safe   bool
   178  }
   179  
   180  func (fb *multiplexedFramerBuilder) SetError(err error) {
   181  	fb.errSet = true
   182  	fb.err = err
   183  }
   184  
   185  func (fb *multiplexedFramerBuilder) ClearError() {
   186  	fb.errSet = false
   187  	fb.err = nil
   188  }
   189  
   190  func (fb *multiplexedFramerBuilder) New(r io.Reader) codec.Framer {
   191  	return &multiplexedFramer{r: r, fb: fb}
   192  }
   193  
   194  func (fb *multiplexedFramerBuilder) Parse(rc io.Reader) (vid uint32, buf []byte, err error) {
   195  	buf, err = fb.New(rc).ReadFrame()
   196  	if err != nil {
   197  		return 0, nil, err
   198  	}
   199  	return binary.BigEndian.Uint32(buf[:4]), buf, nil
   200  }
   201  
   202  type multiplexedFramer struct {
   203  	fb *multiplexedFramerBuilder
   204  	r  io.Reader
   205  }
   206  
   207  func (f *multiplexedFramer) ReadFrame() ([]byte, error) {
   208  	if f.fb.errSet {
   209  		return nil, f.fb.err
   210  	}
   211  	var headData [8]byte
   212  
   213  	_, err := io.ReadFull(f.r, headData[:])
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  
   218  	length := binary.BigEndian.Uint32(headData[4:])
   219  
   220  	msg := make([]byte, len(headData)+int(length))
   221  	copy(msg, headData[:])
   222  
   223  	_, err = io.ReadFull(f.r, msg[len(headData):])
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  	return msg, nil
   228  }
   229  
   230  func (f *multiplexedFramer) IsSafe() bool {
   231  	return f.fb.safe
   232  }