trpc.group/trpc-go/trpc-go@v1.0.2/stream/server_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package stream_test
    15  
    16  import (
    17  	"bytes"
    18  	"context"
    19  	"encoding/binary"
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"math/rand"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  
    28  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    29  
    30  	"trpc.group/trpc-go/trpc-go/client"
    31  	"trpc.group/trpc-go/trpc-go/errs"
    32  
    33  	trpc "trpc.group/trpc-go/trpc-go"
    34  	"trpc.group/trpc-go/trpc-go/stream"
    35  
    36  	"trpc.group/trpc-go/trpc-go/codec"
    37  	"trpc.group/trpc-go/trpc-go/server"
    38  	"trpc.group/trpc-go/trpc-go/transport"
    39  
    40  	"github.com/stretchr/testify/assert"
    41  )
    42  
    43  type fakeStreamHandle struct {
    44  }
    45  
    46  // StreamHandleFunc Mock StreamHandleFunc method
    47  func (fs *fakeStreamHandle) StreamHandleFunc(ctx context.Context, sh server.StreamHandler, req []byte) ([]byte, error) {
    48  	return nil, nil
    49  }
    50  
    51  // Init Mock Init method
    52  func (fs *fakeStreamHandle) Init(opts *server.Options) {
    53  	return
    54  }
    55  
    56  type fakeServerTransport struct{}
    57  
    58  type fakeServerCodec struct{}
    59  
    60  // Send Mock Send method
    61  func (s *fakeServerTransport) Send(ctx context.Context, rspBuf []byte) error {
    62  	if string(rspBuf) == "init-error" {
    63  		return errors.New("init-error")
    64  	}
    65  	return nil
    66  }
    67  
    68  // Close Mock Close method
    69  func (s *fakeServerTransport) Close(ctx context.Context) {
    70  	return
    71  }
    72  
    73  // ListenAndServe Mock ListenAndServe method
    74  func (s *fakeServerTransport) ListenAndServe(ctx context.Context, opts ...transport.ListenServeOption) error {
    75  
    76  	return nil
    77  }
    78  
    79  // Decode Mock codec Decode method
    80  func (c *fakeServerCodec) Decode(msg codec.Msg, reqBuf []byte) (reqBody []byte, err error) {
    81  	return reqBuf, nil
    82  }
    83  
    84  // Encode Mock codec Encode method
    85  func (c *fakeServerCodec) Encode(msg codec.Msg, rspBody []byte) (rspBuf []byte, err error) {
    86  	rsp := string(rspBody)
    87  	if rsp == "encode-error" {
    88  		return nil, errors.New("server encode response fail")
    89  	}
    90  	if msg.StreamID() < uint32(100) {
    91  		return nil, errors.New("streamID less than 100")
    92  	}
    93  	if msg.StreamID() == uint32(101) {
    94  		return []byte("init-error"), nil
    95  	}
    96  	return rspBody, nil
    97  }
    98  
    99  func streamHandler(stream server.Stream) error {
   100  	time.Sleep(time.Second)
   101  	return nil
   102  }
   103  
   104  func errorStreamHandler(stream server.Stream) error {
   105  	return errors.New("handle fail")
   106  }
   107  
   108  type fakeAddr struct {
   109  }
   110  
   111  // Network method of Network Mock net.Addr interface
   112  func (f *fakeAddr) Network() string {
   113  	return "tcp"
   114  }
   115  
   116  // String method of String Mock net.Addr interface
   117  func (f *fakeAddr) String() string {
   118  	return "127.0.0.01:67891"
   119  }
   120  
   121  // TestStreamDispatcherHandleInit Test Stream Dispatcher
   122  func TestStreamDispatcherHandleInit(t *testing.T) {
   123  	codec.Register("fake", &fakeServerCodec{}, nil)
   124  
   125  	si := &server.StreamServerInfo{}
   126  	dispatcher := stream.NewStreamDispatcher()
   127  	assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher)
   128  
   129  	// Init test
   130  	opts := &server.Options{}
   131  	ft := &fakeServerTransport{}
   132  	opts.Transport = ft
   133  	opts.Codec = codec.GetServer("fake")
   134  	err := dispatcher.Init(opts)
   135  	assert.Nil(t, err)
   136  	assert.Equal(t, opts.Transport, opts.StreamTransport)
   137  	// StreamHandleFunc msg not nil
   138  	ctx := context.Background()
   139  	ctx, msg := codec.WithNewMessage(ctx)
   140  	rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, nil)
   141  	assert.Nil(t, rsp)
   142  	assert.Contains(t, err.Error(), "frameHead is not contained in msg")
   143  	msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{})
   144  	// StreamHandleFunc handle init
   145  	fh := &trpc.FrameHead{}
   146  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)
   147  	msg.WithFrameHead(fh)
   148  	msg.WithStreamID(uint32(100))
   149  	msg.WithRemoteAddr(&fakeAddr{})
   150  	rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init"))
   151  	assert.Nil(t, rsp)
   152  	assert.Equal(t, err, errs.ErrServerNoResponse)
   153  
   154  	// StreamHandleFunc handle init with codec encode error
   155  	msg.WithFrameHead(fh)
   156  	msg.WithStreamID(uint32(99))
   157  	rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init"))
   158  	assert.Nil(t, rsp)
   159  	assert.Equal(t, err.Error(), "streamID less than 100")
   160  
   161  	// StreamHandleFunc handle init send error
   162  	msg.WithFrameHead(fh)
   163  	msg.WithStreamID(uint32(101))
   164  	rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init-error"))
   165  	assert.Nil(t, rsp)
   166  	assert.Contains(t, err.Error(), "init-error")
   167  
   168  	// StreamHandleFun handle data to validate streamID was stored
   169  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)
   170  	msg.WithFrameHead(fh)
   171  	rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data"))
   172  	assert.Nil(t, rsp)
   173  	assert.Equal(t, err, errs.ErrServerNoResponse)
   174  
   175  	// StreamHandleFunc handle error
   176  	msg.WithFrameHead(fh)
   177  	msg.WithStreamID(100)
   178  	rsp, err = dispatcher.StreamHandleFunc(ctx, errorStreamHandler, si, []byte("init"))
   179  	assert.Nil(t, rsp)
   180  	assert.Equal(t, err, errs.ErrServerNoResponse)
   181  	time.Sleep(100 * time.Millisecond)
   182  }
   183  
   184  // TestStreamDispatcherHandleData test StreamDispatcher Handle data
   185  func TestStreamDispatcherHandleData(t *testing.T) {
   186  	codec.Register("fake", &fakeServerCodec{}, nil)
   187  
   188  	si := &server.StreamServerInfo{}
   189  	dispatcher := stream.NewStreamDispatcher()
   190  	assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher)
   191  
   192  	// Init test
   193  	opts := &server.Options{}
   194  	ft := &fakeServerTransport{}
   195  	opts.Transport = ft
   196  	opts.Codec = codec.GetServer("fake")
   197  	err := dispatcher.Init(opts)
   198  	assert.Nil(t, err)
   199  	assert.Equal(t, opts.Transport, opts.StreamTransport)
   200  
   201  	ctx := context.Background()
   202  	ctx, msg := codec.WithNewMessage(ctx)
   203  	fh := &trpc.FrameHead{}
   204  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)
   205  	msg.WithFrameHead(fh)
   206  	msg.WithStreamID(uint32(100))
   207  	msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{})
   208  	addr := &fakeAddr{}
   209  	msg.WithRemoteAddr(addr)
   210  	rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init"))
   211  	assert.Nil(t, rsp)
   212  	assert.Equal(t, err, errs.ErrServerNoResponse)
   213  
   214  	// handleData normal
   215  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)
   216  	msg.WithFrameHead(fh)
   217  	rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data"))
   218  	assert.Nil(t, rsp)
   219  	assert.Equal(t, err, errs.ErrServerNoResponse)
   220  
   221  	// handleData error no such addr
   222  	msg.WithRemoteAddr(nil)
   223  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)
   224  	msg.WithFrameHead(fh)
   225  	rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data"))
   226  	assert.Nil(t, rsp)
   227  	assert.Contains(t, err.Error(), "no such addr")
   228  
   229  	// handle data error no such stream id
   230  	msg.WithRemoteAddr(addr)
   231  	msg.WithStreamID(uint32(101))
   232  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)
   233  	msg.WithFrameHead(fh)
   234  	rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data"))
   235  	assert.Nil(t, rsp)
   236  	assert.Contains(t, err.Error(), "no such stream ID")
   237  }
   238  
   239  // TestStreamDispatcherHandleClose test handles Close frame
   240  func TestStreamDispatcherHandleClose(t *testing.T) {
   241  
   242  	codec.Register("fake", &fakeServerCodec{}, nil)
   243  
   244  	si := &server.StreamServerInfo{}
   245  	dispatcher := stream.NewStreamDispatcher()
   246  	assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher)
   247  
   248  	// Init test
   249  	opts := &server.Options{}
   250  	ft := &fakeServerTransport{}
   251  	opts.Transport = ft
   252  	opts.Codec = codec.GetServer("fake")
   253  	err := dispatcher.Init(opts)
   254  	assert.Nil(t, err)
   255  	assert.Equal(t, opts.Transport, opts.StreamTransport)
   256  
   257  	ctx := context.Background()
   258  	ctx, msg := codec.WithNewMessage(ctx)
   259  	fh := &trpc.FrameHead{}
   260  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)
   261  	msg.WithFrameHead(fh)
   262  	msg.WithStreamID(uint32(100))
   263  	msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{})
   264  
   265  	addr := &fakeAddr{}
   266  	msg.WithRemoteAddr(addr)
   267  	msg.WithFrameHead(fh)
   268  	rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init"))
   269  	assert.Nil(t, rsp)
   270  	assert.Equal(t, err, errs.ErrServerNoResponse)
   271  
   272  	// handle close normal
   273  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE)
   274  	msg.WithFrameHead(fh)
   275  	rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("close"))
   276  	assert.Nil(t, rsp)
   277  	assert.Equal(t, errs.ErrServerNoResponse, err)
   278  
   279  	// handle close no such addr
   280  	msg.WithFrameHead(fh)
   281  	msg.WithRemoteAddr(nil)
   282  	rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("close"))
   283  	assert.Nil(t, rsp)
   284  	assert.Equal(t, errs.ErrServerNoResponse, err)
   285  
   286  	// handle close server rsp err
   287  	msg.WithRemoteAddr(addr)
   288  	msg.WithFrameHead(fh)
   289  	msg.WithServerRspErr(errors.New("server rsp error"))
   290  	rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("close"))
   291  	assert.Nil(t, rsp)
   292  	assert.Equal(t, errs.ErrServerNoResponse, err)
   293  
   294  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK)
   295  	msg.WithFrameHead(fh)
   296  	msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{})
   297  	rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("feedback"))
   298  	assert.Nil(t, rsp)
   299  	assert.Equal(t, err, errs.ErrServerNoResponse)
   300  
   301  	fh.StreamFrameType = uint8(8)
   302  	msg.WithFrameHead(fh)
   303  	rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("unknown"))
   304  	assert.Nil(t, rsp)
   305  	assert.Contains(t, err.Error(), "unknown frame type")
   306  }
   307  
   308  // TestServerStreamSendMsg test server receives messages
   309  func TestServerStreamSendMsg(t *testing.T) {
   310  	codec.Register("fake", &fakeServerCodec{}, nil)
   311  
   312  	si := &server.StreamServerInfo{}
   313  	dispatcher := stream.NewStreamDispatcher()
   314  	assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher)
   315  
   316  	// Init test
   317  	opts := &server.Options{}
   318  	ft := &fakeServerTransport{}
   319  	opts.Transport = ft
   320  	opts.Codec = codec.GetServer("fake")
   321  	err := dispatcher.Init(opts)
   322  	assert.Nil(t, err)
   323  	assert.Equal(t, opts.Transport, opts.StreamTransport)
   324  
   325  	// StreamHandleFunc msg not nil
   326  	ctx := context.Background()
   327  	ctx, msg := codec.WithNewMessage(ctx)
   328  	fh := &trpc.FrameHead{}
   329  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)
   330  	msg.WithFrameHead(fh)
   331  	msg.WithStreamID(uint32(100))
   332  	msg.WithRemoteAddr(&fakeAddr{})
   333  	msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{})
   334  
   335  	opts.CurrentCompressType = codec.CompressTypeNoop
   336  	opts.CurrentSerializationType = codec.SerializationTypeNoop
   337  
   338  	sh := func(ss server.Stream) error {
   339  		ctx = ss.Context()
   340  		assert.NotNil(t, ctx)
   341  		err := ss.SendMsg(&codec.Body{Data: []byte("init")})
   342  		assert.Nil(t, err)
   343  		return err
   344  	}
   345  	rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init"))
   346  	assert.Nil(t, rsp)
   347  	assert.Equal(t, err, errs.ErrServerNoResponse)
   348  	time.Sleep(100 * time.Millisecond)
   349  
   350  	opts.CurrentCompressType = 5
   351  	opts.CurrentSerializationType = codec.SerializationTypeNoop
   352  	sh = func(ss server.Stream) error {
   353  		ctx = ss.Context()
   354  		assert.NotNil(t, ctx)
   355  		err := ss.SendMsg(&codec.Body{Data: []byte("init")})
   356  		assert.NotNil(t, err)
   357  		return err
   358  	}
   359  	dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init"))
   360  	time.Sleep(200 * time.Millisecond)
   361  
   362  	opts.CurrentCompressType = codec.CompressTypeNoop
   363  	opts.CurrentSerializationType = codec.SerializationTypeNoop
   364  	sh = func(ss server.Stream) error {
   365  		ctx = ss.Context()
   366  		assert.NotNil(t, ctx)
   367  		err := ss.SendMsg(&codec.Body{Data: []byte("encode-error")})
   368  		assert.Contains(t, err.Error(), "server codec Encode")
   369  		return err
   370  	}
   371  	dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init"))
   372  	time.Sleep(200 * time.Millisecond)
   373  
   374  	opts.CurrentCompressType = codec.CompressTypeNoop
   375  	opts.CurrentSerializationType = codec.SerializationTypeNoop
   376  	sh = func(ss server.Stream) error {
   377  		ctx = ss.Context()
   378  		assert.NotNil(t, ctx)
   379  		err := ss.SendMsg(&codec.Body{Data: []byte("init-error")})
   380  		return err
   381  	}
   382  	dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init"))
   383  	time.Sleep(200 * time.Millisecond)
   384  }
   385  
   386  // TestServerStreamRecvMsg test receive message
   387  func TestServerStreamRecvMsg(t *testing.T) {
   388  	codec.Register("fake", &fakeServerCodec{}, nil)
   389  
   390  	si := &server.StreamServerInfo{}
   391  	dispatcher := stream.NewStreamDispatcher()
   392  	assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher)
   393  
   394  	// Init test
   395  	opts := &server.Options{}
   396  	ft := &fakeServerTransport{}
   397  	opts.Transport = ft
   398  	opts.Codec = codec.GetServer("fake")
   399  	err := dispatcher.Init(opts)
   400  	assert.Nil(t, err)
   401  	assert.Equal(t, opts.Transport, opts.StreamTransport)
   402  
   403  	// StreamHandleFunc msg not nil
   404  	ctx := context.Background()
   405  	ctx, msg := codec.WithNewMessage(ctx)
   406  	fh := &trpc.FrameHead{}
   407  	msg.WithFrameHead(fh)
   408  	msg.WithStreamID(uint32(100))
   409  	msg.WithRemoteAddr(&fakeAddr{})
   410  	msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{})
   411  	opts.CurrentCompressType = codec.CompressTypeNoop
   412  	opts.CurrentSerializationType = codec.SerializationTypeNoop
   413  
   414  	sh := func(ss server.Stream) error {
   415  		ctx := ss.Context()
   416  		assert.NotNil(t, ctx)
   417  		body := &codec.Body{}
   418  		err := ss.RecvMsg(body)
   419  		assert.Nil(t, err)
   420  		assert.Equal(t, string(body.Data), "data")
   421  		err = ss.RecvMsg(body)
   422  		assert.Equal(t, err, io.EOF)
   423  
   424  		err = ss.RecvMsg(body)
   425  		assert.Equal(t, err, io.EOF)
   426  		return err
   427  	}
   428  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)
   429  	rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init"))
   430  	assert.Nil(t, rsp)
   431  	assert.Equal(t, err, errs.ErrServerNoResponse)
   432  	// handleData normal
   433  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)
   434  	msg.WithFrameHead(fh)
   435  	rsp, err = dispatcher.StreamHandleFunc(ctx, sh, si, []byte("data"))
   436  	assert.Nil(t, rsp)
   437  	assert.Equal(t, err, errs.ErrServerNoResponse)
   438  
   439  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE)
   440  	msg.WithFrameHead(fh)
   441  	rsp, err = dispatcher.StreamHandleFunc(ctx, sh, si, []byte("close"))
   442  	assert.Nil(t, rsp)
   443  	assert.Equal(t, err, errs.ErrServerNoResponse)
   444  
   445  	time.Sleep(300 * time.Millisecond)
   446  }
   447  
   448  // TestServerStreamRecvMsgFail test for failure to receive data
   449  func TestServerStreamRecvMsgFail(t *testing.T) {
   450  	codec.Register("fake", &fakeServerCodec{}, nil)
   451  	si := &server.StreamServerInfo{}
   452  	dispatcher := stream.NewStreamDispatcher()
   453  	assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher)
   454  	// Init test
   455  	opts := &server.Options{}
   456  	ft := &fakeServerTransport{}
   457  	opts.Transport = ft
   458  	opts.Codec = codec.GetServer("fake")
   459  	err := dispatcher.Init(opts)
   460  	assert.Nil(t, err)
   461  	assert.Equal(t, opts.Transport, opts.StreamTransport)
   462  
   463  	// StreamHandleFunc msg not nil
   464  	ctx := context.Background()
   465  	ctx, msg := codec.WithNewMessage(ctx)
   466  	fh := &trpc.FrameHead{}
   467  	msg.WithFrameHead(fh)
   468  	msg.WithStreamID(uint32(100))
   469  	msg.WithRemoteAddr(&fakeAddr{})
   470  	msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{})
   471  
   472  	opts.CurrentCompressType = codec.CompressTypeGzip
   473  	opts.CurrentSerializationType = codec.SerializationTypeNoop
   474  
   475  	sh := func(ss server.Stream) error {
   476  		ctx := ss.Context()
   477  		assert.NotNil(t, ctx)
   478  		body := &codec.Body{}
   479  		err := ss.RecvMsg(body)
   480  		assert.NotNil(t, err)
   481  		assert.Contains(t, err.Error(), "server codec Decompress")
   482  
   483  		err = ss.RecvMsg(body)
   484  		assert.NotNil(t, err)
   485  		assert.Contains(t, err.Error(), "server codec Unmarshal")
   486  		return err
   487  	}
   488  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)
   489  	msg.WithFrameHead(fh)
   490  	rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init"))
   491  	assert.Nil(t, rsp)
   492  	assert.Equal(t, err, errs.ErrServerNoResponse)
   493  	// handleData normal
   494  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)
   495  	msg.WithFrameHead(fh)
   496  	rsp, err = dispatcher.StreamHandleFunc(ctx, sh, si, []byte("data"))
   497  	assert.Nil(t, rsp)
   498  	assert.Equal(t, err, errs.ErrServerNoResponse)
   499  }
   500  
   501  // TesthandleError test server error condition
   502  func TestHandleError(t *testing.T) {
   503  	codec.Register("fake", &fakeServerCodec{}, nil)
   504  	si := &server.StreamServerInfo{}
   505  	dispatcher := stream.NewStreamDispatcher()
   506  	assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher)
   507  	// Init test
   508  	opts := &server.Options{}
   509  	ft := &fakeServerTransport{}
   510  	opts.Transport = ft
   511  	opts.Codec = codec.GetServer("fake")
   512  	err := dispatcher.Init(opts)
   513  	assert.Nil(t, err)
   514  	assert.Equal(t, opts.Transport, opts.StreamTransport)
   515  
   516  	// StreamHandleFunc msg not nil
   517  	ctx := context.Background()
   518  	ctx, msg := codec.WithNewMessage(ctx)
   519  	fh := &trpc.FrameHead{}
   520  	msg.WithFrameHead(fh)
   521  	msg.WithStreamID(uint32(100))
   522  	msg.WithRemoteAddr(&fakeAddr{})
   523  	msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{})
   524  
   525  	opts.CurrentCompressType = codec.CompressTypeGzip
   526  	opts.CurrentSerializationType = codec.SerializationTypeNoop
   527  
   528  	sh := func(ss server.Stream) error {
   529  		ctx := ss.Context()
   530  		assert.NotNil(t, ctx)
   531  		body := &codec.Body{}
   532  		err := ss.RecvMsg(body)
   533  		assert.NotNil(t, err)
   534  		assert.Contains(t, err.Error(), "Connection is closed")
   535  		return err
   536  	}
   537  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)
   538  	rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init"))
   539  	assert.Nil(t, rsp)
   540  	assert.Equal(t, err, errs.ErrServerNoResponse)
   541  	// handleError
   542  	msg.WithFrameHead(nil)
   543  	msg.WithServerRspErr(errors.New("Connection is closed"))
   544  
   545  	noopSh := func(ss server.Stream) error {
   546  		return nil
   547  	}
   548  	msg.WithFrameHead(fh)
   549  	rsp, err = dispatcher.StreamHandleFunc(ctx, noopSh, si, nil)
   550  	assert.Nil(t, rsp)
   551  	assert.Equal(t, err, errs.ErrServerNoResponse)
   552  	time.Sleep(100 * time.Millisecond)
   553  }
   554  
   555  // TestStreamDispatcherHandleFeedback test handles feedback frame
   556  func TestStreamDispatcherHandleFeedback(t *testing.T) {
   557  
   558  	codec.Register("fake", &fakeServerCodec{}, nil)
   559  	si := &server.StreamServerInfo{}
   560  
   561  	dispatcher := stream.NewStreamDispatcher()
   562  	assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher)
   563  
   564  	// Init test
   565  	opts := &server.Options{}
   566  	ft := &fakeServerTransport{}
   567  	opts.Transport = ft
   568  	opts.Codec = codec.GetServer("fake")
   569  	err := dispatcher.Init(opts)
   570  	assert.Nil(t, err)
   571  	assert.Equal(t, opts.Transport, opts.StreamTransport)
   572  
   573  	ctx := context.Background()
   574  	ctx, msg := codec.WithNewMessage(ctx)
   575  	fh := &trpc.FrameHead{}
   576  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)
   577  	msg.WithFrameHead(fh)
   578  	msg.WithStreamID(uint32(100))
   579  	msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{InitWindowSize: 10})
   580  
   581  	sh := func(ss server.Stream) error {
   582  		time.Sleep(time.Second)
   583  		return nil
   584  	}
   585  
   586  	addr := &fakeAddr{}
   587  	msg.WithRemoteAddr(addr)
   588  	rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init"))
   589  	assert.Nil(t, rsp)
   590  	assert.Equal(t, err, errs.ErrServerNoResponse)
   591  
   592  	// handle feedback get server stream fail
   593  	msg.WithRemoteAddr(nil)
   594  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK)
   595  	msg.WithFrameHead(fh)
   596  	rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback"))
   597  	assert.Nil(t, rsp)
   598  	assert.NotNil(t, err)
   599  
   600  	// handle feedback invalid stream
   601  	msg.WithRemoteAddr(addr)
   602  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK)
   603  	msg.WithFrameHead(fh)
   604  	rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback"))
   605  	assert.Nil(t, rsp)
   606  	assert.NotNil(t, err)
   607  
   608  	// normal feedback
   609  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK)
   610  	msg.WithFrameHead(fh)
   611  	msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{WindowSizeIncrement: 1000})
   612  	rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback"))
   613  	assert.Nil(t, rsp)
   614  	assert.Equal(t, err, errs.ErrServerNoResponse)
   615  }
   616  
   617  // TestServerFlowControl tests the situation of server-side flow control
   618  func TestServerFlowControl(t *testing.T) {
   619  	codec.Register("fake", &fakeServerCodec{}, nil)
   620  	si := &server.StreamServerInfo{}
   621  	dispatcher := stream.NewStreamDispatcher()
   622  	// Init test
   623  	opts := &server.Options{}
   624  	ft := &fakeServerTransport{}
   625  	opts.Transport = ft
   626  	opts.Codec = codec.GetServer("fake")
   627  	err := dispatcher.Init(opts)
   628  	assert.Nil(t, err)
   629  	assert.Equal(t, opts.Transport, opts.StreamTransport)
   630  	// StreamHandleFunc msg not nil
   631  	ctx := context.Background()
   632  	ctx, msg := codec.WithNewMessage(ctx)
   633  	fh := &trpc.FrameHead{}
   634  	msg.WithFrameHead(fh)
   635  	msg.WithStreamID(uint32(100))
   636  	addr := &fakeAddr{}
   637  	msg.WithRemoteAddr(addr)
   638  	msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{InitWindowSize: 65535})
   639  	opts.CurrentCompressType = codec.CompressTypeNoop
   640  	opts.CurrentSerializationType = codec.SerializationTypeNoop
   641  	var wg sync.WaitGroup
   642  	wg.Add(1)
   643  	sh := func(ss server.Stream) error {
   644  		defer wg.Done()
   645  		for i := 0; i < 20000; i++ {
   646  			body := &codec.Body{}
   647  			err := ss.RecvMsg(body)
   648  			assert.Nil(t, err)
   649  			assert.Equal(t, string(body.Data), "data")
   650  		}
   651  		return nil
   652  	}
   653  	fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)
   654  	rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init"))
   655  	assert.Nil(t, rsp)
   656  	assert.Equal(t, err, errs.ErrServerNoResponse)
   657  
   658  	// handleData normal
   659  	for i := 0; i < 20000; i++ {
   660  		newCtx := context.Background()
   661  		newCtx, newMsg := codec.WithNewMessage(newCtx)
   662  		newMsg.WithStreamID(uint32(100))
   663  		newMsg.WithRemoteAddr(addr)
   664  		newFh := &trpc.FrameHead{}
   665  		newFh.StreamID = uint32(100)
   666  		newFh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)
   667  		newMsg.WithFrameHead(newFh)
   668  		rsp, err := dispatcher.StreamHandleFunc(newCtx, sh, si, []byte("data"))
   669  		assert.Nil(t, rsp)
   670  		assert.Equal(t, err, errs.ErrServerNoResponse)
   671  	}
   672  	wg.Wait()
   673  }
   674  
   675  func TestClientStreamFlowControl(t *testing.T) {
   676  	svrOpts := []server.Option{server.WithAddress("127.0.0.1:30210")}
   677  	handle := func(s server.Stream) error {
   678  		req := getBytes(1024)
   679  		for i := 0; i < 1000; i++ {
   680  			err := s.RecvMsg(req)
   681  			assert.Nil(t, err)
   682  		}
   683  		err := s.RecvMsg(req)
   684  		assert.Equal(t, io.EOF, err)
   685  
   686  		rsp := getBytes(1024)
   687  		copy(rsp.Data, req.Data)
   688  		for i := 0; i < 1000; i++ {
   689  			err = s.SendMsg(rsp)
   690  			assert.Nil(t, err)
   691  		}
   692  		return nil
   693  	}
   694  	svr := startStreamServer(handle, svrOpts)
   695  	defer closeStreamServer(svr)
   696  
   697  	cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30210")}
   698  	cliStream, err := getClientStream(context.Background(), bidiDesc, cliOpts)
   699  	assert.Nil(t, err)
   700  
   701  	req := getBytes(1024)
   702  	rand.Read(req.Data)
   703  	for i := 0; i < 1000; i++ {
   704  		err = cliStream.SendMsg(req)
   705  		assert.Nil(t, err)
   706  	}
   707  	err = cliStream.CloseSend()
   708  	assert.Nil(t, err)
   709  	rsp := getBytes(1024)
   710  	for i := 0; i < 1000; i++ {
   711  		err = cliStream.RecvMsg(rsp)
   712  		assert.Nil(t, err)
   713  		assert.Equal(t, req, rsp)
   714  	}
   715  	err = cliStream.RecvMsg(rsp)
   716  	assert.Equal(t, io.EOF, err)
   717  }
   718  
   719  func TestServerStreamFlowControl(t *testing.T) {
   720  	svrOpts := []server.Option{server.WithAddress("127.0.0.1:30211")}
   721  	handle := func(s server.Stream) error {
   722  		req := getBytes(1024)
   723  		err := s.RecvMsg(req)
   724  		assert.Nil(t, err)
   725  
   726  		rsp := getBytes(1024)
   727  		copy(rsp.Data, req.Data)
   728  		for i := 0; i < 1000; i++ {
   729  			err := s.SendMsg(rsp)
   730  			assert.Nil(t, err)
   731  		}
   732  		return nil
   733  	}
   734  	svr := startStreamServer(handle, svrOpts)
   735  	defer closeStreamServer(svr)
   736  
   737  	cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30211")}
   738  	cliStream, err := getClientStream(context.Background(), bidiDesc, cliOpts)
   739  	assert.Nil(t, err)
   740  
   741  	req := getBytes(1024)
   742  	rand.Read(req.Data)
   743  	err = cliStream.SendMsg(req)
   744  	assert.Nil(t, err)
   745  	err = cliStream.CloseSend()
   746  	assert.Nil(t, err)
   747  	rsp := getBytes(1024)
   748  	for i := 0; i < 1000; i++ {
   749  		err = cliStream.RecvMsg(rsp)
   750  		assert.Nil(t, err)
   751  		assert.Equal(t, req, rsp)
   752  	}
   753  	err = cliStream.RecvMsg(rsp)
   754  	assert.Equal(t, err, io.EOF)
   755  }
   756  
   757  func startStreamServer(handle func(server.Stream) error, opts []server.Option) server.Service {
   758  	svrOpts := []server.Option{
   759  		server.WithProtocol("trpc"),
   760  		server.WithNetwork("tcp"),
   761  		server.WithStreamTransport(transport.NewServerStreamTransport(transport.WithReusePort(true))),
   762  		server.WithTransport(transport.NewServerStreamTransport(transport.WithReusePort(true))),
   763  		// The server must actively set the serialization method
   764  		server.WithCurrentSerializationType(codec.SerializationTypeNoop),
   765  	}
   766  	svrOpts = append(svrOpts, opts...)
   767  	svr := server.New(svrOpts...)
   768  	register(svr, handle)
   769  	go func() {
   770  		err := svr.Serve()
   771  		if err != nil {
   772  			panic(err)
   773  		}
   774  	}()
   775  	time.Sleep(100 * time.Millisecond)
   776  	return svr
   777  }
   778  
   779  func closeStreamServer(svr server.Service) {
   780  	ch := make(chan struct{}, 1)
   781  	svr.Close(ch)
   782  	<-ch
   783  }
   784  
   785  var (
   786  	clientDesc = &client.ClientStreamDesc{
   787  		StreamName:    "streamTest",
   788  		ClientStreams: true,
   789  		ServerStreams: false,
   790  	}
   791  	serverDesc = &client.ClientStreamDesc{
   792  		StreamName:    "streamTest",
   793  		ClientStreams: false,
   794  		ServerStreams: true,
   795  	}
   796  	bidiDesc = &client.ClientStreamDesc{
   797  		StreamName:    "streamTest",
   798  		ClientStreams: true,
   799  		ServerStreams: true,
   800  	}
   801  )
   802  
   803  func getClientStream(ctx context.Context, desc *client.ClientStreamDesc, opts []client.Option) (client.ClientStream, error) {
   804  	cli := stream.NewStreamClient()
   805  	method := "/trpc.test.stream.Greeter/StreamSayHello"
   806  	cliOpts := []client.Option{
   807  		client.WithProtocol("trpc"),
   808  		client.WithTransport(transport.NewClientTransport()),
   809  		client.WithStreamTransport(transport.NewClientStreamTransport()),
   810  		client.WithCurrentSerializationType(codec.SerializationTypeNoop),
   811  	}
   812  	cliOpts = append(cliOpts, opts...)
   813  	return cli.NewStream(ctx, desc, method, cliOpts...)
   814  }
   815  
   816  func register(s server.Service, f func(server.Stream) error) {
   817  	svr := &greeterServiceImpl{f: f}
   818  	if err := s.Register(&GreeterServer_ServiceDesc, svr); err != nil {
   819  		panic(fmt.Sprintf("Greeter register error: %v", err))
   820  	}
   821  }
   822  
   823  type greeterServiceImpl struct {
   824  	f func(server.Stream) error
   825  }
   826  
   827  func (s *greeterServiceImpl) BidiStreamSayHello(stream server.Stream) error {
   828  	return s.f(stream)
   829  }
   830  
   831  func GreeterService_BidiStreamSayHello_Handler(srv interface{}, stream server.Stream) error {
   832  	return srv.(GreeterService).BidiStreamSayHello(stream)
   833  }
   834  
   835  type GreeterService interface {
   836  	// BidiStreamSayHello Bidi streaming
   837  	BidiStreamSayHello(server.Stream) error
   838  }
   839  
   840  var GreeterServer_ServiceDesc = server.ServiceDesc{
   841  	ServiceName:  "trpc.test.stream.Greeter",
   842  	HandlerType:  (*GreeterService)(nil),
   843  	StreamHandle: stream.NewStreamDispatcher(),
   844  	Streams: []server.StreamDesc{
   845  		{
   846  			StreamName:    "/trpc.test.stream.Greeter/StreamSayHello",
   847  			Handler:       GreeterService_BidiStreamSayHello_Handler,
   848  			ServerStreams: true,
   849  		},
   850  	},
   851  }
   852  
   853  func getBytes(size int) *codec.Body {
   854  	return &codec.Body{Data: make([]byte, size)}
   855  }
   856  
   857  /* --------------- Filter Unit Test -------------*/
   858  
   859  type wrappedServerStream struct {
   860  	server.Stream
   861  }
   862  
   863  func newWrappedServerStream(s server.Stream) server.Stream {
   864  	return &wrappedServerStream{s}
   865  }
   866  
   867  func (w *wrappedServerStream) RecvMsg(m interface{}) error {
   868  	err := w.Stream.RecvMsg(m)
   869  	num := binary.LittleEndian.Uint64(m.(*codec.Body).Data[:8])
   870  	binary.LittleEndian.PutUint64(m.(*codec.Body).Data[:8], num+1)
   871  	return err
   872  }
   873  
   874  func (w *wrappedServerStream) SendMsg(m interface{}) error {
   875  	num := binary.LittleEndian.Uint64(m.(*codec.Body).Data[:8])
   876  	binary.LittleEndian.PutUint64(m.(*codec.Body).Data[:8], num+1)
   877  	return w.Stream.SendMsg(m)
   878  }
   879  
   880  var (
   881  	testKey1 = "hello"
   882  	testKey2 = "ping"
   883  	testData = map[string][]byte{
   884  		testKey1: []byte("world"),
   885  		testKey2: []byte("pong"),
   886  	}
   887  )
   888  
   889  func serverFilterAdd1(ss server.Stream, si *server.StreamServerInfo,
   890  	handler server.StreamHandler) error {
   891  	msg := trpc.Message(ss.Context())
   892  	meta := msg.ServerMetaData()
   893  	if v, ok := meta[testKey1]; !ok {
   894  		return errors.New("meta not exist")
   895  	} else if !bytes.Equal(v, testData[testKey1]) {
   896  		return errors.New("meta not match")
   897  	}
   898  	err := handler(newWrappedServerStream(ss))
   899  	return err
   900  }
   901  
   902  func serverFilterAdd2(ss server.Stream, si *server.StreamServerInfo,
   903  	handler server.StreamHandler) error {
   904  	msg := trpc.Message(ss.Context())
   905  	meta := msg.ServerMetaData()
   906  	if v, ok := meta[testKey2]; !ok {
   907  		return errors.New("meta not exist")
   908  	} else if !bytes.Equal(v, testData[testKey2]) {
   909  		return errors.New("meta not match")
   910  	}
   911  	err := handler(newWrappedServerStream(ss))
   912  	return err
   913  }