github.com/goplus/gossa@v0.3.25/xtypes.go (about)

     1  package gossa
     2  
     3  import (
     4  	"fmt"
     5  	"go/token"
     6  	"go/types"
     7  	"log"
     8  	"reflect"
     9  	"strconv"
    10  	"unsafe"
    11  
    12  	"github.com/goplus/reflectx"
    13  	"golang.org/x/tools/go/ssa"
    14  	"golang.org/x/tools/go/types/typeutil"
    15  )
    16  
    17  var (
    18  	tyEmptyStruct = reflect.TypeOf((*struct{})(nil)).Elem()
    19  	tyEmptyPtr    = reflect.TypeOf((*struct{})(nil))
    20  	tyEmptyMap    = reflect.TypeOf((*map[struct{}]struct{})(nil)).Elem()
    21  	tyEmptySlice  = reflect.TypeOf((*[]struct{})(nil)).Elem()
    22  	tyEmptyArray  = reflect.TypeOf((*[0]struct{})(nil)).Elem()
    23  	tyEmptyChan   = reflect.TypeOf((*chan struct{})(nil)).Elem()
    24  	tyEmptyFunc   = reflect.TypeOf((*func())(nil)).Elem()
    25  )
    26  
    27  /*
    28  	Array
    29  	Chan
    30  	Func
    31  	Interface
    32  	Map
    33  	Ptr
    34  	Slice
    35  	String
    36  	Struct
    37  	UnsafePointer
    38  */
    39  
    40  func emptyType(kind reflect.Kind) reflect.Type {
    41  	switch kind {
    42  	case reflect.Array:
    43  		return tyEmptyArray
    44  	case reflect.Chan:
    45  		return tyEmptyChan
    46  	case reflect.Func:
    47  		return tyEmptyFunc
    48  	case reflect.Interface:
    49  		return tyEmptyInterface
    50  	case reflect.Map:
    51  		return tyEmptyMap
    52  	case reflect.Ptr:
    53  		return tyEmptyPtr
    54  	case reflect.Slice:
    55  		return tyEmptySlice
    56  	case reflect.Struct:
    57  		return tyEmptyStruct
    58  	default:
    59  		return xtypeTypes[kind]
    60  	}
    61  	panic(fmt.Errorf("emptyType: unreachable kind %v", kind))
    62  }
    63  
    64  func toMockType(typ types.Type) reflect.Type {
    65  	switch t := typ.(type) {
    66  	case *types.Basic:
    67  		kind := t.Kind()
    68  		if kind > types.Invalid && kind < types.UntypedNil {
    69  			return xtypeTypes[kind]
    70  		}
    71  		panic(fmt.Errorf("toMockType: invalid type %v", typ))
    72  	case *types.Pointer:
    73  		return tyEmptyPtr
    74  	case *types.Slice:
    75  		return tyEmptySlice
    76  	case *types.Array:
    77  		e := toMockType(t.Elem())
    78  		return reflect.ArrayOf(int(t.Len()), e)
    79  	case *types.Map:
    80  		return tyEmptyMap
    81  	case *types.Chan:
    82  		return tyEmptyChan
    83  	case *types.Struct:
    84  		n := t.NumFields()
    85  		fs := make([]reflect.StructField, n, n)
    86  		for i := 0; i < n; i++ {
    87  			ft := t.Field(i)
    88  			fs[i].Name = "F" + strconv.Itoa(i)
    89  			fs[i].Type = toMockType(ft.Type())
    90  			fs[i].Anonymous = ft.Embedded()
    91  		}
    92  		return reflect.StructOf(fs)
    93  	case *types.Named:
    94  		return toMockType(typ.Underlying())
    95  	case *types.Interface:
    96  		return tyEmptyInterface
    97  	case *types.Signature:
    98  		in := t.Params().Len()
    99  		out := t.Results().Len()
   100  		if in+out == 0 {
   101  			return tyEmptyFunc
   102  		}
   103  		ins := make([]reflect.Type, in, in)
   104  		outs := make([]reflect.Type, out, out)
   105  		for i := 0; i < in; i++ {
   106  			ins[i] = tyEmptyStruct
   107  		}
   108  		for i := 0; i < out; i++ {
   109  			outs[i] = tyEmptyStruct
   110  		}
   111  		return reflect.FuncOf(ins, outs, t.Variadic())
   112  	default:
   113  		panic(fmt.Errorf("toEmptyType: unreachable %v", typ))
   114  	}
   115  }
   116  
   117  var xtypeTypes = [...]reflect.Type{
   118  	types.Bool:          reflect.TypeOf(false),
   119  	types.Int:           reflect.TypeOf(0),
   120  	types.Int8:          reflect.TypeOf(int8(0)),
   121  	types.Int16:         reflect.TypeOf(int16(0)),
   122  	types.Int32:         reflect.TypeOf(int32(0)),
   123  	types.Int64:         reflect.TypeOf(int64(0)),
   124  	types.Uint:          reflect.TypeOf(uint(0)),
   125  	types.Uint8:         reflect.TypeOf(uint8(0)),
   126  	types.Uint16:        reflect.TypeOf(uint16(0)),
   127  	types.Uint32:        reflect.TypeOf(uint32(0)),
   128  	types.Uint64:        reflect.TypeOf(uint64(0)),
   129  	types.Uintptr:       reflect.TypeOf(uintptr(0)),
   130  	types.Float32:       reflect.TypeOf(float32(0)),
   131  	types.Float64:       reflect.TypeOf(float64(0)),
   132  	types.Complex64:     reflect.TypeOf(complex64(0)),
   133  	types.Complex128:    reflect.TypeOf(complex128(0)),
   134  	types.String:        reflect.TypeOf(""),
   135  	types.UnsafePointer: reflect.TypeOf(unsafe.Pointer(nil)),
   136  
   137  	types.UntypedBool:    reflect.TypeOf(false),
   138  	types.UntypedInt:     reflect.TypeOf(0),
   139  	types.UntypedRune:    reflect.TypeOf('a'),
   140  	types.UntypedFloat:   reflect.TypeOf(0.1),
   141  	types.UntypedComplex: reflect.TypeOf(0 + 1i),
   142  	types.UntypedString:  reflect.TypeOf(""),
   143  }
   144  
   145  type FindMethod interface {
   146  	FindMethod(mtyp reflect.Type, fn *types.Func) func([]reflect.Value) []reflect.Value
   147  }
   148  
   149  type TypesRecord struct {
   150  	loader Loader
   151  	finder FindMethod
   152  	rcache map[reflect.Type]types.Type
   153  	tcache *typeutil.Map
   154  }
   155  
   156  func NewTypesRecord(loader Loader, finder FindMethod) *TypesRecord {
   157  	return &TypesRecord{
   158  		loader: loader,
   159  		finder: finder,
   160  		rcache: make(map[reflect.Type]types.Type),
   161  		tcache: &typeutil.Map{},
   162  	}
   163  }
   164  
   165  func (r *TypesRecord) LookupReflect(typ types.Type) (rt reflect.Type, ok bool) {
   166  	rt, ok = r.loader.LookupReflect(typ)
   167  	if !ok {
   168  		if rt := r.tcache.At(typ); rt != nil {
   169  			return rt.(reflect.Type), true
   170  		}
   171  	}
   172  	return
   173  }
   174  
   175  func (r *TypesRecord) LookupLocalTypes(rt reflect.Type) (typ types.Type, ok bool) {
   176  	typ, ok = r.rcache[rt]
   177  	return
   178  }
   179  
   180  func (r *TypesRecord) LookupTypes(rt reflect.Type) (typ types.Type, ok bool) {
   181  	typ, ok = r.loader.LookupTypes(rt)
   182  	if !ok {
   183  		typ, ok = r.rcache[rt]
   184  	}
   185  	return
   186  }
   187  
   188  func (r *TypesRecord) saveType(typ types.Type, rt reflect.Type) {
   189  	r.tcache.Set(typ, rt)
   190  	r.rcache[rt] = typ
   191  }
   192  
   193  func (r *TypesRecord) ToType(typ types.Type) reflect.Type {
   194  	if rt, ok := r.LookupReflect(typ); ok {
   195  		return rt
   196  	}
   197  	var rt reflect.Type
   198  	switch t := typ.(type) {
   199  	case *types.Basic:
   200  		kind := t.Kind()
   201  		if kind > types.Invalid && kind < types.UntypedNil {
   202  			rt = xtypeTypes[kind]
   203  		}
   204  	case *types.Pointer:
   205  		elem := r.ToType(t.Elem())
   206  		rt = reflect.PtrTo(elem)
   207  	case *types.Slice:
   208  		elem := r.ToType(t.Elem())
   209  		rt = reflect.SliceOf(elem)
   210  	case *types.Array:
   211  		elem := r.ToType(t.Elem())
   212  		rt = reflect.ArrayOf(int(t.Len()), elem)
   213  	case *types.Map:
   214  		key := r.ToType(t.Key())
   215  		elem := r.ToType(t.Elem())
   216  		rt = reflect.MapOf(key, elem)
   217  	case *types.Chan:
   218  		elem := r.ToType(t.Elem())
   219  		rt = reflect.ChanOf(toReflectChanDir(t.Dir()), elem)
   220  	case *types.Struct:
   221  		rt = r.toStructType(t)
   222  	case *types.Named:
   223  		rt = r.toNamedType(t)
   224  	case *types.Interface:
   225  		rt = r.toInterfaceType(t)
   226  	case *types.Signature:
   227  		in := r.ToTypeList(t.Params())
   228  		out := r.ToTypeList(t.Results())
   229  		b := t.Variadic()
   230  		if b && len(in) > 0 {
   231  			last := in[len(in)-1]
   232  			if last.Kind() == reflect.String {
   233  				in[len(in)-1] = reflect.TypeOf([]byte{})
   234  			}
   235  		}
   236  		rt = reflect.FuncOf(in, out, b)
   237  	case *types.Tuple:
   238  		r.ToTypeList(t)
   239  		rt = reflect.TypeOf((*_tuple)(nil)).Elem()
   240  	default:
   241  		panic(fmt.Errorf("ToType: not handled %v\n", typ))
   242  	}
   243  	r.saveType(typ, rt)
   244  	return rt
   245  }
   246  
   247  type _tuple struct{}
   248  
   249  func (r *TypesRecord) toInterfaceType(t *types.Interface) reflect.Type {
   250  	n := t.NumMethods()
   251  	if n == 0 {
   252  		return tyEmptyInterface
   253  	}
   254  	ms := make([]reflect.Method, n, n)
   255  	for i := 0; i < n; i++ {
   256  		fn := t.Method(i)
   257  		mtyp := r.ToType(fn.Type())
   258  		ms[i] = reflect.Method{
   259  			Name: fn.Name(),
   260  			Type: mtyp,
   261  		}
   262  		if pkg := fn.Pkg(); pkg != nil {
   263  			ms[i].PkgPath = pkg.Path()
   264  		}
   265  	}
   266  	return reflectx.InterfaceOf(nil, ms)
   267  }
   268  
   269  func (r *TypesRecord) toNamedType(t *types.Named) reflect.Type {
   270  	ut := t.Underlying()
   271  	name := t.Obj()
   272  	if name.Pkg() == nil {
   273  		if name.Name() == "error" {
   274  			return tyErrorInterface
   275  		}
   276  		return r.ToType(ut)
   277  	}
   278  	methods := IntuitiveMethodSet(t)
   279  	numMethods := len(methods)
   280  	if numMethods == 0 {
   281  		styp := toMockType(t.Underlying())
   282  		typ := reflectx.NamedTypeOf(name.Pkg().Path(), name.Name(), styp)
   283  		r.saveType(t, typ)
   284  		utype := r.ToType(ut)
   285  		reflectx.SetUnderlying(typ, utype)
   286  		return typ
   287  	} else {
   288  		var mcount, pcount int
   289  		for i := 0; i < numMethods; i++ {
   290  			sig := methods[i].Type().(*types.Signature)
   291  			if !isPointer(sig.Recv().Type()) {
   292  				mcount++
   293  			}
   294  			pcount++
   295  		}
   296  		// toMockType for size/align
   297  		etyp := toMockType(ut)
   298  		styp := reflectx.NamedTypeOf(name.Pkg().Path(), name.Name(), etyp)
   299  		typ := reflectx.NewMethodSet(styp, mcount, pcount)
   300  		r.saveType(t, typ)
   301  		utype := r.ToType(ut)
   302  		reflectx.SetUnderlying(typ, utype)
   303  		if typ.Kind() != reflect.Interface {
   304  			r.setMethods(typ, methods)
   305  		}
   306  		return typ
   307  	}
   308  }
   309  
   310  func (r *TypesRecord) toStructType(t *types.Struct) reflect.Type {
   311  	n := t.NumFields()
   312  	if n == 0 {
   313  		return tyEmptyStruct
   314  	}
   315  	flds := make([]reflect.StructField, n)
   316  	for i := 0; i < n; i++ {
   317  		flds[i] = r.toStructField(t.Field(i), t.Tag(i))
   318  	}
   319  	typ := reflectx.StructOf(flds)
   320  	methods := IntuitiveMethodSet(t)
   321  	if numMethods := len(methods); numMethods != 0 {
   322  		// anonymous structs with methods. struct { T }
   323  		var mcount, pcount int
   324  		for i := 0; i < numMethods; i++ {
   325  			sig := methods[i].Type().(*types.Signature)
   326  			if !isPointer(sig.Recv().Type()) {
   327  				mcount++
   328  			}
   329  			pcount++
   330  		}
   331  		typ = reflectx.NewMethodSet(typ, mcount, pcount)
   332  		r.setMethods(typ, methods)
   333  	}
   334  	return typ
   335  }
   336  
   337  func (r *TypesRecord) toStructField(v *types.Var, tag string) reflect.StructField {
   338  	name := v.Name()
   339  	typ := r.ToType(v.Type())
   340  	fld := reflect.StructField{
   341  		Name:      name,
   342  		Type:      typ,
   343  		Tag:       reflect.StructTag(tag),
   344  		Anonymous: v.Anonymous(),
   345  	}
   346  	if !token.IsExported(name) {
   347  		fld.PkgPath = v.Pkg().Path()
   348  	}
   349  	return fld
   350  }
   351  
   352  func (r *TypesRecord) ToTypeList(tuple *types.Tuple) []reflect.Type {
   353  	n := tuple.Len()
   354  	if n == 0 {
   355  		return nil
   356  	}
   357  	list := make([]reflect.Type, n, n)
   358  	for i := 0; i < n; i++ {
   359  		list[i] = r.ToType(tuple.At(i).Type())
   360  	}
   361  	return list
   362  }
   363  
   364  func isPointer(typ types.Type) bool {
   365  	_, ok := typ.Underlying().(*types.Pointer)
   366  	return ok
   367  }
   368  
   369  func (r *TypesRecord) setMethods(typ reflect.Type, methods []*types.Selection) {
   370  	numMethods := len(methods)
   371  	var ms []reflectx.Method
   372  	for i := 0; i < numMethods; i++ {
   373  		fn := methods[i].Obj().(*types.Func)
   374  		sig := methods[i].Type().(*types.Signature)
   375  		pointer := isPointer(sig.Recv().Type())
   376  		mtyp := r.ToType(sig)
   377  		var mfn func(args []reflect.Value) []reflect.Value
   378  		idx := methods[i].Index()
   379  		if len(idx) > 1 {
   380  			isptr := isPointer(fn.Type().Underlying().(*types.Signature).Recv().Type())
   381  			mfn = func(args []reflect.Value) []reflect.Value {
   382  				v := args[0]
   383  				for v.Kind() == reflect.Ptr {
   384  					v = v.Elem()
   385  				}
   386  				v = reflectx.FieldByIndexX(v, idx[:len(idx)-1])
   387  				if isptr && v.Kind() != reflect.Ptr {
   388  					v = v.Addr()
   389  				}
   390  				m, _ := reflectx.MethodByName(v.Type(), fn.Name())
   391  				args[0] = v
   392  				return m.Func.Call(args)
   393  			}
   394  		} else {
   395  			mfn = r.finder.FindMethod(mtyp, fn)
   396  		}
   397  		var pkgpath string
   398  		if pkg := fn.Pkg(); pkg != nil {
   399  			pkgpath = pkg.Path()
   400  		}
   401  		ms = append(ms, reflectx.MakeMethod(fn.Name(), pkgpath, pointer, mtyp, mfn))
   402  	}
   403  	err := reflectx.SetMethodSet(typ, ms, false)
   404  	if err != nil {
   405  		log.Fatalf("SetMethodSet %v err, %v\n", typ, err)
   406  	}
   407  }
   408  
   409  func toReflectChanDir(d types.ChanDir) reflect.ChanDir {
   410  	switch d {
   411  	case types.SendRecv:
   412  		return reflect.BothDir
   413  	case types.SendOnly:
   414  		return reflect.SendDir
   415  	case types.RecvOnly:
   416  		return reflect.RecvDir
   417  	}
   418  	return 0
   419  }
   420  
   421  func (r *TypesRecord) LoadType(typ types.Type) reflect.Type {
   422  	return r.ToType(typ)
   423  }
   424  
   425  func (r *TypesRecord) Load(pkg *ssa.Package) {
   426  	checked := make(map[types.Type]bool)
   427  	for _, v := range pkg.Members {
   428  		typ := v.Type()
   429  		if checked[typ] {
   430  			continue
   431  		}
   432  		checked[typ] = true
   433  		r.LoadType(typ)
   434  	}
   435  }
   436  
   437  // golang.org/x/tools/go/types/typeutil.IntuitiveMethodSet
   438  func IntuitiveMethodSet(T types.Type) []*types.Selection {
   439  	isPointerToConcrete := func(T types.Type) bool {
   440  		ptr, ok := T.(*types.Pointer)
   441  		return ok && !types.IsInterface(ptr.Elem())
   442  	}
   443  
   444  	var result []*types.Selection
   445  	mset := types.NewMethodSet(T)
   446  	if types.IsInterface(T) || isPointerToConcrete(T) {
   447  		for i, n := 0, mset.Len(); i < n; i++ {
   448  			result = append(result, mset.At(i))
   449  		}
   450  	} else {
   451  		// T is some other concrete type.
   452  		// Report methods of T and *T, preferring those of T.
   453  		pmset := types.NewMethodSet(types.NewPointer(T))
   454  		for i, n := 0, pmset.Len(); i < n; i++ {
   455  			meth := pmset.At(i)
   456  			if m := mset.Lookup(meth.Obj().Pkg(), meth.Obj().Name()); m != nil {
   457  				meth = m
   458  			}
   459  			result = append(result, meth)
   460  		}
   461  	}
   462  	return result
   463  }