github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/staticcheck/sa5001/sa5001.go (about) 1 package sa5001 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/types" 7 8 "github.com/amarpal/go-tools/analysis/code" 9 "github.com/amarpal/go-tools/analysis/lint" 10 "github.com/amarpal/go-tools/analysis/report" 11 12 "golang.org/x/tools/go/analysis" 13 "golang.org/x/tools/go/analysis/passes/inspect" 14 ) 15 16 var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{ 17 Analyzer: &analysis.Analyzer{ 18 Name: "SA5001", 19 Run: run, 20 Requires: []*analysis.Analyzer{inspect.Analyzer}, 21 }, 22 Doc: &lint.Documentation{ 23 Title: `Deferring \'Close\' before checking for a possible error`, 24 Since: "2017.1", 25 Severity: lint.SeverityWarning, 26 MergeIf: lint.MergeIfAny, 27 }, 28 }) 29 30 var Analyzer = SCAnalyzer.Analyzer 31 32 func run(pass *analysis.Pass) (interface{}, error) { 33 fn := func(node ast.Node) { 34 block := node.(*ast.BlockStmt) 35 if len(block.List) < 2 { 36 return 37 } 38 for i, stmt := range block.List { 39 if i == len(block.List)-1 { 40 break 41 } 42 assign, ok := stmt.(*ast.AssignStmt) 43 if !ok { 44 continue 45 } 46 if len(assign.Rhs) != 1 { 47 continue 48 } 49 if len(assign.Lhs) < 2 { 50 continue 51 } 52 if lhs, ok := assign.Lhs[len(assign.Lhs)-1].(*ast.Ident); ok && lhs.Name == "_" { 53 continue 54 } 55 call, ok := assign.Rhs[0].(*ast.CallExpr) 56 if !ok { 57 continue 58 } 59 sig, ok := pass.TypesInfo.TypeOf(call.Fun).(*types.Signature) 60 if !ok { 61 continue 62 } 63 if sig.Results().Len() < 2 { 64 continue 65 } 66 last := sig.Results().At(sig.Results().Len() - 1) 67 // FIXME(dh): check that it's error from universe, not 68 // another type of the same name 69 if last.Type().String() != "error" { 70 continue 71 } 72 lhs, ok := assign.Lhs[0].(*ast.Ident) 73 if !ok { 74 continue 75 } 76 def, ok := block.List[i+1].(*ast.DeferStmt) 77 if !ok { 78 continue 79 } 80 sel, ok := def.Call.Fun.(*ast.SelectorExpr) 81 if !ok { 82 continue 83 } 84 ident, ok := selectorX(sel).(*ast.Ident) 85 if !ok { 86 continue 87 } 88 if pass.TypesInfo.ObjectOf(ident) != pass.TypesInfo.ObjectOf(lhs) { 89 continue 90 } 91 if sel.Sel.Name != "Close" { 92 continue 93 } 94 report.Report(pass, def, fmt.Sprintf("should check returned error before deferring %s", report.Render(pass, def.Call))) 95 } 96 } 97 code.Preorder(pass, fn, (*ast.BlockStmt)(nil)) 98 return nil, nil 99 } 100 101 func selectorX(sel *ast.SelectorExpr) ast.Node { 102 switch x := sel.X.(type) { 103 case *ast.SelectorExpr: 104 return selectorX(x) 105 default: 106 return x 107 } 108 }