golang.org/x/tools/gopls@v0.15.3/internal/analysis/fillreturns/fillreturns.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fillreturns
     6  
     7  import (
     8  	"bytes"
     9  	_ "embed"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/types"
    14  	"regexp"
    15  	"strings"
    16  
    17  	"golang.org/x/tools/go/analysis"
    18  	"golang.org/x/tools/go/ast/astutil"
    19  	"golang.org/x/tools/internal/analysisinternal"
    20  	"golang.org/x/tools/internal/fuzzy"
    21  )
    22  
    23  //go:embed doc.go
    24  var doc string
    25  
    26  var Analyzer = &analysis.Analyzer{
    27  	Name:             "fillreturns",
    28  	Doc:              analysisinternal.MustExtractDoc(doc, "fillreturns"),
    29  	Run:              run,
    30  	RunDespiteErrors: true,
    31  	URL:              "https://pkg.go.dev/golang.org/x/tools/gopls/internal/analysis/fillreturns",
    32  }
    33  
    34  func run(pass *analysis.Pass) (interface{}, error) {
    35  	info := pass.TypesInfo
    36  	if info == nil {
    37  		return nil, fmt.Errorf("nil TypeInfo")
    38  	}
    39  
    40  outer:
    41  	for _, typeErr := range pass.TypeErrors {
    42  		// Filter out the errors that are not relevant to this analyzer.
    43  		if !FixesError(typeErr) {
    44  			continue
    45  		}
    46  		var file *ast.File
    47  		for _, f := range pass.Files {
    48  			if f.Pos() <= typeErr.Pos && typeErr.Pos <= f.End() {
    49  				file = f
    50  				break
    51  			}
    52  		}
    53  		if file == nil {
    54  			continue
    55  		}
    56  
    57  		// Get the end position of the error.
    58  		// (This heuristic assumes that the buffer is formatted,
    59  		// at least up to the end position of the error.)
    60  		var buf bytes.Buffer
    61  		if err := format.Node(&buf, pass.Fset, file); err != nil {
    62  			continue
    63  		}
    64  		typeErrEndPos := analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), typeErr.Pos)
    65  
    66  		// TODO(rfindley): much of the error handling code below returns, when it
    67  		// should probably continue.
    68  
    69  		// Get the path for the relevant range.
    70  		path, _ := astutil.PathEnclosingInterval(file, typeErr.Pos, typeErrEndPos)
    71  		if len(path) == 0 {
    72  			return nil, nil
    73  		}
    74  
    75  		// Find the enclosing return statement.
    76  		var ret *ast.ReturnStmt
    77  		var retIdx int
    78  		for i, n := range path {
    79  			if r, ok := n.(*ast.ReturnStmt); ok {
    80  				ret = r
    81  				retIdx = i
    82  				break
    83  			}
    84  		}
    85  		if ret == nil {
    86  			return nil, nil
    87  		}
    88  
    89  		// Get the function type that encloses the ReturnStmt.
    90  		var enclosingFunc *ast.FuncType
    91  		for _, n := range path[retIdx+1:] {
    92  			switch node := n.(type) {
    93  			case *ast.FuncLit:
    94  				enclosingFunc = node.Type
    95  			case *ast.FuncDecl:
    96  				enclosingFunc = node.Type
    97  			}
    98  			if enclosingFunc != nil {
    99  				break
   100  			}
   101  		}
   102  		if enclosingFunc == nil || enclosingFunc.Results == nil {
   103  			continue
   104  		}
   105  
   106  		// Skip any generic enclosing functions, since type parameters don't
   107  		// have 0 values.
   108  		// TODO(rfindley): We should be able to handle this if the return
   109  		// values are all concrete types.
   110  		if tparams := enclosingFunc.TypeParams; tparams != nil && tparams.NumFields() > 0 {
   111  			return nil, nil
   112  		}
   113  
   114  		// Find the function declaration that encloses the ReturnStmt.
   115  		var outer *ast.FuncDecl
   116  		for _, p := range path {
   117  			if p, ok := p.(*ast.FuncDecl); ok {
   118  				outer = p
   119  				break
   120  			}
   121  		}
   122  		if outer == nil {
   123  			return nil, nil
   124  		}
   125  
   126  		// Skip any return statements that contain function calls with multiple
   127  		// return values.
   128  		for _, expr := range ret.Results {
   129  			e, ok := expr.(*ast.CallExpr)
   130  			if !ok {
   131  				continue
   132  			}
   133  			if tup, ok := info.TypeOf(e).(*types.Tuple); ok && tup.Len() > 1 {
   134  				continue outer
   135  			}
   136  		}
   137  
   138  		// Duplicate the return values to track which values have been matched.
   139  		remaining := make([]ast.Expr, len(ret.Results))
   140  		copy(remaining, ret.Results)
   141  
   142  		fixed := make([]ast.Expr, len(enclosingFunc.Results.List))
   143  
   144  		// For each value in the return function declaration, find the leftmost element
   145  		// in the return statement that has the desired type. If no such element exists,
   146  		// fill in the missing value with the appropriate "zero" value.
   147  		// Beware that type information may be incomplete.
   148  		var retTyps []types.Type
   149  		for _, ret := range enclosingFunc.Results.List {
   150  			retTyp := info.TypeOf(ret.Type)
   151  			if retTyp == nil {
   152  				return nil, nil
   153  			}
   154  			retTyps = append(retTyps, retTyp)
   155  		}
   156  		matches := analysisinternal.MatchingIdents(retTyps, file, ret.Pos(), info, pass.Pkg)
   157  		for i, retTyp := range retTyps {
   158  			var match ast.Expr
   159  			var idx int
   160  			for j, val := range remaining {
   161  				if t := info.TypeOf(val); t == nil || !matchingTypes(t, retTyp) {
   162  					continue
   163  				}
   164  				if !analysisinternal.IsZeroValue(val) {
   165  					match, idx = val, j
   166  					break
   167  				}
   168  				// If the current match is a "zero" value, we keep searching in
   169  				// case we find a non-"zero" value match. If we do not find a
   170  				// non-"zero" value, we will use the "zero" value.
   171  				match, idx = val, j
   172  			}
   173  
   174  			if match != nil {
   175  				fixed[i] = match
   176  				remaining = append(remaining[:idx], remaining[idx+1:]...)
   177  			} else {
   178  				names, ok := matches[retTyp]
   179  				if !ok {
   180  					return nil, fmt.Errorf("invalid return type: %v", retTyp)
   181  				}
   182  				// Find the identifier most similar to the return type.
   183  				// If no identifier matches the pattern, generate a zero value.
   184  				if best := fuzzy.BestMatch(retTyp.String(), names); best != "" {
   185  					fixed[i] = ast.NewIdent(best)
   186  				} else if zero := analysisinternal.ZeroValue(file, pass.Pkg, retTyp); zero != nil {
   187  					fixed[i] = zero
   188  				} else {
   189  					return nil, nil
   190  				}
   191  			}
   192  		}
   193  
   194  		// Remove any non-matching "zero values" from the leftover values.
   195  		var nonZeroRemaining []ast.Expr
   196  		for _, expr := range remaining {
   197  			if !analysisinternal.IsZeroValue(expr) {
   198  				nonZeroRemaining = append(nonZeroRemaining, expr)
   199  			}
   200  		}
   201  		// Append leftover return values to end of new return statement.
   202  		fixed = append(fixed, nonZeroRemaining...)
   203  
   204  		newRet := &ast.ReturnStmt{
   205  			Return:  ret.Pos(),
   206  			Results: fixed,
   207  		}
   208  
   209  		// Convert the new return statement AST to text.
   210  		var newBuf bytes.Buffer
   211  		if err := format.Node(&newBuf, pass.Fset, newRet); err != nil {
   212  			return nil, err
   213  		}
   214  
   215  		pass.Report(analysis.Diagnostic{
   216  			Pos:     typeErr.Pos,
   217  			End:     typeErrEndPos,
   218  			Message: typeErr.Msg,
   219  			SuggestedFixes: []analysis.SuggestedFix{{
   220  				Message: "Fill in return values",
   221  				TextEdits: []analysis.TextEdit{{
   222  					Pos:     ret.Pos(),
   223  					End:     ret.End(),
   224  					NewText: newBuf.Bytes(),
   225  				}},
   226  			}},
   227  		})
   228  	}
   229  	return nil, nil
   230  }
   231  
   232  func matchingTypes(want, got types.Type) bool {
   233  	if want == got || types.Identical(want, got) {
   234  		return true
   235  	}
   236  	// Code segment to help check for untyped equality from (golang/go#32146).
   237  	if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
   238  		if lhs, ok := got.Underlying().(*types.Basic); ok {
   239  			return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
   240  		}
   241  	}
   242  	return types.AssignableTo(want, got) || types.ConvertibleTo(want, got)
   243  }
   244  
   245  // Error messages have changed across Go versions. These regexps capture recent
   246  // incarnations.
   247  //
   248  // TODO(rfindley): once error codes are exported and exposed via go/packages,
   249  // use error codes rather than string matching here.
   250  var wrongReturnNumRegexes = []*regexp.Regexp{
   251  	regexp.MustCompile(`wrong number of return values \(want (\d+), got (\d+)\)`),
   252  	regexp.MustCompile(`too many return values`),
   253  	regexp.MustCompile(`not enough return values`),
   254  }
   255  
   256  func FixesError(err types.Error) bool {
   257  	msg := strings.TrimSpace(err.Msg)
   258  	for _, rx := range wrongReturnNumRegexes {
   259  		if rx.MatchString(msg) {
   260  			return true
   261  		}
   262  	}
   263  	return false
   264  }