github.com/lixvbnet/courtney@v0.0.0-20221025031132-0dcb02231211/scanner/scanner.go (about)

     1  package scanner
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/constant"
     7  	"go/token"
     8  	"go/types"
     9  	"strings"
    10  
    11  	"github.com/dave/astrid"
    12  	"github.com/dave/brenda"
    13  	"github.com/lixvbnet/courtney/shared"
    14  	"github.com/pkg/errors"
    15  	"golang.org/x/tools/go/packages"
    16  )
    17  
    18  // CodeMap scans a number of packages for code to exclude
    19  type CodeMap struct {
    20  	setup    *shared.Setup
    21  	pkgs     []*packages.Package
    22  	Excludes map[string]map[int]bool
    23  }
    24  
    25  // PackageMap scans a single package for code to exclude
    26  type PackageMap struct {
    27  	*CodeMap
    28  	pkg  *packages.Package
    29  	fset *token.FileSet
    30  }
    31  
    32  // FileMap scans a single file for code to exclude
    33  type FileMap struct {
    34  	*PackageMap
    35  	file    *ast.File
    36  	matcher *astrid.Matcher
    37  }
    38  
    39  type packageId struct {
    40  	path string
    41  	name string
    42  }
    43  
    44  // New returns a CoseMap with the provided setup
    45  func New(setup *shared.Setup) *CodeMap {
    46  	return &CodeMap{
    47  		setup:    setup,
    48  		Excludes: make(map[string]map[int]bool),
    49  	}
    50  }
    51  
    52  func (c *CodeMap) addExclude(fpath string, line int) {
    53  	if c.Excludes[fpath] == nil {
    54  		c.Excludes[fpath] = make(map[int]bool)
    55  	}
    56  	c.Excludes[fpath][line] = true
    57  }
    58  
    59  // LoadProgram uses the loader package to load and process the source for a
    60  // number or packages.
    61  func (c *CodeMap) LoadProgram() error {
    62  	var patterns []string
    63  	for _, p := range c.setup.Packages {
    64  		patterns = append(patterns, p.Path)
    65  	}
    66  	wd, err := c.setup.Env.Getwd()
    67  	if err != nil {
    68  		return errors.WithStack(err)
    69  	}
    70  
    71  	cfg := &packages.Config{
    72  		Dir: wd,
    73  		Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles |
    74  			packages.NeedImports | packages.NeedTypes | packages.NeedTypesSizes |
    75  			packages.NeedSyntax | packages.NeedTypesInfo,
    76  		Env: c.setup.Env.Environ(),
    77  	}
    78  
    79  	// add a recover to catch a panic and add some context to the error
    80  	defer func() {
    81  		if panicErr := recover(); panicErr != nil {
    82  			panic(fmt.Sprintf("%+v", errors.Errorf("Panic in packages.Load: %s", panicErr)))
    83  		}
    84  	}()
    85  
    86  	pkgs, err := packages.Load(cfg, patterns...)
    87  	/*
    88  		ctxt := build.Default
    89  		ctxt.GOPATH = c.setup.Env.Getenv("GOPATH")
    90  
    91  		conf := loader.Config{Build: &ctxt, Cwd: wd, ParserMode: parser.ParseComments}
    92  
    93  		for _, p := range c.setup.Packages {
    94  			conf.Import(p.Path)
    95  		}
    96  		prog, err := conf.Load()
    97  	*/
    98  	if err != nil {
    99  		return errors.Wrap(err, "Error loading config")
   100  	}
   101  	c.pkgs = pkgs
   102  	return nil
   103  }
   104  
   105  // ScanPackages scans the imported packages
   106  func (c *CodeMap) ScanPackages() error {
   107  	for _, p := range c.pkgs {
   108  		pm := &PackageMap{
   109  			CodeMap: c,
   110  			pkg:     p,
   111  			fset:    p.Fset,
   112  		}
   113  		if err := pm.ScanPackage(); err != nil {
   114  			return errors.WithStack(err)
   115  		}
   116  	}
   117  	return nil
   118  }
   119  
   120  // ScanPackage scans a single package
   121  func (p *PackageMap) ScanPackage() error {
   122  	for _, f := range p.pkg.Syntax {
   123  
   124  		fm := &FileMap{
   125  			PackageMap: p,
   126  			file:       f,
   127  			matcher:    astrid.NewMatcher(p.pkg.TypesInfo.Uses, p.pkg.TypesInfo.Defs),
   128  		}
   129  		if err := fm.FindExcludes(); err != nil {
   130  			return errors.WithStack(err)
   131  		}
   132  	}
   133  	return nil
   134  }
   135  
   136  // FindExcludes scans a single file to find code to exclude from coverage files
   137  func (f *FileMap) FindExcludes() error {
   138  	var err error
   139  
   140  	ast.Inspect(f.file, func(node ast.Node) bool {
   141  		if err != nil {
   142  			// notest
   143  			return false
   144  		}
   145  		b, inner := f.inspectNode(node)
   146  		if inner != nil {
   147  			// notest
   148  			err = inner
   149  			return false
   150  		}
   151  		return b
   152  	})
   153  	if err != nil {
   154  		return errors.WithStack(err)
   155  	}
   156  	for _, cg := range f.file.Comments {
   157  		f.inspectComment(cg)
   158  	}
   159  	return nil
   160  }
   161  
   162  func (f *FileMap) findScope(node ast.Node, filter func(ast.Node) bool) ast.Node {
   163  	inside := func(node, holder ast.Node) bool {
   164  		return node != nil && holder != nil && node.Pos() > holder.Pos() && node.Pos() <= holder.End()
   165  	}
   166  	var scopes []ast.Node
   167  	ast.Inspect(f.file, func(scope ast.Node) bool {
   168  		if inside(node, scope) {
   169  			scopes = append(scopes, scope)
   170  			return true
   171  		}
   172  		return false
   173  	})
   174  	// find the last matching scope
   175  	for i := len(scopes) - 1; i >= 0; i-- {
   176  		if filter == nil || filter(scopes[i]) {
   177  			return scopes[i]
   178  		}
   179  	}
   180  	// notest
   181  	return nil
   182  }
   183  
   184  func (f *FileMap) inspectComment(cg *ast.CommentGroup) {
   185  	for _, cm := range cg.List {
   186  		if !strings.HasPrefix(cm.Text, "//notest") && !strings.HasPrefix(cm.Text, "// notest") {
   187  			continue
   188  		}
   189  
   190  		// get the parent scope
   191  		scope := f.findScope(cm, nil)
   192  
   193  		// scope can be nil if the comment is in an empty file... in that
   194  		// case we don't need any excludes.
   195  		if scope != nil {
   196  			comment := f.fset.Position(cm.Pos())
   197  			start := f.fset.Position(scope.Pos())
   198  			end := f.fset.Position(scope.End())
   199  			endLine := end.Line
   200  			if _, ok := scope.(*ast.CaseClause); ok {
   201  				// case block needs an extra line...
   202  				endLine++
   203  			}
   204  			for line := comment.Line; line < endLine; line++ {
   205  				f.addExclude(start.Filename, line)
   206  			}
   207  		}
   208  	}
   209  }
   210  
   211  func (f *FileMap) inspectNode(node ast.Node) (bool, error) {
   212  	if node == nil {
   213  		return true, nil
   214  	}
   215  	switch n := node.(type) {
   216  	case *ast.CallExpr:
   217  		if id, ok := n.Fun.(*ast.Ident); ok && id.Name == "panic" {
   218  			pos := f.fset.Position(n.Pos())
   219  			f.addExclude(pos.Filename, pos.Line)
   220  		}
   221  	case *ast.IfStmt:
   222  		if err := f.inspectIf(n); err != nil {
   223  			return false, err
   224  		}
   225  	case *ast.SwitchStmt:
   226  		if n.Tag != nil {
   227  			// we are only concerned with switch statements with no tag
   228  			// expression e.g. switch { ... }
   229  			return true, nil
   230  		}
   231  		var falseExpr []ast.Expr
   232  		var defaultClause *ast.CaseClause
   233  		for _, s := range n.Body.List {
   234  			cc := s.(*ast.CaseClause)
   235  			if cc.List == nil {
   236  				// save the default clause until the end
   237  				defaultClause = cc
   238  				continue
   239  			}
   240  			if err := f.inspectCase(cc, falseExpr...); err != nil {
   241  				return false, err
   242  			}
   243  			falseExpr = append(falseExpr, f.boolOr(cc.List))
   244  		}
   245  		if defaultClause != nil {
   246  			if err := f.inspectCase(defaultClause, falseExpr...); err != nil {
   247  				return false, err
   248  			}
   249  		}
   250  	}
   251  	return true, nil
   252  }
   253  
   254  func (f *FileMap) inspectCase(stmt *ast.CaseClause, falseExpr ...ast.Expr) error {
   255  	s := brenda.NewSolver(f.fset, f.pkg.TypesInfo.Uses, f.pkg.TypesInfo.Defs, f.boolOr(stmt.List), falseExpr...)
   256  	if err := s.SolveTrue(); err != nil {
   257  		return errors.WithStack(err)
   258  	}
   259  	f.processResults(s, &ast.BlockStmt{List: stmt.Body})
   260  	return nil
   261  }
   262  
   263  func (f *FileMap) boolOr(list []ast.Expr) ast.Expr {
   264  	if len(list) == 0 {
   265  		return nil
   266  	}
   267  	if len(list) == 1 {
   268  		return list[0]
   269  	}
   270  	current := list[0]
   271  	for i := 1; i < len(list); i++ {
   272  		current = &ast.BinaryExpr{X: current, Y: list[i], Op: token.LOR}
   273  	}
   274  	return current
   275  }
   276  
   277  func (f *FileMap) inspectIf(stmt *ast.IfStmt, falseExpr ...ast.Expr) error {
   278  
   279  	// main if block
   280  	s := brenda.NewSolver(f.fset, f.pkg.TypesInfo.Uses, f.pkg.TypesInfo.Defs, stmt.Cond, falseExpr...)
   281  	if err := s.SolveTrue(); err != nil {
   282  		return errors.WithStack(err)
   283  	}
   284  	f.processResults(s, stmt.Body)
   285  
   286  	switch e := stmt.Else.(type) {
   287  	case *ast.BlockStmt:
   288  
   289  		// else block
   290  		s := brenda.NewSolver(f.fset, f.pkg.TypesInfo.Uses, f.pkg.TypesInfo.Defs, stmt.Cond, falseExpr...)
   291  		if err := s.SolveFalse(); err != nil {
   292  			return errors.WithStack(err)
   293  		}
   294  		f.processResults(s, e)
   295  
   296  	case *ast.IfStmt:
   297  
   298  		// else if block
   299  		falseExpr = append(falseExpr, stmt.Cond)
   300  		if err := f.inspectIf(e, falseExpr...); err != nil {
   301  			return errors.WithStack(err)
   302  		}
   303  	}
   304  	return nil
   305  }
   306  
   307  func (f *FileMap) processResults(s *brenda.Solver, block *ast.BlockStmt) {
   308  	for expr, match := range s.Components {
   309  		if !match.Match && !match.Inverse {
   310  			continue
   311  		}
   312  
   313  		found, op, expr := f.isErrorComparison(expr)
   314  		if !found {
   315  			continue
   316  		}
   317  		if op == token.NEQ && match.Match || op == token.EQL && match.Inverse {
   318  			ast.Inspect(block, f.inspectNodeForReturn(expr))
   319  			ast.Inspect(block, f.inspectNodeForWrap(block, expr))
   320  		}
   321  	}
   322  }
   323  
   324  func (f *FileMap) isErrorComparison(e ast.Expr) (found bool, sign token.Token, expr ast.Expr) {
   325  	if b, ok := e.(*ast.BinaryExpr); ok {
   326  		if b.Op != token.NEQ && b.Op != token.EQL {
   327  			return
   328  		}
   329  		xErr := f.isError(b.X)
   330  		yNil := f.isNil(b.Y)
   331  
   332  		if xErr && yNil {
   333  			return true, b.Op, b.X
   334  		}
   335  		yErr := f.isError(b.Y)
   336  		xNil := f.isNil(b.X)
   337  		if yErr && xNil {
   338  			return true, b.Op, b.Y
   339  		}
   340  	}
   341  	return
   342  }
   343  
   344  func (f *FileMap) inspectNodeForReturn(search ast.Expr) func(node ast.Node) bool {
   345  	return func(node ast.Node) bool {
   346  		if node == nil {
   347  			return true
   348  		}
   349  		switch n := node.(type) {
   350  		case *ast.ReturnStmt:
   351  			if f.isErrorReturn(n, search) {
   352  				pos := f.fset.Position(n.Pos())
   353  				f.addExclude(pos.Filename, pos.Line)
   354  			}
   355  		}
   356  		return true
   357  	}
   358  }
   359  
   360  func (f *FileMap) inspectNodeForWrap(block *ast.BlockStmt, search ast.Expr) func(node ast.Node) bool {
   361  	return func(node ast.Node) bool {
   362  		if node == nil {
   363  			return true
   364  		}
   365  		switch n := node.(type) {
   366  		case *ast.DeclStmt:
   367  			// covers the case:
   368  			// var e = foo()
   369  			// and
   370  			// var e error = foo()
   371  			gd, ok := n.Decl.(*ast.GenDecl)
   372  			if !ok {
   373  				// notest
   374  				return true
   375  			}
   376  			if gd.Tok != token.VAR {
   377  				// notest
   378  				return true
   379  			}
   380  			if len(gd.Specs) != 1 {
   381  				// notest
   382  				return true
   383  			}
   384  			spec, ok := gd.Specs[0].(*ast.ValueSpec)
   385  			if !ok {
   386  				// notest
   387  				return true
   388  			}
   389  			if len(spec.Names) != 1 || len(spec.Values) != 1 {
   390  				// notest
   391  				return true
   392  			}
   393  			newSearch := spec.Names[0]
   394  
   395  			if f.isErrorCall(spec.Values[0], search) {
   396  				ast.Inspect(block, f.inspectNodeForReturn(newSearch))
   397  			}
   398  
   399  		case *ast.AssignStmt:
   400  			if len(n.Lhs) != 1 || len(n.Rhs) != 1 {
   401  				// notest
   402  				return true
   403  			}
   404  			newSearch := n.Lhs[0]
   405  
   406  			if f.isErrorCall(n.Rhs[0], search) {
   407  				ast.Inspect(block, f.inspectNodeForReturn(newSearch))
   408  			}
   409  		}
   410  		return true
   411  	}
   412  }
   413  
   414  func (f *FileMap) isErrorCall(expr, search ast.Expr) bool {
   415  	n, ok := expr.(*ast.CallExpr)
   416  	if !ok {
   417  		return false
   418  	}
   419  	if !f.isError(n) {
   420  		// never gets here, but leave it in for completeness
   421  		// notest
   422  		return false
   423  	}
   424  	for _, arg := range n.Args {
   425  		if f.matcher.Match(arg, search) {
   426  			return true
   427  		}
   428  	}
   429  	return false
   430  }
   431  
   432  func (f *FileMap) isErrorReturnNamedResultParameters(r *ast.ReturnStmt, search ast.Expr) bool {
   433  	// covers the syntax:
   434  	// func a() (err error) {
   435  	// 	if err != nil {
   436  	// 		return
   437  	// 	}
   438  	// }
   439  	scope := f.findScope(r, func(n ast.Node) bool {
   440  		switch n.(type) {
   441  		case *ast.FuncDecl, *ast.FuncLit:
   442  			return true
   443  		}
   444  		return false
   445  	})
   446  	var t *ast.FuncType
   447  	switch s := scope.(type) {
   448  	case *ast.FuncDecl:
   449  		t = s.Type
   450  	case *ast.FuncLit:
   451  		t = s.Type
   452  	}
   453  	if t.Results == nil {
   454  		return false
   455  	}
   456  	last := t.Results.List[len(t.Results.List)-1]
   457  	if last.Names == nil {
   458  		// anonymous returns - shouldn't be able to get here because a bare
   459  		// return statement with either have zero results or named results.
   460  		// notest
   461  		return false
   462  	}
   463  	id := last.Names[len(last.Names)-1]
   464  	return f.matcher.Match(id, search)
   465  }
   466  
   467  func (f *FileMap) isErrorReturn(r *ast.ReturnStmt, search ast.Expr) bool {
   468  	if len(r.Results) == 0 {
   469  		return f.isErrorReturnNamedResultParameters(r, search)
   470  	}
   471  
   472  	last := r.Results[len(r.Results)-1]
   473  
   474  	// check the last result is an error
   475  	if !f.isError(last) {
   476  		return false
   477  	}
   478  
   479  	// check all the other results are nil or zero
   480  	for i, v := range r.Results {
   481  		if i == len(r.Results)-1 {
   482  			// ignore the last item
   483  			break
   484  		}
   485  		if !f.isZero(v) {
   486  			return false
   487  		}
   488  	}
   489  
   490  	return f.matcher.Match(last, search) || f.isErrorCall(last, search)
   491  }
   492  
   493  func (f *FileMap) isError(v ast.Expr) bool {
   494  	if n, ok := f.pkg.TypesInfo.TypeOf(v).(*types.Named); ok {
   495  		o := n.Obj()
   496  		return o != nil && o.Pkg() == nil && o.Name() == "error"
   497  	}
   498  	return false
   499  }
   500  
   501  func (f *FileMap) isNil(v ast.Expr) bool {
   502  	t := f.pkg.TypesInfo.Types[v]
   503  	return t.IsNil()
   504  }
   505  
   506  func (f *FileMap) isZero(v ast.Expr) bool {
   507  	t := f.pkg.TypesInfo.Types[v]
   508  	if t.IsNil() {
   509  		return true
   510  	}
   511  	if t.Value != nil {
   512  		// constant
   513  		switch t.Value.Kind() {
   514  		case constant.Bool:
   515  			return constant.BoolVal(t.Value) == false
   516  		case constant.String:
   517  			return constant.StringVal(t.Value) == ""
   518  		case constant.Int, constant.Float, constant.Complex:
   519  			return constant.Sign(t.Value) == 0
   520  		default:
   521  			// notest
   522  			return false
   523  		}
   524  	}
   525  	if t.IsValue() {
   526  		if cl, ok := v.(*ast.CompositeLit); ok {
   527  			for _, e := range cl.Elts {
   528  				if kve, ok := e.(*ast.KeyValueExpr); ok {
   529  					e = kve.Value
   530  				}
   531  				if !f.isZero(e) {
   532  					return false
   533  				}
   534  			}
   535  		}
   536  	}
   537  	return true
   538  }