github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/retry/retry_test.go (about)

     1  package retry
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"strconv"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/stretchr/testify/require"
    12  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
    13  	grpcCodes "google.golang.org/grpc/codes"
    14  	grpcStatus "google.golang.org/grpc/status"
    15  
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
    17  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    18  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
    19  )
    20  
    21  func TestRetryModes(t *testing.T) {
    22  	for _, idempotentType := range []idempotency{
    23  		idempotent,
    24  		nonIdempotent,
    25  	} {
    26  		t.Run(idempotentType.String(), func(t *testing.T) {
    27  			for i, tt := range errsToCheck {
    28  				t.Run(strconv.Itoa(i)+"."+tt.err.Error(), func(t *testing.T) {
    29  					m := Check(tt.err)
    30  					require.Equal(t, tt.canRetry[idempotent], m.MustRetry(true))
    31  					require.Equal(t, tt.canRetry[nonIdempotent], m.MustRetry(false))
    32  					require.Equal(t, tt.backoff, m.BackoffType())
    33  					require.Equal(t, tt.deleteSession, m.MustDeleteSession())
    34  				})
    35  			}
    36  		})
    37  	}
    38  }
    39  
    40  type CustomError struct {
    41  	Err error
    42  }
    43  
    44  func (e *CustomError) Error() string {
    45  	return fmt.Sprintf("custom error: %v", e.Err)
    46  }
    47  
    48  func (e *CustomError) Unwrap() error {
    49  	return e.Err
    50  }
    51  
    52  func TestRetryWithCustomErrors(t *testing.T) {
    53  	var (
    54  		limit = 10
    55  		ctx   = context.Background()
    56  	)
    57  	for _, tt := range []struct {
    58  		error     error
    59  		retriable bool
    60  	}{
    61  		{
    62  			error: &CustomError{
    63  				Err: RetryableError(
    64  					fmt.Errorf("custom error"),
    65  					WithDeleteSession(),
    66  				),
    67  			},
    68  			retriable: true,
    69  		},
    70  		{
    71  			error: &CustomError{
    72  				Err: xerrors.Operation(
    73  					xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION),
    74  				),
    75  			},
    76  			retriable: true,
    77  		},
    78  		{
    79  			error: &CustomError{
    80  				Err: fmt.Errorf(
    81  					"wrapped error: %w",
    82  					xerrors.Operation(
    83  						xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION),
    84  					),
    85  				),
    86  			},
    87  			retriable: true,
    88  		},
    89  		{
    90  			error: &CustomError{
    91  				Err: fmt.Errorf(
    92  					"wrapped error: %w",
    93  					xerrors.Operation(
    94  						xerrors.WithStatusCode(Ydb.StatusIds_UNAUTHORIZED),
    95  					),
    96  				),
    97  			},
    98  			retriable: false,
    99  		},
   100  	} {
   101  		t.Run(tt.error.Error(), func(t *testing.T) {
   102  			i := 0
   103  			err := Retry(ctx, func(ctx context.Context) error {
   104  				i++
   105  				if i < limit {
   106  					return tt.error
   107  				}
   108  
   109  				return nil
   110  			})
   111  			if tt.retriable {
   112  				if i != limit {
   113  					t.Fatalf("unexpected i: %d, queryErr: %v", i, err)
   114  				}
   115  			} else {
   116  				if i != 1 {
   117  					t.Fatalf("unexpected i: %d, queryErr: %v", i, err)
   118  				}
   119  			}
   120  		})
   121  	}
   122  }
   123  
   124  func TestRetryTransportDeadlineExceeded(t *testing.T) {
   125  	cancelCounterValue := 5
   126  	for _, code := range []grpcCodes.Code{
   127  		grpcCodes.DeadlineExceeded,
   128  		grpcCodes.Canceled,
   129  	} {
   130  		t.Run(code.String(), func(t *testing.T) {
   131  			counter := 0
   132  			ctx, cancel := xcontext.WithTimeout(context.Background(), time.Hour)
   133  			err := Retry(ctx, func(ctx context.Context) error {
   134  				counter++
   135  				if !(counter < cancelCounterValue) {
   136  					cancel()
   137  				}
   138  
   139  				return xerrors.Transport(grpcStatus.Error(code, ""))
   140  			}, WithIdempotent(true))
   141  			require.ErrorIs(t, err, context.Canceled)
   142  			require.Equal(t, cancelCounterValue, counter)
   143  		})
   144  	}
   145  }
   146  
   147  func TestRetryTransportCancelled(t *testing.T) {
   148  	cancelCounterValue := 5
   149  	for _, code := range []grpcCodes.Code{
   150  		grpcCodes.DeadlineExceeded,
   151  		grpcCodes.Canceled,
   152  	} {
   153  		t.Run(code.String(), func(t *testing.T) {
   154  			t.Helper()
   155  			counter := 0
   156  			ctx, cancel := xcontext.WithCancel(context.Background())
   157  			err := Retry(ctx, func(ctx context.Context) error {
   158  				counter++
   159  				if !(counter < cancelCounterValue) {
   160  					cancel()
   161  				}
   162  
   163  				return xerrors.Transport(grpcStatus.Error(code, ""))
   164  			}, WithIdempotent(true))
   165  			require.ErrorIs(t, err, context.Canceled)
   166  			require.Equal(t, cancelCounterValue, counter)
   167  		})
   168  	}
   169  }
   170  
   171  type noQuota struct{}
   172  
   173  var errNoQuota = errors.New("no quota")
   174  
   175  func (noQuota) Acquire(ctx context.Context) error {
   176  	return errNoQuota
   177  }
   178  
   179  func TestRetryWithBudget(t *testing.T) {
   180  	xtest.TestManyTimes(t, func(t testing.TB) {
   181  		quota := noQuota{}
   182  		ctx, cancel := context.WithCancel(xtest.Context(t))
   183  		defer cancel()
   184  		err := Retry(ctx, func(ctx context.Context) (err error) {
   185  			return RetryableError(errors.New("custom error"))
   186  		}, WithBudget(quota))
   187  		require.ErrorIs(t, err, errNoQuota)
   188  	})
   189  }
   190  
   191  type MockPanicCallback struct {
   192  	called   bool
   193  	received interface{}
   194  }
   195  
   196  func (m *MockPanicCallback) Call(e interface{}) {
   197  	m.called = true
   198  	m.received = e
   199  }
   200  
   201  func TestOpWithRecover_NoPanic(t *testing.T) {
   202  	ctx := context.Background()
   203  	options := &retryOptions{
   204  		panicCallback: nil,
   205  	}
   206  	op := func(ctx context.Context) (*struct{}, error) {
   207  		return nil, nil //nolint:nilnil
   208  	}
   209  
   210  	_, err := opWithRecover(ctx, options, op)
   211  
   212  	require.NoError(t, err)
   213  }
   214  
   215  func TestOpWithRecover_WithPanic(t *testing.T) {
   216  	ctx := context.Background()
   217  	mockCallback := new(MockPanicCallback)
   218  	options := &retryOptions{
   219  		panicCallback: mockCallback.Call,
   220  	}
   221  	op := func(ctx context.Context) (*struct{}, error) {
   222  		panic("test panic")
   223  	}
   224  
   225  	_, err := opWithRecover(ctx, options, op)
   226  
   227  	require.Error(t, err)
   228  	require.Contains(t, err.Error(), "panic recovered: test panic")
   229  	require.True(t, mockCallback.called)
   230  	require.Equal(t, "test panic", mockCallback.received)
   231  }
   232  
   233  func TestRetryWithResult(t *testing.T) {
   234  	ctx := xtest.Context(t)
   235  	t.Run("HappyWay", func(t *testing.T) {
   236  		v, err := RetryWithResult(ctx, func(ctx context.Context) (*int, error) {
   237  			v := 123
   238  
   239  			return &v, nil
   240  		})
   241  		require.NoError(t, err)
   242  		require.NotNil(t, v)
   243  		require.EqualValues(t, 123, *v)
   244  	})
   245  	t.Run("RetryableError", func(t *testing.T) {
   246  		var counter int
   247  		v, err := RetryWithResult(ctx, func(ctx context.Context) (*int, error) {
   248  			counter++
   249  			if counter < 10 {
   250  				return nil, RetryableError(errors.New("test"))
   251  			}
   252  			v := counter * 123
   253  
   254  			return &v, nil
   255  		})
   256  		require.NoError(t, err)
   257  		require.NotNil(t, v)
   258  		require.EqualValues(t, 1230, *v)
   259  		require.EqualValues(t, 10, counter)
   260  	})
   261  	t.Run("Context", func(t *testing.T) {
   262  		t.Run("Cancelled", func(t *testing.T) {
   263  			childCtx, cancel := context.WithCancel(ctx)
   264  			v, err := RetryWithResult(childCtx, func(ctx context.Context) (*int, error) {
   265  				cancel()
   266  
   267  				return nil, ctx.Err()
   268  			})
   269  			require.ErrorIs(t, err, context.Canceled)
   270  			require.Nil(t, v)
   271  		})
   272  		t.Run("DeadlineExceeded", func(t *testing.T) {
   273  			childCtx, cancel := context.WithTimeout(ctx, 0)
   274  			v, err := RetryWithResult(childCtx, func(ctx context.Context) (*int, error) {
   275  				cancel()
   276  
   277  				return nil, ctx.Err()
   278  			})
   279  			require.ErrorIs(t, err, context.DeadlineExceeded)
   280  			require.Nil(t, v)
   281  		})
   282  	})
   283  }