github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/memdb/readonly.go (about) 1 package memdb 2 3 import ( 4 "context" 5 "fmt" 6 "runtime" 7 "slices" 8 "sort" 9 "strings" 10 11 "github.com/hashicorp/go-memdb" 12 13 "github.com/authzed/spicedb/internal/datastore/common" 14 "github.com/authzed/spicedb/pkg/datastore" 15 "github.com/authzed/spicedb/pkg/datastore/options" 16 core "github.com/authzed/spicedb/pkg/proto/core/v1" 17 "github.com/authzed/spicedb/pkg/spiceerrors" 18 ) 19 20 type txFactory func() (*memdb.Txn, error) 21 22 type memdbReader struct { 23 TryLocker 24 txSource txFactory 25 initErr error 26 } 27 28 // QueryRelationships reads relationships starting from the resource side. 29 func (r *memdbReader) QueryRelationships( 30 _ context.Context, 31 filter datastore.RelationshipsFilter, 32 opts ...options.QueryOptionsOption, 33 ) (datastore.RelationshipIterator, error) { 34 if r.initErr != nil { 35 return nil, r.initErr 36 } 37 38 r.mustLock() 39 defer r.Unlock() 40 41 tx, err := r.txSource() 42 if err != nil { 43 return nil, err 44 } 45 46 queryOpts := options.NewQueryOptionsWithOptions(opts...) 47 48 bestIterator, err := iteratorForFilter(tx, filter) 49 if err != nil { 50 return nil, err 51 } 52 53 if queryOpts.After != nil && queryOpts.Sort == options.Unsorted { 54 return nil, datastore.ErrCursorsWithoutSorting 55 } 56 57 matchingRelationshipsFilterFunc := filterFuncForFilters( 58 filter.OptionalResourceType, 59 filter.OptionalResourceIds, 60 filter.OptionalResourceIDPrefix, 61 filter.OptionalResourceRelation, 62 filter.OptionalSubjectsSelectors, 63 filter.OptionalCaveatName, 64 makeCursorFilterFn(queryOpts.After, queryOpts.Sort), 65 ) 66 filteredIterator := memdb.NewFilterIterator(bestIterator, matchingRelationshipsFilterFunc) 67 68 switch queryOpts.Sort { 69 case options.Unsorted: 70 fallthrough 71 72 case options.ByResource: 73 iter := newMemdbTupleIterator(filteredIterator, queryOpts.Limit, queryOpts.Sort) 74 return iter, nil 75 76 case options.BySubject: 77 return newSubjectSortedIterator(filteredIterator, queryOpts.Limit) 78 79 default: 80 return nil, spiceerrors.MustBugf("unsupported sort order: %v", queryOpts.Sort) 81 } 82 } 83 84 func mustHaveBeenClosed(iter *memdbTupleIterator) { 85 if !iter.closed { 86 panic("Tuple iterator garbage collected before Close() was called") 87 } 88 } 89 90 // ReverseQueryRelationships reads relationships starting from the subject. 91 func (r *memdbReader) ReverseQueryRelationships( 92 _ context.Context, 93 subjectsFilter datastore.SubjectsFilter, 94 opts ...options.ReverseQueryOptionsOption, 95 ) (datastore.RelationshipIterator, error) { 96 if r.initErr != nil { 97 return nil, r.initErr 98 } 99 100 r.mustLock() 101 defer r.Unlock() 102 103 tx, err := r.txSource() 104 if err != nil { 105 return nil, err 106 } 107 108 queryOpts := options.NewReverseQueryOptionsWithOptions(opts...) 109 110 iterator, err := tx.Get( 111 tableRelationship, 112 indexSubjectNamespace, 113 subjectsFilter.SubjectType, 114 ) 115 if err != nil { 116 return nil, err 117 } 118 119 filterObjectType, filterRelation := "", "" 120 if queryOpts.ResRelation != nil { 121 filterObjectType = queryOpts.ResRelation.Namespace 122 filterRelation = queryOpts.ResRelation.Relation 123 } 124 125 matchingRelationshipsFilterFunc := filterFuncForFilters( 126 filterObjectType, 127 nil, 128 "", 129 filterRelation, 130 []datastore.SubjectsSelector{subjectsFilter.AsSelector()}, 131 "", 132 makeCursorFilterFn(queryOpts.AfterForReverse, queryOpts.SortForReverse), 133 ) 134 filteredIterator := memdb.NewFilterIterator(iterator, matchingRelationshipsFilterFunc) 135 136 return newMemdbTupleIterator(filteredIterator, queryOpts.LimitForReverse, queryOpts.SortForReverse), nil 137 } 138 139 // ReadNamespace reads a namespace definition and version and returns it, and the revision at 140 // which it was created or last written, if found. 141 func (r *memdbReader) ReadNamespaceByName(_ context.Context, nsName string) (ns *core.NamespaceDefinition, lastWritten datastore.Revision, err error) { 142 if r.initErr != nil { 143 return nil, datastore.NoRevision, r.initErr 144 } 145 146 r.mustLock() 147 defer r.Unlock() 148 149 tx, err := r.txSource() 150 if err != nil { 151 return nil, datastore.NoRevision, err 152 } 153 154 foundRaw, err := tx.First(tableNamespace, indexID, nsName) 155 if err != nil { 156 return nil, datastore.NoRevision, err 157 } 158 159 if foundRaw == nil { 160 return nil, datastore.NoRevision, datastore.NewNamespaceNotFoundErr(nsName) 161 } 162 163 found := foundRaw.(*namespace) 164 165 loaded := &core.NamespaceDefinition{} 166 if err := loaded.UnmarshalVT(found.configBytes); err != nil { 167 return nil, datastore.NoRevision, err 168 } 169 170 return loaded, found.updated, nil 171 } 172 173 // ListNamespaces lists all namespaces defined. 174 func (r *memdbReader) ListAllNamespaces(_ context.Context) ([]datastore.RevisionedNamespace, error) { 175 if r.initErr != nil { 176 return nil, r.initErr 177 } 178 179 r.mustLock() 180 defer r.Unlock() 181 182 tx, err := r.txSource() 183 if err != nil { 184 return nil, err 185 } 186 187 var nsDefs []datastore.RevisionedNamespace 188 189 it, err := tx.LowerBound(tableNamespace, indexID) 190 if err != nil { 191 return nil, err 192 } 193 194 for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { 195 found := foundRaw.(*namespace) 196 197 loaded := &core.NamespaceDefinition{} 198 if err := loaded.UnmarshalVT(found.configBytes); err != nil { 199 return nil, err 200 } 201 202 nsDefs = append(nsDefs, datastore.RevisionedNamespace{ 203 Definition: loaded, 204 LastWrittenRevision: found.updated, 205 }) 206 } 207 208 return nsDefs, nil 209 } 210 211 func (r *memdbReader) LookupNamespacesWithNames(_ context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) { 212 if r.initErr != nil { 213 return nil, r.initErr 214 } 215 216 if len(nsNames) == 0 { 217 return nil, nil 218 } 219 220 r.mustLock() 221 defer r.Unlock() 222 223 tx, err := r.txSource() 224 if err != nil { 225 return nil, err 226 } 227 228 it, err := tx.LowerBound(tableNamespace, indexID) 229 if err != nil { 230 return nil, err 231 } 232 233 nsNameMap := make(map[string]struct{}, len(nsNames)) 234 for _, nsName := range nsNames { 235 nsNameMap[nsName] = struct{}{} 236 } 237 238 nsDefs := make([]datastore.RevisionedNamespace, 0, len(nsNames)) 239 240 for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { 241 found := foundRaw.(*namespace) 242 243 loaded := &core.NamespaceDefinition{} 244 if err := loaded.UnmarshalVT(found.configBytes); err != nil { 245 return nil, err 246 } 247 248 if _, ok := nsNameMap[loaded.Name]; ok { 249 nsDefs = append(nsDefs, datastore.RevisionedNamespace{ 250 Definition: loaded, 251 LastWrittenRevision: found.updated, 252 }) 253 } 254 } 255 256 return nsDefs, nil 257 } 258 259 func (r *memdbReader) mustLock() { 260 if !r.TryLock() { 261 panic("detected concurrent use of ReadWriteTransaction") 262 } 263 } 264 265 func iteratorForFilter(txn *memdb.Txn, filter datastore.RelationshipsFilter) (memdb.ResultIterator, error) { 266 // "_prefix" is a specialized index suffix used by github.com/hashicorp/go-memdb to match on 267 // a prefix of a string. 268 // See: https://github.com/hashicorp/go-memdb/blob/9940d4a14258e3b887bfb4bc6ebc28f65461a01c/txn.go#L531 269 index := indexNamespace + "_prefix" 270 271 var args []any 272 if filter.OptionalResourceType != "" { 273 args = append(args, filter.OptionalResourceType) 274 index = indexNamespace 275 } else { 276 args = append(args, "") 277 } 278 279 if filter.OptionalResourceType != "" && filter.OptionalResourceRelation != "" { 280 args = append(args, filter.OptionalResourceRelation) 281 index = indexNamespaceAndRelation 282 } 283 284 if len(args) == 0 { 285 return nil, spiceerrors.MustBugf("cannot specify an empty filter") 286 } 287 288 iter, err := txn.Get(tableRelationship, index, args...) 289 if err != nil { 290 return nil, fmt.Errorf("unable to get iterator for filter: %w", err) 291 } 292 293 return iter, err 294 } 295 296 func filterFuncForFilters( 297 optionalResourceType string, 298 optionalResourceIds []string, 299 optionalResourceIDPrefix string, 300 optionalRelation string, 301 optionalSubjectsSelectors []datastore.SubjectsSelector, 302 optionalCaveatFilter string, 303 cursorFilter func(*relationship) bool, 304 ) memdb.FilterFunc { 305 return func(tupleRaw interface{}) bool { 306 tuple := tupleRaw.(*relationship) 307 308 switch { 309 case optionalResourceType != "" && optionalResourceType != tuple.namespace: 310 return true 311 case len(optionalResourceIds) > 0 && !slices.Contains(optionalResourceIds, tuple.resourceID): 312 return true 313 case optionalResourceIDPrefix != "" && !strings.HasPrefix(tuple.resourceID, optionalResourceIDPrefix): 314 return true 315 case optionalRelation != "" && optionalRelation != tuple.relation: 316 return true 317 case optionalCaveatFilter != "" && (tuple.caveat == nil || tuple.caveat.caveatName != optionalCaveatFilter): 318 return true 319 } 320 321 applySubjectSelector := func(selector datastore.SubjectsSelector) bool { 322 switch { 323 case len(selector.OptionalSubjectType) > 0 && selector.OptionalSubjectType != tuple.subjectNamespace: 324 return false 325 case len(selector.OptionalSubjectIds) > 0 && !slices.Contains(selector.OptionalSubjectIds, tuple.subjectObjectID): 326 return false 327 } 328 329 if selector.RelationFilter.OnlyNonEllipsisRelations { 330 return tuple.subjectRelation != datastore.Ellipsis 331 } 332 333 relations := make([]string, 0, 2) 334 if selector.RelationFilter.IncludeEllipsisRelation { 335 relations = append(relations, datastore.Ellipsis) 336 } 337 338 if selector.RelationFilter.NonEllipsisRelation != "" { 339 relations = append(relations, selector.RelationFilter.NonEllipsisRelation) 340 } 341 342 return len(relations) == 0 || slices.Contains(relations, tuple.subjectRelation) 343 } 344 345 if len(optionalSubjectsSelectors) > 0 { 346 hasMatchingSelector := false 347 for _, selector := range optionalSubjectsSelectors { 348 if applySubjectSelector(selector) { 349 hasMatchingSelector = true 350 break 351 } 352 } 353 354 if !hasMatchingSelector { 355 return true 356 } 357 } 358 359 return cursorFilter(tuple) 360 } 361 } 362 363 func makeCursorFilterFn(after *core.RelationTuple, order options.SortOrder) func(tpl *relationship) bool { 364 if after != nil { 365 switch order { 366 case options.ByResource: 367 return func(tpl *relationship) bool { 368 return less(tpl.namespace, tpl.resourceID, tpl.relation, after.ResourceAndRelation) || 369 (eq(tpl.namespace, tpl.resourceID, tpl.relation, after.ResourceAndRelation) && 370 (less(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) || 371 eq(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject))) 372 } 373 case options.BySubject: 374 return func(tpl *relationship) bool { 375 return less(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) || 376 (eq(tpl.subjectNamespace, tpl.subjectObjectID, tpl.subjectRelation, after.Subject) && 377 (less(tpl.namespace, tpl.resourceID, tpl.relation, after.ResourceAndRelation) || 378 eq(tpl.namespace, tpl.resourceID, tpl.relation, after.ResourceAndRelation))) 379 } 380 } 381 } 382 return noopCursorFilter 383 } 384 385 func newSubjectSortedIterator(it memdb.ResultIterator, limit *uint64) (datastore.RelationshipIterator, error) { 386 results := make([]*core.RelationTuple, 0) 387 388 // Coalesce all of the results into memory 389 for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { 390 rt, err := foundRaw.(*relationship).RelationTuple() 391 if err != nil { 392 return nil, err 393 } 394 395 results = append(results, rt) 396 } 397 398 // Sort them by subject 399 sort.Slice(results, func(i, j int) bool { 400 lhsRes := results[i].ResourceAndRelation 401 lhsSub := results[i].Subject 402 rhsRes := results[j].ResourceAndRelation 403 rhsSub := results[j].Subject 404 return less(lhsSub.Namespace, lhsSub.ObjectId, lhsSub.Relation, rhsSub) || 405 (eq(lhsSub.Namespace, lhsSub.ObjectId, lhsSub.Relation, rhsSub) && 406 (less(lhsRes.Namespace, lhsRes.ObjectId, lhsRes.Relation, rhsRes))) 407 }) 408 409 // Limit them if requested 410 if limit != nil && uint64(len(results)) > *limit { 411 results = results[0:*limit] 412 } 413 414 return common.NewSliceRelationshipIterator(results, options.BySubject), nil 415 } 416 417 func noopCursorFilter(_ *relationship) bool { 418 return false 419 } 420 421 func less(lhsNamespace, lhsObjectID, lhsRelation string, rhs *core.ObjectAndRelation) bool { 422 return lhsNamespace < rhs.Namespace || 423 (lhsNamespace == rhs.Namespace && lhsObjectID < rhs.ObjectId) || 424 (lhsNamespace == rhs.Namespace && lhsObjectID == rhs.ObjectId && lhsRelation < rhs.Relation) 425 } 426 427 func eq(lhsNamespace, lhsObjectID, lhsRelation string, rhs *core.ObjectAndRelation) bool { 428 return lhsNamespace == rhs.Namespace && lhsObjectID == rhs.ObjectId && lhsRelation == rhs.Relation 429 } 430 431 func newMemdbTupleIterator(it memdb.ResultIterator, limit *uint64, order options.SortOrder) *memdbTupleIterator { 432 iter := &memdbTupleIterator{it: it, limit: limit, order: order} 433 runtime.SetFinalizer(iter, mustHaveBeenClosed) 434 return iter 435 } 436 437 type memdbTupleIterator struct { 438 closed bool 439 it memdb.ResultIterator 440 limit *uint64 441 count uint64 442 err error 443 order options.SortOrder 444 last *core.RelationTuple 445 } 446 447 func (mti *memdbTupleIterator) Next() *core.RelationTuple { 448 if mti.closed { 449 return nil 450 } 451 452 foundRaw := mti.it.Next() 453 if foundRaw == nil { 454 return nil 455 } 456 457 if mti.limit != nil && mti.count >= *mti.limit { 458 return nil 459 } 460 mti.count++ 461 462 rt, err := foundRaw.(*relationship).RelationTuple() 463 if err != nil { 464 mti.err = err 465 return nil 466 } 467 468 mti.last = rt 469 return rt 470 } 471 472 func (mti *memdbTupleIterator) Cursor() (options.Cursor, error) { 473 switch { 474 case mti.closed: 475 return nil, datastore.ErrClosedIterator 476 case mti.order == options.Unsorted: 477 return nil, datastore.ErrCursorsWithoutSorting 478 case mti.last == nil: 479 return nil, datastore.ErrCursorEmpty 480 default: 481 return mti.last, nil 482 } 483 } 484 485 func (mti *memdbTupleIterator) Err() error { 486 return mti.err 487 } 488 489 func (mti *memdbTupleIterator) Close() { 490 mti.closed = true 491 mti.err = datastore.ErrClosedIterator 492 } 493 494 var _ datastore.Reader = &memdbReader{} 495 496 type TryLocker interface { 497 TryLock() bool 498 Unlock() 499 } 500 501 type noopTryLocker struct{} 502 503 func (ntl noopTryLocker) TryLock() bool { 504 return true 505 } 506 507 func (ntl noopTryLocker) Unlock() {} 508 509 var _ TryLocker = noopTryLocker{}