github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/direct/transform_evaluator.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """An evaluator of a specific application of a transform.""" 19 20 # pytype: skip-file 21 22 import atexit 23 import collections 24 import logging 25 import random 26 import time 27 from collections import abc 28 from typing import TYPE_CHECKING 29 from typing import Any 30 from typing import Dict 31 from typing import List 32 from typing import Tuple 33 from typing import Type 34 35 from apache_beam import coders 36 from apache_beam import io 37 from apache_beam import pvalue 38 from apache_beam.internal import pickler 39 from apache_beam.runners import common 40 from apache_beam.runners.common import DoFnRunner 41 from apache_beam.runners.common import DoFnState 42 from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite # pylint: disable=protected-access 43 from apache_beam.runners.direct.direct_runner import _DirectReadFromPubSub 44 from apache_beam.runners.direct.direct_runner import _GroupByKeyOnly 45 from apache_beam.runners.direct.direct_runner import _StreamingGroupAlsoByWindow 46 from apache_beam.runners.direct.direct_runner import _StreamingGroupByKeyOnly 47 from apache_beam.runners.direct.direct_userstate import DirectUserStateContext 48 from apache_beam.runners.direct.sdf_direct_runner import ProcessElements 49 from apache_beam.runners.direct.sdf_direct_runner import ProcessFn 50 from apache_beam.runners.direct.sdf_direct_runner import SDFProcessElementInvoker 51 from apache_beam.runners.direct.test_stream_impl import _TestStream 52 from apache_beam.runners.direct.test_stream_impl import _WatermarkController 53 from apache_beam.runners.direct.util import KeyedWorkItem 54 from apache_beam.runners.direct.util import TransformResult 55 from apache_beam.runners.direct.watermark_manager import WatermarkManager 56 from apache_beam.testing.test_stream import ElementEvent 57 from apache_beam.testing.test_stream import PairWithTiming 58 from apache_beam.testing.test_stream import ProcessingTimeEvent 59 from apache_beam.testing.test_stream import TimingInfo 60 from apache_beam.testing.test_stream import WatermarkEvent 61 from apache_beam.testing.test_stream import WindowedValueHolder 62 from apache_beam.transforms import core 63 from apache_beam.transforms.trigger import InMemoryUnmergedState 64 from apache_beam.transforms.trigger import TimeDomain 65 from apache_beam.transforms.trigger import _CombiningValueStateTag 66 from apache_beam.transforms.trigger import _ListStateTag 67 from apache_beam.transforms.trigger import _ReadModifyWriteStateTag 68 from apache_beam.transforms.trigger import create_trigger_driver 69 from apache_beam.transforms.userstate import get_dofn_specs 70 from apache_beam.transforms.userstate import is_stateful_dofn 71 from apache_beam.transforms.window import GlobalWindows 72 from apache_beam.transforms.window import WindowedValue 73 from apache_beam.typehints.typecheck import TypeCheckError 74 from apache_beam.utils import counters 75 from apache_beam.utils.timestamp import MIN_TIMESTAMP 76 from apache_beam.utils.timestamp import Timestamp 77 78 if TYPE_CHECKING: 79 from apache_beam.io.gcp.pubsub import _PubSubSource 80 from apache_beam.io.gcp.pubsub import PubsubMessage 81 from apache_beam.pipeline import AppliedPTransform 82 from apache_beam.runners.direct.evaluation_context import EvaluationContext 83 84 _LOGGER = logging.getLogger(__name__) 85 86 87 class TransformEvaluatorRegistry(object): 88 """For internal use only; no backwards-compatibility guarantees. 89 90 Creates instances of TransformEvaluator for the application of a transform. 91 """ 92 93 _test_evaluators_overrides = { 94 } # type: Dict[Type[core.PTransform], Type[_TransformEvaluator]] 95 96 def __init__(self, evaluation_context): 97 # type: (EvaluationContext) -> None 98 assert evaluation_context 99 self._evaluation_context = evaluation_context 100 self._evaluators = { 101 io.Read: _BoundedReadEvaluator, 102 _DirectReadFromPubSub: _PubSubReadEvaluator, 103 core.Flatten: _FlattenEvaluator, 104 core.Impulse: _ImpulseEvaluator, 105 core.ParDo: _ParDoEvaluator, 106 _GroupByKeyOnly: _GroupByKeyOnlyEvaluator, 107 _StreamingGroupByKeyOnly: _StreamingGroupByKeyOnlyEvaluator, 108 _StreamingGroupAlsoByWindow: _StreamingGroupAlsoByWindowEvaluator, 109 _NativeWrite: _NativeWriteEvaluator, 110 _TestStream: _TestStreamEvaluator, 111 ProcessElements: _ProcessElementsEvaluator, 112 _WatermarkController: _WatermarkControllerEvaluator, 113 PairWithTiming: _PairWithTimingEvaluator, 114 } # type: Dict[Type[core.PTransform], Type[_TransformEvaluator]] 115 self._evaluators.update(self._test_evaluators_overrides) 116 self._root_bundle_providers = { 117 core.PTransform: DefaultRootBundleProvider, 118 _TestStream: _TestStreamRootBundleProvider, 119 } 120 121 def get_evaluator( 122 self, applied_ptransform, input_committed_bundle, side_inputs): 123 """Returns a TransformEvaluator suitable for processing given inputs.""" 124 assert applied_ptransform 125 assert bool(applied_ptransform.side_inputs) == bool(side_inputs) 126 127 # Walk up the class hierarchy to find an evaluable type. This is necessary 128 # for supporting sub-classes of core transforms. 129 for cls in applied_ptransform.transform.__class__.mro(): 130 evaluator = self._evaluators.get(cls) 131 if evaluator: 132 break 133 134 if not evaluator: 135 raise NotImplementedError( 136 'Execution of [%s] not implemented in runner %s.' % 137 (type(applied_ptransform.transform), self)) 138 return evaluator( 139 self._evaluation_context, 140 applied_ptransform, 141 input_committed_bundle, 142 side_inputs) 143 144 def get_root_bundle_provider(self, applied_ptransform): 145 provider_cls = None 146 for cls in applied_ptransform.transform.__class__.mro(): 147 provider_cls = self._root_bundle_providers.get(cls) 148 if provider_cls: 149 break 150 if not provider_cls: 151 raise NotImplementedError( 152 'Root provider for [%s] not implemented in runner %s' % 153 (type(applied_ptransform.transform), self)) 154 return provider_cls(self._evaluation_context, applied_ptransform) 155 156 def should_execute_serially(self, applied_ptransform): 157 """Returns True if this applied_ptransform should run one bundle at a time. 158 159 Some TransformEvaluators use a global state object to keep track of their 160 global execution state. For example evaluator for _GroupByKeyOnly uses this 161 state as an in memory dictionary to buffer keys. 162 163 Serially executed evaluators will act as syncing point in the graph and 164 execution will not move forward until they receive all of their inputs. Once 165 they receive all of their input, they will release the combined output. 166 Their output may consist of multiple bundles as they may divide their output 167 into pieces before releasing. 168 169 Args: 170 applied_ptransform: Transform to be used for execution. 171 172 Returns: 173 True if executor should execute applied_ptransform serially. 174 """ 175 if isinstance(applied_ptransform.transform, 176 (_GroupByKeyOnly, 177 _StreamingGroupByKeyOnly, 178 _StreamingGroupAlsoByWindow, 179 _NativeWrite)): 180 return True 181 elif (isinstance(applied_ptransform.transform, core.ParDo) and 182 is_stateful_dofn(applied_ptransform.transform.dofn)): 183 return True 184 return False 185 186 187 class RootBundleProvider(object): 188 """Provides bundles for the initial execution of a root transform.""" 189 def __init__(self, evaluation_context, applied_ptransform): 190 self._evaluation_context = evaluation_context 191 self._applied_ptransform = applied_ptransform 192 193 def get_root_bundles(self): 194 raise NotImplementedError 195 196 197 class DefaultRootBundleProvider(RootBundleProvider): 198 """Provides an empty bundle by default for root transforms.""" 199 def get_root_bundles(self): 200 input_node = pvalue.PBegin(self._applied_ptransform.transform.pipeline) 201 empty_bundle = ( 202 self._evaluation_context.create_empty_committed_bundle(input_node)) 203 return [empty_bundle] 204 205 206 class _TestStreamRootBundleProvider(RootBundleProvider): 207 """Provides an initial bundle for the TestStream evaluator. 208 209 This bundle is used as the initial state to the TestStream. Each unprocessed 210 bundle emitted from the TestStream afterwards is its state: index into the 211 stream, and the watermark. 212 """ 213 def get_root_bundles(self): 214 test_stream = self._applied_ptransform.transform 215 216 # If there was an endpoint defined then get the events from the 217 # TestStreamService. 218 if test_stream.endpoint: 219 _TestStreamEvaluator.event_stream = _TestStream.events_from_rpc( 220 test_stream.endpoint, 221 test_stream.output_tags, 222 test_stream.coder, 223 self._evaluation_context) 224 else: 225 _TestStreamEvaluator.event_stream = ( 226 _TestStream.events_from_script(test_stream._events)) 227 228 bundle = self._evaluation_context.create_bundle( 229 pvalue.PBegin(self._applied_ptransform.transform.pipeline)) 230 bundle.add(GlobalWindows.windowed_value(b'', timestamp=MIN_TIMESTAMP)) 231 bundle.commit(None) 232 return [bundle] 233 234 235 class _TransformEvaluator(object): 236 """An evaluator of a specific application of a transform.""" 237 238 def __init__(self, 239 evaluation_context, # type: EvaluationContext 240 applied_ptransform, # type: AppliedPTransform 241 input_committed_bundle, 242 side_inputs 243 ): 244 self._evaluation_context = evaluation_context 245 self._applied_ptransform = applied_ptransform 246 self._input_committed_bundle = input_committed_bundle 247 self._side_inputs = side_inputs 248 self._expand_outputs() 249 self._execution_context = evaluation_context.get_execution_context( 250 applied_ptransform) 251 self._step_context = self._execution_context.get_step_context() 252 253 def _expand_outputs(self): 254 outputs = set() 255 for pval in self._applied_ptransform.outputs.values(): 256 if isinstance(pval, pvalue.DoOutputsTuple): 257 pvals = (v for v in pval) 258 else: 259 pvals = (pval, ) 260 for v in pvals: 261 outputs.add(v) 262 self._outputs = frozenset(outputs) 263 264 def _split_list_into_bundles( 265 self, 266 output_pcollection, 267 elements, 268 max_element_per_bundle, 269 element_size_fn): 270 """Splits elements, an iterable, into multiple output bundles. 271 272 Args: 273 output_pcollection: PCollection that the elements belong to. 274 elements: elements to be chunked into bundles. 275 max_element_per_bundle: (approximately) the maximum element per bundle. 276 If it is None, only a single bundle will be produced. 277 element_size_fn: Function to return the size of a given element. 278 279 Returns: 280 List of output uncommitted bundles with at least one bundle. 281 """ 282 bundle = self._evaluation_context.create_bundle(output_pcollection) 283 bundle_size = 0 284 bundles = [bundle] 285 for element in elements: 286 if max_element_per_bundle and bundle_size >= max_element_per_bundle: 287 bundle = self._evaluation_context.create_bundle(output_pcollection) 288 bundle_size = 0 289 bundles.append(bundle) 290 291 bundle.output(element) 292 bundle_size += element_size_fn(element) 293 return bundles 294 295 def start_bundle(self): 296 """Starts a new bundle.""" 297 pass 298 299 def process_timer_wrapper(self, timer_firing): 300 """Process timer by clearing and then calling process_timer(). 301 302 This method is called with any timer firing and clears the delivered 303 timer from the keyed state and then calls process_timer(). The default 304 process_timer() implementation emits a KeyedWorkItem for the particular 305 timer and passes it to process_element(). Evaluator subclasses which 306 desire different timer delivery semantics can override process_timer(). 307 """ 308 state = self._step_context.get_keyed_state(timer_firing.encoded_key) 309 state.clear_timer( 310 timer_firing.window, 311 timer_firing.name, 312 timer_firing.time_domain, 313 dynamic_timer_tag=timer_firing.dynamic_timer_tag) 314 self.process_timer(timer_firing) 315 316 def process_timer(self, timer_firing): 317 """Default process_timer() impl. generating KeyedWorkItem element.""" 318 self.process_element( 319 GlobalWindows.windowed_value( 320 KeyedWorkItem( 321 timer_firing.encoded_key, timer_firings=[timer_firing]))) 322 323 def process_element(self, element): 324 """Processes a new element as part of the current bundle.""" 325 raise NotImplementedError('%s do not process elements.' % type(self)) 326 327 def finish_bundle(self): 328 # type: () -> TransformResult 329 330 """Finishes the bundle and produces output.""" 331 pass 332 333 334 class _BoundedReadEvaluator(_TransformEvaluator): 335 """TransformEvaluator for bounded Read transform.""" 336 337 # After some benchmarks, 1000 was optimal among {100,1000,10000} 338 MAX_ELEMENT_PER_BUNDLE = 1000 339 340 def __init__( 341 self, 342 evaluation_context, 343 applied_ptransform, 344 input_committed_bundle, 345 side_inputs): 346 assert not side_inputs 347 self._source = applied_ptransform.transform.source 348 self._source.pipeline_options = evaluation_context.pipeline_options 349 super().__init__( 350 evaluation_context, 351 applied_ptransform, 352 input_committed_bundle, 353 side_inputs) 354 355 def finish_bundle(self): 356 assert len(self._outputs) == 1 357 output_pcollection = list(self._outputs)[0] 358 359 def _read_values_to_bundles(reader): 360 read_result = [GlobalWindows.windowed_value(e) for e in reader] 361 return self._split_list_into_bundles( 362 output_pcollection, 363 read_result, 364 _BoundedReadEvaluator.MAX_ELEMENT_PER_BUNDLE, 365 lambda _: 1) 366 367 if isinstance(self._source, io.iobase.BoundedSource): 368 # Getting a RangeTracker for the default range of the source and reading 369 # the full source using that. 370 range_tracker = self._source.get_range_tracker(None, None) 371 reader = self._source.read(range_tracker) 372 bundles = _read_values_to_bundles(reader) 373 else: 374 with self._source.reader() as reader: 375 bundles = _read_values_to_bundles(reader) 376 377 return TransformResult(self, bundles, [], None, None) 378 379 380 class _WatermarkControllerEvaluator(_TransformEvaluator): 381 """TransformEvaluator for the _WatermarkController transform. 382 383 This is used to enable multiple output watermarks for the TestStream. 384 """ 385 386 # The state tag used to store the watermark. 387 WATERMARK_TAG = _ReadModifyWriteStateTag( 388 '_WatermarkControllerEvaluator_Watermark_Tag') 389 390 def __init__( 391 self, 392 evaluation_context, 393 applied_ptransform, 394 input_committed_bundle, 395 side_inputs): 396 assert not side_inputs 397 self.transform = applied_ptransform.transform 398 super().__init__( 399 evaluation_context, 400 applied_ptransform, 401 input_committed_bundle, 402 side_inputs) 403 self._init_state() 404 405 def _init_state(self): 406 """Gets and sets the initial state. 407 408 This is used to keep track of the watermark hold between calls. 409 """ 410 transform_states = self._evaluation_context._transform_keyed_states 411 state = transform_states[self._applied_ptransform] 412 if self.WATERMARK_TAG not in state: 413 watermark_state = InMemoryUnmergedState() 414 watermark_state.set_global_state(self.WATERMARK_TAG, MIN_TIMESTAMP) 415 state[self.WATERMARK_TAG] = watermark_state 416 self._state = state[self.WATERMARK_TAG] 417 418 @property 419 def _watermark(self): 420 return self._state.get_global_state(self.WATERMARK_TAG) 421 422 @_watermark.setter 423 def _watermark(self, watermark): 424 self._state.set_global_state(self.WATERMARK_TAG, watermark) 425 426 def start_bundle(self): 427 self.bundles = [] 428 429 def process_element(self, element): 430 # In order to keep the order of the elements between the script and what 431 # flows through the pipeline the same, emit the elements here. 432 event = element.value 433 if isinstance(event, WatermarkEvent): 434 self._watermark = event.new_watermark 435 elif isinstance(event, ElementEvent): 436 main_output = list(self._outputs)[0] 437 bundle = self._evaluation_context.create_bundle(main_output) 438 for tv in event.timestamped_values: 439 # Unreify the value into the correct window. 440 if isinstance(tv.value, WindowedValueHolder): 441 bundle.output(tv.value.windowed_value) 442 else: 443 bundle.output( 444 GlobalWindows.windowed_value(tv.value, timestamp=tv.timestamp)) 445 self.bundles.append(bundle) 446 447 def finish_bundle(self): 448 # The watermark hold we set here is the way we allow the TestStream events 449 # to control the output watermark. 450 return TransformResult( 451 self, self.bundles, [], None, {None: self._watermark}) 452 453 454 class _PairWithTimingEvaluator(_TransformEvaluator): 455 """TransformEvaluator for the PairWithTiming transform. 456 457 This transform takes an element as an input and outputs 458 KV(element, `TimingInfo`). Where the `TimingInfo` contains both the 459 processing time timestamp and watermark. 460 """ 461 def __init__( 462 self, 463 evaluation_context, 464 applied_ptransform, 465 input_committed_bundle, 466 side_inputs): 467 assert not side_inputs 468 super().__init__( 469 evaluation_context, 470 applied_ptransform, 471 input_committed_bundle, 472 side_inputs) 473 474 def start_bundle(self): 475 main_output = list(self._outputs)[0] 476 self.bundle = self._evaluation_context.create_bundle(main_output) 477 478 watermark_manager = self._evaluation_context._watermark_manager 479 watermarks = watermark_manager.get_watermarks(self._applied_ptransform) 480 481 output_watermark = watermarks.output_watermark 482 now = Timestamp(seconds=watermark_manager._clock.time()) 483 self.timing_info = TimingInfo(now, output_watermark) 484 485 def process_element(self, element): 486 result = WindowedValue((element.value, self.timing_info), 487 element.timestamp, 488 element.windows, 489 element.pane_info) 490 self.bundle.output(result) 491 492 def finish_bundle(self): 493 return TransformResult(self, [self.bundle], [], None, {}) 494 495 496 class _TestStreamEvaluator(_TransformEvaluator): 497 """TransformEvaluator for the TestStream transform. 498 499 This evaluator's responsibility is to retrieve the next event from the 500 _TestStream and either: advance the clock, advance the _TestStream watermark, 501 or pass the event to the _WatermarkController. 502 503 The _WatermarkController is in charge of emitting the elements to the 504 downstream consumers and setting its own output watermark. 505 """ 506 507 event_stream = None 508 509 def __init__( 510 self, 511 evaluation_context, 512 applied_ptransform, 513 input_committed_bundle, 514 side_inputs): 515 assert not side_inputs 516 super().__init__( 517 evaluation_context, 518 applied_ptransform, 519 input_committed_bundle, 520 side_inputs) 521 self.test_stream = applied_ptransform.transform 522 self.is_done = False 523 524 def start_bundle(self): 525 self.bundles = [] 526 self.watermark = MIN_TIMESTAMP 527 528 def process_element(self, element): 529 # The watermark of the _TestStream transform itself. 530 self.watermark = element.timestamp 531 532 # Set up the correct watermark holds in the Watermark controllers and the 533 # TestStream so that the watermarks will not automatically advance to +inf 534 # when elements start streaming. This can happen multiple times in the first 535 # bundle, but the operations are idempotent and adding state to keep track 536 # of this would add unnecessary code complexity. 537 events = [] 538 if self.watermark == MIN_TIMESTAMP: 539 for event in self.test_stream._set_up(self.test_stream.output_tags): 540 events.append(event) 541 542 # Retrieve the TestStream's event stream and read from it. 543 try: 544 events.append(next(self.event_stream)) 545 except StopIteration: 546 # Advance the watermarks to +inf to cleanly stop the pipeline. 547 self.is_done = True 548 events += ([ 549 e for e in self.test_stream._tear_down(self.test_stream.output_tags) 550 ]) 551 552 for event in events: 553 # We can either have the _TestStream or the _WatermarkController to emit 554 # the elements. We chose to emit in the _WatermarkController so that the 555 # element is emitted at the correct watermark value. 556 if isinstance(event, (ElementEvent, WatermarkEvent)): 557 # The WATERMARK_CONTROL_TAG is used to hold the _TestStream's 558 # watermark to -inf, then +inf-1, then +inf. This watermark progression 559 # is ultimately used to set up the proper holds to allow the 560 # _WatermarkControllers to control their own output watermarks. 561 if event.tag == _TestStream.WATERMARK_CONTROL_TAG: 562 self.watermark = event.new_watermark 563 else: 564 main_output = list(self._outputs)[0] 565 bundle = self._evaluation_context.create_bundle(main_output) 566 bundle.output(GlobalWindows.windowed_value(event)) 567 self.bundles.append(bundle) 568 elif isinstance(event, ProcessingTimeEvent): 569 self._evaluation_context._watermark_manager._clock.advance_time( 570 event.advance_by) 571 else: 572 raise ValueError('Invalid TestStream event: %s.' % event) 573 574 def finish_bundle(self): 575 unprocessed_bundles = [] 576 577 # Continue to send its own state to itself via an unprocessed bundle. This 578 # acts as a heartbeat, where each element will read the next event from the 579 # event stream. 580 if not self.is_done: 581 unprocessed_bundle = self._evaluation_context.create_bundle( 582 pvalue.PBegin(self._applied_ptransform.transform.pipeline)) 583 unprocessed_bundle.add( 584 GlobalWindows.windowed_value(b'', timestamp=self.watermark)) 585 unprocessed_bundles.append(unprocessed_bundle) 586 587 # Returning the watermark in the dict here is used as a watermark hold. 588 return TransformResult( 589 self, self.bundles, unprocessed_bundles, None, {None: self.watermark}) 590 591 592 class _PubSubReadEvaluator(_TransformEvaluator): 593 """TransformEvaluator for PubSub read.""" 594 595 # A mapping of transform to _PubSubSubscriptionWrapper. 596 # TODO(https://github.com/apache/beam/issues/19751): Prevents garbage 597 # collection of pipeline instances. 598 _subscription_cache = {} # type: Dict[AppliedPTransform, str] 599 600 def __init__( 601 self, 602 evaluation_context, 603 applied_ptransform, 604 input_committed_bundle, 605 side_inputs): 606 assert not side_inputs 607 super().__init__( 608 evaluation_context, 609 applied_ptransform, 610 input_committed_bundle, 611 side_inputs) 612 613 self.source = self._applied_ptransform.transform._source # type: _PubSubSource 614 if self.source.id_label: 615 raise NotImplementedError( 616 'DirectRunner: id_label is not supported for PubSub reads') 617 618 sub_project = None 619 if hasattr(self._evaluation_context, 'pipeline_options'): 620 from apache_beam.options.pipeline_options import GoogleCloudOptions 621 sub_project = ( 622 self._evaluation_context.pipeline_options.view_as( 623 GoogleCloudOptions).project) 624 if not sub_project: 625 sub_project = self.source.project 626 627 self._sub_name = self.get_subscription( 628 self._applied_ptransform, 629 self.source.project, 630 self.source.topic_name, 631 sub_project, 632 self.source.subscription_name) 633 634 @classmethod 635 def get_subscription( 636 cls, transform, project, short_topic_name, sub_project, short_sub_name): 637 from google.cloud import pubsub 638 639 if short_sub_name: 640 return pubsub.SubscriberClient.subscription_path(project, short_sub_name) 641 642 if transform in cls._subscription_cache: 643 return cls._subscription_cache[transform] 644 645 sub_client = pubsub.SubscriberClient() 646 sub_name = sub_client.subscription_path( 647 sub_project, 648 'beam_%d_%x' % (int(time.time()), random.randrange(1 << 32))) 649 topic_name = sub_client.topic_path(project, short_topic_name) 650 sub_client.create_subscription(name=sub_name, topic=topic_name) 651 atexit.register(sub_client.delete_subscription, subscription=sub_name) 652 cls._subscription_cache[transform] = sub_name 653 return cls._subscription_cache[transform] 654 655 def start_bundle(self): 656 pass 657 658 def process_element(self, element): 659 pass 660 661 def _read_from_pubsub(self, timestamp_attribute): 662 # type: (...) -> List[Tuple[Timestamp, PubsubMessage]] 663 from apache_beam.io.gcp.pubsub import PubsubMessage 664 from google.cloud import pubsub 665 666 def _get_element(message): 667 parsed_message = PubsubMessage._from_message(message) 668 if (timestamp_attribute and 669 timestamp_attribute in parsed_message.attributes): 670 rfc3339_or_milli = parsed_message.attributes[timestamp_attribute] 671 try: 672 timestamp = Timestamp(micros=int(rfc3339_or_milli) * 1000) 673 except ValueError: 674 try: 675 timestamp = Timestamp.from_rfc3339(rfc3339_or_milli) 676 except ValueError as e: 677 raise ValueError('Bad timestamp value: %s' % e) 678 else: 679 if message.publish_time is None: 680 raise ValueError('No publish time present in message: %s' % message) 681 try: 682 timestamp = Timestamp.from_utc_datetime(message.publish_time) 683 except ValueError as e: 684 raise ValueError('Bad timestamp value for message %s: %s', message, e) 685 686 return timestamp, parsed_message 687 688 # Because of the AutoAck, we are not able to reread messages if this 689 # evaluator fails with an exception before emitting a bundle. However, 690 # the DirectRunner currently doesn't retry work items anyway, so the 691 # pipeline would enter an inconsistent state on any error. 692 sub_client = pubsub.SubscriberClient() 693 try: 694 response = sub_client.pull( 695 subscription=self._sub_name, max_messages=10, timeout=30) 696 results = [_get_element(rm.message) for rm in response.received_messages] 697 ack_ids = [rm.ack_id for rm in response.received_messages] 698 if ack_ids: 699 sub_client.acknowledge(subscription=self._sub_name, ack_ids=ack_ids) 700 finally: 701 sub_client.close() 702 703 return results 704 705 def finish_bundle(self): 706 # type: () -> TransformResult 707 data = self._read_from_pubsub(self.source.timestamp_attribute) 708 if data: 709 output_pcollection = list(self._outputs)[0] 710 bundle = self._evaluation_context.create_bundle(output_pcollection) 711 # TODO(ccy): Respect the PubSub source's id_label field. 712 for timestamp, message in data: 713 if self.source.with_attributes: 714 element = message 715 else: 716 element = message.data 717 bundle.output( 718 GlobalWindows.windowed_value(element, timestamp=timestamp)) 719 bundles = [bundle] 720 else: 721 bundles = [] 722 assert self._applied_ptransform.transform is not None 723 if self._applied_ptransform.inputs: 724 input_pvalue = self._applied_ptransform.inputs[0] 725 else: 726 input_pvalue = pvalue.PBegin(self._applied_ptransform.transform.pipeline) 727 unprocessed_bundle = self._evaluation_context.create_bundle(input_pvalue) 728 729 # TODO(udim): Correct value for watermark hold. 730 return TransformResult( 731 self, 732 bundles, [unprocessed_bundle], 733 None, {None: Timestamp.of(time.time())}) 734 735 736 class _FlattenEvaluator(_TransformEvaluator): 737 """TransformEvaluator for Flatten transform.""" 738 def __init__( 739 self, 740 evaluation_context, 741 applied_ptransform, 742 input_committed_bundle, 743 side_inputs): 744 assert not side_inputs 745 super().__init__( 746 evaluation_context, 747 applied_ptransform, 748 input_committed_bundle, 749 side_inputs) 750 751 def start_bundle(self): 752 assert len(self._outputs) == 1 753 output_pcollection = list(self._outputs)[0] 754 self.bundle = self._evaluation_context.create_bundle(output_pcollection) 755 756 def process_element(self, element): 757 self.bundle.output(element) 758 759 def finish_bundle(self): 760 bundles = [self.bundle] 761 return TransformResult(self, bundles, [], None, None) 762 763 764 class _ImpulseEvaluator(_TransformEvaluator): 765 """TransformEvaluator for Impulse transform.""" 766 def finish_bundle(self): 767 assert len(self._outputs) == 1 768 output_pcollection = list(self._outputs)[0] 769 bundle = self._evaluation_context.create_bundle(output_pcollection) 770 bundle.output(GlobalWindows.windowed_value(b'')) 771 return TransformResult(self, [bundle], [], None, None) 772 773 774 class _TaggedReceivers(dict): 775 """Received ParDo output and redirect to the associated output bundle.""" 776 def __init__(self, evaluation_context): 777 self._evaluation_context = evaluation_context 778 self._null_receiver = None 779 super().__init__() 780 781 class NullReceiver(common.Receiver): 782 """Ignores undeclared outputs, default execution mode.""" 783 def receive(self, element): 784 # type: (WindowedValue) -> None 785 pass 786 787 class _InMemoryReceiver(common.Receiver): 788 """Buffers undeclared outputs to the given dictionary.""" 789 def __init__(self, target, tag): 790 self._target = target 791 self._tag = tag 792 793 def receive(self, element): 794 # type: (WindowedValue) -> None 795 self._target[self._tag].append(element) 796 797 def __missing__(self, key): 798 if not self._null_receiver: 799 self._null_receiver = _TaggedReceivers.NullReceiver() 800 return self._null_receiver 801 802 803 class _ParDoEvaluator(_TransformEvaluator): 804 """TransformEvaluator for ParDo transform.""" 805 806 def __init__(self, 807 evaluation_context, # type: EvaluationContext 808 applied_ptransform, # type: AppliedPTransform 809 input_committed_bundle, 810 side_inputs, 811 perform_dofn_pickle_test=True 812 ): 813 super().__init__( 814 evaluation_context, 815 applied_ptransform, 816 input_committed_bundle, 817 side_inputs) 818 # This is a workaround for SDF implementation. SDF implementation adds state 819 # to the SDF that is not picklable. 820 self._perform_dofn_pickle_test = perform_dofn_pickle_test 821 822 def start_bundle(self): 823 transform = self._applied_ptransform.transform 824 825 self._tagged_receivers = _TaggedReceivers(self._evaluation_context) 826 for output_tag in self._applied_ptransform.outputs: 827 output_pcollection = pvalue.PCollection(None, tag=output_tag) 828 output_pcollection.producer = self._applied_ptransform 829 self._tagged_receivers[output_tag] = ( 830 self._evaluation_context.create_bundle(output_pcollection)) 831 self._tagged_receivers[output_tag].tag = output_tag 832 833 self._counter_factory = counters.CounterFactory() 834 835 # TODO(aaltay): Consider storing the serialized form as an optimization. 836 dofn = ( 837 pickler.loads(pickler.dumps(transform.dofn)) 838 if self._perform_dofn_pickle_test else transform.dofn) 839 840 args = transform.args if hasattr(transform, 'args') else [] 841 kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {} 842 843 self.user_state_context = None 844 self.user_timer_map = {} 845 if is_stateful_dofn(dofn): 846 kv_type_hint = self._applied_ptransform.inputs[0].element_type 847 if kv_type_hint and kv_type_hint != Any: 848 coder = coders.registry.get_coder(kv_type_hint) 849 self.key_coder = coder.key_coder() 850 else: 851 self.key_coder = coders.registry.get_coder(Any) 852 853 self.user_state_context = DirectUserStateContext( 854 self._step_context, dofn, self.key_coder) 855 _, all_timer_specs = get_dofn_specs(dofn) 856 for timer_spec in all_timer_specs: 857 self.user_timer_map['user/%s' % timer_spec.name] = timer_spec 858 859 self.runner = DoFnRunner( 860 dofn, 861 args, 862 kwargs, 863 self._side_inputs, 864 self._applied_ptransform.inputs[0].windowing, 865 tagged_receivers=self._tagged_receivers, 866 step_name=self._applied_ptransform.full_label, 867 state=DoFnState(self._counter_factory), 868 user_state_context=self.user_state_context) 869 self.runner.setup() 870 self.runner.start() 871 872 def process_timer(self, timer_firing): 873 if timer_firing.name not in self.user_timer_map: 874 _LOGGER.warning('Unknown timer fired: %s', timer_firing) 875 timer_spec = self.user_timer_map[timer_firing.name] 876 self.runner.process_user_timer( 877 timer_spec, 878 self.key_coder.decode(timer_firing.encoded_key), 879 timer_firing.window, 880 timer_firing.timestamp, 881 # TODO Add paneinfo to timer_firing in DirectRunner 882 None, 883 timer_firing.dynamic_timer_tag) 884 885 def process_element(self, element): 886 self.runner.process(element) 887 888 def finish_bundle(self): 889 self.runner.finish() 890 self.runner.teardown() 891 bundles = list(self._tagged_receivers.values()) 892 result_counters = self._counter_factory.get_counters() 893 if self.user_state_context: 894 self.user_state_context.commit() 895 self.user_state_context.reset() 896 return TransformResult(self, bundles, [], result_counters, None) 897 898 899 class _GroupByKeyOnlyEvaluator(_TransformEvaluator): 900 """TransformEvaluator for _GroupByKeyOnly transform.""" 901 902 MAX_ELEMENT_PER_BUNDLE = None 903 ELEMENTS_TAG = _ListStateTag('elements') 904 COMPLETION_TAG = _CombiningValueStateTag('completed', any) 905 906 def __init__( 907 self, 908 evaluation_context, 909 applied_ptransform, 910 input_committed_bundle, 911 side_inputs): 912 assert not side_inputs 913 super().__init__( 914 evaluation_context, 915 applied_ptransform, 916 input_committed_bundle, 917 side_inputs) 918 919 def _is_final_bundle(self): 920 return ( 921 self._execution_context.watermarks.input_watermark == 922 WatermarkManager.WATERMARK_POS_INF) 923 924 def start_bundle(self): 925 self.global_state = self._step_context.get_keyed_state(None) 926 927 assert len(self._outputs) == 1 928 self.output_pcollection = list(self._outputs)[0] 929 930 # The output type of a GroupByKey will be Tuple[Any, Any] or more specific. 931 # TODO(https://github.com/apache/beam/issues/18490): Infer coders earlier. 932 kv_type_hint = ( 933 self._applied_ptransform.outputs[None].element_type or 934 self._applied_ptransform.transform.get_type_hints().input_types[0][0]) 935 self.key_coder = coders.registry.get_coder(kv_type_hint.tuple_types[0]) 936 937 def process_timer(self, timer_firing): 938 # We do not need to emit a KeyedWorkItem to process_element(). 939 pass 940 941 def process_element(self, element): 942 assert not self.global_state.get_state( 943 None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG) 944 if (isinstance(element, WindowedValue) and 945 isinstance(element.value, abc.Iterable) and len(element.value) == 2): 946 k, v = element.value 947 encoded_k = self.key_coder.encode(k) 948 state = self._step_context.get_keyed_state(encoded_k) 949 state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v) 950 else: 951 raise TypeCheckError( 952 'Input to _GroupByKeyOnly must be a PCollection of ' 953 'windowed key-value pairs. Instead received: %r.' % element) 954 955 def finish_bundle(self): 956 if self._is_final_bundle(): 957 if self.global_state.get_state(None, 958 _GroupByKeyOnlyEvaluator.COMPLETION_TAG): 959 # Ignore empty bundles after emitting output. (This may happen because 960 # empty bundles do not affect input watermarks.) 961 bundles = [] 962 else: 963 gbk_result = [] 964 # TODO(ccy): perhaps we can clean this up to not use this 965 # internal attribute of the DirectStepContext. 966 for encoded_k in self._step_context.existing_keyed_state: 967 # Ignore global state. 968 if encoded_k is None: 969 continue 970 k = self.key_coder.decode(encoded_k) 971 state = self._step_context.get_keyed_state(encoded_k) 972 vs = state.get_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG) 973 gbk_result.append(GlobalWindows.windowed_value((k, vs))) 974 975 def len_element_fn(element): 976 _, v = element.value 977 return len(v) 978 979 bundles = self._split_list_into_bundles( 980 self.output_pcollection, 981 gbk_result, 982 _GroupByKeyOnlyEvaluator.MAX_ELEMENT_PER_BUNDLE, 983 len_element_fn) 984 985 self.global_state.add_state( 986 None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG, True) 987 hold = WatermarkManager.WATERMARK_POS_INF 988 else: 989 bundles = [] 990 hold = WatermarkManager.WATERMARK_NEG_INF 991 self.global_state.set_timer( 992 None, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_POS_INF) 993 994 return TransformResult(self, bundles, [], None, {None: hold}) 995 996 997 class _StreamingGroupByKeyOnlyEvaluator(_TransformEvaluator): 998 """TransformEvaluator for _StreamingGroupByKeyOnly transform. 999 1000 The _GroupByKeyOnlyEvaluator buffers elements until its input watermark goes 1001 to infinity, which is suitable for batch mode execution. During streaming 1002 mode execution, we emit each bundle as it comes to the next transform. 1003 """ 1004 1005 MAX_ELEMENT_PER_BUNDLE = None 1006 1007 def __init__( 1008 self, 1009 evaluation_context, 1010 applied_ptransform, 1011 input_committed_bundle, 1012 side_inputs): 1013 assert not side_inputs 1014 super().__init__( 1015 evaluation_context, 1016 applied_ptransform, 1017 input_committed_bundle, 1018 side_inputs) 1019 1020 def start_bundle(self): 1021 self.gbk_items = collections.defaultdict(list) 1022 1023 assert len(self._outputs) == 1 1024 self.output_pcollection = list(self._outputs)[0] 1025 1026 # The input type of a GroupByKey will be Tuple[Any, Any] or more specific. 1027 kv_type_hint = self._applied_ptransform.inputs[0].element_type 1028 key_type_hint = (kv_type_hint.tuple_types[0] if kv_type_hint else Any) 1029 self.key_coder = coders.registry.get_coder(key_type_hint) 1030 1031 def process_element(self, element): 1032 if (isinstance(element, WindowedValue) and 1033 isinstance(element.value, collections.abc.Iterable) and 1034 len(element.value) == 2): 1035 k, v = element.value 1036 self.gbk_items[self.key_coder.encode(k)].append(v) 1037 else: 1038 raise TypeCheckError( 1039 'Input to _GroupByKeyOnly must be a PCollection of ' 1040 'windowed key-value pairs. Instead received: %r.' % element) 1041 1042 def finish_bundle(self): 1043 bundles = [] 1044 bundle = None 1045 for encoded_k, vs in self.gbk_items.items(): 1046 if not bundle: 1047 bundle = self._evaluation_context.create_bundle(self.output_pcollection) 1048 bundles.append(bundle) 1049 kwi = KeyedWorkItem(encoded_k, elements=vs) 1050 bundle.add(GlobalWindows.windowed_value(kwi)) 1051 1052 return TransformResult(self, bundles, [], None, None) 1053 1054 1055 class _StreamingGroupAlsoByWindowEvaluator(_TransformEvaluator): 1056 """TransformEvaluator for the _StreamingGroupAlsoByWindow transform. 1057 1058 This evaluator is only used in streaming mode. In batch mode, the 1059 GroupAlsoByWindow operation is evaluated as a normal DoFn, as defined 1060 in transforms/core.py. 1061 """ 1062 def __init__( 1063 self, 1064 evaluation_context, 1065 applied_ptransform, 1066 input_committed_bundle, 1067 side_inputs): 1068 assert not side_inputs 1069 super().__init__( 1070 evaluation_context, 1071 applied_ptransform, 1072 input_committed_bundle, 1073 side_inputs) 1074 1075 def start_bundle(self): 1076 assert len(self._outputs) == 1 1077 self.output_pcollection = list(self._outputs)[0] 1078 self.driver = create_trigger_driver( 1079 self._applied_ptransform.transform.windowing, 1080 clock=self._evaluation_context._watermark_manager._clock) 1081 self.gabw_items = [] 1082 self.keyed_holds = {} 1083 1084 # The input type (which is the same as the output type) of a 1085 # GroupAlsoByWindow will be Tuple[Any, Iter[Any]] or more specific. 1086 kv_type_hint = self._applied_ptransform.outputs[None].element_type 1087 key_type_hint = (kv_type_hint.tuple_types[0] if kv_type_hint else Any) 1088 self.key_coder = coders.registry.get_coder(key_type_hint) 1089 1090 def process_element(self, element): 1091 kwi = element.value 1092 assert isinstance(kwi, KeyedWorkItem), kwi 1093 encoded_k, timer_firings, vs = ( 1094 kwi.encoded_key, kwi.timer_firings, kwi.elements) 1095 k = self.key_coder.decode(encoded_k) 1096 state = self._step_context.get_keyed_state(encoded_k) 1097 1098 watermarks = self._evaluation_context._watermark_manager.get_watermarks( 1099 self._applied_ptransform) 1100 for timer_firing in timer_firings: 1101 for wvalue in self.driver.process_timer(timer_firing.window, 1102 timer_firing.name, 1103 timer_firing.time_domain, 1104 timer_firing.timestamp, 1105 state, 1106 watermarks.input_watermark): 1107 self.gabw_items.append(wvalue.with_value((k, wvalue.value))) 1108 if vs: 1109 for wvalue in self.driver.process_elements(state, 1110 vs, 1111 watermarks.output_watermark, 1112 watermarks.input_watermark): 1113 self.gabw_items.append(wvalue.with_value((k, wvalue.value))) 1114 1115 self.keyed_holds[encoded_k] = state.get_earliest_hold() 1116 1117 def finish_bundle(self): 1118 bundles = [] 1119 if self.gabw_items: 1120 bundle = self._evaluation_context.create_bundle(self.output_pcollection) 1121 for item in self.gabw_items: 1122 bundle.add(item) 1123 bundles.append(bundle) 1124 1125 return TransformResult(self, bundles, [], None, self.keyed_holds) 1126 1127 1128 class _NativeWriteEvaluator(_TransformEvaluator): 1129 """TransformEvaluator for _NativeWrite transform.""" 1130 1131 ELEMENTS_TAG = _ListStateTag('elements') 1132 1133 def __init__( 1134 self, 1135 evaluation_context, 1136 applied_ptransform, 1137 input_committed_bundle, 1138 side_inputs): 1139 assert not side_inputs 1140 super().__init__( 1141 evaluation_context, 1142 applied_ptransform, 1143 input_committed_bundle, 1144 side_inputs) 1145 1146 assert applied_ptransform.transform.sink 1147 self._sink = applied_ptransform.transform.sink 1148 1149 @property 1150 def _is_final_bundle(self): 1151 return ( 1152 self._execution_context.watermarks.input_watermark == 1153 WatermarkManager.WATERMARK_POS_INF) 1154 1155 @property 1156 def _has_already_produced_output(self): 1157 return ( 1158 self._execution_context.watermarks.output_watermark == 1159 WatermarkManager.WATERMARK_POS_INF) 1160 1161 def start_bundle(self): 1162 self.global_state = self._step_context.get_keyed_state(None) 1163 1164 def process_timer(self, timer_firing): 1165 # We do not need to emit a KeyedWorkItem to process_element(). 1166 pass 1167 1168 def process_element(self, element): 1169 self.global_state.add_state( 1170 None, _NativeWriteEvaluator.ELEMENTS_TAG, element) 1171 1172 def finish_bundle(self): 1173 # finish_bundle will append incoming bundles in memory until all the bundles 1174 # carrying data is processed. This is done to produce only a single output 1175 # shard (some tests depends on this behavior). It is possible to have 1176 # incoming empty bundles after the output is produced, these bundles will be 1177 # ignored and would not generate additional output files. 1178 # TODO(altay): Do not wait until the last bundle to write in a single shard. 1179 if self._is_final_bundle: 1180 elements = self.global_state.get_state( 1181 None, _NativeWriteEvaluator.ELEMENTS_TAG) 1182 if self._has_already_produced_output: 1183 # Ignore empty bundles that arrive after the output is produced. 1184 assert elements == [] 1185 else: 1186 self._sink.pipeline_options = self._evaluation_context.pipeline_options 1187 with self._sink.writer() as writer: 1188 for v in elements: 1189 writer.Write(v.value) 1190 hold = WatermarkManager.WATERMARK_POS_INF 1191 else: 1192 hold = WatermarkManager.WATERMARK_NEG_INF 1193 self.global_state.set_timer( 1194 None, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_POS_INF) 1195 1196 return TransformResult(self, [], [], None, {None: hold}) 1197 1198 1199 class _ProcessElementsEvaluator(_TransformEvaluator): 1200 """An evaluator for sdf_direct_runner.ProcessElements transform.""" 1201 1202 # Maximum number of elements that will be produced by a Splittable DoFn before 1203 # a checkpoint is requested by the runner. 1204 DEFAULT_MAX_NUM_OUTPUTS = None 1205 # Maximum duration a Splittable DoFn will process an element before a 1206 # checkpoint is requested by the runner. 1207 DEFAULT_MAX_DURATION = 1 1208 1209 def __init__( 1210 self, 1211 evaluation_context, 1212 applied_ptransform, 1213 input_committed_bundle, 1214 side_inputs): 1215 super().__init__( 1216 evaluation_context, 1217 applied_ptransform, 1218 input_committed_bundle, 1219 side_inputs) 1220 1221 process_elements_transform = applied_ptransform.transform 1222 assert isinstance(process_elements_transform, ProcessElements) 1223 1224 # Replacing the do_fn of the transform with a wrapper do_fn that performs 1225 # SDF magic. 1226 transform = applied_ptransform.transform 1227 sdf = transform.sdf 1228 self._process_fn = transform.new_process_fn(sdf) 1229 transform.dofn = self._process_fn 1230 1231 assert isinstance(self._process_fn, ProcessFn) 1232 1233 self._process_fn.step_context = self._step_context 1234 1235 process_element_invoker = ( 1236 SDFProcessElementInvoker( 1237 max_num_outputs=self.DEFAULT_MAX_NUM_OUTPUTS, 1238 max_duration=self.DEFAULT_MAX_DURATION)) 1239 self._process_fn.set_process_element_invoker(process_element_invoker) 1240 1241 self._par_do_evaluator = _ParDoEvaluator( 1242 evaluation_context, 1243 applied_ptransform, 1244 input_committed_bundle, 1245 side_inputs, 1246 perform_dofn_pickle_test=False) 1247 self.keyed_holds = {} 1248 1249 def start_bundle(self): 1250 self._par_do_evaluator.start_bundle() 1251 1252 def process_element(self, element): 1253 assert isinstance(element, WindowedValue) 1254 assert len(element.windows) == 1 1255 window = element.windows[0] 1256 if isinstance(element.value, KeyedWorkItem): 1257 key = element.value.encoded_key 1258 else: 1259 # If not a `KeyedWorkItem`, this must be a tuple where key is a randomly 1260 # generated key and the value is a `WindowedValue` that contains an 1261 # `ElementAndRestriction` object. 1262 assert isinstance(element.value, tuple) 1263 key = element.value[0] 1264 1265 self._par_do_evaluator.process_element(element) 1266 1267 state = self._step_context.get_keyed_state(key) 1268 self.keyed_holds[key] = state.get_state( 1269 window, self._process_fn.watermark_hold_tag) 1270 1271 def finish_bundle(self): 1272 par_do_result = self._par_do_evaluator.finish_bundle() 1273 1274 transform_result = TransformResult( 1275 self, 1276 par_do_result.uncommitted_output_bundles, 1277 par_do_result.unprocessed_bundles, 1278 par_do_result.counters, 1279 par_do_result.keyed_watermark_holds, 1280 par_do_result.undeclared_tag_values) 1281 for key, keyed_hold in self.keyed_holds.items(): 1282 transform_result.keyed_watermark_holds[key] = keyed_hold 1283 return transform_result