github.com/rudderlabs/rudder-go-kit@v0.30.0/throttling/throttling_test.go (about) 1 package throttling 2 3 import ( 4 "context" 5 "fmt" 6 "sync" 7 "sync/atomic" 8 "testing" 9 "time" 10 11 "github.com/ory/dockertest/v3" 12 "github.com/stretchr/testify/require" 13 14 "github.com/rudderlabs/rudder-go-kit/testhelper/rand" 15 ) 16 17 func TestThrottling(t *testing.T) { 18 pool, err := dockertest.NewPool("") 19 require.NoError(t, err) 20 21 type limiterSettings struct { 22 name string 23 limiter *Limiter 24 // concurrency has been introduced because the GCRA algorithms (both in-memory and Redis) tend to lose precision 25 // when the concurrency is too high. Until we fix it or come up with a guard to limit the amount of concurrent 26 // requests, we're limiting the concurrency to X in the tests (to avoid test flakiness). 27 concurrency int 28 } 29 30 var ( 31 ctx = context.Background() 32 rc = bootstrapRedis(ctx, t, pool) 33 limiters = []limiterSettings{ 34 { 35 name: "gcra", 36 limiter: newLimiter(t, WithInMemoryGCRA(0)), 37 concurrency: 100, 38 }, 39 { 40 name: "gcra redis", 41 limiter: newLimiter(t, WithRedisGCRA(rc, 100)), // TODO: this should work properly with burst = 0 as well (i.e. burst = rate) 42 concurrency: 100, 43 }, 44 { 45 name: "sorted sets redis", 46 limiter: newLimiter(t, WithRedisSortedSet(rc)), 47 concurrency: 5000, 48 }, 49 } 50 ) 51 52 flakinessRate := 1 // increase to run the tests multiple times in a row to debug flaky tests 53 for i := 0; i < flakinessRate; i++ { 54 for _, tc := range []testCase{ 55 // avoid rates that are too small (e.g. 10), that's where there is the most flakiness 56 {rate: 500, window: 1}, 57 {rate: 1000, window: 2}, 58 {rate: 2000, window: 3}, 59 } { 60 for _, l := range limiters { 61 t.Run(testName(l.name, tc.rate, tc.window), func(t *testing.T) { 62 expected := tc.rate 63 testLimiter(ctx, t, l.limiter, tc.rate, tc.window, expected, l.concurrency) 64 }) 65 } 66 } 67 } 68 } 69 70 func testLimiter( 71 ctx context.Context, t *testing.T, l *Limiter, rate, window, expected int64, concurrency int, 72 ) { 73 t.Helper() 74 var ( 75 wg sync.WaitGroup 76 passed int64 77 cost int64 = 1 78 key = rand.UniqueString(10) 79 maxRoutines = make(chan struct{}, concurrency) 80 // Time tracking variables 81 startTime int64 82 currentTime int64 83 timeMutex sync.Mutex 84 stop = make(chan struct{}, 1) 85 ) 86 87 run := func() (err error, allowed bool, redisTime time.Duration) { 88 switch { 89 case l.redisSpeaker != nil && l.useGCRA: 90 redisTime, allowed, _, _, err = l.redisGCRA(ctx, cost, rate, window, key) 91 case l.redisSpeaker != nil && !l.useGCRA: 92 redisTime, allowed, _, _, err = l.redisSortedSet(ctx, cost, rate, window, key) 93 default: 94 allowed, _, err = l.Allow(ctx, cost, rate, window, key) 95 } 96 return 97 } 98 if l.useGCRA { // warm up the GCRA algorithms 99 for i := 0; i < int(rate); i++ { 100 _, _, _ = run() 101 } 102 } 103 loop: 104 105 for { 106 select { 107 case <-stop: 108 // To decrease the error margin (mostly introduced for the Redis algorithms) I'm measuring time in two 109 // different ways depending on whether I'm using an in-memory algorithm or a Redis one. 110 // For the Redis algorithms I measure the elapsed time by keeping track of the timestamps returned by 111 // the Redis Lua scripts. 112 // This is because there is a bit of a drift between the time as we measure it here and the time as it 113 // is measured in the Lua scripts. 114 break loop 115 case maxRoutines <- struct{}{}: 116 wg.Add(1) 117 go func() { 118 defer wg.Done() 119 defer func() { <-maxRoutines }() 120 121 err, allowed, redisTime := run() 122 require.NoError(t, err) 123 now := time.Now().UnixNano() 124 125 timeMutex.Lock() 126 defer timeMutex.Unlock() 127 if redisTime > 0 { // Redis limiters have a getTime() function that returns the current time from Redis perspective 128 if redisTime > time.Duration(currentTime) { 129 currentTime = int64(redisTime) 130 } 131 } else if now > currentTime { // in-memory algorithm, let's use time.Now() 132 currentTime = now 133 } 134 if startTime == 0 { 135 startTime = currentTime 136 } 137 if currentTime-startTime > window*int64(time.Second) { 138 select { 139 case stop <- struct{}{}: 140 default: // one signal to stop is enough, don't block 141 } 142 return // do not increment "passed" because we're over the window 143 } 144 if allowed { 145 atomic.AddInt64(&passed, cost) 146 } 147 }() 148 } 149 } 150 151 wg.Wait() 152 153 diff := expected - passed 154 errorMargin := int64(0.1*float64(rate)) + 1 // ~10% error margin 155 if passed < 1 || diff < -errorMargin || diff > errorMargin { 156 t.Errorf("Expected %d, got %d (diff: %d, error margin: %d)", expected, passed, diff, errorMargin) 157 } 158 } 159 160 func TestGCRABurstAsRate(t *testing.T) { 161 pool, err := dockertest.NewPool("") 162 require.NoError(t, err) 163 164 var ( 165 ctx = context.Background() 166 rc = bootstrapRedis(ctx, t, pool) 167 l = newLimiter(t, WithRedisGCRA(rc, 0)) 168 // Configuration variables 169 key = "foo" 170 cost int64 = 1 171 rate int64 = 500 172 window int64 = 1 173 // Expectations 174 passed int64 175 start = time.Now() 176 ) 177 178 for i := int64(0); i < rate*2; i++ { 179 allowed, _, err := l.Allow(ctx, cost, rate, window, key) 180 require.NoError(t, err) 181 if allowed { 182 passed++ 183 } 184 } 185 186 require.GreaterOrEqual(t, passed, rate) 187 require.Less( 188 t, time.Since(start), time.Duration(window*int64(time.Second)), 189 "we should've been able to make the required request in less than the window duration due to the burst setting", 190 ) 191 } 192 193 func TestReturn(t *testing.T) { 194 pool, err := dockertest.NewPool("") 195 require.NoError(t, err) 196 197 type testCase struct { 198 name string 199 limiter *Limiter 200 } 201 202 var ( 203 rate int64 = 10 204 window int64 = 1 205 windowDur = time.Duration(window) * time.Second 206 ctx = context.Background() 207 rc = bootstrapRedis(ctx, t, pool) 208 testCases = []testCase{ 209 { 210 name: "sorted sets redis", 211 limiter: newLimiter(t, WithRedisSortedSet(rc)), 212 }, 213 } 214 ) 215 216 for _, tc := range testCases { 217 t.Run(testName(tc.name, rate, window), func(t *testing.T) { 218 var ( 219 passed int 220 key = rand.UniqueString(10) 221 tokens []func(context.Context) error 222 ) 223 for i := int64(0); i < rate*10; i++ { 224 allowed, returner, err := tc.limiter.Allow(ctx, 1, rate, window, key) 225 require.NoError(t, err) 226 if allowed { 227 passed++ 228 tokens = append(tokens, returner) 229 } 230 } 231 232 require.EqualValues(t, rate, passed) 233 234 allowed, _, err := tc.limiter.Allow(ctx, 1, rate, window, key) 235 require.NoError(t, err) 236 require.False(t, allowed) 237 238 // return one token and try again 239 require.NoError(t, tokens[0](ctx)) 240 require.Eventually(t, func() bool { 241 allowed, returner, err := tc.limiter.Allow(ctx, 1, rate, window, key) 242 return allowed && err == nil && returner != nil 243 }, windowDur/2, time.Millisecond) 244 }) 245 } 246 } 247 248 func TestBadData(t *testing.T) { 249 pool, err := dockertest.NewPool("") 250 require.NoError(t, err) 251 252 var ( 253 ctx = context.Background() 254 rc = bootstrapRedis(ctx, t, pool) 255 limiters = map[string]*Limiter{ 256 "gcra": newLimiter(t, WithInMemoryGCRA(0)), 257 "gcra redis": newLimiter(t, WithRedisGCRA(rc, 0)), 258 "sorted sets redis": newLimiter(t, WithRedisSortedSet(rc)), 259 } 260 ) 261 262 for name, l := range limiters { 263 t.Run(name+" cost", func(t *testing.T) { 264 allowed, ret, err := l.Allow(ctx, 0, 10, 1, "foo") 265 require.False(t, allowed) 266 require.Nil(t, ret) 267 require.Error(t, err, "cost must be greater than 0") 268 }) 269 t.Run(name+" rate", func(t *testing.T) { 270 allowed, ret, err := l.Allow(ctx, 1, 0, 1, "foo") 271 require.False(t, allowed) 272 require.Nil(t, ret) 273 require.Error(t, err, "rate must be greater than 0") 274 }) 275 t.Run(name+" window", func(t *testing.T) { 276 allowed, ret, err := l.Allow(ctx, 1, 10, 0, "foo") 277 require.False(t, allowed) 278 require.Nil(t, ret) 279 require.Error(t, err, "window must be greater than 0") 280 }) 281 t.Run(name+" key", func(t *testing.T) { 282 allowed, ret, err := l.Allow(ctx, 1, 10, 1, "") 283 require.False(t, allowed) 284 require.Nil(t, ret) 285 require.Error(t, err, "key must not be empty") 286 }) 287 } 288 } 289 290 func TestRetryAfter(t *testing.T) { 291 pool, err := dockertest.NewPool("") 292 require.NoError(t, err) 293 294 type testCase struct { 295 name string 296 limiter *Limiter 297 rate int64 298 window int64 299 runFor time.Duration 300 warmUp bool 301 expectedAllowedCount int 302 expectedAllowedCountDelta float64 303 expectedSleepsCount int 304 expectedSleepsCountDelta float64 305 } 306 307 var ( 308 ctx = context.Background() 309 rc = bootstrapRedis(ctx, t, pool) 310 testCases = []testCase{ 311 { 312 name: "gcra", 313 limiter: newLimiter(t, WithInMemoryGCRA(0)), 314 rate: 2, 315 window: 1, 316 runFor: 3 * time.Second, 317 warmUp: true, 318 expectedAllowedCount: 6, 319 expectedAllowedCountDelta: 3, 320 expectedSleepsCount: 3, 321 expectedSleepsCountDelta: 2, 322 }, 323 { 324 name: "gcra redis", 325 limiter: newLimiter(t, WithRedisGCRA(rc, 100)), 326 rate: 2, 327 window: 1, 328 runFor: 3 * time.Second, 329 warmUp: true, 330 expectedAllowedCount: 6, 331 expectedAllowedCountDelta: 3, 332 expectedSleepsCount: 3, 333 expectedSleepsCountDelta: 2, 334 }, 335 { 336 name: "sorted sets redis", 337 limiter: newLimiter(t, WithRedisSortedSet(rc)), 338 rate: 2, 339 window: 1, 340 runFor: 3 * time.Second, 341 warmUp: false, 342 expectedAllowedCount: 6, 343 expectedAllowedCountDelta: 0, // this algorithm is the most precise but requires more memory on Redis 344 expectedSleepsCount: 3, 345 expectedSleepsCountDelta: 0, // this algorithm is the most precise but requires more memory on Redis 346 }, 347 } 348 ) 349 350 flakinessRate := 1 // increase to run the tests multiple times in a row to debug flaky tests 351 for i := 0; i < flakinessRate; i++ { 352 for _, tc := range testCases { 353 t.Run(testName(tc.name, tc.rate, tc.window), func(t *testing.T) { 354 timeout := time.NewTimer(tc.runFor) 355 t.Cleanup(func() { 356 _ = timeout.Stop 357 }) 358 359 if tc.warmUp { 360 timer := time.NewTimer(time.Duration(tc.window) * time.Second) 361 warmUp: 362 for { 363 _, _, err := tc.limiter.Allow(ctx, 1, tc.rate, tc.window, t.Name()) 364 require.NoError(t, err) 365 select { 366 case <-timer.C: 367 break warmUp 368 default: 369 } 370 } 371 } 372 373 var ( 374 allowedCount int 375 sleepsCount int 376 ) 377 378 loop: 379 for { 380 allowed, retryAfter, _, err := tc.limiter.AllowAfter(ctx, 1, tc.rate, tc.window, t.Name()) 381 require.NoError(t, err) 382 383 t.Logf("allowed: %v, retryAfter: %v", allowed, retryAfter) 384 385 if allowed { 386 require.EqualValues(t, 0, retryAfter) 387 allowedCount++ 388 } else { 389 require.Greater(t, retryAfter, int64(0)) 390 sleepsCount++ 391 } 392 393 select { 394 case <-ctx.Done(): 395 break loop 396 case <-timeout.C: 397 break loop 398 case <-time.After(retryAfter): 399 t.Logf("slept for %v", retryAfter) 400 } 401 } 402 403 require.InDelta(t, tc.expectedAllowedCount, allowedCount, tc.expectedAllowedCountDelta) 404 require.InDelta(t, tc.expectedSleepsCount, sleepsCount, tc.expectedSleepsCountDelta) 405 }) 406 } 407 } 408 } 409 410 func testName(name string, rate, window int64) string { 411 return fmt.Sprintf("%s/%d tokens per %ds", name, rate, window) 412 }