github.com/krishnamiriyala/courtney@v0.3.2/scanner/scanner.go (about)

     1  package scanner
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/constant"
     7  	"go/token"
     8  	"go/types"
     9  
    10  	"github.com/krishnamiriyala/astrid"
    11  	"github.com/krishnamiriyala/brenda"
    12  	"github.com/krishnamiriyala/courtney/shared"
    13  	"github.com/pkg/errors"
    14  	"golang.org/x/tools/go/packages"
    15  )
    16  
    17  // CodeMap scans a number of packages for code to exclude
    18  type CodeMap struct {
    19  	setup    *shared.Setup
    20  	pkgs     []*packages.Package
    21  	Excludes map[string]map[int]bool
    22  }
    23  
    24  // PackageMap scans a single package for code to exclude
    25  type PackageMap struct {
    26  	*CodeMap
    27  	pkg  *packages.Package
    28  	fset *token.FileSet
    29  }
    30  
    31  // FileMap scans a single file for code to exclude
    32  type FileMap struct {
    33  	*PackageMap
    34  	file    *ast.File
    35  	matcher *astrid.Matcher
    36  }
    37  
    38  type packageId struct {
    39  	path string
    40  	name string
    41  }
    42  
    43  // New returns a CoseMap with the provided setup
    44  func New(setup *shared.Setup) *CodeMap {
    45  	return &CodeMap{
    46  		setup:    setup,
    47  		Excludes: make(map[string]map[int]bool),
    48  	}
    49  }
    50  
    51  func (c *CodeMap) addExclude(fpath string, line int) {
    52  	if c.Excludes[fpath] == nil {
    53  		c.Excludes[fpath] = make(map[int]bool)
    54  	}
    55  	c.Excludes[fpath][line] = true
    56  }
    57  
    58  // LoadProgram uses the loader package to load and process the source for a
    59  // number or packages.
    60  func (c *CodeMap) LoadProgram() error {
    61  	var patterns []string
    62  	for _, p := range c.setup.Packages {
    63  		patterns = append(patterns, p.Path)
    64  	}
    65  	wd, err := c.setup.Env.Getwd()
    66  	if err != nil {
    67  		return errors.WithStack(err)
    68  	}
    69  
    70  	cfg := &packages.Config{
    71  		Dir: wd,
    72  		Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles |
    73  			packages.NeedImports | packages.NeedTypes | packages.NeedTypesSizes |
    74  			packages.NeedSyntax | packages.NeedTypesInfo,
    75  		Env: c.setup.Env.Environ(),
    76  	}
    77  
    78  	// add a recover to catch a panic and add some context to the error
    79  	defer func() {
    80  		if panicErr := recover(); panicErr != nil {
    81  			panic(fmt.Sprintf("%+v", errors.Errorf("Panic in packages.Load: %s", panicErr)))
    82  		}
    83  	}()
    84  
    85  	pkgs, err := packages.Load(cfg, patterns...)
    86  	/*
    87  		ctxt := build.Default
    88  		ctxt.GOPATH = c.setup.Env.Getenv("GOPATH")
    89  
    90  		conf := loader.Config{Build: &ctxt, Cwd: wd, ParserMode: parser.ParseComments}
    91  
    92  		for _, p := range c.setup.Packages {
    93  			conf.Import(p.Path)
    94  		}
    95  		prog, err := conf.Load()
    96  	*/
    97  	if err != nil {
    98  		return errors.Wrap(err, "Error loading config")
    99  	}
   100  	c.pkgs = pkgs
   101  	return nil
   102  }
   103  
   104  // ScanPackages scans the imported packages
   105  func (c *CodeMap) ScanPackages() error {
   106  	for _, p := range c.pkgs {
   107  		pm := &PackageMap{
   108  			CodeMap: c,
   109  			pkg:     p,
   110  			fset:    p.Fset,
   111  		}
   112  		if err := pm.ScanPackage(); err != nil {
   113  			return errors.WithStack(err)
   114  		}
   115  	}
   116  	return nil
   117  }
   118  
   119  // ScanPackage scans a single package
   120  func (p *PackageMap) ScanPackage() error {
   121  	for _, f := range p.pkg.Syntax {
   122  
   123  		fm := &FileMap{
   124  			PackageMap: p,
   125  			file:       f,
   126  			matcher:    astrid.NewMatcher(p.pkg.TypesInfo.Uses, p.pkg.TypesInfo.Defs),
   127  		}
   128  		if err := fm.FindExcludes(); err != nil {
   129  			return errors.WithStack(err)
   130  		}
   131  	}
   132  	return nil
   133  }
   134  
   135  // FindExcludes scans a single file to find code to exclude from coverage files
   136  func (f *FileMap) FindExcludes() error {
   137  	var err error
   138  
   139  	ast.Inspect(f.file, func(node ast.Node) bool {
   140  		if err != nil {
   141  			return false
   142  		}
   143  		b, inner := f.inspectNode(node)
   144  		if inner != nil {
   145  			err = inner
   146  			return false
   147  		}
   148  		return b
   149  	})
   150  	if err != nil {
   151  		return errors.WithStack(err)
   152  	}
   153  	for _, cg := range f.file.Comments {
   154  		f.inspectComment(cg)
   155  	}
   156  	return nil
   157  }
   158  
   159  func (f *FileMap) findScope(node ast.Node, filter func(ast.Node) bool) ast.Node {
   160  	inside := func(node, holder ast.Node) bool {
   161  		return node != nil && holder != nil && node.Pos() > holder.Pos() && node.Pos() <= holder.End()
   162  	}
   163  	var scopes []ast.Node
   164  	ast.Inspect(f.file, func(scope ast.Node) bool {
   165  		if inside(node, scope) {
   166  			scopes = append(scopes, scope)
   167  			return true
   168  		}
   169  		return false
   170  	})
   171  	// find the last matching scope
   172  	for i := len(scopes) - 1; i >= 0; i-- {
   173  		if filter == nil || filter(scopes[i]) {
   174  			return scopes[i]
   175  		}
   176  	}
   177  	return nil
   178  }
   179  
   180  func (f *FileMap) inspectComment(cg *ast.CommentGroup) {
   181  	for _, cm := range cg.List {
   182  		if cm.Text != "// TODO: notest" {
   183  			continue
   184  		}
   185  
   186  		// get the parent scope
   187  		scope := f.findScope(cm, nil)
   188  
   189  		// scope can be nil if the comment is in an empty file... in that
   190  		// case we don't need any excludes.
   191  		if scope != nil {
   192  			comment := f.fset.Position(cm.Pos())
   193  			start := f.fset.Position(scope.Pos())
   194  			end := f.fset.Position(scope.End())
   195  			endLine := end.Line
   196  			if _, ok := scope.(*ast.CaseClause); ok {
   197  				// case block needs an extra line...
   198  				endLine++
   199  			}
   200  			for line := comment.Line; line < endLine; line++ {
   201  				f.addExclude(start.Filename, line)
   202  			}
   203  		}
   204  	}
   205  }
   206  
   207  func (f *FileMap) inspectNode(node ast.Node) (bool, error) {
   208  	if node == nil {
   209  		return true, nil
   210  	}
   211  	switch n := node.(type) {
   212  	case *ast.CallExpr:
   213  		if id, ok := n.Fun.(*ast.Ident); ok && id.Name == "panic" {
   214  			pos := f.fset.Position(n.Pos())
   215  			f.addExclude(pos.Filename, pos.Line)
   216  		}
   217  	case *ast.IfStmt:
   218  		if err := f.inspectIf(n); err != nil {
   219  			return false, err
   220  		}
   221  	case *ast.SwitchStmt:
   222  		if n.Tag != nil {
   223  			// we are only concerned with switch statements with no tag
   224  			// expression e.g. switch { ... }
   225  			return true, nil
   226  		}
   227  		var falseExpr []ast.Expr
   228  		var defaultClause *ast.CaseClause
   229  		for _, s := range n.Body.List {
   230  			cc := s.(*ast.CaseClause)
   231  			if cc.List == nil {
   232  				// save the default clause until the end
   233  				defaultClause = cc
   234  				continue
   235  			}
   236  			if err := f.inspectCase(cc, falseExpr...); err != nil {
   237  				return false, err
   238  			}
   239  			falseExpr = append(falseExpr, f.boolOr(cc.List))
   240  		}
   241  		if defaultClause != nil {
   242  			if err := f.inspectCase(defaultClause, falseExpr...); err != nil {
   243  				return false, err
   244  			}
   245  		}
   246  	}
   247  	return true, nil
   248  }
   249  
   250  func (f *FileMap) inspectCase(stmt *ast.CaseClause, falseExpr ...ast.Expr) error {
   251  	s := brenda.NewSolver(f.fset, f.pkg.TypesInfo.Uses, f.pkg.TypesInfo.Defs, f.boolOr(stmt.List), falseExpr...)
   252  	if err := s.SolveTrue(); err != nil {
   253  		return errors.WithStack(err)
   254  	}
   255  	f.processResults(s, &ast.BlockStmt{List: stmt.Body})
   256  	return nil
   257  }
   258  
   259  func (f *FileMap) boolOr(list []ast.Expr) ast.Expr {
   260  	if len(list) == 0 {
   261  		return nil
   262  	}
   263  	if len(list) == 1 {
   264  		return list[0]
   265  	}
   266  	current := list[0]
   267  	for i := 1; i < len(list); i++ {
   268  		current = &ast.BinaryExpr{X: current, Y: list[i], Op: token.LOR}
   269  	}
   270  	return current
   271  }
   272  
   273  func (f *FileMap) inspectIf(stmt *ast.IfStmt, falseExpr ...ast.Expr) error {
   274  
   275  	// main if block
   276  	s := brenda.NewSolver(f.fset, f.pkg.TypesInfo.Uses, f.pkg.TypesInfo.Defs, stmt.Cond, falseExpr...)
   277  	if err := s.SolveTrue(); err != nil {
   278  		return errors.WithStack(err)
   279  	}
   280  	f.processResults(s, stmt.Body)
   281  
   282  	switch e := stmt.Else.(type) {
   283  	case *ast.BlockStmt:
   284  
   285  		// else block
   286  		s := brenda.NewSolver(f.fset, f.pkg.TypesInfo.Uses, f.pkg.TypesInfo.Defs, stmt.Cond, falseExpr...)
   287  		if err := s.SolveFalse(); err != nil {
   288  			return errors.WithStack(err)
   289  		}
   290  		f.processResults(s, e)
   291  
   292  	case *ast.IfStmt:
   293  
   294  		// else if block
   295  		falseExpr = append(falseExpr, stmt.Cond)
   296  		if err := f.inspectIf(e, falseExpr...); err != nil {
   297  			return errors.WithStack(err)
   298  		}
   299  	}
   300  	return nil
   301  }
   302  
   303  func (f *FileMap) processResults(s *brenda.Solver, block *ast.BlockStmt) {
   304  	for expr, match := range s.Components {
   305  		if !match.Match && !match.Inverse {
   306  			continue
   307  		}
   308  
   309  		found, op, expr := f.isErrorComparison(expr)
   310  		if !found {
   311  			continue
   312  		}
   313  		if op == token.NEQ && match.Match || op == token.EQL && match.Inverse {
   314  			ast.Inspect(block, f.inspectNodeForReturn(expr))
   315  			ast.Inspect(block, f.inspectNodeForWrap(block, expr))
   316  		}
   317  	}
   318  }
   319  
   320  func (f *FileMap) isErrorComparison(e ast.Expr) (found bool, sign token.Token, expr ast.Expr) {
   321  	if b, ok := e.(*ast.BinaryExpr); ok {
   322  		if b.Op != token.NEQ && b.Op != token.EQL {
   323  			return
   324  		}
   325  		xErr := f.isError(b.X)
   326  		yNil := f.isNil(b.Y)
   327  
   328  		if xErr && yNil {
   329  			return true, b.Op, b.X
   330  		}
   331  		yErr := f.isError(b.Y)
   332  		xNil := f.isNil(b.X)
   333  		if yErr && xNil {
   334  			return true, b.Op, b.Y
   335  		}
   336  	}
   337  	return
   338  }
   339  
   340  func (f *FileMap) inspectNodeForReturn(search ast.Expr) func(node ast.Node) bool {
   341  	return func(node ast.Node) bool {
   342  		if node == nil {
   343  			return true
   344  		}
   345  		switch n := node.(type) {
   346  		case *ast.ReturnStmt:
   347  			if f.isErrorReturn(n, search) {
   348  				pos := f.fset.Position(n.Pos())
   349  				f.addExclude(pos.Filename, pos.Line)
   350  			}
   351  		}
   352  		return true
   353  	}
   354  }
   355  
   356  func (f *FileMap) inspectNodeForWrap(block *ast.BlockStmt, search ast.Expr) func(node ast.Node) bool {
   357  	return func(node ast.Node) bool {
   358  		if node == nil {
   359  			return true
   360  		}
   361  		switch n := node.(type) {
   362  		case *ast.DeclStmt:
   363  			// covers the case:
   364  			// var e = foo()
   365  			// and
   366  			// var e error = foo()
   367  			gd, ok := n.Decl.(*ast.GenDecl)
   368  			if !ok {
   369  				return true
   370  			}
   371  			if gd.Tok != token.VAR {
   372  				return true
   373  			}
   374  			if len(gd.Specs) != 1 {
   375  				return true
   376  			}
   377  			spec, ok := gd.Specs[0].(*ast.ValueSpec)
   378  			if !ok {
   379  				return true
   380  			}
   381  			if len(spec.Names) != 1 || len(spec.Values) != 1 {
   382  				return true
   383  			}
   384  			newSearch := spec.Names[0]
   385  
   386  			if f.isErrorCall(spec.Values[0], search) {
   387  				ast.Inspect(block, f.inspectNodeForReturn(newSearch))
   388  			}
   389  
   390  		case *ast.AssignStmt:
   391  			if len(n.Lhs) != 1 || len(n.Rhs) != 1 {
   392  				return true
   393  			}
   394  			newSearch := n.Lhs[0]
   395  
   396  			if f.isErrorCall(n.Rhs[0], search) {
   397  				ast.Inspect(block, f.inspectNodeForReturn(newSearch))
   398  			}
   399  		}
   400  		return true
   401  	}
   402  }
   403  
   404  func (f *FileMap) isErrorCall(expr, search ast.Expr) bool {
   405  	n, ok := expr.(*ast.CallExpr)
   406  	if !ok {
   407  		return false
   408  	}
   409  	if !f.isError(n) {
   410  		// never gets here, but leave it in for completeness
   411  		return false
   412  	}
   413  	for _, arg := range n.Args {
   414  		if f.matcher.Match(arg, search) {
   415  			return true
   416  		}
   417  	}
   418  	return false
   419  }
   420  
   421  func (f *FileMap) isErrorReturnNamedResultParameters(r *ast.ReturnStmt, search ast.Expr) bool {
   422  	// covers the syntax:
   423  	// func a() (err error) {
   424  	// 	if err != nil {
   425  	// 		return
   426  	// 	}
   427  	// }
   428  	scope := f.findScope(r, func(n ast.Node) bool {
   429  		switch n.(type) {
   430  		case *ast.FuncDecl, *ast.FuncLit:
   431  			return true
   432  		}
   433  		return false
   434  	})
   435  	var t *ast.FuncType
   436  	switch s := scope.(type) {
   437  	case *ast.FuncDecl:
   438  		t = s.Type
   439  	case *ast.FuncLit:
   440  		t = s.Type
   441  	}
   442  	if t.Results == nil {
   443  		return false
   444  	}
   445  	last := t.Results.List[len(t.Results.List)-1]
   446  	if last.Names == nil {
   447  		// anonymous returns - shouldn't be able to get here because a bare
   448  		// return statement with either have zero results or named results.
   449  		return false
   450  	}
   451  	id := last.Names[len(last.Names)-1]
   452  	return f.matcher.Match(id, search)
   453  }
   454  
   455  func (f *FileMap) isErrorReturn(r *ast.ReturnStmt, search ast.Expr) bool {
   456  	if len(r.Results) == 0 {
   457  		return f.isErrorReturnNamedResultParameters(r, search)
   458  	}
   459  
   460  	last := r.Results[len(r.Results)-1]
   461  
   462  	// check the last result is an error
   463  	if !f.isError(last) {
   464  		return false
   465  	}
   466  
   467  	// check all the other results are nil or zero
   468  	for i, v := range r.Results {
   469  		if i == len(r.Results)-1 {
   470  			// ignore the last item
   471  			break
   472  		}
   473  		if !f.isZero(v) {
   474  			return false
   475  		}
   476  	}
   477  
   478  	return f.matcher.Match(last, search) || f.isErrorCall(last, search)
   479  }
   480  
   481  func (f *FileMap) isError(v ast.Expr) bool {
   482  	if n, ok := f.pkg.TypesInfo.TypeOf(v).(*types.Named); ok {
   483  		o := n.Obj()
   484  		return o != nil && o.Pkg() == nil && o.Name() == "error"
   485  	}
   486  	return false
   487  }
   488  
   489  func (f *FileMap) isNil(v ast.Expr) bool {
   490  	t := f.pkg.TypesInfo.Types[v]
   491  	return t.IsNil()
   492  }
   493  
   494  func (f *FileMap) isZero(v ast.Expr) bool {
   495  	t := f.pkg.TypesInfo.Types[v]
   496  	if t.IsNil() {
   497  		return true
   498  	}
   499  	if t.Value != nil {
   500  		// constant
   501  		switch t.Value.Kind() {
   502  		case constant.Bool:
   503  			return constant.BoolVal(t.Value) == false
   504  		case constant.String:
   505  			return constant.StringVal(t.Value) == ""
   506  		case constant.Int, constant.Float, constant.Complex:
   507  			return constant.Sign(t.Value) == 0
   508  		default:
   509  			return false
   510  		}
   511  	}
   512  	if t.IsValue() {
   513  		if cl, ok := v.(*ast.CompositeLit); ok {
   514  			for _, e := range cl.Elts {
   515  				if kve, ok := e.(*ast.KeyValueExpr); ok {
   516  					e = kve.Value
   517  				}
   518  				if !f.isZero(e) {
   519  					return false
   520  				}
   521  			}
   522  		}
   523  	}
   524  	return true
   525  }