github.com/livekit/protocol@v1.39.3/utils/rate_test.go (about)

     1  // Copyright (c) 2016,2020 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  //
    21  // SOURCE: https://github.com/uber-go/ratelimit/blob/main/ratelimit_test.go
    22  // EDIT: slight modification to allow setting rate limit on the fly
    23  // SCOPE: LeakyBucket
    24  package utils
    25  
    26  import (
    27  	"sync"
    28  	"testing"
    29  	"time"
    30  
    31  	"go.uber.org/atomic"
    32  
    33  	"github.com/benbjohnson/clock"
    34  	"github.com/stretchr/testify/require"
    35  )
    36  
    37  const UnstableTest = "UNSTABLE TEST"
    38  
    39  // options from upstream, but stripped these
    40  // Note: This file is inspired by:
    41  // https://github.com/prashantv/go-bench/blob/master/ratelimit
    42  
    43  // Limiter is used to rate-limit some process, possibly across goroutines.
    44  // The process is expected to call Take() before every iteration, which
    45  // may block to throttle the goroutine.
    46  type Limiter interface {
    47  	// Take should block to make sure that the RPS is met.
    48  	Take() time.Time
    49  }
    50  
    51  // config configures a limiter.
    52  type config struct {
    53  	clock Clock
    54  	slack int
    55  	per   time.Duration
    56  }
    57  
    58  // buildConfig combines defaults with options.
    59  func buildConfig(opts []Option) config {
    60  	c := config{
    61  		clock: clock.New(),
    62  		slack: 10,
    63  		per:   time.Second,
    64  	}
    65  
    66  	for _, opt := range opts {
    67  		opt.apply(&c)
    68  	}
    69  	return c
    70  }
    71  
    72  // Option configures a Limiter.
    73  type Option interface {
    74  	apply(*config)
    75  }
    76  
    77  type clockOption struct {
    78  	clock Clock
    79  }
    80  
    81  func (o clockOption) apply(c *config) {
    82  	c.clock = o.clock
    83  }
    84  
    85  // WithClock returns an option for ratelimit.New that provides an alternate
    86  // Clock implementation, typically a mock Clock for testing.
    87  func WithClock(clock Clock) Option {
    88  	return clockOption{clock: clock}
    89  }
    90  
    91  type slackOption int
    92  
    93  func (o slackOption) apply(c *config) {
    94  	c.slack = int(o)
    95  }
    96  
    97  // WithoutSlack configures the limiter to be strict and not to accumulate
    98  // previously "unspent" requests for future bursts of traffic.
    99  var WithoutSlack Option = slackOption(0)
   100  
   101  // WithSlack configures custom slack.
   102  // Slack allows the limiter to accumulate "unspent" requests
   103  // for future bursts of traffic.
   104  func WithSlack(slack int) Option {
   105  	return slackOption(slack)
   106  }
   107  
   108  type perOption time.Duration
   109  
   110  func (p perOption) apply(c *config) {
   111  	c.per = time.Duration(p)
   112  }
   113  
   114  // Per allows configuring limits for different time windows.
   115  //
   116  // The default window is one second, so New(100) produces a one hundred per
   117  // second (100 Hz) rate limiter.
   118  //
   119  // New(2, Per(60*time.Second)) creates a 2 per minute rate limiter.
   120  func Per(per time.Duration) Option {
   121  	return perOption(per)
   122  }
   123  
   124  type testRunner interface {
   125  	// createLimiter builds a limiter with given options.
   126  	createLimiter(int, ...Option) Limiter
   127  	// takeOnceAfter attempts to Take at a specific time.
   128  	takeOnceAfter(time.Duration, Limiter)
   129  	// startTaking tries to Take() on passed in limiters in a loop/goroutine.
   130  	startTaking(rls ...Limiter)
   131  	// assertCountAt asserts the limiters have Taken() a number of times at the given time.
   132  	// It's a thin wrapper around afterFunc to reduce boilerplate code.
   133  	assertCountAt(d time.Duration, count int)
   134  	assertCountAtWithNoise(d time.Duration, count int, noise int)
   135  	// afterFunc executes a func at a given time.
   136  	// not using clock.AfterFunc because andres-erbsen/clock misses a nap there.
   137  	afterFunc(d time.Duration, fn func())
   138  	// some tests want raw access to the clock.
   139  	getClock() *clock.Mock
   140  }
   141  
   142  type runnerImpl struct {
   143  	t *testing.T
   144  
   145  	clock       *clock.Mock
   146  	constructor func(int, ...Option) Limiter
   147  	count       atomic.Int32
   148  	// maxDuration is the time we need to move into the future for a test.
   149  	// It's populated automatically based on assertCountAt/afterFunc.
   150  	maxDuration time.Duration
   151  	doneCh      chan struct{}
   152  	wg          sync.WaitGroup
   153  }
   154  
   155  func runTest(t *testing.T, fn func(testRunner)) {
   156  	impls := []struct {
   157  		name        string
   158  		constructor func(int, ...Option) Limiter
   159  	}{
   160  		{
   161  			name: "mutex",
   162  			constructor: func(rate int, opts ...Option) Limiter {
   163  				config := buildConfig(opts)
   164  				perRequest := config.per / time.Duration(rate)
   165  				cfg := leakyBucketConfig{
   166  					perRequest: perRequest,
   167  					maxSlack:   -1 * time.Duration(config.slack) * perRequest,
   168  				}
   169  				l := &LeakyBucket{
   170  					clock: config.clock,
   171  				}
   172  				l.cfg.Store(&cfg)
   173  				return l
   174  			},
   175  		},
   176  	}
   177  
   178  	for _, tt := range impls {
   179  		t.Run(tt.name, func(t *testing.T) {
   180  			// Set a non-default time.Time since some limiters (int64 in particular) use
   181  			// the default value as "non-initialized" state.
   182  			clockMock := clock.NewMock()
   183  			clockMock.Set(time.Now())
   184  			r := runnerImpl{
   185  				t:           t,
   186  				clock:       clockMock,
   187  				constructor: tt.constructor,
   188  				doneCh:      make(chan struct{}),
   189  			}
   190  			defer close(r.doneCh)
   191  			defer r.wg.Wait()
   192  
   193  			fn(&r)
   194  
   195  			// it's possible that there are some goroutines still waiting
   196  			// in taking the bandwidth. We need to keep moving the clock forward
   197  			// until all goroutines are finished
   198  			go func() {
   199  				ticker := time.NewTicker(5 * time.Millisecond)
   200  				defer ticker.Stop()
   201  
   202  				for {
   203  					select {
   204  					case <-ticker.C:
   205  						r.clock.Add(r.maxDuration)
   206  					case <-r.doneCh:
   207  					}
   208  				}
   209  			}()
   210  		})
   211  	}
   212  }
   213  
   214  // createLimiter builds a limiter with given options.
   215  func (r *runnerImpl) createLimiter(rate int, opts ...Option) Limiter {
   216  	opts = append(opts, WithClock(r.clock))
   217  	return r.constructor(rate, opts...)
   218  }
   219  
   220  func (r *runnerImpl) getClock() *clock.Mock {
   221  	return r.clock
   222  }
   223  
   224  // startTaking tries to Take() on passed in limiters in a loop/goroutine.
   225  func (r *runnerImpl) startTaking(rls ...Limiter) {
   226  	r.goWait(func() {
   227  		for {
   228  			for _, rl := range rls {
   229  				rl.Take()
   230  			}
   231  			r.count.Inc()
   232  			select {
   233  			case <-r.doneCh:
   234  				return
   235  			default:
   236  			}
   237  		}
   238  	})
   239  }
   240  
   241  // takeOnceAfter attempts to Take at a specific time.
   242  func (r *runnerImpl) takeOnceAfter(d time.Duration, rl Limiter) {
   243  	r.wg.Add(1)
   244  	r.afterFunc(d, func() {
   245  		rl.Take()
   246  		r.count.Inc()
   247  		r.wg.Done()
   248  	})
   249  }
   250  
   251  // assertCountAt asserts the limiters have Taken() a number of times at a given time.
   252  func (r *runnerImpl) assertCountAt(d time.Duration, count int) {
   253  	r.wg.Add(1)
   254  	r.afterFunc(d, func() {
   255  		defer r.wg.Done()
   256  		require.Equal(r.t, int32(count), r.count.Load(), "count not as expected")
   257  	})
   258  }
   259  
   260  // assertCountAtWithNoise like assertCountAt but also considers possible noise in CI
   261  func (r *runnerImpl) assertCountAtWithNoise(d time.Duration, count int, noise int) {
   262  	r.wg.Add(1)
   263  	r.afterFunc(d, func() {
   264  		defer r.wg.Done()
   265  		require.InDelta(r.t, count, int(r.count.Load()), float64(noise),
   266  			"expected count to be within noise tolerance")
   267  	})
   268  }
   269  
   270  // afterFunc executes a func at a given time.
   271  func (r *runnerImpl) afterFunc(d time.Duration, fn func()) {
   272  	if d > r.maxDuration {
   273  		r.maxDuration = d
   274  	}
   275  
   276  	r.goWait(func() {
   277  		select {
   278  		case <-r.doneCh:
   279  			return
   280  		case <-r.clock.After(d):
   281  		}
   282  		fn()
   283  	})
   284  }
   285  
   286  // goWait runs a function in a goroutine and makes sure the goroutine was scheduled.
   287  func (r *runnerImpl) goWait(fn func()) {
   288  	wg := sync.WaitGroup{}
   289  	wg.Add(1)
   290  	go func() {
   291  		wg.Done()
   292  		fn()
   293  	}()
   294  	wg.Wait()
   295  }
   296  
   297  func TestRateLimiter(t *testing.T) {
   298  	runTest(t, func(r testRunner) {
   299  		rl := r.createLimiter(100, WithoutSlack)
   300  
   301  		// Create copious counts concurrently.
   302  		r.startTaking(rl)
   303  		r.startTaking(rl)
   304  		r.startTaking(rl)
   305  		r.startTaking(rl)
   306  
   307  		r.assertCountAtWithNoise(1*time.Second, 100, 2)
   308  		r.assertCountAtWithNoise(2*time.Second, 200, 2)
   309  		r.assertCountAtWithNoise(3*time.Second, 300, 2)
   310  	})
   311  }
   312  
   313  func TestDelayedRateLimiter(t *testing.T) {
   314  	t.Skip(UnstableTest)
   315  	runTest(t, func(r testRunner) {
   316  		slow := r.createLimiter(10, WithoutSlack)
   317  		fast := r.createLimiter(100, WithoutSlack)
   318  
   319  		r.startTaking(slow, fast)
   320  
   321  		r.afterFunc(20*time.Second, func() {
   322  			r.startTaking(fast)
   323  			r.startTaking(fast)
   324  			r.startTaking(fast)
   325  			r.startTaking(fast)
   326  		})
   327  
   328  		r.assertCountAt(30*time.Second, 1200)
   329  	})
   330  }
   331  
   332  func TestPer(t *testing.T) {
   333  	runTest(t, func(r testRunner) {
   334  		rl := r.createLimiter(7, WithoutSlack, Per(time.Minute))
   335  
   336  		r.startTaking(rl)
   337  		r.startTaking(rl)
   338  
   339  		r.assertCountAt(1*time.Second, 1)
   340  		r.assertCountAt(1*time.Minute, 8)
   341  		r.assertCountAt(2*time.Minute, 15)
   342  	})
   343  }
   344  
   345  // TestInitial verifies that the initial sequence is scheduled as expected.
   346  func TestInitial(t *testing.T) {
   347  	tests := []struct {
   348  		msg  string
   349  		opts []Option
   350  	}{
   351  		{
   352  			msg: "With Slack",
   353  		},
   354  		{
   355  			msg:  "Without Slack",
   356  			opts: []Option{WithoutSlack},
   357  		},
   358  	}
   359  
   360  	for _, tt := range tests {
   361  		t.Run(tt.msg, func(t *testing.T) {
   362  			runTest(t, func(r testRunner) {
   363  				perRequest := 100 * time.Millisecond
   364  				rl := r.createLimiter(10, tt.opts...)
   365  
   366  				var (
   367  					clk  = r.getClock()
   368  					prev = clk.Now()
   369  
   370  					results = make(chan time.Time, 3)
   371  					have    []time.Duration
   372  				)
   373  
   374  				results <- rl.Take()
   375  				clk.Add(perRequest)
   376  
   377  				results <- rl.Take()
   378  				clk.Add(perRequest)
   379  
   380  				results <- rl.Take()
   381  				clk.Add(perRequest)
   382  
   383  				for i := 0; i < 3; i++ {
   384  					ts := <-results
   385  					have = append(have, ts.Sub(prev))
   386  					prev = ts
   387  				}
   388  
   389  				require.Equal(t,
   390  					[]time.Duration{
   391  						0,
   392  						perRequest,
   393  						perRequest,
   394  					},
   395  					have,
   396  					"bad timestamps for inital takes",
   397  				)
   398  			})
   399  		})
   400  	}
   401  }
   402  
   403  func TestMaxSlack(t *testing.T) {
   404  	runTest(t, func(r testRunner) {
   405  		clock := r.getClock()
   406  		rl := r.createLimiter(1, WithSlack(1))
   407  		rl.Take()
   408  		clock.Add(time.Second)
   409  		rl.Take()
   410  		clock.Add(time.Second)
   411  		rl.Take()
   412  
   413  		doneCh := make(chan struct{})
   414  		go func() {
   415  			rl.Take()
   416  			close(doneCh)
   417  		}()
   418  
   419  		select {
   420  		case <-doneCh:
   421  			require.Fail(t, "expect rate limiter to be waiting")
   422  		case <-time.After(time.Millisecond):
   423  			// clean up ratelimiter waiting for take
   424  			clock.Add(time.Second)
   425  		}
   426  	})
   427  }
   428  
   429  func TestSlack(t *testing.T) {
   430  	t.Skip(UnstableTest)
   431  
   432  	// To simulate slack, we combine two limiters.
   433  	// - First, we start a single goroutine with both of them,
   434  	//   during this time the slow limiter will dominate,
   435  	//   and allow the fast limiter to accumulate slack.
   436  	// - After 2 seconds, we start another goroutine with
   437  	//   only the faster limiter. This will allow it to max out,
   438  	//   and consume all the slack.
   439  	// - After 3 seconds, we look at the final result, and we expect,
   440  	//   a sum of:
   441  	//   - slower limiter running for 3 seconds
   442  	//   - faster limiter running for 1 second
   443  	//   - slack accumulated by the faster limiter during the two seconds.
   444  	//     it was blocked by slower limiter.
   445  	tests := []struct {
   446  		msg  string
   447  		opt  []Option
   448  		want int
   449  	}{
   450  		{
   451  			msg: "no option, defaults to 10",
   452  			// 2*10 + 1*100 + 1*10 (slack)
   453  			want: 130,
   454  		},
   455  		{
   456  			msg: "slack of 10, like default",
   457  			opt: []Option{WithSlack(10)},
   458  			// 2*10 + 1*100 + 1*10 (slack)
   459  			want: 130,
   460  		},
   461  		{
   462  			msg: "slack of 20",
   463  			opt: []Option{WithSlack(20)},
   464  			// 2*10 + 1*100 + 1*20 (slack)
   465  			want: 140,
   466  		},
   467  		{
   468  			// Note this is bigger then the rate of the limiter.
   469  			msg: "slack of 150",
   470  			opt: []Option{WithSlack(150)},
   471  			// 2*10 + 1*100 + 1*150 (slack)
   472  			want: 270,
   473  		},
   474  		{
   475  			msg: "no option, defaults to 10, with per",
   476  			// 2*(10*2) + 1*(100*2) + 1*10 (slack)
   477  			opt:  []Option{Per(500 * time.Millisecond)},
   478  			want: 230,
   479  		},
   480  		{
   481  			msg: "slack of 10, like default, with per",
   482  			opt: []Option{WithSlack(10), Per(500 * time.Millisecond)},
   483  			// 2*(10*2) + 1*(100*2) + 1*10 (slack)
   484  			want: 230,
   485  		},
   486  		{
   487  			msg: "slack of 20, with per",
   488  			opt: []Option{WithSlack(20), Per(500 * time.Millisecond)},
   489  			// 2*(10*2) + 1*(100*2) + 1*20 (slack)
   490  			want: 240,
   491  		},
   492  		{
   493  			// Note this is bigger then the rate of the limiter.
   494  			msg: "slack of 150, with per",
   495  			opt: []Option{WithSlack(150), Per(500 * time.Millisecond)},
   496  			// 2*(10*2) + 1*(100*2) + 1*150 (slack)
   497  			want: 370,
   498  		},
   499  	}
   500  
   501  	for _, tt := range tests {
   502  		t.Run(tt.msg, func(t *testing.T) {
   503  			runTest(t, func(r testRunner) {
   504  				slow := r.createLimiter(10, WithoutSlack)
   505  				fast := r.createLimiter(100, tt.opt...)
   506  
   507  				r.startTaking(slow, fast)
   508  
   509  				r.afterFunc(2*time.Second, func() {
   510  					r.startTaking(fast)
   511  					r.startTaking(fast)
   512  				})
   513  
   514  				// limiter with 10hz dominates here - we're always at 10.
   515  				r.assertCountAtWithNoise(1*time.Second, 10, 2)
   516  				r.assertCountAtWithNoise(3*time.Second, tt.want, 2)
   517  			})
   518  		})
   519  	}
   520  }
   521  
   522  func TestSetRateLimitOnTheFly(t *testing.T) {
   523  	t.Skip(UnstableTest)
   524  	runTest(t, func(r testRunner) {
   525  		// Set rate to 1hz
   526  		limiter, ok := r.createLimiter(1, WithoutSlack).(*LeakyBucket)
   527  		if !ok {
   528  			t.Skip("Update is not supported")
   529  		}
   530  
   531  		r.startTaking(limiter)
   532  		r.assertCountAt(time.Second, 2)
   533  
   534  		r.getClock().Add(time.Second)
   535  		r.assertCountAt(time.Second, 3)
   536  
   537  		// increase to 2hz
   538  		limiter.Update(2, 0)
   539  		r.getClock().Add(time.Second)
   540  		r.assertCountAt(time.Second, 4) // <- delayed due to paying sleepFor debt
   541  		r.getClock().Add(time.Second)
   542  		r.assertCountAt(time.Second, 6)
   543  
   544  		// reduce to 1hz again
   545  		limiter.Update(1, 0)
   546  		r.getClock().Add(time.Second)
   547  		r.assertCountAt(time.Second, 7)
   548  		r.getClock().Add(time.Second)
   549  		r.assertCountAt(time.Second, 8)
   550  
   551  		slack := 3
   552  		require.GreaterOrEqual(t, limiter.sleepFor, time.Duration(0))
   553  		limiter.Update(1, slack)
   554  		r.getClock().Add(time.Second * time.Duration(slack))
   555  		r.assertCountAt(time.Second, 8+slack)
   556  	})
   557  }