
     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     4  package rate
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"math"
    11  	"strconv"
    12  	"strings"
    14  	""
    15  	""
    16  	""
    17  	""
    19  	""
    20  	""
    21  	""
    22  	""
    23  )
    25  var (
    26  	log              = logging.DefaultLogger.WithField(logfields.LogSubsys, "rate")
    27  	ErrWaitCancelled = errors.New("request cancelled while waiting for rate limiting slot")
    28  )
    30  const (
    31  	defaultMeanOver                = 10
    32  	defaultDelayedAdjustmentFactor = 0.50
    33  	defaultMaxAdjustmentFactor     = 100.0
    35  	// waitSemaphoreWeight is the maximum resolution of the wait semaphore,
    36  	// the higher this value, the more accurate the ParallelRequests
    37  	// requirement is implemented
    38  	waitSemaphoreResolution = 10000000
    40  	// logUUID is the UUID of the request.
    41  	logUUID = "uuid"
    42  	// logAPICallName is the name of the underlying API call, such as
    43  	// "endpoint-create".
    44  	logAPICallName = "name"
    45  	// logProcessingDuration is the time taken to perform the actual underlying
    46  	// API call such as creating an endpoint or deleting an endpoint. This is
    47  	// the time between when the request has finished waiting (or being
    48  	// delayed), to when the underlying action has finished.
    49  	logProcessingDuration = "processingDuration"
    50  	// logParallelRequests is the number of allowed parallel requests. See
    51  	// APILimiter.parallelRequests.
    52  	logParallelRequests = "parallelRequests"
    53  	// logMinWaitDuration represents APILimiterParameters.MinWaitDuration.
    54  	logMinWaitDuration = "minWaitDuration"
    55  	// logMaxWaitDuration represents APILimiterParameters.MaxWaitDuration.
    56  	logMaxWaitDuration = "maxWaitDuration"
    57  	// logMaxWaitDurationLimiter is the actual / calculated maximum threshold
    58  	// for a request to wait. Any request exceeding this threshold will not be
    59  	// processed.
    60  	logMaxWaitDurationLimiter = "maxWaitDurationLimiter"
    61  	// logWaitDurationLimit is the actual / calculated amount of time
    62  	// determined by the underlying rate-limiting library that this request
    63  	// must wait before the rate limiter releases it, so that it can take the
    64  	// underlying action. See*Reservation).Delay().
    65  	logWaitDurationLimit = "waitDurationLimiter"
    66  	// logWaitDurationTotal is the actual total amount of time that this
    67  	// request spent waiting to be released by the rate limiter.
    68  	logWaitDurationTotal = "waitDurationTotal"
    69  	// logLimit is the rate limit. See APILimiterParameters.RateLimit.
    70  	logLimit = "limit"
    71  	// logLimit is the burst rate. See APILimiterParameters.RateBurst.
    72  	logBurst = "burst"
    73  	// logTotalDuration is the total time between when the request was first
    74  	// scheduled (entered the rate limiter) to when it completed processing of
    75  	// the underlying action. This is the absolute total time of the request
    76  	// from beginning to end.
    77  	logTotalDuration = "totalDuration"
    78  	// logSkipped represents whether the rate limiter will skip rate-limiting
    79  	// this request. See APILimiterParameters.SkipInitial.
    80  	logSkipped = "rateLimiterSkipped"
    81  )
    83  type outcome string
    85  const (
    86  	outcomeParallelMaxWait outcome = "fail-parallel-wait"
    87  	outcomeLimitMaxWait    outcome = "fail-limit-wait"
    88  	outcomeReqCancelled    outcome = "request-cancelled"
    89  	outcomeErrorCode       int     = 429
    90  	outcomeSuccessCode     int     = 200
    91  )
    93  // APILimiter is an extension to x/time/rate.Limiter specifically for Cilium
    94  // API calls. It allows to automatically adjust the rate, burst and maximum
    95  // parallel API calls to stay as close as possible to an estimated processing
    96  // time.
    97  type APILimiter struct {
    98  	// name is the name of the API call. This field is immutable after
    99  	// NewAPILimiter()
   100  	name string
   102  	// params is the parameters of the limiter. This field is immutable
   103  	// after NewAPILimiter()
   104  	params APILimiterParameters
   106  	// metrics points to the metrics implementation provided by the caller
   107  	// of the APILimiter. This field is immutable after NewAPILimiter()
   108  	metrics MetricsObserver
   110  	// mutex protects all fields below this line
   111  	mutex lock.RWMutex
   113  	// meanProcessingDuration is the latest mean processing duration,
   114  	// calculated based on processingDurations
   115  	meanProcessingDuration float64
   117  	// processingDurations is the last params.MeanOver processing durations
   118  	processingDurations []time.Duration
   120  	// meanWaitDuration is the latest mean wait duration, calculated based
   121  	// on waitDurations
   122  	meanWaitDuration float64
   124  	// waitDurations is the last params.MeanOver wait durations
   125  	waitDurations []time.Duration
   127  	// parallelRequests is the currently allowed maximum parallel
   128  	// requests. This defaults to params.MaxParallel requests and is then
   129  	// adjusted automatically if params.AutoAdjust is enabled.
   130  	parallelRequests int
   132  	// adjustmentFactor is the latest adjustment factor. It is the ratio
   133  	// between params.EstimatedProcessingDuration and
   134  	// meanProcessingDuration.
   135  	adjustmentFactor float64
   137  	// limiter is the rate limiter based on params.RateLimit and
   138  	// params.RateBurst.
   139  	limiter *rate.Limiter
   141  	// currentRequestsInFlight is the number of parallel API requests
   142  	// currently in flight
   143  	currentRequestsInFlight int
   145  	// requestsProcessed is the total number of processed requests
   146  	requestsProcessed int64
   148  	// requestsScheduled is the total number of scheduled requests
   149  	requestsScheduled int64
   151  	// parallelWaitSemaphore is the semaphore used to implement
   152  	// params.MaxParallel. It is initialized with a capacity of
   153  	// waitSemaphoreResolution and each API request will acquire
   154  	// waitSemaphoreResolution/params.MaxParallel tokens.
   155  	parallelWaitSemaphore *semaphore.Weighted
   156  }
   158  // APILimiterParameters is the configuration of an APILimiter. The structure
   159  // may not be mutated after it has been passed into NewAPILimiter().
   160  type APILimiterParameters struct {
   161  	// EstimatedProcessingDuration is the estimated duration an API call
   162  	// will take. This value is used if AutoAdjust is enabled to
   163  	// automatically adjust rate limits to stay as close as possible to the
   164  	// estimated processing duration.
   165  	EstimatedProcessingDuration time.Duration
   167  	// AutoAdjust enables automatic adjustment of the values
   168  	// ParallelRequests, RateLimit, and RateBurst in order to keep the
   169  	// mean processing duration close to EstimatedProcessingDuration
   170  	AutoAdjust bool
   172  	// MeanOver is the number of entries to keep in order to calculate the
   173  	// mean processing and wait duration
   174  	MeanOver int
   176  	// ParallelRequests is the parallel requests allowed. If AutoAdjust is
   177  	// enabled, the value will adjust automatically.
   178  	ParallelRequests int
   180  	// MaxParallelRequests is the maximum parallel requests allowed. If
   181  	// AutoAdjust is enabled, then the ParalelRequests will never grow
   182  	// above MaxParallelRequests.
   183  	MaxParallelRequests int
   185  	// MinParallelRequests is the minimum parallel requests allowed. If
   186  	// AutoAdjust is enabled, then the ParallelRequests will never fall
   187  	// below MinParallelRequests.
   188  	MinParallelRequests int
   190  	// RateLimit is the initial number of API requests allowed per second.
   191  	// If AutoAdjust is enabled, the value will adjust automatically.
   192  	RateLimit rate.Limit
   194  	// RateBurst is the initial allowed burst of API requests allowed. If
   195  	// AutoAdjust is enabled, the value will adjust automatically.
   196  	RateBurst int
   198  	// MinWaitDuration is the minimum time an API request always has to
   199  	// wait before the Wait() function returns an error.
   200  	MinWaitDuration time.Duration
   202  	// MaxWaitDuration is the maximum time an API request is allowed to
   203  	// wait before the Wait() function returns an error.
   204  	MaxWaitDuration time.Duration
   206  	// Log enables info logging of processed API requests. This should only
   207  	// be used for low frequency API calls.
   208  	Log bool
   210  	// DelayedAdjustmentFactor is percentage of the AdjustmentFactor to be
   211  	// applied to RateBurst and MaxWaitDuration defined as a value between
   212  	// 0.0..1.0. This is used to steer a slower reaction of the RateBurst
   213  	// and ParallelRequests compared to RateLimit.
   214  	DelayedAdjustmentFactor float64
   216  	// SkipInitial is the number of initial API calls for which to not
   217  	// apply any rate limiting. This is useful to define a learning phase
   218  	// in the beginning to allow for auto adjustment before imposing wait
   219  	// durations and rate limiting on API calls.
   220  	SkipInitial int
   222  	// MaxAdjustmentFactor is the maximum adjustment factor when AutoAdjust
   223  	// is enabled. Base values will not adjust more than by this factor.
   224  	MaxAdjustmentFactor float64
   225  }
   227  // MergeUserConfig merges the provided user configuration into the existing
   228  // parameters and returns a new copy.
   229  func (p APILimiterParameters) MergeUserConfig(config string) (APILimiterParameters, error) {
   230  	if err := (&p).mergeUserConfig(config); err != nil {
   231  		return APILimiterParameters{}, err
   232  	}
   234  	return p, nil
   235  }
   237  // NewAPILimiter returns a new APILimiter based on the parameters and metrics implementation
   238  func NewAPILimiter(name string, p APILimiterParameters, metrics MetricsObserver) *APILimiter {
   239  	if p.MeanOver == 0 {
   240  		p.MeanOver = defaultMeanOver
   241  	}
   243  	if p.MinParallelRequests == 0 {
   244  		p.MinParallelRequests = 1
   245  	}
   247  	if p.RateBurst == 0 {
   248  		p.RateBurst = 1
   249  	}
   251  	if p.DelayedAdjustmentFactor == 0.0 {
   252  		p.DelayedAdjustmentFactor = defaultDelayedAdjustmentFactor
   253  	}
   255  	if p.MaxAdjustmentFactor == 0.0 {
   256  		p.MaxAdjustmentFactor = defaultMaxAdjustmentFactor
   257  	}
   259  	l := &APILimiter{
   260  		name:                  name,
   261  		params:                p,
   262  		parallelRequests:      p.ParallelRequests,
   263  		parallelWaitSemaphore: semaphore.NewWeighted(waitSemaphoreResolution),
   264  		metrics:               metrics,
   265  	}
   267  	if p.RateLimit != 0 {
   268  		l.limiter = rate.NewLimiter(p.RateLimit, p.RateBurst)
   269  	}
   271  	return l
   272  }
   274  // NewAPILimiterFromConfig returns a new APILimiter based on user configuration
   275  func NewAPILimiterFromConfig(name, config string, metrics MetricsObserver) (*APILimiter, error) {
   276  	p := &APILimiterParameters{}
   278  	if err := p.mergeUserConfig(config); err != nil {
   279  		return nil, err
   280  	}
   282  	return NewAPILimiter(name, *p, metrics), nil
   283  }
   285  func (p *APILimiterParameters) mergeUserConfigKeyValue(key, value string) error {
   286  	switch strings.ToLower(key) {
   287  	case "rate-limit":
   288  		limit, err := parseRate(value)
   289  		if err != nil {
   290  			return fmt.Errorf("unable to parse rate %q: %w", value, err)
   291  		}
   292  		p.RateLimit = limit
   293  	case "rate-burst":
   294  		burst, err := parsePositiveInt(value)
   295  		if err != nil {
   296  			return err
   297  		}
   298  		p.RateBurst = burst
   299  	case "min-wait-duration":
   300  		minWaitDuration, err := time.ParseDuration(value)
   301  		if err != nil {
   302  			return fmt.Errorf("unable to parse duration %q: %w", value, err)
   303  		}
   304  		p.MinWaitDuration = minWaitDuration
   305  	case "max-wait-duration":
   306  		maxWaitDuration, err := time.ParseDuration(value)
   307  		if err != nil {
   308  			return fmt.Errorf("unable to parse duration %q: %w", value, err)
   309  		}
   310  		p.MaxWaitDuration = maxWaitDuration
   311  	case "estimated-processing-duration":
   312  		estProcessingDuration, err := time.ParseDuration(value)
   313  		if err != nil {
   314  			return fmt.Errorf("unable to parse duration %q: %w", value, err)
   315  		}
   316  		p.EstimatedProcessingDuration = estProcessingDuration
   317  	case "auto-adjust":
   318  		v, err := strconv.ParseBool(value)
   319  		if err != nil {
   320  			return fmt.Errorf("unable to parse bool %q: %w", value, err)
   321  		}
   322  		p.AutoAdjust = v
   323  	case "parallel-requests":
   324  		parallel, err := parsePositiveInt(value)
   325  		if err != nil {
   326  			return err
   327  		}
   328  		p.ParallelRequests = parallel
   329  	case "min-parallel-requests":
   330  		minParallel, err := parsePositiveInt(value)
   331  		if err != nil {
   332  			return err
   333  		}
   334  		p.MinParallelRequests = minParallel
   335  	case "max-parallel-requests":
   336  		maxParallel, err := parsePositiveInt(value)
   337  		if err != nil {
   338  			return err
   339  		}
   340  		p.MaxParallelRequests = int(maxParallel)
   341  	case "mean-over":
   342  		meanOver, err := parsePositiveInt(value)
   343  		if err != nil {
   344  			return err
   345  		}
   346  		p.MeanOver = meanOver
   347  	case "log":
   348  		v, err := strconv.ParseBool(value)
   349  		if err != nil {
   350  			return fmt.Errorf("unable to parse bool %q: %w", value, err)
   351  		}
   352  		p.Log = v
   353  	case "delayed-adjustment-factor":
   354  		delayedAdjustmentFactor, err := strconv.ParseFloat(value, 64)
   355  		if err != nil {
   356  			return fmt.Errorf("unable to parse float %q: %w", value, err)
   357  		}
   358  		p.DelayedAdjustmentFactor = delayedAdjustmentFactor
   359  	case "max-adjustment-factor":
   360  		maxAdjustmentFactor, err := strconv.ParseFloat(value, 64)
   361  		if err != nil {
   362  			return fmt.Errorf("unable to parse float %q: %w", value, err)
   363  		}
   364  		p.MaxAdjustmentFactor = maxAdjustmentFactor
   365  	case "skip-initial":
   366  		skipInitial, err := parsePositiveInt(value)
   367  		if err != nil {
   368  			return err
   369  		}
   370  		p.SkipInitial = skipInitial
   371  	default:
   372  		return fmt.Errorf("unknown rate limiting option %q", key)
   373  	}
   375  	return nil
   376  }
   378  func (p *APILimiterParameters) mergeUserConfig(config string) error {
   379  	tokens := strings.Split(config, ",")
   380  	for _, token := range tokens {
   381  		if token == "" {
   382  			continue
   383  		}
   385  		t := strings.SplitN(token, ":", 2)
   386  		if len(t) != 2 {
   387  			return fmt.Errorf("unable to parse rate limit option %q, must in the form name=option:value[,option:value]", token)
   388  		}
   390  		if err := p.mergeUserConfigKeyValue(t[0], t[1]); err != nil {
   391  			return fmt.Errorf("unable to parse rate limit option %q with value %q: %w", t[0], t[1], err)
   392  		}
   393  	}
   395  	return nil
   396  }
   398  func (l *APILimiter) Parameters() APILimiterParameters {
   399  	return l.params
   400  }
   402  // SetRateLimit sets the rate limit of the limiter. If limiter is unset, a new
   403  // Limiter is created using the rate burst set in the parameters.
   404  func (l *APILimiter) SetRateLimit(limit rate.Limit) {
   405  	l.mutex.Lock()
   406  	defer l.mutex.Unlock()
   407  	if l.limiter != nil {
   408  		l.limiter.SetLimit(limit)
   409  	} else {
   410  		l.limiter = rate.NewLimiter(limit, l.params.RateBurst)
   411  	}
   412  }
   414  // SetRateBurst sets the rate burst of the limiter. If limiter is unset, a new
   415  // Limiter is created using the rate limit set in the parameters.
   416  func (l *APILimiter) SetRateBurst(burst int) {
   417  	l.mutex.Lock()
   418  	defer l.mutex.Unlock()
   419  	if l.limiter != nil {
   420  		l.limiter.SetBurst(burst)
   421  	} else {
   422  		l.limiter = rate.NewLimiter(l.params.RateLimit, burst)
   423  	}
   424  }
   426  func (l *APILimiter) delayedAdjustment(current, min, max float64) (n float64) {
   427  	n = current * l.adjustmentFactor
   428  	n = current + ((n - current) * l.params.DelayedAdjustmentFactor)
   429  	if min > 0.0 && n < min {
   430  		n = min
   431  	}
   432  	if max > 0.0 && n > max {
   433  		n = max
   434  	}
   435  	return
   436  }
   438  func (l *APILimiter) calculateAdjustmentFactor() float64 {
   439  	f := l.params.EstimatedProcessingDuration.Seconds() / l.meanProcessingDuration
   440  	if f > l.params.MaxAdjustmentFactor {
   441  		f = l.params.MaxAdjustmentFactor
   442  	}
   443  	if f < 1.0/l.params.MaxAdjustmentFactor {
   444  		f = 1.0 / l.params.MaxAdjustmentFactor
   445  	}
   446  	return f
   447  }
   449  func (l *APILimiter) adjustmentLimit(newValue, initialValue float64) float64 {
   450  	return math.Max(initialValue/l.params.MaxAdjustmentFactor, math.Min(initialValue*l.params.MaxAdjustmentFactor, newValue))
   451  }
   453  func (l *APILimiter) adjustedBurst() int {
   454  	newBurst := l.delayedAdjustment(float64(l.params.RateBurst), float64(l.params.MinParallelRequests), 0.0)
   455  	return int(math.Round(l.adjustmentLimit(newBurst, float64(l.params.RateBurst))))
   456  }
   458  func (l *APILimiter) adjustedLimit() rate.Limit {
   459  	newLimit := rate.Limit(float64(l.params.RateLimit) * l.adjustmentFactor)
   460  	return rate.Limit(l.adjustmentLimit(float64(newLimit), float64(l.params.RateLimit)))
   461  }
   463  func (l *APILimiter) adjustedParallelRequests() int {
   464  	newParallelRequests := l.delayedAdjustment(float64(l.params.ParallelRequests),
   465  		float64(l.params.MinParallelRequests), float64(l.params.MaxParallelRequests))
   466  	return int(l.adjustmentLimit(newParallelRequests, float64(l.params.ParallelRequests)))
   467  }
   469  func (l *APILimiter) requestFinished(r *limitedRequest, err error, code int) {
   470  	if r.finished {
   471  		return
   472  	}
   474  	r.finished = true
   476  	var processingDuration time.Duration
   477  	if !r.startTime.IsZero() {
   478  		processingDuration = time.Since(r.startTime)
   479  	}
   481  	totalDuration := time.Since(r.scheduleTime)
   483  	scopedLog := log.WithFields(logrus.Fields{
   484  		logAPICallName:,
   485  		logUUID:               r.uuid,
   486  		logProcessingDuration: processingDuration,
   487  		logTotalDuration:      totalDuration,
   488  		logWaitDurationTotal:  r.waitDuration,
   489  	})
   491  	if err != nil {
   492  		scopedLog = scopedLog.WithError(err)
   493  	}
   495  	if l.params.Log {
   496  		scopedLog.Info("API call has been processed")
   497  	} else {
   498  		scopedLog.Debug("API call has been processed")
   499  	}
   501  	if r.waitSemaphoreWeight != 0 {
   502  		l.parallelWaitSemaphore.Release(r.waitSemaphoreWeight)
   503  	}
   505  	l.mutex.Lock()
   507  	if !r.startTime.IsZero() {
   508  		l.requestsProcessed++
   509  		l.currentRequestsInFlight--
   510  	}
   512  	// Only auto-adjust ratelimiter using metrics from successful API requests
   513  	if err == nil {
   514  		l.processingDurations = append(l.processingDurations, processingDuration)
   515  		if exceed := len(l.processingDurations) - l.params.MeanOver; exceed > 0 {
   516  			l.processingDurations = l.processingDurations[exceed:]
   517  		}
   518  		l.meanProcessingDuration = calcMeanDuration(l.processingDurations)
   520  		l.waitDurations = append(l.waitDurations, r.waitDuration)
   521  		if exceed := len(l.waitDurations) - l.params.MeanOver; exceed > 0 {
   522  			l.waitDurations = l.waitDurations[exceed:]
   523  		}
   524  		l.meanWaitDuration = calcMeanDuration(l.waitDurations)
   526  		if l.params.AutoAdjust && l.params.EstimatedProcessingDuration != 0 {
   527  			l.adjustmentFactor = l.calculateAdjustmentFactor()
   528  			l.parallelRequests = l.adjustedParallelRequests()
   530  			if l.limiter != nil {
   531  				l.limiter.SetLimit(l.adjustedLimit())
   533  				newBurst := l.adjustedBurst()
   534  				l.limiter.SetBurst(newBurst)
   535  			}
   536  		}
   537  	}
   539  	values := MetricsValues{
   540  		EstimatedProcessingDuration: l.params.EstimatedProcessingDuration.Seconds(),
   541  		WaitDuration:                r.waitDuration,
   542  		MaxWaitDuration:             l.params.MaxWaitDuration,
   543  		MinWaitDuration:             l.params.MinWaitDuration,
   544  		MeanProcessingDuration:      l.meanProcessingDuration,
   545  		MeanWaitDuration:            l.meanWaitDuration,
   546  		ParallelRequests:            l.parallelRequests,
   547  		CurrentRequestsInFlight:     l.currentRequestsInFlight,
   548  		AdjustmentFactor:            l.adjustmentFactor,
   549  		Error:                       err,
   550  		Outcome:                     string(r.outcome),
   551  		ReturnCode:                  code,
   552  	}
   554  	if l.limiter != nil {
   555  		values.Limit = l.limiter.Limit()
   556  		values.Burst = l.limiter.Burst()
   557  	}
   558  	l.mutex.Unlock()
   560  	if l.metrics != nil {
   561  		l.metrics.ProcessedRequest(, values)
   562  	}
   563  }
   565  // calcMeanDuration returns the mean duration in seconds
   566  func calcMeanDuration(durations []time.Duration) float64 {
   567  	total := 0.0
   568  	for _, t := range durations {
   569  		total += t.Seconds()
   570  	}
   571  	return total / float64(len(durations))
   572  }
   574  // LimitedRequest represents a request that is being limited. It is returned
   575  // by Wait() and the caller of Wait() is responsible to call Done() or Error()
   576  // when the API call has been processed or resulted in an error. It is safe to
   577  // call Error() and then Done(). It is not safe to call Done(), Error(), or
   578  // WaitDuration() concurrently.
   579  type LimitedRequest interface {
   580  	Done()
   581  	Error(err error, code int)
   582  	WaitDuration() time.Duration
   583  }
   585  type limitedRequest struct {
   586  	limiter             *APILimiter
   587  	startTime           time.Time
   588  	scheduleTime        time.Time
   589  	waitDuration        time.Duration
   590  	waitSemaphoreWeight int64
   591  	uuid                string
   592  	finished            bool
   593  	outcome             outcome
   594  }
   596  // WaitDuration returns the duration the request had to wait
   597  func (l *limitedRequest) WaitDuration() time.Duration {
   598  	return l.waitDuration
   599  }
   601  // Done must be called when the API request has been successfully processed
   602  func (l *limitedRequest) Done() {
   603  	l.limiter.requestFinished(l, nil, outcomeSuccessCode)
   604  }
   606  // Error must be called when the API request resulted in an error
   607  func (l *limitedRequest) Error(err error, code int) {
   608  	l.limiter.requestFinished(l, err, code)
   609  }
   611  // Wait blocks until the next API call is allowed to be processed. If the
   612  // configured MaxWaitDuration is exceeded, an error is returned. On success, a
   613  // LimitedRequest is returned on which Done() must be called when the API call
   614  // has completed or Error() if an error occurred.
   615  func (l *APILimiter) Wait(ctx context.Context) (LimitedRequest, error) {
   616  	req, err := l.wait(ctx)
   617  	if err != nil {
   618  		l.requestFinished(req, err, outcomeErrorCode)
   619  		return nil, err
   620  	}
   621  	return req, nil
   622  }
   624  // wait implements the API rate limiting delaying functionality. Every error
   625  // message and corresponding log message are documented in
   626  // Documentation/configuration/api-rate-limiting.rst. If any changes related to
   627  // errors or log messages are made to this function, please update the
   628  // aforementioned page as well.
   629  func (l *APILimiter) wait(ctx context.Context) (req *limitedRequest, err error) {
   630  	var (
   631  		limitWaitDuration time.Duration
   632  		r                 *rate.Reservation
   633  	)
   635  	req = &limitedRequest{
   636  		limiter:      l,
   637  		scheduleTime: time.Now(),
   638  		uuid:         uuid.New().String(),
   639  	}
   641  	l.mutex.Lock()
   643  	l.requestsScheduled++
   645  	scopedLog := log.WithFields(logrus.Fields{
   646  		logAPICallName:,
   647  		logUUID:             req.uuid,
   648  		logParallelRequests: l.parallelRequests,
   649  	})
   651  	if l.params.MaxWaitDuration > 0 {
   652  		scopedLog = scopedLog.WithField(logMaxWaitDuration, l.params.MaxWaitDuration)
   653  	}
   655  	if l.params.MinWaitDuration > 0 {
   656  		scopedLog = scopedLog.WithField(logMinWaitDuration, l.params.MinWaitDuration)
   657  	}
   659  	select {
   660  	case <-ctx.Done():
   661  		if l.params.Log {
   662  			scopedLog.Warning("Not processing API request due to cancelled context")
   663  		}
   664  		l.mutex.Unlock()
   665  		req.outcome = outcomeReqCancelled
   666  		err = fmt.Errorf("%w: %w", ErrWaitCancelled, ctx.Err())
   667  		return
   668  	default:
   669  	}
   671  	skip := l.params.SkipInitial > 0 && l.requestsScheduled <= int64(l.params.SkipInitial)
   672  	if skip {
   673  		scopedLog = scopedLog.WithField(logSkipped, skip)
   674  	}
   676  	parallelRequests := l.parallelRequests
   677  	meanProcessingDuration := l.meanProcessingDuration
   678  	l.mutex.Unlock()
   680  	if l.params.Log {
   681  		scopedLog.Info("Processing API request with rate limiter")
   682  	} else {
   683  		scopedLog.Debug("Processing API request with rate limiter")
   684  	}
   686  	if skip {
   687  		goto skipRateLimiter
   688  	}
   690  	if parallelRequests > 0 {
   691  		waitCtx := ctx
   692  		if l.params.MaxWaitDuration > 0 {
   693  			ctx2, cancel := context.WithTimeout(ctx, l.params.MaxWaitDuration)
   694  			defer cancel()
   695  			waitCtx = ctx2
   696  		}
   697  		w := int64(waitSemaphoreResolution / parallelRequests)
   698  		err2 := l.parallelWaitSemaphore.Acquire(waitCtx, w)
   699  		if err2 != nil {
   700  			if l.params.Log {
   701  				scopedLog.WithError(err2).Warning("Not processing API request. Wait duration for maximum parallel requests exceeds maximum")
   702  			}
   703  			req.outcome = outcomeParallelMaxWait
   704  			err = fmt.Errorf("timed out while waiting to be served with %d parallel requests: %w", parallelRequests, err2)
   705  			return
   706  		}
   707  		req.waitSemaphoreWeight = w
   708  	}
   709  	req.waitDuration = time.Since(req.scheduleTime)
   711  	l.mutex.Lock()
   712  	if l.limiter != nil {
   713  		r = l.limiter.Reserve()
   714  		limitWaitDuration = r.Delay()
   716  		scopedLog = scopedLog.WithFields(logrus.Fields{
   717  			logLimit:                  fmt.Sprintf("%.2f/s", l.limiter.Limit()),
   718  			logBurst:                  l.limiter.Burst(),
   719  			logWaitDurationLimit:      limitWaitDuration,
   720  			logMaxWaitDurationLimiter: l.params.MaxWaitDuration - req.waitDuration,
   721  		})
   722  	}
   723  	l.mutex.Unlock()
   725  	if l.params.MinWaitDuration > 0 && limitWaitDuration < l.params.MinWaitDuration {
   726  		limitWaitDuration = l.params.MinWaitDuration
   727  	}
   729  	if (l.params.MaxWaitDuration > 0 && (limitWaitDuration+req.waitDuration) > l.params.MaxWaitDuration) || limitWaitDuration == rate.InfDuration {
   730  		if l.params.Log {
   731  			scopedLog.Warning("Not processing API request. Wait duration exceeds maximum")
   732  		}
   734  		// The rate limiter should only consider a reservation valid if
   735  		// the request is actually processed. Cancellation of the
   736  		// reservation should happen before we sleep below.
   737  		if r != nil {
   738  			r.Cancel()
   739  		}
   741  		// Instead of returning immediately, pace the caller by
   742  		// sleeping for the mean processing duration. This helps
   743  		// against callers who disrespect 429 error codes and retry
   744  		// immediately.
   745  		if meanProcessingDuration > 0.0 {
   746  			time.Sleep(time.Duration(meanProcessingDuration * float64(time.Second)))
   747  		}
   749  		req.outcome = outcomeLimitMaxWait
   750  		err = fmt.Errorf("request would have to wait %v to be served (maximum wait duration: %v)",
   751  			limitWaitDuration, l.params.MaxWaitDuration-req.waitDuration)
   752  		return
   753  	}
   755  	if limitWaitDuration != 0 {
   756  		select {
   757  		case <-time.After(limitWaitDuration):
   758  		case <-ctx.Done():
   759  			if l.params.Log {
   760  				scopedLog.Warning("Not processing API request due to cancelled context while waiting")
   761  			}
   762  			// The rate limiter should only consider a reservation
   763  			// valid if the request is actually processed.
   764  			if r != nil {
   765  				r.Cancel()
   766  			}
   768  			req.outcome = outcomeReqCancelled
   769  			err = fmt.Errorf("%w: %w", ErrWaitCancelled, ctx.Err())
   770  			return
   771  		}
   772  	}
   774  	req.waitDuration = time.Since(req.scheduleTime)
   776  skipRateLimiter:
   778  	l.mutex.Lock()
   779  	l.currentRequestsInFlight++
   780  	l.mutex.Unlock()
   782  	scopedLog = scopedLog.WithField(logWaitDurationTotal, req.waitDuration)
   784  	if l.params.Log {
   785  		scopedLog.Info("API request released by rate limiter")
   786  	} else {
   787  		scopedLog.Debug("API request released by rate limiter")
   788  	}
   790  	req.startTime = time.Now()
   791  	return req, nil
   793  }
   795  func parseRate(r string) (rate.Limit, error) {
   796  	tokens := strings.SplitN(r, "/", 2)
   797  	if len(tokens) != 2 {
   798  		return 0, fmt.Errorf("not in the form number/interval")
   799  	}
   801  	f, err := strconv.ParseFloat(tokens[0], 64)
   802  	if err != nil {
   803  		return 0, fmt.Errorf("unable to parse float %q: %w", tokens[0], err)
   804  	}
   806  	// Reject rates such as 1/1 or 10/10 as it will default to nanoseconds
   807  	// which is likely unexpected to the user. Require an explicit suffix.
   808  	if _, err := strconv.ParseInt(string(tokens[1]), 10, 64); err == nil {
   809  		return 0, fmt.Errorf("interval %q must contain duration suffix", tokens[1])
   810  	}
   812  	// If duration is provided as "m" or "s", convert it into "1m" or "1s"
   813  	if _, err := strconv.ParseInt(string(tokens[1][0]), 10, 64); err != nil {
   814  		tokens[1] = "1" + tokens[1]
   815  	}
   817  	d, err := time.ParseDuration(tokens[1])
   818  	if err != nil {
   819  		return 0, fmt.Errorf("unable to parse duration %q: %w", tokens[1], err)
   820  	}
   822  	return rate.Limit(f / d.Seconds()), nil
   823  }
   825  // APILimiterSet is a set of APILimiter indexed by name
   826  type APILimiterSet struct {
   827  	limiters map[string]*APILimiter
   828  	metrics  MetricsObserver
   829  }
   831  // MetricsValues is the snapshot of relevant values to feed into the
   832  // MetricsObserver
   833  type MetricsValues struct {
   834  	WaitDuration                time.Duration
   835  	MinWaitDuration             time.Duration
   836  	MaxWaitDuration             time.Duration
   837  	Outcome                     string
   838  	MeanProcessingDuration      float64
   839  	MeanWaitDuration            float64
   840  	EstimatedProcessingDuration float64
   841  	ParallelRequests            int
   842  	Limit                       rate.Limit
   843  	Burst                       int
   844  	CurrentRequestsInFlight     int
   845  	AdjustmentFactor            float64
   846  	Error                       error
   847  	ReturnCode                  int
   848  }
   850  // MetricsObserver is the interface that must be implemented to extract metrics
   851  type MetricsObserver interface {
   852  	// ProcessedRequest is invoked after invocation of an API call
   853  	ProcessedRequest(name string, values MetricsValues)
   854  }
   856  // NewAPILimiterSet creates a new APILimiterSet based on a set of rate limiting
   857  // configurations and the default configuration. Any rate limiter that is
   858  // configured in the config OR the defaults will be configured and made
   859  // available via the Limiter(name) and Wait() function.
   860  func NewAPILimiterSet(config map[string]string, defaults map[string]APILimiterParameters, metrics MetricsObserver) (*APILimiterSet, error) {
   861  	limiters := map[string]*APILimiter{}
   863  	for name, p := range defaults {
   864  		// Merge user config into defaults when provided
   865  		if userConfig, ok := config[name]; ok {
   866  			combinedParams, err := p.MergeUserConfig(userConfig)
   867  			if err != nil {
   868  				return nil, err
   869  			}
   870  			p = combinedParams
   871  		}
   873  		limiters[name] = NewAPILimiter(name, p, metrics)
   874  	}
   876  	for name, c := range config {
   877  		if _, ok := defaults[name]; !ok {
   878  			l, err := NewAPILimiterFromConfig(name, c, metrics)
   879  			if err != nil {
   880  				return nil, fmt.Errorf("unable to parse rate limiting configuration %s=%s: %w", name, c, err)
   881  			}
   883  			limiters[name] = l
   884  		}
   885  	}
   887  	return &APILimiterSet{
   888  		limiters: limiters,
   889  		metrics:  metrics,
   890  	}, nil
   891  }
   893  // Limiter returns the APILimiter with a given name
   894  func (s *APILimiterSet) Limiter(name string) *APILimiter {
   895  	return s.limiters[name]
   896  }
   898  type dummyRequest struct{}
   900  func (d dummyRequest) WaitDuration() time.Duration { return 0 }
   901  func (d dummyRequest) Done()                       {}
   902  func (d dummyRequest) Error(err error, code int)   {}
   904  // Wait invokes Wait() on the APILimiter with the given name. If the limiter
   905  // does not exist, a dummy limiter is used which will not impose any
   906  // restrictions.
   907  func (s *APILimiterSet) Wait(ctx context.Context, name string) (LimitedRequest, error) {
   908  	l, ok := s.limiters[name]
   909  	if !ok {
   910  		return dummyRequest{}, nil
   911  	}
   913  	return l.Wait(ctx)
   914  }
   916  // parsePositiveInt parses value as an int. It returns an error if value cannot
   917  // be parsed or is negative.
   918  func parsePositiveInt(value string) (int, error) {
   919  	switch i64, err := strconv.ParseInt(value, 10, 64); {
   920  	case err != nil:
   921  		return 0, fmt.Errorf("unable to parse positive integer %q: %w", value, err)
   922  	case i64 < 0:
   923  		return 0, fmt.Errorf("unable to parse positive integer %q: negative value", value)
   924  	case i64 > math.MaxInt:
   925  		return 0, fmt.Errorf("unable to parse positive integer %q: overflow", value)
   926  	default:
   927  		return int(i64), nil
   928  	}
   929  }