github.com/wfusion/gofusion@v1.1.14/common/di/dig.go (about) 1 package di 2 3 import ( 4 "fmt" 5 "reflect" 6 7 "go.uber.org/dig" 8 9 "github.com/wfusion/gofusion/common/constant" 10 "github.com/wfusion/gofusion/common/utils" 11 "github.com/wfusion/gofusion/common/utils/inspect" 12 ) 13 14 var ( 15 Dig = NewDI() 16 17 inType = reflect.TypeOf(In{}) 18 outType = reflect.TypeOf(Out{}) 19 digInType = reflect.TypeOf(dig.In{}) 20 digOutType = reflect.TypeOf(dig.Out{}) 21 22 Type = reflect.TypeOf(NewDI()) 23 ) 24 25 type _dig struct { 26 *dig.Container 27 fields []reflect.StructField 28 } 29 30 func NewDI() DI { 31 return &_dig{Container: dig.New()} 32 } 33 34 type provideOption struct { 35 name string 36 group string 37 } 38 39 func Name(name string) utils.OptionFunc[provideOption] { 40 return func(p *provideOption) { 41 p.name = name 42 } 43 } 44 45 func Group(group string) utils.OptionFunc[provideOption] { 46 return func(p *provideOption) { 47 p.group = group 48 } 49 } 50 51 func (d *_dig) Invoke(fn any) error { return d.Container.Invoke(fn) } 52 func (d *_dig) MustInvoke(fn any) { utils.MustSuccess(d.Container.Invoke(fn)) } 53 func (d *_dig) Provide(ctor any, opts ...utils.OptionExtender) (err error) { 54 opt := utils.ApplyOptions[provideOption](opts...) 55 digOpts := make([]dig.ProvideOption, 0, 2) 56 if opt.name != "" { 57 digOpts = append(digOpts, dig.Name(opt.name)) 58 } 59 if opt.group != "" { 60 digOpts = append(digOpts, dig.Group(opt.group)) 61 } 62 63 defer d.addFields(ctor, opt) 64 return d.Container.Provide(ctor, digOpts...) 65 } 66 func (d *_dig) MustProvide(ctor any, opts ...utils.OptionExtender) DI { 67 utils.MustSuccess(d.Provide(ctor, opts...)) 68 return d 69 } 70 func (d *_dig) Decorate(decorator any) error { return d.Container.Decorate(decorator) } 71 func (d *_dig) MustDecorate(decorator any) DI { 72 utils.MustSuccess(d.Container.Decorate(decorator)) 73 return d 74 } 75 76 func (d *_dig) String() string { 77 return d.Container.String() 78 } 79 80 func (d *_dig) Clear() { 81 d.Container = dig.New() 82 } 83 84 // Preload prevent invoke concurrently because invoke is not concurrent safe 85 // base on: https://github.com/uber-go/dig/issues/241 86 func (d *_dig) Preload() { 87 fields := make([]reflect.StructField, 0, 1+len(d.fields)) 88 fields = append(fields, reflect.StructField{ 89 Name: "In", 90 PkgPath: "", 91 Type: digInType, 92 Tag: "", 93 Offset: 0, 94 Index: nil, 95 Anonymous: true, 96 }) 97 for i := 0; i < len(d.fields); i++ { 98 fields = append(fields, reflect.StructField{ 99 Name: fmt.Sprintf("Arg%X", i+1), 100 PkgPath: "", 101 Type: d.fields[i].Type, 102 Tag: d.fields[i].Tag, 103 Offset: 0, 104 Index: nil, 105 Anonymous: false, 106 }) 107 } 108 structType := reflect.StructOf(fields) 109 110 // FIXME: we cannot declare function param type dynamic now 111 scope := inspect.GetField[*dig.Scope](d.Container, "scope") 112 containerStoreType := inspect.TypeOf("go.uber.org/dig.containerStore") 113 containerStoreVal := reflect.ValueOf(scope).Convert(containerStoreType) 114 115 fakeParam := utils.Must(newParam(structType, scope)) 116 paramType := inspect.TypeOf("go.uber.org/dig.param") 117 paramVal := reflect.ValueOf(fakeParam).Convert(paramType) 118 buildFn := paramVal.MethodByName("Build") 119 returnValList := buildFn.Call([]reflect.Value{containerStoreVal}) 120 if errVal := returnValList[len(returnValList)-1].Interface(); errVal != nil { 121 if err, ok := errVal.(error); ok && err != nil { 122 panic(err) 123 } 124 } 125 } 126 127 func (d *_dig) addFields(ctor any, opt *provideOption) { 128 typ := reflect.TypeOf(ctor) 129 numOfOut := typ.NumOut() 130 for i := 0; i < numOfOut; i++ { 131 out := typ.Out(i) 132 // ignore error and non-interface nor non-struct out param 133 if out == constant.ErrorType || 134 (out.Kind() != reflect.Interface && 135 (out.Kind() != reflect.Struct && !(out.Kind() == reflect.Ptr && out.Elem().Kind() == reflect.Struct))) { 136 continue 137 } 138 if !utils.EmbedsType(out, digOutType) { 139 var tag reflect.StructTag 140 switch { 141 case opt.name != "": 142 tag = reflect.StructTag(fmt.Sprintf(`name:"%s"`, opt.name)) 143 case opt.group != "": 144 tag = reflect.StructTag(fmt.Sprintf(`group:"%s"`, opt.group)) 145 out = reflect.SliceOf(out) 146 } 147 148 d.fields = append(d.fields, reflect.StructField{Type: out, Tag: tag}) 149 continue 150 } 151 152 // traverse all field 153 numOfFields := out.NumField() 154 for j := 0; j < numOfFields; j++ { 155 f := out.Field(j) 156 157 // ignore dig out 158 if f.Type == digOutType || f.Type == outType { 159 continue 160 } 161 162 d.fields = append(d.fields, f) 163 } 164 } 165 }