github.com/movsb/taorm@v0.0.0-20201209183410-91bafb0b22a6/taorm/stmt.go (about) 1 package taorm 2 3 import ( 4 "bytes" 5 "database/sql" 6 "errors" 7 "fmt" 8 "reflect" 9 "regexp" 10 "strings" 11 12 "github.com/movsb/taorm/filter" 13 ) 14 15 // _Where ... 16 type _Where struct { 17 query string 18 args []interface{} 19 } 20 21 func (w _Where) build() (query string, args []interface{}) { 22 sb := bytes.NewBuffer(nil) 23 sb.Grow(len(query)) // should we reserve capacity for slice too? 24 var i int 25 for _, c := range w.query { 26 switch c { 27 case '?': 28 if i >= len(w.args) { 29 panic(fmt.Errorf("err where args count")) 30 } 31 value := reflect.ValueOf(w.args[i]) 32 if value.Kind() == reflect.Slice { 33 n := value.Len() 34 sb.WriteString(createSQLInMarks(n)) 35 for j := 0; j < n; j++ { 36 args = append(args, value.Index(j).Interface()) 37 } 38 } else { 39 sb.WriteByte('?') 40 args = append(args, w.args[i]) 41 } 42 i++ 43 default: 44 sb.WriteRune(c) 45 } 46 } 47 if i != len(w.args) { 48 panic(fmt.Errorf("err where args count")) 49 } 50 return sb.String(), args 51 } 52 53 // _Expr is a raw SQL expression. 54 // 55 // e.g.: `UPDATE sth SET left = right`, here `right` is the expression. 56 // 57 // TODO expr args cannot be slice. 58 type _Expr _Where 59 60 // Expr creates an expression for Update* operations. 61 func Expr(expr string, args ...interface{}) _Expr { 62 return _Expr{ 63 query: expr, 64 args: args, 65 } 66 } 67 68 type _RawQuery struct { 69 query string 70 args []interface{} 71 } 72 73 // Stmt is an SQL statement. 74 type Stmt struct { 75 db *DB 76 raw _RawQuery // not set if query == "" 77 model interface{} 78 fromTable interface{} 79 info *_StructInfo 80 tableNames []string 81 innerJoinTables []string 82 fields []string 83 ands []_Where 84 groupBy string 85 having string 86 orderBy string 87 limit int64 88 offset int64 89 } 90 91 // From ... 92 // table can be either string or struct. 93 func (s *Stmt) From(table interface{}) *Stmt { 94 switch typed := table.(type) { 95 case string: 96 s.tableNames = append(s.tableNames, typed) 97 default: 98 name, err := s.tryFindTableName(table) 99 if err != nil { 100 panic(WrapError(err)) 101 } 102 s.tableNames = append(s.tableNames, name) 103 } 104 return s 105 } 106 107 // InnerJoin ... 108 func (s *Stmt) InnerJoin(table interface{}, on string) *Stmt { 109 name := "" 110 switch typed := table.(type) { 111 case string: 112 name = typed 113 default: 114 n, err := s.tryFindTableName(typed) 115 if err != nil { 116 panic(WrapError(err)) 117 } 118 name = n 119 } 120 121 q := " INNER JOIN " + name 122 if on != "" { 123 q += " ON " + on 124 } 125 s.innerJoinTables = append(s.innerJoinTables, q) 126 return s 127 } 128 129 // Select ... 130 func (s *Stmt) Select(fields string) *Stmt { 131 if len(fields) > 0 { 132 s.fields = append(s.fields, fields) 133 } 134 return s 135 } 136 137 // Where ... 138 func (s *Stmt) Where(query string, args ...interface{}) *Stmt { 139 w := _Where{ 140 query: query, 141 args: args, 142 } 143 s.ands = append(s.ands, w) 144 return s 145 } 146 147 // WhereIf ... 148 func (s *Stmt) WhereIf(cond bool, query string, args ...interface{}) *Stmt { 149 if cond { 150 s.Where(query, args...) 151 } 152 return s 153 } 154 155 // GroupBy ... 156 func (s *Stmt) GroupBy(groupBy string) *Stmt { 157 s.groupBy = groupBy 158 return s 159 } 160 161 // Having ... 162 func (s *Stmt) Having(having string) *Stmt { 163 s.having = having 164 return s 165 } 166 167 // OrderBy ... 168 // TODO multiple orderbys 169 func (s *Stmt) OrderBy(orderBy string) *Stmt { 170 s.orderBy = orderBy 171 return s 172 } 173 174 // Limit ... 175 func (s *Stmt) Limit(limit int64) *Stmt { 176 s.limit = limit 177 return s 178 } 179 180 // Offset ... 181 func (s *Stmt) Offset(offset int64) *Stmt { 182 s.offset = offset 183 return s 184 } 185 186 // Filter ... may throw exceptions 187 // Filter has to know whom to filter. So before filtering, call From(), Model() 188 // or pass the third argument. 189 func (s *Stmt) Filter(expr string, mapper filter.Mapper, _Struct ...interface{}) *Stmt { 190 var info *_StructInfo 191 192 if s.info != nil { 193 info = s.info 194 } else if s.model != nil { 195 inf, err := getRegistered(s.model) 196 if err != nil { 197 panic(WrapError(err)) 198 } 199 info = inf 200 } else if s.fromTable != nil { 201 inf, err := getRegistered(s.fromTable) 202 if err != nil { 203 panic(WrapError(err)) 204 } 205 info = inf 206 } else if len(_Struct) > 0 { // Warn: == 1 207 inf, err := getRegistered(_Struct[0]) 208 if err != nil { 209 panic(WrapError(err)) 210 } 211 info = inf 212 } else { 213 panic(WrapError(errors.New("cannot deduce what to filter"))) 214 } 215 216 query, args, err := filter.Filter( 217 func(field string) reflect.Type { 218 return info.fields[field]._type // maybe not exist 219 }, 220 expr, 221 mapper, 222 info.tableName, 223 ) 224 if err != nil { 225 panic(WrapError(err)) 226 } 227 s.WhereIf(query != "", query, args...) 228 return s 229 } 230 231 // noWheres returns true if no SQL conditions. 232 // Includes and, or. 233 func (s *Stmt) noWheres() bool { 234 return len(s.ands) <= 0 235 } 236 237 func (s *Stmt) buildWheres() (string, []interface{}) { 238 if s.model != nil { 239 id, ok := s.info.getPrimaryKey(s.model) 240 s.WhereIf(ok, "id=?", id) 241 } 242 243 if s.noWheres() { 244 return "", nil 245 } 246 247 var args []interface{} 248 sb := bytes.NewBuffer(nil) 249 sb.WriteString(" WHERE ") 250 for i, w := range s.ands { 251 if i > 0 { 252 sb.WriteString(" AND ") 253 } 254 query, xargs := w.build() 255 sb.WriteString("(" + query + ")") 256 args = append(args, xargs...) 257 } 258 return sb.String(), args 259 } 260 261 func (s *Stmt) buildCreate() (*_StructInfo, string, []interface{}, error) { 262 panicIf(len(s.tableNames) != 1, "model length is not 1") 263 panicIf(s.raw.query != "", "cannot use raw here") 264 info, err := getRegistered(s.model) 265 if err != nil { 266 return info, "", nil, err 267 } 268 args := info.ifacesOf(s.model) 269 if len(args) == 0 { 270 return info, "", nil, ErrNoFields 271 } 272 return info, info.insertstr, args, nil 273 } 274 275 func (s *Stmt) tryFindTableName(out interface{}) (string, error) { 276 info, err := getRegistered(out) 277 if err != nil { 278 return "", err 279 } 280 if info.tableName == "" { 281 return "", fmt.Errorf("trying to use auto-registered struct table name") 282 } 283 return info.tableName, nil 284 } 285 286 func (s *Stmt) buildSelect(out interface{}, isCount bool) (string, []interface{}, error) { 287 if s.raw.query != "" { 288 return s.raw.query, s.raw.args, nil 289 } 290 291 if len(s.tableNames) == 0 { 292 name, err := s.tryFindTableName(out) 293 if err != nil { 294 return "", nil, err 295 } 296 s.tableNames = append(s.tableNames, name) 297 } 298 299 panicIf(len(s.tableNames) == 0, "model is empty") 300 301 var strFields string 302 303 if isCount { 304 strFields = "COUNT(1)" 305 } else { 306 fields := []string{} 307 if len(s.fields) == 0 { 308 if len(s.innerJoinTables) == 0 { 309 fields = []string{"*"} 310 } else { 311 fields = []string{s.tableNames[0] + ".*"} 312 } 313 } else { 314 if len(s.innerJoinTables) == 0 || len(s.fields) == 1 && s.fields[0] == "*" { 315 fields = s.fields 316 } else { 317 for _, list := range s.fields { 318 slice := strings.Split(list, ",") 319 for _, field := range slice { 320 index := strings.IndexByte(field, '.') 321 if index == -1 { 322 f := s.tableNames[0] + "." + field 323 fields = append(fields, f) 324 } else { 325 fields = append(fields, field) 326 } 327 } 328 } 329 } 330 } 331 strFields = strings.Join(fields, ",") 332 } 333 334 query := `SELECT ` + strFields + ` FROM ` + strings.Join(s.tableNames, ",") 335 if len(s.innerJoinTables) > 0 { 336 query += strings.Join(s.innerJoinTables, " ") 337 } 338 339 var args []interface{} 340 341 whereQuery, whereArgs := s.buildWheres() 342 query += whereQuery 343 args = append(args, whereArgs...) 344 345 query += s.buildGroupBy() 346 query += s.buildHaving() 347 348 if orderBy, err := s.buildOrderBy(); err != nil { 349 return "", nil, err 350 } else { 351 if orderBy != "" { 352 query += orderBy 353 } 354 } 355 query += s.buildLimit() 356 357 return query, args, nil 358 } 359 360 func (s *Stmt) buildUpdateMap(fields map[string]interface{}) (string, []interface{}, error) { 361 panicIf(len(s.tableNames) == 0, "model is empty") 362 panicIf(s.raw.query != "", "cannot use raw here") 363 query := `UPDATE ` + strings.Join(s.tableNames, ",") + ` SET ` 364 365 if len(fields) == 0 { 366 return "", nil, ErrNoFields 367 } 368 369 updates := make([]string, 0, len(fields)) 370 args := make([]interface{}, 0, len(fields)) 371 372 for field, value := range fields { 373 switch tv := value.(type) { 374 case _Expr: 375 eq, ea := _Where(tv).build() 376 pair := field + "=" + eq 377 updates = append(updates, pair) 378 args = append(args, ea...) 379 default: 380 pair := field + "=?" 381 updates = append(updates, pair) 382 args = append(args, value) 383 } 384 } 385 386 query += strings.Join(updates, ",") 387 388 whereQuery, whereArgs := s.buildWheres() 389 query += whereQuery 390 args = append(args, whereArgs...) 391 392 query += s.buildLimit() 393 394 return query, args, nil 395 } 396 397 func (s *Stmt) buildUpdateModel(model interface{}) (string, []interface{}, error) { 398 panicIf(len(s.tableNames) == 0, "model is empty") 399 panicIf(s.raw.query != "", "cannot use raw here") 400 query := s.info.updatestr 401 args := s.info.ifacesOf(model) 402 whereQuery, whereArgs := s.buildWheres() 403 query += whereQuery 404 args = append(args, whereArgs...) 405 return query, args, nil 406 } 407 408 func (s *Stmt) buildDelete() (string, []interface{}, error) { 409 panicIf(len(s.tableNames) == 0, "model is empty") 410 panicIf(s.raw.query != "", "cannot use raw here") 411 var args []interface{} 412 query := `DELETE FROM ` + strings.Join(s.tableNames, ",") 413 414 whereQuery, whereArgs := s.buildWheres() 415 query += whereQuery 416 args = append(args, whereArgs...) 417 418 query += s.buildLimit() 419 420 return query, args, nil 421 } 422 423 func (s *Stmt) buildGroupBy() (groupBy string) { 424 if s.groupBy != "" { 425 groupBy = ` GROUP BY ` + s.groupBy 426 } 427 return 428 } 429 430 func (s *Stmt) buildHaving() (having string) { 431 if s.having != `` { 432 having = ` HAVING ` + s.having 433 } 434 return 435 } 436 437 var regexpOrderBy = regexp.MustCompile(`^ *((\w+\.)?(\w+)) *(\w+)? *$`) 438 439 func (s *Stmt) buildOrderBy() (string, error) { 440 orderBy := " ORDER BY " 441 if s.orderBy == "" { 442 return "", nil 443 } 444 parts := strings.Split(s.orderBy, ",") 445 orderBys := []string{} 446 for _, part := range parts { 447 if !regexpOrderBy.MatchString(part) { 448 return ``, fmt.Errorf(`invalid order_by: %s`, part) 449 } 450 orderBys = append(orderBys, part) 451 452 // these are for automatically adding table names to fields in order_by etc. 453 // they are commented out because of custom field name doesn't belong to some table. 454 // currently I don't know how to handle this correctly. 455 // 456 // matches := regexpOrderBy.FindStringSubmatch(part) 457 // if len(matches) != 5 { 458 // return "", errors.New("invalid orderby") 459 // } 460 // table := matches[2] 461 // column := matches[1] 462 // order := matches[4] 463 // // avoid column ambiguous 464 // // "Error 1052: Column 'created_at' in order clause is ambiguous" 465 // if table == "" && len(s.tableNames)+len(s.innerJoinTables) > 1 { 466 // column = s.tableNames[0] + "." + column 467 // } 468 // if order != "" { 469 // column += " " + order 470 // } 471 // orderBys = append(orderBys, column) 472 } 473 orderBy += strings.Join(orderBys, ",") 474 return orderBy, nil 475 } 476 477 func (s *Stmt) buildLimit() (limit string) { 478 if s.limit > 0 { 479 limit += ` LIMIT ` + fmt.Sprint(s.limit) 480 if s.offset >= 0 { 481 limit += ` OFFSET ` + fmt.Sprint(s.offset) 482 } 483 } 484 return 485 } 486 487 // Create ... 488 func (s *Stmt) Create() error { 489 info, query, args, err := s.buildCreate() 490 if err != nil { 491 return WrapError(err) 492 } 493 494 dumpSQL(query, args...) 495 496 result, err := s.db.Exec(query, args...) 497 if err != nil { 498 return WrapError(err) 499 } 500 501 id, err := result.LastInsertId() 502 if err != nil { 503 return WrapError(err) 504 } 505 506 info.setPrimaryKey(s.model, id) 507 508 return nil 509 } 510 511 // MustCreate ... 512 func (s *Stmt) MustCreate() { 513 if err := s.Create(); err != nil { 514 panic(err) 515 } 516 } 517 518 // CreateSQL ... 519 func (s *Stmt) CreateSQL() string { 520 _, query, args, err := s.buildCreate() 521 if err != nil { 522 panic(WrapError(err)) 523 } 524 return strSQL(query, args...) 525 } 526 527 // Find ... 528 func (s *Stmt) Find(out interface{}) error { 529 query, args, err := s.buildSelect(out, false) 530 if err != nil { 531 return WrapError(err) 532 } 533 534 dumpSQL(query, args...) 535 return ScanRows(out, s.db, query, args...) 536 } 537 538 // MustFind ... 539 func (s *Stmt) MustFind(out interface{}) { 540 if err := s.Find(out); err != nil { 541 panic(err) 542 } 543 } 544 545 // FindSQL ... 546 func (s *Stmt) FindSQL() string { 547 query, args, err := s.buildSelect(s.model, false) 548 if err != nil { 549 panic(WrapError(err)) 550 } 551 return strSQL(query, args...) 552 } 553 554 // Count ... 555 func (s *Stmt) Count(out interface{}) error { 556 query, args, err := s.buildSelect(s.fromTable, true) 557 if err != nil { 558 return WrapError(err) 559 } 560 561 dumpSQL(query, args...) 562 return ScanRows(out, s.db, query, args...) 563 } 564 565 // MustCount ... 566 func (s *Stmt) MustCount(out interface{}) { 567 if err := s.Count(out); err != nil { 568 panic(err) 569 } 570 } 571 572 // CountSQL ... 573 func (s *Stmt) CountSQL() string { 574 query, args, err := s.buildSelect(s.fromTable, true) 575 if err != nil { 576 panic(WrapError(err)) 577 } 578 return strSQL(query, args...) 579 } 580 581 func (s *Stmt) updateMap(fields M, anyway bool) (sql.Result, error) { 582 if len(fields) == 0 { 583 return nil, ErrNoFields 584 } 585 586 query, args, err := s.buildUpdateMap(fields) 587 if err != nil { 588 return nil, err 589 } 590 591 if !anyway && s.noWheres() { 592 return nil, ErrNoWhere 593 } 594 595 dumpSQL(query, args...) 596 597 res, err := s.db.Exec(query, args...) 598 if err != nil { 599 return nil, err 600 } 601 602 return res, nil 603 } 604 605 func (s *Stmt) updateModel(model interface{}) (sql.Result, error) { 606 query, args, err := s.buildUpdateModel(model) 607 if err != nil { 608 return nil, err 609 } 610 611 dumpSQL(query, args...) 612 613 res, err := s.db.Exec(query, args...) 614 if err != nil { 615 return nil, err 616 } 617 618 return res, nil 619 } 620 621 // UpdateMap ... 622 func (s *Stmt) UpdateMap(updates M) (sql.Result, error) { 623 res, err := s.updateMap(updates, false) 624 return res, WrapError(err) 625 } 626 627 // UpdateMapAnyway ... 628 func (s *Stmt) UpdateMapAnyway(updates M) (sql.Result, error) { 629 res, err := s.updateMap(updates, true) 630 return res, WrapError(err) 631 } 632 633 // UpdateModel ... 634 func (s *Stmt) UpdateModel(model interface{}) (sql.Result, error) { 635 res, err := s.updateModel(model) 636 return res, WrapError(err) 637 } 638 639 // MustUpdateMap ... 640 func (s *Stmt) MustUpdateMap(updates M) sql.Result { 641 res, err := s.updateMap(updates, false) 642 if err != nil { 643 panic(err) 644 } 645 return res 646 } 647 648 // MustUpdateMapAnyway ... 649 func (s *Stmt) MustUpdateMapAnyway(updates M) sql.Result { 650 res, err := s.updateMap(updates, true) 651 if err != nil { 652 panic(err) 653 } 654 return res 655 } 656 657 // MustUpdateModel ... 658 func (s *Stmt) MustUpdateModel(model interface{}) sql.Result { 659 res, err := s.updateModel(model) 660 if err != nil { 661 panic(err) 662 } 663 return res 664 } 665 666 // UpdateMapSQL ... 667 func (s *Stmt) UpdateMapSQL(updates M) string { 668 query, args, err := s.buildUpdateMap(updates) 669 if err != nil { 670 panic(WrapError(err)) 671 } 672 return strSQL(query, args...) 673 } 674 675 // UpdateModelSQL ... 676 func (s *Stmt) UpdateModelSQL(model interface{}) string { 677 query, args, err := s.buildUpdateModel(model) 678 if err != nil { 679 panic(WrapError(err)) 680 } 681 return strSQL(query, args...) 682 } 683 684 func (s *Stmt) _delete(anyway bool) error { 685 query, args, err := s.buildDelete() 686 if err != nil { 687 return err 688 } 689 690 if !anyway && s.noWheres() { 691 return ErrNoWhere 692 } 693 694 dumpSQL(query, args...) 695 696 _, err = s.db.Exec(query, args...) 697 if err != nil { 698 return err 699 } 700 701 return nil 702 } 703 704 // Delete ... 705 func (s *Stmt) Delete() error { 706 return WrapError(s._delete(false)) 707 } 708 709 // DeleteAnyway ... 710 func (s *Stmt) DeleteAnyway() error { 711 return WrapError(s._delete(true)) 712 } 713 714 // MustDelete ... 715 func (s *Stmt) MustDelete() { 716 if err := s.Delete(); err != nil { 717 panic(err) 718 } 719 } 720 721 // MustDeleteAnyway ... 722 func (s *Stmt) MustDeleteAnyway() { 723 if err := s.DeleteAnyway(); err != nil { 724 panic(err) 725 } 726 } 727 728 // DeleteSQL ... 729 func (s *Stmt) DeleteSQL() string { 730 query, args, err := s.buildDelete() 731 if err != nil { 732 panic(WrapError(err)) 733 } 734 return strSQL(query, args...) 735 }