github.com/please-build/go-rules/tools/please_go@v0.0.0-20240319165128-ea27d6f5caba/test/write_test_main.go (about) 1 package test 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/doc" 7 "go/parser" 8 "go/token" 9 "os" 10 "strings" 11 "text/template" 12 "unicode" 13 "unicode/utf8" 14 ) 15 16 type testDescr struct { 17 Package string 18 Main string 19 TestFunctions []string 20 BenchFunctions []string 21 FuzzFunctions []string 22 Examples []*doc.Example 23 CoverVars []CoverVar 24 Imports []string 25 Coverage bool 26 Benchmark bool 27 HasFuzz bool 28 } 29 30 // WriteTestMain templates a test main file from the given sources to the given output file. 31 func WriteTestMain(testPackage string, sources []string, output string, coverage bool, coverVars []CoverVar, benchmark, hasFuzz, coverageRedesign bool) error { 32 testDescr, err := parseTestSources(sources) 33 if err != nil { 34 return err 35 } 36 testDescr.Coverage = coverage 37 testDescr.CoverVars = coverVars 38 if len(testDescr.TestFunctions) > 0 || len(testDescr.BenchFunctions) > 0 || len(testDescr.Examples) > 0 || len(testDescr.FuzzFunctions) > 0 || testDescr.Main != "" { 39 // Can't set this if there are no test functions, it'll be an unused import. 40 if coverageRedesign { 41 testDescr.Imports = []string{fmt.Sprintf("%s \"%s\"", testDescr.Package, testPackage)} 42 } else { 43 testDescr.Imports = extraImportPaths(testPackage, testDescr.Package, testDescr.CoverVars) 44 } 45 } 46 47 testDescr.Benchmark = benchmark 48 testDescr.HasFuzz = hasFuzz 49 50 f, err := os.Create(output) 51 if err != nil { 52 return err 53 } 54 defer f.Close() 55 // This might be consumed by other things. 56 fmt.Printf("Package: %s\n", testDescr.Package) 57 58 if coverageRedesign { 59 return testMainTmpl.Execute(f, testDescr) 60 } 61 return oldTestMainTmpl.Execute(f, testDescr) 62 } 63 64 func extraImportPaths(testPackage, alias string, coverVars []CoverVar) []string { 65 ret := make([]string, 0, len(coverVars)+1) 66 ret = append(ret, fmt.Sprintf("%s \"%s\"", alias, testPackage)) 67 68 for i, v := range coverVars { 69 name := fmt.Sprintf("_cover%d", i) 70 coverVars[i].ImportName = name 71 ret = append(ret, fmt.Sprintf("%s \"%s\"", name, v.ImportPath)) 72 } 73 return ret 74 } 75 76 // parseTestSources parses the test sources and returns the package and set of test functions in them. 77 func parseTestSources(sources []string) (testDescr, error) { 78 descr := testDescr{} 79 for _, source := range sources { 80 f, err := parser.ParseFile(token.NewFileSet(), source, nil, parser.ParseComments) 81 if err != nil { 82 fmt.Fprintf(os.Stderr, "Error parsing %s: %s\n", source, err) 83 return descr, err 84 } 85 descr.Package = f.Name.Name 86 // If we're testing main, we will get errors from it clashing with func main. 87 if descr.Package == "main" { 88 descr.Package = "_main" 89 } 90 for _, d := range f.Decls { 91 if fd, ok := d.(*ast.FuncDecl); ok && fd.Recv == nil { 92 name := fd.Name.String() 93 if isTestMain(fd) { 94 descr.Main = name 95 } else if isTest(fd, 1, name, "Test") { 96 descr.TestFunctions = append(descr.TestFunctions, name) 97 } else if isTest(fd, 1, name, "Benchmark") { 98 descr.BenchFunctions = append(descr.BenchFunctions, name) 99 } else if isTest(fd, 1, name, "Fuzz") { 100 descr.FuzzFunctions = append(descr.FuzzFunctions, name) 101 } 102 } 103 } 104 // Get doc to find the examples for us :) 105 descr.Examples = append(descr.Examples, doc.Examples(f)...) 106 } 107 return descr, nil 108 } 109 110 // isTestMain returns true if fn is a TestMain(m *testing.M) function. 111 // Copied from Go sources. 112 func isTestMain(fn *ast.FuncDecl) bool { 113 if fn.Name.String() != "TestMain" || 114 fn.Type.Results != nil && len(fn.Type.Results.List) > 0 || 115 fn.Type.Params == nil || 116 len(fn.Type.Params.List) != 1 || 117 len(fn.Type.Params.List[0].Names) > 1 { 118 return false 119 } 120 ptr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr) 121 if !ok { 122 return false 123 } 124 // We can't easily check that the type is *testing.M 125 // because we don't know how testing has been imported, 126 // but at least check that it's *M or *something.M. 127 if name, ok := ptr.X.(*ast.Ident); ok && name.Name == "M" { 128 return true 129 } 130 if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == "M" { 131 return true 132 } 133 return false 134 } 135 136 // isTest returns true if the given function looks like a test. 137 // Copied from Go sources. 138 func isTest(fd *ast.FuncDecl, argLen int, name, prefix string) bool { 139 if !strings.HasPrefix(name, prefix) || fd.Recv != nil || len(fd.Type.Params.List) != argLen { 140 return false 141 } else if len(name) == len(prefix) { // "Test" is ok 142 return true 143 } 144 145 rune, _ := utf8.DecodeRuneInString(name[len(prefix):]) 146 return !unicode.IsLower(rune) 147 } 148 149 // testMainTmpl is the template for our test main, copied from Go's builtin one. 150 // Some bits are excluded because we don't support them and/or do them differently. 151 var testMainTmpl = template.Must(template.New("main").Parse(` 152 package main 153 154 import ( 155 _gostdlib_os "os" 156 {{if not .Benchmark}}_gostdlib_strings "strings"{{end}} 157 _gostdlib_testing "testing" 158 _gostdlib_testdeps "testing/internal/testdeps" 159 160 {{if .Coverage}} 161 _ "runtime/coverage" 162 _ "unsafe" 163 {{end}} 164 165 {{range .Imports}} 166 {{.}} 167 {{end}} 168 ) 169 170 var tests = []_gostdlib_testing.InternalTest{ 171 {{range .TestFunctions}} 172 {"{{.}}", {{$.Package}}.{{.}}}, 173 {{end}} 174 } 175 var examples = []_gostdlib_testing.InternalExample{ 176 {{range .Examples}} 177 {"{{.Name}}", {{$.Package}}.Example{{.Name}}, {{.Output | printf "%q"}}, {{.Unordered}}}, 178 {{end}} 179 } 180 181 var benchmarks = []_gostdlib_testing.InternalBenchmark{ 182 {{range .BenchFunctions}} 183 {"{{.}}", {{$.Package}}.{{.}}}, 184 {{end}} 185 } 186 187 var fuzzTargets = []_gostdlib_testing.InternalFuzzTarget{ 188 {{ range .FuzzFunctions }} 189 {"{{.}}", {{$.Package}}.{{.}}}, 190 {{ end }} 191 } 192 193 {{if .Coverage}} 194 //go:linkname runtime_coverage_processCoverTestDir runtime/coverage.processCoverTestDir 195 func runtime_coverage_processCoverTestDir(dir string, cfile string, cmode string, cpkgs string) error 196 197 //go:linkname testing_registerCover2 testing.registerCover2 198 func testing_registerCover2(mode string, tearDown func(coverprofile string, gocoverdir string) (string, error)) 199 200 //go:linkname runtime_coverage_markProfileEmitted runtime/coverage.markProfileEmitted 201 func runtime_coverage_markProfileEmitted(val bool) 202 203 func coverTearDown(coverprofile string, gocoverdir string) (string, error) { 204 var err error 205 if gocoverdir == "" { 206 gocoverdir, err = _gostdlib_os.MkdirTemp("", "gocoverdir") 207 if err != nil { 208 return "error setting GOCOVERDIR: bad os.MkdirTemp return", err 209 } 210 defer _gostdlib_os.RemoveAll(gocoverdir) 211 } 212 runtime_coverage_markProfileEmitted(true) 213 if err := runtime_coverage_processCoverTestDir(gocoverdir, coverprofile, "set", ""); err != nil { 214 return "error generating coverage report", err 215 } 216 return "", nil 217 } 218 {{end}} 219 220 var testDeps = _gostdlib_testdeps.TestDeps{} 221 222 func internalMain() int { 223 224 {{if .Coverage}} 225 coverfile := _gostdlib_os.Getenv("COVERAGE_FILE") 226 args := []string{_gostdlib_os.Args[0], "-test.v", "-test.coverprofile", coverfile} 227 testing_registerCover2("set", coverTearDown) 228 {{else}} 229 args := []string{_gostdlib_os.Args[0], "-test.v"} 230 {{end}} 231 {{if not .Benchmark}} 232 testVar := _gostdlib_os.Getenv("TESTS") 233 if testVar != "" { 234 testVar = _gostdlib_strings.ReplaceAll(testVar, " ", "|") 235 args = append(args, "-test.run", testVar) 236 } 237 _gostdlib_os.Args = append(args, _gostdlib_os.Args[1:]...) 238 m := _gostdlib_testing.MainStart(testDeps, tests, nil, fuzzTargets, examples) 239 {{else}} 240 args = append(args, "-test.bench", ".*") 241 _gostdlib_os.Args = append(args, _gostdlib_os.Args[1:]...) 242 m := _gostdlib_testing.MainStart(testDeps, nil, benchmarks, fuzzTargets, nil) 243 {{end}} 244 {{if .Main}} 245 {{.Package}}.{{.Main}}(m) 246 return 0 247 {{else}} 248 return m.Run() 249 {{end}} 250 } 251 252 func main() { 253 _gostdlib_os.Exit(internalMain()) 254 } 255 `)) 256 257 var oldTestMainTmpl = template.Must(template.New("oldmain").Parse(` 258 package main 259 260 import ( 261 _gostdlib_os "os" 262 {{if not .Benchmark}}_gostdlib_strings "strings"{{end}} 263 _gostdlib_testing "testing" 264 _gostdlib_testdeps "testing/internal/testdeps" 265 266 {{range .Imports}} 267 {{.}} 268 {{end}} 269 ) 270 271 var tests = []_gostdlib_testing.InternalTest{ 272 {{range .TestFunctions}} 273 {"{{.}}", {{$.Package}}.{{.}}}, 274 {{end}} 275 } 276 var examples = []_gostdlib_testing.InternalExample{ 277 {{range .Examples}} 278 {"{{.Name}}", {{$.Package}}.Example{{.Name}}, {{.Output | printf "%q"}}, {{.Unordered}}}, 279 {{end}} 280 } 281 282 var benchmarks = []_gostdlib_testing.InternalBenchmark{ 283 {{range .BenchFunctions}} 284 {"{{.}}", {{$.Package}}.{{.}}}, 285 {{end}} 286 } 287 288 {{ if .HasFuzz }} 289 var fuzzTargets = []_gostdlib_testing.InternalFuzzTarget{ 290 {{ range .FuzzFunctions }} 291 {"{{.}}", {{$.Package}}.{{.}}}, 292 {{ end }} 293 } 294 {{ end }} 295 296 {{if .Coverage}} 297 298 // Only updated by init functions, so no need for atomicity. 299 var ( 300 coverCounters = make(map[string][]uint32) 301 coverBlocks = make(map[string][]_gostdlib_testing.CoverBlock) 302 ) 303 304 func init() { 305 {{range $i, $c := .CoverVars}} 306 {{if $c.ImportName }} 307 coverRegisterFile({{printf "%q" $c.File}}, {{$c.ImportName}}.{{$c.Var}}.Count[:], {{$c.ImportName}}.{{$c.Var}}.Pos[:], {{$c.ImportName}}.{{$c.Var}}.NumStmt[:]) 308 {{end}} 309 {{end}} 310 } 311 312 func coverRegisterFile(fileName string, counter []uint32, pos []uint32, numStmts []uint16) { 313 if 3*len(counter) != len(pos) || len(counter) != len(numStmts) { 314 panic("coverage: mismatched sizes") 315 } 316 if coverCounters[fileName] != nil { 317 // Already registered. 318 return 319 } 320 coverCounters[fileName] = counter 321 block := make([]_gostdlib_testing.CoverBlock, len(counter)) 322 for i := range counter { 323 block[i] = _gostdlib_testing.CoverBlock{ 324 Line0: pos[3*i+0], 325 Col0: uint16(pos[3*i+2]), 326 Line1: pos[3*i+1], 327 Col1: uint16(pos[3*i+2]>>16), 328 Stmts: numStmts[i], 329 } 330 } 331 coverBlocks[fileName] = block 332 } 333 {{end}} 334 335 var testDeps = _gostdlib_testdeps.TestDeps{} 336 337 func main() { 338 {{if .Coverage}} 339 _gostdlib_testing.RegisterCover(_gostdlib_testing.Cover{ 340 Mode: "set", 341 Counters: coverCounters, 342 Blocks: coverBlocks, 343 CoveredPackages: "", 344 }) 345 coverfile := _gostdlib_os.Getenv("COVERAGE_FILE") 346 args := []string{_gostdlib_os.Args[0], "-test.v", "-test.coverprofile", coverfile} 347 {{else}} 348 args := []string{_gostdlib_os.Args[0], "-test.v"} 349 {{end}} 350 {{if not .Benchmark}} 351 testVar := _gostdlib_os.Getenv("TESTS") 352 if testVar != "" { 353 testVar = _gostdlib_strings.ReplaceAll(testVar, " ", "|") 354 args = append(args, "-test.run", testVar) 355 } 356 _gostdlib_os.Args = append(args, _gostdlib_os.Args[1:]...) 357 m := _gostdlib_testing.MainStart(testDeps, tests, nil,{{ if .HasFuzz }} fuzzTargets,{{ end }} examples) 358 {{else}} 359 args = append(args, "-test.bench", ".*") 360 _gostdlib_os.Args = append(args, _gostdlib_os.Args[1:]...) 361 m := _gostdlib_testing.MainStart(testDeps, nil, benchmarks,{{ if .HasFuzz }} fuzzTargets,{{ end }} nil) 362 {{end}} 363 364 {{if .Main}} 365 {{.Package}}.{{.Main}}(m) 366 {{else}} 367 _gostdlib_os.Exit(m.Run()) 368 {{end}} 369 } 370 `))