github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/spanner/caveat.go (about)

     1  package spanner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"time"
     7  
     8  	"cloud.google.com/go/spanner"
     9  	"google.golang.org/grpc/codes"
    10  
    11  	"github.com/authzed/spicedb/internal/datastore/common"
    12  	"github.com/authzed/spicedb/internal/datastore/revisions"
    13  	"github.com/authzed/spicedb/pkg/datastore"
    14  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    15  )
    16  
    17  func (sr spannerReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) {
    18  	caveatKey := spanner.Key{name}
    19  	row, err := sr.txSource().ReadRow(ctx, tableCaveat, caveatKey, []string{colCaveatDefinition, colCaveatTS})
    20  	if err != nil {
    21  		if spanner.ErrCode(err) == codes.NotFound {
    22  			return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name)
    23  		}
    24  		return nil, datastore.NoRevision, fmt.Errorf(errUnableToReadCaveat, err)
    25  	}
    26  	var serialized []byte
    27  	var updated time.Time
    28  	if err := row.Columns(&serialized, &updated); err != nil {
    29  		return nil, datastore.NoRevision, fmt.Errorf(errUnableToReadCaveat, err)
    30  	}
    31  
    32  	loaded := &core.CaveatDefinition{}
    33  	if err := loaded.UnmarshalVT(serialized); err != nil {
    34  		return nil, datastore.NoRevision, err
    35  	}
    36  	return loaded, revisions.NewForTime(updated), nil
    37  }
    38  
    39  func (sr spannerReader) ListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) {
    40  	return sr.listCaveats(ctx, nil)
    41  }
    42  
    43  func (sr spannerReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
    44  	if len(caveatNames) == 0 {
    45  		return nil, nil
    46  	}
    47  	return sr.listCaveats(ctx, caveatNames)
    48  }
    49  
    50  func (sr spannerReader) listCaveats(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
    51  	keyset := spanner.AllKeys()
    52  	if len(caveatNames) > 0 {
    53  		keys := make([]spanner.Key, 0, len(caveatNames))
    54  		for _, n := range caveatNames {
    55  			keys = append(keys, spanner.Key{n})
    56  		}
    57  		keyset = spanner.KeySetFromKeys(keys...)
    58  	}
    59  	iter := sr.txSource().Read(
    60  		ctx,
    61  		tableCaveat,
    62  		keyset,
    63  		[]string{colCaveatDefinition, colCaveatTS},
    64  	)
    65  	defer iter.Stop()
    66  
    67  	var caveats []datastore.RevisionedCaveat
    68  	if err := iter.Do(func(row *spanner.Row) error {
    69  		var serialized []byte
    70  		var updated time.Time
    71  		if err := row.Columns(&serialized, &updated); err != nil {
    72  			return err
    73  		}
    74  
    75  		loaded := &core.CaveatDefinition{}
    76  		if err := loaded.UnmarshalVT(serialized); err != nil {
    77  			return err
    78  		}
    79  		caveats = append(caveats, datastore.RevisionedCaveat{
    80  			Definition:          loaded,
    81  			LastWrittenRevision: revisions.NewForTime(updated),
    82  		})
    83  
    84  		return nil
    85  	}); err != nil {
    86  		return nil, fmt.Errorf(errUnableToListCaveats, err)
    87  	}
    88  
    89  	return caveats, nil
    90  }
    91  
    92  func (rwt spannerReadWriteTXN) WriteCaveats(_ context.Context, caveats []*core.CaveatDefinition) error {
    93  	names := map[string]struct{}{}
    94  	mutations := make([]*spanner.Mutation, 0, len(caveats))
    95  	for _, caveat := range caveats {
    96  		if _, ok := names[caveat.Name]; ok {
    97  			return fmt.Errorf(errUnableToWriteCaveat, fmt.Errorf("duplicate caveats in input: %s", caveat.Name))
    98  		}
    99  		names[caveat.Name] = struct{}{}
   100  		serialized, err := caveat.MarshalVT()
   101  		if err != nil {
   102  			return fmt.Errorf(errUnableToWriteCaveat, err)
   103  		}
   104  
   105  		mutations = append(mutations, spanner.InsertOrUpdate(
   106  			tableCaveat,
   107  			[]string{colName, colCaveatDefinition, colCaveatTS},
   108  			[]any{caveat.Name, serialized, spanner.CommitTimestamp},
   109  		))
   110  	}
   111  
   112  	return rwt.spannerRWT.BufferWrite(mutations)
   113  }
   114  
   115  func (rwt spannerReadWriteTXN) DeleteCaveats(_ context.Context, names []string) error {
   116  	keys := make([]spanner.Key, 0, len(names))
   117  	for _, n := range names {
   118  		keys = append(keys, spanner.Key{n})
   119  	}
   120  	err := rwt.spannerRWT.BufferWrite([]*spanner.Mutation{
   121  		spanner.Delete(tableCaveat, spanner.KeySetFromKeys(keys...)),
   122  	})
   123  	if err != nil {
   124  		return fmt.Errorf(errUnableToDeleteCaveat, err)
   125  	}
   126  
   127  	return err
   128  }
   129  
   130  func ContextualizedCaveatFrom(name spanner.NullString, context spanner.NullJSON) (*core.ContextualizedCaveat, error) {
   131  	if name.Valid && name.StringVal != "" {
   132  		var cctx map[string]any
   133  		if context.Valid && context.Value != nil {
   134  			cctx = context.Value.(map[string]any)
   135  		}
   136  		return common.ContextualizedCaveatFrom(name.StringVal, cctx)
   137  	}
   138  	return nil, nil
   139  }