github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/internal/gcimporter/bexport_test.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gcimporter_test
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/build"
    11  	"go/constant"
    12  	"go/parser"
    13  	"go/token"
    14  	"go/types"
    15  	"path/filepath"
    16  	"reflect"
    17  	"runtime"
    18  	"sort"
    19  	"strings"
    20  	"testing"
    21  
    22  	"golang.org/x/tools/go/ast/inspector"
    23  	"golang.org/x/tools/go/buildutil"
    24  	"golang.org/x/tools/go/loader"
    25  	"golang.org/x/tools/internal/gcimporter"
    26  	"golang.org/x/tools/internal/typeparams"
    27  	"golang.org/x/tools/internal/typeparams/genericfeatures"
    28  )
    29  
    30  var isRace = false
    31  
    32  func TestBExportData_stdlib(t *testing.T) {
    33  	if runtime.Compiler == "gccgo" {
    34  		t.Skip("gccgo standard library is inaccessible")
    35  	}
    36  	if runtime.GOOS == "android" {
    37  		t.Skipf("incomplete std lib on %s", runtime.GOOS)
    38  	}
    39  	if isRace {
    40  		t.Skipf("stdlib tests take too long in race mode and flake on builders")
    41  	}
    42  	if testing.Short() {
    43  		t.Skip("skipping RAM hungry test in -short mode")
    44  	}
    45  
    46  	// Load, parse and type-check the program.
    47  	ctxt := build.Default // copy
    48  	ctxt.GOPATH = ""      // disable GOPATH
    49  	conf := loader.Config{
    50  		Build:       &ctxt,
    51  		AllowErrors: true,
    52  		TypeChecker: types.Config{
    53  			Error: func(err error) { t.Log(err) },
    54  		},
    55  	}
    56  	for _, path := range buildutil.AllPackages(conf.Build) {
    57  		conf.Import(path)
    58  	}
    59  
    60  	// Create a package containing type and value errors to ensure
    61  	// they are properly encoded/decoded.
    62  	f, err := conf.ParseFile("haserrors/haserrors.go", `package haserrors
    63  const UnknownValue = "" + 0
    64  type UnknownType undefined
    65  `)
    66  	if err != nil {
    67  		t.Fatal(err)
    68  	}
    69  	conf.CreateFromFiles("haserrors", f)
    70  
    71  	prog, err := conf.Load()
    72  	if err != nil {
    73  		t.Fatalf("Load failed: %v", err)
    74  	}
    75  
    76  	numPkgs := len(prog.AllPackages)
    77  	if want := minStdlibPackages; numPkgs < want {
    78  		t.Errorf("Loaded only %d packages, want at least %d", numPkgs, want)
    79  	}
    80  
    81  	checked := 0
    82  	for pkg, info := range prog.AllPackages {
    83  		if info.Files == nil {
    84  			continue // empty directory
    85  		}
    86  		// Binary export does not support generic code.
    87  		inspect := inspector.New(info.Files)
    88  		if genericfeatures.ForPackage(inspect, &info.Info) != 0 {
    89  			t.Logf("skipping package %q which uses generics", pkg.Path())
    90  			continue
    91  		}
    92  		checked++
    93  		exportdata, err := gcimporter.BExportData(conf.Fset, pkg)
    94  		if err != nil {
    95  			t.Fatal(err)
    96  		}
    97  
    98  		imports := make(map[string]*types.Package)
    99  		fset2 := token.NewFileSet()
   100  		n, pkg2, err := gcimporter.BImportData(fset2, imports, exportdata, pkg.Path())
   101  		if err != nil {
   102  			t.Errorf("BImportData(%s): %v", pkg.Path(), err)
   103  			continue
   104  		}
   105  		if n != len(exportdata) {
   106  			t.Errorf("BImportData(%s) decoded %d bytes, want %d",
   107  				pkg.Path(), n, len(exportdata))
   108  		}
   109  
   110  		// Compare the packages' corresponding members.
   111  		for _, name := range pkg.Scope().Names() {
   112  			if !token.IsExported(name) {
   113  				continue
   114  			}
   115  			obj1 := pkg.Scope().Lookup(name)
   116  			obj2 := pkg2.Scope().Lookup(name)
   117  			if obj2 == nil {
   118  				t.Errorf("%s.%s not found, want %s", pkg.Path(), name, obj1)
   119  				continue
   120  			}
   121  
   122  			fl1 := fileLine(conf.Fset, obj1)
   123  			fl2 := fileLine(fset2, obj2)
   124  			if fl1 != fl2 {
   125  				t.Errorf("%s.%s: got posn %s, want %s",
   126  					pkg.Path(), name, fl2, fl1)
   127  			}
   128  
   129  			if err := equalObj(obj1, obj2); err != nil {
   130  				t.Errorf("%s.%s: %s\ngot:  %s\nwant: %s",
   131  					pkg.Path(), name, err, obj2, obj1)
   132  			}
   133  		}
   134  	}
   135  	if want := minStdlibPackages; checked < want {
   136  		t.Errorf("Checked only %d packages, want at least %d", checked, want)
   137  	}
   138  }
   139  
   140  func fileLine(fset *token.FileSet, obj types.Object) string {
   141  	posn := fset.Position(obj.Pos())
   142  	filename := filepath.Clean(strings.ReplaceAll(posn.Filename, "$GOROOT", runtime.GOROOT()))
   143  	return fmt.Sprintf("%s:%d", filename, posn.Line)
   144  }
   145  
   146  // equalObj reports how x and y differ.  They are assumed to belong to
   147  // different universes so cannot be compared directly.
   148  func equalObj(x, y types.Object) error {
   149  	if reflect.TypeOf(x) != reflect.TypeOf(y) {
   150  		return fmt.Errorf("%T vs %T", x, y)
   151  	}
   152  	xt := x.Type()
   153  	yt := y.Type()
   154  	switch x.(type) {
   155  	case *types.Var, *types.Func:
   156  		// ok
   157  	case *types.Const:
   158  		xval := x.(*types.Const).Val()
   159  		yval := y.(*types.Const).Val()
   160  		// Use string comparison for floating-point values since rounding is permitted.
   161  		if constant.Compare(xval, token.NEQ, yval) &&
   162  			!(xval.Kind() == constant.Float && xval.String() == yval.String()) {
   163  			return fmt.Errorf("unequal constants %s vs %s", xval, yval)
   164  		}
   165  	case *types.TypeName:
   166  		xt = xt.Underlying()
   167  		yt = yt.Underlying()
   168  	default:
   169  		return fmt.Errorf("unexpected %T", x)
   170  	}
   171  	return equalType(xt, yt)
   172  }
   173  
   174  func equalType(x, y types.Type) error {
   175  	if reflect.TypeOf(x) != reflect.TypeOf(y) {
   176  		return fmt.Errorf("unequal kinds: %T vs %T", x, y)
   177  	}
   178  	switch x := x.(type) {
   179  	case *types.Interface:
   180  		y := y.(*types.Interface)
   181  		// TODO(gri): enable separate emission of Embedded interfaces
   182  		// and ExplicitMethods then use this logic.
   183  		// if x.NumEmbeddeds() != y.NumEmbeddeds() {
   184  		// 	return fmt.Errorf("unequal number of embedded interfaces: %d vs %d",
   185  		// 		x.NumEmbeddeds(), y.NumEmbeddeds())
   186  		// }
   187  		// for i := 0; i < x.NumEmbeddeds(); i++ {
   188  		// 	xi := x.Embedded(i)
   189  		// 	yi := y.Embedded(i)
   190  		// 	if xi.String() != yi.String() {
   191  		// 		return fmt.Errorf("mismatched %th embedded interface: %s vs %s",
   192  		// 			i, xi, yi)
   193  		// 	}
   194  		// }
   195  		// if x.NumExplicitMethods() != y.NumExplicitMethods() {
   196  		// 	return fmt.Errorf("unequal methods: %d vs %d",
   197  		// 		x.NumExplicitMethods(), y.NumExplicitMethods())
   198  		// }
   199  		// for i := 0; i < x.NumExplicitMethods(); i++ {
   200  		// 	xm := x.ExplicitMethod(i)
   201  		// 	ym := y.ExplicitMethod(i)
   202  		// 	if xm.Name() != ym.Name() {
   203  		// 		return fmt.Errorf("mismatched %th method: %s vs %s", i, xm, ym)
   204  		// 	}
   205  		// 	if err := equalType(xm.Type(), ym.Type()); err != nil {
   206  		// 		return fmt.Errorf("mismatched %s method: %s", xm.Name(), err)
   207  		// 	}
   208  		// }
   209  		if x.NumMethods() != y.NumMethods() {
   210  			return fmt.Errorf("unequal methods: %d vs %d",
   211  				x.NumMethods(), y.NumMethods())
   212  		}
   213  		for i := 0; i < x.NumMethods(); i++ {
   214  			xm := x.Method(i)
   215  			ym := y.Method(i)
   216  			if xm.Name() != ym.Name() {
   217  				return fmt.Errorf("mismatched %dth method: %s vs %s", i, xm, ym)
   218  			}
   219  			if err := equalType(xm.Type(), ym.Type()); err != nil {
   220  				return fmt.Errorf("mismatched %s method: %s", xm.Name(), err)
   221  			}
   222  		}
   223  		// Constraints are handled explicitly in the *TypeParam case below, so we
   224  		// don't yet need to consider embeddeds here.
   225  		// TODO(rfindley): consider the type set here.
   226  	case *types.Array:
   227  		y := y.(*types.Array)
   228  		if x.Len() != y.Len() {
   229  			return fmt.Errorf("unequal array lengths: %d vs %d", x.Len(), y.Len())
   230  		}
   231  		if err := equalType(x.Elem(), y.Elem()); err != nil {
   232  			return fmt.Errorf("array elements: %s", err)
   233  		}
   234  	case *types.Basic:
   235  		y := y.(*types.Basic)
   236  		if x.Kind() != y.Kind() {
   237  			return fmt.Errorf("unequal basic types: %s vs %s", x, y)
   238  		}
   239  	case *types.Chan:
   240  		y := y.(*types.Chan)
   241  		if x.Dir() != y.Dir() {
   242  			return fmt.Errorf("unequal channel directions: %d vs %d", x.Dir(), y.Dir())
   243  		}
   244  		if err := equalType(x.Elem(), y.Elem()); err != nil {
   245  			return fmt.Errorf("channel elements: %s", err)
   246  		}
   247  	case *types.Map:
   248  		y := y.(*types.Map)
   249  		if err := equalType(x.Key(), y.Key()); err != nil {
   250  			return fmt.Errorf("map keys: %s", err)
   251  		}
   252  		if err := equalType(x.Elem(), y.Elem()); err != nil {
   253  			return fmt.Errorf("map values: %s", err)
   254  		}
   255  	case *types.Named:
   256  		y := y.(*types.Named)
   257  		return cmpNamed(x, y)
   258  	case *types.Pointer:
   259  		y := y.(*types.Pointer)
   260  		if err := equalType(x.Elem(), y.Elem()); err != nil {
   261  			return fmt.Errorf("pointer elements: %s", err)
   262  		}
   263  	case *types.Signature:
   264  		y := y.(*types.Signature)
   265  		if err := equalType(x.Params(), y.Params()); err != nil {
   266  			return fmt.Errorf("parameters: %s", err)
   267  		}
   268  		if err := equalType(x.Results(), y.Results()); err != nil {
   269  			return fmt.Errorf("results: %s", err)
   270  		}
   271  		if x.Variadic() != y.Variadic() {
   272  			return fmt.Errorf("unequal variadicity: %t vs %t",
   273  				x.Variadic(), y.Variadic())
   274  		}
   275  		if (x.Recv() != nil) != (y.Recv() != nil) {
   276  			return fmt.Errorf("unequal receivers: %s vs %s", x.Recv(), y.Recv())
   277  		}
   278  		if x.Recv() != nil {
   279  			// TODO(adonovan): fix: this assertion fires for interface methods.
   280  			// The type of the receiver of an interface method is a named type
   281  			// if the Package was loaded from export data, or an unnamed (interface)
   282  			// type if the Package was produced by type-checking ASTs.
   283  			// if err := equalType(x.Recv().Type(), y.Recv().Type()); err != nil {
   284  			// 	return fmt.Errorf("receiver: %s", err)
   285  			// }
   286  		}
   287  		if err := equalTypeParams(typeparams.ForSignature(x), typeparams.ForSignature(y)); err != nil {
   288  			return fmt.Errorf("type params: %s", err)
   289  		}
   290  		if err := equalTypeParams(typeparams.RecvTypeParams(x), typeparams.RecvTypeParams(y)); err != nil {
   291  			return fmt.Errorf("recv type params: %s", err)
   292  		}
   293  	case *types.Slice:
   294  		y := y.(*types.Slice)
   295  		if err := equalType(x.Elem(), y.Elem()); err != nil {
   296  			return fmt.Errorf("slice elements: %s", err)
   297  		}
   298  	case *types.Struct:
   299  		y := y.(*types.Struct)
   300  		if x.NumFields() != y.NumFields() {
   301  			return fmt.Errorf("unequal struct fields: %d vs %d",
   302  				x.NumFields(), y.NumFields())
   303  		}
   304  		for i := 0; i < x.NumFields(); i++ {
   305  			xf := x.Field(i)
   306  			yf := y.Field(i)
   307  			if xf.Name() != yf.Name() {
   308  				return fmt.Errorf("mismatched fields: %s vs %s", xf, yf)
   309  			}
   310  			if err := equalType(xf.Type(), yf.Type()); err != nil {
   311  				return fmt.Errorf("struct field %s: %s", xf.Name(), err)
   312  			}
   313  			if x.Tag(i) != y.Tag(i) {
   314  				return fmt.Errorf("struct field %s has unequal tags: %q vs %q",
   315  					xf.Name(), x.Tag(i), y.Tag(i))
   316  			}
   317  		}
   318  	case *types.Tuple:
   319  		y := y.(*types.Tuple)
   320  		if x.Len() != y.Len() {
   321  			return fmt.Errorf("unequal tuple lengths: %d vs %d", x.Len(), y.Len())
   322  		}
   323  		for i := 0; i < x.Len(); i++ {
   324  			if err := equalType(x.At(i).Type(), y.At(i).Type()); err != nil {
   325  				return fmt.Errorf("tuple element %d: %s", i, err)
   326  			}
   327  		}
   328  	case *typeparams.TypeParam:
   329  		y := y.(*typeparams.TypeParam)
   330  		if x.String() != y.String() {
   331  			return fmt.Errorf("unequal named types: %s vs %s", x, y)
   332  		}
   333  		// For now, just compare constraints by type string to short-circuit
   334  		// cycles. We have to make interfaces explicit as export data currently
   335  		// doesn't support marking interfaces as implicit.
   336  		// TODO(rfindley): remove makeExplicit once export data contains an
   337  		// implicit bit.
   338  		xc := makeExplicit(x.Constraint()).String()
   339  		yc := makeExplicit(y.Constraint()).String()
   340  		if xc != yc {
   341  			return fmt.Errorf("unequal constraints: %s vs %s", xc, yc)
   342  		}
   343  
   344  	default:
   345  		panic(fmt.Sprintf("unexpected %T type", x))
   346  	}
   347  	return nil
   348  }
   349  
   350  // cmpNamed compares two named types x and y, returning an error for any
   351  // discrepancies. It does not compare their underlying types.
   352  func cmpNamed(x, y *types.Named) error {
   353  	xOrig := typeparams.NamedTypeOrigin(x)
   354  	yOrig := typeparams.NamedTypeOrigin(y)
   355  	if xOrig.String() != yOrig.String() {
   356  		return fmt.Errorf("unequal named types: %s vs %s", x, y)
   357  	}
   358  	if err := equalTypeParams(typeparams.ForNamed(x), typeparams.ForNamed(y)); err != nil {
   359  		return fmt.Errorf("type parameters: %s", err)
   360  	}
   361  	if err := equalTypeArgs(typeparams.NamedTypeArgs(x), typeparams.NamedTypeArgs(y)); err != nil {
   362  		return fmt.Errorf("type arguments: %s", err)
   363  	}
   364  	if x.NumMethods() != y.NumMethods() {
   365  		return fmt.Errorf("unequal methods: %d vs %d",
   366  			x.NumMethods(), y.NumMethods())
   367  	}
   368  	// Unfortunately method sorting is not canonical, so sort before comparing.
   369  	var xms, yms []*types.Func
   370  	for i := 0; i < x.NumMethods(); i++ {
   371  		xms = append(xms, x.Method(i))
   372  		yms = append(yms, y.Method(i))
   373  	}
   374  	for _, ms := range [][]*types.Func{xms, yms} {
   375  		sort.Slice(ms, func(i, j int) bool {
   376  			return ms[i].Name() < ms[j].Name()
   377  		})
   378  	}
   379  	for i, xm := range xms {
   380  		ym := yms[i]
   381  		if xm.Name() != ym.Name() {
   382  			return fmt.Errorf("mismatched %dth method: %s vs %s", i, xm, ym)
   383  		}
   384  		// Calling equalType here leads to infinite recursion, so just compare
   385  		// strings.
   386  		if xm.String() != ym.String() {
   387  			return fmt.Errorf("unequal methods: %s vs %s", x, y)
   388  		}
   389  	}
   390  	return nil
   391  }
   392  
   393  // makeExplicit returns an explicit version of typ, if typ is an implicit
   394  // interface. Otherwise it returns typ unmodified.
   395  func makeExplicit(typ types.Type) types.Type {
   396  	if iface, _ := typ.(*types.Interface); iface != nil && typeparams.IsImplicit(iface) {
   397  		var methods []*types.Func
   398  		for i := 0; i < iface.NumExplicitMethods(); i++ {
   399  			methods = append(methods, iface.Method(i))
   400  		}
   401  		var embeddeds []types.Type
   402  		for i := 0; i < iface.NumEmbeddeds(); i++ {
   403  			embeddeds = append(embeddeds, iface.EmbeddedType(i))
   404  		}
   405  		return types.NewInterfaceType(methods, embeddeds)
   406  	}
   407  	return typ
   408  }
   409  
   410  func equalTypeArgs(x, y *typeparams.TypeList) error {
   411  	if x.Len() != y.Len() {
   412  		return fmt.Errorf("unequal lengths: %d vs %d", x.Len(), y.Len())
   413  	}
   414  	for i := 0; i < x.Len(); i++ {
   415  		if err := equalType(x.At(i), y.At(i)); err != nil {
   416  			return fmt.Errorf("type %d: %s", i, err)
   417  		}
   418  	}
   419  	return nil
   420  }
   421  
   422  func equalTypeParams(x, y *typeparams.TypeParamList) error {
   423  	if x.Len() != y.Len() {
   424  		return fmt.Errorf("unequal lengths: %d vs %d", x.Len(), y.Len())
   425  	}
   426  	for i := 0; i < x.Len(); i++ {
   427  		if err := equalType(x.At(i), y.At(i)); err != nil {
   428  			return fmt.Errorf("type parameter %d: %s", i, err)
   429  		}
   430  	}
   431  	return nil
   432  }
   433  
   434  // TestVeryLongFile tests the position of an import object declared in
   435  // a very long input file.  Line numbers greater than maxlines are
   436  // reported as line 1, not garbage or token.NoPos.
   437  func TestVeryLongFile(t *testing.T) {
   438  	// parse and typecheck
   439  	longFile := "package foo" + strings.Repeat("\n", 123456) + "var X int"
   440  	fset1 := token.NewFileSet()
   441  	f, err := parser.ParseFile(fset1, "foo.go", longFile, 0)
   442  	if err != nil {
   443  		t.Fatal(err)
   444  	}
   445  	var conf types.Config
   446  	pkg, err := conf.Check("foo", fset1, []*ast.File{f}, nil)
   447  	if err != nil {
   448  		t.Fatal(err)
   449  	}
   450  
   451  	// export
   452  	exportdata, err := gcimporter.BExportData(fset1, pkg)
   453  	if err != nil {
   454  		t.Fatal(err)
   455  	}
   456  
   457  	// import
   458  	imports := make(map[string]*types.Package)
   459  	fset2 := token.NewFileSet()
   460  	_, pkg2, err := gcimporter.BImportData(fset2, imports, exportdata, pkg.Path())
   461  	if err != nil {
   462  		t.Fatalf("BImportData(%s): %v", pkg.Path(), err)
   463  	}
   464  
   465  	// compare
   466  	posn1 := fset1.Position(pkg.Scope().Lookup("X").Pos())
   467  	posn2 := fset2.Position(pkg2.Scope().Lookup("X").Pos())
   468  	if want := "foo.go:1:1"; posn2.String() != want {
   469  		t.Errorf("X position = %s, want %s (orig was %s)",
   470  			posn2, want, posn1)
   471  	}
   472  }
   473  
   474  const src = `
   475  package p
   476  
   477  type (
   478  	T0 = int32
   479  	T1 = struct{}
   480  	T2 = struct{ T1 }
   481  	Invalid = foo // foo is undeclared
   482  )
   483  `
   484  
   485  func checkPkg(t *testing.T, pkg *types.Package, label string) {
   486  	T1 := types.NewStruct(nil, nil)
   487  	T2 := types.NewStruct([]*types.Var{types.NewField(0, pkg, "T1", T1, true)}, nil)
   488  
   489  	for _, test := range []struct {
   490  		name string
   491  		typ  types.Type
   492  	}{
   493  		{"T0", types.Typ[types.Int32]},
   494  		{"T1", T1},
   495  		{"T2", T2},
   496  		{"Invalid", types.Typ[types.Invalid]},
   497  	} {
   498  		obj := pkg.Scope().Lookup(test.name)
   499  		if obj == nil {
   500  			t.Errorf("%s: %s not found", label, test.name)
   501  			continue
   502  		}
   503  		tname, _ := obj.(*types.TypeName)
   504  		if tname == nil {
   505  			t.Errorf("%s: %v not a type name", label, obj)
   506  			continue
   507  		}
   508  		if !tname.IsAlias() {
   509  			t.Errorf("%s: %v: not marked as alias", label, tname)
   510  			continue
   511  		}
   512  		if got := tname.Type(); !types.Identical(got, test.typ) {
   513  			t.Errorf("%s: %v: got %v; want %v", label, tname, got, test.typ)
   514  		}
   515  	}
   516  }
   517  
   518  func TestTypeAliases(t *testing.T) {
   519  	// parse and typecheck
   520  	fset1 := token.NewFileSet()
   521  	f, err := parser.ParseFile(fset1, "p.go", src, 0)
   522  	if err != nil {
   523  		t.Fatal(err)
   524  	}
   525  	var conf types.Config
   526  	pkg1, err := conf.Check("p", fset1, []*ast.File{f}, nil)
   527  	if err == nil {
   528  		// foo in undeclared in src; we should see an error
   529  		t.Fatal("invalid source type-checked without error")
   530  	}
   531  	if pkg1 == nil {
   532  		// despite incorrect src we should see a (partially) type-checked package
   533  		t.Fatal("nil package returned")
   534  	}
   535  	checkPkg(t, pkg1, "export")
   536  
   537  	// export
   538  	exportdata, err := gcimporter.BExportData(fset1, pkg1)
   539  	if err != nil {
   540  		t.Fatal(err)
   541  	}
   542  
   543  	// import
   544  	imports := make(map[string]*types.Package)
   545  	fset2 := token.NewFileSet()
   546  	_, pkg2, err := gcimporter.BImportData(fset2, imports, exportdata, pkg1.Path())
   547  	if err != nil {
   548  		t.Fatalf("BImportData(%s): %v", pkg1.Path(), err)
   549  	}
   550  	checkPkg(t, pkg2, "import")
   551  }