github.com/uber/kraken@v0.1.4/utils/httputil/httputil.go (about)

     1  // Copyright (c) 2016-2019 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package httputil
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"io/ioutil"
    23  	"net/http"
    24  	"net/url"
    25  	"time"
    26  
    27  	"github.com/cenkalti/backoff"
    28  	"github.com/pressly/chi"
    29  
    30  	"github.com/uber/kraken/core"
    31  	"github.com/uber/kraken/utils/handler"
    32  )
    33  
    34  // RoundTripper is an alias of the http.RoundTripper for mocking purposes.
    35  type RoundTripper = http.RoundTripper
    36  
    37  // StatusError occurs if an HTTP response has an unexpected status code.
    38  type StatusError struct {
    39  	Method       string
    40  	URL          string
    41  	Status       int
    42  	Header       http.Header
    43  	ResponseDump string
    44  }
    45  
    46  // NewStatusError returns a new StatusError.
    47  func NewStatusError(resp *http.Response) StatusError {
    48  	defer resp.Body.Close()
    49  	respBytes, err := ioutil.ReadAll(resp.Body)
    50  	respDump := string(respBytes)
    51  	if err != nil {
    52  		respDump = fmt.Sprintf("failed to dump response: %s", err)
    53  	}
    54  	return StatusError{
    55  		Method:       resp.Request.Method,
    56  		URL:          resp.Request.URL.String(),
    57  		Status:       resp.StatusCode,
    58  		Header:       resp.Header,
    59  		ResponseDump: respDump,
    60  	}
    61  }
    62  
    63  func (e StatusError) Error() string {
    64  	if e.ResponseDump == "" {
    65  		return fmt.Sprintf("%s %s %d", e.Method, e.URL, e.Status)
    66  	}
    67  	return fmt.Sprintf("%s %s %d: %s", e.Method, e.URL, e.Status, e.ResponseDump)
    68  }
    69  
    70  // IsStatus returns true if err is a StatusError of the given status.
    71  func IsStatus(err error, status int) bool {
    72  	statusErr, ok := err.(StatusError)
    73  	return ok && statusErr.Status == status
    74  }
    75  
    76  // IsCreated returns true if err is a "created", 201
    77  func IsCreated(err error) bool {
    78  	return IsStatus(err, http.StatusCreated)
    79  }
    80  
    81  // IsNotFound returns true if err is a "not found" StatusError.
    82  func IsNotFound(err error) bool {
    83  	return IsStatus(err, http.StatusNotFound)
    84  }
    85  
    86  // IsConflict returns true if err is a "status conflict" StatusError.
    87  func IsConflict(err error) bool {
    88  	return IsStatus(err, http.StatusConflict)
    89  }
    90  
    91  // IsAccepted returns true if err is a "status accepted" StatusError.
    92  func IsAccepted(err error) bool {
    93  	return IsStatus(err, http.StatusAccepted)
    94  }
    95  
    96  // IsForbidden returns true if statis code is 403 "forbidden"
    97  func IsForbidden(err error) bool {
    98  	return IsStatus(err, http.StatusForbidden)
    99  }
   100  
   101  // NetworkError occurs on any Send error which occurred while trying to send
   102  // the HTTP request, e.g. the given host is unresponsive.
   103  type NetworkError struct {
   104  	err error
   105  }
   106  
   107  func (e NetworkError) Error() string {
   108  	return fmt.Sprintf("network error: %s", e.err)
   109  }
   110  
   111  // IsNetworkError returns true if err is a NetworkError.
   112  func IsNetworkError(err error) bool {
   113  	_, ok := err.(NetworkError)
   114  	return ok
   115  }
   116  
   117  type sendOptions struct {
   118  	body          io.Reader
   119  	timeout       time.Duration
   120  	acceptedCodes map[int]bool
   121  	headers       map[string]string
   122  	redirect      func(req *http.Request, via []*http.Request) error
   123  	retry         retryOptions
   124  	transport     http.RoundTripper
   125  	ctx           context.Context
   126  
   127  	// This is not a valid http option. It provides a way to override
   128  	// parts of the url. For example, url.Scheme can be changed from
   129  	// http to https.
   130  	url *url.URL
   131  
   132  	// This is not a valid http option. HTTP fallback is added to allow
   133  	// easier migration from http to https.
   134  	// In go1.11 and go1.12, the responses returned when http request is
   135  	// sent to https server are different in the fallback mode:
   136  	// go1.11 returns a network error whereas go1.12 returns BadRequest.
   137  	// This causes TestTLSClientBadAuth to fail because the test checks
   138  	// retry error.
   139  	// This flag is added to allow disabling http fallback in unit tests.
   140  	// NOTE: it does not impact how it runs in production.
   141  	httpFallbackDisabled bool
   142  }
   143  
   144  // SendOption allows overriding defaults for the Send function.
   145  type SendOption func(*sendOptions)
   146  
   147  // SendNoop returns a no-op option.
   148  func SendNoop() SendOption {
   149  	return func(o *sendOptions) {}
   150  }
   151  
   152  // SendBody specifies a body for http request
   153  func SendBody(body io.Reader) SendOption {
   154  	return func(o *sendOptions) { o.body = body }
   155  }
   156  
   157  // SendTimeout specifies timeout for http request
   158  func SendTimeout(timeout time.Duration) SendOption {
   159  	return func(o *sendOptions) { o.timeout = timeout }
   160  }
   161  
   162  // SendHeaders specifies headers for http request
   163  func SendHeaders(headers map[string]string) SendOption {
   164  	return func(o *sendOptions) { o.headers = headers }
   165  }
   166  
   167  // SendAcceptedCodes specifies accepted codes for http request
   168  func SendAcceptedCodes(codes ...int) SendOption {
   169  	m := make(map[int]bool)
   170  	for _, c := range codes {
   171  		m[c] = true
   172  	}
   173  	return func(o *sendOptions) { o.acceptedCodes = m }
   174  }
   175  
   176  // SendRedirect specifies a redirect policy for http request
   177  func SendRedirect(redirect func(req *http.Request, via []*http.Request) error) SendOption {
   178  	return func(o *sendOptions) { o.redirect = redirect }
   179  }
   180  
   181  type retryOptions struct {
   182  	backoff    backoff.BackOff
   183  	extraCodes map[int]bool
   184  }
   185  
   186  // RetryOption allows overriding defaults for the SendRetry option.
   187  type RetryOption func(*retryOptions)
   188  
   189  // RetryBackoff adds exponential backoff between retries.
   190  func RetryBackoff(b backoff.BackOff) RetryOption {
   191  	return func(o *retryOptions) { o.backoff = b }
   192  }
   193  
   194  // RetryCodes adds more status codes to be retried (in addition to the default
   195  // 5XX codes).
   196  //
   197  // WARNING: You better know what you're doing to retry anything non-5XX.
   198  func RetryCodes(codes ...int) RetryOption {
   199  	return func(o *retryOptions) {
   200  		for _, c := range codes {
   201  			o.extraCodes[c] = true
   202  		}
   203  	}
   204  }
   205  
   206  // SendRetry will we retry the request on network / 5XX errors.
   207  func SendRetry(options ...RetryOption) SendOption {
   208  	retry := retryOptions{
   209  		backoff: backoff.WithMaxRetries(
   210  			backoff.NewConstantBackOff(250*time.Millisecond),
   211  			2),
   212  		extraCodes: make(map[int]bool),
   213  	}
   214  	for _, o := range options {
   215  		o(&retry)
   216  	}
   217  	return func(o *sendOptions) { o.retry = retry }
   218  }
   219  
   220  // DisableHTTPFallback disables http fallback when https request fails.
   221  func DisableHTTPFallback() SendOption {
   222  	return func(o *sendOptions) {
   223  		o.httpFallbackDisabled = true
   224  	}
   225  }
   226  
   227  // SendTLS sets the transport with TLS config for the HTTP client.
   228  func SendTLS(config *tls.Config) SendOption {
   229  	return func(o *sendOptions) {
   230  		if config == nil {
   231  			return
   232  		}
   233  		o.transport = &http.Transport{TLSClientConfig: config}
   234  		o.url.Scheme = "https"
   235  	}
   236  }
   237  
   238  // SendTLSTransport sets the transport with TLS config for the HTTP client.
   239  func SendTLSTransport(transport http.RoundTripper) SendOption {
   240  	return func(o *sendOptions) {
   241  		o.transport = transport
   242  		o.url.Scheme = "https"
   243  	}
   244  }
   245  
   246  // SendTransport sets the transport for the HTTP client.
   247  func SendTransport(transport http.RoundTripper) SendOption {
   248  	return func(o *sendOptions) { o.transport = transport }
   249  }
   250  
   251  // SendContext sets the context for the HTTP client.
   252  func SendContext(ctx context.Context) SendOption {
   253  	return func(o *sendOptions) { o.ctx = ctx }
   254  }
   255  
   256  // Send sends an HTTP request. May return NetworkError or StatusError (see above).
   257  func Send(method, rawurl string, options ...SendOption) (*http.Response, error) {
   258  	u, err := url.Parse(rawurl)
   259  	if err != nil {
   260  		return nil, fmt.Errorf("parse url: %s", err)
   261  	}
   262  	opts := &sendOptions{
   263  		body:                 nil,
   264  		timeout:              60 * time.Second,
   265  		acceptedCodes:        map[int]bool{http.StatusOK: true},
   266  		headers:              map[string]string{},
   267  		retry:                retryOptions{backoff: &backoff.StopBackOff{}},
   268  		transport:            nil, // Use HTTP default.
   269  		ctx:                  context.Background(),
   270  		url:                  u,
   271  		httpFallbackDisabled: false,
   272  	}
   273  	for _, o := range options {
   274  		o(opts)
   275  	}
   276  
   277  	req, err := newRequest(method, opts)
   278  	if err != nil {
   279  		return nil, err
   280  	}
   281  
   282  	client := &http.Client{
   283  		Timeout:       opts.timeout,
   284  		CheckRedirect: opts.redirect,
   285  		Transport:     opts.transport,
   286  	}
   287  
   288  	var resp *http.Response
   289  	for {
   290  		resp, err = client.Do(req)
   291  		// Retry without tls. During migration there would be a time when the
   292  		// component receiving the tls request does not serve https response.
   293  		// TODO (@evelynl): disable retry after tls migration.
   294  		if err != nil && req.URL.Scheme == "https" && !opts.httpFallbackDisabled {
   295  			originalErr := err
   296  			resp, err = fallbackToHTTP(client, method, opts)
   297  			if err != nil {
   298  				// Sometimes the request fails for a reason unrelated to https.
   299  				// To keep this reason visible, we always include the original
   300  				// error.
   301  				err = fmt.Errorf(
   302  					"failed to fallback https to http, original https error: %s,\n"+
   303  						"fallback http error: %s", originalErr, err)
   304  			}
   305  		}
   306  		if err != nil ||
   307  			(resp.StatusCode >= 500 && !opts.acceptedCodes[resp.StatusCode]) ||
   308  			(opts.retry.extraCodes[resp.StatusCode]) {
   309  			d := opts.retry.backoff.NextBackOff()
   310  			if d == backoff.Stop {
   311  				break // Backoff timed out.
   312  			}
   313  			time.Sleep(d)
   314  			continue
   315  		}
   316  		break
   317  	}
   318  	if err != nil {
   319  		return nil, NetworkError{err}
   320  	}
   321  	if !opts.acceptedCodes[resp.StatusCode] {
   322  		return nil, NewStatusError(resp)
   323  	}
   324  	return resp, nil
   325  }
   326  
   327  // Get sends a GET http request.
   328  func Get(url string, options ...SendOption) (*http.Response, error) {
   329  	return Send("GET", url, options...)
   330  }
   331  
   332  // Head sends a HEAD http request.
   333  func Head(url string, options ...SendOption) (*http.Response, error) {
   334  	return Send("HEAD", url, options...)
   335  }
   336  
   337  // Post sends a POST http request.
   338  func Post(url string, options ...SendOption) (*http.Response, error) {
   339  	return Send("POST", url, options...)
   340  }
   341  
   342  // Put sends a PUT http request.
   343  func Put(url string, options ...SendOption) (*http.Response, error) {
   344  	return Send("PUT", url, options...)
   345  }
   346  
   347  // Patch sends a PATCH http request.
   348  func Patch(url string, options ...SendOption) (*http.Response, error) {
   349  	return Send("PATCH", url, options...)
   350  }
   351  
   352  // Delete sends a DELETE http request.
   353  func Delete(url string, options ...SendOption) (*http.Response, error) {
   354  	return Send("DELETE", url, options...)
   355  }
   356  
   357  // PollAccepted wraps GET requests for endpoints which require 202-polling.
   358  func PollAccepted(
   359  	url string, b backoff.BackOff, options ...SendOption) (*http.Response, error) {
   360  
   361  	b.Reset()
   362  	for {
   363  		resp, err := Get(url, options...)
   364  		if err != nil {
   365  			if IsAccepted(err) {
   366  				d := b.NextBackOff()
   367  				if d == backoff.Stop {
   368  					break // Backoff timed out.
   369  				}
   370  				time.Sleep(d)
   371  				continue
   372  			}
   373  			return nil, err
   374  		}
   375  		return resp, nil
   376  	}
   377  	return nil, errors.New("backoff timed out on 202 responses")
   378  }
   379  
   380  // GetQueryArg gets an argument from http.Request by name.
   381  // When the argument is not specified, it returns a default value.
   382  func GetQueryArg(r *http.Request, name string, defaultVal string) string {
   383  	v := r.URL.Query().Get(name)
   384  	if v == "" {
   385  		v = defaultVal
   386  	}
   387  
   388  	return v
   389  }
   390  
   391  // ParseParam parses a parameter from url.
   392  func ParseParam(r *http.Request, name string) (string, error) {
   393  	param := chi.URLParam(r, name)
   394  	if param == "" {
   395  		return "", handler.Errorf("param %s is required", name).Status(http.StatusBadRequest)
   396  	}
   397  	val, err := url.PathUnescape(param)
   398  	if err != nil {
   399  		return "", handler.Errorf("path unescape %s: %s", name, err).Status(http.StatusBadRequest)
   400  	}
   401  	return val, nil
   402  }
   403  
   404  // ParseDigest parses a digest from url.
   405  func ParseDigest(r *http.Request, name string) (core.Digest, error) {
   406  	raw, err := ParseParam(r, name)
   407  	if err != nil {
   408  		return core.Digest{}, err
   409  	}
   410  
   411  	d, err := core.ParseSHA256Digest(raw)
   412  	if err != nil {
   413  		return core.Digest{}, handler.Errorf("parse digest: %s", err).Status(http.StatusBadRequest)
   414  	}
   415  	return d, nil
   416  }
   417  
   418  func newRequest(method string, opts *sendOptions) (*http.Request, error) {
   419  	req, err := http.NewRequest(method, opts.url.String(), opts.body)
   420  	if err != nil {
   421  		return nil, fmt.Errorf("new request: %s", err)
   422  	}
   423  	req = req.WithContext(opts.ctx)
   424  	if opts.body == nil {
   425  		req.ContentLength = 0
   426  	}
   427  	for key, val := range opts.headers {
   428  		req.Header.Set(key, val)
   429  	}
   430  	return req, nil
   431  }
   432  
   433  func fallbackToHTTP(
   434  	client *http.Client, method string, opts *sendOptions) (*http.Response, error) {
   435  
   436  	req, err := newRequest(method, opts)
   437  	if err != nil {
   438  		return nil, err
   439  	}
   440  	req.URL.Scheme = "http"
   441  
   442  	return client.Do(req)
   443  }
   444  
   445  func min(a, b time.Duration) time.Duration {
   446  	if a < b {
   447  		return a
   448  	}
   449  	return b
   450  }