github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/internal/sql/adapter/db.go (about) 1 package adapter 2 3 import ( 4 "context" 5 "database/sql" 6 "runtime" 7 8 "github.com/octohelm/storage/pkg/sqlbuilder" 9 ) 10 11 func Wrap(d *sql.DB, convertErr func(err error) error) DB { 12 return &db{ 13 DB: d, 14 option: option{ 15 convertErr: convertErr, 16 }, 17 } 18 } 19 20 type option struct { 21 convertErr func(err error) error 22 } 23 24 type db struct { 25 option 26 *sql.DB 27 } 28 29 func (d *db) Exec(ctx context.Context, expr sqlbuilder.SqlExpr) (sql.Result, error) { 30 e := sqlbuilder.ResolveExprContext(ctx, expr) 31 if sqlbuilder.IsNilExpr(e) { 32 return nil, nil 33 } 34 if err := e.Err(); err != nil { 35 return nil, d.convertErr(err) 36 } 37 38 if sqlDo := SqlDoFromContext(ctx); sqlDo != nil { 39 result, err := sqlDo.ExecContext(ctx, e.Query(), e.Args()...) 40 if err != nil { 41 return nil, d.convertErr(err) 42 } 43 return result, nil 44 } 45 46 result, err := d.ExecContext(ctx, e.Query(), e.Args()...) 47 if err != nil { 48 return nil, d.convertErr(err) 49 } 50 return result, nil 51 } 52 53 func (d *db) Query(ctx context.Context, expr sqlbuilder.SqlExpr) (*sql.Rows, error) { 54 e := sqlbuilder.ResolveExprContext(ctx, expr) 55 if sqlbuilder.IsNilExpr(e) { 56 return nil, nil 57 } 58 if err := e.Err(); err != nil { 59 return nil, err 60 } 61 if sqlDo := SqlDoFromContext(ctx); sqlDo != nil { 62 return sqlDo.QueryContext(ctx, e.Query(), e.Args()...) 63 } 64 return d.QueryContext(ctx, e.Query(), e.Args()...) 65 } 66 67 func (d *db) Transaction(ctx context.Context, action func(ctx context.Context) error) (err error) { 68 var inScopeOfTxnCreated = false 69 var txn *sql.Tx 70 71 if sqlDo := SqlDoFromContext(ctx); sqlDo != nil { 72 if tx, ok := sqlDo.(*sql.Tx); ok { 73 txn = tx 74 } 75 } 76 77 if txn == nil { 78 tx, err := d.BeginTx(ctx, nil) 79 if err != nil { 80 return err 81 } 82 inScopeOfTxnCreated = true 83 txn = tx 84 } 85 86 defer func() { 87 if p := recover(); p != nil { 88 // make sure rollack 89 _ = txn.Rollback() 90 91 switch e := p.(type) { 92 case runtime.Error: 93 panic(e) 94 case error: 95 err = e 96 default: 97 panic(e) 98 } 99 } else if inScopeOfTxnCreated { 100 if err != nil { 101 _ = txn.Rollback() 102 } else { 103 err = txn.Commit() 104 } 105 } 106 }() 107 108 return action(ContextWithSqlDo(ctx, txn)) 109 }