github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/core.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Core PTransform subclasses, such as FlatMap, GroupByKey, and Map.""" 19 20 # pytype: skip-file 21 22 import concurrent.futures 23 import copy 24 import inspect 25 import logging 26 import random 27 import sys 28 import time 29 import traceback 30 import types 31 import typing 32 from itertools import dropwhile 33 34 from apache_beam import coders 35 from apache_beam import pvalue 36 from apache_beam import typehints 37 from apache_beam.coders import typecoders 38 from apache_beam.internal import pickler 39 from apache_beam.internal import util 40 from apache_beam.options.pipeline_options import TypeOptions 41 from apache_beam.portability import common_urns 42 from apache_beam.portability import python_urns 43 from apache_beam.portability.api import beam_runner_api_pb2 44 from apache_beam.transforms import ptransform 45 from apache_beam.transforms import userstate 46 from apache_beam.transforms.display import DisplayDataItem 47 from apache_beam.transforms.display import HasDisplayData 48 from apache_beam.transforms.ptransform import PTransform 49 from apache_beam.transforms.ptransform import PTransformWithSideInputs 50 from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX 51 from apache_beam.transforms.sideinputs import get_sideinput_index 52 from apache_beam.transforms.userstate import StateSpec 53 from apache_beam.transforms.userstate import TimerSpec 54 from apache_beam.transforms.window import GlobalWindows 55 from apache_beam.transforms.window import SlidingWindows 56 from apache_beam.transforms.window import TimestampCombiner 57 from apache_beam.transforms.window import TimestampedValue 58 from apache_beam.transforms.window import WindowedValue 59 from apache_beam.transforms.window import WindowFn 60 from apache_beam.typehints import row_type 61 from apache_beam.typehints import trivial_inference 62 from apache_beam.typehints.batch import BatchConverter 63 from apache_beam.typehints.decorators import TypeCheckError 64 from apache_beam.typehints.decorators import WithTypeHints 65 from apache_beam.typehints.decorators import get_signature 66 from apache_beam.typehints.decorators import get_type_hints 67 from apache_beam.typehints.decorators import with_input_types 68 from apache_beam.typehints.decorators import with_output_types 69 from apache_beam.typehints.trivial_inference import element_type 70 from apache_beam.typehints.typehints import TypeConstraint 71 from apache_beam.typehints.typehints import is_consistent_with 72 from apache_beam.typehints.typehints import visit_inner_types 73 from apache_beam.utils import urns 74 from apache_beam.utils.timestamp import Duration 75 76 if typing.TYPE_CHECKING: 77 from google.protobuf import message # pylint: disable=ungrouped-imports 78 from apache_beam.io import iobase 79 from apache_beam.pipeline import Pipeline 80 from apache_beam.runners.pipeline_context import PipelineContext 81 from apache_beam.transforms import create_source 82 from apache_beam.transforms.trigger import AccumulationMode 83 from apache_beam.transforms.trigger import DefaultTrigger 84 from apache_beam.transforms.trigger import TriggerFn 85 86 __all__ = [ 87 'DoFn', 88 'CombineFn', 89 'PartitionFn', 90 'ParDo', 91 'FlatMap', 92 'FlatMapTuple', 93 'Map', 94 'MapTuple', 95 'Filter', 96 'CombineGlobally', 97 'CombinePerKey', 98 'CombineValues', 99 'GroupBy', 100 'GroupByKey', 101 'Select', 102 'Partition', 103 'Windowing', 104 'WindowInto', 105 'Flatten', 106 'Create', 107 'Impulse', 108 'RestrictionProvider', 109 'WatermarkEstimatorProvider', 110 ] 111 112 # Type variables 113 T = typing.TypeVar('T') 114 K = typing.TypeVar('K') 115 V = typing.TypeVar('V') 116 117 _LOGGER = logging.getLogger(__name__) 118 119 120 class DoFnContext(object): 121 """A context available to all methods of DoFn instance.""" 122 pass 123 124 125 class DoFnProcessContext(DoFnContext): 126 """A processing context passed to DoFn process() during execution. 127 128 Most importantly, a DoFn.process method will access context.element 129 to get the element it is supposed to process. 130 131 Attributes: 132 label: label of the ParDo whose element is being processed. 133 element: element being processed 134 (in process method only; always None in start_bundle and finish_bundle) 135 timestamp: timestamp of the element 136 (in process method only; always None in start_bundle and finish_bundle) 137 windows: windows of the element 138 (in process method only; always None in start_bundle and finish_bundle) 139 state: a DoFnState object, which holds the runner's internal state 140 for this element. 141 Not used by the pipeline code. 142 """ 143 def __init__(self, label, element=None, state=None): 144 """Initialize a processing context object with an element and state. 145 146 The element represents one value from a PCollection that will be accessed 147 by a DoFn object during pipeline execution, and state is an arbitrary object 148 where counters and other pipeline state information can be passed in. 149 150 DoFnProcessContext objects are also used as inputs to PartitionFn instances. 151 152 Args: 153 label: label of the PCollection whose element is being processed. 154 element: element of a PCollection being processed using this context. 155 state: a DoFnState object with state to be passed in to the DoFn object. 156 """ 157 self.label = label 158 self.state = state 159 if element is not None: 160 self.set_element(element) 161 162 def set_element(self, windowed_value): 163 if windowed_value is None: 164 # Not currently processing an element. 165 if hasattr(self, 'element'): 166 del self.element 167 del self.timestamp 168 del self.windows 169 else: 170 self.element = windowed_value.value 171 self.timestamp = windowed_value.timestamp 172 self.windows = windowed_value.windows 173 174 175 class ProcessContinuation(object): 176 """An object that may be produced as the last element of a process method 177 invocation. 178 179 If produced, indicates that there is more work to be done for the current 180 input element. 181 """ 182 def __init__(self, resume_delay=0): 183 """Initializes a ProcessContinuation object. 184 185 Args: 186 resume_delay: indicates the minimum time, in seconds, that should elapse 187 before re-invoking process() method for resuming the invocation of the 188 current element. 189 """ 190 self.resume_delay = resume_delay 191 192 @staticmethod 193 def resume(resume_delay=0): 194 """A convenient method that produces a ``ProcessContinuation``. 195 196 Args: 197 resume_delay: delay after which processing current element should be 198 resumed. 199 Returns: a ``ProcessContinuation`` for signalling the runner that current 200 input element has not been fully processed and should be resumed later. 201 """ 202 return ProcessContinuation(resume_delay=resume_delay) 203 204 205 class RestrictionProvider(object): 206 """Provides methods for generating and manipulating restrictions. 207 208 This class should be implemented to support Splittable ``DoFn`` in Python 209 SDK. See https://s.apache.org/splittable-do-fn for more details about 210 Splittable ``DoFn``. 211 212 To denote a ``DoFn`` class to be Splittable ``DoFn``, ``DoFn.process()`` 213 method of that class should have exactly one parameter whose default value is 214 an instance of ``RestrictionParam``. This ``RestrictionParam`` can either be 215 constructed with an explicit ``RestrictionProvider``, or, if no 216 ``RestrictionProvider`` is provided, the ``DoFn`` itself must be a 217 ``RestrictionProvider``. 218 219 The provided ``RestrictionProvider`` instance must provide suitable overrides 220 for the following methods: 221 * create_tracker() 222 * initial_restriction() 223 * restriction_size() 224 225 Optionally, ``RestrictionProvider`` may override default implementations of 226 following methods: 227 * restriction_coder() 228 * split() 229 * split_and_size() 230 * truncate() 231 232 ** Pausing and resuming processing of an element ** 233 234 As the last element produced by the iterator returned by the 235 ``DoFn.process()`` method, a Splittable ``DoFn`` may return an object of type 236 ``ProcessContinuation``. 237 238 If restriction_tracker.defer_remander is called in the ```DoFn.process()``, it 239 means that runner should later re-invoke ``DoFn.process()`` method to resume 240 processing the current element and the manner in which the re-invocation 241 should be performed. 242 243 ** Updating output watermark ** 244 245 ``DoFn.process()`` method of Splittable ``DoFn``s could contain a parameter 246 with default value ``DoFn.WatermarkReporterParam``. If specified this asks the 247 runner to provide a function that can be used to give the runner a 248 (best-effort) lower bound about the timestamps of future output associated 249 with the current element processed by the ``DoFn``. If the ``DoFn`` has 250 multiple outputs, the watermark applies to all of them. Provided function must 251 be invoked with a single parameter of type ``Timestamp`` or as an integer that 252 gives the watermark in number of seconds. 253 """ 254 def create_tracker(self, restriction): 255 # type: (...) -> iobase.RestrictionTracker 256 257 """Produces a new ``RestrictionTracker`` for the given restriction. 258 259 This API is required to be implemented. 260 261 Args: 262 restriction: an object that defines a restriction as identified by a 263 Splittable ``DoFn`` that utilizes the current ``RestrictionProvider``. 264 For example, a tuple that gives a range of positions for a Splittable 265 ``DoFn`` that reads files based on byte positions. 266 Returns: an object of type ``RestrictionTracker``. 267 """ 268 raise NotImplementedError 269 270 def initial_restriction(self, element): 271 """Produces an initial restriction for the given element. 272 273 This API is required to be implemented. 274 """ 275 raise NotImplementedError 276 277 def split(self, element, restriction): 278 """Splits the given element and restriction initially. 279 280 This method enables runners to perform bulk splitting initially allowing for 281 a rapid increase in parallelism. Note that initial split is a different 282 concept from the split during element processing time. Please refer to 283 ``iobase.RestrictionTracker.try_split`` for details about splitting when the 284 current element and restriction are actively being processed. 285 286 Returns an iterator of restrictions. The total set of elements produced by 287 reading input element for each of the returned restrictions should be the 288 same as the total set of elements produced by reading the input element for 289 the input restriction. 290 291 This API is optional if ``split_and_size`` has been implemented. 292 293 If this method is not override, there is no initial splitting happening on 294 each restriction. 295 296 """ 297 yield restriction 298 299 def restriction_coder(self): 300 """Returns a ``Coder`` for restrictions. 301 302 Returned``Coder`` will be used for the restrictions produced by the current 303 ``RestrictionProvider``. 304 305 Returns: 306 an object of type ``Coder``. 307 """ 308 return coders.registry.get_coder(object) 309 310 def restriction_size(self, element, restriction): 311 """Returns the size of a restriction with respect to the given element. 312 313 By default, asks a newly-created restriction tracker for the default size 314 of the restriction. 315 316 The return value must be non-negative. 317 318 Must be thread safe. Will be invoked concurrently during bundle processing 319 due to runner initiated splitting and progress estimation. 320 321 This API is required to be implemented. 322 """ 323 raise NotImplementedError 324 325 def split_and_size(self, element, restriction): 326 """Like split, but also does sizing, returning (restriction, size) pairs. 327 328 For each pair, size must be non-negative. 329 330 This API is optional if ``split`` and ``restriction_size`` have been 331 implemented. 332 """ 333 for part in self.split(element, restriction): 334 yield part, self.restriction_size(element, part) 335 336 def truncate(self, element, restriction): 337 """Truncates the provided restriction into a restriction representing a 338 finite amount of work when the pipeline is 339 `draining <https://docs.google.com/document/d/1NExwHlj-2q2WUGhSO4jTu8XGhDPmm3cllSN8IMmWci8/edit#> for additional details about drain.>_`. # pylint: disable=line-too-long 340 By default, if the restriction is bounded then the restriction will be 341 returned otherwise None will be returned. 342 343 This API is optional and should only be implemented if more granularity is 344 required. 345 346 Return a truncated finite restriction if further processing is required 347 otherwise return None to represent that no further processing of this 348 restriction is required. 349 350 The default behavior when a pipeline is being drained is that bounded 351 restrictions process entirely while unbounded restrictions process till a 352 checkpoint is possible. 353 """ 354 restriction_tracker = self.create_tracker(restriction) 355 if restriction_tracker.is_bounded(): 356 return restriction 357 358 359 def get_function_arguments(obj, func): 360 # type: (...) -> typing.Tuple[typing.List[str], typing.List[typing.Any]] 361 362 """Return the function arguments based on the name provided. If they have 363 a _inspect_function attached to the class then use that otherwise default 364 to the modified version of python inspect library. 365 366 Returns: 367 Same as get_function_args_defaults. 368 """ 369 func_name = '_inspect_%s' % func 370 if hasattr(obj, func_name): 371 f = getattr(obj, func_name) 372 return f() 373 f = getattr(obj, func) 374 return get_function_args_defaults(f) 375 376 377 def get_function_args_defaults(f): 378 # type: (...) -> typing.Tuple[typing.List[str], typing.List[typing.Any]] 379 380 """Returns the function arguments of a given function. 381 382 Returns: 383 (args: List[str], defaults: List[Any]). The first list names the 384 arguments of the method and the second one has the values of the default 385 arguments. This is similar to ``inspect.getfullargspec()``'s results, except 386 it doesn't include bound arguments and may follow function wrappers. 387 """ 388 signature = get_signature(f) 389 parameter = inspect.Parameter 390 # TODO(BEAM-5878) support kwonlyargs on Python 3. 391 _SUPPORTED_ARG_TYPES = [ 392 parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD 393 ] 394 args = [ 395 name for name, 396 p in signature.parameters.items() if p.kind in _SUPPORTED_ARG_TYPES 397 ] 398 defaults = [ 399 p.default for p in signature.parameters.values() 400 if p.kind in _SUPPORTED_ARG_TYPES and p.default is not p.empty 401 ] 402 403 return args, defaults 404 405 406 class WatermarkEstimatorProvider(object): 407 """Provides methods for generating WatermarkEstimator. 408 409 This class should be implemented if wanting to providing output_watermark 410 information within an SDF. 411 412 In order to make an SDF.process() access to the typical WatermarkEstimator, 413 the SDF author should have an argument whose default value is a 414 DoFn.WatermarkEstimatorParam instance. This DoFn.WatermarkEstimatorParam 415 can either be constructed with an explicit WatermarkEstimatorProvider, 416 or, if no WatermarkEstimatorProvider is provided, the DoFn itself must 417 be a WatermarkEstimatorProvider. 418 """ 419 def initial_estimator_state(self, element, restriction): 420 """Returns the initial state of the WatermarkEstimator with given element 421 and restriction. 422 This function is called by the system. 423 """ 424 raise NotImplementedError 425 426 def create_watermark_estimator(self, estimator_state): 427 """Create a new WatermarkEstimator based on the state. The state is 428 typically useful when resuming processing an element. 429 """ 430 raise NotImplementedError 431 432 def estimator_state_coder(self): 433 return coders.registry.get_coder(object) 434 435 436 class _DoFnParam(object): 437 """DoFn parameter.""" 438 def __init__(self, param_id): 439 self.param_id = param_id 440 441 def __eq__(self, other): 442 if type(self) == type(other): 443 return self.param_id == other.param_id 444 return False 445 446 def __hash__(self): 447 return hash(self.param_id) 448 449 def __repr__(self): 450 return self.param_id 451 452 453 class _RestrictionDoFnParam(_DoFnParam): 454 """Restriction Provider DoFn parameter.""" 455 def __init__(self, restriction_provider=None): 456 # type: (typing.Optional[RestrictionProvider]) -> None 457 if (restriction_provider is not None and 458 not isinstance(restriction_provider, RestrictionProvider)): 459 raise ValueError( 460 'DoFn.RestrictionParam expected RestrictionProvider object.') 461 self.restriction_provider = restriction_provider 462 self.param_id = ( 463 'RestrictionParam(%s)' % restriction_provider.__class__.__name__) 464 465 466 class _StateDoFnParam(_DoFnParam): 467 """State DoFn parameter.""" 468 def __init__(self, state_spec): 469 # type: (StateSpec) -> None 470 if not isinstance(state_spec, StateSpec): 471 raise ValueError("DoFn.StateParam expected StateSpec object.") 472 self.state_spec = state_spec 473 self.param_id = 'StateParam(%s)' % state_spec.name 474 475 476 class _TimerDoFnParam(_DoFnParam): 477 """Timer DoFn parameter.""" 478 def __init__(self, timer_spec): 479 # type: (TimerSpec) -> None 480 if not isinstance(timer_spec, TimerSpec): 481 raise ValueError("DoFn.TimerParam expected TimerSpec object.") 482 self.timer_spec = timer_spec 483 self.param_id = 'TimerParam(%s)' % timer_spec.name 484 485 486 class _BundleFinalizerParam(_DoFnParam): 487 """Bundle Finalization DoFn parameter.""" 488 def __init__(self): 489 self._callbacks = [] 490 self.param_id = "FinalizeBundle" 491 492 def register(self, callback): 493 self._callbacks.append(callback) 494 495 # Log errors when calling callback to make sure all callbacks get called 496 # though there are errors. And errors should not fail pipeline. 497 def finalize_bundle(self): 498 for callback in self._callbacks: 499 try: 500 callback() 501 except Exception as e: 502 _LOGGER.warning("Got exception from finalization call: %s", e) 503 504 def has_callbacks(self): 505 # type: () -> bool 506 return len(self._callbacks) > 0 507 508 def reset(self): 509 # type: () -> None 510 del self._callbacks[:] 511 512 513 class _WatermarkEstimatorParam(_DoFnParam): 514 """WatermarkEstimator DoFn parameter.""" 515 def __init__( 516 self, 517 watermark_estimator_provider: typing. 518 Optional[WatermarkEstimatorProvider] = None): 519 if (watermark_estimator_provider is not None and not isinstance( 520 watermark_estimator_provider, WatermarkEstimatorProvider)): 521 raise ValueError( 522 'DoFn.WatermarkEstimatorParam expected' 523 'WatermarkEstimatorProvider object.') 524 self.watermark_estimator_provider = watermark_estimator_provider 525 self.param_id = 'WatermarkEstimatorProvider' 526 527 528 class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn): 529 """A function object used by a transform with custom processing. 530 531 The ParDo transform is such a transform. The ParDo.apply 532 method will take an object of type DoFn and apply it to all elements of a 533 PCollection object. 534 535 In order to have concrete DoFn objects one has to subclass from DoFn and 536 define the desired behavior (start_bundle/finish_bundle and process) or wrap a 537 callable object using the CallableWrapperDoFn class. 538 """ 539 540 # Parameters that can be used in the .process() method. 541 ElementParam = _DoFnParam('ElementParam') 542 SideInputParam = _DoFnParam('SideInputParam') 543 TimestampParam = _DoFnParam('TimestampParam') 544 WindowParam = _DoFnParam('WindowParam') 545 PaneInfoParam = _DoFnParam('PaneInfoParam') 546 WatermarkEstimatorParam = _WatermarkEstimatorParam 547 BundleFinalizerParam = _BundleFinalizerParam 548 KeyParam = _DoFnParam('KeyParam') 549 550 # Parameters to access state and timers. Not restricted to use only in the 551 # .process() method. Usage: DoFn.StateParam(state_spec), 552 # DoFn.TimerParam(timer_spec), DoFn.TimestampParam, DoFn.WindowParam, 553 # DoFn.KeyParam 554 StateParam = _StateDoFnParam 555 TimerParam = _TimerDoFnParam 556 DynamicTimerTagParam = _DoFnParam('DynamicTimerTagParam') 557 558 DoFnProcessParams = [ 559 ElementParam, 560 SideInputParam, 561 TimestampParam, 562 WindowParam, 563 WatermarkEstimatorParam, 564 PaneInfoParam, 565 BundleFinalizerParam, 566 KeyParam, 567 StateParam, 568 TimerParam, 569 ] 570 571 RestrictionParam = _RestrictionDoFnParam 572 573 @staticmethod 574 def from_callable(fn): 575 return CallableWrapperDoFn(fn) 576 577 @staticmethod 578 def unbounded_per_element(): 579 """A decorator on process fn specifying that the fn performs an unbounded 580 amount of work per input element.""" 581 def wrapper(process_fn): 582 process_fn.unbounded_per_element = True 583 return process_fn 584 585 return wrapper 586 587 @staticmethod 588 def yields_elements(fn): 589 """A decorator to apply to ``process_batch`` indicating it yields elements. 590 591 By default ``process_batch`` is assumed to both consume and produce 592 "batches", which are collections of multiple logical Beam elements. This 593 decorator indicates that ``process_batch`` **produces** individual elements 594 at a time. ``process_batch`` is always expected to consume batches. 595 """ 596 if not fn.__name__ in ('process', 'process_batch'): 597 raise TypeError( 598 "@yields_elements must be applied to a process or " 599 f"process_batch method, got {fn!r}.") 600 601 fn._beam_yields_elements = True 602 return fn 603 604 @staticmethod 605 def yields_batches(fn): 606 """A decorator to apply to ``process`` indicating it yields batches. 607 608 By default ``process`` is assumed to both consume and produce 609 individual elements at a time. This decorator indicates that ``process`` 610 **produces** "batches", which are collections of multiple logical Beam 611 elements. 612 """ 613 if not fn.__name__ in ('process', 'process_batch'): 614 raise TypeError( 615 "@yields_elements must be applied to a process or " 616 f"process_batch method, got {fn!r}.") 617 618 fn._beam_yields_batches = True 619 return fn 620 621 def default_label(self): 622 return self.__class__.__name__ 623 624 def process(self, element, *args, **kwargs): 625 """Method to use for processing elements. 626 627 This is invoked by ``DoFnRunner`` for each element of a input 628 ``PCollection``. 629 630 The following parameters can be used as default values on ``process`` 631 arguments to indicate that a DoFn accepts the corresponding parameters. For 632 example, a DoFn might accept the element and its timestamp with the 633 following signature:: 634 635 def process(element=DoFn.ElementParam, timestamp=DoFn.TimestampParam): 636 ... 637 638 The full set of parameters is: 639 640 - ``DoFn.ElementParam``: element to be processed, should not be mutated. 641 - ``DoFn.SideInputParam``: a side input that may be used when processing. 642 - ``DoFn.TimestampParam``: timestamp of the input element. 643 - ``DoFn.WindowParam``: ``Window`` the input element belongs to. 644 - ``DoFn.TimerParam``: a ``userstate.RuntimeTimer`` object defined by the 645 spec of the parameter. 646 - ``DoFn.StateParam``: a ``userstate.RuntimeState`` object defined by the 647 spec of the parameter. 648 - ``DoFn.KeyParam``: key associated with the element. 649 - ``DoFn.RestrictionParam``: an ``iobase.RestrictionTracker`` will be 650 provided here to allow treatment as a Splittable ``DoFn``. The restriction 651 tracker will be derived from the restriction provider in the parameter. 652 - ``DoFn.WatermarkEstimatorParam``: a function that can be used to track 653 output watermark of Splittable ``DoFn`` implementations. 654 655 Args: 656 element: The element to be processed 657 *args: side inputs 658 **kwargs: other keyword arguments. 659 660 Returns: 661 An Iterable of output elements or None. 662 """ 663 raise NotImplementedError 664 665 def process_batch(self, batch, *args, **kwargs): 666 raise NotImplementedError 667 668 def setup(self): 669 """Called to prepare an instance for processing bundles of elements. 670 671 This is a good place to initialize transient in-memory resources, such as 672 network connections. The resources can then be disposed in 673 ``DoFn.teardown``. 674 """ 675 pass 676 677 def start_bundle(self): 678 """Called before a bundle of elements is processed on a worker. 679 680 Elements to be processed are split into bundles and distributed 681 to workers. Before a worker calls process() on the first element 682 of its bundle, it calls this method. 683 """ 684 pass 685 686 def finish_bundle(self): 687 """Called after a bundle of elements is processed on a worker. 688 """ 689 pass 690 691 def teardown(self): 692 """Called to use to clean up this instance before it is discarded. 693 694 A runner will do its best to call this method on any given instance to 695 prevent leaks of transient resources, however, there may be situations where 696 this is impossible (e.g. process crash, hardware failure, etc.) or 697 unnecessary (e.g. the pipeline is shutting down and the process is about to 698 be killed anyway, so all transient resources will be released automatically 699 by the OS). In these cases, the call may not happen. It will also not be 700 retried, because in such situations the DoFn instance no longer exists, so 701 there's no instance to retry it on. 702 703 Thus, all work that depends on input elements, and all externally important 704 side effects, must be performed in ``DoFn.process`` or 705 ``DoFn.finish_bundle``. 706 """ 707 pass 708 709 def get_function_arguments(self, func): 710 return get_function_arguments(self, func) 711 712 def default_type_hints(self): 713 process_type_hints = typehints.decorators.IOTypeHints.from_callable( 714 self.process) or typehints.decorators.IOTypeHints.empty() 715 716 if self._process_yields_batches: 717 # process() produces batches, don't use it's output typehint 718 process_type_hints = process_type_hints.with_output_types_from( 719 typehints.decorators.IOTypeHints.empty()) 720 721 if self._process_batch_yields_elements: 722 # process_batch() produces elements, *do* use it's output typehint 723 724 # First access the typehint 725 process_batch_type_hints = typehints.decorators.IOTypeHints.from_callable( 726 self.process_batch) or typehints.decorators.IOTypeHints.empty() 727 728 # Then we deconflict with the typehint from process, if it exists 729 if (process_batch_type_hints.output_types != 730 typehints.decorators.IOTypeHints.empty().output_types): 731 if (process_type_hints.output_types != 732 typehints.decorators.IOTypeHints.empty().output_types and 733 process_batch_type_hints.output_types != 734 process_type_hints.output_types): 735 raise TypeError( 736 f"DoFn {self!r} yields element from both process and " 737 "process_batch, but they have mismatched output typehints:\n" 738 f" process: {process_type_hints.output_types}\n" 739 f" process_batch: {process_batch_type_hints.output_types}") 740 741 process_type_hints = process_type_hints.with_output_types_from( 742 process_batch_type_hints) 743 744 try: 745 process_type_hints = process_type_hints.strip_iterable() 746 except ValueError as e: 747 raise ValueError('Return value not iterable: %s: %s' % (self, e)) 748 749 # Prefer class decorator type hints for backwards compatibility. 750 return get_type_hints(self.__class__).with_defaults(process_type_hints) 751 752 # TODO(sourabhbajaj): Do we want to remove the responsibility of these from 753 # the DoFn or maybe the runner 754 def infer_output_type(self, input_type): 755 # TODO(https://github.com/apache/beam/issues/19824): Side inputs types. 756 return trivial_inference.element_type( 757 _strip_output_annotations( 758 trivial_inference.infer_return_type(self.process, [input_type]))) 759 760 @property 761 def _process_defined(self) -> bool: 762 # Check if this DoFn's process method has been overridden 763 # Note that we retrieve the __func__ attribute, if it exists, to get the 764 # underlying function from the bound method. 765 # If __func__ doesn't exist, self.process was likely overridden with a free 766 # function, as in CallableWrapperDoFn. 767 return getattr(self.process, '__func__', self.process) != DoFn.process 768 769 @property 770 def _process_batch_defined(self) -> bool: 771 # Check if this DoFn's process_batch method has been overridden 772 # Note that we retrieve the __func__ attribute, if it exists, to get the 773 # underlying function from the bound method. 774 # If __func__ doesn't exist, self.process_batch was likely overridden with 775 # a free function. 776 return getattr( 777 self.process_batch, '__func__', 778 self.process_batch) != DoFn.process_batch 779 780 @property 781 def _can_yield_batches(self) -> bool: 782 return ((self._process_defined and self._process_yields_batches) or ( 783 self._process_batch_defined and 784 not self._process_batch_yields_elements)) 785 786 @property 787 def _process_yields_batches(self) -> bool: 788 return getattr(self.process, '_beam_yields_batches', False) 789 790 @property 791 def _process_batch_yields_elements(self) -> bool: 792 return getattr(self.process_batch, '_beam_yields_elements', False) 793 794 def get_input_batch_type( 795 self, input_element_type 796 ) -> typing.Optional[typing.Union[TypeConstraint, type]]: 797 """Determine the batch type expected as input to process_batch. 798 799 The default implementation of ``get_input_batch_type`` simply observes the 800 input typehint for the first parameter of ``process_batch``. A Batched DoFn 801 may override this method if a dynamic approach is required. 802 803 Args: 804 input_element_type: The **element type** of the input PCollection this 805 DoFn is being applied to. 806 807 Returns: 808 ``None`` if this DoFn cannot accept batches, else a Beam typehint or 809 a native Python typehint. 810 """ 811 if not self._process_batch_defined: 812 return None 813 input_type = list( 814 inspect.signature(self.process_batch).parameters.values())[0].annotation 815 if input_type == inspect.Signature.empty: 816 # TODO(https://github.com/apache/beam/issues/21652): Consider supporting 817 # an alternative (dynamic?) approach for declaring input type 818 raise TypeError( 819 f"Either {self.__class__.__name__}.process_batch() must have a type " 820 f"annotation on its first parameter, or {self.__class__.__name__} " 821 "must override get_input_batch_type.") 822 return input_type 823 824 def _get_input_batch_type_normalized(self, input_element_type): 825 return typehints.native_type_compatibility.convert_to_beam_type( 826 self.get_input_batch_type(input_element_type)) 827 828 def _get_output_batch_type_normalized(self, input_element_type): 829 return typehints.native_type_compatibility.convert_to_beam_type( 830 self.get_output_batch_type(input_element_type)) 831 832 @staticmethod 833 def _get_element_type_from_return_annotation(method, input_type): 834 return_type = inspect.signature(method).return_annotation 835 if return_type == inspect.Signature.empty: 836 # output type not annotated, try to infer it 837 return_type = trivial_inference.infer_return_type(method, [input_type]) 838 839 return_type = typehints.native_type_compatibility.convert_to_beam_type( 840 return_type) 841 if isinstance(return_type, typehints.typehints.IterableTypeConstraint): 842 return return_type.inner_type 843 elif isinstance(return_type, typehints.typehints.IteratorTypeConstraint): 844 return return_type.yielded_type 845 else: 846 raise TypeError( 847 "Expected Iterator in return type annotation for " 848 f"{method!r}, did you mean Iterator[{return_type}]? Note Beam DoFn " 849 "process and process_batch methods are expected to produce " 850 "generators - they should 'yield' rather than 'return'.") 851 852 def get_output_batch_type( 853 self, input_element_type 854 ) -> typing.Optional[typing.Union[TypeConstraint, type]]: 855 """Determine the batch type produced by this DoFn's ``process_batch`` 856 implementation and/or its ``process`` implementation with 857 ``@yields_batch``. 858 859 The default implementation of this method observes the return type 860 annotations on ``process_batch`` and/or ``process``. A Batched DoFn may 861 override this method if a dynamic approach is required. 862 863 Args: 864 input_element_type: The **element type** of the input PCollection this 865 DoFn is being applied to. 866 867 Returns: 868 ``None`` if this DoFn will never yield batches, else a Beam typehint or 869 a native Python typehint. 870 """ 871 output_batch_type = None 872 if self._process_defined and self._process_yields_batches: 873 output_batch_type = self._get_element_type_from_return_annotation( 874 self.process, input_element_type) 875 if self._process_batch_defined and not self._process_batch_yields_elements: 876 process_batch_type = self._get_element_type_from_return_annotation( 877 self.process_batch, 878 self._get_input_batch_type_normalized(input_element_type)) 879 880 # TODO: Consider requiring an inheritance relationship rather than 881 # equality 882 if (output_batch_type is not None and 883 (not process_batch_type == output_batch_type)): 884 raise TypeError( 885 f"DoFn {self!r} yields batches from both process and " 886 "process_batch, but they produce different types:\n" 887 f" process: {output_batch_type}\n" 888 f" process_batch: {process_batch_type!r}") 889 890 output_batch_type = process_batch_type 891 892 return output_batch_type 893 894 def _process_argspec_fn(self): 895 """Returns the Python callable that will eventually be invoked. 896 897 This should ideally be the user-level function that is called with 898 the main and (if any) side inputs, and is used to relate the type 899 hint parameters with the input parameters (e.g., by argument name). 900 """ 901 return self.process 902 903 urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_DOFN) 904 905 906 class CallableWrapperDoFn(DoFn): 907 """For internal use only; no backwards-compatibility guarantees. 908 909 A DoFn (function) object wrapping a callable object. 910 911 The purpose of this class is to conveniently wrap simple functions and use 912 them in transforms. 913 """ 914 def __init__(self, fn, fullargspec=None): 915 """Initializes a CallableWrapperDoFn object wrapping a callable. 916 917 Args: 918 fn: A callable object. 919 920 Raises: 921 TypeError: if fn parameter is not a callable type. 922 """ 923 if not callable(fn): 924 raise TypeError('Expected a callable object instead of: %r' % fn) 925 926 self._fn = fn 927 self._fullargspec = fullargspec 928 if isinstance( 929 fn, (types.BuiltinFunctionType, types.MethodType, types.FunctionType)): 930 self.process = fn 931 else: 932 # For cases such as set / list where fn is callable but not a function 933 self.process = lambda element: fn(element) 934 935 super().__init__() 936 937 def display_data(self): 938 # If the callable has a name, then it's likely a function, and 939 # we show its name. 940 # Otherwise, it might be an instance of a callable class. We 941 # show its class. 942 display_data_value = ( 943 self._fn.__name__ 944 if hasattr(self._fn, '__name__') else self._fn.__class__) 945 return { 946 'fn': DisplayDataItem(display_data_value, label='Transform Function') 947 } 948 949 def __repr__(self): 950 return 'CallableWrapperDoFn(%s)' % self._fn 951 952 def default_type_hints(self): 953 fn_type_hints = typehints.decorators.IOTypeHints.from_callable(self._fn) 954 type_hints = get_type_hints(self._fn).with_defaults(fn_type_hints) 955 # The fn's output type should be iterable. Strip off the outer 956 # container type due to the 'flatten' portion of FlatMap/ParDo. 957 try: 958 type_hints = type_hints.strip_iterable() 959 except ValueError as e: 960 raise TypeCheckError( 961 'Return value not iterable: %s: %s' % 962 (self.display_data()['fn'].value, e)) 963 return type_hints 964 965 def infer_output_type(self, input_type): 966 return trivial_inference.element_type( 967 _strip_output_annotations( 968 trivial_inference.infer_return_type(self._fn, [input_type]))) 969 970 def _process_argspec_fn(self): 971 return getattr(self._fn, '_argspec_fn', self._fn) 972 973 def _inspect_process(self): 974 if self._fullargspec: 975 return self._fullargspec 976 else: 977 return get_function_args_defaults(self._process_argspec_fn()) 978 979 980 class CombineFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn): 981 """A function object used by a Combine transform with custom processing. 982 983 A CombineFn specifies how multiple values in all or part of a PCollection can 984 be merged into a single value---essentially providing the same kind of 985 information as the arguments to the Python "reduce" builtin (except for the 986 input argument, which is an instance of CombineFnProcessContext). The 987 combining process proceeds as follows: 988 989 1. Input values are partitioned into one or more batches. 990 2. For each batch, the setup method is invoked. 991 3. For each batch, the create_accumulator method is invoked to create a fresh 992 initial "accumulator" value representing the combination of zero values. 993 4. For each input value in the batch, the add_input method is invoked to 994 combine more values with the accumulator for that batch. 995 5. The merge_accumulators method is invoked to combine accumulators from 996 separate batches into a single combined output accumulator value, once all 997 of the accumulators have had all the input value in their batches added to 998 them. This operation is invoked repeatedly, until there is only one 999 accumulator value left. 1000 6. The extract_output operation is invoked on the final accumulator to get 1001 the output value. 1002 7. The teardown method is invoked. 1003 1004 Note: If this **CombineFn** is used with a transform that has defaults, 1005 **apply** will be called with an empty list at expansion time to get the 1006 default value. 1007 """ 1008 def default_label(self): 1009 return self.__class__.__name__ 1010 1011 def setup(self, *args, **kwargs): 1012 """Called to prepare an instance for combining. 1013 1014 This method can be useful if there is some state that needs to be loaded 1015 before executing any of the other methods. The resources can then be 1016 disposed of in ``CombineFn.teardown``. 1017 1018 If you are using Dataflow, you need to enable Dataflow Runner V2 1019 before using this feature. 1020 1021 Args: 1022 *args: Additional arguments and side inputs. 1023 **kwargs: Additional arguments and side inputs. 1024 """ 1025 pass 1026 1027 def create_accumulator(self, *args, **kwargs): 1028 """Return a fresh, empty accumulator for the combine operation. 1029 1030 Args: 1031 *args: Additional arguments and side inputs. 1032 **kwargs: Additional arguments and side inputs. 1033 """ 1034 raise NotImplementedError(str(self)) 1035 1036 def add_input(self, mutable_accumulator, element, *args, **kwargs): 1037 """Return result of folding element into accumulator. 1038 1039 CombineFn implementors must override add_input. 1040 1041 Args: 1042 mutable_accumulator: the current accumulator, 1043 may be modified and returned for efficiency 1044 element: the element to add, should not be mutated 1045 *args: Additional arguments and side inputs. 1046 **kwargs: Additional arguments and side inputs. 1047 """ 1048 raise NotImplementedError(str(self)) 1049 1050 def add_inputs(self, mutable_accumulator, elements, *args, **kwargs): 1051 """Returns the result of folding each element in elements into accumulator. 1052 1053 This is provided in case the implementation affords more efficient 1054 bulk addition of elements. The default implementation simply loops 1055 over the inputs invoking add_input for each one. 1056 1057 Args: 1058 mutable_accumulator: the current accumulator, 1059 may be modified and returned for efficiency 1060 elements: the elements to add, should not be mutated 1061 *args: Additional arguments and side inputs. 1062 **kwargs: Additional arguments and side inputs. 1063 """ 1064 for element in elements: 1065 mutable_accumulator =\ 1066 self.add_input(mutable_accumulator, element, *args, **kwargs) 1067 return mutable_accumulator 1068 1069 def merge_accumulators(self, accumulators, *args, **kwargs): 1070 """Returns the result of merging several accumulators 1071 to a single accumulator value. 1072 1073 Args: 1074 accumulators: the accumulators to merge. 1075 Only the first accumulator may be modified and returned for efficiency; 1076 the other accumulators should not be mutated, because they may be 1077 shared with other code and mutating them could lead to incorrect 1078 results or data corruption. 1079 *args: Additional arguments and side inputs. 1080 **kwargs: Additional arguments and side inputs. 1081 """ 1082 raise NotImplementedError(str(self)) 1083 1084 def compact(self, accumulator, *args, **kwargs): 1085 """Optionally returns a more compact represenation of the accumulator. 1086 1087 This is called before an accumulator is sent across the wire, and can 1088 be useful in cases where values are buffered or otherwise lazily 1089 kept unprocessed when added to the accumulator. Should return an 1090 equivalent, though possibly modified, accumulator. 1091 1092 By default returns the accumulator unmodified. 1093 1094 Args: 1095 accumulator: the current accumulator 1096 *args: Additional arguments and side inputs. 1097 **kwargs: Additional arguments and side inputs. 1098 """ 1099 return accumulator 1100 1101 def extract_output(self, accumulator, *args, **kwargs): 1102 """Return result of converting accumulator into the output value. 1103 1104 Args: 1105 accumulator: the final accumulator value computed by this CombineFn 1106 for the entire input key or PCollection. Can be modified for 1107 efficiency. 1108 *args: Additional arguments and side inputs. 1109 **kwargs: Additional arguments and side inputs. 1110 """ 1111 raise NotImplementedError(str(self)) 1112 1113 def teardown(self, *args, **kwargs): 1114 """Called to clean up an instance before it is discarded. 1115 1116 If you are using Dataflow, you need to enable Dataflow Runner V2 1117 before using this feature. 1118 1119 Args: 1120 *args: Additional arguments and side inputs. 1121 **kwargs: Additional arguments and side inputs. 1122 """ 1123 pass 1124 1125 def apply(self, elements, *args, **kwargs): 1126 """Returns result of applying this CombineFn to the input values. 1127 1128 Args: 1129 elements: the set of values to combine. 1130 *args: Additional arguments and side inputs. 1131 **kwargs: Additional arguments and side inputs. 1132 """ 1133 return self.extract_output( 1134 self.add_inputs( 1135 self.create_accumulator(*args, **kwargs), elements, *args, 1136 **kwargs), 1137 *args, 1138 **kwargs) 1139 1140 def for_input_type(self, input_type): 1141 """Returns a specialized implementation of self, if it exists. 1142 1143 Otherwise, returns self. 1144 1145 Args: 1146 input_type: the type of input elements. 1147 """ 1148 return self 1149 1150 @staticmethod 1151 def from_callable(fn): 1152 return CallableWrapperCombineFn(fn) 1153 1154 @staticmethod 1155 def maybe_from_callable(fn, has_side_inputs=True): 1156 # type: (typing.Union[CombineFn, typing.Callable], bool) -> CombineFn 1157 if isinstance(fn, CombineFn): 1158 return fn 1159 elif callable(fn) and not has_side_inputs: 1160 return NoSideInputsCallableWrapperCombineFn(fn) 1161 elif callable(fn): 1162 return CallableWrapperCombineFn(fn) 1163 else: 1164 raise TypeError('Expected a CombineFn or callable, got %r' % fn) 1165 1166 def get_accumulator_coder(self): 1167 return coders.registry.get_coder(object) 1168 1169 urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_COMBINE_FN) 1170 1171 1172 class _ReiterableChain(object): 1173 """Like itertools.chain, but allowing re-iteration.""" 1174 def __init__(self, iterables): 1175 self.iterables = iterables 1176 1177 def __iter__(self): 1178 for iterable in self.iterables: 1179 for item in iterable: 1180 yield item 1181 1182 def __bool__(self): 1183 for iterable in self.iterables: 1184 for _ in iterable: 1185 return True 1186 return False 1187 1188 1189 class CallableWrapperCombineFn(CombineFn): 1190 """For internal use only; no backwards-compatibility guarantees. 1191 1192 A CombineFn (function) object wrapping a callable object. 1193 1194 The purpose of this class is to conveniently wrap simple functions and use 1195 them in Combine transforms. 1196 """ 1197 _DEFAULT_BUFFER_SIZE = 10 1198 1199 def __init__(self, fn, buffer_size=_DEFAULT_BUFFER_SIZE): 1200 """Initializes a CallableFn object wrapping a callable. 1201 1202 Args: 1203 fn: A callable object that reduces elements of an iterable to a single 1204 value (like the builtins sum and max). This callable must be capable of 1205 receiving the kind of values it generates as output in its input, and 1206 for best results, its operation must be commutative and associative. 1207 1208 Raises: 1209 TypeError: if fn parameter is not a callable type. 1210 """ 1211 if not callable(fn): 1212 raise TypeError('Expected a callable object instead of: %r' % fn) 1213 1214 super().__init__() 1215 self._fn = fn 1216 self._buffer_size = buffer_size 1217 1218 def display_data(self): 1219 return {'fn_dd': self._fn} 1220 1221 def __repr__(self): 1222 return "%s(%s)" % (self.__class__.__name__, self._fn) 1223 1224 def create_accumulator(self, *args, **kwargs): 1225 return [] 1226 1227 def add_input(self, accumulator, element, *args, **kwargs): 1228 accumulator.append(element) 1229 if len(accumulator) > self._buffer_size: 1230 accumulator = [self._fn(accumulator, *args, **kwargs)] 1231 return accumulator 1232 1233 def add_inputs(self, accumulator, elements, *args, **kwargs): 1234 accumulator.extend(elements) 1235 if len(accumulator) > self._buffer_size: 1236 accumulator = [self._fn(accumulator, *args, **kwargs)] 1237 return accumulator 1238 1239 def merge_accumulators(self, accumulators, *args, **kwargs): 1240 return [self._fn(_ReiterableChain(accumulators), *args, **kwargs)] 1241 1242 def compact(self, accumulator, *args, **kwargs): 1243 if len(accumulator) <= 1: 1244 return accumulator 1245 else: 1246 return [self._fn(accumulator, *args, **kwargs)] 1247 1248 def extract_output(self, accumulator, *args, **kwargs): 1249 return self._fn(accumulator, *args, **kwargs) 1250 1251 def default_type_hints(self): 1252 fn_type_hints = typehints.decorators.IOTypeHints.from_callable(self._fn) 1253 type_hints = get_type_hints(self._fn).with_defaults(fn_type_hints) 1254 if type_hints.input_types is None: 1255 return type_hints 1256 else: 1257 # fn(Iterable[V]) -> V becomes CombineFn(V) -> V 1258 input_args, input_kwargs = type_hints.input_types 1259 if not input_args: 1260 if len(input_kwargs) == 1: 1261 input_args, input_kwargs = tuple(input_kwargs.values()), {} 1262 else: 1263 raise TypeError('Combiner input type must be specified positionally.') 1264 if not is_consistent_with(input_args[0], 1265 typehints.Iterable[typehints.Any]): 1266 raise TypeCheckError( 1267 'All functions for a Combine PTransform must accept a ' 1268 'single argument compatible with: Iterable[Any]. ' 1269 'Instead a function with input type: %s was received.' % 1270 input_args[0]) 1271 input_args = (element_type(input_args[0]), ) + input_args[1:] 1272 # TODO(robertwb): Assert output type is consistent with input type? 1273 return type_hints.with_input_types(*input_args, **input_kwargs) 1274 1275 def infer_output_type(self, input_type): 1276 return _strip_output_annotations( 1277 trivial_inference.infer_return_type(self._fn, [input_type])) 1278 1279 def for_input_type(self, input_type): 1280 # Avoid circular imports. 1281 from apache_beam.transforms import cy_combiners 1282 if self._fn is any: 1283 return cy_combiners.AnyCombineFn() 1284 elif self._fn is all: 1285 return cy_combiners.AllCombineFn() 1286 else: 1287 known_types = { 1288 (sum, int): cy_combiners.SumInt64Fn(), 1289 (min, int): cy_combiners.MinInt64Fn(), 1290 (max, int): cy_combiners.MaxInt64Fn(), 1291 (sum, float): cy_combiners.SumFloatFn(), 1292 (min, float): cy_combiners.MinFloatFn(), 1293 (max, float): cy_combiners.MaxFloatFn(), 1294 } 1295 return known_types.get((self._fn, input_type), self) 1296 1297 1298 class NoSideInputsCallableWrapperCombineFn(CallableWrapperCombineFn): 1299 """For internal use only; no backwards-compatibility guarantees. 1300 1301 A CombineFn (function) object wrapping a callable object with no side inputs. 1302 1303 This is identical to its parent, but avoids accepting and passing *args 1304 and **kwargs for efficiency as they are known to be empty. 1305 """ 1306 def create_accumulator(self): 1307 return [] 1308 1309 def add_input(self, accumulator, element): 1310 accumulator.append(element) 1311 if len(accumulator) > self._buffer_size: 1312 accumulator = [self._fn(accumulator)] 1313 return accumulator 1314 1315 def add_inputs(self, accumulator, elements): 1316 accumulator.extend(elements) 1317 if len(accumulator) > self._buffer_size: 1318 accumulator = [self._fn(accumulator)] 1319 return accumulator 1320 1321 def merge_accumulators(self, accumulators): 1322 return [self._fn(_ReiterableChain(accumulators))] 1323 1324 def compact(self, accumulator): 1325 if len(accumulator) <= 1: 1326 return accumulator 1327 else: 1328 return [self._fn(accumulator)] 1329 1330 def extract_output(self, accumulator): 1331 return self._fn(accumulator) 1332 1333 1334 class PartitionFn(WithTypeHints): 1335 """A function object used by a Partition transform. 1336 1337 A PartitionFn specifies how individual values in a PCollection will be placed 1338 into separate partitions, indexed by an integer. 1339 """ 1340 def default_label(self): 1341 return self.__class__.__name__ 1342 1343 def partition_for(self, element, num_partitions, *args, **kwargs): 1344 # type: (T, int, *typing.Any, **typing.Any) -> int 1345 1346 """Specify which partition will receive this element. 1347 1348 Args: 1349 element: An element of the input PCollection. 1350 num_partitions: Number of partitions, i.e., output PCollections. 1351 *args: optional parameters and side inputs. 1352 **kwargs: optional parameters and side inputs. 1353 1354 Returns: 1355 An integer in [0, num_partitions). 1356 """ 1357 pass 1358 1359 1360 class CallableWrapperPartitionFn(PartitionFn): 1361 """For internal use only; no backwards-compatibility guarantees. 1362 1363 A PartitionFn object wrapping a callable object. 1364 1365 Instances of this class wrap simple functions for use in Partition operations. 1366 """ 1367 def __init__(self, fn): 1368 """Initializes a PartitionFn object wrapping a callable. 1369 1370 Args: 1371 fn: A callable object, which should accept the following arguments: 1372 element - element to assign to a partition. 1373 num_partitions - number of output partitions. 1374 and may accept additional arguments and side inputs. 1375 1376 Raises: 1377 TypeError: if fn is not a callable type. 1378 """ 1379 if not callable(fn): 1380 raise TypeError('Expected a callable object instead of: %r' % fn) 1381 self._fn = fn 1382 1383 def partition_for(self, element, num_partitions, *args, **kwargs): 1384 # type: (T, int, *typing.Any, **typing.Any) -> int 1385 return self._fn(element, num_partitions, *args, **kwargs) 1386 1387 1388 def _get_function_body_without_inners(func): 1389 source_lines = inspect.getsourcelines(func)[0] 1390 source_lines = dropwhile(lambda x: x.startswith("@"), source_lines) 1391 def_line = next(source_lines).strip() 1392 if def_line.startswith("def ") and def_line.endswith(":"): 1393 first_line = next(source_lines) 1394 indentation = len(first_line) - len(first_line.lstrip()) 1395 final_lines = [first_line[indentation:]] 1396 1397 skip_inner_def = False 1398 if first_line[indentation:].startswith("def "): 1399 skip_inner_def = True 1400 for line in source_lines: 1401 line_indentation = len(line) - len(line.lstrip()) 1402 1403 if line[indentation:].startswith("def "): 1404 skip_inner_def = True 1405 continue 1406 1407 if skip_inner_def and line_indentation == indentation: 1408 skip_inner_def = False 1409 1410 if skip_inner_def and line_indentation > indentation: 1411 continue 1412 final_lines.append(line[indentation:]) 1413 1414 return "".join(final_lines) 1415 else: 1416 return def_line.rsplit(":")[-1].strip() 1417 1418 1419 def _check_fn_use_yield_and_return(fn): 1420 if isinstance(fn, types.BuiltinFunctionType): 1421 return False 1422 try: 1423 source_code = _get_function_body_without_inners(fn) 1424 has_yield = False 1425 has_return = False 1426 for line in source_code.split("\n"): 1427 if line.lstrip().startswith("yield ") or line.lstrip().startswith( 1428 "yield("): 1429 has_yield = True 1430 if line.lstrip().startswith("return ") or line.lstrip().startswith( 1431 "return("): 1432 has_return = True 1433 if has_yield and has_return: 1434 return True 1435 return False 1436 except Exception as e: 1437 _LOGGER.debug(str(e)) 1438 return False 1439 1440 1441 class ParDo(PTransformWithSideInputs): 1442 """A :class:`ParDo` transform. 1443 1444 Processes an input :class:`~apache_beam.pvalue.PCollection` by applying a 1445 :class:`DoFn` to each element and returning the accumulated results into an 1446 output :class:`~apache_beam.pvalue.PCollection`. The type of the elements is 1447 not fixed as long as the :class:`DoFn` can deal with it. In reality the type 1448 is restrained to some extent because the elements sometimes must be persisted 1449 to external storage. See the :meth:`.expand()` method comments for a 1450 detailed description of all possible arguments. 1451 1452 Note that the :class:`DoFn` must return an iterable for each element of the 1453 input :class:`~apache_beam.pvalue.PCollection`. An easy way to do this is to 1454 use the ``yield`` keyword in the process method. 1455 1456 Args: 1457 pcoll (~apache_beam.pvalue.PCollection): 1458 a :class:`~apache_beam.pvalue.PCollection` to be processed. 1459 fn (`typing.Union[DoFn, typing.Callable]`): a :class:`DoFn` object to be 1460 applied to each element of **pcoll** argument, or a Callable. 1461 *args: positional arguments passed to the :class:`DoFn` object. 1462 **kwargs: keyword arguments passed to the :class:`DoFn` object. 1463 1464 Note that the positional and keyword arguments will be processed in order 1465 to detect :class:`~apache_beam.pvalue.PCollection` s that will be computed as 1466 side inputs to the transform. During pipeline execution whenever the 1467 :class:`DoFn` object gets executed (its :meth:`DoFn.process()` method gets 1468 called) the :class:`~apache_beam.pvalue.PCollection` arguments will be 1469 replaced by values from the :class:`~apache_beam.pvalue.PCollection` in the 1470 exact positions where they appear in the argument lists. 1471 """ 1472 def __init__(self, fn, *args, **kwargs): 1473 super().__init__(fn, *args, **kwargs) 1474 # TODO(robertwb): Change all uses of the dofn attribute to use fn instead. 1475 self.dofn = self.fn 1476 self.output_tags = set() # type: typing.Set[str] 1477 1478 if not isinstance(self.fn, DoFn): 1479 raise TypeError('ParDo must be called with a DoFn instance.') 1480 1481 # DoFn.process cannot allow both return and yield 1482 if _check_fn_use_yield_and_return(self.fn.process): 1483 _LOGGER.warning( 1484 'Using yield and return in the process method ' 1485 'of %s can lead to unexpected behavior, see:' 1486 'https://github.com/apache/beam/issues/22969.', 1487 self.fn.__class__) 1488 1489 # Validate the DoFn by creating a DoFnSignature 1490 from apache_beam.runners.common import DoFnSignature 1491 self._signature = DoFnSignature(self.fn) 1492 1493 def with_exception_handling( 1494 self, 1495 main_tag='good', 1496 dead_letter_tag='bad', 1497 *, 1498 exc_class=Exception, 1499 partial=False, 1500 use_subprocess=False, 1501 threshold=1, 1502 threshold_windowing=None, 1503 timeout=None): 1504 """Automatically provides a dead letter output for skipping bad records. 1505 This can allow a pipeline to continue successfully rather than fail or 1506 continuously throw errors on retry when bad elements are encountered. 1507 1508 This returns a tagged output with two PCollections, the first being the 1509 results of successfully processing the input PCollection, and the second 1510 being the set of bad records (those which threw exceptions during 1511 processing) along with information about the errors raised. 1512 1513 For example, one would write:: 1514 1515 good, bad = Map(maybe_error_raising_function).with_exception_handling() 1516 1517 and `good` will be a PCollection of mapped records and `bad` will contain 1518 those that raised exceptions. 1519 1520 1521 Args: 1522 main_tag: tag to be used for the main (good) output of the DoFn, 1523 useful to avoid possible conflicts if this DoFn already produces 1524 multiple outputs. Optional, defaults to 'good'. 1525 dead_letter_tag: tag to be used for the bad records, useful to avoid 1526 possible conflicts if this DoFn already produces multiple outputs. 1527 Optional, defaults to 'bad'. 1528 exc_class: An exception class, or tuple of exception classes, to catch. 1529 Optional, defaults to 'Exception'. 1530 partial: Whether to emit outputs for an element as they're produced 1531 (which could result in partial outputs for a ParDo or FlatMap that 1532 throws an error part way through execution) or buffer all outputs 1533 until successful processing of the entire element. Optional, 1534 defaults to False. 1535 use_subprocess: Whether to execute the DoFn logic in a subprocess. This 1536 allows one to recover from errors that can crash the calling process 1537 (e.g. from an underlying C/C++ library causing a segfault), but is 1538 slower as elements and results must cross a process boundary. Note 1539 that this starts up a long-running process that is used to handle 1540 all the elements (until hard failure, which should be rare) rather 1541 than a new process per element, so the overhead should be minimal 1542 (and can be amortized if there's any per-process or per-bundle 1543 initialization that needs to be done). Optional, defaults to False. 1544 threshold: An upper bound on the ratio of records that can be bad before 1545 aborting the entire pipeline. Optional, defaults to 1.0 (meaning 1546 up to 100% of records can be bad and the pipeline will still succeed). 1547 threshold_windowing: Event-time windowing to use for threshold. Optional, 1548 defaults to the windowing of the input. 1549 timeout: If the element has not finished processing in timeout seconds, 1550 raise a TimeoutError. Defaults to None, meaning no time limit. 1551 """ 1552 args, kwargs = self.raw_side_inputs 1553 return self.label >> _ExceptionHandlingWrapper( 1554 self.fn, 1555 args, 1556 kwargs, 1557 main_tag, 1558 dead_letter_tag, 1559 exc_class, 1560 partial, 1561 use_subprocess, 1562 threshold, 1563 threshold_windowing, 1564 timeout) 1565 1566 def default_type_hints(self): 1567 return self.fn.get_type_hints() 1568 1569 def infer_output_type(self, input_type): 1570 return self.fn.infer_output_type(input_type) 1571 1572 def infer_batch_converters(self, input_element_type): 1573 # TODO: Test this code (in batch_dofn_test) 1574 if self.fn._process_batch_defined: 1575 input_batch_type = self.fn._get_input_batch_type_normalized( 1576 input_element_type) 1577 1578 if input_batch_type is None: 1579 raise TypeError( 1580 "process_batch method on {self.fn!r} does not have " 1581 "an input type annoation") 1582 1583 try: 1584 # Generate a batch converter to convert between the input type and the 1585 # (batch) input type of process_batch 1586 self.fn.input_batch_converter = BatchConverter.from_typehints( 1587 element_type=input_element_type, batch_type=input_batch_type) 1588 except TypeError as e: 1589 raise TypeError( 1590 "Failed to find a BatchConverter for the input types of DoFn " 1591 f"{self.fn!r} (element_type={input_element_type!r}, " 1592 f"batch_type={input_batch_type!r}).") from e 1593 1594 else: 1595 self.fn.input_batch_converter = None 1596 1597 if self.fn._can_yield_batches: 1598 output_batch_type = self.fn._get_output_batch_type_normalized( 1599 input_element_type) 1600 if output_batch_type is None: 1601 # TODO: Mention process method in this error 1602 raise TypeError( 1603 f"process_batch method on {self.fn!r} does not have " 1604 "a return type annoation") 1605 1606 # Generate a batch converter to convert between the output type and the 1607 # (batch) output type of process_batch 1608 output_element_type = self.infer_output_type(input_element_type) 1609 1610 try: 1611 self.fn.output_batch_converter = BatchConverter.from_typehints( 1612 element_type=output_element_type, batch_type=output_batch_type) 1613 except TypeError as e: 1614 raise TypeError( 1615 "Failed to find a BatchConverter for the *output* types of DoFn " 1616 f"{self.fn!r} (element_type={output_element_type!r}, " 1617 f"batch_type={output_batch_type!r}). Maybe you need to override " 1618 "DoFn.infer_output_type to set the output element type?") from e 1619 else: 1620 self.fn.output_batch_converter = None 1621 1622 def make_fn(self, fn, has_side_inputs): 1623 if isinstance(fn, DoFn): 1624 return fn 1625 return CallableWrapperDoFn(fn) 1626 1627 def _process_argspec_fn(self): 1628 return self.fn._process_argspec_fn() 1629 1630 def display_data(self): 1631 return { 1632 'fn': DisplayDataItem(self.fn.__class__, label='Transform Function'), 1633 'fn_dd': self.fn 1634 } 1635 1636 def expand(self, pcoll): 1637 # In the case of a stateful DoFn, warn if the key coder is not 1638 # deterministic. 1639 if self._signature.is_stateful_dofn(): 1640 kv_type_hint = pcoll.element_type 1641 if kv_type_hint and kv_type_hint != typehints.Any: 1642 coder = coders.registry.get_coder(kv_type_hint) 1643 if not coder.is_kv_coder(): 1644 raise ValueError( 1645 'Input elements to the transform %s with stateful DoFn must be ' 1646 'key-value pairs.' % self) 1647 key_coder = coder.key_coder() 1648 else: 1649 key_coder = coders.registry.get_coder(typehints.Any) 1650 1651 if not key_coder.is_deterministic(): 1652 _LOGGER.warning( 1653 'Key coder %s for transform %s with stateful DoFn may not ' 1654 'be deterministic. This may cause incorrect behavior for complex ' 1655 'key types. Consider adding an input type hint for this transform.', 1656 key_coder, 1657 self) 1658 1659 if self._signature.is_unbounded_per_element(): 1660 is_bounded = False 1661 else: 1662 is_bounded = pcoll.is_bounded 1663 1664 self.infer_batch_converters(pcoll.element_type) 1665 1666 return pvalue.PCollection.from_(pcoll, is_bounded=is_bounded) 1667 1668 def with_outputs(self, *tags, main=None, allow_unknown_tags=None): 1669 """Returns a tagged tuple allowing access to the outputs of a 1670 :class:`ParDo`. 1671 1672 The resulting object supports access to the 1673 :class:`~apache_beam.pvalue.PCollection` associated with a tag 1674 (e.g. ``o.tag``, ``o[tag]``) and iterating over the available tags 1675 (e.g. ``for tag in o: ...``). 1676 1677 Args: 1678 *tags: if non-empty, list of valid tags. If a list of valid tags is given, 1679 it will be an error to use an undeclared tag later in the pipeline. 1680 **main_kw: dictionary empty or with one key ``'main'`` defining the tag to 1681 be used for the main output (which will not have a tag associated with 1682 it). 1683 1684 Returns: 1685 ~apache_beam.pvalue.DoOutputsTuple: An object of type 1686 :class:`~apache_beam.pvalue.DoOutputsTuple` that bundles together all 1687 the outputs of a :class:`ParDo` transform and allows accessing the 1688 individual :class:`~apache_beam.pvalue.PCollection` s for each output 1689 using an ``object.tag`` syntax. 1690 1691 Raises: 1692 TypeError: if the **self** object is not a 1693 :class:`~apache_beam.pvalue.PCollection` that is the result of a 1694 :class:`ParDo` transform. 1695 ValueError: if **main_kw** contains any key other than 1696 ``'main'``. 1697 """ 1698 if main in tags: 1699 raise ValueError( 1700 'Main output tag %r must be different from side output tags %r.' % 1701 (main, tags)) 1702 return _MultiParDo(self, tags, main, allow_unknown_tags) 1703 1704 def _do_fn_info(self): 1705 return DoFnInfo.create(self.fn, self.args, self.kwargs) 1706 1707 def _get_key_and_window_coder(self, named_inputs): 1708 if named_inputs is None or not self._signature.is_stateful_dofn(): 1709 return None, None 1710 main_input = list(set(named_inputs.keys()) - set(self.side_inputs))[0] 1711 input_pcoll = named_inputs[main_input] 1712 kv_type_hint = input_pcoll.element_type 1713 if kv_type_hint and kv_type_hint != typehints.Any: 1714 coder = coders.registry.get_coder(kv_type_hint) 1715 if not coder.is_kv_coder(): 1716 raise ValueError( 1717 'Input elements to the transform %s with stateful DoFn must be ' 1718 'key-value pairs.' % self) 1719 key_coder = coder.key_coder() 1720 else: 1721 key_coder = coders.registry.get_coder(typehints.Any) 1722 window_coder = input_pcoll.windowing.windowfn.get_window_coder() 1723 return key_coder, window_coder 1724 1725 # typing: PTransform base class does not accept extra_kwargs 1726 def to_runner_api_parameter(self, context, **extra_kwargs): # type: ignore[override] 1727 # type: (PipelineContext, **typing.Any) -> typing.Tuple[str, message.Message] 1728 assert isinstance(self, ParDo), \ 1729 "expected instance of ParDo, but got %s" % self.__class__ 1730 state_specs, timer_specs = userstate.get_dofn_specs(self.fn) 1731 if state_specs or timer_specs: 1732 context.add_requirement( 1733 common_urns.requirements.REQUIRES_STATEFUL_PROCESSING.urn) 1734 from apache_beam.runners.common import DoFnSignature 1735 sig = DoFnSignature(self.fn) 1736 is_splittable = sig.is_splittable_dofn() 1737 if is_splittable: 1738 restriction_coder = sig.get_restriction_coder() 1739 # restriction_coder will never be None when is_splittable is True 1740 assert restriction_coder is not None 1741 restriction_coder_id = context.coders.get_id( 1742 restriction_coder) # type: typing.Optional[str] 1743 context.add_requirement( 1744 common_urns.requirements.REQUIRES_SPLITTABLE_DOFN.urn) 1745 else: 1746 restriction_coder_id = None 1747 has_bundle_finalization = sig.has_bundle_finalization() 1748 if has_bundle_finalization: 1749 context.add_requirement( 1750 common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn) 1751 1752 # Get key_coder and window_coder for main_input. 1753 key_coder, window_coder = self._get_key_and_window_coder( 1754 extra_kwargs.get('named_inputs', None)) 1755 return ( 1756 common_urns.primitives.PAR_DO.urn, 1757 beam_runner_api_pb2.ParDoPayload( 1758 do_fn=self._do_fn_info().to_runner_api(context), 1759 requests_finalization=has_bundle_finalization, 1760 restriction_coder_id=restriction_coder_id, 1761 state_specs={ 1762 spec.name: spec.to_runner_api(context) 1763 for spec in state_specs 1764 }, 1765 timer_family_specs={ 1766 spec.name: spec.to_runner_api(context, key_coder, window_coder) 1767 for spec in timer_specs 1768 }, 1769 # It'd be nice to name these according to their actual 1770 # names/positions in the orignal argument list, but such a 1771 # transformation is currently irreversible given how 1772 # remove_objects_from_args and insert_values_in_args 1773 # are currently implemented. 1774 side_inputs={(SIDE_INPUT_PREFIX + '%s') % ix: 1775 si.to_runner_api(context) 1776 for ix, 1777 si in enumerate(self.side_inputs)})) 1778 1779 @staticmethod 1780 @PTransform.register_urn( 1781 common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload) 1782 def from_runner_api_parameter(unused_ptransform, pardo_payload, context): 1783 fn, args, kwargs, si_tags_and_types, windowing = pickler.loads( 1784 DoFnInfo.from_runner_api( 1785 pardo_payload.do_fn, context).serialized_dofn_data()) 1786 if si_tags_and_types: 1787 raise NotImplementedError('explicit side input data') 1788 elif windowing: 1789 raise NotImplementedError('explicit windowing') 1790 result = ParDo(fn, *args, **kwargs) 1791 # This is an ordered list stored as a dict (see the comments in 1792 # to_runner_api_parameter above). 1793 indexed_side_inputs = [( 1794 get_sideinput_index(tag), 1795 pvalue.AsSideInput.from_runner_api(si, context)) for tag, 1796 si in pardo_payload.side_inputs.items()] 1797 result.side_inputs = [si for _, si in sorted(indexed_side_inputs)] 1798 return result 1799 1800 def runner_api_requires_keyed_input(self): 1801 return userstate.is_stateful_dofn(self.fn) 1802 1803 def get_restriction_coder(self): 1804 """Returns `restriction coder if `DoFn` of this `ParDo` is a SDF. 1805 1806 Returns `None` otherwise. 1807 """ 1808 from apache_beam.runners.common import DoFnSignature 1809 return DoFnSignature(self.fn).get_restriction_coder() 1810 1811 def _add_type_constraint_from_consumer(self, full_label, input_type_hints): 1812 if not hasattr(self.fn, '_runtime_output_constraints'): 1813 self.fn._runtime_output_constraints = {} 1814 self.fn._runtime_output_constraints[full_label] = input_type_hints 1815 1816 1817 class _MultiParDo(PTransform): 1818 def __init__(self, do_transform, tags, main_tag, allow_unknown_tags=None): 1819 super().__init__(do_transform.label) 1820 self._do_transform = do_transform 1821 self._tags = tags 1822 self._main_tag = main_tag 1823 self._allow_unknown_tags = allow_unknown_tags 1824 1825 def expand(self, pcoll): 1826 _ = pcoll | self._do_transform 1827 return pvalue.DoOutputsTuple( 1828 pcoll.pipeline, 1829 self._do_transform, 1830 self._tags, 1831 self._main_tag, 1832 self._allow_unknown_tags) 1833 1834 1835 class DoFnInfo(object): 1836 """This class represents the state in the ParDoPayload's function spec, 1837 which is the actual DoFn together with some data required for invoking it. 1838 """ 1839 @staticmethod 1840 def register_stateless_dofn(urn): 1841 def wrapper(cls): 1842 StatelessDoFnInfo.REGISTERED_DOFNS[urn] = cls 1843 cls._stateless_dofn_urn = urn 1844 return cls 1845 1846 return wrapper 1847 1848 @classmethod 1849 def create(cls, fn, args, kwargs): 1850 if hasattr(fn, '_stateless_dofn_urn'): 1851 assert not args and not kwargs 1852 return StatelessDoFnInfo(fn._stateless_dofn_urn) 1853 else: 1854 return PickledDoFnInfo(cls._pickled_do_fn_info(fn, args, kwargs)) 1855 1856 @staticmethod 1857 def from_runner_api(spec, unused_context): 1858 if spec.urn == python_urns.PICKLED_DOFN_INFO: 1859 return PickledDoFnInfo(spec.payload) 1860 elif spec.urn in StatelessDoFnInfo.REGISTERED_DOFNS: 1861 return StatelessDoFnInfo(spec.urn) 1862 else: 1863 raise ValueError('Unexpected DoFn type: %s' % spec.urn) 1864 1865 @staticmethod 1866 def _pickled_do_fn_info(fn, args, kwargs): 1867 # This can be cleaned up once all runners move to portability. 1868 return pickler.dumps((fn, args, kwargs, None, None)) 1869 1870 def serialized_dofn_data(self): 1871 raise NotImplementedError(type(self)) 1872 1873 1874 class PickledDoFnInfo(DoFnInfo): 1875 def __init__(self, serialized_data): 1876 self._serialized_data = serialized_data 1877 1878 def serialized_dofn_data(self): 1879 return self._serialized_data 1880 1881 def to_runner_api(self, unused_context): 1882 return beam_runner_api_pb2.FunctionSpec( 1883 urn=python_urns.PICKLED_DOFN_INFO, payload=self._serialized_data) 1884 1885 1886 class StatelessDoFnInfo(DoFnInfo): 1887 1888 REGISTERED_DOFNS = {} # type: typing.Dict[str, typing.Type[DoFn]] 1889 1890 def __init__(self, urn): 1891 # type: (str) -> None 1892 assert urn in self.REGISTERED_DOFNS 1893 self._urn = urn 1894 1895 def serialized_dofn_data(self): 1896 return self._pickled_do_fn_info(self.REGISTERED_DOFNS[self._urn](), (), {}) 1897 1898 def to_runner_api(self, unused_context): 1899 return beam_runner_api_pb2.FunctionSpec(urn=self._urn) 1900 1901 1902 def FlatMap(fn, *args, **kwargs): # pylint: disable=invalid-name 1903 """:func:`FlatMap` is like :class:`ParDo` except it takes a callable to 1904 specify the transformation. 1905 1906 The callable must return an iterable for each element of the input 1907 :class:`~apache_beam.pvalue.PCollection`. The elements of these iterables will 1908 be flattened into the output :class:`~apache_beam.pvalue.PCollection`. 1909 1910 Args: 1911 fn (callable): a callable object. 1912 *args: positional arguments passed to the transform callable. 1913 **kwargs: keyword arguments passed to the transform callable. 1914 1915 Returns: 1916 ~apache_beam.pvalue.PCollection: 1917 A :class:`~apache_beam.pvalue.PCollection` containing the 1918 :func:`FlatMap` outputs. 1919 1920 Raises: 1921 TypeError: If the **fn** passed as argument is not a callable. 1922 Typical error is to pass a :class:`DoFn` instance which is supported only 1923 for :class:`ParDo`. 1924 """ 1925 label = 'FlatMap(%s)' % ptransform.label_from_callable(fn) 1926 if not callable(fn): 1927 raise TypeError( 1928 'FlatMap can be used only with callable objects. ' 1929 'Received %r instead.' % (fn)) 1930 1931 pardo = ParDo(CallableWrapperDoFn(fn), *args, **kwargs) 1932 pardo.label = label 1933 return pardo 1934 1935 1936 def Map(fn, *args, **kwargs): # pylint: disable=invalid-name 1937 """:func:`Map` is like :func:`FlatMap` except its callable returns only a 1938 single element. 1939 1940 Args: 1941 fn (callable): a callable object. 1942 *args: positional arguments passed to the transform callable. 1943 **kwargs: keyword arguments passed to the transform callable. 1944 1945 Returns: 1946 ~apache_beam.pvalue.PCollection: 1947 A :class:`~apache_beam.pvalue.PCollection` containing the 1948 :func:`Map` outputs. 1949 1950 Raises: 1951 TypeError: If the **fn** passed as argument is not a callable. 1952 Typical error is to pass a :class:`DoFn` instance which is supported only 1953 for :class:`ParDo`. 1954 """ 1955 if not callable(fn): 1956 raise TypeError( 1957 'Map can be used only with callable objects. ' 1958 'Received %r instead.' % (fn)) 1959 from apache_beam.transforms.util import fn_takes_side_inputs 1960 if fn_takes_side_inputs(fn): 1961 wrapper = lambda x, *args, **kwargs: [fn(x, *args, **kwargs)] 1962 else: 1963 wrapper = lambda x: [fn(x)] 1964 1965 label = 'Map(%s)' % ptransform.label_from_callable(fn) 1966 1967 # TODO. What about callable classes? 1968 if hasattr(fn, '__name__'): 1969 wrapper.__name__ = fn.__name__ 1970 1971 # Proxy the type-hint information from the original function to this new 1972 # wrapped function. 1973 type_hints = get_type_hints(fn).with_defaults( 1974 typehints.decorators.IOTypeHints.from_callable(fn)) 1975 if type_hints.input_types is not None: 1976 wrapper = with_input_types( 1977 *type_hints.input_types[0], **type_hints.input_types[1])( 1978 wrapper) 1979 output_hint = type_hints.simple_output_type(label) 1980 if output_hint: 1981 wrapper = with_output_types( 1982 typehints.Iterable[_strip_output_annotations(output_hint)])( 1983 wrapper) 1984 # pylint: disable=protected-access 1985 wrapper._argspec_fn = fn 1986 # pylint: enable=protected-access 1987 1988 pardo = FlatMap(wrapper, *args, **kwargs) 1989 pardo.label = label 1990 return pardo 1991 1992 1993 def MapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name 1994 r""":func:`MapTuple` is like :func:`Map` but expects tuple inputs and 1995 flattens them into multiple input arguments. 1996 1997 beam.MapTuple(lambda a, b, ...: ...) 1998 1999 In other words 2000 2001 beam.MapTuple(fn) 2002 2003 is equivalent to 2004 2005 beam.Map(lambda element, ...: fn(\*element, ...)) 2006 2007 This can be useful when processing a PCollection of tuples 2008 (e.g. key-value pairs). 2009 2010 Args: 2011 fn (callable): a callable object. 2012 *args: positional arguments passed to the transform callable. 2013 **kwargs: keyword arguments passed to the transform callable. 2014 2015 Returns: 2016 ~apache_beam.pvalue.PCollection: 2017 A :class:`~apache_beam.pvalue.PCollection` containing the 2018 :func:`MapTuple` outputs. 2019 2020 Raises: 2021 TypeError: If the **fn** passed as argument is not a callable. 2022 Typical error is to pass a :class:`DoFn` instance which is supported only 2023 for :class:`ParDo`. 2024 """ 2025 if not callable(fn): 2026 raise TypeError( 2027 'MapTuple can be used only with callable objects. ' 2028 'Received %r instead.' % (fn)) 2029 2030 label = 'MapTuple(%s)' % ptransform.label_from_callable(fn) 2031 2032 arg_names, defaults = get_function_args_defaults(fn) 2033 num_defaults = len(defaults) 2034 if num_defaults < len(args) + len(kwargs): 2035 raise TypeError('Side inputs must have defaults for MapTuple.') 2036 2037 if defaults or args or kwargs: 2038 wrapper = lambda x, *args, **kwargs: [fn(*(tuple(x) + args), **kwargs)] 2039 else: 2040 wrapper = lambda x: [fn(*x)] 2041 2042 # Proxy the type-hint information from the original function to this new 2043 # wrapped function. 2044 type_hints = get_type_hints(fn).with_defaults( 2045 typehints.decorators.IOTypeHints.from_callable(fn)) 2046 if type_hints.input_types is not None: 2047 # TODO(BEAM-14052): ignore input hints, as we do not have enough 2048 # information to infer the input type hint of the wrapper function. 2049 pass 2050 output_hint = type_hints.simple_output_type(label) 2051 if output_hint: 2052 wrapper = with_output_types( 2053 typehints.Iterable[_strip_output_annotations(output_hint)])( 2054 wrapper) 2055 2056 # Replace the first (args) component. 2057 modified_arg_names = ['tuple_element'] + arg_names[-num_defaults:] 2058 modified_argspec = (modified_arg_names, defaults) 2059 pardo = ParDo( 2060 CallableWrapperDoFn(wrapper, fullargspec=modified_argspec), 2061 *args, 2062 **kwargs) 2063 pardo.label = label 2064 return pardo 2065 2066 2067 def FlatMapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name 2068 r""":func:`FlatMapTuple` is like :func:`FlatMap` but expects tuple inputs and 2069 flattens them into multiple input arguments. 2070 2071 beam.FlatMapTuple(lambda a, b, ...: ...) 2072 2073 is equivalent to Python 2 2074 2075 beam.FlatMap(lambda (a, b, ...), ...: ...) 2076 2077 In other words 2078 2079 beam.FlatMapTuple(fn) 2080 2081 is equivalent to 2082 2083 beam.FlatMap(lambda element, ...: fn(\*element, ...)) 2084 2085 This can be useful when processing a PCollection of tuples 2086 (e.g. key-value pairs). 2087 2088 Args: 2089 fn (callable): a callable object. 2090 *args: positional arguments passed to the transform callable. 2091 **kwargs: keyword arguments passed to the transform callable. 2092 2093 Returns: 2094 ~apache_beam.pvalue.PCollection: 2095 A :class:`~apache_beam.pvalue.PCollection` containing the 2096 :func:`FlatMapTuple` outputs. 2097 2098 Raises: 2099 TypeError: If the **fn** passed as argument is not a callable. 2100 Typical error is to pass a :class:`DoFn` instance which is supported only 2101 for :class:`ParDo`. 2102 """ 2103 if not callable(fn): 2104 raise TypeError( 2105 'FlatMapTuple can be used only with callable objects. ' 2106 'Received %r instead.' % (fn)) 2107 2108 label = 'FlatMapTuple(%s)' % ptransform.label_from_callable(fn) 2109 2110 arg_names, defaults = get_function_args_defaults(fn) 2111 num_defaults = len(defaults) 2112 if num_defaults < len(args) + len(kwargs): 2113 raise TypeError('Side inputs must have defaults for FlatMapTuple.') 2114 2115 if defaults or args or kwargs: 2116 wrapper = lambda x, *args, **kwargs: fn(*(tuple(x) + args), **kwargs) 2117 else: 2118 wrapper = lambda x: fn(*x) 2119 2120 # Proxy the type-hint information from the original function to this new 2121 # wrapped function. 2122 type_hints = get_type_hints(fn).with_defaults( 2123 typehints.decorators.IOTypeHints.from_callable(fn)) 2124 if type_hints.input_types is not None: 2125 # TODO(BEAM-14052): ignore input hints, as we do not have enough 2126 # information to infer the input type hint of the wrapper function. 2127 pass 2128 output_hint = type_hints.simple_output_type(label) 2129 if output_hint: 2130 wrapper = with_output_types(_strip_output_annotations(output_hint))(wrapper) 2131 2132 # Replace the first (args) component. 2133 modified_arg_names = ['tuple_element'] + arg_names[-num_defaults:] 2134 modified_argspec = (modified_arg_names, defaults) 2135 pardo = ParDo( 2136 CallableWrapperDoFn(wrapper, fullargspec=modified_argspec), 2137 *args, 2138 **kwargs) 2139 pardo.label = label 2140 return pardo 2141 2142 2143 class _ExceptionHandlingWrapper(ptransform.PTransform): 2144 """Implementation of ParDo.with_exception_handling.""" 2145 def __init__( 2146 self, 2147 fn, 2148 args, 2149 kwargs, 2150 main_tag, 2151 dead_letter_tag, 2152 exc_class, 2153 partial, 2154 use_subprocess, 2155 threshold, 2156 threshold_windowing, 2157 timeout): 2158 if partial and use_subprocess: 2159 raise ValueError('partial and use_subprocess are mutually incompatible.') 2160 self._fn = fn 2161 self._args = args 2162 self._kwargs = kwargs 2163 self._main_tag = main_tag 2164 self._dead_letter_tag = dead_letter_tag 2165 self._exc_class = exc_class 2166 self._partial = partial 2167 self._use_subprocess = use_subprocess 2168 self._threshold = threshold 2169 self._threshold_windowing = threshold_windowing 2170 self._timeout = timeout 2171 2172 def expand(self, pcoll): 2173 if self._use_subprocess: 2174 wrapped_fn = _SubprocessDoFn(self._fn, timeout=self._timeout) 2175 elif self._timeout: 2176 wrapped_fn = _TimeoutDoFn(self._fn, timeout=self._timeout) 2177 else: 2178 wrapped_fn = self._fn 2179 result = pcoll | ParDo( 2180 _ExceptionHandlingWrapperDoFn( 2181 wrapped_fn, self._dead_letter_tag, self._exc_class, self._partial), 2182 *self._args, 2183 **self._kwargs).with_outputs( 2184 self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True) 2185 2186 if self._threshold < 1.0: 2187 2188 class MaybeWindow(ptransform.PTransform): 2189 @staticmethod 2190 def expand(pcoll): 2191 if self._threshold_windowing: 2192 return pcoll | WindowInto(self._threshold_windowing) 2193 else: 2194 return pcoll 2195 2196 input_count_view = pcoll | 'CountTotal' >> ( 2197 MaybeWindow() | Map(lambda _: 1) 2198 | CombineGlobally(sum).as_singleton_view()) 2199 bad_count_pcoll = result[self._dead_letter_tag] | 'CountBad' >> ( 2200 MaybeWindow() | Map(lambda _: 1) 2201 | CombineGlobally(sum).without_defaults()) 2202 2203 def check_threshold(bad, total, threshold, window=DoFn.WindowParam): 2204 if bad > total * threshold: 2205 raise ValueError( 2206 'The number of failing elements within the window %r ' 2207 'exceeded threshold: %s / %s = %s > %s' % 2208 (window, bad, total, bad / total, threshold)) 2209 2210 _ = bad_count_pcoll | Map( 2211 check_threshold, input_count_view, self._threshold) 2212 2213 return result 2214 2215 2216 class _ExceptionHandlingWrapperDoFn(DoFn): 2217 def __init__(self, fn, dead_letter_tag, exc_class, partial): 2218 self._fn = fn 2219 self._dead_letter_tag = dead_letter_tag 2220 self._exc_class = exc_class 2221 self._partial = partial 2222 2223 def __getattribute__(self, name): 2224 if (name.startswith('__') or name in self.__dict__ or 2225 name in _ExceptionHandlingWrapperDoFn.__dict__): 2226 return object.__getattribute__(self, name) 2227 else: 2228 return getattr(self._fn, name) 2229 2230 def process(self, *args, **kwargs): 2231 try: 2232 result = self._fn.process(*args, **kwargs) 2233 if not self._partial: 2234 # Don't emit any results until we know there will be no errors. 2235 result = list(result) 2236 yield from result 2237 except self._exc_class as exn: 2238 yield pvalue.TaggedOutput( 2239 self._dead_letter_tag, 2240 ( 2241 args[0], ( 2242 type(exn), 2243 repr(exn), 2244 traceback.format_exception(*sys.exc_info())))) 2245 2246 2247 class _SubprocessDoFn(DoFn): 2248 """Process method run in a subprocess, turning hard crashes into exceptions. 2249 """ 2250 def __init__(self, fn, timeout=None): 2251 self._fn = fn 2252 self._serialized_fn = pickler.dumps(fn) 2253 self._timeout = timeout 2254 2255 def __getattribute__(self, name): 2256 if (name.startswith('__') or name in self.__dict__ or 2257 name in type(self).__dict__): 2258 return object.__getattribute__(self, name) 2259 else: 2260 return getattr(self._fn, name) 2261 2262 def setup(self): 2263 self._pool = None 2264 2265 def start_bundle(self): 2266 # The pool is initialized lazily, including calls to setup and start_bundle. 2267 # This allows us to continue processing elements after a crash. 2268 pass 2269 2270 def process(self, *args, **kwargs): 2271 return self._call_remote(self._remote_process, *args, **kwargs) 2272 2273 def finish_bundle(self): 2274 self._call_remote(self._remote_finish_bundle) 2275 2276 def teardown(self): 2277 self._call_remote(self._remote_teardown) 2278 self._terminate_pool() 2279 2280 def _call_remote(self, method, *args, **kwargs): 2281 if self._pool is None: 2282 self._pool = concurrent.futures.ProcessPoolExecutor(1) 2283 self._pool.submit(self._remote_init, self._serialized_fn).result() 2284 try: 2285 return self._pool.submit(method, *args, **kwargs).result( 2286 self._timeout if method == self._remote_process else None) 2287 except (concurrent.futures.process.BrokenProcessPool, 2288 TimeoutError, 2289 concurrent.futures._base.TimeoutError): 2290 self._terminate_pool() 2291 raise 2292 2293 def _terminate_pool(self): 2294 """Forcibly terminate the pool, not leaving any live subprocesses.""" 2295 pool = self._pool 2296 self._pool = None 2297 processes = list(pool._processes.values()) 2298 pool.shutdown(wait=False) 2299 for p in processes: 2300 if p.is_alive(): 2301 p.kill() 2302 time.sleep(1) 2303 for p in processes: 2304 if p.is_alive(): 2305 p.terminate() 2306 2307 # These are classmethods to avoid picking the state of self. 2308 # They should only be called in an isolated process, so there's no concern 2309 # about sharing state or thread safety. 2310 2311 @classmethod 2312 def _remote_init(cls, serialized_fn): 2313 cls._serialized_fn = serialized_fn 2314 cls._fn = None 2315 cls._started = False 2316 2317 @classmethod 2318 def _remote_process(cls, *args, **kwargs): 2319 if cls._fn is None: 2320 cls._fn = pickler.loads(cls._serialized_fn) 2321 cls._fn.setup() 2322 if not cls._started: 2323 cls._fn.start_bundle() 2324 cls._started = True 2325 result = cls._fn.process(*args, **kwargs) 2326 if result: 2327 # Don't return generator objects. 2328 result = list(result) 2329 return result 2330 2331 @classmethod 2332 def _remote_finish_bundle(cls): 2333 if cls._started: 2334 cls._started = False 2335 if cls._fn.finish_bundle(): 2336 # This is because we restart and re-initialize the pool if it crashed. 2337 raise RuntimeError( 2338 "Returning elements from _SubprocessDoFn.finish_bundle not safe.") 2339 2340 @classmethod 2341 def _remote_teardown(cls): 2342 if cls._fn: 2343 cls._fn.teardown() 2344 cls._fn = None 2345 2346 2347 class _TimeoutDoFn(DoFn): 2348 """Process method run in a separate thread allowing timeouts. 2349 """ 2350 def __init__(self, fn, timeout=None): 2351 self._fn = fn 2352 self._timeout = timeout 2353 self._pool = None 2354 2355 def __getattribute__(self, name): 2356 if (name.startswith('__') or name in self.__dict__ or 2357 name in type(self).__dict__): 2358 return object.__getattribute__(self, name) 2359 else: 2360 return getattr(self._fn, name) 2361 2362 def process(self, *args, **kwargs): 2363 if self._pool is None: 2364 self._pool = concurrent.futures.ThreadPoolExecutor(10) 2365 # Ensure we iterate over the entire output list in the given amount of time. 2366 try: 2367 return self._pool.submit( 2368 lambda: list(self._fn.process(*args, **kwargs))).result( 2369 self._timeout) 2370 except TimeoutError: 2371 self._pool.shutdown(wait=False) 2372 self._pool = None 2373 raise 2374 2375 def teardown(self): 2376 try: 2377 self._fn.teardown() 2378 finally: 2379 if self._pool is not None: 2380 self._pool.shutdown(wait=False) 2381 self._pool = None 2382 2383 2384 def Filter(fn, *args, **kwargs): # pylint: disable=invalid-name 2385 """:func:`Filter` is a :func:`FlatMap` with its callable filtering out 2386 elements. 2387 2388 Filter accepts a function that keeps elements that return True, and filters 2389 out the remaining elements. 2390 2391 Args: 2392 fn (``Callable[..., bool]``): a callable object. First argument will be an 2393 element. 2394 *args: positional arguments passed to the transform callable. 2395 **kwargs: keyword arguments passed to the transform callable. 2396 2397 Returns: 2398 ~apache_beam.pvalue.PCollection: 2399 A :class:`~apache_beam.pvalue.PCollection` containing the 2400 :func:`Filter` outputs. 2401 2402 Raises: 2403 TypeError: If the **fn** passed as argument is not a callable. 2404 Typical error is to pass a :class:`DoFn` instance which is supported only 2405 for :class:`ParDo`. 2406 """ 2407 if not callable(fn): 2408 raise TypeError( 2409 'Filter can be used only with callable objects. ' 2410 'Received %r instead.' % (fn)) 2411 wrapper = lambda x, *args, **kwargs: [x] if fn(x, *args, **kwargs) else [] 2412 2413 label = 'Filter(%s)' % ptransform.label_from_callable(fn) 2414 2415 # TODO: What about callable classes? 2416 if hasattr(fn, '__name__'): 2417 wrapper.__name__ = fn.__name__ 2418 2419 # Get type hints from this instance or the callable. Do not use output type 2420 # hints from the callable (which should be bool if set). 2421 fn_type_hints = typehints.decorators.IOTypeHints.from_callable(fn) 2422 if fn_type_hints is not None: 2423 fn_type_hints = fn_type_hints.with_output_types() 2424 type_hints = get_type_hints(fn).with_defaults(fn_type_hints) 2425 2426 # Proxy the type-hint information from the function being wrapped, setting the 2427 # output type to be the same as the input type. 2428 if type_hints.input_types is not None: 2429 wrapper = with_input_types( 2430 *type_hints.input_types[0], **type_hints.input_types[1])( 2431 wrapper) 2432 output_hint = type_hints.simple_output_type(label) 2433 if (output_hint is None and get_type_hints(wrapper).input_types and 2434 get_type_hints(wrapper).input_types[0]): 2435 output_hint = get_type_hints(wrapper).input_types[0][0] 2436 if output_hint: 2437 wrapper = with_output_types( 2438 typehints.Iterable[_strip_output_annotations(output_hint)])( 2439 wrapper) 2440 # pylint: disable=protected-access 2441 wrapper._argspec_fn = fn 2442 # pylint: enable=protected-access 2443 2444 pardo = FlatMap(wrapper, *args, **kwargs) 2445 pardo.label = label 2446 return pardo 2447 2448 2449 def _combine_payload(combine_fn, context): 2450 return beam_runner_api_pb2.CombinePayload( 2451 combine_fn=combine_fn.to_runner_api(context), 2452 accumulator_coder_id=context.coders.get_id( 2453 combine_fn.get_accumulator_coder())) 2454 2455 2456 class CombineGlobally(PTransform): 2457 """A :class:`CombineGlobally` transform. 2458 2459 Reduces a :class:`~apache_beam.pvalue.PCollection` to a single value by 2460 progressively applying a :class:`CombineFn` to portions of the 2461 :class:`~apache_beam.pvalue.PCollection` (and to intermediate values created 2462 thereby). See documentation in :class:`CombineFn` for details on the specifics 2463 on how :class:`CombineFn` s are applied. 2464 2465 Args: 2466 pcoll (~apache_beam.pvalue.PCollection): 2467 a :class:`~apache_beam.pvalue.PCollection` to be reduced into a single 2468 value. 2469 fn (callable): a :class:`CombineFn` object that will be called to 2470 progressively reduce the :class:`~apache_beam.pvalue.PCollection` into 2471 single values, or a callable suitable for wrapping by 2472 :class:`~apache_beam.transforms.core.CallableWrapperCombineFn`. 2473 *args: positional arguments passed to the :class:`CombineFn` object. 2474 **kwargs: keyword arguments passed to the :class:`CombineFn` object. 2475 2476 Raises: 2477 TypeError: If the output type of the input 2478 :class:`~apache_beam.pvalue.PCollection` is not compatible 2479 with ``Iterable[A]``. 2480 2481 Returns: 2482 ~apache_beam.pvalue.PCollection: A single-element 2483 :class:`~apache_beam.pvalue.PCollection` containing the main output of 2484 the :class:`CombineGlobally` transform. 2485 2486 Note that the positional and keyword arguments will be processed in order 2487 to detect :class:`~apache_beam.pvalue.PValue` s that will be computed as side 2488 inputs to the transform. 2489 During pipeline execution whenever the :class:`CombineFn` object gets executed 2490 (i.e. any of the :class:`CombineFn` methods get called), the 2491 :class:`~apache_beam.pvalue.PValue` arguments will be replaced by their 2492 actual value in the exact position where they appear in the argument lists. 2493 """ 2494 has_defaults = True 2495 as_view = False 2496 fanout = None # type: typing.Optional[int] 2497 2498 def __init__(self, fn, *args, **kwargs): 2499 if not (isinstance(fn, CombineFn) or callable(fn)): 2500 raise TypeError( 2501 'CombineGlobally can be used only with combineFn objects. ' 2502 'Received %r instead.' % (fn)) 2503 2504 super().__init__() 2505 self.fn = fn 2506 self.args = args 2507 self.kwargs = kwargs 2508 2509 def display_data(self): 2510 return { 2511 'combine_fn': DisplayDataItem( 2512 self.fn.__class__, label='Combine Function'), 2513 'combine_fn_dd': self.fn, 2514 } 2515 2516 def default_label(self): 2517 if self.fanout is None: 2518 return '%s(%s)' % ( 2519 self.__class__.__name__, ptransform.label_from_callable(self.fn)) 2520 else: 2521 return '%s(%s, fanout=%s)' % ( 2522 self.__class__.__name__, 2523 ptransform.label_from_callable(self.fn), 2524 self.fanout) 2525 2526 def _clone(self, **extra_attributes): 2527 clone = copy.copy(self) 2528 clone.__dict__.update(extra_attributes) 2529 return clone 2530 2531 def with_fanout(self, fanout): 2532 return self._clone(fanout=fanout) 2533 2534 def with_defaults(self, has_defaults=True): 2535 return self._clone(has_defaults=has_defaults) 2536 2537 def without_defaults(self): 2538 return self.with_defaults(False) 2539 2540 def as_singleton_view(self): 2541 return self._clone(as_view=True) 2542 2543 def expand(self, pcoll): 2544 def add_input_types(transform): 2545 type_hints = self.get_type_hints() 2546 if type_hints.input_types: 2547 return transform.with_input_types(type_hints.input_types[0][0]) 2548 return transform 2549 2550 combine_fn = CombineFn.maybe_from_callable( 2551 self.fn, has_side_inputs=self.args or self.kwargs) 2552 combine_per_key = CombinePerKey(combine_fn, *self.args, **self.kwargs) 2553 if self.fanout: 2554 combine_per_key = combine_per_key.with_hot_key_fanout(self.fanout) 2555 2556 combined = ( 2557 pcoll 2558 | 'KeyWithVoid' >> add_input_types( 2559 ParDo(_KeyWithNone()).with_output_types( 2560 typehints.KV[None, pcoll.element_type])) 2561 | 'CombinePerKey' >> combine_per_key 2562 | 'UnKey' >> Map(lambda k_v: k_v[1])) 2563 2564 if not self.has_defaults and not self.as_view: 2565 return combined 2566 2567 elif self.as_view: 2568 if self.has_defaults: 2569 try: 2570 combine_fn.setup(*self.args, **self.kwargs) 2571 # This is called in the main program, but cannot be avoided 2572 # in the as_view case as it must be available to all windows. 2573 default_value = combine_fn.apply([], *self.args, **self.kwargs) 2574 finally: 2575 combine_fn.teardown(*self.args, **self.kwargs) 2576 else: 2577 default_value = pvalue.AsSingleton._NO_DEFAULT 2578 return pvalue.AsSingleton(combined, default_value=default_value) 2579 2580 else: 2581 if pcoll.windowing.windowfn != GlobalWindows(): 2582 raise ValueError( 2583 "Default values are not yet supported in CombineGlobally() if the " 2584 "output PCollection is not windowed by GlobalWindows. " 2585 "Instead, use CombineGlobally().without_defaults() to output " 2586 "an empty PCollection if the input PCollection is empty, " 2587 "or CombineGlobally().as_singleton_view() to get the default " 2588 "output of the CombineFn if the input PCollection is empty.") 2589 2590 def typed(transform): 2591 # TODO(robertwb): We should infer this. 2592 if combined.element_type: 2593 return transform.with_output_types(combined.element_type) 2594 return transform 2595 2596 # Capture in closure (avoiding capturing self). 2597 args, kwargs = self.args, self.kwargs 2598 2599 def inject_default(_, combined): 2600 if combined: 2601 assert len(combined) == 1 2602 return combined[0] 2603 else: 2604 try: 2605 combine_fn.setup(*args, **kwargs) 2606 default = combine_fn.apply([], *args, **kwargs) 2607 finally: 2608 combine_fn.teardown(*args, **kwargs) 2609 return default 2610 2611 return ( 2612 pcoll.pipeline 2613 | 'DoOnce' >> Create([None]) 2614 | 'InjectDefault' >> typed( 2615 Map(inject_default, pvalue.AsList(combined)))) 2616 2617 @staticmethod 2618 @PTransform.register_urn( 2619 common_urns.composites.COMBINE_GLOBALLY.urn, 2620 beam_runner_api_pb2.CombinePayload) 2621 def from_runner_api_parameter(unused_ptransform, combine_payload, context): 2622 return CombineGlobally( 2623 CombineFn.from_runner_api(combine_payload.combine_fn, context)) 2624 2625 2626 @DoFnInfo.register_stateless_dofn(python_urns.KEY_WITH_NONE_DOFN) 2627 class _KeyWithNone(DoFn): 2628 def process(self, v): 2629 yield None, v 2630 2631 2632 class CombinePerKey(PTransformWithSideInputs): 2633 """A per-key Combine transform. 2634 2635 Identifies sets of values associated with the same key in the input 2636 PCollection, then applies a CombineFn to condense those sets to single 2637 values. See documentation in CombineFn for details on the specifics on how 2638 CombineFns are applied. 2639 2640 Args: 2641 pcoll: input pcollection. 2642 fn: instance of CombineFn to apply to all values under the same key in 2643 pcoll, or a callable whose signature is ``f(iterable, *args, **kwargs)`` 2644 (e.g., sum, max). 2645 *args: arguments and side inputs, passed directly to the CombineFn. 2646 **kwargs: arguments and side inputs, passed directly to the CombineFn. 2647 2648 Returns: 2649 A PObject holding the result of the combine operation. 2650 """ 2651 def with_hot_key_fanout(self, fanout): 2652 """A per-key combine operation like self but with two levels of aggregation. 2653 2654 If a given key is produced by too many upstream bundles, the final 2655 reduction can become a bottleneck despite partial combining being lifted 2656 pre-GroupByKey. In these cases it can be helpful to perform intermediate 2657 partial aggregations in parallel and then re-group to peform a final 2658 (per-key) combine. This is also useful for high-volume keys in streaming 2659 where combiners are not generally lifted for latency reasons. 2660 2661 Note that a fanout greater than 1 requires the data to be sent through 2662 two GroupByKeys, and a high fanout can also result in more shuffle data 2663 due to less per-bundle combining. Setting the fanout for a key at 1 or less 2664 places values on the "cold key" path that skip the intermediate level of 2665 aggregation. 2666 2667 Args: 2668 fanout: either None, for no fanout, an int, for a constant-degree fanout, 2669 or a callable mapping keys to a key-specific degree of fanout. 2670 2671 Returns: 2672 A per-key combining PTransform with the specified fanout. 2673 """ 2674 from apache_beam.transforms.combiners import curry_combine_fn 2675 if fanout is None: 2676 return self 2677 else: 2678 return _CombinePerKeyWithHotKeyFanout( 2679 curry_combine_fn(self.fn, self.args, self.kwargs), fanout) 2680 2681 def display_data(self): 2682 return { 2683 'combine_fn': DisplayDataItem( 2684 self.fn.__class__, label='Combine Function'), 2685 'combine_fn_dd': self.fn 2686 } 2687 2688 def make_fn(self, fn, has_side_inputs): 2689 self._fn_label = ptransform.label_from_callable(fn) 2690 return CombineFn.maybe_from_callable(fn, has_side_inputs) 2691 2692 def default_label(self): 2693 return '%s(%s)' % (self.__class__.__name__, self._fn_label) 2694 2695 def _process_argspec_fn(self): 2696 return lambda element, *args, **kwargs: None 2697 2698 def expand(self, pcoll): 2699 args, kwargs = util.insert_values_in_args( 2700 self.args, self.kwargs, self.side_inputs) 2701 return pcoll | GroupByKey() | 'Combine' >> CombineValues( 2702 self.fn, *args, **kwargs) 2703 2704 def default_type_hints(self): 2705 result = self.fn.get_type_hints() 2706 k = typehints.TypeVariable('K') 2707 if result.input_types: 2708 args, kwargs = result.input_types 2709 args = (typehints.Tuple[k, args[0]], ) + args[1:] 2710 result = result.with_input_types(*args, **kwargs) 2711 else: 2712 result = result.with_input_types(typehints.Tuple[k, typehints.Any]) 2713 if result.output_types: 2714 main_output_type = result.simple_output_type('') 2715 result = result.with_output_types(typehints.Tuple[k, main_output_type]) 2716 else: 2717 result = result.with_output_types(typehints.Tuple[k, typehints.Any]) 2718 return result 2719 2720 def to_runner_api_parameter( 2721 self, 2722 context, # type: PipelineContext 2723 ): 2724 # type: (...) -> typing.Tuple[str, beam_runner_api_pb2.CombinePayload] 2725 if self.args or self.kwargs: 2726 from apache_beam.transforms.combiners import curry_combine_fn 2727 combine_fn = curry_combine_fn(self.fn, self.args, self.kwargs) 2728 else: 2729 combine_fn = self.fn 2730 return ( 2731 common_urns.composites.COMBINE_PER_KEY.urn, 2732 _combine_payload(combine_fn, context)) 2733 2734 @staticmethod 2735 @PTransform.register_urn( 2736 common_urns.composites.COMBINE_PER_KEY.urn, 2737 beam_runner_api_pb2.CombinePayload) 2738 def from_runner_api_parameter(unused_ptransform, combine_payload, context): 2739 return CombinePerKey( 2740 CombineFn.from_runner_api(combine_payload.combine_fn, context)) 2741 2742 def runner_api_requires_keyed_input(self): 2743 return True 2744 2745 2746 # TODO(robertwb): Rename to CombineGroupedValues? 2747 class CombineValues(PTransformWithSideInputs): 2748 def make_fn(self, fn, has_side_inputs): 2749 return CombineFn.maybe_from_callable(fn, has_side_inputs) 2750 2751 def expand(self, pcoll): 2752 args, kwargs = util.insert_values_in_args( 2753 self.args, self.kwargs, self.side_inputs) 2754 2755 input_type = pcoll.element_type 2756 key_type = None 2757 if input_type is not None: 2758 key_type, _ = input_type.tuple_types 2759 2760 runtime_type_check = ( 2761 pcoll.pipeline._options.view_as(TypeOptions).runtime_type_check) 2762 return pcoll | ParDo( 2763 CombineValuesDoFn(key_type, self.fn, runtime_type_check), 2764 *args, 2765 **kwargs) 2766 2767 def to_runner_api_parameter(self, context): 2768 if self.args or self.kwargs: 2769 from apache_beam.transforms.combiners import curry_combine_fn 2770 combine_fn = curry_combine_fn(self.fn, self.args, self.kwargs) 2771 else: 2772 combine_fn = self.fn 2773 return ( 2774 common_urns.combine_components.COMBINE_GROUPED_VALUES.urn, 2775 _combine_payload(combine_fn, context)) 2776 2777 @staticmethod 2778 @PTransform.register_urn( 2779 common_urns.combine_components.COMBINE_GROUPED_VALUES.urn, 2780 beam_runner_api_pb2.CombinePayload) 2781 def from_runner_api_parameter(unused_ptransform, combine_payload, context): 2782 return CombineValues( 2783 CombineFn.from_runner_api(combine_payload.combine_fn, context)) 2784 2785 2786 class CombineValuesDoFn(DoFn): 2787 """DoFn for performing per-key Combine transforms.""" 2788 2789 def __init__( 2790 self, 2791 input_pcoll_type, 2792 combinefn, # type: CombineFn 2793 runtime_type_check, # type: bool 2794 ): 2795 super().__init__() 2796 self.combinefn = combinefn 2797 self.runtime_type_check = runtime_type_check 2798 2799 def setup(self): 2800 self.combinefn.setup() 2801 2802 def process(self, element, *args, **kwargs): 2803 # Expected elements input to this DoFn are 2-tuples of the form 2804 # (key, iter), with iter an iterable of all the values associated with key 2805 # in the input PCollection. 2806 if self.runtime_type_check: 2807 # Apply the combiner in a single operation rather than artificially 2808 # breaking it up so that output type violations manifest as TypeCheck 2809 # errors rather than type errors. 2810 return [(element[0], self.combinefn.apply(element[1], *args, **kwargs))] 2811 2812 # Add the elements into three accumulators (for testing of merge). 2813 elements = list(element[1]) 2814 accumulators = [] 2815 for k in range(3): 2816 if len(elements) <= k: 2817 break 2818 accumulators.append( 2819 self.combinefn.add_inputs( 2820 self.combinefn.create_accumulator(*args, **kwargs), 2821 elements[k::3], 2822 *args, 2823 **kwargs)) 2824 # Merge the accumulators. 2825 accumulator = self.combinefn.merge_accumulators( 2826 accumulators, *args, **kwargs) 2827 # Convert accumulator to the final result. 2828 return [( 2829 element[0], self.combinefn.extract_output(accumulator, *args, 2830 **kwargs))] 2831 2832 def teardown(self): 2833 self.combinefn.teardown() 2834 2835 def default_type_hints(self): 2836 hints = self.combinefn.get_type_hints() 2837 if hints.input_types: 2838 K = typehints.TypeVariable('K') 2839 args, kwargs = hints.input_types 2840 args = (typehints.Tuple[K, typehints.Iterable[args[0]]], ) + args[1:] 2841 hints = hints.with_input_types(*args, **kwargs) 2842 else: 2843 K = typehints.Any 2844 if hints.output_types: 2845 main_output_type = hints.simple_output_type('') 2846 hints = hints.with_output_types(typehints.Tuple[K, main_output_type]) 2847 return hints 2848 2849 2850 class _CombinePerKeyWithHotKeyFanout(PTransform): 2851 2852 def __init__( 2853 self, 2854 combine_fn, # type: CombineFn 2855 fanout, # type: typing.Union[int, typing.Callable[[typing.Any], int]] 2856 ): 2857 # type: (...) -> None 2858 self._combine_fn = combine_fn 2859 self._fanout_fn = ((lambda key: fanout) 2860 if isinstance(fanout, int) else fanout) 2861 2862 def default_label(self): 2863 return '%s(%s, fanout=%s)' % ( 2864 self.__class__.__name__, 2865 ptransform.label_from_callable(self._combine_fn), 2866 ptransform.label_from_callable(self._fanout_fn)) 2867 2868 def expand(self, pcoll): 2869 2870 from apache_beam.transforms.trigger import AccumulationMode 2871 combine_fn = self._combine_fn 2872 fanout_fn = self._fanout_fn 2873 2874 if isinstance(pcoll.windowing.windowfn, SlidingWindows): 2875 raise ValueError( 2876 'CombinePerKey.with_hot_key_fanout does not yet work properly with ' 2877 'SlidingWindows. See: https://github.com/apache/beam/issues/20528') 2878 2879 class SplitHotCold(DoFn): 2880 def start_bundle(self): 2881 # Spreading a hot key across all possible sub-keys for all bundles 2882 # would defeat the goal of not overwhelming downstream reducers 2883 # (as well as making less efficient use of PGBK combining tables). 2884 # Instead, each bundle independently makes a consistent choice about 2885 # which "shard" of a key to send its intermediate results. 2886 self._nonce = int(random.getrandbits(31)) 2887 2888 def process(self, element): 2889 key, value = element 2890 fanout = fanout_fn(key) 2891 if fanout <= 1: 2892 # Boolean indicates this is not an accumulator. 2893 yield (key, (False, value)) # cold 2894 else: 2895 yield pvalue.TaggedOutput('hot', ((self._nonce % fanout, key), value)) 2896 2897 class PreCombineFn(CombineFn): 2898 @staticmethod 2899 def extract_output(accumulator): 2900 # Boolean indicates this is an accumulator. 2901 return (True, accumulator) 2902 2903 setup = combine_fn.setup 2904 create_accumulator = combine_fn.create_accumulator 2905 add_input = combine_fn.add_input 2906 merge_accumulators = combine_fn.merge_accumulators 2907 compact = combine_fn.compact 2908 teardown = combine_fn.teardown 2909 2910 class PostCombineFn(CombineFn): 2911 @staticmethod 2912 def add_input(accumulator, element): 2913 is_accumulator, value = element 2914 if is_accumulator: 2915 return combine_fn.merge_accumulators([accumulator, value]) 2916 else: 2917 return combine_fn.add_input(accumulator, value) 2918 2919 setup = combine_fn.setup 2920 create_accumulator = combine_fn.create_accumulator 2921 merge_accumulators = combine_fn.merge_accumulators 2922 compact = combine_fn.compact 2923 extract_output = combine_fn.extract_output 2924 teardown = combine_fn.teardown 2925 2926 def StripNonce(nonce_key_value): 2927 (_, key), value = nonce_key_value 2928 return key, value 2929 2930 cold, hot = pcoll | ParDo(SplitHotCold()).with_outputs('hot', main='cold') 2931 cold.element_type = typehints.Any # No multi-output type hints. 2932 precombined_hot = ( 2933 hot 2934 # Avoid double counting that may happen with stacked accumulating mode. 2935 | 'WindowIntoDiscarding' >> WindowInto( 2936 pcoll.windowing, accumulation_mode=AccumulationMode.DISCARDING) 2937 | CombinePerKey(PreCombineFn()) 2938 | Map(StripNonce) 2939 | 'WindowIntoOriginal' >> WindowInto(pcoll.windowing)) 2940 return ((cold, precombined_hot) 2941 | Flatten() 2942 | CombinePerKey(PostCombineFn())) 2943 2944 2945 @typehints.with_input_types(typing.Tuple[K, V]) 2946 @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) 2947 class GroupByKey(PTransform): 2948 """A group by key transform. 2949 2950 Processes an input PCollection consisting of key/value pairs represented as a 2951 tuple pair. The result is a PCollection where values having a common key are 2952 grouped together. For example (a, 1), (b, 2), (a, 3) will result into 2953 (a, [1, 3]), (b, [2]). 2954 2955 The implementation here is used only when run on the local direct runner. 2956 """ 2957 class ReifyWindows(DoFn): 2958 def process( 2959 self, element, window=DoFn.WindowParam, timestamp=DoFn.TimestampParam): 2960 try: 2961 k, v = element 2962 except TypeError: 2963 raise TypeCheckError( 2964 'Input to GroupByKey must be a PCollection with ' 2965 'elements compatible with KV[A, B]') 2966 2967 return [(k, WindowedValue(v, timestamp, [window]))] 2968 2969 def infer_output_type(self, input_type): 2970 key_type, value_type = trivial_inference.key_value_types(input_type) 2971 return typehints.KV[ 2972 key_type, typehints.WindowedValue[value_type]] # type: ignore[misc] 2973 2974 def expand(self, pcoll): 2975 from apache_beam.transforms.trigger import DataLossReason 2976 from apache_beam.transforms.trigger import DefaultTrigger 2977 windowing = pcoll.windowing 2978 trigger = windowing.triggerfn 2979 if not pcoll.is_bounded and isinstance( 2980 windowing.windowfn, GlobalWindows) and isinstance(trigger, 2981 DefaultTrigger): 2982 if pcoll.pipeline.allow_unsafe_triggers: 2983 # TODO(BEAM-9487) Change comment for Beam 2.33 2984 _LOGGER.warning( 2985 '%s: PCollection passed to GroupByKey is unbounded, has a global ' 2986 'window, and uses a default trigger. This is being allowed ' 2987 'because --allow_unsafe_triggers is set, but it may prevent ' 2988 'data from making it through the pipeline.', 2989 self.label) 2990 else: 2991 raise ValueError( 2992 'GroupByKey cannot be applied to an unbounded ' + 2993 'PCollection with global windowing and a default trigger') 2994 2995 unsafe_reason = trigger.may_lose_data(windowing) 2996 if unsafe_reason != DataLossReason.NO_POTENTIAL_LOSS: 2997 reason_msg = str(unsafe_reason).replace('DataLossReason.', '') 2998 if pcoll.pipeline.allow_unsafe_triggers: 2999 _LOGGER.warning( 3000 '%s: Unsafe trigger `%s` detected (reason: %s). This is ' 3001 'being allowed because --allow_unsafe_triggers is set. This could ' 3002 'lead to missing or incomplete groups.', 3003 self.label, 3004 trigger, 3005 reason_msg) 3006 else: 3007 msg = '{}: Unsafe trigger: `{}` may lose data. '.format( 3008 self.label, trigger) 3009 msg += 'Reason: {}. '.format(reason_msg) 3010 msg += 'This can be overriden with the --allow_unsafe_triggers flag.' 3011 raise ValueError(msg) 3012 3013 return pvalue.PCollection.from_(pcoll) 3014 3015 def infer_output_type(self, input_type): 3016 key_type, value_type = (typehints.typehints.coerce_to_kv_type( 3017 input_type).tuple_types) 3018 return typehints.KV[key_type, typehints.Iterable[value_type]] 3019 3020 def to_runner_api_parameter(self, unused_context): 3021 # type: (PipelineContext) -> typing.Tuple[str, None] 3022 return common_urns.primitives.GROUP_BY_KEY.urn, None 3023 3024 @staticmethod 3025 @PTransform.register_urn(common_urns.primitives.GROUP_BY_KEY.urn, None) 3026 def from_runner_api_parameter( 3027 unused_ptransform, unused_payload, unused_context): 3028 return GroupByKey() 3029 3030 def runner_api_requires_keyed_input(self): 3031 return True 3032 3033 3034 def _expr_to_callable(expr, pos): 3035 if isinstance(expr, str): 3036 return lambda x: getattr(x, expr) 3037 elif callable(expr): 3038 return expr 3039 else: 3040 raise TypeError( 3041 'Field expression %r at %s must be a callable or a string.' % 3042 (expr, pos)) 3043 3044 3045 class GroupBy(PTransform): 3046 """Groups a PCollection by one or more expressions, used to derive the key. 3047 3048 `GroupBy(expr)` is roughly equivalent to 3049 3050 beam.Map(lambda v: (expr(v), v)) | beam.GroupByKey() 3051 3052 but provides several conveniences, e.g. 3053 3054 * Several arguments may be provided, as positional or keyword arguments, 3055 resulting in a tuple-like key. For example `GroupBy(a=expr1, b=expr2)` 3056 groups by a key with attributes `a` and `b` computed by applying 3057 `expr1` and `expr2` to each element. 3058 3059 * Strings can be used as a shorthand for accessing an attribute, e.g. 3060 `GroupBy('some_field')` is equivalent to 3061 `GroupBy(lambda v: getattr(v, 'some_field'))`. 3062 3063 The GroupBy operation can be made into an aggregating operation by invoking 3064 its `aggregate_field` method. 3065 """ 3066 3067 def __init__( 3068 self, 3069 *fields, # type: typing.Union[str, typing.Callable] 3070 **kwargs # type: typing.Union[str, typing.Callable] 3071 ): 3072 if len(fields) == 1 and not kwargs: 3073 self._force_tuple_keys = False 3074 name = fields[0] if isinstance(fields[0], str) else 'key' 3075 key_fields = [(name, _expr_to_callable(fields[0], 0))] 3076 else: 3077 self._force_tuple_keys = True 3078 key_fields = [] 3079 for ix, field in enumerate(fields): 3080 name = field if isinstance(field, str) else 'key%d' % ix 3081 key_fields.append((name, _expr_to_callable(field, ix))) 3082 for name, expr in kwargs.items(): 3083 key_fields.append((name, _expr_to_callable(expr, name))) 3084 self._key_fields = key_fields 3085 field_names = tuple(name for name, _ in key_fields) 3086 self._key_type = lambda *values: _dynamic_named_tuple('Key', field_names)( 3087 *values) 3088 3089 def aggregate_field( 3090 self, 3091 field, # type: typing.Union[str, typing.Callable] 3092 combine_fn, # type: typing.Union[typing.Callable, CombineFn] 3093 dest, # type: str 3094 ): 3095 """Returns a grouping operation that also aggregates grouped values. 3096 3097 Args: 3098 field: indicates the field to be aggregated 3099 combine_fn: indicates the aggregation function to be used 3100 dest: indicates the name that will be used for the aggregate in the output 3101 3102 May be called repeatedly to aggregate multiple fields, e.g. 3103 3104 GroupBy('key') 3105 .aggregate_field('some_attr', sum, 'sum_attr') 3106 .aggregate_field(lambda v: ..., MeanCombineFn, 'mean') 3107 """ 3108 return _GroupAndAggregate(self, ()).aggregate_field(field, combine_fn, dest) 3109 3110 def force_tuple_keys(self, value=True): 3111 """Forces the keys to always be tuple-like, even if there is only a single 3112 expression. 3113 """ 3114 res = copy.copy(self) 3115 res._force_tuple_keys = value 3116 return res 3117 3118 def _key_func(self): 3119 if not self._force_tuple_keys and len(self._key_fields) == 1: 3120 return self._key_fields[0][1] 3121 else: 3122 key_type = self._key_type 3123 key_exprs = [expr for _, expr in self._key_fields] 3124 return lambda element: key_type(*(expr(element) for expr in key_exprs)) 3125 3126 def _key_type_hint(self, input_type): 3127 if not self._force_tuple_keys and len(self._key_fields) == 1: 3128 expr = self._key_fields[0][1] 3129 return trivial_inference.infer_return_type(expr, [input_type]) 3130 else: 3131 return row_type.RowTypeConstraint.from_fields([ 3132 (name, trivial_inference.infer_return_type(expr, [input_type])) 3133 for (name, expr) in self._key_fields 3134 ]) 3135 3136 def default_label(self): 3137 return 'GroupBy(%s)' % ', '.join(name for name, _ in self._key_fields) 3138 3139 def expand(self, pcoll): 3140 input_type = pcoll.element_type or typing.Any 3141 return ( 3142 pcoll 3143 | Map(lambda x: (self._key_func()(x), x)).with_output_types( 3144 typehints.Tuple[self._key_type_hint(input_type), input_type]) 3145 | GroupByKey()) 3146 3147 3148 _dynamic_named_tuple_cache = { 3149 } # type: typing.Dict[typing.Tuple[str, typing.Tuple[str, ...]], typing.Type[tuple]] 3150 3151 3152 def _dynamic_named_tuple(type_name, field_names): 3153 # type: (str, typing.Tuple[str, ...]) -> typing.Type[tuple] 3154 cache_key = (type_name, field_names) 3155 result = _dynamic_named_tuple_cache.get(cache_key) 3156 if result is None: 3157 import collections 3158 result = _dynamic_named_tuple_cache[cache_key] = collections.namedtuple( 3159 type_name, field_names) 3160 # typing: can't override a method. also, self type is unknown and can't 3161 # be cast to tuple 3162 result.__reduce__ = lambda self: ( # type: ignore[assignment] 3163 _unpickle_dynamic_named_tuple, (type_name, field_names, tuple(self))) # type: ignore[arg-type] 3164 return result 3165 3166 3167 def _unpickle_dynamic_named_tuple(type_name, field_names, values): 3168 # type: (str, typing.Tuple[str, ...], typing.Iterable[typing.Any]) -> tuple 3169 return _dynamic_named_tuple(type_name, field_names)(*values) 3170 3171 3172 class _GroupAndAggregate(PTransform): 3173 def __init__(self, grouping, aggregations): 3174 self._grouping = grouping 3175 self._aggregations = aggregations 3176 3177 def aggregate_field( 3178 self, 3179 field, # type: typing.Union[str, typing.Callable] 3180 combine_fn, # type: typing.Union[typing.Callable, CombineFn] 3181 dest, # type: str 3182 ): 3183 field = _expr_to_callable(field, 0) 3184 return _GroupAndAggregate( 3185 self._grouping, list(self._aggregations) + [(field, combine_fn, dest)]) 3186 3187 def expand(self, pcoll): 3188 from apache_beam.transforms.combiners import TupleCombineFn 3189 key_func = self._grouping.force_tuple_keys(True)._key_func() 3190 value_exprs = [expr for expr, _, __ in self._aggregations] 3191 value_func = lambda element: [expr(element) for expr in value_exprs] 3192 result_fields = tuple(name 3193 for name, _ in self._grouping._key_fields) + tuple( 3194 dest for _, __, dest in self._aggregations) 3195 key_type_hint = self._grouping.force_tuple_keys(True)._key_type_hint( 3196 pcoll.element_type) 3197 3198 return ( 3199 pcoll 3200 | Map(lambda x: (key_func(x), value_func(x))).with_output_types( 3201 typehints.Tuple[key_type_hint, typing.Any]) 3202 | CombinePerKey( 3203 TupleCombineFn( 3204 *[combine_fn for _, combine_fn, __ in self._aggregations])) 3205 | MapTuple( 3206 lambda key, 3207 value: _dynamic_named_tuple('Result', result_fields) 3208 (*(key + value)))) 3209 3210 3211 class Select(PTransform): 3212 """Converts the elements of a PCollection into a schema'd PCollection of Rows. 3213 3214 `Select(...)` is roughly equivalent to `Map(lambda x: Row(...))` where each 3215 argument (which may be a string or callable) of `ToRow` is applied to `x`. 3216 For example, 3217 3218 pcoll | beam.Select('a', b=lambda x: foo(x)) 3219 3220 is the same as 3221 3222 pcoll | beam.Map(lambda x: beam.Row(a=x.a, b=foo(x))) 3223 """ 3224 3225 def __init__( 3226 self, 3227 *args, # type: typing.Union[str, typing.Callable] 3228 **kwargs # type: typing.Union[str, typing.Callable] 3229 ): 3230 self._fields = [( 3231 expr if isinstance(expr, str) else 'arg%02d' % ix, 3232 _expr_to_callable(expr, ix)) for (ix, expr) in enumerate(args) 3233 ] + [(name, _expr_to_callable(expr, name)) 3234 for (name, expr) in kwargs.items()] 3235 3236 def default_label(self): 3237 return 'ToRows(%s)' % ', '.join(name for name, _ in self._fields) 3238 3239 def expand(self, pcoll): 3240 return pcoll | Map( 3241 lambda x: pvalue.Row(**{name: expr(x) 3242 for name, expr in self._fields})) 3243 3244 def infer_output_type(self, input_type): 3245 return row_type.RowTypeConstraint.from_fields([ 3246 (name, trivial_inference.infer_return_type(expr, [input_type])) 3247 for (name, expr) in self._fields 3248 ]) 3249 3250 3251 class Partition(PTransformWithSideInputs): 3252 """Split a PCollection into several partitions. 3253 3254 Uses the specified PartitionFn to separate an input PCollection into the 3255 specified number of sub-PCollections. 3256 3257 When apply()d, a Partition() PTransform requires the following: 3258 3259 Args: 3260 partitionfn: a PartitionFn, or a callable with the signature described in 3261 CallableWrapperPartitionFn. 3262 n: number of output partitions. 3263 3264 The result of this PTransform is a simple list of the output PCollections 3265 representing each of n partitions, in order. 3266 """ 3267 class ApplyPartitionFnFn(DoFn): 3268 """A DoFn that applies a PartitionFn.""" 3269 def process(self, element, partitionfn, n, *args, **kwargs): 3270 partition = partitionfn.partition_for(element, n, *args, **kwargs) 3271 if not 0 <= partition < n: 3272 raise ValueError( 3273 'PartitionFn specified out-of-bounds partition index: ' 3274 '%d not in [0, %d)' % (partition, n)) 3275 # Each input is directed into the output that corresponds to the 3276 # selected partition. 3277 yield pvalue.TaggedOutput(str(partition), element) 3278 3279 def make_fn(self, fn, has_side_inputs): 3280 return fn if isinstance(fn, PartitionFn) else CallableWrapperPartitionFn(fn) 3281 3282 def expand(self, pcoll): 3283 n = int(self.args[0]) 3284 args, kwargs = util.insert_values_in_args( 3285 self.args, self.kwargs, self.side_inputs) 3286 return pcoll | ParDo(self.ApplyPartitionFnFn(), self.fn, *args, ** 3287 kwargs).with_outputs(*[str(t) for t in range(n)]) 3288 3289 3290 class Windowing(object): 3291 def __init__(self, 3292 windowfn, # type: WindowFn 3293 triggerfn=None, # type: typing.Optional[TriggerFn] 3294 accumulation_mode=None, # type: typing.Optional[beam_runner_api_pb2.AccumulationMode.Enum.ValueType] 3295 timestamp_combiner=None, # type: typing.Optional[beam_runner_api_pb2.OutputTime.Enum.ValueType] 3296 allowed_lateness=0, # type: typing.Union[int, float] 3297 environment_id=None, # type: typing.Optional[str] 3298 ): 3299 """Class representing the window strategy. 3300 3301 Args: 3302 windowfn: Window assign function. 3303 triggerfn: Trigger function. 3304 accumulation_mode: a AccumulationMode, controls what to do with data 3305 when a trigger fires multiple times. 3306 timestamp_combiner: a TimestampCombiner, determines how output 3307 timestamps of grouping operations are assigned. 3308 allowed_lateness: Maximum delay in seconds after end of window 3309 allowed for any late data to be processed without being discarded 3310 directly. 3311 environment_id: Environment where the current window_fn should be 3312 applied in. 3313 """ 3314 global AccumulationMode, DefaultTrigger # pylint: disable=global-variable-not-assigned 3315 # pylint: disable=wrong-import-order, wrong-import-position 3316 from apache_beam.transforms.trigger import AccumulationMode, DefaultTrigger 3317 # pylint: enable=wrong-import-order, wrong-import-position 3318 if triggerfn is None: 3319 triggerfn = DefaultTrigger() 3320 if accumulation_mode is None: 3321 if triggerfn == DefaultTrigger(): 3322 accumulation_mode = AccumulationMode.DISCARDING 3323 else: 3324 raise ValueError( 3325 'accumulation_mode must be provided for non-trivial triggers') 3326 if not windowfn.get_window_coder().is_deterministic(): 3327 raise ValueError( 3328 'window fn (%s) does not have a determanistic coder (%s)' % 3329 (windowfn, windowfn.get_window_coder())) 3330 self.windowfn = windowfn 3331 self.triggerfn = triggerfn 3332 self.accumulation_mode = accumulation_mode 3333 self.allowed_lateness = Duration.of(allowed_lateness) 3334 self.environment_id = environment_id 3335 self.timestamp_combiner = ( 3336 timestamp_combiner or TimestampCombiner.OUTPUT_AT_EOW) 3337 self._is_default = ( 3338 self.windowfn == GlobalWindows() and 3339 self.triggerfn == DefaultTrigger() and 3340 self.accumulation_mode == AccumulationMode.DISCARDING and 3341 self.timestamp_combiner == TimestampCombiner.OUTPUT_AT_EOW and 3342 self.allowed_lateness == 0) 3343 3344 def __repr__(self): 3345 return "Windowing(%s, %s, %s, %s, %s)" % ( 3346 self.windowfn, 3347 self.triggerfn, 3348 self.accumulation_mode, 3349 self.timestamp_combiner, 3350 self.environment_id) 3351 3352 def __eq__(self, other): 3353 if type(self) == type(other): 3354 if self._is_default and other._is_default: 3355 return True 3356 return ( 3357 self.windowfn == other.windowfn and 3358 self.triggerfn == other.triggerfn and 3359 self.accumulation_mode == other.accumulation_mode and 3360 self.timestamp_combiner == other.timestamp_combiner and 3361 self.allowed_lateness == other.allowed_lateness and 3362 self.environment_id == self.environment_id) 3363 return False 3364 3365 def __hash__(self): 3366 return hash(( 3367 self.windowfn, 3368 self.triggerfn, 3369 self.accumulation_mode, 3370 self.allowed_lateness, 3371 self.timestamp_combiner, 3372 self.environment_id)) 3373 3374 def is_default(self): 3375 return self._is_default 3376 3377 def to_runner_api(self, context): 3378 # type: (PipelineContext) -> beam_runner_api_pb2.WindowingStrategy 3379 environment_id = self.environment_id or context.default_environment_id() 3380 return beam_runner_api_pb2.WindowingStrategy( 3381 window_fn=self.windowfn.to_runner_api(context), 3382 # TODO(robertwb): Prohibit implicit multi-level merging. 3383 merge_status=( 3384 beam_runner_api_pb2.MergeStatus.NEEDS_MERGE 3385 if self.windowfn.is_merging() else 3386 beam_runner_api_pb2.MergeStatus.NON_MERGING), 3387 window_coder_id=context.coders.get_id(self.windowfn.get_window_coder()), 3388 trigger=self.triggerfn.to_runner_api(context), 3389 accumulation_mode=self.accumulation_mode, 3390 output_time=self.timestamp_combiner, 3391 # TODO(robertwb): Support EMIT_IF_NONEMPTY 3392 closing_behavior=beam_runner_api_pb2.ClosingBehavior.EMIT_ALWAYS, 3393 on_time_behavior=beam_runner_api_pb2.OnTimeBehavior.FIRE_ALWAYS, 3394 allowed_lateness=self.allowed_lateness.micros // 1000, 3395 environment_id=environment_id) 3396 3397 @staticmethod 3398 def from_runner_api(proto, context): 3399 # pylint: disable=wrong-import-order, wrong-import-position 3400 from apache_beam.transforms.trigger import TriggerFn 3401 return Windowing( 3402 windowfn=WindowFn.from_runner_api(proto.window_fn, context), 3403 triggerfn=TriggerFn.from_runner_api(proto.trigger, context), 3404 accumulation_mode=proto.accumulation_mode, 3405 timestamp_combiner=proto.output_time, 3406 allowed_lateness=Duration(micros=proto.allowed_lateness * 1000), 3407 environment_id=None) 3408 3409 3410 @typehints.with_input_types(T) 3411 @typehints.with_output_types(T) 3412 class WindowInto(ParDo): 3413 """A window transform assigning windows to each element of a PCollection. 3414 3415 Transforms an input PCollection by applying a windowing function to each 3416 element. Each transformed element in the result will be a WindowedValue 3417 element with the same input value and timestamp, with its new set of windows 3418 determined by the windowing function. 3419 """ 3420 class WindowIntoFn(DoFn): 3421 """A DoFn that applies a WindowInto operation.""" 3422 def __init__(self, windowing): 3423 # type: (Windowing) -> None 3424 self.windowing = windowing 3425 3426 def process( 3427 self, element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam): 3428 context = WindowFn.AssignContext( 3429 timestamp, element=element, window=window) 3430 new_windows = self.windowing.windowfn.assign(context) 3431 yield WindowedValue(element, context.timestamp, new_windows) 3432 3433 def __init__( 3434 self, 3435 windowfn, # type: typing.Union[Windowing, WindowFn] 3436 trigger=None, # type: typing.Optional[TriggerFn] 3437 accumulation_mode=None, 3438 timestamp_combiner=None, 3439 allowed_lateness=0): 3440 """Initializes a WindowInto transform. 3441 3442 Args: 3443 windowfn (Windowing, WindowFn): Function to be used for windowing. 3444 trigger: (optional) Trigger used for windowing, or None for default. 3445 accumulation_mode: (optional) Accumulation mode used for windowing, 3446 required for non-trivial triggers. 3447 timestamp_combiner: (optional) Timestamp combniner used for windowing, 3448 or None for default. 3449 """ 3450 if isinstance(windowfn, Windowing): 3451 # Overlay windowing with kwargs. 3452 windowing = windowfn 3453 windowfn = windowing.windowfn 3454 3455 # Use windowing to fill in defaults for the extra arguments. 3456 trigger = trigger or windowing.triggerfn 3457 accumulation_mode = accumulation_mode or windowing.accumulation_mode 3458 timestamp_combiner = timestamp_combiner or windowing.timestamp_combiner 3459 3460 self.windowing = Windowing( 3461 windowfn, 3462 trigger, 3463 accumulation_mode, 3464 timestamp_combiner, 3465 allowed_lateness) 3466 super().__init__(self.WindowIntoFn(self.windowing)) 3467 3468 def get_windowing(self, unused_inputs): 3469 # type: (typing.Any) -> Windowing 3470 return self.windowing 3471 3472 def infer_output_type(self, input_type): 3473 return input_type 3474 3475 def expand(self, pcoll): 3476 input_type = pcoll.element_type 3477 3478 if input_type is not None: 3479 output_type = input_type 3480 self.with_input_types(input_type) 3481 self.with_output_types(output_type) 3482 return super().expand(pcoll) 3483 3484 # typing: PTransform base class does not accept extra_kwargs 3485 def to_runner_api_parameter(self, context, **extra_kwargs): # type: ignore[override] 3486 # type: (PipelineContext, **typing.Any) -> typing.Tuple[str, message.Message] 3487 return ( 3488 common_urns.primitives.ASSIGN_WINDOWS.urn, 3489 self.windowing.to_runner_api(context)) 3490 3491 @staticmethod 3492 def from_runner_api_parameter(unused_ptransform, proto, context): 3493 windowing = Windowing.from_runner_api(proto, context) 3494 return WindowInto( 3495 windowing.windowfn, 3496 trigger=windowing.triggerfn, 3497 accumulation_mode=windowing.accumulation_mode, 3498 timestamp_combiner=windowing.timestamp_combiner) 3499 3500 3501 PTransform.register_urn( 3502 common_urns.primitives.ASSIGN_WINDOWS.urn, 3503 # TODO(robertwb): Update WindowIntoPayload to include the full strategy. 3504 # (Right now only WindowFn is used, but we need this to reconstitute the 3505 # WindowInto transform, and in the future will need it at runtime to 3506 # support meta-data driven triggers.) 3507 # TODO(robertwb): Use a reference rather than embedding? 3508 beam_runner_api_pb2.WindowingStrategy, 3509 WindowInto.from_runner_api_parameter) 3510 3511 # Python's pickling is broken for nested classes. 3512 WindowIntoFn = WindowInto.WindowIntoFn 3513 3514 3515 class Flatten(PTransform): 3516 """Merges several PCollections into a single PCollection. 3517 3518 Copies all elements in 0 or more PCollections into a single output 3519 PCollection. If there are no input PCollections, the resulting PCollection 3520 will be empty (but see also kwargs below). 3521 3522 Args: 3523 **kwargs: Accepts a single named argument "pipeline", which specifies the 3524 pipeline that "owns" this PTransform. Ordinarily Flatten can obtain this 3525 information from one of the input PCollections, but if there are none (or 3526 if there's a chance there may be none), this argument is the only way to 3527 provide pipeline information and should be considered mandatory. 3528 """ 3529 def __init__(self, **kwargs): 3530 super().__init__() 3531 self.pipeline = kwargs.pop( 3532 'pipeline', None) # type: typing.Optional[Pipeline] 3533 if kwargs: 3534 raise ValueError('Unexpected keyword arguments: %s' % list(kwargs)) 3535 3536 def _extract_input_pvalues(self, pvalueish): 3537 try: 3538 pvalueish = tuple(pvalueish) 3539 except TypeError: 3540 raise ValueError( 3541 'Input to Flatten must be an iterable. ' 3542 'Got a value of type %s instead.' % type(pvalueish)) 3543 return pvalueish, pvalueish 3544 3545 def expand(self, pcolls): 3546 for pcoll in pcolls: 3547 self._check_pcollection(pcoll) 3548 is_bounded = all(pcoll.is_bounded for pcoll in pcolls) 3549 return pvalue.PCollection(self.pipeline, is_bounded=is_bounded) 3550 3551 def infer_output_type(self, input_type): 3552 return input_type 3553 3554 def to_runner_api_parameter(self, context): 3555 # type: (PipelineContext) -> typing.Tuple[str, None] 3556 return common_urns.primitives.FLATTEN.urn, None 3557 3558 @staticmethod 3559 def from_runner_api_parameter( 3560 unused_ptransform, unused_parameter, unused_context): 3561 return Flatten() 3562 3563 3564 PTransform.register_urn( 3565 common_urns.primitives.FLATTEN.urn, None, Flatten.from_runner_api_parameter) 3566 3567 3568 class Create(PTransform): 3569 """A transform that creates a PCollection from an iterable.""" 3570 def __init__(self, values, reshuffle=True): 3571 """Initializes a Create transform. 3572 3573 Args: 3574 values: An object of values for the PCollection 3575 """ 3576 super().__init__() 3577 if isinstance(values, (str, bytes)): 3578 raise TypeError( 3579 'PTransform Create: Refusing to treat string as ' 3580 'an iterable. (string=%r)' % values) 3581 elif isinstance(values, dict): 3582 values = values.items() 3583 self.values = tuple(values) 3584 self.reshuffle = reshuffle 3585 self._coder = typecoders.registry.get_coder(self.get_output_type()) 3586 3587 def __getstate__(self): 3588 serialized_values = [self._coder.encode(v) for v in self.values] 3589 return serialized_values, self.reshuffle, self._coder 3590 3591 def __setstate__(self, state): 3592 serialized_values, self.reshuffle, self._coder = state 3593 self.values = [self._coder.decode(v) for v in serialized_values] 3594 3595 def to_runner_api_parameter(self, context): 3596 # type: (PipelineContext) -> typing.Tuple[str, bytes] 3597 # Required as this is identified by type in PTransformOverrides. 3598 # TODO(https://github.com/apache/beam/issues/18713): Use an actual URN 3599 # here. 3600 return self.to_runner_api_pickled(context) 3601 3602 def infer_output_type(self, unused_input_type): 3603 if not self.values: 3604 return typehints.Any 3605 return typehints.Union[[ 3606 trivial_inference.instance_to_type(v) for v in self.values 3607 ]] 3608 3609 def get_output_type(self): 3610 return ( 3611 self.get_type_hints().simple_output_type(self.label) or 3612 self.infer_output_type(None)) 3613 3614 def expand(self, pbegin): 3615 assert isinstance(pbegin, pvalue.PBegin) 3616 serialized_values = [self._coder.encode(v) for v in self.values] 3617 reshuffle = self.reshuffle 3618 3619 # Avoid the "redistributing" reshuffle for 0 and 1 element Creates. 3620 # These special cases are often used in building up more complex 3621 # transforms (e.g. Write). 3622 3623 class MaybeReshuffle(PTransform): 3624 def expand(self, pcoll): 3625 if len(serialized_values) > 1 and reshuffle: 3626 from apache_beam.transforms.util import Reshuffle 3627 return pcoll | Reshuffle() 3628 else: 3629 return pcoll 3630 3631 return ( 3632 pbegin 3633 | Impulse() 3634 | FlatMap(lambda _: serialized_values).with_output_types(bytes) 3635 | MaybeReshuffle().with_output_types(bytes) 3636 | Map(self._coder.decode).with_output_types(self.get_output_type())) 3637 3638 def as_read(self): 3639 from apache_beam.io import iobase 3640 source = self._create_source_from_iterable(self.values, self._coder) 3641 return iobase.Read(source).with_output_types(self.get_output_type()) 3642 3643 def get_windowing(self, unused_inputs): 3644 # type: (typing.Any) -> Windowing 3645 return Windowing(GlobalWindows()) 3646 3647 @staticmethod 3648 def _create_source_from_iterable(values, coder): 3649 return Create._create_source(list(map(coder.encode, values)), coder) 3650 3651 @staticmethod 3652 def _create_source(serialized_values, coder): 3653 # type: (typing.Any, typing.Any) -> create_source._CreateSource 3654 from apache_beam.transforms.create_source import _CreateSource 3655 3656 return _CreateSource(serialized_values, coder) 3657 3658 3659 @typehints.with_output_types(bytes) 3660 class Impulse(PTransform): 3661 """Impulse primitive.""" 3662 def expand(self, pbegin): 3663 if not isinstance(pbegin, pvalue.PBegin): 3664 raise TypeError( 3665 'Input to Impulse transform must be a PBegin but found %s' % pbegin) 3666 return pvalue.PCollection(pbegin.pipeline, element_type=bytes) 3667 3668 def get_windowing(self, inputs): 3669 # type: (typing.Any) -> Windowing 3670 return Windowing(GlobalWindows()) 3671 3672 def infer_output_type(self, unused_input_type): 3673 return bytes 3674 3675 def to_runner_api_parameter(self, unused_context): 3676 # type: (PipelineContext) -> typing.Tuple[str, None] 3677 return common_urns.primitives.IMPULSE.urn, None 3678 3679 @staticmethod 3680 @PTransform.register_urn(common_urns.primitives.IMPULSE.urn, None) 3681 def from_runner_api_parameter( 3682 unused_ptransform, unused_parameter, unused_context): 3683 return Impulse() 3684 3685 3686 def _strip_output_annotations(type_hint): 3687 # TODO(robertwb): These should be parameterized types that the 3688 # type inferencer understands. 3689 # Then we can replace them with the correct element types instead of 3690 # using Any. Refer to typehints.WindowedValue when doing this. 3691 annotations = (TimestampedValue, WindowedValue, pvalue.TaggedOutput) 3692 3693 contains_annotation = False 3694 3695 def visitor(t, unused_args): 3696 if t in annotations: 3697 raise StopIteration 3698 3699 try: 3700 visit_inner_types(type_hint, visitor, []) 3701 except StopIteration: 3702 contains_annotation = True 3703 3704 return typehints.Any if contains_annotation else type_hint