github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/go/analysis/passes/loopclosure/loopclosure.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package loopclosure defines an Analyzer that checks for references to
     6  // enclosing loop variables from within nested functions.
     7  package loopclosure
     8  
     9  import (
    10  	"go/ast"
    11  	"go/types"
    12  
    13  	"golang.org/x/tools/go/analysis"
    14  	"golang.org/x/tools/go/analysis/passes/inspect"
    15  	"golang.org/x/tools/go/ast/inspector"
    16  	"golang.org/x/tools/go/types/typeutil"
    17  )
    18  
    19  const Doc = `check references to loop variables from within nested functions
    20  
    21  This analyzer reports places where a function literal references the
    22  iteration variable of an enclosing loop, and the loop calls the function
    23  in such a way (e.g. with go or defer) that it may outlive the loop
    24  iteration and possibly observe the wrong value of the variable.
    25  
    26  In this example, all the deferred functions run after the loop has
    27  completed, so all observe the final value of v.
    28  
    29      for _, v := range list {
    30          defer func() {
    31              use(v) // incorrect
    32          }()
    33      }
    34  
    35  One fix is to create a new variable for each iteration of the loop:
    36  
    37      for _, v := range list {
    38          v := v // new var per iteration
    39          defer func() {
    40              use(v) // ok
    41          }()
    42      }
    43  
    44  The next example uses a go statement and has a similar problem.
    45  In addition, it has a data race because the loop updates v
    46  concurrent with the goroutines accessing it.
    47  
    48      for _, v := range elem {
    49          go func() {
    50              use(v)  // incorrect, and a data race
    51          }()
    52      }
    53  
    54  A fix is the same as before. The checker also reports problems
    55  in goroutines started by golang.org/x/sync/errgroup.Group.
    56  A hard-to-spot variant of this form is common in parallel tests:
    57  
    58      func Test(t *testing.T) {
    59          for _, test := range tests {
    60              t.Run(test.name, func(t *testing.T) {
    61                  t.Parallel()
    62                  use(test) // incorrect, and a data race
    63              })
    64          }
    65      }
    66  
    67  The t.Parallel() call causes the rest of the function to execute
    68  concurrent with the loop.
    69  
    70  The analyzer reports references only in the last statement,
    71  as it is not deep enough to understand the effects of subsequent
    72  statements that might render the reference benign.
    73  ("Last statement" is defined recursively in compound
    74  statements such as if, switch, and select.)
    75  
    76  See: https://golang.org/doc/go_faq.html#closures_and_goroutines`
    77  
    78  var Analyzer = &analysis.Analyzer{
    79  	Name:     "loopclosure",
    80  	Doc:      Doc,
    81  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    82  	Run:      run,
    83  }
    84  
    85  func run(pass *analysis.Pass) (interface{}, error) {
    86  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    87  
    88  	nodeFilter := []ast.Node{
    89  		(*ast.RangeStmt)(nil),
    90  		(*ast.ForStmt)(nil),
    91  	}
    92  	inspect.Preorder(nodeFilter, func(n ast.Node) {
    93  		// Find the variables updated by the loop statement.
    94  		var vars []types.Object
    95  		addVar := func(expr ast.Expr) {
    96  			if id, _ := expr.(*ast.Ident); id != nil {
    97  				if obj := pass.TypesInfo.ObjectOf(id); obj != nil {
    98  					vars = append(vars, obj)
    99  				}
   100  			}
   101  		}
   102  		var body *ast.BlockStmt
   103  		switch n := n.(type) {
   104  		case *ast.RangeStmt:
   105  			body = n.Body
   106  			addVar(n.Key)
   107  			addVar(n.Value)
   108  		case *ast.ForStmt:
   109  			body = n.Body
   110  			switch post := n.Post.(type) {
   111  			case *ast.AssignStmt:
   112  				// e.g. for p = head; p != nil; p = p.next
   113  				for _, lhs := range post.Lhs {
   114  					addVar(lhs)
   115  				}
   116  			case *ast.IncDecStmt:
   117  				// e.g. for i := 0; i < n; i++
   118  				addVar(post.X)
   119  			}
   120  		}
   121  		if vars == nil {
   122  			return
   123  		}
   124  
   125  		// Inspect statements to find function literals that may be run outside of
   126  		// the current loop iteration.
   127  		//
   128  		// For go, defer, and errgroup.Group.Go, we ignore all but the last
   129  		// statement, because it's hard to prove go isn't followed by wait, or
   130  		// defer by return. "Last" is defined recursively.
   131  		//
   132  		// TODO: consider allowing the "last" go/defer/Go statement to be followed by
   133  		// N "trivial" statements, possibly under a recursive definition of "trivial"
   134  		// so that that checker could, for example, conclude that a go statement is
   135  		// followed by an if statement made of only trivial statements and trivial expressions,
   136  		// and hence the go statement could still be checked.
   137  		forEachLastStmt(body.List, func(last ast.Stmt) {
   138  			var stmts []ast.Stmt
   139  			switch s := last.(type) {
   140  			case *ast.GoStmt:
   141  				stmts = litStmts(s.Call.Fun)
   142  			case *ast.DeferStmt:
   143  				stmts = litStmts(s.Call.Fun)
   144  			case *ast.ExprStmt: // check for errgroup.Group.Go
   145  				if call, ok := s.X.(*ast.CallExpr); ok {
   146  					stmts = litStmts(goInvoke(pass.TypesInfo, call))
   147  				}
   148  			}
   149  			for _, stmt := range stmts {
   150  				reportCaptured(pass, vars, stmt)
   151  			}
   152  		})
   153  
   154  		// Also check for testing.T.Run (with T.Parallel).
   155  		// We consider every t.Run statement in the loop body, because there is
   156  		// no commonly used mechanism for synchronizing parallel subtests.
   157  		// It is of course theoretically possible to synchronize parallel subtests,
   158  		// though such a pattern is likely to be exceedingly rare as it would be
   159  		// fighting against the test runner.
   160  		for _, s := range body.List {
   161  			switch s := s.(type) {
   162  			case *ast.ExprStmt:
   163  				if call, ok := s.X.(*ast.CallExpr); ok {
   164  					for _, stmt := range parallelSubtest(pass.TypesInfo, call) {
   165  						reportCaptured(pass, vars, stmt)
   166  					}
   167  
   168  				}
   169  			}
   170  		}
   171  	})
   172  	return nil, nil
   173  }
   174  
   175  // reportCaptured reports a diagnostic stating a loop variable
   176  // has been captured by a func literal if checkStmt has escaping
   177  // references to vars. vars is expected to be variables updated by a loop statement,
   178  // and checkStmt is expected to be a statements from the body of a func literal in the loop.
   179  func reportCaptured(pass *analysis.Pass, vars []types.Object, checkStmt ast.Stmt) {
   180  	ast.Inspect(checkStmt, func(n ast.Node) bool {
   181  		id, ok := n.(*ast.Ident)
   182  		if !ok {
   183  			return true
   184  		}
   185  		obj := pass.TypesInfo.Uses[id]
   186  		if obj == nil {
   187  			return true
   188  		}
   189  		for _, v := range vars {
   190  			if v == obj {
   191  				pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name)
   192  			}
   193  		}
   194  		return true
   195  	})
   196  }
   197  
   198  // forEachLastStmt calls onLast on each "last" statement in a list of statements.
   199  // "Last" is defined recursively so, for example, if the last statement is
   200  // a switch statement, then each switch case is also visited to examine
   201  // its last statements.
   202  func forEachLastStmt(stmts []ast.Stmt, onLast func(last ast.Stmt)) {
   203  	if len(stmts) == 0 {
   204  		return
   205  	}
   206  
   207  	s := stmts[len(stmts)-1]
   208  	switch s := s.(type) {
   209  	case *ast.IfStmt:
   210  	loop:
   211  		for {
   212  			forEachLastStmt(s.Body.List, onLast)
   213  			switch e := s.Else.(type) {
   214  			case *ast.BlockStmt:
   215  				forEachLastStmt(e.List, onLast)
   216  				break loop
   217  			case *ast.IfStmt:
   218  				s = e
   219  			case nil:
   220  				break loop
   221  			}
   222  		}
   223  	case *ast.ForStmt:
   224  		forEachLastStmt(s.Body.List, onLast)
   225  	case *ast.RangeStmt:
   226  		forEachLastStmt(s.Body.List, onLast)
   227  	case *ast.SwitchStmt:
   228  		for _, c := range s.Body.List {
   229  			cc := c.(*ast.CaseClause)
   230  			forEachLastStmt(cc.Body, onLast)
   231  		}
   232  	case *ast.TypeSwitchStmt:
   233  		for _, c := range s.Body.List {
   234  			cc := c.(*ast.CaseClause)
   235  			forEachLastStmt(cc.Body, onLast)
   236  		}
   237  	case *ast.SelectStmt:
   238  		for _, c := range s.Body.List {
   239  			cc := c.(*ast.CommClause)
   240  			forEachLastStmt(cc.Body, onLast)
   241  		}
   242  	default:
   243  		onLast(s)
   244  	}
   245  }
   246  
   247  // litStmts returns all statements from the function body of a function
   248  // literal.
   249  //
   250  // If fun is not a function literal, it returns nil.
   251  func litStmts(fun ast.Expr) []ast.Stmt {
   252  	lit, _ := fun.(*ast.FuncLit)
   253  	if lit == nil {
   254  		return nil
   255  	}
   256  	return lit.Body.List
   257  }
   258  
   259  // goInvoke returns a function expression that would be called asynchronously
   260  // (but not awaited) in another goroutine as a consequence of the call.
   261  // For example, given the g.Go call below, it returns the function literal expression.
   262  //
   263  //	import "sync/errgroup"
   264  //	var g errgroup.Group
   265  //	g.Go(func() error { ... })
   266  //
   267  // Currently only "golang.org/x/sync/errgroup.Group()" is considered.
   268  func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr {
   269  	if !isMethodCall(info, call, "golang.org/x/sync/errgroup", "Group", "Go") {
   270  		return nil
   271  	}
   272  	return call.Args[0]
   273  }
   274  
   275  // parallelSubtest returns statements that can be easily proven to execute
   276  // concurrently via the go test runner, as t.Run has been invoked with a
   277  // function literal that calls t.Parallel.
   278  //
   279  // In practice, users rely on the fact that statements before the call to
   280  // t.Parallel are synchronous. For example by declaring test := test inside the
   281  // function literal, but before the call to t.Parallel.
   282  //
   283  // Therefore, we only flag references in statements that are obviously
   284  // dominated by a call to t.Parallel. As a simple heuristic, we only consider
   285  // statements following the final labeled statement in the function body, to
   286  // avoid scenarios where a jump would cause either the call to t.Parallel or
   287  // the problematic reference to be skipped.
   288  //
   289  //	import "testing"
   290  //
   291  //	func TestFoo(t *testing.T) {
   292  //		tests := []int{0, 1, 2}
   293  //		for i, test := range tests {
   294  //			t.Run("subtest", func(t *testing.T) {
   295  //				println(i, test) // OK
   296  //		 		t.Parallel()
   297  //				println(i, test) // Not OK
   298  //			})
   299  //		}
   300  //	}
   301  func parallelSubtest(info *types.Info, call *ast.CallExpr) []ast.Stmt {
   302  	if !isMethodCall(info, call, "testing", "T", "Run") {
   303  		return nil
   304  	}
   305  
   306  	if len(call.Args) != 2 {
   307  		// Ignore calls such as t.Run(fn()).
   308  		return nil
   309  	}
   310  
   311  	lit, _ := call.Args[1].(*ast.FuncLit)
   312  	if lit == nil {
   313  		return nil
   314  	}
   315  
   316  	// Capture the *testing.T object for the first argument to the function
   317  	// literal.
   318  	if len(lit.Type.Params.List[0].Names) == 0 {
   319  		return nil
   320  	}
   321  
   322  	tObj := info.Defs[lit.Type.Params.List[0].Names[0]]
   323  	if tObj == nil {
   324  		return nil
   325  	}
   326  
   327  	// Match statements that occur after a call to t.Parallel following the final
   328  	// labeled statement in the function body.
   329  	//
   330  	// We iterate over lit.Body.List to have a simple, fast and "frequent enough"
   331  	// dominance relationship for t.Parallel(): lit.Body.List[i] dominates
   332  	// lit.Body.List[j] for i < j unless there is a jump.
   333  	var stmts []ast.Stmt
   334  	afterParallel := false
   335  	for _, stmt := range lit.Body.List {
   336  		stmt, labeled := unlabel(stmt)
   337  		if labeled {
   338  			// Reset: naively we don't know if a jump could have caused the
   339  			// previously considered statements to be skipped.
   340  			stmts = nil
   341  			afterParallel = false
   342  		}
   343  
   344  		if afterParallel {
   345  			stmts = append(stmts, stmt)
   346  			continue
   347  		}
   348  
   349  		// Check if stmt is a call to t.Parallel(), for the correct t.
   350  		exprStmt, ok := stmt.(*ast.ExprStmt)
   351  		if !ok {
   352  			continue
   353  		}
   354  		expr := exprStmt.X
   355  		if isMethodCall(info, expr, "testing", "T", "Parallel") {
   356  			call, _ := expr.(*ast.CallExpr)
   357  			if call == nil {
   358  				continue
   359  			}
   360  			x, _ := call.Fun.(*ast.SelectorExpr)
   361  			if x == nil {
   362  				continue
   363  			}
   364  			id, _ := x.X.(*ast.Ident)
   365  			if id == nil {
   366  				continue
   367  			}
   368  			if info.Uses[id] == tObj {
   369  				afterParallel = true
   370  			}
   371  		}
   372  	}
   373  
   374  	return stmts
   375  }
   376  
   377  // unlabel returns the inner statement for the possibly labeled statement stmt,
   378  // stripping any (possibly nested) *ast.LabeledStmt wrapper.
   379  //
   380  // The second result reports whether stmt was an *ast.LabeledStmt.
   381  func unlabel(stmt ast.Stmt) (ast.Stmt, bool) {
   382  	labeled := false
   383  	for {
   384  		labelStmt, ok := stmt.(*ast.LabeledStmt)
   385  		if !ok {
   386  			return stmt, labeled
   387  		}
   388  		labeled = true
   389  		stmt = labelStmt.Stmt
   390  	}
   391  }
   392  
   393  // isMethodCall reports whether expr is a method call of
   394  // <pkgPath>.<typeName>.<method>.
   395  func isMethodCall(info *types.Info, expr ast.Expr, pkgPath, typeName, method string) bool {
   396  	call, ok := expr.(*ast.CallExpr)
   397  	if !ok {
   398  		return false
   399  	}
   400  
   401  	// Check that we are calling a method <method>
   402  	f := typeutil.StaticCallee(info, call)
   403  	if f == nil || f.Name() != method {
   404  		return false
   405  	}
   406  	recv := f.Type().(*types.Signature).Recv()
   407  	if recv == nil {
   408  		return false
   409  	}
   410  
   411  	// Check that the receiver is a <pkgPath>.<typeName> or
   412  	// *<pkgPath>.<typeName>.
   413  	rtype := recv.Type()
   414  	if ptr, ok := recv.Type().(*types.Pointer); ok {
   415  		rtype = ptr.Elem()
   416  	}
   417  	named, ok := rtype.(*types.Named)
   418  	if !ok {
   419  		return false
   420  	}
   421  	if named.Obj().Name() != typeName {
   422  		return false
   423  	}
   424  	pkg := f.Pkg()
   425  	if pkg == nil {
   426  		return false
   427  	}
   428  	if pkg.Path() != pkgPath {
   429  		return false
   430  	}
   431  
   432  	return true
   433  }