github.com/v2fly/tools@v0.100.0/internal/lsp/source/extract.go (about) 1 // Copyright 2020 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package source 6 7 import ( 8 "bytes" 9 "fmt" 10 "go/ast" 11 "go/format" 12 "go/parser" 13 "go/token" 14 "go/types" 15 "strings" 16 "unicode" 17 18 "github.com/v2fly/tools/go/analysis" 19 "github.com/v2fly/tools/go/ast/astutil" 20 "github.com/v2fly/tools/internal/analysisinternal" 21 "github.com/v2fly/tools/internal/span" 22 ) 23 24 func extractVariable(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { 25 expr, path, ok, err := CanExtractVariable(rng, file) 26 if !ok { 27 return nil, fmt.Errorf("extractVariable: cannot extract %s: %v", fset.Position(rng.Start), err) 28 } 29 30 // Create new AST node for extracted code. 31 var lhsNames []string 32 switch expr := expr.(type) { 33 // TODO: stricter rules for selectorExpr. 34 case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr, 35 *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: 36 lhsNames = append(lhsNames, generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0)) 37 case *ast.CallExpr: 38 tup, ok := info.TypeOf(expr).(*types.Tuple) 39 if !ok { 40 // If the call expression only has one return value, we can treat it the 41 // same as our standard extract variable case. 42 lhsNames = append(lhsNames, 43 generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0)) 44 break 45 } 46 for i := 0; i < tup.Len(); i++ { 47 // Generate a unique variable for each return value. 48 lhsNames = append(lhsNames, 49 generateAvailableIdentifier(expr.Pos(), file, path, info, "x", i)) 50 } 51 default: 52 return nil, fmt.Errorf("cannot extract %T", expr) 53 } 54 55 insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path) 56 if insertBeforeStmt == nil { 57 return nil, fmt.Errorf("cannot find location to insert extraction") 58 } 59 tok := fset.File(expr.Pos()) 60 if tok == nil { 61 return nil, fmt.Errorf("no file for pos %v", fset.Position(file.Pos())) 62 } 63 newLineIndent := "\n" + calculateIndentation(src, tok, insertBeforeStmt) 64 65 lhs := strings.Join(lhsNames, ", ") 66 assignStmt := &ast.AssignStmt{ 67 Lhs: []ast.Expr{ast.NewIdent(lhs)}, 68 Tok: token.DEFINE, 69 Rhs: []ast.Expr{expr}, 70 } 71 var buf bytes.Buffer 72 if err := format.Node(&buf, fset, assignStmt); err != nil { 73 return nil, err 74 } 75 assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent 76 77 return &analysis.SuggestedFix{ 78 TextEdits: []analysis.TextEdit{ 79 { 80 Pos: rng.Start, 81 End: rng.End, 82 NewText: []byte(lhs), 83 }, 84 { 85 Pos: insertBeforeStmt.Pos(), 86 End: insertBeforeStmt.Pos(), 87 NewText: []byte(assignment), 88 }, 89 }, 90 }, nil 91 } 92 93 // CanExtractVariable reports whether the code in the given range can be 94 // extracted to a variable. 95 func CanExtractVariable(rng span.Range, file *ast.File) (ast.Expr, []ast.Node, bool, error) { 96 if rng.Start == rng.End { 97 return nil, nil, false, fmt.Errorf("start and end are equal") 98 } 99 path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) 100 if len(path) == 0 { 101 return nil, nil, false, fmt.Errorf("no path enclosing interval") 102 } 103 for _, n := range path { 104 if _, ok := n.(*ast.ImportSpec); ok { 105 return nil, nil, false, fmt.Errorf("cannot extract variable in an import block") 106 } 107 } 108 node := path[0] 109 if rng.Start != node.Pos() || rng.End != node.End() { 110 return nil, nil, false, fmt.Errorf("range does not map to an AST node") 111 } 112 expr, ok := node.(ast.Expr) 113 if !ok { 114 return nil, nil, false, fmt.Errorf("node is not an expression") 115 } 116 switch expr.(type) { 117 case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr, 118 *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: 119 return expr, path, true, nil 120 } 121 return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr) 122 } 123 124 // Calculate indentation for insertion. 125 // When inserting lines of code, we must ensure that the lines have consistent 126 // formatting (i.e. the proper indentation). To do so, we observe the indentation on the 127 // line of code on which the insertion occurs. 128 func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.Node) string { 129 line := tok.Line(insertBeforeStmt.Pos()) 130 lineOffset := tok.Offset(tok.LineStart(line)) 131 stmtOffset := tok.Offset(insertBeforeStmt.Pos()) 132 return string(content[lineOffset:stmtOffset]) 133 } 134 135 // generateAvailableIdentifier adjusts the new function name until there are no collisons in scope. 136 // Possible collisions include other function and variable names. 137 func generateAvailableIdentifier(pos token.Pos, file *ast.File, path []ast.Node, info *types.Info, prefix string, idx int) string { 138 scopes := CollectScopes(info, path, pos) 139 name := prefix + fmt.Sprintf("%d", idx) 140 for file.Scope.Lookup(name) != nil || !isValidName(name, scopes) { 141 idx++ 142 name = fmt.Sprintf("%v%d", prefix, idx) 143 } 144 return name 145 } 146 147 // isValidName checks for variable collision in scope. 148 func isValidName(name string, scopes []*types.Scope) bool { 149 for _, scope := range scopes { 150 if scope == nil { 151 continue 152 } 153 if scope.Lookup(name) != nil { 154 return false 155 } 156 } 157 return true 158 } 159 160 // returnVariable keeps track of the information we need to properly introduce a new variable 161 // that we will return in the extracted function. 162 type returnVariable struct { 163 // name is the identifier that is used on the left-hand side of the call to 164 // the extracted function. 165 name ast.Expr 166 // decl is the declaration of the variable. It is used in the type signature of the 167 // extracted function and for variable declarations. 168 decl *ast.Field 169 // zeroVal is the "zero value" of the type of the variable. It is used in a return 170 // statement in the extracted function. 171 zeroVal ast.Expr 172 } 173 174 // extractFunction refactors the selected block of code into a new function. 175 // It also replaces the selected block of code with a call to the extracted 176 // function. First, we manually adjust the selection range. We remove trailing 177 // and leading whitespace characters to ensure the range is precisely bounded 178 // by AST nodes. Next, we determine the variables that will be the parameters 179 // and return values of the extracted function. Lastly, we construct the call 180 // of the function and insert this call as well as the extracted function into 181 // their proper locations. 182 func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { 183 p, ok, err := CanExtractFunction(fset, rng, src, file) 184 if !ok { 185 return nil, fmt.Errorf("extractFunction: cannot extract %s: %v", 186 fset.Position(rng.Start), err) 187 } 188 tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start 189 fileScope := info.Scopes[file] 190 if fileScope == nil { 191 return nil, fmt.Errorf("extractFunction: file scope is empty") 192 } 193 pkgScope := fileScope.Parent() 194 if pkgScope == nil { 195 return nil, fmt.Errorf("extractFunction: package scope is empty") 196 } 197 198 // A return statement is non-nested if its parent node is equal to the parent node 199 // of the first node in the selection. These cases must be handled separately because 200 // non-nested return statements are guaranteed to execute. 201 var retStmts []*ast.ReturnStmt 202 var hasNonNestedReturn bool 203 startParent := findParent(outer, start) 204 ast.Inspect(outer, func(n ast.Node) bool { 205 if n == nil { 206 return false 207 } 208 if n.Pos() < rng.Start || n.End() > rng.End { 209 return n.Pos() <= rng.End 210 } 211 ret, ok := n.(*ast.ReturnStmt) 212 if !ok { 213 return true 214 } 215 if findParent(outer, n) == startParent { 216 hasNonNestedReturn = true 217 } 218 retStmts = append(retStmts, ret) 219 return false 220 }) 221 containsReturnStatement := len(retStmts) > 0 222 223 // Now that we have determined the correct range for the selection block, 224 // we must determine the signature of the extracted function. We will then replace 225 // the block with an assignment statement that calls the extracted function with 226 // the appropriate parameters and return values. 227 variables, err := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0]) 228 if err != nil { 229 return nil, err 230 } 231 232 var ( 233 params, returns []ast.Expr // used when calling the extracted function 234 paramTypes, returnTypes []*ast.Field // used in the signature of the extracted function 235 uninitialized []types.Object // vars we will need to initialize before the call 236 ) 237 238 // Avoid duplicates while traversing vars and uninitialzed. 239 seenVars := make(map[types.Object]ast.Expr) 240 seenUninitialized := make(map[types.Object]struct{}) 241 242 // Some variables on the left-hand side of our assignment statement may be free. If our 243 // selection begins in the same scope in which the free variable is defined, we can 244 // redefine it in our assignment statement. See the following example, where 'b' and 245 // 'err' (both free variables) can be redefined in the second funcCall() while maintaining 246 // correctness. 247 // 248 // 249 // Not Redefined: 250 // 251 // a, err := funcCall() 252 // var b int 253 // b, err = funcCall() 254 // 255 // Redefined: 256 // 257 // a, err := funcCall() 258 // b, err := funcCall() 259 // 260 // We track the number of free variables that can be redefined to maintain our preference 261 // of using "x, y, z := fn()" style assignment statements. 262 var canRedefineCount int 263 264 // Each identifier in the selected block must become (1) a parameter to the 265 // extracted function, (2) a return value of the extracted function, or (3) a local 266 // variable in the extracted function. Determine the outcome(s) for each variable 267 // based on whether it is free, altered within the selected block, and used outside 268 // of the selected block. 269 for _, v := range variables { 270 if _, ok := seenVars[v.obj]; ok { 271 continue 272 } 273 if v.obj.Name() == "_" { 274 // The blank identifier is always a local variable 275 continue 276 } 277 typ := analysisinternal.TypeExpr(fset, file, pkg, v.obj.Type()) 278 if typ == nil { 279 return nil, fmt.Errorf("nil AST expression for type: %v", v.obj.Name()) 280 } 281 seenVars[v.obj] = typ 282 identifier := ast.NewIdent(v.obj.Name()) 283 // An identifier must meet three conditions to become a return value of the 284 // extracted function. (1) its value must be defined or reassigned within 285 // the selection (isAssigned), (2) it must be used at least once after the 286 // selection (isUsed), and (3) its first use after the selection 287 // cannot be its own reassignment or redefinition (objOverriden). 288 if v.obj.Parent() == nil { 289 return nil, fmt.Errorf("parent nil") 290 } 291 isUsed, firstUseAfter := objUsed(info, span.NewRange(fset, rng.End, v.obj.Parent().End()), v.obj) 292 if v.assigned && isUsed && !varOverridden(info, firstUseAfter, v.obj, v.free, outer) { 293 returnTypes = append(returnTypes, &ast.Field{Type: typ}) 294 returns = append(returns, identifier) 295 if !v.free { 296 uninitialized = append(uninitialized, v.obj) 297 } else if v.obj.Parent().Pos() == startParent.Pos() { 298 canRedefineCount++ 299 } 300 } 301 // An identifier must meet two conditions to become a parameter of the 302 // extracted function. (1) it must be free (isFree), and (2) its first 303 // use within the selection cannot be its own definition (isDefined). 304 if v.free && !v.defined { 305 params = append(params, identifier) 306 paramTypes = append(paramTypes, &ast.Field{ 307 Names: []*ast.Ident{identifier}, 308 Type: typ, 309 }) 310 } 311 } 312 313 // Find the function literal that encloses the selection. The enclosing function literal 314 // may not be the enclosing function declaration (i.e. 'outer'). For example, in the 315 // following block: 316 // 317 // func main() { 318 // ast.Inspect(node, func(n ast.Node) bool { 319 // v := 1 // this line extracted 320 // return true 321 // }) 322 // } 323 // 324 // 'outer' is main(). However, the extracted selection most directly belongs to 325 // the anonymous function literal, the second argument of ast.Inspect(). We use the 326 // enclosing function literal to determine the proper return types for return statements 327 // within the selection. We still need the enclosing function declaration because this is 328 // the top-level declaration. We inspect the top-level declaration to look for variables 329 // as well as for code replacement. 330 enclosing := outer.Type 331 for _, p := range path { 332 if p == enclosing { 333 break 334 } 335 if fl, ok := p.(*ast.FuncLit); ok { 336 enclosing = fl.Type 337 break 338 } 339 } 340 341 // We put the selection in a constructed file. We can then traverse and edit 342 // the extracted selection without modifying the original AST. 343 startOffset := tok.Offset(rng.Start) 344 endOffset := tok.Offset(rng.End) 345 selection := src[startOffset:endOffset] 346 extractedBlock, err := parseBlockStmt(fset, selection) 347 if err != nil { 348 return nil, err 349 } 350 351 // We need to account for return statements in the selected block, as they will complicate 352 // the logical flow of the extracted function. See the following example, where ** denotes 353 // the range to be extracted. 354 // 355 // Before: 356 // 357 // func _() int { 358 // a := 1 359 // b := 2 360 // **if a == b { 361 // return a 362 // }** 363 // ... 364 // } 365 // 366 // After: 367 // 368 // func _() int { 369 // a := 1 370 // b := 2 371 // cond0, ret0 := x0(a, b) 372 // if cond0 { 373 // return ret0 374 // } 375 // ... 376 // } 377 // 378 // func x0(a int, b int) (bool, int) { 379 // if a == b { 380 // return true, a 381 // } 382 // return false, 0 383 // } 384 // 385 // We handle returns by adding an additional boolean return value to the extracted function. 386 // This bool reports whether the original function would have returned. Because the 387 // extracted selection contains a return statement, we must also add the types in the 388 // return signature of the enclosing function to the return signature of the 389 // extracted function. We then add an extra if statement checking this boolean value 390 // in the original function. If the condition is met, the original function should 391 // return a value, mimicking the functionality of the original return statement(s) 392 // in the selection. 393 // 394 // If there is a return that is guaranteed to execute (hasNonNestedReturns=true), then 395 // we don't need to include this additional condition check and can simply return. 396 // 397 // Before: 398 // 399 // func _() int { 400 // a := 1 401 // b := 2 402 // **if a == b { 403 // return a 404 // } 405 // return b** 406 // } 407 // 408 // After: 409 // 410 // func _() int { 411 // a := 1 412 // b := 2 413 // return x0(a, b) 414 // } 415 // 416 // func x0(a int, b int) int { 417 // if a == b { 418 // return a 419 // } 420 // return b 421 // } 422 423 var retVars []*returnVariable 424 var ifReturn *ast.IfStmt 425 if containsReturnStatement { 426 if !hasNonNestedReturn { 427 // The selected block contained return statements, so we have to modify the 428 // signature of the extracted function as described above. Adjust all of 429 // the return statements in the extracted function to reflect this change in 430 // signature. 431 if err := adjustReturnStatements(returnTypes, seenVars, fset, file, 432 pkg, extractedBlock); err != nil { 433 return nil, err 434 } 435 } 436 // Collect the additional return values and types needed to accommodate return 437 // statements in the selection. Update the type signature of the extracted 438 // function and construct the if statement that will be inserted in the enclosing 439 // function. 440 retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start, hasNonNestedReturn) 441 if err != nil { 442 return nil, err 443 } 444 } 445 446 // Add a return statement to the end of the new function. This return statement must include 447 // the values for the types of the original extracted function signature and (if a return 448 // statement is present in the selection) enclosing function signature. 449 // This only needs to be done if the selections does not have a non-nested return, otherwise 450 // it already terminates with a return statement. 451 hasReturnValues := len(returns)+len(retVars) > 0 452 if hasReturnValues && !hasNonNestedReturn { 453 extractedBlock.List = append(extractedBlock.List, &ast.ReturnStmt{ 454 Results: append(returns, getZeroVals(retVars)...), 455 }) 456 } 457 458 // Construct the appropriate call to the extracted function. 459 // We must meet two conditions to use ":=" instead of '='. (1) there must be at least 460 // one variable on the lhs that is uninitailized (non-free) prior to the assignment. 461 // (2) all of the initialized (free) variables on the lhs must be able to be redefined. 462 sym := token.ASSIGN 463 canDefineCount := len(uninitialized) + canRedefineCount 464 canDefine := len(uninitialized)+len(retVars) > 0 && canDefineCount == len(returns) 465 if canDefine { 466 sym = token.DEFINE 467 } 468 funName := generateAvailableIdentifier(rng.Start, file, path, info, "fn", 0) 469 extractedFunCall := generateFuncCall(hasNonNestedReturn, hasReturnValues, params, 470 append(returns, getNames(retVars)...), funName, sym) 471 472 // Build the extracted function. 473 newFunc := &ast.FuncDecl{ 474 Name: ast.NewIdent(funName), 475 Type: &ast.FuncType{ 476 Params: &ast.FieldList{List: paramTypes}, 477 Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)}, 478 }, 479 Body: extractedBlock, 480 } 481 482 // Create variable declarations for any identifiers that need to be initialized prior to 483 // calling the extracted function. We do not manually initialize variables if every return 484 // value is unitialized. We can use := to initialize the variables in this situation. 485 var declarations []ast.Stmt 486 if canDefineCount != len(returns) { 487 declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars) 488 } 489 490 var declBuf, replaceBuf, newFuncBuf, ifBuf, commentBuf bytes.Buffer 491 if err := format.Node(&declBuf, fset, declarations); err != nil { 492 return nil, err 493 } 494 if err := format.Node(&replaceBuf, fset, extractedFunCall); err != nil { 495 return nil, err 496 } 497 if ifReturn != nil { 498 if err := format.Node(&ifBuf, fset, ifReturn); err != nil { 499 return nil, err 500 } 501 } 502 if err := format.Node(&newFuncBuf, fset, newFunc); err != nil { 503 return nil, err 504 } 505 // Find all the comments within the range and print them to be put somewhere. 506 // TODO(suzmue): print these in the extracted function at the correct place. 507 for _, cg := range file.Comments { 508 if cg.Pos().IsValid() && cg.Pos() < rng.End && cg.Pos() >= rng.Start { 509 for _, c := range cg.List { 510 fmt.Fprintln(&commentBuf, c.Text) 511 } 512 } 513 } 514 515 // We're going to replace the whole enclosing function, 516 // so preserve the text before and after the selected block. 517 outerStart := tok.Offset(outer.Pos()) 518 outerEnd := tok.Offset(outer.End()) 519 before := src[outerStart:startOffset] 520 after := src[endOffset:outerEnd] 521 newLineIndent := "\n" + calculateIndentation(src, tok, start) 522 523 var fullReplacement strings.Builder 524 fullReplacement.Write(before) 525 if commentBuf.Len() > 0 { 526 comments := strings.ReplaceAll(commentBuf.String(), "\n", newLineIndent) 527 fullReplacement.WriteString(comments) 528 } 529 if declBuf.Len() > 0 { // add any initializations, if needed 530 initializations := strings.ReplaceAll(declBuf.String(), "\n", newLineIndent) + 531 newLineIndent 532 fullReplacement.WriteString(initializations) 533 } 534 fullReplacement.Write(replaceBuf.Bytes()) // call the extracted function 535 if ifBuf.Len() > 0 { // add the if statement below the function call, if needed 536 ifstatement := newLineIndent + 537 strings.ReplaceAll(ifBuf.String(), "\n", newLineIndent) 538 fullReplacement.WriteString(ifstatement) 539 } 540 fullReplacement.Write(after) 541 fullReplacement.WriteString("\n\n") // add newlines after the enclosing function 542 fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function 543 544 return &analysis.SuggestedFix{ 545 TextEdits: []analysis.TextEdit{{ 546 Pos: outer.Pos(), 547 End: outer.End(), 548 NewText: []byte(fullReplacement.String()), 549 }}, 550 }, nil 551 } 552 553 // adjustRangeForWhitespace adjusts the given range to exclude unnecessary leading or 554 // trailing whitespace characters from selection. In the following example, each line 555 // of the if statement is indented once. There are also two extra spaces after the 556 // closing bracket before the line break. 557 // 558 // \tif (true) { 559 // \t _ = 1 560 // \t} \n 561 // 562 // By default, a valid range begins at 'if' and ends at the first whitespace character 563 // after the '}'. But, users are likely to highlight full lines rather than adjusting 564 // their cursors for whitespace. To support this use case, we must manually adjust the 565 // ranges to match the correct AST node. In this particular example, we would adjust 566 // rng.Start forward by one byte, and rng.End backwards by two bytes. 567 func adjustRangeForWhitespace(rng span.Range, tok *token.File, content []byte) span.Range { 568 offset := tok.Offset(rng.Start) 569 for offset < len(content) { 570 if !unicode.IsSpace(rune(content[offset])) { 571 break 572 } 573 // Move forwards one byte to find a non-whitespace character. 574 offset += 1 575 } 576 rng.Start = tok.Pos(offset) 577 578 // Move backwards to find a non-whitespace character. 579 offset = tok.Offset(rng.End) 580 for o := offset - 1; 0 <= o && o < len(content); o-- { 581 if !unicode.IsSpace(rune(content[o])) { 582 break 583 } 584 offset = o 585 } 586 rng.End = tok.Pos(offset) 587 return rng 588 } 589 590 // findParent finds the parent AST node of the given target node, if the target is a 591 // descendant of the starting node. 592 func findParent(start ast.Node, target ast.Node) ast.Node { 593 var parent ast.Node 594 analysisinternal.WalkASTWithParent(start, func(n, p ast.Node) bool { 595 if n == target { 596 parent = p 597 return false 598 } 599 return true 600 }) 601 return parent 602 } 603 604 // variable describes the status of a variable within a selection. 605 type variable struct { 606 obj types.Object 607 608 // free reports whether the variable is a free variable, meaning it should 609 // be a parameter to the extracted function. 610 free bool 611 612 // assigned reports whether the variable is assigned to in the selection. 613 assigned bool 614 615 // defined reports whether the variable is defined in the selection. 616 defined bool 617 } 618 619 // collectFreeVars maps each identifier in the given range to whether it is "free." 620 // Given a range, a variable in that range is defined as "free" if it is declared 621 // outside of the range and neither at the file scope nor package scope. These free 622 // variables will be used as arguments in the extracted function. It also returns a 623 // list of identifiers that may need to be returned by the extracted function. 624 // Some of the code in this function has been adapted from tools/cmd/guru/freevars.go. 625 func collectFreeVars(info *types.Info, file *ast.File, fileScope, pkgScope *types.Scope, rng span.Range, node ast.Node) ([]*variable, error) { 626 // id returns non-nil if n denotes an object that is referenced by the span 627 // and defined either within the span or in the lexical environment. The bool 628 // return value acts as an indicator for where it was defined. 629 id := func(n *ast.Ident) (types.Object, bool) { 630 obj := info.Uses[n] 631 if obj == nil { 632 return info.Defs[n], false 633 } 634 if obj.Name() == "_" { 635 return nil, false // exclude objects denoting '_' 636 } 637 if _, ok := obj.(*types.PkgName); ok { 638 return nil, false // imported package 639 } 640 if !(file.Pos() <= obj.Pos() && obj.Pos() <= file.End()) { 641 return nil, false // not defined in this file 642 } 643 scope := obj.Parent() 644 if scope == nil { 645 return nil, false // e.g. interface method, struct field 646 } 647 if scope == fileScope || scope == pkgScope { 648 return nil, false // defined at file or package scope 649 } 650 if rng.Start <= obj.Pos() && obj.Pos() <= rng.End { 651 return obj, false // defined within selection => not free 652 } 653 return obj, true 654 } 655 // sel returns non-nil if n denotes a selection o.x.y that is referenced by the 656 // span and defined either within the span or in the lexical environment. The bool 657 // return value acts as an indicator for where it was defined. 658 var sel func(n *ast.SelectorExpr) (types.Object, bool) 659 sel = func(n *ast.SelectorExpr) (types.Object, bool) { 660 switch x := astutil.Unparen(n.X).(type) { 661 case *ast.SelectorExpr: 662 return sel(x) 663 case *ast.Ident: 664 return id(x) 665 } 666 return nil, false 667 } 668 seen := make(map[types.Object]*variable) 669 firstUseIn := make(map[types.Object]token.Pos) 670 var vars []types.Object 671 ast.Inspect(node, func(n ast.Node) bool { 672 if n == nil { 673 return false 674 } 675 if rng.Start <= n.Pos() && n.End() <= rng.End { 676 var obj types.Object 677 var isFree, prune bool 678 switch n := n.(type) { 679 case *ast.Ident: 680 obj, isFree = id(n) 681 case *ast.SelectorExpr: 682 obj, isFree = sel(n) 683 prune = true 684 } 685 if obj != nil { 686 seen[obj] = &variable{ 687 obj: obj, 688 free: isFree, 689 } 690 vars = append(vars, obj) 691 // Find the first time that the object is used in the selection. 692 first, ok := firstUseIn[obj] 693 if !ok || n.Pos() < first { 694 firstUseIn[obj] = n.Pos() 695 } 696 if prune { 697 return false 698 } 699 } 700 } 701 return n.Pos() <= rng.End 702 }) 703 704 // Find identifiers that are initialized or whose values are altered at some 705 // point in the selected block. For example, in a selected block from lines 2-4, 706 // variables x, y, and z are included in assigned. However, in a selected block 707 // from lines 3-4, only variables y and z are included in assigned. 708 // 709 // 1: var a int 710 // 2: var x int 711 // 3: y := 3 712 // 4: z := x + a 713 // 714 ast.Inspect(node, func(n ast.Node) bool { 715 if n == nil { 716 return false 717 } 718 if n.Pos() < rng.Start || n.End() > rng.End { 719 return n.Pos() <= rng.End 720 } 721 switch n := n.(type) { 722 case *ast.AssignStmt: 723 for _, assignment := range n.Lhs { 724 lhs, ok := assignment.(*ast.Ident) 725 if !ok { 726 continue 727 } 728 obj, _ := id(lhs) 729 if obj == nil { 730 continue 731 } 732 if _, ok := seen[obj]; !ok { 733 continue 734 } 735 seen[obj].assigned = true 736 if n.Tok != token.DEFINE { 737 continue 738 } 739 // Find identifiers that are defined prior to being used 740 // elsewhere in the selection. 741 // TODO: Include identifiers that are assigned prior to being 742 // used elsewhere in the selection. Then, change the assignment 743 // to a definition in the extracted function. 744 if firstUseIn[obj] != lhs.Pos() { 745 continue 746 } 747 // Ensure that the object is not used in its own re-definition. 748 // For example: 749 // var f float64 750 // f, e := math.Frexp(f) 751 for _, expr := range n.Rhs { 752 if referencesObj(info, expr, obj) { 753 continue 754 } 755 if _, ok := seen[obj]; !ok { 756 continue 757 } 758 seen[obj].defined = true 759 break 760 } 761 } 762 return false 763 case *ast.DeclStmt: 764 gen, ok := n.Decl.(*ast.GenDecl) 765 if !ok { 766 return false 767 } 768 for _, spec := range gen.Specs { 769 vSpecs, ok := spec.(*ast.ValueSpec) 770 if !ok { 771 continue 772 } 773 for _, vSpec := range vSpecs.Names { 774 obj, _ := id(vSpec) 775 if obj == nil { 776 continue 777 } 778 if _, ok := seen[obj]; !ok { 779 continue 780 } 781 seen[obj].assigned = true 782 } 783 } 784 return false 785 case *ast.IncDecStmt: 786 if ident, ok := n.X.(*ast.Ident); !ok { 787 return false 788 } else if obj, _ := id(ident); obj == nil { 789 return false 790 } else { 791 if _, ok := seen[obj]; !ok { 792 return false 793 } 794 seen[obj].assigned = true 795 } 796 } 797 return true 798 }) 799 var variables []*variable 800 for _, obj := range vars { 801 v, ok := seen[obj] 802 if !ok { 803 return nil, fmt.Errorf("no seen types.Object for %v", obj) 804 } 805 variables = append(variables, v) 806 } 807 return variables, nil 808 } 809 810 // referencesObj checks whether the given object appears in the given expression. 811 func referencesObj(info *types.Info, expr ast.Expr, obj types.Object) bool { 812 var hasObj bool 813 ast.Inspect(expr, func(n ast.Node) bool { 814 if n == nil { 815 return false 816 } 817 ident, ok := n.(*ast.Ident) 818 if !ok { 819 return true 820 } 821 objUse := info.Uses[ident] 822 if obj == objUse { 823 hasObj = true 824 return false 825 } 826 return false 827 }) 828 return hasObj 829 } 830 831 type fnExtractParams struct { 832 tok *token.File 833 path []ast.Node 834 rng span.Range 835 outer *ast.FuncDecl 836 start ast.Node 837 } 838 839 // CanExtractFunction reports whether the code in the given range can be 840 // extracted to a function. 841 func CanExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File) (*fnExtractParams, bool, error) { 842 if rng.Start == rng.End { 843 return nil, false, fmt.Errorf("start and end are equal") 844 } 845 tok := fset.File(file.Pos()) 846 if tok == nil { 847 return nil, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos())) 848 } 849 rng = adjustRangeForWhitespace(rng, tok, src) 850 path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) 851 if len(path) == 0 { 852 return nil, false, fmt.Errorf("no path enclosing interval") 853 } 854 // Node that encloses the selection must be a statement. 855 // TODO: Support function extraction for an expression. 856 _, ok := path[0].(ast.Stmt) 857 if !ok { 858 return nil, false, fmt.Errorf("node is not a statement") 859 } 860 861 // Find the function declaration that encloses the selection. 862 var outer *ast.FuncDecl 863 for _, p := range path { 864 if p, ok := p.(*ast.FuncDecl); ok { 865 outer = p 866 break 867 } 868 } 869 if outer == nil { 870 return nil, false, fmt.Errorf("no enclosing function") 871 } 872 873 // Find the nodes at the start and end of the selection. 874 var start, end ast.Node 875 ast.Inspect(outer, func(n ast.Node) bool { 876 if n == nil { 877 return false 878 } 879 // Do not override 'start' with a node that begins at the same location 880 // but is nested further from 'outer'. 881 if start == nil && n.Pos() == rng.Start && n.End() <= rng.End { 882 start = n 883 } 884 if end == nil && n.End() == rng.End && n.Pos() >= rng.Start { 885 end = n 886 } 887 return n.Pos() <= rng.End 888 }) 889 if start == nil || end == nil { 890 return nil, false, fmt.Errorf("range does not map to AST nodes") 891 } 892 return &fnExtractParams{ 893 tok: tok, 894 path: path, 895 rng: rng, 896 outer: outer, 897 start: start, 898 }, true, nil 899 } 900 901 // objUsed checks if the object is used within the range. It returns the first 902 // occurrence of the object in the range, if it exists. 903 func objUsed(info *types.Info, rng span.Range, obj types.Object) (bool, *ast.Ident) { 904 var firstUse *ast.Ident 905 for id, objUse := range info.Uses { 906 if obj != objUse { 907 continue 908 } 909 if id.Pos() < rng.Start || id.End() > rng.End { 910 continue 911 } 912 if firstUse == nil || id.Pos() < firstUse.Pos() { 913 firstUse = id 914 } 915 } 916 return firstUse != nil, firstUse 917 } 918 919 // varOverridden traverses the given AST node until we find the given identifier. Then, we 920 // examine the occurrence of the given identifier and check for (1) whether the identifier 921 // is being redefined. If the identifier is free, we also check for (2) whether the identifier 922 // is being reassigned. We will not include an identifier in the return statement of the 923 // extracted function if it meets one of the above conditions. 924 func varOverridden(info *types.Info, firstUse *ast.Ident, obj types.Object, isFree bool, node ast.Node) bool { 925 var isOverriden bool 926 ast.Inspect(node, func(n ast.Node) bool { 927 if n == nil { 928 return false 929 } 930 assignment, ok := n.(*ast.AssignStmt) 931 if !ok { 932 return true 933 } 934 // A free variable is initialized prior to the selection. We can always reassign 935 // this variable after the selection because it has already been defined. 936 // Conversely, a non-free variable is initialized within the selection. Thus, we 937 // cannot reassign this variable after the selection unless it is initialized and 938 // returned by the extracted function. 939 if !isFree && assignment.Tok == token.ASSIGN { 940 return false 941 } 942 for _, assigned := range assignment.Lhs { 943 ident, ok := assigned.(*ast.Ident) 944 // Check if we found the first use of the identifier. 945 if !ok || ident != firstUse { 946 continue 947 } 948 objUse := info.Uses[ident] 949 if objUse == nil || objUse != obj { 950 continue 951 } 952 // Ensure that the object is not used in its own definition. 953 // For example: 954 // var f float64 955 // f, e := math.Frexp(f) 956 for _, expr := range assignment.Rhs { 957 if referencesObj(info, expr, obj) { 958 return false 959 } 960 } 961 isOverriden = true 962 return false 963 } 964 return false 965 }) 966 return isOverriden 967 } 968 969 // parseExtraction generates an AST file from the given text. We then return the portion of the 970 // file that represents the text. 971 func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) { 972 text := "package main\nfunc _() { " + string(src) + " }" 973 extract, err := parser.ParseFile(fset, "", text, 0) 974 if err != nil { 975 return nil, err 976 } 977 if len(extract.Decls) == 0 { 978 return nil, fmt.Errorf("parsed file does not contain any declarations") 979 } 980 decl, ok := extract.Decls[0].(*ast.FuncDecl) 981 if !ok { 982 return nil, fmt.Errorf("parsed file does not contain expected function declaration") 983 } 984 if decl.Body == nil { 985 return nil, fmt.Errorf("extracted function has no body") 986 } 987 return decl.Body, nil 988 } 989 990 // generateReturnInfo generates the information we need to adjust the return statements and 991 // signature of the extracted function. We prepare names, signatures, and "zero values" that 992 // represent the new variables. We also use this information to construct the if statement that 993 // is inserted below the call to the extracted function. 994 func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, fset *token.FileSet, pos token.Pos, hasNonNestedReturns bool) ([]*returnVariable, *ast.IfStmt, error) { 995 var retVars []*returnVariable 996 var cond *ast.Ident 997 if !hasNonNestedReturns { 998 // Generate information for the added bool value. 999 cond = &ast.Ident{Name: generateAvailableIdentifier(pos, file, path, info, "cond", 0)} 1000 retVars = append(retVars, &returnVariable{ 1001 name: cond, 1002 decl: &ast.Field{Type: ast.NewIdent("bool")}, 1003 zeroVal: ast.NewIdent("false"), 1004 }) 1005 } 1006 // Generate information for the values in the return signature of the enclosing function. 1007 if enclosing.Results != nil { 1008 for i, field := range enclosing.Results.List { 1009 typ := info.TypeOf(field.Type) 1010 if typ == nil { 1011 return nil, nil, fmt.Errorf( 1012 "failed type conversion, AST expression: %T", field.Type) 1013 } 1014 expr := analysisinternal.TypeExpr(fset, file, pkg, typ) 1015 if expr == nil { 1016 return nil, nil, fmt.Errorf("nil AST expression") 1017 } 1018 retVars = append(retVars, &returnVariable{ 1019 name: ast.NewIdent(generateAvailableIdentifier(pos, file, 1020 path, info, "ret", i)), 1021 decl: &ast.Field{Type: expr}, 1022 zeroVal: analysisinternal.ZeroValue( 1023 fset, file, pkg, typ), 1024 }) 1025 } 1026 } 1027 var ifReturn *ast.IfStmt 1028 if !hasNonNestedReturns { 1029 // Create the return statement for the enclosing function. We must exclude the variable 1030 // for the condition of the if statement (cond) from the return statement. 1031 ifReturn = &ast.IfStmt{ 1032 Cond: cond, 1033 Body: &ast.BlockStmt{ 1034 List: []ast.Stmt{&ast.ReturnStmt{Results: getNames(retVars)[1:]}}, 1035 }, 1036 } 1037 } 1038 return retVars, ifReturn, nil 1039 } 1040 1041 // adjustReturnStatements adds "zero values" of the given types to each return statement 1042 // in the given AST node. 1043 func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]ast.Expr, fset *token.FileSet, file *ast.File, pkg *types.Package, extractedBlock *ast.BlockStmt) error { 1044 var zeroVals []ast.Expr 1045 // Create "zero values" for each type. 1046 for _, returnType := range returnTypes { 1047 var val ast.Expr 1048 for obj, typ := range seenVars { 1049 if typ != returnType.Type { 1050 continue 1051 } 1052 val = analysisinternal.ZeroValue(fset, file, pkg, obj.Type()) 1053 break 1054 } 1055 if val == nil { 1056 return fmt.Errorf( 1057 "could not find matching AST expression for %T", returnType.Type) 1058 } 1059 zeroVals = append(zeroVals, val) 1060 } 1061 // Add "zero values" to each return statement. 1062 // The bool reports whether the enclosing function should return after calling the 1063 // extracted function. We set the bool to 'true' because, if these return statements 1064 // execute, the extracted function terminates early, and the enclosing function must 1065 // return as well. 1066 zeroVals = append(zeroVals, ast.NewIdent("true")) 1067 ast.Inspect(extractedBlock, func(n ast.Node) bool { 1068 if n == nil { 1069 return false 1070 } 1071 if n, ok := n.(*ast.ReturnStmt); ok { 1072 n.Results = append(zeroVals, n.Results...) 1073 return false 1074 } 1075 return true 1076 }) 1077 return nil 1078 } 1079 1080 // generateFuncCall constructs a call expression for the extracted function, described by the 1081 // given parameters and return variables. 1082 func generateFuncCall(hasNonNestedReturn, hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token) ast.Node { 1083 var replace ast.Node 1084 if hasReturnVals { 1085 callExpr := &ast.CallExpr{ 1086 Fun: ast.NewIdent(name), 1087 Args: params, 1088 } 1089 if hasNonNestedReturn { 1090 // Create a return statement that returns the result of the function call. 1091 replace = &ast.ReturnStmt{ 1092 Return: 0, 1093 Results: []ast.Expr{callExpr}, 1094 } 1095 } else { 1096 // Assign the result of the function call. 1097 replace = &ast.AssignStmt{ 1098 Lhs: returns, 1099 Tok: token, 1100 Rhs: []ast.Expr{callExpr}, 1101 } 1102 } 1103 } else { 1104 replace = &ast.CallExpr{ 1105 Fun: ast.NewIdent(name), 1106 Args: params, 1107 } 1108 } 1109 return replace 1110 } 1111 1112 // initializeVars creates variable declarations, if needed. 1113 // Our preference is to replace the selected block with an "x, y, z := fn()" style 1114 // assignment statement. We can use this style when all of the variables in the 1115 // extracted function's return statement are either not defined prior to the extracted block 1116 // or can be safely redefined. However, for example, if z is already defined 1117 // in a different scope, we replace the selected block with: 1118 // 1119 // var x int 1120 // var y string 1121 // x, y, z = fn() 1122 func initializeVars(uninitialized []types.Object, retVars []*returnVariable, seenUninitialized map[types.Object]struct{}, seenVars map[types.Object]ast.Expr) []ast.Stmt { 1123 var declarations []ast.Stmt 1124 for _, obj := range uninitialized { 1125 if _, ok := seenUninitialized[obj]; ok { 1126 continue 1127 } 1128 seenUninitialized[obj] = struct{}{} 1129 valSpec := &ast.ValueSpec{ 1130 Names: []*ast.Ident{ast.NewIdent(obj.Name())}, 1131 Type: seenVars[obj], 1132 } 1133 genDecl := &ast.GenDecl{ 1134 Tok: token.VAR, 1135 Specs: []ast.Spec{valSpec}, 1136 } 1137 declarations = append(declarations, &ast.DeclStmt{Decl: genDecl}) 1138 } 1139 // Each variable added from a return statement in the selection 1140 // must be initialized. 1141 for i, retVar := range retVars { 1142 n := retVar.name.(*ast.Ident) 1143 valSpec := &ast.ValueSpec{ 1144 Names: []*ast.Ident{n}, 1145 Type: retVars[i].decl.Type, 1146 } 1147 genDecl := &ast.GenDecl{ 1148 Tok: token.VAR, 1149 Specs: []ast.Spec{valSpec}, 1150 } 1151 declarations = append(declarations, &ast.DeclStmt{Decl: genDecl}) 1152 } 1153 return declarations 1154 } 1155 1156 // getNames returns the names from the given list of returnVariable. 1157 func getNames(retVars []*returnVariable) []ast.Expr { 1158 var names []ast.Expr 1159 for _, retVar := range retVars { 1160 names = append(names, retVar.name) 1161 } 1162 return names 1163 } 1164 1165 // getZeroVals returns the "zero values" from the given list of returnVariable. 1166 func getZeroVals(retVars []*returnVariable) []ast.Expr { 1167 var zvs []ast.Expr 1168 for _, retVar := range retVars { 1169 zvs = append(zvs, retVar.zeroVal) 1170 } 1171 return zvs 1172 } 1173 1174 // getDecls returns the declarations from the given list of returnVariable. 1175 func getDecls(retVars []*returnVariable) []*ast.Field { 1176 var decls []*ast.Field 1177 for _, retVar := range retVars { 1178 decls = append(decls, retVar.decl) 1179 } 1180 return decls 1181 }