github.com/shuguocloud/go-zero@v1.3.0/core/stores/sqlx/tx.go (about) 1 package sqlx 2 3 import ( 4 "database/sql" 5 "fmt" 6 ) 7 8 type ( 9 beginnable func(*sql.DB) (trans, error) 10 11 trans interface { 12 Session 13 Commit() error 14 Rollback() error 15 } 16 17 txSession struct { 18 *sql.Tx 19 } 20 ) 21 22 // NewSessionFromTx returns a Session with the given sql.Tx. 23 // Use it with caution, it's provided for other ORM to interact with. 24 func NewSessionFromTx(tx *sql.Tx) Session { 25 return txSession{Tx: tx} 26 } 27 28 func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) { 29 return exec(t.Tx, q, args...) 30 } 31 32 func (t txSession) Prepare(q string) (StmtSession, error) { 33 stmt, err := t.Tx.Prepare(q) 34 if err != nil { 35 return nil, err 36 } 37 38 return statement{ 39 query: q, 40 stmt: stmt, 41 }, nil 42 } 43 44 func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error { 45 return query(t.Tx, func(rows *sql.Rows) error { 46 return unmarshalRow(v, rows, true) 47 }, q, args...) 48 } 49 50 func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error { 51 return query(t.Tx, func(rows *sql.Rows) error { 52 return unmarshalRow(v, rows, false) 53 }, q, args...) 54 } 55 56 func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error { 57 return query(t.Tx, func(rows *sql.Rows) error { 58 return unmarshalRows(v, rows, true) 59 }, q, args...) 60 } 61 62 func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { 63 return query(t.Tx, func(rows *sql.Rows) error { 64 return unmarshalRows(v, rows, false) 65 }, q, args...) 66 } 67 68 func begin(db *sql.DB) (trans, error) { 69 tx, err := db.Begin() 70 if err != nil { 71 return nil, err 72 } 73 74 return txSession{ 75 Tx: tx, 76 }, nil 77 } 78 79 func transact(db *commonSqlConn, b beginnable, fn func(Session) error) (err error) { 80 conn, err := db.connProv() 81 if err != nil { 82 db.onError(err) 83 return err 84 } 85 86 return transactOnConn(conn, b, fn) 87 } 88 89 func transactOnConn(conn *sql.DB, b beginnable, fn func(Session) error) (err error) { 90 var tx trans 91 tx, err = b(conn) 92 if err != nil { 93 return 94 } 95 96 defer func() { 97 if p := recover(); p != nil { 98 if e := tx.Rollback(); e != nil { 99 err = fmt.Errorf("recover from %#v, rollback failed: %s", p, e) 100 } else { 101 err = fmt.Errorf("recoveer from %#v", p) 102 } 103 } else if err != nil { 104 if e := tx.Rollback(); e != nil { 105 err = fmt.Errorf("transaction failed: %s, rollback failed: %s", err, e) 106 } 107 } else { 108 err = tx.Commit() 109 } 110 }() 111 112 return fn(tx) 113 }