github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/developmentmembership/trackingsubjectset.go (about)

     1  package developmentmembership
     2  
     3  import (
     4  	"github.com/authzed/spicedb/internal/datasets"
     5  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
     6  	"github.com/authzed/spicedb/pkg/tuple"
     7  )
     8  
     9  // TrackingSubjectSet defines a set that tracks accessible subjects and their associated
    10  // relationships.
    11  //
    12  // NOTE: This is designed solely for the developer API and testing and should *not* be used in any
    13  // performance sensitive code.
    14  type TrackingSubjectSet struct {
    15  	setByType map[string]datasets.BaseSubjectSet[FoundSubject]
    16  }
    17  
    18  // NewTrackingSubjectSet creates a new TrackingSubjectSet
    19  func NewTrackingSubjectSet() *TrackingSubjectSet {
    20  	tss := &TrackingSubjectSet{
    21  		setByType: map[string]datasets.BaseSubjectSet[FoundSubject]{},
    22  	}
    23  	return tss
    24  }
    25  
    26  // MustNewTrackingSubjectSetWith creates a new TrackingSubjectSet, and adds the specified
    27  // subjects to it.
    28  func MustNewTrackingSubjectSetWith(subjects ...FoundSubject) *TrackingSubjectSet {
    29  	tss := NewTrackingSubjectSet()
    30  	for _, subject := range subjects {
    31  		err := tss.Add(subject)
    32  		if err != nil {
    33  			panic(err)
    34  		}
    35  	}
    36  	return tss
    37  }
    38  
    39  // AddFrom adds the subjects found in the other set to this set.
    40  func (tss *TrackingSubjectSet) AddFrom(otherSet *TrackingSubjectSet) error {
    41  	for key, oss := range otherSet.setByType {
    42  		err := tss.getSetForKey(key).UnionWithSet(oss)
    43  		if err != nil {
    44  			return err
    45  		}
    46  	}
    47  	return nil
    48  }
    49  
    50  // MustAddFrom adds the subjects found in the other set to this set.
    51  func (tss *TrackingSubjectSet) MustAddFrom(otherSet *TrackingSubjectSet) {
    52  	err := tss.AddFrom(otherSet)
    53  	if err != nil {
    54  		panic(err)
    55  	}
    56  }
    57  
    58  // RemoveFrom removes any subjects found in the other set from this set.
    59  func (tss *TrackingSubjectSet) RemoveFrom(otherSet *TrackingSubjectSet) {
    60  	for key, oss := range otherSet.setByType {
    61  		tss.getSetForKey(key).SubtractAll(oss)
    62  	}
    63  }
    64  
    65  // MustAdd adds the given subjects to this set.
    66  func (tss *TrackingSubjectSet) MustAdd(subjectsAndResources ...FoundSubject) {
    67  	err := tss.Add(subjectsAndResources...)
    68  	if err != nil {
    69  		panic(err)
    70  	}
    71  }
    72  
    73  // Add adds the given subjects to this set.
    74  func (tss *TrackingSubjectSet) Add(subjectsAndResources ...FoundSubject) error {
    75  	for _, fs := range subjectsAndResources {
    76  		err := tss.getSet(fs).Add(fs)
    77  		if err != nil {
    78  			return err
    79  		}
    80  	}
    81  	return nil
    82  }
    83  
    84  func (tss *TrackingSubjectSet) getSetForKey(key string) datasets.BaseSubjectSet[FoundSubject] {
    85  	if existing, ok := tss.setByType[key]; ok {
    86  		return existing
    87  	}
    88  
    89  	ns, rel := tuple.MustSplitRelRef(key)
    90  	created := datasets.NewBaseSubjectSet(
    91  		func(subjectID string, caveatExpression *core.CaveatExpression, excludedSubjects []FoundSubject, sources ...FoundSubject) FoundSubject {
    92  			fs := NewFoundSubject(&core.DirectSubject{
    93  				Subject: &core.ObjectAndRelation{
    94  					Namespace: ns,
    95  					ObjectId:  subjectID,
    96  					Relation:  rel,
    97  				},
    98  				CaveatExpression: caveatExpression,
    99  			})
   100  			fs.excludedSubjects = excludedSubjects
   101  			fs.caveatExpression = caveatExpression
   102  			for _, source := range sources {
   103  				if source.relationships != nil {
   104  					fs.relationships.UpdateFrom(source.relationships)
   105  				}
   106  			}
   107  			return fs
   108  		},
   109  	)
   110  	tss.setByType[key] = created
   111  	return created
   112  }
   113  
   114  func (tss *TrackingSubjectSet) getSet(fs FoundSubject) datasets.BaseSubjectSet[FoundSubject] {
   115  	return tss.getSetForKey(tuple.JoinRelRef(fs.subject.Namespace, fs.subject.Relation))
   116  }
   117  
   118  // Get returns the found subject in the set, if any.
   119  func (tss *TrackingSubjectSet) Get(subject *core.ObjectAndRelation) (FoundSubject, bool) {
   120  	set, ok := tss.setByType[tuple.JoinRelRef(subject.Namespace, subject.Relation)]
   121  	if !ok {
   122  		return FoundSubject{}, false
   123  	}
   124  
   125  	return set.Get(subject.ObjectId)
   126  }
   127  
   128  // Contains returns true if the set contains the given subject.
   129  func (tss *TrackingSubjectSet) Contains(subject *core.ObjectAndRelation) bool {
   130  	_, ok := tss.Get(subject)
   131  	return ok
   132  }
   133  
   134  // Exclude returns a new set that contains the items in this set minus those in the other set.
   135  func (tss *TrackingSubjectSet) Exclude(otherSet *TrackingSubjectSet) *TrackingSubjectSet {
   136  	newSet := NewTrackingSubjectSet()
   137  
   138  	for key, bss := range tss.setByType {
   139  		cloned := bss.Clone()
   140  		if oss, ok := otherSet.setByType[key]; ok {
   141  			cloned.SubtractAll(oss)
   142  		}
   143  
   144  		newSet.setByType[key] = cloned
   145  	}
   146  
   147  	return newSet
   148  }
   149  
   150  // MustIntersect returns a new set that contains the items in this set *and* the other set. Note that
   151  // if wildcard is found in *both* sets, it will be returned *along* with any concrete subjects found
   152  // on the other side of the intersection.
   153  func (tss *TrackingSubjectSet) MustIntersect(otherSet *TrackingSubjectSet) *TrackingSubjectSet {
   154  	updated, err := tss.Intersect(otherSet)
   155  	if err != nil {
   156  		panic(err)
   157  	}
   158  	return updated
   159  }
   160  
   161  // Intersect returns a new set that contains the items in this set *and* the other set. Note that
   162  // if wildcard is found in *both* sets, it will be returned *along* with any concrete subjects found
   163  // on the other side of the intersection.
   164  func (tss *TrackingSubjectSet) Intersect(otherSet *TrackingSubjectSet) (*TrackingSubjectSet, error) {
   165  	newSet := NewTrackingSubjectSet()
   166  
   167  	for key, bss := range tss.setByType {
   168  		if oss, ok := otherSet.setByType[key]; ok {
   169  			cloned := bss.Clone()
   170  			err := cloned.IntersectionDifference(oss)
   171  			if err != nil {
   172  				return nil, err
   173  			}
   174  
   175  			newSet.setByType[key] = cloned
   176  		}
   177  	}
   178  
   179  	return newSet, nil
   180  }
   181  
   182  // ApplyParentCaveatExpression applies the given parent caveat expression (if any) to each subject set.
   183  func (tss *TrackingSubjectSet) ApplyParentCaveatExpression(parentCaveatExpr *core.CaveatExpression) {
   184  	if parentCaveatExpr == nil {
   185  		return
   186  	}
   187  
   188  	for key, bss := range tss.setByType {
   189  		tss.setByType[key] = bss.WithParentCaveatExpression(parentCaveatExpr)
   190  	}
   191  }
   192  
   193  // removeExact removes the given subject(s) from the set. If the subject is a wildcard, only
   194  // the exact matching wildcard will be removed.
   195  func (tss *TrackingSubjectSet) removeExact(subjects ...*core.ObjectAndRelation) {
   196  	for _, subject := range subjects {
   197  		if set, ok := tss.setByType[tuple.JoinRelRef(subject.Namespace, subject.Relation)]; ok {
   198  			set.UnsafeRemoveExact(FoundSubject{
   199  				subject: subject,
   200  			})
   201  		}
   202  	}
   203  }
   204  
   205  func (tss *TrackingSubjectSet) getSubjects() []string {
   206  	var subjects []string
   207  	for _, subjectSet := range tss.setByType {
   208  		for _, foundSubject := range subjectSet.AsSlice() {
   209  			subjects = append(subjects, tuple.StringONR(foundSubject.subject))
   210  		}
   211  	}
   212  	return subjects
   213  }
   214  
   215  // ToSlice returns a slice of all subjects found in the set.
   216  func (tss *TrackingSubjectSet) ToSlice() []FoundSubject {
   217  	subjects := []FoundSubject{}
   218  	for _, bss := range tss.setByType {
   219  		subjects = append(subjects, bss.AsSlice()...)
   220  	}
   221  
   222  	return subjects
   223  }
   224  
   225  // ToFoundSubjects returns the set as a FoundSubjects struct.
   226  func (tss *TrackingSubjectSet) ToFoundSubjects() FoundSubjects {
   227  	return FoundSubjects{tss}
   228  }
   229  
   230  // IsEmpty returns true if the tracking subject set is empty.
   231  func (tss *TrackingSubjectSet) IsEmpty() bool {
   232  	for _, bss := range tss.setByType {
   233  		if !bss.IsEmpty() {
   234  			return false
   235  		}
   236  	}
   237  	return true
   238  }