github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/pipeline_fragment.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Module to build pipeline fragment that produces given PCollections.
    19  
    20  For internal use only; no backwards-compatibility guarantees.
    21  """
    22  import apache_beam as beam
    23  from apache_beam.pipeline import AppliedPTransform
    24  from apache_beam.pipeline import PipelineVisitor
    25  from apache_beam.runners.interactive import interactive_environment as ie
    26  from apache_beam.testing.test_stream import TestStream
    27  
    28  
    29  class PipelineFragment(object):
    30    """A fragment of a pipeline definition.
    31  
    32    A pipeline fragment is built from the original pipeline definition to include
    33    only PTransforms that are necessary to produce the given PCollections.
    34    """
    35    def __init__(self, pcolls, options=None):
    36      """Constructor of PipelineFragment.
    37  
    38      Args:
    39        pcolls: (List[PCollection]) a list of PCollections to build pipeline
    40            fragment for.
    41        options: (PipelineOptions) the pipeline options for the implicit
    42            pipeline run.
    43      """
    44      assert len(pcolls) > 0, (
    45          'Need at least 1 PCollection as the target data to build a pipeline '
    46          'fragment that produces it.')
    47      for pcoll in pcolls:
    48        assert isinstance(pcoll, beam.pvalue.PCollection), (
    49            '{} is not an apache_beam.pvalue.PCollection.'.format(pcoll))
    50      # No modification to self._user_pipeline is allowed.
    51      self._user_pipeline = pcolls[0].pipeline
    52      # These are user PCollections. Do not use them to deduce anything that
    53      # will be executed by any runner. Instead, use
    54      # `self._runner_pcolls_to_user_pcolls.keys()` to get copied PCollections.
    55      self._pcolls = set(pcolls)
    56      for pcoll in self._pcolls:
    57        assert pcoll.pipeline is self._user_pipeline, (
    58            '{} belongs to a different user pipeline than other PCollections '
    59            'given and cannot be used to build a pipeline fragment that produces '
    60            'the given PCollections.'.format(pcoll))
    61      self._options = options
    62  
    63      # A copied pipeline instance for modification without changing the user
    64      # pipeline instance held by the end user. This instance can be processed
    65      # into a pipeline fragment that later run by the underlying runner.
    66      self._runner_pipeline = self._build_runner_pipeline()
    67      _, self._context = self._runner_pipeline.to_runner_api(return_context=True)
    68      from apache_beam.runners.interactive import pipeline_instrument as instr
    69      self._runner_pcoll_to_id = instr.pcoll_to_pcoll_id(
    70          self._runner_pipeline, self._context)
    71      # Correlate components in the runner pipeline to components in the user
    72      # pipeline. The target pcolls are the pcolls given and defined in the user
    73      # pipeline.
    74      self._id_to_target_pcoll = self._calculate_target_pcoll_ids()
    75      self._label_to_user_transform = self._calculate_user_transform_labels()
    76      # Below will give us the 1:1 correlation between
    77      # PCollections/AppliedPTransforms from the copied runner pipeline and
    78      # PCollections/AppliedPTransforms from the user pipeline.
    79      # (Dict[PCollection, PCollection])
    80      (
    81          self._runner_pcolls_to_user_pcolls,
    82          # (Dict[AppliedPTransform, AppliedPTransform])
    83          self._runner_transforms_to_user_transforms
    84      ) = self._build_correlation_between_pipelines(
    85          self._runner_pcoll_to_id,
    86          self._id_to_target_pcoll,
    87          self._label_to_user_transform)
    88  
    89      # Below are operated on the runner pipeline.
    90      (self._necessary_transforms,
    91       self._necessary_pcollections) = self._mark_necessary_transforms_and_pcolls(
    92           self._runner_pcolls_to_user_pcolls)
    93      self._runner_pipeline = self._prune_runner_pipeline_to_fragment(
    94          self._runner_pipeline, self._necessary_transforms)
    95  
    96    def deduce_fragment(self):
    97      """Deduce the pipeline fragment as an apache_beam.Pipeline instance."""
    98      fragment = beam.pipeline.Pipeline.from_runner_api(
    99          self._runner_pipeline.to_runner_api(),
   100          self._runner_pipeline.runner,
   101          self._options)
   102      ie.current_env().add_derived_pipeline(self._runner_pipeline, fragment)
   103      return fragment
   104  
   105    def run(self, display_pipeline_graph=False, use_cache=True, blocking=False):
   106      """Shorthand to run the pipeline fragment."""
   107      from apache_beam.runners.interactive.interactive_runner import InteractiveRunner
   108      if not isinstance(self._runner_pipeline.runner, InteractiveRunner):
   109        raise RuntimeError(
   110            'Please specify InteractiveRunner when creating '
   111            'the Beam pipeline to use this function.')
   112      try:
   113        preserved_skip_display = self._runner_pipeline.runner._skip_display
   114        preserved_force_compute = self._runner_pipeline.runner._force_compute
   115        preserved_blocking = self._runner_pipeline.runner._blocking
   116        self._runner_pipeline.runner._skip_display = not display_pipeline_graph
   117        self._runner_pipeline.runner._force_compute = not use_cache
   118        self._runner_pipeline.runner._blocking = blocking
   119        return self.deduce_fragment().run()
   120      finally:
   121        self._runner_pipeline.runner._skip_display = preserved_skip_display
   122        self._runner_pipeline.runner._force_compute = preserved_force_compute
   123        self._runner_pipeline.runner._blocking = preserved_blocking
   124  
   125    def _build_runner_pipeline(self):
   126      runner_pipeline = beam.pipeline.Pipeline.from_runner_api(
   127          self._user_pipeline.to_runner_api(),
   128          self._user_pipeline.runner,
   129          self._options)
   130      ie.current_env().add_derived_pipeline(self._user_pipeline, runner_pipeline)
   131      return runner_pipeline
   132  
   133    def _calculate_target_pcoll_ids(self):
   134      pcoll_id_to_target_pcoll = {}
   135      for pcoll in self._pcolls:
   136        pcoll_id_to_target_pcoll[self._runner_pcoll_to_id.get(str(pcoll),
   137                                                              '')] = pcoll
   138      return pcoll_id_to_target_pcoll
   139  
   140    def _calculate_user_transform_labels(self):
   141      label_to_user_transform = {}
   142  
   143      class UserTransformVisitor(PipelineVisitor):
   144        def enter_composite_transform(self, transform_node):
   145          self.visit_transform(transform_node)
   146  
   147        def visit_transform(self, transform_node):
   148          if transform_node is not None:
   149            label_to_user_transform[transform_node.full_label] = transform_node
   150  
   151      v = UserTransformVisitor()
   152      self._runner_pipeline.visit(v)
   153      return label_to_user_transform
   154  
   155    def _build_correlation_between_pipelines(
   156        self, runner_pcoll_to_id, id_to_target_pcoll, label_to_user_transform):
   157      runner_pcolls_to_user_pcolls = {}
   158      runner_transforms_to_user_transforms = {}
   159  
   160      class CorrelationVisitor(PipelineVisitor):
   161        def enter_composite_transform(self, transform_node):
   162          self.visit_transform(transform_node)
   163  
   164        def visit_transform(self, transform_node):
   165          self._process_transform(transform_node)
   166          for in_pcoll in transform_node.inputs:
   167            self._process_pcoll(in_pcoll)
   168          for out_pcoll in transform_node.outputs.values():
   169            self._process_pcoll(out_pcoll)
   170  
   171        def _process_pcoll(self, pcoll):
   172          pcoll_id = runner_pcoll_to_id.get(str(pcoll), '')
   173          if pcoll_id in id_to_target_pcoll:
   174            runner_pcolls_to_user_pcolls[pcoll] = (id_to_target_pcoll[pcoll_id])
   175  
   176        def _process_transform(self, transform_node):
   177          if transform_node.full_label in label_to_user_transform:
   178            runner_transforms_to_user_transforms[transform_node] = (
   179                label_to_user_transform[transform_node.full_label])
   180  
   181      v = CorrelationVisitor()
   182      self._runner_pipeline.visit(v)
   183      return runner_pcolls_to_user_pcolls, runner_transforms_to_user_transforms
   184  
   185    def _mark_necessary_transforms_and_pcolls(self, runner_pcolls_to_user_pcolls):
   186      necessary_transforms = set()
   187      all_inputs = set()
   188      updated_all_inputs = set(runner_pcolls_to_user_pcolls.keys())
   189      # Do this until no more new PCollection is recorded.
   190      while len(updated_all_inputs) != len(all_inputs):
   191        all_inputs = set(updated_all_inputs)
   192        for pcoll in all_inputs:
   193          producer = pcoll.producer
   194          while producer:
   195            if producer in necessary_transforms:
   196              break
   197            # Mark the AppliedPTransform as necessary.
   198            necessary_transforms.add(producer)
   199  
   200            # Also mark composites that are not the root transform. If the root
   201            # transform is added, then all transforms are incorrectly marked as
   202            # necessary. If composites are not handled, then there will be
   203            # orphaned PCollections.
   204            if producer.parent is not None:
   205              necessary_transforms.update(producer.parts)
   206  
   207              # This will recursively add all the PCollections in this composite.
   208              for part in producer.parts:
   209                updated_all_inputs.update(part.outputs.values())
   210  
   211            # Record all necessary input and side input PCollections.
   212            updated_all_inputs.update(producer.inputs)
   213            # pylint: disable=bad-option-value
   214            side_input_pvalues = set(
   215                map(lambda side_input: side_input.pvalue, producer.side_inputs))
   216            updated_all_inputs.update(side_input_pvalues)
   217            # Go to its parent AppliedPTransform.
   218            producer = producer.parent
   219      return necessary_transforms, all_inputs
   220  
   221    def _prune_runner_pipeline_to_fragment(
   222        self, runner_pipeline, necessary_transforms):
   223      class PruneVisitor(PipelineVisitor):
   224        def enter_composite_transform(self, transform_node):
   225          if should_skip_pruning(transform_node):
   226            return
   227  
   228          pruned_parts = list(transform_node.parts)
   229          for part in transform_node.parts:
   230            if part not in necessary_transforms:
   231              pruned_parts.remove(part)
   232          transform_node.parts = tuple(pruned_parts)
   233          self.visit_transform(transform_node)
   234  
   235        def visit_transform(self, transform_node):
   236          if transform_node not in necessary_transforms:
   237            transform_node.parent = None
   238  
   239      v = PruneVisitor()
   240      runner_pipeline.visit(v)
   241      return runner_pipeline
   242  
   243  
   244  def should_skip_pruning(transform: AppliedPTransform):
   245    return (
   246        isinstance(transform.transform, TestStream) or
   247        '_DataFrame_' in transform.full_label)