github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/direct/direct_runner_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import threading
    21  import unittest
    22  from collections import defaultdict
    23  
    24  import hamcrest as hc
    25  
    26  import apache_beam as beam
    27  from apache_beam.metrics.cells import DistributionData
    28  from apache_beam.metrics.cells import DistributionResult
    29  from apache_beam.metrics.execution import MetricKey
    30  from apache_beam.metrics.execution import MetricResult
    31  from apache_beam.metrics.metric import Metrics
    32  from apache_beam.metrics.metric import MetricsFilter
    33  from apache_beam.metrics.metricbase import MetricName
    34  from apache_beam.pipeline import Pipeline
    35  from apache_beam.runners import DirectRunner
    36  from apache_beam.runners import TestDirectRunner
    37  from apache_beam.runners import create_runner
    38  from apache_beam.runners.direct.evaluation_context import _ExecutionContext
    39  from apache_beam.runners.direct.transform_evaluator import _GroupByKeyOnlyEvaluator
    40  from apache_beam.runners.direct.transform_evaluator import _TransformEvaluator
    41  from apache_beam.testing import test_pipeline
    42  from apache_beam.testing.util import assert_that
    43  from apache_beam.testing.util import equal_to
    44  
    45  
    46  class DirectPipelineResultTest(unittest.TestCase):
    47    def test_waiting_on_result_stops_executor_threads(self):
    48      pre_test_threads = set(t.ident for t in threading.enumerate())
    49  
    50      for runner in ['DirectRunner',
    51                     'BundleBasedDirectRunner',
    52                     'SwitchingDirectRunner']:
    53        pipeline = test_pipeline.TestPipeline(runner=runner)
    54        _ = (pipeline | beam.Create([{'foo': 'bar'}]))
    55        result = pipeline.run()
    56        result.wait_until_finish()
    57  
    58        post_test_threads = set(t.ident for t in threading.enumerate())
    59        new_threads = post_test_threads - pre_test_threads
    60        self.assertEqual(len(new_threads), 0)
    61  
    62    def test_direct_runner_metrics(self):
    63      class MyDoFn(beam.DoFn):
    64        def start_bundle(self):
    65          count = Metrics.counter(self.__class__, 'bundles')
    66          count.inc()
    67  
    68        def finish_bundle(self):
    69          count = Metrics.counter(self.__class__, 'finished_bundles')
    70          count.inc()
    71  
    72        def process(self, element):
    73          gauge = Metrics.gauge(self.__class__, 'latest_element')
    74          gauge.set(element)
    75          count = Metrics.counter(self.__class__, 'elements')
    76          count.inc()
    77          distro = Metrics.distribution(self.__class__, 'element_dist')
    78          distro.update(element)
    79          return [element]
    80  
    81      p = Pipeline(DirectRunner())
    82      pcoll = (
    83          p | beam.Create([1, 2, 3, 4, 5], reshuffle=False)
    84          | 'Do' >> beam.ParDo(MyDoFn()))
    85      assert_that(pcoll, equal_to([1, 2, 3, 4, 5]))
    86      result = p.run()
    87      result.wait_until_finish()
    88      metrics = result.metrics().query(MetricsFilter().with_step('Do'))
    89      namespace = '{}.{}'.format(MyDoFn.__module__, MyDoFn.__name__)
    90  
    91      hc.assert_that(
    92          metrics['counters'],
    93          hc.contains_inanyorder(
    94              MetricResult(
    95                  MetricKey('Do', MetricName(namespace, 'elements')), 5, 5),
    96              MetricResult(
    97                  MetricKey('Do', MetricName(namespace, 'bundles')), 1, 1),
    98              MetricResult(
    99                  MetricKey('Do', MetricName(namespace, 'finished_bundles')),
   100                  1,
   101                  1)))
   102  
   103      hc.assert_that(
   104          metrics['distributions'],
   105          hc.contains_inanyorder(
   106              MetricResult(
   107                  MetricKey('Do', MetricName(namespace, 'element_dist')),
   108                  DistributionResult(DistributionData(15, 5, 1, 5)),
   109                  DistributionResult(DistributionData(15, 5, 1, 5)))))
   110  
   111      gauge_result = metrics['gauges'][0]
   112      hc.assert_that(
   113          gauge_result.key,
   114          hc.equal_to(MetricKey('Do', MetricName(namespace, 'latest_element'))))
   115      hc.assert_that(gauge_result.committed.value, hc.equal_to(5))
   116      hc.assert_that(gauge_result.attempted.value, hc.equal_to(5))
   117  
   118    def test_create_runner(self):
   119      self.assertTrue(isinstance(create_runner('DirectRunner'), DirectRunner))
   120      self.assertTrue(
   121          isinstance(create_runner('TestDirectRunner'), TestDirectRunner))
   122  
   123  
   124  class BundleBasedRunnerTest(unittest.TestCase):
   125    def test_type_hints(self):
   126      with test_pipeline.TestPipeline(runner='BundleBasedDirectRunner') as p:
   127        _ = (
   128            p
   129            | beam.Create([[]]).with_output_types(beam.typehints.List[int])
   130            | beam.combiners.Count.Globally())
   131  
   132    def test_impulse(self):
   133      with test_pipeline.TestPipeline(runner='BundleBasedDirectRunner') as p:
   134        assert_that(p | beam.Impulse(), equal_to([b'']))
   135  
   136  
   137  class DirectRunnerRetryTests(unittest.TestCase):
   138    def test_retry_fork_graph(self):
   139      # TODO(https://github.com/apache/beam/issues/18640): The FnApiRunner
   140      # currently does not currently support retries.
   141      p = beam.Pipeline(runner='BundleBasedDirectRunner')
   142  
   143      # TODO(mariagh): Remove the use of globals from the test.
   144      global count_b, count_c  # pylint: disable=global-variable-undefined
   145      count_b, count_c = 0, 0
   146  
   147      def f_b(x):
   148        global count_b  # pylint: disable=global-variable-undefined
   149        count_b += 1
   150        raise Exception('exception in f_b')
   151  
   152      def f_c(x):
   153        global count_c  # pylint: disable=global-variable-undefined
   154        count_c += 1
   155        raise Exception('exception in f_c')
   156  
   157      names = p | 'CreateNodeA' >> beam.Create(['Ann', 'Joe'])
   158  
   159      fork_b = names | 'SendToB' >> beam.Map(f_b)  # pylint: disable=unused-variable
   160      fork_c = names | 'SendToC' >> beam.Map(f_c)  # pylint: disable=unused-variable
   161  
   162      with self.assertRaises(Exception):
   163        p.run().wait_until_finish()
   164      assert count_b == count_c == 4
   165  
   166    def test_no_partial_writeouts(self):
   167      class TestTransformEvaluator(_TransformEvaluator):
   168        def __init__(self):
   169          self._execution_context = _ExecutionContext(None, {})
   170  
   171        def start_bundle(self):
   172          self.step_context = self._execution_context.get_step_context()
   173  
   174        def process_element(self, element):
   175          k, v = element
   176          state = self.step_context.get_keyed_state(k)
   177          state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v)
   178  
   179      # Create instance and add key/value, key/value2
   180      evaluator = TestTransformEvaluator()
   181      evaluator.start_bundle()
   182      self.assertIsNone(evaluator.step_context.existing_keyed_state.get('key'))
   183      self.assertIsNone(evaluator.step_context.partial_keyed_state.get('key'))
   184  
   185      evaluator.process_element(['key', 'value'])
   186      self.assertEqual(
   187          evaluator.step_context.existing_keyed_state['key'].state,
   188          defaultdict(lambda: defaultdict(list)))
   189      self.assertEqual(
   190          evaluator.step_context.partial_keyed_state['key'].state,
   191          {None: {
   192              'elements': ['value']
   193          }})
   194  
   195      evaluator.process_element(['key', 'value2'])
   196      self.assertEqual(
   197          evaluator.step_context.existing_keyed_state['key'].state,
   198          defaultdict(lambda: defaultdict(list)))
   199      self.assertEqual(
   200          evaluator.step_context.partial_keyed_state['key'].state,
   201          {None: {
   202              'elements': ['value', 'value2']
   203          }})
   204  
   205      # Simulate an exception (redo key/value)
   206      evaluator._execution_context.reset()
   207      evaluator.start_bundle()
   208      evaluator.process_element(['key', 'value'])
   209      self.assertEqual(
   210          evaluator.step_context.existing_keyed_state['key'].state,
   211          defaultdict(lambda: defaultdict(list)))
   212      self.assertEqual(
   213          evaluator.step_context.partial_keyed_state['key'].state,
   214          {None: {
   215              'elements': ['value']
   216          }})
   217  
   218  
   219  if __name__ == '__main__':
   220    unittest.main()