gotest.tools/gotestsum@v1.11.0/cmd/tool/slowest/ast.go (about) 1 package slowest 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/format" 7 "go/parser" 8 "go/token" 9 "os" 10 "strings" 11 12 "golang.org/x/tools/go/packages" 13 "gotest.tools/gotestsum/internal/log" 14 "gotest.tools/gotestsum/testjson" 15 ) 16 17 func writeTestSkip(tcs []testjson.TestCase, skipStmt ast.Stmt) error { 18 fset := token.NewFileSet() 19 cfg := packages.Config{ 20 Mode: modeAll(), 21 Tests: true, 22 Fset: fset, 23 BuildFlags: buildFlags(), 24 } 25 pkgNames, index := testNamesByPkgName(tcs) 26 pkgs, err := packages.Load(&cfg, pkgNames...) 27 if err != nil { 28 return fmt.Errorf("failed to load packages: %v", err) 29 } 30 31 for _, pkg := range pkgs { 32 if len(pkg.Errors) > 0 { 33 return errPkgLoad(pkg) 34 } 35 tcs, ok := index[normalizePkgName(pkg.PkgPath)] 36 if !ok { 37 log.Debugf("skipping %v, no slow tests", pkg.PkgPath) 38 continue 39 } 40 41 log.Debugf("rewriting %v for %d test cases", pkg.PkgPath, len(tcs)) 42 for _, file := range pkg.Syntax { 43 path := fset.File(file.Pos()).Name() 44 log.Debugf("looking for test cases in: %v", path) 45 if !rewriteAST(file, tcs, skipStmt) { 46 continue 47 } 48 if err := writeFile(path, file, fset); err != nil { 49 return fmt.Errorf("failed to write ast to file %v: %v", path, err) 50 } 51 } 52 } 53 return errTestCasesNotFound(index) 54 } 55 56 // normalizePkgName removes the _test suffix from a package name. External test 57 // packages (those named package_test) may contain tests, but the test2json output 58 // always uses the non-external package name. The _test suffix must be removed 59 // so that any slow tests in an external test package can be found. 60 func normalizePkgName(name string) string { 61 return strings.TrimSuffix(name, "_test") 62 } 63 64 func writeFile(path string, file *ast.File, fset *token.FileSet) error { 65 fh, err := os.Create(path) 66 if err != nil { 67 return err 68 } 69 defer func() { 70 if err := fh.Close(); err != nil { 71 log.Errorf("Failed to close file %v: %v", path, err) 72 } 73 }() 74 return format.Node(fh, fset, file) 75 } 76 77 func parseSkipStatement(text string) (ast.Stmt, error) { 78 switch text { 79 case "default", "testing.Short": 80 text = ` 81 if testing.Short() { 82 t.Skip("too slow for testing.Short") 83 } 84 ` 85 } 86 // Add some required boilerplate around the statement to make it a valid file 87 text = "package stub\nfunc Stub() {\n" + text + "\n}\n" 88 file, err := parser.ParseFile(token.NewFileSet(), "fragment", text, 0) 89 if err != nil { 90 return nil, err 91 } 92 stmt := file.Decls[0].(*ast.FuncDecl).Body.List[0] 93 return stmt, nil 94 } 95 96 func rewriteAST(file *ast.File, testNames set, skipStmt ast.Stmt) bool { 97 var modified bool 98 for _, decl := range file.Decls { 99 fd, ok := decl.(*ast.FuncDecl) 100 if !ok { 101 continue 102 } 103 name := fd.Name.Name // TODO: can this be nil? 104 if _, ok := testNames[name]; !ok { 105 continue 106 } 107 108 fd.Body.List = append([]ast.Stmt{skipStmt}, fd.Body.List...) 109 modified = true 110 delete(testNames, name) 111 } 112 return modified 113 } 114 115 type set map[string]struct{} 116 117 // testNamesByPkgName strips subtest names from test names, then builds 118 // and returns a slice of all the packages names, and a mapping of package name 119 // to set of failed tests in that package. 120 // 121 // subtests are removed because the AST lookup currently only works for top-level 122 // functions, not t.Run subtests. 123 func testNamesByPkgName(tcs []testjson.TestCase) ([]string, map[string]set) { 124 var pkgs []string 125 index := make(map[string]set) 126 for _, tc := range tcs { 127 testName := tc.Test.Name() 128 if tc.Test.IsSubTest() { 129 root, _ := tc.Test.Split() 130 testName = root 131 } 132 if len(index[tc.Package]) == 0 { 133 pkgs = append(pkgs, tc.Package) 134 index[tc.Package] = make(map[string]struct{}) 135 } 136 index[tc.Package][testName] = struct{}{} 137 } 138 return pkgs, index 139 } 140 141 func errPkgLoad(pkg *packages.Package) error { 142 buf := new(strings.Builder) 143 for _, err := range pkg.Errors { 144 buf.WriteString("\n" + err.Error()) 145 } 146 return fmt.Errorf("failed to load package %v %v", pkg.PkgPath, buf.String()) 147 } 148 149 func errTestCasesNotFound(index map[string]set) error { 150 var missed []string 151 for pkg, tcs := range index { 152 for tc := range tcs { 153 missed = append(missed, fmt.Sprintf("%v.%v", pkg, tc)) 154 } 155 } 156 if len(missed) == 0 { 157 return nil 158 } 159 return fmt.Errorf("failed to find source for test cases:\n%v", strings.Join(missed, "\n")) 160 } 161 162 func modeAll() packages.LoadMode { 163 mode := packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles 164 mode = mode | packages.NeedImports | packages.NeedDeps 165 mode = mode | packages.NeedTypes | packages.NeedTypesSizes 166 mode = mode | packages.NeedSyntax | packages.NeedTypesInfo 167 return mode 168 } 169 170 func buildFlags() []string { 171 flags := os.Getenv("GOFLAGS") 172 if len(flags) == 0 { 173 return nil 174 } 175 return strings.Split(os.Getenv("GOFLAGS"), " ") 176 }