github.com/systematiccaos/gorm@v1.22.6/callbacks/create.go (about) 1 package callbacks 2 3 import ( 4 "fmt" 5 "reflect" 6 7 "github.com/systematiccaos/gorm" 8 "github.com/systematiccaos/gorm/clause" 9 "github.com/systematiccaos/gorm/schema" 10 "github.com/systematiccaos/gorm/utils" 11 ) 12 13 func BeforeCreate(db *gorm.DB) { 14 if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { 15 callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { 16 if db.Statement.Schema.BeforeSave { 17 if i, ok := value.(BeforeSaveInterface); ok { 18 called = true 19 db.AddError(i.BeforeSave(tx)) 20 } 21 } 22 23 if db.Statement.Schema.BeforeCreate { 24 if i, ok := value.(BeforeCreateInterface); ok { 25 called = true 26 db.AddError(i.BeforeCreate(tx)) 27 } 28 } 29 return called 30 }) 31 } 32 } 33 34 func Create(config *Config) func(db *gorm.DB) { 35 supportReturning := utils.Contains(config.CreateClauses, "RETURNING") 36 37 return func(db *gorm.DB) { 38 if db.Error != nil { 39 return 40 } 41 42 if db.Statement.Schema != nil { 43 if !db.Statement.Unscoped { 44 for _, c := range db.Statement.Schema.CreateClauses { 45 db.Statement.AddClause(c) 46 } 47 } 48 49 if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { 50 if _, ok := db.Statement.Clauses["RETURNING"]; !ok { 51 fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) 52 for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { 53 fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) 54 } 55 db.Statement.AddClause(clause.Returning{Columns: fromColumns}) 56 } 57 } 58 } 59 60 if db.Statement.SQL.Len() == 0 { 61 db.Statement.SQL.Grow(180) 62 db.Statement.AddClauseIfNotExists(clause.Insert{}) 63 db.Statement.AddClause(ConvertToCreateValues(db.Statement)) 64 65 db.Statement.Build(db.Statement.BuildClauses...) 66 } 67 68 isDryRun := !db.DryRun && db.Error == nil 69 if !isDryRun { 70 return 71 } 72 73 ok, mode := hasReturning(db, supportReturning) 74 if ok { 75 if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { 76 if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing { 77 mode |= gorm.ScanOnConflictDoNothing 78 } 79 } 80 81 rows, err := db.Statement.ConnPool.QueryContext( 82 db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., 83 ) 84 if db.AddError(err) == nil { 85 gorm.Scan(rows, db, mode) 86 db.AddError(rows.Close()) 87 } 88 89 return 90 } 91 92 result, err := db.Statement.ConnPool.ExecContext( 93 db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., 94 ) 95 if err != nil { 96 db.AddError(err) 97 return 98 } 99 100 db.RowsAffected, _ = result.RowsAffected() 101 if db.RowsAffected != 0 && db.Statement.Schema != nil && 102 db.Statement.Schema.PrioritizedPrimaryField != nil && 103 db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { 104 insertID, err := result.LastInsertId() 105 insertOk := err == nil && insertID > 0 106 if !insertOk { 107 db.AddError(err) 108 return 109 } 110 111 switch db.Statement.ReflectValue.Kind() { 112 case reflect.Slice, reflect.Array: 113 if config.LastInsertIDReversed { 114 for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { 115 rv := db.Statement.ReflectValue.Index(i) 116 if reflect.Indirect(rv).Kind() != reflect.Struct { 117 break 118 } 119 120 _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv) 121 if isZero { 122 db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) 123 insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement 124 } 125 } 126 } else { 127 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { 128 rv := db.Statement.ReflectValue.Index(i) 129 if reflect.Indirect(rv).Kind() != reflect.Struct { 130 break 131 } 132 133 if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero { 134 db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID) 135 insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement 136 } 137 } 138 } 139 case reflect.Struct: 140 _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue) 141 if isZero { 142 db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) 143 } 144 } 145 } 146 } 147 } 148 149 func AfterCreate(db *gorm.DB) { 150 if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { 151 callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { 152 if db.Statement.Schema.AfterSave { 153 if i, ok := value.(AfterSaveInterface); ok { 154 called = true 155 db.AddError(i.AfterSave(tx)) 156 } 157 } 158 159 if db.Statement.Schema.AfterCreate { 160 if i, ok := value.(AfterCreateInterface); ok { 161 called = true 162 db.AddError(i.AfterCreate(tx)) 163 } 164 } 165 return called 166 }) 167 } 168 } 169 170 // ConvertToCreateValues convert to create values 171 func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { 172 curTime := stmt.DB.NowFunc() 173 174 switch value := stmt.Dest.(type) { 175 case map[string]interface{}: 176 values = ConvertMapToValuesForCreate(stmt, value) 177 case *map[string]interface{}: 178 values = ConvertMapToValuesForCreate(stmt, *value) 179 case []map[string]interface{}: 180 values = ConvertSliceOfMapToValuesForCreate(stmt, value) 181 case *[]map[string]interface{}: 182 values = ConvertSliceOfMapToValuesForCreate(stmt, *value) 183 default: 184 var ( 185 selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) 186 _, updateTrackTime = stmt.Get("gorm:update_track_time") 187 isZero bool 188 ) 189 stmt.Settings.Delete("gorm:update_track_time") 190 191 values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} 192 193 for _, db := range stmt.Schema.DBNames { 194 if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { 195 if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) { 196 values.Columns = append(values.Columns, clause.Column{Name: db}) 197 } 198 } 199 } 200 201 switch stmt.ReflectValue.Kind() { 202 case reflect.Slice, reflect.Array: 203 rValLen := stmt.ReflectValue.Len() 204 stmt.SQL.Grow(rValLen * 18) 205 values.Values = make([][]interface{}, rValLen) 206 if rValLen == 0 { 207 stmt.AddError(gorm.ErrEmptySlice) 208 return 209 } 210 211 defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} 212 for i := 0; i < rValLen; i++ { 213 rv := reflect.Indirect(stmt.ReflectValue.Index(i)) 214 if !rv.IsValid() { 215 stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) 216 return 217 } 218 219 values.Values[i] = make([]interface{}, len(values.Columns)) 220 for idx, column := range values.Columns { 221 field := stmt.Schema.FieldsByDBName[column.Name] 222 if values.Values[i][idx], isZero = field.ValueOf(rv); isZero { 223 if field.DefaultValueInterface != nil { 224 values.Values[i][idx] = field.DefaultValueInterface 225 field.Set(rv, field.DefaultValueInterface) 226 } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { 227 field.Set(rv, curTime) 228 values.Values[i][idx], _ = field.ValueOf(rv) 229 } 230 } else if field.AutoUpdateTime > 0 && updateTrackTime { 231 field.Set(rv, curTime) 232 values.Values[i][idx], _ = field.ValueOf(rv) 233 } 234 } 235 236 for _, field := range stmt.Schema.FieldsWithDefaultDBValue { 237 if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { 238 if rvOfvalue, isZero := field.ValueOf(rv); !isZero { 239 if len(defaultValueFieldsHavingValue[field]) == 0 { 240 defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) 241 } 242 defaultValueFieldsHavingValue[field][i] = rvOfvalue 243 } 244 } 245 } 246 } 247 248 for field, vs := range defaultValueFieldsHavingValue { 249 values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) 250 for idx := range values.Values { 251 if vs[idx] == nil { 252 values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field)) 253 } else { 254 values.Values[idx] = append(values.Values[idx], vs[idx]) 255 } 256 } 257 } 258 case reflect.Struct: 259 values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} 260 for idx, column := range values.Columns { 261 field := stmt.Schema.FieldsByDBName[column.Name] 262 if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero { 263 if field.DefaultValueInterface != nil { 264 values.Values[0][idx] = field.DefaultValueInterface 265 field.Set(stmt.ReflectValue, field.DefaultValueInterface) 266 } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { 267 field.Set(stmt.ReflectValue, curTime) 268 values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) 269 } 270 } else if field.AutoUpdateTime > 0 && updateTrackTime { 271 field.Set(stmt.ReflectValue, curTime) 272 values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue) 273 } 274 } 275 276 for _, field := range stmt.Schema.FieldsWithDefaultDBValue { 277 if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { 278 if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero { 279 values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) 280 values.Values[0] = append(values.Values[0], rvOfvalue) 281 } 282 } 283 } 284 default: 285 stmt.AddError(gorm.ErrInvalidData) 286 } 287 } 288 289 if c, ok := stmt.Clauses["ON CONFLICT"]; ok { 290 if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { 291 if stmt.Schema != nil && len(values.Columns) >= 1 { 292 selectColumns, restricted := stmt.SelectAndOmitColumns(true, true) 293 294 columns := make([]string, 0, len(values.Columns)-1) 295 for _, column := range values.Columns { 296 if field := stmt.Schema.LookUpField(column.Name); field != nil { 297 if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { 298 if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { 299 if field.AutoUpdateTime > 0 { 300 assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} 301 switch field.AutoUpdateTime { 302 case schema.UnixNanosecond: 303 assignment.Value = curTime.UnixNano() 304 case schema.UnixMillisecond: 305 assignment.Value = curTime.UnixNano() / 1e6 306 case schema.UnixSecond: 307 assignment.Value = curTime.Unix() 308 } 309 310 onConflict.DoUpdates = append(onConflict.DoUpdates, assignment) 311 } else { 312 columns = append(columns, column.Name) 313 } 314 } 315 } 316 } 317 } 318 319 onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...) 320 if len(onConflict.DoUpdates) == 0 { 321 onConflict.DoNothing = true 322 } 323 324 // use primary fields as default OnConflict columns 325 if len(onConflict.Columns) == 0 { 326 for _, field := range stmt.Schema.PrimaryFields { 327 onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName}) 328 } 329 } 330 stmt.AddClause(onConflict) 331 } 332 } 333 } 334 335 return values 336 }