github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/caching/write_cache.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Module to write cache for PCollections being computed.
    19  
    20  For internal use only; no backward-compatibility guarantees.
    21  """
    22  # pytype: skip-file
    23  
    24  from typing import Tuple
    25  
    26  import apache_beam as beam
    27  from apache_beam.portability.api import beam_runner_api_pb2
    28  from apache_beam.runners.interactive import cache_manager as cache
    29  from apache_beam.runners.interactive.caching.cacheable import Cacheable
    30  from apache_beam.runners.interactive.caching.reify import reify_to_cache
    31  from apache_beam.runners.pipeline_context import PipelineContext
    32  from apache_beam.transforms.ptransform import PTransform
    33  
    34  
    35  class WriteCache:
    36    """Class that facilitates writing cache for PCollections being computed.
    37    """
    38    def __init__(
    39        self,
    40        pipeline: beam_runner_api_pb2.Pipeline,
    41        context: PipelineContext,
    42        cache_manager: cache.CacheManager,
    43        cacheable: Cacheable):
    44      self._pipeline = pipeline
    45      self._context = context
    46      self._cache_manager = cache_manager
    47      self._cacheable = cacheable
    48      self._key = cacheable.to_key().to_str()
    49  
    50    def write_cache(self) -> None:
    51      """Writes cache for the cacheable PCollection that is being computed.
    52  
    53      First, it creates a temporary pipeline instance on top of the existing
    54      component_id_map from self._pipeline's context so that both pipelines
    55      share the context and have no conflict component ids.
    56      Second, it creates a _PCollectionPlaceHolder in the temporary pipeline that
    57      mimics the attributes of the cacheable PCollection to be written into cache.
    58      It also marks all components in the current temporary pipeline as
    59      ignorable when later copying components to self._pipeline.
    60      Third, it instantiates a _WriteCacheTransform that uses the
    61      _PCollectionPlaceHolder as the input. This adds a subgraph under top level
    62      transforms that writes the _PCollectionPlaceHolder into cache.
    63      Fourth, it copies components of the subgraph from the temporary pipeline to
    64      self._pipeline, skipping components that are ignored in the temporary
    65      pipeline and components that are not in the temporary pipeline but presents
    66      in the component_id_map of self._pipeline.
    67      Last, it replaces inputs of all transforms that consume the
    68      _PCollectionPlaceHolder with the cacheable PCollection to be written to
    69      cache.
    70      """
    71      template, write_input_placeholder = self._build_runner_api_template()
    72      input_placeholder_id = self._context.pcollections.get_id(
    73          write_input_placeholder.placeholder_pcoll)
    74      input_id = self._context.pcollections.get_id(self._cacheable.pcoll)
    75  
    76      # Copy cache writing subgraph from the template to the pipeline proto.
    77      for pcoll_id in template.components.pcollections:
    78        if (pcoll_id in self._pipeline.components.pcollections or
    79            pcoll_id in write_input_placeholder.ignorable_components.pcollections
    80            ):
    81          continue
    82        self._pipeline.components.pcollections[pcoll_id].CopyFrom(
    83            template.components.pcollections[pcoll_id])
    84      for coder_id in template.components.coders:
    85        if (coder_id in self._pipeline.components.coders or
    86            coder_id in write_input_placeholder.ignorable_components.coders):
    87          continue
    88        self._pipeline.components.coders[coder_id].CopyFrom(
    89            template.components.coders[coder_id])
    90      for windowing_strategy_id in template.components.windowing_strategies:
    91        if (windowing_strategy_id in
    92            self._pipeline.components.windowing_strategies or
    93            windowing_strategy_id in
    94            write_input_placeholder.ignorable_components.windowing_strategies):
    95          continue
    96        self._pipeline.components.windowing_strategies[
    97            windowing_strategy_id].CopyFrom(
    98                template.components.windowing_strategies[windowing_strategy_id])
    99      template_root_transform_id = template.root_transform_ids[0]
   100      root_transform_id = self._pipeline.root_transform_ids[0]
   101      for transform_id in template.components.transforms:
   102        if (transform_id in self._pipeline.components.transforms or transform_id
   103            in write_input_placeholder.ignorable_components.transforms):
   104          continue
   105        self._pipeline.components.transforms[transform_id].CopyFrom(
   106            template.components.transforms[transform_id])
   107      for top_level_transform in template.components.transforms[
   108          template_root_transform_id].subtransforms:
   109        if (top_level_transform in
   110            write_input_placeholder.ignorable_components.transforms):
   111          continue
   112        self._pipeline.components.transforms[
   113            root_transform_id].subtransforms.append(top_level_transform)
   114  
   115      # Replace all the input pcoll of input_placeholder_id from cache writing
   116      # with cacheable pcoll of input_id.
   117      for transform in self._pipeline.components.transforms.values():
   118        inputs = transform.inputs
   119        if input_placeholder_id in inputs.values():
   120          keys_need_replacement = set()
   121          for key in inputs:
   122            if inputs[key] == input_placeholder_id:
   123              keys_need_replacement.add(key)
   124          for key in keys_need_replacement:
   125            inputs[key] = input_id
   126  
   127    def _build_runner_api_template(
   128        self) -> Tuple[beam_runner_api_pb2.Pipeline, '_PCollectionPlaceHolder']:
   129      pph = _PCollectionPlaceHolder(self._cacheable.pcoll, self._context)
   130      transform = _WriteCacheTransform(self._cache_manager, self._key)
   131      _ = pph.placeholder_pcoll | 'sink_cache_' + self._key >> transform
   132      return pph.placeholder_pcoll.pipeline.to_runner_api(), pph
   133  
   134  
   135  class _WriteCacheTransform(PTransform):
   136    """A composite transform encapsulates writing cache for PCollections.
   137    """
   138    def __init__(self, cache_manager: cache.CacheManager, key: str):
   139      self._cache_manager = cache_manager
   140      self._key = key
   141  
   142    def expand(self, pcoll: beam.pvalue.PCollection) -> beam.pvalue.PValue:
   143      return reify_to_cache(
   144          pcoll=pcoll, cache_key=self._key, cache_manager=self._cache_manager)
   145  
   146  
   147  class _PCollectionPlaceHolder:
   148    """A placeholder as an input to the cache writing transform.
   149    """
   150    def __init__(self, pcoll: beam.pvalue.PCollection, context: PipelineContext):
   151      tmp_pipeline = beam.Pipeline()
   152      tmp_pipeline.component_id_map = context.component_id_map
   153      self._input_placeholder = tmp_pipeline | 'CreatePInput' >> beam.Create(
   154          [], reshuffle=False)
   155      self._input_placeholder.tag = pcoll.tag
   156      self._input_placeholder.element_type = pcoll.element_type
   157      self._input_placeholder.is_bounded = pcoll.is_bounded
   158      self._input_placeholder._windowing = pcoll.windowing
   159      self._ignorable_components = tmp_pipeline.to_runner_api().components
   160  
   161    @property
   162    def placeholder_pcoll(self) -> beam.pvalue.PCollection:
   163      return self._input_placeholder
   164  
   165    @property
   166    def ignorable_components(self) -> beam_runner_api_pb2.Components:
   167      """Subgraph generated by the placeholder that can be ignored in the final
   168      pipeline proto.
   169      """
   170      return self._ignorable_components