github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/augmented_pipeline.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Module to augment interactive flavor into the given pipeline.
    19  
    20  For internal use only; no backward-compatibility guarantees.
    21  """
    22  # pytype: skip-file
    23  
    24  import copy
    25  from typing import Dict
    26  from typing import Optional
    27  from typing import Set
    28  
    29  import apache_beam as beam
    30  from apache_beam.portability.api import beam_runner_api_pb2
    31  from apache_beam.runners.interactive import interactive_environment as ie
    32  from apache_beam.runners.interactive import background_caching_job
    33  from apache_beam.runners.interactive.caching.cacheable import Cacheable
    34  from apache_beam.runners.interactive.caching.read_cache import ReadCache
    35  from apache_beam.runners.interactive.caching.write_cache import WriteCache
    36  
    37  
    38  class AugmentedPipeline:
    39    """A pipeline with augmented interactive flavor that caches intermediate
    40    PCollections defined by the user, reads computed PCollections as source and
    41    prunes unnecessary pipeline parts for fast computation.
    42    """
    43    def __init__(
    44        self,
    45        user_pipeline: beam.Pipeline,
    46        pcolls: Optional[Set[beam.pvalue.PCollection]] = None):
    47      """
    48      Initializes a pipelilne for augmenting interactive flavor.
    49  
    50      Args:
    51        user_pipeline: a beam.Pipeline instance defined by the user.
    52        pcolls: cacheable pcolls to be computed/retrieved. If the set is
    53          empty, all intermediate pcolls assigned to variables are applicable.
    54      """
    55      assert not pcolls or all(pcoll.pipeline is user_pipeline for pcoll in
    56        pcolls), 'All %s need to belong to %s' % (pcolls, user_pipeline)
    57      self._user_pipeline = user_pipeline
    58      self._pcolls = pcolls
    59      self._cache_manager = ie.current_env().get_cache_manager(
    60          self._user_pipeline, create_if_absent=True)
    61      if background_caching_job.has_source_to_cache(self._user_pipeline):
    62        self._cache_manager = ie.current_env().get_cache_manager(
    63            self._user_pipeline)
    64      _, self._context = self._user_pipeline.to_runner_api(return_context=True)
    65      self._context.component_id_map = copy.copy(
    66          self._user_pipeline.component_id_map)
    67      self._cacheables = self.cacheables()
    68  
    69    @property
    70    def augmented_pipeline(self) -> beam_runner_api_pb2.Pipeline:
    71      return self.augment()
    72  
    73    # TODO(https://github.com/apache/beam/issues/20526): Support generating a
    74    # background recording job that contains unbound source recording transforms
    75    # only.
    76    @property
    77    def background_recording_pipeline(self) -> beam_runner_api_pb2.Pipeline:
    78      raise NotImplementedError
    79  
    80    def cacheables(self) -> Dict[beam.pvalue.PCollection, Cacheable]:
    81      """Finds all the cacheable intermediate PCollections in the pipeline with
    82      their metadata.
    83      """
    84      c = {}
    85      for watching in ie.current_env().watching():
    86        for key, val in watching:
    87          if (isinstance(val, beam.pvalue.PCollection) and
    88              val.pipeline is self._user_pipeline and
    89              (not self._pcolls or val in self._pcolls)):
    90            c[val] = Cacheable(
    91                var=key,
    92                pcoll=val,
    93                version=str(id(val)),
    94                producer_version=str(id(val.producer)))
    95      return c
    96  
    97    def augment(self) -> beam_runner_api_pb2.Pipeline:
    98      """Augments the pipeline with cache. Always calculates a new result.
    99  
   100      For a cacheable PCollection, if cache exists, read cache; else, write cache.
   101      """
   102      pipeline = self._user_pipeline.to_runner_api()
   103  
   104      # Find pcolls eligible for reading or writing cache.
   105      readcache_pcolls = set()
   106      for pcoll, cacheable in self._cacheables.items():
   107        key = repr(cacheable.to_key())
   108        if (self._cache_manager.exists('full', key) and
   109            pcoll in ie.current_env().computed_pcollections):
   110          readcache_pcolls.add(pcoll)
   111      writecache_pcolls = set(
   112          self._cacheables.keys()).difference(readcache_pcolls)
   113  
   114      # Wire in additional transforms to read cache and write cache.
   115      for readcache_pcoll in readcache_pcolls:
   116        ReadCache(
   117            pipeline,
   118            self._context,
   119            self._cache_manager,
   120            self._cacheables[readcache_pcoll]).read_cache()
   121      for writecache_pcoll in writecache_pcolls:
   122        WriteCache(
   123            pipeline,
   124            self._context,
   125            self._cache_manager,
   126            self._cacheables[writecache_pcoll]).write_cache()
   127      # TODO(https://github.com/apache/beam/issues/20526): Support streaming, add
   128      # pruning logic, and integrate pipeline fragment logic.
   129      return pipeline