github.com/dkishere/pop/v6@v6.103.1/preload_associations.go (about) 1 package pop 2 3 import ( 4 "fmt" 5 "reflect" 6 "regexp" 7 "strings" 8 9 "github.com/gobuffalo/flect" 10 "github.com/dkishere/pop/v6/internal/defaults" 11 "github.com/dkishere/pop/v6/logging" 12 "github.com/jmoiron/sqlx" 13 "github.com/jmoiron/sqlx/reflectx" 14 ) 15 16 var validFieldRegexp = regexp.MustCompile(`^(([a-zA-Z0-9]*)(\.[a-zA-Z0-9]+)?)+$`) 17 18 // NewModelMetaInfo creates the meta info details for the model passed 19 // as a parameter. 20 func NewModelMetaInfo(model *Model) *ModelMetaInfo { 21 mmi := &ModelMetaInfo{} 22 mmi.Model = model 23 mmi.init() 24 return mmi 25 } 26 27 // NewAssociationMetaInfo creates the meta info details for the passed association. 28 func NewAssociationMetaInfo(fi *reflectx.FieldInfo) *AssociationMetaInfo { 29 ami := &AssociationMetaInfo{} 30 ami.FieldInfo = fi 31 ami.init() 32 return ami 33 } 34 35 // ModelMetaInfo a type to abstract all fields information regarding 36 // to a model. A model is representation of a table in the 37 // database. 38 type ModelMetaInfo struct { 39 *reflectx.StructMap 40 Model *Model 41 mapper *reflectx.Mapper 42 nestedFields map[string][]string 43 } 44 45 func (mmi *ModelMetaInfo) init() { 46 m := reflectx.NewMapper("") 47 mmi.mapper = m 48 49 t := reflectx.Deref(reflect.TypeOf(mmi.Model.Value)) 50 if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { 51 t = reflectx.Deref(t.Elem()) 52 } 53 54 mmi.StructMap = m.TypeMap(t) 55 mmi.nestedFields = make(map[string][]string) 56 } 57 58 func (mmi *ModelMetaInfo) iterate(fn func(reflect.Value)) { 59 modelValue := reflect.Indirect(reflect.ValueOf(mmi.Model.Value)) 60 if modelValue.Kind() == reflect.Slice || modelValue.Kind() == reflect.Array { 61 for i := 0; i < modelValue.Len(); i++ { 62 fn(modelValue.Index(i)) 63 } 64 return 65 } 66 fn(modelValue) 67 } 68 69 func (mmi *ModelMetaInfo) getDBFieldTaggedWith(value string) *reflectx.FieldInfo { 70 for _, fi := range mmi.Index { 71 if fi.Field.Tag.Get("db") == value { 72 if len(fi.Children) > 0 { 73 return fi.Children[0] 74 } 75 return fi 76 } 77 } 78 return nil 79 } 80 81 func (mmi *ModelMetaInfo) preloadFields(fields ...string) ([]*reflectx.FieldInfo, error) { 82 if len(fields) == 0 { 83 return mmi.Index, nil 84 } 85 86 var preloadFields []*reflectx.FieldInfo 87 for _, f := range fields { 88 if !validFieldRegexp.MatchString(f) { 89 return preloadFields, fmt.Errorf("association field '%s' does not match the format %s", f, "'<field>' or '<field>.<nested-field>'") 90 } 91 if strings.Contains(f, ".") { 92 fname := f[:strings.Index(f, ".")] 93 mmi.nestedFields[fname] = append(mmi.nestedFields[fname], f[strings.Index(f, ".")+1:]) 94 f = f[:strings.Index(f, ".")] 95 } 96 97 preloadField := mmi.GetByPath(f) 98 if preloadField == nil { 99 return preloadFields, fmt.Errorf("field %s does not exist in model %s", f, mmi.Model.TableName()) 100 } 101 102 var exist bool 103 for _, pf := range preloadFields { 104 if pf.Path == preloadField.Path { 105 exist = true 106 } 107 } 108 if !exist { 109 preloadFields = append(preloadFields, preloadField) 110 } 111 } 112 return preloadFields, nil 113 } 114 115 // AssociationMetaInfo a type to abstract all field information 116 // regarding to an association. An association is a field 117 // that has defined a tag like 'has_many', 'belongs_to', 118 // 'many_to_many' and 'has_one'. 119 type AssociationMetaInfo struct { 120 *reflectx.FieldInfo 121 *reflectx.StructMap 122 } 123 124 func (ami *AssociationMetaInfo) init() { 125 mapper := reflectx.NewMapper("") 126 t := reflectx.Deref(ami.FieldInfo.Field.Type) 127 if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { 128 t = reflectx.Deref(t.Elem()) 129 } 130 131 ami.StructMap = mapper.TypeMap(t) 132 } 133 134 func (ami *AssociationMetaInfo) toSlice() reflect.Value { 135 ft := reflectx.Deref(ami.Field.Type) 136 var vt reflect.Value 137 if ft.Kind() == reflect.Slice || ft.Kind() == reflect.Array { 138 vt = reflect.New(ft) 139 } else { 140 vt = reflect.New(reflect.SliceOf(ft)) 141 } 142 return vt 143 } 144 145 func (ami *AssociationMetaInfo) getDBFieldTaggedWith(value string) *reflectx.FieldInfo { 146 for _, fi := range ami.StructMap.Index { 147 if fi.Field.Tag.Get("db") == value { 148 if len(fi.Children) > 0 { 149 return fi.Children[0] 150 } 151 return fi 152 } 153 } 154 return nil 155 } 156 157 func (ami *AssociationMetaInfo) fkName() string { 158 t := ami.Field.Type 159 if t.Kind() == reflect.Slice || t.Kind() == reflect.Array { 160 t = reflectx.Deref(t.Elem()) 161 } 162 fkName := fmt.Sprintf("%s%s", flect.Underscore(flect.Singularize(t.Name())), "_id") 163 fkNameTag := flect.Underscore(ami.Field.Tag.Get("fk_id")) 164 return defaults.String(fkNameTag, fkName) 165 } 166 167 // preload is the query mode used to load associations from database 168 // similar to the active record default approach on Rails. 169 func preload(tx *Connection, model interface{}, fields ...string) error { 170 mmi := NewModelMetaInfo(NewModel(model, tx.Context())) 171 172 preloadFields, err := mmi.preloadFields(fields...) 173 if err != nil { 174 return err 175 } 176 177 var associations []*AssociationMetaInfo 178 for _, fieldInfo := range preloadFields { 179 if isFieldAssociation(fieldInfo.Field) && fieldInfo.Parent.Name == "" { 180 associations = append(associations, NewAssociationMetaInfo(fieldInfo)) 181 } 182 } 183 184 for _, asoc := range associations { 185 if asoc.Field.Tag.Get("has_many") != "" { 186 err := preloadHasMany(tx, asoc, mmi) 187 if err != nil { 188 return err 189 } 190 } 191 192 if asoc.Field.Tag.Get("has_one") != "" { 193 err := preloadHasOne(tx, asoc, mmi) 194 if err != nil { 195 return err 196 } 197 } 198 199 if asoc.Field.Tag.Get("belongs_to") != "" { 200 err := preloadBelongsTo(tx, asoc, mmi) 201 if err != nil { 202 return err 203 } 204 } 205 206 if asoc.Field.Tag.Get("many_to_many") != "" { 207 err := preloadManyToMany(tx, asoc, mmi) 208 if err != nil { 209 return err 210 } 211 } 212 } 213 return nil 214 } 215 216 func isFieldAssociation(field reflect.StructField) bool { 217 for _, associationLabel := range []string{"has_many", "has_one", "belongs_to", "many_to_many"} { 218 if field.Tag.Get(associationLabel) != "" { 219 return true 220 } 221 } 222 return false 223 } 224 225 func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { 226 // 1) get all associations ids. 227 // 1.1) In here I pick ids from model meta info directly. 228 ids := []interface{}{} 229 mmi.Model.iterate(func(m *Model) error { 230 ids = append(ids, m.ID()) 231 return nil 232 }) 233 234 if len(ids) == 0 { 235 return nil 236 } 237 238 // 2) load all associations constraint by model ids. 239 fk := asoc.Field.Tag.Get("fk_id") 240 if fk == "" { 241 fk = mmi.Model.associationName() 242 } 243 244 q := tx.Q() 245 q.eager = false 246 q.eagerFields = []string{} 247 248 slice := asoc.toSlice() 249 250 if strings.TrimSpace(asoc.Field.Tag.Get("order_by")) != "" { 251 q.Order(asoc.Field.Tag.Get("order_by")) 252 } 253 254 err := q.Where(fmt.Sprintf("%s in (?)", fk), ids).All(slice.Interface()) 255 if err != nil { 256 return err 257 } 258 259 // 2.1) load all nested associations from this assoc. 260 if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok { 261 for _, asocNestedField := range asocNestedFields { 262 if err := preload(tx, slice.Interface(), asocNestedField); err != nil { 263 return err 264 } 265 } 266 } 267 268 // 3) iterate over every model and fill it with the assoc. 269 foreignField := asoc.getDBFieldTaggedWith(fk) 270 mmi.iterate(func(mvalue reflect.Value) { 271 modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) 272 for i := 0; i < slice.Elem().Len(); i++ { 273 asocValue := slice.Elem().Index(i) 274 valueField := reflect.Indirect(mmi.mapper.FieldByName(asocValue, foreignField.Path)) 275 if mmi.mapper.FieldByName(mvalue, "ID").Interface() == valueField.Interface() || 276 reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), valueField) { 277 278 switch { 279 case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: 280 modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) 281 case modelAssociationField.Kind() == reflect.Ptr: 282 modelAssociationField.Elem().Set(reflect.Append(modelAssociationField.Elem(), asocValue)) 283 default: 284 modelAssociationField.Set(asocValue) 285 } 286 } 287 } 288 }) 289 290 return nil 291 } 292 293 func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { 294 // 1) get all associations ids. 295 ids := []interface{}{} 296 mmi.Model.iterate(func(m *Model) error { 297 ids = append(ids, m.ID()) 298 return nil 299 }) 300 301 if len(ids) == 0 { 302 return nil 303 } 304 305 // 2) load all associations constraint by model ids. 306 fk := asoc.Field.Tag.Get("fk_id") 307 if fk == "" { 308 fk = mmi.Model.associationName() 309 } 310 311 q := tx.Q() 312 q.eager = false 313 q.eagerFields = []string{} 314 315 slice := asoc.toSlice() 316 err := q.Where(fmt.Sprintf("%s in (?)", fk), ids).All(slice.Interface()) 317 if err != nil { 318 return err 319 } 320 321 // 2.1) load all nested associations from this assoc. 322 if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok { 323 for _, asocNestedField := range asocNestedFields { 324 if err := preload(tx, slice.Interface(), asocNestedField); err != nil { 325 return err 326 } 327 } 328 } 329 330 // 3) iterate over every model and fill it with the assoc. 331 foreignField := asoc.getDBFieldTaggedWith(fk) 332 mmi.iterate(func(mvalue reflect.Value) { 333 modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) 334 for i := 0; i < slice.Elem().Len(); i++ { 335 asocValue := slice.Elem().Index(i) 336 if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() || 337 reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) { 338 if modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array { 339 modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) 340 continue 341 } 342 modelAssociationField.Set(asocValue) 343 } 344 } 345 }) 346 347 return nil 348 } 349 350 func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { 351 // 1) get all associations ids. 352 fi := mmi.getDBFieldTaggedWith(asoc.fkName()) 353 if fi == nil { 354 fi = mmi.getDBFieldTaggedWith(fmt.Sprintf("%s%s", flect.Underscore(asoc.Path), "_id")) 355 } 356 357 fkids := []interface{}{} 358 mmi.iterate(func(val reflect.Value) { 359 if !isFieldNilPtr(val, fi) { 360 fkids = append(fkids, mmi.mapper.FieldByName(val, fi.Path).Interface()) 361 } 362 }) 363 364 if len(fkids) == 0 { 365 return nil 366 } 367 368 // 2) load all associations constraint by association fields ids. 369 fk := "id" 370 371 q := tx.Q() 372 q.eager = false 373 q.eagerFields = []string{} 374 375 slice := asoc.toSlice() 376 err := q.Where(fmt.Sprintf("%s in (?)", fk), fkids).All(slice.Interface()) 377 if err != nil { 378 return err 379 } 380 381 // 2.1) load all nested associations from this assoc. 382 if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok { 383 for _, asocNestedField := range asocNestedFields { 384 if err := preload(tx, slice.Interface(), asocNestedField); err != nil { 385 return err 386 } 387 } 388 } 389 390 // 3) iterate over every model and fill it with the assoc. 391 mmi.iterate(func(mvalue reflect.Value) { 392 if isFieldNilPtr(mvalue, fi) { 393 return 394 } 395 modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) 396 for i := 0; i < slice.Elem().Len(); i++ { 397 asocValue := slice.Elem().Index(i) 398 fkField := reflect.Indirect(mmi.mapper.FieldByName(mvalue, fi.Path)) 399 if fkField.Interface() == mmi.mapper.FieldByName(asocValue, "ID").Interface() || 400 reflect.DeepEqual(fkField, mmi.mapper.FieldByName(asocValue, "ID")) { 401 402 switch { 403 case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array: 404 modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) 405 case modelAssociationField.Kind() == reflect.Ptr: 406 modelAssociationField.Elem().Set(asocValue) 407 default: 408 modelAssociationField.Set(asocValue) 409 } 410 } 411 } 412 }) 413 414 return nil 415 } 416 417 func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error { 418 // 1) get all associations ids. 419 // 1.1) In here I pick ids from model meta info directly. 420 ids := []interface{}{} 421 mmi.Model.iterate(func(m *Model) error { 422 ids = append(ids, m.ID()) 423 return nil 424 }) 425 426 if len(ids) == 0 { 427 return nil 428 } 429 430 // 2) load all associations. 431 // 2.1) In here I pick the label name from association. 432 manyToManyTableName := asoc.Field.Tag.Get("many_to_many") 433 modelAssociationName := mmi.Model.associationName() 434 assocFkName := asoc.fkName() 435 436 if strings.Contains(manyToManyTableName, ":") { 437 modelAssociationName = strings.TrimSpace(manyToManyTableName[strings.Index(manyToManyTableName, ":")+1:]) 438 manyToManyTableName = strings.TrimSpace(manyToManyTableName[:strings.Index(manyToManyTableName, ":")]) 439 } 440 441 sql := fmt.Sprintf("SELECT %s, %s FROM %s WHERE %s in (?)", modelAssociationName, assocFkName, manyToManyTableName, modelAssociationName) 442 sql, args, _ := sqlx.In(sql, ids) 443 sql = tx.Dialect.TranslateSQL(sql) 444 log(logging.SQL, sql, args...) 445 446 cn, err := tx.Store.Transaction() 447 if err != nil { 448 return err 449 } 450 451 rows, err := cn.Queryx(sql, args...) 452 if err != nil { 453 return err 454 } 455 456 mapAssoc := map[string][]interface{}{} 457 fkids := []interface{}{} 458 for rows.Next() { 459 row, err := rows.SliceScan() 460 if err != nil { 461 return err 462 } 463 if len(row) > 0 { 464 if _, ok := row[0].([]uint8); ok { // -> it's UUID 465 row[0] = string(row[0].([]uint8)) 466 } 467 if _, ok := row[1].([]uint8); ok { // -> it's UUID 468 row[1] = string(row[1].([]uint8)) 469 } 470 key := fmt.Sprintf("%v", row[0]) 471 mapAssoc[key] = append(mapAssoc[key], row[1]) 472 fkids = append(fkids, row[1]) 473 } 474 } 475 476 q := tx.Q() 477 q.eager = false 478 q.eagerFields = []string{} 479 480 if strings.TrimSpace(asoc.Field.Tag.Get("order_by")) != "" { 481 q.Order(asoc.Field.Tag.Get("order_by")) 482 } 483 484 slice := asoc.toSlice() 485 q.Where("id in (?)", fkids).All(slice.Interface()) 486 487 // 2.2) load all nested associations from this assoc. 488 if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok { 489 for _, asocNestedField := range asocNestedFields { 490 if err := preload(tx, slice.Interface(), asocNestedField); err != nil { 491 return err 492 } 493 } 494 } 495 496 // 3) iterate over every model and fill it with the assoc. 497 mmi.iterate(func(mvalue reflect.Value) { 498 id := mmi.mapper.FieldByName(mvalue, "ID").Interface() 499 if assocFkIds, ok := mapAssoc[fmt.Sprintf("%v", id)]; ok { 500 modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name) 501 for i := 0; i < slice.Elem().Len(); i++ { 502 asocValue := slice.Elem().Index(i) 503 for _, fkid := range assocFkIds { 504 if fmt.Sprintf("%v", fkid) == fmt.Sprintf("%v", mmi.mapper.FieldByName(asocValue, "ID").Interface()) { 505 modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue)) 506 } 507 } 508 } 509 } 510 }) 511 512 return nil 513 } 514 515 func isFieldNilPtr(val reflect.Value, fi *reflectx.FieldInfo) bool { 516 fieldValue := reflectx.FieldByIndexesReadOnly(val, fi.Index) 517 return fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() 518 }