github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/queries/eager_load.go (about)

     1  package queries
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"reflect"
     7  	"strings"
     8  
     9  	"github.com/friendsofgo/errors"
    10  	"github.com/volatiletech/sqlboiler/v4/boil"
    11  	"github.com/volatiletech/strmangle"
    12  )
    13  
    14  type loadRelationshipState struct {
    15  	ctx    context.Context
    16  	exec   boil.Executor
    17  	loaded map[string]struct{}
    18  	toLoad []string
    19  	mods   map[string]Applicator
    20  }
    21  
    22  func (l loadRelationshipState) hasLoaded(depth int) bool {
    23  	_, ok := l.loaded[l.buildKey(depth)]
    24  	return ok
    25  }
    26  
    27  func (l loadRelationshipState) setLoaded(depth int) {
    28  	l.loaded[l.buildKey(depth)] = struct{}{}
    29  }
    30  
    31  func (l loadRelationshipState) buildKey(depth int) string {
    32  	buf := strmangle.GetBuffer()
    33  
    34  	for i, piece := range l.toLoad[:depth+1] {
    35  		if i != 0 {
    36  			buf.WriteByte('.')
    37  		}
    38  		buf.WriteString(piece)
    39  	}
    40  
    41  	str := buf.String()
    42  	strmangle.PutBuffer(buf)
    43  	return str
    44  }
    45  
    46  // eagerLoad loads all of the model's relationships
    47  //
    48  // toLoad should look like:
    49  // []string{"Relationship", "Relationship.NestedRelationship"} ... etc
    50  // obj should be one of:
    51  // *[]*struct or *struct
    52  // bkind should reflect what kind of thing it is above
    53  func eagerLoad(ctx context.Context, exec boil.Executor, toLoad []string, mods map[string]Applicator, obj interface{}, bkind bindKind) error {
    54  	state := loadRelationshipState{
    55  		ctx:    ctx, // defiant to the end, I know this is frowned upon
    56  		exec:   exec,
    57  		loaded: map[string]struct{}{},
    58  		mods:   mods,
    59  	}
    60  	for _, toLoad := range toLoad {
    61  		state.toLoad = strings.Split(toLoad, ".")
    62  		if err := state.loadRelationships(0, obj, bkind); err != nil {
    63  			return err
    64  		}
    65  	}
    66  
    67  	return nil
    68  }
    69  
    70  // loadRelationships dynamically calls the template generated eager load
    71  // functions of the form:
    72  //
    73  //   func (t *TableR) LoadRelationshipName(exec Executor, singular bool, obj interface{})
    74  //
    75  // The arguments to this function are:
    76  //   - t is not considered here, and is always passed nil. The function exists on a loaded
    77  //     struct to avoid a circular dependency with boil, and the receiver is ignored.
    78  //   - exec is used to perform additional queries that might be required for loading the relationships.
    79  //   - bkind is passed in to identify whether or not this was a single object
    80  //     or a slice that must be loaded into.
    81  //   - obj is the object or slice of objects, always of the type *obj or *[]*obj as per bind.
    82  //
    83  // We start with a normal select before eager loading anything: select * from a;
    84  // Then we start eager loading things, it can be represented by a DAG
    85  //          a1, a2           select id, a_id from b where id in (a1, a2)
    86  //         / |    \
    87  //        b1 b2    b3        select id, b_id from c where id in (b2, b3, b4)
    88  //       /   | \     \
    89  //      c1  c2 c3    c4
    90  //
    91  // That's to say that we descend the graph of relationships, and at each level
    92  // we gather all the things up we want to load into, load them, and then move
    93  // to the next level of the graph.
    94  func (l loadRelationshipState) loadRelationships(depth int, obj interface{}, bkind bindKind) error {
    95  	typ := reflect.TypeOf(obj).Elem()
    96  	if bkind == kindPtrSliceStruct {
    97  		typ = typ.Elem().Elem()
    98  	}
    99  
   100  	loadingFrom := reflect.ValueOf(obj)
   101  	if loadingFrom.IsNil() {
   102  		return nil
   103  	}
   104  
   105  	if !l.hasLoaded(depth) {
   106  		if err := l.callLoadFunction(depth, loadingFrom, typ, bkind); err != nil {
   107  			return err
   108  		}
   109  	}
   110  
   111  	// Check if we can stop
   112  	if depth+1 >= len(l.toLoad) {
   113  		return nil
   114  	}
   115  
   116  	// *[]*struct -> []*struct
   117  	// *struct -> struct
   118  	loadingFrom = reflect.Indirect(loadingFrom)
   119  
   120  	// If it's singular we can just immediately call without looping
   121  	if bkind == kindStruct {
   122  		return l.loadRelationshipsRecurse(depth, loadingFrom)
   123  	}
   124  
   125  	// If we were an empty slice to begin with, bail, probably a useless check
   126  	if loadingFrom.Len() == 0 {
   127  		return nil
   128  	}
   129  
   130  	// Collect eagerly loaded things to send into next eager load call
   131  	slice, nextBKind, err := collectLoaded(l.toLoad[depth], loadingFrom)
   132  	if err != nil {
   133  		return err
   134  	}
   135  
   136  	// If we could collect nothing we're done
   137  	if slice.Len() == 0 {
   138  		return nil
   139  	}
   140  
   141  	ptr := reflect.New(slice.Type())
   142  	ptr.Elem().Set(slice)
   143  
   144  	return l.loadRelationships(depth+1, ptr.Interface(), nextBKind)
   145  }
   146  
   147  // callLoadFunction finds the loader struct, finds the method that we need
   148  // to call and calls it.
   149  func (l loadRelationshipState) callLoadFunction(depth int, loadingFrom reflect.Value, typ reflect.Type, bkind bindKind) error {
   150  	current := l.toLoad[depth]
   151  	ln, found := typ.FieldByName(loaderStructName)
   152  	// It's possible a Loaders struct doesn't exist on the struct.
   153  	if !found {
   154  		return errors.Errorf("attempted to load %s but no L struct was found", current)
   155  	}
   156  
   157  	// Attempt to find the LoadRelationshipName function
   158  	loadMethod, found := ln.Type.MethodByName(loadMethodPrefix + current)
   159  	if !found {
   160  		return errors.Errorf("could not find %s%s method for eager loading", loadMethodPrefix, current)
   161  	}
   162  
   163  	ctxArg := reflect.ValueOf(l.ctx)
   164  	// Hack to allow nil executors
   165  	execArg := reflect.ValueOf(l.exec)
   166  	if !execArg.IsValid() {
   167  		execArg = reflect.ValueOf((*sql.DB)(nil))
   168  	}
   169  
   170  	// Get a loader instance from anything we have, *struct, or *[]*struct
   171  	val := reflect.Indirect(loadingFrom)
   172  	if bkind == kindPtrSliceStruct {
   173  		if val.Len() == 0 {
   174  			return nil
   175  		}
   176  		val = val.Index(0)
   177  		if val.IsNil() {
   178  			return nil
   179  		}
   180  		val = reflect.Indirect(val)
   181  	}
   182  
   183  	methodArgs := make([]reflect.Value, 0, 5)
   184  	methodArgs = append(methodArgs, val.FieldByName(loaderStructName))
   185  	if ctxArg.IsValid() {
   186  		methodArgs = append(methodArgs, ctxArg)
   187  	}
   188  	methodArgs = append(methodArgs, execArg, reflect.ValueOf(bkind == kindStruct), loadingFrom)
   189  	if mods, ok := l.mods[l.buildKey(depth)]; ok {
   190  		methodArgs = append(methodArgs, reflect.ValueOf(mods))
   191  	} else {
   192  		methodArgs = append(methodArgs, applicatorSentinelVal)
   193  	}
   194  
   195  	ret := loadMethod.Func.Call(methodArgs)
   196  	if intf := ret[0].Interface(); intf != nil {
   197  		return errors.Wrapf(intf.(error), "failed to eager load %s", current)
   198  	}
   199  
   200  	l.setLoaded(depth)
   201  	return nil
   202  }
   203  
   204  // loadRelationshipsRecurse is a helper function for taking a reflect.Value and
   205  // Basically calls loadRelationships with: obj.R.EagerLoadedObj
   206  // Called with an obj of *struct
   207  func (l loadRelationshipState) loadRelationshipsRecurse(depth int, obj reflect.Value) error {
   208  	key := l.toLoad[depth]
   209  	r, err := findRelationshipStruct(obj)
   210  	if err != nil {
   211  		return errors.Wrapf(err, "failed to append loaded %s", key)
   212  	}
   213  
   214  	loadedObject := reflect.Indirect(r).FieldByName(key)
   215  	if loadedObject.IsNil() {
   216  		return nil
   217  	}
   218  
   219  	bkind := kindStruct
   220  	if derefed := reflect.Indirect(loadedObject); derefed.Kind() != reflect.Struct {
   221  		bkind = kindPtrSliceStruct
   222  
   223  		// Convert away any helper slice types
   224  		// elemType is *elem (from []*elem or helperSliceType)
   225  		// sliceType is *[]*elem
   226  		elemType := derefed.Type().Elem()
   227  		sliceType := reflect.PtrTo(reflect.SliceOf(elemType))
   228  
   229  		loadedObject = loadedObject.Addr().Convert(sliceType)
   230  	}
   231  	return l.loadRelationships(depth+1, loadedObject.Interface(), bkind)
   232  }
   233  
   234  // collectLoaded traverses the next level of the graph and picks up all
   235  // the values that we need for the next eager load query.
   236  //
   237  // For example when loadingFrom is [parent1, parent2]
   238  //
   239  //   parent1 -> child1
   240  //          \-> child2
   241  //   parent2 -> child3
   242  //
   243  // This should return [child1, child2, child3]
   244  func collectLoaded(key string, loadingFrom reflect.Value) (reflect.Value, bindKind, error) {
   245  	// Pull the first one so we can get the types out of it in order to
   246  	// create the proper type of slice.
   247  	current := reflect.Indirect(loadingFrom.Index(0))
   248  	lnFrom := loadingFrom.Len()
   249  
   250  	r, err := findRelationshipStruct(current)
   251  	if err != nil {
   252  		return reflect.Value{}, 0, errors.Wrapf(err, "failed to collect loaded %s", key)
   253  	}
   254  
   255  	loadedObject := reflect.Indirect(r).FieldByName(key)
   256  	loadedType := loadedObject.Type() // Should be *obj or []*obj
   257  
   258  	bkind := kindPtrSliceStruct
   259  	if loadedType.Elem().Kind() == reflect.Struct {
   260  		bkind = kindStruct
   261  		loadedType = reflect.SliceOf(loadedType)
   262  	} else {
   263  		// Ensure that we get rid of all the helper "XSlice" types
   264  		loadedType = reflect.SliceOf(loadedType.Elem())
   265  	}
   266  
   267  	collection := reflect.MakeSlice(loadedType, 0, 0)
   268  
   269  	i := 0
   270  	for {
   271  		switch bkind {
   272  		case kindStruct:
   273  			if !loadedObject.IsNil() {
   274  				collection = reflect.Append(collection, loadedObject)
   275  			}
   276  		case kindPtrSliceStruct:
   277  			collection = reflect.AppendSlice(collection, loadedObject)
   278  		}
   279  
   280  		i++
   281  		if i >= lnFrom {
   282  			break
   283  		}
   284  
   285  		current = reflect.Indirect(loadingFrom.Index(i))
   286  		r, err = findRelationshipStruct(current)
   287  		if err != nil {
   288  			return reflect.Value{}, 0, errors.Wrapf(err, "failed to collect loaded %s", key)
   289  		}
   290  
   291  		loadedObject = reflect.Indirect(r).FieldByName(key)
   292  	}
   293  
   294  	return collection, kindPtrSliceStruct, nil
   295  }
   296  
   297  func findRelationshipStruct(obj reflect.Value) (reflect.Value, error) {
   298  	relationshipStruct := obj.FieldByName(relationshipStructName)
   299  	if !relationshipStruct.IsValid() {
   300  		return reflect.Value{}, errors.New("relationship struct was invalid")
   301  	} else if relationshipStruct.IsNil() {
   302  		return reflect.Value{}, errors.New("relationship struct was nil")
   303  	}
   304  
   305  	return relationshipStruct, nil
   306  }
   307  
   308  var (
   309  	applicatorSentinel    Applicator
   310  	applicatorSentinelVal = reflect.ValueOf(&applicatorSentinel).Elem()
   311  )
   312  
   313  // SetFromEmbeddedStruct sets `to` value from embedded struct
   314  // of the `from` struct or slice of structs.
   315  // Expects `to` and `from` to be a pair of pre-allocated **struct or *[]*struct.
   316  // Returns false if types do not match.
   317  func SetFromEmbeddedStruct(to interface{}, from interface{}) bool {
   318  	toPtrVal := reflect.ValueOf(to)
   319  	fromPtrVal := reflect.ValueOf(from)
   320  	if toPtrVal.Kind() != reflect.Ptr || fromPtrVal.Kind() != reflect.Ptr {
   321  		return false
   322  	}
   323  	toStructTyp, ok := singularStructType(to)
   324  	if !ok {
   325  		return false
   326  	}
   327  	fromStructTyp, ok := singularStructType(from)
   328  	if !ok {
   329  		return false
   330  	}
   331  	fieldNum, ok := embeddedStructFieldNum(fromStructTyp, toStructTyp)
   332  	if !ok {
   333  		return false
   334  	}
   335  	toVal := toPtrVal.Elem()
   336  	if toVal.Kind() == reflect.Interface {
   337  		toVal = reflect.ValueOf(toVal.Interface())
   338  	}
   339  	fromVal := fromPtrVal.Elem()
   340  	if fromVal.Kind() == reflect.Interface {
   341  		fromVal = reflect.ValueOf(fromVal.Interface())
   342  	}
   343  
   344  	if toVal.Kind() == reflect.Ptr && toVal.Elem().Kind() == reflect.Struct &&
   345  		fromVal.Kind() == reflect.Ptr && fromVal.Elem().Kind() == reflect.Struct {
   346  		toVal.Set(fromVal.Elem().Field(fieldNum).Addr())
   347  
   348  		return true
   349  	}
   350  
   351  	toKind := toPtrVal.Type().Elem().Kind()
   352  	fromKind := fromPtrVal.Type().Elem().Kind()
   353  
   354  	if toKind == reflect.Slice && fromKind == reflect.Slice {
   355  		toSlice := reflect.MakeSlice(toVal.Type(), fromVal.Len(), fromVal.Len())
   356  		for i := 0; i < fromVal.Len(); i++ {
   357  			toSlice.Index(i).Set(fromVal.Index(i).Elem().Field(fieldNum).Addr())
   358  		}
   359  		toVal.Set(toSlice)
   360  
   361  		return true
   362  	}
   363  
   364  	return false
   365  }
   366  
   367  // singularStructType returns singular struct type
   368  // from **struct or *[]*struct types.
   369  // Used for Load* methods during binding.
   370  func singularStructType(obj interface{}) (reflect.Type, bool) {
   371  	val := reflect.Indirect(reflect.ValueOf(obj))
   372  	if val.Kind() == reflect.Interface {
   373  		val = reflect.ValueOf(val.Interface())
   374  	}
   375  	typ := val.Type()
   376  	inSlice := false
   377  SWITCH:
   378  	switch typ.Kind() {
   379  	case reflect.Ptr:
   380  		typ = typ.Elem()
   381  
   382  		goto SWITCH
   383  	case reflect.Slice:
   384  		if inSlice {
   385  			// Slices inside other slices are not supported
   386  			return nil, false
   387  		}
   388  		inSlice = true
   389  		typ = typ.Elem()
   390  
   391  		goto SWITCH
   392  	case reflect.Struct:
   393  		return typ, true
   394  	default:
   395  		return nil, false
   396  	}
   397  }
   398  
   399  // embeddedStructFieldNum returns the index of embedded struct field of type `emb` inside `obj` struct.
   400  func embeddedStructFieldNum(obj reflect.Type, emb reflect.Type) (int, bool) {
   401  	for i := 0; i < obj.NumField(); i++ {
   402  		v := obj.Field(i)
   403  		if v.Type.Kind() == reflect.Struct &&
   404  			v.Anonymous && v.Type == emb {
   405  			return i, true
   406  		}
   407  	}
   408  	return 0, false
   409  }