trpc.group/trpc-go/trpc-go@v1.0.3/stream/client_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  // Package stream_test Unit test for package stream.
    15  package stream_test
    16  
    17  import (
    18  	"context"
    19  	"crypto/rand"
    20  	"encoding/binary"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"testing"
    25  	"time"
    26  
    27  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    28  
    29  	trpc "trpc.group/trpc-go/trpc-go"
    30  	"trpc.group/trpc-go/trpc-go/client"
    31  	"trpc.group/trpc-go/trpc-go/codec"
    32  	"trpc.group/trpc-go/trpc-go/errs"
    33  	"trpc.group/trpc-go/trpc-go/server"
    34  	"trpc.group/trpc-go/trpc-go/stream"
    35  	"trpc.group/trpc-go/trpc-go/transport"
    36  
    37  	"github.com/stretchr/testify/assert"
    38  )
    39  
    40  var ctx = context.Background()
    41  
    42  type fakeTransport struct {
    43  	expectChan chan recvExpect
    44  	send       func() error
    45  	close      func()
    46  }
    47  
    48  // RoundTrip Mock RoundTrip method.
    49  func (c *fakeTransport) RoundTrip(ctx context.Context, req []byte,
    50  	roundTripOpts ...transport.RoundTripOption) (rsp []byte, err error) {
    51  	return nil, nil
    52  }
    53  
    54  // Send Mock Send method.
    55  func (c *fakeTransport) Send(ctx context.Context, req []byte, opts ...transport.RoundTripOption) error {
    56  	if c.send != nil {
    57  		return c.send()
    58  	}
    59  	return nil
    60  }
    61  
    62  type recvExpect func(*trpc.FrameHead, codec.Msg) ([]byte, error)
    63  
    64  // Recv Mock recv method.
    65  func (c *fakeTransport) Recv(ctx context.Context, opts ...transport.RoundTripOption) ([]byte, error) {
    66  	msg := codec.Message(ctx)
    67  	var fh *trpc.FrameHead
    68  	fh, ok := msg.FrameHead().(*trpc.FrameHead)
    69  	if !ok {
    70  		fh = &trpc.FrameHead{}
    71  		msg.WithFrameHead(fh)
    72  	}
    73  	f := <-c.expectChan
    74  	return f(fh, msg)
    75  }
    76  
    77  // Init Mock Init method.
    78  func (c *fakeTransport) Init(ctx context.Context, opts ...transport.RoundTripOption) error {
    79  	return nil
    80  }
    81  
    82  // Close Mock Close method.
    83  func (c *fakeTransport) Close(ctx context.Context) {
    84  	if c.close != nil {
    85  		c.close()
    86  	}
    87  }
    88  
    89  type fakeCodec struct {
    90  }
    91  
    92  // Encode Mock codec Encode method.
    93  func (c *fakeCodec) Encode(msg codec.Msg, reqBody []byte) (reqBuf []byte, err error) {
    94  	if string(reqBody) == "failbody" {
    95  		return nil, errors.New("encode fail")
    96  	}
    97  	return reqBody, nil
    98  }
    99  
   100  // Decode Mock codec Decode method.
   101  func (c *fakeCodec) Decode(msg codec.Msg, rspBuf []byte) (rspBody []byte, err error) {
   102  	if string(rspBuf) == "businessfail" {
   103  		return nil, errors.New("businessfail")
   104  	}
   105  	if string(rspBuf) == "msgfail" {
   106  		msg.WithClientRspErr(errors.New("msgfail"))
   107  		return nil, nil
   108  	}
   109  	return rspBuf, nil
   110  }
   111  
   112  // TestMain tests the Main function.
   113  func TestMain(m *testing.M) {
   114  	transport.DefaultServerTransport = &fakeServerTransport{}
   115  	m.Run()
   116  }
   117  
   118  // TestClient tests the streaming client.
   119  func TestClient(t *testing.T) {
   120  	var reqBody = &codec.Body{Data: []byte("body")}
   121  	var rspBody = &codec.Body{}
   122  	codec.RegisterSerializer(0, &codec.NoopSerialization{})
   123  	codec.Register("fake", nil, &fakeCodec{})
   124  	codec.Register("fake-nil", nil, nil)
   125  
   126  	cli := stream.NewStreamClient()
   127  	assert.Equal(t, cli, stream.DefaultStreamClient)
   128  
   129  	ctx := context.Background()
   130  	var ft = &fakeTransport{expectChan: make(chan recvExpect, 1)}
   131  	transport.DefaultClientTransport = ft
   132  
   133  	f := func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   134  		return nil, nil
   135  	}
   136  	ft.expectChan <- f
   137  	cs, err := cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   138  		client.WithTarget("ip://127.0.0.1:8000"),
   139  		client.WithProtocol("fake"), client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   140  		client.WithCurrentCompressType(codec.CompressTypeNoop),
   141  		client.WithStreamTransport(ft))
   142  	assert.NotNil(t, cs)
   143  	assert.Nil(t, err)
   144  
   145  	// Test Context.
   146  	resultCtx := cs.Context()
   147  	assert.NotNil(t, resultCtx)
   148  	// Test to send data normally.
   149  	err = cs.SendMsg(reqBody)
   150  	assert.Nil(t, err)
   151  
   152  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   153  		fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)
   154  		return []byte("body"), nil
   155  	}
   156  	ft.expectChan <- f
   157  
   158  	// Test to receive data normally.
   159  	err = cs.RecvMsg(rspBody)
   160  	assert.Nil(t, err)
   161  	assert.Equal(t, rspBody.Data, []byte("body"))
   162  
   163  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   164  		fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE)
   165  		return nil, nil
   166  	}
   167  	ft.expectChan <- f
   168  
   169  	// Test received io.EOF.
   170  	rspBody = &codec.Body{}
   171  	err = cs.RecvMsg(rspBody)
   172  	assert.Equal(t, io.EOF, err)
   173  	assert.Nil(t, rspBody.Data)
   174  
   175  	err = cs.CloseSend()
   176  	assert.Nil(t, err)
   177  
   178  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   179  		return nil, nil
   180  	}
   181  	ft.expectChan <- f
   182  	cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   183  		client.WithTarget("ip://127.0.0.1:8000"),
   184  		client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop),
   185  		client.WithTransport(ft),
   186  		client.WithStreamTransport(ft))
   187  	assert.NotNil(t, cs)
   188  	assert.Nil(t, err)
   189  
   190  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   191  		msg.WithClientRspErr(errors.New("close type is reset"))
   192  		fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE)
   193  		return nil, nil
   194  	}
   195  	ft.expectChan <- f
   196  	// test reset.
   197  	rspBody = &codec.Body{}
   198  	err = cs.RecvMsg(rspBody)
   199  	assert.NotNil(t, err)
   200  	assert.Nil(t, rspBody.Data)
   201  	assert.Contains(t, err.Error(), "close type is reset")
   202  
   203  }
   204  
   205  // TestClientFlowControl tests the streaming client.
   206  func TestClientFlowControl(t *testing.T) {
   207  	var reqBody = &codec.Body{Data: []byte("body")}
   208  
   209  	var rspBody = &codec.Body{}
   210  	codec.RegisterSerializer(0, &codec.NoopSerialization{})
   211  	codec.Register("fake", nil, &fakeCodec{})
   212  	codec.Register("fake-nil", nil, nil)
   213  
   214  	cli := stream.NewStreamClient()
   215  	assert.Equal(t, cli, stream.DefaultStreamClient)
   216  
   217  	ctx := context.Background()
   218  	var ft = &fakeTransport{expectChan: make(chan recvExpect, 1)}
   219  	transport.DefaultClientTransport = ft
   220  
   221  	f := func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   222  		fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)
   223  		msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{InitWindowSize: 2000})
   224  		return nil, nil
   225  	}
   226  	ft.expectChan <- f
   227  
   228  	cs, err := cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   229  		client.WithTarget("ip://127.0.0.1:8000"),
   230  		client.WithProtocol("fake"), client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   231  		client.WithCurrentCompressType(codec.CompressTypeNoop),
   232  		client.WithTransport(ft),
   233  		client.WithStreamTransport(ft))
   234  	assert.NotNil(t, cs)
   235  	assert.Nil(t, err)
   236  
   237  	// Test Context.
   238  	resultCtx := cs.Context()
   239  	assert.NotNil(t, resultCtx)
   240  	// Test to send data normally.
   241  	err = cs.SendMsg(reqBody)
   242  	assert.Nil(t, err)
   243  
   244  	for i := 0; i < 20000; i++ {
   245  		f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   246  			fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)
   247  			return []byte("body"), nil
   248  		}
   249  		ft.expectChan <- f
   250  		// Test to receive data normally.
   251  		err = cs.RecvMsg(rspBody)
   252  		assert.Nil(t, err)
   253  		assert.Equal(t, rspBody.Data, []byte("body"))
   254  	}
   255  
   256  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   257  		fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE)
   258  		return nil, nil
   259  	}
   260  	ft.expectChan <- f
   261  
   262  	// Test received io.EOF.
   263  	rspBody = &codec.Body{}
   264  	err = cs.RecvMsg(rspBody)
   265  	assert.Equal(t, io.EOF, err)
   266  	assert.Nil(t, rspBody.Data)
   267  }
   268  
   269  // TestClientError tests the case of streaming Client error handling.
   270  func TestClientError(t *testing.T) {
   271  	var rspBody = &codec.Body{}
   272  	codec.RegisterSerializer(0, &codec.NoopSerialization{})
   273  	codec.Register("fake", nil, &fakeCodec{})
   274  	codec.Register("fake-nil", nil, nil)
   275  
   276  	cli := stream.NewStreamClient()
   277  	assert.Equal(t, cli, stream.DefaultStreamClient)
   278  
   279  	var ft = &fakeTransport{expectChan: make(chan recvExpect, 1)}
   280  	transport.DefaultClientTransport = ft
   281  	f := func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   282  		return nil, errors.New("init error")
   283  	}
   284  	ft.expectChan <- f
   285  
   286  	// Test for init transport errors.
   287  	cs, err := cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   288  		client.WithTarget("ip://127.0.0.1:8000"),
   289  		client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop),
   290  		client.WithTransport(ft),
   291  		client.WithStreamTransport(ft))
   292  	assert.Nil(t, cs)
   293  	assert.NotNil(t, err)
   294  
   295  	// test Init error.
   296  	cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   297  		client.WithTarget("ip://127.0.0.1:8000"),
   298  		client.WithProtocol("fake-nil"), client.WithSerializationType(codec.SerializationTypeNoop),
   299  		client.WithTransport(ft),
   300  		client.WithStreamTransport(ft))
   301  	assert.Nil(t, cs)
   302  	assert.NotNil(t, err)
   303  
   304  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   305  		return nil, nil
   306  	}
   307  	ft.expectChan <- f
   308  	cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   309  		client.WithTarget("ip://127.0.0.1:8000"),
   310  		client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop),
   311  		client.WithTransport(ft),
   312  		client.WithStreamTransport(ft), client.WithClientStreamQueueSize(100000))
   313  	assert.NotNil(t, cs)
   314  	assert.Nil(t, err)
   315  	// test recv data error.
   316  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   317  		return nil, errors.New("recv data error")
   318  	}
   319  	ft.expectChan <- f
   320  	err = cs.RecvMsg(rspBody)
   321  	assert.NotNil(t, err)
   322  	assert.Nil(t, rspBody.Data)
   323  
   324  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   325  		msg.WithClientRspErr(errors.New("test init with clientRspError"))
   326  		return nil, nil
   327  	}
   328  	ft.expectChan <- f
   329  	cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   330  		client.WithTarget("ip://127.0.0.1:8000"),
   331  		client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop),
   332  		client.WithStreamTransport(ft), client.WithClientStreamQueueSize(100000))
   333  	assert.Nil(t, cs)
   334  	assert.NotNil(t, err)
   335  
   336  	// receive unexpected stream frame type
   337  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   338  		msg.WithStreamFrame(int(1))
   339  		return nil, nil
   340  	}
   341  	ft.expectChan <- f
   342  	cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   343  		client.WithTarget("ip://127.0.0.1:8000"),
   344  		client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop),
   345  		client.WithStreamTransport(ft), client.WithClientStreamQueueSize(100000))
   346  	assert.Nil(t, cs)
   347  	assert.Contains(t, err.Error(), "unexpected frame type")
   348  }
   349  
   350  // TestClientContext tests the case of streaming client context cancel and timeout.
   351  func TestClientContext(t *testing.T) {
   352  
   353  	var rspBody = &codec.Body{}
   354  	codec.RegisterSerializer(0, &codec.NoopSerialization{})
   355  	codec.Register("fake", nil, &fakeCodec{})
   356  	codec.Register("fake-nil", nil, nil)
   357  
   358  	cli := stream.NewStreamClient()
   359  	assert.Equal(t, cli, stream.DefaultStreamClient)
   360  
   361  	var ft = &fakeTransport{expectChan: make(chan recvExpect, 1)}
   362  	transport.DefaultClientTransport = ft
   363  	// test context cancel situation.
   364  	f := func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   365  		return nil, nil
   366  	}
   367  	ft.expectChan <- f
   368  	ctx, cancel := context.WithCancel(context.Background())
   369  	cs, err := cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   370  		client.WithTarget("ip://127.0.0.1:8000"),
   371  		client.WithProtocol("fake"), client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   372  		client.WithCurrentCompressType(codec.CompressTypeNoop),
   373  		client.WithTransport(ft),
   374  		client.WithStreamTransport(ft))
   375  	assert.NotNil(t, cs)
   376  	assert.Nil(t, err)
   377  	cancel()
   378  	err = cs.RecvMsg(rspBody)
   379  	assert.Contains(t, err.Error(), "tcp client stream canceled before recv")
   380  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   381  		return nil, errors.New("close it")
   382  	}
   383  	ft.expectChan <- f
   384  	time.Sleep(5 * time.Millisecond)
   385  	// test context timeout situation.
   386  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   387  		return nil, nil
   388  	}
   389  	ft.expectChan <- f
   390  
   391  	timeoutCtx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
   392  	defer cancel()
   393  	cs, err = cli.NewStream(timeoutCtx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   394  		client.WithTarget("ip://127.0.0.1:8000"),
   395  		client.WithProtocol("fake"), client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   396  		client.WithCurrentCompressType(codec.CompressTypeNoop),
   397  		client.WithTransport(ft),
   398  		client.WithStreamTransport(ft))
   399  	assert.NotNil(t, cs)
   400  	assert.Nil(t, err)
   401  
   402  	err = cs.RecvMsg(rspBody)
   403  	assert.Contains(t, err.Error(), "tcp client stream canceled timeout before recv")
   404  }
   405  
   406  func clientFilterAdd1(ctx context.Context, desc *client.ClientStreamDesc, newStream client.Streamer) (client.ClientStream, error) {
   407  	var msg codec.Msg
   408  	ctx, msg = codec.EnsureMessage(ctx)
   409  	meta := msg.ClientMetaData()
   410  	if meta == nil {
   411  		meta = codec.MetaData{}
   412  	}
   413  	meta[testKey1] = []byte(testData[testKey1])
   414  	msg.WithClientMetaData(meta)
   415  
   416  	s, err := newStream(ctx, desc)
   417  	if err != nil {
   418  		return nil, err
   419  	}
   420  
   421  	return newWrappedClientStream(s), nil
   422  }
   423  
   424  func clientFilterAdd2(ctx context.Context, desc *client.ClientStreamDesc, newStream client.Streamer) (client.ClientStream, error) {
   425  	var msg codec.Msg
   426  	ctx, msg = codec.EnsureMessage(ctx)
   427  	meta := msg.ClientMetaData()
   428  	if meta == nil {
   429  		meta = codec.MetaData{}
   430  	}
   431  	meta[testKey2] = []byte(testData[testKey2])
   432  	msg.WithClientMetaData(meta)
   433  
   434  	s, err := newStream(ctx, desc)
   435  	if err != nil {
   436  		return nil, err
   437  	}
   438  	return newWrappedClientStream(s), nil
   439  }
   440  
   441  type wrappedClientStream struct {
   442  	client.ClientStream
   443  }
   444  
   445  func newWrappedClientStream(s client.ClientStream) client.ClientStream {
   446  	return &wrappedClientStream{s}
   447  }
   448  
   449  func (w *wrappedClientStream) RecvMsg(m interface{}) error {
   450  	err := w.ClientStream.RecvMsg(m)
   451  	num := binary.LittleEndian.Uint64(m.(*codec.Body).Data[:8])
   452  	binary.LittleEndian.PutUint64(m.(*codec.Body).Data[:8], num+1)
   453  	return err
   454  }
   455  
   456  func (w *wrappedClientStream) SendMsg(m interface{}) error {
   457  	num := binary.LittleEndian.Uint64(m.(*codec.Body).Data[:8])
   458  	binary.LittleEndian.PutUint64(m.(*codec.Body).Data[:8], num+1)
   459  	return w.ClientStream.SendMsg(m)
   460  }
   461  
   462  func TestClientStreamClientFilters(t *testing.T) {
   463  	rawData := make([]byte, 1024)
   464  	rand.Read(rawData)
   465  	var beginNum uint64 = 100
   466  
   467  	counts := 1000
   468  	svrOpts := []server.Option{
   469  		server.WithAddress("127.0.0.1:30211"),
   470  		server.WithStreamFilters(serverFilterAdd1, serverFilterAdd2),
   471  	}
   472  	handle := func(s server.Stream) error {
   473  		var req *codec.Body
   474  
   475  		// server receives data.
   476  		for i := 0; i < counts; i++ {
   477  			req = getBytes(1024)
   478  			err := s.RecvMsg(req)
   479  			assert.Nil(t, err)
   480  			resultNum := binary.LittleEndian.Uint64(req.Data[:8])
   481  
   482  			// After the client SendMsg + server RecvMsg, two Filter, Num+4.
   483  			assert.Equal(t, beginNum+4, resultNum)
   484  			assert.Equal(t, rawData[8:], req.Data[8:])
   485  		}
   486  		err := s.RecvMsg(getBytes(1024))
   487  		assert.Equal(t, io.EOF, err)
   488  
   489  		// server sends data.
   490  		rsp := getBytes(1024)
   491  		for i := 0; i < counts; i++ {
   492  			copy(rsp.Data, req.Data)
   493  			err = s.SendMsg(rsp)
   494  			assert.Nil(t, err)
   495  		}
   496  		return nil
   497  	}
   498  	svr := startStreamServer(handle, svrOpts)
   499  	defer closeStreamServer(svr)
   500  
   501  	cliOpts := []client.Option{
   502  		client.WithTarget("ip://127.0.0.1:30211"),
   503  		client.WithStreamFilters(clientFilterAdd1, clientFilterAdd2),
   504  	}
   505  	cliStream, err := getClientStream(context.Background(), bidiDesc, cliOpts)
   506  	assert.Nil(t, err)
   507  
   508  	// client sends data.
   509  	for i := 0; i < counts; i++ {
   510  		req := getBytes(1024)
   511  		copy(req.Data, rawData)
   512  		binary.LittleEndian.PutUint64(req.Data[:8], beginNum)
   513  
   514  		err = cliStream.SendMsg(req)
   515  		assert.Nil(t, err)
   516  	}
   517  	err = cliStream.CloseSend()
   518  	assert.Nil(t, err)
   519  
   520  	// client receives data.
   521  	for i := 0; i < counts; i++ {
   522  		rsp := getBytes(1024)
   523  		err = cliStream.RecvMsg(rsp)
   524  		assert.Nil(t, err)
   525  
   526  		// After the client once SendMsg, once RecvMsg, two Filter, Num+4.
   527  		resultNum := binary.LittleEndian.Uint64(rsp.Data[:8])
   528  		assert.Equal(t, beginNum+8, resultNum)
   529  		assert.Equal(t, rawData[8:], rsp.Data[8:])
   530  	}
   531  	rsp := getBytes(1024)
   532  	err = cliStream.RecvMsg(rsp)
   533  	assert.Equal(t, io.EOF, err)
   534  }
   535  
   536  func TestClientStreamFlowControlStop(t *testing.T) {
   537  	windows := 102400
   538  	dataLen := 1024
   539  	maxSends := windows / dataLen
   540  	svrOpts := []server.Option{
   541  		server.WithAddress("127.0.0.1:30211"),
   542  		server.WithMaxWindowSize(uint32(windows)),
   543  	}
   544  	handle := func(s server.Stream) error {
   545  		time.Sleep(time.Hour)
   546  		return nil
   547  	}
   548  	svr := startStreamServer(handle, svrOpts)
   549  	defer closeStreamServer(svr)
   550  
   551  	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(200*time.Millisecond))
   552  	defer cancel()
   553  	cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30211")}
   554  	cliStream, err := getClientStream(ctx, bidiDesc, cliOpts)
   555  	assert.Nil(t, err)
   556  
   557  	req := getBytes(dataLen)
   558  	rand.Read(req.Data)
   559  
   560  	for i := 0; i < maxSends; i++ {
   561  		err = cliStream.SendMsg(req)
   562  		assert.Nil(t, err)
   563  	}
   564  	err = cliStream.SendMsg(req)
   565  	assert.Equal(t, errors.New("stream is already closed"), err)
   566  }
   567  
   568  func TestServerStreamFlowControlStop(t *testing.T) {
   569  	windows := 102400
   570  	dataLen := 1024
   571  	maxSends := windows / dataLen
   572  	waitCh := make(chan struct{}, 1)
   573  	svrOpts := []server.Option{server.WithAddress("127.0.0.1:30211")}
   574  	handle := func(s server.Stream) error {
   575  		rsp := getBytes(dataLen)
   576  		rand.Read(rsp.Data)
   577  		for i := 0; i < maxSends; i++ {
   578  			err := s.SendMsg(rsp)
   579  			assert.Nil(t, err)
   580  		}
   581  
   582  		finish := make(chan struct{}, 1)
   583  		go func() {
   584  			err := s.SendMsg(rsp)
   585  			assert.Equal(t, errors.New("stream is already closed"), err)
   586  			finish <- struct{}{}
   587  		}()
   588  
   589  		deadline := time.NewTimer(200 * time.Millisecond)
   590  		select {
   591  		case <-deadline.C:
   592  		case <-finish:
   593  			assert.Fail(t, "SendMsg should block")
   594  		}
   595  
   596  		waitCh <- struct{}{}
   597  		return nil
   598  	}
   599  	svr := startStreamServer(handle, svrOpts)
   600  	defer closeStreamServer(svr)
   601  
   602  	cliOpts := []client.Option{
   603  		client.WithTarget("ip://127.0.0.1:30211"),
   604  		client.WithMaxWindowSize(uint32(windows)),
   605  	}
   606  	_, err := getClientStream(context.Background(), bidiDesc, cliOpts)
   607  	assert.Nil(t, err)
   608  	<-waitCh
   609  }
   610  
   611  func TestClientStreamSendRecvNoBlock(t *testing.T) {
   612  	svrOpts := []server.Option{server.WithAddress("127.0.0.1:30210")}
   613  	handle := func(s server.Stream) error {
   614  		// Must sleep, to avoid returning before receiving the first packet from the client,
   615  		// resulting in the processing of the first packet returns an error,
   616  		// losing the chance for the test client to block on the second SendMsg.
   617  		time.Sleep(200 * time.Millisecond)
   618  		return errors.New("test error")
   619  	}
   620  	svr := startStreamServer(handle, svrOpts)
   621  	defer closeStreamServer(svr)
   622  
   623  	cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30210")}
   624  	cliStream, err := getClientStream(context.Background(), bidiDesc, cliOpts)
   625  	assert.Nil(t, err)
   626  
   627  	// defaultInitWindowSize = 65535.
   628  	req := getBytes(65535)
   629  	err = cliStream.SendMsg(req)
   630  	assert.Nil(t, err)
   631  
   632  	err = cliStream.SendMsg(req)
   633  	fmt.Println(err)
   634  	assert.NotNil(t, err)
   635  
   636  	rsp := getBytes(1024)
   637  	err = cliStream.RecvMsg(rsp)
   638  	assert.NotNil(t, err)
   639  }
   640  
   641  func TestServerStreamSendRecvNoBlock(t *testing.T) {
   642  	svrOpts := []server.Option{server.WithAddress("127.0.0.1:30210")}
   643  	SendMsgReturn := make(chan struct{}, 1)
   644  	RecvMsgReturn := make(chan struct{}, 1)
   645  	handle := func(s server.Stream) error {
   646  		go func() {
   647  			msg := getBytes(65535)
   648  			s.SendMsg(msg)
   649  			s.SendMsg(msg)
   650  			SendMsgReturn <- struct{}{}
   651  		}()
   652  		go func() {
   653  			msg := getBytes(1024)
   654  			s.RecvMsg(msg)
   655  			s.RecvMsg(msg)
   656  			RecvMsgReturn <- struct{}{}
   657  		}()
   658  		// Must sleep, to avoid the second SendMsg does not enter the waiting window to block.
   659  		time.Sleep(200 * time.Millisecond)
   660  		return nil
   661  	}
   662  	svr := startStreamServer(handle, svrOpts)
   663  	defer closeStreamServer(svr)
   664  
   665  	cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30210")}
   666  	_, err := getClientStream(context.Background(), bidiDesc, cliOpts)
   667  	assert.Nil(t, err)
   668  
   669  	<-SendMsgReturn
   670  	<-RecvMsgReturn
   671  }
   672  
   673  func TestClientStreamReturn(t *testing.T) {
   674  	const (
   675  		invalidCompressType = -1
   676  		dataLen             = 1024
   677  	)
   678  
   679  	svrOpts := []server.Option{
   680  		server.WithAddress("127.0.0.1:30211"),
   681  		server.WithCurrentCompressType(invalidCompressType),
   682  	}
   683  	handle := func(s server.Stream) error {
   684  		req := getBytes(dataLen)
   685  		s.RecvMsg(req)
   686  		rsp := req
   687  		s.SendMsg(rsp)
   688  		return errs.NewFrameError(101, "expected error")
   689  	}
   690  	svr := startStreamServer(handle, svrOpts)
   691  	defer closeStreamServer(svr)
   692  
   693  	cliOpts := []client.Option{
   694  		client.WithTarget("ip://127.0.0.1:30211"),
   695  		client.WithCompressType(invalidCompressType),
   696  	}
   697  
   698  	clientStream, err := getClientStream(context.Background(), clientDesc, cliOpts)
   699  	assert.Nil(t, err)
   700  	err = clientStream.SendMsg(getBytes(dataLen))
   701  	assert.Nil(t, err)
   702  
   703  	rsp := getBytes(dataLen)
   704  	err = clientStream.RecvMsg(rsp)
   705  
   706  	assert.EqualValues(t, int32(101), errs.Code(err.(*errs.Error).Unwrap()))
   707  }
   708  
   709  // TestClientSendFailWhenServerUnavailable test when the client blocks
   710  // on SendMsg because of flow control, if the server is closed, the client
   711  // SendMsg should return.
   712  func TestClientSendFailWhenServerUnavailable(t *testing.T) {
   713  	codec.Register("mock", nil, &fakeCodec{})
   714  	tp := &fakeTransport{expectChan: make(chan recvExpect, 1)}
   715  	tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) {
   716  		return nil, nil
   717  	}
   718  	cs, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "",
   719  		client.WithProtocol("mock"),
   720  		client.WithTarget("ip://127.0.0.1:8000"),
   721  		client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   722  		client.WithStreamTransport(tp),
   723  	)
   724  	assert.Nil(t, err)
   725  	assert.NotNil(t, cs)
   726  	assert.Nil(t, cs.SendMsg(getBytes(65535)))
   727  	tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) {
   728  		return nil, errors.New("server is closed")
   729  	}
   730  	assert.Eventually(
   731  		t,
   732  		func() bool {
   733  			cs.SendMsg(getBytes(65535))
   734  			return true
   735  		},
   736  		time.Second,
   737  		100*time.Millisecond,
   738  	)
   739  }
   740  
   741  // TestClientReceiveErrorWhenServerUnavailable tests that the client receives a non-io.EOF
   742  // error when the server crashes or the connection is closed, otherwise the client would
   743  // mistakenly think that the server closed the stream normally.
   744  func TestClientReceiveErrorWhenServerUnavailable(t *testing.T) {
   745  	codec.Register("mock", nil, &fakeCodec{})
   746  	tp := &fakeTransport{expectChan: make(chan recvExpect, 1)}
   747  	tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) {
   748  		return nil, nil
   749  	}
   750  	cs, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "",
   751  		client.WithProtocol("mock"),
   752  		client.WithTarget("ip://127.0.0.1:8000"),
   753  		client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   754  		client.WithStreamTransport(tp),
   755  	)
   756  	assert.Nil(t, err)
   757  	assert.NotNil(t, cs)
   758  	tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) {
   759  		return nil, io.EOF
   760  	}
   761  	err = cs.RecvMsg(nil)
   762  	assert.NotEqual(t, io.EOF, err)
   763  	assert.ErrorIs(t, err, io.EOF)
   764  }
   765  
   766  func TestClientNewStreamFail(t *testing.T) {
   767  	codec.Register("mock", nil, &fakeCodec{})
   768  	t.Run("Close Transport when Send Fail", func(t *testing.T) {
   769  		var isClosed bool
   770  		tp := &fakeTransport{expectChan: make(chan recvExpect, 1)}
   771  		tp.send = func() error {
   772  			return errors.New("client error")
   773  		}
   774  		tp.close = func() {
   775  			isClosed = true
   776  		}
   777  		_, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "",
   778  			client.WithProtocol("mock"),
   779  			client.WithTarget("ip://127.0.0.1:8000"),
   780  			client.WithStreamTransport(tp),
   781  		)
   782  		assert.NotNil(t, err)
   783  		assert.True(t, isClosed)
   784  	})
   785  	t.Run("Close Transport when Recv Fail", func(t *testing.T) {
   786  		var isClosed bool
   787  		tp := &fakeTransport{expectChan: make(chan recvExpect, 1)}
   788  		tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) {
   789  			m.WithClientRspErr(errors.New("server error"))
   790  			return nil, nil
   791  		}
   792  		tp.close = func() {
   793  			isClosed = true
   794  		}
   795  		_, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "",
   796  			client.WithProtocol("mock"),
   797  			client.WithTarget("ip://127.0.0.1:8000"),
   798  			client.WithStreamTransport(tp),
   799  		)
   800  		assert.NotNil(t, err)
   801  		assert.True(t, isClosed)
   802  	})
   803  }
   804  
   805  func TestClientServerCompress(t *testing.T) {
   806  	var (
   807  		dataLen      = 1024
   808  		compressType = codec.CompressTypeSnappy
   809  	)
   810  	svrOpts := []server.Option{
   811  		server.WithAddress("127.0.0.1:30211"),
   812  	}
   813  	handle := func(s server.Stream) error {
   814  		assert.Equal(t, compressType, codec.Message(s.Context()).CompressType())
   815  		req := getBytes(dataLen)
   816  		s.RecvMsg(req)
   817  		rsp := req
   818  		s.SendMsg(rsp)
   819  		return nil
   820  	}
   821  	svr := startStreamServer(handle, svrOpts)
   822  	defer closeStreamServer(svr)
   823  
   824  	cliOpts := []client.Option{
   825  		client.WithTarget("ip://127.0.0.1:30211"),
   826  		client.WithCompressType(compressType),
   827  	}
   828  
   829  	clientStream, err := getClientStream(context.Background(), clientDesc, cliOpts)
   830  	assert.Nil(t, err)
   831  	req := getBytes(dataLen)
   832  	rand.Read(req.Data)
   833  	err = clientStream.SendMsg(req)
   834  	assert.Nil(t, err)
   835  
   836  	rsp := getBytes(dataLen)
   837  	err = clientStream.RecvMsg(rsp)
   838  	assert.Equal(t, rsp.Data, req.Data)
   839  	assert.Nil(t, err)
   840  }