gitlab.com/infor-cloud/martian-cloud/tharsis/go-limiter@v0.0.0-20230411193226-3247984d5abc/memorystore/store_test.go (about) 1 package memorystore 2 3 import ( 4 "context" 5 "crypto/rand" 6 "crypto/sha256" 7 "fmt" 8 "sort" 9 "testing" 10 "time" 11 12 "gitlab.com/infor-cloud/martian-cloud/tharsis/go-limiter/fasttime" 13 ) 14 15 func TestFillRate(t *testing.T) { 16 t.Parallel() 17 18 t.Run("many_tokens_small_interval", func(t *testing.T) { 19 t.Parallel() 20 21 s, _ := New(&Config{ 22 Tokens: 65525, 23 Interval: time.Second, 24 }) 25 26 for i := 0; i < 20; i++ { 27 limit, remaining, _, _, _ := s.Take(context.Background(), "key") 28 if remaining < limit-uint64(i)-1 { 29 t.Errorf("invalid remaining: run: %d limit: %d remaining: %d", i, limit, remaining) 30 } 31 time.Sleep(100 * time.Millisecond) 32 } 33 }) 34 } 35 36 func testKey(tb testing.TB) string { 37 tb.Helper() 38 39 var b [512]byte 40 if _, err := rand.Read(b[:]); err != nil { 41 tb.Fatalf("failed to generate random string: %v", err) 42 } 43 digest := fmt.Sprintf("%x", sha256.Sum256(b[:])) 44 return digest[:32] 45 } 46 47 func TestStore_Exercise(t *testing.T) { 48 t.Parallel() 49 50 ctx := context.Background() 51 52 s, err := New(&Config{ 53 Tokens: 5, 54 Interval: 3 * time.Second, 55 SweepInterval: 24 * time.Hour, 56 SweepMinTTL: 24 * time.Hour, 57 }) 58 if err != nil { 59 t.Fatal(err) 60 } 61 t.Cleanup(func() { 62 if err := s.Close(ctx); err != nil { 63 t.Fatal(err) 64 } 65 }) 66 67 key := testKey(t) 68 69 // Get when no config exists 70 { 71 limit, remaining, err := s.(*store).Get(ctx, key) 72 if err != nil { 73 t.Fatal(err) 74 } 75 76 if got, want := limit, uint64(0); got != want { 77 t.Errorf("expected %v to be %v", got, want) 78 } 79 if got, want := remaining, uint64(0); got != want { 80 t.Errorf("expected %v to be %v", got, want) 81 } 82 } 83 84 // Take with no key configuration - this should use the default values 85 { 86 limit, remaining, reset, ok, err := s.Take(ctx, key) 87 if err != nil { 88 t.Fatal(err) 89 } 90 if !ok { 91 t.Errorf("expected ok") 92 } 93 if got, want := limit, uint64(5); got != want { 94 t.Errorf("expected %v to be %v", got, want) 95 } 96 if got, want := remaining, uint64(4); got != want { 97 t.Errorf("expected %v to be %v", got, want) 98 } 99 if got, want := time.Until(time.Unix(0, int64(reset))), 3*time.Second; got > want { 100 t.Errorf("expected %v to less than %v", got, want) 101 } 102 } 103 104 // Get the value 105 { 106 limit, remaining, err := s.(*store).Get(ctx, key) 107 if err != nil { 108 t.Fatal(err) 109 } 110 if got, want := limit, uint64(5); got != want { 111 t.Errorf("expected %v to be %v", got, want) 112 } 113 if got, want := remaining, uint64(4); got != want { 114 t.Errorf("expected %v to be %v", got, want) 115 } 116 } 117 118 // Now set a value 119 { 120 if err := s.Set(ctx, key, 11, 5*time.Second); err != nil { 121 t.Fatal(err) 122 } 123 } 124 125 // Get the value again 126 { 127 limit, remaining, err := s.(*store).Get(ctx, key) 128 if err != nil { 129 t.Fatal(err) 130 } 131 if got, want := limit, uint64(11); got != want { 132 t.Errorf("expected %v to be %v", got, want) 133 } 134 if got, want := remaining, uint64(11); got != want { 135 t.Errorf("expected %v to be %v", got, want) 136 } 137 } 138 139 // Take again, this should use the new values 140 { 141 limit, remaining, reset, ok, err := s.Take(ctx, key) 142 if err != nil { 143 t.Fatal(err) 144 } 145 if !ok { 146 t.Errorf("expected ok") 147 } 148 if got, want := limit, uint64(11); got != want { 149 t.Errorf("expected %v to be %v", got, want) 150 } 151 if got, want := remaining, uint64(10); got != want { 152 t.Errorf("expected %v to be %v", got, want) 153 } 154 if got, want := time.Until(time.Unix(0, int64(reset))), 5*time.Second; got > want { 155 t.Errorf("expected %v to less than %v", got, want) 156 } 157 } 158 159 // Get the value again 160 { 161 limit, remaining, err := s.(*store).Get(ctx, key) 162 if err != nil { 163 t.Fatal(err) 164 } 165 if got, want := limit, uint64(11); got != want { 166 t.Errorf("expected %v to be %v", got, want) 167 } 168 if got, want := remaining, uint64(10); got != want { 169 t.Errorf("expected %v to be %v", got, want) 170 } 171 } 172 173 // Burst and take 174 { 175 if err := s.Burst(ctx, key, 5); err != nil { 176 t.Fatal(err) 177 } 178 179 limit, remaining, reset, ok, err := s.Take(ctx, key) 180 if err != nil { 181 t.Fatal(err) 182 } 183 if !ok { 184 t.Errorf("expected ok") 185 } 186 if got, want := limit, uint64(11); got != want { 187 t.Errorf("expected %v to be %v", got, want) 188 } 189 if got, want := remaining, uint64(14); got != want { 190 t.Errorf("expected %v to be %v", got, want) 191 } 192 if got, want := time.Until(time.Unix(0, int64(reset))), 5*time.Second; got > want { 193 t.Errorf("expected %v to less than %v", got, want) 194 } 195 } 196 197 // Get the value one final time 198 { 199 limit, remaining, err := s.(*store).Get(ctx, key) 200 if err != nil { 201 t.Fatal(err) 202 } 203 if got, want := limit, uint64(11); got != want { 204 t.Errorf("expected %v to be %v", got, want) 205 } 206 if got, want := remaining, uint64(14); got != want { 207 t.Errorf("expected %v to be %v", got, want) 208 } 209 } 210 } 211 212 func TestStore_Take(t *testing.T) { 213 t.Parallel() 214 215 ctx := context.Background() 216 217 cases := []struct { 218 name string 219 tokens uint64 220 interval time.Duration 221 }{ 222 { 223 name: "milli", 224 tokens: 5, 225 interval: 500 * time.Millisecond, 226 }, 227 { 228 name: "second", 229 tokens: 10, 230 interval: 1 * time.Second, 231 }, 232 } 233 234 for _, tc := range cases { 235 tc := tc 236 237 t.Run(tc.name, func(t *testing.T) { 238 t.Parallel() 239 240 key := testKey(t) 241 242 s, err := New(&Config{ 243 Interval: tc.interval, 244 Tokens: tc.tokens, 245 SweepInterval: 24 * time.Hour, 246 SweepMinTTL: 24 * time.Hour, 247 }) 248 if err != nil { 249 t.Fatal(err) 250 } 251 t.Cleanup(func() { 252 if err := s.Close(ctx); err != nil { 253 t.Fatal(err) 254 } 255 }) 256 257 type result struct { 258 limit, remaining uint64 259 reset time.Duration 260 ok bool 261 err error 262 } 263 264 // Take twice everything from the bucket. 265 takeCh := make(chan *result, 2*tc.tokens) 266 for i := uint64(1); i <= 2*tc.tokens; i++ { 267 go func() { 268 limit, remaining, reset, ok, err := s.Take(ctx, key) 269 takeCh <- &result{limit, remaining, time.Duration(fasttime.Now() - reset), ok, err} 270 }() 271 } 272 273 // Accumulate and sort results, since they could come in any order. 274 var results []*result 275 for i := uint64(1); i <= 2*tc.tokens; i++ { 276 select { 277 case result := <-takeCh: 278 results = append(results, result) 279 case <-time.After(5 * time.Second): 280 t.Fatal("timeout") 281 } 282 } 283 sort.Slice(results, func(i, j int) bool { 284 if results[i].remaining == results[j].remaining { 285 return !results[j].ok 286 } 287 return results[i].remaining > results[j].remaining 288 }) 289 290 for i, result := range results { 291 if err := result.err; err != nil { 292 t.Fatal(err) 293 } 294 295 if got, want := result.limit, tc.tokens; got != want { 296 t.Errorf("limit: expected %d to be %d", got, want) 297 } 298 if got, want := result.reset, tc.interval; got > want { 299 t.Errorf("reset: expected %d to be less than %d", got, want) 300 } 301 302 // first half should pass, second half should fail 303 if uint64(i) < tc.tokens { 304 if got, want := result.remaining, tc.tokens-uint64(i)-1; got != want { 305 t.Errorf("remaining: expected %d to be %d", got, want) 306 } 307 if got, want := result.ok, true; got != want { 308 t.Errorf("ok: expected %t to be %t", got, want) 309 } 310 } else { 311 if got, want := result.remaining, uint64(0); got != want { 312 t.Errorf("remaining: expected %d to be %d", got, want) 313 } 314 if got, want := result.ok, false; got != want { 315 t.Errorf("ok: expected %t to be %t", got, want) 316 } 317 } 318 } 319 320 // Wait for the bucket to have entries again. 321 time.Sleep(tc.interval) 322 323 // Verify we can take once more. 324 _, _, _, ok, err := s.Take(ctx, key) 325 if err != nil { 326 t.Fatal(err) 327 } 328 if !ok { 329 t.Errorf("expected %t to be %t", ok, true) 330 } 331 }) 332 } 333 } 334 335 func TestStore_TakeMany(t *testing.T) { 336 t.Parallel() 337 338 ctx := context.Background() 339 340 cases := []struct { 341 name string 342 tokens uint64 343 takeAmount uint64 344 interval time.Duration 345 }{ 346 { 347 name: "milli", 348 tokens: 10, 349 takeAmount: 2, 350 interval: 500 * time.Millisecond, 351 }, 352 { 353 name: "second", 354 tokens: 10, 355 takeAmount: 4, 356 interval: 1 * time.Second, 357 }, 358 } 359 360 for _, tc := range cases { 361 tc := tc 362 363 t.Run(tc.name, func(t *testing.T) { 364 t.Parallel() 365 366 key := testKey(t) 367 368 s, err := New(&Config{ 369 Interval: tc.interval, 370 Tokens: tc.tokens, 371 SweepInterval: 24 * time.Hour, 372 SweepMinTTL: 24 * time.Hour, 373 }) 374 if err != nil { 375 t.Fatal(err) 376 } 377 t.Cleanup(func() { 378 if err := s.Close(ctx); err != nil { 379 t.Fatal(err) 380 } 381 }) 382 383 type result struct { 384 limit, remaining uint64 385 reset time.Duration 386 ok bool 387 err error 388 } 389 390 // Take everything from the bucket. 391 //possibleIterations calculates possible iterations for test case 392 possibleIterations := tc.tokens / tc.takeAmount 393 takeCh := make(chan *result, 2*possibleIterations) 394 for i := uint64(1); i <= 2*possibleIterations; i++ { 395 go func() { 396 limit, remaining, reset, ok, err := s.TakeMany(ctx, key, tc.takeAmount) 397 takeCh <- &result{limit, remaining, time.Duration(fasttime.Now() - reset), ok, err} 398 }() 399 } 400 401 // Accumulate and sort results, since they could come in any order. 402 var results []*result 403 for i := uint64(1); i <= 2*possibleIterations; i++ { 404 select { 405 case result := <-takeCh: 406 results = append(results, result) 407 case <-time.After(5 * time.Second): 408 t.Fatal("timeout") 409 } 410 } 411 sort.Slice(results, func(i, j int) bool { 412 if results[i].remaining == results[j].remaining { 413 return !results[j].ok 414 } 415 return results[i].remaining > results[j].remaining 416 }) 417 418 for i, result := range results { 419 if err := result.err; err != nil { 420 t.Fatal(err) 421 } 422 if got, want := result.limit, tc.tokens; got != want { 423 t.Errorf("limit: expected %d to be %d", got, want) 424 } 425 if got, want := result.reset, tc.interval; got > want { 426 t.Errorf("reset: expected %d to be less than %d", got, want) 427 } 428 429 // first half should pass, second half should fail 430 if uint64(i+1)*tc.takeAmount <= tc.tokens { 431 if got, want := result.remaining, tc.tokens-uint64(i+1)*tc.takeAmount; got != want { 432 t.Errorf("remaining: expected %d to be %d", got, want) 433 } 434 if got, want := result.ok, true; got != want { 435 t.Errorf("ok: expected %t to be %t", got, want) 436 } 437 } else { 438 if got, want := result.remaining, uint64(0); got != want { 439 t.Errorf("remaining: expected %d to be %d", got, want) 440 } 441 if got, want := result.ok, false; got != want { 442 t.Errorf("ok 0: expected %t to be %t", got, want) 443 } 444 } 445 } 446 447 // Wait for the bucket to have entries again. 448 time.Sleep(tc.interval) 449 450 // Verify we can take once more. 451 _, _, _, ok, err := s.TakeMany(ctx, key, tc.takeAmount) 452 if err != nil { 453 t.Fatal(err) 454 } 455 if !ok { 456 t.Errorf("expected %t to be %t", ok, true) 457 } 458 }) 459 } 460 } 461 462 func TestBucketedLimiter_tick(t *testing.T) { 463 t.Parallel() 464 465 cases := []struct { 466 name string 467 start uint64 468 curr uint64 469 interval time.Duration 470 exp uint64 471 }{ 472 { 473 name: "no_diff", 474 start: 0, 475 curr: 0, 476 interval: time.Second, 477 exp: 0, 478 }, 479 { 480 name: "half", 481 start: 0, 482 curr: uint64(500 * time.Millisecond), 483 interval: time.Second, 484 exp: 0, 485 }, 486 { 487 name: "almost", 488 start: 0, 489 curr: uint64(1*time.Second - time.Nanosecond), 490 interval: time.Second, 491 exp: 0, 492 }, 493 { 494 name: "exact", 495 start: 0, 496 curr: uint64(1 * time.Second), 497 interval: time.Second, 498 exp: 1, 499 }, 500 { 501 name: "multiple", 502 start: 0, 503 curr: uint64(50*time.Second - 500*time.Millisecond), 504 interval: time.Second, 505 exp: 49, 506 }, 507 { 508 name: "short", 509 start: 0, 510 curr: uint64(50*time.Second - 500*time.Millisecond), 511 interval: time.Millisecond, 512 exp: 49500, 513 }, 514 } 515 516 for _, tc := range cases { 517 tc := tc 518 519 t.Run(tc.name, func(t *testing.T) { 520 t.Parallel() 521 522 if got, want := tick(tc.start, tc.curr, tc.interval), tc.exp; got != want { 523 t.Errorf("expected %v to be %v", got, want) 524 } 525 }) 526 } 527 }