github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/exec.go (about)

     1  package sqx
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/bingoohuang/gg/pkg/mapstruct"
    13  	"github.com/bingoohuang/gg/pkg/sqlparse/sqlparser"
    14  )
    15  
    16  // QueryAsNumber executes a query which only returns number like count(*) sql.
    17  func (s SQL) QueryAsNumber(db SqxDB) (int64, error) {
    18  	str, err := s.QueryAsString(db)
    19  	if err != nil {
    20  		return 0, err
    21  	}
    22  
    23  	return strconv.ParseInt(str, 10, 64)
    24  }
    25  
    26  // QueryAsString executes a query which only returns number like count(*) sql.
    27  func (s SQL) QueryAsString(db SqxDB) (string, error) {
    28  	row, err := s.QueryAsRow(db)
    29  	if err != nil {
    30  		return "", err
    31  	}
    32  
    33  	if len(row) == 0 {
    34  		return "", nil
    35  	}
    36  
    37  	return row[0], nil
    38  }
    39  
    40  // Update executes an update/delete query and returns rows affected.
    41  func (s SQL) Update(db SqxDB) (int64, error) {
    42  	r, err := s.UpdateRaw(db)
    43  	if err != nil {
    44  		return 0, err
    45  	}
    46  
    47  	return r.RowsAffected()
    48  }
    49  
    50  func (s SQL) UpdateRaw(db SqxDB) (sql.Result, error) {
    51  	if dbTypeAware, ok := db.(DBTypeAware); ok {
    52  		dbType := dbTypeAware.GetDBType()
    53  		cr, err := dbType.Convert(s.Q, s.ConvertOptions...)
    54  		if err != nil {
    55  			return nil, err
    56  		}
    57  
    58  		s.Q, s.Vars = cr.PickArgs(s.Vars)
    59  	}
    60  
    61  	if !s.NoLog {
    62  		logQuery(s.Name, s.Q, s.Vars)
    63  	}
    64  
    65  	ctx, cancel := s.prepareContext()
    66  	defer cancel()
    67  
    68  	result, err := db.ExecContext(ctx, s.Q, s.Vars...)
    69  	logQueryError(s.NoLog, s.Name, result, err)
    70  	return result, err
    71  }
    72  
    73  type RowScannerInit interface {
    74  	InitRowScanner(columns []string)
    75  }
    76  
    77  type RowScanner interface {
    78  	ScanRow(columns []string, rows *sql.Rows, rowIndex int) (bool, error)
    79  }
    80  
    81  type ScanRowFn func(columns []string, rows *sql.Rows, rowIndex int) (bool, error)
    82  
    83  func (s ScanRowFn) ScanRow(columns []string, rows *sql.Rows, rowIndex int) (bool, error) {
    84  	return s(columns, rows, rowIndex)
    85  }
    86  
    87  // QueryOption defines the query options.
    88  type QueryOption struct {
    89  	MaxRows          int
    90  	TagNames         []string
    91  	Scanner          RowScanner
    92  	LowerColumnNames bool
    93  
    94  	ConvertOptionOptions []sqlparser.ConvertOption
    95  }
    96  
    97  // QueryOptionFn define the prototype function to set QueryOption.
    98  type QueryOptionFn func(o *QueryOption)
    99  
   100  // QueryOptionFns is the slice of QueryOptionFn.
   101  type QueryOptionFns []QueryOptionFn
   102  
   103  func (q QueryOptionFns) Options() *QueryOption {
   104  	o := &QueryOption{
   105  		TagNames: []string{"col", "db", "mapstruct", "field", "json", "yaml"},
   106  	}
   107  	for _, fn := range q {
   108  		fn(o)
   109  	}
   110  
   111  	return o
   112  }
   113  
   114  // WithMaxRows set the max rows of QueryOption.
   115  func WithMaxRows(maxRows int) QueryOptionFn {
   116  	return func(o *QueryOption) { o.MaxRows = maxRows }
   117  }
   118  
   119  // WithLowerColumnNames set the LowerColumnNames of QueryOption.
   120  func WithLowerColumnNames(v bool) QueryOptionFn {
   121  	return func(o *QueryOption) { o.LowerColumnNames = v }
   122  }
   123  
   124  // WithTagNames set the tagNames for mapping struct fields to query Columns.
   125  func WithTagNames(tagNames ...string) QueryOptionFn {
   126  	return func(o *QueryOption) { o.TagNames = tagNames }
   127  }
   128  
   129  // WithOptions apply the query option directly.
   130  func WithOptions(v *QueryOption) QueryOptionFn {
   131  	return func(o *QueryOption) { *o = *v }
   132  }
   133  
   134  // WithScanRow set row scanner for the query result.
   135  func WithScanRow(v ScanRowFn) QueryOptionFn {
   136  	return func(o *QueryOption) { o.Scanner = v }
   137  }
   138  
   139  // WithRowScanner set row scanner for the query result.
   140  func WithRowScanner(v RowScanner) QueryOptionFn {
   141  	return func(o *QueryOption) { o.Scanner = v }
   142  }
   143  
   144  // allowRowNum test the current rowNum is allowed for MaxRows control.
   145  func (o QueryOption) allowRowNum(rowNum int) bool {
   146  	return o.MaxRows == 0 || rowNum <= o.MaxRows
   147  }
   148  
   149  // Query queries return with result.
   150  func (s SQL) Query(db SqxDB, result interface{}, optionFns ...QueryOptionFn) error {
   151  	err := s.query(db, result, optionFns...)
   152  	if !s.NoLog {
   153  		logQueryError(true, s.Name, nil, err)
   154  		logRows(s.Name, GetQueryRows(result))
   155  	}
   156  	return err
   157  }
   158  
   159  func GetQueryRows(dest interface{}) int {
   160  	if dest == nil {
   161  		return 0
   162  	}
   163  
   164  	v := reflect.ValueOf(dest)
   165  	if v.Kind() == reflect.Ptr {
   166  		v = v.Elem()
   167  	}
   168  
   169  	switch v.Kind() {
   170  	case reflect.Slice, reflect.Array:
   171  		return v.Len()
   172  	default:
   173  		return 1
   174  	}
   175  }
   176  
   177  func (s SQL) query(db SqxDB, result interface{}, optionFns ...QueryOptionFn) error {
   178  	resultValue := reflect.ValueOf(result)
   179  	if resultValue.Kind() != reflect.Ptr {
   180  		return fmt.Errorf("result must be a pointer")
   181  	}
   182  
   183  	elem := resultValue.Elem()
   184  	elemKind := elem.Kind()
   185  	if elemKind == reflect.Ptr { // 如果依然是指针
   186  		typ := elem.Type().Elem() // 获取二级指针底层类型
   187  		val := reflect.New(typ)   // 创新底层类型对象
   188  		err := s.Query(db, val.Interface(), optionFns...)
   189  		if err == nil {
   190  			elem.Set(val) // 赋予一级指针新对象地址
   191  		}
   192  		return err
   193  	}
   194  
   195  	option := QueryOptionFns(optionFns).Options()
   196  
   197  	var err error
   198  	var input interface{}
   199  
   200  	options := WithOptions(option)
   201  	switch elemKind {
   202  	case reflect.Struct:
   203  		input, err = s.QueryAsMap(db, options)
   204  	case reflect.Slice:
   205  		sliceElemType := elem.Type().Elem()
   206  		switch sliceElemType.Kind() {
   207  		case reflect.Struct:
   208  			input, err = s.QueryAsMaps(db, options)
   209  		case reflect.String, reflect.Bool,
   210  			reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   211  			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   212  			scanner := &Col1Scanner{}
   213  			err = s.QueryRaw(db, options, WithRowScanner(scanner))
   214  			input = scanner.Data
   215  		default:
   216  			return ErrNotSupported
   217  		}
   218  	case reflect.String, reflect.Bool,
   219  		reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   220  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   221  		scanner := &Col1Scanner{MaxRows: 1}
   222  		err = s.QueryRaw(db, options, WithRowScanner(scanner))
   223  		if len(scanner.Data) > 0 {
   224  			input = scanner.Data[0]
   225  		}
   226  	default:
   227  		return ErrNotSupported
   228  	}
   229  
   230  	if err != nil {
   231  		return err
   232  	}
   233  
   234  	decoder, err := mapstruct.NewDecoder(&mapstruct.Config{
   235  		Result:   result,
   236  		TagNames: option.TagNames,
   237  		Squash:   true,
   238  		WeakType: true,
   239  	})
   240  	if err != nil {
   241  		return err
   242  	}
   243  
   244  	return decoder.Decode(input)
   245  }
   246  
   247  var ErrNotSupported = errors.New("sqx: Unsupported result type")
   248  
   249  type Col1Scanner struct {
   250  	Data    []string
   251  	MaxRows int
   252  }
   253  
   254  func (s *Col1Scanner) ScanRow(columns []string, rows *sql.Rows, _ int) (bool, error) {
   255  	if v, err := ScanSliceRow(rows, columns); err != nil {
   256  		return false, err
   257  	} else {
   258  		s.Data = append(s.Data, v[0])
   259  		return s.MaxRows == 0 || len(s.Data) < s.MaxRows, nil
   260  	}
   261  }
   262  
   263  type MapScanner struct {
   264  	Data    []map[string]string
   265  	MaxRows int
   266  }
   267  
   268  func (s *MapScanner) Data0() map[string]string {
   269  	if len(s.Data) == 0 {
   270  		return nil
   271  	}
   272  
   273  	return s.Data[0]
   274  }
   275  
   276  func (s *MapScanner) ScanRow(columns []string, rows *sql.Rows, _ int) (bool, error) {
   277  	if v, err := ScanMapRow(rows, columns); err != nil {
   278  		return false, err
   279  	} else {
   280  		s.Data = append(s.Data, v)
   281  		return s.MaxRows == 0 || len(s.Data) < s.MaxRows, nil
   282  	}
   283  }
   284  
   285  // QueryAsMaps query rows as map slice.
   286  func (s SQL) QueryAsMaps(db SqxDB, optionFns ...QueryOptionFn) ([]map[string]string, error) {
   287  	scanner := &MapScanner{Data: make([]map[string]string, 0)}
   288  	err := s.QueryRaw(db, append(optionFns, WithRowScanner(scanner))...)
   289  	return scanner.Data, err
   290  }
   291  
   292  // QueryAsMap query a single row as a map return.
   293  func (s SQL) QueryAsMap(db SqxDB, optionFns ...QueryOptionFn) (map[string]string, error) {
   294  	scanner := &MapScanner{Data: make([]map[string]string, 0), MaxRows: 1}
   295  	err := s.QueryRaw(db, append(optionFns, WithRowScanner(scanner))...)
   296  	return scanner.Data0(), err
   297  }
   298  
   299  func ScanSliceRow(rows *sql.Rows, columns []string) ([]string, error) {
   300  	holders, err := ScanRow(len(columns), rows)
   301  	if err != nil {
   302  		return nil, err
   303  	}
   304  
   305  	m := make([]string, len(columns))
   306  	for i, h := range holders {
   307  		m[i] = h.String()
   308  	}
   309  
   310  	return m, nil
   311  }
   312  
   313  func ScanMapRow(rows *sql.Rows, columns []string) (map[string]string, error) {
   314  	holders, err := ScanRow(len(columns), rows)
   315  	if err != nil {
   316  		return nil, err
   317  	}
   318  
   319  	m := make(map[string]string)
   320  	for i, h := range holders {
   321  		m[columns[i]] = h.String()
   322  	}
   323  
   324  	return m, nil
   325  }
   326  
   327  type StringRowScanner struct {
   328  	Data    [][]string
   329  	MaxRows int
   330  }
   331  
   332  func (r *StringRowScanner) ScanRow(columns []string, rows *sql.Rows, _ int) (bool, error) {
   333  	if m, err := ScanStringRow(rows, columns); err != nil {
   334  		return false, err
   335  	} else {
   336  		r.Data = append(r.Data, m)
   337  		return r.MaxRows == 0 || len(r.Data) < r.MaxRows, nil
   338  	}
   339  }
   340  
   341  func (r *StringRowScanner) Data0() []string {
   342  	if len(r.Data) == 0 {
   343  		return nil
   344  	}
   345  
   346  	return r.Data[0]
   347  }
   348  
   349  // QueryAsRow query a single row as a string slice return.
   350  func (s SQL) QueryAsRow(db SqxDB, optionFns ...QueryOptionFn) ([]string, error) {
   351  	f := &StringRowScanner{MaxRows: 1}
   352  	if err := s.QueryRaw(db, append(optionFns, WithRowScanner(f))...); err != nil {
   353  		return nil, err
   354  	}
   355  
   356  	return f.Data0(), nil
   357  }
   358  
   359  // QueryAsRows query rows as [][]string.
   360  func (s SQL) QueryAsRows(db SqxDB, optionFns ...QueryOptionFn) ([][]string, error) {
   361  	f := &StringRowScanner{}
   362  	if err := s.QueryRaw(db, append(optionFns, WithRowScanner(f))...); err != nil {
   363  		return nil, err
   364  	}
   365  
   366  	return f.Data, nil
   367  }
   368  
   369  func ScanStringRow(rows *sql.Rows, columns []string) ([]string, error) {
   370  	holders, err := ScanRow(len(columns), rows)
   371  	if err != nil {
   372  		return nil, err
   373  	}
   374  
   375  	m := make([]string, len(columns))
   376  	for i, h := range holders {
   377  		m[i] = h.String()
   378  	}
   379  	return m, nil
   380  }
   381  
   382  // QueryRaw query rows for customized row scanner.
   383  func (s SQL) QueryRaw(db SqxDB, optionFns ...QueryOptionFn) error {
   384  	option, r, columns, err := s.prepareQuery(db, optionFns...)
   385  	if err != nil {
   386  		return err
   387  	}
   388  
   389  	defer r.Close()
   390  
   391  	if initial, ok := option.Scanner.(RowScannerInit); ok {
   392  		initial.InitRowScanner(columns)
   393  	}
   394  
   395  	rows := 0
   396  	for rn := 0; r.Next() && option.allowRowNum(rn+1); rn++ {
   397  		rows++
   398  		if continued, err := option.Scanner.ScanRow(columns, r, rn); err != nil {
   399  			return err
   400  		} else if !continued {
   401  			break
   402  		}
   403  	}
   404  
   405  	if rows == 0 {
   406  		return sql.ErrNoRows
   407  	}
   408  
   409  	return nil
   410  }
   411  
   412  func ScanRowValues(rows *sql.Rows) ([]interface{}, error) {
   413  	cols, err := rows.Columns()
   414  	if err != nil {
   415  		return nil, err
   416  	}
   417  
   418  	row, err := ScanRow(len(cols), rows)
   419  	if err != nil {
   420  		return nil, err
   421  	}
   422  
   423  	rowValues := make([]interface{}, len(cols))
   424  	for i := range rowValues {
   425  		rowValues[i] = row[i].Get()
   426  	}
   427  
   428  	return rowValues, nil
   429  }
   430  
   431  func ScanRow(columnSize int, r *sql.Rows) ([]NullAny, error) {
   432  	holders := make([]NullAny, columnSize)
   433  	pointers := make([]interface{}, columnSize)
   434  	for i := 0; i < columnSize; i++ {
   435  		pointers[i] = &holders[i]
   436  	}
   437  
   438  	if err := r.Scan(pointers...); err != nil {
   439  		return nil, err
   440  	}
   441  
   442  	return holders, nil
   443  }
   444  
   445  func (s SQL) prepareContext() (ctx context.Context, cancel func()) {
   446  	ctx = s.Ctx
   447  	if ctx == nil {
   448  		ctx = context.Background()
   449  	}
   450  	if s.Timeout > 0 {
   451  		return context.WithTimeout(ctx, s.Timeout)
   452  	}
   453  
   454  	return ctx, func() {}
   455  }
   456  
   457  func (s *SQL) prepareQuery(db SqxDB, optionFns ...QueryOptionFn) (*QueryOption, *sql.Rows, []string, error) {
   458  	if err := s.adaptQuery(db); err != nil {
   459  		return nil, nil, nil, err
   460  	}
   461  
   462  	ctx, cancel := s.prepareContext()
   463  	defer cancel()
   464  	ctx = context.WithValue(ctx, AdaptedKey, s.adapted)
   465  	r, err := db.QueryContext(ctx, s.Q, s.Vars...)
   466  	if err != nil {
   467  		return nil, nil, nil, err
   468  	}
   469  
   470  	columns, err := r.Columns()
   471  	if err != nil {
   472  		return nil, nil, nil, err
   473  	}
   474  
   475  	option := QueryOptionFns(optionFns).Options()
   476  	if option.LowerColumnNames {
   477  		for i, col := range columns {
   478  			columns[i] = strings.ToLower(col)
   479  		}
   480  	}
   481  
   482  	return option, r, columns, nil
   483  }