github.com/gernest/nezuko@v0.1.2/internal/load/test.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package load
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/doc"
    13  	"go/parser"
    14  	"go/token"
    15  	"path/filepath"
    16  	"sort"
    17  	"strings"
    18  	"text/template"
    19  	"unicode"
    20  	"unicode/utf8"
    21  
    22  	"github.com/gernest/nezuko/internal/base"
    23  	"github.com/gernest/nezuko/internal/str"
    24  )
    25  
    26  var TestMainDeps = []string{
    27  	// Dependencies for testmain.
    28  	"os",
    29  	"testing",
    30  	"testing/internal/testdeps",
    31  }
    32  
    33  type TestCover struct {
    34  	Mode     string
    35  	Local    bool
    36  	Pkgs     []*Package
    37  	Paths    []string
    38  	Vars     []coverInfo
    39  	DeclVars func(*Package, ...string) map[string]*CoverVar
    40  }
    41  
    42  func testImportStack(top string, p *Package, target string) []string {
    43  	stk := []string{top, p.ImportPath}
    44  Search:
    45  	for p.ImportPath != target {
    46  		for _, p1 := range p.Internal.Imports {
    47  			if p1.ImportPath == target || str.Contains(p1.Deps, target) {
    48  				stk = append(stk, p1.ImportPath)
    49  				p = p1
    50  				continue Search
    51  			}
    52  		}
    53  		// Can't happen, but in case it does...
    54  		stk = append(stk, "<lost path to cycle>")
    55  		break
    56  	}
    57  	return stk
    58  }
    59  
    60  // recompileForTest copies and replaces certain packages in pmain's dependency
    61  // graph. This is necessary for two reasons. First, if ptest is different than
    62  // preal, packages that import the package under test should get ptest instead
    63  // of preal. This is particularly important if pxtest depends on functionality
    64  // exposed in test sources in ptest. Second, if there is a main package
    65  // (other than pmain) anywhere, we need to set p.Internal.ForceLibrary and
    66  // clear p.Internal.BuildInfo in the test copy to prevent link conflicts.
    67  // This may happen if both -coverpkg and the command line patterns include
    68  // multiple main packages.
    69  func recompileForTest(pmain, preal, ptest, pxtest *Package) {
    70  	// The "test copy" of preal is ptest.
    71  	// For each package that depends on preal, make a "test copy"
    72  	// that depends on ptest. And so on, up the dependency tree.
    73  	testCopy := map[*Package]*Package{preal: ptest}
    74  	for _, p := range PackageList([]*Package{pmain}) {
    75  		if p == preal {
    76  			continue
    77  		}
    78  		// Copy on write.
    79  		didSplit := p == pmain || p == pxtest
    80  		split := func() {
    81  			if didSplit {
    82  				return
    83  			}
    84  			didSplit = true
    85  			if testCopy[p] != nil {
    86  				panic("recompileForTest loop")
    87  			}
    88  			p1 := new(Package)
    89  			testCopy[p] = p1
    90  			*p1 = *p
    91  			p1.ForTest = preal.ImportPath
    92  			p1.Internal.Imports = make([]*Package, len(p.Internal.Imports))
    93  			copy(p1.Internal.Imports, p.Internal.Imports)
    94  			p1.Imports = make([]string, len(p.Imports))
    95  			copy(p1.Imports, p.Imports)
    96  			p = p1
    97  			p.Target = ""
    98  			p.Internal.BuildInfo = ""
    99  			p.Internal.ForceLibrary = true
   100  		}
   101  
   102  		// Update p.Internal.Imports to use test copies.
   103  		for i, imp := range p.Internal.Imports {
   104  			if p1 := testCopy[imp]; p1 != nil && p1 != imp {
   105  				split()
   106  				p.Internal.Imports[i] = p1
   107  			}
   108  		}
   109  
   110  		// Don't compile build info from a main package. This can happen
   111  		// if -coverpkg patterns include main packages, since those packages
   112  		// are imported by pmain. See golang.org/issue/30907.
   113  		if p.Internal.BuildInfo != "" && p != pmain {
   114  			split()
   115  		}
   116  	}
   117  }
   118  
   119  // isTestFunc tells whether fn has the type of a testing function. arg
   120  // specifies the parameter type we look for: B, M or T.
   121  func isTestFunc(fn *ast.FuncDecl, arg string) bool {
   122  	if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 ||
   123  		fn.Type.Params.List == nil ||
   124  		len(fn.Type.Params.List) != 1 ||
   125  		len(fn.Type.Params.List[0].Names) > 1 {
   126  		return false
   127  	}
   128  	ptr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
   129  	if !ok {
   130  		return false
   131  	}
   132  	// We can't easily check that the type is *testing.M
   133  	// because we don't know how testing has been imported,
   134  	// but at least check that it's *M or *something.M.
   135  	// Same applies for B and T.
   136  	if name, ok := ptr.X.(*ast.Ident); ok && name.Name == arg {
   137  		return true
   138  	}
   139  	if sel, ok := ptr.X.(*ast.SelectorExpr); ok && sel.Sel.Name == arg {
   140  		return true
   141  	}
   142  	return false
   143  }
   144  
   145  // isTest tells whether name looks like a test (or benchmark, according to prefix).
   146  // It is a Test (say) if there is a character after Test that is not a lower-case letter.
   147  // We don't want TesticularCancer.
   148  func isTest(name, prefix string) bool {
   149  	if !strings.HasPrefix(name, prefix) {
   150  		return false
   151  	}
   152  	if len(name) == len(prefix) { // "Test" is ok
   153  		return true
   154  	}
   155  	rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
   156  	return !unicode.IsLower(rune)
   157  }
   158  
   159  type coverInfo struct {
   160  	Package *Package
   161  	Vars    map[string]*CoverVar
   162  }
   163  
   164  // loadTestFuncs returns the testFuncs describing the tests that will be run.
   165  func loadTestFuncs(ptest *Package) (*testFuncs, error) {
   166  	t := &testFuncs{
   167  		Package: ptest,
   168  	}
   169  	for _, file := range ptest.TestGoFiles {
   170  		if err := t.load(filepath.Join(ptest.Dir, file), "_test", &t.ImportTest, &t.NeedTest); err != nil {
   171  			return nil, err
   172  		}
   173  	}
   174  	for _, file := range ptest.XTestGoFiles {
   175  		if err := t.load(filepath.Join(ptest.Dir, file), "_xtest", &t.ImportXtest, &t.NeedXtest); err != nil {
   176  			return nil, err
   177  		}
   178  	}
   179  	return t, nil
   180  }
   181  
   182  // formatTestmain returns the content of the _testmain.go file for t.
   183  func formatTestmain(t *testFuncs) ([]byte, error) {
   184  	var buf bytes.Buffer
   185  	if err := testmainTmpl.Execute(&buf, t); err != nil {
   186  		return nil, err
   187  	}
   188  	return buf.Bytes(), nil
   189  }
   190  
   191  type testFuncs struct {
   192  	Tests       []testFunc
   193  	Benchmarks  []testFunc
   194  	Examples    []testFunc
   195  	TestMain    *testFunc
   196  	Package     *Package
   197  	ImportTest  bool
   198  	NeedTest    bool
   199  	ImportXtest bool
   200  	NeedXtest   bool
   201  	Cover       *TestCover
   202  }
   203  
   204  // ImportPath returns the import path of the package being tested, if it is within GOPATH.
   205  // This is printed by the testing package when running benchmarks.
   206  func (t *testFuncs) ImportPath() string {
   207  	pkg := t.Package.ImportPath
   208  	if strings.HasPrefix(pkg, "_/") {
   209  		return ""
   210  	}
   211  	if pkg == "command-line-arguments" {
   212  		return ""
   213  	}
   214  	return pkg
   215  }
   216  
   217  // Covered returns a string describing which packages are being tested for coverage.
   218  // If the covered package is the same as the tested package, it returns the empty string.
   219  // Otherwise it is a comma-separated human-readable list of packages beginning with
   220  // " in", ready for use in the coverage message.
   221  func (t *testFuncs) Covered() string {
   222  	if t.Cover == nil || t.Cover.Paths == nil {
   223  		return ""
   224  	}
   225  	return " in " + strings.Join(t.Cover.Paths, ", ")
   226  }
   227  
   228  // Tested returns the name of the package being tested.
   229  func (t *testFuncs) Tested() string {
   230  	return t.Package.Name
   231  }
   232  
   233  type testFunc struct {
   234  	Package   string // imported package name (_test or _xtest)
   235  	Name      string // function name
   236  	Output    string // output, for examples
   237  	Unordered bool   // output is allowed to be unordered.
   238  }
   239  
   240  var testFileSet = token.NewFileSet()
   241  
   242  func (t *testFuncs) load(filename, pkg string, doImport, seen *bool) error {
   243  	f, err := parser.ParseFile(testFileSet, filename, nil, parser.ParseComments)
   244  	if err != nil {
   245  		return base.ExpandScanner(err)
   246  	}
   247  	for _, d := range f.Decls {
   248  		n, ok := d.(*ast.FuncDecl)
   249  		if !ok {
   250  			continue
   251  		}
   252  		if n.Recv != nil {
   253  			continue
   254  		}
   255  		name := n.Name.String()
   256  		switch {
   257  		case name == "TestMain":
   258  			if isTestFunc(n, "T") {
   259  				t.Tests = append(t.Tests, testFunc{pkg, name, "", false})
   260  				*doImport, *seen = true, true
   261  				continue
   262  			}
   263  			err := checkTestFunc(n, "M")
   264  			if err != nil {
   265  				return err
   266  			}
   267  			if t.TestMain != nil {
   268  				return errors.New("multiple definitions of TestMain")
   269  			}
   270  			t.TestMain = &testFunc{pkg, name, "", false}
   271  			*doImport, *seen = true, true
   272  		case isTest(name, "Test"):
   273  			err := checkTestFunc(n, "T")
   274  			if err != nil {
   275  				return err
   276  			}
   277  			t.Tests = append(t.Tests, testFunc{pkg, name, "", false})
   278  			*doImport, *seen = true, true
   279  		case isTest(name, "Benchmark"):
   280  			err := checkTestFunc(n, "B")
   281  			if err != nil {
   282  				return err
   283  			}
   284  			t.Benchmarks = append(t.Benchmarks, testFunc{pkg, name, "", false})
   285  			*doImport, *seen = true, true
   286  		}
   287  	}
   288  	ex := doc.Examples(f)
   289  	sort.Slice(ex, func(i, j int) bool { return ex[i].Order < ex[j].Order })
   290  	for _, e := range ex {
   291  		*doImport = true // import test file whether executed or not
   292  		if e.Output == "" && !e.EmptyOutput {
   293  			// Don't run examples with no output.
   294  			continue
   295  		}
   296  		t.Examples = append(t.Examples, testFunc{pkg, "Example" + e.Name, e.Output, e.Unordered})
   297  		*seen = true
   298  	}
   299  	return nil
   300  }
   301  
   302  func checkTestFunc(fn *ast.FuncDecl, arg string) error {
   303  	if !isTestFunc(fn, arg) {
   304  		name := fn.Name.String()
   305  		pos := testFileSet.Position(fn.Pos())
   306  		return fmt.Errorf("%s: wrong signature for %s, must be: func %s(%s *testing.%s)", pos, name, name, strings.ToLower(arg), arg)
   307  	}
   308  	return nil
   309  }
   310  
   311  var testmainTmpl = template.Must(template.New("main").Parse(`
   312  package main
   313  
   314  import (
   315  {{if not .TestMain}}
   316  	"os"
   317  {{end}}
   318  	"testing"
   319  	"testing/internal/testdeps"
   320  
   321  {{if .ImportTest}}
   322  	{{if .NeedTest}}_test{{else}}_{{end}} {{.Package.ImportPath | printf "%q"}}
   323  {{end}}
   324  {{if .ImportXtest}}
   325  	{{if .NeedXtest}}_xtest{{else}}_{{end}} {{.Package.ImportPath | printf "%s_test" | printf "%q"}}
   326  {{end}}
   327  {{if .Cover}}
   328  {{range $i, $p := .Cover.Vars}}
   329  	_cover{{$i}} {{$p.Package.ImportPath | printf "%q"}}
   330  {{end}}
   331  {{end}}
   332  )
   333  
   334  var tests = []testing.InternalTest{
   335  {{range .Tests}}
   336  	{"{{.Name}}", {{.Package}}.{{.Name}}},
   337  {{end}}
   338  }
   339  
   340  var benchmarks = []testing.InternalBenchmark{
   341  {{range .Benchmarks}}
   342  	{"{{.Name}}", {{.Package}}.{{.Name}}},
   343  {{end}}
   344  }
   345  
   346  var examples = []testing.InternalExample{
   347  {{range .Examples}}
   348  	{"{{.Name}}", {{.Package}}.{{.Name}}, {{.Output | printf "%q"}}, {{.Unordered}}},
   349  {{end}}
   350  }
   351  
   352  func init() {
   353  	testdeps.ImportPath = {{.ImportPath | printf "%q"}}
   354  }
   355  
   356  {{if .Cover}}
   357  
   358  // Only updated by init functions, so no need for atomicity.
   359  var (
   360  	coverCounters = make(map[string][]uint32)
   361  	coverBlocks = make(map[string][]testing.CoverBlock)
   362  )
   363  
   364  func init() {
   365  	{{range $i, $p := .Cover.Vars}}
   366  	{{range $file, $cover := $p.Vars}}
   367  	coverRegisterFile({{printf "%q" $cover.File}}, _cover{{$i}}.{{$cover.Var}}.Count[:], _cover{{$i}}.{{$cover.Var}}.Pos[:], _cover{{$i}}.{{$cover.Var}}.NumStmt[:])
   368  	{{end}}
   369  	{{end}}
   370  }
   371  
   372  func coverRegisterFile(fileName string, counter []uint32, pos []uint32, numStmts []uint16) {
   373  	if 3*len(counter) != len(pos) || len(counter) != len(numStmts) {
   374  		panic("coverage: mismatched sizes")
   375  	}
   376  	if coverCounters[fileName] != nil {
   377  		// Already registered.
   378  		return
   379  	}
   380  	coverCounters[fileName] = counter
   381  	block := make([]testing.CoverBlock, len(counter))
   382  	for i := range counter {
   383  		block[i] = testing.CoverBlock{
   384  			Line0: pos[3*i+0],
   385  			Col0: uint16(pos[3*i+2]),
   386  			Line1: pos[3*i+1],
   387  			Col1: uint16(pos[3*i+2]>>16),
   388  			Stmts: numStmts[i],
   389  		}
   390  	}
   391  	coverBlocks[fileName] = block
   392  }
   393  {{end}}
   394  
   395  func main() {
   396  {{if .Cover}}
   397  	testing.RegisterCover(testing.Cover{
   398  		Mode: {{printf "%q" .Cover.Mode}},
   399  		Counters: coverCounters,
   400  		Blocks: coverBlocks,
   401  		CoveredPackages: {{printf "%q" .Covered}},
   402  	})
   403  {{end}}
   404  	m := testing.MainStart(testdeps.TestDeps{}, tests, benchmarks, examples)
   405  {{with .TestMain}}
   406  	{{.Package}}.{{.Name}}(m)
   407  {{else}}
   408  	os.Exit(m.Run())
   409  {{end}}
   410  }
   411  
   412  `))