github.com/hirochachacha/plua@v0.0.0-20170217012138-c82f520cc725/compiler/ast/walk.go (about)

     1  package ast
     2  
     3  type visit func(node Node) (skip bool)
     4  
     5  func Find(node Node, typ Type) (ret Node) {
     6  	Walk(node, func(node Node) bool {
     7  		if node.Type() == typ {
     8  			ret = node
     9  
    10  			return true
    11  		}
    12  		return false
    13  	})
    14  
    15  	return
    16  }
    17  
    18  func Walk(node Node, visit func(node Node) (skip bool)) {
    19  	walk(node, visit)
    20  }
    21  
    22  func walkComment(comment *Comment, visit visit) bool {
    23  	if comment == nil {
    24  		return false
    25  	}
    26  
    27  	return visit(comment)
    28  }
    29  
    30  func walkName(name *Name, visit visit) bool {
    31  	if name == nil {
    32  		return false
    33  	}
    34  
    35  	return visit(name)
    36  }
    37  
    38  func walkParamList(params *ParamList, visit visit) bool {
    39  	if params == nil {
    40  		return false
    41  	}
    42  
    43  	if visit(params) {
    44  		return true
    45  	}
    46  
    47  	for _, name := range params.List {
    48  		if walkName(name, visit) {
    49  			return true
    50  		}
    51  	}
    52  	return false
    53  }
    54  
    55  func walkBlock(block *Block, visit visit) bool {
    56  	if block == nil {
    57  		return false
    58  	}
    59  
    60  	if visit(block) {
    61  		return true
    62  	}
    63  
    64  	for _, e := range block.List {
    65  		if walkStmt(e, visit) {
    66  			return true
    67  		}
    68  	}
    69  	return false
    70  }
    71  
    72  func walkFuncBody(fbody *FuncBody, visit visit) bool {
    73  	if fbody == nil {
    74  		return false
    75  	}
    76  
    77  	if visit(fbody) {
    78  		return true
    79  	}
    80  
    81  	if walkParamList(fbody.Params, visit) {
    82  		return true
    83  	}
    84  	return walkBlock(fbody.Body, visit)
    85  }
    86  
    87  func walkExpr(expr Expr, visit visit) bool {
    88  	if expr == nil {
    89  		return false
    90  	}
    91  
    92  	if visit(expr) {
    93  		return true
    94  	}
    95  
    96  	switch expr := expr.(type) {
    97  	case *BadExpr:
    98  	case *Name:
    99  	case *Vararg:
   100  	case *BasicLit:
   101  	case *FuncLit:
   102  		return walkFuncBody(expr.Body, visit)
   103  	case *TableLit:
   104  		for _, e := range expr.Fields {
   105  			if walkExpr(e, visit) {
   106  				return true
   107  			}
   108  		}
   109  	case *ParenExpr:
   110  		return walkExpr(expr.X, visit)
   111  	case *SelectorExpr:
   112  		if walk(expr.X, visit) {
   113  			return true
   114  		}
   115  		return walkName(expr.Sel, visit)
   116  	case *IndexExpr:
   117  		if walkExpr(expr.X, visit) {
   118  			return true
   119  		}
   120  		return walkExpr(expr.Index, visit)
   121  	case *CallExpr:
   122  		if walkExpr(expr.X, visit) {
   123  			return true
   124  		}
   125  		if walkName(expr.Name, visit) {
   126  			return true
   127  		}
   128  		for _, arg := range expr.Args {
   129  			if walkExpr(arg, visit) {
   130  				return true
   131  			}
   132  		}
   133  	case *UnaryExpr:
   134  		return walkExpr(expr.X, visit)
   135  	case *BinaryExpr:
   136  		if walkExpr(expr.X, visit) {
   137  			return true
   138  		}
   139  		return walkExpr(expr.Y, visit)
   140  	case *KeyValueExpr:
   141  		if walkExpr(expr.Key, visit) {
   142  			return true
   143  		}
   144  		return walkExpr(expr.Value, visit)
   145  	default:
   146  		panic("unreachable")
   147  	}
   148  
   149  	return false
   150  }
   151  
   152  func walkStmt(stmt Stmt, visit visit) bool {
   153  	if stmt == nil {
   154  		return false
   155  	}
   156  
   157  	if visit(stmt) {
   158  		return true
   159  	}
   160  
   161  	switch stmt := stmt.(type) {
   162  	case *BadStmt:
   163  	case *EmptyStmt:
   164  	case *LocalAssignStmt:
   165  		for _, name := range stmt.LHS {
   166  			if walkName(name, visit) {
   167  				return true
   168  			}
   169  		}
   170  		for _, e := range stmt.RHS {
   171  			if walkExpr(e, visit) {
   172  				return true
   173  			}
   174  		}
   175  	case *LocalFuncStmt:
   176  		if walkName(stmt.Name, visit) {
   177  			return true
   178  		}
   179  		return walkFuncBody(stmt.Body, visit)
   180  	case *FuncStmt:
   181  		for _, name := range stmt.PathList {
   182  			if walkName(name, visit) {
   183  				return true
   184  			}
   185  		}
   186  		if walkName(stmt.Name, visit) {
   187  			return true
   188  		}
   189  		return walkFuncBody(stmt.Body, visit)
   190  	case *LabelStmt:
   191  		return walkName(stmt.Name, visit)
   192  	case *ExprStmt:
   193  		x := stmt.X
   194  
   195  		if x == nil {
   196  			return false
   197  		}
   198  
   199  		if walkExpr(x.X, visit) {
   200  			return true
   201  		}
   202  		if walkName(x.Name, visit) {
   203  			return true
   204  		}
   205  		for _, arg := range x.Args {
   206  			if walkExpr(arg, visit) {
   207  				return true
   208  			}
   209  		}
   210  	case *AssignStmt:
   211  		for _, e := range stmt.LHS {
   212  			if walkExpr(e, visit) {
   213  				return true
   214  			}
   215  		}
   216  		for _, e := range stmt.RHS {
   217  			if walkExpr(e, visit) {
   218  				return true
   219  			}
   220  		}
   221  	case *GotoStmt:
   222  		return walkName(stmt.Label, visit)
   223  	case *BreakStmt:
   224  	case *IfStmt:
   225  		if walkExpr(stmt.Cond, visit) {
   226  			return true
   227  		}
   228  		if walkBlock(stmt.Body, visit) {
   229  			return true
   230  		}
   231  		for _, s := range stmt.ElseIfList {
   232  			if walkExpr(s.Cond, visit) {
   233  				return true
   234  			}
   235  			if walkBlock(s.Body, visit) {
   236  				return true
   237  			}
   238  		}
   239  		return walkBlock(stmt.ElseBody, visit)
   240  	case *DoStmt:
   241  		return walkBlock(stmt.Body, visit)
   242  	case *WhileStmt:
   243  		if walkExpr(stmt.Cond, visit) {
   244  			return true
   245  		}
   246  		return walkBlock(stmt.Body, visit)
   247  	case *RepeatStmt:
   248  		if walkExpr(stmt.Cond, visit) {
   249  			return true
   250  		}
   251  		return walkBlock(stmt.Body, visit)
   252  	case *ReturnStmt:
   253  		for _, ret := range stmt.Results {
   254  			if walkExpr(ret, visit) {
   255  				return true
   256  			}
   257  		}
   258  	case *ForStmt:
   259  		if walkName(stmt.Name, visit) {
   260  			return true
   261  		}
   262  		if walkExpr(stmt.Start, visit) {
   263  			return true
   264  		}
   265  		if walkExpr(stmt.Finish, visit) {
   266  			return true
   267  		}
   268  		if walkExpr(stmt.Step, visit) {
   269  			return true
   270  		}
   271  		return walkBlock(stmt.Body, visit)
   272  	case *ForEachStmt:
   273  		for _, name := range stmt.Names {
   274  			if walkName(name, visit) {
   275  				return true
   276  			}
   277  		}
   278  		for _, e := range stmt.Exprs {
   279  			if walkExpr(e, visit) {
   280  				return true
   281  			}
   282  		}
   283  		return walkBlock(stmt.Body, visit)
   284  	default:
   285  		panic("unreachable")
   286  	}
   287  
   288  	return false
   289  }
   290  
   291  func walk(node Node, visit visit) (skip bool) {
   292  	if node == nil {
   293  		return false
   294  	}
   295  
   296  	if visit(node) {
   297  		return true
   298  	}
   299  
   300  	switch node := node.(type) {
   301  	case *Comment:
   302  	case *CommentGroup:
   303  		for _, c := range node.List {
   304  			if walkComment(c, visit) {
   305  				return true
   306  			}
   307  		}
   308  
   309  	case *ParamList:
   310  		for _, name := range node.List {
   311  			if walkName(name, visit) {
   312  				return true
   313  			}
   314  		}
   315  
   316  	case *BadExpr:
   317  	case *Name:
   318  	case *Vararg:
   319  	case *BasicLit:
   320  	case *FuncLit:
   321  		return walkFuncBody(node.Body, visit)
   322  	case *TableLit:
   323  		for _, e := range node.Fields {
   324  			if walkExpr(e, visit) {
   325  				return true
   326  			}
   327  		}
   328  	case *ParenExpr:
   329  		return walkExpr(node.X, visit)
   330  	case *SelectorExpr:
   331  		if walk(node.X, visit) {
   332  			return true
   333  		}
   334  		return walkName(node.Sel, visit)
   335  	case *IndexExpr:
   336  		if walkExpr(node.X, visit) {
   337  			return true
   338  		}
   339  		return walkExpr(node.Index, visit)
   340  	case *CallExpr:
   341  		if walkExpr(node.X, visit) {
   342  			return true
   343  		}
   344  		if walkName(node.Name, visit) {
   345  			return true
   346  		}
   347  		for _, arg := range node.Args {
   348  			if walkExpr(arg, visit) {
   349  				return true
   350  			}
   351  		}
   352  	case *UnaryExpr:
   353  		return walkExpr(node.X, visit)
   354  	case *BinaryExpr:
   355  		if walkExpr(node.X, visit) {
   356  			return true
   357  		}
   358  		return walkExpr(node.Y, visit)
   359  	case *KeyValueExpr:
   360  		if walkExpr(node.Key, visit) {
   361  			return true
   362  		}
   363  		return walkExpr(node.Value, visit)
   364  
   365  	case *BadStmt:
   366  	case *EmptyStmt:
   367  	case *LocalAssignStmt:
   368  		for _, name := range node.LHS {
   369  			if walkName(name, visit) {
   370  				return true
   371  			}
   372  		}
   373  		for _, e := range node.RHS {
   374  			if walkExpr(e, visit) {
   375  				return true
   376  			}
   377  		}
   378  	case *LocalFuncStmt:
   379  		if walkName(node.Name, visit) {
   380  			return true
   381  		}
   382  		return walkFuncBody(node.Body, visit)
   383  	case *FuncStmt:
   384  		for _, name := range node.PathList {
   385  			if walkName(name, visit) {
   386  				return true
   387  			}
   388  		}
   389  		if walkName(node.Name, visit) {
   390  			return true
   391  		}
   392  		return walkFuncBody(node.Body, visit)
   393  	case *LabelStmt:
   394  		return walkName(node.Name, visit)
   395  	case *ExprStmt:
   396  		x := node.X
   397  
   398  		if x == nil {
   399  			return false
   400  		}
   401  
   402  		if walkExpr(x.X, visit) {
   403  			return true
   404  		}
   405  		if walkName(x.Name, visit) {
   406  			return true
   407  		}
   408  		for _, arg := range x.Args {
   409  			if walkExpr(arg, visit) {
   410  				return true
   411  			}
   412  		}
   413  	case *AssignStmt:
   414  		for _, e := range node.LHS {
   415  			if walkExpr(e, visit) {
   416  				return true
   417  			}
   418  		}
   419  		for _, e := range node.RHS {
   420  			if walkExpr(e, visit) {
   421  				return true
   422  			}
   423  		}
   424  	case *GotoStmt:
   425  		return walkName(node.Label, visit)
   426  	case *BreakStmt:
   427  	case *IfStmt:
   428  		if walkExpr(node.Cond, visit) {
   429  			return true
   430  		}
   431  		if walkBlock(node.Body, visit) {
   432  			return true
   433  		}
   434  		for _, n := range node.ElseIfList {
   435  			if walkExpr(n.Cond, visit) {
   436  				return true
   437  			}
   438  			if walkBlock(n.Body, visit) {
   439  				return true
   440  			}
   441  		}
   442  		return walkBlock(node.ElseBody, visit)
   443  	case *DoStmt:
   444  		return walkBlock(node.Body, visit)
   445  	case *WhileStmt:
   446  		if walkExpr(node.Cond, visit) {
   447  			return true
   448  		}
   449  		return walkBlock(node.Body, visit)
   450  	case *RepeatStmt:
   451  		if walkExpr(node.Cond, visit) {
   452  			return true
   453  		}
   454  		return walkBlock(node.Body, visit)
   455  	case *ReturnStmt:
   456  		for _, ret := range node.Results {
   457  			if walkExpr(ret, visit) {
   458  				return true
   459  			}
   460  		}
   461  	case *ForStmt:
   462  		if walkName(node.Name, visit) {
   463  			return true
   464  		}
   465  		if walkExpr(node.Start, visit) {
   466  			return true
   467  		}
   468  		if walkExpr(node.Finish, visit) {
   469  			return true
   470  		}
   471  		if walkExpr(node.Step, visit) {
   472  			return true
   473  		}
   474  		return walkBlock(node.Body, visit)
   475  	case *ForEachStmt:
   476  		for _, name := range node.Names {
   477  			if walkName(name, visit) {
   478  				return true
   479  			}
   480  		}
   481  		for _, e := range node.Exprs {
   482  			if walkExpr(e, visit) {
   483  				return true
   484  			}
   485  		}
   486  		return walkBlock(node.Body, visit)
   487  
   488  	case *File:
   489  		for _, e := range node.Chunk {
   490  			if walkStmt(e, visit) {
   491  				return true
   492  			}
   493  		}
   494  	case *Block:
   495  		for _, e := range node.List {
   496  			if walkStmt(e, visit) {
   497  				return true
   498  			}
   499  		}
   500  		return false
   501  	case *FuncBody:
   502  		if walkParamList(node.Params, visit) {
   503  			return true
   504  		}
   505  		return walkBlock(node.Body, visit)
   506  	default:
   507  		panic("unreachable")
   508  	}
   509  
   510  	return false
   511  }