github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/mysql/caveat.go (about) 1 package mysql 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 9 "github.com/authzed/spicedb/internal/datastore/common" 10 "github.com/authzed/spicedb/internal/datastore/revisions" 11 "github.com/authzed/spicedb/pkg/datastore" 12 core "github.com/authzed/spicedb/pkg/proto/core/v1" 13 14 sq "github.com/Masterminds/squirrel" 15 ) 16 17 const ( 18 errDeleteCaveat = "unable to delete caveats: %w" 19 errReadCaveat = "unable to read caveat: %w" 20 errListCaveats = "unable to list caveats: %w" 21 errWriteCaveats = "unable to write caveats: %w" 22 ) 23 24 func (mr *mysqlReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) { 25 filteredReadCaveat := mr.filterer(mr.ReadCaveatQuery) 26 sqlStatement, args, err := filteredReadCaveat.Where(sq.Eq{colName: name}).ToSql() 27 if err != nil { 28 return nil, datastore.NoRevision, err 29 } 30 31 tx, txCleanup, err := mr.txSource(ctx) 32 if err != nil { 33 return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err) 34 } 35 defer common.LogOnError(ctx, txCleanup) 36 37 var serializedDef []byte 38 var txID uint64 39 err = tx.QueryRowContext(ctx, sqlStatement, args...).Scan(&serializedDef, &txID) 40 if err != nil { 41 if errors.Is(err, sql.ErrNoRows) { 42 return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name) 43 } 44 return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err) 45 } 46 def := core.CaveatDefinition{} 47 err = def.UnmarshalVT(serializedDef) 48 if err != nil { 49 return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err) 50 } 51 return &def, revisions.NewForTransactionID(txID), nil 52 } 53 54 func (mr *mysqlReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) { 55 if len(caveatNames) == 0 { 56 return nil, nil 57 } 58 return mr.lookupCaveats(ctx, caveatNames) 59 } 60 61 func (mr *mysqlReader) ListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) { 62 return mr.lookupCaveats(ctx, nil) 63 } 64 65 func (mr *mysqlReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) { 66 caveatsWithNames := mr.ListCaveatsQuery 67 if len(caveatNames) > 0 { 68 caveatsWithNames = caveatsWithNames.Where(sq.Eq{colName: caveatNames}) 69 } 70 71 filteredListCaveat := mr.filterer(caveatsWithNames) 72 listSQL, listArgs, err := filteredListCaveat.ToSql() 73 if err != nil { 74 return nil, err 75 } 76 77 tx, txCleanup, err := mr.txSource(ctx) 78 if err != nil { 79 return nil, fmt.Errorf(errListCaveats, err) 80 } 81 defer common.LogOnError(ctx, txCleanup) 82 83 rows, err := tx.QueryContext(ctx, listSQL, listArgs...) 84 if err != nil { 85 return nil, fmt.Errorf(errListCaveats, err) 86 } 87 defer common.LogOnError(ctx, rows.Close) 88 89 var caveats []datastore.RevisionedCaveat 90 for rows.Next() { 91 var defBytes []byte 92 var txID uint64 93 94 err = rows.Scan(&defBytes, &txID) 95 if err != nil { 96 return nil, fmt.Errorf(errListCaveats, err) 97 } 98 c := core.CaveatDefinition{} 99 err = c.UnmarshalVT(defBytes) 100 if err != nil { 101 return nil, fmt.Errorf(errListCaveats, err) 102 } 103 caveats = append(caveats, datastore.RevisionedCaveat{ 104 Definition: &c, 105 LastWrittenRevision: revisions.NewForTransactionID(txID), 106 }) 107 } 108 if rows.Err() != nil { 109 return nil, fmt.Errorf(errListCaveats, rows.Err()) 110 } 111 112 return caveats, nil 113 } 114 115 func (rwt *mysqlReadWriteTXN) WriteCaveats(ctx context.Context, caveats []*core.CaveatDefinition) error { 116 if len(caveats) == 0 { 117 return nil 118 } 119 writeQuery := rwt.WriteCaveatQuery 120 121 caveatNamesToWrite := make([]string, 0, len(caveats)) 122 for _, newCaveat := range caveats { 123 serialized, err := newCaveat.MarshalVT() 124 if err != nil { 125 return fmt.Errorf("unable to write caveat: %w", err) 126 } 127 128 writeQuery = writeQuery.Values(newCaveat.Name, serialized, rwt.newTxnID) 129 caveatNamesToWrite = append(caveatNamesToWrite, newCaveat.Name) 130 } 131 132 err := rwt.deleteCaveatsFromNames(ctx, caveatNamesToWrite) 133 if err != nil { 134 return fmt.Errorf(errWriteCaveats, err) 135 } 136 137 querySQL, writeArgs, err := writeQuery.ToSql() 138 if err != nil { 139 return fmt.Errorf(errWriteCaveats, err) 140 } 141 142 _, err = rwt.tx.ExecContext(ctx, querySQL, writeArgs...) 143 if err != nil { 144 return fmt.Errorf(errWriteCaveats, err) 145 } 146 147 return nil 148 } 149 150 func (rwt *mysqlReadWriteTXN) DeleteCaveats(ctx context.Context, names []string) error { 151 return rwt.deleteCaveatsFromNames(ctx, names) 152 } 153 154 func (rwt *mysqlReadWriteTXN) deleteCaveatsFromNames(ctx context.Context, names []string) error { 155 delSQL, delArgs, err := rwt.DeleteCaveatQuery. 156 Set(colDeletedTxn, rwt.newTxnID). 157 Where(sq.Eq{colName: names}). 158 ToSql() 159 if err != nil { 160 return fmt.Errorf(errDeleteCaveat, err) 161 } 162 163 _, err = rwt.tx.ExecContext(ctx, delSQL, delArgs...) 164 if err != nil { 165 return fmt.Errorf(errDeleteCaveat, err) 166 } 167 return nil 168 }