github.com/rajeev159/opa@v0.45.0/topdown/copypropagation/unionfind.go (about)

     1  // Copyright 2020 The OPA Authors.  All rights reserved.
     2  // Use of this source code is governed by an Apache2
     3  // license that can be found in the LICENSE file.
     4  
     5  package copypropagation
     6  
     7  import (
     8  	"fmt"
     9  
    10  	"github.com/open-policy-agent/opa/ast"
    11  	"github.com/open-policy-agent/opa/util"
    12  )
    13  
    14  type rankFunc func(*unionFindRoot, *unionFindRoot) (*unionFindRoot, *unionFindRoot)
    15  
    16  type unionFind struct {
    17  	roots   *util.HashMap
    18  	parents *ast.ValueMap
    19  	rank    rankFunc
    20  }
    21  
    22  func newUnionFind(rank rankFunc) *unionFind {
    23  	return &unionFind{
    24  		roots: util.NewHashMap(func(a util.T, b util.T) bool {
    25  			return a.(ast.Value).Compare(b.(ast.Value)) == 0
    26  		}, func(v util.T) int {
    27  			return v.(ast.Value).Hash()
    28  		}),
    29  		parents: ast.NewValueMap(),
    30  		rank:    rank,
    31  	}
    32  }
    33  
    34  func (uf *unionFind) MakeSet(v ast.Value) *unionFindRoot {
    35  
    36  	root, ok := uf.Find(v)
    37  	if ok {
    38  		return root
    39  	}
    40  
    41  	root = newUnionFindRoot(v)
    42  	uf.parents.Put(v, v)
    43  	uf.roots.Put(v, root)
    44  	return root
    45  }
    46  
    47  func (uf *unionFind) Find(v ast.Value) (*unionFindRoot, bool) {
    48  
    49  	parent := uf.parents.Get(v)
    50  	if parent == nil {
    51  		return nil, false
    52  	}
    53  
    54  	if parent.Compare(v) == 0 {
    55  		r, ok := uf.roots.Get(v)
    56  		return r.(*unionFindRoot), ok
    57  	}
    58  
    59  	return uf.Find(parent)
    60  }
    61  
    62  func (uf *unionFind) Merge(a, b ast.Value) (*unionFindRoot, bool) {
    63  
    64  	r1 := uf.MakeSet(a)
    65  	r2 := uf.MakeSet(b)
    66  
    67  	if r1 != r2 {
    68  
    69  		r1, r2 = uf.rank(r1, r2)
    70  
    71  		uf.parents.Put(r2.key, r1.key)
    72  		uf.roots.Delete(r2.key)
    73  
    74  		// Sets can have at most one constant value associated with them. When
    75  		// unioning, we must preserve this invariant. If a set has two constants,
    76  		// there will be no way to prove the query.
    77  		if r1.constant != nil && r2.constant != nil && !r1.constant.Equal(r2.constant) {
    78  			return nil, false
    79  		} else if r1.constant == nil {
    80  			r1.constant = r2.constant
    81  		}
    82  	}
    83  
    84  	return r1, true
    85  }
    86  
    87  func (uf *unionFind) String() string {
    88  	o := struct {
    89  		Roots   map[string]interface{}
    90  		Parents map[string]ast.Value
    91  	}{
    92  		map[string]interface{}{},
    93  		map[string]ast.Value{},
    94  	}
    95  
    96  	uf.roots.Iter(func(k util.T, v util.T) bool {
    97  		o.Roots[k.(ast.Value).String()] = struct {
    98  			Constant *ast.Term
    99  			Key      ast.Value
   100  		}{
   101  			v.(*unionFindRoot).constant,
   102  			v.(*unionFindRoot).key,
   103  		}
   104  		return true
   105  	})
   106  
   107  	uf.parents.Iter(func(k ast.Value, v ast.Value) bool {
   108  		o.Parents[k.String()] = v
   109  		return true
   110  	})
   111  
   112  	return string(util.MustMarshalJSON(o))
   113  }
   114  
   115  type unionFindRoot struct {
   116  	key      ast.Value
   117  	constant *ast.Term
   118  }
   119  
   120  func newUnionFindRoot(key ast.Value) *unionFindRoot {
   121  	return &unionFindRoot{
   122  		key: key,
   123  	}
   124  }
   125  
   126  func (r *unionFindRoot) Value() ast.Value {
   127  	if r.constant != nil {
   128  		return r.constant.Value
   129  	}
   130  	return r.key
   131  }
   132  
   133  func (r *unionFindRoot) String() string {
   134  	return fmt.Sprintf("{key: %s, constant: %s", r.key, r.constant)
   135  }