github.com/inturn/pre-commit-gobuild@v1.0.12/internal/errchecker/errcheck.go (about) 1 // Package errcheck is the library used to implement the errcheck command-line tool. 2 // 3 // Note: The API of this package has not been finalized and may change at any point. 4 package errchecker 5 6 import ( 7 "bufio" 8 "errors" 9 "fmt" 10 "go/ast" 11 "go/token" 12 "go/types" 13 "os" 14 "regexp" 15 "sort" 16 "strings" 17 "sync" 18 19 "golang.org/x/tools/go/packages" 20 ) 21 22 var errorType *types.Interface 23 24 func init() { 25 errorType = types.Universe.Lookup("error").Type().Underlying().(*types.Interface) 26 27 } 28 29 var ( 30 // ErrNoGoFiles is returned when CheckPackage is run on a package with no Go source files 31 ErrNoGoFiles = errors.New("package contains no go source files") 32 ) 33 34 // UncheckedError indicates the position of an unchecked error return. 35 type UncheckedError struct { 36 Pos token.Position 37 Line string 38 FuncName string 39 } 40 41 // UncheckedErrors is returned from the CheckPackage function if the package contains 42 // any unchecked errors. 43 // Errors should be appended using the Append method, which is safe to use concurrently. 44 type UncheckedErrors struct { 45 mu sync.Mutex 46 47 // Errors is a list of all the unchecked errors in the package. 48 // Printing an error reports its position within the file and the contents of the line. 49 Errors []UncheckedError 50 } 51 52 func (e *UncheckedErrors) Append(errors ...UncheckedError) { 53 e.mu.Lock() 54 defer e.mu.Unlock() 55 e.Errors = append(e.Errors, errors...) 56 } 57 58 func (e *UncheckedErrors) Error() string { 59 return fmt.Sprintf("%d unchecked errors", len(e.Errors)) 60 } 61 62 // Len is the number of elements in the collection. 63 func (e *UncheckedErrors) Len() int { return len(e.Errors) } 64 65 // Swap swaps the elements with indexes i and j. 66 func (e *UncheckedErrors) Swap(i, j int) { e.Errors[i], e.Errors[j] = e.Errors[j], e.Errors[i] } 67 68 type byName struct{ *UncheckedErrors } 69 70 // Less reports whether the element with index i should sort before the element with index j. 71 func (e byName) Less(i, j int) bool { 72 ei, ej := e.Errors[i], e.Errors[j] 73 74 pi, pj := ei.Pos, ej.Pos 75 76 if pi.Filename != pj.Filename { 77 return pi.Filename < pj.Filename 78 } 79 if pi.Line != pj.Line { 80 return pi.Line < pj.Line 81 } 82 if pi.Column != pj.Column { 83 return pi.Column < pj.Column 84 } 85 86 return ei.Line < ej.Line 87 } 88 89 type Checker struct { 90 // ignore is a map of package names to regular expressions. Identifiers from a package are 91 // checked against its regular expressions and if any of the expressions match the call 92 // is not checked. 93 Ignore map[string]*regexp.Regexp 94 95 // If blank is true then assignments to the blank identifier are also considered to be 96 // ignored errors. 97 Blank bool 98 99 // If asserts is true then ignored type assertion results are also checked 100 Asserts bool 101 102 // build tags 103 Tags []string 104 105 Verbose bool 106 107 // If true, checking of _test.go files is disabled 108 WithoutTests bool 109 110 // If true, checking of files with generated code is disabled 111 WithoutGeneratedCode bool 112 113 exclude map[string]bool 114 } 115 116 func NewChecker() *Checker { 117 c := Checker{} 118 c.SetExclude(map[string]bool{}) 119 return &c 120 } 121 122 func (c *Checker) SetExclude(l map[string]bool) { 123 c.exclude = map[string]bool{} 124 125 // Default exclude for stdlib functions 126 for _, exc := range []string{ 127 // bytes 128 "(*bytes.Buffer).Write", 129 "(*bytes.Buffer).WriteByte", 130 "(*bytes.Buffer).WriteRune", 131 "(*bytes.Buffer).WriteString", 132 133 // fmt 134 "fmt.Errorf", 135 "fmt.Print", 136 "fmt.Printf", 137 "fmt.Println", 138 139 // math/rand 140 "math/rand.Read", 141 "(*math/rand.Rand).Read", 142 143 // strings 144 "(*strings.Builder).Write", 145 "(*strings.Builder).WriteByte", 146 "(*strings.Builder).WriteRune", 147 "(*strings.Builder).WriteString", 148 149 // hash 150 "(hash.Hash).Write", 151 } { 152 c.exclude[exc] = true 153 } 154 155 for k := range l { 156 c.exclude[k] = true 157 } 158 } 159 160 func (c *Checker) logf(msg string, args ...interface{}) { 161 if c.Verbose { 162 fmt.Fprintf(os.Stderr, msg+"\n", args...) 163 } 164 } 165 166 // loadPackages is used for testing. 167 var loadPackages = func(cfg *packages.Config, paths ...string) ([]*packages.Package, error) { 168 return packages.Load(cfg, paths...) 169 } 170 171 func (c *Checker) load(paths ...string) ([]*packages.Package, error) { 172 cfg := &packages.Config{ 173 Mode: packages.LoadAllSyntax, 174 Tests: !c.WithoutTests, 175 BuildFlags: []string{fmt.Sprintf("-tags=%s", strings.Join(c.Tags, " "))}, 176 } 177 return loadPackages(cfg, paths...) 178 } 179 180 var generatedCodeRegexp = regexp.MustCompile("^// Code generated .* DO NOT EDIT\\.$") 181 182 func (c *Checker) shouldSkipFile(file *ast.File) bool { 183 if !c.WithoutGeneratedCode { 184 return false 185 } 186 187 for _, cg := range file.Comments { 188 for _, comment := range cg.List { 189 if generatedCodeRegexp.MatchString(comment.Text) { 190 return true 191 } 192 } 193 } 194 195 return false 196 } 197 198 // CheckPackages checks packages for errors. 199 func (c *Checker) CheckPackages(paths ...string) error { 200 pkgs, err := c.load(paths...) 201 if err != nil { 202 return err 203 } 204 // Check for errors in the initial packages. 205 for _, pkg := range pkgs { 206 if len(pkg.Errors) > 0 { 207 return fmt.Errorf("errors while loading package %s: %v", pkg.ID, pkg.Errors) 208 } 209 } 210 211 var wg sync.WaitGroup 212 u := &UncheckedErrors{} 213 for _, pkg := range pkgs { 214 wg.Add(1) 215 216 go func(pkg *packages.Package) { 217 defer wg.Done() 218 c.logf("Checking %s", pkg.Types.Path()) 219 220 v := &visitor{ 221 pkg: pkg, 222 ignore: c.Ignore, 223 blank: c.Blank, 224 asserts: c.Asserts, 225 lines: make(map[string][]string), 226 exclude: c.exclude, 227 errors: []UncheckedError{}, 228 } 229 230 for _, astFile := range v.pkg.Syntax { 231 if c.shouldSkipFile(astFile) { 232 continue 233 } 234 ast.Walk(v, astFile) 235 } 236 u.Append(v.errors...) 237 }(pkg) 238 } 239 240 wg.Wait() 241 if u.Len() > 0 { 242 // Sort unchecked errors and remove duplicates. Duplicates may occur when a file 243 // containing an unchecked error belongs to > 1 package. 244 sort.Sort(byName{u}) 245 uniq := u.Errors[:0] // compact in-place 246 for i, err := range u.Errors { 247 if i == 0 || err != u.Errors[i-1] { 248 uniq = append(uniq, err) 249 } 250 } 251 u.Errors = uniq 252 return u 253 } 254 return nil 255 } 256 257 // visitor implements the errcheck algorithm 258 type visitor struct { 259 pkg *packages.Package 260 ignore map[string]*regexp.Regexp 261 blank bool 262 asserts bool 263 lines map[string][]string 264 exclude map[string]bool 265 266 errors []UncheckedError 267 } 268 269 // selectorAndFunc tries to get the selector and function from call expression. 270 // For example, given the call expression representing "a.b()", the selector 271 // is "a.b" and the function is "b" itself. 272 // 273 // The final return value will be true if it is able to do extract a selector 274 // from the call and look up the function object it refers to. 275 // 276 // If the call does not include a selector (like if it is a plain "f()" function call) 277 // then the final return value will be false. 278 func (v *visitor) selectorAndFunc(call *ast.CallExpr) (*ast.SelectorExpr, *types.Func, bool) { 279 sel, ok := call.Fun.(*ast.SelectorExpr) 280 if !ok { 281 return nil, nil, false 282 } 283 284 fn, ok := v.pkg.TypesInfo.ObjectOf(sel.Sel).(*types.Func) 285 if !ok { 286 // Shouldn't happen, but be paranoid 287 return nil, nil, false 288 } 289 290 return sel, fn, true 291 292 } 293 294 // fullName will return a package / receiver-type qualified name for a called function 295 // if the function is the result of a selector. Otherwise it will return 296 // the empty string. 297 // 298 // The name is fully qualified by the import path, possible type, 299 // function/method name and pointer receiver. 300 // 301 // For example, 302 // - for "fmt.Printf(...)" it will return "fmt.Printf" 303 // - for "base64.StdEncoding.Decode(...)" it will return "(*encoding/base64.Encoding).Decode" 304 // - for "myFunc()" it will return "" 305 func (v *visitor) fullName(call *ast.CallExpr) string { 306 _, fn, ok := v.selectorAndFunc(call) 307 if !ok { 308 return "" 309 } 310 311 // TODO(dh): vendored packages will have /vendor/ in their name, 312 // thus not matching vendored standard library packages. If we 313 // want to support vendored stdlib packages, we need to implement 314 // FullName with our own logic. 315 return fn.FullName() 316 } 317 318 // namesForExcludeCheck will return a list of fully-qualified function names 319 // from a function call that can be used to check against the exclusion list. 320 // 321 // If a function call is against a local function (like "myFunc()") then no 322 // names are returned. If the function is package-qualified (like "fmt.Printf()") 323 // then just that function's fullName is returned. 324 // 325 // Otherwise, we walk through all the potentially embeddded interfaces of the receiver 326 // the collect a list of type-qualified function names that we will check. 327 func (v *visitor) namesForExcludeCheck(call *ast.CallExpr) []string { 328 sel, fn, ok := v.selectorAndFunc(call) 329 if !ok { 330 return nil 331 } 332 333 name := v.fullName(call) 334 if name == "" { 335 return nil 336 } 337 338 // This will be missing for functions without a receiver (like fmt.Printf), 339 // so just fall back to the the function's fullName in that case. 340 selection, ok := v.pkg.TypesInfo.Selections[sel] 341 if !ok { 342 return []string{name} 343 } 344 345 // This will return with ok false if the function isn't defined 346 // on an interface, so just fall back to the fullName. 347 ts, ok := walkThroughEmbeddedInterfaces(selection) 348 if !ok { 349 return []string{name} 350 } 351 352 result := make([]string, len(ts)) 353 for i, t := range ts { 354 // Like in fullName, vendored packages will have /vendor/ in their name, 355 // thus not matching vendored standard library packages. If we 356 // want to support vendored stdlib packages, we need to implement 357 // additional logic here. 358 result[i] = fmt.Sprintf("(%s).%s", t.String(), fn.Name()) 359 } 360 return result 361 } 362 363 func (v *visitor) excludeCall(call *ast.CallExpr) bool { 364 for _, name := range v.namesForExcludeCheck(call) { 365 if v.exclude[name] { 366 return true 367 } 368 } 369 370 return false 371 } 372 373 func (v *visitor) ignoreCall(call *ast.CallExpr) bool { 374 if v.excludeCall(call) { 375 return true 376 } 377 378 // Try to get an identifier. 379 // Currently only supports simple expressions: 380 // 1. f() 381 // 2. x.y.f() 382 var id *ast.Ident 383 switch exp := call.Fun.(type) { 384 case (*ast.Ident): 385 id = exp 386 case (*ast.SelectorExpr): 387 id = exp.Sel 388 default: 389 // eg: *ast.SliceExpr, *ast.IndexExpr 390 } 391 392 if id == nil { 393 return false 394 } 395 396 // If we got an identifier for the function, see if it is ignored 397 if re, ok := v.ignore[""]; ok && re.MatchString(id.Name) { 398 return true 399 } 400 401 if obj := v.pkg.TypesInfo.Uses[id]; obj != nil { 402 if pkg := obj.Pkg(); pkg != nil { 403 if re, ok := v.ignore[pkg.Path()]; ok { 404 return re.MatchString(id.Name) 405 } 406 407 // if current package being considered is vendored, check to see if it should be ignored based 408 // on the unvendored path. 409 if nonVendoredPkg, ok := nonVendoredPkgPath(pkg.Path()); ok { 410 if re, ok := v.ignore[nonVendoredPkg]; ok { 411 return re.MatchString(id.Name) 412 } 413 } 414 } 415 } 416 417 return false 418 } 419 420 // nonVendoredPkgPath returns the unvendored version of the provided package path (or returns the provided path if it 421 // does not represent a vendored path). The second return value is true if the provided package was vendored, false 422 // otherwise. 423 func nonVendoredPkgPath(pkgPath string) (string, bool) { 424 lastVendorIndex := strings.LastIndex(pkgPath, "/vendor/") 425 if lastVendorIndex == -1 { 426 return pkgPath, false 427 } 428 return pkgPath[lastVendorIndex+len("/vendor/"):], true 429 } 430 431 // errorsByArg returns a slice s such that 432 // len(s) == number of return types of call 433 // s[i] == true iff return type at position i from left is an error type 434 func (v *visitor) errorsByArg(call *ast.CallExpr) []bool { 435 switch t := v.pkg.TypesInfo.Types[call].Type.(type) { 436 case *types.Named: 437 // Single return 438 return []bool{isErrorType(t)} 439 case *types.Pointer: 440 // Single return via pointer 441 return []bool{isErrorType(t)} 442 case *types.Tuple: 443 // Multiple returns 444 s := make([]bool, t.Len()) 445 for i := 0; i < t.Len(); i++ { 446 switch et := t.At(i).Type().(type) { 447 case *types.Named: 448 // Single return 449 s[i] = isErrorType(et) 450 case *types.Pointer: 451 // Single return via pointer 452 s[i] = isErrorType(et) 453 default: 454 s[i] = false 455 } 456 } 457 return s 458 } 459 return []bool{false} 460 } 461 462 func (v *visitor) callReturnsError(call *ast.CallExpr) bool { 463 if v.isRecover(call) { 464 return true 465 } 466 for _, isError := range v.errorsByArg(call) { 467 if isError { 468 return true 469 } 470 } 471 return false 472 } 473 474 // isRecover returns true if the given CallExpr is a call to the built-in recover() function. 475 func (v *visitor) isRecover(call *ast.CallExpr) bool { 476 if fun, ok := call.Fun.(*ast.Ident); ok { 477 if _, ok := v.pkg.TypesInfo.Uses[fun].(*types.Builtin); ok { 478 return fun.Name == "recover" 479 } 480 } 481 return false 482 } 483 484 func (v *visitor) addErrorAtPosition(position token.Pos, call *ast.CallExpr) { 485 pos := v.pkg.Fset.Position(position) 486 lines, ok := v.lines[pos.Filename] 487 if !ok { 488 lines = readfile(pos.Filename) 489 v.lines[pos.Filename] = lines 490 } 491 492 line := "??" 493 if pos.Line-1 < len(lines) { 494 line = strings.TrimSpace(lines[pos.Line-1]) 495 } 496 497 var name string 498 if call != nil { 499 name = v.fullName(call) 500 } 501 502 v.errors = append(v.errors, UncheckedError{pos, line, name}) 503 } 504 505 func readfile(filename string) []string { 506 var f, err = os.Open(filename) 507 if err != nil { 508 return nil 509 } 510 511 var lines []string 512 var scanner = bufio.NewScanner(f) 513 for scanner.Scan() { 514 lines = append(lines, scanner.Text()) 515 } 516 return lines 517 } 518 519 func (v *visitor) Visit(node ast.Node) ast.Visitor { 520 switch stmt := node.(type) { 521 case *ast.ExprStmt: 522 if call, ok := stmt.X.(*ast.CallExpr); ok { 523 if !v.ignoreCall(call) && v.callReturnsError(call) { 524 v.addErrorAtPosition(call.Lparen, call) 525 } 526 } 527 case *ast.GoStmt: 528 if !v.ignoreCall(stmt.Call) && v.callReturnsError(stmt.Call) { 529 v.addErrorAtPosition(stmt.Call.Lparen, stmt.Call) 530 } 531 case *ast.DeferStmt: 532 if !v.ignoreCall(stmt.Call) && v.callReturnsError(stmt.Call) { 533 v.addErrorAtPosition(stmt.Call.Lparen, stmt.Call) 534 } 535 case *ast.AssignStmt: 536 if len(stmt.Rhs) == 1 { 537 // single value on rhs; check against lhs identifiers 538 if call, ok := stmt.Rhs[0].(*ast.CallExpr); ok { 539 if !v.blank { 540 break 541 } 542 if v.ignoreCall(call) { 543 break 544 } 545 isError := v.errorsByArg(call) 546 for i := 0; i < len(stmt.Lhs); i++ { 547 if id, ok := stmt.Lhs[i].(*ast.Ident); ok { 548 // We shortcut calls to recover() because errorsByArg can't 549 // check its return types for errors since it returns interface{}. 550 if id.Name == "_" && (v.isRecover(call) || isError[i]) { 551 v.addErrorAtPosition(id.NamePos, call) 552 } 553 } 554 } 555 } else if assert, ok := stmt.Rhs[0].(*ast.TypeAssertExpr); ok { 556 if !v.asserts { 557 break 558 } 559 if assert.Type == nil { 560 // type switch 561 break 562 } 563 if len(stmt.Lhs) < 2 { 564 // assertion result not read 565 v.addErrorAtPosition(stmt.Rhs[0].Pos(), nil) 566 } else if id, ok := stmt.Lhs[1].(*ast.Ident); ok && v.blank && id.Name == "_" { 567 // assertion result ignored 568 v.addErrorAtPosition(id.NamePos, nil) 569 } 570 } 571 } else { 572 // multiple value on rhs; in this case a call can't return 573 // multiple values. Assume len(stmt.Lhs) == len(stmt.Rhs) 574 for i := 0; i < len(stmt.Lhs); i++ { 575 if id, ok := stmt.Lhs[i].(*ast.Ident); ok { 576 if call, ok := stmt.Rhs[i].(*ast.CallExpr); ok { 577 if !v.blank { 578 continue 579 } 580 if v.ignoreCall(call) { 581 continue 582 } 583 if id.Name == "_" && v.callReturnsError(call) { 584 v.addErrorAtPosition(id.NamePos, call) 585 } 586 } else if assert, ok := stmt.Rhs[i].(*ast.TypeAssertExpr); ok { 587 if !v.asserts { 588 continue 589 } 590 if assert.Type == nil { 591 // Shouldn't happen anyway, no multi assignment in type switches 592 continue 593 } 594 v.addErrorAtPosition(id.NamePos, nil) 595 } 596 } 597 } 598 } 599 default: 600 } 601 return v 602 } 603 604 func isErrorType(t types.Type) bool { 605 return types.Implements(t, errorType) 606 }