github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/postgres/caveat.go (about)

     1  package postgres
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  
     8  	"github.com/jackc/pgx/v5"
     9  
    10  	"github.com/authzed/spicedb/pkg/datastore"
    11  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    12  
    13  	sq "github.com/Masterminds/squirrel"
    14  )
    15  
    16  var (
    17  	writeCaveat = psql.Insert(tableCaveat).Columns(colCaveatName, colCaveatDefinition)
    18  	listCaveat  = psql.
    19  			Select(colCaveatDefinition, colCreatedXid).
    20  			From(tableCaveat).
    21  			OrderBy(colCaveatName)
    22  	readCaveat = psql.
    23  			Select(colCaveatDefinition, colCreatedXid).
    24  			From(tableCaveat)
    25  	deleteCaveat = psql.Update(tableCaveat).Where(sq.Eq{colDeletedXid: liveDeletedTxnID})
    26  )
    27  
    28  const (
    29  	errWriteCaveats  = "unable to write caveats: %w"
    30  	errDeleteCaveats = "unable delete caveats: %w"
    31  	errListCaveats   = "unable to list caveats: %w"
    32  	errReadCaveat    = "unable to read caveat: %w"
    33  )
    34  
    35  func (r *pgReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) {
    36  	filteredReadCaveat := r.filterer(readCaveat)
    37  	sql, args, err := filteredReadCaveat.Where(sq.Eq{colCaveatName: name}).ToSql()
    38  	if err != nil {
    39  		return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err)
    40  	}
    41  
    42  	var txID xid8
    43  	var serializedDef []byte
    44  	err = r.query.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error {
    45  		return row.Scan(&serializedDef, &txID)
    46  	}, sql, args...)
    47  	if err != nil {
    48  		if errors.Is(err, pgx.ErrNoRows) {
    49  			return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name)
    50  		}
    51  		return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err)
    52  	}
    53  	def := core.CaveatDefinition{}
    54  	err = def.UnmarshalVT(serializedDef)
    55  	if err != nil {
    56  		return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err)
    57  	}
    58  
    59  	rev := revisionForVersion(txID)
    60  
    61  	return &def, rev, nil
    62  }
    63  
    64  func (r *pgReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
    65  	if len(caveatNames) == 0 {
    66  		return nil, nil
    67  	}
    68  	return r.lookupCaveats(ctx, caveatNames)
    69  }
    70  
    71  func (r *pgReader) ListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) {
    72  	return r.lookupCaveats(ctx, nil)
    73  }
    74  
    75  func (r *pgReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
    76  	caveatsWithNames := listCaveat
    77  	if len(caveatNames) > 0 {
    78  		caveatsWithNames = caveatsWithNames.Where(sq.Eq{colCaveatName: caveatNames})
    79  	}
    80  
    81  	filteredListCaveat := r.filterer(caveatsWithNames)
    82  	sql, args, err := filteredListCaveat.ToSql()
    83  	if err != nil {
    84  		return nil, fmt.Errorf(errListCaveats, err)
    85  	}
    86  
    87  	var caveats []datastore.RevisionedCaveat
    88  	err = r.query.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error {
    89  		for rows.Next() {
    90  			var version xid8
    91  			var defBytes []byte
    92  			err = rows.Scan(&defBytes, &version)
    93  			if err != nil {
    94  				return fmt.Errorf(errListCaveats, err)
    95  			}
    96  			c := core.CaveatDefinition{}
    97  			err = c.UnmarshalVT(defBytes)
    98  			if err != nil {
    99  				return fmt.Errorf(errListCaveats, err)
   100  			}
   101  
   102  			revision := revisionForVersion(version)
   103  			caveats = append(caveats, datastore.RevisionedCaveat{Definition: &c, LastWrittenRevision: revision})
   104  		}
   105  		return rows.Err()
   106  	}, sql, args...)
   107  	if err != nil {
   108  		return nil, fmt.Errorf(errListCaveats, err)
   109  	}
   110  
   111  	return caveats, nil
   112  }
   113  
   114  func (rwt *pgReadWriteTXN) WriteCaveats(ctx context.Context, caveats []*core.CaveatDefinition) error {
   115  	if len(caveats) == 0 {
   116  		return nil
   117  	}
   118  	write := writeCaveat
   119  	writtenCaveatNames := make([]string, 0, len(caveats))
   120  	for _, caveat := range caveats {
   121  		definitionBytes, err := caveat.MarshalVT()
   122  		if err != nil {
   123  			return fmt.Errorf(errWriteCaveats, err)
   124  		}
   125  		valuesToWrite := []any{caveat.Name, definitionBytes}
   126  		write = write.Values(valuesToWrite...)
   127  		writtenCaveatNames = append(writtenCaveatNames, caveat.Name)
   128  	}
   129  
   130  	// mark current caveats as deleted
   131  	err := rwt.deleteCaveatsFromNames(ctx, writtenCaveatNames)
   132  	if err != nil {
   133  		return fmt.Errorf(errWriteCaveats, err)
   134  	}
   135  
   136  	// store the new caveat revision
   137  	sql, args, err := write.ToSql()
   138  	if err != nil {
   139  		return fmt.Errorf(errWriteCaveats, err)
   140  	}
   141  	if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil {
   142  		return fmt.Errorf(errWriteCaveats, err)
   143  	}
   144  	return nil
   145  }
   146  
   147  func (rwt *pgReadWriteTXN) DeleteCaveats(ctx context.Context, names []string) error {
   148  	// mark current caveats as deleted
   149  	return rwt.deleteCaveatsFromNames(ctx, names)
   150  }
   151  
   152  func (rwt *pgReadWriteTXN) deleteCaveatsFromNames(ctx context.Context, names []string) error {
   153  	sql, args, err := deleteCaveat.
   154  		Set(colDeletedXid, rwt.newXID).
   155  		Where(sq.Eq{colCaveatName: names}).
   156  		ToSql()
   157  	if err != nil {
   158  		return fmt.Errorf(errDeleteCaveats, err)
   159  	}
   160  
   161  	if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil {
   162  		return fmt.Errorf(errDeleteCaveats, err)
   163  	}
   164  	return nil
   165  }