github.com/rajeev159/opa@v0.45.0/topdown/save.go (about)

     1  package topdown
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  	"strings"
     7  
     8  	"github.com/open-policy-agent/opa/ast"
     9  )
    10  
    11  // saveSet contains a stack of terms that are considered 'unknown' during
    12  // partial evaluation. Only var and ref terms (rooted at one of the root
    13  // documents) can be added to the save set. Vars added to the save set are
    14  // namespaced by the binding list they are added with. This means the save set
    15  // can be shared across queries.
    16  type saveSet struct {
    17  	instr *Instrumentation
    18  	l     *list.List
    19  }
    20  
    21  func newSaveSet(ts []*ast.Term, b *bindings, instr *Instrumentation) *saveSet {
    22  	ss := &saveSet{
    23  		l:     list.New(),
    24  		instr: instr,
    25  	}
    26  	ss.Push(ts, b)
    27  	return ss
    28  }
    29  
    30  func (ss *saveSet) Push(ts []*ast.Term, b *bindings) {
    31  	ss.l.PushBack(newSaveSetElem(ts, b))
    32  }
    33  
    34  func (ss *saveSet) Pop() {
    35  	ss.l.Remove(ss.l.Back())
    36  }
    37  
    38  // Contains returns true if the term t is contained in the save set. Non-var and
    39  // non-ref terms are never contained. Ref terms are contained if they share a
    40  // prefix with a ref that was added (in either direction).
    41  func (ss *saveSet) Contains(t *ast.Term, b *bindings) bool {
    42  	if ss != nil {
    43  		ss.instr.startTimer(partialOpSaveSetContains)
    44  		ret := ss.contains(t, b)
    45  		ss.instr.stopTimer(partialOpSaveSetContains)
    46  		return ret
    47  	}
    48  	return false
    49  }
    50  
    51  func (ss *saveSet) contains(t *ast.Term, b *bindings) bool {
    52  	for el := ss.l.Back(); el != nil; el = el.Prev() {
    53  		if el.Value.(*saveSetElem).Contains(t, b) {
    54  			return true
    55  		}
    56  	}
    57  	return false
    58  }
    59  
    60  // ContainsRecursive returns true if the term t is or contains a term that is
    61  // contained in the save set. This function will close over the binding list
    62  // when it encounters vars.
    63  func (ss *saveSet) ContainsRecursive(t *ast.Term, b *bindings) bool {
    64  	if ss != nil {
    65  		ss.instr.startTimer(partialOpSaveSetContainsRec)
    66  		ret := ss.containsrec(t, b)
    67  		ss.instr.stopTimer(partialOpSaveSetContainsRec)
    68  		return ret
    69  	}
    70  	return false
    71  }
    72  
    73  func (ss *saveSet) containsrec(t *ast.Term, b *bindings) bool {
    74  	var found bool
    75  	ast.WalkTerms(t, func(x *ast.Term) bool {
    76  		if _, ok := x.Value.(ast.Var); ok {
    77  			x1, b1 := b.apply(x)
    78  			if x1 != x || b1 != b {
    79  				if ss.containsrec(x1, b1) {
    80  					found = true
    81  				}
    82  			} else if ss.contains(x1, b1) {
    83  				found = true
    84  			}
    85  		}
    86  		return found
    87  	})
    88  	return found
    89  }
    90  
    91  func (ss *saveSet) Vars(caller *bindings) ast.VarSet {
    92  	result := ast.NewVarSet()
    93  	for x := ss.l.Front(); x != nil; x = x.Next() {
    94  		elem := x.Value.(*saveSetElem)
    95  		for _, v := range elem.vars {
    96  			if v, ok := elem.b.PlugNamespaced(v, caller).Value.(ast.Var); ok {
    97  				result.Add(v)
    98  			}
    99  		}
   100  	}
   101  	return result
   102  }
   103  
   104  func (ss *saveSet) String() string {
   105  	var buf []string
   106  
   107  	for x := ss.l.Front(); x != nil; x = x.Next() {
   108  		buf = append(buf, x.Value.(*saveSetElem).String())
   109  	}
   110  
   111  	return "(" + strings.Join(buf, " ") + ")"
   112  }
   113  
   114  type saveSetElem struct {
   115  	refs []ast.Ref
   116  	vars []*ast.Term
   117  	b    *bindings
   118  }
   119  
   120  func newSaveSetElem(ts []*ast.Term, b *bindings) *saveSetElem {
   121  
   122  	var refs []ast.Ref
   123  	var vars []*ast.Term
   124  
   125  	for _, t := range ts {
   126  		switch v := t.Value.(type) {
   127  		case ast.Var:
   128  			vars = append(vars, t)
   129  		case ast.Ref:
   130  			refs = append(refs, v)
   131  		default:
   132  			panic("illegal value")
   133  		}
   134  	}
   135  
   136  	return &saveSetElem{
   137  		b:    b,
   138  		vars: vars,
   139  		refs: refs,
   140  	}
   141  }
   142  
   143  func (sse *saveSetElem) Contains(t *ast.Term, b *bindings) bool {
   144  	switch other := t.Value.(type) {
   145  	case ast.Var:
   146  		return sse.containsVar(t, b)
   147  	case ast.Ref:
   148  		for _, ref := range sse.refs {
   149  			if ref.HasPrefix(other) || other.HasPrefix(ref) {
   150  				return true
   151  			}
   152  		}
   153  		return sse.containsVar(other[0], b)
   154  	}
   155  	return false
   156  }
   157  
   158  func (sse *saveSetElem) String() string {
   159  	return fmt.Sprintf("(refs: %v, vars: %v, b: %v)", sse.refs, sse.vars, sse.b)
   160  }
   161  
   162  func (sse *saveSetElem) containsVar(t *ast.Term, b *bindings) bool {
   163  	if b == sse.b {
   164  		for _, v := range sse.vars {
   165  			if v.Equal(t) {
   166  				return true
   167  			}
   168  		}
   169  	}
   170  	return false
   171  }
   172  
   173  // saveStack contains a stack of queries that represent the result of partial
   174  // evaluation. When partial evaluation completes, the top of the stack
   175  // represents a complete, partially evaluated query that can be saved and
   176  // evaluated later.
   177  //
   178  // The result is stored in a stack so that partial evaluation of a query can be
   179  // paused and then resumed in cases where different queries make up the result
   180  // of partial evaluation, such as when a rule with a default clause is
   181  // partially evaluated. In this case, the partially evaluated rule will be
   182  // output in the support module.
   183  type saveStack struct {
   184  	Stack []saveStackQuery
   185  }
   186  
   187  func newSaveStack() *saveStack {
   188  	return &saveStack{
   189  		Stack: []saveStackQuery{
   190  			{},
   191  		},
   192  	}
   193  }
   194  
   195  func (s *saveStack) PushQuery(query saveStackQuery) {
   196  	s.Stack = append(s.Stack, query)
   197  }
   198  
   199  func (s *saveStack) PopQuery() saveStackQuery {
   200  	last := s.Stack[len(s.Stack)-1]
   201  	s.Stack = s.Stack[:len(s.Stack)-1]
   202  	return last
   203  }
   204  
   205  func (s *saveStack) Peek() saveStackQuery {
   206  	return s.Stack[len(s.Stack)-1]
   207  }
   208  
   209  func (s *saveStack) Push(expr *ast.Expr, b1 *bindings, b2 *bindings) {
   210  	idx := len(s.Stack) - 1
   211  	s.Stack[idx] = append(s.Stack[idx], saveStackElem{expr, b1, b2})
   212  }
   213  
   214  func (s *saveStack) Pop() {
   215  	idx := len(s.Stack) - 1
   216  	query := s.Stack[idx]
   217  	s.Stack[idx] = query[:len(query)-1]
   218  }
   219  
   220  type saveStackQuery []saveStackElem
   221  
   222  func (s saveStackQuery) Plug(b *bindings) ast.Body {
   223  	if len(s) == 0 {
   224  		return ast.NewBody(ast.NewExpr(ast.BooleanTerm(true)))
   225  	}
   226  	result := make(ast.Body, len(s))
   227  	for i := range s {
   228  		expr := s[i].Plug(b)
   229  		result.Set(expr, i)
   230  	}
   231  	return result
   232  }
   233  
   234  type saveStackElem struct {
   235  	Expr *ast.Expr
   236  	B1   *bindings
   237  	B2   *bindings
   238  }
   239  
   240  func (e saveStackElem) Plug(caller *bindings) *ast.Expr {
   241  	if e.B1 == nil && e.B2 == nil {
   242  		return e.Expr
   243  	}
   244  	expr := e.Expr.Copy()
   245  	switch terms := expr.Terms.(type) {
   246  	case []*ast.Term:
   247  		if expr.IsEquality() {
   248  			terms[1] = e.B1.PlugNamespaced(terms[1], caller)
   249  			terms[2] = e.B2.PlugNamespaced(terms[2], caller)
   250  		} else {
   251  			for i := 1; i < len(terms); i++ {
   252  				terms[i] = e.B1.PlugNamespaced(terms[i], caller)
   253  			}
   254  		}
   255  	case *ast.Term:
   256  		expr.Terms = e.B1.PlugNamespaced(terms, caller)
   257  	}
   258  	for i := range expr.With {
   259  		expr.With[i].Value = e.B1.PlugNamespaced(expr.With[i].Value, caller)
   260  	}
   261  	return expr
   262  }
   263  
   264  // saveSupport contains additional partially evaluated policies that are part
   265  // of the output of partial evaluation.
   266  //
   267  // The support structure is accumulated as partial evaluation runs and then
   268  // considered complete once partial evaluation finishes (but not before). This
   269  // differs from partially evaluated queries which are considered complete as
   270  // soon as each one finishes.
   271  type saveSupport struct {
   272  	modules map[string]*ast.Module
   273  }
   274  
   275  func newSaveSupport() *saveSupport {
   276  	return &saveSupport{
   277  		modules: map[string]*ast.Module{},
   278  	}
   279  }
   280  
   281  func (s *saveSupport) List() []*ast.Module {
   282  	result := make([]*ast.Module, 0, len(s.modules))
   283  	for _, module := range s.modules {
   284  		result = append(result, module)
   285  	}
   286  	return result
   287  }
   288  
   289  func (s *saveSupport) Exists(path ast.Ref) bool {
   290  	k := path[:len(path)-1].String()
   291  	module, ok := s.modules[k]
   292  	if !ok {
   293  		return false
   294  	}
   295  	name := ast.Var(path[len(path)-1].Value.(ast.String))
   296  	for _, rule := range module.Rules {
   297  		if rule.Head.Name.Equal(name) {
   298  			return true
   299  		}
   300  	}
   301  	return false
   302  }
   303  
   304  func (s *saveSupport) Insert(path ast.Ref, rule *ast.Rule) {
   305  	pkg := path[:len(path)-1]
   306  	k := pkg.String()
   307  	module, ok := s.modules[k]
   308  	if !ok {
   309  		module = &ast.Module{
   310  			Package: &ast.Package{
   311  				Path: pkg,
   312  			},
   313  		}
   314  		s.modules[k] = module
   315  	}
   316  	rule.Module = module
   317  	module.Rules = append(module.Rules, rule)
   318  }
   319  
   320  // saveRequired returns true if the statement x will result in some expressions
   321  // being saved. This check allows the evaluator to evaluate statements
   322  // completely during partial evaluation as long as they do not depend on any
   323  // kind of unknown value or statements that would generate saves.
   324  func saveRequired(c *ast.Compiler, ic *inliningControl, icIgnoreInternal bool, ss *saveSet, b *bindings, x interface{}, rec bool) bool {
   325  
   326  	var found bool
   327  
   328  	vis := ast.NewGenericVisitor(func(node interface{}) bool {
   329  		if found {
   330  			return found
   331  		}
   332  		switch node := node.(type) {
   333  		case *ast.Expr:
   334  			found = len(node.With) > 0 || ignoreExprDuringPartial(node)
   335  		case *ast.Term:
   336  			switch v := node.Value.(type) {
   337  			case ast.Var:
   338  				// Variables only need to be tested in the node from call site
   339  				// because once traversal recurses into a rule existing unknown
   340  				// variables are out-of-scope.
   341  				if !rec && ss.ContainsRecursive(node, b) {
   342  					found = true
   343  				}
   344  			case ast.Ref:
   345  				if ss.Contains(node, b) {
   346  					found = true
   347  				} else if ic.Disabled(v.ConstantPrefix(), icIgnoreInternal) {
   348  					found = true
   349  				} else {
   350  					for _, rule := range c.GetRulesDynamicWithOpts(v, ast.RulesOptions{IncludeHiddenModules: false}) {
   351  						if saveRequired(c, ic, icIgnoreInternal, ss, b, rule, true) {
   352  							found = true
   353  							break
   354  						}
   355  					}
   356  				}
   357  			}
   358  		}
   359  		return found
   360  	})
   361  
   362  	vis.Walk(x)
   363  
   364  	return found
   365  }
   366  
   367  func ignoreExprDuringPartial(expr *ast.Expr) bool {
   368  	if !expr.IsCall() {
   369  		return false
   370  	}
   371  
   372  	bi, ok := ast.BuiltinMap[expr.Operator().String()]
   373  
   374  	return ok && ignoreDuringPartial(bi)
   375  }
   376  
   377  func ignoreDuringPartial(bi *ast.Builtin) bool {
   378  	for _, ignore := range ast.IgnoreDuringPartialEval {
   379  		if bi == ignore {
   380  			return true
   381  		}
   382  	}
   383  	return false
   384  }
   385  
   386  type inliningControl struct {
   387  	shallow bool
   388  	disable []disableInliningFrame
   389  }
   390  
   391  type disableInliningFrame struct {
   392  	internal bool
   393  	refs     []ast.Ref
   394  }
   395  
   396  func (i *inliningControl) PushDisable(refs []ast.Ref, internal bool) {
   397  	if i == nil {
   398  		return
   399  	}
   400  	i.disable = append(i.disable, disableInliningFrame{
   401  		internal: internal,
   402  		refs:     refs,
   403  	})
   404  }
   405  
   406  func (i *inliningControl) PopDisable() {
   407  	if i == nil {
   408  		return
   409  	}
   410  	i.disable = i.disable[:len(i.disable)-1]
   411  }
   412  
   413  func (i *inliningControl) Disabled(ref ast.Ref, ignoreInternal bool) bool {
   414  	if i == nil {
   415  		return false
   416  	}
   417  	for _, frame := range i.disable {
   418  		if !frame.internal || !ignoreInternal {
   419  			for _, other := range frame.refs {
   420  				if other.HasPrefix(ref) || ref.HasPrefix(other) {
   421  					return true
   422  				}
   423  			}
   424  		}
   425  	}
   426  	return false
   427  }