git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/retry/examples/custom_retry_function_test.go (about)

     1  package retry_test
     2  
     3  import (
     4  	"fmt"
     5  	"io/ioutil"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"strconv"
     9  	"testing"
    10  	"time"
    11  
    12  	"git.sr.ht/~pingoo/stdx/retry"
    13  	"github.com/stretchr/testify/assert"
    14  )
    15  
    16  // RetriableError is a custom error that contains a positive duration for the next retry
    17  type RetriableError struct {
    18  	Err        error
    19  	RetryAfter time.Duration
    20  }
    21  
    22  // Error returns error message and a Retry-After duration
    23  func (e *RetriableError) Error() string {
    24  	return fmt.Sprintf("%s (retry after %v)", e.Err.Error(), e.RetryAfter)
    25  }
    26  
    27  var _ error = (*RetriableError)(nil)
    28  
    29  // TestCustomRetryFunction shows how to use a custom retry function
    30  func TestCustomRetryFunction(t *testing.T) {
    31  	attempts := 5 // server succeeds after 5 attempts
    32  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    33  		if attempts > 0 {
    34  			// inform the client to retry after one second using standard
    35  			// HTTP 429 status code with Retry-After header in seconds
    36  			w.Header().Add("Retry-After", "1")
    37  			w.WriteHeader(http.StatusTooManyRequests)
    38  			w.Write([]byte("Server limit reached"))
    39  			attempts--
    40  			return
    41  		}
    42  		w.WriteHeader(http.StatusOK)
    43  		w.Write([]byte("hello"))
    44  	}))
    45  	defer ts.Close()
    46  
    47  	var body []byte
    48  
    49  	err := retry.Do(
    50  		func() error {
    51  			resp, err := http.Get(ts.URL)
    52  
    53  			if err == nil {
    54  				defer func() {
    55  					if err := resp.Body.Close(); err != nil {
    56  						panic(err)
    57  					}
    58  				}()
    59  				body, err = ioutil.ReadAll(resp.Body)
    60  				if resp.StatusCode != 200 {
    61  					err = fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
    62  					if resp.StatusCode == http.StatusTooManyRequests {
    63  						// check Retry-After header if it contains seconds to wait for the next retry
    64  						if retryAfter, e := strconv.ParseInt(resp.Header.Get("Retry-After"), 10, 32); e == nil {
    65  							// the server returns 0 to inform that the operation cannot be retried
    66  							if retryAfter <= 0 {
    67  								return retry.Unrecoverable(err)
    68  							}
    69  							return &RetriableError{
    70  								Err:        err,
    71  								RetryAfter: time.Duration(retryAfter) * time.Second,
    72  							}
    73  						}
    74  						// A real implementation should also try to http.Parse the retryAfter response header
    75  						// to conform with HTTP specification. Herein we know here that we return only seconds.
    76  					}
    77  				}
    78  			}
    79  
    80  			return err
    81  		},
    82  		retry.DelayType(func(n uint, err error, config *retry.Config) time.Duration {
    83  			fmt.Println("Server fails with: " + err.Error())
    84  			if retriable, ok := err.(*RetriableError); ok {
    85  				fmt.Printf("Client follows server recommendation to retry after %v\n", retriable.RetryAfter)
    86  				return retriable.RetryAfter
    87  			}
    88  			// apply a default exponential back off strategy
    89  			return retry.BackOffDelay(n, err, config)
    90  		}),
    91  	)
    92  
    93  	fmt.Println("Server responds with: " + string(body))
    94  
    95  	assert.NoError(t, err)
    96  	assert.Equal(t, "hello", string(body))
    97  }