github.com/livekit/protocol@v1.16.1-0.20240517185851-47e4c6bba773/utils/rate_test.go (about) 1 // Copyright (c) 2016,2020 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 // 21 // SOURCE: https://github.com/uber-go/ratelimit/blob/main/ratelimit_test.go 22 // EDIT: slight modification to allow setting rate limit on the fly 23 // SCOPE: LeakyBucket 24 package utils 25 26 import ( 27 "sync" 28 "testing" 29 "time" 30 31 "go.uber.org/atomic" 32 33 "github.com/benbjohnson/clock" 34 "github.com/stretchr/testify/assert" 35 "github.com/stretchr/testify/require" 36 ) 37 38 const UnstableTest = "UNSTABLE TEST" 39 40 // options from upstream, but stripped these 41 // Note: This file is inspired by: 42 // https://github.com/prashantv/go-bench/blob/master/ratelimit 43 44 // Limiter is used to rate-limit some process, possibly across goroutines. 45 // The process is expected to call Take() before every iteration, which 46 // may block to throttle the goroutine. 47 type Limiter interface { 48 // Take should block to make sure that the RPS is met. 49 Take() time.Time 50 } 51 52 // config configures a limiter. 53 type config struct { 54 clock Clock 55 slack int 56 per time.Duration 57 } 58 59 // buildConfig combines defaults with options. 60 func buildConfig(opts []Option) config { 61 c := config{ 62 clock: clock.New(), 63 slack: 10, 64 per: time.Second, 65 } 66 67 for _, opt := range opts { 68 opt.apply(&c) 69 } 70 return c 71 } 72 73 // Option configures a Limiter. 74 type Option interface { 75 apply(*config) 76 } 77 78 type clockOption struct { 79 clock Clock 80 } 81 82 func (o clockOption) apply(c *config) { 83 c.clock = o.clock 84 } 85 86 // WithClock returns an option for ratelimit.New that provides an alternate 87 // Clock implementation, typically a mock Clock for testing. 88 func WithClock(clock Clock) Option { 89 return clockOption{clock: clock} 90 } 91 92 type slackOption int 93 94 func (o slackOption) apply(c *config) { 95 c.slack = int(o) 96 } 97 98 // WithoutSlack configures the limiter to be strict and not to accumulate 99 // previously "unspent" requests for future bursts of traffic. 100 var WithoutSlack Option = slackOption(0) 101 102 // WithSlack configures custom slack. 103 // Slack allows the limiter to accumulate "unspent" requests 104 // for future bursts of traffic. 105 func WithSlack(slack int) Option { 106 return slackOption(slack) 107 } 108 109 type perOption time.Duration 110 111 func (p perOption) apply(c *config) { 112 c.per = time.Duration(p) 113 } 114 115 // Per allows configuring limits for different time windows. 116 // 117 // The default window is one second, so New(100) produces a one hundred per 118 // second (100 Hz) rate limiter. 119 // 120 // New(2, Per(60*time.Second)) creates a 2 per minute rate limiter. 121 func Per(per time.Duration) Option { 122 return perOption(per) 123 } 124 125 type testRunner interface { 126 // createLimiter builds a limiter with given options. 127 createLimiter(int, ...Option) Limiter 128 // takeOnceAfter attempts to Take at a specific time. 129 takeOnceAfter(time.Duration, Limiter) 130 // startTaking tries to Take() on passed in limiters in a loop/goroutine. 131 startTaking(rls ...Limiter) 132 // assertCountAt asserts the limiters have Taken() a number of times at the given time. 133 // It's a thin wrapper around afterFunc to reduce boilerplate code. 134 assertCountAt(d time.Duration, count int) 135 // afterFunc executes a func at a given time. 136 // not using clock.AfterFunc because andres-erbsen/clock misses a nap there. 137 afterFunc(d time.Duration, fn func()) 138 // some tests want raw access to the clock. 139 getClock() *clock.Mock 140 } 141 142 type runnerImpl struct { 143 t *testing.T 144 145 clock *clock.Mock 146 constructor func(int, ...Option) Limiter 147 count atomic.Int32 148 // maxDuration is the time we need to move into the future for a test. 149 // It's populated automatically based on assertCountAt/afterFunc. 150 maxDuration time.Duration 151 doneCh chan struct{} 152 wg sync.WaitGroup 153 } 154 155 func runTest(t *testing.T, fn func(testRunner)) { 156 impls := []struct { 157 name string 158 constructor func(int, ...Option) Limiter 159 }{ 160 { 161 name: "mutex", 162 constructor: func(rate int, opts ...Option) Limiter { 163 config := buildConfig(opts) 164 perRequest := config.per / time.Duration(rate) 165 cfg := leakyBucketConfig{ 166 perRequest: perRequest, 167 maxSlack: -1 * time.Duration(config.slack) * perRequest, 168 } 169 l := &LeakyBucket{ 170 clock: config.clock, 171 } 172 l.cfg.Store(&cfg) 173 return l 174 }, 175 }, 176 } 177 178 for _, tt := range impls { 179 t.Run(tt.name, func(t *testing.T) { 180 // Set a non-default time.Time since some limiters (int64 in particular) use 181 // the default value as "non-initialized" state. 182 clockMock := clock.NewMock() 183 clockMock.Set(time.Now()) 184 r := runnerImpl{ 185 t: t, 186 clock: clockMock, 187 constructor: tt.constructor, 188 doneCh: make(chan struct{}), 189 } 190 defer close(r.doneCh) 191 defer r.wg.Wait() 192 193 fn(&r) 194 r.clock.Add(r.maxDuration) 195 }) 196 } 197 } 198 199 // createLimiter builds a limiter with given options. 200 func (r *runnerImpl) createLimiter(rate int, opts ...Option) Limiter { 201 opts = append(opts, WithClock(r.clock)) 202 return r.constructor(rate, opts...) 203 } 204 205 func (r *runnerImpl) getClock() *clock.Mock { 206 return r.clock 207 } 208 209 // startTaking tries to Take() on passed in limiters in a loop/goroutine. 210 func (r *runnerImpl) startTaking(rls ...Limiter) { 211 r.goWait(func() { 212 for { 213 for _, rl := range rls { 214 rl.Take() 215 } 216 r.count.Inc() 217 select { 218 case <-r.doneCh: 219 return 220 default: 221 } 222 } 223 }) 224 } 225 226 // takeOnceAfter attempts to Take at a specific time. 227 func (r *runnerImpl) takeOnceAfter(d time.Duration, rl Limiter) { 228 r.wg.Add(1) 229 r.afterFunc(d, func() { 230 rl.Take() 231 r.count.Inc() 232 r.wg.Done() 233 }) 234 } 235 236 // assertCountAt asserts the limiters have Taken() a number of times at a given time. 237 func (r *runnerImpl) assertCountAt(d time.Duration, count int) { 238 r.wg.Add(1) 239 r.afterFunc(d, func() { 240 assert.Equal(r.t, int32(count), r.count.Load(), "count not as expected") 241 r.wg.Done() 242 }) 243 } 244 245 // afterFunc executes a func at a given time. 246 func (r *runnerImpl) afterFunc(d time.Duration, fn func()) { 247 if d > r.maxDuration { 248 r.maxDuration = d 249 } 250 251 r.goWait(func() { 252 select { 253 case <-r.doneCh: 254 return 255 case <-r.clock.After(d): 256 } 257 fn() 258 }) 259 } 260 261 // goWait runs a function in a goroutine and makes sure the goroutine was scheduled. 262 func (r *runnerImpl) goWait(fn func()) { 263 wg := sync.WaitGroup{} 264 wg.Add(1) 265 go func() { 266 wg.Done() 267 fn() 268 }() 269 wg.Wait() 270 } 271 272 func TestRateLimiter(t *testing.T) { 273 t.Parallel() 274 runTest(t, func(r testRunner) { 275 rl := r.createLimiter(100, WithoutSlack) 276 277 // Create copious counts concurrently. 278 r.startTaking(rl) 279 r.startTaking(rl) 280 r.startTaking(rl) 281 r.startTaking(rl) 282 283 r.assertCountAt(1*time.Second, 100) 284 r.assertCountAt(2*time.Second, 200) 285 r.assertCountAt(3*time.Second, 300) 286 }) 287 } 288 289 func TestDelayedRateLimiter(t *testing.T) { 290 t.Skip(UnstableTest) 291 t.Parallel() 292 runTest(t, func(r testRunner) { 293 slow := r.createLimiter(10, WithoutSlack) 294 fast := r.createLimiter(100, WithoutSlack) 295 296 r.startTaking(slow, fast) 297 298 r.afterFunc(20*time.Second, func() { 299 r.startTaking(fast) 300 r.startTaking(fast) 301 r.startTaking(fast) 302 r.startTaking(fast) 303 }) 304 305 r.assertCountAt(30*time.Second, 1200) 306 }) 307 } 308 309 func TestPer(t *testing.T) { 310 t.Parallel() 311 runTest(t, func(r testRunner) { 312 rl := r.createLimiter(7, WithoutSlack, Per(time.Minute)) 313 314 r.startTaking(rl) 315 r.startTaking(rl) 316 317 r.assertCountAt(1*time.Second, 1) 318 r.assertCountAt(1*time.Minute, 8) 319 r.assertCountAt(2*time.Minute, 15) 320 }) 321 } 322 323 // TestInitial verifies that the initial sequence is scheduled as expected. 324 func TestInitial(t *testing.T) { 325 t.Parallel() 326 tests := []struct { 327 msg string 328 opts []Option 329 }{ 330 { 331 msg: "With Slack", 332 }, 333 { 334 msg: "Without Slack", 335 opts: []Option{WithoutSlack}, 336 }, 337 } 338 339 for _, tt := range tests { 340 t.Run(tt.msg, func(t *testing.T) { 341 runTest(t, func(r testRunner) { 342 rl := r.createLimiter(10, tt.opts...) 343 344 var ( 345 clk = r.getClock() 346 prev = clk.Now() 347 348 results = make(chan time.Time) 349 have []time.Duration 350 startWg sync.WaitGroup 351 ) 352 startWg.Add(3) 353 354 for i := 0; i < 3; i++ { 355 go func() { 356 startWg.Done() 357 results <- rl.Take() 358 }() 359 } 360 361 startWg.Wait() 362 clk.Add(time.Second) 363 364 for i := 0; i < 3; i++ { 365 ts := <-results 366 have = append(have, ts.Sub(prev)) 367 prev = ts 368 } 369 370 assert.Equal(t, 371 []time.Duration{ 372 0, 373 time.Millisecond * 100, 374 time.Millisecond * 100, 375 }, 376 have, 377 "bad timestamps for inital takes", 378 ) 379 }) 380 }) 381 } 382 } 383 384 func TestMaxSlack(t *testing.T) { 385 t.Parallel() 386 runTest(t, func(r testRunner) { 387 rl := r.createLimiter(1, WithSlack(1)) 388 389 r.takeOnceAfter(time.Nanosecond, rl) 390 r.takeOnceAfter(2*time.Second+1*time.Nanosecond, rl) 391 r.takeOnceAfter(2*time.Second+2*time.Nanosecond, rl) 392 r.takeOnceAfter(2*time.Second+3*time.Nanosecond, rl) 393 r.takeOnceAfter(2*time.Second+4*time.Nanosecond, rl) 394 395 r.assertCountAt(3*time.Second, 3) 396 r.assertCountAt(10*time.Second, 5) 397 }) 398 } 399 400 func TestSlack(t *testing.T) { 401 t.Parallel() 402 // To simulate slack, we combine two limiters. 403 // - First, we start a single goroutine with both of them, 404 // during this time the slow limiter will dominate, 405 // and allow the fast limiter to accumulate slack. 406 // - After 2 seconds, we start another goroutine with 407 // only the faster limiter. This will allow it to max out, 408 // and consume all the slack. 409 // - After 3 seconds, we look at the final result, and we expect, 410 // a sum of: 411 // - slower limiter running for 3 seconds 412 // - faster limiter running for 1 second 413 // - slack accumulated by the faster limiter during the two seconds. 414 // it was blocked by slower limiter. 415 tests := []struct { 416 msg string 417 opt []Option 418 want int 419 }{ 420 { 421 msg: "no option, defaults to 10", 422 // 2*10 + 1*100 + 1*10 (slack) 423 want: 130, 424 }, 425 { 426 msg: "slack of 10, like default", 427 opt: []Option{WithSlack(10)}, 428 // 2*10 + 1*100 + 1*10 (slack) 429 want: 130, 430 }, 431 { 432 msg: "slack of 20", 433 opt: []Option{WithSlack(20)}, 434 // 2*10 + 1*100 + 1*20 (slack) 435 want: 140, 436 }, 437 { 438 // Note this is bigger then the rate of the limiter. 439 msg: "slack of 150", 440 opt: []Option{WithSlack(150)}, 441 // 2*10 + 1*100 + 1*150 (slack) 442 want: 270, 443 }, 444 { 445 msg: "no option, defaults to 10, with per", 446 // 2*(10*2) + 1*(100*2) + 1*10 (slack) 447 opt: []Option{Per(500 * time.Millisecond)}, 448 want: 230, 449 }, 450 { 451 msg: "slack of 10, like default, with per", 452 opt: []Option{WithSlack(10), Per(500 * time.Millisecond)}, 453 // 2*(10*2) + 1*(100*2) + 1*10 (slack) 454 want: 230, 455 }, 456 { 457 msg: "slack of 20, with per", 458 opt: []Option{WithSlack(20), Per(500 * time.Millisecond)}, 459 // 2*(10*2) + 1*(100*2) + 1*20 (slack) 460 want: 240, 461 }, 462 { 463 // Note this is bigger then the rate of the limiter. 464 msg: "slack of 150, with per", 465 opt: []Option{WithSlack(150), Per(500 * time.Millisecond)}, 466 // 2*(10*2) + 1*(100*2) + 1*150 (slack) 467 want: 370, 468 }, 469 } 470 471 for _, tt := range tests { 472 cfg := buildConfig(tt.opt) 473 if cfg.slack >= 100 { 474 t.Skip(UnstableTest) 475 } 476 477 t.Run(tt.msg, func(t *testing.T) { 478 runTest(t, func(r testRunner) { 479 slow := r.createLimiter(10, WithoutSlack) 480 fast := r.createLimiter(100, tt.opt...) 481 482 r.startTaking(slow, fast) 483 484 r.afterFunc(2*time.Second, func() { 485 r.startTaking(fast) 486 r.startTaking(fast) 487 }) 488 489 // limiter with 10hz dominates here - we're always at 10. 490 r.assertCountAt(1*time.Second, 10) 491 r.assertCountAt(3*time.Second, tt.want) 492 }) 493 }) 494 } 495 } 496 497 func TestSetRateLimitOnTheFly(t *testing.T) { 498 runTest(t, func(r testRunner) { 499 // Set rate to 1hz 500 limiter, ok := r.createLimiter(1, WithoutSlack).(*LeakyBucket) 501 if !ok { 502 t.Skip("Update is not supported") 503 } 504 505 r.startTaking(limiter) 506 r.assertCountAt(time.Second, 2) 507 508 r.getClock().Add(time.Second) 509 r.assertCountAt(time.Second, 3) 510 511 // increase to 2hz 512 limiter.Update(2, 0) 513 r.getClock().Add(time.Second) 514 r.assertCountAt(time.Second, 4) // <- delayed due to paying sleepFor debt 515 r.getClock().Add(time.Second) 516 r.assertCountAt(time.Second, 6) 517 518 // reduce to 1hz again 519 limiter.Update(1, 0) 520 r.getClock().Add(time.Second) 521 r.assertCountAt(time.Second, 7) 522 r.getClock().Add(time.Second) 523 r.assertCountAt(time.Second, 8) 524 525 slack := 3 526 require.GreaterOrEqual(t, limiter.sleepFor, time.Duration(0)) 527 limiter.Update(1, slack) 528 r.getClock().Add(time.Second * time.Duration(slack)) 529 r.assertCountAt(time.Second, 8+slack) 530 }) 531 }