github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datasets/subjectsetbytype_test.go (about)

     1  package datasets
     2  
     3  import (
     4  	"sort"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/require"
     8  
     9  	"github.com/authzed/spicedb/pkg/genutil/mapz"
    10  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    11  	"github.com/authzed/spicedb/pkg/tuple"
    12  )
    13  
    14  func RR(namespaceName string, relationName string) *core.RelationReference {
    15  	return &core.RelationReference{
    16  		Namespace: namespaceName,
    17  		Relation:  relationName,
    18  	}
    19  }
    20  
    21  func TestSubjectByTypeSet(t *testing.T) {
    22  	assertHasObjectIds := func(s *SubjectByTypeSet, rr *core.RelationReference, expected []string) {
    23  		wasFound := false
    24  		s.ForEachType(func(foundRR *core.RelationReference, subjects SubjectSet) {
    25  			objectIds := make([]string, 0, len(subjects.AsSlice()))
    26  			for _, subject := range subjects.AsSlice() {
    27  				require.Empty(t, subject.GetExcludedSubjects())
    28  				objectIds = append(objectIds, subject.SubjectId)
    29  			}
    30  
    31  			if rr.Namespace == foundRR.Namespace && rr.Relation == foundRR.Relation {
    32  				sort.Strings(objectIds)
    33  				require.Equal(t, expected, objectIds)
    34  				wasFound = true
    35  			}
    36  		})
    37  		require.True(t, wasFound)
    38  
    39  		wasFound = false
    40  		err := s.ForEachTypeUntil(func(foundRR *core.RelationReference, subjects SubjectSet) (bool, error) {
    41  			objectIds := make([]string, 0, len(subjects.AsSlice()))
    42  			for _, subject := range subjects.AsSlice() {
    43  				require.Empty(t, subject.GetExcludedSubjects())
    44  				objectIds = append(objectIds, subject.SubjectId)
    45  			}
    46  
    47  			if rr.Namespace == foundRR.Namespace && rr.Relation == foundRR.Relation {
    48  				sort.Strings(objectIds)
    49  				require.Equal(t, expected, objectIds)
    50  				wasFound = true
    51  				return false, nil
    52  			}
    53  
    54  			return true, nil
    55  		})
    56  		require.True(t, wasFound)
    57  		require.NoError(t, err)
    58  	}
    59  
    60  	set := NewSubjectByTypeSet()
    61  	require.True(t, set.IsEmpty())
    62  
    63  	// Add some concrete subjects.
    64  	err := set.AddConcreteSubject(tuple.ParseONR("document:foo#viewer"))
    65  	require.NoError(t, err)
    66  
    67  	err = set.AddConcreteSubject(tuple.ParseONR("document:bar#viewer"))
    68  	require.NoError(t, err)
    69  
    70  	err = set.AddConcreteSubject(tuple.ParseONR("team:something#member"))
    71  	require.NoError(t, err)
    72  
    73  	err = set.AddConcreteSubject(tuple.ParseONR("team:other#member"))
    74  	require.NoError(t, err)
    75  
    76  	err = set.AddConcreteSubject(tuple.ParseONR("team:other#manager"))
    77  	require.NoError(t, err)
    78  
    79  	// Add a caveated subject.
    80  	err = set.AddSubjectOf(tuple.MustWithCaveat(tuple.MustParse("document:foo#viewer@user:tom"), "first"))
    81  	require.NoError(t, err)
    82  
    83  	require.False(t, set.IsEmpty())
    84  
    85  	// Run for each type over the set
    86  	assertHasObjectIds(set, RR("document", "viewer"), []string{"bar", "foo"})
    87  	assertHasObjectIds(set, RR("team", "member"), []string{"other", "something"})
    88  	assertHasObjectIds(set, RR("team", "manager"), []string{"other"})
    89  	assertHasObjectIds(set, RR("user", "..."), []string{"tom"})
    90  
    91  	// Map
    92  	mapped, err := set.Map(func(rr *core.RelationReference) (*core.RelationReference, error) {
    93  		if rr.Namespace == "document" {
    94  			return RR("doc", rr.Relation), nil
    95  		}
    96  
    97  		return rr, nil
    98  	})
    99  	require.NoError(t, err)
   100  
   101  	assertHasObjectIds(mapped, RR("doc", "viewer"), []string{"bar", "foo"})
   102  	assertHasObjectIds(mapped, RR("team", "member"), []string{"other", "something"})
   103  	assertHasObjectIds(mapped, RR("team", "manager"), []string{"other"})
   104  	assertHasObjectIds(mapped, RR("user", "..."), []string{"tom"})
   105  }
   106  
   107  func TestSubjectSetByTypeWithCaveats(t *testing.T) {
   108  	set := NewSubjectByTypeSet()
   109  	require.True(t, set.IsEmpty())
   110  
   111  	err := set.AddSubjectOf(tuple.MustWithCaveat(tuple.MustParse("document:foo#viewer@user:tom"), "first"))
   112  	require.NoError(t, err)
   113  
   114  	ss, ok := set.SubjectSetForType(&core.RelationReference{
   115  		Namespace: "user",
   116  		Relation:  "...",
   117  	})
   118  	require.True(t, ok)
   119  
   120  	tom, ok := ss.Get("tom")
   121  	require.True(t, ok)
   122  
   123  	require.Equal(t,
   124  		caveatexpr("first"),
   125  		tom.GetCaveatExpression(),
   126  	)
   127  }
   128  
   129  func TestSubjectSetMapOverSameSubjectDifferentRelation(t *testing.T) {
   130  	set := NewSubjectByTypeSet()
   131  	require.True(t, set.IsEmpty())
   132  
   133  	err := set.AddSubjectOf(tuple.MustParse("document:foo#folder@folder:folder1"))
   134  	require.NoError(t, err)
   135  
   136  	err = set.AddSubjectOf(tuple.MustParse("document:foo#folder@folder:folder2#parent"))
   137  	require.NoError(t, err)
   138  
   139  	mapped, err := set.Map(func(rr *core.RelationReference) (*core.RelationReference, error) {
   140  		return &core.RelationReference{
   141  			Namespace: rr.Namespace,
   142  			Relation:  "shared",
   143  		}, nil
   144  	})
   145  	require.NoError(t, err)
   146  
   147  	foundSubjectIDs := mapz.NewSet[string]()
   148  	for _, sub := range mapped.byType["folder#shared"].AsSlice() {
   149  		foundSubjectIDs.Add(sub.SubjectId)
   150  	}
   151  
   152  	require.ElementsMatch(t, []string{"folder1", "folder2"}, foundSubjectIDs.AsSlice())
   153  }