github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/engine/access/rpc/connection/cache.go (about)

     1  package connection
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  	"time"
     7  
     8  	lru "github.com/hashicorp/golang-lru/v2"
     9  	"github.com/onflow/crypto"
    10  	"github.com/rs/zerolog"
    11  	"go.uber.org/atomic"
    12  	"google.golang.org/grpc"
    13  	"google.golang.org/grpc/connectivity"
    14  
    15  	"github.com/onflow/flow-go/module"
    16  )
    17  
    18  // CachedClient represents a gRPC client connection that is cached for reuse.
    19  type CachedClient struct {
    20  	conn    *grpc.ClientConn
    21  	address string
    22  	timeout time.Duration
    23  
    24  	cache          *Cache
    25  	closeRequested *atomic.Bool
    26  	wg             sync.WaitGroup
    27  	mu             sync.RWMutex
    28  }
    29  
    30  // ClientConn returns the underlying gRPC client connection.
    31  func (cc *CachedClient) ClientConn() *grpc.ClientConn {
    32  	cc.mu.RLock()
    33  	defer cc.mu.RUnlock()
    34  	return cc.conn
    35  }
    36  
    37  // Address returns the address of the remote server.
    38  func (cc *CachedClient) Address() string {
    39  	return cc.address
    40  }
    41  
    42  // CloseRequested returns true if the CachedClient has been marked for closure.
    43  func (cc *CachedClient) CloseRequested() bool {
    44  	return cc.closeRequested.Load()
    45  }
    46  
    47  // AddRequest increments the in-flight request counter for the CachedClient.
    48  // It returns a function that should be called when the request completes to decrement the counter
    49  func (cc *CachedClient) AddRequest() func() {
    50  	cc.wg.Add(1)
    51  	return cc.wg.Done
    52  }
    53  
    54  // Invalidate removes the CachedClient from the cache and closes the connection.
    55  func (cc *CachedClient) Invalidate() {
    56  	cc.cache.invalidate(cc.address)
    57  
    58  	// Close the connection asynchronously to avoid blocking requests
    59  	go cc.Close()
    60  }
    61  
    62  // Close closes the CachedClient connection. It marks the connection for closure and waits asynchronously for ongoing
    63  // requests to complete before closing the connection.
    64  func (cc *CachedClient) Close() {
    65  	// Mark the connection for closure
    66  	if !cc.closeRequested.CompareAndSwap(false, true) {
    67  		return
    68  	}
    69  
    70  	// Obtain the lock to ensure that any connection attempts have completed
    71  	cc.mu.RLock()
    72  	conn := cc.conn
    73  	cc.mu.RUnlock()
    74  
    75  	// If the initial connection attempt failed, conn will be nil
    76  	if conn == nil {
    77  		return
    78  	}
    79  
    80  	// If there are ongoing requests, wait for them to complete asynchronously
    81  	// this avoids tearing down the connection while requests are in-flight resulting in errors
    82  	cc.wg.Wait()
    83  
    84  	// Close the connection
    85  	conn.Close()
    86  }
    87  
    88  // Cache represents a cache of CachedClient instances with a given maximum size.
    89  type Cache struct {
    90  	cache   *lru.Cache[string, *CachedClient]
    91  	maxSize int
    92  
    93  	logger  zerolog.Logger
    94  	metrics module.GRPCConnectionPoolMetrics
    95  }
    96  
    97  // NewCache creates a new Cache with the specified maximum size and the underlying LRU cache.
    98  func NewCache(
    99  	log zerolog.Logger,
   100  	metrics module.GRPCConnectionPoolMetrics,
   101  	maxSize int,
   102  ) (*Cache, error) {
   103  	cache, err := lru.NewWithEvict(maxSize, func(_ string, client *CachedClient) {
   104  		go client.Close() // close is blocking, so run in a goroutine
   105  
   106  		log.Debug().Str("grpc_conn_evicted", client.address).Msg("closing grpc connection evicted from pool")
   107  		metrics.ConnectionFromPoolEvicted()
   108  	})
   109  
   110  	if err != nil {
   111  		return nil, fmt.Errorf("could not initialize connection pool cache: %w", err)
   112  	}
   113  
   114  	return &Cache{
   115  		cache:   cache,
   116  		maxSize: maxSize,
   117  		logger:  log,
   118  		metrics: metrics,
   119  	}, nil
   120  }
   121  
   122  // GetConnected returns a CachedClient for the given address that has an active connection.
   123  // If the address is not in the cache, it creates a new entry and connects.
   124  func (c *Cache) GetConnected(
   125  	address string,
   126  	timeout time.Duration,
   127  	networkPubKey crypto.PublicKey,
   128  	connectFn func(string, time.Duration, crypto.PublicKey, *CachedClient) (*grpc.ClientConn, error),
   129  ) (*CachedClient, error) {
   130  	client := &CachedClient{
   131  		address:        address,
   132  		timeout:        timeout,
   133  		closeRequested: atomic.NewBool(false),
   134  		cache:          c,
   135  	}
   136  
   137  	// Note: PeekOrAdd does not "visit" the existing entry, so we need to call Get explicitly
   138  	// to mark the entry as "visited" and update the LRU order. Unfortunately, the lru library
   139  	// doesn't have a GetOrAdd method, so this is the simplest way to achieve atomic get-or-add
   140  	val, existed, _ := c.cache.PeekOrAdd(address, client)
   141  	if existed {
   142  		client = val
   143  		_, _ = c.cache.Get(address)
   144  		c.metrics.ConnectionFromPoolReused()
   145  	} else {
   146  		c.metrics.ConnectionAddedToPool()
   147  	}
   148  
   149  	client.mu.Lock()
   150  	defer client.mu.Unlock()
   151  
   152  	// after getting the lock, check if the connection is still active
   153  	if client.conn != nil && client.conn.GetState() != connectivity.Shutdown {
   154  		return client, nil
   155  	}
   156  
   157  	// if the connection is not setup yet or closed, create a new connection and cache it
   158  	conn, err := connectFn(client.address, client.timeout, networkPubKey, client)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  
   163  	c.metrics.NewConnectionEstablished()
   164  	c.metrics.TotalConnectionsInPool(uint(c.Len()), uint(c.MaxSize()))
   165  
   166  	client.conn = conn
   167  	return client, nil
   168  }
   169  
   170  // invalidate removes the CachedClient entry from the cache with the given address, and shuts
   171  // down the connection.
   172  func (c *Cache) invalidate(address string) {
   173  	if !c.cache.Remove(address) {
   174  		return
   175  	}
   176  
   177  	c.logger.Debug().Str("cached_client_invalidated", address).Msg("invalidating cached client")
   178  	c.metrics.ConnectionFromPoolInvalidated()
   179  }
   180  
   181  // Len returns the number of CachedClient entries in the cache.
   182  func (c *Cache) Len() int {
   183  	return c.cache.Len()
   184  }
   185  
   186  // MaxSize returns the maximum size of the cache.
   187  func (c *Cache) MaxSize() int {
   188  	return c.maxSize
   189  }