github.com/letsencrypt/boulder@v0.20251208.0/ratelimits/source_redis.go (about)

     1  package ratelimits
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net"
     7  	"time"
     8  
     9  	"github.com/jmhodges/clock"
    10  	"github.com/prometheus/client_golang/prometheus"
    11  	"github.com/prometheus/client_golang/prometheus/promauto"
    12  	"github.com/redis/go-redis/v9"
    13  )
    14  
    15  // Compile-time check that RedisSource implements the source interface.
    16  var _ Source = (*RedisSource)(nil)
    17  
    18  // RedisSource is a ratelimits source backed by sharded Redis.
    19  type RedisSource struct {
    20  	client  *redis.Ring
    21  	clk     clock.Clock
    22  	latency *prometheus.HistogramVec
    23  }
    24  
    25  // NewRedisSource returns a new Redis backed source using the provided
    26  // *redis.Ring client.
    27  func NewRedisSource(client *redis.Ring, clk clock.Clock, stats prometheus.Registerer) *RedisSource {
    28  	latency := promauto.With(stats).NewHistogramVec(prometheus.HistogramOpts{
    29  		Name: "ratelimits_latency",
    30  		Help: "Histogram of Redis call latencies labeled by call=[set|get|delete|ping] and result=[success|error]",
    31  		// Exponential buckets ranging from 0.0005s to 3s.
    32  		Buckets: prometheus.ExponentialBucketsRange(0.0005, 3, 8),
    33  	}, []string{"call", "result"})
    34  
    35  	return &RedisSource{
    36  		client:  client,
    37  		clk:     clk,
    38  		latency: latency,
    39  	}
    40  }
    41  
    42  var errMixedSuccess = errors.New("some keys not found")
    43  
    44  // resultForError returns a string representing the result of the operation
    45  // based on the provided error.
    46  func resultForError(err error) string {
    47  	if errors.Is(errMixedSuccess, err) {
    48  		// Indicates that some of the keys in a batchset operation were not found.
    49  		return "mixedSuccess"
    50  	} else if errors.Is(redis.Nil, err) {
    51  		// Bucket key does not exist.
    52  		return "notFound"
    53  	} else if errors.Is(err, context.DeadlineExceeded) {
    54  		// Client read or write deadline exceeded.
    55  		return "deadlineExceeded"
    56  	} else if errors.Is(err, context.Canceled) {
    57  		// Caller canceled the operation.
    58  		return "canceled"
    59  	}
    60  	var netErr net.Error
    61  	if errors.As(err, &netErr) && netErr.Timeout() {
    62  		// Dialer timed out connecting to Redis.
    63  		return "timeout"
    64  	}
    65  	var redisErr redis.Error
    66  	if errors.Is(err, redisErr) {
    67  		// An internal error was returned by the Redis server.
    68  		return "redisError"
    69  	}
    70  	return "failed"
    71  }
    72  
    73  func (r *RedisSource) observeLatency(call string, latency time.Duration, err error) {
    74  	result := "success"
    75  	if err != nil {
    76  		result = resultForError(err)
    77  	}
    78  	r.latency.With(prometheus.Labels{"call": call, "result": result}).Observe(latency.Seconds())
    79  }
    80  
    81  // BatchSet stores TATs at the specified bucketKeys using a pipelined Redis
    82  // Transaction in order to reduce the number of round-trips to each Redis shard.
    83  func (r *RedisSource) BatchSet(ctx context.Context, buckets map[string]time.Time) error {
    84  	start := r.clk.Now()
    85  
    86  	pipeline := r.client.Pipeline()
    87  	for bucketKey, tat := range buckets {
    88  		// Set a TTL of TAT + 10 minutes to account for clock skew.
    89  		ttl := tat.UTC().Sub(r.clk.Now()) + 10*time.Minute
    90  		pipeline.Set(ctx, bucketKey, tat.UTC().UnixNano(), ttl)
    91  	}
    92  	_, err := pipeline.Exec(ctx)
    93  	if err != nil {
    94  		r.observeLatency("batchset", r.clk.Since(start), err)
    95  		return err
    96  	}
    97  
    98  	totalLatency := r.clk.Since(start)
    99  
   100  	r.observeLatency("batchset", totalLatency, nil)
   101  	return nil
   102  }
   103  
   104  // BatchSetNotExisting attempts to set TATs for the specified bucketKeys if they
   105  // do not already exist. Returns a map indicating which keys already existed.
   106  func (r *RedisSource) BatchSetNotExisting(ctx context.Context, buckets map[string]time.Time) (map[string]bool, error) {
   107  	start := r.clk.Now()
   108  
   109  	pipeline := r.client.Pipeline()
   110  	cmds := make(map[string]*redis.BoolCmd, len(buckets))
   111  	for bucketKey, tat := range buckets {
   112  		// Set a TTL of TAT + 10 minutes to account for clock skew.
   113  		ttl := tat.UTC().Sub(r.clk.Now()) + 10*time.Minute
   114  		cmds[bucketKey] = pipeline.SetNX(ctx, bucketKey, tat.UTC().UnixNano(), ttl)
   115  	}
   116  	_, err := pipeline.Exec(ctx)
   117  	if err != nil {
   118  		r.observeLatency("batchsetnotexisting", r.clk.Since(start), err)
   119  		return nil, err
   120  	}
   121  
   122  	alreadyExists := make(map[string]bool, len(buckets))
   123  	totalLatency := r.clk.Since(start)
   124  	for bucketKey, cmd := range cmds {
   125  		success, err := cmd.Result()
   126  		if err != nil {
   127  			return nil, err
   128  		}
   129  		if !success {
   130  			alreadyExists[bucketKey] = true
   131  		}
   132  	}
   133  
   134  	r.observeLatency("batchsetnotexisting", totalLatency, nil)
   135  	return alreadyExists, nil
   136  }
   137  
   138  // BatchIncrement updates TATs for the specified bucketKeys using a pipelined
   139  // Redis Transaction in order to reduce the number of round-trips to each Redis
   140  // shard.
   141  func (r *RedisSource) BatchIncrement(ctx context.Context, buckets map[string]increment) error {
   142  	start := r.clk.Now()
   143  
   144  	pipeline := r.client.Pipeline()
   145  	for bucketKey, incr := range buckets {
   146  		pipeline.IncrBy(ctx, bucketKey, incr.cost.Nanoseconds())
   147  		pipeline.Expire(ctx, bucketKey, incr.ttl)
   148  	}
   149  	_, err := pipeline.Exec(ctx)
   150  	if err != nil {
   151  		r.observeLatency("batchincrby", r.clk.Since(start), err)
   152  		return err
   153  	}
   154  
   155  	totalLatency := r.clk.Since(start)
   156  	r.observeLatency("batchincrby", totalLatency, nil)
   157  	return nil
   158  }
   159  
   160  // Get retrieves the TAT at the specified bucketKey. If the bucketKey does not
   161  // exist, ErrBucketNotFound is returned.
   162  func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, error) {
   163  	start := r.clk.Now()
   164  
   165  	tatNano, err := r.client.Get(ctx, bucketKey).Int64()
   166  	if err != nil {
   167  		if errors.Is(err, redis.Nil) {
   168  			// Bucket key does not exist.
   169  			r.observeLatency("get", r.clk.Since(start), err)
   170  			return time.Time{}, ErrBucketNotFound
   171  		}
   172  		// An error occurred while retrieving the TAT.
   173  		r.observeLatency("get", r.clk.Since(start), err)
   174  		return time.Time{}, err
   175  	}
   176  
   177  	r.observeLatency("get", r.clk.Since(start), nil)
   178  	return time.Unix(0, tatNano).UTC(), nil
   179  }
   180  
   181  // BatchGet retrieves the TATs at the specified bucketKeys using a pipelined
   182  // Redis Transaction in order to reduce the number of round-trips to each Redis
   183  // shard. If a bucketKey does not exist, it WILL NOT be included in the returned
   184  // map.
   185  func (r *RedisSource) BatchGet(ctx context.Context, bucketKeys []string) (map[string]time.Time, error) {
   186  	start := r.clk.Now()
   187  
   188  	pipeline := r.client.Pipeline()
   189  	for _, bucketKey := range bucketKeys {
   190  		pipeline.Get(ctx, bucketKey)
   191  	}
   192  	results, err := pipeline.Exec(ctx)
   193  	if err != nil && !errors.Is(err, redis.Nil) {
   194  		r.observeLatency("batchget", r.clk.Since(start), err)
   195  		return nil, err
   196  	}
   197  
   198  	totalLatency := r.clk.Since(start)
   199  
   200  	tats := make(map[string]time.Time, len(bucketKeys))
   201  	notFoundCount := 0
   202  	for i, result := range results {
   203  		tatNano, err := result.(*redis.StringCmd).Int64()
   204  		if err != nil {
   205  			if !errors.Is(err, redis.Nil) {
   206  				// This should never happen as any errors should have been
   207  				// caught after the pipeline.Exec() call.
   208  				r.observeLatency("batchget", r.clk.Since(start), err)
   209  				return nil, err
   210  			}
   211  			notFoundCount++
   212  			continue
   213  		}
   214  		tats[bucketKeys[i]] = time.Unix(0, tatNano).UTC()
   215  	}
   216  
   217  	var batchErr error
   218  	if notFoundCount < len(results) {
   219  		// Some keys were not found.
   220  		batchErr = errMixedSuccess
   221  	} else if notFoundCount == len(results) {
   222  		// All keys were not found.
   223  		batchErr = redis.Nil
   224  	}
   225  
   226  	r.observeLatency("batchget", totalLatency, batchErr)
   227  	return tats, nil
   228  }
   229  
   230  // BatchDelete deletes the TATs at the specified bucketKeys ('name:id'). A nil
   231  // return value does not indicate that the bucketKeys existed.
   232  func (r *RedisSource) BatchDelete(ctx context.Context, bucketKeys []string) error {
   233  	start := r.clk.Now()
   234  
   235  	err := r.client.Del(ctx, bucketKeys...).Err()
   236  	if err != nil {
   237  		r.observeLatency("delete", r.clk.Since(start), err)
   238  		return err
   239  	}
   240  
   241  	r.observeLatency("delete", r.clk.Since(start), nil)
   242  	return nil
   243  }
   244  
   245  // Ping checks that each shard of the *redis.Ring is reachable using the PING
   246  // command.
   247  func (r *RedisSource) Ping(ctx context.Context) error {
   248  	start := r.clk.Now()
   249  
   250  	err := r.client.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error {
   251  		return shard.Ping(ctx).Err()
   252  	})
   253  	if err != nil {
   254  		r.observeLatency("ping", r.clk.Since(start), err)
   255  		return err
   256  	}
   257  
   258  	r.observeLatency("ping", r.clk.Since(start), nil)
   259  	return nil
   260  }