github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/dispatch/stream.go (about)

     1  package dispatch
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"sync/atomic"
     7  
     8  	grpc "google.golang.org/grpc"
     9  )
    10  
    11  // Stream defines the interface generically matching a streaming dispatch response.
    12  type Stream[T any] interface {
    13  	// Publish publishes the result to the stream.
    14  	Publish(T) error
    15  
    16  	// Context returns the context for the stream.
    17  	Context() context.Context
    18  }
    19  
    20  type grpcStream[T any] interface {
    21  	grpc.ServerStream
    22  	Send(T) error
    23  }
    24  
    25  // WrapGRPCStream wraps a gRPC result stream with a concurrent-safe dispatch stream. This is
    26  // necessary because gRPC response streams are *not concurrent safe*.
    27  // See: https://groups.google.com/g/grpc-io/c/aI6L6M4fzQ0?pli=1
    28  func WrapGRPCStream[R any, S grpcStream[R]](grpcStream S) Stream[R] {
    29  	return &concurrentSafeStream[R]{
    30  		grpcStream: grpcStream,
    31  		mu:         sync.Mutex{},
    32  	}
    33  }
    34  
    35  type concurrentSafeStream[T any] struct {
    36  	grpcStream grpcStream[T]
    37  	mu         sync.Mutex
    38  }
    39  
    40  func (s *concurrentSafeStream[T]) Context() context.Context {
    41  	return s.grpcStream.Context()
    42  }
    43  
    44  func (s *concurrentSafeStream[T]) Publish(result T) error {
    45  	s.mu.Lock()
    46  	defer s.mu.Unlock()
    47  	return s.grpcStream.Send(result)
    48  }
    49  
    50  // NewCollectingDispatchStream creates a new CollectingDispatchStream.
    51  func NewCollectingDispatchStream[T any](ctx context.Context) *CollectingDispatchStream[T] {
    52  	return &CollectingDispatchStream[T]{
    53  		ctx:     ctx,
    54  		results: nil,
    55  		mu:      sync.Mutex{},
    56  	}
    57  }
    58  
    59  // CollectingDispatchStream is a dispatch stream that collects results in memory.
    60  type CollectingDispatchStream[T any] struct {
    61  	ctx     context.Context
    62  	results []T
    63  	mu      sync.Mutex
    64  }
    65  
    66  func (s *CollectingDispatchStream[T]) Context() context.Context {
    67  	return s.ctx
    68  }
    69  
    70  func (s *CollectingDispatchStream[T]) Results() []T {
    71  	return s.results
    72  }
    73  
    74  func (s *CollectingDispatchStream[T]) Publish(result T) error {
    75  	s.mu.Lock()
    76  	defer s.mu.Unlock()
    77  	s.results = append(s.results, result)
    78  	return nil
    79  }
    80  
    81  // WrappedDispatchStream is a dispatch stream that wraps another dispatch stream, and performs
    82  // an operation on each result before puppeting back up to the parent stream.
    83  type WrappedDispatchStream[T any] struct {
    84  	Stream    Stream[T]
    85  	Ctx       context.Context
    86  	Processor func(result T) (T, bool, error)
    87  }
    88  
    89  func (s *WrappedDispatchStream[T]) Publish(result T) error {
    90  	if s.Processor == nil {
    91  		return s.Stream.Publish(result)
    92  	}
    93  
    94  	processed, ok, err := s.Processor(result)
    95  	if err != nil {
    96  		return err
    97  	}
    98  	if !ok {
    99  		return nil
   100  	}
   101  
   102  	return s.Stream.Publish(processed)
   103  }
   104  
   105  func (s *WrappedDispatchStream[T]) Context() context.Context {
   106  	return s.Ctx
   107  }
   108  
   109  // StreamWithContext returns the given dispatch stream, wrapped to return the given context.
   110  func StreamWithContext[T any](context context.Context, stream Stream[T]) Stream[T] {
   111  	return &WrappedDispatchStream[T]{
   112  		Stream:    stream,
   113  		Ctx:       context,
   114  		Processor: nil,
   115  	}
   116  }
   117  
   118  // HandlingDispatchStream is a dispatch stream that executes a handler for each item published.
   119  // It uses an internal mutex to ensure it is thread safe.
   120  type HandlingDispatchStream[T any] struct {
   121  	ctx       context.Context
   122  	processor func(result T) error
   123  	mu        sync.Mutex
   124  }
   125  
   126  // NewHandlingDispatchStream returns a new handling dispatch stream.
   127  func NewHandlingDispatchStream[T any](ctx context.Context, processor func(result T) error) Stream[T] {
   128  	return &HandlingDispatchStream[T]{
   129  		ctx:       ctx,
   130  		processor: processor,
   131  		mu:        sync.Mutex{},
   132  	}
   133  }
   134  
   135  func (s *HandlingDispatchStream[T]) Publish(result T) error {
   136  	s.mu.Lock()
   137  	defer s.mu.Unlock()
   138  
   139  	if s.processor == nil {
   140  		return nil
   141  	}
   142  
   143  	return s.processor(result)
   144  }
   145  
   146  func (s *HandlingDispatchStream[T]) Context() context.Context {
   147  	return s.ctx
   148  }
   149  
   150  // CountingDispatchStream is a dispatch stream that counts the number of items published.
   151  // It uses an internal atomic int to ensure it is thread safe.
   152  type CountingDispatchStream[T any] struct {
   153  	Stream Stream[T]
   154  	count  *atomic.Uint64
   155  }
   156  
   157  func NewCountingDispatchStream[T any](wrapped Stream[T]) *CountingDispatchStream[T] {
   158  	return &CountingDispatchStream[T]{
   159  		Stream: wrapped,
   160  		count:  &atomic.Uint64{},
   161  	}
   162  }
   163  
   164  func (s *CountingDispatchStream[T]) PublishedCount() uint64 {
   165  	return s.count.Load()
   166  }
   167  
   168  func (s *CountingDispatchStream[T]) Publish(result T) error {
   169  	err := s.Stream.Publish(result)
   170  	if err != nil {
   171  		return err
   172  	}
   173  
   174  	s.count.Add(1)
   175  	return nil
   176  }
   177  
   178  func (s *CountingDispatchStream[T]) Context() context.Context {
   179  	return s.Stream.Context()
   180  }
   181  
   182  // Ensure the streams implement the interface.
   183  var (
   184  	_ Stream[any] = &CollectingDispatchStream[any]{}
   185  	_ Stream[any] = &WrappedDispatchStream[any]{}
   186  	_ Stream[any] = &CountingDispatchStream[any]{}
   187  )