github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/testutils/lint/passes/errcmp/errcmp.go (about)

     1  // Copyright 2020 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  // Package errcmp defines an Analyzer which checks
    12  // for usage of errors.Is instead of direct ==/!= comparisons.
    13  package errcmp
    14  
    15  import (
    16  	"go/ast"
    17  	"go/token"
    18  	"go/types"
    19  	"strings"
    20  
    21  	"golang.org/x/tools/go/analysis"
    22  	"golang.org/x/tools/go/analysis/passes/inspect"
    23  	"golang.org/x/tools/go/ast/inspector"
    24  )
    25  
    26  // Doc documents this pass.
    27  const Doc = `check for comparison of error objects`
    28  
    29  var errorType = types.Universe.Lookup("error").Type()
    30  
    31  // Analyzer checks for usage of errors.Is instead of direct ==/!=
    32  // comparisons.
    33  var Analyzer = &analysis.Analyzer{
    34  	Name:     "errcmp",
    35  	Doc:      Doc,
    36  	Requires: []*analysis.Analyzer{inspect.Analyzer},
    37  	Run:      run,
    38  }
    39  
    40  func run(pass *analysis.Pass) (interface{}, error) {
    41  	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
    42  
    43  	// Our analyzer just wants to see comparisons and casts.
    44  	nodeFilter := []ast.Node{
    45  		(*ast.BinaryExpr)(nil),
    46  		(*ast.TypeAssertExpr)(nil),
    47  		(*ast.SwitchStmt)(nil),
    48  	}
    49  
    50  	// Now traverse the ASTs.
    51  	inspect.Preorder(nodeFilter, func(n ast.Node) {
    52  		// Catch-all for possible bugs in the linter code.
    53  		defer func() {
    54  			if r := recover(); r != nil {
    55  				if err, ok := r.(error); ok {
    56  					pass.Reportf(n.Pos(), "internal linter error: %v", err)
    57  					return
    58  				}
    59  				panic(r)
    60  			}
    61  		}()
    62  
    63  		if cmp, ok := n.(*ast.BinaryExpr); ok {
    64  			checkErrCmp(pass, cmp)
    65  			return
    66  		}
    67  		if cmp, ok := n.(*ast.TypeAssertExpr); ok {
    68  			checkErrCast(pass, cmp)
    69  			return
    70  		}
    71  		if cmp, ok := n.(*ast.SwitchStmt); ok {
    72  			checkErrSwitch(pass, cmp)
    73  			return
    74  		}
    75  	})
    76  
    77  	return nil, nil
    78  }
    79  
    80  func checkErrSwitch(pass *analysis.Pass, s *ast.SwitchStmt) {
    81  	if pass.TypesInfo.Types[s.Tag].Type == errorType {
    82  		pass.Reportf(s.Switch, escNl(`invalid direct comparison of error object
    83  Tip:
    84     switch err { case errRef:...
    85  -> switch { case errors.Is(err, errRef): ...
    86  `))
    87  	}
    88  }
    89  
    90  func checkErrCast(pass *analysis.Pass, texpr *ast.TypeAssertExpr) {
    91  	if pass.TypesInfo.Types[texpr.X].Type == errorType {
    92  		pass.Reportf(texpr.Lparen, escNl(`invalid direct cast on error object
    93  Alternatives:
    94     if _, ok := err.(*T); ok        ->   if errors.HasType(err, (*T)(nil)
    95     if _, ok := err.(I); ok         ->   if errors.HasInterface(err, (*I)(nil))
    96     if myErr, ok := err.(*T); ok    ->   if myErr := (*T)(nil); errors.As(err, &myErr)
    97     if myErr, ok := err.(I); ok     ->   if myErr := (I)(nil); errors.As(err, &myErr)
    98     switch err.(type) { case *T:... ->   switch { case errors.HasType(err, (*T)(nil): ...
    99  `))
   100  	}
   101  }
   102  
   103  func isEOFError(e ast.Expr) bool {
   104  	if s, ok := e.(*ast.SelectorExpr); ok {
   105  		if io, ok := s.X.(*ast.Ident); ok && io.Name == "io" && io.Obj == (*ast.Object)(nil) {
   106  			if s.Sel.Name == "EOF" {
   107  				return true
   108  			}
   109  		}
   110  	}
   111  	return false
   112  }
   113  
   114  func checkErrCmp(pass *analysis.Pass, binaryExpr *ast.BinaryExpr) {
   115  	switch binaryExpr.Op {
   116  	case token.NEQ, token.EQL:
   117  		if pass.TypesInfo.Types[binaryExpr.X].Type == errorType &&
   118  			!pass.TypesInfo.Types[binaryExpr.Y].IsNil() {
   119  			// We have a special case: when the RHS is io.EOF.
   120  			// This is nearly always used with APIs that return
   121  			// it undecorated.
   122  			if isEOFError(binaryExpr.Y) {
   123  				return
   124  			}
   125  
   126  			pass.Reportf(binaryExpr.OpPos, escNl(`use errors.Is instead of a direct comparison
   127  For example:
   128     if errors.Is(err, errMyOwnErrReference) {
   129       ...
   130     }
   131  `))
   132  		}
   133  	}
   134  }
   135  
   136  func escNl(msg string) string {
   137  	return strings.ReplaceAll(msg, "\n", "\\n++")
   138  }