github.com/Azure/aad-pod-identity@v1.8.17/pkg/retry/retry.go (about)

     1  package retry
     2  
     3  import (
     4  	"strings"
     5  	"time"
     6  )
     7  
     8  // Func is a function that is being retried.
     9  type Func func() error
    10  
    11  // ShouldRetryFunc is a function that consumes the last-known error
    12  // from the targeted function and determine if we should run it again
    13  type ShouldRetryFunc func(error) bool
    14  
    15  // RetriableError is an error that when occurred,
    16  // we should retry targeted function.
    17  type RetriableError string
    18  
    19  // ClientInt is an abstraction that retries running a
    20  // function based on what type of error has occurred
    21  type ClientInt interface {
    22  	Do(f Func, shouldRetry ShouldRetryFunc) error
    23  	RegisterRetriableErrors(rerrs ...RetriableError)
    24  	UnregisterRetriableErrors(rerrs ...RetriableError)
    25  }
    26  
    27  type client struct {
    28  	retriableErrors map[RetriableError]bool
    29  	maxRetry        int
    30  	retryInterval   time.Duration
    31  }
    32  
    33  var _ ClientInt = &client{}
    34  
    35  // NewRetryClient returns an implementation of ClientInt that retries
    36  // running a given function based on the parameters provided.
    37  func NewRetryClient(maxRetry int, retryInterval time.Duration) ClientInt {
    38  	return &client{
    39  		retriableErrors: make(map[RetriableError]bool),
    40  		maxRetry:        maxRetry,
    41  		retryInterval:   retryInterval,
    42  	}
    43  }
    44  
    45  // Do runs the targeted function f and will retry running
    46  // it if it returns an error and shouldRetry returns true.
    47  func (c *client) Do(f Func, shouldRetry ShouldRetryFunc) error {
    48  	// The original error
    49  	err := f()
    50  	if err == nil {
    51  		return nil
    52  	}
    53  
    54  	// Error occurred when retrying
    55  	rerr := err
    56  	for i := 0; i < c.maxRetry; i++ {
    57  		if rerr == nil || !c.isRetriable(rerr) || !shouldRetry(rerr) {
    58  			break
    59  		}
    60  
    61  		time.Sleep(c.retryInterval)
    62  		// We should retry if:
    63  		// 1) the last known error is not nil
    64  		// 2) the error is retriable
    65  		// 3) shouldRetry returns true
    66  		rerr = f()
    67  	}
    68  
    69  	// Return the original error from the first run,
    70  	// indicating that we retried running the function
    71  	return err
    72  }
    73  
    74  // RegisterRetriableErrors registers a retriable error to the retrier.
    75  func (c *client) RegisterRetriableErrors(rerrs ...RetriableError) {
    76  	for _, rerr := range rerrs {
    77  		c.retriableErrors[rerr] = true
    78  	}
    79  }
    80  
    81  // UnregisterRetriableErrors unregisters an error from the retrier.
    82  func (c *client) UnregisterRetriableErrors(rerrs ...RetriableError) {
    83  	for _, rerr := range rerrs {
    84  		delete(c.retriableErrors, rerr)
    85  	}
    86  }
    87  
    88  // isRetriable returns true if an error is retriable.
    89  func (c *client) isRetriable(err error) bool {
    90  	if err == nil {
    91  		return false
    92  	}
    93  
    94  	for rerr := range c.retriableErrors {
    95  		if strings.Contains(err.Error(), string(rerr)) {
    96  			return true
    97  		}
    98  	}
    99  
   100  	return false
   101  }