github.com/rajeev159/opa@v0.45.0/topdown/copypropagation/copypropagation.go (about) 1 // Copyright 2018 The OPA Authors. All rights reserved. 2 // Use of this source code is governed by an Apache2 3 // license that can be found in the LICENSE file. 4 5 package copypropagation 6 7 import ( 8 "sort" 9 10 "github.com/open-policy-agent/opa/ast" 11 ) 12 13 // CopyPropagator implements a simple copy propagation optimization to remove 14 // intermediate variables in partial evaluation results. 15 // 16 // For example, given the query: input.x > 1 where 'input' is unknown, the 17 // compiled query would become input.x = a; a > 1 which would remain in the 18 // partial evaluation result. The CopyPropagator will remove the variable 19 // assignment so that partial evaluation simply outputs input.x > 1. 20 // 21 // In many cases, copy propagation can remove all variables from the result of 22 // partial evaluation which simplifies evaluation for non-OPA consumers. 23 // 24 // In some cases, copy propagation cannot remove all variables. If the output of 25 // a built-in call is subsequently used as a ref head, the output variable must 26 // be kept. For example. sort(input, x); x[0] == 1. In this case, copy 27 // propagation cannot replace x[0] == 1 with sort(input, x)[0] == 1 as this is 28 // not legal. 29 type CopyPropagator struct { 30 livevars ast.VarSet // vars that must be preserved in the resulting query 31 sorted []ast.Var // sorted copy of vars to ensure deterministic result 32 ensureNonEmptyBody bool 33 compiler *ast.Compiler 34 } 35 36 // New returns a new CopyPropagator that optimizes queries while preserving vars 37 // in the livevars set. 38 func New(livevars ast.VarSet) *CopyPropagator { 39 40 sorted := make([]ast.Var, 0, len(livevars)) 41 for v := range livevars { 42 sorted = append(sorted, v) 43 } 44 45 sort.Slice(sorted, func(i, j int) bool { 46 return sorted[i].Compare(sorted[j]) < 0 47 }) 48 49 return &CopyPropagator{livevars: livevars, sorted: sorted} 50 } 51 52 // WithEnsureNonEmptyBody configures p to ensure that results are always non-empty. 53 func (p *CopyPropagator) WithEnsureNonEmptyBody(yes bool) *CopyPropagator { 54 p.ensureNonEmptyBody = yes 55 return p 56 } 57 58 // WithCompiler configures the compiler to read from while processing the query. This 59 // should be the same compiler used to compile the original policy. 60 func (p *CopyPropagator) WithCompiler(c *ast.Compiler) *CopyPropagator { 61 p.compiler = c 62 return p 63 } 64 65 // Apply executes the copy propagation optimization and returns a new query. 66 func (p *CopyPropagator) Apply(query ast.Body) ast.Body { 67 68 result := ast.NewBody() 69 70 uf, ok := makeDisjointSets(p.livevars, query) 71 if !ok { 72 return query 73 } 74 75 // Compute set of vars that appear in the head of refs in the query. If a var 76 // is dereferenced, we can plug it with a constant value, but it is not always 77 // optimal to do so. 78 // TODO: Improve the algorithm for when we should plug constants/calls/etc 79 headvars := ast.NewVarSet() 80 ast.WalkRefs(query, func(x ast.Ref) bool { 81 if v, ok := x[0].Value.(ast.Var); ok { 82 if root, ok := uf.Find(v); ok { 83 root.constant = nil 84 headvars.Add(root.key.(ast.Var)) 85 } else { 86 headvars.Add(v) 87 } 88 } 89 return false 90 }) 91 92 removedEqs := ast.NewValueMap() 93 94 for _, expr := range query { 95 96 pctx := &plugContext{ 97 removedEqs: removedEqs, 98 uf: uf, 99 negated: expr.Negated, 100 headvars: headvars, 101 } 102 103 expr = p.plugBindings(pctx, expr) 104 105 if p.updateBindings(pctx, expr) { 106 result.Append(expr) 107 } 108 } 109 110 // Run post-processing step on the query to ensure that all live vars are bound 111 // in the result. The plugging that happens above substitutes all vars in the 112 // same set with the root. 113 // 114 // This step should run before the next step to prevent unnecessary bindings 115 // from being added to the result. For example: 116 // 117 // - Given the following result: <empty> 118 // - Given the following removed equalities: "x = input.x" and "y = input" 119 // - Given the following liveset: {x} 120 // 121 // If this step were to run AFTER the following step, the output would be: 122 // 123 // x = input.x; y = input 124 // 125 // Even though y = input is not required. 126 for _, v := range p.sorted { 127 if root, ok := uf.Find(v); ok { 128 if root.constant != nil { 129 result.Append(ast.Equality.Expr(ast.NewTerm(v), root.constant)) 130 } else if b := removedEqs.Get(root.key); b != nil { 131 result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(b))) 132 } else if root.key != v { 133 result.Append(ast.Equality.Expr(ast.NewTerm(v), ast.NewTerm(root.key))) 134 } 135 } 136 } 137 138 // Run post-processing step on query to ensure that all killed exprs are 139 // accounted for. There are several cases we look for: 140 // 141 // * If an expr is killed but the binding is never used, the query 142 // must still include the expr. For example, given the query 'input.x = a' and 143 // an empty livevar set, the result must include the ref input.x otherwise the 144 // query could be satisfied without input.x being defined. 145 // 146 // * If an expr is killed that provided safety to vars which are not 147 // otherwise being made safe by the current result. 148 // 149 // For any of these cases we re-add the removed equality expression 150 // to the current result. 151 152 // Invariant: Live vars are bound (above) and reserved vars are implicitly ground. 153 safe := ast.ReservedVars.Copy() 154 safe.Update(p.livevars) 155 safe.Update(ast.OutputVarsFromBody(p.compiler, result, safe)) 156 unsafe := result.Vars(ast.SafetyCheckVisitorParams).Diff(safe) 157 158 for _, b := range sortbindings(removedEqs) { 159 removedEq := ast.Equality.Expr(ast.NewTerm(b.k), ast.NewTerm(b.v)) 160 161 providesSafety := false 162 outputVars := ast.OutputVarsFromExpr(p.compiler, removedEq, safe) 163 diff := unsafe.Diff(outputVars) 164 if len(diff) < len(unsafe) { 165 unsafe = diff 166 providesSafety = true 167 } 168 169 if providesSafety || !containedIn(b.v, result) { 170 result.Append(removedEq) 171 safe.Update(outputVars) 172 } 173 } 174 175 if len(unsafe) > 0 { 176 // NOTE(tsandall): This should be impossible but if it does occur, throw 177 // away the result rather than generating unsafe output. 178 return query 179 } 180 181 if p.ensureNonEmptyBody && len(result) == 0 { 182 result = append(result, ast.NewExpr(ast.BooleanTerm(true))) 183 } 184 185 return result 186 } 187 188 // plugBindings applies the binding list and union-find to x. This process 189 // removes as many variables as possible. 190 func (p *CopyPropagator) plugBindings(pctx *plugContext, expr *ast.Expr) *ast.Expr { 191 192 xform := bindingPlugTransform{ 193 pctx: pctx, 194 } 195 196 // Deep copy the expression as it may be mutated during the transform and 197 // the caller running copy propagation may have references to the 198 // expression. Note, the transform does not contain any error paths and 199 // should never return a non-expression value for the root so consider 200 // errors unreachable. 201 x, err := ast.Transform(xform, expr.Copy()) 202 203 if expr, ok := x.(*ast.Expr); !ok || err != nil { 204 panic("unreachable") 205 } else { 206 return expr 207 } 208 } 209 210 type bindingPlugTransform struct { 211 pctx *plugContext 212 } 213 214 func (t bindingPlugTransform) Transform(x interface{}) (interface{}, error) { 215 switch x := x.(type) { 216 case ast.Var: 217 return t.plugBindingsVar(t.pctx, x), nil 218 case ast.Ref: 219 return t.plugBindingsRef(t.pctx, x), nil 220 default: 221 return x, nil 222 } 223 } 224 225 func (t bindingPlugTransform) plugBindingsVar(pctx *plugContext, v ast.Var) ast.Value { 226 227 var result ast.Value = v 228 229 // Apply union-find to remove redundant variables from input. 230 root, ok := pctx.uf.Find(v) 231 if ok { 232 result = root.Value() 233 } 234 235 // Apply binding list to substitute remaining vars. 236 v, ok = result.(ast.Var) 237 if !ok { 238 return result 239 } 240 b := pctx.removedEqs.Get(v) 241 if b == nil { 242 return result 243 } 244 if pctx.negated && !b.IsGround() { 245 return result 246 } 247 248 if r, ok := b.(ast.Ref); ok && r.OutputVars().Contains(v) { 249 return result 250 } 251 252 return b 253 } 254 255 func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.Ref { 256 257 // Apply union-find to remove redundant variables from input. 258 if root, ok := pctx.uf.Find(v[0].Value); ok { 259 v[0].Value = root.Value() 260 } 261 262 result := v 263 264 // Refs require special handling. If the head of the ref was killed, then 265 // the rest of the ref must be concatenated with the new base. 266 if b := pctx.removedEqs.Get(v[0].Value); b != nil { 267 if !pctx.negated || b.IsGround() { 268 var base ast.Ref 269 switch x := b.(type) { 270 case ast.Ref: 271 base = x 272 default: 273 base = ast.Ref{ast.NewTerm(x)} 274 } 275 result = base.Concat(v[1:]) 276 } 277 } 278 279 return result 280 } 281 282 // updateBindings returns false if the expression can be killed. If the 283 // expression is killed, the binding list is updated to map a var to value. 284 func (p *CopyPropagator) updateBindings(pctx *plugContext, expr *ast.Expr) bool { 285 if pctx.negated || len(expr.With) > 0 { 286 return true 287 } 288 if expr.IsEquality() { 289 a, b := expr.Operand(0), expr.Operand(1) 290 if a.Equal(b) { 291 return false 292 } 293 k, v, keep := p.updateBindingsEq(a, b) 294 if !keep { 295 if v != nil { 296 pctx.removedEqs.Put(k, v) 297 } 298 return false 299 } 300 } else if expr.IsCall() { 301 terms := expr.Terms.([]*ast.Term) 302 if p.compiler.GetArity(expr.Operator()) == len(terms)-2 { // with captured output 303 output := terms[len(terms)-1] 304 if k, ok := output.Value.(ast.Var); ok && !p.livevars.Contains(k) && !pctx.headvars.Contains(k) { 305 pctx.removedEqs.Put(k, ast.CallTerm(terms[:len(terms)-1]...).Value) 306 return false 307 } 308 } 309 } 310 return !isNoop(expr) 311 } 312 313 func (p *CopyPropagator) updateBindingsEq(a, b *ast.Term) (ast.Var, ast.Value, bool) { 314 k, v, keep := p.updateBindingsEqAsymmetric(a, b) 315 if !keep { 316 return k, v, keep 317 } 318 return p.updateBindingsEqAsymmetric(b, a) 319 } 320 321 func (p *CopyPropagator) updateBindingsEqAsymmetric(a, b *ast.Term) (ast.Var, ast.Value, bool) { 322 k, ok := a.Value.(ast.Var) 323 if !ok || p.livevars.Contains(k) { 324 return "", nil, true 325 } 326 327 switch b.Value.(type) { 328 case ast.Ref, ast.Call: 329 return k, b.Value, false 330 } 331 332 return "", nil, true 333 } 334 335 type plugContext struct { 336 removedEqs *ast.ValueMap 337 uf *unionFind 338 headvars ast.VarSet 339 negated bool 340 } 341 342 type binding struct { 343 k ast.Value 344 v ast.Value 345 } 346 347 func containedIn(value ast.Value, x interface{}) bool { 348 var stop bool 349 switch v := value.(type) { 350 case ast.Ref: 351 ast.WalkRefs(x, func(other ast.Ref) bool { 352 if stop || other.HasPrefix(v) { 353 stop = true 354 return stop 355 } 356 return false 357 }) 358 default: 359 ast.WalkTerms(x, func(other *ast.Term) bool { 360 if stop || other.Value.Compare(v) == 0 { 361 stop = true 362 return stop 363 } 364 return false 365 }) 366 } 367 return stop 368 } 369 370 func sortbindings(bindings *ast.ValueMap) []*binding { 371 sorted := make([]*binding, 0, bindings.Len()) 372 bindings.Iter(func(k ast.Value, v ast.Value) bool { 373 sorted = append(sorted, &binding{k, v}) 374 return false 375 }) 376 sort.Slice(sorted, func(i, j int) bool { 377 return sorted[i].k.Compare(sorted[j].k) < 0 378 }) 379 return sorted 380 } 381 382 // makeDisjointSets builds the union-find structure for the query. The structure 383 // is built by processing all of the equality exprs in the query. Sets represent 384 // vars that must be equal to each other. In addition to vars, each set can have 385 // at most one constant. If the query contains expressions that cannot be 386 // satisfied (e.g., because a set has multiple constants) this function returns 387 // false. 388 func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) { 389 uf := newUnionFind(func(r1, r2 *unionFindRoot) (*unionFindRoot, *unionFindRoot) { 390 if v, ok := r1.key.(ast.Var); ok && livevars.Contains(v) { 391 return r1, r2 392 } 393 return r2, r1 394 }) 395 for _, expr := range query { 396 if expr.IsEquality() && !expr.Negated && len(expr.With) == 0 { 397 a, b := expr.Operand(0), expr.Operand(1) 398 varA, ok1 := a.Value.(ast.Var) 399 varB, ok2 := b.Value.(ast.Var) 400 if ok1 && ok2 { 401 if _, ok := uf.Merge(varA, varB); !ok { 402 return nil, false 403 } 404 } else if ok1 && ast.IsConstant(b.Value) { 405 root := uf.MakeSet(varA) 406 if root.constant != nil && !root.constant.Equal(b) { 407 return nil, false 408 } 409 root.constant = b 410 } else if ok2 && ast.IsConstant(a.Value) { 411 root := uf.MakeSet(varB) 412 if root.constant != nil && !root.constant.Equal(a) { 413 return nil, false 414 } 415 root.constant = a 416 } 417 } 418 } 419 420 return uf, true 421 } 422 423 func isNoop(expr *ast.Expr) bool { 424 425 if !expr.IsCall() && !expr.IsEvery() { 426 term := expr.Terms.(*ast.Term) 427 if !ast.IsConstant(term.Value) { 428 return false 429 } 430 return !ast.Boolean(false).Equal(term.Value) 431 } 432 433 // A==A can be ignored 434 if expr.Operator().Equal(ast.Equal.Ref()) { 435 return expr.Operand(0).Equal(expr.Operand(1)) 436 } 437 438 return false 439 }