github.com/yunabe/lgo@v0.0.0-20190709125917-42c42d410fdf/converter/autoexit.go (about) 1 // This file defines injectAutoExitToFile, which injects core.ExitIfCtxDone to interrupt lgo code 2 // 3 // Basic rules: 4 // - Injects ExitIfCtxDone between two heavy statements (== function calls). 5 // - Does not inject ExitIfCtxDone in functions under defer statements. 6 // - Injects ExitIfCtxDone at the top of a function. 7 // - Injects ExitIfCtxDone at the top of a for-loop body. 8 9 package converter 10 11 import ( 12 "fmt" 13 "go/ast" 14 "go/token" 15 "go/types" 16 17 "github.com/yunabe/lgo/core" 18 "github.com/yunabe/lgo/parser" 19 ) 20 21 func containsCall(expr ast.Expr) bool { 22 var v containsCallVisitor 23 ast.Walk(&v, expr) 24 return v.contains 25 } 26 27 type containsCallVisitor struct { 28 contains bool 29 } 30 31 func (v *containsCallVisitor) Visit(node ast.Node) ast.Visitor { 32 if node == nil { 33 return nil 34 } 35 if _, ok := node.(*ast.CallExpr); ok { 36 v.contains = true 37 } 38 if _, ok := node.(*ast.FuncLit); ok { 39 return nil 40 } 41 if v.contains { 42 return nil 43 } 44 return v 45 } 46 47 func isHeavyStmt(stm ast.Stmt) bool { 48 switch stm := stm.(type) { 49 case *ast.AssignStmt: 50 for _, e := range stm.Lhs { 51 if containsCall(e) { 52 return true 53 } 54 } 55 for _, e := range stm.Rhs { 56 if containsCall(e) { 57 return true 58 } 59 } 60 case *ast.ExprStmt: 61 return containsCall(stm.X) 62 case *ast.ForStmt: 63 return isHeavyStmt(stm.Init) 64 case *ast.GoStmt: 65 if containsCall(stm.Call.Fun) { 66 return true 67 } 68 for _, arg := range stm.Call.Args { 69 if containsCall(arg) { 70 return true 71 } 72 } 73 case *ast.IfStmt: 74 if isHeavyStmt(stm.Init) { 75 return true 76 } 77 if containsCall(stm.Cond) { 78 return true 79 } 80 case *ast.ReturnStmt: 81 for _, r := range stm.Results { 82 if containsCall(r) { 83 return true 84 } 85 } 86 case *ast.SwitchStmt: 87 if isHeavyStmt(stm.Init) { 88 return true 89 } 90 if containsCall(stm.Tag) { 91 return true 92 } 93 // Return true if one of case clause contains a function call. 94 for _, l := range stm.Body.List { 95 cas := l.(*ast.CaseClause) 96 for _, e := range cas.List { 97 if containsCall(e) { 98 return true 99 } 100 } 101 } 102 } 103 return false 104 } 105 106 func injectAutoExitBlock(block *ast.BlockStmt, injectHead bool, defaultFlag bool, importCore func() string) { 107 injectAutoExitToBlockStmtList(&block.List, injectHead, defaultFlag, importCore) 108 } 109 110 func makeExitIfDoneCommClause(importCore func() string) *ast.CommClause { 111 ctx := &ast.CallExpr{ 112 Fun: &ast.SelectorExpr{ 113 X: &ast.Ident{Name: importCore()}, 114 Sel: &ast.Ident{Name: "GetExecContext"}, 115 }, 116 } 117 done := &ast.CallExpr{ 118 Fun: &ast.SelectorExpr{ 119 X: ctx, 120 Sel: &ast.Ident{Name: "Done"}, 121 }, 122 } 123 return &ast.CommClause{ 124 Comm: &ast.ExprStmt{X: &ast.UnaryExpr{Op: token.ARROW, X: done}}, 125 Body: []ast.Stmt{&ast.ExprStmt{ 126 X: &ast.CallExpr{ 127 Fun: ast.NewIdent("panic"), 128 Args: []ast.Expr{&ast.SelectorExpr{ 129 X: &ast.Ident{Name: importCore()}, 130 Sel: &ast.Ident{Name: "Bailout"}, 131 }}, 132 }, 133 }}, 134 } 135 } 136 137 func injectAutoExitToBlockStmtList(lst *[]ast.Stmt, injectHead bool, defaultFlag bool, importCore func() string) { 138 newList := make([]ast.Stmt, 0, 2*len(*lst)+1) 139 flag := defaultFlag 140 appendAutoExpt := func() { 141 newList = append(newList, &ast.ExprStmt{ 142 X: &ast.CallExpr{ 143 Fun: ast.NewIdent(importCore() + ".ExitIfCtxDone"), 144 }, 145 }) 146 } 147 if injectHead { 148 appendAutoExpt() 149 flag = false 150 } 151 for i := 0; i < len(*lst); i++ { 152 stmt := (*lst)[i] 153 if stmt, ok := stmt.(*ast.SendStmt); ok { 154 newList = append(newList, &ast.SelectStmt{ 155 Body: &ast.BlockStmt{ 156 List: []ast.Stmt{ 157 &ast.CommClause{Comm: stmt}, 158 makeExitIfDoneCommClause(importCore), 159 }, 160 }, 161 }) 162 flag = false 163 continue 164 } 165 166 heavy := injectAutoExitToStmt(stmt, importCore, flag) 167 if heavy { 168 if flag { 169 appendAutoExpt() 170 } 171 flag = true 172 } 173 newList = append(newList, stmt) 174 } 175 *lst = newList 176 } 177 178 func injectAutoExitToStmt(stm ast.Stmt, importCore func() string, prevHeavy bool) bool { 179 heavy := isHeavyStmt(stm) 180 switch stm := stm.(type) { 181 case *ast.ForStmt: 182 injectAutoExit(stm.Init, importCore) 183 injectAutoExit(stm.Cond, importCore) 184 injectAutoExit(stm.Post, importCore) 185 injectAutoExitBlock(stm.Body, true, heavy, importCore) 186 case *ast.IfStmt: 187 injectAutoExit(stm.Init, importCore) 188 injectAutoExit(stm.Cond, importCore) 189 injectAutoExitBlock(stm.Body, false, heavy, importCore) 190 case *ast.SwitchStmt: 191 injectAutoExit(stm.Init, importCore) 192 injectAutoExit(stm.Tag, importCore) 193 for _, l := range stm.Body.List { 194 cas := l.(*ast.CaseClause) 195 injectAutoExitToBlockStmtList(&cas.Body, false, heavy, importCore) 196 } 197 case *ast.BlockStmt: 198 injectAutoExitBlock(stm, true, prevHeavy, importCore) 199 case *ast.DeferStmt: 200 // Do not inject autoexit code inside a function defined in a defer statement. 201 default: 202 injectAutoExit(stm, importCore) 203 } 204 return heavy 205 } 206 207 func injectAutoExitToSelectStmt(sel *ast.SelectStmt, importCore func() string) { 208 for _, stmt := range sel.Body.List { 209 stmt := stmt.(*ast.CommClause) 210 if stmt.Comm == nil { 211 // stmt is default case. Do nothing if there is default-case. 212 return 213 } 214 } 215 sel.Body.List = append(sel.Body.List, makeExitIfDoneCommClause(importCore)) 216 } 217 218 func injectAutoExit(node ast.Node, importCore func() string) { 219 if node == nil { 220 return 221 } 222 v := autoExitInjector{importCore: importCore} 223 ast.Walk(&v, node) 224 } 225 226 type autoExitInjector struct { 227 importCore func() string 228 } 229 230 func (v *autoExitInjector) Visit(node ast.Node) ast.Visitor { 231 if decl, ok := node.(*ast.FuncDecl); ok { 232 injectAutoExitBlock(decl.Body, true, false, v.importCore) 233 return nil 234 } 235 if fn, ok := node.(*ast.FuncLit); ok { 236 injectAutoExitBlock(fn.Body, true, false, v.importCore) 237 return nil 238 } 239 if node, ok := node.(*ast.SelectStmt); ok { 240 injectAutoExitToSelectStmt(node, v.importCore) 241 } 242 return v 243 } 244 245 func injectAutoExitToFile(file *ast.File, immg *importManager) { 246 importCore := func() string { 247 corePkg, _ := lgoImporter.Import(core.SelfPkgPath) 248 return immg.shortName(corePkg) 249 } 250 injectAutoExit(file, importCore) 251 } 252 253 // selectCommExprRecorder collects `<-ch` expressions that are used inside select-case clauses. 254 // The result of this recorder is used from mayWrapRecvOp so that it does not modify `<-ch` operations in select-case clauses. 255 type selectCommExprRecorder struct { 256 m map[ast.Expr]bool 257 } 258 259 func (v *selectCommExprRecorder) Visit(node ast.Node) ast.Visitor { 260 sel, ok := node.(*ast.SelectStmt) 261 if !ok { 262 return v 263 } 264 for _, cas := range sel.Body.List { 265 comm := cas.(*ast.CommClause).Comm 266 if comm == nil { 267 continue 268 } 269 if assign, ok := comm.(*ast.AssignStmt); ok { 270 for _, expr := range assign.Rhs { 271 v.m[expr] = true 272 } 273 } 274 if expr, ok := comm.(*ast.ExprStmt); ok { 275 v.m[expr.X] = true 276 } 277 } 278 return v 279 } 280 281 func mayWrapRecvOp(conf *Config, file *ast.File, fset *token.FileSet, checker *types.Checker, pkg *types.Package, runctx types.Object, oldImports []*types.PkgName) (*types.Checker, *types.Package, types.Object, []*types.PkgName) { 282 rec := selectCommExprRecorder{make(map[ast.Expr]bool)} 283 ast.Walk(&rec, file) 284 picker := newNamePicker(checker.Defs) 285 immg := newImportManager(pkg, file, checker) 286 importCore := func() string { 287 corePkg, _ := lgoImporter.Import(core.SelfPkgPath) 288 return immg.shortName(corePkg) 289 } 290 291 var rewritten bool 292 rewriteExpr(file, func(expr ast.Expr) ast.Expr { 293 if rec.m[expr] { 294 return expr 295 } 296 ue, ok := expr.(*ast.UnaryExpr) 297 if !ok || ue.Op != token.ARROW { 298 return expr 299 } 300 rewritten = true 301 wrapperName := picker.NewName("recvChan") 302 decl := makeChanRecvWrapper(ue, wrapperName, immg, importCore, checker) 303 file.Decls = append(file.Decls, decl) 304 return &ast.CallExpr{ 305 Fun: ast.NewIdent(wrapperName), 306 Args: []ast.Expr{ue.X}, 307 } 308 }) 309 if !rewritten { 310 return checker, pkg, runctx, oldImports 311 } 312 if len(immg.injectedImports) > 0 { 313 var newDecls []ast.Decl 314 for _, im := range immg.injectedImports { 315 newDecls = append(newDecls, im) 316 } 317 newDecls = append(newDecls, file.Decls...) 318 file.Decls = newDecls 319 } 320 var err error 321 checker, pkg, runctx, oldImports, err = checkFileInPhase2(conf, file, fset) 322 if err != nil { 323 panic(fmt.Errorf("No error expected but got: %v", err)) 324 } 325 return checker, pkg, runctx, oldImports 326 } 327 328 // makeChanRecvWrapper makes ast.Node trees of a function declaration like 329 // func recvChan1(c chan int) (x int, ok bool) { 330 // select { 331 // case x, ok = <-c: 332 // return 333 // } 334 // } 335 // The select statement is equivalent to <-c in this phase. The code to interrupt the select statement 336 // is injected in the latter phase by autoExitInjector. 337 func makeChanRecvWrapper(expr *ast.UnaryExpr, funcName string, immg *importManager, importCore func() string, checker *types.Checker) *ast.FuncDecl { 338 typeExpr := func(typ types.Type) ast.Expr { 339 s := types.TypeString(typ, func(pkg *types.Package) string { 340 return immg.shortName(pkg) 341 }) 342 expr, err := parser.ParseExpr(s) 343 if err != nil { 344 panic(fmt.Sprintf("Failed to parse type expr %q: %v", s, err)) 345 } 346 return expr 347 } 348 349 var results []*ast.Field 350 typ := checker.Types[expr].Type 351 if tup, ok := typ.(*types.Tuple); ok { 352 results = []*ast.Field{ 353 { 354 Names: []*ast.Ident{ast.NewIdent("x")}, 355 Type: typeExpr(tup.At(0).Type()), 356 }, { 357 Names: []*ast.Ident{ast.NewIdent("ok")}, 358 Type: typeExpr(tup.At(1).Type()), 359 }, 360 } 361 } else { 362 results = []*ast.Field{ 363 { 364 Names: []*ast.Ident{ast.NewIdent("x")}, 365 Type: typeExpr(typ), 366 }, 367 } 368 } 369 chType := checker.Types[expr.X].Type.(*types.Chan) 370 var lhs []ast.Expr 371 if len(results) == 1 { 372 lhs = []ast.Expr{ast.NewIdent("x")} 373 } else { 374 lhs = []ast.Expr{ast.NewIdent("x"), ast.NewIdent("ok")} 375 } 376 return &ast.FuncDecl{ 377 Name: ast.NewIdent(funcName), 378 Type: &ast.FuncType{ 379 Params: &ast.FieldList{ 380 List: []*ast.Field{ 381 { 382 Names: []*ast.Ident{ast.NewIdent("c")}, 383 Type: typeExpr(chType), 384 }, 385 }, 386 }, 387 Results: &ast.FieldList{List: results}, 388 }, 389 Body: &ast.BlockStmt{ 390 List: []ast.Stmt{ 391 &ast.SelectStmt{ 392 Body: &ast.BlockStmt{ 393 List: []ast.Stmt{ 394 &ast.CommClause{ 395 Comm: &ast.AssignStmt{ 396 Tok: token.ASSIGN, 397 Lhs: lhs, 398 Rhs: []ast.Expr{&ast.UnaryExpr{ 399 Op: token.ARROW, 400 X: ast.NewIdent("c"), 401 }}, 402 }, 403 Body: []ast.Stmt{&ast.ReturnStmt{}}, 404 }, 405 }, 406 }, 407 }, 408 }, 409 }, 410 } 411 }