github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/caching/write_cache.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Module to write cache for PCollections being computed. 19 20 For internal use only; no backward-compatibility guarantees. 21 """ 22 # pytype: skip-file 23 24 from typing import Tuple 25 26 import apache_beam as beam 27 from apache_beam.portability.api import beam_runner_api_pb2 28 from apache_beam.runners.interactive import cache_manager as cache 29 from apache_beam.runners.interactive.caching.cacheable import Cacheable 30 from apache_beam.runners.interactive.caching.reify import reify_to_cache 31 from apache_beam.runners.pipeline_context import PipelineContext 32 from apache_beam.transforms.ptransform import PTransform 33 34 35 class WriteCache: 36 """Class that facilitates writing cache for PCollections being computed. 37 """ 38 def __init__( 39 self, 40 pipeline: beam_runner_api_pb2.Pipeline, 41 context: PipelineContext, 42 cache_manager: cache.CacheManager, 43 cacheable: Cacheable): 44 self._pipeline = pipeline 45 self._context = context 46 self._cache_manager = cache_manager 47 self._cacheable = cacheable 48 self._key = cacheable.to_key().to_str() 49 50 def write_cache(self) -> None: 51 """Writes cache for the cacheable PCollection that is being computed. 52 53 First, it creates a temporary pipeline instance on top of the existing 54 component_id_map from self._pipeline's context so that both pipelines 55 share the context and have no conflict component ids. 56 Second, it creates a _PCollectionPlaceHolder in the temporary pipeline that 57 mimics the attributes of the cacheable PCollection to be written into cache. 58 It also marks all components in the current temporary pipeline as 59 ignorable when later copying components to self._pipeline. 60 Third, it instantiates a _WriteCacheTransform that uses the 61 _PCollectionPlaceHolder as the input. This adds a subgraph under top level 62 transforms that writes the _PCollectionPlaceHolder into cache. 63 Fourth, it copies components of the subgraph from the temporary pipeline to 64 self._pipeline, skipping components that are ignored in the temporary 65 pipeline and components that are not in the temporary pipeline but presents 66 in the component_id_map of self._pipeline. 67 Last, it replaces inputs of all transforms that consume the 68 _PCollectionPlaceHolder with the cacheable PCollection to be written to 69 cache. 70 """ 71 template, write_input_placeholder = self._build_runner_api_template() 72 input_placeholder_id = self._context.pcollections.get_id( 73 write_input_placeholder.placeholder_pcoll) 74 input_id = self._context.pcollections.get_id(self._cacheable.pcoll) 75 76 # Copy cache writing subgraph from the template to the pipeline proto. 77 for pcoll_id in template.components.pcollections: 78 if (pcoll_id in self._pipeline.components.pcollections or 79 pcoll_id in write_input_placeholder.ignorable_components.pcollections 80 ): 81 continue 82 self._pipeline.components.pcollections[pcoll_id].CopyFrom( 83 template.components.pcollections[pcoll_id]) 84 for coder_id in template.components.coders: 85 if (coder_id in self._pipeline.components.coders or 86 coder_id in write_input_placeholder.ignorable_components.coders): 87 continue 88 self._pipeline.components.coders[coder_id].CopyFrom( 89 template.components.coders[coder_id]) 90 for windowing_strategy_id in template.components.windowing_strategies: 91 if (windowing_strategy_id in 92 self._pipeline.components.windowing_strategies or 93 windowing_strategy_id in 94 write_input_placeholder.ignorable_components.windowing_strategies): 95 continue 96 self._pipeline.components.windowing_strategies[ 97 windowing_strategy_id].CopyFrom( 98 template.components.windowing_strategies[windowing_strategy_id]) 99 template_root_transform_id = template.root_transform_ids[0] 100 root_transform_id = self._pipeline.root_transform_ids[0] 101 for transform_id in template.components.transforms: 102 if (transform_id in self._pipeline.components.transforms or transform_id 103 in write_input_placeholder.ignorable_components.transforms): 104 continue 105 self._pipeline.components.transforms[transform_id].CopyFrom( 106 template.components.transforms[transform_id]) 107 for top_level_transform in template.components.transforms[ 108 template_root_transform_id].subtransforms: 109 if (top_level_transform in 110 write_input_placeholder.ignorable_components.transforms): 111 continue 112 self._pipeline.components.transforms[ 113 root_transform_id].subtransforms.append(top_level_transform) 114 115 # Replace all the input pcoll of input_placeholder_id from cache writing 116 # with cacheable pcoll of input_id. 117 for transform in self._pipeline.components.transforms.values(): 118 inputs = transform.inputs 119 if input_placeholder_id in inputs.values(): 120 keys_need_replacement = set() 121 for key in inputs: 122 if inputs[key] == input_placeholder_id: 123 keys_need_replacement.add(key) 124 for key in keys_need_replacement: 125 inputs[key] = input_id 126 127 def _build_runner_api_template( 128 self) -> Tuple[beam_runner_api_pb2.Pipeline, '_PCollectionPlaceHolder']: 129 pph = _PCollectionPlaceHolder(self._cacheable.pcoll, self._context) 130 transform = _WriteCacheTransform(self._cache_manager, self._key) 131 _ = pph.placeholder_pcoll | 'sink_cache_' + self._key >> transform 132 return pph.placeholder_pcoll.pipeline.to_runner_api(), pph 133 134 135 class _WriteCacheTransform(PTransform): 136 """A composite transform encapsulates writing cache for PCollections. 137 """ 138 def __init__(self, cache_manager: cache.CacheManager, key: str): 139 self._cache_manager = cache_manager 140 self._key = key 141 142 def expand(self, pcoll: beam.pvalue.PCollection) -> beam.pvalue.PValue: 143 return reify_to_cache( 144 pcoll=pcoll, cache_key=self._key, cache_manager=self._cache_manager) 145 146 147 class _PCollectionPlaceHolder: 148 """A placeholder as an input to the cache writing transform. 149 """ 150 def __init__(self, pcoll: beam.pvalue.PCollection, context: PipelineContext): 151 tmp_pipeline = beam.Pipeline() 152 tmp_pipeline.component_id_map = context.component_id_map 153 self._input_placeholder = tmp_pipeline | 'CreatePInput' >> beam.Create( 154 [], reshuffle=False) 155 self._input_placeholder.tag = pcoll.tag 156 self._input_placeholder.element_type = pcoll.element_type 157 self._input_placeholder.is_bounded = pcoll.is_bounded 158 self._input_placeholder._windowing = pcoll.windowing 159 self._ignorable_components = tmp_pipeline.to_runner_api().components 160 161 @property 162 def placeholder_pcoll(self) -> beam.pvalue.PCollection: 163 return self._input_placeholder 164 165 @property 166 def ignorable_components(self) -> beam_runner_api_pb2.Components: 167 """Subgraph generated by the placeholder that can be ignored in the final 168 pipeline proto. 169 """ 170 return self._ignorable_components