github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/query_gen/accumulator.go (about)

     1  /* WIP: A version of the builder which accumulates errors, we'll see if we can't unify the implementations at some point */
     2  package qgen
     3  
     4  import (
     5  	"database/sql"
     6  	"log"
     7  	"strings"
     8  )
     9  
    10  var LogPrepares = true
    11  
    12  // So we don't have to do the qgen.Builder.Accumulator() boilerplate all the time
    13  func NewAcc() *Accumulator {
    14  	return Builder.Accumulator()
    15  }
    16  
    17  type Accumulator struct {
    18  	conn     *sql.DB
    19  	adapter  Adapter
    20  	firstErr error
    21  }
    22  
    23  func (acc *Accumulator) SetConn(conn *sql.DB) {
    24  	acc.conn = conn
    25  }
    26  
    27  func (acc *Accumulator) SetAdapter(name string) error {
    28  	adap, err := GetAdapter(name)
    29  	if err != nil {
    30  		return err
    31  	}
    32  	acc.adapter = adap
    33  	return nil
    34  }
    35  
    36  func (acc *Accumulator) GetAdapter() Adapter {
    37  	return acc.adapter
    38  }
    39  
    40  func (acc *Accumulator) FirstError() error {
    41  	return acc.firstErr
    42  }
    43  
    44  func (acc *Accumulator) RecordError(err error) {
    45  	if err == nil {
    46  		return
    47  	}
    48  	if acc.firstErr == nil {
    49  		acc.firstErr = err
    50  	}
    51  }
    52  
    53  func (acc *Accumulator) prepare(res string, err error) *sql.Stmt {
    54  	// TODO: Can we make this less noisy on debug mode?
    55  	if LogPrepares {
    56  		log.Print("res: ", res)
    57  	}
    58  	if err != nil {
    59  		acc.RecordError(err)
    60  		return nil
    61  	}
    62  	stmt, err := acc.conn.Prepare(res)
    63  	acc.RecordError(err)
    64  	return stmt
    65  }
    66  
    67  func (acc *Accumulator) RawPrepare(res string) *sql.Stmt {
    68  	return acc.prepare(res, nil)
    69  }
    70  
    71  func (acc *Accumulator) query(q string, args ...interface{}) (rows *sql.Rows, err error) {
    72  	err = acc.FirstError()
    73  	if err != nil {
    74  		return rows, err
    75  	}
    76  	return acc.conn.Query(q, args...)
    77  }
    78  
    79  func (acc *Accumulator) exec(q string, args ...interface{}) (res sql.Result, err error) {
    80  	err = acc.FirstError()
    81  	if err != nil {
    82  		return res, err
    83  	}
    84  	return acc.conn.Exec(q, args...)
    85  }
    86  
    87  func (acc *Accumulator) Tx(handler func(*TransactionBuilder) error) {
    88  	tx, err := acc.conn.Begin()
    89  	if err != nil {
    90  		acc.RecordError(err)
    91  		return
    92  	}
    93  	err = handler(&TransactionBuilder{tx, acc.adapter, nil})
    94  	if err != nil {
    95  		tx.Rollback()
    96  		acc.RecordError(err)
    97  		return
    98  	}
    99  	acc.RecordError(tx.Commit())
   100  }
   101  
   102  func (acc *Accumulator) SimpleSelect(table, columns, where, orderby, limit string) *sql.Stmt {
   103  	return acc.prepare(acc.adapter.SimpleSelect("", table, columns, where, orderby, limit))
   104  }
   105  
   106  func (acc *Accumulator) SimpleCount(table, where, limit string) *sql.Stmt {
   107  	return acc.prepare(acc.adapter.SimpleCount("", table, where, limit))
   108  }
   109  
   110  func (acc *Accumulator) SimpleLeftJoin(table1, table2, columns, joiners, where, orderby, limit string) *sql.Stmt {
   111  	return acc.prepare(acc.adapter.SimpleLeftJoin("", table1, table2, columns, joiners, where, orderby, limit))
   112  }
   113  
   114  func (acc *Accumulator) SimpleInnerJoin(table1, table2, columns, joiners, where, orderby, limit string) *sql.Stmt {
   115  	return acc.prepare(acc.adapter.SimpleInnerJoin("", table1, table2, columns, joiners, where, orderby, limit))
   116  }
   117  
   118  func (acc *Accumulator) CreateTable(table, charset, collation string, columns []DBTableColumn, keys []DBTableKey) *sql.Stmt {
   119  	return acc.prepare(acc.adapter.CreateTable("", table, charset, collation, columns, keys))
   120  }
   121  
   122  func (acc *Accumulator) SimpleInsert(table, columns, fields string) *sql.Stmt {
   123  	return acc.prepare(acc.adapter.SimpleInsert("", table, columns, fields))
   124  }
   125  
   126  func (acc *Accumulator) SimpleBulkInsert(table, cols string, fieldSet []string) *sql.Stmt {
   127  	return acc.prepare(acc.adapter.SimpleBulkInsert("", table, cols, fieldSet))
   128  }
   129  
   130  func (acc *Accumulator) SimpleInsertSelect(ins DBInsert, sel DBSelect) *sql.Stmt {
   131  	return acc.prepare(acc.adapter.SimpleInsertSelect("", ins, sel))
   132  }
   133  
   134  func (acc *Accumulator) SimpleInsertLeftJoin(ins DBInsert, sel DBJoin) *sql.Stmt {
   135  	return acc.prepare(acc.adapter.SimpleInsertLeftJoin("", ins, sel))
   136  }
   137  
   138  func (acc *Accumulator) SimpleInsertInnerJoin(ins DBInsert, sel DBJoin) *sql.Stmt {
   139  	return acc.prepare(acc.adapter.SimpleInsertInnerJoin("", ins, sel))
   140  }
   141  
   142  func (acc *Accumulator) SimpleUpdate(table, set, where string) *sql.Stmt {
   143  	return acc.prepare(acc.adapter.SimpleUpdate(qUpdate(table, set, where)))
   144  }
   145  
   146  func (acc *Accumulator) SimpleUpdateSelect(table, set, table2, cols, where, orderby, limit string) *sql.Stmt {
   147  	pre := qUpdate(table, set, "").WhereQ(acc.GetAdapter().Builder().Select().Table(table2).Columns(cols).Where(where).Orderby(orderby).Limit(limit))
   148  	return acc.prepare(acc.adapter.SimpleUpdateSelect(pre))
   149  }
   150  
   151  func (acc *Accumulator) SimpleDelete(table, where string) *sql.Stmt {
   152  	return acc.prepare(acc.adapter.SimpleDelete("", table, where))
   153  }
   154  
   155  // I don't know why you need this, but here it is x.x
   156  func (acc *Accumulator) Purge(table string) *sql.Stmt {
   157  	return acc.prepare(acc.adapter.Purge("", table))
   158  }
   159  
   160  func (acc *Accumulator) prepareTx(tx *sql.Tx, res string, err error) (stmt *sql.Stmt) {
   161  	if err != nil {
   162  		acc.RecordError(err)
   163  		return nil
   164  	}
   165  	stmt, err = tx.Prepare(res)
   166  	acc.RecordError(err)
   167  	return stmt
   168  }
   169  
   170  // These ones support transactions
   171  func (acc *Accumulator) SimpleSelectTx(tx *sql.Tx, table, columns, where, orderby, limit string) (stmt *sql.Stmt) {
   172  	res, err := acc.adapter.SimpleSelect("", table, columns, where, orderby, limit)
   173  	return acc.prepareTx(tx, res, err)
   174  }
   175  
   176  func (acc *Accumulator) SimpleCountTx(tx *sql.Tx, table, where, limit string) (stmt *sql.Stmt) {
   177  	res, err := acc.adapter.SimpleCount("", table, where, limit)
   178  	return acc.prepareTx(tx, res, err)
   179  }
   180  
   181  func (acc *Accumulator) SimpleLeftJoinTx(tx *sql.Tx, table1, table2, columns, joiners, where, orderby, limit string) (stmt *sql.Stmt) {
   182  	res, err := acc.adapter.SimpleLeftJoin("", table1, table2, columns, joiners, where, orderby, limit)
   183  	return acc.prepareTx(tx, res, err)
   184  }
   185  
   186  func (acc *Accumulator) SimpleInnerJoinTx(tx *sql.Tx, table1, table2, columns, joiners, where, orderby, limit string) (stmt *sql.Stmt) {
   187  	res, err := acc.adapter.SimpleInnerJoin("", table1, table2, columns, joiners, where, orderby, limit)
   188  	return acc.prepareTx(tx, res, err)
   189  }
   190  
   191  func (acc *Accumulator) CreateTableTx(tx *sql.Tx, table, charset, collation string, columns []DBTableColumn, keys []DBTableKey) (stmt *sql.Stmt) {
   192  	res, err := acc.adapter.CreateTable("", table, charset, collation, columns, keys)
   193  	return acc.prepareTx(tx, res, err)
   194  }
   195  
   196  func (acc *Accumulator) SimpleInsertTx(tx *sql.Tx, table, columns, fields string) (stmt *sql.Stmt) {
   197  	res, err := acc.adapter.SimpleInsert("", table, columns, fields)
   198  	return acc.prepareTx(tx, res, err)
   199  }
   200  
   201  func (acc *Accumulator) SimpleInsertSelectTx(tx *sql.Tx, ins DBInsert, sel DBSelect) (stmt *sql.Stmt) {
   202  	res, err := acc.adapter.SimpleInsertSelect("", ins, sel)
   203  	return acc.prepareTx(tx, res, err)
   204  }
   205  
   206  func (acc *Accumulator) SimpleInsertLeftJoinTx(tx *sql.Tx, ins DBInsert, sel DBJoin) (stmt *sql.Stmt) {
   207  	res, err := acc.adapter.SimpleInsertLeftJoin("", ins, sel)
   208  	return acc.prepareTx(tx, res, err)
   209  }
   210  
   211  func (acc *Accumulator) SimpleInsertInnerJoinTx(tx *sql.Tx, ins DBInsert, sel DBJoin) (stmt *sql.Stmt) {
   212  	res, err := acc.adapter.SimpleInsertInnerJoin("", ins, sel)
   213  	return acc.prepareTx(tx, res, err)
   214  }
   215  
   216  func (acc *Accumulator) SimpleUpdateTx(tx *sql.Tx, table, set, where string) (stmt *sql.Stmt) {
   217  	res, err := acc.adapter.SimpleUpdate(qUpdate(table, set, where))
   218  	return acc.prepareTx(tx, res, err)
   219  }
   220  
   221  func (acc *Accumulator) SimpleDeleteTx(tx *sql.Tx, table, where string) (stmt *sql.Stmt) {
   222  	res, err := acc.adapter.SimpleDelete("", table, where)
   223  	return acc.prepareTx(tx, res, err)
   224  }
   225  
   226  // I don't know why you need this, but here it is x.x
   227  func (acc *Accumulator) PurgeTx(tx *sql.Tx, table string) (stmt *sql.Stmt) {
   228  	res, err := acc.adapter.Purge("", table)
   229  	return acc.prepareTx(tx, res, err)
   230  }
   231  
   232  func (acc *Accumulator) Delete(table string) *accDeleteBuilder {
   233  	return &accDeleteBuilder{table, "", nil, acc}
   234  }
   235  
   236  func (acc *Accumulator) Update(table string) *accUpdateBuilder {
   237  	return &accUpdateBuilder{qUpdate(table, "", ""), acc}
   238  }
   239  
   240  func (acc *Accumulator) Select(table string) *AccSelectBuilder {
   241  	return &AccSelectBuilder{table, "", "", "", "", nil, nil, "", acc}
   242  }
   243  
   244  func (acc *Accumulator) Exists(tbl, col string) *AccSelectBuilder {
   245  	return acc.Select(tbl).Columns(col).Where(col + "=?")
   246  }
   247  
   248  func (acc *Accumulator) Insert(table string) *accInsertBuilder {
   249  	return &accInsertBuilder{table, "", "", acc}
   250  }
   251  
   252  func (acc *Accumulator) BulkInsert(table string) *accBulkInsertBuilder {
   253  	return &accBulkInsertBuilder{table, "", nil, acc}
   254  }
   255  
   256  func (acc *Accumulator) Count(table string) *accCountBuilder {
   257  	return &accCountBuilder{table, "", "", nil, nil, "", acc}
   258  }
   259  
   260  type SimpleModel struct {
   261  	delete *sql.Stmt
   262  	create *sql.Stmt
   263  	update *sql.Stmt
   264  }
   265  
   266  func (acc *Accumulator) SimpleModel(tbl, colstr, primary string) SimpleModel {
   267  	var qlist, uplist string
   268  	for _, col := range strings.Split(colstr, ",") {
   269  		qlist += "?,"
   270  		uplist += col + "=?,"
   271  	}
   272  	if len(qlist) > 0 {
   273  		qlist = qlist[0 : len(qlist)-1]
   274  		uplist = uplist[0 : len(uplist)-1]
   275  	}
   276  
   277  	where := primary + "=?"
   278  	return SimpleModel{
   279  		delete: acc.Delete(tbl).Where(where).Prepare(),
   280  		create: acc.Insert(tbl).Columns(colstr).Fields(qlist).Prepare(),
   281  		update: acc.Update(tbl).Set(uplist).Where(where).Prepare(),
   282  	}
   283  }
   284  
   285  func (m SimpleModel) Delete(keyVal interface{}) error {
   286  	_, err := m.delete.Exec(keyVal)
   287  	return err
   288  }
   289  
   290  func (m SimpleModel) Update(args ...interface{}) error {
   291  	_, err := m.update.Exec(args...)
   292  	return err
   293  }
   294  
   295  func (m SimpleModel) Create(args ...interface{}) error {
   296  	_, err := m.create.Exec(args...)
   297  	return err
   298  }
   299  
   300  func (m SimpleModel) CreateID(args ...interface{}) (int, error) {
   301  	res, err := m.create.Exec(args...)
   302  	if err != nil {
   303  		return 0, err
   304  	}
   305  	lastID, err := res.LastInsertId()
   306  	return int(lastID), err
   307  }
   308  
   309  func (acc *Accumulator) Model(table string) *accModelBuilder {
   310  	return &accModelBuilder{table, "", acc}
   311  }
   312  
   313  type accModelBuilder struct {
   314  	table   string
   315  	primary string
   316  
   317  	build *Accumulator
   318  }
   319  
   320  func (b *accModelBuilder) Primary(col string) *accModelBuilder {
   321  	b.primary = col
   322  	return b
   323  }