github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/zdi/inject.go (about)

     1  package zdi
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  
     7  	"github.com/sohaha/zlsgo/zerror"
     8  	"github.com/sohaha/zlsgo/zreflect"
     9  )
    10  
    11  func (inj *injector) InvokeWithErrorOnly(f interface{}) (err error) {
    12  	v, err := inj.Invoke(f)
    13  	if err != nil {
    14  		return err
    15  	}
    16  
    17  	if len(v) == 0 {
    18  		return nil
    19  	}
    20  
    21  	for i := range v {
    22  		if err, ok := v[i].Interface().(error); ok {
    23  			return err
    24  		}
    25  	}
    26  
    27  	return nil
    28  }
    29  
    30  func (inj *injector) Invoke(f interface{}) (values []reflect.Value, err error) {
    31  	catch := zerror.TryCatch(func() error {
    32  		t := zreflect.TypeOf(f)
    33  		switch v := f.(type) {
    34  		case PreInvoker:
    35  			values, err = inj.fast(v, t, t.NumIn())
    36  		default:
    37  			values, err = inj.call(f, t, t.NumIn())
    38  		}
    39  		return nil
    40  	})
    41  
    42  	if catch != nil {
    43  		err = catch
    44  	}
    45  
    46  	return
    47  }
    48  
    49  func (inj *injector) call(f interface{}, t reflect.Type, numIn int) ([]reflect.Value, error) {
    50  	var in []reflect.Value
    51  	if numIn > 0 {
    52  		in = make([]reflect.Value, numIn)
    53  		var argType reflect.Type
    54  		for i := 0; i < numIn; i++ {
    55  			argType = t.In(i)
    56  			val, ok := inj.Get(argType)
    57  			if !ok {
    58  				return nil, fmt.Errorf("value not found for type %v", argType)
    59  			}
    60  
    61  			in[i] = val
    62  		}
    63  	}
    64  	return zreflect.ValueOf(f).Call(in), nil
    65  }
    66  
    67  func (inj *injector) Map(val interface{}, opt ...Option) (override reflect.Type) {
    68  	o := mapOption{}
    69  	for _, opt := range opt {
    70  		opt(&o)
    71  	}
    72  	if o.key == nil {
    73  		o.key = reflect.TypeOf(val)
    74  	}
    75  	if _, ok := inj.values[o.key]; ok {
    76  		override = o.key
    77  	}
    78  
    79  	inj.values[o.key] = zreflect.ValueOf(val)
    80  	return
    81  }
    82  
    83  func (inj *injector) Maps(values ...interface{}) (override []reflect.Type) {
    84  	for _, val := range values {
    85  		o := inj.Map(val)
    86  		if o != nil {
    87  			override = append(override, o)
    88  		}
    89  	}
    90  	return
    91  }
    92  
    93  func (inj *injector) Set(typ reflect.Type, val reflect.Value) {
    94  	inj.values[typ] = val
    95  }
    96  
    97  func (inj *injector) Get(t reflect.Type) (reflect.Value, bool) {
    98  	val := inj.values[t]
    99  	if val.IsValid() {
   100  		return val, true
   101  	}
   102  
   103  	if provider, ok := inj.providers[t]; ok {
   104  		results, err := inj.Invoke(provider.Interface())
   105  		if err != nil {
   106  			panic(err)
   107  		}
   108  		for _, result := range results {
   109  			resultType := result.Type()
   110  			inj.values[resultType] = result
   111  			delete(inj.providers, resultType)
   112  			if resultType == t {
   113  				val = result
   114  			}
   115  		}
   116  
   117  		if val.IsValid() {
   118  			return val, true
   119  		}
   120  	}
   121  
   122  	if t.Kind() == reflect.Interface {
   123  		for k, v := range inj.values {
   124  			if k.Implements(t) {
   125  				val = v
   126  				break
   127  			}
   128  		}
   129  	}
   130  
   131  	if val.IsValid() {
   132  		return val, true
   133  	}
   134  
   135  	var ok bool
   136  	if inj.parent != nil {
   137  		val, ok = inj.parent.Get(t)
   138  	}
   139  
   140  	return val, ok
   141  }