github.com/paweljw/pop/v5@v5.4.6/preload_associations.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"regexp"
     7  	"strings"
     8  
     9  	"github.com/gobuffalo/flect"
    10  	"github.com/gobuffalo/pop/v5/internal/defaults"
    11  	"github.com/gobuffalo/pop/v5/logging"
    12  	"github.com/jmoiron/sqlx"
    13  	"github.com/jmoiron/sqlx/reflectx"
    14  )
    15  
    16  var validFieldRegexp = regexp.MustCompile(`^(([a-zA-Z0-9]*)(\.[a-zA-Z0-9]+)?)+$`)
    17  
    18  // NewModelMetaInfo creates the meta info details for the model passed
    19  // as a parameter.
    20  func NewModelMetaInfo(model *Model) *ModelMetaInfo {
    21  	mmi := &ModelMetaInfo{}
    22  	mmi.Model = model
    23  	mmi.init()
    24  	return mmi
    25  }
    26  
    27  // NewAssociationMetaInfo creates the meta info details for the passed association.
    28  func NewAssociationMetaInfo(fi *reflectx.FieldInfo) *AssociationMetaInfo {
    29  	ami := &AssociationMetaInfo{}
    30  	ami.FieldInfo = fi
    31  	ami.init()
    32  	return ami
    33  }
    34  
    35  // ModelMetaInfo a type to abstract all fields information regarding
    36  // to a model. A model is representation of a table in the
    37  // database.
    38  type ModelMetaInfo struct {
    39  	*reflectx.StructMap
    40  	Model        *Model
    41  	mapper       *reflectx.Mapper
    42  	nestedFields map[string][]string
    43  }
    44  
    45  func (mmi *ModelMetaInfo) init() {
    46  	m := reflectx.NewMapper("")
    47  	mmi.mapper = m
    48  
    49  	t := reflectx.Deref(reflect.TypeOf(mmi.Model.Value))
    50  	if t.Kind() == reflect.Slice || t.Kind() == reflect.Array {
    51  		t = reflectx.Deref(t.Elem())
    52  	}
    53  
    54  	mmi.StructMap = m.TypeMap(t)
    55  	mmi.nestedFields = make(map[string][]string)
    56  }
    57  
    58  func (mmi *ModelMetaInfo) iterate(fn func(reflect.Value)) {
    59  	modelValue := reflect.Indirect(reflect.ValueOf(mmi.Model.Value))
    60  	if modelValue.Kind() == reflect.Slice || modelValue.Kind() == reflect.Array {
    61  		for i := 0; i < modelValue.Len(); i++ {
    62  			fn(modelValue.Index(i))
    63  		}
    64  		return
    65  	}
    66  	fn(modelValue)
    67  }
    68  
    69  func (mmi *ModelMetaInfo) getDBFieldTaggedWith(value string) *reflectx.FieldInfo {
    70  	for _, fi := range mmi.Index {
    71  		if fi.Field.Tag.Get("db") == value {
    72  			if len(fi.Children) > 0 {
    73  				return fi.Children[0]
    74  			}
    75  			return fi
    76  		}
    77  	}
    78  	return nil
    79  }
    80  
    81  func (mmi *ModelMetaInfo) preloadFields(fields ...string) ([]*reflectx.FieldInfo, error) {
    82  	if len(fields) == 0 {
    83  		return mmi.Index, nil
    84  	}
    85  
    86  	var preloadFields []*reflectx.FieldInfo
    87  	for _, f := range fields {
    88  		if !validFieldRegexp.MatchString(f) {
    89  			return preloadFields, fmt.Errorf("association field '%s' does not match the format %s", f, "'<field>' or '<field>.<nested-field>'")
    90  		}
    91  		if strings.Contains(f, ".") {
    92  			fname := f[:strings.Index(f, ".")]
    93  			mmi.nestedFields[fname] = append(mmi.nestedFields[fname], f[strings.Index(f, ".")+1:])
    94  			f = f[:strings.Index(f, ".")]
    95  		}
    96  
    97  		preloadField := mmi.GetByPath(f)
    98  		if preloadField == nil {
    99  			return preloadFields, fmt.Errorf("field %s does not exist in model %s", f, mmi.Model.TableName())
   100  		}
   101  
   102  		var exist bool
   103  		for _, pf := range preloadFields {
   104  			if pf.Path == preloadField.Path {
   105  				exist = true
   106  			}
   107  		}
   108  		if !exist {
   109  			preloadFields = append(preloadFields, preloadField)
   110  		}
   111  	}
   112  	return preloadFields, nil
   113  }
   114  
   115  // AssociationMetaInfo a type to abstract all field information
   116  // regarding to an association. An association is a field
   117  // that has defined a tag like 'has_many', 'belongs_to',
   118  // 'many_to_many' and 'has_one'.
   119  type AssociationMetaInfo struct {
   120  	*reflectx.FieldInfo
   121  	*reflectx.StructMap
   122  }
   123  
   124  func (ami *AssociationMetaInfo) init() {
   125  	mapper := reflectx.NewMapper("")
   126  	t := reflectx.Deref(ami.FieldInfo.Field.Type)
   127  	if t.Kind() == reflect.Slice || t.Kind() == reflect.Array {
   128  		t = reflectx.Deref(t.Elem())
   129  	}
   130  
   131  	ami.StructMap = mapper.TypeMap(t)
   132  }
   133  
   134  func (ami *AssociationMetaInfo) toSlice() reflect.Value {
   135  	ft := reflectx.Deref(ami.Field.Type)
   136  	var vt reflect.Value
   137  	if ft.Kind() == reflect.Slice || ft.Kind() == reflect.Array {
   138  		vt = reflect.New(ft)
   139  	} else {
   140  		vt = reflect.New(reflect.SliceOf(ft))
   141  	}
   142  	return vt
   143  }
   144  
   145  func (ami *AssociationMetaInfo) getDBFieldTaggedWith(value string) *reflectx.FieldInfo {
   146  	for _, fi := range ami.StructMap.Index {
   147  		if fi.Field.Tag.Get("db") == value {
   148  			if len(fi.Children) > 0 {
   149  				return fi.Children[0]
   150  			}
   151  			return fi
   152  		}
   153  	}
   154  	return nil
   155  }
   156  
   157  func (ami *AssociationMetaInfo) fkName() string {
   158  	t := ami.Field.Type
   159  	if t.Kind() == reflect.Slice || t.Kind() == reflect.Array {
   160  		t = reflectx.Deref(t.Elem())
   161  	}
   162  	fkName := fmt.Sprintf("%s%s", flect.Underscore(flect.Singularize(t.Name())), "_id")
   163  	fkNameTag := flect.Underscore(ami.Field.Tag.Get("fk_id"))
   164  	return defaults.String(fkNameTag, fkName)
   165  }
   166  
   167  // preload is the query mode used to load associations from database
   168  // similar to the active record default approach on Rails.
   169  func preload(tx *Connection, model interface{}, fields ...string) error {
   170  	mmi := NewModelMetaInfo(NewModel(model, tx.Context()))
   171  
   172  	preloadFields, err := mmi.preloadFields(fields...)
   173  	if err != nil {
   174  		return err
   175  	}
   176  
   177  	var associations []*AssociationMetaInfo
   178  	for _, fieldInfo := range preloadFields {
   179  		if isFieldAssociation(fieldInfo.Field) && fieldInfo.Parent.Name == "" {
   180  			associations = append(associations, NewAssociationMetaInfo(fieldInfo))
   181  		}
   182  	}
   183  
   184  	for _, asoc := range associations {
   185  		if asoc.Field.Tag.Get("has_many") != "" {
   186  			err := preloadHasMany(tx, asoc, mmi)
   187  			if err != nil {
   188  				return err
   189  			}
   190  		}
   191  
   192  		if asoc.Field.Tag.Get("has_one") != "" {
   193  			err := preloadHasOne(tx, asoc, mmi)
   194  			if err != nil {
   195  				return err
   196  			}
   197  		}
   198  
   199  		if asoc.Field.Tag.Get("belongs_to") != "" {
   200  			err := preloadBelongsTo(tx, asoc, mmi)
   201  			if err != nil {
   202  				return err
   203  			}
   204  		}
   205  
   206  		if asoc.Field.Tag.Get("many_to_many") != "" {
   207  			err := preloadManyToMany(tx, asoc, mmi)
   208  			if err != nil {
   209  				return err
   210  			}
   211  		}
   212  	}
   213  	return nil
   214  }
   215  
   216  func isFieldAssociation(field reflect.StructField) bool {
   217  	for _, associationLabel := range []string{"has_many", "has_one", "belongs_to", "many_to_many"} {
   218  		if field.Tag.Get(associationLabel) != "" {
   219  			return true
   220  		}
   221  	}
   222  	return false
   223  }
   224  
   225  func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
   226  	// 1) get all associations ids.
   227  	// 1.1) In here I pick ids from model meta info directly.
   228  	ids := []interface{}{}
   229  	mmi.Model.iterate(func(m *Model) error {
   230  		ids = append(ids, m.ID())
   231  		return nil
   232  	})
   233  
   234  	if len(ids) == 0 {
   235  		return nil
   236  	}
   237  
   238  	// 2) load all associations constraint by model ids.
   239  	fk := asoc.Field.Tag.Get("fk_id")
   240  	if fk == "" {
   241  		fk = mmi.Model.associationName()
   242  	}
   243  
   244  	q := tx.Q()
   245  	q.eager = false
   246  	q.eagerFields = []string{}
   247  
   248  	slice := asoc.toSlice()
   249  
   250  	if strings.TrimSpace(asoc.Field.Tag.Get("order_by")) != "" {
   251  		q.Order(asoc.Field.Tag.Get("order_by"))
   252  	}
   253  
   254  	err := q.Where(fmt.Sprintf("%s in (?)", fk), ids).All(slice.Interface())
   255  	if err != nil {
   256  		return err
   257  	}
   258  
   259  	// 2.1) load all nested associations from this assoc.
   260  	if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok {
   261  		for _, asocNestedField := range asocNestedFields {
   262  			if err := preload(tx, slice.Interface(), asocNestedField); err != nil {
   263  				return err
   264  			}
   265  		}
   266  	}
   267  
   268  	// 3) iterate over every model and fill it with the assoc.
   269  	foreignField := asoc.getDBFieldTaggedWith(fk)
   270  	mmi.iterate(func(mvalue reflect.Value) {
   271  		modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
   272  		for i := 0; i < slice.Elem().Len(); i++ {
   273  			asocValue := slice.Elem().Index(i)
   274  			valueField := reflect.Indirect(mmi.mapper.FieldByName(asocValue, foreignField.Path))
   275  			if mmi.mapper.FieldByName(mvalue, "ID").Interface() == valueField.Interface() ||
   276  				reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), valueField) {
   277  
   278  				switch {
   279  				case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array:
   280  					modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
   281  				case modelAssociationField.Kind() == reflect.Ptr:
   282  					modelAssociationField.Elem().Set(reflect.Append(modelAssociationField.Elem(), asocValue))
   283  				default:
   284  					modelAssociationField.Set(asocValue)
   285  				}
   286  			}
   287  		}
   288  	})
   289  
   290  	return nil
   291  }
   292  
   293  func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
   294  	// 1) get all associations ids.
   295  	ids := []interface{}{}
   296  	mmi.Model.iterate(func(m *Model) error {
   297  		ids = append(ids, m.ID())
   298  		return nil
   299  	})
   300  
   301  	if len(ids) == 0 {
   302  		return nil
   303  	}
   304  
   305  	// 2) load all associations constraint by model ids.
   306  	fk := asoc.Field.Tag.Get("fk_id")
   307  	if fk == "" {
   308  		fk = mmi.Model.associationName()
   309  	}
   310  
   311  	q := tx.Q()
   312  	q.eager = false
   313  	q.eagerFields = []string{}
   314  
   315  	slice := asoc.toSlice()
   316  	err := q.Where(fmt.Sprintf("%s in (?)", fk), ids).All(slice.Interface())
   317  	if err != nil {
   318  		return err
   319  	}
   320  
   321  	// 2.1) load all nested associations from this assoc.
   322  	if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok {
   323  		for _, asocNestedField := range asocNestedFields {
   324  			if err := preload(tx, slice.Interface(), asocNestedField); err != nil {
   325  				return err
   326  			}
   327  		}
   328  	}
   329  
   330  	//  3) iterate over every model and fill it with the assoc.
   331  	foreignField := asoc.getDBFieldTaggedWith(fk)
   332  	mmi.iterate(func(mvalue reflect.Value) {
   333  		modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
   334  		for i := 0; i < slice.Elem().Len(); i++ {
   335  			asocValue := slice.Elem().Index(i)
   336  			if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() ||
   337  				reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) {
   338  				if modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array {
   339  					modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
   340  					continue
   341  				}
   342  				modelAssociationField.Set(asocValue)
   343  			}
   344  		}
   345  	})
   346  
   347  	return nil
   348  }
   349  
   350  func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
   351  	// 1) get all associations ids.
   352  	fi := mmi.getDBFieldTaggedWith(asoc.fkName())
   353  	if fi == nil {
   354  		fi = mmi.getDBFieldTaggedWith(fmt.Sprintf("%s%s", flect.Underscore(asoc.Path), "_id"))
   355  	}
   356  
   357  	fkids := []interface{}{}
   358  	mmi.iterate(func(val reflect.Value) {
   359  		if !isFieldNilPtr(val, fi) {
   360  			fkids = append(fkids, mmi.mapper.FieldByName(val, fi.Path).Interface())
   361  		}
   362  	})
   363  
   364  	if len(fkids) == 0 {
   365  		return nil
   366  	}
   367  
   368  	// 2) load all associations constraint by association fields ids.
   369  	fk := "id"
   370  
   371  	q := tx.Q()
   372  	q.eager = false
   373  	q.eagerFields = []string{}
   374  
   375  	slice := asoc.toSlice()
   376  	err := q.Where(fmt.Sprintf("%s in (?)", fk), fkids).All(slice.Interface())
   377  	if err != nil {
   378  		return err
   379  	}
   380  
   381  	// 2.1) load all nested associations from this assoc.
   382  	if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok {
   383  		for _, asocNestedField := range asocNestedFields {
   384  			if err := preload(tx, slice.Interface(), asocNestedField); err != nil {
   385  				return err
   386  			}
   387  		}
   388  	}
   389  
   390  	// 3) iterate over every model and fill it with the assoc.
   391  	mmi.iterate(func(mvalue reflect.Value) {
   392  		if isFieldNilPtr(mvalue, fi) {
   393  			return
   394  		}
   395  		modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
   396  		for i := 0; i < slice.Elem().Len(); i++ {
   397  			asocValue := slice.Elem().Index(i)
   398  			fkField := reflect.Indirect(mmi.mapper.FieldByName(mvalue, fi.Path))
   399  			if fkField.Interface() == mmi.mapper.FieldByName(asocValue, "ID").Interface() ||
   400  				reflect.DeepEqual(fkField, mmi.mapper.FieldByName(asocValue, "ID")) {
   401  
   402  				switch {
   403  				case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array:
   404  					modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
   405  				case modelAssociationField.Kind() == reflect.Ptr:
   406  					modelAssociationField.Elem().Set(asocValue)
   407  				default:
   408  					modelAssociationField.Set(asocValue)
   409  				}
   410  			}
   411  		}
   412  	})
   413  
   414  	return nil
   415  }
   416  
   417  func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
   418  	// 1) get all associations ids.
   419  	// 1.1) In here I pick ids from model meta info directly.
   420  	ids := []interface{}{}
   421  	mmi.Model.iterate(func(m *Model) error {
   422  		ids = append(ids, m.ID())
   423  		return nil
   424  	})
   425  
   426  	if len(ids) == 0 {
   427  		return nil
   428  	}
   429  
   430  	// 2) load all associations.
   431  	// 2.1) In here I pick the label name from association.
   432  	manyToManyTableName := asoc.Field.Tag.Get("many_to_many")
   433  	modelAssociationName := mmi.Model.associationName()
   434  	assocFkName := asoc.fkName()
   435  
   436  	if strings.Contains(manyToManyTableName, ":") {
   437  		modelAssociationName = strings.TrimSpace(manyToManyTableName[strings.Index(manyToManyTableName, ":")+1:])
   438  		manyToManyTableName = strings.TrimSpace(manyToManyTableName[:strings.Index(manyToManyTableName, ":")])
   439  	}
   440  
   441  	sql := fmt.Sprintf("SELECT %s, %s FROM %s WHERE %s in (?)", modelAssociationName, assocFkName, manyToManyTableName, modelAssociationName)
   442  	sql, args, _ := sqlx.In(sql, ids)
   443  	sql = tx.Dialect.TranslateSQL(sql)
   444  	log(logging.SQL, sql, args...)
   445  
   446  	cn, err := tx.Store.Transaction()
   447  	if err != nil {
   448  		return err
   449  	}
   450  
   451  	rows, err := cn.Queryx(sql, args...)
   452  	if err != nil {
   453  		return err
   454  	}
   455  
   456  	mapAssoc := map[string][]interface{}{}
   457  	fkids := []interface{}{}
   458  	for rows.Next() {
   459  		row, err := rows.SliceScan()
   460  		if err != nil {
   461  			return err
   462  		}
   463  		if len(row) > 0 {
   464  			if _, ok := row[0].([]uint8); ok { // -> it's UUID
   465  				row[0] = string(row[0].([]uint8))
   466  			}
   467  			if _, ok := row[1].([]uint8); ok { // -> it's UUID
   468  				row[1] = string(row[1].([]uint8))
   469  			}
   470  			key := fmt.Sprintf("%v", row[0])
   471  			mapAssoc[key] = append(mapAssoc[key], row[1])
   472  			fkids = append(fkids, row[1])
   473  		}
   474  	}
   475  
   476  	q := tx.Q()
   477  	q.eager = false
   478  	q.eagerFields = []string{}
   479  
   480  	if strings.TrimSpace(asoc.Field.Tag.Get("order_by")) != "" {
   481  		q.Order(asoc.Field.Tag.Get("order_by"))
   482  	}
   483  
   484  	slice := asoc.toSlice()
   485  	q.Where("id in (?)", fkids).All(slice.Interface())
   486  
   487  	// 2.2) load all nested associations from this assoc.
   488  	if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok {
   489  		for _, asocNestedField := range asocNestedFields {
   490  			if err := preload(tx, slice.Interface(), asocNestedField); err != nil {
   491  				return err
   492  			}
   493  		}
   494  	}
   495  
   496  	// 3) iterate over every model and fill it with the assoc.
   497  	mmi.iterate(func(mvalue reflect.Value) {
   498  		id := mmi.mapper.FieldByName(mvalue, "ID").Interface()
   499  		if assocFkIds, ok := mapAssoc[fmt.Sprintf("%v", id)]; ok {
   500  			modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
   501  			for i := 0; i < slice.Elem().Len(); i++ {
   502  				asocValue := slice.Elem().Index(i)
   503  				for _, fkid := range assocFkIds {
   504  					if fmt.Sprintf("%v", fkid) == fmt.Sprintf("%v", mmi.mapper.FieldByName(asocValue, "ID").Interface()) {
   505  						modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
   506  					}
   507  				}
   508  			}
   509  		}
   510  	})
   511  
   512  	return nil
   513  }
   514  
   515  func isFieldNilPtr(val reflect.Value, fi *reflectx.FieldInfo) bool {
   516  	fieldValue := reflectx.FieldByIndexesReadOnly(val, fi.Index)
   517  	return fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil()
   518  }