trpc.group/trpc-go/trpc-go@v1.0.3/transport/client_transport_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package transport_test 15 16 import ( 17 "context" 18 "errors" 19 "fmt" 20 "io" 21 "math" 22 "net" 23 "strings" 24 "testing" 25 "time" 26 27 "trpc.group/trpc-go/trpc-go/codec" 28 "trpc.group/trpc-go/trpc-go/errs" 29 "trpc.group/trpc-go/trpc-go/pool/connpool" 30 "trpc.group/trpc-go/trpc-go/pool/multiplexed" 31 "trpc.group/trpc-go/trpc-go/transport" 32 33 "github.com/stretchr/testify/assert" 34 "github.com/stretchr/testify/require" 35 36 trpc "trpc.group/trpc-go/trpc-go" 37 ) 38 39 func TestTcpRoundTripPoolNIl(t *testing.T) { 40 st := transport.NewClientTransport() 41 optNetwork := transport.WithDialNetwork("tcp") 42 optPool := transport.WithDialPool(nil) 43 _, err := st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool) 44 assert.NotNil(t, err) 45 } 46 47 func TestTcpRoundTripTCPErr(t *testing.T) { 48 st := transport.NewClientTransport() 49 optNetwork := transport.WithDialNetwork("tcp") 50 pool := connpool.NewConnectionPool() 51 optPool := transport.WithDialPool(pool) 52 fb := &trpc.FramerBuilder{} 53 optFramerBuilder := transport.WithClientFramerBuilder(fb) 54 optDisabled := transport.WithDisableConnectionPool() 55 newCtx := context.Background() 56 newCtx.Done() 57 newCtx.Deadline() 58 _, err := st.RoundTrip(newCtx, []byte("hello"), optNetwork, optPool, optFramerBuilder, optDisabled) 59 assert.NotNil(t, err) 60 } 61 62 func TestTcpRoundTripCTXErr(t *testing.T) { 63 st := transport.NewClientTransport() 64 optNetwork := transport.WithDialNetwork("tcp") 65 pool := connpool.NewConnectionPool() 66 optPool := transport.WithDialPool(pool) 67 fb := &trpc.FramerBuilder{} 68 optFramerBuilder := transport.WithClientFramerBuilder(fb) 69 _, err := st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool, optFramerBuilder) 70 assert.NotNil(t, err) 71 } 72 73 type fakePool struct { 74 } 75 76 func (p *fakePool) Get(network string, address string, opts connpool.GetOptions) (net.Conn, error) { 77 return &fakeConn{}, nil 78 } 79 80 type fakeConn struct { 81 } 82 83 func (c *fakeConn) Close() error { 84 return nil 85 } 86 87 func (c *fakeConn) Read(b []byte) (n int, err error) { 88 return 0, nil 89 } 90 91 type netError struct { 92 error 93 } 94 95 // Timeout() bool 96 // Temporary() bool 97 func (c *netError) Timeout() bool { 98 return true 99 } 100 func (c *netError) Temporary() bool { 101 return true 102 } 103 104 func (c *fakeConn) Write(b []byte) (n int, err error) { 105 if Count == 1 { 106 return 0, errors.New("write failure") 107 } 108 if Count == 2 { 109 return 0, netError{errors.New("net failure")} 110 } 111 return len(b), nil 112 } 113 114 func (c *fakeConn) LocalAddr() net.Addr { 115 return nil 116 } 117 118 func (c *fakeConn) RemoteAddr() net.Addr { 119 return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8888} 120 } 121 122 func (c *fakeConn) SetDeadline(t time.Time) error { 123 return nil 124 } 125 126 func (c *fakeConn) SetReadDeadline(t time.Time) error { 127 return nil 128 } 129 130 func (c *fakeConn) SetWriteDeadline(t time.Time) error { 131 return nil 132 } 133 134 func TestTcpRoundTripReadFrameNil(t *testing.T) { 135 st := transport.NewClientTransport() 136 optNetwork := transport.WithDialNetwork("tcp") 137 optPool := transport.WithDialPool(&fakePool{}) 138 fb := &trpc.FramerBuilder{} 139 optFramerBuilder := transport.WithClientFramerBuilder(fb) 140 optReqType := transport.WithReqType(transport.SendOnly) 141 optAddress := transport.WithDialAddress(":8888") 142 _, err := st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool, optFramerBuilder, 143 optReqType, optAddress) 144 assert.NotNil(t, err) 145 } 146 147 func TestTCPRoundTripSetRemoteAddr(t *testing.T) { 148 st := transport.NewClientTransport() 149 optNetwork := transport.WithDialNetwork("tcp") 150 optPool := transport.WithDialPool(&fakePool{}) 151 fb := &trpc.FramerBuilder{} 152 optFramerBuilder := transport.WithClientFramerBuilder(fb) 153 optAddress := transport.WithDialAddress("127.0.0.1:8888") 154 ctx, msg := codec.WithNewMessage(context.Background()) 155 _, _ = st.RoundTrip(ctx, []byte("hello"), optNetwork, optPool, optFramerBuilder, optAddress) 156 assert.NotNil(t, msg.RemoteAddr()) 157 assert.Equal(t, "127.0.0.1:8888", msg.RemoteAddr().String()) 158 } 159 160 type newCtx struct { 161 } 162 163 var Count int64 164 165 func (c *newCtx) Deadline() (deadline time.Time, ok bool) { 166 deadline = time.Now() 167 return deadline, true 168 } 169 func (c *newCtx) Done() <-chan struct{} { 170 return nil 171 } 172 func (c *newCtx) Err() error { 173 if Count == 1 { 174 return context.DeadlineExceeded 175 } 176 return context.Canceled 177 } 178 func (c *newCtx) Value(key interface{}) interface{} { 179 return nil 180 } 181 182 func TestTcpRoundTripCanceled(t *testing.T) { 183 st := transport.NewClientTransport() 184 optNetwork := transport.WithDialNetwork("tcp") 185 optPool := transport.WithDialPool(&fakePool{}) 186 fb := &trpc.FramerBuilder{} 187 optFramerBuilder := transport.WithClientFramerBuilder(fb) 188 optAddress := transport.WithDialAddress(":8888") 189 _, err := st.RoundTrip(&newCtx{}, []byte("hello"), optNetwork, optPool, optFramerBuilder, 190 optAddress) 191 assert.NotNil(t, err) 192 } 193 194 func TestTcpRoundTripTimeout(t *testing.T) { 195 st := transport.NewClientTransport() 196 optNetwork := transport.WithDialNetwork("tcp") 197 optPool := transport.WithDialPool(&fakePool{}) 198 fb := &trpc.FramerBuilder{} 199 optFramerBuilder := transport.WithClientFramerBuilder(fb) 200 optAddress := transport.WithDialAddress(":8888") 201 Count = 1 202 _, err := st.RoundTrip(&newCtx{}, []byte("hello"), optNetwork, optPool, optFramerBuilder, 203 optAddress) 204 assert.NotNil(t, err) 205 } 206 207 func TestTcpRoundTripConnWriteErr(t *testing.T) { 208 st := transport.NewClientTransport() 209 optNetwork := transport.WithDialNetwork("tcp") 210 optPool := transport.WithDialPool(&fakePool{}) 211 fb := &trpc.FramerBuilder{} 212 optFramerBuilder := transport.WithClientFramerBuilder(fb) 213 optAddress := transport.WithDialAddress(":8888") 214 Count = 1 215 _, err := st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool, optFramerBuilder, 216 optAddress) 217 assert.NotNil(t, err) 218 Count = 2 219 _, err = st.RoundTrip(context.Background(), []byte("hello"), optNetwork, optPool, optFramerBuilder, 220 optAddress) 221 assert.NotNil(t, err) 222 } 223 224 type NewPacketConn struct { 225 } 226 227 func (c *NewPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 228 return 0, nil, nil 229 } 230 func (c *NewPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 231 if Count == 1 { 232 return len(p), errors.New("write failure") 233 } 234 return len(p), netError{errors.New("net failure")} 235 } 236 func (c *NewPacketConn) Close() error { 237 return nil 238 } 239 func (c *NewPacketConn) LocalAddr() net.Addr { 240 return nil 241 } 242 func (c *NewPacketConn) SetDeadline(t time.Time) error { 243 return nil 244 } 245 func (c *NewPacketConn) SetReadDeadline(t time.Time) error { 246 return nil 247 } 248 func (c *NewPacketConn) SetWriteDeadline(t time.Time) error { 249 return nil 250 } 251 func (c *NewPacketConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { 252 return len(b), nil, netError{errors.New("net failure")} 253 } 254 255 func TestNewClientTransport(t *testing.T) { 256 st := transport.NewClientTransport() 257 assert.NotNil(t, st) 258 } 259 260 func TestWithDialPool(t *testing.T) { 261 opt := transport.WithDialPool(nil) 262 opts := &transport.RoundTripOptions{} 263 opt(opts) 264 assert.Equal(t, nil, opts.Pool) 265 } 266 267 func TestWithReqType(t *testing.T) { 268 opt := transport.WithReqType(transport.SendOnly) 269 opts := &transport.RoundTripOptions{} 270 opt(opts) 271 assert.Equal(t, transport.SendOnly, opts.ReqType) 272 } 273 274 type emptyPool struct { 275 } 276 277 func (p *emptyPool) Get(network string, address string, opts connpool.GetOptions) (net.Conn, error) { 278 return nil, errors.New("empty") 279 } 280 281 var testReqByte = []byte{'a', 'b'} 282 283 func TestWithDialPoolError(t *testing.T) { 284 ctx, f := context.WithTimeout(context.Background(), 3*time.Second) 285 defer f() 286 _, err := transport.RoundTrip(ctx, testReqByte, 287 transport.WithDialPool(&emptyPool{}), 288 transport.WithDialNetwork("tcp")) 289 // fmt.Printf("err: %v", err) 290 assert.NotNil(t, err) 291 } 292 293 func TestContextTimeout(t *testing.T) { 294 ctx, f := context.WithTimeout(context.Background(), time.Millisecond) 295 defer f() 296 <-ctx.Done() 297 fb := &trpc.FramerBuilder{} 298 _, err := transport.RoundTrip(ctx, testReqByte, 299 transport.WithDialNetwork("tcp"), 300 transport.WithDialAddress(":8888"), 301 transport.WithClientFramerBuilder(fb)) 302 assert.NotNil(t, err) 303 } 304 305 func TestContextTimeout_Multiplexed(t *testing.T) { 306 ctx, f := context.WithTimeout(context.Background(), time.Millisecond) 307 defer f() 308 <-ctx.Done() 309 fb := &trpc.FramerBuilder{} 310 _, err := transport.RoundTrip(ctx, testReqByte, 311 transport.WithDialNetwork("tcp"), 312 transport.WithDialAddress(":8888"), 313 transport.WithMultiplexed(true), 314 transport.WithMsg(codec.Message(ctx)), 315 transport.WithClientFramerBuilder(fb)) 316 assert.NotNil(t, err) 317 } 318 319 func TestContextCancel(t *testing.T) { 320 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 321 cancel() 322 fb := &trpc.FramerBuilder{} 323 _, err := transport.RoundTrip(ctx, testReqByte, 324 transport.WithDialNetwork("tcp"), 325 transport.WithDialAddress(":8888"), 326 transport.WithClientFramerBuilder(fb)) 327 assert.NotNil(t, err) 328 } 329 330 func TestWithReqTypeSendOnly(t *testing.T) { 331 ctx, f := context.WithTimeout(context.Background(), 3*time.Second) 332 defer f() 333 _, err := transport.RoundTrip(ctx, []byte{}, 334 transport.WithReqType(transport.SendOnly), 335 transport.WithDialNetwork("tcp")) 336 // fmt.Printf("err: %v", err) 337 assert.NotNil(t, err) 338 } 339 340 func TestClientTransport_RoundTrip(t *testing.T) { 341 fb := &lengthDelimitedBuilder{} 342 go func() { 343 err := transport.ListenAndServe( 344 transport.WithListenNetwork("udp"), 345 transport.WithListenAddress("localhost:9998"), 346 transport.WithHandler(&lengthDelimitedHandler{}), 347 transport.WithServerFramerBuilder(fb), 348 ) 349 assert.Nil(t, err) 350 }() 351 time.Sleep(20 * time.Millisecond) 352 353 t.Run("write: message too long", func(t *testing.T) { 354 c := mustListenUDP(t) 355 t.Cleanup(func() { 356 if err := c.Close(); err != nil { 357 t.Log(err) 358 } 359 }) 360 largeRequest := encodeLengthDelimited(strings.Repeat("1", math.MaxInt32/4)) 361 _, err := transport.RoundTrip(context.Background(), largeRequest, 362 transport.WithClientFramerBuilder(fb), 363 transport.WithDialNetwork("udp"), 364 transport.WithDialAddress(c.LocalAddr().String()), 365 transport.WithReqType(transport.SendAndRecv), 366 ) 367 require.Equal(t, errs.RetClientNetErr, errs.Code(err)) 368 require.Contains(t, errs.Msg(err), "udp client transport WriteTo") 369 }) 370 371 var err error 372 _, err = transport.RoundTrip(context.Background(), encodeLengthDelimited("helloworld")) 373 assert.NotNil(t, err) 374 375 tc := transport.NewClientTransport() 376 _, err = tc.RoundTrip(context.Background(), encodeLengthDelimited("helloworld")) 377 assert.NotNil(t, err) 378 379 // Test address invalid. 380 _, err = tc.RoundTrip(context.Background(), encodeLengthDelimited("helloworld"), 381 transport.WithDialNetwork("udp"), 382 transport.WithDialAddress("invalidaddress"), 383 transport.WithReqType(transport.SendOnly)) 384 assert.NotNil(t, err) 385 386 // Test send only. 387 rsp, err := tc.RoundTrip(context.Background(), encodeLengthDelimited("helloworld"), 388 transport.WithDialNetwork("udp"), 389 transport.WithDialAddress("localhost:9998"), 390 transport.WithClientFramerBuilder(fb), 391 transport.WithReqType(transport.SendOnly), 392 transport.WithConnectionMode(transport.NotConnected)) 393 assert.NotNil(t, err) 394 assert.Equal(t, errs.ErrClientNoResponse, err) 395 assert.Nil(t, rsp) 396 397 // Test multiplexed send only. 398 ctx, msg := codec.WithNewMessage(context.Background()) 399 rsp, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"), 400 transport.WithDialNetwork("udp"), 401 transport.WithMultiplexed(true), 402 transport.WithDialAddress("localhost:9998"), 403 transport.WithReqType(transport.SendOnly), 404 transport.WithClientFramerBuilder(fb), 405 transport.WithMsg(msg), 406 ) 407 assert.NotNil(t, err) 408 assert.Equal(t, errs.ErrClientNoResponse, err) 409 assert.Nil(t, rsp) 410 411 // Test context canceled. 412 ctx, cancel := context.WithCancel(context.Background()) 413 cancel() 414 _, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"), 415 transport.WithDialNetwork("udp"), 416 transport.WithClientFramerBuilder(fb), 417 transport.WithDialAddress("localhost:9998")) 418 assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientCanceled)) 419 420 // Test context timeout. 421 ctx, timeout := context.WithTimeout(context.Background(), time.Millisecond) 422 defer timeout() 423 <-ctx.Done() 424 _, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"), 425 transport.WithDialNetwork("udp"), 426 transport.WithClientFramerBuilder(fb), 427 transport.WithDialAddress("localhost:9998")) 428 assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientTimeout)) 429 430 // Test roundtrip. 431 ctx, cancel = context.WithTimeout(context.Background(), time.Second) 432 defer cancel() 433 rsp, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"), 434 transport.WithDialNetwork("udp"), 435 transport.WithDialAddress("localhost:9998"), 436 transport.WithConnectionMode(transport.NotConnected), 437 transport.WithClientFramerBuilder(fb), 438 ) 439 assert.NotNil(t, rsp) 440 assert.Nil(t, err) 441 442 // Test setting RemoteAddr of UDP RoundTrip. 443 ctx, cancel = context.WithTimeout(context.Background(), time.Second) 444 defer cancel() 445 ctx, msg = codec.WithNewMessage(ctx) 446 _, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"), 447 transport.WithDialNetwork("udp"), 448 transport.WithDialAddress("127.0.0.1:9998"), 449 transport.WithConnectionMode(transport.Connected), 450 transport.WithClientFramerBuilder(fb), 451 ) 452 assert.Nil(t, err) 453 assert.Equal(t, "127.0.0.1:9998", msg.RemoteAddr().String()) 454 455 // Test local addr. 456 localAddr := "127.0.0.1:" 457 ctx, cancel = context.WithTimeout(context.Background(), time.Second) 458 defer cancel() 459 ctx, msg = codec.WithNewMessage(ctx) 460 _, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"), 461 transport.WithDialNetwork("udp"), 462 transport.WithDialAddress("127.0.0.1:9998"), 463 transport.WithConnectionMode(transport.Connected), 464 transport.WithClientFramerBuilder(fb), 465 transport.WithLocalAddr(localAddr), 466 ) 467 assert.Nil(t, err) 468 assert.Equal(t, "127.0.0.1", msg.LocalAddr().(*net.UDPAddr).IP.String()) 469 470 // Test local addr error. 471 localAddr = "invalid address" 472 ctx, cancel = context.WithTimeout(context.Background(), time.Second) 473 defer cancel() 474 ctx, msg = codec.WithNewMessage(ctx) 475 _, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"), 476 transport.WithDialNetwork("udp"), 477 transport.WithDialAddress("127.0.0.1:9998"), 478 transport.WithConnectionMode(transport.Connected), 479 transport.WithClientFramerBuilder(fb), 480 transport.WithLocalAddr(localAddr), 481 ) 482 assert.NotNil(t, err) 483 assert.Nil(t, msg.LocalAddr()) 484 485 // Test readframer error. 486 ctx, cancel = context.WithTimeout(context.Background(), time.Second) 487 defer cancel() 488 _, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"), 489 transport.WithDialNetwork("udp"), 490 transport.WithDialAddress("127.0.0.1:9998"), 491 transport.WithConnectionMode(transport.Connected), 492 transport.WithClientFramerBuilder(&lengthDelimitedBuilder{ 493 readError: true, 494 }), 495 ) 496 assert.Contains(t, err.Error(), readFrameError.Error()) 497 498 // Test readframe bytes remaining error. 499 ctx, cancel = context.WithTimeout(context.Background(), time.Second) 500 defer cancel() 501 _, err = tc.RoundTrip(ctx, encodeLengthDelimited("helloworld"), 502 transport.WithDialNetwork("udp"), 503 transport.WithDialAddress("127.0.0.1:9998"), 504 transport.WithConnectionMode(transport.Connected), 505 transport.WithClientFramerBuilder(&lengthDelimitedBuilder{ 506 remainingBytes: true, 507 }), 508 ) 509 assert.Contains(t, err.Error(), remainingBytesError.Error()) 510 } 511 512 func mustListenUDP(t *testing.T) net.PacketConn { 513 c, err := net.ListenPacket("udp", "127.0.0.1:0") 514 if err != nil { 515 t.Fatal(err) 516 } 517 return c 518 } 519 520 // Frame a stream of bytes based on a length prefix 521 // +------------+--------------------------------+ 522 // | len: uint8 | frame payload | 523 // +------------+--------------------------------+ 524 type lengthDelimitedBuilder struct { 525 remainingBytes bool 526 readError bool 527 } 528 529 func (fb *lengthDelimitedBuilder) New(reader io.Reader) codec.Framer { 530 return &lengthDelimited{ 531 readError: fb.readError, 532 remainingBytes: fb.remainingBytes, 533 reader: reader, 534 } 535 } 536 537 func (fb *lengthDelimitedBuilder) Parse(rc io.Reader) (vid uint32, buf []byte, err error) { 538 buf, err = fb.New(rc).ReadFrame() 539 if err != nil { 540 return 0, nil, err 541 } 542 return 0, buf, nil 543 } 544 545 type lengthDelimited struct { 546 reader io.Reader 547 readError bool 548 remainingBytes bool 549 } 550 551 func encodeLengthDelimited(data string) []byte { 552 result := []byte{byte(len(data))} 553 result = append(result, []byte(data)...) 554 return result 555 } 556 557 var ( 558 readFrameError = errors.New("read framer error") 559 remainingBytesError = fmt.Errorf( 560 "packet data is not drained, the remaining %d will be dropped", 561 remainingBytes, 562 ) 563 remainingBytes = 1 564 ) 565 566 func (f *lengthDelimited) ReadFrame() ([]byte, error) { 567 if f.readError { 568 return nil, readFrameError 569 } 570 head := make([]byte, 1) 571 if _, err := io.ReadFull(f.reader, head); err != nil { 572 return nil, err 573 } 574 bodyLen := int(head[0]) 575 if f.remainingBytes { 576 bodyLen = bodyLen - remainingBytes 577 } 578 body := make([]byte, bodyLen) 579 if _, err := io.ReadFull(f.reader, body); err != nil { 580 return nil, err 581 } 582 return body, nil 583 } 584 585 type lengthDelimitedHandler struct{} 586 587 func (h *lengthDelimitedHandler) Handle(ctx context.Context, req []byte) ([]byte, error) { 588 rsp := make([]byte, len(req)+1) 589 rsp[0] = byte(len(req)) 590 copy(rsp[1:], req) 591 return rsp, nil 592 } 593 594 func TestClientTransport_MultiplexedErr(t *testing.T) { 595 listener, err := net.Listen("tcp", ":") 596 require.Nil(t, err) 597 defer listener.Close() 598 go func() { 599 transport.ListenAndServe( 600 transport.WithListener(listener), 601 transport.WithHandler(&echoHandler{}), 602 transport.WithServerFramerBuilder(transport.GetFramerBuilder("trpc")), 603 ) 604 }() 605 time.Sleep(20 * time.Millisecond) 606 607 tc := transport.NewClientTransport() 608 fb := &trpc.FramerBuilder{} 609 610 // Test multiplexed context timeout. 611 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 612 defer cancel() 613 ctx, msg := codec.WithNewMessage(ctx) 614 _, err = tc.RoundTrip(ctx, []byte("helloworld"), 615 transport.WithDialNetwork(listener.Addr().Network()), 616 transport.WithDialAddress(listener.Addr().String()), 617 transport.WithMultiplexed(true), 618 transport.WithClientFramerBuilder(fb), 619 transport.WithMsg(msg), 620 ) 621 assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientTimeout)) 622 623 // Test multiplexed context canceled. 624 ctx, cancel = context.WithTimeout(context.Background(), time.Second) 625 go func() { 626 time.Sleep(time.Millisecond * 200) 627 cancel() 628 }() 629 _, err = tc.RoundTrip(ctx, []byte("helloworld"), 630 transport.WithDialNetwork(listener.Addr().Network()), 631 transport.WithDialAddress(listener.Addr().String()), 632 transport.WithMultiplexed(true), 633 transport.WithClientFramerBuilder(fb), 634 transport.WithMsg(msg), 635 ) 636 assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientCanceled)) 637 } 638 639 func TestClientTransport_RoundTrip_PreConnected(t *testing.T) { 640 go func() { 641 err := transport.ListenAndServe( 642 transport.WithListenNetwork("udp"), 643 transport.WithListenAddress("localhost:9999"), 644 transport.WithHandler(&echoHandler{}), 645 transport.WithServerFramerBuilder(transport.GetFramerBuilder("trpc")), 646 ) 647 assert.Nil(t, err) 648 }() 649 time.Sleep(20 * time.Millisecond) 650 651 var err error 652 _, err = transport.RoundTrip(context.Background(), []byte("helloworld")) 653 assert.NotNil(t, err) 654 655 tc := transport.NewClientTransport() 656 657 // Test connected UDPConn. 658 rsp, err := tc.RoundTrip(context.Background(), []byte("helloworld"), 659 transport.WithDialNetwork("udp"), 660 transport.WithDialAddress("localhost:9999"), 661 transport.WithDialPassword("passwd"), 662 transport.WithClientFramerBuilder(&trpc.FramerBuilder{}), 663 transport.WithReqType(transport.SendOnly), 664 transport.WithConnectionMode(transport.Connected)) 665 assert.NotNil(t, err) 666 assert.Equal(t, errs.ErrClientNoResponse, err) 667 assert.Nil(t, rsp) 668 669 // Test context done. 670 ctx, cancel := context.WithCancel(context.Background()) 671 cancel() 672 _, err = tc.RoundTrip(ctx, []byte("helloworld"), 673 transport.WithDialNetwork("udp"), 674 transport.WithDialAddress("localhost:9999"), 675 transport.WithConnectionMode(transport.Connected)) 676 assert.NotNil(t, err) 677 678 // Test RoundTrip. 679 ctx, cancel = context.WithTimeout(ctx, time.Second) 680 defer cancel() 681 rsp, err = tc.RoundTrip(ctx, []byte("helloworld"), 682 transport.WithDialNetwork("udp"), 683 transport.WithDialAddress("localhost:9999"), 684 transport.WithConnectionMode(transport.Connected)) 685 assert.NotNil(t, err) 686 assert.Nil(t, rsp) 687 } 688 689 func TestOptions(t *testing.T) { 690 691 opts := &transport.RoundTripOptions{} 692 693 o := transport.WithDialTLS("client.cert", "client.key", "ca.pem", "servername") 694 o(opts) 695 assert.Equal(t, "client.cert", opts.TLSCertFile) 696 assert.Equal(t, "client.key", opts.TLSKeyFile) 697 assert.Equal(t, "ca.pem", opts.CACertFile) 698 assert.Equal(t, "servername", opts.TLSServerName) 699 700 o = transport.WithDisableConnectionPool() 701 o(opts) 702 703 assert.True(t, opts.DisableConnectionPool) 704 } 705 706 // TestWithMultiplexedPool tests connection pool multiplexing. 707 func TestWithMultiplexedPool(t *testing.T) { 708 opts := &transport.RoundTripOptions{} 709 m := multiplexed.New(multiplexed.WithConnectNumber(10)) 710 o := transport.WithMultiplexedPool(m) 711 o(opts) 712 assert.True(t, opts.EnableMultiplexed) 713 assert.Equal(t, opts.Multiplexed, m) 714 } 715 716 // TestUDPTransportFramerBuilderErr tests nil FramerBuilder error. 717 func TestUDPTransportFramerBuilderErr(t *testing.T) { 718 opts := []transport.RoundTripOption{ 719 transport.WithDialNetwork("udp"), 720 } 721 ts := transport.NewClientTransport() 722 _, err := ts.RoundTrip(context.Background(), nil, opts...) 723 assert.EqualValues(t, err.(*errs.Error).Code, int32(errs.RetClientConnectFail)) 724 } 725 726 // TestWithLocalAddr tests local addr. 727 func TestWithLocalAddr(t *testing.T) { 728 opts := &transport.RoundTripOptions{} 729 localAddr := "127.0.0.1:8080" 730 o := transport.WithLocalAddr(localAddr) 731 o(opts) 732 assert.Equal(t, opts.LocalAddr, localAddr) 733 } 734 735 func TestWithDialTimeout(t *testing.T) { 736 opts := &transport.RoundTripOptions{} 737 timeout := time.Second 738 o := transport.WithDialTimeout(timeout) 739 o(opts) 740 assert.Equal(t, opts.DialTimeout, timeout) 741 } 742 743 func TestWithProtocol(t *testing.T) { 744 opts := &transport.RoundTripOptions{} 745 protocol := "xxx-protocol" 746 o := transport.WithProtocol(protocol) 747 o(opts) 748 assert.Equal(t, protocol, opts.Protocol) 749 } 750 751 func TestWithDisableEncodeTransInfoBase64(t *testing.T) { 752 opts := &transport.ClientTransportOptions{} 753 transport.WithDisableEncodeTransInfoBase64()(opts) 754 assert.Equal(t, true, opts.DisableHTTPEncodeTransInfoBase64) 755 }