github.com/thiagoyeds/go-cloud@v0.26.0/internal/retry/retry_test.go (about)

     1  // Copyright 2018 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package retry
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"os"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/googleapis/gax-go/v2"
    25  	"golang.org/x/xerrors"
    26  )
    27  
    28  // Errors to distinguish retryable and non-retryable cases.
    29  var (
    30  	errRetry   = errors.New("retry")
    31  	errNoRetry = errors.New("no retry")
    32  )
    33  
    34  func retryable(err error) bool {
    35  	return err == errRetry
    36  }
    37  
    38  func TestCall(t *testing.T) {
    39  	for _, test := range []struct {
    40  		desc        string
    41  		isRetryable func(error) bool
    42  		f           func(int) error // passed the number of calls so far
    43  		wantErr     error           // the return value of call
    44  		wantCount   int             // number of times f is called
    45  	}{
    46  		{
    47  			desc:        "f returns nil",
    48  			isRetryable: retryable,
    49  			f:           func(int) error { return nil },
    50  			wantCount:   1,
    51  			wantErr:     nil,
    52  		},
    53  		{
    54  			desc:        "f returns non-retryable error",
    55  			isRetryable: retryable,
    56  			f:           func(int) error { return errNoRetry },
    57  			wantCount:   1,
    58  			wantErr:     errNoRetry,
    59  		},
    60  		{
    61  			desc:        "f returns retryable error",
    62  			isRetryable: retryable,
    63  			f: func(n int) error {
    64  				if n < 2 {
    65  					return errRetry
    66  				}
    67  				return errNoRetry
    68  			},
    69  			wantCount: 3,
    70  			wantErr:   errNoRetry,
    71  		},
    72  		{
    73  			desc:        "f returns context error", // same as non-retryable
    74  			isRetryable: retryable,
    75  			f:           func(int) error { return context.Canceled },
    76  			wantCount:   1,
    77  			wantErr:     context.Canceled,
    78  		},
    79  	} {
    80  		t.Run(test.desc, func(t *testing.T) {
    81  			sleep := func(context.Context, time.Duration) error { return nil }
    82  			gotCount := 0
    83  			f := func() error { gotCount++; return test.f(gotCount - 1) }
    84  			gotErr := call(context.Background(), gax.Backoff{}, test.isRetryable, f, sleep)
    85  			if gotErr != test.wantErr {
    86  				t.Errorf("error: got %v, want %v", gotErr, test.wantErr)
    87  			}
    88  			if gotCount != test.wantCount {
    89  				t.Errorf("retry count: got %d, want %d", gotCount, test.wantCount)
    90  			}
    91  		})
    92  	}
    93  }
    94  
    95  func TestCallCancel(t *testing.T) {
    96  	t.Run("done on entry", func(t *testing.T) {
    97  		// If the context is done on entry, f is never called.
    98  		ctx, cancel := context.WithCancel(context.Background())
    99  		cancel()
   100  		gotCount := 0
   101  		f := func() error { gotCount++; return nil }
   102  		gotErr := call(ctx, gax.Backoff{}, retryable, f, nil)
   103  		if gotCount != 0 {
   104  			t.Errorf("retry count: got %d, want 0", gotCount)
   105  		}
   106  		wantErr := &ContextError{CtxErr: context.Canceled}
   107  		if !equalContextError(gotErr, wantErr) {
   108  			t.Errorf("error: got %v, want %v", gotErr, wantErr)
   109  		}
   110  	})
   111  	t.Run("done in sleep", func(t *testing.T) {
   112  		// If the context is done during sleep, we get a ContextError.
   113  		gotCount := 0
   114  		f := func() error { gotCount++; return errRetry }
   115  		sleep := func(context.Context, time.Duration) error { return context.Canceled }
   116  		gotErr := call(context.Background(), gax.Backoff{}, retryable, f, sleep)
   117  		if gotCount != 1 {
   118  			t.Errorf("retry count: got %d, want 1", gotCount)
   119  		}
   120  		wantErr := &ContextError{CtxErr: context.Canceled, FuncErr: errRetry}
   121  		if !equalContextError(gotErr, wantErr) {
   122  			t.Errorf("error: got %v, want %v", gotErr, wantErr)
   123  		}
   124  	})
   125  }
   126  
   127  func equalContextError(got error, want *ContextError) bool {
   128  	cerr, ok := got.(*ContextError)
   129  	if !ok {
   130  		return false
   131  	}
   132  	return cerr.CtxErr == want.CtxErr && cerr.FuncErr == want.FuncErr
   133  }
   134  
   135  func TestErrorsIs(t *testing.T) {
   136  	err := &ContextError{
   137  		CtxErr:  context.Canceled,
   138  		FuncErr: os.ErrExist,
   139  	}
   140  	for _, target := range []error{err, context.Canceled, os.ErrExist} {
   141  		if !xerrors.Is(err, target) {
   142  			t.Errorf("xerrors.Is(%v) == false, want true", target)
   143  		}
   144  	}
   145  }