github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/sqx.go (about) 1 package sqx 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "reflect" 8 "strconv" 9 "strings" 10 "time" 11 12 "github.com/bingoohuang/gg/pkg/sqlparse/sqlparser" 13 "github.com/bingoohuang/gg/pkg/ss" 14 "github.com/bingoohuang/gg/pkg/strcase" 15 ) 16 17 // ErrConditionKind tells that the condition kind should be struct or its pointer 18 var ErrConditionKind = errors.New("condition kind should be struct or its pointer") 19 20 // SQL is a structure for query and vars. 21 type SQL struct { 22 Name string 23 Q string 24 AppendQ string 25 Vars []interface{} 26 Ctx context.Context 27 NoLog bool 28 29 Timeout time.Duration 30 Limit int 31 ConvertOptions []sqlparser.ConvertOption 32 33 adapted bool 34 } 35 36 func (s *SQL) AppendIf(ok bool, sub string, args ...interface{}) *SQL { 37 if !ok { 38 return s 39 } 40 41 return s.Append(sub, args...) 42 } 43 44 // Append appends sub statement to the query. 45 func (s *SQL) Append(sub string, args ...interface{}) *SQL { 46 if sub == "" { 47 return s 48 } 49 50 if strings.HasPrefix(sub, " ") { 51 s.Q += sub 52 } else { 53 s.Q += " " + sub 54 } 55 56 s.Vars = append(s.Vars, args...) 57 58 return s 59 } 60 61 // NewSQL create s SQL object. 62 func NewSQL(query string, vars ...interface{}) *SQL { 63 return &SQL{Q: query, Vars: vars} 64 } 65 66 // WithVars replace vars. 67 func WithVars(vars ...interface{}) []interface{} { return vars } 68 69 // WithConvertOptions set SQL conversion options. 70 func (s *SQL) WithConvertOptions(convertOptions []sqlparser.ConvertOption) *SQL { 71 s.ConvertOptions = convertOptions 72 return s 73 } 74 75 // WithTimeout set sql execution timeout 76 func (s *SQL) WithTimeout(timeout time.Duration) *SQL { 77 s.Timeout = timeout 78 return s 79 } 80 81 // WithVars replace vars. 82 func (s *SQL) WithVars(vars ...interface{}) *SQL { 83 s.Vars = vars 84 return s 85 } 86 87 func (s *SQL) AndIf(ok bool, cond string, args ...interface{}) *SQL { 88 if !ok { 89 return s 90 } 91 92 return s.And(cond, args...) 93 } 94 95 func (s *SQL) And(cond string, args ...interface{}) *SQL { 96 switch len(args) { 97 case 0: 98 if !ss.ContainsFold(s.Q, "where") { 99 s.Q += " where " + cond 100 } else { 101 s.Q += " and " + cond 102 } 103 return s 104 case 1: 105 arg := reflect.ValueOf(args[0]) 106 if arg.IsZero() { 107 return s 108 } 109 110 isSlice := arg.Kind() == reflect.Slice 111 if isSlice && arg.Len() == 0 { 112 return s 113 } 114 if isSlice && arg.Len() > 1 && strings.Count(cond, "?") == 1 { 115 cond = strings.Replace(cond, "?", ss.Repeat("?", ",", arg.Len()), 1) 116 } 117 if !ss.ContainsFold(s.Q, "where") { 118 s.Q += " where " + cond 119 } else { 120 s.Q += " and " + cond 121 } 122 123 if isSlice { 124 for i := 0; i < arg.Len(); i++ { 125 s.Vars = append(s.Vars, arg.Index(i).Interface()) 126 } 127 } else { 128 s.Vars = append(s.Vars, args[0]) 129 } 130 return s 131 default: 132 panic("not supported") 133 } 134 } 135 136 func (s *SQL) adaptUpdate(db SqxDB) error { 137 if dbTypeAware, ok := db.(DBTypeAware); ok { 138 dbType := dbTypeAware.GetDBType() 139 options := s.ConvertOptions 140 cr, err := dbType.Convert(s.Q, options...) 141 if err != nil { 142 return err 143 } 144 145 s.Q, s.Vars = cr.PickArgs(s.Vars) 146 } 147 148 if !s.NoLog { 149 logQuery(s.Name, s.Q, s.Vars) 150 } 151 152 return nil 153 } 154 155 func (s *SQL) adaptQuery(db SqxDB) error { 156 if dbTypeAware, ok := db.(DBTypeAware); ok { 157 dbType := dbTypeAware.GetDBType() 158 options := s.ConvertOptions 159 if s.Limit > 0 { 160 options = append([]sqlparser.ConvertOption{sqlparser.WithLimit(s.Limit)}, options...) 161 } 162 cr, err := dbType.Convert(s.Q, options...) 163 if err != nil { 164 return err 165 } 166 167 s.Q, s.Vars = cr.PickArgs(s.Vars) 168 if s.AppendQ != "" { 169 s.Q += " " + s.AppendQ 170 } 171 172 s.adapted = true 173 } 174 175 if !s.NoLog { 176 logQuery(s.Name, s.Q, s.Vars) 177 } 178 179 return nil 180 } 181 182 // CreateSQL creates a composite SQL on base and condition cond. 183 func CreateSQL(base string, cond interface{}) (*SQL, error) { 184 result := &SQL{} 185 if cond == nil { 186 result.Q = base 187 return result, nil 188 } 189 190 vc, err := inferenceCondValue(cond) 191 if err != nil { 192 return nil, err 193 } 194 195 condSql, vars, err := iterateFields(vc) 196 if err != nil { 197 return nil, err 198 } 199 200 if condSql == "" { 201 result.Q = base 202 return result, nil 203 } 204 205 result.Vars = vars 206 207 parsed, err := sqlparser.Parse(base) 208 if err != nil { 209 return nil, err 210 } 211 212 iw, ok := parsed.(sqlparser.IWhere) 213 if !ok { 214 return result, nil 215 } 216 217 x := `select 1 from t where ` + createNewWhere(iw, condSql) 218 condParsed, err := sqlparser.Parse(x) 219 if err != nil { 220 return nil, err 221 } 222 223 iw.SetWhere(condParsed.(*sqlparser.Select).Where) 224 result.Q = sqlparser.String(parsed) 225 226 return result, nil 227 } 228 229 func createNewWhere(iw sqlparser.IWhere, condSql string) string { 230 where := iw.GetWhere() 231 if where == nil { 232 return condSql 233 } 234 235 whereString := sqlparser.String(where) 236 if _, ok := where.Expr.(*sqlparser.OrExpr); ok { 237 return `(` + whereString[7:] + `) and ` + condSql 238 } 239 240 return `` + whereString[7:] + ` and ` + condSql 241 } 242 243 func inferenceCondValue(cond interface{}) (reflect.Value, error) { 244 vc := reflect.ValueOf(cond) 245 if vc.Kind() == reflect.Ptr { 246 vc = vc.Elem() 247 } 248 249 if vc.Kind() != reflect.Struct { 250 return reflect.Value{}, ErrConditionKind 251 } 252 253 return vc, nil 254 } 255 256 const andPrefix = " and " 257 258 func iterateFields(vc reflect.Value) (string, []interface{}, error) { 259 condSql := "" 260 vars := make([]interface{}, 0) 261 t := vc.Type() 262 263 for i := 0; i < vc.NumField(); i++ { 264 f := t.Field(i) 265 if f.PkgPath != "" { // not exported 266 continue 267 } 268 269 cond := f.Tag.Get("cond") 270 if cond == "-" { // ignore as a condition field 271 continue 272 } 273 274 v := vc.Field(i) 275 if f.Anonymous { 276 embeddedSQL, embeddedVars, err := iterateFields(v) 277 if err != nil { 278 return "", nil, err 279 } 280 281 condSql += andPrefix + embeddedSQL 282 vars = append(vars, embeddedVars...) 283 continue 284 } 285 286 cond, fieldVars, err := processTag(f.Tag, f.Name, v) 287 if err != nil { 288 return "", nil, err 289 } 290 291 if cond != "" { 292 condSql += andPrefix + cond 293 vars = append(vars, fieldVars...) 294 } 295 } 296 297 if condSql != "" { 298 condSql = condSql[len(andPrefix):] 299 } 300 301 return condSql, vars, nil 302 } 303 304 func processTag(tag reflect.StructTag, fieldName string, v reflect.Value) (cond string, vars []interface{}, err error) { 305 cond = tag.Get("cond") 306 zero := tag.Get("zero") 307 if yes, err1 := isZero(v, zero); err1 != nil { 308 return "", nil, err1 309 } else if yes { // ignore zero field 310 return "", nil, nil 311 } 312 313 if cond == "" { 314 cond = strcase.ToSnake(fieldName) + "=?" 315 } 316 317 vi := v.Interface() 318 if modifier := tag.Get("modifier"); modifier != "" { 319 vi = strings.ReplaceAll(modifier, "v", fmt.Sprintf("%v", vi)) 320 } 321 322 for i := 0; i < strings.Count(cond, "?"); i++ { 323 vars = append(vars, vi) 324 } 325 return 326 } 327 328 func isZero(v reflect.Value, zero string) (bool, error) { 329 if zero == "" { 330 return v.IsZero(), nil 331 } 332 333 switch v.Kind() { 334 case reflect.String: 335 return zero == v.Interface(), nil 336 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 337 zeroV, err := strconv.ParseInt(zero, 10, 64) 338 if err != nil { 339 return false, err 340 } 341 return zeroV == v.Convert(TypeInt64).Interface(), nil 342 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 343 zeroV, err := strconv.ParseUint(zero, 10, 64) 344 if err != nil { 345 return false, err 346 } 347 return zeroV == v.Convert(TypeUint64).Interface(), nil 348 case reflect.Float32, reflect.Float64: 349 zeroV, err := strconv.ParseFloat(zero, 64) 350 if err != nil { 351 return false, err 352 } 353 354 return zeroV == v.Convert(TypeFloat64).Interface(), nil 355 case reflect.Bool: 356 zeroV, err := strconv.ParseBool(zero) 357 if err != nil { 358 return false, err 359 } 360 return zeroV == v.Interface(), nil 361 } 362 363 return false, nil 364 } 365 366 var ( 367 TypeInt64 = reflect.TypeOf(int64(0)) 368 TypeUint64 = reflect.TypeOf(uint64(0)) 369 TypeFloat64 = reflect.TypeOf(float64(0)) 370 )