github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/staticcheck/sa5001/sa5001.go (about)

     1  package sa5001
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/types"
     7  
     8  	"github.com/amarpal/go-tools/analysis/code"
     9  	"github.com/amarpal/go-tools/analysis/lint"
    10  	"github.com/amarpal/go-tools/analysis/report"
    11  
    12  	"golang.org/x/tools/go/analysis"
    13  	"golang.org/x/tools/go/analysis/passes/inspect"
    14  )
    15  
    16  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
    17  	Analyzer: &analysis.Analyzer{
    18  		Name:     "SA5001",
    19  		Run:      run,
    20  		Requires: []*analysis.Analyzer{inspect.Analyzer},
    21  	},
    22  	Doc: &lint.Documentation{
    23  		Title:    `Deferring \'Close\' before checking for a possible error`,
    24  		Since:    "2017.1",
    25  		Severity: lint.SeverityWarning,
    26  		MergeIf:  lint.MergeIfAny,
    27  	},
    28  })
    29  
    30  var Analyzer = SCAnalyzer.Analyzer
    31  
    32  func run(pass *analysis.Pass) (interface{}, error) {
    33  	fn := func(node ast.Node) {
    34  		block := node.(*ast.BlockStmt)
    35  		if len(block.List) < 2 {
    36  			return
    37  		}
    38  		for i, stmt := range block.List {
    39  			if i == len(block.List)-1 {
    40  				break
    41  			}
    42  			assign, ok := stmt.(*ast.AssignStmt)
    43  			if !ok {
    44  				continue
    45  			}
    46  			if len(assign.Rhs) != 1 {
    47  				continue
    48  			}
    49  			if len(assign.Lhs) < 2 {
    50  				continue
    51  			}
    52  			if lhs, ok := assign.Lhs[len(assign.Lhs)-1].(*ast.Ident); ok && lhs.Name == "_" {
    53  				continue
    54  			}
    55  			call, ok := assign.Rhs[0].(*ast.CallExpr)
    56  			if !ok {
    57  				continue
    58  			}
    59  			sig, ok := pass.TypesInfo.TypeOf(call.Fun).(*types.Signature)
    60  			if !ok {
    61  				continue
    62  			}
    63  			if sig.Results().Len() < 2 {
    64  				continue
    65  			}
    66  			last := sig.Results().At(sig.Results().Len() - 1)
    67  			// FIXME(dh): check that it's error from universe, not
    68  			// another type of the same name
    69  			if last.Type().String() != "error" {
    70  				continue
    71  			}
    72  			lhs, ok := assign.Lhs[0].(*ast.Ident)
    73  			if !ok {
    74  				continue
    75  			}
    76  			def, ok := block.List[i+1].(*ast.DeferStmt)
    77  			if !ok {
    78  				continue
    79  			}
    80  			sel, ok := def.Call.Fun.(*ast.SelectorExpr)
    81  			if !ok {
    82  				continue
    83  			}
    84  			ident, ok := selectorX(sel).(*ast.Ident)
    85  			if !ok {
    86  				continue
    87  			}
    88  			if pass.TypesInfo.ObjectOf(ident) != pass.TypesInfo.ObjectOf(lhs) {
    89  				continue
    90  			}
    91  			if sel.Sel.Name != "Close" {
    92  				continue
    93  			}
    94  			report.Report(pass, def, fmt.Sprintf("should check returned error before deferring %s", report.Render(pass, def.Call)))
    95  		}
    96  	}
    97  	code.Preorder(pass, fn, (*ast.BlockStmt)(nil))
    98  	return nil, nil
    99  }
   100  
   101  func selectorX(sel *ast.SelectorExpr) ast.Node {
   102  	switch x := sel.X.(type) {
   103  	case *ast.SelectorExpr:
   104  		return selectorX(x)
   105  	default:
   106  		return x
   107  	}
   108  }