github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/testfixtures/validating.go (about)

     1  package testfixtures
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  
     8  	v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
     9  
    10  	"github.com/authzed/spicedb/pkg/datastore"
    11  	"github.com/authzed/spicedb/pkg/datastore/options"
    12  	"github.com/authzed/spicedb/pkg/genutil/mapz"
    13  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    14  	"github.com/authzed/spicedb/pkg/tuple"
    15  )
    16  
    17  type validatingDatastore struct {
    18  	datastore.Datastore
    19  }
    20  
    21  // NewValidatingDatastore creates a proxy which runs validation on all call parameters before
    22  // passing the call onward.
    23  func NewValidatingDatastore(delegate datastore.Datastore) datastore.Datastore {
    24  	return validatingDatastore{Datastore: delegate}
    25  }
    26  
    27  func (vd validatingDatastore) SnapshotReader(revision datastore.Revision) datastore.Reader {
    28  	return validatingSnapshotReader{vd.Datastore.SnapshotReader(revision)}
    29  }
    30  
    31  func (vd validatingDatastore) ReadWriteTx(
    32  	ctx context.Context,
    33  	f datastore.TxUserFunc,
    34  	opts ...options.RWTOptionsOption,
    35  ) (datastore.Revision, error) {
    36  	if f == nil {
    37  		return datastore.NoRevision, fmt.Errorf("nil delegate function")
    38  	}
    39  
    40  	return vd.Datastore.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
    41  		txDelegate := validatingReadWriteTransaction{validatingSnapshotReader{rwt}, rwt}
    42  		return f(ctx, txDelegate)
    43  	}, opts...)
    44  }
    45  
    46  func (vd validatingDatastore) Unwrap() datastore.Datastore {
    47  	return vd.Datastore
    48  }
    49  
    50  type validatingSnapshotReader struct {
    51  	delegate datastore.Reader
    52  }
    53  
    54  func (vsr validatingSnapshotReader) ListAllNamespaces(
    55  	ctx context.Context,
    56  ) ([]datastore.RevisionedNamespace, error) {
    57  	read, err := vsr.delegate.ListAllNamespaces(ctx)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	for _, ns := range read {
    63  		err := ns.Definition.Validate()
    64  		if err != nil {
    65  			return nil, err
    66  		}
    67  	}
    68  
    69  	return read, err
    70  }
    71  
    72  func (vsr validatingSnapshotReader) LookupNamespacesWithNames(
    73  	ctx context.Context,
    74  	nsNames []string,
    75  ) ([]datastore.RevisionedNamespace, error) {
    76  	read, err := vsr.delegate.LookupNamespacesWithNames(ctx, nsNames)
    77  	if err != nil {
    78  		return read, err
    79  	}
    80  
    81  	for _, ns := range read {
    82  		err := ns.Definition.Validate()
    83  		if err != nil {
    84  			return nil, err
    85  		}
    86  	}
    87  
    88  	return read, nil
    89  }
    90  
    91  func (vsr validatingSnapshotReader) QueryRelationships(ctx context.Context,
    92  	filter datastore.RelationshipsFilter,
    93  	opts ...options.QueryOptionsOption,
    94  ) (datastore.RelationshipIterator, error) {
    95  	return vsr.delegate.QueryRelationships(ctx, filter, opts...)
    96  }
    97  
    98  func (vsr validatingSnapshotReader) ReadNamespaceByName(
    99  	ctx context.Context,
   100  	nsName string,
   101  ) (*core.NamespaceDefinition, datastore.Revision, error) {
   102  	read, createdAt, err := vsr.delegate.ReadNamespaceByName(ctx, nsName)
   103  	if err != nil {
   104  		return read, createdAt, err
   105  	}
   106  
   107  	err = read.Validate()
   108  	return read, createdAt, err
   109  }
   110  
   111  func (vsr validatingSnapshotReader) ReverseQueryRelationships(ctx context.Context,
   112  	subjectsFilter datastore.SubjectsFilter,
   113  	opts ...options.ReverseQueryOptionsOption,
   114  ) (datastore.RelationshipIterator, error) {
   115  	queryOpts := options.NewReverseQueryOptionsWithOptions(opts...)
   116  	if queryOpts.ResRelation != nil {
   117  		if queryOpts.ResRelation.Namespace == "" {
   118  			return nil, errors.New("resource relation on reverse query missing namespace")
   119  		}
   120  		if queryOpts.ResRelation.Relation == "" {
   121  			return nil, errors.New("resource relation on reverse query missing relation")
   122  		}
   123  	}
   124  
   125  	return vsr.delegate.ReverseQueryRelationships(ctx, subjectsFilter, opts...)
   126  }
   127  
   128  func (vsr validatingSnapshotReader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) {
   129  	read, createdAt, err := vsr.delegate.ReadCaveatByName(ctx, name)
   130  	if err != nil {
   131  		return read, createdAt, err
   132  	}
   133  
   134  	err = read.Validate()
   135  	return read, createdAt, err
   136  }
   137  
   138  func (vsr validatingSnapshotReader) LookupCaveatsWithNames(ctx context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
   139  	read, err := vsr.delegate.LookupCaveatsWithNames(ctx, caveatNames)
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  
   144  	for _, caveat := range read {
   145  		err := caveat.Definition.Validate()
   146  		if err != nil {
   147  			return nil, err
   148  		}
   149  	}
   150  
   151  	return read, err
   152  }
   153  
   154  func (vsr validatingSnapshotReader) ListAllCaveats(ctx context.Context) ([]datastore.RevisionedCaveat, error) {
   155  	read, err := vsr.delegate.ListAllCaveats(ctx)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  
   160  	for _, caveat := range read {
   161  		err := caveat.Definition.Validate()
   162  		if err != nil {
   163  			return nil, err
   164  		}
   165  	}
   166  
   167  	return read, err
   168  }
   169  
   170  type validatingReadWriteTransaction struct {
   171  	validatingSnapshotReader
   172  	delegate datastore.ReadWriteTransaction
   173  }
   174  
   175  func (vrwt validatingReadWriteTransaction) WriteNamespaces(ctx context.Context, newConfigs ...*core.NamespaceDefinition) error {
   176  	for _, newConfig := range newConfigs {
   177  		if err := newConfig.Validate(); err != nil {
   178  			return err
   179  		}
   180  	}
   181  	return vrwt.delegate.WriteNamespaces(ctx, newConfigs...)
   182  }
   183  
   184  func (vrwt validatingReadWriteTransaction) DeleteNamespaces(ctx context.Context, nsNames ...string) error {
   185  	return vrwt.delegate.DeleteNamespaces(ctx, nsNames...)
   186  }
   187  
   188  func (vrwt validatingReadWriteTransaction) WriteRelationships(ctx context.Context, mutations []*core.RelationTupleUpdate) error {
   189  	if err := validateUpdatesToWrite(mutations...); err != nil {
   190  		return err
   191  	}
   192  
   193  	// Ensure there are no duplicate mutations.
   194  	tupleSet := mapz.NewSet[string]()
   195  	for _, mutation := range mutations {
   196  		if err := mutation.Validate(); err != nil {
   197  			return err
   198  		}
   199  
   200  		if !tupleSet.Add(tuple.StringWithoutCaveat(mutation.Tuple)) {
   201  			return fmt.Errorf("found duplicate update for relationship %s", tuple.StringWithoutCaveat(mutation.Tuple))
   202  		}
   203  	}
   204  
   205  	return vrwt.delegate.WriteRelationships(ctx, mutations)
   206  }
   207  
   208  func (vrwt validatingReadWriteTransaction) DeleteRelationships(ctx context.Context, filter *v1.RelationshipFilter, options ...options.DeleteOptionsOption) (bool, error) {
   209  	if err := filter.Validate(); err != nil {
   210  		return false, err
   211  	}
   212  
   213  	return vrwt.delegate.DeleteRelationships(ctx, filter, options...)
   214  }
   215  
   216  func (vrwt validatingReadWriteTransaction) WriteCaveats(ctx context.Context, caveats []*core.CaveatDefinition) error {
   217  	return vrwt.delegate.WriteCaveats(ctx, caveats)
   218  }
   219  
   220  func (vrwt validatingReadWriteTransaction) DeleteCaveats(ctx context.Context, names []string) error {
   221  	return vrwt.delegate.DeleteCaveats(ctx, names)
   222  }
   223  
   224  func (vrwt validatingReadWriteTransaction) BulkLoad(ctx context.Context, source datastore.BulkWriteRelationshipSource) (uint64, error) {
   225  	return vrwt.delegate.BulkLoad(ctx, source)
   226  }
   227  
   228  // validateUpdatesToWrite performs basic validation on relationship updates going into datastores.
   229  func validateUpdatesToWrite(updates ...*core.RelationTupleUpdate) error {
   230  	for _, update := range updates {
   231  		err := tuple.UpdateToRelationshipUpdate(update).HandwrittenValidate()
   232  		if err != nil {
   233  			return err
   234  		}
   235  		if update.Tuple.Subject.Relation == "" {
   236  			return fmt.Errorf("expected ... instead of an empty relation string relation in %v", update.Tuple)
   237  		}
   238  		if update.Tuple.Subject.ObjectId == tuple.PublicWildcard && update.Tuple.Subject.Relation != tuple.Ellipsis {
   239  			return fmt.Errorf(
   240  				"attempt to write a wildcard relationship (`%s`) with a non-empty relation `%v`. Please report this bug",
   241  				tuple.MustString(update.Tuple),
   242  				update.Tuple.Subject.Relation,
   243  			)
   244  		}
   245  	}
   246  
   247  	return nil
   248  }
   249  
   250  var (
   251  	_ datastore.Datastore            = validatingDatastore{}
   252  	_ datastore.Reader               = validatingSnapshotReader{}
   253  	_ datastore.ReadWriteTransaction = validatingReadWriteTransaction{}
   254  )