golang.org/x/tools@v0.21.0/go/analysis/passes/testinggoroutine/testinggoroutine.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testinggoroutine
     6  
     7  import (
     8  	_ "embed"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/token"
    12  	"go/types"
    13  
    14  	"golang.org/x/tools/go/analysis"
    15  	"golang.org/x/tools/go/analysis/passes/inspect"
    16  	"golang.org/x/tools/go/analysis/passes/internal/analysisutil"
    17  	"golang.org/x/tools/go/ast/astutil"
    18  	"golang.org/x/tools/go/ast/inspector"
    19  	"golang.org/x/tools/go/types/typeutil"
    20  	"golang.org/x/tools/internal/aliases"
    21  )
    22  
    23  //go:embed doc.go
    24  var doc string
    25  
    26  var reportSubtest bool
    27  
    28  func init() {
    29  	Analyzer.Flags.BoolVar(&reportSubtest, "subtest", false, "whether to check if t.Run subtest is terminated correctly; experimental")
    30  }
    31  
    32  var Analyzer = &analysis.Analyzer{
    33  	Name:     "testinggoroutine",
    34  	Doc:      analysisutil.MustExtractDoc(doc, "testinggoroutine"),
    35  	URL:      "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/testinggoroutine",
    36  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    37  	Run:      run,
    38  }
    39  
    40  func run(pass *analysis.Pass) (interface{}, error) {
    41  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    42  
    43  	if !analysisutil.Imports(pass.Pkg, "testing") {
    44  		return nil, nil
    45  	}
    46  
    47  	toDecl := localFunctionDecls(pass.TypesInfo, pass.Files)
    48  
    49  	// asyncs maps nodes whose statements will be executed concurrently
    50  	// with respect to some test function, to the call sites where they
    51  	// are invoked asynchronously. There may be multiple such call sites
    52  	// for e.g. test helpers.
    53  	asyncs := make(map[ast.Node][]*asyncCall)
    54  	var regions []ast.Node
    55  	addCall := func(c *asyncCall) {
    56  		if c != nil {
    57  			r := c.region
    58  			if asyncs[r] == nil {
    59  				regions = append(regions, r)
    60  			}
    61  			asyncs[r] = append(asyncs[r], c)
    62  		}
    63  	}
    64  
    65  	// Collect all of the go callee() and t.Run(name, callee) extents.
    66  	inspect.Nodes([]ast.Node{
    67  		(*ast.FuncDecl)(nil),
    68  		(*ast.GoStmt)(nil),
    69  		(*ast.CallExpr)(nil),
    70  	}, func(node ast.Node, push bool) bool {
    71  		if !push {
    72  			return false
    73  		}
    74  		switch node := node.(type) {
    75  		case *ast.FuncDecl:
    76  			return hasBenchmarkOrTestParams(node)
    77  
    78  		case *ast.GoStmt:
    79  			c := goAsyncCall(pass.TypesInfo, node, toDecl)
    80  			addCall(c)
    81  
    82  		case *ast.CallExpr:
    83  			c := tRunAsyncCall(pass.TypesInfo, node)
    84  			addCall(c)
    85  		}
    86  		return true
    87  	})
    88  
    89  	// Check for t.Forbidden() calls within each region r that is a
    90  	// callee in some go r() or a t.Run("name", r).
    91  	//
    92  	// Also considers a special case when r is a go t.Forbidden() call.
    93  	for _, region := range regions {
    94  		ast.Inspect(region, func(n ast.Node) bool {
    95  			if n == region {
    96  				return true // always descend into the region itself.
    97  			} else if asyncs[n] != nil {
    98  				return false // will be visited by another region.
    99  			}
   100  
   101  			call, ok := n.(*ast.CallExpr)
   102  			if !ok {
   103  				return true
   104  			}
   105  			x, sel, fn := forbiddenMethod(pass.TypesInfo, call)
   106  			if x == nil {
   107  				return true
   108  			}
   109  
   110  			for _, e := range asyncs[region] {
   111  				if !withinScope(e.scope, x) {
   112  					forbidden := formatMethod(sel, fn) // e.g. "(*testing.T).Forbidden
   113  
   114  					var context string
   115  					var where analysis.Range = e.async // Put the report at the go fun() or t.Run(name, fun).
   116  					if _, local := e.fun.(*ast.FuncLit); local {
   117  						where = call // Put the report at the t.Forbidden() call.
   118  					} else if id, ok := e.fun.(*ast.Ident); ok {
   119  						context = fmt.Sprintf(" (%s calls %s)", id.Name, forbidden)
   120  					}
   121  					if _, ok := e.async.(*ast.GoStmt); ok {
   122  						pass.ReportRangef(where, "call to %s from a non-test goroutine%s", forbidden, context)
   123  					} else if reportSubtest {
   124  						pass.ReportRangef(where, "call to %s on %s defined outside of the subtest%s", forbidden, x.Name(), context)
   125  					}
   126  				}
   127  			}
   128  			return true
   129  		})
   130  	}
   131  
   132  	return nil, nil
   133  }
   134  
   135  func hasBenchmarkOrTestParams(fnDecl *ast.FuncDecl) bool {
   136  	// Check that the function's arguments include "*testing.T" or "*testing.B".
   137  	params := fnDecl.Type.Params.List
   138  
   139  	for _, param := range params {
   140  		if _, ok := typeIsTestingDotTOrB(param.Type); ok {
   141  			return true
   142  		}
   143  	}
   144  
   145  	return false
   146  }
   147  
   148  func typeIsTestingDotTOrB(expr ast.Expr) (string, bool) {
   149  	starExpr, ok := expr.(*ast.StarExpr)
   150  	if !ok {
   151  		return "", false
   152  	}
   153  	selExpr, ok := starExpr.X.(*ast.SelectorExpr)
   154  	if !ok {
   155  		return "", false
   156  	}
   157  	varPkg := selExpr.X.(*ast.Ident)
   158  	if varPkg.Name != "testing" {
   159  		return "", false
   160  	}
   161  
   162  	varTypeName := selExpr.Sel.Name
   163  	ok = varTypeName == "B" || varTypeName == "T"
   164  	return varTypeName, ok
   165  }
   166  
   167  // asyncCall describes a region of code that needs to be checked for
   168  // t.Forbidden() calls as it is started asynchronously from an async
   169  // node go fun() or t.Run(name, fun).
   170  type asyncCall struct {
   171  	region ast.Node // region of code to check for t.Forbidden() calls.
   172  	async  ast.Node // *ast.GoStmt or *ast.CallExpr (for t.Run)
   173  	scope  ast.Node // Report t.Forbidden() if t is not declared within scope.
   174  	fun    ast.Expr // fun in go fun() or t.Run(name, fun)
   175  }
   176  
   177  // withinScope returns true if x.Pos() is in [scope.Pos(), scope.End()].
   178  func withinScope(scope ast.Node, x *types.Var) bool {
   179  	if scope != nil {
   180  		return x.Pos() != token.NoPos && scope.Pos() <= x.Pos() && x.Pos() <= scope.End()
   181  	}
   182  	return false
   183  }
   184  
   185  // goAsyncCall returns the extent of a call from a go fun() statement.
   186  func goAsyncCall(info *types.Info, goStmt *ast.GoStmt, toDecl func(*types.Func) *ast.FuncDecl) *asyncCall {
   187  	call := goStmt.Call
   188  
   189  	fun := astutil.Unparen(call.Fun)
   190  	if id := funcIdent(fun); id != nil {
   191  		if lit := funcLitInScope(id); lit != nil {
   192  			return &asyncCall{region: lit, async: goStmt, scope: nil, fun: fun}
   193  		}
   194  	}
   195  
   196  	if fn := typeutil.StaticCallee(info, call); fn != nil { // static call or method in the package?
   197  		if decl := toDecl(fn); decl != nil {
   198  			return &asyncCall{region: decl, async: goStmt, scope: nil, fun: fun}
   199  		}
   200  	}
   201  
   202  	// Check go statement for go t.Forbidden() or go func(){t.Forbidden()}().
   203  	return &asyncCall{region: goStmt, async: goStmt, scope: nil, fun: fun}
   204  }
   205  
   206  // tRunAsyncCall returns the extent of a call from a t.Run("name", fun) expression.
   207  func tRunAsyncCall(info *types.Info, call *ast.CallExpr) *asyncCall {
   208  	if len(call.Args) != 2 {
   209  		return nil
   210  	}
   211  	run := typeutil.Callee(info, call)
   212  	if run, ok := run.(*types.Func); !ok || !isMethodNamed(run, "testing", "Run") {
   213  		return nil
   214  	}
   215  
   216  	fun := astutil.Unparen(call.Args[1])
   217  	if lit, ok := fun.(*ast.FuncLit); ok { // function lit?
   218  		return &asyncCall{region: lit, async: call, scope: lit, fun: fun}
   219  	}
   220  
   221  	if id := funcIdent(fun); id != nil {
   222  		if lit := funcLitInScope(id); lit != nil { // function lit in variable?
   223  			return &asyncCall{region: lit, async: call, scope: lit, fun: fun}
   224  		}
   225  	}
   226  
   227  	// Check within t.Run(name, fun) for calls to t.Forbidden,
   228  	// e.g. t.Run(name, func(t *testing.T){ t.Forbidden() })
   229  	return &asyncCall{region: call, async: call, scope: fun, fun: fun}
   230  }
   231  
   232  var forbidden = []string{
   233  	"FailNow",
   234  	"Fatal",
   235  	"Fatalf",
   236  	"Skip",
   237  	"Skipf",
   238  	"SkipNow",
   239  }
   240  
   241  // forbiddenMethod decomposes a call x.m() into (x, x.m, m) where
   242  // x is a variable, x.m is a selection, and m is the static callee m.
   243  // Returns (nil, nil, nil) if call is not of this form.
   244  func forbiddenMethod(info *types.Info, call *ast.CallExpr) (*types.Var, *types.Selection, *types.Func) {
   245  	// Compare to typeutil.StaticCallee.
   246  	fun := astutil.Unparen(call.Fun)
   247  	selExpr, ok := fun.(*ast.SelectorExpr)
   248  	if !ok {
   249  		return nil, nil, nil
   250  	}
   251  	sel := info.Selections[selExpr]
   252  	if sel == nil {
   253  		return nil, nil, nil
   254  	}
   255  
   256  	var x *types.Var
   257  	if id, ok := astutil.Unparen(selExpr.X).(*ast.Ident); ok {
   258  		x, _ = info.Uses[id].(*types.Var)
   259  	}
   260  	if x == nil {
   261  		return nil, nil, nil
   262  	}
   263  
   264  	fn, _ := sel.Obj().(*types.Func)
   265  	if fn == nil || !isMethodNamed(fn, "testing", forbidden...) {
   266  		return nil, nil, nil
   267  	}
   268  	return x, sel, fn
   269  }
   270  
   271  func formatMethod(sel *types.Selection, fn *types.Func) string {
   272  	var ptr string
   273  	rtype := sel.Recv()
   274  	if p, ok := aliases.Unalias(rtype).(*types.Pointer); ok {
   275  		ptr = "*"
   276  		rtype = p.Elem()
   277  	}
   278  	return fmt.Sprintf("(%s%s).%s", ptr, rtype.String(), fn.Name())
   279  }