github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/typesystem/reachabilitygraph.go (about)

     1  package typesystem
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sort"
     7  	"strconv"
     8  	"sync"
     9  
    10  	"github.com/cespare/xxhash/v2"
    11  	"golang.org/x/exp/maps"
    12  
    13  	"github.com/authzed/spicedb/pkg/genutil/mapz"
    14  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    15  	"github.com/authzed/spicedb/pkg/spiceerrors"
    16  	"github.com/authzed/spicedb/pkg/tuple"
    17  )
    18  
    19  // ReachabilityGraph is a helper struct that provides an easy way to determine all entrypoints
    20  // for a subject of a particular type into a schema, for the purpose of walking from the subject
    21  // to a specific resource relation.
    22  type ReachabilityGraph struct {
    23  	ts                          *TypeSystem
    24  	cachedGraphs                sync.Map
    25  	hasOptimizedEntrypointCache sync.Map
    26  }
    27  
    28  // ReachabilityEntrypoint is an entrypoint into the reachability graph for a subject of particular
    29  // type.
    30  type ReachabilityEntrypoint struct {
    31  	re             *core.ReachabilityEntrypoint
    32  	parentRelation *core.RelationReference
    33  }
    34  
    35  // Hash returns a hash representing the data in the entrypoint, for comparison to other entrypoints.
    36  // This is ONLY stable within a single version of SpiceDB and should NEVER be stored for later
    37  // comparison outside of the process.
    38  func (re ReachabilityEntrypoint) Hash() (uint64, error) {
    39  	size := re.re.SizeVT()
    40  	if re.parentRelation != nil {
    41  		size += re.parentRelation.SizeVT()
    42  	}
    43  
    44  	hashData := make([]byte, 0, size)
    45  
    46  	data, err := re.re.MarshalVT()
    47  	if err != nil {
    48  		return 0, err
    49  	}
    50  
    51  	hashData = append(hashData, data...)
    52  
    53  	if re.parentRelation != nil {
    54  		data, err := re.parentRelation.MarshalVT()
    55  		if err != nil {
    56  			return 0, err
    57  		}
    58  
    59  		hashData = append(hashData, data...)
    60  	}
    61  
    62  	return xxhash.Sum64(hashData), nil
    63  }
    64  
    65  // EntrypointKind is the kind of the entrypoint.
    66  func (re ReachabilityEntrypoint) EntrypointKind() core.ReachabilityEntrypoint_ReachabilityEntrypointKind {
    67  	return re.re.Kind
    68  }
    69  
    70  // TuplesetRelation returns the tupleset relation of the TTU, if a TUPLESET_TO_USERSET_ENTRYPOINT.
    71  func (re ReachabilityEntrypoint) TuplesetRelation() (string, error) {
    72  	if re.EntrypointKind() != core.ReachabilityEntrypoint_TUPLESET_TO_USERSET_ENTRYPOINT {
    73  		return "", fmt.Errorf("cannot call TupleToUserset for kind %v", re.EntrypointKind())
    74  	}
    75  
    76  	return re.re.TuplesetRelation, nil
    77  }
    78  
    79  // DirectRelation is the relation that this entrypoint represents, if a RELATION_ENTRYPOINT.
    80  func (re ReachabilityEntrypoint) DirectRelation() (*core.RelationReference, error) {
    81  	if re.EntrypointKind() != core.ReachabilityEntrypoint_RELATION_ENTRYPOINT {
    82  		return nil, fmt.Errorf("cannot call DirectRelation for kind %v", re.EntrypointKind())
    83  	}
    84  
    85  	return re.re.TargetRelation, nil
    86  }
    87  
    88  // ContainingRelationOrPermission is the relation or permission containing this entrypoint.
    89  func (re ReachabilityEntrypoint) ContainingRelationOrPermission() *core.RelationReference {
    90  	return re.parentRelation
    91  }
    92  
    93  // IsDirectResult returns whether the entrypoint, when evaluated, becomes a direct result of
    94  // the parent relation/permission. A direct result only exists if the entrypoint is not contained
    95  // under an intersection or exclusion, which makes the entrypoint's object merely conditionally
    96  // reachable.
    97  func (re ReachabilityEntrypoint) IsDirectResult() bool {
    98  	return re.re.ResultStatus == core.ReachabilityEntrypoint_DIRECT_OPERATION_RESULT
    99  }
   100  
   101  func (re ReachabilityEntrypoint) String() string {
   102  	return re.MustDebugString()
   103  }
   104  
   105  func (re ReachabilityEntrypoint) MustDebugString() string {
   106  	switch re.EntrypointKind() {
   107  	case core.ReachabilityEntrypoint_RELATION_ENTRYPOINT:
   108  		return fmt.Sprintf("relation-entrypoint: %s#%s", re.re.TargetRelation.Namespace, re.re.TargetRelation.Relation)
   109  
   110  	case core.ReachabilityEntrypoint_TUPLESET_TO_USERSET_ENTRYPOINT:
   111  		return fmt.Sprintf("ttu-entrypoint: %s#%s | %s | %s#%s", re.parentRelation.Namespace, re.parentRelation.Relation, re.re.TuplesetRelation, re.re.TargetRelation.Namespace, re.re.TargetRelation.Relation)
   112  
   113  	case core.ReachabilityEntrypoint_COMPUTED_USERSET_ENTRYPOINT:
   114  		return fmt.Sprintf("computed-entrypoint: %s#%s", re.re.TargetRelation.Namespace, re.re.TargetRelation.Relation)
   115  
   116  	default:
   117  		panic("unknown relation entrypoint kind")
   118  	}
   119  }
   120  
   121  // ReachabilityGraphFor returns a reachability graph for the given namespace.
   122  func ReachabilityGraphFor(ts *ValidatedNamespaceTypeSystem) *ReachabilityGraph {
   123  	return &ReachabilityGraph{ts.TypeSystem, sync.Map{}, sync.Map{}}
   124  }
   125  
   126  // RelationsEncounteredForResource returns all relations that are encountered when walking outward from a resource+relation.
   127  func (rg *ReachabilityGraph) RelationsEncounteredForResource(
   128  	ctx context.Context,
   129  	resourceType *core.RelationReference,
   130  ) ([]*core.RelationReference, error) {
   131  	_, relationNames, err := rg.computeEntrypoints(ctx, resourceType, nil /* include all entrypoints */, reachabilityFull, entrypointLookupFindAll)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  
   136  	relationRefs := make([]*core.RelationReference, 0, len(relationNames))
   137  	for _, relationName := range relationNames {
   138  		namespace, relation := tuple.MustSplitRelRef(relationName)
   139  		relationRefs = append(relationRefs, &core.RelationReference{
   140  			Namespace: namespace,
   141  			Relation:  relation,
   142  		})
   143  	}
   144  	return relationRefs, nil
   145  }
   146  
   147  // RelationsEncounteredForSubject returns all relations that are encountered when walking outward from a subject+relation.
   148  func (rg *ReachabilityGraph) RelationsEncounteredForSubject(
   149  	ctx context.Context,
   150  	allDefinitions []*core.NamespaceDefinition,
   151  	startingSubjectType *core.RelationReference,
   152  ) ([]*core.RelationReference, error) {
   153  	if startingSubjectType.Namespace != rg.ts.nsDef.Name {
   154  		return nil, spiceerrors.MustBugf("gave mismatching namespace name for subject type to reachability graph")
   155  	}
   156  
   157  	allRelationNames := mapz.NewSet[string]()
   158  
   159  	subjectTypesToCheck := []*core.RelationReference{startingSubjectType}
   160  
   161  	// TODO(jschorr): optimize this to not require walking over all types recursively.
   162  	added := mapz.NewSet[string]()
   163  	for {
   164  		if len(subjectTypesToCheck) == 0 {
   165  			break
   166  		}
   167  
   168  		collected := &[]ReachabilityEntrypoint{}
   169  		for _, nsDef := range allDefinitions {
   170  			nts, err := rg.ts.TypeSystemForNamespace(ctx, nsDef.Name)
   171  			if err != nil {
   172  				return nil, err
   173  			}
   174  
   175  			nrg := ReachabilityGraphFor(&ValidatedNamespaceTypeSystem{nts})
   176  
   177  			for _, relation := range nsDef.Relation {
   178  				for _, subjectType := range subjectTypesToCheck {
   179  					if subjectType.Namespace == nsDef.Name && subjectType.Relation == relation.Name {
   180  						continue
   181  					}
   182  
   183  					encounteredRelations := map[string]struct{}{}
   184  					err := nrg.collectEntrypoints(ctx, &core.RelationReference{
   185  						Namespace: nsDef.Name,
   186  						Relation:  relation.Name,
   187  					}, subjectType, collected, encounteredRelations, reachabilityFull, entrypointLookupFindAll)
   188  					if err != nil {
   189  						return nil, err
   190  					}
   191  				}
   192  			}
   193  		}
   194  
   195  		subjectTypesToCheck = make([]*core.RelationReference, 0, len(*collected))
   196  
   197  		for _, entrypoint := range *collected {
   198  			st := tuple.JoinRelRef(entrypoint.re.TargetRelation.Namespace, entrypoint.re.TargetRelation.Relation)
   199  			if !added.Add(st) {
   200  				continue
   201  			}
   202  
   203  			allRelationNames.Add(st)
   204  			subjectTypesToCheck = append(subjectTypesToCheck, entrypoint.re.TargetRelation)
   205  		}
   206  	}
   207  
   208  	relationRefs := make([]*core.RelationReference, 0, allRelationNames.Len())
   209  	for _, relationName := range allRelationNames.AsSlice() {
   210  		namespace, relation := tuple.MustSplitRelRef(relationName)
   211  		relationRefs = append(relationRefs, &core.RelationReference{
   212  			Namespace: namespace,
   213  			Relation:  relation,
   214  		})
   215  	}
   216  	return relationRefs, nil
   217  }
   218  
   219  // AllEntrypointsForSubjectToResource returns the entrypoints into the reachability graph, starting
   220  // at the given subject type and walking to the given resource type.
   221  func (rg *ReachabilityGraph) AllEntrypointsForSubjectToResource(
   222  	ctx context.Context,
   223  	subjectType *core.RelationReference,
   224  	resourceType *core.RelationReference,
   225  ) ([]ReachabilityEntrypoint, error) {
   226  	entrypoints, _, err := rg.computeEntrypoints(ctx, resourceType, subjectType, reachabilityFull, entrypointLookupFindAll)
   227  	return entrypoints, err
   228  }
   229  
   230  // OptimizedEntrypointsForSubjectToResource returns the *optimized* set of entrypoints into the
   231  // reachability graph, starting at the given subject type and walking to the given resource type.
   232  //
   233  // The optimized set will skip branches on intersections and exclusions in an attempt to minimize
   234  // the number of entrypoints.
   235  func (rg *ReachabilityGraph) OptimizedEntrypointsForSubjectToResource(
   236  	ctx context.Context,
   237  	subjectType *core.RelationReference,
   238  	resourceType *core.RelationReference,
   239  ) ([]ReachabilityEntrypoint, error) {
   240  	entrypoints, _, err := rg.computeEntrypoints(ctx, resourceType, subjectType, reachabilityOptimized, entrypointLookupFindAll)
   241  	return entrypoints, err
   242  }
   243  
   244  // HasOptimizedEntrypointsForSubjectToResource returns whether there exists any *optimized*
   245  // entrypoints into the reachability graph, starting at the given subject type and walking
   246  // to the given resource type.
   247  //
   248  // The optimized set will skip branches on intersections and exclusions in an attempt to minimize
   249  // the number of entrypoints.
   250  func (rg *ReachabilityGraph) HasOptimizedEntrypointsForSubjectToResource(
   251  	ctx context.Context,
   252  	subjectType *core.RelationReference,
   253  	resourceType *core.RelationReference,
   254  ) (bool, error) {
   255  	cacheKey := tuple.StringRR(subjectType) + "=>" + tuple.StringRR(resourceType)
   256  	if result, ok := rg.hasOptimizedEntrypointCache.Load(cacheKey); ok {
   257  		return result.(bool), nil
   258  	}
   259  
   260  	// TODO(jzelinskie): measure to see if it's worth singleflighting this
   261  	found, _, err := rg.computeEntrypoints(ctx, resourceType, subjectType, reachabilityOptimized, entrypointLookupFindOne)
   262  	if err != nil {
   263  		return false, err
   264  	}
   265  
   266  	result := len(found) > 0
   267  	rg.hasOptimizedEntrypointCache.Store(cacheKey, result)
   268  	return result, nil
   269  }
   270  
   271  type entrypointLookupOption int
   272  
   273  const (
   274  	entrypointLookupFindAll entrypointLookupOption = iota
   275  	entrypointLookupFindOne
   276  )
   277  
   278  func (rg *ReachabilityGraph) computeEntrypoints(
   279  	ctx context.Context,
   280  	resourceType *core.RelationReference,
   281  	optionalSubjectType *core.RelationReference,
   282  	reachabilityOption reachabilityOption,
   283  	entrypointLookupOption entrypointLookupOption,
   284  ) ([]ReachabilityEntrypoint, []string, error) {
   285  	if resourceType.Namespace != rg.ts.nsDef.Name {
   286  		return nil, nil, fmt.Errorf("gave mismatching namespace name for resource type to reachability graph")
   287  	}
   288  
   289  	collected := &[]ReachabilityEntrypoint{}
   290  	encounteredRelations := map[string]struct{}{}
   291  	err := rg.collectEntrypoints(ctx, resourceType, optionalSubjectType, collected, encounteredRelations, reachabilityOption, entrypointLookupOption)
   292  	if err != nil {
   293  		return nil, maps.Keys(encounteredRelations), err
   294  	}
   295  
   296  	collectedEntrypoints := *collected
   297  
   298  	// Deduplicate any entrypoints found. An example that can cause a duplicate is a relation which references
   299  	// the same subject type multiple times due to caveats:
   300  	//
   301  	// relation somerel: user | user with somecaveat
   302  	//
   303  	// This will produce two entrypoints (one per user reference), but as entrypoints themselves are not caveated,
   304  	// one is spurious.
   305  	entrypointMap := make(map[uint64]ReachabilityEntrypoint, len(collectedEntrypoints))
   306  	uniqueEntrypoints := make([]ReachabilityEntrypoint, 0, len(collectedEntrypoints))
   307  	for _, entrypoint := range collectedEntrypoints {
   308  		hash, err := entrypoint.Hash()
   309  		if err != nil {
   310  			return nil, maps.Keys(encounteredRelations), err
   311  		}
   312  
   313  		if _, ok := entrypointMap[hash]; !ok {
   314  			entrypointMap[hash] = entrypoint
   315  			uniqueEntrypoints = append(uniqueEntrypoints, entrypoint)
   316  		}
   317  	}
   318  
   319  	return uniqueEntrypoints, maps.Keys(encounteredRelations), nil
   320  }
   321  
   322  func (rg *ReachabilityGraph) getOrBuildGraph(ctx context.Context, resourceType *core.RelationReference, reachabilityOption reachabilityOption) (*core.ReachabilityGraph, error) {
   323  	// Check the cache.
   324  	// TODO(jschorr): Move this to a global cache.
   325  	cacheKey := tuple.StringRR(resourceType) + "-" + strconv.Itoa(int(reachabilityOption))
   326  	if cached, ok := rg.cachedGraphs.Load(cacheKey); ok {
   327  		return cached.(*core.ReachabilityGraph), nil
   328  	}
   329  
   330  	// Load the type system for the target resource relation.
   331  	namespace, err := rg.ts.resolver.LookupNamespace(ctx, resourceType.Namespace)
   332  	if err != nil {
   333  		return nil, err
   334  	}
   335  
   336  	rts, err := NewNamespaceTypeSystem(namespace, rg.ts.resolver)
   337  	if err != nil {
   338  		return nil, err
   339  	}
   340  
   341  	rrg, err := computeReachability(ctx, rts, resourceType.Relation, reachabilityOption)
   342  	if err != nil {
   343  		return nil, err
   344  	}
   345  
   346  	rg.cachedGraphs.Store(cacheKey, rrg)
   347  	return rrg, err
   348  }
   349  
   350  func (rg *ReachabilityGraph) collectEntrypoints(
   351  	ctx context.Context,
   352  	resourceType *core.RelationReference,
   353  	optionalSubjectType *core.RelationReference,
   354  	collected *[]ReachabilityEntrypoint,
   355  	encounteredRelations map[string]struct{},
   356  	reachabilityOption reachabilityOption,
   357  	entrypointLookupOption entrypointLookupOption,
   358  ) error {
   359  	// Ensure that we only process each relation once.
   360  	key := tuple.JoinRelRef(resourceType.Namespace, resourceType.Relation)
   361  	if _, ok := encounteredRelations[key]; ok {
   362  		return nil
   363  	}
   364  
   365  	encounteredRelations[key] = struct{}{}
   366  
   367  	rrg, err := rg.getOrBuildGraph(ctx, resourceType, reachabilityOption)
   368  	if err != nil {
   369  		return err
   370  	}
   371  
   372  	if optionalSubjectType != nil {
   373  		// Add subject type entrypoints.
   374  		subjectTypeEntrypoints, ok := rrg.EntrypointsBySubjectType[optionalSubjectType.Namespace]
   375  		if ok {
   376  			addEntrypoints(subjectTypeEntrypoints, resourceType, collected, encounteredRelations)
   377  		}
   378  
   379  		if entrypointLookupOption == entrypointLookupFindOne && len(*collected) > 0 {
   380  			return nil
   381  		}
   382  
   383  		// Add subject relation entrypoints.
   384  		subjectRelationEntrypoints, ok := rrg.EntrypointsBySubjectRelation[tuple.JoinRelRef(optionalSubjectType.Namespace, optionalSubjectType.Relation)]
   385  		if ok {
   386  			addEntrypoints(subjectRelationEntrypoints, resourceType, collected, encounteredRelations)
   387  		}
   388  
   389  		if entrypointLookupOption == entrypointLookupFindOne && len(*collected) > 0 {
   390  			return nil
   391  		}
   392  	} else {
   393  		// Add all entrypoints.
   394  		for _, entrypoints := range rrg.EntrypointsBySubjectType {
   395  			addEntrypoints(entrypoints, resourceType, collected, encounteredRelations)
   396  		}
   397  
   398  		for _, entrypoints := range rrg.EntrypointsBySubjectRelation {
   399  			addEntrypoints(entrypoints, resourceType, collected, encounteredRelations)
   400  		}
   401  	}
   402  
   403  	// Sort the keys to ensure a stable graph is produced.
   404  	keys := maps.Keys(rrg.EntrypointsBySubjectRelation)
   405  	sort.Strings(keys)
   406  
   407  	// Recursively collect over any reachability graphs for subjects with non-ellipsis relations.
   408  	for _, entrypointSetKey := range keys {
   409  		entrypointSet := rrg.EntrypointsBySubjectRelation[entrypointSetKey]
   410  		if entrypointSet.SubjectRelation != nil && entrypointSet.SubjectRelation.Relation != tuple.Ellipsis {
   411  			err := rg.collectEntrypoints(ctx, entrypointSet.SubjectRelation, optionalSubjectType, collected, encounteredRelations, reachabilityOption, entrypointLookupOption)
   412  			if err != nil {
   413  				return err
   414  			}
   415  
   416  			if entrypointLookupOption == entrypointLookupFindOne && len(*collected) > 0 {
   417  				return nil
   418  			}
   419  		}
   420  	}
   421  
   422  	return nil
   423  }
   424  
   425  func addEntrypoints(entrypoints *core.ReachabilityEntrypoints, parentRelation *core.RelationReference, collected *[]ReachabilityEntrypoint, encounteredRelations map[string]struct{}) {
   426  	for _, entrypoint := range entrypoints.Entrypoints {
   427  		if entrypoint.TuplesetRelation != "" {
   428  			key := tuple.JoinRelRef(entrypoint.TargetRelation.Namespace, entrypoint.TuplesetRelation)
   429  			encounteredRelations[key] = struct{}{}
   430  		}
   431  
   432  		*collected = append(*collected, ReachabilityEntrypoint{entrypoint, parentRelation})
   433  	}
   434  }