github.com/cenkalti/backoff/v4@v4.2.1/retry_test.go (about)

     1  package backoff
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"log"
     9  	"testing"
    10  	"time"
    11  )
    12  
    13  type testTimer struct {
    14  	timer *time.Timer
    15  }
    16  
    17  func (t *testTimer) Start(duration time.Duration) {
    18  	t.timer = time.NewTimer(0)
    19  }
    20  
    21  func (t *testTimer) Stop() {
    22  	if t.timer != nil {
    23  		t.timer.Stop()
    24  	}
    25  }
    26  
    27  func (t *testTimer) C() <-chan time.Time {
    28  	return t.timer.C
    29  }
    30  
    31  func TestRetry(t *testing.T) {
    32  	const successOn = 3
    33  	var i = 0
    34  
    35  	// This function is successful on "successOn" calls.
    36  	f := func() error {
    37  		i++
    38  		log.Printf("function is called %d. time\n", i)
    39  
    40  		if i == successOn {
    41  			log.Println("OK")
    42  			return nil
    43  		}
    44  
    45  		log.Println("error")
    46  		return errors.New("error")
    47  	}
    48  
    49  	err := RetryNotifyWithTimer(f, NewExponentialBackOff(), nil, &testTimer{})
    50  	if err != nil {
    51  		t.Errorf("unexpected error: %s", err.Error())
    52  	}
    53  	if i != successOn {
    54  		t.Errorf("invalid number of retries: %d", i)
    55  	}
    56  }
    57  
    58  func TestRetryWithData(t *testing.T) {
    59  	const successOn = 3
    60  	var i = 0
    61  
    62  	// This function is successful on "successOn" calls.
    63  	f := func() (int, error) {
    64  		i++
    65  		log.Printf("function is called %d. time\n", i)
    66  
    67  		if i == successOn {
    68  			log.Println("OK")
    69  			return 42, nil
    70  		}
    71  
    72  		log.Println("error")
    73  		return 1, errors.New("error")
    74  	}
    75  
    76  	res, err := RetryNotifyWithTimerAndData(f, NewExponentialBackOff(), nil, &testTimer{})
    77  	if err != nil {
    78  		t.Errorf("unexpected error: %s", err.Error())
    79  	}
    80  	if i != successOn {
    81  		t.Errorf("invalid number of retries: %d", i)
    82  	}
    83  	if res != 42 {
    84  		t.Errorf("invalid data in response: %d, expected 42", res)
    85  	}
    86  }
    87  
    88  func TestRetryContext(t *testing.T) {
    89  	var cancelOn = 3
    90  	var i = 0
    91  
    92  	ctx, cancel := context.WithCancel(context.Background())
    93  	defer cancel()
    94  
    95  	// This function cancels context on "cancelOn" calls.
    96  	f := func() error {
    97  		i++
    98  		log.Printf("function is called %d. time\n", i)
    99  
   100  		// cancelling the context in the operation function is not a typical
   101  		// use-case, however it allows to get predictable test results.
   102  		if i == cancelOn {
   103  			cancel()
   104  		}
   105  
   106  		log.Println("error")
   107  		return fmt.Errorf("error (%d)", i)
   108  	}
   109  
   110  	err := RetryNotifyWithTimer(f, WithContext(NewConstantBackOff(time.Millisecond), ctx), nil, &testTimer{})
   111  	if err == nil {
   112  		t.Errorf("error is unexpectedly nil")
   113  	}
   114  	if !errors.Is(err, context.Canceled) {
   115  		t.Errorf("unexpected error: %s", err.Error())
   116  	}
   117  	if i != cancelOn {
   118  		t.Errorf("invalid number of retries: %d", i)
   119  	}
   120  }
   121  
   122  func TestRetryPermanent(t *testing.T) {
   123  	ensureRetries := func(test string, shouldRetry bool, f func() (int, error), expectRes int) {
   124  		numRetries := -1
   125  		maxRetries := 1
   126  
   127  		res, _ := RetryNotifyWithTimerAndData(
   128  			func() (int, error) {
   129  				numRetries++
   130  				if numRetries >= maxRetries {
   131  					return -1, Permanent(errors.New("forced"))
   132  				}
   133  				return f()
   134  			},
   135  			NewExponentialBackOff(),
   136  			nil,
   137  			&testTimer{},
   138  		)
   139  
   140  		if shouldRetry && numRetries == 0 {
   141  			t.Errorf("Test: '%s', backoff should have retried", test)
   142  		}
   143  
   144  		if !shouldRetry && numRetries > 0 {
   145  			t.Errorf("Test: '%s', backoff should not have retried", test)
   146  		}
   147  
   148  		if res != expectRes {
   149  			t.Errorf("Test: '%s', got res %d but expected %d", test, res, expectRes)
   150  		}
   151  	}
   152  
   153  	for _, testCase := range []struct {
   154  		name        string
   155  		f           func() (int, error)
   156  		shouldRetry bool
   157  		res         int
   158  	}{
   159  		{
   160  			"nil test",
   161  			func() (int, error) {
   162  				return 1, nil
   163  			},
   164  			false,
   165  			1,
   166  		},
   167  		{
   168  			"io.EOF",
   169  			func() (int, error) {
   170  				return 2, io.EOF
   171  			},
   172  			true,
   173  			-1,
   174  		},
   175  		{
   176  			"Permanent(io.EOF)",
   177  			func() (int, error) {
   178  				return 3, Permanent(io.EOF)
   179  			},
   180  			false,
   181  			3,
   182  		},
   183  		{
   184  			"Wrapped: Permanent(io.EOF)",
   185  			func() (int, error) {
   186  				return 4, fmt.Errorf("Wrapped error: %w", Permanent(io.EOF))
   187  			},
   188  			false,
   189  			4,
   190  		},
   191  	} {
   192  		ensureRetries(testCase.name, testCase.shouldRetry, testCase.f, testCase.res)
   193  	}
   194  }
   195  
   196  func TestPermanent(t *testing.T) {
   197  	want := errors.New("foo")
   198  	other := errors.New("bar")
   199  	var err error = Permanent(want)
   200  
   201  	got := errors.Unwrap(err)
   202  	if got != want {
   203  		t.Errorf("got %v, want %v", got, want)
   204  	}
   205  
   206  	if is := errors.Is(err, want); !is {
   207  		t.Errorf("err: %v is not %v", err, want)
   208  	}
   209  
   210  	if is := errors.Is(err, other); is {
   211  		t.Errorf("err: %v is %v", err, other)
   212  	}
   213  
   214  	wrapped := fmt.Errorf("wrapped: %w", err)
   215  	var permanent *PermanentError
   216  	if !errors.As(wrapped, &permanent) {
   217  		t.Errorf("errors.As(%v, %v)", wrapped, permanent)
   218  	}
   219  
   220  	err = Permanent(nil)
   221  	if err != nil {
   222  		t.Errorf("got %v, want nil", err)
   223  	}
   224  }