github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/direct/direct_runner_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import threading 21 import unittest 22 from collections import defaultdict 23 24 import hamcrest as hc 25 26 import apache_beam as beam 27 from apache_beam.metrics.cells import DistributionData 28 from apache_beam.metrics.cells import DistributionResult 29 from apache_beam.metrics.execution import MetricKey 30 from apache_beam.metrics.execution import MetricResult 31 from apache_beam.metrics.metric import Metrics 32 from apache_beam.metrics.metric import MetricsFilter 33 from apache_beam.metrics.metricbase import MetricName 34 from apache_beam.pipeline import Pipeline 35 from apache_beam.runners import DirectRunner 36 from apache_beam.runners import TestDirectRunner 37 from apache_beam.runners import create_runner 38 from apache_beam.runners.direct.evaluation_context import _ExecutionContext 39 from apache_beam.runners.direct.transform_evaluator import _GroupByKeyOnlyEvaluator 40 from apache_beam.runners.direct.transform_evaluator import _TransformEvaluator 41 from apache_beam.testing import test_pipeline 42 from apache_beam.testing.util import assert_that 43 from apache_beam.testing.util import equal_to 44 45 46 class DirectPipelineResultTest(unittest.TestCase): 47 def test_waiting_on_result_stops_executor_threads(self): 48 pre_test_threads = set(t.ident for t in threading.enumerate()) 49 50 for runner in ['DirectRunner', 51 'BundleBasedDirectRunner', 52 'SwitchingDirectRunner']: 53 pipeline = test_pipeline.TestPipeline(runner=runner) 54 _ = (pipeline | beam.Create([{'foo': 'bar'}])) 55 result = pipeline.run() 56 result.wait_until_finish() 57 58 post_test_threads = set(t.ident for t in threading.enumerate()) 59 new_threads = post_test_threads - pre_test_threads 60 self.assertEqual(len(new_threads), 0) 61 62 def test_direct_runner_metrics(self): 63 class MyDoFn(beam.DoFn): 64 def start_bundle(self): 65 count = Metrics.counter(self.__class__, 'bundles') 66 count.inc() 67 68 def finish_bundle(self): 69 count = Metrics.counter(self.__class__, 'finished_bundles') 70 count.inc() 71 72 def process(self, element): 73 gauge = Metrics.gauge(self.__class__, 'latest_element') 74 gauge.set(element) 75 count = Metrics.counter(self.__class__, 'elements') 76 count.inc() 77 distro = Metrics.distribution(self.__class__, 'element_dist') 78 distro.update(element) 79 return [element] 80 81 p = Pipeline(DirectRunner()) 82 pcoll = ( 83 p | beam.Create([1, 2, 3, 4, 5], reshuffle=False) 84 | 'Do' >> beam.ParDo(MyDoFn())) 85 assert_that(pcoll, equal_to([1, 2, 3, 4, 5])) 86 result = p.run() 87 result.wait_until_finish() 88 metrics = result.metrics().query(MetricsFilter().with_step('Do')) 89 namespace = '{}.{}'.format(MyDoFn.__module__, MyDoFn.__name__) 90 91 hc.assert_that( 92 metrics['counters'], 93 hc.contains_inanyorder( 94 MetricResult( 95 MetricKey('Do', MetricName(namespace, 'elements')), 5, 5), 96 MetricResult( 97 MetricKey('Do', MetricName(namespace, 'bundles')), 1, 1), 98 MetricResult( 99 MetricKey('Do', MetricName(namespace, 'finished_bundles')), 100 1, 101 1))) 102 103 hc.assert_that( 104 metrics['distributions'], 105 hc.contains_inanyorder( 106 MetricResult( 107 MetricKey('Do', MetricName(namespace, 'element_dist')), 108 DistributionResult(DistributionData(15, 5, 1, 5)), 109 DistributionResult(DistributionData(15, 5, 1, 5))))) 110 111 gauge_result = metrics['gauges'][0] 112 hc.assert_that( 113 gauge_result.key, 114 hc.equal_to(MetricKey('Do', MetricName(namespace, 'latest_element')))) 115 hc.assert_that(gauge_result.committed.value, hc.equal_to(5)) 116 hc.assert_that(gauge_result.attempted.value, hc.equal_to(5)) 117 118 def test_create_runner(self): 119 self.assertTrue(isinstance(create_runner('DirectRunner'), DirectRunner)) 120 self.assertTrue( 121 isinstance(create_runner('TestDirectRunner'), TestDirectRunner)) 122 123 124 class BundleBasedRunnerTest(unittest.TestCase): 125 def test_type_hints(self): 126 with test_pipeline.TestPipeline(runner='BundleBasedDirectRunner') as p: 127 _ = ( 128 p 129 | beam.Create([[]]).with_output_types(beam.typehints.List[int]) 130 | beam.combiners.Count.Globally()) 131 132 def test_impulse(self): 133 with test_pipeline.TestPipeline(runner='BundleBasedDirectRunner') as p: 134 assert_that(p | beam.Impulse(), equal_to([b''])) 135 136 137 class DirectRunnerRetryTests(unittest.TestCase): 138 def test_retry_fork_graph(self): 139 # TODO(https://github.com/apache/beam/issues/18640): The FnApiRunner 140 # currently does not currently support retries. 141 p = beam.Pipeline(runner='BundleBasedDirectRunner') 142 143 # TODO(mariagh): Remove the use of globals from the test. 144 global count_b, count_c # pylint: disable=global-variable-undefined 145 count_b, count_c = 0, 0 146 147 def f_b(x): 148 global count_b # pylint: disable=global-variable-undefined 149 count_b += 1 150 raise Exception('exception in f_b') 151 152 def f_c(x): 153 global count_c # pylint: disable=global-variable-undefined 154 count_c += 1 155 raise Exception('exception in f_c') 156 157 names = p | 'CreateNodeA' >> beam.Create(['Ann', 'Joe']) 158 159 fork_b = names | 'SendToB' >> beam.Map(f_b) # pylint: disable=unused-variable 160 fork_c = names | 'SendToC' >> beam.Map(f_c) # pylint: disable=unused-variable 161 162 with self.assertRaises(Exception): 163 p.run().wait_until_finish() 164 assert count_b == count_c == 4 165 166 def test_no_partial_writeouts(self): 167 class TestTransformEvaluator(_TransformEvaluator): 168 def __init__(self): 169 self._execution_context = _ExecutionContext(None, {}) 170 171 def start_bundle(self): 172 self.step_context = self._execution_context.get_step_context() 173 174 def process_element(self, element): 175 k, v = element 176 state = self.step_context.get_keyed_state(k) 177 state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v) 178 179 # Create instance and add key/value, key/value2 180 evaluator = TestTransformEvaluator() 181 evaluator.start_bundle() 182 self.assertIsNone(evaluator.step_context.existing_keyed_state.get('key')) 183 self.assertIsNone(evaluator.step_context.partial_keyed_state.get('key')) 184 185 evaluator.process_element(['key', 'value']) 186 self.assertEqual( 187 evaluator.step_context.existing_keyed_state['key'].state, 188 defaultdict(lambda: defaultdict(list))) 189 self.assertEqual( 190 evaluator.step_context.partial_keyed_state['key'].state, 191 {None: { 192 'elements': ['value'] 193 }}) 194 195 evaluator.process_element(['key', 'value2']) 196 self.assertEqual( 197 evaluator.step_context.existing_keyed_state['key'].state, 198 defaultdict(lambda: defaultdict(list))) 199 self.assertEqual( 200 evaluator.step_context.partial_keyed_state['key'].state, 201 {None: { 202 'elements': ['value', 'value2'] 203 }}) 204 205 # Simulate an exception (redo key/value) 206 evaluator._execution_context.reset() 207 evaluator.start_bundle() 208 evaluator.process_element(['key', 'value']) 209 self.assertEqual( 210 evaluator.step_context.existing_keyed_state['key'].state, 211 defaultdict(lambda: defaultdict(list))) 212 self.assertEqual( 213 evaluator.step_context.partial_keyed_state['key'].state, 214 {None: { 215 'elements': ['value'] 216 }}) 217 218 219 if __name__ == '__main__': 220 unittest.main()