github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/go/analysis/passes/loopclosure/loopclosure.go (about) 1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package loopclosure defines an Analyzer that checks for references to 6 // enclosing loop variables from within nested functions. 7 package loopclosure 8 9 import ( 10 "go/ast" 11 "go/types" 12 13 "golang.org/x/tools/go/analysis" 14 "golang.org/x/tools/go/analysis/passes/inspect" 15 "golang.org/x/tools/go/ast/inspector" 16 "golang.org/x/tools/go/types/typeutil" 17 ) 18 19 const Doc = `check references to loop variables from within nested functions 20 21 This analyzer reports places where a function literal references the 22 iteration variable of an enclosing loop, and the loop calls the function 23 in such a way (e.g. with go or defer) that it may outlive the loop 24 iteration and possibly observe the wrong value of the variable. 25 26 In this example, all the deferred functions run after the loop has 27 completed, so all observe the final value of v. 28 29 for _, v := range list { 30 defer func() { 31 use(v) // incorrect 32 }() 33 } 34 35 One fix is to create a new variable for each iteration of the loop: 36 37 for _, v := range list { 38 v := v // new var per iteration 39 defer func() { 40 use(v) // ok 41 }() 42 } 43 44 The next example uses a go statement and has a similar problem. 45 In addition, it has a data race because the loop updates v 46 concurrent with the goroutines accessing it. 47 48 for _, v := range elem { 49 go func() { 50 use(v) // incorrect, and a data race 51 }() 52 } 53 54 A fix is the same as before. The checker also reports problems 55 in goroutines started by golang.org/x/sync/errgroup.Group. 56 A hard-to-spot variant of this form is common in parallel tests: 57 58 func Test(t *testing.T) { 59 for _, test := range tests { 60 t.Run(test.name, func(t *testing.T) { 61 t.Parallel() 62 use(test) // incorrect, and a data race 63 }) 64 } 65 } 66 67 The t.Parallel() call causes the rest of the function to execute 68 concurrent with the loop. 69 70 The analyzer reports references only in the last statement, 71 as it is not deep enough to understand the effects of subsequent 72 statements that might render the reference benign. 73 ("Last statement" is defined recursively in compound 74 statements such as if, switch, and select.) 75 76 See: https://golang.org/doc/go_faq.html#closures_and_goroutines` 77 78 var Analyzer = &analysis.Analyzer{ 79 Name: "loopclosure", 80 Doc: Doc, 81 Requires: []*analysis.Analyzer{inspect.Analyzer}, 82 Run: run, 83 } 84 85 func run(pass *analysis.Pass) (interface{}, error) { 86 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) 87 88 nodeFilter := []ast.Node{ 89 (*ast.RangeStmt)(nil), 90 (*ast.ForStmt)(nil), 91 } 92 inspect.Preorder(nodeFilter, func(n ast.Node) { 93 // Find the variables updated by the loop statement. 94 var vars []types.Object 95 addVar := func(expr ast.Expr) { 96 if id, _ := expr.(*ast.Ident); id != nil { 97 if obj := pass.TypesInfo.ObjectOf(id); obj != nil { 98 vars = append(vars, obj) 99 } 100 } 101 } 102 var body *ast.BlockStmt 103 switch n := n.(type) { 104 case *ast.RangeStmt: 105 body = n.Body 106 addVar(n.Key) 107 addVar(n.Value) 108 case *ast.ForStmt: 109 body = n.Body 110 switch post := n.Post.(type) { 111 case *ast.AssignStmt: 112 // e.g. for p = head; p != nil; p = p.next 113 for _, lhs := range post.Lhs { 114 addVar(lhs) 115 } 116 case *ast.IncDecStmt: 117 // e.g. for i := 0; i < n; i++ 118 addVar(post.X) 119 } 120 } 121 if vars == nil { 122 return 123 } 124 125 // Inspect statements to find function literals that may be run outside of 126 // the current loop iteration. 127 // 128 // For go, defer, and errgroup.Group.Go, we ignore all but the last 129 // statement, because it's hard to prove go isn't followed by wait, or 130 // defer by return. "Last" is defined recursively. 131 // 132 // TODO: consider allowing the "last" go/defer/Go statement to be followed by 133 // N "trivial" statements, possibly under a recursive definition of "trivial" 134 // so that that checker could, for example, conclude that a go statement is 135 // followed by an if statement made of only trivial statements and trivial expressions, 136 // and hence the go statement could still be checked. 137 forEachLastStmt(body.List, func(last ast.Stmt) { 138 var stmts []ast.Stmt 139 switch s := last.(type) { 140 case *ast.GoStmt: 141 stmts = litStmts(s.Call.Fun) 142 case *ast.DeferStmt: 143 stmts = litStmts(s.Call.Fun) 144 case *ast.ExprStmt: // check for errgroup.Group.Go 145 if call, ok := s.X.(*ast.CallExpr); ok { 146 stmts = litStmts(goInvoke(pass.TypesInfo, call)) 147 } 148 } 149 for _, stmt := range stmts { 150 reportCaptured(pass, vars, stmt) 151 } 152 }) 153 154 // Also check for testing.T.Run (with T.Parallel). 155 // We consider every t.Run statement in the loop body, because there is 156 // no commonly used mechanism for synchronizing parallel subtests. 157 // It is of course theoretically possible to synchronize parallel subtests, 158 // though such a pattern is likely to be exceedingly rare as it would be 159 // fighting against the test runner. 160 for _, s := range body.List { 161 switch s := s.(type) { 162 case *ast.ExprStmt: 163 if call, ok := s.X.(*ast.CallExpr); ok { 164 for _, stmt := range parallelSubtest(pass.TypesInfo, call) { 165 reportCaptured(pass, vars, stmt) 166 } 167 168 } 169 } 170 } 171 }) 172 return nil, nil 173 } 174 175 // reportCaptured reports a diagnostic stating a loop variable 176 // has been captured by a func literal if checkStmt has escaping 177 // references to vars. vars is expected to be variables updated by a loop statement, 178 // and checkStmt is expected to be a statements from the body of a func literal in the loop. 179 func reportCaptured(pass *analysis.Pass, vars []types.Object, checkStmt ast.Stmt) { 180 ast.Inspect(checkStmt, func(n ast.Node) bool { 181 id, ok := n.(*ast.Ident) 182 if !ok { 183 return true 184 } 185 obj := pass.TypesInfo.Uses[id] 186 if obj == nil { 187 return true 188 } 189 for _, v := range vars { 190 if v == obj { 191 pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name) 192 } 193 } 194 return true 195 }) 196 } 197 198 // forEachLastStmt calls onLast on each "last" statement in a list of statements. 199 // "Last" is defined recursively so, for example, if the last statement is 200 // a switch statement, then each switch case is also visited to examine 201 // its last statements. 202 func forEachLastStmt(stmts []ast.Stmt, onLast func(last ast.Stmt)) { 203 if len(stmts) == 0 { 204 return 205 } 206 207 s := stmts[len(stmts)-1] 208 switch s := s.(type) { 209 case *ast.IfStmt: 210 loop: 211 for { 212 forEachLastStmt(s.Body.List, onLast) 213 switch e := s.Else.(type) { 214 case *ast.BlockStmt: 215 forEachLastStmt(e.List, onLast) 216 break loop 217 case *ast.IfStmt: 218 s = e 219 case nil: 220 break loop 221 } 222 } 223 case *ast.ForStmt: 224 forEachLastStmt(s.Body.List, onLast) 225 case *ast.RangeStmt: 226 forEachLastStmt(s.Body.List, onLast) 227 case *ast.SwitchStmt: 228 for _, c := range s.Body.List { 229 cc := c.(*ast.CaseClause) 230 forEachLastStmt(cc.Body, onLast) 231 } 232 case *ast.TypeSwitchStmt: 233 for _, c := range s.Body.List { 234 cc := c.(*ast.CaseClause) 235 forEachLastStmt(cc.Body, onLast) 236 } 237 case *ast.SelectStmt: 238 for _, c := range s.Body.List { 239 cc := c.(*ast.CommClause) 240 forEachLastStmt(cc.Body, onLast) 241 } 242 default: 243 onLast(s) 244 } 245 } 246 247 // litStmts returns all statements from the function body of a function 248 // literal. 249 // 250 // If fun is not a function literal, it returns nil. 251 func litStmts(fun ast.Expr) []ast.Stmt { 252 lit, _ := fun.(*ast.FuncLit) 253 if lit == nil { 254 return nil 255 } 256 return lit.Body.List 257 } 258 259 // goInvoke returns a function expression that would be called asynchronously 260 // (but not awaited) in another goroutine as a consequence of the call. 261 // For example, given the g.Go call below, it returns the function literal expression. 262 // 263 // import "sync/errgroup" 264 // var g errgroup.Group 265 // g.Go(func() error { ... }) 266 // 267 // Currently only "golang.org/x/sync/errgroup.Group()" is considered. 268 func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr { 269 if !isMethodCall(info, call, "golang.org/x/sync/errgroup", "Group", "Go") { 270 return nil 271 } 272 return call.Args[0] 273 } 274 275 // parallelSubtest returns statements that can be easily proven to execute 276 // concurrently via the go test runner, as t.Run has been invoked with a 277 // function literal that calls t.Parallel. 278 // 279 // In practice, users rely on the fact that statements before the call to 280 // t.Parallel are synchronous. For example by declaring test := test inside the 281 // function literal, but before the call to t.Parallel. 282 // 283 // Therefore, we only flag references in statements that are obviously 284 // dominated by a call to t.Parallel. As a simple heuristic, we only consider 285 // statements following the final labeled statement in the function body, to 286 // avoid scenarios where a jump would cause either the call to t.Parallel or 287 // the problematic reference to be skipped. 288 // 289 // import "testing" 290 // 291 // func TestFoo(t *testing.T) { 292 // tests := []int{0, 1, 2} 293 // for i, test := range tests { 294 // t.Run("subtest", func(t *testing.T) { 295 // println(i, test) // OK 296 // t.Parallel() 297 // println(i, test) // Not OK 298 // }) 299 // } 300 // } 301 func parallelSubtest(info *types.Info, call *ast.CallExpr) []ast.Stmt { 302 if !isMethodCall(info, call, "testing", "T", "Run") { 303 return nil 304 } 305 306 if len(call.Args) != 2 { 307 // Ignore calls such as t.Run(fn()). 308 return nil 309 } 310 311 lit, _ := call.Args[1].(*ast.FuncLit) 312 if lit == nil { 313 return nil 314 } 315 316 // Capture the *testing.T object for the first argument to the function 317 // literal. 318 if len(lit.Type.Params.List[0].Names) == 0 { 319 return nil 320 } 321 322 tObj := info.Defs[lit.Type.Params.List[0].Names[0]] 323 if tObj == nil { 324 return nil 325 } 326 327 // Match statements that occur after a call to t.Parallel following the final 328 // labeled statement in the function body. 329 // 330 // We iterate over lit.Body.List to have a simple, fast and "frequent enough" 331 // dominance relationship for t.Parallel(): lit.Body.List[i] dominates 332 // lit.Body.List[j] for i < j unless there is a jump. 333 var stmts []ast.Stmt 334 afterParallel := false 335 for _, stmt := range lit.Body.List { 336 stmt, labeled := unlabel(stmt) 337 if labeled { 338 // Reset: naively we don't know if a jump could have caused the 339 // previously considered statements to be skipped. 340 stmts = nil 341 afterParallel = false 342 } 343 344 if afterParallel { 345 stmts = append(stmts, stmt) 346 continue 347 } 348 349 // Check if stmt is a call to t.Parallel(), for the correct t. 350 exprStmt, ok := stmt.(*ast.ExprStmt) 351 if !ok { 352 continue 353 } 354 expr := exprStmt.X 355 if isMethodCall(info, expr, "testing", "T", "Parallel") { 356 call, _ := expr.(*ast.CallExpr) 357 if call == nil { 358 continue 359 } 360 x, _ := call.Fun.(*ast.SelectorExpr) 361 if x == nil { 362 continue 363 } 364 id, _ := x.X.(*ast.Ident) 365 if id == nil { 366 continue 367 } 368 if info.Uses[id] == tObj { 369 afterParallel = true 370 } 371 } 372 } 373 374 return stmts 375 } 376 377 // unlabel returns the inner statement for the possibly labeled statement stmt, 378 // stripping any (possibly nested) *ast.LabeledStmt wrapper. 379 // 380 // The second result reports whether stmt was an *ast.LabeledStmt. 381 func unlabel(stmt ast.Stmt) (ast.Stmt, bool) { 382 labeled := false 383 for { 384 labelStmt, ok := stmt.(*ast.LabeledStmt) 385 if !ok { 386 return stmt, labeled 387 } 388 labeled = true 389 stmt = labelStmt.Stmt 390 } 391 } 392 393 // isMethodCall reports whether expr is a method call of 394 // <pkgPath>.<typeName>.<method>. 395 func isMethodCall(info *types.Info, expr ast.Expr, pkgPath, typeName, method string) bool { 396 call, ok := expr.(*ast.CallExpr) 397 if !ok { 398 return false 399 } 400 401 // Check that we are calling a method <method> 402 f := typeutil.StaticCallee(info, call) 403 if f == nil || f.Name() != method { 404 return false 405 } 406 recv := f.Type().(*types.Signature).Recv() 407 if recv == nil { 408 return false 409 } 410 411 // Check that the receiver is a <pkgPath>.<typeName> or 412 // *<pkgPath>.<typeName>. 413 rtype := recv.Type() 414 if ptr, ok := recv.Type().(*types.Pointer); ok { 415 rtype = ptr.Elem() 416 } 417 named, ok := rtype.(*types.Named) 418 if !ok { 419 return false 420 } 421 if named.Obj().Name() != typeName { 422 return false 423 } 424 pkg := f.Pkg() 425 if pkg == nil { 426 return false 427 } 428 if pkg.Path() != pkgPath { 429 return false 430 } 431 432 return true 433 }