github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/middleware/dispatcher/dispatcher.go (about) 1 package dispatcher 2 3 import ( 4 "context" 5 6 middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" 7 "google.golang.org/grpc" 8 9 "github.com/authzed/spicedb/internal/dispatch" 10 ) 11 12 type ctxKeyType struct{} 13 14 var dispatcherKey ctxKeyType = struct{}{} 15 16 type dispatchHandle struct { 17 dispatcher dispatch.Dispatcher 18 } 19 20 // ContextWithHandle adds a placeholder to a context that will later be 21 // filled by the dispatcher 22 func ContextWithHandle(ctx context.Context) context.Context { 23 return context.WithValue(ctx, dispatcherKey, &dispatchHandle{}) 24 } 25 26 // FromContext reads the selected dispatcher out of a context.Context 27 // and returns nil if it does not exist. 28 func FromContext(ctx context.Context) dispatch.Dispatcher { 29 if c := ctx.Value(dispatcherKey); c != nil { 30 handle := c.(*dispatchHandle) 31 return handle.dispatcher 32 } 33 return nil 34 } 35 36 // MustFromContext reads the selected dispatcher out of a context.Context, computes a zedtoken 37 // from it, and panics if it has not been set on the context. 38 func MustFromContext(ctx context.Context) dispatch.Dispatcher { 39 dispatcher := FromContext(ctx) 40 if dispatcher == nil { 41 panic("dispatcher middleware did not inject dispatcher") 42 } 43 44 return dispatcher 45 } 46 47 // SetInContext adds a dispatcher to the given context 48 func SetInContext(ctx context.Context, dispatcher dispatch.Dispatcher) error { 49 handle := ctx.Value(dispatcherKey) 50 if handle == nil { 51 return nil 52 } 53 handle.(*dispatchHandle).dispatcher = dispatcher 54 return nil 55 } 56 57 // UnaryServerInterceptor returns a new unary server interceptor that adds the 58 // dispatcher to the context 59 func UnaryServerInterceptor(dispatcher dispatch.Dispatcher) grpc.UnaryServerInterceptor { 60 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 61 newCtx := ContextWithHandle(ctx) 62 if err := SetInContext(newCtx, dispatcher); err != nil { 63 return nil, err 64 } 65 66 return handler(newCtx, req) 67 } 68 } 69 70 // StreamServerInterceptor returns a new stream server interceptor that adds the 71 // dispatcher to the context 72 func StreamServerInterceptor(dispatcher dispatch.Dispatcher) grpc.StreamServerInterceptor { 73 return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 74 wrapped := middleware.WrapServerStream(stream) 75 wrapped.WrappedContext = ContextWithHandle(wrapped.WrappedContext) 76 if err := SetInContext(wrapped.WrappedContext, dispatcher); err != nil { 77 return err 78 } 79 return handler(srv, wrapped) 80 } 81 }