
     1  package grpcconn
     3  // gRPC connection pooling
     5  import (
     6  	"context"
     7  	"encoding/json"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"sync"
    13  	""
    14  	""
    15  	""
    16  	""
    17  	""
    18  	""
    19  	""
    20  	""
    21  	grpcCodes ""
    22  	""
    23  	grpcStatus ""
    24  )
    26  const (
    27  	minConnectionTTL        = 10 * clock.Second
    28  	defaultConnPoolCapacity = 16
    29  	defaultNumConnections   = 1
    30  )
    32  var (
    33  	ErrConnMgrClosed = errors.New("connection manager closed")
    34  	errConnPoolEmpty = errors.New("connection pool empty")
    36  	MetricGRPCConnections = prometheus.NewGaugeVec(prometheus.GaugeOpts{
    37  		Name: "grpcconn_connections",
    38  		Help: "The number of gRPC connections used by grpcconn.",
    39  	}, []string{"remote_service", "peer"})
    40  )
    42  type Config struct {
    43  	RPCTimeout         clock.Duration
    44  	BackOffTimeout     clock.Duration
    45  	Zone               string
    46  	OverrideHostHeader string
    48  	// NumConnections is the number of client connections to establish
    49  	// per target endpoint
    50  	//
    51  	// NOTE: A single GRPC client opens a maximum of 100 HTTP/2 Connections
    52  	// to an endpoint. Once those connections are saturated, it will queue
    53  	// requests to be delivered once there is availability. The recommended
    54  	// method of overcoming this limitation is establishing multiple GPRC client
    55  	// connections. See
    56  	NumConnections int
    57  }
    59  // ConnFactory creates gRPC client objects.
    60  type ConnFactory[T any] interface {
    61  	NewGRPCClient(cc grpc.ClientConnInterface) T
    62  	GetServerListURL() string
    63  	ServiceName() string
    64  	ShouldDisposeOfConn(err error) bool
    65  }
    67  // ConnMgr automates gRPC `Connection` pooling.  This is necessary for use
    68  // cases requiring frequent stream creation and high stream concurrency to
    69  // avoid reaching the default 100 stream per connection limit.
    70  // ConnMgr resolves gRPC instance endpoints and connects to them. Both
    71  // resolution and connection is performed on the background allowing any number
    72  // of concurrent AcquireConn to result in only one reconnect event.
    73  type ConnMgr[T any] struct {
    74  	cfg             *Config
    75  	getEndpointsURL string
    76  	connFactory     ConnFactory[T]
    77  	httpClt         *http.Client
    78  	reconnectCh     chan struct{}
    79  	ctx             context.Context
    80  	cancel          context.CancelFunc
    81  	closeWG         sync.WaitGroup
    82  	idPool          *IDPool
    83  	log             *logrus.Entry
    85  	connPoolMu    sync.RWMutex
    86  	connPool      []*Conn[T]
    87  	nextConnPivot uint64
    88  	connectedCh   chan struct{}
    89  }
    91  type Conn[T any] struct {
    92  	inner   *grpc.ClientConn
    93  	client  T
    94  	counter int
    95  	broken  bool
    96  	id      ID
    97  }
    99  func (c *Conn[T]) Client() T      { return c.client }
   100  func (c *Conn[T]) Target() string { return c.inner.Target() }
   101  func (c *Conn[T]) ID() string     { return }
   103  // NewConnMgr instantiates a connection manager that maintains a gRPC
   104  // connection pool.
   105  func NewConnMgr[T any](cfg *Config, httpClient *http.Client, connFactory ConnFactory[T], opts ...option[T]) *ConnMgr[T] {
   106  	// This ensures NumConnections is always at least 1
   107  	setter.SetDefault(&cfg.NumConnections, defaultNumConnections)
   108  	cm := ConnMgr[T]{
   109  		cfg:             cfg,
   110  		getEndpointsURL: connFactory.GetServerListURL() + "?zone=" + cfg.Zone,
   111  		connFactory:     connFactory,
   112  		httpClt:         httpClient,
   113  		reconnectCh:     make(chan struct{}, 1),
   114  		connPool:        make([]*Conn[T], 0, defaultConnPoolCapacity),
   115  		idPool:          NewIDPool(),
   116  		log:             logrus.WithField("category", "grpcconn"),
   117  	}
   118  	cm.ctx, cm.cancel = context.WithCancel(context.Background())
   119  	cm.closeWG.Add(1)
   120  	go
   121  	return &cm
   122  }
   124  type option[T any] func(cm *ConnMgr[T])
   126  func WithLogger[T any](log *logrus.Entry) option[T] {
   127  	return func(cm *ConnMgr[T]) {
   128  		cm.log = log
   129  	}
   130  }
   132  func (cm *ConnMgr[T]) AcquireConn(ctx context.Context) (_ *Conn[T], err error) {
   133  	ctx = tracing.StartScope(ctx)
   134  	defer func() {
   135  		tracing.EndScope(ctx, err)
   136  	}()
   138  	for {
   139  		// If the connection manager is already closed then return an error.
   140  		if cm.ctx.Err() != nil {
   141  			return nil, ErrConnMgrClosed
   142  		}
   143  		cm.connPoolMu.Lock()
   144  		// Increment the connection index pivot to ensure that we select a
   145  		// different connection every time when the load distribution is even.
   146  		cm.nextConnPivot++
   147  		// Select the least loaded connection.
   148  		connPoolSize := len(cm.connPool)
   149  		var leastLoadedConn *Conn[T]
   150  		if connPoolSize > 0 {
   151  			currConnIdx := cm.nextConnPivot % uint64(connPoolSize)
   152  			leastLoadedConn = cm.connPool[currConnIdx]
   153  			for i := 1; i < connPoolSize; i++ {
   154  				currConnIdx = (cm.nextConnPivot + uint64(i)) % uint64(connPoolSize)
   155  				currConn := cm.connPool[currConnIdx]
   156  				if currConn.counter < leastLoadedConn.counter {
   157  					leastLoadedConn = currConn
   158  				}
   159  			}
   160  		}
   161  		// If a least loaded connection is selected, then return it.
   162  		if leastLoadedConn != nil {
   163  			leastLoadedConn.counter++
   164  			cm.connPoolMu.Unlock()
   165  			return leastLoadedConn, nil
   166  		}
   167  		// We have got nothing to offer, let's refresh the connection pool to
   168  		// get more connections.
   169  		connectedCh := cm.ensureReconnect()
   170  		cm.connPoolMu.Unlock()
   171  		// Wait for the connection pool to be refreshed on the background or
   172  		// the operation timeout elapsing.
   173  		select {
   174  		case <-connectedCh:
   175  			continue
   176  		case <-ctx.Done():
   177  			return nil, ctx.Err()
   178  		}
   179  	}
   180  }
   182  func (cm *ConnMgr[T]) ensureReconnect() chan struct{} {
   183  	if cm.connectedCh != nil {
   184  		return cm.connectedCh
   185  	}
   186  	cm.connectedCh = make(chan struct{})
   187  	select {
   188  	case cm.reconnectCh <- struct{}{}:
   189  	default:
   190  	}
   191  	return cm.connectedCh
   192  }
   194  func (cm *ConnMgr[T]) ReleaseConn(conn *Conn[T], err error) bool {
   195  	cm.connPoolMu.Lock()
   196  	removedFromPool := false
   197  	connPoolSize := len(cm.connPool)
   198  	if cm.shouldDisposeOfConn(conn, err) {
   199  		conn.broken = true
   200  		// Remove the connection from the pool.
   201  		for i, currConn := range cm.connPool {
   202  			if currConn != conn {
   203  				continue
   204  			}
   205  			copy(cm.connPool[i:], cm.connPool[i+1:])
   206  			lastIdx := len(cm.connPool) - 1
   207  			cm.connPool[lastIdx] = nil
   208  			cm.connPool = cm.connPool[:lastIdx]
   209  			removedFromPool = true
   210  			connPoolSize = len(cm.connPool)
   211  			cm.idPool.Release(
   212  			MetricGRPCConnections.WithLabelValues(cm.connFactory.ServiceName(), conn.Target()).Dec()
   213  			break
   214  		}
   215  		cm.ensureReconnect()
   216  	}
   217  	conn.counter--
   218  	closeConn := false
   219  	if conn.broken && conn.counter < 1 {
   220  		closeConn = true
   221  	}
   222  	cm.connPoolMu.Unlock()
   224  	if removedFromPool {
   225  		cm.log.WithError(err).Warnf("Server removed from %s pool: %s, poolSize=%d, reason=%s",
   226  			cm.connFactory.ServiceName(), conn.Target(), connPoolSize, err)
   227  	}
   228  	if closeConn {
   229  		_ = conn.inner.Close()
   230  		cm.log.Warnf("Disconnected from %s server %s", cm.connFactory.ServiceName(), conn.Target())
   231  		return true
   232  	}
   233  	return false
   234  }
   236  func (cm *ConnMgr[T]) shouldDisposeOfConn(conn *Conn[T], err error) bool {
   237  	if conn.broken {
   238  		return false
   239  	}
   240  	if err == nil {
   241  		return false
   242  	}
   244  	rootErr := errors.Cause(err)
   245  	if errors.Is(rootErr, context.Canceled) {
   246  		return false
   247  	}
   248  	if errors.Is(rootErr, context.DeadlineExceeded) {
   249  		return false
   250  	}
   251  	switch grpcStatus.Code(err) {
   252  	case grpcCodes.Canceled:
   253  		return false
   254  	case grpcCodes.DeadlineExceeded:
   255  		return false
   256  	}
   258  	return cm.connFactory.ShouldDisposeOfConn(rootErr)
   259  }
   261  func (cm *ConnMgr[T]) Close() {
   262  	cm.cancel()
   263  	cm.closeWG.Wait()
   264  }
   266  func (cm *ConnMgr[T]) run() {
   267  	defer func() {
   268  		cm.connPoolMu.Lock()
   269  		for i, conn := range cm.connPool {
   270  			_ = conn.inner.Close()
   271  			cm.connPool[i] = nil
   272  			cm.idPool.Release(
   273  			MetricGRPCConnections.WithLabelValues(cm.connFactory.ServiceName(), conn.Target()).Dec()
   274  		}
   275  		cm.connPool = cm.connPool[:0]
   276  		if cm.connectedCh != nil {
   277  			close(cm.connectedCh)
   278  			cm.connectedCh = nil
   279  		}
   280  		cm.connPoolMu.Unlock()
   281  		cm.closeWG.Done()
   282  	}()
   283  	var nilOrReconnectCh <-chan clock.Time
   284  	for {
   285  		select {
   286  		case <-nilOrReconnectCh:
   287  		case <-cm.reconnectCh:
   288  			cm.log.Info("Force connection pool refresh")
   289  		case <-cm.ctx.Done():
   290  			return
   291  		}
   292  		reconnectPeriod, err := cm.refreshConnPool()
   293  		if err != nil {
   294  			// If the client is closing, then return immediately.
   295  			if errors.Is(err, context.Canceled) {
   296  				return
   297  			}
   298  			cm.log.WithError(err).Errorf("Failed to refresh connection pool")
   299  			reconnectPeriod = cm.cfg.BackOffTimeout
   300  		}
   301  		// If a server returns zero TTL it means that periodic server list
   302  		// refreshes should be disabled.
   303  		if reconnectPeriod > 0 {
   304  			nilOrReconnectCh = clock.After(reconnectPeriod)
   305  		}
   306  	}
   307  }
   309  func (cm *ConnMgr[T]) refreshConnPool() (clock.Duration, error) {
   310  	begin := clock.Now()
   311  	ctx, cancel := context.WithTimeout(cm.ctx, cm.cfg.RPCTimeout)
   312  	defer cancel()
   314  	getGRPCEndpointRs, err := cm.getServerEndpoints(ctx)
   315  	if err != nil {
   316  		return 0, errors.Wrap(err, "while getting gRPC endpoints")
   317  	}
   319  	// Adjust TTL to be a reasonable value. Zero disables periodic refreshes.
   320  	ttl := clock.Duration(getGRPCEndpointRs.TTL) * clock.Second
   321  	if ttl <= 0 {
   322  		ttl = 0
   323  	} else if ttl < minConnectionTTL {
   324  		ttl = minConnectionTTL
   325  	}
   327  	newConnCount := 0
   328  	crossZoneCount := 0
   329  	cm.log.Infof("Connecting to %d %s servers", len(getGRPCEndpointRs.Servers), cm.connFactory.ServiceName())
   330  	for _, serverSpec := range getGRPCEndpointRs.Servers {
   331  		if serverSpec.Zone != cm.cfg.Zone {
   332  			crossZoneCount++
   333  		}
   334  		// Do we have the correct number of connections for this serverSpec in the pool?
   335  		activeConnections := cm.countConnections(serverSpec.Endpoint)
   336  		if activeConnections >= cm.cfg.NumConnections {
   337  			continue
   338  		}
   340  		for i := 0; i < (cm.cfg.NumConnections - activeConnections); i++ {
   341  			conn, err := cm.newConnection(serverSpec.Endpoint)
   342  			if err != nil {
   343  				// If the client is closing, then return immediately.
   344  				if errors.Is(err, context.Canceled) {
   345  					return 0, err
   346  				}
   347  				cm.log.WithError(err).Errorf("Failed to dial %s server: %s",
   348  					cm.connFactory.ServiceName(), serverSpec.Endpoint)
   349  				break
   350  			}
   352  			// Add the connection to the pool and notify
   353  			// goroutines waiting for a connection.
   354  			cm.connPoolMu.Lock()
   355  			cm.connPool = append(cm.connPool, conn)
   356  			if cm.connectedCh != nil {
   357  				close(cm.connectedCh)
   358  				cm.connectedCh = nil
   359  			}
   360  			MetricGRPCConnections.WithLabelValues(cm.connFactory.ServiceName(), conn.Target()).Inc()
   361  			cm.connPoolMu.Unlock()
   362  			newConnCount++
   363  			cm.log.Infof("Connected to %s server: %s, zone=%s", cm.connFactory.ServiceName(), serverSpec.Endpoint, serverSpec.Zone)
   364  		}
   365  	}
   366  	cm.connPoolMu.Lock()
   367  	connPoolSize := len(cm.connPool)
   368  	// If there has been no new connection established but the pool is not
   369  	// empty then trigger the requested connected notification anyway.
   370  	if connPoolSize > 0 && cm.connectedCh != nil {
   371  		close(cm.connectedCh)
   372  		cm.connectedCh = nil
   373  	}
   374  	cm.connPoolMu.Unlock()
   375  	took := clock.Since(begin).Truncate(clock.Millisecond)
   376  	cm.log.Warnf("Connection pool refreshed: took=%s, zone=%s, poolSize=%d, newConnCount=%d, knownServerCount=%d, crossZoneCount=%d, ttl=%s",
   377  		took, cm.cfg.Zone, connPoolSize, newConnCount, len(getGRPCEndpointRs.Servers), crossZoneCount, ttl)
   378  	if connPoolSize < 1 {
   379  		return 0, errConnPoolEmpty
   380  	}
   381  	return ttl, nil
   382  }
   384  // countConnections returns the total number of connections in the pool for the provided endpoint.
   385  func (cm *ConnMgr[T]) countConnections(endpoint string) (count int) {
   386  	cm.connPoolMu.RLock()
   387  	defer cm.connPoolMu.RUnlock()
   388  	for i := 0; i < len(cm.connPool); i++ {
   389  		if cm.connPool[i].Target() == endpoint {
   390  			count++
   391  		}
   392  	}
   393  	return count
   394  }
   396  // newConnection establishes a new GRPC connection to the provided endpoint
   397  func (cm *ConnMgr[T]) newConnection(endpoint string) (*Conn[T], error) {
   398  	// Establish a connection with the server.
   399  	ctx, cancel := context.WithTimeout(cm.ctx, cm.cfg.RPCTimeout)
   400  	opts := []grpc.DialOption{
   401  		grpc.WithBlock(),
   402  		grpc.WithTransportCredentials(insecureCredentials),
   403  		grpc.WithUnaryInterceptor(otelUnaryInterceptor),
   404  		grpc.WithStreamInterceptor(otelStreamInterceptor),
   405  	}
   406  	grpcConn, err := grpc.DialContext(ctx, endpoint, opts...)
   407  	cancel()
   408  	if err != nil {
   409  		return nil, err
   410  	}
   411  	id := cm.idPool.Allocate()
   412  	return &Conn[T]{
   413  		inner:  grpcConn,
   414  		client: cm.connFactory.NewGRPCClient(grpcConn),
   415  		id:     id,
   416  	}, nil
   417  }
   419  var (
   420  	insecureCredentials   = insecure.NewCredentials()
   421  	otelUnaryInterceptor  = otelgrpc.UnaryClientInterceptor()
   422  	otelStreamInterceptor = otelgrpc.StreamClientInterceptor()
   423  )
   425  func (cm *ConnMgr[T]) getServerEndpoints(ctx context.Context) (*GetGRPCEndpointsRs, error) {
   426  	rq, err := http.NewRequestWithContext(ctx, "GET", cm.getEndpointsURL, http.NoBody)
   427  	if err != nil {
   428  		return nil, errors.Wrap(err, "during request")
   429  	}
   431  	// Override the host header if provided in config
   432  	if cm.cfg.OverrideHostHeader != "" {
   433  		rq.Host = cm.cfg.OverrideHostHeader
   434  	}
   436  	rs, err := cm.httpClt.Do(rq)
   437  	if err != nil {
   438  		return nil, errors.Stack(err)
   439  	}
   440  	defer rs.Body.Close()
   442  	if rs.StatusCode != http.StatusOK {
   443  		return nil, errFromResponse(rs)
   444  	}
   445  	var rsBody GetGRPCEndpointsRs
   446  	if err := readResponseBody(rs, &rsBody); err != nil {
   447  		return nil, errors.Wrap(err, "while unmarshalling response")
   448  	}
   449  	return &rsBody, nil
   450  }
   452  type GenericResponse struct {
   453  	Msg string `json:"message"`
   454  }
   456  func errFromResponse(rs *http.Response) error {
   457  	body, err := io.ReadAll(rs.Body)
   458  	if err != nil {
   459  		return fmt.Errorf("HTTP request error, status=%s", rs.Status)
   460  	}
   461  	defer rs.Body.Close()
   462  	var rsBody GenericResponse
   463  	if err := json.Unmarshal(body, &rsBody); err != nil {
   464  		return errors.Wrapf(err, "HTTP request error, status=%s, body=%s", rs.Status, body)
   465  	}
   466  	return fmt.Errorf("HTTP request error, status=%s, message=%s", rs.Status, rsBody.Msg)
   467  }
   469  // TransCountInTests returns the total number of pending read/write operations.
   470  // It is only supposed to be used in tests, hence it is not exposed in Client
   471  // interface.
   472  func (cm *ConnMgr[T]) TransCountInTests() int {
   473  	transCount := 0
   474  	cm.connPoolMu.RLock()
   475  	for _, serverConn := range cm.connPool {
   476  		transCount += serverConn.counter
   477  	}
   478  	cm.connPoolMu.RUnlock()
   479  	return transCount
   480  }
   482  type GetGRPCEndpointsRs struct {
   483  	Servers []ServerSpec `json:"servers"`
   484  	TTL     int          `json:"ttl"`
   485  	// FIXME: Remove the following fields once all clients are upgraded.
   486  	Endpoint string `json:"grpc_endpoint"`
   487  	Zone     string `json:"zone"`
   488  }
   490  type ServerSpec struct {
   491  	Endpoint  string     `json:"endpoint"`
   492  	Zone      string     `json:"zone"`
   493  	Timestamp clock.Time `json:"timestamp"`
   494  }
   496  func readResponseBody(rs *http.Response, body interface{}) error {
   497  	bodyBytes, err := io.ReadAll(rs.Body)
   498  	if err != nil {
   499  		return errors.Wrap(err, "while reading response")
   500  	}
   501  	if err := json.Unmarshal(bodyBytes, &body); err != nil {
   502  		return errors.Wrapf(err, "while parsing response %s", bodyBytes)
   503  	}
   504  	return nil
   505  }