github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/developmentmembership/membership.go (about)

     1  package developmentmembership
     2  
     3  import (
     4  	"fmt"
     5  
     6  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
     7  
     8  	"github.com/authzed/spicedb/pkg/spiceerrors"
     9  	"github.com/authzed/spicedb/pkg/tuple"
    10  )
    11  
    12  // Set represents the set of membership for one or more ONRs, based on expansion
    13  // trees.
    14  type Set struct {
    15  	// objectsAndRelations is a map from an ONR (as a string) to the subjects found for that ONR.
    16  	objectsAndRelations map[string]FoundSubjects
    17  }
    18  
    19  // SubjectsByONR returns a map from ONR (as a string) to the FoundSubjects for that ONR.
    20  func (ms *Set) SubjectsByONR() map[string]FoundSubjects {
    21  	return ms.objectsAndRelations
    22  }
    23  
    24  // NewMembershipSet constructs a new membership set.
    25  //
    26  // NOTE: This is designed solely for the developer API and should *not* be used in any performance
    27  // sensitive code.
    28  func NewMembershipSet() *Set {
    29  	return &Set{
    30  		objectsAndRelations: map[string]FoundSubjects{},
    31  	}
    32  }
    33  
    34  // AddExpansion adds the expansion of an ONR to the membership set. Returns false if the ONR was already added.
    35  //
    36  // NOTE: The expansion tree *should* be the fully recursive expansion.
    37  func (ms *Set) AddExpansion(onr *core.ObjectAndRelation, expansion *core.RelationTupleTreeNode) (FoundSubjects, bool, error) {
    38  	onrString := tuple.StringONR(onr)
    39  	existing, ok := ms.objectsAndRelations[onrString]
    40  	if ok {
    41  		return existing, false, nil
    42  	}
    43  
    44  	tss, err := populateFoundSubjects(onr, expansion)
    45  	if err != nil {
    46  		return FoundSubjects{}, false, err
    47  	}
    48  
    49  	fs := tss.ToFoundSubjects()
    50  	ms.objectsAndRelations[onrString] = fs
    51  	return fs, true, nil
    52  }
    53  
    54  // AccessibleExpansionSubjects returns a TrackingSubjectSet representing the set of accessible subjects in the expansion.
    55  func AccessibleExpansionSubjects(treeNode *core.RelationTupleTreeNode) (*TrackingSubjectSet, error) {
    56  	return populateFoundSubjects(treeNode.Expanded, treeNode)
    57  }
    58  
    59  func populateFoundSubjects(rootONR *core.ObjectAndRelation, treeNode *core.RelationTupleTreeNode) (*TrackingSubjectSet, error) {
    60  	resource := rootONR
    61  	if treeNode.Expanded != nil {
    62  		resource = treeNode.Expanded
    63  	}
    64  
    65  	switch typed := treeNode.NodeType.(type) {
    66  	case *core.RelationTupleTreeNode_IntermediateNode:
    67  		switch typed.IntermediateNode.Operation {
    68  		case core.SetOperationUserset_UNION:
    69  			toReturn := NewTrackingSubjectSet()
    70  			for _, child := range typed.IntermediateNode.ChildNodes {
    71  				tss, err := populateFoundSubjects(resource, child)
    72  				if err != nil {
    73  					return nil, err
    74  				}
    75  
    76  				err = toReturn.AddFrom(tss)
    77  				if err != nil {
    78  					return nil, err
    79  				}
    80  			}
    81  
    82  			toReturn.ApplyParentCaveatExpression(treeNode.CaveatExpression)
    83  			return toReturn, nil
    84  
    85  		case core.SetOperationUserset_INTERSECTION:
    86  			if len(typed.IntermediateNode.ChildNodes) == 0 {
    87  				return nil, fmt.Errorf("found intersection with no children")
    88  			}
    89  
    90  			firstChildSet, err := populateFoundSubjects(rootONR, typed.IntermediateNode.ChildNodes[0])
    91  			if err != nil {
    92  				return nil, err
    93  			}
    94  
    95  			toReturn := NewTrackingSubjectSet()
    96  			err = toReturn.AddFrom(firstChildSet)
    97  			if err != nil {
    98  				return nil, err
    99  			}
   100  
   101  			for _, child := range typed.IntermediateNode.ChildNodes[1:] {
   102  				childSet, err := populateFoundSubjects(rootONR, child)
   103  				if err != nil {
   104  					return nil, err
   105  				}
   106  
   107  				updated, err := toReturn.Intersect(childSet)
   108  				if err != nil {
   109  					return nil, err
   110  				}
   111  
   112  				toReturn = updated
   113  			}
   114  
   115  			toReturn.ApplyParentCaveatExpression(treeNode.CaveatExpression)
   116  			return toReturn, nil
   117  
   118  		case core.SetOperationUserset_EXCLUSION:
   119  			if len(typed.IntermediateNode.ChildNodes) == 0 {
   120  				return nil, fmt.Errorf("found exclusion with no children")
   121  			}
   122  
   123  			firstChildSet, err := populateFoundSubjects(rootONR, typed.IntermediateNode.ChildNodes[0])
   124  			if err != nil {
   125  				return nil, err
   126  			}
   127  
   128  			toReturn := NewTrackingSubjectSet()
   129  			err = toReturn.AddFrom(firstChildSet)
   130  			if err != nil {
   131  				return nil, err
   132  			}
   133  
   134  			for _, child := range typed.IntermediateNode.ChildNodes[1:] {
   135  				childSet, err := populateFoundSubjects(rootONR, child)
   136  				if err != nil {
   137  					return nil, err
   138  				}
   139  				toReturn = toReturn.Exclude(childSet)
   140  			}
   141  
   142  			toReturn.ApplyParentCaveatExpression(treeNode.CaveatExpression)
   143  			return toReturn, nil
   144  
   145  		default:
   146  			return nil, spiceerrors.MustBugf("unknown expand operation")
   147  		}
   148  
   149  	case *core.RelationTupleTreeNode_LeafNode:
   150  		toReturn := NewTrackingSubjectSet()
   151  		for _, subject := range typed.LeafNode.Subjects {
   152  			fs := NewFoundSubject(subject)
   153  			err := toReturn.Add(fs)
   154  			if err != nil {
   155  				return nil, err
   156  			}
   157  
   158  			fs.relationships.Add(resource)
   159  		}
   160  
   161  		toReturn.ApplyParentCaveatExpression(treeNode.CaveatExpression)
   162  		return toReturn, nil
   163  
   164  	default:
   165  		return nil, spiceerrors.MustBugf("unknown TreeNode type")
   166  	}
   167  }