trpc.group/trpc-go/trpc-go@v1.0.3/transport/server_transport_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package transport_test 15 16 import ( 17 "context" 18 "encoding/binary" 19 "encoding/json" 20 "errors" 21 "fmt" 22 "net" 23 "runtime" 24 "sync" 25 "testing" 26 "time" 27 28 "github.com/stretchr/testify/assert" 29 "github.com/stretchr/testify/require" 30 31 _ "trpc.group/trpc-go/trpc-go" 32 "trpc.group/trpc-go/trpc-go/errs" 33 "trpc.group/trpc-go/trpc-go/transport" 34 ) 35 36 func TestNewServerTransport(t *testing.T) { 37 st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute)) 38 assert.NotNil(t, st) 39 } 40 41 func TestTCPListenAndServe(t *testing.T) { 42 var addr = getFreeAddr("tcp4") 43 44 // Wait until server transport is ready. 45 wg := sync.WaitGroup{} 46 wg.Add(1) 47 go func() { 48 defer wg.Done() 49 st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute)) 50 err := st.ListenAndServe(context.Background(), 51 transport.WithListenNetwork("tcp4"), 52 transport.WithListenAddress(addr), 53 transport.WithHandler(&errorHandler{}), 54 transport.WithServerFramerBuilder(&framerBuilder{}), 55 transport.WithServiceName("test name"), 56 ) 57 58 if err != nil { 59 t.Logf("ListenAndServe fail:%v", err) 60 } 61 }() 62 wg.Wait() 63 64 // Round trip. 65 req := &helloRequest{ 66 Name: "trpc", 67 Msg: "HelloWorld", 68 } 69 70 data, err := json.Marshal(req) 71 if err != nil { 72 t.Fatalf("json marshal fail:%v", err) 73 } 74 lenData := make([]byte, 4) 75 binary.BigEndian.PutUint32(lenData, uint32(len(data))) 76 77 reqData := append(lenData, data...) 78 79 ctx, f := context.WithTimeout(context.Background(), 20*time.Millisecond) 80 defer f() 81 82 _, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("tcp4"), 83 transport.WithDialAddress(addr), 84 transport.WithClientFramerBuilder(&framerBuilder{})) 85 assert.NotNil(t, err) 86 } 87 88 func TestTCPTLSListenAndServe(t *testing.T) { 89 addr := getFreeAddr("tcp") 90 91 // Wait until server transport ready. 92 wg := &sync.WaitGroup{} 93 wg.Add(1) 94 go func() { 95 defer wg.Done() 96 st := transport.NewServerTransport() 97 err := st.ListenAndServe(context.Background(), 98 transport.WithListenNetwork("tcp"), 99 transport.WithListenAddress(addr), 100 transport.WithHandler(&echoHandler{}), 101 transport.WithServerFramerBuilder(&framerBuilder{}), 102 transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.pem"), 103 ) 104 105 if err != nil { 106 t.Logf("ListenAndServe fail:%v", err) 107 } 108 }() 109 wg.Wait() 110 111 // Round trip. 112 req := &helloRequest{ 113 Name: "trpc", 114 Msg: "HelloWorld", 115 } 116 117 data, err := json.Marshal(req) 118 if err != nil { 119 t.Fatalf("json marshal fail:%v", err) 120 } 121 lenData := make([]byte, 4) 122 binary.BigEndian.PutUint32(lenData, uint32(len(data))) 123 124 reqData := append(lenData, data...) 125 126 ctx, f := context.WithTimeout(context.Background(), 200*time.Millisecond) 127 defer f() 128 129 _, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("tcp"), 130 transport.WithDialAddress(addr), 131 transport.WithClientFramerBuilder(&framerBuilder{}), 132 transport.WithDialTLS("../testdata/client.crt", "../testdata/client.key", "../testdata/ca.pem", "localhost")) 133 assert.Nil(t, err) 134 135 _, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("tcp"), 136 transport.WithDialAddress(addr), 137 transport.WithClientFramerBuilder(&framerBuilder{}), 138 transport.WithDialTLS("../testdata/client.crt", "../testdata/client.key", "none", "")) 139 assert.Nil(t, err) 140 } 141 142 func TestHandleError(t *testing.T) { 143 var addr = getFreeAddr("udp4") 144 145 // Wait until server transport is ready. 146 wg := &sync.WaitGroup{} 147 wg.Add(1) 148 go func() { 149 defer wg.Done() 150 err := transport.ListenAndServe( 151 transport.WithListenNetwork("udp4"), 152 transport.WithListenAddress(addr), 153 transport.WithHandler(&errorHandler{}), 154 transport.WithServerFramerBuilder(&framerBuilder{}), 155 ) 156 157 if err != nil { 158 t.Logf("test fail:%v", err) 159 } 160 }() 161 wg.Wait() 162 163 // Round trip. 164 req := &helloRequest{ 165 Name: "trpc", 166 Msg: "HelloWorld", 167 } 168 169 data, err := json.Marshal(req) 170 if err != nil { 171 t.Fatalf("test fail:%v", err) 172 } 173 lenData := make([]byte, 4) 174 binary.BigEndian.PutUint32(lenData, uint32(len(data))) 175 176 reqData := append(lenData, data...) 177 178 ctx, f := context.WithTimeout(context.Background(), 20*time.Millisecond) 179 defer f() 180 _, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("udp4"), 181 transport.WithDialAddress(addr), 182 transport.WithClientFramerBuilder(&framerBuilder{})) 183 assert.NotNil(t, err) 184 } 185 186 func TestNewServerTransport_NotSupport(t *testing.T) { 187 st := transport.NewServerTransport() 188 err := st.ListenAndServe(context.Background(), transport.WithListenNetwork("unix")) 189 assert.NotNil(t, err) 190 191 err = st.ListenAndServe(context.Background(), transport.WithListenNetwork("xxx")) 192 assert.NotNil(t, err) 193 } 194 195 func TestServerTransport_ListenAndServeUDP(t *testing.T) { 196 // NoReusePort 197 st := transport.NewServerTransport(transport.WithReusePort(false), 198 transport.WithKeepAlivePeriod(time.Minute)) 199 err := st.ListenAndServe( 200 context.Background(), 201 transport.WithListenNetwork("udp"), 202 transport.WithServerFramerBuilder(&framerBuilder{}), 203 ) 204 assert.Nil(t, err) 205 206 st = transport.NewServerTransport(transport.WithReusePort(true)) 207 err = st.ListenAndServe( 208 context.Background(), 209 transport.WithListenNetwork("udp"), 210 transport.WithServerFramerBuilder(&framerBuilder{}), 211 ) 212 assert.Nil(t, err) 213 214 st = transport.NewServerTransport(transport.WithReusePort(true)) 215 err = st.ListenAndServe( 216 context.Background(), 217 transport.WithListenNetwork("ip"), 218 transport.WithServerFramerBuilder(&framerBuilder{}), 219 ) 220 assert.NotNil(t, err) 221 } 222 223 func TestServerTransport_ListenAndServe(t *testing.T) { 224 // NoFramerBuilder 225 st := transport.NewServerTransport(transport.WithReusePort(false)) 226 err := st.ListenAndServe(context.Background(), transport.WithListenNetwork("tcp")) 227 assert.NotNil(t, err) 228 229 fb := transport.GetFramerBuilder("trpc") 230 // NoReusePort 231 st = transport.NewServerTransport(transport.WithReusePort(false)) 232 err = st.ListenAndServe(context.Background(), 233 transport.WithListenNetwork("tcp"), 234 transport.WithServerFramerBuilder(fb)) 235 assert.Nil(t, err) 236 237 // ReusePort 238 st = transport.NewServerTransport(transport.WithReusePort(true)) 239 err = st.ListenAndServe(context.Background(), 240 transport.WithListenNetwork("tcp"), 241 transport.WithServerFramerBuilder(fb)) 242 assert.Nil(t, err) 243 244 // Listener 245 lis, err := net.Listen("tcp", getFreeAddr("tcp")) 246 assert.Nil(t, err) 247 st = transport.NewServerTransport() 248 err = st.ListenAndServe(context.Background(), 249 transport.WithListener(lis), 250 transport.WithServerFramerBuilder(fb)) 251 assert.Nil(t, err) 252 lis.Close() 253 254 // ReusePort + Listen Error 255 st = transport.NewServerTransport(transport.WithReusePort(true)) 256 err = st.ListenAndServe(context.Background(), 257 transport.WithListenNetwork("tcperror"), 258 transport.WithServerFramerBuilder(fb)) 259 assert.NotNil(t, err) 260 261 // context cancel 262 ctx, cancel := context.WithCancel(context.Background()) 263 cancel() 264 st = transport.NewServerTransport(transport.WithReusePort(true)) 265 err = st.ListenAndServe(ctx, transport.WithListenNetwork("tcp"), transport.WithServerFramerBuilder(fb)) 266 assert.Nil(t, err) 267 } 268 269 func TestServerTransport_ListenAndServeBothUDPAndTCP(t *testing.T) { 270 fb := transport.GetFramerBuilder("trpc") 271 // Empty network. 272 network := "" 273 st := transport.NewServerTransport() 274 err := st.ListenAndServe(context.Background(), transport.WithListenNetwork(network)) 275 assert.EqualError(t, err, "server transport: not support network type "+network) 276 277 // Another unknown wrong input. 278 network = "wrong_type" 279 st = transport.NewServerTransport() 280 err = st.ListenAndServe(context.Background(), transport.WithListenNetwork(network)) 281 assert.EqualError(t, err, "server transport: not support network type "+network) 282 283 // Right input. 284 network = "tcp,udp" 285 // No reuse. 286 st = transport.NewServerTransport(transport.WithReusePort(false)) 287 err = st.ListenAndServe(context.Background(), 288 transport.WithListenNetwork(network), 289 transport.WithServerFramerBuilder(fb)) 290 assert.Nil(t, err) 291 } 292 293 // TestTCPListenAndServeAsync tests asynchronous server process. 294 func TestTCPListenAndServeAsync(t *testing.T) { 295 var addr = getFreeAddr("tcp4") 296 297 // Wait until server transport is ready. 298 wg := sync.WaitGroup{} 299 wg.Add(1) 300 go func() { 301 defer wg.Done() 302 st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute)) 303 err := st.ListenAndServe(context.Background(), 304 transport.WithListenNetwork("tcp4"), 305 transport.WithListenAddress(addr), 306 transport.WithHandler(&errorHandler{}), 307 transport.WithServerFramerBuilder(&framerBuilder{}), 308 transport.WithServerAsync(true), 309 transport.WithWritev(true), 310 ) 311 312 if err != nil { 313 t.Logf("ListenAndServe fail:%v", err) 314 } 315 }() 316 wg.Wait() 317 318 // round trip 319 req := &helloRequest{ 320 Name: "trpc", 321 Msg: "HelloWorld", 322 } 323 324 data, err := json.Marshal(req) 325 if err != nil { 326 t.Fatalf("json marshal fail:%v", err) 327 } 328 lenData := make([]byte, 4) 329 binary.BigEndian.PutUint32(lenData, uint32(len(data))) 330 331 reqData := append(lenData, data...) 332 333 ctx, f := context.WithTimeout(context.Background(), 20*time.Millisecond) 334 defer f() 335 336 _, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("tcp4"), 337 transport.WithDialAddress(addr), 338 transport.WithClientFramerBuilder(&framerBuilder{})) 339 assert.NotNil(t, err) 340 } 341 342 // TestTCPListenAndServerRoutinePool tests serving with goroutine pool. 343 func TestTCPListenAndServerRoutinePool(t *testing.T) { 344 var addr = getFreeAddr("tcp4") 345 346 // Wait until server transport is ready. 347 wg := sync.WaitGroup{} 348 wg.Add(1) 349 go func() { 350 defer wg.Done() 351 st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute)) 352 err := st.ListenAndServe(context.Background(), 353 transport.WithListenNetwork("tcp4"), 354 transport.WithListenAddress(addr), 355 transport.WithHandler(&errorHandler{}), 356 transport.WithServerFramerBuilder(&framerBuilder{}), 357 transport.WithServerAsync(true), 358 transport.WithMaxRoutines(100), 359 ) 360 361 if err != nil { 362 t.Logf("ListenAndServe fail:%v", err) 363 } 364 }() 365 wg.Wait() 366 367 // round trip 368 req := &helloRequest{ 369 Name: "trpc", 370 Msg: "HelloWorld", 371 } 372 373 data, err := json.Marshal(req) 374 if err != nil { 375 t.Fatalf("json marshal fail:%v", err) 376 } 377 lenData := make([]byte, 4) 378 binary.BigEndian.PutUint32(lenData, uint32(len(data))) 379 380 reqData := append(lenData, data...) 381 382 ctx, f := context.WithTimeout(context.Background(), 20*time.Millisecond) 383 defer f() 384 385 _, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("tcp4"), 386 transport.WithDialAddress(addr), 387 transport.WithClientFramerBuilder(&framerBuilder{})) 388 assert.NotNil(t, err) 389 } 390 391 func TestWithReusePort(t *testing.T) { 392 opts := &transport.ServerTransportOptions{} 393 require.False(t, opts.ReusePort) 394 395 opt := transport.WithReusePort(true) 396 require.NotNil(t, opt) 397 opt(opts) 398 if runtime.GOOS != "windows" { 399 require.True(t, opts.ReusePort) 400 } else { 401 require.False(t, opts.ReusePort) 402 } 403 404 opt = transport.WithReusePort(false) 405 require.NotNil(t, opt) 406 opt(opts) 407 require.False(t, opts.ReusePort) 408 } 409 410 func TestWithRecvMsgChannelSize(t *testing.T) { 411 opt := transport.WithRecvMsgChannelSize(1000) 412 assert.NotNil(t, opt) 413 opts := &transport.ServerTransportOptions{} 414 opt(opts) 415 assert.Equal(t, 1000, opts.RecvMsgChannelSize) 416 } 417 418 func TestWithSendMsgChannelSize(t *testing.T) { 419 opt := transport.WithSendMsgChannelSize(1000) 420 assert.NotNil(t, opt) 421 opts := &transport.ServerTransportOptions{} 422 opt(opts) 423 assert.Equal(t, 1000, opts.SendMsgChannelSize) 424 } 425 426 func TestWithRecvUDPPacketBufferSize(t *testing.T) { 427 opt := transport.WithRecvUDPPacketBufferSize(1000) 428 assert.NotNil(t, opt) 429 opts := &transport.ServerTransportOptions{} 430 opt(opts) 431 assert.Equal(t, 1000, opts.RecvUDPPacketBufferSize) 432 } 433 434 func TestWithRecvUDPRawSocketBufSize(t *testing.T) { 435 opt := transport.WithRecvUDPRawSocketBufSize(1000) 436 assert.NotNil(t, opt) 437 opts := &transport.ServerTransportOptions{} 438 opt(opts) 439 assert.Equal(t, 1000, opts.RecvUDPRawSocketBufSize) 440 } 441 442 func TestWithIdleTimeout(t *testing.T) { 443 opt := transport.WithIdleTimeout(time.Second) 444 assert.NotNil(t, opt) 445 opts := &transport.ServerTransportOptions{} 446 opt(opts) 447 assert.Equal(t, time.Second, opts.IdleTimeout) 448 } 449 450 func TestWithKeepAlivePeriod(t *testing.T) { 451 opt := transport.WithKeepAlivePeriod(time.Minute) 452 assert.NotNil(t, opt) 453 opts := &transport.ServerTransportOptions{} 454 opt(opts) 455 assert.Equal(t, time.Minute, opts.KeepAlivePeriod) 456 } 457 458 func TestWithServeTLS(t *testing.T) { 459 opt := transport.WithServeTLS("certfile", "keyfile", "") 460 assert.NotNil(t, opt) 461 opts := &transport.ListenServeOptions{} 462 opt(opts) 463 assert.Equal(t, "certfile", opts.TLSCertFile) 464 assert.Equal(t, "keyfile", opts.TLSKeyFile) 465 } 466 467 // TestWithServeAsync tests setting server async. 468 func TestWithServeAsync(t *testing.T) { 469 opt := transport.WithServerAsync(true) 470 assert.NotNil(t, opt) 471 opts := &transport.ListenServeOptions{} 472 opt(opts) 473 assert.Equal(t, true, opts.ServerAsync) 474 } 475 476 // TestWithWritev tests setting writev. 477 func TestWithWritev(t *testing.T) { 478 opt := transport.WithWritev(true) 479 assert.NotNil(t, opt) 480 opts := &transport.ListenServeOptions{} 481 opt(opts) 482 assert.Equal(t, true, opts.Writev) 483 } 484 485 // TestWithMaxRoutine tests setting max number of goroutines. 486 func TestWithMaxRoutine(t *testing.T) { 487 opt := transport.WithMaxRoutines(100) 488 assert.NotNil(t, opt) 489 opts := &transport.ListenServeOptions{} 490 opt(opts) 491 assert.Equal(t, 100, opts.Routines) 492 } 493 494 // TestTCPServerClosed tests if TCP listener can be closed immediately. 495 func TestTCPListenerClosed(t *testing.T) { 496 err := tryCloseTCPListener(false) 497 if err != nil { 498 t.Errorf("close tcp listener err: %v", err) 499 } 500 } 501 502 // TestTCPListenerClosed_WithReuseport tests if TCP listener can be closed immediately. 503 func TestTCPListenerClosed_WithReuseport(t *testing.T) { 504 err := tryCloseTCPListener(true) 505 if err != nil { 506 t.Errorf("close tcp listener (with reuseport) err: %v", err) 507 } 508 } 509 510 func tryCloseTCPListener(reuseport bool) error { 511 port, err := getFreePort("tcp") 512 if err != nil { 513 return fmt.Errorf("get freeport error: %v", err) 514 } 515 516 ctx := context.Background() 517 ctx, cancel := context.WithCancel(ctx) 518 519 var prepareErr error 520 wg := sync.WaitGroup{} 521 wg.Add(1) 522 go func() { 523 defer wg.Done() 524 st := transport.NewServerTransport(transport.WithReusePort(reuseport)) 525 err := st.ListenAndServe(ctx, 526 transport.WithListenNetwork("tcp"), 527 transport.WithListenAddress(fmt.Sprintf(":%d", port)), 528 transport.WithHandler(&echoHandler{}), 529 transport.WithServerFramerBuilder(&framerBuilder{}), 530 ) 531 if err != nil { 532 prepareErr = err 533 } 534 }() 535 wg.Wait() 536 537 if prepareErr != nil { 538 cancel() 539 return fmt.Errorf("prepare listener error: %v", prepareErr) 540 } 541 542 // First time dial, should work. 543 conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", port)) 544 if err != nil { 545 cancel() 546 return fmt.Errorf("tcp dial error: %v", err) 547 } 548 conn.Close() 549 550 // Notify and wait server close. 551 cancel() 552 time.Sleep(5 * time.Millisecond) 553 554 // Second time dial, must fail. 555 _, err = net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), 10*time.Millisecond) 556 if err == nil { 557 return fmt.Errorf("tcp dial (2nd time) want error") 558 } 559 return nil 560 } 561 562 func TestGetListenersFds(t *testing.T) { 563 ListenFds := transport.GetListenersFds() 564 assert.NotNil(t, ListenFds) 565 } 566 567 var savedListenerPort int 568 569 func TestSaveListener(t *testing.T) { 570 port, err := getFreePort("tcp") 571 if err != nil { 572 t.Fatalf("get freeport error: %v", err) 573 } 574 err = transport.SaveListener(NewPacketConn{}) 575 assert.NotNil(t, err) 576 577 newListener, _ := net.Listen("tcp", fmt.Sprintf(":%d", port)) 578 err = transport.SaveListener(newListener) 579 assert.Nil(t, err) 580 savedListenerPort = port 581 } 582 583 func TestTCPSeverErr(t *testing.T) { 584 st := transport.NewServerTransport() 585 err := st.ListenAndServe(context.Background(), 586 transport.WithListenNetwork("tcp"), 587 transport.WithListenAddress(getFreeAddr("tcp")), 588 transport.WithHandler(&echoHandler{}), 589 transport.WithServerFramerBuilder(&framerBuilder{})) 590 assert.Nil(t, err) 591 } 592 593 func TestUDPServerErr(t *testing.T) { 594 st := transport.NewServerTransport() 595 596 err := st.ListenAndServe(context.Background(), 597 transport.WithListenNetwork("udp"), 598 transport.WithListenAddress(getFreeAddr("udp")), 599 transport.WithHandler(&echoHandler{}), 600 transport.WithServerFramerBuilder(&framerBuilder{})) 601 assert.Nil(t, err) 602 } 603 604 type fakeListen struct { 605 } 606 607 func (c *fakeListen) Accept() (net.Conn, error) { 608 return nil, &netError{errors.New("网络失败")} 609 } 610 func (c *fakeListen) Close() error { 611 return nil 612 } 613 614 func (c *fakeListen) Addr() net.Addr { 615 return nil 616 } 617 618 func TestTCPServerConErr(t *testing.T) { 619 go func() { 620 fb := transport.GetFramerBuilder("trpc") 621 st := transport.NewServerTransport() 622 err := st.ListenAndServe(context.Background(), 623 transport.WithListener(&fakeListen{}), 624 transport.WithServerFramerBuilder(fb)) 625 if err != nil { 626 t.Logf("ListenAndServe fail:%v", err) 627 } 628 }() 629 } 630 631 func TestUDPServerConErr(t *testing.T) { 632 fb := transport.GetFramerBuilder("trpc") 633 st := transport.NewServerTransport() 634 err := st.ListenAndServe(context.Background(), 635 transport.WithListenNetwork("udp"), 636 transport.WithListenAddress(getFreeAddr("udp")), 637 transport.WithServerFramerBuilder(fb)) 638 if err != nil { 639 t.Fatalf("ListenAndServe fail:%v", err) 640 } 641 } 642 643 func getFreePort(network string) (int, error) { 644 if network == "tcp" || network == "tcp4" || network == "tcp6" { 645 addr, err := net.ResolveTCPAddr(network, "localhost:0") 646 if err != nil { 647 return -1, err 648 } 649 650 l, err := net.ListenTCP(network, addr) 651 if err != nil { 652 return -1, err 653 } 654 defer l.Close() 655 656 return l.Addr().(*net.TCPAddr).Port, nil 657 } 658 659 if network == "udp" || network == "udp4" || network == "udp6" { 660 addr, err := net.ResolveUDPAddr(network, "localhost:0") 661 if err != nil { 662 return -1, err 663 } 664 665 l, err := net.ListenUDP(network, addr) 666 if err != nil { 667 return -1, err 668 } 669 defer l.Close() 670 671 return l.LocalAddr().(*net.UDPAddr).Port, nil 672 } 673 674 return -1, errors.New("invalid network") 675 } 676 677 func TestGetFreePort(t *testing.T) { 678 for i := 0; i < 10; i++ { 679 p, err := getFreePort("tcp") 680 assert.Nil(t, err) 681 assert.NotEqual(t, p, -1) 682 t.Logf("get freeport network:%s, port:%d", "tcp", p) 683 } 684 685 for i := 0; i < 10; i++ { 686 p, err := getFreePort("udp") 687 assert.Nil(t, err) 688 assert.NotEqual(t, p, -1) 689 t.Logf("get freeport network:%s, port:%d", "udp", p) 690 } 691 692 p1, err := getFreePort("tcp") 693 assert.Nil(t, err) 694 695 p2, err := getFreePort("tcp") 696 assert.Nil(t, err) 697 assert.NotEqual(t, p1, p2, "allocated 2 conflict ports") 698 } 699 700 func getFreeAddr(network string) string { 701 p, err := getFreePort(network) 702 if err != nil { 703 panic(err) 704 } 705 706 return fmt.Sprintf(":%d", p) 707 } 708 709 func TestTCPWriteToClosedConn(t *testing.T) { 710 l, err := net.Listen("tcp4", "localhost:0") 711 require.Nil(t, err) 712 defer l.Close() 713 714 var wg sync.WaitGroup 715 wg.Add(1) 716 go func() { 717 defer wg.Done() 718 st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute)) 719 err := st.ListenAndServe(context.Background(), 720 transport.WithListener(l), 721 transport.WithHandler(&echoHandler{}), 722 transport.WithServerFramerBuilder(&framerBuilder{}), 723 transport.WithServerAsync(true), 724 ) 725 assert.Nil(t, err) 726 }() 727 wg.Wait() 728 conn, err := net.Dial("tcp4", l.Addr().String()) 729 require.Nil(t, err) 730 require.Nil(t, conn.Close()) 731 _, err = conn.Write([]byte("data")) 732 require.Contains(t, errs.Msg(err), "use of closed network connection") 733 } 734 735 func TestTCPServerHandleErrAndClose(t *testing.T) { 736 var addr = getFreeAddr("tcp4") 737 738 wg := sync.WaitGroup{} 739 wg.Add(1) 740 go func() { 741 defer wg.Done() 742 st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute)) 743 err := st.ListenAndServe(context.Background(), 744 transport.WithListenNetwork("tcp4"), 745 transport.WithListenAddress(addr), 746 transport.WithHandler(&errorHandler{}), 747 transport.WithServerFramerBuilder(&framerBuilder{}), 748 transport.WithServerAsync(true), 749 ) 750 assert.Nil(t, err) 751 }() 752 wg.Wait() 753 754 // First time dial, should work. 755 conn, err := net.Dial("tcp", addr) 756 assert.Nil(t, err) 757 time.Sleep(time.Millisecond * 5) 758 data := []byte("hello world") 759 req := make([]byte, 4) 760 binary.BigEndian.PutUint32(req, uint32(len(data))) 761 req = append(req, data...) 762 _, err = conn.Write(req) 763 assert.Nil(t, err) 764 765 // Check the connection is closed by server. 766 time.Sleep(time.Millisecond * 5) 767 out := make([]byte, 8) 768 _, err = conn.Read(out) 769 assert.NotNil(t, err) 770 } 771 772 // TestTCPListenAndServeWithSafeFramer tests that we support safe framer without copying packages. 773 func TestUDPListenAndServeWithSafeFramer(t *testing.T) { 774 var addr = getFreeAddr("udp") 775 776 // Wait until server transport is ready. 777 wg := sync.WaitGroup{} 778 wg.Add(1) 779 go func() { 780 defer wg.Done() 781 err := transport.ListenAndServe( 782 transport.WithListenNetwork("udp"), 783 transport.WithListenAddress(addr), 784 transport.WithHandler(&echoHandler{}), 785 transport.WithServerFramerBuilder(&framerBuilder{safe: true}), 786 ) 787 assert.Nil(t, err) 788 time.Sleep(20 * time.Millisecond) 789 }() 790 wg.Wait() 791 792 req := &helloRequest{ 793 Name: "trpc", 794 Msg: "HelloWorld", 795 } 796 data, err := json.Marshal(req) 797 if err != nil { 798 t.Fatalf("json marshal fail:%v", err) 799 } 800 lenData := make([]byte, 4) 801 binary.BigEndian.PutUint32(lenData, uint32(len(data))) 802 reqData := append(lenData, data...) 803 ctx, f := context.WithTimeout(context.Background(), 20*time.Millisecond) 804 defer f() 805 806 rspData, err := transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("udp"), 807 transport.WithDialAddress(addr), 808 transport.WithClientFramerBuilder(&framerBuilder{safe: true})) 809 assert.Nil(t, err) 810 811 length := binary.BigEndian.Uint32(rspData[:4]) 812 helloRsp := &helloResponse{} 813 err = json.Unmarshal(rspData[4:4+length], helloRsp) 814 assert.Nil(t, err) 815 assert.Equal(t, helloRsp.Msg, "HelloWorld") 816 } 817 818 // TestTCPListenAndServeWithSafeFramer tests that frame is not copied when Framer is already safe. 819 func TestTCPListenAndServeWithSafeFramer(t *testing.T) { 820 var addr = getFreeAddr("tcp4") 821 822 wg := sync.WaitGroup{} 823 wg.Add(1) 824 go func() { 825 defer wg.Done() 826 st := transport.NewServerTransport(transport.WithKeepAlivePeriod(time.Minute)) 827 err := st.ListenAndServe(context.Background(), 828 transport.WithListenNetwork("tcp4"), 829 transport.WithListenAddress(addr), 830 transport.WithHandler(&echoHandler{}), 831 transport.WithServerFramerBuilder(&framerBuilder{safe: true}), 832 transport.WithServerAsync(true), 833 ) 834 assert.Nil(t, err) 835 time.Sleep(20 * time.Millisecond) 836 }() 837 wg.Wait() 838 839 req := &helloRequest{ 840 Name: "trpc", 841 Msg: "HelloWorld", 842 } 843 data, err := json.Marshal(req) 844 if err != nil { 845 t.Fatalf("json marshal fail:%v", err) 846 } 847 lenData := make([]byte, 4) 848 binary.BigEndian.PutUint32(lenData, uint32(len(data))) 849 reqData := append(lenData, data...) 850 ctx, f := context.WithTimeout(context.Background(), 20*time.Millisecond) 851 defer f() 852 853 rspData, err := transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("tcp4"), 854 transport.WithDialAddress(addr), 855 transport.WithClientFramerBuilder(&framerBuilder{safe: true})) 856 assert.Nil(t, err) 857 858 length := binary.BigEndian.Uint32(rspData[:4]) 859 helloRsp := &helloResponse{} 860 err = json.Unmarshal(rspData[4:4+length], helloRsp) 861 assert.Nil(t, err) 862 assert.Equal(t, helloRsp.Msg, "HelloWorld") 863 } 864 865 func TestWithDisableKeepAlives(t *testing.T) { 866 disable := true 867 o := transport.WithDisableKeepAlives(true) 868 opts := &transport.ListenServeOptions{} 869 o(opts) 870 assert.Equal(t, disable, opts.DisableKeepAlives) 871 } 872 873 func TestWithServerIdleTimeout(t *testing.T) { 874 idleTimeout := time.Second 875 o := transport.WithServerIdleTimeout(idleTimeout) 876 opts := &transport.ListenServeOptions{} 877 o(opts) 878 assert.Equal(t, opts.IdleTimeout, idleTimeout) 879 } 880 881 func TestUDPServeClose(t *testing.T) { 882 ts := transport.NewServerTransport() 883 ctx, cancel := context.WithCancel(context.Background()) 884 cancel() 885 err := ts.ListenAndServe( 886 ctx, 887 transport.WithListenNetwork("udp"), 888 transport.WithListenAddress(getFreeAddr("udp")), 889 transport.WithHandler(&echoHandler{}), 890 transport.WithServerFramerBuilder(&framerBuilder{safe: true}), 891 transport.WithServerAsync(true), 892 ) 893 assert.Nil(t, err) 894 time.Sleep(100 * time.Millisecond) 895 } 896 897 type MockUDPError struct{} 898 899 func (e MockUDPError) Error() string { return "mock udp error" } 900 func (e MockUDPError) Timeout() bool { return false } 901 func (e MockUDPError) Temporary() bool { return true } 902 903 func TestUDPReadError(t *testing.T) { 904 addr := getFreeAddr("udp") 905 906 err := transport.ListenAndServe( 907 transport.WithListenNetwork("udp"), 908 transport.WithListenAddress(addr), 909 transport.WithHandler(&echoHandler{}), 910 transport.WithServerFramerBuilder(&framerBuilder{safe: true}), 911 transport.WithServerAsync(false), 912 ) 913 assert.Nil(t, err) 914 time.Sleep(60 * time.Millisecond) 915 } 916 917 func TestUDPWriteError(t *testing.T) { 918 addr := getFreeAddr("udp") 919 920 err := transport.ListenAndServe( 921 transport.WithListenNetwork("udp"), 922 transport.WithListenAddress(addr), 923 transport.WithHandler(&echoHandler{}), 924 transport.WithServerFramerBuilder(&framerBuilder{safe: true}), 925 transport.WithServerAsync(false), 926 ) 927 assert.Nil(t, err) 928 time.Sleep(20 * time.Millisecond) 929 930 req := &helloRequest{ 931 Name: "trpc", 932 Msg: "HelloWorld", 933 } 934 data, err := json.Marshal(req) 935 if err != nil { 936 t.Fatalf("json marshal fail:%v", err) 937 } 938 lenData := make([]byte, 4) 939 binary.BigEndian.PutUint32(lenData, uint32(len(data))) 940 reqData := append(lenData, data...) 941 ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) 942 defer cancel() 943 _, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("udp"), 944 transport.WithDialAddress(addr), 945 transport.WithClientFramerBuilder(&framerBuilder{safe: true})) 946 assert.Nil(t, err) 947 } 948 949 func TestPoolInvokeFail(t *testing.T) { 950 951 addr := getFreeAddr("udp") 952 953 err := transport.ListenAndServe( 954 transport.WithListenNetwork("udp"), 955 transport.WithListenAddress(addr), 956 transport.WithHandler(&echoHandler{}), 957 transport.WithServerFramerBuilder(&framerBuilder{safe: true}), 958 transport.WithServerAsync(true), 959 ) 960 assert.Nil(t, err) 961 time.Sleep(20 * time.Millisecond) 962 963 req := &helloRequest{ 964 Name: "trpc", 965 Msg: "HelloWorld", 966 } 967 data, err := json.Marshal(req) 968 if err != nil { 969 t.Fatalf("json marshal fail:%v", err) 970 } 971 lenData := make([]byte, 4) 972 binary.BigEndian.PutUint32(lenData, uint32(len(data))) 973 reqData := append(lenData, data...) 974 ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) 975 defer cancel() 976 _, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("udp"), 977 transport.WithDialAddress(addr), 978 transport.WithClientFramerBuilder(&framerBuilder{safe: true})) 979 assert.Nil(t, err) 980 } 981 982 func TestCreatePoolFail(t *testing.T) { 983 addr := getFreeAddr("udp") 984 985 err := transport.ListenAndServe( 986 transport.WithListenNetwork("udp"), 987 transport.WithListenAddress(addr), 988 transport.WithHandler(&echoHandler{}), 989 transport.WithServerFramerBuilder(&framerBuilder{safe: true}), 990 transport.WithServerAsync(true), 991 ) 992 assert.Nil(t, err) 993 994 req := &helloRequest{ 995 Name: "trpc", 996 Msg: "HelloWorld", 997 } 998 data, err := json.Marshal(req) 999 if err != nil { 1000 t.Fatalf("json marshal fail:%v", err) 1001 } 1002 lenData := make([]byte, 4) 1003 binary.BigEndian.PutUint32(lenData, uint32(len(data))) 1004 reqData := append(lenData, data...) 1005 ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) 1006 defer cancel() 1007 _, err = transport.RoundTrip(ctx, reqData, transport.WithDialNetwork("udp"), 1008 transport.WithDialAddress(addr), 1009 transport.WithClientFramerBuilder(&framerBuilder{safe: true})) 1010 assert.Nil(t, err) 1011 } 1012 1013 func TestListenAndServeTLSFail(t *testing.T) { 1014 s := transport.NewServerTransport() 1015 ctx, cancel := context.WithCancel(context.Background()) 1016 defer cancel() 1017 ln, err := net.Listen("tcp", "127.0.0.1:0") 1018 require.Nil(t, err) 1019 defer ln.Close() 1020 require.NotNil(t, s.ListenAndServe(ctx, 1021 transport.WithListenNetwork("tcp"), 1022 transport.WithServeTLS("fakeCertFileName", "fakeKeyFileName", "fakeCAFileName"), 1023 transport.WithServerFramerBuilder(&framerBuilder{}), 1024 transport.WithListener(ln), 1025 )) 1026 } 1027 1028 func TestListenAndServeWithStopListener(t *testing.T) { 1029 s := transport.NewServerTransport() 1030 ctx, cancel := context.WithCancel(context.Background()) 1031 defer cancel() 1032 ln, err := net.Listen("tcp", "127.0.0.1:0") 1033 require.Nil(t, err) 1034 ch := make(chan struct{}) 1035 require.Nil(t, s.ListenAndServe(ctx, 1036 transport.WithListenNetwork("tcp"), 1037 transport.WithServerFramerBuilder(&framerBuilder{}), 1038 transport.WithListener(ln), 1039 transport.WithStopListening(ch), 1040 )) 1041 _, err = net.Dial("tcp", ln.Addr().String()) 1042 require.Nil(t, err) 1043 close(ch) 1044 time.Sleep(time.Millisecond) 1045 _, err = net.Dial("tcp", ln.Addr().String()) 1046 require.NotNil(t, err) 1047 }