github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/retry/sql_test.go (about)

     1  package retry
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"database/sql/driver"
     7  	"strconv"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/backoff"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/badconn"
    17  )
    18  
    19  type mockConnector struct {
    20  	t        testing.TB
    21  	conns    uint32
    22  	queryErr error
    23  	execErr  error
    24  }
    25  
    26  var _ driver.Connector = &mockConnector{}
    27  
    28  func (m *mockConnector) Open(name string) (driver.Conn, error) {
    29  	m.t.Log(stack.Record(0))
    30  
    31  	return nil, driver.ErrSkip
    32  }
    33  
    34  func (m *mockConnector) Connect(ctx context.Context) (driver.Conn, error) {
    35  	m.t.Log(stack.Record(0))
    36  	m.conns++
    37  
    38  	return &mockConn{
    39  		t:        m.t,
    40  		queryErr: m.queryErr,
    41  		execErr:  m.execErr,
    42  	}, nil
    43  }
    44  
    45  func (m *mockConnector) Driver() driver.Driver {
    46  	m.t.Log(stack.Record(0))
    47  
    48  	return m
    49  }
    50  
    51  type mockConn struct {
    52  	t        testing.TB
    53  	queryErr error
    54  	execErr  error
    55  	closed   bool
    56  }
    57  
    58  var (
    59  	_ driver.Conn               = &mockConn{}
    60  	_ driver.ConnPrepareContext = &mockConn{}
    61  	_ driver.ConnBeginTx        = &mockConn{}
    62  	_ driver.ExecerContext      = &mockConn{}
    63  	_ driver.QueryerContext     = &mockConn{}
    64  )
    65  
    66  func (m *mockConn) Prepare(query string) (driver.Stmt, error) {
    67  	m.t.Log(stack.Record(0))
    68  
    69  	return nil, driver.ErrSkip
    70  }
    71  
    72  func (m *mockConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
    73  	m.t.Log(stack.Record(0))
    74  	if m.closed {
    75  		return nil, driver.ErrBadConn
    76  	}
    77  
    78  	return &mockStmt{
    79  		t:     m.t,
    80  		conn:  m,
    81  		query: query,
    82  	}, nil
    83  }
    84  
    85  func (m *mockConn) Close() error {
    86  	m.t.Log(stack.Record(0))
    87  	m.closed = true
    88  
    89  	return nil
    90  }
    91  
    92  func (m *mockConn) Begin() (driver.Tx, error) {
    93  	m.t.Log(stack.Record(0))
    94  
    95  	return nil, driver.ErrSkip
    96  }
    97  
    98  func (m *mockConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
    99  	m.t.Log(stack.Record(0))
   100  	if m.closed {
   101  		return nil, driver.ErrBadConn
   102  	}
   103  
   104  	return m, nil
   105  }
   106  
   107  func (m *mockConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   108  	m.t.Log(stack.Record(0))
   109  	if !xerrors.IsRetryObjectValid(m.execErr) {
   110  		m.closed = true
   111  	}
   112  
   113  	return nil, m.queryErr
   114  }
   115  
   116  func (m *mockConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   117  	m.t.Log(stack.Record(0))
   118  	if !xerrors.IsRetryObjectValid(m.execErr) {
   119  		m.closed = true
   120  	}
   121  
   122  	return nil, m.execErr
   123  }
   124  
   125  func (m *mockConn) Commit() error {
   126  	m.t.Log(stack.Record(0))
   127  
   128  	return nil
   129  }
   130  
   131  func (m *mockConn) Rollback() error {
   132  	m.t.Log(stack.Record(0))
   133  
   134  	return nil
   135  }
   136  
   137  type mockStmt struct {
   138  	t     testing.TB
   139  	conn  *mockConn
   140  	query string
   141  }
   142  
   143  var (
   144  	_ driver.Stmt             = &mockStmt{}
   145  	_ driver.StmtExecContext  = &mockStmt{}
   146  	_ driver.StmtQueryContext = &mockStmt{}
   147  )
   148  
   149  func (m *mockStmt) Close() error {
   150  	m.t.Log(stack.Record(0))
   151  
   152  	return nil
   153  }
   154  
   155  func (m *mockStmt) NumInput() int {
   156  	m.t.Log(stack.Record(0))
   157  
   158  	return -1
   159  }
   160  
   161  func (m *mockStmt) Exec(args []driver.Value) (driver.Result, error) {
   162  	m.t.Log(stack.Record(0))
   163  
   164  	return nil, driver.ErrSkip
   165  }
   166  
   167  func (m *mockStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   168  	m.t.Log(stack.Record(0))
   169  
   170  	return m.conn.ExecContext(ctx, m.query, args)
   171  }
   172  
   173  func (m *mockStmt) Query(args []driver.Value) (driver.Rows, error) {
   174  	m.t.Log(stack.Record(0))
   175  
   176  	return nil, driver.ErrSkip
   177  }
   178  
   179  func (m *mockStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   180  	m.t.Log(stack.Record(0))
   181  
   182  	return m.conn.QueryContext(ctx, m.query, args)
   183  }
   184  
   185  func TestDoTx(t *testing.T) {
   186  	for _, idempotentType := range []idempotency{
   187  		idempotent,
   188  		nonIdempotent,
   189  	} {
   190  		t.Run(idempotentType.String(), func(t *testing.T) {
   191  			for i, tt := range errsToCheck {
   192  				t.Run(strconv.Itoa(i)+"."+tt.err.Error(), func(t *testing.T) {
   193  					m := &mockConnector{
   194  						t:        t,
   195  						queryErr: badconn.Map(tt.err),
   196  						execErr:  badconn.Map(tt.err),
   197  					}
   198  					db := sql.OpenDB(m)
   199  					var attempts int
   200  					err := DoTx(context.Background(), db,
   201  						func(ctx context.Context, tx *sql.Tx) error {
   202  							attempts++
   203  							if attempts > 10 {
   204  								return nil
   205  							}
   206  							rows, err := tx.QueryContext(ctx, "SELECT 1")
   207  							if err != nil {
   208  								return err
   209  							}
   210  							defer func() {
   211  								_ = rows.Close()
   212  							}()
   213  
   214  							return rows.Err()
   215  						},
   216  						WithIdempotent(bool(idempotentType)),
   217  						WithFastBackoff(backoff.New(backoff.WithSlotDuration(time.Nanosecond))),
   218  						WithSlowBackoff(backoff.New(backoff.WithSlotDuration(time.Nanosecond))),
   219  					)
   220  					if tt.canRetry[idempotentType] {
   221  						require.NoError(t, err)
   222  						require.NotEmpty(t, attempts)
   223  						if tt.deleteSession {
   224  							require.Greater(t, m.conns, uint32(1))
   225  						} else {
   226  							require.Equal(t, uint32(1), m.conns)
   227  						}
   228  					} else {
   229  						require.Error(t, err)
   230  					}
   231  				})
   232  			}
   233  		})
   234  	}
   235  }