github.com/karrick/go@v0.0.0-20170817181416-d5b0ec858b37/src/cmd/vet/httpresponse.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // This file contains the check for http.Response values being used before
     6  // checking for errors.
     7  
     8  package main
     9  
    10  import (
    11  	"go/ast"
    12  	"go/types"
    13  )
    14  
    15  func init() {
    16  	register("httpresponse",
    17  		"check errors are checked before using an http Response",
    18  		checkHTTPResponse, callExpr)
    19  }
    20  
    21  func checkHTTPResponse(f *File, node ast.Node) {
    22  	// If http.Response or http.Client are not defined, skip this check.
    23  	if httpResponseType == nil || httpClientType == nil {
    24  		return
    25  	}
    26  	call := node.(*ast.CallExpr)
    27  	if !isHTTPFuncOrMethodOnClient(f, call) {
    28  		return // the function call is not related to this check.
    29  	}
    30  
    31  	finder := &blockStmtFinder{node: call}
    32  	ast.Walk(finder, f.file)
    33  	stmts := finder.stmts()
    34  	if len(stmts) < 2 {
    35  		return // the call to the http function is the last statement of the block.
    36  	}
    37  
    38  	asg, ok := stmts[0].(*ast.AssignStmt)
    39  	if !ok {
    40  		return // the first statement is not assignment.
    41  	}
    42  	resp := rootIdent(asg.Lhs[0])
    43  	if resp == nil {
    44  		return // could not find the http.Response in the assignment.
    45  	}
    46  
    47  	def, ok := stmts[1].(*ast.DeferStmt)
    48  	if !ok {
    49  		return // the following statement is not a defer.
    50  	}
    51  	root := rootIdent(def.Call.Fun)
    52  	if root == nil {
    53  		return // could not find the receiver of the defer call.
    54  	}
    55  
    56  	if resp.Obj == root.Obj {
    57  		f.Badf(root.Pos(), "using %s before checking for errors", resp.Name)
    58  	}
    59  }
    60  
    61  // isHTTPFuncOrMethodOnClient checks whether the given call expression is on
    62  // either a function of the net/http package or a method of http.Client that
    63  // returns (*http.Response, error).
    64  func isHTTPFuncOrMethodOnClient(f *File, expr *ast.CallExpr) bool {
    65  	fun, _ := expr.Fun.(*ast.SelectorExpr)
    66  	sig, _ := f.pkg.types[fun].Type.(*types.Signature)
    67  	if sig == nil {
    68  		return false // the call is not on of the form x.f()
    69  	}
    70  
    71  	res := sig.Results()
    72  	if res.Len() != 2 {
    73  		return false // the function called does not return two values.
    74  	}
    75  	if ptr, ok := res.At(0).Type().(*types.Pointer); !ok || !types.Identical(ptr.Elem(), httpResponseType) {
    76  		return false // the first return type is not *http.Response.
    77  	}
    78  	if !types.Identical(res.At(1).Type().Underlying(), errorType) {
    79  		return false // the second return type is not error
    80  	}
    81  
    82  	typ := f.pkg.types[fun.X].Type
    83  	if typ == nil {
    84  		id, ok := fun.X.(*ast.Ident)
    85  		return ok && id.Name == "http" // function in net/http package.
    86  	}
    87  
    88  	if types.Identical(typ, httpClientType) {
    89  		return true // method on http.Client.
    90  	}
    91  	ptr, ok := typ.(*types.Pointer)
    92  	return ok && types.Identical(ptr.Elem(), httpClientType) // method on *http.Client.
    93  }
    94  
    95  // blockStmtFinder is an ast.Visitor that given any ast node can find the
    96  // statement containing it and its succeeding statements in the same block.
    97  type blockStmtFinder struct {
    98  	node  ast.Node       // target of search
    99  	stmt  ast.Stmt       // innermost statement enclosing argument to Visit
   100  	block *ast.BlockStmt // innermost block enclosing argument to Visit.
   101  }
   102  
   103  // Visit finds f.node performing a search down the ast tree.
   104  // It keeps the last block statement and statement seen for later use.
   105  func (f *blockStmtFinder) Visit(node ast.Node) ast.Visitor {
   106  	if node == nil || f.node.Pos() < node.Pos() || f.node.End() > node.End() {
   107  		return nil // not here
   108  	}
   109  	switch n := node.(type) {
   110  	case *ast.BlockStmt:
   111  		f.block = n
   112  	case ast.Stmt:
   113  		f.stmt = n
   114  	}
   115  	if f.node.Pos() == node.Pos() && f.node.End() == node.End() {
   116  		return nil // found
   117  	}
   118  	return f // keep looking
   119  }
   120  
   121  // stmts returns the statements of f.block starting from the one including f.node.
   122  func (f *blockStmtFinder) stmts() []ast.Stmt {
   123  	for i, v := range f.block.List {
   124  		if f.stmt == v {
   125  			return f.block.List[i:]
   126  		}
   127  	}
   128  	return nil
   129  }
   130  
   131  // rootIdent finds the root identifier x in a chain of selections x.y.z, or nil if not found.
   132  func rootIdent(n ast.Node) *ast.Ident {
   133  	switch n := n.(type) {
   134  	case *ast.SelectorExpr:
   135  		return rootIdent(n.X)
   136  	case *ast.Ident:
   137  		return n
   138  	default:
   139  		return nil
   140  	}
   141  }