github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/proxy/schemacaching/standardcaching_test.go (about)

     1  package schemacaching
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sync"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/stretchr/testify/require"
    12  	"golang.org/x/sync/errgroup"
    13  
    14  	"github.com/authzed/spicedb/internal/datastore/memdb"
    15  	"github.com/authzed/spicedb/internal/datastore/proxy/proxy_test"
    16  	"github.com/authzed/spicedb/internal/datastore/revisions"
    17  	"github.com/authzed/spicedb/pkg/caveats"
    18  	caveattypes "github.com/authzed/spicedb/pkg/caveats/types"
    19  	"github.com/authzed/spicedb/pkg/datastore"
    20  	"github.com/authzed/spicedb/pkg/datastore/options"
    21  	"github.com/authzed/spicedb/pkg/genutil/mapz"
    22  	ns "github.com/authzed/spicedb/pkg/namespace"
    23  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    24  	"github.com/authzed/spicedb/pkg/testutil"
    25  )
    26  
    27  var (
    28  	old  = revisions.NewForTransactionID(0)
    29  	zero = revisions.NewForTransactionID(1)
    30  	one  = revisions.NewForTransactionID(2)
    31  	two  = revisions.NewForTransactionID(3)
    32  
    33  	nilOpts []options.RWTOptionsOption
    34  )
    35  
    36  const (
    37  	nsA = "namespace_a"
    38  	nsB = "namespace_b"
    39  )
    40  
    41  // TestNilUnmarshal asserts that if we get a nil NamespaceDefinition from a
    42  // datastore implementation, the process of inserting it into the cache and
    43  // back does not break anything.
    44  func TestNilUnmarshal(t *testing.T) {
    45  	nsDef := (*core.NamespaceDefinition)(nil)
    46  	marshalled, err := nsDef.MarshalVT()
    47  	require.Nil(t, err)
    48  
    49  	var newDef *core.NamespaceDefinition
    50  	err = nsDef.UnmarshalVT(marshalled)
    51  	require.Nil(t, err)
    52  	require.Equal(t, nsDef, newDef)
    53  }
    54  
    55  type testerDef struct {
    56  	name                   string
    57  	readSingleFunctionName string
    58  	readSingleFunc         func(ctx context.Context, reader datastore.Reader, name string) (datastore.SchemaDefinition, datastore.Revision, error)
    59  
    60  	lookupFunctionName string
    61  	lookupFunc         func(ctx context.Context, reader datastore.Reader, names []string) ([]datastore.SchemaDefinition, error)
    62  
    63  	notFoundErr error
    64  
    65  	writeFunctionName string
    66  	writeFunc         func(rwt datastore.ReadWriteTransaction, def datastore.SchemaDefinition) error
    67  
    68  	createDef      func(name string) datastore.SchemaDefinition
    69  	wrap           func(def datastore.SchemaDefinition) any
    70  	wrapRevisioned func(def datastore.SchemaDefinition) any
    71  }
    72  
    73  var testers = []testerDef{
    74  	{
    75  		"namespace",
    76  
    77  		"ReadNamespaceByName",
    78  		func(ctx context.Context, reader datastore.Reader, name string) (datastore.SchemaDefinition, datastore.Revision, error) {
    79  			return reader.ReadNamespaceByName(ctx, name)
    80  		},
    81  
    82  		"LookupNamespacesWithNames",
    83  		func(ctx context.Context, reader datastore.Reader, names []string) ([]datastore.SchemaDefinition, error) {
    84  			defs, err := reader.LookupNamespacesWithNames(ctx, names)
    85  			if err != nil {
    86  				return nil, err
    87  			}
    88  			schemaDefs := []datastore.SchemaDefinition{}
    89  			for _, def := range defs {
    90  				schemaDefs = append(schemaDefs, def.Definition)
    91  			}
    92  			return schemaDefs, nil
    93  		},
    94  
    95  		datastore.ErrNamespaceNotFound{},
    96  
    97  		"WriteNamespaces",
    98  		func(rwt datastore.ReadWriteTransaction, def datastore.SchemaDefinition) error {
    99  			return rwt.WriteNamespaces(context.Background(), def.(*core.NamespaceDefinition))
   100  		},
   101  
   102  		func(name string) datastore.SchemaDefinition { return &core.NamespaceDefinition{Name: name} },
   103  		func(def datastore.SchemaDefinition) any {
   104  			return []*core.NamespaceDefinition{def.(*core.NamespaceDefinition)}
   105  		},
   106  		func(def datastore.SchemaDefinition) any {
   107  			return []datastore.RevisionedNamespace{{Definition: def.(*core.NamespaceDefinition)}}
   108  		},
   109  	},
   110  	{
   111  		"caveat",
   112  		"ReadCaveatByName",
   113  		func(ctx context.Context, reader datastore.Reader, name string) (datastore.SchemaDefinition, datastore.Revision, error) {
   114  			return reader.ReadCaveatByName(ctx, name)
   115  		},
   116  
   117  		"LookupCaveatsWithNames",
   118  		func(ctx context.Context, reader datastore.Reader, names []string) ([]datastore.SchemaDefinition, error) {
   119  			defs, err := reader.LookupCaveatsWithNames(ctx, names)
   120  			if err != nil {
   121  				return nil, err
   122  			}
   123  			schemaDefs := []datastore.SchemaDefinition{}
   124  			for _, def := range defs {
   125  				schemaDefs = append(schemaDefs, def.Definition)
   126  			}
   127  			return schemaDefs, nil
   128  		},
   129  
   130  		datastore.ErrCaveatNameNotFound{},
   131  
   132  		"WriteCaveats",
   133  		func(rwt datastore.ReadWriteTransaction, def datastore.SchemaDefinition) error {
   134  			return rwt.WriteCaveats(context.Background(), []*core.CaveatDefinition{def.(*core.CaveatDefinition)})
   135  		},
   136  
   137  		func(name string) datastore.SchemaDefinition { return &core.CaveatDefinition{Name: name} },
   138  		func(def datastore.SchemaDefinition) any {
   139  			return []*core.CaveatDefinition{def.(*core.CaveatDefinition)}
   140  		},
   141  		func(def datastore.SchemaDefinition) any {
   142  			return []datastore.RevisionedCaveat{{Definition: def.(*core.CaveatDefinition)}}
   143  		},
   144  	},
   145  }
   146  
   147  func TestSnapshotCaching(t *testing.T) {
   148  	for _, tester := range testers {
   149  		tester := tester
   150  		t.Run(tester.name, func(t *testing.T) {
   151  			dsMock := &proxy_test.MockDatastore{}
   152  
   153  			oneReader := &proxy_test.MockReader{}
   154  			dsMock.On("SnapshotReader", one).Return(oneReader)
   155  			oneReader.On(tester.readSingleFunctionName, nsA).Return(nil, old, nil).Once()
   156  			oneReader.On(tester.readSingleFunctionName, nsB).Return(nil, zero, nil).Once()
   157  
   158  			twoReader := &proxy_test.MockReader{}
   159  			dsMock.On("SnapshotReader", two).Return(twoReader)
   160  			twoReader.On(tester.readSingleFunctionName, nsA).Return(nil, zero, nil).Once()
   161  			twoReader.On(tester.readSingleFunctionName, nsB).Return(nil, one, nil).Once()
   162  
   163  			require := require.New(t)
   164  			ds := NewCachingDatastoreProxy(dsMock, DatastoreProxyTestCache(t), 1*time.Hour, JustInTimeCaching, 100*time.Millisecond)
   165  
   166  			_, updatedOneA, err := tester.readSingleFunc(context.Background(), ds.SnapshotReader(one), nsA)
   167  			require.NoError(err)
   168  			require.True(old.Equal(updatedOneA))
   169  
   170  			_, updatedOneAAgain, err := tester.readSingleFunc(context.Background(), ds.SnapshotReader(one), nsA)
   171  			require.NoError(err)
   172  			require.True(old.Equal(updatedOneAAgain))
   173  
   174  			_, updatedOneB, err := tester.readSingleFunc(context.Background(), ds.SnapshotReader(one), nsB)
   175  			require.NoError(err)
   176  			require.True(zero.Equal(updatedOneB))
   177  
   178  			_, updatedOneBAgain, err := tester.readSingleFunc(context.Background(), ds.SnapshotReader(one), nsB)
   179  			require.NoError(err)
   180  			require.True(zero.Equal(updatedOneBAgain))
   181  
   182  			_, updatedTwoA, err := tester.readSingleFunc(context.Background(), ds.SnapshotReader(two), nsA)
   183  			require.NoError(err)
   184  			require.True(zero.Equal(updatedTwoA))
   185  
   186  			_, updatedTwoAAgain, err := tester.readSingleFunc(context.Background(), ds.SnapshotReader(two), nsA)
   187  			require.NoError(err)
   188  			require.True(zero.Equal(updatedTwoAAgain))
   189  
   190  			_, updatedTwoB, err := tester.readSingleFunc(context.Background(), ds.SnapshotReader(two), nsB)
   191  			require.NoError(err)
   192  			require.True(one.Equal(updatedTwoB))
   193  
   194  			_, updatedTwoBAgain, err := tester.readSingleFunc(context.Background(), ds.SnapshotReader(two), nsB)
   195  			require.NoError(err)
   196  			require.True(one.Equal(updatedTwoBAgain))
   197  
   198  			dsMock.AssertExpectations(t)
   199  			oneReader.AssertExpectations(t)
   200  			twoReader.AssertExpectations(t)
   201  		})
   202  	}
   203  }
   204  
   205  func TestRWTCaching(t *testing.T) {
   206  	for _, tester := range testers {
   207  		tester := tester
   208  		t.Run(tester.name, func(t *testing.T) {
   209  			dsMock := &proxy_test.MockDatastore{}
   210  			rwtMock := &proxy_test.MockReadWriteTransaction{}
   211  
   212  			require := require.New(t)
   213  
   214  			dsMock.On("ReadWriteTx", nilOpts).Return(rwtMock, one, nil).Once()
   215  			rwtMock.On(tester.readSingleFunctionName, nsA).Return(nil, zero, nil).Once()
   216  
   217  			ctx := context.Background()
   218  
   219  			ds := NewCachingDatastoreProxy(dsMock, nil, 1*time.Hour, JustInTimeCaching, 100*time.Millisecond)
   220  
   221  			rev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
   222  				_, updatedA, err := tester.readSingleFunc(ctx, rwt, nsA)
   223  				require.NoError(err)
   224  				require.True(zero.Equal(updatedA))
   225  
   226  				// This will not call out the mock RWT again, the mock will panic if it does.
   227  				_, updatedA, err = tester.readSingleFunc(ctx, rwt, nsA)
   228  				require.NoError(err)
   229  				require.True(zero.Equal(updatedA))
   230  
   231  				return nil
   232  			})
   233  			require.True(one.Equal(rev))
   234  			require.NoError(err)
   235  
   236  			dsMock.AssertExpectations(t)
   237  			rwtMock.AssertExpectations(t)
   238  		})
   239  	}
   240  }
   241  
   242  func TestRWTCacheWithWrites(t *testing.T) {
   243  	for _, tester := range testers {
   244  		tester := tester
   245  		t.Run(tester.name, func(t *testing.T) {
   246  			dsMock := &proxy_test.MockDatastore{}
   247  			rwtMock := &proxy_test.MockReadWriteTransaction{}
   248  
   249  			require := require.New(t)
   250  
   251  			dsMock.On("ReadWriteTx", nilOpts).Return(rwtMock, one, nil).Once()
   252  			rwtMock.On(tester.readSingleFunctionName, nsA).Return(nil, zero, tester.notFoundErr).Once()
   253  
   254  			ctx := context.Background()
   255  
   256  			ds := NewCachingDatastoreProxy(dsMock, nil, 1*time.Hour, JustInTimeCaching, 100*time.Millisecond)
   257  
   258  			rev, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
   259  				// Cache the 404
   260  				_, _, err := tester.readSingleFunc(ctx, rwt, nsA)
   261  				require.Error(err, tester.notFoundErr)
   262  
   263  				// This will not call out the mock RWT again, the mock will panic if it does.
   264  				_, _, err = tester.readSingleFunc(ctx, rwt, nsA)
   265  				require.Error(err, tester.notFoundErr)
   266  
   267  				// Write nsA
   268  				def := tester.createDef(nsA)
   269  				rwtMock.On(tester.writeFunctionName, tester.wrap(def)).Return(nil).Once()
   270  				require.NoError(tester.writeFunc(rwt, def))
   271  
   272  				// Call Read* on nsA and we should flow through to the mock
   273  				rwtMock.On(tester.readSingleFunctionName, nsA).Return(def, zero, nil).Once()
   274  				def, updatedA, err := tester.readSingleFunc(ctx, rwt, nsA)
   275  
   276  				require.True(updatedA.Equal(zero))
   277  				require.NotNil(def)
   278  				require.NoError(err)
   279  
   280  				return nil
   281  			})
   282  			require.True(one.Equal(rev))
   283  			require.NoError(err)
   284  
   285  			dsMock.AssertExpectations(t)
   286  			rwtMock.AssertExpectations(t)
   287  		})
   288  	}
   289  }
   290  
   291  func TestSingleFlight(t *testing.T) {
   292  	for _, tester := range testers {
   293  		tester := tester
   294  		t.Run(tester.name, func(t *testing.T) {
   295  			dsMock := &proxy_test.MockDatastore{}
   296  
   297  			oneReader := &proxy_test.MockReader{}
   298  			dsMock.On("SnapshotReader", one).Return(oneReader)
   299  			oneReader.
   300  				On(tester.readSingleFunctionName, nsA).
   301  				WaitUntil(time.After(50*time.Millisecond)).
   302  				Return(nil, old, nil).
   303  				Once()
   304  
   305  			require := require.New(t)
   306  
   307  			ds := NewCachingDatastoreProxy(dsMock, nil, 1*time.Hour, JustInTimeCaching, 100*time.Millisecond)
   308  
   309  			readNamespace := func() error {
   310  				_, updatedAt, err := tester.readSingleFunc(context.Background(), ds.SnapshotReader(one), nsA)
   311  				require.NoError(err)
   312  				require.True(old.Equal(updatedAt))
   313  				return err
   314  			}
   315  
   316  			g := errgroup.Group{}
   317  			g.Go(readNamespace)
   318  			g.Go(readNamespace)
   319  
   320  			require.NoError(g.Wait())
   321  
   322  			dsMock.AssertExpectations(t)
   323  			oneReader.AssertExpectations(t)
   324  		})
   325  	}
   326  }
   327  
   328  func TestSnapshotCachingRealDatastore(t *testing.T) {
   329  	tcs := []struct {
   330  		name          string
   331  		nsDef         *core.NamespaceDefinition
   332  		namespaceName string
   333  		caveatDef     *core.CaveatDefinition
   334  		caveatName    string
   335  	}{
   336  		{
   337  			"missing",
   338  			nil,
   339  			"somenamespace",
   340  			nil,
   341  			"somecaveat",
   342  		},
   343  		{
   344  			"defined",
   345  			ns.Namespace(
   346  				"document",
   347  				ns.MustRelation("owner",
   348  					nil,
   349  					ns.AllowedRelation("user", "..."),
   350  				),
   351  				ns.MustRelation("editor",
   352  					nil,
   353  					ns.AllowedRelation("user", "..."),
   354  				),
   355  			),
   356  			"document",
   357  			ns.MustCaveatDefinition(caveats.MustEnvForVariables(
   358  				map[string]caveattypes.VariableType{
   359  					"somevar": caveattypes.IntType,
   360  				},
   361  			), "somecaveat", "somevar < 42"),
   362  			"somecaveat",
   363  		},
   364  	}
   365  
   366  	for _, tc := range tcs {
   367  		tc := tc
   368  		t.Run(tc.name, func(t *testing.T) {
   369  			rawDS, err := memdb.NewMemdbDatastore(0, 0, memdb.DisableGC)
   370  			require.NoError(t, err)
   371  
   372  			ctx := context.Background()
   373  			ds := NewCachingDatastoreProxy(rawDS, nil, 1*time.Hour, JustInTimeCaching, 100*time.Millisecond)
   374  
   375  			if tc.nsDef != nil {
   376  				_, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
   377  					err := rwt.WriteNamespaces(ctx, tc.nsDef)
   378  					if err != nil {
   379  						return err
   380  					}
   381  
   382  					return rwt.WriteCaveats(ctx, []*core.CaveatDefinition{tc.caveatDef})
   383  				})
   384  				require.NoError(t, err)
   385  			}
   386  
   387  			headRev, err := ds.HeadRevision(ctx)
   388  			require.NoError(t, err)
   389  
   390  			reader := ds.SnapshotReader(headRev)
   391  			ns, _, _ := reader.ReadNamespaceByName(ctx, tc.namespaceName)
   392  			testutil.RequireProtoEqual(t, tc.nsDef, ns, "found different namespaces")
   393  
   394  			ns2, _, _ := reader.ReadNamespaceByName(ctx, tc.namespaceName)
   395  			testutil.RequireProtoEqual(t, tc.nsDef, ns2, "found different namespaces")
   396  
   397  			c1, _, _ := reader.ReadCaveatByName(ctx, tc.caveatName)
   398  			testutil.RequireProtoEqual(t, tc.caveatDef, c1, "found different caveats")
   399  
   400  			c2, _, _ := reader.ReadCaveatByName(ctx, tc.caveatName)
   401  			testutil.RequireProtoEqual(t, tc.caveatDef, c2, "found different caveats")
   402  		})
   403  	}
   404  }
   405  
   406  type reader struct {
   407  	proxy_test.MockReader
   408  }
   409  
   410  func (r *reader) ReadNamespaceByName(ctx context.Context, namespace string) (ns *core.NamespaceDefinition, lastWritten datastore.Revision, err error) {
   411  	time.Sleep(10 * time.Millisecond)
   412  	if errors.Is(ctx.Err(), context.Canceled) {
   413  		return nil, old, fmt.Errorf("error")
   414  	}
   415  	return &core.NamespaceDefinition{Name: namespace}, old, nil
   416  }
   417  
   418  func (r *reader) ReadCaveatByName(ctx context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) {
   419  	time.Sleep(10 * time.Millisecond)
   420  	if errors.Is(ctx.Err(), context.Canceled) {
   421  		return nil, old, fmt.Errorf("error")
   422  	}
   423  	return &core.CaveatDefinition{Name: name}, old, nil
   424  }
   425  
   426  func TestSingleFlightCancelled(t *testing.T) {
   427  	for _, tester := range testers {
   428  		tester := tester
   429  		t.Run(tester.name, func(t *testing.T) {
   430  			dsMock := &proxy_test.MockDatastore{}
   431  			ctx1, cancel1 := context.WithCancel(context.Background())
   432  			ctx2, cancel2 := context.WithCancel(context.Background())
   433  			defer cancel2()
   434  			defer cancel1()
   435  
   436  			dsMock.On("SnapshotReader", one).Return(&reader{MockReader: proxy_test.MockReader{}})
   437  
   438  			ds := NewCachingDatastoreProxy(dsMock, nil, 1*time.Hour, JustInTimeCaching, 100*time.Millisecond)
   439  
   440  			g := sync.WaitGroup{}
   441  			var d2 datastore.SchemaDefinition
   442  			g.Add(2)
   443  			go func() {
   444  				_, _, _ = tester.readSingleFunc(ctx1, ds.SnapshotReader(one), nsA)
   445  				g.Done()
   446  			}()
   447  			go func() {
   448  				time.Sleep(5 * time.Millisecond)
   449  				d2, _, _ = tester.readSingleFunc(ctx2, ds.SnapshotReader(one), nsA)
   450  				g.Done()
   451  			}()
   452  			cancel1()
   453  
   454  			g.Wait()
   455  			require.NotNil(t, d2)
   456  			require.Equal(t, nsA, d2.GetName())
   457  
   458  			dsMock.AssertExpectations(t)
   459  		})
   460  	}
   461  }
   462  
   463  func TestMixedCaching(t *testing.T) {
   464  	for _, tester := range testers {
   465  		tester := tester
   466  		t.Run(tester.name, func(t *testing.T) {
   467  			dsMock := &proxy_test.MockDatastore{}
   468  
   469  			defA := tester.createDef(nsA)
   470  			defB := tester.createDef(nsB)
   471  
   472  			reader := &proxy_test.MockReader{}
   473  			reader.On(tester.readSingleFunctionName, nsA).Return(defA, old, nil).Once()
   474  			reader.On(tester.lookupFunctionName, []string{nsB}).Return(tester.wrapRevisioned(defB), nil).Once()
   475  
   476  			dsMock.On("SnapshotReader", one).Return(reader)
   477  
   478  			require := require.New(t)
   479  			ds := NewCachingDatastoreProxy(dsMock, DatastoreProxyTestCache(t), 1*time.Hour, JustInTimeCaching, 100*time.Millisecond)
   480  
   481  			dsReader := ds.SnapshotReader(one)
   482  
   483  			// Lookup name A
   484  			_, _, err := tester.readSingleFunc(context.Background(), dsReader, nsA)
   485  			require.NoError(err)
   486  
   487  			// Lookup A and B, which should only lookup B and use A from cache.
   488  			found, err := tester.lookupFunc(context.Background(), dsReader, []string{nsA, nsB})
   489  			require.NoError(err)
   490  			require.Equal(2, len(found))
   491  
   492  			names := mapz.NewSet[string]()
   493  			for _, d := range found {
   494  				names.Add(d.GetName())
   495  			}
   496  
   497  			require.True(names.Has(nsA))
   498  			require.True(names.Has(nsB))
   499  
   500  			// Lookup A and B, which should use both from cache.
   501  			foundAgain, err := tester.lookupFunc(context.Background(), dsReader, []string{nsA, nsB})
   502  			require.NoError(err)
   503  			require.Equal(2, len(foundAgain))
   504  
   505  			namesAgain := mapz.NewSet[string]()
   506  			for _, d := range foundAgain {
   507  				namesAgain.Add(d.GetName())
   508  			}
   509  
   510  			require.True(namesAgain.Has(nsA))
   511  			require.True(namesAgain.Has(nsB))
   512  
   513  			dsMock.AssertExpectations(t)
   514  			reader.AssertExpectations(t)
   515  		})
   516  	}
   517  }