github.com/letsencrypt/trillian@v1.1.2-0.20180615153820-ae375a99d36a/quota/etcd/storage/quota_storage.go (about)

     1  // Copyright 2017 Google Inc. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package storage contains storage classes for etcd-based quotas.
    16  package storage
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  	"regexp"
    22  	"strconv"
    23  	"strings"
    24  	"time"
    25  
    26  	"github.com/coreos/etcd/clientv3"
    27  	"github.com/coreos/etcd/clientv3/concurrency"
    28  	"github.com/golang/glog"
    29  	"github.com/golang/protobuf/proto"
    30  	"github.com/google/trillian/quota"
    31  	"github.com/google/trillian/quota/etcd/storagepb"
    32  	"github.com/google/trillian/util"
    33  	"google.golang.org/grpc/codes"
    34  	"google.golang.org/grpc/status"
    35  )
    36  
    37  const (
    38  	configsKey = "quotas/configs"
    39  )
    40  
    41  var (
    42  	timeSource util.TimeSource = &util.SystemTimeSource{}
    43  
    44  	globalPattern *regexp.Regexp
    45  	treesPattern  *regexp.Regexp
    46  	usersPattern  *regexp.Regexp
    47  )
    48  
    49  func init() {
    50  	var err error
    51  	globalPattern, err = regexp.Compile("^quotas/global/(read|write)/config$")
    52  	if err != nil {
    53  		glog.Fatalf("bad global pattern: %v", err)
    54  	}
    55  	treesPattern, err = regexp.Compile(`^quotas/trees/\d+/(read|write)/config$`)
    56  	if err != nil {
    57  		glog.Fatalf("bad trees pattern: %v", err)
    58  	}
    59  	usersPattern, err = regexp.Compile("^quotas/users/[^/]+/(read|write)/config$")
    60  	if err != nil {
    61  		glog.Fatalf("bad users pattern: %v", err)
    62  	}
    63  }
    64  
    65  // IsNameValid returns true if name is a valid quota name.
    66  func IsNameValid(name string) bool {
    67  	switch {
    68  	case globalPattern.MatchString(name):
    69  		return true
    70  	case usersPattern.MatchString(name):
    71  		return true
    72  	case treesPattern.MatchString(name):
    73  		// Tree ID must fit on an int64
    74  		id := strings.Split(name, "/")[2]
    75  		_, err := strconv.ParseInt(id, 10, 64)
    76  		return err == nil
    77  	}
    78  	return false
    79  }
    80  
    81  // QuotaStorage is the interface between the etcd-based quota implementations (quota.Manager and
    82  // RPCs) and etcd itself.
    83  type QuotaStorage struct {
    84  	Client *clientv3.Client
    85  }
    86  
    87  // UpdateConfigs creates or updates the supplied configs in etcd.
    88  // If no config exists, the current config is assumed to be an empty storagepb.Configs proto.
    89  // The update function allows for mask-based updates and ensures a single-transaction
    90  // read-modify-write operation.
    91  // If reset is true, all specified configs will be set to their max number of tokens. If false,
    92  // existing quotas won't be modified, unless the max number of tokens is lowered, in which case
    93  // the new ceiling is enforced.
    94  // Newly created quotas are always set to max tokens, regardless of the reset parameter.
    95  func (qs *QuotaStorage) UpdateConfigs(ctx context.Context, reset bool, update func(*storagepb.Configs)) (*storagepb.Configs, error) {
    96  	if update == nil {
    97  		return nil, status.Error(codes.Internal, "update function required")
    98  	}
    99  
   100  	var updated *storagepb.Configs
   101  	_, err := concurrency.NewSTMSerializable(ctx, qs.Client, func(s concurrency.STM) error {
   102  		previous, err := getConfigs(s)
   103  		if err != nil {
   104  			return err
   105  		}
   106  		// Take a deep copy of "previous". It's pointers all the way down, so it's easier to just
   107  		// unmarshal it again. STM has the key we just read and it should be exactly the same as
   108  		// previous...
   109  		updated, err = getConfigs(s)
   110  		if err != nil {
   111  			return err
   112  		}
   113  
   114  		// ... but let's sanity check that the configs match, just in case.
   115  		if !proto.Equal(previous, updated) {
   116  			return status.Error(codes.Internal, "verification failed: previous quota config != updated")
   117  		}
   118  
   119  		update(updated)
   120  		if err := validate(updated); err != nil {
   121  			return err
   122  		}
   123  
   124  		if !proto.Equal(previous, updated) {
   125  			pb, err := proto.Marshal(updated)
   126  			if err != nil {
   127  				return err
   128  			}
   129  			s.Put(configsKey, string(pb))
   130  		}
   131  
   132  		now := timeSource.Now()
   133  		for _, cfg := range updated.Configs {
   134  			// Make no distinction between enabled and disabled configs here. Get/Peek/Put are
   135  			// prepared to handle it, and recording the bucket as if it were enabled allows us to
   136  			// take advantage of the already-existing reset and lowering logic.
   137  			key := bucketKey(cfg)
   138  
   139  			var prev *storagepb.Config
   140  			for _, p := range previous.Configs {
   141  				if cfg.Name == p.Name {
   142  					prev = p
   143  					break
   144  				}
   145  			}
   146  
   147  			switch {
   148  			case prev == nil || prev.State == storagepb.Config_DISABLED || reset: // new bucket
   149  				bucket := &storagepb.Bucket{
   150  					Tokens: cfg.MaxTokens,
   151  					LastReplenishMillisSinceEpoch: now.UnixNano() / 1e6,
   152  				}
   153  				pb, err := proto.Marshal(bucket)
   154  				if err != nil {
   155  					return err
   156  				}
   157  				s.Put(key, string(pb))
   158  			case prev != nil && cfg.MaxTokens < prev.MaxTokens: // lowered bucket
   159  				// modBucket will coerce tokens to cfg.MaxTokens, if necessary
   160  				if _, err := modBucket(s, cfg, now, 0 /* add */); err != nil {
   161  					return err
   162  				}
   163  			}
   164  		}
   165  
   166  		return nil
   167  	})
   168  	return updated, err
   169  }
   170  
   171  func validate(cfgs *storagepb.Configs) error {
   172  	names := make(map[string]bool)
   173  	for i, cfg := range cfgs.Configs {
   174  		switch n := cfg.Name; {
   175  		case n == "":
   176  			return status.Errorf(codes.InvalidArgument, "config name is required (Configs[%v].Name is empty)", i)
   177  		case !IsNameValid(cfg.Name):
   178  			return status.Errorf(codes.InvalidArgument, "config name malformed (Configs[%v].Name = %q)", i, n)
   179  		}
   180  		if s := cfg.State; s == storagepb.Config_UNKNOWN_CONFIG_STATE {
   181  			return status.Errorf(codes.InvalidArgument, "config state invalid (Configs[%v].State = %s)", i, s)
   182  		}
   183  		if t := cfg.MaxTokens; t <= 0 {
   184  			return status.Errorf(codes.InvalidArgument, "config max tokens must be > 0 (Configs[%v].MaxTokens = %v)", i, t)
   185  		}
   186  		switch s := cfg.ReplenishmentStrategy.(type) {
   187  		case *storagepb.Config_SequencingBased:
   188  			if usersPattern.MatchString(cfg.Name) {
   189  				return status.Errorf(codes.InvalidArgument, "user quotas cannot use sequencing-based replenishment (Configs[%v].ReplenishmentStrategy)", i)
   190  			}
   191  			if strings.HasSuffix(cfg.Name, "/read/config") {
   192  				return status.Errorf(codes.InvalidArgument, "read quotas cannot use sequencing-based replenishment (Configs[%v].ReplenishmentStrategy)", i)
   193  			}
   194  		case *storagepb.Config_TimeBased:
   195  			if t := s.TimeBased.TokensToReplenish; t <= 0 {
   196  				return status.Errorf(codes.InvalidArgument, "time based tokens must be > 0 (Configs[%v].TimeBased.TokensToReplenish = %v)", i, t)
   197  			}
   198  			if r := s.TimeBased.ReplenishIntervalSeconds; r <= 0 {
   199  				return status.Errorf(codes.InvalidArgument, "time based replenish interval must be > 0 (Configs[%v].TimeBased.ReplenishIntervalSeconds = %v)", i, r)
   200  			}
   201  		default:
   202  			return status.Errorf(codes.InvalidArgument, "unsupported replenishment strategy (Configs[%v].ReplenishmentStrategy = %T)", i, s)
   203  		}
   204  		if names[cfg.Name] {
   205  			return status.Errorf(codes.InvalidArgument, "duplicate config name found at Configs[%v].Name", i)
   206  		}
   207  		names[cfg.Name] = true
   208  	}
   209  	return nil
   210  }
   211  
   212  // Configs returns the currently known quota configs.
   213  // If no config was explicitly created, then an empty storage.Configs proto is returned.
   214  func (qs *QuotaStorage) Configs(ctx context.Context) (*storagepb.Configs, error) {
   215  	var cfgs *storagepb.Configs
   216  	_, err := concurrency.NewSTMSerializable(ctx, qs.Client, func(s concurrency.STM) error {
   217  		var err error
   218  		cfgs, err = getConfigs(s)
   219  		return err
   220  
   221  	})
   222  	return cfgs, err
   223  }
   224  
   225  // Get acquires "tokens" tokens from the named quotas.
   226  // If one of the specified quotas doesn't have enough tokens, the entire operation fails. Unknown or
   227  // disabled quotas are considered infinite, therefore get requests will always succeed for them.
   228  func (qs *QuotaStorage) Get(ctx context.Context, names []string, tokens int64) error {
   229  	if tokens < 0 {
   230  		return fmt.Errorf("invalid number of tokens: %v", tokens)
   231  	}
   232  	return qs.mod(ctx, names, -tokens)
   233  }
   234  
   235  func (qs *QuotaStorage) mod(ctx context.Context, names []string, add int64) error {
   236  	now := timeSource.Now()
   237  	return qs.forNames(ctx, names, defaultMode, func(s concurrency.STM, name string, cfg *storagepb.Config) error {
   238  		_, err := modBucket(s, cfg, now, add)
   239  		return err
   240  	})
   241  }
   242  
   243  // Peek returns a map of quota name to tokens for the named quotas.
   244  // Unknown or disabled quotas are considered infinite and returned as having quota.MaxTokens tokens,
   245  // therefore all requested names are guaranteed to be in the resulting map
   246  func (qs *QuotaStorage) Peek(ctx context.Context, names []string) (map[string]int64, error) {
   247  	now := timeSource.Now()
   248  	tokens := make(map[string]int64)
   249  	err := qs.forNames(ctx, names, emitInfinite, func(s concurrency.STM, name string, cfg *storagepb.Config) error {
   250  		var err error
   251  		var t int64
   252  		if cfg == nil {
   253  			t = int64(quota.MaxTokens)
   254  		} else {
   255  			t, err = modBucket(s, cfg, now, 0 /* add */)
   256  		}
   257  		tokens[name] = t
   258  		return err
   259  	})
   260  	return tokens, err
   261  }
   262  
   263  // Put adds "tokens" tokens to the named quotas.
   264  // Time-based quotas cannot be replenished this way, therefore put requests for them are ignored.
   265  // Unknown or disabled quotas are considered infinite and also ignored.
   266  func (qs *QuotaStorage) Put(ctx context.Context, names []string, tokens int64) error {
   267  	if tokens < 0 {
   268  		return fmt.Errorf("invalid number of tokens: %v", tokens)
   269  	}
   270  	return qs.mod(ctx, names, tokens)
   271  }
   272  
   273  // Reset resets the named quotas to their maximum number of tokens.
   274  // Unknown or disabled quotas are considered infinite and ignored.
   275  func (qs *QuotaStorage) Reset(ctx context.Context, names []string) error {
   276  	now := timeSource.Now()
   277  	return qs.forNames(ctx, names, defaultMode, func(s concurrency.STM, name string, cfg *storagepb.Config) error {
   278  		bucket := &storagepb.Bucket{
   279  			Tokens: cfg.MaxTokens,
   280  			LastReplenishMillisSinceEpoch: now.UnixNano() / 1e6,
   281  		}
   282  		pb, err := proto.Marshal(bucket)
   283  		if err != nil {
   284  			return err
   285  		}
   286  		s.Put(bucketKey(cfg), string(pb))
   287  		return nil
   288  	})
   289  }
   290  
   291  // forNamesMode specifies how forNames handles disabled and infinite quotas.
   292  type forNamesMode int
   293  
   294  const (
   295  	// defaultMode emits only known, enabled configs.
   296  	defaultMode forNamesMode = iota
   297  
   298  	// emitInfinite emits all known, enabled configs and all infinite configs.
   299  	// Infinite configs are emitted with a nil cfg value.
   300  	// emitInfinite emits disabled configs with a nil cfg value as well.
   301  	emitInfinite
   302  )
   303  
   304  // forNames calls fn for all configs specified by names. Execution is performed in a single etcd
   305  // transaction.
   306  // By default, fn is only called for known, enabled configs. See forNamesMode for other behaviors.
   307  // Names are validated and de-duped automatically.
   308  func (qs *QuotaStorage) forNames(ctx context.Context, names []string, mode forNamesMode, fn func(concurrency.STM, string, *storagepb.Config) error) error {
   309  	for _, name := range names {
   310  		if !IsNameValid(name) {
   311  			return fmt.Errorf("invalid name: %q", name)
   312  		}
   313  	}
   314  
   315  	_, err := concurrency.NewSTMSerializable(ctx, qs.Client, func(s concurrency.STM) error {
   316  		cfgs, err := getConfigs(s)
   317  		if err != nil {
   318  			return err
   319  		}
   320  
   321  		seenNames := make(map[string]bool)
   322  		for _, name := range names {
   323  			if seenNames[name] {
   324  				continue
   325  			}
   326  			seenNames[name] = true
   327  
   328  			emitted := false
   329  			for _, cfg := range cfgs.Configs {
   330  				if cfg.Name == name {
   331  					if cfg.State == storagepb.Config_ENABLED {
   332  						if err := fn(s, name, cfg); err != nil {
   333  							return err
   334  						}
   335  						emitted = true
   336  					}
   337  					break
   338  				}
   339  			}
   340  			if !emitted && mode == emitInfinite {
   341  				if err := fn(s, name, nil); err != nil {
   342  					return err
   343  				}
   344  			}
   345  		}
   346  		return nil
   347  	})
   348  	return err
   349  }
   350  
   351  func getConfigs(s concurrency.STM) (*storagepb.Configs, error) {
   352  	// TODO(codingllama): Consider watching configs instead of re-reading
   353  	cfgs := &storagepb.Configs{}
   354  	val := s.Get(configsKey)
   355  	if val == "" {
   356  		// Empty value means no config was explicitly created yet.
   357  		// Use the default (empty) configs in this case.
   358  		return cfgs, nil
   359  	}
   360  	if err := proto.Unmarshal([]byte(val), cfgs); err != nil {
   361  		return nil, fmt.Errorf("error unmarshaling %v: %v", configsKey, err)
   362  	}
   363  	return cfgs, nil
   364  }
   365  
   366  // modBucket adds "add" tokens to the specified quota. Add may be negative or zero.
   367  // Time-based quotas that are due replenishment will be replenished before the add operation. Quotas
   368  // that are above ceiling (eg, due to lowered max tokens) will also be constrained to the
   369  // appropriate ceiling. As a consequence, calls with add = 0 are still useful for peeking and the
   370  // explained side-effects.
   371  // modBucket returns the current token count for cfg.
   372  func modBucket(s concurrency.STM, cfg *storagepb.Config, now time.Time, add int64) (int64, error) {
   373  	key := bucketKey(cfg)
   374  
   375  	val := s.Get(key)
   376  	prevBucket := storagepb.Bucket{}
   377  	if err := proto.Unmarshal([]byte(val), &prevBucket); err != nil {
   378  		return 0, fmt.Errorf("error unmarshaling %v: %v", key, err)
   379  	}
   380  	newBucket := prevBucket
   381  
   382  	if tb := cfg.GetTimeBased(); tb != nil {
   383  		if now.Unix() >= newBucket.LastReplenishMillisSinceEpoch/1e3+tb.ReplenishIntervalSeconds {
   384  			newBucket.Tokens += tb.TokensToReplenish
   385  			if newBucket.Tokens > cfg.MaxTokens {
   386  				newBucket.Tokens = cfg.MaxTokens
   387  			}
   388  			newBucket.LastReplenishMillisSinceEpoch = now.UnixNano() / 1e6
   389  		}
   390  		if add > 0 {
   391  			add = 0 // Do not replenish time-based quotas
   392  		}
   393  	}
   394  
   395  	newBucket.Tokens += add
   396  	if newBucket.Tokens < 0 {
   397  		return 0, fmt.Errorf("insufficient tokens on %v (%v vs %v)", key, prevBucket.Tokens, -add)
   398  	}
   399  	if newBucket.Tokens > cfg.MaxTokens {
   400  		newBucket.Tokens = cfg.MaxTokens
   401  	}
   402  
   403  	if !proto.Equal(&prevBucket, &newBucket) {
   404  		pb, err := proto.Marshal(&newBucket)
   405  		if err != nil {
   406  			return 0, err
   407  		}
   408  		s.Put(key, string(pb))
   409  	}
   410  
   411  	return newBucket.Tokens, nil
   412  }
   413  
   414  func bucketKey(cfg *storagepb.Config) string {
   415  	return fmt.Sprintf("%v/0", cfg.Name)
   416  }