github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/staticcheck/sa2002/sa2002.go (about)

     1  package sa2002
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  
     7  	"github.com/amarpal/go-tools/analysis/lint"
     8  	"github.com/amarpal/go-tools/analysis/report"
     9  	"github.com/amarpal/go-tools/go/ir"
    10  	"github.com/amarpal/go-tools/go/types/typeutil"
    11  	"github.com/amarpal/go-tools/internal/passes/buildir"
    12  
    13  	"golang.org/x/tools/go/analysis"
    14  )
    15  
    16  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
    17  	Analyzer: &analysis.Analyzer{
    18  		Name:     "SA2002",
    19  		Run:      run,
    20  		Requires: []*analysis.Analyzer{buildir.Analyzer},
    21  	},
    22  	Doc: &lint.Documentation{
    23  		Title:    `Called \'testing.T.FailNow\' or \'SkipNow\' in a goroutine, which isn't allowed`,
    24  		Since:    "2017.1",
    25  		Severity: lint.SeverityError,
    26  		MergeIf:  lint.MergeIfAny,
    27  	},
    28  })
    29  
    30  var Analyzer = SCAnalyzer.Analyzer
    31  
    32  func run(pass *analysis.Pass) (interface{}, error) {
    33  	for _, fn := range pass.ResultOf[buildir.Analyzer].(*buildir.IR).SrcFuncs {
    34  		for _, block := range fn.Blocks {
    35  			for _, ins := range block.Instrs {
    36  				gostmt, ok := ins.(*ir.Go)
    37  				if !ok {
    38  					continue
    39  				}
    40  				var fn *ir.Function
    41  				switch val := gostmt.Call.Value.(type) {
    42  				case *ir.Function:
    43  					fn = val
    44  				case *ir.MakeClosure:
    45  					fn = val.Fn.(*ir.Function)
    46  				default:
    47  					continue
    48  				}
    49  				if fn.Blocks == nil {
    50  					continue
    51  				}
    52  				for _, block := range fn.Blocks {
    53  					for _, ins := range block.Instrs {
    54  						call, ok := ins.(*ir.Call)
    55  						if !ok {
    56  							continue
    57  						}
    58  						if call.Call.IsInvoke() {
    59  							continue
    60  						}
    61  						callee := call.Call.StaticCallee()
    62  						if callee == nil {
    63  							continue
    64  						}
    65  						recv := callee.Signature.Recv()
    66  						if recv == nil {
    67  							continue
    68  						}
    69  						if !typeutil.IsType(recv.Type(), "*testing.common") {
    70  							continue
    71  						}
    72  						fn, ok := call.Call.StaticCallee().Object().(*types.Func)
    73  						if !ok {
    74  							continue
    75  						}
    76  						name := fn.Name()
    77  						switch name {
    78  						case "FailNow", "Fatal", "Fatalf", "SkipNow", "Skip", "Skipf":
    79  						default:
    80  							continue
    81  						}
    82  						// TODO(dh): don't report multiple diagnostics
    83  						// for multiple calls to T.Fatal, but do
    84  						// collect all of them as related information
    85  						report.Report(pass, gostmt, fmt.Sprintf("the goroutine calls T.%s, which must be called in the same goroutine as the test", name),
    86  							report.Related(call, fmt.Sprintf("call to T.%s", name)))
    87  					}
    88  				}
    89  			}
    90  		}
    91  	}
    92  	return nil, nil
    93  }