golang.org/x/tools@v0.21.0/go/analysis/passes/loopclosure/loopclosure.go (about) 1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package loopclosure 6 7 import ( 8 _ "embed" 9 "go/ast" 10 "go/types" 11 12 "golang.org/x/tools/go/analysis" 13 "golang.org/x/tools/go/analysis/passes/inspect" 14 "golang.org/x/tools/go/analysis/passes/internal/analysisutil" 15 "golang.org/x/tools/go/ast/inspector" 16 "golang.org/x/tools/go/types/typeutil" 17 "golang.org/x/tools/internal/typesinternal" 18 "golang.org/x/tools/internal/versions" 19 ) 20 21 //go:embed doc.go 22 var doc string 23 24 var Analyzer = &analysis.Analyzer{ 25 Name: "loopclosure", 26 Doc: analysisutil.MustExtractDoc(doc, "loopclosure"), 27 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/loopclosure", 28 Requires: []*analysis.Analyzer{inspect.Analyzer}, 29 Run: run, 30 } 31 32 func run(pass *analysis.Pass) (interface{}, error) { 33 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) 34 35 nodeFilter := []ast.Node{ 36 (*ast.File)(nil), 37 (*ast.RangeStmt)(nil), 38 (*ast.ForStmt)(nil), 39 } 40 inspect.Nodes(nodeFilter, func(n ast.Node, push bool) bool { 41 if !push { 42 // inspect.Nodes is slightly suboptimal as we only use push=true. 43 return true 44 } 45 // Find the variables updated by the loop statement. 46 var vars []types.Object 47 addVar := func(expr ast.Expr) { 48 if id, _ := expr.(*ast.Ident); id != nil { 49 if obj := pass.TypesInfo.ObjectOf(id); obj != nil { 50 vars = append(vars, obj) 51 } 52 } 53 } 54 var body *ast.BlockStmt 55 switch n := n.(type) { 56 case *ast.File: 57 // Only traverse the file if its goversion is strictly before go1.22. 58 goversion := versions.FileVersion(pass.TypesInfo, n) 59 return versions.Before(goversion, versions.Go1_22) 60 case *ast.RangeStmt: 61 body = n.Body 62 addVar(n.Key) 63 addVar(n.Value) 64 case *ast.ForStmt: 65 body = n.Body 66 switch post := n.Post.(type) { 67 case *ast.AssignStmt: 68 // e.g. for p = head; p != nil; p = p.next 69 for _, lhs := range post.Lhs { 70 addVar(lhs) 71 } 72 case *ast.IncDecStmt: 73 // e.g. for i := 0; i < n; i++ 74 addVar(post.X) 75 } 76 } 77 if vars == nil { 78 return true 79 } 80 81 // Inspect statements to find function literals that may be run outside of 82 // the current loop iteration. 83 // 84 // For go, defer, and errgroup.Group.Go, we ignore all but the last 85 // statement, because it's hard to prove go isn't followed by wait, or 86 // defer by return. "Last" is defined recursively. 87 // 88 // TODO: consider allowing the "last" go/defer/Go statement to be followed by 89 // N "trivial" statements, possibly under a recursive definition of "trivial" 90 // so that that checker could, for example, conclude that a go statement is 91 // followed by an if statement made of only trivial statements and trivial expressions, 92 // and hence the go statement could still be checked. 93 forEachLastStmt(body.List, func(last ast.Stmt) { 94 var stmts []ast.Stmt 95 switch s := last.(type) { 96 case *ast.GoStmt: 97 stmts = litStmts(s.Call.Fun) 98 case *ast.DeferStmt: 99 stmts = litStmts(s.Call.Fun) 100 case *ast.ExprStmt: // check for errgroup.Group.Go 101 if call, ok := s.X.(*ast.CallExpr); ok { 102 stmts = litStmts(goInvoke(pass.TypesInfo, call)) 103 } 104 } 105 for _, stmt := range stmts { 106 reportCaptured(pass, vars, stmt) 107 } 108 }) 109 110 // Also check for testing.T.Run (with T.Parallel). 111 // We consider every t.Run statement in the loop body, because there is 112 // no commonly used mechanism for synchronizing parallel subtests. 113 // It is of course theoretically possible to synchronize parallel subtests, 114 // though such a pattern is likely to be exceedingly rare as it would be 115 // fighting against the test runner. 116 for _, s := range body.List { 117 switch s := s.(type) { 118 case *ast.ExprStmt: 119 if call, ok := s.X.(*ast.CallExpr); ok { 120 for _, stmt := range parallelSubtest(pass.TypesInfo, call) { 121 reportCaptured(pass, vars, stmt) 122 } 123 124 } 125 } 126 } 127 return true 128 }) 129 return nil, nil 130 } 131 132 // reportCaptured reports a diagnostic stating a loop variable 133 // has been captured by a func literal if checkStmt has escaping 134 // references to vars. vars is expected to be variables updated by a loop statement, 135 // and checkStmt is expected to be a statements from the body of a func literal in the loop. 136 func reportCaptured(pass *analysis.Pass, vars []types.Object, checkStmt ast.Stmt) { 137 ast.Inspect(checkStmt, func(n ast.Node) bool { 138 id, ok := n.(*ast.Ident) 139 if !ok { 140 return true 141 } 142 obj := pass.TypesInfo.Uses[id] 143 if obj == nil { 144 return true 145 } 146 for _, v := range vars { 147 if v == obj { 148 pass.ReportRangef(id, "loop variable %s captured by func literal", id.Name) 149 } 150 } 151 return true 152 }) 153 } 154 155 // forEachLastStmt calls onLast on each "last" statement in a list of statements. 156 // "Last" is defined recursively so, for example, if the last statement is 157 // a switch statement, then each switch case is also visited to examine 158 // its last statements. 159 func forEachLastStmt(stmts []ast.Stmt, onLast func(last ast.Stmt)) { 160 if len(stmts) == 0 { 161 return 162 } 163 164 s := stmts[len(stmts)-1] 165 switch s := s.(type) { 166 case *ast.IfStmt: 167 loop: 168 for { 169 forEachLastStmt(s.Body.List, onLast) 170 switch e := s.Else.(type) { 171 case *ast.BlockStmt: 172 forEachLastStmt(e.List, onLast) 173 break loop 174 case *ast.IfStmt: 175 s = e 176 case nil: 177 break loop 178 } 179 } 180 case *ast.ForStmt: 181 forEachLastStmt(s.Body.List, onLast) 182 case *ast.RangeStmt: 183 forEachLastStmt(s.Body.List, onLast) 184 case *ast.SwitchStmt: 185 for _, c := range s.Body.List { 186 cc := c.(*ast.CaseClause) 187 forEachLastStmt(cc.Body, onLast) 188 } 189 case *ast.TypeSwitchStmt: 190 for _, c := range s.Body.List { 191 cc := c.(*ast.CaseClause) 192 forEachLastStmt(cc.Body, onLast) 193 } 194 case *ast.SelectStmt: 195 for _, c := range s.Body.List { 196 cc := c.(*ast.CommClause) 197 forEachLastStmt(cc.Body, onLast) 198 } 199 default: 200 onLast(s) 201 } 202 } 203 204 // litStmts returns all statements from the function body of a function 205 // literal. 206 // 207 // If fun is not a function literal, it returns nil. 208 func litStmts(fun ast.Expr) []ast.Stmt { 209 lit, _ := fun.(*ast.FuncLit) 210 if lit == nil { 211 return nil 212 } 213 return lit.Body.List 214 } 215 216 // goInvoke returns a function expression that would be called asynchronously 217 // (but not awaited) in another goroutine as a consequence of the call. 218 // For example, given the g.Go call below, it returns the function literal expression. 219 // 220 // import "sync/errgroup" 221 // var g errgroup.Group 222 // g.Go(func() error { ... }) 223 // 224 // Currently only "golang.org/x/sync/errgroup.Group()" is considered. 225 func goInvoke(info *types.Info, call *ast.CallExpr) ast.Expr { 226 if !isMethodCall(info, call, "golang.org/x/sync/errgroup", "Group", "Go") { 227 return nil 228 } 229 return call.Args[0] 230 } 231 232 // parallelSubtest returns statements that can be easily proven to execute 233 // concurrently via the go test runner, as t.Run has been invoked with a 234 // function literal that calls t.Parallel. 235 // 236 // In practice, users rely on the fact that statements before the call to 237 // t.Parallel are synchronous. For example by declaring test := test inside the 238 // function literal, but before the call to t.Parallel. 239 // 240 // Therefore, we only flag references in statements that are obviously 241 // dominated by a call to t.Parallel. As a simple heuristic, we only consider 242 // statements following the final labeled statement in the function body, to 243 // avoid scenarios where a jump would cause either the call to t.Parallel or 244 // the problematic reference to be skipped. 245 // 246 // import "testing" 247 // 248 // func TestFoo(t *testing.T) { 249 // tests := []int{0, 1, 2} 250 // for i, test := range tests { 251 // t.Run("subtest", func(t *testing.T) { 252 // println(i, test) // OK 253 // t.Parallel() 254 // println(i, test) // Not OK 255 // }) 256 // } 257 // } 258 func parallelSubtest(info *types.Info, call *ast.CallExpr) []ast.Stmt { 259 if !isMethodCall(info, call, "testing", "T", "Run") { 260 return nil 261 } 262 263 if len(call.Args) != 2 { 264 // Ignore calls such as t.Run(fn()). 265 return nil 266 } 267 268 lit, _ := call.Args[1].(*ast.FuncLit) 269 if lit == nil { 270 return nil 271 } 272 273 // Capture the *testing.T object for the first argument to the function 274 // literal. 275 if len(lit.Type.Params.List[0].Names) == 0 { 276 return nil 277 } 278 279 tObj := info.Defs[lit.Type.Params.List[0].Names[0]] 280 if tObj == nil { 281 return nil 282 } 283 284 // Match statements that occur after a call to t.Parallel following the final 285 // labeled statement in the function body. 286 // 287 // We iterate over lit.Body.List to have a simple, fast and "frequent enough" 288 // dominance relationship for t.Parallel(): lit.Body.List[i] dominates 289 // lit.Body.List[j] for i < j unless there is a jump. 290 var stmts []ast.Stmt 291 afterParallel := false 292 for _, stmt := range lit.Body.List { 293 stmt, labeled := unlabel(stmt) 294 if labeled { 295 // Reset: naively we don't know if a jump could have caused the 296 // previously considered statements to be skipped. 297 stmts = nil 298 afterParallel = false 299 } 300 301 if afterParallel { 302 stmts = append(stmts, stmt) 303 continue 304 } 305 306 // Check if stmt is a call to t.Parallel(), for the correct t. 307 exprStmt, ok := stmt.(*ast.ExprStmt) 308 if !ok { 309 continue 310 } 311 expr := exprStmt.X 312 if isMethodCall(info, expr, "testing", "T", "Parallel") { 313 call, _ := expr.(*ast.CallExpr) 314 if call == nil { 315 continue 316 } 317 x, _ := call.Fun.(*ast.SelectorExpr) 318 if x == nil { 319 continue 320 } 321 id, _ := x.X.(*ast.Ident) 322 if id == nil { 323 continue 324 } 325 if info.Uses[id] == tObj { 326 afterParallel = true 327 } 328 } 329 } 330 331 return stmts 332 } 333 334 // unlabel returns the inner statement for the possibly labeled statement stmt, 335 // stripping any (possibly nested) *ast.LabeledStmt wrapper. 336 // 337 // The second result reports whether stmt was an *ast.LabeledStmt. 338 func unlabel(stmt ast.Stmt) (ast.Stmt, bool) { 339 labeled := false 340 for { 341 labelStmt, ok := stmt.(*ast.LabeledStmt) 342 if !ok { 343 return stmt, labeled 344 } 345 labeled = true 346 stmt = labelStmt.Stmt 347 } 348 } 349 350 // isMethodCall reports whether expr is a method call of 351 // <pkgPath>.<typeName>.<method>. 352 func isMethodCall(info *types.Info, expr ast.Expr, pkgPath, typeName, method string) bool { 353 call, ok := expr.(*ast.CallExpr) 354 if !ok { 355 return false 356 } 357 358 // Check that we are calling a method <method> 359 f := typeutil.StaticCallee(info, call) 360 if f == nil || f.Name() != method { 361 return false 362 } 363 recv := f.Type().(*types.Signature).Recv() 364 if recv == nil { 365 return false 366 } 367 368 // Check that the receiver is a <pkgPath>.<typeName> or 369 // *<pkgPath>.<typeName>. 370 _, named := typesinternal.ReceiverNamed(recv) 371 return analysisutil.IsNamedType(named, pkgPath, typeName) 372 }