github.com/rajeev159/opa@v0.45.0/ast/visit.go (about) 1 // Copyright 2016 The OPA Authors. All rights reserved. 2 // Use of this source code is governed by an Apache2 3 // license that can be found in the LICENSE file. 4 5 package ast 6 7 // Visitor defines the interface for iterating AST elements. The Visit function 8 // can return a Visitor w which will be used to visit the children of the AST 9 // element v. If the Visit function returns nil, the children will not be 10 // visited. This is deprecated. 11 type Visitor interface { 12 Visit(v interface{}) (w Visitor) 13 } 14 15 // BeforeAndAfterVisitor wraps Visitor to provide hooks for being called before 16 // and after the AST has been visited. This is deprecated. 17 type BeforeAndAfterVisitor interface { 18 Visitor 19 Before(x interface{}) 20 After(x interface{}) 21 } 22 23 // Walk iterates the AST by calling the Visit function on the Visitor 24 // v for x before recursing. This is deprecated. 25 func Walk(v Visitor, x interface{}) { 26 if bav, ok := v.(BeforeAndAfterVisitor); !ok { 27 walk(v, x) 28 } else { 29 bav.Before(x) 30 defer bav.After(x) 31 walk(bav, x) 32 } 33 } 34 35 // WalkBeforeAndAfter iterates the AST by calling the Visit function on the 36 // Visitor v for x before recursing. This is deprecated. 37 func WalkBeforeAndAfter(v BeforeAndAfterVisitor, x interface{}) { 38 Walk(v, x) 39 } 40 41 func walk(v Visitor, x interface{}) { 42 w := v.Visit(x) 43 if w == nil { 44 return 45 } 46 switch x := x.(type) { 47 case *Module: 48 Walk(w, x.Package) 49 for _, i := range x.Imports { 50 Walk(w, i) 51 } 52 for _, r := range x.Rules { 53 Walk(w, r) 54 } 55 for _, a := range x.Annotations { 56 Walk(w, a) 57 } 58 for _, c := range x.Comments { 59 Walk(w, c) 60 } 61 case *Package: 62 Walk(w, x.Path) 63 case *Import: 64 Walk(w, x.Path) 65 Walk(w, x.Alias) 66 case *Rule: 67 Walk(w, x.Head) 68 Walk(w, x.Body) 69 if x.Else != nil { 70 Walk(w, x.Else) 71 } 72 case *Head: 73 Walk(w, x.Name) 74 Walk(w, x.Args) 75 if x.Key != nil { 76 Walk(w, x.Key) 77 } 78 if x.Value != nil { 79 Walk(w, x.Value) 80 } 81 case Body: 82 for _, e := range x { 83 Walk(w, e) 84 } 85 case Args: 86 for _, t := range x { 87 Walk(w, t) 88 } 89 case *Expr: 90 switch ts := x.Terms.(type) { 91 case *Term, *SomeDecl, *Every: 92 Walk(w, ts) 93 case []*Term: 94 for _, t := range ts { 95 Walk(w, t) 96 } 97 } 98 for i := range x.With { 99 Walk(w, x.With[i]) 100 } 101 case *With: 102 Walk(w, x.Target) 103 Walk(w, x.Value) 104 case *Term: 105 Walk(w, x.Value) 106 case Ref: 107 for _, t := range x { 108 Walk(w, t) 109 } 110 case *object: 111 x.Foreach(func(k, vv *Term) { 112 Walk(w, k) 113 Walk(w, vv) 114 }) 115 case *Array: 116 x.Foreach(func(t *Term) { 117 Walk(w, t) 118 }) 119 case Set: 120 x.Foreach(func(t *Term) { 121 Walk(w, t) 122 }) 123 case *ArrayComprehension: 124 Walk(w, x.Term) 125 Walk(w, x.Body) 126 case *ObjectComprehension: 127 Walk(w, x.Key) 128 Walk(w, x.Value) 129 Walk(w, x.Body) 130 case *SetComprehension: 131 Walk(w, x.Term) 132 Walk(w, x.Body) 133 case Call: 134 for _, t := range x { 135 Walk(w, t) 136 } 137 case *Every: 138 if x.Key != nil { 139 Walk(w, x.Key) 140 } 141 Walk(w, x.Value) 142 Walk(w, x.Domain) 143 Walk(w, x.Body) 144 } 145 } 146 147 // WalkVars calls the function f on all vars under x. If the function f 148 // returns true, AST nodes under the last node will not be visited. 149 func WalkVars(x interface{}, f func(Var) bool) { 150 vis := &GenericVisitor{func(x interface{}) bool { 151 if v, ok := x.(Var); ok { 152 return f(v) 153 } 154 return false 155 }} 156 vis.Walk(x) 157 } 158 159 // WalkClosures calls the function f on all closures under x. If the function f 160 // returns true, AST nodes under the last node will not be visited. 161 func WalkClosures(x interface{}, f func(interface{}) bool) { 162 vis := &GenericVisitor{func(x interface{}) bool { 163 switch x := x.(type) { 164 case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every: 165 return f(x) 166 } 167 return false 168 }} 169 vis.Walk(x) 170 } 171 172 // WalkRefs calls the function f on all references under x. If the function f 173 // returns true, AST nodes under the last node will not be visited. 174 func WalkRefs(x interface{}, f func(Ref) bool) { 175 vis := &GenericVisitor{func(x interface{}) bool { 176 if r, ok := x.(Ref); ok { 177 return f(r) 178 } 179 return false 180 }} 181 vis.Walk(x) 182 } 183 184 // WalkTerms calls the function f on all terms under x. If the function f 185 // returns true, AST nodes under the last node will not be visited. 186 func WalkTerms(x interface{}, f func(*Term) bool) { 187 vis := &GenericVisitor{func(x interface{}) bool { 188 if term, ok := x.(*Term); ok { 189 return f(term) 190 } 191 return false 192 }} 193 vis.Walk(x) 194 } 195 196 // WalkWiths calls the function f on all with modifiers under x. If the function f 197 // returns true, AST nodes under the last node will not be visited. 198 func WalkWiths(x interface{}, f func(*With) bool) { 199 vis := &GenericVisitor{func(x interface{}) bool { 200 if w, ok := x.(*With); ok { 201 return f(w) 202 } 203 return false 204 }} 205 vis.Walk(x) 206 } 207 208 // WalkExprs calls the function f on all expressions under x. If the function f 209 // returns true, AST nodes under the last node will not be visited. 210 func WalkExprs(x interface{}, f func(*Expr) bool) { 211 vis := &GenericVisitor{func(x interface{}) bool { 212 if r, ok := x.(*Expr); ok { 213 return f(r) 214 } 215 return false 216 }} 217 vis.Walk(x) 218 } 219 220 // WalkBodies calls the function f on all bodies under x. If the function f 221 // returns true, AST nodes under the last node will not be visited. 222 func WalkBodies(x interface{}, f func(Body) bool) { 223 vis := &GenericVisitor{func(x interface{}) bool { 224 if b, ok := x.(Body); ok { 225 return f(b) 226 } 227 return false 228 }} 229 vis.Walk(x) 230 } 231 232 // WalkRules calls the function f on all rules under x. If the function f 233 // returns true, AST nodes under the last node will not be visited. 234 func WalkRules(x interface{}, f func(*Rule) bool) { 235 vis := &GenericVisitor{func(x interface{}) bool { 236 if r, ok := x.(*Rule); ok { 237 stop := f(r) 238 // NOTE(tsandall): since rules cannot be embedded inside of queries 239 // we can stop early if there is no else block. 240 if stop || r.Else == nil { 241 return true 242 } 243 } 244 return false 245 }} 246 vis.Walk(x) 247 } 248 249 // WalkNodes calls the function f on all nodes under x. If the function f 250 // returns true, AST nodes under the last node will not be visited. 251 func WalkNodes(x interface{}, f func(Node) bool) { 252 vis := &GenericVisitor{func(x interface{}) bool { 253 if n, ok := x.(Node); ok { 254 return f(n) 255 } 256 return false 257 }} 258 vis.Walk(x) 259 } 260 261 // GenericVisitor provides a utility to walk over AST nodes using a 262 // closure. If the closure returns true, the visitor will not walk 263 // over AST nodes under x. 264 type GenericVisitor struct { 265 f func(x interface{}) bool 266 } 267 268 // NewGenericVisitor returns a new GenericVisitor that will invoke the function 269 // f on AST nodes. 270 func NewGenericVisitor(f func(x interface{}) bool) *GenericVisitor { 271 return &GenericVisitor{f} 272 } 273 274 // Walk iterates the AST by calling the function f on the 275 // GenericVisitor before recursing. Contrary to the generic Walk, this 276 // does not require allocating the visitor from heap. 277 func (vis *GenericVisitor) Walk(x interface{}) { 278 if vis.f(x) { 279 return 280 } 281 282 switch x := x.(type) { 283 case *Module: 284 vis.Walk(x.Package) 285 for _, i := range x.Imports { 286 vis.Walk(i) 287 } 288 for _, r := range x.Rules { 289 vis.Walk(r) 290 } 291 for _, a := range x.Annotations { 292 vis.Walk(a) 293 } 294 for _, c := range x.Comments { 295 vis.Walk(c) 296 } 297 case *Package: 298 vis.Walk(x.Path) 299 case *Import: 300 vis.Walk(x.Path) 301 vis.Walk(x.Alias) 302 case *Rule: 303 vis.Walk(x.Head) 304 vis.Walk(x.Body) 305 if x.Else != nil { 306 vis.Walk(x.Else) 307 } 308 case *Head: 309 vis.Walk(x.Name) 310 vis.Walk(x.Args) 311 if x.Key != nil { 312 vis.Walk(x.Key) 313 } 314 if x.Value != nil { 315 vis.Walk(x.Value) 316 } 317 case Body: 318 for _, e := range x { 319 vis.Walk(e) 320 } 321 case Args: 322 for _, t := range x { 323 vis.Walk(t) 324 } 325 case *Expr: 326 switch ts := x.Terms.(type) { 327 case *Term, *SomeDecl, *Every: 328 vis.Walk(ts) 329 case []*Term: 330 for _, t := range ts { 331 vis.Walk(t) 332 } 333 } 334 for i := range x.With { 335 vis.Walk(x.With[i]) 336 } 337 case *With: 338 vis.Walk(x.Target) 339 vis.Walk(x.Value) 340 case *Term: 341 vis.Walk(x.Value) 342 case Ref: 343 for _, t := range x { 344 vis.Walk(t) 345 } 346 case *object: 347 x.Foreach(func(k, v *Term) { 348 vis.Walk(k) 349 vis.Walk(x.Get(k)) 350 }) 351 case *Array: 352 x.Foreach(func(t *Term) { 353 vis.Walk(t) 354 }) 355 case Set: 356 for _, t := range x.Slice() { 357 vis.Walk(t) 358 } 359 case *ArrayComprehension: 360 vis.Walk(x.Term) 361 vis.Walk(x.Body) 362 case *ObjectComprehension: 363 vis.Walk(x.Key) 364 vis.Walk(x.Value) 365 vis.Walk(x.Body) 366 case *SetComprehension: 367 vis.Walk(x.Term) 368 vis.Walk(x.Body) 369 case Call: 370 for _, t := range x { 371 vis.Walk(t) 372 } 373 case *Every: 374 if x.Key != nil { 375 vis.Walk(x.Key) 376 } 377 vis.Walk(x.Value) 378 vis.Walk(x.Domain) 379 vis.Walk(x.Body) 380 } 381 } 382 383 // BeforeAfterVisitor provides a utility to walk over AST nodes using 384 // closures. If the before closure returns true, the visitor will not 385 // walk over AST nodes under x. The after closure is invoked always 386 // after visiting a node. 387 type BeforeAfterVisitor struct { 388 before func(x interface{}) bool 389 after func(x interface{}) 390 } 391 392 // NewBeforeAfterVisitor returns a new BeforeAndAfterVisitor that 393 // will invoke the functions before and after AST nodes. 394 func NewBeforeAfterVisitor(before func(x interface{}) bool, after func(x interface{})) *BeforeAfterVisitor { 395 return &BeforeAfterVisitor{before, after} 396 } 397 398 // Walk iterates the AST by calling the functions on the 399 // BeforeAndAfterVisitor before and after recursing. Contrary to the 400 // generic Walk, this does not require allocating the visitor from 401 // heap. 402 func (vis *BeforeAfterVisitor) Walk(x interface{}) { 403 defer vis.after(x) 404 if vis.before(x) { 405 return 406 } 407 408 switch x := x.(type) { 409 case *Module: 410 vis.Walk(x.Package) 411 for _, i := range x.Imports { 412 vis.Walk(i) 413 } 414 for _, r := range x.Rules { 415 vis.Walk(r) 416 } 417 for _, a := range x.Annotations { 418 vis.Walk(a) 419 } 420 for _, c := range x.Comments { 421 vis.Walk(c) 422 } 423 case *Package: 424 vis.Walk(x.Path) 425 case *Import: 426 vis.Walk(x.Path) 427 vis.Walk(x.Alias) 428 case *Rule: 429 vis.Walk(x.Head) 430 vis.Walk(x.Body) 431 if x.Else != nil { 432 vis.Walk(x.Else) 433 } 434 case *Head: 435 vis.Walk(x.Name) 436 vis.Walk(x.Args) 437 if x.Key != nil { 438 vis.Walk(x.Key) 439 } 440 if x.Value != nil { 441 vis.Walk(x.Value) 442 } 443 case Body: 444 for _, e := range x { 445 vis.Walk(e) 446 } 447 case Args: 448 for _, t := range x { 449 vis.Walk(t) 450 } 451 case *Expr: 452 switch ts := x.Terms.(type) { 453 case *Term, *SomeDecl, *Every: 454 vis.Walk(ts) 455 case []*Term: 456 for _, t := range ts { 457 vis.Walk(t) 458 } 459 } 460 for i := range x.With { 461 vis.Walk(x.With[i]) 462 } 463 case *With: 464 vis.Walk(x.Target) 465 vis.Walk(x.Value) 466 case *Term: 467 vis.Walk(x.Value) 468 case Ref: 469 for _, t := range x { 470 vis.Walk(t) 471 } 472 case *object: 473 x.Foreach(func(k, v *Term) { 474 vis.Walk(k) 475 vis.Walk(x.Get(k)) 476 }) 477 case *Array: 478 x.Foreach(func(t *Term) { 479 vis.Walk(t) 480 }) 481 case Set: 482 for _, t := range x.Slice() { 483 vis.Walk(t) 484 } 485 case *ArrayComprehension: 486 vis.Walk(x.Term) 487 vis.Walk(x.Body) 488 case *ObjectComprehension: 489 vis.Walk(x.Key) 490 vis.Walk(x.Value) 491 vis.Walk(x.Body) 492 case *SetComprehension: 493 vis.Walk(x.Term) 494 vis.Walk(x.Body) 495 case Call: 496 for _, t := range x { 497 vis.Walk(t) 498 } 499 case *Every: 500 if x.Key != nil { 501 vis.Walk(x.Key) 502 } 503 vis.Walk(x.Value) 504 vis.Walk(x.Domain) 505 vis.Walk(x.Body) 506 } 507 } 508 509 // VarVisitor walks AST nodes under a given node and collects all encountered 510 // variables. The collected variables can be controlled by specifying 511 // VarVisitorParams when creating the visitor. 512 type VarVisitor struct { 513 params VarVisitorParams 514 vars VarSet 515 } 516 517 // VarVisitorParams contains settings for a VarVisitor. 518 type VarVisitorParams struct { 519 SkipRefHead bool 520 SkipRefCallHead bool 521 SkipObjectKeys bool 522 SkipClosures bool 523 SkipWithTarget bool 524 SkipSets bool 525 } 526 527 // NewVarVisitor returns a new VarVisitor object. 528 func NewVarVisitor() *VarVisitor { 529 return &VarVisitor{ 530 vars: NewVarSet(), 531 } 532 } 533 534 // WithParams sets the parameters in params on vis. 535 func (vis *VarVisitor) WithParams(params VarVisitorParams) *VarVisitor { 536 vis.params = params 537 return vis 538 } 539 540 // Vars returns a VarSet that contains collected vars. 541 func (vis *VarVisitor) Vars() VarSet { 542 return vis.vars 543 } 544 545 // visit determines if the VarVisitor will recurse into x: if it returns `true`, 546 // the visitor will _skip_ that branch of the AST 547 func (vis *VarVisitor) visit(v interface{}) bool { 548 if vis.params.SkipObjectKeys { 549 if o, ok := v.(Object); ok { 550 o.Foreach(func(k, v *Term) { 551 vis.Walk(v) 552 }) 553 return true 554 } 555 } 556 if vis.params.SkipRefHead { 557 if r, ok := v.(Ref); ok { 558 for _, t := range r[1:] { 559 vis.Walk(t) 560 } 561 return true 562 } 563 } 564 if vis.params.SkipClosures { 565 switch v := v.(type) { 566 case *ArrayComprehension, *ObjectComprehension, *SetComprehension: 567 return true 568 case *Expr: 569 if ev, ok := v.Terms.(*Every); ok { 570 vis.Walk(ev.Domain) 571 // We're _not_ walking ev.Body -- that's the closure here 572 return true 573 } 574 } 575 } 576 if vis.params.SkipWithTarget { 577 if v, ok := v.(*With); ok { 578 vis.Walk(v.Value) 579 return true 580 } 581 } 582 if vis.params.SkipSets { 583 if _, ok := v.(Set); ok { 584 return true 585 } 586 } 587 if vis.params.SkipRefCallHead { 588 switch v := v.(type) { 589 case *Expr: 590 if terms, ok := v.Terms.([]*Term); ok { 591 for _, t := range terms[0].Value.(Ref)[1:] { 592 vis.Walk(t) 593 } 594 for i := 1; i < len(terms); i++ { 595 vis.Walk(terms[i]) 596 } 597 for _, w := range v.With { 598 vis.Walk(w) 599 } 600 return true 601 } 602 case Call: 603 operator := v[0].Value.(Ref) 604 for i := 1; i < len(operator); i++ { 605 vis.Walk(operator[i]) 606 } 607 for i := 1; i < len(v); i++ { 608 vis.Walk(v[i]) 609 } 610 return true 611 case *With: 612 if ref, ok := v.Target.Value.(Ref); ok { 613 for _, t := range ref[1:] { 614 vis.Walk(t) 615 } 616 } 617 if ref, ok := v.Value.Value.(Ref); ok { 618 for _, t := range ref[1:] { 619 vis.Walk(t) 620 } 621 } else { 622 vis.Walk(v.Value) 623 } 624 return true 625 } 626 } 627 if v, ok := v.(Var); ok { 628 vis.vars.Add(v) 629 } 630 return false 631 } 632 633 // Walk iterates the AST by calling the function f on the 634 // GenericVisitor before recursing. Contrary to the generic Walk, this 635 // does not require allocating the visitor from heap. 636 func (vis *VarVisitor) Walk(x interface{}) { 637 if vis.visit(x) { 638 return 639 } 640 641 switch x := x.(type) { 642 case *Module: 643 vis.Walk(x.Package) 644 for _, i := range x.Imports { 645 vis.Walk(i) 646 } 647 for _, r := range x.Rules { 648 vis.Walk(r) 649 } 650 for _, c := range x.Comments { 651 vis.Walk(c) 652 } 653 case *Package: 654 vis.Walk(x.Path) 655 case *Import: 656 vis.Walk(x.Path) 657 vis.Walk(x.Alias) 658 case *Rule: 659 vis.Walk(x.Head) 660 vis.Walk(x.Body) 661 if x.Else != nil { 662 vis.Walk(x.Else) 663 } 664 case *Head: 665 vis.Walk(x.Name) 666 vis.Walk(x.Args) 667 if x.Key != nil { 668 vis.Walk(x.Key) 669 } 670 if x.Value != nil { 671 vis.Walk(x.Value) 672 } 673 case Body: 674 for _, e := range x { 675 vis.Walk(e) 676 } 677 case Args: 678 for _, t := range x { 679 vis.Walk(t) 680 } 681 case *Expr: 682 switch ts := x.Terms.(type) { 683 case *Term, *SomeDecl, *Every: 684 vis.Walk(ts) 685 case []*Term: 686 for _, t := range ts { 687 vis.Walk(t) 688 } 689 } 690 for i := range x.With { 691 vis.Walk(x.With[i]) 692 } 693 case *With: 694 vis.Walk(x.Target) 695 vis.Walk(x.Value) 696 case *Term: 697 vis.Walk(x.Value) 698 case Ref: 699 for _, t := range x { 700 vis.Walk(t) 701 } 702 case *object: 703 x.Foreach(func(k, v *Term) { 704 vis.Walk(k) 705 vis.Walk(x.Get(k)) 706 }) 707 case *Array: 708 x.Foreach(func(t *Term) { 709 vis.Walk(t) 710 }) 711 case Set: 712 for _, t := range x.Slice() { 713 vis.Walk(t) 714 } 715 case *ArrayComprehension: 716 vis.Walk(x.Term) 717 vis.Walk(x.Body) 718 case *ObjectComprehension: 719 vis.Walk(x.Key) 720 vis.Walk(x.Value) 721 vis.Walk(x.Body) 722 case *SetComprehension: 723 vis.Walk(x.Term) 724 vis.Walk(x.Body) 725 case Call: 726 for _, t := range x { 727 vis.Walk(t) 728 } 729 case *Every: 730 if x.Key != nil { 731 vis.Walk(x.Key) 732 } 733 vis.Walk(x.Value) 734 vis.Walk(x.Domain) 735 vis.Walk(x.Body) 736 } 737 }