github.com/sethvargo/go-limiter@v1.0.0/memorystore/store_test.go (about) 1 package memorystore 2 3 import ( 4 "context" 5 "crypto/rand" 6 "crypto/sha256" 7 "fmt" 8 "sort" 9 "testing" 10 "time" 11 12 "github.com/sethvargo/go-limiter/internal/fasttime" 13 ) 14 15 func TestFillRate(t *testing.T) { 16 t.Parallel() 17 18 t.Run("many_tokens_small_interval", func(t *testing.T) { 19 t.Parallel() 20 21 s, _ := New(&Config{ 22 Tokens: 65525, 23 Interval: time.Second, 24 }) 25 26 for i := 0; i < 20; i++ { 27 limit, remaining, _, _, _ := s.Take(context.Background(), "key") 28 if remaining < limit-uint64(i)-1 { 29 t.Errorf("invalid remaining: run: %d limit: %d remaining: %d", i, limit, remaining) 30 } 31 time.Sleep(100 * time.Millisecond) 32 } 33 }) 34 } 35 36 func testKey(tb testing.TB) string { 37 tb.Helper() 38 39 var b [512]byte 40 if _, err := rand.Read(b[:]); err != nil { 41 tb.Fatalf("failed to generate random string: %v", err) 42 } 43 digest := fmt.Sprintf("%x", sha256.Sum256(b[:])) 44 return digest[:32] 45 } 46 47 func TestStore_Exercise(t *testing.T) { 48 t.Parallel() 49 50 ctx := context.Background() 51 52 s, err := New(&Config{ 53 Tokens: 5, 54 Interval: 3 * time.Second, 55 SweepInterval: 24 * time.Hour, 56 SweepMinTTL: 24 * time.Hour, 57 }) 58 if err != nil { 59 t.Fatal(err) 60 } 61 t.Cleanup(func() { 62 if err := s.Close(ctx); err != nil { 63 t.Fatal(err) 64 } 65 }) 66 67 key := testKey(t) 68 69 // Get when no config exists 70 { 71 limit, remaining, err := s.(*store).Get(ctx, key) 72 if err != nil { 73 t.Fatal(err) 74 } 75 76 if got, want := limit, uint64(0); got != want { 77 t.Errorf("expected %v to be %v", got, want) 78 } 79 if got, want := remaining, uint64(0); got != want { 80 t.Errorf("expected %v to be %v", got, want) 81 } 82 } 83 84 // Take with no key configuration - this should use the default values 85 { 86 limit, remaining, reset, ok, err := s.Take(ctx, key) 87 if err != nil { 88 t.Fatal(err) 89 } 90 if !ok { 91 t.Errorf("expected ok") 92 } 93 if got, want := limit, uint64(5); got != want { 94 t.Errorf("expected %v to be %v", got, want) 95 } 96 if got, want := remaining, uint64(4); got != want { 97 t.Errorf("expected %v to be %v", got, want) 98 } 99 if got, want := time.Until(time.Unix(0, int64(reset))), 3*time.Second; got > want { 100 t.Errorf("expected %v to less than %v", got, want) 101 } 102 } 103 104 // Get the value 105 { 106 limit, remaining, err := s.(*store).Get(ctx, key) 107 if err != nil { 108 t.Fatal(err) 109 } 110 if got, want := limit, uint64(5); got != want { 111 t.Errorf("expected %v to be %v", got, want) 112 } 113 if got, want := remaining, uint64(4); got != want { 114 t.Errorf("expected %v to be %v", got, want) 115 } 116 } 117 118 // Now set a value 119 { 120 if err := s.Set(ctx, key, 11, 5*time.Second); err != nil { 121 t.Fatal(err) 122 } 123 } 124 125 // Get the value again 126 { 127 limit, remaining, err := s.(*store).Get(ctx, key) 128 if err != nil { 129 t.Fatal(err) 130 } 131 if got, want := limit, uint64(11); got != want { 132 t.Errorf("expected %v to be %v", got, want) 133 } 134 if got, want := remaining, uint64(11); got != want { 135 t.Errorf("expected %v to be %v", got, want) 136 } 137 } 138 139 // Take again, this should use the new values 140 { 141 limit, remaining, reset, ok, err := s.Take(ctx, key) 142 if err != nil { 143 t.Fatal(err) 144 } 145 if !ok { 146 t.Errorf("expected ok") 147 } 148 if got, want := limit, uint64(11); got != want { 149 t.Errorf("expected %v to be %v", got, want) 150 } 151 if got, want := remaining, uint64(10); got != want { 152 t.Errorf("expected %v to be %v", got, want) 153 } 154 if got, want := time.Until(time.Unix(0, int64(reset))), 5*time.Second; got > want { 155 t.Errorf("expected %v to less than %v", got, want) 156 } 157 } 158 159 // Get the value again 160 { 161 limit, remaining, err := s.(*store).Get(ctx, key) 162 if err != nil { 163 t.Fatal(err) 164 } 165 if got, want := limit, uint64(11); got != want { 166 t.Errorf("expected %v to be %v", got, want) 167 } 168 if got, want := remaining, uint64(10); got != want { 169 t.Errorf("expected %v to be %v", got, want) 170 } 171 } 172 173 // Burst and take 174 { 175 if err := s.Burst(ctx, key, 5); err != nil { 176 t.Fatal(err) 177 } 178 179 limit, remaining, reset, ok, err := s.Take(ctx, key) 180 if err != nil { 181 t.Fatal(err) 182 } 183 if !ok { 184 t.Errorf("expected ok") 185 } 186 if got, want := limit, uint64(11); got != want { 187 t.Errorf("expected %v to be %v", got, want) 188 } 189 if got, want := remaining, uint64(14); got != want { 190 t.Errorf("expected %v to be %v", got, want) 191 } 192 if got, want := time.Until(time.Unix(0, int64(reset))), 5*time.Second; got > want { 193 t.Errorf("expected %v to less than %v", got, want) 194 } 195 } 196 197 // Get the value one final time 198 { 199 limit, remaining, err := s.(*store).Get(ctx, key) 200 if err != nil { 201 t.Fatal(err) 202 } 203 if got, want := limit, uint64(11); got != want { 204 t.Errorf("expected %v to be %v", got, want) 205 } 206 if got, want := remaining, uint64(14); got != want { 207 t.Errorf("expected %v to be %v", got, want) 208 } 209 } 210 } 211 212 func TestStore_Take(t *testing.T) { 213 t.Parallel() 214 215 ctx := context.Background() 216 217 cases := []struct { 218 name string 219 tokens uint64 220 interval time.Duration 221 }{ 222 { 223 name: "milli", 224 tokens: 5, 225 interval: 500 * time.Millisecond, 226 }, 227 { 228 name: "second", 229 tokens: 10, 230 interval: 1 * time.Second, 231 }, 232 } 233 234 for _, tc := range cases { 235 tc := tc 236 237 t.Run(tc.name, func(t *testing.T) { 238 t.Parallel() 239 240 key := testKey(t) 241 242 s, err := New(&Config{ 243 Interval: tc.interval, 244 Tokens: tc.tokens, 245 SweepInterval: 24 * time.Hour, 246 SweepMinTTL: 24 * time.Hour, 247 }) 248 if err != nil { 249 t.Fatal(err) 250 } 251 t.Cleanup(func() { 252 if err := s.Close(ctx); err != nil { 253 t.Fatal(err) 254 } 255 }) 256 257 type result struct { 258 limit, remaining uint64 259 reset time.Duration 260 ok bool 261 err error 262 } 263 264 // Take twice everything from the bucket. 265 takeCh := make(chan *result, 2*tc.tokens) 266 for i := uint64(1); i <= 2*tc.tokens; i++ { 267 go func() { 268 limit, remaining, reset, ok, err := s.Take(ctx, key) 269 takeCh <- &result{limit, remaining, time.Duration(fasttime.Now() - reset), ok, err} 270 }() 271 } 272 273 // Accumulate and sort results, since they could come in any order. 274 var results []*result 275 for i := uint64(1); i <= 2*tc.tokens; i++ { 276 select { 277 case result := <-takeCh: 278 results = append(results, result) 279 case <-time.After(5 * time.Second): 280 t.Fatal("timeout") 281 } 282 } 283 sort.Slice(results, func(i, j int) bool { 284 if results[i].remaining == results[j].remaining { 285 return !results[j].ok 286 } 287 return results[i].remaining > results[j].remaining 288 }) 289 290 for i, result := range results { 291 if err := result.err; err != nil { 292 t.Fatal(err) 293 } 294 295 if got, want := result.limit, tc.tokens; got != want { 296 t.Errorf("limit: expected %d to be %d", got, want) 297 } 298 if got, want := result.reset, tc.interval; got > want { 299 t.Errorf("reset: expected %d to be less than %d", got, want) 300 } 301 302 // first half should pass, second half should fail 303 if uint64(i) < tc.tokens { 304 if got, want := result.remaining, tc.tokens-uint64(i)-1; got != want { 305 t.Errorf("remaining: expected %d to be %d", got, want) 306 } 307 if got, want := result.ok, true; got != want { 308 t.Errorf("ok: expected %t to be %t", got, want) 309 } 310 } else { 311 if got, want := result.remaining, uint64(0); got != want { 312 t.Errorf("remaining: expected %d to be %d", got, want) 313 } 314 if got, want := result.ok, false; got != want { 315 t.Errorf("ok: expected %t to be %t", got, want) 316 } 317 } 318 } 319 320 // Wait for the bucket to have entries again. 321 time.Sleep(tc.interval) 322 323 // Verify we can take once more. 324 _, _, _, ok, err := s.Take(ctx, key) 325 if err != nil { 326 t.Fatal(err) 327 } 328 if !ok { 329 t.Errorf("expected %t to be %t", ok, true) 330 } 331 }) 332 } 333 } 334 335 func TestBucketedLimiter_tick(t *testing.T) { 336 t.Parallel() 337 338 cases := []struct { 339 name string 340 start uint64 341 curr uint64 342 interval time.Duration 343 exp uint64 344 }{ 345 { 346 name: "no_diff", 347 start: 0, 348 curr: 0, 349 interval: time.Second, 350 exp: 0, 351 }, 352 { 353 name: "half", 354 start: 0, 355 curr: uint64(500 * time.Millisecond), 356 interval: time.Second, 357 exp: 0, 358 }, 359 { 360 name: "almost", 361 start: 0, 362 curr: uint64(1*time.Second - time.Nanosecond), 363 interval: time.Second, 364 exp: 0, 365 }, 366 { 367 name: "exact", 368 start: 0, 369 curr: uint64(1 * time.Second), 370 interval: time.Second, 371 exp: 1, 372 }, 373 { 374 name: "multiple", 375 start: 0, 376 curr: uint64(50*time.Second - 500*time.Millisecond), 377 interval: time.Second, 378 exp: 49, 379 }, 380 { 381 name: "short", 382 start: 0, 383 curr: uint64(50*time.Second - 500*time.Millisecond), 384 interval: time.Millisecond, 385 exp: 49500, 386 }, 387 } 388 389 for _, tc := range cases { 390 tc := tc 391 392 t.Run(tc.name, func(t *testing.T) { 393 t.Parallel() 394 395 if got, want := tick(tc.start, tc.curr, tc.interval), tc.exp; got != want { 396 t.Errorf("expected %v to be %v", got, want) 397 } 398 }) 399 } 400 }