github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/simple/s1034/s1034.go (about)

     1  package s1034
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/types"
     7  
     8  	"github.com/amarpal/go-tools/analysis/code"
     9  	"github.com/amarpal/go-tools/analysis/edit"
    10  	"github.com/amarpal/go-tools/analysis/facts/generated"
    11  	"github.com/amarpal/go-tools/analysis/lint"
    12  	"github.com/amarpal/go-tools/analysis/report"
    13  	"github.com/amarpal/go-tools/pattern"
    14  
    15  	"golang.org/x/tools/go/analysis"
    16  	"golang.org/x/tools/go/analysis/passes/inspect"
    17  )
    18  
    19  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
    20  	Analyzer: &analysis.Analyzer{
    21  		Name:     "S1034",
    22  		Run:      run,
    23  		Requires: []*analysis.Analyzer{inspect.Analyzer, generated.Analyzer},
    24  	},
    25  	Doc: &lint.Documentation{
    26  		Title:   `Use result of type assertion to simplify cases`,
    27  		Since:   "2019.2",
    28  		MergeIf: lint.MergeIfAny,
    29  	},
    30  })
    31  
    32  var Analyzer = SCAnalyzer.Analyzer
    33  
    34  var (
    35  	checkSimplifyTypeSwitchQ = pattern.MustParse(`
    36  		(TypeSwitchStmt
    37  			nil
    38  			expr@(TypeAssertExpr ident@(Ident _) _)
    39  			body)`)
    40  	checkSimplifyTypeSwitchR = pattern.MustParse(`(AssignStmt ident ":=" expr)`)
    41  )
    42  
    43  func run(pass *analysis.Pass) (interface{}, error) {
    44  	fn := func(node ast.Node) {
    45  		m, ok := code.Match(pass, checkSimplifyTypeSwitchQ, node)
    46  		if !ok {
    47  			return
    48  		}
    49  		stmt := node.(*ast.TypeSwitchStmt)
    50  		expr := m.State["expr"].(ast.Node)
    51  		ident := m.State["ident"].(*ast.Ident)
    52  
    53  		x := pass.TypesInfo.ObjectOf(ident)
    54  		var allOffenders []*ast.TypeAssertExpr
    55  		canSuggestFix := true
    56  		for _, clause := range stmt.Body.List {
    57  			clause := clause.(*ast.CaseClause)
    58  			if len(clause.List) != 1 {
    59  				continue
    60  			}
    61  			hasUnrelatedAssertion := false
    62  			var offenders []*ast.TypeAssertExpr
    63  			ast.Inspect(clause, func(node ast.Node) bool {
    64  				assert2, ok := node.(*ast.TypeAssertExpr)
    65  				if !ok {
    66  					return true
    67  				}
    68  				ident, ok := assert2.X.(*ast.Ident)
    69  				if !ok {
    70  					hasUnrelatedAssertion = true
    71  					return false
    72  				}
    73  				if pass.TypesInfo.ObjectOf(ident) != x {
    74  					hasUnrelatedAssertion = true
    75  					return false
    76  				}
    77  
    78  				if !types.Identical(pass.TypesInfo.TypeOf(clause.List[0]), pass.TypesInfo.TypeOf(assert2.Type)) {
    79  					hasUnrelatedAssertion = true
    80  					return false
    81  				}
    82  				offenders = append(offenders, assert2)
    83  				return true
    84  			})
    85  			if !hasUnrelatedAssertion {
    86  				// don't flag cases that have other type assertions
    87  				// unrelated to the one in the case clause. often
    88  				// times, this is done for symmetry, when two
    89  				// different values have to be asserted to the same
    90  				// type.
    91  				allOffenders = append(allOffenders, offenders...)
    92  			}
    93  			canSuggestFix = canSuggestFix && !hasUnrelatedAssertion
    94  		}
    95  		if len(allOffenders) != 0 {
    96  			var opts []report.Option
    97  			for _, offender := range allOffenders {
    98  				opts = append(opts, report.Related(offender, "could eliminate this type assertion"))
    99  			}
   100  			opts = append(opts, report.FilterGenerated())
   101  
   102  			msg := fmt.Sprintf("assigning the result of this type assertion to a variable (switch %s := %s.(type)) could eliminate type assertions in switch cases",
   103  				report.Render(pass, ident), report.Render(pass, ident))
   104  			if canSuggestFix {
   105  				var edits []analysis.TextEdit
   106  				edits = append(edits, edit.ReplaceWithPattern(pass.Fset, expr, checkSimplifyTypeSwitchR, m.State))
   107  				for _, offender := range allOffenders {
   108  					edits = append(edits, edit.ReplaceWithNode(pass.Fset, offender, offender.X))
   109  				}
   110  				opts = append(opts, report.Fixes(edit.Fix("simplify type switch", edits...)))
   111  				report.Report(pass, expr, msg, opts...)
   112  			} else {
   113  				report.Report(pass, expr, msg, opts...)
   114  			}
   115  		}
   116  	}
   117  	code.Preorder(pass, fn, (*ast.TypeSwitchStmt)(nil))
   118  	return nil, nil
   119  }