github.com/jackc/pgx/v5@v5.5.5/tx.go (about)

     1  package pgx
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"strconv"
     8  	"strings"
     9  
    10  	"github.com/jackc/pgx/v5/pgconn"
    11  )
    12  
    13  // TxIsoLevel is the transaction isolation level (serializable, repeatable read, read committed or read uncommitted)
    14  type TxIsoLevel string
    15  
    16  // Transaction isolation levels
    17  const (
    18  	Serializable    TxIsoLevel = "serializable"
    19  	RepeatableRead  TxIsoLevel = "repeatable read"
    20  	ReadCommitted   TxIsoLevel = "read committed"
    21  	ReadUncommitted TxIsoLevel = "read uncommitted"
    22  )
    23  
    24  // TxAccessMode is the transaction access mode (read write or read only)
    25  type TxAccessMode string
    26  
    27  // Transaction access modes
    28  const (
    29  	ReadWrite TxAccessMode = "read write"
    30  	ReadOnly  TxAccessMode = "read only"
    31  )
    32  
    33  // TxDeferrableMode is the transaction deferrable mode (deferrable or not deferrable)
    34  type TxDeferrableMode string
    35  
    36  // Transaction deferrable modes
    37  const (
    38  	Deferrable    TxDeferrableMode = "deferrable"
    39  	NotDeferrable TxDeferrableMode = "not deferrable"
    40  )
    41  
    42  // TxOptions are transaction modes within a transaction block
    43  type TxOptions struct {
    44  	IsoLevel       TxIsoLevel
    45  	AccessMode     TxAccessMode
    46  	DeferrableMode TxDeferrableMode
    47  
    48  	// BeginQuery is the SQL query that will be executed to begin the transaction. This allows using non-standard syntax
    49  	// such as BEGIN PRIORITY HIGH with CockroachDB. If set this will override the other settings.
    50  	BeginQuery string
    51  }
    52  
    53  var emptyTxOptions TxOptions
    54  
    55  func (txOptions TxOptions) beginSQL() string {
    56  	if txOptions == emptyTxOptions {
    57  		return "begin"
    58  	}
    59  
    60  	if txOptions.BeginQuery != "" {
    61  		return txOptions.BeginQuery
    62  	}
    63  
    64  	var buf strings.Builder
    65  	buf.Grow(64) // 64 - maximum length of string with available options
    66  	buf.WriteString("begin")
    67  
    68  	if txOptions.IsoLevel != "" {
    69  		buf.WriteString(" isolation level ")
    70  		buf.WriteString(string(txOptions.IsoLevel))
    71  	}
    72  	if txOptions.AccessMode != "" {
    73  		buf.WriteByte(' ')
    74  		buf.WriteString(string(txOptions.AccessMode))
    75  	}
    76  	if txOptions.DeferrableMode != "" {
    77  		buf.WriteByte(' ')
    78  		buf.WriteString(string(txOptions.DeferrableMode))
    79  	}
    80  
    81  	return buf.String()
    82  }
    83  
    84  var ErrTxClosed = errors.New("tx is closed")
    85  
    86  // ErrTxCommitRollback occurs when an error has occurred in a transaction and
    87  // Commit() is called. PostgreSQL accepts COMMIT on aborted transactions, but
    88  // it is treated as ROLLBACK.
    89  var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback")
    90  
    91  // Begin starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no
    92  // auto-rollback on context cancellation.
    93  func (c *Conn) Begin(ctx context.Context) (Tx, error) {
    94  	return c.BeginTx(ctx, TxOptions{})
    95  }
    96  
    97  // BeginTx starts a transaction with txOptions determining the transaction mode. Unlike database/sql, the context only
    98  // affects the begin command. i.e. there is no auto-rollback on context cancellation.
    99  func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
   100  	_, err := c.Exec(ctx, txOptions.beginSQL())
   101  	if err != nil {
   102  		// begin should never fail unless there is an underlying connection issue or
   103  		// a context timeout. In either case, the connection is possibly broken.
   104  		c.die(errors.New("failed to begin transaction"))
   105  		return nil, err
   106  	}
   107  
   108  	return &dbTx{conn: c}, nil
   109  }
   110  
   111  // Tx represents a database transaction.
   112  //
   113  // Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx
   114  // state, to support pseudo-nested transactions with savepoints, and to allow tests to mock transactions. However,
   115  // adding a method to an interface is technically a breaking change. If new methods are added to Conn it may be
   116  // desirable to add them to Tx as well. Because of this the Tx interface is partially excluded from semantic version
   117  // requirements. Methods will not be removed or changed, but new methods may be added.
   118  type Tx interface {
   119  	// Begin starts a pseudo nested transaction.
   120  	Begin(ctx context.Context) (Tx, error)
   121  
   122  	// Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested
   123  	// transaction. Commit will return an error where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is
   124  	// otherwise safe to call multiple times. If the commit fails with a rollback status (e.g. the transaction was already
   125  	// in a broken state) then an error where errors.Is(ErrTxCommitRollback) is true will be returned.
   126  	Commit(ctx context.Context) error
   127  
   128  	// Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a
   129  	// pseudo nested transaction. Rollback will return an error where errors.Is(ErrTxClosed) is true if the Tx is already
   130  	// closed, but is otherwise safe to call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will
   131  	// be called first in a non-error condition. Any other failure of a real transaction will result in the connection
   132  	// being closed.
   133  	Rollback(ctx context.Context) error
   134  
   135  	CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error)
   136  	SendBatch(ctx context.Context, b *Batch) BatchResults
   137  	LargeObjects() LargeObjects
   138  
   139  	Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error)
   140  
   141  	Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error)
   142  	Query(ctx context.Context, sql string, args ...any) (Rows, error)
   143  	QueryRow(ctx context.Context, sql string, args ...any) Row
   144  
   145  	// Conn returns the underlying *Conn that on which this transaction is executing.
   146  	Conn() *Conn
   147  }
   148  
   149  // dbTx represents a database transaction.
   150  //
   151  // All dbTx methods return ErrTxClosed if Commit or Rollback has already been
   152  // called on the dbTx.
   153  type dbTx struct {
   154  	conn         *Conn
   155  	savepointNum int64
   156  	closed       bool
   157  }
   158  
   159  // Begin starts a pseudo nested transaction implemented with a savepoint.
   160  func (tx *dbTx) Begin(ctx context.Context) (Tx, error) {
   161  	if tx.closed {
   162  		return nil, ErrTxClosed
   163  	}
   164  
   165  	tx.savepointNum++
   166  	_, err := tx.conn.Exec(ctx, "savepoint sp_"+strconv.FormatInt(tx.savepointNum, 10))
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	return &dbSimulatedNestedTx{tx: tx, savepointNum: tx.savepointNum}, nil
   172  }
   173  
   174  // Commit commits the transaction.
   175  func (tx *dbTx) Commit(ctx context.Context) error {
   176  	if tx.closed {
   177  		return ErrTxClosed
   178  	}
   179  
   180  	commandTag, err := tx.conn.Exec(ctx, "commit")
   181  	tx.closed = true
   182  	if err != nil {
   183  		if tx.conn.PgConn().TxStatus() != 'I' {
   184  			_ = tx.conn.Close(ctx) // already have error to return
   185  		}
   186  		return err
   187  	}
   188  	if commandTag.String() == "ROLLBACK" {
   189  		return ErrTxCommitRollback
   190  	}
   191  
   192  	return nil
   193  }
   194  
   195  // Rollback rolls back the transaction. Rollback will return ErrTxClosed if the
   196  // Tx is already closed, but is otherwise safe to call multiple times. Hence, a
   197  // defer tx.Rollback() is safe even if tx.Commit() will be called first in a
   198  // non-error condition.
   199  func (tx *dbTx) Rollback(ctx context.Context) error {
   200  	if tx.closed {
   201  		return ErrTxClosed
   202  	}
   203  
   204  	_, err := tx.conn.Exec(ctx, "rollback")
   205  	tx.closed = true
   206  	if err != nil {
   207  		// A rollback failure leaves the connection in an undefined state
   208  		tx.conn.die(fmt.Errorf("rollback failed: %w", err))
   209  		return err
   210  	}
   211  
   212  	return nil
   213  }
   214  
   215  // Exec delegates to the underlying *Conn
   216  func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) {
   217  	if tx.closed {
   218  		return pgconn.CommandTag{}, ErrTxClosed
   219  	}
   220  
   221  	return tx.conn.Exec(ctx, sql, arguments...)
   222  }
   223  
   224  // Prepare delegates to the underlying *Conn
   225  func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) {
   226  	if tx.closed {
   227  		return nil, ErrTxClosed
   228  	}
   229  
   230  	return tx.conn.Prepare(ctx, name, sql)
   231  }
   232  
   233  // Query delegates to the underlying *Conn
   234  func (tx *dbTx) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
   235  	if tx.closed {
   236  		// Because checking for errors can be deferred to the *Rows, build one with the error
   237  		err := ErrTxClosed
   238  		return &baseRows{closed: true, err: err}, err
   239  	}
   240  
   241  	return tx.conn.Query(ctx, sql, args...)
   242  }
   243  
   244  // QueryRow delegates to the underlying *Conn
   245  func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...any) Row {
   246  	rows, _ := tx.Query(ctx, sql, args...)
   247  	return (*connRow)(rows.(*baseRows))
   248  }
   249  
   250  // CopyFrom delegates to the underlying *Conn
   251  func (tx *dbTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
   252  	if tx.closed {
   253  		return 0, ErrTxClosed
   254  	}
   255  
   256  	return tx.conn.CopyFrom(ctx, tableName, columnNames, rowSrc)
   257  }
   258  
   259  // SendBatch delegates to the underlying *Conn
   260  func (tx *dbTx) SendBatch(ctx context.Context, b *Batch) BatchResults {
   261  	if tx.closed {
   262  		return &batchResults{err: ErrTxClosed}
   263  	}
   264  
   265  	return tx.conn.SendBatch(ctx, b)
   266  }
   267  
   268  // LargeObjects returns a LargeObjects instance for the transaction.
   269  func (tx *dbTx) LargeObjects() LargeObjects {
   270  	return LargeObjects{tx: tx}
   271  }
   272  
   273  func (tx *dbTx) Conn() *Conn {
   274  	return tx.conn
   275  }
   276  
   277  // dbSimulatedNestedTx represents a simulated nested transaction implemented by a savepoint.
   278  type dbSimulatedNestedTx struct {
   279  	tx           Tx
   280  	savepointNum int64
   281  	closed       bool
   282  }
   283  
   284  // Begin starts a pseudo nested transaction implemented with a savepoint.
   285  func (sp *dbSimulatedNestedTx) Begin(ctx context.Context) (Tx, error) {
   286  	if sp.closed {
   287  		return nil, ErrTxClosed
   288  	}
   289  
   290  	return sp.tx.Begin(ctx)
   291  }
   292  
   293  // Commit releases the savepoint essentially committing the pseudo nested transaction.
   294  func (sp *dbSimulatedNestedTx) Commit(ctx context.Context) error {
   295  	if sp.closed {
   296  		return ErrTxClosed
   297  	}
   298  
   299  	_, err := sp.Exec(ctx, "release savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10))
   300  	sp.closed = true
   301  	return err
   302  }
   303  
   304  // Rollback rolls back to the savepoint essentially rolling back the pseudo nested transaction. Rollback will return
   305  // ErrTxClosed if the dbSavepoint is already closed, but is otherwise safe to call multiple times. Hence, a defer sp.Rollback()
   306  // is safe even if sp.Commit() will be called first in a non-error condition.
   307  func (sp *dbSimulatedNestedTx) Rollback(ctx context.Context) error {
   308  	if sp.closed {
   309  		return ErrTxClosed
   310  	}
   311  
   312  	_, err := sp.Exec(ctx, "rollback to savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10))
   313  	sp.closed = true
   314  	return err
   315  }
   316  
   317  // Exec delegates to the underlying Tx
   318  func (sp *dbSimulatedNestedTx) Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) {
   319  	if sp.closed {
   320  		return pgconn.CommandTag{}, ErrTxClosed
   321  	}
   322  
   323  	return sp.tx.Exec(ctx, sql, arguments...)
   324  }
   325  
   326  // Prepare delegates to the underlying Tx
   327  func (sp *dbSimulatedNestedTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) {
   328  	if sp.closed {
   329  		return nil, ErrTxClosed
   330  	}
   331  
   332  	return sp.tx.Prepare(ctx, name, sql)
   333  }
   334  
   335  // Query delegates to the underlying Tx
   336  func (sp *dbSimulatedNestedTx) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
   337  	if sp.closed {
   338  		// Because checking for errors can be deferred to the *Rows, build one with the error
   339  		err := ErrTxClosed
   340  		return &baseRows{closed: true, err: err}, err
   341  	}
   342  
   343  	return sp.tx.Query(ctx, sql, args...)
   344  }
   345  
   346  // QueryRow delegates to the underlying Tx
   347  func (sp *dbSimulatedNestedTx) QueryRow(ctx context.Context, sql string, args ...any) Row {
   348  	rows, _ := sp.Query(ctx, sql, args...)
   349  	return (*connRow)(rows.(*baseRows))
   350  }
   351  
   352  // CopyFrom delegates to the underlying *Conn
   353  func (sp *dbSimulatedNestedTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
   354  	if sp.closed {
   355  		return 0, ErrTxClosed
   356  	}
   357  
   358  	return sp.tx.CopyFrom(ctx, tableName, columnNames, rowSrc)
   359  }
   360  
   361  // SendBatch delegates to the underlying *Conn
   362  func (sp *dbSimulatedNestedTx) SendBatch(ctx context.Context, b *Batch) BatchResults {
   363  	if sp.closed {
   364  		return &batchResults{err: ErrTxClosed}
   365  	}
   366  
   367  	return sp.tx.SendBatch(ctx, b)
   368  }
   369  
   370  func (sp *dbSimulatedNestedTx) LargeObjects() LargeObjects {
   371  	return LargeObjects{tx: sp}
   372  }
   373  
   374  func (sp *dbSimulatedNestedTx) Conn() *Conn {
   375  	return sp.tx.Conn()
   376  }
   377  
   378  // BeginFunc calls Begin on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn
   379  // returns an error it calls Rollback on db. The context will be used when executing the transaction control statements
   380  // (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn.
   381  func BeginFunc(
   382  	ctx context.Context,
   383  	db interface {
   384  		Begin(ctx context.Context) (Tx, error)
   385  	},
   386  	fn func(Tx) error,
   387  ) (err error) {
   388  	var tx Tx
   389  	tx, err = db.Begin(ctx)
   390  	if err != nil {
   391  		return err
   392  	}
   393  
   394  	return beginFuncExec(ctx, tx, fn)
   395  }
   396  
   397  // BeginTxFunc calls BeginTx on db and then calls fn. If fn does not return an error then it calls Commit on db. If fn
   398  // returns an error it calls Rollback on db. The context will be used when executing the transaction control statements
   399  // (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of fn.
   400  func BeginTxFunc(
   401  	ctx context.Context,
   402  	db interface {
   403  		BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error)
   404  	},
   405  	txOptions TxOptions,
   406  	fn func(Tx) error,
   407  ) (err error) {
   408  	var tx Tx
   409  	tx, err = db.BeginTx(ctx, txOptions)
   410  	if err != nil {
   411  		return err
   412  	}
   413  
   414  	return beginFuncExec(ctx, tx, fn)
   415  }
   416  
   417  func beginFuncExec(ctx context.Context, tx Tx, fn func(Tx) error) (err error) {
   418  	defer func() {
   419  		rollbackErr := tx.Rollback(ctx)
   420  		if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) {
   421  			err = rollbackErr
   422  		}
   423  	}()
   424  
   425  	fErr := fn(tx)
   426  	if fErr != nil {
   427  		_ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return
   428  		return fErr
   429  	}
   430  
   431  	return tx.Commit(ctx)
   432  }