trpc.group/trpc-go/trpc-go@v1.0.2/client/stream_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package client_test
    15  
    16  import (
    17  	"context"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/stretchr/testify/require"
    22  
    23  	"trpc.group/trpc-go/trpc-go/client"
    24  	"trpc.group/trpc-go/trpc-go/codec"
    25  	"trpc.group/trpc-go/trpc-go/naming/registry"
    26  
    27  	_ "trpc.group/trpc-go/trpc-go"
    28  )
    29  
    30  // TestStream tests client stream.
    31  func TestStream(t *testing.T) {
    32  	ctx := context.Background()
    33  	reqBody := &codec.Body{}
    34  	codec.RegisterSerializer(0, &codec.NoopSerialization{})
    35  	codec.Register("fake", nil, &fakeCodec{})
    36  	codec.Register("fake-nil", nil, nil)
    37  
    38  	// calling without error
    39  	streamCli := client.NewStream()
    40  	require.NotNil(t, streamCli)
    41  	opts, err := streamCli.Init(ctx, client.WithTarget("ip://127.0.0.1:8000"),
    42  		client.WithTimeout(time.Second), client.WithSerializationType(codec.SerializationTypeNoop),
    43  		client.WithStreamTransport(&fakeTransport{}), client.WithProtocol("fake"))
    44  	require.Nil(t, err)
    45  	require.NotNil(t, opts)
    46  	err = streamCli.Invoke(ctx)
    47  	require.Nil(t, err)
    48  	err = streamCli.Send(ctx, reqBody)
    49  	require.Nil(t, err)
    50  	rsp, err := streamCli.Recv(ctx)
    51  	require.Nil(t, err)
    52  	require.Equal(t, []byte("body"), rsp)
    53  	err = streamCli.Close(ctx)
    54  	require.Nil(t, err)
    55  
    56  	// test nil Codec
    57  	opts, err = streamCli.Init(ctx,
    58  		client.WithTarget("ip://127.0.0.1:8080"),
    59  		client.WithTimeout(time.Second),
    60  		client.WithProtocol("fake-nil"),
    61  		client.WithSerializationType(codec.SerializationTypeNoop),
    62  		client.WithStreamTransport(&fakeTransport{}))
    63  	require.NotNil(t, err)
    64  	require.Nil(t, opts)
    65  	err = streamCli.Invoke(ctx)
    66  	require.Nil(t, err)
    67  
    68  	// test selectNode with error
    69  	opts, err = streamCli.Init(ctx, client.WithTarget("ip/:/127.0.0.1:8080"),
    70  		client.WithProtocol("fake"))
    71  	require.NotNil(t, err)
    72  	require.Contains(t, err.Error(), "invalid")
    73  	require.Nil(t, opts)
    74  
    75  	// test stream recv failure
    76  	ctx = context.WithValue(ctx, "recv-error", "recv failed")
    77  	opts, err = streamCli.Init(ctx, client.WithTarget("ip://127.0.0.1:8000"),
    78  		client.WithTimeout(time.Second), client.WithSerializationType(codec.SerializationTypeNoop),
    79  		client.WithStreamTransport(&fakeTransport{}), client.WithProtocol("fake"))
    80  	require.Nil(t, err)
    81  	require.NotNil(t, opts)
    82  	err = streamCli.Invoke(ctx)
    83  	require.Nil(t, err)
    84  	rsp, err = streamCli.Recv(ctx)
    85  	require.Nil(t, rsp)
    86  	require.NotNil(t, err)
    87  
    88  	// test decode failure
    89  	ctx = context.WithValue(ctx, "recv-decode-error", "businessfail")
    90  	rsp, err = streamCli.Recv(ctx)
    91  	require.Nil(t, rsp)
    92  	require.NotNil(t, err)
    93  
    94  	// test compress failure
    95  	ctx = context.Background()
    96  	opts, err = streamCli.Init(ctx, client.WithTarget("ip://127.0.0.1:8000"),
    97  		client.WithTimeout(time.Second), client.WithSerializationType(codec.SerializationTypeNoop),
    98  		client.WithStreamTransport(&fakeTransport{}), client.WithCurrentCompressType(codec.CompressTypeGzip),
    99  		client.WithProtocol("fake"))
   100  	require.Nil(t, err)
   101  	require.NotNil(t, opts)
   102  	err = streamCli.Invoke(ctx)
   103  	require.Nil(t, err)
   104  	_, err = streamCli.Recv(ctx)
   105  	require.NotNil(t, err)
   106  
   107  	// test compress without error
   108  	opts, err = streamCli.Init(ctx, client.WithTarget("ip://127.0.0.1:8000"),
   109  		client.WithTimeout(time.Second), client.WithSerializationType(codec.SerializationTypeNoop),
   110  		client.WithStreamTransport(&fakeTransport{}), client.WithCurrentCompressType(codec.CompressTypeNoop),
   111  		client.WithProtocol("fake"))
   112  	require.Nil(t, err)
   113  	require.NotNil(t, opts)
   114  	err = streamCli.Invoke(ctx)
   115  	require.Nil(t, err)
   116  	rsp, err = streamCli.Recv(ctx)
   117  	require.Nil(t, err)
   118  	require.NotNil(t, rsp)
   119  }
   120  
   121  func TestGetStreamFilter(t *testing.T) {
   122  	type noopClientStream struct {
   123  		client.ClientStream
   124  	}
   125  	testClientStream := &noopClientStream{}
   126  	testFilter := func(ctx context.Context, desc *client.ClientStreamDesc,
   127  		streamer client.Streamer) (client.ClientStream, error) {
   128  		return testClientStream, nil
   129  	}
   130  	client.RegisterStreamFilter("testFilter", testFilter)
   131  	filter := client.GetStreamFilter("testFilter")
   132  	cs, err := filter(context.Background(), &client.ClientStreamDesc{}, nil)
   133  	require.Nil(t, err)
   134  	require.Equal(t, testClientStream, cs)
   135  }
   136  
   137  func TestStreamGetAddress(t *testing.T) {
   138  	s := client.NewStream()
   139  	require.NotNil(t, s)
   140  	ctx, msg := codec.EnsureMessage(context.Background())
   141  	node := &registry.Node{}
   142  	const addr = "127.0.0.1:8000"
   143  	opts, err := s.Init(ctx,
   144  		client.WithTarget("ip://"+addr),
   145  		client.WithTimeout(time.Second),
   146  		client.WithSelectorNode(node),
   147  	)
   148  	require.Nil(t, err)
   149  	require.NotNil(t, opts)
   150  	require.Equal(t, addr, node.Address)
   151  	require.NotNil(t, msg.RemoteAddr())
   152  	require.Equal(t, addr, msg.RemoteAddr().String())
   153  }