github.com/acoshift/pgsql@v0.15.3/tx.go (about)

     1  package pgsql
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  )
     8  
     9  // ErrAbortTx rollbacks transaction and return nil error
    10  var ErrAbortTx = errors.New("pgsql: abort tx")
    11  
    12  // BeginTxer type
    13  type BeginTxer interface {
    14  	BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error)
    15  }
    16  
    17  // TxOptions is the transaction options
    18  type TxOptions struct {
    19  	sql.TxOptions
    20  	MaxAttempts int
    21  }
    22  
    23  const (
    24  	defaultMaxAttempts = 10
    25  )
    26  
    27  // RunInTx runs fn inside retryable transaction.
    28  //
    29  // see RunInTxContext for more info.
    30  func RunInTx(db BeginTxer, opts *TxOptions, fn func(*sql.Tx) error) error {
    31  	return RunInTxContext(context.Background(), db, opts, fn)
    32  }
    33  
    34  // RunInTxContext runs fn inside retryable transaction with context.
    35  // It use Serializable isolation level if tx options isolation is setted to sql.LevelDefault.
    36  //
    37  // RunInTxContext DO NOT handle panic.
    38  // But when panic, it will rollback the transaction.
    39  func RunInTxContext(ctx context.Context, db BeginTxer, opts *TxOptions, fn func(*sql.Tx) error) error {
    40  	option := TxOptions{
    41  		TxOptions: sql.TxOptions{
    42  			Isolation: sql.LevelSerializable,
    43  		},
    44  		MaxAttempts: defaultMaxAttempts,
    45  	}
    46  
    47  	if opts != nil {
    48  		if opts.MaxAttempts > 0 {
    49  			option.MaxAttempts = opts.MaxAttempts
    50  		}
    51  		option.TxOptions = opts.TxOptions
    52  
    53  		// override default isolation level to serializable
    54  		if opts.Isolation == sql.LevelDefault {
    55  			option.Isolation = sql.LevelSerializable
    56  		}
    57  	}
    58  
    59  	f := func() error {
    60  		tx, err := db.BeginTx(ctx, &option.TxOptions)
    61  		if err != nil {
    62  			return err
    63  		}
    64  		// use defer to also rollback when panic
    65  		defer tx.Rollback()
    66  
    67  		err = fn(tx)
    68  		if err != nil {
    69  			return err
    70  		}
    71  		return tx.Commit()
    72  	}
    73  
    74  	var err error
    75  	for i := 0; i < option.MaxAttempts; i++ {
    76  		err = f()
    77  		if err == nil || errors.Is(err, ErrAbortTx) {
    78  			return nil
    79  		}
    80  		if !IsSerializationFailure(err) {
    81  			return err
    82  		}
    83  	}
    84  
    85  	return err
    86  }