github.com/tiagovtristao/plz@v13.4.0+incompatible/tools/please_go_test/gotest/write_test_main.go (about)

     1  package gotest
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/doc"
     7  	"go/parser"
     8  	"go/token"
     9  	"os"
    10  	"os/exec"
    11  	"path"
    12  	"regexp"
    13  	"strconv"
    14  	"strings"
    15  	"text/template"
    16  	"unicode"
    17  	"unicode/utf8"
    18  )
    19  
    20  type testDescr struct {
    21  	Package   string
    22  	Main      string
    23  	Functions []string
    24  	Examples  []*doc.Example
    25  	CoverVars []CoverVar
    26  	Imports   []string
    27  	Version18 bool
    28  }
    29  
    30  // WriteTestMain templates a test main file from the given sources to the given output file.
    31  // This mimics what 'go test' does, although we do not currently support benchmarks or examples.
    32  func WriteTestMain(pkgDir, importPath string, version18 bool, sources []string, output string, coverVars []CoverVar) error {
    33  	testDescr, err := parseTestSources(sources)
    34  	if err != nil {
    35  		return err
    36  	}
    37  	testDescr.CoverVars = coverVars
    38  	testDescr.Version18 = version18
    39  	if len(testDescr.Functions) > 0 || len(testDescr.Examples) > 0 {
    40  		// Can't set this if there are no test functions, it'll be an unused import.
    41  		testDescr.Imports = extraImportPaths(testDescr.Package, pkgDir, importPath, coverVars)
    42  	}
    43  
    44  	f, err := os.Create(output)
    45  	if err != nil {
    46  		return err
    47  	}
    48  	defer f.Close()
    49  	// This might be consumed by other things.
    50  	fmt.Printf("Package: %s\n", testDescr.Package)
    51  	return testMainTmpl.Execute(f, testDescr)
    52  }
    53  
    54  // IsVersion18 returns true if the given Go tool is version 1.8 or greater.
    55  // This is needed because the test main signature has changed - it's not subject to the Go1 compatibility guarantee :(
    56  func IsVersion18(goTool string) bool {
    57  	cmd := exec.Command(goTool, "version")
    58  	out, err := cmd.Output()
    59  	if err != nil {
    60  		log.Fatalf("Can't determine Go version: %s", err)
    61  	}
    62  	return isVersion18(out)
    63  }
    64  
    65  func isVersion18(version []byte) bool {
    66  	r := regexp.MustCompile("go version go1.([0-9]+)[^0-9].*")
    67  	m := r.FindSubmatch(version)
    68  	if len(m) == 0 {
    69  		log.Warning("Failed to match %s", version)
    70  		return false
    71  	}
    72  	v, _ := strconv.Atoi(string(m[1]))
    73  	return v >= 8
    74  }
    75  
    76  // extraImportPaths returns the set of extra import paths that are needed.
    77  func extraImportPaths(pkg, pkgDir, importPath string, coverVars []CoverVar) []string {
    78  	pkgDir = collapseFinalDir(path.Join(pkgDir, pkg), importPath)
    79  	ret := []string{fmt.Sprintf("%s \"%s\"", pkg, path.Join(importPath, pkgDir))}
    80  	for i, v := range coverVars {
    81  		name := fmt.Sprintf("_cover%d", i)
    82  		coverVars[i].ImportName = name
    83  		ret = append(ret, fmt.Sprintf("%s \"%s\"", name, path.Join(importPath, v.ImportPath)))
    84  	}
    85  	return ret
    86  }
    87  
    88  // parseTestSources parses the test sources and returns the package and set of test functions in them.
    89  func parseTestSources(sources []string) (testDescr, error) {
    90  	descr := testDescr{}
    91  	for _, source := range sources {
    92  		f, err := parser.ParseFile(token.NewFileSet(), source, nil, parser.ParseComments)
    93  		if err != nil {
    94  			log.Errorf("Error parsing %s: %s", source, err)
    95  			return descr, err
    96  		}
    97  		descr.Package = f.Name.Name
    98  		// If we're testing main, we will get errors from it clashing with func main.
    99  		if descr.Package == "main" {
   100  			descr.Package = "_main"
   101  		}
   102  		for _, d := range f.Decls {
   103  			if fd, ok := d.(*ast.FuncDecl); ok && fd.Recv == nil {
   104  				name := fd.Name.String()
   105  				if isTestMain(fd) {
   106  					descr.Main = name
   107  				} else if isTest(fd, 1, name, "Test") {
   108  					descr.Functions = append(descr.Functions, name)
   109  				}
   110  			}
   111  		}
   112  		// Get doc to find the examples for us :)
   113  		descr.Examples = append(descr.Examples, doc.Examples(f)...)
   114  	}
   115  	return descr, nil
   116  }
   117  
   118  // isTestMain returns true if fn is a TestMain(m *testing.M) function.
   119  // Copied from Go sources.
   120  func isTestMain(fn *ast.FuncDecl) bool {
   121  	if fn.Name.String() != "TestMain" ||
   122  		fn.Type.Results != nil && len(fn.Type.Results.List) > 0 ||
   123  		fn.Type.Params == nil ||
   124  		len(fn.Type.Params.List) != 1 ||
   125  		len(fn.Type.Params.List[0].Names) > 1 {
   126  		return false
   127  	}
   128  	ptr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
   129  	if !ok {
   130  		return false
   131  	}
   132  	// We can't easily check that the type is *testing.M
   133  	// because we don't know how testing has been imported,
   134  	// but at least check that it's *M or *something.M.
   135  	if name, ok := ptr.X.(*ast.Ident); ok && name.Name == "M" {
   136  		return true
   137  	}
   138  	if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == "M" {
   139  		return true
   140  	}
   141  	return false
   142  }
   143  
   144  // isTest returns true if the given function looks like a test.
   145  // Copied from Go sources.
   146  func isTest(fd *ast.FuncDecl, argLen int, name, prefix string) bool {
   147  	if !strings.HasPrefix(name, prefix) || fd.Recv != nil || len(fd.Type.Params.List) != argLen {
   148  		return false
   149  	} else if len(name) == len(prefix) { // "Test" is ok
   150  		return true
   151  	}
   152  	rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
   153  	return !unicode.IsLower(rune)
   154  }
   155  
   156  // testMainTmpl is the template for our test main, copied from Go's builtin one.
   157  // Some bits are excluded because we don't support them and/or do them differently.
   158  var testMainTmpl = template.Must(template.New("main").Parse(`
   159  package main
   160  
   161  import (
   162  	"os"
   163  	"testing"
   164  {{if .Version18}}
   165          "testing/internal/testdeps"
   166  {{end}}
   167  
   168  {{range .Imports}}
   169  	{{.}}
   170  {{end}}
   171  )
   172  
   173  var tests = []testing.InternalTest{
   174  {{range .Functions}}
   175  	{"{{.}}", {{$.Package}}.{{.}}},
   176  {{end}}
   177  }
   178  var examples = []testing.InternalExample{
   179  {{range .Examples}}
   180  	{"{{.Name}}", {{$.Package}}.Example{{.Name}}, {{.Output | printf "%q"}}, {{.Unordered}}},
   181  {{end}}
   182  }
   183  
   184  {{if .CoverVars}}
   185  
   186  // Only updated by init functions, so no need for atomicity.
   187  var (
   188  	coverCounters = make(map[string][]uint32)
   189  	coverBlocks = make(map[string][]testing.CoverBlock)
   190  )
   191  
   192  func init() {
   193  	{{range $i, $c := .CoverVars}}
   194  	coverRegisterFile({{printf "%q" $c.File}}, {{$c.ImportName}}.{{$c.Var}}.Count[:], {{$c.ImportName}}.{{$c.Var}}.Pos[:], {{$c.ImportName}}.{{$c.Var}}.NumStmt[:])
   195  	{{end}}
   196  }
   197  
   198  func coverRegisterFile(fileName string, counter []uint32, pos []uint32, numStmts []uint16) {
   199  	if 3*len(counter) != len(pos) || len(counter) != len(numStmts) {
   200  		panic("coverage: mismatched sizes")
   201  	}
   202  	if coverCounters[fileName] != nil {
   203  		// Already registered.
   204  		return
   205  	}
   206  	coverCounters[fileName] = counter
   207  	block := make([]testing.CoverBlock, len(counter))
   208  	for i := range counter {
   209  		block[i] = testing.CoverBlock{
   210  			Line0: pos[3*i+0],
   211  			Col0: uint16(pos[3*i+2]),
   212  			Line1: pos[3*i+1],
   213  			Col1: uint16(pos[3*i+2]>>16),
   214  			Stmts: numStmts[i],
   215  		}
   216  	}
   217  	coverBlocks[fileName] = block
   218  }
   219  {{end}}
   220  
   221  {{if .Version18}}
   222  var testDeps = testdeps.TestDeps{}
   223  {{else}}
   224  func testDeps(pat, str string) (bool, error) {
   225      return pat == str, nil
   226  }
   227  {{end}}
   228  
   229  func main() {
   230  {{if .CoverVars}}
   231  	testing.RegisterCover(testing.Cover{
   232  		Mode: "set",
   233  		Counters: coverCounters,
   234  		Blocks: coverBlocks,
   235  		CoveredPackages: "",
   236  	})
   237      coverfile := os.Getenv("COVERAGE_FILE")
   238      args := []string{os.Args[0], "-test.v", "-test.coverprofile", coverfile}
   239  {{else}}
   240      args := []string{os.Args[0], "-test.v"}
   241  {{end}}
   242      testVar := os.Getenv("TESTS")
   243      if testVar != "" {
   244          args = append(args, "-test.run", testVar)
   245      }
   246      os.Args = append(args, os.Args[1:]...)
   247  	m := testing.MainStart(testDeps, tests, nil, examples)
   248  {{if .Main}}
   249  	{{.Package}}.{{.Main}}(m)
   250  {{else}}
   251  	os.Exit(m.Run())
   252  {{end}}
   253  }
   254  `))