github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/developmentmembership/trackingsubjectset.go (about) 1 package developmentmembership 2 3 import ( 4 "github.com/authzed/spicedb/internal/datasets" 5 core "github.com/authzed/spicedb/pkg/proto/core/v1" 6 "github.com/authzed/spicedb/pkg/tuple" 7 ) 8 9 // TrackingSubjectSet defines a set that tracks accessible subjects and their associated 10 // relationships. 11 // 12 // NOTE: This is designed solely for the developer API and testing and should *not* be used in any 13 // performance sensitive code. 14 type TrackingSubjectSet struct { 15 setByType map[string]datasets.BaseSubjectSet[FoundSubject] 16 } 17 18 // NewTrackingSubjectSet creates a new TrackingSubjectSet 19 func NewTrackingSubjectSet() *TrackingSubjectSet { 20 tss := &TrackingSubjectSet{ 21 setByType: map[string]datasets.BaseSubjectSet[FoundSubject]{}, 22 } 23 return tss 24 } 25 26 // MustNewTrackingSubjectSetWith creates a new TrackingSubjectSet, and adds the specified 27 // subjects to it. 28 func MustNewTrackingSubjectSetWith(subjects ...FoundSubject) *TrackingSubjectSet { 29 tss := NewTrackingSubjectSet() 30 for _, subject := range subjects { 31 err := tss.Add(subject) 32 if err != nil { 33 panic(err) 34 } 35 } 36 return tss 37 } 38 39 // AddFrom adds the subjects found in the other set to this set. 40 func (tss *TrackingSubjectSet) AddFrom(otherSet *TrackingSubjectSet) error { 41 for key, oss := range otherSet.setByType { 42 err := tss.getSetForKey(key).UnionWithSet(oss) 43 if err != nil { 44 return err 45 } 46 } 47 return nil 48 } 49 50 // MustAddFrom adds the subjects found in the other set to this set. 51 func (tss *TrackingSubjectSet) MustAddFrom(otherSet *TrackingSubjectSet) { 52 err := tss.AddFrom(otherSet) 53 if err != nil { 54 panic(err) 55 } 56 } 57 58 // RemoveFrom removes any subjects found in the other set from this set. 59 func (tss *TrackingSubjectSet) RemoveFrom(otherSet *TrackingSubjectSet) { 60 for key, oss := range otherSet.setByType { 61 tss.getSetForKey(key).SubtractAll(oss) 62 } 63 } 64 65 // MustAdd adds the given subjects to this set. 66 func (tss *TrackingSubjectSet) MustAdd(subjectsAndResources ...FoundSubject) { 67 err := tss.Add(subjectsAndResources...) 68 if err != nil { 69 panic(err) 70 } 71 } 72 73 // Add adds the given subjects to this set. 74 func (tss *TrackingSubjectSet) Add(subjectsAndResources ...FoundSubject) error { 75 for _, fs := range subjectsAndResources { 76 err := tss.getSet(fs).Add(fs) 77 if err != nil { 78 return err 79 } 80 } 81 return nil 82 } 83 84 func (tss *TrackingSubjectSet) getSetForKey(key string) datasets.BaseSubjectSet[FoundSubject] { 85 if existing, ok := tss.setByType[key]; ok { 86 return existing 87 } 88 89 ns, rel := tuple.MustSplitRelRef(key) 90 created := datasets.NewBaseSubjectSet( 91 func(subjectID string, caveatExpression *core.CaveatExpression, excludedSubjects []FoundSubject, sources ...FoundSubject) FoundSubject { 92 fs := NewFoundSubject(&core.DirectSubject{ 93 Subject: &core.ObjectAndRelation{ 94 Namespace: ns, 95 ObjectId: subjectID, 96 Relation: rel, 97 }, 98 CaveatExpression: caveatExpression, 99 }) 100 fs.excludedSubjects = excludedSubjects 101 fs.caveatExpression = caveatExpression 102 for _, source := range sources { 103 if source.relationships != nil { 104 fs.relationships.UpdateFrom(source.relationships) 105 } 106 } 107 return fs 108 }, 109 ) 110 tss.setByType[key] = created 111 return created 112 } 113 114 func (tss *TrackingSubjectSet) getSet(fs FoundSubject) datasets.BaseSubjectSet[FoundSubject] { 115 return tss.getSetForKey(tuple.JoinRelRef(fs.subject.Namespace, fs.subject.Relation)) 116 } 117 118 // Get returns the found subject in the set, if any. 119 func (tss *TrackingSubjectSet) Get(subject *core.ObjectAndRelation) (FoundSubject, bool) { 120 set, ok := tss.setByType[tuple.JoinRelRef(subject.Namespace, subject.Relation)] 121 if !ok { 122 return FoundSubject{}, false 123 } 124 125 return set.Get(subject.ObjectId) 126 } 127 128 // Contains returns true if the set contains the given subject. 129 func (tss *TrackingSubjectSet) Contains(subject *core.ObjectAndRelation) bool { 130 _, ok := tss.Get(subject) 131 return ok 132 } 133 134 // Exclude returns a new set that contains the items in this set minus those in the other set. 135 func (tss *TrackingSubjectSet) Exclude(otherSet *TrackingSubjectSet) *TrackingSubjectSet { 136 newSet := NewTrackingSubjectSet() 137 138 for key, bss := range tss.setByType { 139 cloned := bss.Clone() 140 if oss, ok := otherSet.setByType[key]; ok { 141 cloned.SubtractAll(oss) 142 } 143 144 newSet.setByType[key] = cloned 145 } 146 147 return newSet 148 } 149 150 // MustIntersect returns a new set that contains the items in this set *and* the other set. Note that 151 // if wildcard is found in *both* sets, it will be returned *along* with any concrete subjects found 152 // on the other side of the intersection. 153 func (tss *TrackingSubjectSet) MustIntersect(otherSet *TrackingSubjectSet) *TrackingSubjectSet { 154 updated, err := tss.Intersect(otherSet) 155 if err != nil { 156 panic(err) 157 } 158 return updated 159 } 160 161 // Intersect returns a new set that contains the items in this set *and* the other set. Note that 162 // if wildcard is found in *both* sets, it will be returned *along* with any concrete subjects found 163 // on the other side of the intersection. 164 func (tss *TrackingSubjectSet) Intersect(otherSet *TrackingSubjectSet) (*TrackingSubjectSet, error) { 165 newSet := NewTrackingSubjectSet() 166 167 for key, bss := range tss.setByType { 168 if oss, ok := otherSet.setByType[key]; ok { 169 cloned := bss.Clone() 170 err := cloned.IntersectionDifference(oss) 171 if err != nil { 172 return nil, err 173 } 174 175 newSet.setByType[key] = cloned 176 } 177 } 178 179 return newSet, nil 180 } 181 182 // ApplyParentCaveatExpression applies the given parent caveat expression (if any) to each subject set. 183 func (tss *TrackingSubjectSet) ApplyParentCaveatExpression(parentCaveatExpr *core.CaveatExpression) { 184 if parentCaveatExpr == nil { 185 return 186 } 187 188 for key, bss := range tss.setByType { 189 tss.setByType[key] = bss.WithParentCaveatExpression(parentCaveatExpr) 190 } 191 } 192 193 // removeExact removes the given subject(s) from the set. If the subject is a wildcard, only 194 // the exact matching wildcard will be removed. 195 func (tss *TrackingSubjectSet) removeExact(subjects ...*core.ObjectAndRelation) { 196 for _, subject := range subjects { 197 if set, ok := tss.setByType[tuple.JoinRelRef(subject.Namespace, subject.Relation)]; ok { 198 set.UnsafeRemoveExact(FoundSubject{ 199 subject: subject, 200 }) 201 } 202 } 203 } 204 205 func (tss *TrackingSubjectSet) getSubjects() []string { 206 var subjects []string 207 for _, subjectSet := range tss.setByType { 208 for _, foundSubject := range subjectSet.AsSlice() { 209 subjects = append(subjects, tuple.StringONR(foundSubject.subject)) 210 } 211 } 212 return subjects 213 } 214 215 // ToSlice returns a slice of all subjects found in the set. 216 func (tss *TrackingSubjectSet) ToSlice() []FoundSubject { 217 subjects := []FoundSubject{} 218 for _, bss := range tss.setByType { 219 subjects = append(subjects, bss.AsSlice()...) 220 } 221 222 return subjects 223 } 224 225 // ToFoundSubjects returns the set as a FoundSubjects struct. 226 func (tss *TrackingSubjectSet) ToFoundSubjects() FoundSubjects { 227 return FoundSubjects{tss} 228 } 229 230 // IsEmpty returns true if the tracking subject set is empty. 231 func (tss *TrackingSubjectSet) IsEmpty() bool { 232 for _, bss := range tss.setByType { 233 if !bss.IsEmpty() { 234 return false 235 } 236 } 237 return true 238 }