github.com/please-build/go-rules/tools/please_go@v0.0.0-20240319165128-ea27d6f5caba/test/write_test_main.go (about)

     1  package test
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/doc"
     7  	"go/parser"
     8  	"go/token"
     9  	"os"
    10  	"strings"
    11  	"text/template"
    12  	"unicode"
    13  	"unicode/utf8"
    14  )
    15  
    16  type testDescr struct {
    17  	Package        string
    18  	Main           string
    19  	TestFunctions  []string
    20  	BenchFunctions []string
    21  	FuzzFunctions  []string
    22  	Examples       []*doc.Example
    23  	CoverVars      []CoverVar
    24  	Imports        []string
    25  	Coverage       bool
    26  	Benchmark      bool
    27  	HasFuzz        bool
    28  }
    29  
    30  // WriteTestMain templates a test main file from the given sources to the given output file.
    31  func WriteTestMain(testPackage string, sources []string, output string, coverage bool, coverVars []CoverVar, benchmark, hasFuzz, coverageRedesign bool) error {
    32  	testDescr, err := parseTestSources(sources)
    33  	if err != nil {
    34  		return err
    35  	}
    36  	testDescr.Coverage = coverage
    37  	testDescr.CoverVars = coverVars
    38  	if len(testDescr.TestFunctions) > 0 || len(testDescr.BenchFunctions) > 0 || len(testDescr.Examples) > 0 || len(testDescr.FuzzFunctions) > 0 || testDescr.Main != "" {
    39  		// Can't set this if there are no test functions, it'll be an unused import.
    40  		if coverageRedesign {
    41  			testDescr.Imports = []string{fmt.Sprintf("%s \"%s\"", testDescr.Package, testPackage)}
    42  		} else {
    43  			testDescr.Imports = extraImportPaths(testPackage, testDescr.Package, testDescr.CoverVars)
    44  		}
    45  	}
    46  
    47  	testDescr.Benchmark = benchmark
    48  	testDescr.HasFuzz = hasFuzz
    49  
    50  	f, err := os.Create(output)
    51  	if err != nil {
    52  		return err
    53  	}
    54  	defer f.Close()
    55  	// This might be consumed by other things.
    56  	fmt.Printf("Package: %s\n", testDescr.Package)
    57  
    58  	if coverageRedesign {
    59  		return testMainTmpl.Execute(f, testDescr)
    60  	}
    61  	return oldTestMainTmpl.Execute(f, testDescr)
    62  }
    63  
    64  func extraImportPaths(testPackage, alias string, coverVars []CoverVar) []string {
    65  	ret := make([]string, 0, len(coverVars)+1)
    66  	ret = append(ret, fmt.Sprintf("%s \"%s\"", alias, testPackage))
    67  
    68  	for i, v := range coverVars {
    69  		name := fmt.Sprintf("_cover%d", i)
    70  		coverVars[i].ImportName = name
    71  		ret = append(ret, fmt.Sprintf("%s \"%s\"", name, v.ImportPath))
    72  	}
    73  	return ret
    74  }
    75  
    76  // parseTestSources parses the test sources and returns the package and set of test functions in them.
    77  func parseTestSources(sources []string) (testDescr, error) {
    78  	descr := testDescr{}
    79  	for _, source := range sources {
    80  		f, err := parser.ParseFile(token.NewFileSet(), source, nil, parser.ParseComments)
    81  		if err != nil {
    82  			fmt.Fprintf(os.Stderr, "Error parsing %s: %s\n", source, err)
    83  			return descr, err
    84  		}
    85  		descr.Package = f.Name.Name
    86  		// If we're testing main, we will get errors from it clashing with func main.
    87  		if descr.Package == "main" {
    88  			descr.Package = "_main"
    89  		}
    90  		for _, d := range f.Decls {
    91  			if fd, ok := d.(*ast.FuncDecl); ok && fd.Recv == nil {
    92  				name := fd.Name.String()
    93  				if isTestMain(fd) {
    94  					descr.Main = name
    95  				} else if isTest(fd, 1, name, "Test") {
    96  					descr.TestFunctions = append(descr.TestFunctions, name)
    97  				} else if isTest(fd, 1, name, "Benchmark") {
    98  					descr.BenchFunctions = append(descr.BenchFunctions, name)
    99  				} else if isTest(fd, 1, name, "Fuzz") {
   100  					descr.FuzzFunctions = append(descr.FuzzFunctions, name)
   101  				}
   102  			}
   103  		}
   104  		// Get doc to find the examples for us :)
   105  		descr.Examples = append(descr.Examples, doc.Examples(f)...)
   106  	}
   107  	return descr, nil
   108  }
   109  
   110  // isTestMain returns true if fn is a TestMain(m *testing.M) function.
   111  // Copied from Go sources.
   112  func isTestMain(fn *ast.FuncDecl) bool {
   113  	if fn.Name.String() != "TestMain" ||
   114  		fn.Type.Results != nil && len(fn.Type.Results.List) > 0 ||
   115  		fn.Type.Params == nil ||
   116  		len(fn.Type.Params.List) != 1 ||
   117  		len(fn.Type.Params.List[0].Names) > 1 {
   118  		return false
   119  	}
   120  	ptr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
   121  	if !ok {
   122  		return false
   123  	}
   124  	// We can't easily check that the type is *testing.M
   125  	// because we don't know how testing has been imported,
   126  	// but at least check that it's *M or *something.M.
   127  	if name, ok := ptr.X.(*ast.Ident); ok && name.Name == "M" {
   128  		return true
   129  	}
   130  	if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == "M" {
   131  		return true
   132  	}
   133  	return false
   134  }
   135  
   136  // isTest returns true if the given function looks like a test.
   137  // Copied from Go sources.
   138  func isTest(fd *ast.FuncDecl, argLen int, name, prefix string) bool {
   139  	if !strings.HasPrefix(name, prefix) || fd.Recv != nil || len(fd.Type.Params.List) != argLen {
   140  		return false
   141  	} else if len(name) == len(prefix) { // "Test" is ok
   142  		return true
   143  	}
   144  
   145  	rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
   146  	return !unicode.IsLower(rune)
   147  }
   148  
   149  // testMainTmpl is the template for our test main, copied from Go's builtin one.
   150  // Some bits are excluded because we don't support them and/or do them differently.
   151  var testMainTmpl = template.Must(template.New("main").Parse(`
   152  package main
   153  
   154  import (
   155  	_gostdlib_os "os"
   156  	{{if not .Benchmark}}_gostdlib_strings "strings"{{end}}
   157  	_gostdlib_testing "testing"
   158  	_gostdlib_testdeps "testing/internal/testdeps"
   159  
   160  {{if .Coverage}}
   161  	_ "runtime/coverage"
   162  	_ "unsafe"
   163  {{end}}
   164  
   165  {{range .Imports}}
   166  	{{.}}
   167  {{end}}
   168  )
   169  
   170  var tests = []_gostdlib_testing.InternalTest{
   171  {{range .TestFunctions}}
   172  	{"{{.}}", {{$.Package}}.{{.}}},
   173  {{end}}
   174  }
   175  var examples = []_gostdlib_testing.InternalExample{
   176  {{range .Examples}}
   177  	{"{{.Name}}", {{$.Package}}.Example{{.Name}}, {{.Output | printf "%q"}}, {{.Unordered}}},
   178  {{end}}
   179  }
   180  
   181  var benchmarks = []_gostdlib_testing.InternalBenchmark{
   182  {{range .BenchFunctions}}
   183  	{"{{.}}", {{$.Package}}.{{.}}},
   184  {{end}}
   185  }
   186  
   187  var fuzzTargets = []_gostdlib_testing.InternalFuzzTarget{
   188  {{ range .FuzzFunctions }}
   189  	{"{{.}}", {{$.Package}}.{{.}}},
   190  {{ end }}
   191  }
   192  
   193  {{if .Coverage}}
   194  //go:linkname runtime_coverage_processCoverTestDir runtime/coverage.processCoverTestDir
   195  func runtime_coverage_processCoverTestDir(dir string, cfile string, cmode string, cpkgs string) error
   196  
   197  //go:linkname testing_registerCover2 testing.registerCover2
   198  func testing_registerCover2(mode string, tearDown func(coverprofile string, gocoverdir string) (string, error))
   199  
   200  //go:linkname runtime_coverage_markProfileEmitted runtime/coverage.markProfileEmitted
   201  func runtime_coverage_markProfileEmitted(val bool)
   202  
   203  func coverTearDown(coverprofile string, gocoverdir string) (string, error) {
   204  	var err error
   205  	if gocoverdir == "" {
   206  		gocoverdir, err = _gostdlib_os.MkdirTemp("", "gocoverdir")
   207  		if err != nil {
   208  			return "error setting GOCOVERDIR: bad os.MkdirTemp return", err
   209  		}
   210  		defer _gostdlib_os.RemoveAll(gocoverdir)
   211  	}
   212  	runtime_coverage_markProfileEmitted(true)
   213  	if err := runtime_coverage_processCoverTestDir(gocoverdir, coverprofile, "set", ""); err != nil {
   214  		return "error generating coverage report", err
   215  	}
   216  	return "", nil
   217  }
   218  {{end}}
   219  
   220  var testDeps = _gostdlib_testdeps.TestDeps{}
   221  
   222  func internalMain() int {
   223  
   224  {{if .Coverage}}
   225      coverfile := _gostdlib_os.Getenv("COVERAGE_FILE")
   226      args := []string{_gostdlib_os.Args[0], "-test.v", "-test.coverprofile", coverfile}
   227  	testing_registerCover2("set", coverTearDown)
   228  {{else}}
   229      args := []string{_gostdlib_os.Args[0], "-test.v"}
   230  {{end}}
   231  {{if not .Benchmark}}
   232      testVar := _gostdlib_os.Getenv("TESTS")
   233      if testVar != "" {
   234  		testVar = _gostdlib_strings.ReplaceAll(testVar, " ", "|")
   235  		args = append(args, "-test.run", testVar)
   236      }
   237      _gostdlib_os.Args = append(args, _gostdlib_os.Args[1:]...)
   238  	m := _gostdlib_testing.MainStart(testDeps, tests, nil, fuzzTargets, examples)
   239  {{else}}
   240  	args = append(args, "-test.bench", ".*")
   241  	_gostdlib_os.Args = append(args, _gostdlib_os.Args[1:]...)
   242  	m := _gostdlib_testing.MainStart(testDeps, nil, benchmarks, fuzzTargets, nil)
   243  {{end}}
   244  {{if .Main}}
   245  	{{.Package}}.{{.Main}}(m)
   246      return 0
   247  {{else}}
   248  	return m.Run()
   249  {{end}}
   250  }
   251  
   252  func main() {
   253  	_gostdlib_os.Exit(internalMain())
   254  }
   255  `))
   256  
   257  var oldTestMainTmpl = template.Must(template.New("oldmain").Parse(`
   258  package main
   259  
   260  import (
   261  	_gostdlib_os "os"
   262  	{{if not .Benchmark}}_gostdlib_strings "strings"{{end}}
   263  	_gostdlib_testing "testing"
   264  	_gostdlib_testdeps "testing/internal/testdeps"
   265  
   266  {{range .Imports}}
   267  	{{.}}
   268  {{end}}
   269  )
   270  
   271  var tests = []_gostdlib_testing.InternalTest{
   272  {{range .TestFunctions}}
   273  	{"{{.}}", {{$.Package}}.{{.}}},
   274  {{end}}
   275  }
   276  var examples = []_gostdlib_testing.InternalExample{
   277  {{range .Examples}}
   278  	{"{{.Name}}", {{$.Package}}.Example{{.Name}}, {{.Output | printf "%q"}}, {{.Unordered}}},
   279  {{end}}
   280  }
   281  
   282  var benchmarks = []_gostdlib_testing.InternalBenchmark{
   283  {{range .BenchFunctions}}
   284  	{"{{.}}", {{$.Package}}.{{.}}},
   285  {{end}}
   286  }
   287  
   288  {{ if .HasFuzz }}
   289  var fuzzTargets = []_gostdlib_testing.InternalFuzzTarget{
   290  {{ range .FuzzFunctions }}
   291  	{"{{.}}", {{$.Package}}.{{.}}},
   292  {{ end }}
   293  }
   294  {{ end }}
   295  
   296  {{if .Coverage}}
   297  
   298  // Only updated by init functions, so no need for atomicity.
   299  var (
   300  	coverCounters = make(map[string][]uint32)
   301  	coverBlocks = make(map[string][]_gostdlib_testing.CoverBlock)
   302  )
   303  
   304  func init() {
   305  	{{range $i, $c := .CoverVars}}
   306  		{{if $c.ImportName }}
   307  			coverRegisterFile({{printf "%q" $c.File}}, {{$c.ImportName}}.{{$c.Var}}.Count[:], {{$c.ImportName}}.{{$c.Var}}.Pos[:], {{$c.ImportName}}.{{$c.Var}}.NumStmt[:])
   308  		{{end}}
   309  	{{end}}
   310  }
   311  
   312  func coverRegisterFile(fileName string, counter []uint32, pos []uint32, numStmts []uint16) {
   313  	if 3*len(counter) != len(pos) || len(counter) != len(numStmts) {
   314  		panic("coverage: mismatched sizes")
   315  	}
   316  	if coverCounters[fileName] != nil {
   317  		// Already registered.
   318  		return
   319  	}
   320  	coverCounters[fileName] = counter
   321  	block := make([]_gostdlib_testing.CoverBlock, len(counter))
   322  	for i := range counter {
   323  		block[i] = _gostdlib_testing.CoverBlock{
   324  			Line0: pos[3*i+0],
   325  			Col0: uint16(pos[3*i+2]),
   326  			Line1: pos[3*i+1],
   327  			Col1: uint16(pos[3*i+2]>>16),
   328  			Stmts: numStmts[i],
   329  		}
   330  	}
   331  	coverBlocks[fileName] = block
   332  }
   333  {{end}}
   334  
   335  var testDeps = _gostdlib_testdeps.TestDeps{}
   336  
   337  func main() {
   338  {{if .Coverage}}
   339  	_gostdlib_testing.RegisterCover(_gostdlib_testing.Cover{
   340  		Mode: "set",
   341  		Counters: coverCounters,
   342  		Blocks: coverBlocks,
   343  		CoveredPackages: "",
   344  	})
   345      coverfile := _gostdlib_os.Getenv("COVERAGE_FILE")
   346      args := []string{_gostdlib_os.Args[0], "-test.v", "-test.coverprofile", coverfile}
   347  {{else}}
   348      args := []string{_gostdlib_os.Args[0], "-test.v"}
   349  {{end}}
   350  {{if not .Benchmark}}
   351      testVar := _gostdlib_os.Getenv("TESTS")
   352      if testVar != "" {
   353  		testVar = _gostdlib_strings.ReplaceAll(testVar, " ", "|")
   354  		args = append(args, "-test.run", testVar)
   355      }
   356      _gostdlib_os.Args = append(args, _gostdlib_os.Args[1:]...)
   357  	m := _gostdlib_testing.MainStart(testDeps, tests, nil,{{ if .HasFuzz }} fuzzTargets,{{ end }} examples)
   358  {{else}}
   359  	args = append(args, "-test.bench", ".*")
   360  	_gostdlib_os.Args = append(args, _gostdlib_os.Args[1:]...)
   361  	m := _gostdlib_testing.MainStart(testDeps, nil, benchmarks,{{ if .HasFuzz }} fuzzTargets,{{ end }} nil)
   362  {{end}}
   363  
   364  {{if .Main}}
   365  	{{.Package}}.{{.Main}}(m)
   366  {{else}}
   367  	_gostdlib_os.Exit(m.Run())
   368  {{end}}
   369  }
   370  `))