github.com/go-playground/pkg/v5@v5.29.1/net/http/retrier.go (about)

     1  //go:build go1.18
     2  // +build go1.18
     3  
     4  package httpext
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"io"
    10  	"net/http"
    11  	"strconv"
    12  	"time"
    13  
    14  	bytesext "github.com/go-playground/pkg/v5/bytes"
    15  	errorsext "github.com/go-playground/pkg/v5/errors"
    16  	ioext "github.com/go-playground/pkg/v5/io"
    17  	typesext "github.com/go-playground/pkg/v5/types"
    18  	valuesext "github.com/go-playground/pkg/v5/values"
    19  	. "github.com/go-playground/pkg/v5/values/result"
    20  )
    21  
    22  // ErrStatusCode can be used to treat/indicate a status code as an error and ability to indicate if it is retryable.
    23  type ErrStatusCode struct {
    24  	// StatusCode is the HTTP response status code that was encountered.
    25  	StatusCode int
    26  
    27  	// IsRetryableStatusCode indicates if the status code is considered retryable.
    28  	IsRetryableStatusCode bool
    29  
    30  	// Headers contains the headers from the HTTP response.
    31  	Headers http.Header
    32  
    33  	// Body is the optional body of the HTTP response.
    34  	Body []byte
    35  }
    36  
    37  // Error returns the error message for the status code.
    38  func (e ErrStatusCode) Error() string {
    39  	return "status code encountered: " + strconv.Itoa(e.StatusCode)
    40  }
    41  
    42  // IsRetryable returns if the provided status code is considered retryable.
    43  func (e ErrStatusCode) IsRetryable() bool {
    44  	return e.IsRetryableStatusCode
    45  }
    46  
    47  // BuildRequestFn2 is a function used to rebuild an HTTP request for use in retryable code.
    48  type BuildRequestFn2 func(ctx context.Context) Result[*http.Request, error]
    49  
    50  // DecodeAnyFn is a function used to decode the response body into the desired type.
    51  type DecodeAnyFn func(ctx context.Context, resp *http.Response, maxMemory bytesext.Bytes, v any) error
    52  
    53  // IsRetryableStatusCodeFn2 is a function used to determine if the provided status code is considered retryable.
    54  type IsRetryableStatusCodeFn2 func(ctx context.Context, code int) bool
    55  
    56  // Retryer is used to retry any fallible operation.
    57  //
    58  // The `Retryer` is designed to be stateless and reusable. Configuration is also copy and so a base `Retryer` can be
    59  // used and changed for one-off requests eg. changing max attempts resulting in a new `Retrier` for that request.
    60  type Retryer struct {
    61  	isRetryableFn           errorsext.IsRetryableFn2[error]
    62  	isRetryableStatusCodeFn IsRetryableStatusCodeFn2
    63  	isEarlyReturnFn         errorsext.EarlyReturnFn[error]
    64  	decodeFn                DecodeAnyFn
    65  	backoffFn               errorsext.BackoffFn[error]
    66  	client                  *http.Client
    67  	timeout                 time.Duration
    68  	maxBytes                bytesext.Bytes
    69  	mode                    errorsext.MaxAttemptsMode
    70  	maxAttempts             uint8
    71  }
    72  
    73  // NewRetryer returns a new `Retryer` with sane default values.
    74  //
    75  // The default values are:
    76  //   - `IsRetryableFn` uses the existing `errorsext.IsRetryableHTTP` function.
    77  //   - `MaxAttemptsMode` is `MaxAttemptsNonRetryableReset`.
    78  //   - `MaxAttempts` is 5.
    79  //   - `BackoffFn` will sleep for 200ms or is successful `Retry-After` header can be parsed. It's recommended to use
    80  //     exponential backoff for production with a quick copy-paste-modify of the default function
    81  //   - `Timeout` is 0.
    82  //   - `IsRetryableStatusCodeFn` is set to the existing `IsRetryableStatusCode` function.
    83  //   - `IsEarlyReturnFn` is set to check if the error is an `ErrStatusCode` and if the status code is non-retryable.
    84  //   - `Client` is set to `http.DefaultClient`.
    85  //   - `MaxBytes` is set to 2MiB.
    86  //   - `DecodeAnyFn` is set to the existing `DecodeResponseAny` function that supports JSON and XML.
    87  //
    88  // WARNING: The default functions may receive enhancements or fixes in the future which could change their behavior,
    89  // however every attempt will be made to maintain backwards compatibility or made additive-only if possible.
    90  func NewRetryer() Retryer {
    91  	return Retryer{
    92  		client:      http.DefaultClient,
    93  		maxBytes:    2 * bytesext.MiB,
    94  		mode:        errorsext.MaxAttemptsNonRetryableReset,
    95  		maxAttempts: 5,
    96  		isRetryableFn: func(ctx context.Context, err error) (isRetryable bool) {
    97  			_, isRetryable = errorsext.IsRetryableHTTP(err)
    98  			return
    99  		},
   100  		isRetryableStatusCodeFn: func(_ context.Context, code int) bool { return IsRetryableStatusCode(code) },
   101  		isEarlyReturnFn: func(_ context.Context, err error) bool {
   102  			var sce ErrStatusCode
   103  			if errors.As(err, &sce) {
   104  				return IsNonRetryableStatusCode(sce.StatusCode)
   105  			}
   106  			return false
   107  		},
   108  		decodeFn: func(ctx context.Context, resp *http.Response, maxMemory bytesext.Bytes, v any) error {
   109  			err := DecodeResponseAny(resp, maxMemory, v)
   110  			if err != nil {
   111  				return err
   112  			}
   113  			return nil
   114  		},
   115  		backoffFn: func(ctx context.Context, attempt int, err error) {
   116  
   117  			wait := time.Millisecond * 200
   118  
   119  			var sce ErrStatusCode
   120  			if errors.As(err, &sce) {
   121  				if sce.Headers != nil && (sce.StatusCode == http.StatusTooManyRequests || sce.StatusCode == http.StatusServiceUnavailable) {
   122  					if ra := HasRetryAfter(sce.Headers); ra.IsSome() {
   123  						wait = ra.Unwrap()
   124  					}
   125  				}
   126  			}
   127  
   128  			t := time.NewTimer(wait)
   129  			defer t.Stop()
   130  			select {
   131  			case <-ctx.Done():
   132  			case <-t.C:
   133  			}
   134  		},
   135  	}
   136  }
   137  
   138  // Client sets the `http.Client` for the `Retryer`.
   139  func (r Retryer) Client(client *http.Client) Retryer {
   140  	r.client = client
   141  	return r
   142  }
   143  
   144  // IsRetryableFn sets the `IsRetryableFn` for the `Retryer`.
   145  func (r Retryer) IsRetryableFn(fn errorsext.IsRetryableFn2[error]) Retryer {
   146  	r.isRetryableFn = fn
   147  	return r
   148  }
   149  
   150  // IsRetryableStatusCodeFn is called to determine if the status code is retryable.
   151  func (r Retryer) IsRetryableStatusCodeFn(fn IsRetryableStatusCodeFn2) Retryer {
   152  	if fn == nil {
   153  		fn = func(_ context.Context, _ int) bool { return false }
   154  	}
   155  	r.isRetryableStatusCodeFn = fn
   156  	return r
   157  }
   158  
   159  // IsEarlyReturnFn sets the `EarlyReturnFn` for the `Retryer`.
   160  func (r Retryer) IsEarlyReturnFn(fn errorsext.EarlyReturnFn[error]) Retryer {
   161  	r.isEarlyReturnFn = fn
   162  	return r
   163  }
   164  
   165  // DecodeFn sets the decode function for the `Retryer`.
   166  func (r Retryer) DecodeFn(fn DecodeAnyFn) Retryer {
   167  	if fn == nil {
   168  		fn = func(_ context.Context, _ *http.Response, _ bytesext.Bytes, _ any) error { return nil }
   169  	}
   170  	r.decodeFn = fn
   171  	return r
   172  }
   173  
   174  // MaxAttempts sets the maximum number of attempts for the `Retryer`.
   175  //
   176  // NOTE: Max attempts is optional and if not set will retry indefinitely on retryable errors.
   177  func (r Retryer) MaxAttempts(mode errorsext.MaxAttemptsMode, maxAttempts uint8) Retryer {
   178  	r.mode, r.maxAttempts = mode, maxAttempts
   179  	return r
   180  }
   181  
   182  // Backoff sets the backoff function for the `Retryer`.
   183  func (r Retryer) Backoff(fn errorsext.BackoffFn[error]) Retryer {
   184  	r.backoffFn = fn
   185  	return r
   186  }
   187  
   188  // MaxBytes sets the maximum memory to use when decoding the response body including:
   189  // - upon unexpected status codes.
   190  // - when decoding the response body.
   191  // - when draining the response body before closing allowing connection re-use.
   192  func (r Retryer) MaxBytes(i bytesext.Bytes) Retryer {
   193  	r.maxBytes = i
   194  	return r
   195  
   196  }
   197  
   198  // Timeout sets the timeout for the `Retryer`. This is the timeout per `RetyableFn` attempt and not the entirety
   199  // of the `Retryer` execution.
   200  //
   201  // A timeout of 0 will disable the timeout and is the default.
   202  func (r Retryer) Timeout(timeout time.Duration) Retryer {
   203  	r.timeout = timeout
   204  	return r
   205  }
   206  
   207  // DoResponse will execute the provided functions code and automatically retry before returning the *http.Response
   208  // based on HTTP status code, if defined, and can be used when processing of the response body may not be necessary
   209  // or something custom is required.
   210  //
   211  // NOTE: it is up to the caller to close the response body if a successful request is made.
   212  func (r Retryer) DoResponse(ctx context.Context, fn BuildRequestFn2, expectedResponseCodes ...int) Result[*http.Response, error] {
   213  	return errorsext.NewRetryer[*http.Response, error]().
   214  		IsRetryableFn(r.isRetryableFn).
   215  		MaxAttempts(r.mode, r.maxAttempts).
   216  		Backoff(r.backoffFn).
   217  		Timeout(r.timeout).
   218  		IsEarlyReturnFn(r.isEarlyReturnFn).
   219  		Do(ctx, func(ctx context.Context) Result[*http.Response, error] {
   220  			req := fn(ctx)
   221  			if req.IsErr() {
   222  				return Err[*http.Response, error](req.Err())
   223  			}
   224  
   225  			resp, err := r.client.Do(req.Unwrap())
   226  			if err != nil {
   227  				return Err[*http.Response, error](err)
   228  			}
   229  
   230  			if len(expectedResponseCodes) > 0 {
   231  				for _, code := range expectedResponseCodes {
   232  					if resp.StatusCode == code {
   233  						goto RETURN
   234  					}
   235  				}
   236  				b, _ := io.ReadAll(ioext.LimitReader(resp.Body, r.maxBytes))
   237  				_ = resp.Body.Close()
   238  				return Err[*http.Response, error](ErrStatusCode{
   239  					StatusCode:            resp.StatusCode,
   240  					IsRetryableStatusCode: r.isRetryableStatusCodeFn(ctx, resp.StatusCode),
   241  					Headers:               resp.Header,
   242  					Body:                  b,
   243  				})
   244  			}
   245  
   246  		RETURN:
   247  			return Ok[*http.Response, error](resp)
   248  		})
   249  }
   250  
   251  // Do will execute the provided functions code and automatically retry using the provided retry function decoding
   252  // the response body into the desired type `v`, which must be passed as mutable.
   253  func (r Retryer) Do(ctx context.Context, fn BuildRequestFn2, v any, expectedResponseCodes ...int) error {
   254  	result := errorsext.NewRetryer[typesext.Nothing, error]().
   255  		IsRetryableFn(r.isRetryableFn).
   256  		MaxAttempts(r.mode, r.maxAttempts).
   257  		Backoff(r.backoffFn).
   258  		Timeout(r.timeout).
   259  		IsEarlyReturnFn(r.isEarlyReturnFn).
   260  		Do(ctx, func(ctx context.Context) Result[typesext.Nothing, error] {
   261  			req := fn(ctx)
   262  			if req.IsErr() {
   263  				return Err[typesext.Nothing, error](req.Err())
   264  			}
   265  
   266  			resp, err := r.client.Do(req.Unwrap())
   267  			if err != nil {
   268  				return Err[typesext.Nothing, error](err)
   269  			}
   270  			defer func() {
   271  				_, _ = io.Copy(io.Discard, ioext.LimitReader(resp.Body, r.maxBytes))
   272  				_ = resp.Body.Close()
   273  			}()
   274  
   275  			if len(expectedResponseCodes) > 0 {
   276  				for _, code := range expectedResponseCodes {
   277  					if resp.StatusCode == code {
   278  						goto DECODE
   279  					}
   280  				}
   281  
   282  				b, _ := io.ReadAll(ioext.LimitReader(resp.Body, r.maxBytes))
   283  				return Err[typesext.Nothing, error](ErrStatusCode{
   284  					StatusCode:            resp.StatusCode,
   285  					IsRetryableStatusCode: r.isRetryableStatusCodeFn(ctx, resp.StatusCode),
   286  					Headers:               resp.Header,
   287  					Body:                  b,
   288  				})
   289  			}
   290  
   291  		DECODE:
   292  			if err = r.decodeFn(ctx, resp, r.maxBytes, v); err != nil {
   293  				return Err[typesext.Nothing, error](err)
   294  			}
   295  			return Ok[typesext.Nothing, error](valuesext.Nothing)
   296  		})
   297  	if result.IsErr() {
   298  		return result.Err()
   299  	}
   300  	return nil
   301  }