github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dataflow/dataflow_runner.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A runner implementation that submits a job for remote execution.
    19  
    20  The runner will create a JSON description of the job graph and then submit it
    21  to the Dataflow Service for remote execution by a worker.
    22  """
    23  # pytype: skip-file
    24  
    25  import base64
    26  import logging
    27  import os
    28  import threading
    29  import time
    30  import traceback
    31  import warnings
    32  from collections import defaultdict
    33  from subprocess import DEVNULL
    34  from typing import TYPE_CHECKING
    35  from typing import List
    36  from urllib.parse import quote
    37  from urllib.parse import quote_from_bytes
    38  from urllib.parse import unquote_to_bytes
    39  
    40  import apache_beam as beam
    41  from apache_beam import coders
    42  from apache_beam import error
    43  from apache_beam.internal import pickler
    44  from apache_beam.internal.gcp import json_value
    45  from apache_beam.options.pipeline_options import DebugOptions
    46  from apache_beam.options.pipeline_options import GoogleCloudOptions
    47  from apache_beam.options.pipeline_options import SetupOptions
    48  from apache_beam.options.pipeline_options import StandardOptions
    49  from apache_beam.options.pipeline_options import TestOptions
    50  from apache_beam.options.pipeline_options import TypeOptions
    51  from apache_beam.options.pipeline_options import WorkerOptions
    52  from apache_beam.portability import common_urns
    53  from apache_beam.portability.api import beam_runner_api_pb2
    54  from apache_beam.pvalue import AsSideInput
    55  from apache_beam.runners.common import DoFnSignature
    56  from apache_beam.runners.common import group_by_key_input_visitor
    57  from apache_beam.runners.dataflow.internal import names
    58  from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api
    59  from apache_beam.runners.dataflow.internal.names import PropertyNames
    60  from apache_beam.runners.dataflow.internal.names import TransformNames
    61  from apache_beam.runners.runner import PipelineResult
    62  from apache_beam.runners.runner import PipelineRunner
    63  from apache_beam.runners.runner import PipelineState
    64  from apache_beam.runners.runner import PValueCache
    65  from apache_beam.transforms import window
    66  from apache_beam.transforms.display import DisplayData
    67  from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX
    68  from apache_beam.typehints import typehints
    69  from apache_beam.utils import processes
    70  from apache_beam.utils import proto_utils
    71  from apache_beam.utils.interactive_utils import is_in_notebook
    72  from apache_beam.utils.plugin import BeamPlugin
    73  
    74  if TYPE_CHECKING:
    75    from apache_beam.pipeline import PTransformOverride
    76  
    77  __all__ = ['DataflowRunner']
    78  
    79  _LOGGER = logging.getLogger(__name__)
    80  
    81  BQ_SOURCE_UW_ERROR = (
    82      'The Read(BigQuerySource(...)) transform is not supported with newer stack '
    83      'features (Fn API, Dataflow Runner V2, etc). Please use the transform '
    84      'apache_beam.io.gcp.bigquery.ReadFromBigQuery instead.')
    85  
    86  
    87  class DataflowRunner(PipelineRunner):
    88    """A runner that creates job graphs and submits them for remote execution.
    89  
    90    Every execution of the run() method will submit an independent job for
    91    remote execution that consists of the nodes reachable from the passed in
    92    node argument or entire graph if node is None. The run() method returns
    93    after the service created the job and  will not wait for the job to finish
    94    if blocking is set to False.
    95    """
    96  
    97    # A list of PTransformOverride objects to be applied before running a pipeline
    98    # using DataflowRunner.
    99    # Currently this only works for overrides where the input and output types do
   100    # not change.
   101    # For internal SDK use only. This should not be updated by Beam pipeline
   102    # authors.
   103  
   104    # Imported here to avoid circular dependencies.
   105    # TODO: Remove the apache_beam.pipeline dependency in CreatePTransformOverride
   106    from apache_beam.runners.dataflow.ptransform_overrides import CombineValuesPTransformOverride
   107    from apache_beam.runners.dataflow.ptransform_overrides import CreatePTransformOverride
   108    from apache_beam.runners.dataflow.ptransform_overrides import ReadPTransformOverride
   109    from apache_beam.runners.dataflow.ptransform_overrides import NativeReadPTransformOverride
   110  
   111    # These overrides should be applied before the proto representation of the
   112    # graph is created.
   113    _PTRANSFORM_OVERRIDES = [
   114        NativeReadPTransformOverride(),
   115    ]  # type: List[PTransformOverride]
   116  
   117    # These overrides should be applied after the proto representation of the
   118    # graph is created.
   119    _NON_PORTABLE_PTRANSFORM_OVERRIDES = [
   120        CombineValuesPTransformOverride(),
   121        CreatePTransformOverride(),
   122        ReadPTransformOverride(),
   123    ]  # type: List[PTransformOverride]
   124  
   125    def __init__(self, cache=None):
   126      # Cache of CloudWorkflowStep protos generated while the runner
   127      # "executes" a pipeline.
   128      self._cache = cache if cache is not None else PValueCache()
   129      self._unique_step_id = 0
   130      self._default_environment = None
   131  
   132    def is_fnapi_compatible(self):
   133      return False
   134  
   135    def apply(self, transform, input, options):
   136      _check_and_add_missing_options(options)
   137      return super().apply(transform, input, options)
   138  
   139    def _get_unique_step_name(self):
   140      self._unique_step_id += 1
   141      return 's%s' % self._unique_step_id
   142  
   143    @staticmethod
   144    def poll_for_job_completion(
   145        runner, result, duration, state_update_callback=None):
   146      """Polls for the specified job to finish running (successfully or not).
   147  
   148      Updates the result with the new job information before returning.
   149  
   150      Args:
   151        runner: DataflowRunner instance to use for polling job state.
   152        result: DataflowPipelineResult instance used for job information.
   153        duration (int): The time to wait (in milliseconds) for job to finish.
   154          If it is set to :data:`None`, it will wait indefinitely until the job
   155          is finished.
   156      """
   157      if result.state == PipelineState.DONE:
   158        return
   159  
   160      last_message_time = None
   161      current_seen_messages = set()
   162  
   163      last_error_rank = float('-inf')
   164      last_error_msg = None
   165      last_job_state = None
   166      # How long to wait after pipeline failure for the error
   167      # message to show up giving the reason for the failure.
   168      # It typically takes about 30 seconds.
   169      final_countdown_timer_secs = 50.0
   170      sleep_secs = 5.0
   171  
   172      # Try to prioritize the user-level traceback, if any.
   173      def rank_error(msg):
   174        if 'work item was attempted' in msg:
   175          return -1
   176        elif 'Traceback' in msg:
   177          return 1
   178        return 0
   179  
   180      if duration:
   181        start_secs = time.time()
   182        duration_secs = duration // 1000
   183  
   184      job_id = result.job_id()
   185      while True:
   186        response = runner.dataflow_client.get_job(job_id)
   187        # If get() is called very soon after Create() the response may not contain
   188        # an initialized 'currentState' field.
   189        if response.currentState is not None:
   190          if response.currentState != last_job_state:
   191            if state_update_callback:
   192              state_update_callback(response.currentState)
   193            _LOGGER.info('Job %s is in state %s', job_id, response.currentState)
   194            last_job_state = response.currentState
   195          if str(response.currentState) != 'JOB_STATE_RUNNING':
   196            # Stop checking for new messages on timeout, explanatory
   197            # message received, success, or a terminal job state caused
   198            # by the user that therefore doesn't require explanation.
   199            if (final_countdown_timer_secs <= 0.0 or last_error_msg is not None or
   200                str(response.currentState) == 'JOB_STATE_DONE' or
   201                str(response.currentState) == 'JOB_STATE_CANCELLED' or
   202                str(response.currentState) == 'JOB_STATE_UPDATED' or
   203                str(response.currentState) == 'JOB_STATE_DRAINED'):
   204              break
   205  
   206            # Check that job is in a post-preparation state before starting the
   207            # final countdown.
   208            if (str(response.currentState) not in ('JOB_STATE_PENDING',
   209                                                   'JOB_STATE_QUEUED')):
   210              # The job has failed; ensure we see any final error messages.
   211              sleep_secs = 1.0  # poll faster during the final countdown
   212              final_countdown_timer_secs -= sleep_secs
   213  
   214        time.sleep(sleep_secs)
   215  
   216        # Get all messages since beginning of the job run or since last message.
   217        page_token = None
   218        while True:
   219          messages, page_token = runner.dataflow_client.list_messages(
   220              job_id, page_token=page_token, start_time=last_message_time)
   221          for m in messages:
   222            message = '%s: %s: %s' % (m.time, m.messageImportance, m.messageText)
   223  
   224            if not last_message_time or m.time > last_message_time:
   225              last_message_time = m.time
   226              current_seen_messages = set()
   227  
   228            if message in current_seen_messages:
   229              # Skip the message if it has already been seen at the current
   230              # time. This could be the case since the list_messages API is
   231              # queried starting at last_message_time.
   232              continue
   233            else:
   234              current_seen_messages.add(message)
   235            # Skip empty messages.
   236            if m.messageImportance is None:
   237              continue
   238            _LOGGER.info(message)
   239            if str(m.messageImportance) == 'JOB_MESSAGE_ERROR':
   240              if rank_error(m.messageText) >= last_error_rank:
   241                last_error_rank = rank_error(m.messageText)
   242                last_error_msg = m.messageText
   243          if not page_token:
   244            break
   245  
   246        if duration:
   247          passed_secs = time.time() - start_secs
   248          if passed_secs > duration_secs:
   249            _LOGGER.warning(
   250                'Timing out on waiting for job %s after %d seconds',
   251                job_id,
   252                passed_secs)
   253            break
   254  
   255      result._job = response
   256      runner.last_error_msg = last_error_msg
   257  
   258    @staticmethod
   259    def _only_element(iterable):
   260      # type: (Iterable[T]) -> T # noqa: F821
   261      element, = iterable
   262      return element
   263  
   264    @staticmethod
   265    def side_input_visitor(is_runner_v2=False, deterministic_key_coders=True):
   266      # Imported here to avoid circular dependencies.
   267      # pylint: disable=wrong-import-order, wrong-import-position
   268      from apache_beam.pipeline import PipelineVisitor
   269      from apache_beam.transforms.core import ParDo
   270  
   271      class SideInputVisitor(PipelineVisitor):
   272        """Ensures input `PCollection` used as a side inputs has a `KV` type.
   273  
   274        TODO(BEAM-115): Once Python SDK is compatible with the new Runner API,
   275        we could directly replace the coder instead of mutating the element type.
   276        """
   277        def visit_transform(self, transform_node):
   278          if isinstance(transform_node.transform, ParDo):
   279            new_side_inputs = []
   280            for side_input in transform_node.side_inputs:
   281              access_pattern = side_input._side_input_data().access_pattern
   282              if access_pattern == common_urns.side_inputs.ITERABLE.urn:
   283                # TODO(https://github.com/apache/beam/issues/20043): Stop
   284                # patching up the access pattern to appease Dataflow when
   285                # using the UW and hardcode the output type to be Any since
   286                # the Dataflow JSON and pipeline proto can differ in coders
   287                # which leads to encoding/decoding issues within the runner.
   288                side_input.pvalue.element_type = typehints.Any
   289                new_side_input = _DataflowIterableSideInput(side_input)
   290              elif access_pattern == common_urns.side_inputs.MULTIMAP.urn:
   291                # Ensure the input coder is a KV coder and patch up the
   292                # access pattern to appease Dataflow.
   293                side_input.pvalue.element_type = typehints.coerce_to_kv_type(
   294                    side_input.pvalue.element_type, transform_node.full_label)
   295                side_input.pvalue.requires_deterministic_key_coder = (
   296                    deterministic_key_coders and transform_node.full_label)
   297                new_side_input = _DataflowMultimapSideInput(side_input)
   298              else:
   299                raise ValueError(
   300                    'Unsupported access pattern for %r: %r' %
   301                    (transform_node.full_label, access_pattern))
   302              new_side_inputs.append(new_side_input)
   303            if is_runner_v2:
   304              transform_node.side_inputs = new_side_inputs
   305              transform_node.transform.side_inputs = new_side_inputs
   306  
   307      return SideInputVisitor()
   308  
   309    @staticmethod
   310    def flatten_input_visitor():
   311      # Imported here to avoid circular dependencies.
   312      from apache_beam.pipeline import PipelineVisitor
   313  
   314      class FlattenInputVisitor(PipelineVisitor):
   315        """A visitor that replaces the element type for input ``PCollections``s of
   316         a ``Flatten`` transform with that of the output ``PCollection``.
   317        """
   318        def visit_transform(self, transform_node):
   319          # Imported here to avoid circular dependencies.
   320          # pylint: disable=wrong-import-order, wrong-import-position
   321          from apache_beam import Flatten
   322          if isinstance(transform_node.transform, Flatten):
   323            output_pcoll = DataflowRunner._only_element(
   324                transform_node.outputs.values())
   325            for input_pcoll in transform_node.inputs:
   326              input_pcoll.element_type = output_pcoll.element_type
   327  
   328      return FlattenInputVisitor()
   329  
   330    @staticmethod
   331    def combinefn_visitor():
   332      # Imported here to avoid circular dependencies.
   333      from apache_beam.pipeline import PipelineVisitor
   334      from apache_beam import core
   335  
   336      class CombineFnVisitor(PipelineVisitor):
   337        """Checks if `CombineFn` has non-default setup or teardown methods.
   338        If yes, raises `ValueError`.
   339        """
   340        def visit_transform(self, applied_transform):
   341          transform = applied_transform.transform
   342          if isinstance(transform, core.ParDo) and isinstance(
   343              transform.fn, core.CombineValuesDoFn):
   344            if self._overrides_setup_or_teardown(transform.fn.combinefn):
   345              raise ValueError(
   346                  'CombineFn.setup and CombineFn.teardown are '
   347                  'not supported with non-portable Dataflow '
   348                  'runner. Please use Dataflow Runner V2 instead.')
   349  
   350        @staticmethod
   351        def _overrides_setup_or_teardown(combinefn):
   352          # TODO(https://github.com/apache/beam/issues/18716): provide an
   353          # implementation for this method
   354          return False
   355  
   356      return CombineFnVisitor()
   357  
   358    def _adjust_pipeline_for_dataflow_v2(self, pipeline):
   359      # Dataflow runner requires a KV type for GBK inputs, hence we enforce that
   360      # here.
   361      pipeline.visit(
   362          group_by_key_input_visitor(
   363              not pipeline._options.view_as(
   364                  TypeOptions).allow_non_deterministic_key_coders))
   365  
   366    def _check_for_unsupported_features_on_non_portable_worker(self, pipeline):
   367      pipeline.visit(self.combinefn_visitor())
   368  
   369    def run_pipeline(self, pipeline, options, pipeline_proto=None):
   370      """Remotely executes entire pipeline or parts reachable from node."""
   371      if _is_runner_v2_disabled(options):
   372        debug_options = options.view_as(DebugOptions)
   373        if not debug_options.lookup_experiment('disable_runner_v2_until_v2.50'):
   374          raise ValueError(
   375              'disable_runner_v2 is deprecated in Beam Python ' +
   376              beam.version.__version__ +
   377              ' and this execution mode will be removed in a future Beam SDK. '
   378              'If needed, please use: '
   379              '"--experiments=disable_runner_v2_until_v2.50".')
   380  
   381      # Label goog-dataflow-notebook if job is started from notebook.
   382      if is_in_notebook():
   383        notebook_version = (
   384            'goog-dataflow-notebook=' +
   385            beam.version.__version__.replace('.', '_'))
   386        if options.view_as(GoogleCloudOptions).labels:
   387          options.view_as(GoogleCloudOptions).labels.append(notebook_version)
   388        else:
   389          options.view_as(GoogleCloudOptions).labels = [notebook_version]
   390  
   391      # Import here to avoid adding the dependency for local running scenarios.
   392      try:
   393        # pylint: disable=wrong-import-order, wrong-import-position
   394        from apache_beam.runners.dataflow.internal import apiclient
   395      except ImportError:
   396        raise ImportError(
   397            'Google Cloud Dataflow runner not available, '
   398            'please install apache_beam[gcp]')
   399  
   400      if pipeline_proto or pipeline.contains_external_transforms:
   401        if _is_runner_v2_disabled(options):
   402          raise ValueError(
   403              'This pipeline contains cross language transforms, '
   404              'which requires Runner V2.')
   405        if not _is_runner_v2(options):
   406          _LOGGER.info(
   407              'Automatically enabling Dataflow Runner V2 since the '
   408              'pipeline used cross-language transforms.')
   409          _add_runner_v2_missing_options(options)
   410  
   411      is_runner_v2 = _is_runner_v2(options)
   412      if not is_runner_v2:
   413        self._check_for_unsupported_features_on_non_portable_worker(pipeline)
   414  
   415      # Convert all side inputs into a form acceptable to Dataflow.
   416      if pipeline:
   417        pipeline.visit(
   418            self.side_input_visitor(
   419                _is_runner_v2(options),
   420                deterministic_key_coders=not options.view_as(
   421                    TypeOptions).allow_non_deterministic_key_coders))
   422  
   423        # Performing configured PTransform overrides. Note that this is currently
   424        # done before Runner API serialization, since the new proto needs to
   425        # contain any added PTransforms.
   426        pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES)
   427  
   428        if options.view_as(DebugOptions).lookup_experiment('use_legacy_bq_sink'):
   429          warnings.warn(
   430              "Native sinks no longer implemented; "
   431              "ignoring use_legacy_bq_sink.")
   432  
   433        from apache_beam.runners.dataflow.ptransform_overrides import GroupIntoBatchesWithShardedKeyPTransformOverride
   434        pipeline.replace_all(
   435            [GroupIntoBatchesWithShardedKeyPTransformOverride(self, options)])
   436  
   437      if pipeline_proto:
   438        self.proto_pipeline = pipeline_proto
   439  
   440      else:
   441        from apache_beam.transforms import environments
   442        if options.view_as(SetupOptions).prebuild_sdk_container_engine:
   443          # if prebuild_sdk_container_engine is specified we will build a new sdk
   444          # container image with dependencies pre-installed and use that image,
   445          # instead of using the inferred default container image.
   446          self._default_environment = (
   447              environments.DockerEnvironment.from_options(options))
   448          options.view_as(WorkerOptions).sdk_container_image = (
   449              self._default_environment.container_image)
   450        else:
   451          artifacts = environments.python_sdk_dependencies(options)
   452          if artifacts and _is_runner_v2(options):
   453            _LOGGER.info(
   454                "Pipeline has additional dependencies to be installed "
   455                "in SDK worker container, consider using the SDK "
   456                "container image pre-building workflow to avoid "
   457                "repetitive installations. Learn more on "
   458                "https://cloud.google.com/dataflow/docs/guides/"
   459                "using-custom-containers#prebuild")
   460          self._default_environment = (
   461              environments.DockerEnvironment.from_container_image(
   462                  apiclient.get_container_image_from_options(options),
   463                  artifacts=artifacts,
   464                  resource_hints=environments.resource_hints_from_options(
   465                      options)))
   466  
   467        # This has to be performed before pipeline proto is constructed to make
   468        # sure that the changes are reflected in the portable job submission path.
   469        self._adjust_pipeline_for_dataflow_v2(pipeline)
   470  
   471        # Snapshot the pipeline in a portable proto.
   472        self.proto_pipeline, self.proto_context = pipeline.to_runner_api(
   473            return_context=True, default_environment=self._default_environment)
   474  
   475      # Optimize the pipeline if it not streaming and the pre_optimize
   476      # experiment is set.
   477      if not options.view_as(StandardOptions).streaming:
   478        pre_optimize = options.view_as(DebugOptions).lookup_experiment(
   479            'pre_optimize', 'default').lower()
   480        from apache_beam.runners.portability.fn_api_runner import translations
   481        if pre_optimize == 'none':
   482          phases = []
   483        elif pre_optimize == 'default' or pre_optimize == 'all':
   484          phases = [translations.pack_combiners, translations.sort_stages]
   485        else:
   486          phases = []
   487          for phase_name in pre_optimize.split(','):
   488            # For now, these are all we allow.
   489            if phase_name in ('pack_combiners', ):
   490              phases.append(getattr(translations, phase_name))
   491            else:
   492              raise ValueError(
   493                  'Unknown or inapplicable phase for pre_optimize: %s' %
   494                  phase_name)
   495          phases.append(translations.sort_stages)
   496  
   497        if phases:
   498          self.proto_pipeline = translations.optimize_pipeline(
   499              self.proto_pipeline,
   500              phases=phases,
   501              known_runner_urns=frozenset(),
   502              partial=True)
   503  
   504      if not is_runner_v2:
   505        # Performing configured PTransform overrides which should not be reflected
   506        # in the proto representation of the graph.
   507        pipeline.replace_all(DataflowRunner._NON_PORTABLE_PTRANSFORM_OVERRIDES)
   508  
   509      # Add setup_options for all the BeamPlugin imports
   510      setup_options = options.view_as(SetupOptions)
   511      plugins = BeamPlugin.get_all_plugin_paths()
   512      if setup_options.beam_plugins is not None:
   513        plugins = list(set(plugins + setup_options.beam_plugins))
   514      setup_options.beam_plugins = plugins
   515  
   516      # Elevate "min_cpu_platform" to pipeline option, but using the existing
   517      # experiment.
   518      debug_options = options.view_as(DebugOptions)
   519      worker_options = options.view_as(WorkerOptions)
   520      if worker_options.min_cpu_platform:
   521        debug_options.add_experiment(
   522            'min_cpu_platform=' + worker_options.min_cpu_platform)
   523  
   524      self.job = apiclient.Job(options, self.proto_pipeline)
   525  
   526      # TODO: Consider skipping these for all use_portable_job_submission jobs.
   527      if pipeline:
   528        # Dataflow Runner v1 requires output type of the Flatten to be the same as
   529        # the inputs, hence we enforce that here. Dataflow Runner v2 does not
   530        # require this.
   531        pipeline.visit(self.flatten_input_visitor())
   532  
   533        # Trigger a traversal of all reachable nodes.
   534        self.visit_transforms(pipeline, options)
   535  
   536      test_options = options.view_as(TestOptions)
   537      # If it is a dry run, return without submitting the job.
   538      if test_options.dry_run:
   539        result = PipelineResult(PipelineState.DONE)
   540        result.wait_until_finish = lambda duration=None: None
   541        return result
   542  
   543      # Get a Dataflow API client and set its options
   544      self.dataflow_client = apiclient.DataflowApplicationClient(
   545          options, self.job.root_staging_location)
   546  
   547      # Create the job description and send a request to the service. The result
   548      # can be None if there is no need to send a request to the service (e.g.
   549      # template creation). If a request was sent and failed then the call will
   550      # raise an exception.
   551      result = DataflowPipelineResult(
   552          self.dataflow_client.create_job(self.job), self)
   553  
   554      # TODO(BEAM-4274): Circular import runners-metrics. Requires refactoring.
   555      from apache_beam.runners.dataflow.dataflow_metrics import DataflowMetrics
   556      self._metrics = DataflowMetrics(self.dataflow_client, result, self.job)
   557      result.metric_results = self._metrics
   558      return result
   559  
   560    def _get_typehint_based_encoding(self, typehint, window_coder):
   561      """Returns an encoding based on a typehint object."""
   562      return self._get_cloud_encoding(
   563          self._get_coder(typehint, window_coder=window_coder))
   564  
   565    @staticmethod
   566    def _get_coder(typehint, window_coder):
   567      """Returns a coder based on a typehint object."""
   568      if window_coder:
   569        return coders.WindowedValueCoder(
   570            coders.registry.get_coder(typehint), window_coder=window_coder)
   571      return coders.registry.get_coder(typehint)
   572  
   573    def _get_cloud_encoding(self, coder, unused=None):
   574      """Returns an encoding based on a coder object."""
   575      if not isinstance(coder, coders.Coder):
   576        raise TypeError(
   577            'Coder object must inherit from coders.Coder: %s.' % str(coder))
   578      return coder.as_cloud_object(self.proto_context.coders)
   579  
   580    def _get_side_input_encoding(self, input_encoding):
   581      """Returns an encoding for the output of a view transform.
   582  
   583      Args:
   584        input_encoding: encoding of current transform's input. Side inputs need
   585          this because the service will check that input and output types match.
   586  
   587      Returns:
   588        An encoding that matches the output and input encoding. This is essential
   589        for the View transforms introduced to produce side inputs to a ParDo.
   590      """
   591      return {
   592          '@type': 'kind:stream',
   593          'component_encodings': [input_encoding],
   594          'is_stream_like': {
   595              'value': True
   596          },
   597      }
   598  
   599    def _get_encoded_output_coder(
   600        self, transform_node, window_value=True, output_tag=None):
   601      """Returns the cloud encoding of the coder for the output of a transform."""
   602  
   603      if output_tag in transform_node.outputs:
   604        element_type = transform_node.outputs[output_tag].element_type
   605      elif len(transform_node.outputs) == 1:
   606        output_tag = DataflowRunner._only_element(transform_node.outputs.keys())
   607        # TODO(robertwb): Handle type hints for multi-output transforms.
   608        element_type = transform_node.outputs[output_tag].element_type
   609  
   610      else:
   611        # TODO(silviuc): Remove this branch (and assert) when typehints are
   612        # propagated everywhere. Returning an 'Any' as type hint will trigger
   613        # usage of the fallback coder (i.e., cPickler).
   614        element_type = typehints.Any
   615      if window_value:
   616        # All outputs have the same windowing. So getting the coder from an
   617        # arbitrary window is fine.
   618        output_tag = next(iter(transform_node.outputs.keys()))
   619        window_coder = (
   620            transform_node.outputs[output_tag].windowing.windowfn.
   621            get_window_coder())
   622      else:
   623        window_coder = None
   624      return self._get_typehint_based_encoding(element_type, window_coder)
   625  
   626    def get_pcoll_with_auto_sharding(self):
   627      if not hasattr(self, '_pcoll_with_auto_sharding'):
   628        return set()
   629      return self._pcoll_with_auto_sharding
   630  
   631    def add_pcoll_with_auto_sharding(self, applied_ptransform):
   632      if not hasattr(self, '_pcoll_with_auto_sharding'):
   633        self.__setattr__('_pcoll_with_auto_sharding', set())
   634      output = DataflowRunner._only_element(applied_ptransform.outputs.keys())
   635      self._pcoll_with_auto_sharding.add(
   636          applied_ptransform.outputs[output]._unique_name())
   637  
   638    def _add_step(self, step_kind, step_label, transform_node, side_tags=()):
   639      """Creates a Step object and adds it to the cache."""
   640      # Import here to avoid adding the dependency for local running scenarios.
   641      # pylint: disable=wrong-import-order, wrong-import-position
   642      from apache_beam.runners.dataflow.internal import apiclient
   643      step = apiclient.Step(step_kind, self._get_unique_step_name())
   644      self.job.proto.steps.append(step.proto)
   645      step.add_property(PropertyNames.USER_NAME, step_label)
   646      # Cache the node/step association for the main output of the transform node.
   647  
   648      # External transforms may not use 'None' as an output tag.
   649      output_tags = ([None] +
   650                     list(side_tags) if None in transform_node.outputs.keys() else
   651                     list(transform_node.outputs.keys()))
   652  
   653      # We have to cache output for all tags since some transforms may produce
   654      # multiple outputs.
   655      for output_tag in output_tags:
   656        self._cache.cache_output(transform_node, output_tag, step)
   657  
   658      # Finally, we add the display data items to the pipeline step.
   659      # If the transform contains no display data then an empty list is added.
   660      step.add_property(
   661          PropertyNames.DISPLAY_DATA,
   662          [
   663              item.get_dict()
   664              for item in DisplayData.create_from(transform_node.transform).items
   665          ])
   666  
   667      if transform_node.resource_hints:
   668        step.add_property(
   669            PropertyNames.RESOURCE_HINTS,
   670            {
   671                hint: quote_from_bytes(value)
   672                for (hint, value) in transform_node.resource_hints.items()
   673            })
   674  
   675      return step
   676  
   677    def _add_singleton_step(
   678        self,
   679        label,
   680        full_label,
   681        tag,
   682        input_step,
   683        windowing_strategy,
   684        access_pattern):
   685      """Creates a CollectionToSingleton step used to handle ParDo side inputs."""
   686      # Import here to avoid adding the dependency for local running scenarios.
   687      from apache_beam.runners.dataflow.internal import apiclient
   688      step = apiclient.Step(TransformNames.COLLECTION_TO_SINGLETON, label)
   689      self.job.proto.steps.append(step.proto)
   690      step.add_property(PropertyNames.USER_NAME, full_label)
   691      step.add_property(
   692          PropertyNames.PARALLEL_INPUT,
   693          {
   694              '@type': 'OutputReference',
   695              PropertyNames.STEP_NAME: input_step.proto.name,
   696              PropertyNames.OUTPUT_NAME: input_step.get_output(tag)
   697          })
   698      step.encoding = self._get_side_input_encoding(input_step.encoding)
   699  
   700      output_info = {
   701          PropertyNames.USER_NAME: '%s.%s' % (full_label, PropertyNames.OUTPUT),
   702          PropertyNames.ENCODING: step.encoding,
   703          PropertyNames.OUTPUT_NAME: PropertyNames.OUT
   704      }
   705      if common_urns.side_inputs.MULTIMAP.urn == access_pattern:
   706        output_info[PropertyNames.USE_INDEXED_FORMAT] = True
   707      step.add_property(PropertyNames.OUTPUT_INFO, [output_info])
   708  
   709      step.add_property(
   710          PropertyNames.WINDOWING_STRATEGY,
   711          self.serialize_windowing_strategy(
   712              windowing_strategy, self._default_environment))
   713      return step
   714  
   715    def run_Impulse(self, transform_node, options):
   716      step = self._add_step(
   717          TransformNames.READ, transform_node.full_label, transform_node)
   718      step.add_property(PropertyNames.FORMAT, 'impulse')
   719      encoded_impulse_element = coders.WindowedValueCoder(
   720          coders.BytesCoder(),
   721          coders.coders.GlobalWindowCoder()).get_impl().encode_nested(
   722              window.GlobalWindows.windowed_value(b''))
   723      if _is_runner_v2(options):
   724        encoded_impulse_as_str = self.byte_array_to_json_string(
   725            encoded_impulse_element)
   726      else:
   727        encoded_impulse_as_str = base64.b64encode(encoded_impulse_element).decode(
   728            'ascii')
   729  
   730      step.add_property(PropertyNames.IMPULSE_ELEMENT, encoded_impulse_as_str)
   731  
   732      step.encoding = self._get_encoded_output_coder(transform_node)
   733      step.add_property(
   734          PropertyNames.OUTPUT_INFO,
   735          [{
   736              PropertyNames.USER_NAME: (
   737                  '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),
   738              PropertyNames.ENCODING: step.encoding,
   739              PropertyNames.OUTPUT_NAME: PropertyNames.OUT
   740          }])
   741  
   742    def run_Flatten(self, transform_node, options):
   743      step = self._add_step(
   744          TransformNames.FLATTEN, transform_node.full_label, transform_node)
   745      inputs = []
   746      for one_input in transform_node.inputs:
   747        input_step = self._cache.get_pvalue(one_input)
   748        inputs.append({
   749            '@type': 'OutputReference',
   750            PropertyNames.STEP_NAME: input_step.proto.name,
   751            PropertyNames.OUTPUT_NAME: input_step.get_output(one_input.tag)
   752        })
   753      step.add_property(PropertyNames.INPUTS, inputs)
   754      step.encoding = self._get_encoded_output_coder(transform_node)
   755      step.add_property(
   756          PropertyNames.OUTPUT_INFO,
   757          [{
   758              PropertyNames.USER_NAME: (
   759                  '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),
   760              PropertyNames.ENCODING: step.encoding,
   761              PropertyNames.OUTPUT_NAME: PropertyNames.OUT
   762          }])
   763  
   764    # TODO(srohde): Remove this after internal usages have been removed.
   765    def apply_GroupByKey(self, transform, pcoll, options):
   766      return transform.expand(pcoll)
   767  
   768    def _verify_gbk_coders(self, transform, pcoll):
   769      # Infer coder of parent.
   770      #
   771      # TODO(ccy): make Coder inference and checking less specialized and more
   772      # comprehensive.
   773  
   774      parent = pcoll.producer
   775      if parent:
   776        coder = parent.transform._infer_output_coder()  # pylint: disable=protected-access
   777      if not coder:
   778        coder = self._get_coder(pcoll.element_type or typehints.Any, None)
   779      if not coder.is_kv_coder():
   780        raise ValueError((
   781            'Coder for the GroupByKey operation "%s" is not a '
   782            'key-value coder: %s.') % (transform.label, coder))
   783      # TODO(robertwb): Update the coder itself if it changed.
   784      coders.registry.verify_deterministic(
   785          coder.key_coder(), 'GroupByKey operation "%s"' % transform.label)
   786  
   787    def run_GroupByKey(self, transform_node, options):
   788      input_tag = transform_node.inputs[0].tag
   789      input_step = self._cache.get_pvalue(transform_node.inputs[0])
   790  
   791      # Verify that the GBK's parent has a KV coder.
   792      self._verify_gbk_coders(transform_node.transform, transform_node.inputs[0])
   793  
   794      step = self._add_step(
   795          TransformNames.GROUP, transform_node.full_label, transform_node)
   796      step.add_property(
   797          PropertyNames.PARALLEL_INPUT,
   798          {
   799              '@type': 'OutputReference',
   800              PropertyNames.STEP_NAME: input_step.proto.name,
   801              PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)
   802          })
   803      step.encoding = self._get_encoded_output_coder(transform_node)
   804      step.add_property(
   805          PropertyNames.OUTPUT_INFO,
   806          [{
   807              PropertyNames.USER_NAME: (
   808                  '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),
   809              PropertyNames.ENCODING: step.encoding,
   810              PropertyNames.OUTPUT_NAME: PropertyNames.OUT
   811          }])
   812      windowing = transform_node.transform.get_windowing(transform_node.inputs)
   813      step.add_property(
   814          PropertyNames.SERIALIZED_FN,
   815          self.serialize_windowing_strategy(windowing, self._default_environment))
   816  
   817    def run_ExternalTransform(self, transform_node, options):
   818      # Adds a dummy step to the Dataflow job description so that inputs and
   819      # outputs are mapped correctly in the presence of external transforms.
   820      #
   821      # Note that Dataflow Python multi-language pipelines use Portable Job
   822      # Submission by default, hence this step and rest of the Dataflow step
   823      # definitions defined here are not used at Dataflow service but we have to
   824      # maintain the mapping correctly till we can fully drop the Dataflow step
   825      # definitions from the SDK.
   826  
   827      # AppliedTransform node outputs have to be updated to correctly map the
   828      # outputs for external transforms.
   829      transform_node.outputs = ({
   830          output.tag: output
   831          for output in transform_node.outputs.values()
   832      })
   833  
   834      self.run_Impulse(transform_node, options)
   835  
   836    def run_ParDo(self, transform_node, options):
   837      transform = transform_node.transform
   838      input_tag = transform_node.inputs[0].tag
   839      input_step = self._cache.get_pvalue(transform_node.inputs[0])
   840  
   841      # Attach side inputs.
   842      si_dict = {}
   843      si_labels = {}
   844      full_label_counts = defaultdict(int)
   845      lookup_label = lambda side_pval: si_labels[side_pval]
   846      named_inputs = transform_node.named_inputs()
   847      label_renames = {}
   848      for ix, side_pval in enumerate(transform_node.side_inputs):
   849        assert isinstance(side_pval, AsSideInput)
   850        step_name = 'SideInput-' + self._get_unique_step_name()
   851        si_label = ((SIDE_INPUT_PREFIX + '%d-%s') %
   852                    (ix, transform_node.full_label))
   853        old_label = (SIDE_INPUT_PREFIX + '%d') % ix
   854  
   855        label_renames[old_label] = si_label
   856  
   857        assert old_label in named_inputs
   858        pcollection_label = '%s.%s' % (
   859            side_pval.pvalue.producer.full_label.split('/')[-1],
   860            side_pval.pvalue.tag if side_pval.pvalue.tag else 'out')
   861        si_full_label = '%s/%s(%s.%s)' % (
   862            transform_node.full_label,
   863            side_pval.__class__.__name__,
   864            pcollection_label,
   865            full_label_counts[pcollection_label])
   866  
   867        # Count the number of times the same PCollection is a side input
   868        # to the same ParDo.
   869        full_label_counts[pcollection_label] += 1
   870  
   871        self._add_singleton_step(
   872            step_name,
   873            si_full_label,
   874            side_pval.pvalue.tag,
   875            self._cache.get_pvalue(side_pval.pvalue),
   876            side_pval.pvalue.windowing,
   877            side_pval._side_input_data().access_pattern)
   878        si_dict[si_label] = {
   879            '@type': 'OutputReference',
   880            PropertyNames.STEP_NAME: step_name,
   881            PropertyNames.OUTPUT_NAME: PropertyNames.OUT
   882        }
   883        si_labels[side_pval] = si_label
   884  
   885      # Now create the step for the ParDo transform being handled.
   886      transform_name = transform_node.full_label.rsplit('/', 1)[-1]
   887      step = self._add_step(
   888          TransformNames.DO,
   889          transform_node.full_label +
   890          ('/{}'.format(transform_name) if transform_node.side_inputs else ''),
   891          transform_node,
   892          transform_node.transform.output_tags)
   893      transform_proto = self.proto_context.transforms.get_proto(transform_node)
   894      transform_id = self.proto_context.transforms.get_id(transform_node)
   895      is_runner_v2 = _is_runner_v2(options)
   896      # Patch side input ids to be unique across a given pipeline.
   897      if (label_renames and
   898          transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn):
   899        # Patch PTransform proto.
   900        for old, new in label_renames.items():
   901          transform_proto.inputs[new] = transform_proto.inputs[old]
   902          del transform_proto.inputs[old]
   903  
   904        # Patch ParDo proto.
   905        proto_type, _ = beam.PTransform._known_urns[transform_proto.spec.urn]
   906        proto = proto_utils.parse_Bytes(transform_proto.spec.payload, proto_type)
   907        for old, new in label_renames.items():
   908          proto.side_inputs[new].CopyFrom(proto.side_inputs[old])
   909          del proto.side_inputs[old]
   910        transform_proto.spec.payload = proto.SerializeToString()
   911        # We need to update the pipeline proto.
   912        del self.proto_pipeline.components.transforms[transform_id]
   913        (
   914            self.proto_pipeline.components.transforms[transform_id].CopyFrom(
   915                transform_proto))
   916      # The data transmitted in SERIALIZED_FN is different depending on whether
   917      # this is a runner v2 pipeline or not.
   918      if is_runner_v2:
   919        serialized_data = transform_id
   920      else:
   921        serialized_data = pickler.dumps(
   922            self._pardo_fn_data(transform_node, lookup_label))
   923      step.add_property(PropertyNames.SERIALIZED_FN, serialized_data)
   924      # TODO(BEAM-8882): Enable once dataflow service doesn't reject this.
   925      # step.add_property(PropertyNames.PIPELINE_PROTO_TRANSFORM_ID, transform_id)
   926      step.add_property(
   927          PropertyNames.PARALLEL_INPUT,
   928          {
   929              '@type': 'OutputReference',
   930              PropertyNames.STEP_NAME: input_step.proto.name,
   931              PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)
   932          })
   933      # Add side inputs if any.
   934      step.add_property(PropertyNames.NON_PARALLEL_INPUTS, si_dict)
   935  
   936      # Generate description for the outputs. The output names
   937      # will be 'None' for main output and '<tag>' for a tagged output.
   938      outputs = []
   939  
   940      all_output_tags = list(transform_proto.outputs.keys())
   941  
   942      # Some external transforms require output tags to not be modified.
   943      # So we randomly select one of the output tags as the main output and
   944      # leave others as side outputs. Transform execution should not change
   945      # dependending on which output tag we choose as the main output here.
   946      # Also, some SDKs do not work correctly if output tags are modified. So for
   947      # external transforms, we leave tags unmodified.
   948      #
   949      # Python SDK uses 'None' as the tag of the main output.
   950      main_output_tag = 'None'
   951  
   952      step.encoding = self._get_encoded_output_coder(
   953          transform_node, output_tag=main_output_tag)
   954  
   955      side_output_tags = set(all_output_tags).difference({main_output_tag})
   956  
   957      # Add the main output to the description.
   958      outputs.append({
   959          PropertyNames.USER_NAME: (
   960              '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),
   961          PropertyNames.ENCODING: step.encoding,
   962          PropertyNames.OUTPUT_NAME: main_output_tag
   963      })
   964      for side_tag in side_output_tags:
   965        # The assumption here is that all outputs will have the same typehint
   966        # and coder as the main output. This is certainly the case right now
   967        # but conceivably it could change in the future.
   968        encoding = self._get_encoded_output_coder(
   969            transform_node, output_tag=side_tag)
   970        outputs.append({
   971            PropertyNames.USER_NAME: (
   972                '%s.%s' % (transform_node.full_label, side_tag)),
   973            PropertyNames.ENCODING: encoding,
   974            PropertyNames.OUTPUT_NAME: side_tag
   975        })
   976  
   977      step.add_property(PropertyNames.OUTPUT_INFO, outputs)
   978  
   979      # Add the restriction encoding if we are a splittable DoFn
   980      restriction_coder = transform.get_restriction_coder()
   981      if restriction_coder:
   982        step.add_property(
   983            PropertyNames.RESTRICTION_ENCODING,
   984            self._get_cloud_encoding(restriction_coder))
   985  
   986      if options.view_as(StandardOptions).streaming:
   987        is_stateful_dofn = (DoFnSignature(transform.dofn).is_stateful_dofn())
   988        if is_stateful_dofn:
   989          step.add_property(PropertyNames.USES_KEYED_STATE, 'true')
   990  
   991          # Also checks whether the step allows shardable keyed states.
   992          # TODO(BEAM-11360): remove this when migrated to portable job
   993          #  submission since we only consider supporting the property in runner
   994          #  v2.
   995          for pcoll in transform_node.outputs.values():
   996            if pcoll._unique_name() in self.get_pcoll_with_auto_sharding():
   997              step.add_property(PropertyNames.ALLOWS_SHARDABLE_STATE, 'true')
   998              # Currently we only allow auto-sharding to be enabled through the
   999              # GroupIntoBatches transform. So we also add the following property
  1000              # which GroupIntoBatchesDoFn has, to allow the backend to perform
  1001              # graph optimization.
  1002              step.add_property(PropertyNames.PRESERVES_KEYS, 'true')
  1003              break
  1004  
  1005    @staticmethod
  1006    def _pardo_fn_data(transform_node, get_label):
  1007      transform = transform_node.transform
  1008      si_tags_and_types = [  # pylint: disable=protected-access
  1009          (get_label(side_pval), side_pval.__class__, side_pval._view_options())
  1010          for side_pval in transform_node.side_inputs]
  1011      return (
  1012          transform.fn,
  1013          transform.args,
  1014          transform.kwargs,
  1015          si_tags_and_types,
  1016          transform_node.inputs[0].windowing)
  1017  
  1018    def run_CombineValuesReplacement(self, transform_node, options):
  1019      transform = transform_node.transform.transform
  1020      input_tag = transform_node.inputs[0].tag
  1021      input_step = self._cache.get_pvalue(transform_node.inputs[0])
  1022      step = self._add_step(
  1023          TransformNames.COMBINE, transform_node.full_label, transform_node)
  1024      transform_id = self.proto_context.transforms.get_id(transform_node.parent)
  1025  
  1026      # The data transmitted in SERIALIZED_FN is different depending on whether
  1027      # this is a runner v2 pipeline or not.
  1028      if _is_runner_v2(options):
  1029        # Fnapi pipelines send the transform ID of the CombineValues transform's
  1030        # parent composite because Dataflow expects the ID of a CombinePerKey
  1031        # transform.
  1032        serialized_data = transform_id
  1033      else:
  1034        # Combiner functions do not take deferred side-inputs (i.e. PValues) and
  1035        # therefore the code to handle extra args/kwargs is simpler than for the
  1036        # DoFn's of the ParDo transform. In the last, empty argument is where
  1037        # side inputs information would go.
  1038        serialized_data = pickler.dumps(
  1039            (transform.fn, transform.args, transform.kwargs, ()))
  1040      step.add_property(PropertyNames.SERIALIZED_FN, serialized_data)
  1041      # TODO(BEAM-8882): Enable once dataflow service doesn't reject this.
  1042      # step.add_property(PropertyNames.PIPELINE_PROTO_TRANSFORM_ID, transform_id)
  1043      step.add_property(
  1044          PropertyNames.PARALLEL_INPUT,
  1045          {
  1046              '@type': 'OutputReference',
  1047              PropertyNames.STEP_NAME: input_step.proto.name,
  1048              PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)
  1049          })
  1050      # Note that the accumulator must not have a WindowedValue encoding, while
  1051      # the output of this step does in fact have a WindowedValue encoding.
  1052      accumulator_encoding = self._get_cloud_encoding(
  1053          transform.fn.get_accumulator_coder())
  1054      output_encoding = self._get_encoded_output_coder(transform_node)
  1055  
  1056      step.encoding = output_encoding
  1057      step.add_property(PropertyNames.ENCODING, accumulator_encoding)
  1058      # Generate description for main output 'out.'
  1059      outputs = []
  1060      # Add the main output to the description.
  1061      outputs.append({
  1062          PropertyNames.USER_NAME: (
  1063              '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),
  1064          PropertyNames.ENCODING: step.encoding,
  1065          PropertyNames.OUTPUT_NAME: PropertyNames.OUT
  1066      })
  1067      step.add_property(PropertyNames.OUTPUT_INFO, outputs)
  1068  
  1069    def run_Read(self, transform_node, options):
  1070      transform = transform_node.transform
  1071      step = self._add_step(
  1072          TransformNames.READ, transform_node.full_label, transform_node)
  1073      # TODO(mairbek): refactor if-else tree to use registerable functions.
  1074      # Initialize the source specific properties.
  1075  
  1076      standard_options = options.view_as(StandardOptions)
  1077      if not hasattr(transform.source, 'format'):
  1078        # If a format is not set, we assume the source to be a custom source.
  1079        source_dict = {}
  1080  
  1081        source_dict['spec'] = {
  1082            '@type': names.SOURCE_TYPE,
  1083            names.SERIALIZED_SOURCE_KEY: pickler.dumps(transform.source)
  1084        }
  1085  
  1086        try:
  1087          source_dict['metadata'] = {
  1088              'estimated_size_bytes': json_value.get_typed_value_descriptor(
  1089                  transform.source.estimate_size())
  1090          }
  1091        except error.RuntimeValueProviderError:
  1092          # Size estimation is best effort, and this error is by value provider.
  1093          _LOGGER.info(
  1094              'Could not estimate size of source %r due to ' + \
  1095              'RuntimeValueProviderError', transform.source)
  1096        except Exception:  # pylint: disable=broad-except
  1097          # Size estimation is best effort. So we log the error and continue.
  1098          _LOGGER.info(
  1099              'Could not estimate size of source %r due to an exception: %s',
  1100              transform.source,
  1101              traceback.format_exc())
  1102  
  1103        step.add_property(PropertyNames.SOURCE_STEP_INPUT, source_dict)
  1104      elif transform.source.format == 'pubsub':
  1105        if not standard_options.streaming:
  1106          raise ValueError(
  1107              'Cloud Pub/Sub is currently available for use '
  1108              'only in streaming pipelines.')
  1109        # Only one of topic or subscription should be set.
  1110        if transform.source.full_subscription:
  1111          step.add_property(
  1112              PropertyNames.PUBSUB_SUBSCRIPTION,
  1113              transform.source.full_subscription)
  1114        elif transform.source.full_topic:
  1115          step.add_property(
  1116              PropertyNames.PUBSUB_TOPIC, transform.source.full_topic)
  1117        if transform.source.id_label:
  1118          step.add_property(
  1119              PropertyNames.PUBSUB_ID_LABEL, transform.source.id_label)
  1120        if transform.source.with_attributes:
  1121          # Setting this property signals Dataflow runner to return full
  1122          # PubsubMessages instead of just the data part of the payload.
  1123          step.add_property(PropertyNames.PUBSUB_SERIALIZED_ATTRIBUTES_FN, '')
  1124  
  1125        if transform.source.timestamp_attribute is not None:
  1126          step.add_property(
  1127              PropertyNames.PUBSUB_TIMESTAMP_ATTRIBUTE,
  1128              transform.source.timestamp_attribute)
  1129      else:
  1130        raise ValueError(
  1131            'Source %r has unexpected format %s.' %
  1132            (transform.source, transform.source.format))
  1133  
  1134      if not hasattr(transform.source, 'format'):
  1135        step.add_property(PropertyNames.FORMAT, names.SOURCE_FORMAT)
  1136      else:
  1137        step.add_property(PropertyNames.FORMAT, transform.source.format)
  1138  
  1139      # Wrap coder in WindowedValueCoder: this is necessary as the encoding of a
  1140      # step should be the type of value outputted by each step.  Read steps
  1141      # automatically wrap output values in a WindowedValue wrapper, if necessary.
  1142      # This is also necessary for proper encoding for size estimation.
  1143      # Using a GlobalWindowCoder as a place holder instead of the default
  1144      # PickleCoder because GlobalWindowCoder is known coder.
  1145      # TODO(robertwb): Query the collection for the windowfn to extract the
  1146      # correct coder.
  1147      coder = coders.WindowedValueCoder(
  1148          coders.registry.get_coder(transform_node.outputs[None].element_type),
  1149          coders.coders.GlobalWindowCoder())
  1150  
  1151      step.encoding = self._get_cloud_encoding(coder)
  1152      step.add_property(
  1153          PropertyNames.OUTPUT_INFO,
  1154          [{
  1155              PropertyNames.USER_NAME: (
  1156                  '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),
  1157              PropertyNames.ENCODING: step.encoding,
  1158              PropertyNames.OUTPUT_NAME: PropertyNames.OUT
  1159          }])
  1160  
  1161    def run__NativeWrite(self, transform_node, options):
  1162      transform = transform_node.transform
  1163      input_tag = transform_node.inputs[0].tag
  1164      input_step = self._cache.get_pvalue(transform_node.inputs[0])
  1165      step = self._add_step(
  1166          TransformNames.WRITE, transform_node.full_label, transform_node)
  1167      # TODO(mairbek): refactor if-else tree to use registerable functions.
  1168      # Initialize the sink specific properties.
  1169      if transform.sink.format == 'pubsub':
  1170        standard_options = options.view_as(StandardOptions)
  1171        if not standard_options.streaming:
  1172          raise ValueError(
  1173              'Cloud Pub/Sub is currently available for use '
  1174              'only in streaming pipelines.')
  1175        step.add_property(PropertyNames.PUBSUB_TOPIC, transform.sink.full_topic)
  1176        if transform.sink.id_label:
  1177          step.add_property(
  1178              PropertyNames.PUBSUB_ID_LABEL, transform.sink.id_label)
  1179        # Setting this property signals Dataflow runner that the PCollection
  1180        # contains PubsubMessage objects instead of just raw data.
  1181        step.add_property(PropertyNames.PUBSUB_SERIALIZED_ATTRIBUTES_FN, '')
  1182        if transform.sink.timestamp_attribute is not None:
  1183          step.add_property(
  1184              PropertyNames.PUBSUB_TIMESTAMP_ATTRIBUTE,
  1185              transform.sink.timestamp_attribute)
  1186      else:
  1187        raise ValueError(
  1188            'Sink %r has unexpected format %s.' %
  1189            (transform.sink, transform.sink.format))
  1190      step.add_property(PropertyNames.FORMAT, transform.sink.format)
  1191  
  1192      # Wrap coder in WindowedValueCoder: this is necessary for proper encoding
  1193      # for size estimation. Using a GlobalWindowCoder as a place holder instead
  1194      # of the default PickleCoder because GlobalWindowCoder is known coder.
  1195      # TODO(robertwb): Query the collection for the windowfn to extract the
  1196      # correct coder.
  1197      coder = coders.WindowedValueCoder(
  1198          transform.sink.coder, coders.coders.GlobalWindowCoder())
  1199      step.encoding = self._get_cloud_encoding(coder)
  1200      step.add_property(PropertyNames.ENCODING, step.encoding)
  1201      step.add_property(
  1202          PropertyNames.PARALLEL_INPUT,
  1203          {
  1204              '@type': 'OutputReference',
  1205              PropertyNames.STEP_NAME: input_step.proto.name,
  1206              PropertyNames.OUTPUT_NAME: input_step.get_output(input_tag)
  1207          })
  1208  
  1209    def run_TestStream(self, transform_node, options):
  1210      from apache_beam.testing.test_stream import ElementEvent
  1211      from apache_beam.testing.test_stream import ProcessingTimeEvent
  1212      from apache_beam.testing.test_stream import WatermarkEvent
  1213      standard_options = options.view_as(StandardOptions)
  1214      if not standard_options.streaming:
  1215        raise ValueError(
  1216            'TestStream is currently available for use '
  1217            'only in streaming pipelines.')
  1218  
  1219      transform = transform_node.transform
  1220      step = self._add_step(
  1221          TransformNames.READ, transform_node.full_label, transform_node)
  1222      step.add_property(
  1223          PropertyNames.SERIALIZED_FN,
  1224          self.proto_context.transforms.get_id(transform_node))
  1225      step.add_property(PropertyNames.FORMAT, 'test_stream')
  1226      test_stream_payload = beam_runner_api_pb2.TestStreamPayload()
  1227      # TestStream source doesn't do any decoding of elements,
  1228      # so we won't set test_stream_payload.coder_id.
  1229      output_coder = transform._infer_output_coder()  # pylint: disable=protected-access
  1230      for event in transform._events:
  1231        new_event = test_stream_payload.events.add()
  1232        if isinstance(event, ElementEvent):
  1233          for tv in event.timestamped_values:
  1234            element = new_event.element_event.elements.add()
  1235            element.encoded_element = output_coder.encode(tv.value)
  1236            element.timestamp = tv.timestamp.micros
  1237        elif isinstance(event, ProcessingTimeEvent):
  1238          new_event.processing_time_event.advance_duration = (
  1239              event.advance_by.micros)
  1240        elif isinstance(event, WatermarkEvent):
  1241          new_event.watermark_event.new_watermark = event.new_watermark.micros
  1242      serialized_payload = self.byte_array_to_json_string(
  1243          test_stream_payload.SerializeToString())
  1244      step.add_property(PropertyNames.SERIALIZED_TEST_STREAM, serialized_payload)
  1245  
  1246      step.encoding = self._get_encoded_output_coder(transform_node)
  1247      step.add_property(
  1248          PropertyNames.OUTPUT_INFO,
  1249          [{
  1250              PropertyNames.USER_NAME: (
  1251                  '%s.%s' % (transform_node.full_label, PropertyNames.OUT)),
  1252              PropertyNames.ENCODING: step.encoding,
  1253              PropertyNames.OUTPUT_NAME: PropertyNames.OUT
  1254          }])
  1255  
  1256    # We must mark this method as not a test or else its name is a matcher for
  1257    # nosetest tests.
  1258    run_TestStream.__test__ = False  # type: ignore[attr-defined]
  1259  
  1260    @classmethod
  1261    def serialize_windowing_strategy(cls, windowing, default_environment):
  1262      from apache_beam.runners import pipeline_context
  1263      context = pipeline_context.PipelineContext(
  1264          default_environment=default_environment)
  1265      windowing_proto = windowing.to_runner_api(context)
  1266      return cls.byte_array_to_json_string(
  1267          beam_runner_api_pb2.MessageWithComponents(
  1268              components=context.to_runner_api(),
  1269              windowing_strategy=windowing_proto).SerializeToString())
  1270  
  1271    @classmethod
  1272    def deserialize_windowing_strategy(cls, serialized_data):
  1273      # Imported here to avoid circular dependencies.
  1274      # pylint: disable=wrong-import-order, wrong-import-position
  1275      from apache_beam.runners import pipeline_context
  1276      from apache_beam.transforms.core import Windowing
  1277      proto = beam_runner_api_pb2.MessageWithComponents()
  1278      proto.ParseFromString(cls.json_string_to_byte_array(serialized_data))
  1279      return Windowing.from_runner_api(
  1280          proto.windowing_strategy,
  1281          pipeline_context.PipelineContext(proto.components))
  1282  
  1283    @staticmethod
  1284    def byte_array_to_json_string(raw_bytes):
  1285      """Implements org.apache.beam.sdk.util.StringUtils.byteArrayToJsonString."""
  1286      return quote(raw_bytes)
  1287  
  1288    @staticmethod
  1289    def json_string_to_byte_array(encoded_string):
  1290      """Implements org.apache.beam.sdk.util.StringUtils.jsonStringToByteArray."""
  1291      return unquote_to_bytes(encoded_string)
  1292  
  1293    def get_default_gcp_region(self):
  1294      """Get a default value for Google Cloud region according to
  1295      https://cloud.google.com/compute/docs/gcloud-compute/#default-properties.
  1296      If no default can be found, returns None.
  1297      """
  1298      environment_region = os.environ.get('CLOUDSDK_COMPUTE_REGION')
  1299      if environment_region:
  1300        _LOGGER.info(
  1301            'Using default GCP region %s from $CLOUDSDK_COMPUTE_REGION',
  1302            environment_region)
  1303        return environment_region
  1304      try:
  1305        cmd = ['gcloud', 'config', 'get-value', 'compute/region']
  1306        raw_output = processes.check_output(cmd, stderr=DEVNULL)
  1307        formatted_output = raw_output.decode('utf-8').strip()
  1308        if formatted_output:
  1309          _LOGGER.info(
  1310              'Using default GCP region %s from `%s`',
  1311              formatted_output,
  1312              ' '.join(cmd))
  1313          return formatted_output
  1314      except RuntimeError:
  1315        pass
  1316      return None
  1317  
  1318  
  1319  class _DataflowSideInput(beam.pvalue.AsSideInput):
  1320    """Wraps a side input as a dataflow-compatible side input."""
  1321    def _view_options(self):
  1322      return {
  1323          'data': self._data,
  1324      }
  1325  
  1326    def _side_input_data(self):
  1327      return self._data
  1328  
  1329  
  1330  def _add_runner_v2_missing_options(options):
  1331    debug_options = options.view_as(DebugOptions)
  1332    debug_options.add_experiment('beam_fn_api')
  1333    debug_options.add_experiment('use_unified_worker')
  1334    debug_options.add_experiment('use_runner_v2')
  1335    debug_options.add_experiment('use_portable_job_submission')
  1336  
  1337  
  1338  def _check_and_add_missing_options(options):
  1339    # Type: (PipelineOptions) -> None
  1340  
  1341    """Validates and adds missing pipeline options depending on options set.
  1342  
  1343    :param options: PipelineOptions for this pipeline.
  1344    """
  1345    debug_options = options.view_as(DebugOptions)
  1346    dataflow_service_options = options.view_as(
  1347        GoogleCloudOptions).dataflow_service_options or []
  1348    options.view_as(
  1349        GoogleCloudOptions).dataflow_service_options = dataflow_service_options
  1350  
  1351    # Ensure that prime is specified as an experiment if specified as a dataflow
  1352    # service option
  1353    if 'enable_prime' in dataflow_service_options:
  1354      debug_options.add_experiment('enable_prime')
  1355    elif debug_options.lookup_experiment('enable_prime'):
  1356      dataflow_service_options.append('enable_prime')
  1357  
  1358    # Streaming only supports using runner v2 (aka unified worker).
  1359    # Runner v2 only supports using streaming engine (aka windmill service)
  1360    if options.view_as(StandardOptions).streaming:
  1361      google_cloud_options = options.view_as(GoogleCloudOptions)
  1362      if _is_runner_v2_disabled(options):
  1363        raise ValueError(
  1364            'Disabling Runner V2 no longer supported for streaming pipeline '
  1365            'using Beam Python %s.' % beam.version.__version__)
  1366  
  1367      if (not google_cloud_options.enable_streaming_engine and
  1368          (debug_options.lookup_experiment("enable_windmill_service") or
  1369           debug_options.lookup_experiment("enable_streaming_engine"))):
  1370        raise ValueError(
  1371            """Streaming engine both disabled and enabled:
  1372            --enable_streaming_engine flag is not set, but
  1373            enable_windmill_service and/or enable_streaming_engine experiments
  1374            are present. It is recommended you only set the
  1375            --enable_streaming_engine flag.""")
  1376  
  1377      # Ensure that if we detected a streaming pipeline that streaming specific
  1378      # options and experiments.
  1379      options.view_as(StandardOptions).streaming = True
  1380      google_cloud_options.enable_streaming_engine = True
  1381      debug_options.add_experiment("enable_streaming_engine")
  1382      debug_options.add_experiment("enable_windmill_service")
  1383      _add_runner_v2_missing_options(debug_options)
  1384    elif (debug_options.lookup_experiment('enable_prime') or
  1385          debug_options.lookup_experiment('beam_fn_api') or
  1386          debug_options.lookup_experiment('use_unified_worker') or
  1387          debug_options.lookup_experiment('use_runner_v2') or
  1388          debug_options.lookup_experiment('use_portable_job_submission')):
  1389      if _is_runner_v2_disabled(options):
  1390        raise ValueError(
  1391            """Runner V2 both disabled and enabled: at least one of
  1392            ['enable_prime', 'beam_fn_api', 'use_unified_worker', 'use_runner_v2',
  1393            'use_portable_job_submission'] is set and also one of
  1394            ['disable_runner_v2', 'disable_runner_v2_until_2023',
  1395            'disable_prime_runner_v2'] is set.""")
  1396      _add_runner_v2_missing_options(debug_options)
  1397  
  1398  
  1399  def _is_runner_v2(options):
  1400    # Type: (PipelineOptions) -> bool
  1401  
  1402    """Returns true if runner v2 is enabled."""
  1403    _check_and_add_missing_options(options)
  1404    return options.view_as(DebugOptions).lookup_experiment(
  1405        'use_runner_v2', default=False)
  1406  
  1407  
  1408  def _is_runner_v2_disabled(options):
  1409    # Type: (PipelineOptions) -> bool
  1410  
  1411    """Returns true if runner v2 is disabled."""
  1412    debug_options = options.view_as(DebugOptions)
  1413    return (
  1414        debug_options.lookup_experiment('disable_runner_v2') or
  1415        debug_options.lookup_experiment('disable_runner_v2_until_2023') or
  1416        debug_options.lookup_experiment('disable_prime_runner_v2'))
  1417  
  1418  
  1419  class _DataflowIterableSideInput(_DataflowSideInput):
  1420    """Wraps an iterable side input as dataflow-compatible side input."""
  1421    def __init__(self, side_input):
  1422      # pylint: disable=protected-access
  1423      self.pvalue = side_input.pvalue
  1424      side_input_data = side_input._side_input_data()
  1425      assert (
  1426          side_input_data.access_pattern == common_urns.side_inputs.ITERABLE.urn)
  1427      self._data = beam.pvalue.SideInputData(
  1428          common_urns.side_inputs.ITERABLE.urn,
  1429          side_input_data.window_mapping_fn,
  1430          side_input_data.view_fn)
  1431  
  1432  
  1433  class _DataflowMultimapSideInput(_DataflowSideInput):
  1434    """Wraps a multimap side input as dataflow-compatible side input."""
  1435    def __init__(self, side_input):
  1436      # pylint: disable=protected-access
  1437      self.pvalue = side_input.pvalue
  1438      side_input_data = side_input._side_input_data()
  1439      assert (
  1440          side_input_data.access_pattern == common_urns.side_inputs.MULTIMAP.urn)
  1441      self._data = beam.pvalue.SideInputData(
  1442          common_urns.side_inputs.MULTIMAP.urn,
  1443          side_input_data.window_mapping_fn,
  1444          side_input_data.view_fn)
  1445  
  1446  
  1447  class DataflowPipelineResult(PipelineResult):
  1448    """Represents the state of a pipeline run on the Dataflow service."""
  1449    def __init__(self, job, runner):
  1450      """Initialize a new DataflowPipelineResult instance.
  1451  
  1452      Args:
  1453        job: Job message from the Dataflow API. Could be :data:`None` if a job
  1454          request was not sent to Dataflow service (e.g. template jobs).
  1455        runner: DataflowRunner instance.
  1456      """
  1457      self._job = job
  1458      self._runner = runner
  1459      self.metric_results = None
  1460  
  1461    def _update_job(self):
  1462      # We need the job id to be able to update job information. There is no need
  1463      # to update the job if we are in a known terminal state.
  1464      if self.has_job and not self.is_in_terminal_state():
  1465        self._job = self._runner.dataflow_client.get_job(self.job_id())
  1466  
  1467    def job_id(self):
  1468      return self._job.id
  1469  
  1470    def metrics(self):
  1471      return self.metric_results
  1472  
  1473    def monitoring_infos(self):
  1474      logging.warning('Monitoring infos not yet supported for Dataflow runner.')
  1475      return []
  1476  
  1477    @property
  1478    def has_job(self):
  1479      return self._job is not None
  1480  
  1481    @staticmethod
  1482    def api_jobstate_to_pipeline_state(api_jobstate):
  1483      values_enum = dataflow_api.Job.CurrentStateValueValuesEnum
  1484  
  1485      # Ordered by the enum values. Values that may be introduced in
  1486      # future versions of Dataflow API are considered UNRECOGNIZED by this SDK.
  1487      api_jobstate_map = defaultdict(
  1488          lambda: PipelineState.UNRECOGNIZED,
  1489          {
  1490              values_enum.JOB_STATE_UNKNOWN: PipelineState.UNKNOWN,
  1491              values_enum.JOB_STATE_STOPPED: PipelineState.STOPPED,
  1492              values_enum.JOB_STATE_RUNNING: PipelineState.RUNNING,
  1493              values_enum.JOB_STATE_DONE: PipelineState.DONE,
  1494              values_enum.JOB_STATE_FAILED: PipelineState.FAILED,
  1495              values_enum.JOB_STATE_CANCELLED: PipelineState.CANCELLED,
  1496              values_enum.JOB_STATE_UPDATED: PipelineState.UPDATED,
  1497              values_enum.JOB_STATE_DRAINING: PipelineState.DRAINING,
  1498              values_enum.JOB_STATE_DRAINED: PipelineState.DRAINED,
  1499              values_enum.JOB_STATE_PENDING: PipelineState.PENDING,
  1500              values_enum.JOB_STATE_CANCELLING: PipelineState.CANCELLING,
  1501              values_enum.JOB_STATE_RESOURCE_CLEANING_UP: PipelineState.
  1502              RESOURCE_CLEANING_UP,
  1503          })
  1504  
  1505      return (
  1506          api_jobstate_map[api_jobstate]
  1507          if api_jobstate else PipelineState.UNKNOWN)
  1508  
  1509    def _get_job_state(self):
  1510      return self.api_jobstate_to_pipeline_state(self._job.currentState)
  1511  
  1512    @property
  1513    def state(self):
  1514      """Return the current state of the remote job.
  1515  
  1516      Returns:
  1517        A PipelineState object.
  1518      """
  1519      if not self.has_job:
  1520        return PipelineState.UNKNOWN
  1521  
  1522      self._update_job()
  1523  
  1524      return self._get_job_state()
  1525  
  1526    def is_in_terminal_state(self):
  1527      if not self.has_job:
  1528        return True
  1529  
  1530      return PipelineState.is_terminal(self._get_job_state())
  1531  
  1532    def wait_until_finish(self, duration=None):
  1533      if not self.is_in_terminal_state():
  1534        if not self.has_job:
  1535          raise IOError('Failed to get the Dataflow job id.')
  1536        consoleUrl = (
  1537            "Console URL: https://console.cloud.google.com/"
  1538            f"dataflow/jobs/<RegionId>/{self.job_id()}"
  1539            "?project=<ProjectId>")
  1540        thread = threading.Thread(
  1541            target=DataflowRunner.poll_for_job_completion,
  1542            args=(self._runner, self, duration))
  1543  
  1544        # Mark the thread as a daemon thread so a keyboard interrupt on the main
  1545        # thread will terminate everything. This is also the reason we will not
  1546        # use thread.join() to wait for the polling thread.
  1547        thread.daemon = True
  1548        thread.start()
  1549        while thread.is_alive():
  1550          time.sleep(5.0)
  1551  
  1552        # TODO: Merge the termination code in poll_for_job_completion and
  1553        # is_in_terminal_state.
  1554        terminated = self.is_in_terminal_state()
  1555        assert duration or terminated, (
  1556            'Job did not reach to a terminal state after waiting indefinitely. '
  1557            '{}'.format(consoleUrl))
  1558  
  1559        if terminated and self.state != PipelineState.DONE:
  1560          # TODO(BEAM-1290): Consider converting this to an error log based on
  1561          # theresolution of the issue.
  1562          _LOGGER.error(consoleUrl)
  1563          raise DataflowRuntimeException(
  1564              'Dataflow pipeline failed. State: %s, Error:\n%s' %
  1565              (self.state, getattr(self._runner, 'last_error_msg', None)),
  1566              self)
  1567      elif PipelineState.is_terminal(
  1568          self.state) and self.state == PipelineState.FAILED and self._runner:
  1569        raise DataflowRuntimeException(
  1570            'Dataflow pipeline failed. State: %s, Error:\n%s' %
  1571            (self.state, getattr(self._runner, 'last_error_msg', None)),
  1572            self)
  1573  
  1574      return self.state
  1575  
  1576    def cancel(self):
  1577      if not self.has_job:
  1578        raise IOError('Failed to get the Dataflow job id.')
  1579  
  1580      self._update_job()
  1581  
  1582      if self.is_in_terminal_state():
  1583        _LOGGER.warning(
  1584            'Cancel failed because job %s is already terminated in state %s.',
  1585            self.job_id(),
  1586            self.state)
  1587      else:
  1588        if not self._runner.dataflow_client.modify_job_state(
  1589            self.job_id(), 'JOB_STATE_CANCELLED'):
  1590          cancel_failed_message = (
  1591              'Failed to cancel job %s, please go to the Developers Console to '
  1592              'cancel it manually.') % self.job_id()
  1593          _LOGGER.error(cancel_failed_message)
  1594          raise DataflowRuntimeException(cancel_failed_message, self)
  1595  
  1596      return self.state
  1597  
  1598    def __str__(self):
  1599      return '<%s %s %s>' % (self.__class__.__name__, self.job_id(), self.state)
  1600  
  1601    def __repr__(self):
  1602      return '<%s %s at %s>' % (self.__class__.__name__, self._job, hex(id(self)))
  1603  
  1604  
  1605  class DataflowRuntimeException(Exception):
  1606    """Indicates an error has occurred in running this pipeline."""
  1607    def __init__(self, msg, result):
  1608      super().__init__(msg)
  1609      self.result = result