github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/ptransform.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """PTransform and descendants.
    19  
    20  A PTransform is an object describing (not executing) a computation. The actual
    21  execution semantics for a transform is captured by a runner object. A transform
    22  object always belongs to a pipeline object.
    23  
    24  A PTransform derived class needs to define the expand() method that describes
    25  how one or more PValues are created by the transform.
    26  
    27  The module defines a few standard transforms: FlatMap (parallel do),
    28  GroupByKey (group by key), etc. Note that the expand() methods for these
    29  classes contain code that will add nodes to the processing graph associated
    30  with a pipeline.
    31  
    32  As support for the FlatMap transform, the module also defines a DoFn
    33  class and wrapper class that allows lambda functions to be used as
    34  FlatMap processing functions.
    35  """
    36  
    37  # pytype: skip-file
    38  
    39  import copy
    40  import itertools
    41  import logging
    42  import operator
    43  import os
    44  import sys
    45  import threading
    46  from functools import reduce
    47  from functools import wraps
    48  from typing import TYPE_CHECKING
    49  from typing import Any
    50  from typing import Callable
    51  from typing import Dict
    52  from typing import Generic
    53  from typing import List
    54  from typing import Mapping
    55  from typing import Optional
    56  from typing import Sequence
    57  from typing import Tuple
    58  from typing import Type
    59  from typing import TypeVar
    60  from typing import Union
    61  from typing import overload
    62  
    63  from google.protobuf import message
    64  
    65  from apache_beam import error
    66  from apache_beam import pvalue
    67  from apache_beam.internal import pickler
    68  from apache_beam.internal import util
    69  from apache_beam.portability import python_urns
    70  from apache_beam.pvalue import DoOutputsTuple
    71  from apache_beam.transforms import resources
    72  from apache_beam.transforms.display import DisplayDataItem
    73  from apache_beam.transforms.display import HasDisplayData
    74  from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX
    75  from apache_beam.typehints import native_type_compatibility
    76  from apache_beam.typehints import typehints
    77  from apache_beam.typehints.decorators import IOTypeHints
    78  from apache_beam.typehints.decorators import TypeCheckError
    79  from apache_beam.typehints.decorators import WithTypeHints
    80  from apache_beam.typehints.decorators import get_signature
    81  from apache_beam.typehints.decorators import get_type_hints
    82  from apache_beam.typehints.decorators import getcallargs_forhints
    83  from apache_beam.typehints.trivial_inference import instance_to_type
    84  from apache_beam.typehints.typehints import validate_composite_type_param
    85  from apache_beam.utils import proto_utils
    86  
    87  if TYPE_CHECKING:
    88    from apache_beam import coders
    89    from apache_beam.pipeline import Pipeline
    90    from apache_beam.runners.pipeline_context import PipelineContext
    91    from apache_beam.transforms.core import Windowing
    92    from apache_beam.portability.api import beam_runner_api_pb2
    93  
    94  __all__ = [
    95      'PTransform',
    96      'ptransform_fn',
    97      'label_from_callable',
    98  ]
    99  
   100  _LOGGER = logging.getLogger(__name__)
   101  
   102  T = TypeVar('T')
   103  InputT = TypeVar('InputT')
   104  OutputT = TypeVar('OutputT')
   105  PTransformT = TypeVar('PTransformT', bound='PTransform')
   106  ConstructorFn = Callable[
   107      ['beam_runner_api_pb2.PTransform', Optional[Any], 'PipelineContext'], Any]
   108  ptransform_fn_typehints_enabled = False
   109  
   110  
   111  class _PValueishTransform(object):
   112    """Visitor for PValueish objects.
   113  
   114    A PValueish is a PValue, or list, tuple, dict of PValuesish objects.
   115  
   116    This visits a PValueish, contstructing a (possibly mutated) copy.
   117    """
   118    def visit_nested(self, node, *args):
   119      if isinstance(node, (tuple, list)):
   120        args = [self.visit(x, *args) for x in node]
   121        if isinstance(node, tuple) and hasattr(node.__class__, '_make'):
   122          # namedtuples require unpacked arguments in their constructor
   123          return node.__class__(*args)
   124        else:
   125          return node.__class__(args)
   126      elif isinstance(node, dict):
   127        return node.__class__(
   128            {key: self.visit(value, *args)
   129             for (key, value) in node.items()})
   130      else:
   131        return node
   132  
   133  
   134  class _SetInputPValues(_PValueishTransform):
   135    def visit(self, node, replacements):
   136      if id(node) in replacements:
   137        return replacements[id(node)]
   138      else:
   139        return self.visit_nested(node, replacements)
   140  
   141  
   142  # Caches to allow for materialization of values when executing a pipeline
   143  # in-process, in eager mode.  This cache allows the same _MaterializedResult
   144  # object to be accessed and used despite Runner API round-trip serialization.
   145  _pipeline_materialization_cache = {
   146  }  # type: Dict[Tuple[int, int], Dict[int, _MaterializedResult]]
   147  _pipeline_materialization_lock = threading.Lock()
   148  
   149  
   150  def _allocate_materialized_pipeline(pipeline):
   151    # type: (Pipeline) -> None
   152    pid = os.getpid()
   153    with _pipeline_materialization_lock:
   154      pipeline_id = id(pipeline)
   155      _pipeline_materialization_cache[(pid, pipeline_id)] = {}
   156  
   157  
   158  def _allocate_materialized_result(pipeline):
   159    # type: (Pipeline) -> _MaterializedResult
   160    pid = os.getpid()
   161    with _pipeline_materialization_lock:
   162      pipeline_id = id(pipeline)
   163      if (pid, pipeline_id) not in _pipeline_materialization_cache:
   164        raise ValueError(
   165            'Materialized pipeline is not allocated for result '
   166            'cache.')
   167      result_id = len(_pipeline_materialization_cache[(pid, pipeline_id)])
   168      result = _MaterializedResult(pipeline_id, result_id)
   169      _pipeline_materialization_cache[(pid, pipeline_id)][result_id] = result
   170      return result
   171  
   172  
   173  def _get_materialized_result(pipeline_id, result_id):
   174    # type: (int, int) -> _MaterializedResult
   175    pid = os.getpid()
   176    with _pipeline_materialization_lock:
   177      if (pid, pipeline_id) not in _pipeline_materialization_cache:
   178        raise Exception(
   179            'Materialization in out-of-process and remote runners is not yet '
   180            'supported.')
   181      return _pipeline_materialization_cache[(pid, pipeline_id)][result_id]
   182  
   183  
   184  def _release_materialized_pipeline(pipeline):
   185    # type: (Pipeline) -> None
   186    pid = os.getpid()
   187    with _pipeline_materialization_lock:
   188      pipeline_id = id(pipeline)
   189      del _pipeline_materialization_cache[(pid, pipeline_id)]
   190  
   191  
   192  class _MaterializedResult(object):
   193    def __init__(self, pipeline_id, result_id):
   194      # type: (int, int) -> None
   195      self._pipeline_id = pipeline_id
   196      self._result_id = result_id
   197      self.elements = []  # type: List[Any]
   198  
   199    def __reduce__(self):
   200      # When unpickled (during Runner API roundtrip serailization), get the
   201      # _MaterializedResult object from the cache so that values are written
   202      # to the original _MaterializedResult when run in eager mode.
   203      return (_get_materialized_result, (self._pipeline_id, self._result_id))
   204  
   205  
   206  class _MaterializedDoOutputsTuple(pvalue.DoOutputsTuple):
   207    def __init__(self, deferred, results_by_tag):
   208      super().__init__(None, None, deferred._tags, deferred._main_tag)
   209      self._deferred = deferred
   210      self._results_by_tag = results_by_tag
   211  
   212    def __getitem__(self, tag):
   213      if tag not in self._results_by_tag:
   214        raise KeyError(
   215            'Tag %r is not a defined output tag of %s.' % (tag, self._deferred))
   216      return self._results_by_tag[tag].elements
   217  
   218  
   219  class _AddMaterializationTransforms(_PValueishTransform):
   220    def _materialize_transform(self, pipeline):
   221      result = _allocate_materialized_result(pipeline)
   222  
   223      # Need to define _MaterializeValuesDoFn here to avoid circular
   224      # dependencies.
   225      from apache_beam import DoFn
   226      from apache_beam import ParDo
   227  
   228      class _MaterializeValuesDoFn(DoFn):
   229        def process(self, element):
   230          result.elements.append(element)
   231  
   232      materialization_label = '_MaterializeValues%d' % result._result_id
   233      return (materialization_label >> ParDo(_MaterializeValuesDoFn()), result)
   234  
   235    def visit(self, node):
   236      if isinstance(node, pvalue.PValue):
   237        transform, result = self._materialize_transform(node.pipeline)
   238        node | transform
   239        return result
   240      elif isinstance(node, pvalue.DoOutputsTuple):
   241        results_by_tag = {}
   242        for tag in itertools.chain([node._main_tag], node._tags):
   243          results_by_tag[tag] = self.visit(node[tag])
   244        return _MaterializedDoOutputsTuple(node, results_by_tag)
   245      else:
   246        return self.visit_nested(node)
   247  
   248  
   249  class _FinalizeMaterialization(_PValueishTransform):
   250    def visit(self, node):
   251      if isinstance(node, _MaterializedResult):
   252        return node.elements
   253      elif isinstance(node, _MaterializedDoOutputsTuple):
   254        return node
   255      else:
   256        return self.visit_nested(node)
   257  
   258  
   259  def get_named_nested_pvalues(pvalueish, as_inputs=False):
   260    if isinstance(pvalueish, tuple):
   261      # Check to see if it's a named tuple.
   262      fields = getattr(pvalueish, '_fields', None)
   263      if fields and len(fields) == len(pvalueish):
   264        tagged_values = zip(fields, pvalueish)
   265      else:
   266        tagged_values = enumerate(pvalueish)
   267    elif isinstance(pvalueish, list):
   268      if as_inputs:
   269        # Full list treated as a list of value for eager evaluation.
   270        yield None, pvalueish
   271        return
   272      tagged_values = enumerate(pvalueish)
   273    elif isinstance(pvalueish, dict):
   274      tagged_values = pvalueish.items()
   275    else:
   276      if as_inputs or isinstance(pvalueish,
   277                                 (pvalue.PValue, pvalue.DoOutputsTuple)):
   278        yield None, pvalueish
   279      return
   280  
   281    for tag, subvalue in tagged_values:
   282      for subtag, subsubvalue in get_named_nested_pvalues(
   283          subvalue, as_inputs=as_inputs):
   284        if subtag is None:
   285          yield tag, subsubvalue
   286        else:
   287          yield '%s.%s' % (tag, subtag), subsubvalue
   288  
   289  
   290  class _ZipPValues(object):
   291    """Pairs each PValue in a pvalueish with a value in a parallel out sibling.
   292  
   293    Sibling should have the same nested structure as pvalueish.  Leaves in
   294    sibling are expanded across nested pvalueish lists, tuples, and dicts.
   295    For example
   296  
   297        ZipPValues().visit({'a': pc1, 'b': (pc2, pc3)},
   298                           {'a': 'A', 'b', 'B'})
   299  
   300    will return
   301  
   302        [('a', pc1, 'A'), ('b', pc2, 'B'), ('b', pc3, 'B')]
   303    """
   304    def visit(self, pvalueish, sibling, pairs=None, context=None):
   305      if pairs is None:
   306        pairs = []
   307        self.visit(pvalueish, sibling, pairs, context)
   308        return pairs
   309      elif isinstance(pvalueish, (pvalue.PValue, pvalue.DoOutputsTuple)):
   310        pairs.append((context, pvalueish, sibling))
   311      elif isinstance(pvalueish, (list, tuple)):
   312        self.visit_sequence(pvalueish, sibling, pairs, context)
   313      elif isinstance(pvalueish, dict):
   314        self.visit_dict(pvalueish, sibling, pairs, context)
   315  
   316    def visit_sequence(self, pvalueish, sibling, pairs, context):
   317      if isinstance(sibling, (list, tuple)):
   318        for ix, (p, s) in enumerate(zip(pvalueish,
   319                                        list(sibling) + [None] * len(pvalueish))):
   320          self.visit(p, s, pairs, 'position %s' % ix)
   321      else:
   322        for p in pvalueish:
   323          self.visit(p, sibling, pairs, context)
   324  
   325    def visit_dict(self, pvalueish, sibling, pairs, context):
   326      if isinstance(sibling, dict):
   327        for key, p in pvalueish.items():
   328          self.visit(p, sibling.get(key), pairs, key)
   329      else:
   330        for p in pvalueish.values():
   331          self.visit(p, sibling, pairs, context)
   332  
   333  
   334  class PTransform(WithTypeHints, HasDisplayData, Generic[InputT, OutputT]):
   335    """A transform object used to modify one or more PCollections.
   336  
   337    Subclasses must define an expand() method that will be used when the transform
   338    is applied to some arguments. Typical usage pattern will be:
   339  
   340      input | CustomTransform(...)
   341  
   342    The expand() method of the CustomTransform object passed in will be called
   343    with input as an argument.
   344    """
   345    # By default, transforms don't have any side inputs.
   346    side_inputs = ()  # type: Sequence[pvalue.AsSideInput]
   347  
   348    # Used for nullary transforms.
   349    pipeline = None  # type: Optional[Pipeline]
   350  
   351    # Default is unset.
   352    _user_label = None  # type: Optional[str]
   353  
   354    def __init__(self, label=None):
   355      # type: (Optional[str]) -> None
   356      super().__init__()
   357      self.label = label  # type: ignore # https://github.com/python/mypy/issues/3004
   358  
   359    @property
   360    def label(self):
   361      # type: () -> str
   362      return self._user_label or self.default_label()
   363  
   364    @label.setter
   365    def label(self, value):
   366      # type: (Optional[str]) -> None
   367      self._user_label = value
   368  
   369    def default_label(self):
   370      # type: () -> str
   371      return self.__class__.__name__
   372  
   373    def annotations(self) -> Dict[str, Union[bytes, str, message.Message]]:
   374      return {}
   375  
   376    def default_type_hints(self):
   377      fn_type_hints = IOTypeHints.from_callable(self.expand)
   378      if fn_type_hints is not None:
   379        fn_type_hints = fn_type_hints.strip_pcoll()
   380  
   381      # Prefer class decorator type hints for backwards compatibility.
   382      return get_type_hints(self.__class__).with_defaults(fn_type_hints)
   383  
   384    def with_input_types(self, input_type_hint):
   385      """Annotates the input type of a :class:`PTransform` with a type-hint.
   386  
   387      Args:
   388        input_type_hint (type): An instance of an allowed built-in type, a custom
   389          class, or an instance of a
   390          :class:`~apache_beam.typehints.typehints.TypeConstraint`.
   391  
   392      Raises:
   393        TypeError: If **input_type_hint** is not a valid type-hint.
   394          See
   395          :obj:`apache_beam.typehints.typehints.validate_composite_type_param()`
   396          for further details.
   397  
   398      Returns:
   399        PTransform: A reference to the instance of this particular
   400        :class:`PTransform` object. This allows chaining type-hinting related
   401        methods.
   402      """
   403      input_type_hint = native_type_compatibility.convert_to_beam_type(
   404          input_type_hint)
   405      validate_composite_type_param(
   406          input_type_hint, 'Type hints for a PTransform')
   407      return super().with_input_types(input_type_hint)
   408  
   409    def with_output_types(self, type_hint):
   410      """Annotates the output type of a :class:`PTransform` with a type-hint.
   411  
   412      Args:
   413        type_hint (type): An instance of an allowed built-in type, a custom class,
   414          or a :class:`~apache_beam.typehints.typehints.TypeConstraint`.
   415  
   416      Raises:
   417        TypeError: If **type_hint** is not a valid type-hint. See
   418          :obj:`~apache_beam.typehints.typehints.validate_composite_type_param()`
   419          for further details.
   420  
   421      Returns:
   422        PTransform: A reference to the instance of this particular
   423        :class:`PTransform` object. This allows chaining type-hinting related
   424        methods.
   425      """
   426      type_hint = native_type_compatibility.convert_to_beam_type(type_hint)
   427      validate_composite_type_param(type_hint, 'Type hints for a PTransform')
   428      return super().with_output_types(type_hint)
   429  
   430    def with_resource_hints(self, **kwargs):  # type: (...) -> PTransform
   431      """Adds resource hints to the :class:`PTransform`.
   432  
   433      Resource hints allow users to express constraints on the environment where
   434      the transform should be executed.  Interpretation of the resource hints is
   435      defined by Beam Runners. Runners may ignore the unsupported hints.
   436  
   437      Args:
   438        **kwargs: key-value pairs describing hints and their values.
   439  
   440      Raises:
   441        ValueError: if provided hints are unknown to the SDK. See
   442          :mod:`apache_beam.transforms.resources` for a list of known hints.
   443  
   444      Returns:
   445        PTransform: A reference to the instance of this particular
   446        :class:`PTransform` object.
   447      """
   448      self.get_resource_hints().update(resources.parse_resource_hints(kwargs))
   449      return self
   450  
   451    def get_resource_hints(self):
   452      # type: () -> Dict[str, bytes]
   453      if '_resource_hints' not in self.__dict__:
   454        # PTransform subclasses don't always call super(), so prefer lazy
   455        # initialization. By default, transforms don't have any resource hints.
   456        self._resource_hints = {}  # type: Dict[str, bytes]
   457      return self._resource_hints
   458  
   459    def type_check_inputs(self, pvalueish):
   460      self.type_check_inputs_or_outputs(pvalueish, 'input')
   461  
   462    def infer_output_type(self, unused_input_type):
   463      return self.get_type_hints().simple_output_type(self.label) or typehints.Any
   464  
   465    def type_check_outputs(self, pvalueish):
   466      self.type_check_inputs_or_outputs(pvalueish, 'output')
   467  
   468    def type_check_inputs_or_outputs(self, pvalueish, input_or_output):
   469      type_hints = self.get_type_hints()
   470      hints = getattr(type_hints, input_or_output + '_types')
   471      if hints is None or not any(hints):
   472        return
   473      arg_hints, kwarg_hints = hints
   474      if arg_hints and kwarg_hints:
   475        raise TypeCheckError(
   476            'PTransform cannot have both positional and keyword type hints '
   477            'without overriding %s._type_check_%s()' %
   478            (self.__class__, input_or_output))
   479      root_hint = (
   480          arg_hints[0] if len(arg_hints) == 1 else arg_hints or kwarg_hints)
   481      for context, pvalue_, hint in _ZipPValues().visit(pvalueish, root_hint):
   482        if isinstance(pvalue_, DoOutputsTuple):
   483          continue
   484        if pvalue_.element_type is None:
   485          # TODO(robertwb): It's a bug that we ever get here. (typecheck)
   486          continue
   487        if hint and not typehints.is_consistent_with(pvalue_.element_type, hint):
   488          at_context = ' %s %s' % (input_or_output, context) if context else ''
   489          raise TypeCheckError(
   490              '{type} type hint violation at {label}{context}: expected {hint}, '
   491              'got {actual_type}\nFull type hint:\n{debug_str}'.format(
   492                  type=input_or_output.title(),
   493                  label=self.label,
   494                  context=at_context,
   495                  hint=hint,
   496                  actual_type=pvalue_.element_type,
   497                  debug_str=type_hints.debug_str()))
   498  
   499    def _infer_output_coder(self, input_type=None, input_coder=None):
   500      # type: (...) -> Optional[coders.Coder]
   501  
   502      """Returns the output coder to use for output of this transform.
   503  
   504      The Coder returned here should not be wrapped in a WindowedValueCoder
   505      wrapper.
   506  
   507      Args:
   508        input_type: An instance of an allowed built-in type, a custom class, or a
   509          typehints.TypeConstraint for the input type, or None if not available.
   510        input_coder: Coder object for encoding input to this PTransform, or None
   511          if not available.
   512  
   513      Returns:
   514        Coder object for encoding output of this PTransform or None if unknown.
   515      """
   516      # TODO(ccy): further refine this API.
   517      return None
   518  
   519    def _clone(self, new_label):
   520      """Clones the current transform instance under a new label."""
   521      transform = copy.copy(self)
   522      transform.label = new_label
   523      return transform
   524  
   525    def expand(self, input_or_inputs: InputT) -> OutputT:
   526      raise NotImplementedError
   527  
   528    def __str__(self):
   529      return '<%s>' % self._str_internal()
   530  
   531    def __repr__(self):
   532      return '<%s at %s>' % (self._str_internal(), hex(id(self)))
   533  
   534    def _str_internal(self):
   535      return '%s(PTransform)%s%s%s' % (
   536          self.__class__.__name__,
   537          ' label=[%s]' % self.label if
   538          (hasattr(self, 'label') and self.label) else '',
   539          ' inputs=%s' % str(self.inputs) if
   540          (hasattr(self, 'inputs') and self.inputs) else '',
   541          ' side_inputs=%s' % str(self.side_inputs) if self.side_inputs else '')
   542  
   543    def _check_pcollection(self, pcoll):
   544      # type: (pvalue.PCollection) -> None
   545      if not isinstance(pcoll, pvalue.PCollection):
   546        raise error.TransformError('Expecting a PCollection argument.')
   547      if not pcoll.pipeline:
   548        raise error.TransformError('PCollection not part of a pipeline.')
   549  
   550    def get_windowing(self, inputs):
   551      # type: (Any) -> Windowing
   552  
   553      """Returns the window function to be associated with transform's output.
   554  
   555      By default most transforms just return the windowing function associated
   556      with the input PCollection (or the first input if several).
   557      """
   558      if inputs:
   559        return inputs[0].windowing
   560      else:
   561        from apache_beam.transforms.core import Windowing
   562        from apache_beam.transforms.window import GlobalWindows
   563        # TODO(robertwb): Return something compatible with every windowing?
   564        return Windowing(GlobalWindows())
   565  
   566    def __rrshift__(self, label):
   567      return _NamedPTransform(self, label)
   568  
   569    def __or__(self, right):
   570      """Used to compose PTransforms, e.g., ptransform1 | ptransform2."""
   571      if isinstance(right, PTransform):
   572        return _ChainedPTransform(self, right)
   573      return NotImplemented
   574  
   575    def __ror__(self, left, label=None):
   576      """Used to apply this PTransform to non-PValues, e.g., a tuple."""
   577      pvalueish, pvalues = self._extract_input_pvalues(left)
   578      if isinstance(pvalues, dict):
   579        pvalues = tuple(pvalues.values())
   580      pipelines = [v.pipeline for v in pvalues if isinstance(v, pvalue.PValue)]
   581      if pvalues and not pipelines:
   582        deferred = False
   583        # pylint: disable=wrong-import-order, wrong-import-position
   584        from apache_beam import pipeline
   585        from apache_beam.options.pipeline_options import PipelineOptions
   586        # pylint: enable=wrong-import-order, wrong-import-position
   587        p = pipeline.Pipeline('DirectRunner', PipelineOptions(sys.argv))
   588      else:
   589        if not pipelines:
   590          if self.pipeline is not None:
   591            p = self.pipeline
   592          else:
   593            raise ValueError(
   594                '"%s" requires a pipeline to be specified '
   595                'as there are no deferred inputs.' % self.label)
   596        else:
   597          p = self.pipeline or pipelines[0]
   598          for pp in pipelines:
   599            if p != pp:
   600              raise ValueError(
   601                  'Mixing values in different pipelines is not allowed.'
   602                  '\n{%r} != {%r}' % (p, pp))
   603        deferred = not getattr(p.runner, 'is_eager', False)
   604      # pylint: disable=wrong-import-order, wrong-import-position
   605      from apache_beam.transforms.core import Create
   606      # pylint: enable=wrong-import-order, wrong-import-position
   607      replacements = {
   608          id(v): p | 'CreatePInput%s' % ix >> Create(v, reshuffle=False)
   609          for (ix, v) in enumerate(pvalues)
   610          if not isinstance(v, pvalue.PValue) and v is not None
   611      }
   612      pvalueish = _SetInputPValues().visit(pvalueish, replacements)
   613      self.pipeline = p
   614      result = p.apply(self, pvalueish, label)
   615      if deferred:
   616        return result
   617      _allocate_materialized_pipeline(p)
   618      materialized_result = _AddMaterializationTransforms().visit(result)
   619      p.run().wait_until_finish()
   620      _release_materialized_pipeline(p)
   621      return _FinalizeMaterialization().visit(materialized_result)
   622  
   623    def _extract_input_pvalues(self, pvalueish):
   624      """Extract all the pvalues contained in the input pvalueish.
   625  
   626      Returns pvalueish as well as the flat inputs list as the input may have to
   627      be copied as inspection may be destructive.
   628  
   629      By default, recursively extracts tuple components and dict values.
   630  
   631      Generally only needs to be overriden for multi-input PTransforms.
   632      """
   633      # pylint: disable=wrong-import-order
   634      from apache_beam import pipeline
   635      # pylint: enable=wrong-import-order
   636      if isinstance(pvalueish, pipeline.Pipeline):
   637        pvalueish = pvalue.PBegin(pvalueish)
   638  
   639      return pvalueish, {
   640          str(tag): value
   641          for (tag, value) in get_named_nested_pvalues(
   642              pvalueish, as_inputs=True)
   643      }
   644  
   645    def _pvaluish_from_dict(self, input_dict):
   646      if len(input_dict) == 1:
   647        return next(iter(input_dict.values()))
   648      else:
   649        return input_dict
   650  
   651    def _named_inputs(self, main_inputs, side_inputs):
   652      # type: (Mapping[str, pvalue.PValue], Sequence[Any]) -> Dict[str, pvalue.PValue]
   653  
   654      """Returns the dictionary of named inputs (including side inputs) as they
   655      should be named in the beam proto.
   656      """
   657      main_inputs = {
   658          tag: input
   659          for (tag, input) in main_inputs.items()
   660          if isinstance(input, pvalue.PCollection)
   661      }
   662      named_side_inputs = {(SIDE_INPUT_PREFIX + '%s') % ix: si.pvalue
   663                           for (ix, si) in enumerate(side_inputs)}
   664      return dict(main_inputs, **named_side_inputs)
   665  
   666    def _named_outputs(self, outputs):
   667      # type: (Dict[object, pvalue.PCollection]) -> Dict[str, pvalue.PCollection]
   668  
   669      """Returns the dictionary of named outputs as they should be named in the
   670      beam proto.
   671      """
   672      # TODO(BEAM-1833): Push names up into the sdk construction.
   673      return {
   674          str(tag): output
   675          for (tag, output) in outputs.items()
   676          if isinstance(output, pvalue.PCollection)
   677      }
   678  
   679    _known_urns = {}  # type: Dict[str, Tuple[Optional[type], ConstructorFn]]
   680  
   681    @classmethod
   682    @overload
   683    def register_urn(
   684        cls,
   685        urn,  # type: str
   686        parameter_type,  # type: Type[T]
   687    ):
   688      # type: (...) -> Callable[[Union[type, Callable[[beam_runner_api_pb2.PTransform, T, PipelineContext], Any]]], Callable[[T, PipelineContext], Any]]
   689      pass
   690  
   691    @classmethod
   692    @overload
   693    def register_urn(
   694        cls,
   695        urn,  # type: str
   696        parameter_type,  # type: None
   697    ):
   698      # type: (...) -> Callable[[Union[type, Callable[[beam_runner_api_pb2.PTransform, bytes, PipelineContext], Any]]], Callable[[bytes, PipelineContext], Any]]
   699      pass
   700  
   701    @classmethod
   702    @overload
   703    def register_urn(cls,
   704                     urn,  # type: str
   705                     parameter_type,  # type: Type[T]
   706                     constructor  # type: Callable[[beam_runner_api_pb2.PTransform, T, PipelineContext], Any]
   707                    ):
   708      # type: (...) -> None
   709      pass
   710  
   711    @classmethod
   712    @overload
   713    def register_urn(cls,
   714                     urn,  # type: str
   715                     parameter_type,  # type: None
   716                     constructor  # type: Callable[[beam_runner_api_pb2.PTransform, bytes, PipelineContext], Any]
   717                    ):
   718      # type: (...) -> None
   719      pass
   720  
   721    @classmethod
   722    def register_urn(cls, urn, parameter_type, constructor=None):
   723      def register(constructor):
   724        if isinstance(constructor, type):
   725          constructor.from_runner_api_parameter = register(
   726              constructor.from_runner_api_parameter)
   727        else:
   728          cls._known_urns[urn] = parameter_type, constructor
   729        return constructor
   730  
   731      if constructor:
   732        # Used as a statement.
   733        register(constructor)
   734      else:
   735        # Used as a decorator.
   736        return register
   737  
   738    def to_runner_api(self, context, has_parts=False, **extra_kwargs):
   739      # type: (PipelineContext, bool, Any) -> beam_runner_api_pb2.FunctionSpec
   740      from apache_beam.portability.api import beam_runner_api_pb2
   741      # typing: only ParDo supports extra_kwargs
   742      urn, typed_param = self.to_runner_api_parameter(context, **extra_kwargs)  # type: ignore[call-arg]
   743      if urn == python_urns.GENERIC_COMPOSITE_TRANSFORM and not has_parts:
   744        # TODO(https://github.com/apache/beam/issues/18713): Remove this fallback.
   745        urn, typed_param = self.to_runner_api_pickled(context)
   746      return beam_runner_api_pb2.FunctionSpec(
   747          urn=urn,
   748          payload=typed_param.SerializeToString() if isinstance(
   749              typed_param, message.Message) else typed_param.encode('utf-8')
   750          if isinstance(typed_param, str) else typed_param)
   751  
   752    @classmethod
   753    def from_runner_api(cls,
   754                        proto,  # type: Optional[beam_runner_api_pb2.PTransform]
   755                        context  # type: PipelineContext
   756                       ):
   757      # type: (...) -> Optional[PTransform]
   758      if proto is None or proto.spec is None or not proto.spec.urn:
   759        return None
   760      parameter_type, constructor = cls._known_urns[proto.spec.urn]
   761  
   762      return constructor(
   763          proto,
   764          proto_utils.parse_Bytes(proto.spec.payload, parameter_type),
   765          context)
   766  
   767    def to_runner_api_parameter(
   768        self,
   769        unused_context  # type: PipelineContext
   770    ):
   771      # type: (...) -> Tuple[str, Optional[Union[message.Message, bytes, str]]]
   772      # The payload here is just to ease debugging.
   773      return (
   774          python_urns.GENERIC_COMPOSITE_TRANSFORM,
   775          getattr(self, '_fn_api_payload', str(self)))
   776  
   777    def to_runner_api_pickled(self, unused_context):
   778      # type: (PipelineContext) -> Tuple[str, bytes]
   779      return (python_urns.PICKLED_TRANSFORM, pickler.dumps(self))
   780  
   781    def runner_api_requires_keyed_input(self):
   782      return False
   783  
   784    def _add_type_constraint_from_consumer(self, full_label, input_type_hints):
   785      # type: (str, Tuple[str, Any]) -> None
   786  
   787      """Adds a consumer transform's input type hints to our output type
   788      constraints, which is used during performance runtime type-checking.
   789      """
   790      pass
   791  
   792  
   793  @PTransform.register_urn(python_urns.GENERIC_COMPOSITE_TRANSFORM, None)
   794  def _create_transform(unused_ptransform, payload, unused_context):
   795    empty_transform = PTransform()
   796    empty_transform._fn_api_payload = payload
   797    return empty_transform
   798  
   799  
   800  @PTransform.register_urn(python_urns.PICKLED_TRANSFORM, None)
   801  def _unpickle_transform(unused_ptransform, pickled_bytes, unused_context):
   802    return pickler.loads(pickled_bytes)
   803  
   804  
   805  class _ChainedPTransform(PTransform):
   806    def __init__(self, *parts):
   807      # type: (*PTransform) -> None
   808      super().__init__(label=self._chain_label(parts))
   809      self._parts = parts
   810  
   811    def _chain_label(self, parts):
   812      return '|'.join(p.label for p in parts)
   813  
   814    def __or__(self, right):
   815      if isinstance(right, PTransform):
   816        # Create a flat list rather than a nested tree of composite
   817        # transforms for better monitoring, etc.
   818        return _ChainedPTransform(*(self._parts + (right, )))
   819      return NotImplemented
   820  
   821    def expand(self, pval):
   822      return reduce(operator.or_, self._parts, pval)
   823  
   824  
   825  class PTransformWithSideInputs(PTransform):
   826    """A superclass for any :class:`PTransform` (e.g.
   827    :func:`~apache_beam.transforms.core.FlatMap` or
   828    :class:`~apache_beam.transforms.core.CombineFn`)
   829    invoking user code.
   830  
   831    :class:`PTransform` s like :func:`~apache_beam.transforms.core.FlatMap`
   832    invoke user-supplied code in some kind of package (e.g. a
   833    :class:`~apache_beam.transforms.core.DoFn`) and optionally provide arguments
   834    and side inputs to that code. This internal-use-only class contains common
   835    functionality for :class:`PTransform` s that fit this model.
   836    """
   837    def __init__(self, fn, *args, **kwargs):
   838      # type: (WithTypeHints, *Any, **Any) -> None
   839      if isinstance(fn, type) and issubclass(fn, WithTypeHints):
   840        # Don't treat Fn class objects as callables.
   841        raise ValueError('Use %s() not %s.' % (fn.__name__, fn.__name__))
   842      self.fn = self.make_fn(fn, bool(args or kwargs))
   843      # Now that we figure out the label, initialize the super-class.
   844      super().__init__()
   845  
   846      if (any(isinstance(v, pvalue.PCollection) for v in args) or
   847          any(isinstance(v, pvalue.PCollection) for v in kwargs.values())):
   848        raise error.SideInputError(
   849            'PCollection used directly as side input argument. Specify '
   850            'AsIter(pcollection) or AsSingleton(pcollection) to indicate how the '
   851            'PCollection is to be used.')
   852      self.args, self.kwargs, self.side_inputs = util.remove_objects_from_args(
   853          args, kwargs, pvalue.AsSideInput)
   854      self.raw_side_inputs = args, kwargs
   855  
   856      # Prevent name collisions with fns of the form '<function <lambda> at ...>'
   857      self._cached_fn = self.fn
   858  
   859      # Ensure fn and side inputs are picklable for remote execution.
   860      try:
   861        self.fn = pickler.loads(pickler.dumps(self.fn))
   862      except RuntimeError as e:
   863        raise RuntimeError('Unable to pickle fn %s: %s' % (self.fn, e))
   864  
   865      self.args = pickler.loads(pickler.dumps(self.args))
   866      self.kwargs = pickler.loads(pickler.dumps(self.kwargs))
   867  
   868      # For type hints, because loads(dumps(class)) != class.
   869      self.fn = self._cached_fn
   870  
   871    def with_input_types(
   872        self, input_type_hint, *side_inputs_arg_hints, **side_input_kwarg_hints):
   873      """Annotates the types of main inputs and side inputs for the PTransform.
   874  
   875      Args:
   876        input_type_hint: An instance of an allowed built-in type, a custom class,
   877          or an instance of a typehints.TypeConstraint.
   878        *side_inputs_arg_hints: A variable length argument composed of
   879          of an allowed built-in type, a custom class, or a
   880          typehints.TypeConstraint.
   881        **side_input_kwarg_hints: A dictionary argument composed of
   882          of an allowed built-in type, a custom class, or a
   883          typehints.TypeConstraint.
   884  
   885      Example of annotating the types of side-inputs::
   886  
   887        FlatMap().with_input_types(int, int, bool)
   888  
   889      Raises:
   890        :class:`TypeError`: If **type_hint** is not a valid type-hint.
   891          See
   892          :func:`~apache_beam.typehints.typehints.validate_composite_type_param`
   893          for further details.
   894  
   895      Returns:
   896        :class:`PTransform`: A reference to the instance of this particular
   897        :class:`PTransform` object. This allows chaining type-hinting related
   898        methods.
   899      """
   900      super().with_input_types(input_type_hint)
   901  
   902      side_inputs_arg_hints = native_type_compatibility.convert_to_beam_types(
   903          side_inputs_arg_hints)
   904      side_input_kwarg_hints = native_type_compatibility.convert_to_beam_types(
   905          side_input_kwarg_hints)
   906  
   907      for si in side_inputs_arg_hints:
   908        validate_composite_type_param(si, 'Type hints for a PTransform')
   909      for si in side_input_kwarg_hints.values():
   910        validate_composite_type_param(si, 'Type hints for a PTransform')
   911  
   912      self.side_inputs_types = side_inputs_arg_hints
   913      return WithTypeHints.with_input_types(
   914          self, input_type_hint, *side_inputs_arg_hints, **side_input_kwarg_hints)
   915  
   916    def type_check_inputs(self, pvalueish):
   917      type_hints = self.get_type_hints()
   918      input_types = type_hints.input_types
   919      if input_types:
   920        args, kwargs = self.raw_side_inputs
   921  
   922        def element_type(side_input):
   923          if isinstance(side_input, pvalue.AsSideInput):
   924            return side_input.element_type
   925          return instance_to_type(side_input)
   926  
   927        arg_types = [pvalueish.element_type] + [element_type(v) for v in args]
   928        kwargs_types = {k: element_type(v) for (k, v) in kwargs.items()}
   929        argspec_fn = self._process_argspec_fn()
   930        bindings = getcallargs_forhints(argspec_fn, *arg_types, **kwargs_types)
   931        hints = getcallargs_forhints(
   932            argspec_fn, *input_types[0], **input_types[1])
   933        for arg, hint in hints.items():
   934          if arg.startswith('__unknown__'):
   935            continue
   936          if hint is None:
   937            continue
   938          if not typehints.is_consistent_with(bindings.get(arg, typehints.Any),
   939                                              hint):
   940            raise TypeCheckError(
   941                'Type hint violation for \'{label}\': requires {hint} but got '
   942                '{actual_type} for {arg}\nFull type hint:\n{debug_str}'.format(
   943                    label=self.label,
   944                    hint=hint,
   945                    actual_type=bindings[arg],
   946                    arg=arg,
   947                    debug_str=type_hints.debug_str()))
   948  
   949    def _process_argspec_fn(self):
   950      """Returns an argspec of the function actually consuming the data.
   951      """
   952      raise NotImplementedError
   953  
   954    def make_fn(self, fn, has_side_inputs):
   955      # TODO(silviuc): Add comment describing that this is meant to be overriden
   956      # by methods detecting callables and wrapping them in DoFns.
   957      return fn
   958  
   959    def default_label(self):
   960      return '%s(%s)' % (self.__class__.__name__, self.fn.default_label())
   961  
   962  
   963  class _PTransformFnPTransform(PTransform):
   964    """A class wrapper for a function-based transform."""
   965    def __init__(self, fn, *args, **kwargs):
   966      super().__init__()
   967      self._fn = fn
   968      self._args = args
   969      self._kwargs = kwargs
   970  
   971    def display_data(self):
   972      res = {
   973          'fn': (
   974              self._fn.__name__
   975              if hasattr(self._fn, '__name__') else self._fn.__class__),
   976          'args': DisplayDataItem(str(self._args)).drop_if_default('()'),
   977          'kwargs': DisplayDataItem(str(self._kwargs)).drop_if_default('{}')
   978      }
   979      return res
   980  
   981    def expand(self, pcoll):
   982      # Since the PTransform will be implemented entirely as a function
   983      # (once called), we need to pass through any type-hinting information that
   984      # may have been annotated via the .with_input_types() and
   985      # .with_output_types() methods.
   986      kwargs = dict(self._kwargs)
   987      args = tuple(self._args)
   988  
   989      # TODO(BEAM-5878) Support keyword-only arguments.
   990      try:
   991        if 'type_hints' in get_signature(self._fn).parameters:
   992          args = (self.get_type_hints(), ) + args
   993      except TypeError:
   994        # Might not be a function.
   995        pass
   996      return self._fn(pcoll, *args, **kwargs)
   997  
   998    def default_label(self):
   999      if self._args:
  1000        return '%s(%s)' % (
  1001            label_from_callable(self._fn), label_from_callable(self._args[0]))
  1002      return label_from_callable(self._fn)
  1003  
  1004  
  1005  def ptransform_fn(fn):
  1006    # type: (Callable) -> Callable[..., _PTransformFnPTransform]
  1007  
  1008    """A decorator for a function-based PTransform.
  1009  
  1010    Args:
  1011      fn: A function implementing a custom PTransform.
  1012  
  1013    Returns:
  1014      A CallablePTransform instance wrapping the function-based PTransform.
  1015  
  1016    This wrapper provides an alternative, simpler way to define a PTransform.
  1017    The standard method is to subclass from PTransform and override the expand()
  1018    method. An equivalent effect can be obtained by defining a function that
  1019    accepts an input PCollection and additional optional arguments and returns a
  1020    resulting PCollection. For example::
  1021  
  1022      @ptransform_fn
  1023      @beam.typehints.with_input_types(..)
  1024      @beam.typehints.with_output_types(..)
  1025      def CustomMapper(pcoll, mapfn):
  1026        return pcoll | ParDo(mapfn)
  1027  
  1028    The equivalent approach using PTransform subclassing::
  1029  
  1030      @beam.typehints.with_input_types(..)
  1031      @beam.typehints.with_output_types(..)
  1032      class CustomMapper(PTransform):
  1033  
  1034        def __init__(self, mapfn):
  1035          super().__init__()
  1036          self.mapfn = mapfn
  1037  
  1038        def expand(self, pcoll):
  1039          return pcoll | ParDo(self.mapfn)
  1040  
  1041    With either method the custom PTransform can be used in pipelines as if
  1042    it were one of the "native" PTransforms::
  1043  
  1044      result_pcoll = input_pcoll | 'Label' >> CustomMapper(somefn)
  1045  
  1046    Note that for both solutions the underlying implementation of the pipe
  1047    operator (i.e., `|`) will inject the pcoll argument in its proper place
  1048    (first argument if no label was specified and second argument otherwise).
  1049  
  1050    Type hint support needs to be enabled via the
  1051    --type_check_additional=ptransform_fn flag in Beam 2.
  1052    If CustomMapper is a Cython function, you can still specify input and output
  1053    types provided the decorators appear before @ptransform_fn.
  1054    """
  1055    # TODO(robertwb): Consider removing staticmethod to allow for self parameter.
  1056    @wraps(fn)
  1057    def callable_ptransform_factory(*args, **kwargs):
  1058      res = _PTransformFnPTransform(fn, *args, **kwargs)
  1059      if ptransform_fn_typehints_enabled:
  1060        # Apply type hints applied before or after the ptransform_fn decorator,
  1061        # falling back on PTransform defaults.
  1062        # If the @with_{input,output}_types decorator comes before ptransform_fn,
  1063        # the type hints get applied to this function. If it comes after they will
  1064        # get applied to fn, and @wraps will copy the _type_hints attribute to
  1065        # this function.
  1066        type_hints = get_type_hints(callable_ptransform_factory)
  1067        res._set_type_hints(type_hints.with_defaults(res.get_type_hints()))
  1068        _LOGGER.debug(
  1069            'type hints for %s: %s', res.default_label(), res.get_type_hints())
  1070      return res
  1071  
  1072    return callable_ptransform_factory
  1073  
  1074  
  1075  def label_from_callable(fn):
  1076    if hasattr(fn, 'default_label'):
  1077      return fn.default_label()
  1078    elif hasattr(fn, '__name__'):
  1079      if fn.__name__ == '<lambda>':
  1080        return '<lambda at %s:%s>' % (
  1081            os.path.basename(fn.__code__.co_filename), fn.__code__.co_firstlineno)
  1082      return fn.__name__
  1083    return str(fn)
  1084  
  1085  
  1086  class _NamedPTransform(PTransform):
  1087    def __init__(self, transform, label):
  1088      super().__init__(label)
  1089      self.transform = transform
  1090  
  1091    def __ror__(self, pvalueish, _unused=None):
  1092      return self.transform.__ror__(pvalueish, self.label)
  1093  
  1094    def expand(self, pvalue):
  1095      raise RuntimeError("Should never be expanded directly.")