github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/revisions/optimized_test.go (about) 1 package revisions 2 3 import ( 4 "context" 5 "errors" 6 "testing" 7 "time" 8 9 "github.com/benbjohnson/clock" 10 "github.com/samber/lo" 11 "github.com/stretchr/testify/mock" 12 "github.com/stretchr/testify/require" 13 "golang.org/x/sync/errgroup" 14 15 "github.com/authzed/spicedb/pkg/datastore" 16 ) 17 18 type trackingRevisionFunction struct { 19 mock.Mock 20 } 21 22 func (m *trackingRevisionFunction) optimizedRevisionFunc(_ context.Context) (datastore.Revision, time.Duration, error) { 23 args := m.Called() 24 return args.Get(0).(datastore.Revision), args.Get(1).(time.Duration), args.Error(2) 25 } 26 27 var ( 28 one = NewForTransactionID(1) 29 two = NewForTransactionID(2) 30 three = NewForTransactionID(3) 31 ) 32 33 func cand(revs ...datastore.Revision) []datastore.Revision { 34 return revs 35 } 36 37 func TestOptimizedRevisionCache(t *testing.T) { 38 type revisionResponse struct { 39 rev datastore.Revision 40 validFor time.Duration 41 } 42 43 testCases := []struct { 44 name string 45 maxStaleness time.Duration 46 expectedCallResponses []revisionResponse 47 expectedRevisions [][]datastore.Revision 48 }{ 49 { 50 "single request", 51 0, 52 []revisionResponse{ 53 {one, 0}, 54 }, 55 [][]datastore.Revision{cand(one)}, 56 }, 57 { 58 "simple no caching request", 59 0, 60 []revisionResponse{ 61 {one, 0}, 62 {two, 0}, 63 {three, 0}, 64 }, 65 [][]datastore.Revision{cand(one), cand(two), cand(three)}, 66 }, 67 { 68 "simple cached once", 69 0, 70 []revisionResponse{ 71 {one, 7 * time.Millisecond}, 72 {two, 0}, 73 }, 74 [][]datastore.Revision{cand(one), cand(one), cand(two)}, 75 }, 76 { 77 "cached by staleness", 78 7 * time.Millisecond, 79 []revisionResponse{ 80 {one, 0}, 81 {two, 100 * time.Millisecond}, 82 }, 83 [][]datastore.Revision{cand(one), cand(one, two), cand(two), cand(two)}, 84 }, 85 { 86 "cached by staleness and validity", 87 2 * time.Millisecond, 88 []revisionResponse{ 89 {one, 4 * time.Millisecond}, 90 {two, 100 * time.Millisecond}, 91 }, 92 [][]datastore.Revision{cand(one), cand(one, two), cand(two)}, 93 }, 94 { 95 "cached for a while", 96 0, 97 []revisionResponse{ 98 {one, 28 * time.Millisecond}, 99 {two, 0}, 100 }, 101 [][]datastore.Revision{cand(one), cand(one), cand(one), cand(one), cand(one), cand(one), cand(two)}, 102 }, 103 } 104 105 for _, tc := range testCases { 106 tc := tc 107 t.Run(tc.name, func(t *testing.T) { 108 require := require.New(t) 109 110 or := NewCachedOptimizedRevisions(tc.maxStaleness) 111 mockTime := clock.NewMock() 112 or.clockFn = mockTime 113 mock := trackingRevisionFunction{} 114 or.SetOptimizedRevisionFunc(mock.optimizedRevisionFunc) 115 116 for _, callSpec := range tc.expectedCallResponses { 117 mock.On("optimizedRevisionFunc").Return(callSpec.rev, callSpec.validFor, nil).Once() 118 } 119 120 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 121 defer cancel() 122 123 for _, expectedRevSet := range tc.expectedRevisions { 124 awaitingRevisions := make(map[datastore.Revision]struct{}, len(expectedRevSet)) 125 for _, rev := range expectedRevSet { 126 awaitingRevisions[rev] = struct{}{} 127 } 128 129 require.Eventually(func() bool { 130 revision, err := or.OptimizedRevision(ctx) 131 require.NoError(err) 132 printableRevSet := lo.Map(expectedRevSet, func(val datastore.Revision, index int) string { 133 return val.String() 134 }) 135 require.Contains(expectedRevSet, revision, "must return the proper revision, allowed set %#v, received %s", printableRevSet, revision) 136 137 delete(awaitingRevisions, revision) 138 return len(awaitingRevisions) == 0 139 }, 1*time.Second, 1*time.Microsecond) 140 141 mockTime.Add(5 * time.Millisecond) 142 } 143 144 mock.AssertExpectations(t) 145 }) 146 } 147 } 148 149 func TestOptimizedRevisionCacheSingleFlight(t *testing.T) { 150 require := require.New(t) 151 152 or := NewCachedOptimizedRevisions(0) 153 mock := trackingRevisionFunction{} 154 or.SetOptimizedRevisionFunc(mock.optimizedRevisionFunc) 155 156 mock. 157 On("optimizedRevisionFunc"). 158 Return(one, time.Duration(0), nil). 159 After(50 * time.Millisecond). 160 Once() 161 162 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 163 defer cancel() 164 165 g := errgroup.Group{} 166 for i := 0; i < 10; i++ { 167 g.Go(func() error { 168 revision, err := or.OptimizedRevision(ctx) 169 if err != nil { 170 return err 171 } 172 require.True(one.Equal(revision), "must return the proper revision %s != %s", one, revision) 173 return nil 174 }) 175 time.Sleep(1 * time.Millisecond) 176 } 177 178 err := g.Wait() 179 require.NoError(err) 180 181 mock.AssertExpectations(t) 182 } 183 184 func BenchmarkOptimizedRevisions(b *testing.B) { 185 b.SetParallelism(1024) 186 187 quantization := 1 * time.Millisecond 188 or := NewCachedOptimizedRevisions(quantization) 189 190 or.SetOptimizedRevisionFunc(func(ctx context.Context) (datastore.Revision, time.Duration, error) { 191 nowNS := time.Now().UnixNano() 192 validForNS := nowNS % quantization.Nanoseconds() 193 roundedNS := nowNS - validForNS 194 rev := NewForTransactionID(uint64(roundedNS)) 195 return rev, time.Duration(validForNS) * time.Nanosecond, nil 196 }) 197 198 ctx := context.Background() 199 b.RunParallel(func(p *testing.PB) { 200 for p.Next() { 201 if _, err := or.OptimizedRevision(ctx); err != nil { 202 b.FailNow() 203 } 204 } 205 }) 206 } 207 208 func TestSingleFlightError(t *testing.T) { 209 req := require.New(t) 210 211 or := NewCachedOptimizedRevisions(0) 212 mock := trackingRevisionFunction{} 213 or.SetOptimizedRevisionFunc(mock.optimizedRevisionFunc) 214 215 mock. 216 On("optimizedRevisionFunc"). 217 Return(one, time.Duration(0), errors.New("fail")). 218 Once() 219 220 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 221 defer cancel() 222 223 _, err := or.OptimizedRevision(ctx) 224 req.Error(err) 225 mock.AssertExpectations(t) 226 }