github.phpd.cn/thought-machine/please@v12.2.0+incompatible/tools/please_go_test/gotest/write_test_main.go (about)

     1  package gotest
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/parser"
     7  	"go/token"
     8  	"os"
     9  	"os/exec"
    10  	"path"
    11  	"regexp"
    12  	"strconv"
    13  	"strings"
    14  	"text/template"
    15  	"unicode"
    16  	"unicode/utf8"
    17  )
    18  
    19  type testDescr struct {
    20  	Package   string
    21  	Main      string
    22  	Functions []string
    23  	CoverVars []CoverVar
    24  	Imports   []string
    25  	Version18 bool
    26  }
    27  
    28  // WriteTestMain templates a test main file from the given sources to the given output file.
    29  // This mimics what 'go test' does, although we do not currently support benchmarks or examples.
    30  func WriteTestMain(pkgDir string, version18 bool, sources []string, output string, coverVars []CoverVar) error {
    31  	testDescr, err := parseTestSources(sources)
    32  	if err != nil {
    33  		return err
    34  	}
    35  	testDescr.CoverVars = coverVars
    36  	testDescr.Version18 = version18
    37  	if len(testDescr.Functions) > 0 {
    38  		// Can't set this if there are no test functions, it'll be an unused import.
    39  		testDescr.Imports = extraImportPaths(testDescr.Package, pkgDir, coverVars)
    40  	}
    41  
    42  	f, err := os.Create(output)
    43  	if err != nil {
    44  		return err
    45  	}
    46  	defer f.Close()
    47  	// This might be consumed by other things.
    48  	fmt.Printf("Package: %s\n", testDescr.Package)
    49  	return testMainTmpl.Execute(f, testDescr)
    50  }
    51  
    52  // IsVersion18 returns true if the given Go tool is version 1.8 or greater.
    53  // This is needed because the test main signature has changed - it's not subject to the Go1 compatibility guarantee :(
    54  func IsVersion18(goTool string) bool {
    55  	cmd := exec.Command(goTool, "version")
    56  	out, err := cmd.Output()
    57  	if err != nil {
    58  		log.Fatalf("Can't determine Go version: %s", err)
    59  	}
    60  	return isVersion18(out)
    61  }
    62  
    63  func isVersion18(version []byte) bool {
    64  	r := regexp.MustCompile("go version go1.([0-9]+)[^0-9].*")
    65  	m := r.FindSubmatch(version)
    66  	if len(m) == 0 {
    67  		log.Warning("Failed to match %s", version)
    68  		return false
    69  	}
    70  	v, _ := strconv.Atoi(string(m[1]))
    71  	return v >= 8
    72  }
    73  
    74  // extraImportPaths returns the set of extra import paths that are needed.
    75  func extraImportPaths(pkg, pkgDir string, coverVars []CoverVar) []string {
    76  	pkgDir = collapseFinalDir(path.Join(strings.TrimPrefix(pkgDir, "src/"), pkg))
    77  	ret := []string{fmt.Sprintf("%s \"%s\"", pkg, pkgDir)}
    78  	for i, v := range coverVars {
    79  		name := fmt.Sprintf("_cover%d", i)
    80  		coverVars[i].ImportName = name
    81  		ret = append(ret, fmt.Sprintf("%s \"%s\"", name, v.ImportPath))
    82  	}
    83  	return ret
    84  }
    85  
    86  // parseTestSources parses the test sources and returns the package and set of test functions in them.
    87  func parseTestSources(sources []string) (testDescr, error) {
    88  	descr := testDescr{}
    89  	for _, source := range sources {
    90  		f, err := parser.ParseFile(token.NewFileSet(), source, nil, 0)
    91  		if err != nil {
    92  			log.Errorf("Error parsing %s: %s", source, err)
    93  			return descr, err
    94  		}
    95  		descr.Package = f.Name.Name
    96  		// If we're testing main, we will get errors from it clashing with func main.
    97  		if descr.Package == "main" {
    98  			descr.Package = "_main"
    99  		}
   100  		for _, d := range f.Decls {
   101  			if fd, ok := d.(*ast.FuncDecl); ok && fd.Recv == nil {
   102  				name := fd.Name.String()
   103  				if isTestMain(fd) {
   104  					descr.Main = name
   105  				} else if isTest(name, "Test") {
   106  					descr.Functions = append(descr.Functions, name)
   107  				}
   108  			}
   109  		}
   110  	}
   111  	return descr, nil
   112  }
   113  
   114  // isTestMain returns true if fn is a TestMain(m *testing.M) function.
   115  // Copied from Go sources.
   116  func isTestMain(fn *ast.FuncDecl) bool {
   117  	if fn.Name.String() != "TestMain" ||
   118  		fn.Type.Results != nil && len(fn.Type.Results.List) > 0 ||
   119  		fn.Type.Params == nil ||
   120  		len(fn.Type.Params.List) != 1 ||
   121  		len(fn.Type.Params.List[0].Names) > 1 {
   122  		return false
   123  	}
   124  	ptr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
   125  	if !ok {
   126  		return false
   127  	}
   128  	// We can't easily check that the type is *testing.M
   129  	// because we don't know how testing has been imported,
   130  	// but at least check that it's *M or *something.M.
   131  	if name, ok := ptr.X.(*ast.Ident); ok && name.Name == "M" {
   132  		return true
   133  	}
   134  	if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == "M" {
   135  		return true
   136  	}
   137  	return false
   138  }
   139  
   140  // isTest returns true if the given function looks like a test.
   141  // Copied from Go sources.
   142  func isTest(name, prefix string) bool {
   143  	if !strings.HasPrefix(name, prefix) {
   144  		return false
   145  	}
   146  	if len(name) == len(prefix) { // "Test" is ok
   147  		return true
   148  	}
   149  	rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
   150  	return !unicode.IsLower(rune)
   151  }
   152  
   153  // testMainTmpl is the template for our test main, copied from Go's builtin one.
   154  // Some bits are excluded because we don't support them and/or do them differently.
   155  var testMainTmpl = template.Must(template.New("main").Parse(`
   156  package main
   157  
   158  import (
   159  	"os"
   160  	"testing"
   161  {{if .Version18}}
   162          "testing/internal/testdeps"
   163  {{end}}
   164  
   165  {{range .Imports}}
   166  	{{.}}
   167  {{end}}
   168  )
   169  
   170  var tests = []testing.InternalTest{
   171  {{range .Functions}}
   172  	{"{{.}}", {{$.Package}}.{{.}}},
   173  {{end}}
   174  }
   175  
   176  {{if .CoverVars}}
   177  
   178  // Only updated by init functions, so no need for atomicity.
   179  var (
   180  	coverCounters = make(map[string][]uint32)
   181  	coverBlocks = make(map[string][]testing.CoverBlock)
   182  )
   183  
   184  func init() {
   185  	{{range $i, $c := .CoverVars}}
   186  	coverRegisterFile({{printf "%q" $c.File}}, {{$c.ImportName}}.{{$c.Var}}.Count[:], {{$c.ImportName}}.{{$c.Var}}.Pos[:], {{$c.ImportName}}.{{$c.Var}}.NumStmt[:])
   187  	{{end}}
   188  }
   189  
   190  func coverRegisterFile(fileName string, counter []uint32, pos []uint32, numStmts []uint16) {
   191  	if 3*len(counter) != len(pos) || len(counter) != len(numStmts) {
   192  		panic("coverage: mismatched sizes")
   193  	}
   194  	if coverCounters[fileName] != nil {
   195  		// Already registered.
   196  		return
   197  	}
   198  	coverCounters[fileName] = counter
   199  	block := make([]testing.CoverBlock, len(counter))
   200  	for i := range counter {
   201  		block[i] = testing.CoverBlock{
   202  			Line0: pos[3*i+0],
   203  			Col0: uint16(pos[3*i+2]),
   204  			Line1: pos[3*i+1],
   205  			Col1: uint16(pos[3*i+2]>>16),
   206  			Stmts: numStmts[i],
   207  		}
   208  	}
   209  	coverBlocks[fileName] = block
   210  }
   211  {{end}}
   212  
   213  {{if .Version18}}
   214  var testDeps = testdeps.TestDeps{}
   215  {{else}}
   216  func testDeps(pat, str string) (bool, error) {
   217      return pat == str, nil
   218  }
   219  {{end}}
   220  
   221  func main() {
   222  {{if .CoverVars}}
   223  	testing.RegisterCover(testing.Cover{
   224  		Mode: "set",
   225  		Counters: coverCounters,
   226  		Blocks: coverBlocks,
   227  		CoveredPackages: "",
   228  	})
   229      coverfile := os.Getenv("COVERAGE_FILE")
   230      args := []string{os.Args[0], "-test.v", "-test.coverprofile", coverfile}
   231  {{else}}
   232      args := []string{os.Args[0], "-test.v"}
   233  {{end}}
   234      testVar := os.Getenv("TESTS")
   235      if testVar != "" {
   236          args = append(args, "-test.run", testVar)
   237      }
   238      os.Args = append(args, os.Args[1:]...)
   239  	benchmarks := []testing.InternalBenchmark{}
   240  	var examples = []testing.InternalExample{}
   241  	m := testing.MainStart(testDeps, tests, benchmarks, examples)
   242  {{if .Main}}
   243  	{{.Package}}.{{.Main}}(m)
   244  {{else}}
   245  	os.Exit(m.Run())
   246  {{end}}
   247  }
   248  `))