trpc.group/trpc-go/trpc-go@v1.0.2/client/client_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package client_test 15 16 import ( 17 "context" 18 "errors" 19 "testing" 20 "time" 21 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/require" 24 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 25 26 trpc "trpc.group/trpc-go/trpc-go" 27 "trpc.group/trpc-go/trpc-go/client" 28 "trpc.group/trpc-go/trpc-go/codec" 29 "trpc.group/trpc-go/trpc-go/errs" 30 "trpc.group/trpc-go/trpc-go/filter" 31 "trpc.group/trpc-go/trpc-go/naming/registry" 32 "trpc.group/trpc-go/trpc-go/naming/selector" 33 "trpc.group/trpc-go/trpc-go/transport" 34 35 _ "trpc.group/trpc-go/trpc-go" 36 ) 37 38 // go test -v -coverprofile=cover.out 39 // go tool cover -func=cover.out 40 41 func TestMain(m *testing.M) { 42 transport.DefaultClientTransport = &fakeTransport{} 43 selector.Register("fake", &fakeSelector{}) // fake://{endpoint} 44 transport.RegisterClientTransport("fake", &fakeTransport{}) 45 m.Run() 46 } 47 48 func TestClient(t *testing.T) { 49 ctx := context.Background() 50 codec.RegisterSerializer(0, &codec.NoopSerialization{}) 51 codec.Register("fake", nil, &fakeCodec{}) 52 53 cli := client.New() 54 require.Equal(t, cli, client.DefaultClient) 55 56 // test if response is valid 57 reqBody := &codec.Body{Data: []byte("body")} 58 rspBody := &codec.Body{} 59 require.Nil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 60 client.WithTimeout(time.Second), client.WithProtocol("fake"))) 61 require.Equal(t, []byte("body"), rspBody.Data) 62 63 // test setting req/resp head 64 reqhead := ®istry.Node{} 65 rsphead := ®istry.Node{} 66 require.Nil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 67 client.WithReqHead(reqhead), client.WithRspHead(rsphead), client.WithProtocol("fake"))) 68 69 // test client options 70 require.Nil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 71 client.WithTimeout(time.Second), 72 client.WithServiceName("trpc.app.callee.service"), 73 client.WithCallerServiceName("trpc.app.caller.service"), 74 client.WithSerializationType(codec.SerializationTypeNoop), 75 client.WithCompressType(codec.CompressTypeGzip), 76 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 77 client.WithCurrentCompressType(codec.CompressTypeNoop), 78 client.WithMetaData("key", []byte("value")), 79 client.WithProtocol("fake"))) 80 81 // test selecting node with network: udp 82 require.Nil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("fake://udpnetwork"), 83 client.WithTimeout(time.Second), client.WithProtocol("fake"))) 84 85 // test selecting node with network: unknown, which will use tcp by default 86 require.Nil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("fake://unknownnetwork"), 87 client.WithTimeout(time.Second), client.WithProtocol("fake"))) 88 89 // test setting namespace in msg 90 ctx = context.Background() 91 ctx, msg := codec.WithNewMessage(ctx) 92 msg.WithNamespace("Development") // getServiceInfoOptions will set env info according to the namespace 93 require.Nil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 94 client.WithTimeout(time.Second), client.WithProtocol("fake"))) 95 require.Equal(t, []byte("body"), rspBody.Data) 96 97 // test that env info from upstream service has higher priority 98 ctx = context.Background() 99 ctx, msg = codec.WithNewMessage(ctx) 100 msg.WithEnvTransfer("faketransfer") // env info from upstream service exists 101 require.Nil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 102 client.WithTimeout(time.Second), client.WithProtocol("fake"))) 103 require.Equal(t, []byte("body"), rspBody.Data) 104 105 // test disabling service router, which will clear env info from msg 106 ctx = context.Background() 107 ctx, msg = codec.WithNewMessage(ctx) 108 msg.WithEnvTransfer("faketransfer") // env info from upstream service exists 109 require.Nil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 110 client.WithTimeout(time.Second), client.WithProtocol("fake"), 111 client.WithDisableServiceRouter())) // opts that disables service router 112 require.Equal(t, []byte("body"), rspBody.Data) 113 require.Equal(t, msg.EnvTransfer(), "") // env info from upstream service was cleared 114 115 // test setting CalleeMethod in opts 116 // updateMsg will then update CalleeMethod in msg 117 ctx = context.Background() 118 ctx, msg = codec.WithNewMessage(ctx) 119 require.Nil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 120 client.WithTimeout(time.Second), client.WithProtocol("fake"), 121 client.WithCalleeMethod("fakemethod"))) // opts 中指定了 CalleeMethod 122 require.Equal(t, msg.CalleeMethod(), "fakemethod") // msg 中的 CalleeMethod 被更新 123 124 // test that the parameters can be extracted from msg in the prev filter 125 ctx = context.Background() 126 ctx, msg = codec.WithNewMessage(ctx) 127 rid := uint32(100000) 128 msg.WithRequestID(uint32(rid)) 129 130 require.Nil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 131 client.WithTimeout(time.Second), client.WithProtocol("fake"), 132 client.WithFilter(func(ctx context.Context, req interface{}, rsp interface{}, f filter.ClientHandleFunc) (err error) { 133 msg := trpc.Message(ctx) 134 require.Equal(t, rid, msg.RequestID()) 135 return f(ctx, req, rsp) 136 }))) 137 138 // test setting CallType in opts 139 // updateMsg will then update CallType in msg 140 ctx = context.Background() 141 head := &trpcpb.RequestProtocol{} 142 ctx, msg = codec.WithNewMessage(ctx) 143 require.Nil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 144 client.WithProtocol("fake"), 145 client.WithSendOnly(), 146 client.WithReqHead(head), 147 )) 148 require.Equal(t, msg.CallType(), codec.SendOnly) 149 } 150 151 func TestClientFail(t *testing.T) { 152 ctx := context.Background() 153 codec.RegisterSerializer(0, &codec.NoopSerialization{}) 154 codec.Register("fake", nil, &fakeCodec{}) 155 156 cli := client.New() 157 require.Equal(t, cli, client.DefaultClient) 158 159 reqBody := &codec.Body{Data: []byte("body")} 160 rspBody := &codec.Body{} 161 // test code failure 162 require.NotNil(t, cli.Invoke(ctx, reqBody, rspBody, 163 client.WithTarget("ip://127.0.0.1:8080"), 164 client.WithTimeout(time.Second), 165 client.WithSerializationType(codec.SerializationTypeNoop))) 166 167 // test invalid target 168 err := cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip/:/127.0.0.1:8080"), 169 client.WithProtocol("fake")) 170 require.NotNil(t, err) 171 require.Contains(t, err.Error(), "invalid") 172 173 // test target selector that not exists 174 err = cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("cl6://127.0.0.1:8080"), 175 client.WithProtocol("fake")) 176 require.NotNil(t, err) 177 require.Contains(t, err.Error(), "not exist") 178 179 // test recording selected node 180 node := ®istry.Node{} 181 require.Nil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 182 client.WithSelectorNode(node), client.WithProtocol("fake"))) 183 require.Equal(t, node.Address, "127.0.0.1:8080") 184 require.Equal(t, node.ServiceName, "127.0.0.1:8080") 185 require.Empty(t, node.Network) 186 187 // test encode failure 188 reqBody = &codec.Body{Data: []byte("failbody")} 189 require.NotNil(t, cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 190 client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop))) 191 192 // test network failure 193 reqBody = &codec.Body{Data: []byte("callfail")} 194 err = cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 195 client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop)) 196 assert.NotNil(t, err) 197 198 // test response failure 199 reqBody = &codec.Body{Data: []byte("businessfail")} 200 err = cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 201 client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop)) 202 203 reqBody = &codec.Body{Data: []byte("msgfail")} 204 err = cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 205 client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop)) 206 assert.NotNil(t, err) 207 208 // test nil rsp 209 reqBody = &codec.Body{Data: []byte("nilrsp")} 210 err = cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 211 client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop)) 212 assert.Nil(t, err) 213 214 // test timeout 215 reqBody = &codec.Body{Data: []byte("timeout")} 216 err = cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), 217 client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop)) 218 assert.NotNil(t, err) 219 220 // test select node failure 221 reqBody = &codec.Body{Data: []byte("body")} 222 err = cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("fake://selectfail"), 223 client.WithTimeout(time.Second), client.WithProtocol("fake")) 224 assert.NotNil(t, err) 225 226 // test selecting the node with empty addr 227 err = cli.Invoke(ctx, reqBody, rspBody, client.WithTarget("fake://emptynode"), 228 client.WithTimeout(time.Second), client.WithProtocol("fake")) 229 assert.NotNil(t, err) 230 231 } 232 233 func TestClientAddrResolve(t *testing.T) { 234 ctx := context.Background() 235 codec.RegisterSerializer(0, &codec.NoopSerialization{}) 236 codec.Register("fake", nil, &fakeCodec{}) 237 cli := client.New() 238 239 reqBody := &codec.Body{Data: []byte("body")} 240 rspBody := &codec.Body{} 241 // test target with ip schema 242 nctx, _ := codec.WithNewMessage(ctx) 243 _ = cli.Invoke(nctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), client.WithProtocol("fake")) 244 assert.Equal(t, "127.0.0.1:8080", codec.Message(nctx).RemoteAddr().String()) 245 246 // test target with ip schema and network: tcp 247 nctx, _ = codec.WithNewMessage(ctx) 248 _ = cli.Invoke(nctx, reqBody, rspBody, 249 client.WithTarget("ip://127.0.0.1:8080"), 250 client.WithNetwork("tcp"), 251 client.WithProtocol("fake"), 252 ) 253 require.Equal(t, "127.0.0.1:8080", codec.Message(nctx).RemoteAddr().String()) 254 255 // test target with hostname schema 256 nctx, _ = codec.WithNewMessage(ctx) 257 _ = cli.Invoke(nctx, reqBody, rspBody, client.WithTarget("ip://www.qq.com:8080"), client.WithProtocol("fake")) 258 assert.Nil(t, codec.Message(nctx).RemoteAddr()) 259 260 // test calling target with ip schema failure 261 nctx, msg := codec.WithNewMessage(ctx) 262 reqBody = &codec.Body{Data: []byte("callfail")} 263 err := cli.Invoke(nctx, reqBody, rspBody, client.WithTarget("ip://127.0.0.1:8080"), client.WithProtocol("fake")) 264 assert.NotNil(t, err) 265 assert.Equal(t, "127.0.0.1:8080", msg.RemoteAddr().String()) 266 267 // test target with unix schema 268 nctx, _ = codec.WithNewMessage(ctx) 269 _ = cli.Invoke(nctx, reqBody, rspBody, 270 client.WithTarget("unix://temp.sock"), 271 client.WithNetwork("unix"), 272 client.WithProtocol("fake"), 273 ) 274 require.Equal(t, "temp.sock", codec.Message(nctx).RemoteAddr().String()) 275 } 276 277 func TestTimeout(t *testing.T) { 278 codec.RegisterSerializer(0, &codec.NoopSerialization{}) 279 codec.Register("fake", nil, &fakeCodec{}) 280 target, protocol := "ip://127.0.0.1:8080", "fake" 281 282 cli := client.New() 283 rspBody := &codec.Body{} 284 err := cli.Invoke(context.Background(), 285 &codec.Body{Data: []byte("timeout")}, rspBody, 286 client.WithTarget(target), 287 client.WithProtocol(protocol)) 288 require.NotNil(t, err) 289 e, ok := err.(*errs.Error) 290 require.True(t, ok) 291 require.Equal(t, errs.RetClientTimeout, e.Code) 292 293 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) 294 defer cancel() 295 err = cli.Invoke(ctx, 296 &codec.Body{Data: []byte("timeout")}, rspBody, 297 client.WithTarget(target), 298 client.WithProtocol(protocol)) 299 require.NotNil(t, err) 300 e, ok = err.(*errs.Error) 301 require.True(t, ok) 302 require.Equal(t, errs.RetClientFullLinkTimeout, e.Code) 303 } 304 305 func TestSameCalleeMultiServiceName(t *testing.T) { 306 callee := "trpc.test.pbcallee" 307 serviceNames := []string{ 308 "trpc.test.helloworld0", 309 "trpc.test.helloworld1", 310 "trpc.test.helloworld2", 311 "trpc.test.helloworld3", 312 } 313 for i := range serviceNames { 314 if i != 2 { 315 require.Nil(t, client.RegisterClientConfig(callee, &client.BackendConfig{ 316 ServiceName: serviceNames[i], 317 Compression: codec.CompressTypeSnappy, 318 })) 319 continue 320 } 321 require.Nil(t, client.RegisterClientConfig(callee, &client.BackendConfig{ 322 ServiceName: serviceNames[i], 323 Compression: codec.CompressTypeBlockSnappy, 324 })) 325 } 326 ctx, msg := codec.EnsureMessage(context.Background()) 327 msg.WithCalleeServiceName(callee) 328 require.NotNil(t, client.DefaultClient.Invoke(ctx, nil, nil, client.WithServiceName(serviceNames[0]))) 329 require.Equal(t, codec.CompressTypeSnappy, msg.CompressType()) 330 ctx, msg = codec.EnsureMessage(context.Background()) 331 msg.WithCalleeServiceName(callee) 332 require.NotNil(t, client.DefaultClient.Invoke(ctx, nil, nil, client.WithServiceName(serviceNames[2]))) 333 require.Equal(t, codec.CompressTypeBlockSnappy, msg.CompressType()) 334 } 335 336 func TestMultiplexedUseLatestMsg(t *testing.T) { 337 codec.RegisterSerializer(0, &codec.NoopSerialization{}) 338 const target = "ip://127.0.0.1:8080" 339 340 rspBody := &codec.Body{} 341 require.Nil(t, client.New().Invoke(context.Background(), 342 &codec.Body{Data: []byte(t.Name())}, rspBody, 343 client.WithTarget(target), 344 client.WithTransport(&multiplexedTransport{ 345 require: func(_ context.Context, _ []byte, opts ...transport.RoundTripOption) { 346 var o transport.RoundTripOptions 347 for _, opt := range opts { 348 opt(&o) 349 } 350 require.NotZero(t, o.Msg.RequestID()) 351 }}), 352 client.WithMultiplexed(true), 353 client.WithFilter(func(ctx context.Context, req, rsp interface{}, next filter.ClientHandleFunc) error { 354 // make a copy of the msg, after next, copy the new msg back. 355 oldMsg := codec.Message(ctx) 356 ctx, msg := codec.WithNewMessage(ctx) 357 codec.CopyMsg(msg, oldMsg) 358 err := next(ctx, req, rsp) 359 codec.CopyMsg(oldMsg, msg) 360 return err 361 }), 362 )) 363 } 364 365 func TestFixTimeout(t *testing.T) { 366 codec.RegisterSerializer(0, &codec.NoopSerialization{}) 367 codec.Register("fake", nil, &fakeCodec{}) 368 target, protocol := "ip://127.0.0.1:8080", "fake" 369 370 cli := client.New() 371 372 rspBody := &codec.Body{} 373 t.Run("RetClientCanceled", func(t *testing.T) { 374 ctx, cancel := context.WithCancel(context.Background()) 375 cancel() 376 err := cli.Invoke(ctx, 377 &codec.Body{Data: []byte("clientCanceled")}, rspBody, 378 client.WithTarget(target), 379 client.WithProtocol(protocol)) 380 require.Equal(t, errs.RetClientCanceled, errs.Code(err)) 381 }) 382 383 t.Run("RetClientFullLinkTimeout", func(t *testing.T) { 384 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(1*time.Millisecond)) 385 defer cancel() 386 var d time.Duration 387 deadline, ok := t.Deadline() 388 if !ok { 389 d = 5 * time.Second 390 } else { 391 const arbitraryCleanupMargin = 1 * time.Second 392 d = time.Until(deadline) - arbitraryCleanupMargin 393 } 394 timer := time.NewTimer(d) 395 defer timer.Stop() 396 select { 397 case <-timer.C: 398 t.Fatalf(" context not timed out after %v", d) 399 case <-ctx.Done(): 400 } 401 if e := ctx.Err(); e != context.DeadlineExceeded { 402 t.Errorf("c.Err() == %v; want %v", e, context.DeadlineExceeded) 403 } 404 err := cli.Invoke(ctx, 405 &codec.Body{Data: []byte("fixTimeout")}, rspBody, 406 client.WithTarget(target), 407 client.WithProtocol(protocol)) 408 require.Equal(t, errs.RetClientFullLinkTimeout, errs.Code(err)) 409 }) 410 } 411 412 type multiplexedTransport struct { 413 require func(context.Context, []byte, ...transport.RoundTripOption) 414 fakeTransport 415 } 416 417 func (t *multiplexedTransport) RoundTrip( 418 ctx context.Context, 419 req []byte, 420 opts ...transport.RoundTripOption, 421 ) ([]byte, error) { 422 t.require(ctx, req, opts...) 423 return t.fakeTransport.RoundTrip(ctx, req, opts...) 424 } 425 426 type fakeTransport struct{} 427 428 func (c *fakeTransport) RoundTrip(ctx context.Context, req []byte, 429 roundTripOpts ...transport.RoundTripOption) (rsp []byte, err error) { 430 time.Sleep(time.Millisecond * 2) 431 if string(req) == "callfail" { 432 return nil, errors.New("transport call fail") 433 } 434 435 if string(req) == "timeout" { 436 return nil, &errs.Error{ 437 Type: errs.ErrorTypeFramework, 438 Code: errs.RetClientTimeout, 439 Msg: "transport call fail", 440 } 441 } 442 443 if string(req) == "nilrsp" { 444 return nil, nil 445 } 446 return req, nil 447 } 448 449 func (c *fakeTransport) Send(ctx context.Context, req []byte, opts ...transport.RoundTripOption) error { 450 return nil 451 } 452 453 func (c *fakeTransport) Recv(ctx context.Context, opts ...transport.RoundTripOption) ([]byte, error) { 454 body, ok := ctx.Value("recv-decode-error").(string) 455 if ok { 456 return []byte(body), nil 457 } 458 459 err, ok := ctx.Value("recv-error").(string) 460 if ok { 461 return nil, errors.New(err) 462 } 463 return []byte("body"), nil 464 } 465 466 func (c *fakeTransport) Init(ctx context.Context, opts ...transport.RoundTripOption) error { 467 return nil 468 } 469 func (c *fakeTransport) Close(ctx context.Context) { 470 return 471 } 472 473 type fakeCodec struct { 474 } 475 476 func (c *fakeCodec) Encode(msg codec.Msg, reqBody []byte) (reqBuf []byte, err error) { 477 if string(reqBody) == "failbody" { 478 return nil, errors.New("encode fail") 479 } 480 return reqBody, nil 481 } 482 483 func (c *fakeCodec) Decode(msg codec.Msg, rspBuf []byte) (rspBody []byte, err error) { 484 if string(rspBuf) == "businessfail" { 485 return nil, errors.New("businessfail") 486 } 487 488 if string(rspBuf) == "msgfail" { 489 msg.WithClientRspErr(errors.New("msgfail")) 490 return nil, nil 491 } 492 return rspBuf, nil 493 } 494 495 type fakeSelector struct { 496 } 497 498 func (c *fakeSelector) Select(serviceName string, opt ...selector.Option) (*registry.Node, error) { 499 if serviceName == "selectfail" { 500 return nil, errors.New("selectfail") 501 } 502 503 if serviceName == "emptynode" { 504 return ®istry.Node{}, nil 505 } 506 507 if serviceName == "udpnetwork" { 508 return ®istry.Node{ 509 Network: "udp", 510 Address: "127.0.0.1:8080", 511 }, nil 512 } 513 514 if serviceName == "unknownnetwork" { 515 return ®istry.Node{ 516 Network: "unknown", 517 Address: "127.0.0.1:8080", 518 }, nil 519 } 520 521 return nil, errors.New("unknown servicename") 522 } 523 524 func (c *fakeSelector) Report(node *registry.Node, cost time.Duration, err error) error { 525 return nil 526 }