github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/interactive_environment_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for apache_beam.runners.interactive.interactive_environment.""" 19 # pytype: skip-file 20 21 import importlib 22 import unittest 23 from unittest.mock import patch 24 25 import apache_beam as beam 26 from apache_beam.runners import runner 27 from apache_beam.runners.interactive import cache_manager as cache 28 from apache_beam.runners.interactive import interactive_environment as ie 29 from apache_beam.runners.interactive.recording_manager import RecordingManager 30 from apache_beam.runners.interactive.sql.sql_chain import SqlNode 31 from apache_beam.runners.interactive.testing.mock_env import isolated_env 32 33 # The module name is also a variable in module. 34 _module_name = 'apache_beam.runners.interactive.interactive_environment_test' 35 36 37 @isolated_env 38 class InteractiveEnvironmentTest(unittest.TestCase): 39 def setUp(self): 40 self._p = beam.Pipeline() 41 self._var_in_class_instance = 'a var in class instance' 42 43 def assertVariableWatched(self, variable_name, variable_val): 44 self.assertTrue(self._is_variable_watched(variable_name, variable_val)) 45 46 def assertVariableNotWatched(self, variable_name, variable_val): 47 self.assertFalse(self._is_variable_watched(variable_name, variable_val)) 48 49 def _is_variable_watched(self, variable_name, variable_val): 50 return any((variable_name, variable_val) in watching 51 for watching in ie.current_env().watching()) 52 53 def _a_function_with_local_watched(self): 54 local_var_watched = 123 # pylint: disable=possibly-unused-variable 55 ie.current_env().watch(locals()) 56 57 def _a_function_not_watching_local(self): 58 local_var_not_watched = 456 # pylint: disable=unused-variable 59 60 def test_watch_main_by_default(self): 61 self.assertTrue('__main__' in ie.current_env()._watching_set) 62 # __main__ module has variable __name__ with value '__main__' 63 self.assertVariableWatched('__name__', '__main__') 64 65 def test_watch_a_module_by_name(self): 66 self.assertFalse(_module_name in ie.current_env()._watching_set) 67 self.assertVariableNotWatched('_module_name', _module_name) 68 ie.current_env().watch(_module_name) 69 self.assertTrue(_module_name in ie.current_env()._watching_set) 70 self.assertVariableWatched('_module_name', _module_name) 71 72 def test_watch_a_module_by_module_object(self): 73 module = importlib.import_module(_module_name) 74 self.assertFalse(module in ie.current_env()._watching_set) 75 self.assertVariableNotWatched('_module_name', _module_name) 76 ie.current_env().watch(module) 77 self.assertTrue(module in ie.current_env()._watching_set) 78 self.assertVariableWatched('_module_name', _module_name) 79 80 def test_watch_locals(self): 81 self.assertVariableNotWatched('local_var_watched', 123) 82 self.assertVariableNotWatched('local_var_not_watched', 456) 83 self._a_function_with_local_watched() 84 self.assertVariableWatched('local_var_watched', 123) 85 self._a_function_not_watching_local() 86 self.assertVariableNotWatched('local_var_not_watched', 456) 87 88 def test_watch_class_instance(self): 89 self.assertVariableNotWatched( 90 '_var_in_class_instance', self._var_in_class_instance) 91 ie.current_env().watch(self) 92 self.assertVariableWatched( 93 '_var_in_class_instance', self._var_in_class_instance) 94 95 def test_fail_to_set_pipeline_result_key_not_pipeline(self): 96 class NotPipeline(object): 97 pass 98 99 with self.assertRaises(AssertionError) as ctx: 100 ie.current_env().set_pipeline_result( 101 NotPipeline(), runner.PipelineResult(runner.PipelineState.RUNNING)) 102 self.assertTrue( 103 'pipeline must be an instance of apache_beam.Pipeline ' 104 'or its subclass' in ctx.exception) 105 106 def test_fail_to_set_pipeline_result_value_not_pipeline_result(self): 107 class NotResult(object): 108 pass 109 110 with self.assertRaises(AssertionError) as ctx: 111 ie.current_env().set_pipeline_result(self._p, NotResult()) 112 self.assertTrue( 113 'result must be an instance of ' 114 'apache_beam.runners.runner.PipelineResult or its ' 115 'subclass' in ctx.exception) 116 117 def test_set_pipeline_result_successfully(self): 118 class PipelineSubClass(beam.Pipeline): 119 pass 120 121 class PipelineResultSubClass(runner.PipelineResult): 122 pass 123 124 pipeline = PipelineSubClass() 125 pipeline_result = PipelineResultSubClass(runner.PipelineState.RUNNING) 126 ie.current_env().set_pipeline_result(pipeline, pipeline_result) 127 self.assertIs(ie.current_env().pipeline_result(pipeline), pipeline_result) 128 129 def test_determine_terminal_state(self): 130 for state in (runner.PipelineState.DONE, 131 runner.PipelineState.FAILED, 132 runner.PipelineState.CANCELLED, 133 runner.PipelineState.UPDATED, 134 runner.PipelineState.DRAINED): 135 ie.current_env().set_pipeline_result( 136 self._p, runner.PipelineResult(state)) 137 self.assertTrue(ie.current_env().is_terminated(self._p)) 138 for state in (runner.PipelineState.UNKNOWN, 139 runner.PipelineState.STARTING, 140 runner.PipelineState.STOPPED, 141 runner.PipelineState.RUNNING, 142 runner.PipelineState.DRAINING, 143 runner.PipelineState.PENDING, 144 runner.PipelineState.CANCELLING, 145 runner.PipelineState.UNRECOGNIZED): 146 ie.current_env().set_pipeline_result( 147 self._p, runner.PipelineResult(state)) 148 self.assertFalse(ie.current_env().is_terminated(self._p)) 149 150 def test_evict_pipeline_result(self): 151 pipeline_result = runner.PipelineResult(runner.PipelineState.DONE) 152 ie.current_env().set_pipeline_result(self._p, pipeline_result) 153 self.assertIs( 154 ie.current_env().evict_pipeline_result(self._p), pipeline_result) 155 self.assertIs(ie.current_env().pipeline_result(self._p), None) 156 157 def test_pipeline_result_is_none_when_pipeline_absent(self): 158 self.assertIs(ie.current_env().pipeline_result(self._p), None) 159 self.assertIs(ie.current_env().is_terminated(self._p), True) 160 self.assertIs(ie.current_env().evict_pipeline_result(self._p), None) 161 162 def test_cleanup_registered_when_creating_new_env(self): 163 with patch('atexit.register') as mocked_atexit: 164 _ = ie.InteractiveEnvironment() 165 mocked_atexit.assert_called_once() 166 167 def test_cleanup_invoked_when_new_env_replace_not_none_env(self): 168 ie._interactive_beam_env = self.current_env 169 with patch('apache_beam.runners.interactive.interactive_environment' 170 '.InteractiveEnvironment.cleanup') as mocked_cleanup: 171 ie.new_env() 172 mocked_cleanup.assert_called_once() 173 174 def test_cleanup_not_invoked_when_cm_changed_from_none(self): 175 env = ie.InteractiveEnvironment() 176 with patch('apache_beam.runners.interactive.interactive_environment' 177 '.InteractiveEnvironment.cleanup') as mocked_cleanup: 178 dummy_pipeline = 'dummy' 179 self.assertIsNone(env.get_cache_manager(dummy_pipeline)) 180 cache_manager = cache.FileBasedCacheManager() 181 env.set_cache_manager(cache_manager, dummy_pipeline) 182 mocked_cleanup.assert_not_called() 183 self.assertIs(env.get_cache_manager(dummy_pipeline), cache_manager) 184 185 def test_cleanup_invoked_when_not_none_cm_changed(self): 186 env = ie.InteractiveEnvironment() 187 with patch('apache_beam.runners.interactive.interactive_environment' 188 '.InteractiveEnvironment.cleanup') as mocked_cleanup: 189 dummy_pipeline = 'dummy' 190 env.set_cache_manager(cache.FileBasedCacheManager(), dummy_pipeline) 191 mocked_cleanup.assert_not_called() 192 env.set_cache_manager(cache.FileBasedCacheManager(), dummy_pipeline) 193 mocked_cleanup.assert_called_once() 194 195 def test_noop_when_cm_is_not_changed(self): 196 cache_manager = cache.FileBasedCacheManager() 197 dummy_pipeline = 'dummy' 198 env = ie.InteractiveEnvironment() 199 with patch('apache_beam.runners.interactive.interactive_environment' 200 '.InteractiveEnvironment.cleanup') as mocked_cleanup: 201 env._cache_managers[str(id(dummy_pipeline))] = cache_manager 202 mocked_cleanup.assert_not_called() 203 env.set_cache_manager(cache_manager, dummy_pipeline) 204 mocked_cleanup.assert_not_called() 205 206 def test_get_cache_manager_creates_cache_manager_if_absent(self): 207 env = ie.InteractiveEnvironment() 208 dummy_pipeline = beam.Pipeline() 209 self.assertIsNone(env.get_cache_manager(dummy_pipeline)) 210 self.assertIsNotNone( 211 env.get_cache_manager(dummy_pipeline, create_if_absent=True)) 212 213 def test_track_user_pipeline_cleanup_non_inspectable_pipeline(self): 214 dummy_pipeline_1 = beam.Pipeline() 215 dummy_pipeline_2 = beam.Pipeline() 216 dummy_pipeline_3 = beam.Pipeline() 217 dummy_pipeline_4 = beam.Pipeline() 218 dummy_pcoll = dummy_pipeline_4 | beam.Create([1]) 219 dummy_pipeline_5 = beam.Pipeline() 220 dummy_non_inspectable_pipeline = 'dummy' 221 ie.current_env().watch(locals()) 222 from apache_beam.runners.interactive.background_caching_job import BackgroundCachingJob 223 ie.current_env().set_background_caching_job( 224 dummy_pipeline_1, 225 BackgroundCachingJob( 226 runner.PipelineResult(runner.PipelineState.DONE), limiters=[])) 227 ie.current_env().set_test_stream_service_controller(dummy_pipeline_2, None) 228 ie.current_env().set_cache_manager( 229 cache.FileBasedCacheManager(), dummy_pipeline_3) 230 ie.current_env().mark_pcollection_computed([dummy_pcoll]) 231 ie.current_env().set_cached_source_signature( 232 dummy_non_inspectable_pipeline, None) 233 ie.current_env().set_pipeline_result( 234 dummy_pipeline_5, runner.PipelineResult(runner.PipelineState.RUNNING)) 235 with patch('apache_beam.runners.interactive.interactive_environment' 236 '.InteractiveEnvironment.cleanup') as mocked_cleanup: 237 ie.current_env().track_user_pipelines() 238 mocked_cleanup.assert_called_once() 239 240 def test_evict_pcollections(self): 241 """Tests the evicton logic in the InteractiveEnvironment.""" 242 243 # Create two PCollection, one that will be evicted and another that won't. 244 p_to_evict = beam.Pipeline() 245 to_evict = p_to_evict | beam.Create([]) 246 247 p_not_evicted = beam.Pipeline() 248 not_evicted = p_not_evicted | beam.Create([]) 249 250 # Mark the PCollections as computed because the eviction logic only works 251 # on computed PCollections. 252 ie.current_env().mark_pcollection_computed([to_evict, not_evicted]) 253 self.assertSetEqual( 254 ie.current_env().computed_pcollections, {to_evict, not_evicted}) 255 256 # Evict the PCollection and then check that the other PCollection is safe. 257 ie.current_env().evict_computed_pcollections(p_to_evict) 258 self.assertSetEqual(ie.current_env().computed_pcollections, {not_evicted}) 259 260 def test_set_get_recording_manager(self): 261 p = beam.Pipeline() 262 rm = RecordingManager(p) 263 ie.current_env().set_recording_manager(rm, p) 264 self.assertIs(rm, ie.current_env().get_recording_manager(p)) 265 266 def test_recording_manager_create_if_absent(self): 267 p = beam.Pipeline() 268 self.assertFalse(ie.current_env().get_recording_manager(p)) 269 self.assertTrue( 270 ie.current_env().get_recording_manager(p, create_if_absent=True)) 271 272 def test_evict_recording_manager(self): 273 p = beam.Pipeline() 274 self.assertFalse(ie.current_env().get_recording_manager(p)) 275 self.assertTrue( 276 ie.current_env().get_recording_manager(p, create_if_absent=True)) 277 278 def test_describe_all_recordings(self): 279 self.assertFalse(ie.current_env().describe_all_recordings()) 280 281 p1 = beam.Pipeline() 282 p2 = beam.Pipeline() 283 ie.current_env().watch(locals()) 284 ie.current_env().track_user_pipelines() 285 rm1 = ie.current_env().get_recording_manager(p1, create_if_absent=True) 286 rm2 = ie.current_env().get_recording_manager(p2, create_if_absent=True) 287 288 description = ie.current_env().describe_all_recordings() 289 self.assertTrue(description) 290 291 expected_description = {p1: rm1.describe(), p2: rm2.describe()} 292 self.assertDictEqual(description, expected_description) 293 294 def test_get_empty_sql_chain(self): 295 env = ie.InteractiveEnvironment() 296 p = beam.Pipeline() 297 chain = env.get_sql_chain(p) 298 self.assertIsNotNone(chain) 299 self.assertEqual(chain.nodes, {}) 300 301 def test_get_sql_chain_with_nodes(self): 302 env = ie.InteractiveEnvironment() 303 p = beam.Pipeline() 304 chain_with_node = env.get_sql_chain(p).append( 305 SqlNode(output_name='name', source=p, query="query")) 306 chain_got = env.get_sql_chain(p) 307 self.assertIs(chain_with_node, chain_got) 308 309 def test_get_sql_chain_setting_user_pipeline(self): 310 env = ie.InteractiveEnvironment() 311 p = beam.Pipeline() 312 chain = env.get_sql_chain(p, set_user_pipeline=True) 313 self.assertIs(chain.user_pipeline, p) 314 315 def test_get_sql_chain_None_when_setting_multiple_user_pipelines(self): 316 env = ie.InteractiveEnvironment() 317 p = beam.Pipeline() 318 chain = env.get_sql_chain(p, set_user_pipeline=True) 319 p2 = beam.Pipeline() 320 # Set the chain for a different pipeline. 321 env.sql_chain[p2] = chain 322 with self.assertRaises(ValueError): 323 env.get_sql_chain(p2, set_user_pipeline=True) 324 325 @patch( 326 'apache_beam.runners.interactive.interactive_environment.' 327 'assert_bucket_exists', 328 return_value=None) 329 def test_get_gcs_cache_dir_valid_path(self, mock_assert_bucket_exists): 330 env = ie.InteractiveEnvironment() 331 p = beam.Pipeline() 332 cache_root = 'gs://test-cache-dir/' 333 actual_cache_dir = env._get_gcs_cache_dir(p, cache_root) 334 expected_cache_dir = 'gs://test-cache-dir/{}'.format(id(p)) 335 self.assertEqual(actual_cache_dir, expected_cache_dir) 336 337 def test_get_gcs_cache_dir_invalid_path(self): 338 env = ie.InteractiveEnvironment() 339 p = beam.Pipeline() 340 cache_root = 'gs://' 341 with self.assertRaises(ValueError): 342 env._get_gcs_cache_dir(p, cache_root) 343 344 345 if __name__ == '__main__': 346 unittest.main()