github.com/uber/kraken@v0.1.4/utils/httputil/httputil_test.go (about)

     1  // Copyright (c) 2016-2019 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package httputil
    15  
    16  import (
    17  	"context"
    18  	"errors"
    19  	"fmt"
    20  	"net/http"
    21  	"net/http/httptest"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/cenkalti/backoff"
    26  	"github.com/golang/mock/gomock"
    27  	"github.com/pressly/chi"
    28  	"github.com/stretchr/testify/require"
    29  
    30  	"github.com/uber/kraken/core"
    31  	"github.com/uber/kraken/mocks/utils/httputil"
    32  )
    33  
    34  const _testURL = "http://localhost:0/test"
    35  
    36  func newResponse(status int) *http.Response {
    37  	// We need to set a dummy request in the response so NewStatusError
    38  	// can access the "original" URL.
    39  	dummyReq, err := http.NewRequest("GET", _testURL, nil)
    40  	if err != nil {
    41  		panic(err)
    42  	}
    43  
    44  	rec := httptest.NewRecorder()
    45  	rec.WriteHeader(status)
    46  	resp := rec.Result()
    47  	resp.Request = dummyReq
    48  
    49  	return resp
    50  }
    51  
    52  func TestSendOptions(t *testing.T) {
    53  	require := require.New(t)
    54  
    55  	ctrl := gomock.NewController(t)
    56  	defer ctrl.Finish()
    57  
    58  	transport := mockhttputil.NewMockRoundTripper(ctrl)
    59  
    60  	transport.EXPECT().RoundTrip(gomock.Any()).Return(newResponse(499), nil)
    61  
    62  	_, err := Get(
    63  		_testURL,
    64  		SendTransport(transport),
    65  		SendAcceptedCodes(200, 499))
    66  	require.NoError(err)
    67  }
    68  
    69  func TestSendRetry(t *testing.T) {
    70  	require := require.New(t)
    71  
    72  	ctrl := gomock.NewController(t)
    73  	defer ctrl.Finish()
    74  
    75  	transport := mockhttputil.NewMockRoundTripper(ctrl)
    76  
    77  	for _, status := range []int{503, 500, 200} {
    78  		transport.EXPECT().RoundTrip(gomock.Any()).Return(newResponse(status), nil)
    79  	}
    80  
    81  	start := time.Now()
    82  	_, err := Get(
    83  		_testURL,
    84  		SendRetry(
    85  			RetryBackoff(backoff.WithMaxRetries(
    86  				backoff.NewConstantBackOff(200*time.Millisecond),
    87  				4))),
    88  		SendTransport(transport))
    89  	require.NoError(err)
    90  	require.InDelta(400*time.Millisecond, time.Since(start), float64(50*time.Millisecond))
    91  }
    92  
    93  func TestSendRetryOnTransportErrors(t *testing.T) {
    94  	require := require.New(t)
    95  
    96  	ctrl := gomock.NewController(t)
    97  	defer ctrl.Finish()
    98  
    99  	transport := mockhttputil.NewMockRoundTripper(ctrl)
   100  
   101  	transport.EXPECT().RoundTrip(gomock.Any()).Return(nil, errors.New("some network error")).Times(3)
   102  
   103  	start := time.Now()
   104  	_, err := Get(
   105  		_testURL,
   106  		SendRetry(
   107  			RetryBackoff(backoff.WithMaxRetries(
   108  				backoff.NewConstantBackOff(200*time.Millisecond),
   109  				2))),
   110  		SendTransport(transport))
   111  	require.Error(err)
   112  	require.InDelta(400*time.Millisecond, time.Since(start), float64(50*time.Millisecond))
   113  }
   114  
   115  func TestSendRetryOn5XX(t *testing.T) {
   116  	require := require.New(t)
   117  
   118  	ctrl := gomock.NewController(t)
   119  	defer ctrl.Finish()
   120  
   121  	transport := mockhttputil.NewMockRoundTripper(ctrl)
   122  
   123  	transport.EXPECT().RoundTrip(gomock.Any()).Return(newResponse(503), nil).Times(3)
   124  
   125  	start := time.Now()
   126  	_, err := Get(
   127  		_testURL,
   128  		SendRetry(
   129  			RetryBackoff(backoff.WithMaxRetries(
   130  				backoff.NewConstantBackOff(200*time.Millisecond),
   131  				2))),
   132  		SendTransport(transport))
   133  	require.Error(err)
   134  	require.Equal(503, err.(StatusError).Status)
   135  	require.InDelta(400*time.Millisecond, time.Since(start), float64(50*time.Millisecond))
   136  }
   137  
   138  func TestSendRetryWithCodes(t *testing.T) {
   139  	require := require.New(t)
   140  
   141  	ctrl := gomock.NewController(t)
   142  	defer ctrl.Finish()
   143  
   144  	transport := mockhttputil.NewMockRoundTripper(ctrl)
   145  
   146  	gomock.InOrder(
   147  		transport.EXPECT().RoundTrip(gomock.Any()).Return(newResponse(400), nil),
   148  		transport.EXPECT().RoundTrip(gomock.Any()).Return(newResponse(503), nil),
   149  		transport.EXPECT().RoundTrip(gomock.Any()).Return(newResponse(404), nil),
   150  	)
   151  
   152  	start := time.Now()
   153  	_, err := Get(
   154  		_testURL,
   155  		SendRetry(
   156  			RetryBackoff(backoff.WithMaxRetries(
   157  				backoff.NewConstantBackOff(200*time.Millisecond),
   158  				2)),
   159  			RetryCodes(400, 404)),
   160  		SendTransport(transport))
   161  	require.Error(err)
   162  	require.Equal(404, err.(StatusError).Status) // Last code returned.
   163  	require.InDelta(400*time.Millisecond, time.Since(start), float64(50*time.Millisecond))
   164  }
   165  
   166  func TestPollAccepted(t *testing.T) {
   167  	require := require.New(t)
   168  
   169  	ctrl := gomock.NewController(t)
   170  	defer ctrl.Finish()
   171  
   172  	transport := mockhttputil.NewMockRoundTripper(ctrl)
   173  
   174  	for _, status := range []int{202, 202, 200} {
   175  		transport.EXPECT().RoundTrip(gomock.Any()).Return(newResponse(status), nil)
   176  	}
   177  
   178  	start := time.Now()
   179  	_, err := PollAccepted(
   180  		_testURL,
   181  		backoff.NewConstantBackOff(200*time.Millisecond),
   182  		SendTransport(transport))
   183  	require.NoError(err)
   184  	require.InDelta(400*time.Millisecond, time.Since(start), float64(50*time.Millisecond))
   185  }
   186  
   187  func TestPollAcceptedStatusError(t *testing.T) {
   188  	require := require.New(t)
   189  
   190  	ctrl := gomock.NewController(t)
   191  	defer ctrl.Finish()
   192  
   193  	transport := mockhttputil.NewMockRoundTripper(ctrl)
   194  
   195  	for _, status := range []int{202, 202, 404} {
   196  		transport.EXPECT().RoundTrip(gomock.Any()).Return(newResponse(status), nil)
   197  	}
   198  
   199  	start := time.Now()
   200  	_, err := PollAccepted(
   201  		_testURL,
   202  		backoff.NewConstantBackOff(200*time.Millisecond),
   203  		SendTransport(transport))
   204  	require.Error(err)
   205  	require.Equal(404, err.(StatusError).Status)
   206  	require.InDelta(400*time.Millisecond, time.Since(start), float64(50*time.Millisecond))
   207  }
   208  
   209  func TestPollAcceptedBackoffTimeout(t *testing.T) {
   210  	require := require.New(t)
   211  
   212  	ctrl := gomock.NewController(t)
   213  	defer ctrl.Finish()
   214  
   215  	transport := mockhttputil.NewMockRoundTripper(ctrl)
   216  
   217  	transport.EXPECT().RoundTrip(gomock.Any()).Return(newResponse(202), nil).Times(3)
   218  
   219  	start := time.Now()
   220  	_, err := PollAccepted(
   221  		_testURL,
   222  		backoff.WithMaxRetries(backoff.NewConstantBackOff(200*time.Millisecond), 2),
   223  		SendTransport(transport))
   224  	require.Error(err)
   225  	require.InDelta(400*time.Millisecond, time.Since(start), float64(50*time.Millisecond))
   226  }
   227  
   228  func TestGetQueryArg(t *testing.T) {
   229  	require := require.New(t)
   230  	arg := "arg"
   231  	value := "value"
   232  	defaultVal := "defaultvalue"
   233  
   234  	r := httptest.NewRequest("GET", fmt.Sprintf("localhost:0/?%s=%s", arg, value), nil)
   235  	require.Equal(value, GetQueryArg(r, arg, defaultVal))
   236  }
   237  
   238  func TestGetQueryArgUseDefault(t *testing.T) {
   239  	require := require.New(t)
   240  	arg := "arg"
   241  	defaultVal := "defaultvalue"
   242  
   243  	r := httptest.NewRequest("GET", "localhost:0/", nil)
   244  	require.Equal(defaultVal, GetQueryArg(r, arg, defaultVal))
   245  }
   246  
   247  func TestParseParam(t *testing.T) {
   248  	require := require.New(t)
   249  
   250  	r := httptest.NewRequest("GET", "/", nil)
   251  
   252  	rctx := chi.NewRouteContext()
   253  	rctx.URLParams.Add("key", "a%2Fb")
   254  
   255  	r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
   256  
   257  	ret, err := ParseParam(r, "key")
   258  	require.NoError(err)
   259  	require.Equal("a/b", ret)
   260  }
   261  
   262  func TestParseParamNotFound(t *testing.T) {
   263  	require := require.New(t)
   264  
   265  	r := httptest.NewRequest("GET", "/", nil)
   266  	rctx := chi.NewRouteContext()
   267  
   268  	r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
   269  
   270  	_, err := ParseParam(r, "key")
   271  	require.Error(err)
   272  }
   273  
   274  func TestParseParamUnescapeError(t *testing.T) {
   275  	require := require.New(t)
   276  
   277  	r := httptest.NewRequest("GET", "/", nil)
   278  	rctx := chi.NewRouteContext()
   279  	rctx.URLParams.Add("key", "value%")
   280  
   281  	r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
   282  
   283  	_, err := ParseParam(r, "key")
   284  	require.Error(err)
   285  }
   286  
   287  func TestParseDigest(t *testing.T) {
   288  	require := require.New(t)
   289  
   290  	r := httptest.NewRequest("GET", "/", nil)
   291  
   292  	d := core.DigestFixture()
   293  	rctx := chi.NewRouteContext()
   294  	rctx.URLParams.Add("digest", d.String())
   295  
   296  	r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
   297  
   298  	ret, err := ParseDigest(r, "digest")
   299  	require.NoError(err)
   300  	require.Equal(d, ret)
   301  }
   302  
   303  func TestParseDigestInvalid(t *testing.T) {
   304  	require := require.New(t)
   305  
   306  	r := httptest.NewRequest("GET", "/", nil)
   307  
   308  	rctx := chi.NewRouteContext()
   309  	rctx.URLParams.Add("digest", "abc")
   310  
   311  	r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
   312  
   313  	_, err := ParseDigest(r, "digest")
   314  	require.Error(err)
   315  }