golang.org/x/tools@v0.21.1-0.20240520172518-788d39e776b1/go/ast/inspector/inspector_test.go (about) 1 // Copyright 2018 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package inspector_test 6 7 import ( 8 "go/ast" 9 "go/build" 10 "go/parser" 11 "go/token" 12 "log" 13 "path/filepath" 14 "reflect" 15 "strconv" 16 "strings" 17 "testing" 18 19 "golang.org/x/tools/go/ast/inspector" 20 ) 21 22 var netFiles []*ast.File 23 24 func init() { 25 files, err := parseNetFiles() 26 if err != nil { 27 log.Fatal(err) 28 } 29 netFiles = files 30 } 31 32 func parseNetFiles() ([]*ast.File, error) { 33 pkg, err := build.Default.Import("net", "", 0) 34 if err != nil { 35 return nil, err 36 } 37 fset := token.NewFileSet() 38 var files []*ast.File 39 for _, filename := range pkg.GoFiles { 40 filename = filepath.Join(pkg.Dir, filename) 41 f, err := parser.ParseFile(fset, filename, nil, 0) 42 if err != nil { 43 return nil, err 44 } 45 files = append(files, f) 46 } 47 return files, nil 48 } 49 50 // TestInspectAllNodes compares Inspector against ast.Inspect. 51 func TestInspectAllNodes(t *testing.T) { 52 inspect := inspector.New(netFiles) 53 54 var nodesA []ast.Node 55 inspect.Nodes(nil, func(n ast.Node, push bool) bool { 56 if push { 57 nodesA = append(nodesA, n) 58 } 59 return true 60 }) 61 var nodesB []ast.Node 62 for _, f := range netFiles { 63 ast.Inspect(f, func(n ast.Node) bool { 64 if n != nil { 65 nodesB = append(nodesB, n) 66 } 67 return true 68 }) 69 } 70 compare(t, nodesA, nodesB) 71 } 72 73 func TestInspectGenericNodes(t *testing.T) { 74 // src is using the 16 identifiers i0, i1, ... i15 so 75 // we can easily verify that we've found all of them. 76 const src = `package a 77 78 type I interface { ~i0|i1 } 79 80 type T[i2, i3 interface{ ~i4 }] struct {} 81 82 func f[i5, i6 any]() { 83 _ = f[i7, i8] 84 var x T[i9, i10] 85 } 86 87 func (*T[i11, i12]) m() 88 89 var _ i13[i14, i15] 90 ` 91 fset := token.NewFileSet() 92 f, _ := parser.ParseFile(fset, "a.go", src, 0) 93 inspect := inspector.New([]*ast.File{f}) 94 found := make([]bool, 16) 95 96 indexListExprs := make(map[*ast.IndexListExpr]bool) 97 98 // Verify that we reach all i* identifiers, and collect IndexListExpr nodes. 99 inspect.Preorder(nil, func(n ast.Node) { 100 switch n := n.(type) { 101 case *ast.Ident: 102 if n.Name[0] == 'i' { 103 index, err := strconv.Atoi(n.Name[1:]) 104 if err != nil { 105 t.Fatal(err) 106 } 107 found[index] = true 108 } 109 case *ast.IndexListExpr: 110 indexListExprs[n] = false 111 } 112 }) 113 for i, v := range found { 114 if !v { 115 t.Errorf("missed identifier i%d", i) 116 } 117 } 118 119 // Verify that we can filter to IndexListExprs that we found in the first 120 // step. 121 if len(indexListExprs) == 0 { 122 t.Fatal("no index list exprs found") 123 } 124 inspect.Preorder([]ast.Node{&ast.IndexListExpr{}}, func(n ast.Node) { 125 ix := n.(*ast.IndexListExpr) 126 indexListExprs[ix] = true 127 }) 128 for ix, v := range indexListExprs { 129 if !v { 130 t.Errorf("inspected node %v not filtered", ix) 131 } 132 } 133 } 134 135 // TestInspectPruning compares Inspector against ast.Inspect, 136 // pruning descent within ast.CallExpr nodes. 137 func TestInspectPruning(t *testing.T) { 138 inspect := inspector.New(netFiles) 139 140 var nodesA []ast.Node 141 inspect.Nodes(nil, func(n ast.Node, push bool) bool { 142 if push { 143 nodesA = append(nodesA, n) 144 _, isCall := n.(*ast.CallExpr) 145 return !isCall // don't descend into function calls 146 } 147 return false 148 }) 149 var nodesB []ast.Node 150 for _, f := range netFiles { 151 ast.Inspect(f, func(n ast.Node) bool { 152 if n != nil { 153 nodesB = append(nodesB, n) 154 _, isCall := n.(*ast.CallExpr) 155 return !isCall // don't descend into function calls 156 } 157 return false 158 }) 159 } 160 compare(t, nodesA, nodesB) 161 } 162 163 func compare(t *testing.T, nodesA, nodesB []ast.Node) { 164 if len(nodesA) != len(nodesB) { 165 t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB)) 166 } else { 167 for i := range nodesA { 168 if a, b := nodesA[i], nodesB[i]; a != b { 169 t.Errorf("node %d is inconsistent: %T, %T", i, a, b) 170 } 171 } 172 } 173 } 174 175 func TestTypeFiltering(t *testing.T) { 176 const src = `package a 177 func f() { 178 print("hi") 179 panic("oops") 180 } 181 ` 182 fset := token.NewFileSet() 183 f, _ := parser.ParseFile(fset, "a.go", src, 0) 184 inspect := inspector.New([]*ast.File{f}) 185 186 var got []string 187 fn := func(n ast.Node, push bool) bool { 188 if push { 189 got = append(got, typeOf(n)) 190 } 191 return true 192 } 193 194 // no type filtering 195 inspect.Nodes(nil, fn) 196 if want := strings.Fields("File Ident FuncDecl Ident FuncType FieldList BlockStmt ExprStmt CallExpr Ident BasicLit ExprStmt CallExpr Ident BasicLit"); !reflect.DeepEqual(got, want) { 197 t.Errorf("inspect: got %s, want %s", got, want) 198 } 199 200 // type filtering 201 nodeTypes := []ast.Node{ 202 (*ast.BasicLit)(nil), 203 (*ast.CallExpr)(nil), 204 } 205 got = nil 206 inspect.Nodes(nodeTypes, fn) 207 if want := strings.Fields("CallExpr BasicLit CallExpr BasicLit"); !reflect.DeepEqual(got, want) { 208 t.Errorf("inspect: got %s, want %s", got, want) 209 } 210 211 // inspect with stack 212 got = nil 213 inspect.WithStack(nodeTypes, func(n ast.Node, push bool, stack []ast.Node) bool { 214 if push { 215 var line []string 216 for _, n := range stack { 217 line = append(line, typeOf(n)) 218 } 219 got = append(got, strings.Join(line, " ")) 220 } 221 return true 222 }) 223 want := []string{ 224 "File FuncDecl BlockStmt ExprStmt CallExpr", 225 "File FuncDecl BlockStmt ExprStmt CallExpr BasicLit", 226 "File FuncDecl BlockStmt ExprStmt CallExpr", 227 "File FuncDecl BlockStmt ExprStmt CallExpr BasicLit", 228 } 229 if !reflect.DeepEqual(got, want) { 230 t.Errorf("inspect: got %s, want %s", got, want) 231 } 232 } 233 234 func typeOf(n ast.Node) string { 235 return strings.TrimPrefix(reflect.TypeOf(n).String(), "*ast.") 236 } 237 238 // The numbers show a marginal improvement (ASTInspect/Inspect) of 3.5x, 239 // but a break-even point (NewInspector/(ASTInspect-Inspect)) of about 5 240 // traversals. 241 // 242 // BenchmarkASTInspect 1.0 ms 243 // BenchmarkNewInspector 2.2 ms 244 // BenchmarkInspect 0.39ms 245 // BenchmarkInspectFilter 0.01ms 246 // BenchmarkInspectCalls 0.14ms 247 248 func BenchmarkNewInspector(b *testing.B) { 249 // Measure one-time construction overhead. 250 for i := 0; i < b.N; i++ { 251 inspector.New(netFiles) 252 } 253 } 254 255 func BenchmarkInspect(b *testing.B) { 256 b.StopTimer() 257 inspect := inspector.New(netFiles) 258 b.StartTimer() 259 260 // Measure marginal cost of traversal. 261 var ndecls, nlits int 262 for i := 0; i < b.N; i++ { 263 inspect.Preorder(nil, func(n ast.Node) { 264 switch n.(type) { 265 case *ast.FuncDecl: 266 ndecls++ 267 case *ast.FuncLit: 268 nlits++ 269 } 270 }) 271 } 272 } 273 274 func BenchmarkInspectFilter(b *testing.B) { 275 b.StopTimer() 276 inspect := inspector.New(netFiles) 277 b.StartTimer() 278 279 // Measure marginal cost of traversal. 280 nodeFilter := []ast.Node{(*ast.FuncDecl)(nil), (*ast.FuncLit)(nil)} 281 var ndecls, nlits int 282 for i := 0; i < b.N; i++ { 283 inspect.Preorder(nodeFilter, func(n ast.Node) { 284 switch n.(type) { 285 case *ast.FuncDecl: 286 ndecls++ 287 case *ast.FuncLit: 288 nlits++ 289 } 290 }) 291 } 292 } 293 294 func BenchmarkInspectCalls(b *testing.B) { 295 b.StopTimer() 296 inspect := inspector.New(netFiles) 297 b.StartTimer() 298 299 // Measure marginal cost of traversal. 300 nodeFilter := []ast.Node{(*ast.CallExpr)(nil)} 301 var ncalls int 302 for i := 0; i < b.N; i++ { 303 inspect.Preorder(nodeFilter, func(n ast.Node) { 304 _ = n.(*ast.CallExpr) 305 ncalls++ 306 }) 307 } 308 } 309 310 func BenchmarkASTInspect(b *testing.B) { 311 var ndecls, nlits int 312 for i := 0; i < b.N; i++ { 313 for _, f := range netFiles { 314 ast.Inspect(f, func(n ast.Node) bool { 315 switch n.(type) { 316 case *ast.FuncDecl: 317 ndecls++ 318 case *ast.FuncLit: 319 nlits++ 320 } 321 return true 322 }) 323 } 324 } 325 }