github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/memdb/readonly.go (about)

     1  package memdb
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"runtime"
     7  	"slices"
     8  	"sort"
     9  	"strings"
    10  
    11  	"github.com/hashicorp/go-memdb"
    12  
    13  	"github.com/authzed/spicedb/internal/datastore/common"
    14  	"github.com/authzed/spicedb/pkg/datastore"
    15  	"github.com/authzed/spicedb/pkg/datastore/options"
    16  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    17  	"github.com/authzed/spicedb/pkg/spiceerrors"
    18  )
    19  
    20  type txFactory func() (*memdb.Txn, error)
    21  
    22  type memdbReader struct {
    23  	TryLocker
    24  	txSource txFactory
    25  	initErr  error
    26  }
    27  
    28  // QueryRelationships reads relationships starting from the resource side.
    29  func (r *memdbReader) QueryRelationships(
    30  	_ context.Context,
    31  	filter datastore.RelationshipsFilter,
    32  	opts ...options.QueryOptionsOption,
    33  ) (datastore.RelationshipIterator, error) {
    34  	if r.initErr != nil {
    35  		return nil, r.initErr
    36  	}
    37  
    38  	r.mustLock()
    39  	defer r.Unlock()
    40  
    41  	tx, err := r.txSource()
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  
    46  	queryOpts := options.NewQueryOptionsWithOptions(opts...)
    47  
    48  	bestIterator, err := iteratorForFilter(tx, filter)
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	if queryOpts.After != nil && queryOpts.Sort == options.Unsorted {
    54  		return nil, datastore.ErrCursorsWithoutSorting
    55  	}
    56  
    57  	matchingRelationshipsFilterFunc := filterFuncForFilters(
    58  		filter.OptionalResourceType,
    59  		filter.OptionalResourceIds,
    60  		filter.OptionalResourceIDPrefix,
    61  		filter.OptionalResourceRelation,
    62  		filter.OptionalSubjectsSelectors,
    63  		filter.OptionalCaveatName,
    64  		makeCursorFilterFn(queryOpts.After, queryOpts.Sort),
    65  	)
    66  	filteredIterator := memdb.NewFilterIterator(bestIterator, matchingRelationshipsFilterFunc)
    67  
    68  	switch queryOpts.Sort {
    69  	case options.Unsorted:
    70  		fallthrough
    71  
    72  	case options.ByResource:
    73  		iter := newMemdbTupleIterator(filteredIterator, queryOpts.Limit, queryOpts.Sort)
    74  		return iter, nil
    75  
    76  	case options.BySubject:
    77  		return newSubjectSortedIterator(filteredIterator, queryOpts.Limit)
    78  
    79  	default:
    80  		return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.Sort)
    81  	}
    82  }
    83  
    84  func mustHaveBeenClosed(iter *memdbTupleIterator) {
    85  	if !iter.closed {
    86  		panic("Tuple iterator garbage collected before Close() was called")
    87  	}
    88  }
    89  
    90  // ReverseQueryRelationships reads relationships starting from the subject.
    91  func (r *memdbReader) ReverseQueryRelationships(
    92  	_ context.Context,
    93  	subjectsFilter datastore.SubjectsFilter,
    94  	opts ...options.ReverseQueryOptionsOption,
    95  ) (datastore.RelationshipIterator, error) {
    96  	if r.initErr != nil {
    97  		return nil, r.initErr
    98  	}
    99  
   100  	r.mustLock()
   101  	defer r.Unlock()
   102  
   103  	tx, err := r.txSource()
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	queryOpts := options.NewReverseQueryOptionsWithOptions(opts...)
   109  
   110  	iterator, err := tx.Get(
   111  		tableRelationship,
   112  		indexSubjectNamespace,
   113  		subjectsFilter.SubjectType,
   114  	)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	filterObjectType, filterRelation := "", ""
   120  	if queryOpts.ResRelation != nil {
   121  		filterObjectType = queryOpts.ResRelation.Namespace
   122  		filterRelation = queryOpts.ResRelation.Relation
   123  	}
   124  
   125  	matchingRelationshipsFilterFunc := filterFuncForFilters(
   126  		filterObjectType,
   127  		nil,
   128  		"",
   129  		filterRelation,
   130  		[]datastore.SubjectsSelector{subjectsFilter.AsSelector()},
   131  		"",
   132  		makeCursorFilterFn(queryOpts.AfterForReverse, queryOpts.SortForReverse),
   133  	)
   134  	filteredIterator := memdb.NewFilterIterator(iterator, matchingRelationshipsFilterFunc)
   135  
   136  	return newMemdbTupleIterator(filteredIterator, queryOpts.LimitForReverse, queryOpts.SortForReverse), nil
   137  }
   138  
   139  // ReadNamespace reads a namespace definition and version and returns it, and the revision at
   140  // which it was created or last written, if found.
   141  func (r *memdbReader) ReadNamespaceByName(_ context.Context, nsName string) (ns *core.NamespaceDefinition, lastWritten datastore.Revision, err error) {
   142  	if r.initErr != nil {
   143  		return nil, datastore.NoRevision, r.initErr
   144  	}
   145  
   146  	r.mustLock()
   147  	defer r.Unlock()
   148  
   149  	tx, err := r.txSource()
   150  	if err != nil {
   151  		return nil, datastore.NoRevision, err
   152  	}
   153  
   154  	foundRaw, err := tx.First(tableNamespace, indexID, nsName)
   155  	if err != nil {
   156  		return nil, datastore.NoRevision, err
   157  	}
   158  
   159  	if foundRaw == nil {
   160  		return nil, datastore.NoRevision, datastore.NewNamespaceNotFoundErr(nsName)
   161  	}
   162  
   163  	found := foundRaw.(*namespace)
   164  
   165  	loaded := &core.NamespaceDefinition{}
   166  	if err := loaded.UnmarshalVT(found.configBytes); err != nil {
   167  		return nil, datastore.NoRevision, err
   168  	}
   169  
   170  	return loaded, found.updated, nil
   171  }
   172  
   173  // ListNamespaces lists all namespaces defined.
   174  func (r *memdbReader) ListAllNamespaces(_ context.Context) ([]datastore.RevisionedNamespace, error) {
   175  	if r.initErr != nil {
   176  		return nil, r.initErr
   177  	}
   178  
   179  	r.mustLock()
   180  	defer r.Unlock()
   181  
   182  	tx, err := r.txSource()
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  
   187  	var nsDefs []datastore.RevisionedNamespace
   188  
   189  	it, err := tx.LowerBound(tableNamespace, indexID)
   190  	if err != nil {
   191  		return nil, err
   192  	}
   193  
   194  	for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
   195  		found := foundRaw.(*namespace)
   196  
   197  		loaded := &core.NamespaceDefinition{}
   198  		if err := loaded.UnmarshalVT(found.configBytes); err != nil {
   199  			return nil, err
   200  		}
   201  
   202  		nsDefs = append(nsDefs, datastore.RevisionedNamespace{
   203  			Definition:          loaded,
   204  			LastWrittenRevision: found.updated,
   205  		})
   206  	}
   207  
   208  	return nsDefs, nil
   209  }
   210  
   211  func (r *memdbReader) LookupNamespacesWithNames(_ context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) {
   212  	if r.initErr != nil {
   213  		return nil, r.initErr
   214  	}
   215  
   216  	if len(nsNames) == 0 {
   217  		return nil, nil
   218  	}
   219  
   220  	r.mustLock()
   221  	defer r.Unlock()
   222  
   223  	tx, err := r.txSource()
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  
   228  	it, err := tx.LowerBound(tableNamespace, indexID)
   229  	if err != nil {
   230  		return nil, err
   231  	}
   232  
   233  	nsNameMap := make(map[string]struct{}, len(nsNames))
   234  	for _, nsName := range nsNames {
   235  		nsNameMap[nsName] = struct{}{}
   236  	}
   237  
   238  	nsDefs := make([]datastore.RevisionedNamespace, 0, len(nsNames))
   239  
   240  	for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
   241  		found := foundRaw.(*namespace)
   242  
   243  		loaded := &core.NamespaceDefinition{}
   244  		if err := loaded.UnmarshalVT(found.configBytes); err != nil {
   245  			return nil, err
   246  		}
   247  
   248  		if _, ok := nsNameMap[loaded.Name]; ok {
   249  			nsDefs = append(nsDefs, datastore.RevisionedNamespace{
   250  				Definition:          loaded,
   251  				LastWrittenRevision: found.updated,
   252  			})
   253  		}
   254  	}
   255  
   256  	return nsDefs, nil
   257  }
   258  
   259  func (r *memdbReader) mustLock() {
   260  	if !r.TryLock() {
   261  		panic("detected concurrent use of ReadWriteTransaction")
   262  	}
   263  }
   264  
   265  func iteratorForFilter(txn *memdb.Txn, filter datastore.RelationshipsFilter) (memdb.ResultIterator, error) {
   266  	// "_prefix" is a specialized index suffix used by github.com/hashicorp/go-memdb to match on
   267  	// a prefix of a string.
   268  	// See: https://github.com/hashicorp/go-memdb/blob/9940d4a14258e3b887bfb4bc6ebc28f65461a01c/txn.go#L531
   269  	index := indexNamespace + "_prefix"
   270  
   271  	var args []any
   272  	if filter.OptionalResourceType != "" {
   273  		args = append(args, filter.OptionalResourceType)
   274  		index = indexNamespace
   275  	} else {
   276  		args = append(args, "")
   277  	}
   278  
   279  	if filter.OptionalResourceType != "" && filter.OptionalResourceRelation != "" {
   280  		args = append(args, filter.OptionalResourceRelation)
   281  		index = indexNamespaceAndRelation
   282  	}
   283  
   284  	if len(args) == 0 {
   285  		return nil, spiceerrors.MustBugf("cannot specify an empty filter")
   286  	}
   287  
   288  	iter, err := txn.Get(tableRelationship, index, args...)
   289  	if err != nil {
   290  		return nil, fmt.Errorf("unable to get iterator for filter: %w", err)
   291  	}
   292  
   293  	return iter, err
   294  }
   295  
   296  func filterFuncForFilters(
   297  	optionalResourceType string,
   298  	optionalResourceIds []string,
   299  	optionalResourceIDPrefix string,
   300  	optionalRelation string,
   301  	optionalSubjectsSelectors []datastore.SubjectsSelector,
   302  	optionalCaveatFilter string,
   303  	cursorFilter func(*relationship) bool,
   304  ) memdb.FilterFunc {
   305  	return func(tupleRaw interface{}) bool {
   306  		tuple := tupleRaw.(*relationship)
   307  
   308  		switch {
   309  		case optionalResourceType != "" && optionalResourceType != tuple.namespace:
   310  			return true
   311  		case len(optionalResourceIds) > 0 && !slices.Contains(optionalResourceIds, tuple.resourceID):
   312  			return true
   313  		case optionalResourceIDPrefix != "" && !strings.HasPrefix(tuple.resourceID, optionalResourceIDPrefix):
   314  			return true
   315  		case optionalRelation != "" && optionalRelation != tuple.relation:
   316  			return true
   317  		case optionalCaveatFilter != "" && (tuple.caveat == nil || tuple.caveat.caveatName != optionalCaveatFilter):
   318  			return true
   319  		}
   320  
   321  		applySubjectSelector := func(selector datastore.SubjectsSelector) bool {
   322  			switch {
   323  			case len(selector.OptionalSubjectType) > 0 && selector.OptionalSubjectType != tuple.subjectNamespace:
   324  				return false
   325  			case len(selector.OptionalSubjectIds) > 0 && !slices.Contains(selector.OptionalSubjectIds, tuple.subjectObjectID):
   326  				return false
   327  			}
   328  
   329  			if selector.RelationFilter.OnlyNonEllipsisRelations {
   330  				return tuple.subjectRelation != datastore.Ellipsis
   331  			}
   332  
   333  			relations := make([]string, 0, 2)
   334  			if selector.RelationFilter.IncludeEllipsisRelation {
   335  				relations = append(relations, datastore.Ellipsis)
   336  			}
   337  
   338  			if selector.RelationFilter.NonEllipsisRelation != "" {
   339  				relations = append(relations, selector.RelationFilter.NonEllipsisRelation)
   340  			}
   341  
   342  			return len(relations) == 0 || slices.Contains(relations, tuple.subjectRelation)
   343  		}
   344  
   345  		if len(optionalSubjectsSelectors) > 0 {
   346  			hasMatchingSelector := false
   347  			for _, selector := range optionalSubjectsSelectors {
   348  				if applySubjectSelector(selector) {
   349  					hasMatchingSelector = true
   350  					break
   351  				}
   352  			}
   353  
   354  			if !hasMatchingSelector {
   355  				return true
   356  			}
   357  		}
   358  
   359  		return cursorFilter(tuple)
   360  	}
   361  }
   362  
   363  func makeCursorFilterFn(after *core.RelationTuple, order options.SortOrder) func(tpl *relationship) bool {
   364  	if after != nil {
   365  		switch order {
   366  		case options.ByResource:
   367  			return func(tpl *relationship) bool {
   368  				return less(tpl.namespace, tpl.resourceID, tpl.relation, after.ResourceAndRelation) ||
   369  					(eq(tpl.namespace, tpl.resourceID, tpl.relation, after.ResourceAndRelation) &&
   370  						(less(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) ||
   371  							eq(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject)))
   372  			}
   373  		case options.BySubject:
   374  			return func(tpl *relationship) bool {
   375  				return less(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) ||
   376  					(eq(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) &&
   377  						(less(tpl.namespace, tpl.resourceID, tpl.relation, after.ResourceAndRelation) ||
   378  							eq(tpl.namespace, tpl.resourceID, tpl.relation, after.ResourceAndRelation)))
   379  			}
   380  		}
   381  	}
   382  	return noopCursorFilter
   383  }
   384  
   385  func newSubjectSortedIterator(it memdb.ResultIterator, limit *uint64) (datastore.RelationshipIterator, error) {
   386  	results := make([]*core.RelationTuple, 0)
   387  
   388  	// Coalesce all of the results into memory
   389  	for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
   390  		rt, err := foundRaw.(*relationship).RelationTuple()
   391  		if err != nil {
   392  			return nil, err
   393  		}
   394  
   395  		results = append(results, rt)
   396  	}
   397  
   398  	// Sort them by subject
   399  	sort.Slice(results, func(i, j int) bool {
   400  		lhsRes := results[i].ResourceAndRelation
   401  		lhsSub := results[i].Subject
   402  		rhsRes := results[j].ResourceAndRelation
   403  		rhsSub := results[j].Subject
   404  		return less(lhsSub.Namespace, lhsSub.ObjectId, lhsSub.Relation, rhsSub) ||
   405  			(eq(lhsSub.Namespace, lhsSub.ObjectId, lhsSub.Relation, rhsSub) &&
   406  				(less(lhsRes.Namespace, lhsRes.ObjectId, lhsRes.Relation, rhsRes)))
   407  	})
   408  
   409  	// Limit them if requested
   410  	if limit != nil && uint64(len(results)) > *limit {
   411  		results = results[0:*limit]
   412  	}
   413  
   414  	return common.NewSliceRelationshipIterator(results, options.BySubject), nil
   415  }
   416  
   417  func noopCursorFilter(_ *relationship) bool {
   418  	return false
   419  }
   420  
   421  func less(lhsNamespace, lhsObjectID, lhsRelation string, rhs *core.ObjectAndRelation) bool {
   422  	return lhsNamespace < rhs.Namespace ||
   423  		(lhsNamespace == rhs.Namespace && lhsObjectID < rhs.ObjectId) ||
   424  		(lhsNamespace == rhs.Namespace && lhsObjectID == rhs.ObjectId && lhsRelation < rhs.Relation)
   425  }
   426  
   427  func eq(lhsNamespace, lhsObjectID, lhsRelation string, rhs *core.ObjectAndRelation) bool {
   428  	return lhsNamespace == rhs.Namespace && lhsObjectID == rhs.ObjectId && lhsRelation == rhs.Relation
   429  }
   430  
   431  func newMemdbTupleIterator(it memdb.ResultIterator, limit *uint64, order options.SortOrder) *memdbTupleIterator {
   432  	iter := &memdbTupleIterator{it: it, limit: limit, order: order}
   433  	runtime.SetFinalizer(iter, mustHaveBeenClosed)
   434  	return iter
   435  }
   436  
   437  type memdbTupleIterator struct {
   438  	closed bool
   439  	it     memdb.ResultIterator
   440  	limit  *uint64
   441  	count  uint64
   442  	err    error
   443  	order  options.SortOrder
   444  	last   *core.RelationTuple
   445  }
   446  
   447  func (mti *memdbTupleIterator) Next() *core.RelationTuple {
   448  	if mti.closed {
   449  		return nil
   450  	}
   451  
   452  	foundRaw := mti.it.Next()
   453  	if foundRaw == nil {
   454  		return nil
   455  	}
   456  
   457  	if mti.limit != nil && mti.count >= *mti.limit {
   458  		return nil
   459  	}
   460  	mti.count++
   461  
   462  	rt, err := foundRaw.(*relationship).RelationTuple()
   463  	if err != nil {
   464  		mti.err = err
   465  		return nil
   466  	}
   467  
   468  	mti.last = rt
   469  	return rt
   470  }
   471  
   472  func (mti *memdbTupleIterator) Cursor() (options.Cursor, error) {
   473  	switch {
   474  	case mti.closed:
   475  		return nil, datastore.ErrClosedIterator
   476  	case mti.order == options.Unsorted:
   477  		return nil, datastore.ErrCursorsWithoutSorting
   478  	case mti.last == nil:
   479  		return nil, datastore.ErrCursorEmpty
   480  	default:
   481  		return mti.last, nil
   482  	}
   483  }
   484  
   485  func (mti *memdbTupleIterator) Err() error {
   486  	return mti.err
   487  }
   488  
   489  func (mti *memdbTupleIterator) Close() {
   490  	mti.closed = true
   491  	mti.err = datastore.ErrClosedIterator
   492  }
   493  
   494  var _ datastore.Reader = &memdbReader{}
   495  
   496  type TryLocker interface {
   497  	TryLock() bool
   498  	Unlock()
   499  }
   500  
   501  type noopTryLocker struct{}
   502  
   503  func (ntl noopTryLocker) TryLock() bool {
   504  	return true
   505  }
   506  
   507  func (ntl noopTryLocker) Unlock() {}
   508  
   509  var _ TryLocker = noopTryLocker{}