github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/crdb/pool/balancer.go (about)

     1  package pool
     2  
     3  import (
     4  	"context"
     5  	"hash/maphash"
     6  	"math/rand"
     7  	"slices"
     8  	"strconv"
     9  	"time"
    10  
    11  	"github.com/jackc/pgx/v5"
    12  	"github.com/jackc/pgx/v5/pgxpool"
    13  	"github.com/prometheus/client_golang/prometheus"
    14  	"golang.org/x/exp/maps"
    15  	"golang.org/x/sync/semaphore"
    16  
    17  	log "github.com/authzed/spicedb/internal/logging"
    18  	"github.com/authzed/spicedb/pkg/genutil"
    19  )
    20  
    21  var (
    22  	connectionsPerCRDBNodeCountGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{
    23  		Name: "crdb_connections_per_node",
    24  		Help: "the number of connections spicedb has to each crdb node",
    25  	}, []string{"pool", "node_id"})
    26  
    27  	pruningTimeHistogram = prometheus.NewHistogramVec(prometheus.HistogramOpts{
    28  		Name:    "crdb_pruning_duration",
    29  		Help:    "milliseconds spent on one iteration of pruning excess connections",
    30  		Buckets: []float64{.1, .2, .5, 1, 2, 5, 10, 20, 50, 100},
    31  	}, []string{"pool"})
    32  )
    33  
    34  func init() {
    35  	prometheus.MustRegister(connectionsPerCRDBNodeCountGauge)
    36  	prometheus.MustRegister(pruningTimeHistogram)
    37  }
    38  
    39  type balancePoolConn[C balanceConn] interface {
    40  	Conn() C
    41  	Release()
    42  }
    43  
    44  type balanceConn interface {
    45  	comparable
    46  	IsClosed() bool
    47  }
    48  
    49  // balanceablePool is an interface that a pool must implement to allow its
    50  // connections to be balanced by the balancer.
    51  type balanceablePool[P balancePoolConn[C], C balanceConn] interface {
    52  	ID() string
    53  	AcquireAllIdle(ctx context.Context) []P
    54  	Node(conn C) uint32
    55  	GC(conn C)
    56  	MaxConns() uint32
    57  	Range(func(conn C, nodeID uint32))
    58  }
    59  
    60  // NodeConnectionBalancer attempts to keep the connections managed by a RetryPool balanced between healthy nodes in
    61  // a Cockroach cluster.
    62  // It asynchronously processes idle connections, and kills any to nodes that have too many. When the pool reconnects,
    63  // it will have a different balance of connections, and over time the balancer will bring the counts close to equal.
    64  type NodeConnectionBalancer struct {
    65  	nodeConnectionBalancer[*pgxpool.Conn, *pgx.Conn]
    66  }
    67  
    68  // NewNodeConnectionBalancer builds a new nodeConnectionBalancer for a given connection pool and health tracker.
    69  func NewNodeConnectionBalancer(pool *RetryPool, healthTracker *NodeHealthTracker, interval time.Duration) *NodeConnectionBalancer {
    70  	return &NodeConnectionBalancer{*newNodeConnectionBalancer[*pgxpool.Conn, *pgx.Conn](pool, healthTracker, interval)}
    71  }
    72  
    73  // nodeConnectionBalancer is generic over underlying connection types for
    74  // testing purposes. Callers should use the exported NodeConnectionBalancer
    75  type nodeConnectionBalancer[P balancePoolConn[C], C balanceConn] struct {
    76  	ticker        *time.Ticker
    77  	sem           *semaphore.Weighted
    78  	pool          balanceablePool[P, C]
    79  	healthTracker *NodeHealthTracker
    80  	rnd           *rand.Rand
    81  	seed          int64
    82  }
    83  
    84  // newNodeConnectionBalancer is generic over underlying connection types for
    85  // testing purposes. Callers should use the exported NewNodeConnectionBalancer.
    86  func newNodeConnectionBalancer[P balancePoolConn[C], C balanceConn](pool balanceablePool[P, C], healthTracker *NodeHealthTracker, interval time.Duration) *nodeConnectionBalancer[P, C] {
    87  	seed := int64(new(maphash.Hash).Sum64())
    88  	return &nodeConnectionBalancer[P, C]{
    89  		ticker:        time.NewTicker(interval),
    90  		sem:           semaphore.NewWeighted(1),
    91  		healthTracker: healthTracker,
    92  		pool:          pool,
    93  		seed:          seed,
    94  		// nolint:gosec
    95  		// use of non cryptographically secure random number generator is not concern here,
    96  		// as it's used for shuffling the nodes to balance the connections when the number of
    97  		// connections do not divide evenly.
    98  		rnd: rand.New(rand.NewSource(seed)),
    99  	}
   100  }
   101  
   102  // Prune starts periodically checking idle connections and killing ones that are determined to be unbalanced.
   103  func (p *nodeConnectionBalancer[P, C]) Prune(ctx context.Context) {
   104  	for {
   105  		select {
   106  		case <-ctx.Done():
   107  			p.ticker.Stop()
   108  			return
   109  		case <-p.ticker.C:
   110  			if p.sem.TryAcquire(1) {
   111  				ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
   112  				p.mustPruneConnections(ctx)
   113  				cancel()
   114  				p.sem.Release(1)
   115  			}
   116  		}
   117  	}
   118  }
   119  
   120  // mustPruneConnections prunes connections to nodes that have more than MaxConns/(# of nodes)
   121  // This causes the pool to reconnect, which over time will lead to a balanced number of connections
   122  // across each node.
   123  func (p *nodeConnectionBalancer[P, C]) mustPruneConnections(ctx context.Context) {
   124  	start := time.Now()
   125  	defer func() {
   126  		pruningTimeHistogram.WithLabelValues(p.pool.ID()).Observe(float64(time.Since(start).Milliseconds()))
   127  	}()
   128  	conns := p.pool.AcquireAllIdle(ctx)
   129  	defer func() {
   130  		// release all acquired idle conns back
   131  		for _, c := range conns {
   132  			c.Release()
   133  		}
   134  	}()
   135  
   136  	// bucket connections by healthy node
   137  	healthyConns := make(map[uint32][]P, 0)
   138  	for _, c := range conns {
   139  		id := p.pool.Node(c.Conn())
   140  		if p.healthTracker.IsHealthy(id) {
   141  			if healthyConns[id] == nil {
   142  				healthyConns[id] = make([]P, 0, 1)
   143  			}
   144  			healthyConns[id] = append(healthyConns[id], c)
   145  		} else {
   146  			p.pool.GC(c.Conn())
   147  		}
   148  	}
   149  
   150  	nodeCount := uint32(p.healthTracker.HealthyNodeCount())
   151  	if nodeCount == 0 {
   152  		nodeCount = 1
   153  	}
   154  
   155  	connectionCounts := make(map[uint32]uint32)
   156  	p.pool.Range(func(conn C, nodeID uint32) {
   157  		connectionCounts[nodeID]++
   158  	})
   159  
   160  	log.Ctx(ctx).Trace().
   161  		Str("pool", p.pool.ID()).
   162  		Any("counts", connectionCounts).
   163  		Msg("connections per node")
   164  
   165  	// Delete metrics for nodes we no longer have connections for
   166  	p.healthTracker.RLock()
   167  	for node := range p.healthTracker.nodesEverSeen {
   168  		if _, ok := connectionCounts[node]; !ok {
   169  			connectionsPerCRDBNodeCountGauge.DeletePartialMatch(map[string]string{
   170  				"pool":    p.pool.ID(),
   171  				"node_id": strconv.FormatUint(uint64(node), 10),
   172  			})
   173  		}
   174  	}
   175  	p.healthTracker.RUnlock()
   176  
   177  	nodes := maps.Keys(connectionCounts)
   178  	slices.Sort(nodes)
   179  
   180  	// Shuffle nodes in place deterministically based on the initial seed.
   181  	// This will always generate the same distribution for the life of the
   182  	// program, but prevents the same nodes from getting all the "extra"
   183  	// connections when they don't divide evenly over nodes.
   184  	p.rnd.Seed(p.seed)
   185  	p.rnd.Shuffle(len(nodes), func(i, j int) {
   186  		nodes[j], nodes[i] = nodes[i], nodes[j]
   187  	})
   188  
   189  	initialPerNodeMax := p.pool.MaxConns() / nodeCount
   190  	for i, node := range nodes {
   191  		count := connectionCounts[node]
   192  		connectionsPerCRDBNodeCountGauge.WithLabelValues(
   193  			p.pool.ID(),
   194  			strconv.FormatUint(uint64(node), 10),
   195  		).Set(float64(count))
   196  
   197  		perNodeMax := initialPerNodeMax
   198  
   199  		// Assign MaxConns%(# of nodes) nodes an extra connection. This ensures that
   200  		// the sum of all perNodeMax values exactly equals the pool MaxConns.
   201  		// Without this, we will either over or underestimate the perNodeMax.
   202  		// If we underestimate, the balancer will fight the pool, and if we overestimate,
   203  		// it's possible for the difference in connections between nodes to differ by up to
   204  		// the number of nodes.
   205  		if p.healthTracker.HealthyNodeCount() == 0 ||
   206  			uint32(i) < p.pool.MaxConns()%uint32(p.healthTracker.HealthyNodeCount()) {
   207  			perNodeMax++
   208  		}
   209  
   210  		// Need to remove excess connections above the perNodeMax
   211  		numToPrune := count - perNodeMax
   212  
   213  		if count <= perNodeMax {
   214  			continue
   215  		}
   216  		log.Ctx(ctx).Trace().
   217  			Uint32("node", node).
   218  			Uint32("poolmaxconns", p.pool.MaxConns()).
   219  			Uint32("conncount", count).
   220  			Uint32("nodemaxconns", perNodeMax).
   221  			Msg("node connections require pruning")
   222  
   223  		// Prune half of the distance we're trying to cover. This will prune more connections if the gap between
   224  		// desired and target is large.
   225  		if numToPrune > 1 {
   226  			numToPrune >>= 1
   227  		}
   228  
   229  		healthyNodeCount := genutil.MustEnsureUInt32(len(healthyConns[node]))
   230  		if healthyNodeCount < numToPrune {
   231  			numToPrune = healthyNodeCount
   232  		}
   233  		if numToPrune == 0 {
   234  			continue
   235  		}
   236  
   237  		for _, c := range healthyConns[node][:numToPrune] {
   238  			log.Ctx(ctx).Trace().Str("pool", p.pool.ID()).Uint32("node", node).Msg("pruning connection")
   239  			p.pool.GC(c.Conn())
   240  		}
   241  
   242  		log.Ctx(ctx).Trace().Str("pool", p.pool.ID()).Uint32("node", node).Uint32("prunedCount", numToPrune).Msg("pruned connections")
   243  	}
   244  }