github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/pipeline_fragment.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Module to build pipeline fragment that produces given PCollections. 19 20 For internal use only; no backwards-compatibility guarantees. 21 """ 22 import apache_beam as beam 23 from apache_beam.pipeline import AppliedPTransform 24 from apache_beam.pipeline import PipelineVisitor 25 from apache_beam.runners.interactive import interactive_environment as ie 26 from apache_beam.testing.test_stream import TestStream 27 28 29 class PipelineFragment(object): 30 """A fragment of a pipeline definition. 31 32 A pipeline fragment is built from the original pipeline definition to include 33 only PTransforms that are necessary to produce the given PCollections. 34 """ 35 def __init__(self, pcolls, options=None): 36 """Constructor of PipelineFragment. 37 38 Args: 39 pcolls: (List[PCollection]) a list of PCollections to build pipeline 40 fragment for. 41 options: (PipelineOptions) the pipeline options for the implicit 42 pipeline run. 43 """ 44 assert len(pcolls) > 0, ( 45 'Need at least 1 PCollection as the target data to build a pipeline ' 46 'fragment that produces it.') 47 for pcoll in pcolls: 48 assert isinstance(pcoll, beam.pvalue.PCollection), ( 49 '{} is not an apache_beam.pvalue.PCollection.'.format(pcoll)) 50 # No modification to self._user_pipeline is allowed. 51 self._user_pipeline = pcolls[0].pipeline 52 # These are user PCollections. Do not use them to deduce anything that 53 # will be executed by any runner. Instead, use 54 # `self._runner_pcolls_to_user_pcolls.keys()` to get copied PCollections. 55 self._pcolls = set(pcolls) 56 for pcoll in self._pcolls: 57 assert pcoll.pipeline is self._user_pipeline, ( 58 '{} belongs to a different user pipeline than other PCollections ' 59 'given and cannot be used to build a pipeline fragment that produces ' 60 'the given PCollections.'.format(pcoll)) 61 self._options = options 62 63 # A copied pipeline instance for modification without changing the user 64 # pipeline instance held by the end user. This instance can be processed 65 # into a pipeline fragment that later run by the underlying runner. 66 self._runner_pipeline = self._build_runner_pipeline() 67 _, self._context = self._runner_pipeline.to_runner_api(return_context=True) 68 from apache_beam.runners.interactive import pipeline_instrument as instr 69 self._runner_pcoll_to_id = instr.pcoll_to_pcoll_id( 70 self._runner_pipeline, self._context) 71 # Correlate components in the runner pipeline to components in the user 72 # pipeline. The target pcolls are the pcolls given and defined in the user 73 # pipeline. 74 self._id_to_target_pcoll = self._calculate_target_pcoll_ids() 75 self._label_to_user_transform = self._calculate_user_transform_labels() 76 # Below will give us the 1:1 correlation between 77 # PCollections/AppliedPTransforms from the copied runner pipeline and 78 # PCollections/AppliedPTransforms from the user pipeline. 79 # (Dict[PCollection, PCollection]) 80 ( 81 self._runner_pcolls_to_user_pcolls, 82 # (Dict[AppliedPTransform, AppliedPTransform]) 83 self._runner_transforms_to_user_transforms 84 ) = self._build_correlation_between_pipelines( 85 self._runner_pcoll_to_id, 86 self._id_to_target_pcoll, 87 self._label_to_user_transform) 88 89 # Below are operated on the runner pipeline. 90 (self._necessary_transforms, 91 self._necessary_pcollections) = self._mark_necessary_transforms_and_pcolls( 92 self._runner_pcolls_to_user_pcolls) 93 self._runner_pipeline = self._prune_runner_pipeline_to_fragment( 94 self._runner_pipeline, self._necessary_transforms) 95 96 def deduce_fragment(self): 97 """Deduce the pipeline fragment as an apache_beam.Pipeline instance.""" 98 fragment = beam.pipeline.Pipeline.from_runner_api( 99 self._runner_pipeline.to_runner_api(), 100 self._runner_pipeline.runner, 101 self._options) 102 ie.current_env().add_derived_pipeline(self._runner_pipeline, fragment) 103 return fragment 104 105 def run(self, display_pipeline_graph=False, use_cache=True, blocking=False): 106 """Shorthand to run the pipeline fragment.""" 107 from apache_beam.runners.interactive.interactive_runner import InteractiveRunner 108 if not isinstance(self._runner_pipeline.runner, InteractiveRunner): 109 raise RuntimeError( 110 'Please specify InteractiveRunner when creating ' 111 'the Beam pipeline to use this function.') 112 try: 113 preserved_skip_display = self._runner_pipeline.runner._skip_display 114 preserved_force_compute = self._runner_pipeline.runner._force_compute 115 preserved_blocking = self._runner_pipeline.runner._blocking 116 self._runner_pipeline.runner._skip_display = not display_pipeline_graph 117 self._runner_pipeline.runner._force_compute = not use_cache 118 self._runner_pipeline.runner._blocking = blocking 119 return self.deduce_fragment().run() 120 finally: 121 self._runner_pipeline.runner._skip_display = preserved_skip_display 122 self._runner_pipeline.runner._force_compute = preserved_force_compute 123 self._runner_pipeline.runner._blocking = preserved_blocking 124 125 def _build_runner_pipeline(self): 126 runner_pipeline = beam.pipeline.Pipeline.from_runner_api( 127 self._user_pipeline.to_runner_api(), 128 self._user_pipeline.runner, 129 self._options) 130 ie.current_env().add_derived_pipeline(self._user_pipeline, runner_pipeline) 131 return runner_pipeline 132 133 def _calculate_target_pcoll_ids(self): 134 pcoll_id_to_target_pcoll = {} 135 for pcoll in self._pcolls: 136 pcoll_id_to_target_pcoll[self._runner_pcoll_to_id.get(str(pcoll), 137 '')] = pcoll 138 return pcoll_id_to_target_pcoll 139 140 def _calculate_user_transform_labels(self): 141 label_to_user_transform = {} 142 143 class UserTransformVisitor(PipelineVisitor): 144 def enter_composite_transform(self, transform_node): 145 self.visit_transform(transform_node) 146 147 def visit_transform(self, transform_node): 148 if transform_node is not None: 149 label_to_user_transform[transform_node.full_label] = transform_node 150 151 v = UserTransformVisitor() 152 self._runner_pipeline.visit(v) 153 return label_to_user_transform 154 155 def _build_correlation_between_pipelines( 156 self, runner_pcoll_to_id, id_to_target_pcoll, label_to_user_transform): 157 runner_pcolls_to_user_pcolls = {} 158 runner_transforms_to_user_transforms = {} 159 160 class CorrelationVisitor(PipelineVisitor): 161 def enter_composite_transform(self, transform_node): 162 self.visit_transform(transform_node) 163 164 def visit_transform(self, transform_node): 165 self._process_transform(transform_node) 166 for in_pcoll in transform_node.inputs: 167 self._process_pcoll(in_pcoll) 168 for out_pcoll in transform_node.outputs.values(): 169 self._process_pcoll(out_pcoll) 170 171 def _process_pcoll(self, pcoll): 172 pcoll_id = runner_pcoll_to_id.get(str(pcoll), '') 173 if pcoll_id in id_to_target_pcoll: 174 runner_pcolls_to_user_pcolls[pcoll] = (id_to_target_pcoll[pcoll_id]) 175 176 def _process_transform(self, transform_node): 177 if transform_node.full_label in label_to_user_transform: 178 runner_transforms_to_user_transforms[transform_node] = ( 179 label_to_user_transform[transform_node.full_label]) 180 181 v = CorrelationVisitor() 182 self._runner_pipeline.visit(v) 183 return runner_pcolls_to_user_pcolls, runner_transforms_to_user_transforms 184 185 def _mark_necessary_transforms_and_pcolls(self, runner_pcolls_to_user_pcolls): 186 necessary_transforms = set() 187 all_inputs = set() 188 updated_all_inputs = set(runner_pcolls_to_user_pcolls.keys()) 189 # Do this until no more new PCollection is recorded. 190 while len(updated_all_inputs) != len(all_inputs): 191 all_inputs = set(updated_all_inputs) 192 for pcoll in all_inputs: 193 producer = pcoll.producer 194 while producer: 195 if producer in necessary_transforms: 196 break 197 # Mark the AppliedPTransform as necessary. 198 necessary_transforms.add(producer) 199 200 # Also mark composites that are not the root transform. If the root 201 # transform is added, then all transforms are incorrectly marked as 202 # necessary. If composites are not handled, then there will be 203 # orphaned PCollections. 204 if producer.parent is not None: 205 necessary_transforms.update(producer.parts) 206 207 # This will recursively add all the PCollections in this composite. 208 for part in producer.parts: 209 updated_all_inputs.update(part.outputs.values()) 210 211 # Record all necessary input and side input PCollections. 212 updated_all_inputs.update(producer.inputs) 213 # pylint: disable=bad-option-value 214 side_input_pvalues = set( 215 map(lambda side_input: side_input.pvalue, producer.side_inputs)) 216 updated_all_inputs.update(side_input_pvalues) 217 # Go to its parent AppliedPTransform. 218 producer = producer.parent 219 return necessary_transforms, all_inputs 220 221 def _prune_runner_pipeline_to_fragment( 222 self, runner_pipeline, necessary_transforms): 223 class PruneVisitor(PipelineVisitor): 224 def enter_composite_transform(self, transform_node): 225 if should_skip_pruning(transform_node): 226 return 227 228 pruned_parts = list(transform_node.parts) 229 for part in transform_node.parts: 230 if part not in necessary_transforms: 231 pruned_parts.remove(part) 232 transform_node.parts = tuple(pruned_parts) 233 self.visit_transform(transform_node) 234 235 def visit_transform(self, transform_node): 236 if transform_node not in necessary_transforms: 237 transform_node.parent = None 238 239 v = PruneVisitor() 240 runner_pipeline.visit(v) 241 return runner_pipeline 242 243 244 def should_skip_pruning(transform: AppliedPTransform): 245 return ( 246 isinstance(transform.transform, TestStream) or 247 '_DataFrame_' in transform.full_label)