github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/retry/retry_test.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package retry
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/stretchr/testify/require"
    14  
    15  	"github.com/Schaudge/grailbase/errors"
    16  )
    17  
    18  func TestBackoff(t *testing.T) {
    19  	policy := Backoff(time.Second, 10*time.Second, 2)
    20  	expect := []time.Duration{
    21  		time.Second,
    22  		2 * time.Second,
    23  		4 * time.Second,
    24  		8 * time.Second,
    25  		10 * time.Second,
    26  		10 * time.Second,
    27  	}
    28  	for retries, wait := range expect {
    29  		keepgoing, dur := policy.Retry(retries)
    30  		if !keepgoing {
    31  			t.Fatal("!keepgoing")
    32  		}
    33  		if got, want := dur, wait; got != want {
    34  			t.Errorf("retry %d: got %v, want %v", retries, got, want)
    35  		}
    36  	}
    37  }
    38  
    39  // TestBackoffOverflow tests the behavior of exponential backoff for large
    40  // numbers of retries.
    41  func TestBackoffOverflow(t *testing.T) {
    42  	policy := Backoff(time.Second, 10*time.Second, 2)
    43  	expect := []time.Duration{
    44  		10 * time.Second,
    45  		10 * time.Second,
    46  		10 * time.Second,
    47  		10 * time.Second,
    48  	}
    49  	for retries, wait := range expect {
    50  		// Use a large number of retries that might overflow exponential
    51  		// calculations.
    52  		keepgoing, dur := policy.Retry(1000 + retries)
    53  		if !keepgoing {
    54  			t.Fatal("!keepgoing")
    55  		}
    56  		if got, want := dur, wait; got != want {
    57  			t.Errorf("retry %d: got %v, want %v", retries, got, want)
    58  		}
    59  	}
    60  }
    61  
    62  func TestBackoffWithFullJitter(t *testing.T) {
    63  	policy := Jitter(Backoff(time.Second, 10*time.Second, 2), 1.0)
    64  	checkWithin := func(t *testing.T, wantMin, wantMax, got time.Duration) {
    65  		if got < wantMin || got > wantMax {
    66  			t.Errorf("got %v, want within (%v, %v)", got, wantMin, wantMax)
    67  		}
    68  	}
    69  	expect := []time.Duration{
    70  		time.Second,
    71  		2 * time.Second,
    72  		4 * time.Second,
    73  		8 * time.Second,
    74  		10 * time.Second,
    75  		10 * time.Second,
    76  	}
    77  	for retries, wait := range expect {
    78  		keepgoing, dur := policy.Retry(retries)
    79  		if !keepgoing {
    80  			t.Fatal("!keepgoing")
    81  		}
    82  		checkWithin(t, 0, wait, dur)
    83  	}
    84  }
    85  
    86  func TestBackoffWithEqualJitter(t *testing.T) {
    87  	policy := Jitter(Backoff(time.Second, 10*time.Second, 2), 0.5)
    88  	checkWithin := func(t *testing.T, wantMin, wantMax, got time.Duration) {
    89  		if got < wantMin || got > wantMax {
    90  			t.Errorf("got %v, want within (%v, %v)", got, wantMin, wantMax)
    91  		}
    92  	}
    93  	expect := []time.Duration{
    94  		time.Second,
    95  		2 * time.Second,
    96  		4 * time.Second,
    97  		8 * time.Second,
    98  		10 * time.Second,
    99  		10 * time.Second,
   100  	}
   101  	for retries, wait := range expect {
   102  		keepgoing, dur := policy.Retry(retries)
   103  		if !keepgoing {
   104  			t.Fatal("!keepgoing")
   105  		}
   106  		checkWithin(t, wait/2, wait, dur)
   107  	}
   108  }
   109  
   110  func TestBackoffWithTimeout(t *testing.T) {
   111  	policy := BackoffWithTimeout(time.Second, 10*time.Second, 2)
   112  	expect := []time.Duration{
   113  		time.Second,
   114  		2 * time.Second,
   115  		4 * time.Second,
   116  		8 * time.Second,
   117  	}
   118  	var retries = 0
   119  	for _, wait := range expect {
   120  		keepgoing, dur := policy.Retry(retries)
   121  		if !keepgoing {
   122  			t.Fatal("!keepgoing")
   123  		}
   124  		if got, want := dur, wait; got != want {
   125  			t.Errorf("retry %d: got %v, want %v", retries, got, want)
   126  		}
   127  		retries++
   128  	}
   129  	keepgoing, _ := policy.Retry(retries)
   130  	if keepgoing {
   131  		t.Errorf("keepgoing: got %v, want %v", keepgoing, false)
   132  	}
   133  
   134  }
   135  
   136  func TestWaitCancel(t *testing.T) {
   137  	ctx, cancel := context.WithCancel(context.Background())
   138  	policy := Backoff(time.Hour, time.Hour, 1)
   139  	cancel()
   140  	if got, want := Wait(ctx, policy, 0), context.Canceled; got != want {
   141  		t.Errorf("got %v, want %v", got, want)
   142  	}
   143  }
   144  
   145  func TestWaitDeadline(t *testing.T) {
   146  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   147  	defer cancel()
   148  	policy := Backoff(time.Hour, time.Hour, 1)
   149  	if got, want := Wait(ctx, policy, 0), errors.E(errors.Timeout); !errors.Match(want, got) {
   150  		t.Errorf("got %v, want %v", got, want)
   151  	}
   152  }
   153  
   154  func testWrapperHelper(i int) (int, error) {
   155  	if i == 0 {
   156  		return 0, fmt.Errorf("This is an Error")
   157  	}
   158  	return 9999, nil
   159  }
   160  
   161  func testWrapperHelperLong(i int) (int, int, error) {
   162  	if i == 0 {
   163  		return 0, 0, fmt.Errorf("This is an Error")
   164  	}
   165  	return 1, 2, nil
   166  }
   167  
   168  func TestWaitForFn(t *testing.T) {
   169  	ctx, cancel := context.WithCancel(context.Background())
   170  	policy := Backoff(time.Hour, time.Hour, 1)
   171  	cancel()
   172  
   173  	output := WaitForFn(ctx, policy, testWrapperHelper, 0)
   174  	require.EqualError(t, output[1].Interface().(error), "This is an Error")
   175  
   176  	output = WaitForFn(ctx, policy, testWrapperHelper, 55)
   177  	require.Equal(t, 9999, int(output[0].Int()))
   178  
   179  	var err error
   180  	defer func() {
   181  		if r := recover(); r != nil {
   182  			err = fmt.Errorf("wrong number of input, expected: 1, actual: 3")
   183  		}
   184  	}()
   185  	WaitForFn(ctx, policy, testWrapperHelper, 1, 2, 3)
   186  	require.EqualError(t, err, "wrong number of input, expected: 1, actual: 3")
   187  }
   188  
   189  func TestWaitForFnLong(t *testing.T) {
   190  	ctx, cancel := context.WithCancel(context.Background())
   191  	policy := Backoff(time.Hour, time.Hour, 1)
   192  	cancel()
   193  
   194  	output := WaitForFn(ctx, policy, testWrapperHelperLong, 0)
   195  	require.EqualError(t, output[2].Interface().(error), "This is an Error")
   196  
   197  	output = WaitForFn(ctx, policy, testWrapperHelperLong, 55)
   198  	require.Equal(t, 1, int(output[0].Int()))
   199  	require.Equal(t, 2, int(output[1].Int()))
   200  
   201  }
   202  
   203  func TestMaxRetries(t *testing.T) {
   204  	retryImmediately := Backoff(0, 0, 0)
   205  
   206  	type testArgs struct {
   207  		retryPolicy Policy
   208  		fn          func(*int) error
   209  	}
   210  	testCases := []struct {
   211  		testName string
   212  		args     testArgs
   213  		expected int
   214  	}{
   215  		{
   216  			testName: "function always fails",
   217  			args: testArgs{
   218  				retryPolicy: MaxRetries(retryImmediately, 1),
   219  				fn: func(callCount *int) error {
   220  					*callCount++
   221  
   222  					return fmt.Errorf("always fail")
   223  				},
   224  			},
   225  			expected: 2,
   226  		},
   227  	}
   228  
   229  	for _, tc := range testCases {
   230  		t.Run(tc.testName, func(t *testing.T) {
   231  			callCount := 0
   232  
   233  			WaitForFn(context.Background(), tc.args.retryPolicy, tc.args.fn, &callCount)
   234  
   235  			require.Equal(t, tc.expected, callCount)
   236  		})
   237  	}
   238  }