github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/retry/retry_test.go (about)

     1  package retry
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"testing"
     8  	"time"
     9  
    10  	"google.golang.org/grpc/codes"
    11  	"google.golang.org/grpc/status"
    12  )
    13  
    14  func alwaysRetry(error) bool { return true }
    15  
    16  // failer returns an error until its counter reaches 0, at which point it returns finalErr, which is
    17  // nil by default (no error).
    18  type failer struct {
    19  	attempts int
    20  	finalErr error
    21  }
    22  
    23  func (f *failer) run() error {
    24  	f.attempts--
    25  	if f.attempts < 0 {
    26  		return f.finalErr
    27  	}
    28  	return errors.New("failing")
    29  }
    30  
    31  func policyString(bp BackoffPolicy) string {
    32  	base := fmt.Sprintf("baseDelay: %v, maxDelay: %v", bp.baseDelay, bp.maxDelay)
    33  	if bp.maxAttempts == 0 {
    34  		return base + ": unlimited retries"
    35  	}
    36  	return base + fmt.Sprintf(": max %d attempts", bp.maxAttempts)
    37  }
    38  
    39  func TestRetries(t *testing.T) {
    40  	cases := []struct {
    41  		policy      BackoffPolicy
    42  		sr          ShouldRetry
    43  		attempts    int
    44  		finalErr    error
    45  		wantError   bool
    46  		wantErrCode codes.Code
    47  	}{
    48  		{
    49  			policy:   ExponentialBackoff(time.Millisecond, time.Millisecond, UnlimitedAttempts),
    50  			sr:       alwaysRetry,
    51  			attempts: 5,
    52  		},
    53  		{
    54  			policy:      ExponentialBackoff(time.Millisecond, time.Millisecond, 5),
    55  			sr:          alwaysRetry,
    56  			attempts:    5,
    57  			finalErr:    status.Error(codes.Unimplemented, "unimplemented!"),
    58  			wantError:   true,
    59  			wantErrCode: codes.Unimplemented,
    60  		},
    61  		{
    62  			policy:    ExponentialBackoff(time.Millisecond, time.Millisecond, 1),
    63  			sr:        alwaysRetry,
    64  			wantError: true,
    65  			attempts:  1,
    66  		},
    67  		{
    68  			policy:    ExponentialBackoff(time.Millisecond, time.Millisecond, 2),
    69  			sr:        alwaysRetry,
    70  			wantError: true,
    71  			attempts:  2,
    72  		},
    73  		{
    74  			policy:   ExponentialBackoff(time.Millisecond, time.Millisecond, 5),
    75  			sr:       alwaysRetry,
    76  			attempts: 5,
    77  		},
    78  	}
    79  	ctx := context.Background()
    80  	for _, c := range cases {
    81  		f := failer{
    82  			attempts: 4,
    83  			finalErr: c.finalErr,
    84  		}
    85  		err := WithPolicy(context.WithValue(ctx, TimeAfterContextKey, func(time.Duration) <-chan time.Time {
    86  			c := make(chan time.Time)
    87  			close(c) // Reading from the closed channel will immediately succeed.
    88  			return c
    89  		}), c.sr, c.policy, f.run)
    90  		attempts := 4 - f.attempts
    91  		if attempts != c.attempts {
    92  			t.Errorf("%s: expected %d attempts, got %d", policyString(c.policy), c.attempts, attempts)
    93  		}
    94  		switch {
    95  		case c.wantError:
    96  			if err == nil {
    97  				t.Errorf("%s: want error, got no error", policyString(c.policy))
    98  			}
    99  			if s, ok := status.FromError(err); c.wantErrCode != 0 && (!ok || s.Code() != c.wantErrCode) {
   100  				t.Errorf("%s: want error with code %v, got %v", policyString(c.policy), c.wantErrCode.String(), err)
   101  			}
   102  		case err != nil:
   103  			t.Errorf("%s: want success, got error: %v", policyString(c.policy), err)
   104  		}
   105  	}
   106  }