github.com/systematiccaos/gorm@v1.22.6/association.go (about) 1 package gorm 2 3 import ( 4 "fmt" 5 "reflect" 6 "strings" 7 8 "github.com/systematiccaos/gorm/clause" 9 "github.com/systematiccaos/gorm/schema" 10 "github.com/systematiccaos/gorm/utils" 11 ) 12 13 // Association Mode contains some helper methods to handle relationship things easily. 14 type Association struct { 15 DB *DB 16 Relationship *schema.Relationship 17 Error error 18 } 19 20 func (db *DB) Association(column string) *Association { 21 association := &Association{DB: db} 22 table := db.Statement.Table 23 24 if err := db.Statement.Parse(db.Statement.Model); err == nil { 25 db.Statement.Table = table 26 association.Relationship = db.Statement.Schema.Relationships.Relations[column] 27 28 if association.Relationship == nil { 29 association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column) 30 } 31 32 db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) 33 for db.Statement.ReflectValue.Kind() == reflect.Ptr { 34 db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() 35 } 36 } else { 37 association.Error = err 38 } 39 40 return association 41 } 42 43 func (association *Association) Find(out interface{}, conds ...interface{}) error { 44 if association.Error == nil { 45 association.Error = association.buildCondition().Find(out, conds...).Error 46 } 47 return association.Error 48 } 49 50 func (association *Association) Append(values ...interface{}) error { 51 if association.Error == nil { 52 switch association.Relationship.Type { 53 case schema.HasOne, schema.BelongsTo: 54 if len(values) > 0 { 55 association.Error = association.Replace(values...) 56 } 57 default: 58 association.saveAssociation( /*clear*/ false, values...) 59 } 60 } 61 62 return association.Error 63 } 64 65 func (association *Association) Replace(values ...interface{}) error { 66 if association.Error == nil { 67 // save associations 68 if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { 69 return association.Error 70 } 71 72 // set old associations's foreign key to null 73 reflectValue := association.DB.Statement.ReflectValue 74 rel := association.Relationship 75 switch rel.Type { 76 case schema.BelongsTo: 77 if len(values) == 0 { 78 updateMap := map[string]interface{}{} 79 switch reflectValue.Kind() { 80 case reflect.Slice, reflect.Array: 81 for i := 0; i < reflectValue.Len(); i++ { 82 association.Error = rel.Field.Set(reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) 83 } 84 case reflect.Struct: 85 association.Error = rel.Field.Set(reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) 86 } 87 88 for _, ref := range rel.References { 89 updateMap[ref.ForeignKey.DBName] = nil 90 } 91 92 association.Error = association.DB.UpdateColumns(updateMap).Error 93 } 94 case schema.HasOne, schema.HasMany: 95 var ( 96 primaryFields []*schema.Field 97 foreignKeys []string 98 updateMap = map[string]interface{}{} 99 relValues = schema.GetRelationsValues(reflectValue, []*schema.Relationship{rel}) 100 modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() 101 tx = association.DB.Model(modelValue) 102 ) 103 104 if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { 105 if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { 106 tx.Not(clause.IN{Column: column, Values: values}) 107 } 108 } 109 110 for _, ref := range rel.References { 111 if ref.OwnPrimaryKey { 112 primaryFields = append(primaryFields, ref.PrimaryKey) 113 foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) 114 updateMap[ref.ForeignKey.DBName] = nil 115 } else if ref.PrimaryValue != "" { 116 tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) 117 } 118 } 119 120 if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { 121 column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) 122 association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error 123 } 124 case schema.Many2Many: 125 var ( 126 primaryFields, relPrimaryFields []*schema.Field 127 joinPrimaryKeys, joinRelPrimaryKeys []string 128 modelValue = reflect.New(rel.JoinTable.ModelType).Interface() 129 tx = association.DB.Model(modelValue) 130 ) 131 132 for _, ref := range rel.References { 133 if ref.PrimaryValue == "" { 134 if ref.OwnPrimaryKey { 135 primaryFields = append(primaryFields, ref.PrimaryKey) 136 joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) 137 } else { 138 relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) 139 joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) 140 } 141 } else { 142 tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) 143 } 144 } 145 146 _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) 147 if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { 148 tx.Where(clause.IN{Column: column, Values: values}) 149 } else { 150 return ErrPrimaryKeyRequired 151 } 152 153 _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) 154 if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { 155 tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) 156 } 157 158 association.Error = tx.Delete(modelValue).Error 159 } 160 } 161 return association.Error 162 } 163 164 func (association *Association) Delete(values ...interface{}) error { 165 if association.Error == nil { 166 var ( 167 reflectValue = association.DB.Statement.ReflectValue 168 rel = association.Relationship 169 primaryFields []*schema.Field 170 foreignKeys []string 171 updateAttrs = map[string]interface{}{} 172 conds []clause.Expression 173 ) 174 175 for _, ref := range rel.References { 176 if ref.PrimaryValue == "" { 177 primaryFields = append(primaryFields, ref.PrimaryKey) 178 foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) 179 updateAttrs[ref.ForeignKey.DBName] = nil 180 } else { 181 conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) 182 } 183 } 184 185 switch rel.Type { 186 case schema.BelongsTo: 187 tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) 188 189 _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields) 190 pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs) 191 conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) 192 193 _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields) 194 relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) 195 conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) 196 197 association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error 198 case schema.HasOne, schema.HasMany: 199 tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) 200 201 _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) 202 pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) 203 conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) 204 205 _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) 206 relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) 207 conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) 208 209 association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error 210 case schema.Many2Many: 211 var ( 212 primaryFields, relPrimaryFields []*schema.Field 213 joinPrimaryKeys, joinRelPrimaryKeys []string 214 joinValue = reflect.New(rel.JoinTable.ModelType).Interface() 215 ) 216 217 for _, ref := range rel.References { 218 if ref.PrimaryValue == "" { 219 if ref.OwnPrimaryKey { 220 primaryFields = append(primaryFields, ref.PrimaryKey) 221 joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) 222 } else { 223 relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) 224 joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) 225 } 226 } else { 227 conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) 228 } 229 } 230 231 _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields) 232 pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs) 233 conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) 234 235 _, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields) 236 relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) 237 conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) 238 239 association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error 240 } 241 242 if association.Error == nil { 243 // clean up deleted values's foreign key 244 relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields) 245 246 cleanUpDeletedRelations := func(data reflect.Value) { 247 if _, zero := rel.Field.ValueOf(data); !zero { 248 fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(data)) 249 primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) 250 251 switch fieldValue.Kind() { 252 case reflect.Slice, reflect.Array: 253 validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) 254 for i := 0; i < fieldValue.Len(); i++ { 255 for idx, field := range rel.FieldSchema.PrimaryFields { 256 primaryValues[idx], _ = field.ValueOf(fieldValue.Index(i)) 257 } 258 259 if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { 260 validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i)) 261 } 262 } 263 264 association.Error = rel.Field.Set(data, validFieldValues.Interface()) 265 case reflect.Struct: 266 for idx, field := range rel.FieldSchema.PrimaryFields { 267 primaryValues[idx], _ = field.ValueOf(fieldValue) 268 } 269 270 if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { 271 if association.Error = rel.Field.Set(data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { 272 break 273 } 274 275 if rel.JoinTable == nil { 276 for _, ref := range rel.References { 277 if ref.OwnPrimaryKey || ref.PrimaryValue != "" { 278 association.Error = ref.ForeignKey.Set(fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) 279 } else { 280 association.Error = ref.ForeignKey.Set(data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) 281 } 282 } 283 } 284 } 285 } 286 } 287 } 288 289 switch reflectValue.Kind() { 290 case reflect.Slice, reflect.Array: 291 for i := 0; i < reflectValue.Len(); i++ { 292 cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i))) 293 } 294 case reflect.Struct: 295 cleanUpDeletedRelations(reflectValue) 296 } 297 } 298 } 299 300 return association.Error 301 } 302 303 func (association *Association) Clear() error { 304 return association.Replace() 305 } 306 307 func (association *Association) Count() (count int64) { 308 if association.Error == nil { 309 association.Error = association.buildCondition().Count(&count).Error 310 } 311 return 312 } 313 314 type assignBack struct { 315 Source reflect.Value 316 Index int 317 Dest reflect.Value 318 } 319 320 func (association *Association) saveAssociation(clear bool, values ...interface{}) { 321 var ( 322 reflectValue = association.DB.Statement.ReflectValue 323 assignBacks []assignBack // assign association values back to arguments after save 324 ) 325 326 appendToRelations := func(source, rv reflect.Value, clear bool) { 327 switch association.Relationship.Type { 328 case schema.HasOne, schema.BelongsTo: 329 switch rv.Kind() { 330 case reflect.Slice, reflect.Array: 331 if rv.Len() > 0 { 332 association.Error = association.Relationship.Field.Set(source, rv.Index(0).Addr().Interface()) 333 334 if association.Relationship.Field.FieldType.Kind() == reflect.Struct { 335 assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) 336 } 337 } 338 case reflect.Struct: 339 association.Error = association.Relationship.Field.Set(source, rv.Addr().Interface()) 340 341 if association.Relationship.Field.FieldType.Kind() == reflect.Struct { 342 assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) 343 } 344 } 345 case schema.HasMany, schema.Many2Many: 346 elemType := association.Relationship.Field.IndirectFieldType.Elem() 347 fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(source)) 348 if clear { 349 fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() 350 } 351 352 appendToFieldValues := func(ev reflect.Value) { 353 if ev.Type().AssignableTo(elemType) { 354 fieldValue = reflect.Append(fieldValue, ev) 355 } else if ev.Type().Elem().AssignableTo(elemType) { 356 fieldValue = reflect.Append(fieldValue, ev.Elem()) 357 } else { 358 association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name) 359 } 360 361 if elemType.Kind() == reflect.Struct { 362 assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()}) 363 } 364 } 365 366 switch rv.Kind() { 367 case reflect.Slice, reflect.Array: 368 for i := 0; i < rv.Len(); i++ { 369 appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr()) 370 } 371 case reflect.Struct: 372 appendToFieldValues(rv.Addr()) 373 } 374 375 if association.Error == nil { 376 association.Error = association.Relationship.Field.Set(source, fieldValue.Interface()) 377 } 378 } 379 } 380 381 selectedSaveColumns := []string{association.Relationship.Name} 382 omitColumns := []string{} 383 selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false) 384 for name, ok := range selectColumns { 385 columnName := "" 386 if strings.HasPrefix(name, association.Relationship.Name) { 387 if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" { 388 columnName = name 389 } 390 } else if strings.HasPrefix(name, clause.Associations) { 391 columnName = name 392 } 393 394 if columnName != "" { 395 if ok { 396 selectedSaveColumns = append(selectedSaveColumns, columnName) 397 } else { 398 omitColumns = append(omitColumns, columnName) 399 } 400 } 401 } 402 403 for _, ref := range association.Relationship.References { 404 if !ref.OwnPrimaryKey { 405 selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) 406 } 407 } 408 409 associationDB := association.DB.Session(&Session{}).Model(nil) 410 if !association.DB.FullSaveAssociations { 411 associationDB.Select(selectedSaveColumns) 412 } 413 if len(omitColumns) > 0 { 414 associationDB.Omit(omitColumns...) 415 } 416 associationDB = associationDB.Session(&Session{}) 417 418 switch reflectValue.Kind() { 419 case reflect.Slice, reflect.Array: 420 if len(values) != reflectValue.Len() { 421 // clear old data 422 if clear && len(values) == 0 { 423 for i := 0; i < reflectValue.Len(); i++ { 424 if err := association.Relationship.Field.Set(reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { 425 association.Error = err 426 break 427 } 428 429 if association.Relationship.JoinTable == nil { 430 for _, ref := range association.Relationship.References { 431 if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { 432 if err := ref.ForeignKey.Set(reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { 433 association.Error = err 434 break 435 } 436 } 437 } 438 } 439 } 440 break 441 } 442 443 association.Error = ErrInvalidValueOfLength 444 return 445 } 446 447 for i := 0; i < reflectValue.Len(); i++ { 448 appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) 449 450 // TODO support save slice data, sql with case? 451 association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error 452 } 453 case reflect.Struct: 454 // clear old data 455 if clear && len(values) == 0 { 456 association.Error = association.Relationship.Field.Set(reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) 457 458 if association.Relationship.JoinTable == nil && association.Error == nil { 459 for _, ref := range association.Relationship.References { 460 if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { 461 association.Error = ref.ForeignKey.Set(reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) 462 } 463 } 464 } 465 } 466 467 for idx, value := range values { 468 rv := reflect.Indirect(reflect.ValueOf(value)) 469 appendToRelations(reflectValue, rv, clear && idx == 0) 470 } 471 472 if len(values) > 0 { 473 association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error 474 } 475 } 476 477 for _, assignBack := range assignBacks { 478 fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(assignBack.Source)) 479 if assignBack.Index > 0 { 480 reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) 481 } else { 482 reflect.Indirect(assignBack.Dest).Set(fieldValue) 483 } 484 } 485 } 486 487 func (association *Association) buildCondition() *DB { 488 var ( 489 queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.ReflectValue) 490 modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() 491 tx = association.DB.Model(modelValue) 492 ) 493 494 if association.Relationship.JoinTable != nil { 495 if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { 496 joinStmt := Statement{DB: tx, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} 497 for _, queryClause := range association.Relationship.JoinTable.QueryClauses { 498 joinStmt.AddClause(queryClause) 499 } 500 joinStmt.Build("WHERE") 501 tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) 502 } 503 504 tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ 505 Table: clause.Table{Name: association.Relationship.JoinTable.Table}, 506 ON: clause.Where{Exprs: queryConds}, 507 }}}) 508 } else { 509 tx.Clauses(clause.Where{Exprs: queryConds}) 510 } 511 512 return tx 513 }