github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/sqx.go (about)

     1  package sqx
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/bingoohuang/gg/pkg/sqlparse/sqlparser"
    13  	"github.com/bingoohuang/gg/pkg/ss"
    14  	"github.com/bingoohuang/gg/pkg/strcase"
    15  )
    16  
    17  // ErrConditionKind tells that the condition kind should be struct or its pointer
    18  var ErrConditionKind = errors.New("condition kind should be struct or its pointer")
    19  
    20  // SQL is a structure for query and vars.
    21  type SQL struct {
    22  	Name    string
    23  	Q       string
    24  	AppendQ string
    25  	Vars    []interface{}
    26  	Ctx     context.Context
    27  	NoLog   bool
    28  
    29  	Timeout        time.Duration
    30  	Limit          int
    31  	ConvertOptions []sqlparser.ConvertOption
    32  
    33  	adapted bool
    34  }
    35  
    36  func (s *SQL) AppendIf(ok bool, sub string, args ...interface{}) *SQL {
    37  	if !ok {
    38  		return s
    39  	}
    40  
    41  	return s.Append(sub, args...)
    42  }
    43  
    44  // Append appends sub statement to the query.
    45  func (s *SQL) Append(sub string, args ...interface{}) *SQL {
    46  	if sub == "" {
    47  		return s
    48  	}
    49  
    50  	if strings.HasPrefix(sub, " ") {
    51  		s.Q += sub
    52  	} else {
    53  		s.Q += " " + sub
    54  	}
    55  
    56  	s.Vars = append(s.Vars, args...)
    57  
    58  	return s
    59  }
    60  
    61  // NewSQL create s SQL object.
    62  func NewSQL(query string, vars ...interface{}) *SQL {
    63  	return &SQL{Q: query, Vars: vars}
    64  }
    65  
    66  // WithVars replace vars.
    67  func WithVars(vars ...interface{}) []interface{} { return vars }
    68  
    69  // WithConvertOptions set SQL conversion options.
    70  func (s *SQL) WithConvertOptions(convertOptions []sqlparser.ConvertOption) *SQL {
    71  	s.ConvertOptions = convertOptions
    72  	return s
    73  }
    74  
    75  // WithTimeout set sql execution timeout
    76  func (s *SQL) WithTimeout(timeout time.Duration) *SQL {
    77  	s.Timeout = timeout
    78  	return s
    79  }
    80  
    81  // WithVars replace vars.
    82  func (s *SQL) WithVars(vars ...interface{}) *SQL {
    83  	s.Vars = vars
    84  	return s
    85  }
    86  
    87  func (s *SQL) AndIf(ok bool, cond string, args ...interface{}) *SQL {
    88  	if !ok {
    89  		return s
    90  	}
    91  
    92  	return s.And(cond, args...)
    93  }
    94  
    95  func (s *SQL) And(cond string, args ...interface{}) *SQL {
    96  	switch len(args) {
    97  	case 0:
    98  		if !ss.ContainsFold(s.Q, "where") {
    99  			s.Q += " where " + cond
   100  		} else {
   101  			s.Q += " and " + cond
   102  		}
   103  		return s
   104  	case 1:
   105  		arg := reflect.ValueOf(args[0])
   106  		if arg.IsZero() {
   107  			return s
   108  		}
   109  
   110  		isSlice := arg.Kind() == reflect.Slice
   111  		if isSlice && arg.Len() == 0 {
   112  			return s
   113  		}
   114  		if isSlice && arg.Len() > 1 && strings.Count(cond, "?") == 1 {
   115  			cond = strings.Replace(cond, "?", ss.Repeat("?", ",", arg.Len()), 1)
   116  		}
   117  		if !ss.ContainsFold(s.Q, "where") {
   118  			s.Q += " where " + cond
   119  		} else {
   120  			s.Q += " and " + cond
   121  		}
   122  
   123  		if isSlice {
   124  			for i := 0; i < arg.Len(); i++ {
   125  				s.Vars = append(s.Vars, arg.Index(i).Interface())
   126  			}
   127  		} else {
   128  			s.Vars = append(s.Vars, args[0])
   129  		}
   130  		return s
   131  	default:
   132  		panic("not supported")
   133  	}
   134  }
   135  
   136  func (s *SQL) adaptUpdate(db SqxDB) error {
   137  	if dbTypeAware, ok := db.(DBTypeAware); ok {
   138  		dbType := dbTypeAware.GetDBType()
   139  		options := s.ConvertOptions
   140  		cr, err := dbType.Convert(s.Q, options...)
   141  		if err != nil {
   142  			return err
   143  		}
   144  
   145  		s.Q, s.Vars = cr.PickArgs(s.Vars)
   146  	}
   147  
   148  	if !s.NoLog {
   149  		logQuery(s.Name, s.Q, s.Vars)
   150  	}
   151  
   152  	return nil
   153  }
   154  
   155  func (s *SQL) adaptQuery(db SqxDB) error {
   156  	if dbTypeAware, ok := db.(DBTypeAware); ok {
   157  		dbType := dbTypeAware.GetDBType()
   158  		options := s.ConvertOptions
   159  		if s.Limit > 0 {
   160  			options = append([]sqlparser.ConvertOption{sqlparser.WithLimit(s.Limit)}, options...)
   161  		}
   162  		cr, err := dbType.Convert(s.Q, options...)
   163  		if err != nil {
   164  			return err
   165  		}
   166  
   167  		s.Q, s.Vars = cr.PickArgs(s.Vars)
   168  		if s.AppendQ != "" {
   169  			s.Q += " " + s.AppendQ
   170  		}
   171  
   172  		s.adapted = true
   173  	}
   174  
   175  	if !s.NoLog {
   176  		logQuery(s.Name, s.Q, s.Vars)
   177  	}
   178  
   179  	return nil
   180  }
   181  
   182  // CreateSQL creates a composite SQL on base and condition cond.
   183  func CreateSQL(base string, cond interface{}) (*SQL, error) {
   184  	result := &SQL{}
   185  	if cond == nil {
   186  		result.Q = base
   187  		return result, nil
   188  	}
   189  
   190  	vc, err := inferenceCondValue(cond)
   191  	if err != nil {
   192  		return nil, err
   193  	}
   194  
   195  	condSql, vars, err := iterateFields(vc)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  
   200  	if condSql == "" {
   201  		result.Q = base
   202  		return result, nil
   203  	}
   204  
   205  	result.Vars = vars
   206  
   207  	parsed, err := sqlparser.Parse(base)
   208  	if err != nil {
   209  		return nil, err
   210  	}
   211  
   212  	iw, ok := parsed.(sqlparser.IWhere)
   213  	if !ok {
   214  		return result, nil
   215  	}
   216  
   217  	x := `select 1 from t where ` + createNewWhere(iw, condSql)
   218  	condParsed, err := sqlparser.Parse(x)
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  
   223  	iw.SetWhere(condParsed.(*sqlparser.Select).Where)
   224  	result.Q = sqlparser.String(parsed)
   225  
   226  	return result, nil
   227  }
   228  
   229  func createNewWhere(iw sqlparser.IWhere, condSql string) string {
   230  	where := iw.GetWhere()
   231  	if where == nil {
   232  		return condSql
   233  	}
   234  
   235  	whereString := sqlparser.String(where)
   236  	if _, ok := where.Expr.(*sqlparser.OrExpr); ok {
   237  		return `(` + whereString[7:] + `) and ` + condSql
   238  	}
   239  
   240  	return `` + whereString[7:] + ` and ` + condSql
   241  }
   242  
   243  func inferenceCondValue(cond interface{}) (reflect.Value, error) {
   244  	vc := reflect.ValueOf(cond)
   245  	if vc.Kind() == reflect.Ptr {
   246  		vc = vc.Elem()
   247  	}
   248  
   249  	if vc.Kind() != reflect.Struct {
   250  		return reflect.Value{}, ErrConditionKind
   251  	}
   252  
   253  	return vc, nil
   254  }
   255  
   256  const andPrefix = " and "
   257  
   258  func iterateFields(vc reflect.Value) (string, []interface{}, error) {
   259  	condSql := ""
   260  	vars := make([]interface{}, 0)
   261  	t := vc.Type()
   262  
   263  	for i := 0; i < vc.NumField(); i++ {
   264  		f := t.Field(i)
   265  		if f.PkgPath != "" { // not exported
   266  			continue
   267  		}
   268  
   269  		cond := f.Tag.Get("cond")
   270  		if cond == "-" { // ignore as a condition field
   271  			continue
   272  		}
   273  
   274  		v := vc.Field(i)
   275  		if f.Anonymous {
   276  			embeddedSQL, embeddedVars, err := iterateFields(v)
   277  			if err != nil {
   278  				return "", nil, err
   279  			}
   280  
   281  			condSql += andPrefix + embeddedSQL
   282  			vars = append(vars, embeddedVars...)
   283  			continue
   284  		}
   285  
   286  		cond, fieldVars, err := processTag(f.Tag, f.Name, v)
   287  		if err != nil {
   288  			return "", nil, err
   289  		}
   290  
   291  		if cond != "" {
   292  			condSql += andPrefix + cond
   293  			vars = append(vars, fieldVars...)
   294  		}
   295  	}
   296  
   297  	if condSql != "" {
   298  		condSql = condSql[len(andPrefix):]
   299  	}
   300  
   301  	return condSql, vars, nil
   302  }
   303  
   304  func processTag(tag reflect.StructTag, fieldName string, v reflect.Value) (cond string, vars []interface{}, err error) {
   305  	cond = tag.Get("cond")
   306  	zero := tag.Get("zero")
   307  	if yes, err1 := isZero(v, zero); err1 != nil {
   308  		return "", nil, err1
   309  	} else if yes { // ignore zero field
   310  		return "", nil, nil
   311  	}
   312  
   313  	if cond == "" {
   314  		cond = strcase.ToSnake(fieldName) + "=?"
   315  	}
   316  
   317  	vi := v.Interface()
   318  	if modifier := tag.Get("modifier"); modifier != "" {
   319  		vi = strings.ReplaceAll(modifier, "v", fmt.Sprintf("%v", vi))
   320  	}
   321  
   322  	for i := 0; i < strings.Count(cond, "?"); i++ {
   323  		vars = append(vars, vi)
   324  	}
   325  	return
   326  }
   327  
   328  func isZero(v reflect.Value, zero string) (bool, error) {
   329  	if zero == "" {
   330  		return v.IsZero(), nil
   331  	}
   332  
   333  	switch v.Kind() {
   334  	case reflect.String:
   335  		return zero == v.Interface(), nil
   336  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   337  		zeroV, err := strconv.ParseInt(zero, 10, 64)
   338  		if err != nil {
   339  			return false, err
   340  		}
   341  		return zeroV == v.Convert(TypeInt64).Interface(), nil
   342  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   343  		zeroV, err := strconv.ParseUint(zero, 10, 64)
   344  		if err != nil {
   345  			return false, err
   346  		}
   347  		return zeroV == v.Convert(TypeUint64).Interface(), nil
   348  	case reflect.Float32, reflect.Float64:
   349  		zeroV, err := strconv.ParseFloat(zero, 64)
   350  		if err != nil {
   351  			return false, err
   352  		}
   353  
   354  		return zeroV == v.Convert(TypeFloat64).Interface(), nil
   355  	case reflect.Bool:
   356  		zeroV, err := strconv.ParseBool(zero)
   357  		if err != nil {
   358  			return false, err
   359  		}
   360  		return zeroV == v.Interface(), nil
   361  	}
   362  
   363  	return false, nil
   364  }
   365  
   366  var (
   367  	TypeInt64   = reflect.TypeOf(int64(0))
   368  	TypeUint64  = reflect.TypeOf(uint64(0))
   369  	TypeFloat64 = reflect.TypeOf(float64(0))
   370  )