github.com/rajeev159/opa@v0.45.0/topdown/copypropagation/unionfind_test.go (about) 1 // Copyright 2020 The OPA Authors. All rights reserved. 2 // Use of this source code is governed by an Apache2 3 // license that can be found in the LICENSE file. 4 5 package copypropagation 6 7 import ( 8 "reflect" 9 "testing" 10 11 "github.com/open-policy-agent/opa/ast" 12 ) 13 14 func TestUnionFindRootValue(t *testing.T) { 15 tests := []struct { 16 name string 17 root unionFindRoot 18 expected ast.Value 19 }{ 20 { 21 name: "var only", 22 root: unionFindRoot{key: ast.Var("foo")}, 23 expected: ast.Var("foo"), 24 }, 25 { 26 name: "const only", 27 root: unionFindRoot{constant: ast.StringTerm("foo")}, 28 expected: ast.String("foo"), 29 }, 30 { 31 name: "const and var", 32 root: unionFindRoot{key: ast.Var("foo"), constant: ast.StringTerm("bar")}, 33 expected: ast.String("bar"), 34 }, 35 } 36 for _, tc := range tests { 37 t.Run(tc.name, func(t *testing.T) { 38 r := &unionFindRoot{ 39 key: tc.root.key, 40 constant: tc.root.constant, 41 } 42 if got := r.Value(); tc.expected.Compare(got) != 0 { 43 t.Errorf("Value() = %v, expected %v", got, tc.expected) 44 } 45 }) 46 } 47 } 48 49 func TestUnionFindMakeSet(t *testing.T) { 50 51 uf := newUnionFind(nil) 52 53 tests := []struct { 54 name string 55 v ast.Value 56 result *unionFindRoot 57 parents map[ast.Value]ast.Value 58 roots map[ast.Value]*unionFindRoot 59 }{ 60 { 61 name: "from empty", 62 v: ast.Var("a"), 63 result: &unionFindRoot{key: ast.Var("a")}, 64 }, 65 { 66 name: "add another var", 67 v: ast.Var("b"), 68 result: &unionFindRoot{key: ast.Var("b")}, 69 }, 70 { 71 name: "add existing", 72 v: ast.Var("b"), 73 result: &unionFindRoot{key: ast.Var("b")}, 74 }, 75 { 76 name: "add ref", 77 v: ast.Ref{ast.StringTerm("foo")}, 78 result: &unionFindRoot{key: ast.Ref{ast.StringTerm("foo")}}, 79 }, 80 { 81 name: "add ref existing", 82 v: ast.Ref{ast.StringTerm("foo")}, 83 result: &unionFindRoot{key: ast.Ref{ast.StringTerm("foo")}}, 84 }, 85 } 86 for _, tc := range tests { 87 t.Run(tc.name, func(t *testing.T) { 88 actual := uf.MakeSet(tc.v) 89 if !reflect.DeepEqual(actual, tc.result) { 90 t.Errorf("MakeSet(%v) = %v, expected %v", tc.v, actual, tc.result) 91 } 92 }) 93 } 94 } 95 96 func TestUnionFindFindEmptyUF(t *testing.T) { 97 uf := newUnionFind(noopUnionFindRank) 98 actual, found := uf.Find(ast.Var("a")) 99 if found || actual != nil { 100 t.Error("Expected Find() to return (nil, false)") 101 } 102 } 103 104 func TestUnionFindFindIsParent(t *testing.T) { 105 uf := newUnionFind(noopUnionFindRank) 106 107 uf.MakeSet(ast.Var("a")) // "a" will have a parent "a" 108 109 actual, found := uf.Find(ast.Var("a")) 110 111 expected := newUnionFindRoot(ast.Var("a")) 112 if !found || actual.Value().Compare(expected.Value()) != 0 { 113 t.Errorf("Expected Find() to return (true, %+v)", expected) 114 } 115 } 116 117 func TestUnionFindFindParent(t *testing.T) { 118 fooBarRef := ast.Ref{ast.StringTerm("foo"), ast.StringTerm("bar"), ast.VarTerm("x")} 119 call := ast.Call{ast.RefTerm(ast.VarTerm("gt")), ast.NumberTerm("1"), ast.VarTerm("x")} 120 121 uf := newUnionFind(noopUnionFindRank) 122 uf.Merge(ast.Var("a"), ast.Var("b")) 123 uf.Merge(ast.Var("b"), ast.Var("c")) 124 uf.Merge(ast.Var("c"), fooBarRef) 125 uf.Merge(fooBarRef, ast.Var("d")) 126 uf.Merge(ast.Var("d"), call) 127 uf.Merge(call, ast.Var("e")) 128 129 actual, found := uf.Find(ast.Var("e")) 130 131 expected := newUnionFindRoot(ast.Var("a")) 132 if !found || actual.Value().Compare(expected.Value()) != 0 { 133 t.Errorf("Expected Find() to return (true, %+v)", expected) 134 } 135 } 136 137 func TestUnionFindMerge(t *testing.T) { 138 uf := newUnionFind(noopUnionFindRank) 139 140 tests := []struct { 141 name string 142 a ast.Value 143 b ast.Value 144 result *unionFindRoot 145 parents map[ast.Value]ast.Value 146 roots map[ast.Value]*unionFindRoot 147 }{ 148 { 149 name: "empty uf", 150 a: ast.Var("a"), 151 b: ast.Var("b"), 152 result: newUnionFindRoot(ast.Var("a")), 153 }, 154 { 155 name: "same values", 156 a: ast.Var("a"), 157 b: ast.Var("a"), 158 result: newUnionFindRoot(ast.Var("a")), 159 }, 160 { 161 name: "same values higher rank result", 162 a: ast.Var("b"), 163 b: ast.Var("b"), 164 result: newUnionFindRoot(ast.Var("a")), 165 }, 166 { 167 name: "transitive", 168 a: ast.Var("b"), 169 b: ast.Var("c"), 170 result: newUnionFindRoot(ast.Var("a")), 171 }, 172 { 173 name: "new roots", 174 a: ast.Var("d"), 175 b: ast.Var("e"), 176 result: newUnionFindRoot(ast.Var("d")), 177 }, 178 { 179 name: "combine roots", 180 a: ast.Var("a"), 181 b: ast.Var("e"), 182 result: newUnionFindRoot(ast.Var("a")), 183 }, 184 { 185 name: "new ref roots", 186 a: ast.Ref{ast.StringTerm("foo"), ast.StringTerm("bar")}, 187 b: ast.Var("x"), 188 result: newUnionFindRoot(ast.Ref{ast.StringTerm("foo"), ast.StringTerm("bar")}), 189 }, 190 { 191 name: "combine ref roots", 192 a: ast.Var("b"), 193 b: ast.Ref{ast.StringTerm("foo"), ast.StringTerm("bar")}, 194 result: newUnionFindRoot(ast.Var("a")), 195 }, 196 } 197 for _, tc := range tests { 198 t.Run(tc.name, func(t *testing.T) { 199 actualRoot, canMerge := uf.Merge(tc.a, tc.b) 200 if !reflect.DeepEqual(actualRoot, tc.result) || !canMerge { 201 t.Errorf("Merge(%v, %v) got = (%v, %v), expected (%v, true)", tc.a, tc.b, actualRoot, canMerge, tc.result) 202 } 203 }) 204 } 205 } 206 207 var noopUnionFindRank = func(a *unionFindRoot, b *unionFindRoot) (*unionFindRoot, *unionFindRoot) { 208 return a, b 209 }