github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/dispatch/singleflight/singleflight_test.go (about) 1 package singleflight 2 3 import ( 4 "context" 5 "encoding/hex" 6 "sync" 7 "sync/atomic" 8 "testing" 9 "time" 10 11 "github.com/bits-and-blooms/bloom/v3" 12 "github.com/prometheus/client_golang/prometheus" 13 promclient "github.com/prometheus/client_model/go" 14 "github.com/stretchr/testify/require" 15 16 "github.com/authzed/spicedb/internal/dispatch" 17 "github.com/authzed/spicedb/internal/dispatch/keys" 18 v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" 19 "github.com/authzed/spicedb/pkg/tuple" 20 ) 21 22 const defaultBloomFilterSize = 50 23 24 func TestSingleFlightDispatcher(t *testing.T) { 25 var called atomic.Uint64 26 f := func() { 27 time.Sleep(100 * time.Millisecond) 28 called.Add(1) 29 } 30 disp := New(mockDispatcher{f: f}, &keys.DirectKeyHandler{}) 31 32 req := &v1.DispatchCheckRequest{ 33 ResourceRelation: tuple.RelationReference("document", "view"), 34 ResourceIds: []string{"foo", "bar"}, 35 Subject: tuple.ObjectAndRelation("user", "tom", "..."), 36 Metadata: &v1.ResolverMeta{ 37 AtRevision: "1234", 38 TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize), 39 }, 40 } 41 42 wg := sync.WaitGroup{} 43 wg.Add(4) 44 go func() { 45 _, _ = disp.DispatchCheck(context.Background(), req.CloneVT()) 46 wg.Done() 47 }() 48 go func() { 49 _, _ = disp.DispatchCheck(context.Background(), req.CloneVT()) 50 wg.Done() 51 }() 52 go func() { 53 _, _ = disp.DispatchCheck(context.Background(), req.CloneVT()) 54 wg.Done() 55 }() 56 go func() { 57 anotherReq := req.CloneVT() 58 anotherReq.ResourceIds = []string{"foo", "baz"} 59 _, _ = disp.DispatchCheck(context.Background(), anotherReq) 60 wg.Done() 61 }() 62 63 wg.Wait() 64 65 require.Equal(t, uint64(2), called.Load(), "should have dispatched %d calls but did %d", uint64(2), called.Load()) 66 } 67 68 func TestSingleFlightDispatcherDetectsLoop(t *testing.T) { 69 singleFlightCount = prometheus.NewCounterVec(singleFlightCountConfig, []string{"method", "shared"}) 70 reg := registerMetricInGatherer(singleFlightCount) 71 72 var called atomic.Uint64 73 f := func() { 74 time.Sleep(100 * time.Millisecond) 75 called.Add(1) 76 } 77 keyHandler := &keys.DirectKeyHandler{} 78 disp := New(mockDispatcher{f: f}, keyHandler) 79 80 req := &v1.DispatchCheckRequest{ 81 ResourceRelation: tuple.RelationReference("document", "view"), 82 ResourceIds: []string{"foo", "bar"}, 83 Subject: tuple.ObjectAndRelation("user", "tom", "..."), 84 Metadata: &v1.ResolverMeta{ 85 AtRevision: "1234", 86 TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize), 87 }, 88 } 89 90 // we simulate the request above being already part of the traversal path, 91 // so that the dispatcher detects a loop and does not singleflight 92 req.Metadata.TraversalBloom = bloomFilterForRequest(t, keyHandler, req) 93 94 wg := sync.WaitGroup{} 95 wg.Add(4) 96 go func() { 97 _, _ = disp.DispatchCheck(context.Background(), req.CloneVT()) 98 wg.Done() 99 }() 100 go func() { 101 _, _ = disp.DispatchCheck(context.Background(), req.CloneVT()) 102 wg.Done() 103 }() 104 go func() { 105 _, _ = disp.DispatchCheck(context.Background(), req.CloneVT()) 106 107 wg.Done() 108 }() 109 go func() { 110 differentReq := req.CloneVT() 111 differentReq.ResourceIds = []string{"foo", "baz"} 112 _, _ = disp.DispatchCheck(context.Background(), differentReq) 113 wg.Done() 114 }() 115 116 wg.Wait() 117 118 require.Equal(t, uint64(4), called.Load(), "should have dispatched %d calls but did %d", uint64(4), called.Load()) 119 assertCounterWithLabel(t, reg, 2, "spicedb_dispatch_single_flight_total", "loop") 120 } 121 122 // this test makes sure that bloom filter information is carried from dispatcher to dispatcher 123 func TestSingleFlightDispatcherDetectsLoopThroughDelegate(t *testing.T) { 124 singleFlightCount = prometheus.NewCounterVec(singleFlightCountConfig, []string{"method", "shared"}) 125 reg := registerMetricInGatherer(singleFlightCount) 126 127 var called atomic.Uint64 128 f := func() { 129 time.Sleep(100 * time.Millisecond) 130 called.Add(1) 131 } 132 keyHandler := &keys.DirectKeyHandler{} 133 // we simulate an actual dispatch-chain loop by nesting 2 singleflight dispatchers 134 disp := New(New(mockDispatcher{f: f}, keyHandler), keyHandler) 135 136 req := &v1.DispatchCheckRequest{ 137 ResourceRelation: tuple.RelationReference("document", "view"), 138 ResourceIds: []string{"foo", "bar"}, 139 Subject: tuple.ObjectAndRelation("user", "tom", "..."), 140 Metadata: &v1.ResolverMeta{ 141 AtRevision: "1234", 142 TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize), 143 }, 144 } 145 146 wg := sync.WaitGroup{} 147 wg.Add(3) 148 go func() { 149 _, _ = disp.DispatchCheck(context.Background(), req.CloneVT()) 150 wg.Done() 151 }() 152 go func() { 153 _, _ = disp.DispatchCheck(context.Background(), req.CloneVT()) 154 wg.Done() 155 }() 156 go func() { 157 _, _ = disp.DispatchCheck(context.Background(), req.CloneVT()) 158 159 wg.Done() 160 }() 161 162 wg.Wait() 163 164 require.Equal(t, uint64(1), called.Load(), "should have dispatched %d calls but did %d", uint64(1), called.Load()) 165 assertCounterWithLabel(t, reg, 2, "spicedb_dispatch_single_flight_total", "loop") 166 } 167 168 func TestSingleFlightDispatcherCancelation(t *testing.T) { 169 var called atomic.Uint64 170 run := make(chan struct{}, 1) 171 f := func() { 172 time.Sleep(100 * time.Millisecond) 173 called.Add(1) 174 run <- struct{}{} 175 } 176 177 req := &v1.DispatchCheckRequest{ 178 ResourceRelation: tuple.RelationReference("document", "view"), 179 ResourceIds: []string{"foo", "bar"}, 180 Subject: tuple.ObjectAndRelation("user", "tom", "..."), 181 Metadata: &v1.ResolverMeta{ 182 AtRevision: "1234", 183 TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize), 184 }, 185 } 186 187 disp := New(mockDispatcher{f: f}, &keys.DirectKeyHandler{}) 188 wg := sync.WaitGroup{} 189 wg.Add(3) 190 go func() { 191 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) 192 defer cancel() 193 _, err := disp.DispatchCheck(ctx, req.CloneVT()) 194 wg.Done() 195 require.ErrorIs(t, err, context.DeadlineExceeded) 196 }() 197 go func() { 198 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) 199 defer cancel() 200 _, err := disp.DispatchCheck(ctx, req.CloneVT()) 201 wg.Done() 202 require.ErrorIs(t, err, context.DeadlineExceeded) 203 }() 204 go func() { 205 ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50) 206 defer cancel() 207 _, err := disp.DispatchCheck(ctx, req.CloneVT()) 208 wg.Done() 209 require.ErrorIs(t, err, context.DeadlineExceeded) 210 }() 211 212 wg.Wait() 213 <-run 214 require.Equal(t, uint64(1), called.Load()) 215 } 216 217 func TestSingleFlightDispatcherExpand(t *testing.T) { 218 var called atomic.Uint64 219 f := func() { 220 time.Sleep(100 * time.Millisecond) 221 called.Add(1) 222 } 223 disp := New(mockDispatcher{f: f}, &keys.DirectKeyHandler{}) 224 225 req := &v1.DispatchExpandRequest{ 226 ResourceAndRelation: tuple.ObjectAndRelation("document", "foo", "view"), 227 Metadata: &v1.ResolverMeta{ 228 AtRevision: "1234", 229 TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize), 230 }, 231 } 232 233 wg := sync.WaitGroup{} 234 wg.Add(4) 235 go func() { 236 _, _ = disp.DispatchExpand(context.Background(), req.CloneVT()) 237 wg.Done() 238 }() 239 go func() { 240 _, _ = disp.DispatchExpand(context.Background(), req.CloneVT()) 241 wg.Done() 242 }() 243 go func() { 244 _, _ = disp.DispatchExpand(context.Background(), req.CloneVT()) 245 wg.Done() 246 }() 247 go func() { 248 anotherReq := req.CloneVT() 249 anotherReq.ResourceAndRelation.ObjectId = "baz" 250 _, _ = disp.DispatchExpand(context.Background(), anotherReq) 251 wg.Done() 252 }() 253 254 wg.Wait() 255 256 require.Equal(t, uint64(2), called.Load(), "should have dispatched %d calls but did %d", uint64(2), called.Load()) 257 } 258 259 func TestSingleFlightDispatcherCheckBypassesIfMissingBloomFiler(t *testing.T) { 260 singleFlightCount = prometheus.NewCounterVec(singleFlightCountConfig, []string{"method", "shared"}) 261 reg := registerMetricInGatherer(singleFlightCount) 262 263 var called atomic.Uint64 264 f := func() { 265 called.Add(1) 266 } 267 disp := New(mockDispatcher{f: f}, &keys.DirectKeyHandler{}) 268 269 req := &v1.DispatchCheckRequest{ 270 ResourceRelation: tuple.RelationReference("document", "view"), 271 ResourceIds: []string{"foo", "bar"}, 272 Subject: tuple.ObjectAndRelation("user", "tom", "..."), 273 Metadata: &v1.ResolverMeta{ 274 AtRevision: "1234", 275 }, 276 } 277 278 _, _ = disp.DispatchCheck(context.Background(), req.CloneVT()) 279 280 require.Equal(t, uint64(1), called.Load(), "should have dispatched %d calls but did %d", uint64(1), called.Load()) 281 assertCounterWithLabel(t, reg, 1, "spicedb_dispatch_single_flight_total", "missing") 282 } 283 284 func TestSingleFlightDispatcherExpandBypassesIfMissingBloomFiler(t *testing.T) { 285 singleFlightCount = prometheus.NewCounterVec(singleFlightCountConfig, []string{"method", "shared"}) 286 reg := registerMetricInGatherer(singleFlightCount) 287 288 var called atomic.Uint64 289 f := func() { 290 called.Add(1) 291 } 292 disp := New(mockDispatcher{f: f}, &keys.DirectKeyHandler{}) 293 294 req := &v1.DispatchExpandRequest{ 295 ResourceAndRelation: tuple.ObjectAndRelation("document", "foo", "view"), 296 Metadata: &v1.ResolverMeta{ 297 AtRevision: "1234", 298 }, 299 } 300 301 _, _ = disp.DispatchExpand(context.Background(), req.CloneVT()) 302 303 require.Equal(t, uint64(1), called.Load(), "should have dispatched %d calls but did %d", uint64(1), called.Load()) 304 assertCounterWithLabel(t, reg, 1, "spicedb_dispatch_single_flight_total", "missing") 305 } 306 307 func registerMetricInGatherer(collector prometheus.Collector) prometheus.Gatherer { 308 reg := prometheus.NewRegistry() 309 reg.MustRegister(collector) 310 311 return reg 312 } 313 314 func assertCounterWithLabel(t *testing.T, gatherer prometheus.Gatherer, expectedMetricsCount int, metricName, labelName string) { 315 t.Helper() 316 317 metrics, err := gatherer.Gather() 318 require.NoError(t, err) 319 320 var mf *promclient.MetricFamily 321 for _, metric := range metrics { 322 if metric.GetName() == metricName { 323 mf = metric 324 } 325 } 326 327 found := false 328 require.Len(t, mf.GetMetric(), expectedMetricsCount) 329 for _, metric := range mf.GetMetric() { 330 for _, label := range metric.Label { 331 if *label.Value == labelName { 332 found = true 333 } 334 } 335 } 336 337 require.True(t, found, "didn't find counter with label %s", labelName) 338 } 339 340 func bloomFilterForRequest(t *testing.T, keyHandler *keys.DirectKeyHandler, req *v1.DispatchCheckRequest) []byte { 341 t.Helper() 342 343 bloomFilter := bloom.NewWithEstimates(defaultBloomFilterSize, 0.001) 344 key, err := keyHandler.CheckDispatchKey(context.Background(), req) 345 require.NoError(t, err) 346 stringKey := hex.EncodeToString(key) 347 bloomFilter = bloomFilter.AddString(stringKey) 348 binaryBloom, err := bloomFilter.MarshalBinary() 349 require.NoError(t, err) 350 351 return binaryBloom 352 } 353 354 type mockDispatcher struct { 355 f func() 356 } 357 358 func (m mockDispatcher) DispatchCheck(_ context.Context, _ *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) { 359 m.f() 360 return &v1.DispatchCheckResponse{}, nil 361 } 362 363 func (m mockDispatcher) DispatchExpand(_ context.Context, _ *v1.DispatchExpandRequest) (*v1.DispatchExpandResponse, error) { 364 m.f() 365 return &v1.DispatchExpandResponse{}, nil 366 } 367 368 func (m mockDispatcher) DispatchReachableResources(_ *v1.DispatchReachableResourcesRequest, _ dispatch.ReachableResourcesStream) error { 369 return nil 370 } 371 372 func (m mockDispatcher) DispatchLookupResources(_ *v1.DispatchLookupResourcesRequest, _ dispatch.LookupResourcesStream) error { 373 return nil 374 } 375 376 func (m mockDispatcher) DispatchLookupSubjects(_ *v1.DispatchLookupSubjectsRequest, _ dispatch.LookupSubjectsStream) error { 377 return nil 378 } 379 380 func (m mockDispatcher) Close() error { 381 return nil 382 } 383 384 func (m mockDispatcher) ReadyState() dispatch.ReadyState { 385 return dispatch.ReadyState{} 386 }