github.com/tiagovtristao/plz@v13.4.0+incompatible/tools/please_go_test/gotest/write_test_main.go (about) 1 package gotest 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/doc" 7 "go/parser" 8 "go/token" 9 "os" 10 "os/exec" 11 "path" 12 "regexp" 13 "strconv" 14 "strings" 15 "text/template" 16 "unicode" 17 "unicode/utf8" 18 ) 19 20 type testDescr struct { 21 Package string 22 Main string 23 Functions []string 24 Examples []*doc.Example 25 CoverVars []CoverVar 26 Imports []string 27 Version18 bool 28 } 29 30 // WriteTestMain templates a test main file from the given sources to the given output file. 31 // This mimics what 'go test' does, although we do not currently support benchmarks or examples. 32 func WriteTestMain(pkgDir, importPath string, version18 bool, sources []string, output string, coverVars []CoverVar) error { 33 testDescr, err := parseTestSources(sources) 34 if err != nil { 35 return err 36 } 37 testDescr.CoverVars = coverVars 38 testDescr.Version18 = version18 39 if len(testDescr.Functions) > 0 || len(testDescr.Examples) > 0 { 40 // Can't set this if there are no test functions, it'll be an unused import. 41 testDescr.Imports = extraImportPaths(testDescr.Package, pkgDir, importPath, coverVars) 42 } 43 44 f, err := os.Create(output) 45 if err != nil { 46 return err 47 } 48 defer f.Close() 49 // This might be consumed by other things. 50 fmt.Printf("Package: %s\n", testDescr.Package) 51 return testMainTmpl.Execute(f, testDescr) 52 } 53 54 // IsVersion18 returns true if the given Go tool is version 1.8 or greater. 55 // This is needed because the test main signature has changed - it's not subject to the Go1 compatibility guarantee :( 56 func IsVersion18(goTool string) bool { 57 cmd := exec.Command(goTool, "version") 58 out, err := cmd.Output() 59 if err != nil { 60 log.Fatalf("Can't determine Go version: %s", err) 61 } 62 return isVersion18(out) 63 } 64 65 func isVersion18(version []byte) bool { 66 r := regexp.MustCompile("go version go1.([0-9]+)[^0-9].*") 67 m := r.FindSubmatch(version) 68 if len(m) == 0 { 69 log.Warning("Failed to match %s", version) 70 return false 71 } 72 v, _ := strconv.Atoi(string(m[1])) 73 return v >= 8 74 } 75 76 // extraImportPaths returns the set of extra import paths that are needed. 77 func extraImportPaths(pkg, pkgDir, importPath string, coverVars []CoverVar) []string { 78 pkgDir = collapseFinalDir(path.Join(pkgDir, pkg), importPath) 79 ret := []string{fmt.Sprintf("%s \"%s\"", pkg, path.Join(importPath, pkgDir))} 80 for i, v := range coverVars { 81 name := fmt.Sprintf("_cover%d", i) 82 coverVars[i].ImportName = name 83 ret = append(ret, fmt.Sprintf("%s \"%s\"", name, path.Join(importPath, v.ImportPath))) 84 } 85 return ret 86 } 87 88 // parseTestSources parses the test sources and returns the package and set of test functions in them. 89 func parseTestSources(sources []string) (testDescr, error) { 90 descr := testDescr{} 91 for _, source := range sources { 92 f, err := parser.ParseFile(token.NewFileSet(), source, nil, parser.ParseComments) 93 if err != nil { 94 log.Errorf("Error parsing %s: %s", source, err) 95 return descr, err 96 } 97 descr.Package = f.Name.Name 98 // If we're testing main, we will get errors from it clashing with func main. 99 if descr.Package == "main" { 100 descr.Package = "_main" 101 } 102 for _, d := range f.Decls { 103 if fd, ok := d.(*ast.FuncDecl); ok && fd.Recv == nil { 104 name := fd.Name.String() 105 if isTestMain(fd) { 106 descr.Main = name 107 } else if isTest(fd, 1, name, "Test") { 108 descr.Functions = append(descr.Functions, name) 109 } 110 } 111 } 112 // Get doc to find the examples for us :) 113 descr.Examples = append(descr.Examples, doc.Examples(f)...) 114 } 115 return descr, nil 116 } 117 118 // isTestMain returns true if fn is a TestMain(m *testing.M) function. 119 // Copied from Go sources. 120 func isTestMain(fn *ast.FuncDecl) bool { 121 if fn.Name.String() != "TestMain" || 122 fn.Type.Results != nil && len(fn.Type.Results.List) > 0 || 123 fn.Type.Params == nil || 124 len(fn.Type.Params.List) != 1 || 125 len(fn.Type.Params.List[0].Names) > 1 { 126 return false 127 } 128 ptr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr) 129 if !ok { 130 return false 131 } 132 // We can't easily check that the type is *testing.M 133 // because we don't know how testing has been imported, 134 // but at least check that it's *M or *something.M. 135 if name, ok := ptr.X.(*ast.Ident); ok && name.Name == "M" { 136 return true 137 } 138 if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == "M" { 139 return true 140 } 141 return false 142 } 143 144 // isTest returns true if the given function looks like a test. 145 // Copied from Go sources. 146 func isTest(fd *ast.FuncDecl, argLen int, name, prefix string) bool { 147 if !strings.HasPrefix(name, prefix) || fd.Recv != nil || len(fd.Type.Params.List) != argLen { 148 return false 149 } else if len(name) == len(prefix) { // "Test" is ok 150 return true 151 } 152 rune, _ := utf8.DecodeRuneInString(name[len(prefix):]) 153 return !unicode.IsLower(rune) 154 } 155 156 // testMainTmpl is the template for our test main, copied from Go's builtin one. 157 // Some bits are excluded because we don't support them and/or do them differently. 158 var testMainTmpl = template.Must(template.New("main").Parse(` 159 package main 160 161 import ( 162 "os" 163 "testing" 164 {{if .Version18}} 165 "testing/internal/testdeps" 166 {{end}} 167 168 {{range .Imports}} 169 {{.}} 170 {{end}} 171 ) 172 173 var tests = []testing.InternalTest{ 174 {{range .Functions}} 175 {"{{.}}", {{$.Package}}.{{.}}}, 176 {{end}} 177 } 178 var examples = []testing.InternalExample{ 179 {{range .Examples}} 180 {"{{.Name}}", {{$.Package}}.Example{{.Name}}, {{.Output | printf "%q"}}, {{.Unordered}}}, 181 {{end}} 182 } 183 184 {{if .CoverVars}} 185 186 // Only updated by init functions, so no need for atomicity. 187 var ( 188 coverCounters = make(map[string][]uint32) 189 coverBlocks = make(map[string][]testing.CoverBlock) 190 ) 191 192 func init() { 193 {{range $i, $c := .CoverVars}} 194 coverRegisterFile({{printf "%q" $c.File}}, {{$c.ImportName}}.{{$c.Var}}.Count[:], {{$c.ImportName}}.{{$c.Var}}.Pos[:], {{$c.ImportName}}.{{$c.Var}}.NumStmt[:]) 195 {{end}} 196 } 197 198 func coverRegisterFile(fileName string, counter []uint32, pos []uint32, numStmts []uint16) { 199 if 3*len(counter) != len(pos) || len(counter) != len(numStmts) { 200 panic("coverage: mismatched sizes") 201 } 202 if coverCounters[fileName] != nil { 203 // Already registered. 204 return 205 } 206 coverCounters[fileName] = counter 207 block := make([]testing.CoverBlock, len(counter)) 208 for i := range counter { 209 block[i] = testing.CoverBlock{ 210 Line0: pos[3*i+0], 211 Col0: uint16(pos[3*i+2]), 212 Line1: pos[3*i+1], 213 Col1: uint16(pos[3*i+2]>>16), 214 Stmts: numStmts[i], 215 } 216 } 217 coverBlocks[fileName] = block 218 } 219 {{end}} 220 221 {{if .Version18}} 222 var testDeps = testdeps.TestDeps{} 223 {{else}} 224 func testDeps(pat, str string) (bool, error) { 225 return pat == str, nil 226 } 227 {{end}} 228 229 func main() { 230 {{if .CoverVars}} 231 testing.RegisterCover(testing.Cover{ 232 Mode: "set", 233 Counters: coverCounters, 234 Blocks: coverBlocks, 235 CoveredPackages: "", 236 }) 237 coverfile := os.Getenv("COVERAGE_FILE") 238 args := []string{os.Args[0], "-test.v", "-test.coverprofile", coverfile} 239 {{else}} 240 args := []string{os.Args[0], "-test.v"} 241 {{end}} 242 testVar := os.Getenv("TESTS") 243 if testVar != "" { 244 args = append(args, "-test.run", testVar) 245 } 246 os.Args = append(args, os.Args[1:]...) 247 m := testing.MainStart(testDeps, tests, nil, examples) 248 {{if .Main}} 249 {{.Package}}.{{.Main}}(m) 250 {{else}} 251 os.Exit(m.Run()) 252 {{end}} 253 } 254 `))