github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/graph/cursors.go (about)

     1  package graph
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"strconv"
     7  	"sync"
     8  
     9  	"github.com/authzed/spicedb/internal/dispatch"
    10  	"github.com/authzed/spicedb/internal/taskrunner"
    11  	"github.com/authzed/spicedb/pkg/datastore/options"
    12  	v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
    13  	"github.com/authzed/spicedb/pkg/spiceerrors"
    14  	"github.com/authzed/spicedb/pkg/tuple"
    15  )
    16  
    17  // cursorInformation is a struct which holds information about the current incoming cursor (if any)
    18  // and the sections to be added to the *outgoing* partial cursor.
    19  type cursorInformation struct {
    20  	// currentCursor is the current incoming cursor. This may be nil.
    21  	currentCursor *v1.Cursor
    22  
    23  	// outgoingCursorSections are the sections to be added to the outgoing *partial* cursor.
    24  	// It is the responsibility of the *caller* to append together the incoming cursors to form
    25  	// the final cursor.
    26  	//
    27  	// A `section` is a portion of the cursor, representing a section of code that was
    28  	// executed to produce the section of the cursor.
    29  	outgoingCursorSections []string
    30  
    31  	// limits is the limits tracker for the call over which the cursor is being used.
    32  	limits *limitTracker
    33  
    34  	// dispatchCursorVersion is the version of the dispatch to be stored in the cursor.
    35  	dispatchCursorVersion uint32
    36  }
    37  
    38  // newCursorInformation constructs a new cursorInformation struct from the incoming cursor (which
    39  // may be nil)
    40  func newCursorInformation(incomingCursor *v1.Cursor, limits *limitTracker, dispatchCursorVersion uint32) (cursorInformation, error) {
    41  	if incomingCursor != nil && incomingCursor.DispatchVersion != dispatchCursorVersion {
    42  		return cursorInformation{}, NewInvalidCursorErr(dispatchCursorVersion, incomingCursor)
    43  	}
    44  
    45  	if dispatchCursorVersion == 0 {
    46  		return cursorInformation{}, spiceerrors.MustBugf("invalid dispatch cursor version")
    47  	}
    48  
    49  	return cursorInformation{
    50  		currentCursor:          incomingCursor,
    51  		outgoingCursorSections: nil,
    52  		limits:                 limits,
    53  		dispatchCursorVersion:  dispatchCursorVersion,
    54  	}, nil
    55  }
    56  
    57  // responsePartialCursor is the *partial* cursor to return in a response.
    58  func (ci cursorInformation) responsePartialCursor() *v1.Cursor {
    59  	return &v1.Cursor{
    60  		DispatchVersion: ci.dispatchCursorVersion,
    61  		Sections:        ci.outgoingCursorSections,
    62  	}
    63  }
    64  
    65  // withClonedLimits returns the cursor, but with its limits tracker cloned.
    66  func (ci cursorInformation) withClonedLimits() cursorInformation {
    67  	return cursorInformation{
    68  		currentCursor:          ci.currentCursor,
    69  		outgoingCursorSections: ci.outgoingCursorSections,
    70  		limits:                 ci.limits.clone(),
    71  		dispatchCursorVersion:  ci.dispatchCursorVersion,
    72  	}
    73  }
    74  
    75  // headSectionValue returns the string value found at the head of the incoming cursor.
    76  // If the incoming cursor is empty, returns empty.
    77  func (ci cursorInformation) headSectionValue() (string, bool) {
    78  	if ci.currentCursor == nil || len(ci.currentCursor.Sections) < 1 {
    79  		return "", false
    80  	}
    81  
    82  	return ci.currentCursor.Sections[0], true
    83  }
    84  
    85  // integerSectionValue returns the *integer* found  at the head of the incoming cursor.
    86  // If the incoming cursor is empty, returns 0. If the incoming cursor does not start with an
    87  // int value, fails with an error.
    88  func (ci cursorInformation) integerSectionValue() (int, error) {
    89  	valueStr, hasValue := ci.headSectionValue()
    90  	if !hasValue {
    91  		return 0, nil
    92  	}
    93  
    94  	if valueStr == "" {
    95  		return 0, nil
    96  	}
    97  
    98  	return strconv.Atoi(valueStr)
    99  }
   100  
   101  // withOutgoingSection returns cursorInformation updated with the given optional
   102  // value appended to the outgoingCursorSections for the current cursor. If the current
   103  // cursor already begins with any values, those values are replaced.
   104  func (ci cursorInformation) withOutgoingSection(value string) (cursorInformation, error) {
   105  	ocs := make([]string, 0, len(ci.outgoingCursorSections)+1)
   106  	ocs = append(ocs, ci.outgoingCursorSections...)
   107  	ocs = append(ocs, value)
   108  
   109  	if ci.currentCursor != nil && len(ci.currentCursor.Sections) > 0 {
   110  		// If the cursor already has values, replace them with those specified.
   111  		return cursorInformation{
   112  			currentCursor: &v1.Cursor{
   113  				DispatchVersion: ci.dispatchCursorVersion,
   114  				Sections:        ci.currentCursor.Sections[1:],
   115  			},
   116  			outgoingCursorSections: ocs,
   117  			limits:                 ci.limits,
   118  			dispatchCursorVersion:  ci.dispatchCursorVersion,
   119  		}, nil
   120  	}
   121  
   122  	return cursorInformation{
   123  		currentCursor:          nil,
   124  		outgoingCursorSections: ocs,
   125  		limits:                 ci.limits,
   126  		dispatchCursorVersion:  ci.dispatchCursorVersion,
   127  	}, nil
   128  }
   129  
   130  func (ci cursorInformation) clearIncoming() cursorInformation {
   131  	return cursorInformation{
   132  		currentCursor:          nil,
   133  		outgoingCursorSections: ci.outgoingCursorSections,
   134  		limits:                 ci.limits,
   135  		dispatchCursorVersion:  ci.dispatchCursorVersion,
   136  	}
   137  }
   138  
   139  // itemAndPostCursor represents an item and the cursor to be used for all items after it.
   140  type itemAndPostCursor[T any] struct {
   141  	item   T
   142  	cursor options.Cursor
   143  }
   144  
   145  // withDatastoreCursorInCursor executes the given lookup function to retrieve items from the datastore,
   146  // and then executes the handler on each of the produced items *in parallel*, streaming the results
   147  // in the correct order to the parent stream.
   148  func withDatastoreCursorInCursor[T any, Q any](
   149  	ctx context.Context,
   150  	ci cursorInformation,
   151  	parentStream dispatch.Stream[Q],
   152  	concurrencyLimit uint16,
   153  	lookup func(queryCursor options.Cursor) ([]itemAndPostCursor[T], error),
   154  	handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error,
   155  ) error {
   156  	// Retrieve the *datastore* cursor, if one is found at the head of the incoming cursor.
   157  	var datastoreCursor options.Cursor
   158  	datastoreCursorString, _ := ci.headSectionValue()
   159  	if datastoreCursorString != "" {
   160  		datastoreCursor = tuple.MustParse(datastoreCursorString)
   161  	}
   162  
   163  	if ci.limits.hasExhaustedLimit() {
   164  		return nil
   165  	}
   166  
   167  	// Execute the lookup to call the database and find items for processing.
   168  	itemsToBeProcessed, err := lookup(datastoreCursor)
   169  	if err != nil {
   170  		return err
   171  	}
   172  
   173  	if len(itemsToBeProcessed) == 0 {
   174  		return nil
   175  	}
   176  
   177  	itemsToRun := make([]T, 0, len(itemsToBeProcessed))
   178  	for _, itemAndCursor := range itemsToBeProcessed {
   179  		itemsToRun = append(itemsToRun, itemAndCursor.item)
   180  	}
   181  
   182  	getItemCursor := func(taskIndex int) (cursorInformation, error) {
   183  		// Create an updated cursor referencing the current item's cursor, so that any items returned know to resume from this point.
   184  		currentCursor, err := ci.withOutgoingSection(tuple.StringWithoutCaveat(itemsToBeProcessed[taskIndex].cursor))
   185  		if err != nil {
   186  			return currentCursor, err
   187  		}
   188  
   189  		// If not the first iteration, we need to clear incoming sections to ensure the iteration starts at the top
   190  		// of the cursor.
   191  		if taskIndex > 0 {
   192  			currentCursor = currentCursor.clearIncoming()
   193  		}
   194  
   195  		return currentCursor, nil
   196  	}
   197  
   198  	return withInternalParallelizedStreamingIterableInCursor(
   199  		ctx,
   200  		ci,
   201  		itemsToRun,
   202  		parentStream,
   203  		concurrencyLimit,
   204  		getItemCursor,
   205  		handler,
   206  	)
   207  }
   208  
   209  type (
   210  	afterResponseCursor func(nextOffset int) *v1.Cursor
   211  	cursorHandler       func(c cursorInformation) error
   212  )
   213  
   214  // withSubsetInCursor executes the given handler with the offset index found at the beginning of the
   215  // cursor. If the offset is not found, executes with 0. The handler is given the current offset as
   216  // well as a callback to mint the cursor with the next offset.
   217  func withSubsetInCursor(
   218  	ci cursorInformation,
   219  	handler func(currentOffset int, nextCursorWith afterResponseCursor) error,
   220  	next cursorHandler,
   221  ) error {
   222  	if ci.limits.hasExhaustedLimit() {
   223  		return nil
   224  	}
   225  
   226  	afterIndex, err := ci.integerSectionValue()
   227  	if err != nil {
   228  		return err
   229  	}
   230  
   231  	if afterIndex >= 0 {
   232  		var foundCerr error
   233  		err = handler(afterIndex, func(nextOffset int) *v1.Cursor {
   234  			cursor, cerr := ci.withOutgoingSection(strconv.Itoa(nextOffset))
   235  			foundCerr = cerr
   236  			if cerr != nil {
   237  				return nil
   238  			}
   239  
   240  			return cursor.responsePartialCursor()
   241  		})
   242  		if err != nil {
   243  			return err
   244  		}
   245  		if foundCerr != nil {
   246  			return foundCerr
   247  		}
   248  	}
   249  
   250  	if ci.limits.hasExhaustedLimit() {
   251  		return nil
   252  	}
   253  
   254  	// -1 means that the handler has been completed.
   255  	uci, err := ci.withOutgoingSection("-1")
   256  	if err != nil {
   257  		return err
   258  	}
   259  	return next(uci)
   260  }
   261  
   262  // combineCursors combines the given cursors into one resulting cursor.
   263  func combineCursors(cursor *v1.Cursor, toAdd *v1.Cursor) (*v1.Cursor, error) {
   264  	if toAdd == nil || len(toAdd.Sections) == 0 {
   265  		return nil, spiceerrors.MustBugf("supplied toAdd cursor was nil or empty")
   266  	}
   267  
   268  	if cursor == nil || len(cursor.Sections) == 0 {
   269  		return toAdd, nil
   270  	}
   271  
   272  	sections := make([]string, 0, len(cursor.Sections)+len(toAdd.Sections))
   273  	sections = append(sections, cursor.Sections...)
   274  	sections = append(sections, toAdd.Sections...)
   275  
   276  	return &v1.Cursor{
   277  		DispatchVersion: toAdd.DispatchVersion,
   278  		Sections:        sections,
   279  	}, nil
   280  }
   281  
   282  // withParallelizedStreamingIterableInCursor executes the given handler for each item in the items list, skipping any
   283  // items marked as completed at the head of the cursor and injecting a cursor representing the current
   284  // item.
   285  //
   286  // For example, if items contains 3 items, and the cursor returned was within the handler for item
   287  // index #1, then item index #0 will be skipped on subsequent invocation.
   288  //
   289  // The next index is executed in parallel with the current index, with its results stored in a CollectingStream
   290  // until the next iteration.
   291  func withParallelizedStreamingIterableInCursor[T any, Q any](
   292  	ctx context.Context,
   293  	ci cursorInformation,
   294  	items []T,
   295  	parentStream dispatch.Stream[Q],
   296  	concurrencyLimit uint16,
   297  	handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error,
   298  ) error {
   299  	// Check the cursor for a starting index, before which any items will be skipped.
   300  	startingIndex, err := ci.integerSectionValue()
   301  	if err != nil {
   302  		return err
   303  	}
   304  
   305  	if startingIndex < 0 || startingIndex > len(items) {
   306  		return spiceerrors.MustBugf("invalid cursor in withParallelizedStreamingIterableInCursor: found starting index %d for items %v", startingIndex, items)
   307  	}
   308  
   309  	itemsToRun := items[startingIndex:]
   310  	if len(itemsToRun) == 0 {
   311  		return nil
   312  	}
   313  
   314  	getItemCursor := func(taskIndex int) (cursorInformation, error) {
   315  		// Create an updated cursor referencing the current item's index, so that any items returned know to resume from this point.
   316  		currentCursor, err := ci.withOutgoingSection(strconv.Itoa(taskIndex + startingIndex))
   317  		if err != nil {
   318  			return currentCursor, err
   319  		}
   320  
   321  		// If not the first iteration, we need to clear incoming sections to ensure the iteration starts at the top
   322  		// of the cursor.
   323  		if taskIndex > 0 {
   324  			currentCursor = currentCursor.clearIncoming()
   325  		}
   326  
   327  		return currentCursor, nil
   328  	}
   329  
   330  	return withInternalParallelizedStreamingIterableInCursor(
   331  		ctx,
   332  		ci,
   333  		itemsToRun,
   334  		parentStream,
   335  		concurrencyLimit,
   336  		getItemCursor,
   337  		handler,
   338  	)
   339  }
   340  
   341  func withInternalParallelizedStreamingIterableInCursor[T any, Q any](
   342  	ctx context.Context,
   343  	ci cursorInformation,
   344  	itemsToRun []T,
   345  	parentStream dispatch.Stream[Q],
   346  	concurrencyLimit uint16,
   347  	getItemCursor func(taskIndex int) (cursorInformation, error),
   348  	handler func(ctx context.Context, ci cursorInformation, item T, stream dispatch.Stream[Q]) error,
   349  ) error {
   350  	// Queue up each iteration's worth of items to be run by the task runner.
   351  	tr := taskrunner.NewPreloadedTaskRunner(ctx, concurrencyLimit, len(itemsToRun))
   352  	stream, err := newParallelLimitedIndexedStream(ctx, ci, parentStream, len(itemsToRun))
   353  	if err != nil {
   354  		return err
   355  	}
   356  
   357  	// Schedule a task to be invoked for each item to be run.
   358  	for taskIndex, item := range itemsToRun {
   359  		taskIndex := taskIndex
   360  		item := item
   361  		tr.Add(func(ctx context.Context) error {
   362  			stream.lock.Lock()
   363  			if ci.limits.hasExhaustedLimit() {
   364  				stream.lock.Unlock()
   365  				return nil
   366  			}
   367  			stream.lock.Unlock()
   368  
   369  			ici, err := getItemCursor(taskIndex)
   370  			if err != nil {
   371  				return err
   372  			}
   373  
   374  			// Invoke the handler with the current item's index in the outgoing cursor, indicating that
   375  			// subsequent invocations should jump right to this item.
   376  			ictx, istream, icursor := stream.forTaskIndex(ctx, taskIndex, ici)
   377  
   378  			err = handler(ictx, icursor, item, istream)
   379  			if err != nil {
   380  				// If the branch was canceled explicitly by *this* streaming iterable because other branches have fulfilled
   381  				// the configured limit, then we can safely ignore this error.
   382  				if errors.Is(context.Cause(ictx), stream.errCanceledBecauseFulfilled) {
   383  					return nil
   384  				}
   385  				return err
   386  			}
   387  
   388  			return stream.completedTaskIndex(taskIndex)
   389  		})
   390  	}
   391  
   392  	err = tr.StartAndWait()
   393  	if err != nil {
   394  		return err
   395  	}
   396  	return nil
   397  }
   398  
   399  // parallelLimitedIndexedStream is a specialization of a dispatch.Stream that collects results from multiple
   400  // tasks running in parallel, and emits them in the order of the tasks. The first task's results are directly
   401  // emitted to the parent stream, while subsequent tasks' results are emitted in the defined order of the tasks
   402  // to ensure cursors and limits work as expected.
   403  type parallelLimitedIndexedStream[Q any] struct {
   404  	lock sync.Mutex
   405  
   406  	ctx          context.Context
   407  	ci           cursorInformation
   408  	parentStream dispatch.Stream[Q]
   409  
   410  	streamCount                 int
   411  	toPublishTaskIndex          int
   412  	countingStream              *dispatch.CountingDispatchStream[Q]
   413  	childStreams                map[int]*dispatch.CollectingDispatchStream[Q]
   414  	childContextCancels         map[int]func(cause error)
   415  	completedTaskIndexes        map[int]bool
   416  	errCanceledBecauseFulfilled error
   417  }
   418  
   419  func newParallelLimitedIndexedStream[Q any](
   420  	ctx context.Context,
   421  	ci cursorInformation,
   422  	parentStream dispatch.Stream[Q],
   423  	streamCount int,
   424  ) (*parallelLimitedIndexedStream[Q], error) {
   425  	if streamCount <= 0 {
   426  		return nil, spiceerrors.MustBugf("got invalid stream count")
   427  	}
   428  
   429  	return &parallelLimitedIndexedStream[Q]{
   430  		ctx:                  ctx,
   431  		ci:                   ci,
   432  		parentStream:         parentStream,
   433  		countingStream:       nil,
   434  		childStreams:         map[int]*dispatch.CollectingDispatchStream[Q]{},
   435  		childContextCancels:  map[int]func(cause error){},
   436  		completedTaskIndexes: map[int]bool{},
   437  		toPublishTaskIndex:   0,
   438  		streamCount:          streamCount,
   439  
   440  		// NOTE: we mint a new error here to ensure that we only skip cancelations from this very instance.
   441  		errCanceledBecauseFulfilled: errors.New("canceled because other branches fulfilled limit"),
   442  	}, nil
   443  }
   444  
   445  // forTaskIndex returns a new context, stream and cursor for invoking the task at the specific index and publishing its results.
   446  func (ls *parallelLimitedIndexedStream[Q]) forTaskIndex(ctx context.Context, index int, currentCursor cursorInformation) (context.Context, dispatch.Stream[Q], cursorInformation) {
   447  	ls.lock.Lock()
   448  	defer ls.lock.Unlock()
   449  
   450  	// Create a new cursor with cloned limits, because each child task which executes (in parallel) will need its own
   451  	// limit tracking. The overall limit on the original cursor is managed in completedTaskIndex.
   452  	childCI := currentCursor.withClonedLimits()
   453  	childContext, cancelDispatch := branchContext(ctx)
   454  
   455  	ls.childContextCancels[index] = cancelDispatch
   456  
   457  	// If executing for the first index, it can stream directly to the parent stream, but we need to count the number
   458  	// of items streamed to adjust the overall limits.
   459  	if index == 0 {
   460  		countingStream := dispatch.NewCountingDispatchStream(ls.parentStream)
   461  		ls.countingStream = countingStream
   462  		return childContext, countingStream, childCI
   463  	}
   464  
   465  	// Otherwise, create a child stream with an adjusted limits on the cursor. We have to clone the cursor's
   466  	// limits here to ensure that the child's publishing doesn't affect the first branch.
   467  	childStream := dispatch.NewCollectingDispatchStream[Q](childContext)
   468  	ls.childStreams[index] = childStream
   469  
   470  	return childContext, childStream, childCI
   471  }
   472  
   473  // cancelRemainingDispatches cancels the contexts for each dispatched branch, indicating that no additional results
   474  // are necessary.
   475  func (ls *parallelLimitedIndexedStream[Q]) cancelRemainingDispatches() {
   476  	for _, cancel := range ls.childContextCancels {
   477  		cancel(ls.errCanceledBecauseFulfilled)
   478  	}
   479  }
   480  
   481  // completedTaskIndex indicates the the task at the specific index has completed successfully and that its collected
   482  // results should be published to the parent stream, so long as all previous tasks have been completed and published as well.
   483  func (ls *parallelLimitedIndexedStream[Q]) completedTaskIndex(index int) error {
   484  	ls.lock.Lock()
   485  	defer ls.lock.Unlock()
   486  
   487  	// Mark the task as completed, but not yet published.
   488  	ls.completedTaskIndexes[index] = true
   489  
   490  	// If the overall limit has been reached, nothing more to do.
   491  	if ls.ci.limits.hasExhaustedLimit() {
   492  		ls.cancelRemainingDispatches()
   493  		return nil
   494  	}
   495  
   496  	// Otherwise, publish any results from previous completed tasks up, and including, this task. This loop ensures
   497  	// that the collected results for each task are published to the parent stream in the correct order.
   498  	for {
   499  		if !ls.completedTaskIndexes[ls.toPublishTaskIndex] {
   500  			return nil
   501  		}
   502  
   503  		if ls.toPublishTaskIndex == 0 {
   504  			// Remove the already emitted data from the overall limits.
   505  			if err := ls.ci.limits.markAlreadyPublished(uint32(ls.countingStream.PublishedCount())); err != nil {
   506  				return err
   507  			}
   508  
   509  			if ls.ci.limits.hasExhaustedLimit() {
   510  				ls.cancelRemainingDispatches()
   511  			}
   512  		} else {
   513  			// Publish, to the parent stream, the results produced by the task and stored in the child stream.
   514  			childStream := ls.childStreams[ls.toPublishTaskIndex]
   515  			for _, result := range childStream.Results() {
   516  				if !ls.ci.limits.prepareForPublishing() {
   517  					ls.cancelRemainingDispatches()
   518  					return nil
   519  				}
   520  
   521  				err := ls.parentStream.Publish(result)
   522  				if err != nil {
   523  					return err
   524  				}
   525  			}
   526  			ls.childStreams[ls.toPublishTaskIndex] = nil
   527  		}
   528  
   529  		ls.toPublishTaskIndex++
   530  	}
   531  }