github.com/rajeev159/opa@v0.45.0/topdown/copypropagation/unionfind_test.go (about)

     1  // Copyright 2020 The OPA Authors.  All rights reserved.
     2  // Use of this source code is governed by an Apache2
     3  // license that can be found in the LICENSE file.
     4  
     5  package copypropagation
     6  
     7  import (
     8  	"reflect"
     9  	"testing"
    10  
    11  	"github.com/open-policy-agent/opa/ast"
    12  )
    13  
    14  func TestUnionFindRootValue(t *testing.T) {
    15  	tests := []struct {
    16  		name     string
    17  		root     unionFindRoot
    18  		expected ast.Value
    19  	}{
    20  		{
    21  			name:     "var only",
    22  			root:     unionFindRoot{key: ast.Var("foo")},
    23  			expected: ast.Var("foo"),
    24  		},
    25  		{
    26  			name:     "const only",
    27  			root:     unionFindRoot{constant: ast.StringTerm("foo")},
    28  			expected: ast.String("foo"),
    29  		},
    30  		{
    31  			name:     "const and var",
    32  			root:     unionFindRoot{key: ast.Var("foo"), constant: ast.StringTerm("bar")},
    33  			expected: ast.String("bar"),
    34  		},
    35  	}
    36  	for _, tc := range tests {
    37  		t.Run(tc.name, func(t *testing.T) {
    38  			r := &unionFindRoot{
    39  				key:      tc.root.key,
    40  				constant: tc.root.constant,
    41  			}
    42  			if got := r.Value(); tc.expected.Compare(got) != 0 {
    43  				t.Errorf("Value() = %v, expected %v", got, tc.expected)
    44  			}
    45  		})
    46  	}
    47  }
    48  
    49  func TestUnionFindMakeSet(t *testing.T) {
    50  
    51  	uf := newUnionFind(nil)
    52  
    53  	tests := []struct {
    54  		name    string
    55  		v       ast.Value
    56  		result  *unionFindRoot
    57  		parents map[ast.Value]ast.Value
    58  		roots   map[ast.Value]*unionFindRoot
    59  	}{
    60  		{
    61  			name:   "from empty",
    62  			v:      ast.Var("a"),
    63  			result: &unionFindRoot{key: ast.Var("a")},
    64  		},
    65  		{
    66  			name:   "add another var",
    67  			v:      ast.Var("b"),
    68  			result: &unionFindRoot{key: ast.Var("b")},
    69  		},
    70  		{
    71  			name:   "add existing",
    72  			v:      ast.Var("b"),
    73  			result: &unionFindRoot{key: ast.Var("b")},
    74  		},
    75  		{
    76  			name:   "add ref",
    77  			v:      ast.Ref{ast.StringTerm("foo")},
    78  			result: &unionFindRoot{key: ast.Ref{ast.StringTerm("foo")}},
    79  		},
    80  		{
    81  			name:   "add ref existing",
    82  			v:      ast.Ref{ast.StringTerm("foo")},
    83  			result: &unionFindRoot{key: ast.Ref{ast.StringTerm("foo")}},
    84  		},
    85  	}
    86  	for _, tc := range tests {
    87  		t.Run(tc.name, func(t *testing.T) {
    88  			actual := uf.MakeSet(tc.v)
    89  			if !reflect.DeepEqual(actual, tc.result) {
    90  				t.Errorf("MakeSet(%v) = %v, expected %v", tc.v, actual, tc.result)
    91  			}
    92  		})
    93  	}
    94  }
    95  
    96  func TestUnionFindFindEmptyUF(t *testing.T) {
    97  	uf := newUnionFind(noopUnionFindRank)
    98  	actual, found := uf.Find(ast.Var("a"))
    99  	if found || actual != nil {
   100  		t.Error("Expected Find() to return (nil, false)")
   101  	}
   102  }
   103  
   104  func TestUnionFindFindIsParent(t *testing.T) {
   105  	uf := newUnionFind(noopUnionFindRank)
   106  
   107  	uf.MakeSet(ast.Var("a")) // "a" will have a parent "a"
   108  
   109  	actual, found := uf.Find(ast.Var("a"))
   110  
   111  	expected := newUnionFindRoot(ast.Var("a"))
   112  	if !found || actual.Value().Compare(expected.Value()) != 0 {
   113  		t.Errorf("Expected Find() to return (true, %+v)", expected)
   114  	}
   115  }
   116  
   117  func TestUnionFindFindParent(t *testing.T) {
   118  	fooBarRef := ast.Ref{ast.StringTerm("foo"), ast.StringTerm("bar"), ast.VarTerm("x")}
   119  	call := ast.Call{ast.RefTerm(ast.VarTerm("gt")), ast.NumberTerm("1"), ast.VarTerm("x")}
   120  
   121  	uf := newUnionFind(noopUnionFindRank)
   122  	uf.Merge(ast.Var("a"), ast.Var("b"))
   123  	uf.Merge(ast.Var("b"), ast.Var("c"))
   124  	uf.Merge(ast.Var("c"), fooBarRef)
   125  	uf.Merge(fooBarRef, ast.Var("d"))
   126  	uf.Merge(ast.Var("d"), call)
   127  	uf.Merge(call, ast.Var("e"))
   128  
   129  	actual, found := uf.Find(ast.Var("e"))
   130  
   131  	expected := newUnionFindRoot(ast.Var("a"))
   132  	if !found || actual.Value().Compare(expected.Value()) != 0 {
   133  		t.Errorf("Expected Find() to return (true, %+v)", expected)
   134  	}
   135  }
   136  
   137  func TestUnionFindMerge(t *testing.T) {
   138  	uf := newUnionFind(noopUnionFindRank)
   139  
   140  	tests := []struct {
   141  		name    string
   142  		a       ast.Value
   143  		b       ast.Value
   144  		result  *unionFindRoot
   145  		parents map[ast.Value]ast.Value
   146  		roots   map[ast.Value]*unionFindRoot
   147  	}{
   148  		{
   149  			name:   "empty uf",
   150  			a:      ast.Var("a"),
   151  			b:      ast.Var("b"),
   152  			result: newUnionFindRoot(ast.Var("a")),
   153  		},
   154  		{
   155  			name:   "same values",
   156  			a:      ast.Var("a"),
   157  			b:      ast.Var("a"),
   158  			result: newUnionFindRoot(ast.Var("a")),
   159  		},
   160  		{
   161  			name:   "same values higher rank result",
   162  			a:      ast.Var("b"),
   163  			b:      ast.Var("b"),
   164  			result: newUnionFindRoot(ast.Var("a")),
   165  		},
   166  		{
   167  			name:   "transitive",
   168  			a:      ast.Var("b"),
   169  			b:      ast.Var("c"),
   170  			result: newUnionFindRoot(ast.Var("a")),
   171  		},
   172  		{
   173  			name:   "new roots",
   174  			a:      ast.Var("d"),
   175  			b:      ast.Var("e"),
   176  			result: newUnionFindRoot(ast.Var("d")),
   177  		},
   178  		{
   179  			name:   "combine roots",
   180  			a:      ast.Var("a"),
   181  			b:      ast.Var("e"),
   182  			result: newUnionFindRoot(ast.Var("a")),
   183  		},
   184  		{
   185  			name:   "new ref roots",
   186  			a:      ast.Ref{ast.StringTerm("foo"), ast.StringTerm("bar")},
   187  			b:      ast.Var("x"),
   188  			result: newUnionFindRoot(ast.Ref{ast.StringTerm("foo"), ast.StringTerm("bar")}),
   189  		},
   190  		{
   191  			name:   "combine ref roots",
   192  			a:      ast.Var("b"),
   193  			b:      ast.Ref{ast.StringTerm("foo"), ast.StringTerm("bar")},
   194  			result: newUnionFindRoot(ast.Var("a")),
   195  		},
   196  	}
   197  	for _, tc := range tests {
   198  		t.Run(tc.name, func(t *testing.T) {
   199  			actualRoot, canMerge := uf.Merge(tc.a, tc.b)
   200  			if !reflect.DeepEqual(actualRoot, tc.result) || !canMerge {
   201  				t.Errorf("Merge(%v, %v) got = (%v, %v), expected (%v, true)", tc.a, tc.b, actualRoot, canMerge, tc.result)
   202  			}
   203  		})
   204  	}
   205  }
   206  
   207  var noopUnionFindRank = func(a *unionFindRoot, b *unionFindRoot) (*unionFindRoot, *unionFindRoot) {
   208  	return a, b
   209  }