github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/cmd/compile/syntax/walk.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // This file implements syntax tree walking.
     6  
     7  package syntax
     8  
     9  import "fmt"
    10  
    11  // Inspect traverses an AST in pre-order: it starts by calling f(root);
    12  // root must not be nil. If f returns true, Inspect invokes f recursively
    13  // for each of the non-nil children of root, followed by a call of f(nil).
    14  //
    15  // See Walk for caveats about shared nodes.
    16  func Inspect(root Node, f func(Node) bool) {
    17  	Walk(root, inspector(f))
    18  }
    19  
    20  type inspector func(Node) bool
    21  
    22  func (v inspector) Visit(node Node) Visitor {
    23  	if v(node) {
    24  		return v
    25  	}
    26  	return nil
    27  }
    28  
    29  // Walk traverses an AST in pre-order: It starts by calling
    30  // v.Visit(node); node must not be nil. If the visitor w returned by
    31  // v.Visit(node) is not nil, Walk is invoked recursively with visitor
    32  // w for each of the non-nil children of node, followed by a call of
    33  // w.Visit(nil).
    34  //
    35  // Some nodes may be shared among multiple parent nodes (e.g., types in
    36  // field lists such as type T in "a, b, c T"). Such shared nodes are
    37  // walked multiple times.
    38  // TODO(gri) Revisit this design. It may make sense to walk those nodes
    39  // only once. A place where this matters is types2.TestResolveIdents.
    40  func Walk(root Node, v Visitor) {
    41  	walker{v}.node(root)
    42  }
    43  
    44  // A Visitor's Visit method is invoked for each node encountered by Walk.
    45  // If the result visitor w is not nil, Walk visits each of the children
    46  // of node with the visitor w, followed by a call of w.Visit(nil).
    47  type Visitor interface {
    48  	Visit(node Node) (w Visitor)
    49  }
    50  
    51  type walker struct {
    52  	v Visitor
    53  }
    54  
    55  func (w walker) node(n Node) {
    56  	if n == nil {
    57  		panic("nil node")
    58  	}
    59  
    60  	w.v = w.v.Visit(n)
    61  	if w.v == nil {
    62  		return
    63  	}
    64  
    65  	switch n := n.(type) {
    66  	// packages
    67  	case *File:
    68  		w.node(n.PkgName)
    69  		w.declList(n.DeclList)
    70  
    71  	// declarations
    72  	case *ImportDecl:
    73  		if n.LocalPkgName != nil {
    74  			w.node(n.LocalPkgName)
    75  		}
    76  		w.node(n.Path)
    77  
    78  	case *ConstDecl:
    79  		w.nameList(n.NameList)
    80  		if n.Type != nil {
    81  			w.node(n.Type)
    82  		}
    83  		if n.Values != nil {
    84  			w.node(n.Values)
    85  		}
    86  
    87  	case *TypeDecl:
    88  		w.node(n.Name)
    89  		w.fieldList(n.TParamList)
    90  		w.node(n.Type)
    91  
    92  	case *VarDecl:
    93  		w.nameList(n.NameList)
    94  		if n.Type != nil {
    95  			w.node(n.Type)
    96  		}
    97  		if n.Values != nil {
    98  			w.node(n.Values)
    99  		}
   100  
   101  	case *FuncDecl:
   102  		if n.Recv != nil {
   103  			w.node(n.Recv)
   104  		}
   105  		w.node(n.Name)
   106  		w.fieldList(n.TParamList)
   107  		w.node(n.Type)
   108  		if n.Body != nil {
   109  			w.node(n.Body)
   110  		}
   111  
   112  	// expressions
   113  	case *BadExpr: // nothing to do
   114  	case *Name: // nothing to do
   115  	case *BasicLit: // nothing to do
   116  
   117  	case *CompositeLit:
   118  		if n.Type != nil {
   119  			w.node(n.Type)
   120  		}
   121  		w.exprList(n.ElemList)
   122  
   123  	case *KeyValueExpr:
   124  		w.node(n.Key)
   125  		w.node(n.Value)
   126  
   127  	case *FuncLit:
   128  		w.node(n.Type)
   129  		w.node(n.Body)
   130  
   131  	case *ParenExpr:
   132  		w.node(n.X)
   133  
   134  	case *SelectorExpr:
   135  		w.node(n.X)
   136  		w.node(n.Sel)
   137  
   138  	case *IndexExpr:
   139  		w.node(n.X)
   140  		w.node(n.Index)
   141  
   142  	case *SliceExpr:
   143  		w.node(n.X)
   144  		for _, x := range n.Index {
   145  			if x != nil {
   146  				w.node(x)
   147  			}
   148  		}
   149  
   150  	case *AssertExpr:
   151  		w.node(n.X)
   152  		w.node(n.Type)
   153  
   154  	case *TypeSwitchGuard:
   155  		if n.Lhs != nil {
   156  			w.node(n.Lhs)
   157  		}
   158  		w.node(n.X)
   159  
   160  	case *Operation:
   161  		w.node(n.X)
   162  		if n.Y != nil {
   163  			w.node(n.Y)
   164  		}
   165  
   166  	case *CallExpr:
   167  		w.node(n.Fun)
   168  		w.exprList(n.ArgList)
   169  
   170  	case *ListExpr:
   171  		w.exprList(n.ElemList)
   172  
   173  	// types
   174  	case *ArrayType:
   175  		if n.Len != nil {
   176  			w.node(n.Len)
   177  		}
   178  		w.node(n.Elem)
   179  
   180  	case *SliceType:
   181  		w.node(n.Elem)
   182  
   183  	case *DotsType:
   184  		w.node(n.Elem)
   185  
   186  	case *StructType:
   187  		w.fieldList(n.FieldList)
   188  		for _, t := range n.TagList {
   189  			if t != nil {
   190  				w.node(t)
   191  			}
   192  		}
   193  
   194  	case *Field:
   195  		if n.Name != nil {
   196  			w.node(n.Name)
   197  		}
   198  		w.node(n.Type)
   199  
   200  	case *InterfaceType:
   201  		w.fieldList(n.MethodList)
   202  
   203  	case *FuncType:
   204  		w.fieldList(n.ParamList)
   205  		w.fieldList(n.ResultList)
   206  
   207  	case *MapType:
   208  		w.node(n.Key)
   209  		w.node(n.Value)
   210  
   211  	case *ChanType:
   212  		w.node(n.Elem)
   213  
   214  	// statements
   215  	case *EmptyStmt: // nothing to do
   216  
   217  	case *LabeledStmt:
   218  		w.node(n.Label)
   219  		w.node(n.Stmt)
   220  
   221  	case *BlockStmt:
   222  		w.stmtList(n.List)
   223  
   224  	case *ExprStmt:
   225  		w.node(n.X)
   226  
   227  	case *SendStmt:
   228  		w.node(n.Chan)
   229  		w.node(n.Value)
   230  
   231  	case *DeclStmt:
   232  		w.declList(n.DeclList)
   233  
   234  	case *AssignStmt:
   235  		w.node(n.Lhs)
   236  		if n.Rhs != nil {
   237  			w.node(n.Rhs)
   238  		}
   239  
   240  	case *BranchStmt:
   241  		if n.Label != nil {
   242  			w.node(n.Label)
   243  		}
   244  		// Target points to nodes elsewhere in the syntax tree
   245  
   246  	case *CallStmt:
   247  		w.node(n.Call)
   248  
   249  	case *ReturnStmt:
   250  		if n.Results != nil {
   251  			w.node(n.Results)
   252  		}
   253  
   254  	case *IfStmt:
   255  		if n.Init != nil {
   256  			w.node(n.Init)
   257  		}
   258  		w.node(n.Cond)
   259  		w.node(n.Then)
   260  		if n.Else != nil {
   261  			w.node(n.Else)
   262  		}
   263  
   264  	case *ForStmt:
   265  		if n.Init != nil {
   266  			w.node(n.Init)
   267  		}
   268  		if n.Cond != nil {
   269  			w.node(n.Cond)
   270  		}
   271  		if n.Post != nil {
   272  			w.node(n.Post)
   273  		}
   274  		w.node(n.Body)
   275  
   276  	case *SwitchStmt:
   277  		if n.Init != nil {
   278  			w.node(n.Init)
   279  		}
   280  		if n.Tag != nil {
   281  			w.node(n.Tag)
   282  		}
   283  		for _, s := range n.Body {
   284  			w.node(s)
   285  		}
   286  
   287  	case *SelectStmt:
   288  		for _, s := range n.Body {
   289  			w.node(s)
   290  		}
   291  
   292  	// helper nodes
   293  	case *RangeClause:
   294  		if n.Lhs != nil {
   295  			w.node(n.Lhs)
   296  		}
   297  		w.node(n.X)
   298  
   299  	case *CaseClause:
   300  		if n.Cases != nil {
   301  			w.node(n.Cases)
   302  		}
   303  		w.stmtList(n.Body)
   304  
   305  	case *CommClause:
   306  		if n.Comm != nil {
   307  			w.node(n.Comm)
   308  		}
   309  		w.stmtList(n.Body)
   310  
   311  	default:
   312  		panic(fmt.Sprintf("internal error: unknown node type %T", n))
   313  	}
   314  
   315  	w.v.Visit(nil)
   316  }
   317  
   318  func (w walker) declList(list []Decl) {
   319  	for _, n := range list {
   320  		w.node(n)
   321  	}
   322  }
   323  
   324  func (w walker) exprList(list []Expr) {
   325  	for _, n := range list {
   326  		w.node(n)
   327  	}
   328  }
   329  
   330  func (w walker) stmtList(list []Stmt) {
   331  	for _, n := range list {
   332  		w.node(n)
   333  	}
   334  }
   335  
   336  func (w walker) nameList(list []*Name) {
   337  	for _, n := range list {
   338  		w.node(n)
   339  	}
   340  }
   341  
   342  func (w walker) fieldList(list []*Field) {
   343  	for _, n := range list {
   344  		w.node(n)
   345  	}
   346  }