trpc.group/trpc-go/trpc-go@v1.0.3/client/stream_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package client_test 15 16 import ( 17 "context" 18 "errors" 19 "testing" 20 "time" 21 22 "github.com/stretchr/testify/require" 23 24 "trpc.group/trpc-go/trpc-go/client" 25 "trpc.group/trpc-go/trpc-go/codec" 26 "trpc.group/trpc-go/trpc-go/naming/registry" 27 28 _ "trpc.group/trpc-go/trpc-go" 29 ) 30 31 // TestStream tests client stream. 32 func TestStream(t *testing.T) { 33 ctx := context.Background() 34 reqBody := &codec.Body{} 35 codec.RegisterSerializer(0, &codec.NoopSerialization{}) 36 codec.Register("fake", nil, &fakeCodec{}) 37 codec.Register("fake-nil", nil, nil) 38 39 // calling without error 40 streamCli := client.NewStream() 41 t.Run("calling without error", func(t *testing.T) { 42 require.NotNil(t, streamCli) 43 opts, err := streamCli.Init(ctx, 44 client.WithTarget("ip://127.0.0.1:8000"), 45 client.WithTimeout(time.Second), 46 client.WithSerializationType(codec.SerializationTypeNoop), 47 client.WithStreamTransport(&fakeTransport{}), 48 client.WithProtocol("fake"), 49 ) 50 require.Nil(t, err) 51 require.NotNil(t, opts) 52 err = streamCli.Invoke(ctx) 53 require.Nil(t, err) 54 err = streamCli.Send(ctx, reqBody) 55 require.Nil(t, err) 56 rsp, err := streamCli.Recv(ctx) 57 require.Nil(t, err) 58 require.Equal(t, []byte("body"), rsp) 59 err = streamCli.Close(ctx) 60 require.Nil(t, err) 61 }) 62 63 t.Run("test nil Codec", func(t *testing.T) { 64 opts, err := streamCli.Init(ctx, 65 client.WithTarget("ip://127.0.0.1:8080"), 66 client.WithTimeout(time.Second), 67 client.WithProtocol("fake-nil"), 68 client.WithSerializationType(codec.SerializationTypeNoop), 69 client.WithStreamTransport(&fakeTransport{})) 70 require.NotNil(t, err) 71 require.Nil(t, opts) 72 err = streamCli.Invoke(ctx) 73 require.Nil(t, err) 74 }) 75 76 t.Run("test selectNode with error", func(t *testing.T) { 77 opts, err := streamCli.Init(ctx, 78 client.WithTarget("ip/:/127.0.0.1:8080"), 79 client.WithProtocol("fake"), 80 ) 81 require.NotNil(t, err) 82 require.Contains(t, err.Error(), "invalid") 83 require.Nil(t, opts) 84 }) 85 86 t.Run("test stream recv failure", func(t *testing.T) { 87 opts, err := streamCli.Init(ctx, 88 client.WithTarget("ip://127.0.0.1:8000"), 89 client.WithTimeout(time.Second), 90 client.WithSerializationType(codec.SerializationTypeNoop), 91 client.WithStreamTransport(&fakeTransport{ 92 recv: func() ([]byte, error) { 93 return nil, errors.New("recv failed") 94 }, 95 }), 96 client.WithProtocol("fake"), 97 ) 98 require.Nil(t, err) 99 require.NotNil(t, opts) 100 err = streamCli.Invoke(ctx) 101 require.Nil(t, err) 102 rsp, err := streamCli.Recv(ctx) 103 require.Nil(t, rsp) 104 require.NotNil(t, err) 105 }) 106 107 t.Run("test decode failure", func(t *testing.T) { 108 _, err := streamCli.Init(ctx, 109 client.WithTarget("ip://127.0.0.1:8000"), 110 client.WithTimeout(time.Second), 111 client.WithSerializationType(codec.SerializationTypeNoop), 112 client.WithStreamTransport(&fakeTransport{ 113 recv: func() ([]byte, error) { 114 return []byte("businessfail"), nil 115 }, 116 }), 117 client.WithProtocol("fake"), 118 ) 119 require.Nil(t, err) 120 rsp, err := streamCli.Recv(ctx) 121 require.Nil(t, rsp) 122 require.NotNil(t, err) 123 }) 124 125 t.Run("test compress failure", func(t *testing.T) { 126 opts, err := streamCli.Init(context.Background(), 127 client.WithTarget("ip://127.0.0.1:8000"), 128 client.WithTimeout(time.Second), 129 client.WithSerializationType(codec.SerializationTypeNoop), 130 client.WithStreamTransport(&fakeTransport{}), 131 client.WithCurrentCompressType(codec.CompressTypeGzip), 132 client.WithProtocol("fake")) 133 require.Nil(t, err) 134 require.NotNil(t, opts) 135 err = streamCli.Invoke(ctx) 136 require.Nil(t, err) 137 _, err = streamCli.Recv(ctx) 138 require.NotNil(t, err) 139 }) 140 141 t.Run("test compress without error", func(t *testing.T) { 142 opts, err := streamCli.Init(ctx, 143 client.WithTarget("ip://127.0.0.1:8000"), 144 client.WithTimeout(time.Second), 145 client.WithSerializationType(codec.SerializationTypeNoop), 146 client.WithStreamTransport(&fakeTransport{}), 147 client.WithCurrentCompressType(codec.CompressTypeNoop), 148 client.WithProtocol("fake"), 149 ) 150 require.Nil(t, err) 151 require.NotNil(t, opts) 152 err = streamCli.Invoke(ctx) 153 require.Nil(t, err) 154 rsp, err := streamCli.Recv(ctx) 155 require.Nil(t, err) 156 require.NotNil(t, rsp) 157 }) 158 } 159 160 func TestGetStreamFilter(t *testing.T) { 161 type noopClientStream struct { 162 client.ClientStream 163 } 164 testClientStream := &noopClientStream{} 165 testFilter := func(ctx context.Context, desc *client.ClientStreamDesc, 166 streamer client.Streamer) (client.ClientStream, error) { 167 return testClientStream, nil 168 } 169 client.RegisterStreamFilter("testFilter", testFilter) 170 filter := client.GetStreamFilter("testFilter") 171 cs, err := filter(context.Background(), &client.ClientStreamDesc{}, nil) 172 require.Nil(t, err) 173 require.Equal(t, testClientStream, cs) 174 } 175 176 func TestStreamGetAddress(t *testing.T) { 177 s := client.NewStream() 178 require.NotNil(t, s) 179 ctx, msg := codec.EnsureMessage(context.Background()) 180 node := ®istry.Node{} 181 const addr = "127.0.0.1:8000" 182 opts, err := s.Init(ctx, 183 client.WithTarget("ip://"+addr), 184 client.WithTimeout(time.Second), 185 client.WithSelectorNode(node), 186 ) 187 require.Nil(t, err) 188 require.NotNil(t, opts) 189 require.Equal(t, addr, node.Address) 190 require.NotNil(t, msg.RemoteAddr()) 191 require.Equal(t, addr, msg.RemoteAddr().String()) 192 } 193 194 func TestStreamCloseTransport(t *testing.T) { 195 codec.Register("fake", nil, &fakeCodec{}) 196 t.Run("close transport when send fail", func(t *testing.T) { 197 var isClose bool 198 streamCli := client.NewStream() 199 _, err := streamCli.Init(context.Background(), 200 client.WithTarget("ip://127.0.0.1:8000"), 201 client.WithStreamTransport(&fakeTransport{ 202 send: func() error { 203 return errors.New("expected error") 204 }, 205 close: func() { 206 isClose = true 207 }, 208 }), 209 client.WithProtocol("fake"), 210 ) 211 require.Nil(t, err) 212 require.NotNil(t, streamCli.Send(context.Background(), nil)) 213 require.True(t, isClose) 214 }) 215 t.Run("close transport when recv fail", func(t *testing.T) { 216 var isClose bool 217 streamCli := client.NewStream() 218 _, err := streamCli.Init(context.Background(), 219 client.WithTarget("ip://127.0.0.1:8000"), 220 client.WithStreamTransport(&fakeTransport{ 221 recv: func() ([]byte, error) { 222 return nil, errors.New("expected error") 223 }, 224 close: func() { 225 isClose = true 226 }, 227 }), 228 client.WithProtocol("fake"), 229 ) 230 require.Nil(t, err) 231 _, err = streamCli.Recv(context.Background()) 232 require.NotNil(t, err) 233 require.True(t, isClose) 234 }) 235 }