github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/retryutils/retry.go (about)

     1  /*
     2  Copyright 2019-2022 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package retryutils defines common retry and jitter logic.
    18  package retryutils
    19  
    20  import (
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"log/slog"
    25  	"time"
    26  
    27  	"github.com/gravitational/trace"
    28  	"github.com/jonboulle/clockwork"
    29  )
    30  
    31  // Retry is an interface that provides retry logic
    32  type Retry interface {
    33  	// Reset resets retry state
    34  	Reset()
    35  	// Inc increments retry attempt
    36  	Inc()
    37  	// Duration returns retry duration,
    38  	// could be 0
    39  	Duration() time.Duration
    40  	// After returns time.Time channel
    41  	// that fires after Duration delay,
    42  	// could fire right away if Duration is 0
    43  	After() <-chan time.Time
    44  	// Clone creates a copy of this retry in a
    45  	// reset state.
    46  	Clone() Retry
    47  }
    48  
    49  // LinearConfig sets up retry configuration
    50  // using arithmetic progression
    51  type LinearConfig struct {
    52  	// First is a first element of the progression,
    53  	// could be 0
    54  	First time.Duration
    55  	// Step is a step of the progression, can't be 0
    56  	Step time.Duration
    57  	// Max is a maximum value of the progression,
    58  	// can't be 0
    59  	Max time.Duration
    60  	// Jitter is an optional jitter function to be applied
    61  	// to the delay.  Note that supplying a jitter means that
    62  	// successive calls to Duration may return different results.
    63  	Jitter Jitter `json:"-"`
    64  	// AutoReset, if greater than zero, causes the linear retry to automatically
    65  	// reset after Max * AutoReset has elapsed since the last call to Incr.
    66  	AutoReset int64
    67  	// Clock to override clock in tests
    68  	Clock clockwork.Clock
    69  }
    70  
    71  // CheckAndSetDefaults checks and sets defaults
    72  func (c *LinearConfig) CheckAndSetDefaults() error {
    73  	if c.Step == 0 {
    74  		return trace.BadParameter("missing parameter Step")
    75  	}
    76  	if c.Max == 0 {
    77  		return trace.BadParameter("missing parameter Max")
    78  	}
    79  	if c.Clock == nil {
    80  		c.Clock = clockwork.NewRealClock()
    81  	}
    82  	return nil
    83  }
    84  
    85  // NewLinear returns a new instance of linear retry
    86  func NewLinear(cfg LinearConfig) (*Linear, error) {
    87  	if err := cfg.CheckAndSetDefaults(); err != nil {
    88  		return nil, trace.Wrap(err)
    89  	}
    90  	return newLinear(cfg), nil
    91  }
    92  
    93  // newLinear creates an instance of Linear from a
    94  // previously verified configuration.
    95  func newLinear(cfg LinearConfig) *Linear {
    96  	closedChan := make(chan time.Time)
    97  	close(closedChan)
    98  	return &Linear{LinearConfig: cfg, closedChan: closedChan}
    99  }
   100  
   101  // NewConstant returns a new linear retry with constant interval.
   102  func NewConstant(interval time.Duration) (*Linear, error) {
   103  	return NewLinear(LinearConfig{Step: interval, Max: interval})
   104  }
   105  
   106  // Linear is used to calculate retry period
   107  // that follows the following logic:
   108  // On the first error there is no delay
   109  // on the next error, delay is FastLinear
   110  // on all other errors, delay is SlowLinear
   111  type Linear struct {
   112  	// LinearConfig is a linear retry config
   113  	LinearConfig
   114  	lastUse    time.Time
   115  	attempt    int64
   116  	closedChan chan time.Time
   117  }
   118  
   119  // Reset resets retry period to initial state
   120  func (r *Linear) Reset() {
   121  	r.attempt = 0
   122  }
   123  
   124  // ResetToDelay resets retry period and increments the number of attempts.
   125  func (r *Linear) ResetToDelay() {
   126  	r.Reset()
   127  	r.Inc()
   128  }
   129  
   130  // Clone creates an identical copy of Linear with fresh state.
   131  func (r *Linear) Clone() Retry {
   132  	return newLinear(r.LinearConfig)
   133  }
   134  
   135  // Inc increments attempt counter
   136  func (r *Linear) Inc() {
   137  	r.attempt++
   138  }
   139  
   140  // Duration returns retry duration based on state
   141  func (r *Linear) Duration() time.Duration {
   142  	if r.AutoReset > 0 {
   143  		now := r.Clock.Now()
   144  		if now.After(r.lastUse.Add(r.Max * time.Duration(r.AutoReset))) {
   145  			r.Reset()
   146  		}
   147  		r.lastUse = now
   148  	}
   149  
   150  	a := r.First + time.Duration(r.attempt)*r.Step
   151  	if a < 1 {
   152  		return 0
   153  	}
   154  
   155  	if a > r.Max {
   156  		a = r.Max
   157  	}
   158  
   159  	if r.Jitter != nil {
   160  		a = r.Jitter(a)
   161  	}
   162  
   163  	return a
   164  }
   165  
   166  // After returns channel that fires with timeout
   167  // defined in Duration method, as a special case
   168  // if Duration is 0 returns a closed channel
   169  func (r *Linear) After() <-chan time.Time {
   170  	d := r.Duration()
   171  	if d < 1 {
   172  		return r.closedChan
   173  	}
   174  	return r.Clock.After(d)
   175  }
   176  
   177  // String returns user-friendly representation of the LinearPeriod
   178  func (r *Linear) String() string {
   179  	return fmt.Sprintf("Linear(attempt=%v, duration=%v)", r.attempt, r.Duration())
   180  }
   181  
   182  // For retries the provided function until it succeeds or the context expires.
   183  func (r *Linear) For(ctx context.Context, retryFn func() error) error {
   184  	for {
   185  		err := retryFn()
   186  		if err == nil {
   187  			return nil
   188  		}
   189  		var permanentRetryError *permanentRetryError
   190  		if errors.As(trace.Unwrap(err), &permanentRetryError) {
   191  			return trace.Wrap(err)
   192  		}
   193  		slog.DebugContext(ctx, "Waiting to retry operation again", "wait", r.Duration(), "error", err)
   194  		select {
   195  		case <-r.After():
   196  			r.Inc()
   197  		case <-ctx.Done():
   198  			return trace.LimitExceeded(ctx.Err().Error())
   199  		}
   200  	}
   201  }
   202  
   203  // PermanentRetryError returns a new instance of a permanent retry error.
   204  func PermanentRetryError(err error) error {
   205  	return &permanentRetryError{err: err}
   206  }
   207  
   208  // permanentRetryError indicates that retry loop should stop.
   209  type permanentRetryError struct {
   210  	err error
   211  }
   212  
   213  // Error returns the original error message.
   214  func (e *permanentRetryError) Error() string {
   215  	return e.err.Error()
   216  }
   217  
   218  // RetryFastFor retries a function repeatedly for a set amount of
   219  // time before returning an error.
   220  //
   221  // Intended mostly for tests.
   222  func RetryStaticFor(d time.Duration, w time.Duration, f func() error) error {
   223  	start := time.Now()
   224  	var err error
   225  
   226  	for time.Since(start) < d {
   227  		if err = f(); err == nil {
   228  			break
   229  		}
   230  
   231  		time.Sleep(w)
   232  	}
   233  
   234  	return err
   235  }