github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/common.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 # cython: profile=True 18 # cython: language_level=3 19 20 """Worker operations executor. 21 22 For internal use only; no backwards-compatibility guarantees. 23 """ 24 25 # pytype: skip-file 26 27 import sys 28 import threading 29 import traceback 30 from enum import Enum 31 from typing import TYPE_CHECKING 32 from typing import Any 33 from typing import Dict 34 from typing import Iterable 35 from typing import List 36 from typing import Mapping 37 from typing import Optional 38 from typing import Tuple 39 40 from apache_beam.coders import TupleCoder 41 from apache_beam.internal import util 42 from apache_beam.options.value_provider import RuntimeValueProvider 43 from apache_beam.portability import common_urns 44 from apache_beam.pvalue import TaggedOutput 45 from apache_beam.runners.sdf_utils import NoOpWatermarkEstimatorProvider 46 from apache_beam.runners.sdf_utils import RestrictionTrackerView 47 from apache_beam.runners.sdf_utils import SplitResultPrimary 48 from apache_beam.runners.sdf_utils import SplitResultResidual 49 from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker 50 from apache_beam.runners.sdf_utils import ThreadsafeWatermarkEstimator 51 from apache_beam.transforms import DoFn 52 from apache_beam.transforms import core 53 from apache_beam.transforms import userstate 54 from apache_beam.transforms.core import RestrictionProvider 55 from apache_beam.transforms.core import WatermarkEstimatorProvider 56 from apache_beam.transforms.window import GlobalWindow 57 from apache_beam.transforms.window import TimestampedValue 58 from apache_beam.transforms.window import WindowFn 59 from apache_beam.typehints import typehints 60 from apache_beam.typehints.batch import BatchConverter 61 from apache_beam.utils.counters import Counter 62 from apache_beam.utils.counters import CounterName 63 from apache_beam.utils.timestamp import Timestamp 64 from apache_beam.utils.windowed_value import HomogeneousWindowedBatch 65 from apache_beam.utils.windowed_value import WindowedBatch 66 from apache_beam.utils.windowed_value import WindowedValue 67 68 if TYPE_CHECKING: 69 from apache_beam.transforms import sideinputs 70 from apache_beam.transforms.core import TimerSpec 71 from apache_beam.io.iobase import RestrictionProgress 72 from apache_beam.iobase import RestrictionTracker 73 from apache_beam.iobase import WatermarkEstimator 74 75 76 class NameContext(object): 77 """Holds the name information for a step.""" 78 def __init__(self, step_name, transform_id=None): 79 # type: (str, Optional[str]) -> None 80 81 """Creates a new step NameContext. 82 83 Args: 84 step_name: The name of the step. 85 """ 86 self.step_name = step_name 87 self.transform_id = transform_id 88 89 def __eq__(self, other): 90 return self.step_name == other.step_name 91 92 def __repr__(self): 93 return 'NameContext(%s)' % self.__dict__ 94 95 def __hash__(self): 96 return hash(self.step_name) 97 98 def metrics_name(self): 99 """Returns the step name used for metrics reporting.""" 100 return self.step_name 101 102 def logging_name(self): 103 """Returns the step name used for logging.""" 104 return self.step_name 105 106 107 class Receiver(object): 108 """For internal use only; no backwards-compatibility guarantees. 109 110 An object that consumes a WindowedValue. 111 112 This class can be efficiently used to pass values between the 113 sdk and worker harnesses. 114 """ 115 def receive(self, windowed_value): 116 # type: (WindowedValue) -> None 117 raise NotImplementedError 118 119 def receive_batch(self, windowed_batch): 120 # type: (WindowedBatch) -> None 121 raise NotImplementedError 122 123 def flush(self): 124 raise NotImplementedError 125 126 127 class MethodWrapper(object): 128 """For internal use only; no backwards-compatibility guarantees. 129 130 Represents a method that can be invoked by `DoFnInvoker`.""" 131 def __init__(self, obj_to_invoke, method_name): 132 """ 133 Initiates a ``MethodWrapper``. 134 135 Args: 136 obj_to_invoke: the object that contains the method. Has to either be a 137 `DoFn` object or a `RestrictionProvider` object. 138 method_name: name of the method as a string. 139 """ 140 141 if not isinstance(obj_to_invoke, 142 (DoFn, RestrictionProvider, WatermarkEstimatorProvider)): 143 raise ValueError( 144 '\'obj_to_invoke\' has to be either a \'DoFn\' or ' 145 'a \'RestrictionProvider\'. Received %r instead.' % obj_to_invoke) 146 147 self.args, self.defaults = core.get_function_arguments(obj_to_invoke, 148 method_name) 149 150 # TODO(BEAM-5878) support kwonlyargs on Python 3. 151 self.method_value = getattr(obj_to_invoke, method_name) 152 self.method_name = method_name 153 154 self.has_userstate_arguments = False 155 self.state_args_to_replace = {} # type: Dict[str, core.StateSpec] 156 self.timer_args_to_replace = {} # type: Dict[str, core.TimerSpec] 157 self.timestamp_arg_name = None # type: Optional[str] 158 self.window_arg_name = None # type: Optional[str] 159 self.key_arg_name = None # type: Optional[str] 160 self.restriction_provider = None 161 self.restriction_provider_arg_name = None 162 self.watermark_estimator_provider = None 163 self.watermark_estimator_provider_arg_name = None 164 self.dynamic_timer_tag_arg_name = None 165 166 if hasattr(self.method_value, 'unbounded_per_element'): 167 self.unbounded_per_element = True 168 else: 169 self.unbounded_per_element = False 170 171 for kw, v in zip(self.args[-len(self.defaults):], self.defaults): 172 if isinstance(v, core.DoFn.StateParam): 173 self.state_args_to_replace[kw] = v.state_spec 174 self.has_userstate_arguments = True 175 elif isinstance(v, core.DoFn.TimerParam): 176 self.timer_args_to_replace[kw] = v.timer_spec 177 self.has_userstate_arguments = True 178 elif core.DoFn.TimestampParam == v: 179 self.timestamp_arg_name = kw 180 elif core.DoFn.WindowParam == v: 181 self.window_arg_name = kw 182 elif core.DoFn.KeyParam == v: 183 self.key_arg_name = kw 184 elif isinstance(v, core.DoFn.RestrictionParam): 185 self.restriction_provider = v.restriction_provider or obj_to_invoke 186 self.restriction_provider_arg_name = kw 187 elif isinstance(v, core.DoFn.WatermarkEstimatorParam): 188 self.watermark_estimator_provider = ( 189 v.watermark_estimator_provider or obj_to_invoke) 190 self.watermark_estimator_provider_arg_name = kw 191 elif core.DoFn.DynamicTimerTagParam == v: 192 self.dynamic_timer_tag_arg_name = kw 193 194 # Create NoOpWatermarkEstimatorProvider if there is no 195 # WatermarkEstimatorParam provided. 196 if self.watermark_estimator_provider is None: 197 self.watermark_estimator_provider = NoOpWatermarkEstimatorProvider() 198 199 def invoke_timer_callback( 200 self, 201 user_state_context, 202 key, 203 window, 204 timestamp, 205 pane_info, 206 dynamic_timer_tag): 207 # TODO(ccy): support side inputs. 208 kwargs = {} 209 if self.has_userstate_arguments: 210 for kw, state_spec in self.state_args_to_replace.items(): 211 kwargs[kw] = user_state_context.get_state(state_spec, key, window) 212 for kw, timer_spec in self.timer_args_to_replace.items(): 213 kwargs[kw] = user_state_context.get_timer( 214 timer_spec, key, window, timestamp, pane_info) 215 216 if self.timestamp_arg_name: 217 kwargs[self.timestamp_arg_name] = Timestamp.of(timestamp) 218 if self.window_arg_name: 219 kwargs[self.window_arg_name] = window 220 if self.key_arg_name: 221 kwargs[self.key_arg_name] = key 222 if self.dynamic_timer_tag_arg_name: 223 kwargs[self.dynamic_timer_tag_arg_name] = dynamic_timer_tag 224 225 if kwargs: 226 return self.method_value(**kwargs) 227 else: 228 return self.method_value() 229 230 231 class BatchingPreference(Enum): 232 DO_NOT_CARE = 1 # This operation can operate on batches or element-at-a-time 233 # TODO: Should we also store batching parameters here? (time/size preferences) 234 BATCH_REQUIRED = 2 # This operation can only operate on batches 235 BATCH_FORBIDDEN = 3 # This operation can only work element-at-a-time 236 # Other possibilities: BATCH_PREFERRED (with min batch size specified) 237 238 @property 239 def supports_batches(self) -> bool: 240 return self in (self.BATCH_REQUIRED, self.DO_NOT_CARE) 241 242 @property 243 def supports_elements(self) -> bool: 244 return self in (self.BATCH_FORBIDDEN, self.DO_NOT_CARE) 245 246 @property 247 def requires_batches(self) -> bool: 248 return self == self.BATCH_REQUIRED 249 250 251 class DoFnSignature(object): 252 """Represents the signature of a given ``DoFn`` object. 253 254 Signature of a ``DoFn`` provides a view of the properties of a given ``DoFn``. 255 Among other things, this will give an extensible way for for (1) accessing the 256 structure of the ``DoFn`` including methods and method parameters 257 (2) identifying features that a given ``DoFn`` support, for example, whether 258 a given ``DoFn`` is a Splittable ``DoFn`` ( 259 https://s.apache.org/splittable-do-fn) (3) validating a ``DoFn`` based on the 260 feature set offered by it. 261 """ 262 def __init__(self, do_fn): 263 # type: (core.DoFn) -> None 264 # We add a property here for all methods defined by Beam DoFn features. 265 266 assert isinstance(do_fn, core.DoFn) 267 self.do_fn = do_fn 268 269 self.process_method = MethodWrapper(do_fn, 'process') 270 self.process_batch_method = MethodWrapper(do_fn, 'process_batch') 271 self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle') 272 self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle') 273 self.setup_lifecycle_method = MethodWrapper(do_fn, 'setup') 274 self.teardown_lifecycle_method = MethodWrapper(do_fn, 'teardown') 275 276 restriction_provider = self.get_restriction_provider() 277 watermark_estimator_provider = self.get_watermark_estimator_provider() 278 self.create_watermark_estimator_method = ( 279 MethodWrapper( 280 watermark_estimator_provider, 'create_watermark_estimator')) 281 self.initial_restriction_method = ( 282 MethodWrapper(restriction_provider, 'initial_restriction') 283 if restriction_provider else None) 284 self.create_tracker_method = ( 285 MethodWrapper(restriction_provider, 'create_tracker') 286 if restriction_provider else None) 287 self.split_method = ( 288 MethodWrapper(restriction_provider, 'split') 289 if restriction_provider else None) 290 291 self._validate() 292 293 # Handle stateful DoFns. 294 self._is_stateful_dofn = userstate.is_stateful_dofn(do_fn) 295 self.timer_methods = {} # type: Dict[TimerSpec, MethodWrapper] 296 if self._is_stateful_dofn: 297 # Populate timer firing methods, keyed by TimerSpec. 298 _, all_timer_specs = userstate.get_dofn_specs(do_fn) 299 for timer_spec in all_timer_specs: 300 method = timer_spec._attached_callback 301 self.timer_methods[timer_spec] = MethodWrapper(do_fn, method.__name__) 302 303 def get_restriction_provider(self): 304 # type: () -> RestrictionProvider 305 return self.process_method.restriction_provider 306 307 def get_watermark_estimator_provider(self): 308 # type: () -> WatermarkEstimatorProvider 309 return self.process_method.watermark_estimator_provider 310 311 def is_unbounded_per_element(self): 312 return self.process_method.unbounded_per_element 313 314 def _validate(self): 315 # type: () -> None 316 self._validate_process() 317 self._validate_process_batch() 318 self._validate_bundle_method(self.start_bundle_method) 319 self._validate_bundle_method(self.finish_bundle_method) 320 self._validate_stateful_dofn() 321 322 def _check_duplicate_dofn_params(self, method: MethodWrapper): 323 param_ids = [ 324 d.param_id for d in method.defaults if isinstance(d, core._DoFnParam) 325 ] 326 if len(param_ids) != len(set(param_ids)): 327 raise ValueError( 328 'DoFn %r has duplicate %s method parameters: %s.' % 329 (self.do_fn, method.method_name, param_ids)) 330 331 def _validate_process(self): 332 # type: () -> None 333 334 """Validate that none of the DoFnParameters are repeated in the function 335 """ 336 self._check_duplicate_dofn_params(self.process_method) 337 338 def _validate_process_batch(self): 339 # type: () -> None 340 self._check_duplicate_dofn_params(self.process_batch_method) 341 342 for d in self.process_batch_method.defaults: 343 if not isinstance(d, core._DoFnParam): 344 continue 345 346 # Helpful errors for params which will be supported in the future 347 if d == (core.DoFn.ElementParam): 348 # We currently assume we can just get the typehint from the first 349 # parameter. ElementParam breaks this assumption 350 raise NotImplementedError( 351 f"DoFn {self.do_fn!r} uses unsupported DoFn param ElementParam.") 352 353 if d in (core.DoFn.KeyParam, core.DoFn.StateParam, core.DoFn.TimerParam): 354 raise NotImplementedError( 355 f"DoFn {self.do_fn!r} has unsupported per-key DoFn param {d}. " 356 "Per-key DoFn params are not yet supported for process_batch " 357 "(https://github.com/apache/beam/issues/21653).") 358 359 # Fallback to catch anything not explicitly supported 360 if not d in (core.DoFn.WindowParam, 361 core.DoFn.TimestampParam, 362 core.DoFn.PaneInfoParam): 363 raise ValueError( 364 f"DoFn {self.do_fn!r} has unsupported process_batch " 365 f"method parameter {d}") 366 367 def _validate_bundle_method(self, method_wrapper): 368 """Validate that none of the DoFnParameters are used in the function 369 """ 370 for param in core.DoFn.DoFnProcessParams: 371 if param in method_wrapper.defaults: 372 raise ValueError( 373 'DoFn.process() method-only parameter %s cannot be used in %s.' % 374 (param, method_wrapper)) 375 376 def _validate_stateful_dofn(self): 377 # type: () -> None 378 userstate.validate_stateful_dofn(self.do_fn) 379 380 def is_splittable_dofn(self): 381 # type: () -> bool 382 return self.get_restriction_provider() is not None 383 384 def get_restriction_coder(self): 385 # type: () -> Optional[TupleCoder] 386 387 """Get coder for a restriction when processing an SDF. """ 388 if self.is_splittable_dofn(): 389 return TupleCoder([ 390 (self.get_restriction_provider().restriction_coder()), 391 (self.get_watermark_estimator_provider().estimator_state_coder()) 392 ]) 393 else: 394 return None 395 396 def is_stateful_dofn(self): 397 # type: () -> bool 398 return self._is_stateful_dofn 399 400 def has_timers(self): 401 # type: () -> bool 402 _, all_timer_specs = userstate.get_dofn_specs(self.do_fn) 403 return bool(all_timer_specs) 404 405 def has_bundle_finalization(self): 406 for sig in (self.start_bundle_method, 407 self.process_method, 408 self.finish_bundle_method): 409 for d in sig.defaults: 410 try: 411 if d == DoFn.BundleFinalizerParam: 412 return True 413 except Exception: # pylint: disable=broad-except 414 # Default value might be incomparable. 415 pass 416 return False 417 418 419 class DoFnInvoker(object): 420 """An abstraction that can be used to execute DoFn methods. 421 422 A DoFnInvoker describes a particular way for invoking methods of a DoFn 423 represented by a given DoFnSignature.""" 424 425 def __init__(self, 426 output_handler, # type: _OutputHandler 427 signature # type: DoFnSignature 428 ): 429 # type: (...) -> None 430 431 """ 432 Initializes `DoFnInvoker` 433 434 :param output_handler: an OutputHandler for receiving elements produced 435 by invoking functions of the DoFn. 436 :param signature: a DoFnSignature for the DoFn being invoked 437 """ 438 self.output_handler = output_handler 439 self.signature = signature 440 self.user_state_context = None # type: Optional[userstate.UserStateContext] 441 self.bundle_finalizer_param = None # type: Optional[core._BundleFinalizerParam] 442 443 @staticmethod 444 def create_invoker( 445 signature, # type: DoFnSignature 446 output_handler, # type: OutputHandler 447 context=None, # type: Optional[DoFnContext] 448 side_inputs=None, # type: Optional[List[sideinputs.SideInputMap]] 449 input_args=None, input_kwargs=None, 450 process_invocation=True, 451 user_state_context=None, # type: Optional[userstate.UserStateContext] 452 bundle_finalizer_param=None # type: Optional[core._BundleFinalizerParam] 453 ): 454 # type: (...) -> DoFnInvoker 455 456 """ Creates a new DoFnInvoker based on given arguments. 457 458 Args: 459 output_handler: an OutputHandler for receiving elements produced by 460 invoking functions of the DoFn. 461 signature: a DoFnSignature for the DoFn being invoked. 462 context: Context to be used when invoking the DoFn (deprecated). 463 side_inputs: side inputs to be used when invoking th process method. 464 input_args: arguments to be used when invoking the process method. Some 465 of the arguments given here might be placeholders (for 466 example for side inputs) that get filled before invoking the 467 process method. 468 input_kwargs: keyword arguments to be used when invoking the process 469 method. Some of the keyword arguments given here might be 470 placeholders (for example for side inputs) that get filled 471 before invoking the process method. 472 process_invocation: If True, this function may return an invoker that 473 performs extra optimizations for invoking process() 474 method efficiently. 475 user_state_context: The UserStateContext instance for the current 476 Stateful DoFn. 477 bundle_finalizer_param: The param that passed to a process method, which 478 allows a callback to be registered. 479 """ 480 side_inputs = side_inputs or [] 481 use_per_window_invoker = process_invocation and ( 482 side_inputs or input_args or input_kwargs or 483 signature.process_method.defaults or 484 signature.process_batch_method.defaults or signature.is_stateful_dofn()) 485 if not use_per_window_invoker: 486 return SimpleInvoker(output_handler, signature) 487 else: 488 if context is None: 489 raise TypeError("Must provide context when not using SimpleInvoker") 490 return PerWindowInvoker( 491 output_handler, 492 signature, 493 context, 494 side_inputs, 495 input_args, 496 input_kwargs, 497 user_state_context, 498 bundle_finalizer_param) 499 500 def invoke_process(self, 501 windowed_value, # type: WindowedValue 502 restriction=None, 503 watermark_estimator_state=None, 504 additional_args=None, 505 additional_kwargs=None 506 ): 507 # type: (...) -> Iterable[SplitResultResidual] 508 509 """Invokes the DoFn.process() function. 510 511 Args: 512 windowed_value: a WindowedValue object that gives the element for which 513 process() method should be invoked along with the window 514 the element belongs to. 515 restriction: The restriction to use when executing this splittable DoFn. 516 Should only be specified for splittable DoFns. 517 watermark_estimator_state: The watermark estimator state to use when 518 executing this splittable DoFn. Should only 519 be specified for splittable DoFns. 520 additional_args: additional arguments to be passed to the current 521 `DoFn.process()` invocation, usually as side inputs. 522 additional_kwargs: additional keyword arguments to be passed to the 523 current `DoFn.process()` invocation. 524 """ 525 raise NotImplementedError 526 527 def invoke_process_batch(self, 528 windowed_batch, # type: WindowedBatch 529 additional_args=None, 530 additional_kwargs=None 531 ): 532 # type: (...) -> None 533 534 """Invokes the DoFn.process() function. 535 536 Args: 537 windowed_batch: a WindowedBatch object that gives a batch of elements for 538 which process_batch() method should be invoked, along with 539 the window each element belongs to. 540 additional_args: additional arguments to be passed to the current 541 `DoFn.process()` invocation, usually as side inputs. 542 additional_kwargs: additional keyword arguments to be passed to the 543 current `DoFn.process()` invocation. 544 """ 545 raise NotImplementedError 546 547 def invoke_setup(self): 548 # type: () -> None 549 550 """Invokes the DoFn.setup() method 551 """ 552 self.signature.setup_lifecycle_method.method_value() 553 554 def invoke_start_bundle(self): 555 # type: () -> None 556 557 """Invokes the DoFn.start_bundle() method. 558 """ 559 self.output_handler.start_bundle_outputs( 560 self.signature.start_bundle_method.method_value()) 561 562 def invoke_finish_bundle(self): 563 # type: () -> None 564 565 """Invokes the DoFn.finish_bundle() method. 566 """ 567 self.output_handler.finish_bundle_outputs( 568 self.signature.finish_bundle_method.method_value()) 569 570 def invoke_teardown(self): 571 # type: () -> None 572 573 """Invokes the DoFn.teardown() method 574 """ 575 self.signature.teardown_lifecycle_method.method_value() 576 577 def invoke_user_timer( 578 self, timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag): 579 # self.output_handler is Optional, but in practice it won't be None here 580 self.output_handler.handle_process_outputs( 581 WindowedValue(None, timestamp, (window, )), 582 self.signature.timer_methods[timer_spec].invoke_timer_callback( 583 self.user_state_context, 584 key, 585 window, 586 timestamp, 587 pane_info, 588 dynamic_timer_tag)) 589 590 def invoke_create_watermark_estimator(self, estimator_state): 591 return self.signature.create_watermark_estimator_method.method_value( 592 estimator_state) 593 594 def invoke_split(self, element, restriction): 595 return self.signature.split_method.method_value(element, restriction) 596 597 def invoke_initial_restriction(self, element): 598 return self.signature.initial_restriction_method.method_value(element) 599 600 def invoke_create_tracker(self, restriction): 601 return self.signature.create_tracker_method.method_value(restriction) 602 603 604 class SimpleInvoker(DoFnInvoker): 605 """An invoker that processes elements ignoring windowing information.""" 606 607 def __init__(self, 608 output_handler, # type: OutputHandler 609 signature # type: DoFnSignature 610 ): 611 # type: (...) -> None 612 super().__init__(output_handler, signature) 613 self.process_method = signature.process_method.method_value 614 self.process_batch_method = signature.process_batch_method.method_value 615 616 def invoke_process(self, 617 windowed_value, # type: WindowedValue 618 restriction=None, 619 watermark_estimator_state=None, 620 additional_args=None, 621 additional_kwargs=None 622 ): 623 # type: (...) -> Iterable[SplitResultResidual] 624 self.output_handler.handle_process_outputs( 625 windowed_value, self.process_method(windowed_value.value)) 626 return [] 627 628 def invoke_process_batch(self, 629 windowed_batch, # type: WindowedBatch 630 restriction=None, 631 watermark_estimator_state=None, 632 additional_args=None, 633 additional_kwargs=None 634 ): 635 # type: (...) -> None 636 self.output_handler.handle_process_batch_outputs( 637 windowed_batch, self.process_batch_method(windowed_batch.values)) 638 639 640 def _get_arg_placeholders( 641 method: MethodWrapper, 642 input_args: Optional[List[Any]], 643 input_kwargs: Optional[Dict[str, any]]): 644 input_args = input_args if input_args else [] 645 input_kwargs = input_kwargs if input_kwargs else {} 646 647 arg_names = method.args 648 default_arg_values = method.defaults 649 650 # Create placeholder for element parameter of DoFn.process() method. 651 # Not to be confused with ArgumentPlaceHolder, which may be passed in 652 # input_args and is a placeholder for side-inputs. 653 class ArgPlaceholder(object): 654 def __init__(self, placeholder): 655 self.placeholder = placeholder 656 657 if all(core.DoFn.ElementParam != arg for arg in default_arg_values): 658 # TODO(https://github.com/apache/beam/issues/19631): Handle cases in which 659 # len(arg_names) == len(default_arg_values). 660 args_to_pick = len(arg_names) - len(default_arg_values) - 1 661 # Positional argument values for process(), with placeholders for special 662 # values such as the element, timestamp, etc. 663 args_with_placeholders = ([ArgPlaceholder(core.DoFn.ElementParam)] + 664 input_args[:args_to_pick]) 665 else: 666 args_to_pick = len(arg_names) - len(default_arg_values) 667 args_with_placeholders = input_args[:args_to_pick] 668 669 # Fill the OtherPlaceholders for context, key, window or timestamp 670 remaining_args_iter = iter(input_args[args_to_pick:]) 671 for a, d in zip(arg_names[-len(default_arg_values):], default_arg_values): 672 if core.DoFn.ElementParam == d: 673 args_with_placeholders.append(ArgPlaceholder(d)) 674 elif core.DoFn.KeyParam == d: 675 args_with_placeholders.append(ArgPlaceholder(d)) 676 elif core.DoFn.WindowParam == d: 677 args_with_placeholders.append(ArgPlaceholder(d)) 678 elif core.DoFn.TimestampParam == d: 679 args_with_placeholders.append(ArgPlaceholder(d)) 680 elif core.DoFn.PaneInfoParam == d: 681 args_with_placeholders.append(ArgPlaceholder(d)) 682 elif core.DoFn.SideInputParam == d: 683 # If no more args are present then the value must be passed via kwarg 684 try: 685 args_with_placeholders.append(next(remaining_args_iter)) 686 except StopIteration: 687 if a not in input_kwargs: 688 raise ValueError("Value for sideinput %s not provided" % a) 689 elif isinstance(d, core.DoFn.StateParam): 690 args_with_placeholders.append(ArgPlaceholder(d)) 691 elif isinstance(d, core.DoFn.TimerParam): 692 args_with_placeholders.append(ArgPlaceholder(d)) 693 elif isinstance(d, type) and core.DoFn.BundleFinalizerParam == d: 694 args_with_placeholders.append(ArgPlaceholder(d)) 695 else: 696 # If no more args are present then the value must be passed via kwarg 697 try: 698 args_with_placeholders.append(next(remaining_args_iter)) 699 except StopIteration: 700 pass 701 args_with_placeholders.extend(list(remaining_args_iter)) 702 703 # Stash the list of placeholder positions for performance 704 placeholders = [(i, x.placeholder) 705 for (i, x) in enumerate(args_with_placeholders) 706 if isinstance(x, ArgPlaceholder)] 707 708 return placeholders, args_with_placeholders, input_kwargs 709 710 711 class PerWindowInvoker(DoFnInvoker): 712 """An invoker that processes elements considering windowing information.""" 713 714 def __init__(self, 715 output_handler, # type: OutputHandler 716 signature, # type: DoFnSignature 717 context, # type: DoFnContext 718 side_inputs, # type: Iterable[sideinputs.SideInputMap] 719 input_args, 720 input_kwargs, 721 user_state_context, # type: Optional[userstate.UserStateContext] 722 bundle_finalizer_param # type: Optional[core._BundleFinalizerParam] 723 ): 724 super().__init__(output_handler, signature) 725 self.side_inputs = side_inputs 726 self.context = context 727 self.process_method = signature.process_method.method_value 728 default_arg_values = signature.process_method.defaults 729 self.has_windowed_inputs = ( 730 not all(si.is_globally_windowed() for si in side_inputs) or any( 731 core.DoFn.WindowParam == arg 732 for arg in signature.process_method.defaults) or any( 733 core.DoFn.WindowParam == arg 734 for arg in signature.process_batch_method.defaults) or 735 signature.is_stateful_dofn()) 736 self.user_state_context = user_state_context 737 self.is_splittable = signature.is_splittable_dofn() 738 self.is_key_param_required = any( 739 core.DoFn.KeyParam == arg for arg in default_arg_values) 740 self.threadsafe_restriction_tracker = None # type: Optional[ThreadsafeRestrictionTracker] 741 self.threadsafe_watermark_estimator = None # type: Optional[ThreadsafeWatermarkEstimator] 742 self.current_windowed_value = None # type: Optional[WindowedValue] 743 self.bundle_finalizer_param = bundle_finalizer_param 744 if self.is_splittable: 745 self.splitting_lock = threading.Lock() 746 self.current_window_index = None 747 self.stop_window_index = None 748 749 # Flag to cache additional arguments on the first element if all 750 # inputs are within the global window. 751 self.cache_globally_windowed_args = not self.has_windowed_inputs 752 753 # Try to prepare all the arguments that can just be filled in 754 # without any additional work. in the process function. 755 # Also cache all the placeholders needed in the process function. 756 ( 757 self.placeholders_for_process, 758 self.args_for_process, 759 self.kwargs_for_process) = _get_arg_placeholders( 760 signature.process_method, input_args, input_kwargs) 761 762 self.process_batch_method = signature.process_batch_method.method_value 763 764 ( 765 self.placeholders_for_process_batch, 766 self.args_for_process_batch, 767 self.kwargs_for_process_batch) = _get_arg_placeholders( 768 signature.process_batch_method, input_args, input_kwargs) 769 770 def invoke_process(self, 771 windowed_value, # type: WindowedValue 772 restriction=None, 773 watermark_estimator_state=None, 774 additional_args=None, 775 additional_kwargs=None 776 ): 777 # type: (...) -> Iterable[SplitResultResidual] 778 if not additional_args: 779 additional_args = [] 780 if not additional_kwargs: 781 additional_kwargs = {} 782 783 self.context.set_element(windowed_value) 784 # Call for the process function for each window if has windowed side inputs 785 # or if the process accesses the window parameter. We can just call it once 786 # otherwise as none of the arguments are changing 787 788 residuals = [] 789 if self.is_splittable: 790 if restriction is None: 791 # This may be a SDF invoked as an ordinary DoFn on runners that don't 792 # understand SDF. See, e.g. BEAM-11472. 793 # In this case, processing the element is simply processing it against 794 # the entire initial restriction. 795 restriction = self.signature.initial_restriction_method.method_value( 796 windowed_value.value) 797 798 with self.splitting_lock: 799 self.current_windowed_value = windowed_value 800 self.restriction = restriction 801 self.watermark_estimator_state = watermark_estimator_state 802 try: 803 if self.has_windowed_inputs and len(windowed_value.windows) > 1: 804 for i, w in enumerate(windowed_value.windows): 805 if not self._should_process_window_for_sdf( 806 windowed_value, additional_kwargs, i): 807 break 808 residual = self._invoke_process_per_window( 809 WindowedValue( 810 windowed_value.value, windowed_value.timestamp, (w, )), 811 additional_args, 812 additional_kwargs) 813 if residual: 814 residuals.append(residual) 815 else: 816 if self._should_process_window_for_sdf(windowed_value, 817 additional_kwargs): 818 residual = self._invoke_process_per_window( 819 windowed_value, additional_args, additional_kwargs) 820 if residual: 821 residuals.append(residual) 822 finally: 823 with self.splitting_lock: 824 self.current_windowed_value = None 825 self.restriction = None 826 self.watermark_estimator_state = None 827 self.current_window_index = None 828 self.threadsafe_restriction_tracker = None 829 self.threadsafe_watermark_estimator = None 830 elif self.has_windowed_inputs and len(windowed_value.windows) != 1: 831 for w in windowed_value.windows: 832 self._invoke_process_per_window( 833 WindowedValue( 834 windowed_value.value, windowed_value.timestamp, (w, )), 835 additional_args, 836 additional_kwargs) 837 else: 838 self._invoke_process_per_window( 839 windowed_value, additional_args, additional_kwargs) 840 return residuals 841 842 def invoke_process_batch(self, 843 windowed_batch, # type: WindowedBatch 844 additional_args=None, 845 additional_kwargs=None 846 ): 847 # type: (...) -> None 848 849 if not additional_args: 850 additional_args = [] 851 if not additional_kwargs: 852 additional_kwargs = {} 853 854 assert isinstance(windowed_batch, HomogeneousWindowedBatch) 855 856 if self.has_windowed_inputs and len(windowed_batch.windows) != 1: 857 for w in windowed_batch.windows: 858 self._invoke_process_batch_per_window( 859 HomogeneousWindowedBatch.of( 860 windowed_batch.values, 861 windowed_batch.timestamp, (w, ), 862 windowed_batch.pane_info), 863 additional_args, 864 additional_kwargs) 865 else: 866 self._invoke_process_batch_per_window( 867 windowed_batch, additional_args, additional_kwargs) 868 869 def _should_process_window_for_sdf( 870 self, 871 windowed_value, # type: WindowedValue 872 additional_kwargs, 873 window_index=None, # type: Optional[int] 874 ): 875 restriction_tracker = self.invoke_create_tracker(self.restriction) 876 watermark_estimator = self.invoke_create_watermark_estimator( 877 self.watermark_estimator_state) 878 with self.splitting_lock: 879 if window_index: 880 self.current_window_index = window_index 881 if window_index == 0: 882 self.stop_window_index = len(windowed_value.windows) 883 if window_index == self.stop_window_index: 884 return False 885 self.threadsafe_restriction_tracker = ThreadsafeRestrictionTracker( 886 restriction_tracker) 887 self.threadsafe_watermark_estimator = ( 888 ThreadsafeWatermarkEstimator(watermark_estimator)) 889 890 restriction_tracker_param = ( 891 self.signature.process_method.restriction_provider_arg_name) 892 if not restriction_tracker_param: 893 raise ValueError( 894 'DoFn is splittable but DoFn does not have a ' 895 'RestrictionTrackerParam defined') 896 additional_kwargs[restriction_tracker_param] = ( 897 RestrictionTrackerView(self.threadsafe_restriction_tracker)) 898 watermark_param = ( 899 self.signature.process_method.watermark_estimator_provider_arg_name) 900 # When the watermark_estimator is a NoOpWatermarkEstimator, the system 901 # will not add watermark_param into the DoFn param list. 902 if watermark_param is not None: 903 additional_kwargs[watermark_param] = self.threadsafe_watermark_estimator 904 return True 905 906 def _invoke_process_per_window(self, 907 windowed_value, # type: WindowedValue 908 additional_args, 909 additional_kwargs, 910 ): 911 # type: (...) -> Optional[SplitResultResidual] 912 913 if self.has_windowed_inputs: 914 assert len(windowed_value.windows) <= 1 915 window, = windowed_value.windows 916 side_inputs = [si[window] for si in self.side_inputs] 917 side_inputs.extend(additional_args) 918 args_for_process, kwargs_for_process = util.insert_values_in_args( 919 self.args_for_process, self.kwargs_for_process, 920 side_inputs) 921 elif self.cache_globally_windowed_args: 922 # Attempt to cache additional args if all inputs are globally 923 # windowed inputs when processing the first element. 924 self.cache_globally_windowed_args = False 925 926 # Fill in sideInputs if they are globally windowed 927 global_window = GlobalWindow() 928 self.args_for_process, self.kwargs_for_process = ( 929 util.insert_values_in_args( 930 self.args_for_process, self.kwargs_for_process, 931 [si[global_window] for si in self.side_inputs])) 932 args_for_process, kwargs_for_process = ( 933 self.args_for_process, self.kwargs_for_process) 934 else: 935 args_for_process, kwargs_for_process = ( 936 self.args_for_process, self.kwargs_for_process) 937 938 # Extract key in the case of a stateful DoFn. Note that in the case of a 939 # stateful DoFn, we set during __init__ self.has_windowed_inputs to be 940 # True. Therefore, windows will be exploded coming into this method, and 941 # we can rely on the window variable being set above. 942 if self.user_state_context or self.is_key_param_required: 943 try: 944 key, unused_value = windowed_value.value 945 except (TypeError, ValueError): 946 raise ValueError(( 947 'Input value to a stateful DoFn or KeyParam must be a KV tuple; ' 948 'instead, got \'%s\'.') % (windowed_value.value, )) 949 950 for i, p in self.placeholders_for_process: 951 if core.DoFn.ElementParam == p: 952 args_for_process[i] = windowed_value.value 953 elif core.DoFn.KeyParam == p: 954 args_for_process[i] = key 955 elif core.DoFn.WindowParam == p: 956 args_for_process[i] = window 957 elif core.DoFn.TimestampParam == p: 958 args_for_process[i] = windowed_value.timestamp 959 elif core.DoFn.PaneInfoParam == p: 960 args_for_process[i] = windowed_value.pane_info 961 elif isinstance(p, core.DoFn.StateParam): 962 assert self.user_state_context is not None 963 args_for_process[i] = ( 964 self.user_state_context.get_state(p.state_spec, key, window)) 965 elif isinstance(p, core.DoFn.TimerParam): 966 assert self.user_state_context is not None 967 args_for_process[i] = ( 968 self.user_state_context.get_timer( 969 p.timer_spec, 970 key, 971 window, 972 windowed_value.timestamp, 973 windowed_value.pane_info)) 974 elif core.DoFn.BundleFinalizerParam == p: 975 args_for_process[i] = self.bundle_finalizer_param 976 977 kwargs_for_process = kwargs_for_process or {} 978 979 if additional_kwargs: 980 kwargs_for_process.update(additional_kwargs) 981 982 self.output_handler.handle_process_outputs( 983 windowed_value, 984 self.process_method(*args_for_process, **kwargs_for_process), 985 self.threadsafe_watermark_estimator) 986 987 if self.is_splittable: 988 assert self.threadsafe_restriction_tracker is not None 989 self.threadsafe_restriction_tracker.check_done() 990 deferred_status = self.threadsafe_restriction_tracker.deferred_status() 991 if deferred_status: 992 deferred_restriction, deferred_timestamp = deferred_status 993 element = windowed_value.value 994 size = self.signature.get_restriction_provider().restriction_size( 995 element, deferred_restriction) 996 if size < 0: 997 raise ValueError('Expected size >= 0 but received %s.' % size) 998 current_watermark = ( 999 self.threadsafe_watermark_estimator.current_watermark()) 1000 estimator_state = ( 1001 self.threadsafe_watermark_estimator.get_estimator_state()) 1002 residual_value = ((element, (deferred_restriction, estimator_state)), 1003 size) 1004 return SplitResultResidual( 1005 residual_value=windowed_value.with_value(residual_value), 1006 current_watermark=current_watermark, 1007 deferred_timestamp=deferred_timestamp) 1008 return None 1009 1010 def _invoke_process_batch_per_window( 1011 self, 1012 windowed_batch: WindowedBatch, 1013 additional_args, 1014 additional_kwargs, 1015 ): 1016 # type: (...) -> Optional[SplitResultResidual] 1017 1018 if self.has_windowed_inputs: 1019 assert isinstance(windowed_batch, HomogeneousWindowedBatch) 1020 assert len(windowed_batch.windows) <= 1 1021 1022 window, = windowed_batch.windows 1023 side_inputs = [si[window] for si in self.side_inputs] 1024 side_inputs.extend(additional_args) 1025 (args_for_process_batch, 1026 kwargs_for_process_batch) = util.insert_values_in_args( 1027 self.args_for_process_batch, 1028 self.kwargs_for_process_batch, 1029 side_inputs) 1030 elif self.cache_globally_windowed_args: 1031 # Attempt to cache additional args if all inputs are globally 1032 # windowed inputs when processing the first element. 1033 self.cache_globally_windowed_args = False 1034 1035 # Fill in sideInputs if they are globally windowed 1036 global_window = GlobalWindow() 1037 self.args_for_process_batch, self.kwargs_for_process_batch = ( 1038 util.insert_values_in_args( 1039 self.args_for_process_batch, self.kwargs_for_process_batch, 1040 [si[global_window] for si in self.side_inputs])) 1041 args_for_process_batch, kwargs_for_process_batch = ( 1042 self.args_for_process_batch, self.kwargs_for_process_batch) 1043 else: 1044 args_for_process_batch, kwargs_for_process_batch = ( 1045 self.args_for_process_batch, self.kwargs_for_process_batch) 1046 1047 for i, p in self.placeholders_for_process_batch: 1048 if core.DoFn.ElementParam == p: 1049 args_for_process_batch[i] = windowed_batch.values 1050 elif core.DoFn.KeyParam == p: 1051 raise NotImplementedError( 1052 "https://github.com/apache/beam/issues/21653: " 1053 "Per-key process_batch") 1054 elif core.DoFn.WindowParam == p: 1055 args_for_process_batch[i] = window 1056 elif core.DoFn.TimestampParam == p: 1057 args_for_process_batch[i] = windowed_batch.timestamp 1058 elif core.DoFn.PaneInfoParam == p: 1059 assert isinstance(windowed_batch, HomogeneousWindowedBatch) 1060 args_for_process_batch[i] = windowed_batch.pane_info 1061 elif isinstance(p, core.DoFn.StateParam): 1062 raise NotImplementedError( 1063 "https://github.com/apache/beam/issues/21653: " 1064 "Per-key process_batch") 1065 elif isinstance(p, core.DoFn.TimerParam): 1066 raise NotImplementedError( 1067 "https://github.com/apache/beam/issues/21653: " 1068 "Per-key process_batch") 1069 1070 kwargs_for_process_batch = kwargs_for_process_batch or {} 1071 1072 self.output_handler.handle_process_batch_outputs( 1073 windowed_batch, 1074 self.process_batch_method( 1075 *args_for_process_batch, **kwargs_for_process_batch), 1076 self.threadsafe_watermark_estimator) 1077 1078 @staticmethod 1079 def _try_split(fraction, 1080 window_index, # type: Optional[int] 1081 stop_window_index, # type: Optional[int] 1082 windowed_value, # type: WindowedValue 1083 restriction, 1084 watermark_estimator_state, 1085 restriction_provider, # type: RestrictionProvider 1086 restriction_tracker, # type: RestrictionTracker 1087 watermark_estimator, # type: WatermarkEstimator 1088 ): 1089 # type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual], Optional[int]]] 1090 1091 """Try to split returning a primaries, residuals and a new stop index. 1092 1093 For non-window observing splittable DoFns we split the current restriction 1094 and assign the primary and residual to all the windows. 1095 1096 For window observing splittable DoFns, we: 1097 1) return a split at a window boundary if the fraction lies outside of the 1098 current window. 1099 2) attempt to split the current restriction, if successful then return 1100 the primary and residual for the current window and an additional 1101 primary and residual for any fully processed and fully unprocessed 1102 windows. 1103 3) fall back to returning a split at the window boundary if possible 1104 1105 Args: 1106 window_index: the current index of the window being processed or None 1107 if the splittable DoFn is not window observing. 1108 stop_window_index: the current index to stop processing at or None 1109 if the splittable DoFn is not window observing. 1110 windowed_value: the current windowed value 1111 restriction: the initial restriction when processing was started. 1112 watermark_estimator_state: the initial watermark estimator state when 1113 processing was started. 1114 restriction_provider: the DoFn's restriction provider 1115 restriction_tracker: the current restriction tracker 1116 watermark_estimator: the current watermark estimator 1117 1118 Returns: 1119 A tuple containing (primaries, residuals, new_stop_index) or None if 1120 splitting was not possible. new_stop_index will only be set if the 1121 splittable DoFn is window observing otherwise it will be None. 1122 """ 1123 def compute_whole_window_split(to_index, from_index): 1124 restriction_size = restriction_provider.restriction_size( 1125 windowed_value, restriction) 1126 if restriction_size < 0: 1127 raise ValueError( 1128 'Expected size >= 0 but received %s.' % restriction_size) 1129 # The primary and residual both share the same value only differing 1130 # by the set of windows they are in. 1131 value = ((windowed_value.value, (restriction, watermark_estimator_state)), 1132 restriction_size) 1133 primary_restriction = SplitResultPrimary( 1134 primary_value=WindowedValue( 1135 value, 1136 windowed_value.timestamp, 1137 windowed_value.windows[:to_index])) if to_index > 0 else None 1138 # Don't report any updated watermarks for the residual since they have 1139 # not processed any part of the restriction. 1140 residual_restriction = SplitResultResidual( 1141 residual_value=WindowedValue( 1142 value, 1143 windowed_value.timestamp, 1144 windowed_value.windows[from_index:stop_window_index]), 1145 current_watermark=None, 1146 deferred_timestamp=None) if from_index < stop_window_index else None 1147 return (primary_restriction, residual_restriction) 1148 1149 primary_restrictions = [] 1150 residual_restrictions = [] 1151 1152 window_observing = window_index is not None 1153 # If we are processing each window separately and we aren't on the last 1154 # window then compute whether the split lies within the current window 1155 # or a future window. 1156 if window_observing and window_index != stop_window_index - 1: 1157 progress = restriction_tracker.current_progress() 1158 if not progress: 1159 # Assume no work has been completed for the current window if progress 1160 # is unavailable. 1161 from apache_beam.io.iobase import RestrictionProgress 1162 progress = RestrictionProgress(completed=0, remaining=1) 1163 1164 scaled_progress = PerWindowInvoker._scale_progress( 1165 progress, window_index, stop_window_index) 1166 # Compute the fraction of the remainder relative to the scaled progress. 1167 # If the value is greater than or equal to progress.remaining_work then we 1168 # should split at the closest window boundary. 1169 fraction_of_remainder = scaled_progress.remaining_work * fraction 1170 if fraction_of_remainder >= progress.remaining_work: 1171 # The fraction is outside of the current window and hence we will 1172 # split at the closest window boundary. Favor a split and return the 1173 # last window if we would have rounded up to the end of the window 1174 # based upon the fraction. 1175 new_stop_window_index = min( 1176 stop_window_index - 1, 1177 window_index + max( 1178 1, 1179 int( 1180 round(( 1181 progress.completed_work + 1182 scaled_progress.remaining_work * fraction) / 1183 progress.total_work)))) 1184 primary, residual = compute_whole_window_split( 1185 new_stop_window_index, new_stop_window_index) 1186 assert primary is not None 1187 assert residual is not None 1188 return ([primary], [residual], new_stop_window_index) 1189 else: 1190 # The fraction is within the current window being processed so compute 1191 # the updated fraction based upon the number of windows being processed. 1192 new_stop_window_index = window_index + 1 1193 fraction = fraction_of_remainder / progress.remaining_work 1194 # Attempt to split below, if we can't then we'll compute a split 1195 # using only window boundaries 1196 else: 1197 # We aren't splitting within multiple windows so we don't change our 1198 # stop index. 1199 new_stop_window_index = stop_window_index 1200 1201 # Temporary workaround for [BEAM-7473]: get current_watermark before 1202 # split, in case watermark gets advanced before getting split results. 1203 # In worst case, current_watermark is always stale, which is ok. 1204 current_watermark = (watermark_estimator.current_watermark()) 1205 current_estimator_state = (watermark_estimator.get_estimator_state()) 1206 split = restriction_tracker.try_split(fraction) 1207 if split: 1208 primary, residual = split 1209 element = windowed_value.value 1210 primary_size = restriction_provider.restriction_size( 1211 windowed_value.value, primary) 1212 if primary_size < 0: 1213 raise ValueError('Expected size >= 0 but received %s.' % primary_size) 1214 residual_size = restriction_provider.restriction_size( 1215 windowed_value.value, residual) 1216 if residual_size < 0: 1217 raise ValueError('Expected size >= 0 but received %s.' % residual_size) 1218 # We use the watermark estimator state for the original process call 1219 # for the primary and the updated watermark estimator state for the 1220 # residual for the split. 1221 primary_split_value = ((element, (primary, watermark_estimator_state)), 1222 primary_size) 1223 residual_split_value = ((element, (residual, current_estimator_state)), 1224 residual_size) 1225 windows = ( 1226 windowed_value.windows[window_index], 1227 ) if window_observing else windowed_value.windows 1228 primary_restrictions.append( 1229 SplitResultPrimary( 1230 primary_value=WindowedValue( 1231 primary_split_value, windowed_value.timestamp, windows))) 1232 residual_restrictions.append( 1233 SplitResultResidual( 1234 residual_value=WindowedValue( 1235 residual_split_value, windowed_value.timestamp, windows), 1236 current_watermark=current_watermark, 1237 deferred_timestamp=None)) 1238 1239 if window_observing: 1240 assert new_stop_window_index == window_index + 1 1241 primary, residual = compute_whole_window_split( 1242 window_index, window_index + 1) 1243 if primary: 1244 primary_restrictions.append(primary) 1245 if residual: 1246 residual_restrictions.append(residual) 1247 return ( 1248 primary_restrictions, residual_restrictions, new_stop_window_index) 1249 elif new_stop_window_index and new_stop_window_index != stop_window_index: 1250 # If we failed to split but have a new stop index then return a split 1251 # at the window boundary. 1252 primary, residual = compute_whole_window_split( 1253 new_stop_window_index, new_stop_window_index) 1254 assert primary is not None 1255 assert residual is not None 1256 return ([primary], [residual], new_stop_window_index) 1257 else: 1258 return None 1259 1260 def try_split(self, fraction): 1261 # type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual]]] 1262 if not self.is_splittable: 1263 return None 1264 1265 with self.splitting_lock: 1266 if not self.threadsafe_restriction_tracker: 1267 return None 1268 1269 # Make a local reference to member variables that change references during 1270 # processing under lock before attempting to split so we have a consistent 1271 # view of all the references. 1272 result = PerWindowInvoker._try_split( 1273 fraction, 1274 self.current_window_index, 1275 self.stop_window_index, 1276 self.current_windowed_value, 1277 self.restriction, 1278 self.watermark_estimator_state, 1279 self.signature.get_restriction_provider(), 1280 self.threadsafe_restriction_tracker, 1281 self.threadsafe_watermark_estimator) 1282 if not result: 1283 return None 1284 1285 residuals, primaries, self.stop_window_index = result 1286 return (residuals, primaries) 1287 1288 @staticmethod 1289 def _scale_progress(progress, window_index, stop_window_index): 1290 # We scale progress based upon the amount of work we will do for one 1291 # window and have it apply for all windows. 1292 completed = window_index * progress.total_work + progress.completed_work 1293 remaining = ( 1294 stop_window_index - 1295 (window_index + 1)) * progress.total_work + progress.remaining_work 1296 from apache_beam.io.iobase import RestrictionProgress 1297 return RestrictionProgress(completed=completed, remaining=remaining) 1298 1299 def current_element_progress(self): 1300 # type: () -> Optional[RestrictionProgress] 1301 if not self.is_splittable: 1302 return None 1303 1304 with self.splitting_lock: 1305 current_window_index = self.current_window_index 1306 stop_window_index = self.stop_window_index 1307 threadsafe_restriction_tracker = self.threadsafe_restriction_tracker 1308 1309 if not threadsafe_restriction_tracker: 1310 return None 1311 1312 progress = threadsafe_restriction_tracker.current_progress() 1313 if not current_window_index or not progress: 1314 return progress 1315 1316 # stop_window_index should always be set if current_window_index is set, 1317 # it is an error otherwise. 1318 assert stop_window_index 1319 return PerWindowInvoker._scale_progress( 1320 progress, current_window_index, stop_window_index) 1321 1322 1323 class DoFnRunner: 1324 """For internal use only; no backwards-compatibility guarantees. 1325 1326 A helper class for executing ParDo operations. 1327 """ 1328 1329 def __init__(self, 1330 fn, # type: core.DoFn 1331 args, 1332 kwargs, 1333 side_inputs, # type: Iterable[sideinputs.SideInputMap] 1334 windowing, 1335 tagged_receivers, # type: Mapping[Optional[str], Receiver] 1336 step_name=None, # type: Optional[str] 1337 logging_context=None, 1338 state=None, 1339 scoped_metrics_container=None, 1340 operation_name=None, 1341 user_state_context=None # type: Optional[userstate.UserStateContext] 1342 ): 1343 """Initializes a DoFnRunner. 1344 1345 Args: 1346 fn: user DoFn to invoke 1347 args: positional side input arguments (static and placeholder), if any 1348 kwargs: keyword side input arguments (static and placeholder), if any 1349 side_inputs: list of sideinput.SideInputMaps for deferred side inputs 1350 windowing: windowing properties of the output PCollection(s) 1351 tagged_receivers: a dict of tag name to Receiver objects 1352 step_name: the name of this step 1353 logging_context: DEPRECATED [BEAM-4728] 1354 state: handle for accessing DoFn state 1355 scoped_metrics_container: DEPRECATED 1356 operation_name: The system name assigned by the runner for this operation. 1357 user_state_context: The UserStateContext instance for the current 1358 Stateful DoFn. 1359 """ 1360 # Need to support multiple iterations. 1361 side_inputs = list(side_inputs) 1362 1363 self.step_name = step_name 1364 self.context = DoFnContext(step_name, state=state) 1365 self.bundle_finalizer_param = DoFn.BundleFinalizerParam() 1366 1367 do_fn_signature = DoFnSignature(fn) 1368 1369 # Optimize for the common case. 1370 main_receivers = tagged_receivers[None] 1371 1372 # TODO(https://github.com/apache/beam/issues/18886): Remove if block after 1373 # output counter released. 1374 if 'outputs_per_element_counter' in RuntimeValueProvider.experiments: 1375 # TODO(BEAM-3955): Make step_name and operation_name less confused. 1376 output_counter_name = ( 1377 CounterName('per-element-output-count', step_name=operation_name)) 1378 per_element_output_counter = state._counter_factory.get_counter( 1379 output_counter_name, Counter.DATAFLOW_DISTRIBUTION).accumulator 1380 else: 1381 per_element_output_counter = None 1382 1383 output_handler = _OutputHandler( 1384 windowing.windowfn, 1385 main_receivers, 1386 tagged_receivers, 1387 per_element_output_counter, 1388 getattr(fn, 'output_batch_converter', None), 1389 getattr( 1390 do_fn_signature.process_method.method_value, 1391 '_beam_yields_batches', 1392 False), 1393 getattr( 1394 do_fn_signature.process_batch_method.method_value, 1395 '_beam_yields_elements', 1396 False), 1397 ) 1398 1399 if do_fn_signature.is_stateful_dofn() and not user_state_context: 1400 raise Exception( 1401 'Requested execution of a stateful DoFn, but no user state context ' 1402 'is available. This likely means that the current runner does not ' 1403 'support the execution of stateful DoFns.') 1404 1405 self.do_fn_invoker = DoFnInvoker.create_invoker( 1406 do_fn_signature, 1407 output_handler, 1408 self.context, 1409 side_inputs, 1410 args, 1411 kwargs, 1412 user_state_context=user_state_context, 1413 bundle_finalizer_param=self.bundle_finalizer_param) 1414 1415 def process(self, windowed_value): 1416 # type: (WindowedValue) -> Iterable[SplitResultResidual] 1417 try: 1418 return self.do_fn_invoker.invoke_process(windowed_value) 1419 except BaseException as exn: 1420 self._reraise_augmented(exn) 1421 return [] 1422 1423 def process_batch(self, windowed_batch): 1424 # type: (WindowedBatch) -> None 1425 try: 1426 self.do_fn_invoker.invoke_process_batch(windowed_batch) 1427 except BaseException as exn: 1428 self._reraise_augmented(exn) 1429 1430 def process_with_sized_restriction(self, windowed_value): 1431 # type: (WindowedValue) -> Iterable[SplitResultResidual] 1432 (element, (restriction, estimator_state)), _ = windowed_value.value 1433 return self.do_fn_invoker.invoke_process( 1434 windowed_value.with_value(element), 1435 restriction=restriction, 1436 watermark_estimator_state=estimator_state) 1437 1438 def try_split(self, fraction): 1439 # type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual]]] 1440 assert isinstance(self.do_fn_invoker, PerWindowInvoker) 1441 return self.do_fn_invoker.try_split(fraction) 1442 1443 def current_element_progress(self): 1444 # type: () -> Optional[RestrictionProgress] 1445 assert isinstance(self.do_fn_invoker, PerWindowInvoker) 1446 return self.do_fn_invoker.current_element_progress() 1447 1448 def process_user_timer( 1449 self, timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag): 1450 try: 1451 self.do_fn_invoker.invoke_user_timer( 1452 timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag) 1453 except BaseException as exn: 1454 self._reraise_augmented(exn) 1455 1456 def _invoke_bundle_method(self, bundle_method): 1457 try: 1458 self.context.set_element(None) 1459 bundle_method() 1460 except BaseException as exn: 1461 self._reraise_augmented(exn) 1462 1463 def _invoke_lifecycle_method(self, lifecycle_method): 1464 try: 1465 self.context.set_element(None) 1466 lifecycle_method() 1467 except BaseException as exn: 1468 self._reraise_augmented(exn) 1469 1470 def setup(self): 1471 # type: () -> None 1472 self._invoke_lifecycle_method(self.do_fn_invoker.invoke_setup) 1473 1474 def start(self): 1475 # type: () -> None 1476 self._invoke_bundle_method(self.do_fn_invoker.invoke_start_bundle) 1477 1478 def finish(self): 1479 # type: () -> None 1480 self._invoke_bundle_method(self.do_fn_invoker.invoke_finish_bundle) 1481 1482 def teardown(self): 1483 # type: () -> None 1484 self._invoke_lifecycle_method(self.do_fn_invoker.invoke_teardown) 1485 1486 def finalize(self): 1487 # type: () -> None 1488 self.bundle_finalizer_param.finalize_bundle() 1489 1490 def _reraise_augmented(self, exn): 1491 if getattr(exn, '_tagged_with_step', False) or not self.step_name: 1492 raise exn 1493 step_annotation = " [while running '%s']" % self.step_name 1494 # To emulate exception chaining (not available in Python 2). 1495 try: 1496 # Attempt to construct the same kind of exception 1497 # with an augmented message. 1498 new_exn = type(exn)(exn.args[0] + step_annotation, *exn.args[1:]) 1499 new_exn._tagged_with_step = True # Could raise attribute error. 1500 except: # pylint: disable=bare-except 1501 # If anything goes wrong, construct a RuntimeError whose message 1502 # records the original exception's type and message. 1503 new_exn = RuntimeError( 1504 traceback.format_exception_only(type(exn), exn)[-1].strip() + 1505 step_annotation) 1506 new_exn._tagged_with_step = True 1507 _, _, tb = sys.exc_info() 1508 raise new_exn.with_traceback(tb) 1509 1510 1511 class OutputHandler(object): 1512 def handle_process_outputs( 1513 self, windowed_input_element, results, watermark_estimator=None): 1514 # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None 1515 raise NotImplementedError 1516 1517 def handle_process_batch_outputs( 1518 self, windowed_input_element, results, watermark_estimator=None): 1519 # type: (WindowedBatch, Iterable[Any], Optional[WatermarkEstimator]) -> None 1520 raise NotImplementedError 1521 1522 1523 class _OutputHandler(OutputHandler): 1524 """Processes output produced by DoFn method invocations.""" 1525 1526 def __init__(self, 1527 window_fn, 1528 main_receivers, # type: Receiver 1529 tagged_receivers, # type: Mapping[Optional[str], Receiver] 1530 per_element_output_counter, 1531 output_batch_converter, # type: Optional[BatchConverter] 1532 process_yields_batches, # type: bool, 1533 process_batch_yields_elements, # type: bool, 1534 ): 1535 """Initializes ``_OutputHandler``. 1536 1537 Args: 1538 window_fn: a windowing function (WindowFn). 1539 main_receivers: a dict of tag name to Receiver objects. 1540 tagged_receivers: main receiver object. 1541 per_element_output_counter: per_element_output_counter of one work_item. 1542 could be none if experimental flag turn off 1543 """ 1544 self.window_fn = window_fn 1545 self.main_receivers = main_receivers 1546 self.tagged_receivers = tagged_receivers 1547 if (per_element_output_counter is not None and 1548 per_element_output_counter.is_cythonized): 1549 self.per_element_output_counter = per_element_output_counter 1550 else: 1551 self.per_element_output_counter = None 1552 self.output_batch_converter = output_batch_converter 1553 self._process_yields_batches = process_yields_batches 1554 self._process_batch_yields_elements = process_batch_yields_elements 1555 1556 def handle_process_outputs( 1557 self, windowed_input_element, results, watermark_estimator=None): 1558 # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None 1559 1560 """Dispatch the result of process computation to the appropriate receivers. 1561 1562 A value wrapped in a TaggedOutput object will be unwrapped and 1563 then dispatched to the appropriate indexed output. 1564 """ 1565 if results is None: 1566 results = [] 1567 1568 # TODO(https://github.com/apache/beam/issues/20404): Verify that the 1569 # results object is a valid iterable type if 1570 # performance_runtime_type_check is active, without harming performance 1571 output_element_count = 0 1572 for result in results: 1573 tag, result = self._handle_tagged_output(result) 1574 1575 if not self._process_yields_batches: 1576 # process yields elements 1577 windowed_value = self._maybe_propagate_windowing_info( 1578 windowed_input_element, result) 1579 1580 output_element_count += 1 1581 1582 self._write_value_to_tag(tag, windowed_value, watermark_estimator) 1583 else: # process yields batches 1584 self._verify_batch_output(result) 1585 1586 if isinstance(result, WindowedBatch): 1587 assert isinstance(result, HomogeneousWindowedBatch) 1588 windowed_batch = result 1589 1590 if (windowed_input_element is not None and 1591 len(windowed_input_element.windows) != 1): 1592 windowed_batch.windows *= len(windowed_input_element.windows) 1593 else: 1594 windowed_batch = ( 1595 HomogeneousWindowedBatch.from_batch_and_windowed_value( 1596 batch=result, windowed_value=windowed_input_element)) 1597 1598 output_element_count += self.output_batch_converter.get_length( 1599 windowed_batch.values) 1600 1601 self._write_batch_to_tag(tag, windowed_batch, watermark_estimator) 1602 1603 # TODO(https://github.com/apache/beam/issues/18886): Remove if block after 1604 # output counter released. Only enable per_element_output_counter when 1605 # counter cythonized 1606 if self.per_element_output_counter is not None: 1607 self.per_element_output_counter.add_input(output_element_count) 1608 1609 def handle_process_batch_outputs( 1610 self, windowed_input_batch, results, watermark_estimator=None): 1611 # type: (WindowedBatch, Iterable[Any], Optional[WatermarkEstimator]) -> None 1612 1613 """Dispatch the result of process_batch computation to the appropriate 1614 receivers. 1615 1616 A value wrapped in a TaggedOutput object will be unwrapped and 1617 then dispatched to the appropriate indexed output. 1618 """ 1619 if results is None: 1620 results = [] 1621 1622 output_element_count = 0 1623 for result in results: 1624 tag, result = self._handle_tagged_output(result) 1625 1626 if not self._process_batch_yields_elements: 1627 # process_batch yields batches 1628 assert self.output_batch_converter is not None 1629 1630 self._verify_batch_output(result) 1631 1632 if isinstance(result, WindowedBatch): 1633 assert isinstance(result, HomogeneousWindowedBatch) 1634 windowed_batch = result 1635 1636 if (windowed_input_batch is not None and 1637 len(windowed_input_batch.windows) != 1): 1638 windowed_batch.windows *= len(windowed_input_batch.windows) 1639 else: 1640 windowed_batch = windowed_input_batch.with_values(result) 1641 1642 output_element_count += self.output_batch_converter.get_length( 1643 windowed_batch.values) 1644 1645 self._write_batch_to_tag(tag, windowed_batch, watermark_estimator) 1646 else: # process_batch yields elements 1647 assert isinstance(windowed_input_batch, HomogeneousWindowedBatch) 1648 1649 windowed_value = self._maybe_propagate_windowing_info( 1650 windowed_input_batch.as_empty_windowed_value(), result) 1651 1652 output_element_count += 1 1653 1654 self._write_value_to_tag(tag, windowed_value, watermark_estimator) 1655 1656 # TODO(https://github.com/apache/beam/issues/18886): Remove if block after 1657 # output counter released. Only enable per_element_output_counter when 1658 # counter cythonized 1659 if self.per_element_output_counter is not None: 1660 self.per_element_output_counter.add_input(output_element_count) 1661 1662 def _maybe_propagate_windowing_info(self, windowed_input_element, result): 1663 # type: (WindowedValue, Any) -> WindowedValue 1664 if isinstance(result, WindowedValue): 1665 windowed_value = result 1666 if (windowed_input_element is not None and 1667 len(windowed_input_element.windows) != 1): 1668 windowed_value.windows *= len(windowed_input_element.windows) 1669 return windowed_value 1670 1671 elif isinstance(result, TimestampedValue): 1672 assign_context = WindowFn.AssignContext(result.timestamp, result.value) 1673 windowed_value = WindowedValue( 1674 result.value, result.timestamp, self.window_fn.assign(assign_context)) 1675 if len(windowed_input_element.windows) != 1: 1676 windowed_value.windows *= len(windowed_input_element.windows) 1677 return windowed_value 1678 1679 else: 1680 return windowed_input_element.with_value(result) 1681 1682 def _handle_tagged_output(self, result): 1683 if isinstance(result, TaggedOutput): 1684 tag = result.tag 1685 if not isinstance(tag, str): 1686 raise TypeError('In %s, tag %s is not a string' % (self, tag)) 1687 return tag, result.value 1688 return None, result 1689 1690 def _write_value_to_tag(self, tag, windowed_value, watermark_estimator): 1691 if watermark_estimator is not None: 1692 watermark_estimator.observe_timestamp(windowed_value.timestamp) 1693 1694 if tag is None: 1695 self.main_receivers.receive(windowed_value) 1696 else: 1697 self.tagged_receivers[tag].receive(windowed_value) 1698 1699 def _write_batch_to_tag(self, tag, windowed_batch, watermark_estimator): 1700 if watermark_estimator is not None: 1701 for timestamp in windowed_batch.timestamps: 1702 watermark_estimator.observe_timestamp(timestamp) 1703 1704 if tag is None: 1705 self.main_receivers.receive_batch(windowed_batch) 1706 else: 1707 self.tagged_receivers[tag].receive_batch(windowed_batch) 1708 1709 def _verify_batch_output(self, result): 1710 if isinstance(result, (WindowedValue, TimestampedValue)): 1711 raise TypeError( 1712 f"Received {type(result).__name__} from DoFn that was " 1713 "expected to produce a batch.") 1714 1715 def start_bundle_outputs(self, results): 1716 """Validate that start_bundle does not output any elements""" 1717 if results is None: 1718 return 1719 raise RuntimeError( 1720 'Start Bundle should not output any elements but got %s' % results) 1721 1722 def finish_bundle_outputs(self, results): 1723 """Dispatch the result of finish_bundle to the appropriate receivers. 1724 1725 A value wrapped in a TaggedOutput object will be unwrapped and 1726 then dispatched to the appropriate indexed output. 1727 """ 1728 if results is None: 1729 return 1730 1731 for result in results: 1732 tag = None 1733 if isinstance(result, TaggedOutput): 1734 tag = result.tag 1735 if not isinstance(tag, str): 1736 raise TypeError('In %s, tag %s is not a string' % (self, tag)) 1737 result = result.value 1738 1739 if isinstance(result, WindowedValue): 1740 windowed_value = result 1741 else: 1742 raise RuntimeError('Finish Bundle should only output WindowedValue ' +\ 1743 'type but got %s' % type(result)) 1744 1745 if tag is None: 1746 self.main_receivers.receive(windowed_value) 1747 else: 1748 self.tagged_receivers[tag].receive(windowed_value) 1749 1750 1751 class _NoContext(WindowFn.AssignContext): 1752 """An uninspectable WindowFn.AssignContext.""" 1753 NO_VALUE = object() 1754 1755 def __init__(self, value, timestamp=NO_VALUE): 1756 self.value = value 1757 self._timestamp = timestamp 1758 1759 @property 1760 def timestamp(self): 1761 if self._timestamp is self.NO_VALUE: 1762 raise ValueError('No timestamp in this context.') 1763 else: 1764 return self._timestamp 1765 1766 @property 1767 def existing_windows(self): 1768 raise ValueError('No existing_windows in this context.') 1769 1770 1771 class DoFnState(object): 1772 """For internal use only; no backwards-compatibility guarantees. 1773 1774 Keeps track of state that DoFns want, currently, user counters. 1775 """ 1776 def __init__(self, counter_factory): 1777 self.step_name = '' 1778 self._counter_factory = counter_factory 1779 1780 def counter_for(self, aggregator): 1781 """Looks up the counter for this aggregator, creating one if necessary.""" 1782 return self._counter_factory.get_aggregator_counter( 1783 self.step_name, aggregator) 1784 1785 1786 # TODO(robertwb): Replace core.DoFnContext with this. 1787 class DoFnContext(object): 1788 """For internal use only; no backwards-compatibility guarantees.""" 1789 def __init__(self, label, element=None, state=None): 1790 self.label = label 1791 self.state = state 1792 if element is not None: 1793 self.set_element(element) 1794 1795 def set_element(self, windowed_value): 1796 # type: (Optional[WindowedValue]) -> None 1797 self.windowed_value = windowed_value 1798 1799 @property 1800 def element(self): 1801 if self.windowed_value is None: 1802 raise AttributeError('element not accessible in this context') 1803 else: 1804 return self.windowed_value.value 1805 1806 @property 1807 def timestamp(self): 1808 if self.windowed_value is None: 1809 raise AttributeError('timestamp not accessible in this context') 1810 else: 1811 return self.windowed_value.timestamp 1812 1813 @property 1814 def windows(self): 1815 if self.windowed_value is None: 1816 raise AttributeError('windows not accessible in this context') 1817 else: 1818 return self.windowed_value.windows 1819 1820 1821 def group_by_key_input_visitor(deterministic_key_coders=True): 1822 # Importing here to avoid a circular dependency 1823 # pylint: disable=wrong-import-order, wrong-import-position 1824 from apache_beam.pipeline import PipelineVisitor 1825 from apache_beam.transforms.core import GroupByKey 1826 1827 class GroupByKeyInputVisitor(PipelineVisitor): 1828 """A visitor that replaces `Any` element type for input `PCollection` of 1829 a `GroupByKey` with a `KV` type. 1830 1831 TODO(BEAM-115): Once Python SDK is compatible with the new Runner API, 1832 we could directly replace the coder instead of mutating the element type. 1833 """ 1834 def __init__(self, deterministic_key_coders=True): 1835 self.deterministic_key_coders = deterministic_key_coders 1836 1837 def enter_composite_transform(self, transform_node): 1838 self.visit_transform(transform_node) 1839 1840 def visit_transform(self, transform_node): 1841 if isinstance(transform_node.transform, GroupByKey): 1842 pcoll = transform_node.inputs[0] 1843 pcoll.element_type = typehints.coerce_to_kv_type( 1844 pcoll.element_type, transform_node.full_label) 1845 pcoll.requires_deterministic_key_coder = ( 1846 self.deterministic_key_coders and transform_node.full_label) 1847 key_type, value_type = pcoll.element_type.tuple_types 1848 if transform_node.outputs: 1849 key = next(iter(transform_node.outputs.keys())) 1850 transform_node.outputs[key].element_type = typehints.KV[ 1851 key_type, typehints.Iterable[value_type]] 1852 transform_node.outputs[key].requires_deterministic_key_coder = ( 1853 self.deterministic_key_coders and transform_node.full_label) 1854 1855 return GroupByKeyInputVisitor(deterministic_key_coders) 1856 1857 1858 def validate_pipeline_graph(pipeline_proto): 1859 """Ensures this is a correctly constructed Beam pipeline. 1860 """ 1861 def get_coder(pcoll_id): 1862 return pipeline_proto.components.coders[ 1863 pipeline_proto.components.pcollections[pcoll_id].coder_id] 1864 1865 def validate_transform(transform_id): 1866 transform_proto = pipeline_proto.components.transforms[transform_id] 1867 1868 # Currently the only validation we perform is that GBK operations have 1869 # their coders set properly. 1870 if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn: 1871 if len(transform_proto.inputs) != 1: 1872 raise ValueError("Unexpected number of inputs: %s" % transform_proto) 1873 if len(transform_proto.outputs) != 1: 1874 raise ValueError("Unexpected number of outputs: %s" % transform_proto) 1875 input_coder = get_coder(next(iter(transform_proto.inputs.values()))) 1876 output_coder = get_coder(next(iter(transform_proto.outputs.values()))) 1877 if input_coder.spec.urn != common_urns.coders.KV.urn: 1878 raise ValueError( 1879 "Bad coder for input of %s: %s" % (transform_id, input_coder)) 1880 if output_coder.spec.urn != common_urns.coders.KV.urn: 1881 raise ValueError( 1882 "Bad coder for output of %s: %s" % (transform_id, output_coder)) 1883 output_values_coder = pipeline_proto.components.coders[ 1884 output_coder.component_coder_ids[1]] 1885 if (input_coder.component_coder_ids[0] != 1886 output_coder.component_coder_ids[0] or 1887 output_values_coder.spec.urn != common_urns.coders.ITERABLE.urn or 1888 output_values_coder.component_coder_ids[0] != 1889 input_coder.component_coder_ids[1]): 1890 raise ValueError( 1891 "Incompatible input coder %s and output coder %s for transform %s" % 1892 (transform_id, input_coder, output_coder)) 1893 1894 for t in transform_proto.subtransforms: 1895 validate_transform(t) 1896 1897 for t in pipeline_proto.root_transform_ids: 1898 validate_transform(t)