github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/pipeline_instrument.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Module to instrument interactivity to the given pipeline. 19 20 For internal use only; no backwards-compatibility guarantees. 21 This module accesses current interactive environment and analyzes given pipeline 22 to transform original pipeline into a one-shot pipeline with interactivity. 23 """ 24 # pytype: skip-file 25 26 import logging 27 from typing import Dict 28 29 import apache_beam as beam 30 from apache_beam.pipeline import PipelineVisitor 31 from apache_beam.portability.api import beam_runner_api_pb2 32 from apache_beam.runners.interactive import interactive_environment as ie 33 from apache_beam.runners.interactive import pipeline_fragment as pf 34 from apache_beam.runners.interactive import background_caching_job 35 from apache_beam.runners.interactive import utils 36 from apache_beam.runners.interactive.caching.cacheable import Cacheable 37 from apache_beam.runners.interactive.caching.cacheable import CacheKey 38 from apache_beam.runners.interactive.caching.reify import WRITE_CACHE 39 from apache_beam.runners.interactive.caching.reify import reify_to_cache 40 from apache_beam.runners.interactive.caching.reify import unreify_from_cache 41 from apache_beam.testing import test_stream 42 43 _LOGGER = logging.getLogger(__name__) 44 45 46 class PipelineInstrument(object): 47 """A pipeline instrument for pipeline to be executed by interactive runner. 48 49 This module should never depend on underlying runner that interactive runner 50 delegates. It instruments the original instance of pipeline directly by 51 appending or replacing transforms with help of cache. It provides 52 interfaces to recover states of original pipeline. It's the interactive 53 runner's responsibility to coordinate supported underlying runners to run 54 the pipeline instrumented and recover the original pipeline states if needed. 55 """ 56 def __init__(self, pipeline, options=None): 57 self._pipeline = pipeline 58 59 self._user_pipeline = ie.current_env().user_pipeline(pipeline) 60 if not self._user_pipeline: 61 self._user_pipeline = pipeline 62 self._cache_manager = ie.current_env().get_cache_manager( 63 self._user_pipeline, create_if_absent=True) 64 # Check if the user defined pipeline contains any source to cache. 65 # If so, during the check, the cache manager is converted into a 66 # streaming cache manager, thus re-assign. 67 if background_caching_job.has_source_to_cache(self._user_pipeline): 68 self._cache_manager = ie.current_env().get_cache_manager( 69 self._user_pipeline) 70 71 self._background_caching_pipeline = beam.pipeline.Pipeline.from_runner_api( 72 pipeline.to_runner_api(), pipeline.runner, options) 73 ie.current_env().add_derived_pipeline( 74 self._pipeline, self._background_caching_pipeline) 75 76 # Snapshot of original pipeline information. 77 (self._original_pipeline_proto, 78 context) = self._pipeline.to_runner_api(return_context=True) 79 80 # All compute-once-against-original-pipeline fields. 81 self._unbounded_sources = utils.unbounded_sources( 82 self._background_caching_pipeline) 83 self._pcoll_to_pcoll_id = pcoll_to_pcoll_id(self._pipeline, context) 84 85 # A Dict[str, Cacheable] from a PCollection id to a Cacheable that belongs 86 # to the analyzed pipeline. 87 self._cacheables = self.find_cacheables() 88 89 # A dict from cache key to PCollection that is read from cache. 90 # If exists, caller should reuse the PCollection read. If not, caller 91 # should create new transform and track the PCollection read from cache. 92 # (Dict[str, AppliedPTransform]). 93 self._cached_pcoll_read = {} 94 95 # A dict from PCollections in the runner pipeline instance to their 96 # corresponding PCollections in the user pipeline instance. Populated 97 # after preprocess(). 98 self._runner_pcoll_to_user_pcoll = {} 99 self._pruned_pipeline_proto = None 100 101 # Refers target pcolls output by instrumented write cache transforms, used 102 # by pruning logic as supplemental targets to build pipeline fragment up 103 # from. 104 self._extended_targets = set() 105 106 # Refers pcolls used as inputs but got replaced by outputs of read cache 107 # transforms instrumented, used by pruning logic as targets no longer need 108 # to be produced during pipeline runs. 109 self._ignored_targets = set() 110 111 # Set of PCollections that are written to cache. 112 self.cached_pcolls = set() 113 114 def instrumented_pipeline_proto(self): 115 """Always returns a new instance of portable instrumented proto.""" 116 targets = set(self._runner_pcoll_to_user_pcoll.keys()) 117 targets.update(self._extended_targets) 118 targets = targets.difference(self._ignored_targets) 119 if len(targets) > 0: 120 # Prunes upstream transforms that don't contribute to the targets the 121 # instrumented pipeline run cares. 122 return pf.PipelineFragment( 123 list(targets)).deduce_fragment().to_runner_api() 124 return self._pipeline.to_runner_api() 125 126 def _required_components( 127 self, 128 pipeline_proto, 129 required_transforms_ids, 130 visited, 131 follow_outputs=False, 132 follow_inputs=False): 133 """Returns the components and subcomponents of the given transforms. 134 135 This method returns required components such as transforms and PCollections 136 related to the given transforms and to all of their subtransforms. This 137 method accomplishes this recursively. 138 """ 139 if not required_transforms_ids: 140 return ({}, {}) 141 142 transforms = pipeline_proto.components.transforms 143 pcollections = pipeline_proto.components.pcollections 144 145 # Cache the transforms that will be copied into the new pipeline proto. 146 required_transforms = {k: transforms[k] for k in required_transforms_ids} 147 148 # Cache all the output PCollections of the transforms. 149 pcollection_ids = [ 150 pc for t in required_transforms.values() for pc in t.outputs.values() 151 ] 152 required_pcollections = { 153 pc_id: pcollections[pc_id] 154 for pc_id in pcollection_ids 155 } 156 157 subtransforms = {} 158 subpcollections = {} 159 160 # Recursively go through all the subtransforms and add their components. 161 for transform_id, transform in required_transforms.items(): 162 if transform_id in pipeline_proto.root_transform_ids: 163 continue 164 (t, pc) = self._required_components( 165 pipeline_proto, 166 transform.subtransforms, 167 visited, 168 follow_outputs=False, 169 follow_inputs=False) 170 subtransforms.update(t) 171 subpcollections.update(pc) 172 173 if follow_outputs: 174 outputs = [ 175 pc_id for t in required_transforms.values() 176 for pc_id in t.outputs.values() 177 ] 178 visited_copy = visited.copy() 179 consuming_transforms = { 180 t_id: t 181 for t_id, 182 t in transforms.items() 183 if set(outputs).intersection(set(t.inputs.values())) 184 } 185 consuming_transforms = set(consuming_transforms.keys()) 186 visited.update(consuming_transforms) 187 consuming_transforms = consuming_transforms - visited_copy 188 (t, pc) = self._required_components( 189 pipeline_proto, 190 list(consuming_transforms), 191 visited, 192 follow_outputs, 193 follow_inputs) 194 subtransforms.update(t) 195 subpcollections.update(pc) 196 197 if follow_inputs: 198 inputs = [ 199 pc_id for t in required_transforms.values() 200 for pc_id in t.inputs.values() 201 ] 202 producing_transforms = { 203 t_id: t 204 for t_id, 205 t in transforms.items() 206 if set(inputs).intersection(set(t.outputs.values())) 207 } 208 (t, pc) = self._required_components( 209 pipeline_proto, 210 list(producing_transforms.keys()), 211 visited, 212 follow_outputs, 213 follow_inputs) 214 subtransforms.update(t) 215 subpcollections.update(pc) 216 217 # Now we got all the components and their subcomponents, so return the 218 # complete collection. 219 required_transforms.update(subtransforms) 220 required_pcollections.update(subpcollections) 221 222 return (required_transforms, required_pcollections) 223 224 def prune_subgraph_for(self, pipeline, required_transform_ids): 225 # Create the pipeline_proto to read all the components from. It will later 226 # create a new pipeline proto from the cut out components. 227 pipeline_proto, context = pipeline.to_runner_api(return_context=True) 228 229 # Get all the root transforms. The caching transforms will be subtransforms 230 # of one of these roots. 231 roots = [root for root in pipeline_proto.root_transform_ids] 232 233 (t, p) = self._required_components( 234 pipeline_proto, 235 roots + required_transform_ids, 236 set(), 237 follow_outputs=True, 238 follow_inputs=True) 239 240 def set_proto_map(proto_map, new_value): 241 proto_map.clear() 242 for key, value in new_value.items(): 243 proto_map[key].CopyFrom(value) 244 245 # Copy the transforms into the new pipeline. 246 pipeline_to_execute = beam_runner_api_pb2.Pipeline() 247 pipeline_to_execute.root_transform_ids[:] = roots 248 set_proto_map(pipeline_to_execute.components.transforms, t) 249 set_proto_map(pipeline_to_execute.components.pcollections, p) 250 set_proto_map( 251 pipeline_to_execute.components.coders, context.to_runner_api().coders) 252 set_proto_map( 253 pipeline_to_execute.components.windowing_strategies, 254 context.to_runner_api().windowing_strategies) 255 256 # Cut out all subtransforms in the root that aren't the required transforms. 257 for root_id in roots: 258 root = pipeline_to_execute.components.transforms[root_id] 259 root.subtransforms[:] = [ 260 transform_id for transform_id in root.subtransforms 261 if transform_id in pipeline_to_execute.components.transforms 262 ] 263 264 return pipeline_to_execute 265 266 def background_caching_pipeline_proto(self): 267 """Returns the background caching pipeline. 268 269 This method creates a background caching pipeline by: adding writes to cache 270 from each unbounded source (done in the instrument method), and cutting out 271 all components (transform, PCollections, coders, windowing strategies) that 272 are not the unbounded sources or writes to cache (or subtransforms thereof). 273 """ 274 # Create the pipeline_proto to read all the components from. It will later 275 # create a new pipeline proto from the cut out components. 276 pipeline_proto, context = self._background_caching_pipeline.to_runner_api( 277 return_context=True) 278 279 # Get all the sources we want to cache. 280 sources = utils.unbounded_sources(self._background_caching_pipeline) 281 282 # Get all the root transforms. The caching transforms will be subtransforms 283 # of one of these roots. 284 roots = [root for root in pipeline_proto.root_transform_ids] 285 286 # Get the transform IDs of the caching transforms. These caching operations 287 # are added to the _background_caching_pipeline in the instrument() method. 288 # It's added there so that multiple calls to this method won't add multiple 289 # caching operations (idempotent). 290 transforms = pipeline_proto.components.transforms 291 caching_transform_ids = [ 292 t_id for root in roots for t_id in transforms[root].subtransforms 293 if WRITE_CACHE in t_id 294 ] 295 296 # Get the IDs of the unbounded sources. 297 required_transform_labels = [src.full_label for src in sources] 298 unbounded_source_ids = [ 299 k for k, 300 v in transforms.items() if v.unique_name in required_transform_labels 301 ] 302 303 # The required transforms are the transforms that we want to cut out of 304 # the pipeline_proto and insert into a new pipeline to return. 305 required_transform_ids = ( 306 roots + caching_transform_ids + unbounded_source_ids) 307 (t, p) = self._required_components( 308 pipeline_proto, required_transform_ids, set()) 309 310 def set_proto_map(proto_map, new_value): 311 proto_map.clear() 312 for key, value in new_value.items(): 313 proto_map[key].CopyFrom(value) 314 315 # Copy the transforms into the new pipeline. 316 pipeline_to_execute = beam_runner_api_pb2.Pipeline() 317 pipeline_to_execute.root_transform_ids[:] = roots 318 set_proto_map(pipeline_to_execute.components.transforms, t) 319 set_proto_map(pipeline_to_execute.components.pcollections, p) 320 set_proto_map( 321 pipeline_to_execute.components.coders, context.to_runner_api().coders) 322 set_proto_map( 323 pipeline_to_execute.components.windowing_strategies, 324 context.to_runner_api().windowing_strategies) 325 326 # Cut out all subtransforms in the root that aren't the required transforms. 327 for root_id in roots: 328 root = pipeline_to_execute.components.transforms[root_id] 329 root.subtransforms[:] = [ 330 transform_id for transform_id in root.subtransforms 331 if transform_id in pipeline_to_execute.components.transforms 332 ] 333 334 return pipeline_to_execute 335 336 @property 337 def cacheables(self) -> Dict[str, Cacheable]: 338 """Returns the Cacheables by PCollection ids. 339 340 If you're already working with user defined pipelines and PCollections, 341 do not build a PipelineInstrument just to get the cacheables. Instead, 342 use apache_beam.runners.interactive.utils.cacheables. 343 """ 344 return self._cacheables 345 346 @property 347 def has_unbounded_sources(self): 348 """Returns whether the pipeline has any recordable sources. 349 """ 350 return len(self._unbounded_sources) > 0 351 352 @property 353 def original_pipeline_proto(self): 354 """Returns a snapshot of the pipeline proto before instrumentation.""" 355 return self._original_pipeline_proto 356 357 @property 358 def user_pipeline(self): 359 """Returns a reference to the pipeline instance defined by the user. If a 360 pipeline has no cacheable PCollection and the user pipeline cannot be 361 found, return None indicating there is nothing to be cached in the user 362 pipeline. 363 364 The pipeline given for instrumenting and mutated in this class is not 365 necessarily the pipeline instance defined by the user. From the watched 366 scopes, this class figures out what the user pipeline instance is. 367 This metadata can be used for tracking pipeline results. 368 """ 369 return self._user_pipeline 370 371 @property 372 def runner_pcoll_to_user_pcoll(self): 373 """Returns cacheable PCollections correlated from instances in the runner 374 pipeline to instances in the user pipeline.""" 375 return self._runner_pcoll_to_user_pcoll 376 377 def find_cacheables(self) -> Dict[str, Cacheable]: 378 """Finds PCollections that need to be cached for analyzed pipeline. 379 380 There might be multiple pipelines defined and watched, this will only find 381 cacheables belong to the analyzed pipeline. 382 """ 383 result = {} 384 cacheables = utils.cacheables() 385 for _, cacheable in cacheables.items(): 386 if cacheable.pcoll.pipeline is not self._user_pipeline: 387 # Ignore all cacheables from other pipelines. 388 continue 389 pcoll_id = self.pcoll_id(cacheable.pcoll) 390 if not pcoll_id: 391 _LOGGER.debug( 392 'Unable to retrieve PCollection id for %s. Ignored.', 393 cacheable.pcoll) 394 continue 395 result[self.pcoll_id(cacheable.pcoll)] = cacheable 396 return result 397 398 def instrument(self): 399 """Instruments original pipeline with cache. 400 401 For cacheable output PCollection, if cache for the key doesn't exist, do 402 _write_cache(); for cacheable input PCollection, if cache for the key 403 exists, do _read_cache(). No instrument in any other situation. 404 405 Modifies: 406 self._pipeline 407 """ 408 cacheable_inputs = set() 409 all_inputs = set() 410 all_outputs = set() 411 unbounded_source_pcolls = set() 412 413 class InstrumentVisitor(PipelineVisitor): 414 """Visitor utilizes cache to instrument the pipeline.""" 415 def __init__(self, pin): 416 self._pin = pin 417 418 def enter_composite_transform(self, transform_node): 419 self.visit_transform(transform_node) 420 421 def visit_transform(self, transform_node): 422 if isinstance(transform_node.transform, 423 tuple(ie.current_env().options.recordable_sources)): 424 unbounded_source_pcolls.update(transform_node.outputs.values()) 425 cacheable_inputs.update(self._pin._cacheable_inputs(transform_node)) 426 ins, outs = self._pin._all_inputs_outputs(transform_node) 427 all_inputs.update(ins) 428 all_outputs.update(outs) 429 430 v = InstrumentVisitor(self) 431 self._pipeline.visit(v) 432 # Every output PCollection that is never used as an input PCollection is 433 # considered as a side effect of the pipeline run and should be included. 434 self._extended_targets.update(all_outputs.difference(all_inputs)) 435 # Add the unbounded source PCollections to the cacheable inputs. This allows 436 # for the caching of unbounded sources without a variable reference. 437 cacheable_inputs.update(unbounded_source_pcolls) 438 439 # Create ReadCache transforms. 440 for cacheable_input in cacheable_inputs: 441 self._read_cache( 442 self._pipeline, 443 cacheable_input, 444 cacheable_input in unbounded_source_pcolls) 445 # Replace/wire inputs w/ cached PCollections from ReadCache transforms. 446 self._replace_with_cached_inputs(self._pipeline) 447 448 # Write cache for all cacheables. 449 for _, cacheable in self._cacheables.items(): 450 self._write_cache( 451 self._pipeline, cacheable.pcoll, ignore_unbounded_reads=True) 452 453 # Instrument the background caching pipeline if we can. 454 if self.has_unbounded_sources: 455 for source in self._unbounded_sources: 456 self._write_cache( 457 self._background_caching_pipeline, 458 source.outputs[None], 459 output_as_extended_target=False, 460 is_capture=True) 461 462 class TestStreamVisitor(PipelineVisitor): 463 def __init__(self): 464 self.test_stream = None 465 466 def enter_composite_transform(self, transform_node): 467 self.visit_transform(transform_node) 468 469 def visit_transform(self, transform_node): 470 if (self.test_stream is None and 471 isinstance(transform_node.transform, test_stream.TestStream)): 472 self.test_stream = transform_node.full_label 473 474 v = TestStreamVisitor() 475 self._pipeline.visit(v) 476 pipeline_proto = self._pipeline.to_runner_api(return_context=False) 477 test_stream_id = '' 478 for t_id, t in pipeline_proto.components.transforms.items(): 479 if t.unique_name == v.test_stream: 480 test_stream_id = t_id 481 break 482 self._pruned_pipeline_proto = self.prune_subgraph_for( 483 self._pipeline, [test_stream_id]) 484 pruned_pipeline = beam.Pipeline.from_runner_api( 485 proto=self._pruned_pipeline_proto, 486 runner=self._pipeline.runner, 487 options=self._pipeline._options) 488 ie.current_env().add_derived_pipeline(self._pipeline, pruned_pipeline) 489 self._pipeline = pruned_pipeline 490 491 def preprocess(self): 492 """Pre-processes the pipeline. 493 494 Since the pipeline instance in the class might not be the same instance 495 defined in the user code, the pre-process will figure out the relationship 496 of cacheable PCollections between these 2 instances by replacing 'pcoll' 497 fields in the cacheable dictionary with ones from the running instance. 498 """ 499 class PreprocessVisitor(PipelineVisitor): 500 def __init__(self, pin): 501 self._pin = pin 502 503 def enter_composite_transform(self, transform_node): 504 self.visit_transform(transform_node) 505 506 def visit_transform(self, transform_node): 507 for in_pcoll in transform_node.inputs: 508 self._process(in_pcoll) 509 for out_pcoll in transform_node.outputs.values(): 510 self._process(out_pcoll) 511 512 def _process(self, pcoll): 513 pcoll_id = self._pin._pcoll_to_pcoll_id.get(str(pcoll), '') 514 if pcoll_id in self._pin._cacheables: 515 pcoll_id = self._pin.pcoll_id(pcoll) 516 user_pcoll = self._pin._cacheables[pcoll_id].pcoll 517 if (pcoll_id in self._pin._cacheables and user_pcoll != pcoll): 518 self._pin._runner_pcoll_to_user_pcoll[pcoll] = user_pcoll 519 self._pin._cacheables[pcoll_id].pcoll = pcoll 520 521 v = PreprocessVisitor(self) 522 self._pipeline.visit(v) 523 524 def _write_cache( 525 self, 526 pipeline, 527 pcoll, 528 output_as_extended_target=True, 529 ignore_unbounded_reads=False, 530 is_capture=False): 531 """Caches a cacheable PCollection. 532 533 For the given PCollection, by appending sub transform part that materialize 534 the PCollection through sink into cache implementation. The cache write is 535 not immediate. It happens when the runner runs the transformed pipeline 536 and thus not usable for this run as intended. This function always writes 537 the cache for the given PCollection as long as the PCollection belongs to 538 the pipeline being instrumented and the keyed cache is absent. 539 540 Modifies: 541 pipeline 542 """ 543 # Makes sure the pcoll belongs to the pipeline being instrumented. 544 if pcoll.pipeline is not pipeline: 545 return 546 547 # Ignore the unbounded reads from recordable sources as these will be pruned 548 # out using the PipelineFragment later on. 549 if ignore_unbounded_reads: 550 ignore = False 551 producer = pcoll.producer 552 while producer: 553 if isinstance(producer.transform, 554 tuple(ie.current_env().options.recordable_sources)): 555 ignore = True 556 break 557 producer = producer.parent 558 if ignore: 559 self._ignored_targets.add(pcoll) 560 return 561 562 # The keyed cache is always valid within this instrumentation. 563 key = self.cache_key(pcoll) 564 # Only need to write when the cache with expected key doesn't exist. 565 if not self._cache_manager.exists('full', key): 566 self.cached_pcolls.add(self.runner_pcoll_to_user_pcoll.get(pcoll, pcoll)) 567 # Read the windowing information and cache it along with the element. This 568 # caches the arguments to a WindowedValue object because Python has logic 569 # that detects if a DoFn returns a WindowedValue. When it detecs one, it 570 # puts the element into the correct window then emits the value to 571 # downstream transforms. 572 extended_target = reify_to_cache( 573 pcoll=pcoll, 574 cache_key=key, 575 cache_manager=self._cache_manager, 576 is_capture=is_capture) 577 if output_as_extended_target: 578 self._extended_targets.add(extended_target) 579 580 def _read_cache(self, pipeline, pcoll, is_unbounded_source_output): 581 """Reads a cached pvalue. 582 583 A noop will cause the pipeline to execute the transform as 584 it is and cache nothing from this transform for next run. 585 586 Modifies: 587 pipeline 588 """ 589 # Makes sure the pcoll belongs to the pipeline being instrumented. 590 if pcoll.pipeline is not pipeline: 591 return 592 # The keyed cache is always valid within this instrumentation. 593 key = self.cache_key(pcoll) 594 # Can only read from cache when the cache with expected key exists and its 595 # computation has been completed. 596 is_cached = self._cache_manager.exists('full', key) 597 is_computed = ( 598 pcoll in self._runner_pcoll_to_user_pcoll and 599 self._runner_pcoll_to_user_pcoll[pcoll] in 600 ie.current_env().computed_pcollections) 601 if ((is_cached and is_computed) or is_unbounded_source_output): 602 if key not in self._cached_pcoll_read: 603 # Mutates the pipeline with cache read transform attached 604 # to root of the pipeline. 605 606 # To put the cached value into the correct window, simply return a 607 # WindowedValue constructed from the element. 608 pcoll_from_cache = unreify_from_cache( 609 pipeline=pipeline, cache_key=key, cache_manager=self._cache_manager) 610 self._cached_pcoll_read[key] = pcoll_from_cache 611 # else: NOOP when cache doesn't exist, just compute the original graph. 612 613 def _replace_with_cached_inputs(self, pipeline): 614 """Replace PCollection inputs in the pipeline with cache if possible. 615 616 For any input PCollection, find out whether there is valid cache. If so, 617 replace the input of the AppliedPTransform with output of the 618 AppliedPtransform that sources pvalue from the cache. If there is no valid 619 cache, noop. 620 """ 621 622 # Find all cached unbounded PCollections. 623 624 # If the pipeline has unbounded sources, then we want to force all cache 625 # reads to go through the TestStream (even if they are bounded sources). 626 if self.has_unbounded_sources: 627 628 class CacheableUnboundedPCollectionVisitor(PipelineVisitor): 629 def __init__(self, pin): 630 self._pin = pin 631 self.unbounded_pcolls = set() 632 633 def enter_composite_transform(self, transform_node): 634 self.visit_transform(transform_node) 635 636 def visit_transform(self, transform_node): 637 if transform_node.outputs: 638 for output_pcoll in transform_node.outputs.values(): 639 key = self._pin.cache_key(output_pcoll) 640 if key in self._pin._cached_pcoll_read: 641 self.unbounded_pcolls.add(key) 642 643 if transform_node.inputs: 644 for input_pcoll in transform_node.inputs: 645 key = self._pin.cache_key(input_pcoll) 646 if key in self._pin._cached_pcoll_read: 647 self.unbounded_pcolls.add(key) 648 649 v = CacheableUnboundedPCollectionVisitor(self) 650 pipeline.visit(v) 651 652 # The set of keys from the cached unbounded PCollections will be used as 653 # the output tags for the TestStream. This is to remember what cache-key 654 # is associated with which PCollection. 655 output_tags = v.unbounded_pcolls 656 657 # Take the PCollections that will be read from the TestStream and insert 658 # them back into the dictionary of cached PCollections. The next step will 659 # replace the downstream consumer of the non-cached PCollections with 660 # these PCollections. 661 if output_tags: 662 output_pcolls = pipeline | test_stream.TestStream( 663 output_tags=output_tags, coder=self._cache_manager._default_pcoder) 664 for tag, pcoll in output_pcolls.items(): 665 self._cached_pcoll_read[tag] = pcoll 666 667 class ReadCacheWireVisitor(PipelineVisitor): 668 """Visitor wires cache read as inputs to replace corresponding original 669 input PCollections in pipeline. 670 """ 671 def __init__(self, pin): 672 """Initializes with a PipelineInstrument.""" 673 self._pin = pin 674 675 def enter_composite_transform(self, transform_node): 676 self.visit_transform(transform_node) 677 678 def visit_transform(self, transform_node): 679 if transform_node.inputs: 680 main_inputs = dict(transform_node.main_inputs) 681 for tag, input_pcoll in main_inputs.items(): 682 key = self._pin.cache_key(input_pcoll) 683 684 # Replace the input pcollection with the cached pcollection (if it 685 # has been cached). 686 if key in self._pin._cached_pcoll_read: 687 # Ignore this pcoll in the final pruned instrumented pipeline. 688 self._pin._ignored_targets.add(input_pcoll) 689 main_inputs[tag] = self._pin._cached_pcoll_read[key] 690 # Update the transform with its new inputs. 691 transform_node.main_inputs = main_inputs 692 693 v = ReadCacheWireVisitor(self) 694 pipeline.visit(v) 695 696 def _cacheable_inputs(self, transform): 697 inputs = set() 698 for in_pcoll in transform.inputs: 699 if self.pcoll_id(in_pcoll) in self._cacheables: 700 inputs.add(in_pcoll) 701 return inputs 702 703 def _all_inputs_outputs(self, transform): 704 inputs = set() 705 outputs = set() 706 for in_pcoll in transform.inputs: 707 inputs.add(in_pcoll) 708 for _, out_pcoll in transform.outputs.items(): 709 outputs.add(out_pcoll) 710 return inputs, outputs 711 712 def pcoll_id(self, pcoll): 713 """Gets the PCollection id of the given pcoll. 714 715 Returns '' if not found. 716 """ 717 return self._pcoll_to_pcoll_id.get(str(pcoll), '') 718 719 def cache_key(self, pcoll): 720 """Gets the identifier of a cacheable PCollection in cache. 721 722 If the pcoll is not a cacheable, return ''. 723 This is only needed in pipeline instrument when the origin of given pcoll 724 is unknown (whether it's from the user pipeline or a runner pipeline). If 725 a pcoll is from the user pipeline, always use CacheKey.from_pcoll to build 726 the key. 727 The key is what the pcoll would use as identifier if it's materialized in 728 cache. It doesn't mean that there would definitely be such cache already. 729 Also, the pcoll can come from the original user defined pipeline object or 730 an equivalent pcoll from a transformed copy of the original pipeline. 731 """ 732 cacheable = self._cacheables.get(self.pcoll_id(pcoll), None) 733 if cacheable: 734 if cacheable.pcoll in self.runner_pcoll_to_user_pcoll: 735 user_pcoll = self.runner_pcoll_to_user_pcoll[cacheable.pcoll] 736 else: 737 user_pcoll = cacheable.pcoll 738 return CacheKey.from_pcoll(cacheable.var, user_pcoll).to_str() 739 return '' 740 741 742 def build_pipeline_instrument(pipeline, options=None): 743 """Creates PipelineInstrument for a pipeline and its options with cache. 744 745 Throughout the process, the returned PipelineInstrument snapshots the given 746 pipeline and then mutates the pipeline. It's invoked by interactive components 747 such as the InteractiveRunner and the given pipeline should be implicitly 748 created runner pipelines instead of pipeline instances defined by the user. 749 750 This is the shorthand for doing 3 steps: 1) compute once for metadata of the 751 given runner pipeline and everything watched from user pipelines; 2) associate 752 info between the runner pipeline and its corresponding user pipeline, 753 eliminate data from other user pipelines if there are any; 3) mutate the 754 runner pipeline to apply interactivity. 755 """ 756 pi = PipelineInstrument(pipeline, options) 757 pi.preprocess() 758 pi.instrument() # Instruments the pipeline only once. 759 return pi 760 761 762 def pcoll_to_pcoll_id(pipeline, original_context): 763 """Returns a dict mapping PCollections string to PCollection IDs. 764 765 Using a PipelineVisitor to iterate over every node in the pipeline, 766 records the mapping from PCollections to PCollections IDs. This mapping 767 will be used to query cached PCollections. 768 769 Returns: 770 (dict from str to str) a dict mapping str(pcoll) to pcoll_id. 771 """ 772 class PCollVisitor(PipelineVisitor): 773 """"A visitor that records input and output values to be replaced. 774 775 Input and output values that should be updated are recorded in maps 776 input_replacements and output_replacements respectively. 777 778 We cannot update input and output values while visiting since that 779 results in validation errors. 780 """ 781 def __init__(self): 782 self.pcoll_to_pcoll_id = {} 783 784 def enter_composite_transform(self, transform_node): 785 self.visit_transform(transform_node) 786 787 def visit_transform(self, transform_node): 788 for pcoll in transform_node.outputs.values(): 789 self.pcoll_to_pcoll_id[str(pcoll)] = ( 790 original_context.pcollections.get_id(pcoll)) 791 792 v = PCollVisitor() 793 pipeline.visit(v) 794 return v.pcoll_to_pcoll_id