github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/revisions/optimized_test.go (about)

     1  package revisions
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/benbjohnson/clock"
    10  	"github.com/samber/lo"
    11  	"github.com/stretchr/testify/mock"
    12  	"github.com/stretchr/testify/require"
    13  	"golang.org/x/sync/errgroup"
    14  
    15  	"github.com/authzed/spicedb/pkg/datastore"
    16  )
    17  
    18  type trackingRevisionFunction struct {
    19  	mock.Mock
    20  }
    21  
    22  func (m *trackingRevisionFunction) optimizedRevisionFunc(_ context.Context) (datastore.Revision, time.Duration, error) {
    23  	args := m.Called()
    24  	return args.Get(0).(datastore.Revision), args.Get(1).(time.Duration), args.Error(2)
    25  }
    26  
    27  var (
    28  	one   = NewForTransactionID(1)
    29  	two   = NewForTransactionID(2)
    30  	three = NewForTransactionID(3)
    31  )
    32  
    33  func cand(revs ...datastore.Revision) []datastore.Revision {
    34  	return revs
    35  }
    36  
    37  func TestOptimizedRevisionCache(t *testing.T) {
    38  	type revisionResponse struct {
    39  		rev      datastore.Revision
    40  		validFor time.Duration
    41  	}
    42  
    43  	testCases := []struct {
    44  		name                  string
    45  		maxStaleness          time.Duration
    46  		expectedCallResponses []revisionResponse
    47  		expectedRevisions     [][]datastore.Revision
    48  	}{
    49  		{
    50  			"single request",
    51  			0,
    52  			[]revisionResponse{
    53  				{one, 0},
    54  			},
    55  			[][]datastore.Revision{cand(one)},
    56  		},
    57  		{
    58  			"simple no caching request",
    59  			0,
    60  			[]revisionResponse{
    61  				{one, 0},
    62  				{two, 0},
    63  				{three, 0},
    64  			},
    65  			[][]datastore.Revision{cand(one), cand(two), cand(three)},
    66  		},
    67  		{
    68  			"simple cached once",
    69  			0,
    70  			[]revisionResponse{
    71  				{one, 7 * time.Millisecond},
    72  				{two, 0},
    73  			},
    74  			[][]datastore.Revision{cand(one), cand(one), cand(two)},
    75  		},
    76  		{
    77  			"cached by staleness",
    78  			7 * time.Millisecond,
    79  			[]revisionResponse{
    80  				{one, 0},
    81  				{two, 100 * time.Millisecond},
    82  			},
    83  			[][]datastore.Revision{cand(one), cand(one, two), cand(two), cand(two)},
    84  		},
    85  		{
    86  			"cached by staleness and validity",
    87  			2 * time.Millisecond,
    88  			[]revisionResponse{
    89  				{one, 4 * time.Millisecond},
    90  				{two, 100 * time.Millisecond},
    91  			},
    92  			[][]datastore.Revision{cand(one), cand(one, two), cand(two)},
    93  		},
    94  		{
    95  			"cached for a while",
    96  			0,
    97  			[]revisionResponse{
    98  				{one, 28 * time.Millisecond},
    99  				{two, 0},
   100  			},
   101  			[][]datastore.Revision{cand(one), cand(one), cand(one), cand(one), cand(one), cand(one), cand(two)},
   102  		},
   103  	}
   104  
   105  	for _, tc := range testCases {
   106  		tc := tc
   107  		t.Run(tc.name, func(t *testing.T) {
   108  			require := require.New(t)
   109  
   110  			or := NewCachedOptimizedRevisions(tc.maxStaleness)
   111  			mockTime := clock.NewMock()
   112  			or.clockFn = mockTime
   113  			mock := trackingRevisionFunction{}
   114  			or.SetOptimizedRevisionFunc(mock.optimizedRevisionFunc)
   115  
   116  			for _, callSpec := range tc.expectedCallResponses {
   117  				mock.On("optimizedRevisionFunc").Return(callSpec.rev, callSpec.validFor, nil).Once()
   118  			}
   119  
   120  			ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   121  			defer cancel()
   122  
   123  			for _, expectedRevSet := range tc.expectedRevisions {
   124  				awaitingRevisions := make(map[datastore.Revision]struct{}, len(expectedRevSet))
   125  				for _, rev := range expectedRevSet {
   126  					awaitingRevisions[rev] = struct{}{}
   127  				}
   128  
   129  				require.Eventually(func() bool {
   130  					revision, err := or.OptimizedRevision(ctx)
   131  					require.NoError(err)
   132  					printableRevSet := lo.Map(expectedRevSet, func(val datastore.Revision, index int) string {
   133  						return val.String()
   134  					})
   135  					require.Contains(expectedRevSet, revision, "must return the proper revision, allowed set %#v, received %s", printableRevSet, revision)
   136  
   137  					delete(awaitingRevisions, revision)
   138  					return len(awaitingRevisions) == 0
   139  				}, 1*time.Second, 1*time.Microsecond)
   140  
   141  				mockTime.Add(5 * time.Millisecond)
   142  			}
   143  
   144  			mock.AssertExpectations(t)
   145  		})
   146  	}
   147  }
   148  
   149  func TestOptimizedRevisionCacheSingleFlight(t *testing.T) {
   150  	require := require.New(t)
   151  
   152  	or := NewCachedOptimizedRevisions(0)
   153  	mock := trackingRevisionFunction{}
   154  	or.SetOptimizedRevisionFunc(mock.optimizedRevisionFunc)
   155  
   156  	mock.
   157  		On("optimizedRevisionFunc").
   158  		Return(one, time.Duration(0), nil).
   159  		After(50 * time.Millisecond).
   160  		Once()
   161  
   162  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   163  	defer cancel()
   164  
   165  	g := errgroup.Group{}
   166  	for i := 0; i < 10; i++ {
   167  		g.Go(func() error {
   168  			revision, err := or.OptimizedRevision(ctx)
   169  			if err != nil {
   170  				return err
   171  			}
   172  			require.True(one.Equal(revision), "must return the proper revision %s != %s", one, revision)
   173  			return nil
   174  		})
   175  		time.Sleep(1 * time.Millisecond)
   176  	}
   177  
   178  	err := g.Wait()
   179  	require.NoError(err)
   180  
   181  	mock.AssertExpectations(t)
   182  }
   183  
   184  func BenchmarkOptimizedRevisions(b *testing.B) {
   185  	b.SetParallelism(1024)
   186  
   187  	quantization := 1 * time.Millisecond
   188  	or := NewCachedOptimizedRevisions(quantization)
   189  
   190  	or.SetOptimizedRevisionFunc(func(ctx context.Context) (datastore.Revision, time.Duration, error) {
   191  		nowNS := time.Now().UnixNano()
   192  		validForNS := nowNS % quantization.Nanoseconds()
   193  		roundedNS := nowNS - validForNS
   194  		rev := NewForTransactionID(uint64(roundedNS))
   195  		return rev, time.Duration(validForNS) * time.Nanosecond, nil
   196  	})
   197  
   198  	ctx := context.Background()
   199  	b.RunParallel(func(p *testing.PB) {
   200  		for p.Next() {
   201  			if _, err := or.OptimizedRevision(ctx); err != nil {
   202  				b.FailNow()
   203  			}
   204  		}
   205  	})
   206  }
   207  
   208  func TestSingleFlightError(t *testing.T) {
   209  	req := require.New(t)
   210  
   211  	or := NewCachedOptimizedRevisions(0)
   212  	mock := trackingRevisionFunction{}
   213  	or.SetOptimizedRevisionFunc(mock.optimizedRevisionFunc)
   214  
   215  	mock.
   216  		On("optimizedRevisionFunc").
   217  		Return(one, time.Duration(0), errors.New("fail")).
   218  		Once()
   219  
   220  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   221  	defer cancel()
   222  
   223  	_, err := or.OptimizedRevision(ctx)
   224  	req.Error(err)
   225  	mock.AssertExpectations(t)
   226  }