github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/go/ast/inspector/inspector_test.go (about) 1 // Copyright 2018 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package inspector_test 6 7 import ( 8 "go/ast" 9 "go/build" 10 "go/parser" 11 "go/token" 12 "log" 13 "path/filepath" 14 "reflect" 15 "strconv" 16 "strings" 17 "testing" 18 19 "github.com/powerman/golang-tools/go/ast/inspector" 20 "github.com/powerman/golang-tools/internal/typeparams" 21 ) 22 23 var netFiles []*ast.File 24 25 func init() { 26 files, err := parseNetFiles() 27 if err != nil { 28 log.Fatal(err) 29 } 30 netFiles = files 31 } 32 33 func parseNetFiles() ([]*ast.File, error) { 34 pkg, err := build.Default.Import("net", "", 0) 35 if err != nil { 36 return nil, err 37 } 38 fset := token.NewFileSet() 39 var files []*ast.File 40 for _, filename := range pkg.GoFiles { 41 filename = filepath.Join(pkg.Dir, filename) 42 f, err := parser.ParseFile(fset, filename, nil, 0) 43 if err != nil { 44 return nil, err 45 } 46 files = append(files, f) 47 } 48 return files, nil 49 } 50 51 // TestAllNodes compares Inspector against ast.Inspect. 52 func TestInspectAllNodes(t *testing.T) { 53 inspect := inspector.New(netFiles) 54 55 var nodesA []ast.Node 56 inspect.Nodes(nil, func(n ast.Node, push bool) bool { 57 if push { 58 nodesA = append(nodesA, n) 59 } 60 return true 61 }) 62 var nodesB []ast.Node 63 for _, f := range netFiles { 64 ast.Inspect(f, func(n ast.Node) bool { 65 if n != nil { 66 nodesB = append(nodesB, n) 67 } 68 return true 69 }) 70 } 71 compare(t, nodesA, nodesB) 72 } 73 74 func TestInspectGenericNodes(t *testing.T) { 75 if !typeparams.Enabled { 76 t.Skip("type parameters are not supported at this Go version") 77 } 78 79 // src is using the 16 identifiers i0, i1, ... i15 so 80 // we can easily verify that we've found all of them. 81 const src = `package a 82 83 type I interface { ~i0|i1 } 84 85 type T[i2, i3 interface{ ~i4 }] struct {} 86 87 func f[i5, i6 any]() { 88 _ = f[i7, i8] 89 var x T[i9, i10] 90 } 91 92 func (*T[i11, i12]) m() 93 94 var _ i13[i14, i15] 95 ` 96 fset := token.NewFileSet() 97 f, _ := parser.ParseFile(fset, "a.go", src, 0) 98 inspect := inspector.New([]*ast.File{f}) 99 found := make([]bool, 16) 100 101 indexListExprs := make(map[*typeparams.IndexListExpr]bool) 102 103 // Verify that we reach all i* identifiers, and collect IndexListExpr nodes. 104 inspect.Preorder(nil, func(n ast.Node) { 105 switch n := n.(type) { 106 case *ast.Ident: 107 if n.Name[0] == 'i' { 108 index, err := strconv.Atoi(n.Name[1:]) 109 if err != nil { 110 t.Fatal(err) 111 } 112 found[index] = true 113 } 114 case *typeparams.IndexListExpr: 115 indexListExprs[n] = false 116 } 117 }) 118 for i, v := range found { 119 if !v { 120 t.Errorf("missed identifier i%d", i) 121 } 122 } 123 124 // Verify that we can filter to IndexListExprs that we found in the first 125 // step. 126 if len(indexListExprs) == 0 { 127 t.Fatal("no index list exprs found") 128 } 129 inspect.Preorder([]ast.Node{&typeparams.IndexListExpr{}}, func(n ast.Node) { 130 ix := n.(*typeparams.IndexListExpr) 131 indexListExprs[ix] = true 132 }) 133 for ix, v := range indexListExprs { 134 if !v { 135 t.Errorf("inspected node %v not filtered", ix) 136 } 137 } 138 } 139 140 // TestPruning compares Inspector against ast.Inspect, 141 // pruning descent within ast.CallExpr nodes. 142 func TestInspectPruning(t *testing.T) { 143 inspect := inspector.New(netFiles) 144 145 var nodesA []ast.Node 146 inspect.Nodes(nil, func(n ast.Node, push bool) bool { 147 if push { 148 nodesA = append(nodesA, n) 149 _, isCall := n.(*ast.CallExpr) 150 return !isCall // don't descend into function calls 151 } 152 return false 153 }) 154 var nodesB []ast.Node 155 for _, f := range netFiles { 156 ast.Inspect(f, func(n ast.Node) bool { 157 if n != nil { 158 nodesB = append(nodesB, n) 159 _, isCall := n.(*ast.CallExpr) 160 return !isCall // don't descend into function calls 161 } 162 return false 163 }) 164 } 165 compare(t, nodesA, nodesB) 166 } 167 168 func compare(t *testing.T, nodesA, nodesB []ast.Node) { 169 if len(nodesA) != len(nodesB) { 170 t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB)) 171 } else { 172 for i := range nodesA { 173 if a, b := nodesA[i], nodesB[i]; a != b { 174 t.Errorf("node %d is inconsistent: %T, %T", i, a, b) 175 } 176 } 177 } 178 } 179 180 func TestTypeFiltering(t *testing.T) { 181 const src = `package a 182 func f() { 183 print("hi") 184 panic("oops") 185 } 186 ` 187 fset := token.NewFileSet() 188 f, _ := parser.ParseFile(fset, "a.go", src, 0) 189 inspect := inspector.New([]*ast.File{f}) 190 191 var got []string 192 fn := func(n ast.Node, push bool) bool { 193 if push { 194 got = append(got, typeOf(n)) 195 } 196 return true 197 } 198 199 // no type filtering 200 inspect.Nodes(nil, fn) 201 if want := strings.Fields("File Ident FuncDecl Ident FuncType FieldList BlockStmt ExprStmt CallExpr Ident BasicLit ExprStmt CallExpr Ident BasicLit"); !reflect.DeepEqual(got, want) { 202 t.Errorf("inspect: got %s, want %s", got, want) 203 } 204 205 // type filtering 206 nodeTypes := []ast.Node{ 207 (*ast.BasicLit)(nil), 208 (*ast.CallExpr)(nil), 209 } 210 got = nil 211 inspect.Nodes(nodeTypes, fn) 212 if want := strings.Fields("CallExpr BasicLit CallExpr BasicLit"); !reflect.DeepEqual(got, want) { 213 t.Errorf("inspect: got %s, want %s", got, want) 214 } 215 216 // inspect with stack 217 got = nil 218 inspect.WithStack(nodeTypes, func(n ast.Node, push bool, stack []ast.Node) bool { 219 if push { 220 var line []string 221 for _, n := range stack { 222 line = append(line, typeOf(n)) 223 } 224 got = append(got, strings.Join(line, " ")) 225 } 226 return true 227 }) 228 want := []string{ 229 "File FuncDecl BlockStmt ExprStmt CallExpr", 230 "File FuncDecl BlockStmt ExprStmt CallExpr BasicLit", 231 "File FuncDecl BlockStmt ExprStmt CallExpr", 232 "File FuncDecl BlockStmt ExprStmt CallExpr BasicLit", 233 } 234 if !reflect.DeepEqual(got, want) { 235 t.Errorf("inspect: got %s, want %s", got, want) 236 } 237 } 238 239 func typeOf(n ast.Node) string { 240 return strings.TrimPrefix(reflect.TypeOf(n).String(), "*ast.") 241 } 242 243 // The numbers show a marginal improvement (ASTInspect/Inspect) of 3.5x, 244 // but a break-even point (NewInspector/(ASTInspect-Inspect)) of about 5 245 // traversals. 246 // 247 // BenchmarkNewInspector 4.5 ms 248 // BenchmarkNewInspect 0.33ms 249 // BenchmarkASTInspect 1.2 ms 250 251 func BenchmarkNewInspector(b *testing.B) { 252 // Measure one-time construction overhead. 253 for i := 0; i < b.N; i++ { 254 inspector.New(netFiles) 255 } 256 } 257 258 func BenchmarkInspect(b *testing.B) { 259 b.StopTimer() 260 inspect := inspector.New(netFiles) 261 b.StartTimer() 262 263 // Measure marginal cost of traversal. 264 var ndecls, nlits int 265 for i := 0; i < b.N; i++ { 266 inspect.Preorder(nil, func(n ast.Node) { 267 switch n.(type) { 268 case *ast.FuncDecl: 269 ndecls++ 270 case *ast.FuncLit: 271 nlits++ 272 } 273 }) 274 } 275 } 276 277 func BenchmarkASTInspect(b *testing.B) { 278 var ndecls, nlits int 279 for i := 0; i < b.N; i++ { 280 for _, f := range netFiles { 281 ast.Inspect(f, func(n ast.Node) bool { 282 switch n.(type) { 283 case *ast.FuncDecl: 284 ndecls++ 285 case *ast.FuncLit: 286 nlits++ 287 } 288 return true 289 }) 290 } 291 } 292 }