github.com/gopherjs/gopherjs@v1.19.0-beta1.0.20240506212314-27071a8796e4/compiler/internal/typeparams/collect_test.go (about)

     1  package typeparams
     2  
     3  import (
     4  	"go/ast"
     5  	"go/types"
     6  	"testing"
     7  
     8  	"github.com/google/go-cmp/cmp"
     9  	"github.com/gopherjs/gopherjs/internal/srctesting"
    10  	"golang.org/x/tools/go/ast/astutil"
    11  )
    12  
    13  func TestVisitor(t *testing.T) {
    14  	// This test verifies that instance collector is able to discover
    15  	// instantiations of generic types and functions in all possible contexts.
    16  	const src = `package testcase
    17  
    18  	type A struct{}
    19  	type B struct{}
    20  	type C struct{}
    21  	type D struct{}
    22  	type E struct{}
    23  	type F struct{}
    24  	type G struct{}
    25  
    26  	type typ[T any, V any] []T
    27  	func (t *typ[T, V]) method(x T) {}
    28  	func fun[U any, W any](x U, y W) {}
    29  
    30  	func entry1(arg typ[int8, A]) (result typ[int16, A]) {
    31  		fun(1, A{})
    32  		fun[int8, A](1, A{})
    33  		println(fun[int16, A])
    34  
    35  		t := typ[int, A]{}
    36  		t.method(0)
    37  		(*typ[int32, A]).method(nil, 0)
    38  		type x struct{ T []typ[int64, A] }
    39  
    40  		return
    41  	}
    42  
    43  	func entry2[T any](arg typ[int8, T]) (result typ[int16, T]) {
    44  		var zeroT T
    45  		fun(1, zeroT)
    46  		fun[int8, T](1, zeroT)
    47  		println(fun[int16, T])
    48  
    49  		t := typ[int, T]{}
    50  		t.method(0)
    51  		(*typ[int32, T]).method(nil, 0)
    52  		type x struct{ T []typ[int64, T] }
    53  
    54  		return
    55  	}
    56  
    57  	type entry3[T any] struct{
    58  		typ[int, T]
    59  		field1 struct { field2 typ[int8, T] }
    60  	}
    61  	func (e entry3[T]) method(arg typ[int8, T]) (result typ[int16, T]) {
    62  		var zeroT T
    63  		fun(1, zeroT)
    64  		fun[int8, T](1, zeroT)
    65  		println(fun[int16, T])
    66  
    67  		t := typ[int, T]{}
    68  		t.method(0)
    69  		(*typ[int32, T]).method(nil, 0)
    70  		type x struct{ T []typ[int64, T] }
    71  
    72  		return
    73  	}
    74  
    75  	type entry4 struct{
    76  		typ[int, E]
    77  		field1 struct { field2 typ[int8, E] }
    78  	}
    79  
    80  	type entry5 = typ[int, F]
    81  	`
    82  	f := srctesting.New(t)
    83  	file := f.Parse("test.go", src)
    84  	info, pkg := f.Check("pkg/test", file)
    85  
    86  	lookupObj := func(name string) types.Object {
    87  		return srctesting.LookupObj(pkg, name)
    88  	}
    89  	lookupType := func(name string) types.Type { return lookupObj(name).Type() }
    90  	lookupDecl := func(name string) ast.Node {
    91  		obj := lookupObj(name)
    92  		path, _ := astutil.PathEnclosingInterval(file, obj.Pos(), obj.Pos())
    93  		for _, n := range path {
    94  			switch n.(type) {
    95  			case *ast.FuncDecl, *ast.TypeSpec:
    96  				return n
    97  			}
    98  		}
    99  		t.Fatalf("Could not find AST node representing %v", obj)
   100  		return nil
   101  	}
   102  
   103  	// Generates a list of instances we expect to discover from functions and
   104  	// methods. Sentinel type is a type parameter we use uniquely within one
   105  	// context, which allows us to make sure that collection is not being tested
   106  	// against a wrong part of AST.
   107  	instancesInFunc := func(sentinel types.Type) []Instance {
   108  		return []Instance{
   109  			{
   110  				// Called with type arguments inferred.
   111  				Object: lookupObj("fun"),
   112  				TArgs:  []types.Type{types.Typ[types.Int], sentinel},
   113  			}, {
   114  				// Called with type arguments explicitly specified.
   115  				Object: lookupObj("fun"),
   116  				TArgs:  []types.Type{types.Typ[types.Int8], sentinel},
   117  			}, {
   118  				// Passed as an argument.
   119  				Object: lookupObj("fun"),
   120  				TArgs:  []types.Type{types.Typ[types.Int16], sentinel},
   121  			}, {
   122  				// Literal expression.
   123  				Object: lookupObj("typ"),
   124  				TArgs:  []types.Type{types.Typ[types.Int], sentinel},
   125  			}, {
   126  				Object: lookupObj("typ.method"),
   127  				TArgs:  []types.Type{types.Typ[types.Int], sentinel},
   128  			}, {
   129  				// Function argument.
   130  				Object: lookupObj("typ"),
   131  				TArgs:  []types.Type{types.Typ[types.Int8], sentinel},
   132  			}, {
   133  				Object: lookupObj("typ.method"),
   134  				TArgs:  []types.Type{types.Typ[types.Int8], sentinel},
   135  			}, {
   136  				// Function return type.
   137  				Object: lookupObj("typ"),
   138  				TArgs:  []types.Type{types.Typ[types.Int16], sentinel},
   139  			}, {
   140  				Object: lookupObj("typ.method"),
   141  				TArgs:  []types.Type{types.Typ[types.Int16], sentinel},
   142  			}, {
   143  				// Method expression.
   144  				Object: lookupObj("typ"),
   145  				TArgs:  []types.Type{types.Typ[types.Int32], sentinel},
   146  			}, {
   147  				Object: lookupObj("typ.method"),
   148  				TArgs:  []types.Type{types.Typ[types.Int32], sentinel},
   149  			}, {
   150  				// Type decl statement.
   151  				Object: lookupObj("typ"),
   152  				TArgs:  []types.Type{types.Typ[types.Int64], sentinel},
   153  			}, {
   154  				Object: lookupObj("typ.method"),
   155  				TArgs:  []types.Type{types.Typ[types.Int64], sentinel},
   156  			},
   157  		}
   158  	}
   159  
   160  	// Generates a list of instances we expect to discover from type declarations.
   161  	// Sentinel type is a type parameter we use uniquely within one context, which
   162  	// allows us to make sure that collection is not being tested against a wrong
   163  	// part of AST.
   164  	instancesInType := func(sentinel types.Type) []Instance {
   165  		return []Instance{
   166  			{
   167  				Object: lookupObj("typ"),
   168  				TArgs:  []types.Type{types.Typ[types.Int], sentinel},
   169  			}, {
   170  				Object: lookupObj("typ.method"),
   171  				TArgs:  []types.Type{types.Typ[types.Int], sentinel},
   172  			}, {
   173  				Object: lookupObj("typ"),
   174  				TArgs:  []types.Type{types.Typ[types.Int8], sentinel},
   175  			}, {
   176  				Object: lookupObj("typ.method"),
   177  				TArgs:  []types.Type{types.Typ[types.Int8], sentinel},
   178  			},
   179  		}
   180  	}
   181  
   182  	tests := []struct {
   183  		descr    string
   184  		resolver *Resolver
   185  		node     ast.Node
   186  		want     []Instance
   187  	}{
   188  		{
   189  			descr:    "non-generic function",
   190  			resolver: nil,
   191  			node:     lookupDecl("entry1"),
   192  			want:     instancesInFunc(lookupType("A")),
   193  		}, {
   194  			descr: "generic function",
   195  			resolver: NewResolver(
   196  				types.NewContext(),
   197  				ToSlice(lookupType("entry2").(*types.Signature).TypeParams()),
   198  				[]types.Type{lookupType("B")},
   199  			),
   200  			node: lookupDecl("entry2"),
   201  			want: instancesInFunc(lookupType("B")),
   202  		}, {
   203  			descr: "generic method",
   204  			resolver: NewResolver(
   205  				types.NewContext(),
   206  				ToSlice(lookupType("entry3.method").(*types.Signature).RecvTypeParams()),
   207  				[]types.Type{lookupType("C")},
   208  			),
   209  			node: lookupDecl("entry3.method"),
   210  			want: append(
   211  				instancesInFunc(lookupType("C")),
   212  				Instance{
   213  					Object: lookupObj("entry3"),
   214  					TArgs:  []types.Type{lookupType("C")},
   215  				},
   216  				Instance{
   217  					Object: lookupObj("entry3.method"),
   218  					TArgs:  []types.Type{lookupType("C")},
   219  				},
   220  			),
   221  		}, {
   222  			descr: "generic type declaration",
   223  			resolver: NewResolver(
   224  				types.NewContext(),
   225  				ToSlice(lookupType("entry3").(*types.Named).TypeParams()),
   226  				[]types.Type{lookupType("D")},
   227  			),
   228  			node: lookupDecl("entry3"),
   229  			want: instancesInType(lookupType("D")),
   230  		}, {
   231  			descr:    "non-generic type declaration",
   232  			resolver: nil,
   233  			node:     lookupDecl("entry4"),
   234  			want:     instancesInType(lookupType("E")),
   235  		}, {
   236  			descr:    "non-generic type alias",
   237  			resolver: nil,
   238  			node:     lookupDecl("entry5"),
   239  			want: []Instance{
   240  				{
   241  					Object: lookupObj("typ"),
   242  					TArgs:  []types.Type{types.Typ[types.Int], lookupType("F")},
   243  				},
   244  				{
   245  					Object: lookupObj("typ.method"),
   246  					TArgs:  []types.Type{types.Typ[types.Int], lookupType("F")},
   247  				},
   248  			},
   249  		},
   250  	}
   251  
   252  	for _, test := range tests {
   253  		t.Run(test.descr, func(t *testing.T) {
   254  			v := visitor{
   255  				instances: &PackageInstanceSets{},
   256  				resolver:  test.resolver,
   257  				info:      info,
   258  			}
   259  			ast.Walk(&v, test.node)
   260  			got := v.instances.Pkg(pkg).Values()
   261  			if diff := cmp.Diff(test.want, got, instanceOpts()); diff != "" {
   262  				t.Errorf("Discovered instance diff (-want,+got):\n%s", diff)
   263  			}
   264  		})
   265  	}
   266  }
   267  
   268  func TestSeedVisitor(t *testing.T) {
   269  	src := `package test
   270  	type typ[T any] int
   271  	func (t typ[T]) method(arg T) { var x typ[string]; _ = x }
   272  	func fun[T any](arg T) { var y typ[string]; _ = y }
   273  
   274  	const a typ[int] = 1
   275  	var b typ[int]
   276  	type c struct { field typ[int8] }
   277  	func (_ c) method() { var _ typ[int16] }
   278  	type d = typ[int32]
   279  	func e() { var _ typ[int64] }
   280  	`
   281  
   282  	f := srctesting.New(t)
   283  	file := f.Parse("test.go", src)
   284  	info, pkg := f.Check("pkg/test", file)
   285  
   286  	sv := seedVisitor{
   287  		visitor: visitor{
   288  			instances: &PackageInstanceSets{},
   289  			resolver:  nil,
   290  			info:      info,
   291  		},
   292  		objMap: map[types.Object]ast.Node{},
   293  	}
   294  	ast.Walk(&sv, file)
   295  
   296  	tInst := func(tArg types.Type) Instance {
   297  		return Instance{
   298  			Object: pkg.Scope().Lookup("typ"),
   299  			TArgs:  []types.Type{tArg},
   300  		}
   301  	}
   302  	mInst := func(tArg types.Type) Instance {
   303  		return Instance{
   304  			Object: srctesting.LookupObj(pkg, "typ.method"),
   305  			TArgs:  []types.Type{tArg},
   306  		}
   307  	}
   308  	want := []Instance{
   309  		tInst(types.Typ[types.Int]),
   310  		mInst(types.Typ[types.Int]),
   311  		tInst(types.Typ[types.Int8]),
   312  		mInst(types.Typ[types.Int8]),
   313  		tInst(types.Typ[types.Int16]),
   314  		mInst(types.Typ[types.Int16]),
   315  		tInst(types.Typ[types.Int32]),
   316  		mInst(types.Typ[types.Int32]),
   317  		tInst(types.Typ[types.Int64]),
   318  		mInst(types.Typ[types.Int64]),
   319  	}
   320  	got := sv.instances.Pkg(pkg).Values()
   321  	if diff := cmp.Diff(want, got, instanceOpts()); diff != "" {
   322  		t.Errorf("Instances from initialSeeder contain diff (-want,+got):\n%s", diff)
   323  	}
   324  }
   325  
   326  func TestCollector(t *testing.T) {
   327  	src := `package test
   328  	type typ[T any] int
   329  	func (t typ[T]) method(arg T) { var _ typ[int]; fun[int8](0) }
   330  	func fun[T any](arg T) {
   331  		var _ typ[int16]
   332  
   333  		type nested[U any] struct{}
   334  		_ = nested[T]{}
   335  	}
   336  
   337  	type ignore = int
   338  
   339  	func a() {
   340  		var _ typ[int32]
   341  		fun[int64](0)
   342  	}
   343  	`
   344  
   345  	f := srctesting.New(t)
   346  	file := f.Parse("test.go", src)
   347  	info, pkg := f.Check("pkg/test", file)
   348  
   349  	c := Collector{
   350  		TContext:  types.NewContext(),
   351  		Info:      info,
   352  		Instances: &PackageInstanceSets{},
   353  	}
   354  	c.Scan(pkg, file)
   355  
   356  	inst := func(name string, tArg types.Type) Instance {
   357  		return Instance{
   358  			Object: srctesting.LookupObj(pkg, name),
   359  			TArgs:  []types.Type{tArg},
   360  		}
   361  	}
   362  	want := []Instance{
   363  		inst("typ", types.Typ[types.Int]),
   364  		inst("typ.method", types.Typ[types.Int]),
   365  		inst("fun", types.Typ[types.Int8]),
   366  		inst("fun.nested", types.Typ[types.Int8]),
   367  		inst("typ", types.Typ[types.Int16]),
   368  		inst("typ.method", types.Typ[types.Int16]),
   369  		inst("typ", types.Typ[types.Int32]),
   370  		inst("typ.method", types.Typ[types.Int32]),
   371  		inst("fun", types.Typ[types.Int64]),
   372  		inst("fun.nested", types.Typ[types.Int64]),
   373  	}
   374  	got := c.Instances.Pkg(pkg).Values()
   375  	if diff := cmp.Diff(want, got, instanceOpts()); diff != "" {
   376  		t.Errorf("Instances from initialSeeder contain diff (-want,+got):\n%s", diff)
   377  	}
   378  }
   379  
   380  func TestCollector_CrossPackage(t *testing.T) {
   381  	f := srctesting.New(t)
   382  	const src = `package foo
   383  	type X[T any] struct {Value T}
   384  
   385  	func F[G any](g G) {
   386  		x := X[G]{}
   387  		println(x)
   388  	}
   389  
   390  	func DoFoo() {
   391  		F(int8(8))
   392  	}
   393  	`
   394  	fooFile := f.Parse("foo.go", src)
   395  	_, fooPkg := f.Check("pkg/foo", fooFile)
   396  
   397  	const src2 = `package bar
   398  	import "pkg/foo"
   399  	func FProxy[T any](t T) {
   400  		foo.F[T](t)
   401  	}
   402  	func DoBar() {
   403  		FProxy(int16(16))
   404  	}
   405  	`
   406  	barFile := f.Parse("bar.go", src2)
   407  	_, barPkg := f.Check("pkg/bar", barFile)
   408  
   409  	c := Collector{
   410  		TContext:  types.NewContext(),
   411  		Info:      f.Info,
   412  		Instances: &PackageInstanceSets{},
   413  	}
   414  	c.Scan(barPkg, barFile)
   415  	c.Scan(fooPkg, fooFile)
   416  
   417  	inst := func(pkg *types.Package, name string, tArg types.BasicKind) Instance {
   418  		return Instance{
   419  			Object: srctesting.LookupObj(pkg, name),
   420  			TArgs:  []types.Type{types.Typ[tArg]},
   421  		}
   422  	}
   423  
   424  	wantFooInstances := []Instance{
   425  		inst(fooPkg, "F", types.Int16), // Found in "pkg/foo".
   426  		inst(fooPkg, "F", types.Int8),
   427  		inst(fooPkg, "X", types.Int16), // Found due to F[int16] found in "pkg/foo".
   428  		inst(fooPkg, "X", types.Int8),
   429  	}
   430  	gotFooInstances := c.Instances.Pkg(fooPkg).Values()
   431  	if diff := cmp.Diff(wantFooInstances, gotFooInstances, instanceOpts()); diff != "" {
   432  		t.Errorf("Instances from pkg/foo contain diff (-want,+got):\n%s", diff)
   433  	}
   434  
   435  	wantBarInstances := []Instance{
   436  		inst(barPkg, "FProxy", types.Int16),
   437  	}
   438  	gotBarInstances := c.Instances.Pkg(barPkg).Values()
   439  	if diff := cmp.Diff(wantBarInstances, gotBarInstances, instanceOpts()); diff != "" {
   440  		t.Errorf("Instances from pkg/foo contain diff (-want,+got):\n%s", diff)
   441  	}
   442  }
   443  
   444  func TestResolver_SubstituteSelection(t *testing.T) {
   445  	tests := []struct {
   446  		descr   string
   447  		src     string
   448  		wantObj string
   449  		wantSig string
   450  	}{{
   451  		descr: "type parameter method",
   452  		src: `package test
   453  		type stringer interface{ String() string }
   454  
   455  		type x struct{}
   456  		func (_ x) String() string { return "" }
   457  
   458  		type g[T stringer] struct{}
   459  		func (_ g[T]) Method(t T) string {
   460  			return t.String()
   461  		}`,
   462  		wantObj: "func (pkg/test.x).String() string",
   463  		wantSig: "func() string",
   464  	}, {
   465  		descr: "generic receiver type with type parameter",
   466  		src: `package test
   467  			type x struct{}
   468  
   469  			type g[T any] struct{}
   470  			func (_ g[T]) Method(t T) string {
   471  				return g[T]{}.Method(t)
   472  			}`,
   473  		wantObj: "func (pkg/test.g[pkg/test.x]).Method(t pkg/test.x) string",
   474  		wantSig: "func(t pkg/test.x) string",
   475  	}, {
   476  		descr: "method expression",
   477  		src: `package test
   478  				type x struct{}
   479  
   480  				type g[T any] struct{}
   481  				func (recv g[T]) Method(t T) string {
   482  					return g[T].Method(recv, t)
   483  				}`,
   484  		wantObj: "func (pkg/test.g[pkg/test.x]).Method(t pkg/test.x) string",
   485  		wantSig: "func(recv pkg/test.g[pkg/test.x], t pkg/test.x) string",
   486  	}}
   487  
   488  	for _, test := range tests {
   489  		t.Run(test.descr, func(t *testing.T) {
   490  			f := srctesting.New(t)
   491  			file := f.Parse("test.go", test.src)
   492  			info, pkg := f.Check("pkg/test", file)
   493  
   494  			method := srctesting.LookupObj(pkg, "g.Method").(*types.Func).Type().(*types.Signature)
   495  			resolver := NewResolver(nil, ToSlice(method.RecvTypeParams()), []types.Type{srctesting.LookupObj(pkg, "x").Type()})
   496  
   497  			if l := len(info.Selections); l != 1 {
   498  				t.Fatalf("Got: %d selections. Want: 1", l)
   499  			}
   500  			for _, sel := range info.Selections {
   501  				gotObj := types.ObjectString(resolver.SubstituteSelection(sel).Obj(), nil)
   502  				if gotObj != test.wantObj {
   503  					t.Fatalf("Got: resolver.SubstituteSelection().Obj() = %q. Want: %q.", gotObj, test.wantObj)
   504  				}
   505  				gotSig := types.TypeString(resolver.SubstituteSelection(sel).Type(), nil)
   506  				if gotSig != test.wantSig {
   507  					t.Fatalf("Got: resolver.SubstituteSelection().Type() = %q. Want: %q.", gotSig, test.wantSig)
   508  				}
   509  			}
   510  		})
   511  	}
   512  }