github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/middleware/dispatcher/dispatcher.go (about)

     1  package dispatcher
     2  
     3  import (
     4  	"context"
     5  
     6  	middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
     7  	"google.golang.org/grpc"
     8  
     9  	"github.com/authzed/spicedb/internal/dispatch"
    10  )
    11  
    12  type ctxKeyType struct{}
    13  
    14  var dispatcherKey ctxKeyType = struct{}{}
    15  
    16  type dispatchHandle struct {
    17  	dispatcher dispatch.Dispatcher
    18  }
    19  
    20  // ContextWithHandle adds a placeholder to a context that will later be
    21  // filled by the dispatcher
    22  func ContextWithHandle(ctx context.Context) context.Context {
    23  	return context.WithValue(ctx, dispatcherKey, &dispatchHandle{})
    24  }
    25  
    26  // FromContext reads the selected dispatcher out of a context.Context
    27  // and returns nil if it does not exist.
    28  func FromContext(ctx context.Context) dispatch.Dispatcher {
    29  	if c := ctx.Value(dispatcherKey); c != nil {
    30  		handle := c.(*dispatchHandle)
    31  		return handle.dispatcher
    32  	}
    33  	return nil
    34  }
    35  
    36  // MustFromContext reads the selected dispatcher out of a context.Context, computes a zedtoken
    37  // from it, and panics if it has not been set on the context.
    38  func MustFromContext(ctx context.Context) dispatch.Dispatcher {
    39  	dispatcher := FromContext(ctx)
    40  	if dispatcher == nil {
    41  		panic("dispatcher middleware did not inject dispatcher")
    42  	}
    43  
    44  	return dispatcher
    45  }
    46  
    47  // SetInContext adds a dispatcher to the given context
    48  func SetInContext(ctx context.Context, dispatcher dispatch.Dispatcher) error {
    49  	handle := ctx.Value(dispatcherKey)
    50  	if handle == nil {
    51  		return nil
    52  	}
    53  	handle.(*dispatchHandle).dispatcher = dispatcher
    54  	return nil
    55  }
    56  
    57  // UnaryServerInterceptor returns a new unary server interceptor that adds the
    58  // dispatcher to the context
    59  func UnaryServerInterceptor(dispatcher dispatch.Dispatcher) grpc.UnaryServerInterceptor {
    60  	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    61  		newCtx := ContextWithHandle(ctx)
    62  		if err := SetInContext(newCtx, dispatcher); err != nil {
    63  			return nil, err
    64  		}
    65  
    66  		return handler(newCtx, req)
    67  	}
    68  }
    69  
    70  // StreamServerInterceptor returns a new stream server interceptor that adds the
    71  // dispatcher to the context
    72  func StreamServerInterceptor(dispatcher dispatch.Dispatcher) grpc.StreamServerInterceptor {
    73  	return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    74  		wrapped := middleware.WrapServerStream(stream)
    75  		wrapped.WrappedContext = ContextWithHandle(wrapped.WrappedContext)
    76  		if err := SetInContext(wrapped.WrappedContext, dispatcher); err != nil {
    77  			return err
    78  		}
    79  		return handler(srv, wrapped)
    80  	}
    81  }