github.com/systematiccaos/gorm@v1.22.6/association.go (about)

     1  package gorm
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"strings"
     7  
     8  	"github.com/systematiccaos/gorm/clause"
     9  	"github.com/systematiccaos/gorm/schema"
    10  	"github.com/systematiccaos/gorm/utils"
    11  )
    12  
    13  // Association Mode contains some helper methods to handle relationship things easily.
    14  type Association struct {
    15  	DB           *DB
    16  	Relationship *schema.Relationship
    17  	Error        error
    18  }
    19  
    20  func (db *DB) Association(column string) *Association {
    21  	association := &Association{DB: db}
    22  	table := db.Statement.Table
    23  
    24  	if err := db.Statement.Parse(db.Statement.Model); err == nil {
    25  		db.Statement.Table = table
    26  		association.Relationship = db.Statement.Schema.Relationships.Relations[column]
    27  
    28  		if association.Relationship == nil {
    29  			association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column)
    30  		}
    31  
    32  		db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
    33  		for db.Statement.ReflectValue.Kind() == reflect.Ptr {
    34  			db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
    35  		}
    36  	} else {
    37  		association.Error = err
    38  	}
    39  
    40  	return association
    41  }
    42  
    43  func (association *Association) Find(out interface{}, conds ...interface{}) error {
    44  	if association.Error == nil {
    45  		association.Error = association.buildCondition().Find(out, conds...).Error
    46  	}
    47  	return association.Error
    48  }
    49  
    50  func (association *Association) Append(values ...interface{}) error {
    51  	if association.Error == nil {
    52  		switch association.Relationship.Type {
    53  		case schema.HasOne, schema.BelongsTo:
    54  			if len(values) > 0 {
    55  				association.Error = association.Replace(values...)
    56  			}
    57  		default:
    58  			association.saveAssociation( /*clear*/ false, values...)
    59  		}
    60  	}
    61  
    62  	return association.Error
    63  }
    64  
    65  func (association *Association) Replace(values ...interface{}) error {
    66  	if association.Error == nil {
    67  		// save associations
    68  		if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
    69  			return association.Error
    70  		}
    71  
    72  		// set old associations's foreign key to null
    73  		reflectValue := association.DB.Statement.ReflectValue
    74  		rel := association.Relationship
    75  		switch rel.Type {
    76  		case schema.BelongsTo:
    77  			if len(values) == 0 {
    78  				updateMap := map[string]interface{}{}
    79  				switch reflectValue.Kind() {
    80  				case reflect.Slice, reflect.Array:
    81  					for i := 0; i < reflectValue.Len(); i++ {
    82  						association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
    83  					}
    84  				case reflect.Struct:
    85  					association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
    86  				}
    87  
    88  				for _, ref := range rel.References {
    89  					updateMap[ref.ForeignKey.DBName] = nil
    90  				}
    91  
    92  				association.Error = association.DB.UpdateColumns(updateMap).Error
    93  			}
    94  		case schema.HasOne, schema.HasMany:
    95  			var (
    96  				primaryFields []*schema.Field
    97  				foreignKeys   []string
    98  				updateMap     = map[string]interface{}{}
    99  				relValues     = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel})
   100  				modelValue    = reflect.New(rel.FieldSchema.ModelType).Interface()
   101  				tx            = association.DB.Model(modelValue)
   102  			)
   103  
   104  			if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
   105  				if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
   106  					tx.Not(clause.IN{Column: column, Values: values})
   107  				}
   108  			}
   109  
   110  			for _, ref := range rel.References {
   111  				if ref.OwnPrimaryKey {
   112  					primaryFields = append(primaryFields, ref.PrimaryKey)
   113  					foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
   114  					updateMap[ref.ForeignKey.DBName] = nil
   115  				} else if ref.PrimaryValue != "" {
   116  					tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
   117  				}
   118  			}
   119  
   120  			if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
   121  				column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
   122  				association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
   123  			}
   124  		case schema.Many2Many:
   125  			var (
   126  				primaryFields, relPrimaryFields     []*schema.Field
   127  				joinPrimaryKeys, joinRelPrimaryKeys []string
   128  				modelValue                          = reflect.New(rel.JoinTable.ModelType).Interface()
   129  				tx                                  = association.DB.Model(modelValue)
   130  			)
   131  
   132  			for _, ref := range rel.References {
   133  				if ref.PrimaryValue == "" {
   134  					if ref.OwnPrimaryKey {
   135  						primaryFields = append(primaryFields, ref.PrimaryKey)
   136  						joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
   137  					} else {
   138  						relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
   139  						joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
   140  					}
   141  				} else {
   142  					tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
   143  				}
   144  			}
   145  
   146  			_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
   147  			if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
   148  				tx.Where(clause.IN{Column: column, Values: values})
   149  			} else {
   150  				return ErrPrimaryKeyRequired
   151  			}
   152  
   153  			_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
   154  			if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
   155  				tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
   156  			}
   157  
   158  			association.Error = tx.Delete(modelValue).Error
   159  		}
   160  	}
   161  	return association.Error
   162  }
   163  
   164  func (association *Association) Delete(values ...interface{}) error {
   165  	if association.Error == nil {
   166  		var (
   167  			reflectValue  = association.DB.Statement.ReflectValue
   168  			rel           = association.Relationship
   169  			primaryFields []*schema.Field
   170  			foreignKeys   []string
   171  			updateAttrs   = map[string]interface{}{}
   172  			conds         []clause.Expression
   173  		)
   174  
   175  		for _, ref := range rel.References {
   176  			if ref.PrimaryValue == "" {
   177  				primaryFields = append(primaryFields, ref.PrimaryKey)
   178  				foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
   179  				updateAttrs[ref.ForeignKey.DBName] = nil
   180  			} else {
   181  				conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
   182  			}
   183  		}
   184  
   185  		switch rel.Type {
   186  		case schema.BelongsTo:
   187  			tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())
   188  
   189  			_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
   190  			pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
   191  			conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
   192  
   193  			_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
   194  			relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
   195  			conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
   196  
   197  			association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
   198  		case schema.HasOne, schema.HasMany:
   199  			tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())
   200  
   201  			_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
   202  			pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
   203  			conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
   204  
   205  			_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
   206  			relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
   207  			conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
   208  
   209  			association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
   210  		case schema.Many2Many:
   211  			var (
   212  				primaryFields, relPrimaryFields     []*schema.Field
   213  				joinPrimaryKeys, joinRelPrimaryKeys []string
   214  				joinValue                           = reflect.New(rel.JoinTable.ModelType).Interface()
   215  			)
   216  
   217  			for _, ref := range rel.References {
   218  				if ref.PrimaryValue == "" {
   219  					if ref.OwnPrimaryKey {
   220  						primaryFields = append(primaryFields, ref.PrimaryKey)
   221  						joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
   222  					} else {
   223  						relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
   224  						joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
   225  					}
   226  				} else {
   227  					conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
   228  				}
   229  			}
   230  
   231  			_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
   232  			pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
   233  			conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
   234  
   235  			_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
   236  			relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
   237  			conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
   238  
   239  			association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
   240  		}
   241  
   242  		if association.Error == nil {
   243  			// clean up deleted values's foreign key
   244  			relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
   245  
   246  			cleanUpDeletedRelations := func(data reflect.Value) {
   247  				if _, zero := rel.Field.ValueOf(data); !zero {
   248  					fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data))
   249  					primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
   250  
   251  					switch fieldValue.Kind() {
   252  					case reflect.Slice, reflect.Array:
   253  						validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
   254  						for i := 0; i < fieldValue.Len(); i++ {
   255  							for idx, field := range rel.FieldSchema.PrimaryFields {
   256  								primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i))
   257  							}
   258  
   259  							if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
   260  								validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i))
   261  							}
   262  						}
   263  
   264  						association.Error = rel.Field.Set(data, validFieldValues.Interface())
   265  					case reflect.Struct:
   266  						for idx, field := range rel.FieldSchema.PrimaryFields {
   267  							primaryValues[idx], _ = field.ValueOf(fieldValue)
   268  						}
   269  
   270  						if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
   271  							if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
   272  								break
   273  							}
   274  
   275  							if rel.JoinTable == nil {
   276  								for _, ref := range rel.References {
   277  									if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
   278  										association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
   279  									} else {
   280  										association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
   281  									}
   282  								}
   283  							}
   284  						}
   285  					}
   286  				}
   287  			}
   288  
   289  			switch reflectValue.Kind() {
   290  			case reflect.Slice, reflect.Array:
   291  				for i := 0; i < reflectValue.Len(); i++ {
   292  					cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i)))
   293  				}
   294  			case reflect.Struct:
   295  				cleanUpDeletedRelations(reflectValue)
   296  			}
   297  		}
   298  	}
   299  
   300  	return association.Error
   301  }
   302  
   303  func (association *Association) Clear() error {
   304  	return association.Replace()
   305  }
   306  
   307  func (association *Association) Count() (count int64) {
   308  	if association.Error == nil {
   309  		association.Error = association.buildCondition().Count(&count).Error
   310  	}
   311  	return
   312  }
   313  
   314  type assignBack struct {
   315  	Source reflect.Value
   316  	Index  int
   317  	Dest   reflect.Value
   318  }
   319  
   320  func (association *Association) saveAssociation(clear bool, values ...interface{}) {
   321  	var (
   322  		reflectValue = association.DB.Statement.ReflectValue
   323  		assignBacks  []assignBack // assign association values back to arguments after save
   324  	)
   325  
   326  	appendToRelations := func(source, rv reflect.Value, clear bool) {
   327  		switch association.Relationship.Type {
   328  		case schema.HasOne, schema.BelongsTo:
   329  			switch rv.Kind() {
   330  			case reflect.Slice, reflect.Array:
   331  				if rv.Len() > 0 {
   332  					association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface())
   333  
   334  					if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
   335  						assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
   336  					}
   337  				}
   338  			case reflect.Struct:
   339  				association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface())
   340  
   341  				if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
   342  					assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
   343  				}
   344  			}
   345  		case schema.HasMany, schema.Many2Many:
   346  			elemType := association.Relationship.Field.IndirectFieldType.Elem()
   347  			fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source))
   348  			if clear {
   349  				fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem()
   350  			}
   351  
   352  			appendToFieldValues := func(ev reflect.Value) {
   353  				if ev.Type().AssignableTo(elemType) {
   354  					fieldValue = reflect.Append(fieldValue, ev)
   355  				} else if ev.Type().Elem().AssignableTo(elemType) {
   356  					fieldValue = reflect.Append(fieldValue, ev.Elem())
   357  				} else {
   358  					association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name)
   359  				}
   360  
   361  				if elemType.Kind() == reflect.Struct {
   362  					assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()})
   363  				}
   364  			}
   365  
   366  			switch rv.Kind() {
   367  			case reflect.Slice, reflect.Array:
   368  				for i := 0; i < rv.Len(); i++ {
   369  					appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
   370  				}
   371  			case reflect.Struct:
   372  				appendToFieldValues(rv.Addr())
   373  			}
   374  
   375  			if association.Error == nil {
   376  				association.Error = association.Relationship.Field.Set(source, fieldValue.Interface())
   377  			}
   378  		}
   379  	}
   380  
   381  	selectedSaveColumns := []string{association.Relationship.Name}
   382  	omitColumns := []string{}
   383  	selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false)
   384  	for name, ok := range selectColumns {
   385  		columnName := ""
   386  		if strings.HasPrefix(name, association.Relationship.Name) {
   387  			if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" {
   388  				columnName = name
   389  			}
   390  		} else if strings.HasPrefix(name, clause.Associations) {
   391  			columnName = name
   392  		}
   393  
   394  		if columnName != "" {
   395  			if ok {
   396  				selectedSaveColumns = append(selectedSaveColumns, columnName)
   397  			} else {
   398  				omitColumns = append(omitColumns, columnName)
   399  			}
   400  		}
   401  	}
   402  
   403  	for _, ref := range association.Relationship.References {
   404  		if !ref.OwnPrimaryKey {
   405  			selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name)
   406  		}
   407  	}
   408  
   409  	associationDB := association.DB.Session(&Session{}).Model(nil)
   410  	if !association.DB.FullSaveAssociations {
   411  		associationDB.Select(selectedSaveColumns)
   412  	}
   413  	if len(omitColumns) > 0 {
   414  		associationDB.Omit(omitColumns...)
   415  	}
   416  	associationDB = associationDB.Session(&Session{})
   417  
   418  	switch reflectValue.Kind() {
   419  	case reflect.Slice, reflect.Array:
   420  		if len(values) != reflectValue.Len() {
   421  			// clear old data
   422  			if clear && len(values) == 0 {
   423  				for i := 0; i < reflectValue.Len(); i++ {
   424  					if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
   425  						association.Error = err
   426  						break
   427  					}
   428  
   429  					if association.Relationship.JoinTable == nil {
   430  						for _, ref := range association.Relationship.References {
   431  							if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
   432  								if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
   433  									association.Error = err
   434  									break
   435  								}
   436  							}
   437  						}
   438  					}
   439  				}
   440  				break
   441  			}
   442  
   443  			association.Error = ErrInvalidValueOfLength
   444  			return
   445  		}
   446  
   447  		for i := 0; i < reflectValue.Len(); i++ {
   448  			appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
   449  
   450  			// TODO support save slice data, sql with case?
   451  			association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
   452  		}
   453  	case reflect.Struct:
   454  		// clear old data
   455  		if clear && len(values) == 0 {
   456  			association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
   457  
   458  			if association.Relationship.JoinTable == nil && association.Error == nil {
   459  				for _, ref := range association.Relationship.References {
   460  					if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
   461  						association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
   462  					}
   463  				}
   464  			}
   465  		}
   466  
   467  		for idx, value := range values {
   468  			rv := reflect.Indirect(reflect.ValueOf(value))
   469  			appendToRelations(reflectValue, rv, clear && idx == 0)
   470  		}
   471  
   472  		if len(values) > 0 {
   473  			association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error
   474  		}
   475  	}
   476  
   477  	for _, assignBack := range assignBacks {
   478  		fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source))
   479  		if assignBack.Index > 0 {
   480  			reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
   481  		} else {
   482  			reflect.Indirect(assignBack.Dest).Set(fieldValue)
   483  		}
   484  	}
   485  }
   486  
   487  func (association *Association) buildCondition() *DB {
   488  	var (
   489  		queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue)
   490  		modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
   491  		tx         = association.DB.Model(modelValue)
   492  	)
   493  
   494  	if association.Relationship.JoinTable != nil {
   495  		if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
   496  			joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
   497  			for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
   498  				joinStmt.AddClause(queryClause)
   499  			}
   500  			joinStmt.Build("WHERE")
   501  			tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
   502  		}
   503  
   504  		tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
   505  			Table: clause.Table{Name: association.Relationship.JoinTable.Table},
   506  			ON:    clause.Where{Exprs: queryConds},
   507  		}}})
   508  	} else {
   509  		tx.Clauses(clause.Where{Exprs: queryConds})
   510  	}
   511  
   512  	return tx
   513  }