github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/middleware/streamtimeout/streamtimeout_test.go (about)

     1  package streamtimeout
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
    10  	"github.com/stretchr/testify/require"
    11  	"github.com/stretchr/testify/suite"
    12  	"google.golang.org/grpc"
    13  )
    14  
    15  type testServer struct {
    16  	testpb.UnimplementedTestServiceServer
    17  }
    18  
    19  func (t testServer) PingEmpty(_ context.Context, _ *testpb.PingEmptyRequest) (*testpb.PingEmptyResponse, error) {
    20  	return &testpb.PingEmptyResponse{}, nil
    21  }
    22  
    23  func (t testServer) Ping(_ context.Context, _ *testpb.PingRequest) (*testpb.PingResponse, error) {
    24  	return &testpb.PingResponse{Value: ""}, nil
    25  }
    26  
    27  func (t testServer) PingError(_ context.Context, _ *testpb.PingErrorRequest) (*testpb.PingErrorResponse, error) {
    28  	return nil, fmt.Errorf("err")
    29  }
    30  
    31  func (t testServer) PingList(_ *testpb.PingListRequest, server testpb.TestService_PingListServer) error {
    32  	var counter int32
    33  	for {
    34  		// Produce ping responses until the context is canceled.
    35  		select {
    36  		case <-server.Context().Done():
    37  			return server.Context().Err()
    38  
    39  		default:
    40  			counter++
    41  			err := server.Send(&testpb.PingListResponse{Counter: counter})
    42  			if err != nil {
    43  				return err
    44  			}
    45  			time.Sleep(time.Duration(counter*10) * time.Millisecond)
    46  		}
    47  	}
    48  }
    49  
    50  func (t testServer) PingStream(_ testpb.TestService_PingStreamServer) error {
    51  	return fmt.Errorf("unused")
    52  }
    53  
    54  type testSuite struct {
    55  	*testpb.InterceptorTestSuite
    56  }
    57  
    58  func TestStreamTimeoutMiddleware(t *testing.T) {
    59  	s := &testSuite{
    60  		InterceptorTestSuite: &testpb.InterceptorTestSuite{
    61  			TestService: &testServer{},
    62  			ServerOpts: []grpc.ServerOption{
    63  				grpc.StreamInterceptor(MustStreamServerInterceptor(50 * time.Millisecond)),
    64  			},
    65  			ClientOpts: []grpc.DialOption{},
    66  		},
    67  	}
    68  	suite.Run(t, s)
    69  }
    70  
    71  func (s *testSuite) TestStreamTimeout() {
    72  	stream, err := s.Client.PingList(s.SimpleCtx(), &testpb.PingListRequest{Value: "something"})
    73  	require.NoError(s.T(), err)
    74  
    75  	var maxCounter int32
    76  
    77  	for {
    78  		// Ensure if we get an error, it is because the context was canceled.
    79  		resp, err := stream.Recv()
    80  		if err != nil {
    81  			require.ErrorContains(s.T(), err, "context canceled")
    82  			return
    83  		}
    84  
    85  		// Ensure that we produced a *maximum* of 6 responses (timeout is 50ms and each response
    86  		// should take 10ms * counter). This ensures that we timed out (roughly) when expected.
    87  		maxCounter = resp.Counter
    88  		require.LessOrEqual(s.T(), maxCounter, int32(6), "stream was not properly canceled: %d", maxCounter)
    89  	}
    90  }