github.com/jd-ly/tools@v0.5.7/internal/lsp/source/extract.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package source
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/format"
    12  	"go/parser"
    13  	"go/token"
    14  	"go/types"
    15  	"strings"
    16  	"unicode"
    17  
    18  	"github.com/jd-ly/tools/go/analysis"
    19  	"github.com/jd-ly/tools/go/ast/astutil"
    20  	"github.com/jd-ly/tools/internal/analysisinternal"
    21  	"github.com/jd-ly/tools/internal/span"
    22  )
    23  
    24  func extractVariable(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
    25  	expr, path, ok, err := canExtractVariable(rng, file)
    26  	if !ok {
    27  		return nil, fmt.Errorf("extractVariable: cannot extract %s: %v", fset.Position(rng.Start), err)
    28  	}
    29  
    30  	// Create new AST node for extracted code.
    31  	var lhsNames []string
    32  	switch expr := expr.(type) {
    33  	// TODO: stricter rules for selectorExpr.
    34  	case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr,
    35  		*ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
    36  		lhsNames = append(lhsNames, generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0))
    37  	case *ast.CallExpr:
    38  		tup, ok := info.TypeOf(expr).(*types.Tuple)
    39  		if !ok {
    40  			// If the call expression only has one return value, we can treat it the
    41  			// same as our standard extract variable case.
    42  			lhsNames = append(lhsNames,
    43  				generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0))
    44  			break
    45  		}
    46  		for i := 0; i < tup.Len(); i++ {
    47  			// Generate a unique variable for each return value.
    48  			lhsNames = append(lhsNames,
    49  				generateAvailableIdentifier(expr.Pos(), file, path, info, "x", i))
    50  		}
    51  	default:
    52  		return nil, fmt.Errorf("cannot extract %T", expr)
    53  	}
    54  
    55  	insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path)
    56  	if insertBeforeStmt == nil {
    57  		return nil, fmt.Errorf("cannot find location to insert extraction")
    58  	}
    59  	tok := fset.File(expr.Pos())
    60  	if tok == nil {
    61  		return nil, fmt.Errorf("no file for pos %v", fset.Position(file.Pos()))
    62  	}
    63  	newLineIndent := "\n" + calculateIndentation(src, tok, insertBeforeStmt)
    64  
    65  	lhs := strings.Join(lhsNames, ", ")
    66  	assignStmt := &ast.AssignStmt{
    67  		Lhs: []ast.Expr{ast.NewIdent(lhs)},
    68  		Tok: token.DEFINE,
    69  		Rhs: []ast.Expr{expr},
    70  	}
    71  	var buf bytes.Buffer
    72  	if err := format.Node(&buf, fset, assignStmt); err != nil {
    73  		return nil, err
    74  	}
    75  	assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent
    76  
    77  	return &analysis.SuggestedFix{
    78  		TextEdits: []analysis.TextEdit{
    79  			{
    80  				Pos:     rng.Start,
    81  				End:     rng.End,
    82  				NewText: []byte(lhs),
    83  			},
    84  			{
    85  				Pos:     insertBeforeStmt.Pos(),
    86  				End:     insertBeforeStmt.Pos(),
    87  				NewText: []byte(assignment),
    88  			},
    89  		},
    90  	}, nil
    91  }
    92  
    93  // canExtractVariable reports whether the code in the given range can be
    94  // extracted to a variable.
    95  func canExtractVariable(rng span.Range, file *ast.File) (ast.Expr, []ast.Node, bool, error) {
    96  	if rng.Start == rng.End {
    97  		return nil, nil, false, fmt.Errorf("start and end are equal")
    98  	}
    99  	path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
   100  	if len(path) == 0 {
   101  		return nil, nil, false, fmt.Errorf("no path enclosing interval")
   102  	}
   103  	for _, n := range path {
   104  		if _, ok := n.(*ast.ImportSpec); ok {
   105  			return nil, nil, false, fmt.Errorf("cannot extract variable in an import block")
   106  		}
   107  	}
   108  	node := path[0]
   109  	if rng.Start != node.Pos() || rng.End != node.End() {
   110  		return nil, nil, false, fmt.Errorf("range does not map to an AST node")
   111  	}
   112  	expr, ok := node.(ast.Expr)
   113  	if !ok {
   114  		return nil, nil, false, fmt.Errorf("node is not an expression")
   115  	}
   116  	switch expr.(type) {
   117  	case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr,
   118  		*ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
   119  		return expr, path, true, nil
   120  	}
   121  	return nil, nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
   122  }
   123  
   124  // Calculate indentation for insertion.
   125  // When inserting lines of code, we must ensure that the lines have consistent
   126  // formatting (i.e. the proper indentation). To do so, we observe the indentation on the
   127  // line of code on which the insertion occurs.
   128  func calculateIndentation(content []byte, tok *token.File, insertBeforeStmt ast.Node) string {
   129  	line := tok.Line(insertBeforeStmt.Pos())
   130  	lineOffset := tok.Offset(tok.LineStart(line))
   131  	stmtOffset := tok.Offset(insertBeforeStmt.Pos())
   132  	return string(content[lineOffset:stmtOffset])
   133  }
   134  
   135  // generateAvailableIdentifier adjusts the new function name until there are no collisons in scope.
   136  // Possible collisions include other function and variable names.
   137  func generateAvailableIdentifier(pos token.Pos, file *ast.File, path []ast.Node, info *types.Info, prefix string, idx int) string {
   138  	scopes := CollectScopes(info, path, pos)
   139  	name := prefix + fmt.Sprintf("%d", idx)
   140  	for file.Scope.Lookup(name) != nil || !isValidName(name, scopes) {
   141  		idx++
   142  		name = fmt.Sprintf("%v%d", prefix, idx)
   143  	}
   144  	return name
   145  }
   146  
   147  // isValidName checks for variable collision in scope.
   148  func isValidName(name string, scopes []*types.Scope) bool {
   149  	for _, scope := range scopes {
   150  		if scope == nil {
   151  			continue
   152  		}
   153  		if scope.Lookup(name) != nil {
   154  			return false
   155  		}
   156  	}
   157  	return true
   158  }
   159  
   160  // returnVariable keeps track of the information we need to properly introduce a new variable
   161  // that we will return in the extracted function.
   162  type returnVariable struct {
   163  	// name is the identifier that is used on the left-hand side of the call to
   164  	// the extracted function.
   165  	name ast.Expr
   166  	// decl is the declaration of the variable. It is used in the type signature of the
   167  	// extracted function and for variable declarations.
   168  	decl *ast.Field
   169  	// zeroVal is the "zero value" of the type of the variable. It is used in a return
   170  	// statement in the extracted function.
   171  	zeroVal ast.Expr
   172  }
   173  
   174  // extractFunction refactors the selected block of code into a new function.
   175  // It also replaces the selected block of code with a call to the extracted
   176  // function. First, we manually adjust the selection range. We remove trailing
   177  // and leading whitespace characters to ensure the range is precisely bounded
   178  // by AST nodes. Next, we determine the variables that will be the parameters
   179  // and return values of the extracted function. Lastly, we construct the call
   180  // of the function and insert this call as well as the extracted function into
   181  // their proper locations.
   182  func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) {
   183  	p, ok, err := canExtractFunction(fset, rng, src, file, info)
   184  	if !ok {
   185  		return nil, fmt.Errorf("extractFunction: cannot extract %s: %v",
   186  			fset.Position(rng.Start), err)
   187  	}
   188  	tok, path, rng, outer, start := p.tok, p.path, p.rng, p.outer, p.start
   189  	fileScope := info.Scopes[file]
   190  	if fileScope == nil {
   191  		return nil, fmt.Errorf("extractFunction: file scope is empty")
   192  	}
   193  	pkgScope := fileScope.Parent()
   194  	if pkgScope == nil {
   195  		return nil, fmt.Errorf("extractFunction: package scope is empty")
   196  	}
   197  
   198  	// TODO: Support non-nested return statements.
   199  	// A return statement is non-nested if its parent node is equal to the parent node
   200  	// of the first node in the selection. These cases must be handled separately because
   201  	// non-nested return statements are guaranteed to execute. Our control flow does not
   202  	// properly consider these situations yet.
   203  	var retStmts []*ast.ReturnStmt
   204  	var hasNonNestedReturn bool
   205  	startParent := findParent(outer, start)
   206  	ast.Inspect(outer, func(n ast.Node) bool {
   207  		if n == nil {
   208  			return false
   209  		}
   210  		if n.Pos() < rng.Start || n.End() > rng.End {
   211  			return n.Pos() <= rng.End
   212  		}
   213  		ret, ok := n.(*ast.ReturnStmt)
   214  		if !ok {
   215  			return true
   216  		}
   217  		if findParent(outer, n) == startParent {
   218  			hasNonNestedReturn = true
   219  			return false
   220  		}
   221  		retStmts = append(retStmts, ret)
   222  		return false
   223  	})
   224  	if hasNonNestedReturn {
   225  		return nil, fmt.Errorf("extractFunction: selected block contains non-nested return")
   226  	}
   227  	containsReturnStatement := len(retStmts) > 0
   228  
   229  	// Now that we have determined the correct range for the selection block,
   230  	// we must determine the signature of the extracted function. We will then replace
   231  	// the block with an assignment statement that calls the extracted function with
   232  	// the appropriate parameters and return values.
   233  	variables, err := collectFreeVars(info, file, fileScope, pkgScope, rng, path[0])
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  
   238  	var (
   239  		params, returns         []ast.Expr     // used when calling the extracted function
   240  		paramTypes, returnTypes []*ast.Field   // used in the signature of the extracted function
   241  		uninitialized           []types.Object // vars we will need to initialize before the call
   242  	)
   243  
   244  	// Avoid duplicates while traversing vars and uninitialzed.
   245  	seenVars := make(map[types.Object]ast.Expr)
   246  	seenUninitialized := make(map[types.Object]struct{})
   247  
   248  	// Some variables on the left-hand side of our assignment statement may be free. If our
   249  	// selection begins in the same scope in which the free variable is defined, we can
   250  	// redefine it in our assignment statement. See the following example, where 'b' and
   251  	// 'err' (both free variables) can be redefined in the second funcCall() while maintaining
   252  	// correctness.
   253  	//
   254  	//
   255  	// Not Redefined:
   256  	//
   257  	// a, err := funcCall()
   258  	// var b int
   259  	// b, err = funcCall()
   260  	//
   261  	// Redefined:
   262  	//
   263  	// a, err := funcCall()
   264  	// b, err := funcCall()
   265  	//
   266  	// We track the number of free variables that can be redefined to maintain our preference
   267  	// of using "x, y, z := fn()" style assignment statements.
   268  	var canRedefineCount int
   269  
   270  	// Each identifier in the selected block must become (1) a parameter to the
   271  	// extracted function, (2) a return value of the extracted function, or (3) a local
   272  	// variable in the extracted function. Determine the outcome(s) for each variable
   273  	// based on whether it is free, altered within the selected block, and used outside
   274  	// of the selected block.
   275  	for _, v := range variables {
   276  		if _, ok := seenVars[v.obj]; ok {
   277  			continue
   278  		}
   279  		typ := analysisinternal.TypeExpr(fset, file, pkg, v.obj.Type())
   280  		if typ == nil {
   281  			return nil, fmt.Errorf("nil AST expression for type: %v", v.obj.Name())
   282  		}
   283  		seenVars[v.obj] = typ
   284  		identifier := ast.NewIdent(v.obj.Name())
   285  		// An identifier must meet three conditions to become a return value of the
   286  		// extracted function. (1) its value must be defined or reassigned within
   287  		// the selection (isAssigned), (2) it must be used at least once after the
   288  		// selection (isUsed), and (3) its first use after the selection
   289  		// cannot be its own reassignment or redefinition (objOverriden).
   290  		if v.obj.Parent() == nil {
   291  			return nil, fmt.Errorf("parent nil")
   292  		}
   293  		isUsed, firstUseAfter := objUsed(info, span.NewRange(fset, rng.End, v.obj.Parent().End()), v.obj)
   294  		if v.assigned && isUsed && !varOverridden(info, firstUseAfter, v.obj, v.free, outer) {
   295  			returnTypes = append(returnTypes, &ast.Field{Type: typ})
   296  			returns = append(returns, identifier)
   297  			if !v.free {
   298  				uninitialized = append(uninitialized, v.obj)
   299  			} else if v.obj.Parent().Pos() == startParent.Pos() {
   300  				canRedefineCount++
   301  			}
   302  		}
   303  		// An identifier must meet two conditions to become a parameter of the
   304  		// extracted function. (1) it must be free (isFree), and (2) its first
   305  		// use within the selection cannot be its own definition (isDefined).
   306  		if v.free && !v.defined {
   307  			params = append(params, identifier)
   308  			paramTypes = append(paramTypes, &ast.Field{
   309  				Names: []*ast.Ident{identifier},
   310  				Type:  typ,
   311  			})
   312  		}
   313  	}
   314  
   315  	// Find the function literal that encloses the selection. The enclosing function literal
   316  	// may not be the enclosing function declaration (i.e. 'outer'). For example, in the
   317  	// following block:
   318  	//
   319  	// func main() {
   320  	//     ast.Inspect(node, func(n ast.Node) bool {
   321  	//         v := 1 // this line extracted
   322  	//         return true
   323  	//     })
   324  	// }
   325  	//
   326  	// 'outer' is main(). However, the extracted selection most directly belongs to
   327  	// the anonymous function literal, the second argument of ast.Inspect(). We use the
   328  	// enclosing function literal to determine the proper return types for return statements
   329  	// within the selection. We still need the enclosing function declaration because this is
   330  	// the top-level declaration. We inspect the top-level declaration to look for variables
   331  	// as well as for code replacement.
   332  	enclosing := outer.Type
   333  	for _, p := range path {
   334  		if p == enclosing {
   335  			break
   336  		}
   337  		if fl, ok := p.(*ast.FuncLit); ok {
   338  			enclosing = fl.Type
   339  			break
   340  		}
   341  	}
   342  
   343  	// We put the selection in a constructed file. We can then traverse and edit
   344  	// the extracted selection without modifying the original AST.
   345  	startOffset := tok.Offset(rng.Start)
   346  	endOffset := tok.Offset(rng.End)
   347  	selection := src[startOffset:endOffset]
   348  	extractedBlock, err := parseBlockStmt(fset, selection)
   349  	if err != nil {
   350  		return nil, err
   351  	}
   352  
   353  	// We need to account for return statements in the selected block, as they will complicate
   354  	// the logical flow of the extracted function. See the following example, where ** denotes
   355  	// the range to be extracted.
   356  	//
   357  	// Before:
   358  	//
   359  	// func _() int {
   360  	//     a := 1
   361  	//     b := 2
   362  	//     **if a == b {
   363  	//         return a
   364  	//     }**
   365  	//     ...
   366  	// }
   367  	//
   368  	// After:
   369  	//
   370  	// func _() int {
   371  	//     a := 1
   372  	//     b := 2
   373  	//     cond0, ret0 := x0(a, b)
   374  	//     if cond0 {
   375  	//         return ret0
   376  	//     }
   377  	//     ...
   378  	// }
   379  	//
   380  	// func x0(a int, b int) (bool, int) {
   381  	//     if a == b {
   382  	//         return true, a
   383  	//     }
   384  	//     return false, 0
   385  	// }
   386  	//
   387  	// We handle returns by adding an additional boolean return value to the extracted function.
   388  	// This bool reports whether the original function would have returned. Because the
   389  	// extracted selection contains a return statement, we must also add the types in the
   390  	// return signature of the enclosing function to the return signature of the
   391  	// extracted function. We then add an extra if statement checking this boolean value
   392  	// in the original function. If the condition is met, the original function should
   393  	// return a value, mimicking the functionality of the original return statement(s)
   394  	// in the selection.
   395  
   396  	var retVars []*returnVariable
   397  	var ifReturn *ast.IfStmt
   398  	if containsReturnStatement {
   399  		// The selected block contained return statements, so we have to modify the
   400  		// signature of the extracted function as described above. Adjust all of
   401  		// the return statements in the extracted function to reflect this change in
   402  		// signature.
   403  		if err := adjustReturnStatements(returnTypes, seenVars, fset, file,
   404  			pkg, extractedBlock); err != nil {
   405  			return nil, err
   406  		}
   407  		// Collect the additional return values and types needed to accommodate return
   408  		// statements in the selection. Update the type signature of the extracted
   409  		// function and construct the if statement that will be inserted in the enclosing
   410  		// function.
   411  		retVars, ifReturn, err = generateReturnInfo(enclosing, pkg, path, file, info, fset, rng.Start)
   412  		if err != nil {
   413  			return nil, err
   414  		}
   415  	}
   416  
   417  	// Add a return statement to the end of the new function. This return statement must include
   418  	// the values for the types of the original extracted function signature and (if a return
   419  	// statement is present in the selection) enclosing function signature.
   420  	hasReturnValues := len(returns)+len(retVars) > 0
   421  	if hasReturnValues {
   422  		extractedBlock.List = append(extractedBlock.List, &ast.ReturnStmt{
   423  			Results: append(returns, getZeroVals(retVars)...),
   424  		})
   425  	}
   426  
   427  	// Construct the appropriate call to the extracted function.
   428  	// We must meet two conditions to use ":=" instead of '='. (1) there must be at least
   429  	// one variable on the lhs that is uninitailized (non-free) prior to the assignment.
   430  	// (2) all of the initialized (free) variables on the lhs must be able to be redefined.
   431  	sym := token.ASSIGN
   432  	canDefineCount := len(uninitialized) + canRedefineCount
   433  	canDefine := len(uninitialized)+len(retVars) > 0 && canDefineCount == len(returns)
   434  	if canDefine {
   435  		sym = token.DEFINE
   436  	}
   437  	funName := generateAvailableIdentifier(rng.Start, file, path, info, "fn", 0)
   438  	extractedFunCall := generateFuncCall(hasReturnValues, params,
   439  		append(returns, getNames(retVars)...), funName, sym)
   440  
   441  	// Build the extracted function.
   442  	newFunc := &ast.FuncDecl{
   443  		Name: ast.NewIdent(funName),
   444  		Type: &ast.FuncType{
   445  			Params:  &ast.FieldList{List: paramTypes},
   446  			Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)},
   447  		},
   448  		Body: extractedBlock,
   449  	}
   450  
   451  	// Create variable declarations for any identifiers that need to be initialized prior to
   452  	// calling the extracted function. We do not manually initialize variables if every return
   453  	// value is unitialized. We can use := to initialize the variables in this situation.
   454  	var declarations []ast.Stmt
   455  	if canDefineCount != len(returns) {
   456  		declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars)
   457  	}
   458  
   459  	var declBuf, replaceBuf, newFuncBuf, ifBuf bytes.Buffer
   460  	if err := format.Node(&declBuf, fset, declarations); err != nil {
   461  		return nil, err
   462  	}
   463  	if err := format.Node(&replaceBuf, fset, extractedFunCall); err != nil {
   464  		return nil, err
   465  	}
   466  	if ifReturn != nil {
   467  		if err := format.Node(&ifBuf, fset, ifReturn); err != nil {
   468  			return nil, err
   469  		}
   470  	}
   471  	if err := format.Node(&newFuncBuf, fset, newFunc); err != nil {
   472  		return nil, err
   473  	}
   474  
   475  	// We're going to replace the whole enclosing function,
   476  	// so preserve the text before and after the selected block.
   477  	outerStart := tok.Offset(outer.Pos())
   478  	outerEnd := tok.Offset(outer.End())
   479  	before := src[outerStart:startOffset]
   480  	after := src[endOffset:outerEnd]
   481  	newLineIndent := "\n" + calculateIndentation(src, tok, start)
   482  
   483  	var fullReplacement strings.Builder
   484  	fullReplacement.Write(before)
   485  	if declBuf.Len() > 0 { // add any initializations, if needed
   486  		initializations := strings.ReplaceAll(declBuf.String(), "\n", newLineIndent) +
   487  			newLineIndent
   488  		fullReplacement.WriteString(initializations)
   489  	}
   490  	fullReplacement.Write(replaceBuf.Bytes()) // call the extracted function
   491  	if ifBuf.Len() > 0 {                      // add the if statement below the function call, if needed
   492  		ifstatement := newLineIndent +
   493  			strings.ReplaceAll(ifBuf.String(), "\n", newLineIndent)
   494  		fullReplacement.WriteString(ifstatement)
   495  	}
   496  	fullReplacement.Write(after)
   497  	fullReplacement.WriteString("\n\n")       // add newlines after the enclosing function
   498  	fullReplacement.Write(newFuncBuf.Bytes()) // insert the extracted function
   499  
   500  	return &analysis.SuggestedFix{
   501  		TextEdits: []analysis.TextEdit{{
   502  			Pos:     outer.Pos(),
   503  			End:     outer.End(),
   504  			NewText: []byte(fullReplacement.String()),
   505  		}},
   506  	}, nil
   507  }
   508  
   509  // adjustRangeForWhitespace adjusts the given range to exclude unnecessary leading or
   510  // trailing whitespace characters from selection. In the following example, each line
   511  // of the if statement is indented once. There are also two extra spaces after the
   512  // closing bracket before the line break.
   513  //
   514  // \tif (true) {
   515  // \t    _ = 1
   516  // \t}  \n
   517  //
   518  // By default, a valid range begins at 'if' and ends at the first whitespace character
   519  // after the '}'. But, users are likely to highlight full lines rather than adjusting
   520  // their cursors for whitespace. To support this use case, we must manually adjust the
   521  // ranges to match the correct AST node. In this particular example, we would adjust
   522  // rng.Start forward by one byte, and rng.End backwards by two bytes.
   523  func adjustRangeForWhitespace(rng span.Range, tok *token.File, content []byte) span.Range {
   524  	offset := tok.Offset(rng.Start)
   525  	for offset < len(content) {
   526  		if !unicode.IsSpace(rune(content[offset])) {
   527  			break
   528  		}
   529  		// Move forwards one byte to find a non-whitespace character.
   530  		offset += 1
   531  	}
   532  	rng.Start = tok.Pos(offset)
   533  
   534  	// Move backwards to find a non-whitespace character.
   535  	offset = tok.Offset(rng.End)
   536  	for o := offset - 1; 0 <= o && o < len(content); o-- {
   537  		if !unicode.IsSpace(rune(content[o])) {
   538  			break
   539  		}
   540  		offset = o
   541  	}
   542  	rng.End = tok.Pos(offset)
   543  	return rng
   544  }
   545  
   546  // findParent finds the parent AST node of the given target node, if the target is a
   547  // descendant of the starting node.
   548  func findParent(start ast.Node, target ast.Node) ast.Node {
   549  	var parent ast.Node
   550  	analysisinternal.WalkASTWithParent(start, func(n, p ast.Node) bool {
   551  		if n == target {
   552  			parent = p
   553  			return false
   554  		}
   555  		return true
   556  	})
   557  	return parent
   558  }
   559  
   560  // variable describes the status of a variable within a selection.
   561  type variable struct {
   562  	obj types.Object
   563  
   564  	// free reports whether the variable is a free variable, meaning it should
   565  	// be a parameter to the extracted function.
   566  	free bool
   567  
   568  	// assigned reports whether the variable is assigned to in the selection.
   569  	assigned bool
   570  
   571  	// defined reports whether the variable is defined in the selection.
   572  	defined bool
   573  }
   574  
   575  // collectFreeVars maps each identifier in the given range to whether it is "free."
   576  // Given a range, a variable in that range is defined as "free" if it is declared
   577  // outside of the range and neither at the file scope nor package scope. These free
   578  // variables will be used as arguments in the extracted function. It also returns a
   579  // list of identifiers that may need to be returned by the extracted function.
   580  // Some of the code in this function has been adapted from tools/cmd/guru/freevars.go.
   581  func collectFreeVars(info *types.Info, file *ast.File, fileScope, pkgScope *types.Scope, rng span.Range, node ast.Node) ([]*variable, error) {
   582  	// id returns non-nil if n denotes an object that is referenced by the span
   583  	// and defined either within the span or in the lexical environment. The bool
   584  	// return value acts as an indicator for where it was defined.
   585  	id := func(n *ast.Ident) (types.Object, bool) {
   586  		obj := info.Uses[n]
   587  		if obj == nil {
   588  			return info.Defs[n], false
   589  		}
   590  		if obj.Name() == "_" {
   591  			return nil, false // exclude objects denoting '_'
   592  		}
   593  		if _, ok := obj.(*types.PkgName); ok {
   594  			return nil, false // imported package
   595  		}
   596  		if !(file.Pos() <= obj.Pos() && obj.Pos() <= file.End()) {
   597  			return nil, false // not defined in this file
   598  		}
   599  		scope := obj.Parent()
   600  		if scope == nil {
   601  			return nil, false // e.g. interface method, struct field
   602  		}
   603  		if scope == fileScope || scope == pkgScope {
   604  			return nil, false // defined at file or package scope
   605  		}
   606  		if rng.Start <= obj.Pos() && obj.Pos() <= rng.End {
   607  			return obj, false // defined within selection => not free
   608  		}
   609  		return obj, true
   610  	}
   611  	// sel returns non-nil if n denotes a selection o.x.y that is referenced by the
   612  	// span and defined either within the span or in the lexical environment. The bool
   613  	// return value acts as an indicator for where it was defined.
   614  	var sel func(n *ast.SelectorExpr) (types.Object, bool)
   615  	sel = func(n *ast.SelectorExpr) (types.Object, bool) {
   616  		switch x := astutil.Unparen(n.X).(type) {
   617  		case *ast.SelectorExpr:
   618  			return sel(x)
   619  		case *ast.Ident:
   620  			return id(x)
   621  		}
   622  		return nil, false
   623  	}
   624  	seen := make(map[types.Object]*variable)
   625  	firstUseIn := make(map[types.Object]token.Pos)
   626  	var vars []types.Object
   627  	ast.Inspect(node, func(n ast.Node) bool {
   628  		if n == nil {
   629  			return false
   630  		}
   631  		if rng.Start <= n.Pos() && n.End() <= rng.End {
   632  			var obj types.Object
   633  			var isFree, prune bool
   634  			switch n := n.(type) {
   635  			case *ast.Ident:
   636  				obj, isFree = id(n)
   637  			case *ast.SelectorExpr:
   638  				obj, isFree = sel(n)
   639  				prune = true
   640  			}
   641  			if obj != nil {
   642  				seen[obj] = &variable{
   643  					obj:  obj,
   644  					free: isFree,
   645  				}
   646  				vars = append(vars, obj)
   647  				// Find the first time that the object is used in the selection.
   648  				first, ok := firstUseIn[obj]
   649  				if !ok || n.Pos() < first {
   650  					firstUseIn[obj] = n.Pos()
   651  				}
   652  				if prune {
   653  					return false
   654  				}
   655  			}
   656  		}
   657  		return n.Pos() <= rng.End
   658  	})
   659  
   660  	// Find identifiers that are initialized or whose values are altered at some
   661  	// point in the selected block. For example, in a selected block from lines 2-4,
   662  	// variables x, y, and z are included in assigned. However, in a selected block
   663  	// from lines 3-4, only variables y and z are included in assigned.
   664  	//
   665  	// 1: var a int
   666  	// 2: var x int
   667  	// 3: y := 3
   668  	// 4: z := x + a
   669  	//
   670  	ast.Inspect(node, func(n ast.Node) bool {
   671  		if n == nil {
   672  			return false
   673  		}
   674  		if n.Pos() < rng.Start || n.End() > rng.End {
   675  			return n.Pos() <= rng.End
   676  		}
   677  		switch n := n.(type) {
   678  		case *ast.AssignStmt:
   679  			for _, assignment := range n.Lhs {
   680  				lhs, ok := assignment.(*ast.Ident)
   681  				if !ok {
   682  					continue
   683  				}
   684  				obj, _ := id(lhs)
   685  				if obj == nil {
   686  					continue
   687  				}
   688  				if _, ok := seen[obj]; !ok {
   689  					continue
   690  				}
   691  				seen[obj].assigned = true
   692  				if n.Tok != token.DEFINE {
   693  					continue
   694  				}
   695  				// Find identifiers that are defined prior to being used
   696  				// elsewhere in the selection.
   697  				// TODO: Include identifiers that are assigned prior to being
   698  				// used elsewhere in the selection. Then, change the assignment
   699  				// to a definition in the extracted function.
   700  				if firstUseIn[obj] != lhs.Pos() {
   701  					continue
   702  				}
   703  				// Ensure that the object is not used in its own re-definition.
   704  				// For example:
   705  				// var f float64
   706  				// f, e := math.Frexp(f)
   707  				for _, expr := range n.Rhs {
   708  					if referencesObj(info, expr, obj) {
   709  						continue
   710  					}
   711  					if _, ok := seen[obj]; !ok {
   712  						continue
   713  					}
   714  					seen[obj].defined = true
   715  					break
   716  				}
   717  			}
   718  			return false
   719  		case *ast.DeclStmt:
   720  			gen, ok := n.Decl.(*ast.GenDecl)
   721  			if !ok {
   722  				return false
   723  			}
   724  			for _, spec := range gen.Specs {
   725  				vSpecs, ok := spec.(*ast.ValueSpec)
   726  				if !ok {
   727  					continue
   728  				}
   729  				for _, vSpec := range vSpecs.Names {
   730  					obj, _ := id(vSpec)
   731  					if obj == nil {
   732  						continue
   733  					}
   734  					if _, ok := seen[obj]; !ok {
   735  						continue
   736  					}
   737  					seen[obj].assigned = true
   738  				}
   739  			}
   740  			return false
   741  		case *ast.IncDecStmt:
   742  			if ident, ok := n.X.(*ast.Ident); !ok {
   743  				return false
   744  			} else if obj, _ := id(ident); obj == nil {
   745  				return false
   746  			} else {
   747  				if _, ok := seen[obj]; !ok {
   748  					return false
   749  				}
   750  				seen[obj].assigned = true
   751  			}
   752  		}
   753  		return true
   754  	})
   755  	var variables []*variable
   756  	for _, obj := range vars {
   757  		v, ok := seen[obj]
   758  		if !ok {
   759  			return nil, fmt.Errorf("no seen types.Object for %v", obj)
   760  		}
   761  		variables = append(variables, v)
   762  	}
   763  	return variables, nil
   764  }
   765  
   766  // referencesObj checks whether the given object appears in the given expression.
   767  func referencesObj(info *types.Info, expr ast.Expr, obj types.Object) bool {
   768  	var hasObj bool
   769  	ast.Inspect(expr, func(n ast.Node) bool {
   770  		if n == nil {
   771  			return false
   772  		}
   773  		ident, ok := n.(*ast.Ident)
   774  		if !ok {
   775  			return true
   776  		}
   777  		objUse := info.Uses[ident]
   778  		if obj == objUse {
   779  			hasObj = true
   780  			return false
   781  		}
   782  		return false
   783  	})
   784  	return hasObj
   785  }
   786  
   787  type fnExtractParams struct {
   788  	tok   *token.File
   789  	path  []ast.Node
   790  	rng   span.Range
   791  	outer *ast.FuncDecl
   792  	start ast.Node
   793  }
   794  
   795  // canExtractFunction reports whether the code in the given range can be
   796  // extracted to a function.
   797  func canExtractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, _ *types.Info) (*fnExtractParams, bool, error) {
   798  	if rng.Start == rng.End {
   799  		return nil, false, fmt.Errorf("start and end are equal")
   800  	}
   801  	tok := fset.File(file.Pos())
   802  	if tok == nil {
   803  		return nil, false, fmt.Errorf("no file for pos %v", fset.Position(file.Pos()))
   804  	}
   805  	rng = adjustRangeForWhitespace(rng, tok, src)
   806  	path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
   807  	if len(path) == 0 {
   808  		return nil, false, fmt.Errorf("no path enclosing interval")
   809  	}
   810  	// Node that encloses the selection must be a statement.
   811  	// TODO: Support function extraction for an expression.
   812  	_, ok := path[0].(ast.Stmt)
   813  	if !ok {
   814  		return nil, false, fmt.Errorf("node is not a statement")
   815  	}
   816  
   817  	// Find the function declaration that encloses the selection.
   818  	var outer *ast.FuncDecl
   819  	for _, p := range path {
   820  		if p, ok := p.(*ast.FuncDecl); ok {
   821  			outer = p
   822  			break
   823  		}
   824  	}
   825  	if outer == nil {
   826  		return nil, false, fmt.Errorf("no enclosing function")
   827  	}
   828  
   829  	// Find the nodes at the start and end of the selection.
   830  	var start, end ast.Node
   831  	ast.Inspect(outer, func(n ast.Node) bool {
   832  		if n == nil {
   833  			return false
   834  		}
   835  		// Do not override 'start' with a node that begins at the same location
   836  		// but is nested further from 'outer'.
   837  		if start == nil && n.Pos() == rng.Start && n.End() <= rng.End {
   838  			start = n
   839  		}
   840  		if end == nil && n.End() == rng.End && n.Pos() >= rng.Start {
   841  			end = n
   842  		}
   843  		return n.Pos() <= rng.End
   844  	})
   845  	if start == nil || end == nil {
   846  		return nil, false, fmt.Errorf("range does not map to AST nodes")
   847  	}
   848  	return &fnExtractParams{
   849  		tok:   tok,
   850  		path:  path,
   851  		rng:   rng,
   852  		outer: outer,
   853  		start: start,
   854  	}, true, nil
   855  }
   856  
   857  // objUsed checks if the object is used within the range. It returns the first
   858  // occurrence of the object in the range, if it exists.
   859  func objUsed(info *types.Info, rng span.Range, obj types.Object) (bool, *ast.Ident) {
   860  	var firstUse *ast.Ident
   861  	for id, objUse := range info.Uses {
   862  		if obj != objUse {
   863  			continue
   864  		}
   865  		if id.Pos() < rng.Start || id.End() > rng.End {
   866  			continue
   867  		}
   868  		if firstUse == nil || id.Pos() < firstUse.Pos() {
   869  			firstUse = id
   870  		}
   871  	}
   872  	return firstUse != nil, firstUse
   873  }
   874  
   875  // varOverridden traverses the given AST node until we find the given identifier. Then, we
   876  // examine the occurrence of the given identifier and check for (1) whether the identifier
   877  // is being redefined. If the identifier is free, we also check for (2) whether the identifier
   878  // is being reassigned. We will not include an identifier in the return statement of the
   879  // extracted function if it meets one of the above conditions.
   880  func varOverridden(info *types.Info, firstUse *ast.Ident, obj types.Object, isFree bool, node ast.Node) bool {
   881  	var isOverriden bool
   882  	ast.Inspect(node, func(n ast.Node) bool {
   883  		if n == nil {
   884  			return false
   885  		}
   886  		assignment, ok := n.(*ast.AssignStmt)
   887  		if !ok {
   888  			return true
   889  		}
   890  		// A free variable is initialized prior to the selection. We can always reassign
   891  		// this variable after the selection because it has already been defined.
   892  		// Conversely, a non-free variable is initialized within the selection. Thus, we
   893  		// cannot reassign this variable after the selection unless it is initialized and
   894  		// returned by the extracted function.
   895  		if !isFree && assignment.Tok == token.ASSIGN {
   896  			return false
   897  		}
   898  		for _, assigned := range assignment.Lhs {
   899  			ident, ok := assigned.(*ast.Ident)
   900  			// Check if we found the first use of the identifier.
   901  			if !ok || ident != firstUse {
   902  				continue
   903  			}
   904  			objUse := info.Uses[ident]
   905  			if objUse == nil || objUse != obj {
   906  				continue
   907  			}
   908  			// Ensure that the object is not used in its own definition.
   909  			// For example:
   910  			// var f float64
   911  			// f, e := math.Frexp(f)
   912  			for _, expr := range assignment.Rhs {
   913  				if referencesObj(info, expr, obj) {
   914  					return false
   915  				}
   916  			}
   917  			isOverriden = true
   918  			return false
   919  		}
   920  		return false
   921  	})
   922  	return isOverriden
   923  }
   924  
   925  // parseExtraction generates an AST file from the given text. We then return the portion of the
   926  // file that represents the text.
   927  func parseBlockStmt(fset *token.FileSet, src []byte) (*ast.BlockStmt, error) {
   928  	text := "package main\nfunc _() { " + string(src) + " }"
   929  	extract, err := parser.ParseFile(fset, "", text, 0)
   930  	if err != nil {
   931  		return nil, err
   932  	}
   933  	if len(extract.Decls) == 0 {
   934  		return nil, fmt.Errorf("parsed file does not contain any declarations")
   935  	}
   936  	decl, ok := extract.Decls[0].(*ast.FuncDecl)
   937  	if !ok {
   938  		return nil, fmt.Errorf("parsed file does not contain expected function declaration")
   939  	}
   940  	if decl.Body == nil {
   941  		return nil, fmt.Errorf("extracted function has no body")
   942  	}
   943  	return decl.Body, nil
   944  }
   945  
   946  // generateReturnInfo generates the information we need to adjust the return statements and
   947  // signature of the extracted function. We prepare names, signatures, and "zero values" that
   948  // represent the new variables. We also use this information to construct the if statement that
   949  // is inserted below the call to the extracted function.
   950  func generateReturnInfo(enclosing *ast.FuncType, pkg *types.Package, path []ast.Node, file *ast.File, info *types.Info, fset *token.FileSet, pos token.Pos) ([]*returnVariable, *ast.IfStmt, error) {
   951  	// Generate information for the added bool value.
   952  	cond := &ast.Ident{Name: generateAvailableIdentifier(pos, file, path, info, "cond", 0)}
   953  	retVars := []*returnVariable{
   954  		{
   955  			name:    cond,
   956  			decl:    &ast.Field{Type: ast.NewIdent("bool")},
   957  			zeroVal: ast.NewIdent("false"),
   958  		},
   959  	}
   960  	// Generate information for the values in the return signature of the enclosing function.
   961  	if enclosing.Results != nil {
   962  		for i, field := range enclosing.Results.List {
   963  			typ := info.TypeOf(field.Type)
   964  			if typ == nil {
   965  				return nil, nil, fmt.Errorf(
   966  					"failed type conversion, AST expression: %T", field.Type)
   967  			}
   968  			expr := analysisinternal.TypeExpr(fset, file, pkg, typ)
   969  			if expr == nil {
   970  				return nil, nil, fmt.Errorf("nil AST expression")
   971  			}
   972  			retVars = append(retVars, &returnVariable{
   973  				name: ast.NewIdent(generateAvailableIdentifier(pos, file,
   974  					path, info, "ret", i)),
   975  				decl: &ast.Field{Type: expr},
   976  				zeroVal: analysisinternal.ZeroValue(
   977  					fset, file, pkg, typ),
   978  			})
   979  		}
   980  	}
   981  	// Create the return statement for the enclosing function. We must exclude the variable
   982  	// for the condition of the if statement (cond) from the return statement.
   983  	ifReturn := &ast.IfStmt{
   984  		Cond: cond,
   985  		Body: &ast.BlockStmt{
   986  			List: []ast.Stmt{&ast.ReturnStmt{Results: getNames(retVars)[1:]}},
   987  		},
   988  	}
   989  	return retVars, ifReturn, nil
   990  }
   991  
   992  // adjustReturnStatements adds "zero values" of the given types to each return statement
   993  // in the given AST node.
   994  func adjustReturnStatements(returnTypes []*ast.Field, seenVars map[types.Object]ast.Expr, fset *token.FileSet, file *ast.File, pkg *types.Package, extractedBlock *ast.BlockStmt) error {
   995  	var zeroVals []ast.Expr
   996  	// Create "zero values" for each type.
   997  	for _, returnType := range returnTypes {
   998  		var val ast.Expr
   999  		for obj, typ := range seenVars {
  1000  			if typ != returnType.Type {
  1001  				continue
  1002  			}
  1003  			val = analysisinternal.ZeroValue(fset, file, pkg, obj.Type())
  1004  			break
  1005  		}
  1006  		if val == nil {
  1007  			return fmt.Errorf(
  1008  				"could not find matching AST expression for %T", returnType.Type)
  1009  		}
  1010  		zeroVals = append(zeroVals, val)
  1011  	}
  1012  	// Add "zero values" to each return statement.
  1013  	// The bool reports whether the enclosing function should return after calling the
  1014  	// extracted function. We set the bool to 'true' because, if these return statements
  1015  	// execute, the extracted function terminates early, and the enclosing function must
  1016  	// return as well.
  1017  	zeroVals = append(zeroVals, ast.NewIdent("true"))
  1018  	ast.Inspect(extractedBlock, func(n ast.Node) bool {
  1019  		if n == nil {
  1020  			return false
  1021  		}
  1022  		if n, ok := n.(*ast.ReturnStmt); ok {
  1023  			n.Results = append(zeroVals, n.Results...)
  1024  			return false
  1025  		}
  1026  		return true
  1027  	})
  1028  	return nil
  1029  }
  1030  
  1031  // generateFuncCall constructs a call expression for the extracted function, described by the
  1032  // given parameters and return variables.
  1033  func generateFuncCall(hasReturnVals bool, params, returns []ast.Expr, name string, token token.Token) ast.Node {
  1034  	var replace ast.Node
  1035  	if hasReturnVals {
  1036  		callExpr := &ast.CallExpr{
  1037  			Fun:  ast.NewIdent(name),
  1038  			Args: params,
  1039  		}
  1040  		replace = &ast.AssignStmt{
  1041  			Lhs: returns,
  1042  			Tok: token,
  1043  			Rhs: []ast.Expr{callExpr},
  1044  		}
  1045  	} else {
  1046  		replace = &ast.CallExpr{
  1047  			Fun:  ast.NewIdent(name),
  1048  			Args: params,
  1049  		}
  1050  	}
  1051  	return replace
  1052  }
  1053  
  1054  // initializeVars creates variable declarations, if needed.
  1055  // Our preference is to replace the selected block with an "x, y, z := fn()" style
  1056  // assignment statement. We can use this style when all of the variables in the
  1057  // extracted function's return statement are either not defined prior to the extracted block
  1058  // or can be safely redefined. However, for example, if z is already defined
  1059  // in a different scope, we replace the selected block with:
  1060  //
  1061  // var x int
  1062  // var y string
  1063  // x, y, z = fn()
  1064  func initializeVars(uninitialized []types.Object, retVars []*returnVariable, seenUninitialized map[types.Object]struct{}, seenVars map[types.Object]ast.Expr) []ast.Stmt {
  1065  	var declarations []ast.Stmt
  1066  	for _, obj := range uninitialized {
  1067  		if _, ok := seenUninitialized[obj]; ok {
  1068  			continue
  1069  		}
  1070  		seenUninitialized[obj] = struct{}{}
  1071  		valSpec := &ast.ValueSpec{
  1072  			Names: []*ast.Ident{ast.NewIdent(obj.Name())},
  1073  			Type:  seenVars[obj],
  1074  		}
  1075  		genDecl := &ast.GenDecl{
  1076  			Tok:   token.VAR,
  1077  			Specs: []ast.Spec{valSpec},
  1078  		}
  1079  		declarations = append(declarations, &ast.DeclStmt{Decl: genDecl})
  1080  	}
  1081  	// Each variable added from a return statement in the selection
  1082  	// must be initialized.
  1083  	for i, retVar := range retVars {
  1084  		n := retVar.name.(*ast.Ident)
  1085  		valSpec := &ast.ValueSpec{
  1086  			Names: []*ast.Ident{n},
  1087  			Type:  retVars[i].decl.Type,
  1088  		}
  1089  		genDecl := &ast.GenDecl{
  1090  			Tok:   token.VAR,
  1091  			Specs: []ast.Spec{valSpec},
  1092  		}
  1093  		declarations = append(declarations, &ast.DeclStmt{Decl: genDecl})
  1094  	}
  1095  	return declarations
  1096  }
  1097  
  1098  // getNames returns the names from the given list of returnVariable.
  1099  func getNames(retVars []*returnVariable) []ast.Expr {
  1100  	var names []ast.Expr
  1101  	for _, retVar := range retVars {
  1102  		names = append(names, retVar.name)
  1103  	}
  1104  	return names
  1105  }
  1106  
  1107  // getZeroVals returns the "zero values" from the given list of returnVariable.
  1108  func getZeroVals(retVars []*returnVariable) []ast.Expr {
  1109  	var zvs []ast.Expr
  1110  	for _, retVar := range retVars {
  1111  		zvs = append(zvs, retVar.zeroVal)
  1112  	}
  1113  	return zvs
  1114  }
  1115  
  1116  // getDecls returns the declarations from the given list of returnVariable.
  1117  func getDecls(retVars []*returnVariable) []*ast.Field {
  1118  	var decls []*ast.Field
  1119  	for _, retVar := range retVars {
  1120  		decls = append(decls, retVar.decl)
  1121  	}
  1122  	return decls
  1123  }