github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/simple/s1031/s1031.go (about)

     1  package s1031
     2  
     3  import (
     4  	"go/ast"
     5  	"go/types"
     6  
     7  	"github.com/amarpal/go-tools/analysis/code"
     8  	"github.com/amarpal/go-tools/analysis/facts/generated"
     9  	"github.com/amarpal/go-tools/analysis/lint"
    10  	"github.com/amarpal/go-tools/analysis/report"
    11  	"github.com/amarpal/go-tools/go/types/typeutil"
    12  	"github.com/amarpal/go-tools/pattern"
    13  
    14  	"golang.org/x/tools/go/analysis"
    15  	"golang.org/x/tools/go/analysis/passes/inspect"
    16  )
    17  
    18  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
    19  	Analyzer: &analysis.Analyzer{
    20  		Name:     "S1031",
    21  		Run:      run,
    22  		Requires: []*analysis.Analyzer{inspect.Analyzer, generated.Analyzer},
    23  	},
    24  	Doc: &lint.Documentation{
    25  		Title: `Omit redundant nil check around loop`,
    26  		Text: `You can use range on nil slices and maps, the loop will simply never
    27  execute. This makes an additional nil check around the loop
    28  unnecessary.`,
    29  		Before: `
    30  if s != nil {
    31      for _, x := range s {
    32          ...
    33      }
    34  }`,
    35  		After: `
    36  for _, x := range s {
    37      ...
    38  }`,
    39  		Since: "2017.1",
    40  		// MergeIfAll because x might be a channel under some build tags.
    41  		// you shouldn't write code like that…
    42  		MergeIf: lint.MergeIfAll,
    43  	},
    44  })
    45  
    46  var Analyzer = SCAnalyzer.Analyzer
    47  
    48  var checkNilCheckAroundRangeQ = pattern.MustParse(`
    49  	(IfStmt
    50  		nil
    51  		(BinaryExpr x@(Object _) "!=" (Builtin "nil"))
    52  		[(RangeStmt _ _ _ x _)]
    53  		nil)`)
    54  
    55  func run(pass *analysis.Pass) (interface{}, error) {
    56  	fn := func(node ast.Node) {
    57  		m, ok := code.Match(pass, checkNilCheckAroundRangeQ, node)
    58  		if !ok {
    59  			return
    60  		}
    61  		ok = typeutil.All(m.State["x"].(types.Object).Type(), func(term *types.Term) bool {
    62  			switch term.Type().Underlying().(type) {
    63  			case *types.Slice, *types.Map:
    64  				return true
    65  			case *types.TypeParam, *types.Chan, *types.Pointer:
    66  				return false
    67  			default:
    68  				lint.ExhaustiveTypeSwitch(term.Type().Underlying())
    69  				return false
    70  			}
    71  		})
    72  		if !ok {
    73  			return
    74  		}
    75  		report.Report(pass, node, "unnecessary nil check around range", report.ShortRange(), report.FilterGenerated())
    76  	}
    77  	code.Preorder(pass, fn, (*ast.IfStmt)(nil))
    78  	return nil, nil
    79  }