github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/go/ssa/builder_generic_test.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa_test
     6  
     7  import (
     8  	"fmt"
     9  	"go/parser"
    10  	"go/token"
    11  	"reflect"
    12  	"sort"
    13  	"testing"
    14  
    15  	"golang.org/x/tools/go/expect"
    16  	"golang.org/x/tools/go/loader"
    17  	"golang.org/x/tools/go/ssa"
    18  	"golang.org/x/tools/internal/typeparams"
    19  )
    20  
    21  // TestGenericBodies tests that bodies of generic functions and methods containing
    22  // different constructs can be built in BuilderMode(0).
    23  //
    24  // Each test specifies the contents of package containing a single go file.
    25  // Each call print(arg0, arg1, ...) to the builtin print function
    26  // in ssa is correlated a comment at the end of the line of the form:
    27  //
    28  //	//@ types(a, b, c)
    29  //
    30  // where a, b and c are the types of the arguments to the print call
    31  // serialized using go/types.Type.String().
    32  // See x/tools/go/expect for details on the syntax.
    33  func TestGenericBodies(t *testing.T) {
    34  	if !typeparams.Enabled {
    35  		t.Skip("TestGenericBodies requires type parameters")
    36  	}
    37  	for _, test := range []struct {
    38  		pkg      string // name of the package.
    39  		contents string // contents of the Go package.
    40  	}{
    41  		{
    42  			pkg: "p",
    43  			contents: `
    44  			package p
    45  
    46  			func f(x int) {
    47  				var i interface{}
    48  				print(i, 0) //@ types("interface{}", int)
    49  				print()     //@ types()
    50  				print(x)    //@ types(int)
    51  			}
    52  			`,
    53  		},
    54  		{
    55  			pkg: "q",
    56  			contents: `
    57  			package q
    58  
    59  			func f[T any](x T) {
    60  				print(x) //@ types(T)
    61  			}
    62  			`,
    63  		},
    64  		{
    65  			pkg: "r",
    66  			contents: `
    67  			package r
    68  
    69  			func f[T ~int]() {
    70  				var x T
    71  				print(x) //@ types(T)
    72  			}
    73  			`,
    74  		},
    75  		{
    76  			pkg: "s",
    77  			contents: `
    78  			package s
    79  
    80  			func a[T ~[4]byte](x T) {
    81  				for k, v := range x {
    82  					print(x, k, v) //@ types(T, int, byte)
    83  				}
    84  			}
    85  			func b[T ~*[4]byte](x T) {
    86  				for k, v := range x {
    87  					print(x, k, v) //@ types(T, int, byte)
    88  				}
    89  			}
    90  			func c[T ~[]byte](x T) {
    91  				for k, v := range x {
    92  					print(x, k, v) //@ types(T, int, byte)
    93  				}
    94  			}
    95  			func d[T ~string](x T) {
    96  				for k, v := range x {
    97  					print(x, k, v) //@ types(T, int, rune)
    98  				}
    99  			}
   100  			func e[T ~map[int]string](x T) {
   101  				for k, v := range x {
   102  					print(x, k, v) //@ types(T, int, string)
   103  				}
   104  			}
   105  			func f[T ~chan string](x T) {
   106  				for v := range x {
   107  					print(x, v) //@ types(T, string)
   108  				}
   109  			}
   110  
   111  			func From() {
   112  				type A [4]byte
   113  				print(a[A]) //@ types("func(x s.A)")
   114  
   115  				type B *[4]byte
   116  				print(b[B]) //@ types("func(x s.B)")
   117  
   118  				type C []byte
   119  				print(c[C]) //@ types("func(x s.C)")
   120  
   121  				type D string
   122  				print(d[D]) //@ types("func(x s.D)")
   123  
   124  				type E map[int]string
   125  				print(e[E]) //@ types("func(x s.E)")
   126  
   127  				type F chan string
   128  				print(f[F]) //@ types("func(x s.F)")
   129  			}
   130  			`,
   131  		},
   132  		{
   133  			pkg: "t",
   134  			contents: `
   135  			package t
   136  
   137  			func f[S any, T ~chan S](x T) {
   138  				for v := range x {
   139  					print(x, v) //@ types(T, S)
   140  				}
   141  			}
   142  
   143  			func From() {
   144  				type F chan string
   145  				print(f[string, F]) //@ types("func(x t.F)")
   146  			}
   147  			`,
   148  		},
   149  		{
   150  			pkg: "u",
   151  			contents: `
   152  			package u
   153  
   154  			func fibonacci[T ~chan int](c, quit T) {
   155  				x, y := 0, 1
   156  				for {
   157  					select {
   158  					case c <- x:
   159  						x, y = y, x+y
   160  					case <-quit:
   161  						print(c, quit, x, y) //@ types(T, T, int, int)
   162  						return
   163  					}
   164  				}
   165  			}
   166  			func start[T ~chan int](c, quit T) {
   167  				go func() {
   168  					for i := 0; i < 10; i++ {
   169  						print(<-c) //@ types(int)
   170  					}
   171  					quit <- 0
   172  				}()
   173  			}
   174  			func From() {
   175  				type F chan int
   176  				c := make(F)
   177  				quit := make(F)
   178  				print(start[F], c, quit)     //@ types("func(c u.F, quit u.F)", "u.F", "u.F")
   179  				print(fibonacci[F], c, quit) //@ types("func(c u.F, quit u.F)", "u.F", "u.F")
   180  			}
   181  			`,
   182  		},
   183  		{
   184  			pkg: "v",
   185  			contents: `
   186  			package v
   187  
   188  			func f[T ~struct{ x int; y string }](i int) T {
   189  				u := []T{ T{0, "lorem"},  T{1, "ipsum"}}
   190  				return u[i]
   191  			}
   192  			func From() {
   193  				type S struct{ x int; y string }
   194  				print(f[S])     //@ types("func(i int) v.S")
   195  			}
   196  			`,
   197  		},
   198  		{
   199  			pkg: "w",
   200  			contents: `
   201  			package w
   202  
   203  			func f[T ~[4]int8](x T, l, h int) []int8 {
   204  				return x[l:h]
   205  			}
   206  			func g[T ~*[4]int16](x T, l, h int) []int16 {
   207  				return x[l:h]
   208  			}
   209  			func h[T ~[]int32](x T, l, h int) T {
   210  				return x[l:h]
   211  			}
   212  			func From() {
   213  				type F [4]int8
   214  				type G *[4]int16
   215  				type H []int32
   216  				print(f[F](F{}, 0, 0))  //@ types("[]int8")
   217  				print(g[G](nil, 0, 0)) //@ types("[]int16")
   218  				print(h[H](nil, 0, 0)) //@ types("w.H")
   219  			}
   220  			`,
   221  		},
   222  		{
   223  			pkg: "x",
   224  			contents: `
   225  			package x
   226  
   227  			func h[E any, T ~[]E](x T, l, h int) []E {
   228  				s := x[l:h]
   229  				print(s) //@ types("T")
   230  				return s
   231  			}
   232  			func From() {
   233  				type H []int32
   234  				print(h[int32, H](nil, 0, 0)) //@ types("[]int32")
   235  			}
   236  			`,
   237  		},
   238  		{
   239  			pkg: "y",
   240  			contents: `
   241  			package y
   242  
   243  			// Test "make" builtin with different forms on core types and
   244  			// when capacities are constants or variable.
   245  			func h[E any, T ~[]E](m, n int) {
   246  				print(make(T, 3))    //@ types(T)
   247  				print(make(T, 3, 5)) //@ types(T)
   248  				print(make(T, m))    //@ types(T)
   249  				print(make(T, m, n)) //@ types(T)
   250  			}
   251  			func i[K comparable, E any, T ~map[K]E](m int) {
   252  				print(make(T))    //@ types(T)
   253  				print(make(T, 5)) //@ types(T)
   254  				print(make(T, m)) //@ types(T)
   255  			}
   256  			func j[E any, T ~chan E](m int) {
   257  				print(make(T))    //@ types(T)
   258  				print(make(T, 6)) //@ types(T)
   259  				print(make(T, m)) //@ types(T)
   260  			}
   261  			func From() {
   262  				type H []int32
   263  				h[int32, H](3, 4)
   264  				type I map[int8]H
   265  				i[int8, H, I](5)
   266  				type J chan I
   267  				j[I, J](6)
   268  			}
   269  			`,
   270  		},
   271  		{
   272  			pkg: "z",
   273  			contents: `
   274  			package z
   275  
   276  			func h[T ~[4]int](x T) {
   277  				print(len(x), cap(x)) //@ types(int, int)
   278  			}
   279  			func i[T ~[4]byte | []int | ~chan uint8](x T) {
   280  				print(len(x), cap(x)) //@ types(int, int)
   281  			}
   282  			func j[T ~[4]int | any | map[string]int]() {
   283  				print(new(T)) //@ types("*T")
   284  			}
   285  			func k[T ~[4]int | any | map[string]int](x T) {
   286  				print(x) //@ types(T)
   287  				panic(x)
   288  			}
   289  			`,
   290  		},
   291  		{
   292  			pkg: "a",
   293  			contents: `
   294  			package a
   295  
   296  			func f[E any, F ~func() E](x F) {
   297  				print(x, x()) //@ types(F, E)
   298  			}
   299  			func From() {
   300  				type T func() int
   301  				f[int, T](func() int { return 0 })
   302  				f[int, func() int](func() int { return 1 })
   303  			}
   304  			`,
   305  		},
   306  		{
   307  			pkg: "b",
   308  			contents: `
   309  			package b
   310  
   311  			func f[E any, M ~map[string]E](m M) {
   312  				y, ok := m["lorem"]
   313  				print(m, y, ok) //@ types(M, E, bool)
   314  			}
   315  			func From() {
   316  				type O map[string][]int
   317  				f(O{"lorem": []int{0, 1, 2, 3}})
   318  			}
   319  			`,
   320  		},
   321  		{
   322  			pkg: "c",
   323  			contents: `
   324  			package c
   325  
   326  			func a[T interface{ []int64 | [5]int64 }](x T) int64 {
   327  				print(x, x[2], x[3]) //@ types(T, int64, int64)
   328  				x[2] = 5
   329  				return x[3]
   330  			}
   331  			func b[T interface{ []byte | string }](x T) byte {
   332  				print(x, x[3]) //@ types(T, byte)
   333  		        return x[3]
   334  			}
   335  			func c[T interface{ []byte }](x T) byte {
   336  				print(x, x[2], x[3]) //@ types(T, byte, byte)
   337  				x[2] = 'b'
   338  				return x[3]
   339  			}
   340  			func d[T interface{ map[int]int64 }](x T) int64 {
   341  				print(x, x[2], x[3]) //@ types(T, int64, int64)
   342  				x[2] = 43
   343          		return x[3]
   344  			}
   345  			func e[T ~string](t T) {
   346  				print(t, t[0]) //@ types(T, uint8)
   347  			}
   348  			func f[T ~string|[]byte](t T) {
   349  				print(t, t[0]) //@ types(T, uint8)
   350  			}
   351  			func g[T []byte](t T) {
   352  				print(t, t[0]) //@ types(T, byte)
   353  			}
   354  			func h[T ~[4]int|[]int](t T) {
   355  				print(t, t[0]) //@ types(T, int)
   356  			}
   357  			func i[T ~[4]int|*[4]int|[]int](t T) {
   358  				print(t, t[0]) //@ types(T, int)
   359  			}
   360  			func j[T ~[4]int|*[4]int|[]int](t T) {
   361  				print(t, &t[0]) //@ types(T, "*int")
   362  			}
   363  			`,
   364  		},
   365  		{
   366  			pkg: "d",
   367  			contents: `
   368  			package d
   369  
   370  			type MyInt int
   371  			type Other int
   372  			type MyInterface interface{ foo() }
   373  
   374  			// ChangeType tests
   375  			func ct0(x int) { v := MyInt(x);  print(x, v) /*@ types(int, "d.MyInt")*/ }
   376  			func ct1[T MyInt | Other, S int ](x S) { v := T(x);  print(x, v) /*@ types(S, T)*/ }
   377  			func ct2[T int, S MyInt | int ](x S) { v := T(x); print(x, v) /*@ types(S, T)*/ }
   378  			func ct3[T MyInt | Other, S MyInt | int ](x S) { v := T(x) ; print(x, v) /*@ types(S, T)*/ }
   379  
   380  			// Convert tests
   381  			func co0[T int | int8](x MyInt) { v := T(x); print(x, v) /*@ types("d.MyInt", T)*/}
   382  			func co1[T int | int8](x T) { v := MyInt(x); print(x, v) /*@ types(T, "d.MyInt")*/ }
   383  			func co2[S, T int | int8](x T) { v := S(x); print(x, v) /*@ types(T, S)*/ }
   384  
   385  			// MakeInterface tests
   386  			func mi0[T MyInterface](x T) { v := MyInterface(x); print(x, v) /*@ types(T, "d.MyInterface")*/ }
   387  
   388  			// NewConst tests
   389  			func nc0[T any]() { v := (*T)(nil); print(v) /*@ types("*T")*/}
   390  
   391  			// SliceToArrayPointer
   392  			func sl0[T *[4]int | *[2]int](x []int) { v := T(x); print(x, v) /*@ types("[]int", T)*/ }
   393  			func sl1[T *[4]int | *[2]int, S []int](x S) { v := T(x); print(x, v) /*@ types(S, T)*/ }
   394  			`,
   395  		},
   396  		{
   397  			pkg: "e",
   398  			contents: `
   399  			package e
   400  
   401  			func c[T interface{ foo() string }](x T) {
   402  				print(x, x.foo, x.foo())  /*@ types(T, "func() string", string)*/
   403  			}
   404  			`,
   405  		},
   406  		{
   407  			pkg: "f",
   408  			contents: `package f
   409  
   410  			func eq[T comparable](t T, i interface{}) bool {
   411  				return t == i
   412  			}
   413  			`,
   414  		},
   415  		{
   416  			pkg: "g",
   417  			contents: `package g
   418  			type S struct{ f int }
   419  			func c[P *S]() []P { return []P{{f: 1}} }
   420  			`,
   421  		},
   422  		{
   423  			pkg: "h",
   424  			contents: `package h
   425  			func sign[bytes []byte | string](s bytes) (bool, bool) {
   426  				neg := false
   427  				if len(s) > 0 && (s[0] == '-' || s[0] == '+') {
   428  					neg = s[0] == '-'
   429  					s = s[1:]
   430  				}
   431  				return !neg, len(s) > 0
   432  			}`,
   433  		},
   434  		{
   435  			pkg: "i",
   436  			contents: `package i
   437  			func digits[bytes []byte | string](s bytes) bool {
   438  				for _, c := range []byte(s) {
   439  					if c < '0' || '9' < c {
   440  						return false
   441  					}
   442  				}
   443  				return true
   444  			}`,
   445  		},
   446  		{
   447  			pkg: "j",
   448  			contents: `
   449  			package j
   450  
   451  			type E interface{}
   452  
   453  			func Foo[T E, PT interface{ *T }]() T {
   454  				pt := PT(new(T))
   455  				x := *pt
   456  				print(x)  /*@ types(T)*/
   457  				return x
   458  			}
   459  			`,
   460  		},
   461  	} {
   462  		test := test
   463  		t.Run(test.pkg, func(t *testing.T) {
   464  			// Parse
   465  			conf := loader.Config{ParserMode: parser.ParseComments}
   466  			fname := test.pkg + ".go"
   467  			f, err := conf.ParseFile(fname, test.contents)
   468  			if err != nil {
   469  				t.Fatalf("parse: %v", err)
   470  			}
   471  			conf.CreateFromFiles(test.pkg, f)
   472  
   473  			// Load
   474  			lprog, err := conf.Load()
   475  			if err != nil {
   476  				t.Fatalf("Load: %v", err)
   477  			}
   478  
   479  			// Create and build SSA
   480  			prog := ssa.NewProgram(lprog.Fset, ssa.SanityCheckFunctions)
   481  			for _, info := range lprog.AllPackages {
   482  				if info.TransitivelyErrorFree {
   483  					prog.CreatePackage(info.Pkg, info.Files, &info.Info, info.Importable)
   484  				}
   485  			}
   486  			p := prog.Package(lprog.Package(test.pkg).Pkg)
   487  			p.Build()
   488  
   489  			// Collect calls to the builtin print function.
   490  			probes := make(map[*ssa.CallCommon]bool)
   491  			for _, mem := range p.Members {
   492  				if fn, ok := mem.(*ssa.Function); ok {
   493  					for _, bb := range fn.Blocks {
   494  						for _, i := range bb.Instrs {
   495  							if i, ok := i.(ssa.CallInstruction); ok {
   496  								call := i.Common()
   497  								if b, ok := call.Value.(*ssa.Builtin); ok && b.Name() == "print" {
   498  									probes[i.Common()] = true
   499  								}
   500  							}
   501  						}
   502  					}
   503  				}
   504  			}
   505  
   506  			// Collect all notes in f, i.e. comments starting with "//@ types".
   507  			notes, err := expect.ExtractGo(prog.Fset, f)
   508  			if err != nil {
   509  				t.Errorf("expect.ExtractGo: %v", err)
   510  			}
   511  
   512  			// Matches each probe with a note that has the same line.
   513  			sameLine := func(x, y token.Pos) bool {
   514  				xp := prog.Fset.Position(x)
   515  				yp := prog.Fset.Position(y)
   516  				return xp.Filename == yp.Filename && xp.Line == yp.Line
   517  			}
   518  			expectations := make(map[*ssa.CallCommon]*expect.Note)
   519  			for call := range probes {
   520  				var match *expect.Note
   521  				for _, note := range notes {
   522  					if note.Name == "types" && sameLine(call.Pos(), note.Pos) {
   523  						match = note // first match is good enough.
   524  						break
   525  					}
   526  				}
   527  				if match != nil {
   528  					expectations[call] = match
   529  				} else {
   530  					t.Errorf("Unmatched probe: %v", call)
   531  				}
   532  			}
   533  
   534  			// Check each expectation.
   535  			for call, note := range expectations {
   536  				var args []string
   537  				for _, a := range call.Args {
   538  					args = append(args, a.Type().String())
   539  				}
   540  				if got, want := fmt.Sprint(args), fmt.Sprint(note.Args); got != want {
   541  					t.Errorf("Arguments to print() were expected to be %q. got %q", want, got)
   542  				}
   543  			}
   544  		})
   545  	}
   546  }
   547  
   548  // TestInstructionString tests serializing instructions via Instruction.String().
   549  func TestInstructionString(t *testing.T) {
   550  	if !typeparams.Enabled {
   551  		t.Skip("TestInstructionString requires type parameters")
   552  	}
   553  	// Tests (ssa.Instruction).String(). Instructions are from a single go file.
   554  	// The Instructions tested are those that match a comment of the form:
   555  	//
   556  	//	//@ instrs(f, kind, strs...)
   557  	//
   558  	// where f is the name of the function, kind is the type of the instructions matched
   559  	// within the function, and tests that the String() value for all of the instructions
   560  	// matched of String() is strs (in some order).
   561  	// See x/tools/go/expect for details on the syntax.
   562  
   563  	const contents = `
   564  	package p
   565  
   566  	//@ instrs("f", "*ssa.TypeAssert")
   567  	//@ instrs("f", "*ssa.Call", "print(nil:interface{}, 0:int)")
   568  	func f(x int) { // non-generic smoke test.
   569  		var i interface{}
   570  		print(i, 0)
   571  	}
   572  
   573  	//@ instrs("h", "*ssa.Alloc", "local T (u)")
   574  	//@ instrs("h", "*ssa.FieldAddr", "&t0.x [#0]")
   575  	func h[T ~struct{ x string }]() T {
   576  		u := T{"lorem"}
   577  		return u
   578  	}
   579  
   580  	//@ instrs("c", "*ssa.TypeAssert", "typeassert t0.(interface{})")
   581  	//@ instrs("c", "*ssa.Call", "invoke x.foo()")
   582  	func c[T interface{ foo() string }](x T) {
   583  		_ = x.foo
   584  		_ = x.foo()
   585  	}
   586  
   587  	//@ instrs("d", "*ssa.TypeAssert", "typeassert t0.(interface{})")
   588  	//@ instrs("d", "*ssa.Call", "invoke x.foo()")
   589  	func d[T interface{ foo() string; comparable }](x T) {
   590  		_ = x.foo
   591  		_ = x.foo()
   592  	}
   593  	`
   594  
   595  	// Parse
   596  	conf := loader.Config{ParserMode: parser.ParseComments}
   597  	const fname = "p.go"
   598  	f, err := conf.ParseFile(fname, contents)
   599  	if err != nil {
   600  		t.Fatalf("parse: %v", err)
   601  	}
   602  	conf.CreateFromFiles("p", f)
   603  
   604  	// Load
   605  	lprog, err := conf.Load()
   606  	if err != nil {
   607  		t.Fatalf("Load: %v", err)
   608  	}
   609  
   610  	// Create and build SSA
   611  	prog := ssa.NewProgram(lprog.Fset, ssa.SanityCheckFunctions)
   612  	for _, info := range lprog.AllPackages {
   613  		if info.TransitivelyErrorFree {
   614  			prog.CreatePackage(info.Pkg, info.Files, &info.Info, info.Importable)
   615  		}
   616  	}
   617  	p := prog.Package(lprog.Package("p").Pkg)
   618  	p.Build()
   619  
   620  	// Collect all notes in f, i.e. comments starting with "//@ instr".
   621  	notes, err := expect.ExtractGo(prog.Fset, f)
   622  	if err != nil {
   623  		t.Errorf("expect.ExtractGo: %v", err)
   624  	}
   625  
   626  	// Expectation is a {function, type string} -> {want, matches}
   627  	// where matches is all Instructions.String() that match the key.
   628  	// Each expecation is that some permutation of matches is wants.
   629  	type expKey struct {
   630  		function string
   631  		kind     string
   632  	}
   633  	type expValue struct {
   634  		wants   []string
   635  		matches []string
   636  	}
   637  	expectations := make(map[expKey]*expValue)
   638  	for _, note := range notes {
   639  		if note.Name == "instrs" {
   640  			if len(note.Args) < 2 {
   641  				t.Error("Had @instrs annotation without at least 2 arguments")
   642  				continue
   643  			}
   644  			fn, kind := fmt.Sprint(note.Args[0]), fmt.Sprint(note.Args[1])
   645  			var wants []string
   646  			for _, arg := range note.Args[2:] {
   647  				wants = append(wants, fmt.Sprint(arg))
   648  			}
   649  			expectations[expKey{fn, kind}] = &expValue{wants, nil}
   650  		}
   651  	}
   652  
   653  	// Collect all Instructions that match the expectations.
   654  	for _, mem := range p.Members {
   655  		if fn, ok := mem.(*ssa.Function); ok {
   656  			for _, bb := range fn.Blocks {
   657  				for _, i := range bb.Instrs {
   658  					kind := fmt.Sprintf("%T", i)
   659  					if e := expectations[expKey{fn.Name(), kind}]; e != nil {
   660  						e.matches = append(e.matches, i.String())
   661  					}
   662  				}
   663  			}
   664  		}
   665  	}
   666  
   667  	// Check each expectation.
   668  	for key, value := range expectations {
   669  		if _, ok := p.Members[key.function]; !ok {
   670  			t.Errorf("Expectation on %s does not match a member in %s", key.function, p.Pkg.Name())
   671  		}
   672  		got, want := value.matches, value.wants
   673  		sort.Strings(got)
   674  		sort.Strings(want)
   675  		if !reflect.DeepEqual(want, got) {
   676  			t.Errorf("Within %s wanted instructions of kind %s: %q. got %q", key.function, key.kind, want, got)
   677  		}
   678  	}
   679  }