github.com/jhump/protoreflect@v1.16.0/desc/protoparse/ast/walk.go (about) 1 package ast 2 3 // VisitFunc is used to examine a node in the AST when walking the tree. 4 // It returns true or false as to whether or not the descendants of the 5 // given node should be visited. If it returns true, the node's children 6 // will be visisted; if false, they will not. When returning true, it 7 // can also return a new VisitFunc to use for the children. If it returns 8 // (true, nil), then the current function will be re-used when visiting 9 // the children. 10 // 11 // See also the Visitor type. 12 type VisitFunc func(Node) (bool, VisitFunc) 13 14 // Walk conducts a walk of the AST rooted at the given root using the 15 // given function. It performs a "pre-order traversal", visiting a 16 // given AST node before it visits that node's descendants. 17 func Walk(root Node, v VisitFunc) { 18 ok, next := v(root) 19 if !ok { 20 return 21 } 22 if next != nil { 23 v = next 24 } 25 if comp, ok := root.(CompositeNode); ok { 26 for _, child := range comp.Children() { 27 Walk(child, v) 28 } 29 } 30 } 31 32 // Visitor provides a technique for walking the AST that allows for 33 // dynamic dispatch, where a particular function is invoked based on 34 // the runtime type of the argument. 35 // 36 // It consists of a number of functions, each of which matches a 37 // concrete Node type. It also includes functions for sub-interfaces 38 // of Node and the Node interface itself, to be used as broader 39 // "catch all" functions. 40 // 41 // To use a visitor, provide a function for the node types of 42 // interest and pass visitor.Visit as the function to a Walk operation. 43 // When a node is traversed, the corresponding function field of 44 // the visitor is invoked, if not nil. If the function for a node's 45 // concrete type is nil/absent but the function for an interface it 46 // implements is present, that interface visit function will be used 47 // instead. If no matching function is present, the traversal will 48 // continue. If a matching function is present, it will be invoked 49 // and its response determines how the traversal proceeds. 50 // 51 // Every visit function returns (bool, *Visitor). If the bool returned 52 // is false, the visited node's descendants are skipped. Otherwise, 53 // traversal will continue into the node's children. If the returned 54 // visitor is nil, the current visitor will continue to be used. But 55 // if a non-nil visitor is returned, it will be used to visit the 56 // node's children. 57 type Visitor struct { 58 // VisitFileNode is invoked when visiting a *FileNode in the AST. 59 VisitFileNode func(*FileNode) (bool, *Visitor) 60 // VisitSyntaxNode is invoked when visiting a *SyntaxNode in the AST. 61 VisitSyntaxNode func(*SyntaxNode) (bool, *Visitor) 62 63 // TODO: add VisitEditionNode 64 65 // VisitPackageNode is invoked when visiting a *PackageNode in the AST. 66 VisitPackageNode func(*PackageNode) (bool, *Visitor) 67 // VisitImportNode is invoked when visiting an *ImportNode in the AST. 68 VisitImportNode func(*ImportNode) (bool, *Visitor) 69 // VisitOptionNode is invoked when visiting an *OptionNode in the AST. 70 VisitOptionNode func(*OptionNode) (bool, *Visitor) 71 // VisitOptionNameNode is invoked when visiting an *OptionNameNode in the AST. 72 VisitOptionNameNode func(*OptionNameNode) (bool, *Visitor) 73 // VisitFieldReferenceNode is invoked when visiting a *FieldReferenceNode in the AST. 74 VisitFieldReferenceNode func(*FieldReferenceNode) (bool, *Visitor) 75 // VisitCompactOptionsNode is invoked when visiting a *CompactOptionsNode in the AST. 76 VisitCompactOptionsNode func(*CompactOptionsNode) (bool, *Visitor) 77 // VisitMessageNode is invoked when visiting a *MessageNode in the AST. 78 VisitMessageNode func(*MessageNode) (bool, *Visitor) 79 // VisitExtendNode is invoked when visiting an *ExtendNode in the AST. 80 VisitExtendNode func(*ExtendNode) (bool, *Visitor) 81 // VisitExtensionRangeNode is invoked when visiting an *ExtensionRangeNode in the AST. 82 VisitExtensionRangeNode func(*ExtensionRangeNode) (bool, *Visitor) 83 // VisitReservedNode is invoked when visiting a *ReservedNode in the AST. 84 VisitReservedNode func(*ReservedNode) (bool, *Visitor) 85 // VisitRangeNode is invoked when visiting a *RangeNode in the AST. 86 VisitRangeNode func(*RangeNode) (bool, *Visitor) 87 // VisitFieldNode is invoked when visiting a *FieldNode in the AST. 88 VisitFieldNode func(*FieldNode) (bool, *Visitor) 89 // VisitGroupNode is invoked when visiting a *GroupNode in the AST. 90 VisitGroupNode func(*GroupNode) (bool, *Visitor) 91 // VisitMapFieldNode is invoked when visiting a *MapFieldNode in the AST. 92 VisitMapFieldNode func(*MapFieldNode) (bool, *Visitor) 93 // VisitMapTypeNode is invoked when visiting a *MapTypeNode in the AST. 94 VisitMapTypeNode func(*MapTypeNode) (bool, *Visitor) 95 // VisitOneOfNode is invoked when visiting a *OneOfNode in the AST. 96 VisitOneOfNode func(*OneOfNode) (bool, *Visitor) 97 // VisitEnumNode is invoked when visiting an *EnumNode in the AST. 98 VisitEnumNode func(*EnumNode) (bool, *Visitor) 99 // VisitEnumValueNode is invoked when visiting an *EnumValueNode in the AST. 100 VisitEnumValueNode func(*EnumValueNode) (bool, *Visitor) 101 // VisitServiceNode is invoked when visiting a *ServiceNode in the AST. 102 VisitServiceNode func(*ServiceNode) (bool, *Visitor) 103 // VisitRPCNode is invoked when visiting an *RPCNode in the AST. 104 VisitRPCNode func(*RPCNode) (bool, *Visitor) 105 // VisitRPCTypeNode is invoked when visiting an *RPCTypeNode in the AST. 106 VisitRPCTypeNode func(*RPCTypeNode) (bool, *Visitor) 107 // VisitIdentNode is invoked when visiting an *IdentNode in the AST. 108 VisitIdentNode func(*IdentNode) (bool, *Visitor) 109 // VisitCompoundIdentNode is invoked when visiting a *CompoundIdentNode in the AST. 110 VisitCompoundIdentNode func(*CompoundIdentNode) (bool, *Visitor) 111 // VisitStringLiteralNode is invoked when visiting a *StringLiteralNode in the AST. 112 VisitStringLiteralNode func(*StringLiteralNode) (bool, *Visitor) 113 // VisitCompoundStringLiteralNode is invoked when visiting a *CompoundStringLiteralNode in the AST. 114 VisitCompoundStringLiteralNode func(*CompoundStringLiteralNode) (bool, *Visitor) 115 // VisitUintLiteralNode is invoked when visiting a *UintLiteralNode in the AST. 116 VisitUintLiteralNode func(*UintLiteralNode) (bool, *Visitor) 117 // VisitPositiveUintLiteralNode is invoked when visiting a *PositiveUintLiteralNode in the AST. 118 // 119 // Deprecated: this node type will not actually be present in an AST. 120 VisitPositiveUintLiteralNode func(*PositiveUintLiteralNode) (bool, *Visitor) 121 // VisitNegativeIntLiteralNode is invoked when visiting a *NegativeIntLiteralNode in the AST. 122 VisitNegativeIntLiteralNode func(*NegativeIntLiteralNode) (bool, *Visitor) 123 // VisitFloatLiteralNode is invoked when visiting a *FloatLiteralNode in the AST. 124 VisitFloatLiteralNode func(*FloatLiteralNode) (bool, *Visitor) 125 // VisitSpecialFloatLiteralNode is invoked when visiting a *SpecialFloatLiteralNode in the AST. 126 VisitSpecialFloatLiteralNode func(*SpecialFloatLiteralNode) (bool, *Visitor) 127 // VisitSignedFloatLiteralNode is invoked when visiting a *SignedFloatLiteralNode in the AST. 128 VisitSignedFloatLiteralNode func(*SignedFloatLiteralNode) (bool, *Visitor) 129 // VisitBoolLiteralNode is invoked when visiting a *BoolLiteralNode in the AST. 130 VisitBoolLiteralNode func(*BoolLiteralNode) (bool, *Visitor) 131 // VisitArrayLiteralNode is invoked when visiting an *ArrayLiteralNode in the AST. 132 VisitArrayLiteralNode func(*ArrayLiteralNode) (bool, *Visitor) 133 // VisitMessageLiteralNode is invoked when visiting a *MessageLiteralNode in the AST. 134 VisitMessageLiteralNode func(*MessageLiteralNode) (bool, *Visitor) 135 // VisitMessageFieldNode is invoked when visiting a *MessageFieldNode in the AST. 136 VisitMessageFieldNode func(*MessageFieldNode) (bool, *Visitor) 137 // VisitKeywordNode is invoked when visiting a *KeywordNode in the AST. 138 VisitKeywordNode func(*KeywordNode) (bool, *Visitor) 139 // VisitRuneNode is invoked when visiting a *RuneNode in the AST. 140 VisitRuneNode func(*RuneNode) (bool, *Visitor) 141 // VisitEmptyDeclNode is invoked when visiting a *EmptyDeclNode in the AST. 142 VisitEmptyDeclNode func(*EmptyDeclNode) (bool, *Visitor) 143 144 // VisitFieldDeclNode is invoked when visiting a FieldDeclNode in the AST. 145 // This function is used when no concrete type function is provided. If 146 // both this and VisitMessageDeclNode are provided, and a node implements 147 // both (such as *GroupNode and *MapFieldNode), this function will be 148 // invoked and not the other. 149 VisitFieldDeclNode func(FieldDeclNode) (bool, *Visitor) 150 // VisitMessageDeclNode is invoked when visiting a MessageDeclNode in the AST. 151 // This function is used when no concrete type function is provided. 152 VisitMessageDeclNode func(MessageDeclNode) (bool, *Visitor) 153 154 // VisitIdentValueNode is invoked when visiting an IdentValueNode in the AST. 155 // This function is used when no concrete type function is provided. 156 VisitIdentValueNode func(IdentValueNode) (bool, *Visitor) 157 // VisitStringValueNode is invoked when visiting a StringValueNode in the AST. 158 // This function is used when no concrete type function is provided. 159 VisitStringValueNode func(StringValueNode) (bool, *Visitor) 160 // VisitIntValueNode is invoked when visiting an IntValueNode in the AST. 161 // This function is used when no concrete type function is provided. If 162 // both this and VisitFloatValueNode are provided, and a node implements 163 // both (such as *UintLiteralNode), this function will be invoked and 164 // not the other. 165 VisitIntValueNode func(IntValueNode) (bool, *Visitor) 166 // VisitFloatValueNode is invoked when visiting a FloatValueNode in the AST. 167 // This function is used when no concrete type function is provided. 168 VisitFloatValueNode func(FloatValueNode) (bool, *Visitor) 169 // VisitValueNode is invoked when visiting a ValueNode in the AST. This 170 // function is used when no concrete type function is provided and no 171 // more specific ValueNode function is provided that matches the node. 172 VisitValueNode func(ValueNode) (bool, *Visitor) 173 174 // VisitTerminalNode is invoked when visiting a TerminalNode in the AST. 175 // This function is used when no concrete type function is provided 176 // no more specific interface type function is provided. 177 VisitTerminalNode func(TerminalNode) (bool, *Visitor) 178 // VisitCompositeNode is invoked when visiting a CompositeNode in the AST. 179 // This function is used when no concrete type function is provided 180 // no more specific interface type function is provided. 181 VisitCompositeNode func(CompositeNode) (bool, *Visitor) 182 // VisitNode is invoked when visiting a Node in the AST. This 183 // function is only used when no other more specific function is 184 // provided. 185 VisitNode func(Node) (bool, *Visitor) 186 } 187 188 // Visit provides the Visitor's implementation of VisitFunc, to be 189 // used with Walk operations. 190 func (v *Visitor) Visit(n Node) (bool, VisitFunc) { 191 var ok, matched bool 192 var next *Visitor 193 switch n := n.(type) { 194 case *FileNode: 195 if v.VisitFileNode != nil { 196 matched = true 197 ok, next = v.VisitFileNode(n) 198 } 199 case *SyntaxNode: 200 if v.VisitSyntaxNode != nil { 201 matched = true 202 ok, next = v.VisitSyntaxNode(n) 203 } 204 case *PackageNode: 205 if v.VisitPackageNode != nil { 206 matched = true 207 ok, next = v.VisitPackageNode(n) 208 } 209 case *ImportNode: 210 if v.VisitImportNode != nil { 211 matched = true 212 ok, next = v.VisitImportNode(n) 213 } 214 case *OptionNode: 215 if v.VisitOptionNode != nil { 216 matched = true 217 ok, next = v.VisitOptionNode(n) 218 } 219 case *OptionNameNode: 220 if v.VisitOptionNameNode != nil { 221 matched = true 222 ok, next = v.VisitOptionNameNode(n) 223 } 224 case *FieldReferenceNode: 225 if v.VisitFieldReferenceNode != nil { 226 matched = true 227 ok, next = v.VisitFieldReferenceNode(n) 228 } 229 case *CompactOptionsNode: 230 if v.VisitCompactOptionsNode != nil { 231 matched = true 232 ok, next = v.VisitCompactOptionsNode(n) 233 } 234 case *MessageNode: 235 if v.VisitMessageNode != nil { 236 matched = true 237 ok, next = v.VisitMessageNode(n) 238 } 239 case *ExtendNode: 240 if v.VisitExtendNode != nil { 241 matched = true 242 ok, next = v.VisitExtendNode(n) 243 } 244 case *ExtensionRangeNode: 245 if v.VisitExtensionRangeNode != nil { 246 matched = true 247 ok, next = v.VisitExtensionRangeNode(n) 248 } 249 case *ReservedNode: 250 if v.VisitReservedNode != nil { 251 matched = true 252 ok, next = v.VisitReservedNode(n) 253 } 254 case *RangeNode: 255 if v.VisitRangeNode != nil { 256 matched = true 257 ok, next = v.VisitRangeNode(n) 258 } 259 case *FieldNode: 260 if v.VisitFieldNode != nil { 261 matched = true 262 ok, next = v.VisitFieldNode(n) 263 } 264 case *GroupNode: 265 if v.VisitGroupNode != nil { 266 matched = true 267 ok, next = v.VisitGroupNode(n) 268 } 269 case *MapFieldNode: 270 if v.VisitMapFieldNode != nil { 271 matched = true 272 ok, next = v.VisitMapFieldNode(n) 273 } 274 case *MapTypeNode: 275 if v.VisitMapTypeNode != nil { 276 matched = true 277 ok, next = v.VisitMapTypeNode(n) 278 } 279 case *OneOfNode: 280 if v.VisitOneOfNode != nil { 281 matched = true 282 ok, next = v.VisitOneOfNode(n) 283 } 284 case *EnumNode: 285 if v.VisitEnumNode != nil { 286 matched = true 287 ok, next = v.VisitEnumNode(n) 288 } 289 case *EnumValueNode: 290 if v.VisitEnumValueNode != nil { 291 matched = true 292 ok, next = v.VisitEnumValueNode(n) 293 } 294 case *ServiceNode: 295 if v.VisitServiceNode != nil { 296 matched = true 297 ok, next = v.VisitServiceNode(n) 298 } 299 case *RPCNode: 300 if v.VisitRPCNode != nil { 301 matched = true 302 ok, next = v.VisitRPCNode(n) 303 } 304 case *RPCTypeNode: 305 if v.VisitRPCTypeNode != nil { 306 matched = true 307 ok, next = v.VisitRPCTypeNode(n) 308 } 309 case *IdentNode: 310 if v.VisitIdentNode != nil { 311 matched = true 312 ok, next = v.VisitIdentNode(n) 313 } 314 case *CompoundIdentNode: 315 if v.VisitCompoundIdentNode != nil { 316 matched = true 317 ok, next = v.VisitCompoundIdentNode(n) 318 } 319 case *StringLiteralNode: 320 if v.VisitStringLiteralNode != nil { 321 matched = true 322 ok, next = v.VisitStringLiteralNode(n) 323 } 324 case *CompoundStringLiteralNode: 325 if v.VisitCompoundStringLiteralNode != nil { 326 matched = true 327 ok, next = v.VisitCompoundStringLiteralNode(n) 328 } 329 case *UintLiteralNode: 330 if v.VisitUintLiteralNode != nil { 331 matched = true 332 ok, next = v.VisitUintLiteralNode(n) 333 } 334 case *PositiveUintLiteralNode: 335 if v.VisitPositiveUintLiteralNode != nil { 336 matched = true 337 ok, next = v.VisitPositiveUintLiteralNode(n) 338 } 339 case *NegativeIntLiteralNode: 340 if v.VisitNegativeIntLiteralNode != nil { 341 matched = true 342 ok, next = v.VisitNegativeIntLiteralNode(n) 343 } 344 case *FloatLiteralNode: 345 if v.VisitFloatLiteralNode != nil { 346 matched = true 347 ok, next = v.VisitFloatLiteralNode(n) 348 } 349 case *SpecialFloatLiteralNode: 350 if v.VisitSpecialFloatLiteralNode != nil { 351 matched = true 352 ok, next = v.VisitSpecialFloatLiteralNode(n) 353 } 354 case *SignedFloatLiteralNode: 355 if v.VisitSignedFloatLiteralNode != nil { 356 matched = true 357 ok, next = v.VisitSignedFloatLiteralNode(n) 358 } 359 case *BoolLiteralNode: 360 if v.VisitBoolLiteralNode != nil { 361 matched = true 362 ok, next = v.VisitBoolLiteralNode(n) 363 } 364 case *ArrayLiteralNode: 365 if v.VisitArrayLiteralNode != nil { 366 matched = true 367 ok, next = v.VisitArrayLiteralNode(n) 368 } 369 case *MessageLiteralNode: 370 if v.VisitMessageLiteralNode != nil { 371 matched = true 372 ok, next = v.VisitMessageLiteralNode(n) 373 } 374 case *MessageFieldNode: 375 if v.VisitMessageFieldNode != nil { 376 matched = true 377 ok, next = v.VisitMessageFieldNode(n) 378 } 379 case *KeywordNode: 380 if v.VisitKeywordNode != nil { 381 matched = true 382 ok, next = v.VisitKeywordNode(n) 383 } 384 case *RuneNode: 385 if v.VisitRuneNode != nil { 386 matched = true 387 ok, next = v.VisitRuneNode(n) 388 } 389 case *EmptyDeclNode: 390 if v.VisitEmptyDeclNode != nil { 391 matched = true 392 ok, next = v.VisitEmptyDeclNode(n) 393 } 394 } 395 396 if !matched { 397 // Visitor provided no concrete type visit function, so 398 // check interface types. We do this in several passes 399 // to provide "priority" for matched interfaces for nodes 400 // that actually implement more than one interface. 401 // 402 // For example, StringLiteralNode implements both 403 // StringValueNode and ValueNode. Both cases could match 404 // so the first case is what would match. So if we want 405 // to test against either, they need to be in different 406 // switch statements. 407 switch n := n.(type) { 408 case FieldDeclNode: 409 if v.VisitFieldDeclNode != nil { 410 matched = true 411 ok, next = v.VisitFieldDeclNode(n) 412 } 413 case IdentValueNode: 414 if v.VisitIdentValueNode != nil { 415 matched = true 416 ok, next = v.VisitIdentValueNode(n) 417 } 418 case StringValueNode: 419 if v.VisitStringValueNode != nil { 420 matched = true 421 ok, next = v.VisitStringValueNode(n) 422 } 423 case IntValueNode: 424 if v.VisitIntValueNode != nil { 425 matched = true 426 ok, next = v.VisitIntValueNode(n) 427 } 428 } 429 } 430 431 if !matched { 432 // These two are excluded from the above switch so that 433 // if visitor provides both VisitIntValueNode and 434 // VisitFloatValueNode, we'll prefer VisitIntValueNode 435 // for *UintLiteralNode (which implements both). Similarly, 436 // that way we prefer VisitFieldDeclNode over 437 // VisitMessageDeclNode when visiting a *GroupNode. 438 switch n := n.(type) { 439 case FloatValueNode: 440 if v.VisitFloatValueNode != nil { 441 matched = true 442 ok, next = v.VisitFloatValueNode(n) 443 } 444 case MessageDeclNode: 445 if v.VisitMessageDeclNode != nil { 446 matched = true 447 ok, next = v.VisitMessageDeclNode(n) 448 } 449 } 450 } 451 452 if !matched { 453 switch n := n.(type) { 454 case ValueNode: 455 if v.VisitValueNode != nil { 456 matched = true 457 ok, next = v.VisitValueNode(n) 458 } 459 } 460 } 461 462 if !matched { 463 switch n := n.(type) { 464 case TerminalNode: 465 if v.VisitTerminalNode != nil { 466 matched = true 467 ok, next = v.VisitTerminalNode(n) 468 } 469 case CompositeNode: 470 if v.VisitCompositeNode != nil { 471 matched = true 472 ok, next = v.VisitCompositeNode(n) 473 } 474 } 475 } 476 477 if !matched { 478 // finally, fallback to most generic visit function 479 if v.VisitNode != nil { 480 matched = true 481 ok, next = v.VisitNode(n) 482 } 483 } 484 485 if !matched { 486 // keep descending with the current visitor 487 return true, nil 488 } 489 490 if !ok { 491 return false, nil 492 } 493 if next != nil { 494 return true, next.Visit 495 } 496 return true, v.Visit 497 }