github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/retry/retry_test.go (about) 1 package retry 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "strconv" 8 "testing" 9 "time" 10 11 "github.com/stretchr/testify/require" 12 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" 13 grpcCodes "google.golang.org/grpc/codes" 14 grpcStatus "google.golang.org/grpc/status" 15 16 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" 17 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 18 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest" 19 ) 20 21 func TestRetryModes(t *testing.T) { 22 for _, idempotentType := range []idempotency{ 23 idempotent, 24 nonIdempotent, 25 } { 26 t.Run(idempotentType.String(), func(t *testing.T) { 27 for i, tt := range errsToCheck { 28 t.Run(strconv.Itoa(i)+"."+tt.err.Error(), func(t *testing.T) { 29 m := Check(tt.err) 30 require.Equal(t, tt.canRetry[idempotent], m.MustRetry(true)) 31 require.Equal(t, tt.canRetry[nonIdempotent], m.MustRetry(false)) 32 require.Equal(t, tt.backoff, m.BackoffType()) 33 require.Equal(t, tt.deleteSession, m.MustDeleteSession()) 34 }) 35 } 36 }) 37 } 38 } 39 40 type CustomError struct { 41 Err error 42 } 43 44 func (e *CustomError) Error() string { 45 return fmt.Sprintf("custom error: %v", e.Err) 46 } 47 48 func (e *CustomError) Unwrap() error { 49 return e.Err 50 } 51 52 func TestRetryWithCustomErrors(t *testing.T) { 53 var ( 54 limit = 10 55 ctx = context.Background() 56 ) 57 for _, tt := range []struct { 58 error error 59 retriable bool 60 }{ 61 { 62 error: &CustomError{ 63 Err: RetryableError( 64 fmt.Errorf("custom error"), 65 WithDeleteSession(), 66 ), 67 }, 68 retriable: true, 69 }, 70 { 71 error: &CustomError{ 72 Err: xerrors.Operation( 73 xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION), 74 ), 75 }, 76 retriable: true, 77 }, 78 { 79 error: &CustomError{ 80 Err: fmt.Errorf( 81 "wrapped error: %w", 82 xerrors.Operation( 83 xerrors.WithStatusCode(Ydb.StatusIds_BAD_SESSION), 84 ), 85 ), 86 }, 87 retriable: true, 88 }, 89 { 90 error: &CustomError{ 91 Err: fmt.Errorf( 92 "wrapped error: %w", 93 xerrors.Operation( 94 xerrors.WithStatusCode(Ydb.StatusIds_UNAUTHORIZED), 95 ), 96 ), 97 }, 98 retriable: false, 99 }, 100 } { 101 t.Run(tt.error.Error(), func(t *testing.T) { 102 i := 0 103 err := Retry(ctx, func(ctx context.Context) error { 104 i++ 105 if i < limit { 106 return tt.error 107 } 108 109 return nil 110 }) 111 if tt.retriable { 112 if i != limit { 113 t.Fatalf("unexpected i: %d, queryErr: %v", i, err) 114 } 115 } else { 116 if i != 1 { 117 t.Fatalf("unexpected i: %d, queryErr: %v", i, err) 118 } 119 } 120 }) 121 } 122 } 123 124 func TestRetryTransportDeadlineExceeded(t *testing.T) { 125 cancelCounterValue := 5 126 for _, code := range []grpcCodes.Code{ 127 grpcCodes.DeadlineExceeded, 128 grpcCodes.Canceled, 129 } { 130 t.Run(code.String(), func(t *testing.T) { 131 counter := 0 132 ctx, cancel := xcontext.WithTimeout(context.Background(), time.Hour) 133 err := Retry(ctx, func(ctx context.Context) error { 134 counter++ 135 if !(counter < cancelCounterValue) { 136 cancel() 137 } 138 139 return xerrors.Transport(grpcStatus.Error(code, "")) 140 }, WithIdempotent(true)) 141 require.ErrorIs(t, err, context.Canceled) 142 require.Equal(t, cancelCounterValue, counter) 143 }) 144 } 145 } 146 147 func TestRetryTransportCancelled(t *testing.T) { 148 cancelCounterValue := 5 149 for _, code := range []grpcCodes.Code{ 150 grpcCodes.DeadlineExceeded, 151 grpcCodes.Canceled, 152 } { 153 t.Run(code.String(), func(t *testing.T) { 154 t.Helper() 155 counter := 0 156 ctx, cancel := xcontext.WithCancel(context.Background()) 157 err := Retry(ctx, func(ctx context.Context) error { 158 counter++ 159 if !(counter < cancelCounterValue) { 160 cancel() 161 } 162 163 return xerrors.Transport(grpcStatus.Error(code, "")) 164 }, WithIdempotent(true)) 165 require.ErrorIs(t, err, context.Canceled) 166 require.Equal(t, cancelCounterValue, counter) 167 }) 168 } 169 } 170 171 type noQuota struct{} 172 173 var errNoQuota = errors.New("no quota") 174 175 func (noQuota) Acquire(ctx context.Context) error { 176 return errNoQuota 177 } 178 179 func TestRetryWithBudget(t *testing.T) { 180 xtest.TestManyTimes(t, func(t testing.TB) { 181 quota := noQuota{} 182 ctx, cancel := context.WithCancel(xtest.Context(t)) 183 defer cancel() 184 err := Retry(ctx, func(ctx context.Context) (err error) { 185 return RetryableError(errors.New("custom error")) 186 }, WithBudget(quota)) 187 require.ErrorIs(t, err, errNoQuota) 188 }) 189 } 190 191 type MockPanicCallback struct { 192 called bool 193 received interface{} 194 } 195 196 func (m *MockPanicCallback) Call(e interface{}) { 197 m.called = true 198 m.received = e 199 } 200 201 func TestOpWithRecover_NoPanic(t *testing.T) { 202 ctx := context.Background() 203 options := &retryOptions{ 204 panicCallback: nil, 205 } 206 op := func(ctx context.Context) (*struct{}, error) { 207 return nil, nil //nolint:nilnil 208 } 209 210 _, err := opWithRecover(ctx, options, op) 211 212 require.NoError(t, err) 213 } 214 215 func TestOpWithRecover_WithPanic(t *testing.T) { 216 ctx := context.Background() 217 mockCallback := new(MockPanicCallback) 218 options := &retryOptions{ 219 panicCallback: mockCallback.Call, 220 } 221 op := func(ctx context.Context) (*struct{}, error) { 222 panic("test panic") 223 } 224 225 _, err := opWithRecover(ctx, options, op) 226 227 require.Error(t, err) 228 require.Contains(t, err.Error(), "panic recovered: test panic") 229 require.True(t, mockCallback.called) 230 require.Equal(t, "test panic", mockCallback.received) 231 } 232 233 func TestRetryWithResult(t *testing.T) { 234 ctx := xtest.Context(t) 235 t.Run("HappyWay", func(t *testing.T) { 236 v, err := RetryWithResult(ctx, func(ctx context.Context) (*int, error) { 237 v := 123 238 239 return &v, nil 240 }) 241 require.NoError(t, err) 242 require.NotNil(t, v) 243 require.EqualValues(t, 123, *v) 244 }) 245 t.Run("RetryableError", func(t *testing.T) { 246 var counter int 247 v, err := RetryWithResult(ctx, func(ctx context.Context) (*int, error) { 248 counter++ 249 if counter < 10 { 250 return nil, RetryableError(errors.New("test")) 251 } 252 v := counter * 123 253 254 return &v, nil 255 }) 256 require.NoError(t, err) 257 require.NotNil(t, v) 258 require.EqualValues(t, 1230, *v) 259 require.EqualValues(t, 10, counter) 260 }) 261 t.Run("Context", func(t *testing.T) { 262 t.Run("Cancelled", func(t *testing.T) { 263 childCtx, cancel := context.WithCancel(ctx) 264 v, err := RetryWithResult(childCtx, func(ctx context.Context) (*int, error) { 265 cancel() 266 267 return nil, ctx.Err() 268 }) 269 require.ErrorIs(t, err, context.Canceled) 270 require.Nil(t, v) 271 }) 272 t.Run("DeadlineExceeded", func(t *testing.T) { 273 childCtx, cancel := context.WithTimeout(ctx, 0) 274 v, err := RetryWithResult(childCtx, func(ctx context.Context) (*int, error) { 275 cancel() 276 277 return nil, ctx.Err() 278 }) 279 require.ErrorIs(t, err, context.DeadlineExceeded) 280 require.Nil(t, v) 281 }) 282 }) 283 }