github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/internal/lsp/analysis/fillreturns/fillreturns.go (about) 1 // Copyright 2020 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package fillreturns defines an Analyzer that will attempt to 6 // automatically fill in a return statement that has missing 7 // values with zero value elements. 8 package fillreturns 9 10 import ( 11 "bytes" 12 "fmt" 13 "go/ast" 14 "go/format" 15 "go/types" 16 "regexp" 17 "strings" 18 19 "github.com/powerman/golang-tools/go/analysis" 20 "github.com/powerman/golang-tools/go/ast/astutil" 21 "github.com/powerman/golang-tools/internal/analysisinternal" 22 "github.com/powerman/golang-tools/internal/typeparams" 23 ) 24 25 const Doc = `suggest fixes for errors due to an incorrect number of return values 26 27 This checker provides suggested fixes for type errors of the 28 type "wrong number of return values (want %d, got %d)". For example: 29 func m() (int, string, *bool, error) { 30 return 31 } 32 will turn into 33 func m() (int, string, *bool, error) { 34 return 0, "", nil, nil 35 } 36 37 This functionality is similar to https://github.com/sqs/goreturns. 38 ` 39 40 var Analyzer = &analysis.Analyzer{ 41 Name: "fillreturns", 42 Doc: Doc, 43 Requires: []*analysis.Analyzer{}, 44 Run: run, 45 RunDespiteErrors: true, 46 } 47 48 func run(pass *analysis.Pass) (interface{}, error) { 49 info := pass.TypesInfo 50 if info == nil { 51 return nil, fmt.Errorf("nil TypeInfo") 52 } 53 54 errors := analysisinternal.GetTypeErrors(pass) 55 outer: 56 for _, typeErr := range errors { 57 // Filter out the errors that are not relevant to this analyzer. 58 if !FixesError(typeErr) { 59 continue 60 } 61 var file *ast.File 62 for _, f := range pass.Files { 63 if f.Pos() <= typeErr.Pos && typeErr.Pos <= f.End() { 64 file = f 65 break 66 } 67 } 68 if file == nil { 69 continue 70 } 71 72 // Get the end position of the error. 73 var buf bytes.Buffer 74 if err := format.Node(&buf, pass.Fset, file); err != nil { 75 continue 76 } 77 typeErrEndPos := analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), typeErr.Pos) 78 79 // TODO(rfindley): much of the error handling code below returns, when it 80 // should probably continue. 81 82 // Get the path for the relevant range. 83 path, _ := astutil.PathEnclosingInterval(file, typeErr.Pos, typeErrEndPos) 84 if len(path) == 0 { 85 return nil, nil 86 } 87 88 // Find the enclosing return statement. 89 var ret *ast.ReturnStmt 90 var retIdx int 91 for i, n := range path { 92 if r, ok := n.(*ast.ReturnStmt); ok { 93 ret = r 94 retIdx = i 95 break 96 } 97 } 98 if ret == nil { 99 return nil, nil 100 } 101 102 // Get the function type that encloses the ReturnStmt. 103 var enclosingFunc *ast.FuncType 104 for _, n := range path[retIdx+1:] { 105 switch node := n.(type) { 106 case *ast.FuncLit: 107 enclosingFunc = node.Type 108 case *ast.FuncDecl: 109 enclosingFunc = node.Type 110 } 111 if enclosingFunc != nil { 112 break 113 } 114 } 115 if enclosingFunc == nil { 116 continue 117 } 118 119 // Skip any generic enclosing functions, since type parameters don't 120 // have 0 values. 121 // TODO(rfindley): We should be able to handle this if the return 122 // values are all concrete types. 123 if tparams := typeparams.ForFuncType(enclosingFunc); tparams != nil && tparams.NumFields() > 0 { 124 return nil, nil 125 } 126 127 // Find the function declaration that encloses the ReturnStmt. 128 var outer *ast.FuncDecl 129 for _, p := range path { 130 if p, ok := p.(*ast.FuncDecl); ok { 131 outer = p 132 break 133 } 134 } 135 if outer == nil { 136 return nil, nil 137 } 138 139 // Skip any return statements that contain function calls with multiple 140 // return values. 141 for _, expr := range ret.Results { 142 e, ok := expr.(*ast.CallExpr) 143 if !ok { 144 continue 145 } 146 if tup, ok := info.TypeOf(e).(*types.Tuple); ok && tup.Len() > 1 { 147 continue outer 148 } 149 } 150 151 // Duplicate the return values to track which values have been matched. 152 remaining := make([]ast.Expr, len(ret.Results)) 153 copy(remaining, ret.Results) 154 155 fixed := make([]ast.Expr, len(enclosingFunc.Results.List)) 156 157 // For each value in the return function declaration, find the leftmost element 158 // in the return statement that has the desired type. If no such element exits, 159 // fill in the missing value with the appropriate "zero" value. 160 var retTyps []types.Type 161 for _, ret := range enclosingFunc.Results.List { 162 retTyps = append(retTyps, info.TypeOf(ret.Type)) 163 } 164 matches := 165 analysisinternal.FindMatchingIdents(retTyps, file, ret.Pos(), info, pass.Pkg) 166 for i, retTyp := range retTyps { 167 var match ast.Expr 168 var idx int 169 for j, val := range remaining { 170 if !matchingTypes(info.TypeOf(val), retTyp) { 171 continue 172 } 173 if !analysisinternal.IsZeroValue(val) { 174 match, idx = val, j 175 break 176 } 177 // If the current match is a "zero" value, we keep searching in 178 // case we find a non-"zero" value match. If we do not find a 179 // non-"zero" value, we will use the "zero" value. 180 match, idx = val, j 181 } 182 183 if match != nil { 184 fixed[i] = match 185 remaining = append(remaining[:idx], remaining[idx+1:]...) 186 } else { 187 idents, ok := matches[retTyp] 188 if !ok { 189 return nil, fmt.Errorf("invalid return type: %v", retTyp) 190 } 191 // Find the identifier whose name is most similar to the return type. 192 // If we do not find any identifier that matches the pattern, 193 // generate a zero value. 194 value := analysisinternal.FindBestMatch(retTyp.String(), idents) 195 if value == nil { 196 value = analysisinternal.ZeroValue( 197 pass.Fset, file, pass.Pkg, retTyp) 198 } 199 if value == nil { 200 return nil, nil 201 } 202 fixed[i] = value 203 } 204 } 205 206 // Remove any non-matching "zero values" from the leftover values. 207 var nonZeroRemaining []ast.Expr 208 for _, expr := range remaining { 209 if !analysisinternal.IsZeroValue(expr) { 210 nonZeroRemaining = append(nonZeroRemaining, expr) 211 } 212 } 213 // Append leftover return values to end of new return statement. 214 fixed = append(fixed, nonZeroRemaining...) 215 216 newRet := &ast.ReturnStmt{ 217 Return: ret.Pos(), 218 Results: fixed, 219 } 220 221 // Convert the new return statement AST to text. 222 var newBuf bytes.Buffer 223 if err := format.Node(&newBuf, pass.Fset, newRet); err != nil { 224 return nil, err 225 } 226 227 pass.Report(analysis.Diagnostic{ 228 Pos: typeErr.Pos, 229 End: typeErrEndPos, 230 Message: typeErr.Msg, 231 SuggestedFixes: []analysis.SuggestedFix{{ 232 Message: "Fill in return values", 233 TextEdits: []analysis.TextEdit{{ 234 Pos: ret.Pos(), 235 End: ret.End(), 236 NewText: newBuf.Bytes(), 237 }}, 238 }}, 239 }) 240 } 241 return nil, nil 242 } 243 244 func matchingTypes(want, got types.Type) bool { 245 if want == got || types.Identical(want, got) { 246 return true 247 } 248 // Code segment to help check for untyped equality from (golang/go#32146). 249 if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 { 250 if lhs, ok := got.Underlying().(*types.Basic); ok { 251 return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType 252 } 253 } 254 return types.AssignableTo(want, got) || types.ConvertibleTo(want, got) 255 } 256 257 // Error messages have changed across Go versions. These regexps capture recent 258 // incarnations. 259 // 260 // TODO(rfindley): once error codes are exported and exposed via go/packages, 261 // use error codes rather than string matching here. 262 var wrongReturnNumRegexes = []*regexp.Regexp{ 263 regexp.MustCompile(`wrong number of return values \(want (\d+), got (\d+)\)`), 264 regexp.MustCompile(`too many return values`), 265 regexp.MustCompile(`not enough return values`), 266 } 267 268 func FixesError(err types.Error) bool { 269 msg := strings.TrimSpace(err.Msg) 270 for _, rx := range wrongReturnNumRegexes { 271 if rx.MatchString(msg) { 272 return true 273 } 274 } 275 return false 276 }