github.com/rjgonzale/pop/v5@v5.1.3-dev/preload_associations.go (about) 1 package pop 2 3 import ( 4 "fmt" 5 "reflect" 6 "regexp" 7 "strings" 8 9 "github.com/gobuffalo/flect" 10 "github.com/gobuffalo/pop/v5/internal/defaults" 11 "github.com/gobuffalo/pop/v5/logging" 12 "github.com/jmoiron/sqlx" 13 "github.com/jmoiron/sqlx/reflectx" 14 ) 15 16 var validFieldRegexp = regexp.MustCompile(`^(([a-zA-Z0-9]*)(\.[a-zA-Z0-9]+)?)+$`) 17 18 // NewModelMetaInfo creates the meta info details for the model passed 19 // as a parameter. 20 func NewModelMetaInfo(model *Model) *ModelMetaInfo { 21 mmi := &ModelMetaInfo{} 22 mmi.Model = model 23 mmi.init() 24 return mmi 25 } 26 27 // NewAssociationMetaInfo creates the meta info details for the passed association. 28 func NewAssociationMetaInfo(fi *reflectx.FieldInfo) *AssociationMetaInfo { 29 ami := &AssociationMetaInfo{} 30 ami.FieldInfo = fi 31 ami.init() 32 return ami 33 } 34 35 // ModelMetaInfo a type to abstract all fields information regarding 36 // to a model. A model is representation of a table in the 37 // database. 38 type ModelMetaInfo struct { 39 *reflectx.StructMap 40 Model *Model 41 mapper *reflectx.Mapper 42 nestedFields map[string]string 43 } 44 45 func (mmi *ModelMetaInfo) init() { 46 m := reflectx.NewMapper("") 47 mmi.mapper = m 48 49 t := reflectx.Deref(reflect.TypeOf(mmi.Model.Value)) 50 if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { 51 t = reflectx.Deref(t.Elem()) 52 } 53 54 mmi.StructMap = m.TypeMap(t) 55 mmi.nestedFields = make(map[string]string) 56 } 57 58 func (mmi *ModelMetaInfo) iterate(fn func(reflect.Value)) { 59 modelValue := reflect.Indirect(reflect.ValueOf(mmi.Model.Value)) 60 if modelValue.Kind() == reflect.Slice || modelValue.Kind() == reflect.Array { 61 for i := 0; i < modelValue.Len(); i++ { 62 fn(modelValue.Index(i)) 63 } 64 return 65 } 66 fn(modelValue) 67 } 68 69 func (mmi *ModelMetaInfo) getDBFieldTaggedWith(value string) *reflectx.FieldInfo { 70 for _, fi := range mmi.Index { 71 if fi.Field.Tag.Get("db") == value { 72 if len(fi.Children) > 0 { 73 return fi.Children[0] 74 } 75 return fi 76 } 77 } 78 return nil 79 } 80 81 func (mmi *ModelMetaInfo) preloadFields(fields ...string) ([]*reflectx.FieldInfo, error) { 82 if len(fields) == 0 { 83 return mmi.Index, nil 84 } 85 86 var preloadFields []*reflectx.FieldInfo 87 for _, f := range fields { 88 if !validFieldRegexp.MatchString(f) { 89 return preloadFields, fmt.Errorf("association field '%s' does not match the format %s", f, "'<field>' or '<field>.<nested-field>'") 90 } 91 if strings.Contains(f, ".") { 92 mmi.nestedFields[f[:strings.Index(f, ".")]] = f[strings.Index(f, ".")+1:] 93 f = f[:strings.Index(f, ".")] 94 } 95 96 preloadField := mmi.GetByPath(f) 97 if preloadField == nil { 98 return preloadFields, fmt.Errorf("field %s does not exist in model %s", f, mmi.Model.TableName()) 99 } 100 101 var exist bool 102 for _, pf := range preloadFields { 103 if pf.Path == preloadField.Path { 104 exist = true 105 } 106 } 107 if !exist { 108 preloadFields = append(preloadFields, preloadField) 109 } 110 } 111 return preloadFields, nil 112 } 113 114 // AssociationMetaInfo a type to abstract all field information 115 // regarding to an association. An association is a field 116 // that has defined a tag like 'has_many', 'belongs_to', 117 // 'many_to_many' and 'has_one'. 118 type AssociationMetaInfo struct { 119 *reflectx.FieldInfo 120 *reflectx.StructMap 121 } 122 123 func (ami *AssociationMetaInfo) init() { 124 mapper := reflectx.NewMapper("") 125 t := reflectx.Deref(ami.FieldInfo.Field.Type) 126 if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { 127 t = reflectx.Deref(t.Elem()) 128 } 129 130 ami.StructMap = mapper.TypeMap(t) 131 } 132 133 func (ami *AssociationMetaInfo) toSlice() reflect.Value { 134 ft := reflectx.Deref(ami.Field.Type) 135 var vt reflect.Value 136 if ft.Kind() == reflect.Slice || ft.Kind() == reflect.Array { 137 vt = reflect.New(ft) 138 } else { 139 vt = reflect.New(reflect.SliceOf(ft)) 140 } 141 return vt 142 } 143 144 func (ami *AssociationMetaInfo) getDBFieldTaggedWith(value string) *reflectx.FieldInfo { 145 for _, fi := range ami.StructMap.Index { 146 if fi.Field.Tag.Get("db") == value { 147 if len(fi.Children) > 0 { 148 return fi.Children[0] 149 } 150 return fi 151 } 152 } 153 return nil 154 } 155 156 func (ami *AssociationMetaInfo) fkName() string { 157 t := ami.Field.Type 158 if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { 159 t = reflectx.Deref(t.Elem()) 160 } 161 fkName := fmt.Sprintf("%s%s", flect.Underscore(flect.Singularize(t.Name())), "_id") 162 fkNameTag := flect.Underscore(ami.Field.Tag.Get("fk_id")) 163 return defaults.String(fkNameTag, fkName) 164 } 165 166 // preload is the query mode used to load associations from database 167 // similar to the active record default approach on Rails. 168 func preload(tx *Connection, model interface{}, fields ...string) error { 169 mmi := NewModelMetaInfo(&Model{Value: model}) 170 171 preloadFields, err := mmi.preloadFields(fields...) 172 if err != nil { 173 return err 174 } 175 176 var associations []*AssociationMetaInfo 177 for _, fieldInfo := range preloadFields { 178 if isFieldAssociation(fieldInfo.Field) && fieldInfo.Parent.Name == "" { 179 associations = append(associations, NewAssociationMetaInfo(fieldInfo)) 180 } 181 } 182 183 for _, asoc := range associations { 184 if asoc.Field.Tag.Get("has_many") != "" { 185 err := preloadHasMany(tx, asoc, mmi) 186 if err != nil { 187 return err 188 } 189 } 190 191 if asoc.Field.Tag.Get("has_one") != "" { 192 err := preloadHasOne(tx, asoc, mmi) 193 if err != nil { 194 return err 195 } 196 } 197 198 if asoc.Field.Tag.Get("belongs_to") != "" { 199 err := preloadBelongsTo(tx, asoc, mmi) 200 if err != nil { 201 return err 202 } 203 } 204 205 if asoc.Field.Tag.Get("many_to_many") != "" { 206 err := preloadManyToMany(tx, asoc, mmi) 207 if err != nil { 208 return err 209 } 210 } 211 } 212 return nil 213 } 214 215 func isFieldAssociation(field reflect.StructField) bool { 216 for _, associationLabel := range []string{"has_many", "has_one", "belongs_to", "many_to_many"} { 217 if field.Tag.Get(associationLabel) != "" { 218 return true 219 } 220 } 221 return false 222 } 223 224 func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { 225 // 1) get all associations ids. 226 // 1.1) In here I pick ids from model meta info directly. 227 ids := []interface{}{} 228 mmi.Model.iterate(func(m *Model) error { 229 ids = append(ids, m.ID()) 230 return nil 231 }) 232 233 if len(ids) == 0 { 234 return nil 235 } 236 237 // 2) load all associations constraint by model ids. 238 fk := asoc.Field.Tag.Get("fk_id") 239 if fk == "" { 240 fk = mmi.Model.associationName() 241 } 242 243 q := tx.Q() 244 q.eager = false 245 q.eagerFields = []string{} 246 247 slice := asoc.toSlice() 248 249 if strings.TrimSpace(asoc.Field.Tag.Get("order_by")) != "" { 250 q.Order(asoc.Field.Tag.Get("order_by")) 251 } 252 253 err := q.Where(fmt.Sprintf("%s in (?)", fk), ids).All(slice.Interface()) 254 if err != nil { 255 return err 256 } 257 258 // 2.1) load all nested associations from this assoc. 259 if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok { 260 if err := preload(tx, slice.Interface(), asocNestedFields); err != nil { 261 return err 262 } 263 } 264 265 // 3) iterate over every model and fill it with the assoc. 266 foreignField := asoc.getDBFieldTaggedWith(fk) 267 mmi.iterate(func(mvalue reflect.Value) { 268 modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) 269 for i := 0; i < slice.Elem().Len(); i++ { 270 asocValue := slice.Elem().Index(i) 271 if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() || 272 reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) { 273 274 switch { 275 case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: 276 modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) 277 case modelAssociationField.Kind() == reflect.Ptr: 278 modelAssociationField.Elem().Set(reflect.Append(modelAssociationField.Elem(), asocValue)) 279 default: 280 modelAssociationField.Set(asocValue) 281 } 282 } 283 } 284 }) 285 286 return nil 287 } 288 289 func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { 290 // 1) get all associations ids. 291 ids := []interface{}{} 292 mmi.Model.iterate(func(m *Model) error { 293 ids = append(ids, m.ID()) 294 return nil 295 }) 296 297 if len(ids) == 0 { 298 return nil 299 } 300 301 // 2) load all associations constraint by model ids. 302 fk := asoc.Field.Tag.Get("fk_id") 303 if fk == "" { 304 fk = mmi.Model.associationName() 305 } 306 307 q := tx.Q() 308 q.eager = false 309 q.eagerFields = []string{} 310 311 slice := asoc.toSlice() 312 err := q.Where(fmt.Sprintf("%s in (?)", fk), ids).All(slice.Interface()) 313 if err != nil { 314 return err 315 } 316 317 // 2.1) load all nested associations from this assoc. 318 if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok { 319 if err := preload(tx, slice.Interface(), asocNestedFields); err != nil { 320 return err 321 } 322 } 323 324 // 3) iterate over every model and fill it with the assoc. 325 foreignField := asoc.getDBFieldTaggedWith(fk) 326 mmi.iterate(func(mvalue reflect.Value) { 327 modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) 328 for i := 0; i < slice.Elem().Len(); i++ { 329 asocValue := slice.Elem().Index(i) 330 if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() || 331 reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) { 332 if modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array { 333 modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) 334 continue 335 } 336 modelAssociationField.Set(asocValue) 337 } 338 } 339 }) 340 341 return nil 342 } 343 344 func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { 345 // 1) get all associations ids. 346 fi := mmi.getDBFieldTaggedWith(asoc.fkName()) 347 if fi == nil { 348 fi = mmi.getDBFieldTaggedWith(fmt.Sprintf("%s%s", flect.Underscore(asoc.Path), "_id")) 349 } 350 351 fkids := []interface{}{} 352 mmi.iterate(func(val reflect.Value) { 353 fkids = append(fkids, mmi.mapper.FieldByName(val, fi.Path).Interface()) 354 }) 355 356 if len(fkids) == 0 { 357 return nil 358 } 359 360 // 2) load all associations constraint by association fields ids. 361 fk := "id" 362 363 q := tx.Q() 364 q.eager = false 365 q.eagerFields = []string{} 366 367 slice := asoc.toSlice() 368 err := q.Where(fmt.Sprintf("%s in (?)", fk), fkids).All(slice.Interface()) 369 if err != nil { 370 return err 371 } 372 373 // 2.1) load all nested associations from this assoc. 374 if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok { 375 if err := preload(tx, slice.Interface(), asocNestedFields); err != nil { 376 return err 377 } 378 } 379 380 // 3) iterate over every model and fill it with the assoc. 381 mmi.iterate(func(mvalue reflect.Value) { 382 modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) 383 for i := 0; i < slice.Elem().Len(); i++ { 384 asocValue := slice.Elem().Index(i) 385 if mmi.mapper.FieldByName(mvalue, fi.Path).Interface() == mmi.mapper.FieldByName(asocValue, "ID").Interface() || 386 reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, fi.Path), mmi.mapper.FieldByName(asocValue, "ID")) { 387 388 switch { 389 case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: 390 modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) 391 case modelAssociationField.Kind() == reflect.Ptr: 392 modelAssociationField.Elem().Set(asocValue) 393 default: 394 modelAssociationField.Set(asocValue) 395 } 396 } 397 } 398 }) 399 400 return nil 401 } 402 403 func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { 404 // 1) get all associations ids. 405 // 1.1) In here I pick ids from model meta info directly. 406 ids := []interface{}{} 407 mmi.Model.iterate(func(m *Model) error { 408 ids = append(ids, m.ID()) 409 return nil 410 }) 411 412 if len(ids) == 0 { 413 return nil 414 } 415 416 // 2) load all associations. 417 // 2.1) In here I pick the label name from association. 418 manyToManyTableName := asoc.Field.Tag.Get("many_to_many") 419 modelAssociationName := mmi.Model.associationName() 420 assocFkName := asoc.fkName() 421 422 if strings.Contains(manyToManyTableName, ":") { 423 modelAssociationName = strings.TrimSpace(manyToManyTableName[strings.Index(manyToManyTableName, ":")+1:]) 424 manyToManyTableName = strings.TrimSpace(manyToManyTableName[:strings.Index(manyToManyTableName, ":")]) 425 } 426 427 if tx.TX != nil { 428 sql := fmt.Sprintf("SELECT %s, %s FROM %s WHERE %s in (?)", modelAssociationName, assocFkName, manyToManyTableName, modelAssociationName) 429 sql, args, _ := sqlx.In(sql, ids) 430 sql = tx.Dialect.TranslateSQL(sql) 431 log(logging.SQL, sql, args...) 432 rows, err := tx.TX.Queryx(sql, args...) 433 if err != nil { 434 return err 435 } 436 437 mapAssoc := map[string][]interface{}{} 438 fkids := []interface{}{} 439 for rows.Next() { 440 row, err := rows.SliceScan() 441 if err != nil { 442 return err 443 } 444 if len(row) > 0 { 445 if _, ok := row[0].([]uint8); ok { // -> it's UUID 446 row[0] = string(row[0].([]uint8)) 447 } 448 if _, ok := row[1].([]uint8); ok { // -> it's UUID 449 row[1] = string(row[1].([]uint8)) 450 } 451 key := fmt.Sprintf("%v", row[0]) 452 mapAssoc[key] = append(mapAssoc[key], row[1]) 453 fkids = append(fkids, row[1]) 454 } 455 } 456 457 q := tx.Q() 458 q.eager = false 459 q.eagerFields = []string{} 460 461 if strings.TrimSpace(asoc.Field.Tag.Get("order_by")) != "" { 462 q.Order(asoc.Field.Tag.Get("order_by")) 463 } 464 465 slice := asoc.toSlice() 466 q.Where("id in (?)", fkids).All(slice.Interface()) 467 468 // 2.2) load all nested associations from this assoc. 469 if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok { 470 if err := preload(tx, slice.Interface(), asocNestedFields); err != nil { 471 return err 472 } 473 } 474 475 // 3) iterate over every model and fill it with the assoc. 476 mmi.iterate(func(mvalue reflect.Value) { 477 id := mmi.mapper.FieldByName(mvalue, "ID").Interface() 478 if assocFkIds, ok := mapAssoc[fmt.Sprintf("%v", id)]; ok { 479 modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) 480 for i := 0; i < slice.Elem().Len(); i++ { 481 asocValue := slice.Elem().Index(i) 482 for _, fkid := range assocFkIds { 483 if fmt.Sprintf("%v", fkid) == fmt.Sprintf("%v", mmi.mapper.FieldByName(asocValue, "ID").Interface()) { 484 modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) 485 } 486 } 487 } 488 } 489 }) 490 } 491 return nil 492 }