github.com/goplusjs/reflectx@v0.5.4/method.go (about)

     1  package reflectx
     2  
     3  import (
     4  	"fmt"
     5  	"go/token"
     6  	"log"
     7  	"reflect"
     8  	"sort"
     9  	"strings"
    10  )
    11  
    12  var (
    13  	EnableExportAllMethod = false
    14  )
    15  
    16  // MakeMethod make reflect.Method for MethodOf
    17  // - name: method name
    18  // - pointer: flag receiver struct or pointer
    19  // - typ: method func type without receiver
    20  // - fn: func with receiver as first argument
    21  func MakeMethod(name string, pkgpath string, pointer bool, typ reflect.Type, fn func(args []reflect.Value) (result []reflect.Value)) Method {
    22  	return Method{
    23  		Name:    name,
    24  		PkgPath: pkgpath,
    25  		Pointer: pointer,
    26  		Type:    typ,
    27  		Func:    fn,
    28  	}
    29  }
    30  
    31  // Method struct for MethodOf
    32  // - name: method name
    33  // - pointer: flag receiver struct or pointer
    34  // - typ: method func type without receiver
    35  // - fn: func with receiver as first argument
    36  type Method struct {
    37  	Name    string
    38  	PkgPath string
    39  	Pointer bool
    40  	Type    reflect.Type
    41  	Func    func([]reflect.Value) []reflect.Value
    42  }
    43  
    44  func extraFieldMethod(ifield int, typ reflect.Type, skip map[string]bool) (methods []Method) {
    45  	isPtr := typ.Kind() == reflect.Ptr
    46  	for i := 0; i < typ.NumMethod(); i++ {
    47  		m := MethodByIndex(typ, i)
    48  		if skip[m.Name] {
    49  			continue
    50  		}
    51  		in, out := parserFuncIO(m.Type)
    52  		mtyp := reflect.FuncOf(in[1:], out, m.Type.IsVariadic())
    53  		var fn func(args []reflect.Value) []reflect.Value
    54  		if isPtr {
    55  			fn = func(args []reflect.Value) []reflect.Value {
    56  				args[0] = args[0].Elem().Field(ifield).Addr()
    57  				return m.Func.Call(args)
    58  			}
    59  		} else {
    60  			fn = func(args []reflect.Value) []reflect.Value {
    61  				args[0] = args[0].Field(ifield)
    62  				if mtyp.IsVariadic() {
    63  					return m.Func.CallSlice(args)
    64  				}
    65  				return m.Func.Call(args)
    66  			}
    67  		}
    68  		methods = append(methods, Method{
    69  			Name:    m.Name,
    70  			Pointer: in[0].Kind() == reflect.Ptr,
    71  			Type:    mtyp,
    72  			Func:    fn,
    73  		})
    74  	}
    75  	return
    76  }
    77  
    78  func parserFuncIO(typ reflect.Type) (in, out []reflect.Type) {
    79  	for i := 0; i < typ.NumIn(); i++ {
    80  		in = append(in, typ.In(i))
    81  	}
    82  	for i := 0; i < typ.NumOut(); i++ {
    83  		out = append(out, typ.Out(i))
    84  	}
    85  	return
    86  }
    87  
    88  func extraPtrFieldMethod(ifield int, typ reflect.Type) (methods []Method) {
    89  	for i := 0; i < typ.NumMethod(); i++ {
    90  		m := typ.Method(i)
    91  		in, out := parserFuncIO(m.Type)
    92  		mtyp := reflect.FuncOf(in[1:], out, m.Type.IsVariadic())
    93  		imethod := i
    94  		methods = append(methods, Method{
    95  			Name: m.Name,
    96  			Type: mtyp,
    97  			Func: func(args []reflect.Value) []reflect.Value {
    98  				var recv = args[0]
    99  				if mtyp.IsVariadic() {
   100  					return recv.Field(ifield).Method(imethod).CallSlice(args[1:])
   101  				}
   102  				return recv.Field(ifield).Method(imethod).Call(args[1:])
   103  			},
   104  		})
   105  	}
   106  	return
   107  }
   108  
   109  func extraInterfaceFieldMethod(ifield int, typ reflect.Type) (methods []Method) {
   110  	for i := 0; i < typ.NumMethod(); i++ {
   111  		m := typ.Method(i)
   112  		in, out := parserFuncIO(m.Type)
   113  		mtyp := reflect.FuncOf(in, out, m.Type.IsVariadic())
   114  		imethod := i
   115  		methods = append(methods, Method{
   116  			Name: m.Name,
   117  			Type: mtyp,
   118  			Func: func(args []reflect.Value) []reflect.Value {
   119  				var recv = args[0]
   120  				return recv.Field(ifield).Method(imethod).Call(args[1:])
   121  			},
   122  		})
   123  	}
   124  	return
   125  }
   126  
   127  func extractEmbedMethod(styp reflect.Type) []Method {
   128  	var methods []Method
   129  	for i := 0; i < styp.NumField(); i++ {
   130  		sf := styp.Field(i)
   131  		if !sf.Anonymous {
   132  			continue
   133  		}
   134  		switch sf.Type.Kind() {
   135  		case reflect.Interface:
   136  			ms := extraInterfaceFieldMethod(i, sf.Type)
   137  			methods = append(methods, ms...)
   138  		case reflect.Ptr:
   139  			ms := extraPtrFieldMethod(i, sf.Type)
   140  			methods = append(methods, ms...)
   141  		default:
   142  			skip := make(map[string]bool)
   143  			ms := extraFieldMethod(i, sf.Type, skip)
   144  			for _, m := range ms {
   145  				skip[m.Name] = true
   146  			}
   147  			pms := extraFieldMethod(i, reflect.PtrTo(sf.Type), skip)
   148  			methods = append(methods, ms...)
   149  			methods = append(methods, pms...)
   150  		}
   151  	}
   152  	// ambiguous selector check
   153  	chk := make(map[string]int)
   154  	for _, m := range methods {
   155  		chk[m.Name]++
   156  	}
   157  	var ms []Method
   158  	for _, m := range methods {
   159  		if chk[m.Name] == 1 {
   160  			ms = append(ms, m)
   161  		}
   162  	}
   163  	return ms
   164  }
   165  
   166  func UpdateField(typ reflect.Type, rmap map[reflect.Type]reflect.Type) bool {
   167  	if rmap == nil || typ.Kind() != reflect.Struct {
   168  		return false
   169  	}
   170  	rt := totype(typ)
   171  	st := toStructType(rt)
   172  	for i := 0; i < len(st.fields); i++ {
   173  		t := replaceType(toType(st.fields[i].typ), rmap)
   174  		st.fields[i].typ = totype(t)
   175  	}
   176  	return true
   177  }
   178  
   179  // func UpdateMethod(typ reflect.Type, methods []Method, rmap map[reflect.Type]reflect.Type) bool {
   180  // 	chk := make(map[string]int)
   181  // 	for _, m := range methods {
   182  // 		chk[m.Name]++
   183  // 		if chk[m.Name] > 1 {
   184  // 			panic(fmt.Sprintf("method redeclared: %v", m.Name))
   185  // 		}
   186  // 	}
   187  // 	if typ.Kind() == reflect.Struct {
   188  // 		ms := extractEmbedMethod(typ)
   189  // 		for _, m := range ms {
   190  // 			if chk[m.Name] == 1 {
   191  // 				continue
   192  // 			}
   193  // 			methods = append(methods, m)
   194  // 		}
   195  // 	}
   196  // 	return updateMethod(typ, methods, rmap)
   197  // }
   198  
   199  func Reset() {
   200  	resetTypeList()
   201  	ntypeMap = make(map[reflect.Type]*Named)
   202  	embedLookupCache = make(map[reflect.Type]reflect.Type)
   203  	structLookupCache = make(map[string][]reflect.Type)
   204  	interfceLookupCache = make(map[string]reflect.Type)
   205  }
   206  
   207  var (
   208  	embedLookupCache = make(map[reflect.Type]reflect.Type)
   209  )
   210  
   211  // StructToMethodSet extract method form struct embed fields
   212  func StructToMethodSet(styp reflect.Type) reflect.Type {
   213  	if styp.Kind() != reflect.Struct {
   214  		return styp
   215  	}
   216  	ms := extractEmbedMethod(styp)
   217  	if len(ms) == 0 {
   218  		return styp
   219  	}
   220  	if typ, ok := embedLookupCache[styp]; ok {
   221  		return typ
   222  	}
   223  	var methods []Method
   224  	var mcout, pcount int
   225  	for _, m := range ms {
   226  		if !m.Pointer {
   227  			mcout++
   228  		}
   229  		pcount++
   230  		methods = append(methods, m)
   231  	}
   232  	typ := newMethodSet(styp, mcout, pcount)
   233  	err := setMethodSet(typ, methods)
   234  	if err != nil {
   235  		log.Panicln("error loadMethods", err)
   236  	}
   237  	embedLookupCache[styp] = typ
   238  	return typ
   239  }
   240  
   241  // func MethodOf(styp reflect.Type, methods []Method) reflect.Type {
   242  // 	chk := make(map[string]int)
   243  // 	for _, m := range methods {
   244  // 		chk[m.Name]++
   245  // 		if chk[m.Name] > 1 {
   246  // 			panic(fmt.Sprintf("method redeclared: %v", m.Name))
   247  // 		}
   248  // 	}
   249  // 	if styp.Kind() == reflect.Struct {
   250  // 		ms := extractEmbedMethod(styp)
   251  // 		for _, m := range ms {
   252  // 			if chk[m.Name] == 1 {
   253  // 				continue
   254  // 			}
   255  // 			methods = append(methods, m)
   256  // 		}
   257  // 	}
   258  // 	typ := methodSetOf(styp, len(methods), len(methods))
   259  // 	err := loadMethods(typ, methods)
   260  // 	if err != nil {
   261  // 		log.Panicln("error loadMethods", err)
   262  // 	}
   263  // 	return typ
   264  // }
   265  
   266  // NewMethodSet is pre define method set of styp
   267  // maxmfunc - set methodset of T max member func
   268  // maxpfunc - set methodset of *T + T max member func
   269  func NewMethodSet(styp reflect.Type, maxmfunc, maxpfunc int) reflect.Type {
   270  	if maxpfunc == 0 {
   271  		return StructToMethodSet(styp)
   272  	}
   273  	chk := make(map[string]int)
   274  	if styp.Kind() == reflect.Struct {
   275  		ms := extractEmbedMethod(styp)
   276  		for _, m := range ms {
   277  			if chk[m.Name] == 1 {
   278  				continue
   279  			}
   280  			maxpfunc++
   281  			if !m.Pointer {
   282  				maxmfunc++
   283  			}
   284  		}
   285  	}
   286  	typ := newMethodSet(styp, maxmfunc, maxpfunc)
   287  	return typ
   288  }
   289  
   290  func SetMethodSet(styp reflect.Type, methods []Method, extractStructEmbed bool) error {
   291  	chk := make(map[string]int)
   292  	for _, m := range methods {
   293  		chk[m.Name]++
   294  		if chk[m.Name] > 1 {
   295  			return fmt.Errorf("method redeclared: %v", m.Name)
   296  		}
   297  	}
   298  	if extractStructEmbed && styp.Kind() == reflect.Struct {
   299  		ms := extractEmbedMethod(styp)
   300  		for _, m := range ms {
   301  			if chk[m.Name] == 1 {
   302  				continue
   303  			}
   304  			methods = append(methods, m)
   305  		}
   306  	}
   307  	return setMethodSet(styp, methods)
   308  }
   309  
   310  func MakeEmptyInterface(pkgpath string, name string) reflect.Type {
   311  	return NamedTypeOf(pkgpath, name, tyEmptyInterface)
   312  }
   313  
   314  func NamedInterfaceOf(pkgpath string, name string, embedded []reflect.Type, methods []reflect.Method) reflect.Type {
   315  	typ := NewInterfaceType(pkgpath, name)
   316  	SetInterfaceType(typ, embedded, methods)
   317  	return typ
   318  }
   319  
   320  var (
   321  	interfceLookupCache = make(map[string]reflect.Type)
   322  )
   323  
   324  func NewInterfaceType(pkgpath string, name string) reflect.Type {
   325  	rt, _ := newType("", "", tyEmptyInterface, 0, 0)
   326  	setTypeName(rt, pkgpath, name)
   327  	return toType(rt)
   328  }
   329  
   330  func SetInterfaceType(typ reflect.Type, embedded []reflect.Type, methods []reflect.Method) error {
   331  	for _, e := range embedded {
   332  		if e.Kind() != reflect.Interface {
   333  			return fmt.Errorf("interface contains embedded non-interface %v", e)
   334  		}
   335  		for i := 0; i < e.NumMethod(); i++ {
   336  			m := e.Method(i)
   337  			methods = append(methods, reflect.Method{
   338  				Name: m.Name,
   339  				Type: m.Type,
   340  			})
   341  		}
   342  	}
   343  	sort.Slice(methods, func(i, j int) bool {
   344  		n := strings.Compare(methods[i].Name, methods[j].Name)
   345  		if n == 0 && methods[i].Type != methods[j].Type {
   346  			panic(fmt.Errorf("duplicate method %v", methods[j].Name))
   347  		}
   348  		return n < 0
   349  	})
   350  	rt := totype(typ)
   351  	st := (*interfaceType)(toKindType(rt))
   352  	st.methods = nil
   353  	var info []string
   354  	var lastname string
   355  	var unnamed bool
   356  	if typ.Name() == "" {
   357  		unnamed = true
   358  	}
   359  	for _, m := range methods {
   360  		if m.Name == lastname {
   361  			continue
   362  		}
   363  		lastname = m.Name
   364  		isexport := methodIsExported(m.Name)
   365  		var mname nameOff
   366  		if unnamed {
   367  			nm := newNameEx(m.Name, "", isexport, !isexport)
   368  			mname = resolveReflectName(nm)
   369  			if !isexport {
   370  				nm.setPkgPath(resolveReflectName(newName(m.PkgPath, "", false)))
   371  			}
   372  		} else {
   373  			mname = resolveReflectName(newName(m.Name, "", isexport))
   374  		}
   375  		st.methods = append(st.methods, imethod{
   376  			name: mname,
   377  			typ:  resolveReflectType(totype(m.Type)),
   378  		})
   379  		info = append(info, methodStr(m.Name, m.Type))
   380  	}
   381  	return nil
   382  }
   383  
   384  func InterfaceOf(embedded []reflect.Type, methods []reflect.Method) reflect.Type {
   385  	for _, e := range embedded {
   386  		if e.Kind() != reflect.Interface {
   387  			panic(fmt.Errorf("interface contains embedded non-interface %v", e))
   388  		}
   389  		for i := 0; i < e.NumMethod(); i++ {
   390  			m := e.Method(i)
   391  			methods = append(methods, reflect.Method{
   392  				Name: m.Name,
   393  				Type: m.Type,
   394  			})
   395  		}
   396  	}
   397  	sort.Slice(methods, func(i, j int) bool {
   398  		n := strings.Compare(methods[i].Name, methods[j].Name)
   399  		if n == 0 && methods[i].Type != methods[j].Type {
   400  			panic(fmt.Sprintf("duplicate method %v", methods[j].Name))
   401  		}
   402  		return n < 0
   403  	})
   404  	rt, _ := newType("", "", tyEmptyInterface, 0, 0)
   405  	st := (*interfaceType)(toKindType(rt))
   406  	st.methods = nil
   407  	var info []string
   408  	var lastname string
   409  	for _, m := range methods {
   410  		if m.Name == lastname {
   411  			continue
   412  		}
   413  		lastname = m.Name
   414  		isexport := methodIsExported(m.Name)
   415  		var mname nameOff
   416  		nm := newNameEx(m.Name, "", isexport, !isexport)
   417  		mname = resolveReflectName(nm)
   418  		if !isexport {
   419  			nm.setPkgPath(resolveReflectName(newName(m.PkgPath, "", false)))
   420  		}
   421  		st.methods = append(st.methods, imethod{
   422  			name: mname,
   423  			typ:  resolveReflectType(totype(m.Type)),
   424  		})
   425  		info = append(info, methodStr(m.Name, m.Type))
   426  	}
   427  	var str string
   428  	if len(info) > 0 {
   429  		str = fmt.Sprintf("*interface { %v }", strings.Join(info, "; "))
   430  	} else {
   431  		str = "*interface {}"
   432  	}
   433  	if t, ok := interfceLookupCache[str]; ok {
   434  		return t
   435  	}
   436  	rt.str = resolveReflectName(newName(str, "", false))
   437  	typ := toType(rt)
   438  	interfceLookupCache[str] = typ
   439  	return typ
   440  }
   441  
   442  func methodIsExported(name string) bool {
   443  	if EnableExportAllMethod {
   444  		return true
   445  	}
   446  	return token.IsExported(name)
   447  }
   448  
   449  func methodStr(name string, typ reflect.Type) string {
   450  	return strings.Replace(typ.String(), "func", name, 1)
   451  }
   452  
   453  func toElem(typ reflect.Type) reflect.Type {
   454  	if typ.Kind() == reflect.Ptr {
   455  		return typ.Elem()
   456  	}
   457  	return typ
   458  }
   459  
   460  func toElemValue(v reflect.Value) reflect.Value {
   461  	if v.Kind() == reflect.Ptr {
   462  		return v.Elem()
   463  	}
   464  	return v
   465  }
   466  
   467  func replaceType(typ reflect.Type, rmap map[reflect.Type]reflect.Type) reflect.Type {
   468  	var fnx func(t reflect.Type) (reflect.Type, bool)
   469  	fnx = func(t reflect.Type) (reflect.Type, bool) {
   470  		for k, v := range rmap {
   471  			if k.String() == t.String() {
   472  				return v, true
   473  			}
   474  		}
   475  		switch t.Kind() {
   476  		case reflect.Ptr:
   477  			if e, ok := fnx(t.Elem()); ok {
   478  				return reflect.PtrTo(e), true
   479  			}
   480  		case reflect.Slice:
   481  			if e, ok := fnx(t.Elem()); ok {
   482  				return reflect.SliceOf(e), true
   483  			}
   484  		case reflect.Array:
   485  			if e, ok := fnx(t.Elem()); ok {
   486  				return reflect.ArrayOf(t.Len(), e), true
   487  			}
   488  		case reflect.Map:
   489  			k, ok1 := fnx(t.Key())
   490  			v, ok2 := fnx(t.Elem())
   491  			if ok1 || ok2 {
   492  				return reflect.MapOf(k, v), true
   493  			}
   494  		}
   495  		return t, false
   496  	}
   497  	if r, ok := fnx(typ); ok {
   498  		return r
   499  	}
   500  	return typ
   501  }
   502  
   503  func parserMethodType(mtyp reflect.Type, rmap map[reflect.Type]reflect.Type) (in, out []reflect.Type, ntyp, inTyp, outTyp reflect.Type) {
   504  	var inFields []reflect.StructField
   505  	var outFields []reflect.StructField
   506  	for i := 0; i < mtyp.NumIn(); i++ {
   507  		t := mtyp.In(i)
   508  		if rmap != nil {
   509  			t = replaceType(t, rmap)
   510  		}
   511  		in = append(in, t)
   512  		inFields = append(inFields, reflect.StructField{
   513  			Name: fmt.Sprintf("Arg%v", i),
   514  			Type: t,
   515  		})
   516  	}
   517  	for i := 0; i < mtyp.NumOut(); i++ {
   518  		t := mtyp.Out(i)
   519  		if rmap != nil {
   520  			t = replaceType(t, rmap)
   521  		}
   522  		out = append(out, t)
   523  		outFields = append(outFields, reflect.StructField{
   524  			Name: fmt.Sprintf("Out%v", i),
   525  			Type: t,
   526  		})
   527  	}
   528  	if rmap == nil {
   529  		ntyp = mtyp
   530  	} else {
   531  		ntyp = reflect.FuncOf(in, out, mtyp.IsVariadic())
   532  	}
   533  	inTyp = reflect.StructOf(inFields)
   534  	outTyp = reflect.StructOf(outFields)
   535  	return
   536  }