github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/internal/gcimporter/shallow_test.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gcimporter_test
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/parser"
    11  	"go/token"
    12  	"go/types"
    13  	"os"
    14  	"strings"
    15  	"testing"
    16  
    17  	"golang.org/x/sync/errgroup"
    18  	"golang.org/x/tools/go/packages"
    19  	"golang.org/x/tools/internal/gcimporter"
    20  	"golang.org/x/tools/internal/testenv"
    21  )
    22  
    23  // TestStd type-checks the standard library using shallow export data.
    24  func TestShallowStd(t *testing.T) {
    25  	if testing.Short() {
    26  		t.Skip("skipping in short mode; too slow (https://golang.org/issue/14113)")
    27  	}
    28  	testenv.NeedsTool(t, "go")
    29  
    30  	// Load import graph of the standard library.
    31  	// (No parsing or type-checking.)
    32  	cfg := &packages.Config{
    33  		Mode: packages.NeedImports |
    34  			packages.NeedName |
    35  			packages.NeedFiles | // see https://github.com/golang/go/issues/56632
    36  			packages.NeedCompiledGoFiles,
    37  		Tests: false,
    38  	}
    39  	pkgs, err := packages.Load(cfg, "std")
    40  	if err != nil {
    41  		t.Fatalf("load: %v", err)
    42  	}
    43  	if len(pkgs) < 200 {
    44  		t.Fatalf("too few packages: %d", len(pkgs))
    45  	}
    46  
    47  	// Type check the packages in parallel postorder.
    48  	done := make(map[*packages.Package]chan struct{})
    49  	packages.Visit(pkgs, nil, func(p *packages.Package) {
    50  		done[p] = make(chan struct{})
    51  	})
    52  	packages.Visit(pkgs, nil,
    53  		func(pkg *packages.Package) {
    54  			go func() {
    55  				// Wait for all deps to be done.
    56  				for _, imp := range pkg.Imports {
    57  					<-done[imp]
    58  				}
    59  				typecheck(t, pkg)
    60  				close(done[pkg])
    61  			}()
    62  		})
    63  	for _, root := range pkgs {
    64  		<-done[root]
    65  	}
    66  }
    67  
    68  // typecheck reads, parses, and type-checks a package.
    69  // It squirrels the export data in the the ppkg.ExportFile field.
    70  func typecheck(t *testing.T, ppkg *packages.Package) {
    71  	if ppkg.PkgPath == "unsafe" {
    72  		return // unsafe is special
    73  	}
    74  
    75  	// Create a local FileSet just for this package.
    76  	fset := token.NewFileSet()
    77  
    78  	// Parse files in parallel.
    79  	syntax := make([]*ast.File, len(ppkg.CompiledGoFiles))
    80  	var group errgroup.Group
    81  	for i, filename := range ppkg.CompiledGoFiles {
    82  		i, filename := i, filename
    83  		group.Go(func() error {
    84  			f, err := parser.ParseFile(fset, filename, nil, parser.SkipObjectResolution)
    85  			if err != nil {
    86  				return err // e.g. missing file
    87  			}
    88  			syntax[i] = f
    89  			return nil
    90  		})
    91  	}
    92  	if err := group.Wait(); err != nil {
    93  		t.Fatal(err)
    94  	}
    95  	// Inv: all files were successfully parsed.
    96  
    97  	// Build map of dependencies by package path.
    98  	// (We don't compute this mapping for the entire
    99  	// packages graph because it is not globally consistent.)
   100  	depsByPkgPath := make(map[string]*packages.Package)
   101  	{
   102  		var visit func(*packages.Package)
   103  		visit = func(pkg *packages.Package) {
   104  			if depsByPkgPath[pkg.PkgPath] == nil {
   105  				depsByPkgPath[pkg.PkgPath] = pkg
   106  				for path := range pkg.Imports {
   107  					visit(pkg.Imports[path])
   108  				}
   109  			}
   110  		}
   111  		visit(ppkg)
   112  	}
   113  
   114  	// importer state
   115  	var (
   116  		insert    func(p *types.Package, name string)
   117  		importMap = make(map[string]*types.Package) // keys are PackagePaths
   118  	)
   119  	loadFromExportData := func(imp *packages.Package) (*types.Package, error) {
   120  		data := []byte(imp.ExportFile)
   121  		return gcimporter.IImportShallow(fset, importMap, data, imp.PkgPath, insert)
   122  	}
   123  	insert = func(p *types.Package, name string) {
   124  		imp, ok := depsByPkgPath[p.Path()]
   125  		if !ok {
   126  			t.Fatalf("can't find dependency: %q", p.Path())
   127  		}
   128  		imported, err := loadFromExportData(imp)
   129  		if err != nil {
   130  			t.Fatalf("unmarshal: %v", err)
   131  		}
   132  		if imported != p {
   133  			t.Fatalf("internal error: inconsistent packages")
   134  		}
   135  		if obj := imported.Scope().Lookup(name); obj == nil {
   136  			t.Fatalf("lookup %q.%s failed", imported.Path(), name)
   137  		}
   138  	}
   139  
   140  	cfg := &types.Config{
   141  		Error: func(e error) {
   142  			t.Error(e)
   143  		},
   144  		Importer: importerFunc(func(importPath string) (*types.Package, error) {
   145  			if importPath == "unsafe" {
   146  				return types.Unsafe, nil // unsafe has no exportdata
   147  			}
   148  			imp, ok := ppkg.Imports[importPath]
   149  			if !ok {
   150  				return nil, fmt.Errorf("missing import %q", importPath)
   151  			}
   152  			return loadFromExportData(imp)
   153  		}),
   154  	}
   155  
   156  	// Type-check the syntax trees.
   157  	tpkg, _ := cfg.Check(ppkg.PkgPath, fset, syntax, nil)
   158  	postTypeCheck(t, fset, tpkg)
   159  
   160  	// Save the export data.
   161  	data, err := gcimporter.IExportShallow(fset, tpkg)
   162  	if err != nil {
   163  		t.Fatalf("internal error marshalling export data: %v", err)
   164  	}
   165  	ppkg.ExportFile = string(data)
   166  }
   167  
   168  // postTypeCheck is called after a package is type checked.
   169  // We use it to assert additional correctness properties,
   170  // for example, that the apparent location of "fmt.Println"
   171  // corresponds to its source location: in other words,
   172  // export+import preserves high-fidelity positions.
   173  func postTypeCheck(t *testing.T, fset *token.FileSet, pkg *types.Package) {
   174  	// We hard-code a few interesting test-case objects.
   175  	var obj types.Object
   176  	switch pkg.Path() {
   177  	case "fmt":
   178  		// func fmt.Println
   179  		obj = pkg.Scope().Lookup("Println")
   180  	case "net/http":
   181  		// method (*http.Request).ParseForm
   182  		req := pkg.Scope().Lookup("Request")
   183  		obj, _, _ = types.LookupFieldOrMethod(req.Type(), true, pkg, "ParseForm")
   184  	default:
   185  		return
   186  	}
   187  	if obj == nil {
   188  		t.Errorf("object not found in package %s", pkg.Path())
   189  		return
   190  	}
   191  
   192  	// Now check the source fidelity of the object's position.
   193  	posn := fset.Position(obj.Pos())
   194  	data, err := os.ReadFile(posn.Filename)
   195  	if err != nil {
   196  		t.Errorf("can't read source file declaring %v: %v", obj, err)
   197  		return
   198  	}
   199  
   200  	// Check line and column denote a source interval containing the object's identifier.
   201  	line := strings.Split(string(data), "\n")[posn.Line-1]
   202  
   203  	if id := line[posn.Column-1 : posn.Column-1+len(obj.Name())]; id != obj.Name() {
   204  		t.Errorf("%+v: expected declaration of %v at this line, column; got %q", posn, obj, line)
   205  	}
   206  
   207  	// Check offset.
   208  	if id := string(data[posn.Offset : posn.Offset+len(obj.Name())]); id != obj.Name() {
   209  		t.Errorf("%+v: expected declaration of %v at this offset; got %q", posn, obj, id)
   210  	}
   211  
   212  	// Check commutativity of Position() and start+len(name) operations:
   213  	// Position(startPos+len(name)) == Position(startPos) + len(name).
   214  	// This important property is a consequence of the way in which the
   215  	// decoder fills the gaps in the sparse line-start offset table.
   216  	endPosn := fset.Position(obj.Pos() + token.Pos(len(obj.Name())))
   217  	wantEndPosn := token.Position{
   218  		Filename: posn.Filename,
   219  		Offset:   posn.Offset + len(obj.Name()),
   220  		Line:     posn.Line,
   221  		Column:   posn.Column + len(obj.Name()),
   222  	}
   223  	if endPosn != wantEndPosn {
   224  		t.Errorf("%+v: expected end Position of %v here; was at %+v", wantEndPosn, obj, endPosn)
   225  	}
   226  }