github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/retry/sql_test.go (about)

     1  package retry
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"database/sql/driver"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/backoff"
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/badconn"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    15  )
    16  
    17  type mockConnector struct {
    18  	t        testing.TB
    19  	conns    uint32
    20  	queryErr error
    21  	execErr  error
    22  }
    23  
    24  var _ driver.Connector = &mockConnector{}
    25  
    26  func (m *mockConnector) Open(name string) (driver.Conn, error) {
    27  	m.t.Log(stack.Record(0))
    28  
    29  	return nil, driver.ErrSkip
    30  }
    31  
    32  func (m *mockConnector) Connect(ctx context.Context) (driver.Conn, error) {
    33  	m.t.Log(stack.Record(0))
    34  	m.conns++
    35  
    36  	return &mockConn{
    37  		t:        m.t,
    38  		queryErr: m.queryErr,
    39  		execErr:  m.execErr,
    40  	}, nil
    41  }
    42  
    43  func (m *mockConnector) Driver() driver.Driver {
    44  	m.t.Log(stack.Record(0))
    45  
    46  	return m
    47  }
    48  
    49  type mockConn struct {
    50  	t        testing.TB
    51  	queryErr error
    52  	execErr  error
    53  	closed   bool
    54  }
    55  
    56  var (
    57  	_ driver.Conn               = &mockConn{}
    58  	_ driver.ConnPrepareContext = &mockConn{}
    59  	_ driver.ConnBeginTx        = &mockConn{}
    60  	_ driver.ExecerContext      = &mockConn{}
    61  	_ driver.QueryerContext     = &mockConn{}
    62  )
    63  
    64  func (m *mockConn) Prepare(query string) (driver.Stmt, error) {
    65  	m.t.Log(stack.Record(0))
    66  
    67  	return nil, driver.ErrSkip
    68  }
    69  
    70  func (m *mockConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
    71  	m.t.Log(stack.Record(0))
    72  	if m.closed {
    73  		return nil, driver.ErrBadConn
    74  	}
    75  
    76  	return &mockStmt{
    77  		t:     m.t,
    78  		conn:  m,
    79  		query: query,
    80  	}, nil
    81  }
    82  
    83  func (m *mockConn) Close() error {
    84  	m.t.Log(stack.Record(0))
    85  	m.closed = true
    86  
    87  	return nil
    88  }
    89  
    90  func (m *mockConn) Begin() (driver.Tx, error) {
    91  	m.t.Log(stack.Record(0))
    92  
    93  	return nil, driver.ErrSkip
    94  }
    95  
    96  func (m *mockConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
    97  	m.t.Log(stack.Record(0))
    98  	if m.closed {
    99  		return nil, driver.ErrBadConn
   100  	}
   101  
   102  	return m, nil
   103  }
   104  
   105  func (m *mockConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   106  	m.t.Log(stack.Record(0))
   107  	if xerrors.MustDeleteSession(m.execErr) {
   108  		m.closed = true
   109  	}
   110  
   111  	return nil, m.queryErr
   112  }
   113  
   114  func (m *mockConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   115  	m.t.Log(stack.Record(0))
   116  	if xerrors.MustDeleteSession(m.execErr) {
   117  		m.closed = true
   118  	}
   119  
   120  	return nil, m.execErr
   121  }
   122  
   123  func (m *mockConn) Commit() error {
   124  	m.t.Log(stack.Record(0))
   125  
   126  	return nil
   127  }
   128  
   129  func (m *mockConn) Rollback() error {
   130  	m.t.Log(stack.Record(0))
   131  
   132  	return nil
   133  }
   134  
   135  type mockStmt struct {
   136  	t     testing.TB
   137  	conn  *mockConn
   138  	query string
   139  }
   140  
   141  var (
   142  	_ driver.Stmt             = &mockStmt{}
   143  	_ driver.StmtExecContext  = &mockStmt{}
   144  	_ driver.StmtQueryContext = &mockStmt{}
   145  )
   146  
   147  func (m *mockStmt) Close() error {
   148  	m.t.Log(stack.Record(0))
   149  
   150  	return nil
   151  }
   152  
   153  func (m *mockStmt) NumInput() int {
   154  	m.t.Log(stack.Record(0))
   155  
   156  	return -1
   157  }
   158  
   159  func (m *mockStmt) Exec(args []driver.Value) (driver.Result, error) {
   160  	m.t.Log(stack.Record(0))
   161  
   162  	return nil, driver.ErrSkip
   163  }
   164  
   165  func (m *mockStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   166  	m.t.Log(stack.Record(0))
   167  
   168  	return m.conn.ExecContext(ctx, m.query, args)
   169  }
   170  
   171  func (m *mockStmt) Query(args []driver.Value) (driver.Rows, error) {
   172  	m.t.Log(stack.Record(0))
   173  
   174  	return nil, driver.ErrSkip
   175  }
   176  
   177  func (m *mockStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   178  	m.t.Log(stack.Record(0))
   179  
   180  	return m.conn.QueryContext(ctx, m.query, args)
   181  }
   182  
   183  //nolint:nestif
   184  func TestDoTx(t *testing.T) {
   185  	for _, idempotentType := range []idempotency{
   186  		idempotent,
   187  		nonIdempotent,
   188  	} {
   189  		t.Run(idempotentType.String(), func(t *testing.T) {
   190  			for _, tt := range errsToCheck {
   191  				t.Run(tt.err.Error(), func(t *testing.T) {
   192  					m := &mockConnector{
   193  						t:        t,
   194  						queryErr: badconn.Map(tt.err),
   195  						execErr:  badconn.Map(tt.err),
   196  					}
   197  					db := sql.OpenDB(m)
   198  					var attempts int
   199  					err := DoTx(context.Background(), db,
   200  						func(ctx context.Context, tx *sql.Tx) error {
   201  							attempts++
   202  							if attempts > 10 {
   203  								return nil
   204  							}
   205  							rows, err := tx.QueryContext(ctx, "SELECT 1")
   206  							if err != nil {
   207  								return err
   208  							}
   209  							defer func() {
   210  								_ = rows.Close()
   211  							}()
   212  
   213  							return rows.Err()
   214  						},
   215  						WithIdempotent(bool(idempotentType)),
   216  						WithFastBackoff(backoff.New(backoff.WithSlotDuration(time.Nanosecond))),
   217  						WithSlowBackoff(backoff.New(backoff.WithSlotDuration(time.Nanosecond))),
   218  						WithTrace(&trace.Retry{
   219  							//nolint:lll
   220  							OnRetry: func(info trace.RetryLoopStartInfo) func(trace.RetryLoopIntermediateInfo) func(trace.RetryLoopDoneInfo) {
   221  								t.Logf("attempt %d, conn %d, mode: %+v", attempts, m.conns, Check(m.queryErr))
   222  
   223  								return func(info trace.RetryLoopIntermediateInfo) func(trace.RetryLoopDoneInfo) {
   224  									t.Logf("attempt %d, conn %d, mode: %+v", attempts, m.conns, Check(m.queryErr))
   225  
   226  									return nil
   227  								}
   228  							},
   229  						}),
   230  					)
   231  					if tt.canRetry[idempotentType] {
   232  						if err != nil {
   233  							t.Errorf("unexpected err after attempts=%d and driver conns=%d: %v)", attempts, m.conns, err)
   234  						}
   235  						if attempts <= 1 {
   236  							t.Errorf("must be attempts > 1 (actual=%d), driver conns=%d)", attempts, m.conns)
   237  						}
   238  						if tt.deleteSession {
   239  							if m.conns <= 1 {
   240  								t.Errorf("must be retry on different conns (attempts=%d, driver conns=%d)", attempts, m.conns)
   241  							}
   242  						} else {
   243  							if m.conns > 1 {
   244  								t.Errorf("must be retry on single conn (attempts=%d, driver conns=%d)", attempts, m.conns)
   245  							}
   246  						}
   247  					} else if err == nil {
   248  						t.Errorf("unexpected nil err (attempts=%d, driver conns=%d)", attempts, m.conns)
   249  					}
   250  				})
   251  			}
   252  		})
   253  	}
   254  }