github.com/m-lab/locate@v0.17.6/limits/ratelimiter.go (about) 1 package limits 2 3 import ( 4 "fmt" 5 "strconv" 6 "time" 7 8 "github.com/gomodule/redigo/redis" 9 ) 10 11 // LimitStatus indicates the result of a rate limit check 12 type LimitStatus struct { 13 // IsLimited indicates if the request should be rate limited 14 IsLimited bool 15 // LimitType indicates which limit was exceeded ("ip" or "ipua" or "") 16 LimitType string 17 } 18 19 // LimitConfig holds the configuration for a single rate limit type 20 type LimitConfig struct { 21 // Interval defines the duration of the sliding window 22 Interval time.Duration 23 24 // MaxEvents defines the maximum number of events allowed in the interval 25 MaxEvents int 26 } 27 28 // RateLimitConfig holds the configuration for both IP-only and IP+UA rate limiting. 29 type RateLimitConfig struct { 30 // IPConfig defines the rate limiting configuration for IP-only checks 31 IPConfig LimitConfig 32 33 // IPUAConfig defines the rate limiting configuration for IP+UA checks 34 IPUAConfig LimitConfig 35 36 // KeyPrefix is the prefix for Redis keys 37 KeyPrefix string 38 } 39 40 // RateLimiter implements a distributed rate limiter using Redis sorted sets (ZSET). 41 // It maintains sliding windows for both IP-only and IP+UA combinations, where: 42 // - Each event is stored in a ZSET with the timestamp as score 43 // - Old events (outside the window) are automatically removed 44 // - Keys automatically expire after the configured interval 45 // 46 // The limiter considers a request to be rate-limited if the number of events 47 // in either window exceeds their respective MaxEvents. 48 type RateLimiter struct { 49 pool *redis.Pool 50 ipConfig LimitConfig 51 ipuaConfig LimitConfig 52 keyPrefix string 53 } 54 55 // NewRateLimiter creates a new rate limiter. 56 func NewRateLimiter(pool *redis.Pool, config RateLimitConfig) *RateLimiter { 57 return &RateLimiter{ 58 pool: pool, 59 ipConfig: config.IPConfig, 60 ipuaConfig: config.IPUAConfig, 61 keyPrefix: config.KeyPrefix, 62 } 63 } 64 65 // generateIPKey creates a Redis key from IP only. 66 func (rl *RateLimiter) generateIPKey(ip string) string { 67 return fmt.Sprintf("%s:%s", rl.keyPrefix, ip) 68 } 69 70 // generateIPUAKey creates a Redis key from IP and User-Agent. 71 func (rl *RateLimiter) generateIPUAKey(ip, ua string) string { 72 // If User-Agent is empty, use "none" as the value. This allows to distinguish 73 // between IP-only keys and IPUA keys with an empty User-Agent. 74 if ua == "" { 75 ua = "none" 76 } 77 return fmt.Sprintf("%s:%s:%s", rl.keyPrefix, ip, ua) 78 } 79 80 // IsLimited checks if the given IP and User-Agent combination should be rate limited. 81 // It first checks the IP-only limit, then the IP+UA limit if the IP-only check passes. 82 func (rl *RateLimiter) IsLimited(ip, ua string) (LimitStatus, error) { 83 conn := rl.pool.Get() 84 defer conn.Close() 85 86 now := time.Now().UnixMicro() 87 ipKey := rl.generateIPKey(ip) 88 ipuaKey := rl.generateIPUAKey(ip, ua) 89 90 // Start pipeline for both checks 91 // 1. IP-only check 92 conn.Send("ZREMRANGEBYSCORE", ipKey, "-inf", now-rl.ipConfig.Interval.Microseconds()) 93 conn.Send("ZADD", ipKey, now, strconv.FormatInt(now, 10)) 94 conn.Send("EXPIRE", ipKey, int64(rl.ipConfig.Interval.Seconds())) 95 conn.Send("ZCARD", ipKey) 96 97 // 2. IP+UA limit check 98 conn.Send("ZREMRANGEBYSCORE", ipuaKey, "-inf", now-rl.ipuaConfig.Interval.Microseconds()) 99 conn.Send("ZADD", ipuaKey, now, strconv.FormatInt(now, 10)) 100 conn.Send("EXPIRE", ipuaKey, int64(rl.ipuaConfig.Interval.Seconds())) 101 conn.Send("ZCARD", ipuaKey) 102 103 // Flush pipeline 104 if err := conn.Flush(); err != nil { 105 return LimitStatus{}, fmt.Errorf("failed to flush pipeline: %w", err) 106 } 107 108 // Receive first 3 replies for IP limit (ZREMRANGEBYSCORE, ZADD, EXPIRE) 109 for i := 0; i < 3; i++ { 110 if _, err := conn.Receive(); err != nil { 111 return LimitStatus{}, fmt.Errorf("failed to receive IP limit reply %d: %w", i, err) 112 } 113 } 114 115 // Receive IP limit count 116 ipCount, err := redis.Int64(conn.Receive()) 117 if err != nil { 118 return LimitStatus{}, fmt.Errorf("failed to receive IP limit count: %w", err) 119 } 120 121 // Check IP-only limit first 122 if ipCount > int64(rl.ipConfig.MaxEvents) { 123 return LimitStatus{ 124 IsLimited: true, 125 LimitType: "ip", 126 }, nil 127 } 128 129 // Receive next 3 replies for IP+UA limit (ZREMRANGEBYSCORE, ZADD, EXPIRE) 130 for i := 0; i < 3; i++ { 131 if _, err := conn.Receive(); err != nil { 132 return LimitStatus{}, fmt.Errorf("failed to receive IP+UA limit reply %d: %w", i, err) 133 } 134 } 135 136 // Receive IP+UA limit count 137 ipuaCount, err := redis.Int64(conn.Receive()) 138 if err != nil { 139 return LimitStatus{}, fmt.Errorf("failed to receive IP+UA limit count: %w", err) 140 } 141 142 // Check IP+UA limit 143 if ipuaCount > int64(rl.ipuaConfig.MaxEvents) { 144 return LimitStatus{ 145 IsLimited: true, 146 LimitType: "ipua", 147 }, nil 148 } 149 150 return LimitStatus{ 151 IsLimited: false, 152 LimitType: "", 153 }, nil 154 }