github.com/prysmaticlabs/prysm@v1.4.4/tools/analyzers/errcheck/analyzer.go (about)

     1  // Package errcheck implements an static analysis analyzer to ensure that errors are handled in go
     2  // code. This analyzer was adapted from https://github.com/kisielk/errcheck (MIT License).
     3  package errcheck
     4  
     5  import (
     6  	"errors"
     7  	"fmt"
     8  	"go/ast"
     9  	"go/token"
    10  	"go/types"
    11  
    12  	"golang.org/x/tools/go/analysis"
    13  	"golang.org/x/tools/go/analysis/passes/inspect"
    14  	"golang.org/x/tools/go/ast/inspector"
    15  )
    16  
    17  // Doc explaining the tool.
    18  const Doc = "This tool enforces all errors must be handled and that type assertions test that " +
    19  	"the type implements the given interface to prevent runtime panics."
    20  
    21  // Analyzer runs static analysis.
    22  var Analyzer = &analysis.Analyzer{
    23  	Name:     "errcheck",
    24  	Doc:      Doc,
    25  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    26  	Run:      run,
    27  }
    28  
    29  var exclusions = make(map[string]bool)
    30  
    31  func init() {
    32  	for _, exc := range [...]string{
    33  		// bytes
    34  		"(*bytes.Buffer).Write",
    35  		"(*bytes.Buffer).WriteByte",
    36  		"(*bytes.Buffer).WriteRune",
    37  		"(*bytes.Buffer).WriteString",
    38  
    39  		// fmt
    40  		"fmt.Errorf",
    41  		"fmt.Print",
    42  		"fmt.Printf",
    43  		"fmt.Println",
    44  		"fmt.Fprint(*bytes.Buffer)",
    45  		"fmt.Fprintf(*bytes.Buffer)",
    46  		"fmt.Fprintln(*bytes.Buffer)",
    47  		"fmt.Fprint(*strings.Builder)",
    48  		"fmt.Fprintf(*strings.Builder)",
    49  		"fmt.Fprintln(*strings.Builder)",
    50  		"fmt.Fprint(os.Stderr)",
    51  		"fmt.Fprintf(os.Stderr)",
    52  		"fmt.Fprintln(os.Stderr)",
    53  
    54  		// math/rand
    55  		"math/rand.Read",
    56  		"(*math/rand.Rand).Read",
    57  
    58  		// hash
    59  		"(hash.Hash).Write",
    60  	} {
    61  		exclusions[exc] = true
    62  	}
    63  }
    64  
    65  func run(pass *analysis.Pass) (interface{}, error) {
    66  	inspection, ok := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    67  	if !ok {
    68  		return nil, errors.New("analyzer is not type *inspector.Inspector")
    69  	}
    70  
    71  	nodeFilter := []ast.Node{
    72  		(*ast.CallExpr)(nil),
    73  		(*ast.ExprStmt)(nil),
    74  		(*ast.GoStmt)(nil),
    75  		(*ast.DeferStmt)(nil),
    76  		(*ast.AssignStmt)(nil),
    77  	}
    78  
    79  	inspection.Preorder(nodeFilter, func(node ast.Node) {
    80  		switch stmt := node.(type) {
    81  		case *ast.ExprStmt:
    82  			if call, ok := stmt.X.(*ast.CallExpr); ok {
    83  				if !ignoreCall(pass, call) && callReturnsError(pass, call) {
    84  					reportUnhandledError(pass, call.Lparen, call)
    85  				}
    86  			}
    87  		case *ast.GoStmt:
    88  			if !ignoreCall(pass, stmt.Call) && callReturnsError(pass, stmt.Call) {
    89  				reportUnhandledError(pass, stmt.Call.Lparen, stmt.Call)
    90  			}
    91  		case *ast.DeferStmt:
    92  			if !ignoreCall(pass, stmt.Call) && callReturnsError(pass, stmt.Call) {
    93  				reportUnhandledError(pass, stmt.Call.Lparen, stmt.Call)
    94  			}
    95  		case *ast.AssignStmt:
    96  			if len(stmt.Rhs) == 1 {
    97  				// single value on rhs; check against lhs identifiers
    98  				if call, ok := stmt.Rhs[0].(*ast.CallExpr); ok {
    99  					if ignoreCall(pass, call) {
   100  						break
   101  					}
   102  					isError := errorsByArg(pass, call)
   103  					for i := 0; i < len(stmt.Lhs); i++ {
   104  						if id, ok := stmt.Lhs[i].(*ast.Ident); ok {
   105  							// We shortcut calls to recover() because errorsByArg can't
   106  							// check its return types for errors since it returns interface{}.
   107  							if id.Name == "_" && (isRecover(pass, call) || isError[i]) {
   108  								reportUnhandledError(pass, id.NamePos, call)
   109  							}
   110  						}
   111  					}
   112  				} else if assert, ok := stmt.Rhs[0].(*ast.TypeAssertExpr); ok {
   113  					if assert.Type == nil {
   114  						// type switch
   115  						break
   116  					}
   117  					if len(stmt.Lhs) < 2 {
   118  						// assertion result not read
   119  						reportUnhandledTypeAssertion(pass, stmt.Rhs[0].Pos())
   120  					} else if id, ok := stmt.Lhs[1].(*ast.Ident); ok && id.Name == "_" {
   121  						// assertion result ignored
   122  						reportUnhandledTypeAssertion(pass, id.NamePos)
   123  					}
   124  				}
   125  			} else {
   126  				// multiple value on rhs; in this case a call can't return
   127  				// multiple values. Assume len(stmt.Lhs) == len(stmt.Rhs)
   128  				for i := 0; i < len(stmt.Lhs); i++ {
   129  					if id, ok := stmt.Lhs[i].(*ast.Ident); ok {
   130  						if call, ok := stmt.Rhs[i].(*ast.CallExpr); ok {
   131  							if ignoreCall(pass, call) {
   132  								continue
   133  							}
   134  							if id.Name == "_" && callReturnsError(pass, call) {
   135  								reportUnhandledError(pass, id.NamePos, call)
   136  							}
   137  						} else if assert, ok := stmt.Rhs[i].(*ast.TypeAssertExpr); ok {
   138  							if assert.Type == nil {
   139  								// Shouldn't happen anyway, no multi assignment in type switches
   140  								continue
   141  							}
   142  							reportUnhandledError(pass, id.NamePos, nil)
   143  						}
   144  					}
   145  				}
   146  			}
   147  		default:
   148  		}
   149  	})
   150  
   151  	return nil, nil
   152  }
   153  
   154  func reportUnhandledError(pass *analysis.Pass, pos token.Pos, call *ast.CallExpr) {
   155  	pass.Reportf(pos, "Unhandled error for function call %s", fullName(pass, call))
   156  }
   157  
   158  func reportUnhandledTypeAssertion(pass *analysis.Pass, pos token.Pos) {
   159  	pass.Reportf(pos, "Unhandled type assertion check. You must test whether or not an "+
   160  		"interface implements the asserted type.")
   161  }
   162  
   163  func fullName(pass *analysis.Pass, call *ast.CallExpr) string {
   164  	_, fn, ok := selectorAndFunc(pass, call)
   165  	if !ok {
   166  		return ""
   167  	}
   168  	return fn.FullName()
   169  }
   170  
   171  // selectorAndFunc tries to get the selector and function from call expression.
   172  // For example, given the call expression representing "a.b()", the selector
   173  // is "a.b" and the function is "b" itself.
   174  //
   175  // The final return value will be true if it is able to do extract a selector
   176  // from the call and look up the function object it refers to.
   177  //
   178  // If the call does not include a selector (like if it is a plain "f()" function call)
   179  // then the final return value will be false.
   180  func selectorAndFunc(pass *analysis.Pass, call *ast.CallExpr) (*ast.SelectorExpr, *types.Func, bool) {
   181  	if call == nil || call.Fun == nil {
   182  		return nil, nil, false
   183  	}
   184  	sel, ok := call.Fun.(*ast.SelectorExpr)
   185  	if !ok {
   186  		return nil, nil, false
   187  	}
   188  
   189  	fn, ok := pass.TypesInfo.ObjectOf(sel.Sel).(*types.Func)
   190  	if !ok {
   191  		return nil, nil, false
   192  	}
   193  
   194  	return sel, fn, true
   195  
   196  }
   197  
   198  func ignoreCall(pass *analysis.Pass, call *ast.CallExpr) bool {
   199  	for _, name := range namesForExcludeCheck(pass, call) {
   200  		if exclusions[name] {
   201  			return true
   202  		}
   203  	}
   204  	return false
   205  }
   206  
   207  var errorType = types.Universe.Lookup("error").Type().Underlying().(*types.Interface)
   208  
   209  func isErrorType(t types.Type) bool {
   210  	return types.Implements(t, errorType)
   211  }
   212  
   213  func callReturnsError(pass *analysis.Pass, call *ast.CallExpr) bool {
   214  	if isRecover(pass, call) {
   215  		return true
   216  	}
   217  
   218  	for _, isError := range errorsByArg(pass, call) {
   219  		if isError {
   220  			return true
   221  		}
   222  	}
   223  
   224  	return false
   225  }
   226  
   227  // errorsByArg returns a slice s such that
   228  // len(s) == number of return types of call
   229  // s[i] == true iff return type at position i from left is an error type
   230  func errorsByArg(pass *analysis.Pass, call *ast.CallExpr) []bool {
   231  	switch t := pass.TypesInfo.Types[call].Type.(type) {
   232  	case *types.Named:
   233  		// Single return
   234  		return []bool{isErrorType(t)}
   235  	case *types.Pointer:
   236  		// Single return via pointer
   237  		return []bool{isErrorType(t)}
   238  	case *types.Tuple:
   239  		// Multiple returns
   240  		s := make([]bool, t.Len())
   241  		for i := 0; i < t.Len(); i++ {
   242  			switch et := t.At(i).Type().(type) {
   243  			case *types.Named:
   244  				// Single return
   245  				s[i] = isErrorType(et)
   246  			case *types.Pointer:
   247  				// Single return via pointer
   248  				s[i] = isErrorType(et)
   249  			default:
   250  				s[i] = false
   251  			}
   252  		}
   253  		return s
   254  	}
   255  	return []bool{false}
   256  }
   257  
   258  func isRecover(pass *analysis.Pass, call *ast.CallExpr) bool {
   259  	if fun, ok := call.Fun.(*ast.Ident); ok {
   260  		if _, ok := pass.TypesInfo.Uses[fun].(*types.Builtin); ok {
   261  			return fun.Name == "recover"
   262  		}
   263  	}
   264  	return false
   265  }
   266  
   267  func namesForExcludeCheck(pass *analysis.Pass, call *ast.CallExpr) []string {
   268  	sel, fn, ok := selectorAndFunc(pass, call)
   269  	if !ok {
   270  		return nil
   271  	}
   272  
   273  	name := fullName(pass, call)
   274  	if name == "" {
   275  		return nil
   276  	}
   277  
   278  	// This will be missing for functions without a receiver (like fmt.Printf),
   279  	// so just fall back to the the function's fullName in that case.
   280  	selection, ok := pass.TypesInfo.Selections[sel]
   281  	if !ok {
   282  		return []string{name}
   283  	}
   284  
   285  	// This will return with ok false if the function isn't defined
   286  	// on an interface, so just fall back to the fullName.
   287  	ts, ok := walkThroughEmbeddedInterfaces(selection)
   288  	if !ok {
   289  		return []string{name}
   290  	}
   291  
   292  	result := make([]string, len(ts))
   293  	for i, t := range ts {
   294  		// Like in fullName, vendored packages will have /vendor/ in their name,
   295  		// thus not matching vendored standard library packages. If we
   296  		// want to support vendored stdlib packages, we need to implement
   297  		// additional logic here.
   298  		result[i] = fmt.Sprintf("(%s).%s", t.String(), fn.Name())
   299  	}
   300  	return result
   301  }
   302  
   303  // walkThroughEmbeddedInterfaces returns a slice of Interfaces that
   304  // we need to walk through in order to reach the actual definition,
   305  // in an Interface, of the method selected by the given selection.
   306  //
   307  // false will be returned in the second return value if:
   308  //   - the right side of the selection is not a function
   309  //   - the actual definition of the function is not in an Interface
   310  //
   311  // The returned slice will contain all the interface types that need
   312  // to be walked through to reach the actual definition.
   313  //
   314  // For example, say we have:
   315  //
   316  //    type Inner interface {Method()}
   317  //    type Middle interface {Inner}
   318  //    type Outer interface {Middle}
   319  //    type T struct {Outer}
   320  //    type U struct {T}
   321  //    type V struct {U}
   322  //
   323  // And then the selector:
   324  //
   325  //    V.Method
   326  //
   327  // We'll return [Outer, Middle, Inner] by first walking through the embedded structs
   328  // until we reach the Outer interface, then descending through the embedded interfaces
   329  // until we find the one that actually explicitly defines Method.
   330  func walkThroughEmbeddedInterfaces(sel *types.Selection) ([]types.Type, bool) {
   331  	fn, ok := sel.Obj().(*types.Func)
   332  	if !ok {
   333  		return nil, false
   334  	}
   335  
   336  	// Start off at the receiver.
   337  	currentT := sel.Recv()
   338  
   339  	// First, we can walk through any Struct fields provided
   340  	// by the selection Index() method. We ignore the last
   341  	// index because it would give the method itself.
   342  	indexes := sel.Index()
   343  	for _, fieldIndex := range indexes[:len(indexes)-1] {
   344  		currentT = typeAtFieldIndex(currentT, fieldIndex)
   345  	}
   346  
   347  	// Now currentT is either a type implementing the actual function,
   348  	// an Invalid type (if the receiver is a package), or an interface.
   349  	//
   350  	// If it's not an Interface, then we're done, as this function
   351  	// only cares about Interface-defined functions.
   352  	//
   353  	// If it is an Interface, we potentially need to continue digging until
   354  	// we find the Interface that actually explicitly defines the function.
   355  	interfaceT, ok := maybeUnname(currentT).(*types.Interface)
   356  	if !ok {
   357  		return nil, false
   358  	}
   359  
   360  	// The first interface we pass through is this one we've found. We return the possibly
   361  	// wrapping types.Named because it is more useful to work with for callers.
   362  	result := []types.Type{currentT}
   363  
   364  	// If this interface itself explicitly defines the given method
   365  	// then we're done digging.
   366  	for !explicitlyDefinesMethod(interfaceT, fn) {
   367  		// Otherwise, we find which of the embedded interfaces _does_
   368  		// define the method, add it to our list, and loop.
   369  		namedInterfaceT, ok := embeddedInterfaceDefiningMethod(interfaceT, fn)
   370  		if !ok {
   371  			// This should be impossible as long as we type-checked: either the
   372  			// interface or one of its embedded ones must implement the method...
   373  			panic(fmt.Sprintf("either %v or one of its embedded interfaces must implement %v", currentT, fn))
   374  		}
   375  		result = append(result, namedInterfaceT)
   376  		interfaceT, ok = namedInterfaceT.Underlying().(*types.Interface)
   377  		if !ok {
   378  			panic(fmt.Sprintf("either %v or one of its embedded interfaces must implement %v", currentT, fn))
   379  		}
   380  	}
   381  
   382  	return result, true
   383  }
   384  
   385  func typeAtFieldIndex(startingAt types.Type, fieldIndex int) types.Type {
   386  	t := maybeUnname(maybeDereference(startingAt))
   387  	s, ok := t.(*types.Struct)
   388  	if !ok {
   389  		panic(fmt.Sprintf("cannot get Field of a type that is not a struct, got a %T", t))
   390  	}
   391  
   392  	return s.Field(fieldIndex).Type()
   393  }
   394  
   395  // embeddedInterfaceDefiningMethod searches through any embedded interfaces of the
   396  // passed interface searching for one that defines the given function. If found, the
   397  // types.Named wrapping that interface will be returned along with true in the second value.
   398  //
   399  // If no such embedded interface is found, nil and false are returned.
   400  func embeddedInterfaceDefiningMethod(interfaceT *types.Interface, fn *types.Func) (*types.Named, bool) {
   401  	for i := 0; i < interfaceT.NumEmbeddeds(); i++ {
   402  		embedded, ok := interfaceT.EmbeddedType(i).(*types.Named)
   403  		if !ok {
   404  			return nil, false
   405  		}
   406  		if definesMethod(embedded.Underlying().(*types.Interface), fn) {
   407  			return embedded, true
   408  		}
   409  	}
   410  	return nil, false
   411  }
   412  
   413  func explicitlyDefinesMethod(interfaceT *types.Interface, fn *types.Func) bool {
   414  	for i := 0; i < interfaceT.NumExplicitMethods(); i++ {
   415  		if interfaceT.ExplicitMethod(i) == fn {
   416  			return true
   417  		}
   418  	}
   419  	return false
   420  }
   421  
   422  func definesMethod(interfaceT *types.Interface, fn *types.Func) bool {
   423  	for i := 0; i < interfaceT.NumMethods(); i++ {
   424  		if interfaceT.Method(i) == fn {
   425  			return true
   426  		}
   427  	}
   428  	return false
   429  }
   430  
   431  func maybeDereference(t types.Type) types.Type {
   432  	p, ok := t.(*types.Pointer)
   433  	if ok {
   434  		return p.Elem()
   435  	}
   436  	return t
   437  }
   438  
   439  func maybeUnname(t types.Type) types.Type {
   440  	n, ok := t.(*types.Named)
   441  	if ok {
   442  		return n.Underlying()
   443  	}
   444  	return t
   445  }