github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/crdb/caveat.go (about)

     1  package crdb
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"time"
     8  
     9  	sq "github.com/Masterminds/squirrel"
    10  	"github.com/jackc/pgx/v5"
    11  
    12  	"github.com/authzed/spicedb/internal/datastore/revisions"
    13  	"github.com/authzed/spicedb/pkg/datastore"
    14  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    15  )
    16  
    17  var (
    18  	upsertCaveatSuffix = fmt.Sprintf(
    19  		"ON CONFLICT (%s) DO UPDATE SET %s = excluded.%s",
    20  		colCaveatName,
    21  		colCaveatDefinition,
    22  		colCaveatDefinition,
    23  	)
    24  	writeCaveat  = psql.Insert(tableCaveat).Columns(colCaveatName, colCaveatDefinition).Suffix(upsertCaveatSuffix)
    25  	readCaveat   = psql.Select(colCaveatDefinition, colTimestamp)
    26  	listCaveat   = psql.Select(colCaveatName, colCaveatDefinition, colTimestamp).From(tableCaveat).OrderBy(colCaveatName)
    27  	deleteCaveat = psql.Delete(tableCaveat)
    28  )
    29  
    30  const (
    31  	errWriteCaveat   = "unable to write new caveat revision: %w"
    32  	errReadCaveat    = "unable to read new caveat `%s`: %w"
    33  	errListCaveats   = "unable to list caveat: %w"
    34  	errDeleteCaveats = "unable to delete caveats: %w"
    35  )
    36  
    37  func (cr *crdbReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) {
    38  	query := cr.fromBuilder(readCaveat, tableCaveat).Where(sq.Eq{colCaveatName: name})
    39  	sql, args, err := query.ToSql()
    40  	if err != nil {
    41  		return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err)
    42  	}
    43  
    44  	var definitionBytes []byte
    45  	var timestamp time.Time
    46  
    47  	err = cr.query.QueryRowFunc(ctx, func(ctx context.Context, row pgx.Row) error {
    48  		return row.Scan(&definitionBytes, &timestamp)
    49  	}, sql, args...)
    50  	if err != nil {
    51  		if errors.Is(err, pgx.ErrNoRows) {
    52  			err = datastore.NewCaveatNameNotFoundErr(name)
    53  		}
    54  		return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err)
    55  	}
    56  
    57  	loaded := &core.CaveatDefinition{}
    58  	if err := loaded.UnmarshalVT(definitionBytes); err != nil {
    59  		return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, name, err)
    60  	}
    61  	cr.addOverlapKey(name)
    62  	return loaded, revisions.NewHLCForTime(timestamp), nil
    63  }
    64  
    65  func (cr *crdbReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
    66  	if len(caveatNames) == 0 {
    67  		return nil, nil
    68  	}
    69  	return cr.lookupCaveats(ctx, caveatNames)
    70  }
    71  
    72  func (cr *crdbReader) ListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) {
    73  	return cr.lookupCaveats(ctx, nil)
    74  }
    75  
    76  type bytesAndTimestamp struct {
    77  	bytes     []byte
    78  	timestamp time.Time
    79  }
    80  
    81  func (cr *crdbReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
    82  	caveatsWithNames := cr.fromBuilder(listCaveat, tableCaveat)
    83  	if len(caveatNames) > 0 {
    84  		caveatsWithNames = caveatsWithNames.Where(sq.Eq{colCaveatName: caveatNames})
    85  	}
    86  
    87  	sql, args, err := caveatsWithNames.ToSql()
    88  	if err != nil {
    89  		return nil, fmt.Errorf(errListCaveats, err)
    90  	}
    91  
    92  	var allDefinitionBytes []bytesAndTimestamp
    93  
    94  	err = cr.query.QueryFunc(ctx, func(ctx context.Context, rows pgx.Rows) error {
    95  		for rows.Next() {
    96  			var defBytes []byte
    97  			var name string
    98  			var timestamp time.Time
    99  			err = rows.Scan(&name, &defBytes, &timestamp)
   100  			if err != nil {
   101  				return fmt.Errorf(errListCaveats, err)
   102  			}
   103  			allDefinitionBytes = append(allDefinitionBytes, bytesAndTimestamp{bytes: defBytes, timestamp: timestamp})
   104  			cr.addOverlapKey(name)
   105  		}
   106  		return nil
   107  	}, sql, args...)
   108  	if err != nil {
   109  		return nil, fmt.Errorf(errListCaveats, err)
   110  	}
   111  
   112  	caveats := make([]datastore.RevisionedCaveat, 0, len(allDefinitionBytes))
   113  	for _, bat := range allDefinitionBytes {
   114  		loaded := &core.CaveatDefinition{}
   115  		if err := loaded.UnmarshalVT(bat.bytes); err != nil {
   116  			return nil, fmt.Errorf(errListCaveats, err)
   117  		}
   118  		caveats = append(caveats, datastore.RevisionedCaveat{
   119  			Definition:          loaded,
   120  			LastWrittenRevision: revisions.NewHLCForTime(bat.timestamp),
   121  		})
   122  	}
   123  
   124  	return caveats, nil
   125  }
   126  
   127  func (rwt *crdbReadWriteTXN) WriteCaveats(ctx context.Context, caveats []*core.CaveatDefinition) error {
   128  	if len(caveats) == 0 {
   129  		return nil
   130  	}
   131  	write := writeCaveat
   132  	writtenCaveatNames := make([]string, 0, len(caveats))
   133  	for _, caveat := range caveats {
   134  		definitionBytes, err := caveat.MarshalVT()
   135  		if err != nil {
   136  			return fmt.Errorf(errWriteCaveat, err)
   137  		}
   138  		valuesToWrite := []any{caveat.Name, definitionBytes}
   139  		write = write.Values(valuesToWrite...)
   140  		writtenCaveatNames = append(writtenCaveatNames, caveat.Name)
   141  	}
   142  
   143  	// store the new caveat
   144  	sql, args, err := write.ToSql()
   145  	if err != nil {
   146  		return fmt.Errorf(errWriteCaveat, err)
   147  	}
   148  
   149  	for _, val := range writtenCaveatNames {
   150  		rwt.addOverlapKey(val)
   151  	}
   152  	if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil {
   153  		return fmt.Errorf(errWriteCaveat, err)
   154  	}
   155  	return nil
   156  }
   157  
   158  func (rwt *crdbReadWriteTXN) DeleteCaveats(ctx context.Context, names []string) error {
   159  	deleteCaveatClause := deleteCaveat.Where(sq.Eq{colCaveatName: names})
   160  	sql, args, err := deleteCaveatClause.ToSql()
   161  	if err != nil {
   162  		return fmt.Errorf(errDeleteCaveats, err)
   163  	}
   164  	for _, val := range names {
   165  		rwt.addOverlapKey(val)
   166  	}
   167  	if _, err := rwt.tx.Exec(ctx, sql, args...); err != nil {
   168  		return fmt.Errorf(errDeleteCaveats, err)
   169  	}
   170  	return nil
   171  }