trpc.group/trpc-go/trpc-go@v1.0.3/client/stream_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package client_test
    15  
    16  import (
    17  	"context"
    18  	"errors"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/stretchr/testify/require"
    23  
    24  	"trpc.group/trpc-go/trpc-go/client"
    25  	"trpc.group/trpc-go/trpc-go/codec"
    26  	"trpc.group/trpc-go/trpc-go/naming/registry"
    27  
    28  	_ "trpc.group/trpc-go/trpc-go"
    29  )
    30  
    31  // TestStream tests client stream.
    32  func TestStream(t *testing.T) {
    33  	ctx := context.Background()
    34  	reqBody := &codec.Body{}
    35  	codec.RegisterSerializer(0, &codec.NoopSerialization{})
    36  	codec.Register("fake", nil, &fakeCodec{})
    37  	codec.Register("fake-nil", nil, nil)
    38  
    39  	// calling without error
    40  	streamCli := client.NewStream()
    41  	t.Run("calling without error", func(t *testing.T) {
    42  		require.NotNil(t, streamCli)
    43  		opts, err := streamCli.Init(ctx,
    44  			client.WithTarget("ip://127.0.0.1:8000"),
    45  			client.WithTimeout(time.Second),
    46  			client.WithSerializationType(codec.SerializationTypeNoop),
    47  			client.WithStreamTransport(&fakeTransport{}),
    48  			client.WithProtocol("fake"),
    49  		)
    50  		require.Nil(t, err)
    51  		require.NotNil(t, opts)
    52  		err = streamCli.Invoke(ctx)
    53  		require.Nil(t, err)
    54  		err = streamCli.Send(ctx, reqBody)
    55  		require.Nil(t, err)
    56  		rsp, err := streamCli.Recv(ctx)
    57  		require.Nil(t, err)
    58  		require.Equal(t, []byte("body"), rsp)
    59  		err = streamCli.Close(ctx)
    60  		require.Nil(t, err)
    61  	})
    62  
    63  	t.Run("test nil Codec", func(t *testing.T) {
    64  		opts, err := streamCli.Init(ctx,
    65  			client.WithTarget("ip://127.0.0.1:8080"),
    66  			client.WithTimeout(time.Second),
    67  			client.WithProtocol("fake-nil"),
    68  			client.WithSerializationType(codec.SerializationTypeNoop),
    69  			client.WithStreamTransport(&fakeTransport{}))
    70  		require.NotNil(t, err)
    71  		require.Nil(t, opts)
    72  		err = streamCli.Invoke(ctx)
    73  		require.Nil(t, err)
    74  	})
    75  
    76  	t.Run("test selectNode with error", func(t *testing.T) {
    77  		opts, err := streamCli.Init(ctx,
    78  			client.WithTarget("ip/:/127.0.0.1:8080"),
    79  			client.WithProtocol("fake"),
    80  		)
    81  		require.NotNil(t, err)
    82  		require.Contains(t, err.Error(), "invalid")
    83  		require.Nil(t, opts)
    84  	})
    85  
    86  	t.Run("test stream recv failure", func(t *testing.T) {
    87  		opts, err := streamCli.Init(ctx,
    88  			client.WithTarget("ip://127.0.0.1:8000"),
    89  			client.WithTimeout(time.Second),
    90  			client.WithSerializationType(codec.SerializationTypeNoop),
    91  			client.WithStreamTransport(&fakeTransport{
    92  				recv: func() ([]byte, error) {
    93  					return nil, errors.New("recv failed")
    94  				},
    95  			}),
    96  			client.WithProtocol("fake"),
    97  		)
    98  		require.Nil(t, err)
    99  		require.NotNil(t, opts)
   100  		err = streamCli.Invoke(ctx)
   101  		require.Nil(t, err)
   102  		rsp, err := streamCli.Recv(ctx)
   103  		require.Nil(t, rsp)
   104  		require.NotNil(t, err)
   105  	})
   106  
   107  	t.Run("test decode failure", func(t *testing.T) {
   108  		_, err := streamCli.Init(ctx,
   109  			client.WithTarget("ip://127.0.0.1:8000"),
   110  			client.WithTimeout(time.Second),
   111  			client.WithSerializationType(codec.SerializationTypeNoop),
   112  			client.WithStreamTransport(&fakeTransport{
   113  				recv: func() ([]byte, error) {
   114  					return []byte("businessfail"), nil
   115  				},
   116  			}),
   117  			client.WithProtocol("fake"),
   118  		)
   119  		require.Nil(t, err)
   120  		rsp, err := streamCli.Recv(ctx)
   121  		require.Nil(t, rsp)
   122  		require.NotNil(t, err)
   123  	})
   124  
   125  	t.Run("test compress failure", func(t *testing.T) {
   126  		opts, err := streamCli.Init(context.Background(),
   127  			client.WithTarget("ip://127.0.0.1:8000"),
   128  			client.WithTimeout(time.Second),
   129  			client.WithSerializationType(codec.SerializationTypeNoop),
   130  			client.WithStreamTransport(&fakeTransport{}),
   131  			client.WithCurrentCompressType(codec.CompressTypeGzip),
   132  			client.WithProtocol("fake"))
   133  		require.Nil(t, err)
   134  		require.NotNil(t, opts)
   135  		err = streamCli.Invoke(ctx)
   136  		require.Nil(t, err)
   137  		_, err = streamCli.Recv(ctx)
   138  		require.NotNil(t, err)
   139  	})
   140  
   141  	t.Run("test compress without error", func(t *testing.T) {
   142  		opts, err := streamCli.Init(ctx,
   143  			client.WithTarget("ip://127.0.0.1:8000"),
   144  			client.WithTimeout(time.Second),
   145  			client.WithSerializationType(codec.SerializationTypeNoop),
   146  			client.WithStreamTransport(&fakeTransport{}),
   147  			client.WithCurrentCompressType(codec.CompressTypeNoop),
   148  			client.WithProtocol("fake"),
   149  		)
   150  		require.Nil(t, err)
   151  		require.NotNil(t, opts)
   152  		err = streamCli.Invoke(ctx)
   153  		require.Nil(t, err)
   154  		rsp, err := streamCli.Recv(ctx)
   155  		require.Nil(t, err)
   156  		require.NotNil(t, rsp)
   157  	})
   158  }
   159  
   160  func TestGetStreamFilter(t *testing.T) {
   161  	type noopClientStream struct {
   162  		client.ClientStream
   163  	}
   164  	testClientStream := &noopClientStream{}
   165  	testFilter := func(ctx context.Context, desc *client.ClientStreamDesc,
   166  		streamer client.Streamer) (client.ClientStream, error) {
   167  		return testClientStream, nil
   168  	}
   169  	client.RegisterStreamFilter("testFilter", testFilter)
   170  	filter := client.GetStreamFilter("testFilter")
   171  	cs, err := filter(context.Background(), &client.ClientStreamDesc{}, nil)
   172  	require.Nil(t, err)
   173  	require.Equal(t, testClientStream, cs)
   174  }
   175  
   176  func TestStreamGetAddress(t *testing.T) {
   177  	s := client.NewStream()
   178  	require.NotNil(t, s)
   179  	ctx, msg := codec.EnsureMessage(context.Background())
   180  	node := &registry.Node{}
   181  	const addr = "127.0.0.1:8000"
   182  	opts, err := s.Init(ctx,
   183  		client.WithTarget("ip://"+addr),
   184  		client.WithTimeout(time.Second),
   185  		client.WithSelectorNode(node),
   186  	)
   187  	require.Nil(t, err)
   188  	require.NotNil(t, opts)
   189  	require.Equal(t, addr, node.Address)
   190  	require.NotNil(t, msg.RemoteAddr())
   191  	require.Equal(t, addr, msg.RemoteAddr().String())
   192  }
   193  
   194  func TestStreamCloseTransport(t *testing.T) {
   195  	codec.Register("fake", nil, &fakeCodec{})
   196  	t.Run("close transport when send fail", func(t *testing.T) {
   197  		var isClose bool
   198  		streamCli := client.NewStream()
   199  		_, err := streamCli.Init(context.Background(),
   200  			client.WithTarget("ip://127.0.0.1:8000"),
   201  			client.WithStreamTransport(&fakeTransport{
   202  				send: func() error {
   203  					return errors.New("expected error")
   204  				},
   205  				close: func() {
   206  					isClose = true
   207  				},
   208  			}),
   209  			client.WithProtocol("fake"),
   210  		)
   211  		require.Nil(t, err)
   212  		require.NotNil(t, streamCli.Send(context.Background(), nil))
   213  		require.True(t, isClose)
   214  	})
   215  	t.Run("close transport when recv fail", func(t *testing.T) {
   216  		var isClose bool
   217  		streamCli := client.NewStream()
   218  		_, err := streamCli.Init(context.Background(),
   219  			client.WithTarget("ip://127.0.0.1:8000"),
   220  			client.WithStreamTransport(&fakeTransport{
   221  				recv: func() ([]byte, error) {
   222  					return nil, errors.New("expected error")
   223  				},
   224  				close: func() {
   225  					isClose = true
   226  				},
   227  			}),
   228  			client.WithProtocol("fake"),
   229  		)
   230  		require.Nil(t, err)
   231  		_, err = streamCli.Recv(context.Background())
   232  		require.NotNil(t, err)
   233  		require.True(t, isClose)
   234  	})
   235  }