trpc.group/trpc-go/trpc-go@v1.0.3/stream/server_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package stream_test 15 16 import ( 17 "bytes" 18 "context" 19 "encoding/binary" 20 "errors" 21 "fmt" 22 "io" 23 "math/rand" 24 "net" 25 "sync" 26 "testing" 27 "time" 28 29 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 30 31 "trpc.group/trpc-go/trpc-go/client" 32 "trpc.group/trpc-go/trpc-go/errs" 33 34 trpc "trpc.group/trpc-go/trpc-go" 35 "trpc.group/trpc-go/trpc-go/stream" 36 37 "trpc.group/trpc-go/trpc-go/codec" 38 "trpc.group/trpc-go/trpc-go/server" 39 "trpc.group/trpc-go/trpc-go/transport" 40 41 "github.com/stretchr/testify/assert" 42 ) 43 44 type fakeStreamHandle struct { 45 } 46 47 // StreamHandleFunc Mock StreamHandleFunc method 48 func (fs *fakeStreamHandle) StreamHandleFunc(ctx context.Context, sh server.StreamHandler, req []byte) ([]byte, error) { 49 return nil, nil 50 } 51 52 // Init Mock Init method 53 func (fs *fakeStreamHandle) Init(opts *server.Options) { 54 return 55 } 56 57 type fakeServerTransport struct{} 58 59 type fakeServerCodec struct{} 60 61 // Send Mock Send method 62 func (s *fakeServerTransport) Send(ctx context.Context, rspBuf []byte) error { 63 if string(rspBuf) == "init-error" { 64 return errors.New("init-error") 65 } 66 return nil 67 } 68 69 // Close Mock Close method 70 func (s *fakeServerTransport) Close(ctx context.Context) { 71 return 72 } 73 74 // ListenAndServe Mock ListenAndServe method 75 func (s *fakeServerTransport) ListenAndServe(ctx context.Context, opts ...transport.ListenServeOption) error { 76 77 return nil 78 } 79 80 // Decode Mock codec Decode method 81 func (c *fakeServerCodec) Decode(msg codec.Msg, reqBuf []byte) (reqBody []byte, err error) { 82 return reqBuf, nil 83 } 84 85 // Encode Mock codec Encode method 86 func (c *fakeServerCodec) Encode(msg codec.Msg, rspBody []byte) (rspBuf []byte, err error) { 87 rsp := string(rspBody) 88 if rsp == "encode-error" { 89 return nil, errors.New("server encode response fail") 90 } 91 if msg.StreamID() < uint32(100) { 92 return nil, errors.New("streamID less than 100") 93 } 94 if msg.StreamID() == uint32(101) { 95 return []byte("init-error"), nil 96 } 97 return rspBody, nil 98 } 99 100 func streamHandler(stream server.Stream) error { 101 time.Sleep(time.Second) 102 return nil 103 } 104 105 func errorStreamHandler(stream server.Stream) error { 106 return errors.New("handle fail") 107 } 108 109 type fakeAddr struct { 110 } 111 112 // Network method of Network Mock net.Addr interface 113 func (f *fakeAddr) Network() string { 114 return "tcp" 115 } 116 117 // String method of String Mock net.Addr interface 118 func (f *fakeAddr) String() string { 119 return "127.0.0.01:67891" 120 } 121 122 // TestStreamDispatcherHandleInit Test Stream Dispatcher 123 func TestStreamDispatcherHandleInit(t *testing.T) { 124 codec.Register("fake", &fakeServerCodec{}, nil) 125 126 si := &server.StreamServerInfo{} 127 dispatcher := stream.NewStreamDispatcher() 128 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 129 130 // Init test 131 opts := &server.Options{} 132 ft := &fakeServerTransport{} 133 opts.Transport = ft 134 opts.Codec = codec.GetServer("fake") 135 err := dispatcher.Init(opts) 136 assert.Nil(t, err) 137 assert.Equal(t, opts.Transport, opts.StreamTransport) 138 // StreamHandleFunc msg not nil 139 ctx := context.Background() 140 ctx, msg := codec.WithNewMessage(ctx) 141 rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, nil) 142 assert.Nil(t, rsp) 143 assert.Contains(t, err.Error(), "frameHead is not contained in msg") 144 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 145 // StreamHandleFunc handle init 146 fh := &trpc.FrameHead{} 147 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 148 msg.WithFrameHead(fh) 149 msg.WithStreamID(uint32(100)) 150 msg.WithRemoteAddr(&fakeAddr{}) 151 msg.WithLocalAddr(&fakeAddr{}) 152 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) 153 assert.Nil(t, rsp) 154 assert.Equal(t, err, errs.ErrServerNoResponse) 155 156 // StreamHandleFunc handle init with codec encode error 157 msg.WithFrameHead(fh) 158 msg.WithStreamID(uint32(99)) 159 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) 160 assert.Nil(t, rsp) 161 assert.Equal(t, err.Error(), "streamID less than 100") 162 163 // StreamHandleFunc handle init send error 164 msg.WithFrameHead(fh) 165 msg.WithStreamID(uint32(101)) 166 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init-error")) 167 assert.Nil(t, rsp) 168 assert.Contains(t, err.Error(), "init-error") 169 170 // StreamHandleFun handle data to validate streamID was stored 171 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 172 msg.WithFrameHead(fh) 173 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data")) 174 assert.Nil(t, rsp) 175 assert.Equal(t, err, errs.ErrServerNoResponse) 176 177 // StreamHandleFunc handle error 178 msg.WithFrameHead(fh) 179 msg.WithStreamID(100) 180 rsp, err = dispatcher.StreamHandleFunc(ctx, errorStreamHandler, si, []byte("init")) 181 assert.Nil(t, rsp) 182 assert.Equal(t, err, errs.ErrServerNoResponse) 183 time.Sleep(100 * time.Millisecond) 184 } 185 186 // TestStreamDispatcherHandleData test StreamDispatcher Handle data 187 func TestStreamDispatcherHandleData(t *testing.T) { 188 codec.Register("fake", &fakeServerCodec{}, nil) 189 190 si := &server.StreamServerInfo{} 191 dispatcher := stream.NewStreamDispatcher() 192 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 193 194 // Init test 195 opts := &server.Options{} 196 ft := &fakeServerTransport{} 197 opts.Transport = ft 198 opts.Codec = codec.GetServer("fake") 199 err := dispatcher.Init(opts) 200 assert.Nil(t, err) 201 assert.Equal(t, opts.Transport, opts.StreamTransport) 202 203 ctx := context.Background() 204 ctx, msg := codec.WithNewMessage(ctx) 205 fh := &trpc.FrameHead{} 206 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 207 msg.WithFrameHead(fh) 208 msg.WithStreamID(uint32(100)) 209 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 210 addr := &fakeAddr{} 211 msg.WithRemoteAddr(addr) 212 msg.WithLocalAddr(addr) 213 rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) 214 assert.Nil(t, rsp) 215 assert.Equal(t, err, errs.ErrServerNoResponse) 216 217 // handleData normal 218 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 219 msg.WithFrameHead(fh) 220 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data")) 221 assert.Nil(t, rsp) 222 assert.Equal(t, err, errs.ErrServerNoResponse) 223 224 // handleData error no such addr 225 raddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:1") 226 msg.WithRemoteAddr(raddr) 227 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 228 msg.WithFrameHead(fh) 229 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data")) 230 assert.Nil(t, rsp) 231 assert.Contains(t, err.Error(), "no such addr") 232 233 // handle data error no such stream id 234 msg.WithRemoteAddr(addr) 235 msg.WithStreamID(uint32(101)) 236 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 237 msg.WithFrameHead(fh) 238 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data")) 239 assert.Nil(t, rsp) 240 assert.Contains(t, err.Error(), "no such stream ID") 241 } 242 243 // TestStreamDispatcherHandleClose test handles Close frame 244 func TestStreamDispatcherHandleClose(t *testing.T) { 245 246 codec.Register("fake", &fakeServerCodec{}, nil) 247 248 si := &server.StreamServerInfo{} 249 dispatcher := stream.NewStreamDispatcher() 250 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 251 252 // Init test 253 opts := &server.Options{} 254 ft := &fakeServerTransport{} 255 opts.Transport = ft 256 opts.Codec = codec.GetServer("fake") 257 err := dispatcher.Init(opts) 258 assert.Nil(t, err) 259 assert.Equal(t, opts.Transport, opts.StreamTransport) 260 261 ctx := context.Background() 262 ctx, msg := codec.WithNewMessage(ctx) 263 fh := &trpc.FrameHead{} 264 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 265 msg.WithFrameHead(fh) 266 msg.WithStreamID(uint32(100)) 267 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 268 269 addr := &fakeAddr{} 270 msg.WithRemoteAddr(addr) 271 msg.WithLocalAddr(addr) 272 msg.WithFrameHead(fh) 273 rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) 274 assert.Nil(t, rsp) 275 assert.Equal(t, err, errs.ErrServerNoResponse) 276 277 // handle close normal 278 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE) 279 msg.WithFrameHead(fh) 280 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("close")) 281 assert.Nil(t, rsp) 282 assert.Equal(t, errs.ErrServerNoResponse, err) 283 284 // handle close no such addr 285 msg.WithFrameHead(fh) 286 raddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:1") 287 msg.WithRemoteAddr(raddr) 288 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("close")) 289 assert.Nil(t, rsp) 290 assert.Equal(t, errs.ErrServerNoResponse, err) 291 292 // handle close server rsp err 293 msg.WithRemoteAddr(addr) 294 msg.WithFrameHead(fh) 295 msg.WithServerRspErr(errors.New("server rsp error")) 296 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("close")) 297 assert.Nil(t, rsp) 298 assert.Equal(t, errs.ErrServerNoResponse, err) 299 300 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK) 301 msg.WithFrameHead(fh) 302 msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{}) 303 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("feedback")) 304 assert.Nil(t, rsp) 305 assert.Equal(t, err, errs.ErrServerNoResponse) 306 307 fh.StreamFrameType = uint8(8) 308 msg.WithFrameHead(fh) 309 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("unknown")) 310 assert.Nil(t, rsp) 311 assert.Contains(t, err.Error(), "unknown frame type") 312 } 313 314 // TestServerStreamSendMsg test server receives messages 315 func TestServerStreamSendMsg(t *testing.T) { 316 codec.Register("fake", &fakeServerCodec{}, nil) 317 318 si := &server.StreamServerInfo{} 319 dispatcher := stream.NewStreamDispatcher() 320 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 321 322 // Init test 323 opts := &server.Options{} 324 ft := &fakeServerTransport{} 325 opts.Transport = ft 326 opts.Codec = codec.GetServer("fake") 327 err := dispatcher.Init(opts) 328 assert.Nil(t, err) 329 assert.Equal(t, opts.Transport, opts.StreamTransport) 330 331 // StreamHandleFunc msg not nil 332 ctx := context.Background() 333 ctx, msg := codec.WithNewMessage(ctx) 334 fh := &trpc.FrameHead{} 335 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 336 msg.WithFrameHead(fh) 337 msg.WithStreamID(uint32(100)) 338 msg.WithRemoteAddr(&fakeAddr{}) 339 msg.WithLocalAddr(&fakeAddr{}) 340 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 341 342 opts.CurrentCompressType = codec.CompressTypeNoop 343 opts.CurrentSerializationType = codec.SerializationTypeNoop 344 345 sh := func(ss server.Stream) error { 346 ctx = ss.Context() 347 assert.NotNil(t, ctx) 348 err := ss.SendMsg(&codec.Body{Data: []byte("init")}) 349 assert.Nil(t, err) 350 return err 351 } 352 rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 353 assert.Nil(t, rsp) 354 assert.Equal(t, err, errs.ErrServerNoResponse) 355 time.Sleep(100 * time.Millisecond) 356 357 opts.CurrentCompressType = 5 358 opts.CurrentSerializationType = codec.SerializationTypeNoop 359 sh = func(ss server.Stream) error { 360 ctx = ss.Context() 361 assert.NotNil(t, ctx) 362 err := ss.SendMsg(&codec.Body{Data: []byte("init")}) 363 assert.NotNil(t, err) 364 return err 365 } 366 dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 367 time.Sleep(200 * time.Millisecond) 368 369 opts.CurrentCompressType = codec.CompressTypeNoop 370 opts.CurrentSerializationType = codec.SerializationTypeNoop 371 sh = func(ss server.Stream) error { 372 ctx = ss.Context() 373 assert.NotNil(t, ctx) 374 err := ss.SendMsg(&codec.Body{Data: []byte("encode-error")}) 375 assert.Contains(t, err.Error(), "server codec Encode") 376 return err 377 } 378 dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 379 time.Sleep(200 * time.Millisecond) 380 381 opts.CurrentCompressType = codec.CompressTypeNoop 382 opts.CurrentSerializationType = codec.SerializationTypeNoop 383 sh = func(ss server.Stream) error { 384 ctx = ss.Context() 385 assert.NotNil(t, ctx) 386 err := ss.SendMsg(&codec.Body{Data: []byte("init-error")}) 387 return err 388 } 389 dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 390 time.Sleep(200 * time.Millisecond) 391 } 392 393 // TestServerStreamRecvMsg test receive message 394 func TestServerStreamRecvMsg(t *testing.T) { 395 codec.Register("fake", &fakeServerCodec{}, nil) 396 397 si := &server.StreamServerInfo{} 398 dispatcher := stream.NewStreamDispatcher() 399 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 400 401 // Init test 402 opts := &server.Options{} 403 ft := &fakeServerTransport{} 404 opts.Transport = ft 405 opts.Codec = codec.GetServer("fake") 406 err := dispatcher.Init(opts) 407 assert.Nil(t, err) 408 assert.Equal(t, opts.Transport, opts.StreamTransport) 409 410 // StreamHandleFunc msg not nil 411 ctx := context.Background() 412 ctx, msg := codec.WithNewMessage(ctx) 413 fh := &trpc.FrameHead{} 414 msg.WithFrameHead(fh) 415 msg.WithStreamID(uint32(100)) 416 msg.WithRemoteAddr(&fakeAddr{}) 417 msg.WithLocalAddr(&fakeAddr{}) 418 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 419 opts.CurrentCompressType = codec.CompressTypeNoop 420 opts.CurrentSerializationType = codec.SerializationTypeNoop 421 422 sh := func(ss server.Stream) error { 423 ctx := ss.Context() 424 assert.NotNil(t, ctx) 425 body := &codec.Body{} 426 err := ss.RecvMsg(body) 427 assert.Nil(t, err) 428 assert.Equal(t, string(body.Data), "data") 429 err = ss.RecvMsg(body) 430 assert.Equal(t, err, io.EOF) 431 432 err = ss.RecvMsg(body) 433 assert.Equal(t, err, io.EOF) 434 return err 435 } 436 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 437 rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 438 assert.Nil(t, rsp) 439 assert.Equal(t, err, errs.ErrServerNoResponse) 440 // handleData normal 441 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 442 msg.WithFrameHead(fh) 443 rsp, err = dispatcher.StreamHandleFunc(ctx, sh, si, []byte("data")) 444 assert.Nil(t, rsp) 445 assert.Equal(t, err, errs.ErrServerNoResponse) 446 447 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE) 448 msg.WithFrameHead(fh) 449 rsp, err = dispatcher.StreamHandleFunc(ctx, sh, si, []byte("close")) 450 assert.Nil(t, rsp) 451 assert.Equal(t, err, errs.ErrServerNoResponse) 452 453 time.Sleep(300 * time.Millisecond) 454 } 455 456 // TestServerStreamRecvMsgFail test for failure to receive data 457 func TestServerStreamRecvMsgFail(t *testing.T) { 458 codec.Register("fake", &fakeServerCodec{}, nil) 459 si := &server.StreamServerInfo{} 460 dispatcher := stream.NewStreamDispatcher() 461 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 462 // Init test 463 opts := &server.Options{} 464 ft := &fakeServerTransport{} 465 opts.Transport = ft 466 opts.Codec = codec.GetServer("fake") 467 err := dispatcher.Init(opts) 468 assert.Nil(t, err) 469 assert.Equal(t, opts.Transport, opts.StreamTransport) 470 471 // StreamHandleFunc msg not nil 472 ctx := context.Background() 473 ctx, msg := codec.WithNewMessage(ctx) 474 fh := &trpc.FrameHead{} 475 msg.WithFrameHead(fh) 476 msg.WithStreamID(uint32(100)) 477 msg.WithRemoteAddr(&fakeAddr{}) 478 msg.WithLocalAddr(&fakeAddr{}) 479 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 480 481 opts.CurrentCompressType = codec.CompressTypeGzip 482 opts.CurrentSerializationType = codec.SerializationTypeNoop 483 484 sh := func(ss server.Stream) error { 485 ctx := ss.Context() 486 assert.NotNil(t, ctx) 487 body := &codec.Body{} 488 err := ss.RecvMsg(body) 489 assert.NotNil(t, err) 490 assert.Contains(t, err.Error(), "server codec Decompress") 491 492 err = ss.RecvMsg(body) 493 assert.NotNil(t, err) 494 assert.Contains(t, err.Error(), "server codec Unmarshal") 495 return err 496 } 497 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 498 msg.WithFrameHead(fh) 499 rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 500 assert.Nil(t, rsp) 501 assert.Equal(t, err, errs.ErrServerNoResponse) 502 // handleData normal 503 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 504 msg.WithFrameHead(fh) 505 rsp, err = dispatcher.StreamHandleFunc(ctx, sh, si, []byte("data")) 506 assert.Nil(t, rsp) 507 assert.Equal(t, err, errs.ErrServerNoResponse) 508 } 509 510 // TesthandleError test server error condition 511 func TestHandleError(t *testing.T) { 512 codec.Register("fake", &fakeServerCodec{}, nil) 513 si := &server.StreamServerInfo{} 514 dispatcher := stream.NewStreamDispatcher() 515 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 516 // Init test 517 opts := &server.Options{} 518 ft := &fakeServerTransport{} 519 opts.Transport = ft 520 opts.Codec = codec.GetServer("fake") 521 err := dispatcher.Init(opts) 522 assert.Nil(t, err) 523 assert.Equal(t, opts.Transport, opts.StreamTransport) 524 525 // StreamHandleFunc msg not nil 526 ctx := context.Background() 527 ctx, msg := codec.WithNewMessage(ctx) 528 fh := &trpc.FrameHead{} 529 msg.WithFrameHead(fh) 530 msg.WithStreamID(uint32(100)) 531 msg.WithRemoteAddr(&fakeAddr{}) 532 msg.WithLocalAddr(&fakeAddr{}) 533 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 534 535 opts.CurrentCompressType = codec.CompressTypeGzip 536 opts.CurrentSerializationType = codec.SerializationTypeNoop 537 538 sh := func(ss server.Stream) error { 539 ctx := ss.Context() 540 assert.NotNil(t, ctx) 541 body := &codec.Body{} 542 err := ss.RecvMsg(body) 543 assert.NotNil(t, err) 544 assert.Contains(t, err.Error(), "Connection is closed") 545 return err 546 } 547 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 548 rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 549 assert.Nil(t, rsp) 550 assert.Equal(t, err, errs.ErrServerNoResponse) 551 // handleError 552 msg.WithFrameHead(nil) 553 msg.WithServerRspErr(errors.New("Connection is closed")) 554 555 noopSh := func(ss server.Stream) error { 556 return nil 557 } 558 msg.WithFrameHead(fh) 559 rsp, err = dispatcher.StreamHandleFunc(ctx, noopSh, si, nil) 560 assert.Nil(t, rsp) 561 assert.Equal(t, err, errs.ErrServerNoResponse) 562 time.Sleep(100 * time.Millisecond) 563 } 564 565 // TestStreamDispatcherHandleFeedback test handles feedback frame 566 func TestStreamDispatcherHandleFeedback(t *testing.T) { 567 568 codec.Register("fake", &fakeServerCodec{}, nil) 569 si := &server.StreamServerInfo{} 570 571 dispatcher := stream.NewStreamDispatcher() 572 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 573 574 // Init test 575 opts := &server.Options{} 576 ft := &fakeServerTransport{} 577 opts.Transport = ft 578 opts.Codec = codec.GetServer("fake") 579 err := dispatcher.Init(opts) 580 assert.Nil(t, err) 581 assert.Equal(t, opts.Transport, opts.StreamTransport) 582 583 ctx := context.Background() 584 ctx, msg := codec.WithNewMessage(ctx) 585 fh := &trpc.FrameHead{} 586 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 587 msg.WithFrameHead(fh) 588 msg.WithStreamID(uint32(100)) 589 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{InitWindowSize: 10}) 590 591 sh := func(ss server.Stream) error { 592 time.Sleep(time.Second) 593 return nil 594 } 595 596 addr := &fakeAddr{} 597 msg.WithRemoteAddr(addr) 598 msg.WithLocalAddr(addr) 599 rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 600 assert.Nil(t, rsp) 601 assert.Equal(t, err, errs.ErrServerNoResponse) 602 603 // handle feedback get server stream fail 604 raddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:1") 605 msg.WithRemoteAddr(raddr) 606 msg.WithLocalAddr(raddr) 607 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK) 608 msg.WithFrameHead(fh) 609 rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback")) 610 assert.Nil(t, rsp) 611 assert.NotNil(t, err) 612 613 // handle feedback invalid stream 614 msg.WithRemoteAddr(addr) 615 msg.WithLocalAddr(addr) 616 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK) 617 msg.WithFrameHead(fh) 618 rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback")) 619 assert.Nil(t, rsp) 620 assert.NotNil(t, err) 621 622 // normal feedback 623 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK) 624 msg.WithFrameHead(fh) 625 msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{WindowSizeIncrement: 1000}) 626 rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback")) 627 assert.Nil(t, rsp) 628 assert.Equal(t, err, errs.ErrServerNoResponse) 629 } 630 631 // TestServerFlowControl tests the situation of server-side flow control 632 func TestServerFlowControl(t *testing.T) { 633 codec.Register("fake", &fakeServerCodec{}, nil) 634 si := &server.StreamServerInfo{} 635 dispatcher := stream.NewStreamDispatcher() 636 // Init test 637 opts := &server.Options{} 638 ft := &fakeServerTransport{} 639 opts.Transport = ft 640 opts.Codec = codec.GetServer("fake") 641 err := dispatcher.Init(opts) 642 assert.Nil(t, err) 643 assert.Equal(t, opts.Transport, opts.StreamTransport) 644 // StreamHandleFunc msg not nil 645 ctx := context.Background() 646 ctx, msg := codec.WithNewMessage(ctx) 647 fh := &trpc.FrameHead{} 648 msg.WithFrameHead(fh) 649 msg.WithStreamID(uint32(100)) 650 addr := &fakeAddr{} 651 msg.WithRemoteAddr(addr) 652 msg.WithLocalAddr(addr) 653 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{InitWindowSize: 65535}) 654 opts.CurrentCompressType = codec.CompressTypeNoop 655 opts.CurrentSerializationType = codec.SerializationTypeNoop 656 var wg sync.WaitGroup 657 wg.Add(1) 658 sh := func(ss server.Stream) error { 659 defer wg.Done() 660 for i := 0; i < 20000; i++ { 661 body := &codec.Body{} 662 err := ss.RecvMsg(body) 663 assert.Nil(t, err) 664 assert.Equal(t, string(body.Data), "data") 665 } 666 return nil 667 } 668 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 669 rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 670 assert.Nil(t, rsp) 671 assert.Equal(t, err, errs.ErrServerNoResponse) 672 673 // handleData normal 674 for i := 0; i < 20000; i++ { 675 newCtx := context.Background() 676 newCtx, newMsg := codec.WithNewMessage(newCtx) 677 newMsg.WithStreamID(uint32(100)) 678 newMsg.WithRemoteAddr(addr) 679 newMsg.WithLocalAddr(addr) 680 newFh := &trpc.FrameHead{} 681 newFh.StreamID = uint32(100) 682 newFh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 683 newMsg.WithFrameHead(newFh) 684 rsp, err := dispatcher.StreamHandleFunc(newCtx, sh, si, []byte("data")) 685 assert.Nil(t, rsp) 686 assert.Equal(t, err, errs.ErrServerNoResponse) 687 } 688 wg.Wait() 689 } 690 691 func TestClientStreamFlowControl(t *testing.T) { 692 svrOpts := []server.Option{server.WithAddress("127.0.0.1:30210")} 693 handle := func(s server.Stream) error { 694 req := getBytes(1024) 695 for i := 0; i < 1000; i++ { 696 err := s.RecvMsg(req) 697 assert.Nil(t, err) 698 } 699 err := s.RecvMsg(req) 700 assert.Equal(t, io.EOF, err) 701 702 rsp := getBytes(1024) 703 copy(rsp.Data, req.Data) 704 for i := 0; i < 1000; i++ { 705 err = s.SendMsg(rsp) 706 assert.Nil(t, err) 707 } 708 return nil 709 } 710 svr := startStreamServer(handle, svrOpts) 711 defer closeStreamServer(svr) 712 713 cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30210")} 714 cliStream, err := getClientStream(context.Background(), bidiDesc, cliOpts) 715 assert.Nil(t, err) 716 717 req := getBytes(1024) 718 rand.Read(req.Data) 719 for i := 0; i < 1000; i++ { 720 err = cliStream.SendMsg(req) 721 assert.Nil(t, err) 722 } 723 err = cliStream.CloseSend() 724 assert.Nil(t, err) 725 rsp := getBytes(1024) 726 for i := 0; i < 1000; i++ { 727 err = cliStream.RecvMsg(rsp) 728 assert.Nil(t, err) 729 assert.Equal(t, req, rsp) 730 } 731 err = cliStream.RecvMsg(rsp) 732 assert.Equal(t, io.EOF, err) 733 } 734 735 func TestServerStreamFlowControl(t *testing.T) { 736 svrOpts := []server.Option{server.WithAddress("127.0.0.1:30211")} 737 handle := func(s server.Stream) error { 738 req := getBytes(1024) 739 err := s.RecvMsg(req) 740 assert.Nil(t, err) 741 742 rsp := getBytes(1024) 743 copy(rsp.Data, req.Data) 744 for i := 0; i < 1000; i++ { 745 err := s.SendMsg(rsp) 746 assert.Nil(t, err) 747 } 748 return nil 749 } 750 svr := startStreamServer(handle, svrOpts) 751 defer closeStreamServer(svr) 752 753 cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30211")} 754 cliStream, err := getClientStream(context.Background(), bidiDesc, cliOpts) 755 assert.Nil(t, err) 756 757 req := getBytes(1024) 758 rand.Read(req.Data) 759 err = cliStream.SendMsg(req) 760 assert.Nil(t, err) 761 err = cliStream.CloseSend() 762 assert.Nil(t, err) 763 rsp := getBytes(1024) 764 for i := 0; i < 1000; i++ { 765 err = cliStream.RecvMsg(rsp) 766 assert.Nil(t, err) 767 assert.Equal(t, req, rsp) 768 } 769 err = cliStream.RecvMsg(rsp) 770 assert.Equal(t, err, io.EOF) 771 } 772 773 func startStreamServer(handle func(server.Stream) error, opts []server.Option) server.Service { 774 svrOpts := []server.Option{ 775 server.WithProtocol("trpc"), 776 server.WithNetwork("tcp"), 777 server.WithStreamTransport(transport.NewServerStreamTransport(transport.WithReusePort(true))), 778 server.WithTransport(transport.NewServerStreamTransport(transport.WithReusePort(true))), 779 // The server must actively set the serialization method 780 server.WithCurrentSerializationType(codec.SerializationTypeNoop), 781 } 782 svrOpts = append(svrOpts, opts...) 783 svr := server.New(svrOpts...) 784 register(svr, handle) 785 go func() { 786 err := svr.Serve() 787 if err != nil { 788 panic(err) 789 } 790 }() 791 time.Sleep(100 * time.Millisecond) 792 return svr 793 } 794 795 func closeStreamServer(svr server.Service) { 796 ch := make(chan struct{}, 1) 797 svr.Close(ch) 798 <-ch 799 } 800 801 var ( 802 clientDesc = &client.ClientStreamDesc{ 803 StreamName: "streamTest", 804 ClientStreams: true, 805 ServerStreams: false, 806 } 807 serverDesc = &client.ClientStreamDesc{ 808 StreamName: "streamTest", 809 ClientStreams: false, 810 ServerStreams: true, 811 } 812 bidiDesc = &client.ClientStreamDesc{ 813 StreamName: "streamTest", 814 ClientStreams: true, 815 ServerStreams: true, 816 } 817 ) 818 819 func getClientStream(ctx context.Context, desc *client.ClientStreamDesc, opts []client.Option) (client.ClientStream, error) { 820 cli := stream.NewStreamClient() 821 method := "/trpc.test.stream.Greeter/StreamSayHello" 822 cliOpts := []client.Option{ 823 client.WithProtocol("trpc"), 824 client.WithTransport(transport.NewClientTransport()), 825 client.WithStreamTransport(transport.NewClientStreamTransport()), 826 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 827 } 828 cliOpts = append(cliOpts, opts...) 829 return cli.NewStream(ctx, desc, method, cliOpts...) 830 } 831 832 func register(s server.Service, f func(server.Stream) error) { 833 svr := &greeterServiceImpl{f: f} 834 if err := s.Register(&GreeterServer_ServiceDesc, svr); err != nil { 835 panic(fmt.Sprintf("Greeter register error: %v", err)) 836 } 837 } 838 839 type greeterServiceImpl struct { 840 f func(server.Stream) error 841 } 842 843 func (s *greeterServiceImpl) BidiStreamSayHello(stream server.Stream) error { 844 return s.f(stream) 845 } 846 847 func GreeterService_BidiStreamSayHello_Handler(srv interface{}, stream server.Stream) error { 848 return srv.(GreeterService).BidiStreamSayHello(stream) 849 } 850 851 type GreeterService interface { 852 // BidiStreamSayHello Bidi streaming 853 BidiStreamSayHello(server.Stream) error 854 } 855 856 var GreeterServer_ServiceDesc = server.ServiceDesc{ 857 ServiceName: "trpc.test.stream.Greeter", 858 HandlerType: (*GreeterService)(nil), 859 StreamHandle: stream.NewStreamDispatcher(), 860 Streams: []server.StreamDesc{ 861 { 862 StreamName: "/trpc.test.stream.Greeter/StreamSayHello", 863 Handler: GreeterService_BidiStreamSayHello_Handler, 864 ServerStreams: true, 865 }, 866 }, 867 } 868 869 func getBytes(size int) *codec.Body { 870 return &codec.Body{Data: make([]byte, size)} 871 } 872 873 /* --------------- Filter Unit Test -------------*/ 874 875 type wrappedServerStream struct { 876 server.Stream 877 } 878 879 func newWrappedServerStream(s server.Stream) server.Stream { 880 return &wrappedServerStream{s} 881 } 882 883 func (w *wrappedServerStream) RecvMsg(m interface{}) error { 884 err := w.Stream.RecvMsg(m) 885 num := binary.LittleEndian.Uint64(m.(*codec.Body).Data[:8]) 886 binary.LittleEndian.PutUint64(m.(*codec.Body).Data[:8], num+1) 887 return err 888 } 889 890 func (w *wrappedServerStream) SendMsg(m interface{}) error { 891 num := binary.LittleEndian.Uint64(m.(*codec.Body).Data[:8]) 892 binary.LittleEndian.PutUint64(m.(*codec.Body).Data[:8], num+1) 893 return w.Stream.SendMsg(m) 894 } 895 896 var ( 897 testKey1 = "hello" 898 testKey2 = "ping" 899 testData = map[string][]byte{ 900 testKey1: []byte("world"), 901 testKey2: []byte("pong"), 902 } 903 ) 904 905 func serverFilterAdd1(ss server.Stream, si *server.StreamServerInfo, 906 handler server.StreamHandler) error { 907 msg := trpc.Message(ss.Context()) 908 meta := msg.ServerMetaData() 909 if v, ok := meta[testKey1]; !ok { 910 return errors.New("meta not exist") 911 } else if !bytes.Equal(v, testData[testKey1]) { 912 return errors.New("meta not match") 913 } 914 err := handler(newWrappedServerStream(ss)) 915 return err 916 } 917 918 func serverFilterAdd2(ss server.Stream, si *server.StreamServerInfo, 919 handler server.StreamHandler) error { 920 msg := trpc.Message(ss.Context()) 921 meta := msg.ServerMetaData() 922 if v, ok := meta[testKey2]; !ok { 923 return errors.New("meta not exist") 924 } else if !bytes.Equal(v, testData[testKey2]) { 925 return errors.New("meta not match") 926 } 927 err := handler(newWrappedServerStream(ss)) 928 return err 929 } 930 931 // TestServerStreamAllFailWhenConnectionClosedAndReconnect tests when a connection 932 // is closed and then reconnected (with the same client IP and port), both SendMsg 933 // and RecvMsg on the server side result in errors. 934 func TestServerStreamAllFailWhenConnectionClosedAndReconnect(t *testing.T) { 935 ch := make(chan struct{}) 936 addr := "127.0.0.1:30211" 937 svrOpts := []server.Option{ 938 server.WithAddress(addr), 939 } 940 handle := func(s server.Stream) error { 941 <-ch 942 err := s.SendMsg(getBytes(100)) 943 assert.Equal(t, errs.Code(err), errs.RetServerSystemErr) 944 err = s.RecvMsg(getBytes(100)) 945 assert.Equal(t, errs.Code(err), errs.RetServerSystemErr) 946 ch <- struct{}{} 947 return nil 948 } 949 svr := startStreamServer(handle, svrOpts) 950 defer closeStreamServer(svr) 951 952 // Init a stream 953 dialer := net.Dialer{ 954 LocalAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 20001}, 955 } 956 conn, err := dialer.Dial("tcp", addr) 957 assert.Nil(t, err) 958 _, msg := codec.WithNewMessage(context.Background()) 959 msg.WithFrameHead(&trpc.FrameHead{ 960 FrameType: uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME), 961 StreamFrameType: uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT), 962 }) 963 msg.WithClientRPCName("/trpc.test.stream.Greeter/StreamSayHello") 964 initReq, err := trpc.DefaultClientCodec.Encode(msg, nil) 965 assert.Nil(t, err) 966 _, err = conn.Write(initReq) 967 assert.Nil(t, err) 968 969 // Close the connection 970 conn.Close() 971 972 // Dial another connection using the same client ip:port 973 time.Sleep(time.Millisecond * 200) 974 _, err = dialer.Dial("tcp", addr) 975 assert.Nil(t, err) 976 977 // Notify server to send and receive 978 ch <- struct{}{} 979 980 // Wait server sending and receiving result assertion 981 <-ch 982 } 983 984 func TestSameClientAddrDiffServerAddr(t *testing.T) { 985 dp := stream.NewStreamDispatcher() 986 dp.Init(&server.Options{ 987 Transport: &fakeServerTransport{}, 988 Codec: &fakeServerCodec{}, 989 CurrentSerializationType: codec.SerializationTypeNoop}) 990 wg := sync.WaitGroup{} 991 992 initFrame := func(localAddr, remoteAddr net.Addr) { 993 ctx, msg := codec.WithNewMessage(context.Background()) 994 msg.WithFrameHead(&trpc.FrameHead{StreamFrameType: uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)}) 995 msg.WithStreamID(200) 996 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 997 msg.WithRemoteAddr(remoteAddr) 998 msg.WithLocalAddr(localAddr) 999 msg.WithSerializationType(codec.SerializationTypeNoop) 1000 wg.Add(1) 1001 rsp, err := dp.StreamHandleFunc( 1002 ctx, 1003 func(s server.Stream) error { 1004 err := s.RecvMsg(&codec.Body{}) 1005 assert.Nil(t, err) 1006 wg.Done() 1007 return nil 1008 }, 1009 &server.StreamServerInfo{}, 1010 []byte("init")) 1011 assert.Nil(t, rsp) 1012 assert.Equal(t, errs.ErrServerNoResponse, err) 1013 } 1014 1015 dataFrame := func(localAddr, remoteAddr net.Addr) { 1016 ctx, msg := codec.WithNewMessage(context.Background()) 1017 msg.WithFrameHead(&trpc.FrameHead{StreamFrameType: uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)}) 1018 msg.WithStreamID(200) 1019 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 1020 msg.WithRemoteAddr(remoteAddr) 1021 msg.WithLocalAddr(localAddr) 1022 rsp, err := dp.StreamHandleFunc(ctx, nil, &server.StreamServerInfo{}, []byte("data")) 1023 assert.Nil(t, rsp) 1024 assert.Equal(t, errs.ErrServerNoResponse, err) 1025 } 1026 1027 clientAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:9000") 1028 serverAddr1, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:10001") 1029 serverAddr2, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:10002") 1030 initFrame(serverAddr1, clientAddr) 1031 initFrame(serverAddr2, clientAddr) 1032 dataFrame(serverAddr1, clientAddr) 1033 dataFrame(serverAddr2, clientAddr) 1034 1035 c := make(chan struct{}) 1036 go func() { 1037 defer close(c) 1038 wg.Wait() 1039 }() 1040 select { 1041 case <-c: // completed normally 1042 case <-time.After(time.Millisecond * 500): // timed out 1043 assert.FailNow(t, "server did not receive data frame") 1044 } 1045 }