github.com/yunabe/lgo@v0.0.0-20190709125917-42c42d410fdf/converter/autoexit.go (about)

     1  // This file defines injectAutoExitToFile, which injects core.ExitIfCtxDone to interrupt lgo code
     2  //
     3  // Basic rules:
     4  // - Injects ExitIfCtxDone between two heavy statements (== function calls).
     5  // - Does not inject ExitIfCtxDone in functions under defer statements.
     6  // - Injects ExitIfCtxDone at the top of a function.
     7  // - Injects ExitIfCtxDone at the top of a for-loop body.
     8  
     9  package converter
    10  
    11  import (
    12  	"fmt"
    13  	"go/ast"
    14  	"go/token"
    15  	"go/types"
    16  
    17  	"github.com/yunabe/lgo/core"
    18  	"github.com/yunabe/lgo/parser"
    19  )
    20  
    21  func containsCall(expr ast.Expr) bool {
    22  	var v containsCallVisitor
    23  	ast.Walk(&v, expr)
    24  	return v.contains
    25  }
    26  
    27  type containsCallVisitor struct {
    28  	contains bool
    29  }
    30  
    31  func (v *containsCallVisitor) Visit(node ast.Node) ast.Visitor {
    32  	if node == nil {
    33  		return nil
    34  	}
    35  	if _, ok := node.(*ast.CallExpr); ok {
    36  		v.contains = true
    37  	}
    38  	if _, ok := node.(*ast.FuncLit); ok {
    39  		return nil
    40  	}
    41  	if v.contains {
    42  		return nil
    43  	}
    44  	return v
    45  }
    46  
    47  func isHeavyStmt(stm ast.Stmt) bool {
    48  	switch stm := stm.(type) {
    49  	case *ast.AssignStmt:
    50  		for _, e := range stm.Lhs {
    51  			if containsCall(e) {
    52  				return true
    53  			}
    54  		}
    55  		for _, e := range stm.Rhs {
    56  			if containsCall(e) {
    57  				return true
    58  			}
    59  		}
    60  	case *ast.ExprStmt:
    61  		return containsCall(stm.X)
    62  	case *ast.ForStmt:
    63  		return isHeavyStmt(stm.Init)
    64  	case *ast.GoStmt:
    65  		if containsCall(stm.Call.Fun) {
    66  			return true
    67  		}
    68  		for _, arg := range stm.Call.Args {
    69  			if containsCall(arg) {
    70  				return true
    71  			}
    72  		}
    73  	case *ast.IfStmt:
    74  		if isHeavyStmt(stm.Init) {
    75  			return true
    76  		}
    77  		if containsCall(stm.Cond) {
    78  			return true
    79  		}
    80  	case *ast.ReturnStmt:
    81  		for _, r := range stm.Results {
    82  			if containsCall(r) {
    83  				return true
    84  			}
    85  		}
    86  	case *ast.SwitchStmt:
    87  		if isHeavyStmt(stm.Init) {
    88  			return true
    89  		}
    90  		if containsCall(stm.Tag) {
    91  			return true
    92  		}
    93  		// Return true if one of case clause contains a function call.
    94  		for _, l := range stm.Body.List {
    95  			cas := l.(*ast.CaseClause)
    96  			for _, e := range cas.List {
    97  				if containsCall(e) {
    98  					return true
    99  				}
   100  			}
   101  		}
   102  	}
   103  	return false
   104  }
   105  
   106  func injectAutoExitBlock(block *ast.BlockStmt, injectHead bool, defaultFlag bool, importCore func() string) {
   107  	injectAutoExitToBlockStmtList(&block.List, injectHead, defaultFlag, importCore)
   108  }
   109  
   110  func makeExitIfDoneCommClause(importCore func() string) *ast.CommClause {
   111  	ctx := &ast.CallExpr{
   112  		Fun: &ast.SelectorExpr{
   113  			X:   &ast.Ident{Name: importCore()},
   114  			Sel: &ast.Ident{Name: "GetExecContext"},
   115  		},
   116  	}
   117  	done := &ast.CallExpr{
   118  		Fun: &ast.SelectorExpr{
   119  			X:   ctx,
   120  			Sel: &ast.Ident{Name: "Done"},
   121  		},
   122  	}
   123  	return &ast.CommClause{
   124  		Comm: &ast.ExprStmt{X: &ast.UnaryExpr{Op: token.ARROW, X: done}},
   125  		Body: []ast.Stmt{&ast.ExprStmt{
   126  			X: &ast.CallExpr{
   127  				Fun: ast.NewIdent("panic"),
   128  				Args: []ast.Expr{&ast.SelectorExpr{
   129  					X:   &ast.Ident{Name: importCore()},
   130  					Sel: &ast.Ident{Name: "Bailout"},
   131  				}},
   132  			},
   133  		}},
   134  	}
   135  }
   136  
   137  func injectAutoExitToBlockStmtList(lst *[]ast.Stmt, injectHead bool, defaultFlag bool, importCore func() string) {
   138  	newList := make([]ast.Stmt, 0, 2*len(*lst)+1)
   139  	flag := defaultFlag
   140  	appendAutoExpt := func() {
   141  		newList = append(newList, &ast.ExprStmt{
   142  			X: &ast.CallExpr{
   143  				Fun: ast.NewIdent(importCore() + ".ExitIfCtxDone"),
   144  			},
   145  		})
   146  	}
   147  	if injectHead {
   148  		appendAutoExpt()
   149  		flag = false
   150  	}
   151  	for i := 0; i < len(*lst); i++ {
   152  		stmt := (*lst)[i]
   153  		if stmt, ok := stmt.(*ast.SendStmt); ok {
   154  			newList = append(newList, &ast.SelectStmt{
   155  				Body: &ast.BlockStmt{
   156  					List: []ast.Stmt{
   157  						&ast.CommClause{Comm: stmt},
   158  						makeExitIfDoneCommClause(importCore),
   159  					},
   160  				},
   161  			})
   162  			flag = false
   163  			continue
   164  		}
   165  
   166  		heavy := injectAutoExitToStmt(stmt, importCore, flag)
   167  		if heavy {
   168  			if flag {
   169  				appendAutoExpt()
   170  			}
   171  			flag = true
   172  		}
   173  		newList = append(newList, stmt)
   174  	}
   175  	*lst = newList
   176  }
   177  
   178  func injectAutoExitToStmt(stm ast.Stmt, importCore func() string, prevHeavy bool) bool {
   179  	heavy := isHeavyStmt(stm)
   180  	switch stm := stm.(type) {
   181  	case *ast.ForStmt:
   182  		injectAutoExit(stm.Init, importCore)
   183  		injectAutoExit(stm.Cond, importCore)
   184  		injectAutoExit(stm.Post, importCore)
   185  		injectAutoExitBlock(stm.Body, true, heavy, importCore)
   186  	case *ast.IfStmt:
   187  		injectAutoExit(stm.Init, importCore)
   188  		injectAutoExit(stm.Cond, importCore)
   189  		injectAutoExitBlock(stm.Body, false, heavy, importCore)
   190  	case *ast.SwitchStmt:
   191  		injectAutoExit(stm.Init, importCore)
   192  		injectAutoExit(stm.Tag, importCore)
   193  		for _, l := range stm.Body.List {
   194  			cas := l.(*ast.CaseClause)
   195  			injectAutoExitToBlockStmtList(&cas.Body, false, heavy, importCore)
   196  		}
   197  	case *ast.BlockStmt:
   198  		injectAutoExitBlock(stm, true, prevHeavy, importCore)
   199  	case *ast.DeferStmt:
   200  		// Do not inject autoexit code inside a function defined in a defer statement.
   201  	default:
   202  		injectAutoExit(stm, importCore)
   203  	}
   204  	return heavy
   205  }
   206  
   207  func injectAutoExitToSelectStmt(sel *ast.SelectStmt, importCore func() string) {
   208  	for _, stmt := range sel.Body.List {
   209  		stmt := stmt.(*ast.CommClause)
   210  		if stmt.Comm == nil {
   211  			// stmt is default case. Do nothing if there is default-case.
   212  			return
   213  		}
   214  	}
   215  	sel.Body.List = append(sel.Body.List, makeExitIfDoneCommClause(importCore))
   216  }
   217  
   218  func injectAutoExit(node ast.Node, importCore func() string) {
   219  	if node == nil {
   220  		return
   221  	}
   222  	v := autoExitInjector{importCore: importCore}
   223  	ast.Walk(&v, node)
   224  }
   225  
   226  type autoExitInjector struct {
   227  	importCore func() string
   228  }
   229  
   230  func (v *autoExitInjector) Visit(node ast.Node) ast.Visitor {
   231  	if decl, ok := node.(*ast.FuncDecl); ok {
   232  		injectAutoExitBlock(decl.Body, true, false, v.importCore)
   233  		return nil
   234  	}
   235  	if fn, ok := node.(*ast.FuncLit); ok {
   236  		injectAutoExitBlock(fn.Body, true, false, v.importCore)
   237  		return nil
   238  	}
   239  	if node, ok := node.(*ast.SelectStmt); ok {
   240  		injectAutoExitToSelectStmt(node, v.importCore)
   241  	}
   242  	return v
   243  }
   244  
   245  func injectAutoExitToFile(file *ast.File, immg *importManager) {
   246  	importCore := func() string {
   247  		corePkg, _ := lgoImporter.Import(core.SelfPkgPath)
   248  		return immg.shortName(corePkg)
   249  	}
   250  	injectAutoExit(file, importCore)
   251  }
   252  
   253  // selectCommExprRecorder collects `<-ch` expressions that are used inside select-case clauses.
   254  // The result of this recorder is used from mayWrapRecvOp so that it does not modify `<-ch` operations in select-case clauses.
   255  type selectCommExprRecorder struct {
   256  	m map[ast.Expr]bool
   257  }
   258  
   259  func (v *selectCommExprRecorder) Visit(node ast.Node) ast.Visitor {
   260  	sel, ok := node.(*ast.SelectStmt)
   261  	if !ok {
   262  		return v
   263  	}
   264  	for _, cas := range sel.Body.List {
   265  		comm := cas.(*ast.CommClause).Comm
   266  		if comm == nil {
   267  			continue
   268  		}
   269  		if assign, ok := comm.(*ast.AssignStmt); ok {
   270  			for _, expr := range assign.Rhs {
   271  				v.m[expr] = true
   272  			}
   273  		}
   274  		if expr, ok := comm.(*ast.ExprStmt); ok {
   275  			v.m[expr.X] = true
   276  		}
   277  	}
   278  	return v
   279  }
   280  
   281  func mayWrapRecvOp(conf *Config, file *ast.File, fset *token.FileSet, checker *types.Checker, pkg *types.Package, runctx types.Object, oldImports []*types.PkgName) (*types.Checker, *types.Package, types.Object, []*types.PkgName) {
   282  	rec := selectCommExprRecorder{make(map[ast.Expr]bool)}
   283  	ast.Walk(&rec, file)
   284  	picker := newNamePicker(checker.Defs)
   285  	immg := newImportManager(pkg, file, checker)
   286  	importCore := func() string {
   287  		corePkg, _ := lgoImporter.Import(core.SelfPkgPath)
   288  		return immg.shortName(corePkg)
   289  	}
   290  
   291  	var rewritten bool
   292  	rewriteExpr(file, func(expr ast.Expr) ast.Expr {
   293  		if rec.m[expr] {
   294  			return expr
   295  		}
   296  		ue, ok := expr.(*ast.UnaryExpr)
   297  		if !ok || ue.Op != token.ARROW {
   298  			return expr
   299  		}
   300  		rewritten = true
   301  		wrapperName := picker.NewName("recvChan")
   302  		decl := makeChanRecvWrapper(ue, wrapperName, immg, importCore, checker)
   303  		file.Decls = append(file.Decls, decl)
   304  		return &ast.CallExpr{
   305  			Fun:  ast.NewIdent(wrapperName),
   306  			Args: []ast.Expr{ue.X},
   307  		}
   308  	})
   309  	if !rewritten {
   310  		return checker, pkg, runctx, oldImports
   311  	}
   312  	if len(immg.injectedImports) > 0 {
   313  		var newDecls []ast.Decl
   314  		for _, im := range immg.injectedImports {
   315  			newDecls = append(newDecls, im)
   316  		}
   317  		newDecls = append(newDecls, file.Decls...)
   318  		file.Decls = newDecls
   319  	}
   320  	var err error
   321  	checker, pkg, runctx, oldImports, err = checkFileInPhase2(conf, file, fset)
   322  	if err != nil {
   323  		panic(fmt.Errorf("No error expected but got: %v", err))
   324  	}
   325  	return checker, pkg, runctx, oldImports
   326  }
   327  
   328  // makeChanRecvWrapper makes ast.Node trees of a function declaration like
   329  // func recvChan1(c chan int) (x int, ok bool) {
   330  //    select {
   331  //    case x, ok = <-c:
   332  //        return
   333  //    }
   334  // }
   335  // The select statement is equivalent to <-c in this phase. The code to interrupt the select statement
   336  // is injected in the latter phase by autoExitInjector.
   337  func makeChanRecvWrapper(expr *ast.UnaryExpr, funcName string, immg *importManager, importCore func() string, checker *types.Checker) *ast.FuncDecl {
   338  	typeExpr := func(typ types.Type) ast.Expr {
   339  		s := types.TypeString(typ, func(pkg *types.Package) string {
   340  			return immg.shortName(pkg)
   341  		})
   342  		expr, err := parser.ParseExpr(s)
   343  		if err != nil {
   344  			panic(fmt.Sprintf("Failed to parse type expr %q: %v", s, err))
   345  		}
   346  		return expr
   347  	}
   348  
   349  	var results []*ast.Field
   350  	typ := checker.Types[expr].Type
   351  	if tup, ok := typ.(*types.Tuple); ok {
   352  		results = []*ast.Field{
   353  			{
   354  				Names: []*ast.Ident{ast.NewIdent("x")},
   355  				Type:  typeExpr(tup.At(0).Type()),
   356  			}, {
   357  				Names: []*ast.Ident{ast.NewIdent("ok")},
   358  				Type:  typeExpr(tup.At(1).Type()),
   359  			},
   360  		}
   361  	} else {
   362  		results = []*ast.Field{
   363  			{
   364  				Names: []*ast.Ident{ast.NewIdent("x")},
   365  				Type:  typeExpr(typ),
   366  			},
   367  		}
   368  	}
   369  	chType := checker.Types[expr.X].Type.(*types.Chan)
   370  	var lhs []ast.Expr
   371  	if len(results) == 1 {
   372  		lhs = []ast.Expr{ast.NewIdent("x")}
   373  	} else {
   374  		lhs = []ast.Expr{ast.NewIdent("x"), ast.NewIdent("ok")}
   375  	}
   376  	return &ast.FuncDecl{
   377  		Name: ast.NewIdent(funcName),
   378  		Type: &ast.FuncType{
   379  			Params: &ast.FieldList{
   380  				List: []*ast.Field{
   381  					{
   382  						Names: []*ast.Ident{ast.NewIdent("c")},
   383  						Type:  typeExpr(chType),
   384  					},
   385  				},
   386  			},
   387  			Results: &ast.FieldList{List: results},
   388  		},
   389  		Body: &ast.BlockStmt{
   390  			List: []ast.Stmt{
   391  				&ast.SelectStmt{
   392  					Body: &ast.BlockStmt{
   393  						List: []ast.Stmt{
   394  							&ast.CommClause{
   395  								Comm: &ast.AssignStmt{
   396  									Tok: token.ASSIGN,
   397  									Lhs: lhs,
   398  									Rhs: []ast.Expr{&ast.UnaryExpr{
   399  										Op: token.ARROW,
   400  										X:  ast.NewIdent("c"),
   401  									}},
   402  								},
   403  								Body: []ast.Stmt{&ast.ReturnStmt{}},
   404  							},
   405  						},
   406  					},
   407  				},
   408  			},
   409  		},
   410  	}
   411  }