github.com/mailgun/holster/v4@v4.20.0/grpcconn/grpcconn.go (about) 1 package grpcconn 2 3 // gRPC connection pooling 4 5 import ( 6 "context" 7 "encoding/json" 8 "fmt" 9 "io" 10 "net/http" 11 "sync" 12 13 "github.com/mailgun/errors" 14 "github.com/mailgun/holster/v4/clock" 15 "github.com/mailgun/holster/v4/setter" 16 "github.com/mailgun/holster/v4/tracing" 17 "github.com/prometheus/client_golang/prometheus" 18 "github.com/sirupsen/logrus" 19 "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" 20 "google.golang.org/grpc" 21 grpcCodes "google.golang.org/grpc/codes" 22 "google.golang.org/grpc/credentials/insecure" 23 grpcStatus "google.golang.org/grpc/status" 24 ) 25 26 const ( 27 minConnectionTTL = 10 * clock.Second 28 defaultConnPoolCapacity = 16 29 defaultNumConnections = 1 30 ) 31 32 var ( 33 ErrConnMgrClosed = errors.New("connection manager closed") 34 errConnPoolEmpty = errors.New("connection pool empty") 35 36 MetricGRPCConnections = prometheus.NewGaugeVec(prometheus.GaugeOpts{ 37 Name: "grpcconn_connections", 38 Help: "The number of gRPC connections used by grpcconn.", 39 }, []string{"remote_service", "peer"}) 40 ) 41 42 type Config struct { 43 RPCTimeout clock.Duration 44 BackOffTimeout clock.Duration 45 Zone string 46 OverrideHostHeader string 47 48 // NumConnections is the number of client connections to establish 49 // per target endpoint 50 // 51 // NOTE: A single GRPC client opens a maximum of 100 HTTP/2 Connections 52 // to an endpoint. Once those connections are saturated, it will queue 53 // requests to be delivered once there is availability. The recommended 54 // method of overcoming this limitation is establishing multiple GPRC client 55 // connections. See https://grpc.io/docs/guides/performance/ 56 NumConnections int 57 } 58 59 // ConnFactory creates gRPC client objects. 60 type ConnFactory[T any] interface { 61 NewGRPCClient(cc grpc.ClientConnInterface) T 62 GetServerListURL() string 63 ServiceName() string 64 ShouldDisposeOfConn(err error) bool 65 } 66 67 // ConnMgr automates gRPC `Connection` pooling. This is necessary for use 68 // cases requiring frequent stream creation and high stream concurrency to 69 // avoid reaching the default 100 stream per connection limit. 70 // ConnMgr resolves gRPC instance endpoints and connects to them. Both 71 // resolution and connection is performed on the background allowing any number 72 // of concurrent AcquireConn to result in only one reconnect event. 73 type ConnMgr[T any] struct { 74 cfg *Config 75 getEndpointsURL string 76 connFactory ConnFactory[T] 77 httpClt *http.Client 78 reconnectCh chan struct{} 79 ctx context.Context 80 cancel context.CancelFunc 81 closeWG sync.WaitGroup 82 idPool *IDPool 83 log *logrus.Entry 84 85 connPoolMu sync.RWMutex 86 connPool []*Conn[T] 87 nextConnPivot uint64 88 connectedCh chan struct{} 89 } 90 91 type Conn[T any] struct { 92 inner *grpc.ClientConn 93 client T 94 counter int 95 broken bool 96 id ID 97 } 98 99 func (c *Conn[T]) Client() T { return c.client } 100 func (c *Conn[T]) Target() string { return c.inner.Target() } 101 func (c *Conn[T]) ID() string { return c.id.String() } 102 103 // NewConnMgr instantiates a connection manager that maintains a gRPC 104 // connection pool. 105 func NewConnMgr[T any](cfg *Config, httpClient *http.Client, connFactory ConnFactory[T], opts ...option[T]) *ConnMgr[T] { 106 // This ensures NumConnections is always at least 1 107 setter.SetDefault(&cfg.NumConnections, defaultNumConnections) 108 cm := ConnMgr[T]{ 109 cfg: cfg, 110 getEndpointsURL: connFactory.GetServerListURL() + "?zone=" + cfg.Zone, 111 connFactory: connFactory, 112 httpClt: httpClient, 113 reconnectCh: make(chan struct{}, 1), 114 connPool: make([]*Conn[T], 0, defaultConnPoolCapacity), 115 idPool: NewIDPool(), 116 log: logrus.WithField("category", "grpcconn"), 117 } 118 cm.ctx, cm.cancel = context.WithCancel(context.Background()) 119 cm.closeWG.Add(1) 120 go cm.run() 121 return &cm 122 } 123 124 type option[T any] func(cm *ConnMgr[T]) 125 126 func WithLogger[T any](log *logrus.Entry) option[T] { 127 return func(cm *ConnMgr[T]) { 128 cm.log = log 129 } 130 } 131 132 func (cm *ConnMgr[T]) AcquireConn(ctx context.Context) (_ *Conn[T], err error) { 133 ctx = tracing.StartScope(ctx) 134 defer func() { 135 tracing.EndScope(ctx, err) 136 }() 137 138 for { 139 // If the connection manager is already closed then return an error. 140 if cm.ctx.Err() != nil { 141 return nil, ErrConnMgrClosed 142 } 143 cm.connPoolMu.Lock() 144 // Increment the connection index pivot to ensure that we select a 145 // different connection every time when the load distribution is even. 146 cm.nextConnPivot++ 147 // Select the least loaded connection. 148 connPoolSize := len(cm.connPool) 149 var leastLoadedConn *Conn[T] 150 if connPoolSize > 0 { 151 currConnIdx := cm.nextConnPivot % uint64(connPoolSize) 152 leastLoadedConn = cm.connPool[currConnIdx] 153 for i := 1; i < connPoolSize; i++ { 154 currConnIdx = (cm.nextConnPivot + uint64(i)) % uint64(connPoolSize) 155 currConn := cm.connPool[currConnIdx] 156 if currConn.counter < leastLoadedConn.counter { 157 leastLoadedConn = currConn 158 } 159 } 160 } 161 // If a least loaded connection is selected, then return it. 162 if leastLoadedConn != nil { 163 leastLoadedConn.counter++ 164 cm.connPoolMu.Unlock() 165 return leastLoadedConn, nil 166 } 167 // We have got nothing to offer, let's refresh the connection pool to 168 // get more connections. 169 connectedCh := cm.ensureReconnect() 170 cm.connPoolMu.Unlock() 171 // Wait for the connection pool to be refreshed on the background or 172 // the operation timeout elapsing. 173 select { 174 case <-connectedCh: 175 continue 176 case <-ctx.Done(): 177 return nil, ctx.Err() 178 } 179 } 180 } 181 182 func (cm *ConnMgr[T]) ensureReconnect() chan struct{} { 183 if cm.connectedCh != nil { 184 return cm.connectedCh 185 } 186 cm.connectedCh = make(chan struct{}) 187 select { 188 case cm.reconnectCh <- struct{}{}: 189 default: 190 } 191 return cm.connectedCh 192 } 193 194 func (cm *ConnMgr[T]) ReleaseConn(conn *Conn[T], err error) bool { 195 cm.connPoolMu.Lock() 196 removedFromPool := false 197 connPoolSize := len(cm.connPool) 198 if cm.shouldDisposeOfConn(conn, err) { 199 conn.broken = true 200 // Remove the connection from the pool. 201 for i, currConn := range cm.connPool { 202 if currConn != conn { 203 continue 204 } 205 copy(cm.connPool[i:], cm.connPool[i+1:]) 206 lastIdx := len(cm.connPool) - 1 207 cm.connPool[lastIdx] = nil 208 cm.connPool = cm.connPool[:lastIdx] 209 removedFromPool = true 210 connPoolSize = len(cm.connPool) 211 cm.idPool.Release(conn.id) 212 MetricGRPCConnections.WithLabelValues(cm.connFactory.ServiceName(), conn.Target()).Dec() 213 break 214 } 215 cm.ensureReconnect() 216 } 217 conn.counter-- 218 closeConn := false 219 if conn.broken && conn.counter < 1 { 220 closeConn = true 221 } 222 cm.connPoolMu.Unlock() 223 224 if removedFromPool { 225 cm.log.WithError(err).Warnf("Server removed from %s pool: %s, poolSize=%d, reason=%s", 226 cm.connFactory.ServiceName(), conn.Target(), connPoolSize, err) 227 } 228 if closeConn { 229 _ = conn.inner.Close() 230 cm.log.Warnf("Disconnected from %s server %s", cm.connFactory.ServiceName(), conn.Target()) 231 return true 232 } 233 return false 234 } 235 236 func (cm *ConnMgr[T]) shouldDisposeOfConn(conn *Conn[T], err error) bool { 237 if conn.broken { 238 return false 239 } 240 if err == nil { 241 return false 242 } 243 244 rootErr := errors.Cause(err) 245 if errors.Is(rootErr, context.Canceled) { 246 return false 247 } 248 if errors.Is(rootErr, context.DeadlineExceeded) { 249 return false 250 } 251 switch grpcStatus.Code(err) { 252 case grpcCodes.Canceled: 253 return false 254 case grpcCodes.DeadlineExceeded: 255 return false 256 } 257 258 return cm.connFactory.ShouldDisposeOfConn(rootErr) 259 } 260 261 func (cm *ConnMgr[T]) Close() { 262 cm.cancel() 263 cm.closeWG.Wait() 264 } 265 266 func (cm *ConnMgr[T]) run() { 267 defer func() { 268 cm.connPoolMu.Lock() 269 for i, conn := range cm.connPool { 270 _ = conn.inner.Close() 271 cm.connPool[i] = nil 272 cm.idPool.Release(conn.id) 273 MetricGRPCConnections.WithLabelValues(cm.connFactory.ServiceName(), conn.Target()).Dec() 274 } 275 cm.connPool = cm.connPool[:0] 276 if cm.connectedCh != nil { 277 close(cm.connectedCh) 278 cm.connectedCh = nil 279 } 280 cm.connPoolMu.Unlock() 281 cm.closeWG.Done() 282 }() 283 var nilOrReconnectCh <-chan clock.Time 284 for { 285 select { 286 case <-nilOrReconnectCh: 287 case <-cm.reconnectCh: 288 cm.log.Info("Force connection pool refresh") 289 case <-cm.ctx.Done(): 290 return 291 } 292 reconnectPeriod, err := cm.refreshConnPool() 293 if err != nil { 294 // If the client is closing, then return immediately. 295 if errors.Is(err, context.Canceled) { 296 return 297 } 298 cm.log.WithError(err).Errorf("Failed to refresh connection pool") 299 reconnectPeriod = cm.cfg.BackOffTimeout 300 } 301 // If a server returns zero TTL it means that periodic server list 302 // refreshes should be disabled. 303 if reconnectPeriod > 0 { 304 nilOrReconnectCh = clock.After(reconnectPeriod) 305 } 306 } 307 } 308 309 func (cm *ConnMgr[T]) refreshConnPool() (clock.Duration, error) { 310 begin := clock.Now() 311 ctx, cancel := context.WithTimeout(cm.ctx, cm.cfg.RPCTimeout) 312 defer cancel() 313 314 getGRPCEndpointRs, err := cm.getServerEndpoints(ctx) 315 if err != nil { 316 return 0, errors.Wrap(err, "while getting gRPC endpoints") 317 } 318 319 // Adjust TTL to be a reasonable value. Zero disables periodic refreshes. 320 ttl := clock.Duration(getGRPCEndpointRs.TTL) * clock.Second 321 if ttl <= 0 { 322 ttl = 0 323 } else if ttl < minConnectionTTL { 324 ttl = minConnectionTTL 325 } 326 327 newConnCount := 0 328 crossZoneCount := 0 329 cm.log.Infof("Connecting to %d %s servers", len(getGRPCEndpointRs.Servers), cm.connFactory.ServiceName()) 330 for _, serverSpec := range getGRPCEndpointRs.Servers { 331 if serverSpec.Zone != cm.cfg.Zone { 332 crossZoneCount++ 333 } 334 // Do we have the correct number of connections for this serverSpec in the pool? 335 activeConnections := cm.countConnections(serverSpec.Endpoint) 336 if activeConnections >= cm.cfg.NumConnections { 337 continue 338 } 339 340 for i := 0; i < (cm.cfg.NumConnections - activeConnections); i++ { 341 conn, err := cm.newConnection(serverSpec.Endpoint) 342 if err != nil { 343 // If the client is closing, then return immediately. 344 if errors.Is(err, context.Canceled) { 345 return 0, err 346 } 347 cm.log.WithError(err).Errorf("Failed to dial %s server: %s", 348 cm.connFactory.ServiceName(), serverSpec.Endpoint) 349 break 350 } 351 352 // Add the connection to the pool and notify 353 // goroutines waiting for a connection. 354 cm.connPoolMu.Lock() 355 cm.connPool = append(cm.connPool, conn) 356 if cm.connectedCh != nil { 357 close(cm.connectedCh) 358 cm.connectedCh = nil 359 } 360 MetricGRPCConnections.WithLabelValues(cm.connFactory.ServiceName(), conn.Target()).Inc() 361 cm.connPoolMu.Unlock() 362 newConnCount++ 363 cm.log.Infof("Connected to %s server: %s, zone=%s", cm.connFactory.ServiceName(), serverSpec.Endpoint, serverSpec.Zone) 364 } 365 } 366 cm.connPoolMu.Lock() 367 connPoolSize := len(cm.connPool) 368 // If there has been no new connection established but the pool is not 369 // empty then trigger the requested connected notification anyway. 370 if connPoolSize > 0 && cm.connectedCh != nil { 371 close(cm.connectedCh) 372 cm.connectedCh = nil 373 } 374 cm.connPoolMu.Unlock() 375 took := clock.Since(begin).Truncate(clock.Millisecond) 376 cm.log.Warnf("Connection pool refreshed: took=%s, zone=%s, poolSize=%d, newConnCount=%d, knownServerCount=%d, crossZoneCount=%d, ttl=%s", 377 took, cm.cfg.Zone, connPoolSize, newConnCount, len(getGRPCEndpointRs.Servers), crossZoneCount, ttl) 378 if connPoolSize < 1 { 379 return 0, errConnPoolEmpty 380 } 381 return ttl, nil 382 } 383 384 // countConnections returns the total number of connections in the pool for the provided endpoint. 385 func (cm *ConnMgr[T]) countConnections(endpoint string) (count int) { 386 cm.connPoolMu.RLock() 387 defer cm.connPoolMu.RUnlock() 388 for i := 0; i < len(cm.connPool); i++ { 389 if cm.connPool[i].Target() == endpoint { 390 count++ 391 } 392 } 393 return count 394 } 395 396 // newConnection establishes a new GRPC connection to the provided endpoint 397 func (cm *ConnMgr[T]) newConnection(endpoint string) (*Conn[T], error) { 398 // Establish a connection with the server. 399 ctx, cancel := context.WithTimeout(cm.ctx, cm.cfg.RPCTimeout) 400 opts := []grpc.DialOption{ 401 grpc.WithBlock(), 402 grpc.WithTransportCredentials(insecureCredentials), 403 grpc.WithUnaryInterceptor(otelUnaryInterceptor), 404 grpc.WithStreamInterceptor(otelStreamInterceptor), 405 } 406 grpcConn, err := grpc.DialContext(ctx, endpoint, opts...) 407 cancel() 408 if err != nil { 409 return nil, err 410 } 411 id := cm.idPool.Allocate() 412 return &Conn[T]{ 413 inner: grpcConn, 414 client: cm.connFactory.NewGRPCClient(grpcConn), 415 id: id, 416 }, nil 417 } 418 419 var ( 420 insecureCredentials = insecure.NewCredentials() 421 otelUnaryInterceptor = otelgrpc.UnaryClientInterceptor() 422 otelStreamInterceptor = otelgrpc.StreamClientInterceptor() 423 ) 424 425 func (cm *ConnMgr[T]) getServerEndpoints(ctx context.Context) (*GetGRPCEndpointsRs, error) { 426 rq, err := http.NewRequestWithContext(ctx, "GET", cm.getEndpointsURL, http.NoBody) 427 if err != nil { 428 return nil, errors.Wrap(err, "during request") 429 } 430 431 // Override the host header if provided in config 432 if cm.cfg.OverrideHostHeader != "" { 433 rq.Host = cm.cfg.OverrideHostHeader 434 } 435 436 rs, err := cm.httpClt.Do(rq) 437 if err != nil { 438 return nil, errors.Stack(err) 439 } 440 defer rs.Body.Close() 441 442 if rs.StatusCode != http.StatusOK { 443 return nil, errFromResponse(rs) 444 } 445 var rsBody GetGRPCEndpointsRs 446 if err := readResponseBody(rs, &rsBody); err != nil { 447 return nil, errors.Wrap(err, "while unmarshalling response") 448 } 449 return &rsBody, nil 450 } 451 452 type GenericResponse struct { 453 Msg string `json:"message"` 454 } 455 456 func errFromResponse(rs *http.Response) error { 457 body, err := io.ReadAll(rs.Body) 458 if err != nil { 459 return fmt.Errorf("HTTP request error, status=%s", rs.Status) 460 } 461 defer rs.Body.Close() 462 var rsBody GenericResponse 463 if err := json.Unmarshal(body, &rsBody); err != nil { 464 return errors.Wrapf(err, "HTTP request error, status=%s, body=%s", rs.Status, body) 465 } 466 return fmt.Errorf("HTTP request error, status=%s, message=%s", rs.Status, rsBody.Msg) 467 } 468 469 // TransCountInTests returns the total number of pending read/write operations. 470 // It is only supposed to be used in tests, hence it is not exposed in Client 471 // interface. 472 func (cm *ConnMgr[T]) TransCountInTests() int { 473 transCount := 0 474 cm.connPoolMu.RLock() 475 for _, serverConn := range cm.connPool { 476 transCount += serverConn.counter 477 } 478 cm.connPoolMu.RUnlock() 479 return transCount 480 } 481 482 type GetGRPCEndpointsRs struct { 483 Servers []ServerSpec `json:"servers"` 484 TTL int `json:"ttl"` 485 // FIXME: Remove the following fields once all clients are upgraded. 486 Endpoint string `json:"grpc_endpoint"` 487 Zone string `json:"zone"` 488 } 489 490 type ServerSpec struct { 491 Endpoint string `json:"endpoint"` 492 Zone string `json:"zone"` 493 Timestamp clock.Time `json:"timestamp"` 494 } 495 496 func readResponseBody(rs *http.Response, body interface{}) error { 497 bodyBytes, err := io.ReadAll(rs.Body) 498 if err != nil { 499 return errors.Wrap(err, "while reading response") 500 } 501 if err := json.Unmarshal(bodyBytes, &body); err != nil { 502 return errors.Wrapf(err, "while parsing response %s", bodyBytes) 503 } 504 return nil 505 }