github.com/openfga/openfga@v1.5.4-rc1/pkg/server/commands/reverseexpand/reverse_expand.go (about)

     1  // Package reverseexpand contains the code that handles the ReverseExpand API
     2  package reverseexpand
     3  
     4  import (
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"sync"
     9  	"sync/atomic"
    10  
    11  	"github.com/hashicorp/go-multierror"
    12  	openfgav1 "github.com/openfga/api/proto/openfga/v1"
    13  	"github.com/sourcegraph/conc/pool"
    14  	"go.opentelemetry.io/otel"
    15  	"go.opentelemetry.io/otel/attribute"
    16  	"go.opentelemetry.io/otel/trace"
    17  	"google.golang.org/protobuf/types/known/structpb"
    18  
    19  	"github.com/openfga/openfga/internal/condition"
    20  	"github.com/openfga/openfga/internal/condition/eval"
    21  	"github.com/openfga/openfga/internal/graph"
    22  	serverconfig "github.com/openfga/openfga/internal/server/config"
    23  	"github.com/openfga/openfga/internal/validation"
    24  	"github.com/openfga/openfga/pkg/logger"
    25  	"github.com/openfga/openfga/pkg/storage"
    26  	"github.com/openfga/openfga/pkg/storage/storagewrappers"
    27  	"github.com/openfga/openfga/pkg/telemetry"
    28  	"github.com/openfga/openfga/pkg/tuple"
    29  	"github.com/openfga/openfga/pkg/typesystem"
    30  )
    31  
    32  var tracer = otel.Tracer("openfga/pkg/server/commands/reverse_expand")
    33  
    34  type ReverseExpandRequest struct {
    35  	StoreID          string
    36  	ObjectType       string
    37  	Relation         string
    38  	User             IsUserRef
    39  	ContextualTuples []*openfgav1.TupleKey
    40  	Context          *structpb.Struct
    41  
    42  	edge *graph.RelationshipEdge
    43  }
    44  
    45  type IsUserRef interface {
    46  	isUserRef()
    47  	GetObjectType() string
    48  	String() string
    49  }
    50  
    51  type UserRefObject struct {
    52  	Object *openfgav1.Object
    53  }
    54  
    55  var _ IsUserRef = (*UserRefObject)(nil)
    56  
    57  func (u *UserRefObject) isUserRef() {}
    58  
    59  func (u *UserRefObject) GetObjectType() string {
    60  	return u.Object.GetType()
    61  }
    62  
    63  func (u *UserRefObject) String() string {
    64  	return tuple.BuildObject(u.Object.GetType(), u.Object.GetId())
    65  }
    66  
    67  type UserRefTypedWildcard struct {
    68  	Type string
    69  }
    70  
    71  var _ IsUserRef = (*UserRefTypedWildcard)(nil)
    72  
    73  func (*UserRefTypedWildcard) isUserRef() {}
    74  
    75  func (u *UserRefTypedWildcard) GetObjectType() string {
    76  	return u.Type
    77  }
    78  
    79  func (u *UserRefTypedWildcard) String() string {
    80  	return tuple.TypedPublicWildcard(u.Type)
    81  }
    82  
    83  type UserRefObjectRelation struct {
    84  	ObjectRelation *openfgav1.ObjectRelation
    85  	Condition      *openfgav1.RelationshipCondition
    86  }
    87  
    88  func (*UserRefObjectRelation) isUserRef() {}
    89  
    90  func (u *UserRefObjectRelation) GetObjectType() string {
    91  	return tuple.GetType(u.ObjectRelation.GetObject())
    92  }
    93  
    94  func (u *UserRefObjectRelation) String() string {
    95  	return tuple.ToObjectRelationString(
    96  		u.ObjectRelation.GetObject(),
    97  		u.ObjectRelation.GetRelation(),
    98  	)
    99  }
   100  
   101  type UserRef struct {
   102  
   103  	// Types that are assignable to Ref
   104  	//  *UserRef_Object
   105  	//  *UserRef_TypedWildcard
   106  	//  *UserRef_ObjectRelation
   107  	Ref IsUserRef
   108  }
   109  
   110  type ReverseExpandQuery struct {
   111  	logger                  logger.Logger
   112  	datastore               storage.RelationshipTupleReader
   113  	typesystem              *typesystem.TypeSystem
   114  	resolveNodeLimit        uint32
   115  	resolveNodeBreadthLimit uint32
   116  
   117  	// visitedUsersetsMap map prevents visiting the same userset through the same edge twice
   118  	visitedUsersetsMap *sync.Map
   119  	// candidateObjectsMap map prevents returning the same object twice
   120  	candidateObjectsMap *sync.Map
   121  }
   122  
   123  type ReverseExpandQueryOption func(d *ReverseExpandQuery)
   124  
   125  func WithResolveNodeLimit(limit uint32) ReverseExpandQueryOption {
   126  	return func(d *ReverseExpandQuery) {
   127  		d.resolveNodeLimit = limit
   128  	}
   129  }
   130  
   131  func WithResolveNodeBreadthLimit(limit uint32) ReverseExpandQueryOption {
   132  	return func(d *ReverseExpandQuery) {
   133  		d.resolveNodeBreadthLimit = limit
   134  	}
   135  }
   136  
   137  func NewReverseExpandQuery(ds storage.RelationshipTupleReader, ts *typesystem.TypeSystem, opts ...ReverseExpandQueryOption) *ReverseExpandQuery {
   138  	query := &ReverseExpandQuery{
   139  		logger:                  logger.NewNoopLogger(),
   140  		datastore:               ds,
   141  		typesystem:              ts,
   142  		resolveNodeLimit:        serverconfig.DefaultResolveNodeLimit,
   143  		resolveNodeBreadthLimit: serverconfig.DefaultResolveNodeBreadthLimit,
   144  		candidateObjectsMap:     new(sync.Map),
   145  		visitedUsersetsMap:      new(sync.Map),
   146  	}
   147  
   148  	for _, opt := range opts {
   149  		opt(query)
   150  	}
   151  
   152  	return query
   153  }
   154  
   155  type ConditionalResultStatus int
   156  
   157  const (
   158  	RequiresFurtherEvalStatus ConditionalResultStatus = iota
   159  	NoFurtherEvalStatus
   160  )
   161  
   162  type ReverseExpandResult struct {
   163  	Object       string
   164  	ResultStatus ConditionalResultStatus
   165  }
   166  
   167  type ResolutionMetadata struct {
   168  	DatastoreQueryCount *uint32
   169  
   170  	// The number of times we are expanding from each node to find set of objects
   171  	DispatchCount *uint32
   172  }
   173  
   174  func NewResolutionMetadata() *ResolutionMetadata {
   175  	return &ResolutionMetadata{
   176  		DatastoreQueryCount: new(uint32),
   177  		DispatchCount:       new(uint32),
   178  	}
   179  }
   180  
   181  func WithLogger(logger logger.Logger) ReverseExpandQueryOption {
   182  	return func(d *ReverseExpandQuery) {
   183  		d.logger = logger
   184  	}
   185  }
   186  
   187  // Execute yields all the objects of the provided objectType that the
   188  // given user possibly has, a specific relation with and sends those
   189  // objects to resultChan. It MUST guarantee no duplicate objects sent.
   190  //
   191  // This function respects context timeouts and cancellations. If an
   192  // error is encountered (e.g. context timeout) before resolving all
   193  // objects, then the provided channel will NOT be closed, and it will
   194  // return the error.
   195  //
   196  // If no errors occur, then Execute will yield all of the objects on
   197  // the provided channel and then close the channel to signal that it
   198  // is done.
   199  func (c *ReverseExpandQuery) Execute(
   200  	ctx context.Context,
   201  	req *ReverseExpandRequest,
   202  	resultChan chan<- *ReverseExpandResult,
   203  	resolutionMetadata *ResolutionMetadata,
   204  ) error {
   205  	err := c.execute(ctx, req, resultChan, false, resolutionMetadata)
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	close(resultChan)
   211  	return nil
   212  }
   213  
   214  func (c *ReverseExpandQuery) execute(
   215  	ctx context.Context,
   216  	req *ReverseExpandRequest,
   217  	resultChan chan<- *ReverseExpandResult,
   218  	intersectionOrExclusionInPreviousEdges bool,
   219  	resolutionMetadata *ResolutionMetadata,
   220  ) error {
   221  	if ctx.Err() != nil {
   222  		return ctx.Err()
   223  	}
   224  
   225  	ctx, span := tracer.Start(ctx, "reverseExpand.Execute", trace.WithAttributes(
   226  		attribute.String("target_type", req.ObjectType),
   227  		attribute.String("target_relation", req.Relation),
   228  		attribute.String("source", req.User.String()),
   229  	))
   230  	defer span.End()
   231  
   232  	if req.edge != nil {
   233  		span.SetAttributes(attribute.String("edge", req.edge.String()))
   234  	}
   235  
   236  	depth, ok := graph.ResolutionDepthFromContext(ctx)
   237  	if !ok {
   238  		ctx = graph.ContextWithResolutionDepth(ctx, 0)
   239  	} else {
   240  		if depth >= c.resolveNodeLimit {
   241  			return graph.ErrResolutionDepthExceeded
   242  		}
   243  
   244  		ctx = graph.ContextWithResolutionDepth(ctx, depth+1)
   245  	}
   246  
   247  	var sourceUserRef *openfgav1.RelationReference
   248  	var sourceUserType, sourceUserObj string
   249  
   250  	// e.g. 'user:bob'
   251  	if val, ok := req.User.(*UserRefObject); ok {
   252  		sourceUserType = val.Object.GetType()
   253  		sourceUserObj = tuple.BuildObject(sourceUserType, val.Object.GetId())
   254  		sourceUserRef = typesystem.DirectRelationReference(sourceUserType, "")
   255  	}
   256  
   257  	// e.g. 'user:*'
   258  	if val, ok := req.User.(*UserRefTypedWildcard); ok {
   259  		sourceUserType = val.Type
   260  		sourceUserRef = typesystem.WildcardRelationReference(sourceUserType)
   261  	}
   262  
   263  	// e.g. 'group:eng#member'
   264  	if val, ok := req.User.(*UserRefObjectRelation); ok {
   265  		sourceUserType = tuple.GetType(val.ObjectRelation.GetObject())
   266  		sourceUserObj = val.ObjectRelation.GetObject()
   267  		sourceUserRef = typesystem.DirectRelationReference(sourceUserType, val.ObjectRelation.GetRelation())
   268  
   269  		if req.edge != nil {
   270  			key := fmt.Sprintf("%s#%s", sourceUserObj, req.edge.String())
   271  			if _, loaded := c.visitedUsersetsMap.LoadOrStore(key, struct{}{}); loaded {
   272  				// we've already visited this userset through this edge, exit to avoid an infinite cycle
   273  				return nil
   274  			}
   275  		}
   276  
   277  		sourceUserRel := val.ObjectRelation.GetRelation()
   278  
   279  		// ReverseExpand(type=document, rel=viewer, user=document:1#viewer) will return "document:1"
   280  		if sourceUserType == req.ObjectType && sourceUserRel == req.Relation {
   281  			if err := c.trySendCandidate(ctx, intersectionOrExclusionInPreviousEdges, sourceUserObj, resultChan); err != nil {
   282  				return err
   283  			}
   284  		}
   285  	}
   286  
   287  	targetObjRef := typesystem.DirectRelationReference(req.ObjectType, req.Relation)
   288  
   289  	g := graph.New(c.typesystem)
   290  
   291  	edges, err := g.GetPrunedRelationshipEdges(targetObjRef, sourceUserRef)
   292  	if err != nil {
   293  		return err
   294  	}
   295  
   296  	pool := pool.New().WithContext(ctx)
   297  	pool.WithCancelOnError()
   298  	pool.WithFirstError()
   299  	pool.WithMaxGoroutines(int(c.resolveNodeBreadthLimit))
   300  	var errs *multierror.Error
   301  
   302  LoopOnEdges:
   303  	for _, edge := range edges {
   304  		innerLoopEdge := edge
   305  		intersectionOrExclusionInPreviousEdges := intersectionOrExclusionInPreviousEdges || innerLoopEdge.TargetReferenceInvolvesIntersectionOrExclusion
   306  		r := &ReverseExpandRequest{
   307  			StoreID:          req.StoreID,
   308  			ObjectType:       req.ObjectType,
   309  			Relation:         req.Relation,
   310  			User:             req.User,
   311  			ContextualTuples: req.ContextualTuples,
   312  			Context:          req.Context,
   313  			edge:             innerLoopEdge,
   314  		}
   315  		switch innerLoopEdge.Type {
   316  		case graph.DirectEdge:
   317  			pool.Go(func(ctx context.Context) error {
   318  				return c.reverseExpandDirect(ctx, r, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
   319  			})
   320  		case graph.ComputedUsersetEdge:
   321  			// follow the computed_userset edge, no new goroutine needed since it's not I/O intensive
   322  			r.User = &UserRefObjectRelation{
   323  				ObjectRelation: &openfgav1.ObjectRelation{
   324  					Object:   sourceUserObj,
   325  					Relation: innerLoopEdge.TargetReference.GetRelation(),
   326  				},
   327  			}
   328  			atomic.AddUint32(resolutionMetadata.DispatchCount, 1)
   329  			err = c.execute(ctx, r, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
   330  			if err != nil {
   331  				errs = multierror.Append(errs, err)
   332  				break LoopOnEdges
   333  			}
   334  		case graph.TupleToUsersetEdge:
   335  			pool.Go(func(ctx context.Context) error {
   336  				return c.reverseExpandTupleToUserset(ctx, r, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
   337  			})
   338  		default:
   339  			panic("unsupported edge type")
   340  		}
   341  	}
   342  
   343  	err = pool.Wait()
   344  	if err != nil {
   345  		errs = multierror.Append(errs, err)
   346  	}
   347  	if errs.ErrorOrNil() != nil {
   348  		telemetry.TraceError(span, errs.ErrorOrNil())
   349  		return errs.ErrorOrNil()
   350  	}
   351  
   352  	return nil
   353  }
   354  
   355  func (c *ReverseExpandQuery) reverseExpandTupleToUserset(
   356  	ctx context.Context,
   357  	req *ReverseExpandRequest,
   358  	resultChan chan<- *ReverseExpandResult,
   359  	intersectionOrExclusionInPreviousEdges bool,
   360  	resolutionMetadata *ResolutionMetadata,
   361  ) error {
   362  	ctx, span := tracer.Start(ctx, "reverseExpandTupleToUserset", trace.WithAttributes(
   363  		attribute.String("edge", req.edge.String()),
   364  		attribute.String("source.user", req.User.String()),
   365  	))
   366  	var err error
   367  	defer func() {
   368  		if err != nil {
   369  			telemetry.TraceError(span, err)
   370  		}
   371  		span.End()
   372  	}()
   373  
   374  	err = c.readTuplesAndExecute(ctx, req, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
   375  	return err
   376  }
   377  
   378  func (c *ReverseExpandQuery) reverseExpandDirect(
   379  	ctx context.Context,
   380  	req *ReverseExpandRequest,
   381  	resultChan chan<- *ReverseExpandResult,
   382  	intersectionOrExclusionInPreviousEdges bool,
   383  	resolutionMetadata *ResolutionMetadata,
   384  ) error {
   385  	ctx, span := tracer.Start(ctx, "reverseExpandDirect", trace.WithAttributes(
   386  		attribute.String("edge", req.edge.String()),
   387  		attribute.String("source.user", req.User.String()),
   388  	))
   389  	var err error
   390  	defer func() {
   391  		if err != nil {
   392  			telemetry.TraceError(span, err)
   393  		}
   394  		span.End()
   395  	}()
   396  
   397  	err = c.readTuplesAndExecute(ctx, req, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
   398  	return err
   399  }
   400  
   401  func (c *ReverseExpandQuery) readTuplesAndExecute(
   402  	ctx context.Context,
   403  	req *ReverseExpandRequest,
   404  	resultChan chan<- *ReverseExpandResult,
   405  	intersectionOrExclusionInPreviousEdges bool,
   406  	resolutionMetadata *ResolutionMetadata,
   407  ) error {
   408  	if ctx.Err() != nil {
   409  		return ctx.Err()
   410  	}
   411  
   412  	ctx, span := tracer.Start(ctx, "readTuplesAndExecute")
   413  	defer span.End()
   414  
   415  	var userFilter []*openfgav1.ObjectRelation
   416  	var relationFilter string
   417  
   418  	switch req.edge.Type {
   419  	case graph.DirectEdge:
   420  		relationFilter = req.edge.TargetReference.GetRelation()
   421  		targetUserObjectType := req.User.GetObjectType()
   422  
   423  		publiclyAssignable, err := c.typesystem.IsPubliclyAssignable(req.edge.TargetReference, targetUserObjectType)
   424  		if err != nil {
   425  			return err
   426  		}
   427  
   428  		if publiclyAssignable {
   429  			// e.g. 'user:*'
   430  			userFilter = append(userFilter, &openfgav1.ObjectRelation{
   431  				Object: tuple.TypedPublicWildcard(targetUserObjectType),
   432  			})
   433  		}
   434  
   435  		// e.g. 'user:bob'
   436  		if val, ok := req.User.(*UserRefObject); ok {
   437  			userFilter = append(userFilter, &openfgav1.ObjectRelation{
   438  				Object: tuple.BuildObject(val.Object.GetType(), val.Object.GetId()),
   439  			})
   440  		}
   441  
   442  		// e.g. 'group:eng#member'
   443  		if val, ok := req.User.(*UserRefObjectRelation); ok {
   444  			userFilter = append(userFilter, val.ObjectRelation)
   445  		}
   446  	case graph.TupleToUsersetEdge:
   447  		relationFilter = req.edge.TuplesetRelation
   448  		// a TTU edge can only have a userset as a source node
   449  		// e.g. 'group:eng#member'
   450  		if val, ok := req.User.(*UserRefObjectRelation); ok {
   451  			userFilter = append(userFilter, &openfgav1.ObjectRelation{
   452  				Object: val.ObjectRelation.GetObject(),
   453  			})
   454  		} else {
   455  			panic("unexpected source for reverse expansion of tuple to userset")
   456  		}
   457  	default:
   458  		panic("unsupported edge type")
   459  	}
   460  
   461  	combinedTupleReader := storagewrappers.NewCombinedTupleReader(c.datastore, req.ContextualTuples)
   462  
   463  	// find all tuples of the form req.edge.TargetReference.Type:...#relationFilter@userFilter
   464  	iter, err := combinedTupleReader.ReadStartingWithUser(ctx, req.StoreID, storage.ReadStartingWithUserFilter{
   465  		ObjectType: req.edge.TargetReference.GetType(),
   466  		Relation:   relationFilter,
   467  		UserFilter: userFilter,
   468  	})
   469  	atomic.AddUint32(resolutionMetadata.DatastoreQueryCount, 1)
   470  	if err != nil {
   471  		return err
   472  	}
   473  
   474  	// filter out invalid tuples yielded by the database iterator
   475  	filteredIter := storage.NewFilteredTupleKeyIterator(
   476  		storage.NewTupleKeyIteratorFromTupleIterator(iter),
   477  		validation.FilterInvalidTuples(c.typesystem),
   478  	)
   479  	defer filteredIter.Stop()
   480  
   481  	pool := pool.New().WithContext(ctx)
   482  	pool.WithCancelOnError()
   483  	pool.WithFirstError()
   484  	pool.WithMaxGoroutines(int(c.resolveNodeBreadthLimit))
   485  
   486  	var errs *multierror.Error
   487  
   488  LoopOnIterator:
   489  	for {
   490  		tk, err := filteredIter.Next(ctx)
   491  		if err != nil {
   492  			if errors.Is(err, storage.ErrIteratorDone) {
   493  				break
   494  			}
   495  			errs = multierror.Append(errs, err)
   496  			break LoopOnIterator
   497  		}
   498  
   499  		condEvalResult, err := eval.EvaluateTupleCondition(ctx, tk, c.typesystem, req.Context)
   500  		if err != nil {
   501  			errs = multierror.Append(errs, err)
   502  			continue
   503  		}
   504  
   505  		if !condEvalResult.ConditionMet {
   506  			if len(condEvalResult.MissingParameters) > 0 {
   507  				errs = multierror.Append(errs, condition.NewEvaluationError(
   508  					tk.GetCondition().GetName(),
   509  					fmt.Errorf("tuple '%s' is missing context parameters '%v'",
   510  						tuple.TupleKeyToString(tk),
   511  						condEvalResult.MissingParameters),
   512  				))
   513  			}
   514  
   515  			continue
   516  		}
   517  
   518  		foundObject := tk.GetObject()
   519  		var newRelation string
   520  
   521  		switch req.edge.Type {
   522  		case graph.DirectEdge:
   523  			newRelation = tk.GetRelation()
   524  		case graph.TupleToUsersetEdge:
   525  			newRelation = req.edge.TargetReference.GetRelation()
   526  		default:
   527  			panic("unsupported edge type")
   528  		}
   529  
   530  		pool.Go(func(ctx context.Context) error {
   531  			atomic.AddUint32(resolutionMetadata.DispatchCount, 1)
   532  			return c.execute(ctx, &ReverseExpandRequest{
   533  				StoreID:    req.StoreID,
   534  				ObjectType: req.ObjectType,
   535  				Relation:   req.Relation,
   536  				User: &UserRefObjectRelation{
   537  					ObjectRelation: &openfgav1.ObjectRelation{
   538  						Object:   foundObject,
   539  						Relation: newRelation,
   540  					},
   541  					Condition: tk.GetCondition(),
   542  				},
   543  				ContextualTuples: req.ContextualTuples,
   544  				Context:          req.Context,
   545  				edge:             req.edge,
   546  			}, resultChan, intersectionOrExclusionInPreviousEdges, resolutionMetadata)
   547  		})
   548  	}
   549  
   550  	errs = multierror.Append(errs, pool.Wait())
   551  	if errs.ErrorOrNil() != nil {
   552  		telemetry.TraceError(span, errs.ErrorOrNil())
   553  		return errs
   554  	}
   555  
   556  	return nil
   557  }
   558  
   559  func (c *ReverseExpandQuery) trySendCandidate(ctx context.Context, intersectionOrExclusionInPreviousEdges bool, candidateObject string, candidateChan chan<- *ReverseExpandResult) error {
   560  	_, span := tracer.Start(ctx, "trySendCandidate", trace.WithAttributes(
   561  		attribute.String("object", candidateObject),
   562  		attribute.Bool("sent", false),
   563  	))
   564  	defer span.End()
   565  
   566  	if _, ok := c.candidateObjectsMap.LoadOrStore(candidateObject, struct{}{}); !ok {
   567  		resultStatus := NoFurtherEvalStatus
   568  		if intersectionOrExclusionInPreviousEdges {
   569  			span.SetAttributes(attribute.Bool("requires_further_eval", true))
   570  			resultStatus = RequiresFurtherEvalStatus
   571  		}
   572  
   573  		select {
   574  		case <-ctx.Done():
   575  			return ctx.Err()
   576  		case candidateChan <- &ReverseExpandResult{
   577  			Object:       candidateObject,
   578  			ResultStatus: resultStatus,
   579  		}:
   580  			span.SetAttributes(attribute.Bool("sent", true))
   581  		}
   582  	}
   583  
   584  	return nil
   585  }