github.com/letsencrypt/boulder@v0.20251208.0/ratelimits/source_redis.go (about) 1 package ratelimits 2 3 import ( 4 "context" 5 "errors" 6 "net" 7 "time" 8 9 "github.com/jmhodges/clock" 10 "github.com/prometheus/client_golang/prometheus" 11 "github.com/prometheus/client_golang/prometheus/promauto" 12 "github.com/redis/go-redis/v9" 13 ) 14 15 // Compile-time check that RedisSource implements the source interface. 16 var _ Source = (*RedisSource)(nil) 17 18 // RedisSource is a ratelimits source backed by sharded Redis. 19 type RedisSource struct { 20 client *redis.Ring 21 clk clock.Clock 22 latency *prometheus.HistogramVec 23 } 24 25 // NewRedisSource returns a new Redis backed source using the provided 26 // *redis.Ring client. 27 func NewRedisSource(client *redis.Ring, clk clock.Clock, stats prometheus.Registerer) *RedisSource { 28 latency := promauto.With(stats).NewHistogramVec(prometheus.HistogramOpts{ 29 Name: "ratelimits_latency", 30 Help: "Histogram of Redis call latencies labeled by call=[set|get|delete|ping] and result=[success|error]", 31 // Exponential buckets ranging from 0.0005s to 3s. 32 Buckets: prometheus.ExponentialBucketsRange(0.0005, 3, 8), 33 }, []string{"call", "result"}) 34 35 return &RedisSource{ 36 client: client, 37 clk: clk, 38 latency: latency, 39 } 40 } 41 42 var errMixedSuccess = errors.New("some keys not found") 43 44 // resultForError returns a string representing the result of the operation 45 // based on the provided error. 46 func resultForError(err error) string { 47 if errors.Is(errMixedSuccess, err) { 48 // Indicates that some of the keys in a batchset operation were not found. 49 return "mixedSuccess" 50 } else if errors.Is(redis.Nil, err) { 51 // Bucket key does not exist. 52 return "notFound" 53 } else if errors.Is(err, context.DeadlineExceeded) { 54 // Client read or write deadline exceeded. 55 return "deadlineExceeded" 56 } else if errors.Is(err, context.Canceled) { 57 // Caller canceled the operation. 58 return "canceled" 59 } 60 var netErr net.Error 61 if errors.As(err, &netErr) && netErr.Timeout() { 62 // Dialer timed out connecting to Redis. 63 return "timeout" 64 } 65 var redisErr redis.Error 66 if errors.Is(err, redisErr) { 67 // An internal error was returned by the Redis server. 68 return "redisError" 69 } 70 return "failed" 71 } 72 73 func (r *RedisSource) observeLatency(call string, latency time.Duration, err error) { 74 result := "success" 75 if err != nil { 76 result = resultForError(err) 77 } 78 r.latency.With(prometheus.Labels{"call": call, "result": result}).Observe(latency.Seconds()) 79 } 80 81 // BatchSet stores TATs at the specified bucketKeys using a pipelined Redis 82 // Transaction in order to reduce the number of round-trips to each Redis shard. 83 func (r *RedisSource) BatchSet(ctx context.Context, buckets map[string]time.Time) error { 84 start := r.clk.Now() 85 86 pipeline := r.client.Pipeline() 87 for bucketKey, tat := range buckets { 88 // Set a TTL of TAT + 10 minutes to account for clock skew. 89 ttl := tat.UTC().Sub(r.clk.Now()) + 10*time.Minute 90 pipeline.Set(ctx, bucketKey, tat.UTC().UnixNano(), ttl) 91 } 92 _, err := pipeline.Exec(ctx) 93 if err != nil { 94 r.observeLatency("batchset", r.clk.Since(start), err) 95 return err 96 } 97 98 totalLatency := r.clk.Since(start) 99 100 r.observeLatency("batchset", totalLatency, nil) 101 return nil 102 } 103 104 // BatchSetNotExisting attempts to set TATs for the specified bucketKeys if they 105 // do not already exist. Returns a map indicating which keys already existed. 106 func (r *RedisSource) BatchSetNotExisting(ctx context.Context, buckets map[string]time.Time) (map[string]bool, error) { 107 start := r.clk.Now() 108 109 pipeline := r.client.Pipeline() 110 cmds := make(map[string]*redis.BoolCmd, len(buckets)) 111 for bucketKey, tat := range buckets { 112 // Set a TTL of TAT + 10 minutes to account for clock skew. 113 ttl := tat.UTC().Sub(r.clk.Now()) + 10*time.Minute 114 cmds[bucketKey] = pipeline.SetNX(ctx, bucketKey, tat.UTC().UnixNano(), ttl) 115 } 116 _, err := pipeline.Exec(ctx) 117 if err != nil { 118 r.observeLatency("batchsetnotexisting", r.clk.Since(start), err) 119 return nil, err 120 } 121 122 alreadyExists := make(map[string]bool, len(buckets)) 123 totalLatency := r.clk.Since(start) 124 for bucketKey, cmd := range cmds { 125 success, err := cmd.Result() 126 if err != nil { 127 return nil, err 128 } 129 if !success { 130 alreadyExists[bucketKey] = true 131 } 132 } 133 134 r.observeLatency("batchsetnotexisting", totalLatency, nil) 135 return alreadyExists, nil 136 } 137 138 // BatchIncrement updates TATs for the specified bucketKeys using a pipelined 139 // Redis Transaction in order to reduce the number of round-trips to each Redis 140 // shard. 141 func (r *RedisSource) BatchIncrement(ctx context.Context, buckets map[string]increment) error { 142 start := r.clk.Now() 143 144 pipeline := r.client.Pipeline() 145 for bucketKey, incr := range buckets { 146 pipeline.IncrBy(ctx, bucketKey, incr.cost.Nanoseconds()) 147 pipeline.Expire(ctx, bucketKey, incr.ttl) 148 } 149 _, err := pipeline.Exec(ctx) 150 if err != nil { 151 r.observeLatency("batchincrby", r.clk.Since(start), err) 152 return err 153 } 154 155 totalLatency := r.clk.Since(start) 156 r.observeLatency("batchincrby", totalLatency, nil) 157 return nil 158 } 159 160 // Get retrieves the TAT at the specified bucketKey. If the bucketKey does not 161 // exist, ErrBucketNotFound is returned. 162 func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, error) { 163 start := r.clk.Now() 164 165 tatNano, err := r.client.Get(ctx, bucketKey).Int64() 166 if err != nil { 167 if errors.Is(err, redis.Nil) { 168 // Bucket key does not exist. 169 r.observeLatency("get", r.clk.Since(start), err) 170 return time.Time{}, ErrBucketNotFound 171 } 172 // An error occurred while retrieving the TAT. 173 r.observeLatency("get", r.clk.Since(start), err) 174 return time.Time{}, err 175 } 176 177 r.observeLatency("get", r.clk.Since(start), nil) 178 return time.Unix(0, tatNano).UTC(), nil 179 } 180 181 // BatchGet retrieves the TATs at the specified bucketKeys using a pipelined 182 // Redis Transaction in order to reduce the number of round-trips to each Redis 183 // shard. If a bucketKey does not exist, it WILL NOT be included in the returned 184 // map. 185 func (r *RedisSource) BatchGet(ctx context.Context, bucketKeys []string) (map[string]time.Time, error) { 186 start := r.clk.Now() 187 188 pipeline := r.client.Pipeline() 189 for _, bucketKey := range bucketKeys { 190 pipeline.Get(ctx, bucketKey) 191 } 192 results, err := pipeline.Exec(ctx) 193 if err != nil && !errors.Is(err, redis.Nil) { 194 r.observeLatency("batchget", r.clk.Since(start), err) 195 return nil, err 196 } 197 198 totalLatency := r.clk.Since(start) 199 200 tats := make(map[string]time.Time, len(bucketKeys)) 201 notFoundCount := 0 202 for i, result := range results { 203 tatNano, err := result.(*redis.StringCmd).Int64() 204 if err != nil { 205 if !errors.Is(err, redis.Nil) { 206 // This should never happen as any errors should have been 207 // caught after the pipeline.Exec() call. 208 r.observeLatency("batchget", r.clk.Since(start), err) 209 return nil, err 210 } 211 notFoundCount++ 212 continue 213 } 214 tats[bucketKeys[i]] = time.Unix(0, tatNano).UTC() 215 } 216 217 var batchErr error 218 if notFoundCount < len(results) { 219 // Some keys were not found. 220 batchErr = errMixedSuccess 221 } else if notFoundCount == len(results) { 222 // All keys were not found. 223 batchErr = redis.Nil 224 } 225 226 r.observeLatency("batchget", totalLatency, batchErr) 227 return tats, nil 228 } 229 230 // BatchDelete deletes the TATs at the specified bucketKeys ('name:id'). A nil 231 // return value does not indicate that the bucketKeys existed. 232 func (r *RedisSource) BatchDelete(ctx context.Context, bucketKeys []string) error { 233 start := r.clk.Now() 234 235 err := r.client.Del(ctx, bucketKeys...).Err() 236 if err != nil { 237 r.observeLatency("delete", r.clk.Since(start), err) 238 return err 239 } 240 241 r.observeLatency("delete", r.clk.Since(start), nil) 242 return nil 243 } 244 245 // Ping checks that each shard of the *redis.Ring is reachable using the PING 246 // command. 247 func (r *RedisSource) Ping(ctx context.Context) error { 248 start := r.clk.Now() 249 250 err := r.client.ForEachShard(ctx, func(ctx context.Context, shard *redis.Client) error { 251 return shard.Ping(ctx).Err() 252 }) 253 if err != nil { 254 r.observeLatency("ping", r.clk.Since(start), err) 255 return err 256 } 257 258 r.observeLatency("ping", r.clk.Since(start), nil) 259 return nil 260 }