github.com/shuguocloud/go-zero@v1.3.0/core/stores/sqlx/tx.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  )
     7  
     8  type (
     9  	beginnable func(*sql.DB) (trans, error)
    10  
    11  	trans interface {
    12  		Session
    13  		Commit() error
    14  		Rollback() error
    15  	}
    16  
    17  	txSession struct {
    18  		*sql.Tx
    19  	}
    20  )
    21  
    22  // NewSessionFromTx returns a Session with the given sql.Tx.
    23  // Use it with caution, it's provided for other ORM to interact with.
    24  func NewSessionFromTx(tx *sql.Tx) Session {
    25  	return txSession{Tx: tx}
    26  }
    27  
    28  func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) {
    29  	return exec(t.Tx, q, args...)
    30  }
    31  
    32  func (t txSession) Prepare(q string) (StmtSession, error) {
    33  	stmt, err := t.Tx.Prepare(q)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  
    38  	return statement{
    39  		query: q,
    40  		stmt:  stmt,
    41  	}, nil
    42  }
    43  
    44  func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error {
    45  	return query(t.Tx, func(rows *sql.Rows) error {
    46  		return unmarshalRow(v, rows, true)
    47  	}, q, args...)
    48  }
    49  
    50  func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error {
    51  	return query(t.Tx, func(rows *sql.Rows) error {
    52  		return unmarshalRow(v, rows, false)
    53  	}, q, args...)
    54  }
    55  
    56  func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error {
    57  	return query(t.Tx, func(rows *sql.Rows) error {
    58  		return unmarshalRows(v, rows, true)
    59  	}, q, args...)
    60  }
    61  
    62  func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error {
    63  	return query(t.Tx, func(rows *sql.Rows) error {
    64  		return unmarshalRows(v, rows, false)
    65  	}, q, args...)
    66  }
    67  
    68  func begin(db *sql.DB) (trans, error) {
    69  	tx, err := db.Begin()
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	return txSession{
    75  		Tx: tx,
    76  	}, nil
    77  }
    78  
    79  func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) {
    80  	conn, err := db.connProv()
    81  	if err != nil {
    82  		db.onError(err)
    83  		return err
    84  	}
    85  
    86  	return transactOnConn(conn, b, fn)
    87  }
    88  
    89  func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) {
    90  	var tx trans
    91  	tx, err = b(conn)
    92  	if err != nil {
    93  		return
    94  	}
    95  
    96  	defer func() {
    97  		if p := recover(); p != nil {
    98  			if e := tx.Rollback(); e != nil {
    99  				err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e)
   100  			} else {
   101  				err = fmt.Errorf("recoveer from %#v", p)
   102  			}
   103  		} else if err != nil {
   104  			if e := tx.Rollback(); e != nil {
   105  				err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e)
   106  			}
   107  		} else {
   108  			err = tx.Commit()
   109  		}
   110  	}()
   111  
   112  	return fn(tx)
   113  }