github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/internal/lsp/analysis/fillreturns/fillreturns.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package fillreturns defines an Analyzer that will attempt to
     6  // automatically fill in a return statement that has missing
     7  // values with zero value elements.
     8  package fillreturns
     9  
    10  import (
    11  	"bytes"
    12  	"fmt"
    13  	"go/ast"
    14  	"go/format"
    15  	"go/types"
    16  	"regexp"
    17  	"strings"
    18  
    19  	"github.com/powerman/golang-tools/go/analysis"
    20  	"github.com/powerman/golang-tools/go/ast/astutil"
    21  	"github.com/powerman/golang-tools/internal/analysisinternal"
    22  	"github.com/powerman/golang-tools/internal/typeparams"
    23  )
    24  
    25  const Doc = `suggest fixes for errors due to an incorrect number of return values
    26  
    27  This checker provides suggested fixes for type errors of the
    28  type "wrong number of return values (want %d, got %d)". For example:
    29  	func m() (int, string, *bool, error) {
    30  		return
    31  	}
    32  will turn into
    33  	func m() (int, string, *bool, error) {
    34  		return 0, "", nil, nil
    35  	}
    36  
    37  This functionality is similar to https://github.com/sqs/goreturns.
    38  `
    39  
    40  var Analyzer = &analysis.Analyzer{
    41  	Name:             "fillreturns",
    42  	Doc:              Doc,
    43  	Requires:         []*analysis.Analyzer{},
    44  	Run:              run,
    45  	RunDespiteErrors: true,
    46  }
    47  
    48  func run(pass *analysis.Pass) (interface{}, error) {
    49  	info := pass.TypesInfo
    50  	if info == nil {
    51  		return nil, fmt.Errorf("nil TypeInfo")
    52  	}
    53  
    54  	errors := analysisinternal.GetTypeErrors(pass)
    55  outer:
    56  	for _, typeErr := range errors {
    57  		// Filter out the errors that are not relevant to this analyzer.
    58  		if !FixesError(typeErr) {
    59  			continue
    60  		}
    61  		var file *ast.File
    62  		for _, f := range pass.Files {
    63  			if f.Pos() <= typeErr.Pos && typeErr.Pos <= f.End() {
    64  				file = f
    65  				break
    66  			}
    67  		}
    68  		if file == nil {
    69  			continue
    70  		}
    71  
    72  		// Get the end position of the error.
    73  		var buf bytes.Buffer
    74  		if err := format.Node(&buf, pass.Fset, file); err != nil {
    75  			continue
    76  		}
    77  		typeErrEndPos := analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), typeErr.Pos)
    78  
    79  		// TODO(rfindley): much of the error handling code below returns, when it
    80  		// should probably continue.
    81  
    82  		// Get the path for the relevant range.
    83  		path, _ := astutil.PathEnclosingInterval(file, typeErr.Pos, typeErrEndPos)
    84  		if len(path) == 0 {
    85  			return nil, nil
    86  		}
    87  
    88  		// Find the enclosing return statement.
    89  		var ret *ast.ReturnStmt
    90  		var retIdx int
    91  		for i, n := range path {
    92  			if r, ok := n.(*ast.ReturnStmt); ok {
    93  				ret = r
    94  				retIdx = i
    95  				break
    96  			}
    97  		}
    98  		if ret == nil {
    99  			return nil, nil
   100  		}
   101  
   102  		// Get the function type that encloses the ReturnStmt.
   103  		var enclosingFunc *ast.FuncType
   104  		for _, n := range path[retIdx+1:] {
   105  			switch node := n.(type) {
   106  			case *ast.FuncLit:
   107  				enclosingFunc = node.Type
   108  			case *ast.FuncDecl:
   109  				enclosingFunc = node.Type
   110  			}
   111  			if enclosingFunc != nil {
   112  				break
   113  			}
   114  		}
   115  		if enclosingFunc == nil {
   116  			continue
   117  		}
   118  
   119  		// Skip any generic enclosing functions, since type parameters don't
   120  		// have 0 values.
   121  		// TODO(rfindley): We should be able to handle this if the return
   122  		// values are all concrete types.
   123  		if tparams := typeparams.ForFuncType(enclosingFunc); tparams != nil && tparams.NumFields() > 0 {
   124  			return nil, nil
   125  		}
   126  
   127  		// Find the function declaration that encloses the ReturnStmt.
   128  		var outer *ast.FuncDecl
   129  		for _, p := range path {
   130  			if p, ok := p.(*ast.FuncDecl); ok {
   131  				outer = p
   132  				break
   133  			}
   134  		}
   135  		if outer == nil {
   136  			return nil, nil
   137  		}
   138  
   139  		// Skip any return statements that contain function calls with multiple
   140  		// return values.
   141  		for _, expr := range ret.Results {
   142  			e, ok := expr.(*ast.CallExpr)
   143  			if !ok {
   144  				continue
   145  			}
   146  			if tup, ok := info.TypeOf(e).(*types.Tuple); ok && tup.Len() > 1 {
   147  				continue outer
   148  			}
   149  		}
   150  
   151  		// Duplicate the return values to track which values have been matched.
   152  		remaining := make([]ast.Expr, len(ret.Results))
   153  		copy(remaining, ret.Results)
   154  
   155  		fixed := make([]ast.Expr, len(enclosingFunc.Results.List))
   156  
   157  		// For each value in the return function declaration, find the leftmost element
   158  		// in the return statement that has the desired type. If no such element exits,
   159  		// fill in the missing value with the appropriate "zero" value.
   160  		var retTyps []types.Type
   161  		for _, ret := range enclosingFunc.Results.List {
   162  			retTyps = append(retTyps, info.TypeOf(ret.Type))
   163  		}
   164  		matches :=
   165  			analysisinternal.FindMatchingIdents(retTyps, file, ret.Pos(), info, pass.Pkg)
   166  		for i, retTyp := range retTyps {
   167  			var match ast.Expr
   168  			var idx int
   169  			for j, val := range remaining {
   170  				if !matchingTypes(info.TypeOf(val), retTyp) {
   171  					continue
   172  				}
   173  				if !analysisinternal.IsZeroValue(val) {
   174  					match, idx = val, j
   175  					break
   176  				}
   177  				// If the current match is a "zero" value, we keep searching in
   178  				// case we find a non-"zero" value match. If we do not find a
   179  				// non-"zero" value, we will use the "zero" value.
   180  				match, idx = val, j
   181  			}
   182  
   183  			if match != nil {
   184  				fixed[i] = match
   185  				remaining = append(remaining[:idx], remaining[idx+1:]...)
   186  			} else {
   187  				idents, ok := matches[retTyp]
   188  				if !ok {
   189  					return nil, fmt.Errorf("invalid return type: %v", retTyp)
   190  				}
   191  				// Find the identifier whose name is most similar to the return type.
   192  				// If we do not find any identifier that matches the pattern,
   193  				// generate a zero value.
   194  				value := analysisinternal.FindBestMatch(retTyp.String(), idents)
   195  				if value == nil {
   196  					value = analysisinternal.ZeroValue(
   197  						pass.Fset, file, pass.Pkg, retTyp)
   198  				}
   199  				if value == nil {
   200  					return nil, nil
   201  				}
   202  				fixed[i] = value
   203  			}
   204  		}
   205  
   206  		// Remove any non-matching "zero values" from the leftover values.
   207  		var nonZeroRemaining []ast.Expr
   208  		for _, expr := range remaining {
   209  			if !analysisinternal.IsZeroValue(expr) {
   210  				nonZeroRemaining = append(nonZeroRemaining, expr)
   211  			}
   212  		}
   213  		// Append leftover return values to end of new return statement.
   214  		fixed = append(fixed, nonZeroRemaining...)
   215  
   216  		newRet := &ast.ReturnStmt{
   217  			Return:  ret.Pos(),
   218  			Results: fixed,
   219  		}
   220  
   221  		// Convert the new return statement AST to text.
   222  		var newBuf bytes.Buffer
   223  		if err := format.Node(&newBuf, pass.Fset, newRet); err != nil {
   224  			return nil, err
   225  		}
   226  
   227  		pass.Report(analysis.Diagnostic{
   228  			Pos:     typeErr.Pos,
   229  			End:     typeErrEndPos,
   230  			Message: typeErr.Msg,
   231  			SuggestedFixes: []analysis.SuggestedFix{{
   232  				Message: "Fill in return values",
   233  				TextEdits: []analysis.TextEdit{{
   234  					Pos:     ret.Pos(),
   235  					End:     ret.End(),
   236  					NewText: newBuf.Bytes(),
   237  				}},
   238  			}},
   239  		})
   240  	}
   241  	return nil, nil
   242  }
   243  
   244  func matchingTypes(want, got types.Type) bool {
   245  	if want == got || types.Identical(want, got) {
   246  		return true
   247  	}
   248  	// Code segment to help check for untyped equality from (golang/go#32146).
   249  	if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
   250  		if lhs, ok := got.Underlying().(*types.Basic); ok {
   251  			return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
   252  		}
   253  	}
   254  	return types.AssignableTo(want, got) || types.ConvertibleTo(want, got)
   255  }
   256  
   257  // Error messages have changed across Go versions. These regexps capture recent
   258  // incarnations.
   259  //
   260  // TODO(rfindley): once error codes are exported and exposed via go/packages,
   261  // use error codes rather than string matching here.
   262  var wrongReturnNumRegexes = []*regexp.Regexp{
   263  	regexp.MustCompile(`wrong number of return values \(want (\d+), got (\d+)\)`),
   264  	regexp.MustCompile(`too many return values`),
   265  	regexp.MustCompile(`not enough return values`),
   266  }
   267  
   268  func FixesError(err types.Error) bool {
   269  	msg := strings.TrimSpace(err.Msg)
   270  	for _, rx := range wrongReturnNumRegexes {
   271  		if rx.MatchString(msg) {
   272  			return true
   273  		}
   274  	}
   275  	return false
   276  }