github.com/systematiccaos/gorm@v1.22.6/callbacks/update.go (about) 1 package callbacks 2 3 import ( 4 "reflect" 5 "sort" 6 7 "github.com/systematiccaos/gorm" 8 "github.com/systematiccaos/gorm/clause" 9 "github.com/systematiccaos/gorm/schema" 10 "github.com/systematiccaos/gorm/utils" 11 ) 12 13 func SetupUpdateReflectValue(db *gorm.DB) { 14 if db.Error == nil && db.Statement.Schema != nil { 15 if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { 16 db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) 17 for db.Statement.ReflectValue.Kind() == reflect.Ptr { 18 db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() 19 } 20 21 if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { 22 for _, rel := range db.Statement.Schema.Relationships.BelongsTo { 23 if _, ok := dest[rel.Name]; ok { 24 rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name]) 25 } 26 } 27 } 28 } 29 } 30 } 31 32 func BeforeUpdate(db *gorm.DB) { 33 if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { 34 callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { 35 if db.Statement.Schema.BeforeSave { 36 if i, ok := value.(BeforeSaveInterface); ok { 37 called = true 38 db.AddError(i.BeforeSave(tx)) 39 } 40 } 41 42 if db.Statement.Schema.BeforeUpdate { 43 if i, ok := value.(BeforeUpdateInterface); ok { 44 called = true 45 db.AddError(i.BeforeUpdate(tx)) 46 } 47 } 48 49 return called 50 }) 51 } 52 } 53 54 func Update(config *Config) func(db *gorm.DB) { 55 supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") 56 57 return func(db *gorm.DB) { 58 if db.Error != nil { 59 return 60 } 61 62 if db.Statement.SQL.Len() == 0 { 63 db.Statement.SQL.Grow(180) 64 db.Statement.AddClauseIfNotExists(clause.Update{}) 65 if set := ConvertToAssignments(db.Statement); len(set) != 0 { 66 db.Statement.AddClause(set) 67 } else if _, ok := db.Statement.Clauses["SET"]; !ok { 68 return 69 } 70 71 } 72 73 if db.Statement.Schema != nil { 74 for _, c := range db.Statement.Schema.UpdateClauses { 75 db.Statement.AddClause(c) 76 } 77 } 78 79 if db.Statement.SQL.Len() == 0 { 80 db.Statement.Build(db.Statement.BuildClauses...) 81 } 82 83 if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok { 84 db.AddError(gorm.ErrMissingWhereClause) 85 return 86 } 87 88 if !db.DryRun && db.Error == nil { 89 if ok, mode := hasReturning(db, supportReturning); ok { 90 if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { 91 dest := db.Statement.Dest 92 db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() 93 gorm.Scan(rows, db, mode) 94 db.Statement.Dest = dest 95 db.AddError(rows.Close()) 96 } 97 } else { 98 result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) 99 100 if db.AddError(err) == nil { 101 db.RowsAffected, _ = result.RowsAffected() 102 } 103 } 104 } 105 } 106 } 107 108 func AfterUpdate(db *gorm.DB) { 109 if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { 110 callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { 111 if db.Statement.Schema.AfterSave { 112 if i, ok := value.(AfterSaveInterface); ok { 113 called = true 114 db.AddError(i.AfterSave(tx)) 115 } 116 } 117 118 if db.Statement.Schema.AfterUpdate { 119 if i, ok := value.(AfterUpdateInterface); ok { 120 called = true 121 db.AddError(i.AfterUpdate(tx)) 122 } 123 } 124 return called 125 }) 126 } 127 } 128 129 // ConvertToAssignments convert to update assignments 130 func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { 131 var ( 132 selectColumns, restricted = stmt.SelectAndOmitColumns(false, true) 133 assignValue func(field *schema.Field, value interface{}) 134 ) 135 136 switch stmt.ReflectValue.Kind() { 137 case reflect.Slice, reflect.Array: 138 assignValue = func(field *schema.Field, value interface{}) { 139 for i := 0; i < stmt.ReflectValue.Len(); i++ { 140 field.Set(stmt.ReflectValue.Index(i), value) 141 } 142 } 143 case reflect.Struct: 144 assignValue = func(field *schema.Field, value interface{}) { 145 if stmt.ReflectValue.CanAddr() { 146 field.Set(stmt.ReflectValue, value) 147 } 148 } 149 default: 150 assignValue = func(field *schema.Field, value interface{}) { 151 } 152 } 153 154 updatingValue := reflect.ValueOf(stmt.Dest) 155 for updatingValue.Kind() == reflect.Ptr { 156 updatingValue = updatingValue.Elem() 157 } 158 159 if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { 160 switch stmt.ReflectValue.Kind() { 161 case reflect.Slice, reflect.Array: 162 if size := stmt.ReflectValue.Len(); size > 0 { 163 var primaryKeyExprs []clause.Expression 164 for i := 0; i < size; i++ { 165 var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields)) 166 var notZero bool 167 for idx, field := range stmt.Schema.PrimaryFields { 168 value, isZero := field.ValueOf(stmt.ReflectValue.Index(i)) 169 exprs[idx] = clause.Eq{Column: field.DBName, Value: value} 170 notZero = notZero || !isZero 171 } 172 if notZero { 173 primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...)) 174 } 175 } 176 177 stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}}) 178 } 179 case reflect.Struct: 180 for _, field := range stmt.Schema.PrimaryFields { 181 if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero { 182 stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) 183 } 184 } 185 } 186 } 187 188 switch value := updatingValue.Interface().(type) { 189 case map[string]interface{}: 190 set = make([]clause.Assignment, 0, len(value)) 191 192 keys := make([]string, 0, len(value)) 193 for k := range value { 194 keys = append(keys, k) 195 } 196 sort.Strings(keys) 197 198 for _, k := range keys { 199 kv := value[k] 200 if _, ok := kv.(*gorm.DB); ok { 201 kv = []interface{}{kv} 202 } 203 204 if stmt.Schema != nil { 205 if field := stmt.Schema.LookUpField(k); field != nil { 206 if field.DBName != "" { 207 if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { 208 set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) 209 assignValue(field, value[k]) 210 } 211 } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { 212 assignValue(field, value[k]) 213 } 214 continue 215 } 216 } 217 218 if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { 219 set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) 220 } 221 } 222 223 if !stmt.SkipHooks && stmt.Schema != nil { 224 for _, dbName := range stmt.Schema.DBNames { 225 field := stmt.Schema.LookUpField(dbName) 226 if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { 227 if v, ok := selectColumns[field.DBName]; (ok && v) || !ok { 228 now := stmt.DB.NowFunc() 229 assignValue(field, now) 230 231 if field.AutoUpdateTime == schema.UnixNanosecond { 232 set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) 233 } else if field.AutoUpdateTime == schema.UnixMillisecond { 234 set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6}) 235 } else if field.GORMDataType == schema.Time { 236 set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) 237 } else { 238 set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) 239 } 240 } 241 } 242 } 243 } 244 default: 245 var updatingSchema = stmt.Schema 246 if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { 247 // different schema 248 updatingStmt := &gorm.Statement{DB: stmt.DB} 249 if err := updatingStmt.Parse(stmt.Dest); err == nil { 250 updatingSchema = updatingStmt.Schema 251 } 252 } 253 254 switch updatingValue.Kind() { 255 case reflect.Struct: 256 set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) 257 for _, dbName := range stmt.Schema.DBNames { 258 if field := updatingSchema.LookUpField(dbName); field != nil { 259 if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { 260 if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { 261 value, isZero := field.ValueOf(updatingValue) 262 if !stmt.SkipHooks && field.AutoUpdateTime > 0 { 263 if field.AutoUpdateTime == schema.UnixNanosecond { 264 value = stmt.DB.NowFunc().UnixNano() 265 } else if field.AutoUpdateTime == schema.UnixMillisecond { 266 value = stmt.DB.NowFunc().UnixNano() / 1e6 267 } else if field.GORMDataType == schema.Time { 268 value = stmt.DB.NowFunc() 269 } else { 270 value = stmt.DB.NowFunc().Unix() 271 } 272 isZero = false 273 } 274 275 if (ok || !isZero) && field.Updatable { 276 set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) 277 assignValue(field, value) 278 } 279 } 280 } else { 281 if value, isZero := field.ValueOf(updatingValue); !isZero { 282 stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) 283 } 284 } 285 } 286 } 287 default: 288 stmt.AddError(gorm.ErrInvalidData) 289 } 290 } 291 292 return 293 }