golang.org/x/tools@v0.21.1-0.20240520172518-788d39e776b1/go/callgraph/vta/graph_test.go (about) 1 // Copyright 2021 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package vta 6 7 import ( 8 "fmt" 9 "go/types" 10 "reflect" 11 "sort" 12 "strings" 13 "testing" 14 15 "golang.org/x/tools/go/callgraph/cha" 16 "golang.org/x/tools/go/ssa" 17 "golang.org/x/tools/go/ssa/ssautil" 18 "golang.org/x/tools/internal/aliases" 19 ) 20 21 func TestNodeInterface(t *testing.T) { 22 // Since ssa package does not allow explicit creation of ssa 23 // values, we use the values from the program testdata/src/simple.go: 24 // - basic type int 25 // - struct X with two int fields a and b 26 // - global variable "gl" 27 // - "main" function and its 28 // - first register instruction t0 := *gl 29 prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0)) 30 if err != nil { 31 t.Fatalf("couldn't load testdata/src/simple.go program: %v", err) 32 } 33 34 pkg := prog.AllPackages()[0] 35 main := pkg.Func("main") 36 reg := firstRegInstr(main) // t0 := *gl 37 X := pkg.Type("X").Type() 38 gl := pkg.Var("gl") 39 glPtrType, ok := aliases.Unalias(gl.Type()).(*types.Pointer) 40 if !ok { 41 t.Fatalf("could not cast gl variable to pointer type") 42 } 43 bint := glPtrType.Elem() 44 45 pint := types.NewPointer(bint) 46 i := types.NewInterface(nil, nil) 47 48 voidFunc := main.Signature.Underlying() 49 50 for _, test := range []struct { 51 n node 52 s string 53 t types.Type 54 }{ 55 {constant{typ: bint}, "Constant(int)", bint}, 56 {pointer{typ: pint}, "Pointer(*int)", pint}, 57 {mapKey{typ: bint}, "MapKey(int)", bint}, 58 {mapValue{typ: pint}, "MapValue(*int)", pint}, 59 {sliceElem{typ: bint}, "Slice([]int)", bint}, 60 {channelElem{typ: pint}, "Channel(chan *int)", pint}, 61 {field{StructType: X, index: 0}, "Field(testdata.X:a)", bint}, 62 {field{StructType: X, index: 1}, "Field(testdata.X:b)", bint}, 63 {global{val: gl}, "Global(gl)", gl.Type()}, 64 {local{val: reg}, "Local(t0)", bint}, 65 {indexedLocal{val: reg, typ: X, index: 0}, "Local(t0[0])", X}, 66 {function{f: main}, "Function(main)", voidFunc}, 67 {nestedPtrInterface{typ: i}, "PtrInterface(interface{})", i}, 68 {nestedPtrFunction{typ: voidFunc}, "PtrFunction(func())", voidFunc}, 69 {panicArg{}, "Panic", nil}, 70 {recoverReturn{}, "Recover", nil}, 71 } { 72 if test.s != test.n.String() { 73 t.Errorf("want %s; got %s", test.s, test.n.String()) 74 } 75 if test.t != test.n.Type() { 76 t.Errorf("want %s; got %s", test.t, test.n.Type()) 77 } 78 } 79 } 80 81 func TestVtaGraph(t *testing.T) { 82 // Get the basic type int from a real program. 83 prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0)) 84 if err != nil { 85 t.Fatalf("couldn't load testdata/src/simple.go program: %v", err) 86 } 87 88 glPtrType, ok := prog.AllPackages()[0].Var("gl").Type().(*types.Pointer) 89 if !ok { 90 t.Fatalf("could not cast gl variable to pointer type") 91 } 92 bint := glPtrType.Elem() 93 94 n1 := constant{typ: bint} 95 n2 := pointer{typ: types.NewPointer(bint)} 96 n3 := mapKey{typ: types.NewMap(bint, bint)} 97 n4 := mapValue{typ: types.NewMap(bint, bint)} 98 99 // Create graph 100 // n1 n2 101 // \ / / 102 // n3 / 103 // | / 104 // n4 105 g := make(vtaGraph) 106 g.addEdge(n1, n3) 107 g.addEdge(n2, n3) 108 g.addEdge(n3, n4) 109 g.addEdge(n2, n4) 110 // for checking duplicates 111 g.addEdge(n1, n3) 112 113 want := vtaGraph{ 114 n1: map[node]bool{n3: true}, 115 n2: map[node]bool{n3: true, n4: true}, 116 n3: map[node]bool{n4: true}, 117 } 118 119 if !reflect.DeepEqual(want, g) { 120 t.Errorf("want %v; got %v", want, g) 121 } 122 123 for _, test := range []struct { 124 n node 125 l int 126 }{ 127 {n1, 1}, 128 {n2, 2}, 129 {n3, 1}, 130 {n4, 0}, 131 } { 132 if sl := len(g.successors(test.n)); sl != test.l { 133 t.Errorf("want %d successors; got %d", test.l, sl) 134 } 135 } 136 } 137 138 // vtaGraphStr stringifies vtaGraph into a list of strings 139 // where each string represents an edge set of the format 140 // node -> succ_1, ..., succ_n. succ_1, ..., succ_n are 141 // sorted in alphabetical order. 142 func vtaGraphStr(g vtaGraph) []string { 143 var vgs []string 144 for n, succ := range g { 145 var succStr []string 146 for s := range succ { 147 succStr = append(succStr, s.String()) 148 } 149 sort.Strings(succStr) 150 entry := fmt.Sprintf("%v -> %v", n.String(), strings.Join(succStr, ", ")) 151 vgs = append(vgs, entry) 152 } 153 return vgs 154 } 155 156 // setdiff returns the set difference of `X-Y` or {s | s ∈ X, s ∉ Y }. 157 func setdiff(X, Y []string) []string { 158 y := make(map[string]bool) 159 var delta []string 160 for _, s := range Y { 161 y[s] = true 162 } 163 164 for _, s := range X { 165 if _, ok := y[s]; !ok { 166 delta = append(delta, s) 167 } 168 } 169 sort.Strings(delta) 170 return delta 171 } 172 173 func TestVTAGraphConstruction(t *testing.T) { 174 for _, file := range []string{ 175 "testdata/src/store.go", 176 "testdata/src/phi.go", 177 "testdata/src/type_conversions.go", 178 "testdata/src/type_assertions.go", 179 "testdata/src/fields.go", 180 "testdata/src/node_uniqueness.go", 181 "testdata/src/store_load_alias.go", 182 "testdata/src/phi_alias.go", 183 "testdata/src/channels.go", 184 "testdata/src/generic_channels.go", 185 "testdata/src/select.go", 186 "testdata/src/stores_arrays.go", 187 "testdata/src/maps.go", 188 "testdata/src/ranges.go", 189 "testdata/src/closures.go", 190 "testdata/src/function_alias.go", 191 "testdata/src/static_calls.go", 192 "testdata/src/dynamic_calls.go", 193 "testdata/src/returns.go", 194 "testdata/src/panic.go", 195 } { 196 t.Run(file, func(t *testing.T) { 197 prog, want, err := testProg(file, ssa.BuilderMode(0)) 198 if err != nil { 199 t.Fatalf("couldn't load test file '%s': %s", file, err) 200 } 201 if len(want) == 0 { 202 t.Fatalf("couldn't find want in `%s`", file) 203 } 204 205 g, _ := typePropGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog)) 206 got := vtaGraphStr(g) 207 if diff := setdiff(want, got); len(diff) > 0 { 208 t.Errorf("`%s`: want superset of %v;\n got %v\ndiff: %v", file, want, got, diff) 209 } 210 }) 211 } 212 }