github.com/v2fly/tools@v0.100.0/internal/lsp/analysis/fillreturns/fillreturns.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package fillreturns defines an Analyzer that will attempt to
     6  // automatically fill in a return statement that has missing
     7  // values with zero value elements.
     8  package fillreturns
     9  
    10  import (
    11  	"bytes"
    12  	"fmt"
    13  	"go/ast"
    14  	"go/format"
    15  	"go/types"
    16  	"regexp"
    17  	"strconv"
    18  	"strings"
    19  
    20  	"github.com/v2fly/tools/go/analysis"
    21  	"github.com/v2fly/tools/go/ast/astutil"
    22  	"github.com/v2fly/tools/internal/analysisinternal"
    23  )
    24  
    25  const Doc = `suggested fixes for "wrong number of return values (want %d, got %d)"
    26  
    27  This checker provides suggested fixes for type errors of the
    28  type "wrong number of return values (want %d, got %d)". For example:
    29  	func m() (int, string, *bool, error) {
    30  		return
    31  	}
    32  will turn into
    33  	func m() (int, string, *bool, error) {
    34  		return 0, "", nil, nil
    35  	}
    36  
    37  This functionality is similar to https://github.com/sqs/goreturns.
    38  `
    39  
    40  var Analyzer = &analysis.Analyzer{
    41  	Name:             "fillreturns",
    42  	Doc:              Doc,
    43  	Requires:         []*analysis.Analyzer{},
    44  	Run:              run,
    45  	RunDespiteErrors: true,
    46  }
    47  
    48  var wrongReturnNumRegex = regexp.MustCompile(`wrong number of return values \(want (\d+), got (\d+)\)`)
    49  
    50  func run(pass *analysis.Pass) (interface{}, error) {
    51  	info := pass.TypesInfo
    52  	if info == nil {
    53  		return nil, fmt.Errorf("nil TypeInfo")
    54  	}
    55  
    56  	errors := analysisinternal.GetTypeErrors(pass)
    57  outer:
    58  	for _, typeErr := range errors {
    59  		// Filter out the errors that are not relevant to this analyzer.
    60  		if !FixesError(typeErr.Msg) {
    61  			continue
    62  		}
    63  		var file *ast.File
    64  		for _, f := range pass.Files {
    65  			if f.Pos() <= typeErr.Pos && typeErr.Pos <= f.End() {
    66  				file = f
    67  				break
    68  			}
    69  		}
    70  		if file == nil {
    71  			continue
    72  		}
    73  
    74  		// Get the end position of the error.
    75  		var buf bytes.Buffer
    76  		if err := format.Node(&buf, pass.Fset, file); err != nil {
    77  			continue
    78  		}
    79  		typeErrEndPos := analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), typeErr.Pos)
    80  
    81  		// Get the path for the relevant range.
    82  		path, _ := astutil.PathEnclosingInterval(file, typeErr.Pos, typeErrEndPos)
    83  		if len(path) == 0 {
    84  			return nil, nil
    85  		}
    86  		// Check to make sure the node of interest is a ReturnStmt.
    87  		ret, ok := path[0].(*ast.ReturnStmt)
    88  		if !ok {
    89  			return nil, nil
    90  		}
    91  
    92  		// Get the function type that encloses the ReturnStmt.
    93  		var enclosingFunc *ast.FuncType
    94  		for _, n := range path {
    95  			switch node := n.(type) {
    96  			case *ast.FuncLit:
    97  				enclosingFunc = node.Type
    98  			case *ast.FuncDecl:
    99  				enclosingFunc = node.Type
   100  			}
   101  			if enclosingFunc != nil {
   102  				break
   103  			}
   104  		}
   105  		if enclosingFunc == nil {
   106  			continue
   107  		}
   108  
   109  		// Find the function declaration that encloses the ReturnStmt.
   110  		var outer *ast.FuncDecl
   111  		for _, p := range path {
   112  			if p, ok := p.(*ast.FuncDecl); ok {
   113  				outer = p
   114  				break
   115  			}
   116  		}
   117  		if outer == nil {
   118  			return nil, nil
   119  		}
   120  
   121  		// Skip any return statements that contain function calls with multiple return values.
   122  		for _, expr := range ret.Results {
   123  			e, ok := expr.(*ast.CallExpr)
   124  			if !ok {
   125  				continue
   126  			}
   127  			if tup, ok := info.TypeOf(e).(*types.Tuple); ok && tup.Len() > 1 {
   128  				continue outer
   129  			}
   130  		}
   131  
   132  		// Duplicate the return values to track which values have been matched.
   133  		remaining := make([]ast.Expr, len(ret.Results))
   134  		copy(remaining, ret.Results)
   135  
   136  		fixed := make([]ast.Expr, len(enclosingFunc.Results.List))
   137  
   138  		// For each value in the return function declaration, find the leftmost element
   139  		// in the return statement that has the desired type. If no such element exits,
   140  		// fill in the missing value with the appropriate "zero" value.
   141  		var retTyps []types.Type
   142  		for _, ret := range enclosingFunc.Results.List {
   143  			retTyps = append(retTyps, info.TypeOf(ret.Type))
   144  		}
   145  		matches :=
   146  			analysisinternal.FindMatchingIdents(retTyps, file, ret.Pos(), info, pass.Pkg)
   147  		for i, retTyp := range retTyps {
   148  			var match ast.Expr
   149  			var idx int
   150  			for j, val := range remaining {
   151  				if !matchingTypes(info.TypeOf(val), retTyp) {
   152  					continue
   153  				}
   154  				if !analysisinternal.IsZeroValue(val) {
   155  					match, idx = val, j
   156  					break
   157  				}
   158  				// If the current match is a "zero" value, we keep searching in
   159  				// case we find a non-"zero" value match. If we do not find a
   160  				// non-"zero" value, we will use the "zero" value.
   161  				match, idx = val, j
   162  			}
   163  
   164  			if match != nil {
   165  				fixed[i] = match
   166  				remaining = append(remaining[:idx], remaining[idx+1:]...)
   167  			} else {
   168  				idents, ok := matches[retTyp]
   169  				if !ok {
   170  					return nil, fmt.Errorf("invalid return type: %v", retTyp)
   171  				}
   172  				// Find the identifer whose name is most similar to the return type.
   173  				// If we do not find any identifer that matches the pattern,
   174  				// generate a zero value.
   175  				value := analysisinternal.FindBestMatch(retTyp.String(), idents)
   176  				if value == nil {
   177  					value = analysisinternal.ZeroValue(
   178  						pass.Fset, file, pass.Pkg, retTyp)
   179  				}
   180  				if value == nil {
   181  					return nil, nil
   182  				}
   183  				fixed[i] = value
   184  			}
   185  		}
   186  
   187  		// Remove any non-matching "zero values" from the leftover values.
   188  		var nonZeroRemaining []ast.Expr
   189  		for _, expr := range remaining {
   190  			if !analysisinternal.IsZeroValue(expr) {
   191  				nonZeroRemaining = append(nonZeroRemaining, expr)
   192  			}
   193  		}
   194  		// Append leftover return values to end of new return statement.
   195  		fixed = append(fixed, nonZeroRemaining...)
   196  
   197  		newRet := &ast.ReturnStmt{
   198  			Return:  ret.Pos(),
   199  			Results: fixed,
   200  		}
   201  
   202  		// Convert the new return statement AST to text.
   203  		var newBuf bytes.Buffer
   204  		if err := format.Node(&newBuf, pass.Fset, newRet); err != nil {
   205  			return nil, err
   206  		}
   207  
   208  		pass.Report(analysis.Diagnostic{
   209  			Pos:     typeErr.Pos,
   210  			End:     typeErrEndPos,
   211  			Message: typeErr.Msg,
   212  			SuggestedFixes: []analysis.SuggestedFix{{
   213  				Message: "Fill in return values",
   214  				TextEdits: []analysis.TextEdit{{
   215  					Pos:     ret.Pos(),
   216  					End:     ret.End(),
   217  					NewText: newBuf.Bytes(),
   218  				}},
   219  			}},
   220  		})
   221  	}
   222  	return nil, nil
   223  }
   224  
   225  func matchingTypes(want, got types.Type) bool {
   226  	if want == got || types.Identical(want, got) {
   227  		return true
   228  	}
   229  	// Code segment to help check for untyped equality from (golang/go#32146).
   230  	if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
   231  		if lhs, ok := got.Underlying().(*types.Basic); ok {
   232  			return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
   233  		}
   234  	}
   235  	return types.AssignableTo(want, got) || types.ConvertibleTo(want, got)
   236  }
   237  
   238  func FixesError(msg string) bool {
   239  	matches := wrongReturnNumRegex.FindStringSubmatch(strings.TrimSpace(msg))
   240  	if len(matches) < 3 {
   241  		return false
   242  	}
   243  	if _, err := strconv.Atoi(matches[1]); err != nil {
   244  		return false
   245  	}
   246  	if _, err := strconv.Atoi(matches[2]); err != nil {
   247  		return false
   248  	}
   249  	return true
   250  }