github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dataflow/dataflow_runner_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the DataflowRunner class."""
    19  
    20  # pytype: skip-file
    21  
    22  import json
    23  import unittest
    24  from datetime import datetime
    25  from itertools import product
    26  
    27  import mock
    28  from parameterized import param
    29  from parameterized import parameterized
    30  
    31  import apache_beam as beam
    32  import apache_beam.transforms as ptransform
    33  from apache_beam.options.pipeline_options import DebugOptions
    34  from apache_beam.options.pipeline_options import GoogleCloudOptions
    35  from apache_beam.options.pipeline_options import PipelineOptions
    36  from apache_beam.pipeline import AppliedPTransform
    37  from apache_beam.pipeline import Pipeline
    38  from apache_beam.portability import common_urns
    39  from apache_beam.portability import python_urns
    40  from apache_beam.portability.api import beam_runner_api_pb2
    41  from apache_beam.pvalue import PCollection
    42  from apache_beam.runners import DataflowRunner
    43  from apache_beam.runners import TestDataflowRunner
    44  from apache_beam.runners import common
    45  from apache_beam.runners import create_runner
    46  from apache_beam.runners.dataflow.dataflow_runner import DataflowPipelineResult
    47  from apache_beam.runners.dataflow.dataflow_runner import DataflowRuntimeException
    48  from apache_beam.runners.dataflow.dataflow_runner import PropertyNames
    49  from apache_beam.runners.dataflow.dataflow_runner import _is_runner_v2
    50  from apache_beam.runners.dataflow.dataflow_runner import _is_runner_v2_disabled
    51  from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api
    52  from apache_beam.runners.runner import PipelineState
    53  from apache_beam.testing.extra_assertions import ExtraAssertionsMixin
    54  from apache_beam.testing.test_pipeline import TestPipeline
    55  from apache_beam.transforms import combiners
    56  from apache_beam.transforms import environments
    57  from apache_beam.transforms import window
    58  from apache_beam.transforms.core import Windowing
    59  from apache_beam.transforms.display import DisplayDataItem
    60  from apache_beam.typehints import typehints
    61  
    62  # Protect against environments where apitools library is not available.
    63  # pylint: disable=wrong-import-order, wrong-import-position
    64  try:
    65    from apache_beam.runners.dataflow.internal import apiclient
    66  except ImportError:
    67    apiclient = None  # type: ignore
    68  # pylint: enable=wrong-import-order, wrong-import-position
    69  
    70  
    71  # SpecialParDo and SpecialDoFn are used in test_remote_runner_display_data.
    72  # Due to https://github.com/apache/beam/issues/19848, these need to be declared
    73  # outside of the test method.
    74  # TODO: Should not subclass ParDo. Switch to PTransform as soon as
    75  # composite transforms support display data.
    76  class SpecialParDo(beam.ParDo):
    77    def __init__(self, fn, now):
    78      super().__init__(fn)
    79      self.fn = fn
    80      self.now = now
    81  
    82    # Make this a list to be accessible within closure
    83    def display_data(self):
    84      return {
    85          'asubcomponent': self.fn, 'a_class': SpecialParDo, 'a_time': self.now
    86      }
    87  
    88  
    89  class SpecialDoFn(beam.DoFn):
    90    def display_data(self):
    91      return {'dofn_value': 42}
    92  
    93    def process(self):
    94      pass
    95  
    96  
    97  @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed')
    98  class DataflowRunnerTest(unittest.TestCase, ExtraAssertionsMixin):
    99    def setUp(self):
   100      self.default_properties = [
   101          '--dataflow_endpoint=ignored',
   102          '--job_name=test-job',
   103          '--project=test-project',
   104          '--staging_location=ignored',
   105          '--temp_location=/dev/null',
   106          '--no_auth',
   107          '--dry_run=True',
   108          '--sdk_location=container'
   109      ]
   110  
   111    @mock.patch('time.sleep', return_value=None)
   112    def test_wait_until_finish(self, patched_time_sleep):
   113      values_enum = dataflow_api.Job.CurrentStateValueValuesEnum
   114  
   115      class MockDataflowRunner(object):
   116        def __init__(self, states):
   117          self.dataflow_client = mock.MagicMock()
   118          self.job = mock.MagicMock()
   119          self.job.currentState = values_enum.JOB_STATE_UNKNOWN
   120          self._states = states
   121          self._next_state_index = 0
   122  
   123          def get_job_side_effect(*args, **kwargs):
   124            self.job.currentState = self._states[self._next_state_index]
   125            if self._next_state_index < (len(self._states) - 1):
   126              self._next_state_index += 1
   127            return mock.DEFAULT
   128  
   129          self.dataflow_client.get_job = mock.MagicMock(
   130              return_value=self.job, side_effect=get_job_side_effect)
   131          self.dataflow_client.list_messages = mock.MagicMock(
   132              return_value=([], None))
   133  
   134      with self.assertRaisesRegex(DataflowRuntimeException,
   135                                  'Dataflow pipeline failed. State: FAILED'):
   136        failed_runner = MockDataflowRunner([values_enum.JOB_STATE_FAILED])
   137        failed_result = DataflowPipelineResult(failed_runner.job, failed_runner)
   138        failed_result.wait_until_finish()
   139  
   140      # check the second call can still triggers the exception
   141      with self.assertRaisesRegex(DataflowRuntimeException,
   142                                  'Dataflow pipeline failed. State: FAILED'):
   143        failed_result.wait_until_finish()
   144  
   145      succeeded_runner = MockDataflowRunner([values_enum.JOB_STATE_DONE])
   146      succeeded_result = DataflowPipelineResult(
   147          succeeded_runner.job, succeeded_runner)
   148      result = succeeded_result.wait_until_finish()
   149      self.assertEqual(result, PipelineState.DONE)
   150  
   151      # Time array has duplicate items, because some logging implementations also
   152      # call time.
   153      with mock.patch('time.time', mock.MagicMock(side_effect=[1, 1, 2, 2, 3])):
   154        duration_succeeded_runner = MockDataflowRunner(
   155            [values_enum.JOB_STATE_RUNNING, values_enum.JOB_STATE_DONE])
   156        duration_succeeded_result = DataflowPipelineResult(
   157            duration_succeeded_runner.job, duration_succeeded_runner)
   158        result = duration_succeeded_result.wait_until_finish(5000)
   159        self.assertEqual(result, PipelineState.DONE)
   160  
   161      with mock.patch('time.time', mock.MagicMock(side_effect=[1, 9, 9, 20, 20])):
   162        duration_timedout_runner = MockDataflowRunner(
   163            [values_enum.JOB_STATE_RUNNING])
   164        duration_timedout_result = DataflowPipelineResult(
   165            duration_timedout_runner.job, duration_timedout_runner)
   166        result = duration_timedout_result.wait_until_finish(5000)
   167        self.assertEqual(result, PipelineState.RUNNING)
   168  
   169      with mock.patch('time.time', mock.MagicMock(side_effect=[1, 1, 2, 2, 3])):
   170        with self.assertRaisesRegex(DataflowRuntimeException,
   171                                    'Dataflow pipeline failed. State: CANCELLED'):
   172          duration_failed_runner = MockDataflowRunner(
   173              [values_enum.JOB_STATE_CANCELLED])
   174          duration_failed_result = DataflowPipelineResult(
   175              duration_failed_runner.job, duration_failed_runner)
   176          duration_failed_result.wait_until_finish(5000)
   177  
   178    @mock.patch('time.sleep', return_value=None)
   179    def test_cancel(self, patched_time_sleep):
   180      values_enum = dataflow_api.Job.CurrentStateValueValuesEnum
   181  
   182      class MockDataflowRunner(object):
   183        def __init__(self, state, cancel_result):
   184          self.dataflow_client = mock.MagicMock()
   185          self.job = mock.MagicMock()
   186          self.job.currentState = state
   187  
   188          self.dataflow_client.get_job = mock.MagicMock(return_value=self.job)
   189          self.dataflow_client.modify_job_state = mock.MagicMock(
   190              return_value=cancel_result)
   191          self.dataflow_client.list_messages = mock.MagicMock(
   192              return_value=([], None))
   193  
   194      with self.assertRaisesRegex(DataflowRuntimeException,
   195                                  'Failed to cancel job'):
   196        failed_runner = MockDataflowRunner(values_enum.JOB_STATE_RUNNING, False)
   197        failed_result = DataflowPipelineResult(failed_runner.job, failed_runner)
   198        failed_result.cancel()
   199  
   200      succeeded_runner = MockDataflowRunner(values_enum.JOB_STATE_RUNNING, True)
   201      succeeded_result = DataflowPipelineResult(
   202          succeeded_runner.job, succeeded_runner)
   203      succeeded_result.cancel()
   204  
   205      terminal_runner = MockDataflowRunner(values_enum.JOB_STATE_DONE, False)
   206      terminal_result = DataflowPipelineResult(
   207          terminal_runner.job, terminal_runner)
   208      terminal_result.cancel()
   209  
   210    def test_create_runner(self):
   211      self.assertTrue(isinstance(create_runner('DataflowRunner'), DataflowRunner))
   212      self.assertTrue(
   213          isinstance(create_runner('TestDataflowRunner'), TestDataflowRunner))
   214  
   215    def test_environment_override_translation_legacy_worker_harness_image(self):
   216      self.default_properties.append('--experiments=beam_fn_api')
   217      self.default_properties.append('--worker_harness_container_image=LEGACY')
   218      remote_runner = DataflowRunner()
   219      with Pipeline(remote_runner,
   220                    options=PipelineOptions(self.default_properties)) as p:
   221        (  # pylint: disable=expression-not-assigned
   222            p | ptransform.Create([1, 2, 3])
   223            | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
   224            | ptransform.GroupByKey())
   225      self.assertEqual(
   226          list(remote_runner.proto_pipeline.components.environments.values()),
   227          [
   228              beam_runner_api_pb2.Environment(
   229                  urn=common_urns.environments.DOCKER.urn,
   230                  payload=beam_runner_api_pb2.DockerPayload(
   231                      container_image='LEGACY').SerializeToString(),
   232                  capabilities=environments.python_sdk_docker_capabilities())
   233          ])
   234  
   235    def test_environment_override_translation_sdk_container_image(self):
   236      self.default_properties.append('--experiments=beam_fn_api')
   237      self.default_properties.append('--sdk_container_image=FOO')
   238      remote_runner = DataflowRunner()
   239      with Pipeline(remote_runner,
   240                    options=PipelineOptions(self.default_properties)) as p:
   241        (  # pylint: disable=expression-not-assigned
   242            p | ptransform.Create([1, 2, 3])
   243            | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
   244            | ptransform.GroupByKey())
   245      self.assertEqual(
   246          list(remote_runner.proto_pipeline.components.environments.values()),
   247          [
   248              beam_runner_api_pb2.Environment(
   249                  urn=common_urns.environments.DOCKER.urn,
   250                  payload=beam_runner_api_pb2.DockerPayload(
   251                      container_image='FOO').SerializeToString(),
   252                  capabilities=environments.python_sdk_docker_capabilities())
   253          ])
   254  
   255    def test_remote_runner_translation(self):
   256      remote_runner = DataflowRunner()
   257      with Pipeline(remote_runner,
   258                    options=PipelineOptions(self.default_properties)) as p:
   259  
   260        (  # pylint: disable=expression-not-assigned
   261            p | ptransform.Create([1, 2, 3])
   262            | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)])
   263            | ptransform.GroupByKey())
   264  
   265    def test_remote_runner_display_data(self):
   266      remote_runner = DataflowRunner()
   267      p = Pipeline(
   268          remote_runner, options=PipelineOptions(self.default_properties))
   269  
   270      now = datetime.now()
   271      # pylint: disable=expression-not-assigned
   272      (
   273          p | ptransform.Create([1, 2, 3, 4, 5])
   274          | 'Do' >> SpecialParDo(SpecialDoFn(), now))
   275  
   276      # TODO(https://github.com/apache/beam/issues/18012) Enable runner API on
   277      # this test.
   278      p.run(test_runner_api=False)
   279      job_dict = json.loads(str(remote_runner.job))
   280      steps = [
   281          step for step in job_dict['steps']
   282          if len(step['properties'].get('display_data', [])) > 0
   283      ]
   284      step = steps[1]
   285      disp_data = step['properties']['display_data']
   286      nspace = SpecialParDo.__module__ + '.'
   287      expected_data = [{
   288          'type': 'TIMESTAMP',
   289          'namespace': nspace + 'SpecialParDo',
   290          'value': DisplayDataItem._format_value(now, 'TIMESTAMP'),
   291          'key': 'a_time'
   292      },
   293                       {
   294                           'type': 'STRING',
   295                           'namespace': nspace + 'SpecialParDo',
   296                           'value': nspace + 'SpecialParDo',
   297                           'key': 'a_class',
   298                           'shortValue': 'SpecialParDo'
   299                       },
   300                       {
   301                           'type': 'INTEGER',
   302                           'namespace': nspace + 'SpecialDoFn',
   303                           'value': 42,
   304                           'key': 'dofn_value'
   305                       }]
   306      self.assertUnhashableCountEqual(disp_data, expected_data)
   307  
   308    def test_group_by_key_input_visitor_with_valid_inputs(self):
   309      p = TestPipeline()
   310      pcoll1 = PCollection(p)
   311      pcoll2 = PCollection(p)
   312      pcoll3 = PCollection(p)
   313  
   314      pcoll1.element_type = None
   315      pcoll2.element_type = typehints.Any
   316      pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any]
   317      for pcoll in [pcoll1, pcoll2, pcoll3]:
   318        applied = AppliedPTransform(
   319            None, beam.GroupByKey(), "label", {'pcoll': pcoll})
   320        applied.outputs[None] = PCollection(None)
   321        common.group_by_key_input_visitor().visit_transform(applied)
   322        self.assertEqual(
   323            pcoll.element_type, typehints.KV[typehints.Any, typehints.Any])
   324  
   325    def test_group_by_key_input_visitor_with_invalid_inputs(self):
   326      p = TestPipeline()
   327      pcoll1 = PCollection(p)
   328      pcoll2 = PCollection(p)
   329  
   330      pcoll1.element_type = str
   331      pcoll2.element_type = typehints.Set
   332      err_msg = (
   333          r"Input to 'label' must be compatible with KV\[Any, Any\]. "
   334          "Found .*")
   335      for pcoll in [pcoll1, pcoll2]:
   336        with self.assertRaisesRegex(ValueError, err_msg):
   337          common.group_by_key_input_visitor().visit_transform(
   338              AppliedPTransform(None, beam.GroupByKey(), "label", {'in': pcoll}))
   339  
   340    def test_group_by_key_input_visitor_for_non_gbk_transforms(self):
   341      p = TestPipeline()
   342      pcoll = PCollection(p)
   343      for transform in [beam.Flatten(), beam.Map(lambda x: x)]:
   344        pcoll.element_type = typehints.Any
   345        common.group_by_key_input_visitor().visit_transform(
   346            AppliedPTransform(None, transform, "label", {'in': pcoll}))
   347        self.assertEqual(pcoll.element_type, typehints.Any)
   348  
   349    def test_flatten_input_with_visitor_with_single_input(self):
   350      self._test_flatten_input_visitor(typehints.KV[int, int], typehints.Any, 1)
   351  
   352    def test_flatten_input_with_visitor_with_multiple_inputs(self):
   353      self._test_flatten_input_visitor(
   354          typehints.KV[int, typehints.Any], typehints.Any, 5)
   355  
   356    def _test_flatten_input_visitor(self, input_type, output_type, num_inputs):
   357      p = TestPipeline()
   358      inputs = {}
   359      for ix in range(num_inputs):
   360        input_pcoll = PCollection(p)
   361        input_pcoll.element_type = input_type
   362        inputs[str(ix)] = input_pcoll
   363      output_pcoll = PCollection(p)
   364      output_pcoll.element_type = output_type
   365  
   366      flatten = AppliedPTransform(None, beam.Flatten(), "label", inputs)
   367      flatten.add_output(output_pcoll, None)
   368      DataflowRunner.flatten_input_visitor().visit_transform(flatten)
   369      for _ in range(num_inputs):
   370        self.assertEqual(inputs['0'].element_type, output_type)
   371  
   372    def test_gbk_then_flatten_input_visitor(self):
   373      p = TestPipeline(
   374          runner=DataflowRunner(),
   375          options=PipelineOptions(self.default_properties))
   376      none_str_pc = p | 'c1' >> beam.Create({None: 'a'})
   377      none_int_pc = p | 'c2' >> beam.Create({None: 3})
   378      flat = (none_str_pc, none_int_pc) | beam.Flatten()
   379      _ = flat | beam.GroupByKey()
   380  
   381      # This may change if type inference changes, but we assert it here
   382      # to make sure the check below is not vacuous.
   383      self.assertNotIsInstance(flat.element_type, typehints.TupleConstraint)
   384  
   385      p.visit(common.group_by_key_input_visitor())
   386      p.visit(DataflowRunner.flatten_input_visitor())
   387  
   388      # The dataflow runner requires gbk input to be tuples *and* flatten
   389      # inputs to be equal to their outputs. Assert both hold.
   390      self.assertIsInstance(flat.element_type, typehints.TupleConstraint)
   391      self.assertEqual(flat.element_type, none_str_pc.element_type)
   392      self.assertEqual(flat.element_type, none_int_pc.element_type)
   393  
   394    def test_serialize_windowing_strategy(self):
   395      # This just tests the basic path; more complete tests
   396      # are in window_test.py.
   397      strategy = Windowing(window.FixedWindows(10))
   398      self.assertEqual(
   399          strategy,
   400          DataflowRunner.deserialize_windowing_strategy(
   401              DataflowRunner.serialize_windowing_strategy(strategy, None)))
   402  
   403    def test_side_input_visitor(self):
   404      p = TestPipeline()
   405      pc = p | beam.Create([])
   406  
   407      transform = beam.Map(
   408          lambda x,
   409          y,
   410          z: (x, y, z),
   411          beam.pvalue.AsSingleton(pc),
   412          beam.pvalue.AsMultiMap(pc))
   413      applied_transform = AppliedPTransform(None, transform, "label", {'pc': pc})
   414      DataflowRunner.side_input_visitor(
   415          is_runner_v2=True).visit_transform(applied_transform)
   416      self.assertEqual(2, len(applied_transform.side_inputs))
   417      self.assertEqual(
   418          common_urns.side_inputs.ITERABLE.urn,
   419          applied_transform.side_inputs[0]._side_input_data().access_pattern)
   420      self.assertEqual(
   421          common_urns.side_inputs.MULTIMAP.urn,
   422          applied_transform.side_inputs[1]._side_input_data().access_pattern)
   423  
   424    def test_min_cpu_platform_flag_is_propagated_to_experiments(self):
   425      remote_runner = DataflowRunner()
   426      self.default_properties.append('--min_cpu_platform=Intel Haswell')
   427  
   428      with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p:
   429        p | ptransform.Create([1])  # pylint: disable=expression-not-assigned
   430      self.assertIn(
   431          'min_cpu_platform=Intel Haswell',
   432          remote_runner.job.options.view_as(DebugOptions).experiments)
   433  
   434    def test_streaming_engine_flag_adds_windmill_experiments(self):
   435      remote_runner = DataflowRunner()
   436      self.default_properties.append('--streaming')
   437      self.default_properties.append('--enable_streaming_engine')
   438      self.default_properties.append('--experiment=some_other_experiment')
   439  
   440      with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p:
   441        p | ptransform.Create([1])  # pylint: disable=expression-not-assigned
   442  
   443      experiments_for_job = (
   444          remote_runner.job.options.view_as(DebugOptions).experiments)
   445      self.assertIn('enable_streaming_engine', experiments_for_job)
   446      self.assertIn('enable_windmill_service', experiments_for_job)
   447      self.assertIn('some_other_experiment', experiments_for_job)
   448  
   449    def test_upload_graph_experiment(self):
   450      remote_runner = DataflowRunner()
   451      self.default_properties.append('--experiment=upload_graph')
   452  
   453      with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p:
   454        p | ptransform.Create([1])  # pylint: disable=expression-not-assigned
   455  
   456      experiments_for_job = (
   457          remote_runner.job.options.view_as(DebugOptions).experiments)
   458      self.assertIn('upload_graph', experiments_for_job)
   459  
   460    def test_use_fastavro_experiment_is_not_added_when_use_avro_is_present(self):
   461      remote_runner = DataflowRunner()
   462      self.default_properties.append('--experiment=use_avro')
   463  
   464      with Pipeline(remote_runner, PipelineOptions(self.default_properties)) as p:
   465        p | ptransform.Create([1])  # pylint: disable=expression-not-assigned
   466  
   467      debug_options = remote_runner.job.options.view_as(DebugOptions)
   468  
   469      self.assertFalse(debug_options.lookup_experiment('use_fastavro', False))
   470  
   471    @mock.patch('os.environ.get', return_value=None)
   472    @mock.patch('apache_beam.utils.processes.check_output', return_value=b'')
   473    def test_get_default_gcp_region_no_default_returns_none(
   474        self, patched_environ, patched_processes):
   475      runner = DataflowRunner()
   476      result = runner.get_default_gcp_region()
   477      self.assertIsNone(result)
   478  
   479    @mock.patch('os.environ.get', return_value='some-region1')
   480    @mock.patch('apache_beam.utils.processes.check_output', return_value=b'')
   481    def test_get_default_gcp_region_from_environ(
   482        self, patched_environ, patched_processes):
   483      runner = DataflowRunner()
   484      result = runner.get_default_gcp_region()
   485      self.assertEqual(result, 'some-region1')
   486  
   487    @mock.patch('os.environ.get', return_value=None)
   488    @mock.patch(
   489        'apache_beam.utils.processes.check_output',
   490        return_value=b'some-region2\n')
   491    def test_get_default_gcp_region_from_gcloud(
   492        self, patched_environ, patched_processes):
   493      runner = DataflowRunner()
   494      result = runner.get_default_gcp_region()
   495      self.assertEqual(result, 'some-region2')
   496  
   497    @mock.patch('os.environ.get', return_value=None)
   498    @mock.patch(
   499        'apache_beam.utils.processes.check_output',
   500        side_effect=RuntimeError('Executable gcloud not found'))
   501    def test_get_default_gcp_region_ignores_error(
   502        self, patched_environ, patched_processes):
   503      runner = DataflowRunner()
   504      result = runner.get_default_gcp_region()
   505      self.assertIsNone(result)
   506  
   507    def test_combine_values_translation(self):
   508      runner = DataflowRunner()
   509  
   510      with beam.Pipeline(runner=runner,
   511                         options=PipelineOptions(self.default_properties)) as p:
   512        (  # pylint: disable=expression-not-assigned
   513            p
   514            | beam.Create([('a', [1, 2]), ('b', [3, 4])])
   515            | beam.CombineValues(lambda v, _: sum(v)))
   516  
   517      job_dict = json.loads(str(runner.job))
   518      self.assertIn(
   519          'CombineValues', set(step['kind'] for step in job_dict['steps']))
   520  
   521    def _find_step(self, job, step_name):
   522      job_dict = json.loads(str(job))
   523      maybe_step = [
   524          s for s in job_dict['steps']
   525          if s['properties']['user_name'] == step_name
   526      ]
   527      self.assertTrue(maybe_step, 'Could not find step {}'.format(step_name))
   528      return maybe_step[0]
   529  
   530    def expect_correct_override(self, job, step_name, step_kind):
   531      """Expects that a transform was correctly overriden."""
   532  
   533      # If the typing information isn't being forwarded correctly, the component
   534      # encodings here will be incorrect.
   535      expected_output_info = [{
   536          "encoding": {
   537              "@type": "kind:windowed_value",
   538              "component_encodings": [{
   539                  "@type": "kind:bytes"
   540              }, {
   541                  "@type": "kind:global_window"
   542              }],
   543              "is_wrapper": True
   544          },
   545          "output_name": "out",
   546          "user_name": step_name + ".out"
   547      }]
   548  
   549      step = self._find_step(job, step_name)
   550      self.assertEqual(step['kind'], step_kind)
   551  
   552      # The display data here is forwarded because the replace transform is
   553      # subclassed from iobase.Read.
   554      self.assertGreater(len(step['properties']['display_data']), 0)
   555      self.assertEqual(step['properties']['output_info'], expected_output_info)
   556  
   557    def test_read_create_translation(self):
   558      runner = DataflowRunner()
   559  
   560      with beam.Pipeline(runner=runner,
   561                         options=PipelineOptions(self.default_properties)) as p:
   562        # pylint: disable=expression-not-assigned
   563        p | beam.Create([b'a', b'b', b'c'])
   564  
   565      self.expect_correct_override(runner.job, 'Create/Read', 'ParallelRead')
   566  
   567    def test_read_pubsub_translation(self):
   568      runner = DataflowRunner()
   569  
   570      self.default_properties.append("--streaming")
   571  
   572      with beam.Pipeline(runner=runner,
   573                         options=PipelineOptions(self.default_properties)) as p:
   574        # pylint: disable=expression-not-assigned
   575        p | beam.io.ReadFromPubSub(topic='projects/project/topics/topic')
   576  
   577      self.expect_correct_override(
   578          runner.job, 'ReadFromPubSub/Read', 'ParallelRead')
   579  
   580    def test_gbk_translation(self):
   581      runner = DataflowRunner()
   582      with beam.Pipeline(runner=runner,
   583                         options=PipelineOptions(self.default_properties)) as p:
   584        # pylint: disable=expression-not-assigned
   585        p | beam.Create([(1, 2)]) | beam.GroupByKey()
   586  
   587      expected_output_info = [{
   588          "encoding": {
   589              "@type": "kind:windowed_value",
   590              "component_encodings": [{
   591                  "@type": "kind:pair",
   592                  "component_encodings": [{
   593                      "@type": "kind:varint"
   594                  },
   595                  {
   596                      "@type": "kind:stream",
   597                      "component_encodings": [{
   598                          "@type": "kind:varint"
   599                      }],
   600                      "is_stream_like": True
   601                  }],
   602                  "is_pair_like": True
   603              }, {
   604                  "@type": "kind:global_window"
   605              }],
   606              "is_wrapper": True
   607          },
   608          "output_name": "out",
   609          "user_name": "GroupByKey.out"
   610      }]  # yapf: disable
   611  
   612      gbk_step = self._find_step(runner.job, 'GroupByKey')
   613      self.assertEqual(gbk_step['kind'], 'GroupByKey')
   614      self.assertEqual(
   615          gbk_step['properties']['output_info'], expected_output_info)
   616  
   617    @unittest.skip(
   618        'https://github.com/apache/beam/issues/18716: enable once '
   619        'CombineFnVisitor is fixed')
   620    def test_unsupported_combinefn_detection(self):
   621      class CombinerWithNonDefaultSetupTeardown(combiners.CountCombineFn):
   622        def setup(self, *args, **kwargs):
   623          pass
   624  
   625        def teardown(self, *args, **kwargs):
   626          pass
   627  
   628      runner = DataflowRunner()
   629      with self.assertRaisesRegex(ValueError,
   630                                  'CombineFn.setup and CombineFn.'
   631                                  'teardown are not supported'):
   632        with beam.Pipeline(runner=runner,
   633                           options=PipelineOptions(self.default_properties)) as p:
   634          _ = (
   635              p | beam.Create([1])
   636              | beam.CombineGlobally(CombinerWithNonDefaultSetupTeardown()))
   637  
   638      try:
   639        with beam.Pipeline(runner=runner,
   640                           options=PipelineOptions(self.default_properties)) as p:
   641          _ = (
   642              p | beam.Create([1])
   643              | beam.CombineGlobally(
   644                  combiners.SingleInputTupleCombineFn(
   645                      combiners.CountCombineFn(), combiners.CountCombineFn())))
   646      except ValueError:
   647        self.fail('ValueError raised unexpectedly')
   648  
   649    def _run_group_into_batches_and_get_step_properties(
   650        self, with_sharded_key, additional_properties):
   651      self.default_properties.append('--streaming')
   652      for property in additional_properties:
   653        self.default_properties.append(property)
   654  
   655      runner = DataflowRunner()
   656      with beam.Pipeline(runner=runner,
   657                         options=PipelineOptions(self.default_properties)) as p:
   658        # pylint: disable=expression-not-assigned
   659        input = p | beam.Create([('a', 1), ('a', 1), ('b', 3), ('b', 4)])
   660        if with_sharded_key:
   661          (
   662              input | beam.GroupIntoBatches.WithShardedKey(2)
   663              | beam.Map(lambda key_values: (key_values[0].key, key_values[1])))
   664          step_name = (
   665              'WithShardedKey/GroupIntoBatches/ParDo(_GroupIntoBatchesDoFn)')
   666        else:
   667          input | beam.GroupIntoBatches(2)
   668          step_name = 'GroupIntoBatches/ParDo(_GroupIntoBatchesDoFn)'
   669  
   670      return self._find_step(runner.job, step_name)['properties']
   671  
   672    def test_group_into_batches_translation(self):
   673      properties = self._run_group_into_batches_and_get_step_properties(
   674          True, ['--enable_streaming_engine', '--experiments=use_runner_v2'])
   675      self.assertEqual(properties[PropertyNames.USES_KEYED_STATE], 'true')
   676      self.assertEqual(properties[PropertyNames.ALLOWS_SHARDABLE_STATE], 'true')
   677      self.assertEqual(properties[PropertyNames.PRESERVES_KEYS], 'true')
   678  
   679    def test_group_into_batches_translation_non_sharded(self):
   680      properties = self._run_group_into_batches_and_get_step_properties(
   681          False, ['--enable_streaming_engine', '--experiments=use_runner_v2'])
   682      self.assertEqual(properties[PropertyNames.USES_KEYED_STATE], 'true')
   683      self.assertNotIn(PropertyNames.ALLOWS_SHARDABLE_STATE, properties)
   684      self.assertNotIn(PropertyNames.PRESERVES_KEYS, properties)
   685  
   686    def test_pack_combiners(self):
   687      class PackableCombines(beam.PTransform):
   688        def annotations(self):
   689          return {python_urns.APPLY_COMBINER_PACKING: b''}
   690  
   691        def expand(self, pcoll):
   692          _ = pcoll | 'PackableMin' >> beam.CombineGlobally(min)
   693          _ = pcoll | 'PackableMax' >> beam.CombineGlobally(max)
   694  
   695      runner = DataflowRunner()
   696      with beam.Pipeline(runner=runner,
   697                         options=PipelineOptions(self.default_properties)) as p:
   698        _ = p | beam.Create([10, 20, 30]) | PackableCombines()
   699  
   700      unpacked_minimum_step_name = (
   701          'PackableCombines/PackableMin/CombinePerKey/Combine')
   702      unpacked_maximum_step_name = (
   703          'PackableCombines/PackableMax/CombinePerKey/Combine')
   704      packed_step_name = (
   705          'PackableCombines/Packed[PackableMin_CombinePerKey, '
   706          'PackableMax_CombinePerKey]/Pack')
   707      transform_names = set(
   708          transform.unique_name
   709          for transform in runner.proto_pipeline.components.transforms.values())
   710      self.assertNotIn(unpacked_minimum_step_name, transform_names)
   711      self.assertNotIn(unpacked_maximum_step_name, transform_names)
   712      self.assertIn(packed_step_name, transform_names)
   713  
   714    @parameterized.expand([
   715        param(memory_hint='min_ram'),
   716        param(memory_hint='minRam'),
   717    ])
   718    def test_resource_hints_translation(self, memory_hint):
   719      runner = DataflowRunner()
   720      self.default_properties.append('--resource_hint=accelerator=some_gpu')
   721      self.default_properties.append(f'--resource_hint={memory_hint}=20GB')
   722      with beam.Pipeline(runner=runner,
   723                         options=PipelineOptions(self.default_properties)) as p:
   724        # pylint: disable=expression-not-assigned
   725        (
   726            p
   727            | beam.Create([1])
   728            | 'MapWithHints' >> beam.Map(lambda x: x + 1).with_resource_hints(
   729                min_ram='10GB',
   730                accelerator='type:nvidia-tesla-k80;count:1;install-nvidia-drivers'
   731            ))
   732  
   733      step = self._find_step(runner.job, 'MapWithHints')
   734      self.assertEqual(
   735          step['properties']['resource_hints'],
   736          {
   737              'beam:resources:min_ram_bytes:v1': '20000000000',
   738              'beam:resources:accelerator:v1': \
   739                  'type%3Anvidia-tesla-k80%3Bcount%3A1%3Binstall-nvidia-drivers'
   740          })
   741  
   742    @parameterized.expand([
   743        (
   744            "%s_%s" % (enable_option, disable_option),
   745            enable_option,
   746            disable_option)
   747        for (enable_option,
   748             disable_option) in product([
   749                 False,
   750                 'enable_prime',
   751                 'beam_fn_api',
   752                 'use_unified_worker',
   753                 'use_runner_v2',
   754                 'use_portable_job_submission'
   755             ],
   756                                        [
   757                                            False,
   758                                            'disable_runner_v2',
   759                                            'disable_runner_v2_until_2023',
   760                                            'disable_prime_runner_v2'
   761                                        ])
   762    ])
   763    def test_batch_is_runner_v2(self, name, enable_option, disable_option):
   764      options = PipelineOptions(
   765          (['--experiments=%s' % enable_option] if enable_option else []) +
   766          (['--experiments=%s' % disable_option] if disable_option else []))
   767      if (enable_option and disable_option):
   768        with self.assertRaisesRegex(ValueError,
   769                                    'Runner V2 both disabled and enabled'):
   770          _is_runner_v2(options)
   771      elif enable_option:
   772        self.assertTrue(_is_runner_v2(options))
   773        self.assertFalse(_is_runner_v2_disabled(options))
   774        for expected in ['beam_fn_api',
   775                         'use_unified_worker',
   776                         'use_runner_v2',
   777                         'use_portable_job_submission']:
   778          self.assertTrue(
   779              options.view_as(DebugOptions).lookup_experiment(expected, False))
   780        if enable_option == 'enable_prime':
   781          self.assertIn(
   782              'enable_prime',
   783              options.view_as(GoogleCloudOptions).dataflow_service_options)
   784      elif disable_option:
   785        self.assertFalse(_is_runner_v2(options))
   786        self.assertTrue(_is_runner_v2_disabled(options))
   787      else:
   788        self.assertFalse(_is_runner_v2(options))
   789  
   790    @parameterized.expand([
   791        (
   792            "%s_%s" % (enable_option, disable_option),
   793            enable_option,
   794            disable_option)
   795        for (enable_option,
   796             disable_option) in product([
   797                 False,
   798                 'enable_prime',
   799                 'beam_fn_api',
   800                 'use_unified_worker',
   801                 'use_runner_v2',
   802                 'use_portable_job_submission'
   803             ],
   804                                        [
   805                                            False,
   806                                            'disable_runner_v2',
   807                                            'disable_runner_v2_until_2023',
   808                                            'disable_prime_runner_v2'
   809                                        ])
   810    ])
   811    def test_streaming_is_runner_v2(self, name, enable_option, disable_option):
   812      options = PipelineOptions(
   813          ['--streaming'] +
   814          (['--experiments=%s' % enable_option] if enable_option else []) +
   815          (['--experiments=%s' % disable_option] if disable_option else []))
   816      if disable_option:
   817        with self.assertRaisesRegex(
   818            ValueError,
   819            'Disabling Runner V2 no longer supported for streaming pipeline'):
   820          _is_runner_v2(options)
   821      else:
   822        self.assertTrue(_is_runner_v2(options))
   823        for expected in ['beam_fn_api',
   824                         'use_unified_worker',
   825                         'use_runner_v2',
   826                         'use_portable_job_submission',
   827                         'enable_windmill_service',
   828                         'enable_streaming_engine']:
   829          self.assertTrue(
   830              options.view_as(DebugOptions).lookup_experiment(expected, False))
   831        if enable_option == 'enable_prime':
   832          self.assertIn(
   833              'enable_prime',
   834              options.view_as(GoogleCloudOptions).dataflow_service_options)
   835  
   836    def test_dataflow_service_options_enable_prime_sets_runner_v2(self):
   837      options = PipelineOptions(['--dataflow_service_options=enable_prime'])
   838      self.assertTrue(_is_runner_v2(options))
   839      for expected in ['beam_fn_api',
   840                       'use_unified_worker',
   841                       'use_runner_v2',
   842                       'use_portable_job_submission']:
   843        self.assertTrue(
   844            options.view_as(DebugOptions).lookup_experiment(expected, False))
   845  
   846      options = PipelineOptions(
   847          ['--streaming', '--dataflow_service_options=enable_prime'])
   848      self.assertTrue(_is_runner_v2(options))
   849      for expected in ['beam_fn_api',
   850                       'use_unified_worker',
   851                       'use_runner_v2',
   852                       'use_portable_job_submission',
   853                       'enable_windmill_service',
   854                       'enable_streaming_engine']:
   855        self.assertTrue(
   856            options.view_as(DebugOptions).lookup_experiment(expected, False))
   857  
   858  
   859  if __name__ == '__main__':
   860    unittest.main()