github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/middleware/handwrittenvalidation/handwrittenvalidation.go (about)

     1  package handwrittenvalidation
     2  
     3  import (
     4  	"context"
     5  
     6  	"google.golang.org/grpc"
     7  	"google.golang.org/grpc/codes"
     8  	"google.golang.org/grpc/status"
     9  )
    10  
    11  type handwrittenValidator interface {
    12  	HandwrittenValidate() error
    13  }
    14  
    15  // UnaryServerInterceptor returns a new unary server interceptor that runs the handwritten validation
    16  // on the incoming request, if any.
    17  func UnaryServerInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    18  	validator, ok := req.(handwrittenValidator)
    19  	if ok {
    20  		err := validator.HandwrittenValidate()
    21  		if err != nil {
    22  			return nil, status.Errorf(codes.InvalidArgument, "%s", err)
    23  		}
    24  	}
    25  
    26  	return handler(ctx, req)
    27  }
    28  
    29  // StreamServerInterceptor returns a new stream server interceptor that runs the handwritten validation
    30  // on the incoming request messages, if any.
    31  func StreamServerInterceptor(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    32  	wrapper := &recvWrapper{stream}
    33  	return handler(srv, wrapper)
    34  }
    35  
    36  type recvWrapper struct {
    37  	grpc.ServerStream
    38  }
    39  
    40  func (s *recvWrapper) RecvMsg(m interface{}) error {
    41  	if err := s.ServerStream.RecvMsg(m); err != nil {
    42  		return err
    43  	}
    44  
    45  	validator, ok := m.(handwrittenValidator)
    46  	if ok {
    47  		err := validator.HandwrittenValidate()
    48  		if err != nil {
    49  			return err
    50  		}
    51  	}
    52  
    53  	return nil
    54  }