github.com/lingyao2333/mo-zero@v1.4.1/core/stores/sqlx/tx.go (about) 1 package sqlx 2 3 import ( 4 "context" 5 "database/sql" 6 "fmt" 7 ) 8 9 type ( 10 beginnable func(*sql.DB) (trans, error) 11 12 trans interface { 13 Session 14 Commit() error 15 Rollback() error 16 } 17 18 txSession struct { 19 *sql.Tx 20 } 21 ) 22 23 // NewSessionFromTx returns a Session with the given sql.Tx. 24 // Use it with caution, it's provided for other ORM to interact with. 25 func NewSessionFromTx(tx *sql.Tx) Session { 26 return txSession{Tx: tx} 27 } 28 29 func (t txSession) Exec(q string, args ...interface{}) (sql.Result, error) { 30 return t.ExecCtx(context.Background(), q, args...) 31 } 32 33 func (t txSession) ExecCtx(ctx context.Context, q string, args ...interface{}) (result sql.Result, err error) { 34 ctx, span := startSpan(ctx, "Exec") 35 defer func() { 36 endSpan(span, err) 37 }() 38 39 result, err = exec(ctx, t.Tx, q, args...) 40 41 return 42 } 43 44 func (t txSession) Prepare(q string) (StmtSession, error) { 45 return t.PrepareCtx(context.Background(), q) 46 } 47 48 func (t txSession) PrepareCtx(ctx context.Context, q string) (stmtSession StmtSession, err error) { 49 ctx, span := startSpan(ctx, "Prepare") 50 defer func() { 51 endSpan(span, err) 52 }() 53 54 stmt, err := t.Tx.PrepareContext(ctx, q) 55 if err != nil { 56 return nil, err 57 } 58 59 return statement{ 60 query: q, 61 stmt: stmt, 62 }, nil 63 } 64 65 func (t txSession) QueryRow(v interface{}, q string, args ...interface{}) error { 66 return t.QueryRowCtx(context.Background(), v, q, args...) 67 } 68 69 func (t txSession) QueryRowCtx(ctx context.Context, v interface{}, q string, args ...interface{}) (err error) { 70 ctx, span := startSpan(ctx, "QueryRow") 71 defer func() { 72 endSpan(span, err) 73 }() 74 75 return query(ctx, t.Tx, func(rows *sql.Rows) error { 76 return unmarshalRow(v, rows, true) 77 }, q, args...) 78 } 79 80 func (t txSession) QueryRowPartial(v interface{}, q string, args ...interface{}) error { 81 return t.QueryRowPartialCtx(context.Background(), v, q, args...) 82 } 83 84 func (t txSession) QueryRowPartialCtx(ctx context.Context, v interface{}, q string, 85 args ...interface{}) (err error) { 86 ctx, span := startSpan(ctx, "QueryRowPartial") 87 defer func() { 88 endSpan(span, err) 89 }() 90 91 return query(ctx, t.Tx, func(rows *sql.Rows) error { 92 return unmarshalRow(v, rows, false) 93 }, q, args...) 94 } 95 96 func (t txSession) QueryRows(v interface{}, q string, args ...interface{}) error { 97 return t.QueryRowsCtx(context.Background(), v, q, args...) 98 } 99 100 func (t txSession) QueryRowsCtx(ctx context.Context, v interface{}, q string, args ...interface{}) (err error) { 101 ctx, span := startSpan(ctx, "QueryRows") 102 defer func() { 103 endSpan(span, err) 104 }() 105 106 return query(ctx, t.Tx, func(rows *sql.Rows) error { 107 return unmarshalRows(v, rows, true) 108 }, q, args...) 109 } 110 111 func (t txSession) QueryRowsPartial(v interface{}, q string, args ...interface{}) error { 112 return t.QueryRowsPartialCtx(context.Background(), v, q, args...) 113 } 114 115 func (t txSession) QueryRowsPartialCtx(ctx context.Context, v interface{}, q string, 116 args ...interface{}) (err error) { 117 ctx, span := startSpan(ctx, "QueryRowsPartial") 118 defer func() { 119 endSpan(span, err) 120 }() 121 122 return query(ctx, t.Tx, func(rows *sql.Rows) error { 123 return unmarshalRows(v, rows, false) 124 }, q, args...) 125 } 126 127 func begin(db *sql.DB) (trans, error) { 128 tx, err := db.Begin() 129 if err != nil { 130 return nil, err 131 } 132 133 return txSession{ 134 Tx: tx, 135 }, nil 136 } 137 138 func transact(ctx context.Context, db *commonSqlConn, b beginnable, 139 fn func(context.Context, Session) error) (err error) { 140 conn, err := db.connProv() 141 if err != nil { 142 db.onError(err) 143 return err 144 } 145 146 return transactOnConn(ctx, conn, b, fn) 147 } 148 149 func transactOnConn(ctx context.Context, conn *sql.DB, b beginnable, 150 fn func(context.Context, Session) error) (err error) { 151 var tx trans 152 tx, err = b(conn) 153 if err != nil { 154 return 155 } 156 157 defer func() { 158 if p := recover(); p != nil { 159 if e := tx.Rollback(); e != nil { 160 err = fmt.Errorf("recover from %#v, rollback failed: %w", p, e) 161 } else { 162 err = fmt.Errorf("recoveer from %#v", p) 163 } 164 } else if err != nil { 165 if e := tx.Rollback(); e != nil { 166 err = fmt.Errorf("transaction failed: %s, rollback failed: %w", err, e) 167 } 168 } else { 169 err = tx.Commit() 170 } 171 }() 172 173 return fn(ctx, tx) 174 }