git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/retry/examples/custom_retry_function_test.go (about) 1 package retry_test 2 3 import ( 4 "fmt" 5 "io/ioutil" 6 "net/http" 7 "net/http/httptest" 8 "strconv" 9 "testing" 10 "time" 11 12 "git.sr.ht/~pingoo/stdx/retry" 13 "github.com/stretchr/testify/assert" 14 ) 15 16 // RetriableError is a custom error that contains a positive duration for the next retry 17 type RetriableError struct { 18 Err error 19 RetryAfter time.Duration 20 } 21 22 // Error returns error message and a Retry-After duration 23 func (e *RetriableError) Error() string { 24 return fmt.Sprintf("%s (retry after %v)", e.Err.Error(), e.RetryAfter) 25 } 26 27 var _ error = (*RetriableError)(nil) 28 29 // TestCustomRetryFunction shows how to use a custom retry function 30 func TestCustomRetryFunction(t *testing.T) { 31 attempts := 5 // server succeeds after 5 attempts 32 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 33 if attempts > 0 { 34 // inform the client to retry after one second using standard 35 // HTTP 429 status code with Retry-After header in seconds 36 w.Header().Add("Retry-After", "1") 37 w.WriteHeader(http.StatusTooManyRequests) 38 w.Write([]byte("Server limit reached")) 39 attempts-- 40 return 41 } 42 w.WriteHeader(http.StatusOK) 43 w.Write([]byte("hello")) 44 })) 45 defer ts.Close() 46 47 var body []byte 48 49 err := retry.Do( 50 func() error { 51 resp, err := http.Get(ts.URL) 52 53 if err == nil { 54 defer func() { 55 if err := resp.Body.Close(); err != nil { 56 panic(err) 57 } 58 }() 59 body, err = ioutil.ReadAll(resp.Body) 60 if resp.StatusCode != 200 { 61 err = fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) 62 if resp.StatusCode == http.StatusTooManyRequests { 63 // check Retry-After header if it contains seconds to wait for the next retry 64 if retryAfter, e := strconv.ParseInt(resp.Header.Get("Retry-After"), 10, 32); e == nil { 65 // the server returns 0 to inform that the operation cannot be retried 66 if retryAfter <= 0 { 67 return retry.Unrecoverable(err) 68 } 69 return &RetriableError{ 70 Err: err, 71 RetryAfter: time.Duration(retryAfter) * time.Second, 72 } 73 } 74 // A real implementation should also try to http.Parse the retryAfter response header 75 // to conform with HTTP specification. Herein we know here that we return only seconds. 76 } 77 } 78 } 79 80 return err 81 }, 82 retry.DelayType(func(n uint, err error, config *retry.Config) time.Duration { 83 fmt.Println("Server fails with: " + err.Error()) 84 if retriable, ok := err.(*RetriableError); ok { 85 fmt.Printf("Client follows server recommendation to retry after %v\n", retriable.RetryAfter) 86 return retriable.RetryAfter 87 } 88 // apply a default exponential back off strategy 89 return retry.BackOffDelay(n, err, config) 90 }), 91 ) 92 93 fmt.Println("Server responds with: " + string(body)) 94 95 assert.NoError(t, err) 96 assert.Equal(t, "hello", string(body)) 97 }