github.com/ltltlt/go-source-code@v0.0.0-20190830023027-95be009773aa/cmd/vet/lostcancel.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"cmd/vet/internal/cfg"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/types"
    12  	"strconv"
    13  )
    14  
    15  func init() {
    16  	register("lostcancel",
    17  		"check for failure to call cancelation function returned by context.WithCancel",
    18  		checkLostCancel,
    19  		funcDecl, funcLit)
    20  }
    21  
    22  const debugLostCancel = false
    23  
    24  var contextPackage = "context"
    25  
    26  // checkLostCancel reports a failure to the call the cancel function
    27  // returned by context.WithCancel, either because the variable was
    28  // assigned to the blank identifier, or because there exists a
    29  // control-flow path from the call to a return statement and that path
    30  // does not "use" the cancel function.  Any reference to the variable
    31  // counts as a use, even within a nested function literal.
    32  //
    33  // checkLostCancel analyzes a single named or literal function.
    34  func checkLostCancel(f *File, node ast.Node) {
    35  	// Fast path: bypass check if file doesn't use context.WithCancel.
    36  	if !hasImport(f.file, contextPackage) {
    37  		return
    38  	}
    39  
    40  	// Maps each cancel variable to its defining ValueSpec/AssignStmt.
    41  	cancelvars := make(map[*types.Var]ast.Node)
    42  
    43  	// Find the set of cancel vars to analyze.
    44  	stack := make([]ast.Node, 0, 32)
    45  	ast.Inspect(node, func(n ast.Node) bool {
    46  		switch n.(type) {
    47  		case *ast.FuncLit:
    48  			if len(stack) > 0 {
    49  				return false // don't stray into nested functions
    50  			}
    51  		case nil:
    52  			stack = stack[:len(stack)-1] // pop
    53  			return true
    54  		}
    55  		stack = append(stack, n) // push
    56  
    57  		// Look for [{AssignStmt,ValueSpec} CallExpr SelectorExpr]:
    58  		//
    59  		//   ctx, cancel    := context.WithCancel(...)
    60  		//   ctx, cancel     = context.WithCancel(...)
    61  		//   var ctx, cancel = context.WithCancel(...)
    62  		//
    63  		if isContextWithCancel(f, n) && isCall(stack[len(stack)-2]) {
    64  			var id *ast.Ident // id of cancel var
    65  			stmt := stack[len(stack)-3]
    66  			switch stmt := stmt.(type) {
    67  			case *ast.ValueSpec:
    68  				if len(stmt.Names) > 1 {
    69  					id = stmt.Names[1]
    70  				}
    71  			case *ast.AssignStmt:
    72  				if len(stmt.Lhs) > 1 {
    73  					id, _ = stmt.Lhs[1].(*ast.Ident)
    74  				}
    75  			}
    76  			if id != nil {
    77  				if id.Name == "_" {
    78  					f.Badf(id.Pos(), "the cancel function returned by context.%s should be called, not discarded, to avoid a context leak",
    79  						n.(*ast.SelectorExpr).Sel.Name)
    80  				} else if v, ok := f.pkg.uses[id].(*types.Var); ok {
    81  					cancelvars[v] = stmt
    82  				} else if v, ok := f.pkg.defs[id].(*types.Var); ok {
    83  					cancelvars[v] = stmt
    84  				}
    85  			}
    86  		}
    87  
    88  		return true
    89  	})
    90  
    91  	if len(cancelvars) == 0 {
    92  		return // no need to build CFG
    93  	}
    94  
    95  	// Tell the CFG builder which functions never return.
    96  	info := &types.Info{Uses: f.pkg.uses, Selections: f.pkg.selectors}
    97  	mayReturn := func(call *ast.CallExpr) bool {
    98  		name := callName(info, call)
    99  		return !noReturnFuncs[name]
   100  	}
   101  
   102  	// Build the CFG.
   103  	var g *cfg.CFG
   104  	var sig *types.Signature
   105  	switch node := node.(type) {
   106  	case *ast.FuncDecl:
   107  		obj := f.pkg.defs[node.Name]
   108  		if obj == nil {
   109  			return // type error (e.g. duplicate function declaration)
   110  		}
   111  		sig, _ = obj.Type().(*types.Signature)
   112  		g = cfg.New(node.Body, mayReturn)
   113  	case *ast.FuncLit:
   114  		sig, _ = f.pkg.types[node.Type].Type.(*types.Signature)
   115  		g = cfg.New(node.Body, mayReturn)
   116  	}
   117  
   118  	// Print CFG.
   119  	if debugLostCancel {
   120  		fmt.Println(g.Format(f.fset))
   121  	}
   122  
   123  	// Examine the CFG for each variable in turn.
   124  	// (It would be more efficient to analyze all cancelvars in a
   125  	// single pass over the AST, but seldom is there more than one.)
   126  	for v, stmt := range cancelvars {
   127  		if ret := lostCancelPath(f, g, v, stmt, sig); ret != nil {
   128  			lineno := f.fset.Position(stmt.Pos()).Line
   129  			f.Badf(stmt.Pos(), "the %s function is not used on all paths (possible context leak)", v.Name())
   130  			f.Badf(ret.Pos(), "this return statement may be reached without using the %s var defined on line %d", v.Name(), lineno)
   131  		}
   132  	}
   133  }
   134  
   135  func isCall(n ast.Node) bool { _, ok := n.(*ast.CallExpr); return ok }
   136  
   137  func hasImport(f *ast.File, path string) bool {
   138  	for _, imp := range f.Imports {
   139  		v, _ := strconv.Unquote(imp.Path.Value)
   140  		if v == path {
   141  			return true
   142  		}
   143  	}
   144  	return false
   145  }
   146  
   147  // isContextWithCancel reports whether n is one of the qualified identifiers
   148  // context.With{Cancel,Timeout,Deadline}.
   149  func isContextWithCancel(f *File, n ast.Node) bool {
   150  	if sel, ok := n.(*ast.SelectorExpr); ok {
   151  		switch sel.Sel.Name {
   152  		case "WithCancel", "WithTimeout", "WithDeadline":
   153  			if x, ok := sel.X.(*ast.Ident); ok {
   154  				if pkgname, ok := f.pkg.uses[x].(*types.PkgName); ok {
   155  					return pkgname.Imported().Path() == contextPackage
   156  				}
   157  				// Import failed, so we can't check package path.
   158  				// Just check the local package name (heuristic).
   159  				return x.Name == "context"
   160  			}
   161  		}
   162  	}
   163  	return false
   164  }
   165  
   166  // lostCancelPath finds a path through the CFG, from stmt (which defines
   167  // the 'cancel' variable v) to a return statement, that doesn't "use" v.
   168  // If it finds one, it returns the return statement (which may be synthetic).
   169  // sig is the function's type, if known.
   170  func lostCancelPath(f *File, g *cfg.CFG, v *types.Var, stmt ast.Node, sig *types.Signature) *ast.ReturnStmt {
   171  	vIsNamedResult := sig != nil && tupleContains(sig.Results(), v)
   172  
   173  	// uses reports whether stmts contain a "use" of variable v.
   174  	uses := func(f *File, v *types.Var, stmts []ast.Node) bool {
   175  		found := false
   176  		for _, stmt := range stmts {
   177  			ast.Inspect(stmt, func(n ast.Node) bool {
   178  				switch n := n.(type) {
   179  				case *ast.Ident:
   180  					if f.pkg.uses[n] == v {
   181  						found = true
   182  					}
   183  				case *ast.ReturnStmt:
   184  					// A naked return statement counts as a use
   185  					// of the named result variables.
   186  					if n.Results == nil && vIsNamedResult {
   187  						found = true
   188  					}
   189  				}
   190  				return !found
   191  			})
   192  		}
   193  		return found
   194  	}
   195  
   196  	// blockUses computes "uses" for each block, caching the result.
   197  	memo := make(map[*cfg.Block]bool)
   198  	blockUses := func(f *File, v *types.Var, b *cfg.Block) bool {
   199  		res, ok := memo[b]
   200  		if !ok {
   201  			res = uses(f, v, b.Nodes)
   202  			memo[b] = res
   203  		}
   204  		return res
   205  	}
   206  
   207  	// Find the var's defining block in the CFG,
   208  	// plus the rest of the statements of that block.
   209  	var defblock *cfg.Block
   210  	var rest []ast.Node
   211  outer:
   212  	for _, b := range g.Blocks {
   213  		for i, n := range b.Nodes {
   214  			if n == stmt {
   215  				defblock = b
   216  				rest = b.Nodes[i+1:]
   217  				break outer
   218  			}
   219  		}
   220  	}
   221  	if defblock == nil {
   222  		panic("internal error: can't find defining block for cancel var")
   223  	}
   224  
   225  	// Is v "used" in the remainder of its defining block?
   226  	if uses(f, v, rest) {
   227  		return nil
   228  	}
   229  
   230  	// Does the defining block return without using v?
   231  	if ret := defblock.Return(); ret != nil {
   232  		return ret
   233  	}
   234  
   235  	// Search the CFG depth-first for a path, from defblock to a
   236  	// return block, in which v is never "used".
   237  	seen := make(map[*cfg.Block]bool)
   238  	var search func(blocks []*cfg.Block) *ast.ReturnStmt
   239  	search = func(blocks []*cfg.Block) *ast.ReturnStmt {
   240  		for _, b := range blocks {
   241  			if !seen[b] {
   242  				seen[b] = true
   243  
   244  				// Prune the search if the block uses v.
   245  				if blockUses(f, v, b) {
   246  					continue
   247  				}
   248  
   249  				// Found path to return statement?
   250  				if ret := b.Return(); ret != nil {
   251  					if debugLostCancel {
   252  						fmt.Printf("found path to return in block %s\n", b)
   253  					}
   254  					return ret // found
   255  				}
   256  
   257  				// Recur
   258  				if ret := search(b.Succs); ret != nil {
   259  					if debugLostCancel {
   260  						fmt.Printf(" from block %s\n", b)
   261  					}
   262  					return ret
   263  				}
   264  			}
   265  		}
   266  		return nil
   267  	}
   268  	return search(defblock.Succs)
   269  }
   270  
   271  func tupleContains(tuple *types.Tuple, v *types.Var) bool {
   272  	for i := 0; i < tuple.Len(); i++ {
   273  		if tuple.At(i) == v {
   274  			return true
   275  		}
   276  	}
   277  	return false
   278  }
   279  
   280  var noReturnFuncs = map[string]bool{
   281  	"(*testing.common).FailNow": true,
   282  	"(*testing.common).Fatal":   true,
   283  	"(*testing.common).Fatalf":  true,
   284  	"(*testing.common).Skip":    true,
   285  	"(*testing.common).SkipNow": true,
   286  	"(*testing.common).Skipf":   true,
   287  	"log.Fatal":                 true,
   288  	"log.Fatalf":                true,
   289  	"log.Fatalln":               true,
   290  	"os.Exit":                   true,
   291  	"panic":                     true,
   292  	"runtime.Goexit":            true,
   293  }
   294  
   295  // callName returns the canonical name of the builtin, method, or
   296  // function called by call, if known.
   297  func callName(info *types.Info, call *ast.CallExpr) string {
   298  	switch fun := call.Fun.(type) {
   299  	case *ast.Ident:
   300  		// builtin, e.g. "panic"
   301  		if obj, ok := info.Uses[fun].(*types.Builtin); ok {
   302  			return obj.Name()
   303  		}
   304  	case *ast.SelectorExpr:
   305  		if sel, ok := info.Selections[fun]; ok && sel.Kind() == types.MethodVal {
   306  			// method call, e.g. "(*testing.common).Fatal"
   307  			meth := sel.Obj()
   308  			return fmt.Sprintf("(%s).%s",
   309  				meth.Type().(*types.Signature).Recv().Type(),
   310  				meth.Name())
   311  		}
   312  		if obj, ok := info.Uses[fun.Sel]; ok {
   313  			// qualified identifier, e.g. "os.Exit"
   314  			return fmt.Sprintf("%s.%s",
   315  				obj.Pkg().Path(),
   316  				obj.Name())
   317  		}
   318  	}
   319  
   320  	// function with no name, or defined in missing imported package
   321  	return ""
   322  }