github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/dispatch/singleflight/singleflight_test.go (about)

     1  package singleflight
     2  
     3  import (
     4  	"context"
     5  	"encoding/hex"
     6  	"sync"
     7  	"sync/atomic"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/bits-and-blooms/bloom/v3"
    12  	"github.com/prometheus/client_golang/prometheus"
    13  	promclient "github.com/prometheus/client_model/go"
    14  	"github.com/stretchr/testify/require"
    15  
    16  	"github.com/authzed/spicedb/internal/dispatch"
    17  	"github.com/authzed/spicedb/internal/dispatch/keys"
    18  	v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
    19  	"github.com/authzed/spicedb/pkg/tuple"
    20  )
    21  
    22  const defaultBloomFilterSize = 50
    23  
    24  func TestSingleFlightDispatcher(t *testing.T) {
    25  	var called atomic.Uint64
    26  	f := func() {
    27  		time.Sleep(100 * time.Millisecond)
    28  		called.Add(1)
    29  	}
    30  	disp := New(mockDispatcher{f: f}, &keys.DirectKeyHandler{})
    31  
    32  	req := &v1.DispatchCheckRequest{
    33  		ResourceRelation: tuple.RelationReference("document", "view"),
    34  		ResourceIds:      []string{"foo", "bar"},
    35  		Subject:          tuple.ObjectAndRelation("user", "tom", "..."),
    36  		Metadata: &v1.ResolverMeta{
    37  			AtRevision:     "1234",
    38  			TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize),
    39  		},
    40  	}
    41  
    42  	wg := sync.WaitGroup{}
    43  	wg.Add(4)
    44  	go func() {
    45  		_, _ = disp.DispatchCheck(context.Background(), req.CloneVT())
    46  		wg.Done()
    47  	}()
    48  	go func() {
    49  		_, _ = disp.DispatchCheck(context.Background(), req.CloneVT())
    50  		wg.Done()
    51  	}()
    52  	go func() {
    53  		_, _ = disp.DispatchCheck(context.Background(), req.CloneVT())
    54  		wg.Done()
    55  	}()
    56  	go func() {
    57  		anotherReq := req.CloneVT()
    58  		anotherReq.ResourceIds = []string{"foo", "baz"}
    59  		_, _ = disp.DispatchCheck(context.Background(), anotherReq)
    60  		wg.Done()
    61  	}()
    62  
    63  	wg.Wait()
    64  
    65  	require.Equal(t, uint64(2), called.Load(), "should have dispatched %d calls but did %d", uint64(2), called.Load())
    66  }
    67  
    68  func TestSingleFlightDispatcherDetectsLoop(t *testing.T) {
    69  	singleFlightCount = prometheus.NewCounterVec(singleFlightCountConfig, []string{"method", "shared"})
    70  	reg := registerMetricInGatherer(singleFlightCount)
    71  
    72  	var called atomic.Uint64
    73  	f := func() {
    74  		time.Sleep(100 * time.Millisecond)
    75  		called.Add(1)
    76  	}
    77  	keyHandler := &keys.DirectKeyHandler{}
    78  	disp := New(mockDispatcher{f: f}, keyHandler)
    79  
    80  	req := &v1.DispatchCheckRequest{
    81  		ResourceRelation: tuple.RelationReference("document", "view"),
    82  		ResourceIds:      []string{"foo", "bar"},
    83  		Subject:          tuple.ObjectAndRelation("user", "tom", "..."),
    84  		Metadata: &v1.ResolverMeta{
    85  			AtRevision:     "1234",
    86  			TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize),
    87  		},
    88  	}
    89  
    90  	// we simulate the request above being already part of the traversal path,
    91  	// so that the dispatcher detects a loop and does not singleflight
    92  	req.Metadata.TraversalBloom = bloomFilterForRequest(t, keyHandler, req)
    93  
    94  	wg := sync.WaitGroup{}
    95  	wg.Add(4)
    96  	go func() {
    97  		_, _ = disp.DispatchCheck(context.Background(), req.CloneVT())
    98  		wg.Done()
    99  	}()
   100  	go func() {
   101  		_, _ = disp.DispatchCheck(context.Background(), req.CloneVT())
   102  		wg.Done()
   103  	}()
   104  	go func() {
   105  		_, _ = disp.DispatchCheck(context.Background(), req.CloneVT())
   106  
   107  		wg.Done()
   108  	}()
   109  	go func() {
   110  		differentReq := req.CloneVT()
   111  		differentReq.ResourceIds = []string{"foo", "baz"}
   112  		_, _ = disp.DispatchCheck(context.Background(), differentReq)
   113  		wg.Done()
   114  	}()
   115  
   116  	wg.Wait()
   117  
   118  	require.Equal(t, uint64(4), called.Load(), "should have dispatched %d calls but did %d", uint64(4), called.Load())
   119  	assertCounterWithLabel(t, reg, 2, "spicedb_dispatch_single_flight_total", "loop")
   120  }
   121  
   122  // this test makes sure that bloom filter information is carried from dispatcher to dispatcher
   123  func TestSingleFlightDispatcherDetectsLoopThroughDelegate(t *testing.T) {
   124  	singleFlightCount = prometheus.NewCounterVec(singleFlightCountConfig, []string{"method", "shared"})
   125  	reg := registerMetricInGatherer(singleFlightCount)
   126  
   127  	var called atomic.Uint64
   128  	f := func() {
   129  		time.Sleep(100 * time.Millisecond)
   130  		called.Add(1)
   131  	}
   132  	keyHandler := &keys.DirectKeyHandler{}
   133  	// we simulate an actual dispatch-chain loop by nesting 2 singleflight dispatchers
   134  	disp := New(New(mockDispatcher{f: f}, keyHandler), keyHandler)
   135  
   136  	req := &v1.DispatchCheckRequest{
   137  		ResourceRelation: tuple.RelationReference("document", "view"),
   138  		ResourceIds:      []string{"foo", "bar"},
   139  		Subject:          tuple.ObjectAndRelation("user", "tom", "..."),
   140  		Metadata: &v1.ResolverMeta{
   141  			AtRevision:     "1234",
   142  			TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize),
   143  		},
   144  	}
   145  
   146  	wg := sync.WaitGroup{}
   147  	wg.Add(3)
   148  	go func() {
   149  		_, _ = disp.DispatchCheck(context.Background(), req.CloneVT())
   150  		wg.Done()
   151  	}()
   152  	go func() {
   153  		_, _ = disp.DispatchCheck(context.Background(), req.CloneVT())
   154  		wg.Done()
   155  	}()
   156  	go func() {
   157  		_, _ = disp.DispatchCheck(context.Background(), req.CloneVT())
   158  
   159  		wg.Done()
   160  	}()
   161  
   162  	wg.Wait()
   163  
   164  	require.Equal(t, uint64(1), called.Load(), "should have dispatched %d calls but did %d", uint64(1), called.Load())
   165  	assertCounterWithLabel(t, reg, 2, "spicedb_dispatch_single_flight_total", "loop")
   166  }
   167  
   168  func TestSingleFlightDispatcherCancelation(t *testing.T) {
   169  	var called atomic.Uint64
   170  	run := make(chan struct{}, 1)
   171  	f := func() {
   172  		time.Sleep(100 * time.Millisecond)
   173  		called.Add(1)
   174  		run <- struct{}{}
   175  	}
   176  
   177  	req := &v1.DispatchCheckRequest{
   178  		ResourceRelation: tuple.RelationReference("document", "view"),
   179  		ResourceIds:      []string{"foo", "bar"},
   180  		Subject:          tuple.ObjectAndRelation("user", "tom", "..."),
   181  		Metadata: &v1.ResolverMeta{
   182  			AtRevision:     "1234",
   183  			TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize),
   184  		},
   185  	}
   186  
   187  	disp := New(mockDispatcher{f: f}, &keys.DirectKeyHandler{})
   188  	wg := sync.WaitGroup{}
   189  	wg.Add(3)
   190  	go func() {
   191  		ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
   192  		defer cancel()
   193  		_, err := disp.DispatchCheck(ctx, req.CloneVT())
   194  		wg.Done()
   195  		require.ErrorIs(t, err, context.DeadlineExceeded)
   196  	}()
   197  	go func() {
   198  		ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
   199  		defer cancel()
   200  		_, err := disp.DispatchCheck(ctx, req.CloneVT())
   201  		wg.Done()
   202  		require.ErrorIs(t, err, context.DeadlineExceeded)
   203  	}()
   204  	go func() {
   205  		ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
   206  		defer cancel()
   207  		_, err := disp.DispatchCheck(ctx, req.CloneVT())
   208  		wg.Done()
   209  		require.ErrorIs(t, err, context.DeadlineExceeded)
   210  	}()
   211  
   212  	wg.Wait()
   213  	<-run
   214  	require.Equal(t, uint64(1), called.Load())
   215  }
   216  
   217  func TestSingleFlightDispatcherExpand(t *testing.T) {
   218  	var called atomic.Uint64
   219  	f := func() {
   220  		time.Sleep(100 * time.Millisecond)
   221  		called.Add(1)
   222  	}
   223  	disp := New(mockDispatcher{f: f}, &keys.DirectKeyHandler{})
   224  
   225  	req := &v1.DispatchExpandRequest{
   226  		ResourceAndRelation: tuple.ObjectAndRelation("document", "foo", "view"),
   227  		Metadata: &v1.ResolverMeta{
   228  			AtRevision:     "1234",
   229  			TraversalBloom: v1.MustNewTraversalBloomFilter(defaultBloomFilterSize),
   230  		},
   231  	}
   232  
   233  	wg := sync.WaitGroup{}
   234  	wg.Add(4)
   235  	go func() {
   236  		_, _ = disp.DispatchExpand(context.Background(), req.CloneVT())
   237  		wg.Done()
   238  	}()
   239  	go func() {
   240  		_, _ = disp.DispatchExpand(context.Background(), req.CloneVT())
   241  		wg.Done()
   242  	}()
   243  	go func() {
   244  		_, _ = disp.DispatchExpand(context.Background(), req.CloneVT())
   245  		wg.Done()
   246  	}()
   247  	go func() {
   248  		anotherReq := req.CloneVT()
   249  		anotherReq.ResourceAndRelation.ObjectId = "baz"
   250  		_, _ = disp.DispatchExpand(context.Background(), anotherReq)
   251  		wg.Done()
   252  	}()
   253  
   254  	wg.Wait()
   255  
   256  	require.Equal(t, uint64(2), called.Load(), "should have dispatched %d calls but did %d", uint64(2), called.Load())
   257  }
   258  
   259  func TestSingleFlightDispatcherCheckBypassesIfMissingBloomFiler(t *testing.T) {
   260  	singleFlightCount = prometheus.NewCounterVec(singleFlightCountConfig, []string{"method", "shared"})
   261  	reg := registerMetricInGatherer(singleFlightCount)
   262  
   263  	var called atomic.Uint64
   264  	f := func() {
   265  		called.Add(1)
   266  	}
   267  	disp := New(mockDispatcher{f: f}, &keys.DirectKeyHandler{})
   268  
   269  	req := &v1.DispatchCheckRequest{
   270  		ResourceRelation: tuple.RelationReference("document", "view"),
   271  		ResourceIds:      []string{"foo", "bar"},
   272  		Subject:          tuple.ObjectAndRelation("user", "tom", "..."),
   273  		Metadata: &v1.ResolverMeta{
   274  			AtRevision: "1234",
   275  		},
   276  	}
   277  
   278  	_, _ = disp.DispatchCheck(context.Background(), req.CloneVT())
   279  
   280  	require.Equal(t, uint64(1), called.Load(), "should have dispatched %d calls but did %d", uint64(1), called.Load())
   281  	assertCounterWithLabel(t, reg, 1, "spicedb_dispatch_single_flight_total", "missing")
   282  }
   283  
   284  func TestSingleFlightDispatcherExpandBypassesIfMissingBloomFiler(t *testing.T) {
   285  	singleFlightCount = prometheus.NewCounterVec(singleFlightCountConfig, []string{"method", "shared"})
   286  	reg := registerMetricInGatherer(singleFlightCount)
   287  
   288  	var called atomic.Uint64
   289  	f := func() {
   290  		called.Add(1)
   291  	}
   292  	disp := New(mockDispatcher{f: f}, &keys.DirectKeyHandler{})
   293  
   294  	req := &v1.DispatchExpandRequest{
   295  		ResourceAndRelation: tuple.ObjectAndRelation("document", "foo", "view"),
   296  		Metadata: &v1.ResolverMeta{
   297  			AtRevision: "1234",
   298  		},
   299  	}
   300  
   301  	_, _ = disp.DispatchExpand(context.Background(), req.CloneVT())
   302  
   303  	require.Equal(t, uint64(1), called.Load(), "should have dispatched %d calls but did %d", uint64(1), called.Load())
   304  	assertCounterWithLabel(t, reg, 1, "spicedb_dispatch_single_flight_total", "missing")
   305  }
   306  
   307  func registerMetricInGatherer(collector prometheus.Collector) prometheus.Gatherer {
   308  	reg := prometheus.NewRegistry()
   309  	reg.MustRegister(collector)
   310  
   311  	return reg
   312  }
   313  
   314  func assertCounterWithLabel(t *testing.T, gatherer prometheus.Gatherer, expectedMetricsCount int, metricName, labelName string) {
   315  	t.Helper()
   316  
   317  	metrics, err := gatherer.Gather()
   318  	require.NoError(t, err)
   319  
   320  	var mf *promclient.MetricFamily
   321  	for _, metric := range metrics {
   322  		if metric.GetName() == metricName {
   323  			mf = metric
   324  		}
   325  	}
   326  
   327  	found := false
   328  	require.Len(t, mf.GetMetric(), expectedMetricsCount)
   329  	for _, metric := range mf.GetMetric() {
   330  		for _, label := range metric.Label {
   331  			if *label.Value == labelName {
   332  				found = true
   333  			}
   334  		}
   335  	}
   336  
   337  	require.True(t, found, "didn't find counter with label %s", labelName)
   338  }
   339  
   340  func bloomFilterForRequest(t *testing.T, keyHandler *keys.DirectKeyHandler, req *v1.DispatchCheckRequest) []byte {
   341  	t.Helper()
   342  
   343  	bloomFilter := bloom.NewWithEstimates(defaultBloomFilterSize, 0.001)
   344  	key, err := keyHandler.CheckDispatchKey(context.Background(), req)
   345  	require.NoError(t, err)
   346  	stringKey := hex.EncodeToString(key)
   347  	bloomFilter = bloomFilter.AddString(stringKey)
   348  	binaryBloom, err := bloomFilter.MarshalBinary()
   349  	require.NoError(t, err)
   350  
   351  	return binaryBloom
   352  }
   353  
   354  type mockDispatcher struct {
   355  	f func()
   356  }
   357  
   358  func (m mockDispatcher) DispatchCheck(_ context.Context, _ *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) {
   359  	m.f()
   360  	return &v1.DispatchCheckResponse{}, nil
   361  }
   362  
   363  func (m mockDispatcher) DispatchExpand(_ context.Context, _ *v1.DispatchExpandRequest) (*v1.DispatchExpandResponse, error) {
   364  	m.f()
   365  	return &v1.DispatchExpandResponse{}, nil
   366  }
   367  
   368  func (m mockDispatcher) DispatchReachableResources(_ *v1.DispatchReachableResourcesRequest, _ dispatch.ReachableResourcesStream) error {
   369  	return nil
   370  }
   371  
   372  func (m mockDispatcher) DispatchLookupResources(_ *v1.DispatchLookupResourcesRequest, _ dispatch.LookupResourcesStream) error {
   373  	return nil
   374  }
   375  
   376  func (m mockDispatcher) DispatchLookupSubjects(_ *v1.DispatchLookupSubjectsRequest, _ dispatch.LookupSubjectsStream) error {
   377  	return nil
   378  }
   379  
   380  func (m mockDispatcher) Close() error {
   381  	return nil
   382  }
   383  
   384  func (m mockDispatcher) ReadyState() dispatch.ReadyState {
   385  	return dispatch.ReadyState{}
   386  }