github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/services/v1/watch.go (about)

     1  package v1
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"time"
     7  
     8  	v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
     9  	grpcvalidate "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
    10  	"google.golang.org/grpc/codes"
    11  	"google.golang.org/grpc/status"
    12  
    13  	datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
    14  	"github.com/authzed/spicedb/internal/middleware/usagemetrics"
    15  	"github.com/authzed/spicedb/internal/services/shared"
    16  	"github.com/authzed/spicedb/pkg/datastore"
    17  	"github.com/authzed/spicedb/pkg/genutil/mapz"
    18  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    19  	dispatchv1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
    20  	"github.com/authzed/spicedb/pkg/tuple"
    21  	"github.com/authzed/spicedb/pkg/zedtoken"
    22  )
    23  
    24  type watchServer struct {
    25  	v1.UnimplementedWatchServiceServer
    26  	shared.WithStreamServiceSpecificInterceptor
    27  
    28  	heartbeatDuration time.Duration
    29  }
    30  
    31  // NewWatchServer creates an instance of the watch server.
    32  func NewWatchServer(heartbeatDuration time.Duration) v1.WatchServiceServer {
    33  	s := &watchServer{
    34  		WithStreamServiceSpecificInterceptor: shared.WithStreamServiceSpecificInterceptor{
    35  			Stream: grpcvalidate.StreamServerInterceptor(),
    36  		},
    37  		heartbeatDuration: heartbeatDuration,
    38  	}
    39  	return s
    40  }
    41  
    42  func (ws *watchServer) Watch(req *v1.WatchRequest, stream v1.WatchService_WatchServer) error {
    43  	if len(req.GetOptionalObjectTypes()) > 0 && len(req.OptionalRelationshipFilters) > 0 {
    44  		return status.Errorf(codes.InvalidArgument, "cannot specify both object types and relationship filters")
    45  	}
    46  
    47  	objectTypes := mapz.NewSet[string](req.GetOptionalObjectTypes()...)
    48  	filters := make([]datastore.RelationshipsFilter, 0, len(req.OptionalRelationshipFilters))
    49  
    50  	ctx := stream.Context()
    51  	ds := datastoremw.MustFromContext(ctx)
    52  
    53  	var afterRevision datastore.Revision
    54  	if req.OptionalStartCursor != nil && req.OptionalStartCursor.Token != "" {
    55  		decodedRevision, err := zedtoken.DecodeRevision(req.OptionalStartCursor, ds)
    56  		if err != nil {
    57  			return status.Errorf(codes.InvalidArgument, "failed to decode start revision: %s", err)
    58  		}
    59  
    60  		afterRevision = decodedRevision
    61  	} else {
    62  		var err error
    63  		afterRevision, err = ds.OptimizedRevision(ctx)
    64  		if err != nil {
    65  			return status.Errorf(codes.Unavailable, "failed to start watch: %s", err)
    66  		}
    67  	}
    68  
    69  	reader := ds.SnapshotReader(afterRevision)
    70  
    71  	for _, filter := range req.OptionalRelationshipFilters {
    72  		if err := validateRelationshipsFilter(stream.Context(), filter, reader); err != nil {
    73  			return ws.rewriteError(ctx, err)
    74  		}
    75  
    76  		dsFilter, err := datastore.RelationshipsFilterFromPublicFilter(filter)
    77  		if err != nil {
    78  			return status.Errorf(codes.InvalidArgument, "failed to parse relationship filter: %s", err)
    79  		}
    80  
    81  		filters = append(filters, dsFilter)
    82  	}
    83  
    84  	usagemetrics.SetInContext(ctx, &dispatchv1.ResponseMeta{
    85  		DispatchCount: 1,
    86  	})
    87  
    88  	updates, errchan := ds.Watch(ctx, afterRevision, datastore.WatchOptions{
    89  		Content:            datastore.WatchRelationships,
    90  		CheckpointInterval: ws.heartbeatDuration,
    91  	})
    92  	for {
    93  		select {
    94  		case update, ok := <-updates:
    95  			if ok {
    96  				filtered := filterUpdates(objectTypes, filters, update.RelationshipChanges)
    97  				if len(filtered) > 0 {
    98  					if err := stream.Send(&v1.WatchResponse{
    99  						Updates:        filtered,
   100  						ChangesThrough: zedtoken.MustNewFromRevision(update.Revision),
   101  					}); err != nil {
   102  						return status.Errorf(codes.Canceled, "watch canceled by user: %s", err)
   103  					}
   104  				}
   105  			}
   106  		case err := <-errchan:
   107  			switch {
   108  			case errors.As(err, &datastore.ErrWatchCanceled{}):
   109  				return status.Errorf(codes.Canceled, "watch canceled by user: %s", err)
   110  			case errors.As(err, &datastore.ErrWatchDisconnected{}):
   111  				return status.Errorf(codes.ResourceExhausted, "watch disconnected: %s", err)
   112  			default:
   113  				return status.Errorf(codes.Internal, "watch error: %s", err)
   114  			}
   115  		}
   116  	}
   117  }
   118  
   119  func (ws *watchServer) rewriteError(ctx context.Context, err error) error {
   120  	return shared.RewriteError(ctx, err, &shared.ConfigForErrors{})
   121  }
   122  
   123  func filterUpdates(objectTypes *mapz.Set[string], filters []datastore.RelationshipsFilter, candidates []*core.RelationTupleUpdate) []*v1.RelationshipUpdate {
   124  	updates := tuple.UpdatesToRelationshipUpdates(candidates)
   125  
   126  	if objectTypes.IsEmpty() && len(filters) == 0 {
   127  		return updates
   128  	}
   129  
   130  	filtered := make([]*v1.RelationshipUpdate, 0, len(updates))
   131  	for _, update := range updates {
   132  		objectType := update.GetRelationship().GetResource().GetObjectType()
   133  		if !objectTypes.IsEmpty() && !objectTypes.Has(objectType) {
   134  			continue
   135  		}
   136  
   137  		if len(filters) > 0 {
   138  			// If there are filters, we need to check if the update matches any of them.
   139  			matched := false
   140  			for _, filter := range filters {
   141  				// TODO(jschorr): Maybe we should add TestRelationship to avoid the conversion?
   142  				if filter.Test(tuple.MustFromRelationship(update.GetRelationship())) {
   143  					matched = true
   144  					break
   145  				}
   146  			}
   147  
   148  			if !matched {
   149  				continue
   150  			}
   151  		}
   152  
   153  		filtered = append(filtered, update)
   154  	}
   155  
   156  	return filtered
   157  }