golang.org/x/tools@v0.21.0/go/analysis/passes/testinggoroutine/testinggoroutine.go (about) 1 // Copyright 2020 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testinggoroutine 6 7 import ( 8 _ "embed" 9 "fmt" 10 "go/ast" 11 "go/token" 12 "go/types" 13 14 "golang.org/x/tools/go/analysis" 15 "golang.org/x/tools/go/analysis/passes/inspect" 16 "golang.org/x/tools/go/analysis/passes/internal/analysisutil" 17 "golang.org/x/tools/go/ast/astutil" 18 "golang.org/x/tools/go/ast/inspector" 19 "golang.org/x/tools/go/types/typeutil" 20 "golang.org/x/tools/internal/aliases" 21 ) 22 23 //go:embed doc.go 24 var doc string 25 26 var reportSubtest bool 27 28 func init() { 29 Analyzer.Flags.BoolVar(&reportSubtest, "subtest", false, "whether to check if t.Run subtest is terminated correctly; experimental") 30 } 31 32 var Analyzer = &analysis.Analyzer{ 33 Name: "testinggoroutine", 34 Doc: analysisutil.MustExtractDoc(doc, "testinggoroutine"), 35 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/testinggoroutine", 36 Requires: []*analysis.Analyzer{inspect.Analyzer}, 37 Run: run, 38 } 39 40 func run(pass *analysis.Pass) (interface{}, error) { 41 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) 42 43 if !analysisutil.Imports(pass.Pkg, "testing") { 44 return nil, nil 45 } 46 47 toDecl := localFunctionDecls(pass.TypesInfo, pass.Files) 48 49 // asyncs maps nodes whose statements will be executed concurrently 50 // with respect to some test function, to the call sites where they 51 // are invoked asynchronously. There may be multiple such call sites 52 // for e.g. test helpers. 53 asyncs := make(map[ast.Node][]*asyncCall) 54 var regions []ast.Node 55 addCall := func(c *asyncCall) { 56 if c != nil { 57 r := c.region 58 if asyncs[r] == nil { 59 regions = append(regions, r) 60 } 61 asyncs[r] = append(asyncs[r], c) 62 } 63 } 64 65 // Collect all of the go callee() and t.Run(name, callee) extents. 66 inspect.Nodes([]ast.Node{ 67 (*ast.FuncDecl)(nil), 68 (*ast.GoStmt)(nil), 69 (*ast.CallExpr)(nil), 70 }, func(node ast.Node, push bool) bool { 71 if !push { 72 return false 73 } 74 switch node := node.(type) { 75 case *ast.FuncDecl: 76 return hasBenchmarkOrTestParams(node) 77 78 case *ast.GoStmt: 79 c := goAsyncCall(pass.TypesInfo, node, toDecl) 80 addCall(c) 81 82 case *ast.CallExpr: 83 c := tRunAsyncCall(pass.TypesInfo, node) 84 addCall(c) 85 } 86 return true 87 }) 88 89 // Check for t.Forbidden() calls within each region r that is a 90 // callee in some go r() or a t.Run("name", r). 91 // 92 // Also considers a special case when r is a go t.Forbidden() call. 93 for _, region := range regions { 94 ast.Inspect(region, func(n ast.Node) bool { 95 if n == region { 96 return true // always descend into the region itself. 97 } else if asyncs[n] != nil { 98 return false // will be visited by another region. 99 } 100 101 call, ok := n.(*ast.CallExpr) 102 if !ok { 103 return true 104 } 105 x, sel, fn := forbiddenMethod(pass.TypesInfo, call) 106 if x == nil { 107 return true 108 } 109 110 for _, e := range asyncs[region] { 111 if !withinScope(e.scope, x) { 112 forbidden := formatMethod(sel, fn) // e.g. "(*testing.T).Forbidden 113 114 var context string 115 var where analysis.Range = e.async // Put the report at the go fun() or t.Run(name, fun). 116 if _, local := e.fun.(*ast.FuncLit); local { 117 where = call // Put the report at the t.Forbidden() call. 118 } else if id, ok := e.fun.(*ast.Ident); ok { 119 context = fmt.Sprintf(" (%s calls %s)", id.Name, forbidden) 120 } 121 if _, ok := e.async.(*ast.GoStmt); ok { 122 pass.ReportRangef(where, "call to %s from a non-test goroutine%s", forbidden, context) 123 } else if reportSubtest { 124 pass.ReportRangef(where, "call to %s on %s defined outside of the subtest%s", forbidden, x.Name(), context) 125 } 126 } 127 } 128 return true 129 }) 130 } 131 132 return nil, nil 133 } 134 135 func hasBenchmarkOrTestParams(fnDecl *ast.FuncDecl) bool { 136 // Check that the function's arguments include "*testing.T" or "*testing.B". 137 params := fnDecl.Type.Params.List 138 139 for _, param := range params { 140 if _, ok := typeIsTestingDotTOrB(param.Type); ok { 141 return true 142 } 143 } 144 145 return false 146 } 147 148 func typeIsTestingDotTOrB(expr ast.Expr) (string, bool) { 149 starExpr, ok := expr.(*ast.StarExpr) 150 if !ok { 151 return "", false 152 } 153 selExpr, ok := starExpr.X.(*ast.SelectorExpr) 154 if !ok { 155 return "", false 156 } 157 varPkg := selExpr.X.(*ast.Ident) 158 if varPkg.Name != "testing" { 159 return "", false 160 } 161 162 varTypeName := selExpr.Sel.Name 163 ok = varTypeName == "B" || varTypeName == "T" 164 return varTypeName, ok 165 } 166 167 // asyncCall describes a region of code that needs to be checked for 168 // t.Forbidden() calls as it is started asynchronously from an async 169 // node go fun() or t.Run(name, fun). 170 type asyncCall struct { 171 region ast.Node // region of code to check for t.Forbidden() calls. 172 async ast.Node // *ast.GoStmt or *ast.CallExpr (for t.Run) 173 scope ast.Node // Report t.Forbidden() if t is not declared within scope. 174 fun ast.Expr // fun in go fun() or t.Run(name, fun) 175 } 176 177 // withinScope returns true if x.Pos() is in [scope.Pos(), scope.End()]. 178 func withinScope(scope ast.Node, x *types.Var) bool { 179 if scope != nil { 180 return x.Pos() != token.NoPos && scope.Pos() <= x.Pos() && x.Pos() <= scope.End() 181 } 182 return false 183 } 184 185 // goAsyncCall returns the extent of a call from a go fun() statement. 186 func goAsyncCall(info *types.Info, goStmt *ast.GoStmt, toDecl func(*types.Func) *ast.FuncDecl) *asyncCall { 187 call := goStmt.Call 188 189 fun := astutil.Unparen(call.Fun) 190 if id := funcIdent(fun); id != nil { 191 if lit := funcLitInScope(id); lit != nil { 192 return &asyncCall{region: lit, async: goStmt, scope: nil, fun: fun} 193 } 194 } 195 196 if fn := typeutil.StaticCallee(info, call); fn != nil { // static call or method in the package? 197 if decl := toDecl(fn); decl != nil { 198 return &asyncCall{region: decl, async: goStmt, scope: nil, fun: fun} 199 } 200 } 201 202 // Check go statement for go t.Forbidden() or go func(){t.Forbidden()}(). 203 return &asyncCall{region: goStmt, async: goStmt, scope: nil, fun: fun} 204 } 205 206 // tRunAsyncCall returns the extent of a call from a t.Run("name", fun) expression. 207 func tRunAsyncCall(info *types.Info, call *ast.CallExpr) *asyncCall { 208 if len(call.Args) != 2 { 209 return nil 210 } 211 run := typeutil.Callee(info, call) 212 if run, ok := run.(*types.Func); !ok || !isMethodNamed(run, "testing", "Run") { 213 return nil 214 } 215 216 fun := astutil.Unparen(call.Args[1]) 217 if lit, ok := fun.(*ast.FuncLit); ok { // function lit? 218 return &asyncCall{region: lit, async: call, scope: lit, fun: fun} 219 } 220 221 if id := funcIdent(fun); id != nil { 222 if lit := funcLitInScope(id); lit != nil { // function lit in variable? 223 return &asyncCall{region: lit, async: call, scope: lit, fun: fun} 224 } 225 } 226 227 // Check within t.Run(name, fun) for calls to t.Forbidden, 228 // e.g. t.Run(name, func(t *testing.T){ t.Forbidden() }) 229 return &asyncCall{region: call, async: call, scope: fun, fun: fun} 230 } 231 232 var forbidden = []string{ 233 "FailNow", 234 "Fatal", 235 "Fatalf", 236 "Skip", 237 "Skipf", 238 "SkipNow", 239 } 240 241 // forbiddenMethod decomposes a call x.m() into (x, x.m, m) where 242 // x is a variable, x.m is a selection, and m is the static callee m. 243 // Returns (nil, nil, nil) if call is not of this form. 244 func forbiddenMethod(info *types.Info, call *ast.CallExpr) (*types.Var, *types.Selection, *types.Func) { 245 // Compare to typeutil.StaticCallee. 246 fun := astutil.Unparen(call.Fun) 247 selExpr, ok := fun.(*ast.SelectorExpr) 248 if !ok { 249 return nil, nil, nil 250 } 251 sel := info.Selections[selExpr] 252 if sel == nil { 253 return nil, nil, nil 254 } 255 256 var x *types.Var 257 if id, ok := astutil.Unparen(selExpr.X).(*ast.Ident); ok { 258 x, _ = info.Uses[id].(*types.Var) 259 } 260 if x == nil { 261 return nil, nil, nil 262 } 263 264 fn, _ := sel.Obj().(*types.Func) 265 if fn == nil || !isMethodNamed(fn, "testing", forbidden...) { 266 return nil, nil, nil 267 } 268 return x, sel, fn 269 } 270 271 func formatMethod(sel *types.Selection, fn *types.Func) string { 272 var ptr string 273 rtype := sel.Recv() 274 if p, ok := aliases.Unalias(rtype).(*types.Pointer); ok { 275 ptr = "*" 276 rtype = p.Elem() 277 } 278 return fmt.Sprintf("(%s%s).%s", ptr, rtype.String(), fn.Name()) 279 }