github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/xsql/tx.go (about)

     1  package xsql
     2  
     3  import (
     4  	"context"
     5  	"database/sql/driver"
     6  	"fmt"
     7  
     8  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
     9  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    10  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/badconn"
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/isolation"
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/table"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    14  )
    15  
    16  type tx struct {
    17  	conn  *conn
    18  	txCtx context.Context
    19  	tx    table.Transaction
    20  }
    21  
    22  var (
    23  	_ driver.Tx                   = &tx{}
    24  	_ driver.ExecerContext        = &tx{}
    25  	_ driver.QueryerContext       = &tx{}
    26  	_ table.TransactionIdentifier = &tx{}
    27  )
    28  
    29  func (c *conn) beginTx(ctx context.Context, txOptions driver.TxOptions) (currentTx, error) {
    30  	if c.currentTx != nil {
    31  		return nil, badconn.Map(
    32  			xerrors.WithStackTrace(
    33  				fmt.Errorf("broken conn state: conn=%q already have current tx=%q",
    34  					c.ID(), c.currentTx.ID(),
    35  				),
    36  			),
    37  		)
    38  	}
    39  	txc, err := isolation.ToYDB(txOptions)
    40  	if err != nil {
    41  		return nil, xerrors.WithStackTrace(err)
    42  	}
    43  	transaction, err := c.session.BeginTransaction(ctx, table.TxSettings(txc))
    44  	if err != nil {
    45  		return nil, badconn.Map(xerrors.WithStackTrace(err))
    46  	}
    47  	c.currentTx = &tx{
    48  		conn:  c,
    49  		txCtx: ctx,
    50  		tx:    transaction,
    51  	}
    52  
    53  	return c.currentTx, nil
    54  }
    55  
    56  func (tx *tx) ID() string {
    57  	return tx.tx.ID()
    58  }
    59  
    60  func (tx *tx) checkTxState() error {
    61  	if tx.conn.currentTx == tx {
    62  		return nil
    63  	}
    64  	if tx.conn.currentTx == nil {
    65  		return fmt.Errorf("broken conn state: tx=%q not related to conn=%q",
    66  			tx.ID(), tx.conn.ID(),
    67  		)
    68  	}
    69  
    70  	return fmt.Errorf("broken conn state: tx=%s not related to conn=%q (conn have current tx=%q)",
    71  		tx.conn.currentTx.ID(), tx.conn.ID(), tx.ID(),
    72  	)
    73  }
    74  
    75  func (tx *tx) Commit() (finalErr error) {
    76  	onDone := trace.DatabaseSQLOnTxCommit(tx.conn.trace, &tx.txCtx,
    77  		stack.FunctionID(""),
    78  		tx,
    79  	)
    80  	defer func() {
    81  		onDone(finalErr)
    82  	}()
    83  	if err := tx.checkTxState(); err != nil {
    84  		return badconn.Map(xerrors.WithStackTrace(err))
    85  	}
    86  	defer func() {
    87  		tx.conn.currentTx = nil
    88  	}()
    89  	_, err := tx.tx.CommitTx(tx.txCtx)
    90  	if err != nil {
    91  		return badconn.Map(xerrors.WithStackTrace(err))
    92  	}
    93  
    94  	return nil
    95  }
    96  
    97  func (tx *tx) Rollback() (finalErr error) {
    98  	onDone := trace.DatabaseSQLOnTxRollback(tx.conn.trace, &tx.txCtx,
    99  		stack.FunctionID(""),
   100  		tx,
   101  	)
   102  	defer func() {
   103  		onDone(finalErr)
   104  	}()
   105  	if err := tx.checkTxState(); err != nil {
   106  		return badconn.Map(xerrors.WithStackTrace(err))
   107  	}
   108  	defer func() {
   109  		tx.conn.currentTx = nil
   110  	}()
   111  	err := tx.tx.Rollback(tx.txCtx)
   112  	if err != nil {
   113  		return badconn.Map(xerrors.WithStackTrace(err))
   114  	}
   115  
   116  	return err
   117  }
   118  
   119  func (tx *tx) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (
   120  	_ driver.Rows, finalErr error,
   121  ) {
   122  	onDone := trace.DatabaseSQLOnTxQuery(tx.conn.trace, &ctx,
   123  		stack.FunctionID(""),
   124  		tx.txCtx, tx, query, true,
   125  	)
   126  	defer func() {
   127  		onDone(finalErr)
   128  	}()
   129  	m := queryModeFromContext(ctx, tx.conn.defaultQueryMode)
   130  	if m != DataQueryMode {
   131  		return nil, badconn.Map(
   132  			xerrors.WithStackTrace(
   133  				xerrors.Retryable(
   134  					fmt.Errorf("wrong query mode: %s", m.String()),
   135  					xerrors.WithDeleteSession(),
   136  					xerrors.WithName("WRONG_QUERY_MODE"),
   137  				),
   138  			),
   139  		)
   140  	}
   141  	query, parameters, err := tx.conn.normalize(query, args...)
   142  	if err != nil {
   143  		return nil, xerrors.WithStackTrace(err)
   144  	}
   145  	res, err := tx.tx.Execute(ctx,
   146  		query, &parameters, tx.conn.dataQueryOptions(ctx)...,
   147  	)
   148  	if err != nil {
   149  		return nil, badconn.Map(xerrors.WithStackTrace(err))
   150  	}
   151  	if err = res.Err(); err != nil {
   152  		return nil, badconn.Map(xerrors.WithStackTrace(err))
   153  	}
   154  
   155  	return &rows{
   156  		conn:   tx.conn,
   157  		result: res,
   158  	}, nil
   159  }
   160  
   161  func (tx *tx) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (
   162  	_ driver.Result, finalErr error,
   163  ) {
   164  	onDone := trace.DatabaseSQLOnTxExec(tx.conn.trace, &ctx,
   165  		stack.FunctionID(""),
   166  		tx.txCtx, tx, query, true,
   167  	)
   168  	defer func() {
   169  		onDone(finalErr)
   170  	}()
   171  	m := queryModeFromContext(ctx, tx.conn.defaultQueryMode)
   172  	if m != DataQueryMode {
   173  		return nil, badconn.Map(
   174  			xerrors.WithStackTrace(
   175  				xerrors.Retryable(
   176  					fmt.Errorf("wrong query mode: %s", m.String()),
   177  					xerrors.WithDeleteSession(),
   178  					xerrors.WithName("WRONG_QUERY_MODE"),
   179  				),
   180  			),
   181  		)
   182  	}
   183  	query, parameters, err := tx.conn.normalize(query, args...)
   184  	if err != nil {
   185  		return nil, xerrors.WithStackTrace(err)
   186  	}
   187  	_, err = tx.tx.Execute(ctx,
   188  		query, &parameters, tx.conn.dataQueryOptions(ctx)...,
   189  	)
   190  	if err != nil {
   191  		return nil, badconn.Map(xerrors.WithStackTrace(err))
   192  	}
   193  
   194  	return resultNoRows{}, nil
   195  }
   196  
   197  func (tx *tx) PrepareContext(ctx context.Context, query string) (_ driver.Stmt, finalErr error) {
   198  	onDone := trace.DatabaseSQLOnTxPrepare(tx.conn.trace, &ctx,
   199  		stack.FunctionID(""),
   200  		&tx.txCtx, tx, query,
   201  	)
   202  	defer func() {
   203  		onDone(finalErr)
   204  	}()
   205  	if !tx.conn.isReady() {
   206  		return nil, badconn.Map(xerrors.WithStackTrace(errNotReadyConn))
   207  	}
   208  
   209  	return &stmt{
   210  		conn:      tx.conn,
   211  		processor: tx,
   212  		stmtCtx:   ctx,
   213  		query:     query,
   214  		trace:     tx.conn.trace,
   215  	}, nil
   216  }