github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/go/ir/source_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //lint:file-ignore SA1019 go/ssa's test suite is built around the deprecated go/loader. We'll leave fixing that to upstream.
     6  
     7  package ir_test
     8  
     9  // This file defines tests of source-level debugging utilities.
    10  
    11  import (
    12  	"fmt"
    13  	"go/ast"
    14  	"go/constant"
    15  	"go/parser"
    16  	"go/token"
    17  	"go/types"
    18  	"io/ioutil"
    19  	"os"
    20  	"path/filepath"
    21  	"runtime"
    22  	"strings"
    23  	"testing"
    24  
    25  	"github.com/amarpal/go-tools/go/ast/astutil"
    26  	"github.com/amarpal/go-tools/go/ir"
    27  	"github.com/amarpal/go-tools/go/ir/irutil"
    28  
    29  	"golang.org/x/tools/go/analysis/analysistest"
    30  	"golang.org/x/tools/go/expect"
    31  	"golang.org/x/tools/go/loader"
    32  )
    33  
    34  func TestObjValueLookup(t *testing.T) {
    35  	if runtime.GOOS == "android" {
    36  		t.Skipf("no testdata directory on %s", runtime.GOOS)
    37  	}
    38  
    39  	conf := loader.Config{ParserMode: parser.ParseComments}
    40  	src, err := ioutil.ReadFile(filepath.Join(analysistest.TestData(), "objlookup.go"))
    41  	if err != nil {
    42  		t.Fatal(err)
    43  	}
    44  	readFile := func(filename string) ([]byte, error) { return src, nil }
    45  	f, err := conf.ParseFile(filepath.Join(analysistest.TestData(), "objlookup.go"), src)
    46  	if err != nil {
    47  		t.Fatal(err)
    48  	}
    49  	conf.CreateFromFiles("main", f)
    50  
    51  	// Maps each var Ident (represented "name:linenum") to the
    52  	// kind of ir.Value we expect (represented "Constant", "&Alloc").
    53  	expectations := make(map[string]string)
    54  
    55  	// Each note of the form @ir(x, "BinOp") in testdata/objlookup.go
    56  	// specifies an expectation that an object named x declared on the
    57  	// same line is associated with an an ir.Value of type *ir.BinOp.
    58  	notes, err := expect.ExtractGo(conf.Fset, f)
    59  	if err != nil {
    60  		t.Fatal(err)
    61  	}
    62  	for _, n := range notes {
    63  		if n.Name != "ir" {
    64  			t.Errorf("%v: unexpected note type %q, want \"ir\"", conf.Fset.Position(n.Pos), n.Name)
    65  			continue
    66  		}
    67  		if len(n.Args) != 2 {
    68  			t.Errorf("%v: ir has %d args, want 2", conf.Fset.Position(n.Pos), len(n.Args))
    69  			continue
    70  		}
    71  		ident, ok := n.Args[0].(expect.Identifier)
    72  		if !ok {
    73  			t.Errorf("%v: got %v for arg 1, want identifier", conf.Fset.Position(n.Pos), n.Args[0])
    74  			continue
    75  		}
    76  		exp, ok := n.Args[1].(string)
    77  		if !ok {
    78  			t.Errorf("%v: got %v for arg 2, want string", conf.Fset.Position(n.Pos), n.Args[1])
    79  			continue
    80  		}
    81  		p, _, err := expect.MatchBefore(conf.Fset, readFile, n.Pos, string(ident))
    82  		if err != nil {
    83  			t.Error(err)
    84  			continue
    85  		}
    86  		pos := conf.Fset.Position(p)
    87  		key := fmt.Sprintf("%s:%d", ident, pos.Line)
    88  		expectations[key] = exp
    89  	}
    90  
    91  	iprog, err := conf.Load()
    92  	if err != nil {
    93  		t.Error(err)
    94  		return
    95  	}
    96  
    97  	prog := irutil.CreateProgram(iprog, 0 /*|ir.PrintFunctions*/)
    98  	mainInfo := iprog.Created[0]
    99  	mainPkg := prog.Package(mainInfo.Pkg)
   100  	mainPkg.SetDebugMode(true)
   101  	mainPkg.Build()
   102  
   103  	var varIds []*ast.Ident
   104  	var varObjs []*types.Var
   105  	for id, obj := range mainInfo.Defs {
   106  		// Check invariants for func and const objects.
   107  		switch obj := obj.(type) {
   108  		case *types.Func:
   109  			checkFuncValue(t, prog, obj)
   110  
   111  		case *types.Const:
   112  			checkConstValue(t, prog, obj)
   113  
   114  		case *types.Var:
   115  			if id.Name == "_" {
   116  				continue
   117  			}
   118  			varIds = append(varIds, id)
   119  			varObjs = append(varObjs, obj)
   120  		}
   121  	}
   122  	for id, obj := range mainInfo.Uses {
   123  		if obj, ok := obj.(*types.Var); ok {
   124  			varIds = append(varIds, id)
   125  			varObjs = append(varObjs, obj)
   126  		}
   127  	}
   128  
   129  	// Check invariants for var objects.
   130  	// The result varies based on the specific Ident.
   131  	for i, id := range varIds {
   132  		obj := varObjs[i]
   133  		ref, _ := astutil.PathEnclosingInterval(f, id.Pos(), id.Pos())
   134  		pos := prog.Fset.Position(id.Pos())
   135  		exp := expectations[fmt.Sprintf("%s:%d", id.Name, pos.Line)]
   136  		if exp == "" {
   137  			t.Errorf("%s: no expectation for var ident %s ", pos, id.Name)
   138  			continue
   139  		}
   140  		wantAddr := false
   141  		if exp[0] == '&' {
   142  			wantAddr = true
   143  			exp = exp[1:]
   144  		}
   145  		checkVarValue(t, prog, mainPkg, ref, obj, exp, wantAddr)
   146  	}
   147  }
   148  
   149  func checkFuncValue(t *testing.T, prog *ir.Program, obj *types.Func) {
   150  	fn := prog.FuncValue(obj)
   151  	// fmt.Printf("FuncValue(%s) = %s\n", obj, fn) // debugging
   152  	if fn == nil {
   153  		if obj.Name() != "interfaceMethod" {
   154  			t.Errorf("FuncValue(%s) == nil", obj)
   155  		}
   156  		return
   157  	}
   158  	if fnobj := fn.Object(); fnobj != obj {
   159  		t.Errorf("FuncValue(%s).Object() == %s; value was %s",
   160  			obj, fnobj, fn.Name())
   161  		return
   162  	}
   163  	if !types.Identical(fn.Type(), obj.Type()) {
   164  		t.Errorf("FuncValue(%s).Type() == %s", obj, fn.Type())
   165  		return
   166  	}
   167  }
   168  
   169  func checkConstValue(t *testing.T, prog *ir.Program, obj *types.Const) {
   170  	c := prog.ConstValue(obj)
   171  	// fmt.Printf("ConstValue(%s) = %s\n", obj, c) // debugging
   172  	if c == nil {
   173  		t.Errorf("ConstValue(%s) == nil", obj)
   174  		return
   175  	}
   176  	if !types.Identical(c.Type(), obj.Type()) {
   177  		t.Errorf("ConstValue(%s).Type() == %s", obj, c.Type())
   178  		return
   179  	}
   180  	if obj.Name() != "nil" {
   181  		if !constant.Compare(c.Value, token.EQL, obj.Val()) {
   182  			t.Errorf("ConstValue(%s).Value (%s) != %s",
   183  				obj, c.Value, obj.Val())
   184  			return
   185  		}
   186  	}
   187  }
   188  
   189  func checkVarValue(t *testing.T, prog *ir.Program, pkg *ir.Package, ref []ast.Node, obj *types.Var, expKind string, wantAddr bool) {
   190  	// The prefix of all assertions messages.
   191  	prefix := fmt.Sprintf("VarValue(%s @ L%d)",
   192  		obj, prog.Fset.Position(ref[0].Pos()).Line)
   193  
   194  	v, gotAddr := prog.VarValue(obj, pkg, ref)
   195  
   196  	// Kind is the concrete type of the ir Value.
   197  	gotKind := "nil"
   198  	if v != nil {
   199  		gotKind = fmt.Sprintf("%T", v)[len("*ir."):]
   200  	}
   201  
   202  	// fmt.Printf("%s = %v (kind %q; expect %q) wantAddr=%t gotAddr=%t\n", prefix, v, gotKind, expKind, wantAddr, gotAddr) // debugging
   203  
   204  	// Check the kinds match.
   205  	// "nil" indicates expected failure (e.g. optimized away).
   206  	if expKind != gotKind {
   207  		t.Errorf("%s concrete type == %s, want %s", prefix, gotKind, expKind)
   208  	}
   209  
   210  	// Check the types match.
   211  	// If wantAddr, the expected type is the object's address.
   212  	if v != nil {
   213  		expType := obj.Type()
   214  		if wantAddr {
   215  			expType = types.NewPointer(expType)
   216  			if !gotAddr {
   217  				t.Errorf("%s: got value, want address", prefix)
   218  			}
   219  		} else if gotAddr {
   220  			t.Errorf("%s: got address, want value", prefix)
   221  		}
   222  		if !types.Identical(v.Type(), expType) {
   223  			t.Errorf("%s.Type() == %s, want %s", prefix, v.Type(), expType)
   224  		}
   225  	}
   226  }
   227  
   228  // Ensure that, in debug mode, we can determine the ir.Value
   229  // corresponding to every ast.Expr.
   230  func TestValueForExpr(t *testing.T) {
   231  	testValueForExpr(t, filepath.Join(analysistest.TestData(), "valueforexpr.go"))
   232  }
   233  
   234  func testValueForExpr(t *testing.T, testfile string) {
   235  	if runtime.GOOS == "android" {
   236  		t.Skipf("no testdata dir on %s", runtime.GOOS)
   237  	}
   238  
   239  	conf := loader.Config{ParserMode: parser.ParseComments}
   240  	f, err := conf.ParseFile(testfile, nil)
   241  	if err != nil {
   242  		t.Error(err)
   243  		return
   244  	}
   245  	conf.CreateFromFiles("main", f)
   246  
   247  	iprog, err := conf.Load()
   248  	if err != nil {
   249  		t.Error(err)
   250  		return
   251  	}
   252  
   253  	mainInfo := iprog.Created[0]
   254  
   255  	prog := irutil.CreateProgram(iprog, 0)
   256  	mainPkg := prog.Package(mainInfo.Pkg)
   257  	mainPkg.SetDebugMode(true)
   258  	mainPkg.Build()
   259  
   260  	if false {
   261  		// debugging
   262  		for _, mem := range mainPkg.Members {
   263  			if fn, ok := mem.(*ir.Function); ok {
   264  				fn.WriteTo(os.Stderr)
   265  			}
   266  		}
   267  	}
   268  
   269  	var parenExprs []*ast.ParenExpr
   270  	ast.Inspect(f, func(n ast.Node) bool {
   271  		if n != nil {
   272  			if e, ok := n.(*ast.ParenExpr); ok {
   273  				parenExprs = append(parenExprs, e)
   274  			}
   275  		}
   276  		return true
   277  	})
   278  
   279  	notes, err := expect.ExtractGo(prog.Fset, f)
   280  	if err != nil {
   281  		t.Fatal(err)
   282  	}
   283  	for _, n := range notes {
   284  		want := n.Name
   285  		if want == "nil" {
   286  			want = "<nil>"
   287  		}
   288  		position := prog.Fset.Position(n.Pos)
   289  		var e ast.Expr
   290  		for _, paren := range parenExprs {
   291  			if paren.Pos() > n.Pos {
   292  				e = paren.X
   293  				break
   294  			}
   295  		}
   296  		if e == nil {
   297  			t.Errorf("%s: note doesn't precede ParenExpr: %q", position, want)
   298  			continue
   299  		}
   300  
   301  		path, _ := astutil.PathEnclosingInterval(f, n.Pos, n.Pos)
   302  		if path == nil {
   303  			t.Errorf("%s: can't find AST path from root to comment: %s", position, want)
   304  			continue
   305  		}
   306  
   307  		fn := ir.EnclosingFunction(mainPkg, path)
   308  		if fn == nil {
   309  			t.Errorf("%s: can't find enclosing function", position)
   310  			continue
   311  		}
   312  
   313  		v, gotAddr := fn.ValueForExpr(e) // (may be nil)
   314  		got := strings.TrimPrefix(fmt.Sprintf("%T", v), "*ir.")
   315  		if got != want {
   316  			t.Errorf("%s: got value %q, want %q", position, got, want)
   317  		}
   318  		if v != nil {
   319  			T := v.Type()
   320  			if gotAddr {
   321  				T = T.Underlying().(*types.Pointer).Elem() // deref
   322  			}
   323  			if !types.Identical(T, mainInfo.TypeOf(e)) {
   324  				t.Errorf("%s: got type %s, want %s", position, mainInfo.TypeOf(e), T)
   325  			}
   326  		}
   327  	}
   328  }
   329  
   330  // findInterval parses input and returns the [start, end) positions of
   331  // the first occurrence of substr in input.  f==nil indicates failure;
   332  // an error has already been reported in that case.
   333  func findInterval(t *testing.T, fset *token.FileSet, input, substr string) (f *ast.File, start, end token.Pos) {
   334  	f, err := parser.ParseFile(fset, "<input>", input, parser.SkipObjectResolution)
   335  	if err != nil {
   336  		t.Errorf("parse error: %s", err)
   337  		return
   338  	}
   339  
   340  	i := strings.Index(input, substr)
   341  	if i < 0 {
   342  		t.Errorf("%q is not a substring of input", substr)
   343  		f = nil
   344  		return
   345  	}
   346  
   347  	filePos := fset.File(f.Package)
   348  	return f, filePos.Pos(i), filePos.Pos(i + len(substr))
   349  }
   350  
   351  func TestEnclosingFunction(t *testing.T) {
   352  	tests := []struct {
   353  		input  string // the input file
   354  		substr string // first occurrence of this string denotes interval
   355  		fn     string // name of expected containing function
   356  	}{
   357  		// We use distinctive numbers as syntactic landmarks.
   358  
   359  		// Ordinary function:
   360  		{`package main
   361  		  func f() { println(1003) }`,
   362  			"100", "main.f"},
   363  		// Methods:
   364  		{`package main
   365                    type T int
   366  		  func (t T) f() { println(200) }`,
   367  			"200", "(main.T).f"},
   368  		// Function literal:
   369  		{`package main
   370  		  func f() { println(func() { print(300) }) }`,
   371  			"300", "main.f$1"},
   372  		// Doubly nested
   373  		{`package main
   374  		  func f() { println(func() { print(func() { print(350) })})}`,
   375  			"350", "main.f$1$1"},
   376  		// Implicit init for package-level var initializer.
   377  		{"package main; var a = 400", "400", "main.init"},
   378  		// No code for constants:
   379  		{"package main; const a = 500", "500", "(none)"},
   380  		// Explicit init()
   381  		{"package main; func init() { println(600) }", "600", "main.init#1"},
   382  		// Multiple explicit init functions:
   383  		{`package main
   384  		  func init() { println("foo") }
   385  		  func init() { println(800) }`,
   386  			"800", "main.init#2"},
   387  		// init() containing FuncLit.
   388  		{`package main
   389  		  func init() { println(func(){print(900)}) }`,
   390  			"900", "main.init#1$1"},
   391  	}
   392  	for _, test := range tests {
   393  		conf := loader.Config{Fset: token.NewFileSet()}
   394  		f, start, end := findInterval(t, conf.Fset, test.input, test.substr)
   395  		if f == nil {
   396  			continue
   397  		}
   398  		path, exact := astutil.PathEnclosingInterval(f, start, end)
   399  		if !exact {
   400  			t.Errorf("EnclosingFunction(%q) not exact", test.substr)
   401  			continue
   402  		}
   403  
   404  		conf.CreateFromFiles("main", f)
   405  
   406  		iprog, err := conf.Load()
   407  		if err != nil {
   408  			t.Error(err)
   409  			continue
   410  		}
   411  		prog := irutil.CreateProgram(iprog, 0)
   412  		pkg := prog.Package(iprog.Created[0].Pkg)
   413  		pkg.Build()
   414  
   415  		name := "(none)"
   416  		fn := ir.EnclosingFunction(pkg, path)
   417  		if fn != nil {
   418  			name = fn.String()
   419  		}
   420  
   421  		if name != test.fn {
   422  			t.Errorf("EnclosingFunction(%q in %q) got %s, want %s",
   423  				test.substr, test.input, name, test.fn)
   424  			continue
   425  		}
   426  
   427  		// While we're here: test HasEnclosingFunction.
   428  		if has := ir.HasEnclosingFunction(pkg, path); has != (fn != nil) {
   429  			t.Errorf("HasEnclosingFunction(%q in %q) got %v, want %v",
   430  				test.substr, test.input, has, fn != nil)
   431  			continue
   432  		}
   433  	}
   434  }