github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/go/ast/astutil/rewrite.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package astutil
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"reflect"
    11  	"sort"
    12  
    13  	"github.com/powerman/golang-tools/internal/typeparams"
    14  )
    15  
    16  // An ApplyFunc is invoked by Apply for each node n, even if n is nil,
    17  // before and/or after the node's children, using a Cursor describing
    18  // the current node and providing operations on it.
    19  //
    20  // The return value of ApplyFunc controls the syntax tree traversal.
    21  // See Apply for details.
    22  type ApplyFunc func(*Cursor) bool
    23  
    24  // Apply traverses a syntax tree recursively, starting with root,
    25  // and calling pre and post for each node as described below.
    26  // Apply returns the syntax tree, possibly modified.
    27  //
    28  // If pre is not nil, it is called for each node before the node's
    29  // children are traversed (pre-order). If pre returns false, no
    30  // children are traversed, and post is not called for that node.
    31  //
    32  // If post is not nil, and a prior call of pre didn't return false,
    33  // post is called for each node after its children are traversed
    34  // (post-order). If post returns false, traversal is terminated and
    35  // Apply returns immediately.
    36  //
    37  // Only fields that refer to AST nodes are considered children;
    38  // i.e., token.Pos, Scopes, Objects, and fields of basic types
    39  // (strings, etc.) are ignored.
    40  //
    41  // Children are traversed in the order in which they appear in the
    42  // respective node's struct definition. A package's files are
    43  // traversed in the filenames' alphabetical order.
    44  //
    45  func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
    46  	parent := &struct{ ast.Node }{root}
    47  	defer func() {
    48  		if r := recover(); r != nil && r != abort {
    49  			panic(r)
    50  		}
    51  		result = parent.Node
    52  	}()
    53  	a := &application{pre: pre, post: post}
    54  	a.apply(parent, "Node", nil, root)
    55  	return
    56  }
    57  
    58  var abort = new(int) // singleton, to signal termination of Apply
    59  
    60  // A Cursor describes a node encountered during Apply.
    61  // Information about the node and its parent is available
    62  // from the Node, Parent, Name, and Index methods.
    63  //
    64  // If p is a variable of type and value of the current parent node
    65  // c.Parent(), and f is the field identifier with name c.Name(),
    66  // the following invariants hold:
    67  //
    68  //   p.f            == c.Node()  if c.Index() <  0
    69  //   p.f[c.Index()] == c.Node()  if c.Index() >= 0
    70  //
    71  // The methods Replace, Delete, InsertBefore, and InsertAfter
    72  // can be used to change the AST without disrupting Apply.
    73  type Cursor struct {
    74  	parent ast.Node
    75  	name   string
    76  	iter   *iterator // valid if non-nil
    77  	node   ast.Node
    78  }
    79  
    80  // Node returns the current Node.
    81  func (c *Cursor) Node() ast.Node { return c.node }
    82  
    83  // Parent returns the parent of the current Node.
    84  func (c *Cursor) Parent() ast.Node { return c.parent }
    85  
    86  // Name returns the name of the parent Node field that contains the current Node.
    87  // If the parent is a *ast.Package and the current Node is a *ast.File, Name returns
    88  // the filename for the current Node.
    89  func (c *Cursor) Name() string { return c.name }
    90  
    91  // Index reports the index >= 0 of the current Node in the slice of Nodes that
    92  // contains it, or a value < 0 if the current Node is not part of a slice.
    93  // The index of the current node changes if InsertBefore is called while
    94  // processing the current node.
    95  func (c *Cursor) Index() int {
    96  	if c.iter != nil {
    97  		return c.iter.index
    98  	}
    99  	return -1
   100  }
   101  
   102  // field returns the current node's parent field value.
   103  func (c *Cursor) field() reflect.Value {
   104  	return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
   105  }
   106  
   107  // Replace replaces the current Node with n.
   108  // The replacement node is not walked by Apply.
   109  func (c *Cursor) Replace(n ast.Node) {
   110  	if _, ok := c.node.(*ast.File); ok {
   111  		file, ok := n.(*ast.File)
   112  		if !ok {
   113  			panic("attempt to replace *ast.File with non-*ast.File")
   114  		}
   115  		c.parent.(*ast.Package).Files[c.name] = file
   116  		return
   117  	}
   118  
   119  	v := c.field()
   120  	if i := c.Index(); i >= 0 {
   121  		v = v.Index(i)
   122  	}
   123  	v.Set(reflect.ValueOf(n))
   124  }
   125  
   126  // Delete deletes the current Node from its containing slice.
   127  // If the current Node is not part of a slice, Delete panics.
   128  // As a special case, if the current node is a package file,
   129  // Delete removes it from the package's Files map.
   130  func (c *Cursor) Delete() {
   131  	if _, ok := c.node.(*ast.File); ok {
   132  		delete(c.parent.(*ast.Package).Files, c.name)
   133  		return
   134  	}
   135  
   136  	i := c.Index()
   137  	if i < 0 {
   138  		panic("Delete node not contained in slice")
   139  	}
   140  	v := c.field()
   141  	l := v.Len()
   142  	reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
   143  	v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
   144  	v.SetLen(l - 1)
   145  	c.iter.step--
   146  }
   147  
   148  // InsertAfter inserts n after the current Node in its containing slice.
   149  // If the current Node is not part of a slice, InsertAfter panics.
   150  // Apply does not walk n.
   151  func (c *Cursor) InsertAfter(n ast.Node) {
   152  	i := c.Index()
   153  	if i < 0 {
   154  		panic("InsertAfter node not contained in slice")
   155  	}
   156  	v := c.field()
   157  	v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
   158  	l := v.Len()
   159  	reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
   160  	v.Index(i + 1).Set(reflect.ValueOf(n))
   161  	c.iter.step++
   162  }
   163  
   164  // InsertBefore inserts n before the current Node in its containing slice.
   165  // If the current Node is not part of a slice, InsertBefore panics.
   166  // Apply will not walk n.
   167  func (c *Cursor) InsertBefore(n ast.Node) {
   168  	i := c.Index()
   169  	if i < 0 {
   170  		panic("InsertBefore node not contained in slice")
   171  	}
   172  	v := c.field()
   173  	v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
   174  	l := v.Len()
   175  	reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
   176  	v.Index(i).Set(reflect.ValueOf(n))
   177  	c.iter.index++
   178  }
   179  
   180  // application carries all the shared data so we can pass it around cheaply.
   181  type application struct {
   182  	pre, post ApplyFunc
   183  	cursor    Cursor
   184  	iter      iterator
   185  }
   186  
   187  func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
   188  	// convert typed nil into untyped nil
   189  	if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() {
   190  		n = nil
   191  	}
   192  
   193  	// avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead
   194  	saved := a.cursor
   195  	a.cursor.parent = parent
   196  	a.cursor.name = name
   197  	a.cursor.iter = iter
   198  	a.cursor.node = n
   199  
   200  	if a.pre != nil && !a.pre(&a.cursor) {
   201  		a.cursor = saved
   202  		return
   203  	}
   204  
   205  	// walk children
   206  	// (the order of the cases matches the order of the corresponding node types in go/ast)
   207  	switch n := n.(type) {
   208  	case nil:
   209  		// nothing to do
   210  
   211  	// Comments and fields
   212  	case *ast.Comment:
   213  		// nothing to do
   214  
   215  	case *ast.CommentGroup:
   216  		if n != nil {
   217  			a.applyList(n, "List")
   218  		}
   219  
   220  	case *ast.Field:
   221  		a.apply(n, "Doc", nil, n.Doc)
   222  		a.applyList(n, "Names")
   223  		a.apply(n, "Type", nil, n.Type)
   224  		a.apply(n, "Tag", nil, n.Tag)
   225  		a.apply(n, "Comment", nil, n.Comment)
   226  
   227  	case *ast.FieldList:
   228  		a.applyList(n, "List")
   229  
   230  	// Expressions
   231  	case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
   232  		// nothing to do
   233  
   234  	case *ast.Ellipsis:
   235  		a.apply(n, "Elt", nil, n.Elt)
   236  
   237  	case *ast.FuncLit:
   238  		a.apply(n, "Type", nil, n.Type)
   239  		a.apply(n, "Body", nil, n.Body)
   240  
   241  	case *ast.CompositeLit:
   242  		a.apply(n, "Type", nil, n.Type)
   243  		a.applyList(n, "Elts")
   244  
   245  	case *ast.ParenExpr:
   246  		a.apply(n, "X", nil, n.X)
   247  
   248  	case *ast.SelectorExpr:
   249  		a.apply(n, "X", nil, n.X)
   250  		a.apply(n, "Sel", nil, n.Sel)
   251  
   252  	case *ast.IndexExpr:
   253  		a.apply(n, "X", nil, n.X)
   254  		a.apply(n, "Index", nil, n.Index)
   255  
   256  	case *typeparams.IndexListExpr:
   257  		a.apply(n, "X", nil, n.X)
   258  		a.applyList(n, "Indices")
   259  
   260  	case *ast.SliceExpr:
   261  		a.apply(n, "X", nil, n.X)
   262  		a.apply(n, "Low", nil, n.Low)
   263  		a.apply(n, "High", nil, n.High)
   264  		a.apply(n, "Max", nil, n.Max)
   265  
   266  	case *ast.TypeAssertExpr:
   267  		a.apply(n, "X", nil, n.X)
   268  		a.apply(n, "Type", nil, n.Type)
   269  
   270  	case *ast.CallExpr:
   271  		a.apply(n, "Fun", nil, n.Fun)
   272  		a.applyList(n, "Args")
   273  
   274  	case *ast.StarExpr:
   275  		a.apply(n, "X", nil, n.X)
   276  
   277  	case *ast.UnaryExpr:
   278  		a.apply(n, "X", nil, n.X)
   279  
   280  	case *ast.BinaryExpr:
   281  		a.apply(n, "X", nil, n.X)
   282  		a.apply(n, "Y", nil, n.Y)
   283  
   284  	case *ast.KeyValueExpr:
   285  		a.apply(n, "Key", nil, n.Key)
   286  		a.apply(n, "Value", nil, n.Value)
   287  
   288  	// Types
   289  	case *ast.ArrayType:
   290  		a.apply(n, "Len", nil, n.Len)
   291  		a.apply(n, "Elt", nil, n.Elt)
   292  
   293  	case *ast.StructType:
   294  		a.apply(n, "Fields", nil, n.Fields)
   295  
   296  	case *ast.FuncType:
   297  		a.apply(n, "Params", nil, n.Params)
   298  		a.apply(n, "Results", nil, n.Results)
   299  
   300  	case *ast.InterfaceType:
   301  		a.apply(n, "Methods", nil, n.Methods)
   302  
   303  	case *ast.MapType:
   304  		a.apply(n, "Key", nil, n.Key)
   305  		a.apply(n, "Value", nil, n.Value)
   306  
   307  	case *ast.ChanType:
   308  		a.apply(n, "Value", nil, n.Value)
   309  
   310  	// Statements
   311  	case *ast.BadStmt:
   312  		// nothing to do
   313  
   314  	case *ast.DeclStmt:
   315  		a.apply(n, "Decl", nil, n.Decl)
   316  
   317  	case *ast.EmptyStmt:
   318  		// nothing to do
   319  
   320  	case *ast.LabeledStmt:
   321  		a.apply(n, "Label", nil, n.Label)
   322  		a.apply(n, "Stmt", nil, n.Stmt)
   323  
   324  	case *ast.ExprStmt:
   325  		a.apply(n, "X", nil, n.X)
   326  
   327  	case *ast.SendStmt:
   328  		a.apply(n, "Chan", nil, n.Chan)
   329  		a.apply(n, "Value", nil, n.Value)
   330  
   331  	case *ast.IncDecStmt:
   332  		a.apply(n, "X", nil, n.X)
   333  
   334  	case *ast.AssignStmt:
   335  		a.applyList(n, "Lhs")
   336  		a.applyList(n, "Rhs")
   337  
   338  	case *ast.GoStmt:
   339  		a.apply(n, "Call", nil, n.Call)
   340  
   341  	case *ast.DeferStmt:
   342  		a.apply(n, "Call", nil, n.Call)
   343  
   344  	case *ast.ReturnStmt:
   345  		a.applyList(n, "Results")
   346  
   347  	case *ast.BranchStmt:
   348  		a.apply(n, "Label", nil, n.Label)
   349  
   350  	case *ast.BlockStmt:
   351  		a.applyList(n, "List")
   352  
   353  	case *ast.IfStmt:
   354  		a.apply(n, "Init", nil, n.Init)
   355  		a.apply(n, "Cond", nil, n.Cond)
   356  		a.apply(n, "Body", nil, n.Body)
   357  		a.apply(n, "Else", nil, n.Else)
   358  
   359  	case *ast.CaseClause:
   360  		a.applyList(n, "List")
   361  		a.applyList(n, "Body")
   362  
   363  	case *ast.SwitchStmt:
   364  		a.apply(n, "Init", nil, n.Init)
   365  		a.apply(n, "Tag", nil, n.Tag)
   366  		a.apply(n, "Body", nil, n.Body)
   367  
   368  	case *ast.TypeSwitchStmt:
   369  		a.apply(n, "Init", nil, n.Init)
   370  		a.apply(n, "Assign", nil, n.Assign)
   371  		a.apply(n, "Body", nil, n.Body)
   372  
   373  	case *ast.CommClause:
   374  		a.apply(n, "Comm", nil, n.Comm)
   375  		a.applyList(n, "Body")
   376  
   377  	case *ast.SelectStmt:
   378  		a.apply(n, "Body", nil, n.Body)
   379  
   380  	case *ast.ForStmt:
   381  		a.apply(n, "Init", nil, n.Init)
   382  		a.apply(n, "Cond", nil, n.Cond)
   383  		a.apply(n, "Post", nil, n.Post)
   384  		a.apply(n, "Body", nil, n.Body)
   385  
   386  	case *ast.RangeStmt:
   387  		a.apply(n, "Key", nil, n.Key)
   388  		a.apply(n, "Value", nil, n.Value)
   389  		a.apply(n, "X", nil, n.X)
   390  		a.apply(n, "Body", nil, n.Body)
   391  
   392  	// Declarations
   393  	case *ast.ImportSpec:
   394  		a.apply(n, "Doc", nil, n.Doc)
   395  		a.apply(n, "Name", nil, n.Name)
   396  		a.apply(n, "Path", nil, n.Path)
   397  		a.apply(n, "Comment", nil, n.Comment)
   398  
   399  	case *ast.ValueSpec:
   400  		a.apply(n, "Doc", nil, n.Doc)
   401  		a.applyList(n, "Names")
   402  		a.apply(n, "Type", nil, n.Type)
   403  		a.applyList(n, "Values")
   404  		a.apply(n, "Comment", nil, n.Comment)
   405  
   406  	case *ast.TypeSpec:
   407  		a.apply(n, "Doc", nil, n.Doc)
   408  		a.apply(n, "Name", nil, n.Name)
   409  		a.apply(n, "Type", nil, n.Type)
   410  		a.apply(n, "Comment", nil, n.Comment)
   411  
   412  	case *ast.BadDecl:
   413  		// nothing to do
   414  
   415  	case *ast.GenDecl:
   416  		a.apply(n, "Doc", nil, n.Doc)
   417  		a.applyList(n, "Specs")
   418  
   419  	case *ast.FuncDecl:
   420  		a.apply(n, "Doc", nil, n.Doc)
   421  		a.apply(n, "Recv", nil, n.Recv)
   422  		a.apply(n, "Name", nil, n.Name)
   423  		a.apply(n, "Type", nil, n.Type)
   424  		a.apply(n, "Body", nil, n.Body)
   425  
   426  	// Files and packages
   427  	case *ast.File:
   428  		a.apply(n, "Doc", nil, n.Doc)
   429  		a.apply(n, "Name", nil, n.Name)
   430  		a.applyList(n, "Decls")
   431  		// Don't walk n.Comments; they have either been walked already if
   432  		// they are Doc comments, or they can be easily walked explicitly.
   433  
   434  	case *ast.Package:
   435  		// collect and sort names for reproducible behavior
   436  		var names []string
   437  		for name := range n.Files {
   438  			names = append(names, name)
   439  		}
   440  		sort.Strings(names)
   441  		for _, name := range names {
   442  			a.apply(n, name, nil, n.Files[name])
   443  		}
   444  
   445  	default:
   446  		panic(fmt.Sprintf("Apply: unexpected node type %T", n))
   447  	}
   448  
   449  	if a.post != nil && !a.post(&a.cursor) {
   450  		panic(abort)
   451  	}
   452  
   453  	a.cursor = saved
   454  }
   455  
   456  // An iterator controls iteration over a slice of nodes.
   457  type iterator struct {
   458  	index, step int
   459  }
   460  
   461  func (a *application) applyList(parent ast.Node, name string) {
   462  	// avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead
   463  	saved := a.iter
   464  	a.iter.index = 0
   465  	for {
   466  		// must reload parent.name each time, since cursor modifications might change it
   467  		v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
   468  		if a.iter.index >= v.Len() {
   469  			break
   470  		}
   471  
   472  		// element x may be nil in a bad AST - be cautious
   473  		var x ast.Node
   474  		if e := v.Index(a.iter.index); e.IsValid() {
   475  			x = e.Interface().(ast.Node)
   476  		}
   477  
   478  		a.iter.step = 1
   479  		a.apply(parent, name, &a.iter, x)
   480  		a.iter.index += a.iter.step
   481  	}
   482  	a.iter = saved
   483  }