github.com/inturn/pre-commit-gobuild@v1.0.12/internal/errchecker/errcheck.go (about)

     1  // Package errcheck is the library used to implement the errcheck command-line tool.
     2  //
     3  // Note: The API of this package has not been finalized and may change at any point.
     4  package errchecker
     5  
     6  import (
     7  	"bufio"
     8  	"errors"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/token"
    12  	"go/types"
    13  	"os"
    14  	"regexp"
    15  	"sort"
    16  	"strings"
    17  	"sync"
    18  
    19  	"golang.org/x/tools/go/packages"
    20  )
    21  
    22  var errorType *types.Interface
    23  
    24  func init() {
    25  	errorType = types.Universe.Lookup("error").Type().Underlying().(*types.Interface)
    26  
    27  }
    28  
    29  var (
    30  	// ErrNoGoFiles is returned when CheckPackage is run on a package with no Go source files
    31  	ErrNoGoFiles = errors.New("package contains no go source files")
    32  )
    33  
    34  // UncheckedError indicates the position of an unchecked error return.
    35  type UncheckedError struct {
    36  	Pos      token.Position
    37  	Line     string
    38  	FuncName string
    39  }
    40  
    41  // UncheckedErrors is returned from the CheckPackage function if the package contains
    42  // any unchecked errors.
    43  // Errors should be appended using the Append method, which is safe to use concurrently.
    44  type UncheckedErrors struct {
    45  	mu sync.Mutex
    46  
    47  	// Errors is a list of all the unchecked errors in the package.
    48  	// Printing an error reports its position within the file and the contents of the line.
    49  	Errors []UncheckedError
    50  }
    51  
    52  func (e *UncheckedErrors) Append(errors ...UncheckedError) {
    53  	e.mu.Lock()
    54  	defer e.mu.Unlock()
    55  	e.Errors = append(e.Errors, errors...)
    56  }
    57  
    58  func (e *UncheckedErrors) Error() string {
    59  	return fmt.Sprintf("%d unchecked errors", len(e.Errors))
    60  }
    61  
    62  // Len is the number of elements in the collection.
    63  func (e *UncheckedErrors) Len() int { return len(e.Errors) }
    64  
    65  // Swap swaps the elements with indexes i and j.
    66  func (e *UncheckedErrors) Swap(i, j int) { e.Errors[i], e.Errors[j] = e.Errors[j], e.Errors[i] }
    67  
    68  type byName struct{ *UncheckedErrors }
    69  
    70  // Less reports whether the element with index i should sort before the element with index j.
    71  func (e byName) Less(i, j int) bool {
    72  	ei, ej := e.Errors[i], e.Errors[j]
    73  
    74  	pi, pj := ei.Pos, ej.Pos
    75  
    76  	if pi.Filename != pj.Filename {
    77  		return pi.Filename < pj.Filename
    78  	}
    79  	if pi.Line != pj.Line {
    80  		return pi.Line < pj.Line
    81  	}
    82  	if pi.Column != pj.Column {
    83  		return pi.Column < pj.Column
    84  	}
    85  
    86  	return ei.Line < ej.Line
    87  }
    88  
    89  type Checker struct {
    90  	// ignore is a map of package names to regular expressions. Identifiers from a package are
    91  	// checked against its regular expressions and if any of the expressions match the call
    92  	// is not checked.
    93  	Ignore map[string]*regexp.Regexp
    94  
    95  	// If blank is true then assignments to the blank identifier are also considered to be
    96  	// ignored errors.
    97  	Blank bool
    98  
    99  	// If asserts is true then ignored type assertion results are also checked
   100  	Asserts bool
   101  
   102  	// build tags
   103  	Tags []string
   104  
   105  	Verbose bool
   106  
   107  	// If true, checking of _test.go files is disabled
   108  	WithoutTests bool
   109  
   110  	// If true, checking of files with generated code is disabled
   111  	WithoutGeneratedCode bool
   112  
   113  	exclude map[string]bool
   114  }
   115  
   116  func NewChecker() *Checker {
   117  	c := Checker{}
   118  	c.SetExclude(map[string]bool{})
   119  	return &c
   120  }
   121  
   122  func (c *Checker) SetExclude(l map[string]bool) {
   123  	c.exclude = map[string]bool{}
   124  
   125  	// Default exclude for stdlib functions
   126  	for _, exc := range []string{
   127  		// bytes
   128  		"(*bytes.Buffer).Write",
   129  		"(*bytes.Buffer).WriteByte",
   130  		"(*bytes.Buffer).WriteRune",
   131  		"(*bytes.Buffer).WriteString",
   132  
   133  		// fmt
   134  		"fmt.Errorf",
   135  		"fmt.Print",
   136  		"fmt.Printf",
   137  		"fmt.Println",
   138  
   139  		// math/rand
   140  		"math/rand.Read",
   141  		"(*math/rand.Rand).Read",
   142  
   143  		// strings
   144  		"(*strings.Builder).Write",
   145  		"(*strings.Builder).WriteByte",
   146  		"(*strings.Builder).WriteRune",
   147  		"(*strings.Builder).WriteString",
   148  
   149  		// hash
   150  		"(hash.Hash).Write",
   151  	} {
   152  		c.exclude[exc] = true
   153  	}
   154  
   155  	for k := range l {
   156  		c.exclude[k] = true
   157  	}
   158  }
   159  
   160  func (c *Checker) logf(msg string, args ...interface{}) {
   161  	if c.Verbose {
   162  		fmt.Fprintf(os.Stderr, msg+"\n", args...)
   163  	}
   164  }
   165  
   166  // loadPackages is used for testing.
   167  var loadPackages = func(cfg *packages.Config, paths ...string) ([]*packages.Package, error) {
   168  	return packages.Load(cfg, paths...)
   169  }
   170  
   171  func (c *Checker) load(paths ...string) ([]*packages.Package, error) {
   172  	cfg := &packages.Config{
   173  		Mode:       packages.LoadAllSyntax,
   174  		Tests:      !c.WithoutTests,
   175  		BuildFlags: []string{fmt.Sprintf("-tags=%s", strings.Join(c.Tags, " "))},
   176  	}
   177  	return loadPackages(cfg, paths...)
   178  }
   179  
   180  var generatedCodeRegexp = regexp.MustCompile("^// Code generated .* DO NOT EDIT\\.$")
   181  
   182  func (c *Checker) shouldSkipFile(file *ast.File) bool {
   183  	if !c.WithoutGeneratedCode {
   184  		return false
   185  	}
   186  
   187  	for _, cg := range file.Comments {
   188  		for _, comment := range cg.List {
   189  			if generatedCodeRegexp.MatchString(comment.Text) {
   190  				return true
   191  			}
   192  		}
   193  	}
   194  
   195  	return false
   196  }
   197  
   198  // CheckPackages checks packages for errors.
   199  func (c *Checker) CheckPackages(paths ...string) error {
   200  	pkgs, err := c.load(paths...)
   201  	if err != nil {
   202  		return err
   203  	}
   204  	// Check for errors in the initial packages.
   205  	for _, pkg := range pkgs {
   206  		if len(pkg.Errors) > 0 {
   207  			return fmt.Errorf("errors while loading package %s: %v", pkg.ID, pkg.Errors)
   208  		}
   209  	}
   210  
   211  	var wg sync.WaitGroup
   212  	u := &UncheckedErrors{}
   213  	for _, pkg := range pkgs {
   214  		wg.Add(1)
   215  
   216  		go func(pkg *packages.Package) {
   217  			defer wg.Done()
   218  			c.logf("Checking %s", pkg.Types.Path())
   219  
   220  			v := &visitor{
   221  				pkg:     pkg,
   222  				ignore:  c.Ignore,
   223  				blank:   c.Blank,
   224  				asserts: c.Asserts,
   225  				lines:   make(map[string][]string),
   226  				exclude: c.exclude,
   227  				errors:  []UncheckedError{},
   228  			}
   229  
   230  			for _, astFile := range v.pkg.Syntax {
   231  				if c.shouldSkipFile(astFile) {
   232  					continue
   233  				}
   234  				ast.Walk(v, astFile)
   235  			}
   236  			u.Append(v.errors...)
   237  		}(pkg)
   238  	}
   239  
   240  	wg.Wait()
   241  	if u.Len() > 0 {
   242  		// Sort unchecked errors and remove duplicates. Duplicates may occur when a file
   243  		// containing an unchecked error belongs to > 1 package.
   244  		sort.Sort(byName{u})
   245  		uniq := u.Errors[:0] // compact in-place
   246  		for i, err := range u.Errors {
   247  			if i == 0 || err != u.Errors[i-1] {
   248  				uniq = append(uniq, err)
   249  			}
   250  		}
   251  		u.Errors = uniq
   252  		return u
   253  	}
   254  	return nil
   255  }
   256  
   257  // visitor implements the errcheck algorithm
   258  type visitor struct {
   259  	pkg     *packages.Package
   260  	ignore  map[string]*regexp.Regexp
   261  	blank   bool
   262  	asserts bool
   263  	lines   map[string][]string
   264  	exclude map[string]bool
   265  
   266  	errors []UncheckedError
   267  }
   268  
   269  // selectorAndFunc tries to get the selector and function from call expression.
   270  // For example, given the call expression representing "a.b()", the selector
   271  // is "a.b" and the function is "b" itself.
   272  //
   273  // The final return value will be true if it is able to do extract a selector
   274  // from the call and look up the function object it refers to.
   275  //
   276  // If the call does not include a selector (like if it is a plain "f()" function call)
   277  // then the final return value will be false.
   278  func (v *visitor) selectorAndFunc(call *ast.CallExpr) (*ast.SelectorExpr, *types.Func, bool) {
   279  	sel, ok := call.Fun.(*ast.SelectorExpr)
   280  	if !ok {
   281  		return nil, nil, false
   282  	}
   283  
   284  	fn, ok := v.pkg.TypesInfo.ObjectOf(sel.Sel).(*types.Func)
   285  	if !ok {
   286  		// Shouldn't happen, but be paranoid
   287  		return nil, nil, false
   288  	}
   289  
   290  	return sel, fn, true
   291  
   292  }
   293  
   294  // fullName will return a package / receiver-type qualified name for a called function
   295  // if the function is the result of a selector. Otherwise it will return
   296  // the empty string.
   297  //
   298  // The name is fully qualified by the import path, possible type,
   299  // function/method name and pointer receiver.
   300  //
   301  // For example,
   302  //   - for "fmt.Printf(...)" it will return "fmt.Printf"
   303  //   - for "base64.StdEncoding.Decode(...)" it will return "(*encoding/base64.Encoding).Decode"
   304  //   - for "myFunc()" it will return ""
   305  func (v *visitor) fullName(call *ast.CallExpr) string {
   306  	_, fn, ok := v.selectorAndFunc(call)
   307  	if !ok {
   308  		return ""
   309  	}
   310  
   311  	// TODO(dh): vendored packages will have /vendor/ in their name,
   312  	// thus not matching vendored standard library packages. If we
   313  	// want to support vendored stdlib packages, we need to implement
   314  	// FullName with our own logic.
   315  	return fn.FullName()
   316  }
   317  
   318  // namesForExcludeCheck will return a list of fully-qualified function names
   319  // from a function call that can be used to check against the exclusion list.
   320  //
   321  // If a function call is against a local function (like "myFunc()") then no
   322  // names are returned. If the function is package-qualified (like "fmt.Printf()")
   323  // then just that function's fullName is returned.
   324  //
   325  // Otherwise, we walk through all the potentially embeddded interfaces of the receiver
   326  // the collect a list of type-qualified function names that we will check.
   327  func (v *visitor) namesForExcludeCheck(call *ast.CallExpr) []string {
   328  	sel, fn, ok := v.selectorAndFunc(call)
   329  	if !ok {
   330  		return nil
   331  	}
   332  
   333  	name := v.fullName(call)
   334  	if name == "" {
   335  		return nil
   336  	}
   337  
   338  	// This will be missing for functions without a receiver (like fmt.Printf),
   339  	// so just fall back to the the function's fullName in that case.
   340  	selection, ok := v.pkg.TypesInfo.Selections[sel]
   341  	if !ok {
   342  		return []string{name}
   343  	}
   344  
   345  	// This will return with ok false if the function isn't defined
   346  	// on an interface, so just fall back to the fullName.
   347  	ts, ok := walkThroughEmbeddedInterfaces(selection)
   348  	if !ok {
   349  		return []string{name}
   350  	}
   351  
   352  	result := make([]string, len(ts))
   353  	for i, t := range ts {
   354  		// Like in fullName, vendored packages will have /vendor/ in their name,
   355  		// thus not matching vendored standard library packages. If we
   356  		// want to support vendored stdlib packages, we need to implement
   357  		// additional logic here.
   358  		result[i] = fmt.Sprintf("(%s).%s", t.String(), fn.Name())
   359  	}
   360  	return result
   361  }
   362  
   363  func (v *visitor) excludeCall(call *ast.CallExpr) bool {
   364  	for _, name := range v.namesForExcludeCheck(call) {
   365  		if v.exclude[name] {
   366  			return true
   367  		}
   368  	}
   369  
   370  	return false
   371  }
   372  
   373  func (v *visitor) ignoreCall(call *ast.CallExpr) bool {
   374  	if v.excludeCall(call) {
   375  		return true
   376  	}
   377  
   378  	// Try to get an identifier.
   379  	// Currently only supports simple expressions:
   380  	//     1. f()
   381  	//     2. x.y.f()
   382  	var id *ast.Ident
   383  	switch exp := call.Fun.(type) {
   384  	case (*ast.Ident):
   385  		id = exp
   386  	case (*ast.SelectorExpr):
   387  		id = exp.Sel
   388  	default:
   389  		// eg: *ast.SliceExpr, *ast.IndexExpr
   390  	}
   391  
   392  	if id == nil {
   393  		return false
   394  	}
   395  
   396  	// If we got an identifier for the function, see if it is ignored
   397  	if re, ok := v.ignore[""]; ok && re.MatchString(id.Name) {
   398  		return true
   399  	}
   400  
   401  	if obj := v.pkg.TypesInfo.Uses[id]; obj != nil {
   402  		if pkg := obj.Pkg(); pkg != nil {
   403  			if re, ok := v.ignore[pkg.Path()]; ok {
   404  				return re.MatchString(id.Name)
   405  			}
   406  
   407  			// if current package being considered is vendored, check to see if it should be ignored based
   408  			// on the unvendored path.
   409  			if nonVendoredPkg, ok := nonVendoredPkgPath(pkg.Path()); ok {
   410  				if re, ok := v.ignore[nonVendoredPkg]; ok {
   411  					return re.MatchString(id.Name)
   412  				}
   413  			}
   414  		}
   415  	}
   416  
   417  	return false
   418  }
   419  
   420  // nonVendoredPkgPath returns the unvendored version of the provided package path (or returns the provided path if it
   421  // does not represent a vendored path). The second return value is true if the provided package was vendored, false
   422  // otherwise.
   423  func nonVendoredPkgPath(pkgPath string) (string, bool) {
   424  	lastVendorIndex := strings.LastIndex(pkgPath, "/vendor/")
   425  	if lastVendorIndex == -1 {
   426  		return pkgPath, false
   427  	}
   428  	return pkgPath[lastVendorIndex+len("/vendor/"):], true
   429  }
   430  
   431  // errorsByArg returns a slice s such that
   432  // len(s) == number of return types of call
   433  // s[i] == true iff return type at position i from left is an error type
   434  func (v *visitor) errorsByArg(call *ast.CallExpr) []bool {
   435  	switch t := v.pkg.TypesInfo.Types[call].Type.(type) {
   436  	case *types.Named:
   437  		// Single return
   438  		return []bool{isErrorType(t)}
   439  	case *types.Pointer:
   440  		// Single return via pointer
   441  		return []bool{isErrorType(t)}
   442  	case *types.Tuple:
   443  		// Multiple returns
   444  		s := make([]bool, t.Len())
   445  		for i := 0; i < t.Len(); i++ {
   446  			switch et := t.At(i).Type().(type) {
   447  			case *types.Named:
   448  				// Single return
   449  				s[i] = isErrorType(et)
   450  			case *types.Pointer:
   451  				// Single return via pointer
   452  				s[i] = isErrorType(et)
   453  			default:
   454  				s[i] = false
   455  			}
   456  		}
   457  		return s
   458  	}
   459  	return []bool{false}
   460  }
   461  
   462  func (v *visitor) callReturnsError(call *ast.CallExpr) bool {
   463  	if v.isRecover(call) {
   464  		return true
   465  	}
   466  	for _, isError := range v.errorsByArg(call) {
   467  		if isError {
   468  			return true
   469  		}
   470  	}
   471  	return false
   472  }
   473  
   474  // isRecover returns true if the given CallExpr is a call to the built-in recover() function.
   475  func (v *visitor) isRecover(call *ast.CallExpr) bool {
   476  	if fun, ok := call.Fun.(*ast.Ident); ok {
   477  		if _, ok := v.pkg.TypesInfo.Uses[fun].(*types.Builtin); ok {
   478  			return fun.Name == "recover"
   479  		}
   480  	}
   481  	return false
   482  }
   483  
   484  func (v *visitor) addErrorAtPosition(position token.Pos, call *ast.CallExpr) {
   485  	pos := v.pkg.Fset.Position(position)
   486  	lines, ok := v.lines[pos.Filename]
   487  	if !ok {
   488  		lines = readfile(pos.Filename)
   489  		v.lines[pos.Filename] = lines
   490  	}
   491  
   492  	line := "??"
   493  	if pos.Line-1 < len(lines) {
   494  		line = strings.TrimSpace(lines[pos.Line-1])
   495  	}
   496  
   497  	var name string
   498  	if call != nil {
   499  		name = v.fullName(call)
   500  	}
   501  
   502  	v.errors = append(v.errors, UncheckedError{pos, line, name})
   503  }
   504  
   505  func readfile(filename string) []string {
   506  	var f, err = os.Open(filename)
   507  	if err != nil {
   508  		return nil
   509  	}
   510  
   511  	var lines []string
   512  	var scanner = bufio.NewScanner(f)
   513  	for scanner.Scan() {
   514  		lines = append(lines, scanner.Text())
   515  	}
   516  	return lines
   517  }
   518  
   519  func (v *visitor) Visit(node ast.Node) ast.Visitor {
   520  	switch stmt := node.(type) {
   521  	case *ast.ExprStmt:
   522  		if call, ok := stmt.X.(*ast.CallExpr); ok {
   523  			if !v.ignoreCall(call) && v.callReturnsError(call) {
   524  				v.addErrorAtPosition(call.Lparen, call)
   525  			}
   526  		}
   527  	case *ast.GoStmt:
   528  		if !v.ignoreCall(stmt.Call) && v.callReturnsError(stmt.Call) {
   529  			v.addErrorAtPosition(stmt.Call.Lparen, stmt.Call)
   530  		}
   531  	case *ast.DeferStmt:
   532  		if !v.ignoreCall(stmt.Call) && v.callReturnsError(stmt.Call) {
   533  			v.addErrorAtPosition(stmt.Call.Lparen, stmt.Call)
   534  		}
   535  	case *ast.AssignStmt:
   536  		if len(stmt.Rhs) == 1 {
   537  			// single value on rhs; check against lhs identifiers
   538  			if call, ok := stmt.Rhs[0].(*ast.CallExpr); ok {
   539  				if !v.blank {
   540  					break
   541  				}
   542  				if v.ignoreCall(call) {
   543  					break
   544  				}
   545  				isError := v.errorsByArg(call)
   546  				for i := 0; i < len(stmt.Lhs); i++ {
   547  					if id, ok := stmt.Lhs[i].(*ast.Ident); ok {
   548  						// We shortcut calls to recover() because errorsByArg can't
   549  						// check its return types for errors since it returns interface{}.
   550  						if id.Name == "_" && (v.isRecover(call) || isError[i]) {
   551  							v.addErrorAtPosition(id.NamePos, call)
   552  						}
   553  					}
   554  				}
   555  			} else if assert, ok := stmt.Rhs[0].(*ast.TypeAssertExpr); ok {
   556  				if !v.asserts {
   557  					break
   558  				}
   559  				if assert.Type == nil {
   560  					// type switch
   561  					break
   562  				}
   563  				if len(stmt.Lhs) < 2 {
   564  					// assertion result not read
   565  					v.addErrorAtPosition(stmt.Rhs[0].Pos(), nil)
   566  				} else if id, ok := stmt.Lhs[1].(*ast.Ident); ok && v.blank && id.Name == "_" {
   567  					// assertion result ignored
   568  					v.addErrorAtPosition(id.NamePos, nil)
   569  				}
   570  			}
   571  		} else {
   572  			// multiple value on rhs; in this case a call can't return
   573  			// multiple values. Assume len(stmt.Lhs) == len(stmt.Rhs)
   574  			for i := 0; i < len(stmt.Lhs); i++ {
   575  				if id, ok := stmt.Lhs[i].(*ast.Ident); ok {
   576  					if call, ok := stmt.Rhs[i].(*ast.CallExpr); ok {
   577  						if !v.blank {
   578  							continue
   579  						}
   580  						if v.ignoreCall(call) {
   581  							continue
   582  						}
   583  						if id.Name == "_" && v.callReturnsError(call) {
   584  							v.addErrorAtPosition(id.NamePos, call)
   585  						}
   586  					} else if assert, ok := stmt.Rhs[i].(*ast.TypeAssertExpr); ok {
   587  						if !v.asserts {
   588  							continue
   589  						}
   590  						if assert.Type == nil {
   591  							// Shouldn't happen anyway, no multi assignment in type switches
   592  							continue
   593  						}
   594  						v.addErrorAtPosition(id.NamePos, nil)
   595  					}
   596  				}
   597  			}
   598  		}
   599  	default:
   600  	}
   601  	return v
   602  }
   603  
   604  func isErrorType(t types.Type) bool {
   605  	return types.Implements(t, errorType)
   606  }