github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/mysql/caveat.go (about)

     1  package mysql
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  
     9  	"github.com/authzed/spicedb/internal/datastore/common"
    10  	"github.com/authzed/spicedb/internal/datastore/revisions"
    11  	"github.com/authzed/spicedb/pkg/datastore"
    12  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    13  
    14  	sq "github.com/Masterminds/squirrel"
    15  )
    16  
    17  const (
    18  	errDeleteCaveat = "unable to delete caveats: %w"
    19  	errReadCaveat   = "unable to read caveat: %w"
    20  	errListCaveats  = "unable to list caveats: %w"
    21  	errWriteCaveats = "unable to write caveats: %w"
    22  )
    23  
    24  func (mr *mysqlReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) {
    25  	filteredReadCaveat := mr.filterer(mr.ReadCaveatQuery)
    26  	sqlStatement, args, err := filteredReadCaveat.Where(sq.Eq{colName: name}).ToSql()
    27  	if err != nil {
    28  		return nil, datastore.NoRevision, err
    29  	}
    30  
    31  	tx, txCleanup, err := mr.txSource(ctx)
    32  	if err != nil {
    33  		return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err)
    34  	}
    35  	defer common.LogOnError(ctx, txCleanup)
    36  
    37  	var serializedDef []byte
    38  	var txID uint64
    39  	err = tx.QueryRowContext(ctx, sqlStatement, args...).Scan(&serializedDef, &txID)
    40  	if err != nil {
    41  		if errors.Is(err, sql.ErrNoRows) {
    42  			return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name)
    43  		}
    44  		return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err)
    45  	}
    46  	def := core.CaveatDefinition{}
    47  	err = def.UnmarshalVT(serializedDef)
    48  	if err != nil {
    49  		return nil, datastore.NoRevision, fmt.Errorf(errReadCaveat, err)
    50  	}
    51  	return &def, revisions.NewForTransactionID(txID), nil
    52  }
    53  
    54  func (mr *mysqlReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
    55  	if len(caveatNames) == 0 {
    56  		return nil, nil
    57  	}
    58  	return mr.lookupCaveats(ctx, caveatNames)
    59  }
    60  
    61  func (mr *mysqlReader) ListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) {
    62  	return mr.lookupCaveats(ctx, nil)
    63  }
    64  
    65  func (mr *mysqlReader) lookupCaveats(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
    66  	caveatsWithNames := mr.ListCaveatsQuery
    67  	if len(caveatNames) > 0 {
    68  		caveatsWithNames = caveatsWithNames.Where(sq.Eq{colName: caveatNames})
    69  	}
    70  
    71  	filteredListCaveat := mr.filterer(caveatsWithNames)
    72  	listSQL, listArgs, err := filteredListCaveat.ToSql()
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	tx, txCleanup, err := mr.txSource(ctx)
    78  	if err != nil {
    79  		return nil, fmt.Errorf(errListCaveats, err)
    80  	}
    81  	defer common.LogOnError(ctx, txCleanup)
    82  
    83  	rows, err := tx.QueryContext(ctx, listSQL, listArgs...)
    84  	if err != nil {
    85  		return nil, fmt.Errorf(errListCaveats, err)
    86  	}
    87  	defer common.LogOnError(ctx, rows.Close)
    88  
    89  	var caveats []datastore.RevisionedCaveat
    90  	for rows.Next() {
    91  		var defBytes []byte
    92  		var txID uint64
    93  
    94  		err = rows.Scan(&defBytes, &txID)
    95  		if err != nil {
    96  			return nil, fmt.Errorf(errListCaveats, err)
    97  		}
    98  		c := core.CaveatDefinition{}
    99  		err = c.UnmarshalVT(defBytes)
   100  		if err != nil {
   101  			return nil, fmt.Errorf(errListCaveats, err)
   102  		}
   103  		caveats = append(caveats, datastore.RevisionedCaveat{
   104  			Definition:          &c,
   105  			LastWrittenRevision: revisions.NewForTransactionID(txID),
   106  		})
   107  	}
   108  	if rows.Err() != nil {
   109  		return nil, fmt.Errorf(errListCaveats, rows.Err())
   110  	}
   111  
   112  	return caveats, nil
   113  }
   114  
   115  func (rwt *mysqlReadWriteTXN) WriteCaveats(ctx context.Context, caveats []*core.CaveatDefinition) error {
   116  	if len(caveats) == 0 {
   117  		return nil
   118  	}
   119  	writeQuery := rwt.WriteCaveatQuery
   120  
   121  	caveatNamesToWrite := make([]string, 0, len(caveats))
   122  	for _, newCaveat := range caveats {
   123  		serialized, err := newCaveat.MarshalVT()
   124  		if err != nil {
   125  			return fmt.Errorf("unable to write caveat: %w", err)
   126  		}
   127  
   128  		writeQuery = writeQuery.Values(newCaveat.Name, serialized, rwt.newTxnID)
   129  		caveatNamesToWrite = append(caveatNamesToWrite, newCaveat.Name)
   130  	}
   131  
   132  	err := rwt.deleteCaveatsFromNames(ctx, caveatNamesToWrite)
   133  	if err != nil {
   134  		return fmt.Errorf(errWriteCaveats, err)
   135  	}
   136  
   137  	querySQL, writeArgs, err := writeQuery.ToSql()
   138  	if err != nil {
   139  		return fmt.Errorf(errWriteCaveats, err)
   140  	}
   141  
   142  	_, err = rwt.tx.ExecContext(ctx, querySQL, writeArgs...)
   143  	if err != nil {
   144  		return fmt.Errorf(errWriteCaveats, err)
   145  	}
   146  
   147  	return nil
   148  }
   149  
   150  func (rwt *mysqlReadWriteTXN) DeleteCaveats(ctx context.Context, names []string) error {
   151  	return rwt.deleteCaveatsFromNames(ctx, names)
   152  }
   153  
   154  func (rwt *mysqlReadWriteTXN) deleteCaveatsFromNames(ctx context.Context, names []string) error {
   155  	delSQL, delArgs, err := rwt.DeleteCaveatQuery.
   156  		Set(colDeletedTxn, rwt.newTxnID).
   157  		Where(sq.Eq{colName: names}).
   158  		ToSql()
   159  	if err != nil {
   160  		return fmt.Errorf(errDeleteCaveat, err)
   161  	}
   162  
   163  	_, err = rwt.tx.ExecContext(ctx, delSQL, delArgs...)
   164  	if err != nil {
   165  		return fmt.Errorf(errDeleteCaveat, err)
   166  	}
   167  	return nil
   168  }