github.com/jhump/golang-x-tools@v0.0.0-20220218190644-4958d6d39439/go/analysis/passes/loopclosure/loopclosure.go (about) 1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package loopclosure defines an Analyzer that checks for references to 6 // enclosing loop variables from within nested functions. 7 package loopclosure 8 9 import ( 10 "go/ast" 11 "go/types" 12 13 "github.com/jhump/golang-x-tools/go/analysis" 14 "github.com/jhump/golang-x-tools/go/analysis/passes/inspect" 15 "github.com/jhump/golang-x-tools/go/ast/inspector" 16 "github.com/jhump/golang-x-tools/go/types/typeutil" 17 ) 18 19 const Doc = `check references to loop variables from within nested functions 20 21 This analyzer checks for references to loop variables from within a 22 function literal inside the loop body. It checks only instances where 23 the function literal is called in a defer or go statement that is the 24 last statement in the loop body, as otherwise we would need whole 25 program analysis. 26 27 For example: 28 29 for i, v := range s { 30 go func() { 31 println(i, v) // not what you might expect 32 }() 33 } 34 35 See: https://golang.org/doc/go_faq.html#closures_and_goroutines` 36 37 var Analyzer = &analysis.Analyzer{ 38 Name: "loopclosure", 39 Doc: Doc, 40 Requires: []*analysis.Analyzer{inspect.Analyzer}, 41 Run: run, 42 } 43 44 func run(pass *analysis.Pass) (interface{}, error) { 45 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) 46 47 nodeFilter := []ast.Node{ 48 (*ast.RangeStmt)(nil), 49 (*ast.ForStmt)(nil), 50 } 51 inspect.Preorder(nodeFilter, func(n ast.Node) { 52 // Find the variables updated by the loop statement. 53 var vars []*ast.Ident 54 addVar := func(expr ast.Expr) { 55 if id, ok := expr.(*ast.Ident); ok { 56 vars = append(vars, id) 57 } 58 } 59 var body *ast.BlockStmt 60 switch n := n.(type) { 61 case *ast.RangeStmt: 62 body = n.Body 63 addVar(n.Key) 64 addVar(n.Value) 65 case *ast.ForStmt: 66 body = n.Body 67 switch post := n.Post.(type) { 68 case *ast.AssignStmt: 69 // e.g. for p = head; p != nil; p = p.next 70 for _, lhs := range post.Lhs { 71 addVar(lhs) 72 } 73 case *ast.IncDecStmt: 74 // e.g. for i := 0; i < n; i++ 75 addVar(post.X) 76 } 77 } 78 if vars == nil { 79 return 80 } 81 82 // Inspect a go or defer statement 83 // if it's the last one in the loop body. 84 // (We give up if there are following statements, 85 // because it's hard to prove go isn't followed by wait, 86 // or defer by return.) 87 if len(body.List) == 0 { 88 return 89 } 90 // The function invoked in the last return statement. 91 var fun ast.Expr 92 switch s := body.List[len(body.List)-1].(type) { 93 case *ast.GoStmt: 94 fun = s.Call.Fun 95 case *ast.DeferStmt: 96 fun = s.Call.Fun 97 case *ast.ExprStmt: // check for errgroup.Group.Go() 98 if call, ok := s.X.(*ast.CallExpr); ok { 99 fun = goInvokes(pass.TypesInfo, call) 100 } 101 } 102 lit, ok := fun.(*ast.FuncLit) 103 if !ok { 104 return 105 } 106 ast.Inspect(lit.Body, func(n ast.Node) bool { 107 id, ok := n.(*ast.Ident) 108 if !ok || id.Obj == nil { 109 return true 110 } 111 if pass.TypesInfo.Types[id].Type == nil { 112 // Not referring to a variable (e.g. struct field name) 113 return true 114 } 115 for _, v := range vars { 116 if v.Obj == id.Obj { 117 pass.ReportRangef(id, "loop variable %s captured by func literal", 118 id.Name) 119 } 120 } 121 return true 122 }) 123 }) 124 return nil, nil 125 } 126 127 // goInvokes returns a function expression that would be called asynchronously 128 // (but not awaited) in another goroutine as a consequence of the call. 129 // For example, given the g.Go call below, it returns the function literal expression. 130 // 131 // import "sync/errgroup" 132 // var g errgroup.Group 133 // g.Go(func() error { ... }) 134 // 135 // Currently only "golang.org/x/sync/errgroup.Group()" is considered. 136 func goInvokes(info *types.Info, call *ast.CallExpr) ast.Expr { 137 f := typeutil.StaticCallee(info, call) 138 // Note: Currently only supports: golang.org/x/sync/errgroup.Go. 139 if f == nil || f.Name() != "Go" { 140 return nil 141 } 142 recv := f.Type().(*types.Signature).Recv() 143 if recv == nil { 144 return nil 145 } 146 rtype, ok := recv.Type().(*types.Pointer) 147 if !ok { 148 return nil 149 } 150 named, ok := rtype.Elem().(*types.Named) 151 if !ok { 152 return nil 153 } 154 if named.Obj().Name() != "Group" { 155 return nil 156 } 157 pkg := f.Pkg() 158 if pkg == nil { 159 return nil 160 } 161 if pkg.Path() != "golang.org/x/sync/errgroup" { 162 return nil 163 } 164 return call.Args[0] 165 }