github.com/goplus/reflectx@v1.2.2/method.go (about)

     1  package reflectx
     2  
     3  import (
     4  	"fmt"
     5  	"go/token"
     6  	"log"
     7  	"reflect"
     8  	"sort"
     9  	"strings"
    10  	"unsafe"
    11  )
    12  
    13  // MakeMethod make reflect.Method for MethodOf
    14  // - name: method name
    15  // - pointer: flag receiver struct or pointer
    16  // - typ: method func type without receiver
    17  // - fn: func with receiver as first argument
    18  func MakeMethod(name string, pkgpath string, pointer bool, typ reflect.Type, fn func(args []reflect.Value) (result []reflect.Value)) Method {
    19  	return Method{
    20  		Name:    name,
    21  		PkgPath: pkgpath,
    22  		Pointer: pointer,
    23  		Type:    typ,
    24  		Func:    fn,
    25  	}
    26  }
    27  
    28  // Method struct for MethodOf
    29  // - name: method name
    30  // - pointer: flag receiver struct or pointer
    31  // - typ: method func type without receiver
    32  // - fn: func with receiver as first argument
    33  type Method struct {
    34  	Name    string
    35  	PkgPath string
    36  	Pointer bool
    37  	Type    reflect.Type
    38  	Func    func([]reflect.Value) []reflect.Value
    39  }
    40  
    41  func extraFieldMethod(ifield int, typ reflect.Type, skip map[string]bool) (methods []Method) {
    42  	isPtr := typ.Kind() == reflect.Ptr
    43  	for i := 0; i < typ.NumMethod(); i++ {
    44  		m := MethodByIndex(typ, i)
    45  		if skip[m.Name] {
    46  			continue
    47  		}
    48  		in, out := parserFuncIO(m.Type)
    49  		mtyp := reflect.FuncOf(in[1:], out, m.Type.IsVariadic())
    50  		var fn func(args []reflect.Value) []reflect.Value
    51  		if isPtr {
    52  			fn = func(args []reflect.Value) []reflect.Value {
    53  				args[0] = args[0].Elem().Field(ifield).Addr()
    54  				return m.Func.Call(args)
    55  			}
    56  		} else {
    57  			fn = func(args []reflect.Value) []reflect.Value {
    58  				args[0] = args[0].Field(ifield)
    59  				if mtyp.IsVariadic() {
    60  					return m.Func.CallSlice(args)
    61  				}
    62  				return m.Func.Call(args)
    63  			}
    64  		}
    65  		methods = append(methods, Method{
    66  			Name:    m.Name,
    67  			Pointer: in[0].Kind() == reflect.Ptr,
    68  			Type:    mtyp,
    69  			Func:    fn,
    70  		})
    71  	}
    72  	return
    73  }
    74  
    75  func parserFuncIO(typ reflect.Type) (in, out []reflect.Type) {
    76  	numIn := typ.NumIn()
    77  	numOut := typ.NumOut()
    78  	for i := 0; i < numIn; i++ {
    79  		in = append(in, typ.In(i))
    80  	}
    81  	for i := 0; i < numOut; i++ {
    82  		out = append(out, typ.Out(i))
    83  	}
    84  	return
    85  }
    86  
    87  func extraPtrFieldMethod(ifield int, typ reflect.Type) (methods []Method) {
    88  	for i := 0; i < typ.NumMethod(); i++ {
    89  		m := typ.Method(i)
    90  		in, out := parserFuncIO(m.Type)
    91  		mtyp := reflect.FuncOf(in[1:], out, m.Type.IsVariadic())
    92  		imethod := i
    93  		methods = append(methods, Method{
    94  			Name: m.Name,
    95  			Type: mtyp,
    96  			Func: func(args []reflect.Value) []reflect.Value {
    97  				var recv = args[0]
    98  				if mtyp.IsVariadic() {
    99  					return recv.Field(ifield).Method(imethod).CallSlice(args[1:])
   100  				}
   101  				return recv.Field(ifield).Method(imethod).Call(args[1:])
   102  			},
   103  		})
   104  	}
   105  	return
   106  }
   107  
   108  func extraInterfaceFieldMethod(ifield int, typ reflect.Type) (methods []Method) {
   109  	for i := 0; i < typ.NumMethod(); i++ {
   110  		m := typ.Method(i)
   111  		in, out := parserFuncIO(m.Type)
   112  		mtyp := reflect.FuncOf(in, out, m.Type.IsVariadic())
   113  		imethod := i
   114  		methods = append(methods, Method{
   115  			Name: m.Name,
   116  			Type: mtyp,
   117  			Func: func(args []reflect.Value) []reflect.Value {
   118  				var recv = args[0]
   119  				return recv.Field(ifield).Method(imethod).Call(args[1:])
   120  			},
   121  		})
   122  	}
   123  	return
   124  }
   125  
   126  func extractEmbedMethod(styp reflect.Type) []Method {
   127  	var methods []Method
   128  	for i := 0; i < styp.NumField(); i++ {
   129  		sf := styp.Field(i)
   130  		if !sf.Anonymous {
   131  			continue
   132  		}
   133  		switch sf.Type.Kind() {
   134  		case reflect.Interface:
   135  			ms := extraInterfaceFieldMethod(i, sf.Type)
   136  			methods = append(methods, ms...)
   137  		case reflect.Ptr:
   138  			ms := extraPtrFieldMethod(i, sf.Type)
   139  			methods = append(methods, ms...)
   140  		default:
   141  			skip := make(map[string]bool)
   142  			ms := extraFieldMethod(i, sf.Type, skip)
   143  			for _, m := range ms {
   144  				skip[m.Name] = true
   145  			}
   146  			pms := extraFieldMethod(i, reflect.PtrTo(sf.Type), skip)
   147  			methods = append(methods, ms...)
   148  			methods = append(methods, pms...)
   149  		}
   150  	}
   151  	// ambiguous selector check
   152  	chk := make(map[string]int)
   153  	for _, m := range methods {
   154  		chk[m.Name]++
   155  	}
   156  	var ms []Method
   157  	for _, m := range methods {
   158  		if chk[m.Name] == 1 {
   159  			ms = append(ms, m)
   160  		}
   161  	}
   162  	return ms
   163  }
   164  
   165  func UpdateField(typ reflect.Type, rmap map[reflect.Type]reflect.Type) bool {
   166  	if rmap == nil || typ.Kind() != reflect.Struct {
   167  		return false
   168  	}
   169  	rt := totype(typ)
   170  	st := toStructType(rt)
   171  	for i := 0; i < len(st.fields); i++ {
   172  		t := replaceType(toType(st.fields[i].typ), rmap)
   173  		st.fields[i].typ = totype(t)
   174  	}
   175  	return true
   176  }
   177  
   178  func Reset() {
   179  	Default.Reset()
   180  }
   181  
   182  func ResetAll() {
   183  	resetAll()
   184  }
   185  
   186  func StructToMethodSet(styp reflect.Type) reflect.Type {
   187  	return Default.StructToMethodSet(styp)
   188  }
   189  
   190  // StructToMethodSet extract method form struct embed fields
   191  func (ctx *Context) StructToMethodSet(styp reflect.Type) reflect.Type {
   192  	if styp.Kind() != reflect.Struct {
   193  		return styp
   194  	}
   195  	ms := extractEmbedMethod(styp)
   196  	if len(ms) == 0 {
   197  		return styp
   198  	}
   199  	if typ, ok := ctx.embedLookupCache[styp]; ok {
   200  		return typ
   201  	}
   202  	var methods []Method
   203  	var mcout, pcount int
   204  	for _, m := range ms {
   205  		if !m.Pointer {
   206  			mcout++
   207  		}
   208  		pcount++
   209  		methods = append(methods, m)
   210  	}
   211  	typ := newMethodSet(styp, mcout, pcount)
   212  	err := ctx.setMethodSet(typ, methods)
   213  	if err != nil {
   214  		log.Panicln("error loadMethods", err)
   215  	}
   216  	ctx.embedLookupCache[styp] = typ
   217  	return typ
   218  }
   219  
   220  // NewMethodSet is pre define method set of styp
   221  // maxmfunc - set methodset of T max member func
   222  // maxpfunc - set methodset of *T + T max member func
   223  func NewMethodSet(styp reflect.Type, maxmfunc, maxpfunc int) reflect.Type {
   224  	return Default.NewMethodSet(styp, maxmfunc, maxpfunc)
   225  }
   226  
   227  func (ctx *Context) NewMethodSet(styp reflect.Type, maxmfunc, maxpfunc int) reflect.Type {
   228  	if maxpfunc == 0 {
   229  		return ctx.StructToMethodSet(styp)
   230  	}
   231  	chk := make(map[string]int)
   232  	if styp.Kind() == reflect.Struct {
   233  		ms := extractEmbedMethod(styp)
   234  		for _, m := range ms {
   235  			if chk[m.Name] == 1 {
   236  				continue
   237  			}
   238  			maxpfunc++
   239  			if !m.Pointer {
   240  				maxmfunc++
   241  			}
   242  		}
   243  	}
   244  	typ := newMethodSet(styp, maxmfunc, maxpfunc)
   245  	return typ
   246  }
   247  
   248  func SetMethodSet(styp reflect.Type, methods []Method, extractStructEmbed bool) error {
   249  	return Default.SetMethodSet(styp, methods, extractStructEmbed)
   250  }
   251  
   252  func (ctx *Context) SetMethodSet(styp reflect.Type, methods []Method, extractStructEmbed bool) error {
   253  	chk := make(map[string]Method)
   254  	for _, m := range methods {
   255  		if v, ok := chk[m.Name]; ok && v.PkgPath == m.PkgPath {
   256  			return fmt.Errorf("method redeclared: %v", m.Name)
   257  		}
   258  		chk[m.Name] = m
   259  	}
   260  	if extractStructEmbed && styp.Kind() == reflect.Struct {
   261  		ms := extractEmbedMethod(styp)
   262  		for _, m := range ms {
   263  			if _, ok := chk[m.Name]; ok {
   264  				continue
   265  			}
   266  			methods = append(methods, m)
   267  		}
   268  	}
   269  	return ctx.setMethodSet(styp, methods)
   270  }
   271  
   272  func MakeEmptyInterface(pkgpath string, name string) reflect.Type {
   273  	return NamedTypeOf(pkgpath, name, tyEmptyInterface)
   274  }
   275  
   276  func NamedInterfaceOf(pkgpath string, name string, embedded []reflect.Type, methods []reflect.Method) reflect.Type {
   277  	typ := NewInterfaceType(pkgpath, name)
   278  	SetInterfaceType(typ, embedded, methods)
   279  	return typ
   280  }
   281  
   282  func NewInterfaceType(pkgpath string, name string) reflect.Type {
   283  	rt, _ := newType("", "", tyEmptyInterface, 0, 0)
   284  	setTypeName(rt, pkgpath, name)
   285  	return toType(rt)
   286  }
   287  
   288  func SetInterfaceType(typ reflect.Type, embedded []reflect.Type, methods []reflect.Method) error {
   289  	for _, e := range embedded {
   290  		if e.Kind() != reflect.Interface {
   291  			return fmt.Errorf("interface contains embedded non-interface %v", e)
   292  		}
   293  		for i := 0; i < e.NumMethod(); i++ {
   294  			m := e.Method(i)
   295  			methods = append(methods, reflect.Method{
   296  				Name: m.Name,
   297  				Type: m.Type,
   298  			})
   299  		}
   300  	}
   301  	sort.Slice(methods, func(i, j int) bool {
   302  		n := strings.Compare(methods[i].Name, methods[j].Name)
   303  		if n == 0 && methods[i].Type != methods[j].Type {
   304  			panic(fmt.Errorf("duplicate method %v", methods[j].Name))
   305  		}
   306  		return n < 0
   307  	})
   308  	rt := totype(typ)
   309  	st := (*interfaceType)(toKindType(rt))
   310  	st.methods = nil
   311  	var info []string
   312  	var lastname string
   313  	var unnamed bool
   314  	if typ.Name() == "" {
   315  		unnamed = true
   316  	}
   317  	for _, m := range methods {
   318  		if m.Name == lastname {
   319  			continue
   320  		}
   321  		lastname = m.Name
   322  		isexport := methodIsExported(m.Name)
   323  		var mname nameOff
   324  		if unnamed {
   325  			nm := newNameEx(m.Name, "", isexport, !isexport)
   326  			mname = resolveReflectName(nm)
   327  			if !isexport {
   328  				nm.setPkgPath(m.PkgPath)
   329  			}
   330  		} else {
   331  			mname = resolveReflectName(newName(m.Name, "", isexport))
   332  		}
   333  		st.methods = append(st.methods, imethod{
   334  			name: mname,
   335  			typ:  resolveReflectType(totype(m.Type)),
   336  		})
   337  		info = append(info, methodStr(m.Name, m.Type))
   338  	}
   339  	return nil
   340  }
   341  
   342  //go:linkname interequal runtime.interequal
   343  func interequal(p, q unsafe.Pointer) bool
   344  
   345  func InterfaceOf(embedded []reflect.Type, methods []reflect.Method) reflect.Type {
   346  	return Default.InterfaceOf(embedded, methods)
   347  }
   348  
   349  func (ctx *Context) InterfaceOf(embedded []reflect.Type, methods []reflect.Method) reflect.Type {
   350  	for _, e := range embedded {
   351  		if e.Kind() != reflect.Interface {
   352  			panic(fmt.Errorf("interface contains embedded non-interface %v", e))
   353  		}
   354  		for i := 0; i < e.NumMethod(); i++ {
   355  			m := e.Method(i)
   356  			methods = append(methods, reflect.Method{
   357  				Name: m.Name,
   358  				Type: m.Type,
   359  			})
   360  		}
   361  	}
   362  	sort.Slice(methods, func(i, j int) bool {
   363  		n := strings.Compare(methods[i].Name, methods[j].Name)
   364  		if n == 0 && methods[i].Type != methods[j].Type {
   365  			panic(fmt.Sprintf("duplicate method %v", methods[j].Name))
   366  		}
   367  		return n < 0
   368  	})
   369  	rt, _ := newType("", "", tyEmptyInterface, 0, 0)
   370  	st := (*interfaceType)(toKindType(rt))
   371  	st.methods = nil
   372  	var info []string
   373  	var lastname string
   374  	for _, m := range methods {
   375  		if m.Name == lastname {
   376  			continue
   377  		}
   378  		lastname = m.Name
   379  		isexport := methodIsExported(m.Name)
   380  		var mname nameOff
   381  		nm := newNameEx(m.Name, "", isexport, !isexport)
   382  		mname = resolveReflectName(nm)
   383  		if !isexport {
   384  			nm.setPkgPath(m.PkgPath)
   385  		}
   386  		st.methods = append(st.methods, imethod{
   387  			name: mname,
   388  			typ:  resolveReflectType(totype(m.Type)),
   389  		})
   390  		info = append(info, methodStr(m.Name, m.Type))
   391  	}
   392  	if len(st.methods) > 0 {
   393  		rt.equal = interequal
   394  	}
   395  	var str string
   396  	if len(info) > 0 {
   397  		str = fmt.Sprintf("*interface { %v }", strings.Join(info, "; "))
   398  	} else {
   399  		str = "*interface {}"
   400  	}
   401  	if t, ok := ctx.interfceLookupCache[str]; ok {
   402  		return t
   403  	}
   404  	rt.str = resolveReflectName(newName(str, "", false))
   405  	typ := toType(rt)
   406  	ctx.interfceLookupCache[str] = typ
   407  	return typ
   408  }
   409  
   410  func methodIsExported(name string) bool {
   411  	return token.IsExported(name)
   412  }
   413  
   414  func methodStr(name string, typ reflect.Type) string {
   415  	return strings.Replace(typ.String(), "func", name, 1)
   416  }
   417  
   418  func toElem(typ reflect.Type) reflect.Type {
   419  	if typ.Kind() == reflect.Ptr {
   420  		return typ.Elem()
   421  	}
   422  	return typ
   423  }
   424  
   425  func toElemValue(v reflect.Value) reflect.Value {
   426  	if v.Kind() == reflect.Ptr {
   427  		return v.Elem()
   428  	}
   429  	return v
   430  }
   431  
   432  func replaceType(typ reflect.Type, rmap map[reflect.Type]reflect.Type) reflect.Type {
   433  	var fnx func(t reflect.Type) (reflect.Type, bool)
   434  	fnx = func(t reflect.Type) (reflect.Type, bool) {
   435  		for k, v := range rmap {
   436  			if k.String() == t.String() {
   437  				return v, true
   438  			}
   439  		}
   440  		switch t.Kind() {
   441  		case reflect.Ptr:
   442  			if e, ok := fnx(t.Elem()); ok {
   443  				return reflect.PtrTo(e), true
   444  			}
   445  		case reflect.Slice:
   446  			if e, ok := fnx(t.Elem()); ok {
   447  				return reflect.SliceOf(e), true
   448  			}
   449  		case reflect.Array:
   450  			if e, ok := fnx(t.Elem()); ok {
   451  				return reflect.ArrayOf(t.Len(), e), true
   452  			}
   453  		case reflect.Map:
   454  			k, ok1 := fnx(t.Key())
   455  			v, ok2 := fnx(t.Elem())
   456  			if ok1 || ok2 {
   457  				return reflect.MapOf(k, v), true
   458  			}
   459  		}
   460  		return t, false
   461  	}
   462  	if r, ok := fnx(typ); ok {
   463  		return r
   464  	}
   465  	return typ
   466  }
   467  
   468  func parserMethodType(mtyp reflect.Type, rmap map[reflect.Type]reflect.Type) (in, out []reflect.Type, ntyp, inTyp, outTyp reflect.Type) {
   469  	var inFields []reflect.StructField
   470  	var outFields []reflect.StructField
   471  	numIn := mtyp.NumIn()
   472  	numOut := mtyp.NumOut()
   473  	for i := 0; i < numIn; i++ {
   474  		t := mtyp.In(i)
   475  		if rmap != nil {
   476  			t = replaceType(t, rmap)
   477  		}
   478  		in = append(in, t)
   479  		inFields = append(inFields, reflect.StructField{
   480  			Name: fmt.Sprintf("Arg%v", i),
   481  			Type: t,
   482  		})
   483  	}
   484  	for i := 0; i < numOut; i++ {
   485  		t := mtyp.Out(i)
   486  		if rmap != nil {
   487  			t = replaceType(t, rmap)
   488  		}
   489  		out = append(out, t)
   490  		outFields = append(outFields, reflect.StructField{
   491  			Name: fmt.Sprintf("Out%v", i),
   492  			Type: t,
   493  		})
   494  	}
   495  	if rmap == nil {
   496  		ntyp = mtyp
   497  	} else {
   498  		ntyp = reflect.FuncOf(in, out, mtyp.IsVariadic())
   499  	}
   500  	inTyp = reflect.StructOf(inFields)
   501  	outTyp = reflect.StructOf(outFields)
   502  	return
   503  }