github.com/golangci/go-tools@v0.0.0-20190318060251-af6baa5dc196/ssa/source_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa_test
     6  
     7  // This file defines tests of source-level debugging utilities.
     8  
     9  import (
    10  	"fmt"
    11  	"go/ast"
    12  	exact "go/constant"
    13  	"go/parser"
    14  	"go/token"
    15  	"go/types"
    16  	"os"
    17  	"regexp"
    18  	"runtime"
    19  	"strings"
    20  	"testing"
    21  
    22  	"golang.org/x/tools/go/ast/astutil"
    23  	"golang.org/x/tools/go/loader"
    24  	"github.com/golangci/go-tools/ssa"
    25  	"github.com/golangci/go-tools/ssa/ssautil"
    26  )
    27  
    28  func TestObjValueLookup(t *testing.T) {
    29  	if runtime.GOOS == "android" {
    30  		t.Skipf("no testdata directory on %s", runtime.GOOS)
    31  	}
    32  
    33  	conf := loader.Config{ParserMode: parser.ParseComments}
    34  	f, err := conf.ParseFile("testdata/objlookup.go", nil)
    35  	if err != nil {
    36  		t.Error(err)
    37  		return
    38  	}
    39  	conf.CreateFromFiles("main", f)
    40  
    41  	// Maps each var Ident (represented "name:linenum") to the
    42  	// kind of ssa.Value we expect (represented "Constant", "&Alloc").
    43  	expectations := make(map[string]string)
    44  
    45  	// Find all annotations of form x::BinOp, &y::Alloc, etc.
    46  	re := regexp.MustCompile(`(\b|&)?(\w*)::(\w*)\b`)
    47  	for _, c := range f.Comments {
    48  		text := c.Text()
    49  		pos := conf.Fset.Position(c.Pos())
    50  		for _, m := range re.FindAllStringSubmatch(text, -1) {
    51  			key := fmt.Sprintf("%s:%d", m[2], pos.Line)
    52  			value := m[1] + m[3]
    53  			expectations[key] = value
    54  		}
    55  	}
    56  
    57  	iprog, err := conf.Load()
    58  	if err != nil {
    59  		t.Error(err)
    60  		return
    61  	}
    62  
    63  	prog := ssautil.CreateProgram(iprog, 0 /*|ssa.PrintFunctions*/)
    64  	mainInfo := iprog.Created[0]
    65  	mainPkg := prog.Package(mainInfo.Pkg)
    66  	mainPkg.SetDebugMode(true)
    67  	mainPkg.Build()
    68  
    69  	var varIds []*ast.Ident
    70  	var varObjs []*types.Var
    71  	for id, obj := range mainInfo.Defs {
    72  		// Check invariants for func and const objects.
    73  		switch obj := obj.(type) {
    74  		case *types.Func:
    75  			checkFuncValue(t, prog, obj)
    76  
    77  		case *types.Const:
    78  			checkConstValue(t, prog, obj)
    79  
    80  		case *types.Var:
    81  			if id.Name == "_" {
    82  				continue
    83  			}
    84  			varIds = append(varIds, id)
    85  			varObjs = append(varObjs, obj)
    86  		}
    87  	}
    88  	for id, obj := range mainInfo.Uses {
    89  		if obj, ok := obj.(*types.Var); ok {
    90  			varIds = append(varIds, id)
    91  			varObjs = append(varObjs, obj)
    92  		}
    93  	}
    94  
    95  	// Check invariants for var objects.
    96  	// The result varies based on the specific Ident.
    97  	for i, id := range varIds {
    98  		obj := varObjs[i]
    99  		ref, _ := astutil.PathEnclosingInterval(f, id.Pos(), id.Pos())
   100  		pos := prog.Fset.Position(id.Pos())
   101  		exp := expectations[fmt.Sprintf("%s:%d", id.Name, pos.Line)]
   102  		if exp == "" {
   103  			t.Errorf("%s: no expectation for var ident %s ", pos, id.Name)
   104  			continue
   105  		}
   106  		wantAddr := false
   107  		if exp[0] == '&' {
   108  			wantAddr = true
   109  			exp = exp[1:]
   110  		}
   111  		checkVarValue(t, prog, mainPkg, ref, obj, exp, wantAddr)
   112  	}
   113  }
   114  
   115  func checkFuncValue(t *testing.T, prog *ssa.Program, obj *types.Func) {
   116  	fn := prog.FuncValue(obj)
   117  	// fmt.Printf("FuncValue(%s) = %s\n", obj, fn) // debugging
   118  	if fn == nil {
   119  		if obj.Name() != "interfaceMethod" {
   120  			t.Errorf("FuncValue(%s) == nil", obj)
   121  		}
   122  		return
   123  	}
   124  	if fnobj := fn.Object(); fnobj != obj {
   125  		t.Errorf("FuncValue(%s).Object() == %s; value was %s",
   126  			obj, fnobj, fn.Name())
   127  		return
   128  	}
   129  	if !types.Identical(fn.Type(), obj.Type()) {
   130  		t.Errorf("FuncValue(%s).Type() == %s", obj, fn.Type())
   131  		return
   132  	}
   133  }
   134  
   135  func checkConstValue(t *testing.T, prog *ssa.Program, obj *types.Const) {
   136  	c := prog.ConstValue(obj)
   137  	// fmt.Printf("ConstValue(%s) = %s\n", obj, c) // debugging
   138  	if c == nil {
   139  		t.Errorf("ConstValue(%s) == nil", obj)
   140  		return
   141  	}
   142  	if !types.Identical(c.Type(), obj.Type()) {
   143  		t.Errorf("ConstValue(%s).Type() == %s", obj, c.Type())
   144  		return
   145  	}
   146  	if obj.Name() != "nil" {
   147  		if !exact.Compare(c.Value, token.EQL, obj.Val()) {
   148  			t.Errorf("ConstValue(%s).Value (%s) != %s",
   149  				obj, c.Value, obj.Val())
   150  			return
   151  		}
   152  	}
   153  }
   154  
   155  func checkVarValue(t *testing.T, prog *ssa.Program, pkg *ssa.Package, ref []ast.Node, obj *types.Var, expKind string, wantAddr bool) {
   156  	// The prefix of all assertions messages.
   157  	prefix := fmt.Sprintf("VarValue(%s @ L%d)",
   158  		obj, prog.Fset.Position(ref[0].Pos()).Line)
   159  
   160  	v, gotAddr := prog.VarValue(obj, pkg, ref)
   161  
   162  	// Kind is the concrete type of the ssa Value.
   163  	gotKind := "nil"
   164  	if v != nil {
   165  		gotKind = fmt.Sprintf("%T", v)[len("*ssa."):]
   166  	}
   167  
   168  	// fmt.Printf("%s = %v (kind %q; expect %q) wantAddr=%t gotAddr=%t\n", prefix, v, gotKind, expKind, wantAddr, gotAddr) // debugging
   169  
   170  	// Check the kinds match.
   171  	// "nil" indicates expected failure (e.g. optimized away).
   172  	if expKind != gotKind {
   173  		t.Errorf("%s concrete type == %s, want %s", prefix, gotKind, expKind)
   174  	}
   175  
   176  	// Check the types match.
   177  	// If wantAddr, the expected type is the object's address.
   178  	if v != nil {
   179  		expType := obj.Type()
   180  		if wantAddr {
   181  			expType = types.NewPointer(expType)
   182  			if !gotAddr {
   183  				t.Errorf("%s: got value, want address", prefix)
   184  			}
   185  		} else if gotAddr {
   186  			t.Errorf("%s: got address, want value", prefix)
   187  		}
   188  		if !types.Identical(v.Type(), expType) {
   189  			t.Errorf("%s.Type() == %s, want %s", prefix, v.Type(), expType)
   190  		}
   191  	}
   192  }
   193  
   194  // Ensure that, in debug mode, we can determine the ssa.Value
   195  // corresponding to every ast.Expr.
   196  func TestValueForExpr(t *testing.T) {
   197  	testValueForExpr(t, "testdata/valueforexpr.go")
   198  }
   199  
   200  func testValueForExpr(t *testing.T, testfile string) {
   201  	if runtime.GOOS == "android" {
   202  		t.Skipf("no testdata dir on %s", runtime.GOOS)
   203  	}
   204  
   205  	conf := loader.Config{ParserMode: parser.ParseComments}
   206  	f, err := conf.ParseFile(testfile, nil)
   207  	if err != nil {
   208  		t.Error(err)
   209  		return
   210  	}
   211  	conf.CreateFromFiles("main", f)
   212  
   213  	iprog, err := conf.Load()
   214  	if err != nil {
   215  		t.Error(err)
   216  		return
   217  	}
   218  
   219  	mainInfo := iprog.Created[0]
   220  
   221  	prog := ssautil.CreateProgram(iprog, 0)
   222  	mainPkg := prog.Package(mainInfo.Pkg)
   223  	mainPkg.SetDebugMode(true)
   224  	mainPkg.Build()
   225  
   226  	if false {
   227  		// debugging
   228  		for _, mem := range mainPkg.Members {
   229  			if fn, ok := mem.(*ssa.Function); ok {
   230  				fn.WriteTo(os.Stderr)
   231  			}
   232  		}
   233  	}
   234  
   235  	// Find the actual AST node for each canonical position.
   236  	parenExprByPos := make(map[token.Pos]*ast.ParenExpr)
   237  	ast.Inspect(f, func(n ast.Node) bool {
   238  		if n != nil {
   239  			if e, ok := n.(*ast.ParenExpr); ok {
   240  				parenExprByPos[e.Pos()] = e
   241  			}
   242  		}
   243  		return true
   244  	})
   245  
   246  	// Find all annotations of form /*@kind*/.
   247  	for _, c := range f.Comments {
   248  		text := strings.TrimSpace(c.Text())
   249  		if text == "" || text[0] != '@' {
   250  			continue
   251  		}
   252  		text = text[1:]
   253  		pos := c.End() + 1
   254  		position := prog.Fset.Position(pos)
   255  		var e ast.Expr
   256  		if target := parenExprByPos[pos]; target == nil {
   257  			t.Errorf("%s: annotation doesn't precede ParenExpr: %q", position, text)
   258  			continue
   259  		} else {
   260  			e = target.X
   261  		}
   262  
   263  		path, _ := astutil.PathEnclosingInterval(f, pos, pos)
   264  		if path == nil {
   265  			t.Errorf("%s: can't find AST path from root to comment: %s", position, text)
   266  			continue
   267  		}
   268  
   269  		fn := ssa.EnclosingFunction(mainPkg, path)
   270  		if fn == nil {
   271  			t.Errorf("%s: can't find enclosing function", position)
   272  			continue
   273  		}
   274  
   275  		v, gotAddr := fn.ValueForExpr(e) // (may be nil)
   276  		got := strings.TrimPrefix(fmt.Sprintf("%T", v), "*ssa.")
   277  		if want := text; got != want {
   278  			t.Errorf("%s: got value %q, want %q", position, got, want)
   279  		}
   280  		if v != nil {
   281  			T := v.Type()
   282  			if gotAddr {
   283  				T = T.Underlying().(*types.Pointer).Elem() // deref
   284  			}
   285  			if !types.Identical(T, mainInfo.TypeOf(e)) {
   286  				t.Errorf("%s: got type %s, want %s", position, mainInfo.TypeOf(e), T)
   287  			}
   288  		}
   289  	}
   290  }
   291  
   292  // findInterval parses input and returns the [start, end) positions of
   293  // the first occurrence of substr in input.  f==nil indicates failure;
   294  // an error has already been reported in that case.
   295  //
   296  func findInterval(t *testing.T, fset *token.FileSet, input, substr string) (f *ast.File, start, end token.Pos) {
   297  	f, err := parser.ParseFile(fset, "<input>", input, 0)
   298  	if err != nil {
   299  		t.Errorf("parse error: %s", err)
   300  		return
   301  	}
   302  
   303  	i := strings.Index(input, substr)
   304  	if i < 0 {
   305  		t.Errorf("%q is not a substring of input", substr)
   306  		f = nil
   307  		return
   308  	}
   309  
   310  	filePos := fset.File(f.Package)
   311  	return f, filePos.Pos(i), filePos.Pos(i + len(substr))
   312  }
   313  
   314  func TestEnclosingFunction(t *testing.T) {
   315  	tests := []struct {
   316  		input  string // the input file
   317  		substr string // first occurrence of this string denotes interval
   318  		fn     string // name of expected containing function
   319  	}{
   320  		// We use distinctive numbers as syntactic landmarks.
   321  
   322  		// Ordinary function:
   323  		{`package main
   324  		  func f() { println(1003) }`,
   325  			"100", "main.f"},
   326  		// Methods:
   327  		{`package main
   328                    type T int
   329  		  func (t T) f() { println(200) }`,
   330  			"200", "(main.T).f"},
   331  		// Function literal:
   332  		{`package main
   333  		  func f() { println(func() { print(300) }) }`,
   334  			"300", "main.f$1"},
   335  		// Doubly nested
   336  		{`package main
   337  		  func f() { println(func() { print(func() { print(350) })})}`,
   338  			"350", "main.f$1$1"},
   339  		// Implicit init for package-level var initializer.
   340  		{"package main; var a = 400", "400", "main.init"},
   341  		// No code for constants:
   342  		{"package main; const a = 500", "500", "(none)"},
   343  		// Explicit init()
   344  		{"package main; func init() { println(600) }", "600", "main.init#1"},
   345  		// Multiple explicit init functions:
   346  		{`package main
   347  		  func init() { println("foo") }
   348  		  func init() { println(800) }`,
   349  			"800", "main.init#2"},
   350  		// init() containing FuncLit.
   351  		{`package main
   352  		  func init() { println(func(){print(900)}) }`,
   353  			"900", "main.init#1$1"},
   354  	}
   355  	for _, test := range tests {
   356  		conf := loader.Config{Fset: token.NewFileSet()}
   357  		f, start, end := findInterval(t, conf.Fset, test.input, test.substr)
   358  		if f == nil {
   359  			continue
   360  		}
   361  		path, exact := astutil.PathEnclosingInterval(f, start, end)
   362  		if !exact {
   363  			t.Errorf("EnclosingFunction(%q) not exact", test.substr)
   364  			continue
   365  		}
   366  
   367  		conf.CreateFromFiles("main", f)
   368  
   369  		iprog, err := conf.Load()
   370  		if err != nil {
   371  			t.Error(err)
   372  			continue
   373  		}
   374  		prog := ssautil.CreateProgram(iprog, 0)
   375  		pkg := prog.Package(iprog.Created[0].Pkg)
   376  		pkg.Build()
   377  
   378  		name := "(none)"
   379  		fn := ssa.EnclosingFunction(pkg, path)
   380  		if fn != nil {
   381  			name = fn.String()
   382  		}
   383  
   384  		if name != test.fn {
   385  			t.Errorf("EnclosingFunction(%q in %q) got %s, want %s",
   386  				test.substr, test.input, name, test.fn)
   387  			continue
   388  		}
   389  
   390  		// While we're here: test HasEnclosingFunction.
   391  		if has := ssa.HasEnclosingFunction(pkg, path); has != (fn != nil) {
   392  			t.Errorf("HasEnclosingFunction(%q in %q) got %v, want %v",
   393  				test.substr, test.input, has, fn != nil)
   394  			continue
   395  		}
   396  	}
   397  }