github.com/m-lab/locate@v0.17.6/limits/ratelimiter.go (about)

     1  package limits
     2  
     3  import (
     4  	"fmt"
     5  	"strconv"
     6  	"time"
     7  
     8  	"github.com/gomodule/redigo/redis"
     9  )
    10  
    11  // LimitStatus indicates the result of a rate limit check
    12  type LimitStatus struct {
    13  	// IsLimited indicates if the request should be rate limited
    14  	IsLimited bool
    15  	// LimitType indicates which limit was exceeded ("ip" or "ipua" or "")
    16  	LimitType string
    17  }
    18  
    19  // LimitConfig holds the configuration for a single rate limit type
    20  type LimitConfig struct {
    21  	// Interval defines the duration of the sliding window
    22  	Interval time.Duration
    23  
    24  	// MaxEvents defines the maximum number of events allowed in the interval
    25  	MaxEvents int
    26  }
    27  
    28  // RateLimitConfig holds the configuration for both IP-only and IP+UA rate limiting.
    29  type RateLimitConfig struct {
    30  	// IPConfig defines the rate limiting configuration for IP-only checks
    31  	IPConfig LimitConfig
    32  
    33  	// IPUAConfig defines the rate limiting configuration for IP+UA checks
    34  	IPUAConfig LimitConfig
    35  
    36  	// KeyPrefix is the prefix for Redis keys
    37  	KeyPrefix string
    38  }
    39  
    40  // RateLimiter implements a distributed rate limiter using Redis sorted sets (ZSET).
    41  // It maintains sliding windows for both IP-only and IP+UA combinations, where:
    42  //   - Each event is stored in a ZSET with the timestamp as score
    43  //   - Old events (outside the window) are automatically removed
    44  //   - Keys automatically expire after the configured interval
    45  //
    46  // The limiter considers a request to be rate-limited if the number of events
    47  // in either window exceeds their respective MaxEvents.
    48  type RateLimiter struct {
    49  	pool       *redis.Pool
    50  	ipConfig   LimitConfig
    51  	ipuaConfig LimitConfig
    52  	keyPrefix  string
    53  }
    54  
    55  // NewRateLimiter creates a new rate limiter.
    56  func NewRateLimiter(pool *redis.Pool, config RateLimitConfig) *RateLimiter {
    57  	return &RateLimiter{
    58  		pool:       pool,
    59  		ipConfig:   config.IPConfig,
    60  		ipuaConfig: config.IPUAConfig,
    61  		keyPrefix:  config.KeyPrefix,
    62  	}
    63  }
    64  
    65  // generateIPKey creates a Redis key from IP only.
    66  func (rl *RateLimiter) generateIPKey(ip string) string {
    67  	return fmt.Sprintf("%s:%s", rl.keyPrefix, ip)
    68  }
    69  
    70  // generateIPUAKey creates a Redis key from IP and User-Agent.
    71  func (rl *RateLimiter) generateIPUAKey(ip, ua string) string {
    72  	// If User-Agent is empty, use "none" as the value. This allows to distinguish
    73  	// between IP-only keys and IPUA keys with an empty User-Agent.
    74  	if ua == "" {
    75  		ua = "none"
    76  	}
    77  	return fmt.Sprintf("%s:%s:%s", rl.keyPrefix, ip, ua)
    78  }
    79  
    80  // IsLimited checks if the given IP and User-Agent combination should be rate limited.
    81  // It first checks the IP-only limit, then the IP+UA limit if the IP-only check passes.
    82  func (rl *RateLimiter) IsLimited(ip, ua string) (LimitStatus, error) {
    83  	conn := rl.pool.Get()
    84  	defer conn.Close()
    85  
    86  	now := time.Now().UnixMicro()
    87  	ipKey := rl.generateIPKey(ip)
    88  	ipuaKey := rl.generateIPUAKey(ip, ua)
    89  
    90  	// Start pipeline for both checks
    91  	// 1. IP-only check
    92  	conn.Send("ZREMRANGEBYSCORE", ipKey, "-inf", now-rl.ipConfig.Interval.Microseconds())
    93  	conn.Send("ZADD", ipKey, now, strconv.FormatInt(now, 10))
    94  	conn.Send("EXPIRE", ipKey, int64(rl.ipConfig.Interval.Seconds()))
    95  	conn.Send("ZCARD", ipKey)
    96  
    97  	// 2. IP+UA limit check
    98  	conn.Send("ZREMRANGEBYSCORE", ipuaKey, "-inf", now-rl.ipuaConfig.Interval.Microseconds())
    99  	conn.Send("ZADD", ipuaKey, now, strconv.FormatInt(now, 10))
   100  	conn.Send("EXPIRE", ipuaKey, int64(rl.ipuaConfig.Interval.Seconds()))
   101  	conn.Send("ZCARD", ipuaKey)
   102  
   103  	// Flush pipeline
   104  	if err := conn.Flush(); err != nil {
   105  		return LimitStatus{}, fmt.Errorf("failed to flush pipeline: %w", err)
   106  	}
   107  
   108  	// Receive first 3 replies for IP limit (ZREMRANGEBYSCORE, ZADD, EXPIRE)
   109  	for i := 0; i < 3; i++ {
   110  		if _, err := conn.Receive(); err != nil {
   111  			return LimitStatus{}, fmt.Errorf("failed to receive IP limit reply %d: %w", i, err)
   112  		}
   113  	}
   114  
   115  	// Receive IP limit count
   116  	ipCount, err := redis.Int64(conn.Receive())
   117  	if err != nil {
   118  		return LimitStatus{}, fmt.Errorf("failed to receive IP limit count: %w", err)
   119  	}
   120  
   121  	// Check IP-only limit first
   122  	if ipCount > int64(rl.ipConfig.MaxEvents) {
   123  		return LimitStatus{
   124  			IsLimited: true,
   125  			LimitType: "ip",
   126  		}, nil
   127  	}
   128  
   129  	// Receive next 3 replies for IP+UA limit (ZREMRANGEBYSCORE, ZADD, EXPIRE)
   130  	for i := 0; i < 3; i++ {
   131  		if _, err := conn.Receive(); err != nil {
   132  			return LimitStatus{}, fmt.Errorf("failed to receive IP+UA limit reply %d: %w", i, err)
   133  		}
   134  	}
   135  
   136  	// Receive IP+UA limit count
   137  	ipuaCount, err := redis.Int64(conn.Receive())
   138  	if err != nil {
   139  		return LimitStatus{}, fmt.Errorf("failed to receive IP+UA limit count: %w", err)
   140  	}
   141  
   142  	// Check IP+UA limit
   143  	if ipuaCount > int64(rl.ipuaConfig.MaxEvents) {
   144  		return LimitStatus{
   145  			IsLimited: true,
   146  			LimitType: "ipua",
   147  		}, nil
   148  	}
   149  
   150  	return LimitStatus{
   151  		IsLimited: false,
   152  		LimitType: "",
   153  	}, nil
   154  }