github.com/jhump/protoreflect@v1.16.0/desc/protoparse/ast/walk.go (about)

     1  package ast
     2  
     3  // VisitFunc is used to examine a node in the AST when walking the tree.
     4  // It returns true or false as to whether or not the descendants of the
     5  // given node should be visited. If it returns true, the node's children
     6  // will be visisted; if false, they will not. When returning true, it
     7  // can also return a new VisitFunc to use for the children. If it returns
     8  // (true, nil), then the current function will be re-used when visiting
     9  // the children.
    10  //
    11  // See also the Visitor type.
    12  type VisitFunc func(Node) (bool, VisitFunc)
    13  
    14  // Walk conducts a walk of the AST rooted at the given root using the
    15  // given function. It performs a "pre-order traversal", visiting a
    16  // given AST node before it visits that node's descendants.
    17  func Walk(root Node, v VisitFunc) {
    18  	ok, next := v(root)
    19  	if !ok {
    20  		return
    21  	}
    22  	if next != nil {
    23  		v = next
    24  	}
    25  	if comp, ok := root.(CompositeNode); ok {
    26  		for _, child := range comp.Children() {
    27  			Walk(child, v)
    28  		}
    29  	}
    30  }
    31  
    32  // Visitor provides a technique for walking the AST that allows for
    33  // dynamic dispatch, where a particular function is invoked based on
    34  // the runtime type of the argument.
    35  //
    36  // It consists of a number of functions, each of which matches a
    37  // concrete Node type. It also includes functions for sub-interfaces
    38  // of Node and the Node interface itself, to be used as broader
    39  // "catch all" functions.
    40  //
    41  // To use a visitor, provide a function for the node types of
    42  // interest and pass visitor.Visit as the function to a Walk operation.
    43  // When a node is traversed, the corresponding function field of
    44  // the visitor is invoked, if not nil. If the function for a node's
    45  // concrete type is nil/absent but the function for an interface it
    46  // implements is present, that interface visit function will be used
    47  // instead. If no matching function is present, the traversal will
    48  // continue. If a matching function is present, it will be invoked
    49  // and its response determines how the traversal proceeds.
    50  //
    51  // Every visit function returns (bool, *Visitor). If the bool returned
    52  // is false, the visited node's descendants are skipped. Otherwise,
    53  // traversal will continue into the node's children. If the returned
    54  // visitor is nil, the current visitor will continue to be used. But
    55  // if a non-nil visitor is returned, it will be used to visit the
    56  // node's children.
    57  type Visitor struct {
    58  	// VisitFileNode is invoked when visiting a *FileNode in the AST.
    59  	VisitFileNode func(*FileNode) (bool, *Visitor)
    60  	// VisitSyntaxNode is invoked when visiting a *SyntaxNode in the AST.
    61  	VisitSyntaxNode func(*SyntaxNode) (bool, *Visitor)
    62  
    63  	// TODO: add VisitEditionNode
    64  
    65  	// VisitPackageNode is invoked when visiting a *PackageNode in the AST.
    66  	VisitPackageNode func(*PackageNode) (bool, *Visitor)
    67  	// VisitImportNode is invoked when visiting an *ImportNode in the AST.
    68  	VisitImportNode func(*ImportNode) (bool, *Visitor)
    69  	// VisitOptionNode is invoked when visiting an *OptionNode in the AST.
    70  	VisitOptionNode func(*OptionNode) (bool, *Visitor)
    71  	// VisitOptionNameNode is invoked when visiting an *OptionNameNode in the AST.
    72  	VisitOptionNameNode func(*OptionNameNode) (bool, *Visitor)
    73  	// VisitFieldReferenceNode is invoked when visiting a *FieldReferenceNode in the AST.
    74  	VisitFieldReferenceNode func(*FieldReferenceNode) (bool, *Visitor)
    75  	// VisitCompactOptionsNode is invoked when visiting a *CompactOptionsNode in the AST.
    76  	VisitCompactOptionsNode func(*CompactOptionsNode) (bool, *Visitor)
    77  	// VisitMessageNode is invoked when visiting a *MessageNode in the AST.
    78  	VisitMessageNode func(*MessageNode) (bool, *Visitor)
    79  	// VisitExtendNode is invoked when visiting an *ExtendNode in the AST.
    80  	VisitExtendNode func(*ExtendNode) (bool, *Visitor)
    81  	// VisitExtensionRangeNode is invoked when visiting an *ExtensionRangeNode in the AST.
    82  	VisitExtensionRangeNode func(*ExtensionRangeNode) (bool, *Visitor)
    83  	// VisitReservedNode is invoked when visiting a *ReservedNode in the AST.
    84  	VisitReservedNode func(*ReservedNode) (bool, *Visitor)
    85  	// VisitRangeNode is invoked when visiting a *RangeNode in the AST.
    86  	VisitRangeNode func(*RangeNode) (bool, *Visitor)
    87  	// VisitFieldNode is invoked when visiting a *FieldNode in the AST.
    88  	VisitFieldNode func(*FieldNode) (bool, *Visitor)
    89  	// VisitGroupNode is invoked when visiting a *GroupNode in the AST.
    90  	VisitGroupNode func(*GroupNode) (bool, *Visitor)
    91  	// VisitMapFieldNode is invoked when visiting a *MapFieldNode in the AST.
    92  	VisitMapFieldNode func(*MapFieldNode) (bool, *Visitor)
    93  	// VisitMapTypeNode is invoked when visiting a *MapTypeNode in the AST.
    94  	VisitMapTypeNode func(*MapTypeNode) (bool, *Visitor)
    95  	// VisitOneOfNode is invoked when visiting a *OneOfNode in the AST.
    96  	VisitOneOfNode func(*OneOfNode) (bool, *Visitor)
    97  	// VisitEnumNode is invoked when visiting an *EnumNode in the AST.
    98  	VisitEnumNode func(*EnumNode) (bool, *Visitor)
    99  	// VisitEnumValueNode is invoked when visiting an *EnumValueNode in the AST.
   100  	VisitEnumValueNode func(*EnumValueNode) (bool, *Visitor)
   101  	// VisitServiceNode is invoked when visiting a *ServiceNode in the AST.
   102  	VisitServiceNode func(*ServiceNode) (bool, *Visitor)
   103  	// VisitRPCNode is invoked when visiting an *RPCNode in the AST.
   104  	VisitRPCNode func(*RPCNode) (bool, *Visitor)
   105  	// VisitRPCTypeNode is invoked when visiting an *RPCTypeNode in the AST.
   106  	VisitRPCTypeNode func(*RPCTypeNode) (bool, *Visitor)
   107  	// VisitIdentNode is invoked when visiting an *IdentNode in the AST.
   108  	VisitIdentNode func(*IdentNode) (bool, *Visitor)
   109  	// VisitCompoundIdentNode is invoked when visiting a *CompoundIdentNode in the AST.
   110  	VisitCompoundIdentNode func(*CompoundIdentNode) (bool, *Visitor)
   111  	// VisitStringLiteralNode is invoked when visiting a *StringLiteralNode in the AST.
   112  	VisitStringLiteralNode func(*StringLiteralNode) (bool, *Visitor)
   113  	// VisitCompoundStringLiteralNode is invoked when visiting a *CompoundStringLiteralNode in the AST.
   114  	VisitCompoundStringLiteralNode func(*CompoundStringLiteralNode) (bool, *Visitor)
   115  	// VisitUintLiteralNode is invoked when visiting a *UintLiteralNode in the AST.
   116  	VisitUintLiteralNode func(*UintLiteralNode) (bool, *Visitor)
   117  	// VisitPositiveUintLiteralNode is invoked when visiting a *PositiveUintLiteralNode in the AST.
   118  	//
   119  	// Deprecated: this node type will not actually be present in an AST.
   120  	VisitPositiveUintLiteralNode func(*PositiveUintLiteralNode) (bool, *Visitor)
   121  	// VisitNegativeIntLiteralNode is invoked when visiting a *NegativeIntLiteralNode in the AST.
   122  	VisitNegativeIntLiteralNode func(*NegativeIntLiteralNode) (bool, *Visitor)
   123  	// VisitFloatLiteralNode is invoked when visiting a *FloatLiteralNode in the AST.
   124  	VisitFloatLiteralNode func(*FloatLiteralNode) (bool, *Visitor)
   125  	// VisitSpecialFloatLiteralNode is invoked when visiting a *SpecialFloatLiteralNode in the AST.
   126  	VisitSpecialFloatLiteralNode func(*SpecialFloatLiteralNode) (bool, *Visitor)
   127  	// VisitSignedFloatLiteralNode is invoked when visiting a *SignedFloatLiteralNode in the AST.
   128  	VisitSignedFloatLiteralNode func(*SignedFloatLiteralNode) (bool, *Visitor)
   129  	// VisitBoolLiteralNode is invoked when visiting a *BoolLiteralNode in the AST.
   130  	VisitBoolLiteralNode func(*BoolLiteralNode) (bool, *Visitor)
   131  	// VisitArrayLiteralNode is invoked when visiting an *ArrayLiteralNode in the AST.
   132  	VisitArrayLiteralNode func(*ArrayLiteralNode) (bool, *Visitor)
   133  	// VisitMessageLiteralNode is invoked when visiting a *MessageLiteralNode in the AST.
   134  	VisitMessageLiteralNode func(*MessageLiteralNode) (bool, *Visitor)
   135  	// VisitMessageFieldNode is invoked when visiting a *MessageFieldNode in the AST.
   136  	VisitMessageFieldNode func(*MessageFieldNode) (bool, *Visitor)
   137  	// VisitKeywordNode is invoked when visiting a *KeywordNode in the AST.
   138  	VisitKeywordNode func(*KeywordNode) (bool, *Visitor)
   139  	// VisitRuneNode is invoked when visiting a *RuneNode in the AST.
   140  	VisitRuneNode func(*RuneNode) (bool, *Visitor)
   141  	// VisitEmptyDeclNode is invoked when visiting a *EmptyDeclNode in the AST.
   142  	VisitEmptyDeclNode func(*EmptyDeclNode) (bool, *Visitor)
   143  
   144  	// VisitFieldDeclNode is invoked when visiting a FieldDeclNode in the AST.
   145  	// This function is used when no concrete type function is provided. If
   146  	// both this and VisitMessageDeclNode are provided, and a node implements
   147  	// both (such as *GroupNode and *MapFieldNode), this function will be
   148  	// invoked and not the other.
   149  	VisitFieldDeclNode func(FieldDeclNode) (bool, *Visitor)
   150  	// VisitMessageDeclNode is invoked when visiting a MessageDeclNode in the AST.
   151  	// This function is used when no concrete type function is provided.
   152  	VisitMessageDeclNode func(MessageDeclNode) (bool, *Visitor)
   153  
   154  	// VisitIdentValueNode is invoked when visiting an IdentValueNode in the AST.
   155  	// This function is used when no concrete type function is provided.
   156  	VisitIdentValueNode func(IdentValueNode) (bool, *Visitor)
   157  	// VisitStringValueNode is invoked when visiting a StringValueNode in the AST.
   158  	// This function is used when no concrete type function is provided.
   159  	VisitStringValueNode func(StringValueNode) (bool, *Visitor)
   160  	// VisitIntValueNode is invoked when visiting an IntValueNode in the AST.
   161  	// This function is used when no concrete type function is provided. If
   162  	// both this and VisitFloatValueNode are provided, and a node implements
   163  	// both (such as *UintLiteralNode), this function will be invoked and
   164  	// not the other.
   165  	VisitIntValueNode func(IntValueNode) (bool, *Visitor)
   166  	// VisitFloatValueNode is invoked when visiting a FloatValueNode in the AST.
   167  	// This function is used when no concrete type function is provided.
   168  	VisitFloatValueNode func(FloatValueNode) (bool, *Visitor)
   169  	// VisitValueNode is invoked when visiting a ValueNode in the AST. This
   170  	// function is used when no concrete type function is provided and no
   171  	// more specific ValueNode function is provided that matches the node.
   172  	VisitValueNode func(ValueNode) (bool, *Visitor)
   173  
   174  	// VisitTerminalNode is invoked when visiting a TerminalNode in the AST.
   175  	// This function is used when no concrete type function is provided
   176  	// no more specific interface type function is provided.
   177  	VisitTerminalNode func(TerminalNode) (bool, *Visitor)
   178  	// VisitCompositeNode is invoked when visiting a CompositeNode in the AST.
   179  	// This function is used when no concrete type function is provided
   180  	// no more specific interface type function is provided.
   181  	VisitCompositeNode func(CompositeNode) (bool, *Visitor)
   182  	// VisitNode is invoked when visiting a Node in the AST. This
   183  	// function is only used when no other more specific function is
   184  	// provided.
   185  	VisitNode func(Node) (bool, *Visitor)
   186  }
   187  
   188  // Visit provides the Visitor's implementation of VisitFunc, to be
   189  // used with Walk operations.
   190  func (v *Visitor) Visit(n Node) (bool, VisitFunc) {
   191  	var ok, matched bool
   192  	var next *Visitor
   193  	switch n := n.(type) {
   194  	case *FileNode:
   195  		if v.VisitFileNode != nil {
   196  			matched = true
   197  			ok, next = v.VisitFileNode(n)
   198  		}
   199  	case *SyntaxNode:
   200  		if v.VisitSyntaxNode != nil {
   201  			matched = true
   202  			ok, next = v.VisitSyntaxNode(n)
   203  		}
   204  	case *PackageNode:
   205  		if v.VisitPackageNode != nil {
   206  			matched = true
   207  			ok, next = v.VisitPackageNode(n)
   208  		}
   209  	case *ImportNode:
   210  		if v.VisitImportNode != nil {
   211  			matched = true
   212  			ok, next = v.VisitImportNode(n)
   213  		}
   214  	case *OptionNode:
   215  		if v.VisitOptionNode != nil {
   216  			matched = true
   217  			ok, next = v.VisitOptionNode(n)
   218  		}
   219  	case *OptionNameNode:
   220  		if v.VisitOptionNameNode != nil {
   221  			matched = true
   222  			ok, next = v.VisitOptionNameNode(n)
   223  		}
   224  	case *FieldReferenceNode:
   225  		if v.VisitFieldReferenceNode != nil {
   226  			matched = true
   227  			ok, next = v.VisitFieldReferenceNode(n)
   228  		}
   229  	case *CompactOptionsNode:
   230  		if v.VisitCompactOptionsNode != nil {
   231  			matched = true
   232  			ok, next = v.VisitCompactOptionsNode(n)
   233  		}
   234  	case *MessageNode:
   235  		if v.VisitMessageNode != nil {
   236  			matched = true
   237  			ok, next = v.VisitMessageNode(n)
   238  		}
   239  	case *ExtendNode:
   240  		if v.VisitExtendNode != nil {
   241  			matched = true
   242  			ok, next = v.VisitExtendNode(n)
   243  		}
   244  	case *ExtensionRangeNode:
   245  		if v.VisitExtensionRangeNode != nil {
   246  			matched = true
   247  			ok, next = v.VisitExtensionRangeNode(n)
   248  		}
   249  	case *ReservedNode:
   250  		if v.VisitReservedNode != nil {
   251  			matched = true
   252  			ok, next = v.VisitReservedNode(n)
   253  		}
   254  	case *RangeNode:
   255  		if v.VisitRangeNode != nil {
   256  			matched = true
   257  			ok, next = v.VisitRangeNode(n)
   258  		}
   259  	case *FieldNode:
   260  		if v.VisitFieldNode != nil {
   261  			matched = true
   262  			ok, next = v.VisitFieldNode(n)
   263  		}
   264  	case *GroupNode:
   265  		if v.VisitGroupNode != nil {
   266  			matched = true
   267  			ok, next = v.VisitGroupNode(n)
   268  		}
   269  	case *MapFieldNode:
   270  		if v.VisitMapFieldNode != nil {
   271  			matched = true
   272  			ok, next = v.VisitMapFieldNode(n)
   273  		}
   274  	case *MapTypeNode:
   275  		if v.VisitMapTypeNode != nil {
   276  			matched = true
   277  			ok, next = v.VisitMapTypeNode(n)
   278  		}
   279  	case *OneOfNode:
   280  		if v.VisitOneOfNode != nil {
   281  			matched = true
   282  			ok, next = v.VisitOneOfNode(n)
   283  		}
   284  	case *EnumNode:
   285  		if v.VisitEnumNode != nil {
   286  			matched = true
   287  			ok, next = v.VisitEnumNode(n)
   288  		}
   289  	case *EnumValueNode:
   290  		if v.VisitEnumValueNode != nil {
   291  			matched = true
   292  			ok, next = v.VisitEnumValueNode(n)
   293  		}
   294  	case *ServiceNode:
   295  		if v.VisitServiceNode != nil {
   296  			matched = true
   297  			ok, next = v.VisitServiceNode(n)
   298  		}
   299  	case *RPCNode:
   300  		if v.VisitRPCNode != nil {
   301  			matched = true
   302  			ok, next = v.VisitRPCNode(n)
   303  		}
   304  	case *RPCTypeNode:
   305  		if v.VisitRPCTypeNode != nil {
   306  			matched = true
   307  			ok, next = v.VisitRPCTypeNode(n)
   308  		}
   309  	case *IdentNode:
   310  		if v.VisitIdentNode != nil {
   311  			matched = true
   312  			ok, next = v.VisitIdentNode(n)
   313  		}
   314  	case *CompoundIdentNode:
   315  		if v.VisitCompoundIdentNode != nil {
   316  			matched = true
   317  			ok, next = v.VisitCompoundIdentNode(n)
   318  		}
   319  	case *StringLiteralNode:
   320  		if v.VisitStringLiteralNode != nil {
   321  			matched = true
   322  			ok, next = v.VisitStringLiteralNode(n)
   323  		}
   324  	case *CompoundStringLiteralNode:
   325  		if v.VisitCompoundStringLiteralNode != nil {
   326  			matched = true
   327  			ok, next = v.VisitCompoundStringLiteralNode(n)
   328  		}
   329  	case *UintLiteralNode:
   330  		if v.VisitUintLiteralNode != nil {
   331  			matched = true
   332  			ok, next = v.VisitUintLiteralNode(n)
   333  		}
   334  	case *PositiveUintLiteralNode:
   335  		if v.VisitPositiveUintLiteralNode != nil {
   336  			matched = true
   337  			ok, next = v.VisitPositiveUintLiteralNode(n)
   338  		}
   339  	case *NegativeIntLiteralNode:
   340  		if v.VisitNegativeIntLiteralNode != nil {
   341  			matched = true
   342  			ok, next = v.VisitNegativeIntLiteralNode(n)
   343  		}
   344  	case *FloatLiteralNode:
   345  		if v.VisitFloatLiteralNode != nil {
   346  			matched = true
   347  			ok, next = v.VisitFloatLiteralNode(n)
   348  		}
   349  	case *SpecialFloatLiteralNode:
   350  		if v.VisitSpecialFloatLiteralNode != nil {
   351  			matched = true
   352  			ok, next = v.VisitSpecialFloatLiteralNode(n)
   353  		}
   354  	case *SignedFloatLiteralNode:
   355  		if v.VisitSignedFloatLiteralNode != nil {
   356  			matched = true
   357  			ok, next = v.VisitSignedFloatLiteralNode(n)
   358  		}
   359  	case *BoolLiteralNode:
   360  		if v.VisitBoolLiteralNode != nil {
   361  			matched = true
   362  			ok, next = v.VisitBoolLiteralNode(n)
   363  		}
   364  	case *ArrayLiteralNode:
   365  		if v.VisitArrayLiteralNode != nil {
   366  			matched = true
   367  			ok, next = v.VisitArrayLiteralNode(n)
   368  		}
   369  	case *MessageLiteralNode:
   370  		if v.VisitMessageLiteralNode != nil {
   371  			matched = true
   372  			ok, next = v.VisitMessageLiteralNode(n)
   373  		}
   374  	case *MessageFieldNode:
   375  		if v.VisitMessageFieldNode != nil {
   376  			matched = true
   377  			ok, next = v.VisitMessageFieldNode(n)
   378  		}
   379  	case *KeywordNode:
   380  		if v.VisitKeywordNode != nil {
   381  			matched = true
   382  			ok, next = v.VisitKeywordNode(n)
   383  		}
   384  	case *RuneNode:
   385  		if v.VisitRuneNode != nil {
   386  			matched = true
   387  			ok, next = v.VisitRuneNode(n)
   388  		}
   389  	case *EmptyDeclNode:
   390  		if v.VisitEmptyDeclNode != nil {
   391  			matched = true
   392  			ok, next = v.VisitEmptyDeclNode(n)
   393  		}
   394  	}
   395  
   396  	if !matched {
   397  		// Visitor provided no concrete type visit function, so
   398  		// check interface types. We do this in several passes
   399  		// to provide "priority" for matched interfaces for nodes
   400  		// that actually implement more than one interface.
   401  		//
   402  		// For example, StringLiteralNode implements both
   403  		// StringValueNode and ValueNode. Both cases could match
   404  		// so the first case is what would match. So if we want
   405  		// to test against either, they need to be in different
   406  		// switch statements.
   407  		switch n := n.(type) {
   408  		case FieldDeclNode:
   409  			if v.VisitFieldDeclNode != nil {
   410  				matched = true
   411  				ok, next = v.VisitFieldDeclNode(n)
   412  			}
   413  		case IdentValueNode:
   414  			if v.VisitIdentValueNode != nil {
   415  				matched = true
   416  				ok, next = v.VisitIdentValueNode(n)
   417  			}
   418  		case StringValueNode:
   419  			if v.VisitStringValueNode != nil {
   420  				matched = true
   421  				ok, next = v.VisitStringValueNode(n)
   422  			}
   423  		case IntValueNode:
   424  			if v.VisitIntValueNode != nil {
   425  				matched = true
   426  				ok, next = v.VisitIntValueNode(n)
   427  			}
   428  		}
   429  	}
   430  
   431  	if !matched {
   432  		// These two are excluded from the above switch so that
   433  		// if visitor provides both VisitIntValueNode and
   434  		// VisitFloatValueNode, we'll prefer VisitIntValueNode
   435  		// for *UintLiteralNode (which implements both). Similarly,
   436  		// that way we prefer VisitFieldDeclNode over
   437  		// VisitMessageDeclNode when visiting a *GroupNode.
   438  		switch n := n.(type) {
   439  		case FloatValueNode:
   440  			if v.VisitFloatValueNode != nil {
   441  				matched = true
   442  				ok, next = v.VisitFloatValueNode(n)
   443  			}
   444  		case MessageDeclNode:
   445  			if v.VisitMessageDeclNode != nil {
   446  				matched = true
   447  				ok, next = v.VisitMessageDeclNode(n)
   448  			}
   449  		}
   450  	}
   451  
   452  	if !matched {
   453  		switch n := n.(type) {
   454  		case ValueNode:
   455  			if v.VisitValueNode != nil {
   456  				matched = true
   457  				ok, next = v.VisitValueNode(n)
   458  			}
   459  		}
   460  	}
   461  
   462  	if !matched {
   463  		switch n := n.(type) {
   464  		case TerminalNode:
   465  			if v.VisitTerminalNode != nil {
   466  				matched = true
   467  				ok, next = v.VisitTerminalNode(n)
   468  			}
   469  		case CompositeNode:
   470  			if v.VisitCompositeNode != nil {
   471  				matched = true
   472  				ok, next = v.VisitCompositeNode(n)
   473  			}
   474  		}
   475  	}
   476  
   477  	if !matched {
   478  		// finally, fallback to most generic visit function
   479  		if v.VisitNode != nil {
   480  			matched = true
   481  			ok, next = v.VisitNode(n)
   482  		}
   483  	}
   484  
   485  	if !matched {
   486  		// keep descending with the current visitor
   487  		return true, nil
   488  	}
   489  
   490  	if !ok {
   491  		return false, nil
   492  	}
   493  	if next != nil {
   494  		return true, next.Visit
   495  	}
   496  	return true, v.Visit
   497  }