github.com/rajeev159/opa@v0.45.0/topdown/copypropagation/copypropagation.go (about)

     1  // Copyright 2018 The OPA Authors.  All rights reserved.
     2  // Use of this source code is governed by an Apache2
     3  // license that can be found in the LICENSE file.
     4  
     5  package copypropagation
     6  
     7  import (
     8  	"sort"
     9  
    10  	"github.com/open-policy-agent/opa/ast"
    11  )
    12  
    13  // CopyPropagator implements a simple copy propagation optimization to remove
    14  // intermediate variables in partial evaluation results.
    15  //
    16  // For example, given the query: input.x > 1 where 'input' is unknown, the
    17  // compiled query would become input.x = a; a > 1 which would remain in the
    18  // partial evaluation result. The CopyPropagator will remove the variable
    19  // assignment so that partial evaluation simply outputs input.x > 1.
    20  //
    21  // In many cases, copy propagation can remove all variables from the result of
    22  // partial evaluation which simplifies evaluation for non-OPA consumers.
    23  //
    24  // In some cases, copy propagation cannot remove all variables. If the output of
    25  // a built-in call is subsequently used as a ref head, the output variable must
    26  // be kept. For example. sort(input, x); x[0] == 1. In this case, copy
    27  // propagation cannot replace x[0] == 1 with sort(input, x)[0] == 1 as this is
    28  // not legal.
    29  type CopyPropagator struct {
    30  	livevars           ast.VarSet // vars that must be preserved in the resulting query
    31  	sorted             []ast.Var  // sorted copy of vars to ensure deterministic result
    32  	ensureNonEmptyBody bool
    33  	compiler           *ast.Compiler
    34  }
    35  
    36  // New returns a new CopyPropagator that optimizes queries while preserving vars
    37  // in the livevars set.
    38  func New(livevars ast.VarSet) *CopyPropagator {
    39  
    40  	sorted := make([]ast.Var, 0, len(livevars))
    41  	for v := range livevars {
    42  		sorted = append(sorted, v)
    43  	}
    44  
    45  	sort.Slice(sorted, func(i, j int) bool {
    46  		return sorted[i].Compare(sorted[j]) < 0
    47  	})
    48  
    49  	return &CopyPropagator{livevars: livevars, sorted: sorted}
    50  }
    51  
    52  // WithEnsureNonEmptyBody configures p to ensure that results are always non-empty.
    53  func (p *CopyPropagator) WithEnsureNonEmptyBody(yes bool) *CopyPropagator {
    54  	p.ensureNonEmptyBody = yes
    55  	return p
    56  }
    57  
    58  // WithCompiler configures the compiler to read from while processing the query. This
    59  // should be the same compiler used to compile the original policy.
    60  func (p *CopyPropagator) WithCompiler(c *ast.Compiler) *CopyPropagator {
    61  	p.compiler = c
    62  	return p
    63  }
    64  
    65  // Apply executes the copy propagation optimization and returns a new query.
    66  func (p *CopyPropagator) Apply(query ast.Body) ast.Body {
    67  
    68  	result := ast.NewBody()
    69  
    70  	uf, ok := makeDisjointSets(p.livevars, query)
    71  	if !ok {
    72  		return query
    73  	}
    74  
    75  	// Compute set of vars that appear in the head of refs in the query. If a var
    76  	// is dereferenced, we can plug it with a constant value, but it is not always
    77  	// optimal to do so.
    78  	// TODO: Improve the algorithm for when we should plug constants/calls/etc
    79  	headvars := ast.NewVarSet()
    80  	ast.WalkRefs(query, func(x ast.Ref) bool {
    81  		if v, ok := x[0].Value.(ast.Var); ok {
    82  			if root, ok := uf.Find(v); ok {
    83  				root.constant = nil
    84  				headvars.Add(root.key.(ast.Var))
    85  			} else {
    86  				headvars.Add(v)
    87  			}
    88  		}
    89  		return false
    90  	})
    91  
    92  	removedEqs := ast.NewValueMap()
    93  
    94  	for _, expr := range query {
    95  
    96  		pctx := &plugContext{
    97  			removedEqs: removedEqs,
    98  			uf:         uf,
    99  			negated:    expr.Negated,
   100  			headvars:   headvars,
   101  		}
   102  
   103  		expr = p.plugBindings(pctx, expr)
   104  
   105  		if p.updateBindings(pctx, expr) {
   106  			result.Append(expr)
   107  		}
   108  	}
   109  
   110  	// Run post-processing step on the query to ensure that all live vars are bound
   111  	// in the result. The plugging that happens above substitutes all vars in the
   112  	// same set with the root.
   113  	//
   114  	// This step should run before the next step to prevent unnecessary bindings
   115  	// from being added to the result. For example:
   116  	//
   117  	// - Given the following result: <empty>
   118  	// - Given the following removed equalities: "x = input.x" and "y = input"
   119  	// - Given the following liveset: {x}
   120  	//
   121  	// If this step were to run AFTER the following step, the output would be:
   122  	//
   123  	//	x = input.x; y = input
   124  	//
   125  	// Even though y = input is not required.
   126  	for _, v := range p.sorted {
   127  		if root, ok := uf.Find(v); ok {
   128  			if root.constant != nil {
   129  				result.Append(ast.Equality.Expr(ast.NewTerm(v), root.constant))
   130  			} else if b := removedEqs.Get(root.key); b != nil {
   131  				result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(b)))
   132  			} else if root.key != v {
   133  				result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(root.key)))
   134  			}
   135  		}
   136  	}
   137  
   138  	// Run post-processing step on query to ensure that all killed exprs are
   139  	// accounted for. There are several cases we look for:
   140  	//
   141  	// * If an expr is killed but the binding is never used, the query
   142  	//   must still include the expr. For example, given the query 'input.x = a' and
   143  	//   an empty livevar set, the result must include the ref input.x otherwise the
   144  	//   query could be satisfied without input.x being defined.
   145  	//
   146  	// * If an expr is killed that provided safety to vars which are not
   147  	//   otherwise being made safe by the current result.
   148  	//
   149  	// For any of these cases we re-add the removed equality expression
   150  	// to the current result.
   151  
   152  	// Invariant: Live vars are bound (above) and reserved vars are implicitly ground.
   153  	safe := ast.ReservedVars.Copy()
   154  	safe.Update(p.livevars)
   155  	safe.Update(ast.OutputVarsFromBody(p.compiler, result, safe))
   156  	unsafe := result.Vars(ast.SafetyCheckVisitorParams).Diff(safe)
   157  
   158  	for _, b := range sortbindings(removedEqs) {
   159  		removedEq := ast.Equality.Expr(ast.NewTerm(b.k), ast.NewTerm(b.v))
   160  
   161  		providesSafety := false
   162  		outputVars := ast.OutputVarsFromExpr(p.compiler, removedEq, safe)
   163  		diff := unsafe.Diff(outputVars)
   164  		if len(diff) < len(unsafe) {
   165  			unsafe = diff
   166  			providesSafety = true
   167  		}
   168  
   169  		if providesSafety || !containedIn(b.v, result) {
   170  			result.Append(removedEq)
   171  			safe.Update(outputVars)
   172  		}
   173  	}
   174  
   175  	if len(unsafe) > 0 {
   176  		// NOTE(tsandall): This should be impossible but if it does occur, throw
   177  		// away the result rather than generating unsafe output.
   178  		return query
   179  	}
   180  
   181  	if p.ensureNonEmptyBody && len(result) == 0 {
   182  		result = append(result, ast.NewExpr(ast.BooleanTerm(true)))
   183  	}
   184  
   185  	return result
   186  }
   187  
   188  // plugBindings applies the binding list and union-find to x. This process
   189  // removes as many variables as possible.
   190  func (p *CopyPropagator) plugBindings(pctx *plugContext, expr *ast.Expr) *ast.Expr {
   191  
   192  	xform := bindingPlugTransform{
   193  		pctx: pctx,
   194  	}
   195  
   196  	// Deep copy the expression as it may be mutated during the transform and
   197  	// the caller running copy propagation may have references to the
   198  	// expression. Note, the transform does not contain any error paths and
   199  	// should never return a non-expression value for the root so consider
   200  	// errors unreachable.
   201  	x, err := ast.Transform(xform, expr.Copy())
   202  
   203  	if expr, ok := x.(*ast.Expr); !ok || err != nil {
   204  		panic("unreachable")
   205  	} else {
   206  		return expr
   207  	}
   208  }
   209  
   210  type bindingPlugTransform struct {
   211  	pctx *plugContext
   212  }
   213  
   214  func (t bindingPlugTransform) Transform(x interface{}) (interface{}, error) {
   215  	switch x := x.(type) {
   216  	case ast.Var:
   217  		return t.plugBindingsVar(t.pctx, x), nil
   218  	case ast.Ref:
   219  		return t.plugBindingsRef(t.pctx, x), nil
   220  	default:
   221  		return x, nil
   222  	}
   223  }
   224  
   225  func (t bindingPlugTransform) plugBindingsVar(pctx *plugContext, v ast.Var) ast.Value {
   226  
   227  	var result ast.Value = v
   228  
   229  	// Apply union-find to remove redundant variables from input.
   230  	root, ok := pctx.uf.Find(v)
   231  	if ok {
   232  		result = root.Value()
   233  	}
   234  
   235  	// Apply binding list to substitute remaining vars.
   236  	v, ok = result.(ast.Var)
   237  	if !ok {
   238  		return result
   239  	}
   240  	b := pctx.removedEqs.Get(v)
   241  	if b == nil {
   242  		return result
   243  	}
   244  	if pctx.negated && !b.IsGround() {
   245  		return result
   246  	}
   247  
   248  	if r, ok := b.(ast.Ref); ok && r.OutputVars().Contains(v) {
   249  		return result
   250  	}
   251  
   252  	return b
   253  }
   254  
   255  func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.Ref {
   256  
   257  	// Apply union-find to remove redundant variables from input.
   258  	if root, ok := pctx.uf.Find(v[0].Value); ok {
   259  		v[0].Value = root.Value()
   260  	}
   261  
   262  	result := v
   263  
   264  	// Refs require special handling. If the head of the ref was killed, then
   265  	// the rest of the ref must be concatenated with the new base.
   266  	if b := pctx.removedEqs.Get(v[0].Value); b != nil {
   267  		if !pctx.negated || b.IsGround() {
   268  			var base ast.Ref
   269  			switch x := b.(type) {
   270  			case ast.Ref:
   271  				base = x
   272  			default:
   273  				base = ast.Ref{ast.NewTerm(x)}
   274  			}
   275  			result = base.Concat(v[1:])
   276  		}
   277  	}
   278  
   279  	return result
   280  }
   281  
   282  // updateBindings returns false if the expression can be killed. If the
   283  // expression is killed, the binding list is updated to map a var to value.
   284  func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool {
   285  	if pctx.negated || len(expr.With) > 0 {
   286  		return true
   287  	}
   288  	if expr.IsEquality() {
   289  		a, b := expr.Operand(0), expr.Operand(1)
   290  		if a.Equal(b) {
   291  			return false
   292  		}
   293  		k, v, keep := p.updateBindingsEq(a, b)
   294  		if !keep {
   295  			if v != nil {
   296  				pctx.removedEqs.Put(k, v)
   297  			}
   298  			return false
   299  		}
   300  	} else if expr.IsCall() {
   301  		terms := expr.Terms.([]*ast.Term)
   302  		if p.compiler.GetArity(expr.Operator()) == len(terms)-2 { // with captured output
   303  			output := terms[len(terms)-1]
   304  			if k, ok := output.Value.(ast.Var); ok && !p.livevars.Contains(k) && !pctx.headvars.Contains(k) {
   305  				pctx.removedEqs.Put(k, ast.CallTerm(terms[:len(terms)-1]...).Value)
   306  				return false
   307  			}
   308  		}
   309  	}
   310  	return !isNoop(expr)
   311  }
   312  
   313  func (p *CopyPropagator) updateBindingsEq(a, b *ast.Term) (ast.Var, ast.Value, bool) {
   314  	k, v, keep := p.updateBindingsEqAsymmetric(a, b)
   315  	if !keep {
   316  		return k, v, keep
   317  	}
   318  	return p.updateBindingsEqAsymmetric(b, a)
   319  }
   320  
   321  func (p *CopyPropagator) updateBindingsEqAsymmetric(a, b *ast.Term) (ast.Var, ast.Value, bool) {
   322  	k, ok := a.Value.(ast.Var)
   323  	if !ok || p.livevars.Contains(k) {
   324  		return "", nil, true
   325  	}
   326  
   327  	switch b.Value.(type) {
   328  	case ast.Ref, ast.Call:
   329  		return k, b.Value, false
   330  	}
   331  
   332  	return "", nil, true
   333  }
   334  
   335  type plugContext struct {
   336  	removedEqs *ast.ValueMap
   337  	uf         *unionFind
   338  	headvars   ast.VarSet
   339  	negated    bool
   340  }
   341  
   342  type binding struct {
   343  	k ast.Value
   344  	v ast.Value
   345  }
   346  
   347  func containedIn(value ast.Value, x interface{}) bool {
   348  	var stop bool
   349  	switch v := value.(type) {
   350  	case ast.Ref:
   351  		ast.WalkRefs(x, func(other ast.Ref) bool {
   352  			if stop || other.HasPrefix(v) {
   353  				stop = true
   354  				return stop
   355  			}
   356  			return false
   357  		})
   358  	default:
   359  		ast.WalkTerms(x, func(other *ast.Term) bool {
   360  			if stop || other.Value.Compare(v) == 0 {
   361  				stop = true
   362  				return stop
   363  			}
   364  			return false
   365  		})
   366  	}
   367  	return stop
   368  }
   369  
   370  func sortbindings(bindings *ast.ValueMap) []*binding {
   371  	sorted := make([]*binding, 0, bindings.Len())
   372  	bindings.Iter(func(k ast.Value, v ast.Value) bool {
   373  		sorted = append(sorted, &binding{k, v})
   374  		return false
   375  	})
   376  	sort.Slice(sorted, func(i, j int) bool {
   377  		return sorted[i].k.Compare(sorted[j].k) < 0
   378  	})
   379  	return sorted
   380  }
   381  
   382  // makeDisjointSets builds the union-find structure for the query. The structure
   383  // is built by processing all of the equality exprs in the query. Sets represent
   384  // vars that must be equal to each other. In addition to vars, each set can have
   385  // at most one constant. If the query contains expressions that cannot be
   386  // satisfied (e.g., because a set has multiple constants) this function returns
   387  // false.
   388  func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
   389  	uf := newUnionFind(func(r1, r2 *unionFindRoot) (*unionFindRoot, *unionFindRoot) {
   390  		if v, ok := r1.key.(ast.Var); ok && livevars.Contains(v) {
   391  			return r1, r2
   392  		}
   393  		return r2, r1
   394  	})
   395  	for _, expr := range query {
   396  		if expr.IsEquality() && !expr.Negated && len(expr.With) == 0 {
   397  			a, b := expr.Operand(0), expr.Operand(1)
   398  			varA, ok1 := a.Value.(ast.Var)
   399  			varB, ok2 := b.Value.(ast.Var)
   400  			if ok1 && ok2 {
   401  				if _, ok := uf.Merge(varA, varB); !ok {
   402  					return nil, false
   403  				}
   404  			} else if ok1 && ast.IsConstant(b.Value) {
   405  				root := uf.MakeSet(varA)
   406  				if root.constant != nil && !root.constant.Equal(b) {
   407  					return nil, false
   408  				}
   409  				root.constant = b
   410  			} else if ok2 && ast.IsConstant(a.Value) {
   411  				root := uf.MakeSet(varB)
   412  				if root.constant != nil && !root.constant.Equal(a) {
   413  					return nil, false
   414  				}
   415  				root.constant = a
   416  			}
   417  		}
   418  	}
   419  
   420  	return uf, true
   421  }
   422  
   423  func isNoop(expr *ast.Expr) bool {
   424  
   425  	if !expr.IsCall() && !expr.IsEvery() {
   426  		term := expr.Terms.(*ast.Term)
   427  		if !ast.IsConstant(term.Value) {
   428  			return false
   429  		}
   430  		return !ast.Boolean(false).Equal(term.Value)
   431  	}
   432  
   433  	// A==A can be ignored
   434  	if expr.Operator().Equal(ast.Equal.Ref()) {
   435  		return expr.Operand(0).Equal(expr.Operand(1))
   436  	}
   437  
   438  	return false
   439  }