github.com/joomcode/cue@v0.4.4-0.20221111115225-539fe3512047/cue/ast/astutil/apply.go (about)

     1  // Copyright 2018 The CUE Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package astutil
    16  
    17  import (
    18  	"encoding/hex"
    19  	"fmt"
    20  	"hash/fnv"
    21  	"reflect"
    22  
    23  	"github.com/joomcode/cue/cue/ast"
    24  )
    25  
    26  // A Cursor describes a node encountered during Apply.
    27  // Information about the node and its parent is available
    28  // from the Node, Parent, and Index methods.
    29  //
    30  // The methods Replace, Delete, InsertBefore, and InsertAfter
    31  // can be used to change the AST without disrupting Apply.
    32  // Delete, InsertBefore, and InsertAfter are only defined for modifying
    33  // a StructLit and will panic in any other context.
    34  type Cursor interface {
    35  	// Node returns the current Node.
    36  	Node() ast.Node
    37  
    38  	// Parent returns the parent of the current Node.
    39  	Parent() Cursor
    40  
    41  	// Index reports the index >= 0 of the current Node in the slice of Nodes
    42  	// that contains it, or a value < 0 if the current Node is not part of a
    43  	// list.
    44  	Index() int
    45  
    46  	// Import reports an opaque identifier that refers to the given package. It
    47  	// may only be called if the input to apply was an ast.File. If the import
    48  	// does not exist, it will be added.
    49  	Import(path string) *ast.Ident
    50  
    51  	// Replace replaces the current Node with n.
    52  	// The replacement node is not walked by Apply. Comments of the old node
    53  	// are copied to the new node if it has not yet an comments associated
    54  	// with it.
    55  	Replace(n ast.Node)
    56  
    57  	// Delete deletes the current Node from its containing struct.
    58  	// If the current Node is not part of a struct, Delete panics.
    59  	Delete()
    60  
    61  	// InsertAfter inserts n after the current Node in its containing struct.
    62  	// If the current Node is not part of a struct, InsertAfter panics.
    63  	// Unless n is wrapped by ApplyRecursively, Apply does not walk n.
    64  	InsertAfter(n ast.Node)
    65  
    66  	// InsertBefore inserts n before the current Node in its containing struct.
    67  	// If the current Node is not part of a struct, InsertBefore panics.
    68  	// Unless n is wrapped by ApplyRecursively, Apply does not walk n.
    69  	InsertBefore(n ast.Node)
    70  
    71  	self() *cursor
    72  }
    73  
    74  // ApplyRecursively indicates that a node inserted with InsertBefore,
    75  // or InsertAfter should be processed recursively.
    76  func ApplyRecursively(n ast.Node) ast.Node {
    77  	return recursive{n}
    78  }
    79  
    80  type recursive struct {
    81  	ast.Node
    82  }
    83  
    84  type info struct {
    85  	f       *ast.File
    86  	current *declsCursor
    87  
    88  	importPatch []*ast.Ident
    89  }
    90  
    91  type cursor struct {
    92  	file     *info
    93  	parent   Cursor
    94  	node     ast.Node
    95  	typ      interface{} // the type of the node
    96  	index    int         // position of any of the sub types.
    97  	replaced bool
    98  }
    99  
   100  func newCursor(parent Cursor, n ast.Node, typ interface{}) *cursor {
   101  	return &cursor{
   102  		parent: parent,
   103  		typ:    typ,
   104  		node:   n,
   105  		index:  -1,
   106  	}
   107  }
   108  
   109  func fileInfo(c Cursor) (info *info) {
   110  	for ; c != nil; c = c.Parent() {
   111  		if i := c.self().file; i != nil {
   112  			return i
   113  		}
   114  	}
   115  	return nil
   116  }
   117  
   118  func (c *cursor) self() *cursor  { return c }
   119  func (c *cursor) Parent() Cursor { return c.parent }
   120  func (c *cursor) Index() int     { return c.index }
   121  func (c *cursor) Node() ast.Node { return c.node }
   122  
   123  func (c *cursor) Import(importPath string) *ast.Ident {
   124  	info := fileInfo(c)
   125  	if info == nil {
   126  		return nil
   127  	}
   128  
   129  	name := ImportPathName(importPath)
   130  
   131  	// TODO: come up with something much better.
   132  	// For instance, hoist the uniquer form cue/export.go to
   133  	// here and make export.go use this.
   134  	hash := fnv.New32()
   135  	name += hex.EncodeToString(hash.Sum([]byte(importPath)))[:6]
   136  
   137  	spec := insertImport(&info.current.decls, &ast.ImportSpec{
   138  		Name: ast.NewIdent(name),
   139  		Path: ast.NewString(importPath),
   140  	})
   141  
   142  	ident := &ast.Ident{Node: spec} // Name is set later.
   143  	info.importPatch = append(info.importPatch, ident)
   144  
   145  	ident.Name = name
   146  
   147  	return ident
   148  }
   149  
   150  func (c *cursor) Replace(n ast.Node) {
   151  	// panic if the value cannot convert to the original type.
   152  	reflect.ValueOf(n).Convert(reflect.TypeOf(c.typ).Elem())
   153  	if ast.Comments(n) != nil {
   154  		CopyComments(n, c.node)
   155  	}
   156  	if r, ok := n.(recursive); ok {
   157  		n = r.Node
   158  	} else {
   159  		c.replaced = true
   160  	}
   161  	c.node = n
   162  }
   163  
   164  func (c *cursor) InsertAfter(n ast.Node)  { panic("unsupported") }
   165  func (c *cursor) InsertBefore(n ast.Node) { panic("unsupported") }
   166  func (c *cursor) Delete()                 { panic("unsupported") }
   167  
   168  // Apply traverses a syntax tree recursively, starting with root,
   169  // and calling pre and post for each node as described below.
   170  // Apply returns the syntax tree, possibly modified.
   171  //
   172  // If pre is not nil, it is called for each node before the node's
   173  // children are traversed (pre-order). If pre returns false, no
   174  // children are traversed, and post is not called for that node.
   175  //
   176  // If post is not nil, and a prior call of pre didn't return false,
   177  // post is called for each node after its children are traversed
   178  // (post-order). If post returns false, traversal is terminated and
   179  // Apply returns immediately.
   180  //
   181  // Only fields that refer to AST nodes are considered children;
   182  // i.e., token.Pos, Scopes, Objects, and fields of basic types
   183  // (strings, etc.) are ignored.
   184  //
   185  // Children are traversed in the order in which they appear in the
   186  // respective node's struct definition.
   187  //
   188  func Apply(node ast.Node, before, after func(Cursor) bool) ast.Node {
   189  	apply(&applier{before: before, after: after}, nil, &node)
   190  	return node
   191  }
   192  
   193  // A applyVisitor's before method is invoked for each node encountered by Walk.
   194  // If the result applyVisitor w is true, Walk visits each of the children
   195  // of node with the applyVisitor w, followed by a call of w.After.
   196  type applyVisitor interface {
   197  	Before(Cursor) applyVisitor
   198  	After(Cursor) bool
   199  }
   200  
   201  // Helper functions for common node lists. They may be empty.
   202  
   203  func applyExprList(v applyVisitor, parent Cursor, ptr interface{}, list []ast.Expr) {
   204  	c := newCursor(parent, nil, nil)
   205  	for i, x := range list {
   206  		c.index = i
   207  		c.node = x
   208  		c.typ = &list[i]
   209  		applyCursor(v, c)
   210  		if x != c.node {
   211  			list[i] = c.node.(ast.Expr)
   212  		}
   213  	}
   214  }
   215  
   216  type declsCursor struct {
   217  	*cursor
   218  	decls, after, process []ast.Decl
   219  	delete                bool
   220  }
   221  
   222  func (c *declsCursor) InsertAfter(n ast.Node) {
   223  	if r, ok := n.(recursive); ok {
   224  		n = r.Node
   225  		c.process = append(c.process, n.(ast.Decl))
   226  	}
   227  	c.after = append(c.after, n.(ast.Decl))
   228  }
   229  
   230  func (c *declsCursor) InsertBefore(n ast.Node) {
   231  	if r, ok := n.(recursive); ok {
   232  		n = r.Node
   233  		c.process = append(c.process, n.(ast.Decl))
   234  	}
   235  	c.decls = append(c.decls, n.(ast.Decl))
   236  }
   237  
   238  func (c *declsCursor) Delete() { c.delete = true }
   239  
   240  func applyDeclList(v applyVisitor, parent Cursor, list []ast.Decl) []ast.Decl {
   241  	c := &declsCursor{
   242  		cursor: newCursor(parent, nil, nil),
   243  		decls:  make([]ast.Decl, 0, len(list)),
   244  	}
   245  	if file, ok := parent.Node().(*ast.File); ok {
   246  		c.cursor.file = &info{f: file, current: c}
   247  	}
   248  	for i, x := range list {
   249  		c.node = x
   250  		c.typ = &list[i]
   251  		applyCursor(v, c)
   252  		if !c.delete {
   253  			c.decls = append(c.decls, c.node.(ast.Decl))
   254  		}
   255  		c.delete = false
   256  		for i := 0; i < len(c.process); i++ {
   257  			x := c.process[i]
   258  			c.node = x
   259  			c.typ = &c.process[i]
   260  			applyCursor(v, c)
   261  			if c.delete {
   262  				panic("cannot delete a node that was added with InsertBefore or InsertAfter")
   263  			}
   264  		}
   265  		c.decls = append(c.decls, c.after...)
   266  		c.after = c.after[:0]
   267  		c.process = c.process[:0]
   268  	}
   269  
   270  	// TODO: ultimately, programmatically linked nodes have to be resolved
   271  	// at the end.
   272  	// if info := c.cursor.file; info != nil {
   273  	// 	done := map[*ast.ImportSpec]bool{}
   274  	// 	for _, ident := range info.importPatch {
   275  	// 		spec := ident.Node.(*ast.ImportSpec)
   276  	// 		if done[spec] {
   277  	// 			continue
   278  	// 		}
   279  	// 		done[spec] = true
   280  
   281  	// 		path, _ := strconv.Unquote(spec.Path)
   282  
   283  	// 		ident.Name =
   284  	// 	}
   285  	// }
   286  
   287  	return c.decls
   288  }
   289  
   290  func apply(v applyVisitor, parent Cursor, nodePtr interface{}) {
   291  	res := reflect.Indirect(reflect.ValueOf(nodePtr))
   292  	n := res.Interface()
   293  	node := n.(ast.Node)
   294  	c := newCursor(parent, node, nodePtr)
   295  	applyCursor(v, c)
   296  	if node != c.node {
   297  		res.Set(reflect.ValueOf(c.node))
   298  	}
   299  }
   300  
   301  // applyCursor traverses an AST in depth-first order: It starts by calling
   302  // v.Visit(node); node must not be nil. If the visitor w returned by
   303  // v.Visit(node) is not nil, apply is invoked recursively with visitor
   304  // w for each of the non-nil children of node, followed by a call of
   305  // w.Visit(nil).
   306  //
   307  func applyCursor(v applyVisitor, c Cursor) {
   308  	if v = v.Before(c); v == nil {
   309  		return
   310  	}
   311  
   312  	node := c.Node()
   313  
   314  	// TODO: record the comment groups and interleave with the values like for
   315  	// parsing and printing?
   316  	comments := node.Comments()
   317  	for _, cm := range comments {
   318  		apply(v, c, &cm)
   319  	}
   320  
   321  	// apply children
   322  	// (the order of the cases matches the order
   323  	// of the corresponding node types in go)
   324  	switch n := node.(type) {
   325  	// Comments and fields
   326  	case *ast.Comment:
   327  		// nothing to do
   328  
   329  	case *ast.CommentGroup:
   330  		for _, cg := range n.List {
   331  			apply(v, c, &cg)
   332  		}
   333  
   334  	case *ast.Attribute:
   335  		// nothing to do
   336  
   337  	case *ast.Field:
   338  		apply(v, c, &n.Label)
   339  		if n.Value != nil {
   340  			apply(v, c, &n.Value)
   341  		}
   342  		for _, a := range n.Attrs {
   343  			apply(v, c, &a)
   344  		}
   345  
   346  	case *ast.StructLit:
   347  		n.Elts = applyDeclList(v, c, n.Elts)
   348  
   349  	// Expressions
   350  	case *ast.BottomLit, *ast.BadExpr, *ast.Ident, *ast.BasicLit:
   351  		// nothing to do
   352  
   353  	case *ast.Interpolation:
   354  		applyExprList(v, c, &n, n.Elts)
   355  
   356  	case *ast.ListLit:
   357  		applyExprList(v, c, &n, n.Elts)
   358  
   359  	case *ast.Ellipsis:
   360  		if n.Type != nil {
   361  			apply(v, c, &n.Type)
   362  		}
   363  
   364  	case *ast.ParenExpr:
   365  		apply(v, c, &n.X)
   366  
   367  	case *ast.SelectorExpr:
   368  		apply(v, c, &n.X)
   369  		apply(v, c, &n.Sel)
   370  
   371  	case *ast.IndexExpr:
   372  		apply(v, c, &n.X)
   373  		apply(v, c, &n.Index)
   374  
   375  	case *ast.SliceExpr:
   376  		apply(v, c, &n.X)
   377  		if n.Low != nil {
   378  			apply(v, c, &n.Low)
   379  		}
   380  		if n.High != nil {
   381  			apply(v, c, &n.High)
   382  		}
   383  
   384  	case *ast.CallExpr:
   385  		apply(v, c, &n.Fun)
   386  		applyExprList(v, c, &n, n.Args)
   387  
   388  	case *ast.UnaryExpr:
   389  		apply(v, c, &n.X)
   390  
   391  	case *ast.BinaryExpr:
   392  		apply(v, c, &n.X)
   393  		apply(v, c, &n.Y)
   394  
   395  	// Declarations
   396  	case *ast.ImportSpec:
   397  		if n.Name != nil {
   398  			apply(v, c, &n.Name)
   399  		}
   400  		apply(v, c, &n.Path)
   401  
   402  	case *ast.BadDecl:
   403  		// nothing to do
   404  
   405  	case *ast.ImportDecl:
   406  		for _, s := range n.Specs {
   407  			apply(v, c, &s)
   408  		}
   409  
   410  	case *ast.EmbedDecl:
   411  		apply(v, c, &n.Expr)
   412  
   413  	case *ast.LetClause:
   414  		apply(v, c, &n.Ident)
   415  		apply(v, c, &n.Expr)
   416  
   417  	case *ast.Alias:
   418  		apply(v, c, &n.Ident)
   419  		apply(v, c, &n.Expr)
   420  
   421  	case *ast.Comprehension:
   422  		clauses := n.Clauses
   423  		for i := range n.Clauses {
   424  			apply(v, c, &clauses[i])
   425  		}
   426  		apply(v, c, &n.Value)
   427  
   428  	// Files and packages
   429  	case *ast.File:
   430  		n.Decls = applyDeclList(v, c, n.Decls)
   431  
   432  	case *ast.Package:
   433  		apply(v, c, &n.Name)
   434  
   435  	case *ast.ForClause:
   436  		if n.Key != nil {
   437  			apply(v, c, &n.Key)
   438  		}
   439  		apply(v, c, &n.Value)
   440  		apply(v, c, &n.Source)
   441  
   442  	case *ast.IfClause:
   443  		apply(v, c, &n.Condition)
   444  
   445  	default:
   446  		panic(fmt.Sprintf("Walk: unexpected node type %T", n))
   447  	}
   448  
   449  	v.After(c)
   450  }
   451  
   452  type applier struct {
   453  	before func(Cursor) bool
   454  	after  func(Cursor) bool
   455  
   456  	commentStack []commentFrame
   457  	current      commentFrame
   458  }
   459  
   460  type commentFrame struct {
   461  	cg  []*ast.CommentGroup
   462  	pos int8
   463  }
   464  
   465  func (f *applier) Before(c Cursor) applyVisitor {
   466  	node := c.Node()
   467  	if f.before == nil || (f.before(c) && node == c.Node()) {
   468  		f.commentStack = append(f.commentStack, f.current)
   469  		f.current = commentFrame{cg: node.Comments()}
   470  		f.visitComments(c, f.current.pos)
   471  		return f
   472  	}
   473  	return nil
   474  }
   475  
   476  func (f *applier) After(c Cursor) bool {
   477  	f.visitComments(c, 127)
   478  	p := len(f.commentStack) - 1
   479  	f.current = f.commentStack[p]
   480  	f.commentStack = f.commentStack[:p]
   481  	f.current.pos++
   482  	if f.after != nil {
   483  		f.after(c)
   484  	}
   485  	return true
   486  }
   487  
   488  func (f *applier) visitComments(p Cursor, pos int8) {
   489  	c := &f.current
   490  	for i := 0; i < len(c.cg); i++ {
   491  		cg := c.cg[i]
   492  		if cg.Position == pos {
   493  			continue
   494  		}
   495  		cursor := newCursor(p, cg, cg)
   496  		if f.before == nil || (f.before(cursor) && !cursor.replaced) {
   497  			for j, c := range cg.List {
   498  				cursor := newCursor(p, c, &c)
   499  				if f.before == nil || (f.before(cursor) && !cursor.replaced) {
   500  					if f.after != nil {
   501  						f.after(cursor)
   502  					}
   503  				}
   504  				cg.List[j] = cursor.node.(*ast.Comment)
   505  			}
   506  			if f.after != nil {
   507  				f.after(cursor)
   508  			}
   509  		}
   510  		c.cg[i] = cursor.node.(*ast.CommentGroup)
   511  	}
   512  }