github.com/Accefy/pop@v0.0.0-20230428174248-e9f677eab5b9/preload_associations.go (about)

     1  package pop
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"regexp"
     7  	"strings"
     8  
     9  	"github.com/Accefy/pop/internal/defaults"
    10  	"github.com/gobuffalo/flect"
    11  	"github.com/gobuffalo/pop/v6/logging"
    12  	"github.com/jmoiron/sqlx"
    13  	"github.com/jmoiron/sqlx/reflectx"
    14  )
    15  
    16  var validFieldRegexp = regexp.MustCompile(`^(([a-zA-Z0-9]*)(\.[a-zA-Z0-9]+)?)+$`)
    17  
    18  // NewModelMetaInfo creates the meta info details for the model passed
    19  // as a parameter.
    20  func NewModelMetaInfo(model *Model) *ModelMetaInfo {
    21  	mmi := &ModelMetaInfo{}
    22  	mmi.Model = model
    23  	mmi.init()
    24  	return mmi
    25  }
    26  
    27  // NewAssociationMetaInfo creates the meta info details for the passed association.
    28  func NewAssociationMetaInfo(fi *reflectx.FieldInfo) *AssociationMetaInfo {
    29  	ami := &AssociationMetaInfo{}
    30  	ami.FieldInfo = fi
    31  	ami.init()
    32  	return ami
    33  }
    34  
    35  // ModelMetaInfo a type to abstract all fields information regarding
    36  // to a model. A model is representation of a table in the
    37  // database.
    38  type ModelMetaInfo struct {
    39  	*reflectx.StructMap
    40  	Model        *Model
    41  	mapper       *reflectx.Mapper
    42  	nestedFields map[string][]string
    43  }
    44  
    45  func (mmi *ModelMetaInfo) init() {
    46  	m := reflectx.NewMapper("")
    47  	mmi.mapper = m
    48  
    49  	t := reflectx.Deref(reflect.TypeOf(mmi.Model.Value))
    50  	if t.Kind() == reflect.Slice || t.Kind() == reflect.Array {
    51  		t = reflectx.Deref(t.Elem())
    52  	}
    53  
    54  	mmi.StructMap = m.TypeMap(t)
    55  	mmi.nestedFields = make(map[string][]string)
    56  }
    57  
    58  func (mmi *ModelMetaInfo) iterate(fn func(reflect.Value)) {
    59  	modelValue := reflect.Indirect(reflect.ValueOf(mmi.Model.Value))
    60  	if modelValue.Kind() == reflect.Slice || modelValue.Kind() == reflect.Array {
    61  		for i := 0; i < modelValue.Len(); i++ {
    62  			fn(modelValue.Index(i))
    63  		}
    64  		return
    65  	}
    66  	fn(modelValue)
    67  }
    68  
    69  func (mmi *ModelMetaInfo) getDBFieldTaggedWith(value string) *reflectx.FieldInfo {
    70  	for _, fi := range mmi.Index {
    71  		if fi.Field.Tag.Get("db") == value {
    72  			if len(fi.Children) > 0 {
    73  				return fi.Children[0]
    74  			}
    75  			return fi
    76  		}
    77  	}
    78  	return nil
    79  }
    80  
    81  func (mmi *ModelMetaInfo) preloadFields(fields ...string) ([]*reflectx.FieldInfo, error) {
    82  	if len(fields) == 0 {
    83  		return mmi.Index, nil
    84  	}
    85  
    86  	var preloadFields []*reflectx.FieldInfo
    87  	for _, f := range fields {
    88  		if !validFieldRegexp.MatchString(f) {
    89  			return preloadFields, fmt.Errorf("association field '%s' does not match the format %s", f, "'<field>' or '<field>.<nested-field>'")
    90  		}
    91  		if strings.Contains(f, ".") {
    92  			fname := f[:strings.Index(f, ".")]
    93  			mmi.nestedFields[fname] = append(mmi.nestedFields[fname], f[strings.Index(f, ".")+1:])
    94  			f = f[:strings.Index(f, ".")]
    95  		}
    96  
    97  		preloadField := mmi.GetByPath(f)
    98  		if preloadField == nil {
    99  			return preloadFields, fmt.Errorf("field %s does not exist in model %s", f, mmi.Model.TableName())
   100  		}
   101  
   102  		var exist bool
   103  		for _, pf := range preloadFields {
   104  			if pf.Path == preloadField.Path {
   105  				exist = true
   106  			}
   107  		}
   108  		if !exist {
   109  			preloadFields = append(preloadFields, preloadField)
   110  		}
   111  	}
   112  	return preloadFields, nil
   113  }
   114  
   115  // AssociationMetaInfo a type to abstract all field information
   116  // regarding to an association. An association is a field
   117  // that has defined a tag like 'has_many', 'belongs_to',
   118  // 'many_to_many' and 'has_one'.
   119  type AssociationMetaInfo struct {
   120  	*reflectx.FieldInfo
   121  	*reflectx.StructMap
   122  }
   123  
   124  func (ami *AssociationMetaInfo) init() {
   125  	mapper := reflectx.NewMapper("")
   126  	t := reflectx.Deref(ami.FieldInfo.Field.Type)
   127  	if t.Kind() == reflect.Slice || t.Kind() == reflect.Array {
   128  		t = reflectx.Deref(t.Elem())
   129  	}
   130  
   131  	ami.StructMap = mapper.TypeMap(t)
   132  }
   133  
   134  func (ami *AssociationMetaInfo) toSlice() reflect.Value {
   135  	ft := reflectx.Deref(ami.Field.Type)
   136  	var vt reflect.Value
   137  	if ft.Kind() == reflect.Slice || ft.Kind() == reflect.Array {
   138  		vt = reflect.New(ft)
   139  	} else {
   140  		vt = reflect.New(reflect.SliceOf(ft))
   141  	}
   142  	return vt
   143  }
   144  
   145  func (ami *AssociationMetaInfo) getDBFieldTaggedWith(value string) *reflectx.FieldInfo {
   146  	for _, fi := range ami.StructMap.Index {
   147  		if fi.Field.Tag.Get("db") == value {
   148  			if len(fi.Children) > 0 {
   149  				return fi.Children[0]
   150  			}
   151  			return fi
   152  		}
   153  	}
   154  	return nil
   155  }
   156  
   157  func (ami *AssociationMetaInfo) fkName() string {
   158  	t := ami.Field.Type
   159  	if t.Kind() == reflect.Slice || t.Kind() == reflect.Array {
   160  		t = reflectx.Deref(t.Elem())
   161  	}
   162  	fkName := fmt.Sprintf("%s%s", flect.Underscore(flect.Singularize(t.Name())), "_id")
   163  	fkNameTag := flect.Underscore(ami.Field.Tag.Get("fk_id"))
   164  	return defaults.String(fkNameTag, fkName)
   165  }
   166  
   167  // preload is the query mode used to load associations from database
   168  // similar to the active record default approach on Rails.
   169  func preload(tx *Connection, model interface{}, fields ...string) error {
   170  	mmi := NewModelMetaInfo(NewModel(model, tx.Context()))
   171  
   172  	preloadFields, err := mmi.preloadFields(fields...)
   173  	if err != nil {
   174  		return err
   175  	}
   176  
   177  	var associations []*AssociationMetaInfo
   178  	for _, fieldInfo := range preloadFields {
   179  		if isFieldAssociation(fieldInfo.Field) && fieldInfo.Parent.Name == "" {
   180  			associations = append(associations, NewAssociationMetaInfo(fieldInfo))
   181  		}
   182  	}
   183  
   184  	for _, asoc := range associations {
   185  		if asoc.Field.Tag.Get("has_many") != "" {
   186  			err := preloadHasMany(tx, asoc, mmi)
   187  			if err != nil {
   188  				return err
   189  			}
   190  		}
   191  
   192  		if asoc.Field.Tag.Get("has_one") != "" {
   193  			err := preloadHasOne(tx, asoc, mmi)
   194  			if err != nil {
   195  				return err
   196  			}
   197  		}
   198  
   199  		if asoc.Field.Tag.Get("belongs_to") != "" {
   200  			err := preloadBelongsTo(tx, asoc, mmi)
   201  			if err != nil {
   202  				return err
   203  			}
   204  		}
   205  
   206  		if asoc.Field.Tag.Get("many_to_many") != "" {
   207  			err := preloadManyToMany(tx, asoc, mmi)
   208  			if err != nil {
   209  				return err
   210  			}
   211  		}
   212  	}
   213  	return nil
   214  }
   215  
   216  func isFieldAssociation(field reflect.StructField) bool {
   217  	for _, associationLabel := range []string{"has_many", "has_one", "belongs_to", "many_to_many"} {
   218  		if field.Tag.Get(associationLabel) != "" {
   219  			return true
   220  		}
   221  	}
   222  	return false
   223  }
   224  
   225  func preloadHasMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
   226  	// 1) get all associations ids.
   227  	// 1.1) In here I pick ids from model meta info directly.
   228  	ids := []interface{}{}
   229  	mmi.Model.iterate(func(m *Model) error {
   230  		ids = append(ids, m.ID())
   231  		return nil
   232  	})
   233  
   234  	if len(ids) == 0 {
   235  		return nil
   236  	}
   237  
   238  	// 2) load all associations constraint by model ids.
   239  	fk := asoc.Field.Tag.Get("fk_id")
   240  	if fk == "" {
   241  		fk = mmi.Model.associationName()
   242  	}
   243  
   244  	q := tx.Q()
   245  	q.eager = false
   246  	q.eagerFields = []string{}
   247  
   248  	slice := asoc.toSlice()
   249  
   250  	if strings.TrimSpace(asoc.Field.Tag.Get("order_by")) != "" {
   251  		q.Order(asoc.Field.Tag.Get("order_by"))
   252  	}
   253  
   254  	err := q.Where(fmt.Sprintf("%s in (?)", fk), ids).All(slice.Interface())
   255  	if err != nil {
   256  		return err
   257  	}
   258  
   259  	// 2.1) load all nested associations from this assoc.
   260  	if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok {
   261  		for _, asocNestedField := range asocNestedFields {
   262  			if err := preload(tx, slice.Interface(), asocNestedField); err != nil {
   263  				return err
   264  			}
   265  		}
   266  	}
   267  
   268  	// 3) iterate over every model and fill it with the assoc.
   269  	foreignField := asoc.getDBFieldTaggedWith(fk)
   270  	mmi.iterate(func(mvalue reflect.Value) {
   271  		for i := 0; i < slice.Elem().Len(); i++ {
   272  			asocValue := slice.Elem().Index(i)
   273  			valueField := reflect.Indirect(mmi.mapper.FieldByName(asocValue, foreignField.Path))
   274  			if mmi.mapper.FieldByName(mvalue, "ID").Interface() == valueField.Interface() ||
   275  				reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), valueField) {
   276  				// IMPORTANT
   277  				//
   278  				// FieldByName will initialize the value. It is important that this happens AFTER
   279  				// we checked whether the field should be set. Otherwise, we'll set a zero value!
   280  				//
   281  				// This is most likely the reason for https://github.com/gobuffalo/pop/issues/139
   282  				modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
   283  				switch {
   284  				case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array:
   285  					modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
   286  				case modelAssociationField.Kind() == reflect.Ptr:
   287  					modelAssociationField.Elem().Set(reflect.Append(modelAssociationField.Elem(), asocValue))
   288  				default:
   289  					modelAssociationField.Set(asocValue)
   290  				}
   291  			}
   292  		}
   293  	})
   294  
   295  	return nil
   296  }
   297  
   298  func preloadHasOne(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
   299  	// 1) get all associations ids.
   300  	ids := []interface{}{}
   301  	mmi.Model.iterate(func(m *Model) error {
   302  		ids = append(ids, m.ID())
   303  		return nil
   304  	})
   305  
   306  	if len(ids) == 0 {
   307  		return nil
   308  	}
   309  
   310  	// 2) load all associations constraint by model ids.
   311  	fk := asoc.Field.Tag.Get("fk_id")
   312  	if fk == "" {
   313  		fk = mmi.Model.associationName()
   314  	}
   315  
   316  	q := tx.Q()
   317  	q.eager = false
   318  	q.eagerFields = []string{}
   319  
   320  	slice := asoc.toSlice()
   321  	err := q.Where(fmt.Sprintf("%s in (?)", fk), ids).All(slice.Interface())
   322  	if err != nil {
   323  		return err
   324  	}
   325  
   326  	// 2.1) load all nested associations from this assoc.
   327  	if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok {
   328  		for _, asocNestedField := range asocNestedFields {
   329  			if err := preload(tx, slice.Interface(), asocNestedField); err != nil {
   330  				return err
   331  			}
   332  		}
   333  	}
   334  
   335  	//  3) iterate over every model and fill it with the assoc.
   336  	foreignField := asoc.getDBFieldTaggedWith(fk)
   337  	mmi.iterate(func(mvalue reflect.Value) {
   338  		for i := 0; i < slice.Elem().Len(); i++ {
   339  			asocValue := slice.Elem().Index(i)
   340  			if mmi.mapper.FieldByName(mvalue, "ID").Interface() == mmi.mapper.FieldByName(asocValue, foreignField.Path).Interface() ||
   341  				reflect.DeepEqual(mmi.mapper.FieldByName(mvalue, "ID"), mmi.mapper.FieldByName(asocValue, foreignField.Path)) {
   342  				// IMPORTANT
   343  				//
   344  				// FieldByName will initialize the value. It is important that this happens AFTER
   345  				// we checked whether the field should be set. Otherwise, we'll set a zero value!
   346  				//
   347  				// This is most likely the reason for https://github.com/gobuffalo/pop/issues/139
   348  				modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
   349  				switch {
   350  				case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array:
   351  					modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
   352  				case modelAssociationField.Kind() == reflect.Ptr:
   353  					modelAssociationField.Elem().Set(asocValue)
   354  				default:
   355  					modelAssociationField.Set(asocValue)
   356  				}
   357  			}
   358  		}
   359  	})
   360  
   361  	return nil
   362  }
   363  
   364  func preloadBelongsTo(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
   365  	// 1) get all associations ids.
   366  	fi := mmi.getDBFieldTaggedWith(asoc.fkName())
   367  	if fi == nil {
   368  		fi = mmi.getDBFieldTaggedWith(fmt.Sprintf("%s%s", flect.Underscore(asoc.Path), "_id"))
   369  	}
   370  
   371  	fkids := []interface{}{}
   372  	mmi.iterate(func(val reflect.Value) {
   373  		if !isFieldNilPtr(val, fi) {
   374  			fkids = append(fkids, mmi.mapper.FieldByName(val, fi.Path).Interface())
   375  		}
   376  	})
   377  
   378  	if len(fkids) == 0 {
   379  		return nil
   380  	}
   381  
   382  	// 2) load all associations constraint by association fields ids.
   383  	fk := "id"
   384  
   385  	q := tx.Q()
   386  	q.eager = false
   387  	q.eagerFields = []string{}
   388  
   389  	slice := asoc.toSlice()
   390  	err := q.Where(fmt.Sprintf("%s in (?)", fk), fkids).All(slice.Interface())
   391  	if err != nil {
   392  		return err
   393  	}
   394  
   395  	// 2.1) load all nested associations from this assoc.
   396  	if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok {
   397  		for _, asocNestedField := range asocNestedFields {
   398  			if err := preload(tx, slice.Interface(), asocNestedField); err != nil {
   399  				return err
   400  			}
   401  		}
   402  	}
   403  
   404  	// 3) iterate over every model and fill it with the assoc.
   405  	mmi.iterate(func(mvalue reflect.Value) {
   406  		if isFieldNilPtr(mvalue, fi) {
   407  			return
   408  		}
   409  		for i := 0; i < slice.Elem().Len(); i++ {
   410  			asocValue := slice.Elem().Index(i)
   411  			fkField := reflect.Indirect(mmi.mapper.FieldByName(mvalue, fi.Path))
   412  			field := mmi.mapper.FieldByName(asocValue, "ID")
   413  			if fkField.Interface() == field.Interface() || reflect.DeepEqual(fkField, field) {
   414  				// IMPORTANT
   415  				//
   416  				// FieldByName will initialize the value. It is important that this happens AFTER
   417  				// we checked whether the field should be set. Otherwise, we'll set a zero value!
   418  				//
   419  				// This is most likely the reason for https://github.com/gobuffalo/pop/issues/139
   420  				modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
   421  				switch {
   422  				case modelAssociationField.Kind() == reflect.Slice || modelAssociationField.Kind() == reflect.Array:
   423  					modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
   424  				case modelAssociationField.Kind() == reflect.Ptr:
   425  					modelAssociationField.Elem().Set(asocValue)
   426  				default:
   427  					modelAssociationField.Set(asocValue)
   428  				}
   429  			}
   430  		}
   431  	})
   432  
   433  	return nil
   434  }
   435  
   436  func preloadManyToMany(tx *Connection, asoc *AssociationMetaInfo, mmi *ModelMetaInfo) error {
   437  	// 1) get all associations ids.
   438  	// 1.1) In here I pick ids from model meta info directly.
   439  	ids := []interface{}{}
   440  	mmi.Model.iterate(func(m *Model) error {
   441  		ids = append(ids, m.ID())
   442  		return nil
   443  	})
   444  
   445  	if len(ids) == 0 {
   446  		return nil
   447  	}
   448  
   449  	// 2) load all associations.
   450  	// 2.1) In here I pick the label name from association.
   451  	manyToManyTableName := asoc.Field.Tag.Get("many_to_many")
   452  	modelAssociationName := mmi.Model.associationName()
   453  	assocFkName := asoc.fkName()
   454  
   455  	if strings.Contains(manyToManyTableName, ":") {
   456  		modelAssociationName = strings.TrimSpace(manyToManyTableName[strings.Index(manyToManyTableName, ":")+1:])
   457  		manyToManyTableName = strings.TrimSpace(manyToManyTableName[:strings.Index(manyToManyTableName, ":")])
   458  	}
   459  
   460  	sql := fmt.Sprintf("SELECT %s, %s FROM %s WHERE %s in (?)", modelAssociationName, assocFkName, manyToManyTableName, modelAssociationName)
   461  	sql, args, _ := sqlx.In(sql, ids)
   462  	sql = tx.Dialect.TranslateSQL(sql)
   463  
   464  	cn, err := tx.Store.Transaction()
   465  	if err != nil {
   466  		return err
   467  	}
   468  
   469  	txlog(logging.SQL, cn, sql, args...)
   470  	rows, err := cn.Queryx(sql, args...)
   471  	if err != nil {
   472  		return err
   473  	}
   474  
   475  	mapAssoc := map[string][]interface{}{}
   476  	fkids := []interface{}{}
   477  	for rows.Next() {
   478  		row, err := rows.SliceScan()
   479  		if err != nil {
   480  			return err
   481  		}
   482  		if len(row) > 0 {
   483  			if _, ok := row[0].([]uint8); ok { // -> it's UUID
   484  				row[0] = string(row[0].([]uint8))
   485  			}
   486  			if _, ok := row[1].([]uint8); ok { // -> it's UUID
   487  				row[1] = string(row[1].([]uint8))
   488  			}
   489  			key := fmt.Sprintf("%v", row[0])
   490  			mapAssoc[key] = append(mapAssoc[key], row[1])
   491  			fkids = append(fkids, row[1])
   492  		}
   493  	}
   494  	if err := rows.Err(); err != nil {
   495  		return err
   496  	}
   497  
   498  	q := tx.Q()
   499  	q.eager = false
   500  	q.eagerFields = []string{}
   501  
   502  	if strings.TrimSpace(asoc.Field.Tag.Get("order_by")) != "" {
   503  		q.Order(asoc.Field.Tag.Get("order_by"))
   504  	}
   505  
   506  	slice := asoc.toSlice()
   507  	q.Where("id in (?)", fkids).All(slice.Interface())
   508  
   509  	// 2.2) load all nested associations from this assoc.
   510  	if asocNestedFields, ok := mmi.nestedFields[asoc.Path]; ok {
   511  		for _, asocNestedField := range asocNestedFields {
   512  			if err := preload(tx, slice.Interface(), asocNestedField); err != nil {
   513  				return err
   514  			}
   515  		}
   516  	}
   517  
   518  	// 3) iterate over every model and fill it with the assoc.
   519  	mmi.iterate(func(mvalue reflect.Value) {
   520  		id := mmi.mapper.FieldByName(mvalue, "ID").Interface()
   521  		if assocFkIds, ok := mapAssoc[fmt.Sprintf("%v", id)]; ok {
   522  			for i := 0; i < slice.Elem().Len(); i++ {
   523  				asocValue := slice.Elem().Index(i)
   524  				for _, fkid := range assocFkIds {
   525  					if fmt.Sprintf("%v", fkid) == fmt.Sprintf("%v", mmi.mapper.FieldByName(asocValue, "ID").Interface()) {
   526  						// IMPORTANT
   527  						//
   528  						// FieldByName will initialize the value. It is important that this happens AFTER
   529  						// we checked whether the field should be set. Otherwise, we'll set a zero value!
   530  						//
   531  						// This is most likely the reason for https://github.com/gobuffalo/pop/issues/139
   532  						modelAssociationField := mmi.mapper.FieldByName(mvalue, asoc.Name)
   533  						modelAssociationField.Set(reflect.Append(modelAssociationField, asocValue))
   534  					}
   535  				}
   536  			}
   537  		}
   538  	})
   539  
   540  	return nil
   541  }
   542  
   543  func isFieldNilPtr(val reflect.Value, fi *reflectx.FieldInfo) bool {
   544  	fieldValue := reflectx.FieldByIndexesReadOnly(val, fi.Index)
   545  	return fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil()
   546  }