github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/services/dispatch/v1/acl.go (about)

     1  package dispatch
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"time"
     7  
     8  	"github.com/authzed/spicedb/internal/middleware/streamtimeout"
     9  
    10  	"github.com/authzed/spicedb/internal/middleware"
    11  
    12  	grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
    13  	"google.golang.org/grpc/codes"
    14  	"google.golang.org/grpc/status"
    15  
    16  	"github.com/authzed/spicedb/internal/dispatch"
    17  	"github.com/authzed/spicedb/internal/graph"
    18  	log "github.com/authzed/spicedb/internal/logging"
    19  	"github.com/authzed/spicedb/internal/services/shared"
    20  	dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
    21  )
    22  
    23  const streamAPITimeout = 45 * time.Second
    24  
    25  type dispatchServer struct {
    26  	dispatchv1.UnimplementedDispatchServiceServer
    27  	shared.WithServiceSpecificInterceptors
    28  
    29  	localDispatch dispatch.Dispatcher
    30  }
    31  
    32  // NewDispatchServer creates a server which can be called for internal dispatch.
    33  func NewDispatchServer(localDispatch dispatch.Dispatcher) dispatchv1.DispatchServiceServer {
    34  	return &dispatchServer{
    35  		localDispatch: localDispatch,
    36  		WithServiceSpecificInterceptors: shared.WithServiceSpecificInterceptors{
    37  			Unary: grpcvalidate.UnaryServerInterceptor(),
    38  			Stream: middleware.ChainStreamServer(
    39  				grpcvalidate.StreamServerInterceptor(),
    40  				streamtimeout.MustStreamServerInterceptor(streamAPITimeout),
    41  			),
    42  		},
    43  	}
    44  }
    45  
    46  func (ds *dispatchServer) DispatchCheck(ctx context.Context, req *dispatchv1.DispatchCheckRequest) (*dispatchv1.DispatchCheckResponse, error) {
    47  	resp, err := ds.localDispatch.DispatchCheck(ctx, req)
    48  	return resp, rewriteGraphError(ctx, err)
    49  }
    50  
    51  func (ds *dispatchServer) DispatchExpand(ctx context.Context, req *dispatchv1.DispatchExpandRequest) (*dispatchv1.DispatchExpandResponse, error) {
    52  	resp, err := ds.localDispatch.DispatchExpand(ctx, req)
    53  	return resp, rewriteGraphError(ctx, err)
    54  }
    55  
    56  func (ds *dispatchServer) DispatchReachableResources(
    57  	req *dispatchv1.DispatchReachableResourcesRequest,
    58  	resp dispatchv1.DispatchService_DispatchReachableResourcesServer,
    59  ) error {
    60  	return ds.localDispatch.DispatchReachableResources(req,
    61  		dispatch.WrapGRPCStream[*dispatchv1.DispatchReachableResourcesResponse](resp))
    62  }
    63  
    64  func (ds *dispatchServer) DispatchLookupResources(
    65  	req *dispatchv1.DispatchLookupResourcesRequest,
    66  	resp dispatchv1.DispatchService_DispatchLookupResourcesServer,
    67  ) error {
    68  	return ds.localDispatch.DispatchLookupResources(req,
    69  		dispatch.WrapGRPCStream[*dispatchv1.DispatchLookupResourcesResponse](resp))
    70  }
    71  
    72  func (ds *dispatchServer) DispatchLookupSubjects(
    73  	req *dispatchv1.DispatchLookupSubjectsRequest,
    74  	resp dispatchv1.DispatchService_DispatchLookupSubjectsServer,
    75  ) error {
    76  	return ds.localDispatch.DispatchLookupSubjects(req,
    77  		dispatch.WrapGRPCStream[*dispatchv1.DispatchLookupSubjectsResponse](resp))
    78  }
    79  
    80  func (ds *dispatchServer) Close() error {
    81  	return nil
    82  }
    83  
    84  func rewriteGraphError(ctx context.Context, err error) error {
    85  	// Check if the error can be directly used.
    86  	if st, ok := status.FromError(err); ok {
    87  		return st.Err()
    88  	}
    89  
    90  	switch {
    91  	case errors.Is(err, context.DeadlineExceeded):
    92  		return status.Errorf(codes.DeadlineExceeded, "%s", err)
    93  	case errors.Is(err, context.Canceled):
    94  		return status.Errorf(codes.Canceled, "%s", err)
    95  	case status.Code(err) == codes.Canceled:
    96  		return err
    97  	case err == nil:
    98  		return nil
    99  
   100  	case errors.As(err, &graph.ErrAlwaysFail{}):
   101  		fallthrough
   102  	default:
   103  		log.Ctx(ctx).Err(err).Msg("unexpected dispatch graph error")
   104  		return err
   105  	}
   106  }