github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/go/ast/astutil/rewrite_test.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package astutil_test
     6  
     7  import (
     8  	"bytes"
     9  	"go/ast"
    10  	"go/format"
    11  	"go/parser"
    12  	"go/token"
    13  	"testing"
    14  
    15  	"github.com/powerman/golang-tools/go/ast/astutil"
    16  	"github.com/powerman/golang-tools/internal/typeparams"
    17  )
    18  
    19  type rewriteTest struct {
    20  	name       string
    21  	orig, want string
    22  	pre, post  astutil.ApplyFunc
    23  }
    24  
    25  var rewriteTests = []rewriteTest{
    26  	{name: "nop", orig: "package p\n", want: "package p\n"},
    27  
    28  	{name: "replace",
    29  		orig: `package p
    30  
    31  var x int
    32  `,
    33  		want: `package p
    34  
    35  var t T
    36  `,
    37  		post: func(c *astutil.Cursor) bool {
    38  			if _, ok := c.Node().(*ast.ValueSpec); ok {
    39  				c.Replace(valspec("t", "T"))
    40  				return false
    41  			}
    42  			return true
    43  		},
    44  	},
    45  
    46  	{name: "set doc strings",
    47  		orig: `package p
    48  
    49  const z = 0
    50  
    51  type T struct{}
    52  
    53  var x int
    54  `,
    55  		want: `package p
    56  // a foo is a foo
    57  const z = 0
    58  // a foo is a foo
    59  type T struct{}
    60  // a foo is a foo
    61  var x int
    62  `,
    63  		post: func(c *astutil.Cursor) bool {
    64  			if _, ok := c.Parent().(*ast.GenDecl); ok && c.Name() == "Doc" && c.Node() == nil {
    65  				c.Replace(&ast.CommentGroup{List: []*ast.Comment{{Text: "// a foo is a foo"}}})
    66  			}
    67  			return true
    68  		},
    69  	},
    70  
    71  	{name: "insert names",
    72  		orig: `package p
    73  
    74  const a = 1
    75  `,
    76  		want: `package p
    77  
    78  const a, b, c = 1, 2, 3
    79  `,
    80  		pre: func(c *astutil.Cursor) bool {
    81  			if _, ok := c.Parent().(*ast.ValueSpec); ok {
    82  				switch c.Name() {
    83  				case "Names":
    84  					c.InsertAfter(ast.NewIdent("c"))
    85  					c.InsertAfter(ast.NewIdent("b"))
    86  				case "Values":
    87  					c.InsertAfter(&ast.BasicLit{Kind: token.INT, Value: "3"})
    88  					c.InsertAfter(&ast.BasicLit{Kind: token.INT, Value: "2"})
    89  				}
    90  			}
    91  			return true
    92  		},
    93  	},
    94  
    95  	{name: "insert",
    96  		orig: `package p
    97  
    98  var (
    99  	x int
   100  	y int
   101  )
   102  `,
   103  		want: `package p
   104  
   105  var before1 int
   106  var before2 int
   107  
   108  var (
   109  	x int
   110  	y int
   111  )
   112  var after2 int
   113  var after1 int
   114  `,
   115  		pre: func(c *astutil.Cursor) bool {
   116  			if _, ok := c.Node().(*ast.GenDecl); ok {
   117  				c.InsertBefore(vardecl("before1", "int"))
   118  				c.InsertAfter(vardecl("after1", "int"))
   119  				c.InsertAfter(vardecl("after2", "int"))
   120  				c.InsertBefore(vardecl("before2", "int"))
   121  			}
   122  			return true
   123  		},
   124  	},
   125  
   126  	{name: "delete",
   127  		orig: `package p
   128  
   129  var x int
   130  var y int
   131  var z int
   132  `,
   133  		want: `package p
   134  
   135  var y int
   136  var z int
   137  `,
   138  		pre: func(c *astutil.Cursor) bool {
   139  			n := c.Node()
   140  			if d, ok := n.(*ast.GenDecl); ok && d.Specs[0].(*ast.ValueSpec).Names[0].Name == "x" {
   141  				c.Delete()
   142  			}
   143  			return true
   144  		},
   145  	},
   146  
   147  	{name: "insertafter-delete",
   148  		orig: `package p
   149  
   150  var x int
   151  var y int
   152  var z int
   153  `,
   154  		want: `package p
   155  
   156  var x1 int
   157  
   158  var y int
   159  var z int
   160  `,
   161  		pre: func(c *astutil.Cursor) bool {
   162  			n := c.Node()
   163  			if d, ok := n.(*ast.GenDecl); ok && d.Specs[0].(*ast.ValueSpec).Names[0].Name == "x" {
   164  				c.InsertAfter(vardecl("x1", "int"))
   165  				c.Delete()
   166  			}
   167  			return true
   168  		},
   169  	},
   170  
   171  	{name: "delete-insertafter",
   172  		orig: `package p
   173  
   174  var x int
   175  var y int
   176  var z int
   177  `,
   178  		want: `package p
   179  
   180  var y int
   181  var x1 int
   182  var z int
   183  `,
   184  		pre: func(c *astutil.Cursor) bool {
   185  			n := c.Node()
   186  			if d, ok := n.(*ast.GenDecl); ok && d.Specs[0].(*ast.ValueSpec).Names[0].Name == "x" {
   187  				c.Delete()
   188  				// The cursor is now effectively atop the 'var y int' node.
   189  				c.InsertAfter(vardecl("x1", "int"))
   190  			}
   191  			return true
   192  		},
   193  	},
   194  }
   195  
   196  func init() {
   197  	if typeparams.Enabled {
   198  		rewriteTests = append(rewriteTests, rewriteTest{
   199  			name: "replace",
   200  			orig: `package p
   201  
   202  type T[P1, P2 any] int
   203  
   204  type R T[int, string]
   205  `,
   206  			want: `package p
   207  
   208  type S[P1, P2 any] int32
   209  
   210  type R S[int32, string]
   211  `,
   212  			post: func(c *astutil.Cursor) bool {
   213  				if ident, ok := c.Node().(*ast.Ident); ok {
   214  					if ident.Name == "int" {
   215  						c.Replace(ast.NewIdent("int32"))
   216  					}
   217  					if ident.Name == "T" {
   218  						c.Replace(ast.NewIdent("S"))
   219  					}
   220  				}
   221  				return true
   222  			},
   223  		})
   224  	}
   225  }
   226  
   227  func valspec(name, typ string) *ast.ValueSpec {
   228  	return &ast.ValueSpec{Names: []*ast.Ident{ast.NewIdent(name)},
   229  		Type: ast.NewIdent(typ),
   230  	}
   231  }
   232  
   233  func vardecl(name, typ string) *ast.GenDecl {
   234  	return &ast.GenDecl{
   235  		Tok:   token.VAR,
   236  		Specs: []ast.Spec{valspec(name, typ)},
   237  	}
   238  }
   239  
   240  func TestRewrite(t *testing.T) {
   241  	t.Run("*", func(t *testing.T) {
   242  		for _, test := range rewriteTests {
   243  			test := test
   244  			t.Run(test.name, func(t *testing.T) {
   245  				t.Parallel()
   246  				fset := token.NewFileSet()
   247  				f, err := parser.ParseFile(fset, test.name, test.orig, parser.ParseComments)
   248  				if err != nil {
   249  					t.Fatal(err)
   250  				}
   251  				n := astutil.Apply(f, test.pre, test.post)
   252  				var buf bytes.Buffer
   253  				if err := format.Node(&buf, fset, n); err != nil {
   254  					t.Fatal(err)
   255  				}
   256  				got := buf.String()
   257  				if got != test.want {
   258  					t.Errorf("got:\n\n%s\nwant:\n\n%s\n", got, test.want)
   259  				}
   260  			})
   261  		}
   262  	})
   263  }
   264  
   265  var sink ast.Node
   266  
   267  func BenchmarkRewrite(b *testing.B) {
   268  	for _, test := range rewriteTests {
   269  		b.Run(test.name, func(b *testing.B) {
   270  			for i := 0; i < b.N; i++ {
   271  				b.StopTimer()
   272  				fset := token.NewFileSet()
   273  				f, err := parser.ParseFile(fset, test.name, test.orig, parser.ParseComments)
   274  				if err != nil {
   275  					b.Fatal(err)
   276  				}
   277  				b.StartTimer()
   278  				sink = astutil.Apply(f, test.pre, test.post)
   279  			}
   280  		})
   281  	}
   282  }