github.com/livekit/protocol@v1.16.1-0.20240517185851-47e4c6bba773/utils/rate_test.go (about)

     1  // Copyright (c) 2016,2020 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  //
    21  // SOURCE: https://github.com/uber-go/ratelimit/blob/main/ratelimit_test.go
    22  // EDIT: slight modification to allow setting rate limit on the fly
    23  // SCOPE: LeakyBucket
    24  package utils
    25  
    26  import (
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	"go.uber.org/atomic"
    32  
    33  	"github.com/benbjohnson/clock"
    34  	"github.com/stretchr/testify/assert"
    35  	"github.com/stretchr/testify/require"
    36  )
    37  
    38  const UnstableTest = "UNSTABLE TEST"
    39  
    40  // options from upstream, but stripped these
    41  // Note: This file is inspired by:
    42  // https://github.com/prashantv/go-bench/blob/master/ratelimit
    43  
    44  // Limiter is used to rate-limit some process, possibly across goroutines.
    45  // The process is expected to call Take() before every iteration, which
    46  // may block to throttle the goroutine.
    47  type Limiter interface {
    48  	// Take should block to make sure that the RPS is met.
    49  	Take() time.Time
    50  }
    51  
    52  // config configures a limiter.
    53  type config struct {
    54  	clock Clock
    55  	slack int
    56  	per   time.Duration
    57  }
    58  
    59  // buildConfig combines defaults with options.
    60  func buildConfig(opts []Option) config {
    61  	c := config{
    62  		clock: clock.New(),
    63  		slack: 10,
    64  		per:   time.Second,
    65  	}
    66  
    67  	for _, opt := range opts {
    68  		opt.apply(&c)
    69  	}
    70  	return c
    71  }
    72  
    73  // Option configures a Limiter.
    74  type Option interface {
    75  	apply(*config)
    76  }
    77  
    78  type clockOption struct {
    79  	clock Clock
    80  }
    81  
    82  func (o clockOption) apply(c *config) {
    83  	c.clock = o.clock
    84  }
    85  
    86  // WithClock returns an option for ratelimit.New that provides an alternate
    87  // Clock implementation, typically a mock Clock for testing.
    88  func WithClock(clock Clock) Option {
    89  	return clockOption{clock: clock}
    90  }
    91  
    92  type slackOption int
    93  
    94  func (o slackOption) apply(c *config) {
    95  	c.slack = int(o)
    96  }
    97  
    98  // WithoutSlack configures the limiter to be strict and not to accumulate
    99  // previously "unspent" requests for future bursts of traffic.
   100  var WithoutSlack Option = slackOption(0)
   101  
   102  // WithSlack configures custom slack.
   103  // Slack allows the limiter to accumulate "unspent" requests
   104  // for future bursts of traffic.
   105  func WithSlack(slack int) Option {
   106  	return slackOption(slack)
   107  }
   108  
   109  type perOption time.Duration
   110  
   111  func (p perOption) apply(c *config) {
   112  	c.per = time.Duration(p)
   113  }
   114  
   115  // Per allows configuring limits for different time windows.
   116  //
   117  // The default window is one second, so New(100) produces a one hundred per
   118  // second (100 Hz) rate limiter.
   119  //
   120  // New(2, Per(60*time.Second)) creates a 2 per minute rate limiter.
   121  func Per(per time.Duration) Option {
   122  	return perOption(per)
   123  }
   124  
   125  type testRunner interface {
   126  	// createLimiter builds a limiter with given options.
   127  	createLimiter(int, ...Option) Limiter
   128  	// takeOnceAfter attempts to Take at a specific time.
   129  	takeOnceAfter(time.Duration, Limiter)
   130  	// startTaking tries to Take() on passed in limiters in a loop/goroutine.
   131  	startTaking(rls ...Limiter)
   132  	// assertCountAt asserts the limiters have Taken() a number of times at the given time.
   133  	// It's a thin wrapper around afterFunc to reduce boilerplate code.
   134  	assertCountAt(d time.Duration, count int)
   135  	// afterFunc executes a func at a given time.
   136  	// not using clock.AfterFunc because andres-erbsen/clock misses a nap there.
   137  	afterFunc(d time.Duration, fn func())
   138  	// some tests want raw access to the clock.
   139  	getClock() *clock.Mock
   140  }
   141  
   142  type runnerImpl struct {
   143  	t *testing.T
   144  
   145  	clock       *clock.Mock
   146  	constructor func(int, ...Option) Limiter
   147  	count       atomic.Int32
   148  	// maxDuration is the time we need to move into the future for a test.
   149  	// It's populated automatically based on assertCountAt/afterFunc.
   150  	maxDuration time.Duration
   151  	doneCh      chan struct{}
   152  	wg          sync.WaitGroup
   153  }
   154  
   155  func runTest(t *testing.T, fn func(testRunner)) {
   156  	impls := []struct {
   157  		name        string
   158  		constructor func(int, ...Option) Limiter
   159  	}{
   160  		{
   161  			name: "mutex",
   162  			constructor: func(rate int, opts ...Option) Limiter {
   163  				config := buildConfig(opts)
   164  				perRequest := config.per / time.Duration(rate)
   165  				cfg := leakyBucketConfig{
   166  					perRequest: perRequest,
   167  					maxSlack:   -1 * time.Duration(config.slack) * perRequest,
   168  				}
   169  				l := &LeakyBucket{
   170  					clock: config.clock,
   171  				}
   172  				l.cfg.Store(&cfg)
   173  				return l
   174  			},
   175  		},
   176  	}
   177  
   178  	for _, tt := range impls {
   179  		t.Run(tt.name, func(t *testing.T) {
   180  			// Set a non-default time.Time since some limiters (int64 in particular) use
   181  			// the default value as "non-initialized" state.
   182  			clockMock := clock.NewMock()
   183  			clockMock.Set(time.Now())
   184  			r := runnerImpl{
   185  				t:           t,
   186  				clock:       clockMock,
   187  				constructor: tt.constructor,
   188  				doneCh:      make(chan struct{}),
   189  			}
   190  			defer close(r.doneCh)
   191  			defer r.wg.Wait()
   192  
   193  			fn(&r)
   194  			r.clock.Add(r.maxDuration)
   195  		})
   196  	}
   197  }
   198  
   199  // createLimiter builds a limiter with given options.
   200  func (r *runnerImpl) createLimiter(rate int, opts ...Option) Limiter {
   201  	opts = append(opts, WithClock(r.clock))
   202  	return r.constructor(rate, opts...)
   203  }
   204  
   205  func (r *runnerImpl) getClock() *clock.Mock {
   206  	return r.clock
   207  }
   208  
   209  // startTaking tries to Take() on passed in limiters in a loop/goroutine.
   210  func (r *runnerImpl) startTaking(rls ...Limiter) {
   211  	r.goWait(func() {
   212  		for {
   213  			for _, rl := range rls {
   214  				rl.Take()
   215  			}
   216  			r.count.Inc()
   217  			select {
   218  			case <-r.doneCh:
   219  				return
   220  			default:
   221  			}
   222  		}
   223  	})
   224  }
   225  
   226  // takeOnceAfter attempts to Take at a specific time.
   227  func (r *runnerImpl) takeOnceAfter(d time.Duration, rl Limiter) {
   228  	r.wg.Add(1)
   229  	r.afterFunc(d, func() {
   230  		rl.Take()
   231  		r.count.Inc()
   232  		r.wg.Done()
   233  	})
   234  }
   235  
   236  // assertCountAt asserts the limiters have Taken() a number of times at a given time.
   237  func (r *runnerImpl) assertCountAt(d time.Duration, count int) {
   238  	r.wg.Add(1)
   239  	r.afterFunc(d, func() {
   240  		assert.Equal(r.t, int32(count), r.count.Load(), "count not as expected")
   241  		r.wg.Done()
   242  	})
   243  }
   244  
   245  // afterFunc executes a func at a given time.
   246  func (r *runnerImpl) afterFunc(d time.Duration, fn func()) {
   247  	if d > r.maxDuration {
   248  		r.maxDuration = d
   249  	}
   250  
   251  	r.goWait(func() {
   252  		select {
   253  		case <-r.doneCh:
   254  			return
   255  		case <-r.clock.After(d):
   256  		}
   257  		fn()
   258  	})
   259  }
   260  
   261  // goWait runs a function in a goroutine and makes sure the goroutine was scheduled.
   262  func (r *runnerImpl) goWait(fn func()) {
   263  	wg := sync.WaitGroup{}
   264  	wg.Add(1)
   265  	go func() {
   266  		wg.Done()
   267  		fn()
   268  	}()
   269  	wg.Wait()
   270  }
   271  
   272  func TestRateLimiter(t *testing.T) {
   273  	t.Parallel()
   274  	runTest(t, func(r testRunner) {
   275  		rl := r.createLimiter(100, WithoutSlack)
   276  
   277  		// Create copious counts concurrently.
   278  		r.startTaking(rl)
   279  		r.startTaking(rl)
   280  		r.startTaking(rl)
   281  		r.startTaking(rl)
   282  
   283  		r.assertCountAt(1*time.Second, 100)
   284  		r.assertCountAt(2*time.Second, 200)
   285  		r.assertCountAt(3*time.Second, 300)
   286  	})
   287  }
   288  
   289  func TestDelayedRateLimiter(t *testing.T) {
   290  	t.Skip(UnstableTest)
   291  	t.Parallel()
   292  	runTest(t, func(r testRunner) {
   293  		slow := r.createLimiter(10, WithoutSlack)
   294  		fast := r.createLimiter(100, WithoutSlack)
   295  
   296  		r.startTaking(slow, fast)
   297  
   298  		r.afterFunc(20*time.Second, func() {
   299  			r.startTaking(fast)
   300  			r.startTaking(fast)
   301  			r.startTaking(fast)
   302  			r.startTaking(fast)
   303  		})
   304  
   305  		r.assertCountAt(30*time.Second, 1200)
   306  	})
   307  }
   308  
   309  func TestPer(t *testing.T) {
   310  	t.Parallel()
   311  	runTest(t, func(r testRunner) {
   312  		rl := r.createLimiter(7, WithoutSlack, Per(time.Minute))
   313  
   314  		r.startTaking(rl)
   315  		r.startTaking(rl)
   316  
   317  		r.assertCountAt(1*time.Second, 1)
   318  		r.assertCountAt(1*time.Minute, 8)
   319  		r.assertCountAt(2*time.Minute, 15)
   320  	})
   321  }
   322  
   323  // TestInitial verifies that the initial sequence is scheduled as expected.
   324  func TestInitial(t *testing.T) {
   325  	t.Parallel()
   326  	tests := []struct {
   327  		msg  string
   328  		opts []Option
   329  	}{
   330  		{
   331  			msg: "With Slack",
   332  		},
   333  		{
   334  			msg:  "Without Slack",
   335  			opts: []Option{WithoutSlack},
   336  		},
   337  	}
   338  
   339  	for _, tt := range tests {
   340  		t.Run(tt.msg, func(t *testing.T) {
   341  			runTest(t, func(r testRunner) {
   342  				rl := r.createLimiter(10, tt.opts...)
   343  
   344  				var (
   345  					clk  = r.getClock()
   346  					prev = clk.Now()
   347  
   348  					results = make(chan time.Time)
   349  					have    []time.Duration
   350  					startWg sync.WaitGroup
   351  				)
   352  				startWg.Add(3)
   353  
   354  				for i := 0; i < 3; i++ {
   355  					go func() {
   356  						startWg.Done()
   357  						results <- rl.Take()
   358  					}()
   359  				}
   360  
   361  				startWg.Wait()
   362  				clk.Add(time.Second)
   363  
   364  				for i := 0; i < 3; i++ {
   365  					ts := <-results
   366  					have = append(have, ts.Sub(prev))
   367  					prev = ts
   368  				}
   369  
   370  				assert.Equal(t,
   371  					[]time.Duration{
   372  						0,
   373  						time.Millisecond * 100,
   374  						time.Millisecond * 100,
   375  					},
   376  					have,
   377  					"bad timestamps for inital takes",
   378  				)
   379  			})
   380  		})
   381  	}
   382  }
   383  
   384  func TestMaxSlack(t *testing.T) {
   385  	t.Parallel()
   386  	runTest(t, func(r testRunner) {
   387  		rl := r.createLimiter(1, WithSlack(1))
   388  
   389  		r.takeOnceAfter(time.Nanosecond, rl)
   390  		r.takeOnceAfter(2*time.Second+1*time.Nanosecond, rl)
   391  		r.takeOnceAfter(2*time.Second+2*time.Nanosecond, rl)
   392  		r.takeOnceAfter(2*time.Second+3*time.Nanosecond, rl)
   393  		r.takeOnceAfter(2*time.Second+4*time.Nanosecond, rl)
   394  
   395  		r.assertCountAt(3*time.Second, 3)
   396  		r.assertCountAt(10*time.Second, 5)
   397  	})
   398  }
   399  
   400  func TestSlack(t *testing.T) {
   401  	t.Parallel()
   402  	// To simulate slack, we combine two limiters.
   403  	// - First, we start a single goroutine with both of them,
   404  	//   during this time the slow limiter will dominate,
   405  	//   and allow the fast limiter to accumulate slack.
   406  	// - After 2 seconds, we start another goroutine with
   407  	//   only the faster limiter. This will allow it to max out,
   408  	//   and consume all the slack.
   409  	// - After 3 seconds, we look at the final result, and we expect,
   410  	//   a sum of:
   411  	//   - slower limiter running for 3 seconds
   412  	//   - faster limiter running for 1 second
   413  	//   - slack accumulated by the faster limiter during the two seconds.
   414  	//     it was blocked by slower limiter.
   415  	tests := []struct {
   416  		msg  string
   417  		opt  []Option
   418  		want int
   419  	}{
   420  		{
   421  			msg: "no option, defaults to 10",
   422  			// 2*10 + 1*100 + 1*10 (slack)
   423  			want: 130,
   424  		},
   425  		{
   426  			msg: "slack of 10, like default",
   427  			opt: []Option{WithSlack(10)},
   428  			// 2*10 + 1*100 + 1*10 (slack)
   429  			want: 130,
   430  		},
   431  		{
   432  			msg: "slack of 20",
   433  			opt: []Option{WithSlack(20)},
   434  			// 2*10 + 1*100 + 1*20 (slack)
   435  			want: 140,
   436  		},
   437  		{
   438  			// Note this is bigger then the rate of the limiter.
   439  			msg: "slack of 150",
   440  			opt: []Option{WithSlack(150)},
   441  			// 2*10 + 1*100 + 1*150 (slack)
   442  			want: 270,
   443  		},
   444  		{
   445  			msg: "no option, defaults to 10, with per",
   446  			// 2*(10*2) + 1*(100*2) + 1*10 (slack)
   447  			opt:  []Option{Per(500 * time.Millisecond)},
   448  			want: 230,
   449  		},
   450  		{
   451  			msg: "slack of 10, like default, with per",
   452  			opt: []Option{WithSlack(10), Per(500 * time.Millisecond)},
   453  			// 2*(10*2) + 1*(100*2) + 1*10 (slack)
   454  			want: 230,
   455  		},
   456  		{
   457  			msg: "slack of 20, with per",
   458  			opt: []Option{WithSlack(20), Per(500 * time.Millisecond)},
   459  			// 2*(10*2) + 1*(100*2) + 1*20 (slack)
   460  			want: 240,
   461  		},
   462  		{
   463  			// Note this is bigger then the rate of the limiter.
   464  			msg: "slack of 150, with per",
   465  			opt: []Option{WithSlack(150), Per(500 * time.Millisecond)},
   466  			// 2*(10*2) + 1*(100*2) + 1*150 (slack)
   467  			want: 370,
   468  		},
   469  	}
   470  
   471  	for _, tt := range tests {
   472  		cfg := buildConfig(tt.opt)
   473  		if cfg.slack >= 100 {
   474  			t.Skip(UnstableTest)
   475  		}
   476  
   477  		t.Run(tt.msg, func(t *testing.T) {
   478  			runTest(t, func(r testRunner) {
   479  				slow := r.createLimiter(10, WithoutSlack)
   480  				fast := r.createLimiter(100, tt.opt...)
   481  
   482  				r.startTaking(slow, fast)
   483  
   484  				r.afterFunc(2*time.Second, func() {
   485  					r.startTaking(fast)
   486  					r.startTaking(fast)
   487  				})
   488  
   489  				// limiter with 10hz dominates here - we're always at 10.
   490  				r.assertCountAt(1*time.Second, 10)
   491  				r.assertCountAt(3*time.Second, tt.want)
   492  			})
   493  		})
   494  	}
   495  }
   496  
   497  func TestSetRateLimitOnTheFly(t *testing.T) {
   498  	runTest(t, func(r testRunner) {
   499  		// Set rate to 1hz
   500  		limiter, ok := r.createLimiter(1, WithoutSlack).(*LeakyBucket)
   501  		if !ok {
   502  			t.Skip("Update is not supported")
   503  		}
   504  
   505  		r.startTaking(limiter)
   506  		r.assertCountAt(time.Second, 2)
   507  
   508  		r.getClock().Add(time.Second)
   509  		r.assertCountAt(time.Second, 3)
   510  
   511  		// increase to 2hz
   512  		limiter.Update(2, 0)
   513  		r.getClock().Add(time.Second)
   514  		r.assertCountAt(time.Second, 4) // <- delayed due to paying sleepFor debt
   515  		r.getClock().Add(time.Second)
   516  		r.assertCountAt(time.Second, 6)
   517  
   518  		// reduce to 1hz again
   519  		limiter.Update(1, 0)
   520  		r.getClock().Add(time.Second)
   521  		r.assertCountAt(time.Second, 7)
   522  		r.getClock().Add(time.Second)
   523  		r.assertCountAt(time.Second, 8)
   524  
   525  		slack := 3
   526  		require.GreaterOrEqual(t, limiter.sleepFor, time.Duration(0))
   527  		limiter.Update(1, slack)
   528  		r.getClock().Add(time.Second * time.Duration(slack))
   529  		r.assertCountAt(time.Second, 8+slack)
   530  	})
   531  }