github.com/letsencrypt/trillian@v1.1.2-0.20180615153820-ae375a99d36a/quota/etcd/storage/quota_storage_test.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package storage
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"os"
    21  	"strings"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/coreos/etcd/clientv3"
    26  	"github.com/golang/protobuf/proto"
    27  	"github.com/google/trillian/quota"
    28  	"github.com/google/trillian/quota/etcd/storagepb"
    29  	"github.com/google/trillian/testonly/integration/etcd"
    30  	"github.com/google/trillian/util"
    31  	"github.com/kylelemons/godebug/pretty"
    32  )
    33  
    34  const (
    35  	quotaMaxTokens = int64(quota.MaxTokens)
    36  )
    37  
    38  var (
    39  	cfgs = &storagepb.Configs{
    40  		Configs: []*storagepb.Config{
    41  			{
    42  				Name:      "quotas/global/read/config",
    43  				State:     storagepb.Config_DISABLED,
    44  				MaxTokens: 1,
    45  				ReplenishmentStrategy: &storagepb.Config_TimeBased{
    46  					TimeBased: &storagepb.TimeBasedStrategy{
    47  						ReplenishIntervalSeconds: 100,
    48  						TokensToReplenish:        10000,
    49  					},
    50  				},
    51  			},
    52  			{
    53  				Name:      "quotas/global/write/config",
    54  				State:     storagepb.Config_ENABLED,
    55  				MaxTokens: 100,
    56  				ReplenishmentStrategy: &storagepb.Config_SequencingBased{
    57  					SequencingBased: &storagepb.SequencingBasedStrategy{},
    58  				},
    59  			},
    60  			{
    61  				Name:      "quotas/users/llama/read/config",
    62  				State:     storagepb.Config_ENABLED,
    63  				MaxTokens: 1000,
    64  				ReplenishmentStrategy: &storagepb.Config_TimeBased{
    65  					TimeBased: &storagepb.TimeBasedStrategy{
    66  						ReplenishIntervalSeconds: 50,
    67  						TokensToReplenish:        500,
    68  					},
    69  				},
    70  			},
    71  		},
    72  	}
    73  	globalRead  = cfgs.Configs[0]
    74  	globalWrite = cfgs.Configs[1]
    75  	userRead    = cfgs.Configs[2]
    76  
    77  	fixedTimeSource = util.NewFakeTimeSource(time.Now())
    78  
    79  	// client is an etcd client.
    80  	// Initialized by TestMain().
    81  	client *clientv3.Client
    82  )
    83  
    84  func TestMain(m *testing.M) {
    85  	_, c, cleanup, err := etcd.StartEtcd()
    86  	if err != nil {
    87  		panic(fmt.Sprintf("StartEtcd() returned err = %v", err))
    88  	}
    89  	client = c
    90  	exitCode := m.Run()
    91  	cleanup()
    92  	os.Exit(exitCode)
    93  }
    94  
    95  func TestIsNameValid(t *testing.T) {
    96  	tests := []struct {
    97  		name string
    98  		want bool
    99  	}{
   100  		{name: "quotas/global/read/config", want: true},
   101  		{name: "quotas/global/write/config", want: true},
   102  		{name: "quotas/trees/12356/read/config", want: true},
   103  		{name: "quotas/users/llama/write/config", want: true},
   104  
   105  		{name: "bad/quota/name"},
   106  		{name: "badprefix/quotas/global/read/config"},
   107  		{name: "quotas/global/read/config/badsuffix"},
   108  		{name: "quotas/bad/read/config"},
   109  		{name: "quotas/global/bad/config"},
   110  		{name: "quotas/trees/bad/read/config"},
   111  		{name: "quotas/trees/11111111111111111111/read/config"}, // ID > MaxInt64
   112  	}
   113  	for _, test := range tests {
   114  		if got := IsNameValid(test.name); got != test.want {
   115  			t.Errorf("IsNameValid(%q) = %v, want = %v", test.name, got, test.want)
   116  		}
   117  	}
   118  }
   119  
   120  func TestQuotaStorage_UpdateConfigs(t *testing.T) {
   121  	defer setupTimeSource(fixedTimeSource)()
   122  
   123  	empty := &storagepb.Configs{}
   124  
   125  	cfgs2 := deepCopy(cfgs)
   126  	cfgs2.Configs = cfgs2.Configs[1:]  // Remove global/read
   127  	cfgs2.Configs[0].MaxTokens = 50    // decrease global/write
   128  	cfgs2.Configs[1].MaxTokens = 10000 // increase user/read
   129  
   130  	treeWriteName := "quotas/trees/12345/write/config"
   131  	cfgs3 := deepCopy(cfgs)
   132  	cfgs3.Configs = append(cfgs3.Configs, &storagepb.Config{
   133  		Name:      treeWriteName,
   134  		State:     storagepb.Config_ENABLED,
   135  		MaxTokens: 200,
   136  		ReplenishmentStrategy: &storagepb.Config_SequencingBased{
   137  			SequencingBased: &storagepb.SequencingBasedStrategy{},
   138  		},
   139  	})
   140  
   141  	// Note: tests are incremental, not isolated. The preceding test will have impact on the
   142  	// next, specially if reset is set to false.
   143  	tests := []struct {
   144  		desc       string
   145  		reset      bool
   146  		wantCfgs   *storagepb.Configs
   147  		wantTokens map[string]int64
   148  	}{
   149  		{
   150  			desc:     "empty",
   151  			reset:    true,
   152  			wantCfgs: empty,
   153  		},
   154  		{
   155  			desc:     "cfgs",
   156  			wantCfgs: cfgs,
   157  			wantTokens: map[string]int64{
   158  				globalRead.Name:  quotaMaxTokens, // disabled
   159  				globalWrite.Name: 100,
   160  				userRead.Name:    1000,
   161  			},
   162  		},
   163  		{
   164  			desc:     "cfgs2",
   165  			wantCfgs: cfgs2,
   166  			wantTokens: map[string]int64{
   167  				globalWrite.Name: 50,   // correctly decreased
   168  				userRead.Name:    1000, // unaltered
   169  			},
   170  		},
   171  		{
   172  			desc:     "cfgs3",
   173  			wantCfgs: cfgs3,
   174  			wantTokens: map[string]int64{
   175  				globalWrite.Name: 50,   // unaltered due to reset = false
   176  				userRead.Name:    1000, // unaltered
   177  				treeWriteName:    200,  // new
   178  			},
   179  		},
   180  		{
   181  			desc:     "cfgs3-pt2",
   182  			reset:    true,
   183  			wantCfgs: cfgs3,
   184  			wantTokens: map[string]int64{
   185  				globalWrite.Name: 100, // correctly reset
   186  				userRead.Name:    1000,
   187  				treeWriteName:    200,
   188  			},
   189  		},
   190  		{
   191  			desc:     "cfgs-pt2",
   192  			wantCfgs: cfgs,
   193  			wantTokens: map[string]int64{
   194  				globalWrite.Name: 100,
   195  				userRead.Name:    1000,
   196  				treeWriteName:    quotaMaxTokens, // deleted / infinite
   197  			},
   198  		},
   199  	}
   200  
   201  	ctx := context.Background()
   202  	qs := &QuotaStorage{Client: client}
   203  	for _, test := range tests {
   204  		cfgs, err := qs.UpdateConfigs(ctx, test.reset, updater(test.wantCfgs))
   205  		if err != nil {
   206  			t.Errorf("%v: UpdateConfigs() returned err = %v", test.desc, err)
   207  			continue
   208  		}
   209  		if got, want := cfgs, test.wantCfgs; !proto.Equal(got, want) {
   210  			diff := pretty.Compare(got, want)
   211  			t.Errorf("%v: post-UpdateConfigs() diff (-got +want)\n%v", test.desc, diff)
   212  		}
   213  
   214  		stored, err := qs.Configs(ctx)
   215  		if err != nil {
   216  			t.Errorf("%v:Configs() returned err = %v", test.desc, err)
   217  			continue
   218  		}
   219  		if got, want := stored, cfgs; !proto.Equal(got, want) {
   220  			diff := pretty.Compare(got, want)
   221  			t.Errorf("%v: post-Configs() diff (-got +want)\n%v", test.desc, diff)
   222  		}
   223  
   224  		if err := peekAndDiff(ctx, qs, test.wantTokens); err != nil {
   225  			t.Errorf("%v: %v", test.desc, err)
   226  		}
   227  	}
   228  }
   229  
   230  func TestQuotaStorage_UpdateConfigsErrors(t *testing.T) {
   231  	globalWriteCfgs := &storagepb.Configs{Configs: []*storagepb.Config{globalWrite}}
   232  
   233  	emptyName := deepCopy(globalWriteCfgs)
   234  	emptyName.Configs[0].Name = ""
   235  
   236  	invalidName1 := deepCopy(globalWriteCfgs)
   237  	invalidName1.Configs[0].Name = "invalid"
   238  
   239  	invalidName2 := deepCopy(globalWriteCfgs)
   240  	invalidName2.Configs[0].Name = "quotas/tree/1234/write" // should be "trees", plural
   241  
   242  	unknownState := deepCopy(globalWriteCfgs)
   243  	unknownState.Configs[0].State = storagepb.Config_UNKNOWN_CONFIG_STATE
   244  
   245  	zeroMaxTokens := deepCopy(globalWriteCfgs)
   246  	zeroMaxTokens.Configs[0].MaxTokens = 0
   247  
   248  	invalidMaxTokens := deepCopy(globalWriteCfgs)
   249  	invalidMaxTokens.Configs[0].MaxTokens = -1
   250  
   251  	noReplenishmentStrategy := deepCopy(globalWriteCfgs)
   252  	noReplenishmentStrategy.Configs[0].ReplenishmentStrategy = nil
   253  
   254  	zeroTimeBasedTokens := deepCopy(globalWriteCfgs)
   255  	zeroTimeBasedTokens.Configs[0].ReplenishmentStrategy = &storagepb.Config_TimeBased{
   256  		TimeBased: &storagepb.TimeBasedStrategy{
   257  			TokensToReplenish:        0,
   258  			ReplenishIntervalSeconds: 10,
   259  		},
   260  	}
   261  
   262  	invalidTimeBasedTokens := deepCopy(globalWriteCfgs)
   263  	invalidTimeBasedTokens.Configs[0].ReplenishmentStrategy = &storagepb.Config_TimeBased{
   264  		TimeBased: &storagepb.TimeBasedStrategy{
   265  			TokensToReplenish:        -1,
   266  			ReplenishIntervalSeconds: 10,
   267  		},
   268  	}
   269  
   270  	zeroReplenishInterval := deepCopy(globalWriteCfgs)
   271  	zeroReplenishInterval.Configs[0].ReplenishmentStrategy = &storagepb.Config_TimeBased{
   272  		TimeBased: &storagepb.TimeBasedStrategy{
   273  			TokensToReplenish:        1,
   274  			ReplenishIntervalSeconds: 0,
   275  		},
   276  	}
   277  
   278  	invalidReplenishInterval := deepCopy(globalWriteCfgs)
   279  	invalidReplenishInterval.Configs[0].ReplenishmentStrategy = &storagepb.Config_TimeBased{
   280  		TimeBased: &storagepb.TimeBasedStrategy{
   281  			TokensToReplenish:        1,
   282  			ReplenishIntervalSeconds: -1,
   283  		},
   284  	}
   285  
   286  	duplicateNames := &storagepb.Configs{Configs: []*storagepb.Config{globalRead, globalWrite, globalWrite}}
   287  
   288  	sequencingBasedStrategy := &storagepb.Config_SequencingBased{SequencingBased: &storagepb.SequencingBasedStrategy{}}
   289  	sequencingBasedUserQuota := &storagepb.Configs{
   290  		Configs: []*storagepb.Config{
   291  			{
   292  				Name:                  userRead.Name,
   293  				State:                 userRead.State,
   294  				MaxTokens:             userRead.MaxTokens,
   295  				ReplenishmentStrategy: sequencingBasedStrategy,
   296  			},
   297  		},
   298  	}
   299  
   300  	sequencingBasedReadQuota1 := deepCopy(globalWriteCfgs)
   301  	sequencingBasedReadQuota1.Configs[0].Name = globalRead.Name
   302  	sequencingBasedReadQuota1.Configs[0].ReplenishmentStrategy = sequencingBasedStrategy
   303  
   304  	sequencingBasedReadQuota2 := deepCopy(globalWriteCfgs)
   305  	sequencingBasedReadQuota2.Configs[0].Name = "quotas/trees/1234/read/config"
   306  	sequencingBasedReadQuota2.Configs[0].ReplenishmentStrategy = sequencingBasedStrategy
   307  
   308  	tests := []struct {
   309  		desc    string
   310  		update  func(*storagepb.Configs)
   311  		wantErr string
   312  	}{
   313  		{desc: "nil", wantErr: "function required"},
   314  		{
   315  			desc:    "emptyName",
   316  			update:  updater(emptyName),
   317  			wantErr: "name is required",
   318  		},
   319  		{
   320  			desc:    "invalidName1",
   321  			update:  updater(invalidName1),
   322  			wantErr: "name malformed",
   323  		},
   324  		{
   325  			desc:    "invalidName2",
   326  			update:  updater(invalidName2),
   327  			wantErr: "name malformed",
   328  		},
   329  		{
   330  			desc:    "unknownState",
   331  			update:  updater(unknownState),
   332  			wantErr: "state invalid",
   333  		},
   334  		{
   335  			desc:    "zeroMaxTokens",
   336  			update:  updater(zeroMaxTokens),
   337  			wantErr: "max tokens must be > 0",
   338  		},
   339  		{
   340  			desc:    "invalidMaxTokens",
   341  			update:  updater(invalidMaxTokens),
   342  			wantErr: "max tokens must be > 0",
   343  		},
   344  		{
   345  			desc:    "noReplenishmentStrategy",
   346  			update:  updater(noReplenishmentStrategy),
   347  			wantErr: "unsupported replenishment strategy",
   348  		},
   349  		{
   350  			desc:    "zeroTimeBasedTokens",
   351  			update:  updater(zeroTimeBasedTokens),
   352  			wantErr: "time based tokens must be > 0",
   353  		},
   354  		{
   355  			desc:    "invalidTimeBasedTokens",
   356  			update:  updater(invalidTimeBasedTokens),
   357  			wantErr: "time based tokens must be > 0",
   358  		},
   359  		{
   360  			desc:    "zeroReplenishInterval",
   361  			update:  updater(zeroReplenishInterval),
   362  			wantErr: "replenish interval must be > 0",
   363  		},
   364  		{
   365  			desc:    "invalidReplenishInterval",
   366  			update:  updater(invalidReplenishInterval),
   367  			wantErr: "replenish interval must be > 0",
   368  		},
   369  		{
   370  			desc:    "duplicateNames",
   371  			update:  updater(duplicateNames),
   372  			wantErr: "duplicate config name",
   373  		},
   374  		{
   375  			desc:    "sequencingBasedUserQuota",
   376  			update:  updater(sequencingBasedUserQuota),
   377  			wantErr: "cannot use sequencing-based replenishment",
   378  		},
   379  		{
   380  			desc:    "sequencingBasedReadQuota1",
   381  			update:  updater(sequencingBasedReadQuota1),
   382  			wantErr: "cannot use sequencing-based replenishment",
   383  		},
   384  		{
   385  			desc:    "sequencingBasedReadQuota2",
   386  			update:  updater(sequencingBasedReadQuota2),
   387  			wantErr: "cannot use sequencing-based replenishment",
   388  		},
   389  	}
   390  
   391  	ctx := context.Background()
   392  	qs := &QuotaStorage{Client: client}
   393  
   394  	want := &storagepb.Configs{} // default cfgs is empty
   395  	if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(want)); err != nil {
   396  		t.Fatalf("UpdateConfigs() returned err = %v", err)
   397  	}
   398  
   399  	for _, test := range tests {
   400  		if _, err := qs.UpdateConfigs(ctx, false /* reset */, test.update); !strings.Contains(err.Error(), test.wantErr) {
   401  			// Fatal because the config has been changed, which will break all following tests.
   402  			t.Fatalf("%v: UpdateConfigs() returned err = %v, want substring %q", test.desc, err, test.wantErr)
   403  		}
   404  
   405  		stored, err := qs.Configs(ctx)
   406  		if err != nil {
   407  			t.Errorf("%v:Configs() returned err = %v", test.desc, err)
   408  			continue
   409  		}
   410  		if got := stored; !proto.Equal(got, want) {
   411  			diff := pretty.Compare(got, want)
   412  			t.Fatalf("%v: post-Configs() diff (-got +want)\n%v", test.desc, diff)
   413  		}
   414  	}
   415  }
   416  
   417  func TestQuotaStorage_DeletedConfig(t *testing.T) {
   418  	defer setupTimeSource(fixedTimeSource)()
   419  
   420  	ctx := context.Background()
   421  	qs := &QuotaStorage{Client: client}
   422  
   423  	cfgs := deepCopy(cfgs)
   424  	cfgs.Configs = cfgs.Configs[1:2] // Only global/write
   425  	globalWrite := cfgs.Configs[0]
   426  	if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(cfgs)); err != nil {
   427  		t.Fatalf("UpdateConfigs() returned err = %v", err)
   428  	}
   429  
   430  	// Normal quota behavior
   431  	names := []string{globalWrite.Name}
   432  	_ = qs.Get(ctx, names, 100)
   433  	if err := peekAndDiff(ctx, qs, map[string]int64{globalWrite.Name: globalWrite.MaxTokens - 100}); err != nil {
   434  		t.Fatalf("peekAndDiff returned err = %v", err)
   435  	}
   436  
   437  	// Deleted: considered infinite
   438  	cfgs = &storagepb.Configs{}
   439  	if _, err := qs.UpdateConfigs(ctx, false /* reset */, updater(cfgs)); err != nil {
   440  		t.Fatalf("UpdateConfigs() returned err = %v", err)
   441  	}
   442  	if err := peekAndDiff(ctx, qs, map[string]int64{globalWrite.Name: quotaMaxTokens}); err != nil {
   443  		t.Fatalf("peekAndDiff returned err = %v", err)
   444  	}
   445  
   446  	// Restored: must behave as new (ie, doesn't "revive" the old token count)
   447  	cfgs = &storagepb.Configs{Configs: []*storagepb.Config{globalWrite}}
   448  	if _, err := qs.UpdateConfigs(ctx, false /* reset */, updater(cfgs)); err != nil {
   449  		t.Fatalf("UpdateConfigs() returned err = %v", err)
   450  	}
   451  	if err := peekAndDiff(ctx, qs, map[string]int64{globalWrite.Name: globalWrite.MaxTokens}); err != nil {
   452  		t.Fatalf("peekAndDiff returned err = %v", err)
   453  	}
   454  }
   455  
   456  func TestQuotaStorage_DisabledConfig(t *testing.T) {
   457  	defer setupTimeSource(fixedTimeSource)()
   458  
   459  	ctx := context.Background()
   460  	qs := &QuotaStorage{Client: client}
   461  
   462  	cfgs := deepCopy(cfgs)
   463  	cfgs.Configs = cfgs.Configs[0:1] // Only global/read
   464  	globalRead := cfgs.Configs[0]
   465  	globalRead.State = storagepb.Config_ENABLED
   466  	globalRead.MaxTokens = 1000
   467  	if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(cfgs)); err != nil {
   468  		t.Fatalf("UpdateConfigs() returned err = %v", err)
   469  	}
   470  
   471  	// Normal quota behavior
   472  	names := []string{globalRead.Name}
   473  	_ = qs.Get(ctx, names, 100)
   474  	if err := peekAndDiff(ctx, qs, map[string]int64{globalRead.Name: globalRead.MaxTokens - 100}); err != nil {
   475  		t.Fatalf("peekAndDiff returned err = %v", err)
   476  	}
   477  
   478  	// Disabled: cfg still exists, but is considered infinite
   479  	globalRead.State = storagepb.Config_DISABLED
   480  	if _, err := qs.UpdateConfigs(ctx, false /* reset */, updater(cfgs)); err != nil {
   481  		t.Fatalf("UpdateConfigs() returned err = %v", err)
   482  	}
   483  	if err := peekAndDiff(ctx, qs, map[string]int64{globalRead.Name: quotaMaxTokens}); err != nil {
   484  		t.Fatalf("peekAndDiff returned err = %v", err)
   485  	}
   486  
   487  	// Enabled: tokens restored to ceiling, even though reset = false
   488  	globalRead.State = storagepb.Config_ENABLED
   489  	if _, err := qs.UpdateConfigs(ctx, false /* reset */, updater(cfgs)); err != nil {
   490  		t.Fatalf("UpdateConfigs() returned err = %v", err)
   491  	}
   492  	if err := peekAndDiff(ctx, qs, map[string]int64{globalRead.Name: globalRead.MaxTokens}); err != nil {
   493  		t.Fatalf("peekAndDiff returned err = %v", err)
   494  	}
   495  }
   496  
   497  func TestQuotaStorage_Get(t *testing.T) {
   498  	fakeTime := util.NewFakeTimeSource(time.Now())
   499  	setupTimeSource(fakeTime)
   500  
   501  	tests := []struct {
   502  		desc                      string
   503  		names                     []string
   504  		tokens                    int64
   505  		nowIncrement              time.Duration
   506  		initialTokens, wantTokens map[string]int64
   507  	}{
   508  		{
   509  			desc:   "success",
   510  			names:  []string{globalRead.Name, globalWrite.Name, userRead.Name},
   511  			tokens: 5,
   512  			wantTokens: map[string]int64{
   513  				globalRead.Name:  quotaMaxTokens, // disabled
   514  				globalWrite.Name: globalWrite.MaxTokens - 5,
   515  				userRead.Name:    userRead.MaxTokens - 5,
   516  			},
   517  		},
   518  		{
   519  			desc:   "globalOnly",
   520  			names:  []string{globalWrite.Name},
   521  			tokens: 7,
   522  			wantTokens: map[string]int64{
   523  				globalWrite.Name: globalWrite.MaxTokens - 7,
   524  				userRead.Name:    userRead.MaxTokens,
   525  			},
   526  		},
   527  		{
   528  			desc:   "userOnly",
   529  			names:  []string{userRead.Name},
   530  			tokens: 7,
   531  			wantTokens: map[string]int64{
   532  				globalWrite.Name: globalWrite.MaxTokens,
   533  				userRead.Name:    userRead.MaxTokens - 7,
   534  			},
   535  		},
   536  		{
   537  			desc:   "zeroTokens",
   538  			names:  []string{globalWrite.Name, userRead.Name},
   539  			tokens: 0,
   540  			wantTokens: map[string]int64{
   541  				globalWrite.Name: globalWrite.MaxTokens,
   542  				userRead.Name:    userRead.MaxTokens,
   543  			},
   544  		},
   545  		{
   546  			desc:         "successWithReplenishment",
   547  			names:        []string{globalWrite.Name, userRead.Name},
   548  			tokens:       5,
   549  			nowIncrement: time.Duration(userRead.GetTimeBased().ReplenishIntervalSeconds) * time.Second,
   550  			wantTokens: map[string]int64{
   551  				globalWrite.Name: globalWrite.MaxTokens - 5,
   552  				userRead.Name:    userRead.MaxTokens - 5, // Replenished then deduced
   553  			},
   554  		},
   555  		{
   556  			desc:         "successDueToReplenishment",
   557  			names:        []string{globalWrite.Name, userRead.Name},
   558  			tokens:       1,
   559  			nowIncrement: time.Duration(userRead.GetTimeBased().ReplenishIntervalSeconds) * time.Second,
   560  			initialTokens: map[string]int64{
   561  				userRead.Name: 0,
   562  			},
   563  			wantTokens: map[string]int64{
   564  				globalWrite.Name: globalWrite.MaxTokens - 1,
   565  				userRead.Name:    userRead.GetTimeBased().TokensToReplenish - 1,
   566  			},
   567  		},
   568  	}
   569  
   570  	ctx := context.Background()
   571  	qs := &QuotaStorage{Client: client}
   572  	for _, test := range tests {
   573  		if err := setupTokens(ctx, qs, cfgs, test.initialTokens); err != nil {
   574  			t.Errorf("%v: setupTokens() returned err = %v", test.desc, err)
   575  			continue
   576  		}
   577  
   578  		fakeTime.Set(fakeTime.Now().Add(test.nowIncrement))
   579  		if err := qs.Get(ctx, test.names, test.tokens); err != nil {
   580  			t.Errorf("%v: Get() returned err = %v", test.desc, err)
   581  			continue
   582  		}
   583  
   584  		if err := peekAndDiff(ctx, qs, test.wantTokens); err != nil {
   585  			t.Errorf("%v: %v", test.desc, err)
   586  		}
   587  	}
   588  }
   589  
   590  func TestQuotaStorage_GetErrors(t *testing.T) {
   591  	tests := []struct {
   592  		desc   string
   593  		names  []string
   594  		tokens int64
   595  	}{
   596  		{
   597  			desc:   "invalidTokens",
   598  			names:  []string{globalWrite.Name, userRead.Name},
   599  			tokens: -1,
   600  		},
   601  		{
   602  			desc:   "insufficientTokens",
   603  			names:  []string{globalWrite.Name, userRead.Name},
   604  			tokens: globalWrite.MaxTokens + 10,
   605  		},
   606  	}
   607  
   608  	ctx := context.Background()
   609  	qs := &QuotaStorage{Client: client}
   610  	if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(cfgs)); err != nil {
   611  		t.Fatalf("UpdateConfigs() returned err = %v", err)
   612  	}
   613  
   614  	for _, test := range tests {
   615  		if err := qs.Get(ctx, test.names, test.tokens); err == nil {
   616  			t.Errorf("%v: Get() returned err = nil, want non-nil", test.desc)
   617  		}
   618  	}
   619  }
   620  
   621  func TestQuotaStorage_Peek(t *testing.T) {
   622  	fakeTime := util.NewFakeTimeSource(time.Now())
   623  	defer setupTimeSource(fakeTime)()
   624  
   625  	tests := []struct {
   626  		desc                      string
   627  		names                     []string
   628  		nowIncrement              time.Duration
   629  		initialTokens, wantTokens map[string]int64
   630  	}{
   631  		{
   632  			desc:  "success",
   633  			names: []string{globalRead.Name, globalWrite.Name, userRead.Name, "quotas/users/llama/write/config"},
   634  			wantTokens: map[string]int64{
   635  				globalRead.Name:                   quotaMaxTokens, // disabled
   636  				globalWrite.Name:                  globalWrite.MaxTokens,
   637  				userRead.Name:                     userRead.MaxTokens,
   638  				"quotas/users/llama/write/config": quotaMaxTokens, // unknown
   639  			},
   640  		},
   641  		{
   642  			desc:         "timeBasedReplenish",
   643  			names:        []string{globalWrite.Name, userRead.Name},
   644  			nowIncrement: time.Duration(userRead.GetTimeBased().ReplenishIntervalSeconds) * time.Second,
   645  			initialTokens: map[string]int64{
   646  				globalWrite.Name: 10,
   647  				userRead.Name:    10,
   648  			},
   649  			wantTokens: map[string]int64{
   650  				globalWrite.Name: 10,
   651  				userRead.Name:    10 + userRead.GetTimeBased().TokensToReplenish,
   652  			},
   653  		},
   654  	}
   655  
   656  	ctx := context.Background()
   657  	qs := &QuotaStorage{Client: client}
   658  	for _, test := range tests {
   659  		if err := setupTokens(ctx, qs, cfgs, test.initialTokens); err != nil {
   660  			t.Errorf("%v: setupTokens() returned err = %v", test.desc, err)
   661  			continue
   662  		}
   663  
   664  		fakeTime.Set(fakeTime.Now().Add(test.nowIncrement))
   665  		if err := peekAndDiff(ctx, qs, test.wantTokens); err != nil {
   666  			t.Errorf("%v: %v", test.desc, err)
   667  		}
   668  	}
   669  }
   670  
   671  func TestQuotaStorage_Put(t *testing.T) {
   672  	fakeTime := util.NewFakeTimeSource(time.Now())
   673  	defer setupTimeSource(fakeTime)()
   674  
   675  	tests := []struct {
   676  		desc                      string
   677  		names                     []string
   678  		tokens                    int64
   679  		nowIncrement              time.Duration
   680  		initialTokens, wantTokens map[string]int64
   681  	}{
   682  		{
   683  			desc:   "zero",
   684  			names:  []string{globalWrite.Name, userRead.Name},
   685  			tokens: 0,
   686  			wantTokens: map[string]int64{
   687  				globalWrite.Name: globalWrite.MaxTokens,
   688  				userRead.Name:    userRead.MaxTokens,
   689  			},
   690  		},
   691  		{
   692  			desc:   "success",
   693  			names:  []string{globalRead.Name, globalWrite.Name, userRead.Name},
   694  			tokens: 10,
   695  			initialTokens: map[string]int64{
   696  				globalWrite.Name: 10,
   697  				userRead.Name:    10,
   698  			},
   699  			wantTokens: map[string]int64{
   700  				globalRead.Name:  quotaMaxTokens, // disabled
   701  				globalWrite.Name: 20,
   702  				userRead.Name:    10, // Time-based quotas don't change on Put()
   703  			},
   704  		},
   705  		{
   706  			desc:   "fullQuota",
   707  			names:  []string{globalWrite.Name, userRead.Name},
   708  			tokens: 10,
   709  			wantTokens: map[string]int64{
   710  				globalWrite.Name: globalWrite.MaxTokens,
   711  				userRead.Name:    userRead.MaxTokens,
   712  			},
   713  		},
   714  		{
   715  			desc:         "replenishToFull",
   716  			names:        []string{userRead.Name},
   717  			tokens:       0,
   718  			nowIncrement: time.Duration(userRead.GetTimeBased().ReplenishIntervalSeconds) * time.Second,
   719  			initialTokens: map[string]int64{
   720  				userRead.Name: userRead.MaxTokens - 1,
   721  			},
   722  			wantTokens: map[string]int64{
   723  				userRead.Name: userRead.MaxTokens,
   724  			},
   725  		},
   726  		{
   727  			desc:         "partialReplenish",
   728  			names:        []string{userRead.Name},
   729  			tokens:       100,
   730  			nowIncrement: time.Duration(userRead.GetTimeBased().ReplenishIntervalSeconds) * time.Second,
   731  			initialTokens: map[string]int64{
   732  				userRead.Name: 0,
   733  			},
   734  			wantTokens: map[string]int64{
   735  				userRead.Name: userRead.GetTimeBased().TokensToReplenish,
   736  			},
   737  		},
   738  	}
   739  
   740  	ctx := context.Background()
   741  	qs := &QuotaStorage{Client: client}
   742  	for _, test := range tests {
   743  		if err := setupTokens(ctx, qs, cfgs, test.initialTokens); err != nil {
   744  			t.Errorf("%v: setupTokens() returned err = %v", test.desc, err)
   745  			continue
   746  		}
   747  
   748  		if err := qs.Put(ctx, test.names, test.tokens); err != nil {
   749  			t.Errorf("%v: Put() returned err = %v", test.desc, err)
   750  		}
   751  
   752  		fakeTime.Set(fakeTime.Now().Add(test.nowIncrement))
   753  		if err := peekAndDiff(ctx, qs, test.wantTokens); err != nil {
   754  			t.Errorf("%v: %v", test.desc, err)
   755  		}
   756  	}
   757  }
   758  
   759  func TestQuotaStorage_PutErrors(t *testing.T) {
   760  	tests := []struct {
   761  		desc   string
   762  		names  []string
   763  		tokens int64
   764  	}{
   765  		{desc: "invalidTokens", names: []string{globalWrite.Name}, tokens: -1},
   766  	}
   767  
   768  	ctx := context.Background()
   769  	qs := &QuotaStorage{Client: client}
   770  	if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(cfgs)); err != nil {
   771  		t.Fatalf("UpdateConfigs() returned err = %v", err)
   772  	}
   773  
   774  	for _, test := range tests {
   775  		if err := qs.Put(ctx, test.names, test.tokens); err == nil {
   776  			t.Errorf("%v: Put() returned err = nil, want non-nil", test.desc)
   777  		}
   778  	}
   779  }
   780  
   781  func TestQuotaStorage_Reset(t *testing.T) {
   782  	defer setupTimeSource(fixedTimeSource)()
   783  
   784  	tests := []struct {
   785  		desc                      string
   786  		names                     []string
   787  		initialTokens, wantTokens map[string]int64
   788  	}{
   789  		{
   790  			desc:  "success",
   791  			names: []string{globalRead.Name, globalWrite.Name, userRead.Name},
   792  			initialTokens: map[string]int64{
   793  				globalWrite.Name: 10,
   794  				userRead.Name:    10,
   795  			},
   796  			wantTokens: map[string]int64{
   797  				globalRead.Name:  quotaMaxTokens, // disabled
   798  				globalWrite.Name: globalWrite.MaxTokens,
   799  				userRead.Name:    userRead.MaxTokens,
   800  			},
   801  		},
   802  		{
   803  			desc:  "globalWrite",
   804  			names: []string{globalWrite.Name},
   805  			initialTokens: map[string]int64{
   806  				globalWrite.Name: 10,
   807  			},
   808  			wantTokens: map[string]int64{
   809  				globalWrite.Name: globalWrite.MaxTokens,
   810  			},
   811  		},
   812  		{
   813  			desc:  "userRead",
   814  			names: []string{userRead.Name},
   815  			initialTokens: map[string]int64{
   816  				userRead.Name: 10,
   817  			},
   818  			wantTokens: map[string]int64{
   819  				userRead.Name: userRead.MaxTokens,
   820  			},
   821  		},
   822  		{
   823  			desc:  "fullQuotas",
   824  			names: []string{globalWrite.Name, userRead.Name},
   825  			wantTokens: map[string]int64{
   826  				globalWrite.Name: globalWrite.MaxTokens,
   827  				userRead.Name:    userRead.MaxTokens,
   828  			},
   829  		},
   830  		{
   831  			desc:  "unknownQuota",
   832  			names: []string{"quotas/users/llama/write/config"},
   833  			wantTokens: map[string]int64{
   834  				"quotas/users/llama/write/config": quotaMaxTokens,
   835  			},
   836  		},
   837  	}
   838  
   839  	ctx := context.Background()
   840  	qs := &QuotaStorage{Client: client}
   841  	if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(cfgs)); err != nil {
   842  		t.Fatalf("UpdateConfigs() returned err = %v", err)
   843  	}
   844  
   845  	for _, test := range tests {
   846  		if err := setupTokens(ctx, qs, cfgs, test.initialTokens); err != nil {
   847  			t.Errorf("%v: setupTokens() returned err = %v", test.desc, err)
   848  			continue
   849  		}
   850  
   851  		if err := qs.Reset(ctx, test.names); err != nil {
   852  			t.Errorf("%v: Reset() returned err = %v", test.desc, err)
   853  		}
   854  
   855  		if err := peekAndDiff(ctx, qs, test.wantTokens); err != nil {
   856  			t.Errorf("%v: %v", test.desc, err)
   857  		}
   858  	}
   859  }
   860  
   861  func TestQuotaStorage_ValidateNames(t *testing.T) {
   862  	fns := []struct {
   863  		name string
   864  		run  func(context.Context, *QuotaStorage, []string) error
   865  	}{
   866  		{
   867  			name: "Get",
   868  			run: func(ctx context.Context, qs *QuotaStorage, names []string) error {
   869  				return qs.Get(ctx, names, 0)
   870  			},
   871  		},
   872  		{
   873  			name: "Peek",
   874  			run: func(ctx context.Context, qs *QuotaStorage, names []string) error {
   875  				_, err := qs.Peek(ctx, names)
   876  				return err
   877  			},
   878  		},
   879  		{
   880  			name: "Put",
   881  			run: func(ctx context.Context, qs *QuotaStorage, names []string) error {
   882  				return qs.Put(ctx, names, 0)
   883  			},
   884  		},
   885  		{
   886  			name: "Reset",
   887  			run: func(ctx context.Context, qs *QuotaStorage, names []string) error {
   888  				return qs.Reset(ctx, names)
   889  			},
   890  		},
   891  	}
   892  
   893  	tests := []struct {
   894  		names []string
   895  	}{
   896  		{names: []string{"bad/quota/name"}},
   897  		{names: []string{"quotas/bad/read/configs"}},
   898  		{names: []string{"quotas/global/read"}}, // missing "/configs"
   899  		{names: []string{"quotas/trees/1234/write"}},
   900  		{names: []string{"quotas/users/llama/write"}},
   901  		{names: []string{"quotas/tree/1234/read/configs"}},  // should be "trees"
   902  		{names: []string{"quotas/user/llama/read/configs"}}, // should be "users"
   903  		{names: []string{globalWrite.Name, "bad"}},
   904  	}
   905  
   906  	ctx := context.Background()
   907  	qs := &QuotaStorage{Client: client}
   908  	for _, test := range tests {
   909  		for _, fn := range fns {
   910  			if err := fn.run(ctx, qs, test.names); err == nil {
   911  				t.Errorf("%v(%v) returned err = nil, want non-nil", fn.name, test.names)
   912  			}
   913  		}
   914  	}
   915  }
   916  
   917  func peekAndDiff(ctx context.Context, qs *QuotaStorage, want map[string]int64) error {
   918  	got, err := qs.Peek(ctx, keys(want))
   919  	if err != nil {
   920  		return err
   921  	}
   922  	if diff := pretty.Compare(got, want); diff != "" {
   923  		return fmt.Errorf("post-Peek() diff (-got +want):\n%v", diff)
   924  	}
   925  	return nil
   926  }
   927  
   928  // setupTimeSource prepares timeSource for tests.
   929  // A cleanup function that restores timeSource to its initial value is returned and should be
   930  // defer-called.
   931  func setupTimeSource(ts util.TimeSource) func() {
   932  	prevTimeSource := timeSource
   933  	timeSource = ts
   934  	return func() { timeSource = prevTimeSource }
   935  }
   936  
   937  // setupTokens resets cfgs and gets tokens from each quota in order to make them match
   938  // initialTokens.
   939  func setupTokens(ctx context.Context, qs *QuotaStorage, cfgs *storagepb.Configs, initialTokens map[string]int64) error {
   940  	if _, err := qs.UpdateConfigs(ctx, true /* reset */, updater(cfgs)); err != nil {
   941  		return fmt.Errorf("UpdateConfigs() returned err = %v", err)
   942  	}
   943  	for name, wantTokens := range initialTokens {
   944  		names := []string{name}
   945  		tokens, err := qs.Peek(ctx, names)
   946  		if err != nil {
   947  			return fmt.Errorf("Peek() returned err = %v", err)
   948  		}
   949  		mod := tokens[name] - wantTokens
   950  		if err := qs.Get(ctx, names, mod); err != nil {
   951  			return fmt.Errorf("Get() returned err = %v", err)
   952  		}
   953  		if err := peekAndDiff(ctx, qs, map[string]int64{name: wantTokens}); err != nil {
   954  			return err
   955  		}
   956  	}
   957  	return nil
   958  }
   959  
   960  func deepCopy(c1 *storagepb.Configs) *storagepb.Configs {
   961  	c2 := &storagepb.Configs{
   962  		Configs: make([]*storagepb.Config, 0, len(c1.Configs)),
   963  	}
   964  	for _, cfg := range c1.Configs {
   965  		cp := *cfg
   966  		c2.Configs = append(c2.Configs, &cp)
   967  	}
   968  	return c2
   969  }
   970  
   971  func keys(m map[string]int64) []string {
   972  	keys := make([]string, 0, len(m))
   973  	for k := range m {
   974  		keys = append(keys, k)
   975  	}
   976  	return keys
   977  }
   978  
   979  func updater(cfgs *storagepb.Configs) func(*storagepb.Configs) {
   980  	return func(c *storagepb.Configs) {
   981  		*c = *cfgs
   982  	}
   983  }