github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/xsql/tx.go (about)

     1  package xsql
     2  
     3  import (
     4  	"context"
     5  	"database/sql/driver"
     6  	"fmt"
     7  
     8  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
     9  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/tx"
    10  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/badconn"
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/isolation"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/table"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    15  )
    16  
    17  type transaction struct {
    18  	tx.Identifier
    19  
    20  	conn *conn
    21  	ctx  context.Context //nolint:containedctx
    22  	tx   table.Transaction
    23  }
    24  
    25  var (
    26  	_ driver.Tx             = &transaction{}
    27  	_ driver.ExecerContext  = &transaction{}
    28  	_ driver.QueryerContext = &transaction{}
    29  	_ tx.Identifier         = &transaction{}
    30  )
    31  
    32  func (c *conn) beginTx(ctx context.Context, txOptions driver.TxOptions) (currentTx, error) {
    33  	if c.currentTx != nil {
    34  		return nil, badconn.Map(
    35  			xerrors.WithStackTrace(
    36  				fmt.Errorf("broken conn state: conn=%q already have current tx=%q",
    37  					c.ID(), c.currentTx.ID(),
    38  				),
    39  			),
    40  		)
    41  	}
    42  	txc, err := isolation.ToYDB(txOptions)
    43  	if err != nil {
    44  		return nil, xerrors.WithStackTrace(err)
    45  	}
    46  	nativeTx, err := c.session.BeginTransaction(ctx, table.TxSettings(txc))
    47  	if err != nil {
    48  		return nil, badconn.Map(xerrors.WithStackTrace(err))
    49  	}
    50  	c.currentTx = &transaction{
    51  		Identifier: tx.ID(nativeTx.ID()),
    52  		conn:       c,
    53  		ctx:        ctx,
    54  		tx:         nativeTx,
    55  	}
    56  
    57  	return c.currentTx, nil
    58  }
    59  
    60  func (tx *transaction) checkTxState() error {
    61  	if tx.conn.currentTx == tx {
    62  		return nil
    63  	}
    64  	if tx.conn.currentTx == nil {
    65  		return fmt.Errorf("broken conn state: tx=%q not related to conn=%q",
    66  			tx.ID(), tx.conn.ID(),
    67  		)
    68  	}
    69  
    70  	return fmt.Errorf("broken conn state: tx=%s not related to conn=%q (conn have current tx=%q)",
    71  		tx.conn.currentTx.ID(), tx.conn.ID(), tx.ID(),
    72  	)
    73  }
    74  
    75  func (tx *transaction) Commit() (finalErr error) {
    76  	var (
    77  		ctx    = tx.ctx
    78  		onDone = trace.DatabaseSQLOnTxCommit(tx.conn.trace, &ctx,
    79  			stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql.(*transaction).Commit"),
    80  			tx,
    81  		)
    82  	)
    83  	defer func() {
    84  		onDone(finalErr)
    85  	}()
    86  	if err := tx.checkTxState(); err != nil {
    87  		return badconn.Map(xerrors.WithStackTrace(err))
    88  	}
    89  	defer func() {
    90  		tx.conn.currentTx = nil
    91  	}()
    92  	if _, err := tx.tx.CommitTx(tx.ctx); err != nil {
    93  		return badconn.Map(xerrors.WithStackTrace(err))
    94  	}
    95  
    96  	return nil
    97  }
    98  
    99  func (tx *transaction) Rollback() (finalErr error) {
   100  	var (
   101  		ctx    = tx.ctx
   102  		onDone = trace.DatabaseSQLOnTxRollback(tx.conn.trace, &ctx,
   103  			stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql.(*transaction).Rollback"),
   104  			tx,
   105  		)
   106  	)
   107  	defer func() {
   108  		onDone(finalErr)
   109  	}()
   110  	if err := tx.checkTxState(); err != nil {
   111  		return badconn.Map(xerrors.WithStackTrace(err))
   112  	}
   113  	defer func() {
   114  		tx.conn.currentTx = nil
   115  	}()
   116  	err := tx.tx.Rollback(tx.ctx)
   117  	if err != nil {
   118  		return badconn.Map(xerrors.WithStackTrace(err))
   119  	}
   120  
   121  	return err
   122  }
   123  
   124  func (tx *transaction) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (
   125  	_ driver.Rows, finalErr error,
   126  ) {
   127  	onDone := trace.DatabaseSQLOnTxQuery(tx.conn.trace, &ctx,
   128  		stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql.(*transaction).QueryContext"),
   129  		tx.ctx, tx, query,
   130  	)
   131  	defer func() {
   132  		onDone(finalErr)
   133  	}()
   134  	m := queryModeFromContext(ctx, tx.conn.defaultQueryMode)
   135  	if m != DataQueryMode {
   136  		return nil, badconn.Map(
   137  			xerrors.WithStackTrace(
   138  				xerrors.Retryable(
   139  					fmt.Errorf("wrong query mode: %s", m.String()),
   140  					xerrors.InvalidObject(),
   141  					xerrors.WithName("WRONG_QUERY_MODE"),
   142  				),
   143  			),
   144  		)
   145  	}
   146  	query, parameters, err := tx.conn.normalize(query, args...)
   147  	if err != nil {
   148  		return nil, xerrors.WithStackTrace(err)
   149  	}
   150  	res, err := tx.tx.Execute(ctx,
   151  		query, &parameters, tx.conn.dataQueryOptions(ctx)...,
   152  	)
   153  	if err != nil {
   154  		return nil, badconn.Map(xerrors.WithStackTrace(err))
   155  	}
   156  	if err = res.Err(); err != nil {
   157  		return nil, badconn.Map(xerrors.WithStackTrace(err))
   158  	}
   159  
   160  	return &rows{
   161  		conn:   tx.conn,
   162  		result: res,
   163  	}, nil
   164  }
   165  
   166  func (tx *transaction) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (
   167  	_ driver.Result, finalErr error,
   168  ) {
   169  	onDone := trace.DatabaseSQLOnTxExec(tx.conn.trace, &ctx,
   170  		stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql.(*transaction).ExecContext"),
   171  		tx.ctx, tx, query,
   172  	)
   173  	defer func() {
   174  		onDone(finalErr)
   175  	}()
   176  	m := queryModeFromContext(ctx, tx.conn.defaultQueryMode)
   177  	if m != DataQueryMode {
   178  		return nil, badconn.Map(
   179  			xerrors.WithStackTrace(
   180  				xerrors.Retryable(
   181  					fmt.Errorf("wrong query mode: %s", m.String()),
   182  					xerrors.InvalidObject(),
   183  					xerrors.WithName("WRONG_QUERY_MODE"),
   184  				),
   185  			),
   186  		)
   187  	}
   188  	query, parameters, err := tx.conn.normalize(query, args...)
   189  	if err != nil {
   190  		return nil, xerrors.WithStackTrace(err)
   191  	}
   192  	_, err = tx.tx.Execute(ctx,
   193  		query, &parameters, tx.conn.dataQueryOptions(ctx)...,
   194  	)
   195  	if err != nil {
   196  		return nil, badconn.Map(xerrors.WithStackTrace(err))
   197  	}
   198  
   199  	return resultNoRows{}, nil
   200  }
   201  
   202  func (tx *transaction) PrepareContext(ctx context.Context, query string) (_ driver.Stmt, finalErr error) {
   203  	onDone := trace.DatabaseSQLOnTxPrepare(tx.conn.trace, &ctx,
   204  		stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql.(*transaction).PrepareContext"),
   205  		tx.ctx, tx, query,
   206  	)
   207  	defer func() {
   208  		onDone(finalErr)
   209  	}()
   210  	if !tx.conn.isReady() {
   211  		return nil, badconn.Map(xerrors.WithStackTrace(errNotReadyConn))
   212  	}
   213  
   214  	return &stmt{
   215  		conn:      tx.conn,
   216  		processor: tx,
   217  		ctx:       ctx,
   218  		query:     query,
   219  		trace:     tx.conn.trace,
   220  	}, nil
   221  }