trpc.group/trpc-go/trpc-go@v1.0.2/stream/server_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package stream_test 15 16 import ( 17 "bytes" 18 "context" 19 "encoding/binary" 20 "errors" 21 "fmt" 22 "io" 23 "math/rand" 24 "sync" 25 "testing" 26 "time" 27 28 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 29 30 "trpc.group/trpc-go/trpc-go/client" 31 "trpc.group/trpc-go/trpc-go/errs" 32 33 trpc "trpc.group/trpc-go/trpc-go" 34 "trpc.group/trpc-go/trpc-go/stream" 35 36 "trpc.group/trpc-go/trpc-go/codec" 37 "trpc.group/trpc-go/trpc-go/server" 38 "trpc.group/trpc-go/trpc-go/transport" 39 40 "github.com/stretchr/testify/assert" 41 ) 42 43 type fakeStreamHandle struct { 44 } 45 46 // StreamHandleFunc Mock StreamHandleFunc method 47 func (fs *fakeStreamHandle) StreamHandleFunc(ctx context.Context, sh server.StreamHandler, req []byte) ([]byte, error) { 48 return nil, nil 49 } 50 51 // Init Mock Init method 52 func (fs *fakeStreamHandle) Init(opts *server.Options) { 53 return 54 } 55 56 type fakeServerTransport struct{} 57 58 type fakeServerCodec struct{} 59 60 // Send Mock Send method 61 func (s *fakeServerTransport) Send(ctx context.Context, rspBuf []byte) error { 62 if string(rspBuf) == "init-error" { 63 return errors.New("init-error") 64 } 65 return nil 66 } 67 68 // Close Mock Close method 69 func (s *fakeServerTransport) Close(ctx context.Context) { 70 return 71 } 72 73 // ListenAndServe Mock ListenAndServe method 74 func (s *fakeServerTransport) ListenAndServe(ctx context.Context, opts ...transport.ListenServeOption) error { 75 76 return nil 77 } 78 79 // Decode Mock codec Decode method 80 func (c *fakeServerCodec) Decode(msg codec.Msg, reqBuf []byte) (reqBody []byte, err error) { 81 return reqBuf, nil 82 } 83 84 // Encode Mock codec Encode method 85 func (c *fakeServerCodec) Encode(msg codec.Msg, rspBody []byte) (rspBuf []byte, err error) { 86 rsp := string(rspBody) 87 if rsp == "encode-error" { 88 return nil, errors.New("server encode response fail") 89 } 90 if msg.StreamID() < uint32(100) { 91 return nil, errors.New("streamID less than 100") 92 } 93 if msg.StreamID() == uint32(101) { 94 return []byte("init-error"), nil 95 } 96 return rspBody, nil 97 } 98 99 func streamHandler(stream server.Stream) error { 100 time.Sleep(time.Second) 101 return nil 102 } 103 104 func errorStreamHandler(stream server.Stream) error { 105 return errors.New("handle fail") 106 } 107 108 type fakeAddr struct { 109 } 110 111 // Network method of Network Mock net.Addr interface 112 func (f *fakeAddr) Network() string { 113 return "tcp" 114 } 115 116 // String method of String Mock net.Addr interface 117 func (f *fakeAddr) String() string { 118 return "127.0.0.01:67891" 119 } 120 121 // TestStreamDispatcherHandleInit Test Stream Dispatcher 122 func TestStreamDispatcherHandleInit(t *testing.T) { 123 codec.Register("fake", &fakeServerCodec{}, nil) 124 125 si := &server.StreamServerInfo{} 126 dispatcher := stream.NewStreamDispatcher() 127 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 128 129 // Init test 130 opts := &server.Options{} 131 ft := &fakeServerTransport{} 132 opts.Transport = ft 133 opts.Codec = codec.GetServer("fake") 134 err := dispatcher.Init(opts) 135 assert.Nil(t, err) 136 assert.Equal(t, opts.Transport, opts.StreamTransport) 137 // StreamHandleFunc msg not nil 138 ctx := context.Background() 139 ctx, msg := codec.WithNewMessage(ctx) 140 rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, nil) 141 assert.Nil(t, rsp) 142 assert.Contains(t, err.Error(), "frameHead is not contained in msg") 143 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 144 // StreamHandleFunc handle init 145 fh := &trpc.FrameHead{} 146 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 147 msg.WithFrameHead(fh) 148 msg.WithStreamID(uint32(100)) 149 msg.WithRemoteAddr(&fakeAddr{}) 150 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) 151 assert.Nil(t, rsp) 152 assert.Equal(t, err, errs.ErrServerNoResponse) 153 154 // StreamHandleFunc handle init with codec encode error 155 msg.WithFrameHead(fh) 156 msg.WithStreamID(uint32(99)) 157 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) 158 assert.Nil(t, rsp) 159 assert.Equal(t, err.Error(), "streamID less than 100") 160 161 // StreamHandleFunc handle init send error 162 msg.WithFrameHead(fh) 163 msg.WithStreamID(uint32(101)) 164 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init-error")) 165 assert.Nil(t, rsp) 166 assert.Contains(t, err.Error(), "init-error") 167 168 // StreamHandleFun handle data to validate streamID was stored 169 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 170 msg.WithFrameHead(fh) 171 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data")) 172 assert.Nil(t, rsp) 173 assert.Equal(t, err, errs.ErrServerNoResponse) 174 175 // StreamHandleFunc handle error 176 msg.WithFrameHead(fh) 177 msg.WithStreamID(100) 178 rsp, err = dispatcher.StreamHandleFunc(ctx, errorStreamHandler, si, []byte("init")) 179 assert.Nil(t, rsp) 180 assert.Equal(t, err, errs.ErrServerNoResponse) 181 time.Sleep(100 * time.Millisecond) 182 } 183 184 // TestStreamDispatcherHandleData test StreamDispatcher Handle data 185 func TestStreamDispatcherHandleData(t *testing.T) { 186 codec.Register("fake", &fakeServerCodec{}, nil) 187 188 si := &server.StreamServerInfo{} 189 dispatcher := stream.NewStreamDispatcher() 190 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 191 192 // Init test 193 opts := &server.Options{} 194 ft := &fakeServerTransport{} 195 opts.Transport = ft 196 opts.Codec = codec.GetServer("fake") 197 err := dispatcher.Init(opts) 198 assert.Nil(t, err) 199 assert.Equal(t, opts.Transport, opts.StreamTransport) 200 201 ctx := context.Background() 202 ctx, msg := codec.WithNewMessage(ctx) 203 fh := &trpc.FrameHead{} 204 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 205 msg.WithFrameHead(fh) 206 msg.WithStreamID(uint32(100)) 207 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 208 addr := &fakeAddr{} 209 msg.WithRemoteAddr(addr) 210 rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) 211 assert.Nil(t, rsp) 212 assert.Equal(t, err, errs.ErrServerNoResponse) 213 214 // handleData normal 215 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 216 msg.WithFrameHead(fh) 217 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data")) 218 assert.Nil(t, rsp) 219 assert.Equal(t, err, errs.ErrServerNoResponse) 220 221 // handleData error no such addr 222 msg.WithRemoteAddr(nil) 223 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 224 msg.WithFrameHead(fh) 225 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data")) 226 assert.Nil(t, rsp) 227 assert.Contains(t, err.Error(), "no such addr") 228 229 // handle data error no such stream id 230 msg.WithRemoteAddr(addr) 231 msg.WithStreamID(uint32(101)) 232 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 233 msg.WithFrameHead(fh) 234 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data")) 235 assert.Nil(t, rsp) 236 assert.Contains(t, err.Error(), "no such stream ID") 237 } 238 239 // TestStreamDispatcherHandleClose test handles Close frame 240 func TestStreamDispatcherHandleClose(t *testing.T) { 241 242 codec.Register("fake", &fakeServerCodec{}, nil) 243 244 si := &server.StreamServerInfo{} 245 dispatcher := stream.NewStreamDispatcher() 246 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 247 248 // Init test 249 opts := &server.Options{} 250 ft := &fakeServerTransport{} 251 opts.Transport = ft 252 opts.Codec = codec.GetServer("fake") 253 err := dispatcher.Init(opts) 254 assert.Nil(t, err) 255 assert.Equal(t, opts.Transport, opts.StreamTransport) 256 257 ctx := context.Background() 258 ctx, msg := codec.WithNewMessage(ctx) 259 fh := &trpc.FrameHead{} 260 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 261 msg.WithFrameHead(fh) 262 msg.WithStreamID(uint32(100)) 263 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 264 265 addr := &fakeAddr{} 266 msg.WithRemoteAddr(addr) 267 msg.WithFrameHead(fh) 268 rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) 269 assert.Nil(t, rsp) 270 assert.Equal(t, err, errs.ErrServerNoResponse) 271 272 // handle close normal 273 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE) 274 msg.WithFrameHead(fh) 275 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("close")) 276 assert.Nil(t, rsp) 277 assert.Equal(t, errs.ErrServerNoResponse, err) 278 279 // handle close no such addr 280 msg.WithFrameHead(fh) 281 msg.WithRemoteAddr(nil) 282 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("close")) 283 assert.Nil(t, rsp) 284 assert.Equal(t, errs.ErrServerNoResponse, err) 285 286 // handle close server rsp err 287 msg.WithRemoteAddr(addr) 288 msg.WithFrameHead(fh) 289 msg.WithServerRspErr(errors.New("server rsp error")) 290 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("close")) 291 assert.Nil(t, rsp) 292 assert.Equal(t, errs.ErrServerNoResponse, err) 293 294 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK) 295 msg.WithFrameHead(fh) 296 msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{}) 297 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("feedback")) 298 assert.Nil(t, rsp) 299 assert.Equal(t, err, errs.ErrServerNoResponse) 300 301 fh.StreamFrameType = uint8(8) 302 msg.WithFrameHead(fh) 303 rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("unknown")) 304 assert.Nil(t, rsp) 305 assert.Contains(t, err.Error(), "unknown frame type") 306 } 307 308 // TestServerStreamSendMsg test server receives messages 309 func TestServerStreamSendMsg(t *testing.T) { 310 codec.Register("fake", &fakeServerCodec{}, nil) 311 312 si := &server.StreamServerInfo{} 313 dispatcher := stream.NewStreamDispatcher() 314 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 315 316 // Init test 317 opts := &server.Options{} 318 ft := &fakeServerTransport{} 319 opts.Transport = ft 320 opts.Codec = codec.GetServer("fake") 321 err := dispatcher.Init(opts) 322 assert.Nil(t, err) 323 assert.Equal(t, opts.Transport, opts.StreamTransport) 324 325 // StreamHandleFunc msg not nil 326 ctx := context.Background() 327 ctx, msg := codec.WithNewMessage(ctx) 328 fh := &trpc.FrameHead{} 329 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 330 msg.WithFrameHead(fh) 331 msg.WithStreamID(uint32(100)) 332 msg.WithRemoteAddr(&fakeAddr{}) 333 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 334 335 opts.CurrentCompressType = codec.CompressTypeNoop 336 opts.CurrentSerializationType = codec.SerializationTypeNoop 337 338 sh := func(ss server.Stream) error { 339 ctx = ss.Context() 340 assert.NotNil(t, ctx) 341 err := ss.SendMsg(&codec.Body{Data: []byte("init")}) 342 assert.Nil(t, err) 343 return err 344 } 345 rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 346 assert.Nil(t, rsp) 347 assert.Equal(t, err, errs.ErrServerNoResponse) 348 time.Sleep(100 * time.Millisecond) 349 350 opts.CurrentCompressType = 5 351 opts.CurrentSerializationType = codec.SerializationTypeNoop 352 sh = func(ss server.Stream) error { 353 ctx = ss.Context() 354 assert.NotNil(t, ctx) 355 err := ss.SendMsg(&codec.Body{Data: []byte("init")}) 356 assert.NotNil(t, err) 357 return err 358 } 359 dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 360 time.Sleep(200 * time.Millisecond) 361 362 opts.CurrentCompressType = codec.CompressTypeNoop 363 opts.CurrentSerializationType = codec.SerializationTypeNoop 364 sh = func(ss server.Stream) error { 365 ctx = ss.Context() 366 assert.NotNil(t, ctx) 367 err := ss.SendMsg(&codec.Body{Data: []byte("encode-error")}) 368 assert.Contains(t, err.Error(), "server codec Encode") 369 return err 370 } 371 dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 372 time.Sleep(200 * time.Millisecond) 373 374 opts.CurrentCompressType = codec.CompressTypeNoop 375 opts.CurrentSerializationType = codec.SerializationTypeNoop 376 sh = func(ss server.Stream) error { 377 ctx = ss.Context() 378 assert.NotNil(t, ctx) 379 err := ss.SendMsg(&codec.Body{Data: []byte("init-error")}) 380 return err 381 } 382 dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 383 time.Sleep(200 * time.Millisecond) 384 } 385 386 // TestServerStreamRecvMsg test receive message 387 func TestServerStreamRecvMsg(t *testing.T) { 388 codec.Register("fake", &fakeServerCodec{}, nil) 389 390 si := &server.StreamServerInfo{} 391 dispatcher := stream.NewStreamDispatcher() 392 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 393 394 // Init test 395 opts := &server.Options{} 396 ft := &fakeServerTransport{} 397 opts.Transport = ft 398 opts.Codec = codec.GetServer("fake") 399 err := dispatcher.Init(opts) 400 assert.Nil(t, err) 401 assert.Equal(t, opts.Transport, opts.StreamTransport) 402 403 // StreamHandleFunc msg not nil 404 ctx := context.Background() 405 ctx, msg := codec.WithNewMessage(ctx) 406 fh := &trpc.FrameHead{} 407 msg.WithFrameHead(fh) 408 msg.WithStreamID(uint32(100)) 409 msg.WithRemoteAddr(&fakeAddr{}) 410 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 411 opts.CurrentCompressType = codec.CompressTypeNoop 412 opts.CurrentSerializationType = codec.SerializationTypeNoop 413 414 sh := func(ss server.Stream) error { 415 ctx := ss.Context() 416 assert.NotNil(t, ctx) 417 body := &codec.Body{} 418 err := ss.RecvMsg(body) 419 assert.Nil(t, err) 420 assert.Equal(t, string(body.Data), "data") 421 err = ss.RecvMsg(body) 422 assert.Equal(t, err, io.EOF) 423 424 err = ss.RecvMsg(body) 425 assert.Equal(t, err, io.EOF) 426 return err 427 } 428 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 429 rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 430 assert.Nil(t, rsp) 431 assert.Equal(t, err, errs.ErrServerNoResponse) 432 // handleData normal 433 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 434 msg.WithFrameHead(fh) 435 rsp, err = dispatcher.StreamHandleFunc(ctx, sh, si, []byte("data")) 436 assert.Nil(t, rsp) 437 assert.Equal(t, err, errs.ErrServerNoResponse) 438 439 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE) 440 msg.WithFrameHead(fh) 441 rsp, err = dispatcher.StreamHandleFunc(ctx, sh, si, []byte("close")) 442 assert.Nil(t, rsp) 443 assert.Equal(t, err, errs.ErrServerNoResponse) 444 445 time.Sleep(300 * time.Millisecond) 446 } 447 448 // TestServerStreamRecvMsgFail test for failure to receive data 449 func TestServerStreamRecvMsgFail(t *testing.T) { 450 codec.Register("fake", &fakeServerCodec{}, nil) 451 si := &server.StreamServerInfo{} 452 dispatcher := stream.NewStreamDispatcher() 453 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 454 // Init test 455 opts := &server.Options{} 456 ft := &fakeServerTransport{} 457 opts.Transport = ft 458 opts.Codec = codec.GetServer("fake") 459 err := dispatcher.Init(opts) 460 assert.Nil(t, err) 461 assert.Equal(t, opts.Transport, opts.StreamTransport) 462 463 // StreamHandleFunc msg not nil 464 ctx := context.Background() 465 ctx, msg := codec.WithNewMessage(ctx) 466 fh := &trpc.FrameHead{} 467 msg.WithFrameHead(fh) 468 msg.WithStreamID(uint32(100)) 469 msg.WithRemoteAddr(&fakeAddr{}) 470 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 471 472 opts.CurrentCompressType = codec.CompressTypeGzip 473 opts.CurrentSerializationType = codec.SerializationTypeNoop 474 475 sh := func(ss server.Stream) error { 476 ctx := ss.Context() 477 assert.NotNil(t, ctx) 478 body := &codec.Body{} 479 err := ss.RecvMsg(body) 480 assert.NotNil(t, err) 481 assert.Contains(t, err.Error(), "server codec Decompress") 482 483 err = ss.RecvMsg(body) 484 assert.NotNil(t, err) 485 assert.Contains(t, err.Error(), "server codec Unmarshal") 486 return err 487 } 488 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 489 msg.WithFrameHead(fh) 490 rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 491 assert.Nil(t, rsp) 492 assert.Equal(t, err, errs.ErrServerNoResponse) 493 // handleData normal 494 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 495 msg.WithFrameHead(fh) 496 rsp, err = dispatcher.StreamHandleFunc(ctx, sh, si, []byte("data")) 497 assert.Nil(t, rsp) 498 assert.Equal(t, err, errs.ErrServerNoResponse) 499 } 500 501 // TesthandleError test server error condition 502 func TestHandleError(t *testing.T) { 503 codec.Register("fake", &fakeServerCodec{}, nil) 504 si := &server.StreamServerInfo{} 505 dispatcher := stream.NewStreamDispatcher() 506 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 507 // Init test 508 opts := &server.Options{} 509 ft := &fakeServerTransport{} 510 opts.Transport = ft 511 opts.Codec = codec.GetServer("fake") 512 err := dispatcher.Init(opts) 513 assert.Nil(t, err) 514 assert.Equal(t, opts.Transport, opts.StreamTransport) 515 516 // StreamHandleFunc msg not nil 517 ctx := context.Background() 518 ctx, msg := codec.WithNewMessage(ctx) 519 fh := &trpc.FrameHead{} 520 msg.WithFrameHead(fh) 521 msg.WithStreamID(uint32(100)) 522 msg.WithRemoteAddr(&fakeAddr{}) 523 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) 524 525 opts.CurrentCompressType = codec.CompressTypeGzip 526 opts.CurrentSerializationType = codec.SerializationTypeNoop 527 528 sh := func(ss server.Stream) error { 529 ctx := ss.Context() 530 assert.NotNil(t, ctx) 531 body := &codec.Body{} 532 err := ss.RecvMsg(body) 533 assert.NotNil(t, err) 534 assert.Contains(t, err.Error(), "Connection is closed") 535 return err 536 } 537 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 538 rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 539 assert.Nil(t, rsp) 540 assert.Equal(t, err, errs.ErrServerNoResponse) 541 // handleError 542 msg.WithFrameHead(nil) 543 msg.WithServerRspErr(errors.New("Connection is closed")) 544 545 noopSh := func(ss server.Stream) error { 546 return nil 547 } 548 msg.WithFrameHead(fh) 549 rsp, err = dispatcher.StreamHandleFunc(ctx, noopSh, si, nil) 550 assert.Nil(t, rsp) 551 assert.Equal(t, err, errs.ErrServerNoResponse) 552 time.Sleep(100 * time.Millisecond) 553 } 554 555 // TestStreamDispatcherHandleFeedback test handles feedback frame 556 func TestStreamDispatcherHandleFeedback(t *testing.T) { 557 558 codec.Register("fake", &fakeServerCodec{}, nil) 559 si := &server.StreamServerInfo{} 560 561 dispatcher := stream.NewStreamDispatcher() 562 assert.Equal(t, dispatcher, stream.DefaultStreamDispatcher) 563 564 // Init test 565 opts := &server.Options{} 566 ft := &fakeServerTransport{} 567 opts.Transport = ft 568 opts.Codec = codec.GetServer("fake") 569 err := dispatcher.Init(opts) 570 assert.Nil(t, err) 571 assert.Equal(t, opts.Transport, opts.StreamTransport) 572 573 ctx := context.Background() 574 ctx, msg := codec.WithNewMessage(ctx) 575 fh := &trpc.FrameHead{} 576 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 577 msg.WithFrameHead(fh) 578 msg.WithStreamID(uint32(100)) 579 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{InitWindowSize: 10}) 580 581 sh := func(ss server.Stream) error { 582 time.Sleep(time.Second) 583 return nil 584 } 585 586 addr := &fakeAddr{} 587 msg.WithRemoteAddr(addr) 588 rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 589 assert.Nil(t, rsp) 590 assert.Equal(t, err, errs.ErrServerNoResponse) 591 592 // handle feedback get server stream fail 593 msg.WithRemoteAddr(nil) 594 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK) 595 msg.WithFrameHead(fh) 596 rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback")) 597 assert.Nil(t, rsp) 598 assert.NotNil(t, err) 599 600 // handle feedback invalid stream 601 msg.WithRemoteAddr(addr) 602 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK) 603 msg.WithFrameHead(fh) 604 rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback")) 605 assert.Nil(t, rsp) 606 assert.NotNil(t, err) 607 608 // normal feedback 609 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK) 610 msg.WithFrameHead(fh) 611 msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{WindowSizeIncrement: 1000}) 612 rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback")) 613 assert.Nil(t, rsp) 614 assert.Equal(t, err, errs.ErrServerNoResponse) 615 } 616 617 // TestServerFlowControl tests the situation of server-side flow control 618 func TestServerFlowControl(t *testing.T) { 619 codec.Register("fake", &fakeServerCodec{}, nil) 620 si := &server.StreamServerInfo{} 621 dispatcher := stream.NewStreamDispatcher() 622 // Init test 623 opts := &server.Options{} 624 ft := &fakeServerTransport{} 625 opts.Transport = ft 626 opts.Codec = codec.GetServer("fake") 627 err := dispatcher.Init(opts) 628 assert.Nil(t, err) 629 assert.Equal(t, opts.Transport, opts.StreamTransport) 630 // StreamHandleFunc msg not nil 631 ctx := context.Background() 632 ctx, msg := codec.WithNewMessage(ctx) 633 fh := &trpc.FrameHead{} 634 msg.WithFrameHead(fh) 635 msg.WithStreamID(uint32(100)) 636 addr := &fakeAddr{} 637 msg.WithRemoteAddr(addr) 638 msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{InitWindowSize: 65535}) 639 opts.CurrentCompressType = codec.CompressTypeNoop 640 opts.CurrentSerializationType = codec.SerializationTypeNoop 641 var wg sync.WaitGroup 642 wg.Add(1) 643 sh := func(ss server.Stream) error { 644 defer wg.Done() 645 for i := 0; i < 20000; i++ { 646 body := &codec.Body{} 647 err := ss.RecvMsg(body) 648 assert.Nil(t, err) 649 assert.Equal(t, string(body.Data), "data") 650 } 651 return nil 652 } 653 fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT) 654 rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) 655 assert.Nil(t, rsp) 656 assert.Equal(t, err, errs.ErrServerNoResponse) 657 658 // handleData normal 659 for i := 0; i < 20000; i++ { 660 newCtx := context.Background() 661 newCtx, newMsg := codec.WithNewMessage(newCtx) 662 newMsg.WithStreamID(uint32(100)) 663 newMsg.WithRemoteAddr(addr) 664 newFh := &trpc.FrameHead{} 665 newFh.StreamID = uint32(100) 666 newFh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) 667 newMsg.WithFrameHead(newFh) 668 rsp, err := dispatcher.StreamHandleFunc(newCtx, sh, si, []byte("data")) 669 assert.Nil(t, rsp) 670 assert.Equal(t, err, errs.ErrServerNoResponse) 671 } 672 wg.Wait() 673 } 674 675 func TestClientStreamFlowControl(t *testing.T) { 676 svrOpts := []server.Option{server.WithAddress("127.0.0.1:30210")} 677 handle := func(s server.Stream) error { 678 req := getBytes(1024) 679 for i := 0; i < 1000; i++ { 680 err := s.RecvMsg(req) 681 assert.Nil(t, err) 682 } 683 err := s.RecvMsg(req) 684 assert.Equal(t, io.EOF, err) 685 686 rsp := getBytes(1024) 687 copy(rsp.Data, req.Data) 688 for i := 0; i < 1000; i++ { 689 err = s.SendMsg(rsp) 690 assert.Nil(t, err) 691 } 692 return nil 693 } 694 svr := startStreamServer(handle, svrOpts) 695 defer closeStreamServer(svr) 696 697 cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30210")} 698 cliStream, err := getClientStream(context.Background(), bidiDesc, cliOpts) 699 assert.Nil(t, err) 700 701 req := getBytes(1024) 702 rand.Read(req.Data) 703 for i := 0; i < 1000; i++ { 704 err = cliStream.SendMsg(req) 705 assert.Nil(t, err) 706 } 707 err = cliStream.CloseSend() 708 assert.Nil(t, err) 709 rsp := getBytes(1024) 710 for i := 0; i < 1000; i++ { 711 err = cliStream.RecvMsg(rsp) 712 assert.Nil(t, err) 713 assert.Equal(t, req, rsp) 714 } 715 err = cliStream.RecvMsg(rsp) 716 assert.Equal(t, io.EOF, err) 717 } 718 719 func TestServerStreamFlowControl(t *testing.T) { 720 svrOpts := []server.Option{server.WithAddress("127.0.0.1:30211")} 721 handle := func(s server.Stream) error { 722 req := getBytes(1024) 723 err := s.RecvMsg(req) 724 assert.Nil(t, err) 725 726 rsp := getBytes(1024) 727 copy(rsp.Data, req.Data) 728 for i := 0; i < 1000; i++ { 729 err := s.SendMsg(rsp) 730 assert.Nil(t, err) 731 } 732 return nil 733 } 734 svr := startStreamServer(handle, svrOpts) 735 defer closeStreamServer(svr) 736 737 cliOpts := []client.Option{client.WithTarget("ip://127.0.0.1:30211")} 738 cliStream, err := getClientStream(context.Background(), bidiDesc, cliOpts) 739 assert.Nil(t, err) 740 741 req := getBytes(1024) 742 rand.Read(req.Data) 743 err = cliStream.SendMsg(req) 744 assert.Nil(t, err) 745 err = cliStream.CloseSend() 746 assert.Nil(t, err) 747 rsp := getBytes(1024) 748 for i := 0; i < 1000; i++ { 749 err = cliStream.RecvMsg(rsp) 750 assert.Nil(t, err) 751 assert.Equal(t, req, rsp) 752 } 753 err = cliStream.RecvMsg(rsp) 754 assert.Equal(t, err, io.EOF) 755 } 756 757 func startStreamServer(handle func(server.Stream) error, opts []server.Option) server.Service { 758 svrOpts := []server.Option{ 759 server.WithProtocol("trpc"), 760 server.WithNetwork("tcp"), 761 server.WithStreamTransport(transport.NewServerStreamTransport(transport.WithReusePort(true))), 762 server.WithTransport(transport.NewServerStreamTransport(transport.WithReusePort(true))), 763 // The server must actively set the serialization method 764 server.WithCurrentSerializationType(codec.SerializationTypeNoop), 765 } 766 svrOpts = append(svrOpts, opts...) 767 svr := server.New(svrOpts...) 768 register(svr, handle) 769 go func() { 770 err := svr.Serve() 771 if err != nil { 772 panic(err) 773 } 774 }() 775 time.Sleep(100 * time.Millisecond) 776 return svr 777 } 778 779 func closeStreamServer(svr server.Service) { 780 ch := make(chan struct{}, 1) 781 svr.Close(ch) 782 <-ch 783 } 784 785 var ( 786 clientDesc = &client.ClientStreamDesc{ 787 StreamName: "streamTest", 788 ClientStreams: true, 789 ServerStreams: false, 790 } 791 serverDesc = &client.ClientStreamDesc{ 792 StreamName: "streamTest", 793 ClientStreams: false, 794 ServerStreams: true, 795 } 796 bidiDesc = &client.ClientStreamDesc{ 797 StreamName: "streamTest", 798 ClientStreams: true, 799 ServerStreams: true, 800 } 801 ) 802 803 func getClientStream(ctx context.Context, desc *client.ClientStreamDesc, opts []client.Option) (client.ClientStream, error) { 804 cli := stream.NewStreamClient() 805 method := "/trpc.test.stream.Greeter/StreamSayHello" 806 cliOpts := []client.Option{ 807 client.WithProtocol("trpc"), 808 client.WithTransport(transport.NewClientTransport()), 809 client.WithStreamTransport(transport.NewClientStreamTransport()), 810 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 811 } 812 cliOpts = append(cliOpts, opts...) 813 return cli.NewStream(ctx, desc, method, cliOpts...) 814 } 815 816 func register(s server.Service, f func(server.Stream) error) { 817 svr := &greeterServiceImpl{f: f} 818 if err := s.Register(&GreeterServer_ServiceDesc, svr); err != nil { 819 panic(fmt.Sprintf("Greeter register error: %v", err)) 820 } 821 } 822 823 type greeterServiceImpl struct { 824 f func(server.Stream) error 825 } 826 827 func (s *greeterServiceImpl) BidiStreamSayHello(stream server.Stream) error { 828 return s.f(stream) 829 } 830 831 func GreeterService_BidiStreamSayHello_Handler(srv interface{}, stream server.Stream) error { 832 return srv.(GreeterService).BidiStreamSayHello(stream) 833 } 834 835 type GreeterService interface { 836 // BidiStreamSayHello Bidi streaming 837 BidiStreamSayHello(server.Stream) error 838 } 839 840 var GreeterServer_ServiceDesc = server.ServiceDesc{ 841 ServiceName: "trpc.test.stream.Greeter", 842 HandlerType: (*GreeterService)(nil), 843 StreamHandle: stream.NewStreamDispatcher(), 844 Streams: []server.StreamDesc{ 845 { 846 StreamName: "/trpc.test.stream.Greeter/StreamSayHello", 847 Handler: GreeterService_BidiStreamSayHello_Handler, 848 ServerStreams: true, 849 }, 850 }, 851 } 852 853 func getBytes(size int) *codec.Body { 854 return &codec.Body{Data: make([]byte, size)} 855 } 856 857 /* --------------- Filter Unit Test -------------*/ 858 859 type wrappedServerStream struct { 860 server.Stream 861 } 862 863 func newWrappedServerStream(s server.Stream) server.Stream { 864 return &wrappedServerStream{s} 865 } 866 867 func (w *wrappedServerStream) RecvMsg(m interface{}) error { 868 err := w.Stream.RecvMsg(m) 869 num := binary.LittleEndian.Uint64(m.(*codec.Body).Data[:8]) 870 binary.LittleEndian.PutUint64(m.(*codec.Body).Data[:8], num+1) 871 return err 872 } 873 874 func (w *wrappedServerStream) SendMsg(m interface{}) error { 875 num := binary.LittleEndian.Uint64(m.(*codec.Body).Data[:8]) 876 binary.LittleEndian.PutUint64(m.(*codec.Body).Data[:8], num+1) 877 return w.Stream.SendMsg(m) 878 } 879 880 var ( 881 testKey1 = "hello" 882 testKey2 = "ping" 883 testData = map[string][]byte{ 884 testKey1: []byte("world"), 885 testKey2: []byte("pong"), 886 } 887 ) 888 889 func serverFilterAdd1(ss server.Stream, si *server.StreamServerInfo, 890 handler server.StreamHandler) error { 891 msg := trpc.Message(ss.Context()) 892 meta := msg.ServerMetaData() 893 if v, ok := meta[testKey1]; !ok { 894 return errors.New("meta not exist") 895 } else if !bytes.Equal(v, testData[testKey1]) { 896 return errors.New("meta not match") 897 } 898 err := handler(newWrappedServerStream(ss)) 899 return err 900 } 901 902 func serverFilterAdd2(ss server.Stream, si *server.StreamServerInfo, 903 handler server.StreamHandler) error { 904 msg := trpc.Message(ss.Context()) 905 meta := msg.ServerMetaData() 906 if v, ok := meta[testKey2]; !ok { 907 return errors.New("meta not exist") 908 } else if !bytes.Equal(v, testData[testKey2]) { 909 return errors.New("meta not match") 910 } 911 err := handler(newWrappedServerStream(ss)) 912 return err 913 }