github.com/systematiccaos/gorm@v1.22.6/callbacks/associations.go (about)

     1  package callbacks
     2  
     3  import (
     4  	"reflect"
     5  	"strings"
     6  
     7  	"github.com/systematiccaos/gorm"
     8  	"github.com/systematiccaos/gorm/clause"
     9  	"github.com/systematiccaos/gorm/schema"
    10  	"github.com/systematiccaos/gorm/utils"
    11  )
    12  
    13  func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
    14  	return func(db *gorm.DB) {
    15  		if db.Error == nil && db.Statement.Schema != nil {
    16  			selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
    17  
    18  			// Save Belongs To associations
    19  			for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
    20  				if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
    21  					continue
    22  				}
    23  
    24  				setupReferences := func(obj reflect.Value, elem reflect.Value) {
    25  					for _, ref := range rel.References {
    26  						if !ref.OwnPrimaryKey {
    27  							pv, _ := ref.PrimaryKey.ValueOf(elem)
    28  							db.AddError(ref.ForeignKey.Set(obj, pv))
    29  
    30  							if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
    31  								dest[ref.ForeignKey.DBName] = pv
    32  								if _, ok := dest[rel.Name]; ok {
    33  									dest[rel.Name] = elem.Interface()
    34  								}
    35  							}
    36  						}
    37  					}
    38  				}
    39  
    40  				switch db.Statement.ReflectValue.Kind() {
    41  				case reflect.Slice, reflect.Array:
    42  					var (
    43  						rValLen   = db.Statement.ReflectValue.Len()
    44  						objs      = make([]reflect.Value, 0, rValLen)
    45  						fieldType = rel.Field.FieldType
    46  						isPtr     = fieldType.Kind() == reflect.Ptr
    47  					)
    48  
    49  					if !isPtr {
    50  						fieldType = reflect.PtrTo(fieldType)
    51  					}
    52  
    53  					elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
    54  					for i := 0; i < rValLen; i++ {
    55  						obj := db.Statement.ReflectValue.Index(i)
    56  						if reflect.Indirect(obj).Kind() != reflect.Struct {
    57  							break
    58  						}
    59  
    60  						if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value
    61  							rv := rel.Field.ReflectValueOf(obj) // relation reflect value
    62  							objs = append(objs, obj)
    63  							if isPtr {
    64  								elems = reflect.Append(elems, rv)
    65  							} else {
    66  								elems = reflect.Append(elems, rv.Addr())
    67  							}
    68  						}
    69  					}
    70  
    71  					if elems.Len() > 0 {
    72  						if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
    73  							for i := 0; i < elems.Len(); i++ {
    74  								setupReferences(objs[i], elems.Index(i))
    75  							}
    76  						}
    77  					}
    78  				case reflect.Struct:
    79  					if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
    80  						rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value
    81  						if rv.Kind() != reflect.Ptr {
    82  							rv = rv.Addr()
    83  						}
    84  
    85  						if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
    86  							setupReferences(db.Statement.ReflectValue, rv)
    87  						}
    88  					}
    89  				}
    90  			}
    91  		}
    92  	}
    93  }
    94  
    95  func SaveAfterAssociations(create bool) func(db *gorm.DB) {
    96  	return func(db *gorm.DB) {
    97  		if db.Error == nil && db.Statement.Schema != nil {
    98  			selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
    99  
   100  			// Save Has One associations
   101  			for _, rel := range db.Statement.Schema.Relationships.HasOne {
   102  				if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
   103  					continue
   104  				}
   105  
   106  				switch db.Statement.ReflectValue.Kind() {
   107  				case reflect.Slice, reflect.Array:
   108  					var (
   109  						fieldType = rel.Field.FieldType
   110  						isPtr     = fieldType.Kind() == reflect.Ptr
   111  					)
   112  
   113  					if !isPtr {
   114  						fieldType = reflect.PtrTo(fieldType)
   115  					}
   116  
   117  					elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
   118  
   119  					for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
   120  						obj := db.Statement.ReflectValue.Index(i)
   121  
   122  						if reflect.Indirect(obj).Kind() == reflect.Struct {
   123  							if _, zero := rel.Field.ValueOf(obj); !zero {
   124  								rv := rel.Field.ReflectValueOf(obj)
   125  								if rv.Kind() != reflect.Ptr {
   126  									rv = rv.Addr()
   127  								}
   128  
   129  								for _, ref := range rel.References {
   130  									if ref.OwnPrimaryKey {
   131  										fv, _ := ref.PrimaryKey.ValueOf(obj)
   132  										db.AddError(ref.ForeignKey.Set(rv, fv))
   133  									} else if ref.PrimaryValue != "" {
   134  										db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue))
   135  									}
   136  								}
   137  
   138  								elems = reflect.Append(elems, rv)
   139  							}
   140  						}
   141  					}
   142  
   143  					if elems.Len() > 0 {
   144  						assignmentColumns := make([]string, 0, len(rel.References))
   145  						for _, ref := range rel.References {
   146  							assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
   147  						}
   148  
   149  						saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
   150  					}
   151  				case reflect.Struct:
   152  					if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
   153  						f := rel.Field.ReflectValueOf(db.Statement.ReflectValue)
   154  						if f.Kind() != reflect.Ptr {
   155  							f = f.Addr()
   156  						}
   157  
   158  						assignmentColumns := make([]string, 0, len(rel.References))
   159  						for _, ref := range rel.References {
   160  							if ref.OwnPrimaryKey {
   161  								fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue)
   162  								ref.ForeignKey.Set(f, fv)
   163  							} else if ref.PrimaryValue != "" {
   164  								ref.ForeignKey.Set(f, ref.PrimaryValue)
   165  							}
   166  							assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
   167  						}
   168  
   169  						saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
   170  					}
   171  				}
   172  			}
   173  
   174  			// Save Has Many associations
   175  			for _, rel := range db.Statement.Schema.Relationships.HasMany {
   176  				if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
   177  					continue
   178  				}
   179  
   180  				fieldType := rel.Field.IndirectFieldType.Elem()
   181  				isPtr := fieldType.Kind() == reflect.Ptr
   182  				if !isPtr {
   183  					fieldType = reflect.PtrTo(fieldType)
   184  				}
   185  				elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
   186  				identityMap := map[string]bool{}
   187  				appendToElems := func(v reflect.Value) {
   188  					if _, zero := rel.Field.ValueOf(v); !zero {
   189  						f := reflect.Indirect(rel.Field.ReflectValueOf(v))
   190  
   191  						for i := 0; i < f.Len(); i++ {
   192  							elem := f.Index(i)
   193  							for _, ref := range rel.References {
   194  								if ref.OwnPrimaryKey {
   195  									pv, _ := ref.PrimaryKey.ValueOf(v)
   196  									ref.ForeignKey.Set(elem, pv)
   197  								} else if ref.PrimaryValue != "" {
   198  									ref.ForeignKey.Set(elem, ref.PrimaryValue)
   199  								}
   200  							}
   201  
   202  							relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
   203  							for _, pf := range rel.FieldSchema.PrimaryFields {
   204  								if pfv, ok := pf.ValueOf(elem); !ok {
   205  									relPrimaryValues = append(relPrimaryValues, pfv)
   206  								}
   207  							}
   208  
   209  							cacheKey := utils.ToStringKey(relPrimaryValues)
   210  							if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
   211  								identityMap[cacheKey] = true
   212  								if isPtr {
   213  									elems = reflect.Append(elems, elem)
   214  								} else {
   215  									elems = reflect.Append(elems, elem.Addr())
   216  								}
   217  							}
   218  						}
   219  					}
   220  				}
   221  
   222  				switch db.Statement.ReflectValue.Kind() {
   223  				case reflect.Slice, reflect.Array:
   224  					for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
   225  						obj := db.Statement.ReflectValue.Index(i)
   226  						if reflect.Indirect(obj).Kind() == reflect.Struct {
   227  							appendToElems(obj)
   228  						}
   229  					}
   230  				case reflect.Struct:
   231  					appendToElems(db.Statement.ReflectValue)
   232  				}
   233  
   234  				if elems.Len() > 0 {
   235  					assignmentColumns := make([]string, 0, len(rel.References))
   236  					for _, ref := range rel.References {
   237  						assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
   238  					}
   239  
   240  					saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
   241  				}
   242  			}
   243  
   244  			// Save Many2Many associations
   245  			for _, rel := range db.Statement.Schema.Relationships.Many2Many {
   246  				if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
   247  					continue
   248  				}
   249  
   250  				fieldType := rel.Field.IndirectFieldType.Elem()
   251  				isPtr := fieldType.Kind() == reflect.Ptr
   252  				if !isPtr {
   253  					fieldType = reflect.PtrTo(fieldType)
   254  				}
   255  				elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
   256  				joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
   257  				objs := []reflect.Value{}
   258  
   259  				appendToJoins := func(obj reflect.Value, elem reflect.Value) {
   260  					joinValue := reflect.New(rel.JoinTable.ModelType)
   261  					for _, ref := range rel.References {
   262  						if ref.OwnPrimaryKey {
   263  							fv, _ := ref.PrimaryKey.ValueOf(obj)
   264  							ref.ForeignKey.Set(joinValue, fv)
   265  						} else if ref.PrimaryValue != "" {
   266  							ref.ForeignKey.Set(joinValue, ref.PrimaryValue)
   267  						} else {
   268  							fv, _ := ref.PrimaryKey.ValueOf(elem)
   269  							ref.ForeignKey.Set(joinValue, fv)
   270  						}
   271  					}
   272  					joins = reflect.Append(joins, joinValue)
   273  				}
   274  
   275  				appendToElems := func(v reflect.Value) {
   276  					if _, zero := rel.Field.ValueOf(v); !zero {
   277  						f := reflect.Indirect(rel.Field.ReflectValueOf(v))
   278  
   279  						for i := 0; i < f.Len(); i++ {
   280  							elem := f.Index(i)
   281  
   282  							objs = append(objs, v)
   283  							if isPtr {
   284  								elems = reflect.Append(elems, elem)
   285  							} else {
   286  								elems = reflect.Append(elems, elem.Addr())
   287  							}
   288  						}
   289  					}
   290  				}
   291  
   292  				switch db.Statement.ReflectValue.Kind() {
   293  				case reflect.Slice, reflect.Array:
   294  					for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
   295  						obj := db.Statement.ReflectValue.Index(i)
   296  						if reflect.Indirect(obj).Kind() == reflect.Struct {
   297  							appendToElems(obj)
   298  						}
   299  					}
   300  				case reflect.Struct:
   301  					appendToElems(db.Statement.ReflectValue)
   302  				}
   303  
   304  				// optimize elems of reflect value length
   305  				if elemLen := elems.Len(); elemLen > 0 {
   306  					if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
   307  						saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
   308  					}
   309  
   310  					for i := 0; i < elemLen; i++ {
   311  						appendToJoins(objs[i], elems.Index(i))
   312  					}
   313  				}
   314  
   315  				if joins.Len() > 0 {
   316  					db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
   317  						SkipHooks:                db.Statement.SkipHooks,
   318  						DisableNestedTransaction: true,
   319  					}).Create(joins.Interface()).Error)
   320  				}
   321  			}
   322  		}
   323  	}
   324  }
   325  
   326  func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
   327  	if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
   328  		onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
   329  		for _, dbName := range s.PrimaryFieldDBNames {
   330  			onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})
   331  		}
   332  
   333  		onConflict.UpdateAll = stmt.DB.FullSaveAssociations
   334  		if !onConflict.UpdateAll {
   335  			onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns)
   336  		}
   337  	} else {
   338  		onConflict.DoNothing = true
   339  	}
   340  
   341  	return
   342  }
   343  
   344  func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
   345  	var (
   346  		selects, omits []string
   347  		onConflict     = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
   348  		refName        = rel.Name + "."
   349  	)
   350  
   351  	for name, ok := range selectColumns {
   352  		columnName := ""
   353  		if strings.HasPrefix(name, refName) {
   354  			columnName = strings.TrimPrefix(name, refName)
   355  		}
   356  
   357  		if columnName != "" {
   358  			if ok {
   359  				selects = append(selects, columnName)
   360  			} else {
   361  				omits = append(omits, columnName)
   362  			}
   363  		}
   364  	}
   365  
   366  	tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
   367  		FullSaveAssociations:     db.FullSaveAssociations,
   368  		SkipHooks:                db.Statement.SkipHooks,
   369  		DisableNestedTransaction: true,
   370  	})
   371  
   372  	db.Statement.Settings.Range(func(k, v interface{}) bool {
   373  		tx.Statement.Settings.Store(k, v)
   374  		return true
   375  	})
   376  
   377  	if tx.Statement.FullSaveAssociations {
   378  		tx = tx.Set("gorm:update_track_time", true)
   379  	}
   380  
   381  	if len(selects) > 0 {
   382  		tx = tx.Select(selects)
   383  	} else if restricted && len(omits) == 0 {
   384  		tx = tx.Omit(clause.Associations)
   385  	}
   386  
   387  	if len(omits) > 0 {
   388  		tx = tx.Omit(omits...)
   389  	}
   390  
   391  	return db.AddError(tx.Create(values).Error)
   392  }