github.com/rudderlabs/rudder-go-kit@v0.30.0/throttling/throttling.go (about) 1 package throttling 2 3 import ( 4 "context" 5 _ "embed" 6 "fmt" 7 "strconv" 8 "strings" 9 "time" 10 11 "github.com/go-redis/redis/v8" 12 13 "github.com/rudderlabs/rudder-go-kit/stats" 14 ) 15 16 /* 17 TODOs: 18 * guard against concurrency? according to benchmarks, Redis performs better if we have no more than 16 routines 19 * see https://github.com/rudderlabs/redis-throttling-playground/blob/main/Benchmarks.md#best-concurrency-setting-with-sortedset---save-1-1-and---appendonly-yes 20 */ 21 22 var ( 23 //go:embed lua/gcra.lua 24 gcraLua string 25 gcraRedisScript *redis.Script 26 //go:embed lua/sortedset.lua 27 sortedSetLua string 28 sortedSetScript *redis.Script 29 ) 30 31 func init() { 32 gcraRedisScript = redis.NewScript(gcraLua) 33 sortedSetScript = redis.NewScript(sortedSetLua) 34 } 35 36 type redisSpeaker interface { 37 redis.Scripter 38 redisSortedSetRemover 39 } 40 41 type statsCollector interface { 42 NewTaggedStat(name, statType string, tags stats.Tags) stats.Measurement 43 } 44 45 type Limiter struct { 46 // for Redis configurations 47 // a default redisSpeaker should always be provided for Redis configurations 48 redisSpeaker redisSpeaker 49 50 // for in-memory configurations 51 gcra *gcra 52 53 // other flags 54 useGCRA bool 55 gcraBurst int64 56 57 // metrics 58 statsCollector statsCollector 59 } 60 61 func New(options ...Option) (*Limiter, error) { 62 rl := &Limiter{} 63 for i := range options { 64 options[i](rl) 65 } 66 if rl.statsCollector == nil { 67 rl.statsCollector = stats.Default 68 } 69 if rl.redisSpeaker != nil { 70 return rl, nil 71 } 72 // Default to in-memory GCRA 73 rl.gcra = &gcra{} 74 rl.useGCRA = true 75 return rl, nil 76 } 77 78 // Allow returns true if the limit is not exceeded, false otherwise. 79 func (l *Limiter) Allow(ctx context.Context, cost, rate, window int64, key string) ( 80 bool, func(context.Context) error, error, 81 ) { 82 allowed, _, tr, err := l.allow(ctx, cost, rate, window, key) 83 return allowed, tr, err 84 } 85 86 // AllowAfter returns true if the limit is not exceeded, false otherwise. 87 // Additionally, it returns the time.Duration until the next allowed request. 88 func (l *Limiter) AllowAfter(ctx context.Context, cost, rate, window int64, key string) ( 89 bool, time.Duration, func(context.Context) error, error, 90 ) { 91 return l.allow(ctx, cost, rate, window, key) 92 } 93 94 func (l *Limiter) allow(ctx context.Context, cost, rate, window int64, key string) ( 95 bool, time.Duration, func(context.Context) error, error, 96 ) { 97 if cost < 1 { 98 return false, 0, nil, fmt.Errorf("cost must be greater than 0") 99 } 100 if rate < 1 { 101 return false, 0, nil, fmt.Errorf("rate must be greater than 0") 102 } 103 if window < 1 { 104 return false, 0, nil, fmt.Errorf("window must be greater than 0") 105 } 106 if key == "" { 107 return false, 0, nil, fmt.Errorf("key must not be empty") 108 } 109 110 if l.redisSpeaker != nil { 111 if l.useGCRA { 112 defer l.getTimer(key, "redis-gcra", rate, window)() 113 _, allowed, retryAfter, tr, err := l.redisGCRA(ctx, cost, rate, window, key) 114 return allowed, retryAfter, tr, err 115 } 116 117 defer l.getTimer(key, "redis-sorted-set", rate, window)() 118 _, allowed, retryAfter, tr, err := l.redisSortedSet(ctx, cost, rate, window, key) 119 return allowed, retryAfter, tr, err 120 } 121 122 defer l.getTimer(key, "gcra", rate, window)() 123 allowed, retryAfter, tr, err := l.gcraLimit(ctx, cost, rate, window, key) 124 return allowed, retryAfter, tr, err 125 } 126 127 func (l *Limiter) redisSortedSet(ctx context.Context, cost, rate, window int64, key string) ( 128 time.Duration, bool, time.Duration, func(context.Context) error, error, 129 ) { 130 res, err := sortedSetScript.Run(ctx, l.redisSpeaker, []string{key}, cost, rate, window).Result() 131 if err != nil { 132 return 0, false, 0, nil, fmt.Errorf("could not run SortedSet Redis script: %v", err) 133 } 134 135 result, ok := res.([]interface{}) 136 if !ok { 137 return 0, false, 0, nil, fmt.Errorf("unexpected result from SortedSet Redis script of type %T: %v", res, res) 138 } 139 if len(result) != 3 { 140 return 0, false, 0, nil, fmt.Errorf("unexpected result from SortedSet Redis script of length %d: %+v", len(result), result) 141 } 142 143 t, ok := result[0].(int64) 144 if !ok { 145 return 0, false, 0, nil, fmt.Errorf("unexpected result[0] from SortedSet Redis script of type %T: %v", result[0], result[0]) 146 } 147 redisTime := time.Duration(t) * time.Microsecond 148 149 members, ok := result[1].(string) 150 if !ok { 151 return redisTime, false, 0, nil, fmt.Errorf("unexpected result[1] from SortedSet Redis script of type %T: %v", result[1], result[1]) 152 } 153 if members == "" { // limit exceeded 154 retryAfter, ok := result[2].(int64) 155 if !ok { 156 return redisTime, false, 0, nil, fmt.Errorf("unexpected result[2] from SortedSet Redis script of type %T: %v", result[2], result[2]) 157 } 158 return redisTime, false, time.Duration(retryAfter) * time.Microsecond, nil, nil 159 } 160 161 r := &sortedSetRedisReturn{ 162 key: key, 163 members: strings.Split(members, ","), 164 remover: l.redisSpeaker, 165 } 166 return redisTime, true, 0, r.Return, nil 167 } 168 169 func (l *Limiter) redisGCRA(ctx context.Context, cost, rate, window int64, key string) ( 170 time.Duration, bool, time.Duration, func(context.Context) error, error, 171 ) { 172 burst := rate 173 if l.gcraBurst > 0 { 174 burst = l.gcraBurst 175 } 176 res, err := gcraRedisScript.Run(ctx, l.redisSpeaker, []string{key}, burst, rate, window, cost).Result() 177 if err != nil { 178 return 0, false, 0, nil, fmt.Errorf("could not run GCRA Redis script: %v", err) 179 } 180 181 result, ok := res.([]any) 182 if !ok { 183 return 0, false, 0, nil, fmt.Errorf("unexpected result from GCRA Redis script of type %T: %v", res, res) 184 } 185 if len(result) != 5 { 186 return 0, false, 0, nil, fmt.Errorf("unexpected result from GCRA Redis scrip of length %d: %+v", len(result), result) 187 } 188 189 t, ok := result[0].(int64) 190 if !ok { 191 return 0, false, 0, nil, fmt.Errorf("unexpected result[0] from GCRA Redis script of type %T: %v", result[0], result[0]) 192 } 193 redisTime := time.Duration(t) * time.Microsecond 194 195 allowed, ok := result[1].(int64) 196 if !ok { 197 return redisTime, false, 0, nil, fmt.Errorf("unexpected result[1] from GCRA Redis script of type %T: %v", result[1], result[1]) 198 } 199 if allowed < 1 { // limit exceeded 200 retryAfter, ok := result[3].(int64) 201 if !ok { 202 return redisTime, false, 0, nil, fmt.Errorf("unexpected result[3] from GCRA Redis script of type %T: %v", result[3], result[3]) 203 } 204 return redisTime, false, time.Duration(retryAfter) * time.Microsecond, nil, nil 205 } 206 207 r := &unsupportedReturn{} 208 return redisTime, true, 0, r.Return, nil 209 } 210 211 func (l *Limiter) gcraLimit(ctx context.Context, cost, rate, window int64, key string) ( 212 bool, time.Duration, func(context.Context) error, error, 213 ) { 214 burst := rate 215 if l.gcraBurst > 0 { 216 burst = l.gcraBurst 217 } 218 allowed, retryAfter, err := l.gcra.limit(ctx, key, cost, burst, rate, window) 219 if err != nil { 220 return false, 0, nil, fmt.Errorf("could not limit: %w", err) 221 } 222 if !allowed { 223 return false, retryAfter, nil, nil // limit exceeded 224 } 225 r := &unsupportedReturn{} 226 return true, 0, r.Return, nil 227 } 228 229 func (l *Limiter) getTimer(key, algo string, rate, window int64) func() { 230 m := l.statsCollector.NewTaggedStat("throttling", stats.TimerType, stats.Tags{ 231 "key": key, 232 "algo": algo, 233 "rate": strconv.FormatInt(rate, 10), 234 "window": strconv.FormatInt(window, 10), 235 }) 236 start := time.Now() 237 return func() { 238 m.Since(start) 239 } 240 }