github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/ptransform.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """PTransform and descendants. 19 20 A PTransform is an object describing (not executing) a computation. The actual 21 execution semantics for a transform is captured by a runner object. A transform 22 object always belongs to a pipeline object. 23 24 A PTransform derived class needs to define the expand() method that describes 25 how one or more PValues are created by the transform. 26 27 The module defines a few standard transforms: FlatMap (parallel do), 28 GroupByKey (group by key), etc. Note that the expand() methods for these 29 classes contain code that will add nodes to the processing graph associated 30 with a pipeline. 31 32 As support for the FlatMap transform, the module also defines a DoFn 33 class and wrapper class that allows lambda functions to be used as 34 FlatMap processing functions. 35 """ 36 37 # pytype: skip-file 38 39 import copy 40 import itertools 41 import logging 42 import operator 43 import os 44 import sys 45 import threading 46 from functools import reduce 47 from functools import wraps 48 from typing import TYPE_CHECKING 49 from typing import Any 50 from typing import Callable 51 from typing import Dict 52 from typing import Generic 53 from typing import List 54 from typing import Mapping 55 from typing import Optional 56 from typing import Sequence 57 from typing import Tuple 58 from typing import Type 59 from typing import TypeVar 60 from typing import Union 61 from typing import overload 62 63 from google.protobuf import message 64 65 from apache_beam import error 66 from apache_beam import pvalue 67 from apache_beam.internal import pickler 68 from apache_beam.internal import util 69 from apache_beam.portability import python_urns 70 from apache_beam.pvalue import DoOutputsTuple 71 from apache_beam.transforms import resources 72 from apache_beam.transforms.display import DisplayDataItem 73 from apache_beam.transforms.display import HasDisplayData 74 from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX 75 from apache_beam.typehints import native_type_compatibility 76 from apache_beam.typehints import typehints 77 from apache_beam.typehints.decorators import IOTypeHints 78 from apache_beam.typehints.decorators import TypeCheckError 79 from apache_beam.typehints.decorators import WithTypeHints 80 from apache_beam.typehints.decorators import get_signature 81 from apache_beam.typehints.decorators import get_type_hints 82 from apache_beam.typehints.decorators import getcallargs_forhints 83 from apache_beam.typehints.trivial_inference import instance_to_type 84 from apache_beam.typehints.typehints import validate_composite_type_param 85 from apache_beam.utils import proto_utils 86 87 if TYPE_CHECKING: 88 from apache_beam import coders 89 from apache_beam.pipeline import Pipeline 90 from apache_beam.runners.pipeline_context import PipelineContext 91 from apache_beam.transforms.core import Windowing 92 from apache_beam.portability.api import beam_runner_api_pb2 93 94 __all__ = [ 95 'PTransform', 96 'ptransform_fn', 97 'label_from_callable', 98 ] 99 100 _LOGGER = logging.getLogger(__name__) 101 102 T = TypeVar('T') 103 InputT = TypeVar('InputT') 104 OutputT = TypeVar('OutputT') 105 PTransformT = TypeVar('PTransformT', bound='PTransform') 106 ConstructorFn = Callable[ 107 ['beam_runner_api_pb2.PTransform', Optional[Any], 'PipelineContext'], Any] 108 ptransform_fn_typehints_enabled = False 109 110 111 class _PValueishTransform(object): 112 """Visitor for PValueish objects. 113 114 A PValueish is a PValue, or list, tuple, dict of PValuesish objects. 115 116 This visits a PValueish, contstructing a (possibly mutated) copy. 117 """ 118 def visit_nested(self, node, *args): 119 if isinstance(node, (tuple, list)): 120 args = [self.visit(x, *args) for x in node] 121 if isinstance(node, tuple) and hasattr(node.__class__, '_make'): 122 # namedtuples require unpacked arguments in their constructor 123 return node.__class__(*args) 124 else: 125 return node.__class__(args) 126 elif isinstance(node, dict): 127 return node.__class__( 128 {key: self.visit(value, *args) 129 for (key, value) in node.items()}) 130 else: 131 return node 132 133 134 class _SetInputPValues(_PValueishTransform): 135 def visit(self, node, replacements): 136 if id(node) in replacements: 137 return replacements[id(node)] 138 else: 139 return self.visit_nested(node, replacements) 140 141 142 # Caches to allow for materialization of values when executing a pipeline 143 # in-process, in eager mode. This cache allows the same _MaterializedResult 144 # object to be accessed and used despite Runner API round-trip serialization. 145 _pipeline_materialization_cache = { 146 } # type: Dict[Tuple[int, int], Dict[int, _MaterializedResult]] 147 _pipeline_materialization_lock = threading.Lock() 148 149 150 def _allocate_materialized_pipeline(pipeline): 151 # type: (Pipeline) -> None 152 pid = os.getpid() 153 with _pipeline_materialization_lock: 154 pipeline_id = id(pipeline) 155 _pipeline_materialization_cache[(pid, pipeline_id)] = {} 156 157 158 def _allocate_materialized_result(pipeline): 159 # type: (Pipeline) -> _MaterializedResult 160 pid = os.getpid() 161 with _pipeline_materialization_lock: 162 pipeline_id = id(pipeline) 163 if (pid, pipeline_id) not in _pipeline_materialization_cache: 164 raise ValueError( 165 'Materialized pipeline is not allocated for result ' 166 'cache.') 167 result_id = len(_pipeline_materialization_cache[(pid, pipeline_id)]) 168 result = _MaterializedResult(pipeline_id, result_id) 169 _pipeline_materialization_cache[(pid, pipeline_id)][result_id] = result 170 return result 171 172 173 def _get_materialized_result(pipeline_id, result_id): 174 # type: (int, int) -> _MaterializedResult 175 pid = os.getpid() 176 with _pipeline_materialization_lock: 177 if (pid, pipeline_id) not in _pipeline_materialization_cache: 178 raise Exception( 179 'Materialization in out-of-process and remote runners is not yet ' 180 'supported.') 181 return _pipeline_materialization_cache[(pid, pipeline_id)][result_id] 182 183 184 def _release_materialized_pipeline(pipeline): 185 # type: (Pipeline) -> None 186 pid = os.getpid() 187 with _pipeline_materialization_lock: 188 pipeline_id = id(pipeline) 189 del _pipeline_materialization_cache[(pid, pipeline_id)] 190 191 192 class _MaterializedResult(object): 193 def __init__(self, pipeline_id, result_id): 194 # type: (int, int) -> None 195 self._pipeline_id = pipeline_id 196 self._result_id = result_id 197 self.elements = [] # type: List[Any] 198 199 def __reduce__(self): 200 # When unpickled (during Runner API roundtrip serailization), get the 201 # _MaterializedResult object from the cache so that values are written 202 # to the original _MaterializedResult when run in eager mode. 203 return (_get_materialized_result, (self._pipeline_id, self._result_id)) 204 205 206 class _MaterializedDoOutputsTuple(pvalue.DoOutputsTuple): 207 def __init__(self, deferred, results_by_tag): 208 super().__init__(None, None, deferred._tags, deferred._main_tag) 209 self._deferred = deferred 210 self._results_by_tag = results_by_tag 211 212 def __getitem__(self, tag): 213 if tag not in self._results_by_tag: 214 raise KeyError( 215 'Tag %r is not a defined output tag of %s.' % (tag, self._deferred)) 216 return self._results_by_tag[tag].elements 217 218 219 class _AddMaterializationTransforms(_PValueishTransform): 220 def _materialize_transform(self, pipeline): 221 result = _allocate_materialized_result(pipeline) 222 223 # Need to define _MaterializeValuesDoFn here to avoid circular 224 # dependencies. 225 from apache_beam import DoFn 226 from apache_beam import ParDo 227 228 class _MaterializeValuesDoFn(DoFn): 229 def process(self, element): 230 result.elements.append(element) 231 232 materialization_label = '_MaterializeValues%d' % result._result_id 233 return (materialization_label >> ParDo(_MaterializeValuesDoFn()), result) 234 235 def visit(self, node): 236 if isinstance(node, pvalue.PValue): 237 transform, result = self._materialize_transform(node.pipeline) 238 node | transform 239 return result 240 elif isinstance(node, pvalue.DoOutputsTuple): 241 results_by_tag = {} 242 for tag in itertools.chain([node._main_tag], node._tags): 243 results_by_tag[tag] = self.visit(node[tag]) 244 return _MaterializedDoOutputsTuple(node, results_by_tag) 245 else: 246 return self.visit_nested(node) 247 248 249 class _FinalizeMaterialization(_PValueishTransform): 250 def visit(self, node): 251 if isinstance(node, _MaterializedResult): 252 return node.elements 253 elif isinstance(node, _MaterializedDoOutputsTuple): 254 return node 255 else: 256 return self.visit_nested(node) 257 258 259 def get_named_nested_pvalues(pvalueish, as_inputs=False): 260 if isinstance(pvalueish, tuple): 261 # Check to see if it's a named tuple. 262 fields = getattr(pvalueish, '_fields', None) 263 if fields and len(fields) == len(pvalueish): 264 tagged_values = zip(fields, pvalueish) 265 else: 266 tagged_values = enumerate(pvalueish) 267 elif isinstance(pvalueish, list): 268 if as_inputs: 269 # Full list treated as a list of value for eager evaluation. 270 yield None, pvalueish 271 return 272 tagged_values = enumerate(pvalueish) 273 elif isinstance(pvalueish, dict): 274 tagged_values = pvalueish.items() 275 else: 276 if as_inputs or isinstance(pvalueish, 277 (pvalue.PValue, pvalue.DoOutputsTuple)): 278 yield None, pvalueish 279 return 280 281 for tag, subvalue in tagged_values: 282 for subtag, subsubvalue in get_named_nested_pvalues( 283 subvalue, as_inputs=as_inputs): 284 if subtag is None: 285 yield tag, subsubvalue 286 else: 287 yield '%s.%s' % (tag, subtag), subsubvalue 288 289 290 class _ZipPValues(object): 291 """Pairs each PValue in a pvalueish with a value in a parallel out sibling. 292 293 Sibling should have the same nested structure as pvalueish. Leaves in 294 sibling are expanded across nested pvalueish lists, tuples, and dicts. 295 For example 296 297 ZipPValues().visit({'a': pc1, 'b': (pc2, pc3)}, 298 {'a': 'A', 'b', 'B'}) 299 300 will return 301 302 [('a', pc1, 'A'), ('b', pc2, 'B'), ('b', pc3, 'B')] 303 """ 304 def visit(self, pvalueish, sibling, pairs=None, context=None): 305 if pairs is None: 306 pairs = [] 307 self.visit(pvalueish, sibling, pairs, context) 308 return pairs 309 elif isinstance(pvalueish, (pvalue.PValue, pvalue.DoOutputsTuple)): 310 pairs.append((context, pvalueish, sibling)) 311 elif isinstance(pvalueish, (list, tuple)): 312 self.visit_sequence(pvalueish, sibling, pairs, context) 313 elif isinstance(pvalueish, dict): 314 self.visit_dict(pvalueish, sibling, pairs, context) 315 316 def visit_sequence(self, pvalueish, sibling, pairs, context): 317 if isinstance(sibling, (list, tuple)): 318 for ix, (p, s) in enumerate(zip(pvalueish, 319 list(sibling) + [None] * len(pvalueish))): 320 self.visit(p, s, pairs, 'position %s' % ix) 321 else: 322 for p in pvalueish: 323 self.visit(p, sibling, pairs, context) 324 325 def visit_dict(self, pvalueish, sibling, pairs, context): 326 if isinstance(sibling, dict): 327 for key, p in pvalueish.items(): 328 self.visit(p, sibling.get(key), pairs, key) 329 else: 330 for p in pvalueish.values(): 331 self.visit(p, sibling, pairs, context) 332 333 334 class PTransform(WithTypeHints, HasDisplayData, Generic[InputT, OutputT]): 335 """A transform object used to modify one or more PCollections. 336 337 Subclasses must define an expand() method that will be used when the transform 338 is applied to some arguments. Typical usage pattern will be: 339 340 input | CustomTransform(...) 341 342 The expand() method of the CustomTransform object passed in will be called 343 with input as an argument. 344 """ 345 # By default, transforms don't have any side inputs. 346 side_inputs = () # type: Sequence[pvalue.AsSideInput] 347 348 # Used for nullary transforms. 349 pipeline = None # type: Optional[Pipeline] 350 351 # Default is unset. 352 _user_label = None # type: Optional[str] 353 354 def __init__(self, label=None): 355 # type: (Optional[str]) -> None 356 super().__init__() 357 self.label = label # type: ignore # https://github.com/python/mypy/issues/3004 358 359 @property 360 def label(self): 361 # type: () -> str 362 return self._user_label or self.default_label() 363 364 @label.setter 365 def label(self, value): 366 # type: (Optional[str]) -> None 367 self._user_label = value 368 369 def default_label(self): 370 # type: () -> str 371 return self.__class__.__name__ 372 373 def annotations(self) -> Dict[str, Union[bytes, str, message.Message]]: 374 return {} 375 376 def default_type_hints(self): 377 fn_type_hints = IOTypeHints.from_callable(self.expand) 378 if fn_type_hints is not None: 379 fn_type_hints = fn_type_hints.strip_pcoll() 380 381 # Prefer class decorator type hints for backwards compatibility. 382 return get_type_hints(self.__class__).with_defaults(fn_type_hints) 383 384 def with_input_types(self, input_type_hint): 385 """Annotates the input type of a :class:`PTransform` with a type-hint. 386 387 Args: 388 input_type_hint (type): An instance of an allowed built-in type, a custom 389 class, or an instance of a 390 :class:`~apache_beam.typehints.typehints.TypeConstraint`. 391 392 Raises: 393 TypeError: If **input_type_hint** is not a valid type-hint. 394 See 395 :obj:`apache_beam.typehints.typehints.validate_composite_type_param()` 396 for further details. 397 398 Returns: 399 PTransform: A reference to the instance of this particular 400 :class:`PTransform` object. This allows chaining type-hinting related 401 methods. 402 """ 403 input_type_hint = native_type_compatibility.convert_to_beam_type( 404 input_type_hint) 405 validate_composite_type_param( 406 input_type_hint, 'Type hints for a PTransform') 407 return super().with_input_types(input_type_hint) 408 409 def with_output_types(self, type_hint): 410 """Annotates the output type of a :class:`PTransform` with a type-hint. 411 412 Args: 413 type_hint (type): An instance of an allowed built-in type, a custom class, 414 or a :class:`~apache_beam.typehints.typehints.TypeConstraint`. 415 416 Raises: 417 TypeError: If **type_hint** is not a valid type-hint. See 418 :obj:`~apache_beam.typehints.typehints.validate_composite_type_param()` 419 for further details. 420 421 Returns: 422 PTransform: A reference to the instance of this particular 423 :class:`PTransform` object. This allows chaining type-hinting related 424 methods. 425 """ 426 type_hint = native_type_compatibility.convert_to_beam_type(type_hint) 427 validate_composite_type_param(type_hint, 'Type hints for a PTransform') 428 return super().with_output_types(type_hint) 429 430 def with_resource_hints(self, **kwargs): # type: (...) -> PTransform 431 """Adds resource hints to the :class:`PTransform`. 432 433 Resource hints allow users to express constraints on the environment where 434 the transform should be executed. Interpretation of the resource hints is 435 defined by Beam Runners. Runners may ignore the unsupported hints. 436 437 Args: 438 **kwargs: key-value pairs describing hints and their values. 439 440 Raises: 441 ValueError: if provided hints are unknown to the SDK. See 442 :mod:`apache_beam.transforms.resources` for a list of known hints. 443 444 Returns: 445 PTransform: A reference to the instance of this particular 446 :class:`PTransform` object. 447 """ 448 self.get_resource_hints().update(resources.parse_resource_hints(kwargs)) 449 return self 450 451 def get_resource_hints(self): 452 # type: () -> Dict[str, bytes] 453 if '_resource_hints' not in self.__dict__: 454 # PTransform subclasses don't always call super(), so prefer lazy 455 # initialization. By default, transforms don't have any resource hints. 456 self._resource_hints = {} # type: Dict[str, bytes] 457 return self._resource_hints 458 459 def type_check_inputs(self, pvalueish): 460 self.type_check_inputs_or_outputs(pvalueish, 'input') 461 462 def infer_output_type(self, unused_input_type): 463 return self.get_type_hints().simple_output_type(self.label) or typehints.Any 464 465 def type_check_outputs(self, pvalueish): 466 self.type_check_inputs_or_outputs(pvalueish, 'output') 467 468 def type_check_inputs_or_outputs(self, pvalueish, input_or_output): 469 type_hints = self.get_type_hints() 470 hints = getattr(type_hints, input_or_output + '_types') 471 if hints is None or not any(hints): 472 return 473 arg_hints, kwarg_hints = hints 474 if arg_hints and kwarg_hints: 475 raise TypeCheckError( 476 'PTransform cannot have both positional and keyword type hints ' 477 'without overriding %s._type_check_%s()' % 478 (self.__class__, input_or_output)) 479 root_hint = ( 480 arg_hints[0] if len(arg_hints) == 1 else arg_hints or kwarg_hints) 481 for context, pvalue_, hint in _ZipPValues().visit(pvalueish, root_hint): 482 if isinstance(pvalue_, DoOutputsTuple): 483 continue 484 if pvalue_.element_type is None: 485 # TODO(robertwb): It's a bug that we ever get here. (typecheck) 486 continue 487 if hint and not typehints.is_consistent_with(pvalue_.element_type, hint): 488 at_context = ' %s %s' % (input_or_output, context) if context else '' 489 raise TypeCheckError( 490 '{type} type hint violation at {label}{context}: expected {hint}, ' 491 'got {actual_type}\nFull type hint:\n{debug_str}'.format( 492 type=input_or_output.title(), 493 label=self.label, 494 context=at_context, 495 hint=hint, 496 actual_type=pvalue_.element_type, 497 debug_str=type_hints.debug_str())) 498 499 def _infer_output_coder(self, input_type=None, input_coder=None): 500 # type: (...) -> Optional[coders.Coder] 501 502 """Returns the output coder to use for output of this transform. 503 504 The Coder returned here should not be wrapped in a WindowedValueCoder 505 wrapper. 506 507 Args: 508 input_type: An instance of an allowed built-in type, a custom class, or a 509 typehints.TypeConstraint for the input type, or None if not available. 510 input_coder: Coder object for encoding input to this PTransform, or None 511 if not available. 512 513 Returns: 514 Coder object for encoding output of this PTransform or None if unknown. 515 """ 516 # TODO(ccy): further refine this API. 517 return None 518 519 def _clone(self, new_label): 520 """Clones the current transform instance under a new label.""" 521 transform = copy.copy(self) 522 transform.label = new_label 523 return transform 524 525 def expand(self, input_or_inputs: InputT) -> OutputT: 526 raise NotImplementedError 527 528 def __str__(self): 529 return '<%s>' % self._str_internal() 530 531 def __repr__(self): 532 return '<%s at %s>' % (self._str_internal(), hex(id(self))) 533 534 def _str_internal(self): 535 return '%s(PTransform)%s%s%s' % ( 536 self.__class__.__name__, 537 ' label=[%s]' % self.label if 538 (hasattr(self, 'label') and self.label) else '', 539 ' inputs=%s' % str(self.inputs) if 540 (hasattr(self, 'inputs') and self.inputs) else '', 541 ' side_inputs=%s' % str(self.side_inputs) if self.side_inputs else '') 542 543 def _check_pcollection(self, pcoll): 544 # type: (pvalue.PCollection) -> None 545 if not isinstance(pcoll, pvalue.PCollection): 546 raise error.TransformError('Expecting a PCollection argument.') 547 if not pcoll.pipeline: 548 raise error.TransformError('PCollection not part of a pipeline.') 549 550 def get_windowing(self, inputs): 551 # type: (Any) -> Windowing 552 553 """Returns the window function to be associated with transform's output. 554 555 By default most transforms just return the windowing function associated 556 with the input PCollection (or the first input if several). 557 """ 558 if inputs: 559 return inputs[0].windowing 560 else: 561 from apache_beam.transforms.core import Windowing 562 from apache_beam.transforms.window import GlobalWindows 563 # TODO(robertwb): Return something compatible with every windowing? 564 return Windowing(GlobalWindows()) 565 566 def __rrshift__(self, label): 567 return _NamedPTransform(self, label) 568 569 def __or__(self, right): 570 """Used to compose PTransforms, e.g., ptransform1 | ptransform2.""" 571 if isinstance(right, PTransform): 572 return _ChainedPTransform(self, right) 573 return NotImplemented 574 575 def __ror__(self, left, label=None): 576 """Used to apply this PTransform to non-PValues, e.g., a tuple.""" 577 pvalueish, pvalues = self._extract_input_pvalues(left) 578 if isinstance(pvalues, dict): 579 pvalues = tuple(pvalues.values()) 580 pipelines = [v.pipeline for v in pvalues if isinstance(v, pvalue.PValue)] 581 if pvalues and not pipelines: 582 deferred = False 583 # pylint: disable=wrong-import-order, wrong-import-position 584 from apache_beam import pipeline 585 from apache_beam.options.pipeline_options import PipelineOptions 586 # pylint: enable=wrong-import-order, wrong-import-position 587 p = pipeline.Pipeline('DirectRunner', PipelineOptions(sys.argv)) 588 else: 589 if not pipelines: 590 if self.pipeline is not None: 591 p = self.pipeline 592 else: 593 raise ValueError( 594 '"%s" requires a pipeline to be specified ' 595 'as there are no deferred inputs.' % self.label) 596 else: 597 p = self.pipeline or pipelines[0] 598 for pp in pipelines: 599 if p != pp: 600 raise ValueError( 601 'Mixing values in different pipelines is not allowed.' 602 '\n{%r} != {%r}' % (p, pp)) 603 deferred = not getattr(p.runner, 'is_eager', False) 604 # pylint: disable=wrong-import-order, wrong-import-position 605 from apache_beam.transforms.core import Create 606 # pylint: enable=wrong-import-order, wrong-import-position 607 replacements = { 608 id(v): p | 'CreatePInput%s' % ix >> Create(v, reshuffle=False) 609 for (ix, v) in enumerate(pvalues) 610 if not isinstance(v, pvalue.PValue) and v is not None 611 } 612 pvalueish = _SetInputPValues().visit(pvalueish, replacements) 613 self.pipeline = p 614 result = p.apply(self, pvalueish, label) 615 if deferred: 616 return result 617 _allocate_materialized_pipeline(p) 618 materialized_result = _AddMaterializationTransforms().visit(result) 619 p.run().wait_until_finish() 620 _release_materialized_pipeline(p) 621 return _FinalizeMaterialization().visit(materialized_result) 622 623 def _extract_input_pvalues(self, pvalueish): 624 """Extract all the pvalues contained in the input pvalueish. 625 626 Returns pvalueish as well as the flat inputs list as the input may have to 627 be copied as inspection may be destructive. 628 629 By default, recursively extracts tuple components and dict values. 630 631 Generally only needs to be overriden for multi-input PTransforms. 632 """ 633 # pylint: disable=wrong-import-order 634 from apache_beam import pipeline 635 # pylint: enable=wrong-import-order 636 if isinstance(pvalueish, pipeline.Pipeline): 637 pvalueish = pvalue.PBegin(pvalueish) 638 639 return pvalueish, { 640 str(tag): value 641 for (tag, value) in get_named_nested_pvalues( 642 pvalueish, as_inputs=True) 643 } 644 645 def _pvaluish_from_dict(self, input_dict): 646 if len(input_dict) == 1: 647 return next(iter(input_dict.values())) 648 else: 649 return input_dict 650 651 def _named_inputs(self, main_inputs, side_inputs): 652 # type: (Mapping[str, pvalue.PValue], Sequence[Any]) -> Dict[str, pvalue.PValue] 653 654 """Returns the dictionary of named inputs (including side inputs) as they 655 should be named in the beam proto. 656 """ 657 main_inputs = { 658 tag: input 659 for (tag, input) in main_inputs.items() 660 if isinstance(input, pvalue.PCollection) 661 } 662 named_side_inputs = {(SIDE_INPUT_PREFIX + '%s') % ix: si.pvalue 663 for (ix, si) in enumerate(side_inputs)} 664 return dict(main_inputs, **named_side_inputs) 665 666 def _named_outputs(self, outputs): 667 # type: (Dict[object, pvalue.PCollection]) -> Dict[str, pvalue.PCollection] 668 669 """Returns the dictionary of named outputs as they should be named in the 670 beam proto. 671 """ 672 # TODO(BEAM-1833): Push names up into the sdk construction. 673 return { 674 str(tag): output 675 for (tag, output) in outputs.items() 676 if isinstance(output, pvalue.PCollection) 677 } 678 679 _known_urns = {} # type: Dict[str, Tuple[Optional[type], ConstructorFn]] 680 681 @classmethod 682 @overload 683 def register_urn( 684 cls, 685 urn, # type: str 686 parameter_type, # type: Type[T] 687 ): 688 # type: (...) -> Callable[[Union[type, Callable[[beam_runner_api_pb2.PTransform, T, PipelineContext], Any]]], Callable[[T, PipelineContext], Any]] 689 pass 690 691 @classmethod 692 @overload 693 def register_urn( 694 cls, 695 urn, # type: str 696 parameter_type, # type: None 697 ): 698 # type: (...) -> Callable[[Union[type, Callable[[beam_runner_api_pb2.PTransform, bytes, PipelineContext], Any]]], Callable[[bytes, PipelineContext], Any]] 699 pass 700 701 @classmethod 702 @overload 703 def register_urn(cls, 704 urn, # type: str 705 parameter_type, # type: Type[T] 706 constructor # type: Callable[[beam_runner_api_pb2.PTransform, T, PipelineContext], Any] 707 ): 708 # type: (...) -> None 709 pass 710 711 @classmethod 712 @overload 713 def register_urn(cls, 714 urn, # type: str 715 parameter_type, # type: None 716 constructor # type: Callable[[beam_runner_api_pb2.PTransform, bytes, PipelineContext], Any] 717 ): 718 # type: (...) -> None 719 pass 720 721 @classmethod 722 def register_urn(cls, urn, parameter_type, constructor=None): 723 def register(constructor): 724 if isinstance(constructor, type): 725 constructor.from_runner_api_parameter = register( 726 constructor.from_runner_api_parameter) 727 else: 728 cls._known_urns[urn] = parameter_type, constructor 729 return constructor 730 731 if constructor: 732 # Used as a statement. 733 register(constructor) 734 else: 735 # Used as a decorator. 736 return register 737 738 def to_runner_api(self, context, has_parts=False, **extra_kwargs): 739 # type: (PipelineContext, bool, Any) -> beam_runner_api_pb2.FunctionSpec 740 from apache_beam.portability.api import beam_runner_api_pb2 741 # typing: only ParDo supports extra_kwargs 742 urn, typed_param = self.to_runner_api_parameter(context, **extra_kwargs) # type: ignore[call-arg] 743 if urn == python_urns.GENERIC_COMPOSITE_TRANSFORM and not has_parts: 744 # TODO(https://github.com/apache/beam/issues/18713): Remove this fallback. 745 urn, typed_param = self.to_runner_api_pickled(context) 746 return beam_runner_api_pb2.FunctionSpec( 747 urn=urn, 748 payload=typed_param.SerializeToString() if isinstance( 749 typed_param, message.Message) else typed_param.encode('utf-8') 750 if isinstance(typed_param, str) else typed_param) 751 752 @classmethod 753 def from_runner_api(cls, 754 proto, # type: Optional[beam_runner_api_pb2.PTransform] 755 context # type: PipelineContext 756 ): 757 # type: (...) -> Optional[PTransform] 758 if proto is None or proto.spec is None or not proto.spec.urn: 759 return None 760 parameter_type, constructor = cls._known_urns[proto.spec.urn] 761 762 return constructor( 763 proto, 764 proto_utils.parse_Bytes(proto.spec.payload, parameter_type), 765 context) 766 767 def to_runner_api_parameter( 768 self, 769 unused_context # type: PipelineContext 770 ): 771 # type: (...) -> Tuple[str, Optional[Union[message.Message, bytes, str]]] 772 # The payload here is just to ease debugging. 773 return ( 774 python_urns.GENERIC_COMPOSITE_TRANSFORM, 775 getattr(self, '_fn_api_payload', str(self))) 776 777 def to_runner_api_pickled(self, unused_context): 778 # type: (PipelineContext) -> Tuple[str, bytes] 779 return (python_urns.PICKLED_TRANSFORM, pickler.dumps(self)) 780 781 def runner_api_requires_keyed_input(self): 782 return False 783 784 def _add_type_constraint_from_consumer(self, full_label, input_type_hints): 785 # type: (str, Tuple[str, Any]) -> None 786 787 """Adds a consumer transform's input type hints to our output type 788 constraints, which is used during performance runtime type-checking. 789 """ 790 pass 791 792 793 @PTransform.register_urn(python_urns.GENERIC_COMPOSITE_TRANSFORM, None) 794 def _create_transform(unused_ptransform, payload, unused_context): 795 empty_transform = PTransform() 796 empty_transform._fn_api_payload = payload 797 return empty_transform 798 799 800 @PTransform.register_urn(python_urns.PICKLED_TRANSFORM, None) 801 def _unpickle_transform(unused_ptransform, pickled_bytes, unused_context): 802 return pickler.loads(pickled_bytes) 803 804 805 class _ChainedPTransform(PTransform): 806 def __init__(self, *parts): 807 # type: (*PTransform) -> None 808 super().__init__(label=self._chain_label(parts)) 809 self._parts = parts 810 811 def _chain_label(self, parts): 812 return '|'.join(p.label for p in parts) 813 814 def __or__(self, right): 815 if isinstance(right, PTransform): 816 # Create a flat list rather than a nested tree of composite 817 # transforms for better monitoring, etc. 818 return _ChainedPTransform(*(self._parts + (right, ))) 819 return NotImplemented 820 821 def expand(self, pval): 822 return reduce(operator.or_, self._parts, pval) 823 824 825 class PTransformWithSideInputs(PTransform): 826 """A superclass for any :class:`PTransform` (e.g. 827 :func:`~apache_beam.transforms.core.FlatMap` or 828 :class:`~apache_beam.transforms.core.CombineFn`) 829 invoking user code. 830 831 :class:`PTransform` s like :func:`~apache_beam.transforms.core.FlatMap` 832 invoke user-supplied code in some kind of package (e.g. a 833 :class:`~apache_beam.transforms.core.DoFn`) and optionally provide arguments 834 and side inputs to that code. This internal-use-only class contains common 835 functionality for :class:`PTransform` s that fit this model. 836 """ 837 def __init__(self, fn, *args, **kwargs): 838 # type: (WithTypeHints, *Any, **Any) -> None 839 if isinstance(fn, type) and issubclass(fn, WithTypeHints): 840 # Don't treat Fn class objects as callables. 841 raise ValueError('Use %s() not %s.' % (fn.__name__, fn.__name__)) 842 self.fn = self.make_fn(fn, bool(args or kwargs)) 843 # Now that we figure out the label, initialize the super-class. 844 super().__init__() 845 846 if (any(isinstance(v, pvalue.PCollection) for v in args) or 847 any(isinstance(v, pvalue.PCollection) for v in kwargs.values())): 848 raise error.SideInputError( 849 'PCollection used directly as side input argument. Specify ' 850 'AsIter(pcollection) or AsSingleton(pcollection) to indicate how the ' 851 'PCollection is to be used.') 852 self.args, self.kwargs, self.side_inputs = util.remove_objects_from_args( 853 args, kwargs, pvalue.AsSideInput) 854 self.raw_side_inputs = args, kwargs 855 856 # Prevent name collisions with fns of the form '<function <lambda> at ...>' 857 self._cached_fn = self.fn 858 859 # Ensure fn and side inputs are picklable for remote execution. 860 try: 861 self.fn = pickler.loads(pickler.dumps(self.fn)) 862 except RuntimeError as e: 863 raise RuntimeError('Unable to pickle fn %s: %s' % (self.fn, e)) 864 865 self.args = pickler.loads(pickler.dumps(self.args)) 866 self.kwargs = pickler.loads(pickler.dumps(self.kwargs)) 867 868 # For type hints, because loads(dumps(class)) != class. 869 self.fn = self._cached_fn 870 871 def with_input_types( 872 self, input_type_hint, *side_inputs_arg_hints, **side_input_kwarg_hints): 873 """Annotates the types of main inputs and side inputs for the PTransform. 874 875 Args: 876 input_type_hint: An instance of an allowed built-in type, a custom class, 877 or an instance of a typehints.TypeConstraint. 878 *side_inputs_arg_hints: A variable length argument composed of 879 of an allowed built-in type, a custom class, or a 880 typehints.TypeConstraint. 881 **side_input_kwarg_hints: A dictionary argument composed of 882 of an allowed built-in type, a custom class, or a 883 typehints.TypeConstraint. 884 885 Example of annotating the types of side-inputs:: 886 887 FlatMap().with_input_types(int, int, bool) 888 889 Raises: 890 :class:`TypeError`: If **type_hint** is not a valid type-hint. 891 See 892 :func:`~apache_beam.typehints.typehints.validate_composite_type_param` 893 for further details. 894 895 Returns: 896 :class:`PTransform`: A reference to the instance of this particular 897 :class:`PTransform` object. This allows chaining type-hinting related 898 methods. 899 """ 900 super().with_input_types(input_type_hint) 901 902 side_inputs_arg_hints = native_type_compatibility.convert_to_beam_types( 903 side_inputs_arg_hints) 904 side_input_kwarg_hints = native_type_compatibility.convert_to_beam_types( 905 side_input_kwarg_hints) 906 907 for si in side_inputs_arg_hints: 908 validate_composite_type_param(si, 'Type hints for a PTransform') 909 for si in side_input_kwarg_hints.values(): 910 validate_composite_type_param(si, 'Type hints for a PTransform') 911 912 self.side_inputs_types = side_inputs_arg_hints 913 return WithTypeHints.with_input_types( 914 self, input_type_hint, *side_inputs_arg_hints, **side_input_kwarg_hints) 915 916 def type_check_inputs(self, pvalueish): 917 type_hints = self.get_type_hints() 918 input_types = type_hints.input_types 919 if input_types: 920 args, kwargs = self.raw_side_inputs 921 922 def element_type(side_input): 923 if isinstance(side_input, pvalue.AsSideInput): 924 return side_input.element_type 925 return instance_to_type(side_input) 926 927 arg_types = [pvalueish.element_type] + [element_type(v) for v in args] 928 kwargs_types = {k: element_type(v) for (k, v) in kwargs.items()} 929 argspec_fn = self._process_argspec_fn() 930 bindings = getcallargs_forhints(argspec_fn, *arg_types, **kwargs_types) 931 hints = getcallargs_forhints( 932 argspec_fn, *input_types[0], **input_types[1]) 933 for arg, hint in hints.items(): 934 if arg.startswith('__unknown__'): 935 continue 936 if hint is None: 937 continue 938 if not typehints.is_consistent_with(bindings.get(arg, typehints.Any), 939 hint): 940 raise TypeCheckError( 941 'Type hint violation for \'{label}\': requires {hint} but got ' 942 '{actual_type} for {arg}\nFull type hint:\n{debug_str}'.format( 943 label=self.label, 944 hint=hint, 945 actual_type=bindings[arg], 946 arg=arg, 947 debug_str=type_hints.debug_str())) 948 949 def _process_argspec_fn(self): 950 """Returns an argspec of the function actually consuming the data. 951 """ 952 raise NotImplementedError 953 954 def make_fn(self, fn, has_side_inputs): 955 # TODO(silviuc): Add comment describing that this is meant to be overriden 956 # by methods detecting callables and wrapping them in DoFns. 957 return fn 958 959 def default_label(self): 960 return '%s(%s)' % (self.__class__.__name__, self.fn.default_label()) 961 962 963 class _PTransformFnPTransform(PTransform): 964 """A class wrapper for a function-based transform.""" 965 def __init__(self, fn, *args, **kwargs): 966 super().__init__() 967 self._fn = fn 968 self._args = args 969 self._kwargs = kwargs 970 971 def display_data(self): 972 res = { 973 'fn': ( 974 self._fn.__name__ 975 if hasattr(self._fn, '__name__') else self._fn.__class__), 976 'args': DisplayDataItem(str(self._args)).drop_if_default('()'), 977 'kwargs': DisplayDataItem(str(self._kwargs)).drop_if_default('{}') 978 } 979 return res 980 981 def expand(self, pcoll): 982 # Since the PTransform will be implemented entirely as a function 983 # (once called), we need to pass through any type-hinting information that 984 # may have been annotated via the .with_input_types() and 985 # .with_output_types() methods. 986 kwargs = dict(self._kwargs) 987 args = tuple(self._args) 988 989 # TODO(BEAM-5878) Support keyword-only arguments. 990 try: 991 if 'type_hints' in get_signature(self._fn).parameters: 992 args = (self.get_type_hints(), ) + args 993 except TypeError: 994 # Might not be a function. 995 pass 996 return self._fn(pcoll, *args, **kwargs) 997 998 def default_label(self): 999 if self._args: 1000 return '%s(%s)' % ( 1001 label_from_callable(self._fn), label_from_callable(self._args[0])) 1002 return label_from_callable(self._fn) 1003 1004 1005 def ptransform_fn(fn): 1006 # type: (Callable) -> Callable[..., _PTransformFnPTransform] 1007 1008 """A decorator for a function-based PTransform. 1009 1010 Args: 1011 fn: A function implementing a custom PTransform. 1012 1013 Returns: 1014 A CallablePTransform instance wrapping the function-based PTransform. 1015 1016 This wrapper provides an alternative, simpler way to define a PTransform. 1017 The standard method is to subclass from PTransform and override the expand() 1018 method. An equivalent effect can be obtained by defining a function that 1019 accepts an input PCollection and additional optional arguments and returns a 1020 resulting PCollection. For example:: 1021 1022 @ptransform_fn 1023 @beam.typehints.with_input_types(..) 1024 @beam.typehints.with_output_types(..) 1025 def CustomMapper(pcoll, mapfn): 1026 return pcoll | ParDo(mapfn) 1027 1028 The equivalent approach using PTransform subclassing:: 1029 1030 @beam.typehints.with_input_types(..) 1031 @beam.typehints.with_output_types(..) 1032 class CustomMapper(PTransform): 1033 1034 def __init__(self, mapfn): 1035 super().__init__() 1036 self.mapfn = mapfn 1037 1038 def expand(self, pcoll): 1039 return pcoll | ParDo(self.mapfn) 1040 1041 With either method the custom PTransform can be used in pipelines as if 1042 it were one of the "native" PTransforms:: 1043 1044 result_pcoll = input_pcoll | 'Label' >> CustomMapper(somefn) 1045 1046 Note that for both solutions the underlying implementation of the pipe 1047 operator (i.e., `|`) will inject the pcoll argument in its proper place 1048 (first argument if no label was specified and second argument otherwise). 1049 1050 Type hint support needs to be enabled via the 1051 --type_check_additional=ptransform_fn flag in Beam 2. 1052 If CustomMapper is a Cython function, you can still specify input and output 1053 types provided the decorators appear before @ptransform_fn. 1054 """ 1055 # TODO(robertwb): Consider removing staticmethod to allow for self parameter. 1056 @wraps(fn) 1057 def callable_ptransform_factory(*args, **kwargs): 1058 res = _PTransformFnPTransform(fn, *args, **kwargs) 1059 if ptransform_fn_typehints_enabled: 1060 # Apply type hints applied before or after the ptransform_fn decorator, 1061 # falling back on PTransform defaults. 1062 # If the @with_{input,output}_types decorator comes before ptransform_fn, 1063 # the type hints get applied to this function. If it comes after they will 1064 # get applied to fn, and @wraps will copy the _type_hints attribute to 1065 # this function. 1066 type_hints = get_type_hints(callable_ptransform_factory) 1067 res._set_type_hints(type_hints.with_defaults(res.get_type_hints())) 1068 _LOGGER.debug( 1069 'type hints for %s: %s', res.default_label(), res.get_type_hints()) 1070 return res 1071 1072 return callable_ptransform_factory 1073 1074 1075 def label_from_callable(fn): 1076 if hasattr(fn, 'default_label'): 1077 return fn.default_label() 1078 elif hasattr(fn, '__name__'): 1079 if fn.__name__ == '<lambda>': 1080 return '<lambda at %s:%s>' % ( 1081 os.path.basename(fn.__code__.co_filename), fn.__code__.co_firstlineno) 1082 return fn.__name__ 1083 return str(fn) 1084 1085 1086 class _NamedPTransform(PTransform): 1087 def __init__(self, transform, label): 1088 super().__init__(label) 1089 self.transform = transform 1090 1091 def __ror__(self, pvalueish, _unused=None): 1092 return self.transform.__ror__(pvalueish, self.label) 1093 1094 def expand(self, pvalue): 1095 raise RuntimeError("Should never be expanded directly.")