github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/backoff/backoff_test.go (about)

     1  package backoff
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"log"
     9  	"math"
    10  	"math/rand"
    11  	"testing"
    12  	"time"
    13  )
    14  
    15  func ExampleRetry() {
    16  	// An operation that may fail.
    17  	operation := func(retryTimes int) error {
    18  		return nil // or an error
    19  	}
    20  
    21  	err := Retry(operation, NewExponentialBackOff())
    22  	if err != nil {
    23  		// Handle error.
    24  		return
    25  	}
    26  
    27  	// Operation is successful.
    28  }
    29  
    30  func ExampleWithContext() { // nolint: govet
    31  	// A context
    32  	ctx := context.Background()
    33  
    34  	// An operation that may fail.
    35  	operation := func(retryTimes int) error {
    36  		return nil // or an error
    37  	}
    38  
    39  	b := WithContext(NewExponentialBackOff(), ctx)
    40  
    41  	err := Retry(operation, b)
    42  	if err != nil {
    43  		// Handle error.
    44  		return
    45  	}
    46  
    47  	// Operation is successful.
    48  }
    49  
    50  func ExampleTicker() {
    51  	// An operation that may fail.
    52  	operation := func() error {
    53  		return nil // or an error
    54  	}
    55  
    56  	ticker := NewTicker(NewExponentialBackOff())
    57  
    58  	var err error
    59  
    60  	// Ticks will continue to arrive when the previous operation is still running,
    61  	// so operations that take a while to fail could run in quick succession.
    62  	for range ticker.C {
    63  		if err = operation(); err != nil {
    64  			log.Println(err, "will retry...")
    65  			continue
    66  		}
    67  
    68  		ticker.Stop()
    69  		break
    70  	}
    71  
    72  	if err != nil {
    73  		// Operation has failed.
    74  		return
    75  	}
    76  
    77  	// Operation is successful.
    78  }
    79  
    80  func TestNextBackOffMillis(t *testing.T) {
    81  	subtestNextBackOff(t, 0, new(ZeroBackOff))
    82  	subtestNextBackOff(t, Stop, new(StopBackOff))
    83  }
    84  
    85  func subtestNextBackOff(t *testing.T, expectedValue time.Duration, backOffPolicy BackOff) {
    86  	for i := 0; i < 10; i++ {
    87  		next := backOffPolicy.NextBackOff()
    88  		if next != expectedValue {
    89  			t.Errorf("got: %d expected: %d", next, expectedValue)
    90  		}
    91  	}
    92  }
    93  
    94  func TestConstantBackOff(t *testing.T) {
    95  	backoff := NewConstantBackOff(time.Second)
    96  	if backoff.NextBackOff() != time.Second {
    97  		t.Error("invalid interval")
    98  	}
    99  }
   100  
   101  func TestContext(t *testing.T) {
   102  	b := NewConstantBackOff(time.Millisecond)
   103  	ctx, cancel := context.WithCancel(context.Background())
   104  	defer cancel()
   105  
   106  	cb := WithContext(b, ctx)
   107  
   108  	if cb.Context() != ctx {
   109  		t.Error("invalid context")
   110  	}
   111  
   112  	cancel()
   113  
   114  	if cb.NextBackOff() != Stop {
   115  		t.Error("invalid next back off")
   116  	}
   117  }
   118  
   119  func TestBackOff(t *testing.T) {
   120  	var (
   121  		testInitialInterval     = 500 * time.Millisecond
   122  		testRandomizationFactor = 0.1
   123  		testMultiplier          = 2.0
   124  		testMaxInterval         = 5 * time.Second
   125  		testMaxElapsedTime      = 15 * time.Minute
   126  	)
   127  
   128  	exp := NewExponentialBackOff()
   129  	exp.InitialInterval = testInitialInterval
   130  	exp.RandomizationFactor = testRandomizationFactor
   131  	exp.Multiplier = testMultiplier
   132  	exp.MaxInterval = testMaxInterval
   133  	exp.MaxElapsedTime = testMaxElapsedTime
   134  	exp.Reset()
   135  
   136  	expectedResults := []time.Duration{500, 1000, 2000, 4000, 5000, 5000, 5000, 5000, 5000, 5000}
   137  	for i, d := range expectedResults {
   138  		expectedResults[i] = d * time.Millisecond
   139  	}
   140  
   141  	for _, expected := range expectedResults {
   142  		assertEquals(t, expected, exp.currentInterval)
   143  		// Assert that the next backoff falls in the expected range.
   144  		minInterval := expected - time.Duration(testRandomizationFactor*float64(expected))
   145  		maxInterval := expected + time.Duration(testRandomizationFactor*float64(expected))
   146  		actualInterval := exp.NextBackOff()
   147  		if !(minInterval <= actualInterval && actualInterval <= maxInterval) {
   148  			t.Error("error")
   149  		}
   150  	}
   151  }
   152  
   153  func TestGetRandomizedInterval(t *testing.T) {
   154  	// 33% chance of being 1.
   155  	assertEquals(t, 1, getRandomValueFromInterval(0.5, 0, 2))
   156  	assertEquals(t, 1, getRandomValueFromInterval(0.5, 0.33, 2))
   157  	// 33% chance of being 2.
   158  	assertEquals(t, 2, getRandomValueFromInterval(0.5, 0.34, 2))
   159  	assertEquals(t, 2, getRandomValueFromInterval(0.5, 0.66, 2))
   160  	// 33% chance of being 3.
   161  	assertEquals(t, 3, getRandomValueFromInterval(0.5, 0.67, 2))
   162  	assertEquals(t, 3, getRandomValueFromInterval(0.5, 0.99, 2))
   163  }
   164  
   165  type TestClock struct {
   166  	i     time.Duration
   167  	start time.Time
   168  }
   169  
   170  func (c *TestClock) Now() time.Time {
   171  	t := c.start.Add(c.i)
   172  	c.i += time.Second
   173  	return t
   174  }
   175  
   176  func TestGetElapsedTime(t *testing.T) {
   177  	exp := NewExponentialBackOff()
   178  	exp.Clock = &TestClock{}
   179  	exp.Reset()
   180  
   181  	elapsedTime := exp.GetElapsedTime()
   182  	if elapsedTime != time.Second {
   183  		t.Errorf("elapsedTime=%d", elapsedTime)
   184  	}
   185  }
   186  
   187  func TestMaxElapsedTime(t *testing.T) {
   188  	exp := NewExponentialBackOff()
   189  	exp.Clock = &TestClock{start: time.Time{}.Add(10000 * time.Second)}
   190  	// Change the currentElapsedTime to be 0 ensuring that the elapsed time will be greater
   191  	// than the max elapsed time.
   192  	exp.startTime = time.Time{}
   193  	assertEquals(t, Stop, exp.NextBackOff())
   194  }
   195  
   196  func TestCustomStop(t *testing.T) {
   197  	exp := NewExponentialBackOff()
   198  	customStop := time.Minute
   199  	exp.Stop = customStop
   200  	exp.Clock = &TestClock{start: time.Time{}.Add(10000 * time.Second)}
   201  	// Change the currentElapsedTime to be 0 ensuring that the elapsed time will be greater
   202  	// than the max elapsed time.
   203  	exp.startTime = time.Time{}
   204  	assertEquals(t, customStop, exp.NextBackOff())
   205  }
   206  
   207  func TestBackOffOverflow(t *testing.T) {
   208  	var (
   209  		testInitialInterval time.Duration = math.MaxInt64 / 2
   210  		testMaxInterval     time.Duration = math.MaxInt64
   211  		testMultiplier                    = 2.1
   212  	)
   213  
   214  	exp := NewExponentialBackOff()
   215  	exp.InitialInterval = testInitialInterval
   216  	exp.Multiplier = testMultiplier
   217  	exp.MaxInterval = testMaxInterval
   218  	exp.Reset()
   219  
   220  	exp.NextBackOff()
   221  	// Assert that when an overflow is possible, the current varerval time.Duration is set to the max varerval time.Duration.
   222  	assertEquals(t, testMaxInterval, exp.currentInterval)
   223  }
   224  
   225  func assertEquals(t *testing.T, expected, value time.Duration) {
   226  	if expected != value {
   227  		t.Errorf("got: %d, expected: %d", value, expected)
   228  	}
   229  }
   230  
   231  type testTimer struct {
   232  	timer *time.Timer
   233  }
   234  
   235  func (t *testTimer) Start(_ time.Duration) {
   236  	t.timer = time.NewTimer(0)
   237  }
   238  
   239  func (t *testTimer) Stop() {
   240  	if t.timer != nil {
   241  		t.timer.Stop()
   242  	}
   243  }
   244  
   245  func (t *testTimer) C() <-chan time.Time {
   246  	return t.timer.C
   247  }
   248  
   249  func TestRetry(t *testing.T) {
   250  	const successOn = 3
   251  	i := 0
   252  
   253  	// This function is successful on "successOn" calls.
   254  	f := func(retryTimes int) error {
   255  		i++
   256  		log.Printf("function is called %d. time\n", i)
   257  
   258  		if i == successOn {
   259  			log.Println("OK")
   260  			return nil
   261  		}
   262  
   263  		log.Println("error")
   264  		return errors.New("error")
   265  	}
   266  
   267  	err := RetryNotifyWithTimer(f, NewExponentialBackOff(), nil, &testTimer{})
   268  	if err != nil {
   269  		t.Errorf("unexpected error: %s", err.Error())
   270  	}
   271  	if i != successOn {
   272  		t.Errorf("invalid number of retries: %d", i)
   273  	}
   274  }
   275  
   276  func TestRetryContext(t *testing.T) {
   277  	cancelOn := 3
   278  	i := 0
   279  
   280  	ctx, cancel := context.WithCancel(context.Background())
   281  	defer cancel()
   282  
   283  	// This function cancels context on "cancelOn" calls.
   284  	f := func(retryTimes int) error {
   285  		i++
   286  		log.Printf("function is called %d. time\n", i)
   287  
   288  		// cancelling the context in the operation function is not a typical
   289  		// use-case, however it allows to get predictable test results.
   290  		if i == cancelOn {
   291  			cancel()
   292  		}
   293  
   294  		log.Println("error")
   295  		return fmt.Errorf("error (%d)", i)
   296  	}
   297  
   298  	err := RetryNotifyWithTimer(f, WithContext(NewConstantBackOff(time.Millisecond), ctx), nil, &testTimer{})
   299  	if err == nil {
   300  		t.Errorf("error is unexpectedly nil")
   301  	}
   302  	if !errors.Is(err, context.Canceled) {
   303  		t.Errorf("unexpected error: %s", err.Error())
   304  	}
   305  	if i != cancelOn {
   306  		t.Errorf("invalid number of retries: %d", i)
   307  	}
   308  }
   309  
   310  func TestRetryPermanent(t *testing.T) {
   311  	ensureRetries := func(test string, shouldRetry bool, f func() error) {
   312  		numRetries := -1
   313  		maxRetries := 1
   314  
   315  		_ = RetryNotifyWithTimer(
   316  			func(retryTimes int) error {
   317  				numRetries++
   318  				if numRetries >= maxRetries {
   319  					return Permanent(errors.New("forced"))
   320  				}
   321  				return f()
   322  			},
   323  			NewExponentialBackOff(),
   324  			nil,
   325  			&testTimer{},
   326  		)
   327  
   328  		if shouldRetry && numRetries == 0 {
   329  			t.Errorf("Test: '%s', backoff should have retried", test)
   330  		}
   331  
   332  		if !shouldRetry && numRetries > 0 {
   333  			t.Errorf("Test: '%s', backoff should not have retried", test)
   334  		}
   335  	}
   336  
   337  	for _, testCase := range []struct {
   338  		name        string
   339  		f           func() error
   340  		shouldRetry bool
   341  	}{
   342  		{
   343  			"nil test",
   344  			func() error {
   345  				return nil
   346  			},
   347  			false,
   348  		},
   349  		{
   350  			"io.EOF",
   351  			func() error {
   352  				return io.EOF
   353  			},
   354  			true,
   355  		},
   356  		{
   357  			"Permanent(io.EOF)",
   358  			func() error {
   359  				return Permanent(io.EOF)
   360  			},
   361  			false,
   362  		},
   363  		{
   364  			"Wrapped: Permanent(io.EOF)",
   365  			func() error {
   366  				return fmt.Errorf("wrapped error: %w", Permanent(io.EOF))
   367  			},
   368  			false,
   369  		},
   370  	} {
   371  		ensureRetries(testCase.name, testCase.shouldRetry, testCase.f)
   372  	}
   373  }
   374  
   375  func TestPermanent(t *testing.T) {
   376  	want := errors.New("foo")
   377  	other := errors.New("bar")
   378  	err := Permanent(want)
   379  
   380  	got := errors.Unwrap(err)
   381  	if got != want {
   382  		t.Errorf("got %v, want %v", got, want)
   383  	}
   384  
   385  	if is := errors.Is(err, want); !is {
   386  		t.Errorf("err: %v is not %v", err, want)
   387  	}
   388  
   389  	if is := errors.Is(err, other); is {
   390  		t.Errorf("err: %v is %v", err, other)
   391  	}
   392  
   393  	wrapped := fmt.Errorf("wrapped: %w", err)
   394  	var permanent *PermanentError
   395  	if !errors.As(wrapped, &permanent) {
   396  		t.Errorf("errors.As(%v, %v)", wrapped, permanent)
   397  	}
   398  
   399  	err = Permanent(nil)
   400  	if err != nil {
   401  		t.Errorf("got %v, want nil", err)
   402  	}
   403  }
   404  
   405  func TestTicker(t *testing.T) {
   406  	const successOn = 3
   407  	i := 0
   408  
   409  	// This function is successful on "successOn" calls.
   410  	f := func() error {
   411  		i++
   412  		log.Printf("function is called %d. time\n", i)
   413  
   414  		if i == successOn {
   415  			log.Println("OK")
   416  			return nil
   417  		}
   418  
   419  		log.Println("error")
   420  		return errors.New("error")
   421  	}
   422  
   423  	b := NewExponentialBackOff()
   424  	ticker := NewTickerWithTimer(b, &testTimer{})
   425  
   426  	var err error
   427  	for range ticker.C {
   428  		if err = f(); err != nil {
   429  			t.Log(err)
   430  			continue
   431  		}
   432  
   433  		break
   434  	}
   435  	if err != nil {
   436  		t.Errorf("unexpected error: %s", err.Error())
   437  	}
   438  	if i != successOn {
   439  		t.Errorf("invalid number of retries: %d", i)
   440  	}
   441  }
   442  
   443  func TestTickerContext(t *testing.T) {
   444  	i := 0
   445  
   446  	ctx, cancel := context.WithCancel(context.Background())
   447  
   448  	// Cancel context as soon as it is created.
   449  	// Ticker must stop after first tick.
   450  	cancel()
   451  
   452  	// This function cancels context on "cancelOn" calls.
   453  	f := func() error {
   454  		i++
   455  		log.Printf("function is called %d. time\n", i)
   456  		log.Println("error")
   457  		return fmt.Errorf("error (%d)", i)
   458  	}
   459  
   460  	b := WithContext(NewConstantBackOff(0), ctx)
   461  	ticker := NewTickerWithTimer(b, &testTimer{})
   462  
   463  	var err error
   464  	for range ticker.C {
   465  		if err = f(); err != nil {
   466  			t.Log(err)
   467  			continue
   468  		}
   469  
   470  		ticker.Stop()
   471  		break
   472  	}
   473  	// Ticker is guaranteed to tick at least once.
   474  	if err == nil {
   475  		t.Errorf("error is unexpectedly nil")
   476  	}
   477  	if err.Error() != "error (1)" {
   478  		t.Errorf("unexpected error: %s", err)
   479  	}
   480  	if i != 1 {
   481  		t.Errorf("invalid number of retries: %d", i)
   482  	}
   483  }
   484  
   485  func TestTickerDefaultTimer(t *testing.T) {
   486  	b := NewExponentialBackOff()
   487  	ticker := NewTickerWithTimer(b, nil)
   488  	// ensure a timer was actually assigned, instead of remaining as nil.
   489  	<-ticker.C
   490  }
   491  
   492  func TestMaxTriesHappy(t *testing.T) {
   493  	r := rand.New(rand.NewSource(time.Now().UnixNano()))
   494  	max := 17 + r.Intn(13)
   495  	bo := WithMaxRetries(&ZeroBackOff{}, uint64(max))
   496  
   497  	// Load up the tries count, but reset should clear the record
   498  	for ix := 0; ix < max/2; ix++ {
   499  		bo.NextBackOff()
   500  	}
   501  	bo.Reset()
   502  
   503  	// Now fill the tries count all the way up
   504  	for ix := 0; ix < max; ix++ {
   505  		d := bo.NextBackOff()
   506  		if d == Stop {
   507  			t.Errorf("returned Stop on try %d", ix)
   508  		}
   509  	}
   510  
   511  	// We have now called the BackOff max number of times, we expect
   512  	// the next result to be Stop, even if we try it multiple times
   513  	for ix := 0; ix < 7; ix++ {
   514  		d := bo.NextBackOff()
   515  		if d != Stop {
   516  			t.Error("invalid next back off")
   517  		}
   518  	}
   519  
   520  	// Reset makes it all work again
   521  	bo.Reset()
   522  	d := bo.NextBackOff()
   523  	if d == Stop {
   524  		t.Error("returned Stop after reset")
   525  	}
   526  }
   527  
   528  // https://github.com/cenkalti/backoff/issues/80
   529  func TestMaxTriesZero(t *testing.T) {
   530  	var called int
   531  
   532  	b := WithMaxRetries(&ZeroBackOff{}, 0)
   533  
   534  	err := Retry(func(retryTimes int) error {
   535  		called++
   536  		return errors.New("err")
   537  	}, b)
   538  
   539  	if err == nil {
   540  		t.Errorf("error expected, nil founc")
   541  	}
   542  	if called != 1 {
   543  		t.Errorf("operation is called %d times", called)
   544  	}
   545  }