github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/go-xorm/xorm/statement.go (about) 1 // Copyright 2015 The Xorm Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package xorm 6 7 import ( 8 "bytes" 9 "database/sql/driver" 10 "encoding/json" 11 "errors" 12 "fmt" 13 "reflect" 14 "strings" 15 "time" 16 17 "github.com/insionng/yougam/libraries/go-xorm/builder" 18 "github.com/insionng/yougam/libraries/go-xorm/core" 19 ) 20 21 type incrParam struct { 22 colName string 23 arg interface{} 24 } 25 26 type decrParam struct { 27 colName string 28 arg interface{} 29 } 30 31 type exprParam struct { 32 colName string 33 expr string 34 } 35 36 // Statement save all the sql info for executing SQL 37 type Statement struct { 38 RefTable *core.Table 39 Engine *Engine 40 Start int 41 LimitN int 42 IdParam *core.PK 43 OrderStr string 44 JoinStr string 45 joinArgs []interface{} 46 GroupByStr string 47 HavingStr string 48 ColumnStr string 49 selectStr string 50 columnMap map[string]bool 51 useAllCols bool 52 OmitStr string 53 AltTableName string 54 tableName string 55 RawSQL string 56 RawParams []interface{} 57 UseCascade bool 58 UseAutoJoin bool 59 StoreEngine string 60 Charset string 61 UseCache bool 62 UseAutoTime bool 63 noAutoCondition bool 64 IsDistinct bool 65 IsForUpdate bool 66 TableAlias string 67 allUseBool bool 68 checkVersion bool 69 unscoped bool 70 mustColumnMap map[string]bool 71 nullableMap map[string]bool 72 incrColumns map[string]incrParam 73 decrColumns map[string]decrParam 74 exprColumns map[string]exprParam 75 cond builder.Cond 76 } 77 78 // Init reset all the statment's fields 79 func (statement *Statement) Init() { 80 statement.RefTable = nil 81 statement.Start = 0 82 statement.LimitN = 0 83 statement.OrderStr = "" 84 statement.UseCascade = true 85 statement.JoinStr = "" 86 statement.joinArgs = make([]interface{}, 0) 87 statement.GroupByStr = "" 88 statement.HavingStr = "" 89 statement.ColumnStr = "" 90 statement.OmitStr = "" 91 statement.columnMap = make(map[string]bool) 92 statement.AltTableName = "" 93 statement.tableName = "" 94 statement.IdParam = nil 95 statement.RawSQL = "" 96 statement.RawParams = make([]interface{}, 0) 97 statement.UseCache = true 98 statement.UseAutoTime = true 99 statement.noAutoCondition = false 100 statement.IsDistinct = false 101 statement.IsForUpdate = false 102 statement.TableAlias = "" 103 statement.selectStr = "" 104 statement.allUseBool = false 105 statement.useAllCols = false 106 statement.mustColumnMap = make(map[string]bool) 107 statement.nullableMap = make(map[string]bool) 108 statement.checkVersion = true 109 statement.unscoped = false 110 statement.incrColumns = make(map[string]incrParam) 111 statement.decrColumns = make(map[string]decrParam) 112 statement.exprColumns = make(map[string]exprParam) 113 statement.cond = builder.NewCond() 114 } 115 116 // NoAutoCondition if you do not want convert bean's field as query condition, then use this function 117 func (statement *Statement) NoAutoCondition(no ...bool) *Statement { 118 statement.noAutoCondition = true 119 if len(no) > 0 { 120 statement.noAutoCondition = no[0] 121 } 122 return statement 123 } 124 125 // Alias set the table alias 126 func (statement *Statement) Alias(alias string) *Statement { 127 statement.TableAlias = alias 128 return statement 129 } 130 131 // SQL adds raw sql statement 132 func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement { 133 switch query.(type) { 134 case (*builder.Builder): 135 var err error 136 statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL() 137 if err != nil { 138 statement.Engine.logger.Error(err) 139 } 140 case string: 141 statement.RawSQL = query.(string) 142 statement.RawParams = args 143 default: 144 statement.Engine.logger.Error("unsupported sql type") 145 } 146 147 return statement 148 } 149 150 // Where add Where statment 151 func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement { 152 return statement.And(query, args...) 153 } 154 155 // And add Where & and statment 156 func (statement *Statement) And(query interface{}, args ...interface{}) *Statement { 157 switch query.(type) { 158 case string: 159 cond := builder.Expr(query.(string), args...) 160 statement.cond = statement.cond.And(cond) 161 case builder.Cond: 162 cond := query.(builder.Cond) 163 statement.cond = statement.cond.And(cond) 164 for _, v := range args { 165 if vv, ok := v.(builder.Cond); ok { 166 statement.cond = statement.cond.And(vv) 167 } 168 } 169 default: 170 // TODO: not support condition type 171 } 172 173 return statement 174 } 175 176 // Or add Where & Or statment 177 func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement { 178 switch query.(type) { 179 case string: 180 cond := builder.Expr(query.(string), args...) 181 statement.cond = statement.cond.Or(cond) 182 case builder.Cond: 183 cond := query.(builder.Cond) 184 statement.cond = statement.cond.Or(cond) 185 for _, v := range args { 186 if vv, ok := v.(builder.Cond); ok { 187 statement.cond = statement.cond.Or(vv) 188 } 189 } 190 default: 191 // TODO: not support condition type 192 } 193 return statement 194 } 195 196 // In generate "Where column IN (?) " statment 197 func (statement *Statement) In(column string, args ...interface{}) *Statement { 198 if len(args) == 0 { 199 return statement 200 } 201 202 in := builder.In(column, args...) 203 statement.cond = statement.cond.And(in) 204 return statement 205 } 206 207 // NotIn generate "Where column NOT IN (?) " statment 208 func (statement *Statement) NotIn(column string, args ...interface{}) *Statement { 209 if len(args) == 0 { 210 return statement 211 } 212 213 in := builder.NotIn(column, args...) 214 statement.cond = statement.cond.And(in) 215 return statement 216 } 217 218 func (statement *Statement) setRefValue(v reflect.Value) { 219 statement.RefTable = statement.Engine.autoMapType(reflect.Indirect(v)) 220 statement.tableName = statement.Engine.tbName(v) 221 } 222 223 // Table tempororily set table name, the parameter could be a string or a pointer of struct 224 func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { 225 v := rValue(tableNameOrBean) 226 t := v.Type() 227 if t.Kind() == reflect.String { 228 statement.AltTableName = tableNameOrBean.(string) 229 } else if t.Kind() == reflect.Struct { 230 statement.RefTable = statement.Engine.autoMapType(v) 231 statement.AltTableName = statement.Engine.tbName(v) 232 } 233 return statement 234 } 235 236 // Auto generating update columnes and values according a struct 237 func buildUpdates(engine *Engine, table *core.Table, bean interface{}, 238 includeVersion bool, includeUpdated bool, includeNil bool, 239 includeAutoIncr bool, allUseBool bool, useAllCols bool, 240 mustColumnMap map[string]bool, nullableMap map[string]bool, 241 columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) { 242 243 var colNames = make([]string, 0) 244 var args = make([]interface{}, 0) 245 for _, col := range table.Columns() { 246 if !includeVersion && col.IsVersion { 247 continue 248 } 249 if col.IsCreated { 250 continue 251 } 252 if !includeUpdated && col.IsUpdated { 253 continue 254 } 255 if !includeAutoIncr && col.IsAutoIncrement { 256 continue 257 } 258 if col.IsDeleted && !unscoped { 259 continue 260 } 261 if use, ok := columnMap[strings.ToLower(col.Name)]; ok && !use { 262 continue 263 } 264 265 fieldValuePtr, err := col.ValueOf(bean) 266 if err != nil { 267 engine.logger.Error(err) 268 continue 269 } 270 271 fieldValue := *fieldValuePtr 272 fieldType := reflect.TypeOf(fieldValue.Interface()) 273 274 requiredField := useAllCols 275 includeNil := useAllCols 276 277 if b, ok := getFlagForColumn(mustColumnMap, col); ok { 278 if b { 279 requiredField = true 280 } else { 281 continue 282 } 283 } 284 285 // !evalphobia! set fieldValue as nil when column is nullable and zero-value 286 if b, ok := getFlagForColumn(nullableMap, col); ok { 287 if b && col.Nullable && isZero(fieldValue.Interface()) { 288 var nilValue *int 289 fieldValue = reflect.ValueOf(nilValue) 290 fieldType = reflect.TypeOf(fieldValue.Interface()) 291 includeNil = true 292 } 293 } 294 295 var val interface{} 296 297 if fieldValue.CanAddr() { 298 if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok { 299 data, err := structConvert.ToDB() 300 if err != nil { 301 engine.logger.Error(err) 302 } else { 303 val = data 304 } 305 goto APPEND 306 } 307 } 308 309 if structConvert, ok := fieldValue.Interface().(core.Conversion); ok { 310 data, err := structConvert.ToDB() 311 if err != nil { 312 engine.logger.Error(err) 313 } else { 314 val = data 315 } 316 goto APPEND 317 } 318 319 if fieldType.Kind() == reflect.Ptr { 320 if fieldValue.IsNil() { 321 if includeNil { 322 args = append(args, nil) 323 colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name))) 324 } 325 continue 326 } else if !fieldValue.IsValid() { 327 continue 328 } else { 329 // dereference ptr type to instance type 330 fieldValue = fieldValue.Elem() 331 fieldType = reflect.TypeOf(fieldValue.Interface()) 332 requiredField = true 333 } 334 } 335 336 switch fieldType.Kind() { 337 case reflect.Bool: 338 if allUseBool || requiredField { 339 val = fieldValue.Interface() 340 } else { 341 // if a bool in a struct, it will not be as a condition because it default is false, 342 // please use Where() instead 343 continue 344 } 345 case reflect.String: 346 if !requiredField && fieldValue.String() == "" { 347 continue 348 } 349 // for MyString, should convert to string or panic 350 if fieldType.String() != reflect.String.String() { 351 val = fieldValue.String() 352 } else { 353 val = fieldValue.Interface() 354 } 355 case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: 356 if !requiredField && fieldValue.Int() == 0 { 357 continue 358 } 359 val = fieldValue.Interface() 360 case reflect.Float32, reflect.Float64: 361 if !requiredField && fieldValue.Float() == 0.0 { 362 continue 363 } 364 val = fieldValue.Interface() 365 case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: 366 if !requiredField && fieldValue.Uint() == 0 { 367 continue 368 } 369 t := int64(fieldValue.Uint()) 370 val = reflect.ValueOf(&t).Interface() 371 case reflect.Struct: 372 if fieldType.ConvertibleTo(core.TimeType) { 373 t := fieldValue.Convert(core.TimeType).Interface().(time.Time) 374 if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { 375 continue 376 } 377 val = engine.FormatTime(col.SQLType.Name, t) 378 } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { 379 val, _ = nulType.Value() 380 } else { 381 if !col.SQLType.IsJson() { 382 engine.autoMapType(fieldValue) 383 if table, ok := engine.Tables[fieldValue.Type()]; ok { 384 if len(table.PrimaryKeys) == 1 { 385 pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) 386 // fix non-int pk issues 387 if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) { 388 val = pkField.Interface() 389 } else { 390 continue 391 } 392 } else { 393 //TODO: how to handler? 394 panic("not supported") 395 } 396 } else { 397 val = fieldValue.Interface() 398 } 399 } else { 400 // Blank struct could not be as update data 401 if requiredField || !isStructZero(fieldValue) { 402 bytes, err := json.Marshal(fieldValue.Interface()) 403 if err != nil { 404 panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface())) 405 } 406 if col.SQLType.IsText() { 407 val = string(bytes) 408 } else if col.SQLType.IsBlob() { 409 val = bytes 410 } 411 } else { 412 continue 413 } 414 } 415 } 416 case reflect.Array, reflect.Slice, reflect.Map: 417 if !requiredField { 418 if fieldValue == reflect.Zero(fieldType) { 419 continue 420 } 421 if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { 422 continue 423 } 424 } 425 426 if col.SQLType.IsText() { 427 bytes, err := json.Marshal(fieldValue.Interface()) 428 if err != nil { 429 engine.logger.Error(err) 430 continue 431 } 432 val = string(bytes) 433 } else if col.SQLType.IsBlob() { 434 var bytes []byte 435 var err error 436 if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && 437 fieldType.Elem().Kind() == reflect.Uint8 { 438 if fieldValue.Len() > 0 { 439 val = fieldValue.Bytes() 440 } else { 441 continue 442 } 443 } else { 444 bytes, err = json.Marshal(fieldValue.Interface()) 445 if err != nil { 446 engine.logger.Error(err) 447 continue 448 } 449 val = bytes 450 } 451 } else { 452 continue 453 } 454 default: 455 val = fieldValue.Interface() 456 } 457 458 APPEND: 459 args = append(args, val) 460 if col.IsPrimaryKey && engine.dialect.DBType() == "ql" { 461 continue 462 } 463 colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name))) 464 } 465 466 return colNames, args 467 } 468 469 func (statement *Statement) needTableName() bool { 470 return len(statement.JoinStr) > 0 471 } 472 473 func (statement *Statement) colName(col *core.Column, tableName string) string { 474 if statement.needTableName() { 475 var nm = tableName 476 if len(statement.TableAlias) > 0 { 477 nm = statement.TableAlias 478 } 479 return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name) 480 } 481 return statement.Engine.Quote(col.Name) 482 } 483 484 func buildConds(engine *Engine, table *core.Table, bean interface{}, 485 includeVersion bool, includeUpdated bool, includeNil bool, 486 includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, 487 mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) (builder.Cond, error) { 488 var conds []builder.Cond 489 for _, col := range table.Columns() { 490 if !includeVersion && col.IsVersion { 491 continue 492 } 493 if !includeUpdated && col.IsUpdated { 494 continue 495 } 496 if !includeAutoIncr && col.IsAutoIncrement { 497 continue 498 } 499 500 if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { 501 continue 502 } 503 if col.SQLType.IsJson() { 504 continue 505 } 506 507 var colName string 508 if addedTableName { 509 var nm = tableName 510 if len(aliasName) > 0 { 511 nm = aliasName 512 } 513 colName = engine.Quote(nm) + "." + engine.Quote(col.Name) 514 } else { 515 colName = engine.Quote(col.Name) 516 } 517 518 fieldValuePtr, err := col.ValueOf(bean) 519 if err != nil { 520 engine.logger.Error(err) 521 continue 522 } 523 524 if col.IsDeleted && !unscoped { // tag "deleted" is enabled 525 conds = append(conds, builder.IsNull{colName}.Or(builder.Eq{colName: "0001-01-01 00:00:00"})) 526 } 527 528 fieldValue := *fieldValuePtr 529 if fieldValue.Interface() == nil { 530 continue 531 } 532 533 fieldType := reflect.TypeOf(fieldValue.Interface()) 534 requiredField := useAllCols 535 536 if b, ok := getFlagForColumn(mustColumnMap, col); ok { 537 if b { 538 requiredField = true 539 } else { 540 continue 541 } 542 } 543 544 if fieldType.Kind() == reflect.Ptr { 545 if fieldValue.IsNil() { 546 if includeNil { 547 conds = append(conds, builder.Eq{colName: nil}) 548 } 549 continue 550 } else if !fieldValue.IsValid() { 551 continue 552 } else { 553 // dereference ptr type to instance type 554 fieldValue = fieldValue.Elem() 555 fieldType = reflect.TypeOf(fieldValue.Interface()) 556 requiredField = true 557 } 558 } 559 560 var val interface{} 561 switch fieldType.Kind() { 562 case reflect.Bool: 563 if allUseBool || requiredField { 564 val = fieldValue.Interface() 565 } else { 566 // if a bool in a struct, it will not be as a condition because it default is false, 567 // please use Where() instead 568 continue 569 } 570 case reflect.String: 571 if !requiredField && fieldValue.String() == "" { 572 continue 573 } 574 // for MyString, should convert to string or panic 575 if fieldType.String() != reflect.String.String() { 576 val = fieldValue.String() 577 } else { 578 val = fieldValue.Interface() 579 } 580 case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: 581 if !requiredField && fieldValue.Int() == 0 { 582 continue 583 } 584 val = fieldValue.Interface() 585 case reflect.Float32, reflect.Float64: 586 if !requiredField && fieldValue.Float() == 0.0 { 587 continue 588 } 589 val = fieldValue.Interface() 590 case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: 591 if !requiredField && fieldValue.Uint() == 0 { 592 continue 593 } 594 t := int64(fieldValue.Uint()) 595 val = reflect.ValueOf(&t).Interface() 596 case reflect.Struct: 597 if fieldType.ConvertibleTo(core.TimeType) { 598 t := fieldValue.Convert(core.TimeType).Interface().(time.Time) 599 if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { 600 continue 601 } 602 val = engine.FormatTime(col.SQLType.Name, t) 603 } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { 604 continue 605 } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { 606 val, _ = valNul.Value() 607 if val == nil { 608 continue 609 } 610 } else { 611 if col.SQLType.IsJson() { 612 if col.SQLType.IsText() { 613 bytes, err := json.Marshal(fieldValue.Interface()) 614 if err != nil { 615 engine.logger.Error(err) 616 continue 617 } 618 val = string(bytes) 619 } else if col.SQLType.IsBlob() { 620 var bytes []byte 621 var err error 622 bytes, err = json.Marshal(fieldValue.Interface()) 623 if err != nil { 624 engine.logger.Error(err) 625 continue 626 } 627 val = bytes 628 } 629 } else { 630 engine.autoMapType(fieldValue) 631 if table, ok := engine.Tables[fieldValue.Type()]; ok { 632 if len(table.PrimaryKeys) == 1 { 633 pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) 634 // fix non-int pk issues 635 //if pkField.Int() != 0 { 636 if pkField.IsValid() && !isZero(pkField.Interface()) { 637 val = pkField.Interface() 638 } else { 639 continue 640 } 641 } else { 642 //TODO: how to handler? 643 panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys)) 644 } 645 } else { 646 val = fieldValue.Interface() 647 } 648 } 649 } 650 case reflect.Array, reflect.Slice, reflect.Map: 651 if fieldValue == reflect.Zero(fieldType) { 652 continue 653 } 654 if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { 655 continue 656 } 657 658 if col.SQLType.IsText() { 659 bytes, err := json.Marshal(fieldValue.Interface()) 660 if err != nil { 661 engine.logger.Error(err) 662 continue 663 } 664 val = string(bytes) 665 } else if col.SQLType.IsBlob() { 666 var bytes []byte 667 var err error 668 if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && 669 fieldType.Elem().Kind() == reflect.Uint8 { 670 if fieldValue.Len() > 0 { 671 val = fieldValue.Bytes() 672 } else { 673 continue 674 } 675 } else { 676 bytes, err = json.Marshal(fieldValue.Interface()) 677 if err != nil { 678 engine.logger.Error(err) 679 continue 680 } 681 val = bytes 682 } 683 } else { 684 continue 685 } 686 default: 687 val = fieldValue.Interface() 688 } 689 690 conds = append(conds, builder.Eq{colName: val}) 691 } 692 693 return builder.And(conds...), nil 694 } 695 696 // TableName return current tableName 697 func (statement *Statement) TableName() string { 698 if statement.AltTableName != "" { 699 return statement.AltTableName 700 } 701 702 return statement.tableName 703 } 704 705 // Id generate "where id = ? " statment or for composite key "where key1 = ? and key2 = ?" 706 func (statement *Statement) Id(id interface{}) *Statement { 707 idValue := reflect.ValueOf(id) 708 idType := reflect.TypeOf(idValue.Interface()) 709 710 switch idType { 711 case ptrPkType: 712 if pkPtr, ok := (id).(*core.PK); ok { 713 statement.IdParam = pkPtr 714 return statement 715 } 716 case pkType: 717 if pk, ok := (id).(core.PK); ok { 718 statement.IdParam = &pk 719 return statement 720 } 721 } 722 723 switch idType.Kind() { 724 case reflect.String: 725 statement.IdParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()} 726 return statement 727 } 728 729 statement.IdParam = &core.PK{id} 730 return statement 731 } 732 733 // Incr Generate "Update ... Set column = column + arg" statment 734 func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { 735 k := strings.ToLower(column) 736 if len(arg) > 0 { 737 statement.incrColumns[k] = incrParam{column, arg[0]} 738 } else { 739 statement.incrColumns[k] = incrParam{column, 1} 740 } 741 return statement 742 } 743 744 // Decr Generate "Update ... Set column = column - arg" statment 745 func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { 746 k := strings.ToLower(column) 747 if len(arg) > 0 { 748 statement.decrColumns[k] = decrParam{column, arg[0]} 749 } else { 750 statement.decrColumns[k] = decrParam{column, 1} 751 } 752 return statement 753 } 754 755 // SetExpr Generate "Update ... Set column = {expression}" statment 756 func (statement *Statement) SetExpr(column string, expression string) *Statement { 757 k := strings.ToLower(column) 758 statement.exprColumns[k] = exprParam{column, expression} 759 return statement 760 } 761 762 // Generate "Update ... Set column = column + arg" statment 763 func (statement *Statement) getInc() map[string]incrParam { 764 return statement.incrColumns 765 } 766 767 // Generate "Update ... Set column = column - arg" statment 768 func (statement *Statement) getDec() map[string]decrParam { 769 return statement.decrColumns 770 } 771 772 // Generate "Update ... Set column = {expression}" statment 773 func (statement *Statement) getExpr() map[string]exprParam { 774 return statement.exprColumns 775 } 776 777 func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { 778 newColumns := make([]string, 0) 779 for _, col := range columns { 780 col = strings.Replace(col, "`", "", -1) 781 col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1) 782 ccols := strings.Split(col, ",") 783 for _, c := range ccols { 784 fields := strings.Split(strings.TrimSpace(c), ".") 785 if len(fields) == 1 { 786 newColumns = append(newColumns, statement.Engine.quote(fields[0])) 787 } else if len(fields) == 2 { 788 newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+ 789 statement.Engine.quote(fields[1])) 790 } else { 791 panic(errors.New("unwanted colnames")) 792 } 793 } 794 } 795 return newColumns 796 } 797 798 // Distinct generates "DISTINCT col1, col2 " statement 799 func (statement *Statement) Distinct(columns ...string) *Statement { 800 statement.IsDistinct = true 801 statement.Cols(columns...) 802 return statement 803 } 804 805 // ForUpdate generates "SELECT ... FOR UPDATE" statement 806 func (statement *Statement) ForUpdate() *Statement { 807 statement.IsForUpdate = true 808 return statement 809 } 810 811 // Select replace select 812 func (statement *Statement) Select(str string) *Statement { 813 statement.selectStr = str 814 return statement 815 } 816 817 // Cols generate "col1, col2" statement 818 func (statement *Statement) Cols(columns ...string) *Statement { 819 cols := col2NewCols(columns...) 820 for _, nc := range cols { 821 statement.columnMap[strings.ToLower(nc)] = true 822 } 823 824 newColumns := statement.col2NewColsWithQuote(columns...) 825 statement.ColumnStr = strings.Join(newColumns, ", ") 826 statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1) 827 return statement 828 } 829 830 // AllCols update use only: update all columns 831 func (statement *Statement) AllCols() *Statement { 832 statement.useAllCols = true 833 return statement 834 } 835 836 // MustCols update use only: must update columns 837 func (statement *Statement) MustCols(columns ...string) *Statement { 838 newColumns := col2NewCols(columns...) 839 for _, nc := range newColumns { 840 statement.mustColumnMap[strings.ToLower(nc)] = true 841 } 842 return statement 843 } 844 845 // UseBool indicates that use bool fields as update contents and query contiditions 846 func (statement *Statement) UseBool(columns ...string) *Statement { 847 if len(columns) > 0 { 848 statement.MustCols(columns...) 849 } else { 850 statement.allUseBool = true 851 } 852 return statement 853 } 854 855 // Omit do not use the columns 856 func (statement *Statement) Omit(columns ...string) { 857 newColumns := col2NewCols(columns...) 858 for _, nc := range newColumns { 859 statement.columnMap[strings.ToLower(nc)] = false 860 } 861 statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) 862 } 863 864 // Nullable Update use only: update columns to null when value is nullable and zero-value 865 func (statement *Statement) Nullable(columns ...string) { 866 newColumns := col2NewCols(columns...) 867 for _, nc := range newColumns { 868 statement.nullableMap[strings.ToLower(nc)] = true 869 } 870 } 871 872 // Top generate LIMIT limit statement 873 func (statement *Statement) Top(limit int) *Statement { 874 statement.Limit(limit) 875 return statement 876 } 877 878 // Limit generate LIMIT start, limit statement 879 func (statement *Statement) Limit(limit int, start ...int) *Statement { 880 statement.LimitN = limit 881 if len(start) > 0 { 882 statement.Start = start[0] 883 } 884 return statement 885 } 886 887 // OrderBy generate "Order By order" statement 888 func (statement *Statement) OrderBy(order string) *Statement { 889 if len(statement.OrderStr) > 0 { 890 statement.OrderStr += ", " 891 } 892 statement.OrderStr += order 893 return statement 894 } 895 896 // Desc generate `ORDER BY xx DESC` 897 func (statement *Statement) Desc(colNames ...string) *Statement { 898 var buf bytes.Buffer 899 fmt.Fprintf(&buf, statement.OrderStr) 900 if len(statement.OrderStr) > 0 { 901 fmt.Fprint(&buf, ", ") 902 } 903 newColNames := statement.col2NewColsWithQuote(colNames...) 904 fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, ")) 905 statement.OrderStr = buf.String() 906 return statement 907 } 908 909 // Asc provide asc order by query condition, the input parameters are columns. 910 func (statement *Statement) Asc(colNames ...string) *Statement { 911 var buf bytes.Buffer 912 fmt.Fprintf(&buf, statement.OrderStr) 913 if len(statement.OrderStr) > 0 { 914 fmt.Fprint(&buf, ", ") 915 } 916 newColNames := statement.col2NewColsWithQuote(colNames...) 917 fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, ")) 918 statement.OrderStr = buf.String() 919 return statement 920 } 921 922 // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN 923 func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { 924 var buf bytes.Buffer 925 if len(statement.JoinStr) > 0 { 926 fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) 927 } else { 928 fmt.Fprintf(&buf, "%v JOIN ", joinOP) 929 } 930 931 switch tablename.(type) { 932 case []string: 933 t := tablename.([]string) 934 if len(t) > 1 { 935 fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1])) 936 } else if len(t) == 1 { 937 fmt.Fprintf(&buf, statement.Engine.Quote(t[0])) 938 } 939 case []interface{}: 940 t := tablename.([]interface{}) 941 l := len(t) 942 var table string 943 if l > 0 { 944 f := t[0] 945 v := rValue(f) 946 t := v.Type() 947 if t.Kind() == reflect.String { 948 table = f.(string) 949 } else if t.Kind() == reflect.Struct { 950 table = statement.Engine.tbName(v) 951 } 952 } 953 if l > 1 { 954 fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table), 955 statement.Engine.Quote(fmt.Sprintf("%v", t[1]))) 956 } else if l == 1 { 957 fmt.Fprintf(&buf, statement.Engine.Quote(table)) 958 } 959 default: 960 fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) 961 } 962 963 fmt.Fprintf(&buf, " ON %v", condition) 964 statement.JoinStr = buf.String() 965 statement.joinArgs = append(statement.joinArgs, args...) 966 return statement 967 } 968 969 // GroupBy generate "Group By keys" statement 970 func (statement *Statement) GroupBy(keys string) *Statement { 971 statement.GroupByStr = keys 972 return statement 973 } 974 975 // Having generate "Having conditions" statement 976 func (statement *Statement) Having(conditions string) *Statement { 977 statement.HavingStr = fmt.Sprintf("HAVING %v", conditions) 978 return statement 979 } 980 981 // Unscoped always disable struct tag "deleted" 982 func (statement *Statement) Unscoped() *Statement { 983 statement.unscoped = true 984 return statement 985 } 986 987 func (statement *Statement) genColumnStr() string { 988 989 var buf bytes.Buffer 990 991 columns := statement.RefTable.Columns() 992 993 for _, col := range columns { 994 995 if statement.OmitStr != "" { 996 if _, ok := getFlagForColumn(statement.columnMap, col); ok { 997 continue 998 } 999 } 1000 1001 if col.MapType == core.ONLYTODB { 1002 continue 1003 } 1004 1005 if buf.Len() != 0 { 1006 buf.WriteString(", ") 1007 } 1008 1009 if col.IsPrimaryKey && statement.Engine.Dialect().DBType() == "ql" { 1010 buf.WriteString("id() AS ") 1011 } 1012 1013 if statement.JoinStr != "" { 1014 if statement.TableAlias != "" { 1015 buf.WriteString(statement.TableAlias) 1016 } else { 1017 buf.WriteString(statement.TableName()) 1018 } 1019 1020 buf.WriteString(".") 1021 } 1022 1023 statement.Engine.QuoteTo(&buf, col.Name) 1024 } 1025 1026 return buf.String() 1027 } 1028 1029 func (statement *Statement) genCreateTableSQL() string { 1030 return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(), 1031 statement.StoreEngine, statement.Charset) 1032 } 1033 1034 func (statement *Statement) genIndexSQL() []string { 1035 var sqls []string 1036 tbName := statement.TableName() 1037 quote := statement.Engine.Quote 1038 for idxName, index := range statement.RefTable.Indexes { 1039 if index.Type == core.IndexType { 1040 sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(tbName, idxName)), 1041 quote(tbName), quote(strings.Join(index.Cols, quote(",")))) 1042 sqls = append(sqls, sql) 1043 } 1044 } 1045 return sqls 1046 } 1047 1048 func uniqueName(tableName, uqeName string) string { 1049 return fmt.Sprintf("UQE_%v_%v", tableName, uqeName) 1050 } 1051 1052 func (statement *Statement) genUniqueSQL() []string { 1053 var sqls []string 1054 tbName := statement.TableName() 1055 for _, index := range statement.RefTable.Indexes { 1056 if index.Type == core.UniqueType { 1057 sql := statement.Engine.dialect.CreateIndexSql(tbName, index) 1058 sqls = append(sqls, sql) 1059 } 1060 } 1061 return sqls 1062 } 1063 1064 func (statement *Statement) genDelIndexSQL() []string { 1065 var sqls []string 1066 tbName := statement.TableName() 1067 for idxName, index := range statement.RefTable.Indexes { 1068 var rIdxName string 1069 if index.Type == core.UniqueType { 1070 rIdxName = uniqueName(tbName, idxName) 1071 } else if index.Type == core.IndexType { 1072 rIdxName = indexName(tbName, idxName) 1073 } 1074 sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(rIdxName)) 1075 if statement.Engine.dialect.IndexOnTable() { 1076 sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(statement.TableName())) 1077 } 1078 sqls = append(sqls, sql) 1079 } 1080 return sqls 1081 } 1082 1083 func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) { 1084 quote := statement.Engine.Quote 1085 sql := fmt.Sprintf("ALTER TABLE %v ADD %v;", quote(statement.TableName()), 1086 col.String(statement.Engine.dialect)) 1087 return sql, []interface{}{} 1088 } 1089 1090 func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) { 1091 return buildConds(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, 1092 statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) 1093 } 1094 1095 func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) { 1096 if !statement.noAutoCondition { 1097 var addedTableName = (len(statement.JoinStr) > 0) 1098 autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName) 1099 if err != nil { 1100 return "", nil, err 1101 } 1102 statement.cond = statement.cond.And(autoCond) 1103 } 1104 1105 statement.processIdParam() 1106 1107 return builder.ToSQL(statement.cond) 1108 } 1109 1110 func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}) { 1111 statement.setRefValue(rValue(bean)) 1112 1113 var columnStr = statement.ColumnStr 1114 if len(statement.selectStr) > 0 { 1115 columnStr = statement.selectStr 1116 } else { 1117 // TODO: always generate column names, not use * even if join 1118 if len(statement.JoinStr) == 0 { 1119 if len(columnStr) == 0 { 1120 if len(statement.GroupByStr) > 0 { 1121 columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) 1122 } else { 1123 columnStr = statement.genColumnStr() 1124 } 1125 } 1126 } else { 1127 if len(columnStr) == 0 { 1128 if len(statement.GroupByStr) > 0 { 1129 columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) 1130 } else { 1131 columnStr = "*" 1132 } 1133 } 1134 } 1135 } 1136 1137 condSQL, condArgs, _ := statement.genConds(bean) 1138 1139 return statement.genSelectSQL(columnStr, condSQL), append(statement.joinArgs, condArgs...) 1140 } 1141 1142 func (statement *Statement) genCountSQL(bean interface{}) (string, []interface{}) { 1143 statement.setRefValue(rValue(bean)) 1144 1145 condSQL, condArgs, _ := statement.genConds(bean) 1146 1147 var selectSql = statement.selectStr 1148 if len(selectSql) <= 0 { 1149 if statement.IsDistinct { 1150 selectSql = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr) 1151 } 1152 selectSql = "count(*)" 1153 } 1154 return statement.genSelectSQL(selectSql, condSQL), append(statement.joinArgs, condArgs...) 1155 } 1156 1157 func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}) { 1158 statement.setRefValue(rValue(bean)) 1159 1160 var sumStrs = make([]string, 0, len(columns)) 1161 for _, colName := range columns { 1162 sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName)) 1163 } 1164 1165 condSQL, condArgs, _ := statement.genConds(bean) 1166 1167 return statement.genSelectSQL(strings.Join(sumStrs, ", "), condSQL), append(statement.joinArgs, condArgs...) 1168 } 1169 1170 func (statement *Statement) genSelectSQL(columnStr, condSQL string) (a string) { 1171 var distinct string 1172 if statement.IsDistinct { 1173 distinct = "DISTINCT " 1174 } 1175 1176 var dialect = statement.Engine.Dialect() 1177 var quote = statement.Engine.Quote 1178 var top string 1179 var mssqlCondi string 1180 1181 statement.processIdParam() 1182 1183 var buf bytes.Buffer 1184 if len(condSQL) > 0 { 1185 fmt.Fprintf(&buf, " WHERE %v", condSQL) 1186 } 1187 var whereStr = buf.String() 1188 1189 var fromStr = " FROM " + quote(statement.TableName()) 1190 if statement.TableAlias != "" { 1191 if dialect.DBType() == core.ORACLE { 1192 fromStr += " " + quote(statement.TableAlias) 1193 } else { 1194 fromStr += " AS " + quote(statement.TableAlias) 1195 } 1196 } 1197 if statement.JoinStr != "" { 1198 fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) 1199 } 1200 1201 if dialect.DBType() == core.MSSQL { 1202 if statement.LimitN > 0 { 1203 top = fmt.Sprintf(" TOP %d ", statement.LimitN) 1204 } 1205 if statement.Start > 0 { 1206 var column string 1207 if len(statement.RefTable.PKColumns()) == 0 { 1208 for _, index := range statement.RefTable.Indexes { 1209 if len(index.Cols) == 1 { 1210 column = index.Cols[0] 1211 break 1212 } 1213 } 1214 if len(column) == 0 { 1215 column = statement.RefTable.ColumnsSeq()[0] 1216 } 1217 } else { 1218 column = statement.RefTable.PKColumns()[0].Name 1219 } 1220 if statement.needTableName() { 1221 if len(statement.TableAlias) > 0 { 1222 column = statement.TableAlias + "." + column 1223 } else { 1224 column = statement.TableName() + "." + column 1225 } 1226 } 1227 1228 var orderStr string 1229 if len(statement.OrderStr) > 0 { 1230 orderStr = " ORDER BY " + statement.OrderStr 1231 } 1232 var groupStr string 1233 if len(statement.GroupByStr) > 0 { 1234 groupStr = " GROUP BY " + statement.GroupByStr 1235 } 1236 mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))", 1237 column, statement.Start, column, fromStr, whereStr, orderStr, groupStr) 1238 } 1239 } 1240 1241 // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern 1242 a = fmt.Sprintf("SELECT %v%v%v%v%v", top, distinct, columnStr, fromStr, whereStr) 1243 if len(mssqlCondi) > 0 { 1244 if len(whereStr) > 0 { 1245 a += " AND " + mssqlCondi 1246 } else { 1247 a += " WHERE " + mssqlCondi 1248 } 1249 } 1250 1251 if statement.GroupByStr != "" { 1252 a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr) 1253 } 1254 if statement.HavingStr != "" { 1255 a = fmt.Sprintf("%v %v", a, statement.HavingStr) 1256 } 1257 if statement.OrderStr != "" { 1258 a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) 1259 } 1260 if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { 1261 if statement.Start > 0 { 1262 a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) 1263 } else if statement.LimitN > 0 { 1264 a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) 1265 } 1266 } else if dialect.DBType() == core.ORACLE { 1267 if statement.Start != 0 || statement.LimitN != 0 { 1268 a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) 1269 } 1270 } 1271 if statement.IsForUpdate { 1272 a = dialect.ForUpdateSql(a) 1273 } 1274 1275 return 1276 } 1277 1278 func (statement *Statement) processIdParam() { 1279 if statement.IdParam == nil { 1280 return 1281 } 1282 1283 for i, col := range statement.RefTable.PKColumns() { 1284 var colName = statement.colName(col, statement.TableName()) 1285 if i < len(*(statement.IdParam)) { 1286 statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.IdParam))[i]}) 1287 } else { 1288 statement.cond = statement.cond.And(builder.Eq{colName: ""}) 1289 } 1290 } 1291 } 1292 1293 func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string { 1294 var colnames = make([]string, len(cols)) 1295 for i, col := range cols { 1296 if includeTableName { 1297 colnames[i] = statement.Engine.Quote(statement.TableName()) + 1298 "." + statement.Engine.Quote(col.Name) 1299 } else { 1300 colnames[i] = statement.Engine.Quote(col.Name) 1301 } 1302 } 1303 return strings.Join(colnames, ", ") 1304 } 1305 1306 func (statement *Statement) convertIDSQL(sqlStr string) string { 1307 if statement.RefTable != nil { 1308 cols := statement.RefTable.PKColumns() 1309 if len(cols) == 0 { 1310 return "" 1311 } 1312 1313 colstrs := statement.joinColumns(cols, false) 1314 sqls := splitNNoCase(sqlStr, " from ", 2) 1315 if len(sqls) != 2 { 1316 return "" 1317 } 1318 1319 return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1]) 1320 } 1321 return "" 1322 } 1323 1324 func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { 1325 if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 { 1326 return "", "" 1327 } 1328 1329 colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true) 1330 sqls := splitNNoCase(sqlStr, "where", 2) 1331 if len(sqls) != 2 { 1332 if len(sqls) == 1 { 1333 return sqls[0], fmt.Sprintf("SELECT %v FROM %v", 1334 colstrs, statement.Engine.Quote(statement.TableName())) 1335 } 1336 return "", "" 1337 } 1338 1339 var whereStr = sqls[1] 1340 1341 //TODO: for postgres only, if any other database? 1342 var paraStr string 1343 if statement.Engine.dialect.DBType() == core.POSTGRES { 1344 paraStr = "$" 1345 } else if statement.Engine.dialect.DBType() == core.MSSQL { 1346 paraStr = ":" 1347 } 1348 1349 if paraStr != "" { 1350 if strings.Contains(sqls[1], paraStr) { 1351 dollers := strings.Split(sqls[1], paraStr) 1352 whereStr = dollers[0] 1353 for i, c := range dollers[1:] { 1354 ccs := strings.SplitN(c, " ", 2) 1355 whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1]) 1356 } 1357 } 1358 } 1359 1360 return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v", 1361 colstrs, statement.Engine.Quote(statement.TableName()), 1362 whereStr) 1363 }