github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/cmd/server/middleware.go (about) 1 package server 2 3 import ( 4 "context" 5 "fmt" 6 7 "github.com/authzed/spicedb/pkg/genutil/mapz" 8 "github.com/authzed/spicedb/pkg/spiceerrors" 9 10 middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2" 11 "google.golang.org/grpc" 12 ) 13 14 type middlewareTypes interface { 15 grpc.UnaryServerInterceptor | grpc.StreamServerInterceptor 16 } 17 18 // MiddlewareChain describes an ordered sequence of middlewares that can be modified 19 // with one or more MiddlewareModification. This struct is used to facilitate the 20 // creation and modification of gRPC middleware chains 21 type MiddlewareChain[T middlewareTypes] struct { 22 chain []ReferenceableMiddleware[T] 23 } 24 25 // NewMiddlewareChain creates a new middleware chain given zero or more named middlewares. 26 // An error will be returned in case validation of the NamedMiddlewares fail. 27 func NewMiddlewareChain[T middlewareTypes](mw ...ReferenceableMiddleware[T]) (MiddlewareChain[T], error) { 28 if err := validate(mw); err != nil { 29 return MiddlewareChain[T]{}, err 30 } 31 return MiddlewareChain[T]{chain: mw}, nil 32 } 33 34 // MiddlewareModification describes an operation to modify a MiddlewareChain 35 type MiddlewareModification[T middlewareTypes] struct { 36 // DependencyMiddlewareName is used to define with respect to which middleware an operation is performed. 37 // Dependency is not required for ReplaceAll operation 38 DependencyMiddlewareName string 39 40 // Operation describes the type of operation to be performed 41 Operation MiddlewareOperation 42 43 // Middlewares are the named middlewares that will be part of this modification 44 Middlewares []ReferenceableMiddleware[T] 45 } 46 47 func (mm MiddlewareModification[T]) validate() error { 48 if mm.Operation != OperationReplaceAllUnsafe && mm.DependencyMiddlewareName == "" { 49 return fmt.Errorf("cannot perform middleware modification without a dependency: %v", mm) 50 } 51 return validate(mm.Middlewares) 52 } 53 54 func validate[T middlewareTypes](mws []ReferenceableMiddleware[T]) error { 55 names := mapz.NewSet[string]() 56 for _, mw := range mws { 57 if mw.Name == "" { 58 return fmt.Errorf("unnamed middleware found: %v", mw) 59 } 60 if !names.Add(mw.Name) { 61 return fmt.Errorf("found middleware with duplicate names in middleware modification: %s", mw.Name) 62 } 63 } 64 return nil 65 } 66 67 // ReferenceableMiddleware represents a middleware in a MiddlewareChain. Middlewares can 68 // be referenced by name in MiddlewareModification, for example "append after middleware abc". 69 // Internal middlewares can also be referenced for operations like append or prepend, but cannot 70 // be referenced for replace operations. Middlewares must always be named. 71 type ReferenceableMiddleware[T middlewareTypes] struct { 72 Name string 73 Internal bool 74 Middleware T 75 } 76 77 // MiddlewareOperation describes the type of operation that will be performed in a MiddlewareModification 78 type MiddlewareOperation int 79 80 const ( 81 // OperationPrepend adds the middlewares right before the referenced dependency 82 OperationPrepend MiddlewareOperation = iota 83 84 // OperationReplace substitutes the referenced dependency with the middlewares of a modification. 85 // If replaced with an empty modification, this acts like a deletion 86 OperationReplace 87 88 // OperationAppend adds the middlewares right after the referenced dependency 89 OperationAppend 90 91 // OperationReplaceAllUnsafe replaces all middlewares in a chain with those in the modification 92 // this operation is only meant to be used in tests. 93 OperationReplaceAllUnsafe 94 ) 95 96 // Names returns the names of the middlewares in a chain 97 func (mc *MiddlewareChain[T]) Names() *mapz.Set[string] { 98 names := mapz.NewSet[string]() 99 for _, mw := range mc.chain { 100 names.Insert(mw.Name) 101 } 102 return names 103 } 104 105 // ToGRPCInterceptors generates slices of gRPC interceptors ready to be installed in a server 106 func (mc *MiddlewareChain[T]) ToGRPCInterceptors() []T { 107 interceptors := make([]T, 0, len(mc.chain)) 108 for _, mw := range mc.chain { 109 interceptors = append(interceptors, mw.Middleware) 110 } 111 return interceptors 112 } 113 114 func (mc *MiddlewareChain[T]) prepend(mod MiddlewareModification[T]) error { 115 if err := mc.validate(mod); err != nil { 116 return err 117 } 118 119 newChain := make([]ReferenceableMiddleware[T], 0, len(mc.chain)) 120 for _, mw := range mc.chain { 121 if mw.Name == mod.DependencyMiddlewareName { 122 newChain = append(newChain, mod.Middlewares...) 123 } 124 newChain = append(newChain, mw) 125 } 126 mc.chain = newChain 127 return nil 128 } 129 130 func (mc *MiddlewareChain[T]) replace(mod MiddlewareModification[T]) error { 131 if err := mc.validate(mod); err != nil { 132 return err 133 } 134 newChain := make([]ReferenceableMiddleware[T], 0, len(mc.chain)) 135 for _, mw := range mc.chain { 136 if mw.Name == mod.DependencyMiddlewareName { 137 newChain = append(newChain, mod.Middlewares...) 138 } else { 139 newChain = append(newChain, mw) 140 } 141 } 142 mc.chain = newChain 143 return nil 144 } 145 146 func (mc *MiddlewareChain[T]) append(mod MiddlewareModification[T]) error { 147 if err := mc.validate(mod); err != nil { 148 return err 149 } 150 151 newChain := make([]ReferenceableMiddleware[T], 0, len(mc.chain)) 152 for _, mw := range mc.chain { 153 newChain = append(newChain, mw) 154 if mw.Name == mod.DependencyMiddlewareName { 155 newChain = append(newChain, mod.Middlewares...) 156 } 157 } 158 mc.chain = newChain 159 return nil 160 } 161 162 func (mc *MiddlewareChain[T]) replaceAll(mod MiddlewareModification[T]) error { 163 if err := mod.validate(); err != nil { 164 return err 165 } 166 mc.chain = mod.Middlewares 167 return nil 168 } 169 170 func (mc *MiddlewareChain[T]) validate(mod MiddlewareModification[T]) error { 171 if err := mod.validate(); err != nil { 172 return err 173 } 174 175 // prevent referencing non-existing middlewares 176 existingNames := mc.Names() 177 if !existingNames.Has(mod.DependencyMiddlewareName) { 178 return fmt.Errorf("referenced dependency does not exist on chain: %s", mod.DependencyMiddlewareName) 179 } 180 181 // prevent appending/prepending a duplicate middleware 182 for _, mw := range mod.Middlewares { 183 if existingNames.Has(mw.Name) && mod.DependencyMiddlewareName == mw.Name && mod.Operation != OperationReplace { 184 return fmt.Errorf("modification will cause a duplicate in chain: %s", mw.Name) 185 } 186 } 187 188 // prevent replacing an internal middleware 189 for _, mw := range mc.chain { 190 if mw.Internal && mw.Name == mod.DependencyMiddlewareName && mod.Operation == OperationReplace { 191 return fmt.Errorf("modification attempts to replace an internal middleware: %s", mw.Name) 192 } 193 } 194 return nil 195 } 196 197 func (mc *MiddlewareChain[T]) modify(modifications ...MiddlewareModification[T]) error { 198 for _, mod := range modifications { 199 switch mod.Operation { 200 case OperationPrepend: 201 if err := mc.prepend(mod); err != nil { 202 return err 203 } 204 case OperationReplace: 205 if err := mc.replace(mod); err != nil { 206 return err 207 } 208 case OperationReplaceAllUnsafe: 209 if err := mc.replaceAll(mod); err != nil { 210 return err 211 } 212 case OperationAppend: 213 if err := mc.append(mod); err != nil { 214 return err 215 } 216 } 217 } 218 return nil 219 } 220 221 type streamOrderAssertion struct { 222 grpc.ServerStream 223 name string 224 alreadyExecuted string 225 notExecuted string 226 } 227 228 func (o streamOrderAssertion) RecvMsg(m any) error { 229 if err := mustHaveExecuted(o.Context(), streamExecuted{}, o.alreadyExecuted); err != nil { 230 return err 231 } 232 233 if err := mustHaveNotExecuted(o.Context(), streamExecuted{}, o.notExecuted); err != nil { 234 return err 235 } 236 237 mustMarkAsExecuted(o.Context(), streamExecuted{}, o.name) 238 err := o.ServerStream.RecvMsg(m) 239 return err 240 } 241 242 func (o streamOrderAssertion) SendMsg(m any) error { 243 return o.ServerStream.SendMsg(m) 244 } 245 246 func NewStreamMiddleware() *StreamOrderEnforcerBuilder { 247 return &StreamOrderEnforcerBuilder{} 248 } 249 250 type StreamOrderEnforcerBuilder struct { 251 name string 252 streamInterceptor grpc.StreamServerInterceptor 253 internal bool 254 interceptorExecuted string 255 interceptorNotExecuted string 256 streamWrapperExecuted string 257 streamWrapperNotExecuted string 258 } 259 260 func (soeb *StreamOrderEnforcerBuilder) WithName(name string) *StreamOrderEnforcerBuilder { 261 soeb.name = name 262 return soeb 263 } 264 265 func (soeb *StreamOrderEnforcerBuilder) WithInterceptor(interceptor grpc.StreamServerInterceptor) *StreamOrderEnforcerBuilder { 266 soeb.streamInterceptor = interceptor 267 return soeb 268 } 269 270 func (soeb *StreamOrderEnforcerBuilder) WithInternal(internal bool) *StreamOrderEnforcerBuilder { 271 soeb.internal = internal 272 return soeb 273 } 274 275 func (soeb *StreamOrderEnforcerBuilder) EnsureWrapperAlreadyExecuted(name string) *StreamOrderEnforcerBuilder { 276 soeb.streamWrapperExecuted = name 277 return soeb 278 } 279 280 func (soeb *StreamOrderEnforcerBuilder) EnsureWrapperNotExecuted(name string) *StreamOrderEnforcerBuilder { 281 soeb.streamWrapperNotExecuted = name 282 return soeb 283 } 284 285 func (soeb *StreamOrderEnforcerBuilder) EnsureInterceptorAlreadyExecuted(name string) *StreamOrderEnforcerBuilder { 286 soeb.interceptorExecuted = name 287 return soeb 288 } 289 290 func (soeb *StreamOrderEnforcerBuilder) EnsureInterceptorNotExecuted(name string) *StreamOrderEnforcerBuilder { 291 soeb.interceptorNotExecuted = name 292 return soeb 293 } 294 295 func (soeb *StreamOrderEnforcerBuilder) Done() ReferenceableMiddleware[grpc.StreamServerInterceptor] { 296 if !spiceerrors.IsInTests() { 297 return ReferenceableMiddleware[grpc.StreamServerInterceptor]{ 298 Name: soeb.name, 299 Internal: soeb.internal, 300 Middleware: soeb.streamInterceptor, 301 } 302 } 303 304 return ReferenceableMiddleware[grpc.StreamServerInterceptor]{ 305 Name: soeb.name, 306 Internal: soeb.internal, 307 Middleware: func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 308 wss := middleware.WrapServerStream(ss) 309 if wss.WrappedContext.Value(streamExecuted{}) == nil { 310 handle := executedHandle{executed: make(map[string]struct{}, 0)} 311 wss.WrappedContext = context.WithValue(wss.WrappedContext, streamExecuted{}, &handle) 312 } 313 if wss.WrappedContext.Value(interceptorsExecuted{}) == nil { 314 handle := executedHandle{executed: make(map[string]struct{}, 0)} 315 wss.WrappedContext = context.WithValue(wss.WrappedContext, interceptorsExecuted{}, &handle) 316 } 317 318 if err := mustHaveExecuted(wss.WrappedContext, interceptorsExecuted{}, soeb.interceptorExecuted); err != nil { 319 return err 320 } 321 322 if err := mustHaveNotExecuted(wss.WrappedContext, interceptorsExecuted{}, soeb.interceptorNotExecuted); err != nil { 323 return err 324 } 325 326 mustMarkAsExecuted(wss.WrappedContext, interceptorsExecuted{}, soeb.name) 327 328 wrappedStream := streamOrderAssertion{ 329 ServerStream: wss, 330 name: soeb.name, 331 alreadyExecuted: soeb.streamWrapperExecuted, 332 notExecuted: soeb.streamWrapperNotExecuted, 333 } 334 return soeb.streamInterceptor(srv, wrappedStream, info, handler) 335 }, 336 } 337 } 338 339 func NewUnaryMiddleware() *UnaryOrderEnforcerBuilder { 340 return &UnaryOrderEnforcerBuilder{} 341 } 342 343 type UnaryOrderEnforcerBuilder struct { 344 name string 345 interceptor grpc.UnaryServerInterceptor 346 internal bool 347 alreadyExecuted string 348 notExecuted string 349 } 350 351 func (soeb *UnaryOrderEnforcerBuilder) WithName(name string) *UnaryOrderEnforcerBuilder { 352 soeb.name = name 353 return soeb 354 } 355 356 func (soeb *UnaryOrderEnforcerBuilder) WithInterceptor(interceptor grpc.UnaryServerInterceptor) *UnaryOrderEnforcerBuilder { 357 soeb.interceptor = interceptor 358 return soeb 359 } 360 361 func (soeb *UnaryOrderEnforcerBuilder) WithInternal(internal bool) *UnaryOrderEnforcerBuilder { 362 soeb.internal = internal 363 return soeb 364 } 365 366 func (soeb *UnaryOrderEnforcerBuilder) EnsureAlreadyExecuted(name string) *UnaryOrderEnforcerBuilder { 367 soeb.alreadyExecuted = name 368 return soeb 369 } 370 371 func (soeb *UnaryOrderEnforcerBuilder) EnsureNotExecuted(name string) *UnaryOrderEnforcerBuilder { 372 soeb.notExecuted = name 373 return soeb 374 } 375 376 func (soeb *UnaryOrderEnforcerBuilder) Done() ReferenceableMiddleware[grpc.UnaryServerInterceptor] { 377 if !spiceerrors.IsInTests() { 378 return ReferenceableMiddleware[grpc.UnaryServerInterceptor]{ 379 Name: soeb.name, 380 Internal: soeb.internal, 381 Middleware: soeb.interceptor, 382 } 383 } 384 385 return ReferenceableMiddleware[grpc.UnaryServerInterceptor]{ 386 Name: soeb.name, 387 Internal: soeb.internal, 388 Middleware: func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { 389 if ctx.Value(interceptorsExecuted{}) == nil { 390 handle := executedHandle{executed: make(map[string]struct{}, 0)} 391 ctx = context.WithValue(ctx, interceptorsExecuted{}, &handle) 392 } 393 394 if err := mustHaveExecuted(ctx, interceptorsExecuted{}, soeb.alreadyExecuted); err != nil { 395 return nil, err 396 } 397 398 if err := mustHaveNotExecuted(ctx, interceptorsExecuted{}, soeb.notExecuted); err != nil { 399 return nil, err 400 } 401 402 mustMarkAsExecuted(ctx, interceptorsExecuted{}, soeb.name) 403 return soeb.interceptor(ctx, req, info, handler) 404 }, 405 } 406 } 407 408 func mustHaveNotExecuted(ctx context.Context, handleKey any, notExecuted string) error { 409 if notExecuted == "" { 410 return nil 411 } 412 413 val := ctx.Value(handleKey) 414 if val == nil { 415 return fmt.Errorf("interception order validation bookkeeping not present in context") 416 } 417 418 handle := val.(*executedHandle) 419 if _, ok := handle.executed[notExecuted]; ok { 420 return fmt.Errorf("expected interceptor %s to be not already executed", notExecuted) 421 } 422 423 return nil 424 } 425 426 func mustHaveExecuted(ctx context.Context, handleKey any, expectedExecuted string) error { 427 if expectedExecuted == "" { 428 return nil 429 } 430 431 val := ctx.Value(handleKey) 432 if val == nil { 433 return spiceerrors.MustBugf("interception order validation bookkeeping not present in context") 434 } 435 436 handle := val.(*executedHandle) 437 if _, ok := handle.executed[expectedExecuted]; ok { 438 return nil 439 } 440 441 return fmt.Errorf("expected interceptor %s to be already executed", expectedExecuted) 442 } 443 444 func mustMarkAsExecuted(ctx context.Context, handleKey any, name string) { 445 val := ctx.Value(handleKey) 446 if val == nil { 447 panic("handle should exist") 448 } 449 450 handle := val.(*executedHandle) 451 handle.executed[name] = struct{}{} 452 } 453 454 type executedHandle struct { 455 executed map[string]struct{} 456 } 457 458 type interceptorsExecuted struct{} 459 460 type streamExecuted struct{}