github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/simple/s1008/s1008.go (about) 1 package s1008 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/constant" 7 "go/token" 8 9 "github.com/amarpal/go-tools/analysis/code" 10 "github.com/amarpal/go-tools/analysis/facts/generated" 11 "github.com/amarpal/go-tools/analysis/lint" 12 "github.com/amarpal/go-tools/analysis/report" 13 "github.com/amarpal/go-tools/pattern" 14 15 "golang.org/x/tools/go/analysis" 16 "golang.org/x/tools/go/analysis/passes/inspect" 17 ) 18 19 var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{ 20 Analyzer: &analysis.Analyzer{ 21 Name: "S1008", 22 Run: run, 23 Requires: []*analysis.Analyzer{inspect.Analyzer, generated.Analyzer}, 24 }, 25 Doc: &lint.Documentation{ 26 Title: `Simplify returning boolean expression`, 27 Before: ` 28 if <expr> { 29 return true 30 } 31 return false`, 32 After: `return <expr>`, 33 Since: "2017.1", 34 MergeIf: lint.MergeIfAny, 35 }, 36 }) 37 38 var Analyzer = SCAnalyzer.Analyzer 39 40 var ( 41 checkIfReturnQIf = pattern.MustParse(`(IfStmt nil cond [(ReturnStmt [ret@(Builtin (Or "true" "false"))])] nil)`) 42 checkIfReturnQRet = pattern.MustParse(`(ReturnStmt [ret@(Builtin (Or "true" "false"))])`) 43 ) 44 45 func run(pass *analysis.Pass) (interface{}, error) { 46 fn := func(node ast.Node) { 47 block := node.(*ast.BlockStmt) 48 l := len(block.List) 49 if l < 2 { 50 return 51 } 52 n1, n2 := block.List[l-2], block.List[l-1] 53 54 if len(block.List) >= 3 { 55 if _, ok := block.List[l-3].(*ast.IfStmt); ok { 56 // Do not flag a series of if statements 57 return 58 } 59 } 60 m1, ok := code.Match(pass, checkIfReturnQIf, n1) 61 if !ok { 62 return 63 } 64 m2, ok := code.Match(pass, checkIfReturnQRet, n2) 65 if !ok { 66 return 67 } 68 69 if op, ok := m1.State["cond"].(*ast.BinaryExpr); ok { 70 switch op.Op { 71 case token.EQL, token.LSS, token.GTR, token.NEQ, token.LEQ, token.GEQ: 72 default: 73 return 74 } 75 } 76 77 ret1 := m1.State["ret"].(*ast.Ident) 78 ret2 := m2.State["ret"].(*ast.Ident) 79 80 if ret1.Name == ret2.Name { 81 // we want the function to return true and false, not the 82 // same value both times. 83 return 84 } 85 86 cond := m1.State["cond"].(ast.Expr) 87 origCond := cond 88 if ret1.Name == "false" { 89 cond = negate(pass, cond) 90 } 91 report.Report(pass, n1, 92 fmt.Sprintf("should use 'return %s' instead of 'if %s { return %s }; return %s'", 93 report.Render(pass, cond), 94 report.Render(pass, origCond), report.Render(pass, ret1), report.Render(pass, ret2)), 95 report.FilterGenerated()) 96 } 97 code.Preorder(pass, fn, (*ast.BlockStmt)(nil)) 98 return nil, nil 99 } 100 101 func negate(pass *analysis.Pass, expr ast.Expr) ast.Expr { 102 switch expr := expr.(type) { 103 case *ast.BinaryExpr: 104 out := *expr 105 switch expr.Op { 106 case token.EQL: 107 out.Op = token.NEQ 108 case token.LSS: 109 out.Op = token.GEQ 110 case token.GTR: 111 // Some builtins never return negative ints; "len(x) <= 0" should be "len(x) == 0". 112 if call, ok := expr.X.(*ast.CallExpr); ok && 113 code.IsCallToAny(pass, call, "len", "cap", "copy") && 114 code.IsIntegerLiteral(pass, expr.Y, constant.MakeInt64(0)) { 115 out.Op = token.EQL 116 } else { 117 out.Op = token.LEQ 118 } 119 case token.NEQ: 120 out.Op = token.EQL 121 case token.LEQ: 122 out.Op = token.GTR 123 case token.GEQ: 124 out.Op = token.LSS 125 } 126 return &out 127 case *ast.Ident, *ast.CallExpr, *ast.IndexExpr, *ast.StarExpr: 128 return &ast.UnaryExpr{ 129 Op: token.NOT, 130 X: expr, 131 } 132 case *ast.UnaryExpr: 133 if expr.Op == token.NOT { 134 return expr.X 135 } 136 return &ast.UnaryExpr{ 137 Op: token.NOT, 138 X: expr, 139 } 140 default: 141 return &ast.UnaryExpr{ 142 Op: token.NOT, 143 X: &ast.ParenExpr{ 144 X: expr, 145 }, 146 } 147 } 148 }