trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/multiplex/multiplex.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  //go:build linux || freebsd || dragonfly || darwin
    15  // +build linux freebsd dragonfly darwin
    16  
    17  // Package multiplex implements a connection pool that supports connection multiplexing.
    18  package multiplex
    19  
    20  import (
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"strings"
    26  	"sync"
    27  	"time"
    28  
    29  	"go.uber.org/atomic"
    30  	"golang.org/x/sync/singleflight"
    31  	"trpc.group/trpc-go/tnet"
    32  
    33  	"trpc.group/trpc-go/trpc-go/internal/queue"
    34  	"trpc.group/trpc-go/trpc-go/log"
    35  	"trpc.group/trpc-go/trpc-go/metrics"
    36  	"trpc.group/trpc-go/trpc-go/pool/connpool"
    37  	"trpc.group/trpc-go/trpc-go/pool/multiplexed"
    38  )
    39  
    40  /*
    41  	Pool, host, connection all have lock.
    42  	The process of acquiring a lock during connection creation:
    43  		host.mu.Lock ----> connection.mu.Lock ----> connection.mu.Unlock ----> host.mu.Unlock
    44  	The process of acquiring a lock during connection closure:
    45  		host.mu.Lock ----> 	host.mu.Unlock ----> connection.mu.Lock ----> connection.mu.Unlock
    46  */
    47  
    48  const (
    49  	defaultDialTimeout = 200 * time.Millisecond
    50  )
    51  
    52  var (
    53  	// ErrConnClosed indicates connection is closed.
    54  	ErrConnClosed = errors.New("connection is closed")
    55  	// ErrDuplicateID indicates request ID already exist.
    56  	ErrDuplicateID = errors.New("request ID already exist")
    57  	// ErrInvalid indicates the operation is invalid.
    58  	ErrInvalid = errors.New("it's invalid")
    59  
    60  	errTooManyVirConns = errors.New("the number of virtual connections exceeds the limit")
    61  )
    62  
    63  // PoolOption represents some settings for the multiplex pool.
    64  type PoolOption struct {
    65  	dialTimeout                  time.Duration
    66  	maxConcurrentVirConnsPerConn int
    67  	enableMetrics                bool
    68  }
    69  
    70  // OptPool is function to modify PoolOption.
    71  type OptPool func(*PoolOption)
    72  
    73  // WithDialTimeout returns an OptPool which sets dial timeout.
    74  func WithDialTimeout(timeout time.Duration) OptPool {
    75  	return func(o *PoolOption) {
    76  		o.dialTimeout = timeout
    77  	}
    78  }
    79  
    80  // WithMaxConcurrentVirConnsPerConn returns an OptPool which sets the number
    81  // of concurrent virtual connections per connection.
    82  func WithMaxConcurrentVirConnsPerConn(max int) OptPool {
    83  	return func(o *PoolOption) {
    84  		o.maxConcurrentVirConnsPerConn = max
    85  	}
    86  }
    87  
    88  // WithEnableMetrics returns an OptPool which enable metrics.
    89  func WithEnableMetrics() OptPool {
    90  	return func(o *PoolOption) {
    91  		o.enableMetrics = true
    92  	}
    93  }
    94  
    95  // NewPool creates a new multiplex pool, which uses dialFunc to dial new connections.
    96  func NewPool(dialFunc connpool.DialFunc, opt ...OptPool) multiplexed.Pool {
    97  	opts := &PoolOption{
    98  		dialTimeout: defaultDialTimeout,
    99  	}
   100  	for _, o := range opt {
   101  		o(opts)
   102  	}
   103  	m := &pool{
   104  		dialFunc:                     dialFunc,
   105  		dialTimeout:                  opts.dialTimeout,
   106  		maxConcurrentVirConnsPerConn: opts.maxConcurrentVirConnsPerConn,
   107  		hosts:                        make(map[string]*host),
   108  	}
   109  	if opts.enableMetrics {
   110  		go m.metrics()
   111  	}
   112  	return m
   113  }
   114  
   115  var _ multiplexed.Pool = (*pool)(nil)
   116  
   117  type pool struct {
   118  	dialFunc                     connpool.DialFunc
   119  	dialTimeout                  time.Duration
   120  	maxConcurrentVirConnsPerConn int
   121  	hosts                        map[string]*host // key is network+address
   122  	mu                           sync.RWMutex
   123  }
   124  
   125  // GetMuxConn gets a multiplexing connection to the address on named network.
   126  // Multiple MuxConns can multiplex on a real connection.
   127  func (p *pool) GetMuxConn(
   128  	ctx context.Context,
   129  	network string,
   130  	address string,
   131  	opts multiplexed.GetOptions,
   132  ) (multiplexed.MuxConn, error) {
   133  	if opts.FP == nil {
   134  		return nil, errors.New("frame parser is not provided")
   135  	}
   136  	host := p.getHost(network, address, opts)
   137  
   138  	// Rlock here to make sure that host has not been closed. If host is closed, rLock
   139  	// will return false. And it also avoids reading host.conns while it is being modified.
   140  	if !host.mu.rLock() {
   141  		return nil, ErrConnClosed
   142  	}
   143  	virConn, err := newVirConn(ctx, host.conns, opts.VID, isClosedOrFull)
   144  	if virConn != nil || err != nil {
   145  		host.mu.rUnlock()
   146  		return virConn, err
   147  	}
   148  	host.mu.rUnlock()
   149  
   150  	for {
   151  		// Lock here to ensure that the connection being created is not missed when reading host.conns,
   152  		// because singleflightDial will lock host.mu before adding the new connection to host.conns asynchronously.
   153  		if !host.mu.lock() {
   154  			return nil, ErrConnClosed
   155  		}
   156  		virConn, err = newVirConn(ctx, host.conns, opts.VID, isClosedOrFull)
   157  		if virConn != nil || err != nil {
   158  			host.mu.unlock()
   159  			return virConn, err
   160  		}
   161  		// if all connections are closed or can't take more virtual connection, create one.
   162  		dialing := host.singleflightDial()
   163  		host.mu.unlock()
   164  
   165  		conn, err := waitDialing(ctx, dialing)
   166  		if err != nil {
   167  			return nil, err
   168  		}
   169  		// create new connection when the number of virtual connections exceeds the limit.
   170  		virConn, err = newVirConn(ctx, []*connection{conn}, opts.VID, isFull)
   171  		if virConn != nil || err != nil {
   172  			return virConn, err
   173  		}
   174  	}
   175  }
   176  
   177  func (p *pool) getHost(network string, address string, opts multiplexed.GetOptions) *host {
   178  	hostName := strings.Join([]string{network, address}, "_")
   179  	p.mu.RLock()
   180  	if h, ok := p.hosts[hostName]; ok {
   181  		p.mu.RUnlock()
   182  		return h
   183  	}
   184  	p.mu.RUnlock()
   185  
   186  	p.mu.Lock()
   187  	defer p.mu.Unlock()
   188  	if h, ok := p.hosts[hostName]; ok {
   189  		return h
   190  	}
   191  	h := &host{
   192  		network:  network,
   193  		address:  address,
   194  		hostName: hostName,
   195  		dialOpts: dialOption{
   196  			fp:            opts.FP,
   197  			localAddr:     opts.LocalAddr,
   198  			caCertFile:    opts.CACertFile,
   199  			tlsCertFile:   opts.TLSCertFile,
   200  			tlsKeyFile:    opts.TLSKeyFile,
   201  			tlsServerName: opts.TLSServerName,
   202  			dialTimeout:   p.dialTimeout,
   203  		},
   204  		dialFunc:                     p.dialFunc,
   205  		maxConcurrentVirConnsPerConn: p.maxConcurrentVirConnsPerConn,
   206  	}
   207  	h.deleteHostFromPool = func() {
   208  		p.deleteHost(h)
   209  	}
   210  	p.hosts[hostName] = h
   211  	return h
   212  }
   213  
   214  func (p *pool) deleteHost(h *host) {
   215  	p.mu.Lock()
   216  	defer p.mu.Unlock()
   217  	delete(p.hosts, h.hostName)
   218  }
   219  
   220  func (p *pool) metrics() {
   221  	for {
   222  		p.mu.RLock()
   223  		hostCopied := make([]*host, 0, len(p.hosts))
   224  		for _, host := range p.hosts {
   225  			hostCopied = append(hostCopied, host)
   226  		}
   227  		p.mu.RUnlock()
   228  		for _, host := range hostCopied {
   229  			host.metrics()
   230  		}
   231  		time.Sleep(3 * time.Second)
   232  	}
   233  }
   234  
   235  type dialOption struct {
   236  	fp            multiplexed.FrameParser
   237  	localAddr     string
   238  	dialTimeout   time.Duration
   239  	caCertFile    string
   240  	tlsCertFile   string
   241  	tlsKeyFile    string
   242  	tlsServerName string
   243  }
   244  
   245  // host manages all connections to the same network and address.
   246  type host struct {
   247  	network                      string
   248  	address                      string
   249  	hostName                     string
   250  	dialOpts                     dialOption
   251  	dialFunc                     connpool.DialFunc
   252  	sfg                          singleflight.Group
   253  	deleteHostFromPool           func()
   254  	mu                           stateRWMutex
   255  	conns                        []*connection
   256  	maxConcurrentVirConnsPerConn int
   257  }
   258  
   259  func (h *host) singleflightDial() <-chan singleflight.Result {
   260  	ch := h.sfg.DoChan(h.hostName, func() (connection interface{}, err error) {
   261  		rawConn, err := h.dialFunc(&connpool.DialOptions{
   262  			Network:       h.network,
   263  			Address:       h.address,
   264  			Timeout:       h.dialOpts.dialTimeout,
   265  			LocalAddr:     h.dialOpts.localAddr,
   266  			CACertFile:    h.dialOpts.caCertFile,
   267  			TLSCertFile:   h.dialOpts.tlsCertFile,
   268  			TLSKeyFile:    h.dialOpts.tlsKeyFile,
   269  			TLSServerName: h.dialOpts.tlsServerName,
   270  		})
   271  		if err != nil {
   272  			return nil, err
   273  		}
   274  		defer func() {
   275  			if err != nil {
   276  				rawConn.Close()
   277  			}
   278  		}()
   279  		conn, err := h.wrapRawConn(rawConn, h.dialOpts.fp)
   280  		if err != nil {
   281  			return nil, err
   282  		}
   283  		// storeConn will call h.mu.Lock
   284  		if err := h.storeConn(conn); err != nil {
   285  			return nil, fmt.Errorf("store connection failed, %w", err)
   286  		}
   287  		return conn, nil
   288  	})
   289  	return ch
   290  }
   291  
   292  func waitDialing(ctx context.Context, dialing <-chan singleflight.Result) (*connection, error) {
   293  	select {
   294  	case result := <-dialing:
   295  		return expandSFResult(result)
   296  	case <-ctx.Done():
   297  		return nil, ctx.Err()
   298  	}
   299  }
   300  
   301  func (h *host) wrapRawConn(rawConn net.Conn, fp multiplexed.FrameParser) (*connection, error) {
   302  	// TODO: support tls
   303  	tc, ok := rawConn.(tnet.Conn)
   304  	if !ok {
   305  		return nil, errors.New("dialed connection must implements tnet.Conn")
   306  	}
   307  
   308  	c := &connection{
   309  		rawConn:               tc,
   310  		fp:                    fp,
   311  		idToVirConn:           newShardMap(defaultShardSize),
   312  		maxConcurrentVirConns: h.maxConcurrentVirConnsPerConn,
   313  	}
   314  	c.deleteConnFromHost = func() {
   315  		if isLastConn := h.deleteConn(c); isLastConn {
   316  			h.deleteHostFromPool()
   317  		}
   318  	}
   319  	// TODO: support closing idle connections
   320  	c.rawConn.SetOnRequest(c.onRequest)
   321  	c.rawConn.SetOnClosed(func(tnet.Conn) error {
   322  		c.close(ErrConnClosed)
   323  		return nil
   324  	})
   325  	return c, nil
   326  }
   327  
   328  func (h *host) loadAllConns() ([]*connection, error) {
   329  	if !h.mu.rLock() {
   330  		return nil, ErrConnClosed
   331  	}
   332  	defer h.mu.rUnlock()
   333  	conns := make([]*connection, len(h.conns))
   334  	copy(conns, h.conns)
   335  	return conns, nil
   336  }
   337  
   338  func (h *host) storeConn(conn *connection) error {
   339  	if !h.mu.lock() {
   340  		return ErrConnClosed
   341  	}
   342  	defer h.mu.unlock()
   343  	h.conns = append(h.conns, conn)
   344  	return nil
   345  }
   346  
   347  func (h *host) deleteConn(conn *connection) (isLastConn bool) {
   348  	if !h.mu.lock() {
   349  		return false
   350  	}
   351  	defer h.mu.unlock()
   352  	h.conns = filterOutConn(h.conns, conn)
   353  	// close host if the last conn is deleted
   354  	if len(h.conns) == 0 {
   355  		h.mu.closeLocked()
   356  		return true
   357  	}
   358  	return false
   359  }
   360  
   361  func (h *host) metrics() {
   362  	conns, err := h.loadAllConns()
   363  	if err != nil {
   364  		return
   365  	}
   366  	var virConnNum uint32
   367  	for _, conn := range conns {
   368  		virConnNum += conn.idToVirConn.length()
   369  	}
   370  	metrics.Gauge(strings.Join([]string{"trpc.MuxConcurrentConnections", h.network, h.address}, ".")).
   371  		Set(float64(len(conns)))
   372  	metrics.Gauge(strings.Join([]string{"trpc.MuxConcurrentVirConns", h.network, h.address}, ".")).
   373  		Set(float64(virConnNum))
   374  	log.Debugf("tnet multiplex status: network: %s, address: %s, connections number: %d,"+
   375  		"concurrent virtual connection number: %d\n", h.network, h.address, len(conns), virConnNum)
   376  }
   377  
   378  func expandSFResult(result singleflight.Result) (*connection, error) {
   379  	if result.Err != nil {
   380  		return nil, result.Err
   381  	}
   382  	return result.Val.(*connection), nil
   383  }
   384  
   385  // connection wraps the underlying tnet.Conn, and manages many virtualConnections.
   386  type connection struct {
   387  	rawConn               tnet.Conn
   388  	deleteConnFromHost    func()
   389  	fp                    multiplexed.FrameParser
   390  	isClosed              atomic.Bool
   391  	mu                    stateRWMutex
   392  	idToVirConn           *shardMap
   393  	maxConcurrentVirConns int
   394  }
   395  
   396  func (c *connection) onRequest(conn tnet.Conn) error {
   397  	vid, buf, err := c.fp.Parse(conn)
   398  	if err != nil {
   399  		c.close(err)
   400  		return err
   401  	}
   402  	vc, ok := c.idToVirConn.load(vid)
   403  	// If the virConn corresponding to the id cannot be found,
   404  	// the virConn has been closed and the current response is discarded.
   405  	if !ok {
   406  		return nil
   407  	}
   408  	vc.recvQueue.Put(buf)
   409  	return nil
   410  }
   411  
   412  func (c *connection) canTakeNewVirConn() bool {
   413  	return c.maxConcurrentVirConns == 0 || c.idToVirConn.length() < uint32(c.maxConcurrentVirConns)
   414  }
   415  
   416  func (c *connection) close(cause error) {
   417  	if !c.isClosed.CAS(false, true) {
   418  		return
   419  	}
   420  	c.deleteConnFromHost()
   421  	c.deleteAllVirConn(cause)
   422  	c.rawConn.Close()
   423  }
   424  
   425  func (c *connection) deleteAllVirConn(cause error) {
   426  	if !c.mu.lock() {
   427  		return
   428  	}
   429  	defer c.mu.unlock()
   430  	c.mu.closeLocked()
   431  	for _, vc := range c.idToVirConn.loadAll() {
   432  		vc.notifyRead(cause)
   433  	}
   434  	c.idToVirConn.reset()
   435  }
   436  
   437  func (c *connection) newVirConn(ctx context.Context, vid uint32) (*virtualConnection, error) {
   438  	if !c.mu.rLock() {
   439  		return nil, ErrConnClosed
   440  	}
   441  	defer c.mu.rUnlock()
   442  	if !c.rawConn.IsActive() {
   443  		return nil, ErrConnClosed
   444  	}
   445  	// CanTakeNewVirConn and loadOrStore are not atomic, which may cause
   446  	// the actual concurrent virConn numbers to exceed the limit max value.
   447  	// Implementing atomic functions requires higher lock granularity,
   448  	// which affects performance.
   449  	if !c.canTakeNewVirConn() {
   450  		return nil, errTooManyVirConns
   451  	}
   452  	ctx, cancel := context.WithCancel(ctx)
   453  	vc := &virtualConnection{
   454  		ctx:        ctx,
   455  		id:         vid,
   456  		cancelFunc: cancel,
   457  		recvQueue:  queue.New[[]byte](ctx.Done()),
   458  		write:      c.rawConn.Write,
   459  		localAddr:  c.rawConn.LocalAddr(),
   460  		remoteAddr: c.rawConn.RemoteAddr(),
   461  		deleteVirConnFromConn: func() {
   462  			c.deleteVirConn(vid)
   463  		},
   464  	}
   465  	_, loaded := c.idToVirConn.loadOrStore(vc.id, vc)
   466  	if loaded {
   467  		cancel()
   468  		return nil, ErrDuplicateID
   469  	}
   470  	return vc, nil
   471  }
   472  
   473  func (c *connection) deleteVirConn(id uint32) {
   474  	c.idToVirConn.delete(id)
   475  }
   476  
   477  var (
   478  	_ multiplexed.MuxConn = (*virtualConnection)(nil)
   479  )
   480  
   481  type virtualConnection struct {
   482  	write                 func(b []byte) (int, error)
   483  	deleteVirConnFromConn func()
   484  	recvQueue             *queue.Queue[[]byte]
   485  	err                   atomic.Error
   486  	ctx                   context.Context
   487  	cancelFunc            context.CancelFunc
   488  	id                    uint32
   489  	isClosed              atomic.Bool
   490  	localAddr             net.Addr
   491  	remoteAddr            net.Addr
   492  }
   493  
   494  // Write writes data to the connection.
   495  // Write and ReadFrame can be concurrent, multiple Write can be concurrent.
   496  func (vc *virtualConnection) Write(b []byte) error {
   497  	if vc.isClosed.Load() {
   498  		return vc.wrapError(ErrConnClosed)
   499  	}
   500  	_, err := vc.write(b)
   501  	return err
   502  }
   503  
   504  // Read reads a packet from connection.
   505  // Write and Read can be concurrent, multiple Read can't be concurrent.
   506  func (vc *virtualConnection) Read() ([]byte, error) {
   507  	if vc.isClosed.Load() {
   508  		return nil, vc.wrapError(ErrConnClosed)
   509  	}
   510  	rsp, ok := vc.recvQueue.Get()
   511  	if !ok {
   512  		return nil, vc.wrapError(errors.New("received data failed"))
   513  	}
   514  	return rsp, nil
   515  }
   516  
   517  // Close closes the connection.
   518  // Any blocked Read or Write operations will be unblocked and return errors.
   519  func (vc *virtualConnection) Close() {
   520  	vc.close(nil)
   521  }
   522  
   523  // LocalAddr returns the local network address, if known.
   524  func (vc *virtualConnection) LocalAddr() net.Addr {
   525  	return vc.localAddr
   526  }
   527  
   528  // RemoteAddr returns the remote network address, if known.
   529  func (vc *virtualConnection) RemoteAddr() net.Addr {
   530  	return vc.remoteAddr
   531  }
   532  
   533  func (vc *virtualConnection) notifyRead(cause error) {
   534  	if !vc.isClosed.CAS(false, true) {
   535  		return
   536  	}
   537  	vc.err.Store(cause)
   538  	vc.cancelFunc()
   539  }
   540  
   541  func (vc *virtualConnection) close(cause error) {
   542  	vc.notifyRead(cause)
   543  	vc.deleteVirConnFromConn()
   544  }
   545  
   546  func (vc *virtualConnection) wrapError(err error) error {
   547  	if loaded := vc.err.Load(); loaded != nil {
   548  		return fmt.Errorf("%w, %s", err, loaded.Error())
   549  	}
   550  	if ctxErr := vc.ctx.Err(); ctxErr != nil {
   551  		return fmt.Errorf("%w, %s", err, ctxErr.Error())
   552  	}
   553  	return err
   554  }
   555  
   556  func filterOutConn(in []*connection, exclude *connection) []*connection {
   557  	out := in[:0]
   558  	for _, v := range in {
   559  		if v != exclude {
   560  			out = append(out, v)
   561  		}
   562  	}
   563  	// If a connection is successfully removed, empty the last value of the slice to avoid memory leaks.
   564  	for i := len(out); i < len(in); i++ {
   565  		in[i] = nil
   566  	}
   567  	return out
   568  }
   569  
   570  func newVirConn(
   571  	ctx context.Context,
   572  	conns []*connection,
   573  	vid uint32,
   574  	isTolerable func(error) bool,
   575  ) (*virtualConnection, error) {
   576  	for _, conn := range conns {
   577  		virConn, err := conn.newVirConn(ctx, vid)
   578  		if isTolerable(err) {
   579  			continue
   580  		}
   581  		return virConn, err
   582  	}
   583  	return nil, nil
   584  }
   585  
   586  func isClosedOrFull(err error) bool {
   587  	if err == ErrConnClosed || err == errTooManyVirConns {
   588  		return true
   589  	}
   590  	return false
   591  }
   592  
   593  func isFull(err error) bool {
   594  	return err == errTooManyVirConns
   595  }