github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/db.go (about) 1 // The original package is migrated from beego and modified, you can find orignal from following link: 2 // "github.com/beego/beego/" 3 // 4 // Copyright 2023 IAC. All Rights Reserved. 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // 12 // Unless required by applicable law or agreed to in writing, software 13 // distributed under the License is distributed on an "AS IS" BASIS, 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 18 package orm 19 20 import ( 21 "context" 22 "database/sql" 23 "errors" 24 "fmt" 25 "reflect" 26 "strings" 27 "time" 28 29 "github.com/mdaxf/iac/databases/orm/hints" 30 ) 31 32 const ( 33 formatTime = "15:04:05" 34 formatDate = "2006-01-02" 35 formatDateTime = "2006-01-02 15:04:05" 36 ) 37 38 // ErrMissPK missing pk error 39 var ErrMissPK = errors.New("missed pk value") 40 41 var operators = map[string]bool{ 42 "exact": true, 43 "iexact": true, 44 "strictexact": true, 45 "contains": true, 46 "icontains": true, 47 // "regex": true, 48 // "iregex": true, 49 "gt": true, 50 "gte": true, 51 "lt": true, 52 "lte": true, 53 "eq": true, 54 "nq": true, 55 "ne": true, 56 "startswith": true, 57 "endswith": true, 58 "istartswith": true, 59 "iendswith": true, 60 "in": true, 61 "between": true, 62 // "year": true, 63 // "month": true, 64 // "day": true, 65 // "week_day": true, 66 "isnull": true, 67 // "search": true, 68 } 69 70 // an instance of dbBaser interface/ 71 type dbBase struct { 72 ins dbBaser 73 } 74 75 // check dbBase implements dbBaser interface. 76 var _ dbBaser = new(dbBase) 77 78 // get struct columns values as interface slice. 79 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { 80 if names == nil { 81 ns := make([]string, 0, len(cols)) 82 names = &ns 83 } 84 values = make([]interface{}, 0, len(cols)) 85 86 for _, column := range cols { 87 var fi *fieldInfo 88 if fi, _ = mi.fields.GetByAny(column); fi != nil { 89 column = fi.column 90 } else { 91 panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) 92 } 93 if !fi.dbcol || fi.auto && skipAuto { 94 continue 95 } 96 value, err := d.collectFieldValue(mi, fi, ind, insert, tz) 97 if err != nil { 98 return nil, nil, err 99 } 100 101 // ignore empty value auto field 102 if insert && fi.auto { 103 if fi.fieldType&IsPositiveIntegerField > 0 { 104 if vu, ok := value.(uint64); !ok || vu == 0 { 105 continue 106 } 107 } else { 108 if vu, ok := value.(int64); !ok || vu == 0 { 109 continue 110 } 111 } 112 autoFields = append(autoFields, fi.column) 113 } 114 115 *names, values = append(*names, column), append(values, value) 116 } 117 118 return 119 } 120 121 // get one field value in struct column as interface. 122 func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { 123 var value interface{} 124 if fi.pk { 125 _, value, _ = getExistPk(mi, ind) 126 } else { 127 field := ind.FieldByIndex(fi.fieldIndex) 128 if fi.isFielder { 129 f := field.Addr().Interface().(Fielder) 130 value = f.RawValue() 131 } else { 132 switch fi.fieldType { 133 case TypeBooleanField: 134 if nb, ok := field.Interface().(sql.NullBool); ok { 135 value = nil 136 if nb.Valid { 137 value = nb.Bool 138 } 139 } else if field.Kind() == reflect.Ptr { 140 if field.IsNil() { 141 value = nil 142 } else { 143 value = field.Elem().Bool() 144 } 145 } else { 146 value = field.Bool() 147 } 148 case TypeVarCharField, TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField: 149 if ns, ok := field.Interface().(sql.NullString); ok { 150 value = nil 151 if ns.Valid { 152 value = ns.String 153 } 154 } else if field.Kind() == reflect.Ptr { 155 if field.IsNil() { 156 value = nil 157 } else { 158 value = field.Elem().String() 159 } 160 } else { 161 value = field.String() 162 } 163 case TypeFloatField, TypeDecimalField: 164 if nf, ok := field.Interface().(sql.NullFloat64); ok { 165 value = nil 166 if nf.Valid { 167 value = nf.Float64 168 } 169 } else if field.Kind() == reflect.Ptr { 170 if field.IsNil() { 171 value = nil 172 } else { 173 value = field.Elem().Float() 174 } 175 } else { 176 vu := field.Interface() 177 if _, ok := vu.(float32); ok { 178 value, _ = StrTo(ToStr(vu)).Float64() 179 } else { 180 value = field.Float() 181 } 182 } 183 case TypeTimeField, TypeDateField, TypeDateTimeField: 184 value = field.Interface() 185 if t, ok := value.(time.Time); ok { 186 d.ins.TimeToDB(&t, tz) 187 if t.IsZero() { 188 value = nil 189 } else { 190 value = t 191 } 192 } 193 default: 194 switch { 195 case fi.fieldType&IsPositiveIntegerField > 0: 196 if field.Kind() == reflect.Ptr { 197 if field.IsNil() { 198 value = nil 199 } else { 200 value = field.Elem().Uint() 201 } 202 } else { 203 value = field.Uint() 204 } 205 case fi.fieldType&IsIntegerField > 0: 206 if ni, ok := field.Interface().(sql.NullInt64); ok { 207 value = nil 208 if ni.Valid { 209 value = ni.Int64 210 } 211 } else if field.Kind() == reflect.Ptr { 212 if field.IsNil() { 213 value = nil 214 } else { 215 value = field.Elem().Int() 216 } 217 } else { 218 value = field.Int() 219 } 220 case fi.fieldType&IsRelField > 0: 221 if field.IsNil() { 222 value = nil 223 } else { 224 if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok { 225 value = vu 226 } else { 227 value = nil 228 } 229 } 230 if !fi.null && value == nil { 231 return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) 232 } 233 } 234 } 235 } 236 switch fi.fieldType { 237 case TypeTimeField, TypeDateField, TypeDateTimeField: 238 if fi.autoNow || fi.autoNowAdd && insert { 239 if insert { 240 if t, ok := value.(time.Time); ok && !t.IsZero() { 241 break 242 } 243 } 244 tnow := time.Now() 245 d.ins.TimeToDB(&tnow, tz) 246 value = tnow 247 if fi.isFielder { 248 f := field.Addr().Interface().(Fielder) 249 f.SetRaw(tnow.In(DefaultTimeLoc)) 250 } else if field.Kind() == reflect.Ptr { 251 v := tnow.In(DefaultTimeLoc) 252 field.Set(reflect.ValueOf(&v)) 253 } else { 254 field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc))) 255 } 256 } 257 case TypeJSONField, TypeJsonbField: 258 if s, ok := value.(string); (ok && len(s) == 0) || value == nil { 259 if fi.colDefault && fi.initial.Exist() { 260 value = fi.initial.String() 261 } else { 262 value = nil 263 } 264 } 265 } 266 } 267 return value, nil 268 } 269 270 // PrepareInsert create insert sql preparation statement object. 271 func (d *dbBase) PrepareInsert(ctx context.Context, q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { 272 Q := d.ins.TableQuote() 273 274 dbcols := make([]string, 0, len(mi.fields.dbcols)) 275 marks := make([]string, 0, len(mi.fields.dbcols)) 276 for _, fi := range mi.fields.fieldsDB { 277 if !fi.auto { 278 dbcols = append(dbcols, fi.column) 279 marks = append(marks, "?") 280 } 281 } 282 qmarks := strings.Join(marks, ", ") 283 sep := fmt.Sprintf("%s, %s", Q, Q) 284 columns := strings.Join(dbcols, sep) 285 286 query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) 287 288 d.ins.ReplaceMarks(&query) 289 290 d.ins.HasReturningID(mi, &query) 291 292 stmt, err := q.PrepareContext(ctx, query) 293 return stmt, query, err 294 } 295 296 // InsertStmt insert struct with prepared statement and given struct reflect value. 297 func (d *dbBase) InsertStmt(ctx context.Context, stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { 298 values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) 299 if err != nil { 300 return 0, err 301 } 302 303 if d.ins.HasReturningID(mi, nil) { 304 row := stmt.QueryRow(values...) 305 var id int64 306 err := row.Scan(&id) 307 return id, err 308 } 309 res, err := stmt.ExecContext(ctx, values...) 310 if err == nil { 311 return res.LastInsertId() 312 } 313 return 0, err 314 } 315 316 // query sql ,read records and persist in dbBaser. 317 func (d *dbBase) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { 318 var whereCols []string 319 var args []interface{} 320 321 // if specify cols length > 0, then use it for where condition. 322 if len(cols) > 0 { 323 var err error 324 whereCols = make([]string, 0, len(cols)) 325 args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) 326 if err != nil { 327 return err 328 } 329 } else { 330 // default use pk value as where condtion. 331 pkColumn, pkValue, ok := getExistPk(mi, ind) 332 if !ok { 333 return ErrMissPK 334 } 335 whereCols = []string{pkColumn} 336 args = append(args, pkValue) 337 } 338 339 Q := d.ins.TableQuote() 340 341 sep := fmt.Sprintf("%s, %s", Q, Q) 342 sels := strings.Join(mi.fields.dbcols, sep) 343 colsNum := len(mi.fields.dbcols) 344 345 sep = fmt.Sprintf("%s = ? AND %s", Q, Q) 346 wheres := strings.Join(whereCols, sep) 347 348 forUpdate := "" 349 if isForUpdate { 350 forUpdate = "FOR UPDATE" 351 } 352 353 query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q, forUpdate) 354 355 refs := make([]interface{}, colsNum) 356 for i := range refs { 357 var ref interface{} 358 refs[i] = &ref 359 } 360 361 d.ins.ReplaceMarks(&query) 362 363 row := q.QueryRowContext(ctx, query, args...) 364 if err := row.Scan(refs...); err != nil { 365 if err == sql.ErrNoRows { 366 return ErrNoRows 367 } 368 return err 369 } 370 elm := reflect.New(mi.addrField.Elem().Type()) 371 mind := reflect.Indirect(elm) 372 d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) 373 ind.Set(mind) 374 return nil 375 } 376 377 // Insert execute insert sql dbQuerier with given struct reflect.Value. 378 func (d *dbBase) Insert(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { 379 names := make([]string, 0, len(mi.fields.dbcols)) 380 values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) 381 if err != nil { 382 return 0, err 383 } 384 385 id, err := d.InsertValue(ctx, q, mi, false, names, values) 386 if err != nil { 387 return 0, err 388 } 389 390 if len(autoFields) > 0 { 391 err = d.ins.setval(ctx, q, mi, autoFields) 392 } 393 return id, err 394 } 395 396 // InsertMulti multi-insert sql with given slice struct reflect.Value. 397 func (d *dbBase) InsertMulti(ctx context.Context, q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { 398 var ( 399 cnt int64 400 nums int 401 values []interface{} 402 names []string 403 ) 404 405 // typ := reflect.Indirect(mi.addrField).Type() 406 407 length, autoFields := sind.Len(), make([]string, 0, 1) 408 409 for i := 1; i <= length; i++ { 410 411 ind := reflect.Indirect(sind.Index(i - 1)) 412 413 // Is this needed ? 414 // if !ind.Type().AssignableTo(typ) { 415 // return cnt, ErrArgs 416 // } 417 418 if i == 1 { 419 var ( 420 vus []interface{} 421 err error 422 ) 423 vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) 424 if err != nil { 425 return cnt, err 426 } 427 values = make([]interface{}, bulk*len(vus)) 428 nums += copy(values, vus) 429 } else { 430 vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) 431 if err != nil { 432 return cnt, err 433 } 434 435 if len(vus) != len(names) { 436 return cnt, ErrArgs 437 } 438 439 nums += copy(values[nums:], vus) 440 } 441 442 if i > 1 && i%bulk == 0 || length == i { 443 num, err := d.InsertValue(ctx, q, mi, true, names, values[:nums]) 444 if err != nil { 445 return cnt, err 446 } 447 cnt += num 448 nums = 0 449 } 450 } 451 452 var err error 453 if len(autoFields) > 0 { 454 err = d.ins.setval(ctx, q, mi, autoFields) 455 } 456 457 return cnt, err 458 } 459 460 // InsertValue execute insert sql with given struct and given values. 461 // insert the given values, not the field values in struct. 462 func (d *dbBase) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { 463 Q := d.ins.TableQuote() 464 465 marks := make([]string, len(names)) 466 for i := range marks { 467 marks[i] = "?" 468 } 469 470 sep := fmt.Sprintf("%s, %s", Q, Q) 471 qmarks := strings.Join(marks, ", ") 472 columns := strings.Join(names, sep) 473 474 multi := len(values) / len(names) 475 476 if isMulti && multi > 1 { 477 qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks 478 } 479 480 query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) 481 482 d.ins.ReplaceMarks(&query) 483 484 if isMulti || !d.ins.HasReturningID(mi, &query) { 485 res, err := q.ExecContext(ctx, query, values...) 486 if err == nil { 487 if isMulti { 488 return res.RowsAffected() 489 } 490 491 lastInsertId, err := res.LastInsertId() 492 if err != nil { 493 DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) 494 return lastInsertId, ErrLastInsertIdUnavailable 495 } else { 496 return lastInsertId, nil 497 } 498 } 499 return 0, err 500 } 501 row := q.QueryRowContext(ctx, query, values...) 502 var id int64 503 err := row.Scan(&id) 504 return id, err 505 } 506 507 // InsertOrUpdate a row 508 // If your primary key or unique column conflict will update 509 // If no will insert 510 func (d *dbBase) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { 511 args0 := "" 512 iouStr := "" 513 argsMap := map[string]string{} 514 switch a.Driver { 515 case DRMySQL: 516 iouStr = "ON DUPLICATE KEY UPDATE" 517 case DRPostgres: 518 if len(args) == 0 { 519 return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) 520 } 521 args0 = strings.ToLower(args[0]) 522 iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) 523 default: 524 return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) 525 } 526 527 // Get on the key-value pairs 528 for _, v := range args { 529 kv := strings.Split(v, "=") 530 if len(kv) == 2 { 531 argsMap[strings.ToLower(kv[0])] = kv[1] 532 } 533 } 534 535 isMulti := false 536 names := make([]string, 0, len(mi.fields.dbcols)-1) 537 Q := d.ins.TableQuote() 538 values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) 539 if err != nil { 540 return 0, err 541 } 542 543 marks := make([]string, len(names)) 544 updateValues := make([]interface{}, 0) 545 updates := make([]string, len(names)) 546 var conflitValue interface{} 547 for i, v := range names { 548 // identifier in database may not be case-sensitive, so quote it 549 v = fmt.Sprintf("%s%s%s", Q, v, Q) 550 marks[i] = "?" 551 valueStr := argsMap[strings.ToLower(v)] 552 if v == args0 { 553 conflitValue = values[i] 554 } 555 if valueStr != "" { 556 switch a.Driver { 557 case DRMySQL: 558 updates[i] = v + "=" + valueStr 559 case DRPostgres: 560 if conflitValue != nil { 561 // postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values 562 updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0) 563 updateValues = append(updateValues, conflitValue) 564 } else { 565 return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) 566 } 567 } 568 } else { 569 updates[i] = v + "=?" 570 updateValues = append(updateValues, values[i]) 571 } 572 } 573 574 values = append(values, updateValues...) 575 576 sep := fmt.Sprintf("%s, %s", Q, Q) 577 qmarks := strings.Join(marks, ", ") 578 qupdates := strings.Join(updates, ", ") 579 columns := strings.Join(names, sep) 580 581 multi := len(values) / len(names) 582 583 if isMulti { 584 qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks 585 } 586 // conflitValue maybe is a int,can`t use fmt.Sprintf 587 query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) 588 589 d.ins.ReplaceMarks(&query) 590 591 if isMulti || !d.ins.HasReturningID(mi, &query) { 592 res, err := q.ExecContext(ctx, query, values...) 593 if err == nil { 594 if isMulti { 595 return res.RowsAffected() 596 } 597 598 lastInsertId, err := res.LastInsertId() 599 if err != nil { 600 DebugLog.Println(ErrLastInsertIdUnavailable, ':', err) 601 return lastInsertId, ErrLastInsertIdUnavailable 602 } else { 603 return lastInsertId, nil 604 } 605 } 606 return 0, err 607 } 608 609 row := q.QueryRowContext(ctx, query, values...) 610 var id int64 611 err = row.Scan(&id) 612 if err != nil && err.Error() == `pq: syntax error at or near "ON"` { 613 err = fmt.Errorf("postgres version must 9.5 or higher") 614 } 615 return id, err 616 } 617 618 // Update execute update sql dbQuerier with given struct reflect.Value. 619 func (d *dbBase) Update(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { 620 pkName, pkValue, ok := getExistPk(mi, ind) 621 if !ok { 622 return 0, ErrMissPK 623 } 624 625 var setNames []string 626 627 // if specify cols length is zero, then commit all columns. 628 if len(cols) == 0 { 629 cols = mi.fields.dbcols 630 setNames = make([]string, 0, len(mi.fields.dbcols)-1) 631 } else { 632 setNames = make([]string, 0, len(cols)) 633 } 634 635 setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) 636 if err != nil { 637 return 0, err 638 } 639 640 var findAutoNowAdd, findAutoNow bool 641 var index int 642 for i, col := range setNames { 643 if mi.fields.GetByColumn(col).autoNowAdd { 644 index = i 645 findAutoNowAdd = true 646 } 647 if mi.fields.GetByColumn(col).autoNow { 648 findAutoNow = true 649 } 650 } 651 if findAutoNowAdd { 652 setNames = append(setNames[0:index], setNames[index+1:]...) 653 setValues = append(setValues[0:index], setValues[index+1:]...) 654 } 655 656 if !findAutoNow { 657 for col, info := range mi.fields.columns { 658 if info.autoNow { 659 setNames = append(setNames, col) 660 setValues = append(setValues, time.Now()) 661 } 662 } 663 } 664 665 setValues = append(setValues, pkValue) 666 667 Q := d.ins.TableQuote() 668 669 sep := fmt.Sprintf("%s = ?, %s", Q, Q) 670 setColumns := strings.Join(setNames, sep) 671 672 query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q) 673 674 d.ins.ReplaceMarks(&query) 675 676 res, err := q.ExecContext(ctx, query, setValues...) 677 if err == nil { 678 return res.RowsAffected() 679 } 680 return 0, err 681 } 682 683 // Delete execute delete sql dbQuerier with given struct reflect.Value. 684 // delete index is pk. 685 func (d *dbBase) Delete(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { 686 var whereCols []string 687 var args []interface{} 688 // if specify cols length > 0, then use it for where condition. 689 if len(cols) > 0 { 690 var err error 691 whereCols = make([]string, 0, len(cols)) 692 args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) 693 if err != nil { 694 return 0, err 695 } 696 } else { 697 // default use pk value as where condtion. 698 pkColumn, pkValue, ok := getExistPk(mi, ind) 699 if !ok { 700 return 0, ErrMissPK 701 } 702 whereCols = []string{pkColumn} 703 args = append(args, pkValue) 704 } 705 706 Q := d.ins.TableQuote() 707 708 sep := fmt.Sprintf("%s = ? AND %s", Q, Q) 709 wheres := strings.Join(whereCols, sep) 710 711 query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) 712 713 d.ins.ReplaceMarks(&query) 714 res, err := q.ExecContext(ctx, query, args...) 715 if err == nil { 716 num, err := res.RowsAffected() 717 if err != nil { 718 return 0, err 719 } 720 if num > 0 { 721 err := d.deleteRels(ctx, q, mi, args, tz) 722 if err != nil { 723 return num, err 724 } 725 } 726 return num, err 727 } 728 return 0, err 729 } 730 731 // UpdateBatch update table-related record by querySet. 732 // need querySet not struct reflect.Value to update related records. 733 func (d *dbBase) UpdateBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { 734 columns := make([]string, 0, len(params)) 735 values := make([]interface{}, 0, len(params)) 736 for col, val := range params { 737 if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol { 738 panic(fmt.Errorf("wrong field/column name `%s`", col)) 739 } else { 740 columns = append(columns, fi.column) 741 values = append(values, val) 742 } 743 } 744 745 if len(columns) == 0 { 746 panic(fmt.Errorf("update params cannot empty")) 747 } 748 749 tables := newDbTables(mi, d.ins) 750 var specifyIndexes string 751 if qs != nil { 752 tables.parseRelated(qs.related, qs.relDepth) 753 specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) 754 } 755 756 where, args := tables.getCondSQL(cond, false, tz) 757 758 values = append(values, args...) 759 760 join := tables.getJoinSQL() 761 762 var query, T string 763 764 Q := d.ins.TableQuote() 765 766 if d.ins.SupportUpdateJoin() { 767 T = "T0." 768 } 769 770 cols := make([]string, 0, len(columns)) 771 772 for i, v := range columns { 773 col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q) 774 if c, ok := values[i].(colValue); ok { 775 switch c.opt { 776 case ColAdd: 777 cols = append(cols, col+" = "+col+" + ?") 778 case ColMinus: 779 cols = append(cols, col+" = "+col+" - ?") 780 case ColMultiply: 781 cols = append(cols, col+" = "+col+" * ?") 782 case ColExcept: 783 cols = append(cols, col+" = "+col+" / ?") 784 case ColBitAnd: 785 cols = append(cols, col+" = "+col+" & ?") 786 case ColBitRShift: 787 cols = append(cols, col+" = "+col+" >> ?") 788 case ColBitLShift: 789 cols = append(cols, col+" = "+col+" << ?") 790 case ColBitXOR: 791 cols = append(cols, col+" = "+col+" ^ ?") 792 case ColBitOr: 793 cols = append(cols, col+" = "+col+" | ?") 794 } 795 values[i] = c.value 796 } else { 797 cols = append(cols, col+" = ?") 798 } 799 } 800 801 sets := strings.Join(cols, ", ") + " " 802 803 if d.ins.SupportUpdateJoin() { 804 query = fmt.Sprintf("UPDATE %s%s%s T0 %s%sSET %s%s", Q, mi.table, Q, specifyIndexes, join, sets, where) 805 } else { 806 supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s%s", 807 Q, mi.fields.pk.column, Q, 808 Q, mi.table, Q, 809 specifyIndexes, join, where) 810 query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) 811 } 812 813 d.ins.ReplaceMarks(&query) 814 res, err := q.ExecContext(ctx, query, values...) 815 if err == nil { 816 return res.RowsAffected() 817 } 818 return 0, err 819 } 820 821 // delete related records. 822 // do UpdateBanch or DeleteBanch by condition of tables' relationship. 823 func (d *dbBase) deleteRels(ctx context.Context, q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { 824 for _, fi := range mi.fields.fieldsReverse { 825 fi = fi.reverseFieldInfo 826 switch fi.onDelete { 827 case odCascade: 828 cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) 829 _, err := d.DeleteBatch(ctx, q, nil, fi.mi, cond, tz) 830 if err != nil { 831 return err 832 } 833 case odSetDefault, odSetNULL: 834 cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) 835 params := Params{fi.column: nil} 836 if fi.onDelete == odSetDefault { 837 params[fi.column] = fi.initial.String() 838 } 839 _, err := d.UpdateBatch(ctx, q, nil, fi.mi, cond, params, tz) 840 if err != nil { 841 return err 842 } 843 case odDoNothing: 844 } 845 } 846 return nil 847 } 848 849 // DeleteBatch delete table-related records. 850 func (d *dbBase) DeleteBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { 851 tables := newDbTables(mi, d.ins) 852 tables.skipEnd = true 853 854 var specifyIndexes string 855 if qs != nil { 856 tables.parseRelated(qs.related, qs.relDepth) 857 specifyIndexes = tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) 858 } 859 860 if cond == nil || cond.IsEmpty() { 861 panic(fmt.Errorf("delete operation cannot execute without condition")) 862 } 863 864 Q := d.ins.TableQuote() 865 866 where, args := tables.getCondSQL(cond, false, tz) 867 join := tables.getJoinSQL() 868 869 cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) 870 query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s%s", cols, Q, mi.table, Q, specifyIndexes, join, where) 871 872 d.ins.ReplaceMarks(&query) 873 874 var rs *sql.Rows 875 r, err := q.QueryContext(ctx, query, args...) 876 if err != nil { 877 return 0, err 878 } 879 rs = r 880 defer rs.Close() 881 882 var ref interface{} 883 args = make([]interface{}, 0) 884 cnt := 0 885 for rs.Next() { 886 if err := rs.Scan(&ref); err != nil { 887 return 0, err 888 } 889 pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz) 890 if err != nil { 891 return 0, err 892 } 893 args = append(args, pkValue) 894 cnt++ 895 } 896 897 if cnt == 0 { 898 return 0, nil 899 } 900 901 marks := make([]string, len(args)) 902 for i := range marks { 903 marks[i] = "?" 904 } 905 sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) 906 query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) 907 908 d.ins.ReplaceMarks(&query) 909 res, err := q.ExecContext(ctx, query, args...) 910 if err == nil { 911 num, err := res.RowsAffected() 912 if err != nil { 913 return 0, err 914 } 915 if num > 0 { 916 err := d.deleteRels(ctx, q, mi, args, tz) 917 if err != nil { 918 return num, err 919 } 920 } 921 return num, nil 922 } 923 return 0, err 924 } 925 926 // ReadBatch read related records. 927 func (d *dbBase) ReadBatch(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { 928 val := reflect.ValueOf(container) 929 ind := reflect.Indirect(val) 930 931 unregister := true 932 one := true 933 isPtr := true 934 name := "" 935 936 if val.Kind() == reflect.Ptr { 937 fn := "" 938 if ind.Kind() == reflect.Slice { 939 one = false 940 typ := ind.Type().Elem() 941 switch typ.Kind() { 942 case reflect.Ptr: 943 fn = getFullName(typ.Elem()) 944 case reflect.Struct: 945 isPtr = false 946 fn = getFullName(typ) 947 name = getTableName(reflect.New(typ)) 948 } 949 } else { 950 fn = getFullName(ind.Type()) 951 name = getTableName(ind) 952 } 953 unregister = fn != mi.fullName 954 } 955 956 if unregister { 957 RegisterModel(container) 958 } 959 960 rlimit := qs.limit 961 offset := qs.offset 962 963 Q := d.ins.TableQuote() 964 965 var tCols []string 966 if len(cols) > 0 { 967 hasRel := len(qs.related) > 0 || qs.relDepth > 0 968 tCols = make([]string, 0, len(cols)) 969 var maps map[string]bool 970 if hasRel { 971 maps = make(map[string]bool) 972 } 973 for _, col := range cols { 974 if fi, ok := mi.fields.GetByAny(col); ok { 975 tCols = append(tCols, fi.column) 976 if hasRel { 977 maps[fi.column] = true 978 } 979 } else { 980 return 0, fmt.Errorf("wrong field/column name `%s`", col) 981 } 982 } 983 if hasRel { 984 for _, fi := range mi.fields.fieldsDB { 985 if fi.fieldType&IsRelField > 0 { 986 if !maps[fi.column] { 987 tCols = append(tCols, fi.column) 988 } 989 } 990 } 991 } 992 } else { 993 tCols = mi.fields.dbcols 994 } 995 996 colsNum := len(tCols) 997 sep := fmt.Sprintf("%s, T0.%s", Q, Q) 998 sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q) 999 1000 tables := newDbTables(mi, d.ins) 1001 tables.parseRelated(qs.related, qs.relDepth) 1002 1003 where, args := tables.getCondSQL(cond, false, tz) 1004 groupBy := tables.getGroupSQL(qs.groups) 1005 orderBy := tables.getOrderSQL(qs.orders) 1006 limit := tables.getLimitSQL(mi, offset, rlimit) 1007 join := tables.getJoinSQL() 1008 specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) 1009 1010 for _, tbl := range tables.tables { 1011 if tbl.sel { 1012 colsNum += len(tbl.mi.fields.dbcols) 1013 sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) 1014 sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) 1015 } 1016 } 1017 1018 sqlSelect := "SELECT" 1019 if qs.distinct { 1020 sqlSelect += " DISTINCT" 1021 } 1022 if qs.aggregate != "" { 1023 sels = qs.aggregate 1024 } 1025 query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", 1026 sqlSelect, sels, Q, mi.table, Q, 1027 specifyIndexes, join, where, groupBy, orderBy, limit) 1028 1029 if qs.forUpdate { 1030 query += " FOR UPDATE" 1031 } 1032 1033 d.ins.ReplaceMarks(&query) 1034 1035 rs, err := q.QueryContext(ctx, query, args...) 1036 if err != nil { 1037 return 0, err 1038 } 1039 1040 defer rs.Close() 1041 1042 slice := ind 1043 if unregister { 1044 mi, _ = defaultModelCache.get(name) 1045 tCols = mi.fields.dbcols 1046 colsNum = len(tCols) 1047 } 1048 1049 refs := make([]interface{}, colsNum) 1050 for i := range refs { 1051 var ref interface{} 1052 refs[i] = &ref 1053 } 1054 var cnt int64 1055 for rs.Next() { 1056 if one && cnt == 0 || !one { 1057 if err := rs.Scan(refs...); err != nil { 1058 return 0, err 1059 } 1060 1061 elm := reflect.New(mi.addrField.Elem().Type()) 1062 mind := reflect.Indirect(elm) 1063 1064 cacheV := make(map[string]*reflect.Value) 1065 cacheM := make(map[string]*modelInfo) 1066 trefs := refs 1067 1068 d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz) 1069 trefs = refs[len(tCols):] 1070 1071 for _, tbl := range tables.tables { 1072 // loop selected tables 1073 if tbl.sel { 1074 last := mind 1075 names := "" 1076 mmi := mi 1077 // loop cascade models 1078 for _, name := range tbl.names { 1079 names += name 1080 if val, ok := cacheV[names]; ok { 1081 last = *val 1082 mmi = cacheM[names] 1083 } else { 1084 fi := mmi.fields.GetByName(name) 1085 lastm := mmi 1086 mmi = fi.relModelInfo 1087 field := last 1088 if last.Kind() != reflect.Invalid { 1089 field = reflect.Indirect(last.FieldByIndex(fi.fieldIndex)) 1090 if field.IsValid() { 1091 d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz) 1092 for _, fi := range mmi.fields.fieldsReverse { 1093 if fi.inModel && fi.reverseFieldInfo.mi == lastm { 1094 if fi.reverseFieldInfo != nil { 1095 f := field.FieldByIndex(fi.fieldIndex) 1096 if f.Kind() == reflect.Ptr { 1097 f.Set(last.Addr()) 1098 } 1099 } 1100 } 1101 } 1102 last = field 1103 } 1104 } 1105 cacheV[names] = &field 1106 cacheM[names] = mmi 1107 } 1108 } 1109 trefs = trefs[len(mmi.fields.dbcols):] 1110 } 1111 } 1112 1113 if one { 1114 ind.Set(mind) 1115 } else { 1116 if cnt == 0 { 1117 // you can use an empty & caped container list 1118 // orm will not replace it 1119 if ind.Len() != 0 { 1120 // if container is not empty 1121 // create a new one 1122 slice = reflect.New(ind.Type()).Elem() 1123 } 1124 } 1125 1126 if isPtr { 1127 slice = reflect.Append(slice, mind.Addr()) 1128 } else { 1129 slice = reflect.Append(slice, mind) 1130 } 1131 } 1132 } 1133 cnt++ 1134 } 1135 1136 if !one { 1137 if cnt > 0 { 1138 ind.Set(slice) 1139 } else { 1140 // when a result is empty and container is nil 1141 // to set an empty container 1142 if ind.IsNil() { 1143 ind.Set(reflect.MakeSlice(ind.Type(), 0, 0)) 1144 } 1145 } 1146 } 1147 1148 return cnt, nil 1149 } 1150 1151 // Count excute count sql and return count result int64. 1152 func (d *dbBase) Count(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { 1153 tables := newDbTables(mi, d.ins) 1154 tables.parseRelated(qs.related, qs.relDepth) 1155 1156 where, args := tables.getCondSQL(cond, false, tz) 1157 groupBy := tables.getGroupSQL(qs.groups) 1158 tables.getOrderSQL(qs.orders) 1159 join := tables.getJoinSQL() 1160 specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) 1161 1162 Q := d.ins.TableQuote() 1163 1164 query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s%s", 1165 Q, mi.table, Q, 1166 specifyIndexes, join, where, groupBy) 1167 1168 if groupBy != "" { 1169 query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) 1170 } 1171 1172 d.ins.ReplaceMarks(&query) 1173 1174 row := q.QueryRowContext(ctx, query, args...) 1175 err = row.Scan(&cnt) 1176 return 1177 } 1178 1179 // GenerateOperatorSQL generate sql with replacing operator string placeholders and replaced values. 1180 func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { 1181 var sql string 1182 params := getFlatParams(fi, args, tz) 1183 1184 if len(params) == 0 { 1185 panic(fmt.Errorf("operator `%s` need at least one args", operator)) 1186 } 1187 arg := params[0] 1188 1189 switch operator { 1190 case "in": 1191 marks := make([]string, len(params)) 1192 for i := range marks { 1193 marks[i] = "?" 1194 } 1195 sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) 1196 case "between": 1197 if len(params) != 2 { 1198 panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params))) 1199 } 1200 sql = "BETWEEN ? AND ?" 1201 default: 1202 if len(params) > 1 { 1203 panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) 1204 } 1205 sql = d.ins.OperatorSQL(operator) 1206 switch operator { 1207 case "exact": 1208 if arg == nil { 1209 params[0] = "IS NULL" 1210 } 1211 case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith": 1212 param := strings.Replace(ToStr(arg), `%`, `\%`, -1) 1213 switch operator { 1214 case "iexact": 1215 case "contains", "icontains": 1216 param = fmt.Sprintf("%%%s%%", param) 1217 case "startswith", "istartswith": 1218 param = fmt.Sprintf("%s%%", param) 1219 case "endswith", "iendswith": 1220 param = fmt.Sprintf("%%%s", param) 1221 } 1222 params[0] = param 1223 case "isnull": 1224 if b, ok := arg.(bool); ok { 1225 if b { 1226 sql = "IS NULL" 1227 } else { 1228 sql = "IS NOT NULL" 1229 } 1230 params = nil 1231 } else { 1232 panic(fmt.Errorf("operator `%s` need a bool value not `%T`", operator, arg)) 1233 } 1234 } 1235 } 1236 return sql, params 1237 } 1238 1239 // GenerateOperatorLeftCol gernerate sql string with inner function, such as UPPER(text). 1240 func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { 1241 // default not use 1242 } 1243 1244 // set values to struct column. 1245 func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { 1246 for i, column := range cols { 1247 val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() 1248 1249 fi := mi.fields.GetByColumn(column) 1250 1251 field := ind.FieldByIndex(fi.fieldIndex) 1252 1253 value, err := d.convertValueFromDB(fi, val, tz) 1254 if err != nil { 1255 panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) 1256 } 1257 1258 _, err = d.setFieldValue(fi, value, field) 1259 1260 if err != nil { 1261 panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) 1262 } 1263 } 1264 } 1265 1266 // convert value from database result to value following in field type. 1267 func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { 1268 if val == nil { 1269 return nil, nil 1270 } 1271 1272 var value interface{} 1273 var tErr error 1274 1275 var str *StrTo 1276 switch v := val.(type) { 1277 case []byte: 1278 s := StrTo(string(v)) 1279 str = &s 1280 case string: 1281 s := StrTo(v) 1282 str = &s 1283 } 1284 1285 fieldType := fi.fieldType 1286 1287 setValue: 1288 switch { 1289 case fieldType == TypeBooleanField: 1290 if str == nil { 1291 switch v := val.(type) { 1292 case int64: 1293 b := v == 1 1294 value = b 1295 default: 1296 s := StrTo(ToStr(v)) 1297 str = &s 1298 } 1299 } 1300 if str != nil { 1301 b, err := str.Bool() 1302 if err != nil { 1303 tErr = err 1304 goto end 1305 } 1306 value = b 1307 } 1308 case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: 1309 if str == nil { 1310 value = ToStr(val) 1311 } else { 1312 value = str.String() 1313 } 1314 case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: 1315 if str == nil { 1316 switch t := val.(type) { 1317 case time.Time: 1318 d.ins.TimeFromDB(&t, tz) 1319 value = t 1320 default: 1321 s := StrTo(ToStr(t)) 1322 str = &s 1323 } 1324 } 1325 if str != nil { 1326 s := str.String() 1327 var ( 1328 t time.Time 1329 err error 1330 ) 1331 1332 if fi.timePrecision != nil && len(s) >= (20+*fi.timePrecision) { 1333 layout := formatDateTime + "." 1334 for i := 0; i < *fi.timePrecision; i++ { 1335 layout += "0" 1336 } 1337 t, err = time.ParseInLocation(layout, s[:20+*fi.timePrecision], tz) 1338 } else if len(s) >= 19 { 1339 s = s[:19] 1340 t, err = time.ParseInLocation(formatDateTime, s, tz) 1341 } else if len(s) >= 10 { 1342 if len(s) > 10 { 1343 s = s[:10] 1344 } 1345 t, err = time.ParseInLocation(formatDate, s, tz) 1346 } else if len(s) >= 8 { 1347 if len(s) > 8 { 1348 s = s[:8] 1349 } 1350 t, err = time.ParseInLocation(formatTime, s, tz) 1351 } 1352 t = t.In(DefaultTimeLoc) 1353 1354 if err != nil && s != "00:00:00" && s != "0000-00-00" && s != "0000-00-00 00:00:00" { 1355 tErr = err 1356 goto end 1357 } 1358 value = t 1359 } 1360 case fieldType&IsIntegerField > 0: 1361 if str == nil { 1362 s := StrTo(ToStr(val)) 1363 str = &s 1364 } 1365 if str != nil { 1366 var err error 1367 switch fieldType { 1368 case TypeBitField: 1369 _, err = str.Int8() 1370 case TypeSmallIntegerField: 1371 _, err = str.Int16() 1372 case TypeIntegerField: 1373 _, err = str.Int32() 1374 case TypeBigIntegerField: 1375 _, err = str.Int64() 1376 case TypePositiveBitField: 1377 _, err = str.Uint8() 1378 case TypePositiveSmallIntegerField: 1379 _, err = str.Uint16() 1380 case TypePositiveIntegerField: 1381 _, err = str.Uint32() 1382 case TypePositiveBigIntegerField: 1383 _, err = str.Uint64() 1384 } 1385 if err != nil { 1386 tErr = err 1387 goto end 1388 } 1389 if fieldType&IsPositiveIntegerField > 0 { 1390 v, _ := str.Uint64() 1391 value = v 1392 } else { 1393 v, _ := str.Int64() 1394 value = v 1395 } 1396 } 1397 case fieldType == TypeFloatField || fieldType == TypeDecimalField: 1398 if str == nil { 1399 switch v := val.(type) { 1400 case float64: 1401 value = v 1402 default: 1403 s := StrTo(ToStr(v)) 1404 str = &s 1405 } 1406 } 1407 if str != nil { 1408 v, err := str.Float64() 1409 if err != nil { 1410 tErr = err 1411 goto end 1412 } 1413 value = v 1414 } 1415 case fieldType&IsRelField > 0: 1416 fi = fi.relModelInfo.fields.pk 1417 fieldType = fi.fieldType 1418 goto setValue 1419 } 1420 1421 end: 1422 if tErr != nil { 1423 err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr) 1424 return nil, err 1425 } 1426 1427 return value, nil 1428 } 1429 1430 // set one value to struct column field. 1431 func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { 1432 fieldType := fi.fieldType 1433 isNative := !fi.isFielder 1434 1435 setValue: 1436 switch { 1437 case fieldType == TypeBooleanField: 1438 if isNative { 1439 if nb, ok := field.Interface().(sql.NullBool); ok { 1440 if value == nil { 1441 nb.Valid = false 1442 } else { 1443 nb.Bool = value.(bool) 1444 nb.Valid = true 1445 } 1446 field.Set(reflect.ValueOf(nb)) 1447 } else if field.Kind() == reflect.Ptr { 1448 if value != nil { 1449 v := value.(bool) 1450 field.Set(reflect.ValueOf(&v)) 1451 } 1452 } else { 1453 if value == nil { 1454 value = false 1455 } 1456 field.SetBool(value.(bool)) 1457 } 1458 } 1459 case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: 1460 if isNative { 1461 if ns, ok := field.Interface().(sql.NullString); ok { 1462 if value == nil { 1463 ns.Valid = false 1464 } else { 1465 ns.String = value.(string) 1466 ns.Valid = true 1467 } 1468 field.Set(reflect.ValueOf(ns)) 1469 } else if field.Kind() == reflect.Ptr { 1470 if value != nil { 1471 v := value.(string) 1472 field.Set(reflect.ValueOf(&v)) 1473 } 1474 } else { 1475 if value == nil { 1476 value = "" 1477 } 1478 field.SetString(value.(string)) 1479 } 1480 } 1481 case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: 1482 if isNative { 1483 if value == nil { 1484 value = time.Time{} 1485 } else if field.Kind() == reflect.Ptr { 1486 if value != nil { 1487 v := value.(time.Time) 1488 field.Set(reflect.ValueOf(&v)) 1489 } 1490 } else { 1491 field.Set(reflect.ValueOf(value)) 1492 } 1493 } 1494 case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr: 1495 if value != nil { 1496 v := uint8(value.(uint64)) 1497 field.Set(reflect.ValueOf(&v)) 1498 } 1499 case fieldType == TypePositiveSmallIntegerField && field.Kind() == reflect.Ptr: 1500 if value != nil { 1501 v := uint16(value.(uint64)) 1502 field.Set(reflect.ValueOf(&v)) 1503 } 1504 case fieldType == TypePositiveIntegerField && field.Kind() == reflect.Ptr: 1505 if value != nil { 1506 if field.Type() == reflect.TypeOf(new(uint)) { 1507 v := uint(value.(uint64)) 1508 field.Set(reflect.ValueOf(&v)) 1509 } else { 1510 v := uint32(value.(uint64)) 1511 field.Set(reflect.ValueOf(&v)) 1512 } 1513 } 1514 case fieldType == TypePositiveBigIntegerField && field.Kind() == reflect.Ptr: 1515 if value != nil { 1516 v := value.(uint64) 1517 field.Set(reflect.ValueOf(&v)) 1518 } 1519 case fieldType == TypeBitField && field.Kind() == reflect.Ptr: 1520 if value != nil { 1521 v := int8(value.(int64)) 1522 field.Set(reflect.ValueOf(&v)) 1523 } 1524 case fieldType == TypeSmallIntegerField && field.Kind() == reflect.Ptr: 1525 if value != nil { 1526 v := int16(value.(int64)) 1527 field.Set(reflect.ValueOf(&v)) 1528 } 1529 case fieldType == TypeIntegerField && field.Kind() == reflect.Ptr: 1530 if value != nil { 1531 if field.Type() == reflect.TypeOf(new(int)) { 1532 v := int(value.(int64)) 1533 field.Set(reflect.ValueOf(&v)) 1534 } else { 1535 v := int32(value.(int64)) 1536 field.Set(reflect.ValueOf(&v)) 1537 } 1538 } 1539 case fieldType == TypeBigIntegerField && field.Kind() == reflect.Ptr: 1540 if value != nil { 1541 v := value.(int64) 1542 field.Set(reflect.ValueOf(&v)) 1543 } 1544 case fieldType&IsIntegerField > 0: 1545 if fieldType&IsPositiveIntegerField > 0 { 1546 if isNative { 1547 if value == nil { 1548 value = uint64(0) 1549 } 1550 field.SetUint(value.(uint64)) 1551 } 1552 } else { 1553 if isNative { 1554 if ni, ok := field.Interface().(sql.NullInt64); ok { 1555 if value == nil { 1556 ni.Valid = false 1557 } else { 1558 ni.Int64 = value.(int64) 1559 ni.Valid = true 1560 } 1561 field.Set(reflect.ValueOf(ni)) 1562 } else { 1563 if value == nil { 1564 value = int64(0) 1565 } 1566 field.SetInt(value.(int64)) 1567 } 1568 } 1569 } 1570 case fieldType == TypeFloatField || fieldType == TypeDecimalField: 1571 if isNative { 1572 if nf, ok := field.Interface().(sql.NullFloat64); ok { 1573 if value == nil { 1574 nf.Valid = false 1575 } else { 1576 nf.Float64 = value.(float64) 1577 nf.Valid = true 1578 } 1579 field.Set(reflect.ValueOf(nf)) 1580 } else if field.Kind() == reflect.Ptr { 1581 if value != nil { 1582 if field.Type() == reflect.TypeOf(new(float32)) { 1583 v := float32(value.(float64)) 1584 field.Set(reflect.ValueOf(&v)) 1585 } else { 1586 v := value.(float64) 1587 field.Set(reflect.ValueOf(&v)) 1588 } 1589 } 1590 } else { 1591 1592 if value == nil { 1593 value = float64(0) 1594 } 1595 field.SetFloat(value.(float64)) 1596 } 1597 } 1598 case fieldType&IsRelField > 0: 1599 if value != nil { 1600 fieldType = fi.relModelInfo.fields.pk.fieldType 1601 mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) 1602 field.Set(mf) 1603 f := mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) 1604 field = f 1605 goto setValue 1606 } 1607 } 1608 1609 if !isNative { 1610 fd := field.Addr().Interface().(Fielder) 1611 err := fd.SetRaw(value) 1612 if err != nil { 1613 err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err) 1614 return nil, err 1615 } 1616 } 1617 1618 return value, nil 1619 } 1620 1621 // ReadValues query sql, read values , save to *[]ParamList. 1622 func (d *dbBase) ReadValues(ctx context.Context, q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { 1623 var ( 1624 maps []Params 1625 lists []ParamsList 1626 list ParamsList 1627 ) 1628 1629 typ := 0 1630 switch v := container.(type) { 1631 case *[]Params: 1632 d := *v 1633 if len(d) == 0 { 1634 maps = d 1635 } 1636 typ = 1 1637 case *[]ParamsList: 1638 d := *v 1639 if len(d) == 0 { 1640 lists = d 1641 } 1642 typ = 2 1643 case *ParamsList: 1644 d := *v 1645 if len(d) == 0 { 1646 list = d 1647 } 1648 typ = 3 1649 default: 1650 panic(fmt.Errorf("unsupport read values type `%T`", container)) 1651 } 1652 1653 tables := newDbTables(mi, d.ins) 1654 1655 var ( 1656 cols []string 1657 infos []*fieldInfo 1658 ) 1659 1660 hasExprs := len(exprs) > 0 1661 1662 Q := d.ins.TableQuote() 1663 1664 if hasExprs { 1665 cols = make([]string, 0, len(exprs)) 1666 infos = make([]*fieldInfo, 0, len(exprs)) 1667 for _, ex := range exprs { 1668 index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) 1669 if !suc { 1670 panic(fmt.Errorf("unknown field/column name `%s`", ex)) 1671 } 1672 cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) 1673 infos = append(infos, fi) 1674 } 1675 } else { 1676 cols = make([]string, 0, len(mi.fields.dbcols)) 1677 infos = make([]*fieldInfo, 0, len(exprs)) 1678 for _, fi := range mi.fields.fieldsDB { 1679 cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q)) 1680 infos = append(infos, fi) 1681 } 1682 } 1683 1684 where, args := tables.getCondSQL(cond, false, tz) 1685 groupBy := tables.getGroupSQL(qs.groups) 1686 orderBy := tables.getOrderSQL(qs.orders) 1687 limit := tables.getLimitSQL(mi, qs.offset, qs.limit) 1688 join := tables.getJoinSQL() 1689 specifyIndexes := tables.getIndexSql(mi.table, qs.useIndex, qs.indexes) 1690 1691 sels := strings.Join(cols, ", ") 1692 1693 sqlSelect := "SELECT" 1694 if qs.distinct { 1695 sqlSelect += " DISTINCT" 1696 } 1697 query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s%s", 1698 sqlSelect, sels, 1699 Q, mi.table, Q, 1700 specifyIndexes, join, where, groupBy, orderBy, limit) 1701 1702 d.ins.ReplaceMarks(&query) 1703 1704 rs, err := q.QueryContext(ctx, query, args...) 1705 if err != nil { 1706 return 0, err 1707 } 1708 refs := make([]interface{}, len(cols)) 1709 for i := range refs { 1710 var ref interface{} 1711 refs[i] = &ref 1712 } 1713 1714 defer rs.Close() 1715 1716 var ( 1717 cnt int64 1718 columns []string 1719 ) 1720 for rs.Next() { 1721 if cnt == 0 { 1722 cols, err := rs.Columns() 1723 if err != nil { 1724 return 0, err 1725 } 1726 columns = cols 1727 } 1728 1729 if err := rs.Scan(refs...); err != nil { 1730 return 0, err 1731 } 1732 1733 switch typ { 1734 case 1: 1735 params := make(Params, len(cols)) 1736 for i, ref := range refs { 1737 fi := infos[i] 1738 1739 val := reflect.Indirect(reflect.ValueOf(ref)).Interface() 1740 1741 value, err := d.convertValueFromDB(fi, val, tz) 1742 if err != nil { 1743 panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) 1744 } 1745 1746 params[columns[i]] = value 1747 } 1748 maps = append(maps, params) 1749 case 2: 1750 params := make(ParamsList, 0, len(cols)) 1751 for i, ref := range refs { 1752 fi := infos[i] 1753 1754 val := reflect.Indirect(reflect.ValueOf(ref)).Interface() 1755 1756 value, err := d.convertValueFromDB(fi, val, tz) 1757 if err != nil { 1758 panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) 1759 } 1760 1761 params = append(params, value) 1762 } 1763 lists = append(lists, params) 1764 case 3: 1765 for i, ref := range refs { 1766 fi := infos[i] 1767 1768 val := reflect.Indirect(reflect.ValueOf(ref)).Interface() 1769 1770 value, err := d.convertValueFromDB(fi, val, tz) 1771 if err != nil { 1772 panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) 1773 } 1774 1775 list = append(list, value) 1776 } 1777 } 1778 1779 cnt++ 1780 } 1781 1782 switch v := container.(type) { 1783 case *[]Params: 1784 *v = maps 1785 case *[]ParamsList: 1786 *v = lists 1787 case *ParamsList: 1788 *v = list 1789 } 1790 1791 return cnt, nil 1792 } 1793 1794 // SupportUpdateJoin flag of update joined record. 1795 func (d *dbBase) SupportUpdateJoin() bool { 1796 return true 1797 } 1798 1799 func (d *dbBase) MaxLimit() uint64 { 1800 return 18446744073709551615 1801 } 1802 1803 // TableQuote return quote. 1804 func (d *dbBase) TableQuote() string { 1805 return "`" 1806 } 1807 1808 // ReplaceMarks replace value placeholder in parametered sql string. 1809 func (d *dbBase) ReplaceMarks(query *string) { 1810 // default use `?` as mark, do nothing 1811 } 1812 1813 // flag of RETURNING sql. 1814 func (d *dbBase) HasReturningID(*modelInfo, *string) bool { 1815 return false 1816 } 1817 1818 // sync auto key 1819 func (d *dbBase) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error { 1820 return nil 1821 } 1822 1823 // TimeFromDB convert time from db. 1824 func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { 1825 *t = t.In(tz) 1826 } 1827 1828 // TimeToDB convert time to db. 1829 func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { 1830 *t = t.In(tz) 1831 } 1832 1833 // DbTypes get database types. 1834 func (d *dbBase) DbTypes() map[string]string { 1835 return nil 1836 } 1837 1838 // GetTables gt all tables. 1839 func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { 1840 tables := make(map[string]bool) 1841 query := d.ins.ShowTablesQuery() 1842 rows, err := db.Query(query) 1843 if err != nil { 1844 return tables, err 1845 } 1846 1847 defer rows.Close() 1848 1849 for rows.Next() { 1850 var table string 1851 err := rows.Scan(&table) 1852 if err != nil { 1853 return tables, err 1854 } 1855 if table != "" { 1856 tables[table] = true 1857 } 1858 } 1859 1860 return tables, nil 1861 } 1862 1863 // GetColumns get all cloumns in table. 1864 func (d *dbBase) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) { 1865 columns := make(map[string][3]string) 1866 query := d.ins.ShowColumnsQuery(table) 1867 rows, err := db.QueryContext(ctx, query) 1868 if err != nil { 1869 return columns, err 1870 } 1871 1872 defer rows.Close() 1873 1874 for rows.Next() { 1875 var ( 1876 name string 1877 typ string 1878 null string 1879 ) 1880 err := rows.Scan(&name, &typ, &null) 1881 if err != nil { 1882 return columns, err 1883 } 1884 columns[name] = [3]string{name, typ, null} 1885 } 1886 1887 return columns, nil 1888 } 1889 1890 // not implement. 1891 func (d *dbBase) OperatorSQL(operator string) string { 1892 panic(ErrNotImplement) 1893 } 1894 1895 // not implement. 1896 func (d *dbBase) ShowTablesQuery() string { 1897 panic(ErrNotImplement) 1898 } 1899 1900 // not implement. 1901 func (d *dbBase) ShowColumnsQuery(table string) string { 1902 panic(ErrNotImplement) 1903 } 1904 1905 // not implement. 1906 func (d *dbBase) IndexExists(context.Context, dbQuerier, string, string) bool { 1907 panic(ErrNotImplement) 1908 } 1909 1910 // GenerateSpecifyIndex return a specifying index clause 1911 func (d *dbBase) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string { 1912 var s []string 1913 Q := d.TableQuote() 1914 for _, index := range indexes { 1915 tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q) 1916 s = append(s, tmp) 1917 } 1918 1919 var useWay string 1920 1921 switch useIndex { 1922 case hints.KeyUseIndex: 1923 useWay = `USE` 1924 case hints.KeyForceIndex: 1925 useWay = `FORCE` 1926 case hints.KeyIgnoreIndex: 1927 useWay = `IGNORE` 1928 default: 1929 DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored") 1930 return `` 1931 } 1932 1933 return fmt.Sprintf(` %s INDEX(%s) `, useWay, strings.Join(s, `,`)) 1934 }