github.com/rajeev159/opa@v0.45.0/ast/visit.go (about)

     1  // Copyright 2016 The OPA Authors.  All rights reserved.
     2  // Use of this source code is governed by an Apache2
     3  // license that can be found in the LICENSE file.
     4  
     5  package ast
     6  
     7  // Visitor defines the interface for iterating AST elements. The Visit function
     8  // can return a Visitor w which will be used to visit the children of the AST
     9  // element v. If the Visit function returns nil, the children will not be
    10  // visited. This is deprecated.
    11  type Visitor interface {
    12  	Visit(v interface{}) (w Visitor)
    13  }
    14  
    15  // BeforeAndAfterVisitor wraps Visitor to provide hooks for being called before
    16  // and after the AST has been visited. This is deprecated.
    17  type BeforeAndAfterVisitor interface {
    18  	Visitor
    19  	Before(x interface{})
    20  	After(x interface{})
    21  }
    22  
    23  // Walk iterates the AST by calling the Visit function on the Visitor
    24  // v for x before recursing. This is deprecated.
    25  func Walk(v Visitor, x interface{}) {
    26  	if bav, ok := v.(BeforeAndAfterVisitor); !ok {
    27  		walk(v, x)
    28  	} else {
    29  		bav.Before(x)
    30  		defer bav.After(x)
    31  		walk(bav, x)
    32  	}
    33  }
    34  
    35  // WalkBeforeAndAfter iterates the AST by calling the Visit function on the
    36  // Visitor v for x before recursing. This is deprecated.
    37  func WalkBeforeAndAfter(v BeforeAndAfterVisitor, x interface{}) {
    38  	Walk(v, x)
    39  }
    40  
    41  func walk(v Visitor, x interface{}) {
    42  	w := v.Visit(x)
    43  	if w == nil {
    44  		return
    45  	}
    46  	switch x := x.(type) {
    47  	case *Module:
    48  		Walk(w, x.Package)
    49  		for _, i := range x.Imports {
    50  			Walk(w, i)
    51  		}
    52  		for _, r := range x.Rules {
    53  			Walk(w, r)
    54  		}
    55  		for _, a := range x.Annotations {
    56  			Walk(w, a)
    57  		}
    58  		for _, c := range x.Comments {
    59  			Walk(w, c)
    60  		}
    61  	case *Package:
    62  		Walk(w, x.Path)
    63  	case *Import:
    64  		Walk(w, x.Path)
    65  		Walk(w, x.Alias)
    66  	case *Rule:
    67  		Walk(w, x.Head)
    68  		Walk(w, x.Body)
    69  		if x.Else != nil {
    70  			Walk(w, x.Else)
    71  		}
    72  	case *Head:
    73  		Walk(w, x.Name)
    74  		Walk(w, x.Args)
    75  		if x.Key != nil {
    76  			Walk(w, x.Key)
    77  		}
    78  		if x.Value != nil {
    79  			Walk(w, x.Value)
    80  		}
    81  	case Body:
    82  		for _, e := range x {
    83  			Walk(w, e)
    84  		}
    85  	case Args:
    86  		for _, t := range x {
    87  			Walk(w, t)
    88  		}
    89  	case *Expr:
    90  		switch ts := x.Terms.(type) {
    91  		case *Term, *SomeDecl, *Every:
    92  			Walk(w, ts)
    93  		case []*Term:
    94  			for _, t := range ts {
    95  				Walk(w, t)
    96  			}
    97  		}
    98  		for i := range x.With {
    99  			Walk(w, x.With[i])
   100  		}
   101  	case *With:
   102  		Walk(w, x.Target)
   103  		Walk(w, x.Value)
   104  	case *Term:
   105  		Walk(w, x.Value)
   106  	case Ref:
   107  		for _, t := range x {
   108  			Walk(w, t)
   109  		}
   110  	case *object:
   111  		x.Foreach(func(k, vv *Term) {
   112  			Walk(w, k)
   113  			Walk(w, vv)
   114  		})
   115  	case *Array:
   116  		x.Foreach(func(t *Term) {
   117  			Walk(w, t)
   118  		})
   119  	case Set:
   120  		x.Foreach(func(t *Term) {
   121  			Walk(w, t)
   122  		})
   123  	case *ArrayComprehension:
   124  		Walk(w, x.Term)
   125  		Walk(w, x.Body)
   126  	case *ObjectComprehension:
   127  		Walk(w, x.Key)
   128  		Walk(w, x.Value)
   129  		Walk(w, x.Body)
   130  	case *SetComprehension:
   131  		Walk(w, x.Term)
   132  		Walk(w, x.Body)
   133  	case Call:
   134  		for _, t := range x {
   135  			Walk(w, t)
   136  		}
   137  	case *Every:
   138  		if x.Key != nil {
   139  			Walk(w, x.Key)
   140  		}
   141  		Walk(w, x.Value)
   142  		Walk(w, x.Domain)
   143  		Walk(w, x.Body)
   144  	}
   145  }
   146  
   147  // WalkVars calls the function f on all vars under x. If the function f
   148  // returns true, AST nodes under the last node will not be visited.
   149  func WalkVars(x interface{}, f func(Var) bool) {
   150  	vis := &GenericVisitor{func(x interface{}) bool {
   151  		if v, ok := x.(Var); ok {
   152  			return f(v)
   153  		}
   154  		return false
   155  	}}
   156  	vis.Walk(x)
   157  }
   158  
   159  // WalkClosures calls the function f on all closures under x. If the function f
   160  // returns true, AST nodes under the last node will not be visited.
   161  func WalkClosures(x interface{}, f func(interface{}) bool) {
   162  	vis := &GenericVisitor{func(x interface{}) bool {
   163  		switch x := x.(type) {
   164  		case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every:
   165  			return f(x)
   166  		}
   167  		return false
   168  	}}
   169  	vis.Walk(x)
   170  }
   171  
   172  // WalkRefs calls the function f on all references under x. If the function f
   173  // returns true, AST nodes under the last node will not be visited.
   174  func WalkRefs(x interface{}, f func(Ref) bool) {
   175  	vis := &GenericVisitor{func(x interface{}) bool {
   176  		if r, ok := x.(Ref); ok {
   177  			return f(r)
   178  		}
   179  		return false
   180  	}}
   181  	vis.Walk(x)
   182  }
   183  
   184  // WalkTerms calls the function f on all terms under x. If the function f
   185  // returns true, AST nodes under the last node will not be visited.
   186  func WalkTerms(x interface{}, f func(*Term) bool) {
   187  	vis := &GenericVisitor{func(x interface{}) bool {
   188  		if term, ok := x.(*Term); ok {
   189  			return f(term)
   190  		}
   191  		return false
   192  	}}
   193  	vis.Walk(x)
   194  }
   195  
   196  // WalkWiths calls the function f on all with modifiers under x. If the function f
   197  // returns true, AST nodes under the last node will not be visited.
   198  func WalkWiths(x interface{}, f func(*With) bool) {
   199  	vis := &GenericVisitor{func(x interface{}) bool {
   200  		if w, ok := x.(*With); ok {
   201  			return f(w)
   202  		}
   203  		return false
   204  	}}
   205  	vis.Walk(x)
   206  }
   207  
   208  // WalkExprs calls the function f on all expressions under x. If the function f
   209  // returns true, AST nodes under the last node will not be visited.
   210  func WalkExprs(x interface{}, f func(*Expr) bool) {
   211  	vis := &GenericVisitor{func(x interface{}) bool {
   212  		if r, ok := x.(*Expr); ok {
   213  			return f(r)
   214  		}
   215  		return false
   216  	}}
   217  	vis.Walk(x)
   218  }
   219  
   220  // WalkBodies calls the function f on all bodies under x. If the function f
   221  // returns true, AST nodes under the last node will not be visited.
   222  func WalkBodies(x interface{}, f func(Body) bool) {
   223  	vis := &GenericVisitor{func(x interface{}) bool {
   224  		if b, ok := x.(Body); ok {
   225  			return f(b)
   226  		}
   227  		return false
   228  	}}
   229  	vis.Walk(x)
   230  }
   231  
   232  // WalkRules calls the function f on all rules under x. If the function f
   233  // returns true, AST nodes under the last node will not be visited.
   234  func WalkRules(x interface{}, f func(*Rule) bool) {
   235  	vis := &GenericVisitor{func(x interface{}) bool {
   236  		if r, ok := x.(*Rule); ok {
   237  			stop := f(r)
   238  			// NOTE(tsandall): since rules cannot be embedded inside of queries
   239  			// we can stop early if there is no else block.
   240  			if stop || r.Else == nil {
   241  				return true
   242  			}
   243  		}
   244  		return false
   245  	}}
   246  	vis.Walk(x)
   247  }
   248  
   249  // WalkNodes calls the function f on all nodes under x. If the function f
   250  // returns true, AST nodes under the last node will not be visited.
   251  func WalkNodes(x interface{}, f func(Node) bool) {
   252  	vis := &GenericVisitor{func(x interface{}) bool {
   253  		if n, ok := x.(Node); ok {
   254  			return f(n)
   255  		}
   256  		return false
   257  	}}
   258  	vis.Walk(x)
   259  }
   260  
   261  // GenericVisitor provides a utility to walk over AST nodes using a
   262  // closure. If the closure returns true, the visitor will not walk
   263  // over AST nodes under x.
   264  type GenericVisitor struct {
   265  	f func(x interface{}) bool
   266  }
   267  
   268  // NewGenericVisitor returns a new GenericVisitor that will invoke the function
   269  // f on AST nodes.
   270  func NewGenericVisitor(f func(x interface{}) bool) *GenericVisitor {
   271  	return &GenericVisitor{f}
   272  }
   273  
   274  // Walk iterates the AST by calling the function f on the
   275  // GenericVisitor before recursing. Contrary to the generic Walk, this
   276  // does not require allocating the visitor from heap.
   277  func (vis *GenericVisitor) Walk(x interface{}) {
   278  	if vis.f(x) {
   279  		return
   280  	}
   281  
   282  	switch x := x.(type) {
   283  	case *Module:
   284  		vis.Walk(x.Package)
   285  		for _, i := range x.Imports {
   286  			vis.Walk(i)
   287  		}
   288  		for _, r := range x.Rules {
   289  			vis.Walk(r)
   290  		}
   291  		for _, a := range x.Annotations {
   292  			vis.Walk(a)
   293  		}
   294  		for _, c := range x.Comments {
   295  			vis.Walk(c)
   296  		}
   297  	case *Package:
   298  		vis.Walk(x.Path)
   299  	case *Import:
   300  		vis.Walk(x.Path)
   301  		vis.Walk(x.Alias)
   302  	case *Rule:
   303  		vis.Walk(x.Head)
   304  		vis.Walk(x.Body)
   305  		if x.Else != nil {
   306  			vis.Walk(x.Else)
   307  		}
   308  	case *Head:
   309  		vis.Walk(x.Name)
   310  		vis.Walk(x.Args)
   311  		if x.Key != nil {
   312  			vis.Walk(x.Key)
   313  		}
   314  		if x.Value != nil {
   315  			vis.Walk(x.Value)
   316  		}
   317  	case Body:
   318  		for _, e := range x {
   319  			vis.Walk(e)
   320  		}
   321  	case Args:
   322  		for _, t := range x {
   323  			vis.Walk(t)
   324  		}
   325  	case *Expr:
   326  		switch ts := x.Terms.(type) {
   327  		case *Term, *SomeDecl, *Every:
   328  			vis.Walk(ts)
   329  		case []*Term:
   330  			for _, t := range ts {
   331  				vis.Walk(t)
   332  			}
   333  		}
   334  		for i := range x.With {
   335  			vis.Walk(x.With[i])
   336  		}
   337  	case *With:
   338  		vis.Walk(x.Target)
   339  		vis.Walk(x.Value)
   340  	case *Term:
   341  		vis.Walk(x.Value)
   342  	case Ref:
   343  		for _, t := range x {
   344  			vis.Walk(t)
   345  		}
   346  	case *object:
   347  		x.Foreach(func(k, v *Term) {
   348  			vis.Walk(k)
   349  			vis.Walk(x.Get(k))
   350  		})
   351  	case *Array:
   352  		x.Foreach(func(t *Term) {
   353  			vis.Walk(t)
   354  		})
   355  	case Set:
   356  		for _, t := range x.Slice() {
   357  			vis.Walk(t)
   358  		}
   359  	case *ArrayComprehension:
   360  		vis.Walk(x.Term)
   361  		vis.Walk(x.Body)
   362  	case *ObjectComprehension:
   363  		vis.Walk(x.Key)
   364  		vis.Walk(x.Value)
   365  		vis.Walk(x.Body)
   366  	case *SetComprehension:
   367  		vis.Walk(x.Term)
   368  		vis.Walk(x.Body)
   369  	case Call:
   370  		for _, t := range x {
   371  			vis.Walk(t)
   372  		}
   373  	case *Every:
   374  		if x.Key != nil {
   375  			vis.Walk(x.Key)
   376  		}
   377  		vis.Walk(x.Value)
   378  		vis.Walk(x.Domain)
   379  		vis.Walk(x.Body)
   380  	}
   381  }
   382  
   383  // BeforeAfterVisitor provides a utility to walk over AST nodes using
   384  // closures. If the before closure returns true, the visitor will not
   385  // walk over AST nodes under x. The after closure is invoked always
   386  // after visiting a node.
   387  type BeforeAfterVisitor struct {
   388  	before func(x interface{}) bool
   389  	after  func(x interface{})
   390  }
   391  
   392  // NewBeforeAfterVisitor returns a new BeforeAndAfterVisitor that
   393  // will invoke the functions before and after AST nodes.
   394  func NewBeforeAfterVisitor(before func(x interface{}) bool, after func(x interface{})) *BeforeAfterVisitor {
   395  	return &BeforeAfterVisitor{before, after}
   396  }
   397  
   398  // Walk iterates the AST by calling the functions on the
   399  // BeforeAndAfterVisitor before and after recursing. Contrary to the
   400  // generic Walk, this does not require allocating the visitor from
   401  // heap.
   402  func (vis *BeforeAfterVisitor) Walk(x interface{}) {
   403  	defer vis.after(x)
   404  	if vis.before(x) {
   405  		return
   406  	}
   407  
   408  	switch x := x.(type) {
   409  	case *Module:
   410  		vis.Walk(x.Package)
   411  		for _, i := range x.Imports {
   412  			vis.Walk(i)
   413  		}
   414  		for _, r := range x.Rules {
   415  			vis.Walk(r)
   416  		}
   417  		for _, a := range x.Annotations {
   418  			vis.Walk(a)
   419  		}
   420  		for _, c := range x.Comments {
   421  			vis.Walk(c)
   422  		}
   423  	case *Package:
   424  		vis.Walk(x.Path)
   425  	case *Import:
   426  		vis.Walk(x.Path)
   427  		vis.Walk(x.Alias)
   428  	case *Rule:
   429  		vis.Walk(x.Head)
   430  		vis.Walk(x.Body)
   431  		if x.Else != nil {
   432  			vis.Walk(x.Else)
   433  		}
   434  	case *Head:
   435  		vis.Walk(x.Name)
   436  		vis.Walk(x.Args)
   437  		if x.Key != nil {
   438  			vis.Walk(x.Key)
   439  		}
   440  		if x.Value != nil {
   441  			vis.Walk(x.Value)
   442  		}
   443  	case Body:
   444  		for _, e := range x {
   445  			vis.Walk(e)
   446  		}
   447  	case Args:
   448  		for _, t := range x {
   449  			vis.Walk(t)
   450  		}
   451  	case *Expr:
   452  		switch ts := x.Terms.(type) {
   453  		case *Term, *SomeDecl, *Every:
   454  			vis.Walk(ts)
   455  		case []*Term:
   456  			for _, t := range ts {
   457  				vis.Walk(t)
   458  			}
   459  		}
   460  		for i := range x.With {
   461  			vis.Walk(x.With[i])
   462  		}
   463  	case *With:
   464  		vis.Walk(x.Target)
   465  		vis.Walk(x.Value)
   466  	case *Term:
   467  		vis.Walk(x.Value)
   468  	case Ref:
   469  		for _, t := range x {
   470  			vis.Walk(t)
   471  		}
   472  	case *object:
   473  		x.Foreach(func(k, v *Term) {
   474  			vis.Walk(k)
   475  			vis.Walk(x.Get(k))
   476  		})
   477  	case *Array:
   478  		x.Foreach(func(t *Term) {
   479  			vis.Walk(t)
   480  		})
   481  	case Set:
   482  		for _, t := range x.Slice() {
   483  			vis.Walk(t)
   484  		}
   485  	case *ArrayComprehension:
   486  		vis.Walk(x.Term)
   487  		vis.Walk(x.Body)
   488  	case *ObjectComprehension:
   489  		vis.Walk(x.Key)
   490  		vis.Walk(x.Value)
   491  		vis.Walk(x.Body)
   492  	case *SetComprehension:
   493  		vis.Walk(x.Term)
   494  		vis.Walk(x.Body)
   495  	case Call:
   496  		for _, t := range x {
   497  			vis.Walk(t)
   498  		}
   499  	case *Every:
   500  		if x.Key != nil {
   501  			vis.Walk(x.Key)
   502  		}
   503  		vis.Walk(x.Value)
   504  		vis.Walk(x.Domain)
   505  		vis.Walk(x.Body)
   506  	}
   507  }
   508  
   509  // VarVisitor walks AST nodes under a given node and collects all encountered
   510  // variables. The collected variables can be controlled by specifying
   511  // VarVisitorParams when creating the visitor.
   512  type VarVisitor struct {
   513  	params VarVisitorParams
   514  	vars   VarSet
   515  }
   516  
   517  // VarVisitorParams contains settings for a VarVisitor.
   518  type VarVisitorParams struct {
   519  	SkipRefHead     bool
   520  	SkipRefCallHead bool
   521  	SkipObjectKeys  bool
   522  	SkipClosures    bool
   523  	SkipWithTarget  bool
   524  	SkipSets        bool
   525  }
   526  
   527  // NewVarVisitor returns a new VarVisitor object.
   528  func NewVarVisitor() *VarVisitor {
   529  	return &VarVisitor{
   530  		vars: NewVarSet(),
   531  	}
   532  }
   533  
   534  // WithParams sets the parameters in params on vis.
   535  func (vis *VarVisitor) WithParams(params VarVisitorParams) *VarVisitor {
   536  	vis.params = params
   537  	return vis
   538  }
   539  
   540  // Vars returns a VarSet that contains collected vars.
   541  func (vis *VarVisitor) Vars() VarSet {
   542  	return vis.vars
   543  }
   544  
   545  // visit determines if the VarVisitor will recurse into x: if it returns `true`,
   546  // the visitor will _skip_ that branch of the AST
   547  func (vis *VarVisitor) visit(v interface{}) bool {
   548  	if vis.params.SkipObjectKeys {
   549  		if o, ok := v.(Object); ok {
   550  			o.Foreach(func(k, v *Term) {
   551  				vis.Walk(v)
   552  			})
   553  			return true
   554  		}
   555  	}
   556  	if vis.params.SkipRefHead {
   557  		if r, ok := v.(Ref); ok {
   558  			for _, t := range r[1:] {
   559  				vis.Walk(t)
   560  			}
   561  			return true
   562  		}
   563  	}
   564  	if vis.params.SkipClosures {
   565  		switch v := v.(type) {
   566  		case *ArrayComprehension, *ObjectComprehension, *SetComprehension:
   567  			return true
   568  		case *Expr:
   569  			if ev, ok := v.Terms.(*Every); ok {
   570  				vis.Walk(ev.Domain)
   571  				// We're _not_ walking ev.Body -- that's the closure here
   572  				return true
   573  			}
   574  		}
   575  	}
   576  	if vis.params.SkipWithTarget {
   577  		if v, ok := v.(*With); ok {
   578  			vis.Walk(v.Value)
   579  			return true
   580  		}
   581  	}
   582  	if vis.params.SkipSets {
   583  		if _, ok := v.(Set); ok {
   584  			return true
   585  		}
   586  	}
   587  	if vis.params.SkipRefCallHead {
   588  		switch v := v.(type) {
   589  		case *Expr:
   590  			if terms, ok := v.Terms.([]*Term); ok {
   591  				for _, t := range terms[0].Value.(Ref)[1:] {
   592  					vis.Walk(t)
   593  				}
   594  				for i := 1; i < len(terms); i++ {
   595  					vis.Walk(terms[i])
   596  				}
   597  				for _, w := range v.With {
   598  					vis.Walk(w)
   599  				}
   600  				return true
   601  			}
   602  		case Call:
   603  			operator := v[0].Value.(Ref)
   604  			for i := 1; i < len(operator); i++ {
   605  				vis.Walk(operator[i])
   606  			}
   607  			for i := 1; i < len(v); i++ {
   608  				vis.Walk(v[i])
   609  			}
   610  			return true
   611  		case *With:
   612  			if ref, ok := v.Target.Value.(Ref); ok {
   613  				for _, t := range ref[1:] {
   614  					vis.Walk(t)
   615  				}
   616  			}
   617  			if ref, ok := v.Value.Value.(Ref); ok {
   618  				for _, t := range ref[1:] {
   619  					vis.Walk(t)
   620  				}
   621  			} else {
   622  				vis.Walk(v.Value)
   623  			}
   624  			return true
   625  		}
   626  	}
   627  	if v, ok := v.(Var); ok {
   628  		vis.vars.Add(v)
   629  	}
   630  	return false
   631  }
   632  
   633  // Walk iterates the AST by calling the function f on the
   634  // GenericVisitor before recursing. Contrary to the generic Walk, this
   635  // does not require allocating the visitor from heap.
   636  func (vis *VarVisitor) Walk(x interface{}) {
   637  	if vis.visit(x) {
   638  		return
   639  	}
   640  
   641  	switch x := x.(type) {
   642  	case *Module:
   643  		vis.Walk(x.Package)
   644  		for _, i := range x.Imports {
   645  			vis.Walk(i)
   646  		}
   647  		for _, r := range x.Rules {
   648  			vis.Walk(r)
   649  		}
   650  		for _, c := range x.Comments {
   651  			vis.Walk(c)
   652  		}
   653  	case *Package:
   654  		vis.Walk(x.Path)
   655  	case *Import:
   656  		vis.Walk(x.Path)
   657  		vis.Walk(x.Alias)
   658  	case *Rule:
   659  		vis.Walk(x.Head)
   660  		vis.Walk(x.Body)
   661  		if x.Else != nil {
   662  			vis.Walk(x.Else)
   663  		}
   664  	case *Head:
   665  		vis.Walk(x.Name)
   666  		vis.Walk(x.Args)
   667  		if x.Key != nil {
   668  			vis.Walk(x.Key)
   669  		}
   670  		if x.Value != nil {
   671  			vis.Walk(x.Value)
   672  		}
   673  	case Body:
   674  		for _, e := range x {
   675  			vis.Walk(e)
   676  		}
   677  	case Args:
   678  		for _, t := range x {
   679  			vis.Walk(t)
   680  		}
   681  	case *Expr:
   682  		switch ts := x.Terms.(type) {
   683  		case *Term, *SomeDecl, *Every:
   684  			vis.Walk(ts)
   685  		case []*Term:
   686  			for _, t := range ts {
   687  				vis.Walk(t)
   688  			}
   689  		}
   690  		for i := range x.With {
   691  			vis.Walk(x.With[i])
   692  		}
   693  	case *With:
   694  		vis.Walk(x.Target)
   695  		vis.Walk(x.Value)
   696  	case *Term:
   697  		vis.Walk(x.Value)
   698  	case Ref:
   699  		for _, t := range x {
   700  			vis.Walk(t)
   701  		}
   702  	case *object:
   703  		x.Foreach(func(k, v *Term) {
   704  			vis.Walk(k)
   705  			vis.Walk(x.Get(k))
   706  		})
   707  	case *Array:
   708  		x.Foreach(func(t *Term) {
   709  			vis.Walk(t)
   710  		})
   711  	case Set:
   712  		for _, t := range x.Slice() {
   713  			vis.Walk(t)
   714  		}
   715  	case *ArrayComprehension:
   716  		vis.Walk(x.Term)
   717  		vis.Walk(x.Body)
   718  	case *ObjectComprehension:
   719  		vis.Walk(x.Key)
   720  		vis.Walk(x.Value)
   721  		vis.Walk(x.Body)
   722  	case *SetComprehension:
   723  		vis.Walk(x.Term)
   724  		vis.Walk(x.Body)
   725  	case Call:
   726  		for _, t := range x {
   727  			vis.Walk(t)
   728  		}
   729  	case *Every:
   730  		if x.Key != nil {
   731  			vis.Walk(x.Key)
   732  		}
   733  		vis.Walk(x.Value)
   734  		vis.Walk(x.Domain)
   735  		vis.Walk(x.Body)
   736  	}
   737  }