github.com/aca02djr/gb@v0.4.1/test/gotest.go (about)

     1  // Copyright 2011 The Go Authors.  All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package test
     6  
     7  // imported from $GOROOT/src/cmd/go/test.go
     8  
     9  import (
    10  	"bytes"
    11  	"errors"
    12  	"go/ast"
    13  	"go/doc"
    14  	"go/parser"
    15  	"go/scanner"
    16  	"go/token"
    17  	"os"
    18  	"path/filepath"
    19  	"sort"
    20  	"strings"
    21  	"text/template"
    22  	"unicode"
    23  	"unicode/utf8"
    24  
    25  	"github.com/constabulary/gb"
    26  	"github.com/constabulary/gb/debug"
    27  	"github.com/constabulary/gb/importer"
    28  )
    29  
    30  type coverInfo struct {
    31  	Package *gb.Package
    32  	Vars    map[string]*CoverVar
    33  }
    34  
    35  // CoverVar holds the name of the generated coverage variables targeting the named file.
    36  type CoverVar struct {
    37  	File string // local file name
    38  	Var  string // name of count struct
    39  }
    40  
    41  var cwd, _ = os.Getwd()
    42  
    43  // shortPath returns an absolute or relative name for path, whatever is shorter.
    44  func shortPath(path string) string {
    45  	if rel, err := filepath.Rel(cwd, path); err == nil && len(rel) < len(path) {
    46  		return rel
    47  	}
    48  	return path
    49  }
    50  
    51  // isTestMain tells whether fn is a TestMain(m *testing.M) function.
    52  func isTestMain(fn *ast.FuncDecl) bool {
    53  	if fn.Name.String() != "TestMain" ||
    54  		fn.Type.Results != nil && len(fn.Type.Results.List) > 0 ||
    55  		fn.Type.Params == nil ||
    56  		len(fn.Type.Params.List) != 1 ||
    57  		len(fn.Type.Params.List[0].Names) > 1 {
    58  		return false
    59  	}
    60  	ptr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
    61  	if !ok {
    62  		return false
    63  	}
    64  	// We can't easily check that the type is *testing.M
    65  	// because we don't know how testing has been imported,
    66  	// but at least check that it's *M or *something.M.
    67  	if name, ok := ptr.X.(*ast.Ident); ok && name.Name == "M" {
    68  		return true
    69  	}
    70  	if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == "M" {
    71  		return true
    72  	}
    73  	return false
    74  }
    75  
    76  // isTest tells whether name looks like a test (or benchmark, according to prefix).
    77  // It is a Test (say) if there is a character after Test that is not a lower-case letter.
    78  // We don't want TesticularCancer.
    79  func isTest(name, prefix string) bool {
    80  	if !strings.HasPrefix(name, prefix) {
    81  		return false
    82  	}
    83  	if len(name) == len(prefix) { // "Test" is ok
    84  		return true
    85  	}
    86  	rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
    87  	return !unicode.IsLower(rune)
    88  }
    89  
    90  // loadTestFuncs returns the testFuncs describing the tests that will be run.
    91  func loadTestFuncs(ptest *importer.Package) (*testFuncs, error) {
    92  	t := &testFuncs{
    93  		Package: ptest,
    94  	}
    95  	debug.Debugf("loadTestFuncs: %v, %v", ptest.TestGoFiles, ptest.XTestGoFiles)
    96  	for _, file := range ptest.TestGoFiles {
    97  		if err := t.load(filepath.Join(ptest.Dir, file), "_test", &t.ImportTest, &t.NeedTest); err != nil {
    98  			return nil, err
    99  		}
   100  	}
   101  	for _, file := range ptest.XTestGoFiles {
   102  		if err := t.load(filepath.Join(ptest.Dir, file), "_xtest", &t.ImportXtest, &t.NeedXtest); err != nil {
   103  			return nil, err
   104  		}
   105  	}
   106  	return t, nil
   107  }
   108  
   109  // writeTestmain writes the _testmain.go file for t to the file named out.
   110  func writeTestmain(out string, t *testFuncs) error {
   111  	f, err := os.Create(out)
   112  	if err != nil {
   113  		return err
   114  	}
   115  	defer f.Close()
   116  
   117  	if err := testmainTmpl.Execute(f, t); err != nil {
   118  		return err
   119  	}
   120  
   121  	return nil
   122  }
   123  
   124  // expandScanner expands a scanner.List error into all the errors in the list.
   125  // The default Error method only shows the first error.
   126  func expandScanner(err error) error {
   127  	// Look for parser errors.
   128  	if err, ok := err.(scanner.ErrorList); ok {
   129  		// Prepare error with \n before each message.
   130  		// When printed in something like context: %v
   131  		// this will put the leading file positions each on
   132  		// its own line.  It will also show all the errors
   133  		// instead of just the first, as err.Error does.
   134  		var buf bytes.Buffer
   135  		for _, e := range err {
   136  			e.Pos.Filename = shortPath(e.Pos.Filename)
   137  			buf.WriteString("\n")
   138  			buf.WriteString(e.Error())
   139  		}
   140  		return errors.New(buf.String())
   141  	}
   142  	return err
   143  }
   144  
   145  type testFuncs struct {
   146  	Tests       []testFunc
   147  	Benchmarks  []testFunc
   148  	Examples    []testFunc
   149  	TestMain    *testFunc
   150  	Package     *importer.Package
   151  	ImportTest  bool
   152  	NeedTest    bool
   153  	ImportXtest bool
   154  	NeedXtest   bool
   155  	NeedCgo     bool
   156  	Cover       []coverInfo
   157  }
   158  
   159  func (t *testFuncs) CoverMode() string {
   160  	return ""
   161  }
   162  
   163  func (t *testFuncs) CoverEnabled() bool {
   164  	return false
   165  }
   166  
   167  // Covered returns a string describing which packages are being tested for coverage.
   168  // If the covered package is the same as the tested package, it returns the empty string.
   169  // Otherwise it is a comma-separated human-readable list of packages beginning with
   170  // " in", ready for use in the coverage message.
   171  func (t *testFuncs) Covered() string {
   172  	return ""
   173  }
   174  
   175  // Tested returns the name of the package being tested.
   176  func (t *testFuncs) Tested() string {
   177  	return t.Package.Name
   178  }
   179  
   180  type testFunc struct {
   181  	Package string // imported package name (_test or _xtest)
   182  	Name    string // function name
   183  	Output  string // output, for examples
   184  }
   185  
   186  var testFileSet = token.NewFileSet()
   187  
   188  func (t *testFuncs) load(filename, pkg string, doImport, seen *bool) error {
   189  	f, err := parser.ParseFile(testFileSet, filename, nil, parser.ParseComments)
   190  	if err != nil {
   191  		return expandScanner(err)
   192  	}
   193  	for _, d := range f.Decls {
   194  		n, ok := d.(*ast.FuncDecl)
   195  		if !ok {
   196  			continue
   197  		}
   198  		if n.Recv != nil {
   199  			continue
   200  		}
   201  		name := n.Name.String()
   202  		switch {
   203  		case isTestMain(n):
   204  			if t.TestMain != nil {
   205  				return errors.New("multiple definitions of TestMain")
   206  			}
   207  			t.TestMain = &testFunc{pkg, name, ""}
   208  			*doImport, *seen = true, true
   209  		case isTest(name, "Test"):
   210  			t.Tests = append(t.Tests, testFunc{pkg, name, ""})
   211  			*doImport, *seen = true, true
   212  		case isTest(name, "Benchmark"):
   213  			t.Benchmarks = append(t.Benchmarks, testFunc{pkg, name, ""})
   214  			*doImport, *seen = true, true
   215  		}
   216  	}
   217  	ex := doc.Examples(f)
   218  	sort.Sort(byOrder(ex))
   219  	for _, e := range ex {
   220  		*doImport = true // import test file whether executed or not
   221  		if e.Output == "" && !e.EmptyOutput {
   222  			// Don't run examples with no output.
   223  			continue
   224  		}
   225  		t.Examples = append(t.Examples, testFunc{pkg, "Example" + e.Name, e.Output})
   226  		*seen = true
   227  	}
   228  	return nil
   229  }
   230  
   231  type byOrder []*doc.Example
   232  
   233  func (x byOrder) Len() int           { return len(x) }
   234  func (x byOrder) Swap(i, j int)      { x[i], x[j] = x[j], x[i] }
   235  func (x byOrder) Less(i, j int) bool { return x[i].Order < x[j].Order }
   236  
   237  var testmainTmpl = template.Must(template.New("main").Parse(`
   238  package main
   239  
   240  import (
   241  {{if not .TestMain}}
   242  	"os"
   243  {{end}}
   244  	"regexp"
   245  	"testing"
   246  
   247  {{if .ImportTest}}
   248  	{{if .NeedTest}}_test{{else}}_{{end}} {{.Package.ImportPath | printf "%q"}}
   249  {{end}}
   250  {{if .ImportXtest}}
   251  	{{if .NeedXtest}}_xtest{{else}}_{{end}} {{.Package.ImportPath | printf "%s_test" | printf "%q"}}
   252  {{end}}
   253  {{range $i, $p := .Cover}}
   254  	_cover{{$i}} {{$p.Package.ImportPath | printf "%q"}}
   255  {{end}}
   256  
   257  {{if .NeedCgo}}
   258  	_ "runtime/cgo"
   259  {{end}}
   260  )
   261  
   262  var tests = []testing.InternalTest{
   263  {{range .Tests}}
   264  	{"{{.Name}}", {{.Package}}.{{.Name}}},
   265  {{end}}
   266  }
   267  
   268  var benchmarks = []testing.InternalBenchmark{
   269  {{range .Benchmarks}}
   270  	{"{{.Name}}", {{.Package}}.{{.Name}}},
   271  {{end}}
   272  }
   273  
   274  var examples = []testing.InternalExample{
   275  {{range .Examples}}
   276  	{Name: "{{.Name}}", F: {{.Package}}.{{.Name}}, Output: {{.Output | printf "%q"}}},
   277  {{end}}
   278  }
   279  
   280  var matchPat string
   281  var matchRe *regexp.Regexp
   282  
   283  func matchString(pat, str string) (result bool, err error) {
   284  	if matchRe == nil || matchPat != pat {
   285  		matchPat = pat
   286  		matchRe, err = regexp.Compile(matchPat)
   287  		if err != nil {
   288  			return
   289  		}
   290  	}
   291  	return matchRe.MatchString(str), nil
   292  }
   293  
   294  {{if .CoverEnabled}}
   295  
   296  // Only updated by init functions, so no need for atomicity.
   297  var (
   298  	coverCounters = make(map[string][]uint32)
   299  	coverBlocks = make(map[string][]testing.CoverBlock)
   300  )
   301  
   302  func init() {
   303  	{{range $i, $p := .Cover}}
   304  	{{range $file, $cover := $p.Vars}}
   305  	coverRegisterFile({{printf "%q" $cover.File}}, _cover{{$i}}.{{$cover.Var}}.Count[:], _cover{{$i}}.{{$cover.Var}}.Pos[:], _cover{{$i}}.{{$cover.Var}}.NumStmt[:])
   306  	{{end}}
   307  	{{end}}
   308  }
   309  
   310  func coverRegisterFile(fileName string, counter []uint32, pos []uint32, numStmts []uint16) {
   311  	if 3*len(counter) != len(pos) || len(counter) != len(numStmts) {
   312  		panic("coverage: mismatched sizes")
   313  	}
   314  	if coverCounters[fileName] != nil {
   315  		// Already registered.
   316  		return
   317  	}
   318  	coverCounters[fileName] = counter
   319  	block := make([]testing.CoverBlock, len(counter))
   320  	for i := range counter {
   321  		block[i] = testing.CoverBlock{
   322  			Line0: pos[3*i+0],
   323  			Col0: uint16(pos[3*i+2]),
   324  			Line1: pos[3*i+1],
   325  			Col1: uint16(pos[3*i+2]>>16),
   326  			Stmts: numStmts[i],
   327  		}
   328  	}
   329  	coverBlocks[fileName] = block
   330  }
   331  {{end}}
   332  
   333  func main() {
   334  {{if .CoverEnabled}}
   335  	testing.RegisterCover(testing.Cover{
   336  		Mode: {{printf "%q" .CoverMode}},
   337  		Counters: coverCounters,
   338  		Blocks: coverBlocks,
   339  		CoveredPackages: {{printf "%q" .Covered}},
   340  	})
   341  {{end}}
   342  	m := testing.MainStart(matchString, tests, benchmarks, examples)
   343  {{with .TestMain}}
   344  	{{.Package}}.{{.Name}}(m)
   345  {{else}}
   346  	os.Exit(m.Run())
   347  {{end}}
   348  }
   349  
   350  `))