github.phpd.cn/thought-machine/please@v12.2.0+incompatible/tools/please_go_test/gotest/write_test_main.go (about) 1 package gotest 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/parser" 7 "go/token" 8 "os" 9 "os/exec" 10 "path" 11 "regexp" 12 "strconv" 13 "strings" 14 "text/template" 15 "unicode" 16 "unicode/utf8" 17 ) 18 19 type testDescr struct { 20 Package string 21 Main string 22 Functions []string 23 CoverVars []CoverVar 24 Imports []string 25 Version18 bool 26 } 27 28 // WriteTestMain templates a test main file from the given sources to the given output file. 29 // This mimics what 'go test' does, although we do not currently support benchmarks or examples. 30 func WriteTestMain(pkgDir string, version18 bool, sources []string, output string, coverVars []CoverVar) error { 31 testDescr, err := parseTestSources(sources) 32 if err != nil { 33 return err 34 } 35 testDescr.CoverVars = coverVars 36 testDescr.Version18 = version18 37 if len(testDescr.Functions) > 0 { 38 // Can't set this if there are no test functions, it'll be an unused import. 39 testDescr.Imports = extraImportPaths(testDescr.Package, pkgDir, coverVars) 40 } 41 42 f, err := os.Create(output) 43 if err != nil { 44 return err 45 } 46 defer f.Close() 47 // This might be consumed by other things. 48 fmt.Printf("Package: %s\n", testDescr.Package) 49 return testMainTmpl.Execute(f, testDescr) 50 } 51 52 // IsVersion18 returns true if the given Go tool is version 1.8 or greater. 53 // This is needed because the test main signature has changed - it's not subject to the Go1 compatibility guarantee :( 54 func IsVersion18(goTool string) bool { 55 cmd := exec.Command(goTool, "version") 56 out, err := cmd.Output() 57 if err != nil { 58 log.Fatalf("Can't determine Go version: %s", err) 59 } 60 return isVersion18(out) 61 } 62 63 func isVersion18(version []byte) bool { 64 r := regexp.MustCompile("go version go1.([0-9]+)[^0-9].*") 65 m := r.FindSubmatch(version) 66 if len(m) == 0 { 67 log.Warning("Failed to match %s", version) 68 return false 69 } 70 v, _ := strconv.Atoi(string(m[1])) 71 return v >= 8 72 } 73 74 // extraImportPaths returns the set of extra import paths that are needed. 75 func extraImportPaths(pkg, pkgDir string, coverVars []CoverVar) []string { 76 pkgDir = collapseFinalDir(path.Join(strings.TrimPrefix(pkgDir, "src/"), pkg)) 77 ret := []string{fmt.Sprintf("%s \"%s\"", pkg, pkgDir)} 78 for i, v := range coverVars { 79 name := fmt.Sprintf("_cover%d", i) 80 coverVars[i].ImportName = name 81 ret = append(ret, fmt.Sprintf("%s \"%s\"", name, v.ImportPath)) 82 } 83 return ret 84 } 85 86 // parseTestSources parses the test sources and returns the package and set of test functions in them. 87 func parseTestSources(sources []string) (testDescr, error) { 88 descr := testDescr{} 89 for _, source := range sources { 90 f, err := parser.ParseFile(token.NewFileSet(), source, nil, 0) 91 if err != nil { 92 log.Errorf("Error parsing %s: %s", source, err) 93 return descr, err 94 } 95 descr.Package = f.Name.Name 96 // If we're testing main, we will get errors from it clashing with func main. 97 if descr.Package == "main" { 98 descr.Package = "_main" 99 } 100 for _, d := range f.Decls { 101 if fd, ok := d.(*ast.FuncDecl); ok && fd.Recv == nil { 102 name := fd.Name.String() 103 if isTestMain(fd) { 104 descr.Main = name 105 } else if isTest(name, "Test") { 106 descr.Functions = append(descr.Functions, name) 107 } 108 } 109 } 110 } 111 return descr, nil 112 } 113 114 // isTestMain returns true if fn is a TestMain(m *testing.M) function. 115 // Copied from Go sources. 116 func isTestMain(fn *ast.FuncDecl) bool { 117 if fn.Name.String() != "TestMain" || 118 fn.Type.Results != nil && len(fn.Type.Results.List) > 0 || 119 fn.Type.Params == nil || 120 len(fn.Type.Params.List) != 1 || 121 len(fn.Type.Params.List[0].Names) > 1 { 122 return false 123 } 124 ptr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr) 125 if !ok { 126 return false 127 } 128 // We can't easily check that the type is *testing.M 129 // because we don't know how testing has been imported, 130 // but at least check that it's *M or *something.M. 131 if name, ok := ptr.X.(*ast.Ident); ok && name.Name == "M" { 132 return true 133 } 134 if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == "M" { 135 return true 136 } 137 return false 138 } 139 140 // isTest returns true if the given function looks like a test. 141 // Copied from Go sources. 142 func isTest(name, prefix string) bool { 143 if !strings.HasPrefix(name, prefix) { 144 return false 145 } 146 if len(name) == len(prefix) { // "Test" is ok 147 return true 148 } 149 rune, _ := utf8.DecodeRuneInString(name[len(prefix):]) 150 return !unicode.IsLower(rune) 151 } 152 153 // testMainTmpl is the template for our test main, copied from Go's builtin one. 154 // Some bits are excluded because we don't support them and/or do them differently. 155 var testMainTmpl = template.Must(template.New("main").Parse(` 156 package main 157 158 import ( 159 "os" 160 "testing" 161 {{if .Version18}} 162 "testing/internal/testdeps" 163 {{end}} 164 165 {{range .Imports}} 166 {{.}} 167 {{end}} 168 ) 169 170 var tests = []testing.InternalTest{ 171 {{range .Functions}} 172 {"{{.}}", {{$.Package}}.{{.}}}, 173 {{end}} 174 } 175 176 {{if .CoverVars}} 177 178 // Only updated by init functions, so no need for atomicity. 179 var ( 180 coverCounters = make(map[string][]uint32) 181 coverBlocks = make(map[string][]testing.CoverBlock) 182 ) 183 184 func init() { 185 {{range $i, $c := .CoverVars}} 186 coverRegisterFile({{printf "%q" $c.File}}, {{$c.ImportName}}.{{$c.Var}}.Count[:], {{$c.ImportName}}.{{$c.Var}}.Pos[:], {{$c.ImportName}}.{{$c.Var}}.NumStmt[:]) 187 {{end}} 188 } 189 190 func coverRegisterFile(fileName string, counter []uint32, pos []uint32, numStmts []uint16) { 191 if 3*len(counter) != len(pos) || len(counter) != len(numStmts) { 192 panic("coverage: mismatched sizes") 193 } 194 if coverCounters[fileName] != nil { 195 // Already registered. 196 return 197 } 198 coverCounters[fileName] = counter 199 block := make([]testing.CoverBlock, len(counter)) 200 for i := range counter { 201 block[i] = testing.CoverBlock{ 202 Line0: pos[3*i+0], 203 Col0: uint16(pos[3*i+2]), 204 Line1: pos[3*i+1], 205 Col1: uint16(pos[3*i+2]>>16), 206 Stmts: numStmts[i], 207 } 208 } 209 coverBlocks[fileName] = block 210 } 211 {{end}} 212 213 {{if .Version18}} 214 var testDeps = testdeps.TestDeps{} 215 {{else}} 216 func testDeps(pat, str string) (bool, error) { 217 return pat == str, nil 218 } 219 {{end}} 220 221 func main() { 222 {{if .CoverVars}} 223 testing.RegisterCover(testing.Cover{ 224 Mode: "set", 225 Counters: coverCounters, 226 Blocks: coverBlocks, 227 CoveredPackages: "", 228 }) 229 coverfile := os.Getenv("COVERAGE_FILE") 230 args := []string{os.Args[0], "-test.v", "-test.coverprofile", coverfile} 231 {{else}} 232 args := []string{os.Args[0], "-test.v"} 233 {{end}} 234 testVar := os.Getenv("TESTS") 235 if testVar != "" { 236 args = append(args, "-test.run", testVar) 237 } 238 os.Args = append(args, os.Args[1:]...) 239 benchmarks := []testing.InternalBenchmark{} 240 var examples = []testing.InternalExample{} 241 m := testing.MainStart(testDeps, tests, benchmarks, examples) 242 {{if .Main}} 243 {{.Package}}.{{.Main}}(m) 244 {{else}} 245 os.Exit(m.Run()) 246 {{end}} 247 } 248 `))