github.com/xmidt-org/webpa-common@v1.11.9/xhttp/retry_test.go (about)

     1  package xhttp
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"net"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/go-kit/kit/metrics/generic"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/mock"
    17  	"github.com/stretchr/testify/require"
    18  	"github.com/xmidt-org/webpa-common/logging"
    19  )
    20  
    21  func testShouldRetry(t *testing.T, shouldRetry ShouldRetryFunc, candidate error, expected bool) {
    22  	assert := assert.New(t)
    23  	assert.Equal(expected, shouldRetry(candidate))
    24  }
    25  
    26  func TestDefaultShouldRetry(t *testing.T) {
    27  	t.Run("Nil", func(t *testing.T) {
    28  		testShouldRetry(t, DefaultShouldRetry, nil, false)
    29  	})
    30  
    31  	t.Run("DNSError", func(t *testing.T) {
    32  		testShouldRetry(t, DefaultShouldRetry, &net.DNSError{IsTemporary: false}, false)
    33  		testShouldRetry(t, DefaultShouldRetry, &net.DNSError{IsTemporary: true}, true)
    34  	})
    35  }
    36  
    37  func testRetryTransactorDefaultLogger(t *testing.T) {
    38  	var (
    39  		assert           = assert.New(t)
    40  		require          = require.New(t)
    41  		transactorCalled = false
    42  
    43  		transactor = func(*http.Request) (*http.Response, error) {
    44  			transactorCalled = true
    45  			return nil, nil
    46  		}
    47  
    48  		retry = RetryTransactor(RetryOptions{Retries: 1}, transactor)
    49  	)
    50  
    51  	require.NotNil(retry)
    52  	retry(httptest.NewRequest("GET", "/", nil))
    53  	assert.True(transactorCalled)
    54  }
    55  
    56  func testRetryTransactorNoRetries(t *testing.T) {
    57  	var (
    58  		assert           = assert.New(t)
    59  		require          = require.New(t)
    60  		transactorCalled = false
    61  
    62  		transactor = func(*http.Request) (*http.Response, error) {
    63  			transactorCalled = true
    64  			return nil, nil
    65  		}
    66  
    67  		retry = RetryTransactor(RetryOptions{}, transactor)
    68  	)
    69  
    70  	require.NotNil(retry)
    71  	retry(httptest.NewRequest("GET", "/", nil))
    72  	assert.True(transactorCalled)
    73  }
    74  
    75  func testRetryTransactorStatus(t *testing.T) {
    76  	var (
    77  		assert  = assert.New(t)
    78  		require = require.New(t)
    79  
    80  		transactorCount = 0
    81  		statusCheck     = 0
    82  		transactor      = func(*http.Request) (*http.Response, error) {
    83  			response := http.Response{
    84  				StatusCode: 429 + transactorCount,
    85  			}
    86  			transactorCount++
    87  			return &response, nil
    88  		}
    89  
    90  		retry = RetryTransactor(RetryOptions{
    91  			Retries: 5,
    92  			ShouldRetryStatus: func(status int) bool {
    93  				statusCheck++
    94  				return status == 429
    95  			},
    96  		}, transactor)
    97  	)
    98  
    99  	require.NotNil(retry)
   100  	retry(httptest.NewRequest("GET", "/", nil))
   101  	assert.Equal(2, transactorCount)
   102  	assert.Equal(2, statusCheck)
   103  }
   104  
   105  func testRetryTransactorAllRetriesFail(t *testing.T, expectedInterval, configuredInterval time.Duration, retryCount int) {
   106  	var (
   107  		assert          = assert.New(t)
   108  		require         = require.New(t)
   109  		expectedRequest = httptest.NewRequest("GET", "/", nil)
   110  		expectedError   = &net.DNSError{IsTemporary: true}
   111  		counter         = generic.NewCounter("test")
   112  		urls            = map[string]int{}
   113  
   114  		transactorCount = 0
   115  		transactor      = func(actualRequest *http.Request) (*http.Response, error) {
   116  			if _, ok := urls[actualRequest.URL.Path]; ok {
   117  				urls[actualRequest.URL.Path]++
   118  			} else {
   119  				urls[actualRequest.URL.Path] = 1
   120  			}
   121  			transactorCount++
   122  			assert.True(expectedRequest == actualRequest)
   123  			return nil, expectedError
   124  		}
   125  
   126  		slept = 0
   127  		retry = RetryTransactor(
   128  			RetryOptions{
   129  				Logger:   logging.NewTestLogger(nil, t),
   130  				Retries:  retryCount,
   131  				Counter:  counter,
   132  				Interval: configuredInterval,
   133  				Sleep: func(actualInterval time.Duration) {
   134  					slept++
   135  					assert.Equal(expectedInterval, actualInterval)
   136  				},
   137  				UpdateRequest: func(request *http.Request) {
   138  					if _, ok := urls[request.URL.Path]; ok {
   139  						request.URL.Path += "a"
   140  					}
   141  				},
   142  			},
   143  			transactor,
   144  		)
   145  	)
   146  
   147  	require.NotNil(retry)
   148  	actualResponse, actualError := retry(expectedRequest)
   149  	assert.Nil(actualResponse)
   150  	assert.Equal(expectedError, actualError)
   151  	assert.Equal(1+retryCount, transactorCount)
   152  	assert.Equal(float64(retryCount), counter.Value())
   153  	assert.Equal(retryCount, slept)
   154  	for _, v := range urls {
   155  		assert.Equal(1, v)
   156  	}
   157  }
   158  
   159  func testRetryTransactorFirstSucceeds(t *testing.T, retryCount int) {
   160  	var (
   161  		assert           = assert.New(t)
   162  		require          = require.New(t)
   163  		expectedRequest  = httptest.NewRequest("GET", "/", nil)
   164  		expectedResponse = new(http.Response)
   165  		counter          = generic.NewCounter("test")
   166  
   167  		transactorCount = 0
   168  		transactor      = func(actualRequest *http.Request) (*http.Response, error) {
   169  			transactorCount++
   170  			assert.True(expectedRequest == actualRequest)
   171  			return expectedResponse, nil
   172  		}
   173  
   174  		retry = RetryTransactor(
   175  			RetryOptions{
   176  				Logger:  logging.NewTestLogger(nil, t),
   177  				Retries: retryCount,
   178  				Counter: counter,
   179  				Sleep: func(d time.Duration) {
   180  					assert.Fail("Sleep should not have been called")
   181  				},
   182  			},
   183  			transactor,
   184  		)
   185  	)
   186  
   187  	require.NotNil(retry)
   188  	actualResponse, actualError := retry(expectedRequest)
   189  	assert.True(expectedResponse == actualResponse)
   190  	assert.NoError(actualError)
   191  	assert.Equal(1, transactorCount)
   192  	assert.Zero(counter.Value())
   193  }
   194  
   195  func testRetryTransactorNotRewindable(t *testing.T) {
   196  	var (
   197  		assert        = assert.New(t)
   198  		require       = require.New(t)
   199  		body          = new(mockReader)
   200  		expectedError = errors.New("expected")
   201  
   202  		retry = RetryTransactor(
   203  			RetryOptions{
   204  				Logger:  logging.NewTestLogger(nil, t),
   205  				Retries: 2,
   206  			},
   207  			func(*http.Request) (*http.Response, error) {
   208  				assert.Fail("The decorated transactor should not have been called")
   209  				return nil, nil
   210  			},
   211  		)
   212  	)
   213  
   214  	body.On("Read", mock.MatchedBy(func([]byte) bool { return true })).Return(0, expectedError).Once()
   215  	require.NotNil(retry)
   216  	response, actualError := retry(&http.Request{Body: ioutil.NopCloser(body)})
   217  	assert.Nil(response)
   218  	assert.Equal(expectedError, actualError)
   219  
   220  	body.AssertExpectations(t)
   221  }
   222  
   223  func testRetryTransactorRewindError(t *testing.T) {
   224  	var (
   225  		assert        = assert.New(t)
   226  		require       = require.New(t)
   227  		expectedError = errors.New("expected")
   228  
   229  		retry = RetryTransactor(
   230  			RetryOptions{
   231  				Logger:  logging.NewTestLogger(nil, t),
   232  				Retries: 2,
   233  				Sleep:   func(time.Duration) {},
   234  			},
   235  			func(*http.Request) (*http.Response, error) {
   236  				return nil, &net.DNSError{IsTemporary: true}
   237  			},
   238  		)
   239  
   240  		r = httptest.NewRequest("POST", "/", nil)
   241  	)
   242  
   243  	r.GetBody = func() (io.ReadCloser, error) {
   244  		return nil, expectedError
   245  	}
   246  
   247  	require.NotNil(retry)
   248  	response, actualError := retry(r)
   249  	assert.Nil(response)
   250  	assert.Equal(expectedError, actualError)
   251  }
   252  
   253  func TestRetryTransactor(t *testing.T) {
   254  	t.Run("DefaultLogger", testRetryTransactorDefaultLogger)
   255  	t.Run("NoRetries", testRetryTransactorNoRetries)
   256  
   257  	t.Run("AllRetriesFail", func(t *testing.T) {
   258  		for _, retryCount := range []int{1, 2, 5} {
   259  			t.Run(fmt.Sprintf("RetryCount=%d", retryCount), func(t *testing.T) {
   260  				testRetryTransactorAllRetriesFail(t, time.Second, 0, retryCount)
   261  				testRetryTransactorAllRetriesFail(t, 10*time.Minute, 10*time.Minute, retryCount)
   262  			})
   263  		}
   264  	})
   265  
   266  	t.Run("FirstSucceeds", func(t *testing.T) {
   267  		for _, retryCount := range []int{1, 2, 5} {
   268  			t.Run(fmt.Sprintf("RetryCount=%d", retryCount), func(t *testing.T) { testRetryTransactorFirstSucceeds(t, retryCount) })
   269  		}
   270  	})
   271  
   272  	t.Run("NotRewindable", testRetryTransactorNotRewindable)
   273  	t.Run("RewindError", testRetryTransactorRewindError)
   274  	t.Run("StatusRetry", testRetryTransactorStatus)
   275  }