github.com/systematiccaos/gorm@v1.22.6/finisher_api.go (about) 1 package gorm 2 3 import ( 4 "database/sql" 5 "errors" 6 "fmt" 7 "reflect" 8 "strings" 9 10 "github.com/systematiccaos/gorm/clause" 11 "github.com/systematiccaos/gorm/logger" 12 "github.com/systematiccaos/gorm/schema" 13 "github.com/systematiccaos/gorm/utils" 14 ) 15 16 // Create insert the value into database 17 func (db *DB) Create(value interface{}) (tx *DB) { 18 if db.CreateBatchSize > 0 { 19 return db.CreateInBatches(value, db.CreateBatchSize) 20 } 21 22 tx = db.getInstance() 23 tx.Statement.Dest = value 24 return tx.callbacks.Create().Execute(tx) 25 } 26 27 // CreateInBatches insert the value in batches into database 28 func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { 29 reflectValue := reflect.Indirect(reflect.ValueOf(value)) 30 31 switch reflectValue.Kind() { 32 case reflect.Slice, reflect.Array: 33 var rowsAffected int64 34 tx = db.getInstance() 35 36 callFc := func(tx *DB) error { 37 // the reflection length judgment of the optimized value 38 reflectLen := reflectValue.Len() 39 for i := 0; i < reflectLen; i += batchSize { 40 ends := i + batchSize 41 if ends > reflectLen { 42 ends = reflectLen 43 } 44 45 subtx := tx.getInstance() 46 subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface() 47 subtx.callbacks.Create().Execute(subtx) 48 if subtx.Error != nil { 49 return subtx.Error 50 } 51 rowsAffected += subtx.RowsAffected 52 } 53 return nil 54 } 55 56 if tx.SkipDefaultTransaction { 57 tx.AddError(callFc(tx.Session(&Session{}))) 58 } else { 59 tx.AddError(tx.Transaction(callFc)) 60 } 61 62 tx.RowsAffected = rowsAffected 63 default: 64 tx = db.getInstance() 65 tx.Statement.Dest = value 66 tx = tx.callbacks.Create().Execute(tx) 67 } 68 return 69 } 70 71 // Save update value in database, if the value doesn't have primary key, will insert it 72 func (db *DB) Save(value interface{}) (tx *DB) { 73 tx = db.getInstance() 74 tx.Statement.Dest = value 75 76 reflectValue := reflect.Indirect(reflect.ValueOf(value)) 77 switch reflectValue.Kind() { 78 case reflect.Slice, reflect.Array: 79 if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { 80 tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) 81 } 82 tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true)) 83 case reflect.Struct: 84 if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { 85 for _, pf := range tx.Statement.Schema.PrimaryFields { 86 if _, isZero := pf.ValueOf(reflectValue); isZero { 87 return tx.callbacks.Create().Execute(tx) 88 } 89 } 90 } 91 92 fallthrough 93 default: 94 selectedUpdate := len(tx.Statement.Selects) != 0 95 // when updating, use all fields including those zero-value fields 96 if !selectedUpdate { 97 tx.Statement.Selects = append(tx.Statement.Selects, "*") 98 } 99 100 tx = tx.callbacks.Update().Execute(tx) 101 102 if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { 103 result := reflect.New(tx.Statement.Schema.ModelType).Interface() 104 if err := tx.Session(&Session{}).Take(result).Error; errors.Is(err, ErrRecordNotFound) { 105 return tx.Create(value) 106 } 107 } 108 } 109 110 return 111 } 112 113 // First find first record that match given conditions, order by primary key 114 func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { 115 tx = db.Limit(1).Order(clause.OrderByColumn{ 116 Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, 117 }) 118 if len(conds) > 0 { 119 if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { 120 tx.Statement.AddClause(clause.Where{Exprs: exprs}) 121 } 122 } 123 tx.Statement.RaiseErrorOnNotFound = true 124 tx.Statement.Dest = dest 125 return tx.callbacks.Query().Execute(tx) 126 } 127 128 // Take return a record that match given conditions, the order will depend on the database implementation 129 func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { 130 tx = db.Limit(1) 131 if len(conds) > 0 { 132 if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { 133 tx.Statement.AddClause(clause.Where{Exprs: exprs}) 134 } 135 } 136 tx.Statement.RaiseErrorOnNotFound = true 137 tx.Statement.Dest = dest 138 return tx.callbacks.Query().Execute(tx) 139 } 140 141 // Last find last record that match given conditions, order by primary key 142 func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { 143 tx = db.Limit(1).Order(clause.OrderByColumn{ 144 Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, 145 Desc: true, 146 }) 147 if len(conds) > 0 { 148 if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { 149 tx.Statement.AddClause(clause.Where{Exprs: exprs}) 150 } 151 } 152 tx.Statement.RaiseErrorOnNotFound = true 153 tx.Statement.Dest = dest 154 return tx.callbacks.Query().Execute(tx) 155 } 156 157 // Find find records that match given conditions 158 func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { 159 tx = db.getInstance() 160 if len(conds) > 0 { 161 if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { 162 tx.Statement.AddClause(clause.Where{Exprs: exprs}) 163 } 164 } 165 tx.Statement.Dest = dest 166 return tx.callbacks.Query().Execute(tx) 167 } 168 169 // FindInBatches find records in batches 170 func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { 171 var ( 172 tx = db.Order(clause.OrderByColumn{ 173 Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, 174 }).Session(&Session{}) 175 queryDB = tx 176 rowsAffected int64 177 batch int 178 ) 179 180 for { 181 result := queryDB.Limit(batchSize).Find(dest) 182 rowsAffected += result.RowsAffected 183 batch++ 184 185 if result.Error == nil && result.RowsAffected != 0 { 186 tx.AddError(fc(result, batch)) 187 } else if result.Error != nil { 188 tx.AddError(result.Error) 189 } 190 191 if tx.Error != nil || int(result.RowsAffected) < batchSize { 192 break 193 } 194 195 // Optimize for-break 196 resultsValue := reflect.Indirect(reflect.ValueOf(dest)) 197 if result.Statement.Schema.PrioritizedPrimaryField == nil { 198 tx.AddError(ErrPrimaryKeyRequired) 199 break 200 } 201 202 primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) 203 queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) 204 } 205 206 tx.RowsAffected = rowsAffected 207 return tx 208 } 209 210 func (tx *DB) assignInterfacesToValue(values ...interface{}) { 211 for _, value := range values { 212 switch v := value.(type) { 213 case []clause.Expression: 214 for _, expr := range v { 215 if eq, ok := expr.(clause.Eq); ok { 216 switch column := eq.Column.(type) { 217 case string: 218 if field := tx.Statement.Schema.LookUpField(column); field != nil { 219 tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) 220 } 221 case clause.Column: 222 if field := tx.Statement.Schema.LookUpField(column.Name); field != nil { 223 tx.AddError(field.Set(tx.Statement.ReflectValue, eq.Value)) 224 } 225 } 226 } else if andCond, ok := expr.(clause.AndConditions); ok { 227 tx.assignInterfacesToValue(andCond.Exprs) 228 } 229 } 230 case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: 231 if exprs := tx.Statement.BuildCondition(value); len(exprs) > 0 { 232 tx.assignInterfacesToValue(exprs) 233 } 234 default: 235 if s, err := schema.Parse(value, tx.cacheStore, tx.NamingStrategy); err == nil { 236 reflectValue := reflect.Indirect(reflect.ValueOf(value)) 237 switch reflectValue.Kind() { 238 case reflect.Struct: 239 for _, f := range s.Fields { 240 if f.Readable { 241 if v, isZero := f.ValueOf(reflectValue); !isZero { 242 if field := tx.Statement.Schema.LookUpField(f.Name); field != nil { 243 tx.AddError(field.Set(tx.Statement.ReflectValue, v)) 244 } 245 } 246 } 247 } 248 } 249 } else if len(values) > 0 { 250 if exprs := tx.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { 251 tx.assignInterfacesToValue(exprs) 252 } 253 return 254 } 255 } 256 } 257 } 258 259 func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { 260 queryTx := db.Limit(1).Order(clause.OrderByColumn{ 261 Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, 262 }) 263 264 if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { 265 if c, ok := tx.Statement.Clauses["WHERE"]; ok { 266 if where, ok := c.Expression.(clause.Where); ok { 267 tx.assignInterfacesToValue(where.Exprs) 268 } 269 } 270 271 // initialize with attrs, conds 272 if len(tx.Statement.attrs) > 0 { 273 tx.assignInterfacesToValue(tx.Statement.attrs...) 274 } 275 } 276 277 // initialize with attrs, conds 278 if len(tx.Statement.assigns) > 0 { 279 tx.assignInterfacesToValue(tx.Statement.assigns...) 280 } 281 return 282 } 283 284 func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { 285 queryTx := db.Limit(1).Order(clause.OrderByColumn{ 286 Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, 287 }) 288 if tx = queryTx.Find(dest, conds...); tx.Error == nil { 289 if tx.RowsAffected == 0 { 290 if c, ok := tx.Statement.Clauses["WHERE"]; ok { 291 if where, ok := c.Expression.(clause.Where); ok { 292 tx.assignInterfacesToValue(where.Exprs) 293 } 294 } 295 296 // initialize with attrs, conds 297 if len(tx.Statement.attrs) > 0 { 298 tx.assignInterfacesToValue(tx.Statement.attrs...) 299 } 300 301 // initialize with attrs, conds 302 if len(tx.Statement.assigns) > 0 { 303 tx.assignInterfacesToValue(tx.Statement.assigns...) 304 } 305 306 return tx.Create(dest) 307 } else if len(db.Statement.assigns) > 0 { 308 exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) 309 assigns := map[string]interface{}{} 310 for _, expr := range exprs { 311 if eq, ok := expr.(clause.Eq); ok { 312 switch column := eq.Column.(type) { 313 case string: 314 assigns[column] = eq.Value 315 case clause.Column: 316 assigns[column.Name] = eq.Value 317 default: 318 } 319 } 320 } 321 322 return tx.Model(dest).Updates(assigns) 323 } 324 } 325 return tx 326 } 327 328 // Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields 329 func (db *DB) Update(column string, value interface{}) (tx *DB) { 330 tx = db.getInstance() 331 tx.Statement.Dest = map[string]interface{}{column: value} 332 return tx.callbacks.Update().Execute(tx) 333 } 334 335 // Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields 336 func (db *DB) Updates(values interface{}) (tx *DB) { 337 tx = db.getInstance() 338 tx.Statement.Dest = values 339 return tx.callbacks.Update().Execute(tx) 340 } 341 342 func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { 343 tx = db.getInstance() 344 tx.Statement.Dest = map[string]interface{}{column: value} 345 tx.Statement.SkipHooks = true 346 return tx.callbacks.Update().Execute(tx) 347 } 348 349 func (db *DB) UpdateColumns(values interface{}) (tx *DB) { 350 tx = db.getInstance() 351 tx.Statement.Dest = values 352 tx.Statement.SkipHooks = true 353 return tx.callbacks.Update().Execute(tx) 354 } 355 356 // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition 357 func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { 358 tx = db.getInstance() 359 if len(conds) > 0 { 360 if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { 361 tx.Statement.AddClause(clause.Where{Exprs: exprs}) 362 } 363 } 364 tx.Statement.Dest = value 365 return tx.callbacks.Delete().Execute(tx) 366 } 367 368 func (db *DB) Count(count *int64) (tx *DB) { 369 tx = db.getInstance() 370 if tx.Statement.Model == nil { 371 tx.Statement.Model = tx.Statement.Dest 372 defer func() { 373 tx.Statement.Model = nil 374 }() 375 } 376 377 if selectClause, ok := db.Statement.Clauses["SELECT"]; ok { 378 defer func() { 379 tx.Statement.Clauses["SELECT"] = selectClause 380 }() 381 } else { 382 defer delete(tx.Statement.Clauses, "SELECT") 383 } 384 385 if len(tx.Statement.Selects) == 0 { 386 tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}}) 387 } else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") { 388 expr := clause.Expr{SQL: "count(*)"} 389 390 if len(tx.Statement.Selects) == 1 { 391 dbName := tx.Statement.Selects[0] 392 fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar) 393 if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { 394 if tx.Statement.Parse(tx.Statement.Model) == nil { 395 if f := tx.Statement.Schema.LookUpField(dbName); f != nil { 396 dbName = f.DBName 397 } 398 } 399 400 if tx.Statement.Distinct { 401 expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} 402 } else if dbName != "*" { 403 expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} 404 } 405 } 406 } 407 408 tx.Statement.AddClause(clause.Select{Expression: expr}) 409 } 410 411 if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { 412 if _, ok := db.Statement.Clauses["GROUP BY"]; !ok { 413 delete(tx.Statement.Clauses, "ORDER BY") 414 defer func() { 415 tx.Statement.Clauses["ORDER BY"] = orderByClause 416 }() 417 } 418 } 419 420 tx.Statement.Dest = count 421 tx = tx.callbacks.Query().Execute(tx) 422 423 if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { 424 *count = tx.RowsAffected 425 } 426 427 return 428 } 429 430 func (db *DB) Row() *sql.Row { 431 tx := db.getInstance().Set("rows", false) 432 tx = tx.callbacks.Row().Execute(tx) 433 row, ok := tx.Statement.Dest.(*sql.Row) 434 if !ok && tx.DryRun { 435 db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) 436 } 437 return row 438 } 439 440 func (db *DB) Rows() (*sql.Rows, error) { 441 tx := db.getInstance().Set("rows", true) 442 tx = tx.callbacks.Row().Execute(tx) 443 rows, ok := tx.Statement.Dest.(*sql.Rows) 444 if !ok && tx.DryRun && tx.Error == nil { 445 tx.Error = ErrDryRunModeUnsupported 446 } 447 return rows, tx.Error 448 } 449 450 // Scan scan value to a struct 451 func (db *DB) Scan(dest interface{}) (tx *DB) { 452 config := *db.Config 453 currentLogger, newLogger := config.Logger, logger.Recorder.New() 454 config.Logger = newLogger 455 456 tx = db.getInstance() 457 tx.Config = &config 458 459 if rows, err := tx.Rows(); err == nil { 460 if rows.Next() { 461 tx.ScanRows(rows, dest) 462 } else { 463 tx.RowsAffected = 0 464 } 465 tx.AddError(rows.Close()) 466 } 467 468 currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { 469 return newLogger.SQL, tx.RowsAffected 470 }, tx.Error) 471 tx.Logger = currentLogger 472 return 473 } 474 475 // Pluck used to query single column from a model as a map 476 // var ages []int64 477 // db.Model(&users).Pluck("age", &ages) 478 func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { 479 tx = db.getInstance() 480 if tx.Statement.Model != nil { 481 if tx.Statement.Parse(tx.Statement.Model) == nil { 482 if f := tx.Statement.Schema.LookUpField(column); f != nil { 483 column = f.DBName 484 } 485 } 486 } 487 488 if len(tx.Statement.Selects) != 1 { 489 fields := strings.FieldsFunc(column, utils.IsValidDBNameChar) 490 tx.Statement.AddClauseIfNotExists(clause.Select{ 491 Distinct: tx.Statement.Distinct, 492 Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, 493 }) 494 } 495 tx.Statement.Dest = dest 496 return tx.callbacks.Query().Execute(tx) 497 } 498 499 func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { 500 tx := db.getInstance() 501 if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) { 502 tx.AddError(err) 503 } 504 tx.Statement.Dest = dest 505 tx.Statement.ReflectValue = reflect.ValueOf(dest) 506 for tx.Statement.ReflectValue.Kind() == reflect.Ptr { 507 elem := tx.Statement.ReflectValue.Elem() 508 if !elem.IsValid() { 509 elem = reflect.New(tx.Statement.ReflectValue.Type().Elem()) 510 tx.Statement.ReflectValue.Set(elem) 511 } 512 tx.Statement.ReflectValue = elem 513 } 514 Scan(rows, tx, ScanInitialized) 515 return tx.Error 516 } 517 518 // Transaction start a transaction as a block, return error will rollback, otherwise to commit. 519 func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { 520 panicked := true 521 522 if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { 523 // nested transaction 524 if !db.DisableNestedTransaction { 525 err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error 526 defer func() { 527 // Make sure to rollback when panic, Block error or Commit error 528 if panicked || err != nil { 529 db.RollbackTo(fmt.Sprintf("sp%p", fc)) 530 } 531 }() 532 } 533 534 if err == nil { 535 err = fc(db.Session(&Session{})) 536 } 537 } else { 538 tx := db.Begin(opts...) 539 540 defer func() { 541 // Make sure to rollback when panic, Block error or Commit error 542 if panicked || err != nil { 543 tx.Rollback() 544 } 545 }() 546 547 if err = tx.Error; err == nil { 548 err = fc(tx) 549 } 550 551 if err == nil { 552 err = tx.Commit().Error 553 } 554 } 555 556 panicked = false 557 return 558 } 559 560 // Begin begins a transaction 561 func (db *DB) Begin(opts ...*sql.TxOptions) *DB { 562 var ( 563 // clone statement 564 tx = db.getInstance().Session(&Session{Context: db.Statement.Context}) 565 opt *sql.TxOptions 566 err error 567 ) 568 569 if len(opts) > 0 { 570 opt = opts[0] 571 } 572 573 if beginner, ok := tx.Statement.ConnPool.(TxBeginner); ok { 574 tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) 575 } else if beginner, ok := tx.Statement.ConnPool.(ConnPoolBeginner); ok { 576 tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt) 577 } else { 578 err = ErrInvalidTransaction 579 } 580 581 if err != nil { 582 tx.AddError(err) 583 } 584 585 return tx 586 } 587 588 // Commit commit a transaction 589 func (db *DB) Commit() *DB { 590 if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { 591 db.AddError(committer.Commit()) 592 } else { 593 db.AddError(ErrInvalidTransaction) 594 } 595 return db 596 } 597 598 // Rollback rollback a transaction 599 func (db *DB) Rollback() *DB { 600 if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { 601 if !reflect.ValueOf(committer).IsNil() { 602 db.AddError(committer.Rollback()) 603 } 604 } else { 605 db.AddError(ErrInvalidTransaction) 606 } 607 return db 608 } 609 610 func (db *DB) SavePoint(name string) *DB { 611 if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { 612 db.AddError(savePointer.SavePoint(db, name)) 613 } else { 614 db.AddError(ErrUnsupportedDriver) 615 } 616 return db 617 } 618 619 func (db *DB) RollbackTo(name string) *DB { 620 if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { 621 db.AddError(savePointer.RollbackTo(db, name)) 622 } else { 623 db.AddError(ErrUnsupportedDriver) 624 } 625 return db 626 } 627 628 // Exec execute raw sql 629 func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { 630 tx = db.getInstance() 631 tx.Statement.SQL = strings.Builder{} 632 633 if strings.Contains(sql, "@") { 634 clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) 635 } else { 636 clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) 637 } 638 639 return tx.callbacks.Raw().Execute(tx) 640 }