github.com/rudderlabs/rudder-go-kit@v0.30.0/throttling/throttling_test.go (about)

     1  package throttling
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"sync/atomic"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/ory/dockertest/v3"
    12  	"github.com/stretchr/testify/require"
    13  
    14  	"github.com/rudderlabs/rudder-go-kit/testhelper/rand"
    15  )
    16  
    17  func TestThrottling(t *testing.T) {
    18  	pool, err := dockertest.NewPool("")
    19  	require.NoError(t, err)
    20  
    21  	type limiterSettings struct {
    22  		name    string
    23  		limiter *Limiter
    24  		// concurrency has been introduced because the GCRA algorithms (both in-memory and Redis) tend to lose precision
    25  		// when the concurrency is too high. Until we fix it or come up with a guard to limit the amount of concurrent
    26  		// requests, we're limiting the concurrency to X in the tests (to avoid test flakiness).
    27  		concurrency int
    28  	}
    29  
    30  	var (
    31  		ctx      = context.Background()
    32  		rc       = bootstrapRedis(ctx, t, pool)
    33  		limiters = []limiterSettings{
    34  			{
    35  				name:        "gcra",
    36  				limiter:     newLimiter(t, WithInMemoryGCRA(0)),
    37  				concurrency: 100,
    38  			},
    39  			{
    40  				name:        "gcra redis",
    41  				limiter:     newLimiter(t, WithRedisGCRA(rc, 100)), // TODO: this should work properly with burst = 0 as well (i.e. burst = rate)
    42  				concurrency: 100,
    43  			},
    44  			{
    45  				name:        "sorted sets redis",
    46  				limiter:     newLimiter(t, WithRedisSortedSet(rc)),
    47  				concurrency: 5000,
    48  			},
    49  		}
    50  	)
    51  
    52  	flakinessRate := 1 // increase to run the tests multiple times in a row to debug flaky tests
    53  	for i := 0; i < flakinessRate; i++ {
    54  		for _, tc := range []testCase{
    55  			// avoid rates that are too small (e.g. 10), that's where there is the most flakiness
    56  			{rate: 500, window: 1},
    57  			{rate: 1000, window: 2},
    58  			{rate: 2000, window: 3},
    59  		} {
    60  			for _, l := range limiters {
    61  				t.Run(testName(l.name, tc.rate, tc.window), func(t *testing.T) {
    62  					expected := tc.rate
    63  					testLimiter(ctx, t, l.limiter, tc.rate, tc.window, expected, l.concurrency)
    64  				})
    65  			}
    66  		}
    67  	}
    68  }
    69  
    70  func testLimiter(
    71  	ctx context.Context, t *testing.T, l *Limiter, rate, window, expected int64, concurrency int,
    72  ) {
    73  	t.Helper()
    74  	var (
    75  		wg          sync.WaitGroup
    76  		passed      int64
    77  		cost        int64 = 1
    78  		key               = rand.UniqueString(10)
    79  		maxRoutines       = make(chan struct{}, concurrency)
    80  		// Time tracking variables
    81  		startTime   int64
    82  		currentTime int64
    83  		timeMutex   sync.Mutex
    84  		stop        = make(chan struct{}, 1)
    85  	)
    86  
    87  	run := func() (err error, allowed bool, redisTime time.Duration) {
    88  		switch {
    89  		case l.redisSpeaker != nil && l.useGCRA:
    90  			redisTime, allowed, _, _, err = l.redisGCRA(ctx, cost, rate, window, key)
    91  		case l.redisSpeaker != nil && !l.useGCRA:
    92  			redisTime, allowed, _, _, err = l.redisSortedSet(ctx, cost, rate, window, key)
    93  		default:
    94  			allowed, _, err = l.Allow(ctx, cost, rate, window, key)
    95  		}
    96  		return
    97  	}
    98  	if l.useGCRA { // warm up the GCRA algorithms
    99  		for i := 0; i < int(rate); i++ {
   100  			_, _, _ = run()
   101  		}
   102  	}
   103  loop:
   104  
   105  	for {
   106  		select {
   107  		case <-stop:
   108  			// To decrease the error margin (mostly introduced for the Redis algorithms) I'm measuring time in two
   109  			// different ways depending on whether I'm using an in-memory algorithm or a Redis one.
   110  			// For the Redis algorithms I measure the elapsed time by keeping track of the timestamps returned by
   111  			// the Redis Lua scripts.
   112  			// This is because there is a bit of a drift between the time as we measure it here and the time as it
   113  			// is measured in the Lua scripts.
   114  			break loop
   115  		case maxRoutines <- struct{}{}:
   116  			wg.Add(1)
   117  			go func() {
   118  				defer wg.Done()
   119  				defer func() { <-maxRoutines }()
   120  
   121  				err, allowed, redisTime := run()
   122  				require.NoError(t, err)
   123  				now := time.Now().UnixNano()
   124  
   125  				timeMutex.Lock()
   126  				defer timeMutex.Unlock()
   127  				if redisTime > 0 { // Redis limiters have a getTime() function that returns the current time from Redis perspective
   128  					if redisTime > time.Duration(currentTime) {
   129  						currentTime = int64(redisTime)
   130  					}
   131  				} else if now > currentTime { // in-memory algorithm, let's use time.Now()
   132  					currentTime = now
   133  				}
   134  				if startTime == 0 {
   135  					startTime = currentTime
   136  				}
   137  				if currentTime-startTime > window*int64(time.Second) {
   138  					select {
   139  					case stop <- struct{}{}:
   140  					default: // one signal to stop is enough, don't block
   141  					}
   142  					return // do not increment "passed" because we're over the window
   143  				}
   144  				if allowed {
   145  					atomic.AddInt64(&passed, cost)
   146  				}
   147  			}()
   148  		}
   149  	}
   150  
   151  	wg.Wait()
   152  
   153  	diff := expected - passed
   154  	errorMargin := int64(0.1*float64(rate)) + 1 // ~10% error margin
   155  	if passed < 1 || diff < -errorMargin || diff > errorMargin {
   156  		t.Errorf("Expected %d, got %d (diff: %d, error margin: %d)", expected, passed, diff, errorMargin)
   157  	}
   158  }
   159  
   160  func TestGCRABurstAsRate(t *testing.T) {
   161  	pool, err := dockertest.NewPool("")
   162  	require.NoError(t, err)
   163  
   164  	var (
   165  		ctx = context.Background()
   166  		rc  = bootstrapRedis(ctx, t, pool)
   167  		l   = newLimiter(t, WithRedisGCRA(rc, 0))
   168  		// Configuration variables
   169  		key          = "foo"
   170  		cost   int64 = 1
   171  		rate   int64 = 500
   172  		window int64 = 1
   173  		// Expectations
   174  		passed int64
   175  		start  = time.Now()
   176  	)
   177  
   178  	for i := int64(0); i < rate*2; i++ {
   179  		allowed, _, err := l.Allow(ctx, cost, rate, window, key)
   180  		require.NoError(t, err)
   181  		if allowed {
   182  			passed++
   183  		}
   184  	}
   185  
   186  	require.GreaterOrEqual(t, passed, rate)
   187  	require.Less(
   188  		t, time.Since(start), time.Duration(window*int64(time.Second)),
   189  		"we should've been able to make the required request in less than the window duration due to the burst setting",
   190  	)
   191  }
   192  
   193  func TestReturn(t *testing.T) {
   194  	pool, err := dockertest.NewPool("")
   195  	require.NoError(t, err)
   196  
   197  	type testCase struct {
   198  		name    string
   199  		limiter *Limiter
   200  	}
   201  
   202  	var (
   203  		rate      int64 = 10
   204  		window    int64 = 1
   205  		windowDur       = time.Duration(window) * time.Second
   206  		ctx             = context.Background()
   207  		rc              = bootstrapRedis(ctx, t, pool)
   208  		testCases       = []testCase{
   209  			{
   210  				name:    "sorted sets redis",
   211  				limiter: newLimiter(t, WithRedisSortedSet(rc)),
   212  			},
   213  		}
   214  	)
   215  
   216  	for _, tc := range testCases {
   217  		t.Run(testName(tc.name, rate, window), func(t *testing.T) {
   218  			var (
   219  				passed int
   220  				key    = rand.UniqueString(10)
   221  				tokens []func(context.Context) error
   222  			)
   223  			for i := int64(0); i < rate*10; i++ {
   224  				allowed, returner, err := tc.limiter.Allow(ctx, 1, rate, window, key)
   225  				require.NoError(t, err)
   226  				if allowed {
   227  					passed++
   228  					tokens = append(tokens, returner)
   229  				}
   230  			}
   231  
   232  			require.EqualValues(t, rate, passed)
   233  
   234  			allowed, _, err := tc.limiter.Allow(ctx, 1, rate, window, key)
   235  			require.NoError(t, err)
   236  			require.False(t, allowed)
   237  
   238  			// return one token and try again
   239  			require.NoError(t, tokens[0](ctx))
   240  			require.Eventually(t, func() bool {
   241  				allowed, returner, err := tc.limiter.Allow(ctx, 1, rate, window, key)
   242  				return allowed && err == nil && returner != nil
   243  			}, windowDur/2, time.Millisecond)
   244  		})
   245  	}
   246  }
   247  
   248  func TestBadData(t *testing.T) {
   249  	pool, err := dockertest.NewPool("")
   250  	require.NoError(t, err)
   251  
   252  	var (
   253  		ctx      = context.Background()
   254  		rc       = bootstrapRedis(ctx, t, pool)
   255  		limiters = map[string]*Limiter{
   256  			"gcra":              newLimiter(t, WithInMemoryGCRA(0)),
   257  			"gcra redis":        newLimiter(t, WithRedisGCRA(rc, 0)),
   258  			"sorted sets redis": newLimiter(t, WithRedisSortedSet(rc)),
   259  		}
   260  	)
   261  
   262  	for name, l := range limiters {
   263  		t.Run(name+" cost", func(t *testing.T) {
   264  			allowed, ret, err := l.Allow(ctx, 0, 10, 1, "foo")
   265  			require.False(t, allowed)
   266  			require.Nil(t, ret)
   267  			require.Error(t, err, "cost must be greater than 0")
   268  		})
   269  		t.Run(name+" rate", func(t *testing.T) {
   270  			allowed, ret, err := l.Allow(ctx, 1, 0, 1, "foo")
   271  			require.False(t, allowed)
   272  			require.Nil(t, ret)
   273  			require.Error(t, err, "rate must be greater than 0")
   274  		})
   275  		t.Run(name+" window", func(t *testing.T) {
   276  			allowed, ret, err := l.Allow(ctx, 1, 10, 0, "foo")
   277  			require.False(t, allowed)
   278  			require.Nil(t, ret)
   279  			require.Error(t, err, "window must be greater than 0")
   280  		})
   281  		t.Run(name+" key", func(t *testing.T) {
   282  			allowed, ret, err := l.Allow(ctx, 1, 10, 1, "")
   283  			require.False(t, allowed)
   284  			require.Nil(t, ret)
   285  			require.Error(t, err, "key must not be empty")
   286  		})
   287  	}
   288  }
   289  
   290  func TestRetryAfter(t *testing.T) {
   291  	pool, err := dockertest.NewPool("")
   292  	require.NoError(t, err)
   293  
   294  	type testCase struct {
   295  		name                      string
   296  		limiter                   *Limiter
   297  		rate                      int64
   298  		window                    int64
   299  		runFor                    time.Duration
   300  		warmUp                    bool
   301  		expectedAllowedCount      int
   302  		expectedAllowedCountDelta float64
   303  		expectedSleepsCount       int
   304  		expectedSleepsCountDelta  float64
   305  	}
   306  
   307  	var (
   308  		ctx       = context.Background()
   309  		rc        = bootstrapRedis(ctx, t, pool)
   310  		testCases = []testCase{
   311  			{
   312  				name:                      "gcra",
   313  				limiter:                   newLimiter(t, WithInMemoryGCRA(0)),
   314  				rate:                      2,
   315  				window:                    1,
   316  				runFor:                    3 * time.Second,
   317  				warmUp:                    true,
   318  				expectedAllowedCount:      6,
   319  				expectedAllowedCountDelta: 3,
   320  				expectedSleepsCount:       3,
   321  				expectedSleepsCountDelta:  2,
   322  			},
   323  			{
   324  				name:                      "gcra redis",
   325  				limiter:                   newLimiter(t, WithRedisGCRA(rc, 100)),
   326  				rate:                      2,
   327  				window:                    1,
   328  				runFor:                    3 * time.Second,
   329  				warmUp:                    true,
   330  				expectedAllowedCount:      6,
   331  				expectedAllowedCountDelta: 3,
   332  				expectedSleepsCount:       3,
   333  				expectedSleepsCountDelta:  2,
   334  			},
   335  			{
   336  				name:                      "sorted sets redis",
   337  				limiter:                   newLimiter(t, WithRedisSortedSet(rc)),
   338  				rate:                      2,
   339  				window:                    1,
   340  				runFor:                    3 * time.Second,
   341  				warmUp:                    false,
   342  				expectedAllowedCount:      6,
   343  				expectedAllowedCountDelta: 0, // this algorithm is the most precise but requires more memory on Redis
   344  				expectedSleepsCount:       3,
   345  				expectedSleepsCountDelta:  0, // this algorithm is the most precise but requires more memory on Redis
   346  			},
   347  		}
   348  	)
   349  
   350  	flakinessRate := 1 // increase to run the tests multiple times in a row to debug flaky tests
   351  	for i := 0; i < flakinessRate; i++ {
   352  		for _, tc := range testCases {
   353  			t.Run(testName(tc.name, tc.rate, tc.window), func(t *testing.T) {
   354  				timeout := time.NewTimer(tc.runFor)
   355  				t.Cleanup(func() {
   356  					_ = timeout.Stop
   357  				})
   358  
   359  				if tc.warmUp {
   360  					timer := time.NewTimer(time.Duration(tc.window) * time.Second)
   361  				warmUp:
   362  					for {
   363  						_, _, err := tc.limiter.Allow(ctx, 1, tc.rate, tc.window, t.Name())
   364  						require.NoError(t, err)
   365  						select {
   366  						case <-timer.C:
   367  							break warmUp
   368  						default:
   369  						}
   370  					}
   371  				}
   372  
   373  				var (
   374  					allowedCount int
   375  					sleepsCount  int
   376  				)
   377  
   378  			loop:
   379  				for {
   380  					allowed, retryAfter, _, err := tc.limiter.AllowAfter(ctx, 1, tc.rate, tc.window, t.Name())
   381  					require.NoError(t, err)
   382  
   383  					t.Logf("allowed: %v, retryAfter: %v", allowed, retryAfter)
   384  
   385  					if allowed {
   386  						require.EqualValues(t, 0, retryAfter)
   387  						allowedCount++
   388  					} else {
   389  						require.Greater(t, retryAfter, int64(0))
   390  						sleepsCount++
   391  					}
   392  
   393  					select {
   394  					case <-ctx.Done():
   395  						break loop
   396  					case <-timeout.C:
   397  						break loop
   398  					case <-time.After(retryAfter):
   399  						t.Logf("slept for %v", retryAfter)
   400  					}
   401  				}
   402  
   403  				require.InDelta(t, tc.expectedAllowedCount, allowedCount, tc.expectedAllowedCountDelta)
   404  				require.InDelta(t, tc.expectedSleepsCount, sleepsCount, tc.expectedSleepsCountDelta)
   405  			})
   406  		}
   407  	}
   408  }
   409  
   410  func testName(name string, rate, window int64) string {
   411  	return fmt.Sprintf("%s/%d tokens per %ds", name, rate, window)
   412  }