github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/operations.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # cython: language_level=3 19 # cython: profile=True 20 21 """Worker operations executor.""" 22 23 # pytype: skip-file 24 # pylint: disable=super-with-arguments 25 26 import collections 27 import logging 28 import threading 29 import warnings 30 from typing import TYPE_CHECKING 31 from typing import Any 32 from typing import DefaultDict 33 from typing import Dict 34 from typing import FrozenSet 35 from typing import Hashable 36 from typing import Iterable 37 from typing import Iterator 38 from typing import List 39 from typing import Mapping 40 from typing import NamedTuple 41 from typing import Optional 42 from typing import Tuple 43 44 from apache_beam import coders 45 from apache_beam.internal import pickler 46 from apache_beam.io import iobase 47 from apache_beam.metrics import monitoring_infos 48 from apache_beam.metrics.cells import DistributionData 49 from apache_beam.metrics.execution import MetricsContainer 50 from apache_beam.portability.api import metrics_pb2 51 from apache_beam.runners import common 52 from apache_beam.runners.common import Receiver 53 from apache_beam.runners.worker import opcounters 54 from apache_beam.runners.worker import operation_specs 55 from apache_beam.runners.worker import sideinputs 56 from apache_beam.transforms import sideinputs as apache_sideinputs 57 from apache_beam.transforms import combiners 58 from apache_beam.transforms import core 59 from apache_beam.transforms import userstate 60 from apache_beam.transforms import window 61 from apache_beam.transforms.combiners import PhasedCombineFnExecutor 62 from apache_beam.transforms.combiners import curry_combine_fn 63 from apache_beam.transforms.window import GlobalWindows 64 from apache_beam.typehints.batch import BatchConverter 65 from apache_beam.utils.windowed_value import WindowedBatch 66 from apache_beam.utils.windowed_value import WindowedValue 67 68 if TYPE_CHECKING: 69 from apache_beam.runners.sdf_utils import SplitResultPrimary 70 from apache_beam.runners.sdf_utils import SplitResultResidual 71 from apache_beam.runners.worker.bundle_processor import ExecutionContext 72 from apache_beam.runners.worker.statesampler import StateSampler 73 from apache_beam.transforms.userstate import TimerSpec 74 75 # Allow some "pure mode" declarations. 76 try: 77 import cython 78 except ImportError: 79 80 class FakeCython(object): 81 compiled = False 82 83 globals()['cython'] = FakeCython() 84 85 _globally_windowed_value = GlobalWindows.windowed_value(None) 86 _global_window_type = type(_globally_windowed_value.windows[0]) 87 88 _LOGGER = logging.getLogger(__name__) 89 90 SdfSplitResultsPrimary = Tuple['DoOperation', 'SplitResultPrimary'] 91 SdfSplitResultsResidual = Tuple['DoOperation', 'SplitResultResidual'] 92 93 94 # TODO(BEAM-9324) Remove these workarounds once upgraded to Cython 3 95 def _cast_to_operation(value): 96 if cython.compiled: 97 return cython.cast(Operation, value) 98 else: 99 return value 100 101 102 # TODO(BEAM-9324) Remove these workarounds once upgraded to Cython 3 103 def _cast_to_receiver(value): 104 if cython.compiled: 105 return cython.cast(Receiver, value) 106 else: 107 return value 108 109 110 class ConsumerSet(Receiver): 111 """A ConsumerSet represents a graph edge between two Operation nodes. 112 113 The ConsumerSet object collects information from the output of the 114 Operation at one end of its edge and the input of the Operation at 115 the other edge. 116 ConsumerSet are attached to the outputting Operation. 117 """ 118 @staticmethod 119 def create(counter_factory, 120 step_name, # type: str 121 output_index, 122 consumers, # type: List[Operation] 123 coder, 124 producer_type_hints, 125 producer_batch_converter, # type: Optional[BatchConverter] 126 ): 127 # type: (...) -> ConsumerSet 128 if len(consumers) == 1: 129 consumer = consumers[0] 130 131 consumer_batch_preference = consumer.get_batching_preference() 132 consumer_batch_converter = consumer.get_input_batch_converter() 133 if (not consumer_batch_preference.supports_batches and 134 producer_batch_converter is None and 135 consumer_batch_converter is None): 136 return SingletonElementConsumerSet( 137 counter_factory, 138 step_name, 139 output_index, 140 consumer, 141 coder, 142 producer_type_hints) 143 144 return GeneralPurposeConsumerSet( 145 counter_factory, 146 step_name, 147 output_index, 148 coder, 149 producer_type_hints, 150 consumers, 151 producer_batch_converter) 152 153 def __init__(self, 154 counter_factory, 155 step_name, # type: str 156 output_index, 157 consumers, 158 coder, 159 producer_type_hints, 160 producer_batch_converter 161 ): 162 self.opcounter = opcounters.OperationCounters( 163 counter_factory, 164 step_name, 165 coder, 166 output_index, 167 producer_type_hints=producer_type_hints, 168 producer_batch_converter=producer_batch_converter) 169 # Used in repr. 170 self.step_name = step_name 171 self.output_index = output_index 172 self.coder = coder 173 self.consumers = consumers 174 175 def try_split(self, fraction_of_remainder): 176 # type: (...) -> Optional[Any] 177 # TODO(SDF): Consider supporting splitting each consumer individually. 178 # This would never come up in the existing SDF expansion, but might 179 # be useful to support fused SDF nodes. 180 # This would require dedicated delivery of the split results to each 181 # of the consumers separately. 182 return None 183 184 def current_element_progress(self): 185 # type: () -> Optional[iobase.RestrictionProgress] 186 187 """Returns the progress of the current element. 188 189 This progress should be an instance of 190 apache_beam.io.iobase.RestrictionProgress, or None if progress is unknown. 191 """ 192 # TODO(SDF): Could implement this as a weighted average, if it becomes 193 # useful to split on. 194 return None 195 196 def update_counters_start(self, windowed_value): 197 # type: (WindowedValue) -> None 198 self.opcounter.update_from(windowed_value) 199 200 def update_counters_finish(self): 201 # type: () -> None 202 self.opcounter.update_collect() 203 204 def update_counters_batch(self, windowed_batch): 205 # type: (WindowedBatch) -> None 206 self.opcounter.update_from_batch(windowed_batch) 207 208 def __repr__(self): 209 return '%s[%s.out%s, coder=%s, len(consumers)=%s]' % ( 210 self.__class__.__name__, 211 self.step_name, 212 self.output_index, 213 self.coder, 214 len(self.consumers)) 215 216 217 class SingletonElementConsumerSet(ConsumerSet): 218 """ConsumerSet representing a single consumer that can only process elements 219 (not batches).""" 220 def __init__(self, 221 counter_factory, 222 step_name, 223 output_index, 224 consumer, # type: Operation 225 coder, 226 producer_type_hints 227 ): 228 super().__init__( 229 counter_factory, 230 step_name, 231 output_index, [consumer], 232 coder, 233 producer_type_hints, 234 None) 235 self.consumer = consumer 236 237 def receive(self, windowed_value): 238 # type: (WindowedValue) -> None 239 self.update_counters_start(windowed_value) 240 self.consumer.process(windowed_value) 241 self.update_counters_finish() 242 243 def receive_batch(self, windowed_batch): 244 raise AssertionError( 245 "SingletonElementConsumerSet.receive_batch is not implemented") 246 247 def flush(self): 248 # SingletonElementConsumerSet has no buffer to flush 249 pass 250 251 def try_split(self, fraction_of_remainder): 252 # type: (...) -> Optional[Any] 253 return self.consumer.try_split(fraction_of_remainder) 254 255 def current_element_progress(self): 256 return self.consumer.current_element_progress() 257 258 259 class GeneralPurposeConsumerSet(ConsumerSet): 260 """ConsumerSet implementation that handles all combinations of possible edges. 261 """ 262 MAX_BATCH_SIZE = 4096 263 264 def __init__(self, 265 counter_factory, 266 step_name, # type: str 267 output_index, 268 coder, 269 producer_type_hints, 270 consumers, # type: List[Operation] 271 producer_batch_converter): 272 super().__init__( 273 counter_factory, 274 step_name, 275 output_index, 276 consumers, 277 coder, 278 producer_type_hints, 279 producer_batch_converter) 280 281 self.producer_batch_converter = producer_batch_converter 282 283 # Partition consumers into three groups: 284 # - consumers that will be passed elements 285 # - consumers that will be passed batches (where their input batch type 286 # matches the output of the producer) 287 # - consumers that will be passed converted batches 288 self.element_consumers: List[Operation] = [] 289 self.passthrough_batch_consumers: List[Operation] = [] 290 other_batch_consumers: DefaultDict[ 291 BatchConverter, List[Operation]] = collections.defaultdict(lambda: []) 292 293 for consumer in consumers: 294 if not consumer.get_batching_preference().supports_batches: 295 self.element_consumers.append(consumer) 296 elif (consumer.get_input_batch_converter() == 297 self.producer_batch_converter): 298 self.passthrough_batch_consumers.append(consumer) 299 else: 300 # Batch consumer with a mismatched batch type 301 if consumer.get_batching_preference().supports_elements: 302 # Pass it elements if we can 303 self.element_consumers.append(consumer) 304 else: 305 # As a last resort, explode and rebatch 306 consumer_batch_converter = consumer.get_input_batch_converter() 307 # This consumer supports batches, it must have a batch converter 308 assert consumer_batch_converter is not None 309 other_batch_consumers[consumer_batch_converter].append(consumer) 310 311 self.other_batch_consumers: Dict[BatchConverter, List[Operation]] = dict( 312 other_batch_consumers) 313 314 self.has_batch_consumers = ( 315 self.passthrough_batch_consumers or self.other_batch_consumers) 316 self._batched_elements: List[Any] = [] 317 318 def receive(self, windowed_value): 319 # type: (WindowedValue) -> None 320 321 self.update_counters_start(windowed_value) 322 323 for consumer in self.element_consumers: 324 _cast_to_operation(consumer).process(windowed_value) 325 326 # TODO: Do this branching when contstructing ConsumerSet 327 if self.has_batch_consumers: 328 self._batched_elements.append(windowed_value) 329 if len(self._batched_elements) >= self.MAX_BATCH_SIZE: 330 self.flush() 331 332 # TODO(https://github.com/apache/beam/issues/21655): Properly estimate 333 # sizes in the batch-consumer only case, this undercounts large iterables 334 self.update_counters_finish() 335 336 def receive_batch(self, windowed_batch): 337 if self.element_consumers: 338 for wv in windowed_batch.as_windowed_values( 339 self.producer_batch_converter.explode_batch): 340 for consumer in self.element_consumers: 341 _cast_to_operation(consumer).process(wv) 342 343 for consumer in self.passthrough_batch_consumers: 344 _cast_to_operation(consumer).process_batch(windowed_batch) 345 346 for (consumer_batch_converter, 347 consumers) in self.other_batch_consumers.items(): 348 # Explode and rebatch into the new batch type (ouch!) 349 # TODO: Register direct conversions for equivalent batch types 350 351 for consumer in consumers: 352 warnings.warn( 353 f"Input to operation {consumer} must be rebatched from type " 354 f"{self.producer_batch_converter.batch_type!r} to " 355 f"{consumer_batch_converter.batch_type!r}.\n" 356 "This is very inefficient, consider re-structuring your pipeline " 357 "or adding a DoFn to directly convert between these types.", 358 InefficientExecutionWarning) 359 _cast_to_operation(consumer).process_batch( 360 windowed_batch.with_values( 361 consumer_batch_converter.produce_batch( 362 self.producer_batch_converter.explode_batch( 363 windowed_batch.values)))) 364 365 self.update_counters_batch(windowed_batch) 366 367 def flush(self): 368 if not self.has_batch_consumers or not self._batched_elements: 369 return 370 371 for batch_converter, consumers in self.other_batch_consumers.items(): 372 for windowed_batch in WindowedBatch.from_windowed_values( 373 self._batched_elements, produce_fn=batch_converter.produce_batch): 374 for consumer in consumers: 375 _cast_to_operation(consumer).process_batch(windowed_batch) 376 377 for consumer in self.passthrough_batch_consumers: 378 for windowed_batch in WindowedBatch.from_windowed_values( 379 self._batched_elements, 380 produce_fn=self.producer_batch_converter.produce_batch): 381 _cast_to_operation(consumer).process_batch(windowed_batch) 382 383 self._batched_elements = [] 384 385 386 class Operation(object): 387 """An operation representing the live version of a work item specification. 388 389 An operation can have one or more outputs and for each output it can have 390 one or more receiver operations that will take that as input. 391 """ 392 393 def __init__(self, 394 name_context, # type: common.NameContext 395 spec, 396 counter_factory, 397 state_sampler # type: StateSampler 398 ): 399 """Initializes a worker operation instance. 400 401 Args: 402 name_context: A NameContext instance, with the name information for this 403 operation. 404 spec: A operation_specs.Worker* instance. 405 counter_factory: The CounterFactory to use for our counters. 406 state_sampler: The StateSampler for the current operation. 407 """ 408 assert isinstance(name_context, common.NameContext) 409 self.name_context = name_context 410 411 self.spec = spec 412 self.counter_factory = counter_factory 413 self.execution_context = None # type: Optional[ExecutionContext] 414 self.consumers = collections.defaultdict( 415 list) # type: DefaultDict[int, List[Operation]] 416 417 # These are overwritten in the legacy harness. 418 self.metrics_container = MetricsContainer(self.name_context.metrics_name()) 419 420 self.state_sampler = state_sampler 421 self.scoped_start_state = self.state_sampler.scoped_state( 422 self.name_context, 'start', metrics_container=self.metrics_container) 423 self.scoped_process_state = self.state_sampler.scoped_state( 424 self.name_context, 'process', metrics_container=self.metrics_container) 425 self.scoped_finish_state = self.state_sampler.scoped_state( 426 self.name_context, 'finish', metrics_container=self.metrics_container) 427 # TODO(ccy): the '-abort' state can be added when the abort is supported in 428 # Operations. 429 self.receivers = [] # type: List[ConsumerSet] 430 # Legacy workers cannot call setup() until after setting additional state 431 # on the operation. 432 self.setup_done = False 433 self.step_name = None # type: Optional[str] 434 435 def setup(self): 436 # type: () -> None 437 438 """Set up operation. 439 440 This must be called before any other methods of the operation.""" 441 with self.scoped_start_state: 442 self.debug_logging_enabled = logging.getLogger().isEnabledFor( 443 logging.DEBUG) 444 # Everything except WorkerSideInputSource, which is not a 445 # top-level operation, should have output_coders 446 #TODO(pabloem): Define better what step name is used here. 447 if getattr(self.spec, 'output_coders', None): 448 self.receivers = [ 449 ConsumerSet.create( 450 self.counter_factory, 451 self.name_context.logging_name(), 452 i, 453 self.consumers[i], 454 coder, 455 self._get_runtime_performance_hints(), 456 self.get_output_batch_converter(), 457 ) for i, 458 coder in enumerate(self.spec.output_coders) 459 ] 460 self.setup_done = True 461 462 def start(self): 463 # type: () -> None 464 465 """Start operation.""" 466 if not self.setup_done: 467 # For legacy workers. 468 self.setup() 469 470 def get_batching_preference(self): 471 # By default operations don't support batching, require Receiver to unbatch 472 return common.BatchingPreference.BATCH_FORBIDDEN 473 474 def get_input_batch_converter(self) -> Optional[BatchConverter]: 475 """Returns a batch type converter if this operation can accept a batch, 476 otherwise None.""" 477 return None 478 479 def get_output_batch_converter(self) -> Optional[BatchConverter]: 480 """Returns a batch type converter if this operation can produce a batch, 481 otherwise None.""" 482 return None 483 484 def process(self, o): 485 # type: (WindowedValue) -> None 486 487 """Process element in operation.""" 488 pass 489 490 def process_batch(self, batch: WindowedBatch): 491 pass 492 493 def finalize_bundle(self): 494 # type: () -> None 495 pass 496 497 def needs_finalization(self): 498 return False 499 500 def try_split(self, fraction_of_remainder): 501 # type: (...) -> Optional[Any] 502 return None 503 504 def current_element_progress(self): 505 return None 506 507 def finish(self): 508 # type: () -> None 509 510 """Finish operation.""" 511 for receiver in self.receivers: 512 _cast_to_receiver(receiver).flush() 513 514 def teardown(self): 515 # type: () -> None 516 517 """Tear down operation. 518 519 No other methods of this operation should be called after this.""" 520 pass 521 522 def reset(self): 523 # type: () -> None 524 self.metrics_container.reset() 525 526 def output(self, windowed_value, output_index=0): 527 # type: (WindowedValue, int) -> None 528 _cast_to_receiver(self.receivers[output_index]).receive(windowed_value) 529 530 def add_receiver(self, operation, output_index=0): 531 # type: (Operation, int) -> None 532 533 """Adds a receiver operation for the specified output.""" 534 self.consumers[output_index].append(operation) 535 536 def monitoring_infos(self, transform_id, tag_to_pcollection_id): 537 # type: (str, Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] 538 539 """Returns the list of MonitoringInfos collected by this operation.""" 540 all_monitoring_infos = self.execution_time_monitoring_infos(transform_id) 541 all_monitoring_infos.update( 542 self.pcollection_count_monitoring_infos(tag_to_pcollection_id)) 543 all_monitoring_infos.update(self.user_monitoring_infos(transform_id)) 544 return all_monitoring_infos 545 546 def pcollection_count_monitoring_infos(self, tag_to_pcollection_id): 547 # type: (Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] 548 549 """Returns the element count MonitoringInfo collected by this operation.""" 550 551 # Skip producing monitoring infos if there is more then one receiver 552 # since there is no way to provide a mapping from tag to pcollection id 553 # within Operation. 554 if len(self.receivers) != 1 or len(tag_to_pcollection_id) != 1: 555 return {} 556 557 all_monitoring_infos = {} 558 pcollection_id = next(iter(tag_to_pcollection_id.values())) 559 receiver = self.receivers[0] 560 elem_count_mi = monitoring_infos.int64_counter( 561 monitoring_infos.ELEMENT_COUNT_URN, 562 receiver.opcounter.element_counter.value(), 563 pcollection=pcollection_id, 564 ) 565 566 (unused_mean, sum, count, min, max) = ( 567 receiver.opcounter.mean_byte_counter.value()) 568 569 sampled_byte_count = monitoring_infos.int64_distribution( 570 monitoring_infos.SAMPLED_BYTE_SIZE_URN, 571 DistributionData(sum, count, min, max), 572 pcollection=pcollection_id, 573 ) 574 all_monitoring_infos[monitoring_infos.to_key(elem_count_mi)] = elem_count_mi 575 all_monitoring_infos[monitoring_infos.to_key( 576 sampled_byte_count)] = sampled_byte_count 577 578 return all_monitoring_infos 579 580 def user_monitoring_infos(self, transform_id): 581 # type: (str) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] 582 583 """Returns the user MonitoringInfos collected by this operation.""" 584 return self.metrics_container.to_runner_api_monitoring_infos(transform_id) 585 586 def execution_time_monitoring_infos(self, transform_id): 587 # type: (str) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] 588 total_time_spent_msecs = ( 589 self.scoped_start_state.sampled_msecs_int() + 590 self.scoped_process_state.sampled_msecs_int() + 591 self.scoped_finish_state.sampled_msecs_int()) 592 mis = [ 593 monitoring_infos.int64_counter( 594 monitoring_infos.START_BUNDLE_MSECS_URN, 595 self.scoped_start_state.sampled_msecs_int(), 596 ptransform=transform_id), 597 monitoring_infos.int64_counter( 598 monitoring_infos.PROCESS_BUNDLE_MSECS_URN, 599 self.scoped_process_state.sampled_msecs_int(), 600 ptransform=transform_id), 601 monitoring_infos.int64_counter( 602 monitoring_infos.FINISH_BUNDLE_MSECS_URN, 603 self.scoped_finish_state.sampled_msecs_int(), 604 ptransform=transform_id), 605 monitoring_infos.int64_counter( 606 monitoring_infos.TOTAL_MSECS_URN, 607 total_time_spent_msecs, 608 ptransform=transform_id), 609 ] 610 return {monitoring_infos.to_key(mi): mi for mi in mis} 611 612 def __str__(self): 613 """Generates a useful string for this object. 614 615 Compactly displays interesting fields. In particular, pickled 616 fields are not displayed. Note that we collapse the fields of the 617 contained Worker* object into this object, since there is a 1-1 618 mapping between Operation and operation_specs.Worker*. 619 620 Returns: 621 Compact string representing this object. 622 """ 623 return self.str_internal() 624 625 def str_internal(self, is_recursive=False): 626 """Internal helper for __str__ that supports recursion. 627 628 When recursing on receivers, keep the output short. 629 Args: 630 is_recursive: whether to omit some details, particularly receivers. 631 Returns: 632 Compact string representing this object. 633 """ 634 printable_name = self.__class__.__name__ 635 if hasattr(self, 'step_name'): 636 printable_name += ' %s' % self.name_context.logging_name() 637 if is_recursive: 638 # If we have a step name, stop here, no more detail needed. 639 return '<%s>' % printable_name 640 641 if self.spec is None: 642 printable_fields = [] 643 else: 644 printable_fields = operation_specs.worker_printable_fields(self.spec) 645 646 if not is_recursive and getattr(self, 'receivers', []): 647 printable_fields.append( 648 'receivers=[%s]' % 649 ', '.join([str(receiver) for receiver in self.receivers])) 650 651 return '<%s %s>' % (printable_name, ', '.join(printable_fields)) 652 653 def _get_runtime_performance_hints(self): 654 # type: () -> Optional[Dict[Optional[str], Tuple[Optional[str], Any]]] 655 656 """Returns any type hints required for performance runtime 657 type-checking.""" 658 return None 659 660 661 class ReadOperation(Operation): 662 def start(self): 663 with self.scoped_start_state: 664 super(ReadOperation, self).start() 665 range_tracker = self.spec.source.source.get_range_tracker( 666 self.spec.source.start_position, self.spec.source.stop_position) 667 for value in self.spec.source.source.read(range_tracker): 668 if isinstance(value, WindowedValue): 669 windowed_value = value 670 else: 671 windowed_value = _globally_windowed_value.with_value(value) 672 self.output(windowed_value) 673 674 675 class ImpulseReadOperation(Operation): 676 def __init__( 677 self, 678 name_context, # type: common.NameContext 679 counter_factory, 680 state_sampler, # type: StateSampler 681 consumers, # type: Mapping[Any, List[Operation]] 682 source, # type: iobase.BoundedSource 683 output_coder): 684 super(ImpulseReadOperation, 685 self).__init__(name_context, None, counter_factory, state_sampler) 686 self.source = source 687 688 self.receivers = [ 689 ConsumerSet.create( 690 self.counter_factory, 691 self.name_context.step_name, 692 0, 693 next(iter(consumers.values())), 694 output_coder, 695 self._get_runtime_performance_hints(), 696 self.get_output_batch_converter()) 697 ] 698 699 def process(self, unused_impulse): 700 # type: (WindowedValue) -> None 701 with self.scoped_process_state: 702 range_tracker = self.source.get_range_tracker(None, None) 703 for value in self.source.read(range_tracker): 704 if isinstance(value, WindowedValue): 705 windowed_value = value 706 else: 707 windowed_value = _globally_windowed_value.with_value(value) 708 self.output(windowed_value) 709 710 711 class InMemoryWriteOperation(Operation): 712 """A write operation that will write to an in-memory sink.""" 713 def process(self, o): 714 # type: (WindowedValue) -> None 715 with self.scoped_process_state: 716 if self.debug_logging_enabled: 717 _LOGGER.debug('Processing [%s] in %s', o, self) 718 self.spec.output_buffer.append( 719 o if self.spec.write_windowed_values else o.value) 720 721 722 class _TaggedReceivers(dict): 723 def __init__(self, counter_factory, step_name): 724 self._counter_factory = counter_factory 725 self._step_name = step_name 726 727 def __missing__(self, tag): 728 self[tag] = receiver = ConsumerSet.create( 729 self._counter_factory, self._step_name, tag, [], None, None, None) 730 return receiver 731 732 def total_output_bytes(self): 733 # type: () -> int 734 total = 0 735 for receiver in self.values(): 736 elements = receiver.opcounter.element_counter.value() 737 if elements > 0: 738 mean = (receiver.opcounter.mean_byte_counter.value())[0] 739 total += elements * mean 740 return total 741 742 743 OpInputInfo = NamedTuple( 744 'OpInputInfo', 745 [ 746 ('transform_id', str), 747 ('main_input_tag', str), 748 ('main_input_coder', coders.WindowedValueCoder), 749 ('outputs', Iterable[str]), 750 ]) 751 752 753 class DoOperation(Operation): 754 """A Do operation that will execute a custom DoFn for each input element.""" 755 756 def __init__(self, 757 name, # type: common.NameContext 758 spec, # operation_specs.WorkerDoFn # need to fix this type 759 counter_factory, 760 sampler, 761 side_input_maps=None, 762 user_state_context=None 763 ): 764 super(DoOperation, self).__init__(name, spec, counter_factory, sampler) 765 self.side_input_maps = side_input_maps 766 self.user_state_context = user_state_context 767 self.tagged_receivers = None # type: Optional[_TaggedReceivers] 768 # A mapping of timer tags to the input "PCollections" they come in on. 769 self.input_info = None # type: Optional[OpInputInfo] 770 771 # See fn_data in dataflow_runner.py 772 # TODO: Store all the items from spec? 773 self.fn, _, _, _, _ = (pickler.loads(self.spec.serialized_fn)) 774 775 def _read_side_inputs(self, tags_and_types): 776 # type: (...) -> Iterator[apache_sideinputs.SideInputMap] 777 778 """Generator reading side inputs in the order prescribed by tags_and_types. 779 780 Args: 781 tags_and_types: List of tuples (tag, type). Each side input has a string 782 tag that is specified in the worker instruction. The type is actually 783 a boolean which is True for singleton input (read just first value) 784 and False for collection input (read all values). 785 786 Yields: 787 With each iteration it yields the result of reading an entire side source 788 either in singleton or collection mode according to the tags_and_types 789 argument. 790 """ 791 # Only call this on the old path where side_input_maps was not 792 # provided directly. 793 assert self.side_input_maps is None 794 795 # We will read the side inputs in the order prescribed by the 796 # tags_and_types argument because this is exactly the order needed to 797 # replace the ArgumentPlaceholder objects in the args/kwargs of the DoFn 798 # getting the side inputs. 799 # 800 # Note that for each tag there could be several read operations in the 801 # specification. This can happen for instance if the source has been 802 # sharded into several files. 803 for i, (side_tag, view_class, view_options) in enumerate(tags_and_types): 804 sources = [] 805 # Using the side_tag in the lambda below will trigger a pylint warning. 806 # However in this case it is fine because the lambda is used right away 807 # while the variable has the value assigned by the current iteration of 808 # the for loop. 809 # pylint: disable=cell-var-from-loop 810 for si in filter(lambda o: o.tag == side_tag, self.spec.side_inputs): 811 if not isinstance(si, operation_specs.WorkerSideInputSource): 812 raise NotImplementedError('Unknown side input type: %r' % si) 813 sources.append(si.source) 814 si_counter = opcounters.SideInputReadCounter( 815 self.counter_factory, 816 self.state_sampler, 817 declaring_step=self.name_context.step_name, 818 # Inputs are 1-indexed, so we add 1 to i in the side input id 819 input_index=i + 1) 820 element_counter = opcounters.OperationCounters( 821 self.counter_factory, 822 self.name_context.step_name, 823 view_options['coder'], 824 i, 825 suffix='side-input') 826 iterator_fn = sideinputs.get_iterator_fn_for_sources( 827 sources, read_counter=si_counter, element_counter=element_counter) 828 yield apache_sideinputs.SideInputMap( 829 view_class, view_options, sideinputs.EmulatedIterable(iterator_fn)) 830 831 def setup(self): 832 # type: () -> None 833 with self.scoped_start_state: 834 super(DoOperation, self).setup() 835 836 # See fn_data in dataflow_runner.py 837 fn, args, kwargs, tags_and_types, window_fn = ( 838 pickler.loads(self.spec.serialized_fn)) 839 840 state = common.DoFnState(self.counter_factory) 841 state.step_name = self.name_context.logging_name() 842 843 # Tag to output index map used to dispatch the output values emitted 844 # by the DoFn function to the appropriate receivers. The main output is 845 # either the only output or the output tagged with 'None' and is 846 # associated with its corresponding index. 847 self.tagged_receivers = _TaggedReceivers( 848 self.counter_factory, self.name_context.logging_name()) 849 850 if len(self.spec.output_tags) == 1: 851 self.tagged_receivers[None] = self.receivers[0] 852 self.tagged_receivers[self.spec.output_tags[0]] = self.receivers[0] 853 else: 854 for index, tag in enumerate(self.spec.output_tags): 855 self.tagged_receivers[tag] = self.receivers[index] 856 if tag == 'None': 857 self.tagged_receivers[None] = self.receivers[index] 858 859 if self.user_state_context: 860 self.timer_specs = { 861 spec.name: spec 862 for spec in userstate.get_dofn_specs(fn)[1] 863 } # type: Dict[str, TimerSpec] 864 865 if self.side_input_maps is None: 866 if tags_and_types: 867 self.side_input_maps = list(self._read_side_inputs(tags_and_types)) 868 else: 869 self.side_input_maps = [] 870 871 self.dofn_runner = common.DoFnRunner( 872 fn, 873 args, 874 kwargs, 875 self.side_input_maps, 876 window_fn, 877 tagged_receivers=self.tagged_receivers, 878 step_name=self.name_context.logging_name(), 879 state=state, 880 user_state_context=self.user_state_context, 881 operation_name=self.name_context.metrics_name()) 882 self.dofn_runner.setup() 883 884 def start(self): 885 # type: () -> None 886 with self.scoped_start_state: 887 super(DoOperation, self).start() 888 self.dofn_runner.start() 889 890 def get_batching_preference(self): 891 if self.fn._process_batch_defined: 892 if self.fn._process_defined: 893 return common.BatchingPreference.DO_NOT_CARE 894 else: 895 return common.BatchingPreference.BATCH_REQUIRED 896 else: 897 return common.BatchingPreference.BATCH_FORBIDDEN 898 899 def get_input_batch_converter(self) -> Optional[BatchConverter]: 900 return getattr(self.fn, 'input_batch_converter', None) 901 902 def get_output_batch_converter(self) -> Optional[BatchConverter]: 903 return getattr(self.fn, 'output_batch_converter', None) 904 905 def process(self, o): 906 # type: (WindowedValue) -> None 907 with self.scoped_process_state: 908 delayed_applications = self.dofn_runner.process(o) 909 if delayed_applications: 910 assert self.execution_context is not None 911 for delayed_application in delayed_applications: 912 self.execution_context.delayed_applications.append( 913 (self, delayed_application)) 914 915 def process_batch(self, windowed_batch: WindowedBatch) -> None: 916 self.dofn_runner.process_batch(windowed_batch) 917 918 def finalize_bundle(self): 919 # type: () -> None 920 self.dofn_runner.finalize() 921 922 def needs_finalization(self): 923 # type: () -> bool 924 return self.dofn_runner.bundle_finalizer_param.has_callbacks() 925 926 def add_timer_info(self, timer_family_id, timer_info): 927 self.user_state_context.add_timer_info(timer_family_id, timer_info) 928 929 def process_timer(self, tag, timer_data): 930 timer_spec = self.timer_specs[tag] 931 self.dofn_runner.process_user_timer( 932 timer_spec, 933 timer_data.user_key, 934 timer_data.windows[0], 935 timer_data.fire_timestamp, 936 timer_data.paneinfo, 937 timer_data.dynamic_timer_tag) 938 939 def finish(self): 940 # type: () -> None 941 super(DoOperation, self).finish() 942 with self.scoped_finish_state: 943 self.dofn_runner.finish() 944 if self.user_state_context: 945 self.user_state_context.commit() 946 947 def teardown(self): 948 # type: () -> None 949 with self.scoped_finish_state: 950 self.dofn_runner.teardown() 951 if self.user_state_context: 952 self.user_state_context.reset() 953 954 def reset(self): 955 # type: () -> None 956 super(DoOperation, self).reset() 957 for side_input_map in self.side_input_maps: 958 side_input_map.reset() 959 if self.user_state_context: 960 self.user_state_context.reset() 961 self.dofn_runner.bundle_finalizer_param.reset() 962 963 def pcollection_count_monitoring_infos(self, tag_to_pcollection_id): 964 # type: (Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] 965 966 """Returns the element count MonitoringInfo collected by this operation.""" 967 infos = super( 968 DoOperation, 969 self).pcollection_count_monitoring_infos(tag_to_pcollection_id) 970 971 if self.tagged_receivers: 972 for tag, receiver in self.tagged_receivers.items(): 973 if str(tag) not in tag_to_pcollection_id: 974 continue 975 pcollection_id = tag_to_pcollection_id[str(tag)] 976 mi = monitoring_infos.int64_counter( 977 monitoring_infos.ELEMENT_COUNT_URN, 978 receiver.opcounter.element_counter.value(), 979 pcollection=pcollection_id) 980 infos[monitoring_infos.to_key(mi)] = mi 981 (unused_mean, sum, count, min, max) = ( 982 receiver.opcounter.mean_byte_counter.value()) 983 sampled_byte_count = monitoring_infos.int64_distribution( 984 monitoring_infos.SAMPLED_BYTE_SIZE_URN, 985 DistributionData(sum, count, min, max), 986 pcollection=pcollection_id) 987 infos[monitoring_infos.to_key(sampled_byte_count)] = sampled_byte_count 988 return infos 989 990 def _get_runtime_performance_hints(self): 991 fns = pickler.loads(self.spec.serialized_fn) 992 if fns and hasattr(fns[0], '_runtime_output_constraints'): 993 return fns[0]._runtime_output_constraints 994 995 return {} 996 997 998 class SdfTruncateSizedRestrictions(DoOperation): 999 def __init__(self, *args, **kwargs): 1000 super(SdfTruncateSizedRestrictions, self).__init__(*args, **kwargs) 1001 1002 def current_element_progress(self): 1003 # type: () -> Optional[iobase.RestrictionProgress] 1004 return self.receivers[0].current_element_progress() 1005 1006 def try_split( 1007 self, fraction_of_remainder 1008 ): # type: (...) -> Optional[Tuple[Iterable[SdfSplitResultsPrimary], Iterable[SdfSplitResultsResidual]]] 1009 return self.receivers[0].try_split(fraction_of_remainder) 1010 1011 1012 class SdfProcessSizedElements(DoOperation): 1013 def __init__(self, *args, **kwargs): 1014 super(SdfProcessSizedElements, self).__init__(*args, **kwargs) 1015 self.lock = threading.RLock() 1016 self.element_start_output_bytes = None # type: Optional[int] 1017 1018 def process(self, o): 1019 # type: (WindowedValue) -> None 1020 assert self.tagged_receivers is not None 1021 with self.scoped_process_state: 1022 try: 1023 with self.lock: 1024 self.element_start_output_bytes = \ 1025 self.tagged_receivers.total_output_bytes() 1026 for receiver in self.tagged_receivers.values(): 1027 receiver.opcounter.restart_sampling() 1028 # Actually processing the element can be expensive; do it without 1029 # the lock. 1030 delayed_applications = self.dofn_runner.process_with_sized_restriction( 1031 o) 1032 if delayed_applications: 1033 assert self.execution_context is not None 1034 for delayed_application in delayed_applications: 1035 self.execution_context.delayed_applications.append( 1036 (self, delayed_application)) 1037 finally: 1038 with self.lock: 1039 self.element_start_output_bytes = None 1040 1041 def try_split(self, fraction_of_remainder): 1042 # type: (...) -> Optional[Tuple[Iterable[SdfSplitResultsPrimary], Iterable[SdfSplitResultsResidual]]] 1043 split = self.dofn_runner.try_split(fraction_of_remainder) 1044 if split: 1045 primaries, residuals = split 1046 return [(self, primary) for primary in primaries 1047 ], [(self, residual) for residual in residuals] 1048 return None 1049 1050 def current_element_progress(self): 1051 # type: () -> Optional[iobase.RestrictionProgress] 1052 with self.lock: 1053 if self.element_start_output_bytes is not None: 1054 progress = self.dofn_runner.current_element_progress() 1055 if progress is not None: 1056 assert self.tagged_receivers is not None 1057 return progress.with_completed( 1058 self.tagged_receivers.total_output_bytes() - 1059 self.element_start_output_bytes) 1060 return None 1061 1062 def monitoring_infos(self, transform_id, tag_to_pcollection_id): 1063 # type: (str, Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo] 1064 1065 def encode_progress(value): 1066 # type: (float) -> bytes 1067 coder = coders.IterableCoder(coders.FloatCoder()) 1068 return coder.encode([value]) 1069 1070 with self.lock: 1071 infos = super(SdfProcessSizedElements, 1072 self).monitoring_infos(transform_id, tag_to_pcollection_id) 1073 current_element_progress = self.current_element_progress() 1074 if current_element_progress: 1075 if current_element_progress.completed_work: 1076 completed = current_element_progress.completed_work 1077 remaining = current_element_progress.remaining_work 1078 else: 1079 completed = current_element_progress.fraction_completed 1080 remaining = current_element_progress.fraction_remaining 1081 assert completed is not None 1082 assert remaining is not None 1083 completed_mi = metrics_pb2.MonitoringInfo( 1084 urn=monitoring_infos.WORK_COMPLETED_URN, 1085 type=monitoring_infos.PROGRESS_TYPE, 1086 labels=monitoring_infos.create_labels(ptransform=transform_id), 1087 payload=encode_progress(completed)) 1088 remaining_mi = metrics_pb2.MonitoringInfo( 1089 urn=monitoring_infos.WORK_REMAINING_URN, 1090 type=monitoring_infos.PROGRESS_TYPE, 1091 labels=monitoring_infos.create_labels(ptransform=transform_id), 1092 payload=encode_progress(remaining)) 1093 infos[monitoring_infos.to_key(completed_mi)] = completed_mi 1094 infos[monitoring_infos.to_key(remaining_mi)] = remaining_mi 1095 return infos 1096 1097 1098 class CombineOperation(Operation): 1099 """A Combine operation executing a CombineFn for each input element.""" 1100 def __init__(self, name_context, spec, counter_factory, state_sampler): 1101 super(CombineOperation, 1102 self).__init__(name_context, spec, counter_factory, state_sampler) 1103 # Combiners do not accept deferred side-inputs (the ignored fourth argument) 1104 # and therefore the code to handle the extra args/kwargs is simpler than for 1105 # the DoFn's of ParDo. 1106 fn, args, kwargs = pickler.loads(self.spec.serialized_fn)[:3] 1107 self.phased_combine_fn = ( 1108 PhasedCombineFnExecutor(self.spec.phase, fn, args, kwargs)) 1109 1110 def setup(self): 1111 # type: () -> None 1112 with self.scoped_start_state: 1113 _LOGGER.debug('Setup called for %s', self) 1114 super(CombineOperation, self).setup() 1115 self.phased_combine_fn.combine_fn.setup() 1116 1117 def process(self, o): 1118 # type: (WindowedValue) -> None 1119 with self.scoped_process_state: 1120 if self.debug_logging_enabled: 1121 _LOGGER.debug('Processing [%s] in %s', o, self) 1122 key, values = o.value 1123 self.output(o.with_value((key, self.phased_combine_fn.apply(values)))) 1124 1125 def finish(self): 1126 # type: () -> None 1127 _LOGGER.debug('Finishing %s', self) 1128 super(CombineOperation, self).finish() 1129 1130 def teardown(self): 1131 # type: () -> None 1132 with self.scoped_finish_state: 1133 _LOGGER.debug('Teardown called for %s', self) 1134 super(CombineOperation, self).teardown() 1135 self.phased_combine_fn.combine_fn.teardown() 1136 1137 1138 def create_pgbk_op(step_name, spec, counter_factory, state_sampler): 1139 if spec.combine_fn: 1140 return PGBKCVOperation(step_name, spec, counter_factory, state_sampler) 1141 else: 1142 return PGBKOperation(step_name, spec, counter_factory, state_sampler) 1143 1144 1145 class PGBKOperation(Operation): 1146 """Partial group-by-key operation. 1147 1148 This takes (windowed) input (key, value) tuples and outputs 1149 (key, [value]) tuples, performing a best effort group-by-key for 1150 values in this bundle, memory permitting. 1151 """ 1152 def __init__(self, name_context, spec, counter_factory, state_sampler): 1153 super(PGBKOperation, 1154 self).__init__(name_context, spec, counter_factory, state_sampler) 1155 assert not self.spec.combine_fn 1156 self.table = collections.defaultdict(list) 1157 self.size = 0 1158 # TODO(robertwb) Make this configurable. 1159 self.max_size = 10 * 1000 1160 1161 def process(self, o): 1162 # type: (WindowedValue) -> None 1163 with self.scoped_process_state: 1164 # TODO(robertwb): Structural (hashable) values. 1165 key = o.value[0], tuple(o.windows) 1166 self.table[key].append(o) 1167 self.size += 1 1168 if self.size > self.max_size: 1169 self.flush(9 * self.max_size // 10) 1170 1171 def finish(self): 1172 # type: () -> None 1173 self.flush(0) 1174 super().finish() 1175 1176 def flush(self, target): 1177 # type: (int) -> None 1178 limit = self.size - target 1179 for ix, (kw, vs) in enumerate(list(self.table.items())): 1180 if ix >= limit: 1181 break 1182 del self.table[kw] 1183 key, windows = kw 1184 output_value = [v.value[1] for v in vs] 1185 windowed_value = WindowedValue((key, output_value), 1186 vs[0].timestamp, 1187 windows) 1188 self.output(windowed_value) 1189 1190 1191 class PGBKCVOperation(Operation): 1192 def __init__( 1193 self, name_context, spec, counter_factory, state_sampler, windowing=None): 1194 super(PGBKCVOperation, 1195 self).__init__(name_context, spec, counter_factory, state_sampler) 1196 # Combiners do not accept deferred side-inputs (the ignored fourth 1197 # argument) and therefore the code to handle the extra args/kwargs is 1198 # simpler than for the DoFn's of ParDo. 1199 fn, args, kwargs = pickler.loads(self.spec.combine_fn)[:3] 1200 self.combine_fn = curry_combine_fn(fn, args, kwargs) 1201 self.combine_fn_add_input = self.combine_fn.add_input 1202 if self.combine_fn.compact.__func__ is core.CombineFn.compact: 1203 self.combine_fn_compact = None 1204 else: 1205 self.combine_fn_compact = self.combine_fn.compact 1206 if windowing: 1207 self.is_default_windowing = windowing.is_default() 1208 tsc_type = windowing.timestamp_combiner 1209 self.timestamp_combiner = ( 1210 None if tsc_type == window.TimestampCombiner.OUTPUT_AT_EOW else 1211 window.TimestampCombiner.get_impl(tsc_type, windowing.windowfn)) 1212 else: 1213 self.is_default_windowing = False # unknown 1214 self.timestamp_combiner = None 1215 # Optimization for the (known tiny accumulator, often wide keyspace) 1216 # combine functions. 1217 # TODO(b/36567833): Bound by in-memory size rather than key count. 1218 self.max_keys = ( 1219 1000 * 1000 if 1220 isinstance(fn, (combiners.CountCombineFn, combiners.MeanCombineFn)) or 1221 # TODO(b/36597732): Replace this 'or' part by adding the 'cy' optimized 1222 # combiners to the short list above. 1223 ( 1224 isinstance(fn, core.CallableWrapperCombineFn) and 1225 fn._fn in (min, max, sum)) else 100 * 1000) # pylint: disable=protected-access 1226 self.key_count = 0 1227 self.table = {} 1228 1229 def setup(self): 1230 # type: () -> None 1231 with self.scoped_start_state: 1232 _LOGGER.debug('Setup called for %s', self) 1233 super(PGBKCVOperation, self).setup() 1234 self.combine_fn.setup() 1235 1236 def process(self, wkv): 1237 # type: (WindowedValue) -> None 1238 with self.scoped_process_state: 1239 key, value = wkv.value 1240 # pylint: disable=unidiomatic-typecheck 1241 # Optimization for the global window case. 1242 if self.is_default_windowing: 1243 wkey = key # type: Hashable 1244 else: 1245 wkey = tuple(wkv.windows), key 1246 entry = self.table.get(wkey, None) 1247 if entry is None: 1248 if self.key_count >= self.max_keys: 1249 target = self.key_count * 9 // 10 1250 old_wkeys = [] 1251 # TODO(robertwb): Use an LRU cache? 1252 for old_wkey, old_wvalue in self.table.items(): 1253 old_wkeys.append(old_wkey) # Can't mutate while iterating. 1254 self.output_key(old_wkey, old_wvalue[0], old_wvalue[1]) 1255 self.key_count -= 1 1256 if self.key_count <= target: 1257 break 1258 for old_wkey in reversed(old_wkeys): 1259 del self.table[old_wkey] 1260 self.key_count += 1 1261 # We save the accumulator as a one element list so we can efficiently 1262 # mutate when new values are added without searching the cache again. 1263 entry = self.table[wkey] = [self.combine_fn.create_accumulator(), None] 1264 if not self.is_default_windowing: 1265 # Conditional as the timestamp attribute is lazily initialized. 1266 entry[1] = wkv.timestamp 1267 entry[0] = self.combine_fn_add_input(entry[0], value) 1268 if not self.is_default_windowing and self.timestamp_combiner: 1269 entry[1] = self.timestamp_combiner.combine(entry[1], wkv.timestamp) 1270 1271 def finish(self): 1272 # type: () -> None 1273 for wkey, value in self.table.items(): 1274 self.output_key(wkey, value[0], value[1]) 1275 self.table = {} 1276 self.key_count = 0 1277 1278 def teardown(self): 1279 # type: () -> None 1280 with self.scoped_finish_state: 1281 _LOGGER.debug('Teardown called for %s', self) 1282 super(PGBKCVOperation, self).teardown() 1283 self.combine_fn.teardown() 1284 1285 def output_key(self, wkey, accumulator, timestamp): 1286 if self.combine_fn_compact is None: 1287 value = accumulator 1288 else: 1289 value = self.combine_fn_compact(accumulator) 1290 1291 if self.is_default_windowing: 1292 self.output(_globally_windowed_value.with_value((wkey, value))) 1293 else: 1294 windows, key = wkey 1295 if self.timestamp_combiner is None: 1296 timestamp = windows[0].max_timestamp() 1297 self.output(WindowedValue((key, value), timestamp, windows)) 1298 1299 1300 class FlattenOperation(Operation): 1301 """Flatten operation. 1302 1303 Receives one or more producer operations, outputs just one list 1304 with all the items. 1305 """ 1306 def process(self, o): 1307 # type: (WindowedValue) -> None 1308 with self.scoped_process_state: 1309 if self.debug_logging_enabled: 1310 _LOGGER.debug('Processing [%s] in %s', o, self) 1311 self.output(o) 1312 1313 1314 def create_operation( 1315 name_context, 1316 spec, 1317 counter_factory, 1318 step_name=None, 1319 state_sampler=None, 1320 test_shuffle_source=None, 1321 test_shuffle_sink=None, 1322 is_streaming=False): 1323 # type: (...) -> Operation 1324 1325 """Create Operation object for given operation specification.""" 1326 1327 # TODO(pabloem): Document arguments to this function call. 1328 if not isinstance(name_context, common.NameContext): 1329 name_context = common.NameContext(step_name=name_context) 1330 1331 if isinstance(spec, operation_specs.WorkerRead): 1332 if isinstance(spec.source, iobase.SourceBundle): 1333 op = ReadOperation( 1334 name_context, spec, counter_factory, state_sampler) # type: Operation 1335 else: 1336 from dataflow_worker.native_operations import NativeReadOperation 1337 op = NativeReadOperation( 1338 name_context, spec, counter_factory, state_sampler) 1339 elif isinstance(spec, operation_specs.WorkerWrite): 1340 from dataflow_worker.native_operations import NativeWriteOperation 1341 op = NativeWriteOperation( 1342 name_context, spec, counter_factory, state_sampler) 1343 elif isinstance(spec, operation_specs.WorkerCombineFn): 1344 op = CombineOperation(name_context, spec, counter_factory, state_sampler) 1345 elif isinstance(spec, operation_specs.WorkerPartialGroupByKey): 1346 op = create_pgbk_op(name_context, spec, counter_factory, state_sampler) 1347 elif isinstance(spec, operation_specs.WorkerDoFn): 1348 op = DoOperation(name_context, spec, counter_factory, state_sampler) 1349 elif isinstance(spec, operation_specs.WorkerGroupingShuffleRead): 1350 from dataflow_worker.shuffle_operations import GroupedShuffleReadOperation 1351 op = GroupedShuffleReadOperation( 1352 name_context, 1353 spec, 1354 counter_factory, 1355 state_sampler, 1356 shuffle_source=test_shuffle_source) 1357 elif isinstance(spec, operation_specs.WorkerUngroupedShuffleRead): 1358 from dataflow_worker.shuffle_operations import UngroupedShuffleReadOperation 1359 op = UngroupedShuffleReadOperation( 1360 name_context, 1361 spec, 1362 counter_factory, 1363 state_sampler, 1364 shuffle_source=test_shuffle_source) 1365 elif isinstance(spec, operation_specs.WorkerInMemoryWrite): 1366 op = InMemoryWriteOperation( 1367 name_context, spec, counter_factory, state_sampler) 1368 elif isinstance(spec, operation_specs.WorkerShuffleWrite): 1369 from dataflow_worker.shuffle_operations import ShuffleWriteOperation 1370 op = ShuffleWriteOperation( 1371 name_context, 1372 spec, 1373 counter_factory, 1374 state_sampler, 1375 shuffle_sink=test_shuffle_sink) 1376 elif isinstance(spec, operation_specs.WorkerFlatten): 1377 op = FlattenOperation(name_context, spec, counter_factory, state_sampler) 1378 elif isinstance(spec, operation_specs.WorkerMergeWindows): 1379 from dataflow_worker.shuffle_operations import BatchGroupAlsoByWindowsOperation 1380 from dataflow_worker.shuffle_operations import StreamingGroupAlsoByWindowsOperation 1381 if is_streaming: 1382 op = StreamingGroupAlsoByWindowsOperation( 1383 name_context, spec, counter_factory, state_sampler) 1384 else: 1385 op = BatchGroupAlsoByWindowsOperation( 1386 name_context, spec, counter_factory, state_sampler) 1387 elif isinstance(spec, operation_specs.WorkerReifyTimestampAndWindows): 1388 from dataflow_worker.shuffle_operations import ReifyTimestampAndWindowsOperation 1389 op = ReifyTimestampAndWindowsOperation( 1390 name_context, spec, counter_factory, state_sampler) 1391 else: 1392 raise TypeError( 1393 'Expected an instance of operation_specs.Worker* class ' 1394 'instead of %s' % (spec, )) 1395 return op 1396 1397 1398 class SimpleMapTaskExecutor(object): 1399 """An executor for map tasks. 1400 1401 Stores progress of the read operation that is the first operation of a map 1402 task. 1403 """ 1404 def __init__( 1405 self, 1406 map_task, 1407 counter_factory, 1408 state_sampler, 1409 test_shuffle_source=None, 1410 test_shuffle_sink=None): 1411 """Initializes SimpleMapTaskExecutor. 1412 1413 Args: 1414 map_task: The map task we are to run. The maptask contains a list of 1415 operations, and aligned lists for step_names, original_names, 1416 system_names of pipeline steps. 1417 counter_factory: The CounterFactory instance for the work item. 1418 state_sampler: The StateSampler tracking the execution step. 1419 test_shuffle_source: Used during tests for dependency injection into 1420 shuffle read operation objects. 1421 test_shuffle_sink: Used during tests for dependency injection into 1422 shuffle write operation objects. 1423 """ 1424 1425 self._map_task = map_task 1426 self._counter_factory = counter_factory 1427 self._ops = [] # type: List[Operation] 1428 self._state_sampler = state_sampler 1429 self._test_shuffle_source = test_shuffle_source 1430 self._test_shuffle_sink = test_shuffle_sink 1431 1432 def operations(self): 1433 # type: () -> List[Operation] 1434 return self._ops[:] 1435 1436 def execute(self): 1437 # type: () -> None 1438 1439 """Executes all the operation_specs.Worker* instructions in a map task. 1440 1441 We update the map_task with the execution status, expressed as counters. 1442 1443 Raises: 1444 RuntimeError: if we find more than on read instruction in task spec. 1445 TypeError: if the spec parameter is not an instance of the recognized 1446 operation_specs.Worker* classes. 1447 """ 1448 1449 # operations is a list of operation_specs.Worker* instances. 1450 # The order of the elements is important because the inputs use 1451 # list indexes as references. 1452 for name_context, spec in zip(self._map_task.name_contexts, 1453 self._map_task.operations): 1454 # This is used for logging and assigning names to counters. 1455 op = create_operation( 1456 name_context, 1457 spec, 1458 self._counter_factory, 1459 None, 1460 self._state_sampler, 1461 test_shuffle_source=self._test_shuffle_source, 1462 test_shuffle_sink=self._test_shuffle_sink) 1463 self._ops.append(op) 1464 1465 # Add receiver operations to the appropriate producers. 1466 if hasattr(op.spec, 'input'): 1467 producer, output_index = op.spec.input 1468 self._ops[producer].add_receiver(op, output_index) 1469 # Flatten has 'inputs', not 'input' 1470 if hasattr(op.spec, 'inputs'): 1471 for producer, output_index in op.spec.inputs: 1472 self._ops[producer].add_receiver(op, output_index) 1473 1474 for ix, op in reversed(list(enumerate(self._ops))): 1475 _LOGGER.debug('Starting op %d %s', ix, op) 1476 op.start() 1477 for op in self._ops: 1478 op.finish() 1479 1480 1481 class InefficientExecutionWarning(RuntimeWarning): 1482 """warning to indicate an inefficiency in a Beam pipeline.""" 1483 1484 1485 # Don't ignore InefficientExecutionWarning, but only log them once 1486 warnings.simplefilter('once', InefficientExecutionWarning)