github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/go/ast/astutil/rewrite.go (about) 1 // Copyright 2017 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package astutil 6 7 import ( 8 "fmt" 9 "go/ast" 10 "reflect" 11 "sort" 12 13 "golang.org/x/tools/internal/typeparams" 14 ) 15 16 // An ApplyFunc is invoked by Apply for each node n, even if n is nil, 17 // before and/or after the node's children, using a Cursor describing 18 // the current node and providing operations on it. 19 // 20 // The return value of ApplyFunc controls the syntax tree traversal. 21 // See Apply for details. 22 type ApplyFunc func(*Cursor) bool 23 24 // Apply traverses a syntax tree recursively, starting with root, 25 // and calling pre and post for each node as described below. 26 // Apply returns the syntax tree, possibly modified. 27 // 28 // If pre is not nil, it is called for each node before the node's 29 // children are traversed (pre-order). If pre returns false, no 30 // children are traversed, and post is not called for that node. 31 // 32 // If post is not nil, and a prior call of pre didn't return false, 33 // post is called for each node after its children are traversed 34 // (post-order). If post returns false, traversal is terminated and 35 // Apply returns immediately. 36 // 37 // Only fields that refer to AST nodes are considered children; 38 // i.e., token.Pos, Scopes, Objects, and fields of basic types 39 // (strings, etc.) are ignored. 40 // 41 // Children are traversed in the order in which they appear in the 42 // respective node's struct definition. A package's files are 43 // traversed in the filenames' alphabetical order. 44 func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) { 45 parent := &struct{ ast.Node }{root} 46 defer func() { 47 if r := recover(); r != nil && r != abort { 48 panic(r) 49 } 50 result = parent.Node 51 }() 52 a := &application{pre: pre, post: post} 53 a.apply(parent, "Node", nil, root) 54 return 55 } 56 57 var abort = new(int) // singleton, to signal termination of Apply 58 59 // A Cursor describes a node encountered during Apply. 60 // Information about the node and its parent is available 61 // from the Node, Parent, Name, and Index methods. 62 // 63 // If p is a variable of type and value of the current parent node 64 // c.Parent(), and f is the field identifier with name c.Name(), 65 // the following invariants hold: 66 // 67 // p.f == c.Node() if c.Index() < 0 68 // p.f[c.Index()] == c.Node() if c.Index() >= 0 69 // 70 // The methods Replace, Delete, InsertBefore, and InsertAfter 71 // can be used to change the AST without disrupting Apply. 72 type Cursor struct { 73 parent ast.Node 74 name string 75 iter *iterator // valid if non-nil 76 node ast.Node 77 } 78 79 // Node returns the current Node. 80 func (c *Cursor) Node() ast.Node { return c.node } 81 82 // Parent returns the parent of the current Node. 83 func (c *Cursor) Parent() ast.Node { return c.parent } 84 85 // Name returns the name of the parent Node field that contains the current Node. 86 // If the parent is a *ast.Package and the current Node is a *ast.File, Name returns 87 // the filename for the current Node. 88 func (c *Cursor) Name() string { return c.name } 89 90 // Index reports the index >= 0 of the current Node in the slice of Nodes that 91 // contains it, or a value < 0 if the current Node is not part of a slice. 92 // The index of the current node changes if InsertBefore is called while 93 // processing the current node. 94 func (c *Cursor) Index() int { 95 if c.iter != nil { 96 return c.iter.index 97 } 98 return -1 99 } 100 101 // field returns the current node's parent field value. 102 func (c *Cursor) field() reflect.Value { 103 return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name) 104 } 105 106 // Replace replaces the current Node with n. 107 // The replacement node is not walked by Apply. 108 func (c *Cursor) Replace(n ast.Node) { 109 if _, ok := c.node.(*ast.File); ok { 110 file, ok := n.(*ast.File) 111 if !ok { 112 panic("attempt to replace *ast.File with non-*ast.File") 113 } 114 c.parent.(*ast.Package).Files[c.name] = file 115 return 116 } 117 118 v := c.field() 119 if i := c.Index(); i >= 0 { 120 v = v.Index(i) 121 } 122 v.Set(reflect.ValueOf(n)) 123 } 124 125 // Delete deletes the current Node from its containing slice. 126 // If the current Node is not part of a slice, Delete panics. 127 // As a special case, if the current node is a package file, 128 // Delete removes it from the package's Files map. 129 func (c *Cursor) Delete() { 130 if _, ok := c.node.(*ast.File); ok { 131 delete(c.parent.(*ast.Package).Files, c.name) 132 return 133 } 134 135 i := c.Index() 136 if i < 0 { 137 panic("Delete node not contained in slice") 138 } 139 v := c.field() 140 l := v.Len() 141 reflect.Copy(v.Slice(i, l), v.Slice(i+1, l)) 142 v.Index(l - 1).Set(reflect.Zero(v.Type().Elem())) 143 v.SetLen(l - 1) 144 c.iter.step-- 145 } 146 147 // InsertAfter inserts n after the current Node in its containing slice. 148 // If the current Node is not part of a slice, InsertAfter panics. 149 // Apply does not walk n. 150 func (c *Cursor) InsertAfter(n ast.Node) { 151 i := c.Index() 152 if i < 0 { 153 panic("InsertAfter node not contained in slice") 154 } 155 v := c.field() 156 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem()))) 157 l := v.Len() 158 reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l)) 159 v.Index(i + 1).Set(reflect.ValueOf(n)) 160 c.iter.step++ 161 } 162 163 // InsertBefore inserts n before the current Node in its containing slice. 164 // If the current Node is not part of a slice, InsertBefore panics. 165 // Apply will not walk n. 166 func (c *Cursor) InsertBefore(n ast.Node) { 167 i := c.Index() 168 if i < 0 { 169 panic("InsertBefore node not contained in slice") 170 } 171 v := c.field() 172 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem()))) 173 l := v.Len() 174 reflect.Copy(v.Slice(i+1, l), v.Slice(i, l)) 175 v.Index(i).Set(reflect.ValueOf(n)) 176 c.iter.index++ 177 } 178 179 // application carries all the shared data so we can pass it around cheaply. 180 type application struct { 181 pre, post ApplyFunc 182 cursor Cursor 183 iter iterator 184 } 185 186 func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) { 187 // convert typed nil into untyped nil 188 if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() { 189 n = nil 190 } 191 192 // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead 193 saved := a.cursor 194 a.cursor.parent = parent 195 a.cursor.name = name 196 a.cursor.iter = iter 197 a.cursor.node = n 198 199 if a.pre != nil && !a.pre(&a.cursor) { 200 a.cursor = saved 201 return 202 } 203 204 // walk children 205 // (the order of the cases matches the order of the corresponding node types in go/ast) 206 switch n := n.(type) { 207 case nil: 208 // nothing to do 209 210 // Comments and fields 211 case *ast.Comment: 212 // nothing to do 213 214 case *ast.CommentGroup: 215 if n != nil { 216 a.applyList(n, "List") 217 } 218 219 case *ast.Field: 220 a.apply(n, "Doc", nil, n.Doc) 221 a.applyList(n, "Names") 222 a.apply(n, "Type", nil, n.Type) 223 a.apply(n, "Tag", nil, n.Tag) 224 a.apply(n, "Comment", nil, n.Comment) 225 226 case *ast.FieldList: 227 a.applyList(n, "List") 228 229 // Expressions 230 case *ast.BadExpr, *ast.Ident, *ast.BasicLit: 231 // nothing to do 232 233 case *ast.Ellipsis: 234 a.apply(n, "Elt", nil, n.Elt) 235 236 case *ast.FuncLit: 237 a.apply(n, "Type", nil, n.Type) 238 a.apply(n, "Body", nil, n.Body) 239 240 case *ast.CompositeLit: 241 a.apply(n, "Type", nil, n.Type) 242 a.applyList(n, "Elts") 243 244 case *ast.ParenExpr: 245 a.apply(n, "X", nil, n.X) 246 247 case *ast.SelectorExpr: 248 a.apply(n, "X", nil, n.X) 249 a.apply(n, "Sel", nil, n.Sel) 250 251 case *ast.IndexExpr: 252 a.apply(n, "X", nil, n.X) 253 a.apply(n, "Index", nil, n.Index) 254 255 case *typeparams.IndexListExpr: 256 a.apply(n, "X", nil, n.X) 257 a.applyList(n, "Indices") 258 259 case *ast.SliceExpr: 260 a.apply(n, "X", nil, n.X) 261 a.apply(n, "Low", nil, n.Low) 262 a.apply(n, "High", nil, n.High) 263 a.apply(n, "Max", nil, n.Max) 264 265 case *ast.TypeAssertExpr: 266 a.apply(n, "X", nil, n.X) 267 a.apply(n, "Type", nil, n.Type) 268 269 case *ast.CallExpr: 270 a.apply(n, "Fun", nil, n.Fun) 271 a.applyList(n, "Args") 272 273 case *ast.StarExpr: 274 a.apply(n, "X", nil, n.X) 275 276 case *ast.UnaryExpr: 277 a.apply(n, "X", nil, n.X) 278 279 case *ast.BinaryExpr: 280 a.apply(n, "X", nil, n.X) 281 a.apply(n, "Y", nil, n.Y) 282 283 case *ast.KeyValueExpr: 284 a.apply(n, "Key", nil, n.Key) 285 a.apply(n, "Value", nil, n.Value) 286 287 // Types 288 case *ast.ArrayType: 289 a.apply(n, "Len", nil, n.Len) 290 a.apply(n, "Elt", nil, n.Elt) 291 292 case *ast.StructType: 293 a.apply(n, "Fields", nil, n.Fields) 294 295 case *ast.FuncType: 296 if tparams := typeparams.ForFuncType(n); tparams != nil { 297 a.apply(n, "TypeParams", nil, tparams) 298 } 299 a.apply(n, "Params", nil, n.Params) 300 a.apply(n, "Results", nil, n.Results) 301 302 case *ast.InterfaceType: 303 a.apply(n, "Methods", nil, n.Methods) 304 305 case *ast.MapType: 306 a.apply(n, "Key", nil, n.Key) 307 a.apply(n, "Value", nil, n.Value) 308 309 case *ast.ChanType: 310 a.apply(n, "Value", nil, n.Value) 311 312 // Statements 313 case *ast.BadStmt: 314 // nothing to do 315 316 case *ast.DeclStmt: 317 a.apply(n, "Decl", nil, n.Decl) 318 319 case *ast.EmptyStmt: 320 // nothing to do 321 322 case *ast.LabeledStmt: 323 a.apply(n, "Label", nil, n.Label) 324 a.apply(n, "Stmt", nil, n.Stmt) 325 326 case *ast.ExprStmt: 327 a.apply(n, "X", nil, n.X) 328 329 case *ast.SendStmt: 330 a.apply(n, "Chan", nil, n.Chan) 331 a.apply(n, "Value", nil, n.Value) 332 333 case *ast.IncDecStmt: 334 a.apply(n, "X", nil, n.X) 335 336 case *ast.AssignStmt: 337 a.applyList(n, "Lhs") 338 a.applyList(n, "Rhs") 339 340 case *ast.GoStmt: 341 a.apply(n, "Call", nil, n.Call) 342 343 case *ast.DeferStmt: 344 a.apply(n, "Call", nil, n.Call) 345 346 case *ast.ReturnStmt: 347 a.applyList(n, "Results") 348 349 case *ast.BranchStmt: 350 a.apply(n, "Label", nil, n.Label) 351 352 case *ast.BlockStmt: 353 a.applyList(n, "List") 354 355 case *ast.IfStmt: 356 a.apply(n, "Init", nil, n.Init) 357 a.apply(n, "Cond", nil, n.Cond) 358 a.apply(n, "Body", nil, n.Body) 359 a.apply(n, "Else", nil, n.Else) 360 361 case *ast.CaseClause: 362 a.applyList(n, "List") 363 a.applyList(n, "Body") 364 365 case *ast.SwitchStmt: 366 a.apply(n, "Init", nil, n.Init) 367 a.apply(n, "Tag", nil, n.Tag) 368 a.apply(n, "Body", nil, n.Body) 369 370 case *ast.TypeSwitchStmt: 371 a.apply(n, "Init", nil, n.Init) 372 a.apply(n, "Assign", nil, n.Assign) 373 a.apply(n, "Body", nil, n.Body) 374 375 case *ast.CommClause: 376 a.apply(n, "Comm", nil, n.Comm) 377 a.applyList(n, "Body") 378 379 case *ast.SelectStmt: 380 a.apply(n, "Body", nil, n.Body) 381 382 case *ast.ForStmt: 383 a.apply(n, "Init", nil, n.Init) 384 a.apply(n, "Cond", nil, n.Cond) 385 a.apply(n, "Post", nil, n.Post) 386 a.apply(n, "Body", nil, n.Body) 387 388 case *ast.RangeStmt: 389 a.apply(n, "Key", nil, n.Key) 390 a.apply(n, "Value", nil, n.Value) 391 a.apply(n, "X", nil, n.X) 392 a.apply(n, "Body", nil, n.Body) 393 394 // Declarations 395 case *ast.ImportSpec: 396 a.apply(n, "Doc", nil, n.Doc) 397 a.apply(n, "Name", nil, n.Name) 398 a.apply(n, "Path", nil, n.Path) 399 a.apply(n, "Comment", nil, n.Comment) 400 401 case *ast.ValueSpec: 402 a.apply(n, "Doc", nil, n.Doc) 403 a.applyList(n, "Names") 404 a.apply(n, "Type", nil, n.Type) 405 a.applyList(n, "Values") 406 a.apply(n, "Comment", nil, n.Comment) 407 408 case *ast.TypeSpec: 409 a.apply(n, "Doc", nil, n.Doc) 410 a.apply(n, "Name", nil, n.Name) 411 if tparams := typeparams.ForTypeSpec(n); tparams != nil { 412 a.apply(n, "TypeParams", nil, tparams) 413 } 414 a.apply(n, "Type", nil, n.Type) 415 a.apply(n, "Comment", nil, n.Comment) 416 417 case *ast.BadDecl: 418 // nothing to do 419 420 case *ast.GenDecl: 421 a.apply(n, "Doc", nil, n.Doc) 422 a.applyList(n, "Specs") 423 424 case *ast.FuncDecl: 425 a.apply(n, "Doc", nil, n.Doc) 426 a.apply(n, "Recv", nil, n.Recv) 427 a.apply(n, "Name", nil, n.Name) 428 a.apply(n, "Type", nil, n.Type) 429 a.apply(n, "Body", nil, n.Body) 430 431 // Files and packages 432 case *ast.File: 433 a.apply(n, "Doc", nil, n.Doc) 434 a.apply(n, "Name", nil, n.Name) 435 a.applyList(n, "Decls") 436 // Don't walk n.Comments; they have either been walked already if 437 // they are Doc comments, or they can be easily walked explicitly. 438 439 case *ast.Package: 440 // collect and sort names for reproducible behavior 441 var names []string 442 for name := range n.Files { 443 names = append(names, name) 444 } 445 sort.Strings(names) 446 for _, name := range names { 447 a.apply(n, name, nil, n.Files[name]) 448 } 449 450 default: 451 panic(fmt.Sprintf("Apply: unexpected node type %T", n)) 452 } 453 454 if a.post != nil && !a.post(&a.cursor) { 455 panic(abort) 456 } 457 458 a.cursor = saved 459 } 460 461 // An iterator controls iteration over a slice of nodes. 462 type iterator struct { 463 index, step int 464 } 465 466 func (a *application) applyList(parent ast.Node, name string) { 467 // avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead 468 saved := a.iter 469 a.iter.index = 0 470 for { 471 // must reload parent.name each time, since cursor modifications might change it 472 v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name) 473 if a.iter.index >= v.Len() { 474 break 475 } 476 477 // element x may be nil in a bad AST - be cautious 478 var x ast.Node 479 if e := v.Index(a.iter.index); e.IsValid() { 480 x = e.Interface().(ast.Node) 481 } 482 483 a.iter.step = 1 484 a.apply(parent, name, &a.iter, x) 485 a.iter.index += a.iter.step 486 } 487 a.iter = saved 488 }