github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/developmentmembership/trackingsubjectset_test.go (about)

     1  package developmentmembership
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/require"
     7  
     8  	"github.com/authzed/spicedb/pkg/genutil/mapz"
     9  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    10  	"github.com/authzed/spicedb/pkg/tuple"
    11  )
    12  
    13  func set(subjects ...*core.DirectSubject) *TrackingSubjectSet {
    14  	newSet := NewTrackingSubjectSet()
    15  	for _, subject := range subjects {
    16  		newSet.MustAdd(NewFoundSubject(subject))
    17  	}
    18  	return newSet
    19  }
    20  
    21  func union(firstSet *TrackingSubjectSet, sets ...*TrackingSubjectSet) *TrackingSubjectSet {
    22  	current := firstSet
    23  	for _, set := range sets {
    24  		current.MustAddFrom(set)
    25  	}
    26  
    27  	return current
    28  }
    29  
    30  func intersect(firstSet *TrackingSubjectSet, sets ...*TrackingSubjectSet) *TrackingSubjectSet {
    31  	current := firstSet
    32  	for _, set := range sets {
    33  		current = current.MustIntersect(set)
    34  	}
    35  	return current
    36  }
    37  
    38  func subtract(firstSet *TrackingSubjectSet, sets ...*TrackingSubjectSet) *TrackingSubjectSet {
    39  	current := firstSet
    40  	for _, set := range sets {
    41  		current = current.Exclude(set)
    42  	}
    43  	return current
    44  }
    45  
    46  func fs(subjectType string, subjectID string, subjectRel string, excludedSubjectIDs ...string) FoundSubject {
    47  	excludedSubjects := make([]FoundSubject, 0, len(excludedSubjectIDs))
    48  	for _, excludedSubjectID := range excludedSubjectIDs {
    49  		excludedSubjects = append(excludedSubjects, FoundSubject{subject: ONR(subjectType, excludedSubjectID, subjectRel)})
    50  	}
    51  
    52  	return FoundSubject{
    53  		subject:          ONR(subjectType, subjectID, subjectRel),
    54  		excludedSubjects: excludedSubjects,
    55  		relationships:    tuple.NewONRSet(),
    56  	}
    57  }
    58  
    59  func TestTrackingSubjectSet(t *testing.T) {
    60  	testCases := []struct {
    61  		name     string
    62  		set      *TrackingSubjectSet
    63  		expected []FoundSubject
    64  	}{
    65  		{
    66  			"simple set",
    67  			set(DS("user", "user1", "...")),
    68  			[]FoundSubject{fs("user", "user1", "...")},
    69  		},
    70  		{
    71  			"simple union",
    72  			union(
    73  				set(DS("user", "user1", "...")),
    74  				set(DS("user", "user2", "...")),
    75  				set(DS("user", "user3", "...")),
    76  			),
    77  			[]FoundSubject{
    78  				fs("user", "user1", "..."),
    79  				fs("user", "user2", "..."),
    80  				fs("user", "user3", "..."),
    81  			},
    82  		},
    83  		{
    84  			"simple intersection",
    85  			intersect(
    86  				set(
    87  					(DS("user", "user1", "...")),
    88  					(DS("user", "user2", "...")),
    89  				),
    90  				set(
    91  					(DS("user", "user2", "...")),
    92  					(DS("user", "user3", "...")),
    93  				),
    94  				set(
    95  					(DS("user", "user2", "...")),
    96  					(DS("user", "user4", "...")),
    97  				),
    98  			),
    99  			[]FoundSubject{fs("user", "user2", "...")},
   100  		},
   101  		{
   102  			"empty intersection",
   103  			intersect(
   104  				set(
   105  					(DS("user", "user1", "...")),
   106  					(DS("user", "user2", "...")),
   107  				),
   108  				set(
   109  					(DS("user", "user3", "...")),
   110  					(DS("user", "user4", "...")),
   111  				),
   112  			),
   113  			[]FoundSubject{},
   114  		},
   115  		{
   116  			"simple exclusion",
   117  			subtract(
   118  				set(
   119  					(DS("user", "user1", "...")),
   120  					(DS("user", "user2", "...")),
   121  				),
   122  				set(DS("user", "user2", "...")),
   123  				set(DS("user", "user3", "...")),
   124  			),
   125  			[]FoundSubject{fs("user", "user1", "...")},
   126  		},
   127  		{
   128  			"empty exclusion",
   129  			subtract(
   130  				set(
   131  					(DS("user", "user1", "...")),
   132  					(DS("user", "user2", "...")),
   133  				),
   134  				set(DS("user", "user1", "...")),
   135  				set(DS("user", "user2", "...")),
   136  			),
   137  			[]FoundSubject{},
   138  		},
   139  		{
   140  			"wildcard left side union",
   141  			union(
   142  				set(
   143  					(DS("user", "*", "...")),
   144  				),
   145  				set(DS("user", "user1", "...")),
   146  			),
   147  			[]FoundSubject{
   148  				fs("user", "*", "..."),
   149  				fs("user", "user1", "..."),
   150  			},
   151  		},
   152  		{
   153  			"wildcard right side union",
   154  			union(
   155  				set(DS("user", "user1", "...")),
   156  				set(
   157  					(DS("user", "*", "...")),
   158  				),
   159  			),
   160  			[]FoundSubject{
   161  				fs("user", "*", "..."),
   162  				fs("user", "user1", "..."),
   163  			},
   164  		},
   165  		{
   166  			"wildcard left side exclusion",
   167  			subtract(
   168  				set(
   169  					(DS("user", "*", "...")),
   170  					(DS("user", "user2", "...")),
   171  				),
   172  				set(DS("user", "user1", "...")),
   173  			),
   174  			[]FoundSubject{
   175  				fs("user", "*", "...", "user1"),
   176  				fs("user", "user2", "..."),
   177  			},
   178  		},
   179  		{
   180  			"wildcard right side exclusion",
   181  			subtract(
   182  				set(
   183  					(DS("user", "user2", "...")),
   184  				),
   185  				set(DS("user", "*", "...")),
   186  			),
   187  			[]FoundSubject{},
   188  		},
   189  		{
   190  			"wildcard right side concrete exclusion",
   191  			subtract(
   192  				set(
   193  					(DS("user", "*", "...")),
   194  				),
   195  				set(DS("user", "user1", "...")),
   196  			),
   197  			[]FoundSubject{
   198  				fs("user", "*", "...", "user1"),
   199  			},
   200  		},
   201  		{
   202  			"wildcard both sides exclusion",
   203  			subtract(
   204  				set(
   205  					(DS("user", "user2", "...")),
   206  					(DS("user", "*", "...")),
   207  				),
   208  				set(DS("user", "*", "...")),
   209  			),
   210  			[]FoundSubject{},
   211  		},
   212  		{
   213  			"wildcard left side intersection",
   214  			intersect(
   215  				set(
   216  					(DS("user", "*", "...")),
   217  					(DS("user", "user2", "...")),
   218  				),
   219  				set(DS("user", "user1", "...")),
   220  			),
   221  			[]FoundSubject{
   222  				fs("user", "user1", "..."),
   223  			},
   224  		},
   225  		{
   226  			"wildcard right side intersection",
   227  			intersect(
   228  				set(DS("user", "user1", "...")),
   229  				set(
   230  					(DS("user", "*", "...")),
   231  					(DS("user", "user2", "...")),
   232  				),
   233  			),
   234  			[]FoundSubject{
   235  				fs("user", "user1", "..."),
   236  			},
   237  		},
   238  		{
   239  			"wildcard both sides intersection",
   240  			intersect(
   241  				set(
   242  					(DS("user", "*", "...")),
   243  					(DS("user", "user1", "..."))),
   244  				set(
   245  					(DS("user", "*", "...")),
   246  					(DS("user", "user2", "...")),
   247  				),
   248  			),
   249  			[]FoundSubject{
   250  				fs("user", "*", "..."),
   251  				fs("user", "user1", "..."),
   252  				fs("user", "user2", "..."),
   253  			},
   254  		},
   255  		{
   256  			"wildcard with exclusions union",
   257  			union(
   258  				MustNewTrackingSubjectSetWith(fs("user", "*", "...", "user1")),
   259  				MustNewTrackingSubjectSetWith(fs("user", "*", "...", "user2")),
   260  			),
   261  			[]FoundSubject{
   262  				fs("user", "*", "..."),
   263  			},
   264  		},
   265  		{
   266  			"wildcard with exclusions intersection",
   267  			intersect(
   268  				MustNewTrackingSubjectSetWith(fs("user", "*", "...", "user1")),
   269  				MustNewTrackingSubjectSetWith(fs("user", "*", "...", "user2")),
   270  			),
   271  			[]FoundSubject{
   272  				fs("user", "*", "...", "user1", "user2"),
   273  			},
   274  		},
   275  		{
   276  			"wildcard with exclusions over subtraction",
   277  			subtract(
   278  				MustNewTrackingSubjectSetWith(
   279  					fs("user", "*", "...", "user1"),
   280  				),
   281  				MustNewTrackingSubjectSetWith(fs("user", "*", "...", "user2")),
   282  			),
   283  			[]FoundSubject{
   284  				fs("user", "user2", "..."),
   285  			},
   286  		},
   287  		{
   288  			"wildcard with exclusions excluded user added",
   289  			subtract(
   290  				MustNewTrackingSubjectSetWith(
   291  					fs("user", "*", "...", "user1"),
   292  				),
   293  				MustNewTrackingSubjectSetWith(fs("user", "user2", "...")),
   294  			),
   295  			[]FoundSubject{
   296  				fs("user", "*", "...", "user1", "user2"),
   297  			},
   298  		},
   299  		{
   300  			"wildcard multiple exclusions",
   301  			subtract(
   302  				MustNewTrackingSubjectSetWith(
   303  					fs("user", "*", "...", "user1"),
   304  				),
   305  				MustNewTrackingSubjectSetWith(fs("user", "user2", "...")),
   306  				MustNewTrackingSubjectSetWith(fs("user", "user3", "...")),
   307  			),
   308  			[]FoundSubject{
   309  				fs("user", "*", "...", "user1", "user2", "user3"),
   310  			},
   311  		},
   312  		{
   313  			"intersection of exclusions",
   314  			intersect(
   315  				MustNewTrackingSubjectSetWith(
   316  					fs("user", "*", "...", "user1"),
   317  				),
   318  				MustNewTrackingSubjectSetWith(
   319  					fs("user", "*", "...", "user2"),
   320  				),
   321  			),
   322  			[]FoundSubject{
   323  				fs("user", "*", "...", "user1", "user2"),
   324  			},
   325  		},
   326  	}
   327  
   328  	for _, tc := range testCases {
   329  		tc := tc
   330  		t.Run(tc.name, func(t *testing.T) {
   331  			require := require.New(t)
   332  			for _, fs := range tc.expected {
   333  				_, isWildcard := fs.WildcardType()
   334  				if isWildcard {
   335  					found, ok := tc.set.Get(fs.subject)
   336  					require.True(ok, "missing expected subject %s", fs.subject)
   337  
   338  					expectedExcluded := mapz.NewSet[string](fs.excludedSubjectStrings()...)
   339  					foundExcluded := mapz.NewSet[string](found.excludedSubjectStrings()...)
   340  					require.Len(expectedExcluded.Subtract(foundExcluded).AsSlice(), 0, "mismatch on excluded subjects on %s: expected: %s, found: %s", fs.subject, expectedExcluded, foundExcluded)
   341  					require.Len(foundExcluded.Subtract(expectedExcluded).AsSlice(), 0, "mismatch on excluded subjects on %s: expected: %s, found: %s", fs.subject, expectedExcluded, foundExcluded)
   342  				} else {
   343  					require.True(tc.set.Contains(fs.subject), "missing expected subject %s", fs.subject)
   344  				}
   345  				tc.set.removeExact(fs.subject)
   346  			}
   347  
   348  			require.True(tc.set.IsEmpty(), "Found remaining: %v", tc.set.getSubjects())
   349  		})
   350  	}
   351  }
   352  
   353  func TestTrackingSubjectSetResourceTracking(t *testing.T) {
   354  	tss := NewTrackingSubjectSet()
   355  	tss.MustAdd(NewFoundSubject(DS("user", "tom", "..."), ONR("resource", "foo", "viewer")))
   356  	tss.MustAdd(NewFoundSubject(DS("user", "tom", "..."), ONR("resource", "bar", "viewer")))
   357  
   358  	found, ok := tss.Get(ONR("user", "tom", "..."))
   359  	require.True(t, ok)
   360  	require.Equal(t, 2, len(found.Relationships()))
   361  
   362  	sss := NewTrackingSubjectSet()
   363  	sss.MustAdd(NewFoundSubject(DS("user", "tom", "..."), ONR("resource", "baz", "viewer")))
   364  
   365  	intersection, err := tss.Intersect(sss)
   366  	require.NoError(t, err)
   367  
   368  	found, ok = intersection.Get(ONR("user", "tom", "..."))
   369  	require.True(t, ok)
   370  	require.Equal(t, 3, len(found.Relationships()))
   371  }
   372  
   373  func TestTrackingSubjectSetResourceTrackingWithWildcard(t *testing.T) {
   374  	tss := NewTrackingSubjectSet()
   375  	tss.MustAdd(NewFoundSubject(DS("user", "tom", "..."), ONR("resource", "foo", "viewer")))
   376  
   377  	sss := NewTrackingSubjectSet()
   378  	sss.MustAdd(NewFoundSubject(DS("user", "*", "..."), ONR("resource", "baz", "viewer")))
   379  
   380  	intersection, err := tss.Intersect(sss)
   381  	require.NoError(t, err)
   382  
   383  	found, ok := intersection.Get(ONR("user", "tom", "..."))
   384  	require.True(t, ok)
   385  	require.Equal(t, 1, len(found.Relationships()))
   386  }