github.com/openfga/openfga@v1.5.4-rc1/pkg/middleware/storeid/storeid_test.go (about)

     1  package storeid
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  
     7  	openfgav1 "github.com/openfga/api/proto/openfga/v1"
     8  	"github.com/stretchr/testify/require"
     9  	"google.golang.org/grpc"
    10  )
    11  
    12  func TestUnaryInterceptor(t *testing.T) {
    13  	t.Run("unary_interceptor_with_no_storeID_in_request", func(t *testing.T) {
    14  		interceptor := NewUnaryInterceptor()
    15  
    16  		handler := func(ctx context.Context, req interface{}) (interface{}, error) {
    17  			storeID, ok := StoreIDFromContext(ctx)
    18  			require.True(t, ok)
    19  			require.Empty(t, storeID)
    20  
    21  			return nil, nil
    22  		}
    23  
    24  		_, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{}, handler)
    25  		require.NoError(t, err)
    26  	})
    27  
    28  	t.Run("unary_interceptor_with_storeID_in_request", func(t *testing.T) {
    29  		storeID := "abc"
    30  		interceptor := NewUnaryInterceptor()
    31  
    32  		handler := func(ctx context.Context, req interface{}) (interface{}, error) {
    33  			got, ok := StoreIDFromContext(ctx)
    34  			require.True(t, ok)
    35  			require.Equal(t, storeID, got)
    36  
    37  			return nil, nil
    38  		}
    39  
    40  		_, err := interceptor(context.Background(), &openfgav1.CheckRequest{StoreId: storeID}, &grpc.UnaryServerInfo{}, handler)
    41  		require.NoError(t, err)
    42  	})
    43  }
    44  
    45  type mockServerStream struct {
    46  	grpc.ServerStream
    47  	ctx context.Context
    48  }
    49  
    50  func (s *mockServerStream) Context() context.Context {
    51  	return s.ctx
    52  }
    53  
    54  func (s *mockServerStream) RecvMsg(interface{}) error {
    55  	return nil
    56  }
    57  
    58  func TestStreamingInterceptor(t *testing.T) {
    59  	t.Run("streaming_interceptor_with_no_GetStoreId_in_request", func(t *testing.T) {
    60  		handler := func(srv interface{}, stream grpc.ServerStream) error {
    61  			err := stream.RecvMsg(nil)
    62  			require.NoError(t, err)
    63  
    64  			storeID, ok := StoreIDFromContext(stream.Context())
    65  			require.True(t, ok)
    66  			require.Empty(t, storeID)
    67  
    68  			return nil
    69  		}
    70  
    71  		ss := &mockServerStream{ctx: context.Background()}
    72  		err := NewStreamingInterceptor()(nil, ss, &grpc.StreamServerInfo{}, handler)
    73  		require.NoError(t, err)
    74  	})
    75  
    76  	t.Run("streaming_interceptor_with_GetStoreId_in_request", func(t *testing.T) {
    77  		handler := func(srv interface{}, stream grpc.ServerStream) error {
    78  			err := stream.RecvMsg(&openfgav1.CheckRequest{StoreId: "abc"})
    79  			require.NoError(t, err)
    80  
    81  			got, ok := StoreIDFromContext(stream.Context())
    82  			require.True(t, ok)
    83  			require.Equal(t, "abc", got)
    84  
    85  			return nil
    86  		}
    87  
    88  		ss := &mockServerStream{ctx: context.Background()}
    89  		err := NewStreamingInterceptor()(nil, ss, &grpc.StreamServerInfo{}, handler)
    90  		require.NoError(t, err)
    91  	})
    92  }