github.com/astaxie/beego@v1.12.3/orm/db.go (about) 1 // Copyright 2014 beego Author. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package orm 16 17 import ( 18 "database/sql" 19 "errors" 20 "fmt" 21 "reflect" 22 "strings" 23 "time" 24 ) 25 26 const ( 27 formatTime = "15:04:05" 28 formatDate = "2006-01-02" 29 formatDateTime = "2006-01-02 15:04:05" 30 ) 31 32 var ( 33 // ErrMissPK missing pk error 34 ErrMissPK = errors.New("missed pk value") 35 ) 36 37 var ( 38 operators = map[string]bool{ 39 "exact": true, 40 "iexact": true, 41 "strictexact": true, 42 "contains": true, 43 "icontains": true, 44 // "regex": true, 45 // "iregex": true, 46 "gt": true, 47 "gte": true, 48 "lt": true, 49 "lte": true, 50 "eq": true, 51 "nq": true, 52 "ne": true, 53 "startswith": true, 54 "endswith": true, 55 "istartswith": true, 56 "iendswith": true, 57 "in": true, 58 "between": true, 59 // "year": true, 60 // "month": true, 61 // "day": true, 62 // "week_day": true, 63 "isnull": true, 64 // "search": true, 65 } 66 ) 67 68 // an instance of dbBaser interface/ 69 type dbBase struct { 70 ins dbBaser 71 } 72 73 // check dbBase implements dbBaser interface. 74 var _ dbBaser = new(dbBase) 75 76 // get struct columns values as interface slice. 77 func (d *dbBase) collectValues(mi *modelInfo, ind reflect.Value, cols []string, skipAuto bool, insert bool, names *[]string, tz *time.Location) (values []interface{}, autoFields []string, err error) { 78 if names == nil { 79 ns := make([]string, 0, len(cols)) 80 names = &ns 81 } 82 values = make([]interface{}, 0, len(cols)) 83 84 for _, column := range cols { 85 var fi *fieldInfo 86 if fi, _ = mi.fields.GetByAny(column); fi != nil { 87 column = fi.column 88 } else { 89 panic(fmt.Errorf("wrong db field/column name `%s` for model `%s`", column, mi.fullName)) 90 } 91 if !fi.dbcol || fi.auto && skipAuto { 92 continue 93 } 94 value, err := d.collectFieldValue(mi, fi, ind, insert, tz) 95 if err != nil { 96 return nil, nil, err 97 } 98 99 // ignore empty value auto field 100 if insert && fi.auto { 101 if fi.fieldType&IsPositiveIntegerField > 0 { 102 if vu, ok := value.(uint64); !ok || vu == 0 { 103 continue 104 } 105 } else { 106 if vu, ok := value.(int64); !ok || vu == 0 { 107 continue 108 } 109 } 110 autoFields = append(autoFields, fi.column) 111 } 112 113 *names, values = append(*names, column), append(values, value) 114 } 115 116 return 117 } 118 119 // get one field value in struct column as interface. 120 func (d *dbBase) collectFieldValue(mi *modelInfo, fi *fieldInfo, ind reflect.Value, insert bool, tz *time.Location) (interface{}, error) { 121 var value interface{} 122 if fi.pk { 123 _, value, _ = getExistPk(mi, ind) 124 } else { 125 field := ind.FieldByIndex(fi.fieldIndex) 126 if fi.isFielder { 127 f := field.Addr().Interface().(Fielder) 128 value = f.RawValue() 129 } else { 130 switch fi.fieldType { 131 case TypeBooleanField: 132 if nb, ok := field.Interface().(sql.NullBool); ok { 133 value = nil 134 if nb.Valid { 135 value = nb.Bool 136 } 137 } else if field.Kind() == reflect.Ptr { 138 if field.IsNil() { 139 value = nil 140 } else { 141 value = field.Elem().Bool() 142 } 143 } else { 144 value = field.Bool() 145 } 146 case TypeVarCharField, TypeCharField, TypeTextField, TypeJSONField, TypeJsonbField: 147 if ns, ok := field.Interface().(sql.NullString); ok { 148 value = nil 149 if ns.Valid { 150 value = ns.String 151 } 152 } else if field.Kind() == reflect.Ptr { 153 if field.IsNil() { 154 value = nil 155 } else { 156 value = field.Elem().String() 157 } 158 } else { 159 value = field.String() 160 } 161 case TypeFloatField, TypeDecimalField: 162 if nf, ok := field.Interface().(sql.NullFloat64); ok { 163 value = nil 164 if nf.Valid { 165 value = nf.Float64 166 } 167 } else if field.Kind() == reflect.Ptr { 168 if field.IsNil() { 169 value = nil 170 } else { 171 value = field.Elem().Float() 172 } 173 } else { 174 vu := field.Interface() 175 if _, ok := vu.(float32); ok { 176 value, _ = StrTo(ToStr(vu)).Float64() 177 } else { 178 value = field.Float() 179 } 180 } 181 case TypeTimeField, TypeDateField, TypeDateTimeField: 182 value = field.Interface() 183 if t, ok := value.(time.Time); ok { 184 d.ins.TimeToDB(&t, tz) 185 if t.IsZero() { 186 value = nil 187 } else { 188 value = t 189 } 190 } 191 default: 192 switch { 193 case fi.fieldType&IsPositiveIntegerField > 0: 194 if field.Kind() == reflect.Ptr { 195 if field.IsNil() { 196 value = nil 197 } else { 198 value = field.Elem().Uint() 199 } 200 } else { 201 value = field.Uint() 202 } 203 case fi.fieldType&IsIntegerField > 0: 204 if ni, ok := field.Interface().(sql.NullInt64); ok { 205 value = nil 206 if ni.Valid { 207 value = ni.Int64 208 } 209 } else if field.Kind() == reflect.Ptr { 210 if field.IsNil() { 211 value = nil 212 } else { 213 value = field.Elem().Int() 214 } 215 } else { 216 value = field.Int() 217 } 218 case fi.fieldType&IsRelField > 0: 219 if field.IsNil() { 220 value = nil 221 } else { 222 if _, vu, ok := getExistPk(fi.relModelInfo, reflect.Indirect(field)); ok { 223 value = vu 224 } else { 225 value = nil 226 } 227 } 228 if !fi.null && value == nil { 229 return nil, fmt.Errorf("field `%s` cannot be NULL", fi.fullName) 230 } 231 } 232 } 233 } 234 switch fi.fieldType { 235 case TypeTimeField, TypeDateField, TypeDateTimeField: 236 if fi.autoNow || fi.autoNowAdd && insert { 237 if insert { 238 if t, ok := value.(time.Time); ok && !t.IsZero() { 239 break 240 } 241 } 242 tnow := time.Now() 243 d.ins.TimeToDB(&tnow, tz) 244 value = tnow 245 if fi.isFielder { 246 f := field.Addr().Interface().(Fielder) 247 f.SetRaw(tnow.In(DefaultTimeLoc)) 248 } else if field.Kind() == reflect.Ptr { 249 v := tnow.In(DefaultTimeLoc) 250 field.Set(reflect.ValueOf(&v)) 251 } else { 252 field.Set(reflect.ValueOf(tnow.In(DefaultTimeLoc))) 253 } 254 } 255 case TypeJSONField, TypeJsonbField: 256 if s, ok := value.(string); (ok && len(s) == 0) || value == nil { 257 if fi.colDefault && fi.initial.Exist() { 258 value = fi.initial.String() 259 } else { 260 value = nil 261 } 262 } 263 } 264 } 265 return value, nil 266 } 267 268 // create insert sql preparation statement object. 269 func (d *dbBase) PrepareInsert(q dbQuerier, mi *modelInfo) (stmtQuerier, string, error) { 270 Q := d.ins.TableQuote() 271 272 dbcols := make([]string, 0, len(mi.fields.dbcols)) 273 marks := make([]string, 0, len(mi.fields.dbcols)) 274 for _, fi := range mi.fields.fieldsDB { 275 if !fi.auto { 276 dbcols = append(dbcols, fi.column) 277 marks = append(marks, "?") 278 } 279 } 280 qmarks := strings.Join(marks, ", ") 281 sep := fmt.Sprintf("%s, %s", Q, Q) 282 columns := strings.Join(dbcols, sep) 283 284 query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) 285 286 d.ins.ReplaceMarks(&query) 287 288 d.ins.HasReturningID(mi, &query) 289 290 stmt, err := q.Prepare(query) 291 return stmt, query, err 292 } 293 294 // insert struct with prepared statement and given struct reflect value. 295 func (d *dbBase) InsertStmt(stmt stmtQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { 296 values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, nil, tz) 297 if err != nil { 298 return 0, err 299 } 300 301 if d.ins.HasReturningID(mi, nil) { 302 row := stmt.QueryRow(values...) 303 var id int64 304 err := row.Scan(&id) 305 return id, err 306 } 307 res, err := stmt.Exec(values...) 308 if err == nil { 309 return res.LastInsertId() 310 } 311 return 0, err 312 } 313 314 // query sql ,read records and persist in dbBaser. 315 func (d *dbBase) Read(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error { 316 var whereCols []string 317 var args []interface{} 318 319 // if specify cols length > 0, then use it for where condition. 320 if len(cols) > 0 { 321 var err error 322 whereCols = make([]string, 0, len(cols)) 323 args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) 324 if err != nil { 325 return err 326 } 327 } else { 328 // default use pk value as where condtion. 329 pkColumn, pkValue, ok := getExistPk(mi, ind) 330 if !ok { 331 return ErrMissPK 332 } 333 whereCols = []string{pkColumn} 334 args = append(args, pkValue) 335 } 336 337 Q := d.ins.TableQuote() 338 339 sep := fmt.Sprintf("%s, %s", Q, Q) 340 sels := strings.Join(mi.fields.dbcols, sep) 341 colsNum := len(mi.fields.dbcols) 342 343 sep = fmt.Sprintf("%s = ? AND %s", Q, Q) 344 wheres := strings.Join(whereCols, sep) 345 346 forUpdate := "" 347 if isForUpdate { 348 forUpdate = "FOR UPDATE" 349 } 350 351 query := fmt.Sprintf("SELECT %s%s%s FROM %s%s%s WHERE %s%s%s = ? %s", Q, sels, Q, Q, mi.table, Q, Q, wheres, Q, forUpdate) 352 353 refs := make([]interface{}, colsNum) 354 for i := range refs { 355 var ref interface{} 356 refs[i] = &ref 357 } 358 359 d.ins.ReplaceMarks(&query) 360 361 row := q.QueryRow(query, args...) 362 if err := row.Scan(refs...); err != nil { 363 if err == sql.ErrNoRows { 364 return ErrNoRows 365 } 366 return err 367 } 368 elm := reflect.New(mi.addrField.Elem().Type()) 369 mind := reflect.Indirect(elm) 370 d.setColsValues(mi, &mind, mi.fields.dbcols, refs, tz) 371 ind.Set(mind) 372 return nil 373 } 374 375 // execute insert sql dbQuerier with given struct reflect.Value. 376 func (d *dbBase) Insert(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location) (int64, error) { 377 names := make([]string, 0, len(mi.fields.dbcols)) 378 values, autoFields, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) 379 if err != nil { 380 return 0, err 381 } 382 383 id, err := d.InsertValue(q, mi, false, names, values) 384 if err != nil { 385 return 0, err 386 } 387 388 if len(autoFields) > 0 { 389 err = d.ins.setval(q, mi, autoFields) 390 } 391 return id, err 392 } 393 394 // multi-insert sql with given slice struct reflect.Value. 395 func (d *dbBase) InsertMulti(q dbQuerier, mi *modelInfo, sind reflect.Value, bulk int, tz *time.Location) (int64, error) { 396 var ( 397 cnt int64 398 nums int 399 values []interface{} 400 names []string 401 ) 402 403 // typ := reflect.Indirect(mi.addrField).Type() 404 405 length, autoFields := sind.Len(), make([]string, 0, 1) 406 407 for i := 1; i <= length; i++ { 408 409 ind := reflect.Indirect(sind.Index(i - 1)) 410 411 // Is this needed ? 412 // if !ind.Type().AssignableTo(typ) { 413 // return cnt, ErrArgs 414 // } 415 416 if i == 1 { 417 var ( 418 vus []interface{} 419 err error 420 ) 421 vus, autoFields, err = d.collectValues(mi, ind, mi.fields.dbcols, false, true, &names, tz) 422 if err != nil { 423 return cnt, err 424 } 425 values = make([]interface{}, bulk*len(vus)) 426 nums += copy(values, vus) 427 } else { 428 vus, _, err := d.collectValues(mi, ind, mi.fields.dbcols, false, true, nil, tz) 429 if err != nil { 430 return cnt, err 431 } 432 433 if len(vus) != len(names) { 434 return cnt, ErrArgs 435 } 436 437 nums += copy(values[nums:], vus) 438 } 439 440 if i > 1 && i%bulk == 0 || length == i { 441 num, err := d.InsertValue(q, mi, true, names, values[:nums]) 442 if err != nil { 443 return cnt, err 444 } 445 cnt += num 446 nums = 0 447 } 448 } 449 450 var err error 451 if len(autoFields) > 0 { 452 err = d.ins.setval(q, mi, autoFields) 453 } 454 455 return cnt, err 456 } 457 458 // execute insert sql with given struct and given values. 459 // insert the given values, not the field values in struct. 460 func (d *dbBase) InsertValue(q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) { 461 Q := d.ins.TableQuote() 462 463 marks := make([]string, len(names)) 464 for i := range marks { 465 marks[i] = "?" 466 } 467 468 sep := fmt.Sprintf("%s, %s", Q, Q) 469 qmarks := strings.Join(marks, ", ") 470 columns := strings.Join(names, sep) 471 472 multi := len(values) / len(names) 473 474 if isMulti && multi > 1 { 475 qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks 476 } 477 478 query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks) 479 480 d.ins.ReplaceMarks(&query) 481 482 if isMulti || !d.ins.HasReturningID(mi, &query) { 483 res, err := q.Exec(query, values...) 484 if err == nil { 485 if isMulti { 486 return res.RowsAffected() 487 } 488 return res.LastInsertId() 489 } 490 return 0, err 491 } 492 row := q.QueryRow(query, values...) 493 var id int64 494 err := row.Scan(&id) 495 return id, err 496 } 497 498 // InsertOrUpdate a row 499 // If your primary key or unique column conflict will update 500 // If no will insert 501 func (d *dbBase) InsertOrUpdate(q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) { 502 args0 := "" 503 iouStr := "" 504 argsMap := map[string]string{} 505 switch a.Driver { 506 case DRMySQL: 507 iouStr = "ON DUPLICATE KEY UPDATE" 508 case DRPostgres: 509 if len(args) == 0 { 510 return 0, fmt.Errorf("`%s` use InsertOrUpdate must have a conflict column", a.DriverName) 511 } 512 args0 = strings.ToLower(args[0]) 513 iouStr = fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", args0) 514 default: 515 return 0, fmt.Errorf("`%s` nonsupport InsertOrUpdate in beego", a.DriverName) 516 } 517 518 //Get on the key-value pairs 519 for _, v := range args { 520 kv := strings.Split(v, "=") 521 if len(kv) == 2 { 522 argsMap[strings.ToLower(kv[0])] = kv[1] 523 } 524 } 525 526 isMulti := false 527 names := make([]string, 0, len(mi.fields.dbcols)-1) 528 Q := d.ins.TableQuote() 529 values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ) 530 531 if err != nil { 532 return 0, err 533 } 534 535 marks := make([]string, len(names)) 536 updateValues := make([]interface{}, 0) 537 updates := make([]string, len(names)) 538 var conflitValue interface{} 539 for i, v := range names { 540 // identifier in database may not be case-sensitive, so quote it 541 v = fmt.Sprintf("%s%s%s", Q, v, Q) 542 marks[i] = "?" 543 valueStr := argsMap[strings.ToLower(v)] 544 if v == args0 { 545 conflitValue = values[i] 546 } 547 if valueStr != "" { 548 switch a.Driver { 549 case DRMySQL: 550 updates[i] = v + "=" + valueStr 551 case DRPostgres: 552 if conflitValue != nil { 553 //postgres ON CONFLICT DO UPDATE SET can`t use colu=colu+values 554 updates[i] = fmt.Sprintf("%s=(select %s from %s where %s = ? )", v, valueStr, mi.table, args0) 555 updateValues = append(updateValues, conflitValue) 556 } else { 557 return 0, fmt.Errorf("`%s` must be in front of `%s` in your struct", args0, v) 558 } 559 } 560 } else { 561 updates[i] = v + "=?" 562 updateValues = append(updateValues, values[i]) 563 } 564 } 565 566 values = append(values, updateValues...) 567 568 sep := fmt.Sprintf("%s, %s", Q, Q) 569 qmarks := strings.Join(marks, ", ") 570 qupdates := strings.Join(updates, ", ") 571 columns := strings.Join(names, sep) 572 573 multi := len(values) / len(names) 574 575 if isMulti { 576 qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks 577 } 578 //conflitValue maybe is a int,can`t use fmt.Sprintf 579 query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr) 580 581 d.ins.ReplaceMarks(&query) 582 583 if isMulti || !d.ins.HasReturningID(mi, &query) { 584 res, err := q.Exec(query, values...) 585 if err == nil { 586 if isMulti { 587 return res.RowsAffected() 588 } 589 return res.LastInsertId() 590 } 591 return 0, err 592 } 593 594 row := q.QueryRow(query, values...) 595 var id int64 596 err = row.Scan(&id) 597 if err != nil && err.Error() == `pq: syntax error at or near "ON"` { 598 err = fmt.Errorf("postgres version must 9.5 or higher") 599 } 600 return id, err 601 } 602 603 // execute update sql dbQuerier with given struct reflect.Value. 604 func (d *dbBase) Update(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { 605 pkName, pkValue, ok := getExistPk(mi, ind) 606 if !ok { 607 return 0, ErrMissPK 608 } 609 610 var setNames []string 611 612 // if specify cols length is zero, then commit all columns. 613 if len(cols) == 0 { 614 cols = mi.fields.dbcols 615 setNames = make([]string, 0, len(mi.fields.dbcols)-1) 616 } else { 617 setNames = make([]string, 0, len(cols)) 618 } 619 620 setValues, _, err := d.collectValues(mi, ind, cols, true, false, &setNames, tz) 621 if err != nil { 622 return 0, err 623 } 624 625 var findAutoNowAdd, findAutoNow bool 626 var index int 627 for i, col := range setNames { 628 if mi.fields.GetByColumn(col).autoNowAdd { 629 index = i 630 findAutoNowAdd = true 631 } 632 if mi.fields.GetByColumn(col).autoNow { 633 findAutoNow = true 634 } 635 } 636 if findAutoNowAdd { 637 setNames = append(setNames[0:index], setNames[index+1:]...) 638 setValues = append(setValues[0:index], setValues[index+1:]...) 639 } 640 641 if !findAutoNow { 642 for col, info := range mi.fields.columns { 643 if info.autoNow { 644 setNames = append(setNames, col) 645 setValues = append(setValues, time.Now()) 646 } 647 } 648 } 649 650 setValues = append(setValues, pkValue) 651 652 Q := d.ins.TableQuote() 653 654 sep := fmt.Sprintf("%s = ?, %s", Q, Q) 655 setColumns := strings.Join(setNames, sep) 656 657 query := fmt.Sprintf("UPDATE %s%s%s SET %s%s%s = ? WHERE %s%s%s = ?", Q, mi.table, Q, Q, setColumns, Q, Q, pkName, Q) 658 659 d.ins.ReplaceMarks(&query) 660 661 res, err := q.Exec(query, setValues...) 662 if err == nil { 663 return res.RowsAffected() 664 } 665 return 0, err 666 } 667 668 // execute delete sql dbQuerier with given struct reflect.Value. 669 // delete index is pk. 670 func (d *dbBase) Delete(q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string) (int64, error) { 671 var whereCols []string 672 var args []interface{} 673 // if specify cols length > 0, then use it for where condition. 674 if len(cols) > 0 { 675 var err error 676 whereCols = make([]string, 0, len(cols)) 677 args, _, err = d.collectValues(mi, ind, cols, false, false, &whereCols, tz) 678 if err != nil { 679 return 0, err 680 } 681 } else { 682 // default use pk value as where condtion. 683 pkColumn, pkValue, ok := getExistPk(mi, ind) 684 if !ok { 685 return 0, ErrMissPK 686 } 687 whereCols = []string{pkColumn} 688 args = append(args, pkValue) 689 } 690 691 Q := d.ins.TableQuote() 692 693 sep := fmt.Sprintf("%s = ? AND %s", Q, Q) 694 wheres := strings.Join(whereCols, sep) 695 696 query := fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s = ?", Q, mi.table, Q, Q, wheres, Q) 697 698 d.ins.ReplaceMarks(&query) 699 res, err := q.Exec(query, args...) 700 if err == nil { 701 num, err := res.RowsAffected() 702 if err != nil { 703 return 0, err 704 } 705 if num > 0 { 706 if mi.fields.pk.auto { 707 if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 { 708 ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(0) 709 } else { 710 ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(0) 711 } 712 } 713 err := d.deleteRels(q, mi, args, tz) 714 if err != nil { 715 return num, err 716 } 717 } 718 return num, err 719 } 720 return 0, err 721 } 722 723 // update table-related record by querySet. 724 // need querySet not struct reflect.Value to update related records. 725 func (d *dbBase) UpdateBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, params Params, tz *time.Location) (int64, error) { 726 columns := make([]string, 0, len(params)) 727 values := make([]interface{}, 0, len(params)) 728 for col, val := range params { 729 if fi, ok := mi.fields.GetByAny(col); !ok || !fi.dbcol { 730 panic(fmt.Errorf("wrong field/column name `%s`", col)) 731 } else { 732 columns = append(columns, fi.column) 733 values = append(values, val) 734 } 735 } 736 737 if len(columns) == 0 { 738 panic(fmt.Errorf("update params cannot empty")) 739 } 740 741 tables := newDbTables(mi, d.ins) 742 if qs != nil { 743 tables.parseRelated(qs.related, qs.relDepth) 744 } 745 746 where, args := tables.getCondSQL(cond, false, tz) 747 748 values = append(values, args...) 749 750 join := tables.getJoinSQL() 751 752 var query, T string 753 754 Q := d.ins.TableQuote() 755 756 if d.ins.SupportUpdateJoin() { 757 T = "T0." 758 } 759 760 cols := make([]string, 0, len(columns)) 761 762 for i, v := range columns { 763 col := fmt.Sprintf("%s%s%s%s", T, Q, v, Q) 764 if c, ok := values[i].(colValue); ok { 765 switch c.opt { 766 case ColAdd: 767 cols = append(cols, col+" = "+col+" + ?") 768 case ColMinus: 769 cols = append(cols, col+" = "+col+" - ?") 770 case ColMultiply: 771 cols = append(cols, col+" = "+col+" * ?") 772 case ColExcept: 773 cols = append(cols, col+" = "+col+" / ?") 774 case ColBitAnd: 775 cols = append(cols, col+" = "+col+" & ?") 776 case ColBitRShift: 777 cols = append(cols, col+" = "+col+" >> ?") 778 case ColBitLShift: 779 cols = append(cols, col+" = "+col+" << ?") 780 case ColBitXOR: 781 cols = append(cols, col+" = "+col+" ^ ?") 782 case ColBitOr: 783 cols = append(cols, col+" = "+col+" | ?") 784 } 785 values[i] = c.value 786 } else { 787 cols = append(cols, col+" = ?") 788 } 789 } 790 791 sets := strings.Join(cols, ", ") + " " 792 793 if d.ins.SupportUpdateJoin() { 794 query = fmt.Sprintf("UPDATE %s%s%s T0 %sSET %s%s", Q, mi.table, Q, join, sets, where) 795 } else { 796 supQuery := fmt.Sprintf("SELECT T0.%s%s%s FROM %s%s%s T0 %s%s", Q, mi.fields.pk.column, Q, Q, mi.table, Q, join, where) 797 query = fmt.Sprintf("UPDATE %s%s%s SET %sWHERE %s%s%s IN ( %s )", Q, mi.table, Q, sets, Q, mi.fields.pk.column, Q, supQuery) 798 } 799 800 d.ins.ReplaceMarks(&query) 801 var err error 802 var res sql.Result 803 if qs != nil && qs.forContext { 804 res, err = q.ExecContext(qs.ctx, query, values...) 805 } else { 806 res, err = q.Exec(query, values...) 807 } 808 if err == nil { 809 return res.RowsAffected() 810 } 811 return 0, err 812 } 813 814 // delete related records. 815 // do UpdateBanch or DeleteBanch by condition of tables' relationship. 816 func (d *dbBase) deleteRels(q dbQuerier, mi *modelInfo, args []interface{}, tz *time.Location) error { 817 for _, fi := range mi.fields.fieldsReverse { 818 fi = fi.reverseFieldInfo 819 switch fi.onDelete { 820 case odCascade: 821 cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) 822 _, err := d.DeleteBatch(q, nil, fi.mi, cond, tz) 823 if err != nil { 824 return err 825 } 826 case odSetDefault, odSetNULL: 827 cond := NewCondition().And(fmt.Sprintf("%s__in", fi.name), args...) 828 params := Params{fi.column: nil} 829 if fi.onDelete == odSetDefault { 830 params[fi.column] = fi.initial.String() 831 } 832 _, err := d.UpdateBatch(q, nil, fi.mi, cond, params, tz) 833 if err != nil { 834 return err 835 } 836 case odDoNothing: 837 } 838 } 839 return nil 840 } 841 842 // delete table-related records. 843 func (d *dbBase) DeleteBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (int64, error) { 844 tables := newDbTables(mi, d.ins) 845 tables.skipEnd = true 846 847 if qs != nil { 848 tables.parseRelated(qs.related, qs.relDepth) 849 } 850 851 if cond == nil || cond.IsEmpty() { 852 panic(fmt.Errorf("delete operation cannot execute without condition")) 853 } 854 855 Q := d.ins.TableQuote() 856 857 where, args := tables.getCondSQL(cond, false, tz) 858 join := tables.getJoinSQL() 859 860 cols := fmt.Sprintf("T0.%s%s%s", Q, mi.fields.pk.column, Q) 861 query := fmt.Sprintf("SELECT %s FROM %s%s%s T0 %s%s", cols, Q, mi.table, Q, join, where) 862 863 d.ins.ReplaceMarks(&query) 864 865 var rs *sql.Rows 866 r, err := q.Query(query, args...) 867 if err != nil { 868 return 0, err 869 } 870 rs = r 871 defer rs.Close() 872 873 var ref interface{} 874 args = make([]interface{}, 0) 875 cnt := 0 876 for rs.Next() { 877 if err := rs.Scan(&ref); err != nil { 878 return 0, err 879 } 880 pkValue, err := d.convertValueFromDB(mi.fields.pk, reflect.ValueOf(ref).Interface(), tz) 881 if err != nil { 882 return 0, err 883 } 884 args = append(args, pkValue) 885 cnt++ 886 } 887 888 if cnt == 0 { 889 return 0, nil 890 } 891 892 marks := make([]string, len(args)) 893 for i := range marks { 894 marks[i] = "?" 895 } 896 sqlIn := fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) 897 query = fmt.Sprintf("DELETE FROM %s%s%s WHERE %s%s%s %s", Q, mi.table, Q, Q, mi.fields.pk.column, Q, sqlIn) 898 899 d.ins.ReplaceMarks(&query) 900 var res sql.Result 901 if qs != nil && qs.forContext { 902 res, err = q.ExecContext(qs.ctx, query, args...) 903 } else { 904 res, err = q.Exec(query, args...) 905 } 906 if err == nil { 907 num, err := res.RowsAffected() 908 if err != nil { 909 return 0, err 910 } 911 if num > 0 { 912 err := d.deleteRels(q, mi, args, tz) 913 if err != nil { 914 return num, err 915 } 916 } 917 return num, nil 918 } 919 return 0, err 920 } 921 922 // read related records. 923 func (d *dbBase) ReadBatch(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, container interface{}, tz *time.Location, cols []string) (int64, error) { 924 925 val := reflect.ValueOf(container) 926 ind := reflect.Indirect(val) 927 928 errTyp := true 929 one := true 930 isPtr := true 931 932 if val.Kind() == reflect.Ptr { 933 fn := "" 934 if ind.Kind() == reflect.Slice { 935 one = false 936 typ := ind.Type().Elem() 937 switch typ.Kind() { 938 case reflect.Ptr: 939 fn = getFullName(typ.Elem()) 940 case reflect.Struct: 941 isPtr = false 942 fn = getFullName(typ) 943 } 944 } else { 945 fn = getFullName(ind.Type()) 946 } 947 errTyp = fn != mi.fullName 948 } 949 950 if errTyp { 951 if one { 952 panic(fmt.Errorf("wrong object type `%s` for rows scan, need *%s", val.Type(), mi.fullName)) 953 } else { 954 panic(fmt.Errorf("wrong object type `%s` for rows scan, need *[]*%s or *[]%s", val.Type(), mi.fullName, mi.fullName)) 955 } 956 } 957 958 rlimit := qs.limit 959 offset := qs.offset 960 961 Q := d.ins.TableQuote() 962 963 var tCols []string 964 if len(cols) > 0 { 965 hasRel := len(qs.related) > 0 || qs.relDepth > 0 966 tCols = make([]string, 0, len(cols)) 967 var maps map[string]bool 968 if hasRel { 969 maps = make(map[string]bool) 970 } 971 for _, col := range cols { 972 if fi, ok := mi.fields.GetByAny(col); ok { 973 tCols = append(tCols, fi.column) 974 if hasRel { 975 maps[fi.column] = true 976 } 977 } else { 978 return 0, fmt.Errorf("wrong field/column name `%s`", col) 979 } 980 } 981 if hasRel { 982 for _, fi := range mi.fields.fieldsDB { 983 if fi.fieldType&IsRelField > 0 { 984 if !maps[fi.column] { 985 tCols = append(tCols, fi.column) 986 } 987 } 988 } 989 } 990 } else { 991 tCols = mi.fields.dbcols 992 } 993 994 colsNum := len(tCols) 995 sep := fmt.Sprintf("%s, T0.%s", Q, Q) 996 sels := fmt.Sprintf("T0.%s%s%s", Q, strings.Join(tCols, sep), Q) 997 998 tables := newDbTables(mi, d.ins) 999 tables.parseRelated(qs.related, qs.relDepth) 1000 1001 where, args := tables.getCondSQL(cond, false, tz) 1002 groupBy := tables.getGroupSQL(qs.groups) 1003 orderBy := tables.getOrderSQL(qs.orders) 1004 limit := tables.getLimitSQL(mi, offset, rlimit) 1005 join := tables.getJoinSQL() 1006 1007 for _, tbl := range tables.tables { 1008 if tbl.sel { 1009 colsNum += len(tbl.mi.fields.dbcols) 1010 sep := fmt.Sprintf("%s, %s.%s", Q, tbl.index, Q) 1011 sels += fmt.Sprintf(", %s.%s%s%s", tbl.index, Q, strings.Join(tbl.mi.fields.dbcols, sep), Q) 1012 } 1013 } 1014 1015 sqlSelect := "SELECT" 1016 if qs.distinct { 1017 sqlSelect += " DISTINCT" 1018 } 1019 query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) 1020 1021 if qs.forupdate { 1022 query += " FOR UPDATE" 1023 } 1024 1025 d.ins.ReplaceMarks(&query) 1026 1027 var rs *sql.Rows 1028 var err error 1029 if qs != nil && qs.forContext { 1030 rs, err = q.QueryContext(qs.ctx, query, args...) 1031 if err != nil { 1032 return 0, err 1033 } 1034 } else { 1035 rs, err = q.Query(query, args...) 1036 if err != nil { 1037 return 0, err 1038 } 1039 } 1040 1041 refs := make([]interface{}, colsNum) 1042 for i := range refs { 1043 var ref interface{} 1044 refs[i] = &ref 1045 } 1046 1047 defer rs.Close() 1048 1049 slice := ind 1050 1051 var cnt int64 1052 for rs.Next() { 1053 if one && cnt == 0 || !one { 1054 if err := rs.Scan(refs...); err != nil { 1055 return 0, err 1056 } 1057 1058 elm := reflect.New(mi.addrField.Elem().Type()) 1059 mind := reflect.Indirect(elm) 1060 1061 cacheV := make(map[string]*reflect.Value) 1062 cacheM := make(map[string]*modelInfo) 1063 trefs := refs 1064 1065 d.setColsValues(mi, &mind, tCols, refs[:len(tCols)], tz) 1066 trefs = refs[len(tCols):] 1067 1068 for _, tbl := range tables.tables { 1069 // loop selected tables 1070 if tbl.sel { 1071 last := mind 1072 names := "" 1073 mmi := mi 1074 // loop cascade models 1075 for _, name := range tbl.names { 1076 names += name 1077 if val, ok := cacheV[names]; ok { 1078 last = *val 1079 mmi = cacheM[names] 1080 } else { 1081 fi := mmi.fields.GetByName(name) 1082 lastm := mmi 1083 mmi = fi.relModelInfo 1084 field := last 1085 if last.Kind() != reflect.Invalid { 1086 field = reflect.Indirect(last.FieldByIndex(fi.fieldIndex)) 1087 if field.IsValid() { 1088 d.setColsValues(mmi, &field, mmi.fields.dbcols, trefs[:len(mmi.fields.dbcols)], tz) 1089 for _, fi := range mmi.fields.fieldsReverse { 1090 if fi.inModel && fi.reverseFieldInfo.mi == lastm { 1091 if fi.reverseFieldInfo != nil { 1092 f := field.FieldByIndex(fi.fieldIndex) 1093 if f.Kind() == reflect.Ptr { 1094 f.Set(last.Addr()) 1095 } 1096 } 1097 } 1098 } 1099 last = field 1100 } 1101 } 1102 cacheV[names] = &field 1103 cacheM[names] = mmi 1104 } 1105 } 1106 trefs = trefs[len(mmi.fields.dbcols):] 1107 } 1108 } 1109 1110 if one { 1111 ind.Set(mind) 1112 } else { 1113 if cnt == 0 { 1114 // you can use a empty & caped container list 1115 // orm will not replace it 1116 if ind.Len() != 0 { 1117 // if container is not empty 1118 // create a new one 1119 slice = reflect.New(ind.Type()).Elem() 1120 } 1121 } 1122 1123 if isPtr { 1124 slice = reflect.Append(slice, mind.Addr()) 1125 } else { 1126 slice = reflect.Append(slice, mind) 1127 } 1128 } 1129 } 1130 cnt++ 1131 } 1132 1133 if !one { 1134 if cnt > 0 { 1135 ind.Set(slice) 1136 } else { 1137 // when a result is empty and container is nil 1138 // to set a empty container 1139 if ind.IsNil() { 1140 ind.Set(reflect.MakeSlice(ind.Type(), 0, 0)) 1141 } 1142 } 1143 } 1144 1145 return cnt, nil 1146 } 1147 1148 // excute count sql and return count result int64. 1149 func (d *dbBase) Count(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, tz *time.Location) (cnt int64, err error) { 1150 tables := newDbTables(mi, d.ins) 1151 tables.parseRelated(qs.related, qs.relDepth) 1152 1153 where, args := tables.getCondSQL(cond, false, tz) 1154 groupBy := tables.getGroupSQL(qs.groups) 1155 tables.getOrderSQL(qs.orders) 1156 join := tables.getJoinSQL() 1157 1158 Q := d.ins.TableQuote() 1159 1160 query := fmt.Sprintf("SELECT COUNT(*) FROM %s%s%s T0 %s%s%s", Q, mi.table, Q, join, where, groupBy) 1161 1162 if groupBy != "" { 1163 query = fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS T", query) 1164 } 1165 1166 d.ins.ReplaceMarks(&query) 1167 1168 var row *sql.Row 1169 if qs != nil && qs.forContext { 1170 row = q.QueryRowContext(qs.ctx, query, args...) 1171 } else { 1172 row = q.QueryRow(query, args...) 1173 } 1174 err = row.Scan(&cnt) 1175 return 1176 } 1177 1178 // generate sql with replacing operator string placeholders and replaced values. 1179 func (d *dbBase) GenerateOperatorSQL(mi *modelInfo, fi *fieldInfo, operator string, args []interface{}, tz *time.Location) (string, []interface{}) { 1180 var sql string 1181 params := getFlatParams(fi, args, tz) 1182 1183 if len(params) == 0 { 1184 panic(fmt.Errorf("operator `%s` need at least one args", operator)) 1185 } 1186 arg := params[0] 1187 1188 switch operator { 1189 case "in": 1190 marks := make([]string, len(params)) 1191 for i := range marks { 1192 marks[i] = "?" 1193 } 1194 sql = fmt.Sprintf("IN (%s)", strings.Join(marks, ", ")) 1195 case "between": 1196 if len(params) != 2 { 1197 panic(fmt.Errorf("operator `%s` need 2 args not %d", operator, len(params))) 1198 } 1199 sql = "BETWEEN ? AND ?" 1200 default: 1201 if len(params) > 1 { 1202 panic(fmt.Errorf("operator `%s` need 1 args not %d", operator, len(params))) 1203 } 1204 sql = d.ins.OperatorSQL(operator) 1205 switch operator { 1206 case "exact", "strictexact": 1207 if arg == nil { 1208 params[0] = "IS NULL" 1209 } 1210 case "iexact", "contains", "icontains", "startswith", "endswith", "istartswith", "iendswith": 1211 param := strings.Replace(ToStr(arg), `%`, `\%`, -1) 1212 switch operator { 1213 case "iexact": 1214 case "contains", "icontains": 1215 param = fmt.Sprintf("%%%s%%", param) 1216 case "startswith", "istartswith": 1217 param = fmt.Sprintf("%s%%", param) 1218 case "endswith", "iendswith": 1219 param = fmt.Sprintf("%%%s", param) 1220 } 1221 params[0] = param 1222 case "isnull": 1223 if b, ok := arg.(bool); ok { 1224 if b { 1225 sql = "IS NULL" 1226 } else { 1227 sql = "IS NOT NULL" 1228 } 1229 params = nil 1230 } else { 1231 panic(fmt.Errorf("operator `%s` need a bool value not `%T`", operator, arg)) 1232 } 1233 } 1234 } 1235 return sql, params 1236 } 1237 1238 // gernerate sql string with inner function, such as UPPER(text). 1239 func (d *dbBase) GenerateOperatorLeftCol(*fieldInfo, string, *string) { 1240 // default not use 1241 } 1242 1243 // set values to struct column. 1244 func (d *dbBase) setColsValues(mi *modelInfo, ind *reflect.Value, cols []string, values []interface{}, tz *time.Location) { 1245 for i, column := range cols { 1246 val := reflect.Indirect(reflect.ValueOf(values[i])).Interface() 1247 1248 fi := mi.fields.GetByColumn(column) 1249 1250 field := ind.FieldByIndex(fi.fieldIndex) 1251 1252 value, err := d.convertValueFromDB(fi, val, tz) 1253 if err != nil { 1254 panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) 1255 } 1256 1257 _, err = d.setFieldValue(fi, value, field) 1258 1259 if err != nil { 1260 panic(fmt.Errorf("Raw value: `%v` %s", val, err.Error())) 1261 } 1262 } 1263 } 1264 1265 // convert value from database result to value following in field type. 1266 func (d *dbBase) convertValueFromDB(fi *fieldInfo, val interface{}, tz *time.Location) (interface{}, error) { 1267 if val == nil { 1268 return nil, nil 1269 } 1270 1271 var value interface{} 1272 var tErr error 1273 1274 var str *StrTo 1275 switch v := val.(type) { 1276 case []byte: 1277 s := StrTo(string(v)) 1278 str = &s 1279 case string: 1280 s := StrTo(v) 1281 str = &s 1282 } 1283 1284 fieldType := fi.fieldType 1285 1286 setValue: 1287 switch { 1288 case fieldType == TypeBooleanField: 1289 if str == nil { 1290 switch v := val.(type) { 1291 case int64: 1292 b := v == 1 1293 value = b 1294 default: 1295 s := StrTo(ToStr(v)) 1296 str = &s 1297 } 1298 } 1299 if str != nil { 1300 b, err := str.Bool() 1301 if err != nil { 1302 tErr = err 1303 goto end 1304 } 1305 value = b 1306 } 1307 case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: 1308 if str == nil { 1309 value = ToStr(val) 1310 } else { 1311 value = str.String() 1312 } 1313 case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: 1314 if str == nil { 1315 switch t := val.(type) { 1316 case time.Time: 1317 d.ins.TimeFromDB(&t, tz) 1318 value = t 1319 default: 1320 s := StrTo(ToStr(t)) 1321 str = &s 1322 } 1323 } 1324 if str != nil { 1325 s := str.String() 1326 var ( 1327 t time.Time 1328 err error 1329 ) 1330 if len(s) >= 19 { 1331 s = s[:19] 1332 t, err = time.ParseInLocation(formatDateTime, s, tz) 1333 } else if len(s) >= 10 { 1334 if len(s) > 10 { 1335 s = s[:10] 1336 } 1337 t, err = time.ParseInLocation(formatDate, s, tz) 1338 } else if len(s) >= 8 { 1339 if len(s) > 8 { 1340 s = s[:8] 1341 } 1342 t, err = time.ParseInLocation(formatTime, s, tz) 1343 } 1344 t = t.In(DefaultTimeLoc) 1345 1346 if err != nil && s != "00:00:00" && s != "0000-00-00" && s != "0000-00-00 00:00:00" { 1347 tErr = err 1348 goto end 1349 } 1350 value = t 1351 } 1352 case fieldType&IsIntegerField > 0: 1353 if str == nil { 1354 s := StrTo(ToStr(val)) 1355 str = &s 1356 } 1357 if str != nil { 1358 var err error 1359 switch fieldType { 1360 case TypeBitField: 1361 _, err = str.Int8() 1362 case TypeSmallIntegerField: 1363 _, err = str.Int16() 1364 case TypeIntegerField: 1365 _, err = str.Int32() 1366 case TypeBigIntegerField: 1367 _, err = str.Int64() 1368 case TypePositiveBitField: 1369 _, err = str.Uint8() 1370 case TypePositiveSmallIntegerField: 1371 _, err = str.Uint16() 1372 case TypePositiveIntegerField: 1373 _, err = str.Uint32() 1374 case TypePositiveBigIntegerField: 1375 _, err = str.Uint64() 1376 } 1377 if err != nil { 1378 tErr = err 1379 goto end 1380 } 1381 if fieldType&IsPositiveIntegerField > 0 { 1382 v, _ := str.Uint64() 1383 value = v 1384 } else { 1385 v, _ := str.Int64() 1386 value = v 1387 } 1388 } 1389 case fieldType == TypeFloatField || fieldType == TypeDecimalField: 1390 if str == nil { 1391 switch v := val.(type) { 1392 case float64: 1393 value = v 1394 default: 1395 s := StrTo(ToStr(v)) 1396 str = &s 1397 } 1398 } 1399 if str != nil { 1400 v, err := str.Float64() 1401 if err != nil { 1402 tErr = err 1403 goto end 1404 } 1405 value = v 1406 } 1407 case fieldType&IsRelField > 0: 1408 fi = fi.relModelInfo.fields.pk 1409 fieldType = fi.fieldType 1410 goto setValue 1411 } 1412 1413 end: 1414 if tErr != nil { 1415 err := fmt.Errorf("convert to `%s` failed, field: %s err: %s", fi.addrValue.Type(), fi.fullName, tErr) 1416 return nil, err 1417 } 1418 1419 return value, nil 1420 1421 } 1422 1423 // set one value to struct column field. 1424 func (d *dbBase) setFieldValue(fi *fieldInfo, value interface{}, field reflect.Value) (interface{}, error) { 1425 1426 fieldType := fi.fieldType 1427 isNative := !fi.isFielder 1428 1429 setValue: 1430 switch { 1431 case fieldType == TypeBooleanField: 1432 if isNative { 1433 if nb, ok := field.Interface().(sql.NullBool); ok { 1434 if value == nil { 1435 nb.Valid = false 1436 } else { 1437 nb.Bool = value.(bool) 1438 nb.Valid = true 1439 } 1440 field.Set(reflect.ValueOf(nb)) 1441 } else if field.Kind() == reflect.Ptr { 1442 if value != nil { 1443 v := value.(bool) 1444 field.Set(reflect.ValueOf(&v)) 1445 } 1446 } else { 1447 if value == nil { 1448 value = false 1449 } 1450 field.SetBool(value.(bool)) 1451 } 1452 } 1453 case fieldType == TypeVarCharField || fieldType == TypeCharField || fieldType == TypeTextField || fieldType == TypeJSONField || fieldType == TypeJsonbField: 1454 if isNative { 1455 if ns, ok := field.Interface().(sql.NullString); ok { 1456 if value == nil { 1457 ns.Valid = false 1458 } else { 1459 ns.String = value.(string) 1460 ns.Valid = true 1461 } 1462 field.Set(reflect.ValueOf(ns)) 1463 } else if field.Kind() == reflect.Ptr { 1464 if value != nil { 1465 v := value.(string) 1466 field.Set(reflect.ValueOf(&v)) 1467 } 1468 } else { 1469 if value == nil { 1470 value = "" 1471 } 1472 field.SetString(value.(string)) 1473 } 1474 } 1475 case fieldType == TypeTimeField || fieldType == TypeDateField || fieldType == TypeDateTimeField: 1476 if isNative { 1477 if value == nil { 1478 value = time.Time{} 1479 } else if field.Kind() == reflect.Ptr { 1480 if value != nil { 1481 v := value.(time.Time) 1482 field.Set(reflect.ValueOf(&v)) 1483 } 1484 } else { 1485 field.Set(reflect.ValueOf(value)) 1486 } 1487 } 1488 case fieldType == TypePositiveBitField && field.Kind() == reflect.Ptr: 1489 if value != nil { 1490 v := uint8(value.(uint64)) 1491 field.Set(reflect.ValueOf(&v)) 1492 } 1493 case fieldType == TypePositiveSmallIntegerField && field.Kind() == reflect.Ptr: 1494 if value != nil { 1495 v := uint16(value.(uint64)) 1496 field.Set(reflect.ValueOf(&v)) 1497 } 1498 case fieldType == TypePositiveIntegerField && field.Kind() == reflect.Ptr: 1499 if value != nil { 1500 if field.Type() == reflect.TypeOf(new(uint)) { 1501 v := uint(value.(uint64)) 1502 field.Set(reflect.ValueOf(&v)) 1503 } else { 1504 v := uint32(value.(uint64)) 1505 field.Set(reflect.ValueOf(&v)) 1506 } 1507 } 1508 case fieldType == TypePositiveBigIntegerField && field.Kind() == reflect.Ptr: 1509 if value != nil { 1510 v := value.(uint64) 1511 field.Set(reflect.ValueOf(&v)) 1512 } 1513 case fieldType == TypeBitField && field.Kind() == reflect.Ptr: 1514 if value != nil { 1515 v := int8(value.(int64)) 1516 field.Set(reflect.ValueOf(&v)) 1517 } 1518 case fieldType == TypeSmallIntegerField && field.Kind() == reflect.Ptr: 1519 if value != nil { 1520 v := int16(value.(int64)) 1521 field.Set(reflect.ValueOf(&v)) 1522 } 1523 case fieldType == TypeIntegerField && field.Kind() == reflect.Ptr: 1524 if value != nil { 1525 if field.Type() == reflect.TypeOf(new(int)) { 1526 v := int(value.(int64)) 1527 field.Set(reflect.ValueOf(&v)) 1528 } else { 1529 v := int32(value.(int64)) 1530 field.Set(reflect.ValueOf(&v)) 1531 } 1532 } 1533 case fieldType == TypeBigIntegerField && field.Kind() == reflect.Ptr: 1534 if value != nil { 1535 v := value.(int64) 1536 field.Set(reflect.ValueOf(&v)) 1537 } 1538 case fieldType&IsIntegerField > 0: 1539 if fieldType&IsPositiveIntegerField > 0 { 1540 if isNative { 1541 if value == nil { 1542 value = uint64(0) 1543 } 1544 field.SetUint(value.(uint64)) 1545 } 1546 } else { 1547 if isNative { 1548 if ni, ok := field.Interface().(sql.NullInt64); ok { 1549 if value == nil { 1550 ni.Valid = false 1551 } else { 1552 ni.Int64 = value.(int64) 1553 ni.Valid = true 1554 } 1555 field.Set(reflect.ValueOf(ni)) 1556 } else { 1557 if value == nil { 1558 value = int64(0) 1559 } 1560 field.SetInt(value.(int64)) 1561 } 1562 } 1563 } 1564 case fieldType == TypeFloatField || fieldType == TypeDecimalField: 1565 if isNative { 1566 if nf, ok := field.Interface().(sql.NullFloat64); ok { 1567 if value == nil { 1568 nf.Valid = false 1569 } else { 1570 nf.Float64 = value.(float64) 1571 nf.Valid = true 1572 } 1573 field.Set(reflect.ValueOf(nf)) 1574 } else if field.Kind() == reflect.Ptr { 1575 if value != nil { 1576 if field.Type() == reflect.TypeOf(new(float32)) { 1577 v := float32(value.(float64)) 1578 field.Set(reflect.ValueOf(&v)) 1579 } else { 1580 v := value.(float64) 1581 field.Set(reflect.ValueOf(&v)) 1582 } 1583 } 1584 } else { 1585 1586 if value == nil { 1587 value = float64(0) 1588 } 1589 field.SetFloat(value.(float64)) 1590 } 1591 } 1592 case fieldType&IsRelField > 0: 1593 if value != nil { 1594 fieldType = fi.relModelInfo.fields.pk.fieldType 1595 mf := reflect.New(fi.relModelInfo.addrField.Elem().Type()) 1596 field.Set(mf) 1597 f := mf.Elem().FieldByIndex(fi.relModelInfo.fields.pk.fieldIndex) 1598 field = f 1599 goto setValue 1600 } 1601 } 1602 1603 if !isNative { 1604 fd := field.Addr().Interface().(Fielder) 1605 err := fd.SetRaw(value) 1606 if err != nil { 1607 err = fmt.Errorf("converted value `%v` set to Fielder `%s` failed, err: %s", value, fi.fullName, err) 1608 return nil, err 1609 } 1610 } 1611 1612 return value, nil 1613 } 1614 1615 // query sql, read values , save to *[]ParamList. 1616 func (d *dbBase) ReadValues(q dbQuerier, qs *querySet, mi *modelInfo, cond *Condition, exprs []string, container interface{}, tz *time.Location) (int64, error) { 1617 1618 var ( 1619 maps []Params 1620 lists []ParamsList 1621 list ParamsList 1622 ) 1623 1624 typ := 0 1625 switch v := container.(type) { 1626 case *[]Params: 1627 d := *v 1628 if len(d) == 0 { 1629 maps = d 1630 } 1631 typ = 1 1632 case *[]ParamsList: 1633 d := *v 1634 if len(d) == 0 { 1635 lists = d 1636 } 1637 typ = 2 1638 case *ParamsList: 1639 d := *v 1640 if len(d) == 0 { 1641 list = d 1642 } 1643 typ = 3 1644 default: 1645 panic(fmt.Errorf("unsupport read values type `%T`", container)) 1646 } 1647 1648 tables := newDbTables(mi, d.ins) 1649 1650 var ( 1651 cols []string 1652 infos []*fieldInfo 1653 ) 1654 1655 hasExprs := len(exprs) > 0 1656 1657 Q := d.ins.TableQuote() 1658 1659 if hasExprs { 1660 cols = make([]string, 0, len(exprs)) 1661 infos = make([]*fieldInfo, 0, len(exprs)) 1662 for _, ex := range exprs { 1663 index, name, fi, suc := tables.parseExprs(mi, strings.Split(ex, ExprSep)) 1664 if !suc { 1665 panic(fmt.Errorf("unknown field/column name `%s`", ex)) 1666 } 1667 cols = append(cols, fmt.Sprintf("%s.%s%s%s %s%s%s", index, Q, fi.column, Q, Q, name, Q)) 1668 infos = append(infos, fi) 1669 } 1670 } else { 1671 cols = make([]string, 0, len(mi.fields.dbcols)) 1672 infos = make([]*fieldInfo, 0, len(exprs)) 1673 for _, fi := range mi.fields.fieldsDB { 1674 cols = append(cols, fmt.Sprintf("T0.%s%s%s %s%s%s", Q, fi.column, Q, Q, fi.name, Q)) 1675 infos = append(infos, fi) 1676 } 1677 } 1678 1679 where, args := tables.getCondSQL(cond, false, tz) 1680 groupBy := tables.getGroupSQL(qs.groups) 1681 orderBy := tables.getOrderSQL(qs.orders) 1682 limit := tables.getLimitSQL(mi, qs.offset, qs.limit) 1683 join := tables.getJoinSQL() 1684 1685 sels := strings.Join(cols, ", ") 1686 1687 sqlSelect := "SELECT" 1688 if qs.distinct { 1689 sqlSelect += " DISTINCT" 1690 } 1691 query := fmt.Sprintf("%s %s FROM %s%s%s T0 %s%s%s%s%s", sqlSelect, sels, Q, mi.table, Q, join, where, groupBy, orderBy, limit) 1692 1693 d.ins.ReplaceMarks(&query) 1694 1695 rs, err := q.Query(query, args...) 1696 if err != nil { 1697 return 0, err 1698 } 1699 refs := make([]interface{}, len(cols)) 1700 for i := range refs { 1701 var ref interface{} 1702 refs[i] = &ref 1703 } 1704 1705 defer rs.Close() 1706 1707 var ( 1708 cnt int64 1709 columns []string 1710 ) 1711 for rs.Next() { 1712 if cnt == 0 { 1713 cols, err := rs.Columns() 1714 if err != nil { 1715 return 0, err 1716 } 1717 columns = cols 1718 } 1719 1720 if err := rs.Scan(refs...); err != nil { 1721 return 0, err 1722 } 1723 1724 switch typ { 1725 case 1: 1726 params := make(Params, len(cols)) 1727 for i, ref := range refs { 1728 fi := infos[i] 1729 1730 val := reflect.Indirect(reflect.ValueOf(ref)).Interface() 1731 1732 value, err := d.convertValueFromDB(fi, val, tz) 1733 if err != nil { 1734 panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) 1735 } 1736 1737 params[columns[i]] = value 1738 } 1739 maps = append(maps, params) 1740 case 2: 1741 params := make(ParamsList, 0, len(cols)) 1742 for i, ref := range refs { 1743 fi := infos[i] 1744 1745 val := reflect.Indirect(reflect.ValueOf(ref)).Interface() 1746 1747 value, err := d.convertValueFromDB(fi, val, tz) 1748 if err != nil { 1749 panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) 1750 } 1751 1752 params = append(params, value) 1753 } 1754 lists = append(lists, params) 1755 case 3: 1756 for i, ref := range refs { 1757 fi := infos[i] 1758 1759 val := reflect.Indirect(reflect.ValueOf(ref)).Interface() 1760 1761 value, err := d.convertValueFromDB(fi, val, tz) 1762 if err != nil { 1763 panic(fmt.Errorf("db value convert failed `%v` %s", val, err.Error())) 1764 } 1765 1766 list = append(list, value) 1767 } 1768 } 1769 1770 cnt++ 1771 } 1772 1773 switch v := container.(type) { 1774 case *[]Params: 1775 *v = maps 1776 case *[]ParamsList: 1777 *v = lists 1778 case *ParamsList: 1779 *v = list 1780 } 1781 1782 return cnt, nil 1783 } 1784 1785 func (d *dbBase) RowsTo(dbQuerier, *querySet, *modelInfo, *Condition, interface{}, string, string, *time.Location) (int64, error) { 1786 return 0, nil 1787 } 1788 1789 // flag of update joined record. 1790 func (d *dbBase) SupportUpdateJoin() bool { 1791 return true 1792 } 1793 1794 func (d *dbBase) MaxLimit() uint64 { 1795 return 18446744073709551615 1796 } 1797 1798 // return quote. 1799 func (d *dbBase) TableQuote() string { 1800 return "`" 1801 } 1802 1803 // replace value placeholder in parametered sql string. 1804 func (d *dbBase) ReplaceMarks(query *string) { 1805 // default use `?` as mark, do nothing 1806 } 1807 1808 // flag of RETURNING sql. 1809 func (d *dbBase) HasReturningID(*modelInfo, *string) bool { 1810 return false 1811 } 1812 1813 // sync auto key 1814 func (d *dbBase) setval(db dbQuerier, mi *modelInfo, autoFields []string) error { 1815 return nil 1816 } 1817 1818 // convert time from db. 1819 func (d *dbBase) TimeFromDB(t *time.Time, tz *time.Location) { 1820 *t = t.In(tz) 1821 } 1822 1823 // convert time to db. 1824 func (d *dbBase) TimeToDB(t *time.Time, tz *time.Location) { 1825 *t = t.In(tz) 1826 } 1827 1828 // get database types. 1829 func (d *dbBase) DbTypes() map[string]string { 1830 return nil 1831 } 1832 1833 // gt all tables. 1834 func (d *dbBase) GetTables(db dbQuerier) (map[string]bool, error) { 1835 tables := make(map[string]bool) 1836 query := d.ins.ShowTablesQuery() 1837 rows, err := db.Query(query) 1838 if err != nil { 1839 return tables, err 1840 } 1841 1842 defer rows.Close() 1843 1844 for rows.Next() { 1845 var table string 1846 err := rows.Scan(&table) 1847 if err != nil { 1848 return tables, err 1849 } 1850 if table != "" { 1851 tables[table] = true 1852 } 1853 } 1854 1855 return tables, nil 1856 } 1857 1858 // get all cloumns in table. 1859 func (d *dbBase) GetColumns(db dbQuerier, table string) (map[string][3]string, error) { 1860 columns := make(map[string][3]string) 1861 query := d.ins.ShowColumnsQuery(table) 1862 rows, err := db.Query(query) 1863 if err != nil { 1864 return columns, err 1865 } 1866 1867 defer rows.Close() 1868 1869 for rows.Next() { 1870 var ( 1871 name string 1872 typ string 1873 null string 1874 ) 1875 err := rows.Scan(&name, &typ, &null) 1876 if err != nil { 1877 return columns, err 1878 } 1879 columns[name] = [3]string{name, typ, null} 1880 } 1881 1882 return columns, nil 1883 } 1884 1885 // not implement. 1886 func (d *dbBase) OperatorSQL(operator string) string { 1887 panic(ErrNotImplement) 1888 } 1889 1890 // not implement. 1891 func (d *dbBase) ShowTablesQuery() string { 1892 panic(ErrNotImplement) 1893 } 1894 1895 // not implement. 1896 func (d *dbBase) ShowColumnsQuery(table string) string { 1897 panic(ErrNotImplement) 1898 } 1899 1900 // not implement. 1901 func (d *dbBase) IndexExists(dbQuerier, string, string) bool { 1902 panic(ErrNotImplement) 1903 }