github.com/rjgonzale/pop/v5@v5.1.3-dev/preload_associations.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"regexp"
     7  	"strings"
     8  
     9  	"github.com/gobuffalo/flect"
    10  	"github.com/gobuffalo/pop/v5/internal/defaults"
    11  	"github.com/gobuffalo/pop/v5/logging"
    12  	"github.com/jmoiron/sqlx"
    13  	"github.com/jmoiron/sqlx/reflectx"
    14  )
    15  
    16  var validFieldRegexp = regexp.MustCompile(`^(([a-zA-Z0-9]*)(\.[a-zA-Z0-9]+)?)+$`)
    17  
    18  // NewModelMetaInfo creates the meta info details for the model passed
    19  // as a parameter.
    20  func NewModelMetaInfo(model *Model) *ModelMetaInfo {
    21  	mmi := &ModelMetaInfo{}
    22  	mmi.Model = model
    23  	mmi.init()
    24  	return mmi
    25  }
    26  
    27  // NewAssociationMetaInfo creates the meta info details for the passed association.
    28  func NewAssociationMetaInfo(fi *reflectx.FieldInfo) *AssociationMetaInfo {
    29  	ami := &AssociationMetaInfo{}
    30  	ami.FieldInfo = fi
    31  	ami.init()
    32  	return ami
    33  }
    34  
    35  // ModelMetaInfo a type to abstract all fields information regarding
    36  // to a model. A model is representation of a table in the
    37  // database.
    38  type ModelMetaInfo struct {
    39  	*reflectx.StructMap
    40  	Model        *Model
    41  	mapper       *reflectx.Mapper
    42  	nestedFields map[string]string
    43  }
    44  
    45  func (mmi *ModelMetaInfo) init() {
    46  	m := reflectx.NewMapper("")
    47  	mmi.mapper = m
    48  
    49  	t := reflectx.Deref(reflect.TypeOf(mmi.Model.Value))
    50  	if t.Kind() == reflect.Slice || t.Kind() == reflect.Array {
    51  		t = reflectx.Deref(t.Elem())
    52  	}
    53  
    54  	mmi.StructMap = m.TypeMap(t)
    55  	mmi.nestedFields = make(map[string]string)
    56  }
    57  
    58  func (mmi *ModelMetaInfo) iterate(fn func(reflect.Value)) {
    59  	modelValue := reflect.Indirect(reflect.ValueOf(mmi.Model.Value))
    60  	if modelValue.Kind() == reflect.Slice || modelValue.Kind() == reflect.Array {
    61  		for i := 0; i < modelValue.Len(); i++ {
    62  			fn(modelValue.Index(i))
    63  		}
    64  		return
    65  	}
    66  	fn(modelValue)
    67  }
    68  
    69  func (mmi *ModelMetaInfo) getDBFieldTaggedWith(value string) *reflectx.FieldInfo {
    70  	for _, fi := range mmi.Index {
    71  		if fi.Field.Tag.Get("db") == value {
    72  			if len(fi.Children) > 0 {
    73  				return fi.Children[0]
    74  			}
    75  			return fi
    76  		}
    77  	}
    78  	return nil
    79  }
    80  
    81  func (mmi *ModelMetaInfo) preloadFields(fields ...string) ([]*reflectx.FieldInfo, error) {
    82  	if len(fields) == 0 {
    83  		return mmi.Index, nil
    84  	}
    85  
    86  	var preloadFields []*reflectx.FieldInfo
    87  	for _, f := range fields {
    88  		if !validFieldRegexp.MatchString(f) {
    89  			return preloadFields, fmt.Errorf("association field '%s' does not match the format %s", f, "'<field>' or '<field>.<nested-field>'")
    90  		}
    91  		if strings.Contains(f, ".") {
    92  			mmi.nestedFields[f[:strings.Index(f, ".")]] = f[strings.Index(f, ".")+1:]
    93  			f = f[:strings.Index(f, ".")]
    94  		}
    95  
    96  		preloadField := mmi.GetByPath(f)
    97  		if preloadField == nil {
    98  			return preloadFields, fmt.Errorf("field %s does not exist in model %s", f, mmi.Model.TableName())
    99  		}
   100  
   101  		var exist bool
   102  		for _, pf := range preloadFields {
   103  			if pf.Path == preloadField.Path {
   104  				exist = true
   105  			}
   106  		}
   107  		if !exist {
   108  			preloadFields = append(preloadFields, preloadField)
   109  		}
   110  	}
   111  	return preloadFields, nil
   112  }
   113  
   114  // AssociationMetaInfo a type to abstract all field information
   115  // regarding to an association. An association is a field
   116  // that has defined a tag like 'has_many', 'belongs_to',
   117  // 'many_to_many' and 'has_one'.
   118  type AssociationMetaInfo struct {
   119  	*reflectx.FieldInfo
   120  	*reflectx.StructMap
   121  }
   122  
   123  func (ami *AssociationMetaInfo) init() {
   124  	mapper := reflectx.NewMapper("")
   125  	t := reflectx.Deref(ami.FieldInfo.Field.Type)
   126  	if t.Kind() == reflect.Slice || t.Kind() == reflect.Array {
   127  		t = reflectx.Deref(t.Elem())
   128  	}
   129  
   130  	ami.StructMap = mapper.TypeMap(t)
   131  }
   132  
   133  func (ami *AssociationMetaInfo) toSlice() reflect.Value {
   134  	ft := reflectx.Deref(ami.Field.Type)
   135  	var vt reflect.Value
   136  	if ft.Kind() == reflect.Slice || ft.Kind() == reflect.Array {
   137  		vt = reflect.New(ft)
   138  	} else {
   139  		vt = reflect.New(reflect.SliceOf(ft))
   140  	}
   141  	return vt
   142  }
   143  
   144  func (ami *AssociationMetaInfo) getDBFieldTaggedWith(value string) *reflectx.FieldInfo {
   145  	for _, fi := range ami.StructMap.Index {
   146  		if fi.Field.Tag.Get("db") == value {
   147  			if len(fi.Children) > 0 {
   148  				return fi.Children[0]
   149  			}
   150  			return fi
   151  		}
   152  	}
   153  	return nil
   154  }
   155  
   156  func (ami *AssociationMetaInfo) fkName() string {
   157  	t := ami.Field.Type
   158  	if t.Kind() == reflect.Slice || t.Kind() == reflect.Array {
   159  		t = reflectx.Deref(t.Elem())
   160  	}
   161  	fkName := fmt.Sprintf("%s%s", flect.Underscore(flect.Singularize(t.Name())), "_id")
   162  	fkNameTag := flect.Underscore(ami.Field.Tag.Get("fk_id"))
   163  	return defaults.String(fkNameTag, fkName)
   164  }
   165  
   166  // preload is the query mode used to load associations from database
   167  // similar to the active record default approach on Rails.
   168  func preload(tx *Connection, model interface{}, fields ...string) error {
   169  	mmi := NewModelMetaInfo(&Model{Value: model})
   170  
   171  	preloadFields, err := mmi.preloadFields(fields...)
   172  	if err != nil {
   173  		return err
   174  	}
   175  
   176  	var associations []*AssociationMetaInfo
   177  	for _, fieldInfo := range preloadFields {
   178  		if isFieldAssociation(fieldInfo.Field) && fieldInfo.Parent.Name == "" {
   179  			associations = append(associations, NewAssociationMetaInfo(fieldInfo))
   180  		}
   181  	}
   182  
   183  	for _, asoc := range associations {
   184  		if asoc.Field.Tag.Get("has_many") != "" {
   185  			err := preloadHasMany(tx, asoc, mmi)
   186  			if err != nil {
   187  				return err
   188  			}
   189  		}
   190  
   191  		if asoc.Field.Tag.Get("has_one") != "" {
   192  			err := preloadHasOne(tx, asoc, mmi)
   193  			if err != nil {
   194  				return err
   195  			}
   196  		}
   197  
   198  		if asoc.Field.Tag.Get("belongs_to") != "" {
   199  			err := preloadBelongsTo(tx, asoc, mmi)
   200  			if err != nil {
   201  				return err
   202  			}
   203  		}
   204  
   205  		if asoc.Field.Tag.Get("many_to_many") != "" {
   206  			err := preloadManyToMany(tx, asoc, mmi)
   207  			if err != nil {
   208  				return err
   209  			}
   210  		}
   211  	}
   212  	return nil
   213  }
   214  
   215  func isFieldAssociation(field reflect.StructField) bool {
   216  	for _, associationLabel := range []string{"has_many", "has_one", "belongs_to", "many_to_many"} {
   217  		if field.Tag.Get(associationLabel) != "" {
   218  			return true
   219  		}
   220  	}
   221  	return false
   222  }
   223  
   224  func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
   225  	// 1) get all associations ids.
   226  	// 1.1) In here I pick ids from model meta info directly.
   227  	ids := []interface{}{}
   228  	mmi.Model.iterate(func(m *Model) error {
   229  		ids = append(ids, m.ID())
   230  		return nil
   231  	})
   232  
   233  	if len(ids) == 0 {
   234  		return nil
   235  	}
   236  
   237  	// 2) load all associations constraint by model ids.
   238  	fk := asoc.Field.Tag.Get("fk_id")
   239  	if fk == "" {
   240  		fk = mmi.Model.associationName()
   241  	}
   242  
   243  	q := tx.Q()
   244  	q.eager = false
   245  	q.eagerFields = []string{}
   246  
   247  	slice := asoc.toSlice()
   248  
   249  	if strings.TrimSpace(asoc.Field.Tag.Get("order_by")) != "" {
   250  		q.Order(asoc.Field.Tag.Get("order_by"))
   251  	}
   252  
   253  	err := q.Where(fmt.Sprintf("%s in (?)", fk), ids).All(slice.Interface())
   254  	if err != nil {
   255  		return err
   256  	}
   257  
   258  	// 2.1) load all nested associations from this assoc.
   259  	if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok {
   260  		if err := preload(tx, slice.Interface(), asocNestedFields); err != nil {
   261  			return err
   262  		}
   263  	}
   264  
   265  	// 3) iterate over every model and fill it with the assoc.
   266  	foreignField := asoc.getDBFieldTaggedWith(fk)
   267  	mmi.iterate(func(mvalue reflect.Value) {
   268  		modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
   269  		for i := 0; i < slice.Elem().Len(); i++ {
   270  			asocValue := slice.Elem().Index(i)
   271  			if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() ||
   272  				reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) {
   273  
   274  				switch {
   275  				case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array:
   276  					modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
   277  				case modelAssociationField.Kind() == reflect.Ptr:
   278  					modelAssociationField.Elem().Set(reflect.Append(modelAssociationField.Elem(), asocValue))
   279  				default:
   280  					modelAssociationField.Set(asocValue)
   281  				}
   282  			}
   283  		}
   284  	})
   285  
   286  	return nil
   287  }
   288  
   289  func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
   290  	// 1) get all associations ids.
   291  	ids := []interface{}{}
   292  	mmi.Model.iterate(func(m *Model) error {
   293  		ids = append(ids, m.ID())
   294  		return nil
   295  	})
   296  
   297  	if len(ids) == 0 {
   298  		return nil
   299  	}
   300  
   301  	// 2) load all associations constraint by model ids.
   302  	fk := asoc.Field.Tag.Get("fk_id")
   303  	if fk == "" {
   304  		fk = mmi.Model.associationName()
   305  	}
   306  
   307  	q := tx.Q()
   308  	q.eager = false
   309  	q.eagerFields = []string{}
   310  
   311  	slice := asoc.toSlice()
   312  	err := q.Where(fmt.Sprintf("%s in (?)", fk), ids).All(slice.Interface())
   313  	if err != nil {
   314  		return err
   315  	}
   316  
   317  	// 2.1) load all nested associations from this assoc.
   318  	if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok {
   319  		if err := preload(tx, slice.Interface(), asocNestedFields); err != nil {
   320  			return err
   321  		}
   322  	}
   323  
   324  	//  3) iterate over every model and fill it with the assoc.
   325  	foreignField := asoc.getDBFieldTaggedWith(fk)
   326  	mmi.iterate(func(mvalue reflect.Value) {
   327  		modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
   328  		for i := 0; i < slice.Elem().Len(); i++ {
   329  			asocValue := slice.Elem().Index(i)
   330  			if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() ||
   331  				reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) {
   332  				if modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array {
   333  					modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
   334  					continue
   335  				}
   336  				modelAssociationField.Set(asocValue)
   337  			}
   338  		}
   339  	})
   340  
   341  	return nil
   342  }
   343  
   344  func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
   345  	// 1) get all associations ids.
   346  	fi := mmi.getDBFieldTaggedWith(asoc.fkName())
   347  	if fi == nil {
   348  		fi = mmi.getDBFieldTaggedWith(fmt.Sprintf("%s%s", flect.Underscore(asoc.Path), "_id"))
   349  	}
   350  
   351  	fkids := []interface{}{}
   352  	mmi.iterate(func(val reflect.Value) {
   353  		fkids = append(fkids, mmi.mapper.FieldByName(val, fi.Path).Interface())
   354  	})
   355  
   356  	if len(fkids) == 0 {
   357  		return nil
   358  	}
   359  
   360  	// 2) load all associations constraint by association fields ids.
   361  	fk := "id"
   362  
   363  	q := tx.Q()
   364  	q.eager = false
   365  	q.eagerFields = []string{}
   366  
   367  	slice := asoc.toSlice()
   368  	err := q.Where(fmt.Sprintf("%s in (?)", fk), fkids).All(slice.Interface())
   369  	if err != nil {
   370  		return err
   371  	}
   372  
   373  	// 2.1) load all nested associations from this assoc.
   374  	if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok {
   375  		if err := preload(tx, slice.Interface(), asocNestedFields); err != nil {
   376  			return err
   377  		}
   378  	}
   379  
   380  	// 3) iterate over every model and fill it with the assoc.
   381  	mmi.iterate(func(mvalue reflect.Value) {
   382  		modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
   383  		for i := 0; i < slice.Elem().Len(); i++ {
   384  			asocValue := slice.Elem().Index(i)
   385  			if mmi.mapper.FieldByName(mvalue, fi.Path).Interface() == mmi.mapper.FieldByName(asocValue, "ID").Interface() ||
   386  				reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, fi.Path), mmi.mapper.FieldByName(asocValue, "ID")) {
   387  
   388  				switch {
   389  				case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array:
   390  					modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
   391  				case modelAssociationField.Kind() == reflect.Ptr:
   392  					modelAssociationField.Elem().Set(asocValue)
   393  				default:
   394  					modelAssociationField.Set(asocValue)
   395  				}
   396  			}
   397  		}
   398  	})
   399  
   400  	return nil
   401  }
   402  
   403  func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
   404  	// 1) get all associations ids.
   405  	// 1.1) In here I pick ids from model meta info directly.
   406  	ids := []interface{}{}
   407  	mmi.Model.iterate(func(m *Model) error {
   408  		ids = append(ids, m.ID())
   409  		return nil
   410  	})
   411  
   412  	if len(ids) == 0 {
   413  		return nil
   414  	}
   415  
   416  	// 2) load all associations.
   417  	// 2.1) In here I pick the label name from association.
   418  	manyToManyTableName := asoc.Field.Tag.Get("many_to_many")
   419  	modelAssociationName := mmi.Model.associationName()
   420  	assocFkName := asoc.fkName()
   421  
   422  	if strings.Contains(manyToManyTableName, ":") {
   423  		modelAssociationName = strings.TrimSpace(manyToManyTableName[strings.Index(manyToManyTableName, ":")+1:])
   424  		manyToManyTableName = strings.TrimSpace(manyToManyTableName[:strings.Index(manyToManyTableName, ":")])
   425  	}
   426  
   427  	if tx.TX != nil {
   428  		sql := fmt.Sprintf("SELECT %s, %s FROM %s WHERE %s in (?)", modelAssociationName, assocFkName, manyToManyTableName, modelAssociationName)
   429  		sql, args, _ := sqlx.In(sql, ids)
   430  		sql = tx.Dialect.TranslateSQL(sql)
   431  		log(logging.SQL, sql, args...)
   432  		rows, err := tx.TX.Queryx(sql, args...)
   433  		if err != nil {
   434  			return err
   435  		}
   436  
   437  		mapAssoc := map[string][]interface{}{}
   438  		fkids := []interface{}{}
   439  		for rows.Next() {
   440  			row, err := rows.SliceScan()
   441  			if err != nil {
   442  				return err
   443  			}
   444  			if len(row) > 0 {
   445  				if _, ok := row[0].([]uint8); ok { // -> it's UUID
   446  					row[0] = string(row[0].([]uint8))
   447  				}
   448  				if _, ok := row[1].([]uint8); ok { // -> it's UUID
   449  					row[1] = string(row[1].([]uint8))
   450  				}
   451  				key := fmt.Sprintf("%v", row[0])
   452  				mapAssoc[key] = append(mapAssoc[key], row[1])
   453  				fkids = append(fkids, row[1])
   454  			}
   455  		}
   456  
   457  		q := tx.Q()
   458  		q.eager = false
   459  		q.eagerFields = []string{}
   460  
   461  		if strings.TrimSpace(asoc.Field.Tag.Get("order_by")) != "" {
   462  			q.Order(asoc.Field.Tag.Get("order_by"))
   463  		}
   464  
   465  		slice := asoc.toSlice()
   466  		q.Where("id in (?)", fkids).All(slice.Interface())
   467  
   468  		// 2.2) load all nested associations from this assoc.
   469  		if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok {
   470  			if err := preload(tx, slice.Interface(), asocNestedFields); err != nil {
   471  				return err
   472  			}
   473  		}
   474  
   475  		// 3) iterate over every model and fill it with the assoc.
   476  		mmi.iterate(func(mvalue reflect.Value) {
   477  			id := mmi.mapper.FieldByName(mvalue, "ID").Interface()
   478  			if assocFkIds, ok := mapAssoc[fmt.Sprintf("%v", id)]; ok {
   479  				modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
   480  				for i := 0; i < slice.Elem().Len(); i++ {
   481  					asocValue := slice.Elem().Index(i)
   482  					for _, fkid := range assocFkIds {
   483  						if fmt.Sprintf("%v", fkid) == fmt.Sprintf("%v", mmi.mapper.FieldByName(asocValue, "ID").Interface()) {
   484  							modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
   485  						}
   486  					}
   487  				}
   488  			}
   489  		})
   490  	}
   491  	return nil
   492  }