github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/middleware/usagemetrics/usagemetrics_test.go (about)

     1  package usagemetrics
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"testing"
     9  
    10  	"github.com/authzed/authzed-go/pkg/responsemeta"
    11  	"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  	"github.com/stretchr/testify/suite"
    15  	"google.golang.org/grpc"
    16  	"google.golang.org/grpc/metadata"
    17  
    18  	dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
    19  )
    20  
    21  type testServer struct {
    22  	testpb.UnimplementedTestServiceServer
    23  }
    24  
    25  func (t testServer) PingEmpty(ctx context.Context, _ *testpb.PingEmptyRequest) (*testpb.PingEmptyResponse, error) {
    26  	SetInContext(ctx, &dispatch.ResponseMeta{
    27  		DispatchCount:       1,
    28  		CachedDispatchCount: 1,
    29  	})
    30  	return &testpb.PingEmptyResponse{}, nil
    31  }
    32  
    33  func (t testServer) Ping(ctx context.Context, _ *testpb.PingRequest) (*testpb.PingResponse, error) {
    34  	SetInContext(ctx, &dispatch.ResponseMeta{
    35  		DispatchCount:       1,
    36  		CachedDispatchCount: 1,
    37  	})
    38  	return &testpb.PingResponse{Value: ""}, nil
    39  }
    40  
    41  func (t testServer) PingError(ctx context.Context, _ *testpb.PingErrorRequest) (*testpb.PingErrorResponse, error) {
    42  	SetInContext(ctx, &dispatch.ResponseMeta{
    43  		DispatchCount:       1,
    44  		CachedDispatchCount: 1,
    45  	})
    46  	return nil, fmt.Errorf("err")
    47  }
    48  
    49  func (t testServer) PingList(_ *testpb.PingListRequest, server testpb.TestService_PingListServer) error {
    50  	SetInContext(server.Context(), &dispatch.ResponseMeta{
    51  		DispatchCount:       1,
    52  		CachedDispatchCount: 1,
    53  	})
    54  	return nil
    55  }
    56  
    57  func (t testServer) PingStream(stream testpb.TestService_PingStreamServer) error {
    58  	count := 0
    59  	for {
    60  		_, err := stream.Recv()
    61  		if errors.Is(err, io.EOF) {
    62  			break
    63  		} else if err != nil {
    64  			return err
    65  		}
    66  		_ = stream.Send(&testpb.PingStreamResponse{Value: "", Counter: int32(count)})
    67  		count++
    68  	}
    69  	return nil
    70  }
    71  
    72  type metricsMiddlewareTestSuite struct {
    73  	*testpb.InterceptorTestSuite
    74  }
    75  
    76  func TestMetricsMiddleware(t *testing.T) {
    77  	s := &metricsMiddlewareTestSuite{
    78  		InterceptorTestSuite: &testpb.InterceptorTestSuite{
    79  			TestService: &testServer{},
    80  			ServerOpts: []grpc.ServerOption{
    81  				grpc.UnaryInterceptor(UnaryServerInterceptor()),
    82  				grpc.StreamInterceptor(StreamServerInterceptor()),
    83  			},
    84  			ClientOpts: []grpc.DialOption{},
    85  		},
    86  	}
    87  	suite.Run(t, s)
    88  }
    89  
    90  func (s *metricsMiddlewareTestSuite) TestTrailers_Unary() {
    91  	var trailerMD metadata.MD
    92  	_, err := s.Client.Ping(s.SimpleCtx(), &testpb.PingRequest{Value: "something"}, grpc.Trailer(&trailerMD))
    93  	require.NoError(s.T(), err)
    94  
    95  	dispatchCount, err := responsemeta.GetIntResponseTrailerMetadata(
    96  		trailerMD,
    97  		responsemeta.DispatchedOperationsCount,
    98  	)
    99  	require.NoError(s.T(), err)
   100  	require.Equal(s.T(), 1, dispatchCount)
   101  
   102  	cachedCount, err := responsemeta.GetIntResponseTrailerMetadata(
   103  		trailerMD,
   104  		responsemeta.CachedOperationsCount,
   105  	)
   106  	require.NoError(s.T(), err)
   107  	require.Equal(s.T(), 1, cachedCount)
   108  }
   109  
   110  func (s *metricsMiddlewareTestSuite) TestTrailers_Stream() {
   111  	stream, err := s.Client.PingList(s.SimpleCtx(), &testpb.PingListRequest{Value: "something"})
   112  	require.NoError(s.T(), err)
   113  	for {
   114  		_, err := stream.Recv()
   115  		if errors.Is(err, io.EOF) {
   116  			break
   117  		}
   118  		assert.NoError(s.T(), err, "no error on messages sent occurred")
   119  	}
   120  
   121  	dispatchCount, err := responsemeta.GetIntResponseTrailerMetadata(
   122  		stream.Trailer(),
   123  		responsemeta.DispatchedOperationsCount,
   124  	)
   125  	require.NoError(s.T(), err)
   126  	require.Equal(s.T(), 1, dispatchCount)
   127  
   128  	cachedCount, err := responsemeta.GetIntResponseTrailerMetadata(
   129  		stream.Trailer(),
   130  		responsemeta.CachedOperationsCount,
   131  	)
   132  	require.NoError(s.T(), err)
   133  	require.Equal(s.T(), 1, cachedCount)
   134  }