github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/retry/sql_test.go (about) 1 package retry 2 3 import ( 4 "context" 5 "database/sql" 6 "database/sql/driver" 7 "testing" 8 "time" 9 10 "github.com/ydb-platform/ydb-go-sdk/v3/internal/backoff" 11 "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" 12 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/badconn" 14 "github.com/ydb-platform/ydb-go-sdk/v3/trace" 15 ) 16 17 type mockConnector struct { 18 t testing.TB 19 conns uint32 20 queryErr error 21 execErr error 22 } 23 24 var _ driver.Connector = &mockConnector{} 25 26 func (m *mockConnector) Open(name string) (driver.Conn, error) { 27 m.t.Log(stack.Record(0)) 28 29 return nil, driver.ErrSkip 30 } 31 32 func (m *mockConnector) Connect(ctx context.Context) (driver.Conn, error) { 33 m.t.Log(stack.Record(0)) 34 m.conns++ 35 36 return &mockConn{ 37 t: m.t, 38 queryErr: m.queryErr, 39 execErr: m.execErr, 40 }, nil 41 } 42 43 func (m *mockConnector) Driver() driver.Driver { 44 m.t.Log(stack.Record(0)) 45 46 return m 47 } 48 49 type mockConn struct { 50 t testing.TB 51 queryErr error 52 execErr error 53 closed bool 54 } 55 56 var ( 57 _ driver.Conn = &mockConn{} 58 _ driver.ConnPrepareContext = &mockConn{} 59 _ driver.ConnBeginTx = &mockConn{} 60 _ driver.ExecerContext = &mockConn{} 61 _ driver.QueryerContext = &mockConn{} 62 ) 63 64 func (m *mockConn) Prepare(query string) (driver.Stmt, error) { 65 m.t.Log(stack.Record(0)) 66 67 return nil, driver.ErrSkip 68 } 69 70 func (m *mockConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 71 m.t.Log(stack.Record(0)) 72 if m.closed { 73 return nil, driver.ErrBadConn 74 } 75 76 return &mockStmt{ 77 t: m.t, 78 conn: m, 79 query: query, 80 }, nil 81 } 82 83 func (m *mockConn) Close() error { 84 m.t.Log(stack.Record(0)) 85 m.closed = true 86 87 return nil 88 } 89 90 func (m *mockConn) Begin() (driver.Tx, error) { 91 m.t.Log(stack.Record(0)) 92 93 return nil, driver.ErrSkip 94 } 95 96 func (m *mockConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 97 m.t.Log(stack.Record(0)) 98 if m.closed { 99 return nil, driver.ErrBadConn 100 } 101 102 return m, nil 103 } 104 105 func (m *mockConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 106 m.t.Log(stack.Record(0)) 107 if xerrors.MustDeleteSession(m.execErr) { 108 m.closed = true 109 } 110 111 return nil, m.queryErr 112 } 113 114 func (m *mockConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 115 m.t.Log(stack.Record(0)) 116 if xerrors.MustDeleteSession(m.execErr) { 117 m.closed = true 118 } 119 120 return nil, m.execErr 121 } 122 123 func (m *mockConn) Commit() error { 124 m.t.Log(stack.Record(0)) 125 126 return nil 127 } 128 129 func (m *mockConn) Rollback() error { 130 m.t.Log(stack.Record(0)) 131 132 return nil 133 } 134 135 type mockStmt struct { 136 t testing.TB 137 conn *mockConn 138 query string 139 } 140 141 var ( 142 _ driver.Stmt = &mockStmt{} 143 _ driver.StmtExecContext = &mockStmt{} 144 _ driver.StmtQueryContext = &mockStmt{} 145 ) 146 147 func (m *mockStmt) Close() error { 148 m.t.Log(stack.Record(0)) 149 150 return nil 151 } 152 153 func (m *mockStmt) NumInput() int { 154 m.t.Log(stack.Record(0)) 155 156 return -1 157 } 158 159 func (m *mockStmt) Exec(args []driver.Value) (driver.Result, error) { 160 m.t.Log(stack.Record(0)) 161 162 return nil, driver.ErrSkip 163 } 164 165 func (m *mockStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 166 m.t.Log(stack.Record(0)) 167 168 return m.conn.ExecContext(ctx, m.query, args) 169 } 170 171 func (m *mockStmt) Query(args []driver.Value) (driver.Rows, error) { 172 m.t.Log(stack.Record(0)) 173 174 return nil, driver.ErrSkip 175 } 176 177 func (m *mockStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 178 m.t.Log(stack.Record(0)) 179 180 return m.conn.QueryContext(ctx, m.query, args) 181 } 182 183 //nolint:nestif 184 func TestDoTx(t *testing.T) { 185 for _, idempotentType := range []idempotency{ 186 idempotent, 187 nonIdempotent, 188 } { 189 t.Run(idempotentType.String(), func(t *testing.T) { 190 for _, tt := range errsToCheck { 191 t.Run(tt.err.Error(), func(t *testing.T) { 192 m := &mockConnector{ 193 t: t, 194 queryErr: badconn.Map(tt.err), 195 execErr: badconn.Map(tt.err), 196 } 197 db := sql.OpenDB(m) 198 var attempts int 199 err := DoTx(context.Background(), db, 200 func(ctx context.Context, tx *sql.Tx) error { 201 attempts++ 202 if attempts > 10 { 203 return nil 204 } 205 rows, err := tx.QueryContext(ctx, "SELECT 1") 206 if err != nil { 207 return err 208 } 209 defer func() { 210 _ = rows.Close() 211 }() 212 213 return rows.Err() 214 }, 215 WithIdempotent(bool(idempotentType)), 216 WithFastBackoff(backoff.New(backoff.WithSlotDuration(time.Nanosecond))), 217 WithSlowBackoff(backoff.New(backoff.WithSlotDuration(time.Nanosecond))), 218 WithTrace(&trace.Retry{ 219 //nolint:lll 220 OnRetry: func(info trace.RetryLoopStartInfo) func(trace.RetryLoopIntermediateInfo) func(trace.RetryLoopDoneInfo) { 221 t.Logf("attempt %d, conn %d, mode: %+v", attempts, m.conns, Check(m.queryErr)) 222 223 return func(info trace.RetryLoopIntermediateInfo) func(trace.RetryLoopDoneInfo) { 224 t.Logf("attempt %d, conn %d, mode: %+v", attempts, m.conns, Check(m.queryErr)) 225 226 return nil 227 } 228 }, 229 }), 230 ) 231 if tt.canRetry[idempotentType] { 232 if err != nil { 233 t.Errorf("unexpected err after attempts=%d and driver conns=%d: %v)", attempts, m.conns, err) 234 } 235 if attempts <= 1 { 236 t.Errorf("must be attempts > 1 (actual=%d), driver conns=%d)", attempts, m.conns) 237 } 238 if tt.deleteSession { 239 if m.conns <= 1 { 240 t.Errorf("must be retry on different conns (attempts=%d, driver conns=%d)", attempts, m.conns) 241 } 242 } else { 243 if m.conns > 1 { 244 t.Errorf("must be retry on single conn (attempts=%d, driver conns=%d)", attempts, m.conns) 245 } 246 } 247 } else if err == nil { 248 t.Errorf("unexpected nil err (attempts=%d, driver conns=%d)", attempts, m.conns) 249 } 250 }) 251 } 252 }) 253 } 254 }