github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/memdb/caveat.go (about) 1 package memdb 2 3 import ( 4 "context" 5 "fmt" 6 7 "github.com/hashicorp/go-memdb" 8 9 "github.com/authzed/spicedb/pkg/datastore" 10 "github.com/authzed/spicedb/pkg/genutil/mapz" 11 core "github.com/authzed/spicedb/pkg/proto/core/v1" 12 ) 13 14 const tableCaveats = "caveats" 15 16 type caveat struct { 17 name string 18 definition []byte 19 revision datastore.Revision 20 } 21 22 func (c *caveat) Unwrap() (*core.CaveatDefinition, error) { 23 definition := core.CaveatDefinition{} 24 err := definition.UnmarshalVT(c.definition) 25 return &definition, err 26 } 27 28 func (r *memdbReader) ReadCaveatByName(_ context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { 29 r.mustLock() 30 defer r.Unlock() 31 32 tx, err := r.txSource() 33 if err != nil { 34 return nil, datastore.NoRevision, err 35 } 36 return r.readUnwrappedCaveatByName(tx, name) 37 } 38 39 func (r *memdbReader) readCaveatByName(tx *memdb.Txn, name string) (*caveat, datastore.Revision, error) { 40 found, err := tx.First(tableCaveats, indexID, name) 41 if err != nil { 42 return nil, datastore.NoRevision, err 43 } 44 if found == nil { 45 return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name) 46 } 47 cvt := found.(*caveat) 48 return cvt, cvt.revision, nil 49 } 50 51 func (r *memdbReader) readUnwrappedCaveatByName(tx *memdb.Txn, name string) (*core.CaveatDefinition, datastore.Revision, error) { 52 c, rev, err := r.readCaveatByName(tx, name) 53 if err != nil { 54 return nil, datastore.NoRevision, err 55 } 56 unwrapped, err := c.Unwrap() 57 if err != nil { 58 return nil, datastore.NoRevision, err 59 } 60 return unwrapped, rev, nil 61 } 62 63 func (r *memdbReader) ListAllCaveats(_ context.Context) ([]datastore.RevisionedCaveat, error) { 64 r.mustLock() 65 defer r.Unlock() 66 67 tx, err := r.txSource() 68 if err != nil { 69 return nil, err 70 } 71 72 var caveats []datastore.RevisionedCaveat 73 it, err := tx.LowerBound(tableCaveats, indexID) 74 if err != nil { 75 return nil, err 76 } 77 78 for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() { 79 rawCaveat := foundRaw.(*caveat) 80 definition, err := rawCaveat.Unwrap() 81 if err != nil { 82 return nil, err 83 } 84 caveats = append(caveats, datastore.RevisionedCaveat{ 85 Definition: definition, 86 LastWrittenRevision: rawCaveat.revision, 87 }) 88 } 89 90 return caveats, nil 91 } 92 93 func (r *memdbReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) { 94 allCaveats, err := r.ListAllCaveats(ctx) 95 if err != nil { 96 return nil, err 97 } 98 99 allowedCaveatNames := mapz.NewSet[string]() 100 allowedCaveatNames.Extend(caveatNames) 101 102 toReturn := make([]datastore.RevisionedCaveat, 0, len(caveatNames)) 103 for _, caveat := range allCaveats { 104 if allowedCaveatNames.Has(caveat.Definition.Name) { 105 toReturn = append(toReturn, caveat) 106 } 107 } 108 return toReturn, nil 109 } 110 111 func (rwt *memdbReadWriteTx) WriteCaveats(_ context.Context, caveats []*core.CaveatDefinition) error { 112 rwt.mustLock() 113 defer rwt.Unlock() 114 tx, err := rwt.txSource() 115 if err != nil { 116 return err 117 } 118 return rwt.writeCaveat(tx, caveats) 119 } 120 121 func (rwt *memdbReadWriteTx) writeCaveat(tx *memdb.Txn, caveats []*core.CaveatDefinition) error { 122 caveatNames := mapz.NewSet[string]() 123 for _, coreCaveat := range caveats { 124 if !caveatNames.Add(coreCaveat.Name) { 125 return fmt.Errorf("duplicate caveat %s", coreCaveat.Name) 126 } 127 marshalled, err := coreCaveat.MarshalVT() 128 if err != nil { 129 return err 130 } 131 c := caveat{ 132 name: coreCaveat.Name, 133 definition: marshalled, 134 revision: rwt.newRevision, 135 } 136 if err := tx.Insert(tableCaveats, &c); err != nil { 137 return err 138 } 139 } 140 return nil 141 } 142 143 func (rwt *memdbReadWriteTx) DeleteCaveats(_ context.Context, names []string) error { 144 rwt.mustLock() 145 defer rwt.Unlock() 146 tx, err := rwt.txSource() 147 if err != nil { 148 return err 149 } 150 for _, name := range names { 151 if err := tx.Delete(tableCaveats, caveat{name: name}); err != nil { 152 return err 153 } 154 } 155 return nil 156 }