github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/cmd/server/middleware.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/authzed/spicedb/pkg/genutil/mapz"
     8  	"github.com/authzed/spicedb/pkg/spiceerrors"
     9  
    10  	middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
    11  	"google.golang.org/grpc"
    12  )
    13  
    14  type middlewareTypes interface {
    15  	grpc.UnaryServerInterceptor | grpc.StreamServerInterceptor
    16  }
    17  
    18  // MiddlewareChain describes an ordered sequence of middlewares that can be modified
    19  // with one or more MiddlewareModification. This struct is used to facilitate the
    20  // creation and modification of gRPC middleware chains
    21  type MiddlewareChain[T middlewareTypes] struct {
    22  	chain []ReferenceableMiddleware[T]
    23  }
    24  
    25  // NewMiddlewareChain creates a new middleware chain given zero or more named middlewares.
    26  // An error will be returned in case validation of the NamedMiddlewares fail.
    27  func NewMiddlewareChain[T middlewareTypes](mw ...ReferenceableMiddleware[T]) (MiddlewareChain[T], error) {
    28  	if err := validate(mw); err != nil {
    29  		return MiddlewareChain[T]{}, err
    30  	}
    31  	return MiddlewareChain[T]{chain: mw}, nil
    32  }
    33  
    34  // MiddlewareModification describes an operation to modify a MiddlewareChain
    35  type MiddlewareModification[T middlewareTypes] struct {
    36  	// DependencyMiddlewareName is used to define with respect to which middleware an operation is performed.
    37  	// Dependency is not required for ReplaceAll operation
    38  	DependencyMiddlewareName string
    39  
    40  	// Operation describes the type of operation to be performed
    41  	Operation MiddlewareOperation
    42  
    43  	// Middlewares are the named middlewares that will be part of this modification
    44  	Middlewares []ReferenceableMiddleware[T]
    45  }
    46  
    47  func (mm MiddlewareModification[T]) validate() error {
    48  	if mm.Operation != OperationReplaceAllUnsafe && mm.DependencyMiddlewareName == "" {
    49  		return fmt.Errorf("cannot perform middleware modification without a dependency: %v", mm)
    50  	}
    51  	return validate(mm.Middlewares)
    52  }
    53  
    54  func validate[T middlewareTypes](mws []ReferenceableMiddleware[T]) error {
    55  	names := mapz.NewSet[string]()
    56  	for _, mw := range mws {
    57  		if mw.Name == "" {
    58  			return fmt.Errorf("unnamed middleware found: %v", mw)
    59  		}
    60  		if !names.Add(mw.Name) {
    61  			return fmt.Errorf("found middleware with duplicate names in middleware modification: %s", mw.Name)
    62  		}
    63  	}
    64  	return nil
    65  }
    66  
    67  // ReferenceableMiddleware represents a middleware in a MiddlewareChain. Middlewares can
    68  // be referenced by name in MiddlewareModification, for example "append after middleware abc".
    69  // Internal middlewares can also be referenced for operations like append or prepend, but cannot
    70  // be referenced for replace operations. Middlewares must always be named.
    71  type ReferenceableMiddleware[T middlewareTypes] struct {
    72  	Name       string
    73  	Internal   bool
    74  	Middleware T
    75  }
    76  
    77  // MiddlewareOperation describes the type of operation that will be performed in a MiddlewareModification
    78  type MiddlewareOperation int
    79  
    80  const (
    81  	// OperationPrepend adds the middlewares right before the referenced dependency
    82  	OperationPrepend MiddlewareOperation = iota
    83  
    84  	// OperationReplace substitutes the referenced dependency with the middlewares of a modification.
    85  	// If replaced with an empty modification, this acts like a deletion
    86  	OperationReplace
    87  
    88  	// OperationAppend adds the middlewares right after the referenced dependency
    89  	OperationAppend
    90  
    91  	// OperationReplaceAllUnsafe replaces all middlewares in a chain with those in the modification
    92  	// this operation is only meant to be used in tests.
    93  	OperationReplaceAllUnsafe
    94  )
    95  
    96  // Names returns the names of the middlewares in a chain
    97  func (mc *MiddlewareChain[T]) Names() *mapz.Set[string] {
    98  	names := mapz.NewSet[string]()
    99  	for _, mw := range mc.chain {
   100  		names.Insert(mw.Name)
   101  	}
   102  	return names
   103  }
   104  
   105  // ToGRPCInterceptors generates slices of gRPC interceptors ready to be installed in a server
   106  func (mc *MiddlewareChain[T]) ToGRPCInterceptors() []T {
   107  	interceptors := make([]T, 0, len(mc.chain))
   108  	for _, mw := range mc.chain {
   109  		interceptors = append(interceptors, mw.Middleware)
   110  	}
   111  	return interceptors
   112  }
   113  
   114  func (mc *MiddlewareChain[T]) prepend(mod MiddlewareModification[T]) error {
   115  	if err := mc.validate(mod); err != nil {
   116  		return err
   117  	}
   118  
   119  	newChain := make([]ReferenceableMiddleware[T], 0, len(mc.chain))
   120  	for _, mw := range mc.chain {
   121  		if mw.Name == mod.DependencyMiddlewareName {
   122  			newChain = append(newChain, mod.Middlewares...)
   123  		}
   124  		newChain = append(newChain, mw)
   125  	}
   126  	mc.chain = newChain
   127  	return nil
   128  }
   129  
   130  func (mc *MiddlewareChain[T]) replace(mod MiddlewareModification[T]) error {
   131  	if err := mc.validate(mod); err != nil {
   132  		return err
   133  	}
   134  	newChain := make([]ReferenceableMiddleware[T], 0, len(mc.chain))
   135  	for _, mw := range mc.chain {
   136  		if mw.Name == mod.DependencyMiddlewareName {
   137  			newChain = append(newChain, mod.Middlewares...)
   138  		} else {
   139  			newChain = append(newChain, mw)
   140  		}
   141  	}
   142  	mc.chain = newChain
   143  	return nil
   144  }
   145  
   146  func (mc *MiddlewareChain[T]) append(mod MiddlewareModification[T]) error {
   147  	if err := mc.validate(mod); err != nil {
   148  		return err
   149  	}
   150  
   151  	newChain := make([]ReferenceableMiddleware[T], 0, len(mc.chain))
   152  	for _, mw := range mc.chain {
   153  		newChain = append(newChain, mw)
   154  		if mw.Name == mod.DependencyMiddlewareName {
   155  			newChain = append(newChain, mod.Middlewares...)
   156  		}
   157  	}
   158  	mc.chain = newChain
   159  	return nil
   160  }
   161  
   162  func (mc *MiddlewareChain[T]) replaceAll(mod MiddlewareModification[T]) error {
   163  	if err := mod.validate(); err != nil {
   164  		return err
   165  	}
   166  	mc.chain = mod.Middlewares
   167  	return nil
   168  }
   169  
   170  func (mc *MiddlewareChain[T]) validate(mod MiddlewareModification[T]) error {
   171  	if err := mod.validate(); err != nil {
   172  		return err
   173  	}
   174  
   175  	// prevent referencing non-existing middlewares
   176  	existingNames := mc.Names()
   177  	if !existingNames.Has(mod.DependencyMiddlewareName) {
   178  		return fmt.Errorf("referenced dependency does not exist on chain: %s", mod.DependencyMiddlewareName)
   179  	}
   180  
   181  	// prevent appending/prepending a duplicate middleware
   182  	for _, mw := range mod.Middlewares {
   183  		if existingNames.Has(mw.Name) && mod.DependencyMiddlewareName == mw.Name && mod.Operation != OperationReplace {
   184  			return fmt.Errorf("modification will cause a duplicate in chain: %s", mw.Name)
   185  		}
   186  	}
   187  
   188  	// prevent replacing an internal middleware
   189  	for _, mw := range mc.chain {
   190  		if mw.Internal && mw.Name == mod.DependencyMiddlewareName && mod.Operation == OperationReplace {
   191  			return fmt.Errorf("modification attempts to replace an internal middleware: %s", mw.Name)
   192  		}
   193  	}
   194  	return nil
   195  }
   196  
   197  func (mc *MiddlewareChain[T]) modify(modifications ...MiddlewareModification[T]) error {
   198  	for _, mod := range modifications {
   199  		switch mod.Operation {
   200  		case OperationPrepend:
   201  			if err := mc.prepend(mod); err != nil {
   202  				return err
   203  			}
   204  		case OperationReplace:
   205  			if err := mc.replace(mod); err != nil {
   206  				return err
   207  			}
   208  		case OperationReplaceAllUnsafe:
   209  			if err := mc.replaceAll(mod); err != nil {
   210  				return err
   211  			}
   212  		case OperationAppend:
   213  			if err := mc.append(mod); err != nil {
   214  				return err
   215  			}
   216  		}
   217  	}
   218  	return nil
   219  }
   220  
   221  type streamOrderAssertion struct {
   222  	grpc.ServerStream
   223  	name            string
   224  	alreadyExecuted string
   225  	notExecuted     string
   226  }
   227  
   228  func (o streamOrderAssertion) RecvMsg(m any) error {
   229  	if err := mustHaveExecuted(o.Context(), streamExecuted{}, o.alreadyExecuted); err != nil {
   230  		return err
   231  	}
   232  
   233  	if err := mustHaveNotExecuted(o.Context(), streamExecuted{}, o.notExecuted); err != nil {
   234  		return err
   235  	}
   236  
   237  	mustMarkAsExecuted(o.Context(), streamExecuted{}, o.name)
   238  	err := o.ServerStream.RecvMsg(m)
   239  	return err
   240  }
   241  
   242  func (o streamOrderAssertion) SendMsg(m any) error {
   243  	return o.ServerStream.SendMsg(m)
   244  }
   245  
   246  func NewStreamMiddleware() *StreamOrderEnforcerBuilder {
   247  	return &StreamOrderEnforcerBuilder{}
   248  }
   249  
   250  type StreamOrderEnforcerBuilder struct {
   251  	name                     string
   252  	streamInterceptor        grpc.StreamServerInterceptor
   253  	internal                 bool
   254  	interceptorExecuted      string
   255  	interceptorNotExecuted   string
   256  	streamWrapperExecuted    string
   257  	streamWrapperNotExecuted string
   258  }
   259  
   260  func (soeb *StreamOrderEnforcerBuilder) WithName(name string) *StreamOrderEnforcerBuilder {
   261  	soeb.name = name
   262  	return soeb
   263  }
   264  
   265  func (soeb *StreamOrderEnforcerBuilder) WithInterceptor(interceptor grpc.StreamServerInterceptor) *StreamOrderEnforcerBuilder {
   266  	soeb.streamInterceptor = interceptor
   267  	return soeb
   268  }
   269  
   270  func (soeb *StreamOrderEnforcerBuilder) WithInternal(internal bool) *StreamOrderEnforcerBuilder {
   271  	soeb.internal = internal
   272  	return soeb
   273  }
   274  
   275  func (soeb *StreamOrderEnforcerBuilder) EnsureWrapperAlreadyExecuted(name string) *StreamOrderEnforcerBuilder {
   276  	soeb.streamWrapperExecuted = name
   277  	return soeb
   278  }
   279  
   280  func (soeb *StreamOrderEnforcerBuilder) EnsureWrapperNotExecuted(name string) *StreamOrderEnforcerBuilder {
   281  	soeb.streamWrapperNotExecuted = name
   282  	return soeb
   283  }
   284  
   285  func (soeb *StreamOrderEnforcerBuilder) EnsureInterceptorAlreadyExecuted(name string) *StreamOrderEnforcerBuilder {
   286  	soeb.interceptorExecuted = name
   287  	return soeb
   288  }
   289  
   290  func (soeb *StreamOrderEnforcerBuilder) EnsureInterceptorNotExecuted(name string) *StreamOrderEnforcerBuilder {
   291  	soeb.interceptorNotExecuted = name
   292  	return soeb
   293  }
   294  
   295  func (soeb *StreamOrderEnforcerBuilder) Done() ReferenceableMiddleware[grpc.StreamServerInterceptor] {
   296  	if !spiceerrors.IsInTests() {
   297  		return ReferenceableMiddleware[grpc.StreamServerInterceptor]{
   298  			Name:       soeb.name,
   299  			Internal:   soeb.internal,
   300  			Middleware: soeb.streamInterceptor,
   301  		}
   302  	}
   303  
   304  	return ReferenceableMiddleware[grpc.StreamServerInterceptor]{
   305  		Name:     soeb.name,
   306  		Internal: soeb.internal,
   307  		Middleware: func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   308  			wss := middleware.WrapServerStream(ss)
   309  			if wss.WrappedContext.Value(streamExecuted{}) == nil {
   310  				handle := executedHandle{executed: make(map[string]struct{}, 0)}
   311  				wss.WrappedContext = context.WithValue(wss.WrappedContext, streamExecuted{}, &handle)
   312  			}
   313  			if wss.WrappedContext.Value(interceptorsExecuted{}) == nil {
   314  				handle := executedHandle{executed: make(map[string]struct{}, 0)}
   315  				wss.WrappedContext = context.WithValue(wss.WrappedContext, interceptorsExecuted{}, &handle)
   316  			}
   317  
   318  			if err := mustHaveExecuted(wss.WrappedContext, interceptorsExecuted{}, soeb.interceptorExecuted); err != nil {
   319  				return err
   320  			}
   321  
   322  			if err := mustHaveNotExecuted(wss.WrappedContext, interceptorsExecuted{}, soeb.interceptorNotExecuted); err != nil {
   323  				return err
   324  			}
   325  
   326  			mustMarkAsExecuted(wss.WrappedContext, interceptorsExecuted{}, soeb.name)
   327  
   328  			wrappedStream := streamOrderAssertion{
   329  				ServerStream:    wss,
   330  				name:            soeb.name,
   331  				alreadyExecuted: soeb.streamWrapperExecuted,
   332  				notExecuted:     soeb.streamWrapperNotExecuted,
   333  			}
   334  			return soeb.streamInterceptor(srv, wrappedStream, info, handler)
   335  		},
   336  	}
   337  }
   338  
   339  func NewUnaryMiddleware() *UnaryOrderEnforcerBuilder {
   340  	return &UnaryOrderEnforcerBuilder{}
   341  }
   342  
   343  type UnaryOrderEnforcerBuilder struct {
   344  	name            string
   345  	interceptor     grpc.UnaryServerInterceptor
   346  	internal        bool
   347  	alreadyExecuted string
   348  	notExecuted     string
   349  }
   350  
   351  func (soeb *UnaryOrderEnforcerBuilder) WithName(name string) *UnaryOrderEnforcerBuilder {
   352  	soeb.name = name
   353  	return soeb
   354  }
   355  
   356  func (soeb *UnaryOrderEnforcerBuilder) WithInterceptor(interceptor grpc.UnaryServerInterceptor) *UnaryOrderEnforcerBuilder {
   357  	soeb.interceptor = interceptor
   358  	return soeb
   359  }
   360  
   361  func (soeb *UnaryOrderEnforcerBuilder) WithInternal(internal bool) *UnaryOrderEnforcerBuilder {
   362  	soeb.internal = internal
   363  	return soeb
   364  }
   365  
   366  func (soeb *UnaryOrderEnforcerBuilder) EnsureAlreadyExecuted(name string) *UnaryOrderEnforcerBuilder {
   367  	soeb.alreadyExecuted = name
   368  	return soeb
   369  }
   370  
   371  func (soeb *UnaryOrderEnforcerBuilder) EnsureNotExecuted(name string) *UnaryOrderEnforcerBuilder {
   372  	soeb.notExecuted = name
   373  	return soeb
   374  }
   375  
   376  func (soeb *UnaryOrderEnforcerBuilder) Done() ReferenceableMiddleware[grpc.UnaryServerInterceptor] {
   377  	if !spiceerrors.IsInTests() {
   378  		return ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   379  			Name:       soeb.name,
   380  			Internal:   soeb.internal,
   381  			Middleware: soeb.interceptor,
   382  		}
   383  	}
   384  
   385  	return ReferenceableMiddleware[grpc.UnaryServerInterceptor]{
   386  		Name:     soeb.name,
   387  		Internal: soeb.internal,
   388  		Middleware: func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
   389  			if ctx.Value(interceptorsExecuted{}) == nil {
   390  				handle := executedHandle{executed: make(map[string]struct{}, 0)}
   391  				ctx = context.WithValue(ctx, interceptorsExecuted{}, &handle)
   392  			}
   393  
   394  			if err := mustHaveExecuted(ctx, interceptorsExecuted{}, soeb.alreadyExecuted); err != nil {
   395  				return nil, err
   396  			}
   397  
   398  			if err := mustHaveNotExecuted(ctx, interceptorsExecuted{}, soeb.notExecuted); err != nil {
   399  				return nil, err
   400  			}
   401  
   402  			mustMarkAsExecuted(ctx, interceptorsExecuted{}, soeb.name)
   403  			return soeb.interceptor(ctx, req, info, handler)
   404  		},
   405  	}
   406  }
   407  
   408  func mustHaveNotExecuted(ctx context.Context, handleKey any, notExecuted string) error {
   409  	if notExecuted == "" {
   410  		return nil
   411  	}
   412  
   413  	val := ctx.Value(handleKey)
   414  	if val == nil {
   415  		return fmt.Errorf("interception order validation bookkeeping not present in context")
   416  	}
   417  
   418  	handle := val.(*executedHandle)
   419  	if _, ok := handle.executed[notExecuted]; ok {
   420  		return fmt.Errorf("expected interceptor %s to be not already executed", notExecuted)
   421  	}
   422  
   423  	return nil
   424  }
   425  
   426  func mustHaveExecuted(ctx context.Context, handleKey any, expectedExecuted string) error {
   427  	if expectedExecuted == "" {
   428  		return nil
   429  	}
   430  
   431  	val := ctx.Value(handleKey)
   432  	if val == nil {
   433  		return spiceerrors.MustBugf("interception order validation bookkeeping not present in context")
   434  	}
   435  
   436  	handle := val.(*executedHandle)
   437  	if _, ok := handle.executed[expectedExecuted]; ok {
   438  		return nil
   439  	}
   440  
   441  	return fmt.Errorf("expected interceptor %s to be already executed", expectedExecuted)
   442  }
   443  
   444  func mustMarkAsExecuted(ctx context.Context, handleKey any, name string) {
   445  	val := ctx.Value(handleKey)
   446  	if val == nil {
   447  		panic("handle should exist")
   448  	}
   449  
   450  	handle := val.(*executedHandle)
   451  	handle.executed[name] = struct{}{}
   452  }
   453  
   454  type executedHandle struct {
   455  	executed map[string]struct{}
   456  }
   457  
   458  type interceptorsExecuted struct{}
   459  
   460  type streamExecuted struct{}