github.com/goplus/gogen@v1.16.0/typeparams.go (about)

     1  /*
     2   Copyright 2022 The GoPlus Authors (goplus.org)
     3   Licensed under the Apache License, Version 2.0 (the "License");
     4   you may not use this file except in compliance with the License.
     5   You may obtain a copy of the License at
     6       http://www.apache.org/licenses/LICENSE-2.0
     7   Unless required by applicable law or agreed to in writing, software
     8   distributed under the License is distributed on an "AS IS" BASIS,
     9   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10   See the License for the specific language governing permissions and
    11   limitations under the License.
    12  */
    13  
    14  package gogen
    15  
    16  import (
    17  	"fmt"
    18  	"go/ast"
    19  	"go/constant"
    20  	"go/token"
    21  	"go/types"
    22  	"log"
    23  	"strings"
    24  	_ "unsafe"
    25  
    26  	"github.com/goplus/gogen/internal"
    27  	"github.com/goplus/gogen/internal/goxdbg"
    28  )
    29  
    30  // ----------------------------------------------------------------------------
    31  
    32  type TypeParam = types.TypeParam
    33  type Union = types.Union
    34  type Term = types.Term
    35  type TypeParamList = types.TypeParamList
    36  
    37  type inferFuncType struct {
    38  	pkg   *Package
    39  	fn    *internal.Elem
    40  	typ   *types.Signature
    41  	targs []types.Type
    42  	src   ast.Node
    43  }
    44  
    45  func newInferFuncType(pkg *Package, fn *internal.Elem, typ *types.Signature, targs []types.Type, src ast.Node) *inferFuncType {
    46  	return &inferFuncType{pkg: pkg, fn: fn, typ: typ, targs: targs, src: src}
    47  }
    48  
    49  func (p *inferFuncType) Type() types.Type {
    50  	fatal("infer of type")
    51  	return p.typ
    52  }
    53  
    54  func (p *inferFuncType) Underlying() types.Type {
    55  	fatal("infer of type")
    56  	return nil
    57  }
    58  
    59  func (p *inferFuncType) String() string {
    60  	return fmt.Sprintf("inferFuncType{typ: %v, targs: %v}", p.typ, p.targs)
    61  }
    62  
    63  func (p *inferFuncType) Instance() *types.Signature {
    64  	tyRet, err := inferFuncTargs(p.pkg, p.fn, p.typ, p.targs)
    65  	if err != nil {
    66  		pos := getSrcPos(p.src)
    67  		p.pkg.cb.panicCodeErrorf(pos, "%v", err)
    68  	}
    69  	return tyRet.(*types.Signature)
    70  }
    71  
    72  func (p *inferFuncType) InstanceWithArgs(args []*internal.Elem, flags InstrFlags) *types.Signature {
    73  	tyRet, err := InferFunc(p.pkg, p.fn, p.typ, p.targs, args, flags)
    74  	if err != nil {
    75  		pos := getSrcPos(p.src)
    76  		p.pkg.cb.panicCodeErrorf(pos, "%v", err)
    77  	}
    78  	return tyRet.(*types.Signature)
    79  }
    80  
    81  func isGenericType(typ types.Type) bool {
    82  	switch t := typ.(type) {
    83  	case *types.Named:
    84  		return t.Obj() != nil && t.TypeArgs() == nil && t.TypeParams() != nil
    85  	case *types.Signature:
    86  		return t.TypeParams() != nil
    87  	}
    88  	return false
    89  }
    90  
    91  func (p *CodeBuilder) instantiate(nidx int, args []*internal.Elem, src ...ast.Node) *CodeBuilder {
    92  	typ := args[0].Type
    93  	var tt bool
    94  	if t, ok := typ.(*TypeType); ok {
    95  		typ = t.Type()
    96  		tt = true
    97  	}
    98  	p.ensureLoaded(typ)
    99  	srcExpr := getSrc(src)
   100  	if !isGenericType(typ) {
   101  		pos := getSrcPos(srcExpr)
   102  		if tt {
   103  			p.panicCodeErrorf(pos, "%v is not a generic type", typ)
   104  		} else {
   105  			p.panicCodeErrorf(pos, "invalid operation: cannot index %v (value of type %v)", types.ExprString(args[0].Val), typ)
   106  		}
   107  	}
   108  	targs := make([]types.Type, nidx)
   109  	indices := make([]ast.Expr, nidx)
   110  	for i := 0; i < nidx; i++ {
   111  		targs[i] = args[i+1].Type.(*TypeType).Type()
   112  		p.ensureLoaded(targs[i])
   113  		indices[i] = args[i+1].Val
   114  	}
   115  	var tyRet types.Type
   116  	var err error
   117  	if tt {
   118  		tyRet, err = types.Instantiate(p.ctxt, typ, targs, true)
   119  		if err == nil {
   120  			tyRet = NewTypeType(tyRet)
   121  		}
   122  	} else {
   123  		sig := typ.(*types.Signature)
   124  		if nidx >= sig.TypeParams().Len() {
   125  			tyRet, err = types.Instantiate(p.ctxt, typ, targs, true)
   126  		} else {
   127  			tyRet = newInferFuncType(p.pkg, args[0], sig, targs, srcExpr)
   128  		}
   129  	}
   130  	if err != nil {
   131  		pos := getSrcPos(srcExpr)
   132  		p.panicCodeErrorf(pos, "%v", err)
   133  	}
   134  	if debugMatch {
   135  		log.Println("==> InferType", tyRet)
   136  	}
   137  	elem := &internal.Elem{
   138  		Type: tyRet, Src: srcExpr,
   139  	}
   140  	if nidx == 1 {
   141  		elem.Val = &ast.IndexExpr{X: args[0].Val, Index: indices[0]}
   142  	} else {
   143  		elem.Val = &ast.IndexListExpr{X: args[0].Val, Indices: indices}
   144  	}
   145  	p.stk.Ret(nidx+1, elem)
   146  	return p
   147  }
   148  
   149  type typesContext = types.Context
   150  
   151  func newTypesContext() *typesContext {
   152  	return types.NewContext()
   153  }
   154  
   155  func toRecvType(pkg *Package, typ types.Type) ast.Expr {
   156  	var star bool
   157  	if t, ok := typ.(*types.Pointer); ok {
   158  		typ = t.Elem()
   159  		star = true
   160  	}
   161  	t, ok := typ.(*types.Named)
   162  	if !ok {
   163  		panic("unexpected: recv type must types.Named")
   164  	}
   165  	expr := toObjectExpr(pkg, t.Obj())
   166  	if tparams := t.TypeParams(); tparams != nil {
   167  		n := tparams.Len()
   168  		indices := make([]ast.Expr, n)
   169  		for i := 0; i < n; i++ {
   170  			indices[i] = toObjectExpr(pkg, tparams.At(i).Obj())
   171  		}
   172  		if n == 1 {
   173  			expr = &ast.IndexExpr{
   174  				X:     expr,
   175  				Index: indices[0],
   176  			}
   177  		} else {
   178  			expr = &ast.IndexListExpr{
   179  				X:       expr,
   180  				Indices: indices,
   181  			}
   182  		}
   183  	}
   184  	if star {
   185  		expr = &ast.StarExpr{X: expr}
   186  	}
   187  	return expr
   188  }
   189  
   190  func toNamedType(pkg *Package, t *types.Named) ast.Expr {
   191  	expr := toObjectExpr(pkg, t.Obj())
   192  	if targs := t.TypeArgs(); targs != nil {
   193  		n := targs.Len()
   194  		indices := make([]ast.Expr, n)
   195  		for i := 0; i < n; i++ {
   196  			indices[i] = toType(pkg, targs.At(i))
   197  		}
   198  		if n == 1 {
   199  			expr = &ast.IndexExpr{
   200  				X:     expr,
   201  				Index: indices[0],
   202  			}
   203  		} else {
   204  			expr = &ast.IndexListExpr{
   205  				X:       expr,
   206  				Indices: indices,
   207  			}
   208  		}
   209  	}
   210  	return expr
   211  }
   212  
   213  type operandMode byte
   214  
   215  //type builtinId int
   216  
   217  type operand struct {
   218  	mode operandMode
   219  	expr ast.Expr
   220  	typ  types.Type
   221  	val  constant.Value
   222  	//id   builtinId
   223  }
   224  
   225  const (
   226  	invalid   operandMode = iota // operand is invalid
   227  	novalue                      // operand represents no value (result of a function call w/o result)
   228  	builtin                      // operand is a built-in function
   229  	typexpr                      // operand is a type
   230  	constant_                    // operand is a constant; the operand's typ is a Basic type
   231  	variable                     // operand is an addressable variable
   232  	mapindex                     // operand is a map index expression (acts like a variable on lhs, commaok on rhs of an assignment)
   233  	value                        // operand is a computed value
   234  	commaok                      // like value, but operand may be used in a comma,ok expression
   235  	commaerr                     // like commaok, but second value is error, not boolean
   236  	cgofunc                      // operand is a cgo function
   237  )
   238  
   239  type positioner interface {
   240  	Pos() token.Pos
   241  }
   242  
   243  //go:linkname checker_infer go/types.(*Checker).infer
   244  func checker_infer(check *types.Checker, posn positioner, tparams []*types.TypeParam, targs []types.Type, params *types.Tuple, args []*operand) (result []types.Type)
   245  
   246  func infer(pkg *Package, posn positioner, tparams []*types.TypeParam, targs []types.Type, params *types.Tuple, args []*operand) (result []types.Type, err error) {
   247  	conf := &types.Config{
   248  		Error: func(e error) {
   249  			err = e
   250  			if terr, ok := e.(types.Error); ok {
   251  				err = fmt.Errorf("%s", terr.Msg)
   252  			}
   253  		},
   254  	}
   255  	checker := types.NewChecker(conf, pkg.Fset, pkg.Types, nil)
   256  	result = checker_infer(checker, posn, tparams, targs, params, args)
   257  	return
   258  }
   259  
   260  func getParamsTypes(tuple *types.Tuple, variadic bool) string {
   261  	n := tuple.Len()
   262  	if n == 0 {
   263  		return ""
   264  	}
   265  	typs := make([]string, n)
   266  	for i := 0; i < n; i++ {
   267  		typs[i] = tuple.At(i).Type().String()
   268  	}
   269  	if variadic {
   270  		typs[n-1] = "..." + tuple.At(n-1).Type().(*types.Slice).Elem().String()
   271  	}
   272  	return strings.Join(typs, ", ")
   273  }
   274  
   275  func checkInferArgs(pkg *Package, fn *internal.Elem, sig *types.Signature, args []*internal.Elem, flags InstrFlags) ([]*internal.Elem, error) {
   276  	nargs := len(args)
   277  	nreq := sig.Params().Len()
   278  	if sig.Variadic() {
   279  		if nargs < nreq-1 {
   280  			caller := types.ExprString(fn.Val)
   281  			return nil, fmt.Errorf(
   282  				"not enough arguments in call to %s\n\thave (%v)\n\twant (%v)", caller, getTypes(args), getParamsTypes(sig.Params(), true))
   283  		}
   284  		if flags&InstrFlagEllipsis != 0 {
   285  			return args, nil
   286  		}
   287  		var typ types.Type
   288  		if nargs < nreq {
   289  			typ = sig.Params().At(nreq - 1).Type()
   290  			elem := typ.(*types.Slice).Elem()
   291  			if t, ok := elem.(*types.TypeParam); ok {
   292  				return nil, fmt.Errorf("cannot infer %v (%v)", elem, pkg.cb.fset.Position(t.Obj().Pos()))
   293  			}
   294  		} else {
   295  			typ = types.NewSlice(types.Default(args[nreq-1].Type))
   296  		}
   297  		res := make([]*internal.Elem, nreq)
   298  		for i := 0; i < nreq-1; i++ {
   299  			res[i] = args[i]
   300  		}
   301  		res[nreq-1] = &internal.Elem{Type: typ}
   302  		return res, nil
   303  	} else if nreq != nargs {
   304  		fewOrMany := "not enough"
   305  		if nargs > nreq {
   306  			fewOrMany = "too many"
   307  		}
   308  		caller := types.ExprString(fn.Val)
   309  		return nil, fmt.Errorf(
   310  			"%s arguments in call to %s\n\thave (%v)\n\twant (%v)", fewOrMany, caller, getTypes(args), getParamsTypes(sig.Params(), false))
   311  	}
   312  	return args, nil
   313  }
   314  
   315  // InferFunc attempts to infer and instantiates the generic function instantiation with given func and args.
   316  func InferFunc(pkg *Package, fn *Element, sig *types.Signature, targs []types.Type, args []*Element, flags InstrFlags) (types.Type, error) {
   317  	_, typ, err := inferFunc(pkg, fn, sig, targs, args, flags)
   318  	return typ, err
   319  }
   320  
   321  func inferFuncTargs(pkg *Package, fn *internal.Elem, sig *types.Signature, targs []types.Type) (types.Type, error) {
   322  	tp := sig.TypeParams()
   323  	n := tp.Len()
   324  	tparams := make([]*types.TypeParam, n)
   325  	for i := 0; i < n; i++ {
   326  		tparams[i] = tp.At(i)
   327  	}
   328  	targs, err := infer(pkg, fn.Val, tparams, targs, nil, nil)
   329  	if err != nil {
   330  		return nil, err
   331  	}
   332  	return types.Instantiate(pkg.cb.ctxt, sig, targs, true)
   333  }
   334  
   335  func toFieldListX(pkg *Package, t *types.TypeParamList) *ast.FieldList {
   336  	if t == nil {
   337  		return nil
   338  	}
   339  	n := t.Len()
   340  	flds := make([]*ast.Field, n)
   341  	for i := 0; i < n; i++ {
   342  		item := t.At(i)
   343  		names := []*ast.Ident{ast.NewIdent(item.Obj().Name())}
   344  		typ := toType(pkg, item.Constraint())
   345  		flds[i] = &ast.Field{Names: names, Type: typ}
   346  	}
   347  	return &ast.FieldList{
   348  		List: flds,
   349  	}
   350  }
   351  
   352  func toFuncType(pkg *Package, sig *types.Signature) *ast.FuncType {
   353  	params := toFieldList(pkg, sig.Params())
   354  	results := toFieldList(pkg, sig.Results())
   355  	if sig.Variadic() {
   356  		n := len(params)
   357  		if n == 0 {
   358  			panic("TODO: toFuncType error")
   359  		}
   360  		toVariadic(params[n-1])
   361  	}
   362  	return &ast.FuncType{
   363  		TypeParams: toFieldListX(pkg, sig.TypeParams()),
   364  		Params:     &ast.FieldList{List: params},
   365  		Results:    &ast.FieldList{List: results},
   366  	}
   367  }
   368  
   369  func toUnionType(pkg *Package, t *types.Union) ast.Expr {
   370  	var v ast.Expr
   371  	n := t.Len()
   372  	for i := 0; i < n; i++ {
   373  		term := t.Term(i)
   374  		typ := toType(pkg, term.Type())
   375  		if term.Tilde() {
   376  			typ = &ast.UnaryExpr{
   377  				Op: token.TILDE,
   378  				X:  typ,
   379  			}
   380  		}
   381  		if n == 1 {
   382  			return typ
   383  		}
   384  		if i == 0 {
   385  			v = typ
   386  		} else {
   387  			v = &ast.BinaryExpr{
   388  				X:  v,
   389  				Op: token.OR,
   390  				Y:  typ,
   391  			}
   392  		}
   393  	}
   394  	return v
   395  }
   396  
   397  func setTypeParams(pkg *Package, typ *types.Named, spec *ast.TypeSpec, tparams []*TypeParam) {
   398  	typ.SetTypeParams(tparams)
   399  	n := len(tparams)
   400  	if n == 0 {
   401  		spec.TypeParams = nil
   402  		return
   403  	}
   404  	flds := make([]*ast.Field, n)
   405  	for i := 0; i < n; i++ {
   406  		item := tparams[i]
   407  		names := []*ast.Ident{ast.NewIdent(item.Obj().Name())}
   408  		typ := toType(pkg, item.Constraint())
   409  		flds[i] = &ast.Field{Names: names, Type: typ}
   410  	}
   411  	spec.TypeParams = &ast.FieldList{List: flds}
   412  }
   413  
   414  func interfaceIsImplicit(t *types.Interface) bool {
   415  	return t.IsImplicit()
   416  }
   417  
   418  // ----------------------------------------------------------------------------
   419  
   420  func boundTypeParams(p *Package, fn *Element, sig *types.Signature, args []*Element, flags InstrFlags) (*Element, *types.Signature, []*Element, error) {
   421  	if debugMatch {
   422  		log.Println("boundTypeParams:", goxdbg.Format(p.Fset, fn.Val), "sig:", sig, "args:", len(args), "flags:", flags)
   423  	}
   424  	params := sig.TypeParams()
   425  	if n := params.Len(); n > 0 {
   426  		from := 0
   427  		if (flags & instrFlagGoptFunc) != 0 {
   428  			from = 1
   429  		}
   430  		targs := make([]types.Type, n)
   431  		for i := 0; i < n; i++ {
   432  			arg := args[from+i]
   433  			t, ok := arg.Type.(*TypeType)
   434  			if !ok {
   435  				src, pos := p.cb.loadExpr(arg.Src)
   436  				err := p.cb.newCodeErrorf(pos, "%s (type %v) is not a type", src, arg.Type)
   437  				return fn, sig, args, err
   438  			}
   439  			targs[i] = t.typ
   440  		}
   441  		ret, err := types.Instantiate(p.cb.ctxt, sig, targs, true)
   442  		if err != nil {
   443  			return fn, sig, args, err
   444  		}
   445  		indices := make([]ast.Expr, n)
   446  		for i := 0; i < n; i++ {
   447  			indices[i] = args[from+i].Val
   448  		}
   449  		fn = &Element{Val: &ast.IndexListExpr{X: fn.Val, Indices: indices}, Type: ret, Src: fn.Src}
   450  		sig = ret.(*types.Signature)
   451  		if from > 0 {
   452  			args = append([]*Element{args[0]}, args[n+1:]...)
   453  		} else {
   454  			args = args[n:]
   455  		}
   456  	}
   457  	return fn, sig, args, nil
   458  }
   459  
   460  func (p *Package) Instantiate(orig types.Type, targs []types.Type, src ...ast.Node) types.Type {
   461  	p.cb.ensureLoaded(orig)
   462  	for _, targ := range targs {
   463  		p.cb.ensureLoaded(targ)
   464  	}
   465  	ctxt := p.cb.ctxt
   466  	if on, ok := CheckOverloadNamed(orig); ok {
   467  		var first error
   468  		for _, t := range on.Types {
   469  			ret, err := types.Instantiate(ctxt, t, targs, true)
   470  			if err == nil {
   471  				return ret
   472  			}
   473  			if first == nil {
   474  				first = err
   475  			}
   476  		}
   477  		p.cb.handleCodeError(getPos(src), first.Error())
   478  		return types.Typ[types.Invalid]
   479  	}
   480  	if !isGenericType(orig) {
   481  		p.cb.handleCodeErrorf(getPos(src), "%v is not a generic type", orig)
   482  		return types.Typ[types.Invalid]
   483  	}
   484  	ret, err := types.Instantiate(ctxt, orig, targs, true)
   485  	if err != nil {
   486  		p.cb.handleCodeError(getPos(src), err.Error())
   487  	}
   488  	return ret
   489  }
   490  
   491  func paramsToArgs(sig *types.Signature) []*internal.Elem {
   492  	n := sig.Params().Len()
   493  	args := make([]*internal.Elem, n)
   494  	for i := 0; i < n; i++ {
   495  		p := sig.Params().At(i)
   496  		args[i] = &internal.Elem{
   497  			Val:  ast.NewIdent(p.Name()),
   498  			Type: p.Type(),
   499  		}
   500  	}
   501  	return args
   502  }
   503  
   504  func instanceInferFunc(pkg *Package, arg *internal.Elem, tsig *inferFuncType, sig *types.Signature) error {
   505  	args := paramsToArgs(sig)
   506  	targs, _, err := inferFunc(tsig.pkg, tsig.fn, tsig.typ, tsig.targs, args, 0)
   507  	if err != nil {
   508  		return err
   509  	}
   510  	arg.Type = sig
   511  	index := make([]ast.Expr, len(targs))
   512  	for i, a := range targs {
   513  		index[i] = toType(pkg, a)
   514  	}
   515  	var x ast.Expr
   516  	switch v := arg.Val.(type) {
   517  	case *ast.IndexExpr:
   518  		x = v.X
   519  	case *ast.IndexListExpr:
   520  		x = v.X
   521  	}
   522  	arg.Val = &ast.IndexListExpr{
   523  		Indices: index,
   524  		X:       x,
   525  	}
   526  	return nil
   527  }
   528  
   529  func instanceFunc(pkg *Package, arg *internal.Elem, tsig *types.Signature, sig *types.Signature) error {
   530  	args := paramsToArgs(sig)
   531  	targs, _, err := inferFunc(pkg, &internal.Elem{Val: arg.Val}, tsig, nil, args, 0)
   532  	if err != nil {
   533  		return err
   534  	}
   535  	arg.Type = sig
   536  	if len(targs) == 1 {
   537  		arg.Val = &ast.IndexExpr{
   538  			Index: toType(pkg, targs[0]),
   539  			X:     arg.Val,
   540  		}
   541  	} else {
   542  		index := make([]ast.Expr, len(targs))
   543  		for i, a := range targs {
   544  			index[i] = toType(pkg, a)
   545  		}
   546  		arg.Val = &ast.IndexListExpr{
   547  			Indices: index,
   548  			X:       arg.Val,
   549  		}
   550  	}
   551  	return nil
   552  }
   553  
   554  // ----------------------------------------------------------------------------