github.com/zorawar87/trillian@v1.2.1/quota/etcd/storage/quota_storage_test.go (about) 1 // Copyright 2017 Google Inc. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package storage 16 17 import ( 18 "context" 19 "fmt" 20 "os" 21 "strings" 22 "testing" 23 "time" 24 25 "github.com/coreos/etcd/clientv3" 26 "github.com/golang/protobuf/proto" 27 "github.com/google/trillian/quota" 28 "github.com/google/trillian/quota/etcd/storagepb" 29 "github.com/google/trillian/testonly/integration/etcd" 30 "github.com/google/trillian/util" 31 "github.com/kylelemons/godebug/pretty" 32 ) 33 34 const ( 35 quotaMaxTokens = int64(quota.MaxTokens) 36 ) 37 38 var ( 39 cfgs = &storagepb.Configs{ 40 Configs: []*storagepb.Config{ 41 { 42 Name: "quotas/global/read/config", 43 State: storagepb.Config_DISABLED, 44 MaxTokens: 1, 45 ReplenishmentStrategy: &storagepb.Config_TimeBased{ 46 TimeBased: &storagepb.TimeBasedStrategy{ 47 ReplenishIntervalSeconds: 100, 48 TokensToReplenish: 10000, 49 }, 50 }, 51 }, 52 { 53 Name: "quotas/global/write/config", 54 State: storagepb.Config_ENABLED, 55 MaxTokens: 100, 56 ReplenishmentStrategy: &storagepb.Config_SequencingBased{ 57 SequencingBased: &storagepb.SequencingBasedStrategy{}, 58 }, 59 }, 60 { 61 Name: "quotas/users/llama/read/config", 62 State: storagepb.Config_ENABLED, 63 MaxTokens: 1000, 64 ReplenishmentStrategy: &storagepb.Config_TimeBased{ 65 TimeBased: &storagepb.TimeBasedStrategy{ 66 ReplenishIntervalSeconds: 50, 67 TokensToReplenish: 500, 68 }, 69 }, 70 }, 71 }, 72 } 73 globalRead = cfgs.Configs[0] 74 globalWrite = cfgs.Configs[1] 75 userRead = cfgs.Configs[2] 76 77 fixedTimeSource = util.NewFakeTimeSource(time.Now()) 78 79 // client is an etcd client. 80 // Initialized by TestMain(). 81 client *clientv3.Client 82 ) 83 84 func TestMain(m *testing.M) { 85 _, c, cleanup, err := etcd.StartEtcd() 86 if err != nil { 87 panic(fmt.Sprintf("StartEtcd() returned err = %v", err)) 88 } 89 client = c 90 exitCode := m.Run() 91 cleanup() 92 os.Exit(exitCode) 93 } 94 95 func TestIsNameValid(t *testing.T) { 96 tests := []struct { 97 name string 98 want bool 99 }{ 100 {name: "quotas/global/read/config", want: true}, 101 {name: "quotas/global/write/config", want: true}, 102 {name: "quotas/trees/12356/read/config", want: true}, 103 {name: "quotas/users/llama/write/config", want: true}, 104 105 {name: "bad/quota/name"}, 106 {name: "badprefix/quotas/global/read/config"}, 107 {name: "quotas/global/read/config/badsuffix"}, 108 {name: "quotas/bad/read/config"}, 109 {name: "quotas/global/bad/config"}, 110 {name: "quotas/trees/bad/read/config"}, 111 {name: "quotas/trees/11111111111111111111/read/config"}, // ID > MaxInt64 112 } 113 for _, test := range tests { 114 if got := IsNameValid(test.name); got != test.want { 115 t.Errorf("IsNameValid(%q) = %v, want = %v", test.name, got, test.want) 116 } 117 } 118 } 119 120 func TestQuotaStorage_UpdateConfigs(t *testing.T) { 121 defer setupTimeSource(fixedTimeSource)() 122 123 empty := &storagepb.Configs{} 124 125 cfgs2 := deepCopy(cfgs) 126 cfgs2.Configs = cfgs2.Configs[1:] // Remove global/read 127 cfgs2.Configs[0].MaxTokens = 50 // decrease global/write 128 cfgs2.Configs[1].MaxTokens = 10000 // increase user/read 129 130 treeWriteName := "quotas/trees/12345/write/config" 131 cfgs3 := deepCopy(cfgs) 132 cfgs3.Configs = append(cfgs3.Configs, &storagepb.Config{ 133 Name: treeWriteName, 134 State: storagepb.Config_ENABLED, 135 MaxTokens: 200, 136 ReplenishmentStrategy: &storagepb.Config_SequencingBased{ 137 SequencingBased: &storagepb.SequencingBasedStrategy{}, 138 }, 139 }) 140 141 // Note: tests are incremental, not isolated. The preceding test will have impact on the 142 // next, specially if reset is set to false. 143 tests := []struct { 144 desc string 145 reset bool 146 wantCfgs *storagepb.Configs 147 wantTokens map[string]int64 148 }{ 149 { 150 desc: "empty", 151 reset: true, 152 wantCfgs: empty, 153 }, 154 { 155 desc: "cfgs", 156 wantCfgs: cfgs, 157 wantTokens: map[string]int64{ 158 globalRead.Name: quotaMaxTokens, // disabled 159 globalWrite.Name: 100, 160 userRead.Name: 1000, 161 }, 162 }, 163 { 164 desc: "cfgs2", 165 wantCfgs: cfgs2, 166 wantTokens: map[string]int64{ 167 globalWrite.Name: 50, // correctly decreased 168 userRead.Name: 1000, // unaltered 169 }, 170 }, 171 { 172 desc: "cfgs3", 173 wantCfgs: cfgs3, 174 wantTokens: map[string]int64{ 175 globalWrite.Name: 50, // unaltered due to reset = false 176 userRead.Name: 1000, // unaltered 177 treeWriteName: 200, // new 178 }, 179 }, 180 { 181 desc: "cfgs3-pt2", 182 reset: true, 183 wantCfgs: cfgs3, 184 wantTokens: map[string]int64{ 185 globalWrite.Name: 100, // correctly reset 186 userRead.Name: 1000, 187 treeWriteName: 200, 188 }, 189 }, 190 { 191 desc: "cfgs-pt2", 192 wantCfgs: cfgs, 193 wantTokens: map[string]int64{ 194 globalWrite.Name: 100, 195 userRead.Name: 1000, 196 treeWriteName: quotaMaxTokens, // deleted / infinite 197 }, 198 }, 199 } 200 201 ctx := context.Background() 202 qs := &QuotaStorage{Client: client} 203 for _, test := range tests { 204 cfgs, err := qs.UpdateConfigs(ctx, test.reset, updater(test.wantCfgs)) 205 if err != nil { 206 t.Errorf("%v: UpdateConfigs() returned err = %v", test.desc, err) 207 continue 208 } 209 if got, want := cfgs, test.wantCfgs; !proto.Equal(got, want) { 210 diff := pretty.Compare(got, want) 211 t.Errorf("%v: post-UpdateConfigs() diff (-got +want)\n%v", test.desc, diff) 212 } 213 214 stored, err := qs.Configs(ctx) 215 if err != nil { 216 t.Errorf("%v:Configs() returned err = %v", test.desc, err) 217 continue 218 } 219 if got, want := stored, cfgs; !proto.Equal(got, want) { 220 diff := pretty.Compare(got, want) 221 t.Errorf("%v: post-Configs() diff (-got +want)\n%v", test.desc, diff) 222 } 223 224 if err := peekAndDiff(ctx, qs, test.wantTokens); err != nil { 225 t.Errorf("%v: %v", test.desc, err) 226 } 227 } 228 } 229 230 func TestQuotaStorage_UpdateConfigsErrors(t *testing.T) { 231 globalWriteCfgs := &storagepb.Configs{Configs: []*storagepb.Config{globalWrite}} 232 233 emptyName := deepCopy(globalWriteCfgs) 234 emptyName.Configs[0].Name = "" 235 236 invalidName1 := deepCopy(globalWriteCfgs) 237 invalidName1.Configs[0].Name = "invalid" 238 239 invalidName2 := deepCopy(globalWriteCfgs) 240 invalidName2.Configs[0].Name = "quotas/tree/1234/write" // should be "trees", plural 241 242 unknownState := deepCopy(globalWriteCfgs) 243 unknownState.Configs[0].State = storagepb.Config_UNKNOWN_CONFIG_STATE 244 245 zeroMaxTokens := deepCopy(globalWriteCfgs) 246 zeroMaxTokens.Configs[0].MaxTokens = 0 247 248 invalidMaxTokens := deepCopy(globalWriteCfgs) 249 invalidMaxTokens.Configs[0].MaxTokens = -1 250 251 noReplenishmentStrategy := deepCopy(globalWriteCfgs) 252 noReplenishmentStrategy.Configs[0].ReplenishmentStrategy = nil 253 254 zeroTimeBasedTokens := deepCopy(globalWriteCfgs) 255 zeroTimeBasedTokens.Configs[0].ReplenishmentStrategy = &storagepb.Config_TimeBased{ 256 TimeBased: &storagepb.TimeBasedStrategy{ 257 TokensToReplenish: 0, 258 ReplenishIntervalSeconds: 10, 259 }, 260 } 261 262 invalidTimeBasedTokens := deepCopy(globalWriteCfgs) 263 invalidTimeBasedTokens.Configs[0].ReplenishmentStrategy = &storagepb.Config_TimeBased{ 264 TimeBased: &storagepb.TimeBasedStrategy{ 265 TokensToReplenish: -1, 266 ReplenishIntervalSeconds: 10, 267 }, 268 } 269 270 zeroReplenishInterval := deepCopy(globalWriteCfgs) 271 zeroReplenishInterval.Configs[0].ReplenishmentStrategy = &storagepb.Config_TimeBased{ 272 TimeBased: &storagepb.TimeBasedStrategy{ 273 TokensToReplenish: 1, 274 ReplenishIntervalSeconds: 0, 275 }, 276 } 277 278 invalidReplenishInterval := deepCopy(globalWriteCfgs) 279 invalidReplenishInterval.Configs[0].ReplenishmentStrategy = &storagepb.Config_TimeBased{ 280 TimeBased: &storagepb.TimeBasedStrategy{ 281 TokensToReplenish: 1, 282 ReplenishIntervalSeconds: -1, 283 }, 284 } 285 286 duplicateNames := &storagepb.Configs{Configs: []*storagepb.Config{globalRead, globalWrite, globalWrite}} 287 288 sequencingBasedStrategy := &storagepb.Config_SequencingBased{SequencingBased: &storagepb.SequencingBasedStrategy{}} 289 sequencingBasedUserQuota := &storagepb.Configs{ 290 Configs: []*storagepb.Config{ 291 { 292 Name: userRead.Name, 293 State: userRead.State, 294 MaxTokens: userRead.MaxTokens, 295 ReplenishmentStrategy: sequencingBasedStrategy, 296 }, 297 }, 298 } 299 300 sequencingBasedReadQuota1 := deepCopy(globalWriteCfgs) 301 sequencingBasedReadQuota1.Configs[0].Name = globalRead.Name 302 sequencingBasedReadQuota1.Configs[0].ReplenishmentStrategy = sequencingBasedStrategy 303 304 sequencingBasedReadQuota2 := deepCopy(globalWriteCfgs) 305 sequencingBasedReadQuota2.Configs[0].Name = "quotas/trees/1234/read/config" 306 sequencingBasedReadQuota2.Configs[0].ReplenishmentStrategy = sequencingBasedStrategy 307 308 tests := []struct { 309 desc string 310 update func(*storagepb.Configs) 311 wantErr string 312 }{ 313 {desc: "nil", wantErr: "function required"}, 314 { 315 desc: "emptyName", 316 update: updater(emptyName), 317 wantErr: "name is required", 318 }, 319 { 320 desc: "invalidName1", 321 update: updater(invalidName1), 322 wantErr: "name malformed", 323 }, 324 { 325 desc: "invalidName2", 326 update: updater(invalidName2), 327 wantErr: "name malformed", 328 }, 329 { 330 desc: "unknownState", 331 update: updater(unknownState), 332 wantErr: "state invalid", 333 }, 334 { 335 desc: "zeroMaxTokens", 336 update: updater(zeroMaxTokens), 337 wantErr: "max tokens must be > 0", 338 }, 339 { 340 desc: "invalidMaxTokens", 341 update: updater(invalidMaxTokens), 342 wantErr: "max tokens must be > 0", 343 }, 344 { 345 desc: "noReplenishmentStrategy", 346 update: updater(noReplenishmentStrategy), 347 wantErr: "unsupported replenishment strategy", 348 }, 349 { 350 desc: "zeroTimeBasedTokens", 351 update: updater(zeroTimeBasedTokens), 352 wantErr: "time based tokens must be > 0", 353 }, 354 { 355 desc: "invalidTimeBasedTokens", 356 update: updater(invalidTimeBasedTokens), 357 wantErr: "time based tokens must be > 0", 358 }, 359 { 360 desc: "zeroReplenishInterval", 361 update: updater(zeroReplenishInterval), 362 wantErr: "replenish interval must be > 0", 363 }, 364 { 365 desc: "invalidReplenishInterval", 366 update: updater(invalidReplenishInterval), 367 wantErr: "replenish interval must be > 0", 368 }, 369 { 370 desc: "duplicateNames", 371 update: updater(duplicateNames), 372 wantErr: "duplicate config name", 373 }, 374 { 375 desc: "sequencingBasedUserQuota", 376 update: updater(sequencingBasedUserQuota), 377 wantErr: "cannot use sequencing-based replenishment", 378 }, 379 { 380 desc: "sequencingBasedReadQuota1", 381 update: updater(sequencingBasedReadQuota1), 382 wantErr: "cannot use sequencing-based replenishment", 383 }, 384 { 385 desc: "sequencingBasedReadQuota2", 386 update: updater(sequencingBasedReadQuota2), 387 wantErr: "cannot use sequencing-based replenishment", 388 }, 389 } 390 391 ctx := context.Background() 392 qs := &QuotaStorage{Client: client} 393 394 want := &storagepb.Configs{} // default cfgs is empty 395 if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(want)); err != nil { 396 t.Fatalf("UpdateConfigs() returned err = %v", err) 397 } 398 399 for _, test := range tests { 400 if _, err := qs.UpdateConfigs(ctx, false /* reset */, test.update); !strings.Contains(err.Error(), test.wantErr) { 401 // Fatal because the config has been changed, which will break all following tests. 402 t.Fatalf("%v: UpdateConfigs() returned err = %v, want substring %q", test.desc, err, test.wantErr) 403 } 404 405 stored, err := qs.Configs(ctx) 406 if err != nil { 407 t.Errorf("%v:Configs() returned err = %v", test.desc, err) 408 continue 409 } 410 if got := stored; !proto.Equal(got, want) { 411 diff := pretty.Compare(got, want) 412 t.Fatalf("%v: post-Configs() diff (-got +want)\n%v", test.desc, diff) 413 } 414 } 415 } 416 417 func TestQuotaStorage_DeletedConfig(t *testing.T) { 418 defer setupTimeSource(fixedTimeSource)() 419 420 ctx := context.Background() 421 qs := &QuotaStorage{Client: client} 422 423 cfgs := deepCopy(cfgs) 424 cfgs.Configs = cfgs.Configs[1:2] // Only global/write 425 globalWrite := cfgs.Configs[0] 426 if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(cfgs)); err != nil { 427 t.Fatalf("UpdateConfigs() returned err = %v", err) 428 } 429 430 // Normal quota behavior 431 names := []string{globalWrite.Name} 432 _ = qs.Get(ctx, names, 100) 433 if err := peekAndDiff(ctx, qs, map[string]int64{globalWrite.Name: globalWrite.MaxTokens - 100}); err != nil { 434 t.Fatalf("peekAndDiff returned err = %v", err) 435 } 436 437 // Deleted: considered infinite 438 cfgs = &storagepb.Configs{} 439 if _, err := qs.UpdateConfigs(ctx, false /* reset */, updater(cfgs)); err != nil { 440 t.Fatalf("UpdateConfigs() returned err = %v", err) 441 } 442 if err := peekAndDiff(ctx, qs, map[string]int64{globalWrite.Name: quotaMaxTokens}); err != nil { 443 t.Fatalf("peekAndDiff returned err = %v", err) 444 } 445 446 // Restored: must behave as new (ie, doesn't "revive" the old token count) 447 cfgs = &storagepb.Configs{Configs: []*storagepb.Config{globalWrite}} 448 if _, err := qs.UpdateConfigs(ctx, false /* reset */, updater(cfgs)); err != nil { 449 t.Fatalf("UpdateConfigs() returned err = %v", err) 450 } 451 if err := peekAndDiff(ctx, qs, map[string]int64{globalWrite.Name: globalWrite.MaxTokens}); err != nil { 452 t.Fatalf("peekAndDiff returned err = %v", err) 453 } 454 } 455 456 func TestQuotaStorage_DisabledConfig(t *testing.T) { 457 defer setupTimeSource(fixedTimeSource)() 458 459 ctx := context.Background() 460 qs := &QuotaStorage{Client: client} 461 462 cfgs := deepCopy(cfgs) 463 cfgs.Configs = cfgs.Configs[0:1] // Only global/read 464 globalRead := cfgs.Configs[0] 465 globalRead.State = storagepb.Config_ENABLED 466 globalRead.MaxTokens = 1000 467 if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(cfgs)); err != nil { 468 t.Fatalf("UpdateConfigs() returned err = %v", err) 469 } 470 471 // Normal quota behavior 472 names := []string{globalRead.Name} 473 _ = qs.Get(ctx, names, 100) 474 if err := peekAndDiff(ctx, qs, map[string]int64{globalRead.Name: globalRead.MaxTokens - 100}); err != nil { 475 t.Fatalf("peekAndDiff returned err = %v", err) 476 } 477 478 // Disabled: cfg still exists, but is considered infinite 479 globalRead.State = storagepb.Config_DISABLED 480 if _, err := qs.UpdateConfigs(ctx, false /* reset */, updater(cfgs)); err != nil { 481 t.Fatalf("UpdateConfigs() returned err = %v", err) 482 } 483 if err := peekAndDiff(ctx, qs, map[string]int64{globalRead.Name: quotaMaxTokens}); err != nil { 484 t.Fatalf("peekAndDiff returned err = %v", err) 485 } 486 487 // Enabled: tokens restored to ceiling, even though reset = false 488 globalRead.State = storagepb.Config_ENABLED 489 if _, err := qs.UpdateConfigs(ctx, false /* reset */, updater(cfgs)); err != nil { 490 t.Fatalf("UpdateConfigs() returned err = %v", err) 491 } 492 if err := peekAndDiff(ctx, qs, map[string]int64{globalRead.Name: globalRead.MaxTokens}); err != nil { 493 t.Fatalf("peekAndDiff returned err = %v", err) 494 } 495 } 496 497 func TestQuotaStorage_Get(t *testing.T) { 498 fakeTime := util.NewFakeTimeSource(time.Now()) 499 setupTimeSource(fakeTime) 500 501 tests := []struct { 502 desc string 503 names []string 504 tokens int64 505 nowIncrement time.Duration 506 initialTokens, wantTokens map[string]int64 507 }{ 508 { 509 desc: "success", 510 names: []string{globalRead.Name, globalWrite.Name, userRead.Name}, 511 tokens: 5, 512 wantTokens: map[string]int64{ 513 globalRead.Name: quotaMaxTokens, // disabled 514 globalWrite.Name: globalWrite.MaxTokens - 5, 515 userRead.Name: userRead.MaxTokens - 5, 516 }, 517 }, 518 { 519 desc: "globalOnly", 520 names: []string{globalWrite.Name}, 521 tokens: 7, 522 wantTokens: map[string]int64{ 523 globalWrite.Name: globalWrite.MaxTokens - 7, 524 userRead.Name: userRead.MaxTokens, 525 }, 526 }, 527 { 528 desc: "userOnly", 529 names: []string{userRead.Name}, 530 tokens: 7, 531 wantTokens: map[string]int64{ 532 globalWrite.Name: globalWrite.MaxTokens, 533 userRead.Name: userRead.MaxTokens - 7, 534 }, 535 }, 536 { 537 desc: "zeroTokens", 538 names: []string{globalWrite.Name, userRead.Name}, 539 tokens: 0, 540 wantTokens: map[string]int64{ 541 globalWrite.Name: globalWrite.MaxTokens, 542 userRead.Name: userRead.MaxTokens, 543 }, 544 }, 545 { 546 desc: "successWithReplenishment", 547 names: []string{globalWrite.Name, userRead.Name}, 548 tokens: 5, 549 nowIncrement: time.Duration(userRead.GetTimeBased().ReplenishIntervalSeconds) * time.Second, 550 wantTokens: map[string]int64{ 551 globalWrite.Name: globalWrite.MaxTokens - 5, 552 userRead.Name: userRead.MaxTokens - 5, // Replenished then deduced 553 }, 554 }, 555 { 556 desc: "successDueToReplenishment", 557 names: []string{globalWrite.Name, userRead.Name}, 558 tokens: 1, 559 nowIncrement: time.Duration(userRead.GetTimeBased().ReplenishIntervalSeconds) * time.Second, 560 initialTokens: map[string]int64{ 561 userRead.Name: 0, 562 }, 563 wantTokens: map[string]int64{ 564 globalWrite.Name: globalWrite.MaxTokens - 1, 565 userRead.Name: userRead.GetTimeBased().TokensToReplenish - 1, 566 }, 567 }, 568 } 569 570 ctx := context.Background() 571 qs := &QuotaStorage{Client: client} 572 for _, test := range tests { 573 if err := setupTokens(ctx, qs, cfgs, test.initialTokens); err != nil { 574 t.Errorf("%v: setupTokens() returned err = %v", test.desc, err) 575 continue 576 } 577 578 fakeTime.Set(fakeTime.Now().Add(test.nowIncrement)) 579 if err := qs.Get(ctx, test.names, test.tokens); err != nil { 580 t.Errorf("%v: Get() returned err = %v", test.desc, err) 581 continue 582 } 583 584 if err := peekAndDiff(ctx, qs, test.wantTokens); err != nil { 585 t.Errorf("%v: %v", test.desc, err) 586 } 587 } 588 } 589 590 func TestQuotaStorage_GetErrors(t *testing.T) { 591 tests := []struct { 592 desc string 593 names []string 594 tokens int64 595 }{ 596 { 597 desc: "invalidTokens", 598 names: []string{globalWrite.Name, userRead.Name}, 599 tokens: -1, 600 }, 601 { 602 desc: "insufficientTokens", 603 names: []string{globalWrite.Name, userRead.Name}, 604 tokens: globalWrite.MaxTokens + 10, 605 }, 606 } 607 608 ctx := context.Background() 609 qs := &QuotaStorage{Client: client} 610 if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(cfgs)); err != nil { 611 t.Fatalf("UpdateConfigs() returned err = %v", err) 612 } 613 614 for _, test := range tests { 615 if err := qs.Get(ctx, test.names, test.tokens); err == nil { 616 t.Errorf("%v: Get() returned err = nil, want non-nil", test.desc) 617 } 618 } 619 } 620 621 func TestQuotaStorage_Peek(t *testing.T) { 622 fakeTime := util.NewFakeTimeSource(time.Now()) 623 defer setupTimeSource(fakeTime)() 624 625 tests := []struct { 626 desc string 627 names []string 628 nowIncrement time.Duration 629 initialTokens, wantTokens map[string]int64 630 }{ 631 { 632 desc: "success", 633 names: []string{globalRead.Name, globalWrite.Name, userRead.Name, "quotas/users/llama/write/config"}, 634 wantTokens: map[string]int64{ 635 globalRead.Name: quotaMaxTokens, // disabled 636 globalWrite.Name: globalWrite.MaxTokens, 637 userRead.Name: userRead.MaxTokens, 638 "quotas/users/llama/write/config": quotaMaxTokens, // unknown 639 }, 640 }, 641 { 642 desc: "timeBasedReplenish", 643 names: []string{globalWrite.Name, userRead.Name}, 644 nowIncrement: time.Duration(userRead.GetTimeBased().ReplenishIntervalSeconds) * time.Second, 645 initialTokens: map[string]int64{ 646 globalWrite.Name: 10, 647 userRead.Name: 10, 648 }, 649 wantTokens: map[string]int64{ 650 globalWrite.Name: 10, 651 userRead.Name: 10 + userRead.GetTimeBased().TokensToReplenish, 652 }, 653 }, 654 } 655 656 ctx := context.Background() 657 qs := &QuotaStorage{Client: client} 658 for _, test := range tests { 659 if err := setupTokens(ctx, qs, cfgs, test.initialTokens); err != nil { 660 t.Errorf("%v: setupTokens() returned err = %v", test.desc, err) 661 continue 662 } 663 664 fakeTime.Set(fakeTime.Now().Add(test.nowIncrement)) 665 if err := peekAndDiff(ctx, qs, test.wantTokens); err != nil { 666 t.Errorf("%v: %v", test.desc, err) 667 } 668 } 669 } 670 671 func TestQuotaStorage_Put(t *testing.T) { 672 fakeTime := util.NewFakeTimeSource(time.Now()) 673 defer setupTimeSource(fakeTime)() 674 675 tests := []struct { 676 desc string 677 names []string 678 tokens int64 679 nowIncrement time.Duration 680 initialTokens, wantTokens map[string]int64 681 }{ 682 { 683 desc: "zero", 684 names: []string{globalWrite.Name, userRead.Name}, 685 tokens: 0, 686 wantTokens: map[string]int64{ 687 globalWrite.Name: globalWrite.MaxTokens, 688 userRead.Name: userRead.MaxTokens, 689 }, 690 }, 691 { 692 desc: "success", 693 names: []string{globalRead.Name, globalWrite.Name, userRead.Name}, 694 tokens: 10, 695 initialTokens: map[string]int64{ 696 globalWrite.Name: 10, 697 userRead.Name: 10, 698 }, 699 wantTokens: map[string]int64{ 700 globalRead.Name: quotaMaxTokens, // disabled 701 globalWrite.Name: 20, 702 userRead.Name: 10, // Time-based quotas don't change on Put() 703 }, 704 }, 705 { 706 desc: "fullQuota", 707 names: []string{globalWrite.Name, userRead.Name}, 708 tokens: 10, 709 wantTokens: map[string]int64{ 710 globalWrite.Name: globalWrite.MaxTokens, 711 userRead.Name: userRead.MaxTokens, 712 }, 713 }, 714 { 715 desc: "replenishToFull", 716 names: []string{userRead.Name}, 717 tokens: 0, 718 nowIncrement: time.Duration(userRead.GetTimeBased().ReplenishIntervalSeconds) * time.Second, 719 initialTokens: map[string]int64{ 720 userRead.Name: userRead.MaxTokens - 1, 721 }, 722 wantTokens: map[string]int64{ 723 userRead.Name: userRead.MaxTokens, 724 }, 725 }, 726 { 727 desc: "partialReplenish", 728 names: []string{userRead.Name}, 729 tokens: 100, 730 nowIncrement: time.Duration(userRead.GetTimeBased().ReplenishIntervalSeconds) * time.Second, 731 initialTokens: map[string]int64{ 732 userRead.Name: 0, 733 }, 734 wantTokens: map[string]int64{ 735 userRead.Name: userRead.GetTimeBased().TokensToReplenish, 736 }, 737 }, 738 } 739 740 ctx := context.Background() 741 qs := &QuotaStorage{Client: client} 742 for _, test := range tests { 743 if err := setupTokens(ctx, qs, cfgs, test.initialTokens); err != nil { 744 t.Errorf("%v: setupTokens() returned err = %v", test.desc, err) 745 continue 746 } 747 748 if err := qs.Put(ctx, test.names, test.tokens); err != nil { 749 t.Errorf("%v: Put() returned err = %v", test.desc, err) 750 } 751 752 fakeTime.Set(fakeTime.Now().Add(test.nowIncrement)) 753 if err := peekAndDiff(ctx, qs, test.wantTokens); err != nil { 754 t.Errorf("%v: %v", test.desc, err) 755 } 756 } 757 } 758 759 func TestQuotaStorage_PutErrors(t *testing.T) { 760 tests := []struct { 761 desc string 762 names []string 763 tokens int64 764 }{ 765 {desc: "invalidTokens", names: []string{globalWrite.Name}, tokens: -1}, 766 } 767 768 ctx := context.Background() 769 qs := &QuotaStorage{Client: client} 770 if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(cfgs)); err != nil { 771 t.Fatalf("UpdateConfigs() returned err = %v", err) 772 } 773 774 for _, test := range tests { 775 if err := qs.Put(ctx, test.names, test.tokens); err == nil { 776 t.Errorf("%v: Put() returned err = nil, want non-nil", test.desc) 777 } 778 } 779 } 780 781 func TestQuotaStorage_Reset(t *testing.T) { 782 defer setupTimeSource(fixedTimeSource)() 783 784 tests := []struct { 785 desc string 786 names []string 787 initialTokens, wantTokens map[string]int64 788 }{ 789 { 790 desc: "success", 791 names: []string{globalRead.Name, globalWrite.Name, userRead.Name}, 792 initialTokens: map[string]int64{ 793 globalWrite.Name: 10, 794 userRead.Name: 10, 795 }, 796 wantTokens: map[string]int64{ 797 globalRead.Name: quotaMaxTokens, // disabled 798 globalWrite.Name: globalWrite.MaxTokens, 799 userRead.Name: userRead.MaxTokens, 800 }, 801 }, 802 { 803 desc: "globalWrite", 804 names: []string{globalWrite.Name}, 805 initialTokens: map[string]int64{ 806 globalWrite.Name: 10, 807 }, 808 wantTokens: map[string]int64{ 809 globalWrite.Name: globalWrite.MaxTokens, 810 }, 811 }, 812 { 813 desc: "userRead", 814 names: []string{userRead.Name}, 815 initialTokens: map[string]int64{ 816 userRead.Name: 10, 817 }, 818 wantTokens: map[string]int64{ 819 userRead.Name: userRead.MaxTokens, 820 }, 821 }, 822 { 823 desc: "fullQuotas", 824 names: []string{globalWrite.Name, userRead.Name}, 825 wantTokens: map[string]int64{ 826 globalWrite.Name: globalWrite.MaxTokens, 827 userRead.Name: userRead.MaxTokens, 828 }, 829 }, 830 { 831 desc: "unknownQuota", 832 names: []string{"quotas/users/llama/write/config"}, 833 wantTokens: map[string]int64{ 834 "quotas/users/llama/write/config": quotaMaxTokens, 835 }, 836 }, 837 } 838 839 ctx := context.Background() 840 qs := &QuotaStorage{Client: client} 841 if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(cfgs)); err != nil { 842 t.Fatalf("UpdateConfigs() returned err = %v", err) 843 } 844 845 for _, test := range tests { 846 if err := setupTokens(ctx, qs, cfgs, test.initialTokens); err != nil { 847 t.Errorf("%v: setupTokens() returned err = %v", test.desc, err) 848 continue 849 } 850 851 if err := qs.Reset(ctx, test.names); err != nil { 852 t.Errorf("%v: Reset() returned err = %v", test.desc, err) 853 } 854 855 if err := peekAndDiff(ctx, qs, test.wantTokens); err != nil { 856 t.Errorf("%v: %v", test.desc, err) 857 } 858 } 859 } 860 861 func TestQuotaStorage_ValidateNames(t *testing.T) { 862 fns := []struct { 863 name string 864 run func(context.Context, *QuotaStorage, []string) error 865 }{ 866 { 867 name: "Get", 868 run: func(ctx context.Context, qs *QuotaStorage, names []string) error { 869 return qs.Get(ctx, names, 0) 870 }, 871 }, 872 { 873 name: "Peek", 874 run: func(ctx context.Context, qs *QuotaStorage, names []string) error { 875 _, err := qs.Peek(ctx, names) 876 return err 877 }, 878 }, 879 { 880 name: "Put", 881 run: func(ctx context.Context, qs *QuotaStorage, names []string) error { 882 return qs.Put(ctx, names, 0) 883 }, 884 }, 885 { 886 name: "Reset", 887 run: func(ctx context.Context, qs *QuotaStorage, names []string) error { 888 return qs.Reset(ctx, names) 889 }, 890 }, 891 } 892 893 tests := []struct { 894 names []string 895 }{ 896 {names: []string{"bad/quota/name"}}, 897 {names: []string{"quotas/bad/read/configs"}}, 898 {names: []string{"quotas/global/read"}}, // missing "/configs" 899 {names: []string{"quotas/trees/1234/write"}}, 900 {names: []string{"quotas/users/llama/write"}}, 901 {names: []string{"quotas/tree/1234/read/configs"}}, // should be "trees" 902 {names: []string{"quotas/user/llama/read/configs"}}, // should be "users" 903 {names: []string{globalWrite.Name, "bad"}}, 904 } 905 906 ctx := context.Background() 907 qs := &QuotaStorage{Client: client} 908 for _, test := range tests { 909 for _, fn := range fns { 910 if err := fn.run(ctx, qs, test.names); err == nil { 911 t.Errorf("%v(%v) returned err = nil, want non-nil", fn.name, test.names) 912 } 913 } 914 } 915 } 916 917 func peekAndDiff(ctx context.Context, qs *QuotaStorage, want map[string]int64) error { 918 got, err := qs.Peek(ctx, keys(want)) 919 if err != nil { 920 return err 921 } 922 if diff := pretty.Compare(got, want); diff != "" { 923 return fmt.Errorf("post-Peek() diff (-got +want):\n%v", diff) 924 } 925 return nil 926 } 927 928 // setupTimeSource prepares timeSource for tests. 929 // A cleanup function that restores timeSource to its initial value is returned and should be 930 // defer-called. 931 func setupTimeSource(ts util.TimeSource) func() { 932 prevTimeSource := timeSource 933 timeSource = ts 934 return func() { timeSource = prevTimeSource } 935 } 936 937 // setupTokens resets cfgs and gets tokens from each quota in order to make them match 938 // initialTokens. 939 func setupTokens(ctx context.Context, qs *QuotaStorage, cfgs *storagepb.Configs, initialTokens map[string]int64) error { 940 if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(cfgs)); err != nil { 941 return fmt.Errorf("UpdateConfigs() returned err = %v", err) 942 } 943 for name, wantTokens := range initialTokens { 944 names := []string{name} 945 tokens, err := qs.Peek(ctx, names) 946 if err != nil { 947 return fmt.Errorf("Peek() returned err = %v", err) 948 } 949 mod := tokens[name] - wantTokens 950 if err := qs.Get(ctx, names, mod); err != nil { 951 return fmt.Errorf("Get() returned err = %v", err) 952 } 953 if err := peekAndDiff(ctx, qs, map[string]int64{name: wantTokens}); err != nil { 954 return err 955 } 956 } 957 return nil 958 } 959 960 func deepCopy(c1 *storagepb.Configs) *storagepb.Configs { 961 c2 := &storagepb.Configs{ 962 Configs: make([]*storagepb.Config, 0, len(c1.Configs)), 963 } 964 for _, cfg := range c1.Configs { 965 cp := *cfg 966 c2.Configs = append(c2.Configs, &cp) 967 } 968 return c2 969 } 970 971 func keys(m map[string]int64) []string { 972 keys := make([]string, 0, len(m)) 973 for k := range m { 974 keys = append(keys, k) 975 } 976 return keys 977 } 978 979 func updater(cfgs *storagepb.Configs) func(*storagepb.Configs) { 980 return func(c *storagepb.Configs) { 981 *c = *cfgs 982 } 983 }