github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/tx.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  )
     8  
     9  type (
    10  	beginnable func(*sql.DB) (trans, error)
    11  
    12  	trans interface {
    13  		Session
    14  		Commit() error
    15  		Rollback() error
    16  	}
    17  
    18  	txSession struct {
    19  		*sql.Tx
    20  	}
    21  )
    22  
    23  // NewSessionFromTx returns a Session with the given sql.Tx.
    24  // Use it with caution, it's provided for other ORM to interact with.
    25  func NewSessionFromTx(tx *sql.Tx) Session {
    26  	return txSession{Tx: tx}
    27  }
    28  
    29  func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
    30  	return t.ExecCtx(context.Background(), q, args...)
    31  }
    32  
    33  func (t txSession) ExecCtx(ctx context.Context, q string, args ...interface{}) (result sql.Result, err error) {
    34  	ctx, span := startSpan(ctx, "Exec")
    35  	defer func() {
    36  		endSpan(span, err)
    37  	}()
    38  
    39  	result, err = exec(ctx, t.Tx, q, args...)
    40  
    41  	return
    42  }
    43  
    44  func (t txSession) Prepare(q string) (StmtSession, error) {
    45  	return t.PrepareCtx(context.Background(), q)
    46  }
    47  
    48  func (t txSession) PrepareCtx(ctx context.Context, q string) (stmtSession StmtSession, err error) {
    49  	ctx, span := startSpan(ctx, "Prepare")
    50  	defer func() {
    51  		endSpan(span, err)
    52  	}()
    53  
    54  	stmt, err := t.Tx.PrepareContext(ctx, q)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  
    59  	return statement{
    60  		query: q,
    61  		stmt:  stmt,
    62  	}, nil
    63  }
    64  
    65  func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
    66  	return t.QueryRowCtx(context.Background(), v, q, args...)
    67  }
    68  
    69  func (t txSession) QueryRowCtx(ctx context.Context, v interface{}, q string, args ...interface{}) (err error) {
    70  	ctx, span := startSpan(ctx, "QueryRow")
    71  	defer func() {
    72  		endSpan(span, err)
    73  	}()
    74  
    75  	return query(ctx, t.Tx, func(rows *sql.Rows) error {
    76  		return unmarshalRow(v, rows, true)
    77  	}, q, args...)
    78  }
    79  
    80  func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
    81  	return t.QueryRowPartialCtx(context.Background(), v, q, args...)
    82  }
    83  
    84  func (t txSession) QueryRowPartialCtx(ctx context.Context, v interface{}, q string,
    85  	args ...interface{}) (err error) {
    86  	ctx, span := startSpan(ctx, "QueryRowPartial")
    87  	defer func() {
    88  		endSpan(span, err)
    89  	}()
    90  
    91  	return query(ctx, t.Tx, func(rows *sql.Rows) error {
    92  		return unmarshalRow(v, rows, false)
    93  	}, q, args...)
    94  }
    95  
    96  func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
    97  	return t.QueryRowsCtx(context.Background(), v, q, args...)
    98  }
    99  
   100  func (t txSession) QueryRowsCtx(ctx context.Context, v interface{}, q string, args ...interface{}) (err error) {
   101  	ctx, span := startSpan(ctx, "QueryRows")
   102  	defer func() {
   103  		endSpan(span, err)
   104  	}()
   105  
   106  	return query(ctx, t.Tx, func(rows *sql.Rows) error {
   107  		return unmarshalRows(v, rows, true)
   108  	}, q, args...)
   109  }
   110  
   111  func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
   112  	return t.QueryRowsPartialCtx(context.Background(), v, q, args...)
   113  }
   114  
   115  func (t txSession) QueryRowsPartialCtx(ctx context.Context, v interface{}, q string,
   116  	args ...interface{}) (err error) {
   117  	ctx, span := startSpan(ctx, "QueryRowsPartial")
   118  	defer func() {
   119  		endSpan(span, err)
   120  	}()
   121  
   122  	return query(ctx, t.Tx, func(rows *sql.Rows) error {
   123  		return unmarshalRows(v, rows, false)
   124  	}, q, args...)
   125  }
   126  
   127  func begin(db *sql.DB) (trans, error) {
   128  	tx, err := db.Begin()
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	return txSession{
   134  		Tx: tx,
   135  	}, nil
   136  }
   137  
   138  func transact(ctx context.Context, db *commonSqlConn, b beginnable,
   139  	fn func(context.Context, Session) error) (err error) {
   140  	conn, err := db.connProv()
   141  	if err != nil {
   142  		db.onError(err)
   143  		return err
   144  	}
   145  
   146  	return transactOnConn(ctx, conn, b, fn)
   147  }
   148  
   149  func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable,
   150  	fn func(context.Context, Session) error) (err error) {
   151  	var tx trans
   152  	tx, err = b(conn)
   153  	if err != nil {
   154  		return
   155  	}
   156  
   157  	defer func() {
   158  		if p := recover(); p != nil {
   159  			if e := tx.Rollback(); e != nil {
   160  				err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e)
   161  			} else {
   162  				err = fmt.Errorf("recoveer from %#v", p)
   163  			}
   164  		} else if err != nil {
   165  			if e := tx.Rollback(); e != nil {
   166  				err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e)
   167  			}
   168  		} else {
   169  			err = tx.Commit()
   170  		}
   171  	}()
   172  
   173  	return fn(ctx, tx)
   174  }