git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/retry/retry_test.go (about)

     1  package retry
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  )
    12  
    13  func TestDoAllFailed(t *testing.T) {
    14  	var retrySum uint
    15  	err := Do(
    16  		func() error { return errors.New("test") },
    17  		OnRetry(func(n uint, err error) { retrySum += n }),
    18  		Delay(time.Nanosecond),
    19  	)
    20  	assert.Error(t, err)
    21  
    22  	expectedErrorFormat := `All attempts fail:
    23  #1: test
    24  #2: test
    25  #3: test
    26  #4: test
    27  #5: test
    28  #6: test
    29  #7: test
    30  #8: test
    31  #9: test
    32  #10: test`
    33  	assert.Equal(t, expectedErrorFormat, err.Error(), "retry error format")
    34  	assert.Equal(t, uint(45), retrySum, "right count of retry")
    35  }
    36  
    37  func TestDoFirstOk(t *testing.T) {
    38  	var retrySum uint
    39  	err := Do(
    40  		func() error { return nil },
    41  		OnRetry(func(n uint, err error) { retrySum += n }),
    42  	)
    43  	assert.NoError(t, err)
    44  	assert.Equal(t, uint(0), retrySum, "no retry")
    45  
    46  }
    47  
    48  func TestRetryIf(t *testing.T) {
    49  	var retryCount uint
    50  	err := Do(
    51  		func() error {
    52  			if retryCount >= 2 {
    53  				return errors.New("special")
    54  			} else {
    55  				return errors.New("test")
    56  			}
    57  		},
    58  		OnRetry(func(n uint, err error) { retryCount++ }),
    59  		RetryIf(func(err error) bool {
    60  			return err.Error() != "special"
    61  		}),
    62  		Delay(time.Nanosecond),
    63  	)
    64  	assert.Error(t, err)
    65  
    66  	expectedErrorFormat := `All attempts fail:
    67  #1: test
    68  #2: test
    69  #3: special`
    70  	assert.Equal(t, expectedErrorFormat, err.Error(), "retry error format")
    71  	assert.Equal(t, uint(2), retryCount, "right count of retry")
    72  
    73  }
    74  
    75  func TestZeroAttemptsWithError(t *testing.T) {
    76  	const maxErrors = 999
    77  	count := 0
    78  
    79  	err := Do(
    80  		func() error {
    81  			if count < maxErrors {
    82  				count += 1
    83  				return errors.New("test")
    84  			}
    85  
    86  			return nil
    87  		},
    88  		Attempts(0),
    89  		MaxDelay(time.Nanosecond),
    90  	)
    91  	assert.NoError(t, err)
    92  
    93  	assert.Equal(t, count, maxErrors)
    94  }
    95  
    96  func TestZeroAttemptsWithoutError(t *testing.T) {
    97  	count := 0
    98  
    99  	err := Do(
   100  		func() error {
   101  			count++
   102  
   103  			return nil
   104  		},
   105  		Attempts(0),
   106  	)
   107  	assert.NoError(t, err)
   108  
   109  	assert.Equal(t, count, 1)
   110  }
   111  
   112  func TestDefaultSleep(t *testing.T) {
   113  	start := time.Now()
   114  	err := Do(
   115  		func() error { return errors.New("test") },
   116  		Attempts(3),
   117  	)
   118  	dur := time.Since(start)
   119  	assert.Error(t, err)
   120  	assert.True(t, dur > 300*time.Millisecond, "3 times default retry is longer then 300ms")
   121  }
   122  
   123  func TestFixedSleep(t *testing.T) {
   124  	start := time.Now()
   125  	err := Do(
   126  		func() error { return errors.New("test") },
   127  		Attempts(3),
   128  		DelayType(FixedDelay),
   129  	)
   130  	dur := time.Since(start)
   131  	assert.Error(t, err)
   132  	assert.True(t, dur < 500*time.Millisecond, "3 times default retry is shorter then 500ms")
   133  }
   134  
   135  func TestLastErrorOnly(t *testing.T) {
   136  	var retrySum uint
   137  	err := Do(
   138  		func() error { return fmt.Errorf("%d", retrySum) },
   139  		OnRetry(func(n uint, err error) { retrySum += 1 }),
   140  		Delay(time.Nanosecond),
   141  		LastErrorOnly(true),
   142  	)
   143  	assert.Error(t, err)
   144  	assert.Equal(t, "9", err.Error())
   145  }
   146  
   147  func TestUnrecoverableError(t *testing.T) {
   148  	attempts := 0
   149  	expectedErr := errors.New("error")
   150  	err := Do(
   151  		func() error {
   152  			attempts++
   153  			return Unrecoverable(expectedErr)
   154  		},
   155  		Attempts(2),
   156  		LastErrorOnly(true),
   157  	)
   158  	assert.Equal(t, expectedErr, err)
   159  	assert.Equal(t, 1, attempts, "unrecoverable error broke the loop")
   160  }
   161  
   162  func TestCombineFixedDelays(t *testing.T) {
   163  	start := time.Now()
   164  	err := Do(
   165  		func() error { return errors.New("test") },
   166  		Attempts(3),
   167  		DelayType(CombineDelay(FixedDelay, FixedDelay)),
   168  	)
   169  	dur := time.Since(start)
   170  	assert.Error(t, err)
   171  	assert.True(t, dur > 400*time.Millisecond, "3 times combined, fixed retry is longer then 400ms")
   172  	assert.True(t, dur < 500*time.Millisecond, "3 times combined, fixed retry is shorter then 500ms")
   173  }
   174  
   175  func TestRandomDelay(t *testing.T) {
   176  	start := time.Now()
   177  	err := Do(
   178  		func() error { return errors.New("test") },
   179  		Attempts(3),
   180  		DelayType(RandomDelay),
   181  		MaxJitter(50*time.Millisecond),
   182  	)
   183  	dur := time.Since(start)
   184  	assert.Error(t, err)
   185  	assert.True(t, dur > 2*time.Millisecond, "3 times random retry is longer then 2ms")
   186  	assert.True(t, dur < 100*time.Millisecond, "3 times random retry is shorter then 100ms")
   187  }
   188  
   189  func TestMaxDelay(t *testing.T) {
   190  	start := time.Now()
   191  	err := Do(
   192  		func() error { return errors.New("test") },
   193  		Attempts(5),
   194  		Delay(10*time.Millisecond),
   195  		MaxDelay(50*time.Millisecond),
   196  	)
   197  	dur := time.Since(start)
   198  	assert.Error(t, err)
   199  	assert.True(t, dur > 120*time.Millisecond, "5 times with maximum delay retry is longer than 120ms")
   200  	assert.True(t, dur < 205*time.Millisecond, "5 times with maximum delay retry is shorter than 205ms")
   201  }
   202  
   203  func TestBackOffDelay(t *testing.T) {
   204  	for _, c := range []struct {
   205  		label         string
   206  		delay         time.Duration
   207  		expectedMaxN  uint
   208  		n             uint
   209  		expectedDelay time.Duration
   210  	}{
   211  		{
   212  			label:         "negative-delay",
   213  			delay:         -1,
   214  			expectedMaxN:  62,
   215  			n:             2,
   216  			expectedDelay: 4,
   217  		},
   218  		{
   219  			label:         "zero-delay",
   220  			delay:         0,
   221  			expectedMaxN:  62,
   222  			n:             65,
   223  			expectedDelay: 1 << 62,
   224  		},
   225  		{
   226  			label:         "one-second",
   227  			delay:         time.Second,
   228  			expectedMaxN:  33,
   229  			n:             62,
   230  			expectedDelay: time.Second << 33,
   231  		},
   232  	} {
   233  		t.Run(
   234  			c.label,
   235  			func(t *testing.T) {
   236  				config := Config{
   237  					delay: c.delay,
   238  				}
   239  				delay := BackOffDelay(c.n, nil, &config)
   240  				assert.Equal(t, c.expectedMaxN, config.maxBackOffN, "max n mismatch")
   241  				assert.Equal(t, c.expectedDelay, delay, "delay duration mismatch")
   242  			},
   243  		)
   244  	}
   245  }
   246  
   247  func TestCombineDelay(t *testing.T) {
   248  	f := func(d time.Duration) DelayTypeFunc {
   249  		return func(_ uint, _ error, _ *Config) time.Duration {
   250  			return d
   251  		}
   252  	}
   253  	const max = time.Duration(1<<63 - 1)
   254  	for _, c := range []struct {
   255  		label    string
   256  		delays   []time.Duration
   257  		expected time.Duration
   258  	}{
   259  		{
   260  			label: "empty",
   261  		},
   262  		{
   263  			label: "single",
   264  			delays: []time.Duration{
   265  				time.Second,
   266  			},
   267  			expected: time.Second,
   268  		},
   269  		{
   270  			label: "negative",
   271  			delays: []time.Duration{
   272  				time.Second,
   273  				-time.Millisecond,
   274  			},
   275  			expected: time.Second - time.Millisecond,
   276  		},
   277  		{
   278  			label: "overflow",
   279  			delays: []time.Duration{
   280  				max,
   281  				time.Second,
   282  				time.Millisecond,
   283  			},
   284  			expected: max,
   285  		},
   286  	} {
   287  		t.Run(
   288  			c.label,
   289  			func(t *testing.T) {
   290  				funcs := make([]DelayTypeFunc, len(c.delays))
   291  				for i, d := range c.delays {
   292  					funcs[i] = f(d)
   293  				}
   294  				actual := CombineDelay(funcs...)(0, nil, nil)
   295  				assert.Equal(t, c.expected, actual, "delay duration mismatch")
   296  			},
   297  		)
   298  	}
   299  }
   300  
   301  func TestContext(t *testing.T) {
   302  	const defaultDelay = 100 * time.Millisecond
   303  	t.Run("cancel before", func(t *testing.T) {
   304  		ctx, cancel := context.WithCancel(context.Background())
   305  		cancel()
   306  
   307  		retrySum := 0
   308  		start := time.Now()
   309  		err := Do(
   310  			func() error { return errors.New("test") },
   311  			OnRetry(func(n uint, err error) { retrySum += 1 }),
   312  			Context(ctx),
   313  		)
   314  		dur := time.Since(start)
   315  		assert.Error(t, err)
   316  		assert.True(t, dur < defaultDelay, "immediately cancellation")
   317  		assert.Equal(t, 0, retrySum, "called at most once")
   318  	})
   319  
   320  	t.Run("cancel in retry progress", func(t *testing.T) {
   321  		ctx, cancel := context.WithCancel(context.Background())
   322  
   323  		retrySum := 0
   324  		err := Do(
   325  			func() error { return errors.New("test") },
   326  			OnRetry(func(n uint, err error) {
   327  				retrySum += 1
   328  				if retrySum > 1 {
   329  					cancel()
   330  				}
   331  			}),
   332  			Context(ctx),
   333  		)
   334  		assert.Error(t, err)
   335  
   336  		expectedErrorFormat := `All attempts fail:
   337  #1: test
   338  #2: context canceled`
   339  		assert.Equal(t, expectedErrorFormat, err.Error(), "retry error format")
   340  		assert.Equal(t, 2, retrySum, "called at most once")
   341  	})
   342  
   343  	t.Run("cancel in retry progress - last error only", func(t *testing.T) {
   344  		ctx, cancel := context.WithCancel(context.Background())
   345  
   346  		retrySum := 0
   347  		err := Do(
   348  			func() error { return errors.New("test") },
   349  			OnRetry(func(n uint, err error) {
   350  				retrySum += 1
   351  				if retrySum > 1 {
   352  					cancel()
   353  				}
   354  			}),
   355  			Context(ctx),
   356  			LastErrorOnly(true),
   357  		)
   358  		assert.Equal(t, context.Canceled, err)
   359  
   360  		assert.Equal(t, 2, retrySum, "called at most once")
   361  	})
   362  }