github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/core.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Core PTransform subclasses, such as FlatMap, GroupByKey, and Map."""
    19  
    20  # pytype: skip-file
    21  
    22  import concurrent.futures
    23  import copy
    24  import inspect
    25  import logging
    26  import random
    27  import sys
    28  import time
    29  import traceback
    30  import types
    31  import typing
    32  from itertools import dropwhile
    33  
    34  from apache_beam import coders
    35  from apache_beam import pvalue
    36  from apache_beam import typehints
    37  from apache_beam.coders import typecoders
    38  from apache_beam.internal import pickler
    39  from apache_beam.internal import util
    40  from apache_beam.options.pipeline_options import TypeOptions
    41  from apache_beam.portability import common_urns
    42  from apache_beam.portability import python_urns
    43  from apache_beam.portability.api import beam_runner_api_pb2
    44  from apache_beam.transforms import ptransform
    45  from apache_beam.transforms import userstate
    46  from apache_beam.transforms.display import DisplayDataItem
    47  from apache_beam.transforms.display import HasDisplayData
    48  from apache_beam.transforms.ptransform import PTransform
    49  from apache_beam.transforms.ptransform import PTransformWithSideInputs
    50  from apache_beam.transforms.sideinputs import SIDE_INPUT_PREFIX
    51  from apache_beam.transforms.sideinputs import get_sideinput_index
    52  from apache_beam.transforms.userstate import StateSpec
    53  from apache_beam.transforms.userstate import TimerSpec
    54  from apache_beam.transforms.window import GlobalWindows
    55  from apache_beam.transforms.window import SlidingWindows
    56  from apache_beam.transforms.window import TimestampCombiner
    57  from apache_beam.transforms.window import TimestampedValue
    58  from apache_beam.transforms.window import WindowedValue
    59  from apache_beam.transforms.window import WindowFn
    60  from apache_beam.typehints import row_type
    61  from apache_beam.typehints import trivial_inference
    62  from apache_beam.typehints.batch import BatchConverter
    63  from apache_beam.typehints.decorators import TypeCheckError
    64  from apache_beam.typehints.decorators import WithTypeHints
    65  from apache_beam.typehints.decorators import get_signature
    66  from apache_beam.typehints.decorators import get_type_hints
    67  from apache_beam.typehints.decorators import with_input_types
    68  from apache_beam.typehints.decorators import with_output_types
    69  from apache_beam.typehints.trivial_inference import element_type
    70  from apache_beam.typehints.typehints import TypeConstraint
    71  from apache_beam.typehints.typehints import is_consistent_with
    72  from apache_beam.typehints.typehints import visit_inner_types
    73  from apache_beam.utils import urns
    74  from apache_beam.utils.timestamp import Duration
    75  
    76  if typing.TYPE_CHECKING:
    77    from google.protobuf import message  # pylint: disable=ungrouped-imports
    78    from apache_beam.io import iobase
    79    from apache_beam.pipeline import Pipeline
    80    from apache_beam.runners.pipeline_context import PipelineContext
    81    from apache_beam.transforms import create_source
    82    from apache_beam.transforms.trigger import AccumulationMode
    83    from apache_beam.transforms.trigger import DefaultTrigger
    84    from apache_beam.transforms.trigger import TriggerFn
    85  
    86  __all__ = [
    87      'DoFn',
    88      'CombineFn',
    89      'PartitionFn',
    90      'ParDo',
    91      'FlatMap',
    92      'FlatMapTuple',
    93      'Map',
    94      'MapTuple',
    95      'Filter',
    96      'CombineGlobally',
    97      'CombinePerKey',
    98      'CombineValues',
    99      'GroupBy',
   100      'GroupByKey',
   101      'Select',
   102      'Partition',
   103      'Windowing',
   104      'WindowInto',
   105      'Flatten',
   106      'Create',
   107      'Impulse',
   108      'RestrictionProvider',
   109      'WatermarkEstimatorProvider',
   110  ]
   111  
   112  # Type variables
   113  T = typing.TypeVar('T')
   114  K = typing.TypeVar('K')
   115  V = typing.TypeVar('V')
   116  
   117  _LOGGER = logging.getLogger(__name__)
   118  
   119  
   120  class DoFnContext(object):
   121    """A context available to all methods of DoFn instance."""
   122    pass
   123  
   124  
   125  class DoFnProcessContext(DoFnContext):
   126    """A processing context passed to DoFn process() during execution.
   127  
   128    Most importantly, a DoFn.process method will access context.element
   129    to get the element it is supposed to process.
   130  
   131    Attributes:
   132      label: label of the ParDo whose element is being processed.
   133      element: element being processed
   134        (in process method only; always None in start_bundle and finish_bundle)
   135      timestamp: timestamp of the element
   136        (in process method only; always None in start_bundle and finish_bundle)
   137      windows: windows of the element
   138        (in process method only; always None in start_bundle and finish_bundle)
   139      state: a DoFnState object, which holds the runner's internal state
   140        for this element.
   141        Not used by the pipeline code.
   142    """
   143    def __init__(self, label, element=None, state=None):
   144      """Initialize a processing context object with an element and state.
   145  
   146      The element represents one value from a PCollection that will be accessed
   147      by a DoFn object during pipeline execution, and state is an arbitrary object
   148      where counters and other pipeline state information can be passed in.
   149  
   150      DoFnProcessContext objects are also used as inputs to PartitionFn instances.
   151  
   152      Args:
   153        label: label of the PCollection whose element is being processed.
   154        element: element of a PCollection being processed using this context.
   155        state: a DoFnState object with state to be passed in to the DoFn object.
   156      """
   157      self.label = label
   158      self.state = state
   159      if element is not None:
   160        self.set_element(element)
   161  
   162    def set_element(self, windowed_value):
   163      if windowed_value is None:
   164        # Not currently processing an element.
   165        if hasattr(self, 'element'):
   166          del self.element
   167          del self.timestamp
   168          del self.windows
   169      else:
   170        self.element = windowed_value.value
   171        self.timestamp = windowed_value.timestamp
   172        self.windows = windowed_value.windows
   173  
   174  
   175  class ProcessContinuation(object):
   176    """An object that may be produced as the last element of a process method
   177      invocation.
   178  
   179    If produced, indicates that there is more work to be done for the current
   180    input element.
   181    """
   182    def __init__(self, resume_delay=0):
   183      """Initializes a ProcessContinuation object.
   184  
   185      Args:
   186        resume_delay: indicates the minimum time, in seconds, that should elapse
   187          before re-invoking process() method for resuming the invocation of the
   188          current element.
   189      """
   190      self.resume_delay = resume_delay
   191  
   192    @staticmethod
   193    def resume(resume_delay=0):
   194      """A convenient method that produces a ``ProcessContinuation``.
   195  
   196      Args:
   197        resume_delay: delay after which processing current element should be
   198          resumed.
   199      Returns: a ``ProcessContinuation`` for signalling the runner that current
   200        input element has not been fully processed and should be resumed later.
   201      """
   202      return ProcessContinuation(resume_delay=resume_delay)
   203  
   204  
   205  class RestrictionProvider(object):
   206    """Provides methods for generating and manipulating restrictions.
   207  
   208    This class should be implemented to support Splittable ``DoFn`` in Python
   209    SDK. See https://s.apache.org/splittable-do-fn for more details about
   210    Splittable ``DoFn``.
   211  
   212    To denote a ``DoFn`` class to be Splittable ``DoFn``, ``DoFn.process()``
   213    method of that class should have exactly one parameter whose default value is
   214    an instance of ``RestrictionParam``. This ``RestrictionParam`` can either be
   215    constructed with an explicit ``RestrictionProvider``, or, if no
   216    ``RestrictionProvider`` is provided, the ``DoFn`` itself must be a
   217    ``RestrictionProvider``.
   218  
   219    The provided ``RestrictionProvider`` instance must provide suitable overrides
   220    for the following methods:
   221    * create_tracker()
   222    * initial_restriction()
   223    * restriction_size()
   224  
   225    Optionally, ``RestrictionProvider`` may override default implementations of
   226    following methods:
   227    * restriction_coder()
   228    * split()
   229    * split_and_size()
   230    * truncate()
   231  
   232    ** Pausing and resuming processing of an element **
   233  
   234    As the last element produced by the iterator returned by the
   235    ``DoFn.process()`` method, a Splittable ``DoFn`` may return an object of type
   236    ``ProcessContinuation``.
   237  
   238    If restriction_tracker.defer_remander is called in the ```DoFn.process()``, it
   239    means that runner should later re-invoke ``DoFn.process()`` method to resume
   240    processing the current element and the manner in which the re-invocation
   241    should be performed.
   242  
   243    ** Updating output watermark **
   244  
   245    ``DoFn.process()`` method of Splittable ``DoFn``s could contain a parameter
   246    with default value ``DoFn.WatermarkReporterParam``. If specified this asks the
   247    runner to provide a function that can be used to give the runner a
   248    (best-effort) lower bound about the timestamps of future output associated
   249    with the current element processed by the ``DoFn``. If the ``DoFn`` has
   250    multiple outputs, the watermark applies to all of them. Provided function must
   251    be invoked with a single parameter of type ``Timestamp`` or as an integer that
   252    gives the watermark in number of seconds.
   253    """
   254    def create_tracker(self, restriction):
   255      # type: (...) -> iobase.RestrictionTracker
   256  
   257      """Produces a new ``RestrictionTracker`` for the given restriction.
   258  
   259      This API is required to be implemented.
   260  
   261      Args:
   262        restriction: an object that defines a restriction as identified by a
   263          Splittable ``DoFn`` that utilizes the current ``RestrictionProvider``.
   264          For example, a tuple that gives a range of positions for a Splittable
   265          ``DoFn`` that reads files based on byte positions.
   266      Returns: an object of type ``RestrictionTracker``.
   267      """
   268      raise NotImplementedError
   269  
   270    def initial_restriction(self, element):
   271      """Produces an initial restriction for the given element.
   272  
   273      This API is required to be implemented.
   274      """
   275      raise NotImplementedError
   276  
   277    def split(self, element, restriction):
   278      """Splits the given element and restriction initially.
   279  
   280      This method enables runners to perform bulk splitting initially allowing for
   281      a rapid increase in parallelism. Note that initial split is a different
   282      concept from the split during element processing time. Please refer to
   283      ``iobase.RestrictionTracker.try_split`` for details about splitting when the
   284      current element and restriction are actively being processed.
   285  
   286      Returns an iterator of restrictions. The total set of elements produced by
   287      reading input element for each of the returned restrictions should be the
   288      same as the total set of elements produced by reading the input element for
   289      the input restriction.
   290  
   291      This API is optional if ``split_and_size`` has been implemented.
   292  
   293      If this method is not override, there is no initial splitting happening on
   294      each restriction.
   295  
   296      """
   297      yield restriction
   298  
   299    def restriction_coder(self):
   300      """Returns a ``Coder`` for restrictions.
   301  
   302      Returned``Coder`` will be used for the restrictions produced by the current
   303      ``RestrictionProvider``.
   304  
   305      Returns:
   306        an object of type ``Coder``.
   307      """
   308      return coders.registry.get_coder(object)
   309  
   310    def restriction_size(self, element, restriction):
   311      """Returns the size of a restriction with respect to the given element.
   312  
   313      By default, asks a newly-created restriction tracker for the default size
   314      of the restriction.
   315  
   316      The return value must be non-negative.
   317  
   318      Must be thread safe. Will be invoked concurrently during bundle processing
   319      due to runner initiated splitting and progress estimation.
   320  
   321      This API is required to be implemented.
   322      """
   323      raise NotImplementedError
   324  
   325    def split_and_size(self, element, restriction):
   326      """Like split, but also does sizing, returning (restriction, size) pairs.
   327  
   328      For each pair, size must be non-negative.
   329  
   330      This API is optional if ``split`` and ``restriction_size`` have been
   331      implemented.
   332      """
   333      for part in self.split(element, restriction):
   334        yield part, self.restriction_size(element, part)
   335  
   336    def truncate(self, element, restriction):
   337      """Truncates the provided restriction into a restriction representing a
   338      finite amount of work when the pipeline is
   339      `draining <https://docs.google.com/document/d/1NExwHlj-2q2WUGhSO4jTu8XGhDPmm3cllSN8IMmWci8/edit#> for additional details about drain.>_`.  # pylint: disable=line-too-long
   340      By default, if the restriction is bounded then the restriction will be
   341      returned otherwise None will be returned.
   342  
   343      This API is optional and should only be implemented if more granularity is
   344      required.
   345  
   346      Return a truncated finite restriction if further processing is required
   347      otherwise return None to represent that no further processing of this
   348      restriction is required.
   349  
   350      The default behavior when a pipeline is being drained is that bounded
   351      restrictions process entirely while unbounded restrictions process till a
   352      checkpoint is possible.
   353      """
   354      restriction_tracker = self.create_tracker(restriction)
   355      if restriction_tracker.is_bounded():
   356        return restriction
   357  
   358  
   359  def get_function_arguments(obj, func):
   360    # type: (...) -> typing.Tuple[typing.List[str], typing.List[typing.Any]]
   361  
   362    """Return the function arguments based on the name provided. If they have
   363    a _inspect_function attached to the class then use that otherwise default
   364    to the modified version of python inspect library.
   365  
   366    Returns:
   367      Same as get_function_args_defaults.
   368    """
   369    func_name = '_inspect_%s' % func
   370    if hasattr(obj, func_name):
   371      f = getattr(obj, func_name)
   372      return f()
   373    f = getattr(obj, func)
   374    return get_function_args_defaults(f)
   375  
   376  
   377  def get_function_args_defaults(f):
   378    # type: (...) -> typing.Tuple[typing.List[str], typing.List[typing.Any]]
   379  
   380    """Returns the function arguments of a given function.
   381  
   382    Returns:
   383      (args: List[str], defaults: List[Any]). The first list names the
   384      arguments of the method and the second one has the values of the default
   385      arguments. This is similar to ``inspect.getfullargspec()``'s results, except
   386      it doesn't include bound arguments and may follow function wrappers.
   387    """
   388    signature = get_signature(f)
   389    parameter = inspect.Parameter
   390    # TODO(BEAM-5878) support kwonlyargs on Python 3.
   391    _SUPPORTED_ARG_TYPES = [
   392        parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD
   393    ]
   394    args = [
   395        name for name,
   396        p in signature.parameters.items() if p.kind in _SUPPORTED_ARG_TYPES
   397    ]
   398    defaults = [
   399        p.default for p in signature.parameters.values()
   400        if p.kind in _SUPPORTED_ARG_TYPES and p.default is not p.empty
   401    ]
   402  
   403    return args, defaults
   404  
   405  
   406  class WatermarkEstimatorProvider(object):
   407    """Provides methods for generating WatermarkEstimator.
   408  
   409    This class should be implemented if wanting to providing output_watermark
   410    information within an SDF.
   411  
   412    In order to make an SDF.process() access to the typical WatermarkEstimator,
   413    the SDF author should have an argument whose default value is a
   414    DoFn.WatermarkEstimatorParam instance.  This DoFn.WatermarkEstimatorParam
   415    can either be constructed with an explicit WatermarkEstimatorProvider,
   416    or, if no WatermarkEstimatorProvider is provided, the DoFn itself must
   417    be a WatermarkEstimatorProvider.
   418    """
   419    def initial_estimator_state(self, element, restriction):
   420      """Returns the initial state of the WatermarkEstimator with given element
   421      and restriction.
   422      This function is called by the system.
   423      """
   424      raise NotImplementedError
   425  
   426    def create_watermark_estimator(self, estimator_state):
   427      """Create a new WatermarkEstimator based on the state. The state is
   428      typically useful when resuming processing an element.
   429      """
   430      raise NotImplementedError
   431  
   432    def estimator_state_coder(self):
   433      return coders.registry.get_coder(object)
   434  
   435  
   436  class _DoFnParam(object):
   437    """DoFn parameter."""
   438    def __init__(self, param_id):
   439      self.param_id = param_id
   440  
   441    def __eq__(self, other):
   442      if type(self) == type(other):
   443        return self.param_id == other.param_id
   444      return False
   445  
   446    def __hash__(self):
   447      return hash(self.param_id)
   448  
   449    def __repr__(self):
   450      return self.param_id
   451  
   452  
   453  class _RestrictionDoFnParam(_DoFnParam):
   454    """Restriction Provider DoFn parameter."""
   455    def __init__(self, restriction_provider=None):
   456      # type: (typing.Optional[RestrictionProvider]) -> None
   457      if (restriction_provider is not None and
   458          not isinstance(restriction_provider, RestrictionProvider)):
   459        raise ValueError(
   460            'DoFn.RestrictionParam expected RestrictionProvider object.')
   461      self.restriction_provider = restriction_provider
   462      self.param_id = (
   463          'RestrictionParam(%s)' % restriction_provider.__class__.__name__)
   464  
   465  
   466  class _StateDoFnParam(_DoFnParam):
   467    """State DoFn parameter."""
   468    def __init__(self, state_spec):
   469      # type: (StateSpec) -> None
   470      if not isinstance(state_spec, StateSpec):
   471        raise ValueError("DoFn.StateParam expected StateSpec object.")
   472      self.state_spec = state_spec
   473      self.param_id = 'StateParam(%s)' % state_spec.name
   474  
   475  
   476  class _TimerDoFnParam(_DoFnParam):
   477    """Timer DoFn parameter."""
   478    def __init__(self, timer_spec):
   479      # type: (TimerSpec) -> None
   480      if not isinstance(timer_spec, TimerSpec):
   481        raise ValueError("DoFn.TimerParam expected TimerSpec object.")
   482      self.timer_spec = timer_spec
   483      self.param_id = 'TimerParam(%s)' % timer_spec.name
   484  
   485  
   486  class _BundleFinalizerParam(_DoFnParam):
   487    """Bundle Finalization DoFn parameter."""
   488    def __init__(self):
   489      self._callbacks = []
   490      self.param_id = "FinalizeBundle"
   491  
   492    def register(self, callback):
   493      self._callbacks.append(callback)
   494  
   495    # Log errors when calling callback to make sure all callbacks get called
   496    # though there are errors. And errors should not fail pipeline.
   497    def finalize_bundle(self):
   498      for callback in self._callbacks:
   499        try:
   500          callback()
   501        except Exception as e:
   502          _LOGGER.warning("Got exception from finalization call: %s", e)
   503  
   504    def has_callbacks(self):
   505      # type: () -> bool
   506      return len(self._callbacks) > 0
   507  
   508    def reset(self):
   509      # type: () -> None
   510      del self._callbacks[:]
   511  
   512  
   513  class _WatermarkEstimatorParam(_DoFnParam):
   514    """WatermarkEstimator DoFn parameter."""
   515    def __init__(
   516        self,
   517        watermark_estimator_provider: typing.
   518        Optional[WatermarkEstimatorProvider] = None):
   519      if (watermark_estimator_provider is not None and not isinstance(
   520          watermark_estimator_provider, WatermarkEstimatorProvider)):
   521        raise ValueError(
   522            'DoFn.WatermarkEstimatorParam expected'
   523            'WatermarkEstimatorProvider object.')
   524      self.watermark_estimator_provider = watermark_estimator_provider
   525      self.param_id = 'WatermarkEstimatorProvider'
   526  
   527  
   528  class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
   529    """A function object used by a transform with custom processing.
   530  
   531    The ParDo transform is such a transform. The ParDo.apply
   532    method will take an object of type DoFn and apply it to all elements of a
   533    PCollection object.
   534  
   535    In order to have concrete DoFn objects one has to subclass from DoFn and
   536    define the desired behavior (start_bundle/finish_bundle and process) or wrap a
   537    callable object using the CallableWrapperDoFn class.
   538    """
   539  
   540    # Parameters that can be used in the .process() method.
   541    ElementParam = _DoFnParam('ElementParam')
   542    SideInputParam = _DoFnParam('SideInputParam')
   543    TimestampParam = _DoFnParam('TimestampParam')
   544    WindowParam = _DoFnParam('WindowParam')
   545    PaneInfoParam = _DoFnParam('PaneInfoParam')
   546    WatermarkEstimatorParam = _WatermarkEstimatorParam
   547    BundleFinalizerParam = _BundleFinalizerParam
   548    KeyParam = _DoFnParam('KeyParam')
   549  
   550    # Parameters to access state and timers.  Not restricted to use only in the
   551    # .process() method. Usage: DoFn.StateParam(state_spec),
   552    # DoFn.TimerParam(timer_spec), DoFn.TimestampParam, DoFn.WindowParam,
   553    # DoFn.KeyParam
   554    StateParam = _StateDoFnParam
   555    TimerParam = _TimerDoFnParam
   556    DynamicTimerTagParam = _DoFnParam('DynamicTimerTagParam')
   557  
   558    DoFnProcessParams = [
   559        ElementParam,
   560        SideInputParam,
   561        TimestampParam,
   562        WindowParam,
   563        WatermarkEstimatorParam,
   564        PaneInfoParam,
   565        BundleFinalizerParam,
   566        KeyParam,
   567        StateParam,
   568        TimerParam,
   569    ]
   570  
   571    RestrictionParam = _RestrictionDoFnParam
   572  
   573    @staticmethod
   574    def from_callable(fn):
   575      return CallableWrapperDoFn(fn)
   576  
   577    @staticmethod
   578    def unbounded_per_element():
   579      """A decorator on process fn specifying that the fn performs an unbounded
   580      amount of work per input element."""
   581      def wrapper(process_fn):
   582        process_fn.unbounded_per_element = True
   583        return process_fn
   584  
   585      return wrapper
   586  
   587    @staticmethod
   588    def yields_elements(fn):
   589      """A decorator to apply to ``process_batch`` indicating it yields elements.
   590  
   591      By default ``process_batch`` is assumed to both consume and produce
   592      "batches", which are collections of multiple logical Beam elements. This
   593      decorator indicates that ``process_batch`` **produces** individual elements
   594      at a time. ``process_batch`` is always expected to consume batches.
   595      """
   596      if not fn.__name__ in ('process', 'process_batch'):
   597        raise TypeError(
   598            "@yields_elements must be applied to a process or "
   599            f"process_batch method, got {fn!r}.")
   600  
   601      fn._beam_yields_elements = True
   602      return fn
   603  
   604    @staticmethod
   605    def yields_batches(fn):
   606      """A decorator to apply to ``process`` indicating it yields batches.
   607  
   608      By default ``process`` is assumed to both consume and produce
   609      individual elements at a time. This decorator indicates that ``process``
   610      **produces** "batches", which are collections of multiple logical Beam
   611      elements.
   612      """
   613      if not fn.__name__ in ('process', 'process_batch'):
   614        raise TypeError(
   615            "@yields_elements must be applied to a process or "
   616            f"process_batch method, got {fn!r}.")
   617  
   618      fn._beam_yields_batches = True
   619      return fn
   620  
   621    def default_label(self):
   622      return self.__class__.__name__
   623  
   624    def process(self, element, *args, **kwargs):
   625      """Method to use for processing elements.
   626  
   627      This is invoked by ``DoFnRunner`` for each element of a input
   628      ``PCollection``.
   629  
   630      The following parameters can be used as default values on ``process``
   631      arguments to indicate that a DoFn accepts the corresponding parameters. For
   632      example, a DoFn might accept the element and its timestamp with the
   633      following signature::
   634  
   635        def process(element=DoFn.ElementParam, timestamp=DoFn.TimestampParam):
   636          ...
   637  
   638      The full set of parameters is:
   639  
   640      - ``DoFn.ElementParam``: element to be processed, should not be mutated.
   641      - ``DoFn.SideInputParam``: a side input that may be used when processing.
   642      - ``DoFn.TimestampParam``: timestamp of the input element.
   643      - ``DoFn.WindowParam``: ``Window`` the input element belongs to.
   644      - ``DoFn.TimerParam``: a ``userstate.RuntimeTimer`` object defined by the
   645        spec of the parameter.
   646      - ``DoFn.StateParam``: a ``userstate.RuntimeState`` object defined by the
   647        spec of the parameter.
   648      - ``DoFn.KeyParam``: key associated with the element.
   649      - ``DoFn.RestrictionParam``: an ``iobase.RestrictionTracker`` will be
   650        provided here to allow treatment as a Splittable ``DoFn``. The restriction
   651        tracker will be derived from the restriction provider in the parameter.
   652      - ``DoFn.WatermarkEstimatorParam``: a function that can be used to track
   653        output watermark of Splittable ``DoFn`` implementations.
   654  
   655      Args:
   656        element: The element to be processed
   657        *args: side inputs
   658        **kwargs: other keyword arguments.
   659  
   660      Returns:
   661        An Iterable of output elements or None.
   662      """
   663      raise NotImplementedError
   664  
   665    def process_batch(self, batch, *args, **kwargs):
   666      raise NotImplementedError
   667  
   668    def setup(self):
   669      """Called to prepare an instance for processing bundles of elements.
   670  
   671      This is a good place to initialize transient in-memory resources, such as
   672      network connections. The resources can then be disposed in
   673      ``DoFn.teardown``.
   674      """
   675      pass
   676  
   677    def start_bundle(self):
   678      """Called before a bundle of elements is processed on a worker.
   679  
   680      Elements to be processed are split into bundles and distributed
   681      to workers. Before a worker calls process() on the first element
   682      of its bundle, it calls this method.
   683      """
   684      pass
   685  
   686    def finish_bundle(self):
   687      """Called after a bundle of elements is processed on a worker.
   688      """
   689      pass
   690  
   691    def teardown(self):
   692      """Called to use to clean up this instance before it is discarded.
   693  
   694      A runner will do its best to call this method on any given instance to
   695      prevent leaks of transient resources, however, there may be situations where
   696      this is impossible (e.g. process crash, hardware failure, etc.) or
   697      unnecessary (e.g. the pipeline is shutting down and the process is about to
   698      be killed anyway, so all transient resources will be released automatically
   699      by the OS). In these cases, the call may not happen. It will also not be
   700      retried, because in such situations the DoFn instance no longer exists, so
   701      there's no instance to retry it on.
   702  
   703      Thus, all work that depends on input elements, and all externally important
   704      side effects, must be performed in ``DoFn.process`` or
   705      ``DoFn.finish_bundle``.
   706      """
   707      pass
   708  
   709    def get_function_arguments(self, func):
   710      return get_function_arguments(self, func)
   711  
   712    def default_type_hints(self):
   713      process_type_hints = typehints.decorators.IOTypeHints.from_callable(
   714          self.process) or typehints.decorators.IOTypeHints.empty()
   715  
   716      if self._process_yields_batches:
   717        # process() produces batches, don't use it's output typehint
   718        process_type_hints = process_type_hints.with_output_types_from(
   719            typehints.decorators.IOTypeHints.empty())
   720  
   721      if self._process_batch_yields_elements:
   722        # process_batch() produces elements, *do* use it's output typehint
   723  
   724        # First access the typehint
   725        process_batch_type_hints = typehints.decorators.IOTypeHints.from_callable(
   726            self.process_batch) or typehints.decorators.IOTypeHints.empty()
   727  
   728        # Then we deconflict with the typehint from process, if it exists
   729        if (process_batch_type_hints.output_types !=
   730            typehints.decorators.IOTypeHints.empty().output_types):
   731          if (process_type_hints.output_types !=
   732              typehints.decorators.IOTypeHints.empty().output_types and
   733              process_batch_type_hints.output_types !=
   734              process_type_hints.output_types):
   735            raise TypeError(
   736                f"DoFn {self!r} yields element from both process and "
   737                "process_batch, but they have mismatched output typehints:\n"
   738                f" process: {process_type_hints.output_types}\n"
   739                f" process_batch: {process_batch_type_hints.output_types}")
   740  
   741          process_type_hints = process_type_hints.with_output_types_from(
   742              process_batch_type_hints)
   743  
   744      try:
   745        process_type_hints = process_type_hints.strip_iterable()
   746      except ValueError as e:
   747        raise ValueError('Return value not iterable: %s: %s' % (self, e))
   748  
   749      # Prefer class decorator type hints for backwards compatibility.
   750      return get_type_hints(self.__class__).with_defaults(process_type_hints)
   751  
   752    # TODO(sourabhbajaj): Do we want to remove the responsibility of these from
   753    # the DoFn or maybe the runner
   754    def infer_output_type(self, input_type):
   755      # TODO(https://github.com/apache/beam/issues/19824): Side inputs types.
   756      return trivial_inference.element_type(
   757          _strip_output_annotations(
   758              trivial_inference.infer_return_type(self.process, [input_type])))
   759  
   760    @property
   761    def _process_defined(self) -> bool:
   762      # Check if this DoFn's process method has been overridden
   763      # Note that we retrieve the __func__ attribute, if it exists, to get the
   764      # underlying function from the bound method.
   765      # If __func__ doesn't exist, self.process was likely overridden with a free
   766      # function, as in CallableWrapperDoFn.
   767      return getattr(self.process, '__func__', self.process) != DoFn.process
   768  
   769    @property
   770    def _process_batch_defined(self) -> bool:
   771      # Check if this DoFn's process_batch method has been overridden
   772      # Note that we retrieve the __func__ attribute, if it exists, to get the
   773      # underlying function from the bound method.
   774      # If __func__ doesn't exist, self.process_batch was likely overridden with
   775      # a free function.
   776      return getattr(
   777          self.process_batch, '__func__',
   778          self.process_batch) != DoFn.process_batch
   779  
   780    @property
   781    def _can_yield_batches(self) -> bool:
   782      return ((self._process_defined and self._process_yields_batches) or (
   783          self._process_batch_defined and
   784          not self._process_batch_yields_elements))
   785  
   786    @property
   787    def _process_yields_batches(self) -> bool:
   788      return getattr(self.process, '_beam_yields_batches', False)
   789  
   790    @property
   791    def _process_batch_yields_elements(self) -> bool:
   792      return getattr(self.process_batch, '_beam_yields_elements', False)
   793  
   794    def get_input_batch_type(
   795        self, input_element_type
   796    ) -> typing.Optional[typing.Union[TypeConstraint, type]]:
   797      """Determine the batch type expected as input to process_batch.
   798  
   799      The default implementation of ``get_input_batch_type`` simply observes the
   800      input typehint for the first parameter of ``process_batch``. A Batched DoFn
   801      may override this method if a dynamic approach is required.
   802  
   803      Args:
   804        input_element_type: The **element type** of the input PCollection this
   805          DoFn is being applied to.
   806  
   807      Returns:
   808        ``None`` if this DoFn cannot accept batches, else a Beam typehint or
   809        a native Python typehint.
   810      """
   811      if not self._process_batch_defined:
   812        return None
   813      input_type = list(
   814          inspect.signature(self.process_batch).parameters.values())[0].annotation
   815      if input_type == inspect.Signature.empty:
   816        # TODO(https://github.com/apache/beam/issues/21652): Consider supporting
   817        # an alternative (dynamic?) approach for declaring input type
   818        raise TypeError(
   819            f"Either {self.__class__.__name__}.process_batch() must have a type "
   820            f"annotation on its first parameter, or {self.__class__.__name__} "
   821            "must override get_input_batch_type.")
   822      return input_type
   823  
   824    def _get_input_batch_type_normalized(self, input_element_type):
   825      return typehints.native_type_compatibility.convert_to_beam_type(
   826          self.get_input_batch_type(input_element_type))
   827  
   828    def _get_output_batch_type_normalized(self, input_element_type):
   829      return typehints.native_type_compatibility.convert_to_beam_type(
   830          self.get_output_batch_type(input_element_type))
   831  
   832    @staticmethod
   833    def _get_element_type_from_return_annotation(method, input_type):
   834      return_type = inspect.signature(method).return_annotation
   835      if return_type == inspect.Signature.empty:
   836        # output type not annotated, try to infer it
   837        return_type = trivial_inference.infer_return_type(method, [input_type])
   838  
   839      return_type = typehints.native_type_compatibility.convert_to_beam_type(
   840          return_type)
   841      if isinstance(return_type, typehints.typehints.IterableTypeConstraint):
   842        return return_type.inner_type
   843      elif isinstance(return_type, typehints.typehints.IteratorTypeConstraint):
   844        return return_type.yielded_type
   845      else:
   846        raise TypeError(
   847            "Expected Iterator in return type annotation for "
   848            f"{method!r}, did you mean Iterator[{return_type}]? Note Beam DoFn "
   849            "process and process_batch methods are expected to produce "
   850            "generators - they should 'yield' rather than 'return'.")
   851  
   852    def get_output_batch_type(
   853        self, input_element_type
   854    ) -> typing.Optional[typing.Union[TypeConstraint, type]]:
   855      """Determine the batch type produced by this DoFn's ``process_batch``
   856      implementation and/or its ``process`` implementation with
   857      ``@yields_batch``.
   858  
   859      The default implementation of this method observes the return type
   860      annotations on ``process_batch`` and/or ``process``.  A Batched DoFn may
   861      override this method if a dynamic approach is required.
   862  
   863      Args:
   864        input_element_type: The **element type** of the input PCollection this
   865          DoFn is being applied to.
   866  
   867      Returns:
   868        ``None`` if this DoFn will never yield batches, else a Beam typehint or
   869        a native Python typehint.
   870      """
   871      output_batch_type = None
   872      if self._process_defined and self._process_yields_batches:
   873        output_batch_type = self._get_element_type_from_return_annotation(
   874            self.process, input_element_type)
   875      if self._process_batch_defined and not self._process_batch_yields_elements:
   876        process_batch_type = self._get_element_type_from_return_annotation(
   877            self.process_batch,
   878            self._get_input_batch_type_normalized(input_element_type))
   879  
   880        # TODO: Consider requiring an inheritance relationship rather than
   881        # equality
   882        if (output_batch_type is not None and
   883            (not process_batch_type == output_batch_type)):
   884          raise TypeError(
   885              f"DoFn {self!r} yields batches from both process and "
   886              "process_batch, but they produce different types:\n"
   887              f" process: {output_batch_type}\n"
   888              f" process_batch: {process_batch_type!r}")
   889  
   890        output_batch_type = process_batch_type
   891  
   892      return output_batch_type
   893  
   894    def _process_argspec_fn(self):
   895      """Returns the Python callable that will eventually be invoked.
   896  
   897      This should ideally be the user-level function that is called with
   898      the main and (if any) side inputs, and is used to relate the type
   899      hint parameters with the input parameters (e.g., by argument name).
   900      """
   901      return self.process
   902  
   903    urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_DOFN)
   904  
   905  
   906  class CallableWrapperDoFn(DoFn):
   907    """For internal use only; no backwards-compatibility guarantees.
   908  
   909    A DoFn (function) object wrapping a callable object.
   910  
   911    The purpose of this class is to conveniently wrap simple functions and use
   912    them in transforms.
   913    """
   914    def __init__(self, fn, fullargspec=None):
   915      """Initializes a CallableWrapperDoFn object wrapping a callable.
   916  
   917      Args:
   918        fn: A callable object.
   919  
   920      Raises:
   921        TypeError: if fn parameter is not a callable type.
   922      """
   923      if not callable(fn):
   924        raise TypeError('Expected a callable object instead of: %r' % fn)
   925  
   926      self._fn = fn
   927      self._fullargspec = fullargspec
   928      if isinstance(
   929          fn, (types.BuiltinFunctionType, types.MethodType, types.FunctionType)):
   930        self.process = fn
   931      else:
   932        # For cases such as set / list where fn is callable but not a function
   933        self.process = lambda element: fn(element)
   934  
   935      super().__init__()
   936  
   937    def display_data(self):
   938      # If the callable has a name, then it's likely a function, and
   939      # we show its name.
   940      # Otherwise, it might be an instance of a callable class. We
   941      # show its class.
   942      display_data_value = (
   943          self._fn.__name__
   944          if hasattr(self._fn, '__name__') else self._fn.__class__)
   945      return {
   946          'fn': DisplayDataItem(display_data_value, label='Transform Function')
   947      }
   948  
   949    def __repr__(self):
   950      return 'CallableWrapperDoFn(%s)' % self._fn
   951  
   952    def default_type_hints(self):
   953      fn_type_hints = typehints.decorators.IOTypeHints.from_callable(self._fn)
   954      type_hints = get_type_hints(self._fn).with_defaults(fn_type_hints)
   955      # The fn's output type should be iterable. Strip off the outer
   956      # container type due to the 'flatten' portion of FlatMap/ParDo.
   957      try:
   958        type_hints = type_hints.strip_iterable()
   959      except ValueError as e:
   960        raise TypeCheckError(
   961            'Return value not iterable: %s: %s' %
   962            (self.display_data()['fn'].value, e))
   963      return type_hints
   964  
   965    def infer_output_type(self, input_type):
   966      return trivial_inference.element_type(
   967          _strip_output_annotations(
   968              trivial_inference.infer_return_type(self._fn, [input_type])))
   969  
   970    def _process_argspec_fn(self):
   971      return getattr(self._fn, '_argspec_fn', self._fn)
   972  
   973    def _inspect_process(self):
   974      if self._fullargspec:
   975        return self._fullargspec
   976      else:
   977        return get_function_args_defaults(self._process_argspec_fn())
   978  
   979  
   980  class CombineFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn):
   981    """A function object used by a Combine transform with custom processing.
   982  
   983    A CombineFn specifies how multiple values in all or part of a PCollection can
   984    be merged into a single value---essentially providing the same kind of
   985    information as the arguments to the Python "reduce" builtin (except for the
   986    input argument, which is an instance of CombineFnProcessContext). The
   987    combining process proceeds as follows:
   988  
   989    1. Input values are partitioned into one or more batches.
   990    2. For each batch, the setup method is invoked.
   991    3. For each batch, the create_accumulator method is invoked to create a fresh
   992       initial "accumulator" value representing the combination of zero values.
   993    4. For each input value in the batch, the add_input method is invoked to
   994       combine more values with the accumulator for that batch.
   995    5. The merge_accumulators method is invoked to combine accumulators from
   996       separate batches into a single combined output accumulator value, once all
   997       of the accumulators have had all the input value in their batches added to
   998       them. This operation is invoked repeatedly, until there is only one
   999       accumulator value left.
  1000    6. The extract_output operation is invoked on the final accumulator to get
  1001       the output value.
  1002    7. The teardown method is invoked.
  1003  
  1004    Note: If this **CombineFn** is used with a transform that has defaults,
  1005    **apply** will be called with an empty list at expansion time to get the
  1006    default value.
  1007    """
  1008    def default_label(self):
  1009      return self.__class__.__name__
  1010  
  1011    def setup(self, *args, **kwargs):
  1012      """Called to prepare an instance for combining.
  1013  
  1014      This method can be useful if there is some state that needs to be loaded
  1015      before executing any of the other methods. The resources can then be
  1016      disposed of in ``CombineFn.teardown``.
  1017  
  1018      If you are using Dataflow, you need to enable Dataflow Runner V2
  1019      before using this feature.
  1020  
  1021      Args:
  1022        *args: Additional arguments and side inputs.
  1023        **kwargs: Additional arguments and side inputs.
  1024      """
  1025      pass
  1026  
  1027    def create_accumulator(self, *args, **kwargs):
  1028      """Return a fresh, empty accumulator for the combine operation.
  1029  
  1030      Args:
  1031        *args: Additional arguments and side inputs.
  1032        **kwargs: Additional arguments and side inputs.
  1033      """
  1034      raise NotImplementedError(str(self))
  1035  
  1036    def add_input(self, mutable_accumulator, element, *args, **kwargs):
  1037      """Return result of folding element into accumulator.
  1038  
  1039      CombineFn implementors must override add_input.
  1040  
  1041      Args:
  1042        mutable_accumulator: the current accumulator,
  1043          may be modified and returned for efficiency
  1044        element: the element to add, should not be mutated
  1045        *args: Additional arguments and side inputs.
  1046        **kwargs: Additional arguments and side inputs.
  1047      """
  1048      raise NotImplementedError(str(self))
  1049  
  1050    def add_inputs(self, mutable_accumulator, elements, *args, **kwargs):
  1051      """Returns the result of folding each element in elements into accumulator.
  1052  
  1053      This is provided in case the implementation affords more efficient
  1054      bulk addition of elements. The default implementation simply loops
  1055      over the inputs invoking add_input for each one.
  1056  
  1057      Args:
  1058        mutable_accumulator: the current accumulator,
  1059          may be modified and returned for efficiency
  1060        elements: the elements to add, should not be mutated
  1061        *args: Additional arguments and side inputs.
  1062        **kwargs: Additional arguments and side inputs.
  1063      """
  1064      for element in elements:
  1065        mutable_accumulator =\
  1066          self.add_input(mutable_accumulator, element, *args, **kwargs)
  1067      return mutable_accumulator
  1068  
  1069    def merge_accumulators(self, accumulators, *args, **kwargs):
  1070      """Returns the result of merging several accumulators
  1071      to a single accumulator value.
  1072  
  1073      Args:
  1074        accumulators: the accumulators to merge.
  1075          Only the first accumulator may be modified and returned for efficiency;
  1076          the other accumulators should not be mutated, because they may be
  1077          shared with other code and mutating them could lead to incorrect
  1078          results or data corruption.
  1079        *args: Additional arguments and side inputs.
  1080        **kwargs: Additional arguments and side inputs.
  1081      """
  1082      raise NotImplementedError(str(self))
  1083  
  1084    def compact(self, accumulator, *args, **kwargs):
  1085      """Optionally returns a more compact represenation of the accumulator.
  1086  
  1087      This is called before an accumulator is sent across the wire, and can
  1088      be useful in cases where values are buffered or otherwise lazily
  1089      kept unprocessed when added to the accumulator.  Should return an
  1090      equivalent, though possibly modified, accumulator.
  1091  
  1092      By default returns the accumulator unmodified.
  1093  
  1094      Args:
  1095        accumulator: the current accumulator
  1096        *args: Additional arguments and side inputs.
  1097        **kwargs: Additional arguments and side inputs.
  1098      """
  1099      return accumulator
  1100  
  1101    def extract_output(self, accumulator, *args, **kwargs):
  1102      """Return result of converting accumulator into the output value.
  1103  
  1104      Args:
  1105        accumulator: the final accumulator value computed by this CombineFn
  1106          for the entire input key or PCollection. Can be modified for
  1107          efficiency.
  1108        *args: Additional arguments and side inputs.
  1109        **kwargs: Additional arguments and side inputs.
  1110      """
  1111      raise NotImplementedError(str(self))
  1112  
  1113    def teardown(self, *args, **kwargs):
  1114      """Called to clean up an instance before it is discarded.
  1115  
  1116      If you are using Dataflow, you need to enable Dataflow Runner V2
  1117      before using this feature.
  1118  
  1119      Args:
  1120        *args: Additional arguments and side inputs.
  1121        **kwargs: Additional arguments and side inputs.
  1122      """
  1123      pass
  1124  
  1125    def apply(self, elements, *args, **kwargs):
  1126      """Returns result of applying this CombineFn to the input values.
  1127  
  1128      Args:
  1129        elements: the set of values to combine.
  1130        *args: Additional arguments and side inputs.
  1131        **kwargs: Additional arguments and side inputs.
  1132      """
  1133      return self.extract_output(
  1134          self.add_inputs(
  1135              self.create_accumulator(*args, **kwargs), elements, *args,
  1136              **kwargs),
  1137          *args,
  1138          **kwargs)
  1139  
  1140    def for_input_type(self, input_type):
  1141      """Returns a specialized implementation of self, if it exists.
  1142  
  1143      Otherwise, returns self.
  1144  
  1145      Args:
  1146        input_type: the type of input elements.
  1147      """
  1148      return self
  1149  
  1150    @staticmethod
  1151    def from_callable(fn):
  1152      return CallableWrapperCombineFn(fn)
  1153  
  1154    @staticmethod
  1155    def maybe_from_callable(fn, has_side_inputs=True):
  1156      # type: (typing.Union[CombineFn, typing.Callable], bool) -> CombineFn
  1157      if isinstance(fn, CombineFn):
  1158        return fn
  1159      elif callable(fn) and not has_side_inputs:
  1160        return NoSideInputsCallableWrapperCombineFn(fn)
  1161      elif callable(fn):
  1162        return CallableWrapperCombineFn(fn)
  1163      else:
  1164        raise TypeError('Expected a CombineFn or callable, got %r' % fn)
  1165  
  1166    def get_accumulator_coder(self):
  1167      return coders.registry.get_coder(object)
  1168  
  1169    urns.RunnerApiFn.register_pickle_urn(python_urns.PICKLED_COMBINE_FN)
  1170  
  1171  
  1172  class _ReiterableChain(object):
  1173    """Like itertools.chain, but allowing re-iteration."""
  1174    def __init__(self, iterables):
  1175      self.iterables = iterables
  1176  
  1177    def __iter__(self):
  1178      for iterable in self.iterables:
  1179        for item in iterable:
  1180          yield item
  1181  
  1182    def __bool__(self):
  1183      for iterable in self.iterables:
  1184        for _ in iterable:
  1185          return True
  1186      return False
  1187  
  1188  
  1189  class CallableWrapperCombineFn(CombineFn):
  1190    """For internal use only; no backwards-compatibility guarantees.
  1191  
  1192    A CombineFn (function) object wrapping a callable object.
  1193  
  1194    The purpose of this class is to conveniently wrap simple functions and use
  1195    them in Combine transforms.
  1196    """
  1197    _DEFAULT_BUFFER_SIZE = 10
  1198  
  1199    def __init__(self, fn, buffer_size=_DEFAULT_BUFFER_SIZE):
  1200      """Initializes a CallableFn object wrapping a callable.
  1201  
  1202      Args:
  1203        fn: A callable object that reduces elements of an iterable to a single
  1204          value (like the builtins sum and max). This callable must be capable of
  1205          receiving the kind of values it generates as output in its input, and
  1206          for best results, its operation must be commutative and associative.
  1207  
  1208      Raises:
  1209        TypeError: if fn parameter is not a callable type.
  1210      """
  1211      if not callable(fn):
  1212        raise TypeError('Expected a callable object instead of: %r' % fn)
  1213  
  1214      super().__init__()
  1215      self._fn = fn
  1216      self._buffer_size = buffer_size
  1217  
  1218    def display_data(self):
  1219      return {'fn_dd': self._fn}
  1220  
  1221    def __repr__(self):
  1222      return "%s(%s)" % (self.__class__.__name__, self._fn)
  1223  
  1224    def create_accumulator(self, *args, **kwargs):
  1225      return []
  1226  
  1227    def add_input(self, accumulator, element, *args, **kwargs):
  1228      accumulator.append(element)
  1229      if len(accumulator) > self._buffer_size:
  1230        accumulator = [self._fn(accumulator, *args, **kwargs)]
  1231      return accumulator
  1232  
  1233    def add_inputs(self, accumulator, elements, *args, **kwargs):
  1234      accumulator.extend(elements)
  1235      if len(accumulator) > self._buffer_size:
  1236        accumulator = [self._fn(accumulator, *args, **kwargs)]
  1237      return accumulator
  1238  
  1239    def merge_accumulators(self, accumulators, *args, **kwargs):
  1240      return [self._fn(_ReiterableChain(accumulators), *args, **kwargs)]
  1241  
  1242    def compact(self, accumulator, *args, **kwargs):
  1243      if len(accumulator) <= 1:
  1244        return accumulator
  1245      else:
  1246        return [self._fn(accumulator, *args, **kwargs)]
  1247  
  1248    def extract_output(self, accumulator, *args, **kwargs):
  1249      return self._fn(accumulator, *args, **kwargs)
  1250  
  1251    def default_type_hints(self):
  1252      fn_type_hints = typehints.decorators.IOTypeHints.from_callable(self._fn)
  1253      type_hints = get_type_hints(self._fn).with_defaults(fn_type_hints)
  1254      if type_hints.input_types is None:
  1255        return type_hints
  1256      else:
  1257        # fn(Iterable[V]) -> V becomes CombineFn(V) -> V
  1258        input_args, input_kwargs = type_hints.input_types
  1259        if not input_args:
  1260          if len(input_kwargs) == 1:
  1261            input_args, input_kwargs = tuple(input_kwargs.values()), {}
  1262          else:
  1263            raise TypeError('Combiner input type must be specified positionally.')
  1264        if not is_consistent_with(input_args[0],
  1265                                  typehints.Iterable[typehints.Any]):
  1266          raise TypeCheckError(
  1267              'All functions for a Combine PTransform must accept a '
  1268              'single argument compatible with: Iterable[Any]. '
  1269              'Instead a function with input type: %s was received.' %
  1270              input_args[0])
  1271        input_args = (element_type(input_args[0]), ) + input_args[1:]
  1272        # TODO(robertwb): Assert output type is consistent with input type?
  1273        return type_hints.with_input_types(*input_args, **input_kwargs)
  1274  
  1275    def infer_output_type(self, input_type):
  1276      return _strip_output_annotations(
  1277          trivial_inference.infer_return_type(self._fn, [input_type]))
  1278  
  1279    def for_input_type(self, input_type):
  1280      # Avoid circular imports.
  1281      from apache_beam.transforms import cy_combiners
  1282      if self._fn is any:
  1283        return cy_combiners.AnyCombineFn()
  1284      elif self._fn is all:
  1285        return cy_combiners.AllCombineFn()
  1286      else:
  1287        known_types = {
  1288            (sum, int): cy_combiners.SumInt64Fn(),
  1289            (min, int): cy_combiners.MinInt64Fn(),
  1290            (max, int): cy_combiners.MaxInt64Fn(),
  1291            (sum, float): cy_combiners.SumFloatFn(),
  1292            (min, float): cy_combiners.MinFloatFn(),
  1293            (max, float): cy_combiners.MaxFloatFn(),
  1294        }
  1295      return known_types.get((self._fn, input_type), self)
  1296  
  1297  
  1298  class NoSideInputsCallableWrapperCombineFn(CallableWrapperCombineFn):
  1299    """For internal use only; no backwards-compatibility guarantees.
  1300  
  1301    A CombineFn (function) object wrapping a callable object with no side inputs.
  1302  
  1303    This is identical to its parent, but avoids accepting and passing *args
  1304    and **kwargs for efficiency as they are known to be empty.
  1305    """
  1306    def create_accumulator(self):
  1307      return []
  1308  
  1309    def add_input(self, accumulator, element):
  1310      accumulator.append(element)
  1311      if len(accumulator) > self._buffer_size:
  1312        accumulator = [self._fn(accumulator)]
  1313      return accumulator
  1314  
  1315    def add_inputs(self, accumulator, elements):
  1316      accumulator.extend(elements)
  1317      if len(accumulator) > self._buffer_size:
  1318        accumulator = [self._fn(accumulator)]
  1319      return accumulator
  1320  
  1321    def merge_accumulators(self, accumulators):
  1322      return [self._fn(_ReiterableChain(accumulators))]
  1323  
  1324    def compact(self, accumulator):
  1325      if len(accumulator) <= 1:
  1326        return accumulator
  1327      else:
  1328        return [self._fn(accumulator)]
  1329  
  1330    def extract_output(self, accumulator):
  1331      return self._fn(accumulator)
  1332  
  1333  
  1334  class PartitionFn(WithTypeHints):
  1335    """A function object used by a Partition transform.
  1336  
  1337    A PartitionFn specifies how individual values in a PCollection will be placed
  1338    into separate partitions, indexed by an integer.
  1339    """
  1340    def default_label(self):
  1341      return self.__class__.__name__
  1342  
  1343    def partition_for(self, element, num_partitions, *args, **kwargs):
  1344      # type: (T, int, *typing.Any, **typing.Any) -> int
  1345  
  1346      """Specify which partition will receive this element.
  1347  
  1348      Args:
  1349        element: An element of the input PCollection.
  1350        num_partitions: Number of partitions, i.e., output PCollections.
  1351        *args: optional parameters and side inputs.
  1352        **kwargs: optional parameters and side inputs.
  1353  
  1354      Returns:
  1355        An integer in [0, num_partitions).
  1356      """
  1357      pass
  1358  
  1359  
  1360  class CallableWrapperPartitionFn(PartitionFn):
  1361    """For internal use only; no backwards-compatibility guarantees.
  1362  
  1363    A PartitionFn object wrapping a callable object.
  1364  
  1365    Instances of this class wrap simple functions for use in Partition operations.
  1366    """
  1367    def __init__(self, fn):
  1368      """Initializes a PartitionFn object wrapping a callable.
  1369  
  1370      Args:
  1371        fn: A callable object, which should accept the following arguments:
  1372              element - element to assign to a partition.
  1373              num_partitions - number of output partitions.
  1374            and may accept additional arguments and side inputs.
  1375  
  1376      Raises:
  1377        TypeError: if fn is not a callable type.
  1378      """
  1379      if not callable(fn):
  1380        raise TypeError('Expected a callable object instead of: %r' % fn)
  1381      self._fn = fn
  1382  
  1383    def partition_for(self, element, num_partitions, *args, **kwargs):
  1384      # type: (T, int, *typing.Any, **typing.Any) -> int
  1385      return self._fn(element, num_partitions, *args, **kwargs)
  1386  
  1387  
  1388  def _get_function_body_without_inners(func):
  1389    source_lines = inspect.getsourcelines(func)[0]
  1390    source_lines = dropwhile(lambda x: x.startswith("@"), source_lines)
  1391    def_line = next(source_lines).strip()
  1392    if def_line.startswith("def ") and def_line.endswith(":"):
  1393      first_line = next(source_lines)
  1394      indentation = len(first_line) - len(first_line.lstrip())
  1395      final_lines = [first_line[indentation:]]
  1396  
  1397      skip_inner_def = False
  1398      if first_line[indentation:].startswith("def "):
  1399        skip_inner_def = True
  1400      for line in source_lines:
  1401        line_indentation = len(line) - len(line.lstrip())
  1402  
  1403        if line[indentation:].startswith("def "):
  1404          skip_inner_def = True
  1405          continue
  1406  
  1407        if skip_inner_def and line_indentation == indentation:
  1408          skip_inner_def = False
  1409  
  1410        if skip_inner_def and line_indentation > indentation:
  1411          continue
  1412        final_lines.append(line[indentation:])
  1413  
  1414      return "".join(final_lines)
  1415    else:
  1416      return def_line.rsplit(":")[-1].strip()
  1417  
  1418  
  1419  def _check_fn_use_yield_and_return(fn):
  1420    if isinstance(fn, types.BuiltinFunctionType):
  1421      return False
  1422    try:
  1423      source_code = _get_function_body_without_inners(fn)
  1424      has_yield = False
  1425      has_return = False
  1426      for line in source_code.split("\n"):
  1427        if line.lstrip().startswith("yield ") or line.lstrip().startswith(
  1428            "yield("):
  1429          has_yield = True
  1430        if line.lstrip().startswith("return ") or line.lstrip().startswith(
  1431            "return("):
  1432          has_return = True
  1433        if has_yield and has_return:
  1434          return True
  1435      return False
  1436    except Exception as e:
  1437      _LOGGER.debug(str(e))
  1438      return False
  1439  
  1440  
  1441  class ParDo(PTransformWithSideInputs):
  1442    """A :class:`ParDo` transform.
  1443  
  1444    Processes an input :class:`~apache_beam.pvalue.PCollection` by applying a
  1445    :class:`DoFn` to each element and returning the accumulated results into an
  1446    output :class:`~apache_beam.pvalue.PCollection`. The type of the elements is
  1447    not fixed as long as the :class:`DoFn` can deal with it. In reality the type
  1448    is restrained to some extent because the elements sometimes must be persisted
  1449    to external storage. See the :meth:`.expand()` method comments for a
  1450    detailed description of all possible arguments.
  1451  
  1452    Note that the :class:`DoFn` must return an iterable for each element of the
  1453    input :class:`~apache_beam.pvalue.PCollection`. An easy way to do this is to
  1454    use the ``yield`` keyword in the process method.
  1455  
  1456    Args:
  1457      pcoll (~apache_beam.pvalue.PCollection):
  1458        a :class:`~apache_beam.pvalue.PCollection` to be processed.
  1459      fn (`typing.Union[DoFn, typing.Callable]`): a :class:`DoFn` object to be
  1460        applied to each element of **pcoll** argument, or a Callable.
  1461      *args: positional arguments passed to the :class:`DoFn` object.
  1462      **kwargs:  keyword arguments passed to the :class:`DoFn` object.
  1463  
  1464    Note that the positional and keyword arguments will be processed in order
  1465    to detect :class:`~apache_beam.pvalue.PCollection` s that will be computed as
  1466    side inputs to the transform. During pipeline execution whenever the
  1467    :class:`DoFn` object gets executed (its :meth:`DoFn.process()` method gets
  1468    called) the :class:`~apache_beam.pvalue.PCollection` arguments will be
  1469    replaced by values from the :class:`~apache_beam.pvalue.PCollection` in the
  1470    exact positions where they appear in the argument lists.
  1471    """
  1472    def __init__(self, fn, *args, **kwargs):
  1473      super().__init__(fn, *args, **kwargs)
  1474      # TODO(robertwb): Change all uses of the dofn attribute to use fn instead.
  1475      self.dofn = self.fn
  1476      self.output_tags = set()  # type: typing.Set[str]
  1477  
  1478      if not isinstance(self.fn, DoFn):
  1479        raise TypeError('ParDo must be called with a DoFn instance.')
  1480  
  1481      # DoFn.process cannot allow both return and yield
  1482      if _check_fn_use_yield_and_return(self.fn.process):
  1483        _LOGGER.warning(
  1484            'Using yield and return in the process method '
  1485            'of %s can lead to unexpected behavior, see:'
  1486            'https://github.com/apache/beam/issues/22969.',
  1487            self.fn.__class__)
  1488  
  1489      # Validate the DoFn by creating a DoFnSignature
  1490      from apache_beam.runners.common import DoFnSignature
  1491      self._signature = DoFnSignature(self.fn)
  1492  
  1493    def with_exception_handling(
  1494        self,
  1495        main_tag='good',
  1496        dead_letter_tag='bad',
  1497        *,
  1498        exc_class=Exception,
  1499        partial=False,
  1500        use_subprocess=False,
  1501        threshold=1,
  1502        threshold_windowing=None,
  1503        timeout=None):
  1504      """Automatically provides a dead letter output for skipping bad records.
  1505      This can allow a pipeline to continue successfully rather than fail or
  1506      continuously throw errors on retry when bad elements are encountered.
  1507  
  1508      This returns a tagged output with two PCollections, the first being the
  1509      results of successfully processing the input PCollection, and the second
  1510      being the set of bad records (those which threw exceptions during
  1511      processing) along with information about the errors raised.
  1512  
  1513      For example, one would write::
  1514  
  1515          good, bad = Map(maybe_error_raising_function).with_exception_handling()
  1516  
  1517      and `good` will be a PCollection of mapped records and `bad` will contain
  1518      those that raised exceptions.
  1519  
  1520  
  1521      Args:
  1522        main_tag: tag to be used for the main (good) output of the DoFn,
  1523            useful to avoid possible conflicts if this DoFn already produces
  1524            multiple outputs.  Optional, defaults to 'good'.
  1525        dead_letter_tag: tag to be used for the bad records, useful to avoid
  1526            possible conflicts if this DoFn already produces multiple outputs.
  1527            Optional, defaults to 'bad'.
  1528        exc_class: An exception class, or tuple of exception classes, to catch.
  1529            Optional, defaults to 'Exception'.
  1530        partial: Whether to emit outputs for an element as they're produced
  1531            (which could result in partial outputs for a ParDo or FlatMap that
  1532            throws an error part way through execution) or buffer all outputs
  1533            until successful processing of the entire element. Optional,
  1534            defaults to False.
  1535        use_subprocess: Whether to execute the DoFn logic in a subprocess. This
  1536            allows one to recover from errors that can crash the calling process
  1537            (e.g. from an underlying C/C++ library causing a segfault), but is
  1538            slower as elements and results must cross a process boundary.  Note
  1539            that this starts up a long-running process that is used to handle
  1540            all the elements (until hard failure, which should be rare) rather
  1541            than a new process per element, so the overhead should be minimal
  1542            (and can be amortized if there's any per-process or per-bundle
  1543            initialization that needs to be done). Optional, defaults to False.
  1544        threshold: An upper bound on the ratio of records that can be bad before
  1545            aborting the entire pipeline. Optional, defaults to 1.0 (meaning
  1546            up to 100% of records can be bad and the pipeline will still succeed).
  1547        threshold_windowing: Event-time windowing to use for threshold. Optional,
  1548            defaults to the windowing of the input.
  1549        timeout: If the element has not finished processing in timeout seconds,
  1550            raise a TimeoutError.  Defaults to None, meaning no time limit.
  1551      """
  1552      args, kwargs = self.raw_side_inputs
  1553      return self.label >> _ExceptionHandlingWrapper(
  1554          self.fn,
  1555          args,
  1556          kwargs,
  1557          main_tag,
  1558          dead_letter_tag,
  1559          exc_class,
  1560          partial,
  1561          use_subprocess,
  1562          threshold,
  1563          threshold_windowing,
  1564          timeout)
  1565  
  1566    def default_type_hints(self):
  1567      return self.fn.get_type_hints()
  1568  
  1569    def infer_output_type(self, input_type):
  1570      return self.fn.infer_output_type(input_type)
  1571  
  1572    def infer_batch_converters(self, input_element_type):
  1573      # TODO: Test this code (in batch_dofn_test)
  1574      if self.fn._process_batch_defined:
  1575        input_batch_type = self.fn._get_input_batch_type_normalized(
  1576            input_element_type)
  1577  
  1578        if input_batch_type is None:
  1579          raise TypeError(
  1580              "process_batch method on {self.fn!r} does not have "
  1581              "an input type annoation")
  1582  
  1583        try:
  1584          # Generate a batch converter to convert between the input type and the
  1585          # (batch) input type of process_batch
  1586          self.fn.input_batch_converter = BatchConverter.from_typehints(
  1587              element_type=input_element_type, batch_type=input_batch_type)
  1588        except TypeError as e:
  1589          raise TypeError(
  1590              "Failed to find a BatchConverter for the input types of DoFn "
  1591              f"{self.fn!r} (element_type={input_element_type!r}, "
  1592              f"batch_type={input_batch_type!r}).") from e
  1593  
  1594      else:
  1595        self.fn.input_batch_converter = None
  1596  
  1597      if self.fn._can_yield_batches:
  1598        output_batch_type = self.fn._get_output_batch_type_normalized(
  1599            input_element_type)
  1600        if output_batch_type is None:
  1601          # TODO: Mention process method in this error
  1602          raise TypeError(
  1603              f"process_batch method on {self.fn!r} does not have "
  1604              "a return type annoation")
  1605  
  1606        # Generate a batch converter to convert between the output type and the
  1607        # (batch) output type of process_batch
  1608        output_element_type = self.infer_output_type(input_element_type)
  1609  
  1610        try:
  1611          self.fn.output_batch_converter = BatchConverter.from_typehints(
  1612              element_type=output_element_type, batch_type=output_batch_type)
  1613        except TypeError as e:
  1614          raise TypeError(
  1615              "Failed to find a BatchConverter for the *output* types of DoFn "
  1616              f"{self.fn!r} (element_type={output_element_type!r}, "
  1617              f"batch_type={output_batch_type!r}). Maybe you need to override "
  1618              "DoFn.infer_output_type to set the output element type?") from e
  1619      else:
  1620        self.fn.output_batch_converter = None
  1621  
  1622    def make_fn(self, fn, has_side_inputs):
  1623      if isinstance(fn, DoFn):
  1624        return fn
  1625      return CallableWrapperDoFn(fn)
  1626  
  1627    def _process_argspec_fn(self):
  1628      return self.fn._process_argspec_fn()
  1629  
  1630    def display_data(self):
  1631      return {
  1632          'fn': DisplayDataItem(self.fn.__class__, label='Transform Function'),
  1633          'fn_dd': self.fn
  1634      }
  1635  
  1636    def expand(self, pcoll):
  1637      # In the case of a stateful DoFn, warn if the key coder is not
  1638      # deterministic.
  1639      if self._signature.is_stateful_dofn():
  1640        kv_type_hint = pcoll.element_type
  1641        if kv_type_hint and kv_type_hint != typehints.Any:
  1642          coder = coders.registry.get_coder(kv_type_hint)
  1643          if not coder.is_kv_coder():
  1644            raise ValueError(
  1645                'Input elements to the transform %s with stateful DoFn must be '
  1646                'key-value pairs.' % self)
  1647          key_coder = coder.key_coder()
  1648        else:
  1649          key_coder = coders.registry.get_coder(typehints.Any)
  1650  
  1651        if not key_coder.is_deterministic():
  1652          _LOGGER.warning(
  1653              'Key coder %s for transform %s with stateful DoFn may not '
  1654              'be deterministic. This may cause incorrect behavior for complex '
  1655              'key types. Consider adding an input type hint for this transform.',
  1656              key_coder,
  1657              self)
  1658  
  1659      if self._signature.is_unbounded_per_element():
  1660        is_bounded = False
  1661      else:
  1662        is_bounded = pcoll.is_bounded
  1663  
  1664      self.infer_batch_converters(pcoll.element_type)
  1665  
  1666      return pvalue.PCollection.from_(pcoll, is_bounded=is_bounded)
  1667  
  1668    def with_outputs(self, *tags, main=None, allow_unknown_tags=None):
  1669      """Returns a tagged tuple allowing access to the outputs of a
  1670      :class:`ParDo`.
  1671  
  1672      The resulting object supports access to the
  1673      :class:`~apache_beam.pvalue.PCollection` associated with a tag
  1674      (e.g. ``o.tag``, ``o[tag]``) and iterating over the available tags
  1675      (e.g. ``for tag in o: ...``).
  1676  
  1677      Args:
  1678        *tags: if non-empty, list of valid tags. If a list of valid tags is given,
  1679          it will be an error to use an undeclared tag later in the pipeline.
  1680        **main_kw: dictionary empty or with one key ``'main'`` defining the tag to
  1681          be used for the main output (which will not have a tag associated with
  1682          it).
  1683  
  1684      Returns:
  1685        ~apache_beam.pvalue.DoOutputsTuple: An object of type
  1686        :class:`~apache_beam.pvalue.DoOutputsTuple` that bundles together all
  1687        the outputs of a :class:`ParDo` transform and allows accessing the
  1688        individual :class:`~apache_beam.pvalue.PCollection` s for each output
  1689        using an ``object.tag`` syntax.
  1690  
  1691      Raises:
  1692        TypeError: if the **self** object is not a
  1693          :class:`~apache_beam.pvalue.PCollection` that is the result of a
  1694          :class:`ParDo` transform.
  1695        ValueError: if **main_kw** contains any key other than
  1696          ``'main'``.
  1697      """
  1698      if main in tags:
  1699        raise ValueError(
  1700            'Main output tag %r must be different from side output tags %r.' %
  1701            (main, tags))
  1702      return _MultiParDo(self, tags, main, allow_unknown_tags)
  1703  
  1704    def _do_fn_info(self):
  1705      return DoFnInfo.create(self.fn, self.args, self.kwargs)
  1706  
  1707    def _get_key_and_window_coder(self, named_inputs):
  1708      if named_inputs is None or not self._signature.is_stateful_dofn():
  1709        return None, None
  1710      main_input = list(set(named_inputs.keys()) - set(self.side_inputs))[0]
  1711      input_pcoll = named_inputs[main_input]
  1712      kv_type_hint = input_pcoll.element_type
  1713      if kv_type_hint and kv_type_hint != typehints.Any:
  1714        coder = coders.registry.get_coder(kv_type_hint)
  1715        if not coder.is_kv_coder():
  1716          raise ValueError(
  1717              'Input elements to the transform %s with stateful DoFn must be '
  1718              'key-value pairs.' % self)
  1719        key_coder = coder.key_coder()
  1720      else:
  1721        key_coder = coders.registry.get_coder(typehints.Any)
  1722      window_coder = input_pcoll.windowing.windowfn.get_window_coder()
  1723      return key_coder, window_coder
  1724  
  1725    # typing: PTransform base class does not accept extra_kwargs
  1726    def to_runner_api_parameter(self, context, **extra_kwargs):  # type: ignore[override]
  1727      # type: (PipelineContext, **typing.Any) -> typing.Tuple[str, message.Message]
  1728      assert isinstance(self, ParDo), \
  1729          "expected instance of ParDo, but got %s" % self.__class__
  1730      state_specs, timer_specs = userstate.get_dofn_specs(self.fn)
  1731      if state_specs or timer_specs:
  1732        context.add_requirement(
  1733            common_urns.requirements.REQUIRES_STATEFUL_PROCESSING.urn)
  1734      from apache_beam.runners.common import DoFnSignature
  1735      sig = DoFnSignature(self.fn)
  1736      is_splittable = sig.is_splittable_dofn()
  1737      if is_splittable:
  1738        restriction_coder = sig.get_restriction_coder()
  1739        # restriction_coder will never be None when is_splittable is True
  1740        assert restriction_coder is not None
  1741        restriction_coder_id = context.coders.get_id(
  1742            restriction_coder)  # type: typing.Optional[str]
  1743        context.add_requirement(
  1744            common_urns.requirements.REQUIRES_SPLITTABLE_DOFN.urn)
  1745      else:
  1746        restriction_coder_id = None
  1747      has_bundle_finalization = sig.has_bundle_finalization()
  1748      if has_bundle_finalization:
  1749        context.add_requirement(
  1750            common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn)
  1751  
  1752      # Get key_coder and window_coder for main_input.
  1753      key_coder, window_coder = self._get_key_and_window_coder(
  1754          extra_kwargs.get('named_inputs', None))
  1755      return (
  1756          common_urns.primitives.PAR_DO.urn,
  1757          beam_runner_api_pb2.ParDoPayload(
  1758              do_fn=self._do_fn_info().to_runner_api(context),
  1759              requests_finalization=has_bundle_finalization,
  1760              restriction_coder_id=restriction_coder_id,
  1761              state_specs={
  1762                  spec.name: spec.to_runner_api(context)
  1763                  for spec in state_specs
  1764              },
  1765              timer_family_specs={
  1766                  spec.name: spec.to_runner_api(context, key_coder, window_coder)
  1767                  for spec in timer_specs
  1768              },
  1769              # It'd be nice to name these according to their actual
  1770              # names/positions in the orignal argument list, but such a
  1771              # transformation is currently irreversible given how
  1772              # remove_objects_from_args and insert_values_in_args
  1773              # are currently implemented.
  1774              side_inputs={(SIDE_INPUT_PREFIX + '%s') % ix:
  1775                           si.to_runner_api(context)
  1776                           for ix,
  1777                           si in enumerate(self.side_inputs)}))
  1778  
  1779    @staticmethod
  1780    @PTransform.register_urn(
  1781        common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload)
  1782    def from_runner_api_parameter(unused_ptransform, pardo_payload, context):
  1783      fn, args, kwargs, si_tags_and_types, windowing = pickler.loads(
  1784          DoFnInfo.from_runner_api(
  1785              pardo_payload.do_fn, context).serialized_dofn_data())
  1786      if si_tags_and_types:
  1787        raise NotImplementedError('explicit side input data')
  1788      elif windowing:
  1789        raise NotImplementedError('explicit windowing')
  1790      result = ParDo(fn, *args, **kwargs)
  1791      # This is an ordered list stored as a dict (see the comments in
  1792      # to_runner_api_parameter above).
  1793      indexed_side_inputs = [(
  1794          get_sideinput_index(tag),
  1795          pvalue.AsSideInput.from_runner_api(si, context)) for tag,
  1796                             si in pardo_payload.side_inputs.items()]
  1797      result.side_inputs = [si for _, si in sorted(indexed_side_inputs)]
  1798      return result
  1799  
  1800    def runner_api_requires_keyed_input(self):
  1801      return userstate.is_stateful_dofn(self.fn)
  1802  
  1803    def get_restriction_coder(self):
  1804      """Returns `restriction coder if `DoFn` of this `ParDo` is a SDF.
  1805  
  1806      Returns `None` otherwise.
  1807      """
  1808      from apache_beam.runners.common import DoFnSignature
  1809      return DoFnSignature(self.fn).get_restriction_coder()
  1810  
  1811    def _add_type_constraint_from_consumer(self, full_label, input_type_hints):
  1812      if not hasattr(self.fn, '_runtime_output_constraints'):
  1813        self.fn._runtime_output_constraints = {}
  1814      self.fn._runtime_output_constraints[full_label] = input_type_hints
  1815  
  1816  
  1817  class _MultiParDo(PTransform):
  1818    def __init__(self, do_transform, tags, main_tag, allow_unknown_tags=None):
  1819      super().__init__(do_transform.label)
  1820      self._do_transform = do_transform
  1821      self._tags = tags
  1822      self._main_tag = main_tag
  1823      self._allow_unknown_tags = allow_unknown_tags
  1824  
  1825    def expand(self, pcoll):
  1826      _ = pcoll | self._do_transform
  1827      return pvalue.DoOutputsTuple(
  1828          pcoll.pipeline,
  1829          self._do_transform,
  1830          self._tags,
  1831          self._main_tag,
  1832          self._allow_unknown_tags)
  1833  
  1834  
  1835  class DoFnInfo(object):
  1836    """This class represents the state in the ParDoPayload's function spec,
  1837    which is the actual DoFn together with some data required for invoking it.
  1838    """
  1839    @staticmethod
  1840    def register_stateless_dofn(urn):
  1841      def wrapper(cls):
  1842        StatelessDoFnInfo.REGISTERED_DOFNS[urn] = cls
  1843        cls._stateless_dofn_urn = urn
  1844        return cls
  1845  
  1846      return wrapper
  1847  
  1848    @classmethod
  1849    def create(cls, fn, args, kwargs):
  1850      if hasattr(fn, '_stateless_dofn_urn'):
  1851        assert not args and not kwargs
  1852        return StatelessDoFnInfo(fn._stateless_dofn_urn)
  1853      else:
  1854        return PickledDoFnInfo(cls._pickled_do_fn_info(fn, args, kwargs))
  1855  
  1856    @staticmethod
  1857    def from_runner_api(spec, unused_context):
  1858      if spec.urn == python_urns.PICKLED_DOFN_INFO:
  1859        return PickledDoFnInfo(spec.payload)
  1860      elif spec.urn in StatelessDoFnInfo.REGISTERED_DOFNS:
  1861        return StatelessDoFnInfo(spec.urn)
  1862      else:
  1863        raise ValueError('Unexpected DoFn type: %s' % spec.urn)
  1864  
  1865    @staticmethod
  1866    def _pickled_do_fn_info(fn, args, kwargs):
  1867      # This can be cleaned up once all runners move to portability.
  1868      return pickler.dumps((fn, args, kwargs, None, None))
  1869  
  1870    def serialized_dofn_data(self):
  1871      raise NotImplementedError(type(self))
  1872  
  1873  
  1874  class PickledDoFnInfo(DoFnInfo):
  1875    def __init__(self, serialized_data):
  1876      self._serialized_data = serialized_data
  1877  
  1878    def serialized_dofn_data(self):
  1879      return self._serialized_data
  1880  
  1881    def to_runner_api(self, unused_context):
  1882      return beam_runner_api_pb2.FunctionSpec(
  1883          urn=python_urns.PICKLED_DOFN_INFO, payload=self._serialized_data)
  1884  
  1885  
  1886  class StatelessDoFnInfo(DoFnInfo):
  1887  
  1888    REGISTERED_DOFNS = {}  # type: typing.Dict[str, typing.Type[DoFn]]
  1889  
  1890    def __init__(self, urn):
  1891      # type: (str) -> None
  1892      assert urn in self.REGISTERED_DOFNS
  1893      self._urn = urn
  1894  
  1895    def serialized_dofn_data(self):
  1896      return self._pickled_do_fn_info(self.REGISTERED_DOFNS[self._urn](), (), {})
  1897  
  1898    def to_runner_api(self, unused_context):
  1899      return beam_runner_api_pb2.FunctionSpec(urn=self._urn)
  1900  
  1901  
  1902  def FlatMap(fn, *args, **kwargs):  # pylint: disable=invalid-name
  1903    """:func:`FlatMap` is like :class:`ParDo` except it takes a callable to
  1904    specify the transformation.
  1905  
  1906    The callable must return an iterable for each element of the input
  1907    :class:`~apache_beam.pvalue.PCollection`. The elements of these iterables will
  1908    be flattened into the output :class:`~apache_beam.pvalue.PCollection`.
  1909  
  1910    Args:
  1911      fn (callable): a callable object.
  1912      *args: positional arguments passed to the transform callable.
  1913      **kwargs: keyword arguments passed to the transform callable.
  1914  
  1915    Returns:
  1916      ~apache_beam.pvalue.PCollection:
  1917      A :class:`~apache_beam.pvalue.PCollection` containing the
  1918      :func:`FlatMap` outputs.
  1919  
  1920    Raises:
  1921      TypeError: If the **fn** passed as argument is not a callable.
  1922        Typical error is to pass a :class:`DoFn` instance which is supported only
  1923        for :class:`ParDo`.
  1924    """
  1925    label = 'FlatMap(%s)' % ptransform.label_from_callable(fn)
  1926    if not callable(fn):
  1927      raise TypeError(
  1928          'FlatMap can be used only with callable objects. '
  1929          'Received %r instead.' % (fn))
  1930  
  1931    pardo = ParDo(CallableWrapperDoFn(fn), *args, **kwargs)
  1932    pardo.label = label
  1933    return pardo
  1934  
  1935  
  1936  def Map(fn, *args, **kwargs):  # pylint: disable=invalid-name
  1937    """:func:`Map` is like :func:`FlatMap` except its callable returns only a
  1938    single element.
  1939  
  1940    Args:
  1941      fn (callable): a callable object.
  1942      *args: positional arguments passed to the transform callable.
  1943      **kwargs: keyword arguments passed to the transform callable.
  1944  
  1945    Returns:
  1946      ~apache_beam.pvalue.PCollection:
  1947      A :class:`~apache_beam.pvalue.PCollection` containing the
  1948      :func:`Map` outputs.
  1949  
  1950    Raises:
  1951      TypeError: If the **fn** passed as argument is not a callable.
  1952        Typical error is to pass a :class:`DoFn` instance which is supported only
  1953        for :class:`ParDo`.
  1954    """
  1955    if not callable(fn):
  1956      raise TypeError(
  1957          'Map can be used only with callable objects. '
  1958          'Received %r instead.' % (fn))
  1959    from apache_beam.transforms.util import fn_takes_side_inputs
  1960    if fn_takes_side_inputs(fn):
  1961      wrapper = lambda x, *args, **kwargs: [fn(x, *args, **kwargs)]
  1962    else:
  1963      wrapper = lambda x: [fn(x)]
  1964  
  1965    label = 'Map(%s)' % ptransform.label_from_callable(fn)
  1966  
  1967    # TODO. What about callable classes?
  1968    if hasattr(fn, '__name__'):
  1969      wrapper.__name__ = fn.__name__
  1970  
  1971    # Proxy the type-hint information from the original function to this new
  1972    # wrapped function.
  1973    type_hints = get_type_hints(fn).with_defaults(
  1974        typehints.decorators.IOTypeHints.from_callable(fn))
  1975    if type_hints.input_types is not None:
  1976      wrapper = with_input_types(
  1977          *type_hints.input_types[0], **type_hints.input_types[1])(
  1978              wrapper)
  1979    output_hint = type_hints.simple_output_type(label)
  1980    if output_hint:
  1981      wrapper = with_output_types(
  1982          typehints.Iterable[_strip_output_annotations(output_hint)])(
  1983              wrapper)
  1984    # pylint: disable=protected-access
  1985    wrapper._argspec_fn = fn
  1986    # pylint: enable=protected-access
  1987  
  1988    pardo = FlatMap(wrapper, *args, **kwargs)
  1989    pardo.label = label
  1990    return pardo
  1991  
  1992  
  1993  def MapTuple(fn, *args, **kwargs):  # pylint: disable=invalid-name
  1994    r""":func:`MapTuple` is like :func:`Map` but expects tuple inputs and
  1995    flattens them into multiple input arguments.
  1996  
  1997        beam.MapTuple(lambda a, b, ...: ...)
  1998  
  1999    In other words
  2000  
  2001        beam.MapTuple(fn)
  2002  
  2003    is equivalent to
  2004  
  2005        beam.Map(lambda element, ...: fn(\*element, ...))
  2006  
  2007    This can be useful when processing a PCollection of tuples
  2008    (e.g. key-value pairs).
  2009  
  2010    Args:
  2011      fn (callable): a callable object.
  2012      *args: positional arguments passed to the transform callable.
  2013      **kwargs: keyword arguments passed to the transform callable.
  2014  
  2015    Returns:
  2016      ~apache_beam.pvalue.PCollection:
  2017      A :class:`~apache_beam.pvalue.PCollection` containing the
  2018      :func:`MapTuple` outputs.
  2019  
  2020    Raises:
  2021      TypeError: If the **fn** passed as argument is not a callable.
  2022        Typical error is to pass a :class:`DoFn` instance which is supported only
  2023        for :class:`ParDo`.
  2024    """
  2025    if not callable(fn):
  2026      raise TypeError(
  2027          'MapTuple can be used only with callable objects. '
  2028          'Received %r instead.' % (fn))
  2029  
  2030    label = 'MapTuple(%s)' % ptransform.label_from_callable(fn)
  2031  
  2032    arg_names, defaults = get_function_args_defaults(fn)
  2033    num_defaults = len(defaults)
  2034    if num_defaults < len(args) + len(kwargs):
  2035      raise TypeError('Side inputs must have defaults for MapTuple.')
  2036  
  2037    if defaults or args or kwargs:
  2038      wrapper = lambda x, *args, **kwargs: [fn(*(tuple(x) + args), **kwargs)]
  2039    else:
  2040      wrapper = lambda x: [fn(*x)]
  2041  
  2042    # Proxy the type-hint information from the original function to this new
  2043    # wrapped function.
  2044    type_hints = get_type_hints(fn).with_defaults(
  2045        typehints.decorators.IOTypeHints.from_callable(fn))
  2046    if type_hints.input_types is not None:
  2047      # TODO(BEAM-14052): ignore input hints, as we do not have enough
  2048      # information to infer the input type hint of the wrapper function.
  2049      pass
  2050    output_hint = type_hints.simple_output_type(label)
  2051    if output_hint:
  2052      wrapper = with_output_types(
  2053          typehints.Iterable[_strip_output_annotations(output_hint)])(
  2054              wrapper)
  2055  
  2056    # Replace the first (args) component.
  2057    modified_arg_names = ['tuple_element'] + arg_names[-num_defaults:]
  2058    modified_argspec = (modified_arg_names, defaults)
  2059    pardo = ParDo(
  2060        CallableWrapperDoFn(wrapper, fullargspec=modified_argspec),
  2061        *args,
  2062        **kwargs)
  2063    pardo.label = label
  2064    return pardo
  2065  
  2066  
  2067  def FlatMapTuple(fn, *args, **kwargs):  # pylint: disable=invalid-name
  2068    r""":func:`FlatMapTuple` is like :func:`FlatMap` but expects tuple inputs and
  2069    flattens them into multiple input arguments.
  2070  
  2071        beam.FlatMapTuple(lambda a, b, ...: ...)
  2072  
  2073    is equivalent to Python 2
  2074  
  2075        beam.FlatMap(lambda (a, b, ...), ...: ...)
  2076  
  2077    In other words
  2078  
  2079        beam.FlatMapTuple(fn)
  2080  
  2081    is equivalent to
  2082  
  2083        beam.FlatMap(lambda element, ...: fn(\*element, ...))
  2084  
  2085    This can be useful when processing a PCollection of tuples
  2086    (e.g. key-value pairs).
  2087  
  2088    Args:
  2089      fn (callable): a callable object.
  2090      *args: positional arguments passed to the transform callable.
  2091      **kwargs: keyword arguments passed to the transform callable.
  2092  
  2093    Returns:
  2094      ~apache_beam.pvalue.PCollection:
  2095      A :class:`~apache_beam.pvalue.PCollection` containing the
  2096      :func:`FlatMapTuple` outputs.
  2097  
  2098    Raises:
  2099      TypeError: If the **fn** passed as argument is not a callable.
  2100        Typical error is to pass a :class:`DoFn` instance which is supported only
  2101        for :class:`ParDo`.
  2102    """
  2103    if not callable(fn):
  2104      raise TypeError(
  2105          'FlatMapTuple can be used only with callable objects. '
  2106          'Received %r instead.' % (fn))
  2107  
  2108    label = 'FlatMapTuple(%s)' % ptransform.label_from_callable(fn)
  2109  
  2110    arg_names, defaults = get_function_args_defaults(fn)
  2111    num_defaults = len(defaults)
  2112    if num_defaults < len(args) + len(kwargs):
  2113      raise TypeError('Side inputs must have defaults for FlatMapTuple.')
  2114  
  2115    if defaults or args or kwargs:
  2116      wrapper = lambda x, *args, **kwargs: fn(*(tuple(x) + args), **kwargs)
  2117    else:
  2118      wrapper = lambda x: fn(*x)
  2119  
  2120    # Proxy the type-hint information from the original function to this new
  2121    # wrapped function.
  2122    type_hints = get_type_hints(fn).with_defaults(
  2123        typehints.decorators.IOTypeHints.from_callable(fn))
  2124    if type_hints.input_types is not None:
  2125      # TODO(BEAM-14052): ignore input hints, as we do not have enough
  2126      # information to infer the input type hint of the wrapper function.
  2127      pass
  2128    output_hint = type_hints.simple_output_type(label)
  2129    if output_hint:
  2130      wrapper = with_output_types(_strip_output_annotations(output_hint))(wrapper)
  2131  
  2132    # Replace the first (args) component.
  2133    modified_arg_names = ['tuple_element'] + arg_names[-num_defaults:]
  2134    modified_argspec = (modified_arg_names, defaults)
  2135    pardo = ParDo(
  2136        CallableWrapperDoFn(wrapper, fullargspec=modified_argspec),
  2137        *args,
  2138        **kwargs)
  2139    pardo.label = label
  2140    return pardo
  2141  
  2142  
  2143  class _ExceptionHandlingWrapper(ptransform.PTransform):
  2144    """Implementation of ParDo.with_exception_handling."""
  2145    def __init__(
  2146        self,
  2147        fn,
  2148        args,
  2149        kwargs,
  2150        main_tag,
  2151        dead_letter_tag,
  2152        exc_class,
  2153        partial,
  2154        use_subprocess,
  2155        threshold,
  2156        threshold_windowing,
  2157        timeout):
  2158      if partial and use_subprocess:
  2159        raise ValueError('partial and use_subprocess are mutually incompatible.')
  2160      self._fn = fn
  2161      self._args = args
  2162      self._kwargs = kwargs
  2163      self._main_tag = main_tag
  2164      self._dead_letter_tag = dead_letter_tag
  2165      self._exc_class = exc_class
  2166      self._partial = partial
  2167      self._use_subprocess = use_subprocess
  2168      self._threshold = threshold
  2169      self._threshold_windowing = threshold_windowing
  2170      self._timeout = timeout
  2171  
  2172    def expand(self, pcoll):
  2173      if self._use_subprocess:
  2174        wrapped_fn = _SubprocessDoFn(self._fn, timeout=self._timeout)
  2175      elif self._timeout:
  2176        wrapped_fn = _TimeoutDoFn(self._fn, timeout=self._timeout)
  2177      else:
  2178        wrapped_fn = self._fn
  2179      result = pcoll | ParDo(
  2180          _ExceptionHandlingWrapperDoFn(
  2181              wrapped_fn, self._dead_letter_tag, self._exc_class, self._partial),
  2182          *self._args,
  2183          **self._kwargs).with_outputs(
  2184              self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
  2185  
  2186      if self._threshold < 1.0:
  2187  
  2188        class MaybeWindow(ptransform.PTransform):
  2189          @staticmethod
  2190          def expand(pcoll):
  2191            if self._threshold_windowing:
  2192              return pcoll | WindowInto(self._threshold_windowing)
  2193            else:
  2194              return pcoll
  2195  
  2196        input_count_view = pcoll | 'CountTotal' >> (
  2197            MaybeWindow() | Map(lambda _: 1)
  2198            | CombineGlobally(sum).as_singleton_view())
  2199        bad_count_pcoll = result[self._dead_letter_tag] | 'CountBad' >> (
  2200            MaybeWindow() | Map(lambda _: 1)
  2201            | CombineGlobally(sum).without_defaults())
  2202  
  2203        def check_threshold(bad, total, threshold, window=DoFn.WindowParam):
  2204          if bad > total * threshold:
  2205            raise ValueError(
  2206                'The number of failing elements within the window %r '
  2207                'exceeded threshold: %s / %s = %s > %s' %
  2208                (window, bad, total, bad / total, threshold))
  2209  
  2210        _ = bad_count_pcoll | Map(
  2211            check_threshold, input_count_view, self._threshold)
  2212  
  2213      return result
  2214  
  2215  
  2216  class _ExceptionHandlingWrapperDoFn(DoFn):
  2217    def __init__(self, fn, dead_letter_tag, exc_class, partial):
  2218      self._fn = fn
  2219      self._dead_letter_tag = dead_letter_tag
  2220      self._exc_class = exc_class
  2221      self._partial = partial
  2222  
  2223    def __getattribute__(self, name):
  2224      if (name.startswith('__') or name in self.__dict__ or
  2225          name in _ExceptionHandlingWrapperDoFn.__dict__):
  2226        return object.__getattribute__(self, name)
  2227      else:
  2228        return getattr(self._fn, name)
  2229  
  2230    def process(self, *args, **kwargs):
  2231      try:
  2232        result = self._fn.process(*args, **kwargs)
  2233        if not self._partial:
  2234          # Don't emit any results until we know there will be no errors.
  2235          result = list(result)
  2236        yield from result
  2237      except self._exc_class as exn:
  2238        yield pvalue.TaggedOutput(
  2239            self._dead_letter_tag,
  2240            (
  2241                args[0], (
  2242                    type(exn),
  2243                    repr(exn),
  2244                    traceback.format_exception(*sys.exc_info()))))
  2245  
  2246  
  2247  class _SubprocessDoFn(DoFn):
  2248    """Process method run in a subprocess, turning hard crashes into exceptions.
  2249    """
  2250    def __init__(self, fn, timeout=None):
  2251      self._fn = fn
  2252      self._serialized_fn = pickler.dumps(fn)
  2253      self._timeout = timeout
  2254  
  2255    def __getattribute__(self, name):
  2256      if (name.startswith('__') or name in self.__dict__ or
  2257          name in type(self).__dict__):
  2258        return object.__getattribute__(self, name)
  2259      else:
  2260        return getattr(self._fn, name)
  2261  
  2262    def setup(self):
  2263      self._pool = None
  2264  
  2265    def start_bundle(self):
  2266      # The pool is initialized lazily, including calls to setup and start_bundle.
  2267      # This allows us to continue processing elements after a crash.
  2268      pass
  2269  
  2270    def process(self, *args, **kwargs):
  2271      return self._call_remote(self._remote_process, *args, **kwargs)
  2272  
  2273    def finish_bundle(self):
  2274      self._call_remote(self._remote_finish_bundle)
  2275  
  2276    def teardown(self):
  2277      self._call_remote(self._remote_teardown)
  2278      self._terminate_pool()
  2279  
  2280    def _call_remote(self, method, *args, **kwargs):
  2281      if self._pool is None:
  2282        self._pool = concurrent.futures.ProcessPoolExecutor(1)
  2283        self._pool.submit(self._remote_init, self._serialized_fn).result()
  2284      try:
  2285        return self._pool.submit(method, *args, **kwargs).result(
  2286            self._timeout if method == self._remote_process else None)
  2287      except (concurrent.futures.process.BrokenProcessPool,
  2288              TimeoutError,
  2289              concurrent.futures._base.TimeoutError):
  2290        self._terminate_pool()
  2291        raise
  2292  
  2293    def _terminate_pool(self):
  2294      """Forcibly terminate the pool, not leaving any live subprocesses."""
  2295      pool = self._pool
  2296      self._pool = None
  2297      processes = list(pool._processes.values())
  2298      pool.shutdown(wait=False)
  2299      for p in processes:
  2300        if p.is_alive():
  2301          p.kill()
  2302      time.sleep(1)
  2303      for p in processes:
  2304        if p.is_alive():
  2305          p.terminate()
  2306  
  2307    # These are classmethods to avoid picking the state of self.
  2308    # They should only be called in an isolated process, so there's no concern
  2309    # about sharing state or thread safety.
  2310  
  2311    @classmethod
  2312    def _remote_init(cls, serialized_fn):
  2313      cls._serialized_fn = serialized_fn
  2314      cls._fn = None
  2315      cls._started = False
  2316  
  2317    @classmethod
  2318    def _remote_process(cls, *args, **kwargs):
  2319      if cls._fn is None:
  2320        cls._fn = pickler.loads(cls._serialized_fn)
  2321        cls._fn.setup()
  2322      if not cls._started:
  2323        cls._fn.start_bundle()
  2324        cls._started = True
  2325      result = cls._fn.process(*args, **kwargs)
  2326      if result:
  2327        # Don't return generator objects.
  2328        result = list(result)
  2329      return result
  2330  
  2331    @classmethod
  2332    def _remote_finish_bundle(cls):
  2333      if cls._started:
  2334        cls._started = False
  2335        if cls._fn.finish_bundle():
  2336          # This is because we restart and re-initialize the pool if it crashed.
  2337          raise RuntimeError(
  2338              "Returning elements from _SubprocessDoFn.finish_bundle not safe.")
  2339  
  2340    @classmethod
  2341    def _remote_teardown(cls):
  2342      if cls._fn:
  2343        cls._fn.teardown()
  2344      cls._fn = None
  2345  
  2346  
  2347  class _TimeoutDoFn(DoFn):
  2348    """Process method run in a separate thread allowing timeouts.
  2349    """
  2350    def __init__(self, fn, timeout=None):
  2351      self._fn = fn
  2352      self._timeout = timeout
  2353      self._pool = None
  2354  
  2355    def __getattribute__(self, name):
  2356      if (name.startswith('__') or name in self.__dict__ or
  2357          name in type(self).__dict__):
  2358        return object.__getattribute__(self, name)
  2359      else:
  2360        return getattr(self._fn, name)
  2361  
  2362    def process(self, *args, **kwargs):
  2363      if self._pool is None:
  2364        self._pool = concurrent.futures.ThreadPoolExecutor(10)
  2365      # Ensure we iterate over the entire output list in the given amount of time.
  2366      try:
  2367        return self._pool.submit(
  2368            lambda: list(self._fn.process(*args, **kwargs))).result(
  2369                self._timeout)
  2370      except TimeoutError:
  2371        self._pool.shutdown(wait=False)
  2372        self._pool = None
  2373        raise
  2374  
  2375    def teardown(self):
  2376      try:
  2377        self._fn.teardown()
  2378      finally:
  2379        if self._pool is not None:
  2380          self._pool.shutdown(wait=False)
  2381          self._pool = None
  2382  
  2383  
  2384  def Filter(fn, *args, **kwargs):  # pylint: disable=invalid-name
  2385    """:func:`Filter` is a :func:`FlatMap` with its callable filtering out
  2386    elements.
  2387  
  2388    Filter accepts a function that keeps elements that return True, and filters
  2389    out the remaining elements.
  2390  
  2391    Args:
  2392      fn (``Callable[..., bool]``): a callable object. First argument will be an
  2393        element.
  2394      *args: positional arguments passed to the transform callable.
  2395      **kwargs: keyword arguments passed to the transform callable.
  2396  
  2397    Returns:
  2398      ~apache_beam.pvalue.PCollection:
  2399      A :class:`~apache_beam.pvalue.PCollection` containing the
  2400      :func:`Filter` outputs.
  2401  
  2402    Raises:
  2403      TypeError: If the **fn** passed as argument is not a callable.
  2404        Typical error is to pass a :class:`DoFn` instance which is supported only
  2405        for :class:`ParDo`.
  2406    """
  2407    if not callable(fn):
  2408      raise TypeError(
  2409          'Filter can be used only with callable objects. '
  2410          'Received %r instead.' % (fn))
  2411    wrapper = lambda x, *args, **kwargs: [x] if fn(x, *args, **kwargs) else []
  2412  
  2413    label = 'Filter(%s)' % ptransform.label_from_callable(fn)
  2414  
  2415    # TODO: What about callable classes?
  2416    if hasattr(fn, '__name__'):
  2417      wrapper.__name__ = fn.__name__
  2418  
  2419    # Get type hints from this instance or the callable. Do not use output type
  2420    # hints from the callable (which should be bool if set).
  2421    fn_type_hints = typehints.decorators.IOTypeHints.from_callable(fn)
  2422    if fn_type_hints is not None:
  2423      fn_type_hints = fn_type_hints.with_output_types()
  2424    type_hints = get_type_hints(fn).with_defaults(fn_type_hints)
  2425  
  2426    # Proxy the type-hint information from the function being wrapped, setting the
  2427    # output type to be the same as the input type.
  2428    if type_hints.input_types is not None:
  2429      wrapper = with_input_types(
  2430          *type_hints.input_types[0], **type_hints.input_types[1])(
  2431              wrapper)
  2432    output_hint = type_hints.simple_output_type(label)
  2433    if (output_hint is None and get_type_hints(wrapper).input_types and
  2434        get_type_hints(wrapper).input_types[0]):
  2435      output_hint = get_type_hints(wrapper).input_types[0][0]
  2436    if output_hint:
  2437      wrapper = with_output_types(
  2438          typehints.Iterable[_strip_output_annotations(output_hint)])(
  2439              wrapper)
  2440    # pylint: disable=protected-access
  2441    wrapper._argspec_fn = fn
  2442    # pylint: enable=protected-access
  2443  
  2444    pardo = FlatMap(wrapper, *args, **kwargs)
  2445    pardo.label = label
  2446    return pardo
  2447  
  2448  
  2449  def _combine_payload(combine_fn, context):
  2450    return beam_runner_api_pb2.CombinePayload(
  2451        combine_fn=combine_fn.to_runner_api(context),
  2452        accumulator_coder_id=context.coders.get_id(
  2453            combine_fn.get_accumulator_coder()))
  2454  
  2455  
  2456  class CombineGlobally(PTransform):
  2457    """A :class:`CombineGlobally` transform.
  2458  
  2459    Reduces a :class:`~apache_beam.pvalue.PCollection` to a single value by
  2460    progressively applying a :class:`CombineFn` to portions of the
  2461    :class:`~apache_beam.pvalue.PCollection` (and to intermediate values created
  2462    thereby). See documentation in :class:`CombineFn` for details on the specifics
  2463    on how :class:`CombineFn` s are applied.
  2464  
  2465    Args:
  2466      pcoll (~apache_beam.pvalue.PCollection):
  2467        a :class:`~apache_beam.pvalue.PCollection` to be reduced into a single
  2468        value.
  2469      fn (callable): a :class:`CombineFn` object that will be called to
  2470        progressively reduce the :class:`~apache_beam.pvalue.PCollection` into
  2471        single values, or a callable suitable for wrapping by
  2472        :class:`~apache_beam.transforms.core.CallableWrapperCombineFn`.
  2473      *args: positional arguments passed to the :class:`CombineFn` object.
  2474      **kwargs: keyword arguments passed to the :class:`CombineFn` object.
  2475  
  2476    Raises:
  2477      TypeError: If the output type of the input
  2478        :class:`~apache_beam.pvalue.PCollection` is not compatible
  2479        with ``Iterable[A]``.
  2480  
  2481    Returns:
  2482      ~apache_beam.pvalue.PCollection: A single-element
  2483      :class:`~apache_beam.pvalue.PCollection` containing the main output of
  2484      the :class:`CombineGlobally` transform.
  2485  
  2486    Note that the positional and keyword arguments will be processed in order
  2487    to detect :class:`~apache_beam.pvalue.PValue` s that will be computed as side
  2488    inputs to the transform.
  2489    During pipeline execution whenever the :class:`CombineFn` object gets executed
  2490    (i.e. any of the :class:`CombineFn` methods get called), the
  2491    :class:`~apache_beam.pvalue.PValue` arguments will be replaced by their
  2492    actual value in the exact position where they appear in the argument lists.
  2493    """
  2494    has_defaults = True
  2495    as_view = False
  2496    fanout = None  # type: typing.Optional[int]
  2497  
  2498    def __init__(self, fn, *args, **kwargs):
  2499      if not (isinstance(fn, CombineFn) or callable(fn)):
  2500        raise TypeError(
  2501            'CombineGlobally can be used only with combineFn objects. '
  2502            'Received %r instead.' % (fn))
  2503  
  2504      super().__init__()
  2505      self.fn = fn
  2506      self.args = args
  2507      self.kwargs = kwargs
  2508  
  2509    def display_data(self):
  2510      return {
  2511          'combine_fn': DisplayDataItem(
  2512              self.fn.__class__, label='Combine Function'),
  2513          'combine_fn_dd': self.fn,
  2514      }
  2515  
  2516    def default_label(self):
  2517      if self.fanout is None:
  2518        return '%s(%s)' % (
  2519            self.__class__.__name__, ptransform.label_from_callable(self.fn))
  2520      else:
  2521        return '%s(%s, fanout=%s)' % (
  2522            self.__class__.__name__,
  2523            ptransform.label_from_callable(self.fn),
  2524            self.fanout)
  2525  
  2526    def _clone(self, **extra_attributes):
  2527      clone = copy.copy(self)
  2528      clone.__dict__.update(extra_attributes)
  2529      return clone
  2530  
  2531    def with_fanout(self, fanout):
  2532      return self._clone(fanout=fanout)
  2533  
  2534    def with_defaults(self, has_defaults=True):
  2535      return self._clone(has_defaults=has_defaults)
  2536  
  2537    def without_defaults(self):
  2538      return self.with_defaults(False)
  2539  
  2540    def as_singleton_view(self):
  2541      return self._clone(as_view=True)
  2542  
  2543    def expand(self, pcoll):
  2544      def add_input_types(transform):
  2545        type_hints = self.get_type_hints()
  2546        if type_hints.input_types:
  2547          return transform.with_input_types(type_hints.input_types[0][0])
  2548        return transform
  2549  
  2550      combine_fn = CombineFn.maybe_from_callable(
  2551          self.fn, has_side_inputs=self.args or self.kwargs)
  2552      combine_per_key = CombinePerKey(combine_fn, *self.args, **self.kwargs)
  2553      if self.fanout:
  2554        combine_per_key = combine_per_key.with_hot_key_fanout(self.fanout)
  2555  
  2556      combined = (
  2557          pcoll
  2558          | 'KeyWithVoid' >> add_input_types(
  2559              ParDo(_KeyWithNone()).with_output_types(
  2560                  typehints.KV[None, pcoll.element_type]))
  2561          | 'CombinePerKey' >> combine_per_key
  2562          | 'UnKey' >> Map(lambda k_v: k_v[1]))
  2563  
  2564      if not self.has_defaults and not self.as_view:
  2565        return combined
  2566  
  2567      elif self.as_view:
  2568        if self.has_defaults:
  2569          try:
  2570            combine_fn.setup(*self.args, **self.kwargs)
  2571            # This is called in the main program, but cannot be avoided
  2572            # in the as_view case as it must be available to all windows.
  2573            default_value = combine_fn.apply([], *self.args, **self.kwargs)
  2574          finally:
  2575            combine_fn.teardown(*self.args, **self.kwargs)
  2576        else:
  2577          default_value = pvalue.AsSingleton._NO_DEFAULT
  2578        return pvalue.AsSingleton(combined, default_value=default_value)
  2579  
  2580      else:
  2581        if pcoll.windowing.windowfn != GlobalWindows():
  2582          raise ValueError(
  2583              "Default values are not yet supported in CombineGlobally() if the "
  2584              "output  PCollection is not windowed by GlobalWindows. "
  2585              "Instead, use CombineGlobally().without_defaults() to output "
  2586              "an empty PCollection if the input PCollection is empty, "
  2587              "or CombineGlobally().as_singleton_view() to get the default "
  2588              "output of the CombineFn if the input PCollection is empty.")
  2589  
  2590        def typed(transform):
  2591          # TODO(robertwb): We should infer this.
  2592          if combined.element_type:
  2593            return transform.with_output_types(combined.element_type)
  2594          return transform
  2595  
  2596        # Capture in closure (avoiding capturing self).
  2597        args, kwargs = self.args, self.kwargs
  2598  
  2599        def inject_default(_, combined):
  2600          if combined:
  2601            assert len(combined) == 1
  2602            return combined[0]
  2603          else:
  2604            try:
  2605              combine_fn.setup(*args, **kwargs)
  2606              default = combine_fn.apply([], *args, **kwargs)
  2607            finally:
  2608              combine_fn.teardown(*args, **kwargs)
  2609            return default
  2610  
  2611        return (
  2612            pcoll.pipeline
  2613            | 'DoOnce' >> Create([None])
  2614            | 'InjectDefault' >> typed(
  2615                Map(inject_default, pvalue.AsList(combined))))
  2616  
  2617    @staticmethod
  2618    @PTransform.register_urn(
  2619        common_urns.composites.COMBINE_GLOBALLY.urn,
  2620        beam_runner_api_pb2.CombinePayload)
  2621    def from_runner_api_parameter(unused_ptransform, combine_payload, context):
  2622      return CombineGlobally(
  2623          CombineFn.from_runner_api(combine_payload.combine_fn, context))
  2624  
  2625  
  2626  @DoFnInfo.register_stateless_dofn(python_urns.KEY_WITH_NONE_DOFN)
  2627  class _KeyWithNone(DoFn):
  2628    def process(self, v):
  2629      yield None, v
  2630  
  2631  
  2632  class CombinePerKey(PTransformWithSideInputs):
  2633    """A per-key Combine transform.
  2634  
  2635    Identifies sets of values associated with the same key in the input
  2636    PCollection, then applies a CombineFn to condense those sets to single
  2637    values. See documentation in CombineFn for details on the specifics on how
  2638    CombineFns are applied.
  2639  
  2640    Args:
  2641      pcoll: input pcollection.
  2642      fn: instance of CombineFn to apply to all values under the same key in
  2643        pcoll, or a callable whose signature is ``f(iterable, *args, **kwargs)``
  2644        (e.g., sum, max).
  2645      *args: arguments and side inputs, passed directly to the CombineFn.
  2646      **kwargs: arguments and side inputs, passed directly to the CombineFn.
  2647  
  2648    Returns:
  2649      A PObject holding the result of the combine operation.
  2650    """
  2651    def with_hot_key_fanout(self, fanout):
  2652      """A per-key combine operation like self but with two levels of aggregation.
  2653  
  2654      If a given key is produced by too many upstream bundles, the final
  2655      reduction can become a bottleneck despite partial combining being lifted
  2656      pre-GroupByKey.  In these cases it can be helpful to perform intermediate
  2657      partial aggregations in parallel and then re-group to peform a final
  2658      (per-key) combine.  This is also useful for high-volume keys in streaming
  2659      where combiners are not generally lifted for latency reasons.
  2660  
  2661      Note that a fanout greater than 1 requires the data to be sent through
  2662      two GroupByKeys, and a high fanout can also result in more shuffle data
  2663      due to less per-bundle combining. Setting the fanout for a key at 1 or less
  2664      places values on the "cold key" path that skip the intermediate level of
  2665      aggregation.
  2666  
  2667      Args:
  2668        fanout: either None, for no fanout, an int, for a constant-degree fanout,
  2669            or a callable mapping keys to a key-specific degree of fanout.
  2670  
  2671      Returns:
  2672        A per-key combining PTransform with the specified fanout.
  2673      """
  2674      from apache_beam.transforms.combiners import curry_combine_fn
  2675      if fanout is None:
  2676        return self
  2677      else:
  2678        return _CombinePerKeyWithHotKeyFanout(
  2679            curry_combine_fn(self.fn, self.args, self.kwargs), fanout)
  2680  
  2681    def display_data(self):
  2682      return {
  2683          'combine_fn': DisplayDataItem(
  2684              self.fn.__class__, label='Combine Function'),
  2685          'combine_fn_dd': self.fn
  2686      }
  2687  
  2688    def make_fn(self, fn, has_side_inputs):
  2689      self._fn_label = ptransform.label_from_callable(fn)
  2690      return CombineFn.maybe_from_callable(fn, has_side_inputs)
  2691  
  2692    def default_label(self):
  2693      return '%s(%s)' % (self.__class__.__name__, self._fn_label)
  2694  
  2695    def _process_argspec_fn(self):
  2696      return lambda element, *args, **kwargs: None
  2697  
  2698    def expand(self, pcoll):
  2699      args, kwargs = util.insert_values_in_args(
  2700          self.args, self.kwargs, self.side_inputs)
  2701      return pcoll | GroupByKey() | 'Combine' >> CombineValues(
  2702          self.fn, *args, **kwargs)
  2703  
  2704    def default_type_hints(self):
  2705      result = self.fn.get_type_hints()
  2706      k = typehints.TypeVariable('K')
  2707      if result.input_types:
  2708        args, kwargs = result.input_types
  2709        args = (typehints.Tuple[k, args[0]], ) + args[1:]
  2710        result = result.with_input_types(*args, **kwargs)
  2711      else:
  2712        result = result.with_input_types(typehints.Tuple[k, typehints.Any])
  2713      if result.output_types:
  2714        main_output_type = result.simple_output_type('')
  2715        result = result.with_output_types(typehints.Tuple[k, main_output_type])
  2716      else:
  2717        result = result.with_output_types(typehints.Tuple[k, typehints.Any])
  2718      return result
  2719  
  2720    def to_runner_api_parameter(
  2721        self,
  2722        context,  # type: PipelineContext
  2723    ):
  2724      # type: (...) -> typing.Tuple[str, beam_runner_api_pb2.CombinePayload]
  2725      if self.args or self.kwargs:
  2726        from apache_beam.transforms.combiners import curry_combine_fn
  2727        combine_fn = curry_combine_fn(self.fn, self.args, self.kwargs)
  2728      else:
  2729        combine_fn = self.fn
  2730      return (
  2731          common_urns.composites.COMBINE_PER_KEY.urn,
  2732          _combine_payload(combine_fn, context))
  2733  
  2734    @staticmethod
  2735    @PTransform.register_urn(
  2736        common_urns.composites.COMBINE_PER_KEY.urn,
  2737        beam_runner_api_pb2.CombinePayload)
  2738    def from_runner_api_parameter(unused_ptransform, combine_payload, context):
  2739      return CombinePerKey(
  2740          CombineFn.from_runner_api(combine_payload.combine_fn, context))
  2741  
  2742    def runner_api_requires_keyed_input(self):
  2743      return True
  2744  
  2745  
  2746  # TODO(robertwb): Rename to CombineGroupedValues?
  2747  class CombineValues(PTransformWithSideInputs):
  2748    def make_fn(self, fn, has_side_inputs):
  2749      return CombineFn.maybe_from_callable(fn, has_side_inputs)
  2750  
  2751    def expand(self, pcoll):
  2752      args, kwargs = util.insert_values_in_args(
  2753          self.args, self.kwargs, self.side_inputs)
  2754  
  2755      input_type = pcoll.element_type
  2756      key_type = None
  2757      if input_type is not None:
  2758        key_type, _ = input_type.tuple_types
  2759  
  2760      runtime_type_check = (
  2761          pcoll.pipeline._options.view_as(TypeOptions).runtime_type_check)
  2762      return pcoll | ParDo(
  2763          CombineValuesDoFn(key_type, self.fn, runtime_type_check),
  2764          *args,
  2765          **kwargs)
  2766  
  2767    def to_runner_api_parameter(self, context):
  2768      if self.args or self.kwargs:
  2769        from apache_beam.transforms.combiners import curry_combine_fn
  2770        combine_fn = curry_combine_fn(self.fn, self.args, self.kwargs)
  2771      else:
  2772        combine_fn = self.fn
  2773      return (
  2774          common_urns.combine_components.COMBINE_GROUPED_VALUES.urn,
  2775          _combine_payload(combine_fn, context))
  2776  
  2777    @staticmethod
  2778    @PTransform.register_urn(
  2779        common_urns.combine_components.COMBINE_GROUPED_VALUES.urn,
  2780        beam_runner_api_pb2.CombinePayload)
  2781    def from_runner_api_parameter(unused_ptransform, combine_payload, context):
  2782      return CombineValues(
  2783          CombineFn.from_runner_api(combine_payload.combine_fn, context))
  2784  
  2785  
  2786  class CombineValuesDoFn(DoFn):
  2787    """DoFn for performing per-key Combine transforms."""
  2788  
  2789    def __init__(
  2790        self,
  2791        input_pcoll_type,
  2792        combinefn,  # type: CombineFn
  2793        runtime_type_check,  # type: bool
  2794    ):
  2795      super().__init__()
  2796      self.combinefn = combinefn
  2797      self.runtime_type_check = runtime_type_check
  2798  
  2799    def setup(self):
  2800      self.combinefn.setup()
  2801  
  2802    def process(self, element, *args, **kwargs):
  2803      # Expected elements input to this DoFn are 2-tuples of the form
  2804      # (key, iter), with iter an iterable of all the values associated with key
  2805      # in the input PCollection.
  2806      if self.runtime_type_check:
  2807        # Apply the combiner in a single operation rather than artificially
  2808        # breaking it up so that output type violations manifest as TypeCheck
  2809        # errors rather than type errors.
  2810        return [(element[0], self.combinefn.apply(element[1], *args, **kwargs))]
  2811  
  2812      # Add the elements into three accumulators (for testing of merge).
  2813      elements = list(element[1])
  2814      accumulators = []
  2815      for k in range(3):
  2816        if len(elements) <= k:
  2817          break
  2818        accumulators.append(
  2819            self.combinefn.add_inputs(
  2820                self.combinefn.create_accumulator(*args, **kwargs),
  2821                elements[k::3],
  2822                *args,
  2823                **kwargs))
  2824      # Merge the accumulators.
  2825      accumulator = self.combinefn.merge_accumulators(
  2826          accumulators, *args, **kwargs)
  2827      # Convert accumulator to the final result.
  2828      return [(
  2829          element[0], self.combinefn.extract_output(accumulator, *args,
  2830                                                    **kwargs))]
  2831  
  2832    def teardown(self):
  2833      self.combinefn.teardown()
  2834  
  2835    def default_type_hints(self):
  2836      hints = self.combinefn.get_type_hints()
  2837      if hints.input_types:
  2838        K = typehints.TypeVariable('K')
  2839        args, kwargs = hints.input_types
  2840        args = (typehints.Tuple[K, typehints.Iterable[args[0]]], ) + args[1:]
  2841        hints = hints.with_input_types(*args, **kwargs)
  2842      else:
  2843        K = typehints.Any
  2844      if hints.output_types:
  2845        main_output_type = hints.simple_output_type('')
  2846        hints = hints.with_output_types(typehints.Tuple[K, main_output_type])
  2847      return hints
  2848  
  2849  
  2850  class _CombinePerKeyWithHotKeyFanout(PTransform):
  2851  
  2852    def __init__(
  2853        self,
  2854        combine_fn,  # type: CombineFn
  2855        fanout,  # type: typing.Union[int, typing.Callable[[typing.Any], int]]
  2856    ):
  2857      # type: (...) -> None
  2858      self._combine_fn = combine_fn
  2859      self._fanout_fn = ((lambda key: fanout)
  2860                         if isinstance(fanout, int) else fanout)
  2861  
  2862    def default_label(self):
  2863      return '%s(%s, fanout=%s)' % (
  2864          self.__class__.__name__,
  2865          ptransform.label_from_callable(self._combine_fn),
  2866          ptransform.label_from_callable(self._fanout_fn))
  2867  
  2868    def expand(self, pcoll):
  2869  
  2870      from apache_beam.transforms.trigger import AccumulationMode
  2871      combine_fn = self._combine_fn
  2872      fanout_fn = self._fanout_fn
  2873  
  2874      if isinstance(pcoll.windowing.windowfn, SlidingWindows):
  2875        raise ValueError(
  2876            'CombinePerKey.with_hot_key_fanout does not yet work properly with '
  2877            'SlidingWindows. See: https://github.com/apache/beam/issues/20528')
  2878  
  2879      class SplitHotCold(DoFn):
  2880        def start_bundle(self):
  2881          # Spreading a hot key across all possible sub-keys for all bundles
  2882          # would defeat the goal of not overwhelming downstream reducers
  2883          # (as well as making less efficient use of PGBK combining tables).
  2884          # Instead, each bundle independently makes a consistent choice about
  2885          # which "shard" of a key to send its intermediate results.
  2886          self._nonce = int(random.getrandbits(31))
  2887  
  2888        def process(self, element):
  2889          key, value = element
  2890          fanout = fanout_fn(key)
  2891          if fanout <= 1:
  2892            # Boolean indicates this is not an accumulator.
  2893            yield (key, (False, value))  # cold
  2894          else:
  2895            yield pvalue.TaggedOutput('hot', ((self._nonce % fanout, key), value))
  2896  
  2897      class PreCombineFn(CombineFn):
  2898        @staticmethod
  2899        def extract_output(accumulator):
  2900          # Boolean indicates this is an accumulator.
  2901          return (True, accumulator)
  2902  
  2903        setup = combine_fn.setup
  2904        create_accumulator = combine_fn.create_accumulator
  2905        add_input = combine_fn.add_input
  2906        merge_accumulators = combine_fn.merge_accumulators
  2907        compact = combine_fn.compact
  2908        teardown = combine_fn.teardown
  2909  
  2910      class PostCombineFn(CombineFn):
  2911        @staticmethod
  2912        def add_input(accumulator, element):
  2913          is_accumulator, value = element
  2914          if is_accumulator:
  2915            return combine_fn.merge_accumulators([accumulator, value])
  2916          else:
  2917            return combine_fn.add_input(accumulator, value)
  2918  
  2919        setup = combine_fn.setup
  2920        create_accumulator = combine_fn.create_accumulator
  2921        merge_accumulators = combine_fn.merge_accumulators
  2922        compact = combine_fn.compact
  2923        extract_output = combine_fn.extract_output
  2924        teardown = combine_fn.teardown
  2925  
  2926      def StripNonce(nonce_key_value):
  2927        (_, key), value = nonce_key_value
  2928        return key, value
  2929  
  2930      cold, hot = pcoll | ParDo(SplitHotCold()).with_outputs('hot', main='cold')
  2931      cold.element_type = typehints.Any  # No multi-output type hints.
  2932      precombined_hot = (
  2933          hot
  2934          # Avoid double counting that may happen with stacked accumulating mode.
  2935          | 'WindowIntoDiscarding' >> WindowInto(
  2936              pcoll.windowing, accumulation_mode=AccumulationMode.DISCARDING)
  2937          | CombinePerKey(PreCombineFn())
  2938          | Map(StripNonce)
  2939          | 'WindowIntoOriginal' >> WindowInto(pcoll.windowing))
  2940      return ((cold, precombined_hot)
  2941              | Flatten()
  2942              | CombinePerKey(PostCombineFn()))
  2943  
  2944  
  2945  @typehints.with_input_types(typing.Tuple[K, V])
  2946  @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]])
  2947  class GroupByKey(PTransform):
  2948    """A group by key transform.
  2949  
  2950    Processes an input PCollection consisting of key/value pairs represented as a
  2951    tuple pair. The result is a PCollection where values having a common key are
  2952    grouped together.  For example (a, 1), (b, 2), (a, 3) will result into
  2953    (a, [1, 3]), (b, [2]).
  2954  
  2955    The implementation here is used only when run on the local direct runner.
  2956    """
  2957    class ReifyWindows(DoFn):
  2958      def process(
  2959          self, element, window=DoFn.WindowParam, timestamp=DoFn.TimestampParam):
  2960        try:
  2961          k, v = element
  2962        except TypeError:
  2963          raise TypeCheckError(
  2964              'Input to GroupByKey must be a PCollection with '
  2965              'elements compatible with KV[A, B]')
  2966  
  2967        return [(k, WindowedValue(v, timestamp, [window]))]
  2968  
  2969      def infer_output_type(self, input_type):
  2970        key_type, value_type = trivial_inference.key_value_types(input_type)
  2971        return typehints.KV[
  2972            key_type, typehints.WindowedValue[value_type]]  # type: ignore[misc]
  2973  
  2974    def expand(self, pcoll):
  2975      from apache_beam.transforms.trigger import DataLossReason
  2976      from apache_beam.transforms.trigger import DefaultTrigger
  2977      windowing = pcoll.windowing
  2978      trigger = windowing.triggerfn
  2979      if not pcoll.is_bounded and isinstance(
  2980          windowing.windowfn, GlobalWindows) and isinstance(trigger,
  2981                                                            DefaultTrigger):
  2982        if pcoll.pipeline.allow_unsafe_triggers:
  2983          # TODO(BEAM-9487) Change comment for Beam 2.33
  2984          _LOGGER.warning(
  2985              '%s: PCollection passed to GroupByKey is unbounded, has a global '
  2986              'window, and uses a default trigger. This is being allowed '
  2987              'because --allow_unsafe_triggers is set, but it may prevent '
  2988              'data from making it through the pipeline.',
  2989              self.label)
  2990        else:
  2991          raise ValueError(
  2992              'GroupByKey cannot be applied to an unbounded ' +
  2993              'PCollection with global windowing and a default trigger')
  2994  
  2995      unsafe_reason = trigger.may_lose_data(windowing)
  2996      if unsafe_reason != DataLossReason.NO_POTENTIAL_LOSS:
  2997        reason_msg = str(unsafe_reason).replace('DataLossReason.', '')
  2998        if pcoll.pipeline.allow_unsafe_triggers:
  2999          _LOGGER.warning(
  3000              '%s: Unsafe trigger `%s` detected (reason: %s). This is '
  3001              'being allowed because --allow_unsafe_triggers is set. This could '
  3002              'lead to missing or incomplete groups.',
  3003              self.label,
  3004              trigger,
  3005              reason_msg)
  3006        else:
  3007          msg = '{}: Unsafe trigger: `{}` may lose data. '.format(
  3008              self.label, trigger)
  3009          msg += 'Reason: {}. '.format(reason_msg)
  3010          msg += 'This can be overriden with the --allow_unsafe_triggers flag.'
  3011          raise ValueError(msg)
  3012  
  3013      return pvalue.PCollection.from_(pcoll)
  3014  
  3015    def infer_output_type(self, input_type):
  3016      key_type, value_type = (typehints.typehints.coerce_to_kv_type(
  3017          input_type).tuple_types)
  3018      return typehints.KV[key_type, typehints.Iterable[value_type]]
  3019  
  3020    def to_runner_api_parameter(self, unused_context):
  3021      # type: (PipelineContext) -> typing.Tuple[str, None]
  3022      return common_urns.primitives.GROUP_BY_KEY.urn, None
  3023  
  3024    @staticmethod
  3025    @PTransform.register_urn(common_urns.primitives.GROUP_BY_KEY.urn, None)
  3026    def from_runner_api_parameter(
  3027        unused_ptransform, unused_payload, unused_context):
  3028      return GroupByKey()
  3029  
  3030    def runner_api_requires_keyed_input(self):
  3031      return True
  3032  
  3033  
  3034  def _expr_to_callable(expr, pos):
  3035    if isinstance(expr, str):
  3036      return lambda x: getattr(x, expr)
  3037    elif callable(expr):
  3038      return expr
  3039    else:
  3040      raise TypeError(
  3041          'Field expression %r at %s must be a callable or a string.' %
  3042          (expr, pos))
  3043  
  3044  
  3045  class GroupBy(PTransform):
  3046    """Groups a PCollection by one or more expressions, used to derive the key.
  3047  
  3048    `GroupBy(expr)` is roughly equivalent to
  3049  
  3050        beam.Map(lambda v: (expr(v), v)) | beam.GroupByKey()
  3051  
  3052    but provides several conveniences, e.g.
  3053  
  3054        * Several arguments may be provided, as positional or keyword arguments,
  3055          resulting in a tuple-like key. For example `GroupBy(a=expr1, b=expr2)`
  3056          groups by a key with attributes `a` and `b` computed by applying
  3057          `expr1` and `expr2` to each element.
  3058  
  3059        * Strings can be used as a shorthand for accessing an attribute, e.g.
  3060          `GroupBy('some_field')` is equivalent to
  3061          `GroupBy(lambda v: getattr(v, 'some_field'))`.
  3062  
  3063    The GroupBy operation can be made into an aggregating operation by invoking
  3064    its `aggregate_field` method.
  3065    """
  3066  
  3067    def __init__(
  3068        self,
  3069        *fields,  # type: typing.Union[str, typing.Callable]
  3070        **kwargs  # type: typing.Union[str, typing.Callable]
  3071    ):
  3072      if len(fields) == 1 and not kwargs:
  3073        self._force_tuple_keys = False
  3074        name = fields[0] if isinstance(fields[0], str) else 'key'
  3075        key_fields = [(name, _expr_to_callable(fields[0], 0))]
  3076      else:
  3077        self._force_tuple_keys = True
  3078        key_fields = []
  3079        for ix, field in enumerate(fields):
  3080          name = field if isinstance(field, str) else 'key%d' % ix
  3081          key_fields.append((name, _expr_to_callable(field, ix)))
  3082        for name, expr in kwargs.items():
  3083          key_fields.append((name, _expr_to_callable(expr, name)))
  3084      self._key_fields = key_fields
  3085      field_names = tuple(name for name, _ in key_fields)
  3086      self._key_type = lambda *values: _dynamic_named_tuple('Key', field_names)(
  3087          *values)
  3088  
  3089    def aggregate_field(
  3090        self,
  3091        field,  # type: typing.Union[str, typing.Callable]
  3092        combine_fn,  # type: typing.Union[typing.Callable, CombineFn]
  3093        dest,  # type: str
  3094    ):
  3095      """Returns a grouping operation that also aggregates grouped values.
  3096  
  3097      Args:
  3098        field: indicates the field to be aggregated
  3099        combine_fn: indicates the aggregation function to be used
  3100        dest: indicates the name that will be used for the aggregate in the output
  3101  
  3102      May be called repeatedly to aggregate multiple fields, e.g.
  3103  
  3104          GroupBy('key')
  3105              .aggregate_field('some_attr', sum, 'sum_attr')
  3106              .aggregate_field(lambda v: ..., MeanCombineFn, 'mean')
  3107      """
  3108      return _GroupAndAggregate(self, ()).aggregate_field(field, combine_fn, dest)
  3109  
  3110    def force_tuple_keys(self, value=True):
  3111      """Forces the keys to always be tuple-like, even if there is only a single
  3112      expression.
  3113      """
  3114      res = copy.copy(self)
  3115      res._force_tuple_keys = value
  3116      return res
  3117  
  3118    def _key_func(self):
  3119      if not self._force_tuple_keys and len(self._key_fields) == 1:
  3120        return self._key_fields[0][1]
  3121      else:
  3122        key_type = self._key_type
  3123        key_exprs = [expr for _, expr in self._key_fields]
  3124        return lambda element: key_type(*(expr(element) for expr in key_exprs))
  3125  
  3126    def _key_type_hint(self, input_type):
  3127      if not self._force_tuple_keys and len(self._key_fields) == 1:
  3128        expr = self._key_fields[0][1]
  3129        return trivial_inference.infer_return_type(expr, [input_type])
  3130      else:
  3131        return row_type.RowTypeConstraint.from_fields([
  3132            (name, trivial_inference.infer_return_type(expr, [input_type]))
  3133            for (name, expr) in self._key_fields
  3134        ])
  3135  
  3136    def default_label(self):
  3137      return 'GroupBy(%s)' % ', '.join(name for name, _ in self._key_fields)
  3138  
  3139    def expand(self, pcoll):
  3140      input_type = pcoll.element_type or typing.Any
  3141      return (
  3142          pcoll
  3143          | Map(lambda x: (self._key_func()(x), x)).with_output_types(
  3144              typehints.Tuple[self._key_type_hint(input_type), input_type])
  3145          | GroupByKey())
  3146  
  3147  
  3148  _dynamic_named_tuple_cache = {
  3149  }  # type: typing.Dict[typing.Tuple[str, typing.Tuple[str, ...]], typing.Type[tuple]]
  3150  
  3151  
  3152  def _dynamic_named_tuple(type_name, field_names):
  3153    # type: (str, typing.Tuple[str, ...]) -> typing.Type[tuple]
  3154    cache_key = (type_name, field_names)
  3155    result = _dynamic_named_tuple_cache.get(cache_key)
  3156    if result is None:
  3157      import collections
  3158      result = _dynamic_named_tuple_cache[cache_key] = collections.namedtuple(
  3159          type_name, field_names)
  3160      # typing: can't override a method. also, self type is unknown and can't
  3161      # be cast to tuple
  3162      result.__reduce__ = lambda self: (  # type: ignore[assignment]
  3163          _unpickle_dynamic_named_tuple, (type_name, field_names, tuple(self)))  # type: ignore[arg-type]
  3164    return result
  3165  
  3166  
  3167  def _unpickle_dynamic_named_tuple(type_name, field_names, values):
  3168    # type: (str, typing.Tuple[str, ...], typing.Iterable[typing.Any]) -> tuple
  3169    return _dynamic_named_tuple(type_name, field_names)(*values)
  3170  
  3171  
  3172  class _GroupAndAggregate(PTransform):
  3173    def __init__(self, grouping, aggregations):
  3174      self._grouping = grouping
  3175      self._aggregations = aggregations
  3176  
  3177    def aggregate_field(
  3178        self,
  3179        field,  # type: typing.Union[str, typing.Callable]
  3180        combine_fn,  # type: typing.Union[typing.Callable, CombineFn]
  3181        dest,  # type: str
  3182    ):
  3183      field = _expr_to_callable(field, 0)
  3184      return _GroupAndAggregate(
  3185          self._grouping, list(self._aggregations) + [(field, combine_fn, dest)])
  3186  
  3187    def expand(self, pcoll):
  3188      from apache_beam.transforms.combiners import TupleCombineFn
  3189      key_func = self._grouping.force_tuple_keys(True)._key_func()
  3190      value_exprs = [expr for expr, _, __ in self._aggregations]
  3191      value_func = lambda element: [expr(element) for expr in value_exprs]
  3192      result_fields = tuple(name
  3193                            for name, _ in self._grouping._key_fields) + tuple(
  3194                                dest for _, __, dest in self._aggregations)
  3195      key_type_hint = self._grouping.force_tuple_keys(True)._key_type_hint(
  3196          pcoll.element_type)
  3197  
  3198      return (
  3199          pcoll
  3200          | Map(lambda x: (key_func(x), value_func(x))).with_output_types(
  3201              typehints.Tuple[key_type_hint, typing.Any])
  3202          | CombinePerKey(
  3203              TupleCombineFn(
  3204                  *[combine_fn for _, combine_fn, __ in self._aggregations]))
  3205          | MapTuple(
  3206              lambda key,
  3207              value: _dynamic_named_tuple('Result', result_fields)
  3208              (*(key + value))))
  3209  
  3210  
  3211  class Select(PTransform):
  3212    """Converts the elements of a PCollection into a schema'd PCollection of Rows.
  3213  
  3214    `Select(...)` is roughly equivalent to `Map(lambda x: Row(...))` where each
  3215    argument (which may be a string or callable) of `ToRow` is applied to `x`.
  3216    For example,
  3217  
  3218        pcoll | beam.Select('a', b=lambda x: foo(x))
  3219  
  3220    is the same as
  3221  
  3222        pcoll | beam.Map(lambda x: beam.Row(a=x.a, b=foo(x)))
  3223    """
  3224  
  3225    def __init__(
  3226        self,
  3227        *args,  # type: typing.Union[str, typing.Callable]
  3228        **kwargs  # type: typing.Union[str, typing.Callable]
  3229    ):
  3230      self._fields = [(
  3231          expr if isinstance(expr, str) else 'arg%02d' % ix,
  3232          _expr_to_callable(expr, ix)) for (ix, expr) in enumerate(args)
  3233                      ] + [(name, _expr_to_callable(expr, name))
  3234                           for (name, expr) in kwargs.items()]
  3235  
  3236    def default_label(self):
  3237      return 'ToRows(%s)' % ', '.join(name for name, _ in self._fields)
  3238  
  3239    def expand(self, pcoll):
  3240      return pcoll | Map(
  3241          lambda x: pvalue.Row(**{name: expr(x)
  3242                                  for name, expr in self._fields}))
  3243  
  3244    def infer_output_type(self, input_type):
  3245      return row_type.RowTypeConstraint.from_fields([
  3246          (name, trivial_inference.infer_return_type(expr, [input_type]))
  3247          for (name, expr) in self._fields
  3248      ])
  3249  
  3250  
  3251  class Partition(PTransformWithSideInputs):
  3252    """Split a PCollection into several partitions.
  3253  
  3254    Uses the specified PartitionFn to separate an input PCollection into the
  3255    specified number of sub-PCollections.
  3256  
  3257    When apply()d, a Partition() PTransform requires the following:
  3258  
  3259    Args:
  3260      partitionfn: a PartitionFn, or a callable with the signature described in
  3261        CallableWrapperPartitionFn.
  3262      n: number of output partitions.
  3263  
  3264    The result of this PTransform is a simple list of the output PCollections
  3265    representing each of n partitions, in order.
  3266    """
  3267    class ApplyPartitionFnFn(DoFn):
  3268      """A DoFn that applies a PartitionFn."""
  3269      def process(self, element, partitionfn, n, *args, **kwargs):
  3270        partition = partitionfn.partition_for(element, n, *args, **kwargs)
  3271        if not 0 <= partition < n:
  3272          raise ValueError(
  3273              'PartitionFn specified out-of-bounds partition index: '
  3274              '%d not in [0, %d)' % (partition, n))
  3275        # Each input is directed into the output that corresponds to the
  3276        # selected partition.
  3277        yield pvalue.TaggedOutput(str(partition), element)
  3278  
  3279    def make_fn(self, fn, has_side_inputs):
  3280      return fn if isinstance(fn, PartitionFn) else CallableWrapperPartitionFn(fn)
  3281  
  3282    def expand(self, pcoll):
  3283      n = int(self.args[0])
  3284      args, kwargs = util.insert_values_in_args(
  3285          self.args, self.kwargs, self.side_inputs)
  3286      return pcoll | ParDo(self.ApplyPartitionFnFn(), self.fn, *args, **
  3287                           kwargs).with_outputs(*[str(t) for t in range(n)])
  3288  
  3289  
  3290  class Windowing(object):
  3291    def __init__(self,
  3292                 windowfn,  # type: WindowFn
  3293                 triggerfn=None,  # type: typing.Optional[TriggerFn]
  3294                 accumulation_mode=None,  # type: typing.Optional[beam_runner_api_pb2.AccumulationMode.Enum.ValueType]
  3295                 timestamp_combiner=None,  # type: typing.Optional[beam_runner_api_pb2.OutputTime.Enum.ValueType]
  3296                 allowed_lateness=0, # type: typing.Union[int, float]
  3297                 environment_id=None, # type: typing.Optional[str]
  3298                 ):
  3299      """Class representing the window strategy.
  3300  
  3301      Args:
  3302        windowfn: Window assign function.
  3303        triggerfn: Trigger function.
  3304        accumulation_mode: a AccumulationMode, controls what to do with data
  3305          when a trigger fires multiple times.
  3306        timestamp_combiner: a TimestampCombiner, determines how output
  3307          timestamps of grouping operations are assigned.
  3308        allowed_lateness: Maximum delay in seconds after end of window
  3309          allowed for any late data to be processed without being discarded
  3310          directly.
  3311        environment_id: Environment where the current window_fn should be
  3312          applied in.
  3313      """
  3314      global AccumulationMode, DefaultTrigger  # pylint: disable=global-variable-not-assigned
  3315      # pylint: disable=wrong-import-order, wrong-import-position
  3316      from apache_beam.transforms.trigger import AccumulationMode, DefaultTrigger
  3317      # pylint: enable=wrong-import-order, wrong-import-position
  3318      if triggerfn is None:
  3319        triggerfn = DefaultTrigger()
  3320      if accumulation_mode is None:
  3321        if triggerfn == DefaultTrigger():
  3322          accumulation_mode = AccumulationMode.DISCARDING
  3323        else:
  3324          raise ValueError(
  3325              'accumulation_mode must be provided for non-trivial triggers')
  3326      if not windowfn.get_window_coder().is_deterministic():
  3327        raise ValueError(
  3328            'window fn (%s) does not have a determanistic coder (%s)' %
  3329            (windowfn, windowfn.get_window_coder()))
  3330      self.windowfn = windowfn
  3331      self.triggerfn = triggerfn
  3332      self.accumulation_mode = accumulation_mode
  3333      self.allowed_lateness = Duration.of(allowed_lateness)
  3334      self.environment_id = environment_id
  3335      self.timestamp_combiner = (
  3336          timestamp_combiner or TimestampCombiner.OUTPUT_AT_EOW)
  3337      self._is_default = (
  3338          self.windowfn == GlobalWindows() and
  3339          self.triggerfn == DefaultTrigger() and
  3340          self.accumulation_mode == AccumulationMode.DISCARDING and
  3341          self.timestamp_combiner == TimestampCombiner.OUTPUT_AT_EOW and
  3342          self.allowed_lateness == 0)
  3343  
  3344    def __repr__(self):
  3345      return "Windowing(%s, %s, %s, %s, %s)" % (
  3346          self.windowfn,
  3347          self.triggerfn,
  3348          self.accumulation_mode,
  3349          self.timestamp_combiner,
  3350          self.environment_id)
  3351  
  3352    def __eq__(self, other):
  3353      if type(self) == type(other):
  3354        if self._is_default and other._is_default:
  3355          return True
  3356        return (
  3357            self.windowfn == other.windowfn and
  3358            self.triggerfn == other.triggerfn and
  3359            self.accumulation_mode == other.accumulation_mode and
  3360            self.timestamp_combiner == other.timestamp_combiner and
  3361            self.allowed_lateness == other.allowed_lateness and
  3362            self.environment_id == self.environment_id)
  3363      return False
  3364  
  3365    def __hash__(self):
  3366      return hash((
  3367          self.windowfn,
  3368          self.triggerfn,
  3369          self.accumulation_mode,
  3370          self.allowed_lateness,
  3371          self.timestamp_combiner,
  3372          self.environment_id))
  3373  
  3374    def is_default(self):
  3375      return self._is_default
  3376  
  3377    def to_runner_api(self, context):
  3378      # type: (PipelineContext) -> beam_runner_api_pb2.WindowingStrategy
  3379      environment_id = self.environment_id or context.default_environment_id()
  3380      return beam_runner_api_pb2.WindowingStrategy(
  3381          window_fn=self.windowfn.to_runner_api(context),
  3382          # TODO(robertwb): Prohibit implicit multi-level merging.
  3383          merge_status=(
  3384              beam_runner_api_pb2.MergeStatus.NEEDS_MERGE
  3385              if self.windowfn.is_merging() else
  3386              beam_runner_api_pb2.MergeStatus.NON_MERGING),
  3387          window_coder_id=context.coders.get_id(self.windowfn.get_window_coder()),
  3388          trigger=self.triggerfn.to_runner_api(context),
  3389          accumulation_mode=self.accumulation_mode,
  3390          output_time=self.timestamp_combiner,
  3391          # TODO(robertwb): Support EMIT_IF_NONEMPTY
  3392          closing_behavior=beam_runner_api_pb2.ClosingBehavior.EMIT_ALWAYS,
  3393          on_time_behavior=beam_runner_api_pb2.OnTimeBehavior.FIRE_ALWAYS,
  3394          allowed_lateness=self.allowed_lateness.micros // 1000,
  3395          environment_id=environment_id)
  3396  
  3397    @staticmethod
  3398    def from_runner_api(proto, context):
  3399      # pylint: disable=wrong-import-order, wrong-import-position
  3400      from apache_beam.transforms.trigger import TriggerFn
  3401      return Windowing(
  3402          windowfn=WindowFn.from_runner_api(proto.window_fn, context),
  3403          triggerfn=TriggerFn.from_runner_api(proto.trigger, context),
  3404          accumulation_mode=proto.accumulation_mode,
  3405          timestamp_combiner=proto.output_time,
  3406          allowed_lateness=Duration(micros=proto.allowed_lateness * 1000),
  3407          environment_id=None)
  3408  
  3409  
  3410  @typehints.with_input_types(T)
  3411  @typehints.with_output_types(T)
  3412  class WindowInto(ParDo):
  3413    """A window transform assigning windows to each element of a PCollection.
  3414  
  3415    Transforms an input PCollection by applying a windowing function to each
  3416    element.  Each transformed element in the result will be a WindowedValue
  3417    element with the same input value and timestamp, with its new set of windows
  3418    determined by the windowing function.
  3419    """
  3420    class WindowIntoFn(DoFn):
  3421      """A DoFn that applies a WindowInto operation."""
  3422      def __init__(self, windowing):
  3423        # type: (Windowing) -> None
  3424        self.windowing = windowing
  3425  
  3426      def process(
  3427          self, element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam):
  3428        context = WindowFn.AssignContext(
  3429            timestamp, element=element, window=window)
  3430        new_windows = self.windowing.windowfn.assign(context)
  3431        yield WindowedValue(element, context.timestamp, new_windows)
  3432  
  3433    def __init__(
  3434        self,
  3435        windowfn,  # type: typing.Union[Windowing, WindowFn]
  3436        trigger=None,  # type: typing.Optional[TriggerFn]
  3437        accumulation_mode=None,
  3438        timestamp_combiner=None,
  3439        allowed_lateness=0):
  3440      """Initializes a WindowInto transform.
  3441  
  3442      Args:
  3443        windowfn (Windowing, WindowFn): Function to be used for windowing.
  3444        trigger: (optional) Trigger used for windowing, or None for default.
  3445        accumulation_mode: (optional) Accumulation mode used for windowing,
  3446            required for non-trivial triggers.
  3447        timestamp_combiner: (optional) Timestamp combniner used for windowing,
  3448            or None for default.
  3449      """
  3450      if isinstance(windowfn, Windowing):
  3451        # Overlay windowing with kwargs.
  3452        windowing = windowfn
  3453        windowfn = windowing.windowfn
  3454  
  3455        # Use windowing to fill in defaults for the extra arguments.
  3456        trigger = trigger or windowing.triggerfn
  3457        accumulation_mode = accumulation_mode or windowing.accumulation_mode
  3458        timestamp_combiner = timestamp_combiner or windowing.timestamp_combiner
  3459  
  3460      self.windowing = Windowing(
  3461          windowfn,
  3462          trigger,
  3463          accumulation_mode,
  3464          timestamp_combiner,
  3465          allowed_lateness)
  3466      super().__init__(self.WindowIntoFn(self.windowing))
  3467  
  3468    def get_windowing(self, unused_inputs):
  3469      # type: (typing.Any) -> Windowing
  3470      return self.windowing
  3471  
  3472    def infer_output_type(self, input_type):
  3473      return input_type
  3474  
  3475    def expand(self, pcoll):
  3476      input_type = pcoll.element_type
  3477  
  3478      if input_type is not None:
  3479        output_type = input_type
  3480        self.with_input_types(input_type)
  3481        self.with_output_types(output_type)
  3482      return super().expand(pcoll)
  3483  
  3484    # typing: PTransform base class does not accept extra_kwargs
  3485    def to_runner_api_parameter(self, context, **extra_kwargs):  # type: ignore[override]
  3486      # type: (PipelineContext, **typing.Any) -> typing.Tuple[str, message.Message]
  3487      return (
  3488          common_urns.primitives.ASSIGN_WINDOWS.urn,
  3489          self.windowing.to_runner_api(context))
  3490  
  3491    @staticmethod
  3492    def from_runner_api_parameter(unused_ptransform, proto, context):
  3493      windowing = Windowing.from_runner_api(proto, context)
  3494      return WindowInto(
  3495          windowing.windowfn,
  3496          trigger=windowing.triggerfn,
  3497          accumulation_mode=windowing.accumulation_mode,
  3498          timestamp_combiner=windowing.timestamp_combiner)
  3499  
  3500  
  3501  PTransform.register_urn(
  3502      common_urns.primitives.ASSIGN_WINDOWS.urn,
  3503      # TODO(robertwb): Update WindowIntoPayload to include the full strategy.
  3504      # (Right now only WindowFn is used, but we need this to reconstitute the
  3505      # WindowInto transform, and in the future will need it at runtime to
  3506      # support meta-data driven triggers.)
  3507      # TODO(robertwb): Use a reference rather than embedding?
  3508      beam_runner_api_pb2.WindowingStrategy,
  3509      WindowInto.from_runner_api_parameter)
  3510  
  3511  # Python's pickling is broken for nested classes.
  3512  WindowIntoFn = WindowInto.WindowIntoFn
  3513  
  3514  
  3515  class Flatten(PTransform):
  3516    """Merges several PCollections into a single PCollection.
  3517  
  3518    Copies all elements in 0 or more PCollections into a single output
  3519    PCollection. If there are no input PCollections, the resulting PCollection
  3520    will be empty (but see also kwargs below).
  3521  
  3522    Args:
  3523      **kwargs: Accepts a single named argument "pipeline", which specifies the
  3524        pipeline that "owns" this PTransform. Ordinarily Flatten can obtain this
  3525        information from one of the input PCollections, but if there are none (or
  3526        if there's a chance there may be none), this argument is the only way to
  3527        provide pipeline information and should be considered mandatory.
  3528    """
  3529    def __init__(self, **kwargs):
  3530      super().__init__()
  3531      self.pipeline = kwargs.pop(
  3532          'pipeline', None)  # type: typing.Optional[Pipeline]
  3533      if kwargs:
  3534        raise ValueError('Unexpected keyword arguments: %s' % list(kwargs))
  3535  
  3536    def _extract_input_pvalues(self, pvalueish):
  3537      try:
  3538        pvalueish = tuple(pvalueish)
  3539      except TypeError:
  3540        raise ValueError(
  3541            'Input to Flatten must be an iterable. '
  3542            'Got a value of type %s instead.' % type(pvalueish))
  3543      return pvalueish, pvalueish
  3544  
  3545    def expand(self, pcolls):
  3546      for pcoll in pcolls:
  3547        self._check_pcollection(pcoll)
  3548      is_bounded = all(pcoll.is_bounded for pcoll in pcolls)
  3549      return pvalue.PCollection(self.pipeline, is_bounded=is_bounded)
  3550  
  3551    def infer_output_type(self, input_type):
  3552      return input_type
  3553  
  3554    def to_runner_api_parameter(self, context):
  3555      # type: (PipelineContext) -> typing.Tuple[str, None]
  3556      return common_urns.primitives.FLATTEN.urn, None
  3557  
  3558    @staticmethod
  3559    def from_runner_api_parameter(
  3560        unused_ptransform, unused_parameter, unused_context):
  3561      return Flatten()
  3562  
  3563  
  3564  PTransform.register_urn(
  3565      common_urns.primitives.FLATTEN.urn, None, Flatten.from_runner_api_parameter)
  3566  
  3567  
  3568  class Create(PTransform):
  3569    """A transform that creates a PCollection from an iterable."""
  3570    def __init__(self, values, reshuffle=True):
  3571      """Initializes a Create transform.
  3572  
  3573      Args:
  3574        values: An object of values for the PCollection
  3575      """
  3576      super().__init__()
  3577      if isinstance(values, (str, bytes)):
  3578        raise TypeError(
  3579            'PTransform Create: Refusing to treat string as '
  3580            'an iterable. (string=%r)' % values)
  3581      elif isinstance(values, dict):
  3582        values = values.items()
  3583      self.values = tuple(values)
  3584      self.reshuffle = reshuffle
  3585      self._coder = typecoders.registry.get_coder(self.get_output_type())
  3586  
  3587    def __getstate__(self):
  3588      serialized_values = [self._coder.encode(v) for v in self.values]
  3589      return serialized_values, self.reshuffle, self._coder
  3590  
  3591    def __setstate__(self, state):
  3592      serialized_values, self.reshuffle, self._coder = state
  3593      self.values = [self._coder.decode(v) for v in serialized_values]
  3594  
  3595    def to_runner_api_parameter(self, context):
  3596      # type: (PipelineContext) -> typing.Tuple[str, bytes]
  3597      # Required as this is identified by type in PTransformOverrides.
  3598      # TODO(https://github.com/apache/beam/issues/18713): Use an actual URN
  3599      # here.
  3600      return self.to_runner_api_pickled(context)
  3601  
  3602    def infer_output_type(self, unused_input_type):
  3603      if not self.values:
  3604        return typehints.Any
  3605      return typehints.Union[[
  3606          trivial_inference.instance_to_type(v) for v in self.values
  3607      ]]
  3608  
  3609    def get_output_type(self):
  3610      return (
  3611          self.get_type_hints().simple_output_type(self.label) or
  3612          self.infer_output_type(None))
  3613  
  3614    def expand(self, pbegin):
  3615      assert isinstance(pbegin, pvalue.PBegin)
  3616      serialized_values = [self._coder.encode(v) for v in self.values]
  3617      reshuffle = self.reshuffle
  3618  
  3619      # Avoid the "redistributing" reshuffle for 0 and 1 element Creates.
  3620      # These special cases are often used in building up more complex
  3621      # transforms (e.g. Write).
  3622  
  3623      class MaybeReshuffle(PTransform):
  3624        def expand(self, pcoll):
  3625          if len(serialized_values) > 1 and reshuffle:
  3626            from apache_beam.transforms.util import Reshuffle
  3627            return pcoll | Reshuffle()
  3628          else:
  3629            return pcoll
  3630  
  3631      return (
  3632          pbegin
  3633          | Impulse()
  3634          | FlatMap(lambda _: serialized_values).with_output_types(bytes)
  3635          | MaybeReshuffle().with_output_types(bytes)
  3636          | Map(self._coder.decode).with_output_types(self.get_output_type()))
  3637  
  3638    def as_read(self):
  3639      from apache_beam.io import iobase
  3640      source = self._create_source_from_iterable(self.values, self._coder)
  3641      return iobase.Read(source).with_output_types(self.get_output_type())
  3642  
  3643    def get_windowing(self, unused_inputs):
  3644      # type: (typing.Any) -> Windowing
  3645      return Windowing(GlobalWindows())
  3646  
  3647    @staticmethod
  3648    def _create_source_from_iterable(values, coder):
  3649      return Create._create_source(list(map(coder.encode, values)), coder)
  3650  
  3651    @staticmethod
  3652    def _create_source(serialized_values, coder):
  3653      # type: (typing.Any, typing.Any) -> create_source._CreateSource
  3654      from apache_beam.transforms.create_source import _CreateSource
  3655  
  3656      return _CreateSource(serialized_values, coder)
  3657  
  3658  
  3659  @typehints.with_output_types(bytes)
  3660  class Impulse(PTransform):
  3661    """Impulse primitive."""
  3662    def expand(self, pbegin):
  3663      if not isinstance(pbegin, pvalue.PBegin):
  3664        raise TypeError(
  3665            'Input to Impulse transform must be a PBegin but found %s' % pbegin)
  3666      return pvalue.PCollection(pbegin.pipeline, element_type=bytes)
  3667  
  3668    def get_windowing(self, inputs):
  3669      # type: (typing.Any) -> Windowing
  3670      return Windowing(GlobalWindows())
  3671  
  3672    def infer_output_type(self, unused_input_type):
  3673      return bytes
  3674  
  3675    def to_runner_api_parameter(self, unused_context):
  3676      # type: (PipelineContext) -> typing.Tuple[str, None]
  3677      return common_urns.primitives.IMPULSE.urn, None
  3678  
  3679    @staticmethod
  3680    @PTransform.register_urn(common_urns.primitives.IMPULSE.urn, None)
  3681    def from_runner_api_parameter(
  3682        unused_ptransform, unused_parameter, unused_context):
  3683      return Impulse()
  3684  
  3685  
  3686  def _strip_output_annotations(type_hint):
  3687    # TODO(robertwb): These should be parameterized types that the
  3688    # type inferencer understands.
  3689    # Then we can replace them with the correct element types instead of
  3690    # using Any. Refer to typehints.WindowedValue when doing this.
  3691    annotations = (TimestampedValue, WindowedValue, pvalue.TaggedOutput)
  3692  
  3693    contains_annotation = False
  3694  
  3695    def visitor(t, unused_args):
  3696      if t in annotations:
  3697        raise StopIteration
  3698  
  3699    try:
  3700      visit_inner_types(type_hint, visitor, [])
  3701    except StopIteration:
  3702      contains_annotation = True
  3703  
  3704    return typehints.Any if contains_annotation else type_hint