github.com/openfga/openfga@v1.5.4-rc1/pkg/middleware/validator/validator.go (about)

     1  package validator
     2  
     3  import (
     4  	"context"
     5  
     6  	grpcvalidator "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
     7  	"google.golang.org/grpc"
     8  )
     9  
    10  type ctxKey string
    11  
    12  var (
    13  	requestIsValidatedCtxKey = ctxKey("request-validated")
    14  )
    15  
    16  func contextWithRequestIsValidated(ctx context.Context) context.Context {
    17  	return context.WithValue(ctx, requestIsValidatedCtxKey, true)
    18  }
    19  
    20  // RequestIsValidatedFromContext returns true if the provided context object has the flag
    21  // indicating that the request has been validated and if its value is set to true.
    22  func RequestIsValidatedFromContext(ctx context.Context) bool {
    23  	validated, ok := ctx.Value(requestIsValidatedCtxKey).(bool)
    24  	return validated && ok
    25  }
    26  
    27  // UnaryServerInterceptor returns a new unary server interceptor that runs request validations
    28  // and injects a bool in the context indicating that validation has been run.
    29  func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
    30  	validator := grpcvalidator.UnaryServerInterceptor()
    31  
    32  	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    33  		return validator(ctx, req, info, func(ctx context.Context, req interface{}) (interface{}, error) {
    34  			return handler(contextWithRequestIsValidated(ctx), req)
    35  		})
    36  	}
    37  }
    38  
    39  // StreamServerInterceptor returns a new streaming server interceptor that runs request validations
    40  // and injects a bool in the context indicating that validation has been run.
    41  func StreamServerInterceptor() grpc.StreamServerInterceptor {
    42  	validator := grpcvalidator.StreamServerInterceptor()
    43  
    44  	return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    45  		return validator(srv, stream, info, func(srv interface{}, ss grpc.ServerStream) error {
    46  			return handler(srv, &recvWrapper{
    47  				ctx:          contextWithRequestIsValidated(stream.Context()),
    48  				ServerStream: ss,
    49  			})
    50  		})
    51  	}
    52  }
    53  
    54  type recvWrapper struct {
    55  	ctx context.Context
    56  	grpc.ServerStream
    57  }
    58  
    59  // Context returns the context associated with the recvWrapper.
    60  func (r *recvWrapper) Context() context.Context {
    61  	return r.ctx
    62  }