github.com/movsb/taorm@v0.0.0-20201209183410-91bafb0b22a6/taorm/db.go (about) 1 package taorm 2 3 import ( 4 "database/sql" 5 ) 6 7 // DB wraps sql.DB. 8 type DB struct { 9 rdb *sql.DB // raw db 10 _SQLCommon 11 } 12 13 // NewDB news a DB. 14 func NewDB(db *sql.DB) *DB { 15 t := &DB{ 16 rdb: db, 17 _SQLCommon: db, 18 } 19 return t 20 } 21 22 // TxCall calls callback within transaction. 23 // It automatically catches and rethrows exceptions. 24 func (db *DB) TxCall(callback func(tx *DB) error) error { 25 rtx, err := db.rdb.Begin() 26 if err != nil { 27 return WrapError(err) 28 } 29 30 tx := &DB{ 31 rdb: db.rdb, 32 _SQLCommon: rtx, 33 } 34 35 var exception struct { 36 caught bool // user callback threw an exception 37 what interface{} // user thrown exception 38 } 39 40 catchCall := func() (err error) { 41 called := false 42 defer func() { 43 exception.what = recover() 44 exception.caught = !called 45 }() 46 err = callback(tx) 47 called = true 48 return 49 } 50 51 if err := catchCall(); err != nil { 52 rtx.Rollback() 53 return err // user error, not wrapped 54 } 55 56 if exception.caught { 57 rtx.Rollback() 58 panic(exception.what) // user exception, not wrapped 59 } 60 61 if err = rtx.Commit(); err != nil { 62 rtx.Rollback() 63 return WrapError(err) 64 } 65 66 return nil 67 } 68 69 // MustTxCall ... 70 func (db *DB) MustTxCall(callback func(tx *DB)) { 71 if err := db.TxCall(func(tx *DB) error { 72 callback(tx) 73 return nil 74 }); err != nil { 75 panic(err) 76 } 77 } 78 79 // Model ... 80 func (db *DB) Model(model interface{}) *Stmt { 81 stmt := &Stmt{ 82 db: db, 83 model: model, 84 tableNames: []string{}, 85 limit: -1, 86 offset: -1, 87 } 88 89 info, err := getRegistered(model) 90 if err != nil { 91 panic(WrapError(err)) 92 } 93 94 stmt.tableNames = append(stmt.tableNames, info.tableName) 95 96 stmt.info = info 97 98 return stmt 99 } 100 101 // From ... 102 func (db *DB) From(table interface{}) *Stmt { 103 s := &Stmt{ 104 db: db, 105 limit: -1, 106 offset: -1, 107 } 108 name, err := s.tryFindTableName(table) 109 if err != nil { 110 panic(WrapError(err)) 111 } 112 s.tableNames = append(s.tableNames, name) 113 s.fromTable = table 114 return s 115 } 116 117 // Raw executes a raw SQL query that returns rows. 118 func (db *DB) Raw(query string, args ...interface{}) Finder { 119 stmt := &Stmt{ 120 db: db, 121 } 122 stmt.raw.query = query 123 stmt.raw.args = args 124 return stmt 125 } 126 127 // --- stmt impl. --- 128 // 129 // Below are some commonly used functions to begin a preparing. 130 131 // MustExec ... 132 func (db *DB) MustExec(query string, args ...interface{}) sql.Result { 133 result, err := db.Exec(query, args...) 134 if err != nil { 135 panic(WrapError(err)) 136 } 137 return result 138 } 139 140 func (db *DB) _New() *Stmt { 141 stmt := &Stmt{ 142 db: db, 143 limit: -1, 144 offset: -1, 145 } 146 return stmt 147 } 148 149 // Select ... 150 func (db *DB) Select(fields string) *Stmt { 151 return db._New().Select(fields) 152 } 153 154 // Where ... 155 func (db *DB) Where(query string, args ...interface{}) *Stmt { 156 return db._New().Where(query, args...) 157 } 158 159 // WhereIf ... 160 func (db *DB) WhereIf(cond bool, query string, args ...interface{}) *Stmt { 161 return db._New().WhereIf(cond, query, args...) 162 } 163 164 // Find ... 165 func (db *DB) Find(out interface{}) error { 166 return db._New().Find(out) 167 } 168 169 // MustFind ... 170 func (db *DB) MustFind(out interface{}) { 171 db._New().MustFind(out) 172 }