github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/dispatch/caching/caching.go (about)

     1  package caching
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"maps"
     7  	"sync"
     8  	"testing"
     9  	"unsafe"
    10  
    11  	"github.com/dustin/go-humanize"
    12  	"github.com/prometheus/client_golang/prometheus"
    13  	"github.com/stretchr/testify/require"
    14  	"go.opentelemetry.io/otel/attribute"
    15  	"go.opentelemetry.io/otel/trace"
    16  
    17  	"github.com/authzed/spicedb/internal/dispatch"
    18  	"github.com/authzed/spicedb/internal/dispatch/keys"
    19  	"github.com/authzed/spicedb/pkg/cache"
    20  	v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
    21  )
    22  
    23  const (
    24  	errCachingInitialization = "error initializing caching dispatcher: %w"
    25  
    26  	prometheusNamespace = "spicedb"
    27  )
    28  
    29  // Dispatcher is a dispatcher with cacheInst-in caching.
    30  type Dispatcher struct {
    31  	d          dispatch.Dispatcher
    32  	c          cache.Cache
    33  	keyHandler keys.Handler
    34  
    35  	checkTotalCounter                  prometheus.Counter
    36  	checkFromCacheCounter              prometheus.Counter
    37  	reachableResourcesTotalCounter     prometheus.Counter
    38  	reachableResourcesFromCacheCounter prometheus.Counter
    39  	lookupResourcesTotalCounter        prometheus.Counter
    40  	lookupResourcesFromCacheCounter    prometheus.Counter
    41  	lookupSubjectsTotalCounter         prometheus.Counter
    42  	lookupSubjectsFromCacheCounter     prometheus.Counter
    43  }
    44  
    45  func DispatchTestCache(t testing.TB) cache.Cache {
    46  	cache, err := cache.NewCache(&cache.Config{
    47  		NumCounters: 1000,
    48  		MaxCost:     1 * humanize.MiByte,
    49  	})
    50  	require.Nil(t, err)
    51  	return cache
    52  }
    53  
    54  // NewCachingDispatcher creates a new dispatch.Dispatcher which delegates
    55  // dispatch requests and caches the responses when possible and desirable.
    56  func NewCachingDispatcher(cacheInst cache.Cache, metricsEnabled bool, prometheusSubsystem string, keyHandler keys.Handler) (*Dispatcher, error) {
    57  	if cacheInst == nil {
    58  		cacheInst = cache.NoopCache()
    59  	}
    60  
    61  	checkTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{
    62  		Namespace: prometheusNamespace,
    63  		Subsystem: prometheusSubsystem,
    64  		Name:      "check_total",
    65  	})
    66  	checkFromCacheCounter := prometheus.NewCounter(prometheus.CounterOpts{
    67  		Namespace: prometheusNamespace,
    68  		Subsystem: prometheusSubsystem,
    69  		Name:      "check_from_cache_total",
    70  	})
    71  
    72  	lookupResourcesTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{
    73  		Namespace: prometheusNamespace,
    74  		Subsystem: prometheusSubsystem,
    75  		Name:      "lookup_resources_total",
    76  	})
    77  	lookupResourcesFromCacheCounter := prometheus.NewCounter(prometheus.CounterOpts{
    78  		Namespace: prometheusNamespace,
    79  		Subsystem: prometheusSubsystem,
    80  		Name:      "lookup_resources_from_cache_total",
    81  	})
    82  
    83  	reachableResourcesTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{
    84  		Namespace: prometheusNamespace,
    85  		Subsystem: prometheusSubsystem,
    86  		Name:      "reachable_resources_total",
    87  	})
    88  	reachableResourcesFromCacheCounter := prometheus.NewCounter(prometheus.CounterOpts{
    89  		Namespace: prometheusNamespace,
    90  		Subsystem: prometheusSubsystem,
    91  		Name:      "reachable_resources_from_cache_total",
    92  	})
    93  
    94  	lookupSubjectsTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{
    95  		Namespace: prometheusNamespace,
    96  		Subsystem: prometheusSubsystem,
    97  		Name:      "lookup_subjects_total",
    98  	})
    99  	lookupSubjectsFromCacheCounter := prometheus.NewCounter(prometheus.CounterOpts{
   100  		Namespace: prometheusNamespace,
   101  		Subsystem: prometheusSubsystem,
   102  		Name:      "lookup_subjects_from_cache_total",
   103  	})
   104  
   105  	if metricsEnabled && prometheusSubsystem != "" {
   106  		err := prometheus.Register(checkTotalCounter)
   107  		if err != nil {
   108  			return nil, fmt.Errorf(errCachingInitialization, err)
   109  		}
   110  		err = prometheus.Register(checkFromCacheCounter)
   111  		if err != nil {
   112  			return nil, fmt.Errorf(errCachingInitialization, err)
   113  		}
   114  		err = prometheus.Register(lookupResourcesTotalCounter)
   115  		if err != nil {
   116  			return nil, fmt.Errorf(errCachingInitialization, err)
   117  		}
   118  		err = prometheus.Register(lookupResourcesFromCacheCounter)
   119  		if err != nil {
   120  			return nil, fmt.Errorf(errCachingInitialization, err)
   121  		}
   122  		err = prometheus.Register(reachableResourcesTotalCounter)
   123  		if err != nil {
   124  			return nil, fmt.Errorf(errCachingInitialization, err)
   125  		}
   126  		err = prometheus.Register(reachableResourcesFromCacheCounter)
   127  		if err != nil {
   128  			return nil, fmt.Errorf(errCachingInitialization, err)
   129  		}
   130  		err = prometheus.Register(lookupSubjectsTotalCounter)
   131  		if err != nil {
   132  			return nil, fmt.Errorf(errCachingInitialization, err)
   133  		}
   134  		err = prometheus.Register(lookupSubjectsFromCacheCounter)
   135  		if err != nil {
   136  			return nil, fmt.Errorf(errCachingInitialization, err)
   137  		}
   138  	}
   139  
   140  	if keyHandler == nil {
   141  		keyHandler = &keys.DirectKeyHandler{}
   142  	}
   143  
   144  	return &Dispatcher{
   145  		d:                                  fakeDelegate{},
   146  		c:                                  cacheInst,
   147  		keyHandler:                         keyHandler,
   148  		checkTotalCounter:                  checkTotalCounter,
   149  		checkFromCacheCounter:              checkFromCacheCounter,
   150  		reachableResourcesTotalCounter:     reachableResourcesTotalCounter,
   151  		reachableResourcesFromCacheCounter: reachableResourcesFromCacheCounter,
   152  		lookupResourcesTotalCounter:        lookupResourcesTotalCounter,
   153  		lookupResourcesFromCacheCounter:    lookupResourcesFromCacheCounter,
   154  		lookupSubjectsTotalCounter:         lookupSubjectsTotalCounter,
   155  		lookupSubjectsFromCacheCounter:     lookupSubjectsFromCacheCounter,
   156  	}, nil
   157  }
   158  
   159  // SetDelegate sets the internal delegate to the specific dispatcher instance.
   160  func (cd *Dispatcher) SetDelegate(delegate dispatch.Dispatcher) {
   161  	cd.d = delegate
   162  }
   163  
   164  // DispatchCheck implements dispatch.Check interface
   165  func (cd *Dispatcher) DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) {
   166  	cd.checkTotalCounter.Inc()
   167  
   168  	requestKey, err := cd.keyHandler.CheckCacheKey(ctx, req)
   169  	if err != nil {
   170  		return &v1.DispatchCheckResponse{Metadata: &v1.ResponseMeta{}}, err
   171  	}
   172  
   173  	// Disable caching when debugging is enabled.
   174  	span := trace.SpanFromContext(ctx)
   175  	if cachedResultRaw, found := cd.c.Get(requestKey); found {
   176  		var response v1.DispatchCheckResponse
   177  		if err := response.UnmarshalVT(cachedResultRaw.([]byte)); err != nil {
   178  			return &v1.DispatchCheckResponse{Metadata: &v1.ResponseMeta{}}, err
   179  		}
   180  
   181  		if req.Metadata.DepthRemaining >= response.Metadata.DepthRequired {
   182  			cd.checkFromCacheCounter.Inc()
   183  			// If debugging is requested, add the req and the response to the trace.
   184  			if req.Debug == v1.DispatchCheckRequest_ENABLE_BASIC_DEBUGGING {
   185  				response.Metadata.DebugInfo = &v1.DebugInformation{
   186  					Check: &v1.CheckDebugTrace{
   187  						Request:        req,
   188  						Results:        maps.Clone(response.ResultsByResourceId),
   189  						IsCachedResult: true,
   190  					},
   191  				}
   192  			}
   193  
   194  			span.SetAttributes(attribute.Bool("cached", true))
   195  			return &response, nil
   196  		}
   197  	}
   198  	span.SetAttributes(attribute.Bool("cached", false))
   199  	computed, err := cd.d.DispatchCheck(ctx, req)
   200  
   201  	// We only want to cache the result if there was no error
   202  	if err == nil {
   203  		adjustedComputed := computed.CloneVT()
   204  		adjustedComputed.Metadata.CachedDispatchCount = adjustedComputed.Metadata.DispatchCount
   205  		adjustedComputed.Metadata.DispatchCount = 0
   206  		adjustedComputed.Metadata.DebugInfo = nil
   207  
   208  		adjustedBytes, err := adjustedComputed.MarshalVT()
   209  		if err != nil {
   210  			return &v1.DispatchCheckResponse{Metadata: &v1.ResponseMeta{}}, err
   211  		}
   212  
   213  		cd.c.Set(requestKey, adjustedBytes, sliceSize(adjustedBytes))
   214  	}
   215  
   216  	// Return both the computed and err in ALL cases: computed contains resolved
   217  	// metadata even if there was an error.
   218  	return computed, err
   219  }
   220  
   221  // DispatchExpand implements dispatch.Expand interface and does not do any caching yet.
   222  func (cd *Dispatcher) DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest) (*v1.DispatchExpandResponse, error) {
   223  	resp, err := cd.d.DispatchExpand(ctx, req)
   224  	return resp, err
   225  }
   226  
   227  // DispatchReachableResources implements dispatch.ReachableResources interface.
   228  func (cd *Dispatcher) DispatchReachableResources(req *v1.DispatchReachableResourcesRequest, stream dispatch.ReachableResourcesStream) error {
   229  	cd.reachableResourcesTotalCounter.Inc()
   230  
   231  	requestKey, err := cd.keyHandler.ReachableResourcesCacheKey(stream.Context(), req)
   232  	if err != nil {
   233  		return err
   234  	}
   235  
   236  	if cachedResultRaw, found := cd.c.Get(requestKey); found {
   237  		cd.reachableResourcesFromCacheCounter.Inc()
   238  		for _, slice := range cachedResultRaw.([][]byte) {
   239  			var response v1.DispatchReachableResourcesResponse
   240  			if err := response.UnmarshalVT(slice); err != nil {
   241  				return fmt.Errorf("could not publish cached reachable resources result: %w", err)
   242  			}
   243  			if err := stream.Publish(&response); err != nil {
   244  				return fmt.Errorf("could not publish cached reachable resources result: %w", err)
   245  			}
   246  		}
   247  
   248  		return nil
   249  	}
   250  
   251  	var (
   252  		mu             sync.Mutex
   253  		toCacheResults [][]byte
   254  	)
   255  	wrapped := &dispatch.WrappedDispatchStream[*v1.DispatchReachableResourcesResponse]{
   256  		Stream: stream,
   257  		Ctx:    stream.Context(),
   258  		Processor: func(result *v1.DispatchReachableResourcesResponse) (*v1.DispatchReachableResourcesResponse, bool, error) {
   259  			adjustedResult := result.CloneVT()
   260  			adjustedResult.Metadata.CachedDispatchCount = adjustedResult.Metadata.DispatchCount
   261  			adjustedResult.Metadata.DispatchCount = 0
   262  			adjustedResult.Metadata.DebugInfo = nil
   263  
   264  			adjustedBytes, err := adjustedResult.MarshalVT()
   265  			if err != nil {
   266  				return nil, false, err
   267  			}
   268  
   269  			mu.Lock()
   270  			toCacheResults = append(toCacheResults, adjustedBytes)
   271  			mu.Unlock()
   272  
   273  			return result, true, nil
   274  		},
   275  	}
   276  
   277  	if err := cd.d.DispatchReachableResources(req, wrapped); err != nil {
   278  		return err
   279  	}
   280  
   281  	var size int64
   282  	for _, slice := range toCacheResults {
   283  		size += sliceSize(slice)
   284  	}
   285  
   286  	cd.c.Set(requestKey, toCacheResults, size)
   287  	return nil
   288  }
   289  
   290  func sliceSize(xs []byte) int64 {
   291  	// Slice Header + Slice Contents
   292  	return int64(int(unsafe.Sizeof(xs)) + len(xs))
   293  }
   294  
   295  // DispatchLookupResources implements dispatch.LookupResources interface.
   296  func (cd *Dispatcher) DispatchLookupResources(req *v1.DispatchLookupResourcesRequest, stream dispatch.LookupResourcesStream) error {
   297  	cd.lookupResourcesTotalCounter.Inc()
   298  
   299  	requestKey, err := cd.keyHandler.LookupResourcesCacheKey(stream.Context(), req)
   300  	if err != nil {
   301  		return err
   302  	}
   303  
   304  	if cachedResultRaw, found := cd.c.Get(requestKey); found {
   305  		cd.lookupResourcesFromCacheCounter.Inc()
   306  		for _, slice := range cachedResultRaw.([][]byte) {
   307  			var response v1.DispatchLookupResourcesResponse
   308  			if err := response.UnmarshalVT(slice); err != nil {
   309  				return err
   310  			}
   311  			if err := stream.Publish(&response); err != nil {
   312  				// don't wrap error with additional context, as it may be a grpc status.Status.
   313  				// status.FromError() is unable to unwrap status.Status values, and as a consequence
   314  				// the Dispatcher wouldn't properly propagate the gRPC error code
   315  				return err
   316  			}
   317  		}
   318  		return nil
   319  	}
   320  
   321  	var (
   322  		mu             sync.Mutex
   323  		toCacheResults [][]byte
   324  	)
   325  	wrapped := &dispatch.WrappedDispatchStream[*v1.DispatchLookupResourcesResponse]{
   326  		Stream: stream,
   327  		Ctx:    stream.Context(),
   328  		Processor: func(result *v1.DispatchLookupResourcesResponse) (*v1.DispatchLookupResourcesResponse, bool, error) {
   329  			adjustedResult := result.CloneVT()
   330  			adjustedResult.Metadata.CachedDispatchCount = adjustedResult.Metadata.DispatchCount
   331  			adjustedResult.Metadata.DispatchCount = 0
   332  			adjustedResult.Metadata.DebugInfo = nil
   333  
   334  			adjustedBytes, err := adjustedResult.MarshalVT()
   335  			if err != nil {
   336  				return &v1.DispatchLookupResourcesResponse{Metadata: &v1.ResponseMeta{}}, false, err
   337  			}
   338  
   339  			mu.Lock()
   340  			toCacheResults = append(toCacheResults, adjustedBytes)
   341  			mu.Unlock()
   342  
   343  			return result, true, nil
   344  		},
   345  	}
   346  
   347  	if err := cd.d.DispatchLookupResources(req, wrapped); err != nil {
   348  		return err
   349  	}
   350  
   351  	var size int64
   352  	for _, slice := range toCacheResults {
   353  		size += sliceSize(slice)
   354  	}
   355  
   356  	cd.c.Set(requestKey, toCacheResults, size)
   357  	return nil
   358  }
   359  
   360  // DispatchLookupSubjects implements dispatch.LookupSubjects interface.
   361  func (cd *Dispatcher) DispatchLookupSubjects(req *v1.DispatchLookupSubjectsRequest, stream dispatch.LookupSubjectsStream) error {
   362  	cd.lookupSubjectsTotalCounter.Inc()
   363  
   364  	requestKey, err := cd.keyHandler.LookupSubjectsCacheKey(stream.Context(), req)
   365  	if err != nil {
   366  		return err
   367  	}
   368  
   369  	if cachedResultRaw, found := cd.c.Get(requestKey); found {
   370  		cd.lookupSubjectsFromCacheCounter.Inc()
   371  		for _, slice := range cachedResultRaw.([][]byte) {
   372  			var response v1.DispatchLookupSubjectsResponse
   373  			if err := response.UnmarshalVT(slice); err != nil {
   374  				return err
   375  			}
   376  			if err := stream.Publish(&response); err != nil {
   377  				// don't wrap error with additional context, as it may be a grpc status.Status.
   378  				// status.FromError() is unable to unwrap status.Status values, and as a consequence
   379  				// the Dispatcher wouldn't properly propagate the gRPC error code
   380  				return err
   381  			}
   382  		}
   383  		return nil
   384  	}
   385  
   386  	var (
   387  		mu             sync.Mutex
   388  		toCacheResults [][]byte
   389  	)
   390  	wrapped := &dispatch.WrappedDispatchStream[*v1.DispatchLookupSubjectsResponse]{
   391  		Stream: stream,
   392  		Ctx:    stream.Context(),
   393  		Processor: func(result *v1.DispatchLookupSubjectsResponse) (*v1.DispatchLookupSubjectsResponse, bool, error) {
   394  			adjustedResult := result.CloneVT()
   395  			adjustedResult.Metadata.CachedDispatchCount = adjustedResult.Metadata.DispatchCount
   396  			adjustedResult.Metadata.DispatchCount = 0
   397  			adjustedResult.Metadata.DebugInfo = nil
   398  
   399  			adjustedBytes, err := adjustedResult.MarshalVT()
   400  			if err != nil {
   401  				return &v1.DispatchLookupSubjectsResponse{Metadata: &v1.ResponseMeta{}}, false, err
   402  			}
   403  
   404  			mu.Lock()
   405  			toCacheResults = append(toCacheResults, adjustedBytes)
   406  			mu.Unlock()
   407  
   408  			return result, true, nil
   409  		},
   410  	}
   411  
   412  	if err := cd.d.DispatchLookupSubjects(req, wrapped); err != nil {
   413  		return err
   414  	}
   415  
   416  	var size int64
   417  	for _, slice := range toCacheResults {
   418  		size += sliceSize(slice)
   419  	}
   420  
   421  	cd.c.Set(requestKey, toCacheResults, size)
   422  	return nil
   423  }
   424  
   425  func (cd *Dispatcher) Close() error {
   426  	prometheus.Unregister(cd.checkTotalCounter)
   427  	prometheus.Unregister(cd.checkFromCacheCounter)
   428  	prometheus.Unregister(cd.reachableResourcesTotalCounter)
   429  	prometheus.Unregister(cd.reachableResourcesFromCacheCounter)
   430  	prometheus.Unregister(cd.lookupResourcesTotalCounter)
   431  	prometheus.Unregister(cd.lookupResourcesFromCacheCounter)
   432  	prometheus.Unregister(cd.lookupSubjectsFromCacheCounter)
   433  	prometheus.Unregister(cd.lookupSubjectsTotalCounter)
   434  	if cache := cd.c; cache != nil {
   435  		cache.Close()
   436  	}
   437  
   438  	return nil
   439  }
   440  
   441  func (cd *Dispatcher) ReadyState() dispatch.ReadyState {
   442  	if cd.c == nil {
   443  		return dispatch.ReadyState{
   444  			IsReady: false,
   445  			Message: "caching dispatcher is missing cache",
   446  		}
   447  	}
   448  
   449  	if cd.d == nil {
   450  		return dispatch.ReadyState{
   451  			IsReady: false,
   452  			Message: "caching dispatcher is missing delegate dispatcher",
   453  		}
   454  	}
   455  
   456  	return cd.d.ReadyState()
   457  }
   458  
   459  // Always verify that we implement the interfaces
   460  var _ dispatch.Dispatcher = &Dispatcher{}