trpc.group/trpc-go/trpc-go@v1.0.2/stream/client_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  // Package stream_test Unit test for package stream.
    15  package stream_test
    16  
    17  import (
    18  	"context"
    19  	"encoding/binary"
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"math/rand"
    24  	"testing"
    25  	"time"
    26  
    27  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    28  
    29  	trpc "trpc.group/trpc-go/trpc-go"
    30  	"trpc.group/trpc-go/trpc-go/client"
    31  	"trpc.group/trpc-go/trpc-go/codec"
    32  	"trpc.group/trpc-go/trpc-go/errs"
    33  	"trpc.group/trpc-go/trpc-go/server"
    34  	"trpc.group/trpc-go/trpc-go/stream"
    35  	"trpc.group/trpc-go/trpc-go/transport"
    36  
    37  	"github.com/stretchr/testify/assert"
    38  )
    39  
    40  var ctx = context.Background()
    41  
    42  type fakeTransport struct {
    43  	expectChan chan recvExpect
    44  }
    45  
    46  // RoundTrip Mock RoundTrip method.
    47  func (c *fakeTransport) RoundTrip(ctx context.Context, req []byte,
    48  	roundTripOpts ...transport.RoundTripOption) (rsp []byte, err error) {
    49  	return nil, nil
    50  }
    51  
    52  // Send Mock Send method.
    53  func (c *fakeTransport) Send(ctx context.Context, req []byte, opts ...transport.RoundTripOption) error {
    54  	err, ok := ctx.Value("send-error").(string)
    55  	if ok {
    56  		return errors.New(err)
    57  	}
    58  	return nil
    59  }
    60  
    61  type recvExpect func(*trpc.FrameHead, codec.Msg) ([]byte, error)
    62  
    63  // Recv Mock recv method.
    64  func (c *fakeTransport) Recv(ctx context.Context, opts ...transport.RoundTripOption) ([]byte, error) {
    65  	msg := codec.Message(ctx)
    66  	var fh *trpc.FrameHead
    67  	fh, ok := msg.FrameHead().(*trpc.FrameHead)
    68  	if !ok {
    69  		fh = &trpc.FrameHead{}
    70  		msg.WithFrameHead(fh)
    71  	}
    72  	f := <-c.expectChan
    73  	return f(fh, msg)
    74  }
    75  
    76  // Init Mock Init method.
    77  func (c *fakeTransport) Init(ctx context.Context, opts ...transport.RoundTripOption) error {
    78  	return nil
    79  }
    80  
    81  // Close Mock Close method.
    82  func (c *fakeTransport) Close(ctx context.Context) {
    83  	return
    84  }
    85  
    86  type fakeCodec struct {
    87  }
    88  
    89  // Encode Mock codec Encode method.
    90  func (c *fakeCodec) Encode(msg codec.Msg, reqBody []byte) (reqBuf []byte, err error) {
    91  	if string(reqBody) == "failbody" {
    92  		return nil, errors.New("encode fail")
    93  	}
    94  	return reqBody, nil
    95  }
    96  
    97  // Decode Mock codec Decode method.
    98  func (c *fakeCodec) Decode(msg codec.Msg, rspBuf []byte) (rspBody []byte, err error) {
    99  	if string(rspBuf) == "businessfail" {
   100  		return nil, errors.New("businessfail")
   101  	}
   102  	if string(rspBuf) == "msgfail" {
   103  		msg.WithClientRspErr(errors.New("msgfail"))
   104  		return nil, nil
   105  	}
   106  	return rspBuf, nil
   107  }
   108  
   109  // TestMain tests the Main function.
   110  func TestMain(m *testing.M) {
   111  	transport.DefaultServerTransport = &fakeServerTransport{}
   112  	m.Run()
   113  }
   114  
   115  // TestClient tests the streaming client.
   116  func TestClient(t *testing.T) {
   117  	var reqBody = &codec.Body{Data: []byte("body")}
   118  	var rspBody = &codec.Body{}
   119  	codec.RegisterSerializer(0, &codec.NoopSerialization{})
   120  	codec.Register("fake", nil, &fakeCodec{})
   121  	codec.Register("fake-nil", nil, nil)
   122  
   123  	cli := stream.NewStreamClient()
   124  	assert.Equal(t, cli, stream.DefaultStreamClient)
   125  
   126  	ctx := context.Background()
   127  	var ft = &fakeTransport{expectChan: make(chan recvExpect, 1)}
   128  	transport.DefaultClientTransport = ft
   129  
   130  	f := func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   131  		return nil, nil
   132  	}
   133  	ft.expectChan <- f
   134  	cs, err := cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   135  		client.WithTarget("ip://127.0.0.1:8000"),
   136  		client.WithProtocol("fake"), client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   137  		client.WithCurrentCompressType(codec.CompressTypeNoop),
   138  		client.WithStreamTransport(ft))
   139  	assert.NotNil(t, cs)
   140  	assert.Nil(t, err)
   141  
   142  	// Test Context.
   143  	resultCtx := cs.Context()
   144  	assert.NotNil(t, resultCtx)
   145  	// Test to send data normally.
   146  	err = cs.SendMsg(reqBody)
   147  	assert.Nil(t, err)
   148  
   149  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   150  		fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)
   151  		return []byte("body"), nil
   152  	}
   153  	ft.expectChan <- f
   154  
   155  	// Test to receive data normally.
   156  	err = cs.RecvMsg(rspBody)
   157  	assert.Nil(t, err)
   158  	assert.Equal(t, rspBody.Data, []byte("body"))
   159  
   160  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   161  		fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE)
   162  		return nil, nil
   163  	}
   164  	ft.expectChan <- f
   165  
   166  	// Test received io.EOF.
   167  	rspBody = &codec.Body{}
   168  	err = cs.RecvMsg(rspBody)
   169  	assert.Equal(t, io.EOF, err)
   170  	assert.Nil(t, rspBody.Data)
   171  
   172  	err = cs.CloseSend()
   173  	assert.Nil(t, err)
   174  
   175  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   176  		return nil, nil
   177  	}
   178  	ft.expectChan <- f
   179  	cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   180  		client.WithTarget("ip://127.0.0.1:8000"),
   181  		client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop),
   182  		client.WithTransport(ft),
   183  		client.WithStreamTransport(ft))
   184  	assert.NotNil(t, cs)
   185  	assert.Nil(t, err)
   186  
   187  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   188  		msg.WithClientRspErr(errors.New("close type is reset"))
   189  		fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE)
   190  		return nil, nil
   191  	}
   192  	ft.expectChan <- f
   193  	// test reset.
   194  	rspBody = &codec.Body{}
   195  	err = cs.RecvMsg(rspBody)
   196  	assert.NotNil(t, err)
   197  	assert.Nil(t, rspBody.Data)
   198  	assert.Contains(t, err.Error(), "close type is reset")
   199  
   200  }
   201  
   202  // TestClientFlowControl tests the streaming client.
   203  func TestClientFlowControl(t *testing.T) {
   204  	var reqBody = &codec.Body{Data: []byte("body")}
   205  
   206  	var rspBody = &codec.Body{}
   207  	codec.RegisterSerializer(0, &codec.NoopSerialization{})
   208  	codec.Register("fake", nil, &fakeCodec{})
   209  	codec.Register("fake-nil", nil, nil)
   210  
   211  	cli := stream.NewStreamClient()
   212  	assert.Equal(t, cli, stream.DefaultStreamClient)
   213  
   214  	ctx := context.Background()
   215  	var ft = &fakeTransport{expectChan: make(chan recvExpect, 1)}
   216  	transport.DefaultClientTransport = ft
   217  
   218  	f := func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   219  		fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)
   220  		msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{InitWindowSize: 2000})
   221  		return nil, nil
   222  	}
   223  	ft.expectChan <- f
   224  
   225  	cs, err := cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   226  		client.WithTarget("ip://127.0.0.1:8000"),
   227  		client.WithProtocol("fake"), client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   228  		client.WithCurrentCompressType(codec.CompressTypeNoop),
   229  		client.WithTransport(ft),
   230  		client.WithStreamTransport(ft))
   231  	assert.NotNil(t, cs)
   232  	assert.Nil(t, err)
   233  
   234  	// Test Context.
   235  	resultCtx := cs.Context()
   236  	assert.NotNil(t, resultCtx)
   237  	// Test to send data normally.
   238  	err = cs.SendMsg(reqBody)
   239  	assert.Nil(t, err)
   240  
   241  	for i := 0; i < 20000; i++ {
   242  		f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   243  			fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)
   244  			return []byte("body"), nil
   245  		}
   246  		ft.expectChan <- f
   247  		// Test to receive data normally.
   248  		err = cs.RecvMsg(rspBody)
   249  		assert.Nil(t, err)
   250  		assert.Equal(t, rspBody.Data, []byte("body"))
   251  	}
   252  
   253  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   254  		fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE)
   255  		return nil, nil
   256  	}
   257  	ft.expectChan <- f
   258  
   259  	// Test received io.EOF.
   260  	rspBody = &codec.Body{}
   261  	err = cs.RecvMsg(rspBody)
   262  	assert.Equal(t, io.EOF, err)
   263  	assert.Nil(t, rspBody.Data)
   264  }
   265  
   266  // TestClientError tests the case of streaming Client error handling.
   267  func TestClientError(t *testing.T) {
   268  	var rspBody = &codec.Body{}
   269  	codec.RegisterSerializer(0, &codec.NoopSerialization{})
   270  	codec.Register("fake", nil, &fakeCodec{})
   271  	codec.Register("fake-nil", nil, nil)
   272  
   273  	cli := stream.NewStreamClient()
   274  	assert.Equal(t, cli, stream.DefaultStreamClient)
   275  
   276  	var ft = &fakeTransport{expectChan: make(chan recvExpect, 1)}
   277  	transport.DefaultClientTransport = ft
   278  	f := func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   279  		return nil, errors.New("init error")
   280  	}
   281  	ft.expectChan <- f
   282  
   283  	// Test for init transport errors.
   284  	cs, err := cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   285  		client.WithTarget("ip://127.0.0.1:8000"),
   286  		client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop),
   287  		client.WithTransport(ft),
   288  		client.WithStreamTransport(ft))
   289  	assert.Nil(t, cs)
   290  	assert.NotNil(t, err)
   291  
   292  	// test Init error.
   293  	cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   294  		client.WithTarget("ip://127.0.0.1:8000"),
   295  		client.WithProtocol("fake-nil"), client.WithSerializationType(codec.SerializationTypeNoop),
   296  		client.WithTransport(ft),
   297  		client.WithStreamTransport(ft))
   298  	assert.Nil(t, cs)
   299  	assert.NotNil(t, err)
   300  
   301  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   302  		return nil, nil
   303  	}
   304  	ft.expectChan <- f
   305  	cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   306  		client.WithTarget("ip://127.0.0.1:8000"),
   307  		client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop),
   308  		client.WithTransport(ft),
   309  		client.WithStreamTransport(ft), client.WithClientStreamQueueSize(100000))
   310  	assert.NotNil(t, cs)
   311  	assert.Nil(t, err)
   312  	// test recv data error.
   313  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   314  		return nil, errors.New("recv data error")
   315  	}
   316  	ft.expectChan <- f
   317  	err = cs.RecvMsg(rspBody)
   318  	assert.NotNil(t, err)
   319  	assert.Nil(t, rspBody.Data)
   320  
   321  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   322  		msg.WithClientRspErr(errors.New("test init with clientRspError"))
   323  		return nil, nil
   324  	}
   325  	ft.expectChan <- f
   326  	cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   327  		client.WithTarget("ip://127.0.0.1:8000"),
   328  		client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop),
   329  		client.WithStreamTransport(ft), client.WithClientStreamQueueSize(100000))
   330  	assert.Nil(t, cs)
   331  	assert.NotNil(t, err)
   332  
   333  }
   334  
   335  // TestClientContext tests the case of streaming client context cancel and timeout.
   336  func TestClientContext(t *testing.T) {
   337  
   338  	var rspBody = &codec.Body{}
   339  	codec.RegisterSerializer(0, &codec.NoopSerialization{})
   340  	codec.Register("fake", nil, &fakeCodec{})
   341  	codec.Register("fake-nil", nil, nil)
   342  
   343  	cli := stream.NewStreamClient()
   344  	assert.Equal(t, cli, stream.DefaultStreamClient)
   345  
   346  	var ft = &fakeTransport{expectChan: make(chan recvExpect, 1)}
   347  	transport.DefaultClientTransport = ft
   348  	// test context cancel situation.
   349  	f := func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   350  		return nil, nil
   351  	}
   352  	ft.expectChan <- f
   353  	ctx, cancel := context.WithCancel(context.Background())
   354  	cs, err := cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   355  		client.WithTarget("ip://127.0.0.1:8000"),
   356  		client.WithProtocol("fake"), client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   357  		client.WithCurrentCompressType(codec.CompressTypeNoop),
   358  		client.WithTransport(ft),
   359  		client.WithStreamTransport(ft))
   360  	assert.NotNil(t, cs)
   361  	assert.Nil(t, err)
   362  	cancel()
   363  	err = cs.RecvMsg(rspBody)
   364  	assert.Contains(t, err.Error(), "tcp client stream canceled before recv")
   365  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   366  		return nil, errors.New("close it")
   367  	}
   368  	ft.expectChan <- f
   369  	time.Sleep(5 * time.Millisecond)
   370  	// test context timeout situation.
   371  	f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) {
   372  		return nil, nil
   373  	}
   374  	ft.expectChan <- f
   375  
   376  	timeoutCtx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
   377  	defer cancel()
   378  	cs, err = cli.NewStream(timeoutCtx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello",
   379  		client.WithTarget("ip://127.0.0.1:8000"),
   380  		client.WithProtocol("fake"), client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   381  		client.WithCurrentCompressType(codec.CompressTypeNoop),
   382  		client.WithTransport(ft),
   383  		client.WithStreamTransport(ft))
   384  	assert.NotNil(t, cs)
   385  	assert.Nil(t, err)
   386  
   387  	err = cs.RecvMsg(rspBody)
   388  	assert.Contains(t, err.Error(), "tcp client stream canceled timeout before recv")
   389  }
   390  
   391  func clientFilterAdd1(ctx context.Context, desc *client.ClientStreamDesc, newStream client.Streamer) (client.ClientStream, error) {
   392  	var msg codec.Msg
   393  	ctx, msg = codec.EnsureMessage(ctx)
   394  	meta := msg.ClientMetaData()
   395  	if meta == nil {
   396  		meta = codec.MetaData{}
   397  	}
   398  	meta[testKey1] = []byte(testData[testKey1])
   399  	msg.WithClientMetaData(meta)
   400  
   401  	s, err := newStream(ctx, desc)
   402  	if err != nil {
   403  		return nil, err
   404  	}
   405  
   406  	return newWrappedClientStream(s), nil
   407  }
   408  
   409  func clientFilterAdd2(ctx context.Context, desc *client.ClientStreamDesc, newStream client.Streamer) (client.ClientStream, error) {
   410  	var msg codec.Msg
   411  	ctx, msg = codec.EnsureMessage(ctx)
   412  	meta := msg.ClientMetaData()
   413  	if meta == nil {
   414  		meta = codec.MetaData{}
   415  	}
   416  	meta[testKey2] = []byte(testData[testKey2])
   417  	msg.WithClientMetaData(meta)
   418  
   419  	s, err := newStream(ctx, desc)
   420  	if err != nil {
   421  		return nil, err
   422  	}
   423  	return newWrappedClientStream(s), nil
   424  }
   425  
   426  type wrappedClientStream struct {
   427  	client.ClientStream
   428  }
   429  
   430  func newWrappedClientStream(s client.ClientStream) client.ClientStream {
   431  	return &wrappedClientStream{s}
   432  }
   433  
   434  func (w *wrappedClientStream) RecvMsg(m interface{}) error {
   435  	err := w.ClientStream.RecvMsg(m)
   436  	num := binary.LittleEndian.Uint64(m.(*codec.Body).Data[:8])
   437  	binary.LittleEndian.PutUint64(m.(*codec.Body).Data[:8], num+1)
   438  	return err
   439  }
   440  
   441  func (w *wrappedClientStream) SendMsg(m interface{}) error {
   442  	num := binary.LittleEndian.Uint64(m.(*codec.Body).Data[:8])
   443  	binary.LittleEndian.PutUint64(m.(*codec.Body).Data[:8], num+1)
   444  	return w.ClientStream.SendMsg(m)
   445  }
   446  
   447  func TestClientStreamClientFilters(t *testing.T) {
   448  	rawData := make([]byte, 1024)
   449  	rand.Read(rawData)
   450  	var beginNum uint64 = 100
   451  
   452  	counts := 1000
   453  	svrOpts := []server.Option{
   454  		server.WithAddress("127.0.0.1:30211"),
   455  		server.WithStreamFilters(serverFilterAdd1, serverFilterAdd2),
   456  	}
   457  	handle := func(s server.Stream) error {
   458  		var req *codec.Body
   459  
   460  		// server receives data.
   461  		for i := 0; i < counts; i++ {
   462  			req = getBytes(1024)
   463  			err := s.RecvMsg(req)
   464  			assert.Nil(t, err)
   465  			resultNum := binary.LittleEndian.Uint64(req.Data[:8])
   466  
   467  			// After the client SendMsg + server RecvMsg, two Filter, Num+4.
   468  			assert.Equal(t, beginNum+4, resultNum)
   469  			assert.Equal(t, rawData[8:], req.Data[8:])
   470  		}
   471  		err := s.RecvMsg(getBytes(1024))
   472  		assert.Equal(t, io.EOF, err)
   473  
   474  		// server sends data.
   475  		rsp := getBytes(1024)
   476  		for i := 0; i < counts; i++ {
   477  			copy(rsp.Data, req.Data)
   478  			err = s.SendMsg(rsp)
   479  			assert.Nil(t, err)
   480  		}
   481  		return nil
   482  	}
   483  	svr := startStreamServer(handle, svrOpts)
   484  	defer closeStreamServer(svr)
   485  
   486  	cliOpts := []client.Option{
   487  		client.WithTarget("ip://127.0.0.1:30211"),
   488  		client.WithStreamFilters(clientFilterAdd1, clientFilterAdd2),
   489  	}
   490  	cliStream, err := getClientStream(context.Background(), bidiDesc, cliOpts)
   491  	assert.Nil(t, err)
   492  
   493  	// client sends data.
   494  	for i := 0; i < counts; i++ {
   495  		req := getBytes(1024)
   496  		copy(req.Data, rawData)
   497  		binary.LittleEndian.PutUint64(req.Data[:8], beginNum)
   498  
   499  		err = cliStream.SendMsg(req)
   500  		assert.Nil(t, err)
   501  	}
   502  	err = cliStream.CloseSend()
   503  	assert.Nil(t, err)
   504  
   505  	// client receives data.
   506  	for i := 0; i < counts; i++ {
   507  		rsp := getBytes(1024)
   508  		err = cliStream.RecvMsg(rsp)
   509  		assert.Nil(t, err)
   510  
   511  		// After the client once SendMsg, once RecvMsg, two Filter, Num+4.
   512  		resultNum := binary.LittleEndian.Uint64(rsp.Data[:8])
   513  		assert.Equal(t, beginNum+8, resultNum)
   514  		assert.Equal(t, rawData[8:], rsp.Data[8:])
   515  	}
   516  	rsp := getBytes(1024)
   517  	err = cliStream.RecvMsg(rsp)
   518  	assert.Equal(t, io.EOF, err)
   519  }
   520  
   521  func TestClientStreamFlowControlStop(t *testing.T) {
   522  	windows := 102400
   523  	dataLen := 1024
   524  	maxSends := windows / dataLen
   525  	svrOpts := []server.Option{
   526  		server.WithAddress("127.0.0.1:30211"),
   527  		server.WithMaxWindowSize(uint32(windows)),
   528  	}
   529  	handle := func(s server.Stream) error {
   530  		time.Sleep(time.Hour)
   531  		return nil
   532  	}
   533  	svr := startStreamServer(handle, svrOpts)
   534  	defer closeStreamServer(svr)
   535  
   536  	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(200*time.Millisecond))
   537  	defer cancel()
   538  	cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30211")}
   539  	cliStream, err := getClientStream(ctx, bidiDesc, cliOpts)
   540  	assert.Nil(t, err)
   541  
   542  	req := getBytes(dataLen)
   543  	rand.Read(req.Data)
   544  
   545  	for i := 0; i < maxSends; i++ {
   546  		err = cliStream.SendMsg(req)
   547  		assert.Nil(t, err)
   548  	}
   549  	err = cliStream.SendMsg(req)
   550  	assert.Equal(t, errors.New("stream is already closed"), err)
   551  }
   552  
   553  func TestServerStreamFlowControlStop(t *testing.T) {
   554  	windows := 102400
   555  	dataLen := 1024
   556  	maxSends := windows / dataLen
   557  	waitCh := make(chan struct{}, 1)
   558  	svrOpts := []server.Option{server.WithAddress("127.0.0.1:30211")}
   559  	handle := func(s server.Stream) error {
   560  		rsp := getBytes(dataLen)
   561  		rand.Read(rsp.Data)
   562  		for i := 0; i < maxSends; i++ {
   563  			err := s.SendMsg(rsp)
   564  			assert.Nil(t, err)
   565  		}
   566  
   567  		finish := make(chan struct{}, 1)
   568  		go func() {
   569  			err := s.SendMsg(rsp)
   570  			assert.Equal(t, errors.New("stream is already closed"), err)
   571  			finish <- struct{}{}
   572  		}()
   573  
   574  		deadline := time.NewTimer(200 * time.Millisecond)
   575  		select {
   576  		case <-deadline.C:
   577  		case <-finish:
   578  			assert.Fail(t, "SendMsg should block")
   579  		}
   580  
   581  		waitCh <- struct{}{}
   582  		return nil
   583  	}
   584  	svr := startStreamServer(handle, svrOpts)
   585  	defer closeStreamServer(svr)
   586  
   587  	cliOpts := []client.Option{
   588  		client.WithTarget("ip://127.0.0.1:30211"),
   589  		client.WithMaxWindowSize(uint32(windows)),
   590  	}
   591  	_, err := getClientStream(context.Background(), bidiDesc, cliOpts)
   592  	assert.Nil(t, err)
   593  	<-waitCh
   594  }
   595  
   596  func TestClientStreamSendRecvNoBlock(t *testing.T) {
   597  	svrOpts := []server.Option{server.WithAddress("127.0.0.1:30210")}
   598  	handle := func(s server.Stream) error {
   599  		// Must sleep, to avoid returning before receiving the first packet from the client,
   600  		// resulting in the processing of the first packet returns an error,
   601  		// losing the chance for the test client to block on the second SendMsg.
   602  		time.Sleep(200 * time.Millisecond)
   603  		return errors.New("test error")
   604  	}
   605  	svr := startStreamServer(handle, svrOpts)
   606  	defer closeStreamServer(svr)
   607  
   608  	cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30210")}
   609  	cliStream, err := getClientStream(context.Background(), bidiDesc, cliOpts)
   610  	assert.Nil(t, err)
   611  
   612  	// defaultInitWindowSize = 65535.
   613  	req := getBytes(65535)
   614  	err = cliStream.SendMsg(req)
   615  	assert.Nil(t, err)
   616  
   617  	err = cliStream.SendMsg(req)
   618  	fmt.Println(err)
   619  	assert.NotNil(t, err)
   620  
   621  	rsp := getBytes(1024)
   622  	err = cliStream.RecvMsg(rsp)
   623  	assert.NotNil(t, err)
   624  }
   625  
   626  func TestServerStreamSendRecvNoBlock(t *testing.T) {
   627  	svrOpts := []server.Option{server.WithAddress("127.0.0.1:30210")}
   628  	SendMsgReturn := make(chan struct{}, 1)
   629  	RecvMsgReturn := make(chan struct{}, 1)
   630  	handle := func(s server.Stream) error {
   631  		go func() {
   632  			msg := getBytes(65535)
   633  			s.SendMsg(msg)
   634  			s.SendMsg(msg)
   635  			SendMsgReturn <- struct{}{}
   636  		}()
   637  		go func() {
   638  			msg := getBytes(1024)
   639  			s.RecvMsg(msg)
   640  			s.RecvMsg(msg)
   641  			RecvMsgReturn <- struct{}{}
   642  		}()
   643  		// Must sleep, to avoid the second SendMsg does not enter the waiting window to block.
   644  		time.Sleep(200 * time.Millisecond)
   645  		return nil
   646  	}
   647  	svr := startStreamServer(handle, svrOpts)
   648  	defer closeStreamServer(svr)
   649  
   650  	cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30210")}
   651  	_, err := getClientStream(context.Background(), bidiDesc, cliOpts)
   652  	assert.Nil(t, err)
   653  
   654  	<-SendMsgReturn
   655  	<-RecvMsgReturn
   656  }
   657  
   658  func TestClientStreamReturn(t *testing.T) {
   659  	const (
   660  		invalidCompressType = -1
   661  		dataLen             = 1024
   662  	)
   663  
   664  	svrOpts := []server.Option{
   665  		server.WithAddress("127.0.0.1:30211"),
   666  		server.WithCurrentCompressType(invalidCompressType),
   667  	}
   668  	handle := func(s server.Stream) error {
   669  		req := getBytes(dataLen)
   670  		s.RecvMsg(req)
   671  		rsp := req
   672  		s.SendMsg(rsp)
   673  		return errs.NewFrameError(101, "expected error")
   674  	}
   675  	svr := startStreamServer(handle, svrOpts)
   676  	defer closeStreamServer(svr)
   677  
   678  	cliOpts := []client.Option{
   679  		client.WithTarget("ip://127.0.0.1:30211"),
   680  		client.WithCompressType(invalidCompressType),
   681  	}
   682  
   683  	clientStream, err := getClientStream(context.Background(), clientDesc, cliOpts)
   684  	assert.Nil(t, err)
   685  	err = clientStream.SendMsg(getBytes(dataLen))
   686  	assert.Nil(t, err)
   687  
   688  	rsp := getBytes(dataLen)
   689  	err = clientStream.RecvMsg(rsp)
   690  	assert.EqualValues(t, int32(101), err.(*errs.Error).Code)
   691  }
   692  
   693  // TestClientSendFailWhenServerUnavailable test when the client blocks
   694  // on SendMsg because of flow control, if the server is closed, the client
   695  // SendMsg should return.
   696  func TestClientSendFailWhenServerUnavailable(t *testing.T) {
   697  	codec.Register("mock", nil, &fakeCodec{})
   698  	tp := &fakeTransport{expectChan: make(chan recvExpect, 1)}
   699  	tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) {
   700  		return nil, nil
   701  	}
   702  	cs, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "",
   703  		client.WithProtocol("mock"),
   704  		client.WithTarget("ip://127.0.0.1:8000"),
   705  		client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   706  		client.WithStreamTransport(tp),
   707  	)
   708  	assert.Nil(t, err)
   709  	assert.NotNil(t, cs)
   710  	assert.Nil(t, cs.SendMsg(getBytes(65535)))
   711  	tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) {
   712  		return nil, errors.New("server is closed")
   713  	}
   714  	assert.Eventually(
   715  		t,
   716  		func() bool {
   717  			cs.SendMsg(getBytes(65535))
   718  			return true
   719  		},
   720  		time.Second,
   721  		100*time.Millisecond,
   722  	)
   723  }
   724  
   725  // TestClientReceiveErrorWhenServerUnavailable tests that the client receives a non-io.EOF
   726  // error when the server crashes or the connection is closed, otherwise the client would
   727  // mistakenly think that the server closed the stream normally.
   728  func TestClientReceiveErrorWhenServerUnavailable(t *testing.T) {
   729  	codec.Register("mock", nil, &fakeCodec{})
   730  	tp := &fakeTransport{expectChan: make(chan recvExpect, 1)}
   731  	tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) {
   732  		return nil, nil
   733  	}
   734  	cs, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "",
   735  		client.WithProtocol("mock"),
   736  		client.WithTarget("ip://127.0.0.1:8000"),
   737  		client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   738  		client.WithStreamTransport(tp),
   739  	)
   740  	assert.Nil(t, err)
   741  	assert.NotNil(t, cs)
   742  	tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) {
   743  		return nil, io.EOF
   744  	}
   745  	err = cs.RecvMsg(nil)
   746  	assert.NotEqual(t, io.EOF, err)
   747  	assert.ErrorIs(t, err, io.EOF)
   748  }