github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dataflow/internal/apiclient.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """ For internal use only. No backwards compatibility guarantees.
    19  
    20  Dataflow client utility functions."""
    21  
    22  # pytype: skip-file
    23  # To regenerate the client:
    24  # pip install google-apitools[cli]
    25  # gen_client --discovery_url=cloudbuild.v1 --overwrite \
    26  #  --outdir=apache_beam/runners/dataflow/internal/clients/cloudbuild \
    27  #  --root_package=. client
    28  
    29  import ast
    30  import codecs
    31  from functools import partial
    32  import getpass
    33  import hashlib
    34  import io
    35  import json
    36  import logging
    37  import os
    38  import random
    39  import string
    40  
    41  import pkg_resources
    42  import re
    43  import sys
    44  import time
    45  import warnings
    46  from copy import copy
    47  from datetime import datetime
    48  
    49  from apitools.base.py import encoding
    50  from apitools.base.py import exceptions
    51  
    52  from apache_beam import version as beam_version
    53  from apache_beam.internal.gcp.auth import get_service_credentials
    54  from apache_beam.internal.gcp.json_value import to_json_value
    55  from apache_beam.internal.http_client import get_new_http
    56  from apache_beam.io.filesystems import FileSystems
    57  from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem
    58  from apache_beam.io.gcp.internal.clients import storage
    59  from apache_beam.options.pipeline_options import DebugOptions
    60  from apache_beam.options.pipeline_options import GoogleCloudOptions
    61  from apache_beam.options.pipeline_options import StandardOptions
    62  from apache_beam.options.pipeline_options import WorkerOptions
    63  from apache_beam.portability import common_urns
    64  from apache_beam.portability.api import beam_runner_api_pb2
    65  from apache_beam.runners.common import validate_pipeline_graph
    66  from apache_beam.runners.dataflow.internal import names
    67  from apache_beam.runners.dataflow.internal.clients import dataflow
    68  from apache_beam.runners.dataflow.internal.names import PropertyNames
    69  from apache_beam.runners.internal import names as shared_names
    70  from apache_beam.runners.portability.stager import Stager
    71  from apache_beam.transforms import DataflowDistributionCounter
    72  from apache_beam.transforms import cy_combiners
    73  from apache_beam.transforms.display import DisplayData
    74  from apache_beam.transforms.environments import is_apache_beam_container
    75  from apache_beam.utils import retry
    76  from apache_beam.utils import proto_utils
    77  
    78  # Environment version information. It is passed to the service during a
    79  # a job submission and is used by the service to establish what features
    80  # are expected by the workers.
    81  _LEGACY_ENVIRONMENT_MAJOR_VERSION = '8'
    82  _FNAPI_ENVIRONMENT_MAJOR_VERSION = '8'
    83  
    84  _LOGGER = logging.getLogger(__name__)
    85  
    86  _PYTHON_VERSIONS_SUPPORTED_BY_DATAFLOW = ['3.7', '3.8', '3.9', '3.10', '3.11']
    87  
    88  
    89  class Step(object):
    90    """Wrapper for a dataflow Step protobuf."""
    91    def __init__(self, step_kind, step_name, additional_properties=None):
    92      self.step_kind = step_kind
    93      self.step_name = step_name
    94      self.proto = dataflow.Step(kind=step_kind, name=step_name)
    95      self.proto.properties = {}
    96      self._additional_properties = []
    97  
    98      if additional_properties is not None:
    99        for (n, v, t) in additional_properties:
   100          self.add_property(n, v, t)
   101  
   102    def add_property(self, name, value, with_type=False):
   103      self._additional_properties.append((name, value, with_type))
   104      self.proto.properties.additionalProperties.append(
   105          dataflow.Step.PropertiesValue.AdditionalProperty(
   106              key=name, value=to_json_value(value, with_type=with_type)))
   107  
   108    def _get_outputs(self):
   109      """Returns a list of all output labels for a step."""
   110      outputs = []
   111      for p in self.proto.properties.additionalProperties:
   112        if p.key == PropertyNames.OUTPUT_INFO:
   113          for entry in p.value.array_value.entries:
   114            for entry_prop in entry.object_value.properties:
   115              if entry_prop.key == PropertyNames.OUTPUT_NAME:
   116                outputs.append(entry_prop.value.string_value)
   117      return outputs
   118  
   119    def __reduce__(self):
   120      """Reduce hook for pickling the Step class more easily."""
   121      return (Step, (self.step_kind, self.step_name, self._additional_properties))
   122  
   123    def get_output(self, tag=None):
   124      """Returns name if it is one of the outputs or first output if name is None.
   125  
   126      Args:
   127        tag: tag of the output as a string or None if we want to get the
   128          name of the first output.
   129  
   130      Returns:
   131        The name of the output associated with the tag or the first output
   132        if tag was None.
   133  
   134      Raises:
   135        ValueError: if the tag does not exist within outputs.
   136      """
   137      outputs = self._get_outputs()
   138      if tag is None or len(outputs) == 1:
   139        return outputs[0]
   140      else:
   141        if tag not in outputs:
   142          raise ValueError('Cannot find named output: %s in %s.' % (tag, outputs))
   143        return tag
   144  
   145  
   146  class Environment(object):
   147    """Wrapper for a dataflow Environment protobuf."""
   148    def __init__(
   149        self,
   150        packages,
   151        options,
   152        environment_version,
   153        proto_pipeline_staged_url,
   154        proto_pipeline=None):
   155      from apache_beam.runners.dataflow.dataflow_runner import _is_runner_v2
   156      self.standard_options = options.view_as(StandardOptions)
   157      self.google_cloud_options = options.view_as(GoogleCloudOptions)
   158      self.worker_options = options.view_as(WorkerOptions)
   159      self.debug_options = options.view_as(DebugOptions)
   160      self.pipeline_url = proto_pipeline_staged_url
   161      self.proto = dataflow.Environment()
   162      self.proto.clusterManagerApiService = GoogleCloudOptions.COMPUTE_API_SERVICE
   163      self.proto.dataset = '{}/cloud_dataflow'.format(
   164          GoogleCloudOptions.BIGQUERY_API_SERVICE)
   165      self.proto.tempStoragePrefix = (
   166          self.google_cloud_options.temp_location.replace(
   167              'gs:/', GoogleCloudOptions.STORAGE_API_SERVICE))
   168      if self.worker_options.worker_region:
   169        self.proto.workerRegion = self.worker_options.worker_region
   170      if self.worker_options.worker_zone:
   171        self.proto.workerZone = self.worker_options.worker_zone
   172      # User agent information.
   173      self.proto.userAgent = dataflow.Environment.UserAgentValue()
   174      self.local = 'localhost' in self.google_cloud_options.dataflow_endpoint
   175      self._proto_pipeline = proto_pipeline
   176  
   177      if self.google_cloud_options.service_account_email:
   178        self.proto.serviceAccountEmail = (
   179            self.google_cloud_options.service_account_email)
   180      if self.google_cloud_options.dataflow_kms_key:
   181        self.proto.serviceKmsKeyName = self.google_cloud_options.dataflow_kms_key
   182  
   183      self.proto.userAgent.additionalProperties.extend([
   184          dataflow.Environment.UserAgentValue.AdditionalProperty(
   185              key='name', value=to_json_value(self._get_python_sdk_name())),
   186          dataflow.Environment.UserAgentValue.AdditionalProperty(
   187              key='version', value=to_json_value(beam_version.__version__))
   188      ])
   189      # Version information.
   190      self.proto.version = dataflow.Environment.VersionValue()
   191      _verify_interpreter_version_is_supported(options)
   192      if self.standard_options.streaming:
   193        job_type = 'FNAPI_STREAMING'
   194      else:
   195        if _is_runner_v2(options):
   196          job_type = 'FNAPI_BATCH'
   197        else:
   198          job_type = 'PYTHON_BATCH'
   199      self.proto.version.additionalProperties.extend([
   200          dataflow.Environment.VersionValue.AdditionalProperty(
   201              key='job_type', value=to_json_value(job_type)),
   202          dataflow.Environment.VersionValue.AdditionalProperty(
   203              key='major', value=to_json_value(environment_version))
   204      ])
   205      # TODO: Use enumerated type instead of strings for job types.
   206      if job_type.startswith('FNAPI_'):
   207        self.debug_options.experiments = self.debug_options.experiments or []
   208  
   209        debug_options_experiments = self.debug_options.experiments
   210        # Add use_multiple_sdk_containers flag if it's not already present. Do not
   211        # add the flag if 'no_use_multiple_sdk_containers' is present.
   212        # TODO: Cleanup use_multiple_sdk_containers once we deprecate Python SDK
   213        # till version 2.4.
   214        if ('use_multiple_sdk_containers' not in debug_options_experiments and
   215            'no_use_multiple_sdk_containers' not in debug_options_experiments):
   216          debug_options_experiments.append('use_multiple_sdk_containers')
   217      # FlexRS
   218      if self.google_cloud_options.flexrs_goal == 'COST_OPTIMIZED':
   219        self.proto.flexResourceSchedulingGoal = (
   220            dataflow.Environment.FlexResourceSchedulingGoalValueValuesEnum.
   221            FLEXRS_COST_OPTIMIZED)
   222      elif self.google_cloud_options.flexrs_goal == 'SPEED_OPTIMIZED':
   223        self.proto.flexResourceSchedulingGoal = (
   224            dataflow.Environment.FlexResourceSchedulingGoalValueValuesEnum.
   225            FLEXRS_SPEED_OPTIMIZED)
   226      # Experiments
   227      if self.debug_options.experiments:
   228        for experiment in self.debug_options.experiments:
   229          self.proto.experiments.append(experiment)
   230      # Worker pool(s) information.
   231      package_descriptors = []
   232      for package in packages:
   233        package_descriptors.append(
   234            dataflow.Package(
   235                location='%s/%s' % (
   236                    self.google_cloud_options.staging_location.replace(
   237                        'gs:/', GoogleCloudOptions.STORAGE_API_SERVICE),
   238                    package),
   239                name=package))
   240  
   241      pool = dataflow.WorkerPool(
   242          kind='local' if self.local else 'harness',
   243          packages=package_descriptors,
   244          taskrunnerSettings=dataflow.TaskRunnerSettings(
   245              parallelWorkerSettings=dataflow.WorkerSettings(
   246                  baseUrl=GoogleCloudOptions.DATAFLOW_ENDPOINT,
   247                  servicePath=self.google_cloud_options.dataflow_endpoint)))
   248  
   249      pool.autoscalingSettings = dataflow.AutoscalingSettings()
   250      # Set worker pool options received through command line.
   251      if self.worker_options.num_workers:
   252        pool.numWorkers = self.worker_options.num_workers
   253      if self.worker_options.max_num_workers:
   254        pool.autoscalingSettings.maxNumWorkers = (
   255            self.worker_options.max_num_workers)
   256      if self.worker_options.autoscaling_algorithm:
   257        values_enum = dataflow.AutoscalingSettings.AlgorithmValueValuesEnum
   258        pool.autoscalingSettings.algorithm = {
   259            'NONE': values_enum.AUTOSCALING_ALGORITHM_NONE,
   260            'THROUGHPUT_BASED': values_enum.AUTOSCALING_ALGORITHM_BASIC,
   261        }.get(self.worker_options.autoscaling_algorithm)
   262      if self.worker_options.machine_type:
   263        pool.machineType = self.worker_options.machine_type
   264      if self.worker_options.disk_size_gb:
   265        pool.diskSizeGb = self.worker_options.disk_size_gb
   266      if self.worker_options.disk_type:
   267        pool.diskType = self.worker_options.disk_type
   268      if self.worker_options.zone:
   269        pool.zone = self.worker_options.zone
   270      if self.worker_options.network:
   271        pool.network = self.worker_options.network
   272      if self.worker_options.subnetwork:
   273        pool.subnetwork = self.worker_options.subnetwork
   274  
   275      # Setting worker pool sdk_harness_container_images option for supported
   276      # Dataflow workers.
   277      environments_to_use = self._get_environments_from_tranforms()
   278  
   279      # Adding container images for other SDKs that may be needed for
   280      # cross-language pipelines.
   281      for id, environment in environments_to_use:
   282        if environment.urn != common_urns.environments.DOCKER.urn:
   283          raise Exception(
   284              'Dataflow can only execute pipeline steps in Docker environments.'
   285              ' Received %r.' % environment)
   286        environment_payload = proto_utils.parse_Bytes(
   287            environment.payload, beam_runner_api_pb2.DockerPayload)
   288        container_image_url = environment_payload.container_image
   289  
   290        container_image = dataflow.SdkHarnessContainerImage()
   291        container_image.containerImage = container_image_url
   292        container_image.useSingleCorePerContainer = (
   293            common_urns.protocols.MULTI_CORE_BUNDLE_PROCESSING.urn not in
   294            environment.capabilities)
   295        container_image.environmentId = id
   296        for capability in environment.capabilities:
   297          container_image.capabilities.append(capability)
   298        pool.sdkHarnessContainerImages.append(container_image)
   299  
   300      if not _is_runner_v2(options) or not pool.sdkHarnessContainerImages:
   301        pool.workerHarnessContainerImage = (
   302            get_container_image_from_options(options))
   303      elif len(pool.sdkHarnessContainerImages) == 1:
   304        # Dataflow expects a value here when there is only one environment.
   305        pool.workerHarnessContainerImage = (
   306            pool.sdkHarnessContainerImages[0].containerImage)
   307  
   308      if self.debug_options.number_of_worker_harness_threads:
   309        pool.numThreadsPerWorker = (
   310            self.debug_options.number_of_worker_harness_threads)
   311      if self.worker_options.use_public_ips is not None:
   312        if self.worker_options.use_public_ips:
   313          pool.ipConfiguration = (
   314              dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PUBLIC)
   315        else:
   316          pool.ipConfiguration = (
   317              dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PRIVATE
   318          )
   319  
   320      if self.standard_options.streaming:
   321        # Use separate data disk for streaming.
   322        disk = dataflow.Disk()
   323        if self.local:
   324          disk.diskType = 'local'
   325        if self.worker_options.disk_type:
   326          disk.diskType = self.worker_options.disk_type
   327        pool.dataDisks.append(disk)
   328      self.proto.workerPools.append(pool)
   329  
   330      sdk_pipeline_options = options.get_all_options(retain_unknown_options=True)
   331      if sdk_pipeline_options:
   332        self.proto.sdkPipelineOptions = (
   333            dataflow.Environment.SdkPipelineOptionsValue())
   334  
   335        options_dict = {
   336            k: v
   337            for k, v in sdk_pipeline_options.items() if v is not None
   338        }
   339        options_dict["pipelineUrl"] = proto_pipeline_staged_url
   340        # Don't pass impersonate_service_account through to the harness.
   341        # Though impersonation should start a job, the workers should
   342        # not try to modify their credentials.
   343        options_dict.pop('impersonate_service_account', None)
   344        self.proto.sdkPipelineOptions.additionalProperties.append(
   345            dataflow.Environment.SdkPipelineOptionsValue.AdditionalProperty(
   346                key='options', value=to_json_value(options_dict)))
   347  
   348        dd = DisplayData.create_from_options(options)
   349        items = [item.get_dict() for item in dd.items]
   350        self.proto.sdkPipelineOptions.additionalProperties.append(
   351            dataflow.Environment.SdkPipelineOptionsValue.AdditionalProperty(
   352                key='display_data', value=to_json_value(items)))
   353  
   354      if self.google_cloud_options.dataflow_service_options:
   355        for option in self.google_cloud_options.dataflow_service_options:
   356          self.proto.serviceOptions.append(option)
   357  
   358      if self.google_cloud_options.enable_hot_key_logging:
   359        self.proto.debugOptions = dataflow.DebugOptions(enableHotKeyLogging=True)
   360  
   361    def _get_environments_from_tranforms(self):
   362      if not self._proto_pipeline:
   363        return []
   364  
   365      environment_ids = set(
   366          transform.environment_id
   367          for transform in self._proto_pipeline.components.transforms.values()
   368          if transform.environment_id)
   369  
   370      return [(id, self._proto_pipeline.components.environments[id])
   371              for id in environment_ids]
   372  
   373    def _get_python_sdk_name(self):
   374      python_version = '%d.%d' % (sys.version_info[0], sys.version_info[1])
   375      return 'Apache Beam Python %s SDK' % python_version
   376  
   377  
   378  class Job(object):
   379    """Wrapper for a dataflow Job protobuf."""
   380    def __str__(self):
   381      def encode_shortstrings(input_buffer, errors='strict'):
   382        """Encoder (from Unicode) that suppresses long base64 strings."""
   383        original_len = len(input_buffer)
   384        if original_len > 150:
   385          if self.base64_str_re.match(input_buffer):
   386            input_buffer = '<string of %d bytes>' % original_len
   387            input_buffer = input_buffer.encode('ascii', errors=errors)
   388          else:
   389            matched = self.coder_str_re.match(input_buffer)
   390            if matched:
   391              input_buffer = '%s<string of %d bytes>' % (
   392                  matched.group(1), matched.end(2) - matched.start(2))
   393              input_buffer = input_buffer.encode('ascii', errors=errors)
   394        return input_buffer, original_len
   395  
   396      def decode_shortstrings(input_buffer, errors='strict'):
   397        """Decoder (to Unicode) that suppresses long base64 strings."""
   398        shortened, length = encode_shortstrings(input_buffer, errors)
   399        return str(shortened), length
   400  
   401      def shortstrings_registerer(encoding_name):
   402        if encoding_name == 'shortstrings':
   403          return codecs.CodecInfo(
   404              name='shortstrings',
   405              encode=encode_shortstrings,
   406              decode=decode_shortstrings)
   407        return None
   408  
   409      codecs.register(shortstrings_registerer)
   410  
   411      # Use json "dump string" method to get readable formatting;
   412      # further modify it to not output too-long strings, aimed at the
   413      # 10,000+ character hex-encoded "serialized_fn" values.
   414      return json.dumps(
   415          json.loads(encoding.MessageToJson(self.proto)),
   416          indent=2,
   417          sort_keys=True)
   418  
   419    @staticmethod
   420    def _build_default_job_name(user_name):
   421      """Generates a default name for a job.
   422  
   423      user_name is lowercased, and any characters outside of [-a-z0-9]
   424      are removed. If necessary, the user_name is truncated to shorten
   425      the job name to 63 characters."""
   426      user_name = re.sub('[^-a-z0-9]', '', user_name.lower())
   427      date_component = datetime.utcnow().strftime('%m%d%H%M%S-%f')
   428      app_user_name = 'beamapp-{}'.format(user_name)
   429      # append 8 random alphanumeric characters to avoid collisions.
   430      random_component = ''.join(
   431          random.choices(string.ascii_lowercase + string.digits, k=8))
   432      job_name = '{}-{}-{}'.format(
   433          app_user_name, date_component, random_component)
   434      if len(job_name) > 63:
   435        job_name = '{}-{}-{}'.format(
   436            app_user_name[:-(len(job_name) - 63)],
   437            date_component,
   438            random_component)
   439      return job_name
   440  
   441    @staticmethod
   442    def default_job_name(job_name):
   443      if job_name is None:
   444        job_name = Job._build_default_job_name(getpass.getuser())
   445      return job_name
   446  
   447    def __init__(self, options, proto_pipeline):
   448      self.options = options
   449      validate_pipeline_graph(proto_pipeline)
   450      self.proto_pipeline = proto_pipeline
   451      self.google_cloud_options = options.view_as(GoogleCloudOptions)
   452      if not self.google_cloud_options.job_name:
   453        self.google_cloud_options.job_name = self.default_job_name(
   454            self.google_cloud_options.job_name)
   455  
   456      required_google_cloud_options = ['project', 'job_name', 'temp_location']
   457      missing = [
   458          option for option in required_google_cloud_options
   459          if not getattr(self.google_cloud_options, option)
   460      ]
   461      if missing:
   462        raise ValueError(
   463            'Missing required configuration parameters: %s' % missing)
   464  
   465      if not self.google_cloud_options.staging_location:
   466        _LOGGER.info(
   467            'Defaulting to the temp_location as staging_location: %s',
   468            self.google_cloud_options.temp_location)
   469        (
   470            self.google_cloud_options.staging_location
   471        ) = self.google_cloud_options.temp_location
   472  
   473      self.root_staging_location = self.google_cloud_options.staging_location
   474  
   475      # Make the staging and temp locations job name and time specific. This is
   476      # needed to avoid clashes between job submissions using the same staging
   477      # area or team members using same job names. This method is not entirely
   478      # foolproof since two job submissions with same name can happen at exactly
   479      # the same time. However the window is extremely small given that
   480      # time.time() has at least microseconds granularity. We add the suffix only
   481      # for GCS staging locations where the potential for such clashes is high.
   482      if self.google_cloud_options.staging_location.startswith('gs://'):
   483        path_suffix = '%s.%f' % (self.google_cloud_options.job_name, time.time())
   484        self.google_cloud_options.staging_location = FileSystems.join(
   485            self.google_cloud_options.staging_location, path_suffix)
   486        self.google_cloud_options.temp_location = FileSystems.join(
   487            self.google_cloud_options.temp_location, path_suffix)
   488  
   489      self.proto = dataflow.Job(name=self.google_cloud_options.job_name)
   490      if self.options.view_as(StandardOptions).streaming:
   491        self.proto.type = dataflow.Job.TypeValueValuesEnum.JOB_TYPE_STREAMING
   492      else:
   493        self.proto.type = dataflow.Job.TypeValueValuesEnum.JOB_TYPE_BATCH
   494      if self.google_cloud_options.update:
   495        self.proto.replaceJobId = self.job_id_for_name(self.proto.name)
   496        if self.google_cloud_options.transform_name_mapping:
   497          self.proto.transformNameMapping = (
   498              dataflow.Job.TransformNameMappingValue())
   499          for _, (key, value) in enumerate(
   500              self.google_cloud_options.transform_name_mapping.items()):
   501            self.proto.transformNameMapping.additionalProperties.append(
   502                dataflow.Job.TransformNameMappingValue.AdditionalProperty(
   503                    key=key, value=value))
   504      if self.google_cloud_options.create_from_snapshot:
   505        self.proto.createdFromSnapshotId = (
   506            self.google_cloud_options.create_from_snapshot)
   507      # Labels.
   508      if self.google_cloud_options.labels:
   509        self.proto.labels = dataflow.Job.LabelsValue()
   510        labels = self.google_cloud_options.labels
   511        for label in labels:
   512          if '{' in label:
   513            label = ast.literal_eval(label)
   514            for key, value in label.items():
   515              self.proto.labels.additionalProperties.append(
   516                  dataflow.Job.LabelsValue.AdditionalProperty(
   517                      key=key, value=value))
   518          else:
   519            parts = label.split('=', 1)
   520            key = parts[0]
   521            value = parts[1] if len(parts) > 1 else ''
   522            self.proto.labels.additionalProperties.append(
   523                dataflow.Job.LabelsValue.AdditionalProperty(key=key, value=value))
   524  
   525      # Client Request ID
   526      self.proto.clientRequestId = '{}-{}'.format(
   527          datetime.utcnow().strftime('%Y%m%d%H%M%S%f'),
   528          random.randrange(9000) + 1000)
   529  
   530      self.base64_str_re = re.compile(r'^[A-Za-z0-9+/]*=*$')
   531      self.coder_str_re = re.compile(r'^([A-Za-z]+\$)([A-Za-z0-9+/]*=*)$')
   532  
   533    def job_id_for_name(self, job_name):
   534      return DataflowApplicationClient(
   535          self.google_cloud_options).job_id_for_name(job_name)
   536  
   537    def json(self):
   538      return encoding.MessageToJson(self.proto)
   539  
   540    def __reduce__(self):
   541      """Reduce hook for pickling the Job class more easily."""
   542      return (Job, (self.options, ))
   543  
   544  
   545  class DataflowApplicationClient(object):
   546    _HASH_CHUNK_SIZE = 1024 * 8
   547    _GCS_CACHE_PREFIX = "artifact_cache"
   548    """A Dataflow API client used by application code to create and query jobs."""
   549    def __init__(self, options, root_staging_location=None):
   550      """Initializes a Dataflow API client object."""
   551      self.standard_options = options.view_as(StandardOptions)
   552      self.google_cloud_options = options.view_as(GoogleCloudOptions)
   553      self._enable_caching = self.google_cloud_options.enable_artifact_caching
   554      self._root_staging_location = (
   555          root_staging_location or self.google_cloud_options.staging_location)
   556  
   557      from apache_beam.runners.dataflow.dataflow_runner import _is_runner_v2
   558      if _is_runner_v2(options):
   559        self.environment_version = _FNAPI_ENVIRONMENT_MAJOR_VERSION
   560      else:
   561        self.environment_version = _LEGACY_ENVIRONMENT_MAJOR_VERSION
   562  
   563      if self.google_cloud_options.no_auth:
   564        credentials = None
   565      else:
   566        credentials = get_service_credentials(options)
   567  
   568      http_client = get_new_http()
   569      self._client = dataflow.DataflowV1b3(
   570          url=self.google_cloud_options.dataflow_endpoint,
   571          credentials=credentials,
   572          get_credentials=(not self.google_cloud_options.no_auth),
   573          http=http_client,
   574          response_encoding=get_response_encoding())
   575      self._storage_client = storage.StorageV1(
   576          url='https://www.googleapis.com/storage/v1',
   577          credentials=credentials,
   578          get_credentials=(not self.google_cloud_options.no_auth),
   579          http=http_client,
   580          response_encoding=get_response_encoding())
   581      self._sdk_image_overrides = self._get_sdk_image_overrides(options)
   582  
   583    def _get_sdk_image_overrides(self, pipeline_options):
   584      worker_options = pipeline_options.view_as(WorkerOptions)
   585      sdk_overrides = worker_options.sdk_harness_container_image_overrides
   586      return (
   587          dict(s.split(',', 1) for s in sdk_overrides) if sdk_overrides else {})
   588  
   589    @staticmethod
   590    def _compute_sha256(file):
   591      hasher = hashlib.sha256()
   592      with open(file, 'rb') as f:
   593        for chunk in iter(partial(f.read,
   594                                  DataflowApplicationClient._HASH_CHUNK_SIZE),
   595                          b""):
   596          hasher.update(chunk)
   597      return hasher.hexdigest()
   598  
   599    def _cached_location(self, sha256):
   600      sha_prefix = sha256[0:2]
   601      return FileSystems.join(
   602          self._root_staging_location,
   603          DataflowApplicationClient._GCS_CACHE_PREFIX,
   604          sha_prefix,
   605          sha256)
   606  
   607    def _gcs_file_copy(self, from_path, to_path, sha256):
   608      if self._enable_caching and sha256:
   609        self._cached_gcs_file_copy(from_path, to_path, sha256)
   610      else:
   611        self._uncached_gcs_file_copy(from_path, to_path)
   612  
   613    def _cached_gcs_file_copy(self, from_path, to_path, sha256):
   614      cached_path = self._cached_location(sha256)
   615      if FileSystems.exists(cached_path):
   616        _LOGGER.info(
   617            'Skipping upload of %s because it already exists at %s',
   618            to_path,
   619            cached_path)
   620      else:
   621        self._uncached_gcs_file_copy(from_path, cached_path)
   622  
   623      FileSystems.copy(
   624          source_file_names=[cached_path], destination_file_names=[to_path])
   625      _LOGGER.info('Copied cached artifact from %s to %s', from_path, to_path)
   626  
   627    @retry.with_exponential_backoff(
   628        retry_filter=retry.retry_on_server_errors_and_timeout_filter)
   629    def _uncached_gcs_file_copy(self, from_path, to_path):
   630      to_folder, to_name = os.path.split(to_path)
   631      total_size = os.path.getsize(from_path)
   632      with open(from_path, 'rb') as f:
   633        self.stage_file(to_folder, to_name, f, total_size=total_size)
   634  
   635    def _stage_resources(self, pipeline, options):
   636      google_cloud_options = options.view_as(GoogleCloudOptions)
   637      if google_cloud_options.staging_location is None:
   638        raise RuntimeError('The --staging_location option must be specified.')
   639      if google_cloud_options.temp_location is None:
   640        raise RuntimeError('The --temp_location option must be specified.')
   641  
   642      resources = []
   643      staged_paths = {}
   644      staged_hashes = {}
   645      for _, env in sorted(pipeline.components.environments.items(),
   646                           key=lambda kv: kv[0]):
   647        for dep in env.dependencies:
   648          if dep.type_urn != common_urns.artifact_types.FILE.urn:
   649            raise RuntimeError('unsupported artifact type %s' % dep.type_urn)
   650          type_payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
   651              dep.type_payload)
   652  
   653          if dep.role_urn == common_urns.artifact_roles.STAGING_TO.urn:
   654            remote_name = (
   655                beam_runner_api_pb2.ArtifactStagingToRolePayload.FromString(
   656                    dep.role_payload)).staged_name
   657            is_staged_role = True
   658          else:
   659            remote_name = os.path.basename(type_payload.path)
   660            is_staged_role = False
   661  
   662          if self._enable_caching and not type_payload.sha256:
   663            type_payload.sha256 = self._compute_sha256(type_payload.path)
   664  
   665          if type_payload.sha256 and type_payload.sha256 in staged_hashes:
   666            _LOGGER.info(
   667                'Found duplicated artifact sha256: %s (%s)',
   668                type_payload.path,
   669                type_payload.sha256)
   670            remote_name = staged_hashes[type_payload.sha256]
   671            if is_staged_role:
   672              # We should not be overriding this, as dep.role_payload.staged_name
   673              # refers to the desired name on the worker, whereas staged_name
   674              # refers to its placement in a distributed filesystem.
   675              # TODO(heejong): Clean this up.
   676              dep.role_payload = beam_runner_api_pb2.ArtifactStagingToRolePayload(
   677                  staged_name=remote_name).SerializeToString()
   678          elif type_payload.path and type_payload.path in staged_paths:
   679            _LOGGER.info(
   680                'Found duplicated artifact path: %s (%s)',
   681                type_payload.path,
   682                type_payload.sha256)
   683            remote_name = staged_paths[type_payload.path]
   684            if is_staged_role:
   685              # We should not be overriding this, as dep.role_payload.staged_name
   686              # refers to the desired name on the worker, whereas staged_name
   687              # refers to its placement in a distributed filesystem.
   688              # TODO(heejong): Clean this up.
   689              dep.role_payload = beam_runner_api_pb2.ArtifactStagingToRolePayload(
   690                  staged_name=remote_name).SerializeToString()
   691          else:
   692            resources.append(
   693                (type_payload.path, remote_name, type_payload.sha256))
   694            staged_paths[type_payload.path] = remote_name
   695            staged_hashes[type_payload.sha256] = remote_name
   696  
   697          if FileSystems.get_scheme(
   698              google_cloud_options.staging_location) == GCSFileSystem.scheme():
   699            dep.type_urn = common_urns.artifact_types.URL.urn
   700            dep.type_payload = beam_runner_api_pb2.ArtifactUrlPayload(
   701                url=FileSystems.join(
   702                    google_cloud_options.staging_location, remote_name),
   703                sha256=type_payload.sha256).SerializeToString()
   704          else:
   705            dep.type_payload = beam_runner_api_pb2.ArtifactFilePayload(
   706                path=FileSystems.join(
   707                    google_cloud_options.staging_location, remote_name),
   708                sha256=type_payload.sha256).SerializeToString()
   709  
   710      resource_stager = _LegacyDataflowStager(self)
   711      staged_resources = resource_stager.stage_job_resources(
   712          resources, staging_location=google_cloud_options.staging_location)
   713      return staged_resources
   714  
   715    def stage_file(
   716        self,
   717        gcs_or_local_path,
   718        file_name,
   719        stream,
   720        mime_type='application/octet-stream',
   721        total_size=None):
   722      """Stages a file at a GCS or local path with stream-supplied contents."""
   723      if not gcs_or_local_path.startswith('gs://'):
   724        local_path = FileSystems.join(gcs_or_local_path, file_name)
   725        _LOGGER.info('Staging file locally to %s', local_path)
   726        with open(local_path, 'wb') as f:
   727          f.write(stream.read())
   728        return
   729      gcs_location = FileSystems.join(gcs_or_local_path, file_name)
   730      bucket, name = gcs_location[5:].split('/', 1)
   731  
   732      request = storage.StorageObjectsInsertRequest(bucket=bucket, name=name)
   733      start_time = time.time()
   734      _LOGGER.info('Starting GCS upload to %s...', gcs_location)
   735      upload = storage.Upload(stream, mime_type, total_size)
   736      try:
   737        response = self._storage_client.objects.Insert(request, upload=upload)
   738      except exceptions.HttpError as e:
   739        reportable_errors = {
   740            403: 'access denied',
   741            404: 'bucket not found',
   742        }
   743        if e.status_code in reportable_errors:
   744          raise IOError((
   745              'Could not upload to GCS path %s: %s. Please verify '
   746              'that credentials are valid and that you have write '
   747              'access to the specified path.') %
   748                        (gcs_or_local_path, reportable_errors[e.status_code]))
   749        raise
   750      _LOGGER.info(
   751          'Completed GCS upload to %s in %s seconds.',
   752          gcs_location,
   753          int(time.time() - start_time))
   754      return response
   755  
   756    @retry.no_retries  # Using no_retries marks this as an integration point.
   757    def create_job(self, job):
   758      """Creates job description. May stage and/or submit for remote execution."""
   759      self.create_job_description(job)
   760  
   761      # Stage and submit the job when necessary
   762      dataflow_job_file = job.options.view_as(DebugOptions).dataflow_job_file
   763      template_location = (
   764          job.options.view_as(GoogleCloudOptions).template_location)
   765  
   766      if job.options.view_as(DebugOptions).lookup_experiment('upload_graph'):
   767        self.stage_file(
   768            job.options.view_as(GoogleCloudOptions).staging_location,
   769            "dataflow_graph.json",
   770            io.BytesIO(job.json().encode('utf-8')))
   771        del job.proto.steps[:]
   772        job.proto.stepsLocation = FileSystems.join(
   773            job.options.view_as(GoogleCloudOptions).staging_location,
   774            "dataflow_graph.json")
   775  
   776      # template file generation should be placed immediately before the
   777      # conditional API call.
   778      job_location = template_location or dataflow_job_file
   779      if job_location:
   780        gcs_or_local_path = os.path.dirname(job_location)
   781        file_name = os.path.basename(job_location)
   782        self.stage_file(
   783            gcs_or_local_path, file_name, io.BytesIO(job.json().encode('utf-8')))
   784  
   785      if not template_location:
   786        return self.submit_job_description(job)
   787  
   788      _LOGGER.info(
   789          'A template was just created at location %s', template_location)
   790      return None
   791  
   792    @staticmethod
   793    def _update_container_image_for_dataflow(beam_container_image_url):
   794      # By default Dataflow pipelines use containers hosted in Dataflow GCR
   795      # instead of Docker Hub.
   796      image_suffix = beam_container_image_url.rsplit('/', 1)[1]
   797      return names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + '/' + image_suffix
   798  
   799    @staticmethod
   800    def _apply_sdk_environment_overrides(
   801        proto_pipeline, sdk_overrides, pipeline_options):
   802      # Updates container image URLs for Dataflow.
   803      # For a given container image URL
   804      # * If a matching override has been provided that will be used.
   805      # * For improved performance, External Apache Beam container images that are
   806      #   not explicitly overridden will be
   807      #   updated to use GCR copies instead of directly downloading from the
   808      #   Docker Hub.
   809  
   810      current_sdk_container_image = get_container_image_from_options(
   811          pipeline_options)
   812  
   813      for environment in proto_pipeline.components.environments.values():
   814        docker_payload = proto_utils.parse_Bytes(
   815            environment.payload, beam_runner_api_pb2.DockerPayload)
   816        overridden = False
   817        new_container_image = docker_payload.container_image
   818        for pattern, override in sdk_overrides.items():
   819          new_container_image = re.sub(pattern, override, new_container_image)
   820          if new_container_image != docker_payload.container_image:
   821            overridden = True
   822  
   823        # Container of the current (Python) SDK is overridden separately, hence
   824        # not updated here.
   825        if (is_apache_beam_container(new_container_image) and not overridden and
   826            new_container_image != current_sdk_container_image):
   827          new_container_image = (
   828              DataflowApplicationClient._update_container_image_for_dataflow(
   829                  docker_payload.container_image))
   830  
   831        if not new_container_image:
   832          raise ValueError(
   833              'SDK Docker container image has to be a non-empty string')
   834  
   835        new_payload = copy(docker_payload)
   836        new_payload.container_image = new_container_image
   837        environment.payload = new_payload.SerializeToString()
   838  
   839    def create_job_description(self, job):
   840      """Creates a job described by the workflow proto."""
   841      DataflowApplicationClient._apply_sdk_environment_overrides(
   842          job.proto_pipeline, self._sdk_image_overrides, job.options)
   843  
   844      # Stage other resources for the SDK harness
   845      resources = self._stage_resources(job.proto_pipeline, job.options)
   846  
   847      # Stage proto pipeline.
   848      self.stage_file(
   849          job.google_cloud_options.staging_location,
   850          shared_names.STAGED_PIPELINE_FILENAME,
   851          io.BytesIO(job.proto_pipeline.SerializeToString()))
   852  
   853      job.proto.environment = Environment(
   854          proto_pipeline_staged_url=FileSystems.join(
   855              job.google_cloud_options.staging_location,
   856              shared_names.STAGED_PIPELINE_FILENAME),
   857          packages=resources,
   858          options=job.options,
   859          environment_version=self.environment_version,
   860          proto_pipeline=job.proto_pipeline).proto
   861      _LOGGER.debug('JOB: %s', job)
   862  
   863    @retry.with_exponential_backoff(num_retries=3, initial_delay_secs=3)
   864    def get_job_metrics(self, job_id):
   865      request = dataflow.DataflowProjectsLocationsJobsGetMetricsRequest()
   866      request.jobId = job_id
   867      request.location = self.google_cloud_options.region
   868      request.projectId = self.google_cloud_options.project
   869      try:
   870        response = self._client.projects_locations_jobs.GetMetrics(request)
   871      except exceptions.BadStatusCodeError as e:
   872        _LOGGER.error(
   873            'HTTP status %d. Unable to query metrics', e.response.status)
   874        raise
   875      return response
   876  
   877    @retry.with_exponential_backoff(num_retries=3)
   878    def submit_job_description(self, job):
   879      """Creates and excutes a job request."""
   880      request = dataflow.DataflowProjectsLocationsJobsCreateRequest()
   881      request.projectId = self.google_cloud_options.project
   882      request.location = self.google_cloud_options.region
   883      request.job = job.proto
   884  
   885      try:
   886        response = self._client.projects_locations_jobs.Create(request)
   887      except exceptions.BadStatusCodeError as e:
   888        _LOGGER.error(
   889            'HTTP status %d trying to create job'
   890            ' at dataflow service endpoint %s',
   891            e.response.status,
   892            self.google_cloud_options.dataflow_endpoint)
   893        _LOGGER.fatal('details of server error: %s', e)
   894        raise
   895  
   896      if response.clientRequestId and \
   897          response.clientRequestId != job.proto.clientRequestId:
   898        if self.google_cloud_options.update:
   899          raise DataflowJobAlreadyExistsError(
   900              "The job named %s with id: %s has already been updated into job "
   901              "id: %s and cannot be updated again." %
   902              (response.name, job.proto.replaceJobId, response.id))
   903        else:
   904          raise DataflowJobAlreadyExistsError(
   905              'There is already active job named %s with id: %s. If you want to '
   906              'submit a second job, try again by setting a different name using '
   907              '--job_name.' % (response.name, response.id))
   908  
   909      _LOGGER.info('Create job: %s', response)
   910      # The response is a Job proto with the id for the new job.
   911      _LOGGER.info('Created job with id: [%s]', response.id)
   912      _LOGGER.info('Submitted job: %s', response.id)
   913      _LOGGER.info(
   914          'To access the Dataflow monitoring console, please navigate to '
   915          'https://console.cloud.google.com/dataflow/jobs/%s/%s?project=%s',
   916          self.google_cloud_options.region,
   917          response.id,
   918          self.google_cloud_options.project)
   919  
   920      return response
   921  
   922    @retry.with_exponential_backoff()  # Using retry defaults from utils/retry.py
   923    def modify_job_state(self, job_id, new_state):
   924      """Modify the run state of the job.
   925  
   926      Args:
   927        job_id: The id of the job.
   928        new_state: A string representing the new desired state. It could be set to
   929        either 'JOB_STATE_DONE', 'JOB_STATE_CANCELLED' or 'JOB_STATE_DRAINING'.
   930  
   931      Returns:
   932        True if the job was modified successfully.
   933      """
   934      if new_state == 'JOB_STATE_DONE':
   935        new_state = dataflow.Job.RequestedStateValueValuesEnum.JOB_STATE_DONE
   936      elif new_state == 'JOB_STATE_CANCELLED':
   937        new_state = dataflow.Job.RequestedStateValueValuesEnum.JOB_STATE_CANCELLED
   938      elif new_state == 'JOB_STATE_DRAINING':
   939        new_state = dataflow.Job.RequestedStateValueValuesEnum.JOB_STATE_DRAINING
   940      else:
   941        # Other states could only be set by the service.
   942        return False
   943  
   944      request = dataflow.DataflowProjectsLocationsJobsUpdateRequest()
   945      request.jobId = job_id
   946      request.projectId = self.google_cloud_options.project
   947      request.location = self.google_cloud_options.region
   948      request.job = dataflow.Job(requestedState=new_state)
   949  
   950      self._client.projects_locations_jobs.Update(request)
   951      return True
   952  
   953    @retry.with_exponential_backoff(
   954        retry_filter=retry.retry_on_server_errors_and_notfound_filter)
   955    def get_job(self, job_id):
   956      """Gets the job status for a submitted job.
   957  
   958      Args:
   959        job_id: A string representing the job_id for the workflow as returned
   960          by the create_job() request.
   961  
   962      Returns:
   963        A Job proto. See below for interesting fields.
   964  
   965      The Job proto returned from a get_job() request contains some interesting
   966      fields:
   967        currentState: An object representing the current state of the job. The
   968          string representation of the object (str() result) has the following
   969          possible values: JOB_STATE_UNKNONW, JOB_STATE_STOPPED,
   970          JOB_STATE_RUNNING, JOB_STATE_DONE, JOB_STATE_FAILED,
   971          JOB_STATE_CANCELLED.
   972        createTime: UTC time when the job was created
   973          (e.g. '2015-03-10T00:01:53.074Z')
   974        currentStateTime: UTC time for the current state of the job.
   975      """
   976      request = dataflow.DataflowProjectsLocationsJobsGetRequest()
   977      request.jobId = job_id
   978      request.projectId = self.google_cloud_options.project
   979      request.location = self.google_cloud_options.region
   980      response = self._client.projects_locations_jobs.Get(request)
   981      return response
   982  
   983    @retry.with_exponential_backoff(
   984        retry_filter=retry.retry_on_server_errors_and_notfound_filter)
   985    def list_messages(
   986        self,
   987        job_id,
   988        start_time=None,
   989        end_time=None,
   990        page_token=None,
   991        minimum_importance=None):
   992      """List messages associated with the execution of a job.
   993  
   994      Args:
   995        job_id: A string representing the job_id for the workflow as returned
   996          by the create_job() request.
   997        start_time: If specified, only messages generated after the start time
   998          will be returned, otherwise all messages since job started will be
   999          returned. The value is a string representing UTC time
  1000          (e.g., '2015-08-18T21:03:50.644Z')
  1001        end_time: If specified, only messages generated before the end time
  1002          will be returned, otherwise all messages up to current time will be
  1003          returned. The value is a string representing UTC time
  1004          (e.g., '2015-08-18T21:03:50.644Z')
  1005        page_token: A string to be used as next page token if the list call
  1006          returned paginated results.
  1007        minimum_importance: Filter for messages based on importance. The possible
  1008          string values in increasing order of importance are: JOB_MESSAGE_DEBUG,
  1009          JOB_MESSAGE_DETAILED, JOB_MESSAGE_BASIC, JOB_MESSAGE_WARNING,
  1010          JOB_MESSAGE_ERROR. For example, a filter set on warning will allow only
  1011          warnings and errors and exclude all others.
  1012  
  1013      Returns:
  1014        A tuple consisting of a list of JobMessage instances and a
  1015        next page token string.
  1016  
  1017      Raises:
  1018        RuntimeError: if an unexpected value for the message_importance argument
  1019          is used.
  1020  
  1021      The JobMessage objects returned by the call contain the following  fields:
  1022        id: A unique string identifier for the message.
  1023        time: A string representing the UTC time of the message
  1024          (e.g., '2015-08-18T21:03:50.644Z')
  1025        messageImportance: An enumeration value for the message importance. The
  1026          value if converted to string will have the following possible values:
  1027          JOB_MESSAGE_DEBUG, JOB_MESSAGE_DETAILED, JOB_MESSAGE_BASIC,
  1028          JOB_MESSAGE_WARNING, JOB_MESSAGE_ERROR.
  1029       messageText: A message string.
  1030      """
  1031      request = dataflow.DataflowProjectsLocationsJobsMessagesListRequest(
  1032          jobId=job_id,
  1033          location=self.google_cloud_options.region,
  1034          projectId=self.google_cloud_options.project)
  1035      if page_token is not None:
  1036        request.pageToken = page_token
  1037      if start_time is not None:
  1038        request.startTime = start_time
  1039      if end_time is not None:
  1040        request.endTime = end_time
  1041      if minimum_importance is not None:
  1042        if minimum_importance == 'JOB_MESSAGE_DEBUG':
  1043          request.minimumImportance = (
  1044              dataflow.DataflowProjectsLocationsJobsMessagesListRequest.
  1045              MinimumImportanceValueValuesEnum.JOB_MESSAGE_DEBUG)
  1046        elif minimum_importance == 'JOB_MESSAGE_DETAILED':
  1047          request.minimumImportance = (
  1048              dataflow.DataflowProjectsLocationsJobsMessagesListRequest.
  1049              MinimumImportanceValueValuesEnum.JOB_MESSAGE_DETAILED)
  1050        elif minimum_importance == 'JOB_MESSAGE_BASIC':
  1051          request.minimumImportance = (
  1052              dataflow.DataflowProjectsLocationsJobsMessagesListRequest.
  1053              MinimumImportanceValueValuesEnum.JOB_MESSAGE_BASIC)
  1054        elif minimum_importance == 'JOB_MESSAGE_WARNING':
  1055          request.minimumImportance = (
  1056              dataflow.DataflowProjectsLocationsJobsMessagesListRequest.
  1057              MinimumImportanceValueValuesEnum.JOB_MESSAGE_WARNING)
  1058        elif minimum_importance == 'JOB_MESSAGE_ERROR':
  1059          request.minimumImportance = (
  1060              dataflow.DataflowProjectsLocationsJobsMessagesListRequest.
  1061              MinimumImportanceValueValuesEnum.JOB_MESSAGE_ERROR)
  1062        else:
  1063          raise RuntimeError(
  1064              'Unexpected value for minimum_importance argument: %r' %
  1065              minimum_importance)
  1066      response = self._client.projects_locations_jobs_messages.List(request)
  1067      return response.jobMessages, response.nextPageToken
  1068  
  1069    def job_id_for_name(self, job_name):
  1070      token = None
  1071      while True:
  1072        request = dataflow.DataflowProjectsLocationsJobsListRequest(
  1073            projectId=self.google_cloud_options.project,
  1074            location=self.google_cloud_options.region,
  1075            pageToken=token)
  1076        response = self._client.projects_locations_jobs.List(request)
  1077        for job in response.jobs:
  1078          if (job.name == job_name and job.currentState in [
  1079              dataflow.Job.CurrentStateValueValuesEnum.JOB_STATE_RUNNING,
  1080              dataflow.Job.CurrentStateValueValuesEnum.JOB_STATE_DRAINING
  1081          ]):
  1082            return job.id
  1083        token = response.nextPageToken
  1084        if token is None:
  1085          raise ValueError("No running job found with name '%s'" % job_name)
  1086  
  1087  
  1088  class MetricUpdateTranslators(object):
  1089    """Translators between accumulators and dataflow metric updates."""
  1090    @staticmethod
  1091    def translate_boolean(accumulator, metric_update_proto):
  1092      metric_update_proto.boolean = accumulator.value
  1093  
  1094    @staticmethod
  1095    def translate_scalar_mean_int(accumulator, metric_update_proto):
  1096      if accumulator.count:
  1097        metric_update_proto.integerMean = dataflow.IntegerMean()
  1098        metric_update_proto.integerMean.sum = to_split_int(accumulator.sum)
  1099        metric_update_proto.integerMean.count = to_split_int(accumulator.count)
  1100      else:
  1101        metric_update_proto.nameAndKind.kind = None
  1102  
  1103    @staticmethod
  1104    def translate_scalar_mean_float(accumulator, metric_update_proto):
  1105      if accumulator.count:
  1106        metric_update_proto.floatingPointMean = dataflow.FloatingPointMean()
  1107        metric_update_proto.floatingPointMean.sum = accumulator.sum
  1108        metric_update_proto.floatingPointMean.count = to_split_int(
  1109            accumulator.count)
  1110      else:
  1111        metric_update_proto.nameAndKind.kind = None
  1112  
  1113    @staticmethod
  1114    def translate_scalar_counter_int(accumulator, metric_update_proto):
  1115      metric_update_proto.integer = to_split_int(accumulator.value)
  1116  
  1117    @staticmethod
  1118    def translate_scalar_counter_float(accumulator, metric_update_proto):
  1119      metric_update_proto.floatingPoint = accumulator.value
  1120  
  1121  
  1122  class _LegacyDataflowStager(Stager):
  1123    def __init__(self, dataflow_application_client):
  1124      super().__init__()
  1125      self._dataflow_application_client = dataflow_application_client
  1126  
  1127    def stage_artifact(self, local_path_to_artifact, artifact_name, sha256):
  1128      self._dataflow_application_client._gcs_file_copy(
  1129          local_path_to_artifact, artifact_name, sha256)
  1130  
  1131    def commit_manifest(self):
  1132      pass
  1133  
  1134    @staticmethod
  1135    def get_sdk_package_name():
  1136      """For internal use only; no backwards-compatibility guarantees.
  1137  
  1138            Returns the PyPI package name to be staged to Google Cloud Dataflow.
  1139      """
  1140      return shared_names.BEAM_PACKAGE_NAME
  1141  
  1142  
  1143  class DataflowJobAlreadyExistsError(retry.PermanentException):
  1144    """A non-retryable exception that a job with the given name already exists."""
  1145    # Inherits retry.PermanentException to avoid retry in
  1146    # DataflowApplicationClient.submit_job_description
  1147    pass
  1148  
  1149  
  1150  def to_split_int(n):
  1151    res = dataflow.SplitInt64()
  1152    res.lowBits = n & 0xffffffff
  1153    res.highBits = n >> 32
  1154    return res
  1155  
  1156  
  1157  # TODO: Used in legacy batch worker. Move under MetricUpdateTranslators
  1158  # after Runner V2 transition.
  1159  def translate_distribution(distribution_update, metric_update_proto):
  1160    """Translate metrics DistributionUpdate to dataflow distribution update.
  1161  
  1162    Args:
  1163      distribution_update: Instance of DistributionData,
  1164      DistributionInt64Accumulator or DataflowDistributionCounter.
  1165      metric_update_proto: Used for report metrics.
  1166    """
  1167    dist_update_proto = dataflow.DistributionUpdate()
  1168    dist_update_proto.min = to_split_int(distribution_update.min)
  1169    dist_update_proto.max = to_split_int(distribution_update.max)
  1170    dist_update_proto.count = to_split_int(distribution_update.count)
  1171    dist_update_proto.sum = to_split_int(distribution_update.sum)
  1172    # DataflowDistributionCounter needs to translate histogram
  1173    if isinstance(distribution_update, DataflowDistributionCounter):
  1174      dist_update_proto.histogram = dataflow.Histogram()
  1175      distribution_update.translate_to_histogram(dist_update_proto.histogram)
  1176    metric_update_proto.distribution = dist_update_proto
  1177  
  1178  
  1179  # TODO: Used in legacy batch worker. Delete after Runner V2 transition.
  1180  def translate_value(value, metric_update_proto):
  1181    metric_update_proto.integer = to_split_int(value)
  1182  
  1183  
  1184  def _get_container_image_tag():
  1185    base_version = pkg_resources.parse_version(
  1186        beam_version.__version__).base_version
  1187    if base_version != beam_version.__version__:
  1188      warnings.warn(
  1189          "A non-standard version of Beam SDK detected: %s. "
  1190          "Dataflow runner will use container image tag %s. "
  1191          "This use case is not supported." %
  1192          (beam_version.__version__, base_version))
  1193    return base_version
  1194  
  1195  
  1196  def get_container_image_from_options(pipeline_options):
  1197    """For internal use only; no backwards-compatibility guarantees.
  1198  
  1199      Args:
  1200        pipeline_options (PipelineOptions): A container for pipeline options.
  1201  
  1202      Returns:
  1203        str: Container image for remote execution.
  1204    """
  1205    from apache_beam.runners.dataflow.dataflow_runner import _is_runner_v2
  1206    worker_options = pipeline_options.view_as(WorkerOptions)
  1207    if worker_options.sdk_container_image:
  1208      return worker_options.sdk_container_image
  1209  
  1210    is_runner_v2 = _is_runner_v2(pipeline_options)
  1211  
  1212    # Legacy and runner v2 exist in different repositories.
  1213    # Set to legacy format, override if runner v2
  1214    container_repo = names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY
  1215    image_name = '{repository}/python{major}{minor}'.format(
  1216        repository=container_repo,
  1217        major=sys.version_info[0],
  1218        minor=sys.version_info[1])
  1219  
  1220    if is_runner_v2:
  1221      image_name = '{repository}/beam_python{major}.{minor}_sdk'.format(
  1222          repository=container_repo,
  1223          major=sys.version_info[0],
  1224          minor=sys.version_info[1])
  1225  
  1226    image_tag = _get_required_container_version(is_runner_v2)
  1227    return image_name + ':' + image_tag
  1228  
  1229  
  1230  def _get_required_container_version(is_runner_v2):
  1231    """For internal use only; no backwards-compatibility guarantees.
  1232  
  1233      Args:
  1234        is_runner_v2 (bool): True if and only if pipeline is using runner v2.
  1235  
  1236      Returns:
  1237        str: The tag of worker container images in GCR that corresponds to
  1238          current version of the SDK.
  1239      """
  1240    if 'dev' in beam_version.__version__:
  1241      if is_runner_v2:
  1242        return names.BEAM_FNAPI_CONTAINER_VERSION
  1243      else:
  1244        return names.BEAM_CONTAINER_VERSION
  1245    else:
  1246      return _get_container_image_tag()
  1247  
  1248  
  1249  def get_response_encoding():
  1250    """Encoding to use to decode HTTP response from Google APIs."""
  1251    return 'utf8'
  1252  
  1253  
  1254  def _verify_interpreter_version_is_supported(pipeline_options):
  1255    if ('%s.%s' %
  1256        (sys.version_info[0],
  1257         sys.version_info[1]) in _PYTHON_VERSIONS_SUPPORTED_BY_DATAFLOW):
  1258      return
  1259  
  1260    if 'dev' in beam_version.__version__:
  1261      return
  1262  
  1263    debug_options = pipeline_options.view_as(DebugOptions)
  1264    if (debug_options.experiments and
  1265        'use_unsupported_python_version' in debug_options.experiments):
  1266      return
  1267  
  1268    raise Exception(
  1269        'Dataflow runner currently supports Python versions %s, got %s.\n'
  1270        'To ignore this requirement and start a job '
  1271        'using an unsupported version of Python interpreter, pass '
  1272        '--experiment use_unsupported_python_version pipeline option.' %
  1273        (_PYTHON_VERSIONS_SUPPORTED_BY_DATAFLOW, sys.version))
  1274  
  1275  
  1276  # To enable a counter on the service, add it to this dictionary.
  1277  # This is required for the legacy python dataflow runner, as portability
  1278  # does not communicate to the service via python code, but instead via a
  1279  # a runner harness (in C++ or Java).
  1280  # TODO(https://github.com/apache/beam/issues/19433) : Remove this antipattern,
  1281  # legacy dataflow python pipelines will break whenever a new cy_combiner type
  1282  # is used.
  1283  structured_counter_translations = {
  1284      cy_combiners.CountCombineFn: (
  1285          dataflow.CounterMetadata.KindValueValuesEnum.SUM,
  1286          MetricUpdateTranslators.translate_scalar_counter_int),
  1287      cy_combiners.SumInt64Fn: (
  1288          dataflow.CounterMetadata.KindValueValuesEnum.SUM,
  1289          MetricUpdateTranslators.translate_scalar_counter_int),
  1290      cy_combiners.MinInt64Fn: (
  1291          dataflow.CounterMetadata.KindValueValuesEnum.MIN,
  1292          MetricUpdateTranslators.translate_scalar_counter_int),
  1293      cy_combiners.MaxInt64Fn: (
  1294          dataflow.CounterMetadata.KindValueValuesEnum.MAX,
  1295          MetricUpdateTranslators.translate_scalar_counter_int),
  1296      cy_combiners.MeanInt64Fn: (
  1297          dataflow.CounterMetadata.KindValueValuesEnum.MEAN,
  1298          MetricUpdateTranslators.translate_scalar_mean_int),
  1299      cy_combiners.SumFloatFn: (
  1300          dataflow.CounterMetadata.KindValueValuesEnum.SUM,
  1301          MetricUpdateTranslators.translate_scalar_counter_float),
  1302      cy_combiners.MinFloatFn: (
  1303          dataflow.CounterMetadata.KindValueValuesEnum.MIN,
  1304          MetricUpdateTranslators.translate_scalar_counter_float),
  1305      cy_combiners.MaxFloatFn: (
  1306          dataflow.CounterMetadata.KindValueValuesEnum.MAX,
  1307          MetricUpdateTranslators.translate_scalar_counter_float),
  1308      cy_combiners.MeanFloatFn: (
  1309          dataflow.CounterMetadata.KindValueValuesEnum.MEAN,
  1310          MetricUpdateTranslators.translate_scalar_mean_float),
  1311      cy_combiners.AllCombineFn: (
  1312          dataflow.CounterMetadata.KindValueValuesEnum.AND,
  1313          MetricUpdateTranslators.translate_boolean),
  1314      cy_combiners.AnyCombineFn: (
  1315          dataflow.CounterMetadata.KindValueValuesEnum.OR,
  1316          MetricUpdateTranslators.translate_boolean),
  1317      cy_combiners.DataflowDistributionCounterFn: (
  1318          dataflow.CounterMetadata.KindValueValuesEnum.DISTRIBUTION,
  1319          translate_distribution),
  1320      cy_combiners.DistributionInt64Fn: (
  1321          dataflow.CounterMetadata.KindValueValuesEnum.DISTRIBUTION,
  1322          translate_distribution),
  1323  }
  1324  
  1325  counter_translations = {
  1326      cy_combiners.CountCombineFn: (
  1327          dataflow.NameAndKind.KindValueValuesEnum.SUM,
  1328          MetricUpdateTranslators.translate_scalar_counter_int),
  1329      cy_combiners.SumInt64Fn: (
  1330          dataflow.NameAndKind.KindValueValuesEnum.SUM,
  1331          MetricUpdateTranslators.translate_scalar_counter_int),
  1332      cy_combiners.MinInt64Fn: (
  1333          dataflow.NameAndKind.KindValueValuesEnum.MIN,
  1334          MetricUpdateTranslators.translate_scalar_counter_int),
  1335      cy_combiners.MaxInt64Fn: (
  1336          dataflow.NameAndKind.KindValueValuesEnum.MAX,
  1337          MetricUpdateTranslators.translate_scalar_counter_int),
  1338      cy_combiners.MeanInt64Fn: (
  1339          dataflow.NameAndKind.KindValueValuesEnum.MEAN,
  1340          MetricUpdateTranslators.translate_scalar_mean_int),
  1341      cy_combiners.SumFloatFn: (
  1342          dataflow.NameAndKind.KindValueValuesEnum.SUM,
  1343          MetricUpdateTranslators.translate_scalar_counter_float),
  1344      cy_combiners.MinFloatFn: (
  1345          dataflow.NameAndKind.KindValueValuesEnum.MIN,
  1346          MetricUpdateTranslators.translate_scalar_counter_float),
  1347      cy_combiners.MaxFloatFn: (
  1348          dataflow.NameAndKind.KindValueValuesEnum.MAX,
  1349          MetricUpdateTranslators.translate_scalar_counter_float),
  1350      cy_combiners.MeanFloatFn: (
  1351          dataflow.NameAndKind.KindValueValuesEnum.MEAN,
  1352          MetricUpdateTranslators.translate_scalar_mean_float),
  1353      cy_combiners.AllCombineFn: (
  1354          dataflow.NameAndKind.KindValueValuesEnum.AND,
  1355          MetricUpdateTranslators.translate_boolean),
  1356      cy_combiners.AnyCombineFn: (
  1357          dataflow.NameAndKind.KindValueValuesEnum.OR,
  1358          MetricUpdateTranslators.translate_boolean),
  1359      cy_combiners.DataflowDistributionCounterFn: (
  1360          dataflow.NameAndKind.KindValueValuesEnum.DISTRIBUTION,
  1361          translate_distribution),
  1362      cy_combiners.DistributionInt64Fn: (
  1363          dataflow.CounterMetadata.KindValueValuesEnum.DISTRIBUTION,
  1364          translate_distribution),
  1365  }