git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/retry/retry_test.go (about) 1 package retry 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "testing" 8 "time" 9 10 "github.com/stretchr/testify/assert" 11 ) 12 13 func TestDoAllFailed(t *testing.T) { 14 var retrySum uint 15 err := Do( 16 func() error { return errors.New("test") }, 17 OnRetry(func(n uint, err error) { retrySum += n }), 18 Delay(time.Nanosecond), 19 ) 20 assert.Error(t, err) 21 22 expectedErrorFormat := `All attempts fail: 23 #1: test 24 #2: test 25 #3: test 26 #4: test 27 #5: test 28 #6: test 29 #7: test 30 #8: test 31 #9: test 32 #10: test` 33 assert.Equal(t, expectedErrorFormat, err.Error(), "retry error format") 34 assert.Equal(t, uint(45), retrySum, "right count of retry") 35 } 36 37 func TestDoFirstOk(t *testing.T) { 38 var retrySum uint 39 err := Do( 40 func() error { return nil }, 41 OnRetry(func(n uint, err error) { retrySum += n }), 42 ) 43 assert.NoError(t, err) 44 assert.Equal(t, uint(0), retrySum, "no retry") 45 46 } 47 48 func TestRetryIf(t *testing.T) { 49 var retryCount uint 50 err := Do( 51 func() error { 52 if retryCount >= 2 { 53 return errors.New("special") 54 } else { 55 return errors.New("test") 56 } 57 }, 58 OnRetry(func(n uint, err error) { retryCount++ }), 59 RetryIf(func(err error) bool { 60 return err.Error() != "special" 61 }), 62 Delay(time.Nanosecond), 63 ) 64 assert.Error(t, err) 65 66 expectedErrorFormat := `All attempts fail: 67 #1: test 68 #2: test 69 #3: special` 70 assert.Equal(t, expectedErrorFormat, err.Error(), "retry error format") 71 assert.Equal(t, uint(2), retryCount, "right count of retry") 72 73 } 74 75 func TestZeroAttemptsWithError(t *testing.T) { 76 const maxErrors = 999 77 count := 0 78 79 err := Do( 80 func() error { 81 if count < maxErrors { 82 count += 1 83 return errors.New("test") 84 } 85 86 return nil 87 }, 88 Attempts(0), 89 MaxDelay(time.Nanosecond), 90 ) 91 assert.NoError(t, err) 92 93 assert.Equal(t, count, maxErrors) 94 } 95 96 func TestZeroAttemptsWithoutError(t *testing.T) { 97 count := 0 98 99 err := Do( 100 func() error { 101 count++ 102 103 return nil 104 }, 105 Attempts(0), 106 ) 107 assert.NoError(t, err) 108 109 assert.Equal(t, count, 1) 110 } 111 112 func TestDefaultSleep(t *testing.T) { 113 start := time.Now() 114 err := Do( 115 func() error { return errors.New("test") }, 116 Attempts(3), 117 ) 118 dur := time.Since(start) 119 assert.Error(t, err) 120 assert.True(t, dur > 300*time.Millisecond, "3 times default retry is longer then 300ms") 121 } 122 123 func TestFixedSleep(t *testing.T) { 124 start := time.Now() 125 err := Do( 126 func() error { return errors.New("test") }, 127 Attempts(3), 128 DelayType(FixedDelay), 129 ) 130 dur := time.Since(start) 131 assert.Error(t, err) 132 assert.True(t, dur < 500*time.Millisecond, "3 times default retry is shorter then 500ms") 133 } 134 135 func TestLastErrorOnly(t *testing.T) { 136 var retrySum uint 137 err := Do( 138 func() error { return fmt.Errorf("%d", retrySum) }, 139 OnRetry(func(n uint, err error) { retrySum += 1 }), 140 Delay(time.Nanosecond), 141 LastErrorOnly(true), 142 ) 143 assert.Error(t, err) 144 assert.Equal(t, "9", err.Error()) 145 } 146 147 func TestUnrecoverableError(t *testing.T) { 148 attempts := 0 149 expectedErr := errors.New("error") 150 err := Do( 151 func() error { 152 attempts++ 153 return Unrecoverable(expectedErr) 154 }, 155 Attempts(2), 156 LastErrorOnly(true), 157 ) 158 assert.Equal(t, expectedErr, err) 159 assert.Equal(t, 1, attempts, "unrecoverable error broke the loop") 160 } 161 162 func TestCombineFixedDelays(t *testing.T) { 163 start := time.Now() 164 err := Do( 165 func() error { return errors.New("test") }, 166 Attempts(3), 167 DelayType(CombineDelay(FixedDelay, FixedDelay)), 168 ) 169 dur := time.Since(start) 170 assert.Error(t, err) 171 assert.True(t, dur > 400*time.Millisecond, "3 times combined, fixed retry is longer then 400ms") 172 assert.True(t, dur < 500*time.Millisecond, "3 times combined, fixed retry is shorter then 500ms") 173 } 174 175 func TestRandomDelay(t *testing.T) { 176 start := time.Now() 177 err := Do( 178 func() error { return errors.New("test") }, 179 Attempts(3), 180 DelayType(RandomDelay), 181 MaxJitter(50*time.Millisecond), 182 ) 183 dur := time.Since(start) 184 assert.Error(t, err) 185 assert.True(t, dur > 2*time.Millisecond, "3 times random retry is longer then 2ms") 186 assert.True(t, dur < 100*time.Millisecond, "3 times random retry is shorter then 100ms") 187 } 188 189 func TestMaxDelay(t *testing.T) { 190 start := time.Now() 191 err := Do( 192 func() error { return errors.New("test") }, 193 Attempts(5), 194 Delay(10*time.Millisecond), 195 MaxDelay(50*time.Millisecond), 196 ) 197 dur := time.Since(start) 198 assert.Error(t, err) 199 assert.True(t, dur > 120*time.Millisecond, "5 times with maximum delay retry is longer than 120ms") 200 assert.True(t, dur < 205*time.Millisecond, "5 times with maximum delay retry is shorter than 205ms") 201 } 202 203 func TestBackOffDelay(t *testing.T) { 204 for _, c := range []struct { 205 label string 206 delay time.Duration 207 expectedMaxN uint 208 n uint 209 expectedDelay time.Duration 210 }{ 211 { 212 label: "negative-delay", 213 delay: -1, 214 expectedMaxN: 62, 215 n: 2, 216 expectedDelay: 4, 217 }, 218 { 219 label: "zero-delay", 220 delay: 0, 221 expectedMaxN: 62, 222 n: 65, 223 expectedDelay: 1 << 62, 224 }, 225 { 226 label: "one-second", 227 delay: time.Second, 228 expectedMaxN: 33, 229 n: 62, 230 expectedDelay: time.Second << 33, 231 }, 232 } { 233 t.Run( 234 c.label, 235 func(t *testing.T) { 236 config := Config{ 237 delay: c.delay, 238 } 239 delay := BackOffDelay(c.n, nil, &config) 240 assert.Equal(t, c.expectedMaxN, config.maxBackOffN, "max n mismatch") 241 assert.Equal(t, c.expectedDelay, delay, "delay duration mismatch") 242 }, 243 ) 244 } 245 } 246 247 func TestCombineDelay(t *testing.T) { 248 f := func(d time.Duration) DelayTypeFunc { 249 return func(_ uint, _ error, _ *Config) time.Duration { 250 return d 251 } 252 } 253 const max = time.Duration(1<<63 - 1) 254 for _, c := range []struct { 255 label string 256 delays []time.Duration 257 expected time.Duration 258 }{ 259 { 260 label: "empty", 261 }, 262 { 263 label: "single", 264 delays: []time.Duration{ 265 time.Second, 266 }, 267 expected: time.Second, 268 }, 269 { 270 label: "negative", 271 delays: []time.Duration{ 272 time.Second, 273 -time.Millisecond, 274 }, 275 expected: time.Second - time.Millisecond, 276 }, 277 { 278 label: "overflow", 279 delays: []time.Duration{ 280 max, 281 time.Second, 282 time.Millisecond, 283 }, 284 expected: max, 285 }, 286 } { 287 t.Run( 288 c.label, 289 func(t *testing.T) { 290 funcs := make([]DelayTypeFunc, len(c.delays)) 291 for i, d := range c.delays { 292 funcs[i] = f(d) 293 } 294 actual := CombineDelay(funcs...)(0, nil, nil) 295 assert.Equal(t, c.expected, actual, "delay duration mismatch") 296 }, 297 ) 298 } 299 } 300 301 func TestContext(t *testing.T) { 302 const defaultDelay = 100 * time.Millisecond 303 t.Run("cancel before", func(t *testing.T) { 304 ctx, cancel := context.WithCancel(context.Background()) 305 cancel() 306 307 retrySum := 0 308 start := time.Now() 309 err := Do( 310 func() error { return errors.New("test") }, 311 OnRetry(func(n uint, err error) { retrySum += 1 }), 312 Context(ctx), 313 ) 314 dur := time.Since(start) 315 assert.Error(t, err) 316 assert.True(t, dur < defaultDelay, "immediately cancellation") 317 assert.Equal(t, 0, retrySum, "called at most once") 318 }) 319 320 t.Run("cancel in retry progress", func(t *testing.T) { 321 ctx, cancel := context.WithCancel(context.Background()) 322 323 retrySum := 0 324 err := Do( 325 func() error { return errors.New("test") }, 326 OnRetry(func(n uint, err error) { 327 retrySum += 1 328 if retrySum > 1 { 329 cancel() 330 } 331 }), 332 Context(ctx), 333 ) 334 assert.Error(t, err) 335 336 expectedErrorFormat := `All attempts fail: 337 #1: test 338 #2: context canceled` 339 assert.Equal(t, expectedErrorFormat, err.Error(), "retry error format") 340 assert.Equal(t, 2, retrySum, "called at most once") 341 }) 342 343 t.Run("cancel in retry progress - last error only", func(t *testing.T) { 344 ctx, cancel := context.WithCancel(context.Background()) 345 346 retrySum := 0 347 err := Do( 348 func() error { return errors.New("test") }, 349 OnRetry(func(n uint, err error) { 350 retrySum += 1 351 if retrySum > 1 { 352 cancel() 353 } 354 }), 355 Context(ctx), 356 LastErrorOnly(true), 357 ) 358 assert.Equal(t, context.Canceled, err) 359 360 assert.Equal(t, 2, retrySum, "called at most once") 361 }) 362 }