github.com/openfga/openfga@v1.5.4-rc1/pkg/server/commands/list_objects.go (about)

     1  package commands
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"math"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/hashicorp/go-multierror"
    13  	openfgav1 "github.com/openfga/api/proto/openfga/v1"
    14  	"github.com/prometheus/client_golang/prometheus"
    15  	"github.com/prometheus/client_golang/prometheus/promauto"
    16  	"google.golang.org/protobuf/types/known/structpb"
    17  
    18  	"github.com/openfga/openfga/internal/build"
    19  	"github.com/openfga/openfga/internal/condition"
    20  	"github.com/openfga/openfga/internal/graph"
    21  	serverconfig "github.com/openfga/openfga/internal/server/config"
    22  	"github.com/openfga/openfga/internal/validation"
    23  	"github.com/openfga/openfga/pkg/logger"
    24  	"github.com/openfga/openfga/pkg/server/commands/reverseexpand"
    25  	serverErrors "github.com/openfga/openfga/pkg/server/errors"
    26  	"github.com/openfga/openfga/pkg/storage"
    27  	"github.com/openfga/openfga/pkg/storage/storagewrappers"
    28  	"github.com/openfga/openfga/pkg/tuple"
    29  	"github.com/openfga/openfga/pkg/typesystem"
    30  )
    31  
    32  const streamedBufferSize = 100
    33  
    34  var (
    35  	furtherEvalRequiredCounter = promauto.NewCounter(prometheus.CounterOpts{
    36  		Namespace: build.ProjectName,
    37  		Name:      "list_objects_further_eval_required_count",
    38  		Help:      "Number of objects in a ListObjects call that needed to issue a Check call to determine a final result",
    39  	})
    40  
    41  	noFurtherEvalRequiredCounter = promauto.NewCounter(prometheus.CounterOpts{
    42  		Namespace: build.ProjectName,
    43  		Name:      "list_objects_no_further_eval_required_count",
    44  		Help:      "Number of objects in a ListObjects call that needed to issue a Check call to determine a final result",
    45  	})
    46  )
    47  
    48  type ListObjectsQuery struct {
    49  	datastore               storage.RelationshipTupleReader
    50  	logger                  logger.Logger
    51  	listObjectsDeadline     time.Duration
    52  	listObjectsMaxResults   uint32
    53  	resolveNodeLimit        uint32
    54  	resolveNodeBreadthLimit uint32
    55  	maxConcurrentReads      uint32
    56  
    57  	checkResolver graph.CheckResolver
    58  }
    59  
    60  type ListObjectsResolutionMetadata struct {
    61  	// The total number of database reads from reverse_expand and Check (if any) to complete the ListObjects request
    62  	DatastoreQueryCount *uint32
    63  
    64  	// The total number of dispatches aggregated from reverse_expand and check resolutions (if any) to complete the ListObjects request
    65  	DispatchCount *uint32
    66  }
    67  
    68  func NewListObjectsResolutionMetadata() *ListObjectsResolutionMetadata {
    69  	return &ListObjectsResolutionMetadata{
    70  		DatastoreQueryCount: new(uint32),
    71  		DispatchCount:       new(uint32),
    72  	}
    73  }
    74  
    75  type ListObjectsResponse struct {
    76  	Objects            []string
    77  	ResolutionMetadata ListObjectsResolutionMetadata
    78  }
    79  
    80  type ListObjectsQueryOption func(d *ListObjectsQuery)
    81  
    82  func WithListObjectsDeadline(deadline time.Duration) ListObjectsQueryOption {
    83  	return func(d *ListObjectsQuery) {
    84  		d.listObjectsDeadline = deadline
    85  	}
    86  }
    87  
    88  func WithListObjectsMaxResults(max uint32) ListObjectsQueryOption {
    89  	return func(d *ListObjectsQuery) {
    90  		d.listObjectsMaxResults = max
    91  	}
    92  }
    93  
    94  // WithResolveNodeLimit see server.WithResolveNodeLimit.
    95  func WithResolveNodeLimit(limit uint32) ListObjectsQueryOption {
    96  	return func(d *ListObjectsQuery) {
    97  		d.resolveNodeLimit = limit
    98  	}
    99  }
   100  
   101  // WithResolveNodeBreadthLimit see server.WithResolveNodeBreadthLimit.
   102  func WithResolveNodeBreadthLimit(limit uint32) ListObjectsQueryOption {
   103  	return func(d *ListObjectsQuery) {
   104  		d.resolveNodeBreadthLimit = limit
   105  	}
   106  }
   107  
   108  func WithLogger(l logger.Logger) ListObjectsQueryOption {
   109  	return func(d *ListObjectsQuery) {
   110  		d.logger = l
   111  	}
   112  }
   113  
   114  // WithMaxConcurrentReads see server.WithMaxConcurrentReadsForListObjects.
   115  func WithMaxConcurrentReads(limit uint32) ListObjectsQueryOption {
   116  	return func(d *ListObjectsQuery) {
   117  		d.maxConcurrentReads = limit
   118  	}
   119  }
   120  
   121  func NewListObjectsQuery(
   122  	ds storage.RelationshipTupleReader,
   123  	checkResolver graph.CheckResolver,
   124  	opts ...ListObjectsQueryOption,
   125  ) (*ListObjectsQuery, error) {
   126  	if ds == nil {
   127  		return nil, fmt.Errorf("the provided datastore parameter 'ds' must be non-nil")
   128  	}
   129  
   130  	if checkResolver == nil {
   131  		return nil, fmt.Errorf("the provided CheckResolver parameter 'checkResolver' must be non-nil")
   132  	}
   133  
   134  	query := &ListObjectsQuery{
   135  		datastore:               ds,
   136  		logger:                  logger.NewNoopLogger(),
   137  		listObjectsDeadline:     serverconfig.DefaultListObjectsDeadline,
   138  		listObjectsMaxResults:   serverconfig.DefaultListObjectsMaxResults,
   139  		resolveNodeLimit:        serverconfig.DefaultResolveNodeLimit,
   140  		resolveNodeBreadthLimit: serverconfig.DefaultResolveNodeBreadthLimit,
   141  		maxConcurrentReads:      serverconfig.DefaultMaxConcurrentReadsForListObjects,
   142  		checkResolver:           checkResolver,
   143  	}
   144  
   145  	for _, opt := range opts {
   146  		opt(query)
   147  	}
   148  
   149  	query.datastore = storagewrappers.NewBoundedConcurrencyTupleReader(query.datastore, query.maxConcurrentReads)
   150  
   151  	return query, nil
   152  }
   153  
   154  type ListObjectsResult struct {
   155  	ObjectID string
   156  	Err      error
   157  }
   158  
   159  // listObjectsRequest captures the RPC request definition interface for the ListObjects API.
   160  // The unary and streaming RPC definitions implement this interface, and so it can be used
   161  // interchangeably for a canonical representation between the two.
   162  type listObjectsRequest interface {
   163  	GetStoreId() string
   164  	GetAuthorizationModelId() string
   165  	GetType() string
   166  	GetRelation() string
   167  	GetUser() string
   168  	GetContextualTuples() *openfgav1.ContextualTupleKeys
   169  	GetContext() *structpb.Struct
   170  }
   171  
   172  // evaluate fires of evaluation of the ListObjects query by delegating to
   173  // [[reverseexpand.ReverseExpand#Execute]] and resolving the results yielded
   174  // from it. If any results yielded by reverse expansion require further eval,
   175  // then these results get dispatched to Check to resolve the residual outcome.
   176  //
   177  // The resultsChan is **always** closed by evaluate when it is done with its work,
   178  // which is either when all results have been yielded, the deadline has been met,
   179  // or some other terminal error case has occurred.
   180  func (q *ListObjectsQuery) evaluate(
   181  	ctx context.Context,
   182  	req listObjectsRequest,
   183  	resultsChan chan<- ListObjectsResult,
   184  	maxResults uint32,
   185  	resolutionMetadata *ListObjectsResolutionMetadata,
   186  ) error {
   187  	targetObjectType := req.GetType()
   188  	targetRelation := req.GetRelation()
   189  
   190  	typesys, ok := typesystem.TypesystemFromContext(ctx)
   191  	if !ok {
   192  		panic("typesystem missing in context")
   193  	}
   194  
   195  	if !typesystem.IsSchemaVersionSupported(typesys.GetSchemaVersion()) {
   196  		return serverErrors.ValidationError(typesystem.ErrInvalidSchemaVersion)
   197  	}
   198  
   199  	for _, ctxTuple := range req.GetContextualTuples().GetTupleKeys() {
   200  		if err := validation.ValidateTuple(typesys, ctxTuple); err != nil {
   201  			return serverErrors.HandleTupleValidateError(err)
   202  		}
   203  	}
   204  
   205  	_, err := typesys.GetRelation(targetObjectType, targetRelation)
   206  	if err != nil {
   207  		if errors.Is(err, typesystem.ErrObjectTypeUndefined) {
   208  			return serverErrors.TypeNotFound(targetObjectType)
   209  		}
   210  
   211  		if errors.Is(err, typesystem.ErrRelationUndefined) {
   212  			return serverErrors.RelationNotFound(targetRelation, targetObjectType, nil)
   213  		}
   214  
   215  		return serverErrors.HandleError("", err)
   216  	}
   217  
   218  	if err := validation.ValidateUser(typesys, req.GetUser()); err != nil {
   219  		return serverErrors.ValidationError(fmt.Errorf("invalid 'user' value: %s", err))
   220  	}
   221  
   222  	handler := func() {
   223  		userObj, userRel := tuple.SplitObjectRelation(req.GetUser())
   224  		userObjType, userObjID := tuple.SplitObject(userObj)
   225  
   226  		var sourceUserRef reverseexpand.IsUserRef
   227  		sourceUserRef = &reverseexpand.UserRefObject{
   228  			Object: &openfgav1.Object{
   229  				Type: userObjType,
   230  				Id:   userObjID,
   231  			},
   232  		}
   233  
   234  		if tuple.IsTypedWildcard(userObj) {
   235  			sourceUserRef = &reverseexpand.UserRefTypedWildcard{Type: tuple.GetType(userObj)}
   236  		}
   237  
   238  		if userRel != "" {
   239  			sourceUserRef = &reverseexpand.UserRefObjectRelation{
   240  				ObjectRelation: &openfgav1.ObjectRelation{
   241  					Object:   userObj,
   242  					Relation: userRel,
   243  				},
   244  			}
   245  		}
   246  
   247  		reverseExpandResultsChan := make(chan *reverseexpand.ReverseExpandResult, 1)
   248  		objectsFound := atomic.Uint32{}
   249  
   250  		ds := storagewrappers.NewCombinedTupleReader(
   251  			q.datastore,
   252  			req.GetContextualTuples().GetTupleKeys(),
   253  		)
   254  
   255  		reverseExpandQuery := reverseexpand.NewReverseExpandQuery(
   256  			ds,
   257  			typesys,
   258  			reverseexpand.WithResolveNodeLimit(q.resolveNodeLimit),
   259  			reverseexpand.WithResolveNodeBreadthLimit(q.resolveNodeBreadthLimit),
   260  			reverseexpand.WithLogger(q.logger),
   261  		)
   262  
   263  		cancelCtx, cancel := context.WithCancel(ctx)
   264  
   265  		wg := sync.WaitGroup{}
   266  
   267  		errChan := make(chan error, 1)
   268  
   269  		reverseExpandResolutionMetadata := reverseexpand.NewResolutionMetadata()
   270  
   271  		wg.Add(1)
   272  		go func() {
   273  			defer wg.Done()
   274  
   275  			err := reverseExpandQuery.Execute(cancelCtx, &reverseexpand.ReverseExpandRequest{
   276  				StoreID:          req.GetStoreId(),
   277  				ObjectType:       targetObjectType,
   278  				Relation:         targetRelation,
   279  				User:             sourceUserRef,
   280  				ContextualTuples: req.GetContextualTuples().GetTupleKeys(),
   281  				Context:          req.GetContext(),
   282  			}, reverseExpandResultsChan, reverseExpandResolutionMetadata)
   283  			if err != nil {
   284  				errChan <- err
   285  			}
   286  			atomic.AddUint32(resolutionMetadata.DatastoreQueryCount, *reverseExpandResolutionMetadata.DatastoreQueryCount)
   287  			atomic.AddUint32(resolutionMetadata.DispatchCount, *reverseExpandResolutionMetadata.DispatchCount)
   288  		}()
   289  
   290  		ctx = typesystem.ContextWithTypesystem(ctx, typesys)
   291  		ctx := storage.ContextWithRelationshipTupleReader(ctx, ds)
   292  
   293  		concurrencyLimiterCh := make(chan struct{}, q.resolveNodeBreadthLimit)
   294  
   295  	ConsumerReadLoop:
   296  		for {
   297  			select {
   298  			case <-ctx.Done():
   299  				break ConsumerReadLoop
   300  			case res, channelOpen := <-reverseExpandResultsChan:
   301  				if !channelOpen {
   302  					break ConsumerReadLoop
   303  				}
   304  
   305  				if !(maxResults == 0) && objectsFound.Load() >= maxResults {
   306  					break ConsumerReadLoop
   307  				}
   308  
   309  				if res.ResultStatus == reverseexpand.NoFurtherEvalStatus {
   310  					noFurtherEvalRequiredCounter.Inc()
   311  					trySendObject(res.Object, &objectsFound, maxResults, resultsChan)
   312  					continue
   313  				}
   314  
   315  				furtherEvalRequiredCounter.Inc()
   316  
   317  				wg.Add(1)
   318  				go func(res *reverseexpand.ReverseExpandResult) {
   319  					defer func() {
   320  						<-concurrencyLimiterCh
   321  						wg.Done()
   322  					}()
   323  
   324  					concurrencyLimiterCh <- struct{}{}
   325  					checkRequestMetadata := graph.NewCheckRequestMetadata(q.resolveNodeLimit)
   326  
   327  					resp, err := q.checkResolver.ResolveCheck(ctx, &graph.ResolveCheckRequest{
   328  						StoreID:              req.GetStoreId(),
   329  						AuthorizationModelID: req.GetAuthorizationModelId(),
   330  						TupleKey:             tuple.NewTupleKey(res.Object, req.GetRelation(), req.GetUser()),
   331  						ContextualTuples:     req.GetContextualTuples().GetTupleKeys(),
   332  						Context:              req.GetContext(),
   333  						RequestMetadata:      checkRequestMetadata,
   334  					})
   335  					if err != nil {
   336  						if errors.Is(err, graph.ErrResolutionDepthExceeded) {
   337  							resultsChan <- ListObjectsResult{Err: serverErrors.AuthorizationModelResolutionTooComplex}
   338  							return
   339  						}
   340  
   341  						resultsChan <- ListObjectsResult{Err: err}
   342  						return
   343  					}
   344  					atomic.AddUint32(resolutionMetadata.DatastoreQueryCount, resp.GetResolutionMetadata().DatastoreQueryCount)
   345  					atomic.AddUint32(resolutionMetadata.DispatchCount, checkRequestMetadata.DispatchCounter.Load())
   346  
   347  					if resp.Allowed {
   348  						trySendObject(res.Object, &objectsFound, maxResults, resultsChan)
   349  					}
   350  				}(res)
   351  
   352  			case err := <-errChan:
   353  				if errors.Is(err, graph.ErrResolutionDepthExceeded) {
   354  					err = serverErrors.AuthorizationModelResolutionTooComplex
   355  				}
   356  
   357  				resultsChan <- ListObjectsResult{Err: err}
   358  				break ConsumerReadLoop
   359  			}
   360  		}
   361  
   362  		cancel()
   363  		wg.Wait()
   364  		close(resultsChan)
   365  	}
   366  
   367  	go handler()
   368  
   369  	return nil
   370  }
   371  
   372  func trySendObject(object string, objectsFound *atomic.Uint32, maxResults uint32, resultsChan chan<- ListObjectsResult) {
   373  	if !(maxResults == 0) {
   374  		if objectsFound.Add(1) > maxResults {
   375  			return
   376  		}
   377  	}
   378  	resultsChan <- ListObjectsResult{ObjectID: object}
   379  }
   380  
   381  // Execute the ListObjectsQuery, returning a list of object IDs up to a maximum of q.listObjectsMaxResults
   382  // or until q.listObjectsDeadline is hit, whichever happens first.
   383  func (q *ListObjectsQuery) Execute(
   384  	ctx context.Context,
   385  	req *openfgav1.ListObjectsRequest,
   386  ) (*ListObjectsResponse, error) {
   387  	resultsChan := make(chan ListObjectsResult, 1)
   388  	maxResults := q.listObjectsMaxResults
   389  	if maxResults > 0 {
   390  		resultsChan = make(chan ListObjectsResult, maxResults)
   391  	}
   392  
   393  	timeoutCtx := ctx
   394  	if q.listObjectsDeadline != 0 {
   395  		var cancel context.CancelFunc
   396  		timeoutCtx, cancel = context.WithTimeout(ctx, q.listObjectsDeadline)
   397  		defer cancel()
   398  	}
   399  
   400  	resolutionMetadata := NewListObjectsResolutionMetadata()
   401  
   402  	err := q.evaluate(timeoutCtx, req, resultsChan, maxResults, resolutionMetadata)
   403  	if err != nil {
   404  		return nil, err
   405  	}
   406  
   407  	objects := make([]string, 0)
   408  
   409  	var errs *multierror.Error
   410  
   411  	for result := range resultsChan {
   412  		if result.Err != nil {
   413  			if errors.Is(result.Err, serverErrors.AuthorizationModelResolutionTooComplex) {
   414  				return nil, result.Err
   415  			}
   416  
   417  			if errors.Is(result.Err, condition.ErrEvaluationFailed) {
   418  				errs = multierror.Append(errs, result.Err)
   419  				continue
   420  			}
   421  
   422  			if errors.Is(result.Err, context.Canceled) || errors.Is(result.Err, context.DeadlineExceeded) {
   423  				continue
   424  			}
   425  
   426  			return nil, serverErrors.HandleError("", result.Err)
   427  		}
   428  
   429  		objects = append(objects, result.ObjectID)
   430  	}
   431  
   432  	if len(objects) < int(maxResults) && errs.ErrorOrNil() != nil {
   433  		return nil, errs
   434  	}
   435  
   436  	return &ListObjectsResponse{
   437  		Objects:            objects,
   438  		ResolutionMetadata: *resolutionMetadata,
   439  	}, nil
   440  }
   441  
   442  // ExecuteStreamed executes the ListObjectsQuery, returning a stream of object IDs.
   443  // It ignores the value of q.listObjectsMaxResults and returns all available results
   444  // until q.listObjectsDeadline is hit.
   445  func (q *ListObjectsQuery) ExecuteStreamed(ctx context.Context, req *openfgav1.StreamedListObjectsRequest, srv openfgav1.OpenFGAService_StreamedListObjectsServer) (*ListObjectsResolutionMetadata, error) {
   446  	maxResults := uint32(math.MaxUint32)
   447  	// make a buffered channel so that writer goroutines aren't blocked when attempting to send a result
   448  	resultsChan := make(chan ListObjectsResult, streamedBufferSize)
   449  
   450  	timeoutCtx := ctx
   451  	if q.listObjectsDeadline != 0 {
   452  		var cancel context.CancelFunc
   453  		timeoutCtx, cancel = context.WithTimeout(ctx, q.listObjectsDeadline)
   454  		defer cancel()
   455  	}
   456  
   457  	resolutionMetadata := NewListObjectsResolutionMetadata()
   458  
   459  	err := q.evaluate(timeoutCtx, req, resultsChan, maxResults, resolutionMetadata)
   460  	if err != nil {
   461  		return nil, err
   462  	}
   463  
   464  	for result := range resultsChan {
   465  		if result.Err != nil {
   466  			if errors.Is(result.Err, serverErrors.AuthorizationModelResolutionTooComplex) {
   467  				return nil, result.Err
   468  			}
   469  
   470  			if errors.Is(result.Err, condition.ErrEvaluationFailed) {
   471  				return nil, serverErrors.ValidationError(result.Err)
   472  			}
   473  
   474  			return nil, serverErrors.HandleError("", result.Err)
   475  		}
   476  
   477  		if err := srv.Send(&openfgav1.StreamedListObjectsResponse{
   478  			Object: result.ObjectID,
   479  		}); err != nil {
   480  			return nil, serverErrors.HandleError("", err)
   481  		}
   482  	}
   483  
   484  	return resolutionMetadata, nil
   485  }