trpc.group/trpc-go/trpc-go@v1.0.3/pool/multiplexed/multiplexed_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package multiplexed 15 16 import ( 17 "bytes" 18 "context" 19 "encoding/binary" 20 "errors" 21 "fmt" 22 "io" 23 "log" 24 "math" 25 "net" 26 "strconv" 27 "sync" 28 "sync/atomic" 29 "testing" 30 "time" 31 32 "golang.org/x/sync/errgroup" 33 "trpc.group/trpc-go/trpc-go/codec" 34 35 "github.com/stretchr/testify/assert" 36 "github.com/stretchr/testify/require" 37 "github.com/stretchr/testify/suite" 38 ) 39 40 func TestMultiplexedSuite(t *testing.T) { 41 suite.Run(t, &msuite{}) 42 } 43 44 type msuite struct { 45 suite.Suite 46 47 network string 48 udpNetwork string 49 address string 50 udpAddr string 51 52 ts *tcpServer 53 us *udpServer 54 55 requestID uint32 56 } 57 58 func (s *msuite) SetupSuite() { 59 s.ts = newTCPServer() 60 s.us = newUDPServer() 61 62 ctx := context.Background() 63 s.ts.start(ctx) 64 s.us.start(ctx) 65 66 s.address = s.ts.ln.Addr().String() 67 s.network = s.ts.ln.Addr().Network() 68 69 s.udpAddr = s.us.conn.LocalAddr().String() 70 s.udpNetwork = s.us.conn.LocalAddr().Network() 71 72 s.requestID = 1 73 } 74 75 func (s *msuite) TearDownSuite() { 76 s.ts.stop() 77 s.us.stop() 78 } 79 80 func (s *msuite) TearDownTest() { 81 // Close all the established tcp concreteConns after each test. 82 s.ts.closeConnections() 83 } 84 85 var errDecodeDelimited = errors.New("decode error") 86 87 type lengthDelimitedFramer struct { 88 IsStream bool 89 reader io.Reader 90 decodeError bool 91 safe bool 92 } 93 94 func (f *lengthDelimitedFramer) New(reader io.Reader) codec.Framer { 95 return &lengthDelimitedFramer{ 96 IsStream: f.IsStream, 97 reader: reader, 98 decodeError: f.decodeError, 99 safe: f.safe, 100 } 101 } 102 103 func (f *lengthDelimitedFramer) ReadFrame() ([]byte, error) { 104 return nil, nil 105 } 106 107 func (f *lengthDelimitedFramer) IsSafe() bool { 108 return f.safe 109 } 110 111 func (f *lengthDelimitedFramer) Parse(rc io.Reader) (vid uint32, buf []byte, err error) { 112 head := make([]byte, 8) 113 num, err := io.ReadFull(rc, head) 114 if err != nil { 115 return 0, nil, err 116 } 117 118 if f.decodeError { 119 return 0, nil, errDecodeDelimited 120 } 121 122 if num != 8 { 123 return 0, nil, errors.New("invalid read full num") 124 } 125 126 n := binary.BigEndian.Uint32(head[:4]) 127 requestID := binary.BigEndian.Uint32(head[4:8]) 128 body := make([]byte, int(n)) 129 130 num, err = io.ReadFull(rc, body) 131 if err != nil { 132 return 0, nil, err 133 } 134 135 if num != int(n) { 136 return 0, nil, errors.New("invalid read full body") 137 } 138 139 if f.IsStream { 140 return requestID, append(head, body...), nil 141 } 142 return requestID, body, nil 143 } 144 145 type delimitedRequest struct { 146 requestID uint32 147 body []byte 148 } 149 150 func (f *lengthDelimitedFramer) Encode(req *delimitedRequest) ([]byte, error) { 151 l := len(req.body) 152 buf := bytes.NewBuffer(make([]byte, 0, 8+l)) 153 if err := binary.Write(buf, binary.BigEndian, uint32(l)); err != nil { 154 return nil, err 155 } 156 if err := binary.Write(buf, binary.BigEndian, req.requestID); err != nil { 157 return nil, err 158 } 159 160 if err := binary.Write(buf, binary.BigEndian, req.body); err != nil { 161 return nil, err 162 } 163 164 return buf.Bytes(), nil 165 } 166 167 func (s *msuite) TestMultiplexedDecodeErr() { 168 tests := []struct { 169 network string 170 address string 171 wantErr error 172 }{ 173 {s.network, s.address, errDecodeDelimited}, 174 {s.udpNetwork, s.udpAddr, context.DeadlineExceeded}, 175 } 176 177 for _, tt := range tests { 178 id := atomic.AddUint32(&s.requestID, 1) 179 ld := &lengthDelimitedFramer{ 180 decodeError: true, 181 } 182 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 183 m := New() 184 opts := NewGetOptions() 185 opts.WithVID(id) 186 opts.WithFrameParser(ld) 187 vc, err := m.GetMuxConn(ctx, tt.network, tt.address, opts) 188 assert.Nil(s.T(), err) 189 body := []byte("hello world") 190 buf, err := ld.Encode(&delimitedRequest{ 191 body: body, 192 requestID: id, 193 }) 194 require.Nil(s.T(), err) 195 require.Nil(s.T(), vc.Write(buf)) 196 _, err = vc.Read() 197 assert.Equal(s.T(), err, tt.wantErr) 198 cancel() 199 } 200 } 201 202 func (s *msuite) TestMultiplexedGetConcurrent() { 203 count := 10 204 ld := &lengthDelimitedFramer{} 205 m := New() 206 tests := []struct { 207 network string 208 address string 209 }{ 210 {s.network, s.address}, 211 {s.udpNetwork, s.udpAddr}, 212 } 213 for _, tt := range tests { 214 wg := sync.WaitGroup{} 215 wg.Add(count) 216 for i := 0; i < count; i++ { 217 go func(i int) { 218 defer wg.Done() 219 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 220 id := atomic.AddUint32(&s.requestID, 1) 221 opts := NewGetOptions() 222 opts.WithVID(id) 223 opts.WithFrameParser(ld) 224 vc, err := m.GetMuxConn(ctx, tt.network, tt.address, opts) 225 assert.Nil(s.T(), err) 226 body := []byte("hello world" + strconv.Itoa(i)) 227 buf, err := ld.Encode(&delimitedRequest{ 228 body: body, 229 requestID: id, 230 }) 231 assert.Nil(s.T(), err) 232 assert.Nil(s.T(), vc.Write(buf)) 233 rsp, err := vc.Read() 234 assert.Nil(s.T(), err) 235 assert.Equal(s.T(), rsp, body) 236 cancel() 237 }(i) 238 } 239 wg.Wait() 240 } 241 } 242 243 func (s *msuite) TestMultiplexedGet() { 244 id := atomic.AddUint32(&s.requestID, 1) 245 ld := &lengthDelimitedFramer{} 246 247 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second) 248 defer cancel() 249 250 m := New(WithConnectNumber(4), WithDropFull(true), WithQueueSize(50000)) 251 opts := NewGetOptions() 252 opts.WithVID(id) 253 opts.WithFrameParser(ld) 254 vc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 255 assert.Nil(s.T(), err) 256 257 body := []byte("hello world") 258 buf, err := ld.Encode(&delimitedRequest{ 259 body: body, 260 requestID: id, 261 }) 262 assert.Nil(s.T(), err) 263 assert.Nil(s.T(), vc.Write(buf)) 264 265 rsp, err := vc.Read() 266 assert.Nil(s.T(), err) 267 assert.Equal(s.T(), rsp, body) 268 } 269 270 func (s *msuite) TestMultiplexedGetWithSafeFramer() { 271 id := atomic.AddUint32(&s.requestID, 1) 272 ld := &lengthDelimitedFramer{safe: true} 273 274 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 275 defer cancel() 276 277 m := New(WithConnectNumber(4), WithDropFull(true), WithQueueSize(50000)) 278 opts := NewGetOptions() 279 opts.WithVID(id) 280 opts.WithFrameParser(ld) 281 vc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 282 assert.Nil(s.T(), err) 283 284 body := []byte("hello world") 285 buf, err := ld.Encode(&delimitedRequest{ 286 body: body, 287 requestID: id, 288 }) 289 assert.Nil(s.T(), err) 290 assert.Nil(s.T(), vc.Write(buf)) 291 292 rsp, err := vc.Read() 293 assert.Nil(s.T(), err) 294 assert.Equal(s.T(), rsp, body) 295 } 296 297 func (s *msuite) TestNoFramerParser() { 298 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 299 defer cancel() 300 m := New() 301 opts := NewGetOptions() 302 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 303 _, err := m.GetMuxConn(ctx, s.network, s.address, opts) 304 assert.Equal(s.T(), err, ErrFrameParserNil) 305 } 306 307 func (s *msuite) TestContextDeadline() { 308 id := atomic.AddUint32(&s.requestID, 1) 309 ld := &lengthDelimitedFramer{} 310 311 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 312 defer cancel() 313 314 m := New() 315 opts := NewGetOptions() 316 opts.WithVID(id) 317 opts.WithFrameParser(ld) 318 vc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 319 assert.Nil(s.T(), err) 320 _, err = vc.Read() 321 assert.Equal(s.T(), err, context.DeadlineExceeded) 322 err = vc.Write([]byte("hello world")) 323 assert.Equal(s.T(), err, context.DeadlineExceeded) 324 325 ctx, cancel = context.WithTimeout(context.Background(), time.Second) 326 defer cancel() 327 vc, err = m.GetMuxConn(ctx, s.network, s.address, opts) 328 assert.Nil(s.T(), err) 329 330 body := []byte("hello world") 331 buf, err := ld.Encode(&delimitedRequest{ 332 body: body, 333 requestID: id, 334 }) 335 assert.Nil(s.T(), err) 336 assert.Nil(s.T(), vc.Write(buf)) 337 338 rsp, err := vc.Read() 339 assert.Nil(s.T(), err) 340 assert.Equal(s.T(), rsp, body) 341 } 342 343 func (s *msuite) TestCloseConnection() { 344 id := atomic.AddUint32(&s.requestID, 1) 345 ld := &lengthDelimitedFramer{} 346 347 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 348 defer cancel() 349 350 m := New(WithConnectNumber(1)) 351 opts := NewGetOptions() 352 opts.WithVID(id) 353 opts.WithFrameParser(ld) 354 _, err := m.GetMuxConn(ctx, s.network, s.address, opts) 355 assert.Nil(s.T(), err) 356 357 time.Sleep(500 * time.Millisecond) 358 v, ok := m.concreteConns.Load(makeNodeKey(s.network, s.address)) 359 assert.True(s.T(), ok) 360 cs := v.(*Connections) 361 cs.conns[0].close(errors.New("fake error"), false) 362 _, ok = m.concreteConns.Load(makeNodeKey(s.network, s.address)) 363 assert.False(s.T(), ok) 364 } 365 366 func (s *msuite) TestDuplicatedClose() { 367 id := atomic.AddUint32(&s.requestID, 1) 368 ld := &lengthDelimitedFramer{} 369 370 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 371 defer cancel() 372 m := New(WithConnectNumber(1)) 373 opts := NewGetOptions() 374 opts.WithVID(id) 375 opts.WithFrameParser(ld) 376 vc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 377 assert.Nil(s.T(), err) 378 379 body := []byte("hello world") 380 buf, err := ld.Encode(&delimitedRequest{ 381 body: body, 382 requestID: id, 383 }) 384 assert.Nil(s.T(), err) 385 assert.Nil(s.T(), vc.Write(buf)) 386 387 rsp, err := vc.Read() 388 assert.Nil(s.T(), err) 389 assert.Equal(s.T(), rsp, body) 390 391 v, ok := m.concreteConns.Load(makeNodeKey(s.network, s.address)) 392 assert.True(s.T(), ok) 393 cs := v.(*Connections) 394 err1 := errors.New("error1") 395 err2 := errors.New("error2") 396 c := cs.conns[0] 397 c.close(err1, false) 398 c.close(err2, false) 399 400 _, err = vc.Read() 401 assert.Equal(s.T(), err, err1) 402 } 403 404 func (s *msuite) TestGetFail() { 405 ld := &lengthDelimitedFramer{} 406 407 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 408 defer cancel() 409 410 m := New() 411 opts := NewGetOptions() 412 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 413 opts.WithFrameParser(ld) 414 _, err := m.GetMuxConn(ctx, s.network, s.address, opts) 415 assert.Nil(s.T(), err) 416 417 m.concreteConns.Store(makeNodeKey(s.network, s.address), &Connection{}) 418 _, err = m.GetMuxConn(ctx, s.network, s.address, opts) 419 assert.NotNil(s.T(), err) 420 } 421 422 func (s *msuite) TestContextCancel() { 423 id := atomic.AddUint32(&s.requestID, 1) 424 ld := &lengthDelimitedFramer{} 425 426 // get with cancel. 427 ctx, cancel := context.WithCancel(context.Background()) 428 cancel() 429 m := New() 430 opts := NewGetOptions() 431 opts.WithVID(id) 432 opts.WithFrameParser(ld) 433 _, err := m.GetMuxConn(ctx, s.network, s.address, opts) 434 assert.NotNil(s.T(), err) 435 } 436 437 // test when send fails. 438 func (s *msuite) TestSendFail() { 439 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 440 defer cancel() 441 m := New(WithDropFull(true), WithQueueSize(1)) 442 opts := NewGetOptions() 443 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 444 opts.WithFrameParser(&emptyFrameParser{}) 445 vc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 446 assert.Nil(s.T(), err) 447 448 body := []byte("hello world") 449 err = vc.Write(body) 450 assert.Nil(s.T(), err) 451 err = vc.Write(body) 452 assert.NotNil(s.T(), err) 453 } 454 455 func (s *msuite) TestWriteErrorCleanVirtualConnection() { 456 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 457 defer cancel() 458 m := New(WithDropFull(true), WithQueueSize(0)) 459 opts := NewGetOptions() 460 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 461 opts.WithFrameParser(&emptyFrameParser{}) 462 mc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 463 assert.Nil(s.T(), err) 464 vc, ok := mc.(*VirtualConnection) 465 assert.True(s.T(), ok) 466 467 body := []byte("hello world") 468 err = vc.Write(body) 469 assert.NotNil(s.T(), err) 470 assert.Len(s.T(), vc.conn.virConns, 0) 471 } 472 473 func (s *msuite) TestReadErrorCleanVirtualConnection() { 474 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) 475 defer cancel() 476 m := New(WithDropFull(true), WithQueueSize(0)) 477 opts := NewGetOptions() 478 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 479 opts.WithFrameParser(&lengthDelimitedFramer{}) 480 mc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 481 assert.Nil(s.T(), err) 482 vc, ok := mc.(*VirtualConnection) 483 assert.True(s.T(), ok) 484 485 time.Sleep(time.Millisecond * 100) 486 _, err = vc.Read() 487 assert.NotNil(s.T(), err) 488 assert.Len(s.T(), vc.conn.virConns, 0) 489 } 490 491 func (s *msuite) TestUdpMultiplexedReadTimeout() { 492 ld := &lengthDelimitedFramer{} 493 494 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 495 defer cancel() 496 m := New() 497 opts := NewGetOptions() 498 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 499 opts.WithFrameParser(ld) 500 vc, err := m.GetMuxConn(ctx, "udp", s.udpAddr, opts) 501 assert.Nil(s.T(), err) 502 _, err = vc.Read() 503 assert.Equal(s.T(), err, ctx.Err()) 504 } 505 506 func (s *msuite) TestMultiplexedServerFail() { 507 tests := []struct { 508 network string 509 address string 510 exists bool 511 }{ 512 {s.network, "invalid address", false}, 513 {s.udpNetwork, "invalid address", false}, 514 } 515 516 for _, tt := range tests { 517 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 518 defer cancel() 519 m := New( 520 WithConnectNumber(1), 521 // On windows, it will try to use up all the timeout to do the dialling. 522 // So limit the dial timeout. 523 WithDialTimeout(time.Millisecond), 524 ) 525 opts := NewGetOptions() 526 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 527 opts.WithFrameParser(&emptyFrameParser{}) 528 _, err := m.GetMuxConn(ctx, tt.network, tt.address, opts) 529 s.T().Logf("m.GetMuxConn err: %+v\n", err) 530 // Because of possible out of order execution of goroutines, 531 // the error may or may not be nil. 532 if err != nil { 533 // If it is non-nil, it must be an expelled error. 534 require.True(s.T(), errors.Is(err, ErrConnectionsHaveBeenExpelled)) 535 } 536 time.Sleep(10 * time.Millisecond) 537 _, ok := m.concreteConns.Load(makeNodeKey(tt.network, tt.address)) 538 assert.Equal(s.T(), tt.exists, ok) 539 } 540 } 541 542 func (s *msuite) TestMultiplexedConcurrentGetInvalidAddr() { 543 const ( 544 network = "tcp" 545 invalidAddr = "invalid addr" 546 ) 547 msg := codec.Message(context.Background()) 548 msg.WithRequestID(atomic.AddUint32(&s.requestID, 1)) 549 550 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 551 defer cancel() 552 m := New(WithConnectNumber(1)) 553 opts := NewGetOptions() 554 opts.WithFrameParser(&emptyFrameParser{}) 555 start := time.Now() 556 for n := 1; ; n++ { 557 if time.Since(start) > time.Second*10 { 558 require.FailNow(s.T(), "expected expelled error in 10s") 559 } 560 var eg errgroup.Group 561 for i := 0; i < n; i++ { 562 eg.Go(func() error { 563 _, err := m.GetMuxConn(ctx, network, invalidAddr, opts) 564 return err 565 }) 566 } 567 if err := eg.Wait(); err != nil { 568 s.T().Logf("ok, m.GetMuxConn error: %+v\n", err) 569 break 570 } 571 } 572 } 573 574 func (s *msuite) TestWithLocalAddr() { 575 tests := []struct { 576 network string 577 address string 578 }{ 579 {s.network, s.address}, 580 {s.udpNetwork, s.udpAddr}, 581 } 582 localAddr := "127.0.0.1" 583 584 for _, tt := range tests { 585 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 586 defer cancel() 587 m := New() 588 opts := NewGetOptions() 589 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 590 opts.WithLocalAddr(localAddr + ":") 591 ld := &lengthDelimitedFramer{} 592 opts.WithFrameParser(ld) 593 body := []byte("hello world") 594 buf, err := ld.Encode(&delimitedRequest{ 595 body: body, 596 requestID: s.requestID, 597 }) 598 assert.Nil(s.T(), err) 599 mc, err := m.GetMuxConn(ctx, tt.network, tt.address, opts) 600 assert.Nil(s.T(), err) 601 vc, ok := mc.(*VirtualConnection) 602 assert.True(s.T(), ok) 603 assert.Nil(s.T(), vc.Write(buf)) 604 assert.Nil(s.T(), err) 605 _, err = vc.Read() 606 assert.Nil(s.T(), err) 607 if tt.network == s.network { 608 conn := vc.conn.getRawConn() 609 realAddr := conn.LocalAddr().(*net.TCPAddr).IP.String() 610 assert.Equal(s.T(), realAddr, localAddr) 611 } else if tt.network == s.udpNetwork { 612 realAddr := vc.conn.packetConn.LocalAddr().(*net.UDPAddr).IP.String() 613 assert.Equal(s.T(), realAddr, localAddr) 614 } 615 } 616 } 617 618 func (s *msuite) TestTCPReconnect() { 619 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 620 defer cancel() 621 m := New(WithConnectNumber(1)) 622 opts := NewGetOptions() 623 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 624 ld := &lengthDelimitedFramer{} 625 opts.WithFrameParser(ld) 626 body := []byte("hello world") 627 buf, err := ld.Encode(&delimitedRequest{ 628 body: body, 629 requestID: s.requestID, 630 }) 631 assert.Nil(s.T(), err) 632 vc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 633 assert.Nil(s.T(), err) 634 assert.Nil(s.T(), vc.Write(buf)) 635 _, err = vc.Read() 636 assert.Nil(s.T(), err) 637 638 // close conn 639 val, ok := m.concreteConns.Load(makeNodeKey(s.network, s.address)) 640 assert.True(s.T(), ok) 641 c := val.(*Connections).conns[0] 642 conn := c.getRawConn() 643 conn.Close() 644 time.Sleep(100 * time.Millisecond) 645 vc, err = m.GetMuxConn(ctx, s.network, s.address, opts) 646 assert.Nil(s.T(), err) 647 assert.Nil(s.T(), vc.Write(buf)) 648 _, err = vc.Read() 649 assert.Nil(s.T(), err) 650 _, ok = m.concreteConns.Load(makeNodeKey(s.network, s.address)) 651 assert.True(s.T(), ok) 652 653 // timeout after reconnected 654 ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) 655 defer done() 656 vc, err = m.GetMuxConn(ctx, s.network, s.address, opts) 657 assert.Nil(s.T(), err) 658 _, err = vc.Read() 659 assert.ErrorIs(s.T(), err, context.DeadlineExceeded) 660 } 661 662 func (s *msuite) TestTCPReconnectMaxReconnectCount() { 663 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 664 defer cancel() 665 m := New(WithConnectNumber(1)) 666 opts := NewGetOptions() 667 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 668 ld := &lengthDelimitedFramer{} 669 opts.WithFrameParser(ld) 670 _, err := m.GetMuxConn(ctx, s.network, "invalid address", opts) 671 assert.Nil(s.T(), err) 672 time.Sleep(time.Second) 673 _, ok := m.concreteConns.Load(makeNodeKey(s.network, "invalid address")) 674 assert.False(s.T(), ok) 675 } 676 677 func (s *msuite) TestStreamMultiplexd() { 678 id := atomic.AddUint32(&s.requestID, 1) 679 680 ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) 681 defer cancel() 682 683 m := New() 684 opts := NewGetOptions() 685 opts.WithVID(id) 686 ld := &lengthDelimitedFramer{IsStream: true} 687 opts.WithFrameParser(ld) 688 vc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 689 assert.Nil(s.T(), err) 690 assert.NotNil(s.T(), vc) 691 692 body := []byte("hello world") 693 buf, err := ld.Encode(&delimitedRequest{ 694 body: body, 695 requestID: id, 696 }) 697 assert.Nil(s.T(), err) 698 assert.Nil(s.T(), vc.Write(buf)) 699 700 rsp, err := vc.Read() 701 assert.Nil(s.T(), err) 702 assert.Equal(s.T(), buf, rsp) 703 } 704 705 func (s *msuite) TestStreamMultiplexd_Addr() { 706 streamID := atomic.AddUint32(&s.requestID, 1) 707 708 ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) 709 defer cancel() 710 711 m := New() 712 opts := NewGetOptions() 713 opts.WithVID(streamID) 714 ld := &lengthDelimitedFramer{IsStream: true} 715 opts.WithFrameParser(ld) 716 vc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 717 assert.Nil(s.T(), err) 718 assert.NotNil(s.T(), vc) 719 time.Sleep(50 * time.Millisecond) 720 721 la := vc.LocalAddr() 722 assert.NotNil(s.T(), la) 723 724 ra := vc.RemoteAddr() 725 assert.Equal(s.T(), s.address, ra.String()) 726 } 727 728 func (s *msuite) TestStreamMultiplexd_MaxVirConnPerConn() { 729 ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) 730 defer cancel() 731 732 m := New(WithMaxVirConnsPerConn(4)) 733 opts := NewGetOptions() 734 ld := &lengthDelimitedFramer{IsStream: true} 735 opts.WithFrameParser(ld) 736 var cs *Connections 737 for i := 0; i < 10; i++ { 738 id := atomic.AddUint32(&s.requestID, 1) 739 opts.WithVID(id) 740 vc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 741 assert.Nil(s.T(), err) 742 assert.NotNil(s.T(), vc) 743 conns, ok := m.concreteConns.Load(makeNodeKey(s.network, s.address)) 744 require.True(s.T(), ok) 745 cs, ok = conns.(*Connections) 746 require.True(s.T(), ok) 747 748 body := []byte("hello world") 749 buf, err := ld.Encode(&delimitedRequest{ 750 body: body, 751 requestID: uint32(id), 752 }) 753 assert.Nil(s.T(), err) 754 assert.Nil(s.T(), vc.Write(buf)) 755 756 rsp, err := vc.Read() 757 assert.Nil(s.T(), err) 758 assert.Equal(s.T(), buf, rsp) 759 } 760 assert.Equal(s.T(), 3, len(cs.conns)) 761 } 762 763 func (s *msuite) TestStreamMultiplexd_MaxIdleConnPerHost() { 764 ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) 765 defer cancel() 766 767 m := New(WithMaxVirConnsPerConn(2), WithMaxIdleConnsPerHost(3)) 768 opts := NewGetOptions() 769 ld := &lengthDelimitedFramer{IsStream: true} 770 opts.WithFrameParser(ld) 771 772 vcs := make([]MuxConn, 0) 773 for i := 0; i < 10; i++ { 774 id := atomic.AddUint32(&s.requestID, 1) 775 opts.WithVID(id) 776 vc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 777 assert.Nil(s.T(), err) 778 vcs = append(vcs, vc) 779 } 780 conns, ok := m.concreteConns.Load(makeNodeKey(s.network, s.address)) 781 require.True(s.T(), ok) 782 cs, ok := conns.(*Connections) 783 require.True(s.T(), ok) 784 assert.Equal(s.T(), 5, len(cs.conns)) 785 for i := 0; i < 10; i++ { 786 vcs[i].Close() 787 } 788 assert.Equal(s.T(), 3, len(cs.conns)) 789 } 790 791 func (s *msuite) TestMultiplexedGetConcurrent_MaxIdleConnPerHost() { 792 count := 100 793 ld := &lengthDelimitedFramer{} 794 m := New(WithMaxVirConnsPerConn(20), WithMaxIdleConnsPerHost(2)) 795 tests := []struct { 796 network string 797 address string 798 }{ 799 {s.network, s.address}, 800 {s.udpNetwork, s.udpAddr}, 801 } 802 for _, tt := range tests { 803 wg := sync.WaitGroup{} 804 wg.Add(count) 805 for i := 0; i < count; i++ { 806 go func(i int) { 807 defer wg.Done() 808 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 809 id := atomic.AddUint32(&s.requestID, 1) 810 opts := NewGetOptions() 811 opts.WithVID(id) 812 opts.WithFrameParser(ld) 813 vc, err := m.GetMuxConn(ctx, tt.network, tt.address, opts) 814 assert.Nil(s.T(), err) 815 body := []byte("hello world" + strconv.Itoa(i)) 816 buf, err := ld.Encode(&delimitedRequest{ 817 body: body, 818 requestID: id, 819 }) 820 assert.Nil(s.T(), err) 821 assert.Nil(s.T(), vc.Write(buf)) 822 rsp, err := vc.Read() 823 assert.Nil(s.T(), err) 824 assert.Equal(s.T(), rsp, body) 825 vc.Close() 826 cancel() 827 }(i) 828 if i%50 == 0 { 829 time.Sleep(50 * time.Millisecond) 830 } 831 } 832 wg.Wait() 833 } 834 } 835 836 func (s *msuite) TestMultiplexedReconnectOnConnectError() { 837 ctx := context.Background() 838 ts := newTCPServer() 839 ts.start(ctx) 840 defer ts.stop() 841 m := New( 842 WithConnectNumber(1), 843 // On windows, it will try to use up all the timeout to do the dialling. 844 // So limit the dial timeout. 845 WithDialTimeout(time.Millisecond*10), 846 ) 847 ctx, cancel := context.WithTimeout(ctx, time.Second) 848 defer cancel() 849 opts := NewGetOptions() 850 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 851 readTrigger := make(chan struct{}) 852 readErr := make(chan error) 853 opts.WithFrameParser(&triggeredReadFramerBuilder{readTrigger: readTrigger, readErr: readErr}) 854 mc, err := m.GetMuxConn(ctx, s.network, ts.ln.Addr().String(), opts) 855 require.Nil(s.T(), err) 856 vc, ok := mc.(*VirtualConnection) 857 assert.True(s.T(), ok) 858 <-readTrigger // Wait for the first read. 859 require.Nil(s.T(), ts.ln.Close()) // Then close the server. 860 readErr <- errAlwaysFail // Fail the first read to trigger reconnection. 861 require.Eventually(s.T(), 862 func() bool { return maxReconnectCount+1 == vc.conn.reconnectCount }, 863 time.Second, 10*time.Millisecond) 864 } 865 866 func (s *msuite) TestMultiplexedReconnectOnReadError() { 867 preInitialBackoff := initialBackoff 868 preMaxBackoff := maxBackoff 869 preMaxReconnectCount := maxReconnectCount 870 preResetInterval := reconnectCountResetInterval 871 defer func() { 872 initialBackoff = preInitialBackoff 873 maxBackoff = preMaxBackoff 874 maxReconnectCount = preMaxReconnectCount 875 reconnectCountResetInterval = preResetInterval 876 }() 877 initialBackoff = time.Microsecond 878 maxBackoff = 50 * time.Microsecond 879 maxReconnectCount = 5 880 reconnectCountResetInterval = time.Hour 881 882 m := New( 883 WithConnectNumber(1), 884 // On windows, it will try to use up all the timeout to do the dialling. 885 // So limit the dial timeout. 886 WithDialTimeout(time.Millisecond*10), 887 ) 888 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 889 defer cancel() 890 opts := NewGetOptions() 891 calledAt := make([]time.Time, 0, maxReconnectCount) 892 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 893 opts.WithFrameParser(&errFramerBuilder{readFrameCalledAt: &calledAt}) 894 mc, err := m.GetMuxConn(ctx, s.network, s.address, opts) 895 require.Nil(s.T(), err) 896 vc, ok := mc.(*VirtualConnection) 897 assert.True(s.T(), ok) 898 require.Eventually(s.T(), 899 func() bool { return maxReconnectCount+1 == vc.conn.reconnectCount }, 900 3*time.Second, time.Second, 901 fmt.Sprintf("final status: maxReconnectCount+1=%d, vc.conn.reconnectCount=%d", 902 maxReconnectCount+1, vc.conn.reconnectCount)) 903 require.Eventually(s.T(), 904 func() bool { return maxReconnectCount+1 == len(calledAt) }, 905 3*time.Second, 50*time.Millisecond, 906 fmt.Sprintf("final status: maxReconnectCount+1=%d, len(calledAt)=%d", 907 maxReconnectCount+1, len(calledAt))) 908 var differences []float64 909 for i := 1; i < len(calledAt); i++ { 910 delay := calledAt[i].Sub(calledAt[i-1]) 911 expectedBackoff := (initialBackoff * time.Duration(i)) 912 s.T().Logf("calledAt delay: %2dms, expect: %2dms (between %d and %d)\n", 913 delay.Milliseconds(), expectedBackoff.Milliseconds(), i-1, i) 914 differences = append(differences, float64(delay-expectedBackoff)) 915 } 916 require.Equal(s.T(), maxReconnectCount+1, len(calledAt), 917 "the actual times called is %d, expect %d", len(calledAt), maxReconnectCount+1) 918 s.T().Logf("differences: %+v", differences) 919 s.T().Logf("mean of differences between real retry delay and the calculated backoff: %vns", mean(differences)) 920 ss := std(differences) 921 s.T().Logf("std of differences between real retry delay and the calculated backoff: %vns", ss) 922 const expectedStdLimit = time.Second 923 require.Less(s.T(), ss, float64(expectedStdLimit), 924 "standard deviation of differences between real retry delay and calculated backoff is expected to be within %s", 925 expectedStdLimit) 926 } 927 928 func (s *msuite) TestMultiplexedReconnectOnWriteError() { 929 ctx := context.Background() 930 ts := newTCPServer() 931 ts.start(ctx) 932 defer ts.stop() 933 m := New( 934 WithConnectNumber(1), 935 // On windows, it will try to use up all the timeout to do the dialling. 936 // So limit the dial timeout. 937 WithDialTimeout(time.Millisecond*10), 938 ) 939 ctx, cancel := context.WithTimeout(ctx, time.Second) 940 defer cancel() 941 opts := NewGetOptions() 942 opts.WithVID(atomic.AddUint32(&s.requestID, 1)) 943 readTrigger := make(chan struct{}) 944 readErr := make(chan error) 945 opts.WithFrameParser(&triggeredReadFramerBuilder{readTrigger: readTrigger, readErr: readErr}) 946 mc, err := m.GetMuxConn(ctx, s.network, ts.ln.Addr().String(), opts) 947 require.Nil(s.T(), err) 948 vc, ok := mc.(*VirtualConnection) 949 assert.True(s.T(), ok) 950 <-readTrigger // Wait for the first read. 951 require.Nil(s.T(), vc.conn.getRawConn().Close()) // Now close the underlying connection. 952 require.Nil(s.T(), vc.Write([]byte("hello"))) // Then this write will trigger a reconnection on write error. 953 // Now we are cool to check that a reconnection is triggered. 954 require.Eventually(s.T(), 955 func() bool { return 1 == vc.conn.reconnectCount }, 956 time.Second, 10*time.Millisecond) 957 } 958 959 func TestMultiplexedDestroyMayCauseGoroutineLeak(t *testing.T) { 960 l, err := net.Listen("tcp", ":") 961 require.Nil(t, err) 962 const connNum = 2 963 acceptedConns, acceptErrs := make(chan net.Conn, connNum*2), make(chan error) 964 var closedConns uint32 965 go func() { 966 for { 967 c, err := l.Accept() 968 if err != nil { 969 acceptErrs <- err 970 return 971 } 972 acceptedConns <- c 973 go func() { 974 _, _ = io.Copy(c, c) 975 atomic.AddUint32(&closedConns, 1) 976 }() 977 } 978 }() 979 980 fb := fixedLenFrameBuilder{packetLen: 2} 981 dialTimeout := time.Millisecond * 50 982 m := New( 983 WithConnectNumber(connNum), 984 // replace the too long default 1s dail timeout. 985 WithDialTimeout(dialTimeout)) 986 getVirtualConn := func(requestID uint32) (MuxConn, error) { 987 getOptions := NewGetOptions() 988 getOptions.WithVID(requestID) 989 getOptions.WithFrameParser(&fb) 990 return m.GetMuxConn(context.Background(), l.Addr().Network(), l.Addr().String(), getOptions) 991 } 992 993 vc, err := getVirtualConn(1) 994 require.Nil(t, err) 995 require.Nil(t, vc.Write(fb.EncodeWithRequestID(1, []byte("1a")))) 996 read, err := vc.Read() 997 require.Nil(t, err) 998 require.Equal(t, []byte("1a"), read) 999 vc.Close() 1000 1001 var ( 1002 c1 net.Conn 1003 c2 net.Conn 1004 ) 1005 select { 1006 case c1 = <-acceptedConns: 1007 case <-time.After(time.Second): 1008 require.FailNow(t, "should accept a connection") 1009 } 1010 select { 1011 case c2 = <-acceptedConns: 1012 case <-time.After(time.Second): 1013 require.FailNow(t, "multiplexed should establish two concreteConns") 1014 } 1015 1016 require.Nil(t, l.Close()) 1017 <-acceptErrs 1018 require.Nil(t, c1.Close()) 1019 // on windows, connecting to closed listener returns an error until dial timeout, not immediately. 1020 // we should sleep additional dialTimeout * maxReconnectCount to wait all retry finished. 1021 time.Sleep((maxBackoff + dialTimeout) * time.Duration(maxReconnectCount)) 1022 require.Equal(t, uint32(1), atomic.LoadUint32(&closedConns)) 1023 1024 vc, err = getVirtualConn(2) 1025 require.Nil(t, err) 1026 require.Nil(t, vc.Write(fb.EncodeWithRequestID(2, []byte("2a")))) 1027 require.EqualValues(t, 1, atomic.LoadUint32(&closedConns)) 1028 read, err = vc.Read() 1029 require.Nil(t, err) 1030 require.Equal(t, []byte("2a"), read) 1031 require.Nil(t, err) 1032 require.Nil(t, c2.Close()) 1033 } 1034 1035 func mean(v []float64) float64 { 1036 n := len(v) 1037 if n == 0 { 1038 return 0 1039 } 1040 var res float64 1041 for i := 0; i < n; i++ { 1042 res += v[i] 1043 } 1044 return res / float64(n) 1045 } 1046 1047 func variance(v []float64) float64 { 1048 n := len(v) 1049 if n <= 1 { 1050 return 0 1051 } 1052 var res float64 1053 m := mean(v) 1054 for i := 0; i < n; i++ { 1055 res += (v[i] - m) * (v[i] - m) 1056 } 1057 return res / float64(n-1) 1058 } 1059 1060 func std(v []float64) float64 { 1061 return math.Sqrt(variance(v)) 1062 } 1063 1064 type errFramerBuilder struct { 1065 readFrameCalledAt *[]time.Time 1066 } 1067 1068 func (fb *errFramerBuilder) New(io.Reader) codec.Framer { 1069 return &errFramer{ 1070 calledAt: fb.readFrameCalledAt, 1071 } 1072 } 1073 1074 func (fb *errFramerBuilder) Parse(rc io.Reader) (vid uint32, buf []byte, err error) { 1075 *fb.readFrameCalledAt = append(*fb.readFrameCalledAt, time.Now()) 1076 buf, err = fb.New(rc).ReadFrame() 1077 if err != nil { 1078 return 0, nil, err 1079 } 1080 return 0, buf, nil 1081 } 1082 1083 var errAlwaysFail = errors.New("always fail") 1084 1085 type errFramer struct { 1086 calledAt *[]time.Time 1087 } 1088 1089 // ReadFrame implements codec.Framer. 1090 func (f *errFramer) ReadFrame() ([]byte, error) { 1091 return nil, errAlwaysFail 1092 } 1093 1094 type triggeredReadFramerBuilder struct { 1095 readTrigger chan struct{} 1096 readErr chan error 1097 } 1098 1099 func (fb *triggeredReadFramerBuilder) New(io.Reader) codec.Framer { 1100 return &triggeredReadFramer{ 1101 readTrigger: fb.readTrigger, 1102 readErr: fb.readErr, 1103 } 1104 } 1105 1106 func (fb *triggeredReadFramerBuilder) Parse(rc io.Reader) (vid uint32, buf []byte, err error) { 1107 buf, err = fb.New(rc).ReadFrame() 1108 if err != nil { 1109 return 0, nil, err 1110 } 1111 return 0, buf, nil 1112 } 1113 1114 type triggeredReadFramer struct { 1115 readTrigger chan struct{} 1116 readErr chan error 1117 } 1118 1119 // ReadFrame implements codec.Framer. 1120 func (f *triggeredReadFramer) ReadFrame() ([]byte, error) { 1121 f.readTrigger <- struct{}{} 1122 err := <-f.readErr 1123 return nil, err 1124 } 1125 1126 type fixedLenFrameBuilder struct { 1127 packetLen int 1128 } 1129 1130 func (fb *fixedLenFrameBuilder) New(r io.Reader) codec.Framer { 1131 return &fixedLenFramer{ 1132 decode: fb.Decode, 1133 buf: make([]byte, 4+fb.packetLen), // uint64 request id + packet len 1134 r: r, 1135 } 1136 } 1137 1138 func (fb *fixedLenFrameBuilder) Parse(rc io.Reader) (vid uint32, buf []byte, err error) { 1139 buf = make([]byte, 4+fb.packetLen) 1140 n, err := rc.Read(buf) 1141 if err != nil { 1142 return 0, nil, err 1143 } 1144 id, bts, err := fb.Decode(buf[:n]) 1145 if err != nil { 1146 return 0, nil, err 1147 } 1148 return id, bts, nil 1149 } 1150 1151 func (*fixedLenFrameBuilder) EncodeWithRequestID(id uint32, buf []byte) []byte { 1152 bts := make([]byte, 4+len(buf)) 1153 binary.BigEndian.PutUint32(bts[:4], id) 1154 copy(bts[4:], buf) 1155 return bts 1156 } 1157 1158 func (*fixedLenFrameBuilder) Decode(bts []byte) (uint32, []byte, error) { 1159 if l := len(bts); l < 4 { 1160 return 0, nil, fmt.Errorf("bts len %d must not be lesser than 8, content: %q", l, bts) 1161 } 1162 return binary.BigEndian.Uint32(bts), bts[4:], nil 1163 } 1164 1165 type fixedLenFramer struct { 1166 decode func([]byte) (uint32, []byte, error) 1167 buf []byte 1168 r io.Reader 1169 } 1170 1171 func (f *fixedLenFramer) ReadFrame() ([]byte, error) { 1172 return nil, errors.New("should not be used by multiplexed") 1173 } 1174 1175 func newTCPServer() *tcpServer { 1176 return &tcpServer{} 1177 } 1178 1179 type tcpServer struct { 1180 cancel context.CancelFunc 1181 ln net.Listener 1182 concreteConns []net.Conn 1183 } 1184 1185 func (s *tcpServer) start(ctx context.Context) error { 1186 var err error 1187 s.ln, err = net.Listen("tcp", "127.0.0.1:0") 1188 if err != nil { 1189 return err 1190 } 1191 ctx, s.cancel = context.WithCancel(ctx) 1192 go func() { 1193 for { 1194 select { 1195 case <-ctx.Done(): 1196 return 1197 default: 1198 } 1199 conn, err := s.ln.Accept() 1200 if err != nil { 1201 log.Println("l.Accept err: ", err) 1202 return 1203 } 1204 s.concreteConns = append(s.concreteConns, conn) 1205 1206 go func() { 1207 select { 1208 case <-ctx.Done(): 1209 return 1210 default: 1211 } 1212 io.Copy(conn, conn) 1213 }() 1214 } 1215 }() 1216 return nil 1217 } 1218 1219 func (s *tcpServer) stop() { 1220 s.cancel() 1221 s.closeConnections() 1222 s.ln.Close() 1223 } 1224 1225 func (s *tcpServer) closeConnections() { 1226 for i := range s.concreteConns { 1227 s.concreteConns[i].Close() 1228 } 1229 s.concreteConns = s.concreteConns[:0] 1230 } 1231 1232 func newUDPServer() *udpServer { 1233 return &udpServer{} 1234 } 1235 1236 type udpServer struct { 1237 cancel context.CancelFunc 1238 conn net.PacketConn 1239 } 1240 1241 func (s *udpServer) start(ctx context.Context) error { 1242 var err error 1243 s.conn, err = net.ListenPacket("udp", "127.0.0.1:0") 1244 if err != nil { 1245 return err 1246 } 1247 ctx, s.cancel = context.WithCancel(ctx) 1248 go func() { 1249 buf := make([]byte, 65535) 1250 for { 1251 select { 1252 case <-ctx.Done(): 1253 return 1254 default: 1255 } 1256 n, addr, err := s.conn.ReadFrom(buf) 1257 if err != nil { 1258 log.Println("l.ReadFrom err: ", err) 1259 return 1260 } 1261 1262 s.conn.WriteTo(buf[:n], addr) 1263 } 1264 }() 1265 return nil 1266 } 1267 1268 func (s *udpServer) stop() { 1269 s.cancel() 1270 s.conn.Close() 1271 }