github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/simple/s1008/s1008.go (about)

     1  package s1008
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/constant"
     7  	"go/token"
     8  
     9  	"github.com/amarpal/go-tools/analysis/code"
    10  	"github.com/amarpal/go-tools/analysis/facts/generated"
    11  	"github.com/amarpal/go-tools/analysis/lint"
    12  	"github.com/amarpal/go-tools/analysis/report"
    13  	"github.com/amarpal/go-tools/pattern"
    14  
    15  	"golang.org/x/tools/go/analysis"
    16  	"golang.org/x/tools/go/analysis/passes/inspect"
    17  )
    18  
    19  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
    20  	Analyzer: &analysis.Analyzer{
    21  		Name:     "S1008",
    22  		Run:      run,
    23  		Requires: []*analysis.Analyzer{inspect.Analyzer, generated.Analyzer},
    24  	},
    25  	Doc: &lint.Documentation{
    26  		Title: `Simplify returning boolean expression`,
    27  		Before: `
    28  if <expr> {
    29      return true
    30  }
    31  return false`,
    32  		After:   `return <expr>`,
    33  		Since:   "2017.1",
    34  		MergeIf: lint.MergeIfAny,
    35  	},
    36  })
    37  
    38  var Analyzer = SCAnalyzer.Analyzer
    39  
    40  var (
    41  	checkIfReturnQIf  = pattern.MustParse(`(IfStmt nil cond [(ReturnStmt [ret@(Builtin (Or "true" "false"))])] nil)`)
    42  	checkIfReturnQRet = pattern.MustParse(`(ReturnStmt [ret@(Builtin (Or "true" "false"))])`)
    43  )
    44  
    45  func run(pass *analysis.Pass) (interface{}, error) {
    46  	fn := func(node ast.Node) {
    47  		block := node.(*ast.BlockStmt)
    48  		l := len(block.List)
    49  		if l < 2 {
    50  			return
    51  		}
    52  		n1, n2 := block.List[l-2], block.List[l-1]
    53  
    54  		if len(block.List) >= 3 {
    55  			if _, ok := block.List[l-3].(*ast.IfStmt); ok {
    56  				// Do not flag a series of if statements
    57  				return
    58  			}
    59  		}
    60  		m1, ok := code.Match(pass, checkIfReturnQIf, n1)
    61  		if !ok {
    62  			return
    63  		}
    64  		m2, ok := code.Match(pass, checkIfReturnQRet, n2)
    65  		if !ok {
    66  			return
    67  		}
    68  
    69  		if op, ok := m1.State["cond"].(*ast.BinaryExpr); ok {
    70  			switch op.Op {
    71  			case token.EQL, token.LSS, token.GTR, token.NEQ, token.LEQ, token.GEQ:
    72  			default:
    73  				return
    74  			}
    75  		}
    76  
    77  		ret1 := m1.State["ret"].(*ast.Ident)
    78  		ret2 := m2.State["ret"].(*ast.Ident)
    79  
    80  		if ret1.Name == ret2.Name {
    81  			// we want the function to return true and false, not the
    82  			// same value both times.
    83  			return
    84  		}
    85  
    86  		cond := m1.State["cond"].(ast.Expr)
    87  		origCond := cond
    88  		if ret1.Name == "false" {
    89  			cond = negate(pass, cond)
    90  		}
    91  		report.Report(pass, n1,
    92  			fmt.Sprintf("should use 'return %s' instead of 'if %s { return %s }; return %s'",
    93  				report.Render(pass, cond),
    94  				report.Render(pass, origCond), report.Render(pass, ret1), report.Render(pass, ret2)),
    95  			report.FilterGenerated())
    96  	}
    97  	code.Preorder(pass, fn, (*ast.BlockStmt)(nil))
    98  	return nil, nil
    99  }
   100  
   101  func negate(pass *analysis.Pass, expr ast.Expr) ast.Expr {
   102  	switch expr := expr.(type) {
   103  	case *ast.BinaryExpr:
   104  		out := *expr
   105  		switch expr.Op {
   106  		case token.EQL:
   107  			out.Op = token.NEQ
   108  		case token.LSS:
   109  			out.Op = token.GEQ
   110  		case token.GTR:
   111  			// Some builtins never return negative ints; "len(x) <= 0" should be "len(x) == 0".
   112  			if call, ok := expr.X.(*ast.CallExpr); ok &&
   113  				code.IsCallToAny(pass, call, "len", "cap", "copy") &&
   114  				code.IsIntegerLiteral(pass, expr.Y, constant.MakeInt64(0)) {
   115  				out.Op = token.EQL
   116  			} else {
   117  				out.Op = token.LEQ
   118  			}
   119  		case token.NEQ:
   120  			out.Op = token.EQL
   121  		case token.LEQ:
   122  			out.Op = token.GTR
   123  		case token.GEQ:
   124  			out.Op = token.LSS
   125  		}
   126  		return &out
   127  	case *ast.Ident, *ast.CallExpr, *ast.IndexExpr, *ast.StarExpr:
   128  		return &ast.UnaryExpr{
   129  			Op: token.NOT,
   130  			X:  expr,
   131  		}
   132  	case *ast.UnaryExpr:
   133  		if expr.Op == token.NOT {
   134  			return expr.X
   135  		}
   136  		return &ast.UnaryExpr{
   137  			Op: token.NOT,
   138  			X:  expr,
   139  		}
   140  	default:
   141  		return &ast.UnaryExpr{
   142  			Op: token.NOT,
   143  			X: &ast.ParenExpr{
   144  				X: expr,
   145  			},
   146  		}
   147  	}
   148  }