github.com/0xKiwi/rules_go@v0.24.3/go/tools/builders/generate_test_main.go (about)

     1  /* Copyright 2016 The Bazel Authors. All rights reserved.
     2  
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7     http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  // Bare bones Go testing support for Bazel.
    17  
    18  package main
    19  
    20  import (
    21  	"flag"
    22  	"fmt"
    23  	"go/ast"
    24  	"go/doc"
    25  	"go/parser"
    26  	"go/token"
    27  	"os"
    28  	"path/filepath"
    29  	"sort"
    30  	"strings"
    31  	"text/template"
    32  )
    33  
    34  type Import struct {
    35  	Name string
    36  	Path string
    37  }
    38  
    39  type TestCase struct {
    40  	Package string
    41  	Name    string
    42  }
    43  
    44  type Example struct {
    45  	Package   string
    46  	Name      string
    47  	Output    string
    48  	Unordered bool
    49  }
    50  
    51  // Cases holds template data.
    52  type Cases struct {
    53  	RunDir     string
    54  	Imports    []*Import
    55  	Tests      []TestCase
    56  	Benchmarks []TestCase
    57  	Examples   []Example
    58  	TestMain   string
    59  	Coverage   bool
    60  	Pkgname    string
    61  }
    62  
    63  const testMainTpl = `
    64  package main
    65  import (
    66  	"flag"
    67  	"log"
    68  	"os"
    69  	"os/exec"
    70  	"path/filepath"
    71  	"runtime"
    72  	"strconv"
    73  	"testing"
    74  	"testing/internal/testdeps"
    75  
    76  {{if .Coverage}}
    77  	"github.com/bazelbuild/rules_go/go/tools/coverdata"
    78  {{end}}
    79  
    80  {{range $p := .Imports}}
    81  	{{$p.Name}} "{{$p.Path}}"
    82  {{end}}
    83  )
    84  
    85  var allTests = []testing.InternalTest{
    86  {{range .Tests}}
    87  	{"{{.Name}}", {{.Package}}.{{.Name}} },
    88  {{end}}
    89  }
    90  
    91  var benchmarks = []testing.InternalBenchmark{
    92  {{range .Benchmarks}}
    93  	{"{{.Name}}", {{.Package}}.{{.Name}} },
    94  {{end}}
    95  }
    96  
    97  var examples = []testing.InternalExample{
    98  {{range .Examples}}
    99  	{Name: "{{.Name}}", F: {{.Package}}.{{.Name}}, Output: {{printf "%q" .Output}}, Unordered: {{.Unordered}} },
   100  {{end}}
   101  }
   102  
   103  func testsInShard() []testing.InternalTest {
   104  	totalShards, err := strconv.Atoi(os.Getenv("TEST_TOTAL_SHARDS"))
   105  	if err != nil || totalShards <= 1 {
   106  		return allTests
   107  	}
   108  	shardIndex, err := strconv.Atoi(os.Getenv("TEST_SHARD_INDEX"))
   109  	if err != nil || shardIndex < 0 {
   110  		return allTests
   111  	}
   112  	tests := []testing.InternalTest{}
   113  	for i, t := range allTests {
   114  		if i % totalShards == shardIndex {
   115  			tests = append(tests, t)
   116  		}
   117  	}
   118  	return tests
   119  }
   120  
   121  func main() {
   122  	if shouldWrap() {
   123  		err := wrap("{{.Pkgname}}")
   124  		if xerr, ok := err.(*exec.ExitError); ok {
   125  			os.Exit(xerr.ExitCode())
   126  		} else if err != nil {
   127  			log.Print(err)
   128  			os.Exit(testWrapperAbnormalExit)
   129  		} else {
   130  			os.Exit(0)
   131  		}
   132  	}
   133  
   134  	// Check if we're being run by Bazel and change directories if so.
   135  	// TEST_SRCDIR and TEST_WORKSPACE are set by the Bazel test runner, so that makes a decent proxy.
   136  	testSrcdir := os.Getenv("TEST_SRCDIR")
   137  	testWorkspace := os.Getenv("TEST_WORKSPACE")
   138  	if testSrcdir != "" && testWorkspace != "" {
   139  		abs := filepath.Join(testSrcdir, testWorkspace, {{printf "%q" .RunDir}})
   140  		err := os.Chdir(abs)
   141  		// Ignore the Chdir err when on Windows, since it might have have runfiles symlinks.
   142  		// https://github.com/bazelbuild/rules_go/pull/1721#issuecomment-422145904
   143  		if err != nil && runtime.GOOS != "windows" {
   144  			log.Fatalf("could not change to test directory: %v", err)
   145  		}
   146  		if err == nil {
   147  			os.Setenv("PWD", abs)
   148  		}
   149  	}
   150  
   151  	m := testing.MainStart(testdeps.TestDeps{}, testsInShard(), benchmarks, examples)
   152  
   153  	if filter := os.Getenv("TESTBRIDGE_TEST_ONLY"); filter != "" {
   154  		flag.Lookup("test.run").Value.Set(filter)
   155  	}
   156  
   157  	{{if .Coverage}}
   158  	if len(coverdata.Cover.Counters) > 0 {
   159  		testing.RegisterCover(coverdata.Cover)
   160  	}
   161  	if coverageDat, ok := os.LookupEnv("COVERAGE_OUTPUT_FILE"); ok {
   162  		if testing.CoverMode() != "" {
   163  			flag.Lookup("test.coverprofile").Value.Set(coverageDat)
   164  		}
   165  	}
   166  	{{end}}
   167  
   168  	{{if not .TestMain}}
   169  	os.Exit(m.Run())
   170  	{{else}}
   171  	{{.TestMain}}(m)
   172  	{{end}}
   173  }
   174  `
   175  
   176  func genTestMain(args []string) error {
   177  	// Prepare our flags
   178  	args, err := expandParamsFiles(args)
   179  	if err != nil {
   180  		return err
   181  	}
   182  	imports := multiFlag{}
   183  	sources := multiFlag{}
   184  	flags := flag.NewFlagSet("GoTestGenTest", flag.ExitOnError)
   185  	goenv := envFlags(flags)
   186  	runDir := flags.String("rundir", ".", "Path to directory where tests should run.")
   187  	out := flags.String("output", "", "output file to write. Defaults to stdout.")
   188  	coverage := flags.Bool("coverage", false, "whether coverage is supported")
   189  	pkgname := flags.String("pkgname", "", "package name of test")
   190  	flags.Var(&imports, "import", "Packages to import")
   191  	flags.Var(&sources, "src", "Sources to process for tests")
   192  	if err := flags.Parse(args); err != nil {
   193  		return err
   194  	}
   195  	if err := goenv.checkFlags(); err != nil {
   196  		return err
   197  	}
   198  	// Process import args
   199  	importMap := map[string]*Import{}
   200  	for _, imp := range imports {
   201  		parts := strings.Split(imp, "=")
   202  		if len(parts) != 2 {
   203  			return fmt.Errorf("Invalid import %q specified", imp)
   204  		}
   205  		i := &Import{Name: parts[0], Path: parts[1]}
   206  		importMap[i.Name] = i
   207  	}
   208  	// Process source args
   209  	sourceList := []string{}
   210  	sourceMap := map[string]string{}
   211  	for _, s := range sources {
   212  		parts := strings.Split(s, "=")
   213  		if len(parts) != 2 {
   214  			return fmt.Errorf("Invalid source %q specified", s)
   215  		}
   216  		sourceList = append(sourceList, parts[1])
   217  		sourceMap[parts[1]] = parts[0]
   218  	}
   219  
   220  	// filter our input file list
   221  	filteredSrcs, err := filterAndSplitFiles(sourceList)
   222  	if err != nil {
   223  		return err
   224  	}
   225  	goSrcs := filteredSrcs.goSrcs
   226  
   227  	outFile := os.Stdout
   228  	if *out != "" {
   229  		var err error
   230  		outFile, err = os.Create(*out)
   231  		if err != nil {
   232  			return fmt.Errorf("os.Create(%q): %v", *out, err)
   233  		}
   234  		defer outFile.Close()
   235  	}
   236  
   237  	cases := Cases{
   238  		RunDir:   strings.Replace(filepath.FromSlash(*runDir), `\`, `\\`, -1),
   239  		Coverage: *coverage,
   240  		Pkgname:  *pkgname,
   241  	}
   242  
   243  	testFileSet := token.NewFileSet()
   244  	pkgs := map[string]bool{}
   245  	for _, f := range goSrcs {
   246  		parse, err := parser.ParseFile(testFileSet, f.filename, nil, parser.ParseComments)
   247  		if err != nil {
   248  			return fmt.Errorf("ParseFile(%q): %v", f.filename, err)
   249  		}
   250  		pkg := sourceMap[f.filename]
   251  		if strings.HasSuffix(parse.Name.String(), "_test") {
   252  			pkg += "_test"
   253  		}
   254  		for _, e := range doc.Examples(parse) {
   255  			if e.Output == "" && !e.EmptyOutput {
   256  				continue
   257  			}
   258  			cases.Examples = append(cases.Examples, Example{
   259  				Name:      "Example" + e.Name,
   260  				Package:   pkg,
   261  				Output:    e.Output,
   262  				Unordered: e.Unordered,
   263  			})
   264  			pkgs[pkg] = true
   265  		}
   266  		for _, d := range parse.Decls {
   267  			fn, ok := d.(*ast.FuncDecl)
   268  			if !ok {
   269  				continue
   270  			}
   271  			if fn.Recv != nil {
   272  				continue
   273  			}
   274  			if fn.Name.Name == "TestMain" {
   275  				// TestMain is not, itself, a test
   276  				pkgs[pkg] = true
   277  				cases.TestMain = fmt.Sprintf("%s.%s", pkg, fn.Name.Name)
   278  				continue
   279  			}
   280  
   281  			// Here we check the signature of the Test* function. To
   282  			// be considered a test:
   283  
   284  			// 1. The function should have a single argument.
   285  			if len(fn.Type.Params.List) != 1 {
   286  				continue
   287  			}
   288  
   289  			// 2. The function should return nothing.
   290  			if fn.Type.Results != nil {
   291  				continue
   292  			}
   293  
   294  			// 3. The only parameter should have a type identified as
   295  			//    *<something>.T
   296  			starExpr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
   297  			if !ok {
   298  				continue
   299  			}
   300  			selExpr, ok := starExpr.X.(*ast.SelectorExpr)
   301  			if !ok {
   302  				continue
   303  			}
   304  
   305  			// We do not descriminate on the referenced type of the
   306  			// parameter being *testing.T. Instead we assert that it
   307  			// should be *<something>.T. This is because the import
   308  			// could have been aliased as a different identifier.
   309  
   310  			if strings.HasPrefix(fn.Name.Name, "Test") {
   311  				if selExpr.Sel.Name != "T" {
   312  					continue
   313  				}
   314  				pkgs[pkg] = true
   315  				cases.Tests = append(cases.Tests, TestCase{
   316  					Package: pkg,
   317  					Name:    fn.Name.Name,
   318  				})
   319  			}
   320  			if strings.HasPrefix(fn.Name.Name, "Benchmark") {
   321  				if selExpr.Sel.Name != "B" {
   322  					continue
   323  				}
   324  				pkgs[pkg] = true
   325  				cases.Benchmarks = append(cases.Benchmarks, TestCase{
   326  					Package: pkg,
   327  					Name:    fn.Name.Name,
   328  				})
   329  			}
   330  		}
   331  	}
   332  
   333  	for name := range importMap {
   334  		// Set the names for all unused imports to "_"
   335  		if !pkgs[name] {
   336  			importMap[name].Name = "_"
   337  		}
   338  		cases.Imports = append(cases.Imports, importMap[name])
   339  	}
   340  	sort.Slice(cases.Imports, func(i, j int) bool {
   341  		return cases.Imports[i].Name < cases.Imports[j].Name
   342  	})
   343  	tpl := template.Must(template.New("source").Parse(testMainTpl))
   344  	if err := tpl.Execute(outFile, &cases); err != nil {
   345  		return fmt.Errorf("template.Execute(%v): %v", cases, err)
   346  	}
   347  	return nil
   348  }