gopkg.in/alecthomas/gometalinter.v3@v3.0.0/_linters/src/github.com/alexkohler/nakedret/nakedret.go (about)

     1  package main
     2  
     3  import (
     4  	"errors"
     5  	"flag"
     6  	"fmt"
     7  	"go/ast"
     8  	"go/build"
     9  	"go/parser"
    10  	"go/token"
    11  	"log"
    12  	"os"
    13  	"path/filepath"
    14  	"strings"
    15  )
    16  
    17  const (
    18  	pwd = "./"
    19  )
    20  
    21  func init() {
    22  	//TODO allow build tags
    23  	build.Default.UseAllFiles = true
    24  }
    25  
    26  func usage() {
    27  	log.Printf("Usage of %s:\n", os.Args[0])
    28  	log.Printf("\nnakedret [flags] # runs on package in current directory\n")
    29  	log.Printf("\nnakedret [flags] [packages]\n")
    30  	log.Printf("Flags:\n")
    31  	flag.PrintDefaults()
    32  }
    33  
    34  type returnsVisitor struct {
    35  	f         *token.FileSet
    36  	maxLength uint
    37  }
    38  
    39  func main() {
    40  
    41  	// Remove log timestamp
    42  	log.SetFlags(0)
    43  
    44  	maxLength := flag.Uint("l", 5, "maximum number of lines for a naked return function")
    45  	flag.Usage = usage
    46  	flag.Parse()
    47  
    48  	if err := checkNakedReturns(flag.Args(), maxLength); err != nil {
    49  		log.Println(err)
    50  	}
    51  }
    52  
    53  func checkNakedReturns(args []string, maxLength *uint) error {
    54  
    55  	fset := token.NewFileSet()
    56  
    57  	files, err := parseInput(args, fset)
    58  	if err != nil {
    59  		return fmt.Errorf("could not parse input %v", err)
    60  	}
    61  
    62  	if maxLength == nil {
    63  		return errors.New("max length nil")
    64  	}
    65  
    66  	retVis := &returnsVisitor{
    67  		f:         fset,
    68  		maxLength: *maxLength,
    69  	}
    70  
    71  	for _, f := range files {
    72  		ast.Walk(retVis, f)
    73  	}
    74  
    75  	return nil
    76  }
    77  
    78  func parseInput(args []string, fset *token.FileSet) ([]*ast.File, error) {
    79  	var directoryList []string
    80  	var fileMode bool
    81  	files := make([]*ast.File, 0)
    82  
    83  	if len(args) == 0 {
    84  		directoryList = append(directoryList, pwd)
    85  	} else {
    86  		for _, arg := range args {
    87  			if strings.HasSuffix(arg, "/...") && isDir(arg[:len(arg)-len("/...")]) {
    88  
    89  				for _, dirname := range allPackagesInFS(arg) {
    90  					directoryList = append(directoryList, dirname)
    91  				}
    92  
    93  			} else if isDir(arg) {
    94  				directoryList = append(directoryList, arg)
    95  
    96  			} else if exists(arg) {
    97  				if strings.HasSuffix(arg, ".go") {
    98  					fileMode = true
    99  					f, err := parser.ParseFile(fset, arg, nil, 0)
   100  					if err != nil {
   101  						return nil, err
   102  					}
   103  					files = append(files, f)
   104  				} else {
   105  					return nil, fmt.Errorf("invalid file %v specified", arg)
   106  				}
   107  			} else {
   108  
   109  				//TODO clean this up a bit
   110  				imPaths := importPaths([]string{arg})
   111  				for _, importPath := range imPaths {
   112  					pkg, err := build.Import(importPath, ".", 0)
   113  					if err != nil {
   114  						return nil, err
   115  					}
   116  					var stringFiles []string
   117  					stringFiles = append(stringFiles, pkg.GoFiles...)
   118  					// files = append(files, pkg.CgoFiles...)
   119  					stringFiles = append(stringFiles, pkg.TestGoFiles...)
   120  					if pkg.Dir != "." {
   121  						for i, f := range stringFiles {
   122  							stringFiles[i] = filepath.Join(pkg.Dir, f)
   123  						}
   124  					}
   125  
   126  					fileMode = true
   127  					for _, stringFile := range stringFiles {
   128  						f, err := parser.ParseFile(fset, stringFile, nil, 0)
   129  						if err != nil {
   130  							return nil, err
   131  						}
   132  						files = append(files, f)
   133  					}
   134  
   135  				}
   136  			}
   137  		}
   138  	}
   139  
   140  	// if we're not in file mode, then we need to grab each and every package in each directory
   141  	// we can to grab all the files
   142  	if !fileMode {
   143  		for _, fpath := range directoryList {
   144  			pkgs, err := parser.ParseDir(fset, fpath, nil, 0)
   145  			if err != nil {
   146  				return nil, err
   147  			}
   148  
   149  			for _, pkg := range pkgs {
   150  				for _, f := range pkg.Files {
   151  					files = append(files, f)
   152  				}
   153  			}
   154  		}
   155  	}
   156  
   157  	return files, nil
   158  }
   159  
   160  func isDir(filename string) bool {
   161  	fi, err := os.Stat(filename)
   162  	return err == nil && fi.IsDir()
   163  }
   164  
   165  func exists(filename string) bool {
   166  	_, err := os.Stat(filename)
   167  	return err == nil
   168  }
   169  
   170  func (v *returnsVisitor) Visit(node ast.Node) ast.Visitor {
   171  	var namedReturns []*ast.Ident
   172  
   173  	funcDecl, ok := node.(*ast.FuncDecl)
   174  	if !ok {
   175  		return v
   176  	}
   177  	var functionLineLength int
   178  	// We've found a function
   179  	if funcDecl.Type != nil && funcDecl.Type.Results != nil {
   180  		for _, field := range funcDecl.Type.Results.List {
   181  			for _, ident := range field.Names {
   182  				if ident != nil {
   183  					namedReturns = append(namedReturns, ident)
   184  				}
   185  			}
   186  		}
   187  		file := v.f.File(funcDecl.Pos())
   188  		functionLineLength = file.Position(funcDecl.End()).Line - file.Position(funcDecl.Pos()).Line
   189  	}
   190  
   191  	if len(namedReturns) > 0 && funcDecl.Body != nil {
   192  		// Scan the body for usage of the named returns
   193  		for _, stmt := range funcDecl.Body.List {
   194  
   195  			switch s := stmt.(type) {
   196  			case *ast.ReturnStmt:
   197  				if len(s.Results) == 0 {
   198  					file := v.f.File(s.Pos())
   199  					if file != nil && uint(functionLineLength) > v.maxLength {
   200  						if funcDecl.Name != nil {
   201  							log.Printf("%v:%v %v naked returns on %v line function \n", file.Name(), file.Position(s.Pos()).Line, funcDecl.Name.Name, functionLineLength)
   202  						}
   203  					}
   204  					continue
   205  				}
   206  
   207  			default:
   208  			}
   209  		}
   210  	}
   211  
   212  	return v
   213  }