github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/dao.go (about)

     1  package sqx
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"fmt"
     7  	"log"
     8  	"reflect"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/bingoohuang/gg/pkg/mathx"
    13  	"github.com/bingoohuang/gg/pkg/reflector"
    14  	"github.com/bingoohuang/gg/pkg/sqlparse/sqlparser"
    15  	"github.com/bingoohuang/gg/pkg/ss"
    16  	"github.com/bingoohuang/gg/pkg/strcase"
    17  )
    18  
    19  type Limit struct {
    20  	Offset int64
    21  	Length int64
    22  }
    23  
    24  type Count int64
    25  
    26  var (
    27  	LimitType = reflect.TypeOf((*Limit)(nil)).Elem()
    28  	CountType = reflect.TypeOf((*Count)(nil)).Elem()
    29  )
    30  
    31  // DBGetter is the interface to get a sql.DBGetter.
    32  type DBGetter interface{ GetDB() *sql.DB }
    33  
    34  // StdDB is the wrapper for sql.DBGetter.
    35  type StdDB struct{ db *sql.DB }
    36  
    37  // GetDB returns a sql.DBGetter.
    38  func (f StdDB) GetDB() *sql.DB { return f.db }
    39  
    40  // DB is the global sql.DB for convenience.
    41  var DB *sql.DB
    42  
    43  // CreateDao fulfils the dao (should be pointer).
    44  func CreateDao(dao interface{}, createDaoOpts ...CreateDaoOpter) error {
    45  	daov := reflect.ValueOf(dao)
    46  	if daov.Kind() != reflect.Ptr || daov.Elem().Kind() != reflect.Struct {
    47  		return fmt.Errorf("dao should be pointer to struct") // nolint:goerr113
    48  	}
    49  
    50  	option, err := applyCreateDaoOption(createDaoOpts)
    51  	if err != nil {
    52  		return err
    53  	}
    54  
    55  	v := reflect.Indirect(daov)
    56  	createDBGetter(v, option)
    57  	createLogger(v, option)
    58  	createErrorSetter(v, option)
    59  
    60  	structValue := MakeStructValue(v)
    61  	for i := 0; i < structValue.NumField; i++ {
    62  		f := structValue.FieldByIndex(i)
    63  
    64  		if f.PkgPath != "" /* not exportable */ || f.Kind != reflect.Func {
    65  			continue
    66  		}
    67  
    68  		tags, err := reflector.ParseTags(string(f.Tag))
    69  		if err != nil {
    70  			return err
    71  		}
    72  
    73  		sqlStmt, sqlName := option.getSQLStmt(f, tags, 0)
    74  		if sqlStmt == nil {
    75  			return fmt.Errorf("failed to find sqlName %s", f.Name) // nolint:goerr113
    76  		}
    77  
    78  		parsed := &SQLParsed{
    79  			ID:  sqlName,
    80  			SQL: sqlStmt,
    81  			opt: option,
    82  		}
    83  
    84  		if err := parsed.fastParseSQL(sqlStmt.Raw()); err != nil {
    85  			return err
    86  		}
    87  
    88  		r := sqlRun{SQLParsed: parsed}
    89  		if err := r.createFn(f); err != nil {
    90  			return err
    91  		}
    92  	}
    93  
    94  	return nil
    95  }
    96  
    97  func (option *CreateDaoOpt) getSQLStmt(field StructField, tags reflector.Tags, stack int) (SQLPart, string) {
    98  	if stack > 10 {
    99  		return nil, ""
   100  	}
   101  
   102  	if sqlStmt := field.GetTag("sql"); sqlStmt != "" {
   103  		dsi := DotItem{
   104  			Name:    field.Name,
   105  			Content: []string{sqlStmt},
   106  			Attrs:   tags.Map(),
   107  		}
   108  		part, err := dsi.DynamicSQL()
   109  		if err != nil {
   110  			option.Logger.LogError(err)
   111  		}
   112  
   113  		return part, field.Name
   114  	}
   115  
   116  	sqlName := field.GetTagOr("sqlName", field.Name)
   117  	if part, err := option.DotSQL(sqlName); err != nil {
   118  		option.Logger.LogError(err)
   119  	} else if part != nil {
   120  		return part, sqlName
   121  	}
   122  
   123  	if sqlName == field.Name {
   124  		return nil, ""
   125  	}
   126  
   127  	if field, ok := field.Parent.FieldByName(sqlName); ok {
   128  		return option.getSQLStmt(field, nil, stack+1)
   129  	}
   130  
   131  	return nil, sqlName
   132  }
   133  
   134  func (r *sqlRun) createFn(f StructField) error {
   135  	numIn := f.Type.NumIn()
   136  	numOut := f.Type.NumOut()
   137  
   138  	lastOutError := numOut > 0 && reflector.IsError(f.Type.Out(numOut-1))
   139  	if lastOutError {
   140  		numOut--
   141  	}
   142  
   143  	fn := r.MakeFunc(f, numIn, numOut)
   144  	if fn == nil {
   145  		err := fmt.Errorf("unsupportd func %s %v", f.Name, f.Type) // nolint:goerr113
   146  		r.logError(err)
   147  
   148  		return err
   149  	}
   150  
   151  	f.Field.Set(reflect.MakeFunc(f.Type, func(args []reflect.Value) []reflect.Value {
   152  		r.opt.ErrSetter(nil)
   153  		values, err := fn(args)
   154  		if err != nil {
   155  			r.opt.ErrSetter(err)
   156  			r.logError(err)
   157  
   158  			values = make([]reflect.Value, numOut, numOut+1)
   159  			for i := 0; i < numOut; i++ {
   160  				values[i] = reflect.Zero(f.Type.Out(i))
   161  			}
   162  		}
   163  
   164  		if lastOutError {
   165  			if err != nil {
   166  				values = append(values, reflect.ValueOf(err))
   167  			} else {
   168  				values = append(values, reflect.Zero(reflector.ErrType))
   169  			}
   170  		}
   171  
   172  		return values
   173  	}))
   174  
   175  	return nil
   176  }
   177  
   178  func (r *sqlRun) MakeFunc(f StructField, numIn, numOut int) func([]reflect.Value) ([]reflect.Value, error) {
   179  	fn := r.getExecFn()
   180  	return func(args []reflect.Value) ([]reflect.Value, error) {
   181  		return fn(numIn, f, makeOutTypes(f.Type, numOut), args)
   182  	}
   183  }
   184  
   185  func (r *sqlRun) getExecFn() func(int, StructField, []reflect.Type, []reflect.Value) ([]reflect.Value, error) {
   186  	switch isBindByName := r.isBindBy(ByName); {
   187  	case !r.IsQuery && isBindByName:
   188  		return r.execByName
   189  	case !r.IsQuery && !isBindByName:
   190  		return r.execBySeq
   191  	case r.IsQuery && isBindByName:
   192  		return r.queryByName
   193  	default: // isQuery && !isBindByName:
   194  		return r.queryBySeq
   195  	}
   196  }
   197  
   198  func makeOutTypes(outType reflect.Type, numOut int) []reflect.Type {
   199  	rt := make([]reflect.Type, numOut)
   200  	for i := 0; i < numOut; i++ {
   201  		rt[i] = outType.Out(i)
   202  	}
   203  
   204  	return rt
   205  }
   206  
   207  type sqlRun struct {
   208  	*SQLParsed
   209  }
   210  
   211  func (p *SQLParsed) evalSeq(numIn int, f StructField, args []reflect.Value) error {
   212  	env := make(map[string]interface{})
   213  	for i, arg := range args {
   214  		env[fmt.Sprintf("_%d", i+1)] = arg.Interface()
   215  	}
   216  
   217  	if len(args) > 0 {
   218  		env = p.createFieldSqlParts(env, args[0])
   219  	}
   220  
   221  	return p.eval(numIn, f, env)
   222  }
   223  
   224  func (p *SQLParsed) eval(numIn int, f StructField, env map[string]interface{}) error {
   225  	runSQL, err := p.SQL.Eval(env)
   226  	if err != nil {
   227  		return err
   228  	}
   229  
   230  	if err := p.parseSQL(runSQL); err != nil {
   231  		return err
   232  	}
   233  
   234  	if err := p.checkFuncInOut(numIn, f); err != nil {
   235  		return err
   236  	}
   237  
   238  	return nil
   239  }
   240  
   241  func (r *sqlRun) queryByName(numIn int, f StructField,
   242  	outTypes []reflect.Type, args []reflect.Value,
   243  ) ([]reflect.Value, error) {
   244  	var bean reflect.Value
   245  
   246  	if numIn > 0 {
   247  		bean = args[0]
   248  	}
   249  
   250  	parsed := *r.SQLParsed
   251  	env := parsed.createNamedMap(bean)
   252  
   253  	if err := parsed.eval(numIn, f, env); err != nil {
   254  		return nil, err
   255  	}
   256  
   257  	vars, err := parsed.createNamedVars(bean)
   258  	if err != nil {
   259  		return nil, err
   260  	}
   261  
   262  	counterIndex := indexOfTypes(outTypes, CountType)
   263  	db := r.opt.DBGetter.GetDB()
   264  	rows, counter, err := parsed.doQueryDirectVars(db, vars, counterIndex >= 0)
   265  	if err != nil {
   266  		return nil, err
   267  	}
   268  
   269  	return parsed.wrapCounter(rows, outTypes, counterIndex, counter)
   270  }
   271  
   272  func (p *SQLParsed) wrapCounter(rows *sql.Rows, outTypes []reflect.Type, counterIndex int, counterFn func() (int64, error)) ([]reflect.Value, error) {
   273  	values, err := p.processQueryRows(rows, remove(outTypes, counterIndex))
   274  	_ = rows.Close()
   275  	if err != nil || counterFn == nil {
   276  		return values, err
   277  	}
   278  
   279  	counter, err := counterFn()
   280  	if err != nil {
   281  		return values, err
   282  	}
   283  
   284  	return insert(values, counterIndex, reflect.ValueOf(Count(counter))), nil
   285  }
   286  
   287  func remove(slice []reflect.Type, s int) []reflect.Type {
   288  	if s < 0 {
   289  		return slice
   290  	}
   291  
   292  	return append(slice[:s], slice[s+1:]...)
   293  }
   294  
   295  func insert(a []reflect.Value, index int, value reflect.Value) []reflect.Value {
   296  	if len(a) == index { // nil or empty slice or after last element
   297  		return append(a, value)
   298  	}
   299  
   300  	a = append(a[:index+1], a[index:]...) // index < len(a)
   301  	a[index] = value
   302  	return a
   303  }
   304  
   305  func indexOfTypes(types []reflect.Type, typ reflect.Type) int {
   306  	for i, t := range types {
   307  		if t == typ {
   308  			return i
   309  		}
   310  	}
   311  
   312  	return -1
   313  }
   314  
   315  func (r *sqlRun) execByName(numIn int, f StructField, outTypes []reflect.Type, args []reflect.Value) ([]reflect.Value, error) {
   316  	var bean reflect.Value
   317  
   318  	if numIn > 0 {
   319  		bean = args[0]
   320  	}
   321  
   322  	item0 := bean
   323  	itemSize := 1
   324  	isBeanSlice := bean.IsValid() && bean.Type().Kind() == reflect.Slice
   325  
   326  	if isBeanSlice {
   327  		if bean.IsNil() || bean.Len() == 0 {
   328  			return []reflect.Value{}, nil
   329  		}
   330  
   331  		item0 = bean.Index(0)
   332  		itemSize = bean.Len()
   333  	}
   334  
   335  	var (
   336  		err        error
   337  		pr         *sql.Stmt
   338  		lastResult sql.Result
   339  		lastSQL    string
   340  	)
   341  
   342  	parsed := *r.SQLParsed
   343  	db := r.opt.DBGetter.GetDB()
   344  	tx, err := db.BeginTx(parsed.opt.Ctx, nil)
   345  	if err != nil {
   346  		return nil, fmt.Errorf("failed to begin tx %w", err)
   347  	}
   348  
   349  	for ii := 0; ii < itemSize; ii++ {
   350  		if ii > 0 {
   351  			item0 = bean.Index(ii)
   352  		}
   353  
   354  		namedMap := parsed.createNamedMap(item0)
   355  		if err := parsed.eval(numIn, f, namedMap); err != nil {
   356  			return nil, err
   357  		}
   358  		vars, err := parsed.createNamedVars(item0)
   359  		if err != nil {
   360  			return nil, err
   361  		}
   362  
   363  		if lastSQL != parsed.runSQL {
   364  			lastSQL = parsed.runSQL
   365  
   366  			query, err := r.replaceQuery(db, parsed.runSQL)
   367  			if err != nil {
   368  				return nil, fmt.Errorf("replaceQuery %s error %w", parsed.runSQL, err)
   369  			}
   370  
   371  			log.Printf("exec %s [%s] with %v", parsed.ID, query, vars)
   372  			if pr, err = tx.PrepareContext(parsed.opt.Ctx, query); err != nil {
   373  				return nil, fmt.Errorf("failed to prepare sql [%s] error %w", r.RawStmt, err)
   374  			}
   375  		}
   376  
   377  		lastResult, err = pr.ExecContext(parsed.opt.Ctx, vars...)
   378  		if err != nil {
   379  			return nil, fmt.Errorf("failed to execute %s with vars %v error %w", parsed.runSQL, vars, err)
   380  		}
   381  
   382  		LogSqlResult(lastResult)
   383  	}
   384  
   385  	if err := tx.Commit(); err != nil {
   386  		return nil, fmt.Errorf("failed to commiterror %w", err)
   387  	}
   388  
   389  	return convertExecResult(lastResult, lastSQL, outTypes)
   390  }
   391  
   392  func LogSqlResult(lastResult sql.Result) {
   393  	lastInsertId, _ := lastResult.LastInsertId()
   394  	rowsAffected, _ := lastResult.RowsAffected()
   395  	log.Printf("Result lastInsertId: %d, rowsAffected: %d", lastInsertId, rowsAffected)
   396  }
   397  
   398  func (p *SQLParsed) createFieldSqlParts(m map[string]interface{}, bean reflect.Value) map[string]interface{} {
   399  	if !bean.IsValid() || bean.Type().Kind() != reflect.Struct {
   400  		return m
   401  	}
   402  
   403  	structValue := MakeStructValue(bean)
   404  	for i, f := range structValue.FieldTypes {
   405  		if sqlPart := f.Tag.Get("sql"); sqlPart != "" {
   406  			if bean.Field(i).IsZero() {
   407  				continue
   408  			}
   409  
   410  			if f.Type.AssignableTo(LimitType) {
   411  				l := bean.Field(i).Interface().(Limit)
   412  				p.fp.AddFieldSqlPart(sqlPart, []interface{}{l.Offset, l.Length}, false)
   413  			} else {
   414  				p.fp.AddFieldSqlPart(sqlPart, []interface{}{bean.Field(i).Interface()}, true)
   415  			}
   416  		}
   417  	}
   418  
   419  	return m
   420  }
   421  
   422  func (p *SQLParsed) createNamedMap(bean reflect.Value) map[string]interface{} {
   423  	m := make(map[string]interface{})
   424  	if !bean.IsValid() {
   425  		return m
   426  	}
   427  
   428  	switch bean.Type().Kind() {
   429  	case reflect.Struct:
   430  		structValue := MakeStructValue(bean)
   431  		for i, f := range structValue.FieldTypes {
   432  			if tagName := f.Tag.Get("name"); tagName != "" {
   433  				m[tagName] = bean.Field(i).Interface()
   434  			} else {
   435  				name := strcase.ToCamelLower(f.Name)
   436  				m[name] = bean.Field(i).Interface()
   437  			}
   438  		}
   439  	case reflect.Map:
   440  		for _, k := range bean.MapKeys() {
   441  			if ks, ok := k.Interface().(string); ok {
   442  				m[ks] = bean.MapIndex(k).Interface()
   443  			}
   444  		}
   445  	}
   446  
   447  	return m
   448  }
   449  
   450  func (p *SQLParsed) createNamedVars(bean reflect.Value) ([]interface{}, error) {
   451  	itemType := bean.Type()
   452  
   453  	var namedValueParser func(name string, item reflect.Value, itemType reflect.Type) interface{}
   454  
   455  	switch itemType.Kind() {
   456  	case reflect.Struct:
   457  		namedValueParser = func(name string, item reflect.Value, itemType reflect.Type) interface{} {
   458  			return item.FieldByNameFunc(func(f string) bool {
   459  				return matchesField2Col(itemType, f, name)
   460  			}).Interface()
   461  		}
   462  	case reflect.Map:
   463  		namedValueParser = func(name string, item reflect.Value, itemType reflect.Type) interface{} {
   464  			return item.MapIndex(reflect.ValueOf(name)).Interface()
   465  		}
   466  	}
   467  
   468  	if namedValueParser == nil {
   469  		// nolint:goerr113
   470  		return nil, fmt.Errorf("named vars should use struct/map, unsupported type %v", itemType)
   471  	}
   472  
   473  	vars := make([]interface{}, len(p.Vars))
   474  
   475  	for i, name := range p.Vars {
   476  		vars[i] = namedValueParser(name, bean, itemType)
   477  	}
   478  
   479  	return vars, nil
   480  }
   481  
   482  func (r *sqlRun) execBySeq(numIn int, f StructField,
   483  	outTypes []reflect.Type, args []reflect.Value,
   484  ) ([]reflect.Value, error) {
   485  	parsed := *r.SQLParsed
   486  
   487  	if err := parsed.evalSeq(numIn, f, args); err != nil {
   488  		return nil, err
   489  	}
   490  
   491  	vars := parsed.makeVars(args)
   492  	db := r.opt.DBGetter.GetDB()
   493  	query, err := r.replaceQuery(db, parsed.runSQL)
   494  	if err != nil {
   495  		return nil, fmt.Errorf("replaceQuery %s error %w", parsed.runSQL, err)
   496  	}
   497  
   498  	log.Printf("exec query %s [%s] with %v", r.ID, query, vars)
   499  
   500  	result, err := db.ExecContext(parsed.opt.Ctx, query, vars...)
   501  	if err != nil {
   502  		return nil, fmt.Errorf("execute %s error %w", r.SQL, err)
   503  	}
   504  
   505  	LogSqlResult(result)
   506  
   507  	results, err := convertExecResult(result, query, outTypes)
   508  	if err != nil {
   509  		return nil, fmt.Errorf("execute %s error %w", r.SQL, err)
   510  	}
   511  
   512  	return results, nil
   513  }
   514  
   515  func (r *sqlRun) queryBySeq(numIn int, f StructField,
   516  	outTypes []reflect.Type, args []reflect.Value,
   517  ) ([]reflect.Value, error) {
   518  	parsed := *r.SQLParsed
   519  	if err := parsed.evalSeq(numIn, f, args); err != nil {
   520  		return nil, err
   521  	}
   522  
   523  	db := r.opt.DBGetter.GetDB()
   524  	counterIndex := indexOfTypes(outTypes, CountType)
   525  
   526  	rows, counterFn, err := parsed.doQuery(db, args, counterIndex >= 0)
   527  	if err != nil {
   528  		return nil, err
   529  	}
   530  
   531  	defer rows.Close()
   532  
   533  	return parsed.wrapCounter(rows, outTypes, counterIndex, counterFn)
   534  }
   535  
   536  func (p *SQLParsed) processQueryRows(rows *sql.Rows, outTypes []reflect.Type) ([]reflect.Value, error) {
   537  	columns, err := rows.Columns()
   538  	if err != nil {
   539  		return nil, fmt.Errorf("get columns %s error %w", p.SQL, err)
   540  	}
   541  
   542  	out0Type := outTypes[0]
   543  	outSlice := reflect.Value{}
   544  	out0TypePtr := out0Type.Kind() == reflect.Ptr
   545  
   546  	switch out0Type.Kind() {
   547  	case reflect.Slice:
   548  		outSlice = reflect.MakeSlice(out0Type, 0, 0)
   549  		out0Type = out0Type.Elem()
   550  	case reflect.Ptr:
   551  		out0Type = out0Type.Elem()
   552  	}
   553  
   554  	interceptorFn := p.getRowScanInterceptorFn()
   555  	mapFields, err := p.createMapFields(columns, out0Type, outTypes)
   556  	if err != nil {
   557  		return nil, err
   558  	}
   559  	ri := 0
   560  
   561  	defer func() {
   562  		log.Printf("query got %d rows", ri)
   563  	}()
   564  
   565  	for ; rows.Next() && (p.opt.QueryMaxRows <= 0 || ri < p.opt.QueryMaxRows); ri++ {
   566  		pointers, out := resetDests(out0Type, out0TypePtr, outTypes, mapFields)
   567  		if err := rows.Scan(pointers[:len(columns)]...); err != nil {
   568  			return nil, fmt.Errorf("scan rows %s error %w", p.SQL, err)
   569  		}
   570  
   571  		fillFields(mapFields, pointers)
   572  
   573  		if interceptorFn != nil {
   574  			outValues := make([]interface{}, len(out))
   575  			for i, outVal := range out {
   576  				outValues[i] = outVal.Interface()
   577  			}
   578  
   579  			if goon, err := interceptorFn(ri, outValues...); err != nil {
   580  				return nil, err
   581  			} else if !goon {
   582  				break
   583  			}
   584  		}
   585  
   586  		if !outSlice.IsValid() {
   587  			return out[:len(outTypes)], nil
   588  		}
   589  
   590  		outSlice = reflect.Append(outSlice, out[0])
   591  	}
   592  
   593  	if outSlice.IsValid() {
   594  		return []reflect.Value{outSlice}, nil
   595  	}
   596  
   597  	return noRows(out0Type, out0TypePtr, outTypes)
   598  }
   599  
   600  func noRows(out0Type reflect.Type, out0TypePtr bool, outTypes []reflect.Type) ([]reflect.Value, error) {
   601  	switch out0Type.Kind() {
   602  	case reflect.Map:
   603  		out := reflect.MakeMap(reflect.MapOf(out0Type.Key(), out0Type.Elem()))
   604  		return []reflect.Value{out}, nil
   605  	case reflect.Struct:
   606  		if out0TypePtr {
   607  			return []reflect.Value{reflect.Zero(outTypes[0])}, nil
   608  		}
   609  
   610  		return []reflect.Value{reflect.Indirect(reflect.New(out0Type))}, nil
   611  	}
   612  
   613  	outValues := make([]reflect.Value, len(outTypes))
   614  	for i := range outTypes {
   615  		outValues[i] = reflect.Indirect(reflect.New(outTypes[i]))
   616  	}
   617  
   618  	return outValues, sql.ErrNoRows
   619  }
   620  
   621  func (p *SQLParsed) getRowScanInterceptorFn() RowScanInterceptorFn {
   622  	if p.opt.RowScanInterceptor != nil {
   623  		return p.opt.RowScanInterceptor.After
   624  	}
   625  
   626  	return nil
   627  }
   628  
   629  func (p *SQLParsed) doQuery(db *sql.DB, args []reflect.Value, counting bool) (*sql.Rows, func() (int64, error), error) {
   630  	vars := p.makeVars(args)
   631  	return p.doQueryDirectVars(db, vars, counting)
   632  }
   633  
   634  func (p *SQLParsed) doQueryDirectVars(db *sql.DB, vars []interface{}, counting bool) (*sql.Rows, func() (int64, error), error) {
   635  	query, err := p.replaceQuery(db, p.runSQL)
   636  	if err != nil {
   637  		return nil, nil, fmt.Errorf("replaceQuery %s error %w", query, err)
   638  	}
   639  
   640  	log.Printf("exec query %s [%s] with %v", p.ID, query, vars)
   641  
   642  	rows, err := db.QueryContext(p.opt.Ctx, query, vars...)
   643  	if err != nil || rows.Err() != nil {
   644  		if err == nil {
   645  			err = rows.Err()
   646  		}
   647  
   648  		return nil, nil, fmt.Errorf("execute %s error %w", query, err)
   649  	}
   650  
   651  	if counting {
   652  		return rows, func() (int64, error) {
   653  			count, err := p.pagingCount(db, query, vars)
   654  			return count, err
   655  		}, nil
   656  	}
   657  
   658  	return rows, nil, nil
   659  }
   660  
   661  var countStarExprs = func() sqlparser.SelectExprs {
   662  	p, _ := sqlparser.Parse(`select count(*)`)
   663  	return p.(*sqlparser.Select).SelectExprs
   664  }()
   665  
   666  func (p *SQLParsed) pagingCount(db *sql.DB, query string, vars []interface{}) (int64, error) {
   667  	parsed, err := sqlparser.Parse(query)
   668  	if err != nil {
   669  		return 0, err
   670  	}
   671  
   672  	selectQuery, ok := parsed.(*sqlparser.Select)
   673  	if !ok {
   674  		return 0, errors.New("not select query")
   675  	}
   676  
   677  	selectQuery.SelectExprs = countStarExprs
   678  	selectQuery.OrderBy = nil
   679  	selectQuery.Having = nil
   680  	oldLimit := selectQuery.Limit
   681  	selectQuery.Limit = nil
   682  
   683  	limitVarsCount := 0
   684  	if oldLimit != nil {
   685  		limitVarsCount++
   686  		if oldLimit.Offset != nil {
   687  			limitVarsCount++
   688  		}
   689  	}
   690  
   691  	countQuery := sqlparser.String(selectQuery)
   692  	vars = vars[:len(vars)-limitVarsCount]
   693  
   694  	log.Printf("I! execute query %s [%s] with args %v", p.ID, countQuery, vars)
   695  
   696  	countQuery, err = p.replaceQuery(db, countQuery)
   697  	if err != nil {
   698  		return 0, fmt.Errorf("replaceQuery %s error %w", countQuery, err)
   699  	}
   700  
   701  	rows, err := db.QueryContext(p.opt.Ctx, countQuery, vars...)
   702  	if err != nil || rows.Err() != nil {
   703  		if err == nil {
   704  			err = rows.Err()
   705  		}
   706  
   707  		return 0, fmt.Errorf("execute %s error %w", countQuery, err)
   708  	}
   709  
   710  	defer rows.Close()
   711  
   712  	rows.Next()
   713  	var count int64
   714  	if err := rows.Scan(&count); err != nil {
   715  		return 0, err
   716  	}
   717  
   718  	return count, nil
   719  }
   720  
   721  func (p *SQLParsed) createMapFields(columns []string, out0Type reflect.Type,
   722  	outTypes []reflect.Type,
   723  ) ([]selectItem, error) {
   724  	switch out0Type.Kind() {
   725  	case reflect.Struct, reflect.Map:
   726  		if len(outTypes) != 1 {
   727  			// nolint:goerr113
   728  			return nil, fmt.Errorf("unsupported return type  %v for current sql %v", out0Type, p.SQL)
   729  		}
   730  	}
   731  
   732  	lenCol := len(columns)
   733  	switch out0Type.Kind() {
   734  	case reflect.Struct:
   735  		mapFields := make([]selectItem, lenCol)
   736  		for i, col := range columns {
   737  			mapFields[i] = p.makeStructField(col, out0Type)
   738  		}
   739  
   740  		return mapFields, nil
   741  	case reflect.Map:
   742  		mapFields := make([]selectItem, lenCol)
   743  		for i, col := range columns {
   744  			mapFields[i] = p.makeMapField(col, out0Type)
   745  		}
   746  
   747  		return mapFields, nil
   748  	}
   749  
   750  	mapFields := make([]selectItem, mathx.Max(lenCol, len(outTypes)))
   751  	for i := range columns {
   752  		if i < len(outTypes) {
   753  			vType := out0Type
   754  			if i > 0 {
   755  				vType = outTypes[i]
   756  			}
   757  
   758  			ptr := vType.Kind() == reflect.Ptr
   759  			if ptr {
   760  				vType = vType.Elem()
   761  			}
   762  
   763  			mapFields[i] = &singleValue{vType: vType, ptr: ptr}
   764  		} else {
   765  			mapFields[i] = &singleValue{vType: reflect.TypeOf("")}
   766  		}
   767  	}
   768  
   769  	for i := lenCol; i < len(outTypes); i++ {
   770  		mapFields[i] = &singleValue{vType: outTypes[i]}
   771  	}
   772  
   773  	return mapFields, nil
   774  }
   775  
   776  func (p *SQLParsed) makeMapField(col string, outType reflect.Type) selectItem {
   777  	return &mapItem{k: reflect.ValueOf(col), vType: outType.Elem()}
   778  }
   779  
   780  func (p *SQLParsed) makeStructField(col string, outType reflect.Type) selectItem {
   781  	fv, ok := outType.FieldByNameFunc(func(field string) bool {
   782  		return matchesField2Col(outType, field, col)
   783  	})
   784  
   785  	if ok {
   786  		return &structItem{StructField: &fv}
   787  	}
   788  
   789  	return nil
   790  }
   791  
   792  func matchesField2Col(structType reflect.Type, field, col string) bool {
   793  	f, _ := structType.FieldByName(field)
   794  	if tagName := f.Tag.Get("name"); tagName != "" {
   795  		return tagName == col
   796  	}
   797  
   798  	return ss.AnyOfFold(field, col, strcase.ToCamel(col))
   799  }
   800  
   801  func (p *SQLParsed) makeVars(args []reflect.Value) []interface{} {
   802  	vars := make([]interface{}, 0, len(p.Vars))
   803  
   804  	for i, name := range p.Vars[:len(p.Vars)-len(p.fp.fieldVars)] {
   805  		if p.BindBy == ByAuto {
   806  			vars = append(vars, args[i].Interface())
   807  		} else {
   808  			seq, _ := strconv.Atoi(name)
   809  			vars = append(vars, args[seq-1].Interface())
   810  		}
   811  	}
   812  
   813  	if len(p.fp.fieldVars) > 0 {
   814  		vars = append(vars, p.fp.fieldVars...)
   815  	}
   816  
   817  	return vars
   818  }
   819  
   820  func (p *SQLParsed) logError(err error) {
   821  	log.Printf("E! error: %v", err)
   822  	p.opt.Logger.LogError(err)
   823  }
   824  
   825  func convertExecResult(result sql.Result, query string, outTypes []reflect.Type) ([]reflect.Value, error) {
   826  	if len(outTypes) == 0 {
   827  		return []reflect.Value{}, nil
   828  	}
   829  
   830  	lastInsertIDVal, _ := result.LastInsertId()
   831  	rowsAffectedVal, _ := result.RowsAffected()
   832  
   833  	firstWord := strings.ToUpper(ss.FirstWord(query))
   834  	results := make([]reflect.Value, 0)
   835  
   836  	if len(outTypes) == 1 {
   837  		if firstWord == "INSERT" {
   838  			return append(results, reflect.ValueOf(lastInsertIDVal).Convert(outTypes[0])), nil
   839  		}
   840  
   841  		return append(results, reflect.ValueOf(rowsAffectedVal).Convert(outTypes[0])), nil
   842  	}
   843  
   844  	results = append(results, reflect.ValueOf(rowsAffectedVal).Convert(outTypes[0]),
   845  		reflect.ValueOf(lastInsertIDVal).Convert(outTypes[1]))
   846  
   847  	for i := 2; i < len(outTypes); i++ {
   848  		results = append(results, reflect.Zero(outTypes[i]))
   849  	}
   850  
   851  	return results, nil
   852  }
   853  
   854  type selectItem interface {
   855  	Type() reflect.Type
   856  	Set(val reflect.Value)
   857  	ResetParent(parent reflect.Value)
   858  }
   859  
   860  type structItem struct {
   861  	*reflect.StructField
   862  	parent reflect.Value
   863  }
   864  
   865  func (s *structItem) Type() reflect.Type               { return s.StructField.Type }
   866  func (s *structItem) ResetParent(parent reflect.Value) { s.parent = parent }
   867  func (s *structItem) Set(val reflect.Value) {
   868  	f := s.parent.FieldByName(s.StructField.Name)
   869  	f.Set(val.Convert(f.Type()))
   870  }
   871  
   872  type mapItem struct {
   873  	k      reflect.Value
   874  	vType  reflect.Type
   875  	parent reflect.Value
   876  }
   877  
   878  func (s *mapItem) Type() reflect.Type               { return s.vType }
   879  func (s *mapItem) ResetParent(parent reflect.Value) { s.parent = parent }
   880  func (s *mapItem) Set(val reflect.Value)            { s.parent.SetMapIndex(s.k, val) }
   881  
   882  type singleValue struct {
   883  	ptr    bool
   884  	parent reflect.Value
   885  	vType  reflect.Type
   886  }
   887  
   888  func (s *singleValue) Type() reflect.Type               { return s.vType }
   889  func (s *singleValue) ResetParent(parent reflect.Value) { s.parent = parent }
   890  func (s *singleValue) Set(val reflect.Value) {
   891  	if !s.parent.IsValid() {
   892  		s.parent = reflect.Indirect(reflect.New(s.vType))
   893  	}
   894  
   895  	s.parent.Set(val)
   896  }
   897  
   898  func resetDests(out0Type reflect.Type, out0TypePtr bool,
   899  	outTypes []reflect.Type, mapFields []selectItem,
   900  ) ([]interface{}, []reflect.Value) {
   901  	pointers := make([]interface{}, len(mapFields))
   902  
   903  	var out0 reflect.Value
   904  
   905  	out := make([]reflect.Value, len(outTypes))
   906  
   907  	out0Kind := out0Type.Kind()
   908  	hasParent := false
   909  	switch out0Kind {
   910  	case reflect.Map, reflect.Struct:
   911  		hasParent = true
   912  	}
   913  
   914  	switch out0Kind {
   915  	case reflect.Map:
   916  		out0 = reflect.MakeMap(reflect.MapOf(out0Type.Key(), out0Type.Elem()))
   917  		out[0] = out0
   918  	default:
   919  		out0Ptr := reflect.New(out0Type)
   920  		out0 = reflect.Indirect(out0Ptr)
   921  
   922  		if out0TypePtr {
   923  			out[0] = out0Ptr
   924  		} else {
   925  			out[0] = out0
   926  		}
   927  	}
   928  
   929  	for i, fv := range mapFields {
   930  		if fv == nil {
   931  			pointers[i] = &NullAny{Type: nil}
   932  			continue
   933  		}
   934  
   935  		if hasParent {
   936  			fv.ResetParent(out0)
   937  		} else if i == 0 {
   938  			fv.ResetParent(out[0])
   939  		} else if i < len(outTypes) {
   940  			out[i] = reflect.Indirect(reflect.New(outTypes[i]))
   941  			fv.ResetParent(out[i])
   942  		}
   943  
   944  		if ImplSQLScanner(fv.Type()) {
   945  			pointers[i] = reflect.New(fv.Type()).Interface()
   946  		} else {
   947  			pointers[i] = &NullAny{Type: fv.Type()}
   948  		}
   949  	}
   950  
   951  	return pointers, out
   952  }
   953  
   954  func fillFields(mapFields []selectItem, pointers []interface{}) {
   955  	for i, field := range mapFields {
   956  		if field == nil {
   957  			continue
   958  		}
   959  
   960  		if p, ok := pointers[i].(*NullAny); ok {
   961  			field.Set(p.GetVal())
   962  		} else {
   963  			field.Set(reflect.ValueOf(pointers[i]).Elem())
   964  		}
   965  	}
   966  }