github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/dispatch/caching/caching.go (about) 1 package caching 2 3 import ( 4 "context" 5 "fmt" 6 "maps" 7 "sync" 8 "testing" 9 "unsafe" 10 11 "github.com/dustin/go-humanize" 12 "github.com/prometheus/client_golang/prometheus" 13 "github.com/stretchr/testify/require" 14 "go.opentelemetry.io/otel/attribute" 15 "go.opentelemetry.io/otel/trace" 16 17 "github.com/authzed/spicedb/internal/dispatch" 18 "github.com/authzed/spicedb/internal/dispatch/keys" 19 "github.com/authzed/spicedb/pkg/cache" 20 v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1" 21 ) 22 23 const ( 24 errCachingInitialization = "error initializing caching dispatcher: %w" 25 26 prometheusNamespace = "spicedb" 27 ) 28 29 // Dispatcher is a dispatcher with cacheInst-in caching. 30 type Dispatcher struct { 31 d dispatch.Dispatcher 32 c cache.Cache 33 keyHandler keys.Handler 34 35 checkTotalCounter prometheus.Counter 36 checkFromCacheCounter prometheus.Counter 37 reachableResourcesTotalCounter prometheus.Counter 38 reachableResourcesFromCacheCounter prometheus.Counter 39 lookupResourcesTotalCounter prometheus.Counter 40 lookupResourcesFromCacheCounter prometheus.Counter 41 lookupSubjectsTotalCounter prometheus.Counter 42 lookupSubjectsFromCacheCounter prometheus.Counter 43 } 44 45 func DispatchTestCache(t testing.TB) cache.Cache { 46 cache, err := cache.NewCache(&cache.Config{ 47 NumCounters: 1000, 48 MaxCost: 1 * humanize.MiByte, 49 }) 50 require.Nil(t, err) 51 return cache 52 } 53 54 // NewCachingDispatcher creates a new dispatch.Dispatcher which delegates 55 // dispatch requests and caches the responses when possible and desirable. 56 func NewCachingDispatcher(cacheInst cache.Cache, metricsEnabled bool, prometheusSubsystem string, keyHandler keys.Handler) (*Dispatcher, error) { 57 if cacheInst == nil { 58 cacheInst = cache.NoopCache() 59 } 60 61 checkTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{ 62 Namespace: prometheusNamespace, 63 Subsystem: prometheusSubsystem, 64 Name: "check_total", 65 }) 66 checkFromCacheCounter := prometheus.NewCounter(prometheus.CounterOpts{ 67 Namespace: prometheusNamespace, 68 Subsystem: prometheusSubsystem, 69 Name: "check_from_cache_total", 70 }) 71 72 lookupResourcesTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{ 73 Namespace: prometheusNamespace, 74 Subsystem: prometheusSubsystem, 75 Name: "lookup_resources_total", 76 }) 77 lookupResourcesFromCacheCounter := prometheus.NewCounter(prometheus.CounterOpts{ 78 Namespace: prometheusNamespace, 79 Subsystem: prometheusSubsystem, 80 Name: "lookup_resources_from_cache_total", 81 }) 82 83 reachableResourcesTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{ 84 Namespace: prometheusNamespace, 85 Subsystem: prometheusSubsystem, 86 Name: "reachable_resources_total", 87 }) 88 reachableResourcesFromCacheCounter := prometheus.NewCounter(prometheus.CounterOpts{ 89 Namespace: prometheusNamespace, 90 Subsystem: prometheusSubsystem, 91 Name: "reachable_resources_from_cache_total", 92 }) 93 94 lookupSubjectsTotalCounter := prometheus.NewCounter(prometheus.CounterOpts{ 95 Namespace: prometheusNamespace, 96 Subsystem: prometheusSubsystem, 97 Name: "lookup_subjects_total", 98 }) 99 lookupSubjectsFromCacheCounter := prometheus.NewCounter(prometheus.CounterOpts{ 100 Namespace: prometheusNamespace, 101 Subsystem: prometheusSubsystem, 102 Name: "lookup_subjects_from_cache_total", 103 }) 104 105 if metricsEnabled && prometheusSubsystem != "" { 106 err := prometheus.Register(checkTotalCounter) 107 if err != nil { 108 return nil, fmt.Errorf(errCachingInitialization, err) 109 } 110 err = prometheus.Register(checkFromCacheCounter) 111 if err != nil { 112 return nil, fmt.Errorf(errCachingInitialization, err) 113 } 114 err = prometheus.Register(lookupResourcesTotalCounter) 115 if err != nil { 116 return nil, fmt.Errorf(errCachingInitialization, err) 117 } 118 err = prometheus.Register(lookupResourcesFromCacheCounter) 119 if err != nil { 120 return nil, fmt.Errorf(errCachingInitialization, err) 121 } 122 err = prometheus.Register(reachableResourcesTotalCounter) 123 if err != nil { 124 return nil, fmt.Errorf(errCachingInitialization, err) 125 } 126 err = prometheus.Register(reachableResourcesFromCacheCounter) 127 if err != nil { 128 return nil, fmt.Errorf(errCachingInitialization, err) 129 } 130 err = prometheus.Register(lookupSubjectsTotalCounter) 131 if err != nil { 132 return nil, fmt.Errorf(errCachingInitialization, err) 133 } 134 err = prometheus.Register(lookupSubjectsFromCacheCounter) 135 if err != nil { 136 return nil, fmt.Errorf(errCachingInitialization, err) 137 } 138 } 139 140 if keyHandler == nil { 141 keyHandler = &keys.DirectKeyHandler{} 142 } 143 144 return &Dispatcher{ 145 d: fakeDelegate{}, 146 c: cacheInst, 147 keyHandler: keyHandler, 148 checkTotalCounter: checkTotalCounter, 149 checkFromCacheCounter: checkFromCacheCounter, 150 reachableResourcesTotalCounter: reachableResourcesTotalCounter, 151 reachableResourcesFromCacheCounter: reachableResourcesFromCacheCounter, 152 lookupResourcesTotalCounter: lookupResourcesTotalCounter, 153 lookupResourcesFromCacheCounter: lookupResourcesFromCacheCounter, 154 lookupSubjectsTotalCounter: lookupSubjectsTotalCounter, 155 lookupSubjectsFromCacheCounter: lookupSubjectsFromCacheCounter, 156 }, nil 157 } 158 159 // SetDelegate sets the internal delegate to the specific dispatcher instance. 160 func (cd *Dispatcher) SetDelegate(delegate dispatch.Dispatcher) { 161 cd.d = delegate 162 } 163 164 // DispatchCheck implements dispatch.Check interface 165 func (cd *Dispatcher) DispatchCheck(ctx context.Context, req *v1.DispatchCheckRequest) (*v1.DispatchCheckResponse, error) { 166 cd.checkTotalCounter.Inc() 167 168 requestKey, err := cd.keyHandler.CheckCacheKey(ctx, req) 169 if err != nil { 170 return &v1.DispatchCheckResponse{Metadata: &v1.ResponseMeta{}}, err 171 } 172 173 // Disable caching when debugging is enabled. 174 span := trace.SpanFromContext(ctx) 175 if cachedResultRaw, found := cd.c.Get(requestKey); found { 176 var response v1.DispatchCheckResponse 177 if err := response.UnmarshalVT(cachedResultRaw.([]byte)); err != nil { 178 return &v1.DispatchCheckResponse{Metadata: &v1.ResponseMeta{}}, err 179 } 180 181 if req.Metadata.DepthRemaining >= response.Metadata.DepthRequired { 182 cd.checkFromCacheCounter.Inc() 183 // If debugging is requested, add the req and the response to the trace. 184 if req.Debug == v1.DispatchCheckRequest_ENABLE_BASIC_DEBUGGING { 185 response.Metadata.DebugInfo = &v1.DebugInformation{ 186 Check: &v1.CheckDebugTrace{ 187 Request: req, 188 Results: maps.Clone(response.ResultsByResourceId), 189 IsCachedResult: true, 190 }, 191 } 192 } 193 194 span.SetAttributes(attribute.Bool("cached", true)) 195 return &response, nil 196 } 197 } 198 span.SetAttributes(attribute.Bool("cached", false)) 199 computed, err := cd.d.DispatchCheck(ctx, req) 200 201 // We only want to cache the result if there was no error 202 if err == nil { 203 adjustedComputed := computed.CloneVT() 204 adjustedComputed.Metadata.CachedDispatchCount = adjustedComputed.Metadata.DispatchCount 205 adjustedComputed.Metadata.DispatchCount = 0 206 adjustedComputed.Metadata.DebugInfo = nil 207 208 adjustedBytes, err := adjustedComputed.MarshalVT() 209 if err != nil { 210 return &v1.DispatchCheckResponse{Metadata: &v1.ResponseMeta{}}, err 211 } 212 213 cd.c.Set(requestKey, adjustedBytes, sliceSize(adjustedBytes)) 214 } 215 216 // Return both the computed and err in ALL cases: computed contains resolved 217 // metadata even if there was an error. 218 return computed, err 219 } 220 221 // DispatchExpand implements dispatch.Expand interface and does not do any caching yet. 222 func (cd *Dispatcher) DispatchExpand(ctx context.Context, req *v1.DispatchExpandRequest) (*v1.DispatchExpandResponse, error) { 223 resp, err := cd.d.DispatchExpand(ctx, req) 224 return resp, err 225 } 226 227 // DispatchReachableResources implements dispatch.ReachableResources interface. 228 func (cd *Dispatcher) DispatchReachableResources(req *v1.DispatchReachableResourcesRequest, stream dispatch.ReachableResourcesStream) error { 229 cd.reachableResourcesTotalCounter.Inc() 230 231 requestKey, err := cd.keyHandler.ReachableResourcesCacheKey(stream.Context(), req) 232 if err != nil { 233 return err 234 } 235 236 if cachedResultRaw, found := cd.c.Get(requestKey); found { 237 cd.reachableResourcesFromCacheCounter.Inc() 238 for _, slice := range cachedResultRaw.([][]byte) { 239 var response v1.DispatchReachableResourcesResponse 240 if err := response.UnmarshalVT(slice); err != nil { 241 return fmt.Errorf("could not publish cached reachable resources result: %w", err) 242 } 243 if err := stream.Publish(&response); err != nil { 244 return fmt.Errorf("could not publish cached reachable resources result: %w", err) 245 } 246 } 247 248 return nil 249 } 250 251 var ( 252 mu sync.Mutex 253 toCacheResults [][]byte 254 ) 255 wrapped := &dispatch.WrappedDispatchStream[*v1.DispatchReachableResourcesResponse]{ 256 Stream: stream, 257 Ctx: stream.Context(), 258 Processor: func(result *v1.DispatchReachableResourcesResponse) (*v1.DispatchReachableResourcesResponse, bool, error) { 259 adjustedResult := result.CloneVT() 260 adjustedResult.Metadata.CachedDispatchCount = adjustedResult.Metadata.DispatchCount 261 adjustedResult.Metadata.DispatchCount = 0 262 adjustedResult.Metadata.DebugInfo = nil 263 264 adjustedBytes, err := adjustedResult.MarshalVT() 265 if err != nil { 266 return nil, false, err 267 } 268 269 mu.Lock() 270 toCacheResults = append(toCacheResults, adjustedBytes) 271 mu.Unlock() 272 273 return result, true, nil 274 }, 275 } 276 277 if err := cd.d.DispatchReachableResources(req, wrapped); err != nil { 278 return err 279 } 280 281 var size int64 282 for _, slice := range toCacheResults { 283 size += sliceSize(slice) 284 } 285 286 cd.c.Set(requestKey, toCacheResults, size) 287 return nil 288 } 289 290 func sliceSize(xs []byte) int64 { 291 // Slice Header + Slice Contents 292 return int64(int(unsafe.Sizeof(xs)) + len(xs)) 293 } 294 295 // DispatchLookupResources implements dispatch.LookupResources interface. 296 func (cd *Dispatcher) DispatchLookupResources(req *v1.DispatchLookupResourcesRequest, stream dispatch.LookupResourcesStream) error { 297 cd.lookupResourcesTotalCounter.Inc() 298 299 requestKey, err := cd.keyHandler.LookupResourcesCacheKey(stream.Context(), req) 300 if err != nil { 301 return err 302 } 303 304 if cachedResultRaw, found := cd.c.Get(requestKey); found { 305 cd.lookupResourcesFromCacheCounter.Inc() 306 for _, slice := range cachedResultRaw.([][]byte) { 307 var response v1.DispatchLookupResourcesResponse 308 if err := response.UnmarshalVT(slice); err != nil { 309 return err 310 } 311 if err := stream.Publish(&response); err != nil { 312 // don't wrap error with additional context, as it may be a grpc status.Status. 313 // status.FromError() is unable to unwrap status.Status values, and as a consequence 314 // the Dispatcher wouldn't properly propagate the gRPC error code 315 return err 316 } 317 } 318 return nil 319 } 320 321 var ( 322 mu sync.Mutex 323 toCacheResults [][]byte 324 ) 325 wrapped := &dispatch.WrappedDispatchStream[*v1.DispatchLookupResourcesResponse]{ 326 Stream: stream, 327 Ctx: stream.Context(), 328 Processor: func(result *v1.DispatchLookupResourcesResponse) (*v1.DispatchLookupResourcesResponse, bool, error) { 329 adjustedResult := result.CloneVT() 330 adjustedResult.Metadata.CachedDispatchCount = adjustedResult.Metadata.DispatchCount 331 adjustedResult.Metadata.DispatchCount = 0 332 adjustedResult.Metadata.DebugInfo = nil 333 334 adjustedBytes, err := adjustedResult.MarshalVT() 335 if err != nil { 336 return &v1.DispatchLookupResourcesResponse{Metadata: &v1.ResponseMeta{}}, false, err 337 } 338 339 mu.Lock() 340 toCacheResults = append(toCacheResults, adjustedBytes) 341 mu.Unlock() 342 343 return result, true, nil 344 }, 345 } 346 347 if err := cd.d.DispatchLookupResources(req, wrapped); err != nil { 348 return err 349 } 350 351 var size int64 352 for _, slice := range toCacheResults { 353 size += sliceSize(slice) 354 } 355 356 cd.c.Set(requestKey, toCacheResults, size) 357 return nil 358 } 359 360 // DispatchLookupSubjects implements dispatch.LookupSubjects interface. 361 func (cd *Dispatcher) DispatchLookupSubjects(req *v1.DispatchLookupSubjectsRequest, stream dispatch.LookupSubjectsStream) error { 362 cd.lookupSubjectsTotalCounter.Inc() 363 364 requestKey, err := cd.keyHandler.LookupSubjectsCacheKey(stream.Context(), req) 365 if err != nil { 366 return err 367 } 368 369 if cachedResultRaw, found := cd.c.Get(requestKey); found { 370 cd.lookupSubjectsFromCacheCounter.Inc() 371 for _, slice := range cachedResultRaw.([][]byte) { 372 var response v1.DispatchLookupSubjectsResponse 373 if err := response.UnmarshalVT(slice); err != nil { 374 return err 375 } 376 if err := stream.Publish(&response); err != nil { 377 // don't wrap error with additional context, as it may be a grpc status.Status. 378 // status.FromError() is unable to unwrap status.Status values, and as a consequence 379 // the Dispatcher wouldn't properly propagate the gRPC error code 380 return err 381 } 382 } 383 return nil 384 } 385 386 var ( 387 mu sync.Mutex 388 toCacheResults [][]byte 389 ) 390 wrapped := &dispatch.WrappedDispatchStream[*v1.DispatchLookupSubjectsResponse]{ 391 Stream: stream, 392 Ctx: stream.Context(), 393 Processor: func(result *v1.DispatchLookupSubjectsResponse) (*v1.DispatchLookupSubjectsResponse, bool, error) { 394 adjustedResult := result.CloneVT() 395 adjustedResult.Metadata.CachedDispatchCount = adjustedResult.Metadata.DispatchCount 396 adjustedResult.Metadata.DispatchCount = 0 397 adjustedResult.Metadata.DebugInfo = nil 398 399 adjustedBytes, err := adjustedResult.MarshalVT() 400 if err != nil { 401 return &v1.DispatchLookupSubjectsResponse{Metadata: &v1.ResponseMeta{}}, false, err 402 } 403 404 mu.Lock() 405 toCacheResults = append(toCacheResults, adjustedBytes) 406 mu.Unlock() 407 408 return result, true, nil 409 }, 410 } 411 412 if err := cd.d.DispatchLookupSubjects(req, wrapped); err != nil { 413 return err 414 } 415 416 var size int64 417 for _, slice := range toCacheResults { 418 size += sliceSize(slice) 419 } 420 421 cd.c.Set(requestKey, toCacheResults, size) 422 return nil 423 } 424 425 func (cd *Dispatcher) Close() error { 426 prometheus.Unregister(cd.checkTotalCounter) 427 prometheus.Unregister(cd.checkFromCacheCounter) 428 prometheus.Unregister(cd.reachableResourcesTotalCounter) 429 prometheus.Unregister(cd.reachableResourcesFromCacheCounter) 430 prometheus.Unregister(cd.lookupResourcesTotalCounter) 431 prometheus.Unregister(cd.lookupResourcesFromCacheCounter) 432 prometheus.Unregister(cd.lookupSubjectsFromCacheCounter) 433 prometheus.Unregister(cd.lookupSubjectsTotalCounter) 434 if cache := cd.c; cache != nil { 435 cache.Close() 436 } 437 438 return nil 439 } 440 441 func (cd *Dispatcher) ReadyState() dispatch.ReadyState { 442 if cd.c == nil { 443 return dispatch.ReadyState{ 444 IsReady: false, 445 Message: "caching dispatcher is missing cache", 446 } 447 } 448 449 if cd.d == nil { 450 return dispatch.ReadyState{ 451 IsReady: false, 452 Message: "caching dispatcher is missing delegate dispatcher", 453 } 454 } 455 456 return cd.d.ReadyState() 457 } 458 459 // Always verify that we implement the interfaces 460 var _ dispatch.Dispatcher = &Dispatcher{}