github.com/kotovmak/go-admin@v1.1.1/modules/db/statement.go (about)

     1  // Copyright 2019 GoAdmin Core Team. All rights reserved.
     2  // Use of this source code is governed by a Apache-2.0 style
     3  // license that can be found in the LICENSE file.
     4  
     5  package db
     6  
     7  import (
     8  	dbsql "database/sql"
     9  	"errors"
    10  	"regexp"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  
    15  	"github.com/kotovmak/go-admin/modules/db/dialect"
    16  	"github.com/kotovmak/go-admin/modules/logger"
    17  )
    18  
    19  // SQL wraps the Connection and driver dialect methods.
    20  type SQL struct {
    21  	dialect.SQLComponent
    22  	diver   Connection
    23  	dialect dialect.Dialect
    24  	conn    string
    25  	tx      *dbsql.Tx
    26  }
    27  
    28  // SQLPool is a object pool of SQL.
    29  var SQLPool = sync.Pool{
    30  	New: func() interface{} {
    31  		return &SQL{
    32  			SQLComponent: dialect.SQLComponent{
    33  				Fields:     make([]string, 0),
    34  				TableName:  "",
    35  				Args:       make([]interface{}, 0),
    36  				Wheres:     make([]dialect.Where, 0),
    37  				Leftjoins:  make([]dialect.Join, 0),
    38  				UpdateRaws: make([]dialect.RawUpdate, 0),
    39  				WhereRaws:  "",
    40  				Order:      "",
    41  				Group:      "",
    42  				Limit:      "",
    43  			},
    44  			diver:   nil,
    45  			dialect: nil,
    46  		}
    47  	},
    48  }
    49  
    50  // H is a shorthand of map.
    51  type H map[string]interface{}
    52  
    53  // newSQL get a new SQL from SQLPool.
    54  func newSQL() *SQL {
    55  	return SQLPool.Get().(*SQL)
    56  }
    57  
    58  // *******************************
    59  // process method
    60  // *******************************
    61  
    62  // TableName return a SQL with given table and default connection.
    63  func Table(table string) *SQL {
    64  	sql := newSQL()
    65  	sql.TableName = table
    66  	sql.conn = "default"
    67  	return sql
    68  }
    69  
    70  // WithDriver return a SQL with given driver.
    71  func WithDriver(conn Connection) *SQL {
    72  	sql := newSQL()
    73  	sql.diver = conn
    74  	sql.dialect = dialect.GetDialectByDriver(conn.Name())
    75  	sql.conn = "default"
    76  	return sql
    77  }
    78  
    79  // WithDriverAndConnection return a SQL with given driver and connection name.
    80  func WithDriverAndConnection(connName string, conn Connection) *SQL {
    81  	sql := newSQL()
    82  	sql.diver = conn
    83  	sql.dialect = dialect.GetDialectByDriver(conn.Name())
    84  	sql.conn = connName
    85  	return sql
    86  }
    87  
    88  // WithDriver return a SQL with given driver.
    89  func (sql *SQL) WithDriver(conn Connection) *SQL {
    90  	sql.diver = conn
    91  	sql.dialect = dialect.GetDialectByDriver(conn.Name())
    92  	return sql
    93  }
    94  
    95  // WithConnection set the connection name of SQL.
    96  func (sql *SQL) WithConnection(conn string) *SQL {
    97  	sql.conn = conn
    98  	return sql
    99  }
   100  
   101  // WithTx set the database transaction object of SQL.
   102  func (sql *SQL) WithTx(tx *dbsql.Tx) *SQL {
   103  	sql.tx = tx
   104  	return sql
   105  }
   106  
   107  // TableName set table of SQL.
   108  func (sql *SQL) Table(table string) *SQL {
   109  	sql.clean()
   110  	sql.TableName = table
   111  	return sql
   112  }
   113  
   114  // Select set select fields.
   115  func (sql *SQL) Select(fields ...string) *SQL {
   116  	sql.Fields = fields
   117  	sql.Functions = make([]string, len(fields))
   118  	reg, _ := regexp.Compile(`(.*?)\((.*?)\)`)
   119  	for k, field := range fields {
   120  		res := reg.FindAllStringSubmatch(field, -1)
   121  		if len(res) > 0 && len(res[0]) > 2 {
   122  			sql.Functions[k] = res[0][1]
   123  			sql.Fields[k] = res[0][2]
   124  		}
   125  	}
   126  	return sql
   127  }
   128  
   129  // OrderBy set order fields.
   130  func (sql *SQL) OrderBy(fields ...string) *SQL {
   131  	if len(fields) == 0 {
   132  		panic("wrong order field")
   133  	}
   134  	for i := 0; i < len(fields); i++ {
   135  		if i == len(fields)-2 {
   136  			sql.Order += " " + sql.wrap(fields[i]) + " " + fields[i+1]
   137  			return sql
   138  		}
   139  		sql.Order += " " + sql.wrap(fields[i]) + " and "
   140  	}
   141  	return sql
   142  }
   143  
   144  // OrderByRaw set order by.
   145  func (sql *SQL) OrderByRaw(order string) *SQL {
   146  	if order != "" {
   147  		sql.Order += " " + order
   148  	}
   149  	return sql
   150  }
   151  
   152  func (sql *SQL) GroupBy(fields ...string) *SQL {
   153  	if len(fields) == 0 {
   154  		panic("wrong group by field")
   155  	}
   156  	for i := 0; i < len(fields); i++ {
   157  		if i == len(fields)-1 {
   158  			sql.Group += " " + sql.wrap(fields[i])
   159  		} else {
   160  			sql.Group += " " + sql.wrap(fields[i]) + ","
   161  		}
   162  	}
   163  	return sql
   164  }
   165  
   166  // GroupByRaw set group by.
   167  func (sql *SQL) GroupByRaw(group string) *SQL {
   168  	if group != "" {
   169  		sql.Group += " " + group
   170  	}
   171  	return sql
   172  }
   173  
   174  // Skip set offset value.
   175  func (sql *SQL) Skip(offset int) *SQL {
   176  	sql.Offset = strconv.Itoa(offset)
   177  	return sql
   178  }
   179  
   180  // Take set limit value.
   181  func (sql *SQL) Take(take int) *SQL {
   182  	sql.Limit = strconv.Itoa(take)
   183  	return sql
   184  }
   185  
   186  // Where add the where operation and argument value.
   187  func (sql *SQL) Where(field string, operation string, arg interface{}) *SQL {
   188  	sql.Wheres = append(sql.Wheres, dialect.Where{
   189  		Field:     field,
   190  		Operation: operation,
   191  		Qmark:     "?",
   192  	})
   193  	sql.Args = append(sql.Args, arg)
   194  	return sql
   195  }
   196  
   197  // WhereIn add the where operation of "in" and argument values.
   198  func (sql *SQL) WhereIn(field string, arg []interface{}) *SQL {
   199  	if len(arg) == 0 {
   200  		panic("wrong parameter")
   201  	}
   202  	sql.Wheres = append(sql.Wheres, dialect.Where{
   203  		Field:     field,
   204  		Operation: "in",
   205  		Qmark:     "(" + strings.Repeat("?,", len(arg)-1) + "?)",
   206  	})
   207  	sql.Args = append(sql.Args, arg...)
   208  	return sql
   209  }
   210  
   211  // WhereNotIn add the where operation of "not in" and argument values.
   212  func (sql *SQL) WhereNotIn(field string, arg []interface{}) *SQL {
   213  	if len(arg) == 0 {
   214  		panic("wrong parameter")
   215  	}
   216  	sql.Wheres = append(sql.Wheres, dialect.Where{
   217  		Field:     field,
   218  		Operation: "not in",
   219  		Qmark:     "(" + strings.Repeat("?,", len(arg)-1) + "?)",
   220  	})
   221  	sql.Args = append(sql.Args, arg...)
   222  	return sql
   223  }
   224  
   225  // Find query the sql result with given id assuming that primary key name is "id".
   226  func (sql *SQL) Find(arg interface{}) (map[string]interface{}, error) {
   227  	return sql.Where("id", "=", arg).First()
   228  }
   229  
   230  // Count query the count of query results.
   231  func (sql *SQL) Count() (int64, error) {
   232  	var (
   233  		res    map[string]interface{}
   234  		err    error
   235  		driver = sql.diver.Name()
   236  	)
   237  
   238  	if res, err = sql.Select("count(*)").First(); err != nil {
   239  		return 0, err
   240  	}
   241  
   242  	if driver == DriverPostgresql {
   243  		return res["count"].(int64), nil
   244  	} else if driver == DriverMssql {
   245  		return res[""].(int64), nil
   246  	}
   247  
   248  	return res["count(*)"].(int64), nil
   249  }
   250  
   251  // Sum sum the value of given field.
   252  func (sql *SQL) Sum(field string) (float64, error) {
   253  	var (
   254  		res map[string]interface{}
   255  		err error
   256  		key = "sum(" + sql.wrap(field) + ")"
   257  	)
   258  	if res, err = sql.Select("sum(" + field + ")").First(); err != nil {
   259  		return 0, err
   260  	}
   261  
   262  	if res == nil {
   263  		return 0, nil
   264  	}
   265  
   266  	if r, ok := res[key].(float64); ok {
   267  		return r, nil
   268  	} else if r, ok := res[key].([]uint8); ok {
   269  		return strconv.ParseFloat(string(r), 64)
   270  	} else {
   271  		return 0, nil
   272  	}
   273  }
   274  
   275  // Max find the maximal value of given field.
   276  func (sql *SQL) Max(field string) (interface{}, error) {
   277  	var (
   278  		res map[string]interface{}
   279  		err error
   280  		key = "max(" + sql.wrap(field) + ")"
   281  	)
   282  	if res, err = sql.Select("max(" + field + ")").First(); err != nil {
   283  		return 0, err
   284  	}
   285  
   286  	if res == nil {
   287  		return 0, nil
   288  	}
   289  
   290  	return res[key], nil
   291  }
   292  
   293  // Min find the minimal value of given field.
   294  func (sql *SQL) Min(field string) (interface{}, error) {
   295  	var (
   296  		res map[string]interface{}
   297  		err error
   298  		key = "min(" + sql.wrap(field) + ")"
   299  	)
   300  	if res, err = sql.Select("min(" + field + ")").First(); err != nil {
   301  		return 0, err
   302  	}
   303  
   304  	if res == nil {
   305  		return 0, nil
   306  	}
   307  
   308  	return res[key], nil
   309  }
   310  
   311  // Avg find the average value of given field.
   312  func (sql *SQL) Avg(field string) (interface{}, error) {
   313  	var (
   314  		res map[string]interface{}
   315  		err error
   316  		key = "avg(" + sql.wrap(field) + ")"
   317  	)
   318  	if res, err = sql.Select("avg(" + field + ")").First(); err != nil {
   319  		return 0, err
   320  	}
   321  
   322  	if res == nil {
   323  		return 0, nil
   324  	}
   325  
   326  	return res[key], nil
   327  }
   328  
   329  // WhereRaw set WhereRaws and arguments.
   330  func (sql *SQL) WhereRaw(raw string, args ...interface{}) *SQL {
   331  	sql.WhereRaws = raw
   332  	sql.Args = append(sql.Args, args...)
   333  	return sql
   334  }
   335  
   336  // UpdateRaw set UpdateRaw.
   337  func (sql *SQL) UpdateRaw(raw string, args ...interface{}) *SQL {
   338  	sql.UpdateRaws = append(sql.UpdateRaws, dialect.RawUpdate{
   339  		Expression: raw,
   340  		Args:       args,
   341  	})
   342  	return sql
   343  }
   344  
   345  // LeftJoin add a left join info.
   346  func (sql *SQL) LeftJoin(table string, fieldA string, operation string, fieldB string) *SQL {
   347  	sql.Leftjoins = append(sql.Leftjoins, dialect.Join{
   348  		FieldA:    fieldA,
   349  		FieldB:    fieldB,
   350  		Table:     table,
   351  		Operation: operation,
   352  	})
   353  	return sql
   354  }
   355  
   356  // *******************************
   357  // Transaction method
   358  // *******************************
   359  
   360  // TxFn is the transaction callback function.
   361  type TxFn func(tx *dbsql.Tx) (error, map[string]interface{})
   362  
   363  // WithTransaction call the callback function within the transaction and
   364  // catch the error.
   365  func (sql *SQL) WithTransaction(fn TxFn) (res map[string]interface{}, err error) {
   366  
   367  	tx := sql.diver.BeginTxAndConnection(sql.conn)
   368  
   369  	defer func() {
   370  		if p := recover(); p != nil {
   371  			// a panic occurred, rollback and repanic
   372  			_ = tx.Rollback()
   373  			panic(p)
   374  		} else if err != nil {
   375  			// something went wrong, rollback
   376  			_ = tx.Rollback()
   377  		} else {
   378  			// all good, commit
   379  			err = tx.Commit()
   380  		}
   381  	}()
   382  
   383  	err, res = fn(tx)
   384  	return
   385  }
   386  
   387  // WithTransactionByLevel call the callback function within the transaction
   388  // of given transaction level and catch the error.
   389  func (sql *SQL) WithTransactionByLevel(level dbsql.IsolationLevel, fn TxFn) (res map[string]interface{}, err error) {
   390  
   391  	tx := sql.diver.BeginTxWithLevelAndConnection(sql.conn, level)
   392  
   393  	defer func() {
   394  		if p := recover(); p != nil {
   395  			// a panic occurred, rollback and repanic
   396  			_ = tx.Rollback()
   397  			panic(p)
   398  		} else if err != nil {
   399  			// something went wrong, rollback
   400  			_ = tx.Rollback()
   401  		} else {
   402  			// all good, commit
   403  			err = tx.Commit()
   404  		}
   405  	}()
   406  
   407  	err, res = fn(tx)
   408  	return
   409  }
   410  
   411  // *******************************
   412  // terminal method
   413  // -------------------------------
   414  // sql args order:
   415  // update ... => where ...
   416  // *******************************
   417  
   418  // First query the result and return the first row.
   419  func (sql *SQL) First() (map[string]interface{}, error) {
   420  	defer RecycleSQL(sql)
   421  
   422  	sql.dialect.Select(&sql.SQLComponent)
   423  
   424  	res, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...)
   425  
   426  	if err != nil {
   427  		return nil, err
   428  	}
   429  
   430  	if len(res) < 1 {
   431  		return nil, errors.New("out of index")
   432  	}
   433  	return res[0], nil
   434  }
   435  
   436  // All query all the result and return.
   437  func (sql *SQL) All() ([]map[string]interface{}, error) {
   438  	defer RecycleSQL(sql)
   439  
   440  	sql.dialect.Select(&sql.SQLComponent)
   441  
   442  	return sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...)
   443  }
   444  
   445  // ShowColumns show columns info.
   446  func (sql *SQL) ShowColumns() ([]map[string]interface{}, error) {
   447  	defer RecycleSQL(sql)
   448  
   449  	return sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowColumns(sql.TableName))
   450  }
   451  
   452  // ShowTables show table info.
   453  func (sql *SQL) ShowTables() ([]string, error) {
   454  	defer RecycleSQL(sql)
   455  
   456  	models, err := sql.diver.QueryWithConnection(sql.conn, sql.dialect.ShowTables())
   457  
   458  	if err != nil {
   459  		return []string{}, err
   460  	}
   461  
   462  	tables := make([]string, 0)
   463  	if len(models) == 0 {
   464  		return tables, nil
   465  	}
   466  
   467  	key := "Tables_in_" + sql.TableName
   468  	if sql.diver.Name() == DriverPostgresql || sql.diver.Name() == DriverSqlite {
   469  		key = "tablename"
   470  	} else if sql.diver.Name() == DriverMssql {
   471  		key = "TABLE_NAME"
   472  	} else if _, ok := models[0][key].(string); !ok {
   473  		key = "Tables_in_" + strings.ToLower(sql.TableName)
   474  	}
   475  
   476  	for i := 0; i < len(models); i++ {
   477  		// skip sqlite system tables
   478  		if sql.diver.Name() == DriverSqlite && models[i][key].(string) == "sqlite_sequence" {
   479  			continue
   480  		}
   481  
   482  		tables = append(tables, models[i][key].(string))
   483  	}
   484  
   485  	return tables, nil
   486  }
   487  
   488  // Update exec the update method of given key/value pairs.
   489  func (sql *SQL) Update(values dialect.H) (int64, error) {
   490  	defer RecycleSQL(sql)
   491  
   492  	sql.Values = values
   493  
   494  	sql.dialect.Update(&sql.SQLComponent)
   495  
   496  	res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...)
   497  
   498  	if err != nil {
   499  		return 0, err
   500  	}
   501  
   502  	if affectRow, _ := res.RowsAffected(); affectRow < 1 {
   503  		return 0, errors.New("no affect row")
   504  	}
   505  
   506  	return res.LastInsertId()
   507  }
   508  
   509  // Delete exec the delete method.
   510  func (sql *SQL) Delete() error {
   511  	defer RecycleSQL(sql)
   512  
   513  	sql.dialect.Delete(&sql.SQLComponent)
   514  
   515  	res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...)
   516  
   517  	if err != nil {
   518  		return err
   519  	}
   520  
   521  	if affectRow, _ := res.RowsAffected(); affectRow < 1 {
   522  		return errors.New("no affect row")
   523  	}
   524  
   525  	return nil
   526  }
   527  
   528  // Exec exec the exec method.
   529  func (sql *SQL) Exec() (int64, error) {
   530  	defer RecycleSQL(sql)
   531  
   532  	sql.dialect.Update(&sql.SQLComponent)
   533  
   534  	res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...)
   535  
   536  	if err != nil {
   537  		return 0, err
   538  	}
   539  
   540  	if affectRow, _ := res.RowsAffected(); affectRow < 1 {
   541  		return 0, errors.New("no affect row")
   542  	}
   543  
   544  	return res.LastInsertId()
   545  }
   546  
   547  const postgresInsertCheckTableName = "goadmin_menu|goadmin_permissions|goadmin_roles|goadmin_users"
   548  
   549  // Insert exec the insert method of given key/value pairs.
   550  func (sql *SQL) Insert(values dialect.H) (int64, error) {
   551  	defer RecycleSQL(sql)
   552  
   553  	sql.Values = values
   554  
   555  	sql.dialect.Insert(&sql.SQLComponent)
   556  
   557  	if sql.diver.Name() == DriverPostgresql && (strings.Contains(postgresInsertCheckTableName, sql.TableName)) {
   558  
   559  		resMap, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement+" RETURNING id", sql.Args...)
   560  
   561  		if err != nil {
   562  
   563  			// Fixed java h2 database postgresql mode
   564  			_, err := sql.diver.QueryWith(sql.tx, sql.conn, sql.Statement, sql.Args...)
   565  
   566  			if err != nil {
   567  				return 0, err
   568  			}
   569  
   570  			res, err := sql.diver.QueryWithConnection(sql.conn, `SELECT max("id") as "id" FROM "`+sql.TableName+`"`)
   571  
   572  			if err != nil {
   573  				return 0, err
   574  			}
   575  
   576  			if len(res) != 0 {
   577  				return res[0]["id"].(int64), nil
   578  			}
   579  
   580  			return 0, err
   581  		}
   582  
   583  		if len(resMap) == 0 {
   584  			return 0, errors.New("no affect row")
   585  		}
   586  
   587  		return resMap[0]["id"].(int64), nil
   588  	}
   589  
   590  	res, err := sql.diver.ExecWith(sql.tx, sql.conn, sql.Statement, sql.Args...)
   591  
   592  	if err != nil {
   593  		return 0, err
   594  	}
   595  
   596  	if affectRow, _ := res.RowsAffected(); affectRow < 1 {
   597  		return 0, errors.New("no affect row")
   598  	}
   599  
   600  	return res.LastInsertId()
   601  }
   602  
   603  func (sql *SQL) wrap(field string) string {
   604  	return sql.diver.GetDelimiter() + field + sql.diver.GetDelimiter2()
   605  }
   606  
   607  func (sql *SQL) clean() {
   608  	sql.Functions = make([]string, 0)
   609  	sql.Group = ""
   610  	sql.Values = make(map[string]interface{})
   611  	sql.Fields = make([]string, 0)
   612  	sql.TableName = ""
   613  	sql.Wheres = make([]dialect.Where, 0)
   614  	sql.Leftjoins = make([]dialect.Join, 0)
   615  	sql.Args = make([]interface{}, 0)
   616  	sql.Order = ""
   617  	sql.Offset = ""
   618  	sql.Limit = ""
   619  	sql.WhereRaws = ""
   620  	sql.UpdateRaws = make([]dialect.RawUpdate, 0)
   621  	sql.Statement = ""
   622  }
   623  
   624  // RecycleSQL clear the SQL and put into the pool.
   625  func RecycleSQL(sql *SQL) {
   626  
   627  	logger.LogSQL(sql.Statement, sql.Args)
   628  
   629  	sql.clean()
   630  
   631  	sql.conn = ""
   632  	sql.diver = nil
   633  	sql.tx = nil
   634  	sql.dialect = nil
   635  
   636  	SQLPool.Put(sql)
   637  }