github.com/systematiccaos/gorm@v1.22.6/callbacks/update.go (about)

     1  package callbacks
     2  
     3  import (
     4  	"reflect"
     5  	"sort"
     6  
     7  	"github.com/systematiccaos/gorm"
     8  	"github.com/systematiccaos/gorm/clause"
     9  	"github.com/systematiccaos/gorm/schema"
    10  	"github.com/systematiccaos/gorm/utils"
    11  )
    12  
    13  func SetupUpdateReflectValue(db *gorm.DB) {
    14  	if db.Error == nil && db.Statement.Schema != nil {
    15  		if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
    16  			db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
    17  			for db.Statement.ReflectValue.Kind() == reflect.Ptr {
    18  				db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
    19  			}
    20  
    21  			if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
    22  				for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
    23  					if _, ok := dest[rel.Name]; ok {
    24  						rel.Field.Set(db.Statement.ReflectValue, dest[rel.Name])
    25  					}
    26  				}
    27  			}
    28  		}
    29  	}
    30  }
    31  
    32  func BeforeUpdate(db *gorm.DB) {
    33  	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
    34  		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
    35  			if db.Statement.Schema.BeforeSave {
    36  				if i, ok := value.(BeforeSaveInterface); ok {
    37  					called = true
    38  					db.AddError(i.BeforeSave(tx))
    39  				}
    40  			}
    41  
    42  			if db.Statement.Schema.BeforeUpdate {
    43  				if i, ok := value.(BeforeUpdateInterface); ok {
    44  					called = true
    45  					db.AddError(i.BeforeUpdate(tx))
    46  				}
    47  			}
    48  
    49  			return called
    50  		})
    51  	}
    52  }
    53  
    54  func Update(config *Config) func(db *gorm.DB) {
    55  	supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
    56  
    57  	return func(db *gorm.DB) {
    58  		if db.Error != nil {
    59  			return
    60  		}
    61  
    62  		if db.Statement.SQL.Len() == 0 {
    63  			db.Statement.SQL.Grow(180)
    64  			db.Statement.AddClauseIfNotExists(clause.Update{})
    65  			if set := ConvertToAssignments(db.Statement); len(set) != 0 {
    66  				db.Statement.AddClause(set)
    67  			} else if _, ok := db.Statement.Clauses["SET"]; !ok {
    68  				return
    69  			}
    70  
    71  		}
    72  
    73  		if db.Statement.Schema != nil {
    74  			for _, c := range db.Statement.Schema.UpdateClauses {
    75  				db.Statement.AddClause(c)
    76  			}
    77  		}
    78  
    79  		if db.Statement.SQL.Len() == 0 {
    80  			db.Statement.Build(db.Statement.BuildClauses...)
    81  		}
    82  
    83  		if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok {
    84  			db.AddError(gorm.ErrMissingWhereClause)
    85  			return
    86  		}
    87  
    88  		if !db.DryRun && db.Error == nil {
    89  			if ok, mode := hasReturning(db, supportReturning); ok {
    90  				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
    91  					dest := db.Statement.Dest
    92  					db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
    93  					gorm.Scan(rows, db, mode)
    94  					db.Statement.Dest = dest
    95  					db.AddError(rows.Close())
    96  				}
    97  			} else {
    98  				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
    99  
   100  				if db.AddError(err) == nil {
   101  					db.RowsAffected, _ = result.RowsAffected()
   102  				}
   103  			}
   104  		}
   105  	}
   106  }
   107  
   108  func AfterUpdate(db *gorm.DB) {
   109  	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
   110  		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
   111  			if db.Statement.Schema.AfterSave {
   112  				if i, ok := value.(AfterSaveInterface); ok {
   113  					called = true
   114  					db.AddError(i.AfterSave(tx))
   115  				}
   116  			}
   117  
   118  			if db.Statement.Schema.AfterUpdate {
   119  				if i, ok := value.(AfterUpdateInterface); ok {
   120  					called = true
   121  					db.AddError(i.AfterUpdate(tx))
   122  				}
   123  			}
   124  			return called
   125  		})
   126  	}
   127  }
   128  
   129  // ConvertToAssignments convert to update assignments
   130  func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
   131  	var (
   132  		selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
   133  		assignValue               func(field *schema.Field, value interface{})
   134  	)
   135  
   136  	switch stmt.ReflectValue.Kind() {
   137  	case reflect.Slice, reflect.Array:
   138  		assignValue = func(field *schema.Field, value interface{}) {
   139  			for i := 0; i < stmt.ReflectValue.Len(); i++ {
   140  				field.Set(stmt.ReflectValue.Index(i), value)
   141  			}
   142  		}
   143  	case reflect.Struct:
   144  		assignValue = func(field *schema.Field, value interface{}) {
   145  			if stmt.ReflectValue.CanAddr() {
   146  				field.Set(stmt.ReflectValue, value)
   147  			}
   148  		}
   149  	default:
   150  		assignValue = func(field *schema.Field, value interface{}) {
   151  		}
   152  	}
   153  
   154  	updatingValue := reflect.ValueOf(stmt.Dest)
   155  	for updatingValue.Kind() == reflect.Ptr {
   156  		updatingValue = updatingValue.Elem()
   157  	}
   158  
   159  	if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
   160  		switch stmt.ReflectValue.Kind() {
   161  		case reflect.Slice, reflect.Array:
   162  			if size := stmt.ReflectValue.Len(); size > 0 {
   163  				var primaryKeyExprs []clause.Expression
   164  				for i := 0; i < size; i++ {
   165  					var exprs = make([]clause.Expression, len(stmt.Schema.PrimaryFields))
   166  					var notZero bool
   167  					for idx, field := range stmt.Schema.PrimaryFields {
   168  						value, isZero := field.ValueOf(stmt.ReflectValue.Index(i))
   169  						exprs[idx] = clause.Eq{Column: field.DBName, Value: value}
   170  						notZero = notZero || !isZero
   171  					}
   172  					if notZero {
   173  						primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
   174  					}
   175  				}
   176  
   177  				stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(primaryKeyExprs...)}})
   178  			}
   179  		case reflect.Struct:
   180  			for _, field := range stmt.Schema.PrimaryFields {
   181  				if value, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
   182  					stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
   183  				}
   184  			}
   185  		}
   186  	}
   187  
   188  	switch value := updatingValue.Interface().(type) {
   189  	case map[string]interface{}:
   190  		set = make([]clause.Assignment, 0, len(value))
   191  
   192  		keys := make([]string, 0, len(value))
   193  		for k := range value {
   194  			keys = append(keys, k)
   195  		}
   196  		sort.Strings(keys)
   197  
   198  		for _, k := range keys {
   199  			kv := value[k]
   200  			if _, ok := kv.(*gorm.DB); ok {
   201  				kv = []interface{}{kv}
   202  			}
   203  
   204  			if stmt.Schema != nil {
   205  				if field := stmt.Schema.LookUpField(k); field != nil {
   206  					if field.DBName != "" {
   207  						if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
   208  							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
   209  							assignValue(field, value[k])
   210  						}
   211  					} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
   212  						assignValue(field, value[k])
   213  					}
   214  					continue
   215  				}
   216  			}
   217  
   218  			if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
   219  				set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
   220  			}
   221  		}
   222  
   223  		if !stmt.SkipHooks && stmt.Schema != nil {
   224  			for _, dbName := range stmt.Schema.DBNames {
   225  				field := stmt.Schema.LookUpField(dbName)
   226  				if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
   227  					if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
   228  						now := stmt.DB.NowFunc()
   229  						assignValue(field, now)
   230  
   231  						if field.AutoUpdateTime == schema.UnixNanosecond {
   232  							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
   233  						} else if field.AutoUpdateTime == schema.UnixMillisecond {
   234  							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
   235  						} else if field.GORMDataType == schema.Time {
   236  							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
   237  						} else {
   238  							set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
   239  						}
   240  					}
   241  				}
   242  			}
   243  		}
   244  	default:
   245  		var updatingSchema = stmt.Schema
   246  		if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
   247  			// different schema
   248  			updatingStmt := &gorm.Statement{DB: stmt.DB}
   249  			if err := updatingStmt.Parse(stmt.Dest); err == nil {
   250  				updatingSchema = updatingStmt.Schema
   251  			}
   252  		}
   253  
   254  		switch updatingValue.Kind() {
   255  		case reflect.Struct:
   256  			set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
   257  			for _, dbName := range stmt.Schema.DBNames {
   258  				if field := updatingSchema.LookUpField(dbName); field != nil {
   259  					if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
   260  						if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
   261  							value, isZero := field.ValueOf(updatingValue)
   262  							if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
   263  								if field.AutoUpdateTime == schema.UnixNanosecond {
   264  									value = stmt.DB.NowFunc().UnixNano()
   265  								} else if field.AutoUpdateTime == schema.UnixMillisecond {
   266  									value = stmt.DB.NowFunc().UnixNano() / 1e6
   267  								} else if field.GORMDataType == schema.Time {
   268  									value = stmt.DB.NowFunc()
   269  								} else {
   270  									value = stmt.DB.NowFunc().Unix()
   271  								}
   272  								isZero = false
   273  							}
   274  
   275  							if (ok || !isZero) && field.Updatable {
   276  								set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
   277  								assignValue(field, value)
   278  							}
   279  						}
   280  					} else {
   281  						if value, isZero := field.ValueOf(updatingValue); !isZero {
   282  							stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
   283  						}
   284  					}
   285  				}
   286  			}
   287  		default:
   288  			stmt.AddError(gorm.ErrInvalidData)
   289  		}
   290  	}
   291  
   292  	return
   293  }