github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/crdb/pool/health.go (about)

     1  package pool
     2  
     3  import (
     4  	"context"
     5  	"math/rand"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/jackc/pgx/v5"
    10  	"github.com/lthibault/jitterbug"
    11  	"github.com/prometheus/client_golang/prometheus"
    12  	"golang.org/x/time/rate"
    13  
    14  	pgxcommon "github.com/authzed/spicedb/internal/datastore/postgres/common"
    15  	log "github.com/authzed/spicedb/internal/logging"
    16  )
    17  
    18  const errorBurst = 2
    19  
    20  var healthyCRDBNodeCountGauge = prometheus.NewGauge(prometheus.GaugeOpts{
    21  	Name: "crdb_healthy_nodes",
    22  	Help: "the number of healthy crdb nodes detected by spicedb",
    23  })
    24  
    25  func init() {
    26  	prometheus.MustRegister(healthyCRDBNodeCountGauge)
    27  }
    28  
    29  // NodeHealthTracker detects changes in the node pool by polling the cluster periodically and recording
    30  // the node ids that are seen. This is used to detect new nodes that come online that have either previously
    31  // been marked unhealthy due to connection errors or due to scale up.
    32  //
    33  // Consumers can manually mark a node healthy or unhealthy as well.
    34  type NodeHealthTracker struct {
    35  	sync.RWMutex
    36  	connConfig    *pgx.ConnConfig
    37  	healthyNodes  map[uint32]struct{}
    38  	nodesEverSeen map[uint32]*rate.Limiter
    39  	newLimiter    func() *rate.Limiter
    40  }
    41  
    42  // NewNodeHealthChecker builds a health checker that polls the cluster at the given url.
    43  func NewNodeHealthChecker(url string) (*NodeHealthTracker, error) {
    44  	connConfig, err := pgxcommon.ParseConfigWithInstrumentation(url)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  
    49  	return &NodeHealthTracker{
    50  		connConfig:    connConfig,
    51  		healthyNodes:  make(map[uint32]struct{}, 0),
    52  		nodesEverSeen: make(map[uint32]*rate.Limiter, 0),
    53  		newLimiter: func() *rate.Limiter {
    54  			return rate.NewLimiter(rate.Every(1*time.Minute), errorBurst)
    55  		},
    56  	}, nil
    57  }
    58  
    59  // Poll starts polling the cluster and recording the node IDs that it sees.
    60  func (t *NodeHealthTracker) Poll(ctx context.Context, interval time.Duration) {
    61  	ticker := jitterbug.New(interval, jitterbug.Uniform{
    62  		// nolint:gosec
    63  		// G404 use of non cryptographically secure random number generator is not concern here,
    64  		// as it's used for jittering the interval for health checks.
    65  		Source: rand.New(rand.NewSource(time.Now().Unix())),
    66  		Min:    interval,
    67  	})
    68  	defer ticker.Stop()
    69  	for {
    70  		select {
    71  		case <-ctx.Done():
    72  			return
    73  		case <-ticker.C:
    74  			t.tryConnect(interval)
    75  		}
    76  	}
    77  }
    78  
    79  // tryConnect attempts to connect to a node and ping it. If successful, that node is marked healthy.
    80  func (t *NodeHealthTracker) tryConnect(interval time.Duration) {
    81  	ctx, cancel := context.WithTimeout(context.Background(), interval)
    82  	defer cancel()
    83  	conn, err := pgx.ConnectConfig(ctx, t.connConfig)
    84  	if err != nil {
    85  		return
    86  	}
    87  	defer conn.Close(ctx)
    88  	if err = conn.Ping(ctx); err != nil {
    89  		return
    90  	}
    91  	log.Ctx(ctx).Trace().
    92  		Uint32("nodeID", nodeID(conn)).
    93  		Msg("health check connected to node")
    94  
    95  	// nodes are marked healthy after a successful connection
    96  	t.SetNodeHealth(nodeID(conn), true)
    97  	t.Lock()
    98  	defer t.Unlock()
    99  	t.nodesEverSeen[nodeID(conn)] = t.newLimiter()
   100  }
   101  
   102  // SetNodeHealth marks a node as either healthy or unhealthy.
   103  func (t *NodeHealthTracker) SetNodeHealth(nodeID uint32, healthy bool) {
   104  	t.Lock()
   105  	defer t.Unlock()
   106  	defer func() {
   107  		healthyCRDBNodeCountGauge.Set(float64(len(t.healthyNodes)))
   108  	}()
   109  
   110  	if _, ok := t.nodesEverSeen[nodeID]; !ok {
   111  		t.nodesEverSeen[nodeID] = t.newLimiter()
   112  	}
   113  
   114  	if healthy {
   115  		t.healthyNodes[nodeID] = struct{}{}
   116  		t.nodesEverSeen[nodeID] = t.newLimiter()
   117  		return
   118  	}
   119  
   120  	// If the limiter allows the request, it means we haven't seen more than
   121  	// 2 failures in the past 1m, so the node shouldn't be marked unhealthy yet.
   122  	// If the limiter denies the request, we've hit too many errors and the node
   123  	// is marked unhealthy.
   124  	if !t.nodesEverSeen[nodeID].Allow() {
   125  		delete(t.healthyNodes, nodeID)
   126  	}
   127  }
   128  
   129  // IsHealthy returns true if the given nodeID has been marked healthy.
   130  func (t *NodeHealthTracker) IsHealthy(nodeID uint32) bool {
   131  	t.RLock()
   132  	_, ok := t.healthyNodes[nodeID]
   133  	t.RUnlock()
   134  	return ok
   135  }
   136  
   137  // HealthyNodeCount returns the number of healthy nodes currently tracked.
   138  func (t *NodeHealthTracker) HealthyNodeCount() int {
   139  	t.RLock()
   140  	defer t.RUnlock()
   141  	return len(t.healthyNodes)
   142  }