github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/graph/expand.go (about)

     1  package graph
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  
     8  	"github.com/authzed/spicedb/internal/caveats"
     9  
    10  	"github.com/authzed/spicedb/internal/dispatch"
    11  	log "github.com/authzed/spicedb/internal/logging"
    12  	datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
    13  	"github.com/authzed/spicedb/internal/namespace"
    14  	"github.com/authzed/spicedb/pkg/datastore"
    15  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    16  	v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
    17  	"github.com/authzed/spicedb/pkg/spiceerrors"
    18  )
    19  
    20  // NewConcurrentExpander creates an instance of ConcurrentExpander
    21  func NewConcurrentExpander(d dispatch.Expand) *ConcurrentExpander {
    22  	return &ConcurrentExpander{d: d}
    23  }
    24  
    25  // ConcurrentExpander exposes a method to perform Expand requests, and delegates subproblems to the
    26  // provided dispatch.Expand instance.
    27  type ConcurrentExpander struct {
    28  	d dispatch.Expand
    29  }
    30  
    31  // ValidatedExpandRequest represents a request after it has been validated and parsed for internal
    32  // consumption.
    33  type ValidatedExpandRequest struct {
    34  	*v1.DispatchExpandRequest
    35  	Revision datastore.Revision
    36  }
    37  
    38  // Expand performs an expand request with the provided request and context.
    39  func (ce *ConcurrentExpander) Expand(ctx context.Context, req ValidatedExpandRequest, relation *core.Relation) (*v1.DispatchExpandResponse, error) {
    40  	log.Ctx(ctx).Trace().Object("expand", req).Send()
    41  
    42  	var directFunc ReduceableExpandFunc
    43  	if relation.UsersetRewrite == nil {
    44  		directFunc = ce.expandDirect(ctx, req)
    45  	} else {
    46  		directFunc = ce.expandUsersetRewrite(ctx, req, relation.UsersetRewrite)
    47  	}
    48  
    49  	resolved := expandOne(ctx, directFunc)
    50  	resolved.Resp.Metadata = addCallToResponseMetadata(resolved.Resp.Metadata)
    51  	return resolved.Resp, resolved.Err
    52  }
    53  
    54  func (ce *ConcurrentExpander) expandDirect(
    55  	ctx context.Context,
    56  	req ValidatedExpandRequest,
    57  ) ReduceableExpandFunc {
    58  	log.Ctx(ctx).Trace().Object("direct", req).Send()
    59  	return func(ctx context.Context, resultChan chan<- ExpandResult) {
    60  		ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision)
    61  		it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{
    62  			OptionalResourceType:     req.ResourceAndRelation.Namespace,
    63  			OptionalResourceIds:      []string{req.ResourceAndRelation.ObjectId},
    64  			OptionalResourceRelation: req.ResourceAndRelation.Relation,
    65  		})
    66  		if err != nil {
    67  			resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata)
    68  			return
    69  		}
    70  		defer it.Close()
    71  
    72  		var foundNonTerminalUsersets []*core.DirectSubject
    73  		var foundTerminalUsersets []*core.DirectSubject
    74  		for tpl := it.Next(); tpl != nil; tpl = it.Next() {
    75  			if it.Err() != nil {
    76  				resultChan <- expandResultError(NewExpansionFailureErr(it.Err()), emptyMetadata)
    77  				return
    78  			}
    79  
    80  			ds := &core.DirectSubject{
    81  				Subject:          tpl.Subject,
    82  				CaveatExpression: caveats.CaveatAsExpr(tpl.Caveat),
    83  			}
    84  			if tpl.Subject.Relation == Ellipsis {
    85  				foundTerminalUsersets = append(foundTerminalUsersets, ds)
    86  			} else {
    87  				foundNonTerminalUsersets = append(foundNonTerminalUsersets, ds)
    88  			}
    89  		}
    90  		it.Close()
    91  
    92  		// If only shallow expansion was required, or there are no non-terminal subjects found,
    93  		// nothing more to do.
    94  		if req.ExpansionMode == v1.DispatchExpandRequest_SHALLOW || len(foundNonTerminalUsersets) == 0 {
    95  			resultChan <- expandResult(
    96  				&core.RelationTupleTreeNode{
    97  					NodeType: &core.RelationTupleTreeNode_LeafNode{
    98  						LeafNode: &core.DirectSubjects{
    99  							Subjects: append(foundTerminalUsersets, foundNonTerminalUsersets...),
   100  						},
   101  					},
   102  					Expanded: req.ResourceAndRelation,
   103  				},
   104  				emptyMetadata,
   105  			)
   106  			return
   107  		}
   108  
   109  		// Otherwise, recursively issue expansion and collect the results from that, plus the
   110  		// found terminals together.
   111  		var requestsToDispatch []ReduceableExpandFunc
   112  		for _, nonTerminalUser := range foundNonTerminalUsersets {
   113  			toDispatch := ce.dispatch(ValidatedExpandRequest{
   114  				&v1.DispatchExpandRequest{
   115  					ResourceAndRelation: nonTerminalUser.Subject,
   116  					Metadata:            decrementDepth(req.Metadata),
   117  					ExpansionMode:       req.ExpansionMode,
   118  				},
   119  				req.Revision,
   120  			})
   121  
   122  			requestsToDispatch = append(requestsToDispatch, decorateWithCaveatIfNecessary(toDispatch, nonTerminalUser.CaveatExpression))
   123  		}
   124  
   125  		result := expandAny(ctx, req.ResourceAndRelation, requestsToDispatch)
   126  		if result.Err != nil {
   127  			resultChan <- result
   128  			return
   129  		}
   130  
   131  		unionNode := result.Resp.TreeNode.GetIntermediateNode()
   132  		unionNode.ChildNodes = append(unionNode.ChildNodes, &core.RelationTupleTreeNode{
   133  			NodeType: &core.RelationTupleTreeNode_LeafNode{
   134  				LeafNode: &core.DirectSubjects{
   135  					Subjects: append(foundTerminalUsersets, foundNonTerminalUsersets...),
   136  				},
   137  			},
   138  			Expanded: req.ResourceAndRelation,
   139  		})
   140  		resultChan <- result
   141  	}
   142  }
   143  
   144  func decorateWithCaveatIfNecessary(toDispatch ReduceableExpandFunc, caveatExpr *core.CaveatExpression) ReduceableExpandFunc {
   145  	// If no caveat expression, simply return the func unmodified.
   146  	if caveatExpr == nil {
   147  		return toDispatch
   148  	}
   149  
   150  	// Otherwise return a wrapped function that expands the underlying func to be dispatched, and then decorates
   151  	// the resulting node with the caveat expression.
   152  	//
   153  	// TODO(jschorr): This will generate a lot of function closures, so we should change Expand to avoid them
   154  	// like we did in Check.
   155  	return func(ctx context.Context, resultChan chan<- ExpandResult) {
   156  		result := expandOne(ctx, toDispatch)
   157  		if result.Err != nil {
   158  			resultChan <- result
   159  			return
   160  		}
   161  
   162  		result.Resp.TreeNode.CaveatExpression = caveatExpr
   163  		resultChan <- result
   164  	}
   165  }
   166  
   167  func (ce *ConcurrentExpander) expandUsersetRewrite(ctx context.Context, req ValidatedExpandRequest, usr *core.UsersetRewrite) ReduceableExpandFunc {
   168  	switch rw := usr.RewriteOperation.(type) {
   169  	case *core.UsersetRewrite_Union:
   170  		log.Ctx(ctx).Trace().Msg("union")
   171  		return ce.expandSetOperation(ctx, req, rw.Union, expandAny)
   172  	case *core.UsersetRewrite_Intersection:
   173  		log.Ctx(ctx).Trace().Msg("intersection")
   174  		return ce.expandSetOperation(ctx, req, rw.Intersection, expandAll)
   175  	case *core.UsersetRewrite_Exclusion:
   176  		log.Ctx(ctx).Trace().Msg("exclusion")
   177  		return ce.expandSetOperation(ctx, req, rw.Exclusion, expandDifference)
   178  	default:
   179  		return alwaysFailExpand
   180  	}
   181  }
   182  
   183  func (ce *ConcurrentExpander) expandSetOperation(ctx context.Context, req ValidatedExpandRequest, so *core.SetOperation, reducer ExpandReducer) ReduceableExpandFunc {
   184  	var requests []ReduceableExpandFunc
   185  	for _, childOneof := range so.Child {
   186  		switch child := childOneof.ChildType.(type) {
   187  		case *core.SetOperation_Child_XThis:
   188  			return expandError(errors.New("use of _this is unsupported; please rewrite your schema"))
   189  		case *core.SetOperation_Child_ComputedUserset:
   190  			requests = append(requests, ce.expandComputedUserset(ctx, req, child.ComputedUserset, nil))
   191  		case *core.SetOperation_Child_UsersetRewrite:
   192  			requests = append(requests, ce.expandUsersetRewrite(ctx, req, child.UsersetRewrite))
   193  		case *core.SetOperation_Child_TupleToUserset:
   194  			requests = append(requests, ce.expandTupleToUserset(ctx, req, child.TupleToUserset))
   195  		case *core.SetOperation_Child_XNil:
   196  			requests = append(requests, emptyExpansion(req.ResourceAndRelation))
   197  		default:
   198  			return expandError(fmt.Errorf("unknown set operation child `%T` in expand", child))
   199  		}
   200  	}
   201  	return func(ctx context.Context, resultChan chan<- ExpandResult) {
   202  		resultChan <- reducer(ctx, req.ResourceAndRelation, requests)
   203  	}
   204  }
   205  
   206  func (ce *ConcurrentExpander) dispatch(req ValidatedExpandRequest) ReduceableExpandFunc {
   207  	return func(ctx context.Context, resultChan chan<- ExpandResult) {
   208  		log.Ctx(ctx).Trace().Object("dispatchExpand", req).Send()
   209  		result, err := ce.d.DispatchExpand(ctx, req.DispatchExpandRequest)
   210  		resultChan <- ExpandResult{result, err}
   211  	}
   212  }
   213  
   214  func (ce *ConcurrentExpander) expandComputedUserset(ctx context.Context, req ValidatedExpandRequest, cu *core.ComputedUserset, tpl *core.RelationTuple) ReduceableExpandFunc {
   215  	log.Ctx(ctx).Trace().Str("relation", cu.Relation).Msg("computed userset")
   216  	var start *core.ObjectAndRelation
   217  	if cu.Object == core.ComputedUserset_TUPLE_USERSET_OBJECT {
   218  		if tpl == nil {
   219  			return expandError(spiceerrors.MustBugf("computed userset for tupleset without tuple"))
   220  		}
   221  
   222  		start = tpl.Subject
   223  	} else if cu.Object == core.ComputedUserset_TUPLE_OBJECT {
   224  		if tpl != nil {
   225  			start = tpl.ResourceAndRelation
   226  		} else {
   227  			start = req.ResourceAndRelation
   228  		}
   229  	}
   230  
   231  	// Check if the target relation exists. If not, return nothing.
   232  	ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision)
   233  	err := namespace.CheckNamespaceAndRelation(ctx, start.Namespace, cu.Relation, true, ds)
   234  	if err != nil {
   235  		if errors.As(err, &namespace.ErrRelationNotFound{}) {
   236  			return emptyExpansion(req.ResourceAndRelation)
   237  		}
   238  
   239  		return expandError(err)
   240  	}
   241  
   242  	return ce.dispatch(ValidatedExpandRequest{
   243  		&v1.DispatchExpandRequest{
   244  			ResourceAndRelation: &core.ObjectAndRelation{
   245  				Namespace: start.Namespace,
   246  				ObjectId:  start.ObjectId,
   247  				Relation:  cu.Relation,
   248  			},
   249  			Metadata:      decrementDepth(req.Metadata),
   250  			ExpansionMode: req.ExpansionMode,
   251  		},
   252  		req.Revision,
   253  	})
   254  }
   255  
   256  func (ce *ConcurrentExpander) expandTupleToUserset(_ context.Context, req ValidatedExpandRequest, ttu *core.TupleToUserset) ReduceableExpandFunc {
   257  	return func(ctx context.Context, resultChan chan<- ExpandResult) {
   258  		ds := datastoremw.MustFromContext(ctx).SnapshotReader(req.Revision)
   259  		it, err := ds.QueryRelationships(ctx, datastore.RelationshipsFilter{
   260  			OptionalResourceType:     req.ResourceAndRelation.Namespace,
   261  			OptionalResourceIds:      []string{req.ResourceAndRelation.ObjectId},
   262  			OptionalResourceRelation: ttu.Tupleset.Relation,
   263  		})
   264  		if err != nil {
   265  			resultChan <- expandResultError(NewExpansionFailureErr(err), emptyMetadata)
   266  			return
   267  		}
   268  		defer it.Close()
   269  
   270  		var requestsToDispatch []ReduceableExpandFunc
   271  		for tpl := it.Next(); tpl != nil; tpl = it.Next() {
   272  			if it.Err() != nil {
   273  				resultChan <- expandResultError(NewExpansionFailureErr(it.Err()), emptyMetadata)
   274  				return
   275  			}
   276  
   277  			toDispatch := ce.expandComputedUserset(ctx, req, ttu.ComputedUserset, tpl)
   278  			requestsToDispatch = append(requestsToDispatch, decorateWithCaveatIfNecessary(toDispatch, caveats.CaveatAsExpr(tpl.Caveat)))
   279  		}
   280  		it.Close()
   281  
   282  		resultChan <- expandAny(ctx, req.ResourceAndRelation, requestsToDispatch)
   283  	}
   284  }
   285  
   286  func setResult(
   287  	op core.SetOperationUserset_Operation,
   288  	start *core.ObjectAndRelation,
   289  	children []*core.RelationTupleTreeNode,
   290  	metadata *v1.ResponseMeta,
   291  ) ExpandResult {
   292  	return expandResult(
   293  		&core.RelationTupleTreeNode{
   294  			NodeType: &core.RelationTupleTreeNode_IntermediateNode{
   295  				IntermediateNode: &core.SetOperationUserset{
   296  					Operation:  op,
   297  					ChildNodes: children,
   298  				},
   299  			},
   300  			Expanded: start,
   301  		},
   302  		metadata,
   303  	)
   304  }
   305  
   306  func expandSetOperation(
   307  	ctx context.Context,
   308  	start *core.ObjectAndRelation,
   309  	requests []ReduceableExpandFunc,
   310  	op core.SetOperationUserset_Operation,
   311  ) ExpandResult {
   312  	children := make([]*core.RelationTupleTreeNode, 0, len(requests))
   313  
   314  	if len(requests) == 0 {
   315  		return setResult(op, start, children, emptyMetadata)
   316  	}
   317  
   318  	childCtx, cancelFn := context.WithCancel(ctx)
   319  	defer cancelFn()
   320  
   321  	resultChans := make([]chan ExpandResult, 0, len(requests))
   322  	for _, req := range requests {
   323  		resultChan := make(chan ExpandResult, 1)
   324  		resultChans = append(resultChans, resultChan)
   325  		go req(childCtx, resultChan)
   326  	}
   327  
   328  	responseMetadata := emptyMetadata
   329  	for _, resultChan := range resultChans {
   330  		select {
   331  		case result := <-resultChan:
   332  			responseMetadata = combineResponseMetadata(responseMetadata, result.Resp.Metadata)
   333  			if result.Err != nil {
   334  				return expandResultError(result.Err, responseMetadata)
   335  			}
   336  			children = append(children, result.Resp.TreeNode)
   337  		case <-ctx.Done():
   338  			return expandResultError(context.Canceled, responseMetadata)
   339  		}
   340  	}
   341  
   342  	return setResult(op, start, children, responseMetadata)
   343  }
   344  
   345  // emptyExpansion returns an empty expansion.
   346  func emptyExpansion(start *core.ObjectAndRelation) ReduceableExpandFunc {
   347  	return func(ctx context.Context, resultChan chan<- ExpandResult) {
   348  		resultChan <- expandResult(&core.RelationTupleTreeNode{
   349  			NodeType: &core.RelationTupleTreeNode_LeafNode{
   350  				LeafNode: &core.DirectSubjects{},
   351  			},
   352  			Expanded: start,
   353  		}, emptyMetadata)
   354  	}
   355  }
   356  
   357  // expandError returns the error.
   358  func expandError(err error) ReduceableExpandFunc {
   359  	return func(ctx context.Context, resultChan chan<- ExpandResult) {
   360  		resultChan <- expandResultError(err, emptyMetadata)
   361  	}
   362  }
   363  
   364  // expandAll returns a tree with all of the children and an intersection node type.
   365  func expandAll(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult {
   366  	return expandSetOperation(ctx, start, requests, core.SetOperationUserset_INTERSECTION)
   367  }
   368  
   369  // expandAny returns a tree with all of the children and a union node type.
   370  func expandAny(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult {
   371  	return expandSetOperation(ctx, start, requests, core.SetOperationUserset_UNION)
   372  }
   373  
   374  // expandDifference returns a tree with all of the children and an exclusion node type.
   375  func expandDifference(ctx context.Context, start *core.ObjectAndRelation, requests []ReduceableExpandFunc) ExpandResult {
   376  	return expandSetOperation(ctx, start, requests, core.SetOperationUserset_EXCLUSION)
   377  }
   378  
   379  // expandOne waits for exactly one response
   380  func expandOne(ctx context.Context, request ReduceableExpandFunc) ExpandResult {
   381  	resultChan := make(chan ExpandResult, 1)
   382  	go request(ctx, resultChan)
   383  
   384  	select {
   385  	case result := <-resultChan:
   386  		if result.Err != nil {
   387  			return result
   388  		}
   389  		return result
   390  	case <-ctx.Done():
   391  		return expandResultError(context.Canceled, emptyMetadata)
   392  	}
   393  }
   394  
   395  var errAlwaysFailExpand = errors.New("always fail")
   396  
   397  func alwaysFailExpand(_ context.Context, resultChan chan<- ExpandResult) {
   398  	resultChan <- expandResultError(errAlwaysFailExpand, emptyMetadata)
   399  }
   400  
   401  func expandResult(treeNode *core.RelationTupleTreeNode, subProblemMetadata *v1.ResponseMeta) ExpandResult {
   402  	return ExpandResult{
   403  		&v1.DispatchExpandResponse{
   404  			Metadata: ensureMetadata(subProblemMetadata),
   405  			TreeNode: treeNode,
   406  		},
   407  		nil,
   408  	}
   409  }
   410  
   411  func expandResultError(err error, subProblemMetadata *v1.ResponseMeta) ExpandResult {
   412  	return ExpandResult{
   413  		&v1.DispatchExpandResponse{
   414  			Metadata: ensureMetadata(subProblemMetadata),
   415  		},
   416  		err,
   417  	}
   418  }