trpc.group/trpc-go/trpc-go@v1.0.3/pool/multiplexed/multiplexed.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 // Package multiplexed provides multiplexed pool implementation. 15 package multiplexed 16 17 import ( 18 "context" 19 "errors" 20 "fmt" 21 "net" 22 "strings" 23 "sync" 24 "sync/atomic" 25 "time" 26 27 "github.com/hashicorp/go-multierror" 28 "trpc.group/trpc-go/trpc-go/internal/packetbuffer" 29 "trpc.group/trpc-go/trpc-go/internal/queue" 30 "trpc.group/trpc-go/trpc-go/internal/report" 31 "trpc.group/trpc-go/trpc-go/log" 32 "trpc.group/trpc-go/trpc-go/pool/connpool" 33 ) 34 35 // DefaultMultiplexedPool is the default multiplexed implementation. 36 var DefaultMultiplexedPool = New() 37 38 const ( 39 defaultBufferSize = 128 * 1024 40 defaultConnNumberPerHost = 2 41 defaultSendQueueSize = 1024 42 defaultDialTimeout = time.Second 43 maxBufferSize = 65535 44 ) 45 46 // The following needs to be variables according to some test cases. 47 var ( 48 initialBackoff = 5 * time.Millisecond 49 maxBackoff = 50 * time.Millisecond 50 maxReconnectCount = 10 51 // reconnectCountResetInterval is twice the expected total reconnect backoff time, 52 // i.e. 2 * \sum_{i=1}^{maxReconnectCount}(i*initialBackoff). 53 reconnectCountResetInterval = 5 * time.Millisecond * (1 + 10) * 10 54 ) 55 56 var ( 57 // ErrFrameParserNil indicates that frame parse is nil. 58 ErrFrameParserNil = errors.New("frame parser is nil") 59 // ErrRecvQueueFull receive queue full. 60 ErrRecvQueueFull = errors.New("virtual connection's recv queue is full") 61 // ErrSendQueueFull send queue is full. 62 ErrSendQueueFull = errors.New("connection's send queue is full") 63 // ErrChanClose connection is closed. 64 ErrChanClose = errors.New("unexpected recv chan close") 65 // ErrAssertFail type assert fail. 66 ErrAssertFail = errors.New("type assert fail") 67 // ErrDupRequestID duplicated request id. 68 ErrDupRequestID = errors.New("duplicated Request ID") 69 // ErrInitPoolFail failed to initialize connection. 70 ErrInitPoolFail = errors.New("init pool for specific node fail") 71 // ErrWriteNotFinished write operation is not completed. 72 ErrWriteNotFinished = errors.New("write not finished") 73 // ErrNetworkNotSupport does not support network type. 74 ErrNetworkNotSupport = errors.New("network not support") 75 // ErrConnectionsHaveBeenExpelled denotes that the connections to a certain ip:port have been expelled. 76 ErrConnectionsHaveBeenExpelled = errors.New("connections have been expelled") 77 ) 78 79 // Pool is a connection pool for multiplexing. 80 type Pool interface { 81 // GetMuxConn gets a multiplexing connection to the address on named network. 82 GetMuxConn(ctx context.Context, network string, address string, opts GetOptions) (MuxConn, error) 83 } 84 85 // New creates a new multiplexed instance. 86 func New(opt ...PoolOption) *Multiplexed { 87 opts := &PoolOptions{ 88 connectNumberPerHost: defaultConnNumberPerHost, 89 sendQueueSize: defaultSendQueueSize, 90 dialTimeout: defaultDialTimeout, 91 } 92 for _, o := range opt { 93 o(opts) 94 } 95 // The maximum number of idle connections cannot be less than the number of pre-allocated connections. 96 if opts.maxIdleConnsPerHost != 0 && opts.maxIdleConnsPerHost < opts.connectNumberPerHost { 97 opts.maxIdleConnsPerHost = opts.connectNumberPerHost 98 } 99 return &Multiplexed{ 100 concreteConns: new(sync.Map), 101 opts: opts, 102 } 103 } 104 105 // Multiplexed represents multiplexing. 106 type Multiplexed struct { 107 mu sync.RWMutex 108 // key(ip:port) 109 // => value(*Connections) <-- Multiple concrete connections to a same ip:port. 110 // => (*Connection) <-- Single concrete connection to a certain ip:port. 111 // => [](*VirtualConnection) <-- Multiple virtual connections multiplexed on a certain concrete connection. 112 concreteConns *sync.Map 113 opts *PoolOptions 114 } 115 116 // GetMuxConn gets a multiplexing connection to the address on named network. 117 func (p *Multiplexed) GetMuxConn( 118 ctx context.Context, 119 network string, 120 address string, 121 opts GetOptions, 122 ) (MuxConn, error) { 123 select { 124 case <-ctx.Done(): 125 return nil, ctx.Err() 126 default: 127 } 128 if err := opts.update(network, address); err != nil { 129 return nil, err 130 } 131 return p.get(ctx, &opts) 132 } 133 134 func (p *Multiplexed) get(ctx context.Context, opts *GetOptions) (*VirtualConnection, error) { 135 // Step 1: nodeKey(ip:port) => concrete connections. 136 value, ok := p.concreteConns.Load(opts.nodeKey) 137 if !ok { 138 p.initPoolForNode(opts) 139 value, ok = p.concreteConns.Load(opts.nodeKey) 140 if !ok { 141 return nil, ErrInitPoolFail 142 } 143 } 144 conns, ok := value.(*Connections) 145 if !ok { 146 return nil, fmt.Errorf("%w, expected: *Connections, actual: %T", ErrAssertFail, value) 147 } 148 // Step 2: concrete connections => single concrete connection. 149 conn, err := conns.pickSingleConcrete(ctx, opts) 150 if err != nil { 151 return nil, fmt.Errorf( 152 "multiplexed pick single concreate connection with node key %s err: %w", opts.nodeKey, err) 153 } 154 // Step 3: single concrete connection => virtual connection. 155 return conn.newVirConn(ctx, opts.VID), nil 156 } 157 158 func (p *Multiplexed) initPoolForNode(opts *GetOptions) { 159 p.mu.Lock() 160 defer p.mu.Unlock() 161 // Check again in case another goroutine has initialized the pool just ahead of us. 162 if _, ok := p.concreteConns.Load(opts.nodeKey); ok { 163 return 164 } 165 p.concreteConns.Store(opts.nodeKey, p.newConcreteConnections(opts)) 166 } 167 168 func (p *Multiplexed) newConcreteConnections(opts *GetOptions) *Connections { 169 conns := &Connections{ 170 nodeKey: opts.nodeKey, 171 opts: p.opts, 172 conns: make([]*Connection, 0, p.opts.connectNumberPerHost), 173 maxIdle: p.opts.maxIdleConnsPerHost, 174 destructor: func() { 175 p.concreteConns.Delete(opts.nodeKey) 176 }, 177 } 178 conns.initialize(opts) 179 return conns 180 } 181 182 func (cs *Connections) newConn(opts *GetOptions) *Connection { 183 c := &Connection{ 184 network: opts.network, 185 address: opts.address, 186 virConns: make(map[uint32]*VirtualConnection), 187 done: make(chan struct{}), 188 dropFull: cs.opts.dropFull, 189 maxVirConns: cs.opts.maxVirConnsPerConn, 190 writeBuffer: make(chan []byte, cs.opts.sendQueueSize), 191 isStream: opts.isStream, 192 isIdle: true, 193 enableIdleRemove: cs.maxIdle > 0 && cs.opts.maxVirConnsPerConn > 0, 194 connsAddIdle: func() { cs.addIdle() }, 195 connsSubIdle: func() { cs.subIdle() }, 196 connsNeedIdleRemove: func() bool { 197 return int(atomic.LoadInt32(&cs.currentIdle)) > cs.maxIdle 198 }, 199 } 200 c.destroy = func() { cs.expel(c) } 201 cs.conns = append(cs.conns, c) 202 cs.addIdle() 203 go c.startConnect(opts, cs.opts.dialTimeout) 204 return c 205 } 206 207 func dialTCP(timeout time.Duration, opts *GetOptions) (net.Conn, *connpool.DialOptions, error) { 208 dialOpts := &connpool.DialOptions{ 209 Network: opts.network, 210 Address: opts.address, 211 Timeout: timeout, 212 CACertFile: opts.CACertFile, 213 TLSCertFile: opts.TLSCertFile, 214 TLSKeyFile: opts.TLSKeyFile, 215 TLSServerName: opts.TLSServerName, 216 LocalAddr: opts.LocalAddr, 217 } 218 conn, err := tryConnect(dialOpts) 219 return conn, dialOpts, err 220 } 221 222 func dialUDP(opts *GetOptions) (net.PacketConn, *net.UDPAddr, error) { 223 addr, err := net.ResolveUDPAddr(opts.network, opts.address) 224 if err != nil { 225 return nil, nil, err 226 } 227 const defaultLocalAddr = ":" 228 localAddr := defaultLocalAddr 229 if opts.LocalAddr != "" { 230 localAddr = opts.LocalAddr 231 } 232 conn, err := net.ListenPacket(opts.network, localAddr) 233 if err != nil { 234 return nil, nil, err 235 } 236 return conn, addr, nil 237 } 238 239 func (cs *Connections) pickSingleConcrete(ctx context.Context, opts *GetOptions) (*Connection, error) { 240 // The lock is always needed because the length of cs.conns may be changed in another goroutine. 241 // Example cases: 242 // 1. During idle removal, the length of cs.conns will be reduced. 243 // 2. If max retry time is reached, the length of cs.conns will be reduced. 244 cs.mu.Lock() 245 defer cs.mu.Unlock() 246 if cs.expelled { 247 return nil, fmt.Errorf("node key: %s, err: %w, caused by sub errors on conns: %+v", 248 cs.nodeKey, ErrConnectionsHaveBeenExpelled, cs.err) 249 } 250 if cs.opts.maxVirConnsPerConn == 0 { 251 // The number of virtual connections on each concrete connection is unlimited, do round robin. 252 cs.roundRobinIndex = (cs.roundRobinIndex + 1) % cs.opts.connectNumberPerHost 253 if cs.roundRobinIndex >= len(cs.conns) { 254 // Current concrete connections have been reduced below the expected number. 255 // Fill with a new concrete connection. 256 cs.roundRobinIndex = len(cs.conns) 257 return cs.newConn(opts), nil 258 } 259 return cs.conns[cs.roundRobinIndex], nil 260 } 261 for _, c := range cs.conns { 262 if c.canGetVirConn() { 263 return c, nil 264 } 265 } 266 return cs.newConn(opts), nil 267 } 268 269 func (c *Connection) canGetVirConn() bool { 270 c.mu.RLock() 271 defer c.mu.RUnlock() 272 return c.maxVirConns == 0 || // 0 means unlimited. 273 len(c.virConns) < c.maxVirConns 274 } 275 276 // startConnect starts to actually execute the connection logic. 277 func (c *Connection) startConnect(opts *GetOptions, dialTimeout time.Duration) { 278 c.fp = opts.FP 279 if err := c.dial(dialTimeout, opts); err != nil { 280 // The first time the connection fails to be established directly fails, 281 // let the upper layer trigger the next time to re-establish the connection. 282 c.close(err, false) 283 return 284 } 285 go c.reading() 286 go c.writing() 287 } 288 289 func (c *Connection) dial(timeout time.Duration, opts *GetOptions) error { 290 if c.isStream { 291 conn, dialOpts, err := dialTCP(timeout, opts) 292 c.dialOpts = dialOpts 293 if err != nil { 294 return err 295 } 296 c.setRawConn(conn) 297 } else { 298 conn, addr, err := dialUDP(opts) 299 if err != nil { 300 return err 301 } 302 c.addr = addr 303 c.packetConn = conn 304 c.packetBuffer = packetbuffer.New(conn, maxBufferSize) 305 } 306 return nil 307 } 308 309 func (c *Connection) reading() { 310 var lastErr error 311 for { 312 select { 313 case <-c.done: 314 return 315 default: 316 } 317 vid, buf, err := c.parse() 318 if err != nil { 319 // If there is an error in tcp unpacking, it may cause problems with 320 // all subsequent parsing, so it is necessary to close the reconnection. 321 if c.isStream { 322 lastErr = err 323 report.MultiplexedTCPReconnectOnReadErr.Incr() 324 log.Tracef("reconnect on read err: %+v", err) 325 break 326 } 327 // udp is processed according to a single packet, receiving an illegal 328 // packet does not affect the subsequent packet processing logic, and can continue to receive packets. 329 log.Tracef("decode packet err: %s", err) 330 continue 331 } 332 333 c.mu.RLock() 334 vc, ok := c.virConns[vid] 335 c.mu.RUnlock() 336 if !ok { 337 continue 338 } 339 vc.recvQueue.Put(buf) 340 } 341 c.close(lastErr, true) 342 } 343 344 func (c *Connection) writing() { 345 var lastErr error 346 L: 347 for { 348 select { 349 case <-c.done: 350 return 351 case it := <-c.writeBuffer: 352 if err := c.writeAll(it); err != nil { 353 if c.isStream { // If tcp fails to write data, it will cause the peer to close the connection. 354 lastErr = err 355 report.MultiplexedTCPReconnectOnWriteErr.Incr() 356 log.Tracef("reconnect on write err: %+v", err) 357 break L 358 } 359 // udp failed to send packets, you can continue to send packets. 360 log.Tracef("multiplexed send UDP packet failed: %v", err) 361 continue 362 } 363 } 364 } 365 c.close(lastErr, true) 366 } 367 368 func (c *Connection) parse() (vid uint32, buf []byte, err error) { 369 if c.isStream { 370 return c.fp.Parse(c.getRawConn()) 371 } 372 defer func() { 373 closeErr := c.packetBuffer.Next() 374 if closeErr == nil { 375 return 376 } 377 if err == nil { 378 err = closeErr 379 return 380 } 381 err = fmt.Errorf("parse error %w, close packet error %s", err, closeErr) 382 }() 383 return c.fp.Parse(c.packetBuffer) 384 } 385 386 // Connection represents the underlying tcp connection. 387 type Connection struct { 388 err error 389 address string 390 network string 391 enableIdleRemove bool 392 destroy func() 393 connsSubIdle func() 394 connsAddIdle func() 395 connsNeedIdleRemove func() bool 396 397 // reconnectCount denotes the current reconnection times. 398 reconnectCount int 399 // lastReconnectTime denotes the time at which the last reconnect happens. 400 lastReconnectTime time.Time 401 402 // mu protects the concurrency safety of virtualConnections, isIdle, 403 // and also protects the connection closing process. 404 mu sync.RWMutex 405 virConns map[uint32]*VirtualConnection 406 isIdle bool 407 408 fp FrameParser 409 done chan struct{} // closed when underlying connection closed. 410 writeBuffer chan []byte 411 dropFull bool 412 maxVirConns int 413 414 // udp only 415 packetBuffer *packetbuffer.PacketBuffer 416 addr *net.UDPAddr 417 packetConn net.PacketConn // the underlying udp connection. 418 419 // tcp/unix stream only 420 conn net.Conn // the underlying tcp connection. 421 connLocker sync.RWMutex 422 dialOpts *connpool.DialOptions 423 isStream bool 424 closed bool 425 } 426 427 func (cs *Connections) initialize(opts *GetOptions) { 428 for i := 0; i < cs.opts.connectNumberPerHost; i++ { 429 cs.newConn(opts) 430 } 431 } 432 433 func (c *Connection) setRawConn(conn net.Conn) { 434 c.connLocker.Lock() 435 defer c.connLocker.Unlock() 436 c.conn = conn 437 } 438 439 func (c *Connection) getRawConn() net.Conn { 440 c.connLocker.RLock() 441 defer c.connLocker.RUnlock() 442 return c.conn 443 } 444 445 // Connections represents a collection of concrete connections. 446 type Connections struct { 447 nodeKey string 448 maxIdle int 449 opts *PoolOptions 450 destructor func() 451 452 // mu protects the concurrent safety of the following fields. 453 mu sync.Mutex 454 conns []*Connection 455 currentIdle int32 456 roundRobinIndex int 457 expelled bool 458 err error 459 } 460 461 func (cs *Connections) addIdle() { 462 if cs.maxIdle > 0 { 463 atomic.AddInt32(&cs.currentIdle, 1) 464 } 465 } 466 467 func (cs *Connections) subIdle() { 468 if cs.maxIdle > 0 { 469 atomic.AddInt32(&cs.currentIdle, -1) 470 } 471 } 472 473 func (cs *Connections) expel(c *Connection) { 474 cs.mu.Lock() 475 cs.subIdle() 476 cs.conns = filterOutConnection(cs.conns, c) 477 cs.err = multierror.Append(cs.err, c.err).ErrorOrNil() 478 if cs.expelled || len(cs.conns) > 0 { 479 cs.mu.Unlock() 480 return 481 } 482 cs.expelled = true 483 cs.mu.Unlock() 484 cs.destructor() 485 } 486 487 func (c *Connection) newVirConn(ctx context.Context, virConnID uint32) *VirtualConnection { 488 ctx, cancel := context.WithCancel(ctx) 489 vc := &VirtualConnection{ 490 id: virConnID, 491 conn: c, 492 ctx: ctx, 493 cancelFunc: cancel, 494 recvQueue: queue.New[[]byte](ctx.Done()), 495 } 496 c.mu.Lock() 497 defer c.mu.Unlock() 498 // If connection fails to establish or reconnect, close virtual connection directly. 499 if c.closed { 500 vc.cancel(c.err) 501 } 502 // Considering the overflow of request id or the repetition of upper-level request id, 503 // you need to first read and check the request id for whether it already exists, if it exists, 504 // you need to return error to the original virtual connection. 505 if prevConn, ok := c.virConns[virConnID]; ok { 506 prevConn.cancel(ErrDupRequestID) 507 } 508 c.virConns[virConnID] = vc 509 if c.isIdle { 510 c.isIdle = false 511 c.connsSubIdle() 512 } 513 return vc 514 } 515 516 func (c *Connection) send(b []byte) error { 517 // If dropfull is set, the queue is full, then discard. 518 if c.dropFull { 519 select { 520 case c.writeBuffer <- b: 521 return nil 522 default: 523 return ErrSendQueueFull 524 } 525 } 526 select { 527 case c.writeBuffer <- b: 528 return nil 529 case <-c.done: 530 return c.err 531 } 532 } 533 534 func (c *Connection) writeAll(b []byte) error { 535 if c.isStream { 536 return c.writeTCP(b) 537 } 538 return c.writeUDP(b) 539 } 540 541 func (c *Connection) writeUDP(b []byte) error { 542 num, err := c.packetConn.WriteTo(b, c.addr) 543 if err != nil { 544 return err 545 } 546 if num != len(b) { 547 return ErrWriteNotFinished 548 } 549 return nil 550 } 551 552 func (c *Connection) writeTCP(b []byte) error { 553 var sentNum, num int 554 var err error 555 conn := c.getRawConn() 556 for sentNum < len(b) { 557 num, err = conn.Write(b[sentNum:]) 558 if err != nil { 559 return err 560 } 561 sentNum += num 562 } 563 return nil 564 } 565 566 func (c *Connection) close(lastErr error, reconnect bool) { 567 if c.isStream { 568 c.closeTCP(lastErr, reconnect) 569 return 570 } 571 c.closeUDP(lastErr) 572 } 573 574 func (c *Connection) closeUDP(lastErr error) { 575 c.destroy() 576 c.err = lastErr 577 close(c.done) 578 579 c.mu.Lock() 580 defer c.mu.Unlock() 581 for _, vc := range c.virConns { 582 vc.cancel(lastErr) 583 } 584 } 585 586 func (c *Connection) closeTCP(lastErr error, reconnect bool) { 587 if lastErr == nil { 588 return 589 } 590 if needDestroy := c.doClose(lastErr, reconnect); needDestroy { 591 c.destroy() 592 } 593 } 594 595 func (c *Connection) doClose(lastErr error, reconnect bool) (needDestroy bool) { 596 c.mu.Lock() 597 defer c.mu.Unlock() 598 599 // Do not use c.err != nil to judge, reconnection will not clear err. 600 if c.closed { 601 return false 602 } 603 c.closed = true 604 c.err = lastErr 605 606 // when close the `c.done` channel, all Read operations will return error, 607 // so we should clean all existing connections, avoiding memory leak. 608 for _, vc := range c.virConns { 609 vc.cancel(lastErr) 610 } 611 c.virConns = make(map[uint32]*VirtualConnection) 612 close(c.done) 613 if conn := c.getRawConn(); conn != nil { 614 conn.Close() 615 } 616 if reconnect && c.doReconnectBackoff() { 617 return !c.reconnect() 618 } 619 return true 620 } 621 622 func tryConnect(opts *connpool.DialOptions) (net.Conn, error) { 623 conn, err := connpool.Dial(opts) 624 if err != nil { 625 return nil, err 626 } 627 if c, ok := conn.(*net.TCPConn); ok { 628 c.SetKeepAlive(true) 629 } 630 return conn, nil 631 } 632 633 func (c *Connection) reconnect() (success bool) { 634 for { 635 conn, err := tryConnect(c.dialOpts) 636 if err != nil { 637 report.MultiplexedTCPReconnectErr.Incr() 638 log.Tracef("reconnect fail: %+v", err) 639 if !c.doReconnectBackoff() { // If the current number of retries is greater than the maximum number 640 // of retries, doReconnectBackoff will return false, so remove the corresponding connection. 641 return false // A new request will trigger a reconnection. 642 } 643 continue 644 } 645 c.setRawConn(conn) 646 c.done = make(chan struct{}) 647 if !c.isIdle { 648 c.isIdle = true 649 c.connsAddIdle() 650 } 651 // Successfully reconnected, remove the closed flag and reset c.err. 652 c.err = nil 653 c.closed = false 654 go c.reading() 655 go c.writing() 656 return true 657 } 658 } 659 660 func (c *Connection) doReconnectBackoff() bool { 661 cur := time.Now() 662 if !c.lastReconnectTime.IsZero() && c.lastReconnectTime.Add(reconnectCountResetInterval).Before(cur) { 663 // Clear reconnect count if reset interval is reached. 664 c.reconnectCount = 0 665 } 666 c.reconnectCount++ 667 c.lastReconnectTime = cur 668 if c.reconnectCount > maxReconnectCount { 669 log.Tracef("reconnection reaches its limit: %d", maxReconnectCount) 670 return false 671 } 672 currentBackoff := time.Duration(c.reconnectCount) * initialBackoff 673 if currentBackoff > maxBackoff { 674 currentBackoff = maxBackoff 675 } 676 time.Sleep(currentBackoff) 677 return true 678 } 679 680 func (c *Connection) remove(virConnID uint32) { 681 if needDestroy := c.doRemove(virConnID); needDestroy { 682 c.destroy() 683 } 684 } 685 686 func (c *Connection) doRemove(virConnID uint32) (needDestroy bool) { 687 c.mu.Lock() 688 defer c.mu.Unlock() 689 delete(c.virConns, virConnID) 690 if c.enableIdleRemove { 691 return c.idleRemove() 692 } 693 return false 694 } 695 696 func (c *Connection) idleRemove() (needDestroy bool) { 697 // Determine if the current connection is free. 698 if len(c.virConns) != 0 { 699 return false 700 } 701 // Check if the connection has been closed. 702 if c.closed { 703 return false 704 } 705 if !c.isIdle { 706 c.isIdle = true 707 c.connsAddIdle() 708 } 709 // Determine whether the current Node idle connection exceeds the maximum value. 710 if !c.connsNeedIdleRemove() { 711 return false 712 } 713 // Close the current connection. 714 c.closed = true 715 close(c.done) 716 if conn := c.getRawConn(); conn != nil { 717 conn.Close() 718 } 719 // Remove the current connection from the connection set. 720 return true 721 } 722 723 var _ MuxConn = (*VirtualConnection)(nil) 724 725 // MuxConn is virtual connection multiplexing on a real connection. 726 type MuxConn interface { 727 // Write writes data to the connection. 728 Write([]byte) error 729 730 // Read reads a packet from connection. 731 Read() ([]byte, error) 732 733 // LocalAddr returns the local network address, if known. 734 LocalAddr() net.Addr 735 736 // RemoteAddr returns the remote network address, if known. 737 RemoteAddr() net.Addr 738 739 // Close closes the connection. 740 // Any blocked Read or Write operations will be unblocked and return errors. 741 Close() 742 } 743 744 // VirtualConnection multiplexes virtual connections. 745 type VirtualConnection struct { 746 id uint32 747 conn *Connection 748 recvQueue *queue.Queue[[]byte] 749 750 ctx context.Context 751 cancelFunc context.CancelFunc 752 closed uint32 753 754 err error 755 mu sync.RWMutex 756 } 757 758 // RemoteAddr gets the peer address of the connection. 759 func (vc *VirtualConnection) RemoteAddr() net.Addr { 760 if !vc.conn.isStream { 761 return vc.conn.addr 762 } 763 if vc.conn == nil { 764 return nil 765 } 766 conn := vc.conn.getRawConn() 767 if conn == nil { 768 return nil 769 } 770 return conn.RemoteAddr() 771 } 772 773 // LocalAddr gets the local address of the connection. 774 func (vc *VirtualConnection) LocalAddr() net.Addr { 775 if vc.conn == nil { 776 return nil 777 } 778 conn := vc.conn.getRawConn() 779 if conn == nil { 780 return nil 781 } 782 return conn.LocalAddr() 783 } 784 785 // Write writes request packet. 786 // Write and Read can be concurrent, multiple Write can be concurrent. 787 func (vc *VirtualConnection) Write(b []byte) error { 788 if err := vc.loadErr(); err != nil { 789 return err 790 } 791 select { 792 case <-vc.ctx.Done(): 793 // clean the virtual connection when context timeout or cancelled. 794 vc.Close() 795 return vc.ctx.Err() 796 default: 797 } 798 if err := vc.conn.send(b); err != nil { 799 // clean the virtual connection when send fail. 800 vc.Close() 801 return err 802 } 803 return nil 804 } 805 806 // Read reads back the packet. 807 // Write and Read can be concurrent, but not concurrent Read. 808 func (vc *VirtualConnection) Read() ([]byte, error) { 809 if err := vc.loadErr(); err != nil { 810 return nil, err 811 } 812 rsp, ok := vc.recvQueue.Get() 813 if !ok { 814 vc.Close() 815 if err := vc.loadErr(); err != nil { 816 return nil, err 817 } 818 return nil, vc.ctx.Err() 819 } 820 return rsp, nil 821 } 822 823 // Close closes the connection. 824 func (vc *VirtualConnection) Close() { 825 if atomic.CompareAndSwapUint32(&vc.closed, 0, 1) { 826 vc.conn.remove(vc.id) 827 } 828 } 829 830 func (vc *VirtualConnection) loadErr() error { 831 vc.mu.RLock() 832 defer vc.mu.RUnlock() 833 return vc.err 834 } 835 836 func (vc *VirtualConnection) storeErr(err error) { 837 if vc.loadErr() != nil { 838 return 839 } 840 vc.mu.Lock() 841 defer vc.mu.Unlock() 842 vc.err = err 843 } 844 845 func (vc *VirtualConnection) cancel(err error) { 846 vc.storeErr(err) 847 vc.cancelFunc() 848 } 849 850 func makeNodeKey(network, address string) string { 851 var key strings.Builder 852 key.Grow(len(network) + len(address) + 1) 853 key.WriteString(network) 854 key.WriteString("_") 855 key.WriteString(address) 856 return key.String() 857 } 858 859 func isStream(network string) (bool, error) { 860 switch network { 861 case "tcp", "tcp4", "tcp6", "unix": 862 return true, nil 863 case "udp", "udp4", "udp6": 864 return false, nil 865 default: 866 return false, ErrNetworkNotSupport 867 } 868 } 869 870 func filterOutConnection(in []*Connection, exclude *Connection) []*Connection { 871 out := in[:0] 872 for _, v := range in { 873 if v != exclude { 874 out = append(out, v) 875 } 876 } 877 // If a connection is successfully removed, empty the last value of the slice to avoid memory leaks. 878 for i := len(out); i < len(in); i++ { 879 in[i] = nil 880 } 881 return out 882 }