github.com/openfga/openfga@v1.5.4-rc1/pkg/middleware/validator/validator.go (about) 1 package validator 2 3 import ( 4 "context" 5 6 grpcvalidator "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" 7 "google.golang.org/grpc" 8 ) 9 10 type ctxKey string 11 12 var ( 13 requestIsValidatedCtxKey = ctxKey("request-validated") 14 ) 15 16 func contextWithRequestIsValidated(ctx context.Context) context.Context { 17 return context.WithValue(ctx, requestIsValidatedCtxKey, true) 18 } 19 20 // RequestIsValidatedFromContext returns true if the provided context object has the flag 21 // indicating that the request has been validated and if its value is set to true. 22 func RequestIsValidatedFromContext(ctx context.Context) bool { 23 validated, ok := ctx.Value(requestIsValidatedCtxKey).(bool) 24 return validated && ok 25 } 26 27 // UnaryServerInterceptor returns a new unary server interceptor that runs request validations 28 // and injects a bool in the context indicating that validation has been run. 29 func UnaryServerInterceptor() grpc.UnaryServerInterceptor { 30 validator := grpcvalidator.UnaryServerInterceptor() 31 32 return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 33 return validator(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) { 34 return handler(contextWithRequestIsValidated(ctx), req) 35 }) 36 } 37 } 38 39 // StreamServerInterceptor returns a new streaming server interceptor that runs request validations 40 // and injects a bool in the context indicating that validation has been run. 41 func StreamServerInterceptor() grpc.StreamServerInterceptor { 42 validator := grpcvalidator.StreamServerInterceptor() 43 44 return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 45 return validator(srv, stream, info, func(srv interface{}, ss grpc.ServerStream) error { 46 return handler(srv, &recvWrapper{ 47 ctx: contextWithRequestIsValidated(stream.Context()), 48 ServerStream: ss, 49 }) 50 }) 51 } 52 } 53 54 type recvWrapper struct { 55 ctx context.Context 56 grpc.ServerStream 57 } 58 59 // Context returns the context associated with the recvWrapper. 60 func (r *recvWrapper) Context() context.Context { 61 return r.ctx 62 }