github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/middleware/streamtimeout/streamtimeout.go (about) 1 package streamtimeout 2 3 import ( 4 "context" 5 "fmt" 6 "time" 7 8 "google.golang.org/grpc" 9 "google.golang.org/grpc/codes" 10 "google.golang.org/grpc/metadata" 11 12 "github.com/authzed/spicedb/pkg/spiceerrors" 13 ) 14 15 // MustStreamServerInterceptor returns a new stream server interceptor that cancels the context 16 // after a timeout if no new data has been received. 17 func MustStreamServerInterceptor(timeout time.Duration) grpc.StreamServerInterceptor { 18 if timeout <= 0 { 19 panic("timeout must be >= 0 for streaming timeout interceptor") 20 } 21 22 return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 23 ctx := stream.Context() 24 withCancel, internalCancelFn := context.WithCancelCause(ctx) 25 timer := time.AfterFunc(timeout, func() { 26 internalCancelFn(spiceerrors.WithCodeAndDetailsAsError(fmt.Errorf("operation took longer than allowed %v to complete", timeout), codes.DeadlineExceeded)) 27 }) 28 wrapper := &sendWrapper{stream, withCancel, timer, timeout} 29 return handler(srv, wrapper) 30 } 31 } 32 33 type sendWrapper struct { 34 grpc.ServerStream 35 36 ctx context.Context 37 timer *time.Timer 38 timeout time.Duration 39 } 40 41 func (s *sendWrapper) Context() context.Context { 42 return s.ctx 43 } 44 45 func (s *sendWrapper) SetTrailer(_ metadata.MD) { 46 s.timer.Stop() 47 } 48 49 func (s *sendWrapper) SendMsg(m any) error { 50 err := s.ServerStream.SendMsg(m) 51 if err != nil { 52 s.timer.Stop() 53 } else { 54 s.timer.Reset(s.timeout) 55 } 56 return err 57 }