github.com/movsb/taorm@v0.0.0-20201209183410-91bafb0b22a6/taorm/db.go (about)

     1  package taorm
     2  
     3  import (
     4  	"database/sql"
     5  )
     6  
     7  // DB wraps sql.DB.
     8  type DB struct {
     9  	rdb *sql.DB // raw db
    10  	_SQLCommon
    11  }
    12  
    13  // NewDB news a DB.
    14  func NewDB(db *sql.DB) *DB {
    15  	t := &DB{
    16  		rdb:        db,
    17  		_SQLCommon: db,
    18  	}
    19  	return t
    20  }
    21  
    22  // TxCall calls callback within transaction.
    23  // It automatically catches and rethrows exceptions.
    24  func (db *DB) TxCall(callback func(tx *DB) error) error {
    25  	rtx, err := db.rdb.Begin()
    26  	if err != nil {
    27  		return WrapError(err)
    28  	}
    29  
    30  	tx := &DB{
    31  		rdb:        db.rdb,
    32  		_SQLCommon: rtx,
    33  	}
    34  
    35  	var exception struct {
    36  		caught bool        // user callback threw an exception
    37  		what   interface{} // user thrown exception
    38  	}
    39  
    40  	catchCall := func() (err error) {
    41  		called := false
    42  		defer func() {
    43  			exception.what = recover()
    44  			exception.caught = !called
    45  		}()
    46  		err = callback(tx)
    47  		called = true
    48  		return
    49  	}
    50  
    51  	if err := catchCall(); err != nil {
    52  		rtx.Rollback()
    53  		return err // user error, not wrapped
    54  	}
    55  
    56  	if exception.caught {
    57  		rtx.Rollback()
    58  		panic(exception.what) // user exception, not wrapped
    59  	}
    60  
    61  	if err = rtx.Commit(); err != nil {
    62  		rtx.Rollback()
    63  		return WrapError(err)
    64  	}
    65  
    66  	return nil
    67  }
    68  
    69  // MustTxCall ...
    70  func (db *DB) MustTxCall(callback func(tx *DB)) {
    71  	if err := db.TxCall(func(tx *DB) error {
    72  		callback(tx)
    73  		return nil
    74  	}); err != nil {
    75  		panic(err)
    76  	}
    77  }
    78  
    79  // Model ...
    80  func (db *DB) Model(model interface{}) *Stmt {
    81  	stmt := &Stmt{
    82  		db:         db,
    83  		model:      model,
    84  		tableNames: []string{},
    85  		limit:      -1,
    86  		offset:     -1,
    87  	}
    88  
    89  	info, err := getRegistered(model)
    90  	if err != nil {
    91  		panic(WrapError(err))
    92  	}
    93  
    94  	stmt.tableNames = append(stmt.tableNames, info.tableName)
    95  
    96  	stmt.info = info
    97  
    98  	return stmt
    99  }
   100  
   101  // From ...
   102  func (db *DB) From(table interface{}) *Stmt {
   103  	s := &Stmt{
   104  		db:     db,
   105  		limit:  -1,
   106  		offset: -1,
   107  	}
   108  	name, err := s.tryFindTableName(table)
   109  	if err != nil {
   110  		panic(WrapError(err))
   111  	}
   112  	s.tableNames = append(s.tableNames, name)
   113  	s.fromTable = table
   114  	return s
   115  }
   116  
   117  // Raw executes a raw SQL query that returns rows.
   118  func (db *DB) Raw(query string, args ...interface{}) Finder {
   119  	stmt := &Stmt{
   120  		db: db,
   121  	}
   122  	stmt.raw.query = query
   123  	stmt.raw.args = args
   124  	return stmt
   125  }
   126  
   127  // --- stmt impl. ---
   128  //
   129  // Below are some commonly used functions to begin a preparing.
   130  
   131  // MustExec ...
   132  func (db *DB) MustExec(query string, args ...interface{}) sql.Result {
   133  	result, err := db.Exec(query, args...)
   134  	if err != nil {
   135  		panic(WrapError(err))
   136  	}
   137  	return result
   138  }
   139  
   140  func (db *DB) _New() *Stmt {
   141  	stmt := &Stmt{
   142  		db:     db,
   143  		limit:  -1,
   144  		offset: -1,
   145  	}
   146  	return stmt
   147  }
   148  
   149  // Select ...
   150  func (db *DB) Select(fields string) *Stmt {
   151  	return db._New().Select(fields)
   152  }
   153  
   154  // Where ...
   155  func (db *DB) Where(query string, args ...interface{}) *Stmt {
   156  	return db._New().Where(query, args...)
   157  }
   158  
   159  // WhereIf ...
   160  func (db *DB) WhereIf(cond bool, query string, args ...interface{}) *Stmt {
   161  	return db._New().WhereIf(cond, query, args...)
   162  }
   163  
   164  // Find ...
   165  func (db *DB) Find(out interface{}) error {
   166  	return db._New().Find(out)
   167  }
   168  
   169  // MustFind ...
   170  func (db *DB) MustFind(out interface{}) {
   171  	db._New().MustFind(out)
   172  }