github.com/xmidt-org/webpa-common@v1.11.9/xhttp/retry_test.go (about) 1 package xhttp 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "io/ioutil" 8 "net" 9 "net/http" 10 "net/http/httptest" 11 "testing" 12 "time" 13 14 "github.com/go-kit/kit/metrics/generic" 15 "github.com/stretchr/testify/assert" 16 "github.com/stretchr/testify/mock" 17 "github.com/stretchr/testify/require" 18 "github.com/xmidt-org/webpa-common/logging" 19 ) 20 21 func testShouldRetry(t *testing.T, shouldRetry ShouldRetryFunc, candidate error, expected bool) { 22 assert := assert.New(t) 23 assert.Equal(expected, shouldRetry(candidate)) 24 } 25 26 func TestDefaultShouldRetry(t *testing.T) { 27 t.Run("Nil", func(t *testing.T) { 28 testShouldRetry(t, DefaultShouldRetry, nil, false) 29 }) 30 31 t.Run("DNSError", func(t *testing.T) { 32 testShouldRetry(t, DefaultShouldRetry, &net.DNSError{IsTemporary: false}, false) 33 testShouldRetry(t, DefaultShouldRetry, &net.DNSError{IsTemporary: true}, true) 34 }) 35 } 36 37 func testRetryTransactorDefaultLogger(t *testing.T) { 38 var ( 39 assert = assert.New(t) 40 require = require.New(t) 41 transactorCalled = false 42 43 transactor = func(*http.Request) (*http.Response, error) { 44 transactorCalled = true 45 return nil, nil 46 } 47 48 retry = RetryTransactor(RetryOptions{Retries: 1}, transactor) 49 ) 50 51 require.NotNil(retry) 52 retry(httptest.NewRequest("GET", "/", nil)) 53 assert.True(transactorCalled) 54 } 55 56 func testRetryTransactorNoRetries(t *testing.T) { 57 var ( 58 assert = assert.New(t) 59 require = require.New(t) 60 transactorCalled = false 61 62 transactor = func(*http.Request) (*http.Response, error) { 63 transactorCalled = true 64 return nil, nil 65 } 66 67 retry = RetryTransactor(RetryOptions{}, transactor) 68 ) 69 70 require.NotNil(retry) 71 retry(httptest.NewRequest("GET", "/", nil)) 72 assert.True(transactorCalled) 73 } 74 75 func testRetryTransactorStatus(t *testing.T) { 76 var ( 77 assert = assert.New(t) 78 require = require.New(t) 79 80 transactorCount = 0 81 statusCheck = 0 82 transactor = func(*http.Request) (*http.Response, error) { 83 response := http.Response{ 84 StatusCode: 429 + transactorCount, 85 } 86 transactorCount++ 87 return &response, nil 88 } 89 90 retry = RetryTransactor(RetryOptions{ 91 Retries: 5, 92 ShouldRetryStatus: func(status int) bool { 93 statusCheck++ 94 return status == 429 95 }, 96 }, transactor) 97 ) 98 99 require.NotNil(retry) 100 retry(httptest.NewRequest("GET", "/", nil)) 101 assert.Equal(2, transactorCount) 102 assert.Equal(2, statusCheck) 103 } 104 105 func testRetryTransactorAllRetriesFail(t *testing.T, expectedInterval, configuredInterval time.Duration, retryCount int) { 106 var ( 107 assert = assert.New(t) 108 require = require.New(t) 109 expectedRequest = httptest.NewRequest("GET", "/", nil) 110 expectedError = &net.DNSError{IsTemporary: true} 111 counter = generic.NewCounter("test") 112 urls = map[string]int{} 113 114 transactorCount = 0 115 transactor = func(actualRequest *http.Request) (*http.Response, error) { 116 if _, ok := urls[actualRequest.URL.Path]; ok { 117 urls[actualRequest.URL.Path]++ 118 } else { 119 urls[actualRequest.URL.Path] = 1 120 } 121 transactorCount++ 122 assert.True(expectedRequest == actualRequest) 123 return nil, expectedError 124 } 125 126 slept = 0 127 retry = RetryTransactor( 128 RetryOptions{ 129 Logger: logging.NewTestLogger(nil, t), 130 Retries: retryCount, 131 Counter: counter, 132 Interval: configuredInterval, 133 Sleep: func(actualInterval time.Duration) { 134 slept++ 135 assert.Equal(expectedInterval, actualInterval) 136 }, 137 UpdateRequest: func(request *http.Request) { 138 if _, ok := urls[request.URL.Path]; ok { 139 request.URL.Path += "a" 140 } 141 }, 142 }, 143 transactor, 144 ) 145 ) 146 147 require.NotNil(retry) 148 actualResponse, actualError := retry(expectedRequest) 149 assert.Nil(actualResponse) 150 assert.Equal(expectedError, actualError) 151 assert.Equal(1+retryCount, transactorCount) 152 assert.Equal(float64(retryCount), counter.Value()) 153 assert.Equal(retryCount, slept) 154 for _, v := range urls { 155 assert.Equal(1, v) 156 } 157 } 158 159 func testRetryTransactorFirstSucceeds(t *testing.T, retryCount int) { 160 var ( 161 assert = assert.New(t) 162 require = require.New(t) 163 expectedRequest = httptest.NewRequest("GET", "/", nil) 164 expectedResponse = new(http.Response) 165 counter = generic.NewCounter("test") 166 167 transactorCount = 0 168 transactor = func(actualRequest *http.Request) (*http.Response, error) { 169 transactorCount++ 170 assert.True(expectedRequest == actualRequest) 171 return expectedResponse, nil 172 } 173 174 retry = RetryTransactor( 175 RetryOptions{ 176 Logger: logging.NewTestLogger(nil, t), 177 Retries: retryCount, 178 Counter: counter, 179 Sleep: func(d time.Duration) { 180 assert.Fail("Sleep should not have been called") 181 }, 182 }, 183 transactor, 184 ) 185 ) 186 187 require.NotNil(retry) 188 actualResponse, actualError := retry(expectedRequest) 189 assert.True(expectedResponse == actualResponse) 190 assert.NoError(actualError) 191 assert.Equal(1, transactorCount) 192 assert.Zero(counter.Value()) 193 } 194 195 func testRetryTransactorNotRewindable(t *testing.T) { 196 var ( 197 assert = assert.New(t) 198 require = require.New(t) 199 body = new(mockReader) 200 expectedError = errors.New("expected") 201 202 retry = RetryTransactor( 203 RetryOptions{ 204 Logger: logging.NewTestLogger(nil, t), 205 Retries: 2, 206 }, 207 func(*http.Request) (*http.Response, error) { 208 assert.Fail("The decorated transactor should not have been called") 209 return nil, nil 210 }, 211 ) 212 ) 213 214 body.On("Read", mock.MatchedBy(func([]byte) bool { return true })).Return(0, expectedError).Once() 215 require.NotNil(retry) 216 response, actualError := retry(&http.Request{Body: ioutil.NopCloser(body)}) 217 assert.Nil(response) 218 assert.Equal(expectedError, actualError) 219 220 body.AssertExpectations(t) 221 } 222 223 func testRetryTransactorRewindError(t *testing.T) { 224 var ( 225 assert = assert.New(t) 226 require = require.New(t) 227 expectedError = errors.New("expected") 228 229 retry = RetryTransactor( 230 RetryOptions{ 231 Logger: logging.NewTestLogger(nil, t), 232 Retries: 2, 233 Sleep: func(time.Duration) {}, 234 }, 235 func(*http.Request) (*http.Response, error) { 236 return nil, &net.DNSError{IsTemporary: true} 237 }, 238 ) 239 240 r = httptest.NewRequest("POST", "/", nil) 241 ) 242 243 r.GetBody = func() (io.ReadCloser, error) { 244 return nil, expectedError 245 } 246 247 require.NotNil(retry) 248 response, actualError := retry(r) 249 assert.Nil(response) 250 assert.Equal(expectedError, actualError) 251 } 252 253 func TestRetryTransactor(t *testing.T) { 254 t.Run("DefaultLogger", testRetryTransactorDefaultLogger) 255 t.Run("NoRetries", testRetryTransactorNoRetries) 256 257 t.Run("AllRetriesFail", func(t *testing.T) { 258 for _, retryCount := range []int{1, 2, 5} { 259 t.Run(fmt.Sprintf("RetryCount=%d", retryCount), func(t *testing.T) { 260 testRetryTransactorAllRetriesFail(t, time.Second, 0, retryCount) 261 testRetryTransactorAllRetriesFail(t, 10*time.Minute, 10*time.Minute, retryCount) 262 }) 263 } 264 }) 265 266 t.Run("FirstSucceeds", func(t *testing.T) { 267 for _, retryCount := range []int{1, 2, 5} { 268 t.Run(fmt.Sprintf("RetryCount=%d", retryCount), func(t *testing.T) { testRetryTransactorFirstSucceeds(t, retryCount) }) 269 } 270 }) 271 272 t.Run("NotRewindable", testRetryTransactorNotRewindable) 273 t.Run("RewindError", testRetryTransactorRewindError) 274 t.Run("StatusRetry", testRetryTransactorStatus) 275 }