github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/combinefn_lifecycle_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """ValidatesRunner tests for CombineFn lifecycle and bundle methods.""" 19 20 # pytype: skip-file 21 22 import unittest 23 from functools import wraps 24 25 import pytest 26 from parameterized import parameterized_class 27 28 from apache_beam.options.pipeline_options import DebugOptions 29 from apache_beam.options.pipeline_options import PipelineOptions 30 from apache_beam.options.pipeline_options import StandardOptions 31 from apache_beam.runners.direct import direct_runner 32 from apache_beam.runners.portability import fn_api_runner 33 from apache_beam.testing.test_pipeline import TestPipeline 34 from apache_beam.transforms.combinefn_lifecycle_pipeline import CallSequenceEnforcingCombineFn 35 from apache_beam.transforms.combinefn_lifecycle_pipeline import run_combine 36 from apache_beam.transforms.combinefn_lifecycle_pipeline import run_pardo 37 38 39 def skip_unless_v2(fn): 40 @wraps(fn) 41 def wrapped(*args, **kwargs): 42 self = args[0] 43 options = self.pipeline.get_pipeline_options() 44 standard_options = options.view_as(StandardOptions) 45 experiments = options.view_as(DebugOptions).experiments or [] 46 47 if 'DataflowRunner' in standard_options.runner and \ 48 'use_runner_v2' not in experiments: 49 self.skipTest( 50 'CombineFn.setup and CombineFn.teardown are not supported. ' 51 'Please use Dataflow Runner V2.') 52 else: 53 return fn(*args, **kwargs) 54 55 return wrapped 56 57 58 @pytest.mark.it_validatesrunner 59 class CombineFnLifecycleTest(unittest.TestCase): 60 def setUp(self): 61 self.pipeline = TestPipeline(is_integration_test=True) 62 63 @skip_unless_v2 64 def test_combine(self): 65 run_combine(self.pipeline) 66 67 @skip_unless_v2 68 def test_non_liftable_combine(self): 69 run_combine(self.pipeline, lift_combiners=False) 70 71 @skip_unless_v2 72 def test_combining_value_state(self): 73 if ('DataflowRunner' in self.pipeline.get_pipeline_options().view_as( 74 StandardOptions).runner): 75 self.skipTest('https://github.com/apache/beam/issues/20722') 76 run_pardo(self.pipeline) 77 78 79 @parameterized_class([ 80 {'runner': direct_runner.BundleBasedDirectRunner}, 81 {'runner': fn_api_runner.FnApiRunner}, 82 ]) # yapf: disable 83 class LocalCombineFnLifecycleTest(unittest.TestCase): 84 def tearDown(self): 85 CallSequenceEnforcingCombineFn.instances.clear() 86 87 def test_combine(self): 88 run_combine(TestPipeline(runner=self.runner())) 89 self._assert_teardown_called() 90 91 def test_non_liftable_combine(self): 92 test_options = PipelineOptions(flags=['--allow_unsafe_triggers']) 93 run_combine( 94 TestPipeline(runner=self.runner(), options=test_options), 95 lift_combiners=False) 96 self._assert_teardown_called() 97 98 def test_combining_value_state(self): 99 run_pardo(TestPipeline(runner=self.runner())) 100 self._assert_teardown_called() 101 102 def _assert_teardown_called(self): 103 """Ensures that teardown has been invoked for all CombineFns.""" 104 for instance in CallSequenceEnforcingCombineFn.instances: 105 self.assertTrue(instance._teardown_called) 106 107 108 if __name__ == '__main__': 109 unittest.main()