github.com/systematiccaos/gorm@v1.22.6/callbacks/create.go (about)

     1  package callbacks
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  
     7  	"github.com/systematiccaos/gorm"
     8  	"github.com/systematiccaos/gorm/clause"
     9  	"github.com/systematiccaos/gorm/schema"
    10  	"github.com/systematiccaos/gorm/utils"
    11  )
    12  
    13  func BeforeCreate(db *gorm.DB) {
    14  	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
    15  		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
    16  			if db.Statement.Schema.BeforeSave {
    17  				if i, ok := value.(BeforeSaveInterface); ok {
    18  					called = true
    19  					db.AddError(i.BeforeSave(tx))
    20  				}
    21  			}
    22  
    23  			if db.Statement.Schema.BeforeCreate {
    24  				if i, ok := value.(BeforeCreateInterface); ok {
    25  					called = true
    26  					db.AddError(i.BeforeCreate(tx))
    27  				}
    28  			}
    29  			return called
    30  		})
    31  	}
    32  }
    33  
    34  func Create(config *Config) func(db *gorm.DB) {
    35  	supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
    36  
    37  	return func(db *gorm.DB) {
    38  		if db.Error != nil {
    39  			return
    40  		}
    41  
    42  		if db.Statement.Schema != nil {
    43  			if !db.Statement.Unscoped {
    44  				for _, c := range db.Statement.Schema.CreateClauses {
    45  					db.Statement.AddClause(c)
    46  				}
    47  			}
    48  
    49  			if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
    50  				if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
    51  					fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
    52  					for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
    53  						fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
    54  					}
    55  					db.Statement.AddClause(clause.Returning{Columns: fromColumns})
    56  				}
    57  			}
    58  		}
    59  
    60  		if db.Statement.SQL.Len() == 0 {
    61  			db.Statement.SQL.Grow(180)
    62  			db.Statement.AddClauseIfNotExists(clause.Insert{})
    63  			db.Statement.AddClause(ConvertToCreateValues(db.Statement))
    64  
    65  			db.Statement.Build(db.Statement.BuildClauses...)
    66  		}
    67  
    68  		isDryRun := !db.DryRun && db.Error == nil
    69  		if !isDryRun {
    70  			return
    71  		}
    72  
    73  		ok, mode := hasReturning(db, supportReturning)
    74  		if ok {
    75  			if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
    76  				if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
    77  					mode |= gorm.ScanOnConflictDoNothing
    78  				}
    79  			}
    80  
    81  			rows, err := db.Statement.ConnPool.QueryContext(
    82  				db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
    83  			)
    84  			if db.AddError(err) == nil {
    85  				gorm.Scan(rows, db, mode)
    86  				db.AddError(rows.Close())
    87  			}
    88  
    89  			return
    90  		}
    91  
    92  		result, err := db.Statement.ConnPool.ExecContext(
    93  			db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
    94  		)
    95  		if err != nil {
    96  			db.AddError(err)
    97  			return
    98  		}
    99  
   100  		db.RowsAffected, _ = result.RowsAffected()
   101  		if db.RowsAffected != 0 && db.Statement.Schema != nil &&
   102  			db.Statement.Schema.PrioritizedPrimaryField != nil &&
   103  			db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
   104  			insertID, err := result.LastInsertId()
   105  			insertOk := err == nil && insertID > 0
   106  			if !insertOk {
   107  				db.AddError(err)
   108  				return
   109  			}
   110  
   111  			switch db.Statement.ReflectValue.Kind() {
   112  			case reflect.Slice, reflect.Array:
   113  				if config.LastInsertIDReversed {
   114  					for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
   115  						rv := db.Statement.ReflectValue.Index(i)
   116  						if reflect.Indirect(rv).Kind() != reflect.Struct {
   117  							break
   118  						}
   119  
   120  						_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv)
   121  						if isZero {
   122  							db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
   123  							insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
   124  						}
   125  					}
   126  				} else {
   127  					for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
   128  						rv := db.Statement.ReflectValue.Index(i)
   129  						if reflect.Indirect(rv).Kind() != reflect.Struct {
   130  							break
   131  						}
   132  
   133  						if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(rv); isZero {
   134  							db.Statement.Schema.PrioritizedPrimaryField.Set(rv, insertID)
   135  							insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
   136  						}
   137  					}
   138  				}
   139  			case reflect.Struct:
   140  				_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.ReflectValue)
   141  				if isZero {
   142  					db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
   143  				}
   144  			}
   145  		}
   146  	}
   147  }
   148  
   149  func AfterCreate(db *gorm.DB) {
   150  	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
   151  		callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
   152  			if db.Statement.Schema.AfterSave {
   153  				if i, ok := value.(AfterSaveInterface); ok {
   154  					called = true
   155  					db.AddError(i.AfterSave(tx))
   156  				}
   157  			}
   158  
   159  			if db.Statement.Schema.AfterCreate {
   160  				if i, ok := value.(AfterCreateInterface); ok {
   161  					called = true
   162  					db.AddError(i.AfterCreate(tx))
   163  				}
   164  			}
   165  			return called
   166  		})
   167  	}
   168  }
   169  
   170  // ConvertToCreateValues convert to create values
   171  func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
   172  	curTime := stmt.DB.NowFunc()
   173  
   174  	switch value := stmt.Dest.(type) {
   175  	case map[string]interface{}:
   176  		values = ConvertMapToValuesForCreate(stmt, value)
   177  	case *map[string]interface{}:
   178  		values = ConvertMapToValuesForCreate(stmt, *value)
   179  	case []map[string]interface{}:
   180  		values = ConvertSliceOfMapToValuesForCreate(stmt, value)
   181  	case *[]map[string]interface{}:
   182  		values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
   183  	default:
   184  		var (
   185  			selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
   186  			_, updateTrackTime        = stmt.Get("gorm:update_track_time")
   187  			isZero                    bool
   188  		)
   189  		stmt.Settings.Delete("gorm:update_track_time")
   190  
   191  		values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
   192  
   193  		for _, db := range stmt.Schema.DBNames {
   194  			if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
   195  				if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
   196  					values.Columns = append(values.Columns, clause.Column{Name: db})
   197  				}
   198  			}
   199  		}
   200  
   201  		switch stmt.ReflectValue.Kind() {
   202  		case reflect.Slice, reflect.Array:
   203  			rValLen := stmt.ReflectValue.Len()
   204  			stmt.SQL.Grow(rValLen * 18)
   205  			values.Values = make([][]interface{}, rValLen)
   206  			if rValLen == 0 {
   207  				stmt.AddError(gorm.ErrEmptySlice)
   208  				return
   209  			}
   210  
   211  			defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
   212  			for i := 0; i < rValLen; i++ {
   213  				rv := reflect.Indirect(stmt.ReflectValue.Index(i))
   214  				if !rv.IsValid() {
   215  					stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
   216  					return
   217  				}
   218  
   219  				values.Values[i] = make([]interface{}, len(values.Columns))
   220  				for idx, column := range values.Columns {
   221  					field := stmt.Schema.FieldsByDBName[column.Name]
   222  					if values.Values[i][idx], isZero = field.ValueOf(rv); isZero {
   223  						if field.DefaultValueInterface != nil {
   224  							values.Values[i][idx] = field.DefaultValueInterface
   225  							field.Set(rv, field.DefaultValueInterface)
   226  						} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
   227  							field.Set(rv, curTime)
   228  							values.Values[i][idx], _ = field.ValueOf(rv)
   229  						}
   230  					} else if field.AutoUpdateTime > 0 && updateTrackTime {
   231  						field.Set(rv, curTime)
   232  						values.Values[i][idx], _ = field.ValueOf(rv)
   233  					}
   234  				}
   235  
   236  				for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
   237  					if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
   238  						if rvOfvalue, isZero := field.ValueOf(rv); !isZero {
   239  							if len(defaultValueFieldsHavingValue[field]) == 0 {
   240  								defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
   241  							}
   242  							defaultValueFieldsHavingValue[field][i] = rvOfvalue
   243  						}
   244  					}
   245  				}
   246  			}
   247  
   248  			for field, vs := range defaultValueFieldsHavingValue {
   249  				values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
   250  				for idx := range values.Values {
   251  					if vs[idx] == nil {
   252  						values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
   253  					} else {
   254  						values.Values[idx] = append(values.Values[idx], vs[idx])
   255  					}
   256  				}
   257  			}
   258  		case reflect.Struct:
   259  			values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
   260  			for idx, column := range values.Columns {
   261  				field := stmt.Schema.FieldsByDBName[column.Name]
   262  				if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
   263  					if field.DefaultValueInterface != nil {
   264  						values.Values[0][idx] = field.DefaultValueInterface
   265  						field.Set(stmt.ReflectValue, field.DefaultValueInterface)
   266  					} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
   267  						field.Set(stmt.ReflectValue, curTime)
   268  						values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
   269  					}
   270  				} else if field.AutoUpdateTime > 0 && updateTrackTime {
   271  					field.Set(stmt.ReflectValue, curTime)
   272  					values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
   273  				}
   274  			}
   275  
   276  			for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
   277  				if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
   278  					if rvOfvalue, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
   279  						values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
   280  						values.Values[0] = append(values.Values[0], rvOfvalue)
   281  					}
   282  				}
   283  			}
   284  		default:
   285  			stmt.AddError(gorm.ErrInvalidData)
   286  		}
   287  	}
   288  
   289  	if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
   290  		if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
   291  			if stmt.Schema != nil && len(values.Columns) >= 1 {
   292  				selectColumns, restricted := stmt.SelectAndOmitColumns(true, true)
   293  
   294  				columns := make([]string, 0, len(values.Columns)-1)
   295  				for _, column := range values.Columns {
   296  					if field := stmt.Schema.LookUpField(column.Name); field != nil {
   297  						if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
   298  							if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 {
   299  								if field.AutoUpdateTime > 0 {
   300  									assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
   301  									switch field.AutoUpdateTime {
   302  									case schema.UnixNanosecond:
   303  										assignment.Value = curTime.UnixNano()
   304  									case schema.UnixMillisecond:
   305  										assignment.Value = curTime.UnixNano() / 1e6
   306  									case schema.UnixSecond:
   307  										assignment.Value = curTime.Unix()
   308  									}
   309  
   310  									onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
   311  								} else {
   312  									columns = append(columns, column.Name)
   313  								}
   314  							}
   315  						}
   316  					}
   317  				}
   318  
   319  				onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
   320  				if len(onConflict.DoUpdates) == 0 {
   321  					onConflict.DoNothing = true
   322  				}
   323  
   324  				// use primary fields as default OnConflict columns
   325  				if len(onConflict.Columns) == 0 {
   326  					for _, field := range stmt.Schema.PrimaryFields {
   327  						onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName})
   328  					}
   329  				}
   330  				stmt.AddClause(onConflict)
   331  			}
   332  		}
   333  	}
   334  
   335  	return values
   336  }