github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/pipeline_instrument.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Module to instrument interactivity to the given pipeline.
    19  
    20  For internal use only; no backwards-compatibility guarantees.
    21  This module accesses current interactive environment and analyzes given pipeline
    22  to transform original pipeline into a one-shot pipeline with interactivity.
    23  """
    24  # pytype: skip-file
    25  
    26  import logging
    27  from typing import Dict
    28  
    29  import apache_beam as beam
    30  from apache_beam.pipeline import PipelineVisitor
    31  from apache_beam.portability.api import beam_runner_api_pb2
    32  from apache_beam.runners.interactive import interactive_environment as ie
    33  from apache_beam.runners.interactive import pipeline_fragment as pf
    34  from apache_beam.runners.interactive import background_caching_job
    35  from apache_beam.runners.interactive import utils
    36  from apache_beam.runners.interactive.caching.cacheable import Cacheable
    37  from apache_beam.runners.interactive.caching.cacheable import CacheKey
    38  from apache_beam.runners.interactive.caching.reify import WRITE_CACHE
    39  from apache_beam.runners.interactive.caching.reify import reify_to_cache
    40  from apache_beam.runners.interactive.caching.reify import unreify_from_cache
    41  from apache_beam.testing import test_stream
    42  
    43  _LOGGER = logging.getLogger(__name__)
    44  
    45  
    46  class PipelineInstrument(object):
    47    """A pipeline instrument for pipeline to be executed by interactive runner.
    48  
    49    This module should never depend on underlying runner that interactive runner
    50    delegates. It instruments the original instance of pipeline directly by
    51    appending or replacing transforms with help of cache. It provides
    52    interfaces to recover states of original pipeline. It's the interactive
    53    runner's responsibility to coordinate supported underlying runners to run
    54    the pipeline instrumented and recover the original pipeline states if needed.
    55    """
    56    def __init__(self, pipeline, options=None):
    57      self._pipeline = pipeline
    58  
    59      self._user_pipeline = ie.current_env().user_pipeline(pipeline)
    60      if not self._user_pipeline:
    61        self._user_pipeline = pipeline
    62      self._cache_manager = ie.current_env().get_cache_manager(
    63          self._user_pipeline, create_if_absent=True)
    64      # Check if the user defined pipeline contains any source to cache.
    65      # If so, during the check, the cache manager is converted into a
    66      # streaming cache manager, thus re-assign.
    67      if background_caching_job.has_source_to_cache(self._user_pipeline):
    68        self._cache_manager = ie.current_env().get_cache_manager(
    69            self._user_pipeline)
    70  
    71      self._background_caching_pipeline = beam.pipeline.Pipeline.from_runner_api(
    72          pipeline.to_runner_api(), pipeline.runner, options)
    73      ie.current_env().add_derived_pipeline(
    74          self._pipeline, self._background_caching_pipeline)
    75  
    76      # Snapshot of original pipeline information.
    77      (self._original_pipeline_proto,
    78       context) = self._pipeline.to_runner_api(return_context=True)
    79  
    80      # All compute-once-against-original-pipeline fields.
    81      self._unbounded_sources = utils.unbounded_sources(
    82          self._background_caching_pipeline)
    83      self._pcoll_to_pcoll_id = pcoll_to_pcoll_id(self._pipeline, context)
    84  
    85      # A Dict[str, Cacheable] from a PCollection id to a Cacheable that belongs
    86      # to the analyzed pipeline.
    87      self._cacheables = self.find_cacheables()
    88  
    89      # A dict from cache key to PCollection that is read from cache.
    90      # If exists, caller should reuse the PCollection read. If not, caller
    91      # should create new transform and track the PCollection read from cache.
    92      # (Dict[str, AppliedPTransform]).
    93      self._cached_pcoll_read = {}
    94  
    95      # A dict from PCollections in the runner pipeline instance to their
    96      # corresponding PCollections in the user pipeline instance. Populated
    97      # after preprocess().
    98      self._runner_pcoll_to_user_pcoll = {}
    99      self._pruned_pipeline_proto = None
   100  
   101      # Refers target pcolls output by instrumented write cache transforms, used
   102      # by pruning logic as supplemental targets to build pipeline fragment up
   103      # from.
   104      self._extended_targets = set()
   105  
   106      # Refers pcolls used as inputs but got replaced by outputs of read cache
   107      # transforms instrumented, used by pruning logic as targets no longer need
   108      # to be produced during pipeline runs.
   109      self._ignored_targets = set()
   110  
   111      # Set of PCollections that are written to cache.
   112      self.cached_pcolls = set()
   113  
   114    def instrumented_pipeline_proto(self):
   115      """Always returns a new instance of portable instrumented proto."""
   116      targets = set(self._runner_pcoll_to_user_pcoll.keys())
   117      targets.update(self._extended_targets)
   118      targets = targets.difference(self._ignored_targets)
   119      if len(targets) > 0:
   120        # Prunes upstream transforms that don't contribute to the targets the
   121        # instrumented pipeline run cares.
   122        return pf.PipelineFragment(
   123            list(targets)).deduce_fragment().to_runner_api()
   124      return self._pipeline.to_runner_api()
   125  
   126    def _required_components(
   127        self,
   128        pipeline_proto,
   129        required_transforms_ids,
   130        visited,
   131        follow_outputs=False,
   132        follow_inputs=False):
   133      """Returns the components and subcomponents of the given transforms.
   134  
   135      This method returns required components such as transforms and PCollections
   136      related to the given transforms and to all of their subtransforms. This
   137      method accomplishes this recursively.
   138      """
   139      if not required_transforms_ids:
   140        return ({}, {})
   141  
   142      transforms = pipeline_proto.components.transforms
   143      pcollections = pipeline_proto.components.pcollections
   144  
   145      # Cache the transforms that will be copied into the new pipeline proto.
   146      required_transforms = {k: transforms[k] for k in required_transforms_ids}
   147  
   148      # Cache all the output PCollections of the transforms.
   149      pcollection_ids = [
   150          pc for t in required_transforms.values() for pc in t.outputs.values()
   151      ]
   152      required_pcollections = {
   153          pc_id: pcollections[pc_id]
   154          for pc_id in pcollection_ids
   155      }
   156  
   157      subtransforms = {}
   158      subpcollections = {}
   159  
   160      # Recursively go through all the subtransforms and add their components.
   161      for transform_id, transform in required_transforms.items():
   162        if transform_id in pipeline_proto.root_transform_ids:
   163          continue
   164        (t, pc) = self._required_components(
   165            pipeline_proto,
   166            transform.subtransforms,
   167            visited,
   168            follow_outputs=False,
   169            follow_inputs=False)
   170        subtransforms.update(t)
   171        subpcollections.update(pc)
   172  
   173      if follow_outputs:
   174        outputs = [
   175            pc_id for t in required_transforms.values()
   176            for pc_id in t.outputs.values()
   177        ]
   178        visited_copy = visited.copy()
   179        consuming_transforms = {
   180            t_id: t
   181            for t_id,
   182            t in transforms.items()
   183            if set(outputs).intersection(set(t.inputs.values()))
   184        }
   185        consuming_transforms = set(consuming_transforms.keys())
   186        visited.update(consuming_transforms)
   187        consuming_transforms = consuming_transforms - visited_copy
   188        (t, pc) = self._required_components(
   189            pipeline_proto,
   190            list(consuming_transforms),
   191            visited,
   192            follow_outputs,
   193            follow_inputs)
   194        subtransforms.update(t)
   195        subpcollections.update(pc)
   196  
   197      if follow_inputs:
   198        inputs = [
   199            pc_id for t in required_transforms.values()
   200            for pc_id in t.inputs.values()
   201        ]
   202        producing_transforms = {
   203            t_id: t
   204            for t_id,
   205            t in transforms.items()
   206            if set(inputs).intersection(set(t.outputs.values()))
   207        }
   208        (t, pc) = self._required_components(
   209            pipeline_proto,
   210            list(producing_transforms.keys()),
   211            visited,
   212            follow_outputs,
   213            follow_inputs)
   214        subtransforms.update(t)
   215        subpcollections.update(pc)
   216  
   217      # Now we got all the components and their subcomponents, so return the
   218      # complete collection.
   219      required_transforms.update(subtransforms)
   220      required_pcollections.update(subpcollections)
   221  
   222      return (required_transforms, required_pcollections)
   223  
   224    def prune_subgraph_for(self, pipeline, required_transform_ids):
   225      # Create the pipeline_proto to read all the components from. It will later
   226      # create a new pipeline proto from the cut out components.
   227      pipeline_proto, context = pipeline.to_runner_api(return_context=True)
   228  
   229      # Get all the root transforms. The caching transforms will be subtransforms
   230      # of one of these roots.
   231      roots = [root for root in pipeline_proto.root_transform_ids]
   232  
   233      (t, p) = self._required_components(
   234          pipeline_proto,
   235          roots + required_transform_ids,
   236          set(),
   237          follow_outputs=True,
   238          follow_inputs=True)
   239  
   240      def set_proto_map(proto_map, new_value):
   241        proto_map.clear()
   242        for key, value in new_value.items():
   243          proto_map[key].CopyFrom(value)
   244  
   245      # Copy the transforms into the new pipeline.
   246      pipeline_to_execute = beam_runner_api_pb2.Pipeline()
   247      pipeline_to_execute.root_transform_ids[:] = roots
   248      set_proto_map(pipeline_to_execute.components.transforms, t)
   249      set_proto_map(pipeline_to_execute.components.pcollections, p)
   250      set_proto_map(
   251          pipeline_to_execute.components.coders, context.to_runner_api().coders)
   252      set_proto_map(
   253          pipeline_to_execute.components.windowing_strategies,
   254          context.to_runner_api().windowing_strategies)
   255  
   256      # Cut out all subtransforms in the root that aren't the required transforms.
   257      for root_id in roots:
   258        root = pipeline_to_execute.components.transforms[root_id]
   259        root.subtransforms[:] = [
   260            transform_id for transform_id in root.subtransforms
   261            if transform_id in pipeline_to_execute.components.transforms
   262        ]
   263  
   264      return pipeline_to_execute
   265  
   266    def background_caching_pipeline_proto(self):
   267      """Returns the background caching pipeline.
   268  
   269      This method creates a background caching pipeline by: adding writes to cache
   270      from each unbounded source (done in the instrument method), and cutting out
   271      all components (transform, PCollections, coders, windowing strategies) that
   272      are not the unbounded sources or writes to cache (or subtransforms thereof).
   273      """
   274      # Create the pipeline_proto to read all the components from. It will later
   275      # create a new pipeline proto from the cut out components.
   276      pipeline_proto, context = self._background_caching_pipeline.to_runner_api(
   277          return_context=True)
   278  
   279      # Get all the sources we want to cache.
   280      sources = utils.unbounded_sources(self._background_caching_pipeline)
   281  
   282      # Get all the root transforms. The caching transforms will be subtransforms
   283      # of one of these roots.
   284      roots = [root for root in pipeline_proto.root_transform_ids]
   285  
   286      # Get the transform IDs of the caching transforms. These caching operations
   287      # are added to the _background_caching_pipeline in the instrument() method.
   288      # It's added there so that multiple calls to this method won't add multiple
   289      # caching operations (idempotent).
   290      transforms = pipeline_proto.components.transforms
   291      caching_transform_ids = [
   292          t_id for root in roots for t_id in transforms[root].subtransforms
   293          if WRITE_CACHE in t_id
   294      ]
   295  
   296      # Get the IDs of the unbounded sources.
   297      required_transform_labels = [src.full_label for src in sources]
   298      unbounded_source_ids = [
   299          k for k,
   300          v in transforms.items() if v.unique_name in required_transform_labels
   301      ]
   302  
   303      # The required transforms are the transforms that we want to cut out of
   304      # the pipeline_proto and insert into a new pipeline to return.
   305      required_transform_ids = (
   306          roots + caching_transform_ids + unbounded_source_ids)
   307      (t, p) = self._required_components(
   308          pipeline_proto, required_transform_ids, set())
   309  
   310      def set_proto_map(proto_map, new_value):
   311        proto_map.clear()
   312        for key, value in new_value.items():
   313          proto_map[key].CopyFrom(value)
   314  
   315      # Copy the transforms into the new pipeline.
   316      pipeline_to_execute = beam_runner_api_pb2.Pipeline()
   317      pipeline_to_execute.root_transform_ids[:] = roots
   318      set_proto_map(pipeline_to_execute.components.transforms, t)
   319      set_proto_map(pipeline_to_execute.components.pcollections, p)
   320      set_proto_map(
   321          pipeline_to_execute.components.coders, context.to_runner_api().coders)
   322      set_proto_map(
   323          pipeline_to_execute.components.windowing_strategies,
   324          context.to_runner_api().windowing_strategies)
   325  
   326      # Cut out all subtransforms in the root that aren't the required transforms.
   327      for root_id in roots:
   328        root = pipeline_to_execute.components.transforms[root_id]
   329        root.subtransforms[:] = [
   330            transform_id for transform_id in root.subtransforms
   331            if transform_id in pipeline_to_execute.components.transforms
   332        ]
   333  
   334      return pipeline_to_execute
   335  
   336    @property
   337    def cacheables(self) -> Dict[str, Cacheable]:
   338      """Returns the Cacheables by PCollection ids.
   339  
   340      If you're already working with user defined pipelines and PCollections,
   341      do not build a PipelineInstrument just to get the cacheables. Instead,
   342      use apache_beam.runners.interactive.utils.cacheables.
   343      """
   344      return self._cacheables
   345  
   346    @property
   347    def has_unbounded_sources(self):
   348      """Returns whether the pipeline has any recordable sources.
   349      """
   350      return len(self._unbounded_sources) > 0
   351  
   352    @property
   353    def original_pipeline_proto(self):
   354      """Returns a snapshot of the pipeline proto before instrumentation."""
   355      return self._original_pipeline_proto
   356  
   357    @property
   358    def user_pipeline(self):
   359      """Returns a reference to the pipeline instance defined by the user. If a
   360      pipeline has no cacheable PCollection and the user pipeline cannot be
   361      found, return None indicating there is nothing to be cached in the user
   362      pipeline.
   363  
   364      The pipeline given for instrumenting and mutated in this class is not
   365      necessarily the pipeline instance defined by the user. From the watched
   366      scopes, this class figures out what the user pipeline instance is.
   367      This metadata can be used for tracking pipeline results.
   368      """
   369      return self._user_pipeline
   370  
   371    @property
   372    def runner_pcoll_to_user_pcoll(self):
   373      """Returns cacheable PCollections correlated from instances in the runner
   374      pipeline to instances in the user pipeline."""
   375      return self._runner_pcoll_to_user_pcoll
   376  
   377    def find_cacheables(self) -> Dict[str, Cacheable]:
   378      """Finds PCollections that need to be cached for analyzed pipeline.
   379  
   380      There might be multiple pipelines defined and watched, this will only find
   381      cacheables belong to the analyzed pipeline.
   382      """
   383      result = {}
   384      cacheables = utils.cacheables()
   385      for _, cacheable in cacheables.items():
   386        if cacheable.pcoll.pipeline is not self._user_pipeline:
   387          # Ignore all cacheables from other pipelines.
   388          continue
   389        pcoll_id = self.pcoll_id(cacheable.pcoll)
   390        if not pcoll_id:
   391          _LOGGER.debug(
   392              'Unable to retrieve PCollection id for %s. Ignored.',
   393              cacheable.pcoll)
   394          continue
   395        result[self.pcoll_id(cacheable.pcoll)] = cacheable
   396      return result
   397  
   398    def instrument(self):
   399      """Instruments original pipeline with cache.
   400  
   401      For cacheable output PCollection, if cache for the key doesn't exist, do
   402      _write_cache(); for cacheable input PCollection, if cache for the key
   403      exists, do _read_cache(). No instrument in any other situation.
   404  
   405      Modifies:
   406        self._pipeline
   407      """
   408      cacheable_inputs = set()
   409      all_inputs = set()
   410      all_outputs = set()
   411      unbounded_source_pcolls = set()
   412  
   413      class InstrumentVisitor(PipelineVisitor):
   414        """Visitor utilizes cache to instrument the pipeline."""
   415        def __init__(self, pin):
   416          self._pin = pin
   417  
   418        def enter_composite_transform(self, transform_node):
   419          self.visit_transform(transform_node)
   420  
   421        def visit_transform(self, transform_node):
   422          if isinstance(transform_node.transform,
   423                        tuple(ie.current_env().options.recordable_sources)):
   424            unbounded_source_pcolls.update(transform_node.outputs.values())
   425          cacheable_inputs.update(self._pin._cacheable_inputs(transform_node))
   426          ins, outs = self._pin._all_inputs_outputs(transform_node)
   427          all_inputs.update(ins)
   428          all_outputs.update(outs)
   429  
   430      v = InstrumentVisitor(self)
   431      self._pipeline.visit(v)
   432      # Every output PCollection that is never used as an input PCollection is
   433      # considered as a side effect of the pipeline run and should be included.
   434      self._extended_targets.update(all_outputs.difference(all_inputs))
   435      # Add the unbounded source PCollections to the cacheable inputs. This allows
   436      # for the caching of unbounded sources without a variable reference.
   437      cacheable_inputs.update(unbounded_source_pcolls)
   438  
   439      # Create ReadCache transforms.
   440      for cacheable_input in cacheable_inputs:
   441        self._read_cache(
   442            self._pipeline,
   443            cacheable_input,
   444            cacheable_input in unbounded_source_pcolls)
   445      # Replace/wire inputs w/ cached PCollections from ReadCache transforms.
   446      self._replace_with_cached_inputs(self._pipeline)
   447  
   448      # Write cache for all cacheables.
   449      for _, cacheable in self._cacheables.items():
   450        self._write_cache(
   451            self._pipeline, cacheable.pcoll, ignore_unbounded_reads=True)
   452  
   453      # Instrument the background caching pipeline if we can.
   454      if self.has_unbounded_sources:
   455        for source in self._unbounded_sources:
   456          self._write_cache(
   457              self._background_caching_pipeline,
   458              source.outputs[None],
   459              output_as_extended_target=False,
   460              is_capture=True)
   461  
   462        class TestStreamVisitor(PipelineVisitor):
   463          def __init__(self):
   464            self.test_stream = None
   465  
   466          def enter_composite_transform(self, transform_node):
   467            self.visit_transform(transform_node)
   468  
   469          def visit_transform(self, transform_node):
   470            if (self.test_stream is None and
   471                isinstance(transform_node.transform, test_stream.TestStream)):
   472              self.test_stream = transform_node.full_label
   473  
   474        v = TestStreamVisitor()
   475        self._pipeline.visit(v)
   476        pipeline_proto = self._pipeline.to_runner_api(return_context=False)
   477        test_stream_id = ''
   478        for t_id, t in pipeline_proto.components.transforms.items():
   479          if t.unique_name == v.test_stream:
   480            test_stream_id = t_id
   481            break
   482        self._pruned_pipeline_proto = self.prune_subgraph_for(
   483            self._pipeline, [test_stream_id])
   484        pruned_pipeline = beam.Pipeline.from_runner_api(
   485            proto=self._pruned_pipeline_proto,
   486            runner=self._pipeline.runner,
   487            options=self._pipeline._options)
   488        ie.current_env().add_derived_pipeline(self._pipeline, pruned_pipeline)
   489        self._pipeline = pruned_pipeline
   490  
   491    def preprocess(self):
   492      """Pre-processes the pipeline.
   493  
   494      Since the pipeline instance in the class might not be the same instance
   495      defined in the user code, the pre-process will figure out the relationship
   496      of cacheable PCollections between these 2 instances by replacing 'pcoll'
   497      fields in the cacheable dictionary with ones from the running instance.
   498      """
   499      class PreprocessVisitor(PipelineVisitor):
   500        def __init__(self, pin):
   501          self._pin = pin
   502  
   503        def enter_composite_transform(self, transform_node):
   504          self.visit_transform(transform_node)
   505  
   506        def visit_transform(self, transform_node):
   507          for in_pcoll in transform_node.inputs:
   508            self._process(in_pcoll)
   509          for out_pcoll in transform_node.outputs.values():
   510            self._process(out_pcoll)
   511  
   512        def _process(self, pcoll):
   513          pcoll_id = self._pin._pcoll_to_pcoll_id.get(str(pcoll), '')
   514          if pcoll_id in self._pin._cacheables:
   515            pcoll_id = self._pin.pcoll_id(pcoll)
   516            user_pcoll = self._pin._cacheables[pcoll_id].pcoll
   517            if (pcoll_id in self._pin._cacheables and user_pcoll != pcoll):
   518              self._pin._runner_pcoll_to_user_pcoll[pcoll] = user_pcoll
   519              self._pin._cacheables[pcoll_id].pcoll = pcoll
   520  
   521      v = PreprocessVisitor(self)
   522      self._pipeline.visit(v)
   523  
   524    def _write_cache(
   525        self,
   526        pipeline,
   527        pcoll,
   528        output_as_extended_target=True,
   529        ignore_unbounded_reads=False,
   530        is_capture=False):
   531      """Caches a cacheable PCollection.
   532  
   533      For the given PCollection, by appending sub transform part that materialize
   534      the PCollection through sink into cache implementation. The cache write is
   535      not immediate. It happens when the runner runs the transformed pipeline
   536      and thus not usable for this run as intended. This function always writes
   537      the cache for the given PCollection as long as the PCollection belongs to
   538      the pipeline being instrumented and the keyed cache is absent.
   539  
   540      Modifies:
   541        pipeline
   542      """
   543      # Makes sure the pcoll belongs to the pipeline being instrumented.
   544      if pcoll.pipeline is not pipeline:
   545        return
   546  
   547      # Ignore the unbounded reads from recordable sources as these will be pruned
   548      # out using the PipelineFragment later on.
   549      if ignore_unbounded_reads:
   550        ignore = False
   551        producer = pcoll.producer
   552        while producer:
   553          if isinstance(producer.transform,
   554                        tuple(ie.current_env().options.recordable_sources)):
   555            ignore = True
   556            break
   557          producer = producer.parent
   558        if ignore:
   559          self._ignored_targets.add(pcoll)
   560          return
   561  
   562      # The keyed cache is always valid within this instrumentation.
   563      key = self.cache_key(pcoll)
   564      # Only need to write when the cache with expected key doesn't exist.
   565      if not self._cache_manager.exists('full', key):
   566        self.cached_pcolls.add(self.runner_pcoll_to_user_pcoll.get(pcoll, pcoll))
   567        # Read the windowing information and cache it along with the element. This
   568        # caches the arguments to a WindowedValue object because Python has logic
   569        # that detects if a DoFn returns a WindowedValue. When it detecs one, it
   570        # puts the element into the correct window then emits the value to
   571        # downstream transforms.
   572        extended_target = reify_to_cache(
   573            pcoll=pcoll,
   574            cache_key=key,
   575            cache_manager=self._cache_manager,
   576            is_capture=is_capture)
   577        if output_as_extended_target:
   578          self._extended_targets.add(extended_target)
   579  
   580    def _read_cache(self, pipeline, pcoll, is_unbounded_source_output):
   581      """Reads a cached pvalue.
   582  
   583      A noop will cause the pipeline to execute the transform as
   584      it is and cache nothing from this transform for next run.
   585  
   586      Modifies:
   587        pipeline
   588      """
   589      # Makes sure the pcoll belongs to the pipeline being instrumented.
   590      if pcoll.pipeline is not pipeline:
   591        return
   592      # The keyed cache is always valid within this instrumentation.
   593      key = self.cache_key(pcoll)
   594      # Can only read from cache when the cache with expected key exists and its
   595      # computation has been completed.
   596      is_cached = self._cache_manager.exists('full', key)
   597      is_computed = (
   598          pcoll in self._runner_pcoll_to_user_pcoll and
   599          self._runner_pcoll_to_user_pcoll[pcoll] in
   600          ie.current_env().computed_pcollections)
   601      if ((is_cached and is_computed) or is_unbounded_source_output):
   602        if key not in self._cached_pcoll_read:
   603          # Mutates the pipeline with cache read transform attached
   604          # to root of the pipeline.
   605  
   606          # To put the cached value into the correct window, simply return a
   607          # WindowedValue constructed from the element.
   608          pcoll_from_cache = unreify_from_cache(
   609              pipeline=pipeline, cache_key=key, cache_manager=self._cache_manager)
   610          self._cached_pcoll_read[key] = pcoll_from_cache
   611      # else: NOOP when cache doesn't exist, just compute the original graph.
   612  
   613    def _replace_with_cached_inputs(self, pipeline):
   614      """Replace PCollection inputs in the pipeline with cache if possible.
   615  
   616      For any input PCollection, find out whether there is valid cache. If so,
   617      replace the input of the AppliedPTransform with output of the
   618      AppliedPtransform that sources pvalue from the cache. If there is no valid
   619      cache, noop.
   620      """
   621  
   622      # Find all cached unbounded PCollections.
   623  
   624      # If the pipeline has unbounded sources, then we want to force all cache
   625      # reads to go through the TestStream (even if they are bounded sources).
   626      if self.has_unbounded_sources:
   627  
   628        class CacheableUnboundedPCollectionVisitor(PipelineVisitor):
   629          def __init__(self, pin):
   630            self._pin = pin
   631            self.unbounded_pcolls = set()
   632  
   633          def enter_composite_transform(self, transform_node):
   634            self.visit_transform(transform_node)
   635  
   636          def visit_transform(self, transform_node):
   637            if transform_node.outputs:
   638              for output_pcoll in transform_node.outputs.values():
   639                key = self._pin.cache_key(output_pcoll)
   640                if key in self._pin._cached_pcoll_read:
   641                  self.unbounded_pcolls.add(key)
   642  
   643            if transform_node.inputs:
   644              for input_pcoll in transform_node.inputs:
   645                key = self._pin.cache_key(input_pcoll)
   646                if key in self._pin._cached_pcoll_read:
   647                  self.unbounded_pcolls.add(key)
   648  
   649        v = CacheableUnboundedPCollectionVisitor(self)
   650        pipeline.visit(v)
   651  
   652        # The set of keys from the cached unbounded PCollections will be used as
   653        # the output tags for the TestStream. This is to remember what cache-key
   654        # is associated with which PCollection.
   655        output_tags = v.unbounded_pcolls
   656  
   657        # Take the PCollections that will be read from the TestStream and insert
   658        # them back into the dictionary of cached PCollections. The next step will
   659        # replace the downstream consumer of the non-cached PCollections with
   660        # these PCollections.
   661        if output_tags:
   662          output_pcolls = pipeline | test_stream.TestStream(
   663              output_tags=output_tags, coder=self._cache_manager._default_pcoder)
   664          for tag, pcoll in output_pcolls.items():
   665            self._cached_pcoll_read[tag] = pcoll
   666  
   667      class ReadCacheWireVisitor(PipelineVisitor):
   668        """Visitor wires cache read as inputs to replace corresponding original
   669        input PCollections in pipeline.
   670        """
   671        def __init__(self, pin):
   672          """Initializes with a PipelineInstrument."""
   673          self._pin = pin
   674  
   675        def enter_composite_transform(self, transform_node):
   676          self.visit_transform(transform_node)
   677  
   678        def visit_transform(self, transform_node):
   679          if transform_node.inputs:
   680            main_inputs = dict(transform_node.main_inputs)
   681            for tag, input_pcoll in main_inputs.items():
   682              key = self._pin.cache_key(input_pcoll)
   683  
   684              # Replace the input pcollection with the cached pcollection (if it
   685              # has been cached).
   686              if key in self._pin._cached_pcoll_read:
   687                # Ignore this pcoll in the final pruned instrumented pipeline.
   688                self._pin._ignored_targets.add(input_pcoll)
   689                main_inputs[tag] = self._pin._cached_pcoll_read[key]
   690            # Update the transform with its new inputs.
   691            transform_node.main_inputs = main_inputs
   692  
   693      v = ReadCacheWireVisitor(self)
   694      pipeline.visit(v)
   695  
   696    def _cacheable_inputs(self, transform):
   697      inputs = set()
   698      for in_pcoll in transform.inputs:
   699        if self.pcoll_id(in_pcoll) in self._cacheables:
   700          inputs.add(in_pcoll)
   701      return inputs
   702  
   703    def _all_inputs_outputs(self, transform):
   704      inputs = set()
   705      outputs = set()
   706      for in_pcoll in transform.inputs:
   707        inputs.add(in_pcoll)
   708      for _, out_pcoll in transform.outputs.items():
   709        outputs.add(out_pcoll)
   710      return inputs, outputs
   711  
   712    def pcoll_id(self, pcoll):
   713      """Gets the PCollection id of the given pcoll.
   714  
   715      Returns '' if not found.
   716      """
   717      return self._pcoll_to_pcoll_id.get(str(pcoll), '')
   718  
   719    def cache_key(self, pcoll):
   720      """Gets the identifier of a cacheable PCollection in cache.
   721  
   722      If the pcoll is not a cacheable, return ''.
   723      This is only needed in pipeline instrument when the origin of given pcoll
   724      is unknown (whether it's from the user pipeline or a runner pipeline). If
   725      a pcoll is from the user pipeline, always use CacheKey.from_pcoll to build
   726      the key.
   727      The key is what the pcoll would use as identifier if it's materialized in
   728      cache. It doesn't mean that there would definitely be such cache already.
   729      Also, the pcoll can come from the original user defined pipeline object or
   730      an equivalent pcoll from a transformed copy of the original pipeline.
   731      """
   732      cacheable = self._cacheables.get(self.pcoll_id(pcoll), None)
   733      if cacheable:
   734        if cacheable.pcoll in self.runner_pcoll_to_user_pcoll:
   735          user_pcoll = self.runner_pcoll_to_user_pcoll[cacheable.pcoll]
   736        else:
   737          user_pcoll = cacheable.pcoll
   738        return CacheKey.from_pcoll(cacheable.var, user_pcoll).to_str()
   739      return ''
   740  
   741  
   742  def build_pipeline_instrument(pipeline, options=None):
   743    """Creates PipelineInstrument for a pipeline and its options with cache.
   744  
   745    Throughout the process, the returned PipelineInstrument snapshots the given
   746    pipeline and then mutates the pipeline. It's invoked by interactive components
   747    such as the InteractiveRunner and the given pipeline should be implicitly
   748    created runner pipelines instead of pipeline instances defined by the user.
   749  
   750    This is the shorthand for doing 3 steps: 1) compute once for metadata of the
   751    given runner pipeline and everything watched from user pipelines; 2) associate
   752    info between the runner pipeline and its corresponding user pipeline,
   753    eliminate data from other user pipelines if there are any; 3) mutate the
   754    runner pipeline to apply interactivity.
   755    """
   756    pi = PipelineInstrument(pipeline, options)
   757    pi.preprocess()
   758    pi.instrument()  # Instruments the pipeline only once.
   759    return pi
   760  
   761  
   762  def pcoll_to_pcoll_id(pipeline, original_context):
   763    """Returns a dict mapping PCollections string to PCollection IDs.
   764  
   765    Using a PipelineVisitor to iterate over every node in the pipeline,
   766    records the mapping from PCollections to PCollections IDs. This mapping
   767    will be used to query cached PCollections.
   768  
   769    Returns:
   770      (dict from str to str) a dict mapping str(pcoll) to pcoll_id.
   771    """
   772    class PCollVisitor(PipelineVisitor):
   773      """"A visitor that records input and output values to be replaced.
   774  
   775      Input and output values that should be updated are recorded in maps
   776      input_replacements and output_replacements respectively.
   777  
   778      We cannot update input and output values while visiting since that
   779      results in validation errors.
   780      """
   781      def __init__(self):
   782        self.pcoll_to_pcoll_id = {}
   783  
   784      def enter_composite_transform(self, transform_node):
   785        self.visit_transform(transform_node)
   786  
   787      def visit_transform(self, transform_node):
   788        for pcoll in transform_node.outputs.values():
   789          self.pcoll_to_pcoll_id[str(pcoll)] = (
   790              original_context.pcollections.get_id(pcoll))
   791  
   792    v = PCollVisitor()
   793    pipeline.visit(v)
   794    return v.pcoll_to_pcoll_id