github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/middleware/consistency/consistency.go (about)

     1  package consistency
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"strings"
     8  
     9  	v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
    10  	"github.com/prometheus/client_golang/prometheus"
    11  	"github.com/prometheus/client_golang/prometheus/promauto"
    12  	"google.golang.org/grpc"
    13  	"google.golang.org/grpc/codes"
    14  	"google.golang.org/grpc/status"
    15  
    16  	log "github.com/authzed/spicedb/internal/logging"
    17  	datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
    18  	"github.com/authzed/spicedb/internal/services/shared"
    19  	"github.com/authzed/spicedb/pkg/cursor"
    20  	"github.com/authzed/spicedb/pkg/datastore"
    21  	"github.com/authzed/spicedb/pkg/zedtoken"
    22  )
    23  
    24  var ConsistentyCounter = promauto.NewCounterVec(prometheus.CounterOpts{
    25  	Namespace: "spicedb",
    26  	Subsystem: "middleware",
    27  	Name:      "consistency_assigned_total",
    28  	Help:      "Count of the consistencies used per request",
    29  }, []string{"method", "source"})
    30  
    31  type hasConsistency interface{ GetConsistency() *v1.Consistency }
    32  
    33  type hasOptionalCursor interface{ GetOptionalCursor() *v1.Cursor }
    34  
    35  type ctxKeyType struct{}
    36  
    37  var revisionKey ctxKeyType = struct{}{}
    38  
    39  var errInvalidZedToken = errors.New("invalid revision requested")
    40  
    41  type revisionHandle struct {
    42  	revision datastore.Revision
    43  }
    44  
    45  // ContextWithHandle adds a placeholder to a context that will later be
    46  // filled by the revision
    47  func ContextWithHandle(ctx context.Context) context.Context {
    48  	return context.WithValue(ctx, revisionKey, &revisionHandle{})
    49  }
    50  
    51  // RevisionFromContext reads the selected revision out of a context.Context, computes a zedtoken
    52  // from it, and returns an error if it has not been set on the context.
    53  func RevisionFromContext(ctx context.Context) (datastore.Revision, *v1.ZedToken, error) {
    54  	if c := ctx.Value(revisionKey); c != nil {
    55  		handle := c.(*revisionHandle)
    56  		rev := handle.revision
    57  		if rev != nil {
    58  			return rev, zedtoken.MustNewFromRevision(rev), nil
    59  		}
    60  	}
    61  
    62  	return nil, nil, fmt.Errorf("consistency middleware did not inject revision")
    63  }
    64  
    65  // AddRevisionToContext adds a revision to the given context, based on the consistency block found
    66  // in the given request (if applicable).
    67  func AddRevisionToContext(ctx context.Context, req interface{}, ds datastore.Datastore) error {
    68  	switch req := req.(type) {
    69  	case hasConsistency:
    70  		return addRevisionToContextFromConsistency(ctx, req, ds)
    71  	default:
    72  		return nil
    73  	}
    74  }
    75  
    76  // addRevisionToContextFromConsistency adds a revision to the given context, based on the consistency block found
    77  // in the given request (if applicable).
    78  func addRevisionToContextFromConsistency(ctx context.Context, req hasConsistency, ds datastore.Datastore) error {
    79  	handle := ctx.Value(revisionKey)
    80  	if handle == nil {
    81  		return nil
    82  	}
    83  
    84  	var revision datastore.Revision
    85  	consistency := req.GetConsistency()
    86  
    87  	withOptionalCursor, hasOptionalCursor := req.(hasOptionalCursor)
    88  
    89  	switch {
    90  	case hasOptionalCursor && withOptionalCursor.GetOptionalCursor() != nil:
    91  		// Always use the revision encoded in the cursor.
    92  		ConsistentyCounter.WithLabelValues("snapshot", "cursor").Inc()
    93  
    94  		requestedRev, err := cursor.DecodeToDispatchRevision(withOptionalCursor.GetOptionalCursor(), ds)
    95  		if err != nil {
    96  			return rewriteDatastoreError(ctx, err)
    97  		}
    98  
    99  		err = ds.CheckRevision(ctx, requestedRev)
   100  		if err != nil {
   101  			return rewriteDatastoreError(ctx, err)
   102  		}
   103  
   104  		revision = requestedRev
   105  
   106  	case consistency == nil || consistency.GetMinimizeLatency():
   107  		// Minimize Latency: Use the datastore's current revision, whatever it may be.
   108  		source := "request"
   109  		if consistency == nil {
   110  			source = "server"
   111  		}
   112  		ConsistentyCounter.WithLabelValues("minlatency", source).Inc()
   113  
   114  		databaseRev, err := ds.OptimizedRevision(ctx)
   115  		if err != nil {
   116  			return rewriteDatastoreError(ctx, err)
   117  		}
   118  		revision = databaseRev
   119  
   120  	case consistency.GetFullyConsistent():
   121  		// Fully Consistent: Use the datastore's synchronized revision.
   122  		ConsistentyCounter.WithLabelValues("full", "request").Inc()
   123  
   124  		databaseRev, err := ds.HeadRevision(ctx)
   125  		if err != nil {
   126  			return rewriteDatastoreError(ctx, err)
   127  		}
   128  		revision = databaseRev
   129  
   130  	case consistency.GetAtLeastAsFresh() != nil:
   131  		// At least as fresh as: Pick one of the datastore's revision and that specified, which
   132  		// ever is later.
   133  		picked, pickedRequest, err := pickBestRevision(ctx, consistency.GetAtLeastAsFresh(), ds)
   134  		if err != nil {
   135  			return rewriteDatastoreError(ctx, err)
   136  		}
   137  
   138  		source := "server"
   139  		if pickedRequest {
   140  			source = "request"
   141  		}
   142  		ConsistentyCounter.WithLabelValues("atleast", source).Inc()
   143  
   144  		revision = picked
   145  
   146  	case consistency.GetAtExactSnapshot() != nil:
   147  		// Exact snapshot: Use the revision as encoded in the zed token.
   148  		ConsistentyCounter.WithLabelValues("snapshot", "request").Inc()
   149  
   150  		requestedRev, err := zedtoken.DecodeRevision(consistency.GetAtExactSnapshot(), ds)
   151  		if err != nil {
   152  			return errInvalidZedToken
   153  		}
   154  
   155  		err = ds.CheckRevision(ctx, requestedRev)
   156  		if err != nil {
   157  			return rewriteDatastoreError(ctx, err)
   158  		}
   159  
   160  		revision = requestedRev
   161  
   162  	default:
   163  		return fmt.Errorf("missing handling of consistency case in %v", consistency)
   164  	}
   165  
   166  	handle.(*revisionHandle).revision = revision
   167  	return nil
   168  }
   169  
   170  var bypassServiceWhitelist = map[string]struct{}{
   171  	"/grpc.reflection.v1alpha.ServerReflection/": {},
   172  	"/grpc.reflection.v1.ServerReflection/":      {},
   173  	"/grpc.health.v1.Health/":                    {},
   174  }
   175  
   176  // UnaryServerInterceptor returns a new unary server interceptor that performs per-request exchange of
   177  // the specified consistency configuration for the revision at which to perform the request.
   178  func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
   179  	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
   180  		for bypass := range bypassServiceWhitelist {
   181  			if strings.HasPrefix(info.FullMethod, bypass) {
   182  				return handler(ctx, req)
   183  			}
   184  		}
   185  		ds := datastoremw.MustFromContext(ctx)
   186  		newCtx := ContextWithHandle(ctx)
   187  		if err := AddRevisionToContext(newCtx, req, ds); err != nil {
   188  			return nil, err
   189  		}
   190  
   191  		return handler(newCtx, req)
   192  	}
   193  }
   194  
   195  // StreamServerInterceptor returns a new stream server interceptor that performs per-request exchange of
   196  // the specified consistency configuration for the revision at which to perform the request.
   197  func StreamServerInterceptor() grpc.StreamServerInterceptor {
   198  	return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   199  		for bypass := range bypassServiceWhitelist {
   200  			if strings.HasPrefix(info.FullMethod, bypass) {
   201  				return handler(srv, stream)
   202  			}
   203  		}
   204  		wrapper := &recvWrapper{stream, ContextWithHandle(stream.Context())}
   205  		return handler(srv, wrapper)
   206  	}
   207  }
   208  
   209  type recvWrapper struct {
   210  	grpc.ServerStream
   211  	ctx context.Context
   212  }
   213  
   214  func (s *recvWrapper) Context() context.Context { return s.ctx }
   215  
   216  func (s *recvWrapper) RecvMsg(m interface{}) error {
   217  	if err := s.ServerStream.RecvMsg(m); err != nil {
   218  		return err
   219  	}
   220  	ds := datastoremw.MustFromContext(s.ctx)
   221  
   222  	return AddRevisionToContext(s.ctx, m, ds)
   223  }
   224  
   225  // pickBestRevision compares the provided ZedToken with the optimized revision of the datastore, and returns the most
   226  // recent one. The boolean return value will be true if the provided ZedToken is the most recent, false otherwise.
   227  func pickBestRevision(ctx context.Context, requested *v1.ZedToken, ds datastore.Datastore) (datastore.Revision, bool, error) {
   228  	// Calculate a revision as we see fit
   229  	databaseRev, err := ds.OptimizedRevision(ctx)
   230  	if err != nil {
   231  		return datastore.NoRevision, false, err
   232  	}
   233  
   234  	if requested != nil {
   235  		requestedRev, err := zedtoken.DecodeRevision(requested, ds)
   236  		if err != nil {
   237  			return datastore.NoRevision, false, errInvalidZedToken
   238  		}
   239  
   240  		if databaseRev.GreaterThan(requestedRev) {
   241  			return databaseRev, false, nil
   242  		}
   243  
   244  		return requestedRev, true, nil
   245  	}
   246  
   247  	return databaseRev, false, nil
   248  }
   249  
   250  func rewriteDatastoreError(ctx context.Context, err error) error {
   251  	// Check if the error can be directly used.
   252  	if _, ok := status.FromError(err); ok {
   253  		return err
   254  	}
   255  
   256  	switch {
   257  	case errors.As(err, &datastore.ErrInvalidRevision{}):
   258  		return status.Errorf(codes.OutOfRange, "invalid revision: %s", err)
   259  
   260  	case errors.As(err, &datastore.ErrReadOnly{}):
   261  		return shared.ErrServiceReadOnly
   262  
   263  	default:
   264  		log.Ctx(ctx).Err(err).Msg("unexpected consistency middleware error")
   265  		return err
   266  	}
   267  }