github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/crdb/pool/balancer.go (about) 1 package pool 2 3 import ( 4 "context" 5 "hash/maphash" 6 "math/rand" 7 "slices" 8 "strconv" 9 "time" 10 11 "github.com/jackc/pgx/v5" 12 "github.com/jackc/pgx/v5/pgxpool" 13 "github.com/prometheus/client_golang/prometheus" 14 "golang.org/x/exp/maps" 15 "golang.org/x/sync/semaphore" 16 17 log "github.com/authzed/spicedb/internal/logging" 18 "github.com/authzed/spicedb/pkg/genutil" 19 ) 20 21 var ( 22 connectionsPerCRDBNodeCountGauge = prometheus.NewGaugeVec(prometheus.GaugeOpts{ 23 Name: "crdb_connections_per_node", 24 Help: "the number of connections spicedb has to each crdb node", 25 }, []string{"pool", "node_id"}) 26 27 pruningTimeHistogram = prometheus.NewHistogramVec(prometheus.HistogramOpts{ 28 Name: "crdb_pruning_duration", 29 Help: "milliseconds spent on one iteration of pruning excess connections", 30 Buckets: []float64{.1, .2, .5, 1, 2, 5, 10, 20, 50, 100}, 31 }, []string{"pool"}) 32 ) 33 34 func init() { 35 prometheus.MustRegister(connectionsPerCRDBNodeCountGauge) 36 prometheus.MustRegister(pruningTimeHistogram) 37 } 38 39 type balancePoolConn[C balanceConn] interface { 40 Conn() C 41 Release() 42 } 43 44 type balanceConn interface { 45 comparable 46 IsClosed() bool 47 } 48 49 // balanceablePool is an interface that a pool must implement to allow its 50 // connections to be balanced by the balancer. 51 type balanceablePool[P balancePoolConn[C], C balanceConn] interface { 52 ID() string 53 AcquireAllIdle(ctx context.Context) []P 54 Node(conn C) uint32 55 GC(conn C) 56 MaxConns() uint32 57 Range(func(conn C, nodeID uint32)) 58 } 59 60 // NodeConnectionBalancer attempts to keep the connections managed by a RetryPool balanced between healthy nodes in 61 // a Cockroach cluster. 62 // It asynchronously processes idle connections, and kills any to nodes that have too many. When the pool reconnects, 63 // it will have a different balance of connections, and over time the balancer will bring the counts close to equal. 64 type NodeConnectionBalancer struct { 65 nodeConnectionBalancer[*pgxpool.Conn, *pgx.Conn] 66 } 67 68 // NewNodeConnectionBalancer builds a new nodeConnectionBalancer for a given connection pool and health tracker. 69 func NewNodeConnectionBalancer(pool *RetryPool, healthTracker *NodeHealthTracker, interval time.Duration) *NodeConnectionBalancer { 70 return &NodeConnectionBalancer{*newNodeConnectionBalancer[*pgxpool.Conn, *pgx.Conn](pool, healthTracker, interval)} 71 } 72 73 // nodeConnectionBalancer is generic over underlying connection types for 74 // testing purposes. Callers should use the exported NodeConnectionBalancer 75 type nodeConnectionBalancer[P balancePoolConn[C], C balanceConn] struct { 76 ticker *time.Ticker 77 sem *semaphore.Weighted 78 pool balanceablePool[P, C] 79 healthTracker *NodeHealthTracker 80 rnd *rand.Rand 81 seed int64 82 } 83 84 // newNodeConnectionBalancer is generic over underlying connection types for 85 // testing purposes. Callers should use the exported NewNodeConnectionBalancer. 86 func newNodeConnectionBalancer[P balancePoolConn[C], C balanceConn](pool balanceablePool[P, C], healthTracker *NodeHealthTracker, interval time.Duration) *nodeConnectionBalancer[P, C] { 87 seed := int64(new(maphash.Hash).Sum64()) 88 return &nodeConnectionBalancer[P, C]{ 89 ticker: time.NewTicker(interval), 90 sem: semaphore.NewWeighted(1), 91 healthTracker: healthTracker, 92 pool: pool, 93 seed: seed, 94 // nolint:gosec 95 // use of non cryptographically secure random number generator is not concern here, 96 // as it's used for shuffling the nodes to balance the connections when the number of 97 // connections do not divide evenly. 98 rnd: rand.New(rand.NewSource(seed)), 99 } 100 } 101 102 // Prune starts periodically checking idle connections and killing ones that are determined to be unbalanced. 103 func (p *nodeConnectionBalancer[P, C]) Prune(ctx context.Context) { 104 for { 105 select { 106 case <-ctx.Done(): 107 p.ticker.Stop() 108 return 109 case <-p.ticker.C: 110 if p.sem.TryAcquire(1) { 111 ctx, cancel := context.WithTimeout(ctx, 10*time.Second) 112 p.mustPruneConnections(ctx) 113 cancel() 114 p.sem.Release(1) 115 } 116 } 117 } 118 } 119 120 // mustPruneConnections prunes connections to nodes that have more than MaxConns/(# of nodes) 121 // This causes the pool to reconnect, which over time will lead to a balanced number of connections 122 // across each node. 123 func (p *nodeConnectionBalancer[P, C]) mustPruneConnections(ctx context.Context) { 124 start := time.Now() 125 defer func() { 126 pruningTimeHistogram.WithLabelValues(p.pool.ID()).Observe(float64(time.Since(start).Milliseconds())) 127 }() 128 conns := p.pool.AcquireAllIdle(ctx) 129 defer func() { 130 // release all acquired idle conns back 131 for _, c := range conns { 132 c.Release() 133 } 134 }() 135 136 // bucket connections by healthy node 137 healthyConns := make(map[uint32][]P, 0) 138 for _, c := range conns { 139 id := p.pool.Node(c.Conn()) 140 if p.healthTracker.IsHealthy(id) { 141 if healthyConns[id] == nil { 142 healthyConns[id] = make([]P, 0, 1) 143 } 144 healthyConns[id] = append(healthyConns[id], c) 145 } else { 146 p.pool.GC(c.Conn()) 147 } 148 } 149 150 nodeCount := uint32(p.healthTracker.HealthyNodeCount()) 151 if nodeCount == 0 { 152 nodeCount = 1 153 } 154 155 connectionCounts := make(map[uint32]uint32) 156 p.pool.Range(func(conn C, nodeID uint32) { 157 connectionCounts[nodeID]++ 158 }) 159 160 log.Ctx(ctx).Trace(). 161 Str("pool", p.pool.ID()). 162 Any("counts", connectionCounts). 163 Msg("connections per node") 164 165 // Delete metrics for nodes we no longer have connections for 166 p.healthTracker.RLock() 167 for node := range p.healthTracker.nodesEverSeen { 168 if _, ok := connectionCounts[node]; !ok { 169 connectionsPerCRDBNodeCountGauge.DeletePartialMatch(map[string]string{ 170 "pool": p.pool.ID(), 171 "node_id": strconv.FormatUint(uint64(node), 10), 172 }) 173 } 174 } 175 p.healthTracker.RUnlock() 176 177 nodes := maps.Keys(connectionCounts) 178 slices.Sort(nodes) 179 180 // Shuffle nodes in place deterministically based on the initial seed. 181 // This will always generate the same distribution for the life of the 182 // program, but prevents the same nodes from getting all the "extra" 183 // connections when they don't divide evenly over nodes. 184 p.rnd.Seed(p.seed) 185 p.rnd.Shuffle(len(nodes), func(i, j int) { 186 nodes[j], nodes[i] = nodes[i], nodes[j] 187 }) 188 189 initialPerNodeMax := p.pool.MaxConns() / nodeCount 190 for i, node := range nodes { 191 count := connectionCounts[node] 192 connectionsPerCRDBNodeCountGauge.WithLabelValues( 193 p.pool.ID(), 194 strconv.FormatUint(uint64(node), 10), 195 ).Set(float64(count)) 196 197 perNodeMax := initialPerNodeMax 198 199 // Assign MaxConns%(# of nodes) nodes an extra connection. This ensures that 200 // the sum of all perNodeMax values exactly equals the pool MaxConns. 201 // Without this, we will either over or underestimate the perNodeMax. 202 // If we underestimate, the balancer will fight the pool, and if we overestimate, 203 // it's possible for the difference in connections between nodes to differ by up to 204 // the number of nodes. 205 if p.healthTracker.HealthyNodeCount() == 0 || 206 uint32(i) < p.pool.MaxConns()%uint32(p.healthTracker.HealthyNodeCount()) { 207 perNodeMax++ 208 } 209 210 // Need to remove excess connections above the perNodeMax 211 numToPrune := count - perNodeMax 212 213 if count <= perNodeMax { 214 continue 215 } 216 log.Ctx(ctx).Trace(). 217 Uint32("node", node). 218 Uint32("poolmaxconns", p.pool.MaxConns()). 219 Uint32("conncount", count). 220 Uint32("nodemaxconns", perNodeMax). 221 Msg("node connections require pruning") 222 223 // Prune half of the distance we're trying to cover. This will prune more connections if the gap between 224 // desired and target is large. 225 if numToPrune > 1 { 226 numToPrune >>= 1 227 } 228 229 healthyNodeCount := genutil.MustEnsureUInt32(len(healthyConns[node])) 230 if healthyNodeCount < numToPrune { 231 numToPrune = healthyNodeCount 232 } 233 if numToPrune == 0 { 234 continue 235 } 236 237 for _, c := range healthyConns[node][:numToPrune] { 238 log.Ctx(ctx).Trace().Str("pool", p.pool.ID()).Uint32("node", node).Msg("pruning connection") 239 p.pool.GC(c.Conn()) 240 } 241 242 log.Ctx(ctx).Trace().Str("pool", p.pool.ID()).Uint32("node", node).Uint32("prunedCount", numToPrune).Msg("pruned connections") 243 } 244 }