github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/crdb/caveat.go (about) 1 package crdb 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "time" 8 9 sq "github.com/Masterminds/squirrel" 10 "github.com/jackc/pgx/v5" 11 12 "github.com/authzed/spicedb/internal/datastore/revisions" 13 "github.com/authzed/spicedb/pkg/datastore" 14 core "github.com/authzed/spicedb/pkg/proto/core/v1" 15 ) 16 17 var ( 18 upsertCaveatSuffix = fmt.Sprintf( 19 "ON CONFLICT (%s) DO UPDATE SET %s = excluded.%s", 20 colCaveatName, 21 colCaveatDefinition, 22 colCaveatDefinition, 23 ) 24 writeCaveat = psql.Insert(tableCaveat).Columns(colCaveatName, colCaveatDefinition).Suffix(upsertCaveatSuffix) 25 readCaveat = psql.Select(colCaveatDefinition, colTimestamp) 26 listCaveat = psql.Select(colCaveatName, colCaveatDefinition, colTimestamp).From(tableCaveat).OrderBy(colCaveatName) 27 deleteCaveat = psql.Delete(tableCaveat) 28 ) 29 30 const ( 31 errWriteCaveat = "unable to write new caveat revision: %w" 32 errReadCaveat = "unable to read new caveat `%s`: %w" 33 errListCaveats = "unable to list caveat: %w" 34 errDeleteCaveats = "unable to delete caveats: %w" 35 ) 36 37 func (cr *crdbReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { 38 query := cr.fromBuilder(readCaveat, tableCaveat).Where(sq.Eq{colCaveatName: name}) 39 sql, args, err := query.ToSql() 40 if err != nil { 41 return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err) 42 } 43 44 var definitionBytes []byte 45 var timestamp time.Time 46 47 err = cr.query.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error { 48 return row.Scan(&definitionBytes, ×tamp) 49 }, sql, args...) 50 if err != nil { 51 if errors.Is(err, pgx.ErrNoRows) { 52 err = datastore.NewCaveatNameNotFoundErr(name) 53 } 54 return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err) 55 } 56 57 loaded := &core.CaveatDefinition{} 58 if err := loaded.UnmarshalVT(definitionBytes); err != nil { 59 return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err) 60 } 61 cr.addOverlapKey(name) 62 return loaded, revisions.NewHLCForTime(timestamp), nil 63 } 64 65 func (cr *crdbReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) { 66 if len(caveatNames) == 0 { 67 return nil, nil 68 } 69 return cr.lookupCaveats(ctx, caveatNames) 70 } 71 72 func (cr *crdbReader) ListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { 73 return cr.lookupCaveats(ctx, nil) 74 } 75 76 type bytesAndTimestamp struct { 77 bytes []byte 78 timestamp time.Time 79 } 80 81 func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) { 82 caveatsWithNames := cr.fromBuilder(listCaveat, tableCaveat) 83 if len(caveatNames) > 0 { 84 caveatsWithNames = caveatsWithNames.Where(sq.Eq{colCaveatName: caveatNames}) 85 } 86 87 sql, args, err := caveatsWithNames.ToSql() 88 if err != nil { 89 return nil, fmt.Errorf(errListCaveats, err) 90 } 91 92 var allDefinitionBytes []bytesAndTimestamp 93 94 err = cr.query.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error { 95 for rows.Next() { 96 var defBytes []byte 97 var name string 98 var timestamp time.Time 99 err = rows.Scan(&name, &defBytes, ×tamp) 100 if err != nil { 101 return fmt.Errorf(errListCaveats, err) 102 } 103 allDefinitionBytes = append(allDefinitionBytes, bytesAndTimestamp{bytes: defBytes, timestamp: timestamp}) 104 cr.addOverlapKey(name) 105 } 106 return nil 107 }, sql, args...) 108 if err != nil { 109 return nil, fmt.Errorf(errListCaveats, err) 110 } 111 112 caveats := make([]datastore.RevisionedCaveat, 0, len(allDefinitionBytes)) 113 for _, bat := range allDefinitionBytes { 114 loaded := &core.CaveatDefinition{} 115 if err := loaded.UnmarshalVT(bat.bytes); err != nil { 116 return nil, fmt.Errorf(errListCaveats, err) 117 } 118 caveats = append(caveats, datastore.RevisionedCaveat{ 119 Definition: loaded, 120 LastWrittenRevision: revisions.NewHLCForTime(bat.timestamp), 121 }) 122 } 123 124 return caveats, nil 125 } 126 127 func (rwt *crdbReadWriteTXN) WriteCaveats(ctx context.Context, caveats []*core.CaveatDefinition) error { 128 if len(caveats) == 0 { 129 return nil 130 } 131 write := writeCaveat 132 writtenCaveatNames := make([]string, 0, len(caveats)) 133 for _, caveat := range caveats { 134 definitionBytes, err := caveat.MarshalVT() 135 if err != nil { 136 return fmt.Errorf(errWriteCaveat, err) 137 } 138 valuesToWrite := []any{caveat.Name, definitionBytes} 139 write = write.Values(valuesToWrite...) 140 writtenCaveatNames = append(writtenCaveatNames, caveat.Name) 141 } 142 143 // store the new caveat 144 sql, args, err := write.ToSql() 145 if err != nil { 146 return fmt.Errorf(errWriteCaveat, err) 147 } 148 149 for _, val := range writtenCaveatNames { 150 rwt.addOverlapKey(val) 151 } 152 if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil { 153 return fmt.Errorf(errWriteCaveat, err) 154 } 155 return nil 156 } 157 158 func (rwt *crdbReadWriteTXN) DeleteCaveats(ctx context.Context, names []string) error { 159 deleteCaveatClause := deleteCaveat.Where(sq.Eq{colCaveatName: names}) 160 sql, args, err := deleteCaveatClause.ToSql() 161 if err != nil { 162 return fmt.Errorf(errDeleteCaveats, err) 163 } 164 for _, val := range names { 165 rwt.addOverlapKey(val) 166 } 167 if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil { 168 return fmt.Errorf(errDeleteCaveats, err) 169 } 170 return nil 171 }