github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dataflow/internal/apiclient_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the apiclient module."""
    19  
    20  # pytype: skip-file
    21  
    22  import itertools
    23  import json
    24  import logging
    25  import os
    26  import sys
    27  import unittest
    28  
    29  import mock
    30  
    31  from apache_beam.io.filesystems import FileSystems
    32  from apache_beam.metrics.cells import DistributionData
    33  from apache_beam.options.pipeline_options import GoogleCloudOptions
    34  from apache_beam.options.pipeline_options import PipelineOptions
    35  from apache_beam.pipeline import Pipeline
    36  from apache_beam.portability import common_urns
    37  from apache_beam.portability.api import beam_runner_api_pb2
    38  from apache_beam.runners.dataflow.internal import names
    39  from apache_beam.runners.dataflow.internal.clients import dataflow
    40  from apache_beam.transforms import Create
    41  from apache_beam.transforms import DataflowDistributionCounter
    42  from apache_beam.transforms import DoFn
    43  from apache_beam.transforms import ParDo
    44  from apache_beam.transforms.environments import DockerEnvironment
    45  
    46  # Protect against environments where apitools library is not available.
    47  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
    48  try:
    49    from apache_beam.runners.dataflow.internal import apiclient
    50  except ImportError:
    51    apiclient = None  # type: ignore
    52  # pylint: enable=wrong-import-order, wrong-import-position
    53  
    54  FAKE_PIPELINE_URL = "gs://invalid-bucket/anywhere"
    55  _LOGGER = logging.getLogger(__name__)
    56  
    57  
    58  @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed')
    59  class UtilTest(unittest.TestCase):
    60    @unittest.skip("Enable once BEAM-1080 is fixed.")
    61    def test_create_application_client(self):
    62      pipeline_options = PipelineOptions()
    63      apiclient.DataflowApplicationClient(pipeline_options)
    64  
    65    def test_pipeline_url(self):
    66      pipeline_options = PipelineOptions([
    67          '--subnetwork',
    68          '/regions/MY/subnetworks/SUBNETWORK',
    69          '--temp_location',
    70          'gs://any-location/temp'
    71      ])
    72      env = apiclient.Environment(
    73          [],
    74          pipeline_options,
    75          '2.0.0',  # any environment version
    76          FAKE_PIPELINE_URL)
    77  
    78      recovered_options = None
    79      for additionalProperty in env.proto.sdkPipelineOptions.additionalProperties:
    80        if additionalProperty.key == 'options':
    81          recovered_options = additionalProperty.value
    82          break
    83      else:
    84        self.fail(
    85            'No pipeline options found in %s' % env.proto.sdkPipelineOptions)
    86  
    87      pipeline_url = None
    88      for property in recovered_options.object_value.properties:
    89        if property.key == 'pipelineUrl':
    90          pipeline_url = property.value
    91          break
    92      else:
    93        self.fail('No pipeline_url found in %s' % recovered_options)
    94  
    95      self.assertEqual(pipeline_url.string_value, FAKE_PIPELINE_URL)
    96  
    97    def test_set_network(self):
    98      pipeline_options = PipelineOptions([
    99          '--network',
   100          'anetworkname',
   101          '--temp_location',
   102          'gs://any-location/temp'
   103      ])
   104      env = apiclient.Environment(
   105          [],  #packages
   106          pipeline_options,
   107          '2.0.0',  #any environment version
   108          FAKE_PIPELINE_URL)
   109      self.assertEqual(env.proto.workerPools[0].network, 'anetworkname')
   110  
   111    def test_set_subnetwork(self):
   112      pipeline_options = PipelineOptions([
   113          '--subnetwork',
   114          '/regions/MY/subnetworks/SUBNETWORK',
   115          '--temp_location',
   116          'gs://any-location/temp'
   117      ])
   118  
   119      env = apiclient.Environment(
   120          [],  #packages
   121          pipeline_options,
   122          '2.0.0',  #any environment version
   123          FAKE_PIPELINE_URL)
   124      self.assertEqual(
   125          env.proto.workerPools[0].subnetwork,
   126          '/regions/MY/subnetworks/SUBNETWORK')
   127  
   128    def test_flexrs_blank(self):
   129      pipeline_options = PipelineOptions(
   130          ['--temp_location', 'gs://any-location/temp'])
   131  
   132      env = apiclient.Environment(
   133          [],  #packages
   134          pipeline_options,
   135          '2.0.0',  #any environment version
   136          FAKE_PIPELINE_URL)
   137      self.assertEqual(env.proto.flexResourceSchedulingGoal, None)
   138  
   139    def test_flexrs_cost(self):
   140      pipeline_options = PipelineOptions([
   141          '--flexrs_goal',
   142          'COST_OPTIMIZED',
   143          '--temp_location',
   144          'gs://any-location/temp'
   145      ])
   146  
   147      env = apiclient.Environment(
   148          [],  #packages
   149          pipeline_options,
   150          '2.0.0',  #any environment version
   151          FAKE_PIPELINE_URL)
   152      self.assertEqual(
   153          env.proto.flexResourceSchedulingGoal,
   154          (
   155              dataflow.Environment.FlexResourceSchedulingGoalValueValuesEnum.
   156              FLEXRS_COST_OPTIMIZED))
   157  
   158    def test_flexrs_speed(self):
   159      pipeline_options = PipelineOptions([
   160          '--flexrs_goal',
   161          'SPEED_OPTIMIZED',
   162          '--temp_location',
   163          'gs://any-location/temp'
   164      ])
   165  
   166      env = apiclient.Environment(
   167          [],  #packages
   168          pipeline_options,
   169          '2.0.0',  #any environment version
   170          FAKE_PIPELINE_URL)
   171      self.assertEqual(
   172          env.proto.flexResourceSchedulingGoal,
   173          (
   174              dataflow.Environment.FlexResourceSchedulingGoalValueValuesEnum.
   175              FLEXRS_SPEED_OPTIMIZED))
   176  
   177    def _verify_sdk_harness_container_images_get_set(self, pipeline_options):
   178      pipeline = Pipeline(options=pipeline_options)
   179      pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned
   180  
   181      test_environment = DockerEnvironment(container_image='test_default_image')
   182      proto_pipeline, _ = pipeline.to_runner_api(
   183          return_context=True, default_environment=test_environment)
   184  
   185      dummy_env = beam_runner_api_pb2.Environment(
   186          urn=common_urns.environments.DOCKER.urn,
   187          payload=(
   188              beam_runner_api_pb2.DockerPayload(
   189                  container_image='dummy_image')).SerializeToString())
   190      dummy_env.capabilities.append(
   191          common_urns.protocols.MULTI_CORE_BUNDLE_PROCESSING.urn)
   192      proto_pipeline.components.environments['dummy_env_id'].CopyFrom(dummy_env)
   193  
   194      dummy_transform = beam_runner_api_pb2.PTransform(
   195          environment_id='dummy_env_id')
   196      proto_pipeline.components.transforms['dummy_transform_id'].CopyFrom(
   197          dummy_transform)
   198  
   199      env = apiclient.Environment(
   200          [],  # packages
   201          pipeline_options,
   202          '2.0.0',  # any environment version
   203          FAKE_PIPELINE_URL,
   204          proto_pipeline)
   205      worker_pool = env.proto.workerPools[0]
   206  
   207      self.assertEqual(2, len(worker_pool.sdkHarnessContainerImages))
   208      # Only one of the environments is missing MULTI_CORE_BUNDLE_PROCESSING.
   209      self.assertEqual(
   210          1,
   211          sum(
   212              c.useSingleCorePerContainer
   213              for c in worker_pool.sdkHarnessContainerImages))
   214  
   215      env_and_image = [(item.environmentId, item.containerImage)
   216                       for item in worker_pool.sdkHarnessContainerImages]
   217      self.assertIn(('dummy_env_id', 'dummy_image'), env_and_image)
   218      self.assertIn((mock.ANY, 'test_default_image'), env_and_image)
   219  
   220    def test_sdk_harness_container_images_get_set_runner_v2(self):
   221      pipeline_options = PipelineOptions([
   222          '--experiments=use_runner_v2',
   223          '--temp_location',
   224          'gs://any-location/temp'
   225      ])
   226  
   227      self._verify_sdk_harness_container_images_get_set(pipeline_options)
   228  
   229    def test_sdk_harness_container_images_get_set_prime(self):
   230      pipeline_options = PipelineOptions([
   231          '--dataflow_service_options=enable_prime',
   232          '--temp_location',
   233          'gs://any-location/temp'
   234      ])
   235  
   236      self._verify_sdk_harness_container_images_get_set(pipeline_options)
   237  
   238    def _verify_sdk_harness_container_image_overrides(self, pipeline_options):
   239      test_environment = DockerEnvironment(
   240          container_image='dummy_container_image')
   241      proto_pipeline, _ = Pipeline().to_runner_api(
   242        return_context=True, default_environment=test_environment)
   243  
   244      # Accessing non-public method for testing.
   245      apiclient.DataflowApplicationClient._apply_sdk_environment_overrides(
   246          proto_pipeline,
   247          {
   248              '.*dummy.*': 'new_dummy_container_image',
   249              '.*notfound.*': 'new_dummy_container_image_2'
   250          },
   251          pipeline_options)
   252  
   253      self.assertIsNotNone(1, len(proto_pipeline.components.environments))
   254      env = list(proto_pipeline.components.environments.values())[0]
   255  
   256      from apache_beam.utils import proto_utils
   257      docker_payload = proto_utils.parse_Bytes(
   258          env.payload, beam_runner_api_pb2.DockerPayload)
   259  
   260      # Container image should be overridden by the given override.
   261      self.assertEqual(
   262          docker_payload.container_image, 'new_dummy_container_image')
   263  
   264    def test_sdk_harness_container_image_overrides_runner_v2(self):
   265      pipeline_options = PipelineOptions([
   266          '--experiments=use_runner_v2',
   267          '--temp_location',
   268          'gs://any-location/temp'
   269      ])
   270  
   271      self._verify_sdk_harness_container_image_overrides(pipeline_options)
   272  
   273    def test_sdk_harness_container_image_overrides_prime(self):
   274      pipeline_options = PipelineOptions([
   275          '--dataflow_service_options=enable_prime',
   276          '--temp_location',
   277          'gs://any-location/temp'
   278      ])
   279  
   280      self._verify_sdk_harness_container_image_overrides(pipeline_options)
   281  
   282    def _verify_dataflow_container_image_override(self, pipeline_options):
   283      pipeline = Pipeline(options=pipeline_options)
   284      pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned
   285  
   286      dummy_env = DockerEnvironment(
   287          container_image='apache/beam_dummy_name:dummy_tag')
   288      proto_pipeline, _ = pipeline.to_runner_api(
   289          return_context=True, default_environment=dummy_env)
   290  
   291      # Accessing non-public method for testing.
   292      apiclient.DataflowApplicationClient._apply_sdk_environment_overrides(
   293          proto_pipeline, {}, pipeline_options)
   294  
   295      from apache_beam.utils import proto_utils
   296      found_override = False
   297      for env in proto_pipeline.components.environments.values():
   298        docker_payload = proto_utils.parse_Bytes(
   299            env.payload, beam_runner_api_pb2.DockerPayload)
   300        if docker_payload.container_image.startswith(
   301            names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY):
   302          found_override = True
   303  
   304      self.assertTrue(found_override)
   305  
   306    def test_dataflow_container_image_override_runner_v2(self):
   307      pipeline_options = PipelineOptions([
   308          '--experiments=use_runner_v2',
   309          '--temp_location',
   310          'gs://any-location/temp'
   311      ])
   312  
   313      self._verify_dataflow_container_image_override(pipeline_options)
   314  
   315    def test_dataflow_container_image_override_prime(self):
   316      pipeline_options = PipelineOptions([
   317          '--dataflow_service_options=enable_prime',
   318          '--temp_location',
   319          'gs://any-location/temp'
   320      ])
   321  
   322      self._verify_dataflow_container_image_override(pipeline_options)
   323  
   324    def _verify_non_apache_container_not_overridden(self, pipeline_options):
   325      pipeline = Pipeline(options=pipeline_options)
   326      pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned
   327  
   328      dummy_env = DockerEnvironment(
   329          container_image='other_org/dummy_name:dummy_tag')
   330      proto_pipeline, _ = pipeline.to_runner_api(
   331          return_context=True, default_environment=dummy_env)
   332  
   333      # Accessing non-public method for testing.
   334      apiclient.DataflowApplicationClient._apply_sdk_environment_overrides(
   335          proto_pipeline, {}, pipeline_options)
   336  
   337      self.assertIsNotNone(2, len(proto_pipeline.components.environments))
   338  
   339      from apache_beam.utils import proto_utils
   340      found_override = False
   341      for env in proto_pipeline.components.environments.values():
   342        docker_payload = proto_utils.parse_Bytes(
   343            env.payload, beam_runner_api_pb2.DockerPayload)
   344        if docker_payload.container_image.startswith(
   345            names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY):
   346          found_override = True
   347  
   348      self.assertFalse(found_override)
   349  
   350    def test_non_apache_container_not_overridden_runner_v2(self):
   351      pipeline_options = PipelineOptions([
   352          '--experiments=use_runner_v2',
   353          '--temp_location',
   354          'gs://any-location/temp'
   355      ])
   356  
   357      self._verify_non_apache_container_not_overridden(pipeline_options)
   358  
   359    def test_non_apache_container_not_overridden_prime(self):
   360      pipeline_options = PipelineOptions([
   361          '--dataflow_service_options=enable_prime',
   362          '--temp_location',
   363          'gs://any-location/temp'
   364      ])
   365  
   366      self._verify_non_apache_container_not_overridden(pipeline_options)
   367  
   368    def _verify_pipeline_sdk_not_overridden(self, pipeline_options):
   369      pipeline = Pipeline(options=pipeline_options)
   370      pipeline | Create([1, 2, 3]) | ParDo(DoFn())  # pylint:disable=expression-not-assigned
   371  
   372      proto_pipeline, _ = pipeline.to_runner_api(return_context=True)
   373  
   374      dummy_env = DockerEnvironment(
   375          container_image='dummy_prefix/dummy_name:dummy_tag')
   376      proto_pipeline, _ = pipeline.to_runner_api(
   377          return_context=True, default_environment=dummy_env)
   378  
   379      # Accessing non-public method for testing.
   380      apiclient.DataflowApplicationClient._apply_sdk_environment_overrides(
   381          proto_pipeline, {}, pipeline_options)
   382  
   383      self.assertIsNotNone(2, len(proto_pipeline.components.environments))
   384  
   385      from apache_beam.utils import proto_utils
   386      found_override = False
   387      for env in proto_pipeline.components.environments.values():
   388        docker_payload = proto_utils.parse_Bytes(
   389            env.payload, beam_runner_api_pb2.DockerPayload)
   390        if docker_payload.container_image.startswith(
   391            names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY):
   392          found_override = True
   393  
   394      self.assertFalse(found_override)
   395  
   396    def test_pipeline_sdk_not_overridden_runner_v2(self):
   397      pipeline_options = PipelineOptions([
   398          '--experiments=use_runner_v2',
   399          '--temp_location',
   400          'gs://any-location/temp',
   401          '--sdk_container_image=dummy_prefix/dummy_name:dummy_tag'
   402      ])
   403  
   404      self._verify_pipeline_sdk_not_overridden(pipeline_options)
   405  
   406    def test_pipeline_sdk_not_overridden_prime(self):
   407      pipeline_options = PipelineOptions([
   408          '--dataflow_service_options=enable_prime',
   409          '--temp_location',
   410          'gs://any-location/temp',
   411          '--sdk_container_image=dummy_prefix/dummy_name:dummy_tag'
   412      ])
   413  
   414      self._verify_pipeline_sdk_not_overridden(pipeline_options)
   415  
   416    def test_invalid_default_job_name(self):
   417      # Regexp for job names in dataflow.
   418      regexp = '^[a-z]([-a-z0-9]{0,61}[a-z0-9])?$'
   419  
   420      job_name = apiclient.Job._build_default_job_name('invalid.-_user_n*/ame')
   421      self.assertRegex(job_name, regexp)
   422  
   423      job_name = apiclient.Job._build_default_job_name(
   424          'invalid-extremely-long.username_that_shouldbeshortened_or_is_invalid')
   425      self.assertRegex(job_name, regexp)
   426  
   427    def test_default_job_name(self):
   428      job_name = apiclient.Job.default_job_name(None)
   429      regexp = 'beamapp-.*-[0-9]{10}-[0-9]{6}-[a-z0-9]{8}$'
   430      self.assertRegex(job_name, regexp)
   431  
   432    def test_split_int(self):
   433      number = 12345
   434      split_number = apiclient.to_split_int(number)
   435      self.assertEqual((split_number.lowBits, split_number.highBits), (number, 0))
   436      shift_number = number << 32
   437      split_number = apiclient.to_split_int(shift_number)
   438      self.assertEqual((split_number.lowBits, split_number.highBits), (0, number))
   439  
   440    def test_translate_distribution_using_accumulator(self):
   441      metric_update = dataflow.CounterUpdate()
   442      accumulator = mock.Mock()
   443      accumulator.min = 1
   444      accumulator.max = 15
   445      accumulator.sum = 16
   446      accumulator.count = 2
   447      apiclient.translate_distribution(accumulator, metric_update)
   448      self.assertEqual(metric_update.distribution.min.lowBits, accumulator.min)
   449      self.assertEqual(metric_update.distribution.max.lowBits, accumulator.max)
   450      self.assertEqual(metric_update.distribution.sum.lowBits, accumulator.sum)
   451      self.assertEqual(
   452          metric_update.distribution.count.lowBits, accumulator.count)
   453  
   454    def test_translate_distribution_using_distribution_data(self):
   455      metric_update = dataflow.CounterUpdate()
   456      distribution_update = DistributionData(16, 2, 1, 15)
   457      apiclient.translate_distribution(distribution_update, metric_update)
   458      self.assertEqual(
   459          metric_update.distribution.min.lowBits, distribution_update.min)
   460      self.assertEqual(
   461          metric_update.distribution.max.lowBits, distribution_update.max)
   462      self.assertEqual(
   463          metric_update.distribution.sum.lowBits, distribution_update.sum)
   464      self.assertEqual(
   465          metric_update.distribution.count.lowBits, distribution_update.count)
   466  
   467    def test_translate_distribution_using_dataflow_distribution_counter(self):
   468      counter_update = DataflowDistributionCounter()
   469      counter_update.add_input(1)
   470      counter_update.add_input(3)
   471      metric_proto = dataflow.CounterUpdate()
   472      apiclient.translate_distribution(counter_update, metric_proto)
   473      histogram = mock.Mock(firstBucketOffset=None, bucketCounts=None)
   474      counter_update.translate_to_histogram(histogram)
   475      self.assertEqual(metric_proto.distribution.min.lowBits, counter_update.min)
   476      self.assertEqual(metric_proto.distribution.max.lowBits, counter_update.max)
   477      self.assertEqual(metric_proto.distribution.sum.lowBits, counter_update.sum)
   478      self.assertEqual(
   479          metric_proto.distribution.count.lowBits, counter_update.count)
   480      self.assertEqual(
   481          metric_proto.distribution.histogram.bucketCounts,
   482          histogram.bucketCounts)
   483      self.assertEqual(
   484          metric_proto.distribution.histogram.firstBucketOffset,
   485          histogram.firstBucketOffset)
   486  
   487    def test_translate_means(self):
   488      metric_update = dataflow.CounterUpdate()
   489      accumulator = mock.Mock()
   490      accumulator.sum = 16
   491      accumulator.count = 2
   492      apiclient.MetricUpdateTranslators.translate_scalar_mean_int(
   493          accumulator, metric_update)
   494      self.assertEqual(metric_update.integerMean.sum.lowBits, accumulator.sum)
   495      self.assertEqual(metric_update.integerMean.count.lowBits, accumulator.count)
   496  
   497      accumulator.sum = 16.0
   498      accumulator.count = 2
   499      apiclient.MetricUpdateTranslators.translate_scalar_mean_float(
   500          accumulator, metric_update)
   501      self.assertEqual(metric_update.floatingPointMean.sum, accumulator.sum)
   502      self.assertEqual(
   503          metric_update.floatingPointMean.count.lowBits, accumulator.count)
   504  
   505    def test_translate_means_using_distribution_accumulator(self):
   506      # This is the special case for MeanByteCount.
   507      # Which is reported over the FnAPI as a beam distribution,
   508      # and to the service as a MetricUpdate IntegerMean.
   509      metric_update = dataflow.CounterUpdate()
   510      accumulator = mock.Mock()
   511      accumulator.min = 7
   512      accumulator.max = 9
   513      accumulator.sum = 16
   514      accumulator.count = 2
   515      apiclient.MetricUpdateTranslators.translate_scalar_mean_int(
   516          accumulator, metric_update)
   517      self.assertEqual(metric_update.integerMean.sum.lowBits, accumulator.sum)
   518      self.assertEqual(metric_update.integerMean.count.lowBits, accumulator.count)
   519  
   520      accumulator.sum = 16.0
   521      accumulator.count = 2
   522      apiclient.MetricUpdateTranslators.translate_scalar_mean_float(
   523          accumulator, metric_update)
   524      self.assertEqual(metric_update.floatingPointMean.sum, accumulator.sum)
   525      self.assertEqual(
   526          metric_update.floatingPointMean.count.lowBits, accumulator.count)
   527  
   528    def test_default_ip_configuration(self):
   529      pipeline_options = PipelineOptions(
   530          ['--temp_location', 'gs://any-location/temp'])
   531      env = apiclient.Environment([],
   532                                  pipeline_options,
   533                                  '2.0.0',
   534                                  FAKE_PIPELINE_URL)
   535      self.assertEqual(env.proto.workerPools[0].ipConfiguration, None)
   536  
   537    def test_public_ip_configuration(self):
   538      pipeline_options = PipelineOptions(
   539          ['--temp_location', 'gs://any-location/temp', '--use_public_ips'])
   540      env = apiclient.Environment([],
   541                                  pipeline_options,
   542                                  '2.0.0',
   543                                  FAKE_PIPELINE_URL)
   544      self.assertEqual(
   545          env.proto.workerPools[0].ipConfiguration,
   546          dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PUBLIC)
   547  
   548    def test_private_ip_configuration(self):
   549      pipeline_options = PipelineOptions(
   550          ['--temp_location', 'gs://any-location/temp', '--no_use_public_ips'])
   551      env = apiclient.Environment([],
   552                                  pipeline_options,
   553                                  '2.0.0',
   554                                  FAKE_PIPELINE_URL)
   555      self.assertEqual(
   556          env.proto.workerPools[0].ipConfiguration,
   557          dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PRIVATE)
   558  
   559    def test_number_of_worker_harness_threads(self):
   560      pipeline_options = PipelineOptions([
   561          '--temp_location',
   562          'gs://any-location/temp',
   563          '--number_of_worker_harness_threads',
   564          '2'
   565      ])
   566      env = apiclient.Environment([],
   567                                  pipeline_options,
   568                                  '2.0.0',
   569                                  FAKE_PIPELINE_URL)
   570      self.assertEqual(env.proto.workerPools[0].numThreadsPerWorker, 2)
   571  
   572    @mock.patch(
   573        'apache_beam.runners.dataflow.internal.apiclient.'
   574        'beam_version.__version__',
   575        '2.2.0')
   576    def test_harness_override_absent_with_runner_v2(self):
   577      pipeline_options = PipelineOptions([
   578          '--temp_location',
   579          'gs://any-location/temp',
   580          '--streaming',
   581          '--experiments=use_runner_v2'
   582      ])
   583      env = apiclient.Environment(
   584          [],  #packages
   585          pipeline_options,
   586          '2.0.0',  #any environment version
   587          FAKE_PIPELINE_URL)
   588      if env.proto.experiments:
   589        for experiment in env.proto.experiments:
   590          self.assertNotIn('runner_harness_container_image=', experiment)
   591  
   592    @mock.patch(
   593        'apache_beam.runners.dataflow.internal.apiclient.'
   594        'beam_version.__version__',
   595        '2.2.0')
   596    def test_custom_harness_override_present_with_runner_v2(self):
   597      pipeline_options = PipelineOptions([
   598          '--temp_location',
   599          'gs://any-location/temp',
   600          '--streaming',
   601          '--experiments=runner_harness_container_image=fake_image',
   602          '--experiments=use_runner_v2',
   603      ])
   604      env = apiclient.Environment(
   605          [],  #packages
   606          pipeline_options,
   607          '2.0.0',  #any environment version
   608          FAKE_PIPELINE_URL)
   609      self.assertEqual(
   610          1,
   611          len([
   612              x for x in env.proto.experiments
   613              if x.startswith('runner_harness_container_image=')
   614          ]))
   615      self.assertIn(
   616          'runner_harness_container_image=fake_image', env.proto.experiments)
   617  
   618    @mock.patch(
   619        'apache_beam.runners.dataflow.internal.apiclient.'
   620        'beam_version.__version__',
   621        '2.2.0.dev')
   622    def test_pinned_worker_harness_image_tag_used_in_dev_sdk(self):
   623      # streaming, fnapi pipeline.
   624      pipeline_options = PipelineOptions(
   625          ['--temp_location', 'gs://any-location/temp', '--streaming'])
   626      env = apiclient.Environment(
   627          [],  #packages
   628          pipeline_options,
   629          '2.0.0',  #any environment version
   630          FAKE_PIPELINE_URL)
   631      self.assertEqual(
   632          env.proto.workerPools[0].workerHarnessContainerImage,
   633          (
   634              names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY +
   635              '/beam_python%d.%d_sdk:%s' % (
   636                  sys.version_info[0],
   637                  sys.version_info[1],
   638                  names.BEAM_FNAPI_CONTAINER_VERSION)))
   639  
   640      # batch, legacy pipeline.
   641      pipeline_options = PipelineOptions(
   642          ['--temp_location', 'gs://any-location/temp'])
   643      env = apiclient.Environment(
   644          [],  #packages
   645          pipeline_options,
   646          '2.0.0',  #any environment version
   647          FAKE_PIPELINE_URL)
   648      self.assertEqual(
   649          env.proto.workerPools[0].workerHarnessContainerImage,
   650          (
   651              names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + '/python%d%d:%s' % (
   652                  sys.version_info[0],
   653                  sys.version_info[1],
   654                  names.BEAM_CONTAINER_VERSION)))
   655  
   656    @mock.patch(
   657        'apache_beam.runners.dataflow.internal.apiclient.'
   658        'beam_version.__version__',
   659        '2.2.0')
   660    def test_worker_harness_image_tag_matches_released_sdk_version(self):
   661      # streaming, fnapi pipeline.
   662      pipeline_options = PipelineOptions(
   663          ['--temp_location', 'gs://any-location/temp', '--streaming'])
   664      env = apiclient.Environment(
   665          [],  #packages
   666          pipeline_options,
   667          '2.0.0',  #any environment version
   668          FAKE_PIPELINE_URL)
   669      self.assertEqual(
   670          env.proto.workerPools[0].workerHarnessContainerImage,
   671          (
   672              names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY +
   673              '/beam_python%d.%d_sdk:2.2.0' %
   674              (sys.version_info[0], sys.version_info[1])))
   675  
   676      # batch, legacy pipeline.
   677      pipeline_options = PipelineOptions(
   678          ['--temp_location', 'gs://any-location/temp'])
   679      env = apiclient.Environment(
   680          [],  #packages
   681          pipeline_options,
   682          '2.0.0',  #any environment version
   683          FAKE_PIPELINE_URL)
   684      self.assertEqual(
   685          env.proto.workerPools[0].workerHarnessContainerImage,
   686          (
   687              names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + '/python%d%d:2.2.0' %
   688              (sys.version_info[0], sys.version_info[1])))
   689  
   690    @mock.patch(
   691        'apache_beam.runners.dataflow.internal.apiclient.'
   692        'beam_version.__version__',
   693        '2.2.0.rc1')
   694    def test_worker_harness_image_tag_matches_base_sdk_version_of_an_rc(self):
   695      # streaming, fnapi pipeline.
   696      pipeline_options = PipelineOptions(
   697          ['--temp_location', 'gs://any-location/temp', '--streaming'])
   698      env = apiclient.Environment(
   699          [],  #packages
   700          pipeline_options,
   701          '2.0.0',  #any environment version
   702          FAKE_PIPELINE_URL)
   703      self.assertEqual(
   704          env.proto.workerPools[0].workerHarnessContainerImage,
   705          (
   706              names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY +
   707              '/beam_python%d.%d_sdk:2.2.0' %
   708              (sys.version_info[0], sys.version_info[1])))
   709  
   710      # batch, legacy pipeline.
   711      pipeline_options = PipelineOptions(
   712          ['--temp_location', 'gs://any-location/temp'])
   713      env = apiclient.Environment(
   714          [],  #packages
   715          pipeline_options,
   716          '2.0.0',  #any environment version
   717          FAKE_PIPELINE_URL)
   718      self.assertEqual(
   719          env.proto.workerPools[0].workerHarnessContainerImage,
   720          (
   721              names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + '/python%d%d:2.2.0' %
   722              (sys.version_info[0], sys.version_info[1])))
   723  
   724    def test_worker_harness_override_takes_precedence_over_sdk_defaults(self):
   725      # streaming, fnapi pipeline.
   726      pipeline_options = PipelineOptions([
   727          '--temp_location',
   728          'gs://any-location/temp',
   729          '--streaming',
   730          '--sdk_container_image=some:image'
   731      ])
   732      env = apiclient.Environment(
   733          [],  #packages
   734          pipeline_options,
   735          '2.0.0',  #any environment version
   736          FAKE_PIPELINE_URL)
   737      self.assertEqual(
   738          env.proto.workerPools[0].workerHarnessContainerImage, 'some:image')
   739      # batch, legacy pipeline.
   740      pipeline_options = PipelineOptions([
   741          '--temp_location',
   742          'gs://any-location/temp',
   743          '--sdk_container_image=some:image'
   744      ])
   745      env = apiclient.Environment(
   746          [],  #packages
   747          pipeline_options,
   748          '2.0.0',  #any environment version
   749          FAKE_PIPELINE_URL)
   750      self.assertEqual(
   751          env.proto.workerPools[0].workerHarnessContainerImage, 'some:image')
   752  
   753    @mock.patch(
   754        'apache_beam.runners.dataflow.internal.apiclient.Job.'
   755        'job_id_for_name',
   756        return_value='test_id')
   757    def test_transform_name_mapping(self, mock_job):
   758      pipeline_options = PipelineOptions([
   759          '--project',
   760          'test_project',
   761          '--job_name',
   762          'test_job_name',
   763          '--temp_location',
   764          'gs://test-location/temp',
   765          '--update',
   766          '--transform_name_mapping',
   767          '{\"from\":\"to\"}'
   768      ])
   769      job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline())
   770      self.assertIsNotNone(job.proto.transformNameMapping)
   771  
   772    def test_created_from_snapshot_id(self):
   773      pipeline_options = PipelineOptions([
   774          '--project',
   775          'test_project',
   776          '--job_name',
   777          'test_job_name',
   778          '--temp_location',
   779          'gs://test-location/temp',
   780          '--create_from_snapshot',
   781          'test_snapshot_id'
   782      ])
   783      job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline())
   784      self.assertEqual('test_snapshot_id', job.proto.createdFromSnapshotId)
   785  
   786    def test_labels(self):
   787      pipeline_options = PipelineOptions([
   788          '--project',
   789          'test_project',
   790          '--job_name',
   791          'test_job_name',
   792          '--temp_location',
   793          'gs://test-location/temp'
   794      ])
   795      job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline())
   796      self.assertIsNone(job.proto.labels)
   797  
   798      pipeline_options = PipelineOptions([
   799          '--project',
   800          'test_project',
   801          '--job_name',
   802          'test_job_name',
   803          '--temp_location',
   804          'gs://test-location/temp',
   805          '--label',
   806          'key1=value1',
   807          '--label',
   808          'key2',
   809          '--label',
   810          'key3=value3',
   811          '--labels',
   812          'key4=value4',
   813          '--labels',
   814          'key5'
   815      ])
   816      job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline())
   817      self.assertEqual(5, len(job.proto.labels.additionalProperties))
   818      self.assertEqual('key1', job.proto.labels.additionalProperties[0].key)
   819      self.assertEqual('value1', job.proto.labels.additionalProperties[0].value)
   820      self.assertEqual('key2', job.proto.labels.additionalProperties[1].key)
   821      self.assertEqual('', job.proto.labels.additionalProperties[1].value)
   822      self.assertEqual('key3', job.proto.labels.additionalProperties[2].key)
   823      self.assertEqual('value3', job.proto.labels.additionalProperties[2].value)
   824      self.assertEqual('key4', job.proto.labels.additionalProperties[3].key)
   825      self.assertEqual('value4', job.proto.labels.additionalProperties[3].value)
   826      self.assertEqual('key5', job.proto.labels.additionalProperties[4].key)
   827      self.assertEqual('', job.proto.labels.additionalProperties[4].value)
   828  
   829      pipeline_options = PipelineOptions([
   830          '--project',
   831          'test_project',
   832          '--job_name',
   833          'test_job_name',
   834          '--temp_location',
   835          'gs://test-location/temp',
   836          '--labels',
   837          '{ "name": "wrench", "mass": "1_3kg", "count": "3" }'
   838      ])
   839      job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline())
   840      self.assertEqual(3, len(job.proto.labels.additionalProperties))
   841      self.assertEqual('name', job.proto.labels.additionalProperties[0].key)
   842      self.assertEqual('wrench', job.proto.labels.additionalProperties[0].value)
   843      self.assertEqual('mass', job.proto.labels.additionalProperties[1].key)
   844      self.assertEqual('1_3kg', job.proto.labels.additionalProperties[1].value)
   845      self.assertEqual('count', job.proto.labels.additionalProperties[2].key)
   846      self.assertEqual('3', job.proto.labels.additionalProperties[2].value)
   847  
   848    def test_experiment_use_multiple_sdk_containers(self):
   849      pipeline_options = PipelineOptions([
   850          '--project',
   851          'test_project',
   852          '--job_name',
   853          'test_job_name',
   854          '--temp_location',
   855          'gs://test-location/temp',
   856          '--experiments',
   857          'beam_fn_api'
   858      ])
   859      environment = apiclient.Environment([],
   860                                          pipeline_options,
   861                                          1,
   862                                          FAKE_PIPELINE_URL)
   863      self.assertIn('use_multiple_sdk_containers', environment.proto.experiments)
   864  
   865      pipeline_options = PipelineOptions([
   866          '--project',
   867          'test_project',
   868          '--job_name',
   869          'test_job_name',
   870          '--temp_location',
   871          'gs://test-location/temp',
   872          '--experiments',
   873          'beam_fn_api',
   874          '--experiments',
   875          'use_multiple_sdk_containers'
   876      ])
   877      environment = apiclient.Environment([],
   878                                          pipeline_options,
   879                                          1,
   880                                          FAKE_PIPELINE_URL)
   881      self.assertIn('use_multiple_sdk_containers', environment.proto.experiments)
   882  
   883      pipeline_options = PipelineOptions([
   884          '--project',
   885          'test_project',
   886          '--job_name',
   887          'test_job_name',
   888          '--temp_location',
   889          'gs://test-location/temp',
   890          '--experiments',
   891          'beam_fn_api',
   892          '--experiments',
   893          'no_use_multiple_sdk_containers'
   894      ])
   895      environment = apiclient.Environment([],
   896                                          pipeline_options,
   897                                          1,
   898                                          FAKE_PIPELINE_URL)
   899      self.assertNotIn(
   900          'use_multiple_sdk_containers', environment.proto.experiments)
   901  
   902    @mock.patch(
   903        'apache_beam.runners.dataflow.internal.apiclient.sys.version_info',
   904        (3, 8))
   905    def test_get_python_sdk_name(self):
   906      pipeline_options = PipelineOptions([
   907          '--project',
   908          'test_project',
   909          '--job_name',
   910          'test_job_name',
   911          '--temp_location',
   912          'gs://test-location/temp',
   913          '--experiments',
   914          'beam_fn_api',
   915          '--experiments',
   916          'use_multiple_sdk_containers'
   917      ])
   918      environment = apiclient.Environment([],
   919                                          pipeline_options,
   920                                          1,
   921                                          FAKE_PIPELINE_URL)
   922      self.assertEqual(
   923          'Apache Beam Python 3.8 SDK', environment._get_python_sdk_name())
   924  
   925    @mock.patch(
   926        'apache_beam.runners.dataflow.internal.apiclient.sys.version_info',
   927        (2, 7))
   928    @mock.patch(
   929        'apache_beam.runners.dataflow.internal.apiclient.'
   930        'beam_version.__version__',
   931        '2.2.0')
   932    def test_interpreter_version_check_fails_py27(self):
   933      pipeline_options = PipelineOptions([])
   934      self.assertRaises(
   935          Exception,
   936          apiclient._verify_interpreter_version_is_supported,
   937          pipeline_options)
   938  
   939    @mock.patch(
   940        'apache_beam.runners.dataflow.internal.apiclient.sys.version_info',
   941        (3, 0, 0))
   942    @mock.patch(
   943        'apache_beam.runners.dataflow.internal.apiclient.'
   944        'beam_version.__version__',
   945        '2.2.0.dev')
   946    def test_interpreter_version_check_passes_on_dev_sdks(self):
   947      pipeline_options = PipelineOptions([])
   948      apiclient._verify_interpreter_version_is_supported(pipeline_options)
   949  
   950    @mock.patch(
   951        'apache_beam.runners.dataflow.internal.apiclient.'
   952        'beam_version.__version__',
   953        '2.2.0')
   954    @mock.patch(
   955        'apache_beam.runners.dataflow.internal.apiclient.sys.version_info',
   956        (3, 0, 0))
   957    def test_interpreter_version_check_passes_with_experiment(self):
   958      pipeline_options = PipelineOptions(
   959          ["--experiment=use_unsupported_python_version"])
   960      apiclient._verify_interpreter_version_is_supported(pipeline_options)
   961  
   962    @mock.patch(
   963        'apache_beam.runners.dataflow.internal.apiclient.sys.version_info',
   964        (3, 8, 2))
   965    @mock.patch(
   966        'apache_beam.runners.dataflow.internal.apiclient.'
   967        'beam_version.__version__',
   968        '2.2.0')
   969    def test_interpreter_version_check_passes_py38(self):
   970      pipeline_options = PipelineOptions([])
   971      apiclient._verify_interpreter_version_is_supported(pipeline_options)
   972  
   973    @mock.patch(
   974        'apache_beam.runners.dataflow.internal.apiclient.sys.version_info',
   975        (3, 12, 0))
   976    @mock.patch(
   977        'apache_beam.runners.dataflow.internal.apiclient.'
   978        'beam_version.__version__',
   979        '2.2.0')
   980    def test_interpreter_version_check_fails_on_not_yet_supported_version(self):
   981      pipeline_options = PipelineOptions([])
   982      self.assertRaises(
   983          Exception,
   984          apiclient._verify_interpreter_version_is_supported,
   985          pipeline_options)
   986  
   987    def test_get_response_encoding(self):
   988      encoding = apiclient.get_response_encoding()
   989  
   990      assert encoding == 'utf8'
   991  
   992    def test_graph_is_uploaded(self):
   993      pipeline_options = PipelineOptions([
   994          '--project',
   995          'test_project',
   996          '--job_name',
   997          'test_job_name',
   998          '--temp_location',
   999          'gs://test-location/temp',
  1000          '--experiments',
  1001          'beam_fn_api',
  1002          '--experiments',
  1003          'upload_graph'
  1004      ])
  1005      job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline())
  1006      pipeline_options.view_as(GoogleCloudOptions).no_auth = True
  1007      client = apiclient.DataflowApplicationClient(pipeline_options)
  1008      with mock.patch.object(client, 'stage_file', side_effect=None):
  1009        with mock.patch.object(client, 'create_job_description',
  1010                               side_effect=None):
  1011          with mock.patch.object(client,
  1012                                 'submit_job_description',
  1013                                 side_effect=None):
  1014            client.create_job(job)
  1015            client.stage_file.assert_called_once_with(
  1016                mock.ANY, "dataflow_graph.json", mock.ANY)
  1017            client.create_job_description.assert_called_once()
  1018  
  1019    def test_create_job_returns_existing_job(self):
  1020      pipeline_options = PipelineOptions([
  1021          '--project',
  1022          'test_project',
  1023          '--job_name',
  1024          'test_job_name',
  1025          '--temp_location',
  1026          'gs://test-location/temp',
  1027      ])
  1028      job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline())
  1029      self.assertTrue(job.proto.clientRequestId)  # asserts non-empty string
  1030      pipeline_options.view_as(GoogleCloudOptions).no_auth = True
  1031      client = apiclient.DataflowApplicationClient(pipeline_options)
  1032  
  1033      response = dataflow.Job()
  1034      # different clientRequestId from `job`
  1035      response.clientRequestId = "20210821081910123456-1234"
  1036      response.name = 'test_job_name'
  1037      response.id = '2021-08-19_21_18_43-9756917246311111021'
  1038  
  1039      with mock.patch.object(client._client.projects_locations_jobs,
  1040                             'Create',
  1041                             side_effect=[response]):
  1042        with mock.patch.object(client, 'create_job_description',
  1043                               side_effect=None):
  1044          with self.assertRaises(
  1045              apiclient.DataflowJobAlreadyExistsError) as context:
  1046            client.create_job(job)
  1047  
  1048          self.assertEqual(
  1049              str(context.exception),
  1050              'There is already active job named %s with id: %s. If you want to '
  1051              'submit a second job, try again by setting a different name using '
  1052              '--job_name.' % ('test_job_name', response.id))
  1053  
  1054    def test_update_job_returns_existing_job(self):
  1055      pipeline_options = PipelineOptions([
  1056          '--project',
  1057          'test_project',
  1058          '--job_name',
  1059          'test_job_name',
  1060          '--temp_location',
  1061          'gs://test-location/temp',
  1062          '--region',
  1063          'us-central1',
  1064          '--update',
  1065      ])
  1066      replace_job_id = '2021-08-21_00_00_01-6081497447916622336'
  1067      with mock.patch('apache_beam.runners.dataflow.internal.apiclient.Job.'
  1068                      'job_id_for_name',
  1069                      return_value=replace_job_id) as job_id_for_name_mock:
  1070        job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline())
  1071      job_id_for_name_mock.assert_called_once()
  1072  
  1073      self.assertTrue(job.proto.clientRequestId)  # asserts non-empty string
  1074  
  1075      pipeline_options.view_as(GoogleCloudOptions).no_auth = True
  1076      client = apiclient.DataflowApplicationClient(pipeline_options)
  1077  
  1078      response = dataflow.Job()
  1079      # different clientRequestId from `job`
  1080      response.clientRequestId = "20210821083254123456-1234"
  1081      response.name = 'test_job_name'
  1082      response.id = '2021-08-19_21_29_07-5725551945600207770'
  1083  
  1084      with mock.patch.object(client, 'create_job_description', side_effect=None):
  1085        with mock.patch.object(client._client.projects_locations_jobs,
  1086                               'Create',
  1087                               side_effect=[response]):
  1088  
  1089          with self.assertRaises(
  1090              apiclient.DataflowJobAlreadyExistsError) as context:
  1091            client.create_job(job)
  1092  
  1093        self.assertEqual(
  1094            str(context.exception),
  1095            'The job named %s with id: %s has already been updated into job '
  1096            'id: %s and cannot be updated again.' %
  1097            ('test_job_name', replace_job_id, response.id))
  1098  
  1099    def test_template_file_generation_with_upload_graph(self):
  1100      pipeline_options = PipelineOptions([
  1101          '--project',
  1102          'test_project',
  1103          '--job_name',
  1104          'test_job_name',
  1105          '--temp_location',
  1106          'gs://test-location/temp',
  1107          '--experiments',
  1108          'upload_graph',
  1109          '--template_location',
  1110          'gs://test-location/template'
  1111      ])
  1112      job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline())
  1113      job.proto.steps.append(dataflow.Step(name='test_step_name'))
  1114  
  1115      pipeline_options.view_as(GoogleCloudOptions).no_auth = True
  1116      client = apiclient.DataflowApplicationClient(pipeline_options)
  1117      with mock.patch.object(client, 'stage_file', side_effect=None):
  1118        with mock.patch.object(client, 'create_job_description',
  1119                               side_effect=None):
  1120          with mock.patch.object(client,
  1121                                 'submit_job_description',
  1122                                 side_effect=None):
  1123            client.create_job(job)
  1124  
  1125            client.stage_file.assert_has_calls([
  1126                mock.call(mock.ANY, 'dataflow_graph.json', mock.ANY),
  1127                mock.call(mock.ANY, 'template', mock.ANY)
  1128            ])
  1129            client.create_job_description.assert_called_once()
  1130            # template is generated, but job should not be submitted to the
  1131            # service.
  1132            client.submit_job_description.assert_not_called()
  1133  
  1134            template_filename = client.stage_file.call_args_list[-1][0][1]
  1135            self.assertTrue('template' in template_filename)
  1136            template_content = client.stage_file.call_args_list[-1][0][2].read(
  1137            ).decode('utf-8')
  1138            template_obj = json.loads(template_content)
  1139            self.assertFalse(template_obj.get('steps'))
  1140            self.assertTrue(template_obj['stepsLocation'])
  1141  
  1142    def test_stage_resources(self):
  1143      pipeline_options = PipelineOptions([
  1144          '--temp_location',
  1145          'gs://test-location/temp',
  1146          '--staging_location',
  1147          'gs://test-location/staging',
  1148          '--no_auth'
  1149      ])
  1150      pipeline = beam_runner_api_pb2.Pipeline(
  1151          components=beam_runner_api_pb2.Components(
  1152              environments={
  1153                  'env1': beam_runner_api_pb2.Environment(
  1154                      dependencies=[
  1155                          beam_runner_api_pb2.ArtifactInformation(
  1156                              type_urn=common_urns.artifact_types.FILE.urn,
  1157                              type_payload=beam_runner_api_pb2.
  1158                              ArtifactFilePayload(
  1159                                  path='/tmp/foo1').SerializeToString(),
  1160                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1161                              role_payload=beam_runner_api_pb2.
  1162                              ArtifactStagingToRolePayload(
  1163                                  staged_name='foo1').SerializeToString()),
  1164                          beam_runner_api_pb2.ArtifactInformation(
  1165                              type_urn=common_urns.artifact_types.FILE.urn,
  1166                              type_payload=beam_runner_api_pb2.
  1167                              ArtifactFilePayload(
  1168                                  path='/tmp/bar1').SerializeToString(),
  1169                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1170                              role_payload=beam_runner_api_pb2.
  1171                              ArtifactStagingToRolePayload(
  1172                                  staged_name='bar1').SerializeToString()),
  1173                          beam_runner_api_pb2.ArtifactInformation(
  1174                              type_urn=common_urns.artifact_types.FILE.urn,
  1175                              type_payload=beam_runner_api_pb2.
  1176                              ArtifactFilePayload(
  1177                                  path='/tmp/baz').SerializeToString(),
  1178                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1179                              role_payload=beam_runner_api_pb2.
  1180                              ArtifactStagingToRolePayload(
  1181                                  staged_name='baz1').SerializeToString()),
  1182                          beam_runner_api_pb2.ArtifactInformation(
  1183                              type_urn=common_urns.artifact_types.FILE.urn,
  1184                              type_payload=beam_runner_api_pb2.
  1185                              ArtifactFilePayload(
  1186                                  path='/tmp/renamed1',
  1187                                  sha256='abcdefg').SerializeToString(),
  1188                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1189                              role_payload=beam_runner_api_pb2.
  1190                              ArtifactStagingToRolePayload(
  1191                                  staged_name='renamed1').SerializeToString())
  1192                      ]),
  1193                  'env2': beam_runner_api_pb2.Environment(
  1194                      dependencies=[
  1195                          beam_runner_api_pb2.ArtifactInformation(
  1196                              type_urn=common_urns.artifact_types.FILE.urn,
  1197                              type_payload=beam_runner_api_pb2.
  1198                              ArtifactFilePayload(
  1199                                  path='/tmp/foo2').SerializeToString(),
  1200                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1201                              role_payload=beam_runner_api_pb2.
  1202                              ArtifactStagingToRolePayload(
  1203                                  staged_name='foo2').SerializeToString()),
  1204                          beam_runner_api_pb2.ArtifactInformation(
  1205                              type_urn=common_urns.artifact_types.FILE.urn,
  1206                              type_payload=beam_runner_api_pb2.
  1207                              ArtifactFilePayload(
  1208                                  path='/tmp/bar2').SerializeToString(),
  1209                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1210                              role_payload=beam_runner_api_pb2.
  1211                              ArtifactStagingToRolePayload(
  1212                                  staged_name='bar2').SerializeToString()),
  1213                          beam_runner_api_pb2.ArtifactInformation(
  1214                              type_urn=common_urns.artifact_types.FILE.urn,
  1215                              type_payload=beam_runner_api_pb2.
  1216                              ArtifactFilePayload(
  1217                                  path='/tmp/baz').SerializeToString(),
  1218                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1219                              role_payload=beam_runner_api_pb2.
  1220                              ArtifactStagingToRolePayload(
  1221                                  staged_name='baz2').SerializeToString()),
  1222                          beam_runner_api_pb2.ArtifactInformation(
  1223                              type_urn=common_urns.artifact_types.FILE.urn,
  1224                              type_payload=beam_runner_api_pb2.
  1225                              ArtifactFilePayload(
  1226                                  path='/tmp/renamed2',
  1227                                  sha256='abcdefg').SerializeToString(),
  1228                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1229                              role_payload=beam_runner_api_pb2.
  1230                              ArtifactStagingToRolePayload(
  1231                                  staged_name='renamed2').SerializeToString())
  1232                      ])
  1233              }))
  1234      client = apiclient.DataflowApplicationClient(pipeline_options)
  1235      with mock.patch.object(apiclient._LegacyDataflowStager,
  1236                             'stage_job_resources') as mock_stager:
  1237        client._stage_resources(pipeline, pipeline_options)
  1238      mock_stager.assert_called_once_with(
  1239          [('/tmp/foo1', 'foo1', ''), ('/tmp/bar1', 'bar1', ''),
  1240           ('/tmp/baz', 'baz1', ''), ('/tmp/renamed1', 'renamed1', 'abcdefg'),
  1241           ('/tmp/foo2', 'foo2', ''), ('/tmp/bar2', 'bar2', '')],
  1242          staging_location='gs://test-location/staging')
  1243  
  1244      pipeline_expected = beam_runner_api_pb2.Pipeline(
  1245          components=beam_runner_api_pb2.Components(
  1246              environments={
  1247                  'env1': beam_runner_api_pb2.Environment(
  1248                      dependencies=[
  1249                          beam_runner_api_pb2.ArtifactInformation(
  1250                              type_urn=common_urns.artifact_types.URL.urn,
  1251                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1252                                  url='gs://test-location/staging/foo1'
  1253                              ).SerializeToString(),
  1254                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1255                              role_payload=beam_runner_api_pb2.
  1256                              ArtifactStagingToRolePayload(
  1257                                  staged_name='foo1').SerializeToString()),
  1258                          beam_runner_api_pb2.ArtifactInformation(
  1259                              type_urn=common_urns.artifact_types.URL.urn,
  1260                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1261                                  url='gs://test-location/staging/bar1').
  1262                              SerializeToString(),
  1263                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1264                              role_payload=beam_runner_api_pb2.
  1265                              ArtifactStagingToRolePayload(
  1266                                  staged_name='bar1').SerializeToString()),
  1267                          beam_runner_api_pb2.ArtifactInformation(
  1268                              type_urn=common_urns.artifact_types.URL.urn,
  1269                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1270                                  url='gs://test-location/staging/baz1').
  1271                              SerializeToString(),
  1272                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1273                              role_payload=beam_runner_api_pb2.
  1274                              ArtifactStagingToRolePayload(
  1275                                  staged_name='baz1').SerializeToString()),
  1276                          beam_runner_api_pb2.ArtifactInformation(
  1277                              type_urn=common_urns.artifact_types.URL.urn,
  1278                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1279                                  url='gs://test-location/staging/renamed1',
  1280                                  sha256='abcdefg').SerializeToString(),
  1281                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1282                              role_payload=beam_runner_api_pb2.
  1283                              ArtifactStagingToRolePayload(
  1284                                  staged_name='renamed1').SerializeToString())
  1285                      ]),
  1286                  'env2': beam_runner_api_pb2.Environment(
  1287                      dependencies=[
  1288                          beam_runner_api_pb2.ArtifactInformation(
  1289                              type_urn=common_urns.artifact_types.URL.urn,
  1290                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1291                                  url='gs://test-location/staging/foo2').
  1292                              SerializeToString(),
  1293                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1294                              role_payload=beam_runner_api_pb2.
  1295                              ArtifactStagingToRolePayload(
  1296                                  staged_name='foo2').SerializeToString()),
  1297                          beam_runner_api_pb2.ArtifactInformation(
  1298                              type_urn=common_urns.artifact_types.URL.urn,
  1299                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1300                                  url='gs://test-location/staging/bar2').
  1301                              SerializeToString(),
  1302                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1303                              role_payload=beam_runner_api_pb2.
  1304                              ArtifactStagingToRolePayload(
  1305                                  staged_name='bar2').SerializeToString()),
  1306                          beam_runner_api_pb2.ArtifactInformation(
  1307                              type_urn=common_urns.artifact_types.URL.urn,
  1308                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1309                                  url='gs://test-location/staging/baz1').
  1310                              SerializeToString(),
  1311                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1312                              role_payload=beam_runner_api_pb2.
  1313                              ArtifactStagingToRolePayload(
  1314                                  staged_name='baz1').SerializeToString()),
  1315                          beam_runner_api_pb2.ArtifactInformation(
  1316                              type_urn=common_urns.artifact_types.URL.urn,
  1317                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1318                                  url='gs://test-location/staging/renamed1',
  1319                                  sha256='abcdefg').SerializeToString(),
  1320                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1321                              role_payload=beam_runner_api_pb2.
  1322                              ArtifactStagingToRolePayload(
  1323                                  staged_name='renamed1').SerializeToString())
  1324                      ])
  1325              }))
  1326      self.assertEqual(pipeline, pipeline_expected)
  1327  
  1328    def test_set_dataflow_service_option(self):
  1329      pipeline_options = PipelineOptions([
  1330          '--dataflow_service_option',
  1331          'whizz=bang',
  1332          '--temp_location',
  1333          'gs://any-location/temp'
  1334      ])
  1335      env = apiclient.Environment(
  1336          [],  #packages
  1337          pipeline_options,
  1338          '2.0.0',  #any environment version
  1339          FAKE_PIPELINE_URL)
  1340      self.assertEqual(env.proto.serviceOptions, ['whizz=bang'])
  1341  
  1342    def test_enable_hot_key_logging(self):
  1343      # Tests that the enable_hot_key_logging is not set by default.
  1344      pipeline_options = PipelineOptions(
  1345          ['--temp_location', 'gs://any-location/temp'])
  1346      env = apiclient.Environment(
  1347          [],  #packages
  1348          pipeline_options,
  1349          '2.0.0',  #any environment version
  1350          FAKE_PIPELINE_URL)
  1351      self.assertIsNone(env.proto.debugOptions)
  1352  
  1353      # Now test that it is set when given.
  1354      pipeline_options = PipelineOptions([
  1355          '--enable_hot_key_logging', '--temp_location', 'gs://any-location/temp'
  1356      ])
  1357      env = apiclient.Environment(
  1358          [],  #packages
  1359          pipeline_options,
  1360          '2.0.0',  #any environment version
  1361          FAKE_PIPELINE_URL)
  1362      self.assertEqual(
  1363          env.proto.debugOptions, dataflow.DebugOptions(enableHotKeyLogging=True))
  1364  
  1365    def _mock_uncached_copy(self, staging_root, src, sha256, dst_name=None):
  1366      sha_prefix = sha256[0:2]
  1367      gcs_cache_path = FileSystems.join(
  1368          staging_root,
  1369          apiclient.DataflowApplicationClient._GCS_CACHE_PREFIX,
  1370          sha_prefix,
  1371          sha256)
  1372  
  1373      if not dst_name:
  1374        _, dst_name = os.path.split(src)
  1375      return [
  1376          mock.call.gcs_exists(gcs_cache_path),
  1377          mock.call.gcs_upload(src, gcs_cache_path),
  1378          mock.call.gcs_gcs_copy(
  1379              source_file_names=[gcs_cache_path],
  1380              destination_file_names=[f'gs://test-location/staging/{dst_name}'])
  1381      ]
  1382  
  1383    def _mock_cached_copy(self, staging_root, src, sha256, dst_name=None):
  1384      uncached = self._mock_uncached_copy(staging_root, src, sha256, dst_name)
  1385      uncached.pop(1)
  1386      return uncached
  1387  
  1388    def test_stage_artifacts_with_caching(self):
  1389      pipeline_options = PipelineOptions([
  1390          '--temp_location',
  1391          'gs://test-location/temp',
  1392          '--staging_location',
  1393          'gs://test-location/staging',
  1394          '--no_auth',
  1395          '--enable_artifact_caching'
  1396      ])
  1397      pipeline = beam_runner_api_pb2.Pipeline(
  1398          components=beam_runner_api_pb2.Components(
  1399              environments={
  1400                  'env1': beam_runner_api_pb2.Environment(
  1401                      dependencies=[
  1402                          beam_runner_api_pb2.ArtifactInformation(
  1403                              type_urn=common_urns.artifact_types.FILE.urn,
  1404                              type_payload=beam_runner_api_pb2.
  1405                              ArtifactFilePayload(
  1406                                  path='/tmp/foo1',
  1407                                  sha256='abcd').SerializeToString(),
  1408                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1409                              role_payload=beam_runner_api_pb2.
  1410                              ArtifactStagingToRolePayload(
  1411                                  staged_name='foo1').SerializeToString()),
  1412                          beam_runner_api_pb2.ArtifactInformation(
  1413                              type_urn=common_urns.artifact_types.FILE.urn,
  1414                              type_payload=beam_runner_api_pb2.
  1415                              ArtifactFilePayload(
  1416                                  path='/tmp/bar1',
  1417                                  sha256='defg').SerializeToString(),
  1418                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1419                              role_payload=beam_runner_api_pb2.
  1420                              ArtifactStagingToRolePayload(
  1421                                  staged_name='bar1').SerializeToString()),
  1422                          beam_runner_api_pb2.ArtifactInformation(
  1423                              type_urn=common_urns.artifact_types.FILE.urn,
  1424                              type_payload=beam_runner_api_pb2.
  1425                              ArtifactFilePayload(path='/tmp/baz', sha256='hijk'
  1426                                                  ).SerializeToString(),
  1427                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1428                              role_payload=beam_runner_api_pb2.
  1429                              ArtifactStagingToRolePayload(
  1430                                  staged_name='baz1').SerializeToString()),
  1431                          beam_runner_api_pb2.ArtifactInformation(
  1432                              type_urn=common_urns.artifact_types.FILE.urn,
  1433                              type_payload=beam_runner_api_pb2.
  1434                              ArtifactFilePayload(
  1435                                  path='/tmp/renamed1',
  1436                                  sha256='abcdefg').SerializeToString(),
  1437                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1438                              role_payload=beam_runner_api_pb2.
  1439                              ArtifactStagingToRolePayload(
  1440                                  staged_name='renamed1').SerializeToString())
  1441                      ]),
  1442                  'env2': beam_runner_api_pb2.Environment(
  1443                      dependencies=[
  1444                          beam_runner_api_pb2.ArtifactInformation(
  1445                              type_urn=common_urns.artifact_types.FILE.urn,
  1446                              type_payload=beam_runner_api_pb2.
  1447                              ArtifactFilePayload(
  1448                                  path='/tmp/foo2',
  1449                                  sha256='lmno').SerializeToString(),
  1450                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1451                              role_payload=beam_runner_api_pb2.
  1452                              ArtifactStagingToRolePayload(
  1453                                  staged_name='foo2').SerializeToString()),
  1454                          beam_runner_api_pb2.ArtifactInformation(
  1455                              type_urn=common_urns.artifact_types.FILE.urn,
  1456                              type_payload=beam_runner_api_pb2.
  1457                              ArtifactFilePayload(
  1458                                  path='/tmp/bar2',
  1459                                  sha256='pqrs').SerializeToString(),
  1460                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1461                              role_payload=beam_runner_api_pb2.
  1462                              ArtifactStagingToRolePayload(
  1463                                  staged_name='bar2').SerializeToString()),
  1464                          beam_runner_api_pb2.ArtifactInformation(
  1465                              type_urn=common_urns.artifact_types.FILE.urn,
  1466                              type_payload=beam_runner_api_pb2.
  1467                              ArtifactFilePayload(path='/tmp/baz', sha256='tuv'
  1468                                                  ).SerializeToString(),
  1469                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1470                              role_payload=beam_runner_api_pb2.
  1471                              ArtifactStagingToRolePayload(
  1472                                  staged_name='baz2').SerializeToString()),
  1473                          beam_runner_api_pb2.ArtifactInformation(
  1474                              type_urn=common_urns.artifact_types.FILE.urn,
  1475                              type_payload=beam_runner_api_pb2.
  1476                              ArtifactFilePayload(
  1477                                  path='/tmp/renamed2',
  1478                                  sha256='abcdefg').SerializeToString(),
  1479                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1480                              role_payload=beam_runner_api_pb2.
  1481                              ArtifactStagingToRolePayload(
  1482                                  staged_name='renamed2').SerializeToString())
  1483                      ])
  1484              }))
  1485      client = apiclient.DataflowApplicationClient(pipeline_options)
  1486      staging_root = 'gs://test-location/staging'
  1487  
  1488      # every other artifact already exists
  1489      n = [0]
  1490  
  1491      def exists_return_value(*args):
  1492        n[0] += 1
  1493        return n[0] % 2 == 0
  1494  
  1495      with mock.patch.object(FileSystems,
  1496                             'exists',
  1497                             side_effect=exists_return_value) as mock_gcs_exists:
  1498        with mock.patch.object(apiclient.DataflowApplicationClient,
  1499                               '_uncached_gcs_file_copy') as mock_gcs_copy:
  1500          with mock.patch.object(FileSystems, 'copy') as mock_gcs_gcs_copy:
  1501  
  1502            manager = mock.Mock()
  1503            manager.attach_mock(mock_gcs_exists, 'gcs_exists')
  1504            manager.attach_mock(mock_gcs_copy, 'gcs_upload')
  1505            manager.attach_mock(mock_gcs_gcs_copy, 'gcs_gcs_copy')
  1506  
  1507            client._stage_resources(pipeline, pipeline_options)
  1508            expected_calls = list(
  1509                itertools.chain.from_iterable([
  1510                    self._mock_uncached_copy(staging_root, '/tmp/foo1', 'abcd'),
  1511                    self._mock_cached_copy(staging_root, '/tmp/bar1', 'defg'),
  1512                    self._mock_uncached_copy(
  1513                        staging_root, '/tmp/baz', 'hijk', 'baz1'),
  1514                    self._mock_cached_copy(
  1515                        staging_root, '/tmp/renamed1', 'abcdefg'),
  1516                    self._mock_uncached_copy(staging_root, '/tmp/foo2', 'lmno'),
  1517                    self._mock_cached_copy(staging_root, '/tmp/bar2', 'pqrs'),
  1518                ]))
  1519            assert manager.mock_calls == expected_calls
  1520  
  1521      pipeline_expected = beam_runner_api_pb2.Pipeline(
  1522          components=beam_runner_api_pb2.Components(
  1523              environments={
  1524                  'env1': beam_runner_api_pb2.Environment(
  1525                      dependencies=[
  1526                          beam_runner_api_pb2.ArtifactInformation(
  1527                              type_urn=common_urns.artifact_types.URL.urn,
  1528                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1529                                  url='gs://test-location/staging/foo1',
  1530                                  sha256='abcd').SerializeToString(),
  1531                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1532                              role_payload=beam_runner_api_pb2.
  1533                              ArtifactStagingToRolePayload(
  1534                                  staged_name='foo1').SerializeToString()),
  1535                          beam_runner_api_pb2.ArtifactInformation(
  1536                              type_urn=common_urns.artifact_types.URL.urn,
  1537                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1538                                  url='gs://test-location/staging/bar1',
  1539                                  sha256='defg').SerializeToString(),
  1540                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1541                              role_payload=beam_runner_api_pb2.
  1542                              ArtifactStagingToRolePayload(
  1543                                  staged_name='bar1').SerializeToString()),
  1544                          beam_runner_api_pb2.ArtifactInformation(
  1545                              type_urn=common_urns.artifact_types.URL.urn,
  1546                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1547                                  url='gs://test-location/staging/baz1',
  1548                                  sha256='hijk').SerializeToString(),
  1549                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1550                              role_payload=beam_runner_api_pb2.
  1551                              ArtifactStagingToRolePayload(
  1552                                  staged_name='baz1').SerializeToString()),
  1553                          beam_runner_api_pb2.ArtifactInformation(
  1554                              type_urn=common_urns.artifact_types.URL.urn,
  1555                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1556                                  url='gs://test-location/staging/renamed1',
  1557                                  sha256='abcdefg').SerializeToString(),
  1558                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1559                              role_payload=beam_runner_api_pb2.
  1560                              ArtifactStagingToRolePayload(
  1561                                  staged_name='renamed1').SerializeToString())
  1562                      ]),
  1563                  'env2': beam_runner_api_pb2.Environment(
  1564                      dependencies=[
  1565                          beam_runner_api_pb2.ArtifactInformation(
  1566                              type_urn=common_urns.artifact_types.URL.urn,
  1567                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1568                                  url='gs://test-location/staging/foo2',
  1569                                  sha256='lmno').SerializeToString(),
  1570                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1571                              role_payload=beam_runner_api_pb2.
  1572                              ArtifactStagingToRolePayload(
  1573                                  staged_name='foo2').SerializeToString()),
  1574                          beam_runner_api_pb2.ArtifactInformation(
  1575                              type_urn=common_urns.artifact_types.URL.urn,
  1576                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1577                                  url='gs://test-location/staging/bar2',
  1578                                  sha256='pqrs').SerializeToString(),
  1579                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1580                              role_payload=beam_runner_api_pb2.
  1581                              ArtifactStagingToRolePayload(
  1582                                  staged_name='bar2').SerializeToString()),
  1583                          beam_runner_api_pb2.ArtifactInformation(
  1584                              type_urn=common_urns.artifact_types.URL.urn,
  1585                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1586                                  url='gs://test-location/staging/baz1',
  1587                                  sha256='tuv').SerializeToString(),
  1588                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1589                              role_payload=beam_runner_api_pb2.
  1590                              ArtifactStagingToRolePayload(
  1591                                  staged_name='baz1').SerializeToString()),
  1592                          beam_runner_api_pb2.ArtifactInformation(
  1593                              type_urn=common_urns.artifact_types.URL.urn,
  1594                              type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
  1595                                  url='gs://test-location/staging/renamed1',
  1596                                  sha256='abcdefg').SerializeToString(),
  1597                              role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1598                              role_payload=beam_runner_api_pb2.
  1599                              ArtifactStagingToRolePayload(
  1600                                  staged_name='renamed1').SerializeToString())
  1601                      ])
  1602              }))
  1603      self.assertEqual(pipeline, pipeline_expected)
  1604  
  1605  
  1606  if __name__ == '__main__':
  1607    unittest.main()