github.com/wfusion/gofusion@v1.1.14/common/di/dig.go (about)

     1  package di
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  
     7  	"go.uber.org/dig"
     8  
     9  	"github.com/wfusion/gofusion/common/constant"
    10  	"github.com/wfusion/gofusion/common/utils"
    11  	"github.com/wfusion/gofusion/common/utils/inspect"
    12  )
    13  
    14  var (
    15  	Dig = NewDI()
    16  
    17  	inType     = reflect.TypeOf(In{})
    18  	outType    = reflect.TypeOf(Out{})
    19  	digInType  = reflect.TypeOf(dig.In{})
    20  	digOutType = reflect.TypeOf(dig.Out{})
    21  
    22  	Type = reflect.TypeOf(NewDI())
    23  )
    24  
    25  type _dig struct {
    26  	*dig.Container
    27  	fields []reflect.StructField
    28  }
    29  
    30  func NewDI() DI {
    31  	return &_dig{Container: dig.New()}
    32  }
    33  
    34  type provideOption struct {
    35  	name  string
    36  	group string
    37  }
    38  
    39  func Name(name string) utils.OptionFunc[provideOption] {
    40  	return func(p *provideOption) {
    41  		p.name = name
    42  	}
    43  }
    44  
    45  func Group(group string) utils.OptionFunc[provideOption] {
    46  	return func(p *provideOption) {
    47  		p.group = group
    48  	}
    49  }
    50  
    51  func (d *_dig) Invoke(fn any) error { return d.Container.Invoke(fn) }
    52  func (d *_dig) MustInvoke(fn any)   { utils.MustSuccess(d.Container.Invoke(fn)) }
    53  func (d *_dig) Provide(ctor any, opts ...utils.OptionExtender) (err error) {
    54  	opt := utils.ApplyOptions[provideOption](opts...)
    55  	digOpts := make([]dig.ProvideOption, 0, 2)
    56  	if opt.name != "" {
    57  		digOpts = append(digOpts, dig.Name(opt.name))
    58  	}
    59  	if opt.group != "" {
    60  		digOpts = append(digOpts, dig.Group(opt.group))
    61  	}
    62  
    63  	defer d.addFields(ctor, opt)
    64  	return d.Container.Provide(ctor, digOpts...)
    65  }
    66  func (d *_dig) MustProvide(ctor any, opts ...utils.OptionExtender) DI {
    67  	utils.MustSuccess(d.Provide(ctor, opts...))
    68  	return d
    69  }
    70  func (d *_dig) Decorate(decorator any) error { return d.Container.Decorate(decorator) }
    71  func (d *_dig) MustDecorate(decorator any) DI {
    72  	utils.MustSuccess(d.Container.Decorate(decorator))
    73  	return d
    74  }
    75  
    76  func (d *_dig) String() string {
    77  	return d.Container.String()
    78  }
    79  
    80  func (d *_dig) Clear() {
    81  	d.Container = dig.New()
    82  }
    83  
    84  // Preload prevent invoke concurrently because invoke is not concurrent safe
    85  // base on: https://github.com/uber-go/dig/issues/241
    86  func (d *_dig) Preload() {
    87  	fields := make([]reflect.StructField, 0, 1+len(d.fields))
    88  	fields = append(fields, reflect.StructField{
    89  		Name:      "In",
    90  		PkgPath:   "",
    91  		Type:      digInType,
    92  		Tag:       "",
    93  		Offset:    0,
    94  		Index:     nil,
    95  		Anonymous: true,
    96  	})
    97  	for i := 0; i < len(d.fields); i++ {
    98  		fields = append(fields, reflect.StructField{
    99  			Name:      fmt.Sprintf("Arg%X", i+1),
   100  			PkgPath:   "",
   101  			Type:      d.fields[i].Type,
   102  			Tag:       d.fields[i].Tag,
   103  			Offset:    0,
   104  			Index:     nil,
   105  			Anonymous: false,
   106  		})
   107  	}
   108  	structType := reflect.StructOf(fields)
   109  
   110  	// FIXME: we cannot declare function param type dynamic now
   111  	scope := inspect.GetField[*dig.Scope](d.Container, "scope")
   112  	containerStoreType := inspect.TypeOf("go.uber.org/dig.containerStore")
   113  	containerStoreVal := reflect.ValueOf(scope).Convert(containerStoreType)
   114  
   115  	fakeParam := utils.Must(newParam(structType, scope))
   116  	paramType := inspect.TypeOf("go.uber.org/dig.param")
   117  	paramVal := reflect.ValueOf(fakeParam).Convert(paramType)
   118  	buildFn := paramVal.MethodByName("Build")
   119  	returnValList := buildFn.Call([]reflect.Value{containerStoreVal})
   120  	if errVal := returnValList[len(returnValList)-1].Interface(); errVal != nil {
   121  		if err, ok := errVal.(error); ok && err != nil {
   122  			panic(err)
   123  		}
   124  	}
   125  }
   126  
   127  func (d *_dig) addFields(ctor any, opt *provideOption) {
   128  	typ := reflect.TypeOf(ctor)
   129  	numOfOut := typ.NumOut()
   130  	for i := 0; i < numOfOut; i++ {
   131  		out := typ.Out(i)
   132  		// ignore error and non-interface nor non-struct out param
   133  		if out == constant.ErrorType ||
   134  			(out.Kind() != reflect.Interface &&
   135  				(out.Kind() != reflect.Struct && !(out.Kind() == reflect.Ptr && out.Elem().Kind() == reflect.Struct))) {
   136  			continue
   137  		}
   138  		if !utils.EmbedsType(out, digOutType) {
   139  			var tag reflect.StructTag
   140  			switch {
   141  			case opt.name != "":
   142  				tag = reflect.StructTag(fmt.Sprintf(`name:"%s"`, opt.name))
   143  			case opt.group != "":
   144  				tag = reflect.StructTag(fmt.Sprintf(`group:"%s"`, opt.group))
   145  				out = reflect.SliceOf(out)
   146  			}
   147  
   148  			d.fields = append(d.fields, reflect.StructField{Type: out, Tag: tag})
   149  			continue
   150  		}
   151  
   152  		// traverse all field
   153  		numOfFields := out.NumField()
   154  		for j := 0; j < numOfFields; j++ {
   155  			f := out.Field(j)
   156  
   157  			// ignore dig out
   158  			if f.Type == digOutType || f.Type == outType {
   159  				continue
   160  			}
   161  
   162  			d.fields = append(d.fields, f)
   163  		}
   164  	}
   165  }