github.com/movsb/taorm@v0.0.0-20201209183410-91bafb0b22a6/taorm/stmt.go (about)

     1  package taorm
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"reflect"
     9  	"regexp"
    10  	"strings"
    11  
    12  	"github.com/movsb/taorm/filter"
    13  )
    14  
    15  // _Where ...
    16  type _Where struct {
    17  	query string
    18  	args  []interface{}
    19  }
    20  
    21  func (w _Where) build() (query string, args []interface{}) {
    22  	sb := bytes.NewBuffer(nil)
    23  	sb.Grow(len(query)) // should we reserve capacity for slice too?
    24  	var i int
    25  	for _, c := range w.query {
    26  		switch c {
    27  		case '?':
    28  			if i >= len(w.args) {
    29  				panic(fmt.Errorf("err where args count"))
    30  			}
    31  			value := reflect.ValueOf(w.args[i])
    32  			if value.Kind() == reflect.Slice {
    33  				n := value.Len()
    34  				sb.WriteString(createSQLInMarks(n))
    35  				for j := 0; j < n; j++ {
    36  					args = append(args, value.Index(j).Interface())
    37  				}
    38  			} else {
    39  				sb.WriteByte('?')
    40  				args = append(args, w.args[i])
    41  			}
    42  			i++
    43  		default:
    44  			sb.WriteRune(c)
    45  		}
    46  	}
    47  	if i != len(w.args) {
    48  		panic(fmt.Errorf("err where args count"))
    49  	}
    50  	return sb.String(), args
    51  }
    52  
    53  // _Expr is a raw SQL expression.
    54  //
    55  // e.g.: `UPDATE sth SET left = right`, here `right` is the expression.
    56  //
    57  // TODO expr args cannot be slice.
    58  type _Expr _Where
    59  
    60  // Expr creates an expression for Update* operations.
    61  func Expr(expr string, args ...interface{}) _Expr {
    62  	return _Expr{
    63  		query: expr,
    64  		args:  args,
    65  	}
    66  }
    67  
    68  type _RawQuery struct {
    69  	query string
    70  	args  []interface{}
    71  }
    72  
    73  // Stmt is an SQL statement.
    74  type Stmt struct {
    75  	db              *DB
    76  	raw             _RawQuery // not set if query == ""
    77  	model           interface{}
    78  	fromTable       interface{}
    79  	info            *_StructInfo
    80  	tableNames      []string
    81  	innerJoinTables []string
    82  	fields          []string
    83  	ands            []_Where
    84  	groupBy         string
    85  	having          string
    86  	orderBy         string
    87  	limit           int64
    88  	offset          int64
    89  }
    90  
    91  // From ...
    92  // table can be either string or struct.
    93  func (s *Stmt) From(table interface{}) *Stmt {
    94  	switch typed := table.(type) {
    95  	case string:
    96  		s.tableNames = append(s.tableNames, typed)
    97  	default:
    98  		name, err := s.tryFindTableName(table)
    99  		if err != nil {
   100  			panic(WrapError(err))
   101  		}
   102  		s.tableNames = append(s.tableNames, name)
   103  	}
   104  	return s
   105  }
   106  
   107  // InnerJoin ...
   108  func (s *Stmt) InnerJoin(table interface{}, on string) *Stmt {
   109  	name := ""
   110  	switch typed := table.(type) {
   111  	case string:
   112  		name = typed
   113  	default:
   114  		n, err := s.tryFindTableName(typed)
   115  		if err != nil {
   116  			panic(WrapError(err))
   117  		}
   118  		name = n
   119  	}
   120  
   121  	q := " INNER JOIN " + name
   122  	if on != "" {
   123  		q += " ON " + on
   124  	}
   125  	s.innerJoinTables = append(s.innerJoinTables, q)
   126  	return s
   127  }
   128  
   129  // Select ...
   130  func (s *Stmt) Select(fields string) *Stmt {
   131  	if len(fields) > 0 {
   132  		s.fields = append(s.fields, fields)
   133  	}
   134  	return s
   135  }
   136  
   137  // Where ...
   138  func (s *Stmt) Where(query string, args ...interface{}) *Stmt {
   139  	w := _Where{
   140  		query: query,
   141  		args:  args,
   142  	}
   143  	s.ands = append(s.ands, w)
   144  	return s
   145  }
   146  
   147  // WhereIf ...
   148  func (s *Stmt) WhereIf(cond bool, query string, args ...interface{}) *Stmt {
   149  	if cond {
   150  		s.Where(query, args...)
   151  	}
   152  	return s
   153  }
   154  
   155  // GroupBy ...
   156  func (s *Stmt) GroupBy(groupBy string) *Stmt {
   157  	s.groupBy = groupBy
   158  	return s
   159  }
   160  
   161  // Having ...
   162  func (s *Stmt) Having(having string) *Stmt {
   163  	s.having = having
   164  	return s
   165  }
   166  
   167  // OrderBy ...
   168  // TODO multiple orderbys
   169  func (s *Stmt) OrderBy(orderBy string) *Stmt {
   170  	s.orderBy = orderBy
   171  	return s
   172  }
   173  
   174  // Limit ...
   175  func (s *Stmt) Limit(limit int64) *Stmt {
   176  	s.limit = limit
   177  	return s
   178  }
   179  
   180  // Offset ...
   181  func (s *Stmt) Offset(offset int64) *Stmt {
   182  	s.offset = offset
   183  	return s
   184  }
   185  
   186  // Filter ... may throw exceptions
   187  // Filter has to know whom to filter. So before filtering, call From(), Model()
   188  // or pass the third argument.
   189  func (s *Stmt) Filter(expr string, mapper filter.Mapper, _Struct ...interface{}) *Stmt {
   190  	var info *_StructInfo
   191  
   192  	if s.info != nil {
   193  		info = s.info
   194  	} else if s.model != nil {
   195  		inf, err := getRegistered(s.model)
   196  		if err != nil {
   197  			panic(WrapError(err))
   198  		}
   199  		info = inf
   200  	} else if s.fromTable != nil {
   201  		inf, err := getRegistered(s.fromTable)
   202  		if err != nil {
   203  			panic(WrapError(err))
   204  		}
   205  		info = inf
   206  	} else if len(_Struct) > 0 { // Warn: == 1
   207  		inf, err := getRegistered(_Struct[0])
   208  		if err != nil {
   209  			panic(WrapError(err))
   210  		}
   211  		info = inf
   212  	} else {
   213  		panic(WrapError(errors.New("cannot deduce what to filter")))
   214  	}
   215  
   216  	query, args, err := filter.Filter(
   217  		func(field string) reflect.Type {
   218  			return info.fields[field]._type // maybe not exist
   219  		},
   220  		expr,
   221  		mapper,
   222  		info.tableName,
   223  	)
   224  	if err != nil {
   225  		panic(WrapError(err))
   226  	}
   227  	s.WhereIf(query != "", query, args...)
   228  	return s
   229  }
   230  
   231  // noWheres returns true if no SQL conditions.
   232  // Includes and, or.
   233  func (s *Stmt) noWheres() bool {
   234  	return len(s.ands) <= 0
   235  }
   236  
   237  func (s *Stmt) buildWheres() (string, []interface{}) {
   238  	if s.model != nil {
   239  		id, ok := s.info.getPrimaryKey(s.model)
   240  		s.WhereIf(ok, "id=?", id)
   241  	}
   242  
   243  	if s.noWheres() {
   244  		return "", nil
   245  	}
   246  
   247  	var args []interface{}
   248  	sb := bytes.NewBuffer(nil)
   249  	sb.WriteString(" WHERE ")
   250  	for i, w := range s.ands {
   251  		if i > 0 {
   252  			sb.WriteString(" AND ")
   253  		}
   254  		query, xargs := w.build()
   255  		sb.WriteString("(" + query + ")")
   256  		args = append(args, xargs...)
   257  	}
   258  	return sb.String(), args
   259  }
   260  
   261  func (s *Stmt) buildCreate() (*_StructInfo, string, []interface{}, error) {
   262  	panicIf(len(s.tableNames) != 1, "model length is not 1")
   263  	panicIf(s.raw.query != "", "cannot use raw here")
   264  	info, err := getRegistered(s.model)
   265  	if err != nil {
   266  		return info, "", nil, err
   267  	}
   268  	args := info.ifacesOf(s.model)
   269  	if len(args) == 0 {
   270  		return info, "", nil, ErrNoFields
   271  	}
   272  	return info, info.insertstr, args, nil
   273  }
   274  
   275  func (s *Stmt) tryFindTableName(out interface{}) (string, error) {
   276  	info, err := getRegistered(out)
   277  	if err != nil {
   278  		return "", err
   279  	}
   280  	if info.tableName == "" {
   281  		return "", fmt.Errorf("trying to use auto-registered struct table name")
   282  	}
   283  	return info.tableName, nil
   284  }
   285  
   286  func (s *Stmt) buildSelect(out interface{}, isCount bool) (string, []interface{}, error) {
   287  	if s.raw.query != "" {
   288  		return s.raw.query, s.raw.args, nil
   289  	}
   290  
   291  	if len(s.tableNames) == 0 {
   292  		name, err := s.tryFindTableName(out)
   293  		if err != nil {
   294  			return "", nil, err
   295  		}
   296  		s.tableNames = append(s.tableNames, name)
   297  	}
   298  
   299  	panicIf(len(s.tableNames) == 0, "model is empty")
   300  
   301  	var strFields string
   302  
   303  	if isCount {
   304  		strFields = "COUNT(1)"
   305  	} else {
   306  		fields := []string{}
   307  		if len(s.fields) == 0 {
   308  			if len(s.innerJoinTables) == 0 {
   309  				fields = []string{"*"}
   310  			} else {
   311  				fields = []string{s.tableNames[0] + ".*"}
   312  			}
   313  		} else {
   314  			if len(s.innerJoinTables) == 0 || len(s.fields) == 1 && s.fields[0] == "*" {
   315  				fields = s.fields
   316  			} else {
   317  				for _, list := range s.fields {
   318  					slice := strings.Split(list, ",")
   319  					for _, field := range slice {
   320  						index := strings.IndexByte(field, '.')
   321  						if index == -1 {
   322  							f := s.tableNames[0] + "." + field
   323  							fields = append(fields, f)
   324  						} else {
   325  							fields = append(fields, field)
   326  						}
   327  					}
   328  				}
   329  			}
   330  		}
   331  		strFields = strings.Join(fields, ",")
   332  	}
   333  
   334  	query := `SELECT ` + strFields + ` FROM ` + strings.Join(s.tableNames, ",")
   335  	if len(s.innerJoinTables) > 0 {
   336  		query += strings.Join(s.innerJoinTables, " ")
   337  	}
   338  
   339  	var args []interface{}
   340  
   341  	whereQuery, whereArgs := s.buildWheres()
   342  	query += whereQuery
   343  	args = append(args, whereArgs...)
   344  
   345  	query += s.buildGroupBy()
   346  	query += s.buildHaving()
   347  
   348  	if orderBy, err := s.buildOrderBy(); err != nil {
   349  		return "", nil, err
   350  	} else {
   351  		if orderBy != "" {
   352  			query += orderBy
   353  		}
   354  	}
   355  	query += s.buildLimit()
   356  
   357  	return query, args, nil
   358  }
   359  
   360  func (s *Stmt) buildUpdateMap(fields map[string]interface{}) (string, []interface{}, error) {
   361  	panicIf(len(s.tableNames) == 0, "model is empty")
   362  	panicIf(s.raw.query != "", "cannot use raw here")
   363  	query := `UPDATE ` + strings.Join(s.tableNames, ",") + ` SET `
   364  
   365  	if len(fields) == 0 {
   366  		return "", nil, ErrNoFields
   367  	}
   368  
   369  	updates := make([]string, 0, len(fields))
   370  	args := make([]interface{}, 0, len(fields))
   371  
   372  	for field, value := range fields {
   373  		switch tv := value.(type) {
   374  		case _Expr:
   375  			eq, ea := _Where(tv).build()
   376  			pair := field + "=" + eq
   377  			updates = append(updates, pair)
   378  			args = append(args, ea...)
   379  		default:
   380  			pair := field + "=?"
   381  			updates = append(updates, pair)
   382  			args = append(args, value)
   383  		}
   384  	}
   385  
   386  	query += strings.Join(updates, ",")
   387  
   388  	whereQuery, whereArgs := s.buildWheres()
   389  	query += whereQuery
   390  	args = append(args, whereArgs...)
   391  
   392  	query += s.buildLimit()
   393  
   394  	return query, args, nil
   395  }
   396  
   397  func (s *Stmt) buildUpdateModel(model interface{}) (string, []interface{}, error) {
   398  	panicIf(len(s.tableNames) == 0, "model is empty")
   399  	panicIf(s.raw.query != "", "cannot use raw here")
   400  	query := s.info.updatestr
   401  	args := s.info.ifacesOf(model)
   402  	whereQuery, whereArgs := s.buildWheres()
   403  	query += whereQuery
   404  	args = append(args, whereArgs...)
   405  	return query, args, nil
   406  }
   407  
   408  func (s *Stmt) buildDelete() (string, []interface{}, error) {
   409  	panicIf(len(s.tableNames) == 0, "model is empty")
   410  	panicIf(s.raw.query != "", "cannot use raw here")
   411  	var args []interface{}
   412  	query := `DELETE FROM ` + strings.Join(s.tableNames, ",")
   413  
   414  	whereQuery, whereArgs := s.buildWheres()
   415  	query += whereQuery
   416  	args = append(args, whereArgs...)
   417  
   418  	query += s.buildLimit()
   419  
   420  	return query, args, nil
   421  }
   422  
   423  func (s *Stmt) buildGroupBy() (groupBy string) {
   424  	if s.groupBy != "" {
   425  		groupBy = ` GROUP BY ` + s.groupBy
   426  	}
   427  	return
   428  }
   429  
   430  func (s *Stmt) buildHaving() (having string) {
   431  	if s.having != `` {
   432  		having = ` HAVING ` + s.having
   433  	}
   434  	return
   435  }
   436  
   437  var regexpOrderBy = regexp.MustCompile(`^ *((\w+\.)?(\w+)) *(\w+)? *$`)
   438  
   439  func (s *Stmt) buildOrderBy() (string, error) {
   440  	orderBy := " ORDER BY "
   441  	if s.orderBy == "" {
   442  		return "", nil
   443  	}
   444  	parts := strings.Split(s.orderBy, ",")
   445  	orderBys := []string{}
   446  	for _, part := range parts {
   447  		if !regexpOrderBy.MatchString(part) {
   448  			return ``, fmt.Errorf(`invalid order_by: %s`, part)
   449  		}
   450  		orderBys = append(orderBys, part)
   451  
   452  		// these are for automatically adding table names to fields in order_by etc.
   453  		// they are commented out because of custom field name doesn't belong to some table.
   454  		// currently I don't know how to handle this correctly.
   455  		//
   456  		// matches := regexpOrderBy.FindStringSubmatch(part)
   457  		// if len(matches) != 5 {
   458  		// 	return "", errors.New("invalid orderby")
   459  		// }
   460  		// table := matches[2]
   461  		// column := matches[1]
   462  		// order := matches[4]
   463  		// // avoid column ambiguous
   464  		// // "Error 1052: Column 'created_at' in order clause is ambiguous"
   465  		// if table == "" && len(s.tableNames)+len(s.innerJoinTables) > 1 {
   466  		// 	column = s.tableNames[0] + "." + column
   467  		// }
   468  		// if order != "" {
   469  		// 	column += " " + order
   470  		// }
   471  		// orderBys = append(orderBys, column)
   472  	}
   473  	orderBy += strings.Join(orderBys, ",")
   474  	return orderBy, nil
   475  }
   476  
   477  func (s *Stmt) buildLimit() (limit string) {
   478  	if s.limit > 0 {
   479  		limit += ` LIMIT ` + fmt.Sprint(s.limit)
   480  		if s.offset >= 0 {
   481  			limit += ` OFFSET ` + fmt.Sprint(s.offset)
   482  		}
   483  	}
   484  	return
   485  }
   486  
   487  // Create ...
   488  func (s *Stmt) Create() error {
   489  	info, query, args, err := s.buildCreate()
   490  	if err != nil {
   491  		return WrapError(err)
   492  	}
   493  
   494  	dumpSQL(query, args...)
   495  
   496  	result, err := s.db.Exec(query, args...)
   497  	if err != nil {
   498  		return WrapError(err)
   499  	}
   500  
   501  	id, err := result.LastInsertId()
   502  	if err != nil {
   503  		return WrapError(err)
   504  	}
   505  
   506  	info.setPrimaryKey(s.model, id)
   507  
   508  	return nil
   509  }
   510  
   511  // MustCreate ...
   512  func (s *Stmt) MustCreate() {
   513  	if err := s.Create(); err != nil {
   514  		panic(err)
   515  	}
   516  }
   517  
   518  // CreateSQL ...
   519  func (s *Stmt) CreateSQL() string {
   520  	_, query, args, err := s.buildCreate()
   521  	if err != nil {
   522  		panic(WrapError(err))
   523  	}
   524  	return strSQL(query, args...)
   525  }
   526  
   527  // Find ...
   528  func (s *Stmt) Find(out interface{}) error {
   529  	query, args, err := s.buildSelect(out, false)
   530  	if err != nil {
   531  		return WrapError(err)
   532  	}
   533  
   534  	dumpSQL(query, args...)
   535  	return ScanRows(out, s.db, query, args...)
   536  }
   537  
   538  // MustFind ...
   539  func (s *Stmt) MustFind(out interface{}) {
   540  	if err := s.Find(out); err != nil {
   541  		panic(err)
   542  	}
   543  }
   544  
   545  // FindSQL ...
   546  func (s *Stmt) FindSQL() string {
   547  	query, args, err := s.buildSelect(s.model, false)
   548  	if err != nil {
   549  		panic(WrapError(err))
   550  	}
   551  	return strSQL(query, args...)
   552  }
   553  
   554  // Count ...
   555  func (s *Stmt) Count(out interface{}) error {
   556  	query, args, err := s.buildSelect(s.fromTable, true)
   557  	if err != nil {
   558  		return WrapError(err)
   559  	}
   560  
   561  	dumpSQL(query, args...)
   562  	return ScanRows(out, s.db, query, args...)
   563  }
   564  
   565  // MustCount ...
   566  func (s *Stmt) MustCount(out interface{}) {
   567  	if err := s.Count(out); err != nil {
   568  		panic(err)
   569  	}
   570  }
   571  
   572  // CountSQL ...
   573  func (s *Stmt) CountSQL() string {
   574  	query, args, err := s.buildSelect(s.fromTable, true)
   575  	if err != nil {
   576  		panic(WrapError(err))
   577  	}
   578  	return strSQL(query, args...)
   579  }
   580  
   581  func (s *Stmt) updateMap(fields M, anyway bool) (sql.Result, error) {
   582  	if len(fields) == 0 {
   583  		return nil, ErrNoFields
   584  	}
   585  
   586  	query, args, err := s.buildUpdateMap(fields)
   587  	if err != nil {
   588  		return nil, err
   589  	}
   590  
   591  	if !anyway && s.noWheres() {
   592  		return nil, ErrNoWhere
   593  	}
   594  
   595  	dumpSQL(query, args...)
   596  
   597  	res, err := s.db.Exec(query, args...)
   598  	if err != nil {
   599  		return nil, err
   600  	}
   601  
   602  	return res, nil
   603  }
   604  
   605  func (s *Stmt) updateModel(model interface{}) (sql.Result, error) {
   606  	query, args, err := s.buildUpdateModel(model)
   607  	if err != nil {
   608  		return nil, err
   609  	}
   610  
   611  	dumpSQL(query, args...)
   612  
   613  	res, err := s.db.Exec(query, args...)
   614  	if err != nil {
   615  		return nil, err
   616  	}
   617  
   618  	return res, nil
   619  }
   620  
   621  // UpdateMap ...
   622  func (s *Stmt) UpdateMap(updates M) (sql.Result, error) {
   623  	res, err := s.updateMap(updates, false)
   624  	return res, WrapError(err)
   625  }
   626  
   627  // UpdateMapAnyway ...
   628  func (s *Stmt) UpdateMapAnyway(updates M) (sql.Result, error) {
   629  	res, err := s.updateMap(updates, true)
   630  	return res, WrapError(err)
   631  }
   632  
   633  // UpdateModel ...
   634  func (s *Stmt) UpdateModel(model interface{}) (sql.Result, error) {
   635  	res, err := s.updateModel(model)
   636  	return res, WrapError(err)
   637  }
   638  
   639  // MustUpdateMap ...
   640  func (s *Stmt) MustUpdateMap(updates M) sql.Result {
   641  	res, err := s.updateMap(updates, false)
   642  	if err != nil {
   643  		panic(err)
   644  	}
   645  	return res
   646  }
   647  
   648  // MustUpdateMapAnyway ...
   649  func (s *Stmt) MustUpdateMapAnyway(updates M) sql.Result {
   650  	res, err := s.updateMap(updates, true)
   651  	if err != nil {
   652  		panic(err)
   653  	}
   654  	return res
   655  }
   656  
   657  // MustUpdateModel ...
   658  func (s *Stmt) MustUpdateModel(model interface{}) sql.Result {
   659  	res, err := s.updateModel(model)
   660  	if err != nil {
   661  		panic(err)
   662  	}
   663  	return res
   664  }
   665  
   666  // UpdateMapSQL ...
   667  func (s *Stmt) UpdateMapSQL(updates M) string {
   668  	query, args, err := s.buildUpdateMap(updates)
   669  	if err != nil {
   670  		panic(WrapError(err))
   671  	}
   672  	return strSQL(query, args...)
   673  }
   674  
   675  // UpdateModelSQL ...
   676  func (s *Stmt) UpdateModelSQL(model interface{}) string {
   677  	query, args, err := s.buildUpdateModel(model)
   678  	if err != nil {
   679  		panic(WrapError(err))
   680  	}
   681  	return strSQL(query, args...)
   682  }
   683  
   684  func (s *Stmt) _delete(anyway bool) error {
   685  	query, args, err := s.buildDelete()
   686  	if err != nil {
   687  		return err
   688  	}
   689  
   690  	if !anyway && s.noWheres() {
   691  		return ErrNoWhere
   692  	}
   693  
   694  	dumpSQL(query, args...)
   695  
   696  	_, err = s.db.Exec(query, args...)
   697  	if err != nil {
   698  		return err
   699  	}
   700  
   701  	return nil
   702  }
   703  
   704  // Delete ...
   705  func (s *Stmt) Delete() error {
   706  	return WrapError(s._delete(false))
   707  }
   708  
   709  // DeleteAnyway ...
   710  func (s *Stmt) DeleteAnyway() error {
   711  	return WrapError(s._delete(true))
   712  }
   713  
   714  // MustDelete ...
   715  func (s *Stmt) MustDelete() {
   716  	if err := s.Delete(); err != nil {
   717  		panic(err)
   718  	}
   719  }
   720  
   721  // MustDeleteAnyway ...
   722  func (s *Stmt) MustDeleteAnyway() {
   723  	if err := s.DeleteAnyway(); err != nil {
   724  		panic(err)
   725  	}
   726  }
   727  
   728  // DeleteSQL ...
   729  func (s *Stmt) DeleteSQL() string {
   730  	query, args, err := s.buildDelete()
   731  	if err != nil {
   732  		panic(WrapError(err))
   733  	}
   734  	return strSQL(query, args...)
   735  }