github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/memdb/memdb_test.go (about)

     1  package memdb
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sync/atomic"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/stretchr/testify/require"
    12  	"golang.org/x/sync/errgroup"
    13  
    14  	"github.com/authzed/spicedb/pkg/datastore"
    15  	"github.com/authzed/spicedb/pkg/datastore/options"
    16  	test "github.com/authzed/spicedb/pkg/datastore/test"
    17  	ns "github.com/authzed/spicedb/pkg/namespace"
    18  	corev1 "github.com/authzed/spicedb/pkg/proto/core/v1"
    19  	"github.com/authzed/spicedb/pkg/tuple"
    20  )
    21  
    22  type memDBTest struct{}
    23  
    24  func (mdbt memDBTest) New(revisionQuantization, _, gcWindow time.Duration, watchBufferLength uint16) (datastore.Datastore, error) {
    25  	return NewMemdbDatastore(watchBufferLength, revisionQuantization, gcWindow)
    26  }
    27  
    28  func TestMemdbDatastore(t *testing.T) {
    29  	test.All(t, memDBTest{})
    30  }
    31  
    32  func TestConcurrentWritePanic(t *testing.T) {
    33  	require := require.New(t)
    34  
    35  	ds, err := NewMemdbDatastore(0, 1*time.Hour, 1*time.Hour)
    36  	require.NoError(err)
    37  
    38  	ctx := context.Background()
    39  	recoverErr := errors.New("panic")
    40  
    41  	// Make the namespace very large to increase the likelihood of overlapping
    42  	relationList := make([]*corev1.Relation, 0, 1000)
    43  	for i := 0; i < 1000; i++ {
    44  		relationList = append(relationList, ns.MustRelation(fmt.Sprintf("reader%d", i), nil))
    45  	}
    46  
    47  	numPanics := uint64(0)
    48  	require.Eventually(func() bool {
    49  		_, err = ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
    50  			g := errgroup.Group{}
    51  			g.Go(func() (err error) {
    52  				defer func() {
    53  					if rec := recover(); rec != nil {
    54  						atomic.AddUint64(&numPanics, 1)
    55  						err = recoverErr
    56  					}
    57  				}()
    58  
    59  				return rwt.WriteNamespaces(ctx, ns.Namespace(
    60  					"resource",
    61  					relationList...,
    62  				))
    63  			})
    64  
    65  			g.Go(func() (err error) {
    66  				defer func() {
    67  					if rec := recover(); rec != nil {
    68  						atomic.AddUint64(&numPanics, 1)
    69  						err = recoverErr
    70  					}
    71  				}()
    72  
    73  				return rwt.WriteNamespaces(ctx, ns.Namespace("user", relationList...))
    74  			})
    75  
    76  			return g.Wait()
    77  		})
    78  		return numPanics > 0
    79  	}, 3*time.Second, 10*time.Millisecond)
    80  	require.ErrorIs(err, recoverErr)
    81  }
    82  
    83  func TestConcurrentWriteRelsError(t *testing.T) {
    84  	require := require.New(t)
    85  
    86  	ds, err := NewMemdbDatastore(0, 1*time.Hour, 1*time.Hour)
    87  	require.NoError(err)
    88  
    89  	ctx := context.Background()
    90  
    91  	// Kick off a number of writes to ensure at least one hits an error.
    92  	g := errgroup.Group{}
    93  
    94  	for i := 0; i < 50; i++ {
    95  		i := i
    96  		g.Go(func() error {
    97  			_, err := ds.ReadWriteTx(ctx, func(ctx context.Context, rwt datastore.ReadWriteTransaction) error {
    98  				updates := []*corev1.RelationTupleUpdate{}
    99  				for j := 0; j < 500; j++ {
   100  					updates = append(updates, &corev1.RelationTupleUpdate{
   101  						Operation: corev1.RelationTupleUpdate_TOUCH,
   102  						Tuple:     tuple.MustParse(fmt.Sprintf("document:doc-%d-%d#viewer@user:tom", i, j)),
   103  					})
   104  				}
   105  
   106  				return rwt.WriteRelationships(ctx, updates)
   107  			}, options.WithDisableRetries(true))
   108  			return err
   109  		})
   110  	}
   111  
   112  	werr := g.Wait()
   113  	require.Error(werr)
   114  	require.ErrorContains(werr, "serialization max retries exceeded")
   115  }