github.com/systematiccaos/gorm@v1.22.6/callbacks/associations.go (about) 1 package callbacks 2 3 import ( 4 "reflect" 5 "strings" 6 7 "github.com/systematiccaos/gorm" 8 "github.com/systematiccaos/gorm/clause" 9 "github.com/systematiccaos/gorm/schema" 10 "github.com/systematiccaos/gorm/utils" 11 ) 12 13 func SaveBeforeAssociations(create bool) func(db *gorm.DB) { 14 return func(db *gorm.DB) { 15 if db.Error == nil && db.Statement.Schema != nil { 16 selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) 17 18 // Save Belongs To associations 19 for _, rel := range db.Statement.Schema.Relationships.BelongsTo { 20 if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { 21 continue 22 } 23 24 setupReferences := func(obj reflect.Value, elem reflect.Value) { 25 for _, ref := range rel.References { 26 if !ref.OwnPrimaryKey { 27 pv, _ := ref.PrimaryKey.ValueOf(elem) 28 db.AddError(ref.ForeignKey.Set(obj, pv)) 29 30 if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { 31 dest[ref.ForeignKey.DBName] = pv 32 if _, ok := dest[rel.Name]; ok { 33 dest[rel.Name] = elem.Interface() 34 } 35 } 36 } 37 } 38 } 39 40 switch db.Statement.ReflectValue.Kind() { 41 case reflect.Slice, reflect.Array: 42 var ( 43 rValLen = db.Statement.ReflectValue.Len() 44 objs = make([]reflect.Value, 0, rValLen) 45 fieldType = rel.Field.FieldType 46 isPtr = fieldType.Kind() == reflect.Ptr 47 ) 48 49 if !isPtr { 50 fieldType = reflect.PtrTo(fieldType) 51 } 52 53 elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) 54 for i := 0; i < rValLen; i++ { 55 obj := db.Statement.ReflectValue.Index(i) 56 if reflect.Indirect(obj).Kind() != reflect.Struct { 57 break 58 } 59 60 if _, zero := rel.Field.ValueOf(obj); !zero { // check belongs to relation value 61 rv := rel.Field.ReflectValueOf(obj) // relation reflect value 62 objs = append(objs, obj) 63 if isPtr { 64 elems = reflect.Append(elems, rv) 65 } else { 66 elems = reflect.Append(elems, rv.Addr()) 67 } 68 } 69 } 70 71 if elems.Len() > 0 { 72 if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { 73 for i := 0; i < elems.Len(); i++ { 74 setupReferences(objs[i], elems.Index(i)) 75 } 76 } 77 } 78 case reflect.Struct: 79 if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { 80 rv := rel.Field.ReflectValueOf(db.Statement.ReflectValue) // relation reflect value 81 if rv.Kind() != reflect.Ptr { 82 rv = rv.Addr() 83 } 84 85 if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { 86 setupReferences(db.Statement.ReflectValue, rv) 87 } 88 } 89 } 90 } 91 } 92 } 93 } 94 95 func SaveAfterAssociations(create bool) func(db *gorm.DB) { 96 return func(db *gorm.DB) { 97 if db.Error == nil && db.Statement.Schema != nil { 98 selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) 99 100 // Save Has One associations 101 for _, rel := range db.Statement.Schema.Relationships.HasOne { 102 if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { 103 continue 104 } 105 106 switch db.Statement.ReflectValue.Kind() { 107 case reflect.Slice, reflect.Array: 108 var ( 109 fieldType = rel.Field.FieldType 110 isPtr = fieldType.Kind() == reflect.Ptr 111 ) 112 113 if !isPtr { 114 fieldType = reflect.PtrTo(fieldType) 115 } 116 117 elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) 118 119 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { 120 obj := db.Statement.ReflectValue.Index(i) 121 122 if reflect.Indirect(obj).Kind() == reflect.Struct { 123 if _, zero := rel.Field.ValueOf(obj); !zero { 124 rv := rel.Field.ReflectValueOf(obj) 125 if rv.Kind() != reflect.Ptr { 126 rv = rv.Addr() 127 } 128 129 for _, ref := range rel.References { 130 if ref.OwnPrimaryKey { 131 fv, _ := ref.PrimaryKey.ValueOf(obj) 132 db.AddError(ref.ForeignKey.Set(rv, fv)) 133 } else if ref.PrimaryValue != "" { 134 db.AddError(ref.ForeignKey.Set(rv, ref.PrimaryValue)) 135 } 136 } 137 138 elems = reflect.Append(elems, rv) 139 } 140 } 141 } 142 143 if elems.Len() > 0 { 144 assignmentColumns := make([]string, 0, len(rel.References)) 145 for _, ref := range rel.References { 146 assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) 147 } 148 149 saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) 150 } 151 case reflect.Struct: 152 if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { 153 f := rel.Field.ReflectValueOf(db.Statement.ReflectValue) 154 if f.Kind() != reflect.Ptr { 155 f = f.Addr() 156 } 157 158 assignmentColumns := make([]string, 0, len(rel.References)) 159 for _, ref := range rel.References { 160 if ref.OwnPrimaryKey { 161 fv, _ := ref.PrimaryKey.ValueOf(db.Statement.ReflectValue) 162 ref.ForeignKey.Set(f, fv) 163 } else if ref.PrimaryValue != "" { 164 ref.ForeignKey.Set(f, ref.PrimaryValue) 165 } 166 assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) 167 } 168 169 saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) 170 } 171 } 172 } 173 174 // Save Has Many associations 175 for _, rel := range db.Statement.Schema.Relationships.HasMany { 176 if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { 177 continue 178 } 179 180 fieldType := rel.Field.IndirectFieldType.Elem() 181 isPtr := fieldType.Kind() == reflect.Ptr 182 if !isPtr { 183 fieldType = reflect.PtrTo(fieldType) 184 } 185 elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) 186 identityMap := map[string]bool{} 187 appendToElems := func(v reflect.Value) { 188 if _, zero := rel.Field.ValueOf(v); !zero { 189 f := reflect.Indirect(rel.Field.ReflectValueOf(v)) 190 191 for i := 0; i < f.Len(); i++ { 192 elem := f.Index(i) 193 for _, ref := range rel.References { 194 if ref.OwnPrimaryKey { 195 pv, _ := ref.PrimaryKey.ValueOf(v) 196 ref.ForeignKey.Set(elem, pv) 197 } else if ref.PrimaryValue != "" { 198 ref.ForeignKey.Set(elem, ref.PrimaryValue) 199 } 200 } 201 202 relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) 203 for _, pf := range rel.FieldSchema.PrimaryFields { 204 if pfv, ok := pf.ValueOf(elem); !ok { 205 relPrimaryValues = append(relPrimaryValues, pfv) 206 } 207 } 208 209 cacheKey := utils.ToStringKey(relPrimaryValues) 210 if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { 211 identityMap[cacheKey] = true 212 if isPtr { 213 elems = reflect.Append(elems, elem) 214 } else { 215 elems = reflect.Append(elems, elem.Addr()) 216 } 217 } 218 } 219 } 220 } 221 222 switch db.Statement.ReflectValue.Kind() { 223 case reflect.Slice, reflect.Array: 224 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { 225 obj := db.Statement.ReflectValue.Index(i) 226 if reflect.Indirect(obj).Kind() == reflect.Struct { 227 appendToElems(obj) 228 } 229 } 230 case reflect.Struct: 231 appendToElems(db.Statement.ReflectValue) 232 } 233 234 if elems.Len() > 0 { 235 assignmentColumns := make([]string, 0, len(rel.References)) 236 for _, ref := range rel.References { 237 assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) 238 } 239 240 saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) 241 } 242 } 243 244 // Save Many2Many associations 245 for _, rel := range db.Statement.Schema.Relationships.Many2Many { 246 if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { 247 continue 248 } 249 250 fieldType := rel.Field.IndirectFieldType.Elem() 251 isPtr := fieldType.Kind() == reflect.Ptr 252 if !isPtr { 253 fieldType = reflect.PtrTo(fieldType) 254 } 255 elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) 256 joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10) 257 objs := []reflect.Value{} 258 259 appendToJoins := func(obj reflect.Value, elem reflect.Value) { 260 joinValue := reflect.New(rel.JoinTable.ModelType) 261 for _, ref := range rel.References { 262 if ref.OwnPrimaryKey { 263 fv, _ := ref.PrimaryKey.ValueOf(obj) 264 ref.ForeignKey.Set(joinValue, fv) 265 } else if ref.PrimaryValue != "" { 266 ref.ForeignKey.Set(joinValue, ref.PrimaryValue) 267 } else { 268 fv, _ := ref.PrimaryKey.ValueOf(elem) 269 ref.ForeignKey.Set(joinValue, fv) 270 } 271 } 272 joins = reflect.Append(joins, joinValue) 273 } 274 275 appendToElems := func(v reflect.Value) { 276 if _, zero := rel.Field.ValueOf(v); !zero { 277 f := reflect.Indirect(rel.Field.ReflectValueOf(v)) 278 279 for i := 0; i < f.Len(); i++ { 280 elem := f.Index(i) 281 282 objs = append(objs, v) 283 if isPtr { 284 elems = reflect.Append(elems, elem) 285 } else { 286 elems = reflect.Append(elems, elem.Addr()) 287 } 288 } 289 } 290 } 291 292 switch db.Statement.ReflectValue.Kind() { 293 case reflect.Slice, reflect.Array: 294 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { 295 obj := db.Statement.ReflectValue.Index(i) 296 if reflect.Indirect(obj).Kind() == reflect.Struct { 297 appendToElems(obj) 298 } 299 } 300 case reflect.Struct: 301 appendToElems(db.Statement.ReflectValue) 302 } 303 304 // optimize elems of reflect value length 305 if elemLen := elems.Len(); elemLen > 0 { 306 if v, ok := selectColumns[rel.Name+".*"]; !ok || v { 307 saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) 308 } 309 310 for i := 0; i < elemLen; i++ { 311 appendToJoins(objs[i], elems.Index(i)) 312 } 313 } 314 315 if joins.Len() > 0 { 316 db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{ 317 SkipHooks: db.Statement.SkipHooks, 318 DisableNestedTransaction: true, 319 }).Create(joins.Interface()).Error) 320 } 321 } 322 } 323 } 324 } 325 326 func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { 327 if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations { 328 onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) 329 for _, dbName := range s.PrimaryFieldDBNames { 330 onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName}) 331 } 332 333 onConflict.UpdateAll = stmt.DB.FullSaveAssociations 334 if !onConflict.UpdateAll { 335 onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns) 336 } 337 } else { 338 onConflict.DoNothing = true 339 } 340 341 return 342 } 343 344 func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { 345 var ( 346 selects, omits []string 347 onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) 348 refName = rel.Name + "." 349 ) 350 351 for name, ok := range selectColumns { 352 columnName := "" 353 if strings.HasPrefix(name, refName) { 354 columnName = strings.TrimPrefix(name, refName) 355 } 356 357 if columnName != "" { 358 if ok { 359 selects = append(selects, columnName) 360 } else { 361 omits = append(omits, columnName) 362 } 363 } 364 } 365 366 tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{ 367 FullSaveAssociations: db.FullSaveAssociations, 368 SkipHooks: db.Statement.SkipHooks, 369 DisableNestedTransaction: true, 370 }) 371 372 db.Statement.Settings.Range(func(k, v interface{}) bool { 373 tx.Statement.Settings.Store(k, v) 374 return true 375 }) 376 377 if tx.Statement.FullSaveAssociations { 378 tx = tx.Set("gorm:update_track_time", true) 379 } 380 381 if len(selects) > 0 { 382 tx = tx.Select(selects) 383 } else if restricted && len(omits) == 0 { 384 tx = tx.Omit(clause.Associations) 385 } 386 387 if len(omits) > 0 { 388 tx = tx.Omit(omits...) 389 } 390 391 return db.AddError(tx.Create(values).Error) 392 }