github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/testfixtures/datastore.go (about)

     1  package testfixtures
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/stretchr/testify/require"
     7  	"google.golang.org/protobuf/types/known/structpb"
     8  
     9  	"github.com/authzed/spicedb/internal/datastore/common"
    10  	"github.com/authzed/spicedb/internal/namespace"
    11  	"github.com/authzed/spicedb/pkg/caveats"
    12  	caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
    13  	"github.com/authzed/spicedb/pkg/datastore"
    14  	ns "github.com/authzed/spicedb/pkg/namespace"
    15  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    16  	"github.com/authzed/spicedb/pkg/schemadsl/compiler"
    17  	"github.com/authzed/spicedb/pkg/schemadsl/input"
    18  	"github.com/authzed/spicedb/pkg/tuple"
    19  	"github.com/authzed/spicedb/pkg/typesystem"
    20  )
    21  
    22  var UserNS = ns.Namespace("user")
    23  
    24  var CaveatDef = ns.MustCaveatDefinition(
    25  	caveats.MustEnvForVariables(map[string]caveattypes.VariableType{
    26  		"secret":         caveattypes.StringType,
    27  		"expectedSecret": caveattypes.StringType,
    28  	}),
    29  	"test",
    30  	"secret == expectedSecret",
    31  )
    32  
    33  var DocumentNS = ns.Namespace(
    34  	"document",
    35  	ns.MustRelation("owner",
    36  		nil,
    37  		ns.AllowedRelation("user", "..."),
    38  	),
    39  	ns.MustRelation("editor",
    40  		nil,
    41  		ns.AllowedRelation("user", "..."),
    42  	),
    43  	ns.MustRelation("viewer",
    44  		nil,
    45  		ns.AllowedRelation("user", "..."),
    46  	),
    47  	ns.MustRelation("viewer_and_editor",
    48  		nil,
    49  		ns.AllowedRelation("user", "..."),
    50  	),
    51  	ns.MustRelation("caveated_viewer",
    52  		nil,
    53  		ns.AllowedRelationWithCaveat("user", "...", ns.AllowedCaveat("test")),
    54  	),
    55  	ns.MustRelation("parent", nil, ns.AllowedRelation("folder", "...")),
    56  	ns.MustRelation("edit",
    57  		ns.Union(
    58  			ns.ComputedUserset("owner"),
    59  			ns.ComputedUserset("editor"),
    60  		),
    61  	),
    62  	ns.MustRelation("view",
    63  		ns.Union(
    64  			ns.ComputedUserset("viewer"),
    65  			ns.ComputedUserset("edit"),
    66  			ns.TupleToUserset("parent", "view"),
    67  		),
    68  	),
    69  	ns.MustRelation("view_and_edit",
    70  		ns.Intersection(
    71  			ns.ComputedUserset("viewer_and_editor"),
    72  			ns.ComputedUserset("edit"),
    73  		),
    74  	),
    75  )
    76  
    77  var FolderNS = ns.Namespace(
    78  	"folder",
    79  	ns.MustRelation("owner",
    80  		nil,
    81  		ns.AllowedRelation("user", "..."),
    82  	),
    83  	ns.MustRelation("editor",
    84  		nil,
    85  		ns.AllowedRelation("user", "..."),
    86  	),
    87  	ns.MustRelation("viewer",
    88  		nil,
    89  		ns.AllowedRelation("user", "..."),
    90  		ns.AllowedRelation("folder", "viewer"),
    91  	),
    92  	ns.MustRelation("parent", nil, ns.AllowedRelation("folder", "...")),
    93  	ns.MustRelation("edit",
    94  		ns.Union(
    95  			ns.ComputedUserset("editor"),
    96  			ns.ComputedUserset("owner"),
    97  		),
    98  	),
    99  	ns.MustRelation("view",
   100  		ns.Union(
   101  			ns.ComputedUserset("viewer"),
   102  			ns.ComputedUserset("edit"),
   103  			ns.TupleToUserset("parent", "view"),
   104  		),
   105  	),
   106  )
   107  
   108  // StandardTuples defines standard tuples for tests.
   109  // NOTE: some tests index directly into this slice, so if you're adding a new tuple, add it
   110  // at the *end*.
   111  var StandardTuples = []string{
   112  	"document:companyplan#parent@folder:company#...",
   113  	"document:masterplan#parent@folder:strategy#...",
   114  	"folder:strategy#parent@folder:company#...",
   115  	"folder:company#owner@user:owner#...",
   116  	"folder:company#viewer@user:legal#...",
   117  	"folder:strategy#owner@user:vp_product#...",
   118  	"document:masterplan#owner@user:product_manager#...",
   119  	"document:masterplan#viewer@user:eng_lead#...",
   120  	"document:masterplan#parent@folder:plans#...",
   121  	"folder:plans#viewer@user:chief_financial_officer#...",
   122  	"folder:auditors#viewer@user:auditor#...",
   123  	"folder:company#viewer@folder:auditors#viewer",
   124  	"document:healthplan#parent@folder:plans#...",
   125  	"folder:isolated#viewer@user:villain#...",
   126  	"document:specialplan#viewer_and_editor@user:multiroleguy#...",
   127  	"document:specialplan#editor@user:multiroleguy#...",
   128  	"document:specialplan#viewer_and_editor@user:missingrolegal#...",
   129  	"document:base64YWZzZGZh-ZHNmZHPwn5iK8J+YivC/fmIrwn5iK==#owner@user:base64YWZzZGZh-ZHNmZHPwn5iK8J+YivC/fmIrwn5iK==#...",
   130  	"document:veryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryverylong#owner@user:veryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryveryverylong#...",
   131  	"document:ownerplan#viewer@user:owner#...",
   132  }
   133  
   134  // EmptyDatastore returns an empty datastore for testing.
   135  func EmptyDatastore(ds datastore.Datastore, require *require.Assertions) (datastore.Datastore, datastore.Revision) {
   136  	rev, err := ds.HeadRevision(context.Background())
   137  	require.NoError(err)
   138  	return ds, rev
   139  }
   140  
   141  // StandardDatastoreWithSchema returns a datastore populated with the standard test definitions.
   142  func StandardDatastoreWithSchema(ds datastore.Datastore, require *require.Assertions) (datastore.Datastore, datastore.Revision) {
   143  	validating := NewValidatingDatastore(ds)
   144  	objectDefs := []*core.NamespaceDefinition{UserNS.CloneVT(), FolderNS.CloneVT(), DocumentNS.CloneVT()}
   145  	return validating, writeDefinitions(validating, require, objectDefs, []*core.CaveatDefinition{CaveatDef})
   146  }
   147  
   148  // StandardDatastoreWithData returns a datastore populated with both the standard test definitions
   149  // and relationships.
   150  func StandardDatastoreWithData(ds datastore.Datastore, require *require.Assertions) (datastore.Datastore, datastore.Revision) {
   151  	ds, _ = StandardDatastoreWithSchema(ds, require)
   152  	ctx := context.Background()
   153  
   154  	tuples := make([]*core.RelationTuple, 0, len(StandardTuples))
   155  	for _, tupleStr := range StandardTuples {
   156  		tpl := tuple.Parse(tupleStr)
   157  		require.NotNil(tpl)
   158  		tuples = append(tuples, tpl)
   159  	}
   160  	revision, err := common.WriteTuples(ctx, ds, core.RelationTupleUpdate_CREATE, tuples...)
   161  	require.NoError(err)
   162  
   163  	return ds, revision
   164  }
   165  
   166  // StandardDatastoreWithCaveatedData returns a datastore populated with both the standard test definitions
   167  // and some caveated relationships.
   168  func StandardDatastoreWithCaveatedData(ds datastore.Datastore, require *require.Assertions) (datastore.Datastore, datastore.Revision) {
   169  	ds, _ = StandardDatastoreWithSchema(ds, require)
   170  	ctx := context.Background()
   171  
   172  	_, err := ds.ReadWriteTx(ctx, func(ctx context.Context, tx datastore.ReadWriteTransaction) error {
   173  		return tx.WriteCaveats(ctx, createTestCaveat(require))
   174  	})
   175  	require.NoError(err)
   176  
   177  	caveatedTpls := make([]*core.RelationTuple, 0, len(StandardTuples))
   178  	for _, tupleStr := range StandardTuples {
   179  		tpl := tuple.Parse(tupleStr)
   180  		require.NotNil(tpl)
   181  		tpl.Caveat = &core.ContextualizedCaveat{
   182  			CaveatName: "test",
   183  			Context:    mustProtoStruct(map[string]any{"expectedSecret": "1234"}),
   184  		}
   185  		caveatedTpls = append(caveatedTpls, tpl)
   186  	}
   187  	revision, err := common.WriteTuples(ctx, ds, core.RelationTupleUpdate_CREATE, caveatedTpls...)
   188  	require.NoError(err)
   189  
   190  	return ds, revision
   191  }
   192  
   193  func createTestCaveat(require *require.Assertions) []*core.CaveatDefinition {
   194  	env, err := caveats.EnvForVariables(map[string]caveattypes.VariableType{
   195  		"secret":         caveattypes.StringType,
   196  		"expectedSecret": caveattypes.StringType,
   197  	})
   198  	require.NoError(err)
   199  
   200  	c, err := caveats.CompileCaveatWithName(env, "secret == expectedSecret", "test")
   201  	require.NoError(err)
   202  
   203  	cBytes, err := c.Serialize()
   204  	require.NoError(err)
   205  
   206  	return []*core.CaveatDefinition{{
   207  		Name:                 "test",
   208  		SerializedExpression: cBytes,
   209  		ParameterTypes:       env.EncodedParametersTypes(),
   210  	}}
   211  }
   212  
   213  // DatastoreFromSchemaAndTestRelationships returns a validating datastore wrapping that specified,
   214  // loaded with the given scehma and relationships.
   215  func DatastoreFromSchemaAndTestRelationships(ds datastore.Datastore, schema string, relationships []*core.RelationTuple, require *require.Assertions) (datastore.Datastore, datastore.Revision) {
   216  	ctx := context.Background()
   217  	validating := NewValidatingDatastore(ds)
   218  
   219  	compiled, err := compiler.Compile(compiler.InputSchema{
   220  		Source:       input.Source("schema"),
   221  		SchemaString: schema,
   222  	}, compiler.AllowUnprefixedObjectType())
   223  	require.NoError(err)
   224  
   225  	_ = writeDefinitions(validating, require, compiled.ObjectDefinitions, compiled.CaveatDefinitions)
   226  
   227  	newRevision, err := validating.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
   228  		mutations := make([]*core.RelationTupleUpdate, 0, len(relationships))
   229  		for _, rel := range relationships {
   230  			mutations = append(mutations, tuple.Create(rel.CloneVT()))
   231  		}
   232  		err = rwt.WriteRelationships(ctx, mutations)
   233  		require.NoError(err)
   234  
   235  		return nil
   236  	})
   237  	require.NoError(err)
   238  
   239  	return validating, newRevision
   240  }
   241  
   242  func writeDefinitions(ds datastore.Datastore, require *require.Assertions, objectDefs []*core.NamespaceDefinition, caveatDefs []*core.CaveatDefinition) datastore.Revision {
   243  	ctx := context.Background()
   244  	newRevision, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
   245  		if len(caveatDefs) > 0 {
   246  			err := rwt.WriteCaveats(ctx, caveatDefs)
   247  			require.NoError(err)
   248  		}
   249  
   250  		for _, nsDef := range objectDefs {
   251  			ts, err := typesystem.NewNamespaceTypeSystem(nsDef,
   252  				typesystem.ResolverForDatastoreReader(rwt).WithPredefinedElements(typesystem.PredefinedElements{
   253  					Namespaces: objectDefs,
   254  					Caveats:    caveatDefs,
   255  				}))
   256  			require.NoError(err)
   257  
   258  			vts, err := ts.Validate(ctx)
   259  			require.NoError(err)
   260  
   261  			aerr := namespace.AnnotateNamespace(vts)
   262  			require.NoError(aerr)
   263  
   264  			err = rwt.WriteNamespaces(ctx, nsDef)
   265  			require.NoError(err)
   266  		}
   267  
   268  		return nil
   269  	})
   270  	require.NoError(err)
   271  	return newRevision
   272  }
   273  
   274  // TupleChecker is a helper type which provides an easy way for collecting relationships/tuples from
   275  // an iterator and verify those found.
   276  type TupleChecker struct {
   277  	Require *require.Assertions
   278  	DS      datastore.Datastore
   279  }
   280  
   281  func (tc TupleChecker) ExactRelationshipIterator(ctx context.Context, tpl *core.RelationTuple, rev datastore.Revision) datastore.RelationshipIterator {
   282  	filter := tuple.MustToFilter(tpl)
   283  	dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(filter)
   284  	tc.Require.NoError(err)
   285  
   286  	iter, err := tc.DS.SnapshotReader(rev).QueryRelationships(ctx, dsFilter)
   287  	tc.Require.NoError(err)
   288  	return iter
   289  }
   290  
   291  func (tc TupleChecker) VerifyIteratorCount(iter datastore.RelationshipIterator, count int) {
   292  	foundCount := 0
   293  	for found := iter.Next(); found != nil; found = iter.Next() {
   294  		foundCount++
   295  	}
   296  	tc.Require.NoError(iter.Err())
   297  	tc.Require.Equal(count, foundCount)
   298  }
   299  
   300  func (tc TupleChecker) VerifyIteratorResults(iter datastore.RelationshipIterator, tpls ...*core.RelationTuple) {
   301  	defer iter.Close()
   302  
   303  	toFind := make(map[string]struct{}, 1024)
   304  
   305  	for _, tpl := range tpls {
   306  		toFind[tuple.MustString(tpl)] = struct{}{}
   307  	}
   308  
   309  	for found := iter.Next(); found != nil; found = iter.Next() {
   310  		tc.Require.NoError(iter.Err())
   311  		foundStr := tuple.MustString(found)
   312  		_, ok := toFind[foundStr]
   313  		tc.Require.True(ok, "found unexpected tuple %s in iterator", foundStr)
   314  		delete(toFind, foundStr)
   315  	}
   316  	tc.Require.NoError(iter.Err())
   317  
   318  	tc.Require.Zero(len(toFind), "did not find some expected tuples: %#v", toFind)
   319  }
   320  
   321  func (tc TupleChecker) VerifyOrderedIteratorResults(iter datastore.RelationshipIterator, tpls ...*core.RelationTuple) {
   322  	for _, tpl := range tpls {
   323  		expectedStr := tuple.MustString(tpl)
   324  
   325  		found := iter.Next()
   326  		tc.Require.NotNil(found, "expected %s, but found no additional results", expectedStr)
   327  
   328  		foundStr := tuple.MustString(found)
   329  		tc.Require.Equal(expectedStr, foundStr)
   330  	}
   331  
   332  	pastLast := iter.Next()
   333  	tc.Require.Nil(pastLast)
   334  	tc.Require.Nil(iter.Err())
   335  }
   336  
   337  func (tc TupleChecker) TupleExists(ctx context.Context, tpl *core.RelationTuple, rev datastore.Revision) {
   338  	iter := tc.ExactRelationshipIterator(ctx, tpl, rev)
   339  	tc.VerifyIteratorResults(iter, tpl)
   340  }
   341  
   342  func (tc TupleChecker) NoTupleExists(ctx context.Context, tpl *core.RelationTuple, rev datastore.Revision) {
   343  	iter := tc.ExactRelationshipIterator(ctx, tpl, rev)
   344  	tc.VerifyIteratorResults(iter)
   345  }
   346  
   347  func mustProtoStruct(in map[string]any) *structpb.Struct {
   348  	out, err := structpb.NewStruct(in)
   349  	if err != nil {
   350  		panic(err)
   351  	}
   352  	return out
   353  }