github.com/systematiccaos/gorm@v1.22.6/callbacks/preload.go (about)

     1  package callbacks
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  
     7  	"github.com/systematiccaos/gorm"
     8  	"github.com/systematiccaos/gorm/clause"
     9  	"github.com/systematiccaos/gorm/schema"
    10  	"github.com/systematiccaos/gorm/utils"
    11  )
    12  
    13  func preload(db *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) {
    14  	var (
    15  		reflectValue     = db.Statement.ReflectValue
    16  		tx               = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
    17  		relForeignKeys   []string
    18  		relForeignFields []*schema.Field
    19  		foreignFields    []*schema.Field
    20  		foreignValues    [][]interface{}
    21  		identityMap      = map[string][]reflect.Value{}
    22  		inlineConds      []interface{}
    23  	)
    24  
    25  	db.Statement.Settings.Range(func(k, v interface{}) bool {
    26  		tx.Statement.Settings.Store(k, v)
    27  		return true
    28  	})
    29  
    30  	if rel.JoinTable != nil {
    31  		var (
    32  			joinForeignFields    = make([]*schema.Field, 0, len(rel.References))
    33  			joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
    34  			joinForeignKeys      = make([]string, 0, len(rel.References))
    35  		)
    36  
    37  		for _, ref := range rel.References {
    38  			if ref.OwnPrimaryKey {
    39  				joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
    40  				joinForeignFields = append(joinForeignFields, ref.ForeignKey)
    41  				foreignFields = append(foreignFields, ref.PrimaryKey)
    42  			} else if ref.PrimaryValue != "" {
    43  				tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
    44  			} else {
    45  				joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
    46  				relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
    47  				relForeignFields = append(relForeignFields, ref.PrimaryKey)
    48  			}
    49  		}
    50  
    51  		joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
    52  		if len(joinForeignValues) == 0 {
    53  			return
    54  		}
    55  
    56  		joinResults := rel.JoinTable.MakeSlice().Elem()
    57  		column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
    58  		db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error)
    59  
    60  		// convert join identity map to relation identity map
    61  		fieldValues := make([]interface{}, len(joinForeignFields))
    62  		joinFieldValues := make([]interface{}, len(joinRelForeignFields))
    63  		for i := 0; i < joinResults.Len(); i++ {
    64  			joinIndexValue := joinResults.Index(i)
    65  			for idx, field := range joinForeignFields {
    66  				fieldValues[idx], _ = field.ValueOf(joinIndexValue)
    67  			}
    68  
    69  			for idx, field := range joinRelForeignFields {
    70  				joinFieldValues[idx], _ = field.ValueOf(joinIndexValue)
    71  			}
    72  
    73  			if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
    74  				joinKey := utils.ToStringKey(joinFieldValues...)
    75  				identityMap[joinKey] = append(identityMap[joinKey], results...)
    76  			}
    77  		}
    78  
    79  		_, foreignValues = schema.GetIdentityFieldValuesMap(joinResults, joinRelForeignFields)
    80  	} else {
    81  		for _, ref := range rel.References {
    82  			if ref.OwnPrimaryKey {
    83  				relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
    84  				relForeignFields = append(relForeignFields, ref.ForeignKey)
    85  				foreignFields = append(foreignFields, ref.PrimaryKey)
    86  			} else if ref.PrimaryValue != "" {
    87  				tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
    88  			} else {
    89  				relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
    90  				relForeignFields = append(relForeignFields, ref.PrimaryKey)
    91  				foreignFields = append(foreignFields, ref.ForeignKey)
    92  			}
    93  		}
    94  
    95  		identityMap, foreignValues = schema.GetIdentityFieldValuesMap(reflectValue, foreignFields)
    96  		if len(foreignValues) == 0 {
    97  			return
    98  		}
    99  	}
   100  
   101  	// nested preload
   102  	for p, pvs := range preloads {
   103  		tx = tx.Preload(p, pvs...)
   104  	}
   105  
   106  	reflectResults := rel.FieldSchema.MakeSlice().Elem()
   107  	column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
   108  
   109  	if len(values) != 0 {
   110  		for _, cond := range conds {
   111  			if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
   112  				tx = fc(tx)
   113  			} else {
   114  				inlineConds = append(inlineConds, cond)
   115  			}
   116  		}
   117  
   118  		db.AddError(tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error)
   119  	}
   120  
   121  	fieldValues := make([]interface{}, len(relForeignFields))
   122  
   123  	// clean up old values before preloading
   124  	switch reflectValue.Kind() {
   125  	case reflect.Struct:
   126  		switch rel.Type {
   127  		case schema.HasMany, schema.Many2Many:
   128  			rel.Field.Set(reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
   129  		default:
   130  			rel.Field.Set(reflectValue, reflect.New(rel.Field.FieldType).Interface())
   131  		}
   132  	case reflect.Slice, reflect.Array:
   133  		for i := 0; i < reflectValue.Len(); i++ {
   134  			switch rel.Type {
   135  			case schema.HasMany, schema.Many2Many:
   136  				rel.Field.Set(reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())
   137  			default:
   138  				rel.Field.Set(reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())
   139  			}
   140  		}
   141  	}
   142  
   143  	for i := 0; i < reflectResults.Len(); i++ {
   144  		elem := reflectResults.Index(i)
   145  		for idx, field := range relForeignFields {
   146  			fieldValues[idx], _ = field.ValueOf(elem)
   147  		}
   148  
   149  		datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
   150  		if !ok {
   151  			db.AddError(fmt.Errorf("failed to assign association %#v, make sure foreign fields exists",
   152  				elem.Interface()))
   153  			continue
   154  		}
   155  
   156  		for _, data := range datas {
   157  			reflectFieldValue := rel.Field.ReflectValueOf(data)
   158  			if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
   159  				reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
   160  			}
   161  
   162  			reflectFieldValue = reflect.Indirect(reflectFieldValue)
   163  			switch reflectFieldValue.Kind() {
   164  			case reflect.Struct:
   165  				rel.Field.Set(data, elem.Interface())
   166  			case reflect.Slice, reflect.Array:
   167  				if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
   168  					rel.Field.Set(data, reflect.Append(reflectFieldValue, elem).Interface())
   169  				} else {
   170  					rel.Field.Set(data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())
   171  				}
   172  			}
   173  		}
   174  	}
   175  }