github.com/axw/llgo@v0.0.0-20160805011314-95b5fe4dca20/cmd/llgoi/llgoi.go (about)

     1  //===- llgoi.go - llgo-based Go REPL --------------------------------------===//
     2  //
     3  //                     The LLVM Compiler Infrastructure
     4  //
     5  // This file is distributed under the University of Illinois Open Source
     6  // License. See LICENSE.TXT for details.
     7  //
     8  //===----------------------------------------------------------------------===//
     9  //
    10  // This is llgoi, a Go REPL based on llgo and the LLVM JIT.
    11  //
    12  //===----------------------------------------------------------------------===//
    13  
    14  package main
    15  
    16  import (
    17  	"bytes"
    18  	"errors"
    19  	"fmt"
    20  	"go/ast"
    21  	"go/build"
    22  	"go/parser"
    23  	"go/scanner"
    24  	"go/token"
    25  	"io"
    26  	"os"
    27  	"os/exec"
    28  	"path/filepath"
    29  	"runtime/debug"
    30  	"strconv"
    31  	"strings"
    32  	"unsafe"
    33  
    34  	"llvm.org/llgo/driver"
    35  	"llvm.org/llgo/irgen"
    36  	"llvm.org/llgo/third_party/gotools/go/types"
    37  	"llvm.org/llgo/third_party/liner"
    38  	"llvm.org/llvm/bindings/go/llvm"
    39  )
    40  
    41  // /* Force exporting __morestack if it's available, so that it is
    42  //    available to the engine when linking with libLLVM.so. */
    43  //
    44  // void *__morestack __attribute__((weak));
    45  import "C"
    46  
    47  func getInstPrefix() (string, error) {
    48  	path, err := exec.LookPath(os.Args[0])
    49  	if err != nil {
    50  		return "", err
    51  	}
    52  
    53  	path, err = filepath.EvalSymlinks(path)
    54  	if err != nil {
    55  		return "", err
    56  	}
    57  
    58  	prefix := filepath.Join(path, "..", "..")
    59  	return prefix, nil
    60  }
    61  
    62  func llvmVersion() string {
    63  	return strings.Replace(llvm.Version, "svn", "", 1)
    64  }
    65  
    66  type line struct {
    67  	line     string
    68  	isStmt   bool
    69  	declName string
    70  	assigns  []string
    71  
    72  	parens, bracks, braces int
    73  }
    74  
    75  type interp struct {
    76  	engine llvm.ExecutionEngine
    77  
    78  	liner       *liner.State
    79  	pendingLine line
    80  
    81  	copts irgen.CompilerOptions
    82  
    83  	imports []*types.Package
    84  	scope   map[string]types.Object
    85  
    86  	modules map[string]llvm.Module
    87  	pkgmap  map[string]*types.Package
    88  	pkgnum  int
    89  }
    90  
    91  func (in *interp) makeCompilerOptions() error {
    92  	prefix, err := getInstPrefix()
    93  	if err != nil {
    94  		return err
    95  	}
    96  
    97  	importPaths := []string{filepath.Join(prefix, "lib", "go", "llgo-"+llvmVersion())}
    98  	in.copts = irgen.CompilerOptions{
    99  		TargetTriple:  llvm.DefaultTargetTriple(),
   100  		ImportPaths:   importPaths,
   101  		GenerateDebug: true,
   102  		Packages:      in.pkgmap,
   103  	}
   104  	err = in.copts.MakeImporter()
   105  	if err != nil {
   106  		return err
   107  	}
   108  
   109  	origImporter := in.copts.Importer
   110  	in.copts.Importer = func(pkgmap map[string]*types.Package, pkgpath string) (*types.Package, error) {
   111  		if pkg, ok := pkgmap[pkgpath]; ok && pkg.Complete() {
   112  			return pkg, nil
   113  		}
   114  		return origImporter(pkgmap, pkgpath)
   115  	}
   116  	return nil
   117  }
   118  
   119  func (in *interp) init() error {
   120  	in.liner = liner.NewLiner()
   121  	in.scope = make(map[string]types.Object)
   122  	in.pkgmap = make(map[string]*types.Package)
   123  	in.modules = make(map[string]llvm.Module)
   124  
   125  	err := in.makeCompilerOptions()
   126  	if err != nil {
   127  		return err
   128  	}
   129  
   130  	return nil
   131  }
   132  
   133  func (in *interp) dispose() {
   134  	in.liner.Close()
   135  	in.engine.Dispose()
   136  }
   137  
   138  func (in *interp) loadSourcePackageFromCode(pkgcode, pkgpath string, copts irgen.CompilerOptions) (*types.Package, error) {
   139  	fset := token.NewFileSet()
   140  	file, err := parser.ParseFile(fset, "<input>", pkgcode, parser.DeclarationErrors|parser.ParseComments)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  	files := []*ast.File{file}
   145  	return in.loadSourcePackage(fset, files, pkgpath, copts)
   146  }
   147  
   148  func (in *interp) loadSourcePackage(fset *token.FileSet, files []*ast.File, pkgpath string, copts irgen.CompilerOptions) (_ *types.Package, resultErr error) {
   149  	compiler, err := irgen.NewCompiler(copts)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  
   154  	module, err := compiler.Compile(fset, files, pkgpath)
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  	in.modules[pkgpath] = module.Module
   159  
   160  	if in.engine.C != nil {
   161  		in.engine.AddModule(module.Module)
   162  	} else {
   163  		options := llvm.NewMCJITCompilerOptions()
   164  		in.engine, err = llvm.NewMCJITCompiler(module.Module, options)
   165  		if err != nil {
   166  			return nil, err
   167  		}
   168  	}
   169  
   170  	var importFunc func()
   171  	importAddress := in.getPackageSymbol(pkgpath, ".import$descriptor")
   172  	*(*unsafe.Pointer)(unsafe.Pointer(&importFunc)) = importAddress
   173  
   174  	defer func() {
   175  		p := recover()
   176  		if p != nil {
   177  			resultErr = fmt.Errorf("panic: %v\n%v", p, string(debug.Stack()))
   178  		}
   179  	}()
   180  	importFunc()
   181  	in.pkgmap[pkgpath] = module.Package
   182  
   183  	return module.Package, nil
   184  }
   185  
   186  func (in *interp) getPackageSymbol(pkgpath, name string) unsafe.Pointer {
   187  	symbolName := irgen.ManglePackagePath(pkgpath) + "." + name
   188  	global := in.modules[pkgpath].NamedGlobal(symbolName)
   189  	if global.IsNil() {
   190  		return nil
   191  	}
   192  	return in.engine.PointerToGlobal(global)
   193  }
   194  
   195  func (in *interp) augmentPackageScope(pkg *types.Package) {
   196  	for _, obj := range in.scope {
   197  		pkg.Scope().Insert(obj)
   198  	}
   199  }
   200  
   201  func (l *line) append(str string, assigns []string) {
   202  	var s scanner.Scanner
   203  	fset := token.NewFileSet()
   204  	file := fset.AddFile("", fset.Base(), len(str))
   205  	s.Init(file, []byte(str), nil, 0)
   206  
   207  	_, tok, _ := s.Scan()
   208  	if l.line == "" {
   209  		switch tok {
   210  		case token.FOR, token.GO, token.IF, token.LBRACE, token.SELECT, token.SWITCH:
   211  			l.isStmt = true
   212  		case token.CONST, token.FUNC, token.TYPE, token.VAR:
   213  			var lit string
   214  			_, tok, lit = s.Scan()
   215  			if tok == token.IDENT {
   216  				l.declName = lit
   217  			}
   218  		}
   219  	}
   220  
   221  	for tok != token.EOF {
   222  		switch tok {
   223  		case token.LPAREN:
   224  			l.parens++
   225  		case token.RPAREN:
   226  			l.parens--
   227  		case token.LBRACE:
   228  			l.braces++
   229  		case token.RBRACE:
   230  			l.braces--
   231  		case token.LBRACK:
   232  			l.bracks++
   233  		case token.RBRACK:
   234  			l.bracks--
   235  		case token.DEC, token.INC,
   236  			token.ASSIGN, token.ADD_ASSIGN, token.SUB_ASSIGN,
   237  			token.MUL_ASSIGN, token.QUO_ASSIGN, token.REM_ASSIGN,
   238  			token.AND_ASSIGN, token.OR_ASSIGN, token.XOR_ASSIGN,
   239  			token.SHL_ASSIGN, token.SHR_ASSIGN, token.AND_NOT_ASSIGN:
   240  			if l.parens == 0 && l.bracks == 0 && l.braces == 0 {
   241  				l.isStmt = true
   242  			}
   243  		}
   244  		_, tok, _ = s.Scan()
   245  	}
   246  
   247  	if l.line == "" {
   248  		l.assigns = assigns
   249  	}
   250  	l.line += str
   251  }
   252  
   253  func (l *line) ready() bool {
   254  	return l.parens <= 0 && l.bracks <= 0 && l.braces <= 0
   255  }
   256  
   257  func (in *interp) readExprLine(str string, assigns []string) ([]interface{}, error) {
   258  	in.pendingLine.append(str, assigns)
   259  	if !in.pendingLine.ready() {
   260  		return nil, nil
   261  	}
   262  	results, err := in.interpretLine(in.pendingLine)
   263  	in.pendingLine = line{}
   264  	return results, err
   265  }
   266  
   267  func (in *interp) interpretLine(l line) ([]interface{}, error) {
   268  	pkgname := fmt.Sprintf("input%05d", in.pkgnum)
   269  	in.pkgnum++
   270  
   271  	pkg := types.NewPackage(pkgname, pkgname)
   272  	scope := pkg.Scope()
   273  
   274  	for _, imppkg := range in.imports {
   275  		obj := types.NewPkgName(token.NoPos, pkg, imppkg.Name(), imppkg)
   276  		scope.Insert(obj)
   277  	}
   278  
   279  	in.augmentPackageScope(pkg)
   280  
   281  	var tv types.TypeAndValue
   282  	if l.declName == "" && !l.isStmt {
   283  		var err error
   284  		tv, err = types.Eval(l.line, pkg, scope)
   285  		if err != nil {
   286  			return nil, err
   287  		}
   288  	}
   289  
   290  	var code bytes.Buffer
   291  	fmt.Fprintf(&code, "package %s\n", pkgname)
   292  
   293  	for _, pkg := range in.imports {
   294  		fmt.Fprintf(&code, "import %q\n", pkg.Path())
   295  	}
   296  
   297  	if l.declName != "" {
   298  		code.WriteString(l.line)
   299  	} else if !l.isStmt && tv.IsValue() {
   300  		var typs []types.Type
   301  		if tuple, ok := tv.Type.(*types.Tuple); ok {
   302  			typs = make([]types.Type, tuple.Len())
   303  			for i := range typs {
   304  				typs[i] = tuple.At(i).Type()
   305  			}
   306  		} else {
   307  			typs = []types.Type{tv.Type}
   308  		}
   309  		if len(l.assigns) == 2 && tv.HasOk() {
   310  			typs = append(typs, types.Typ[types.Bool])
   311  		}
   312  		if len(l.assigns) != 0 && len(l.assigns) != len(typs) {
   313  			return nil, errors.New("return value mismatch")
   314  		}
   315  
   316  		code.WriteString("var ")
   317  		for i := range typs {
   318  			if i != 0 {
   319  				code.WriteString(", ")
   320  			}
   321  			if len(l.assigns) != 0 && l.assigns[i] != "" {
   322  				if _, ok := in.scope[l.assigns[i]]; ok {
   323  					fmt.Fprintf(&code, "__llgoiV%d", i)
   324  				} else {
   325  					code.WriteString(l.assigns[i])
   326  				}
   327  			} else {
   328  				fmt.Fprintf(&code, "__llgoiV%d", i)
   329  			}
   330  		}
   331  		fmt.Fprintf(&code, " = %s\n", l.line)
   332  
   333  		code.WriteString("func init() {\n")
   334  		varnames := make([]string, len(typs))
   335  		for i := range typs {
   336  			var varname string
   337  			if len(l.assigns) != 0 && l.assigns[i] != "" {
   338  				if _, ok := in.scope[l.assigns[i]]; ok {
   339  					fmt.Fprintf(&code, "\t%s = __llgoiV%d\n", l.assigns[i], i)
   340  				}
   341  				varname = l.assigns[i]
   342  			} else {
   343  				varname = fmt.Sprintf("__llgoiV%d", i)
   344  			}
   345  			varnames[i] = varname
   346  		}
   347  		code.WriteString("}\n\n")
   348  
   349  		code.WriteString("func __llgoiResults() []interface{} {\n")
   350  		code.WriteString("\treturn []interface{}{\n")
   351  		for _, varname := range varnames {
   352  			fmt.Fprintf(&code, "\t\t%s,\n", varname)
   353  		}
   354  		code.WriteString("\t}\n")
   355  		code.WriteString("}\n")
   356  	} else {
   357  		if len(l.assigns) != 0 {
   358  			return nil, errors.New("return value mismatch")
   359  		}
   360  
   361  		fmt.Fprintf(&code, "func init() {\n\t%s}", l.line)
   362  	}
   363  
   364  	copts := in.copts
   365  	copts.PackageCreated = in.augmentPackageScope
   366  	copts.DisableUnusedImportCheck = true
   367  	pkg, err := in.loadSourcePackageFromCode(code.String(), pkgname, copts)
   368  	if err != nil {
   369  		return nil, err
   370  	}
   371  	in.imports = append(in.imports, pkg)
   372  
   373  	var results []interface{}
   374  	llgoiResultsAddress := in.getPackageSymbol(pkgname, "__llgoiResults$descriptor")
   375  	if llgoiResultsAddress != nil {
   376  		var resultsFunc func() []interface{}
   377  		*(*unsafe.Pointer)(unsafe.Pointer(&resultsFunc)) = llgoiResultsAddress
   378  		results = resultsFunc()
   379  	}
   380  
   381  	for _, assign := range l.assigns {
   382  		if assign != "" {
   383  			if _, ok := in.scope[assign]; !ok {
   384  				in.scope[assign] = pkg.Scope().Lookup(assign)
   385  			}
   386  		}
   387  	}
   388  
   389  	if l.declName != "" {
   390  		in.scope[l.declName] = pkg.Scope().Lookup(l.declName)
   391  	}
   392  
   393  	return results, nil
   394  }
   395  
   396  func (in *interp) maybeReadAssignment(line string, s *scanner.Scanner, initial string, base int) (bool, error) {
   397  	if initial == "_" {
   398  		initial = ""
   399  	}
   400  	assigns := []string{initial}
   401  
   402  	pos, tok, lit := s.Scan()
   403  	for tok == token.COMMA {
   404  		pos, tok, lit = s.Scan()
   405  		if tok != token.IDENT {
   406  			return false, nil
   407  		}
   408  
   409  		if lit == "_" {
   410  			lit = ""
   411  		}
   412  		assigns = append(assigns, lit)
   413  
   414  		pos, tok, lit = s.Scan()
   415  	}
   416  
   417  	if tok != token.DEFINE {
   418  		return false, nil
   419  	}
   420  
   421  	// It's an assignment statement, there are no results.
   422  	_, err := in.readExprLine(line[int(pos)-base+2:], assigns)
   423  	return true, err
   424  }
   425  
   426  func (in *interp) loadPackage(pkgpath string) (*types.Package, error) {
   427  	pkg, err := in.copts.Importer(in.pkgmap, pkgpath)
   428  	if err == nil {
   429  		return pkg, nil
   430  	}
   431  
   432  	buildpkg, err := build.Import(pkgpath, ".", 0)
   433  	if err != nil {
   434  		return nil, err
   435  	}
   436  	if len(buildpkg.CgoFiles) != 0 {
   437  		return nil, fmt.Errorf("%s: cannot load cgo package", pkgpath)
   438  	}
   439  
   440  	for _, imp := range buildpkg.Imports {
   441  		_, err := in.loadPackage(imp)
   442  		if err != nil {
   443  			return nil, err
   444  		}
   445  	}
   446  
   447  	inputs := make([]string, len(buildpkg.GoFiles))
   448  	for i, file := range buildpkg.GoFiles {
   449  		inputs[i] = filepath.Join(buildpkg.Dir, file)
   450  	}
   451  
   452  	fset := token.NewFileSet()
   453  	files, err := driver.ParseFiles(fset, inputs)
   454  	if err != nil {
   455  		return nil, err
   456  	}
   457  
   458  	return in.loadSourcePackage(fset, files, pkgpath, in.copts)
   459  }
   460  
   461  // readLine accumulates lines of input, including trailing newlines,
   462  // executing statements as they are completed.
   463  func (in *interp) readLine(line string) ([]interface{}, error) {
   464  	if !in.pendingLine.ready() {
   465  		return in.readExprLine(line, nil)
   466  	}
   467  
   468  	var s scanner.Scanner
   469  	fset := token.NewFileSet()
   470  	file := fset.AddFile("", fset.Base(), len(line))
   471  	s.Init(file, []byte(line), nil, 0)
   472  
   473  	_, tok, lit := s.Scan()
   474  	switch tok {
   475  	case token.EOF:
   476  		return nil, nil
   477  
   478  	case token.IMPORT:
   479  		_, tok, lit = s.Scan()
   480  		if tok != token.STRING {
   481  			return nil, errors.New("expected string literal")
   482  		}
   483  		pkgpath, err := strconv.Unquote(lit)
   484  		if err != nil {
   485  			return nil, err
   486  		}
   487  		pkg, err := in.loadPackage(pkgpath)
   488  		if err != nil {
   489  			return nil, err
   490  		}
   491  		in.imports = append(in.imports, pkg)
   492  		return nil, nil
   493  
   494  	case token.IDENT:
   495  		ok, err := in.maybeReadAssignment(line, &s, lit, file.Base())
   496  		if err != nil {
   497  			return nil, err
   498  		}
   499  		if ok {
   500  			return nil, nil
   501  		}
   502  		fallthrough
   503  
   504  	default:
   505  		return in.readExprLine(line, nil)
   506  	}
   507  }
   508  
   509  // printResult prints a value that was the result of an expression evaluated
   510  // by the interpreter.
   511  func printResult(w io.Writer, v interface{}) {
   512  	// TODO the result should be formatted in Go syntax, without
   513  	// package qualifiers for types defined within the interpreter.
   514  	fmt.Fprintf(w, "%+v", v)
   515  }
   516  
   517  // formatHistory reformats the provided Go source by collapsing all lines
   518  // and adding semicolons where required, suitable for adding to line history.
   519  func formatHistory(input []byte) string {
   520  	var buf bytes.Buffer
   521  	var s scanner.Scanner
   522  	fset := token.NewFileSet()
   523  	file := fset.AddFile("", fset.Base(), len(input))
   524  	s.Init(file, input, nil, 0)
   525  	pos, tok, lit := s.Scan()
   526  	for tok != token.EOF {
   527  		if int(pos)-1 > buf.Len() {
   528  			n := int(pos) - 1 - buf.Len()
   529  			buf.WriteString(strings.Repeat(" ", n))
   530  		}
   531  		var semicolon bool
   532  		if tok == token.SEMICOLON {
   533  			semicolon = true
   534  		} else if lit != "" {
   535  			buf.WriteString(lit)
   536  		} else {
   537  			buf.WriteString(tok.String())
   538  		}
   539  		pos, tok, lit = s.Scan()
   540  		if semicolon {
   541  			switch tok {
   542  			case token.RBRACE, token.RPAREN, token.EOF:
   543  			default:
   544  				buf.WriteRune(';')
   545  			}
   546  		}
   547  	}
   548  	return buf.String()
   549  }
   550  
   551  func main() {
   552  	llvm.LinkInMCJIT()
   553  	llvm.InitializeNativeTarget()
   554  	llvm.InitializeNativeAsmPrinter()
   555  
   556  	var in interp
   557  	err := in.init()
   558  	if err != nil {
   559  		panic(err)
   560  	}
   561  	defer in.dispose()
   562  
   563  	var buf bytes.Buffer
   564  	for {
   565  		if in.pendingLine.ready() && buf.Len() > 0 {
   566  			history := formatHistory(buf.Bytes())
   567  			in.liner.AppendHistory(history)
   568  			buf.Reset()
   569  		}
   570  		prompt := "(llgo) "
   571  		if !in.pendingLine.ready() {
   572  			prompt = strings.Repeat(" ", len(prompt))
   573  		}
   574  		line, err := in.liner.Prompt(prompt)
   575  		if err == io.EOF {
   576  			break
   577  		} else if err != nil {
   578  			panic(err)
   579  		}
   580  		if line == "" {
   581  			continue
   582  		}
   583  		buf.WriteString(line + "\n")
   584  		results, err := in.readLine(line + "\n")
   585  		if err != nil {
   586  			fmt.Println(err)
   587  		}
   588  		for _, result := range results {
   589  			printResult(os.Stdout, result)
   590  			fmt.Println()
   591  		}
   592  	}
   593  
   594  	if liner.TerminalSupported() {
   595  		fmt.Println()
   596  	}
   597  }