github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/retry/sql_test.go (about) 1 package retry 2 3 import ( 4 "context" 5 "database/sql" 6 "database/sql/driver" 7 "strconv" 8 "testing" 9 "time" 10 11 "github.com/stretchr/testify/require" 12 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/backoff" 14 "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" 15 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 16 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/badconn" 17 ) 18 19 type mockConnector struct { 20 t testing.TB 21 conns uint32 22 queryErr error 23 execErr error 24 } 25 26 var _ driver.Connector = &mockConnector{} 27 28 func (m *mockConnector) Open(name string) (driver.Conn, error) { 29 m.t.Log(stack.Record(0)) 30 31 return nil, driver.ErrSkip 32 } 33 34 func (m *mockConnector) Connect(ctx context.Context) (driver.Conn, error) { 35 m.t.Log(stack.Record(0)) 36 m.conns++ 37 38 return &mockConn{ 39 t: m.t, 40 queryErr: m.queryErr, 41 execErr: m.execErr, 42 }, nil 43 } 44 45 func (m *mockConnector) Driver() driver.Driver { 46 m.t.Log(stack.Record(0)) 47 48 return m 49 } 50 51 type mockConn struct { 52 t testing.TB 53 queryErr error 54 execErr error 55 closed bool 56 } 57 58 var ( 59 _ driver.Conn = &mockConn{} 60 _ driver.ConnPrepareContext = &mockConn{} 61 _ driver.ConnBeginTx = &mockConn{} 62 _ driver.ExecerContext = &mockConn{} 63 _ driver.QueryerContext = &mockConn{} 64 ) 65 66 func (m *mockConn) Prepare(query string) (driver.Stmt, error) { 67 m.t.Log(stack.Record(0)) 68 69 return nil, driver.ErrSkip 70 } 71 72 func (m *mockConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 73 m.t.Log(stack.Record(0)) 74 if m.closed { 75 return nil, driver.ErrBadConn 76 } 77 78 return &mockStmt{ 79 t: m.t, 80 conn: m, 81 query: query, 82 }, nil 83 } 84 85 func (m *mockConn) Close() error { 86 m.t.Log(stack.Record(0)) 87 m.closed = true 88 89 return nil 90 } 91 92 func (m *mockConn) Begin() (driver.Tx, error) { 93 m.t.Log(stack.Record(0)) 94 95 return nil, driver.ErrSkip 96 } 97 98 func (m *mockConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 99 m.t.Log(stack.Record(0)) 100 if m.closed { 101 return nil, driver.ErrBadConn 102 } 103 104 return m, nil 105 } 106 107 func (m *mockConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 108 m.t.Log(stack.Record(0)) 109 if !xerrors.IsRetryObjectValid(m.execErr) { 110 m.closed = true 111 } 112 113 return nil, m.queryErr 114 } 115 116 func (m *mockConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 117 m.t.Log(stack.Record(0)) 118 if !xerrors.IsRetryObjectValid(m.execErr) { 119 m.closed = true 120 } 121 122 return nil, m.execErr 123 } 124 125 func (m *mockConn) Commit() error { 126 m.t.Log(stack.Record(0)) 127 128 return nil 129 } 130 131 func (m *mockConn) Rollback() error { 132 m.t.Log(stack.Record(0)) 133 134 return nil 135 } 136 137 type mockStmt struct { 138 t testing.TB 139 conn *mockConn 140 query string 141 } 142 143 var ( 144 _ driver.Stmt = &mockStmt{} 145 _ driver.StmtExecContext = &mockStmt{} 146 _ driver.StmtQueryContext = &mockStmt{} 147 ) 148 149 func (m *mockStmt) Close() error { 150 m.t.Log(stack.Record(0)) 151 152 return nil 153 } 154 155 func (m *mockStmt) NumInput() int { 156 m.t.Log(stack.Record(0)) 157 158 return -1 159 } 160 161 func (m *mockStmt) Exec(args []driver.Value) (driver.Result, error) { 162 m.t.Log(stack.Record(0)) 163 164 return nil, driver.ErrSkip 165 } 166 167 func (m *mockStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 168 m.t.Log(stack.Record(0)) 169 170 return m.conn.ExecContext(ctx, m.query, args) 171 } 172 173 func (m *mockStmt) Query(args []driver.Value) (driver.Rows, error) { 174 m.t.Log(stack.Record(0)) 175 176 return nil, driver.ErrSkip 177 } 178 179 func (m *mockStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 180 m.t.Log(stack.Record(0)) 181 182 return m.conn.QueryContext(ctx, m.query, args) 183 } 184 185 func TestDoTx(t *testing.T) { 186 for _, idempotentType := range []idempotency{ 187 idempotent, 188 nonIdempotent, 189 } { 190 t.Run(idempotentType.String(), func(t *testing.T) { 191 for i, tt := range errsToCheck { 192 t.Run(strconv.Itoa(i)+"."+tt.err.Error(), func(t *testing.T) { 193 m := &mockConnector{ 194 t: t, 195 queryErr: badconn.Map(tt.err), 196 execErr: badconn.Map(tt.err), 197 } 198 db := sql.OpenDB(m) 199 var attempts int 200 err := DoTx(context.Background(), db, 201 func(ctx context.Context, tx *sql.Tx) error { 202 attempts++ 203 if attempts > 10 { 204 return nil 205 } 206 rows, err := tx.QueryContext(ctx, "SELECT 1") 207 if err != nil { 208 return err 209 } 210 defer func() { 211 _ = rows.Close() 212 }() 213 214 return rows.Err() 215 }, 216 WithIdempotent(bool(idempotentType)), 217 WithFastBackoff(backoff.New(backoff.WithSlotDuration(time.Nanosecond))), 218 WithSlowBackoff(backoff.New(backoff.WithSlotDuration(time.Nanosecond))), 219 ) 220 if tt.canRetry[idempotentType] { 221 require.NoError(t, err) 222 require.NotEmpty(t, attempts) 223 if tt.deleteSession { 224 require.Greater(t, m.conns, uint32(1)) 225 } else { 226 require.Equal(t, uint32(1), m.conns) 227 } 228 } else { 229 require.Error(t, err) 230 } 231 }) 232 } 233 }) 234 } 235 }