github.com/letsencrypt/boulder@v0.20251208.0/ratelimits/limit.go (about)

     1  package ratelimits
     2  
     3  import (
     4  	"context"
     5  	"encoding/csv"
     6  	"errors"
     7  	"fmt"
     8  	"net/netip"
     9  	"os"
    10  	"sort"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/prometheus/client_golang/prometheus"
    16  
    17  	"github.com/letsencrypt/boulder/config"
    18  	"github.com/letsencrypt/boulder/core"
    19  	"github.com/letsencrypt/boulder/identifier"
    20  	blog "github.com/letsencrypt/boulder/log"
    21  	"github.com/letsencrypt/boulder/strictyaml"
    22  )
    23  
    24  // errLimitDisabled indicates that the limit name specified is valid but is not
    25  // currently configured.
    26  var errLimitDisabled = errors.New("limit disabled")
    27  
    28  // LimitConfig defines the exportable configuration for a rate limit or a rate
    29  // limit override, without a `limit`'s internal fields.
    30  //
    31  // The zero value of this struct is invalid, because some of the fields must be
    32  // greater than zero.
    33  type LimitConfig struct {
    34  	// Burst specifies maximum concurrent allowed requests at any given time. It
    35  	// must be greater than zero.
    36  	Burst int64
    37  
    38  	// Count is the number of requests allowed per period. It must be greater
    39  	// than zero.
    40  	Count int64
    41  
    42  	// Period is the duration of time in which the count (of requests) is
    43  	// allowed. It must be greater than zero.
    44  	Period config.Duration
    45  }
    46  
    47  type LimitConfigs map[string]*LimitConfig
    48  
    49  // Limit defines the configuration for a rate limit or a rate limit override.
    50  //
    51  // The zero value of this struct is invalid, because some of the fields must be
    52  // greater than zero. It and several of its fields are exported to support admin
    53  // tooling used during the migration from overrides.yaml to the overrides
    54  // database table.
    55  type Limit struct {
    56  	// Burst specifies maximum concurrent allowed requests at any given time. It
    57  	// must be greater than zero.
    58  	Burst int64
    59  
    60  	// Count is the number of requests allowed per period. It must be greater
    61  	// than zero.
    62  	Count int64
    63  
    64  	// Period is the duration of time in which the count (of requests) is
    65  	// allowed. It must be greater than zero.
    66  	Period config.Duration
    67  
    68  	// Name is the name of the limit. It must be one of the Name enums defined
    69  	// in this package.
    70  	Name Name
    71  
    72  	// Comment is an optional field that can be used to provide additional
    73  	// context for an override. It is not used for default limits.
    74  	Comment string
    75  
    76  	// emissionInterval is the interval, in nanoseconds, at which tokens are
    77  	// added to a bucket (period / count). This is also the steady-state rate at
    78  	// which requests can be made without being denied even once the burst has
    79  	// been exhausted. This is precomputed to avoid doing the same calculation
    80  	// on every request.
    81  	emissionInterval int64
    82  
    83  	// burstOffset is the duration of time, in nanoseconds, it takes for a
    84  	// bucket to go from empty to full (burst * (period / count)). This is
    85  	// precomputed to avoid doing the same calculation on every request.
    86  	burstOffset int64
    87  
    88  	// isOverride is true if the limit is an override.
    89  	isOverride bool
    90  }
    91  
    92  // precompute calculates the emissionInterval and burstOffset for the limit.
    93  func (l *Limit) precompute() {
    94  	l.emissionInterval = l.Period.Nanoseconds() / l.Count
    95  	l.burstOffset = l.emissionInterval * l.Burst
    96  }
    97  
    98  func ValidateLimit(l *Limit) error {
    99  	if l.Burst <= 0 {
   100  		return fmt.Errorf("invalid burst '%d', must be > 0", l.Burst)
   101  	}
   102  	if l.Count <= 0 {
   103  		return fmt.Errorf("invalid count '%d', must be > 0", l.Count)
   104  	}
   105  	if l.Period.Duration <= 0 {
   106  		return fmt.Errorf("invalid period '%s', must be > 0", l.Period)
   107  	}
   108  	return nil
   109  }
   110  
   111  type Limits map[string]*Limit
   112  
   113  // loadDefaultsFromFile unmarshals the defaults YAML file at path into a map of
   114  // limits.
   115  func loadDefaultsFromFile(path string) (LimitConfigs, error) {
   116  	lm := make(LimitConfigs)
   117  	data, err := os.ReadFile(path)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	err = strictyaml.Unmarshal(data, &lm)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  	return lm, nil
   126  }
   127  
   128  type overrideYAML struct {
   129  	LimitConfig `yaml:",inline"`
   130  	// Ids is a list of ids that this override applies to.
   131  	Ids []struct {
   132  		Id string `yaml:"id"`
   133  		// Comment is an optional field that can be used to provide additional
   134  		// context for the override.
   135  		Comment string `yaml:"comment,omitempty"`
   136  	} `yaml:"ids"`
   137  }
   138  
   139  type overridesYAML []map[string]overrideYAML
   140  
   141  // loadOverridesFromFile unmarshals the YAML file at path into a map of
   142  // overrides.
   143  func loadOverridesFromFile(path string) (overridesYAML, error) {
   144  	ov := overridesYAML{}
   145  	data, err := os.ReadFile(path)
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  	err = strictyaml.Unmarshal(data, &ov)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  	return ov, nil
   154  }
   155  
   156  // parseOverrideNameId is broken out for ease of testing.
   157  func parseOverrideNameId(key string) (Name, string, error) {
   158  	if !strings.Contains(key, ":") {
   159  		// Avoids a potential panic in strings.SplitN below.
   160  		return Unknown, "", fmt.Errorf("invalid override %q, must be formatted 'name:id'", key)
   161  	}
   162  	nameAndId := strings.SplitN(key, ":", 2)
   163  	nameStr := nameAndId[0]
   164  	if nameStr == "" {
   165  		return Unknown, "", fmt.Errorf("empty name in override %q, must be formatted 'name:id'", key)
   166  	}
   167  
   168  	name, ok := StringToName[nameStr]
   169  	if !ok {
   170  		return Unknown, "", fmt.Errorf("unrecognized name %q in override limit %q, must be one of %v", nameStr, key, LimitNames)
   171  	}
   172  	id := nameAndId[1]
   173  	if id == "" {
   174  		return Unknown, "", fmt.Errorf("empty id in override %q, must be formatted 'name:id'", key)
   175  	}
   176  	return name, id, nil
   177  }
   178  
   179  // parseOverrideNameEnumId is like parseOverrideNameId, but it expects the
   180  // key to be formatted as 'name:id', where 'name' is a Name enum string and 'id'
   181  // is a string identifier. It returns an error if either part is missing or invalid.
   182  func parseOverrideNameEnumId(key string) (Name, string, error) {
   183  	if !strings.Contains(key, ":") {
   184  		// Avoids a potential panic in strings.SplitN below.
   185  		return Unknown, "", fmt.Errorf("invalid override %q, must be formatted 'name:id'", key)
   186  	}
   187  	nameStrAndId := strings.SplitN(key, ":", 2)
   188  	if len(nameStrAndId) != 2 {
   189  		return Unknown, "", fmt.Errorf("invalid override %q, must be formatted 'name:id'", key)
   190  	}
   191  
   192  	nameInt, err := strconv.Atoi(nameStrAndId[0])
   193  	if err != nil {
   194  		return Unknown, "", fmt.Errorf("invalid name %q in override limit %q, must be an integer", nameStrAndId[0], key)
   195  	}
   196  	name := Name(nameInt)
   197  	if !name.isValid() {
   198  		return Unknown, "", fmt.Errorf("invalid name %q in override limit %q, must be one of %v", nameStrAndId[0], key, LimitNames)
   199  
   200  	}
   201  	id := nameStrAndId[1]
   202  	if id == "" {
   203  		return Unknown, "", fmt.Errorf("empty id in override %q, must be formatted 'name:id'", key)
   204  	}
   205  	return name, id, nil
   206  }
   207  
   208  // parseOverrideLimits validates a YAML list of override limits. It must be
   209  // formatted as a list of maps, where each map has a single key representing the
   210  // limit name and a value that is a map containing the limit fields and an
   211  // additional 'ids' field that is a list of ids that this override applies to.
   212  func parseOverrideLimits(newOverridesYAML overridesYAML) (Limits, error) {
   213  	parsed := make(Limits)
   214  
   215  	for _, ov := range newOverridesYAML {
   216  		for k, v := range ov {
   217  			name, ok := StringToName[k]
   218  			if !ok {
   219  				return nil, fmt.Errorf("unrecognized name %q in override limit, must be one of %v", k, LimitNames)
   220  			}
   221  
   222  			for _, entry := range v.Ids {
   223  				id, err := hydrateOverrideLimit(entry.Id, name)
   224  				if err != nil {
   225  					return nil, fmt.Errorf(
   226  						"validating name %s and id %q for override limit %q: %w", name, id, k, err)
   227  				}
   228  
   229  				lim := &Limit{
   230  					Burst:      v.Burst,
   231  					Count:      v.Count,
   232  					Period:     v.Period,
   233  					Name:       name,
   234  					Comment:    entry.Comment,
   235  					isOverride: true,
   236  				}
   237  
   238  				err = ValidateLimit(lim)
   239  				if err != nil {
   240  					return nil, fmt.Errorf(
   241  						"validating name %s and id %q for override limit %q: %w", name, id, k, err)
   242  				}
   243  
   244  				parsed[joinWithColon(name.EnumString(), id)] = lim
   245  			}
   246  		}
   247  	}
   248  	return parsed, nil
   249  }
   250  
   251  // hydrateOverrideLimit validates the limit Name and override bucket key. It
   252  // returns the correct bucket key to use in-memory.
   253  func hydrateOverrideLimit(bucketKey string, limitName Name) (string, error) {
   254  	if !limitName.isValid() {
   255  		return "", fmt.Errorf("unrecognized limit name %d", limitName)
   256  	}
   257  
   258  	err := validateIdForName(limitName, bucketKey)
   259  	if err != nil {
   260  		return "", err
   261  	}
   262  
   263  	// Interpret and compute a new in-memory bucket key for two rate limits,
   264  	// since their keys aren't nice to store in a config file or database entry.
   265  	switch limitName {
   266  	case CertificatesPerDomain:
   267  		// Convert IP addresses to their covering /32 (IPv4) or /64
   268  		// (IPv6) prefixes in CIDR notation.
   269  		ip, err := netip.ParseAddr(bucketKey)
   270  		if err == nil {
   271  			prefix, err := coveringIPPrefix(limitName, ip)
   272  			if err != nil {
   273  				return "", fmt.Errorf("computing prefix for IP address %q: %w", bucketKey, err)
   274  			}
   275  			bucketKey = prefix.String()
   276  		}
   277  	case CertificatesPerFQDNSet:
   278  		// Compute the hash of a comma-separated list of identifier values.
   279  		bucketKey = fmt.Sprintf("%x", core.HashIdentifiers(identifier.FromStringSlice(strings.Split(bucketKey, ","))))
   280  	}
   281  
   282  	return bucketKey, nil
   283  }
   284  
   285  // parseDefaultLimits validates a map of default limits and rekeys it by 'Name'.
   286  func parseDefaultLimits(newDefaultLimits LimitConfigs) (Limits, error) {
   287  	parsed := make(Limits)
   288  
   289  	for k, v := range newDefaultLimits {
   290  		name, ok := StringToName[k]
   291  		if !ok {
   292  			return nil, fmt.Errorf("unrecognized name %q in default limit, must be one of %v", k, LimitNames)
   293  		}
   294  
   295  		lim := &Limit{
   296  			Burst:  v.Burst,
   297  			Count:  v.Count,
   298  			Period: v.Period,
   299  			Name:   name,
   300  		}
   301  
   302  		err := ValidateLimit(lim)
   303  		if err != nil {
   304  			return nil, fmt.Errorf("parsing default limit %q: %w", k, err)
   305  		}
   306  
   307  		lim.precompute()
   308  		parsed[name.EnumString()] = lim
   309  	}
   310  	return parsed, nil
   311  }
   312  
   313  type OverridesRefresher func(context.Context, prometheus.Gauge, blog.Logger) (Limits, error)
   314  
   315  type limitRegistry struct {
   316  	// defaults stores default limits by 'name'.
   317  	defaults Limits
   318  
   319  	// overrides stores override limits by 'name:id'.
   320  	overrides       Limits
   321  	overridesLoaded bool
   322  
   323  	// refreshOverrides is a function to refresh override limits.
   324  	refreshOverrides OverridesRefresher
   325  
   326  	overridesTimestamp prometheus.Gauge
   327  	overridesErrors    prometheus.Gauge
   328  	overridesPerLimit  prometheus.GaugeVec
   329  
   330  	logger blog.Logger
   331  }
   332  
   333  // getLimit returns the limit for the specified by name and bucketKey, name is
   334  // required, bucketKey is optional. If bucketkey is empty, the default for the
   335  // limit specified by name is returned. If no default limit exists for the
   336  // specified name, errLimitDisabled is returned.
   337  func (l *limitRegistry) getLimit(name Name, bucketKey string) (*Limit, error) {
   338  	if !name.isValid() {
   339  		// This should never happen. Callers should only be specifying the limit
   340  		// Name enums defined in this package.
   341  		return nil, fmt.Errorf("specified name enum %q, is invalid", name)
   342  	}
   343  	if bucketKey != "" {
   344  		// Check for override.
   345  		ol, ok := l.overrides[bucketKey]
   346  		if ok {
   347  			return ol, nil
   348  		}
   349  	}
   350  	dl, ok := l.defaults[name.EnumString()]
   351  	if ok {
   352  		return dl, nil
   353  	}
   354  	return nil, errLimitDisabled
   355  }
   356  
   357  // loadOverrides replaces this registry's overrides with a new dataset.
   358  func (l *limitRegistry) loadOverrides(ctx context.Context) error {
   359  	newOverrides, err := l.refreshOverrides(ctx, l.overridesErrors, l.logger)
   360  	if err != nil {
   361  		return err
   362  	}
   363  	l.overridesLoaded = true
   364  
   365  	if len(newOverrides) < 1 {
   366  		l.logger.Warning("loading overrides: no valid overrides")
   367  		// If it's an empty set, don't replace any current overrides.
   368  		return nil
   369  	}
   370  
   371  	newOverridesPerLimit := make(map[Name]float64)
   372  	for _, override := range newOverrides {
   373  		override.precompute()
   374  		newOverridesPerLimit[override.Name]++
   375  	}
   376  
   377  	l.overrides = newOverrides
   378  	l.overridesTimestamp.SetToCurrentTime()
   379  	for rlName, rlString := range nameToString {
   380  		l.overridesPerLimit.WithLabelValues(rlString).Set(newOverridesPerLimit[rlName])
   381  	}
   382  
   383  	return nil
   384  }
   385  
   386  // loadOverridesWithRetry tries to loadOverrides, retrying at least every 30
   387  // seconds upon failure.
   388  func (l *limitRegistry) loadOverridesWithRetry(ctx context.Context) error {
   389  	retries := 0
   390  	for {
   391  		err := l.loadOverrides(ctx)
   392  		if err == nil {
   393  			return nil
   394  		}
   395  		l.logger.Errf("loading overrides: %v", err)
   396  		retries++
   397  		select {
   398  		case <-time.After(core.RetryBackoff(retries, time.Second/6, time.Second*15, 2)):
   399  		case <-ctx.Done():
   400  			return err
   401  		}
   402  	}
   403  }
   404  
   405  // NewRefresher loads, and periodically refreshes, overrides using this
   406  // registry's refreshOverrides function.
   407  func (l *limitRegistry) NewRefresher(interval time.Duration) context.CancelFunc {
   408  	ctx, cancel := context.WithCancel(context.Background())
   409  
   410  	go func() {
   411  		err := l.loadOverridesWithRetry(ctx)
   412  		if err != nil {
   413  			l.logger.Errf("loading overrides (initial): %v", err)
   414  		}
   415  
   416  		ticker := time.NewTicker(interval)
   417  		defer ticker.Stop()
   418  		for {
   419  			select {
   420  			case <-ticker.C:
   421  				err := l.loadOverridesWithRetry(ctx)
   422  				if err != nil {
   423  					l.logger.Errf("loading overrides (refresh): %v", err)
   424  				}
   425  			case <-ctx.Done():
   426  				return
   427  			}
   428  		}
   429  	}()
   430  
   431  	return cancel
   432  }
   433  
   434  // LoadOverridesByBucketKey loads the overrides YAML at the supplied path,
   435  // parses it with the existing helpers, and returns the resulting limits map
   436  // keyed by "<name>:<id>". This function is exported to support admin tooling
   437  // used during the migration from overrides.yaml to the overrides database
   438  // table.
   439  func LoadOverridesByBucketKey(path string) (Limits, error) {
   440  	ovs, err := loadOverridesFromFile(path)
   441  	if err != nil {
   442  		return nil, err
   443  	}
   444  	return parseOverrideLimits(ovs)
   445  }
   446  
   447  // DumpOverrides writes the provided overrides to CSV at the supplied path. Each
   448  // override is written as a single row, one per ID. Rows are sorted in the
   449  // following order:
   450  //   - Name    (ascending)
   451  //   - Count   (descending)
   452  //   - Burst   (descending)
   453  //   - Period  (ascending)
   454  //   - Comment (ascending)
   455  //   - ID      (ascending)
   456  //
   457  // This function supports admin tooling that routinely exports the overrides
   458  // table for investigation or auditing.
   459  func DumpOverrides(path string, overrides Limits) error {
   460  	type row struct {
   461  		name    string
   462  		id      string
   463  		count   int64
   464  		burst   int64
   465  		period  string
   466  		comment string
   467  	}
   468  
   469  	var rows []row
   470  	for bucketKey, limit := range overrides {
   471  		name, id, err := parseOverrideNameEnumId(bucketKey)
   472  		if err != nil {
   473  			return err
   474  		}
   475  
   476  		rows = append(rows, row{
   477  			name:    name.String(),
   478  			id:      id,
   479  			count:   limit.Count,
   480  			burst:   limit.Burst,
   481  			period:  limit.Period.Duration.String(),
   482  			comment: limit.Comment,
   483  		})
   484  	}
   485  
   486  	sort.Slice(rows, func(i, j int) bool {
   487  		// Sort by limit name in ascending order.
   488  		if rows[i].name != rows[j].name {
   489  			return rows[i].name < rows[j].name
   490  		}
   491  		// Sort by count in descending order (higher counts first).
   492  		if rows[i].count != rows[j].count {
   493  			return rows[i].count > rows[j].count
   494  		}
   495  		// Sort by burst in descending order (higher bursts first).
   496  		if rows[i].burst != rows[j].burst {
   497  			return rows[i].burst > rows[j].burst
   498  		}
   499  		// Sort by period in ascending order (shorter durations first).
   500  		if rows[i].period != rows[j].period {
   501  			return rows[i].period < rows[j].period
   502  		}
   503  		// Sort by comment in ascending order.
   504  		if rows[i].comment != rows[j].comment {
   505  			return rows[i].comment < rows[j].comment
   506  		}
   507  		// Sort by ID in ascending order.
   508  		return rows[i].id < rows[j].id
   509  	})
   510  
   511  	f, err := os.Create(path)
   512  	if err != nil {
   513  		return err
   514  	}
   515  	defer f.Close()
   516  
   517  	w := csv.NewWriter(f)
   518  	err = w.Write([]string{"name", "id", "count", "burst", "period", "comment"})
   519  	if err != nil {
   520  		return err
   521  	}
   522  
   523  	for _, r := range rows {
   524  		err := w.Write([]string{r.name, r.id, strconv.FormatInt(r.count, 10), strconv.FormatInt(r.burst, 10), r.period, r.comment})
   525  		if err != nil {
   526  			return err
   527  		}
   528  	}
   529  	w.Flush()
   530  
   531  	return w.Error()
   532  }