github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/go/ast/inspector/inspector_test.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package inspector_test
     6  
     7  import (
     8  	"go/ast"
     9  	"go/build"
    10  	"go/parser"
    11  	"go/token"
    12  	"log"
    13  	"path/filepath"
    14  	"reflect"
    15  	"strconv"
    16  	"strings"
    17  	"testing"
    18  
    19  	"github.com/powerman/golang-tools/go/ast/inspector"
    20  	"github.com/powerman/golang-tools/internal/typeparams"
    21  )
    22  
    23  var netFiles []*ast.File
    24  
    25  func init() {
    26  	files, err := parseNetFiles()
    27  	if err != nil {
    28  		log.Fatal(err)
    29  	}
    30  	netFiles = files
    31  }
    32  
    33  func parseNetFiles() ([]*ast.File, error) {
    34  	pkg, err := build.Default.Import("net", "", 0)
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  	fset := token.NewFileSet()
    39  	var files []*ast.File
    40  	for _, filename := range pkg.GoFiles {
    41  		filename = filepath.Join(pkg.Dir, filename)
    42  		f, err := parser.ParseFile(fset, filename, nil, 0)
    43  		if err != nil {
    44  			return nil, err
    45  		}
    46  		files = append(files, f)
    47  	}
    48  	return files, nil
    49  }
    50  
    51  // TestAllNodes compares Inspector against ast.Inspect.
    52  func TestInspectAllNodes(t *testing.T) {
    53  	inspect := inspector.New(netFiles)
    54  
    55  	var nodesA []ast.Node
    56  	inspect.Nodes(nil, func(n ast.Node, push bool) bool {
    57  		if push {
    58  			nodesA = append(nodesA, n)
    59  		}
    60  		return true
    61  	})
    62  	var nodesB []ast.Node
    63  	for _, f := range netFiles {
    64  		ast.Inspect(f, func(n ast.Node) bool {
    65  			if n != nil {
    66  				nodesB = append(nodesB, n)
    67  			}
    68  			return true
    69  		})
    70  	}
    71  	compare(t, nodesA, nodesB)
    72  }
    73  
    74  func TestInspectGenericNodes(t *testing.T) {
    75  	if !typeparams.Enabled {
    76  		t.Skip("type parameters are not supported at this Go version")
    77  	}
    78  
    79  	// src is using the 16 identifiers i0, i1, ... i15 so
    80  	// we can easily verify that we've found all of them.
    81  	const src = `package a
    82  
    83  type I interface { ~i0|i1 }
    84  
    85  type T[i2, i3 interface{ ~i4 }] struct {}
    86  
    87  func f[i5, i6 any]() {
    88  	_ = f[i7, i8]
    89  	var x T[i9, i10]
    90  }
    91  
    92  func (*T[i11, i12]) m()
    93  
    94  var _ i13[i14, i15]
    95  `
    96  	fset := token.NewFileSet()
    97  	f, _ := parser.ParseFile(fset, "a.go", src, 0)
    98  	inspect := inspector.New([]*ast.File{f})
    99  	found := make([]bool, 16)
   100  
   101  	indexListExprs := make(map[*typeparams.IndexListExpr]bool)
   102  
   103  	// Verify that we reach all i* identifiers, and collect IndexListExpr nodes.
   104  	inspect.Preorder(nil, func(n ast.Node) {
   105  		switch n := n.(type) {
   106  		case *ast.Ident:
   107  			if n.Name[0] == 'i' {
   108  				index, err := strconv.Atoi(n.Name[1:])
   109  				if err != nil {
   110  					t.Fatal(err)
   111  				}
   112  				found[index] = true
   113  			}
   114  		case *typeparams.IndexListExpr:
   115  			indexListExprs[n] = false
   116  		}
   117  	})
   118  	for i, v := range found {
   119  		if !v {
   120  			t.Errorf("missed identifier i%d", i)
   121  		}
   122  	}
   123  
   124  	// Verify that we can filter to IndexListExprs that we found in the first
   125  	// step.
   126  	if len(indexListExprs) == 0 {
   127  		t.Fatal("no index list exprs found")
   128  	}
   129  	inspect.Preorder([]ast.Node{&typeparams.IndexListExpr{}}, func(n ast.Node) {
   130  		ix := n.(*typeparams.IndexListExpr)
   131  		indexListExprs[ix] = true
   132  	})
   133  	for ix, v := range indexListExprs {
   134  		if !v {
   135  			t.Errorf("inspected node %v not filtered", ix)
   136  		}
   137  	}
   138  }
   139  
   140  // TestPruning compares Inspector against ast.Inspect,
   141  // pruning descent within ast.CallExpr nodes.
   142  func TestInspectPruning(t *testing.T) {
   143  	inspect := inspector.New(netFiles)
   144  
   145  	var nodesA []ast.Node
   146  	inspect.Nodes(nil, func(n ast.Node, push bool) bool {
   147  		if push {
   148  			nodesA = append(nodesA, n)
   149  			_, isCall := n.(*ast.CallExpr)
   150  			return !isCall // don't descend into function calls
   151  		}
   152  		return false
   153  	})
   154  	var nodesB []ast.Node
   155  	for _, f := range netFiles {
   156  		ast.Inspect(f, func(n ast.Node) bool {
   157  			if n != nil {
   158  				nodesB = append(nodesB, n)
   159  				_, isCall := n.(*ast.CallExpr)
   160  				return !isCall // don't descend into function calls
   161  			}
   162  			return false
   163  		})
   164  	}
   165  	compare(t, nodesA, nodesB)
   166  }
   167  
   168  func compare(t *testing.T, nodesA, nodesB []ast.Node) {
   169  	if len(nodesA) != len(nodesB) {
   170  		t.Errorf("inconsistent node lists: %d vs %d", len(nodesA), len(nodesB))
   171  	} else {
   172  		for i := range nodesA {
   173  			if a, b := nodesA[i], nodesB[i]; a != b {
   174  				t.Errorf("node %d is inconsistent: %T, %T", i, a, b)
   175  			}
   176  		}
   177  	}
   178  }
   179  
   180  func TestTypeFiltering(t *testing.T) {
   181  	const src = `package a
   182  func f() {
   183  	print("hi")
   184  	panic("oops")
   185  }
   186  `
   187  	fset := token.NewFileSet()
   188  	f, _ := parser.ParseFile(fset, "a.go", src, 0)
   189  	inspect := inspector.New([]*ast.File{f})
   190  
   191  	var got []string
   192  	fn := func(n ast.Node, push bool) bool {
   193  		if push {
   194  			got = append(got, typeOf(n))
   195  		}
   196  		return true
   197  	}
   198  
   199  	// no type filtering
   200  	inspect.Nodes(nil, fn)
   201  	if want := strings.Fields("File Ident FuncDecl Ident FuncType FieldList BlockStmt ExprStmt CallExpr Ident BasicLit ExprStmt CallExpr Ident BasicLit"); !reflect.DeepEqual(got, want) {
   202  		t.Errorf("inspect: got %s, want %s", got, want)
   203  	}
   204  
   205  	// type filtering
   206  	nodeTypes := []ast.Node{
   207  		(*ast.BasicLit)(nil),
   208  		(*ast.CallExpr)(nil),
   209  	}
   210  	got = nil
   211  	inspect.Nodes(nodeTypes, fn)
   212  	if want := strings.Fields("CallExpr BasicLit CallExpr BasicLit"); !reflect.DeepEqual(got, want) {
   213  		t.Errorf("inspect: got %s, want %s", got, want)
   214  	}
   215  
   216  	// inspect with stack
   217  	got = nil
   218  	inspect.WithStack(nodeTypes, func(n ast.Node, push bool, stack []ast.Node) bool {
   219  		if push {
   220  			var line []string
   221  			for _, n := range stack {
   222  				line = append(line, typeOf(n))
   223  			}
   224  			got = append(got, strings.Join(line, " "))
   225  		}
   226  		return true
   227  	})
   228  	want := []string{
   229  		"File FuncDecl BlockStmt ExprStmt CallExpr",
   230  		"File FuncDecl BlockStmt ExprStmt CallExpr BasicLit",
   231  		"File FuncDecl BlockStmt ExprStmt CallExpr",
   232  		"File FuncDecl BlockStmt ExprStmt CallExpr BasicLit",
   233  	}
   234  	if !reflect.DeepEqual(got, want) {
   235  		t.Errorf("inspect: got %s, want %s", got, want)
   236  	}
   237  }
   238  
   239  func typeOf(n ast.Node) string {
   240  	return strings.TrimPrefix(reflect.TypeOf(n).String(), "*ast.")
   241  }
   242  
   243  // The numbers show a marginal improvement (ASTInspect/Inspect) of 3.5x,
   244  // but a break-even point (NewInspector/(ASTInspect-Inspect)) of about 5
   245  // traversals.
   246  //
   247  // BenchmarkNewInspector   4.5 ms
   248  // BenchmarkNewInspect	   0.33ms
   249  // BenchmarkASTInspect    1.2  ms
   250  
   251  func BenchmarkNewInspector(b *testing.B) {
   252  	// Measure one-time construction overhead.
   253  	for i := 0; i < b.N; i++ {
   254  		inspector.New(netFiles)
   255  	}
   256  }
   257  
   258  func BenchmarkInspect(b *testing.B) {
   259  	b.StopTimer()
   260  	inspect := inspector.New(netFiles)
   261  	b.StartTimer()
   262  
   263  	// Measure marginal cost of traversal.
   264  	var ndecls, nlits int
   265  	for i := 0; i < b.N; i++ {
   266  		inspect.Preorder(nil, func(n ast.Node) {
   267  			switch n.(type) {
   268  			case *ast.FuncDecl:
   269  				ndecls++
   270  			case *ast.FuncLit:
   271  				nlits++
   272  			}
   273  		})
   274  	}
   275  }
   276  
   277  func BenchmarkASTInspect(b *testing.B) {
   278  	var ndecls, nlits int
   279  	for i := 0; i < b.N; i++ {
   280  		for _, f := range netFiles {
   281  			ast.Inspect(f, func(n ast.Node) bool {
   282  				switch n.(type) {
   283  				case *ast.FuncDecl:
   284  					ndecls++
   285  				case *ast.FuncLit:
   286  					nlits++
   287  				}
   288  				return true
   289  			})
   290  		}
   291  	}
   292  }