github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/middleware/usagemetrics/usagemetrics_test.go (about) 1 package usagemetrics 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "testing" 9 10 "github.com/authzed/authzed-go/pkg/responsemeta" 11 "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb" 12 "github.com/stretchr/testify/assert" 13 "github.com/stretchr/testify/require" 14 "github.com/stretchr/testify/suite" 15 "google.golang.org/grpc" 16 "google.golang.org/grpc/metadata" 17 18 dispatch "github.com/authzed/spicedb/pkg/proto/dispatch/v1" 19 ) 20 21 type testServer struct { 22 testpb.UnimplementedTestServiceServer 23 } 24 25 func (t testServer) PingEmpty(ctx context.Context, _ *testpb.PingEmptyRequest) (*testpb.PingEmptyResponse, error) { 26 SetInContext(ctx, &dispatch.ResponseMeta{ 27 DispatchCount: 1, 28 CachedDispatchCount: 1, 29 }) 30 return &testpb.PingEmptyResponse{}, nil 31 } 32 33 func (t testServer) Ping(ctx context.Context, _ *testpb.PingRequest) (*testpb.PingResponse, error) { 34 SetInContext(ctx, &dispatch.ResponseMeta{ 35 DispatchCount: 1, 36 CachedDispatchCount: 1, 37 }) 38 return &testpb.PingResponse{Value: ""}, nil 39 } 40 41 func (t testServer) PingError(ctx context.Context, _ *testpb.PingErrorRequest) (*testpb.PingErrorResponse, error) { 42 SetInContext(ctx, &dispatch.ResponseMeta{ 43 DispatchCount: 1, 44 CachedDispatchCount: 1, 45 }) 46 return nil, fmt.Errorf("err") 47 } 48 49 func (t testServer) PingList(_ *testpb.PingListRequest, server testpb.TestService_PingListServer) error { 50 SetInContext(server.Context(), &dispatch.ResponseMeta{ 51 DispatchCount: 1, 52 CachedDispatchCount: 1, 53 }) 54 return nil 55 } 56 57 func (t testServer) PingStream(stream testpb.TestService_PingStreamServer) error { 58 count := 0 59 for { 60 _, err := stream.Recv() 61 if errors.Is(err, io.EOF) { 62 break 63 } else if err != nil { 64 return err 65 } 66 _ = stream.Send(&testpb.PingStreamResponse{Value: "", Counter: int32(count)}) 67 count++ 68 } 69 return nil 70 } 71 72 type metricsMiddlewareTestSuite struct { 73 *testpb.InterceptorTestSuite 74 } 75 76 func TestMetricsMiddleware(t *testing.T) { 77 s := &metricsMiddlewareTestSuite{ 78 InterceptorTestSuite: &testpb.InterceptorTestSuite{ 79 TestService: &testServer{}, 80 ServerOpts: []grpc.ServerOption{ 81 grpc.UnaryInterceptor(UnaryServerInterceptor()), 82 grpc.StreamInterceptor(StreamServerInterceptor()), 83 }, 84 ClientOpts: []grpc.DialOption{}, 85 }, 86 } 87 suite.Run(t, s) 88 } 89 90 func (s *metricsMiddlewareTestSuite) TestTrailers_Unary() { 91 var trailerMD metadata.MD 92 _, err := s.Client.Ping(s.SimpleCtx(), &testpb.PingRequest{Value: "something"}, grpc.Trailer(&trailerMD)) 93 require.NoError(s.T(), err) 94 95 dispatchCount, err := responsemeta.GetIntResponseTrailerMetadata( 96 trailerMD, 97 responsemeta.DispatchedOperationsCount, 98 ) 99 require.NoError(s.T(), err) 100 require.Equal(s.T(), 1, dispatchCount) 101 102 cachedCount, err := responsemeta.GetIntResponseTrailerMetadata( 103 trailerMD, 104 responsemeta.CachedOperationsCount, 105 ) 106 require.NoError(s.T(), err) 107 require.Equal(s.T(), 1, cachedCount) 108 } 109 110 func (s *metricsMiddlewareTestSuite) TestTrailers_Stream() { 111 stream, err := s.Client.PingList(s.SimpleCtx(), &testpb.PingListRequest{Value: "something"}) 112 require.NoError(s.T(), err) 113 for { 114 _, err := stream.Recv() 115 if errors.Is(err, io.EOF) { 116 break 117 } 118 assert.NoError(s.T(), err, "no error on messages sent occurred") 119 } 120 121 dispatchCount, err := responsemeta.GetIntResponseTrailerMetadata( 122 stream.Trailer(), 123 responsemeta.DispatchedOperationsCount, 124 ) 125 require.NoError(s.T(), err) 126 require.Equal(s.T(), 1, dispatchCount) 127 128 cachedCount, err := responsemeta.GetIntResponseTrailerMetadata( 129 stream.Trailer(), 130 responsemeta.CachedOperationsCount, 131 ) 132 require.NoError(s.T(), err) 133 require.Equal(s.T(), 1, cachedCount) 134 }