gotest.tools/gotestsum@v1.11.0/cmd/tool/slowest/ast.go (about)

     1  package slowest
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/format"
     7  	"go/parser"
     8  	"go/token"
     9  	"os"
    10  	"strings"
    11  
    12  	"golang.org/x/tools/go/packages"
    13  	"gotest.tools/gotestsum/internal/log"
    14  	"gotest.tools/gotestsum/testjson"
    15  )
    16  
    17  func writeTestSkip(tcs []testjson.TestCase, skipStmt ast.Stmt) error {
    18  	fset := token.NewFileSet()
    19  	cfg := packages.Config{
    20  		Mode:       modeAll(),
    21  		Tests:      true,
    22  		Fset:       fset,
    23  		BuildFlags: buildFlags(),
    24  	}
    25  	pkgNames, index := testNamesByPkgName(tcs)
    26  	pkgs, err := packages.Load(&cfg, pkgNames...)
    27  	if err != nil {
    28  		return fmt.Errorf("failed to load packages: %v", err)
    29  	}
    30  
    31  	for _, pkg := range pkgs {
    32  		if len(pkg.Errors) > 0 {
    33  			return errPkgLoad(pkg)
    34  		}
    35  		tcs, ok := index[normalizePkgName(pkg.PkgPath)]
    36  		if !ok {
    37  			log.Debugf("skipping %v, no slow tests", pkg.PkgPath)
    38  			continue
    39  		}
    40  
    41  		log.Debugf("rewriting %v for %d test cases", pkg.PkgPath, len(tcs))
    42  		for _, file := range pkg.Syntax {
    43  			path := fset.File(file.Pos()).Name()
    44  			log.Debugf("looking for test cases in: %v", path)
    45  			if !rewriteAST(file, tcs, skipStmt) {
    46  				continue
    47  			}
    48  			if err := writeFile(path, file, fset); err != nil {
    49  				return fmt.Errorf("failed to write ast to file %v: %v", path, err)
    50  			}
    51  		}
    52  	}
    53  	return errTestCasesNotFound(index)
    54  }
    55  
    56  // normalizePkgName removes the _test suffix from a package name. External test
    57  // packages (those named package_test) may contain tests, but the test2json output
    58  // always uses the non-external package name. The _test suffix must be removed
    59  // so that any slow tests in an external test package can be found.
    60  func normalizePkgName(name string) string {
    61  	return strings.TrimSuffix(name, "_test")
    62  }
    63  
    64  func writeFile(path string, file *ast.File, fset *token.FileSet) error {
    65  	fh, err := os.Create(path)
    66  	if err != nil {
    67  		return err
    68  	}
    69  	defer func() {
    70  		if err := fh.Close(); err != nil {
    71  			log.Errorf("Failed to close file %v: %v", path, err)
    72  		}
    73  	}()
    74  	return format.Node(fh, fset, file)
    75  }
    76  
    77  func parseSkipStatement(text string) (ast.Stmt, error) {
    78  	switch text {
    79  	case "default", "testing.Short":
    80  		text = `
    81  	if testing.Short() {
    82  		t.Skip("too slow for testing.Short")
    83  	}
    84  `
    85  	}
    86  	// Add some required boilerplate around the statement to make it a valid file
    87  	text = "package stub\nfunc Stub() {\n" + text + "\n}\n"
    88  	file, err := parser.ParseFile(token.NewFileSet(), "fragment", text, 0)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	stmt := file.Decls[0].(*ast.FuncDecl).Body.List[0]
    93  	return stmt, nil
    94  }
    95  
    96  func rewriteAST(file *ast.File, testNames set, skipStmt ast.Stmt) bool {
    97  	var modified bool
    98  	for _, decl := range file.Decls {
    99  		fd, ok := decl.(*ast.FuncDecl)
   100  		if !ok {
   101  			continue
   102  		}
   103  		name := fd.Name.Name // TODO: can this be nil?
   104  		if _, ok := testNames[name]; !ok {
   105  			continue
   106  		}
   107  
   108  		fd.Body.List = append([]ast.Stmt{skipStmt}, fd.Body.List...)
   109  		modified = true
   110  		delete(testNames, name)
   111  	}
   112  	return modified
   113  }
   114  
   115  type set map[string]struct{}
   116  
   117  // testNamesByPkgName strips subtest names from test names, then builds
   118  // and returns a slice of all the packages names, and a mapping of package name
   119  // to set of failed tests in that package.
   120  //
   121  // subtests are removed because the AST lookup currently only works for top-level
   122  // functions, not t.Run subtests.
   123  func testNamesByPkgName(tcs []testjson.TestCase) ([]string, map[string]set) {
   124  	var pkgs []string
   125  	index := make(map[string]set)
   126  	for _, tc := range tcs {
   127  		testName := tc.Test.Name()
   128  		if tc.Test.IsSubTest() {
   129  			root, _ := tc.Test.Split()
   130  			testName = root
   131  		}
   132  		if len(index[tc.Package]) == 0 {
   133  			pkgs = append(pkgs, tc.Package)
   134  			index[tc.Package] = make(map[string]struct{})
   135  		}
   136  		index[tc.Package][testName] = struct{}{}
   137  	}
   138  	return pkgs, index
   139  }
   140  
   141  func errPkgLoad(pkg *packages.Package) error {
   142  	buf := new(strings.Builder)
   143  	for _, err := range pkg.Errors {
   144  		buf.WriteString("\n" + err.Error())
   145  	}
   146  	return fmt.Errorf("failed to load package %v %v", pkg.PkgPath, buf.String())
   147  }
   148  
   149  func errTestCasesNotFound(index map[string]set) error {
   150  	var missed []string
   151  	for pkg, tcs := range index {
   152  		for tc := range tcs {
   153  			missed = append(missed, fmt.Sprintf("%v.%v", pkg, tc))
   154  		}
   155  	}
   156  	if len(missed) == 0 {
   157  		return nil
   158  	}
   159  	return fmt.Errorf("failed to find source for test cases:\n%v", strings.Join(missed, "\n"))
   160  }
   161  
   162  func modeAll() packages.LoadMode {
   163  	mode := packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles
   164  	mode = mode | packages.NeedImports | packages.NeedDeps
   165  	mode = mode | packages.NeedTypes | packages.NeedTypesSizes
   166  	mode = mode | packages.NeedSyntax | packages.NeedTypesInfo
   167  	return mode
   168  }
   169  
   170  func buildFlags() []string {
   171  	flags := os.Getenv("GOFLAGS")
   172  	if len(flags) == 0 {
   173  		return nil
   174  	}
   175  	return strings.Split(os.Getenv("GOFLAGS"), " ")
   176  }