github.com/openfga/openfga@v1.5.4-rc1/pkg/middleware/storeid/storeid_test.go (about) 1 package storeid 2 3 import ( 4 "context" 5 "testing" 6 7 openfgav1 "github.com/openfga/api/proto/openfga/v1" 8 "github.com/stretchr/testify/require" 9 "google.golang.org/grpc" 10 ) 11 12 func TestUnaryInterceptor(t *testing.T) { 13 t.Run("unary_interceptor_with_no_storeID_in_request", func(t *testing.T) { 14 interceptor := NewUnaryInterceptor() 15 16 handler := func(ctx context.Context, req interface{}) (interface{}, error) { 17 storeID, ok := StoreIDFromContext(ctx) 18 require.True(t, ok) 19 require.Empty(t, storeID) 20 21 return nil, nil 22 } 23 24 _, err := interceptor(context.Background(), nil, &grpc.UnaryServerInfo{}, handler) 25 require.NoError(t, err) 26 }) 27 28 t.Run("unary_interceptor_with_storeID_in_request", func(t *testing.T) { 29 storeID := "abc" 30 interceptor := NewUnaryInterceptor() 31 32 handler := func(ctx context.Context, req interface{}) (interface{}, error) { 33 got, ok := StoreIDFromContext(ctx) 34 require.True(t, ok) 35 require.Equal(t, storeID, got) 36 37 return nil, nil 38 } 39 40 _, err := interceptor(context.Background(), &openfgav1.CheckRequest{StoreId: storeID}, &grpc.UnaryServerInfo{}, handler) 41 require.NoError(t, err) 42 }) 43 } 44 45 type mockServerStream struct { 46 grpc.ServerStream 47 ctx context.Context 48 } 49 50 func (s *mockServerStream) Context() context.Context { 51 return s.ctx 52 } 53 54 func (s *mockServerStream) RecvMsg(interface{}) error { 55 return nil 56 } 57 58 func TestStreamingInterceptor(t *testing.T) { 59 t.Run("streaming_interceptor_with_no_GetStoreId_in_request", func(t *testing.T) { 60 handler := func(srv interface{}, stream grpc.ServerStream) error { 61 err := stream.RecvMsg(nil) 62 require.NoError(t, err) 63 64 storeID, ok := StoreIDFromContext(stream.Context()) 65 require.True(t, ok) 66 require.Empty(t, storeID) 67 68 return nil 69 } 70 71 ss := &mockServerStream{ctx: context.Background()} 72 err := NewStreamingInterceptor()(nil, ss, &grpc.StreamServerInfo{}, handler) 73 require.NoError(t, err) 74 }) 75 76 t.Run("streaming_interceptor_with_GetStoreId_in_request", func(t *testing.T) { 77 handler := func(srv interface{}, stream grpc.ServerStream) error { 78 err := stream.RecvMsg(&openfgav1.CheckRequest{StoreId: "abc"}) 79 require.NoError(t, err) 80 81 got, ok := StoreIDFromContext(stream.Context()) 82 require.True(t, ok) 83 require.Equal(t, "abc", got) 84 85 return nil 86 } 87 88 ss := &mockServerStream{ctx: context.Background()} 89 err := NewStreamingInterceptor()(nil, ss, &grpc.StreamServerInfo{}, handler) 90 require.NoError(t, err) 91 }) 92 }