github.com/m3db/m3@v1.5.1-0.20231129193456-75a402aa583b/src/x/retry/retry.go (about)

     1  // Copyright (c) 2016 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package retry
    22  
    23  import (
    24  	"context"
    25  	"errors"
    26  	"fmt"
    27  	"math"
    28  	"time"
    29  
    30  	xerrors "github.com/m3db/m3/src/x/errors"
    31  
    32  	"github.com/uber-go/tally"
    33  )
    34  
    35  var (
    36  	// ErrWhileConditionFalse is returned when the while condition to a while retry
    37  	// method evaluates false.
    38  	ErrWhileConditionFalse = errors.New("retry while condition evaluated to false")
    39  )
    40  
    41  type retrier struct {
    42  	opts           Options
    43  	initialBackoff time.Duration
    44  	backoffFactor  float64
    45  	maxBackoff     time.Duration
    46  	maxRetries     int
    47  	forever        bool
    48  	jitter         bool
    49  	rngFn          RngFn
    50  	sleepFn        func(t time.Duration)
    51  	metrics        retrierMetrics
    52  }
    53  
    54  type retrierMetrics struct {
    55  	calls              tally.Counter
    56  	attempts           tally.Counter
    57  	success            tally.Counter
    58  	successLatency     tally.Histogram
    59  	errors             tally.Counter
    60  	errorsNotRetryable tally.Counter
    61  	errorsFinal        tally.Counter
    62  	errorsLatency      tally.Histogram
    63  	retries            tally.Counter
    64  }
    65  
    66  // NewRetrier creates a new retrier.
    67  func NewRetrier(opts Options) Retrier {
    68  	scope := opts.MetricsScope()
    69  	errorTags := struct {
    70  		retryable    map[string]string
    71  		notRetryable map[string]string
    72  	}{
    73  		map[string]string{
    74  			"type": "retryable",
    75  		},
    76  		map[string]string{
    77  			"type": "not-retryable",
    78  		},
    79  	}
    80  
    81  	return &retrier{
    82  		opts:           opts,
    83  		initialBackoff: opts.InitialBackoff(),
    84  		backoffFactor:  opts.BackoffFactor(),
    85  		maxBackoff:     opts.MaxBackoff(),
    86  		maxRetries:     opts.MaxRetries(),
    87  		forever:        opts.Forever(),
    88  		jitter:         opts.Jitter(),
    89  		rngFn:          opts.RngFn(),
    90  		sleepFn:        time.Sleep,
    91  		metrics: retrierMetrics{
    92  			calls:              scope.Counter("calls"),
    93  			attempts:           scope.Counter("attempts"),
    94  			success:            scope.Counter("success"),
    95  			successLatency:     histogramWithDurationBuckets(scope, "success-latency"),
    96  			errors:             scope.Tagged(errorTags.retryable).Counter("errors"),
    97  			errorsNotRetryable: scope.Tagged(errorTags.notRetryable).Counter("errors"),
    98  			errorsFinal:        scope.Counter("errors-final"),
    99  			errorsLatency:      histogramWithDurationBuckets(scope, "errors-latency"),
   100  			retries:            scope.Counter("retries"),
   101  		},
   102  	}
   103  }
   104  
   105  func (r *retrier) Options() Options {
   106  	return r.opts
   107  }
   108  
   109  func (r *retrier) Attempt(fn Fn) error {
   110  	return r.attempt(nil, fn)
   111  }
   112  
   113  func (r *retrier) AttemptWhile(continueFn ContinueFn, fn Fn) error {
   114  	return r.attempt(continueFn, fn)
   115  }
   116  
   117  func (r *retrier) AttemptContext(ctx context.Context, fn Fn) error {
   118  	contextNotCancelled := func(attempt int) bool {
   119  		select {
   120  		case <-ctx.Done():
   121  			return false
   122  		default:
   123  			return true
   124  		}
   125  	}
   126  	err := r.attempt(contextNotCancelled, fn)
   127  	if err != nil {
   128  		if errors.Is(err, ErrWhileConditionFalse) {
   129  			return fmt.Errorf("context canceled while retrying: %w", ctx.Err())
   130  		}
   131  		return err
   132  	}
   133  	return nil
   134  }
   135  
   136  func (r *retrier) attempt(continueFn ContinueFn, fn Fn) error {
   137  	// Always track a call, useful for counting number of total operations.
   138  	r.metrics.calls.Inc(1)
   139  
   140  	attempt := 0
   141  
   142  	if continueFn != nil && !continueFn(attempt) {
   143  		return ErrWhileConditionFalse
   144  	}
   145  
   146  	start := time.Now()
   147  	err := fn()
   148  	duration := time.Since(start)
   149  	r.metrics.attempts.Inc(1)
   150  	attempt++
   151  	if err == nil {
   152  		r.metrics.successLatency.RecordDuration(duration)
   153  		r.metrics.success.Inc(1)
   154  		return nil
   155  	}
   156  	r.metrics.errorsLatency.RecordDuration(duration)
   157  	if xerrors.IsNonRetryableError(err) {
   158  		r.metrics.errorsNotRetryable.Inc(1)
   159  		return err
   160  	}
   161  	r.metrics.errors.Inc(1)
   162  
   163  	for i := 1; r.forever || i <= r.maxRetries; i++ {
   164  		r.sleepFn(time.Duration(BackoffNanos(
   165  			i,
   166  			r.jitter,
   167  			r.backoffFactor,
   168  			r.initialBackoff,
   169  			r.maxBackoff,
   170  			r.rngFn,
   171  		)))
   172  
   173  		if continueFn != nil && !continueFn(attempt) {
   174  			return ErrWhileConditionFalse
   175  		}
   176  
   177  		r.metrics.retries.Inc(1)
   178  		start := time.Now()
   179  		err = fn()
   180  		duration := time.Since(start)
   181  		r.metrics.attempts.Inc(1)
   182  		attempt++
   183  		if err == nil {
   184  			r.metrics.successLatency.RecordDuration(duration)
   185  			r.metrics.success.Inc(1)
   186  			return nil
   187  		}
   188  		r.metrics.errorsLatency.RecordDuration(duration)
   189  		if xerrors.IsNonRetryableError(err) {
   190  			r.metrics.errorsNotRetryable.Inc(1)
   191  			return err
   192  		}
   193  		r.metrics.errors.Inc(1)
   194  	}
   195  	r.metrics.errorsFinal.Inc(1)
   196  
   197  	return err
   198  }
   199  
   200  // BackoffNanos calculates the backoff for a retry in nanoseconds.
   201  func BackoffNanos(
   202  	retry int,
   203  	jitter bool,
   204  	backoffFactor float64,
   205  	initialBackoff time.Duration,
   206  	maxBackoff time.Duration,
   207  	rngFn RngFn,
   208  ) int64 {
   209  	backoff := initialBackoff.Nanoseconds()
   210  	if retry >= 1 {
   211  		backoffFloat64 := float64(backoff) * math.Pow(backoffFactor, float64(retry-1))
   212  		// math.Inf is also larger than math.MaxInt64.
   213  		if backoffFloat64 > math.MaxInt64 {
   214  			return maxBackoff.Nanoseconds()
   215  		}
   216  		backoff = int64(backoffFloat64)
   217  	}
   218  	// Validate the value of backoff to make sure Int63n() does not panic.
   219  	if jitter && backoff >= 2 {
   220  		half := backoff / 2
   221  		backoff = half + rngFn(half)
   222  	}
   223  	if maxBackoff := maxBackoff.Nanoseconds(); backoff > maxBackoff {
   224  		backoff = maxBackoff
   225  	}
   226  	return backoff
   227  }
   228  
   229  // histogramWithDurationBuckets returns a histogram with the standard duration buckets.
   230  func histogramWithDurationBuckets(scope tally.Scope, name string) tally.Histogram {
   231  	sub := scope.Tagged(map[string]string{
   232  		// Bump the version if the histogram buckets need to be changed to avoid overlapping buckets
   233  		// in the same query causing errors.
   234  		"schema": "v1",
   235  	})
   236  	buckets := append(tally.DurationBuckets{0, time.Millisecond},
   237  		tally.MustMakeExponentialDurationBuckets(2*time.Millisecond, 1.5, 30)...)
   238  	return sub.Histogram(name, buckets)
   239  }