github.com/wfusion/gofusion@v1.1.14/db/callbacks/auto_incr.go (about) 1 package callbacks 2 3 import ( 4 "reflect" 5 "strings" 6 7 "gorm.io/gorm" 8 "gorm.io/gorm/clause" 9 "gorm.io/gorm/utils" 10 11 comUtl "github.com/wfusion/gofusion/common/utils" 12 13 . "gorm.io/driver/mysql" 14 ) 15 16 func CreateAutoIncr(db *gorm.DB, gormDialector gorm.Dialector, autoIncrIncr int64) { 17 withReturning := false 18 dialector := gormDialector.(*Dialector) 19 if !dialector.Config.SkipInitializeWithVersion && strings.Contains(dialector.ServerVersion, "MariaDB") { 20 withReturning = checkVersion(dialector.ServerVersion, "10.5") 21 } 22 23 lastInsertIDReversed := false 24 if !dialector.Config.DisableWithReturning && withReturning { 25 lastInsertIDReversed = true 26 } 27 28 comUtl.MustSuccess( 29 db.Callback().Create().Replace("gorm:create", func(db *gorm.DB) { 30 if db.Error != nil { 31 return 32 } 33 34 BuildCreateSQL(db) 35 36 isDryRun := !db.DryRun && db.Error == nil 37 if !isDryRun { 38 return 39 } 40 41 ok, mode := hasReturning(db, utils.Contains(db.Callback().Create().Clauses, "RETURNING")) 42 if ok { 43 if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { 44 if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { 45 mode |= gorm.ScanOnConflictDoNothing 46 } 47 } 48 49 rows, err := db.Statement.ConnPool.QueryContext( 50 db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., 51 ) 52 if db.AddError(err) == nil { 53 defer func() { _ = db.AddError(rows.Close()) }() 54 gorm.Scan(rows, db, mode) 55 } 56 57 return 58 } 59 60 result, err := db.Statement.ConnPool.ExecContext( 61 db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., 62 ) 63 if err != nil { 64 _ = db.AddError(err) 65 return 66 } 67 68 db.RowsAffected, _ = result.RowsAffected() 69 if db.RowsAffected != 0 && db.Statement.Schema != nil && 70 db.Statement.Schema.PrioritizedPrimaryField != nil && 71 db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { 72 insertID, err := result.LastInsertId() 73 insertOk := err == nil && insertID > 0 74 if !insertOk { 75 _ = db.AddError(err) 76 return 77 } 78 79 switch db.Statement.ReflectValue.Kind() { 80 case reflect.Slice, reflect.Array: 81 if lastInsertIDReversed { 82 for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { 83 rv := db.Statement.ReflectValue.Index(i) 84 if reflect.Indirect(rv).Kind() != reflect.Struct { 85 break 86 } 87 88 _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) 89 if isZero { 90 _ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) 91 //insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement 92 insertID -= autoIncrIncr 93 } 94 } 95 } else { 96 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { 97 rv := db.Statement.ReflectValue.Index(i) 98 if reflect.Indirect(rv).Kind() != reflect.Struct { 99 break 100 } 101 102 if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero { 103 _ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID)) 104 //insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement 105 insertID += autoIncrIncr 106 } 107 } 108 } 109 case reflect.Struct: 110 _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) 111 if isZero { 112 _ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) 113 } 114 } 115 } 116 }), 117 ) 118 }