github.com/0xKiwi/rules_go@v0.24.3/go/tools/builders/generate_test_main.go (about) 1 /* Copyright 2016 The Bazel Authors. All rights reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 */ 15 16 // Bare bones Go testing support for Bazel. 17 18 package main 19 20 import ( 21 "flag" 22 "fmt" 23 "go/ast" 24 "go/doc" 25 "go/parser" 26 "go/token" 27 "os" 28 "path/filepath" 29 "sort" 30 "strings" 31 "text/template" 32 ) 33 34 type Import struct { 35 Name string 36 Path string 37 } 38 39 type TestCase struct { 40 Package string 41 Name string 42 } 43 44 type Example struct { 45 Package string 46 Name string 47 Output string 48 Unordered bool 49 } 50 51 // Cases holds template data. 52 type Cases struct { 53 RunDir string 54 Imports []*Import 55 Tests []TestCase 56 Benchmarks []TestCase 57 Examples []Example 58 TestMain string 59 Coverage bool 60 Pkgname string 61 } 62 63 const testMainTpl = ` 64 package main 65 import ( 66 "flag" 67 "log" 68 "os" 69 "os/exec" 70 "path/filepath" 71 "runtime" 72 "strconv" 73 "testing" 74 "testing/internal/testdeps" 75 76 {{if .Coverage}} 77 "github.com/bazelbuild/rules_go/go/tools/coverdata" 78 {{end}} 79 80 {{range $p := .Imports}} 81 {{$p.Name}} "{{$p.Path}}" 82 {{end}} 83 ) 84 85 var allTests = []testing.InternalTest{ 86 {{range .Tests}} 87 {"{{.Name}}", {{.Package}}.{{.Name}} }, 88 {{end}} 89 } 90 91 var benchmarks = []testing.InternalBenchmark{ 92 {{range .Benchmarks}} 93 {"{{.Name}}", {{.Package}}.{{.Name}} }, 94 {{end}} 95 } 96 97 var examples = []testing.InternalExample{ 98 {{range .Examples}} 99 {Name: "{{.Name}}", F: {{.Package}}.{{.Name}}, Output: {{printf "%q" .Output}}, Unordered: {{.Unordered}} }, 100 {{end}} 101 } 102 103 func testsInShard() []testing.InternalTest { 104 totalShards, err := strconv.Atoi(os.Getenv("TEST_TOTAL_SHARDS")) 105 if err != nil || totalShards <= 1 { 106 return allTests 107 } 108 shardIndex, err := strconv.Atoi(os.Getenv("TEST_SHARD_INDEX")) 109 if err != nil || shardIndex < 0 { 110 return allTests 111 } 112 tests := []testing.InternalTest{} 113 for i, t := range allTests { 114 if i % totalShards == shardIndex { 115 tests = append(tests, t) 116 } 117 } 118 return tests 119 } 120 121 func main() { 122 if shouldWrap() { 123 err := wrap("{{.Pkgname}}") 124 if xerr, ok := err.(*exec.ExitError); ok { 125 os.Exit(xerr.ExitCode()) 126 } else if err != nil { 127 log.Print(err) 128 os.Exit(testWrapperAbnormalExit) 129 } else { 130 os.Exit(0) 131 } 132 } 133 134 // Check if we're being run by Bazel and change directories if so. 135 // TEST_SRCDIR and TEST_WORKSPACE are set by the Bazel test runner, so that makes a decent proxy. 136 testSrcdir := os.Getenv("TEST_SRCDIR") 137 testWorkspace := os.Getenv("TEST_WORKSPACE") 138 if testSrcdir != "" && testWorkspace != "" { 139 abs := filepath.Join(testSrcdir, testWorkspace, {{printf "%q" .RunDir}}) 140 err := os.Chdir(abs) 141 // Ignore the Chdir err when on Windows, since it might have have runfiles symlinks. 142 // https://github.com/bazelbuild/rules_go/pull/1721#issuecomment-422145904 143 if err != nil && runtime.GOOS != "windows" { 144 log.Fatalf("could not change to test directory: %v", err) 145 } 146 if err == nil { 147 os.Setenv("PWD", abs) 148 } 149 } 150 151 m := testing.MainStart(testdeps.TestDeps{}, testsInShard(), benchmarks, examples) 152 153 if filter := os.Getenv("TESTBRIDGE_TEST_ONLY"); filter != "" { 154 flag.Lookup("test.run").Value.Set(filter) 155 } 156 157 {{if .Coverage}} 158 if len(coverdata.Cover.Counters) > 0 { 159 testing.RegisterCover(coverdata.Cover) 160 } 161 if coverageDat, ok := os.LookupEnv("COVERAGE_OUTPUT_FILE"); ok { 162 if testing.CoverMode() != "" { 163 flag.Lookup("test.coverprofile").Value.Set(coverageDat) 164 } 165 } 166 {{end}} 167 168 {{if not .TestMain}} 169 os.Exit(m.Run()) 170 {{else}} 171 {{.TestMain}}(m) 172 {{end}} 173 } 174 ` 175 176 func genTestMain(args []string) error { 177 // Prepare our flags 178 args, err := expandParamsFiles(args) 179 if err != nil { 180 return err 181 } 182 imports := multiFlag{} 183 sources := multiFlag{} 184 flags := flag.NewFlagSet("GoTestGenTest", flag.ExitOnError) 185 goenv := envFlags(flags) 186 runDir := flags.String("rundir", ".", "Path to directory where tests should run.") 187 out := flags.String("output", "", "output file to write. Defaults to stdout.") 188 coverage := flags.Bool("coverage", false, "whether coverage is supported") 189 pkgname := flags.String("pkgname", "", "package name of test") 190 flags.Var(&imports, "import", "Packages to import") 191 flags.Var(&sources, "src", "Sources to process for tests") 192 if err := flags.Parse(args); err != nil { 193 return err 194 } 195 if err := goenv.checkFlags(); err != nil { 196 return err 197 } 198 // Process import args 199 importMap := map[string]*Import{} 200 for _, imp := range imports { 201 parts := strings.Split(imp, "=") 202 if len(parts) != 2 { 203 return fmt.Errorf("Invalid import %q specified", imp) 204 } 205 i := &Import{Name: parts[0], Path: parts[1]} 206 importMap[i.Name] = i 207 } 208 // Process source args 209 sourceList := []string{} 210 sourceMap := map[string]string{} 211 for _, s := range sources { 212 parts := strings.Split(s, "=") 213 if len(parts) != 2 { 214 return fmt.Errorf("Invalid source %q specified", s) 215 } 216 sourceList = append(sourceList, parts[1]) 217 sourceMap[parts[1]] = parts[0] 218 } 219 220 // filter our input file list 221 filteredSrcs, err := filterAndSplitFiles(sourceList) 222 if err != nil { 223 return err 224 } 225 goSrcs := filteredSrcs.goSrcs 226 227 outFile := os.Stdout 228 if *out != "" { 229 var err error 230 outFile, err = os.Create(*out) 231 if err != nil { 232 return fmt.Errorf("os.Create(%q): %v", *out, err) 233 } 234 defer outFile.Close() 235 } 236 237 cases := Cases{ 238 RunDir: strings.Replace(filepath.FromSlash(*runDir), `\`, `\\`, -1), 239 Coverage: *coverage, 240 Pkgname: *pkgname, 241 } 242 243 testFileSet := token.NewFileSet() 244 pkgs := map[string]bool{} 245 for _, f := range goSrcs { 246 parse, err := parser.ParseFile(testFileSet, f.filename, nil, parser.ParseComments) 247 if err != nil { 248 return fmt.Errorf("ParseFile(%q): %v", f.filename, err) 249 } 250 pkg := sourceMap[f.filename] 251 if strings.HasSuffix(parse.Name.String(), "_test") { 252 pkg += "_test" 253 } 254 for _, e := range doc.Examples(parse) { 255 if e.Output == "" && !e.EmptyOutput { 256 continue 257 } 258 cases.Examples = append(cases.Examples, Example{ 259 Name: "Example" + e.Name, 260 Package: pkg, 261 Output: e.Output, 262 Unordered: e.Unordered, 263 }) 264 pkgs[pkg] = true 265 } 266 for _, d := range parse.Decls { 267 fn, ok := d.(*ast.FuncDecl) 268 if !ok { 269 continue 270 } 271 if fn.Recv != nil { 272 continue 273 } 274 if fn.Name.Name == "TestMain" { 275 // TestMain is not, itself, a test 276 pkgs[pkg] = true 277 cases.TestMain = fmt.Sprintf("%s.%s", pkg, fn.Name.Name) 278 continue 279 } 280 281 // Here we check the signature of the Test* function. To 282 // be considered a test: 283 284 // 1. The function should have a single argument. 285 if len(fn.Type.Params.List) != 1 { 286 continue 287 } 288 289 // 2. The function should return nothing. 290 if fn.Type.Results != nil { 291 continue 292 } 293 294 // 3. The only parameter should have a type identified as 295 // *<something>.T 296 starExpr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr) 297 if !ok { 298 continue 299 } 300 selExpr, ok := starExpr.X.(*ast.SelectorExpr) 301 if !ok { 302 continue 303 } 304 305 // We do not descriminate on the referenced type of the 306 // parameter being *testing.T. Instead we assert that it 307 // should be *<something>.T. This is because the import 308 // could have been aliased as a different identifier. 309 310 if strings.HasPrefix(fn.Name.Name, "Test") { 311 if selExpr.Sel.Name != "T" { 312 continue 313 } 314 pkgs[pkg] = true 315 cases.Tests = append(cases.Tests, TestCase{ 316 Package: pkg, 317 Name: fn.Name.Name, 318 }) 319 } 320 if strings.HasPrefix(fn.Name.Name, "Benchmark") { 321 if selExpr.Sel.Name != "B" { 322 continue 323 } 324 pkgs[pkg] = true 325 cases.Benchmarks = append(cases.Benchmarks, TestCase{ 326 Package: pkg, 327 Name: fn.Name.Name, 328 }) 329 } 330 } 331 } 332 333 for name := range importMap { 334 // Set the names for all unused imports to "_" 335 if !pkgs[name] { 336 importMap[name].Name = "_" 337 } 338 cases.Imports = append(cases.Imports, importMap[name]) 339 } 340 sort.Slice(cases.Imports, func(i, j int) bool { 341 return cases.Imports[i].Name < cases.Imports[j].Name 342 }) 343 tpl := template.Must(template.New("source").Parse(testMainTpl)) 344 if err := tpl.Execute(outFile, &cases); err != nil { 345 return fmt.Errorf("template.Execute(%v): %v", cases, err) 346 } 347 return nil 348 }