github.com/llvm-mirror/llgo@v0.0.0-20190322182713-bf6f0a60fce1/cmd/llgoi/llgoi.go (about)

     1  //===- llgoi.go - llgo-based Go REPL --------------------------------------===//
     2  //
     3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
     4  // See https://llvm.org/LICENSE.txt for license information.
     5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
     6  //
     7  //===----------------------------------------------------------------------===//
     8  //
     9  // This is llgoi, a Go REPL based on llgo and the LLVM JIT.
    10  //
    11  //===----------------------------------------------------------------------===//
    12  
    13  package main
    14  
    15  import (
    16  	"bytes"
    17  	"errors"
    18  	"fmt"
    19  	"go/ast"
    20  	"go/build"
    21  	"go/parser"
    22  	"go/scanner"
    23  	"go/token"
    24  	"io"
    25  	"os"
    26  	"os/exec"
    27  	"path/filepath"
    28  	"runtime/debug"
    29  	"strconv"
    30  	"strings"
    31  	"unsafe"
    32  
    33  	"llvm.org/llgo/driver"
    34  	"llvm.org/llgo/irgen"
    35  	"llvm.org/llgo/third_party/gotools/go/types"
    36  	"llvm.org/llgo/third_party/liner"
    37  	"llvm.org/llvm/bindings/go/llvm"
    38  )
    39  
    40  // /* Force exporting __morestack if it's available, so that it is
    41  //    available to the engine when linking with libLLVM.so. */
    42  //
    43  // void *__morestack __attribute__((weak));
    44  import "C"
    45  
    46  func getInstPrefix() (string, error) {
    47  	path, err := exec.LookPath(os.Args[0])
    48  	if err != nil {
    49  		return "", err
    50  	}
    51  
    52  	path, err = filepath.EvalSymlinks(path)
    53  	if err != nil {
    54  		return "", err
    55  	}
    56  
    57  	prefix := filepath.Join(path, "..", "..")
    58  	return prefix, nil
    59  }
    60  
    61  func llvmVersion() string {
    62  	return strings.Replace(llvm.Version, "svn", "", 1)
    63  }
    64  
    65  type line struct {
    66  	line     string
    67  	isStmt   bool
    68  	declName string
    69  	assigns  []string
    70  
    71  	parens, bracks, braces int
    72  }
    73  
    74  type interp struct {
    75  	engine llvm.ExecutionEngine
    76  
    77  	liner       *liner.State
    78  	pendingLine line
    79  
    80  	copts irgen.CompilerOptions
    81  
    82  	imports []*types.Package
    83  	scope   map[string]types.Object
    84  
    85  	modules map[string]llvm.Module
    86  	pkgmap  map[string]*types.Package
    87  	pkgnum  int
    88  }
    89  
    90  func (in *interp) makeCompilerOptions() error {
    91  	prefix, err := getInstPrefix()
    92  	if err != nil {
    93  		return err
    94  	}
    95  
    96  	importPaths := []string{filepath.Join(prefix, "lib", "go", "llgo-"+llvmVersion())}
    97  	in.copts = irgen.CompilerOptions{
    98  		TargetTriple:  llvm.DefaultTargetTriple(),
    99  		ImportPaths:   importPaths,
   100  		GenerateDebug: true,
   101  		Packages:      in.pkgmap,
   102  	}
   103  	err = in.copts.MakeImporter()
   104  	if err != nil {
   105  		return err
   106  	}
   107  
   108  	origImporter := in.copts.Importer
   109  	in.copts.Importer = func(pkgmap map[string]*types.Package, pkgpath string) (*types.Package, error) {
   110  		if pkg, ok := pkgmap[pkgpath]; ok && pkg.Complete() {
   111  			return pkg, nil
   112  		}
   113  		return origImporter(pkgmap, pkgpath)
   114  	}
   115  	return nil
   116  }
   117  
   118  func (in *interp) init() error {
   119  	in.liner = liner.NewLiner()
   120  	in.scope = make(map[string]types.Object)
   121  	in.pkgmap = make(map[string]*types.Package)
   122  	in.modules = make(map[string]llvm.Module)
   123  
   124  	err := in.makeCompilerOptions()
   125  	if err != nil {
   126  		return err
   127  	}
   128  
   129  	return nil
   130  }
   131  
   132  func (in *interp) dispose() {
   133  	in.liner.Close()
   134  	in.engine.Dispose()
   135  }
   136  
   137  func (in *interp) loadSourcePackageFromCode(pkgcode, pkgpath string, copts irgen.CompilerOptions) (*types.Package, error) {
   138  	fset := token.NewFileSet()
   139  	file, err := parser.ParseFile(fset, "<input>", pkgcode, parser.DeclarationErrors|parser.ParseComments)
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  	files := []*ast.File{file}
   144  	return in.loadSourcePackage(fset, files, pkgpath, copts)
   145  }
   146  
   147  func (in *interp) loadSourcePackage(fset *token.FileSet, files []*ast.File, pkgpath string, copts irgen.CompilerOptions) (_ *types.Package, resultErr error) {
   148  	compiler, err := irgen.NewCompiler(copts)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	module, err := compiler.Compile(fset, files, pkgpath)
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  	in.modules[pkgpath] = module.Module
   158  
   159  	if in.engine.C != nil {
   160  		in.engine.AddModule(module.Module)
   161  	} else {
   162  		options := llvm.NewMCJITCompilerOptions()
   163  		in.engine, err = llvm.NewMCJITCompiler(module.Module, options)
   164  		if err != nil {
   165  			return nil, err
   166  		}
   167  	}
   168  
   169  	var importFunc func()
   170  	importAddress := in.getPackageSymbol(pkgpath, ".import$descriptor")
   171  	*(*unsafe.Pointer)(unsafe.Pointer(&importFunc)) = importAddress
   172  
   173  	defer func() {
   174  		p := recover()
   175  		if p != nil {
   176  			resultErr = fmt.Errorf("panic: %v\n%v", p, string(debug.Stack()))
   177  		}
   178  	}()
   179  	importFunc()
   180  	in.pkgmap[pkgpath] = module.Package
   181  
   182  	return module.Package, nil
   183  }
   184  
   185  func (in *interp) getPackageSymbol(pkgpath, name string) unsafe.Pointer {
   186  	symbolName := irgen.ManglePackagePath(pkgpath) + "." + name
   187  	global := in.modules[pkgpath].NamedGlobal(symbolName)
   188  	if global.IsNil() {
   189  		return nil
   190  	}
   191  	return in.engine.PointerToGlobal(global)
   192  }
   193  
   194  func (in *interp) augmentPackageScope(pkg *types.Package) {
   195  	for _, obj := range in.scope {
   196  		pkg.Scope().Insert(obj)
   197  	}
   198  }
   199  
   200  func (l *line) append(str string, assigns []string) {
   201  	var s scanner.Scanner
   202  	fset := token.NewFileSet()
   203  	file := fset.AddFile("", fset.Base(), len(str))
   204  	s.Init(file, []byte(str), nil, 0)
   205  
   206  	_, tok, _ := s.Scan()
   207  	if l.line == "" {
   208  		switch tok {
   209  		case token.FOR, token.GO, token.IF, token.LBRACE, token.SELECT, token.SWITCH:
   210  			l.isStmt = true
   211  		case token.CONST, token.FUNC, token.TYPE, token.VAR:
   212  			var lit string
   213  			_, tok, lit = s.Scan()
   214  			if tok == token.IDENT {
   215  				l.declName = lit
   216  			}
   217  		}
   218  	}
   219  
   220  	for tok != token.EOF {
   221  		switch tok {
   222  		case token.LPAREN:
   223  			l.parens++
   224  		case token.RPAREN:
   225  			l.parens--
   226  		case token.LBRACE:
   227  			l.braces++
   228  		case token.RBRACE:
   229  			l.braces--
   230  		case token.LBRACK:
   231  			l.bracks++
   232  		case token.RBRACK:
   233  			l.bracks--
   234  		case token.DEC, token.INC,
   235  			token.ASSIGN, token.ADD_ASSIGN, token.SUB_ASSIGN,
   236  			token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN,
   237  			token.AND_ASSIGN, token.OR_ASSIGN, token.XOR_ASSIGN,
   238  			token.SHL_ASSIGN, token.SHR_ASSIGN, token.AND_NOT_ASSIGN:
   239  			if l.parens == 0 && l.bracks == 0 && l.braces == 0 {
   240  				l.isStmt = true
   241  			}
   242  		}
   243  		_, tok, _ = s.Scan()
   244  	}
   245  
   246  	if l.line == "" {
   247  		l.assigns = assigns
   248  	}
   249  	l.line += str
   250  }
   251  
   252  func (l *line) ready() bool {
   253  	return l.parens <= 0 && l.bracks <= 0 && l.braces <= 0
   254  }
   255  
   256  func (in *interp) readExprLine(str string, assigns []string) ([]interface{}, error) {
   257  	in.pendingLine.append(str, assigns)
   258  	if !in.pendingLine.ready() {
   259  		return nil, nil
   260  	}
   261  	results, err := in.interpretLine(in.pendingLine)
   262  	in.pendingLine = line{}
   263  	return results, err
   264  }
   265  
   266  func (in *interp) interpretLine(l line) ([]interface{}, error) {
   267  	pkgname := fmt.Sprintf("input%05d", in.pkgnum)
   268  	in.pkgnum++
   269  
   270  	pkg := types.NewPackage(pkgname, pkgname)
   271  	scope := pkg.Scope()
   272  
   273  	for _, imppkg := range in.imports {
   274  		obj := types.NewPkgName(token.NoPos, pkg, imppkg.Name(), imppkg)
   275  		scope.Insert(obj)
   276  	}
   277  
   278  	in.augmentPackageScope(pkg)
   279  
   280  	var tv types.TypeAndValue
   281  	if l.declName == "" && !l.isStmt {
   282  		var err error
   283  		tv, err = types.Eval(l.line, pkg, scope)
   284  		if err != nil {
   285  			return nil, err
   286  		}
   287  	}
   288  
   289  	var code bytes.Buffer
   290  	fmt.Fprintf(&code, "package %s\n", pkgname)
   291  
   292  	for _, pkg := range in.imports {
   293  		fmt.Fprintf(&code, "import %q\n", pkg.Path())
   294  	}
   295  
   296  	if l.declName != "" {
   297  		code.WriteString(l.line)
   298  	} else if !l.isStmt && tv.IsValue() {
   299  		var typs []types.Type
   300  		if tuple, ok := tv.Type.(*types.Tuple); ok {
   301  			typs = make([]types.Type, tuple.Len())
   302  			for i := range typs {
   303  				typs[i] = tuple.At(i).Type()
   304  			}
   305  		} else {
   306  			typs = []types.Type{tv.Type}
   307  		}
   308  		if len(l.assigns) == 2 && tv.HasOk() {
   309  			typs = append(typs, types.Typ[types.Bool])
   310  		}
   311  		if len(l.assigns) != 0 && len(l.assigns) != len(typs) {
   312  			return nil, errors.New("return value mismatch")
   313  		}
   314  
   315  		code.WriteString("var ")
   316  		for i := range typs {
   317  			if i != 0 {
   318  				code.WriteString(", ")
   319  			}
   320  			if len(l.assigns) != 0 && l.assigns[i] != "" {
   321  				if _, ok := in.scope[l.assigns[i]]; ok {
   322  					fmt.Fprintf(&code, "__llgoiV%d", i)
   323  				} else {
   324  					code.WriteString(l.assigns[i])
   325  				}
   326  			} else {
   327  				fmt.Fprintf(&code, "__llgoiV%d", i)
   328  			}
   329  		}
   330  		fmt.Fprintf(&code, " = %s\n", l.line)
   331  
   332  		code.WriteString("func init() {\n")
   333  		varnames := make([]string, len(typs))
   334  		for i := range typs {
   335  			var varname string
   336  			if len(l.assigns) != 0 && l.assigns[i] != "" {
   337  				if _, ok := in.scope[l.assigns[i]]; ok {
   338  					fmt.Fprintf(&code, "\t%s = __llgoiV%d\n", l.assigns[i], i)
   339  				}
   340  				varname = l.assigns[i]
   341  			} else {
   342  				varname = fmt.Sprintf("__llgoiV%d", i)
   343  			}
   344  			varnames[i] = varname
   345  		}
   346  		code.WriteString("}\n\n")
   347  
   348  		code.WriteString("func __llgoiResults() []interface{} {\n")
   349  		code.WriteString("\treturn []interface{}{\n")
   350  		for _, varname := range varnames {
   351  			fmt.Fprintf(&code, "\t\t%s,\n", varname)
   352  		}
   353  		code.WriteString("\t}\n")
   354  		code.WriteString("}\n")
   355  	} else {
   356  		if len(l.assigns) != 0 {
   357  			return nil, errors.New("return value mismatch")
   358  		}
   359  
   360  		fmt.Fprintf(&code, "func init() {\n\t%s}", l.line)
   361  	}
   362  
   363  	copts := in.copts
   364  	copts.PackageCreated = in.augmentPackageScope
   365  	copts.DisableUnusedImportCheck = true
   366  	pkg, err := in.loadSourcePackageFromCode(code.String(), pkgname, copts)
   367  	if err != nil {
   368  		return nil, err
   369  	}
   370  	in.imports = append(in.imports, pkg)
   371  
   372  	var results []interface{}
   373  	llgoiResultsAddress := in.getPackageSymbol(pkgname, "__llgoiResults$descriptor")
   374  	if llgoiResultsAddress != nil {
   375  		var resultsFunc func() []interface{}
   376  		*(*unsafe.Pointer)(unsafe.Pointer(&resultsFunc)) = llgoiResultsAddress
   377  		results = resultsFunc()
   378  	}
   379  
   380  	for _, assign := range l.assigns {
   381  		if assign != "" {
   382  			if _, ok := in.scope[assign]; !ok {
   383  				in.scope[assign] = pkg.Scope().Lookup(assign)
   384  			}
   385  		}
   386  	}
   387  
   388  	if l.declName != "" {
   389  		in.scope[l.declName] = pkg.Scope().Lookup(l.declName)
   390  	}
   391  
   392  	return results, nil
   393  }
   394  
   395  func (in *interp) maybeReadAssignment(line string, s *scanner.Scanner, initial string, base int) (bool, error) {
   396  	if initial == "_" {
   397  		initial = ""
   398  	}
   399  	assigns := []string{initial}
   400  
   401  	pos, tok, lit := s.Scan()
   402  	for tok == token.COMMA {
   403  		pos, tok, lit = s.Scan()
   404  		if tok != token.IDENT {
   405  			return false, nil
   406  		}
   407  
   408  		if lit == "_" {
   409  			lit = ""
   410  		}
   411  		assigns = append(assigns, lit)
   412  
   413  		pos, tok, lit = s.Scan()
   414  	}
   415  
   416  	if tok != token.DEFINE {
   417  		return false, nil
   418  	}
   419  
   420  	// It's an assignment statement, there are no results.
   421  	_, err := in.readExprLine(line[int(pos)-base+2:], assigns)
   422  	return true, err
   423  }
   424  
   425  func (in *interp) loadPackage(pkgpath string) (*types.Package, error) {
   426  	pkg, err := in.copts.Importer(in.pkgmap, pkgpath)
   427  	if err == nil {
   428  		return pkg, nil
   429  	}
   430  
   431  	buildpkg, err := build.Import(pkgpath, ".", 0)
   432  	if err != nil {
   433  		return nil, err
   434  	}
   435  	if len(buildpkg.CgoFiles) != 0 {
   436  		return nil, fmt.Errorf("%s: cannot load cgo package", pkgpath)
   437  	}
   438  
   439  	for _, imp := range buildpkg.Imports {
   440  		_, err := in.loadPackage(imp)
   441  		if err != nil {
   442  			return nil, err
   443  		}
   444  	}
   445  
   446  	inputs := make([]string, len(buildpkg.GoFiles))
   447  	for i, file := range buildpkg.GoFiles {
   448  		inputs[i] = filepath.Join(buildpkg.Dir, file)
   449  	}
   450  
   451  	fset := token.NewFileSet()
   452  	files, err := driver.ParseFiles(fset, inputs)
   453  	if err != nil {
   454  		return nil, err
   455  	}
   456  
   457  	return in.loadSourcePackage(fset, files, pkgpath, in.copts)
   458  }
   459  
   460  // readLine accumulates lines of input, including trailing newlines,
   461  // executing statements as they are completed.
   462  func (in *interp) readLine(line string) ([]interface{}, error) {
   463  	if !in.pendingLine.ready() {
   464  		return in.readExprLine(line, nil)
   465  	}
   466  
   467  	var s scanner.Scanner
   468  	fset := token.NewFileSet()
   469  	file := fset.AddFile("", fset.Base(), len(line))
   470  	s.Init(file, []byte(line), nil, 0)
   471  
   472  	_, tok, lit := s.Scan()
   473  	switch tok {
   474  	case token.EOF:
   475  		return nil, nil
   476  
   477  	case token.IMPORT:
   478  		_, tok, lit = s.Scan()
   479  		if tok != token.STRING {
   480  			return nil, errors.New("expected string literal")
   481  		}
   482  		pkgpath, err := strconv.Unquote(lit)
   483  		if err != nil {
   484  			return nil, err
   485  		}
   486  		pkg, err := in.loadPackage(pkgpath)
   487  		if err != nil {
   488  			return nil, err
   489  		}
   490  		in.imports = append(in.imports, pkg)
   491  		return nil, nil
   492  
   493  	case token.IDENT:
   494  		ok, err := in.maybeReadAssignment(line, &s, lit, file.Base())
   495  		if err != nil {
   496  			return nil, err
   497  		}
   498  		if ok {
   499  			return nil, nil
   500  		}
   501  		fallthrough
   502  
   503  	default:
   504  		return in.readExprLine(line, nil)
   505  	}
   506  }
   507  
   508  // printResult prints a value that was the result of an expression evaluated
   509  // by the interpreter.
   510  func printResult(w io.Writer, v interface{}) {
   511  	// TODO the result should be formatted in Go syntax, without
   512  	// package qualifiers for types defined within the interpreter.
   513  	fmt.Fprintf(w, "%+v", v)
   514  }
   515  
   516  // formatHistory reformats the provided Go source by collapsing all lines
   517  // and adding semicolons where required, suitable for adding to line history.
   518  func formatHistory(input []byte) string {
   519  	var buf bytes.Buffer
   520  	var s scanner.Scanner
   521  	fset := token.NewFileSet()
   522  	file := fset.AddFile("", fset.Base(), len(input))
   523  	s.Init(file, input, nil, 0)
   524  	pos, tok, lit := s.Scan()
   525  	for tok != token.EOF {
   526  		if int(pos)-1 > buf.Len() {
   527  			n := int(pos) - 1 - buf.Len()
   528  			buf.WriteString(strings.Repeat(" ", n))
   529  		}
   530  		var semicolon bool
   531  		if tok == token.SEMICOLON {
   532  			semicolon = true
   533  		} else if lit != "" {
   534  			buf.WriteString(lit)
   535  		} else {
   536  			buf.WriteString(tok.String())
   537  		}
   538  		pos, tok, lit = s.Scan()
   539  		if semicolon {
   540  			switch tok {
   541  			case token.RBRACE, token.RPAREN, token.EOF:
   542  			default:
   543  				buf.WriteRune(';')
   544  			}
   545  		}
   546  	}
   547  	return buf.String()
   548  }
   549  
   550  func main() {
   551  	llvm.LinkInMCJIT()
   552  	llvm.InitializeNativeTarget()
   553  	llvm.InitializeNativeAsmPrinter()
   554  
   555  	var in interp
   556  	err := in.init()
   557  	if err != nil {
   558  		panic(err)
   559  	}
   560  	defer in.dispose()
   561  
   562  	var buf bytes.Buffer
   563  	for {
   564  		if in.pendingLine.ready() && buf.Len() > 0 {
   565  			history := formatHistory(buf.Bytes())
   566  			in.liner.AppendHistory(history)
   567  			buf.Reset()
   568  		}
   569  		prompt := "(llgo) "
   570  		if !in.pendingLine.ready() {
   571  			prompt = strings.Repeat(" ", len(prompt))
   572  		}
   573  		line, err := in.liner.Prompt(prompt)
   574  		if err == io.EOF {
   575  			break
   576  		} else if err != nil {
   577  			panic(err)
   578  		}
   579  		if line == "" {
   580  			continue
   581  		}
   582  		buf.WriteString(line + "\n")
   583  		results, err := in.readLine(line + "\n")
   584  		if err != nil {
   585  			fmt.Println(err)
   586  		}
   587  		for _, result := range results {
   588  			printResult(os.Stdout, result)
   589  			fmt.Println()
   590  		}
   591  	}
   592  
   593  	if liner.TerminalSupported() {
   594  		fmt.Println()
   595  	}
   596  }