github.com/systematiccaos/gorm@v1.22.6/statement.go (about) 1 package gorm 2 3 import ( 4 "context" 5 "database/sql" 6 "database/sql/driver" 7 "fmt" 8 "reflect" 9 "regexp" 10 "sort" 11 "strconv" 12 "strings" 13 "sync" 14 15 "github.com/systematiccaos/gorm/clause" 16 "github.com/systematiccaos/gorm/logger" 17 "github.com/systematiccaos/gorm/schema" 18 "github.com/systematiccaos/gorm/utils" 19 ) 20 21 // Statement statement 22 type Statement struct { 23 *DB 24 TableExpr *clause.Expr 25 Table string 26 Model interface{} 27 Unscoped bool 28 Dest interface{} 29 ReflectValue reflect.Value 30 Clauses map[string]clause.Clause 31 BuildClauses []string 32 Distinct bool 33 Selects []string // selected columns 34 Omits []string // omit columns 35 Joins []join 36 Preloads map[string][]interface{} 37 Settings sync.Map 38 ConnPool ConnPool 39 Schema *schema.Schema 40 Context context.Context 41 RaiseErrorOnNotFound bool 42 SkipHooks bool 43 SQL strings.Builder 44 Vars []interface{} 45 CurDestIndex int 46 attrs []interface{} 47 assigns []interface{} 48 scopes []func(*DB) *DB 49 } 50 51 type join struct { 52 Name string 53 Conds []interface{} 54 On *clause.Where 55 } 56 57 // StatementModifier statement modifier interface 58 type StatementModifier interface { 59 ModifyStatement(*Statement) 60 } 61 62 // WriteString write string 63 func (stmt *Statement) WriteString(str string) (int, error) { 64 return stmt.SQL.WriteString(str) 65 } 66 67 // WriteByte write byte 68 func (stmt *Statement) WriteByte(c byte) error { 69 return stmt.SQL.WriteByte(c) 70 } 71 72 // WriteQuoted write quoted value 73 func (stmt *Statement) WriteQuoted(value interface{}) { 74 stmt.QuoteTo(&stmt.SQL, value) 75 } 76 77 // QuoteTo write quoted value to writer 78 func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { 79 write := func(raw bool, str string) { 80 if raw { 81 writer.WriteString(str) 82 } else { 83 stmt.DB.Dialector.QuoteTo(writer, str) 84 } 85 } 86 87 switch v := field.(type) { 88 case clause.Table: 89 if v.Name == clause.CurrentTable { 90 if stmt.TableExpr != nil { 91 stmt.TableExpr.Build(stmt) 92 } else { 93 write(v.Raw, stmt.Table) 94 } 95 } else { 96 write(v.Raw, v.Name) 97 } 98 99 if v.Alias != "" { 100 writer.WriteByte(' ') 101 write(v.Raw, v.Alias) 102 } 103 case clause.Column: 104 if v.Table != "" { 105 if v.Table == clause.CurrentTable { 106 write(v.Raw, stmt.Table) 107 } else { 108 write(v.Raw, v.Table) 109 } 110 writer.WriteByte('.') 111 } 112 113 if v.Name == clause.PrimaryKey { 114 if stmt.Schema == nil { 115 stmt.DB.AddError(ErrModelValueRequired) 116 } else if stmt.Schema.PrioritizedPrimaryField != nil { 117 write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) 118 } else if len(stmt.Schema.DBNames) > 0 { 119 write(v.Raw, stmt.Schema.DBNames[0]) 120 } 121 } else { 122 write(v.Raw, v.Name) 123 } 124 125 if v.Alias != "" { 126 writer.WriteString(" AS ") 127 write(v.Raw, v.Alias) 128 } 129 case []clause.Column: 130 writer.WriteByte('(') 131 for idx, d := range v { 132 if idx > 0 { 133 writer.WriteString(",") 134 } 135 stmt.QuoteTo(writer, d) 136 } 137 writer.WriteByte(')') 138 case clause.Expr: 139 v.Build(stmt) 140 case string: 141 stmt.DB.Dialector.QuoteTo(writer, v) 142 case []string: 143 writer.WriteByte('(') 144 for idx, d := range v { 145 if idx > 0 { 146 writer.WriteString(",") 147 } 148 stmt.DB.Dialector.QuoteTo(writer, d) 149 } 150 writer.WriteByte(')') 151 default: 152 stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) 153 } 154 } 155 156 // Quote returns quoted value 157 func (stmt *Statement) Quote(field interface{}) string { 158 var builder strings.Builder 159 stmt.QuoteTo(&builder, field) 160 return builder.String() 161 } 162 163 // AddVar add var 164 func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { 165 for idx, v := range vars { 166 if idx > 0 { 167 writer.WriteByte(',') 168 } 169 170 switch v := v.(type) { 171 case sql.NamedArg: 172 stmt.Vars = append(stmt.Vars, v.Value) 173 case clause.Column, clause.Table: 174 stmt.QuoteTo(writer, v) 175 case Valuer: 176 reflectValue := reflect.ValueOf(v) 177 if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() { 178 stmt.AddVar(writer, nil) 179 } else { 180 stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) 181 } 182 case clause.Expr: 183 v.Build(stmt) 184 case *clause.Expr: 185 v.Build(stmt) 186 case driver.Valuer: 187 stmt.Vars = append(stmt.Vars, v) 188 stmt.DB.Dialector.BindVarTo(writer, stmt, v) 189 case []byte: 190 stmt.Vars = append(stmt.Vars, v) 191 stmt.DB.Dialector.BindVarTo(writer, stmt, v) 192 case []interface{}: 193 if len(v) > 0 { 194 writer.WriteByte('(') 195 stmt.AddVar(writer, v...) 196 writer.WriteByte(')') 197 } else { 198 writer.WriteString("(NULL)") 199 } 200 case *DB: 201 subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() 202 if v.Statement.SQL.Len() > 0 { 203 var ( 204 vars = subdb.Statement.Vars 205 sql = v.Statement.SQL.String() 206 ) 207 208 subdb.Statement.Vars = make([]interface{}, 0, len(vars)) 209 for _, vv := range vars { 210 subdb.Statement.Vars = append(subdb.Statement.Vars, vv) 211 bindvar := strings.Builder{} 212 v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv) 213 sql = strings.Replace(sql, bindvar.String(), "?", 1) 214 } 215 216 subdb.Statement.SQL.Reset() 217 subdb.Statement.Vars = stmt.Vars 218 if strings.Contains(sql, "@") { 219 clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) 220 } else { 221 clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) 222 } 223 } else { 224 subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) 225 subdb.callbacks.Query().Execute(subdb) 226 } 227 228 writer.WriteString(subdb.Statement.SQL.String()) 229 stmt.Vars = subdb.Statement.Vars 230 default: 231 switch rv := reflect.ValueOf(v); rv.Kind() { 232 case reflect.Slice, reflect.Array: 233 if rv.Len() == 0 { 234 writer.WriteString("(NULL)") 235 } else { 236 writer.WriteByte('(') 237 for i := 0; i < rv.Len(); i++ { 238 if i > 0 { 239 writer.WriteByte(',') 240 } 241 stmt.AddVar(writer, rv.Index(i).Interface()) 242 } 243 writer.WriteByte(')') 244 } 245 default: 246 stmt.Vars = append(stmt.Vars, v) 247 stmt.DB.Dialector.BindVarTo(writer, stmt, v) 248 } 249 } 250 } 251 } 252 253 // AddClause add clause 254 func (stmt *Statement) AddClause(v clause.Interface) { 255 if optimizer, ok := v.(StatementModifier); ok { 256 optimizer.ModifyStatement(stmt) 257 } else { 258 name := v.Name() 259 c := stmt.Clauses[name] 260 c.Name = name 261 v.MergeClause(&c) 262 stmt.Clauses[name] = c 263 } 264 } 265 266 // AddClauseIfNotExists add clause if not exists 267 func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { 268 if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil { 269 stmt.AddClause(v) 270 } 271 } 272 273 // BuildCondition build condition 274 func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression { 275 if s, ok := query.(string); ok { 276 // if it is a number, then treats it as primary key 277 if _, err := strconv.Atoi(s); err != nil { 278 if s == "" && len(args) == 0 { 279 return nil 280 } 281 282 if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { 283 // looks like a where condition 284 return []clause.Expression{clause.Expr{SQL: s, Vars: args}} 285 } 286 287 if len(args) > 0 && strings.Contains(s, "@") { 288 // looks like a named query 289 return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} 290 } 291 292 if strings.Contains(strings.TrimSpace(s), " ") { 293 // looks like a where condition 294 return []clause.Expression{clause.Expr{SQL: s, Vars: args}} 295 } 296 297 if len(args) == 1 { 298 return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} 299 } 300 } 301 } 302 303 conds := make([]clause.Expression, 0, 4) 304 args = append([]interface{}{query}, args...) 305 for idx, arg := range args { 306 if valuer, ok := arg.(driver.Valuer); ok { 307 arg, _ = valuer.Value() 308 } 309 310 switch v := arg.(type) { 311 case clause.Expression: 312 conds = append(conds, v) 313 case *DB: 314 if cs, ok := v.Statement.Clauses["WHERE"]; ok { 315 if where, ok := cs.Expression.(clause.Where); ok { 316 if len(where.Exprs) == 1 { 317 if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { 318 where.Exprs[0] = clause.AndConditions(orConds) 319 } 320 } 321 conds = append(conds, clause.And(where.Exprs...)) 322 } else if cs.Expression != nil { 323 conds = append(conds, cs.Expression) 324 } 325 } 326 case map[interface{}]interface{}: 327 for i, j := range v { 328 conds = append(conds, clause.Eq{Column: i, Value: j}) 329 } 330 case map[string]string: 331 var keys = make([]string, 0, len(v)) 332 for i := range v { 333 keys = append(keys, i) 334 } 335 sort.Strings(keys) 336 337 for _, key := range keys { 338 conds = append(conds, clause.Eq{Column: key, Value: v[key]}) 339 } 340 case map[string]interface{}: 341 var keys = make([]string, 0, len(v)) 342 for i := range v { 343 keys = append(keys, i) 344 } 345 sort.Strings(keys) 346 347 for _, key := range keys { 348 reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) 349 switch reflectValue.Kind() { 350 case reflect.Slice, reflect.Array: 351 if _, ok := v[key].(driver.Valuer); ok { 352 conds = append(conds, clause.Eq{Column: key, Value: v[key]}) 353 } else if _, ok := v[key].(Valuer); ok { 354 conds = append(conds, clause.Eq{Column: key, Value: v[key]}) 355 } else { 356 // optimize reflect value length 357 valueLen := reflectValue.Len() 358 values := make([]interface{}, valueLen) 359 for i := 0; i < valueLen; i++ { 360 values[i] = reflectValue.Index(i).Interface() 361 } 362 363 conds = append(conds, clause.IN{Column: key, Values: values}) 364 } 365 default: 366 conds = append(conds, clause.Eq{Column: key, Value: v[key]}) 367 } 368 } 369 default: 370 reflectValue := reflect.Indirect(reflect.ValueOf(arg)) 371 for reflectValue.Kind() == reflect.Ptr { 372 reflectValue = reflectValue.Elem() 373 } 374 375 if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { 376 selectedColumns := map[string]bool{} 377 if idx == 0 { 378 for _, v := range args[1:] { 379 if vs, ok := v.(string); ok { 380 selectedColumns[vs] = true 381 } 382 } 383 } 384 restricted := len(selectedColumns) != 0 385 386 switch reflectValue.Kind() { 387 case reflect.Struct: 388 for _, field := range s.Fields { 389 selected := selectedColumns[field.DBName] || selectedColumns[field.Name] 390 if selected || (!restricted && field.Readable) { 391 if v, isZero := field.ValueOf(reflectValue); !isZero || selected { 392 if field.DBName != "" { 393 conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) 394 } else if field.DataType != "" { 395 conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) 396 } 397 } 398 } 399 } 400 case reflect.Slice, reflect.Array: 401 for i := 0; i < reflectValue.Len(); i++ { 402 for _, field := range s.Fields { 403 selected := selectedColumns[field.DBName] || selectedColumns[field.Name] 404 if selected || (!restricted && field.Readable) { 405 if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected { 406 if field.DBName != "" { 407 conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v}) 408 } else if field.DataType != "" { 409 conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v}) 410 } 411 } 412 } 413 } 414 } 415 } 416 417 if restricted { 418 break 419 } 420 } else if !reflectValue.IsValid() { 421 stmt.AddError(ErrInvalidData) 422 } else if len(conds) == 0 { 423 if len(args) == 1 { 424 switch reflectValue.Kind() { 425 case reflect.Slice, reflect.Array: 426 // optimize reflect value length 427 valueLen := reflectValue.Len() 428 values := make([]interface{}, valueLen) 429 for i := 0; i < valueLen; i++ { 430 values[i] = reflectValue.Index(i).Interface() 431 } 432 433 if len(values) > 0 { 434 conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values}) 435 } 436 return conds 437 } 438 } 439 440 conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args}) 441 } 442 } 443 } 444 445 return conds 446 } 447 448 // Build build sql with clauses names 449 func (stmt *Statement) Build(clauses ...string) { 450 var firstClauseWritten bool 451 452 for _, name := range clauses { 453 if c, ok := stmt.Clauses[name]; ok { 454 if firstClauseWritten { 455 stmt.WriteByte(' ') 456 } 457 458 firstClauseWritten = true 459 if b, ok := stmt.DB.ClauseBuilders[name]; ok { 460 b(c, stmt) 461 } else { 462 c.Build(stmt) 463 } 464 } 465 } 466 } 467 468 func (stmt *Statement) Parse(value interface{}) (err error) { 469 return stmt.ParseWithSpecialTableName(value, "") 470 } 471 472 func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { 473 if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { 474 if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { 475 stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} 476 stmt.Table = tables[1] 477 return 478 } 479 480 stmt.Table = stmt.Schema.Table 481 } 482 return err 483 } 484 485 func (stmt *Statement) clone() *Statement { 486 newStmt := &Statement{ 487 TableExpr: stmt.TableExpr, 488 Table: stmt.Table, 489 Model: stmt.Model, 490 Unscoped: stmt.Unscoped, 491 Dest: stmt.Dest, 492 ReflectValue: stmt.ReflectValue, 493 Clauses: map[string]clause.Clause{}, 494 Distinct: stmt.Distinct, 495 Selects: stmt.Selects, 496 Omits: stmt.Omits, 497 Preloads: map[string][]interface{}{}, 498 ConnPool: stmt.ConnPool, 499 Schema: stmt.Schema, 500 Context: stmt.Context, 501 RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, 502 SkipHooks: stmt.SkipHooks, 503 } 504 505 if stmt.SQL.Len() > 0 { 506 newStmt.SQL.WriteString(stmt.SQL.String()) 507 newStmt.Vars = make([]interface{}, 0, len(stmt.Vars)) 508 newStmt.Vars = append(newStmt.Vars, stmt.Vars...) 509 } 510 511 for k, c := range stmt.Clauses { 512 newStmt.Clauses[k] = c 513 } 514 515 for k, p := range stmt.Preloads { 516 newStmt.Preloads[k] = p 517 } 518 519 if len(stmt.Joins) > 0 { 520 newStmt.Joins = make([]join, len(stmt.Joins)) 521 copy(newStmt.Joins, stmt.Joins) 522 } 523 524 if len(stmt.scopes) > 0 { 525 newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes)) 526 copy(newStmt.scopes, stmt.scopes) 527 } 528 529 stmt.Settings.Range(func(k, v interface{}) bool { 530 newStmt.Settings.Store(k, v) 531 return true 532 }) 533 534 return newStmt 535 } 536 537 // SetColumn set column's value 538 // stmt.SetColumn("Name", "jinzhu") // Hooks Method 539 // stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method 540 func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { 541 if v, ok := stmt.Dest.(map[string]interface{}); ok { 542 v[name] = value 543 } else if v, ok := stmt.Dest.([]map[string]interface{}); ok { 544 for _, m := range v { 545 m[name] = value 546 } 547 } else if stmt.Schema != nil { 548 if field := stmt.Schema.LookUpField(name); field != nil { 549 destValue := reflect.ValueOf(stmt.Dest) 550 for destValue.Kind() == reflect.Ptr { 551 destValue = destValue.Elem() 552 } 553 554 if stmt.ReflectValue != destValue { 555 if !destValue.CanAddr() { 556 destValueCanAddr := reflect.New(destValue.Type()) 557 destValueCanAddr.Elem().Set(destValue) 558 stmt.Dest = destValueCanAddr.Interface() 559 destValue = destValueCanAddr.Elem() 560 } 561 562 switch destValue.Kind() { 563 case reflect.Struct: 564 field.Set(destValue, value) 565 default: 566 stmt.AddError(ErrInvalidData) 567 } 568 } 569 570 switch stmt.ReflectValue.Kind() { 571 case reflect.Slice, reflect.Array: 572 if len(fromCallbacks) > 0 { 573 for i := 0; i < stmt.ReflectValue.Len(); i++ { 574 field.Set(stmt.ReflectValue.Index(i), value) 575 } 576 } else { 577 field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value) 578 } 579 case reflect.Struct: 580 if !stmt.ReflectValue.CanAddr() { 581 stmt.AddError(ErrInvalidValue) 582 return 583 } 584 585 field.Set(stmt.ReflectValue, value) 586 } 587 } else { 588 stmt.AddError(ErrInvalidField) 589 } 590 } else { 591 stmt.AddError(ErrInvalidField) 592 } 593 } 594 595 // Changed check model changed or not when updating 596 func (stmt *Statement) Changed(fields ...string) bool { 597 modelValue := stmt.ReflectValue 598 switch modelValue.Kind() { 599 case reflect.Slice, reflect.Array: 600 modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) 601 } 602 603 selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) 604 changed := func(field *schema.Field) bool { 605 fieldValue, _ := field.ValueOf(modelValue) 606 if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { 607 if v, ok := stmt.Dest.(map[string]interface{}); ok { 608 if fv, ok := v[field.Name]; ok { 609 return !utils.AssertEqual(fv, fieldValue) 610 } else if fv, ok := v[field.DBName]; ok { 611 return !utils.AssertEqual(fv, fieldValue) 612 } 613 } else { 614 destValue := reflect.ValueOf(stmt.Dest) 615 for destValue.Kind() == reflect.Ptr { 616 destValue = destValue.Elem() 617 } 618 619 changedValue, zero := field.ValueOf(destValue) 620 return !zero && !utils.AssertEqual(changedValue, fieldValue) 621 } 622 } 623 return false 624 } 625 626 if len(fields) == 0 { 627 for _, field := range stmt.Schema.FieldsByDBName { 628 if changed(field) { 629 return true 630 } 631 } 632 } else { 633 for _, name := range fields { 634 if field := stmt.Schema.LookUpField(name); field != nil { 635 if changed(field) { 636 return true 637 } 638 } 639 } 640 } 641 642 return false 643 } 644 645 var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`) 646 647 // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false 648 func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { 649 results := map[string]bool{} 650 notRestricted := false 651 652 // select columns 653 for _, column := range stmt.Selects { 654 if stmt.Schema == nil { 655 results[column] = true 656 } else if column == "*" { 657 notRestricted = true 658 for _, dbName := range stmt.Schema.DBNames { 659 results[dbName] = true 660 } 661 } else if column == clause.Associations { 662 for _, rel := range stmt.Schema.Relationships.Relations { 663 results[rel.Name] = true 664 } 665 } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { 666 results[field.DBName] = true 667 } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 { 668 results[matches[1]] = true 669 } else { 670 results[column] = true 671 } 672 } 673 674 // omit columns 675 for _, omit := range stmt.Omits { 676 if stmt.Schema == nil { 677 results[omit] = false 678 } else if omit == "*" { 679 for _, dbName := range stmt.Schema.DBNames { 680 results[dbName] = false 681 } 682 } else if omit == clause.Associations { 683 for _, rel := range stmt.Schema.Relationships.Relations { 684 results[rel.Name] = false 685 } 686 } else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" { 687 results[field.DBName] = false 688 } else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 { 689 results[matches[1]] = false 690 } else { 691 results[omit] = false 692 } 693 } 694 695 if stmt.Schema != nil { 696 for _, field := range stmt.Schema.FieldsByName { 697 name := field.DBName 698 if name == "" { 699 name = field.Name 700 } 701 702 if requireCreate && !field.Creatable { 703 results[name] = false 704 } else if requireUpdate && !field.Updatable { 705 results[name] = false 706 } 707 } 708 } 709 710 return results, !notRestricted && len(stmt.Selects) > 0 711 }