github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/dispatch/stream.go (about) 1 package dispatch 2 3 import ( 4 "context" 5 "sync" 6 "sync/atomic" 7 8 grpc "google.golang.org/grpc" 9 ) 10 11 // Stream defines the interface generically matching a streaming dispatch response. 12 type Stream[T any] interface { 13 // Publish publishes the result to the stream. 14 Publish(T) error 15 16 // Context returns the context for the stream. 17 Context() context.Context 18 } 19 20 type grpcStream[T any] interface { 21 grpc.ServerStream 22 Send(T) error 23 } 24 25 // WrapGRPCStream wraps a gRPC result stream with a concurrent-safe dispatch stream. This is 26 // necessary because gRPC response streams are *not concurrent safe*. 27 // See: https://groups.google.com/g/grpc-io/c/aI6L6M4fzQ0?pli=1 28 func WrapGRPCStream[R any, S grpcStream[R]](grpcStream S) Stream[R] { 29 return &concurrentSafeStream[R]{ 30 grpcStream: grpcStream, 31 mu: sync.Mutex{}, 32 } 33 } 34 35 type concurrentSafeStream[T any] struct { 36 grpcStream grpcStream[T] 37 mu sync.Mutex 38 } 39 40 func (s *concurrentSafeStream[T]) Context() context.Context { 41 return s.grpcStream.Context() 42 } 43 44 func (s *concurrentSafeStream[T]) Publish(result T) error { 45 s.mu.Lock() 46 defer s.mu.Unlock() 47 return s.grpcStream.Send(result) 48 } 49 50 // NewCollectingDispatchStream creates a new CollectingDispatchStream. 51 func NewCollectingDispatchStream[T any](ctx context.Context) *CollectingDispatchStream[T] { 52 return &CollectingDispatchStream[T]{ 53 ctx: ctx, 54 results: nil, 55 mu: sync.Mutex{}, 56 } 57 } 58 59 // CollectingDispatchStream is a dispatch stream that collects results in memory. 60 type CollectingDispatchStream[T any] struct { 61 ctx context.Context 62 results []T 63 mu sync.Mutex 64 } 65 66 func (s *CollectingDispatchStream[T]) Context() context.Context { 67 return s.ctx 68 } 69 70 func (s *CollectingDispatchStream[T]) Results() []T { 71 return s.results 72 } 73 74 func (s *CollectingDispatchStream[T]) Publish(result T) error { 75 s.mu.Lock() 76 defer s.mu.Unlock() 77 s.results = append(s.results, result) 78 return nil 79 } 80 81 // WrappedDispatchStream is a dispatch stream that wraps another dispatch stream, and performs 82 // an operation on each result before puppeting back up to the parent stream. 83 type WrappedDispatchStream[T any] struct { 84 Stream Stream[T] 85 Ctx context.Context 86 Processor func(result T) (T, bool, error) 87 } 88 89 func (s *WrappedDispatchStream[T]) Publish(result T) error { 90 if s.Processor == nil { 91 return s.Stream.Publish(result) 92 } 93 94 processed, ok, err := s.Processor(result) 95 if err != nil { 96 return err 97 } 98 if !ok { 99 return nil 100 } 101 102 return s.Stream.Publish(processed) 103 } 104 105 func (s *WrappedDispatchStream[T]) Context() context.Context { 106 return s.Ctx 107 } 108 109 // StreamWithContext returns the given dispatch stream, wrapped to return the given context. 110 func StreamWithContext[T any](context context.Context, stream Stream[T]) Stream[T] { 111 return &WrappedDispatchStream[T]{ 112 Stream: stream, 113 Ctx: context, 114 Processor: nil, 115 } 116 } 117 118 // HandlingDispatchStream is a dispatch stream that executes a handler for each item published. 119 // It uses an internal mutex to ensure it is thread safe. 120 type HandlingDispatchStream[T any] struct { 121 ctx context.Context 122 processor func(result T) error 123 mu sync.Mutex 124 } 125 126 // NewHandlingDispatchStream returns a new handling dispatch stream. 127 func NewHandlingDispatchStream[T any](ctx context.Context, processor func(result T) error) Stream[T] { 128 return &HandlingDispatchStream[T]{ 129 ctx: ctx, 130 processor: processor, 131 mu: sync.Mutex{}, 132 } 133 } 134 135 func (s *HandlingDispatchStream[T]) Publish(result T) error { 136 s.mu.Lock() 137 defer s.mu.Unlock() 138 139 if s.processor == nil { 140 return nil 141 } 142 143 return s.processor(result) 144 } 145 146 func (s *HandlingDispatchStream[T]) Context() context.Context { 147 return s.ctx 148 } 149 150 // CountingDispatchStream is a dispatch stream that counts the number of items published. 151 // It uses an internal atomic int to ensure it is thread safe. 152 type CountingDispatchStream[T any] struct { 153 Stream Stream[T] 154 count *atomic.Uint64 155 } 156 157 func NewCountingDispatchStream[T any](wrapped Stream[T]) *CountingDispatchStream[T] { 158 return &CountingDispatchStream[T]{ 159 Stream: wrapped, 160 count: &atomic.Uint64{}, 161 } 162 } 163 164 func (s *CountingDispatchStream[T]) PublishedCount() uint64 { 165 return s.count.Load() 166 } 167 168 func (s *CountingDispatchStream[T]) Publish(result T) error { 169 err := s.Stream.Publish(result) 170 if err != nil { 171 return err 172 } 173 174 s.count.Add(1) 175 return nil 176 } 177 178 func (s *CountingDispatchStream[T]) Context() context.Context { 179 return s.Stream.Context() 180 } 181 182 // Ensure the streams implement the interface. 183 var ( 184 _ Stream[any] = &CollectingDispatchStream[any]{} 185 _ Stream[any] = &WrappedDispatchStream[any]{} 186 _ Stream[any] = &CountingDispatchStream[any]{} 187 )