github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/zdi/inject.go (about) 1 package zdi 2 3 import ( 4 "fmt" 5 "reflect" 6 7 "github.com/sohaha/zlsgo/zerror" 8 "github.com/sohaha/zlsgo/zreflect" 9 ) 10 11 func (inj *injector) InvokeWithErrorOnly(f interface{}) (err error) { 12 v, err := inj.Invoke(f) 13 if err != nil { 14 return err 15 } 16 17 if len(v) == 0 { 18 return nil 19 } 20 21 for i := range v { 22 if err, ok := v[i].Interface().(error); ok { 23 return err 24 } 25 } 26 27 return nil 28 } 29 30 func (inj *injector) Invoke(f interface{}) (values []reflect.Value, err error) { 31 catch := zerror.TryCatch(func() error { 32 t := zreflect.TypeOf(f) 33 switch v := f.(type) { 34 case PreInvoker: 35 values, err = inj.fast(v, t, t.NumIn()) 36 default: 37 values, err = inj.call(f, t, t.NumIn()) 38 } 39 return nil 40 }) 41 42 if catch != nil { 43 err = catch 44 } 45 46 return 47 } 48 49 func (inj *injector) call(f interface{}, t reflect.Type, numIn int) ([]reflect.Value, error) { 50 var in []reflect.Value 51 if numIn > 0 { 52 in = make([]reflect.Value, numIn) 53 var argType reflect.Type 54 for i := 0; i < numIn; i++ { 55 argType = t.In(i) 56 val, ok := inj.Get(argType) 57 if !ok { 58 return nil, fmt.Errorf("value not found for type %v", argType) 59 } 60 61 in[i] = val 62 } 63 } 64 return zreflect.ValueOf(f).Call(in), nil 65 } 66 67 func (inj *injector) Map(val interface{}, opt ...Option) (override reflect.Type) { 68 o := mapOption{} 69 for _, opt := range opt { 70 opt(&o) 71 } 72 if o.key == nil { 73 o.key = reflect.TypeOf(val) 74 } 75 if _, ok := inj.values[o.key]; ok { 76 override = o.key 77 } 78 79 inj.values[o.key] = zreflect.ValueOf(val) 80 return 81 } 82 83 func (inj *injector) Maps(values ...interface{}) (override []reflect.Type) { 84 for _, val := range values { 85 o := inj.Map(val) 86 if o != nil { 87 override = append(override, o) 88 } 89 } 90 return 91 } 92 93 func (inj *injector) Set(typ reflect.Type, val reflect.Value) { 94 inj.values[typ] = val 95 } 96 97 func (inj *injector) Get(t reflect.Type) (reflect.Value, bool) { 98 val := inj.values[t] 99 if val.IsValid() { 100 return val, true 101 } 102 103 if provider, ok := inj.providers[t]; ok { 104 results, err := inj.Invoke(provider.Interface()) 105 if err != nil { 106 panic(err) 107 } 108 for _, result := range results { 109 resultType := result.Type() 110 inj.values[resultType] = result 111 delete(inj.providers, resultType) 112 if resultType == t { 113 val = result 114 } 115 } 116 117 if val.IsValid() { 118 return val, true 119 } 120 } 121 122 if t.Kind() == reflect.Interface { 123 for k, v := range inj.values { 124 if k.Implements(t) { 125 val = v 126 break 127 } 128 } 129 } 130 131 if val.IsValid() { 132 return val, true 133 } 134 135 var ok bool 136 if inj.parent != nil { 137 val, ok = inj.parent.Get(t) 138 } 139 140 return val, ok 141 }