github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/sql/utils_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for utils module.""" 19 20 # pytype: skip-file 21 22 import unittest 23 from typing import NamedTuple 24 from typing import Optional 25 from typing import Union 26 from unittest.mock import patch 27 28 import pytest 29 30 import apache_beam as beam 31 from apache_beam.options.pipeline_options import GoogleCloudOptions 32 from apache_beam.options.pipeline_options import SetupOptions 33 from apache_beam.runners.interactive import interactive_environment as ie 34 from apache_beam.runners.interactive.sql.utils import DataflowOptionsForm 35 from apache_beam.runners.interactive.sql.utils import find_pcolls 36 from apache_beam.runners.interactive.sql.utils import pformat_dict 37 from apache_beam.runners.interactive.sql.utils import pformat_namedtuple 38 from apache_beam.runners.interactive.sql.utils import register_coder_for_schema 39 from apache_beam.runners.interactive.sql.utils import replace_single_pcoll_token 40 41 42 class ANamedTuple(NamedTuple): 43 a: int 44 b: str 45 46 47 class OptionalUnionType(NamedTuple): 48 unnamed: Optional[Union[int, str]] 49 50 51 class UtilsTest(unittest.TestCase): 52 def test_register_coder_for_schema(self): 53 self.assertNotIsInstance( 54 beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder) 55 register_coder_for_schema(ANamedTuple) 56 self.assertIsInstance( 57 beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder) 58 59 def test_find_pcolls(self): 60 with patch('apache_beam.runners.interactive.interactive_beam.collect', 61 lambda _: None): 62 found = find_pcolls( 63 """SELECT * FROM pcoll_1 JOIN pcoll_2 64 USING (common_column)""", { 65 'pcoll_1': None, 'pcoll_2': None 66 }) 67 self.assertIn('pcoll_1', found) 68 self.assertIn('pcoll_2', found) 69 70 def test_replace_single_pcoll_token(self): 71 sql = 'SELECT * FROM abc WHERE a=1 AND b=2' 72 replaced_sql = replace_single_pcoll_token(sql, 'wow') 73 self.assertEqual(replaced_sql, sql) 74 replaced_sql = replace_single_pcoll_token(sql, 'abc') 75 self.assertEqual( 76 replaced_sql, 'SELECT * FROM PCOLLECTION WHERE a=1 AND b=2') 77 78 def test_pformat_namedtuple(self): 79 actual = pformat_namedtuple(ANamedTuple) 80 self.assertEqual("ANamedTuple(a: <class 'int'>, b: <class 'str'>)", actual) 81 82 def test_pformat_namedtuple_with_unnamed_fields(self): 83 actual = pformat_namedtuple(OptionalUnionType) 84 # Parameters of an Union type can be in any order. 85 possible_expected = ( 86 'OptionalUnionType(unnamed: typing.Union[int, str, NoneType])', 87 'OptionalUnionType(unnamed: typing.Union[str, int, NoneType])') 88 self.assertIn(actual, possible_expected) 89 90 def test_pformat_dict(self): 91 actual = pformat_dict({'a': 1, 'b': '2'}) 92 self.assertEqual('{\na: 1,\nb: 2\n}', actual) 93 94 95 @unittest.skipIf( 96 not ie.current_env().is_interactive_ready, 97 '[interactive] dependency is not installed.') 98 @pytest.mark.skipif( 99 not ie.current_env().is_interactive_ready, 100 reason='[interactive] dependency is not installed.') 101 class OptionsFormTest(unittest.TestCase): 102 def test_dataflow_options_form(self): 103 p = beam.Pipeline() 104 pcoll = p | beam.Create([1, 2, 3]) 105 with patch('google.auth') as ga: 106 ga.default = lambda: ['', 'default_project_id'] 107 df_form = DataflowOptionsForm('pcoll', pcoll) 108 df_form.display_for_input() 109 df_form.entries[2].input.value = 'gs://test-bucket' 110 df_form.entries[3].input.value = 'a-pkg' 111 options = df_form.to_options() 112 cloud_options = options.view_as(GoogleCloudOptions) 113 self.assertEqual(cloud_options.project, 'default_project_id') 114 self.assertEqual(cloud_options.region, 'us-central1') 115 self.assertEqual( 116 cloud_options.staging_location, 'gs://test-bucket/staging') 117 self.assertEqual(cloud_options.temp_location, 'gs://test-bucket/temp') 118 self.assertIsNotNone(options.view_as(SetupOptions).requirements_file) 119 120 121 if __name__ == '__main__': 122 unittest.main()