github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/go/callgraph/vta/graph_test.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package vta
     6  
     7  import (
     8  	"fmt"
     9  	"go/types"
    10  	"reflect"
    11  	"sort"
    12  	"strings"
    13  	"testing"
    14  
    15  	"golang.org/x/tools/go/callgraph/cha"
    16  	"golang.org/x/tools/go/ssa"
    17  	"golang.org/x/tools/go/ssa/ssautil"
    18  )
    19  
    20  func TestNodeInterface(t *testing.T) {
    21  	// Since ssa package does not allow explicit creation of ssa
    22  	// values, we use the values from the program testdata/src/simple.go:
    23  	//   - basic type int
    24  	//   - struct X with two int fields a and b
    25  	//   - global variable "gl"
    26  	//   - "main" function and its
    27  	//   - first register instruction t0 := *gl
    28  	prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0))
    29  	if err != nil {
    30  		t.Fatalf("couldn't load testdata/src/simple.go program: %v", err)
    31  	}
    32  
    33  	pkg := prog.AllPackages()[0]
    34  	main := pkg.Func("main")
    35  	reg := firstRegInstr(main) // t0 := *gl
    36  	X := pkg.Type("X").Type()
    37  	gl := pkg.Var("gl")
    38  	glPtrType, ok := gl.Type().(*types.Pointer)
    39  	if !ok {
    40  		t.Fatalf("could not cast gl variable to pointer type")
    41  	}
    42  	bint := glPtrType.Elem()
    43  
    44  	pint := types.NewPointer(bint)
    45  	i := types.NewInterface(nil, nil)
    46  
    47  	voidFunc := main.Signature.Underlying()
    48  
    49  	for _, test := range []struct {
    50  		n node
    51  		s string
    52  		t types.Type
    53  	}{
    54  		{constant{typ: bint}, "Constant(int)", bint},
    55  		{pointer{typ: pint}, "Pointer(*int)", pint},
    56  		{mapKey{typ: bint}, "MapKey(int)", bint},
    57  		{mapValue{typ: pint}, "MapValue(*int)", pint},
    58  		{sliceElem{typ: bint}, "Slice([]int)", bint},
    59  		{channelElem{typ: pint}, "Channel(chan *int)", pint},
    60  		{field{StructType: X, index: 0}, "Field(testdata.X:a)", bint},
    61  		{field{StructType: X, index: 1}, "Field(testdata.X:b)", bint},
    62  		{global{val: gl}, "Global(gl)", gl.Type()},
    63  		{local{val: reg}, "Local(t0)", bint},
    64  		{indexedLocal{val: reg, typ: X, index: 0}, "Local(t0[0])", X},
    65  		{function{f: main}, "Function(main)", voidFunc},
    66  		{nestedPtrInterface{typ: i}, "PtrInterface(interface{})", i},
    67  		{nestedPtrFunction{typ: voidFunc}, "PtrFunction(func())", voidFunc},
    68  		{panicArg{}, "Panic", nil},
    69  		{recoverReturn{}, "Recover", nil},
    70  	} {
    71  		if test.s != test.n.String() {
    72  			t.Errorf("want %s; got %s", test.s, test.n.String())
    73  		}
    74  		if test.t != test.n.Type() {
    75  			t.Errorf("want %s; got %s", test.t, test.n.Type())
    76  		}
    77  	}
    78  }
    79  
    80  func TestVtaGraph(t *testing.T) {
    81  	// Get the basic type int from a real program.
    82  	prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0))
    83  	if err != nil {
    84  		t.Fatalf("couldn't load testdata/src/simple.go program: %v", err)
    85  	}
    86  
    87  	glPtrType, ok := prog.AllPackages()[0].Var("gl").Type().(*types.Pointer)
    88  	if !ok {
    89  		t.Fatalf("could not cast gl variable to pointer type")
    90  	}
    91  	bint := glPtrType.Elem()
    92  
    93  	n1 := constant{typ: bint}
    94  	n2 := pointer{typ: types.NewPointer(bint)}
    95  	n3 := mapKey{typ: types.NewMap(bint, bint)}
    96  	n4 := mapValue{typ: types.NewMap(bint, bint)}
    97  
    98  	// Create graph
    99  	//   n1   n2
   100  	//    \  / /
   101  	//     n3 /
   102  	//     | /
   103  	//     n4
   104  	g := make(vtaGraph)
   105  	g.addEdge(n1, n3)
   106  	g.addEdge(n2, n3)
   107  	g.addEdge(n3, n4)
   108  	g.addEdge(n2, n4)
   109  	// for checking duplicates
   110  	g.addEdge(n1, n3)
   111  
   112  	want := vtaGraph{
   113  		n1: map[node]bool{n3: true},
   114  		n2: map[node]bool{n3: true, n4: true},
   115  		n3: map[node]bool{n4: true},
   116  	}
   117  
   118  	if !reflect.DeepEqual(want, g) {
   119  		t.Errorf("want %v; got %v", want, g)
   120  	}
   121  
   122  	for _, test := range []struct {
   123  		n node
   124  		l int
   125  	}{
   126  		{n1, 1},
   127  		{n2, 2},
   128  		{n3, 1},
   129  		{n4, 0},
   130  	} {
   131  		if sl := len(g.successors(test.n)); sl != test.l {
   132  			t.Errorf("want %d successors; got %d", test.l, sl)
   133  		}
   134  	}
   135  }
   136  
   137  // vtaGraphStr stringifies vtaGraph into a list of strings
   138  // where each string represents an edge set of the format
   139  // node -> succ_1, ..., succ_n. succ_1, ..., succ_n are
   140  // sorted in alphabetical order.
   141  func vtaGraphStr(g vtaGraph) []string {
   142  	var vgs []string
   143  	for n, succ := range g {
   144  		var succStr []string
   145  		for s := range succ {
   146  			succStr = append(succStr, s.String())
   147  		}
   148  		sort.Strings(succStr)
   149  		entry := fmt.Sprintf("%v -> %v", n.String(), strings.Join(succStr, ", "))
   150  		vgs = append(vgs, entry)
   151  	}
   152  	return vgs
   153  }
   154  
   155  // subGraph checks if a graph `g1` is a subgraph of graph `g2`.
   156  // Assumes that each element in `g1` and `g2` is an edge set
   157  // for a particular node in a fixed yet arbitrary format.
   158  func subGraph(g1, g2 []string) bool {
   159  	m := make(map[string]bool)
   160  	for _, s := range g2 {
   161  		m[s] = true
   162  	}
   163  
   164  	for _, s := range g1 {
   165  		if _, ok := m[s]; !ok {
   166  			return false
   167  		}
   168  	}
   169  	return true
   170  }
   171  
   172  func TestVTAGraphConstruction(t *testing.T) {
   173  	for _, file := range []string{
   174  		"testdata/src/store.go",
   175  		"testdata/src/phi.go",
   176  		"testdata/src/type_conversions.go",
   177  		"testdata/src/type_assertions.go",
   178  		"testdata/src/fields.go",
   179  		"testdata/src/node_uniqueness.go",
   180  		"testdata/src/store_load_alias.go",
   181  		"testdata/src/phi_alias.go",
   182  		"testdata/src/channels.go",
   183  		"testdata/src/select.go",
   184  		"testdata/src/stores_arrays.go",
   185  		"testdata/src/maps.go",
   186  		"testdata/src/ranges.go",
   187  		"testdata/src/closures.go",
   188  		"testdata/src/function_alias.go",
   189  		"testdata/src/static_calls.go",
   190  		"testdata/src/dynamic_calls.go",
   191  		"testdata/src/returns.go",
   192  		"testdata/src/panic.go",
   193  	} {
   194  		t.Run(file, func(t *testing.T) {
   195  			prog, want, err := testProg(file, ssa.BuilderMode(0))
   196  			if err != nil {
   197  				t.Fatalf("couldn't load test file '%s': %s", file, err)
   198  			}
   199  			if len(want) == 0 {
   200  				t.Fatalf("couldn't find want in `%s`", file)
   201  			}
   202  
   203  			g, _ := typePropGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog))
   204  			if gs := vtaGraphStr(g); !subGraph(want, gs) {
   205  				t.Errorf("`%s`: want superset of %v;\n got %v", file, want, gs)
   206  			}
   207  		})
   208  	}
   209  }