github.com/go-board/x-go@v0.1.2-0.20220610024734-db1323f6cb15/xnet/xhttp/xrequest/interceptor.go (about)

     1  package xrequest
     2  
     3  import (
     4  	"compress/gzip"
     5  	"log"
     6  	"math/rand"
     7  	"net/http"
     8  	"strconv"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/go-board/x-go/xnet/xhttp"
    13  	"github.com/go-board/x-go/xslice"
    14  )
    15  
    16  // RoundTripperFn implement http.RoundTripper for convenient usage.
    17  type RoundTripperFn func(request *http.Request) (*http.Response, error)
    18  
    19  func (fn RoundTripperFn) RoundTrip(request *http.Request) (*http.Response, error) { return fn(request) }
    20  
    21  // Interceptor is interceptor that can do more work before/after an request
    22  type Interceptor interface {
    23  	Next(fn http.RoundTripper) http.RoundTripper
    24  }
    25  
    26  // InterceptorFn implement Interceptor for convenient usage.
    27  type InterceptorFn func(rt http.RoundTripper) http.RoundTripper
    28  
    29  func (fn InterceptorFn) Next(rt http.RoundTripper) http.RoundTripper { return fn(rt) }
    30  
    31  // ComposeInterceptor compose interceptors to given http.RoundTripper
    32  func ComposeInterceptor(rt http.RoundTripper, interceptors ...Interceptor) http.RoundTripper {
    33  	if len(interceptors) == 0 {
    34  		return rt
    35  	}
    36  	return ComposeInterceptor(interceptors[0].Next(rt), interceptors[1:]...)
    37  }
    38  
    39  // InjectHeader inject given header into request.
    40  func InjectHeader(h http.Header) InterceptorFn {
    41  	return func(rt http.RoundTripper) http.RoundTripper {
    42  		return RoundTripperFn(func(req *http.Request) (*http.Response, error) {
    43  			for k, v := range h {
    44  				for _, vv := range v {
    45  					req.Header.Add(k, vv)
    46  				}
    47  			}
    48  			return rt.RoundTrip(req)
    49  		})
    50  	}
    51  }
    52  
    53  // Logging is Interceptor that log http request stats
    54  func Logging(rt http.RoundTripper) http.RoundTripper {
    55  	return RoundTripperFn(func(request *http.Request) (*http.Response, error) {
    56  		before := time.Now()
    57  		response, err := rt.RoundTrip(request)
    58  		if err != nil {
    59  			log.Printf("%s %s, latency: %s, status: %s\n", request.Method, request.URL.Path, time.Since(before), err)
    60  		} else {
    61  			log.Printf("%s %s, latency: %s, status: %s\n", request.Method, request.URL.Path, time.Since(before), response.Status)
    62  		}
    63  		return response, err
    64  	})
    65  }
    66  
    67  // RetryOnStatusCode retry on return codes...
    68  func RetryOnStatusCode(codes ...int) InterceptorFn {
    69  	return func(rt http.RoundTripper) http.RoundTripper {
    70  		return RoundTripperFn(func(request *http.Request) (response *http.Response, err error) {
    71  			for i := 0; i < 3; i++ {
    72  				response, err = rt.RoundTrip(request)
    73  				if err != nil || (response != nil && xslice.ContainsInt(codes, response.StatusCode)) {
    74  					continue
    75  				}
    76  				return
    77  			}
    78  			return
    79  		})
    80  	}
    81  }
    82  
    83  // RetryStrategy is strategy for http request
    84  type RetryStrategy struct {
    85  	Backoff     func(r *http.Request, i int) time.Duration
    86  	MaxRetries  func(r *http.Request) int
    87  	ShouldRetry func(r *http.Request, resp *http.Response, err error) bool
    88  }
    89  
    90  // RetryWithStrategy retry with given strategy.
    91  func RetryWithStrategy(strategy RetryStrategy) InterceptorFn {
    92  	if strategy.Backoff == nil {
    93  		strategy.Backoff = func(r *http.Request, i int) time.Duration { return 0 }
    94  	}
    95  	if strategy.MaxRetries == nil {
    96  		strategy.MaxRetries = func(r *http.Request) int {
    97  			if retryStr := r.Header.Get("X-Max-Retries"); retryStr != "" {
    98  				if retries, err := strconv.ParseInt(retryStr, 10, 64); err == nil {
    99  					return int(retries)
   100  				}
   101  			}
   102  			return 3
   103  		}
   104  	}
   105  	if strategy.ShouldRetry == nil {
   106  		strategy.ShouldRetry = func(r *http.Request, resp *http.Response, err error) bool {
   107  			return err != nil || xslice.ContainsInt([]int{500}, resp.StatusCode)
   108  		}
   109  	}
   110  	return func(rt http.RoundTripper) http.RoundTripper {
   111  		return RoundTripperFn(func(request *http.Request) (response *http.Response, err error) {
   112  			maxRetries := strategy.MaxRetries(request)
   113  			for i := 0; i < maxRetries; i++ {
   114  				response, err = rt.RoundTrip(request)
   115  				if strategy.ShouldRetry(request, response, err) {
   116  					continue
   117  				}
   118  				return
   119  			}
   120  			return
   121  		})
   122  	}
   123  }
   124  
   125  // RoundRobinProxy proxy request with round robin strategy to different server.
   126  func RoundRobinProxy(hosts ...string) InterceptorFn {
   127  	if len(hosts) == 0 {
   128  		panic("empty hosts list")
   129  	}
   130  	var term uint64 = 0
   131  	return func(rt http.RoundTripper) http.RoundTripper {
   132  		return RoundTripperFn(func(request *http.Request) (*http.Response, error) {
   133  			host := hosts[atomic.AddUint64(&term, 1)%uint64(len(hosts))]
   134  			request.Host = host
   135  			return rt.RoundTrip(request)
   136  		})
   137  	}
   138  }
   139  
   140  // RandomProxy proxy request with random strategy to different server.
   141  func RandomProxy(hosts ...string) InterceptorFn {
   142  	if len(hosts) == 0 {
   143  		panic("empty hosts list")
   144  	}
   145  	return func(rt http.RoundTripper) http.RoundTripper {
   146  		return RoundTripperFn(func(request *http.Request) (*http.Response, error) {
   147  			request.Host = hosts[rand.Intn(len(hosts))]
   148  			return rt.RoundTrip(request)
   149  		})
   150  	}
   151  }
   152  
   153  // GzipDecompressResponse decompress response body if possible.
   154  func GzipDecompressResponse() InterceptorFn {
   155  	return func(rt http.RoundTripper) http.RoundTripper {
   156  		return RoundTripperFn(func(request *http.Request) (*http.Response, error) {
   157  			response, err := rt.RoundTrip(request)
   158  			if err != nil {
   159  				return nil, err
   160  			}
   161  			if response.Header.Get(xhttp.HeaderContentEncoding) == "gzip" {
   162  				r, err := gzip.NewReader(response.Body)
   163  				if err != nil {
   164  					return nil, err
   165  				}
   166  				response.Body = r
   167  			}
   168  			return response, err
   169  		})
   170  	}
   171  }