github.com/goplus/igop@v0.25.0/visit.go (about)

     1  /*
     2   * Copyright (c) 2022 The GoPlus Authors (goplus.org). All rights reserved.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package igop
    18  
    19  import (
    20  	"fmt"
    21  	"go/ast"
    22  	"go/token"
    23  	"go/types"
    24  	"log"
    25  	"reflect"
    26  	"strings"
    27  
    28  	"github.com/goplus/igop/load"
    29  	"github.com/visualfc/xtype"
    30  	"golang.org/x/tools/go/ssa"
    31  )
    32  
    33  const (
    34  	fnBase = 100
    35  )
    36  
    37  func checkPackages(intp *Interp, pkgs []*ssa.Package) (err error) {
    38  	if intp.ctx.Mode&DisableRecover == 0 {
    39  		defer func() {
    40  			if v := recover(); v != nil {
    41  				err = v.(error)
    42  			}
    43  		}()
    44  	}
    45  	visit := visitor{
    46  		intp: intp,
    47  		prog: intp.mainpkg.Prog,
    48  		pkgs: make(map[*ssa.Package]bool),
    49  		seen: make(map[*ssa.Function]bool),
    50  		base: fnBase,
    51  	}
    52  	for _, pkg := range pkgs {
    53  		visit.pkgs[pkg] = true
    54  	}
    55  	visit.program()
    56  	return
    57  }
    58  
    59  type visitor struct {
    60  	intp *Interp
    61  	prog *ssa.Program
    62  	pkgs map[*ssa.Package]bool
    63  	seen map[*ssa.Function]bool
    64  	base int
    65  }
    66  
    67  func (visit *visitor) program() {
    68  	chks := make(map[string]bool)
    69  	chks[""] = true // anonymous struct embed named type
    70  	for pkg := range visit.pkgs {
    71  		chks[pkg.Pkg.Path()] = true
    72  	}
    73  
    74  	isExtern := func(typ reflect.Type) bool {
    75  		if typ.Kind() == reflect.Ptr {
    76  			typ = typ.Elem()
    77  		}
    78  		return !chks[typ.PkgPath()]
    79  	}
    80  
    81  	methodsOf := func(T types.Type) {
    82  		if types.IsInterface(T) {
    83  			return
    84  		}
    85  		typ := visit.intp.preToType(T)
    86  		// skip extern type
    87  		if isExtern(typ) {
    88  			return
    89  		}
    90  		mmap := make(map[string]*ssa.Function)
    91  		mset := visit.prog.MethodSets.MethodSet(T)
    92  		for i, n := 0, mset.Len(); i < n; i++ {
    93  			sel := mset.At(i)
    94  			obj := sel.Obj()
    95  			// skip embed extern type method
    96  			if pkg := obj.Pkg(); pkg != nil {
    97  				if !chks[pkg.Path()] {
    98  					continue
    99  				}
   100  				if visit.intp.ctx.Mode&CheckGopOverloadFunc != 0 && obj.Pos() == token.NoPos {
   101  					continue
   102  				}
   103  			}
   104  			fn := visit.prog.MethodValue(sel)
   105  			mmap[obj.Name()] = fn
   106  			visit.function(fn)
   107  		}
   108  		visit.intp.msets[typ] = mmap
   109  	}
   110  
   111  	exportedTypeHack := func(t *ssa.Type) {
   112  		if ast.IsExported(t.Name()) && !types.IsInterface(t.Type()) {
   113  			if named, ok := t.Type().(*types.Named); ok && !hasTypeParam(named) {
   114  				methodsOf(named)                   //  T
   115  				methodsOf(types.NewPointer(named)) // *T
   116  			}
   117  		}
   118  	}
   119  
   120  	for pkg := range visit.pkgs {
   121  		for _, mem := range pkg.Members {
   122  			switch mem := mem.(type) {
   123  			case *ssa.Function:
   124  				visit.function(mem)
   125  			case *ssa.Type:
   126  				exportedTypeHack(mem)
   127  			}
   128  		}
   129  	}
   130  
   131  	for _, T := range visit.prog.RuntimeTypes() {
   132  		methodsOf(T)
   133  	}
   134  }
   135  
   136  func (visit *visitor) findLinkSym(fn *ssa.Function) (*load.LinkSym, bool) {
   137  	if sp, ok := visit.intp.ctx.pkgs[fn.Pkg.Pkg.Path()]; ok {
   138  		for _, link := range sp.Links {
   139  			if link.Name == fn.Name() {
   140  				return link, true
   141  			}
   142  		}
   143  	}
   144  	return nil, false
   145  }
   146  
   147  func (visit *visitor) findFunction(sym *load.LinkSym) *ssa.Function {
   148  	for pkg := range visit.pkgs {
   149  		if pkg.Pkg.Path() == sym.Linkname.PkgPath {
   150  			if typ := sym.Linkname.Recv; typ != "" {
   151  				var star bool
   152  				if typ[0] == '*' {
   153  					star = true
   154  					typ = typ[1:]
   155  				}
   156  				if obj := pkg.Pkg.Scope().Lookup(typ); obj != nil {
   157  					t := obj.Type()
   158  					if star {
   159  						t = types.NewPointer(t)
   160  					}
   161  					return visit.prog.LookupMethod(t, pkg.Pkg, sym.Linkname.Method)
   162  				}
   163  			}
   164  			return pkg.Func(sym.Linkname.Name)
   165  		}
   166  	}
   167  	return nil
   168  }
   169  
   170  func wrapMethodType(sig *types.Signature) *types.Signature {
   171  	params := sig.Params()
   172  	n := params.Len()
   173  	list := make([]*types.Var, n+1)
   174  	list[0] = sig.Recv()
   175  	for i := 0; i < n; i++ {
   176  		list[i+1] = params.At(i)
   177  	}
   178  	return types.NewSignature(nil, types.NewTuple(list...), sig.Results(), sig.Variadic())
   179  }
   180  
   181  func (visit *visitor) findLinkFunc(sym *load.LinkSym) (ext reflect.Value, ok bool) {
   182  	ext, ok = findExternLinkFunc(visit.intp, &sym.Linkname)
   183  	if ok {
   184  		return
   185  	}
   186  	if link := visit.findFunction(sym); link != nil {
   187  		visit.function(link)
   188  		sig := link.Signature
   189  		if sig.Recv() != nil {
   190  			sig = wrapMethodType(sig)
   191  		}
   192  		typ := visit.intp.preToType(sig)
   193  		pfn := visit.intp.loadFunction(link)
   194  		ext = pfn.makeFunction(typ, nil)
   195  		ok = true
   196  	}
   197  	return
   198  }
   199  
   200  func (visit *visitor) function(fn *ssa.Function) {
   201  	if visit.seen[fn] {
   202  		return
   203  	}
   204  	if hasTypeParam(fn.Type()) {
   205  		return
   206  	}
   207  	visit.seen[fn] = true
   208  	fnPath := fn.String()
   209  	if f, ok := visit.intp.ctx.override[fnPath]; ok &&
   210  		visit.intp.preToType(fn.Type()) == f.Type() {
   211  		fn.Blocks = nil
   212  		return
   213  	}
   214  	if fn.Blocks == nil {
   215  		if _, ok := visit.pkgs[fn.Pkg]; ok {
   216  			if _, ok = findExternFunc(visit.intp, fn); !ok {
   217  				if sym, ok := visit.findLinkSym(fn); ok {
   218  					if ext, ok := visit.findLinkFunc(sym); ok {
   219  						typ := visit.intp.preToType(fn.Type())
   220  						ftyp := ext.Type()
   221  						if typ != ftyp {
   222  							ext = xtype.ConvertFunc(ext, xtype.TypeOfType(typ))
   223  						}
   224  						visit.intp.ctx.override[fnPath] = ext
   225  						return
   226  					}
   227  				}
   228  				if visit.intp.ctx.Mode&EnableNoStrict != 0 {
   229  					typ := visit.intp.preToType(fn.Type())
   230  					numOut := typ.NumOut()
   231  					if numOut == 0 {
   232  						visit.intp.ctx.override[fnPath] = reflect.MakeFunc(typ, func(args []reflect.Value) (results []reflect.Value) {
   233  							return
   234  						})
   235  					} else {
   236  						visit.intp.ctx.override[fnPath] = reflect.MakeFunc(typ, func(args []reflect.Value) (results []reflect.Value) {
   237  							results = make([]reflect.Value, numOut)
   238  							for i := 0; i < numOut; i++ {
   239  								results[i] = reflect.New(typ.Out(i)).Elem()
   240  							}
   241  							return
   242  						})
   243  					}
   244  					println(fmt.Sprintf("igop warning: %v: %v missing function body", visit.intp.ctx.FileSet.Position(fn.Pos()), fnPath))
   245  					return
   246  				}
   247  				panic(fmt.Errorf("%v: %v missing function body", visit.intp.ctx.FileSet.Position(fn.Pos()), fnPath))
   248  			}
   249  		}
   250  		return
   251  	}
   252  	if len(fn.TypeArgs()) != 0 {
   253  		visit.intp.record.EnterInstance(fn)
   254  		defer visit.intp.record.LeaveInstance(fn)
   255  	}
   256  	visit.intp.loadType(fn.Type())
   257  	for _, alloc := range fn.Locals {
   258  		visit.intp.loadType(alloc.Type())
   259  		visit.intp.loadType(deref(alloc.Type()))
   260  	}
   261  	pfn := visit.intp.loadFunction(fn)
   262  	for _, p := range fn.Params {
   263  		pfn.regIndex(p)
   264  	}
   265  	for _, p := range fn.FreeVars {
   266  		pfn.regIndex(p)
   267  	}
   268  	var buf [32]*ssa.Value // avoid alloc in common case
   269  	for _, b := range fn.Blocks {
   270  		Instrs := make([]func(*frame), len(b.Instrs))
   271  		ssaInstrs := make([]ssa.Instruction, len(b.Instrs))
   272  		var index int
   273  		n := len(b.Instrs)
   274  		for i := 0; i < n; i++ {
   275  			instr := b.Instrs[i]
   276  			ops := instr.Operands(buf[:0])
   277  			switch instr := instr.(type) {
   278  			case *ssa.Alloc:
   279  				visit.intp.loadType(instr.Type())
   280  				visit.intp.loadType(deref(instr.Type()))
   281  			case *ssa.Next:
   282  				// skip *ssa.opaqueType: iter
   283  				ops = nil
   284  			case *ssa.Extract:
   285  				// skip
   286  				ops = nil
   287  			case *ssa.TypeAssert:
   288  				visit.intp.loadType(instr.AssertedType)
   289  			case *ssa.MakeChan:
   290  				visit.intp.loadType(instr.Type())
   291  			case *ssa.MakeMap:
   292  				visit.intp.loadType(instr.Type())
   293  			case *ssa.MakeSlice:
   294  				visit.intp.loadType(instr.Type())
   295  			case *ssa.SliceToArrayPointer:
   296  				visit.intp.loadType(instr.Type())
   297  			case *ssa.Convert:
   298  				visit.intp.loadType(instr.Type())
   299  			case *ssa.ChangeType:
   300  				visit.intp.loadType(instr.Type())
   301  			case *ssa.MakeInterface:
   302  				visit.intp.loadType(instr.Type())
   303  			}
   304  			for _, op := range ops {
   305  				switch v := (*op).(type) {
   306  				case *ssa.Function:
   307  					visit.function(v)
   308  				case nil:
   309  					// skip
   310  				default:
   311  					visit.intp.loadType(v.Type())
   312  				}
   313  			}
   314  			pfn.makeInstr = instr
   315  			ifn := makeInstr(visit.intp, pfn, instr)
   316  			if ifn == nil {
   317  				continue
   318  			}
   319  			if visit.intp.ctx.evalMode && fn.String() == "main.init" {
   320  				if visit.intp.ctx.evalInit == nil {
   321  					visit.intp.ctx.evalInit = make(map[string]bool)
   322  				}
   323  				if call, ok := instr.(*ssa.Call); ok {
   324  					key := call.String()
   325  					if strings.HasPrefix(key, "init#") {
   326  						if visit.intp.ctx.evalInit[key] {
   327  							ifn = func(fr *frame) {}
   328  						} else {
   329  							visit.intp.ctx.evalInit[key] = true
   330  						}
   331  					}
   332  				}
   333  			}
   334  			if visit.intp.ctx.evalCallFn != nil {
   335  				if call, ok := instr.(*ssa.Call); ok {
   336  					ir := pfn.regIndex(call)
   337  					results := call.Call.Signature().Results()
   338  					ofn := ifn
   339  					switch results.Len() {
   340  					case 0:
   341  						ifn = func(fr *frame) {
   342  							ofn(fr)
   343  							visit.intp.ctx.evalCallFn(visit.intp, call)
   344  						}
   345  					case 1:
   346  						ifn = func(fr *frame) {
   347  							ofn(fr)
   348  							visit.intp.ctx.evalCallFn(visit.intp, call, fr.reg(ir))
   349  						}
   350  					default:
   351  						ifn = func(fr *frame) {
   352  							ofn(fr)
   353  							r := fr.reg(ir).(tuple)
   354  							visit.intp.ctx.evalCallFn(visit.intp, call, r...)
   355  						}
   356  					}
   357  				}
   358  			}
   359  			if visit.intp.ctx.Mode&EnableTracing != 0 {
   360  				ofn := ifn
   361  				ifn = func(fr *frame) {
   362  					if v, ok := instr.(ssa.Value); ok {
   363  						log.Printf("\t%-20T %v = %-40v\t%v\n", instr, v.Name(), instr, v.Type())
   364  					} else {
   365  						log.Printf("\t%-20T %v\n", instr, instr)
   366  					}
   367  					ofn(fr)
   368  				}
   369  				if index == 0 {
   370  					ofn := ifn
   371  					bi := b.Index
   372  					common := b.Comment
   373  					ifn = func(fr *frame) {
   374  						log.Printf(".%v %v\n", bi, common)
   375  						ofn(fr)
   376  					}
   377  				}
   378  				if index == 0 && b.Index == 0 {
   379  					ofn := ifn
   380  					ifn = func(fr *frame) {
   381  						log.Printf("Entering %v%v.", fr.pfn.Fn, loc(fr.interp.ctx.FileSet, fr.pfn.Fn.Pos()))
   382  						ofn(fr)
   383  					}
   384  				}
   385  				if _, ok := instr.(*ssa.Return); ok {
   386  					ofn := ifn
   387  					ifn = func(fr *frame) {
   388  						ofn(fr)
   389  						var caller ssa.Instruction
   390  						if fr.caller != nil {
   391  							caller = fr.caller.pfn.InstrForPC(fr.caller.ipc - 1)
   392  						}
   393  						if caller == nil {
   394  							log.Printf("Leaving %v.\n", fr.pfn.Fn)
   395  						} else {
   396  							log.Printf("Leaving %v, resuming %v call %v%v.\n",
   397  								fr.pfn.Fn, fr.caller.pfn.Fn, caller, loc(fr.interp.ctx.FileSet, caller.Pos()))
   398  						}
   399  					}
   400  				}
   401  			}
   402  			Instrs[index] = ifn
   403  			ssaInstrs[index] = instr
   404  			index++
   405  		}
   406  		offset := len(pfn.Instrs)
   407  		pfn.Blocks = append(pfn.Blocks, offset)
   408  		pfn.Instrs = append(pfn.Instrs, Instrs[:index]...)
   409  		pfn.ssaInstrs = append(pfn.ssaInstrs, ssaInstrs[:index]...)
   410  		if b == fn.Recover && visit.intp.ctx.Mode&DisableRecover == 0 {
   411  			pfn.Recover = pfn.Instrs[offset:]
   412  		}
   413  	}
   414  	pfn.makeInstr = nil
   415  	pfn.base = visit.base
   416  	visit.base += len(pfn.ssaInstrs) + 2
   417  	pfn.initPool()
   418  }
   419  
   420  func loc(fset *token.FileSet, pos token.Pos) string {
   421  	if pos == token.NoPos {
   422  		return ""
   423  	}
   424  	return " at " + fset.Position(pos).String()
   425  }