github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/dao.go (about) 1 package sqx 2 3 import ( 4 "database/sql" 5 "errors" 6 "fmt" 7 "log" 8 "reflect" 9 "strconv" 10 "strings" 11 12 "github.com/bingoohuang/gg/pkg/mathx" 13 "github.com/bingoohuang/gg/pkg/reflector" 14 "github.com/bingoohuang/gg/pkg/sqlparse/sqlparser" 15 "github.com/bingoohuang/gg/pkg/ss" 16 "github.com/bingoohuang/gg/pkg/strcase" 17 ) 18 19 type Limit struct { 20 Offset int64 21 Length int64 22 } 23 24 type Count int64 25 26 var ( 27 LimitType = reflect.TypeOf((*Limit)(nil)).Elem() 28 CountType = reflect.TypeOf((*Count)(nil)).Elem() 29 ) 30 31 // DBGetter is the interface to get a sql.DBGetter. 32 type DBGetter interface{ GetDB() *sql.DB } 33 34 // StdDB is the wrapper for sql.DBGetter. 35 type StdDB struct{ db *sql.DB } 36 37 // GetDB returns a sql.DBGetter. 38 func (f StdDB) GetDB() *sql.DB { return f.db } 39 40 // DB is the global sql.DB for convenience. 41 var DB *sql.DB 42 43 // CreateDao fulfils the dao (should be pointer). 44 func CreateDao(dao interface{}, createDaoOpts ...CreateDaoOpter) error { 45 daov := reflect.ValueOf(dao) 46 if daov.Kind() != reflect.Ptr || daov.Elem().Kind() != reflect.Struct { 47 return fmt.Errorf("dao should be pointer to struct") // nolint:goerr113 48 } 49 50 option, err := applyCreateDaoOption(createDaoOpts) 51 if err != nil { 52 return err 53 } 54 55 v := reflect.Indirect(daov) 56 createDBGetter(v, option) 57 createLogger(v, option) 58 createErrorSetter(v, option) 59 60 structValue := MakeStructValue(v) 61 for i := 0; i < structValue.NumField; i++ { 62 f := structValue.FieldByIndex(i) 63 64 if f.PkgPath != "" /* not exportable */ || f.Kind != reflect.Func { 65 continue 66 } 67 68 tags, err := reflector.ParseTags(string(f.Tag)) 69 if err != nil { 70 return err 71 } 72 73 sqlStmt, sqlName := option.getSQLStmt(f, tags, 0) 74 if sqlStmt == nil { 75 return fmt.Errorf("failed to find sqlName %s", f.Name) // nolint:goerr113 76 } 77 78 parsed := &SQLParsed{ 79 ID: sqlName, 80 SQL: sqlStmt, 81 opt: option, 82 } 83 84 if err := parsed.fastParseSQL(sqlStmt.Raw()); err != nil { 85 return err 86 } 87 88 r := sqlRun{SQLParsed: parsed} 89 if err := r.createFn(f); err != nil { 90 return err 91 } 92 } 93 94 return nil 95 } 96 97 func (option *CreateDaoOpt) getSQLStmt(field StructField, tags reflector.Tags, stack int) (SQLPart, string) { 98 if stack > 10 { 99 return nil, "" 100 } 101 102 if sqlStmt := field.GetTag("sql"); sqlStmt != "" { 103 dsi := DotItem{ 104 Name: field.Name, 105 Content: []string{sqlStmt}, 106 Attrs: tags.Map(), 107 } 108 part, err := dsi.DynamicSQL() 109 if err != nil { 110 option.Logger.LogError(err) 111 } 112 113 return part, field.Name 114 } 115 116 sqlName := field.GetTagOr("sqlName", field.Name) 117 if part, err := option.DotSQL(sqlName); err != nil { 118 option.Logger.LogError(err) 119 } else if part != nil { 120 return part, sqlName 121 } 122 123 if sqlName == field.Name { 124 return nil, "" 125 } 126 127 if field, ok := field.Parent.FieldByName(sqlName); ok { 128 return option.getSQLStmt(field, nil, stack+1) 129 } 130 131 return nil, sqlName 132 } 133 134 func (r *sqlRun) createFn(f StructField) error { 135 numIn := f.Type.NumIn() 136 numOut := f.Type.NumOut() 137 138 lastOutError := numOut > 0 && reflector.IsError(f.Type.Out(numOut-1)) 139 if lastOutError { 140 numOut-- 141 } 142 143 fn := r.MakeFunc(f, numIn, numOut) 144 if fn == nil { 145 err := fmt.Errorf("unsupportd func %s %v", f.Name, f.Type) // nolint:goerr113 146 r.logError(err) 147 148 return err 149 } 150 151 f.Field.Set(reflect.MakeFunc(f.Type, func(args []reflect.Value) []reflect.Value { 152 r.opt.ErrSetter(nil) 153 values, err := fn(args) 154 if err != nil { 155 r.opt.ErrSetter(err) 156 r.logError(err) 157 158 values = make([]reflect.Value, numOut, numOut+1) 159 for i := 0; i < numOut; i++ { 160 values[i] = reflect.Zero(f.Type.Out(i)) 161 } 162 } 163 164 if lastOutError { 165 if err != nil { 166 values = append(values, reflect.ValueOf(err)) 167 } else { 168 values = append(values, reflect.Zero(reflector.ErrType)) 169 } 170 } 171 172 return values 173 })) 174 175 return nil 176 } 177 178 func (r *sqlRun) MakeFunc(f StructField, numIn, numOut int) func([]reflect.Value) ([]reflect.Value, error) { 179 fn := r.getExecFn() 180 return func(args []reflect.Value) ([]reflect.Value, error) { 181 return fn(numIn, f, makeOutTypes(f.Type, numOut), args) 182 } 183 } 184 185 func (r *sqlRun) getExecFn() func(int, StructField, []reflect.Type, []reflect.Value) ([]reflect.Value, error) { 186 switch isBindByName := r.isBindBy(ByName); { 187 case !r.IsQuery && isBindByName: 188 return r.execByName 189 case !r.IsQuery && !isBindByName: 190 return r.execBySeq 191 case r.IsQuery && isBindByName: 192 return r.queryByName 193 default: // isQuery && !isBindByName: 194 return r.queryBySeq 195 } 196 } 197 198 func makeOutTypes(outType reflect.Type, numOut int) []reflect.Type { 199 rt := make([]reflect.Type, numOut) 200 for i := 0; i < numOut; i++ { 201 rt[i] = outType.Out(i) 202 } 203 204 return rt 205 } 206 207 type sqlRun struct { 208 *SQLParsed 209 } 210 211 func (p *SQLParsed) evalSeq(numIn int, f StructField, args []reflect.Value) error { 212 env := make(map[string]interface{}) 213 for i, arg := range args { 214 env[fmt.Sprintf("_%d", i+1)] = arg.Interface() 215 } 216 217 if len(args) > 0 { 218 env = p.createFieldSqlParts(env, args[0]) 219 } 220 221 return p.eval(numIn, f, env) 222 } 223 224 func (p *SQLParsed) eval(numIn int, f StructField, env map[string]interface{}) error { 225 runSQL, err := p.SQL.Eval(env) 226 if err != nil { 227 return err 228 } 229 230 if err := p.parseSQL(runSQL); err != nil { 231 return err 232 } 233 234 if err := p.checkFuncInOut(numIn, f); err != nil { 235 return err 236 } 237 238 return nil 239 } 240 241 func (r *sqlRun) queryByName(numIn int, f StructField, 242 outTypes []reflect.Type, args []reflect.Value, 243 ) ([]reflect.Value, error) { 244 var bean reflect.Value 245 246 if numIn > 0 { 247 bean = args[0] 248 } 249 250 parsed := *r.SQLParsed 251 env := parsed.createNamedMap(bean) 252 253 if err := parsed.eval(numIn, f, env); err != nil { 254 return nil, err 255 } 256 257 vars, err := parsed.createNamedVars(bean) 258 if err != nil { 259 return nil, err 260 } 261 262 counterIndex := indexOfTypes(outTypes, CountType) 263 db := r.opt.DBGetter.GetDB() 264 rows, counter, err := parsed.doQueryDirectVars(db, vars, counterIndex >= 0) 265 if err != nil { 266 return nil, err 267 } 268 269 return parsed.wrapCounter(rows, outTypes, counterIndex, counter) 270 } 271 272 func (p *SQLParsed) wrapCounter(rows *sql.Rows, outTypes []reflect.Type, counterIndex int, counterFn func() (int64, error)) ([]reflect.Value, error) { 273 values, err := p.processQueryRows(rows, remove(outTypes, counterIndex)) 274 _ = rows.Close() 275 if err != nil || counterFn == nil { 276 return values, err 277 } 278 279 counter, err := counterFn() 280 if err != nil { 281 return values, err 282 } 283 284 return insert(values, counterIndex, reflect.ValueOf(Count(counter))), nil 285 } 286 287 func remove(slice []reflect.Type, s int) []reflect.Type { 288 if s < 0 { 289 return slice 290 } 291 292 return append(slice[:s], slice[s+1:]...) 293 } 294 295 func insert(a []reflect.Value, index int, value reflect.Value) []reflect.Value { 296 if len(a) == index { // nil or empty slice or after last element 297 return append(a, value) 298 } 299 300 a = append(a[:index+1], a[index:]...) // index < len(a) 301 a[index] = value 302 return a 303 } 304 305 func indexOfTypes(types []reflect.Type, typ reflect.Type) int { 306 for i, t := range types { 307 if t == typ { 308 return i 309 } 310 } 311 312 return -1 313 } 314 315 func (r *sqlRun) execByName(numIn int, f StructField, outTypes []reflect.Type, args []reflect.Value) ([]reflect.Value, error) { 316 var bean reflect.Value 317 318 if numIn > 0 { 319 bean = args[0] 320 } 321 322 item0 := bean 323 itemSize := 1 324 isBeanSlice := bean.IsValid() && bean.Type().Kind() == reflect.Slice 325 326 if isBeanSlice { 327 if bean.IsNil() || bean.Len() == 0 { 328 return []reflect.Value{}, nil 329 } 330 331 item0 = bean.Index(0) 332 itemSize = bean.Len() 333 } 334 335 var ( 336 err error 337 pr *sql.Stmt 338 lastResult sql.Result 339 lastSQL string 340 ) 341 342 parsed := *r.SQLParsed 343 db := r.opt.DBGetter.GetDB() 344 tx, err := db.BeginTx(parsed.opt.Ctx, nil) 345 if err != nil { 346 return nil, fmt.Errorf("failed to begin tx %w", err) 347 } 348 349 for ii := 0; ii < itemSize; ii++ { 350 if ii > 0 { 351 item0 = bean.Index(ii) 352 } 353 354 namedMap := parsed.createNamedMap(item0) 355 if err := parsed.eval(numIn, f, namedMap); err != nil { 356 return nil, err 357 } 358 vars, err := parsed.createNamedVars(item0) 359 if err != nil { 360 return nil, err 361 } 362 363 if lastSQL != parsed.runSQL { 364 lastSQL = parsed.runSQL 365 366 query, err := r.replaceQuery(db, parsed.runSQL) 367 if err != nil { 368 return nil, fmt.Errorf("replaceQuery %s error %w", parsed.runSQL, err) 369 } 370 371 log.Printf("exec %s [%s] with %v", parsed.ID, query, vars) 372 if pr, err = tx.PrepareContext(parsed.opt.Ctx, query); err != nil { 373 return nil, fmt.Errorf("failed to prepare sql [%s] error %w", r.RawStmt, err) 374 } 375 } 376 377 lastResult, err = pr.ExecContext(parsed.opt.Ctx, vars...) 378 if err != nil { 379 return nil, fmt.Errorf("failed to execute %s with vars %v error %w", parsed.runSQL, vars, err) 380 } 381 382 LogSqlResult(lastResult) 383 } 384 385 if err := tx.Commit(); err != nil { 386 return nil, fmt.Errorf("failed to commiterror %w", err) 387 } 388 389 return convertExecResult(lastResult, lastSQL, outTypes) 390 } 391 392 func LogSqlResult(lastResult sql.Result) { 393 lastInsertId, _ := lastResult.LastInsertId() 394 rowsAffected, _ := lastResult.RowsAffected() 395 log.Printf("Result lastInsertId: %d, rowsAffected: %d", lastInsertId, rowsAffected) 396 } 397 398 func (p *SQLParsed) createFieldSqlParts(m map[string]interface{}, bean reflect.Value) map[string]interface{} { 399 if !bean.IsValid() || bean.Type().Kind() != reflect.Struct { 400 return m 401 } 402 403 structValue := MakeStructValue(bean) 404 for i, f := range structValue.FieldTypes { 405 if sqlPart := f.Tag.Get("sql"); sqlPart != "" { 406 if bean.Field(i).IsZero() { 407 continue 408 } 409 410 if f.Type.AssignableTo(LimitType) { 411 l := bean.Field(i).Interface().(Limit) 412 p.fp.AddFieldSqlPart(sqlPart, []interface{}{l.Offset, l.Length}, false) 413 } else { 414 p.fp.AddFieldSqlPart(sqlPart, []interface{}{bean.Field(i).Interface()}, true) 415 } 416 } 417 } 418 419 return m 420 } 421 422 func (p *SQLParsed) createNamedMap(bean reflect.Value) map[string]interface{} { 423 m := make(map[string]interface{}) 424 if !bean.IsValid() { 425 return m 426 } 427 428 switch bean.Type().Kind() { 429 case reflect.Struct: 430 structValue := MakeStructValue(bean) 431 for i, f := range structValue.FieldTypes { 432 if tagName := f.Tag.Get("name"); tagName != "" { 433 m[tagName] = bean.Field(i).Interface() 434 } else { 435 name := strcase.ToCamelLower(f.Name) 436 m[name] = bean.Field(i).Interface() 437 } 438 } 439 case reflect.Map: 440 for _, k := range bean.MapKeys() { 441 if ks, ok := k.Interface().(string); ok { 442 m[ks] = bean.MapIndex(k).Interface() 443 } 444 } 445 } 446 447 return m 448 } 449 450 func (p *SQLParsed) createNamedVars(bean reflect.Value) ([]interface{}, error) { 451 itemType := bean.Type() 452 453 var namedValueParser func(name string, item reflect.Value, itemType reflect.Type) interface{} 454 455 switch itemType.Kind() { 456 case reflect.Struct: 457 namedValueParser = func(name string, item reflect.Value, itemType reflect.Type) interface{} { 458 return item.FieldByNameFunc(func(f string) bool { 459 return matchesField2Col(itemType, f, name) 460 }).Interface() 461 } 462 case reflect.Map: 463 namedValueParser = func(name string, item reflect.Value, itemType reflect.Type) interface{} { 464 return item.MapIndex(reflect.ValueOf(name)).Interface() 465 } 466 } 467 468 if namedValueParser == nil { 469 // nolint:goerr113 470 return nil, fmt.Errorf("named vars should use struct/map, unsupported type %v", itemType) 471 } 472 473 vars := make([]interface{}, len(p.Vars)) 474 475 for i, name := range p.Vars { 476 vars[i] = namedValueParser(name, bean, itemType) 477 } 478 479 return vars, nil 480 } 481 482 func (r *sqlRun) execBySeq(numIn int, f StructField, 483 outTypes []reflect.Type, args []reflect.Value, 484 ) ([]reflect.Value, error) { 485 parsed := *r.SQLParsed 486 487 if err := parsed.evalSeq(numIn, f, args); err != nil { 488 return nil, err 489 } 490 491 vars := parsed.makeVars(args) 492 db := r.opt.DBGetter.GetDB() 493 query, err := r.replaceQuery(db, parsed.runSQL) 494 if err != nil { 495 return nil, fmt.Errorf("replaceQuery %s error %w", parsed.runSQL, err) 496 } 497 498 log.Printf("exec query %s [%s] with %v", r.ID, query, vars) 499 500 result, err := db.ExecContext(parsed.opt.Ctx, query, vars...) 501 if err != nil { 502 return nil, fmt.Errorf("execute %s error %w", r.SQL, err) 503 } 504 505 LogSqlResult(result) 506 507 results, err := convertExecResult(result, query, outTypes) 508 if err != nil { 509 return nil, fmt.Errorf("execute %s error %w", r.SQL, err) 510 } 511 512 return results, nil 513 } 514 515 func (r *sqlRun) queryBySeq(numIn int, f StructField, 516 outTypes []reflect.Type, args []reflect.Value, 517 ) ([]reflect.Value, error) { 518 parsed := *r.SQLParsed 519 if err := parsed.evalSeq(numIn, f, args); err != nil { 520 return nil, err 521 } 522 523 db := r.opt.DBGetter.GetDB() 524 counterIndex := indexOfTypes(outTypes, CountType) 525 526 rows, counterFn, err := parsed.doQuery(db, args, counterIndex >= 0) 527 if err != nil { 528 return nil, err 529 } 530 531 defer rows.Close() 532 533 return parsed.wrapCounter(rows, outTypes, counterIndex, counterFn) 534 } 535 536 func (p *SQLParsed) processQueryRows(rows *sql.Rows, outTypes []reflect.Type) ([]reflect.Value, error) { 537 columns, err := rows.Columns() 538 if err != nil { 539 return nil, fmt.Errorf("get columns %s error %w", p.SQL, err) 540 } 541 542 out0Type := outTypes[0] 543 outSlice := reflect.Value{} 544 out0TypePtr := out0Type.Kind() == reflect.Ptr 545 546 switch out0Type.Kind() { 547 case reflect.Slice: 548 outSlice = reflect.MakeSlice(out0Type, 0, 0) 549 out0Type = out0Type.Elem() 550 case reflect.Ptr: 551 out0Type = out0Type.Elem() 552 } 553 554 interceptorFn := p.getRowScanInterceptorFn() 555 mapFields, err := p.createMapFields(columns, out0Type, outTypes) 556 if err != nil { 557 return nil, err 558 } 559 ri := 0 560 561 defer func() { 562 log.Printf("query got %d rows", ri) 563 }() 564 565 for ; rows.Next() && (p.opt.QueryMaxRows <= 0 || ri < p.opt.QueryMaxRows); ri++ { 566 pointers, out := resetDests(out0Type, out0TypePtr, outTypes, mapFields) 567 if err := rows.Scan(pointers[:len(columns)]...); err != nil { 568 return nil, fmt.Errorf("scan rows %s error %w", p.SQL, err) 569 } 570 571 fillFields(mapFields, pointers) 572 573 if interceptorFn != nil { 574 outValues := make([]interface{}, len(out)) 575 for i, outVal := range out { 576 outValues[i] = outVal.Interface() 577 } 578 579 if goon, err := interceptorFn(ri, outValues...); err != nil { 580 return nil, err 581 } else if !goon { 582 break 583 } 584 } 585 586 if !outSlice.IsValid() { 587 return out[:len(outTypes)], nil 588 } 589 590 outSlice = reflect.Append(outSlice, out[0]) 591 } 592 593 if outSlice.IsValid() { 594 return []reflect.Value{outSlice}, nil 595 } 596 597 return noRows(out0Type, out0TypePtr, outTypes) 598 } 599 600 func noRows(out0Type reflect.Type, out0TypePtr bool, outTypes []reflect.Type) ([]reflect.Value, error) { 601 switch out0Type.Kind() { 602 case reflect.Map: 603 out := reflect.MakeMap(reflect.MapOf(out0Type.Key(), out0Type.Elem())) 604 return []reflect.Value{out}, nil 605 case reflect.Struct: 606 if out0TypePtr { 607 return []reflect.Value{reflect.Zero(outTypes[0])}, nil 608 } 609 610 return []reflect.Value{reflect.Indirect(reflect.New(out0Type))}, nil 611 } 612 613 outValues := make([]reflect.Value, len(outTypes)) 614 for i := range outTypes { 615 outValues[i] = reflect.Indirect(reflect.New(outTypes[i])) 616 } 617 618 return outValues, sql.ErrNoRows 619 } 620 621 func (p *SQLParsed) getRowScanInterceptorFn() RowScanInterceptorFn { 622 if p.opt.RowScanInterceptor != nil { 623 return p.opt.RowScanInterceptor.After 624 } 625 626 return nil 627 } 628 629 func (p *SQLParsed) doQuery(db *sql.DB, args []reflect.Value, counting bool) (*sql.Rows, func() (int64, error), error) { 630 vars := p.makeVars(args) 631 return p.doQueryDirectVars(db, vars, counting) 632 } 633 634 func (p *SQLParsed) doQueryDirectVars(db *sql.DB, vars []interface{}, counting bool) (*sql.Rows, func() (int64, error), error) { 635 query, err := p.replaceQuery(db, p.runSQL) 636 if err != nil { 637 return nil, nil, fmt.Errorf("replaceQuery %s error %w", query, err) 638 } 639 640 log.Printf("exec query %s [%s] with %v", p.ID, query, vars) 641 642 rows, err := db.QueryContext(p.opt.Ctx, query, vars...) 643 if err != nil || rows.Err() != nil { 644 if err == nil { 645 err = rows.Err() 646 } 647 648 return nil, nil, fmt.Errorf("execute %s error %w", query, err) 649 } 650 651 if counting { 652 return rows, func() (int64, error) { 653 count, err := p.pagingCount(db, query, vars) 654 return count, err 655 }, nil 656 } 657 658 return rows, nil, nil 659 } 660 661 var countStarExprs = func() sqlparser.SelectExprs { 662 p, _ := sqlparser.Parse(`select count(*)`) 663 return p.(*sqlparser.Select).SelectExprs 664 }() 665 666 func (p *SQLParsed) pagingCount(db *sql.DB, query string, vars []interface{}) (int64, error) { 667 parsed, err := sqlparser.Parse(query) 668 if err != nil { 669 return 0, err 670 } 671 672 selectQuery, ok := parsed.(*sqlparser.Select) 673 if !ok { 674 return 0, errors.New("not select query") 675 } 676 677 selectQuery.SelectExprs = countStarExprs 678 selectQuery.OrderBy = nil 679 selectQuery.Having = nil 680 oldLimit := selectQuery.Limit 681 selectQuery.Limit = nil 682 683 limitVarsCount := 0 684 if oldLimit != nil { 685 limitVarsCount++ 686 if oldLimit.Offset != nil { 687 limitVarsCount++ 688 } 689 } 690 691 countQuery := sqlparser.String(selectQuery) 692 vars = vars[:len(vars)-limitVarsCount] 693 694 log.Printf("I! execute query %s [%s] with args %v", p.ID, countQuery, vars) 695 696 countQuery, err = p.replaceQuery(db, countQuery) 697 if err != nil { 698 return 0, fmt.Errorf("replaceQuery %s error %w", countQuery, err) 699 } 700 701 rows, err := db.QueryContext(p.opt.Ctx, countQuery, vars...) 702 if err != nil || rows.Err() != nil { 703 if err == nil { 704 err = rows.Err() 705 } 706 707 return 0, fmt.Errorf("execute %s error %w", countQuery, err) 708 } 709 710 defer rows.Close() 711 712 rows.Next() 713 var count int64 714 if err := rows.Scan(&count); err != nil { 715 return 0, err 716 } 717 718 return count, nil 719 } 720 721 func (p *SQLParsed) createMapFields(columns []string, out0Type reflect.Type, 722 outTypes []reflect.Type, 723 ) ([]selectItem, error) { 724 switch out0Type.Kind() { 725 case reflect.Struct, reflect.Map: 726 if len(outTypes) != 1 { 727 // nolint:goerr113 728 return nil, fmt.Errorf("unsupported return type %v for current sql %v", out0Type, p.SQL) 729 } 730 } 731 732 lenCol := len(columns) 733 switch out0Type.Kind() { 734 case reflect.Struct: 735 mapFields := make([]selectItem, lenCol) 736 for i, col := range columns { 737 mapFields[i] = p.makeStructField(col, out0Type) 738 } 739 740 return mapFields, nil 741 case reflect.Map: 742 mapFields := make([]selectItem, lenCol) 743 for i, col := range columns { 744 mapFields[i] = p.makeMapField(col, out0Type) 745 } 746 747 return mapFields, nil 748 } 749 750 mapFields := make([]selectItem, mathx.Max(lenCol, len(outTypes))) 751 for i := range columns { 752 if i < len(outTypes) { 753 vType := out0Type 754 if i > 0 { 755 vType = outTypes[i] 756 } 757 758 ptr := vType.Kind() == reflect.Ptr 759 if ptr { 760 vType = vType.Elem() 761 } 762 763 mapFields[i] = &singleValue{vType: vType, ptr: ptr} 764 } else { 765 mapFields[i] = &singleValue{vType: reflect.TypeOf("")} 766 } 767 } 768 769 for i := lenCol; i < len(outTypes); i++ { 770 mapFields[i] = &singleValue{vType: outTypes[i]} 771 } 772 773 return mapFields, nil 774 } 775 776 func (p *SQLParsed) makeMapField(col string, outType reflect.Type) selectItem { 777 return &mapItem{k: reflect.ValueOf(col), vType: outType.Elem()} 778 } 779 780 func (p *SQLParsed) makeStructField(col string, outType reflect.Type) selectItem { 781 fv, ok := outType.FieldByNameFunc(func(field string) bool { 782 return matchesField2Col(outType, field, col) 783 }) 784 785 if ok { 786 return &structItem{StructField: &fv} 787 } 788 789 return nil 790 } 791 792 func matchesField2Col(structType reflect.Type, field, col string) bool { 793 f, _ := structType.FieldByName(field) 794 if tagName := f.Tag.Get("name"); tagName != "" { 795 return tagName == col 796 } 797 798 return ss.AnyOfFold(field, col, strcase.ToCamel(col)) 799 } 800 801 func (p *SQLParsed) makeVars(args []reflect.Value) []interface{} { 802 vars := make([]interface{}, 0, len(p.Vars)) 803 804 for i, name := range p.Vars[:len(p.Vars)-len(p.fp.fieldVars)] { 805 if p.BindBy == ByAuto { 806 vars = append(vars, args[i].Interface()) 807 } else { 808 seq, _ := strconv.Atoi(name) 809 vars = append(vars, args[seq-1].Interface()) 810 } 811 } 812 813 if len(p.fp.fieldVars) > 0 { 814 vars = append(vars, p.fp.fieldVars...) 815 } 816 817 return vars 818 } 819 820 func (p *SQLParsed) logError(err error) { 821 log.Printf("E! error: %v", err) 822 p.opt.Logger.LogError(err) 823 } 824 825 func convertExecResult(result sql.Result, query string, outTypes []reflect.Type) ([]reflect.Value, error) { 826 if len(outTypes) == 0 { 827 return []reflect.Value{}, nil 828 } 829 830 lastInsertIDVal, _ := result.LastInsertId() 831 rowsAffectedVal, _ := result.RowsAffected() 832 833 firstWord := strings.ToUpper(ss.FirstWord(query)) 834 results := make([]reflect.Value, 0) 835 836 if len(outTypes) == 1 { 837 if firstWord == "INSERT" { 838 return append(results, reflect.ValueOf(lastInsertIDVal).Convert(outTypes[0])), nil 839 } 840 841 return append(results, reflect.ValueOf(rowsAffectedVal).Convert(outTypes[0])), nil 842 } 843 844 results = append(results, reflect.ValueOf(rowsAffectedVal).Convert(outTypes[0]), 845 reflect.ValueOf(lastInsertIDVal).Convert(outTypes[1])) 846 847 for i := 2; i < len(outTypes); i++ { 848 results = append(results, reflect.Zero(outTypes[i])) 849 } 850 851 return results, nil 852 } 853 854 type selectItem interface { 855 Type() reflect.Type 856 Set(val reflect.Value) 857 ResetParent(parent reflect.Value) 858 } 859 860 type structItem struct { 861 *reflect.StructField 862 parent reflect.Value 863 } 864 865 func (s *structItem) Type() reflect.Type { return s.StructField.Type } 866 func (s *structItem) ResetParent(parent reflect.Value) { s.parent = parent } 867 func (s *structItem) Set(val reflect.Value) { 868 f := s.parent.FieldByName(s.StructField.Name) 869 f.Set(val.Convert(f.Type())) 870 } 871 872 type mapItem struct { 873 k reflect.Value 874 vType reflect.Type 875 parent reflect.Value 876 } 877 878 func (s *mapItem) Type() reflect.Type { return s.vType } 879 func (s *mapItem) ResetParent(parent reflect.Value) { s.parent = parent } 880 func (s *mapItem) Set(val reflect.Value) { s.parent.SetMapIndex(s.k, val) } 881 882 type singleValue struct { 883 ptr bool 884 parent reflect.Value 885 vType reflect.Type 886 } 887 888 func (s *singleValue) Type() reflect.Type { return s.vType } 889 func (s *singleValue) ResetParent(parent reflect.Value) { s.parent = parent } 890 func (s *singleValue) Set(val reflect.Value) { 891 if !s.parent.IsValid() { 892 s.parent = reflect.Indirect(reflect.New(s.vType)) 893 } 894 895 s.parent.Set(val) 896 } 897 898 func resetDests(out0Type reflect.Type, out0TypePtr bool, 899 outTypes []reflect.Type, mapFields []selectItem, 900 ) ([]interface{}, []reflect.Value) { 901 pointers := make([]interface{}, len(mapFields)) 902 903 var out0 reflect.Value 904 905 out := make([]reflect.Value, len(outTypes)) 906 907 out0Kind := out0Type.Kind() 908 hasParent := false 909 switch out0Kind { 910 case reflect.Map, reflect.Struct: 911 hasParent = true 912 } 913 914 switch out0Kind { 915 case reflect.Map: 916 out0 = reflect.MakeMap(reflect.MapOf(out0Type.Key(), out0Type.Elem())) 917 out[0] = out0 918 default: 919 out0Ptr := reflect.New(out0Type) 920 out0 = reflect.Indirect(out0Ptr) 921 922 if out0TypePtr { 923 out[0] = out0Ptr 924 } else { 925 out[0] = out0 926 } 927 } 928 929 for i, fv := range mapFields { 930 if fv == nil { 931 pointers[i] = &NullAny{Type: nil} 932 continue 933 } 934 935 if hasParent { 936 fv.ResetParent(out0) 937 } else if i == 0 { 938 fv.ResetParent(out[0]) 939 } else if i < len(outTypes) { 940 out[i] = reflect.Indirect(reflect.New(outTypes[i])) 941 fv.ResetParent(out[i]) 942 } 943 944 if ImplSQLScanner(fv.Type()) { 945 pointers[i] = reflect.New(fv.Type()).Interface() 946 } else { 947 pointers[i] = &NullAny{Type: fv.Type()} 948 } 949 } 950 951 return pointers, out 952 } 953 954 func fillFields(mapFields []selectItem, pointers []interface{}) { 955 for i, field := range mapFields { 956 if field == nil { 957 continue 958 } 959 960 if p, ok := pointers[i].(*NullAny); ok { 961 field.Set(p.GetVal()) 962 } else { 963 field.Set(reflect.ValueOf(pointers[i]).Elem()) 964 } 965 } 966 }