github.com/v2fly/tools@v0.100.0/internal/lsp/analysis/fillreturns/fillreturns.go (about) 1 // Copyright 2020 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package fillreturns defines an Analyzer that will attempt to 6 // automatically fill in a return statement that has missing 7 // values with zero value elements. 8 package fillreturns 9 10 import ( 11 "bytes" 12 "fmt" 13 "go/ast" 14 "go/format" 15 "go/types" 16 "regexp" 17 "strconv" 18 "strings" 19 20 "github.com/v2fly/tools/go/analysis" 21 "github.com/v2fly/tools/go/ast/astutil" 22 "github.com/v2fly/tools/internal/analysisinternal" 23 ) 24 25 const Doc = `suggested fixes for "wrong number of return values (want %d, got %d)" 26 27 This checker provides suggested fixes for type errors of the 28 type "wrong number of return values (want %d, got %d)". For example: 29 func m() (int, string, *bool, error) { 30 return 31 } 32 will turn into 33 func m() (int, string, *bool, error) { 34 return 0, "", nil, nil 35 } 36 37 This functionality is similar to https://github.com/sqs/goreturns. 38 ` 39 40 var Analyzer = &analysis.Analyzer{ 41 Name: "fillreturns", 42 Doc: Doc, 43 Requires: []*analysis.Analyzer{}, 44 Run: run, 45 RunDespiteErrors: true, 46 } 47 48 var wrongReturnNumRegex = regexp.MustCompile(`wrong number of return values \(want (\d+), got (\d+)\)`) 49 50 func run(pass *analysis.Pass) (interface{}, error) { 51 info := pass.TypesInfo 52 if info == nil { 53 return nil, fmt.Errorf("nil TypeInfo") 54 } 55 56 errors := analysisinternal.GetTypeErrors(pass) 57 outer: 58 for _, typeErr := range errors { 59 // Filter out the errors that are not relevant to this analyzer. 60 if !FixesError(typeErr.Msg) { 61 continue 62 } 63 var file *ast.File 64 for _, f := range pass.Files { 65 if f.Pos() <= typeErr.Pos && typeErr.Pos <= f.End() { 66 file = f 67 break 68 } 69 } 70 if file == nil { 71 continue 72 } 73 74 // Get the end position of the error. 75 var buf bytes.Buffer 76 if err := format.Node(&buf, pass.Fset, file); err != nil { 77 continue 78 } 79 typeErrEndPos := analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), typeErr.Pos) 80 81 // Get the path for the relevant range. 82 path, _ := astutil.PathEnclosingInterval(file, typeErr.Pos, typeErrEndPos) 83 if len(path) == 0 { 84 return nil, nil 85 } 86 // Check to make sure the node of interest is a ReturnStmt. 87 ret, ok := path[0].(*ast.ReturnStmt) 88 if !ok { 89 return nil, nil 90 } 91 92 // Get the function type that encloses the ReturnStmt. 93 var enclosingFunc *ast.FuncType 94 for _, n := range path { 95 switch node := n.(type) { 96 case *ast.FuncLit: 97 enclosingFunc = node.Type 98 case *ast.FuncDecl: 99 enclosingFunc = node.Type 100 } 101 if enclosingFunc != nil { 102 break 103 } 104 } 105 if enclosingFunc == nil { 106 continue 107 } 108 109 // Find the function declaration that encloses the ReturnStmt. 110 var outer *ast.FuncDecl 111 for _, p := range path { 112 if p, ok := p.(*ast.FuncDecl); ok { 113 outer = p 114 break 115 } 116 } 117 if outer == nil { 118 return nil, nil 119 } 120 121 // Skip any return statements that contain function calls with multiple return values. 122 for _, expr := range ret.Results { 123 e, ok := expr.(*ast.CallExpr) 124 if !ok { 125 continue 126 } 127 if tup, ok := info.TypeOf(e).(*types.Tuple); ok && tup.Len() > 1 { 128 continue outer 129 } 130 } 131 132 // Duplicate the return values to track which values have been matched. 133 remaining := make([]ast.Expr, len(ret.Results)) 134 copy(remaining, ret.Results) 135 136 fixed := make([]ast.Expr, len(enclosingFunc.Results.List)) 137 138 // For each value in the return function declaration, find the leftmost element 139 // in the return statement that has the desired type. If no such element exits, 140 // fill in the missing value with the appropriate "zero" value. 141 var retTyps []types.Type 142 for _, ret := range enclosingFunc.Results.List { 143 retTyps = append(retTyps, info.TypeOf(ret.Type)) 144 } 145 matches := 146 analysisinternal.FindMatchingIdents(retTyps, file, ret.Pos(), info, pass.Pkg) 147 for i, retTyp := range retTyps { 148 var match ast.Expr 149 var idx int 150 for j, val := range remaining { 151 if !matchingTypes(info.TypeOf(val), retTyp) { 152 continue 153 } 154 if !analysisinternal.IsZeroValue(val) { 155 match, idx = val, j 156 break 157 } 158 // If the current match is a "zero" value, we keep searching in 159 // case we find a non-"zero" value match. If we do not find a 160 // non-"zero" value, we will use the "zero" value. 161 match, idx = val, j 162 } 163 164 if match != nil { 165 fixed[i] = match 166 remaining = append(remaining[:idx], remaining[idx+1:]...) 167 } else { 168 idents, ok := matches[retTyp] 169 if !ok { 170 return nil, fmt.Errorf("invalid return type: %v", retTyp) 171 } 172 // Find the identifer whose name is most similar to the return type. 173 // If we do not find any identifer that matches the pattern, 174 // generate a zero value. 175 value := analysisinternal.FindBestMatch(retTyp.String(), idents) 176 if value == nil { 177 value = analysisinternal.ZeroValue( 178 pass.Fset, file, pass.Pkg, retTyp) 179 } 180 if value == nil { 181 return nil, nil 182 } 183 fixed[i] = value 184 } 185 } 186 187 // Remove any non-matching "zero values" from the leftover values. 188 var nonZeroRemaining []ast.Expr 189 for _, expr := range remaining { 190 if !analysisinternal.IsZeroValue(expr) { 191 nonZeroRemaining = append(nonZeroRemaining, expr) 192 } 193 } 194 // Append leftover return values to end of new return statement. 195 fixed = append(fixed, nonZeroRemaining...) 196 197 newRet := &ast.ReturnStmt{ 198 Return: ret.Pos(), 199 Results: fixed, 200 } 201 202 // Convert the new return statement AST to text. 203 var newBuf bytes.Buffer 204 if err := format.Node(&newBuf, pass.Fset, newRet); err != nil { 205 return nil, err 206 } 207 208 pass.Report(analysis.Diagnostic{ 209 Pos: typeErr.Pos, 210 End: typeErrEndPos, 211 Message: typeErr.Msg, 212 SuggestedFixes: []analysis.SuggestedFix{{ 213 Message: "Fill in return values", 214 TextEdits: []analysis.TextEdit{{ 215 Pos: ret.Pos(), 216 End: ret.End(), 217 NewText: newBuf.Bytes(), 218 }}, 219 }}, 220 }) 221 } 222 return nil, nil 223 } 224 225 func matchingTypes(want, got types.Type) bool { 226 if want == got || types.Identical(want, got) { 227 return true 228 } 229 // Code segment to help check for untyped equality from (golang/go#32146). 230 if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 { 231 if lhs, ok := got.Underlying().(*types.Basic); ok { 232 return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType 233 } 234 } 235 return types.AssignableTo(want, got) || types.ConvertibleTo(want, got) 236 } 237 238 func FixesError(msg string) bool { 239 matches := wrongReturnNumRegex.FindStringSubmatch(strings.TrimSpace(msg)) 240 if len(matches) < 3 { 241 return false 242 } 243 if _, err := strconv.Atoi(matches[1]); err != nil { 244 return false 245 } 246 if _, err := strconv.Atoi(matches[2]); err != nil { 247 return false 248 } 249 return true 250 }