github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/memdb/caveat.go (about)

     1  package memdb
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/hashicorp/go-memdb"
     8  
     9  	"github.com/authzed/spicedb/pkg/datastore"
    10  	"github.com/authzed/spicedb/pkg/genutil/mapz"
    11  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    12  )
    13  
    14  const tableCaveats = "caveats"
    15  
    16  type caveat struct {
    17  	name       string
    18  	definition []byte
    19  	revision   datastore.Revision
    20  }
    21  
    22  func (c *caveat) Unwrap() (*core.CaveatDefinition, error) {
    23  	definition := core.CaveatDefinition{}
    24  	err := definition.UnmarshalVT(c.definition)
    25  	return &definition, err
    26  }
    27  
    28  func (r *memdbReader) ReadCaveatByName(_ context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) {
    29  	r.mustLock()
    30  	defer r.Unlock()
    31  
    32  	tx, err := r.txSource()
    33  	if err != nil {
    34  		return nil, datastore.NoRevision, err
    35  	}
    36  	return r.readUnwrappedCaveatByName(tx, name)
    37  }
    38  
    39  func (r *memdbReader) readCaveatByName(tx *memdb.Txn, name string) (*caveat, datastore.Revision, error) {
    40  	found, err := tx.First(tableCaveats, indexID, name)
    41  	if err != nil {
    42  		return nil, datastore.NoRevision, err
    43  	}
    44  	if found == nil {
    45  		return nil, datastore.NoRevision, datastore.NewCaveatNameNotFoundErr(name)
    46  	}
    47  	cvt := found.(*caveat)
    48  	return cvt, cvt.revision, nil
    49  }
    50  
    51  func (r *memdbReader) readUnwrappedCaveatByName(tx *memdb.Txn, name string) (*core.CaveatDefinition, datastore.Revision, error) {
    52  	c, rev, err := r.readCaveatByName(tx, name)
    53  	if err != nil {
    54  		return nil, datastore.NoRevision, err
    55  	}
    56  	unwrapped, err := c.Unwrap()
    57  	if err != nil {
    58  		return nil, datastore.NoRevision, err
    59  	}
    60  	return unwrapped, rev, nil
    61  }
    62  
    63  func (r *memdbReader) ListAllCaveats(_ context.Context) ([]datastore.RevisionedCaveat, error) {
    64  	r.mustLock()
    65  	defer r.Unlock()
    66  
    67  	tx, err := r.txSource()
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  
    72  	var caveats []datastore.RevisionedCaveat
    73  	it, err := tx.LowerBound(tableCaveats, indexID)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	for foundRaw := it.Next(); foundRaw != nil; foundRaw = it.Next() {
    79  		rawCaveat := foundRaw.(*caveat)
    80  		definition, err := rawCaveat.Unwrap()
    81  		if err != nil {
    82  			return nil, err
    83  		}
    84  		caveats = append(caveats, datastore.RevisionedCaveat{
    85  			Definition:          definition,
    86  			LastWrittenRevision: rawCaveat.revision,
    87  		})
    88  	}
    89  
    90  	return caveats, nil
    91  }
    92  
    93  func (r *memdbReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
    94  	allCaveats, err := r.ListAllCaveats(ctx)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	allowedCaveatNames := mapz.NewSet[string]()
   100  	allowedCaveatNames.Extend(caveatNames)
   101  
   102  	toReturn := make([]datastore.RevisionedCaveat, 0, len(caveatNames))
   103  	for _, caveat := range allCaveats {
   104  		if allowedCaveatNames.Has(caveat.Definition.Name) {
   105  			toReturn = append(toReturn, caveat)
   106  		}
   107  	}
   108  	return toReturn, nil
   109  }
   110  
   111  func (rwt *memdbReadWriteTx) WriteCaveats(_ context.Context, caveats []*core.CaveatDefinition) error {
   112  	rwt.mustLock()
   113  	defer rwt.Unlock()
   114  	tx, err := rwt.txSource()
   115  	if err != nil {
   116  		return err
   117  	}
   118  	return rwt.writeCaveat(tx, caveats)
   119  }
   120  
   121  func (rwt *memdbReadWriteTx) writeCaveat(tx *memdb.Txn, caveats []*core.CaveatDefinition) error {
   122  	caveatNames := mapz.NewSet[string]()
   123  	for _, coreCaveat := range caveats {
   124  		if !caveatNames.Add(coreCaveat.Name) {
   125  			return fmt.Errorf("duplicate caveat %s", coreCaveat.Name)
   126  		}
   127  		marshalled, err := coreCaveat.MarshalVT()
   128  		if err != nil {
   129  			return err
   130  		}
   131  		c := caveat{
   132  			name:       coreCaveat.Name,
   133  			definition: marshalled,
   134  			revision:   rwt.newRevision,
   135  		}
   136  		if err := tx.Insert(tableCaveats, &c); err != nil {
   137  			return err
   138  		}
   139  	}
   140  	return nil
   141  }
   142  
   143  func (rwt *memdbReadWriteTx) DeleteCaveats(_ context.Context, names []string) error {
   144  	rwt.mustLock()
   145  	defer rwt.Unlock()
   146  	tx, err := rwt.txSource()
   147  	if err != nil {
   148  		return err
   149  	}
   150  	for _, name := range names {
   151  		if err := tx.Delete(tableCaveats, caveat{name: name}); err != nil {
   152  			return err
   153  		}
   154  	}
   155  	return nil
   156  }