github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/go/types/typeutil/callee_test.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package typeutil_test
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"go/parser"
    11  	"go/token"
    12  	"go/types"
    13  	"strings"
    14  	"testing"
    15  
    16  	"golang.org/x/tools/go/types/typeutil"
    17  	"golang.org/x/tools/internal/typeparams"
    18  )
    19  
    20  func TestStaticCallee(t *testing.T) {
    21  	testStaticCallee(t, []string{
    22  		`package q;
    23  		func Abs(x int) int {
    24  			if x < 0 {
    25  				return -x
    26  			}
    27  			return x
    28  		}`,
    29  		`package p
    30  		import "q"
    31  
    32  		type T int
    33  
    34  		func g(int)
    35  
    36  		var f = g
    37  
    38  		var x int
    39  
    40  		type s struct{ f func(int) }
    41  		func (s) g(int)
    42  
    43  		type I interface{ f(int) }
    44  
    45  		var a struct{b struct{c s}}
    46  
    47  		var n map[int]func()
    48  		var m []func()
    49  
    50  		func calls() {
    51  			g(x)           // a declared func
    52  			s{}.g(x)       // a concrete method
    53  			a.b.c.g(x)     // same
    54  			_ = q.Abs(x)   // declared func, qualified identifier
    55  		}
    56  
    57  		func noncalls() {
    58  			_ = T(x)    // a type
    59  			f(x)        // a var
    60  			panic(x)    // a built-in
    61  			s{}.f(x)    // a field
    62  			I(nil).f(x) // interface method
    63  			m[0]()      // a map
    64  			n[0]()      // a slice
    65  		}
    66  		`})
    67  }
    68  
    69  func TestTypeParamStaticCallee(t *testing.T) {
    70  	if !typeparams.Enabled {
    71  		t.Skip("type parameters are not enabled")
    72  	}
    73  	testStaticCallee(t, []string{
    74  		`package q
    75  		func R[T any]() {}
    76  		`,
    77  		`package p
    78  		import "q"
    79  		type I interface{
    80  			i()
    81  		}
    82  
    83  		type G[T any] func() T
    84  		func F[T any]() T { var x T; return x }
    85  
    86  		type M[T I] struct{ t T }
    87  		func (m M[T]) noncalls() {
    88  			m.t.i()   // method on a type parameter
    89  		}
    90  
    91  		func (m M[T]) calls() {
    92  			m.calls() // method on a generic type
    93  		}
    94  
    95  		type Chain[T I] struct{ r struct { s M[T] } }
    96  
    97  		type S int
    98  		func (S) i() {}
    99  
   100  		func Multi[TP0, TP1 any](){}
   101  
   102  		func calls() {
   103  			_ = F[int]()            // instantiated function
   104  			_ = (F[int])()          // go through parens
   105  			M[S]{}.calls()          // instantiated method
   106  			Chain[S]{}.r.s.calls()  // same as above
   107  			Multi[int,string]()     // multiple type parameters
   108  			q.R[int]()              // different package
   109  		}
   110  
   111  		func noncalls() {
   112  			_ = G[int](nil)()  // instantiated function
   113  		}
   114  		`})
   115  }
   116  
   117  // testStaticCallee parses and type checks each file content in contents
   118  // as a single file package in order. Within functions that have the suffix
   119  // "calls" it checks that the CallExprs within have a static callee.
   120  // If the function's name == "calls" all calls must have static callees,
   121  // and if the name != "calls", the calls must not have static callees.
   122  // Failures are reported on t.
   123  func testStaticCallee(t *testing.T, contents []string) {
   124  	fset := token.NewFileSet()
   125  	packages := make(map[string]*types.Package)
   126  	cfg := &types.Config{Importer: closure(packages)}
   127  	info := &types.Info{
   128  		Uses:       make(map[*ast.Ident]types.Object),
   129  		Selections: make(map[*ast.SelectorExpr]*types.Selection),
   130  	}
   131  	typeparams.InitInstanceInfo(info)
   132  
   133  	var files []*ast.File
   134  	for i, content := range contents {
   135  		// parse
   136  		f, err := parser.ParseFile(fset, fmt.Sprintf("%d.go", i), content, 0)
   137  		if err != nil {
   138  			t.Fatal(err)
   139  		}
   140  		files = append(files, f)
   141  
   142  		// type-check
   143  		pkg, err := cfg.Check(f.Name.Name, fset, []*ast.File{f}, info)
   144  		if err != nil {
   145  			t.Fatal(err)
   146  		}
   147  		packages[pkg.Path()] = pkg
   148  	}
   149  
   150  	// check
   151  	for _, f := range files {
   152  		for _, decl := range f.Decls {
   153  			if decl, ok := decl.(*ast.FuncDecl); ok && strings.HasSuffix(decl.Name.Name, "calls") {
   154  				wantCallee := decl.Name.Name == "calls" // false within func noncalls()
   155  				ast.Inspect(decl.Body, func(n ast.Node) bool {
   156  					if call, ok := n.(*ast.CallExpr); ok {
   157  						fn := typeutil.StaticCallee(info, call)
   158  						if fn == nil && wantCallee {
   159  							t.Errorf("%s: StaticCallee returned nil",
   160  								fset.Position(call.Lparen))
   161  						} else if fn != nil && !wantCallee {
   162  							t.Errorf("%s: StaticCallee returned %s, want nil",
   163  								fset.Position(call.Lparen), fn)
   164  						}
   165  					}
   166  					return true
   167  				})
   168  			}
   169  		}
   170  	}
   171  }