github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/quickfix/qf1002/qf1002.go (about)

     1  package qf1002
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/token"
     7  	"strings"
     8  
     9  	"github.com/amarpal/go-tools/analysis/code"
    10  	"github.com/amarpal/go-tools/analysis/edit"
    11  	"github.com/amarpal/go-tools/analysis/lint"
    12  	"github.com/amarpal/go-tools/analysis/report"
    13  	"github.com/amarpal/go-tools/go/ast/astutil"
    14  
    15  	"golang.org/x/tools/go/analysis"
    16  	"golang.org/x/tools/go/analysis/passes/inspect"
    17  )
    18  
    19  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
    20  	Analyzer: &analysis.Analyzer{
    21  		Name:     "QF1002",
    22  		Run:      run,
    23  		Requires: []*analysis.Analyzer{inspect.Analyzer},
    24  	},
    25  	Doc: &lint.Documentation{
    26  		Title: "Convert untagged switch to tagged switch",
    27  		Text: `
    28  An untagged switch that compares a single variable against a series of
    29  values can be replaced with a tagged switch.`,
    30  		Before: `
    31  switch {
    32  case x == 1 || x == 2, x == 3:
    33      ...
    34  case x == 4:
    35      ...
    36  default:
    37      ...
    38  }`,
    39  
    40  		After: `
    41  switch x {
    42  case 1, 2, 3:
    43      ...
    44  case 4:
    45      ...
    46  default:
    47      ...
    48  }`,
    49  		Since:    "2021.1",
    50  		Severity: lint.SeverityHint,
    51  	},
    52  })
    53  
    54  var Analyzer = SCAnalyzer.Analyzer
    55  
    56  func run(pass *analysis.Pass) (interface{}, error) {
    57  	fn := func(node ast.Node) {
    58  		swtch := node.(*ast.SwitchStmt)
    59  		if swtch.Tag != nil || len(swtch.Body.List) == 0 {
    60  			return
    61  		}
    62  
    63  		pairs := make([][]*ast.BinaryExpr, len(swtch.Body.List))
    64  		for i, stmt := range swtch.Body.List {
    65  			stmt := stmt.(*ast.CaseClause)
    66  			for _, cond := range stmt.List {
    67  				if !findSwitchPairs(pass, cond, &pairs[i]) {
    68  					return
    69  				}
    70  			}
    71  		}
    72  
    73  		var x ast.Expr
    74  		for _, pair := range pairs {
    75  			if len(pair) == 0 {
    76  				continue
    77  			}
    78  			if x == nil {
    79  				x = pair[0].X
    80  			} else {
    81  				if !astutil.Equal(x, pair[0].X) {
    82  					return
    83  				}
    84  			}
    85  		}
    86  		if x == nil {
    87  			// the switch only has a default case
    88  			if len(pairs) > 1 {
    89  				panic("found more than one case clause with no pairs")
    90  			}
    91  			return
    92  		}
    93  
    94  		edits := make([]analysis.TextEdit, 0, len(swtch.Body.List)+1)
    95  		for i, stmt := range swtch.Body.List {
    96  			stmt := stmt.(*ast.CaseClause)
    97  			if stmt.List == nil {
    98  				continue
    99  			}
   100  
   101  			var values []string
   102  			for _, binexpr := range pairs[i] {
   103  				y := binexpr.Y
   104  				if p, ok := y.(*ast.ParenExpr); ok {
   105  					y = p.X
   106  				}
   107  				values = append(values, report.Render(pass, y))
   108  			}
   109  
   110  			edits = append(edits, edit.ReplaceWithString(edit.Range{stmt.List[0].Pos(), stmt.Colon}, strings.Join(values, ", ")))
   111  		}
   112  		pos := swtch.Switch + token.Pos(len("switch"))
   113  		edits = append(edits, edit.ReplaceWithString(edit.Range{pos, pos}, " "+report.Render(pass, x)))
   114  		report.Report(pass, swtch, fmt.Sprintf("could use tagged switch on %s", report.Render(pass, x)),
   115  			report.Fixes(edit.Fix("Replace with tagged switch", edits...)))
   116  	}
   117  
   118  	code.Preorder(pass, fn, (*ast.SwitchStmt)(nil))
   119  	return nil, nil
   120  }
   121  
   122  func findSwitchPairs(pass *analysis.Pass, expr ast.Expr, pairs *[]*ast.BinaryExpr) bool {
   123  	binexpr, ok := astutil.Unparen(expr).(*ast.BinaryExpr)
   124  	if !ok {
   125  		return false
   126  	}
   127  	switch binexpr.Op {
   128  	case token.EQL:
   129  		if code.MayHaveSideEffects(pass, binexpr.X, nil) || code.MayHaveSideEffects(pass, binexpr.Y, nil) {
   130  			return false
   131  		}
   132  		// syntactic identity should suffice. we do not allow side
   133  		// effects in the case clauses, so there should be no way for
   134  		// values to change.
   135  		if len(*pairs) > 0 && !astutil.Equal(binexpr.X, (*pairs)[0].X) {
   136  			return false
   137  		}
   138  		*pairs = append(*pairs, binexpr)
   139  		return true
   140  	case token.LOR:
   141  		return findSwitchPairs(pass, binexpr.X, pairs) && findSwitchPairs(pass, binexpr.Y, pairs)
   142  	default:
   143  		return false
   144  	}
   145  }