trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/multiplex/multiplex.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 //go:build linux || freebsd || dragonfly || darwin 15 // +build linux freebsd dragonfly darwin 16 17 // Package multiplex implements a connection pool that supports connection multiplexing. 18 package multiplex 19 20 import ( 21 "context" 22 "errors" 23 "fmt" 24 "net" 25 "strings" 26 "sync" 27 "time" 28 29 "go.uber.org/atomic" 30 "golang.org/x/sync/singleflight" 31 "trpc.group/trpc-go/tnet" 32 33 "trpc.group/trpc-go/trpc-go/internal/queue" 34 "trpc.group/trpc-go/trpc-go/log" 35 "trpc.group/trpc-go/trpc-go/metrics" 36 "trpc.group/trpc-go/trpc-go/pool/connpool" 37 "trpc.group/trpc-go/trpc-go/pool/multiplexed" 38 ) 39 40 /* 41 Pool, host, connection all have lock. 42 The process of acquiring a lock during connection creation: 43 host.mu.Lock ----> connection.mu.Lock ----> connection.mu.Unlock ----> host.mu.Unlock 44 The process of acquiring a lock during connection closure: 45 host.mu.Lock ----> host.mu.Unlock ----> connection.mu.Lock ----> connection.mu.Unlock 46 */ 47 48 const ( 49 defaultDialTimeout = 200 * time.Millisecond 50 ) 51 52 var ( 53 // ErrConnClosed indicates connection is closed. 54 ErrConnClosed = errors.New("connection is closed") 55 // ErrDuplicateID indicates request ID already exist. 56 ErrDuplicateID = errors.New("request ID already exist") 57 // ErrInvalid indicates the operation is invalid. 58 ErrInvalid = errors.New("it's invalid") 59 60 errTooManyVirConns = errors.New("the number of virtual connections exceeds the limit") 61 ) 62 63 // PoolOption represents some settings for the multiplex pool. 64 type PoolOption struct { 65 dialTimeout time.Duration 66 maxConcurrentVirConnsPerConn int 67 enableMetrics bool 68 } 69 70 // OptPool is function to modify PoolOption. 71 type OptPool func(*PoolOption) 72 73 // WithDialTimeout returns an OptPool which sets dial timeout. 74 func WithDialTimeout(timeout time.Duration) OptPool { 75 return func(o *PoolOption) { 76 o.dialTimeout = timeout 77 } 78 } 79 80 // WithMaxConcurrentVirConnsPerConn returns an OptPool which sets the number 81 // of concurrent virtual connections per connection. 82 func WithMaxConcurrentVirConnsPerConn(max int) OptPool { 83 return func(o *PoolOption) { 84 o.maxConcurrentVirConnsPerConn = max 85 } 86 } 87 88 // WithEnableMetrics returns an OptPool which enable metrics. 89 func WithEnableMetrics() OptPool { 90 return func(o *PoolOption) { 91 o.enableMetrics = true 92 } 93 } 94 95 // NewPool creates a new multiplex pool, which uses dialFunc to dial new connections. 96 func NewPool(dialFunc connpool.DialFunc, opt ...OptPool) multiplexed.Pool { 97 opts := &PoolOption{ 98 dialTimeout: defaultDialTimeout, 99 } 100 for _, o := range opt { 101 o(opts) 102 } 103 m := &pool{ 104 dialFunc: dialFunc, 105 dialTimeout: opts.dialTimeout, 106 maxConcurrentVirConnsPerConn: opts.maxConcurrentVirConnsPerConn, 107 hosts: make(map[string]*host), 108 } 109 if opts.enableMetrics { 110 go m.metrics() 111 } 112 return m 113 } 114 115 var _ multiplexed.Pool = (*pool)(nil) 116 117 type pool struct { 118 dialFunc connpool.DialFunc 119 dialTimeout time.Duration 120 maxConcurrentVirConnsPerConn int 121 hosts map[string]*host // key is network+address 122 mu sync.RWMutex 123 } 124 125 // GetMuxConn gets a multiplexing connection to the address on named network. 126 // Multiple MuxConns can multiplex on a real connection. 127 func (p *pool) GetMuxConn( 128 ctx context.Context, 129 network string, 130 address string, 131 opts multiplexed.GetOptions, 132 ) (multiplexed.MuxConn, error) { 133 if opts.FP == nil { 134 return nil, errors.New("frame parser is not provided") 135 } 136 host := p.getHost(network, address, opts) 137 138 // Rlock here to make sure that host has not been closed. If host is closed, rLock 139 // will return false. And it also avoids reading host.conns while it is being modified. 140 if !host.mu.rLock() { 141 return nil, ErrConnClosed 142 } 143 virConn, err := newVirConn(ctx, host.conns, opts.VID, isClosedOrFull) 144 if virConn != nil || err != nil { 145 host.mu.rUnlock() 146 return virConn, err 147 } 148 host.mu.rUnlock() 149 150 for { 151 // Lock here to ensure that the connection being created is not missed when reading host.conns, 152 // because singleflightDial will lock host.mu before adding the new connection to host.conns asynchronously. 153 if !host.mu.lock() { 154 return nil, ErrConnClosed 155 } 156 virConn, err = newVirConn(ctx, host.conns, opts.VID, isClosedOrFull) 157 if virConn != nil || err != nil { 158 host.mu.unlock() 159 return virConn, err 160 } 161 // if all connections are closed or can't take more virtual connection, create one. 162 dialing := host.singleflightDial() 163 host.mu.unlock() 164 165 conn, err := waitDialing(ctx, dialing) 166 if err != nil { 167 return nil, err 168 } 169 // create new connection when the number of virtual connections exceeds the limit. 170 virConn, err = newVirConn(ctx, []*connection{conn}, opts.VID, isFull) 171 if virConn != nil || err != nil { 172 return virConn, err 173 } 174 } 175 } 176 177 func (p *pool) getHost(network string, address string, opts multiplexed.GetOptions) *host { 178 hostName := strings.Join([]string{network, address}, "_") 179 p.mu.RLock() 180 if h, ok := p.hosts[hostName]; ok { 181 p.mu.RUnlock() 182 return h 183 } 184 p.mu.RUnlock() 185 186 p.mu.Lock() 187 defer p.mu.Unlock() 188 if h, ok := p.hosts[hostName]; ok { 189 return h 190 } 191 h := &host{ 192 network: network, 193 address: address, 194 hostName: hostName, 195 dialOpts: dialOption{ 196 fp: opts.FP, 197 localAddr: opts.LocalAddr, 198 caCertFile: opts.CACertFile, 199 tlsCertFile: opts.TLSCertFile, 200 tlsKeyFile: opts.TLSKeyFile, 201 tlsServerName: opts.TLSServerName, 202 dialTimeout: p.dialTimeout, 203 }, 204 dialFunc: p.dialFunc, 205 maxConcurrentVirConnsPerConn: p.maxConcurrentVirConnsPerConn, 206 } 207 h.deleteHostFromPool = func() { 208 p.deleteHost(h) 209 } 210 p.hosts[hostName] = h 211 return h 212 } 213 214 func (p *pool) deleteHost(h *host) { 215 p.mu.Lock() 216 defer p.mu.Unlock() 217 delete(p.hosts, h.hostName) 218 } 219 220 func (p *pool) metrics() { 221 for { 222 p.mu.RLock() 223 hostCopied := make([]*host, 0, len(p.hosts)) 224 for _, host := range p.hosts { 225 hostCopied = append(hostCopied, host) 226 } 227 p.mu.RUnlock() 228 for _, host := range hostCopied { 229 host.metrics() 230 } 231 time.Sleep(3 * time.Second) 232 } 233 } 234 235 type dialOption struct { 236 fp multiplexed.FrameParser 237 localAddr string 238 dialTimeout time.Duration 239 caCertFile string 240 tlsCertFile string 241 tlsKeyFile string 242 tlsServerName string 243 } 244 245 // host manages all connections to the same network and address. 246 type host struct { 247 network string 248 address string 249 hostName string 250 dialOpts dialOption 251 dialFunc connpool.DialFunc 252 sfg singleflight.Group 253 deleteHostFromPool func() 254 mu stateRWMutex 255 conns []*connection 256 maxConcurrentVirConnsPerConn int 257 } 258 259 func (h *host) singleflightDial() <-chan singleflight.Result { 260 ch := h.sfg.DoChan(h.hostName, func() (connection interface{}, err error) { 261 rawConn, err := h.dialFunc(&connpool.DialOptions{ 262 Network: h.network, 263 Address: h.address, 264 Timeout: h.dialOpts.dialTimeout, 265 LocalAddr: h.dialOpts.localAddr, 266 CACertFile: h.dialOpts.caCertFile, 267 TLSCertFile: h.dialOpts.tlsCertFile, 268 TLSKeyFile: h.dialOpts.tlsKeyFile, 269 TLSServerName: h.dialOpts.tlsServerName, 270 }) 271 if err != nil { 272 return nil, err 273 } 274 defer func() { 275 if err != nil { 276 rawConn.Close() 277 } 278 }() 279 conn, err := h.wrapRawConn(rawConn, h.dialOpts.fp) 280 if err != nil { 281 return nil, err 282 } 283 // storeConn will call h.mu.Lock 284 if err := h.storeConn(conn); err != nil { 285 return nil, fmt.Errorf("store connection failed, %w", err) 286 } 287 return conn, nil 288 }) 289 return ch 290 } 291 292 func waitDialing(ctx context.Context, dialing <-chan singleflight.Result) (*connection, error) { 293 select { 294 case result := <-dialing: 295 return expandSFResult(result) 296 case <-ctx.Done(): 297 return nil, ctx.Err() 298 } 299 } 300 301 func (h *host) wrapRawConn(rawConn net.Conn, fp multiplexed.FrameParser) (*connection, error) { 302 // TODO: support tls 303 tc, ok := rawConn.(tnet.Conn) 304 if !ok { 305 return nil, errors.New("dialed connection must implements tnet.Conn") 306 } 307 308 c := &connection{ 309 rawConn: tc, 310 fp: fp, 311 idToVirConn: newShardMap(defaultShardSize), 312 maxConcurrentVirConns: h.maxConcurrentVirConnsPerConn, 313 } 314 c.deleteConnFromHost = func() { 315 if isLastConn := h.deleteConn(c); isLastConn { 316 h.deleteHostFromPool() 317 } 318 } 319 // TODO: support closing idle connections 320 c.rawConn.SetOnRequest(c.onRequest) 321 c.rawConn.SetOnClosed(func(tnet.Conn) error { 322 c.close(ErrConnClosed) 323 return nil 324 }) 325 return c, nil 326 } 327 328 func (h *host) loadAllConns() ([]*connection, error) { 329 if !h.mu.rLock() { 330 return nil, ErrConnClosed 331 } 332 defer h.mu.rUnlock() 333 conns := make([]*connection, len(h.conns)) 334 copy(conns, h.conns) 335 return conns, nil 336 } 337 338 func (h *host) storeConn(conn *connection) error { 339 if !h.mu.lock() { 340 return ErrConnClosed 341 } 342 defer h.mu.unlock() 343 h.conns = append(h.conns, conn) 344 return nil 345 } 346 347 func (h *host) deleteConn(conn *connection) (isLastConn bool) { 348 if !h.mu.lock() { 349 return false 350 } 351 defer h.mu.unlock() 352 h.conns = filterOutConn(h.conns, conn) 353 // close host if the last conn is deleted 354 if len(h.conns) == 0 { 355 h.mu.closeLocked() 356 return true 357 } 358 return false 359 } 360 361 func (h *host) metrics() { 362 conns, err := h.loadAllConns() 363 if err != nil { 364 return 365 } 366 var virConnNum uint32 367 for _, conn := range conns { 368 virConnNum += conn.idToVirConn.length() 369 } 370 metrics.Gauge(strings.Join([]string{"trpc.MuxConcurrentConnections", h.network, h.address}, ".")). 371 Set(float64(len(conns))) 372 metrics.Gauge(strings.Join([]string{"trpc.MuxConcurrentVirConns", h.network, h.address}, ".")). 373 Set(float64(virConnNum)) 374 log.Debugf("tnet multiplex status: network: %s, address: %s, connections number: %d,"+ 375 "concurrent virtual connection number: %d\n", h.network, h.address, len(conns), virConnNum) 376 } 377 378 func expandSFResult(result singleflight.Result) (*connection, error) { 379 if result.Err != nil { 380 return nil, result.Err 381 } 382 return result.Val.(*connection), nil 383 } 384 385 // connection wraps the underlying tnet.Conn, and manages many virtualConnections. 386 type connection struct { 387 rawConn tnet.Conn 388 deleteConnFromHost func() 389 fp multiplexed.FrameParser 390 isClosed atomic.Bool 391 mu stateRWMutex 392 idToVirConn *shardMap 393 maxConcurrentVirConns int 394 } 395 396 func (c *connection) onRequest(conn tnet.Conn) error { 397 vid, buf, err := c.fp.Parse(conn) 398 if err != nil { 399 c.close(err) 400 return err 401 } 402 vc, ok := c.idToVirConn.load(vid) 403 // If the virConn corresponding to the id cannot be found, 404 // the virConn has been closed and the current response is discarded. 405 if !ok { 406 return nil 407 } 408 vc.recvQueue.Put(buf) 409 return nil 410 } 411 412 func (c *connection) canTakeNewVirConn() bool { 413 return c.maxConcurrentVirConns == 0 || c.idToVirConn.length() < uint32(c.maxConcurrentVirConns) 414 } 415 416 func (c *connection) close(cause error) { 417 if !c.isClosed.CAS(false, true) { 418 return 419 } 420 c.deleteConnFromHost() 421 c.deleteAllVirConn(cause) 422 c.rawConn.Close() 423 } 424 425 func (c *connection) deleteAllVirConn(cause error) { 426 if !c.mu.lock() { 427 return 428 } 429 defer c.mu.unlock() 430 c.mu.closeLocked() 431 for _, vc := range c.idToVirConn.loadAll() { 432 vc.notifyRead(cause) 433 } 434 c.idToVirConn.reset() 435 } 436 437 func (c *connection) newVirConn(ctx context.Context, vid uint32) (*virtualConnection, error) { 438 if !c.mu.rLock() { 439 return nil, ErrConnClosed 440 } 441 defer c.mu.rUnlock() 442 if !c.rawConn.IsActive() { 443 return nil, ErrConnClosed 444 } 445 // CanTakeNewVirConn and loadOrStore are not atomic, which may cause 446 // the actual concurrent virConn numbers to exceed the limit max value. 447 // Implementing atomic functions requires higher lock granularity, 448 // which affects performance. 449 if !c.canTakeNewVirConn() { 450 return nil, errTooManyVirConns 451 } 452 ctx, cancel := context.WithCancel(ctx) 453 vc := &virtualConnection{ 454 ctx: ctx, 455 id: vid, 456 cancelFunc: cancel, 457 recvQueue: queue.New[[]byte](ctx.Done()), 458 write: c.rawConn.Write, 459 localAddr: c.rawConn.LocalAddr(), 460 remoteAddr: c.rawConn.RemoteAddr(), 461 deleteVirConnFromConn: func() { 462 c.deleteVirConn(vid) 463 }, 464 } 465 _, loaded := c.idToVirConn.loadOrStore(vc.id, vc) 466 if loaded { 467 cancel() 468 return nil, ErrDuplicateID 469 } 470 return vc, nil 471 } 472 473 func (c *connection) deleteVirConn(id uint32) { 474 c.idToVirConn.delete(id) 475 } 476 477 var ( 478 _ multiplexed.MuxConn = (*virtualConnection)(nil) 479 ) 480 481 type virtualConnection struct { 482 write func(b []byte) (int, error) 483 deleteVirConnFromConn func() 484 recvQueue *queue.Queue[[]byte] 485 err atomic.Error 486 ctx context.Context 487 cancelFunc context.CancelFunc 488 id uint32 489 isClosed atomic.Bool 490 localAddr net.Addr 491 remoteAddr net.Addr 492 } 493 494 // Write writes data to the connection. 495 // Write and ReadFrame can be concurrent, multiple Write can be concurrent. 496 func (vc *virtualConnection) Write(b []byte) error { 497 if vc.isClosed.Load() { 498 return vc.wrapError(ErrConnClosed) 499 } 500 _, err := vc.write(b) 501 return err 502 } 503 504 // Read reads a packet from connection. 505 // Write and Read can be concurrent, multiple Read can't be concurrent. 506 func (vc *virtualConnection) Read() ([]byte, error) { 507 if vc.isClosed.Load() { 508 return nil, vc.wrapError(ErrConnClosed) 509 } 510 rsp, ok := vc.recvQueue.Get() 511 if !ok { 512 return nil, vc.wrapError(errors.New("received data failed")) 513 } 514 return rsp, nil 515 } 516 517 // Close closes the connection. 518 // Any blocked Read or Write operations will be unblocked and return errors. 519 func (vc *virtualConnection) Close() { 520 vc.close(nil) 521 } 522 523 // LocalAddr returns the local network address, if known. 524 func (vc *virtualConnection) LocalAddr() net.Addr { 525 return vc.localAddr 526 } 527 528 // RemoteAddr returns the remote network address, if known. 529 func (vc *virtualConnection) RemoteAddr() net.Addr { 530 return vc.remoteAddr 531 } 532 533 func (vc *virtualConnection) notifyRead(cause error) { 534 if !vc.isClosed.CAS(false, true) { 535 return 536 } 537 vc.err.Store(cause) 538 vc.cancelFunc() 539 } 540 541 func (vc *virtualConnection) close(cause error) { 542 vc.notifyRead(cause) 543 vc.deleteVirConnFromConn() 544 } 545 546 func (vc *virtualConnection) wrapError(err error) error { 547 if loaded := vc.err.Load(); loaded != nil { 548 return fmt.Errorf("%w, %s", err, loaded.Error()) 549 } 550 if ctxErr := vc.ctx.Err(); ctxErr != nil { 551 return fmt.Errorf("%w, %s", err, ctxErr.Error()) 552 } 553 return err 554 } 555 556 func filterOutConn(in []*connection, exclude *connection) []*connection { 557 out := in[:0] 558 for _, v := range in { 559 if v != exclude { 560 out = append(out, v) 561 } 562 } 563 // If a connection is successfully removed, empty the last value of the slice to avoid memory leaks. 564 for i := len(out); i < len(in); i++ { 565 in[i] = nil 566 } 567 return out 568 } 569 570 func newVirConn( 571 ctx context.Context, 572 conns []*connection, 573 vid uint32, 574 isTolerable func(error) bool, 575 ) (*virtualConnection, error) { 576 for _, conn := range conns { 577 virConn, err := conn.newVirConn(ctx, vid) 578 if isTolerable(err) { 579 continue 580 } 581 return virConn, err 582 } 583 return nil, nil 584 } 585 586 func isClosedOrFull(err error) bool { 587 if err == ErrConnClosed || err == errTooManyVirConns { 588 return true 589 } 590 return false 591 } 592 593 func isFull(err error) bool { 594 return err == errTooManyVirConns 595 }