github.com/Azure/aad-pod-identity@v1.8.17/pkg/retry/retry.go (about) 1 package retry 2 3 import ( 4 "strings" 5 "time" 6 ) 7 8 // Func is a function that is being retried. 9 type Func func() error 10 11 // ShouldRetryFunc is a function that consumes the last-known error 12 // from the targeted function and determine if we should run it again 13 type ShouldRetryFunc func(error) bool 14 15 // RetriableError is an error that when occurred, 16 // we should retry targeted function. 17 type RetriableError string 18 19 // ClientInt is an abstraction that retries running a 20 // function based on what type of error has occurred 21 type ClientInt interface { 22 Do(f Func, shouldRetry ShouldRetryFunc) error 23 RegisterRetriableErrors(rerrs ...RetriableError) 24 UnregisterRetriableErrors(rerrs ...RetriableError) 25 } 26 27 type client struct { 28 retriableErrors map[RetriableError]bool 29 maxRetry int 30 retryInterval time.Duration 31 } 32 33 var _ ClientInt = &client{} 34 35 // NewRetryClient returns an implementation of ClientInt that retries 36 // running a given function based on the parameters provided. 37 func NewRetryClient(maxRetry int, retryInterval time.Duration) ClientInt { 38 return &client{ 39 retriableErrors: make(map[RetriableError]bool), 40 maxRetry: maxRetry, 41 retryInterval: retryInterval, 42 } 43 } 44 45 // Do runs the targeted function f and will retry running 46 // it if it returns an error and shouldRetry returns true. 47 func (c *client) Do(f Func, shouldRetry ShouldRetryFunc) error { 48 // The original error 49 err := f() 50 if err == nil { 51 return nil 52 } 53 54 // Error occurred when retrying 55 rerr := err 56 for i := 0; i < c.maxRetry; i++ { 57 if rerr == nil || !c.isRetriable(rerr) || !shouldRetry(rerr) { 58 break 59 } 60 61 time.Sleep(c.retryInterval) 62 // We should retry if: 63 // 1) the last known error is not nil 64 // 2) the error is retriable 65 // 3) shouldRetry returns true 66 rerr = f() 67 } 68 69 // Return the original error from the first run, 70 // indicating that we retried running the function 71 return err 72 } 73 74 // RegisterRetriableErrors registers a retriable error to the retrier. 75 func (c *client) RegisterRetriableErrors(rerrs ...RetriableError) { 76 for _, rerr := range rerrs { 77 c.retriableErrors[rerr] = true 78 } 79 } 80 81 // UnregisterRetriableErrors unregisters an error from the retrier. 82 func (c *client) UnregisterRetriableErrors(rerrs ...RetriableError) { 83 for _, rerr := range rerrs { 84 delete(c.retriableErrors, rerr) 85 } 86 } 87 88 // isRetriable returns true if an error is retriable. 89 func (c *client) isRetriable(err error) bool { 90 if err == nil { 91 return false 92 } 93 94 for rerr := range c.retriableErrors { 95 if strings.Contains(err.Error(), string(rerr)) { 96 return true 97 } 98 } 99 100 return false 101 }