github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/exec.go (about) 1 package sqx 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "reflect" 9 "strconv" 10 "strings" 11 12 "github.com/bingoohuang/gg/pkg/mapstruct" 13 "github.com/bingoohuang/gg/pkg/sqlparse/sqlparser" 14 ) 15 16 // QueryAsNumber executes a query which only returns number like count(*) sql. 17 func (s SQL) QueryAsNumber(db SqxDB) (int64, error) { 18 str, err := s.QueryAsString(db) 19 if err != nil { 20 return 0, err 21 } 22 23 return strconv.ParseInt(str, 10, 64) 24 } 25 26 // QueryAsString executes a query which only returns number like count(*) sql. 27 func (s SQL) QueryAsString(db SqxDB) (string, error) { 28 row, err := s.QueryAsRow(db) 29 if err != nil { 30 return "", err 31 } 32 33 if len(row) == 0 { 34 return "", nil 35 } 36 37 return row[0], nil 38 } 39 40 // Update executes an update/delete query and returns rows affected. 41 func (s SQL) Update(db SqxDB) (int64, error) { 42 r, err := s.UpdateRaw(db) 43 if err != nil { 44 return 0, err 45 } 46 47 return r.RowsAffected() 48 } 49 50 func (s SQL) UpdateRaw(db SqxDB) (sql.Result, error) { 51 if dbTypeAware, ok := db.(DBTypeAware); ok { 52 dbType := dbTypeAware.GetDBType() 53 cr, err := dbType.Convert(s.Q, s.ConvertOptions...) 54 if err != nil { 55 return nil, err 56 } 57 58 s.Q, s.Vars = cr.PickArgs(s.Vars) 59 } 60 61 if !s.NoLog { 62 logQuery(s.Name, s.Q, s.Vars) 63 } 64 65 ctx, cancel := s.prepareContext() 66 defer cancel() 67 68 result, err := db.ExecContext(ctx, s.Q, s.Vars...) 69 logQueryError(s.NoLog, s.Name, result, err) 70 return result, err 71 } 72 73 type RowScannerInit interface { 74 InitRowScanner(columns []string) 75 } 76 77 type RowScanner interface { 78 ScanRow(columns []string, rows *sql.Rows, rowIndex int) (bool, error) 79 } 80 81 type ScanRowFn func(columns []string, rows *sql.Rows, rowIndex int) (bool, error) 82 83 func (s ScanRowFn) ScanRow(columns []string, rows *sql.Rows, rowIndex int) (bool, error) { 84 return s(columns, rows, rowIndex) 85 } 86 87 // QueryOption defines the query options. 88 type QueryOption struct { 89 MaxRows int 90 TagNames []string 91 Scanner RowScanner 92 LowerColumnNames bool 93 94 ConvertOptionOptions []sqlparser.ConvertOption 95 } 96 97 // QueryOptionFn define the prototype function to set QueryOption. 98 type QueryOptionFn func(o *QueryOption) 99 100 // QueryOptionFns is the slice of QueryOptionFn. 101 type QueryOptionFns []QueryOptionFn 102 103 func (q QueryOptionFns) Options() *QueryOption { 104 o := &QueryOption{ 105 TagNames: []string{"col", "db", "mapstruct", "field", "json", "yaml"}, 106 } 107 for _, fn := range q { 108 fn(o) 109 } 110 111 return o 112 } 113 114 // WithMaxRows set the max rows of QueryOption. 115 func WithMaxRows(maxRows int) QueryOptionFn { 116 return func(o *QueryOption) { o.MaxRows = maxRows } 117 } 118 119 // WithLowerColumnNames set the LowerColumnNames of QueryOption. 120 func WithLowerColumnNames(v bool) QueryOptionFn { 121 return func(o *QueryOption) { o.LowerColumnNames = v } 122 } 123 124 // WithTagNames set the tagNames for mapping struct fields to query Columns. 125 func WithTagNames(tagNames ...string) QueryOptionFn { 126 return func(o *QueryOption) { o.TagNames = tagNames } 127 } 128 129 // WithOptions apply the query option directly. 130 func WithOptions(v *QueryOption) QueryOptionFn { 131 return func(o *QueryOption) { *o = *v } 132 } 133 134 // WithScanRow set row scanner for the query result. 135 func WithScanRow(v ScanRowFn) QueryOptionFn { 136 return func(o *QueryOption) { o.Scanner = v } 137 } 138 139 // WithRowScanner set row scanner for the query result. 140 func WithRowScanner(v RowScanner) QueryOptionFn { 141 return func(o *QueryOption) { o.Scanner = v } 142 } 143 144 // allowRowNum test the current rowNum is allowed for MaxRows control. 145 func (o QueryOption) allowRowNum(rowNum int) bool { 146 return o.MaxRows == 0 || rowNum <= o.MaxRows 147 } 148 149 // Query queries return with result. 150 func (s SQL) Query(db SqxDB, result interface{}, optionFns ...QueryOptionFn) error { 151 err := s.query(db, result, optionFns...) 152 if !s.NoLog { 153 logQueryError(true, s.Name, nil, err) 154 logRows(s.Name, GetQueryRows(result)) 155 } 156 return err 157 } 158 159 func GetQueryRows(dest interface{}) int { 160 if dest == nil { 161 return 0 162 } 163 164 v := reflect.ValueOf(dest) 165 if v.Kind() == reflect.Ptr { 166 v = v.Elem() 167 } 168 169 switch v.Kind() { 170 case reflect.Slice, reflect.Array: 171 return v.Len() 172 default: 173 return 1 174 } 175 } 176 177 func (s SQL) query(db SqxDB, result interface{}, optionFns ...QueryOptionFn) error { 178 resultValue := reflect.ValueOf(result) 179 if resultValue.Kind() != reflect.Ptr { 180 return fmt.Errorf("result must be a pointer") 181 } 182 183 elem := resultValue.Elem() 184 elemKind := elem.Kind() 185 if elemKind == reflect.Ptr { // 如果依然是指针 186 typ := elem.Type().Elem() // 获取二级指针底层类型 187 val := reflect.New(typ) // 创新底层类型对象 188 err := s.Query(db, val.Interface(), optionFns...) 189 if err == nil { 190 elem.Set(val) // 赋予一级指针新对象地址 191 } 192 return err 193 } 194 195 option := QueryOptionFns(optionFns).Options() 196 197 var err error 198 var input interface{} 199 200 options := WithOptions(option) 201 switch elemKind { 202 case reflect.Struct: 203 input, err = s.QueryAsMap(db, options) 204 case reflect.Slice: 205 sliceElemType := elem.Type().Elem() 206 switch sliceElemType.Kind() { 207 case reflect.Struct: 208 input, err = s.QueryAsMaps(db, options) 209 case reflect.String, reflect.Bool, 210 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 211 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 212 scanner := &Col1Scanner{} 213 err = s.QueryRaw(db, options, WithRowScanner(scanner)) 214 input = scanner.Data 215 default: 216 return ErrNotSupported 217 } 218 case reflect.String, reflect.Bool, 219 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 220 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 221 scanner := &Col1Scanner{MaxRows: 1} 222 err = s.QueryRaw(db, options, WithRowScanner(scanner)) 223 if len(scanner.Data) > 0 { 224 input = scanner.Data[0] 225 } 226 default: 227 return ErrNotSupported 228 } 229 230 if err != nil { 231 return err 232 } 233 234 decoder, err := mapstruct.NewDecoder(&mapstruct.Config{ 235 Result: result, 236 TagNames: option.TagNames, 237 Squash: true, 238 WeakType: true, 239 }) 240 if err != nil { 241 return err 242 } 243 244 return decoder.Decode(input) 245 } 246 247 var ErrNotSupported = errors.New("sqx: Unsupported result type") 248 249 type Col1Scanner struct { 250 Data []string 251 MaxRows int 252 } 253 254 func (s *Col1Scanner) ScanRow(columns []string, rows *sql.Rows, _ int) (bool, error) { 255 if v, err := ScanSliceRow(rows, columns); err != nil { 256 return false, err 257 } else { 258 s.Data = append(s.Data, v[0]) 259 return s.MaxRows == 0 || len(s.Data) < s.MaxRows, nil 260 } 261 } 262 263 type MapScanner struct { 264 Data []map[string]string 265 MaxRows int 266 } 267 268 func (s *MapScanner) Data0() map[string]string { 269 if len(s.Data) == 0 { 270 return nil 271 } 272 273 return s.Data[0] 274 } 275 276 func (s *MapScanner) ScanRow(columns []string, rows *sql.Rows, _ int) (bool, error) { 277 if v, err := ScanMapRow(rows, columns); err != nil { 278 return false, err 279 } else { 280 s.Data = append(s.Data, v) 281 return s.MaxRows == 0 || len(s.Data) < s.MaxRows, nil 282 } 283 } 284 285 // QueryAsMaps query rows as map slice. 286 func (s SQL) QueryAsMaps(db SqxDB, optionFns ...QueryOptionFn) ([]map[string]string, error) { 287 scanner := &MapScanner{Data: make([]map[string]string, 0)} 288 err := s.QueryRaw(db, append(optionFns, WithRowScanner(scanner))...) 289 return scanner.Data, err 290 } 291 292 // QueryAsMap query a single row as a map return. 293 func (s SQL) QueryAsMap(db SqxDB, optionFns ...QueryOptionFn) (map[string]string, error) { 294 scanner := &MapScanner{Data: make([]map[string]string, 0), MaxRows: 1} 295 err := s.QueryRaw(db, append(optionFns, WithRowScanner(scanner))...) 296 return scanner.Data0(), err 297 } 298 299 func ScanSliceRow(rows *sql.Rows, columns []string) ([]string, error) { 300 holders, err := ScanRow(len(columns), rows) 301 if err != nil { 302 return nil, err 303 } 304 305 m := make([]string, len(columns)) 306 for i, h := range holders { 307 m[i] = h.String() 308 } 309 310 return m, nil 311 } 312 313 func ScanMapRow(rows *sql.Rows, columns []string) (map[string]string, error) { 314 holders, err := ScanRow(len(columns), rows) 315 if err != nil { 316 return nil, err 317 } 318 319 m := make(map[string]string) 320 for i, h := range holders { 321 m[columns[i]] = h.String() 322 } 323 324 return m, nil 325 } 326 327 type StringRowScanner struct { 328 Data [][]string 329 MaxRows int 330 } 331 332 func (r *StringRowScanner) ScanRow(columns []string, rows *sql.Rows, _ int) (bool, error) { 333 if m, err := ScanStringRow(rows, columns); err != nil { 334 return false, err 335 } else { 336 r.Data = append(r.Data, m) 337 return r.MaxRows == 0 || len(r.Data) < r.MaxRows, nil 338 } 339 } 340 341 func (r *StringRowScanner) Data0() []string { 342 if len(r.Data) == 0 { 343 return nil 344 } 345 346 return r.Data[0] 347 } 348 349 // QueryAsRow query a single row as a string slice return. 350 func (s SQL) QueryAsRow(db SqxDB, optionFns ...QueryOptionFn) ([]string, error) { 351 f := &StringRowScanner{MaxRows: 1} 352 if err := s.QueryRaw(db, append(optionFns, WithRowScanner(f))...); err != nil { 353 return nil, err 354 } 355 356 return f.Data0(), nil 357 } 358 359 // QueryAsRows query rows as [][]string. 360 func (s SQL) QueryAsRows(db SqxDB, optionFns ...QueryOptionFn) ([][]string, error) { 361 f := &StringRowScanner{} 362 if err := s.QueryRaw(db, append(optionFns, WithRowScanner(f))...); err != nil { 363 return nil, err 364 } 365 366 return f.Data, nil 367 } 368 369 func ScanStringRow(rows *sql.Rows, columns []string) ([]string, error) { 370 holders, err := ScanRow(len(columns), rows) 371 if err != nil { 372 return nil, err 373 } 374 375 m := make([]string, len(columns)) 376 for i, h := range holders { 377 m[i] = h.String() 378 } 379 return m, nil 380 } 381 382 // QueryRaw query rows for customized row scanner. 383 func (s SQL) QueryRaw(db SqxDB, optionFns ...QueryOptionFn) error { 384 option, r, columns, err := s.prepareQuery(db, optionFns...) 385 if err != nil { 386 return err 387 } 388 389 defer r.Close() 390 391 if initial, ok := option.Scanner.(RowScannerInit); ok { 392 initial.InitRowScanner(columns) 393 } 394 395 rows := 0 396 for rn := 0; r.Next() && option.allowRowNum(rn+1); rn++ { 397 rows++ 398 if continued, err := option.Scanner.ScanRow(columns, r, rn); err != nil { 399 return err 400 } else if !continued { 401 break 402 } 403 } 404 405 if rows == 0 { 406 return sql.ErrNoRows 407 } 408 409 return nil 410 } 411 412 func ScanRowValues(rows *sql.Rows) ([]interface{}, error) { 413 cols, err := rows.Columns() 414 if err != nil { 415 return nil, err 416 } 417 418 row, err := ScanRow(len(cols), rows) 419 if err != nil { 420 return nil, err 421 } 422 423 rowValues := make([]interface{}, len(cols)) 424 for i := range rowValues { 425 rowValues[i] = row[i].Get() 426 } 427 428 return rowValues, nil 429 } 430 431 func ScanRow(columnSize int, r *sql.Rows) ([]NullAny, error) { 432 holders := make([]NullAny, columnSize) 433 pointers := make([]interface{}, columnSize) 434 for i := 0; i < columnSize; i++ { 435 pointers[i] = &holders[i] 436 } 437 438 if err := r.Scan(pointers...); err != nil { 439 return nil, err 440 } 441 442 return holders, nil 443 } 444 445 func (s SQL) prepareContext() (ctx context.Context, cancel func()) { 446 ctx = s.Ctx 447 if ctx == nil { 448 ctx = context.Background() 449 } 450 if s.Timeout > 0 { 451 return context.WithTimeout(ctx, s.Timeout) 452 } 453 454 return ctx, func() {} 455 } 456 457 func (s *SQL) prepareQuery(db SqxDB, optionFns ...QueryOptionFn) (*QueryOption, *sql.Rows, []string, error) { 458 if err := s.adaptQuery(db); err != nil { 459 return nil, nil, nil, err 460 } 461 462 ctx, cancel := s.prepareContext() 463 defer cancel() 464 ctx = context.WithValue(ctx, AdaptedKey, s.adapted) 465 r, err := db.QueryContext(ctx, s.Q, s.Vars...) 466 if err != nil { 467 return nil, nil, nil, err 468 } 469 470 columns, err := r.Columns() 471 if err != nil { 472 return nil, nil, nil, err 473 } 474 475 option := QueryOptionFns(optionFns).Options() 476 if option.LowerColumnNames { 477 for i, col := range columns { 478 columns[i] = strings.ToLower(col) 479 } 480 } 481 482 return option, r, columns, nil 483 }