github.com/goplus/gossa@v0.3.25/context.go (about)

     1  package gossa
     2  
     3  import (
     4  	"bytes"
     5  	"flag"
     6  	"fmt"
     7  	"go/ast"
     8  	"go/parser"
     9  	"go/token"
    10  	"go/types"
    11  	"io"
    12  	"os"
    13  	"path/filepath"
    14  	"reflect"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/goplus/reflectx"
    19  
    20  	"golang.org/x/tools/go/ssa"
    21  	"golang.org/x/tools/go/ssa/ssautil"
    22  )
    23  
    24  // Mode is a bitmask of options affecting the interpreter.
    25  type Mode uint
    26  
    27  const (
    28  	DisableRecover         Mode = 1 << iota // Disable recover() in target programs; show interpreter crash instead.
    29  	DisableCustomBuiltin                    // Disable load custom builtin func
    30  	DisableUnexportMethods                  // Disable unexport methods
    31  	EnableTracing                           // Print a trace of all instructions as they are interpreted.
    32  	EnableDumpInstr                         // Print packages & SSA instruction code
    33  	EnablePrintAny                          // Enable builtin print for any type ( struct/array )
    34  )
    35  
    36  // Loader types loader interface
    37  type Loader interface {
    38  	Import(path string) (*types.Package, error)
    39  	Installed(path string) (*Package, bool)
    40  	Packages() []*types.Package
    41  	LookupReflect(typ types.Type) (reflect.Type, bool)
    42  	LookupTypes(typ reflect.Type) (types.Type, bool)
    43  }
    44  
    45  // Context ssa context
    46  type Context struct {
    47  	Loader      Loader                   // types loader
    48  	Mode        Mode                     // mode
    49  	ParserMode  parser.Mode              // parser mode
    50  	BuilderMode ssa.BuilderMode          // ssa builder mode
    51  	External    types.Importer           // external import
    52  	Sizes       types.Sizes              // types size for package unsafe
    53  	debugFunc   func(*DebugInfo)         // debug func
    54  	override    map[string]reflect.Value // override function
    55  	output      io.Writer                // capture print/println output
    56  	callForPool int                      // least call count for enable function pool
    57  	evalMode    bool                     // eval mode
    58  	evalInit    map[string]bool          // eval init check
    59  	evalCallFn  func(call *ssa.Call, res ...interface{})
    60  }
    61  
    62  func (c *Context) IsEvalMode() bool {
    63  	return c.evalMode
    64  }
    65  
    66  // NewContext create a new Context
    67  func NewContext(mode Mode) *Context {
    68  	ctx := &Context{
    69  		Loader:      NewTypesLoader(mode),
    70  		Mode:        mode,
    71  		ParserMode:  parser.AllErrors,
    72  		BuilderMode: 0, //ssa.SanityCheckFunctions,
    73  		override:    make(map[string]reflect.Value),
    74  		callForPool: 64,
    75  	}
    76  	if mode&EnableDumpInstr != 0 {
    77  		ctx.BuilderMode |= ssa.PrintFunctions
    78  	}
    79  	return ctx
    80  }
    81  
    82  // SetLeastCallForEnablePool set least call count for enable function pool, default 64
    83  func (c *Context) SetLeastCallForEnablePool(count int) {
    84  	c.callForPool = count
    85  }
    86  
    87  func (c *Context) SetDebug(fn func(*DebugInfo)) {
    88  	c.BuilderMode |= ssa.GlobalDebug
    89  	c.debugFunc = fn
    90  }
    91  
    92  // SetOverrideFunction register external function to override function.
    93  // match func fullname and signature
    94  func (c *Context) SetOverrideFunction(key string, fn interface{}) {
    95  	c.override[key] = reflect.ValueOf(fn)
    96  }
    97  
    98  // ClearOverrideFunction reset override function
    99  func (c *Context) ClearOverrideFunction(key string) {
   100  	delete(c.override, key)
   101  }
   102  
   103  // set builtin print/println captured output
   104  func (c *Context) SetPrintOutput(output *bytes.Buffer) {
   105  	c.output = output
   106  }
   107  
   108  func (c *Context) writeOutput(data []byte) (n int, err error) {
   109  	if c.output != nil {
   110  		return c.output.Write(data)
   111  	}
   112  	return os.Stdout.Write(data)
   113  }
   114  
   115  func (c *Context) LoadDir(fset *token.FileSet, path string) (pkgs []*ssa.Package, first error) {
   116  	apkgs, err := parser.ParseDir(fset, path, nil, c.ParserMode)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  	for _, apkg := range apkgs {
   121  		if pkg, err := c.LoadAstPackage(fset, apkg); err == nil {
   122  			pkgs = append(pkgs, pkg)
   123  		} else if first == nil {
   124  			first = err
   125  		}
   126  	}
   127  	return
   128  }
   129  
   130  func RegisterFileProcess(ext string, fn SourceProcessFunc) {
   131  	sourceProcessor[ext] = fn
   132  }
   133  
   134  type SourceProcessFunc func(ctx *Context, filename string, src interface{}) ([]byte, error)
   135  
   136  var (
   137  	sourceProcessor = make(map[string]SourceProcessFunc)
   138  )
   139  
   140  func (c *Context) LoadFile(fset *token.FileSet, filename string, src interface{}) (*ssa.Package, error) {
   141  	file, err := c.ParseFile(fset, filename, src)
   142  	if err != nil {
   143  		return nil, err
   144  	}
   145  	return c.LoadAstFile(fset, file)
   146  }
   147  
   148  func (c *Context) ParseFile(fset *token.FileSet, filename string, src interface{}) (*ast.File, error) {
   149  	if ext := filepath.Ext(filename); ext != "" {
   150  		if fn, ok := sourceProcessor[ext]; ok {
   151  			data, err := fn(c, filename, src)
   152  			if err != nil {
   153  				return nil, err
   154  			}
   155  			src = data
   156  		}
   157  	}
   158  	return parser.ParseFile(fset, filename, src, c.ParserMode)
   159  }
   160  
   161  func (c *Context) LoadAstFile(fset *token.FileSet, file *ast.File) (*ssa.Package, error) {
   162  	pkg := types.NewPackage(file.Name.Name, "")
   163  	files := []*ast.File{file}
   164  	if c.Mode&DisableCustomBuiltin == 0 {
   165  		if f, err := parserBuiltin(fset, file.Name.Name); err == nil {
   166  			files = []*ast.File{f, file}
   167  		}
   168  	}
   169  	ssapkg, _, err := c.BuildPackage(fset, pkg, files)
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  	ssapkg.Build()
   174  	return ssapkg, nil
   175  }
   176  
   177  func (c *Context) LoadAstPackage(fset *token.FileSet, apkg *ast.Package) (*ssa.Package, error) {
   178  	pkg := types.NewPackage(apkg.Name, "")
   179  	var files []*ast.File
   180  	for _, f := range apkg.Files {
   181  		files = append(files, f)
   182  	}
   183  	if c.Mode&DisableCustomBuiltin == 0 {
   184  		if f, err := parserBuiltin(fset, apkg.Name); err == nil {
   185  			files = append([]*ast.File{f}, files...)
   186  		}
   187  	}
   188  	ssapkg, _, err := c.BuildPackage(fset, pkg, files)
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  	ssapkg.Build()
   193  	return ssapkg, nil
   194  }
   195  
   196  func (c *Context) RunPkg(mainPkg *ssa.Package, input string, args []string) (exitCode int, err error) {
   197  	// reset os args and flag
   198  	os.Args = []string{input}
   199  	if args != nil {
   200  		os.Args = append(os.Args, args...)
   201  	}
   202  	flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
   203  
   204  	interp, err := c.NewInterp(mainPkg)
   205  	if err != nil {
   206  		return 2, err
   207  	}
   208  	if err = interp.RunInit(); err != nil {
   209  		return 2, err
   210  	}
   211  	return interp.RunMain()
   212  }
   213  
   214  func (c *Context) RunFunc(mainPkg *ssa.Package, fnname string, args ...Value) (ret Value, err error) {
   215  	interp, err := c.NewInterp(mainPkg)
   216  	if err != nil {
   217  		return nil, err
   218  	}
   219  	return interp.RunFunc(fnname, args...)
   220  }
   221  
   222  func (c *Context) NewInterp(mainPkg *ssa.Package) (*Interp, error) {
   223  	return NewInterp(c, mainPkg)
   224  }
   225  
   226  func (c *Context) TestPkg(pkgs []*ssa.Package, input string, args []string) error {
   227  	var failed bool
   228  	start := time.Now()
   229  	var testPkgs []*ssa.Package
   230  	for _, pkg := range pkgs {
   231  		p, err := CreateTestMainPackage(pkg)
   232  		if err != nil {
   233  			return err
   234  		}
   235  		if p != nil {
   236  			testPkgs = append(testPkgs, p)
   237  		}
   238  	}
   239  	defer func() {
   240  		sec := time.Since(start).Seconds()
   241  		if failed {
   242  			fmt.Printf("FAIL\t%s %0.3fs\n", input, sec)
   243  		} else {
   244  			fmt.Printf("ok\t%s %0.3fs\n", input, sec)
   245  		}
   246  	}()
   247  	if len(testPkgs) == 0 {
   248  		fmt.Println("testing: warning: no tests to run")
   249  	}
   250  	os.Args = []string{input}
   251  	if args != nil {
   252  		os.Args = append(os.Args, args...)
   253  	}
   254  	flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
   255  	for _, pkg := range testPkgs {
   256  		interp, err := NewInterp(c, pkg)
   257  		if err != nil {
   258  			failed = true
   259  			fmt.Printf("create interp failed: %v\n", err)
   260  			continue
   261  		}
   262  		if err = interp.RunInit(); err != nil {
   263  			failed = true
   264  			fmt.Printf("init error: %v\n", err)
   265  			continue
   266  		}
   267  		exitCode, _ := interp.RunMain()
   268  		if exitCode != 0 {
   269  			failed = true
   270  		}
   271  	}
   272  	if failed {
   273  		return ErrTestFailed
   274  	}
   275  	return nil
   276  }
   277  
   278  func (c *Context) RunFile(filename string, src interface{}, args []string) (exitCode int, err error) {
   279  	fset := token.NewFileSet()
   280  	pkg, err := c.LoadFile(fset, filename, src)
   281  	if err != nil {
   282  		return 2, err
   283  	}
   284  	return c.RunPkg(pkg, filename, args)
   285  }
   286  
   287  func (c *Context) Run(path string, args []string) (exitCode int, err error) {
   288  	if strings.HasSuffix(path, ".go") {
   289  		return c.RunFile(path, nil, args)
   290  	}
   291  	fset := token.NewFileSet()
   292  	pkgs, err := c.LoadDir(fset, path)
   293  	if err != nil {
   294  		return 2, err
   295  	}
   296  	mainPkgs := ssautil.MainPackages(pkgs)
   297  	if len(mainPkgs) == 0 {
   298  		return 2, ErrNotFoundMain
   299  	}
   300  	return c.RunPkg(mainPkgs[0], path, args)
   301  }
   302  
   303  func (c *Context) RunTest(path string, args []string) error {
   304  	fset := token.NewFileSet()
   305  	// preload regexp for create testing
   306  	c.Loader.Import("regexp")
   307  	pkgs, err := c.LoadDir(fset, path)
   308  	if err != nil {
   309  		return err
   310  	}
   311  	return c.TestPkg(pkgs, path, args)
   312  }
   313  
   314  func (ctx *Context) BuildPackage(fset *token.FileSet, pkg *types.Package, files []*ast.File) (*ssa.Package, *types.Info, error) {
   315  	if fset == nil {
   316  		panic("no token.FileSet")
   317  	}
   318  	if pkg.Path() == "" {
   319  		panic("package has no import path")
   320  	}
   321  
   322  	info := &types.Info{
   323  		Types:      make(map[ast.Expr]types.TypeAndValue),
   324  		Defs:       make(map[*ast.Ident]types.Object),
   325  		Uses:       make(map[*ast.Ident]types.Object),
   326  		Implicits:  make(map[ast.Node]types.Object),
   327  		Scopes:     make(map[ast.Node]*types.Scope),
   328  		Selections: make(map[*ast.SelectorExpr]*types.Selection),
   329  	}
   330  
   331  	tc := &types.Config{
   332  		Importer: NewImporter(ctx.Loader, ctx.External),
   333  		Sizes:    ctx.Sizes,
   334  	}
   335  	if ctx.evalMode {
   336  		tc.DisableUnusedImportCheck = true
   337  	}
   338  	if err := types.NewChecker(tc, fset, pkg, info).Files(files); err != nil {
   339  		return nil, nil, err
   340  	}
   341  
   342  	prog := ssa.NewProgram(fset, ctx.BuilderMode)
   343  
   344  	// Create SSA packages for all imports.
   345  	// Order is not significant.
   346  	created := make(map[*types.Package]bool)
   347  	var createAll func(pkgs []*types.Package)
   348  	createAll = func(pkgs []*types.Package) {
   349  		for _, p := range pkgs {
   350  			if !created[p] {
   351  				created[p] = true
   352  				if !p.Complete() {
   353  					if ctx.Mode&EnableDumpInstr != 0 {
   354  						fmt.Println("# indirect", p)
   355  					}
   356  					p.MarkComplete()
   357  				} else {
   358  					if ctx.Mode&EnableDumpInstr != 0 {
   359  						fmt.Println("# imported", p)
   360  					}
   361  				}
   362  				prog.CreatePackage(p, nil, nil, true)
   363  				createAll(p.Imports())
   364  			}
   365  		}
   366  	}
   367  	// create imports
   368  	createAll(pkg.Imports())
   369  	// create indirect depends
   370  	createAll(ctx.Loader.Packages())
   371  
   372  	// Create and build the primary package.
   373  	ssapkg := prog.CreatePackage(pkg, files, info, false)
   374  	ssapkg.Build()
   375  	return ssapkg, info, nil
   376  }
   377  
   378  func RunFile(filename string, src interface{}, args []string, mode Mode) (exitCode int, err error) {
   379  	reflectx.Reset()
   380  	ctx := NewContext(mode)
   381  	return ctx.RunFile(filename, src, args)
   382  }
   383  
   384  func Run(path string, args []string, mode Mode) (exitCode int, err error) {
   385  	reflectx.Reset()
   386  	ctx := NewContext(mode)
   387  	return ctx.Run(path, args)
   388  }
   389  
   390  func RunTest(path string, args []string, mode Mode) error {
   391  	reflectx.Reset()
   392  	ctx := NewContext(mode)
   393  	return ctx.RunTest(path, args)
   394  }
   395  
   396  var (
   397  	builtinPkg = &Package{
   398  		Name:          "builtin",
   399  		Path:          "github.com/goplus/gossa/builtin",
   400  		Deps:          make(map[string]string),
   401  		Interfaces:    map[string]reflect.Type{},
   402  		NamedTypes:    map[string]NamedType{},
   403  		AliasTypes:    map[string]reflect.Type{},
   404  		Vars:          map[string]reflect.Value{},
   405  		Funcs:         map[string]reflect.Value{},
   406  		TypedConsts:   map[string]TypedConst{},
   407  		UntypedConsts: map[string]UntypedConst{},
   408  	}
   409  	builtinPrefix = "Builtin_"
   410  )
   411  
   412  func init() {
   413  	RegisterPackage(builtinPkg)
   414  }
   415  
   416  func RegisterCustomBuiltin(key string, fn interface{}) error {
   417  	v := reflect.ValueOf(fn)
   418  	switch v.Kind() {
   419  	case reflect.Func:
   420  		if !strings.HasPrefix(key, builtinPrefix) {
   421  			key = builtinPrefix + key
   422  		}
   423  		builtinPkg.Funcs[key] = v
   424  		typ := v.Type()
   425  		for i := 0; i < typ.NumIn(); i++ {
   426  			checkBuiltinDeps(typ.In(i))
   427  		}
   428  		for i := 0; i < typ.NumOut(); i++ {
   429  			checkBuiltinDeps(typ.Out(i))
   430  		}
   431  		return nil
   432  	}
   433  	return ErrNoFunction
   434  }
   435  
   436  func checkBuiltinDeps(typ reflect.Type) {
   437  	if typ.PkgPath() != "" {
   438  		builtinPkg.Deps[typ.Name()] = typ.PkgPath()
   439  	}
   440  }
   441  
   442  var (
   443  	builtin_tmpl = `package main
   444  import "github.com/goplus/gossa/builtin"
   445  `
   446  )
   447  
   448  func parserBuiltin(fset *token.FileSet, pkg string) (*ast.File, error) {
   449  	var list []string
   450  	for k, _ := range builtinPkg.Funcs {
   451  		if strings.HasPrefix(k, builtinPrefix) {
   452  			list = append(list, k[len(builtinPrefix):]+"=builtin."+k)
   453  		}
   454  	}
   455  	if len(list) == 0 {
   456  		return nil, os.ErrInvalid
   457  	}
   458  	src := fmt.Sprintf(`package %v
   459  import "github.com/goplus/gossa/builtin"
   460  var (
   461  	%v
   462  )
   463  `, pkg, strings.Join(list, "\n"))
   464  	return parser.ParseFile(fset, "gossa_builtin.go", src, 0)
   465  }