github.com/sethvargo/go-limiter@v1.0.0/memorystore/store_test.go (about)

     1  package memorystore
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/sha256"
     7  	"fmt"
     8  	"sort"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/sethvargo/go-limiter/internal/fasttime"
    13  )
    14  
    15  func TestFillRate(t *testing.T) {
    16  	t.Parallel()
    17  
    18  	t.Run("many_tokens_small_interval", func(t *testing.T) {
    19  		t.Parallel()
    20  
    21  		s, _ := New(&Config{
    22  			Tokens:   65525,
    23  			Interval: time.Second,
    24  		})
    25  
    26  		for i := 0; i < 20; i++ {
    27  			limit, remaining, _, _, _ := s.Take(context.Background(), "key")
    28  			if remaining < limit-uint64(i)-1 {
    29  				t.Errorf("invalid remaining: run: %d limit: %d remaining: %d", i, limit, remaining)
    30  			}
    31  			time.Sleep(100 * time.Millisecond)
    32  		}
    33  	})
    34  }
    35  
    36  func testKey(tb testing.TB) string {
    37  	tb.Helper()
    38  
    39  	var b [512]byte
    40  	if _, err := rand.Read(b[:]); err != nil {
    41  		tb.Fatalf("failed to generate random string: %v", err)
    42  	}
    43  	digest := fmt.Sprintf("%x", sha256.Sum256(b[:]))
    44  	return digest[:32]
    45  }
    46  
    47  func TestStore_Exercise(t *testing.T) {
    48  	t.Parallel()
    49  
    50  	ctx := context.Background()
    51  
    52  	s, err := New(&Config{
    53  		Tokens:        5,
    54  		Interval:      3 * time.Second,
    55  		SweepInterval: 24 * time.Hour,
    56  		SweepMinTTL:   24 * time.Hour,
    57  	})
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  	t.Cleanup(func() {
    62  		if err := s.Close(ctx); err != nil {
    63  			t.Fatal(err)
    64  		}
    65  	})
    66  
    67  	key := testKey(t)
    68  
    69  	// Get when no config exists
    70  	{
    71  		limit, remaining, err := s.(*store).Get(ctx, key)
    72  		if err != nil {
    73  			t.Fatal(err)
    74  		}
    75  
    76  		if got, want := limit, uint64(0); got != want {
    77  			t.Errorf("expected %v to be %v", got, want)
    78  		}
    79  		if got, want := remaining, uint64(0); got != want {
    80  			t.Errorf("expected %v to be %v", got, want)
    81  		}
    82  	}
    83  
    84  	// Take with no key configuration - this should use the default values
    85  	{
    86  		limit, remaining, reset, ok, err := s.Take(ctx, key)
    87  		if err != nil {
    88  			t.Fatal(err)
    89  		}
    90  		if !ok {
    91  			t.Errorf("expected ok")
    92  		}
    93  		if got, want := limit, uint64(5); got != want {
    94  			t.Errorf("expected %v to be %v", got, want)
    95  		}
    96  		if got, want := remaining, uint64(4); got != want {
    97  			t.Errorf("expected %v to be %v", got, want)
    98  		}
    99  		if got, want := time.Until(time.Unix(0, int64(reset))), 3*time.Second; got > want {
   100  			t.Errorf("expected %v to less than %v", got, want)
   101  		}
   102  	}
   103  
   104  	// Get the value
   105  	{
   106  		limit, remaining, err := s.(*store).Get(ctx, key)
   107  		if err != nil {
   108  			t.Fatal(err)
   109  		}
   110  		if got, want := limit, uint64(5); got != want {
   111  			t.Errorf("expected %v to be %v", got, want)
   112  		}
   113  		if got, want := remaining, uint64(4); got != want {
   114  			t.Errorf("expected %v to be %v", got, want)
   115  		}
   116  	}
   117  
   118  	// Now set a value
   119  	{
   120  		if err := s.Set(ctx, key, 11, 5*time.Second); err != nil {
   121  			t.Fatal(err)
   122  		}
   123  	}
   124  
   125  	// Get the value again
   126  	{
   127  		limit, remaining, err := s.(*store).Get(ctx, key)
   128  		if err != nil {
   129  			t.Fatal(err)
   130  		}
   131  		if got, want := limit, uint64(11); got != want {
   132  			t.Errorf("expected %v to be %v", got, want)
   133  		}
   134  		if got, want := remaining, uint64(11); got != want {
   135  			t.Errorf("expected %v to be %v", got, want)
   136  		}
   137  	}
   138  
   139  	// Take again, this should use the new values
   140  	{
   141  		limit, remaining, reset, ok, err := s.Take(ctx, key)
   142  		if err != nil {
   143  			t.Fatal(err)
   144  		}
   145  		if !ok {
   146  			t.Errorf("expected ok")
   147  		}
   148  		if got, want := limit, uint64(11); got != want {
   149  			t.Errorf("expected %v to be %v", got, want)
   150  		}
   151  		if got, want := remaining, uint64(10); got != want {
   152  			t.Errorf("expected %v to be %v", got, want)
   153  		}
   154  		if got, want := time.Until(time.Unix(0, int64(reset))), 5*time.Second; got > want {
   155  			t.Errorf("expected %v to less than %v", got, want)
   156  		}
   157  	}
   158  
   159  	// Get the value again
   160  	{
   161  		limit, remaining, err := s.(*store).Get(ctx, key)
   162  		if err != nil {
   163  			t.Fatal(err)
   164  		}
   165  		if got, want := limit, uint64(11); got != want {
   166  			t.Errorf("expected %v to be %v", got, want)
   167  		}
   168  		if got, want := remaining, uint64(10); got != want {
   169  			t.Errorf("expected %v to be %v", got, want)
   170  		}
   171  	}
   172  
   173  	// Burst and take
   174  	{
   175  		if err := s.Burst(ctx, key, 5); err != nil {
   176  			t.Fatal(err)
   177  		}
   178  
   179  		limit, remaining, reset, ok, err := s.Take(ctx, key)
   180  		if err != nil {
   181  			t.Fatal(err)
   182  		}
   183  		if !ok {
   184  			t.Errorf("expected ok")
   185  		}
   186  		if got, want := limit, uint64(11); got != want {
   187  			t.Errorf("expected %v to be %v", got, want)
   188  		}
   189  		if got, want := remaining, uint64(14); got != want {
   190  			t.Errorf("expected %v to be %v", got, want)
   191  		}
   192  		if got, want := time.Until(time.Unix(0, int64(reset))), 5*time.Second; got > want {
   193  			t.Errorf("expected %v to less than %v", got, want)
   194  		}
   195  	}
   196  
   197  	// Get the value one final time
   198  	{
   199  		limit, remaining, err := s.(*store).Get(ctx, key)
   200  		if err != nil {
   201  			t.Fatal(err)
   202  		}
   203  		if got, want := limit, uint64(11); got != want {
   204  			t.Errorf("expected %v to be %v", got, want)
   205  		}
   206  		if got, want := remaining, uint64(14); got != want {
   207  			t.Errorf("expected %v to be %v", got, want)
   208  		}
   209  	}
   210  }
   211  
   212  func TestStore_Take(t *testing.T) {
   213  	t.Parallel()
   214  
   215  	ctx := context.Background()
   216  
   217  	cases := []struct {
   218  		name     string
   219  		tokens   uint64
   220  		interval time.Duration
   221  	}{
   222  		{
   223  			name:     "milli",
   224  			tokens:   5,
   225  			interval: 500 * time.Millisecond,
   226  		},
   227  		{
   228  			name:     "second",
   229  			tokens:   10,
   230  			interval: 1 * time.Second,
   231  		},
   232  	}
   233  
   234  	for _, tc := range cases {
   235  		tc := tc
   236  
   237  		t.Run(tc.name, func(t *testing.T) {
   238  			t.Parallel()
   239  
   240  			key := testKey(t)
   241  
   242  			s, err := New(&Config{
   243  				Interval:      tc.interval,
   244  				Tokens:        tc.tokens,
   245  				SweepInterval: 24 * time.Hour,
   246  				SweepMinTTL:   24 * time.Hour,
   247  			})
   248  			if err != nil {
   249  				t.Fatal(err)
   250  			}
   251  			t.Cleanup(func() {
   252  				if err := s.Close(ctx); err != nil {
   253  					t.Fatal(err)
   254  				}
   255  			})
   256  
   257  			type result struct {
   258  				limit, remaining uint64
   259  				reset            time.Duration
   260  				ok               bool
   261  				err              error
   262  			}
   263  
   264  			// Take twice everything from the bucket.
   265  			takeCh := make(chan *result, 2*tc.tokens)
   266  			for i := uint64(1); i <= 2*tc.tokens; i++ {
   267  				go func() {
   268  					limit, remaining, reset, ok, err := s.Take(ctx, key)
   269  					takeCh <- &result{limit, remaining, time.Duration(fasttime.Now() - reset), ok, err}
   270  				}()
   271  			}
   272  
   273  			// Accumulate and sort results, since they could come in any order.
   274  			var results []*result
   275  			for i := uint64(1); i <= 2*tc.tokens; i++ {
   276  				select {
   277  				case result := <-takeCh:
   278  					results = append(results, result)
   279  				case <-time.After(5 * time.Second):
   280  					t.Fatal("timeout")
   281  				}
   282  			}
   283  			sort.Slice(results, func(i, j int) bool {
   284  				if results[i].remaining == results[j].remaining {
   285  					return !results[j].ok
   286  				}
   287  				return results[i].remaining > results[j].remaining
   288  			})
   289  
   290  			for i, result := range results {
   291  				if err := result.err; err != nil {
   292  					t.Fatal(err)
   293  				}
   294  
   295  				if got, want := result.limit, tc.tokens; got != want {
   296  					t.Errorf("limit: expected %d to be %d", got, want)
   297  				}
   298  				if got, want := result.reset, tc.interval; got > want {
   299  					t.Errorf("reset: expected %d to be less than %d", got, want)
   300  				}
   301  
   302  				// first half should pass, second half should fail
   303  				if uint64(i) < tc.tokens {
   304  					if got, want := result.remaining, tc.tokens-uint64(i)-1; got != want {
   305  						t.Errorf("remaining: expected %d to be %d", got, want)
   306  					}
   307  					if got, want := result.ok, true; got != want {
   308  						t.Errorf("ok: expected %t to be %t", got, want)
   309  					}
   310  				} else {
   311  					if got, want := result.remaining, uint64(0); got != want {
   312  						t.Errorf("remaining: expected %d to be %d", got, want)
   313  					}
   314  					if got, want := result.ok, false; got != want {
   315  						t.Errorf("ok: expected %t to be %t", got, want)
   316  					}
   317  				}
   318  			}
   319  
   320  			// Wait for the bucket to have entries again.
   321  			time.Sleep(tc.interval)
   322  
   323  			// Verify we can take once more.
   324  			_, _, _, ok, err := s.Take(ctx, key)
   325  			if err != nil {
   326  				t.Fatal(err)
   327  			}
   328  			if !ok {
   329  				t.Errorf("expected %t to be %t", ok, true)
   330  			}
   331  		})
   332  	}
   333  }
   334  
   335  func TestBucketedLimiter_tick(t *testing.T) {
   336  	t.Parallel()
   337  
   338  	cases := []struct {
   339  		name     string
   340  		start    uint64
   341  		curr     uint64
   342  		interval time.Duration
   343  		exp      uint64
   344  	}{
   345  		{
   346  			name:     "no_diff",
   347  			start:    0,
   348  			curr:     0,
   349  			interval: time.Second,
   350  			exp:      0,
   351  		},
   352  		{
   353  			name:     "half",
   354  			start:    0,
   355  			curr:     uint64(500 * time.Millisecond),
   356  			interval: time.Second,
   357  			exp:      0,
   358  		},
   359  		{
   360  			name:     "almost",
   361  			start:    0,
   362  			curr:     uint64(1*time.Second - time.Nanosecond),
   363  			interval: time.Second,
   364  			exp:      0,
   365  		},
   366  		{
   367  			name:     "exact",
   368  			start:    0,
   369  			curr:     uint64(1 * time.Second),
   370  			interval: time.Second,
   371  			exp:      1,
   372  		},
   373  		{
   374  			name:     "multiple",
   375  			start:    0,
   376  			curr:     uint64(50*time.Second - 500*time.Millisecond),
   377  			interval: time.Second,
   378  			exp:      49,
   379  		},
   380  		{
   381  			name:     "short",
   382  			start:    0,
   383  			curr:     uint64(50*time.Second - 500*time.Millisecond),
   384  			interval: time.Millisecond,
   385  			exp:      49500,
   386  		},
   387  	}
   388  
   389  	for _, tc := range cases {
   390  		tc := tc
   391  
   392  		t.Run(tc.name, func(t *testing.T) {
   393  			t.Parallel()
   394  
   395  			if got, want := tick(tc.start, tc.curr, tc.interval), tc.exp; got != want {
   396  				t.Errorf("expected %v to be %v", got, want)
   397  			}
   398  		})
   399  	}
   400  }