github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/sql/utils_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for utils module."""
    19  
    20  # pytype: skip-file
    21  
    22  import unittest
    23  from typing import NamedTuple
    24  from typing import Optional
    25  from typing import Union
    26  from unittest.mock import patch
    27  
    28  import pytest
    29  
    30  import apache_beam as beam
    31  from apache_beam.options.pipeline_options import GoogleCloudOptions
    32  from apache_beam.options.pipeline_options import SetupOptions
    33  from apache_beam.runners.interactive import interactive_environment as ie
    34  from apache_beam.runners.interactive.sql.utils import DataflowOptionsForm
    35  from apache_beam.runners.interactive.sql.utils import find_pcolls
    36  from apache_beam.runners.interactive.sql.utils import pformat_dict
    37  from apache_beam.runners.interactive.sql.utils import pformat_namedtuple
    38  from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
    39  from apache_beam.runners.interactive.sql.utils import replace_single_pcoll_token
    40  
    41  
    42  class ANamedTuple(NamedTuple):
    43    a: int
    44    b: str
    45  
    46  
    47  class OptionalUnionType(NamedTuple):
    48    unnamed: Optional[Union[int, str]]
    49  
    50  
    51  class UtilsTest(unittest.TestCase):
    52    def test_register_coder_for_schema(self):
    53      self.assertNotIsInstance(
    54          beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder)
    55      register_coder_for_schema(ANamedTuple)
    56      self.assertIsInstance(
    57          beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder)
    58  
    59    def test_find_pcolls(self):
    60      with patch('apache_beam.runners.interactive.interactive_beam.collect',
    61                 lambda _: None):
    62        found = find_pcolls(
    63            """SELECT * FROM pcoll_1 JOIN pcoll_2
    64            USING (common_column)""", {
    65                'pcoll_1': None, 'pcoll_2': None
    66            })
    67        self.assertIn('pcoll_1', found)
    68        self.assertIn('pcoll_2', found)
    69  
    70    def test_replace_single_pcoll_token(self):
    71      sql = 'SELECT * FROM abc WHERE a=1 AND b=2'
    72      replaced_sql = replace_single_pcoll_token(sql, 'wow')
    73      self.assertEqual(replaced_sql, sql)
    74      replaced_sql = replace_single_pcoll_token(sql, 'abc')
    75      self.assertEqual(
    76          replaced_sql, 'SELECT * FROM PCOLLECTION WHERE a=1 AND b=2')
    77  
    78    def test_pformat_namedtuple(self):
    79      actual = pformat_namedtuple(ANamedTuple)
    80      self.assertEqual("ANamedTuple(a: <class 'int'>, b: <class 'str'>)", actual)
    81  
    82    def test_pformat_namedtuple_with_unnamed_fields(self):
    83      actual = pformat_namedtuple(OptionalUnionType)
    84      # Parameters of an Union type can be in any order.
    85      possible_expected = (
    86          'OptionalUnionType(unnamed: typing.Union[int, str, NoneType])',
    87          'OptionalUnionType(unnamed: typing.Union[str, int, NoneType])')
    88      self.assertIn(actual, possible_expected)
    89  
    90    def test_pformat_dict(self):
    91      actual = pformat_dict({'a': 1, 'b': '2'})
    92      self.assertEqual('{\na: 1,\nb: 2\n}', actual)
    93  
    94  
    95  @unittest.skipIf(
    96      not ie.current_env().is_interactive_ready,
    97      '[interactive] dependency is not installed.')
    98  @pytest.mark.skipif(
    99      not ie.current_env().is_interactive_ready,
   100      reason='[interactive] dependency is not installed.')
   101  class OptionsFormTest(unittest.TestCase):
   102    def test_dataflow_options_form(self):
   103      p = beam.Pipeline()
   104      pcoll = p | beam.Create([1, 2, 3])
   105      with patch('google.auth') as ga:
   106        ga.default = lambda: ['', 'default_project_id']
   107        df_form = DataflowOptionsForm('pcoll', pcoll)
   108        df_form.display_for_input()
   109        df_form.entries[2].input.value = 'gs://test-bucket'
   110        df_form.entries[3].input.value = 'a-pkg'
   111        options = df_form.to_options()
   112        cloud_options = options.view_as(GoogleCloudOptions)
   113        self.assertEqual(cloud_options.project, 'default_project_id')
   114        self.assertEqual(cloud_options.region, 'us-central1')
   115        self.assertEqual(
   116            cloud_options.staging_location, 'gs://test-bucket/staging')
   117        self.assertEqual(cloud_options.temp_location, 'gs://test-bucket/temp')
   118        self.assertIsNotNone(options.view_as(SetupOptions).requirements_file)
   119  
   120  
   121  if __name__ == '__main__':
   122    unittest.main()