gitlab.com/infor-cloud/martian-cloud/tharsis/go-limiter@v0.0.0-20230411193226-3247984d5abc/memorystore/store_test.go (about)

     1  package memorystore
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/sha256"
     7  	"fmt"
     8  	"sort"
     9  	"testing"
    10  	"time"
    11  
    12  	"gitlab.com/infor-cloud/martian-cloud/tharsis/go-limiter/fasttime"
    13  )
    14  
    15  func TestFillRate(t *testing.T) {
    16  	t.Parallel()
    17  
    18  	t.Run("many_tokens_small_interval", func(t *testing.T) {
    19  		t.Parallel()
    20  
    21  		s, _ := New(&Config{
    22  			Tokens:   65525,
    23  			Interval: time.Second,
    24  		})
    25  
    26  		for i := 0; i < 20; i++ {
    27  			limit, remaining, _, _, _ := s.Take(context.Background(), "key")
    28  			if remaining < limit-uint64(i)-1 {
    29  				t.Errorf("invalid remaining: run: %d limit: %d remaining: %d", i, limit, remaining)
    30  			}
    31  			time.Sleep(100 * time.Millisecond)
    32  		}
    33  	})
    34  }
    35  
    36  func testKey(tb testing.TB) string {
    37  	tb.Helper()
    38  
    39  	var b [512]byte
    40  	if _, err := rand.Read(b[:]); err != nil {
    41  		tb.Fatalf("failed to generate random string: %v", err)
    42  	}
    43  	digest := fmt.Sprintf("%x", sha256.Sum256(b[:]))
    44  	return digest[:32]
    45  }
    46  
    47  func TestStore_Exercise(t *testing.T) {
    48  	t.Parallel()
    49  
    50  	ctx := context.Background()
    51  
    52  	s, err := New(&Config{
    53  		Tokens:        5,
    54  		Interval:      3 * time.Second,
    55  		SweepInterval: 24 * time.Hour,
    56  		SweepMinTTL:   24 * time.Hour,
    57  	})
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  	t.Cleanup(func() {
    62  		if err := s.Close(ctx); err != nil {
    63  			t.Fatal(err)
    64  		}
    65  	})
    66  
    67  	key := testKey(t)
    68  
    69  	// Get when no config exists
    70  	{
    71  		limit, remaining, err := s.(*store).Get(ctx, key)
    72  		if err != nil {
    73  			t.Fatal(err)
    74  		}
    75  
    76  		if got, want := limit, uint64(0); got != want {
    77  			t.Errorf("expected %v to be %v", got, want)
    78  		}
    79  		if got, want := remaining, uint64(0); got != want {
    80  			t.Errorf("expected %v to be %v", got, want)
    81  		}
    82  	}
    83  
    84  	// Take with no key configuration - this should use the default values
    85  	{
    86  		limit, remaining, reset, ok, err := s.Take(ctx, key)
    87  		if err != nil {
    88  			t.Fatal(err)
    89  		}
    90  		if !ok {
    91  			t.Errorf("expected ok")
    92  		}
    93  		if got, want := limit, uint64(5); got != want {
    94  			t.Errorf("expected %v to be %v", got, want)
    95  		}
    96  		if got, want := remaining, uint64(4); got != want {
    97  			t.Errorf("expected %v to be %v", got, want)
    98  		}
    99  		if got, want := time.Until(time.Unix(0, int64(reset))), 3*time.Second; got > want {
   100  			t.Errorf("expected %v to less than %v", got, want)
   101  		}
   102  	}
   103  
   104  	// Get the value
   105  	{
   106  		limit, remaining, err := s.(*store).Get(ctx, key)
   107  		if err != nil {
   108  			t.Fatal(err)
   109  		}
   110  		if got, want := limit, uint64(5); got != want {
   111  			t.Errorf("expected %v to be %v", got, want)
   112  		}
   113  		if got, want := remaining, uint64(4); got != want {
   114  			t.Errorf("expected %v to be %v", got, want)
   115  		}
   116  	}
   117  
   118  	// Now set a value
   119  	{
   120  		if err := s.Set(ctx, key, 11, 5*time.Second); err != nil {
   121  			t.Fatal(err)
   122  		}
   123  	}
   124  
   125  	// Get the value again
   126  	{
   127  		limit, remaining, err := s.(*store).Get(ctx, key)
   128  		if err != nil {
   129  			t.Fatal(err)
   130  		}
   131  		if got, want := limit, uint64(11); got != want {
   132  			t.Errorf("expected %v to be %v", got, want)
   133  		}
   134  		if got, want := remaining, uint64(11); got != want {
   135  			t.Errorf("expected %v to be %v", got, want)
   136  		}
   137  	}
   138  
   139  	// Take again, this should use the new values
   140  	{
   141  		limit, remaining, reset, ok, err := s.Take(ctx, key)
   142  		if err != nil {
   143  			t.Fatal(err)
   144  		}
   145  		if !ok {
   146  			t.Errorf("expected ok")
   147  		}
   148  		if got, want := limit, uint64(11); got != want {
   149  			t.Errorf("expected %v to be %v", got, want)
   150  		}
   151  		if got, want := remaining, uint64(10); got != want {
   152  			t.Errorf("expected %v to be %v", got, want)
   153  		}
   154  		if got, want := time.Until(time.Unix(0, int64(reset))), 5*time.Second; got > want {
   155  			t.Errorf("expected %v to less than %v", got, want)
   156  		}
   157  	}
   158  
   159  	// Get the value again
   160  	{
   161  		limit, remaining, err := s.(*store).Get(ctx, key)
   162  		if err != nil {
   163  			t.Fatal(err)
   164  		}
   165  		if got, want := limit, uint64(11); got != want {
   166  			t.Errorf("expected %v to be %v", got, want)
   167  		}
   168  		if got, want := remaining, uint64(10); got != want {
   169  			t.Errorf("expected %v to be %v", got, want)
   170  		}
   171  	}
   172  
   173  	// Burst and take
   174  	{
   175  		if err := s.Burst(ctx, key, 5); err != nil {
   176  			t.Fatal(err)
   177  		}
   178  
   179  		limit, remaining, reset, ok, err := s.Take(ctx, key)
   180  		if err != nil {
   181  			t.Fatal(err)
   182  		}
   183  		if !ok {
   184  			t.Errorf("expected ok")
   185  		}
   186  		if got, want := limit, uint64(11); got != want {
   187  			t.Errorf("expected %v to be %v", got, want)
   188  		}
   189  		if got, want := remaining, uint64(14); got != want {
   190  			t.Errorf("expected %v to be %v", got, want)
   191  		}
   192  		if got, want := time.Until(time.Unix(0, int64(reset))), 5*time.Second; got > want {
   193  			t.Errorf("expected %v to less than %v", got, want)
   194  		}
   195  	}
   196  
   197  	// Get the value one final time
   198  	{
   199  		limit, remaining, err := s.(*store).Get(ctx, key)
   200  		if err != nil {
   201  			t.Fatal(err)
   202  		}
   203  		if got, want := limit, uint64(11); got != want {
   204  			t.Errorf("expected %v to be %v", got, want)
   205  		}
   206  		if got, want := remaining, uint64(14); got != want {
   207  			t.Errorf("expected %v to be %v", got, want)
   208  		}
   209  	}
   210  }
   211  
   212  func TestStore_Take(t *testing.T) {
   213  	t.Parallel()
   214  
   215  	ctx := context.Background()
   216  
   217  	cases := []struct {
   218  		name     string
   219  		tokens   uint64
   220  		interval time.Duration
   221  	}{
   222  		{
   223  			name:     "milli",
   224  			tokens:   5,
   225  			interval: 500 * time.Millisecond,
   226  		},
   227  		{
   228  			name:     "second",
   229  			tokens:   10,
   230  			interval: 1 * time.Second,
   231  		},
   232  	}
   233  
   234  	for _, tc := range cases {
   235  		tc := tc
   236  
   237  		t.Run(tc.name, func(t *testing.T) {
   238  			t.Parallel()
   239  
   240  			key := testKey(t)
   241  
   242  			s, err := New(&Config{
   243  				Interval:      tc.interval,
   244  				Tokens:        tc.tokens,
   245  				SweepInterval: 24 * time.Hour,
   246  				SweepMinTTL:   24 * time.Hour,
   247  			})
   248  			if err != nil {
   249  				t.Fatal(err)
   250  			}
   251  			t.Cleanup(func() {
   252  				if err := s.Close(ctx); err != nil {
   253  					t.Fatal(err)
   254  				}
   255  			})
   256  
   257  			type result struct {
   258  				limit, remaining uint64
   259  				reset            time.Duration
   260  				ok               bool
   261  				err              error
   262  			}
   263  
   264  			// Take twice everything from the bucket.
   265  			takeCh := make(chan *result, 2*tc.tokens)
   266  			for i := uint64(1); i <= 2*tc.tokens; i++ {
   267  				go func() {
   268  					limit, remaining, reset, ok, err := s.Take(ctx, key)
   269  					takeCh <- &result{limit, remaining, time.Duration(fasttime.Now() - reset), ok, err}
   270  				}()
   271  			}
   272  
   273  			// Accumulate and sort results, since they could come in any order.
   274  			var results []*result
   275  			for i := uint64(1); i <= 2*tc.tokens; i++ {
   276  				select {
   277  				case result := <-takeCh:
   278  					results = append(results, result)
   279  				case <-time.After(5 * time.Second):
   280  					t.Fatal("timeout")
   281  				}
   282  			}
   283  			sort.Slice(results, func(i, j int) bool {
   284  				if results[i].remaining == results[j].remaining {
   285  					return !results[j].ok
   286  				}
   287  				return results[i].remaining > results[j].remaining
   288  			})
   289  
   290  			for i, result := range results {
   291  				if err := result.err; err != nil {
   292  					t.Fatal(err)
   293  				}
   294  
   295  				if got, want := result.limit, tc.tokens; got != want {
   296  					t.Errorf("limit: expected %d to be %d", got, want)
   297  				}
   298  				if got, want := result.reset, tc.interval; got > want {
   299  					t.Errorf("reset: expected %d to be less than %d", got, want)
   300  				}
   301  
   302  				// first half should pass, second half should fail
   303  				if uint64(i) < tc.tokens {
   304  					if got, want := result.remaining, tc.tokens-uint64(i)-1; got != want {
   305  						t.Errorf("remaining: expected %d to be %d", got, want)
   306  					}
   307  					if got, want := result.ok, true; got != want {
   308  						t.Errorf("ok: expected %t to be %t", got, want)
   309  					}
   310  				} else {
   311  					if got, want := result.remaining, uint64(0); got != want {
   312  						t.Errorf("remaining: expected %d to be %d", got, want)
   313  					}
   314  					if got, want := result.ok, false; got != want {
   315  						t.Errorf("ok: expected %t to be %t", got, want)
   316  					}
   317  				}
   318  			}
   319  
   320  			// Wait for the bucket to have entries again.
   321  			time.Sleep(tc.interval)
   322  
   323  			// Verify we can take once more.
   324  			_, _, _, ok, err := s.Take(ctx, key)
   325  			if err != nil {
   326  				t.Fatal(err)
   327  			}
   328  			if !ok {
   329  				t.Errorf("expected %t to be %t", ok, true)
   330  			}
   331  		})
   332  	}
   333  }
   334  
   335  func TestStore_TakeMany(t *testing.T) {
   336  	t.Parallel()
   337  
   338  	ctx := context.Background()
   339  
   340  	cases := []struct {
   341  		name       string
   342  		tokens     uint64
   343  		takeAmount uint64
   344  		interval   time.Duration
   345  	}{
   346  		{
   347  			name:       "milli",
   348  			tokens:     10,
   349  			takeAmount: 2,
   350  			interval:   500 * time.Millisecond,
   351  		},
   352  		{
   353  			name:       "second",
   354  			tokens:     10,
   355  			takeAmount: 4,
   356  			interval:   1 * time.Second,
   357  		},
   358  	}
   359  
   360  	for _, tc := range cases {
   361  		tc := tc
   362  
   363  		t.Run(tc.name, func(t *testing.T) {
   364  			t.Parallel()
   365  
   366  			key := testKey(t)
   367  
   368  			s, err := New(&Config{
   369  				Interval:      tc.interval,
   370  				Tokens:        tc.tokens,
   371  				SweepInterval: 24 * time.Hour,
   372  				SweepMinTTL:   24 * time.Hour,
   373  			})
   374  			if err != nil {
   375  				t.Fatal(err)
   376  			}
   377  			t.Cleanup(func() {
   378  				if err := s.Close(ctx); err != nil {
   379  					t.Fatal(err)
   380  				}
   381  			})
   382  
   383  			type result struct {
   384  				limit, remaining uint64
   385  				reset            time.Duration
   386  				ok               bool
   387  				err              error
   388  			}
   389  
   390  			// Take everything from the bucket.
   391  			//possibleIterations calculates possible iterations for test case
   392  			possibleIterations := tc.tokens / tc.takeAmount
   393  			takeCh := make(chan *result, 2*possibleIterations)
   394  			for i := uint64(1); i <= 2*possibleIterations; i++ {
   395  				go func() {
   396  					limit, remaining, reset, ok, err := s.TakeMany(ctx, key, tc.takeAmount)
   397  					takeCh <- &result{limit, remaining, time.Duration(fasttime.Now() - reset), ok, err}
   398  				}()
   399  			}
   400  
   401  			// Accumulate and sort results, since they could come in any order.
   402  			var results []*result
   403  			for i := uint64(1); i <= 2*possibleIterations; i++ {
   404  				select {
   405  				case result := <-takeCh:
   406  					results = append(results, result)
   407  				case <-time.After(5 * time.Second):
   408  					t.Fatal("timeout")
   409  				}
   410  			}
   411  			sort.Slice(results, func(i, j int) bool {
   412  				if results[i].remaining == results[j].remaining {
   413  					return !results[j].ok
   414  				}
   415  				return results[i].remaining > results[j].remaining
   416  			})
   417  
   418  			for i, result := range results {
   419  				if err := result.err; err != nil {
   420  					t.Fatal(err)
   421  				}
   422  				if got, want := result.limit, tc.tokens; got != want {
   423  					t.Errorf("limit: expected %d to be %d", got, want)
   424  				}
   425  				if got, want := result.reset, tc.interval; got > want {
   426  					t.Errorf("reset: expected %d to be less than %d", got, want)
   427  				}
   428  
   429  				// first half should pass, second half should fail
   430  				if uint64(i+1)*tc.takeAmount <= tc.tokens {
   431  					if got, want := result.remaining, tc.tokens-uint64(i+1)*tc.takeAmount; got != want {
   432  						t.Errorf("remaining: expected %d to be %d", got, want)
   433  					}
   434  					if got, want := result.ok, true; got != want {
   435  						t.Errorf("ok: expected %t to be %t", got, want)
   436  					}
   437  				} else {
   438  					if got, want := result.remaining, uint64(0); got != want {
   439  						t.Errorf("remaining: expected %d to be %d", got, want)
   440  					}
   441  					if got, want := result.ok, false; got != want {
   442  						t.Errorf("ok 0: expected %t to be %t", got, want)
   443  					}
   444  				}
   445  			}
   446  
   447  			// Wait for the bucket to have entries again.
   448  			time.Sleep(tc.interval)
   449  
   450  			// Verify we can take once more.
   451  			_, _, _, ok, err := s.TakeMany(ctx, key, tc.takeAmount)
   452  			if err != nil {
   453  				t.Fatal(err)
   454  			}
   455  			if !ok {
   456  				t.Errorf("expected %t to be %t", ok, true)
   457  			}
   458  		})
   459  	}
   460  }
   461  
   462  func TestBucketedLimiter_tick(t *testing.T) {
   463  	t.Parallel()
   464  
   465  	cases := []struct {
   466  		name     string
   467  		start    uint64
   468  		curr     uint64
   469  		interval time.Duration
   470  		exp      uint64
   471  	}{
   472  		{
   473  			name:     "no_diff",
   474  			start:    0,
   475  			curr:     0,
   476  			interval: time.Second,
   477  			exp:      0,
   478  		},
   479  		{
   480  			name:     "half",
   481  			start:    0,
   482  			curr:     uint64(500 * time.Millisecond),
   483  			interval: time.Second,
   484  			exp:      0,
   485  		},
   486  		{
   487  			name:     "almost",
   488  			start:    0,
   489  			curr:     uint64(1*time.Second - time.Nanosecond),
   490  			interval: time.Second,
   491  			exp:      0,
   492  		},
   493  		{
   494  			name:     "exact",
   495  			start:    0,
   496  			curr:     uint64(1 * time.Second),
   497  			interval: time.Second,
   498  			exp:      1,
   499  		},
   500  		{
   501  			name:     "multiple",
   502  			start:    0,
   503  			curr:     uint64(50*time.Second - 500*time.Millisecond),
   504  			interval: time.Second,
   505  			exp:      49,
   506  		},
   507  		{
   508  			name:     "short",
   509  			start:    0,
   510  			curr:     uint64(50*time.Second - 500*time.Millisecond),
   511  			interval: time.Millisecond,
   512  			exp:      49500,
   513  		},
   514  	}
   515  
   516  	for _, tc := range cases {
   517  		tc := tc
   518  
   519  		t.Run(tc.name, func(t *testing.T) {
   520  			t.Parallel()
   521  
   522  			if got, want := tick(tc.start, tc.curr, tc.interval), tc.exp; got != want {
   523  				t.Errorf("expected %v to be %v", got, want)
   524  			}
   525  		})
   526  	}
   527  }