github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/interactive_environment_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for apache_beam.runners.interactive.interactive_environment."""
    19  # pytype: skip-file
    20  
    21  import importlib
    22  import unittest
    23  from unittest.mock import patch
    24  
    25  import apache_beam as beam
    26  from apache_beam.runners import runner
    27  from apache_beam.runners.interactive import cache_manager as cache
    28  from apache_beam.runners.interactive import interactive_environment as ie
    29  from apache_beam.runners.interactive.recording_manager import RecordingManager
    30  from apache_beam.runners.interactive.sql.sql_chain import SqlNode
    31  from apache_beam.runners.interactive.testing.mock_env import isolated_env
    32  
    33  # The module name is also a variable in module.
    34  _module_name = 'apache_beam.runners.interactive.interactive_environment_test'
    35  
    36  
    37  @isolated_env
    38  class InteractiveEnvironmentTest(unittest.TestCase):
    39    def setUp(self):
    40      self._p = beam.Pipeline()
    41      self._var_in_class_instance = 'a var in class instance'
    42  
    43    def assertVariableWatched(self, variable_name, variable_val):
    44      self.assertTrue(self._is_variable_watched(variable_name, variable_val))
    45  
    46    def assertVariableNotWatched(self, variable_name, variable_val):
    47      self.assertFalse(self._is_variable_watched(variable_name, variable_val))
    48  
    49    def _is_variable_watched(self, variable_name, variable_val):
    50      return any((variable_name, variable_val) in watching
    51                 for watching in ie.current_env().watching())
    52  
    53    def _a_function_with_local_watched(self):
    54      local_var_watched = 123  # pylint: disable=possibly-unused-variable
    55      ie.current_env().watch(locals())
    56  
    57    def _a_function_not_watching_local(self):
    58      local_var_not_watched = 456  # pylint: disable=unused-variable
    59  
    60    def test_watch_main_by_default(self):
    61      self.assertTrue('__main__' in ie.current_env()._watching_set)
    62      # __main__ module has variable __name__ with value '__main__'
    63      self.assertVariableWatched('__name__', '__main__')
    64  
    65    def test_watch_a_module_by_name(self):
    66      self.assertFalse(_module_name in ie.current_env()._watching_set)
    67      self.assertVariableNotWatched('_module_name', _module_name)
    68      ie.current_env().watch(_module_name)
    69      self.assertTrue(_module_name in ie.current_env()._watching_set)
    70      self.assertVariableWatched('_module_name', _module_name)
    71  
    72    def test_watch_a_module_by_module_object(self):
    73      module = importlib.import_module(_module_name)
    74      self.assertFalse(module in ie.current_env()._watching_set)
    75      self.assertVariableNotWatched('_module_name', _module_name)
    76      ie.current_env().watch(module)
    77      self.assertTrue(module in ie.current_env()._watching_set)
    78      self.assertVariableWatched('_module_name', _module_name)
    79  
    80    def test_watch_locals(self):
    81      self.assertVariableNotWatched('local_var_watched', 123)
    82      self.assertVariableNotWatched('local_var_not_watched', 456)
    83      self._a_function_with_local_watched()
    84      self.assertVariableWatched('local_var_watched', 123)
    85      self._a_function_not_watching_local()
    86      self.assertVariableNotWatched('local_var_not_watched', 456)
    87  
    88    def test_watch_class_instance(self):
    89      self.assertVariableNotWatched(
    90          '_var_in_class_instance', self._var_in_class_instance)
    91      ie.current_env().watch(self)
    92      self.assertVariableWatched(
    93          '_var_in_class_instance', self._var_in_class_instance)
    94  
    95    def test_fail_to_set_pipeline_result_key_not_pipeline(self):
    96      class NotPipeline(object):
    97        pass
    98  
    99      with self.assertRaises(AssertionError) as ctx:
   100        ie.current_env().set_pipeline_result(
   101            NotPipeline(), runner.PipelineResult(runner.PipelineState.RUNNING))
   102        self.assertTrue(
   103            'pipeline must be an instance of apache_beam.Pipeline '
   104            'or its subclass' in ctx.exception)
   105  
   106    def test_fail_to_set_pipeline_result_value_not_pipeline_result(self):
   107      class NotResult(object):
   108        pass
   109  
   110      with self.assertRaises(AssertionError) as ctx:
   111        ie.current_env().set_pipeline_result(self._p, NotResult())
   112        self.assertTrue(
   113            'result must be an instance of '
   114            'apache_beam.runners.runner.PipelineResult or its '
   115            'subclass' in ctx.exception)
   116  
   117    def test_set_pipeline_result_successfully(self):
   118      class PipelineSubClass(beam.Pipeline):
   119        pass
   120  
   121      class PipelineResultSubClass(runner.PipelineResult):
   122        pass
   123  
   124      pipeline = PipelineSubClass()
   125      pipeline_result = PipelineResultSubClass(runner.PipelineState.RUNNING)
   126      ie.current_env().set_pipeline_result(pipeline, pipeline_result)
   127      self.assertIs(ie.current_env().pipeline_result(pipeline), pipeline_result)
   128  
   129    def test_determine_terminal_state(self):
   130      for state in (runner.PipelineState.DONE,
   131                    runner.PipelineState.FAILED,
   132                    runner.PipelineState.CANCELLED,
   133                    runner.PipelineState.UPDATED,
   134                    runner.PipelineState.DRAINED):
   135        ie.current_env().set_pipeline_result(
   136            self._p, runner.PipelineResult(state))
   137        self.assertTrue(ie.current_env().is_terminated(self._p))
   138      for state in (runner.PipelineState.UNKNOWN,
   139                    runner.PipelineState.STARTING,
   140                    runner.PipelineState.STOPPED,
   141                    runner.PipelineState.RUNNING,
   142                    runner.PipelineState.DRAINING,
   143                    runner.PipelineState.PENDING,
   144                    runner.PipelineState.CANCELLING,
   145                    runner.PipelineState.UNRECOGNIZED):
   146        ie.current_env().set_pipeline_result(
   147            self._p, runner.PipelineResult(state))
   148        self.assertFalse(ie.current_env().is_terminated(self._p))
   149  
   150    def test_evict_pipeline_result(self):
   151      pipeline_result = runner.PipelineResult(runner.PipelineState.DONE)
   152      ie.current_env().set_pipeline_result(self._p, pipeline_result)
   153      self.assertIs(
   154          ie.current_env().evict_pipeline_result(self._p), pipeline_result)
   155      self.assertIs(ie.current_env().pipeline_result(self._p), None)
   156  
   157    def test_pipeline_result_is_none_when_pipeline_absent(self):
   158      self.assertIs(ie.current_env().pipeline_result(self._p), None)
   159      self.assertIs(ie.current_env().is_terminated(self._p), True)
   160      self.assertIs(ie.current_env().evict_pipeline_result(self._p), None)
   161  
   162    def test_cleanup_registered_when_creating_new_env(self):
   163      with patch('atexit.register') as mocked_atexit:
   164        _ = ie.InteractiveEnvironment()
   165        mocked_atexit.assert_called_once()
   166  
   167    def test_cleanup_invoked_when_new_env_replace_not_none_env(self):
   168      ie._interactive_beam_env = self.current_env
   169      with patch('apache_beam.runners.interactive.interactive_environment'
   170                 '.InteractiveEnvironment.cleanup') as mocked_cleanup:
   171        ie.new_env()
   172        mocked_cleanup.assert_called_once()
   173  
   174    def test_cleanup_not_invoked_when_cm_changed_from_none(self):
   175      env = ie.InteractiveEnvironment()
   176      with patch('apache_beam.runners.interactive.interactive_environment'
   177                 '.InteractiveEnvironment.cleanup') as mocked_cleanup:
   178        dummy_pipeline = 'dummy'
   179        self.assertIsNone(env.get_cache_manager(dummy_pipeline))
   180        cache_manager = cache.FileBasedCacheManager()
   181        env.set_cache_manager(cache_manager, dummy_pipeline)
   182        mocked_cleanup.assert_not_called()
   183        self.assertIs(env.get_cache_manager(dummy_pipeline), cache_manager)
   184  
   185    def test_cleanup_invoked_when_not_none_cm_changed(self):
   186      env = ie.InteractiveEnvironment()
   187      with patch('apache_beam.runners.interactive.interactive_environment'
   188                 '.InteractiveEnvironment.cleanup') as mocked_cleanup:
   189        dummy_pipeline = 'dummy'
   190        env.set_cache_manager(cache.FileBasedCacheManager(), dummy_pipeline)
   191        mocked_cleanup.assert_not_called()
   192        env.set_cache_manager(cache.FileBasedCacheManager(), dummy_pipeline)
   193        mocked_cleanup.assert_called_once()
   194  
   195    def test_noop_when_cm_is_not_changed(self):
   196      cache_manager = cache.FileBasedCacheManager()
   197      dummy_pipeline = 'dummy'
   198      env = ie.InteractiveEnvironment()
   199      with patch('apache_beam.runners.interactive.interactive_environment'
   200                 '.InteractiveEnvironment.cleanup') as mocked_cleanup:
   201        env._cache_managers[str(id(dummy_pipeline))] = cache_manager
   202        mocked_cleanup.assert_not_called()
   203        env.set_cache_manager(cache_manager, dummy_pipeline)
   204        mocked_cleanup.assert_not_called()
   205  
   206    def test_get_cache_manager_creates_cache_manager_if_absent(self):
   207      env = ie.InteractiveEnvironment()
   208      dummy_pipeline = beam.Pipeline()
   209      self.assertIsNone(env.get_cache_manager(dummy_pipeline))
   210      self.assertIsNotNone(
   211          env.get_cache_manager(dummy_pipeline, create_if_absent=True))
   212  
   213    def test_track_user_pipeline_cleanup_non_inspectable_pipeline(self):
   214      dummy_pipeline_1 = beam.Pipeline()
   215      dummy_pipeline_2 = beam.Pipeline()
   216      dummy_pipeline_3 = beam.Pipeline()
   217      dummy_pipeline_4 = beam.Pipeline()
   218      dummy_pcoll = dummy_pipeline_4 | beam.Create([1])
   219      dummy_pipeline_5 = beam.Pipeline()
   220      dummy_non_inspectable_pipeline = 'dummy'
   221      ie.current_env().watch(locals())
   222      from apache_beam.runners.interactive.background_caching_job import BackgroundCachingJob
   223      ie.current_env().set_background_caching_job(
   224          dummy_pipeline_1,
   225          BackgroundCachingJob(
   226              runner.PipelineResult(runner.PipelineState.DONE), limiters=[]))
   227      ie.current_env().set_test_stream_service_controller(dummy_pipeline_2, None)
   228      ie.current_env().set_cache_manager(
   229          cache.FileBasedCacheManager(), dummy_pipeline_3)
   230      ie.current_env().mark_pcollection_computed([dummy_pcoll])
   231      ie.current_env().set_cached_source_signature(
   232          dummy_non_inspectable_pipeline, None)
   233      ie.current_env().set_pipeline_result(
   234          dummy_pipeline_5, runner.PipelineResult(runner.PipelineState.RUNNING))
   235      with patch('apache_beam.runners.interactive.interactive_environment'
   236                 '.InteractiveEnvironment.cleanup') as mocked_cleanup:
   237        ie.current_env().track_user_pipelines()
   238        mocked_cleanup.assert_called_once()
   239  
   240    def test_evict_pcollections(self):
   241      """Tests the evicton logic in the InteractiveEnvironment."""
   242  
   243      # Create two PCollection, one that will be evicted and another that won't.
   244      p_to_evict = beam.Pipeline()
   245      to_evict = p_to_evict | beam.Create([])
   246  
   247      p_not_evicted = beam.Pipeline()
   248      not_evicted = p_not_evicted | beam.Create([])
   249  
   250      # Mark the PCollections as computed because the eviction logic only works
   251      # on computed PCollections.
   252      ie.current_env().mark_pcollection_computed([to_evict, not_evicted])
   253      self.assertSetEqual(
   254          ie.current_env().computed_pcollections, {to_evict, not_evicted})
   255  
   256      # Evict the PCollection and then check that the other PCollection is safe.
   257      ie.current_env().evict_computed_pcollections(p_to_evict)
   258      self.assertSetEqual(ie.current_env().computed_pcollections, {not_evicted})
   259  
   260    def test_set_get_recording_manager(self):
   261      p = beam.Pipeline()
   262      rm = RecordingManager(p)
   263      ie.current_env().set_recording_manager(rm, p)
   264      self.assertIs(rm, ie.current_env().get_recording_manager(p))
   265  
   266    def test_recording_manager_create_if_absent(self):
   267      p = beam.Pipeline()
   268      self.assertFalse(ie.current_env().get_recording_manager(p))
   269      self.assertTrue(
   270          ie.current_env().get_recording_manager(p, create_if_absent=True))
   271  
   272    def test_evict_recording_manager(self):
   273      p = beam.Pipeline()
   274      self.assertFalse(ie.current_env().get_recording_manager(p))
   275      self.assertTrue(
   276          ie.current_env().get_recording_manager(p, create_if_absent=True))
   277  
   278    def test_describe_all_recordings(self):
   279      self.assertFalse(ie.current_env().describe_all_recordings())
   280  
   281      p1 = beam.Pipeline()
   282      p2 = beam.Pipeline()
   283      ie.current_env().watch(locals())
   284      ie.current_env().track_user_pipelines()
   285      rm1 = ie.current_env().get_recording_manager(p1, create_if_absent=True)
   286      rm2 = ie.current_env().get_recording_manager(p2, create_if_absent=True)
   287  
   288      description = ie.current_env().describe_all_recordings()
   289      self.assertTrue(description)
   290  
   291      expected_description = {p1: rm1.describe(), p2: rm2.describe()}
   292      self.assertDictEqual(description, expected_description)
   293  
   294    def test_get_empty_sql_chain(self):
   295      env = ie.InteractiveEnvironment()
   296      p = beam.Pipeline()
   297      chain = env.get_sql_chain(p)
   298      self.assertIsNotNone(chain)
   299      self.assertEqual(chain.nodes, {})
   300  
   301    def test_get_sql_chain_with_nodes(self):
   302      env = ie.InteractiveEnvironment()
   303      p = beam.Pipeline()
   304      chain_with_node = env.get_sql_chain(p).append(
   305          SqlNode(output_name='name', source=p, query="query"))
   306      chain_got = env.get_sql_chain(p)
   307      self.assertIs(chain_with_node, chain_got)
   308  
   309    def test_get_sql_chain_setting_user_pipeline(self):
   310      env = ie.InteractiveEnvironment()
   311      p = beam.Pipeline()
   312      chain = env.get_sql_chain(p, set_user_pipeline=True)
   313      self.assertIs(chain.user_pipeline, p)
   314  
   315    def test_get_sql_chain_None_when_setting_multiple_user_pipelines(self):
   316      env = ie.InteractiveEnvironment()
   317      p = beam.Pipeline()
   318      chain = env.get_sql_chain(p, set_user_pipeline=True)
   319      p2 = beam.Pipeline()
   320      # Set the chain for a different pipeline.
   321      env.sql_chain[p2] = chain
   322      with self.assertRaises(ValueError):
   323        env.get_sql_chain(p2, set_user_pipeline=True)
   324  
   325    @patch(
   326        'apache_beam.runners.interactive.interactive_environment.'
   327        'assert_bucket_exists',
   328        return_value=None)
   329    def test_get_gcs_cache_dir_valid_path(self, mock_assert_bucket_exists):
   330      env = ie.InteractiveEnvironment()
   331      p = beam.Pipeline()
   332      cache_root = 'gs://test-cache-dir/'
   333      actual_cache_dir = env._get_gcs_cache_dir(p, cache_root)
   334      expected_cache_dir = 'gs://test-cache-dir/{}'.format(id(p))
   335      self.assertEqual(actual_cache_dir, expected_cache_dir)
   336  
   337    def test_get_gcs_cache_dir_invalid_path(self):
   338      env = ie.InteractiveEnvironment()
   339      p = beam.Pipeline()
   340      cache_root = 'gs://'
   341      with self.assertRaises(ValueError):
   342        env._get_gcs_cache_dir(p, cache_root)
   343  
   344  
   345  if __name__ == '__main__':
   346    unittest.main()