trpc.group/trpc-go/trpc-go@v1.0.3/stream/client_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 // Package stream_test Unit test for package stream. 15 package stream_test 16 17 import ( 18 "context" 19 "crypto/rand" 20 "encoding/binary" 21 "errors" 22 "fmt" 23 "io" 24 "testing" 25 "time" 26 27 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 28 29 trpc "trpc.group/trpc-go/trpc-go" 30 "trpc.group/trpc-go/trpc-go/client" 31 "trpc.group/trpc-go/trpc-go/codec" 32 "trpc.group/trpc-go/trpc-go/errs" 33 "trpc.group/trpc-go/trpc-go/server" 34 "trpc.group/trpc-go/trpc-go/stream" 35 "trpc.group/trpc-go/trpc-go/transport" 36 37 "github.com/stretchr/testify/assert" 38 ) 39 40 var ctx = context.Background() 41 42 type fakeTransport struct { 43 expectChan chan recvExpect 44 send func() error 45 close func() 46 } 47 48 // RoundTrip Mock RoundTrip method. 49 func (c *fakeTransport) RoundTrip(ctx context.Context, req []byte, 50 roundTripOpts ...transport.RoundTripOption) (rsp []byte, err error) { 51 return nil, nil 52 } 53 54 // Send Mock Send method. 55 func (c *fakeTransport) Send(ctx context.Context, req []byte, opts ...transport.RoundTripOption) error { 56 if c.send != nil { 57 return c.send() 58 } 59 return nil 60 } 61 62 type recvExpect func(*trpc.FrameHead, codec.Msg) ([]byte, error) 63 64 // Recv Mock recv method. 65 func (c *fakeTransport) Recv(ctx context.Context, opts ...transport.RoundTripOption) ([]byte, error) { 66 msg := codec.Message(ctx) 67 var fh *trpc.FrameHead 68 fh, ok := msg.FrameHead().(*trpc.FrameHead) 69 if !ok { 70 fh = &trpc.FrameHead{} 71 msg.WithFrameHead(fh) 72 } 73 f := <-c.expectChan 74 return f(fh, msg) 75 } 76 77 // Init Mock Init method. 78 func (c *fakeTransport) Init(ctx context.Context, opts ...transport.RoundTripOption) error { 79 return nil 80 } 81 82 // Close Mock Close method. 83 func (c *fakeTransport) Close(ctx context.Context) { 84 if c.close != nil { 85 c.close() 86 } 87 } 88 89 type fakeCodec struct { 90 } 91 92 // Encode Mock codec Encode method. 93 func (c *fakeCodec) Encode(msg codec.Msg, reqBody []byte) (reqBuf []byte, err error) { 94 if string(reqBody) == "failbody" { 95 return nil, errors.New("encode fail") 96 } 97 return reqBody, nil 98 } 99 100 // Decode Mock codec Decode method. 101 func (c *fakeCodec) Decode(msg codec.Msg, rspBuf []byte) (rspBody []byte, err error) { 102 if string(rspBuf) == "businessfail" { 103 return nil, errors.New("businessfail") 104 } 105 if string(rspBuf) == "msgfail" { 106 msg.WithClientRspErr(errors.New("msgfail")) 107 return nil, nil 108 } 109 return rspBuf, nil 110 } 111 112 // TestMain tests the Main function. 113 func TestMain(m *testing.M) { 114 transport.DefaultServerTransport = &fakeServerTransport{} 115 m.Run() 116 } 117 118 // TestClient tests the streaming client. 119 func TestClient(t *testing.T) { 120 var reqBody = &codec.Body{Data: []byte("body")} 121 var rspBody = &codec.Body{} 122 codec.RegisterSerializer(0, &codec.NoopSerialization{}) 123 codec.Register("fake", nil, &fakeCodec{}) 124 codec.Register("fake-nil", nil, nil) 125 126 cli := stream.NewStreamClient() 127 assert.Equal(t, cli, stream.DefaultStreamClient) 128 129 ctx := context.Background() 130 var ft = &fakeTransport{expectChan: make(chan recvExpect, 1)} 131 transport.DefaultClientTransport = ft 132 133 f := func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 134 return nil, nil 135 } 136 ft.expectChan <- f 137 cs, err := cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello", 138 client.WithTarget("ip://127.0.0.1:8000"), 139 client.WithProtocol("fake"), client.WithCurrentSerializationType(codec.SerializationTypeNoop), 140 client.WithCurrentCompressType(codec.CompressTypeNoop), 141 client.WithStreamTransport(ft)) 142 assert.NotNil(t, cs) 143 assert.Nil(t, err) 144 145 // Test Context. 146 resultCtx := cs.Context() 147 assert.NotNil(t, resultCtx) 148 // Test to send data normally. 149 err = cs.SendMsg(reqBody) 150 assert.Nil(t, err) 151 152 f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 153 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 154 return []byte("body"), nil 155 } 156 ft.expectChan <- f 157 158 // Test to receive data normally. 159 err = cs.RecvMsg(rspBody) 160 assert.Nil(t, err) 161 assert.Equal(t, rspBody.Data, []byte("body")) 162 163 f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 164 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE) 165 return nil, nil 166 } 167 ft.expectChan <- f 168 169 // Test received io.EOF. 170 rspBody = &codec.Body{} 171 err = cs.RecvMsg(rspBody) 172 assert.Equal(t, io.EOF, err) 173 assert.Nil(t, rspBody.Data) 174 175 err = cs.CloseSend() 176 assert.Nil(t, err) 177 178 f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 179 return nil, nil 180 } 181 ft.expectChan <- f 182 cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello", 183 client.WithTarget("ip://127.0.0.1:8000"), 184 client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop), 185 client.WithTransport(ft), 186 client.WithStreamTransport(ft)) 187 assert.NotNil(t, cs) 188 assert.Nil(t, err) 189 190 f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 191 msg.WithClientRspErr(errors.New("close type is reset")) 192 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE) 193 return nil, nil 194 } 195 ft.expectChan <- f 196 // test reset. 197 rspBody = &codec.Body{} 198 err = cs.RecvMsg(rspBody) 199 assert.NotNil(t, err) 200 assert.Nil(t, rspBody.Data) 201 assert.Contains(t, err.Error(), "close type is reset") 202 203 } 204 205 // TestClientFlowControl tests the streaming client. 206 func TestClientFlowControl(t *testing.T) { 207 var reqBody = &codec.Body{Data: []byte("body")} 208 209 var rspBody = &codec.Body{} 210 codec.RegisterSerializer(0, &codec.NoopSerialization{}) 211 codec.Register("fake", nil, &fakeCodec{}) 212 codec.Register("fake-nil", nil, nil) 213 214 cli := stream.NewStreamClient() 215 assert.Equal(t, cli, stream.DefaultStreamClient) 216 217 ctx := context.Background() 218 var ft = &fakeTransport{expectChan: make(chan recvExpect, 1)} 219 transport.DefaultClientTransport = ft 220 221 f := func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 222 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 223 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{InitWindowSize: 2000}) 224 return nil, nil 225 } 226 ft.expectChan <- f 227 228 cs, err := cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello", 229 client.WithTarget("ip://127.0.0.1:8000"), 230 client.WithProtocol("fake"), client.WithCurrentSerializationType(codec.SerializationTypeNoop), 231 client.WithCurrentCompressType(codec.CompressTypeNoop), 232 client.WithTransport(ft), 233 client.WithStreamTransport(ft)) 234 assert.NotNil(t, cs) 235 assert.Nil(t, err) 236 237 // Test Context. 238 resultCtx := cs.Context() 239 assert.NotNil(t, resultCtx) 240 // Test to send data normally. 241 err = cs.SendMsg(reqBody) 242 assert.Nil(t, err) 243 244 for i := 0; i < 20000; i++ { 245 f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 246 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 247 return []byte("body"), nil 248 } 249 ft.expectChan <- f 250 // Test to receive data normally. 251 err = cs.RecvMsg(rspBody) 252 assert.Nil(t, err) 253 assert.Equal(t, rspBody.Data, []byte("body")) 254 } 255 256 f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 257 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE) 258 return nil, nil 259 } 260 ft.expectChan <- f 261 262 // Test received io.EOF. 263 rspBody = &codec.Body{} 264 err = cs.RecvMsg(rspBody) 265 assert.Equal(t, io.EOF, err) 266 assert.Nil(t, rspBody.Data) 267 } 268 269 // TestClientError tests the case of streaming Client error handling. 270 func TestClientError(t *testing.T) { 271 var rspBody = &codec.Body{} 272 codec.RegisterSerializer(0, &codec.NoopSerialization{}) 273 codec.Register("fake", nil, &fakeCodec{}) 274 codec.Register("fake-nil", nil, nil) 275 276 cli := stream.NewStreamClient() 277 assert.Equal(t, cli, stream.DefaultStreamClient) 278 279 var ft = &fakeTransport{expectChan: make(chan recvExpect, 1)} 280 transport.DefaultClientTransport = ft 281 f := func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 282 return nil, errors.New("init error") 283 } 284 ft.expectChan <- f 285 286 // Test for init transport errors. 287 cs, err := cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello", 288 client.WithTarget("ip://127.0.0.1:8000"), 289 client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop), 290 client.WithTransport(ft), 291 client.WithStreamTransport(ft)) 292 assert.Nil(t, cs) 293 assert.NotNil(t, err) 294 295 // test Init error. 296 cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello", 297 client.WithTarget("ip://127.0.0.1:8000"), 298 client.WithProtocol("fake-nil"), client.WithSerializationType(codec.SerializationTypeNoop), 299 client.WithTransport(ft), 300 client.WithStreamTransport(ft)) 301 assert.Nil(t, cs) 302 assert.NotNil(t, err) 303 304 f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 305 return nil, nil 306 } 307 ft.expectChan <- f 308 cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello", 309 client.WithTarget("ip://127.0.0.1:8000"), 310 client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop), 311 client.WithTransport(ft), 312 client.WithStreamTransport(ft), client.WithClientStreamQueueSize(100000)) 313 assert.NotNil(t, cs) 314 assert.Nil(t, err) 315 // test recv data error. 316 f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 317 return nil, errors.New("recv data error") 318 } 319 ft.expectChan <- f 320 err = cs.RecvMsg(rspBody) 321 assert.NotNil(t, err) 322 assert.Nil(t, rspBody.Data) 323 324 f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 325 msg.WithClientRspErr(errors.New("test init with clientRspError")) 326 return nil, nil 327 } 328 ft.expectChan <- f 329 cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello", 330 client.WithTarget("ip://127.0.0.1:8000"), 331 client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop), 332 client.WithStreamTransport(ft), client.WithClientStreamQueueSize(100000)) 333 assert.Nil(t, cs) 334 assert.NotNil(t, err) 335 336 // receive unexpected stream frame type 337 f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 338 msg.WithStreamFrame(int(1)) 339 return nil, nil 340 } 341 ft.expectChan <- f 342 cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello", 343 client.WithTarget("ip://127.0.0.1:8000"), 344 client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop), 345 client.WithStreamTransport(ft), client.WithClientStreamQueueSize(100000)) 346 assert.Nil(t, cs) 347 assert.Contains(t, err.Error(), "unexpected frame type") 348 } 349 350 // TestClientContext tests the case of streaming client context cancel and timeout. 351 func TestClientContext(t *testing.T) { 352 353 var rspBody = &codec.Body{} 354 codec.RegisterSerializer(0, &codec.NoopSerialization{}) 355 codec.Register("fake", nil, &fakeCodec{}) 356 codec.Register("fake-nil", nil, nil) 357 358 cli := stream.NewStreamClient() 359 assert.Equal(t, cli, stream.DefaultStreamClient) 360 361 var ft = &fakeTransport{expectChan: make(chan recvExpect, 1)} 362 transport.DefaultClientTransport = ft 363 // test context cancel situation. 364 f := func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 365 return nil, nil 366 } 367 ft.expectChan <- f 368 ctx, cancel := context.WithCancel(context.Background()) 369 cs, err := cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello", 370 client.WithTarget("ip://127.0.0.1:8000"), 371 client.WithProtocol("fake"), client.WithCurrentSerializationType(codec.SerializationTypeNoop), 372 client.WithCurrentCompressType(codec.CompressTypeNoop), 373 client.WithTransport(ft), 374 client.WithStreamTransport(ft)) 375 assert.NotNil(t, cs) 376 assert.Nil(t, err) 377 cancel() 378 err = cs.RecvMsg(rspBody) 379 assert.Contains(t, err.Error(), "tcp client stream canceled before recv") 380 f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 381 return nil, errors.New("close it") 382 } 383 ft.expectChan <- f 384 time.Sleep(5 * time.Millisecond) 385 // test context timeout situation. 386 f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { 387 return nil, nil 388 } 389 ft.expectChan <- f 390 391 timeoutCtx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) 392 defer cancel() 393 cs, err = cli.NewStream(timeoutCtx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello", 394 client.WithTarget("ip://127.0.0.1:8000"), 395 client.WithProtocol("fake"), client.WithCurrentSerializationType(codec.SerializationTypeNoop), 396 client.WithCurrentCompressType(codec.CompressTypeNoop), 397 client.WithTransport(ft), 398 client.WithStreamTransport(ft)) 399 assert.NotNil(t, cs) 400 assert.Nil(t, err) 401 402 err = cs.RecvMsg(rspBody) 403 assert.Contains(t, err.Error(), "tcp client stream canceled timeout before recv") 404 } 405 406 func clientFilterAdd1(ctx context.Context, desc *client.ClientStreamDesc, newStream client.Streamer) (client.ClientStream, error) { 407 var msg codec.Msg 408 ctx, msg = codec.EnsureMessage(ctx) 409 meta := msg.ClientMetaData() 410 if meta == nil { 411 meta = codec.MetaData{} 412 } 413 meta[testKey1] = []byte(testData[testKey1]) 414 msg.WithClientMetaData(meta) 415 416 s, err := newStream(ctx, desc) 417 if err != nil { 418 return nil, err 419 } 420 421 return newWrappedClientStream(s), nil 422 } 423 424 func clientFilterAdd2(ctx context.Context, desc *client.ClientStreamDesc, newStream client.Streamer) (client.ClientStream, error) { 425 var msg codec.Msg 426 ctx, msg = codec.EnsureMessage(ctx) 427 meta := msg.ClientMetaData() 428 if meta == nil { 429 meta = codec.MetaData{} 430 } 431 meta[testKey2] = []byte(testData[testKey2]) 432 msg.WithClientMetaData(meta) 433 434 s, err := newStream(ctx, desc) 435 if err != nil { 436 return nil, err 437 } 438 return newWrappedClientStream(s), nil 439 } 440 441 type wrappedClientStream struct { 442 client.ClientStream 443 } 444 445 func newWrappedClientStream(s client.ClientStream) client.ClientStream { 446 return &wrappedClientStream{s} 447 } 448 449 func (w *wrappedClientStream) RecvMsg(m interface{}) error { 450 err := w.ClientStream.RecvMsg(m) 451 num := binary.LittleEndian.Uint64(m.(*codec.Body).Data[:8]) 452 binary.LittleEndian.PutUint64(m.(*codec.Body).Data[:8], num+1) 453 return err 454 } 455 456 func (w *wrappedClientStream) SendMsg(m interface{}) error { 457 num := binary.LittleEndian.Uint64(m.(*codec.Body).Data[:8]) 458 binary.LittleEndian.PutUint64(m.(*codec.Body).Data[:8], num+1) 459 return w.ClientStream.SendMsg(m) 460 } 461 462 func TestClientStreamClientFilters(t *testing.T) { 463 rawData := make([]byte, 1024) 464 rand.Read(rawData) 465 var beginNum uint64 = 100 466 467 counts := 1000 468 svrOpts := []server.Option{ 469 server.WithAddress("127.0.0.1:30211"), 470 server.WithStreamFilters(serverFilterAdd1, serverFilterAdd2), 471 } 472 handle := func(s server.Stream) error { 473 var req *codec.Body 474 475 // server receives data. 476 for i := 0; i < counts; i++ { 477 req = getBytes(1024) 478 err := s.RecvMsg(req) 479 assert.Nil(t, err) 480 resultNum := binary.LittleEndian.Uint64(req.Data[:8]) 481 482 // After the client SendMsg + server RecvMsg, two Filter, Num+4. 483 assert.Equal(t, beginNum+4, resultNum) 484 assert.Equal(t, rawData[8:], req.Data[8:]) 485 } 486 err := s.RecvMsg(getBytes(1024)) 487 assert.Equal(t, io.EOF, err) 488 489 // server sends data. 490 rsp := getBytes(1024) 491 for i := 0; i < counts; i++ { 492 copy(rsp.Data, req.Data) 493 err = s.SendMsg(rsp) 494 assert.Nil(t, err) 495 } 496 return nil 497 } 498 svr := startStreamServer(handle, svrOpts) 499 defer closeStreamServer(svr) 500 501 cliOpts := []client.Option{ 502 client.WithTarget("ip://127.0.0.1:30211"), 503 client.WithStreamFilters(clientFilterAdd1, clientFilterAdd2), 504 } 505 cliStream, err := getClientStream(context.Background(), bidiDesc, cliOpts) 506 assert.Nil(t, err) 507 508 // client sends data. 509 for i := 0; i < counts; i++ { 510 req := getBytes(1024) 511 copy(req.Data, rawData) 512 binary.LittleEndian.PutUint64(req.Data[:8], beginNum) 513 514 err = cliStream.SendMsg(req) 515 assert.Nil(t, err) 516 } 517 err = cliStream.CloseSend() 518 assert.Nil(t, err) 519 520 // client receives data. 521 for i := 0; i < counts; i++ { 522 rsp := getBytes(1024) 523 err = cliStream.RecvMsg(rsp) 524 assert.Nil(t, err) 525 526 // After the client once SendMsg, once RecvMsg, two Filter, Num+4. 527 resultNum := binary.LittleEndian.Uint64(rsp.Data[:8]) 528 assert.Equal(t, beginNum+8, resultNum) 529 assert.Equal(t, rawData[8:], rsp.Data[8:]) 530 } 531 rsp := getBytes(1024) 532 err = cliStream.RecvMsg(rsp) 533 assert.Equal(t, io.EOF, err) 534 } 535 536 func TestClientStreamFlowControlStop(t *testing.T) { 537 windows := 102400 538 dataLen := 1024 539 maxSends := windows / dataLen 540 svrOpts := []server.Option{ 541 server.WithAddress("127.0.0.1:30211"), 542 server.WithMaxWindowSize(uint32(windows)), 543 } 544 handle := func(s server.Stream) error { 545 time.Sleep(time.Hour) 546 return nil 547 } 548 svr := startStreamServer(handle, svrOpts) 549 defer closeStreamServer(svr) 550 551 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(200*time.Millisecond)) 552 defer cancel() 553 cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30211")} 554 cliStream, err := getClientStream(ctx, bidiDesc, cliOpts) 555 assert.Nil(t, err) 556 557 req := getBytes(dataLen) 558 rand.Read(req.Data) 559 560 for i := 0; i < maxSends; i++ { 561 err = cliStream.SendMsg(req) 562 assert.Nil(t, err) 563 } 564 err = cliStream.SendMsg(req) 565 assert.Equal(t, errors.New("stream is already closed"), err) 566 } 567 568 func TestServerStreamFlowControlStop(t *testing.T) { 569 windows := 102400 570 dataLen := 1024 571 maxSends := windows / dataLen 572 waitCh := make(chan struct{}, 1) 573 svrOpts := []server.Option{server.WithAddress("127.0.0.1:30211")} 574 handle := func(s server.Stream) error { 575 rsp := getBytes(dataLen) 576 rand.Read(rsp.Data) 577 for i := 0; i < maxSends; i++ { 578 err := s.SendMsg(rsp) 579 assert.Nil(t, err) 580 } 581 582 finish := make(chan struct{}, 1) 583 go func() { 584 err := s.SendMsg(rsp) 585 assert.Equal(t, errors.New("stream is already closed"), err) 586 finish <- struct{}{} 587 }() 588 589 deadline := time.NewTimer(200 * time.Millisecond) 590 select { 591 case <-deadline.C: 592 case <-finish: 593 assert.Fail(t, "SendMsg should block") 594 } 595 596 waitCh <- struct{}{} 597 return nil 598 } 599 svr := startStreamServer(handle, svrOpts) 600 defer closeStreamServer(svr) 601 602 cliOpts := []client.Option{ 603 client.WithTarget("ip://127.0.0.1:30211"), 604 client.WithMaxWindowSize(uint32(windows)), 605 } 606 _, err := getClientStream(context.Background(), bidiDesc, cliOpts) 607 assert.Nil(t, err) 608 <-waitCh 609 } 610 611 func TestClientStreamSendRecvNoBlock(t *testing.T) { 612 svrOpts := []server.Option{server.WithAddress("127.0.0.1:30210")} 613 handle := func(s server.Stream) error { 614 // Must sleep, to avoid returning before receiving the first packet from the client, 615 // resulting in the processing of the first packet returns an error, 616 // losing the chance for the test client to block on the second SendMsg. 617 time.Sleep(200 * time.Millisecond) 618 return errors.New("test error") 619 } 620 svr := startStreamServer(handle, svrOpts) 621 defer closeStreamServer(svr) 622 623 cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30210")} 624 cliStream, err := getClientStream(context.Background(), bidiDesc, cliOpts) 625 assert.Nil(t, err) 626 627 // defaultInitWindowSize = 65535. 628 req := getBytes(65535) 629 err = cliStream.SendMsg(req) 630 assert.Nil(t, err) 631 632 err = cliStream.SendMsg(req) 633 fmt.Println(err) 634 assert.NotNil(t, err) 635 636 rsp := getBytes(1024) 637 err = cliStream.RecvMsg(rsp) 638 assert.NotNil(t, err) 639 } 640 641 func TestServerStreamSendRecvNoBlock(t *testing.T) { 642 svrOpts := []server.Option{server.WithAddress("127.0.0.1:30210")} 643 SendMsgReturn := make(chan struct{}, 1) 644 RecvMsgReturn := make(chan struct{}, 1) 645 handle := func(s server.Stream) error { 646 go func() { 647 msg := getBytes(65535) 648 s.SendMsg(msg) 649 s.SendMsg(msg) 650 SendMsgReturn <- struct{}{} 651 }() 652 go func() { 653 msg := getBytes(1024) 654 s.RecvMsg(msg) 655 s.RecvMsg(msg) 656 RecvMsgReturn <- struct{}{} 657 }() 658 // Must sleep, to avoid the second SendMsg does not enter the waiting window to block. 659 time.Sleep(200 * time.Millisecond) 660 return nil 661 } 662 svr := startStreamServer(handle, svrOpts) 663 defer closeStreamServer(svr) 664 665 cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30210")} 666 _, err := getClientStream(context.Background(), bidiDesc, cliOpts) 667 assert.Nil(t, err) 668 669 <-SendMsgReturn 670 <-RecvMsgReturn 671 } 672 673 func TestClientStreamReturn(t *testing.T) { 674 const ( 675 invalidCompressType = -1 676 dataLen = 1024 677 ) 678 679 svrOpts := []server.Option{ 680 server.WithAddress("127.0.0.1:30211"), 681 server.WithCurrentCompressType(invalidCompressType), 682 } 683 handle := func(s server.Stream) error { 684 req := getBytes(dataLen) 685 s.RecvMsg(req) 686 rsp := req 687 s.SendMsg(rsp) 688 return errs.NewFrameError(101, "expected error") 689 } 690 svr := startStreamServer(handle, svrOpts) 691 defer closeStreamServer(svr) 692 693 cliOpts := []client.Option{ 694 client.WithTarget("ip://127.0.0.1:30211"), 695 client.WithCompressType(invalidCompressType), 696 } 697 698 clientStream, err := getClientStream(context.Background(), clientDesc, cliOpts) 699 assert.Nil(t, err) 700 err = clientStream.SendMsg(getBytes(dataLen)) 701 assert.Nil(t, err) 702 703 rsp := getBytes(dataLen) 704 err = clientStream.RecvMsg(rsp) 705 706 assert.EqualValues(t, int32(101), errs.Code(err.(*errs.Error).Unwrap())) 707 } 708 709 // TestClientSendFailWhenServerUnavailable test when the client blocks 710 // on SendMsg because of flow control, if the server is closed, the client 711 // SendMsg should return. 712 func TestClientSendFailWhenServerUnavailable(t *testing.T) { 713 codec.Register("mock", nil, &fakeCodec{}) 714 tp := &fakeTransport{expectChan: make(chan recvExpect, 1)} 715 tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) { 716 return nil, nil 717 } 718 cs, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "", 719 client.WithProtocol("mock"), 720 client.WithTarget("ip://127.0.0.1:8000"), 721 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 722 client.WithStreamTransport(tp), 723 ) 724 assert.Nil(t, err) 725 assert.NotNil(t, cs) 726 assert.Nil(t, cs.SendMsg(getBytes(65535))) 727 tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) { 728 return nil, errors.New("server is closed") 729 } 730 assert.Eventually( 731 t, 732 func() bool { 733 cs.SendMsg(getBytes(65535)) 734 return true 735 }, 736 time.Second, 737 100*time.Millisecond, 738 ) 739 } 740 741 // TestClientReceiveErrorWhenServerUnavailable tests that the client receives a non-io.EOF 742 // error when the server crashes or the connection is closed, otherwise the client would 743 // mistakenly think that the server closed the stream normally. 744 func TestClientReceiveErrorWhenServerUnavailable(t *testing.T) { 745 codec.Register("mock", nil, &fakeCodec{}) 746 tp := &fakeTransport{expectChan: make(chan recvExpect, 1)} 747 tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) { 748 return nil, nil 749 } 750 cs, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "", 751 client.WithProtocol("mock"), 752 client.WithTarget("ip://127.0.0.1:8000"), 753 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 754 client.WithStreamTransport(tp), 755 ) 756 assert.Nil(t, err) 757 assert.NotNil(t, cs) 758 tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) { 759 return nil, io.EOF 760 } 761 err = cs.RecvMsg(nil) 762 assert.NotEqual(t, io.EOF, err) 763 assert.ErrorIs(t, err, io.EOF) 764 } 765 766 func TestClientNewStreamFail(t *testing.T) { 767 codec.Register("mock", nil, &fakeCodec{}) 768 t.Run("Close Transport when Send Fail", func(t *testing.T) { 769 var isClosed bool 770 tp := &fakeTransport{expectChan: make(chan recvExpect, 1)} 771 tp.send = func() error { 772 return errors.New("client error") 773 } 774 tp.close = func() { 775 isClosed = true 776 } 777 _, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "", 778 client.WithProtocol("mock"), 779 client.WithTarget("ip://127.0.0.1:8000"), 780 client.WithStreamTransport(tp), 781 ) 782 assert.NotNil(t, err) 783 assert.True(t, isClosed) 784 }) 785 t.Run("Close Transport when Recv Fail", func(t *testing.T) { 786 var isClosed bool 787 tp := &fakeTransport{expectChan: make(chan recvExpect, 1)} 788 tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) { 789 m.WithClientRspErr(errors.New("server error")) 790 return nil, nil 791 } 792 tp.close = func() { 793 isClosed = true 794 } 795 _, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "", 796 client.WithProtocol("mock"), 797 client.WithTarget("ip://127.0.0.1:8000"), 798 client.WithStreamTransport(tp), 799 ) 800 assert.NotNil(t, err) 801 assert.True(t, isClosed) 802 }) 803 } 804 805 func TestClientServerCompress(t *testing.T) { 806 var ( 807 dataLen = 1024 808 compressType = codec.CompressTypeSnappy 809 ) 810 svrOpts := []server.Option{ 811 server.WithAddress("127.0.0.1:30211"), 812 } 813 handle := func(s server.Stream) error { 814 assert.Equal(t, compressType, codec.Message(s.Context()).CompressType()) 815 req := getBytes(dataLen) 816 s.RecvMsg(req) 817 rsp := req 818 s.SendMsg(rsp) 819 return nil 820 } 821 svr := startStreamServer(handle, svrOpts) 822 defer closeStreamServer(svr) 823 824 cliOpts := []client.Option{ 825 client.WithTarget("ip://127.0.0.1:30211"), 826 client.WithCompressType(compressType), 827 } 828 829 clientStream, err := getClientStream(context.Background(), clientDesc, cliOpts) 830 assert.Nil(t, err) 831 req := getBytes(dataLen) 832 rand.Read(req.Data) 833 err = clientStream.SendMsg(req) 834 assert.Nil(t, err) 835 836 rsp := getBytes(dataLen) 837 err = clientStream.RecvMsg(rsp) 838 assert.Equal(t, rsp.Data, req.Data) 839 assert.Nil(t, err) 840 }