github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/simple/s1002/s1002.go (about)

     1  package s1002
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/token"
     7  	"go/types"
     8  	"strings"
     9  
    10  	"github.com/amarpal/go-tools/analysis/code"
    11  	"github.com/amarpal/go-tools/analysis/edit"
    12  	"github.com/amarpal/go-tools/analysis/facts/generated"
    13  	"github.com/amarpal/go-tools/analysis/lint"
    14  	"github.com/amarpal/go-tools/analysis/report"
    15  	"github.com/amarpal/go-tools/go/types/typeutil"
    16  
    17  	"golang.org/x/tools/go/analysis"
    18  	"golang.org/x/tools/go/analysis/passes/inspect"
    19  )
    20  
    21  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
    22  	Analyzer: &analysis.Analyzer{
    23  		Name:     "S1002",
    24  		Run:      run,
    25  		Requires: []*analysis.Analyzer{inspect.Analyzer, generated.Analyzer},
    26  	},
    27  	Doc: &lint.Documentation{
    28  		Title:  `Omit comparison with boolean constant`,
    29  		Before: `if x == true {}`,
    30  		After:  `if x {}`,
    31  		Since:  "2017.1",
    32  		// MergeIfAll because 'true' might not be the builtin constant under all build tags.
    33  		// You shouldn't write code like that…
    34  		MergeIf: lint.MergeIfAll,
    35  	},
    36  })
    37  
    38  var Analyzer = SCAnalyzer.Analyzer
    39  
    40  func run(pass *analysis.Pass) (interface{}, error) {
    41  	fn := func(node ast.Node) {
    42  		if code.IsInTest(pass, node) {
    43  			return
    44  		}
    45  
    46  		expr := node.(*ast.BinaryExpr)
    47  		if expr.Op != token.EQL && expr.Op != token.NEQ {
    48  			return
    49  		}
    50  		x := code.IsBoolConst(pass, expr.X)
    51  		y := code.IsBoolConst(pass, expr.Y)
    52  		if !x && !y {
    53  			return
    54  		}
    55  		var other ast.Expr
    56  		var val bool
    57  		if x {
    58  			val = code.BoolConst(pass, expr.X)
    59  			other = expr.Y
    60  		} else {
    61  			val = code.BoolConst(pass, expr.Y)
    62  			other = expr.X
    63  		}
    64  
    65  		ok := typeutil.All(pass.TypesInfo.TypeOf(other), func(term *types.Term) bool {
    66  			basic, ok := term.Type().Underlying().(*types.Basic)
    67  			return ok && basic.Kind() == types.Bool
    68  		})
    69  		if !ok {
    70  			return
    71  		}
    72  		op := ""
    73  		if (expr.Op == token.EQL && !val) || (expr.Op == token.NEQ && val) {
    74  			op = "!"
    75  		}
    76  		r := op + report.Render(pass, other)
    77  		l1 := len(r)
    78  		r = strings.TrimLeft(r, "!")
    79  		if (l1-len(r))%2 == 1 {
    80  			r = "!" + r
    81  		}
    82  		report.Report(pass, expr, fmt.Sprintf("should omit comparison to bool constant, can be simplified to %s", r),
    83  			report.FilterGenerated(),
    84  			report.Fixes(edit.Fix("simplify bool comparison", edit.ReplaceWithString(expr, r))))
    85  	}
    86  	code.Preorder(pass, fn, (*ast.BinaryExpr)(nil))
    87  	return nil, nil
    88  }