github.com/jhump/golang-x-tools@v0.0.0-20220218190644-4958d6d39439/internal/lsp/source/extract.go (about) 1 // Copyright 2020 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package source 6 7 import ( 8 "bytes" 9 "fmt" 10 "go/ast" 11 "go/format" 12 "go/parser" 13 "go/token" 14 "go/types" 15 "strings" 16 "unicode" 17 18 "github.com/jhump/golang-x-tools/go/analysis" 19 "github.com/jhump/golang-x-tools/go/ast/astutil" 20 "github.com/jhump/golang-x-tools/internal/analysisinternal" 21 "github.com/jhump/golang-x-tools/internal/span" 22 ) 23 24 func extractVariable(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { 25 expr, path, ok, err := CanExtractVariable(rng, file) 26 if !ok { 27 return nil, fmt.Errorf("extractVariable: cannot extract %s: %v", fset.Position(rng.Start), err) 28 } 29 30 // Create new AST node for extracted code. 31 var lhsNames []string 32 switch expr := expr.(type) { 33 // TODO: stricter rules for selectorExpr. 34 case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr, 35 *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: 36 lhsName, _ := generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0) 37 lhsNames = append(lhsNames, lhsName) 38 case *ast.CallExpr: 39 tup, ok := info.TypeOf(expr).(*types.Tuple) 40 if !ok { 41 // If the call expression only has one return value, we can treat it the 42 // same as our standard extract variable case. 43 lhsName, _ := generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0) 44 lhsNames = append(lhsNames, lhsName) 45 break 46 } 47 idx := 0 48 for i := 0; i < tup.Len(); i++ { 49 // Generate a unique variable for each return value. 50 var lhsName string 51 lhsName, idx = generateAvailableIdentifier(expr.Pos(), file, path, info, "x", idx) 52 lhsNames = append(lhsNames, lhsName) 53 } 54 default: 55 return nil, fmt.Errorf("cannot extract %T", expr) 56 } 57 58 insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path) 59 if insertBeforeStmt == nil { 60 return nil, fmt.Errorf("cannot find location to insert extraction") 61 } 62 tok := fset.File(expr.Pos()) 63 if tok == nil { 64 return nil, fmt.Errorf("no file for pos %v", fset.Position(file.Pos())) 65 } 66 indent, err := calculateIndentation(src, tok, insertBeforeStmt) 67 if err != nil { 68 return nil, err 69 } 70 newLineIndent := "\n" + indent 71 72 lhs := strings.Join(lhsNames, ", ") 73 assignStmt := &ast.AssignStmt{ 74 Lhs: []ast.Expr{ast.NewIdent(lhs)}, 75 Tok: token.DEFINE, 76 Rhs: []ast.Expr{expr}, 77 } 78 var buf bytes.Buffer 79 if err := format.Node(&buf, fset, assignStmt); err != nil { 80 return nil, err 81 } 82 assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent 83 84 return &analysis.SuggestedFix{ 85 TextEdits: []analysis.TextEdit{ 86 { 87 Pos: insertBeforeStmt.Pos(), 88 End: insertBeforeStmt.Pos(), 89 NewText: []byte(assignment), 90 }, 91 { 92 Pos: rng.Start, 93 End: rng.End, 94 NewText: []byte(lhs), 95 }, 96 }, 97 }, nil 98 } 99 100 // CanExtractVariable reports whether the code in the given range can be 101 // extracted to a variable. 102 func CanExtractVariable(rng span.Range, file *ast.File) (ast.Expr, []ast.Node, bool, error) { 103 if rng.Start == rng.End { 104 return nil, nil, false, fmt.Errorf("start and end are equal") 105 } 106 path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) 107 if len(path) == 0 { 108 return nil, nil, false, fmt.Errorf("no path enclosing interval") 109 } 110 for _, n := range path { 111 if _, ok := n.(*ast.ImportSpec); ok { 112 return nil, nil, false, fmt.Errorf("cannot extract variable in an import block") 113 } 114 } 115 node := path[0] 116 if rng.Start != node.Pos() || rng.End != node.End() { 117 return nil, nil, false, fmt.Errorf("range does not map to an AST node") 118 } 119 expr, ok := node.(ast.Expr) 120 if !ok { 121 return nil, nil, false, fmt.Errorf("node is not an expression") 122 } 123 switch expr.(type) { 124 case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr, 125 *ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: 126 return expr, path, true, nil 127 } 128 return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr) 129 } 130 131 // Calculate indentation for insertion. 132 // When inserting lines of code, we must ensure that the lines have consistent 133 // formatting (i.e. the proper indentation). To do so, we observe the indentation on the 134 // line of code on which the insertion occurs. 135 func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.Node) (string, error) { 136 line := tok.Line(insertBeforeStmt.Pos()) 137 lineOffset, err := Offset(tok, tok.LineStart(line)) 138 if err != nil { 139 return "", err 140 } 141 stmtOffset, err := Offset(tok, insertBeforeStmt.Pos()) 142 if err != nil { 143 return "", err 144 } 145 return string(content[lineOffset:stmtOffset]), nil 146 } 147 148 // generateAvailableIdentifier adjusts the new function name until there are no collisons in scope. 149 // Possible collisions include other function and variable names. Returns the next index to check for prefix. 150 func generateAvailableIdentifier(pos token.Pos, file *ast.File, path []ast.Node, info *types.Info, prefix string, idx int) (string, int) { 151 scopes := CollectScopes(info, path, pos) 152 return generateIdentifier(idx, prefix, func(name string) bool { 153 return file.Scope.Lookup(name) != nil || !isValidName(name, scopes) 154 }) 155 } 156 157 func generateIdentifier(idx int, prefix string, hasCollision func(string) bool) (string, int) { 158 name := prefix 159 if idx != 0 { 160 name += fmt.Sprintf("%d", idx) 161 } 162 for hasCollision(name) { 163 idx++ 164 name = fmt.Sprintf("%v%d", prefix, idx) 165 } 166 return name, idx + 1 167 } 168 169 // isValidName checks for variable collision in scope. 170 func isValidName(name string, scopes []*types.Scope) bool { 171 for _, scope := range scopes { 172 if scope == nil { 173 continue 174 } 175 if scope.Lookup(name) != nil { 176 return false 177 } 178 } 179 return true 180 } 181 182 // returnVariable keeps track of the information we need to properly introduce a new variable 183 // that we will return in the extracted function. 184 type returnVariable struct { 185 // name is the identifier that is used on the left-hand side of the call to 186 // the extracted function. 187 name ast.Expr 188 // decl is the declaration of the variable. It is used in the type signature of the 189 // extracted function and for variable declarations. 190 decl *ast.Field 191 // zeroVal is the "zero value" of the type of the variable. It is used in a return 192 // statement in the extracted function. 193 zeroVal ast.Expr 194 } 195 196 // extractMethod refactors the selected block of code into a new method. 197 func extractMethod(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { 198 return extractFunctionMethod(fset, rng, src, file, pkg, info, true) 199 } 200 201 // extractFunction refactors the selected block of code into a new function. 202 func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { 203 return extractFunctionMethod(fset, rng, src, file, pkg, info, false) 204 } 205 206 // extractFunctionMethod refactors the selected block of code into a new function/method. 207 // It also replaces the selected block of code with a call to the extracted 208 // function. First, we manually adjust the selection range. We remove trailing 209 // and leading whitespace characters to ensure the range is precisely bounded 210 // by AST nodes. Next, we determine the variables that will be the parameters 211 // and return values of the extracted function/method. Lastly, we construct the call 212 // of the function/method and insert this call as well as the extracted function/method into 213 // their proper locations. 214 func extractFunctionMethod(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info, isMethod bool) (*analysis.SuggestedFix, error) { 215 errorPrefix := "extractFunction" 216 if isMethod { 217 errorPrefix = "extractMethod" 218 } 219 p, ok, methodOk, err := CanExtractFunction(fset, rng, src, file) 220 if (!ok && !isMethod) || (!methodOk && isMethod) { 221 return nil, fmt.Errorf("%s: cannot extract %s: %v", errorPrefix, 222 fset.Position(rng.Start), err) 223 } 224 tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start 225 fileScope := info.Scopes[file] 226 if fileScope == nil { 227 return nil, fmt.Errorf("%s: file scope is empty", errorPrefix) 228 } 229 pkgScope := fileScope.Parent() 230 if pkgScope == nil { 231 return nil, fmt.Errorf("%s: package scope is empty", errorPrefix) 232 } 233 234 // A return statement is non-nested if its parent node is equal to the parent node 235 // of the first node in the selection. These cases must be handled separately because 236 // non-nested return statements are guaranteed to execute. 237 var retStmts []*ast.ReturnStmt 238 var hasNonNestedReturn bool 239 startParent := findParent(outer, start) 240 ast.Inspect(outer, func(n ast.Node) bool { 241 if n == nil { 242 return false 243 } 244 if n.Pos() < rng.Start || n.End() > rng.End { 245 return n.Pos() <= rng.End 246 } 247 ret, ok := n.(*ast.ReturnStmt) 248 if !ok { 249 return true 250 } 251 if findParent(outer, n) == startParent { 252 hasNonNestedReturn = true 253 } 254 retStmts = append(retStmts, ret) 255 return false 256 }) 257 containsReturnStatement := len(retStmts) > 0 258 259 // Now that we have determined the correct range for the selection block, 260 // we must determine the signature of the extracted function. We will then replace 261 // the block with an assignment statement that calls the extracted function with 262 // the appropriate parameters and return values. 263 variables, err := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0]) 264 if err != nil { 265 return nil, err 266 } 267 268 var ( 269 receiverUsed bool 270 receiver *ast.Field 271 receiverName string 272 receiverObj types.Object 273 ) 274 if isMethod { 275 if outer == nil || outer.Recv == nil || len(outer.Recv.List) == 0 { 276 return nil, fmt.Errorf("%s: cannot extract need method receiver", errorPrefix) 277 } 278 receiver = outer.Recv.List[0] 279 if len(receiver.Names) == 0 || receiver.Names[0] == nil { 280 return nil, fmt.Errorf("%s: cannot extract need method receiver name", errorPrefix) 281 } 282 recvName := receiver.Names[0] 283 receiverName = recvName.Name 284 receiverObj = info.ObjectOf(recvName) 285 } 286 287 var ( 288 params, returns []ast.Expr // used when calling the extracted function 289 paramTypes, returnTypes []*ast.Field // used in the signature of the extracted function 290 uninitialized []types.Object // vars we will need to initialize before the call 291 ) 292 293 // Avoid duplicates while traversing vars and uninitialzed. 294 seenVars := make(map[types.Object]ast.Expr) 295 seenUninitialized := make(map[types.Object]struct{}) 296 297 // Some variables on the left-hand side of our assignment statement may be free. If our 298 // selection begins in the same scope in which the free variable is defined, we can 299 // redefine it in our assignment statement. See the following example, where 'b' and 300 // 'err' (both free variables) can be redefined in the second funcCall() while maintaining 301 // correctness. 302 // 303 // 304 // Not Redefined: 305 // 306 // a, err := funcCall() 307 // var b int 308 // b, err = funcCall() 309 // 310 // Redefined: 311 // 312 // a, err := funcCall() 313 // b, err := funcCall() 314 // 315 // We track the number of free variables that can be redefined to maintain our preference 316 // of using "x, y, z := fn()" style assignment statements. 317 var canRedefineCount int 318 319 // Each identifier in the selected block must become (1) a parameter to the 320 // extracted function, (2) a return value of the extracted function, or (3) a local 321 // variable in the extracted function. Determine the outcome(s) for each variable 322 // based on whether it is free, altered within the selected block, and used outside 323 // of the selected block. 324 for _, v := range variables { 325 if _, ok := seenVars[v.obj]; ok { 326 continue 327 } 328 if v.obj.Name() == "_" { 329 // The blank identifier is always a local variable 330 continue 331 } 332 typ := analysisinternal.TypeExpr(fset, file, pkg, v.obj.Type()) 333 if typ == nil { 334 return nil, fmt.Errorf("nil AST expression for type: %v", v.obj.Name()) 335 } 336 seenVars[v.obj] = typ 337 identifier := ast.NewIdent(v.obj.Name()) 338 // An identifier must meet three conditions to become a return value of the 339 // extracted function. (1) its value must be defined or reassigned within 340 // the selection (isAssigned), (2) it must be used at least once after the 341 // selection (isUsed), and (3) its first use after the selection 342 // cannot be its own reassignment or redefinition (objOverriden). 343 if v.obj.Parent() == nil { 344 return nil, fmt.Errorf("parent nil") 345 } 346 isUsed, firstUseAfter := objUsed(info, span.NewRange(fset, rng.End, v.obj.Parent().End()), v.obj) 347 if v.assigned && isUsed && !varOverridden(info, firstUseAfter, v.obj, v.free, outer) { 348 returnTypes = append(returnTypes, &ast.Field{Type: typ}) 349 returns = append(returns, identifier) 350 if !v.free { 351 uninitialized = append(uninitialized, v.obj) 352 } else if v.obj.Parent().Pos() == startParent.Pos() { 353 canRedefineCount++ 354 } 355 } 356 // An identifier must meet two conditions to become a parameter of the 357 // extracted function. (1) it must be free (isFree), and (2) its first 358 // use within the selection cannot be its own definition (isDefined). 359 if v.free && !v.defined { 360 // Skip the selector for a method. 361 if isMethod && v.obj == receiverObj { 362 receiverUsed = true 363 continue 364 } 365 params = append(params, identifier) 366 paramTypes = append(paramTypes, &ast.Field{ 367 Names: []*ast.Ident{identifier}, 368 Type: typ, 369 }) 370 } 371 } 372 373 // Find the function literal that encloses the selection. The enclosing function literal 374 // may not be the enclosing function declaration (i.e. 'outer'). For example, in the 375 // following block: 376 // 377 // func main() { 378 // ast.Inspect(node, func(n ast.Node) bool { 379 // v := 1 // this line extracted 380 // return true 381 // }) 382 // } 383 // 384 // 'outer' is main(). However, the extracted selection most directly belongs to 385 // the anonymous function literal, the second argument of ast.Inspect(). We use the 386 // enclosing function literal to determine the proper return types for return statements 387 // within the selection. We still need the enclosing function declaration because this is 388 // the top-level declaration. We inspect the top-level declaration to look for variables 389 // as well as for code replacement. 390 enclosing := outer.Type 391 for _, p := range path { 392 if p == enclosing { 393 break 394 } 395 if fl, ok := p.(*ast.FuncLit); ok { 396 enclosing = fl.Type 397 break 398 } 399 } 400 401 // We put the selection in a constructed file. We can then traverse and edit 402 // the extracted selection without modifying the original AST. 403 startOffset, err := Offset(tok, rng.Start) 404 if err != nil { 405 return nil, err 406 } 407 endOffset, err := Offset(tok, rng.End) 408 if err != nil { 409 return nil, err 410 } 411 selection := src[startOffset:endOffset] 412 extractedBlock, err := parseBlockStmt(fset, selection) 413 if err != nil { 414 return nil, err 415 } 416 417 // We need to account for return statements in the selected block, as they will complicate 418 // the logical flow of the extracted function. See the following example, where ** denotes 419 // the range to be extracted. 420 // 421 // Before: 422 // 423 // func _() int { 424 // a := 1 425 // b := 2 426 // **if a == b { 427 // return a 428 // }** 429 // ... 430 // } 431 // 432 // After: 433 // 434 // func _() int { 435 // a := 1 436 // b := 2 437 // cond0, ret0 := x0(a, b) 438 // if cond0 { 439 // return ret0 440 // } 441 // ... 442 // } 443 // 444 // func x0(a int, b int) (bool, int) { 445 // if a == b { 446 // return true, a 447 // } 448 // return false, 0 449 // } 450 // 451 // We handle returns by adding an additional boolean return value to the extracted function. 452 // This bool reports whether the original function would have returned. Because the 453 // extracted selection contains a return statement, we must also add the types in the 454 // return signature of the enclosing function to the return signature of the 455 // extracted function. We then add an extra if statement checking this boolean value 456 // in the original function. If the condition is met, the original function should 457 // return a value, mimicking the functionality of the original return statement(s) 458 // in the selection. 459 // 460 // If there is a return that is guaranteed to execute (hasNonNestedReturns=true), then 461 // we don't need to include this additional condition check and can simply return. 462 // 463 // Before: 464 // 465 // func _() int { 466 // a := 1 467 // b := 2 468 // **if a == b { 469 // return a 470 // } 471 // return b** 472 // } 473 // 474 // After: 475 // 476 // func _() int { 477 // a := 1 478 // b := 2 479 // return x0(a, b) 480 // } 481 // 482 // func x0(a int, b int) int { 483 // if a == b { 484 // return a 485 // } 486 // return b 487 // } 488 489 var retVars []*returnVariable 490 var ifReturn *ast.IfStmt 491 if containsReturnStatement { 492 if !hasNonNestedReturn { 493 // The selected block contained return statements, so we have to modify the 494 // signature of the extracted function as described above. Adjust all of 495 // the return statements in the extracted function to reflect this change in 496 // signature. 497 if err := adjustReturnStatements(returnTypes, seenVars, fset, file, 498 pkg, extractedBlock); err != nil { 499 return nil, err 500 } 501 } 502 // Collect the additional return values and types needed to accommodate return 503 // statements in the selection. Update the type signature of the extracted 504 // function and construct the if statement that will be inserted in the enclosing 505 // function. 506 retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start, hasNonNestedReturn) 507 if err != nil { 508 return nil, err 509 } 510 } 511 512 // Add a return statement to the end of the new function. This return statement must include 513 // the values for the types of the original extracted function signature and (if a return 514 // statement is present in the selection) enclosing function signature. 515 // This only needs to be done if the selections does not have a non-nested return, otherwise 516 // it already terminates with a return statement. 517 hasReturnValues := len(returns)+len(retVars) > 0 518 if hasReturnValues && !hasNonNestedReturn { 519 extractedBlock.List = append(extractedBlock.List, &ast.ReturnStmt{ 520 Results: append(returns, getZeroVals(retVars)...), 521 }) 522 } 523 524 // Construct the appropriate call to the extracted function. 525 // We must meet two conditions to use ":=" instead of '='. (1) there must be at least 526 // one variable on the lhs that is uninitailized (non-free) prior to the assignment. 527 // (2) all of the initialized (free) variables on the lhs must be able to be redefined. 528 sym := token.ASSIGN 529 canDefineCount := len(uninitialized) + canRedefineCount 530 canDefine := len(uninitialized)+len(retVars) > 0 && canDefineCount == len(returns) 531 if canDefine { 532 sym = token.DEFINE 533 } 534 var name, funName string 535 if isMethod { 536 name = "newMethod" 537 // TODO(suzmue): generate a name that does not conflict for "newMethod". 538 funName = name 539 } else { 540 name = "newFunction" 541 funName, _ = generateAvailableIdentifier(rng.Start, file, path, info, name, 0) 542 } 543 extractedFunCall := generateFuncCall(hasNonNestedReturn, hasReturnValues, params, 544 append(returns, getNames(retVars)...), funName, sym, receiverName) 545 546 // Build the extracted function. 547 newFunc := &ast.FuncDecl{ 548 Name: ast.NewIdent(funName), 549 Type: &ast.FuncType{ 550 Params: &ast.FieldList{List: paramTypes}, 551 Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)}, 552 }, 553 Body: extractedBlock, 554 } 555 if isMethod { 556 var names []*ast.Ident 557 if receiverUsed { 558 names = append(names, ast.NewIdent(receiverName)) 559 } 560 newFunc.Recv = &ast.FieldList{ 561 List: []*ast.Field{{ 562 Names: names, 563 Type: receiver.Type, 564 }}, 565 } 566 } 567 568 // Create variable declarations for any identifiers that need to be initialized prior to 569 // calling the extracted function. We do not manually initialize variables if every return 570 // value is unitialized. We can use := to initialize the variables in this situation. 571 var declarations []ast.Stmt 572 if canDefineCount != len(returns) { 573 declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars) 574 } 575 576 var declBuf, replaceBuf, newFuncBuf, ifBuf, commentBuf bytes.Buffer 577 if err := format.Node(&declBuf, fset, declarations); err != nil { 578 return nil, err 579 } 580 if err := format.Node(&replaceBuf, fset, extractedFunCall); err != nil { 581 return nil, err 582 } 583 if ifReturn != nil { 584 if err := format.Node(&ifBuf, fset, ifReturn); err != nil { 585 return nil, err 586 } 587 } 588 if err := format.Node(&newFuncBuf, fset, newFunc); err != nil { 589 return nil, err 590 } 591 // Find all the comments within the range and print them to be put somewhere. 592 // TODO(suzmue): print these in the extracted function at the correct place. 593 for _, cg := range file.Comments { 594 if cg.Pos().IsValid() && cg.Pos() < rng.End && cg.Pos() >= rng.Start { 595 for _, c := range cg.List { 596 fmt.Fprintln(&commentBuf, c.Text) 597 } 598 } 599 } 600 601 // We're going to replace the whole enclosing function, 602 // so preserve the text before and after the selected block. 603 outerStart, err := Offset(tok, outer.Pos()) 604 if err != nil { 605 return nil, err 606 } 607 outerEnd, err := Offset(tok, outer.End()) 608 if err != nil { 609 return nil, err 610 } 611 before := src[outerStart:startOffset] 612 after := src[endOffset:outerEnd] 613 indent, err := calculateIndentation(src, tok, start) 614 if err != nil { 615 return nil, err 616 } 617 newLineIndent := "\n" + indent 618 619 var fullReplacement strings.Builder 620 fullReplacement.Write(before) 621 if commentBuf.Len() > 0 { 622 comments := strings.ReplaceAll(commentBuf.String(), "\n", newLineIndent) 623 fullReplacement.WriteString(comments) 624 } 625 if declBuf.Len() > 0 { // add any initializations, if needed 626 initializations := strings.ReplaceAll(declBuf.String(), "\n", newLineIndent) + 627 newLineIndent 628 fullReplacement.WriteString(initializations) 629 } 630 fullReplacement.Write(replaceBuf.Bytes()) // call the extracted function 631 if ifBuf.Len() > 0 { // add the if statement below the function call, if needed 632 ifstatement := newLineIndent + 633 strings.ReplaceAll(ifBuf.String(), "\n", newLineIndent) 634 fullReplacement.WriteString(ifstatement) 635 } 636 fullReplacement.Write(after) 637 fullReplacement.WriteString("\n\n") // add newlines after the enclosing function 638 fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function 639 640 return &analysis.SuggestedFix{ 641 TextEdits: []analysis.TextEdit{{ 642 Pos: outer.Pos(), 643 End: outer.End(), 644 NewText: []byte(fullReplacement.String()), 645 }}, 646 }, nil 647 } 648 649 // adjustRangeForWhitespace adjusts the given range to exclude unnecessary leading or 650 // trailing whitespace characters from selection. In the following example, each line 651 // of the if statement is indented once. There are also two extra spaces after the 652 // closing bracket before the line break. 653 // 654 // \tif (true) { 655 // \t _ = 1 656 // \t} \n 657 // 658 // By default, a valid range begins at 'if' and ends at the first whitespace character 659 // after the '}'. But, users are likely to highlight full lines rather than adjusting 660 // their cursors for whitespace. To support this use case, we must manually adjust the 661 // ranges to match the correct AST node. In this particular example, we would adjust 662 // rng.Start forward by one byte, and rng.End backwards by two bytes. 663 func adjustRangeForWhitespace(rng span.Range, tok *token.File, content []byte) (span.Range, error) { 664 offset, err := Offset(tok, rng.Start) 665 if err != nil { 666 return span.Range{}, err 667 } 668 for offset < len(content) { 669 if !unicode.IsSpace(rune(content[offset])) { 670 break 671 } 672 // Move forwards one byte to find a non-whitespace character. 673 offset += 1 674 } 675 rng.Start = tok.Pos(offset) 676 677 // Move backwards to find a non-whitespace character. 678 offset, err = Offset(tok, rng.End) 679 if err != nil { 680 return span.Range{}, err 681 } 682 for o := offset - 1; 0 <= o && o < len(content); o-- { 683 if !unicode.IsSpace(rune(content[o])) { 684 break 685 } 686 offset = o 687 } 688 rng.End = tok.Pos(offset) 689 return rng, nil 690 } 691 692 // findParent finds the parent AST node of the given target node, if the target is a 693 // descendant of the starting node. 694 func findParent(start ast.Node, target ast.Node) ast.Node { 695 var parent ast.Node 696 analysisinternal.WalkASTWithParent(start, func(n, p ast.Node) bool { 697 if n == target { 698 parent = p 699 return false 700 } 701 return true 702 }) 703 return parent 704 } 705 706 // variable describes the status of a variable within a selection. 707 type variable struct { 708 obj types.Object 709 710 // free reports whether the variable is a free variable, meaning it should 711 // be a parameter to the extracted function. 712 free bool 713 714 // assigned reports whether the variable is assigned to in the selection. 715 assigned bool 716 717 // defined reports whether the variable is defined in the selection. 718 defined bool 719 } 720 721 // collectFreeVars maps each identifier in the given range to whether it is "free." 722 // Given a range, a variable in that range is defined as "free" if it is declared 723 // outside of the range and neither at the file scope nor package scope. These free 724 // variables will be used as arguments in the extracted function. It also returns a 725 // list of identifiers that may need to be returned by the extracted function. 726 // Some of the code in this function has been adapted from tools/cmd/guru/freevars.go. 727 func collectFreeVars(info *types.Info, file *ast.File, fileScope, pkgScope *types.Scope, rng span.Range, node ast.Node) ([]*variable, error) { 728 // id returns non-nil if n denotes an object that is referenced by the span 729 // and defined either within the span or in the lexical environment. The bool 730 // return value acts as an indicator for where it was defined. 731 id := func(n *ast.Ident) (types.Object, bool) { 732 obj := info.Uses[n] 733 if obj == nil { 734 return info.Defs[n], false 735 } 736 if obj.Name() == "_" { 737 return nil, false // exclude objects denoting '_' 738 } 739 if _, ok := obj.(*types.PkgName); ok { 740 return nil, false // imported package 741 } 742 if !(file.Pos() <= obj.Pos() && obj.Pos() <= file.End()) { 743 return nil, false // not defined in this file 744 } 745 scope := obj.Parent() 746 if scope == nil { 747 return nil, false // e.g. interface method, struct field 748 } 749 if scope == fileScope || scope == pkgScope { 750 return nil, false // defined at file or package scope 751 } 752 if rng.Start <= obj.Pos() && obj.Pos() <= rng.End { 753 return obj, false // defined within selection => not free 754 } 755 return obj, true 756 } 757 // sel returns non-nil if n denotes a selection o.x.y that is referenced by the 758 // span and defined either within the span or in the lexical environment. The bool 759 // return value acts as an indicator for where it was defined. 760 var sel func(n *ast.SelectorExpr) (types.Object, bool) 761 sel = func(n *ast.SelectorExpr) (types.Object, bool) { 762 switch x := astutil.Unparen(n.X).(type) { 763 case *ast.SelectorExpr: 764 return sel(x) 765 case *ast.Ident: 766 return id(x) 767 } 768 return nil, false 769 } 770 seen := make(map[types.Object]*variable) 771 firstUseIn := make(map[types.Object]token.Pos) 772 var vars []types.Object 773 ast.Inspect(node, func(n ast.Node) bool { 774 if n == nil { 775 return false 776 } 777 if rng.Start <= n.Pos() && n.End() <= rng.End { 778 var obj types.Object 779 var isFree, prune bool 780 switch n := n.(type) { 781 case *ast.Ident: 782 obj, isFree = id(n) 783 case *ast.SelectorExpr: 784 obj, isFree = sel(n) 785 prune = true 786 } 787 if obj != nil { 788 seen[obj] = &variable{ 789 obj: obj, 790 free: isFree, 791 } 792 vars = append(vars, obj) 793 // Find the first time that the object is used in the selection. 794 first, ok := firstUseIn[obj] 795 if !ok || n.Pos() < first { 796 firstUseIn[obj] = n.Pos() 797 } 798 if prune { 799 return false 800 } 801 } 802 } 803 return n.Pos() <= rng.End 804 }) 805 806 // Find identifiers that are initialized or whose values are altered at some 807 // point in the selected block. For example, in a selected block from lines 2-4, 808 // variables x, y, and z are included in assigned. However, in a selected block 809 // from lines 3-4, only variables y and z are included in assigned. 810 // 811 // 1: var a int 812 // 2: var x int 813 // 3: y := 3 814 // 4: z := x + a 815 // 816 ast.Inspect(node, func(n ast.Node) bool { 817 if n == nil { 818 return false 819 } 820 if n.Pos() < rng.Start || n.End() > rng.End { 821 return n.Pos() <= rng.End 822 } 823 switch n := n.(type) { 824 case *ast.AssignStmt: 825 for _, assignment := range n.Lhs { 826 lhs, ok := assignment.(*ast.Ident) 827 if !ok { 828 continue 829 } 830 obj, _ := id(lhs) 831 if obj == nil { 832 continue 833 } 834 if _, ok := seen[obj]; !ok { 835 continue 836 } 837 seen[obj].assigned = true 838 if n.Tok != token.DEFINE { 839 continue 840 } 841 // Find identifiers that are defined prior to being used 842 // elsewhere in the selection. 843 // TODO: Include identifiers that are assigned prior to being 844 // used elsewhere in the selection. Then, change the assignment 845 // to a definition in the extracted function. 846 if firstUseIn[obj] != lhs.Pos() { 847 continue 848 } 849 // Ensure that the object is not used in its own re-definition. 850 // For example: 851 // var f float64 852 // f, e := math.Frexp(f) 853 for _, expr := range n.Rhs { 854 if referencesObj(info, expr, obj) { 855 continue 856 } 857 if _, ok := seen[obj]; !ok { 858 continue 859 } 860 seen[obj].defined = true 861 break 862 } 863 } 864 return false 865 case *ast.DeclStmt: 866 gen, ok := n.Decl.(*ast.GenDecl) 867 if !ok { 868 return false 869 } 870 for _, spec := range gen.Specs { 871 vSpecs, ok := spec.(*ast.ValueSpec) 872 if !ok { 873 continue 874 } 875 for _, vSpec := range vSpecs.Names { 876 obj, _ := id(vSpec) 877 if obj == nil { 878 continue 879 } 880 if _, ok := seen[obj]; !ok { 881 continue 882 } 883 seen[obj].assigned = true 884 } 885 } 886 return false 887 case *ast.IncDecStmt: 888 if ident, ok := n.X.(*ast.Ident); !ok { 889 return false 890 } else if obj, _ := id(ident); obj == nil { 891 return false 892 } else { 893 if _, ok := seen[obj]; !ok { 894 return false 895 } 896 seen[obj].assigned = true 897 } 898 } 899 return true 900 }) 901 var variables []*variable 902 for _, obj := range vars { 903 v, ok := seen[obj] 904 if !ok { 905 return nil, fmt.Errorf("no seen types.Object for %v", obj) 906 } 907 variables = append(variables, v) 908 } 909 return variables, nil 910 } 911 912 // referencesObj checks whether the given object appears in the given expression. 913 func referencesObj(info *types.Info, expr ast.Expr, obj types.Object) bool { 914 var hasObj bool 915 ast.Inspect(expr, func(n ast.Node) bool { 916 if n == nil { 917 return false 918 } 919 ident, ok := n.(*ast.Ident) 920 if !ok { 921 return true 922 } 923 objUse := info.Uses[ident] 924 if obj == objUse { 925 hasObj = true 926 return false 927 } 928 return false 929 }) 930 return hasObj 931 } 932 933 type fnExtractParams struct { 934 tok *token.File 935 path []ast.Node 936 rng span.Range 937 outer *ast.FuncDecl 938 start ast.Node 939 } 940 941 // CanExtractFunction reports whether the code in the given range can be 942 // extracted to a function. 943 func CanExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File) (*fnExtractParams, bool, bool, error) { 944 if rng.Start == rng.End { 945 return nil, false, false, fmt.Errorf("start and end are equal") 946 } 947 tok := fset.File(file.Pos()) 948 if tok == nil { 949 return nil, false, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos())) 950 } 951 var err error 952 rng, err = adjustRangeForWhitespace(rng, tok, src) 953 if err != nil { 954 return nil, false, false, err 955 } 956 path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) 957 if len(path) == 0 { 958 return nil, false, false, fmt.Errorf("no path enclosing interval") 959 } 960 // Node that encloses the selection must be a statement. 961 // TODO: Support function extraction for an expression. 962 _, ok := path[0].(ast.Stmt) 963 if !ok { 964 return nil, false, false, fmt.Errorf("node is not a statement") 965 } 966 967 // Find the function declaration that encloses the selection. 968 var outer *ast.FuncDecl 969 for _, p := range path { 970 if p, ok := p.(*ast.FuncDecl); ok { 971 outer = p 972 break 973 } 974 } 975 if outer == nil { 976 return nil, false, false, fmt.Errorf("no enclosing function") 977 } 978 979 // Find the nodes at the start and end of the selection. 980 var start, end ast.Node 981 ast.Inspect(outer, func(n ast.Node) bool { 982 if n == nil { 983 return false 984 } 985 // Do not override 'start' with a node that begins at the same location 986 // but is nested further from 'outer'. 987 if start == nil && n.Pos() == rng.Start && n.End() <= rng.End { 988 start = n 989 } 990 if end == nil && n.End() == rng.End && n.Pos() >= rng.Start { 991 end = n 992 } 993 return n.Pos() <= rng.End 994 }) 995 if start == nil || end == nil { 996 return nil, false, false, fmt.Errorf("range does not map to AST nodes") 997 } 998 // If the region is a blockStmt, use the first and last nodes in the block 999 // statement. 1000 // <rng.start>{ ... }<rng.end> => { <rng.start>...<rng.end> } 1001 if blockStmt, ok := start.(*ast.BlockStmt); ok { 1002 if len(blockStmt.List) == 0 { 1003 return nil, false, false, fmt.Errorf("range maps to empty block statement") 1004 } 1005 start, end = blockStmt.List[0], blockStmt.List[len(blockStmt.List)-1] 1006 rng.Start, rng.End = start.Pos(), end.End() 1007 } 1008 return &fnExtractParams{ 1009 tok: tok, 1010 path: path, 1011 rng: rng, 1012 outer: outer, 1013 start: start, 1014 }, true, outer.Recv != nil, nil 1015 } 1016 1017 // objUsed checks if the object is used within the range. It returns the first 1018 // occurrence of the object in the range, if it exists. 1019 func objUsed(info *types.Info, rng span.Range, obj types.Object) (bool, *ast.Ident) { 1020 var firstUse *ast.Ident 1021 for id, objUse := range info.Uses { 1022 if obj != objUse { 1023 continue 1024 } 1025 if id.Pos() < rng.Start || id.End() > rng.End { 1026 continue 1027 } 1028 if firstUse == nil || id.Pos() < firstUse.Pos() { 1029 firstUse = id 1030 } 1031 } 1032 return firstUse != nil, firstUse 1033 } 1034 1035 // varOverridden traverses the given AST node until we find the given identifier. Then, we 1036 // examine the occurrence of the given identifier and check for (1) whether the identifier 1037 // is being redefined. If the identifier is free, we also check for (2) whether the identifier 1038 // is being reassigned. We will not include an identifier in the return statement of the 1039 // extracted function if it meets one of the above conditions. 1040 func varOverridden(info *types.Info, firstUse *ast.Ident, obj types.Object, isFree bool, node ast.Node) bool { 1041 var isOverriden bool 1042 ast.Inspect(node, func(n ast.Node) bool { 1043 if n == nil { 1044 return false 1045 } 1046 assignment, ok := n.(*ast.AssignStmt) 1047 if !ok { 1048 return true 1049 } 1050 // A free variable is initialized prior to the selection. We can always reassign 1051 // this variable after the selection because it has already been defined. 1052 // Conversely, a non-free variable is initialized within the selection. Thus, we 1053 // cannot reassign this variable after the selection unless it is initialized and 1054 // returned by the extracted function. 1055 if !isFree && assignment.Tok == token.ASSIGN { 1056 return false 1057 } 1058 for _, assigned := range assignment.Lhs { 1059 ident, ok := assigned.(*ast.Ident) 1060 // Check if we found the first use of the identifier. 1061 if !ok || ident != firstUse { 1062 continue 1063 } 1064 objUse := info.Uses[ident] 1065 if objUse == nil || objUse != obj { 1066 continue 1067 } 1068 // Ensure that the object is not used in its own definition. 1069 // For example: 1070 // var f float64 1071 // f, e := math.Frexp(f) 1072 for _, expr := range assignment.Rhs { 1073 if referencesObj(info, expr, obj) { 1074 return false 1075 } 1076 } 1077 isOverriden = true 1078 return false 1079 } 1080 return false 1081 }) 1082 return isOverriden 1083 } 1084 1085 // parseExtraction generates an AST file from the given text. We then return the portion of the 1086 // file that represents the text. 1087 func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) { 1088 text := "package main\nfunc _() { " + string(src) + " }" 1089 extract, err := parser.ParseFile(fset, "", text, 0) 1090 if err != nil { 1091 return nil, err 1092 } 1093 if len(extract.Decls) == 0 { 1094 return nil, fmt.Errorf("parsed file does not contain any declarations") 1095 } 1096 decl, ok := extract.Decls[0].(*ast.FuncDecl) 1097 if !ok { 1098 return nil, fmt.Errorf("parsed file does not contain expected function declaration") 1099 } 1100 if decl.Body == nil { 1101 return nil, fmt.Errorf("extracted function has no body") 1102 } 1103 return decl.Body, nil 1104 } 1105 1106 // generateReturnInfo generates the information we need to adjust the return statements and 1107 // signature of the extracted function. We prepare names, signatures, and "zero values" that 1108 // represent the new variables. We also use this information to construct the if statement that 1109 // is inserted below the call to the extracted function. 1110 func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, fset *token.FileSet, pos token.Pos, hasNonNestedReturns bool) ([]*returnVariable, *ast.IfStmt, error) { 1111 var retVars []*returnVariable 1112 var cond *ast.Ident 1113 if !hasNonNestedReturns { 1114 // Generate information for the added bool value. 1115 name, _ := generateAvailableIdentifier(pos, file, path, info, "shouldReturn", 0) 1116 cond = &ast.Ident{Name: name} 1117 retVars = append(retVars, &returnVariable{ 1118 name: cond, 1119 decl: &ast.Field{Type: ast.NewIdent("bool")}, 1120 zeroVal: ast.NewIdent("false"), 1121 }) 1122 } 1123 // Generate information for the values in the return signature of the enclosing function. 1124 if enclosing.Results != nil { 1125 idx := 0 1126 for _, field := range enclosing.Results.List { 1127 typ := info.TypeOf(field.Type) 1128 if typ == nil { 1129 return nil, nil, fmt.Errorf( 1130 "failed type conversion, AST expression: %T", field.Type) 1131 } 1132 expr := analysisinternal.TypeExpr(fset, file, pkg, typ) 1133 if expr == nil { 1134 return nil, nil, fmt.Errorf("nil AST expression") 1135 } 1136 var name string 1137 name, idx = generateAvailableIdentifier(pos, file, 1138 path, info, "returnValue", idx) 1139 retVars = append(retVars, &returnVariable{ 1140 name: ast.NewIdent(name), 1141 decl: &ast.Field{Type: expr}, 1142 zeroVal: analysisinternal.ZeroValue( 1143 fset, file, pkg, typ), 1144 }) 1145 } 1146 } 1147 var ifReturn *ast.IfStmt 1148 if !hasNonNestedReturns { 1149 // Create the return statement for the enclosing function. We must exclude the variable 1150 // for the condition of the if statement (cond) from the return statement. 1151 ifReturn = &ast.IfStmt{ 1152 Cond: cond, 1153 Body: &ast.BlockStmt{ 1154 List: []ast.Stmt{&ast.ReturnStmt{Results: getNames(retVars)[1:]}}, 1155 }, 1156 } 1157 } 1158 return retVars, ifReturn, nil 1159 } 1160 1161 // adjustReturnStatements adds "zero values" of the given types to each return statement 1162 // in the given AST node. 1163 func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]ast.Expr, fset *token.FileSet, file *ast.File, pkg *types.Package, extractedBlock *ast.BlockStmt) error { 1164 var zeroVals []ast.Expr 1165 // Create "zero values" for each type. 1166 for _, returnType := range returnTypes { 1167 var val ast.Expr 1168 for obj, typ := range seenVars { 1169 if typ != returnType.Type { 1170 continue 1171 } 1172 val = analysisinternal.ZeroValue(fset, file, pkg, obj.Type()) 1173 break 1174 } 1175 if val == nil { 1176 return fmt.Errorf( 1177 "could not find matching AST expression for %T", returnType.Type) 1178 } 1179 zeroVals = append(zeroVals, val) 1180 } 1181 // Add "zero values" to each return statement. 1182 // The bool reports whether the enclosing function should return after calling the 1183 // extracted function. We set the bool to 'true' because, if these return statements 1184 // execute, the extracted function terminates early, and the enclosing function must 1185 // return as well. 1186 zeroVals = append(zeroVals, ast.NewIdent("true")) 1187 ast.Inspect(extractedBlock, func(n ast.Node) bool { 1188 if n == nil { 1189 return false 1190 } 1191 if n, ok := n.(*ast.ReturnStmt); ok { 1192 n.Results = append(zeroVals, n.Results...) 1193 return false 1194 } 1195 return true 1196 }) 1197 return nil 1198 } 1199 1200 // generateFuncCall constructs a call expression for the extracted function, described by the 1201 // given parameters and return variables. 1202 func generateFuncCall(hasNonNestedReturn, hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token, selector string) ast.Node { 1203 var replace ast.Node 1204 callExpr := &ast.CallExpr{ 1205 Fun: ast.NewIdent(name), 1206 Args: params, 1207 } 1208 if selector != "" { 1209 callExpr = &ast.CallExpr{ 1210 Fun: &ast.SelectorExpr{ 1211 X: ast.NewIdent(selector), 1212 Sel: ast.NewIdent(name), 1213 }, 1214 Args: params, 1215 } 1216 } 1217 if hasReturnVals { 1218 if hasNonNestedReturn { 1219 // Create a return statement that returns the result of the function call. 1220 replace = &ast.ReturnStmt{ 1221 Return: 0, 1222 Results: []ast.Expr{callExpr}, 1223 } 1224 } else { 1225 // Assign the result of the function call. 1226 replace = &ast.AssignStmt{ 1227 Lhs: returns, 1228 Tok: token, 1229 Rhs: []ast.Expr{callExpr}, 1230 } 1231 } 1232 } else { 1233 replace = callExpr 1234 } 1235 return replace 1236 } 1237 1238 // initializeVars creates variable declarations, if needed. 1239 // Our preference is to replace the selected block with an "x, y, z := fn()" style 1240 // assignment statement. We can use this style when all of the variables in the 1241 // extracted function's return statement are either not defined prior to the extracted block 1242 // or can be safely redefined. However, for example, if z is already defined 1243 // in a different scope, we replace the selected block with: 1244 // 1245 // var x int 1246 // var y string 1247 // x, y, z = fn() 1248 func initializeVars(uninitialized []types.Object, retVars []*returnVariable, seenUninitialized map[types.Object]struct{}, seenVars map[types.Object]ast.Expr) []ast.Stmt { 1249 var declarations []ast.Stmt 1250 for _, obj := range uninitialized { 1251 if _, ok := seenUninitialized[obj]; ok { 1252 continue 1253 } 1254 seenUninitialized[obj] = struct{}{} 1255 valSpec := &ast.ValueSpec{ 1256 Names: []*ast.Ident{ast.NewIdent(obj.Name())}, 1257 Type: seenVars[obj], 1258 } 1259 genDecl := &ast.GenDecl{ 1260 Tok: token.VAR, 1261 Specs: []ast.Spec{valSpec}, 1262 } 1263 declarations = append(declarations, &ast.DeclStmt{Decl: genDecl}) 1264 } 1265 // Each variable added from a return statement in the selection 1266 // must be initialized. 1267 for i, retVar := range retVars { 1268 n := retVar.name.(*ast.Ident) 1269 valSpec := &ast.ValueSpec{ 1270 Names: []*ast.Ident{n}, 1271 Type: retVars[i].decl.Type, 1272 } 1273 genDecl := &ast.GenDecl{ 1274 Tok: token.VAR, 1275 Specs: []ast.Spec{valSpec}, 1276 } 1277 declarations = append(declarations, &ast.DeclStmt{Decl: genDecl}) 1278 } 1279 return declarations 1280 } 1281 1282 // getNames returns the names from the given list of returnVariable. 1283 func getNames(retVars []*returnVariable) []ast.Expr { 1284 var names []ast.Expr 1285 for _, retVar := range retVars { 1286 names = append(names, retVar.name) 1287 } 1288 return names 1289 } 1290 1291 // getZeroVals returns the "zero values" from the given list of returnVariable. 1292 func getZeroVals(retVars []*returnVariable) []ast.Expr { 1293 var zvs []ast.Expr 1294 for _, retVar := range retVars { 1295 zvs = append(zvs, retVar.zeroVal) 1296 } 1297 return zvs 1298 } 1299 1300 // getDecls returns the declarations from the given list of returnVariable. 1301 func getDecls(retVars []*returnVariable) []*ast.Field { 1302 var decls []*ast.Field 1303 for _, retVar := range retVars { 1304 decls = append(decls, retVar.decl) 1305 } 1306 return decls 1307 }