github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/middleware/streamtimeout/streamtimeout_test.go (about) 1 package streamtimeout 2 3 import ( 4 "context" 5 "fmt" 6 "testing" 7 "time" 8 9 "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb" 10 "github.com/stretchr/testify/require" 11 "github.com/stretchr/testify/suite" 12 "google.golang.org/grpc" 13 ) 14 15 type testServer struct { 16 testpb.UnimplementedTestServiceServer 17 } 18 19 func (t testServer) PingEmpty(_ context.Context, _ *testpb.PingEmptyRequest) (*testpb.PingEmptyResponse, error) { 20 return &testpb.PingEmptyResponse{}, nil 21 } 22 23 func (t testServer) Ping(_ context.Context, _ *testpb.PingRequest) (*testpb.PingResponse, error) { 24 return &testpb.PingResponse{Value: ""}, nil 25 } 26 27 func (t testServer) PingError(_ context.Context, _ *testpb.PingErrorRequest) (*testpb.PingErrorResponse, error) { 28 return nil, fmt.Errorf("err") 29 } 30 31 func (t testServer) PingList(_ *testpb.PingListRequest, server testpb.TestService_PingListServer) error { 32 var counter int32 33 for { 34 // Produce ping responses until the context is canceled. 35 select { 36 case <-server.Context().Done(): 37 return server.Context().Err() 38 39 default: 40 counter++ 41 err := server.Send(&testpb.PingListResponse{Counter: counter}) 42 if err != nil { 43 return err 44 } 45 time.Sleep(time.Duration(counter*10) * time.Millisecond) 46 } 47 } 48 } 49 50 func (t testServer) PingStream(_ testpb.TestService_PingStreamServer) error { 51 return fmt.Errorf("unused") 52 } 53 54 type testSuite struct { 55 *testpb.InterceptorTestSuite 56 } 57 58 func TestStreamTimeoutMiddleware(t *testing.T) { 59 s := &testSuite{ 60 InterceptorTestSuite: &testpb.InterceptorTestSuite{ 61 TestService: &testServer{}, 62 ServerOpts: []grpc.ServerOption{ 63 grpc.StreamInterceptor(MustStreamServerInterceptor(50 * time.Millisecond)), 64 }, 65 ClientOpts: []grpc.DialOption{}, 66 }, 67 } 68 suite.Run(t, s) 69 } 70 71 func (s *testSuite) TestStreamTimeout() { 72 stream, err := s.Client.PingList(s.SimpleCtx(), &testpb.PingListRequest{Value: "something"}) 73 require.NoError(s.T(), err) 74 75 var maxCounter int32 76 77 for { 78 // Ensure if we get an error, it is because the context was canceled. 79 resp, err := stream.Recv() 80 if err != nil { 81 require.ErrorContains(s.T(), err, "context canceled") 82 return 83 } 84 85 // Ensure that we produced a *maximum* of 6 responses (timeout is 50ms and each response 86 // should take 10ms * counter). This ensures that we timed out (roughly) when expected. 87 maxCounter = resp.Counter 88 require.LessOrEqual(s.T(), maxCounter, int32(6), "stream was not properly canceled: %d", maxCounter) 89 } 90 }