github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/common.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # cython: profile=True
    18  # cython: language_level=3
    19  
    20  """Worker operations executor.
    21  
    22  For internal use only; no backwards-compatibility guarantees.
    23  """
    24  
    25  # pytype: skip-file
    26  
    27  import sys
    28  import threading
    29  import traceback
    30  from enum import Enum
    31  from typing import TYPE_CHECKING
    32  from typing import Any
    33  from typing import Dict
    34  from typing import Iterable
    35  from typing import List
    36  from typing import Mapping
    37  from typing import Optional
    38  from typing import Tuple
    39  
    40  from apache_beam.coders import TupleCoder
    41  from apache_beam.internal import util
    42  from apache_beam.options.value_provider import RuntimeValueProvider
    43  from apache_beam.portability import common_urns
    44  from apache_beam.pvalue import TaggedOutput
    45  from apache_beam.runners.sdf_utils import NoOpWatermarkEstimatorProvider
    46  from apache_beam.runners.sdf_utils import RestrictionTrackerView
    47  from apache_beam.runners.sdf_utils import SplitResultPrimary
    48  from apache_beam.runners.sdf_utils import SplitResultResidual
    49  from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker
    50  from apache_beam.runners.sdf_utils import ThreadsafeWatermarkEstimator
    51  from apache_beam.transforms import DoFn
    52  from apache_beam.transforms import core
    53  from apache_beam.transforms import userstate
    54  from apache_beam.transforms.core import RestrictionProvider
    55  from apache_beam.transforms.core import WatermarkEstimatorProvider
    56  from apache_beam.transforms.window import GlobalWindow
    57  from apache_beam.transforms.window import TimestampedValue
    58  from apache_beam.transforms.window import WindowFn
    59  from apache_beam.typehints import typehints
    60  from apache_beam.typehints.batch import BatchConverter
    61  from apache_beam.utils.counters import Counter
    62  from apache_beam.utils.counters import CounterName
    63  from apache_beam.utils.timestamp import Timestamp
    64  from apache_beam.utils.windowed_value import HomogeneousWindowedBatch
    65  from apache_beam.utils.windowed_value import WindowedBatch
    66  from apache_beam.utils.windowed_value import WindowedValue
    67  
    68  if TYPE_CHECKING:
    69    from apache_beam.transforms import sideinputs
    70    from apache_beam.transforms.core import TimerSpec
    71    from apache_beam.io.iobase import RestrictionProgress
    72    from apache_beam.iobase import RestrictionTracker
    73    from apache_beam.iobase import WatermarkEstimator
    74  
    75  
    76  class NameContext(object):
    77    """Holds the name information for a step."""
    78    def __init__(self, step_name, transform_id=None):
    79      # type: (str, Optional[str]) -> None
    80  
    81      """Creates a new step NameContext.
    82  
    83      Args:
    84        step_name: The name of the step.
    85      """
    86      self.step_name = step_name
    87      self.transform_id = transform_id
    88  
    89    def __eq__(self, other):
    90      return self.step_name == other.step_name
    91  
    92    def __repr__(self):
    93      return 'NameContext(%s)' % self.__dict__
    94  
    95    def __hash__(self):
    96      return hash(self.step_name)
    97  
    98    def metrics_name(self):
    99      """Returns the step name used for metrics reporting."""
   100      return self.step_name
   101  
   102    def logging_name(self):
   103      """Returns the step name used for logging."""
   104      return self.step_name
   105  
   106  
   107  class Receiver(object):
   108    """For internal use only; no backwards-compatibility guarantees.
   109  
   110    An object that consumes a WindowedValue.
   111  
   112    This class can be efficiently used to pass values between the
   113    sdk and worker harnesses.
   114    """
   115    def receive(self, windowed_value):
   116      # type: (WindowedValue) -> None
   117      raise NotImplementedError
   118  
   119    def receive_batch(self, windowed_batch):
   120      # type: (WindowedBatch) -> None
   121      raise NotImplementedError
   122  
   123    def flush(self):
   124      raise NotImplementedError
   125  
   126  
   127  class MethodWrapper(object):
   128    """For internal use only; no backwards-compatibility guarantees.
   129  
   130    Represents a method that can be invoked by `DoFnInvoker`."""
   131    def __init__(self, obj_to_invoke, method_name):
   132      """
   133      Initiates a ``MethodWrapper``.
   134  
   135      Args:
   136        obj_to_invoke: the object that contains the method. Has to either be a
   137                      `DoFn` object or a `RestrictionProvider` object.
   138        method_name: name of the method as a string.
   139      """
   140  
   141      if not isinstance(obj_to_invoke,
   142                        (DoFn, RestrictionProvider, WatermarkEstimatorProvider)):
   143        raise ValueError(
   144            '\'obj_to_invoke\' has to be either a \'DoFn\' or '
   145            'a \'RestrictionProvider\'. Received %r instead.' % obj_to_invoke)
   146  
   147      self.args, self.defaults = core.get_function_arguments(obj_to_invoke,
   148                                                             method_name)
   149  
   150      # TODO(BEAM-5878) support kwonlyargs on Python 3.
   151      self.method_value = getattr(obj_to_invoke, method_name)
   152      self.method_name = method_name
   153  
   154      self.has_userstate_arguments = False
   155      self.state_args_to_replace = {}  # type: Dict[str, core.StateSpec]
   156      self.timer_args_to_replace = {}  # type: Dict[str, core.TimerSpec]
   157      self.timestamp_arg_name = None  # type: Optional[str]
   158      self.window_arg_name = None  # type: Optional[str]
   159      self.key_arg_name = None  # type: Optional[str]
   160      self.restriction_provider = None
   161      self.restriction_provider_arg_name = None
   162      self.watermark_estimator_provider = None
   163      self.watermark_estimator_provider_arg_name = None
   164      self.dynamic_timer_tag_arg_name = None
   165  
   166      if hasattr(self.method_value, 'unbounded_per_element'):
   167        self.unbounded_per_element = True
   168      else:
   169        self.unbounded_per_element = False
   170  
   171      for kw, v in zip(self.args[-len(self.defaults):], self.defaults):
   172        if isinstance(v, core.DoFn.StateParam):
   173          self.state_args_to_replace[kw] = v.state_spec
   174          self.has_userstate_arguments = True
   175        elif isinstance(v, core.DoFn.TimerParam):
   176          self.timer_args_to_replace[kw] = v.timer_spec
   177          self.has_userstate_arguments = True
   178        elif core.DoFn.TimestampParam == v:
   179          self.timestamp_arg_name = kw
   180        elif core.DoFn.WindowParam == v:
   181          self.window_arg_name = kw
   182        elif core.DoFn.KeyParam == v:
   183          self.key_arg_name = kw
   184        elif isinstance(v, core.DoFn.RestrictionParam):
   185          self.restriction_provider = v.restriction_provider or obj_to_invoke
   186          self.restriction_provider_arg_name = kw
   187        elif isinstance(v, core.DoFn.WatermarkEstimatorParam):
   188          self.watermark_estimator_provider = (
   189              v.watermark_estimator_provider or obj_to_invoke)
   190          self.watermark_estimator_provider_arg_name = kw
   191        elif core.DoFn.DynamicTimerTagParam == v:
   192          self.dynamic_timer_tag_arg_name = kw
   193  
   194      # Create NoOpWatermarkEstimatorProvider if there is no
   195      # WatermarkEstimatorParam provided.
   196      if self.watermark_estimator_provider is None:
   197        self.watermark_estimator_provider = NoOpWatermarkEstimatorProvider()
   198  
   199    def invoke_timer_callback(
   200        self,
   201        user_state_context,
   202        key,
   203        window,
   204        timestamp,
   205        pane_info,
   206        dynamic_timer_tag):
   207      # TODO(ccy): support side inputs.
   208      kwargs = {}
   209      if self.has_userstate_arguments:
   210        for kw, state_spec in self.state_args_to_replace.items():
   211          kwargs[kw] = user_state_context.get_state(state_spec, key, window)
   212        for kw, timer_spec in self.timer_args_to_replace.items():
   213          kwargs[kw] = user_state_context.get_timer(
   214              timer_spec, key, window, timestamp, pane_info)
   215  
   216      if self.timestamp_arg_name:
   217        kwargs[self.timestamp_arg_name] = Timestamp.of(timestamp)
   218      if self.window_arg_name:
   219        kwargs[self.window_arg_name] = window
   220      if self.key_arg_name:
   221        kwargs[self.key_arg_name] = key
   222      if self.dynamic_timer_tag_arg_name:
   223        kwargs[self.dynamic_timer_tag_arg_name] = dynamic_timer_tag
   224  
   225      if kwargs:
   226        return self.method_value(**kwargs)
   227      else:
   228        return self.method_value()
   229  
   230  
   231  class BatchingPreference(Enum):
   232    DO_NOT_CARE = 1  # This operation can operate on batches or element-at-a-time
   233    # TODO: Should we also store batching parameters here? (time/size preferences)
   234    BATCH_REQUIRED = 2  # This operation can only operate on batches
   235    BATCH_FORBIDDEN = 3  # This operation can only work element-at-a-time
   236    # Other possibilities: BATCH_PREFERRED (with min batch size specified)
   237  
   238    @property
   239    def supports_batches(self) -> bool:
   240      return self in (self.BATCH_REQUIRED, self.DO_NOT_CARE)
   241  
   242    @property
   243    def supports_elements(self) -> bool:
   244      return self in (self.BATCH_FORBIDDEN, self.DO_NOT_CARE)
   245  
   246    @property
   247    def requires_batches(self) -> bool:
   248      return self == self.BATCH_REQUIRED
   249  
   250  
   251  class DoFnSignature(object):
   252    """Represents the signature of a given ``DoFn`` object.
   253  
   254    Signature of a ``DoFn`` provides a view of the properties of a given ``DoFn``.
   255    Among other things, this will give an extensible way for for (1) accessing the
   256    structure of the ``DoFn`` including methods and method parameters
   257    (2) identifying features that a given ``DoFn`` support, for example, whether
   258    a given ``DoFn`` is a Splittable ``DoFn`` (
   259    https://s.apache.org/splittable-do-fn) (3) validating a ``DoFn`` based on the
   260    feature set offered by it.
   261    """
   262    def __init__(self, do_fn):
   263      # type: (core.DoFn) -> None
   264      # We add a property here for all methods defined by Beam DoFn features.
   265  
   266      assert isinstance(do_fn, core.DoFn)
   267      self.do_fn = do_fn
   268  
   269      self.process_method = MethodWrapper(do_fn, 'process')
   270      self.process_batch_method = MethodWrapper(do_fn, 'process_batch')
   271      self.start_bundle_method = MethodWrapper(do_fn, 'start_bundle')
   272      self.finish_bundle_method = MethodWrapper(do_fn, 'finish_bundle')
   273      self.setup_lifecycle_method = MethodWrapper(do_fn, 'setup')
   274      self.teardown_lifecycle_method = MethodWrapper(do_fn, 'teardown')
   275  
   276      restriction_provider = self.get_restriction_provider()
   277      watermark_estimator_provider = self.get_watermark_estimator_provider()
   278      self.create_watermark_estimator_method = (
   279          MethodWrapper(
   280              watermark_estimator_provider, 'create_watermark_estimator'))
   281      self.initial_restriction_method = (
   282          MethodWrapper(restriction_provider, 'initial_restriction')
   283          if restriction_provider else None)
   284      self.create_tracker_method = (
   285          MethodWrapper(restriction_provider, 'create_tracker')
   286          if restriction_provider else None)
   287      self.split_method = (
   288          MethodWrapper(restriction_provider, 'split')
   289          if restriction_provider else None)
   290  
   291      self._validate()
   292  
   293      # Handle stateful DoFns.
   294      self._is_stateful_dofn = userstate.is_stateful_dofn(do_fn)
   295      self.timer_methods = {}  # type: Dict[TimerSpec, MethodWrapper]
   296      if self._is_stateful_dofn:
   297        # Populate timer firing methods, keyed by TimerSpec.
   298        _, all_timer_specs = userstate.get_dofn_specs(do_fn)
   299        for timer_spec in all_timer_specs:
   300          method = timer_spec._attached_callback
   301          self.timer_methods[timer_spec] = MethodWrapper(do_fn, method.__name__)
   302  
   303    def get_restriction_provider(self):
   304      # type: () -> RestrictionProvider
   305      return self.process_method.restriction_provider
   306  
   307    def get_watermark_estimator_provider(self):
   308      # type: () -> WatermarkEstimatorProvider
   309      return self.process_method.watermark_estimator_provider
   310  
   311    def is_unbounded_per_element(self):
   312      return self.process_method.unbounded_per_element
   313  
   314    def _validate(self):
   315      # type: () -> None
   316      self._validate_process()
   317      self._validate_process_batch()
   318      self._validate_bundle_method(self.start_bundle_method)
   319      self._validate_bundle_method(self.finish_bundle_method)
   320      self._validate_stateful_dofn()
   321  
   322    def _check_duplicate_dofn_params(self, method: MethodWrapper):
   323      param_ids = [
   324          d.param_id for d in method.defaults if isinstance(d, core._DoFnParam)
   325      ]
   326      if len(param_ids) != len(set(param_ids)):
   327        raise ValueError(
   328            'DoFn %r has duplicate %s method parameters: %s.' %
   329            (self.do_fn, method.method_name, param_ids))
   330  
   331    def _validate_process(self):
   332      # type: () -> None
   333  
   334      """Validate that none of the DoFnParameters are repeated in the function
   335      """
   336      self._check_duplicate_dofn_params(self.process_method)
   337  
   338    def _validate_process_batch(self):
   339      # type: () -> None
   340      self._check_duplicate_dofn_params(self.process_batch_method)
   341  
   342      for d in self.process_batch_method.defaults:
   343        if not isinstance(d, core._DoFnParam):
   344          continue
   345  
   346        # Helpful errors for params which will be supported in the future
   347        if d == (core.DoFn.ElementParam):
   348          # We currently assume we can just get the typehint from the first
   349          # parameter. ElementParam breaks this assumption
   350          raise NotImplementedError(
   351              f"DoFn {self.do_fn!r} uses unsupported DoFn param ElementParam.")
   352  
   353        if d in (core.DoFn.KeyParam, core.DoFn.StateParam, core.DoFn.TimerParam):
   354          raise NotImplementedError(
   355              f"DoFn {self.do_fn!r} has unsupported per-key DoFn param {d}. "
   356              "Per-key DoFn params are not yet supported for process_batch "
   357              "(https://github.com/apache/beam/issues/21653).")
   358  
   359        # Fallback to catch anything not explicitly supported
   360        if not d in (core.DoFn.WindowParam,
   361                     core.DoFn.TimestampParam,
   362                     core.DoFn.PaneInfoParam):
   363          raise ValueError(
   364              f"DoFn {self.do_fn!r} has unsupported process_batch "
   365              f"method parameter {d}")
   366  
   367    def _validate_bundle_method(self, method_wrapper):
   368      """Validate that none of the DoFnParameters are used in the function
   369      """
   370      for param in core.DoFn.DoFnProcessParams:
   371        if param in method_wrapper.defaults:
   372          raise ValueError(
   373              'DoFn.process() method-only parameter %s cannot be used in %s.' %
   374              (param, method_wrapper))
   375  
   376    def _validate_stateful_dofn(self):
   377      # type: () -> None
   378      userstate.validate_stateful_dofn(self.do_fn)
   379  
   380    def is_splittable_dofn(self):
   381      # type: () -> bool
   382      return self.get_restriction_provider() is not None
   383  
   384    def get_restriction_coder(self):
   385      # type: () -> Optional[TupleCoder]
   386  
   387      """Get coder for a restriction when processing an SDF. """
   388      if self.is_splittable_dofn():
   389        return TupleCoder([
   390            (self.get_restriction_provider().restriction_coder()),
   391            (self.get_watermark_estimator_provider().estimator_state_coder())
   392        ])
   393      else:
   394        return None
   395  
   396    def is_stateful_dofn(self):
   397      # type: () -> bool
   398      return self._is_stateful_dofn
   399  
   400    def has_timers(self):
   401      # type: () -> bool
   402      _, all_timer_specs = userstate.get_dofn_specs(self.do_fn)
   403      return bool(all_timer_specs)
   404  
   405    def has_bundle_finalization(self):
   406      for sig in (self.start_bundle_method,
   407                  self.process_method,
   408                  self.finish_bundle_method):
   409        for d in sig.defaults:
   410          try:
   411            if d == DoFn.BundleFinalizerParam:
   412              return True
   413          except Exception:  # pylint: disable=broad-except
   414            # Default value might be incomparable.
   415            pass
   416      return False
   417  
   418  
   419  class DoFnInvoker(object):
   420    """An abstraction that can be used to execute DoFn methods.
   421  
   422    A DoFnInvoker describes a particular way for invoking methods of a DoFn
   423    represented by a given DoFnSignature."""
   424  
   425    def __init__(self,
   426                 output_handler,  # type: _OutputHandler
   427                 signature  # type: DoFnSignature
   428                ):
   429      # type: (...) -> None
   430  
   431      """
   432      Initializes `DoFnInvoker`
   433  
   434      :param output_handler: an OutputHandler for receiving elements produced
   435                               by invoking functions of the DoFn.
   436      :param signature: a DoFnSignature for the DoFn being invoked
   437      """
   438      self.output_handler = output_handler
   439      self.signature = signature
   440      self.user_state_context = None  # type: Optional[userstate.UserStateContext]
   441      self.bundle_finalizer_param = None  # type: Optional[core._BundleFinalizerParam]
   442  
   443    @staticmethod
   444    def create_invoker(
   445        signature,  # type: DoFnSignature
   446        output_handler,  # type: OutputHandler
   447        context=None,  # type: Optional[DoFnContext]
   448        side_inputs=None,   # type: Optional[List[sideinputs.SideInputMap]]
   449        input_args=None, input_kwargs=None,
   450        process_invocation=True,
   451        user_state_context=None,  # type: Optional[userstate.UserStateContext]
   452        bundle_finalizer_param=None  # type: Optional[core._BundleFinalizerParam]
   453    ):
   454      # type: (...) -> DoFnInvoker
   455  
   456      """ Creates a new DoFnInvoker based on given arguments.
   457  
   458      Args:
   459          output_handler: an OutputHandler for receiving elements produced by
   460                            invoking functions of the DoFn.
   461          signature: a DoFnSignature for the DoFn being invoked.
   462          context: Context to be used when invoking the DoFn (deprecated).
   463          side_inputs: side inputs to be used when invoking th process method.
   464          input_args: arguments to be used when invoking the process method. Some
   465                      of the arguments given here might be placeholders (for
   466                      example for side inputs) that get filled before invoking the
   467                      process method.
   468          input_kwargs: keyword arguments to be used when invoking the process
   469                        method. Some of the keyword arguments given here might be
   470                        placeholders (for example for side inputs) that get filled
   471                        before invoking the process method.
   472          process_invocation: If True, this function may return an invoker that
   473                              performs extra optimizations for invoking process()
   474                              method efficiently.
   475          user_state_context: The UserStateContext instance for the current
   476                              Stateful DoFn.
   477          bundle_finalizer_param: The param that passed to a process method, which
   478                                  allows a callback to be registered.
   479      """
   480      side_inputs = side_inputs or []
   481      use_per_window_invoker = process_invocation and (
   482          side_inputs or input_args or input_kwargs or
   483          signature.process_method.defaults or
   484          signature.process_batch_method.defaults or signature.is_stateful_dofn())
   485      if not use_per_window_invoker:
   486        return SimpleInvoker(output_handler, signature)
   487      else:
   488        if context is None:
   489          raise TypeError("Must provide context when not using SimpleInvoker")
   490        return PerWindowInvoker(
   491            output_handler,
   492            signature,
   493            context,
   494            side_inputs,
   495            input_args,
   496            input_kwargs,
   497            user_state_context,
   498            bundle_finalizer_param)
   499  
   500    def invoke_process(self,
   501                       windowed_value,  # type: WindowedValue
   502                       restriction=None,
   503                       watermark_estimator_state=None,
   504                       additional_args=None,
   505                       additional_kwargs=None
   506                      ):
   507      # type: (...) -> Iterable[SplitResultResidual]
   508  
   509      """Invokes the DoFn.process() function.
   510  
   511      Args:
   512        windowed_value: a WindowedValue object that gives the element for which
   513                        process() method should be invoked along with the window
   514                        the element belongs to.
   515        restriction: The restriction to use when executing this splittable DoFn.
   516                     Should only be specified for splittable DoFns.
   517        watermark_estimator_state: The watermark estimator state to use when
   518                                   executing this splittable DoFn. Should only
   519                                   be specified for splittable DoFns.
   520        additional_args: additional arguments to be passed to the current
   521                        `DoFn.process()` invocation, usually as side inputs.
   522        additional_kwargs: additional keyword arguments to be passed to the
   523                           current `DoFn.process()` invocation.
   524      """
   525      raise NotImplementedError
   526  
   527    def invoke_process_batch(self,
   528                       windowed_batch,  # type: WindowedBatch
   529                       additional_args=None,
   530                       additional_kwargs=None
   531                      ):
   532      # type: (...) -> None
   533  
   534      """Invokes the DoFn.process() function.
   535  
   536      Args:
   537        windowed_batch: a WindowedBatch object that gives a batch of elements for
   538                        which process_batch() method should be invoked, along with
   539                        the window each element belongs to.
   540        additional_args: additional arguments to be passed to the current
   541                        `DoFn.process()` invocation, usually as side inputs.
   542        additional_kwargs: additional keyword arguments to be passed to the
   543                           current `DoFn.process()` invocation.
   544      """
   545      raise NotImplementedError
   546  
   547    def invoke_setup(self):
   548      # type: () -> None
   549  
   550      """Invokes the DoFn.setup() method
   551      """
   552      self.signature.setup_lifecycle_method.method_value()
   553  
   554    def invoke_start_bundle(self):
   555      # type: () -> None
   556  
   557      """Invokes the DoFn.start_bundle() method.
   558      """
   559      self.output_handler.start_bundle_outputs(
   560          self.signature.start_bundle_method.method_value())
   561  
   562    def invoke_finish_bundle(self):
   563      # type: () -> None
   564  
   565      """Invokes the DoFn.finish_bundle() method.
   566      """
   567      self.output_handler.finish_bundle_outputs(
   568          self.signature.finish_bundle_method.method_value())
   569  
   570    def invoke_teardown(self):
   571      # type: () -> None
   572  
   573      """Invokes the DoFn.teardown() method
   574      """
   575      self.signature.teardown_lifecycle_method.method_value()
   576  
   577    def invoke_user_timer(
   578        self, timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag):
   579      # self.output_handler is Optional, but in practice it won't be None here
   580      self.output_handler.handle_process_outputs(
   581          WindowedValue(None, timestamp, (window, )),
   582          self.signature.timer_methods[timer_spec].invoke_timer_callback(
   583              self.user_state_context,
   584              key,
   585              window,
   586              timestamp,
   587              pane_info,
   588              dynamic_timer_tag))
   589  
   590    def invoke_create_watermark_estimator(self, estimator_state):
   591      return self.signature.create_watermark_estimator_method.method_value(
   592          estimator_state)
   593  
   594    def invoke_split(self, element, restriction):
   595      return self.signature.split_method.method_value(element, restriction)
   596  
   597    def invoke_initial_restriction(self, element):
   598      return self.signature.initial_restriction_method.method_value(element)
   599  
   600    def invoke_create_tracker(self, restriction):
   601      return self.signature.create_tracker_method.method_value(restriction)
   602  
   603  
   604  class SimpleInvoker(DoFnInvoker):
   605    """An invoker that processes elements ignoring windowing information."""
   606  
   607    def __init__(self,
   608                 output_handler,  # type: OutputHandler
   609                 signature  # type: DoFnSignature
   610                ):
   611      # type: (...) -> None
   612      super().__init__(output_handler, signature)
   613      self.process_method = signature.process_method.method_value
   614      self.process_batch_method = signature.process_batch_method.method_value
   615  
   616    def invoke_process(self,
   617                       windowed_value,  # type: WindowedValue
   618                       restriction=None,
   619                       watermark_estimator_state=None,
   620                       additional_args=None,
   621                       additional_kwargs=None
   622                      ):
   623      # type: (...) -> Iterable[SplitResultResidual]
   624      self.output_handler.handle_process_outputs(
   625          windowed_value, self.process_method(windowed_value.value))
   626      return []
   627  
   628    def invoke_process_batch(self,
   629                       windowed_batch,  # type: WindowedBatch
   630                       restriction=None,
   631                       watermark_estimator_state=None,
   632                       additional_args=None,
   633                       additional_kwargs=None
   634                      ):
   635      # type: (...) -> None
   636      self.output_handler.handle_process_batch_outputs(
   637          windowed_batch, self.process_batch_method(windowed_batch.values))
   638  
   639  
   640  def _get_arg_placeholders(
   641      method: MethodWrapper,
   642      input_args: Optional[List[Any]],
   643      input_kwargs: Optional[Dict[str, any]]):
   644    input_args = input_args if input_args else []
   645    input_kwargs = input_kwargs if input_kwargs else {}
   646  
   647    arg_names = method.args
   648    default_arg_values = method.defaults
   649  
   650    # Create placeholder for element parameter of DoFn.process() method.
   651    # Not to be confused with ArgumentPlaceHolder, which may be passed in
   652    # input_args and is a placeholder for side-inputs.
   653    class ArgPlaceholder(object):
   654      def __init__(self, placeholder):
   655        self.placeholder = placeholder
   656  
   657    if all(core.DoFn.ElementParam != arg for arg in default_arg_values):
   658      # TODO(https://github.com/apache/beam/issues/19631): Handle cases in which
   659      #   len(arg_names) == len(default_arg_values).
   660      args_to_pick = len(arg_names) - len(default_arg_values) - 1
   661      # Positional argument values for process(), with placeholders for special
   662      # values such as the element, timestamp, etc.
   663      args_with_placeholders = ([ArgPlaceholder(core.DoFn.ElementParam)] +
   664                                input_args[:args_to_pick])
   665    else:
   666      args_to_pick = len(arg_names) - len(default_arg_values)
   667      args_with_placeholders = input_args[:args_to_pick]
   668  
   669    # Fill the OtherPlaceholders for context, key, window or timestamp
   670    remaining_args_iter = iter(input_args[args_to_pick:])
   671    for a, d in zip(arg_names[-len(default_arg_values):], default_arg_values):
   672      if core.DoFn.ElementParam == d:
   673        args_with_placeholders.append(ArgPlaceholder(d))
   674      elif core.DoFn.KeyParam == d:
   675        args_with_placeholders.append(ArgPlaceholder(d))
   676      elif core.DoFn.WindowParam == d:
   677        args_with_placeholders.append(ArgPlaceholder(d))
   678      elif core.DoFn.TimestampParam == d:
   679        args_with_placeholders.append(ArgPlaceholder(d))
   680      elif core.DoFn.PaneInfoParam == d:
   681        args_with_placeholders.append(ArgPlaceholder(d))
   682      elif core.DoFn.SideInputParam == d:
   683        # If no more args are present then the value must be passed via kwarg
   684        try:
   685          args_with_placeholders.append(next(remaining_args_iter))
   686        except StopIteration:
   687          if a not in input_kwargs:
   688            raise ValueError("Value for sideinput %s not provided" % a)
   689      elif isinstance(d, core.DoFn.StateParam):
   690        args_with_placeholders.append(ArgPlaceholder(d))
   691      elif isinstance(d, core.DoFn.TimerParam):
   692        args_with_placeholders.append(ArgPlaceholder(d))
   693      elif isinstance(d, type) and core.DoFn.BundleFinalizerParam == d:
   694        args_with_placeholders.append(ArgPlaceholder(d))
   695      else:
   696        # If no more args are present then the value must be passed via kwarg
   697        try:
   698          args_with_placeholders.append(next(remaining_args_iter))
   699        except StopIteration:
   700          pass
   701    args_with_placeholders.extend(list(remaining_args_iter))
   702  
   703    # Stash the list of placeholder positions for performance
   704    placeholders = [(i, x.placeholder)
   705                    for (i, x) in enumerate(args_with_placeholders)
   706                    if isinstance(x, ArgPlaceholder)]
   707  
   708    return placeholders, args_with_placeholders, input_kwargs
   709  
   710  
   711  class PerWindowInvoker(DoFnInvoker):
   712    """An invoker that processes elements considering windowing information."""
   713  
   714    def __init__(self,
   715                 output_handler,  # type: OutputHandler
   716                 signature,  # type: DoFnSignature
   717                 context,  # type: DoFnContext
   718                 side_inputs,  # type: Iterable[sideinputs.SideInputMap]
   719                 input_args,
   720                 input_kwargs,
   721                 user_state_context,  # type: Optional[userstate.UserStateContext]
   722                 bundle_finalizer_param  # type: Optional[core._BundleFinalizerParam]
   723                ):
   724      super().__init__(output_handler, signature)
   725      self.side_inputs = side_inputs
   726      self.context = context
   727      self.process_method = signature.process_method.method_value
   728      default_arg_values = signature.process_method.defaults
   729      self.has_windowed_inputs = (
   730          not all(si.is_globally_windowed() for si in side_inputs) or any(
   731              core.DoFn.WindowParam == arg
   732              for arg in signature.process_method.defaults) or any(
   733                  core.DoFn.WindowParam == arg
   734                  for arg in signature.process_batch_method.defaults) or
   735          signature.is_stateful_dofn())
   736      self.user_state_context = user_state_context
   737      self.is_splittable = signature.is_splittable_dofn()
   738      self.is_key_param_required = any(
   739          core.DoFn.KeyParam == arg for arg in default_arg_values)
   740      self.threadsafe_restriction_tracker = None  # type: Optional[ThreadsafeRestrictionTracker]
   741      self.threadsafe_watermark_estimator = None  # type: Optional[ThreadsafeWatermarkEstimator]
   742      self.current_windowed_value = None  # type: Optional[WindowedValue]
   743      self.bundle_finalizer_param = bundle_finalizer_param
   744      if self.is_splittable:
   745        self.splitting_lock = threading.Lock()
   746        self.current_window_index = None
   747        self.stop_window_index = None
   748  
   749      # Flag to cache additional arguments on the first element if all
   750      # inputs are within the global window.
   751      self.cache_globally_windowed_args = not self.has_windowed_inputs
   752  
   753      # Try to prepare all the arguments that can just be filled in
   754      # without any additional work. in the process function.
   755      # Also cache all the placeholders needed in the process function.
   756      (
   757          self.placeholders_for_process,
   758          self.args_for_process,
   759          self.kwargs_for_process) = _get_arg_placeholders(
   760              signature.process_method, input_args, input_kwargs)
   761  
   762      self.process_batch_method = signature.process_batch_method.method_value
   763  
   764      (
   765          self.placeholders_for_process_batch,
   766          self.args_for_process_batch,
   767          self.kwargs_for_process_batch) = _get_arg_placeholders(
   768              signature.process_batch_method, input_args, input_kwargs)
   769  
   770    def invoke_process(self,
   771                       windowed_value,  # type: WindowedValue
   772                       restriction=None,
   773                       watermark_estimator_state=None,
   774                       additional_args=None,
   775                       additional_kwargs=None
   776                      ):
   777      # type: (...) -> Iterable[SplitResultResidual]
   778      if not additional_args:
   779        additional_args = []
   780      if not additional_kwargs:
   781        additional_kwargs = {}
   782  
   783      self.context.set_element(windowed_value)
   784      # Call for the process function for each window if has windowed side inputs
   785      # or if the process accesses the window parameter. We can just call it once
   786      # otherwise as none of the arguments are changing
   787  
   788      residuals = []
   789      if self.is_splittable:
   790        if restriction is None:
   791          # This may be a SDF invoked as an ordinary DoFn on runners that don't
   792          # understand SDF.  See, e.g. BEAM-11472.
   793          # In this case, processing the element is simply processing it against
   794          # the entire initial restriction.
   795          restriction = self.signature.initial_restriction_method.method_value(
   796              windowed_value.value)
   797  
   798        with self.splitting_lock:
   799          self.current_windowed_value = windowed_value
   800          self.restriction = restriction
   801          self.watermark_estimator_state = watermark_estimator_state
   802        try:
   803          if self.has_windowed_inputs and len(windowed_value.windows) > 1:
   804            for i, w in enumerate(windowed_value.windows):
   805              if not self._should_process_window_for_sdf(
   806                  windowed_value, additional_kwargs, i):
   807                break
   808              residual = self._invoke_process_per_window(
   809                  WindowedValue(
   810                      windowed_value.value, windowed_value.timestamp, (w, )),
   811                  additional_args,
   812                  additional_kwargs)
   813              if residual:
   814                residuals.append(residual)
   815          else:
   816            if self._should_process_window_for_sdf(windowed_value,
   817                                                   additional_kwargs):
   818              residual = self._invoke_process_per_window(
   819                  windowed_value, additional_args, additional_kwargs)
   820              if residual:
   821                residuals.append(residual)
   822        finally:
   823          with self.splitting_lock:
   824            self.current_windowed_value = None
   825            self.restriction = None
   826            self.watermark_estimator_state = None
   827            self.current_window_index = None
   828            self.threadsafe_restriction_tracker = None
   829            self.threadsafe_watermark_estimator = None
   830      elif self.has_windowed_inputs and len(windowed_value.windows) != 1:
   831        for w in windowed_value.windows:
   832          self._invoke_process_per_window(
   833              WindowedValue(
   834                  windowed_value.value, windowed_value.timestamp, (w, )),
   835              additional_args,
   836              additional_kwargs)
   837      else:
   838        self._invoke_process_per_window(
   839            windowed_value, additional_args, additional_kwargs)
   840      return residuals
   841  
   842    def invoke_process_batch(self,
   843                       windowed_batch,  # type: WindowedBatch
   844                       additional_args=None,
   845                       additional_kwargs=None
   846                      ):
   847      # type: (...) -> None
   848  
   849      if not additional_args:
   850        additional_args = []
   851      if not additional_kwargs:
   852        additional_kwargs = {}
   853  
   854      assert isinstance(windowed_batch, HomogeneousWindowedBatch)
   855  
   856      if self.has_windowed_inputs and len(windowed_batch.windows) != 1:
   857        for w in windowed_batch.windows:
   858          self._invoke_process_batch_per_window(
   859              HomogeneousWindowedBatch.of(
   860                  windowed_batch.values,
   861                  windowed_batch.timestamp, (w, ),
   862                  windowed_batch.pane_info),
   863              additional_args,
   864              additional_kwargs)
   865      else:
   866        self._invoke_process_batch_per_window(
   867            windowed_batch, additional_args, additional_kwargs)
   868  
   869    def _should_process_window_for_sdf(
   870        self,
   871        windowed_value, # type: WindowedValue
   872        additional_kwargs,
   873        window_index=None, # type: Optional[int]
   874    ):
   875      restriction_tracker = self.invoke_create_tracker(self.restriction)
   876      watermark_estimator = self.invoke_create_watermark_estimator(
   877          self.watermark_estimator_state)
   878      with self.splitting_lock:
   879        if window_index:
   880          self.current_window_index = window_index
   881          if window_index == 0:
   882            self.stop_window_index = len(windowed_value.windows)
   883          if window_index == self.stop_window_index:
   884            return False
   885        self.threadsafe_restriction_tracker = ThreadsafeRestrictionTracker(
   886            restriction_tracker)
   887        self.threadsafe_watermark_estimator = (
   888            ThreadsafeWatermarkEstimator(watermark_estimator))
   889  
   890      restriction_tracker_param = (
   891          self.signature.process_method.restriction_provider_arg_name)
   892      if not restriction_tracker_param:
   893        raise ValueError(
   894            'DoFn is splittable but DoFn does not have a '
   895            'RestrictionTrackerParam defined')
   896      additional_kwargs[restriction_tracker_param] = (
   897          RestrictionTrackerView(self.threadsafe_restriction_tracker))
   898      watermark_param = (
   899          self.signature.process_method.watermark_estimator_provider_arg_name)
   900      # When the watermark_estimator is a NoOpWatermarkEstimator, the system
   901      # will not add watermark_param into the DoFn param list.
   902      if watermark_param is not None:
   903        additional_kwargs[watermark_param] = self.threadsafe_watermark_estimator
   904      return True
   905  
   906    def _invoke_process_per_window(self,
   907                                   windowed_value,  # type: WindowedValue
   908                                   additional_args,
   909                                   additional_kwargs,
   910                                  ):
   911      # type: (...) -> Optional[SplitResultResidual]
   912  
   913      if self.has_windowed_inputs:
   914        assert len(windowed_value.windows) <= 1
   915        window, = windowed_value.windows
   916        side_inputs = [si[window] for si in self.side_inputs]
   917        side_inputs.extend(additional_args)
   918        args_for_process, kwargs_for_process = util.insert_values_in_args(
   919            self.args_for_process, self.kwargs_for_process,
   920            side_inputs)
   921      elif self.cache_globally_windowed_args:
   922        # Attempt to cache additional args if all inputs are globally
   923        # windowed inputs when processing the first element.
   924        self.cache_globally_windowed_args = False
   925  
   926        # Fill in sideInputs if they are globally windowed
   927        global_window = GlobalWindow()
   928        self.args_for_process, self.kwargs_for_process = (
   929            util.insert_values_in_args(
   930                self.args_for_process, self.kwargs_for_process,
   931                [si[global_window] for si in self.side_inputs]))
   932        args_for_process, kwargs_for_process = (
   933            self.args_for_process, self.kwargs_for_process)
   934      else:
   935        args_for_process, kwargs_for_process = (
   936            self.args_for_process, self.kwargs_for_process)
   937  
   938      # Extract key in the case of a stateful DoFn. Note that in the case of a
   939      # stateful DoFn, we set during __init__ self.has_windowed_inputs to be
   940      # True. Therefore, windows will be exploded coming into this method, and
   941      # we can rely on the window variable being set above.
   942      if self.user_state_context or self.is_key_param_required:
   943        try:
   944          key, unused_value = windowed_value.value
   945        except (TypeError, ValueError):
   946          raise ValueError((
   947              'Input value to a stateful DoFn or KeyParam must be a KV tuple; '
   948              'instead, got \'%s\'.') % (windowed_value.value, ))
   949  
   950      for i, p in self.placeholders_for_process:
   951        if core.DoFn.ElementParam == p:
   952          args_for_process[i] = windowed_value.value
   953        elif core.DoFn.KeyParam == p:
   954          args_for_process[i] = key
   955        elif core.DoFn.WindowParam == p:
   956          args_for_process[i] = window
   957        elif core.DoFn.TimestampParam == p:
   958          args_for_process[i] = windowed_value.timestamp
   959        elif core.DoFn.PaneInfoParam == p:
   960          args_for_process[i] = windowed_value.pane_info
   961        elif isinstance(p, core.DoFn.StateParam):
   962          assert self.user_state_context is not None
   963          args_for_process[i] = (
   964              self.user_state_context.get_state(p.state_spec, key, window))
   965        elif isinstance(p, core.DoFn.TimerParam):
   966          assert self.user_state_context is not None
   967          args_for_process[i] = (
   968              self.user_state_context.get_timer(
   969                  p.timer_spec,
   970                  key,
   971                  window,
   972                  windowed_value.timestamp,
   973                  windowed_value.pane_info))
   974        elif core.DoFn.BundleFinalizerParam == p:
   975          args_for_process[i] = self.bundle_finalizer_param
   976  
   977      kwargs_for_process = kwargs_for_process or {}
   978  
   979      if additional_kwargs:
   980        kwargs_for_process.update(additional_kwargs)
   981  
   982      self.output_handler.handle_process_outputs(
   983          windowed_value,
   984          self.process_method(*args_for_process, **kwargs_for_process),
   985          self.threadsafe_watermark_estimator)
   986  
   987      if self.is_splittable:
   988        assert self.threadsafe_restriction_tracker is not None
   989        self.threadsafe_restriction_tracker.check_done()
   990        deferred_status = self.threadsafe_restriction_tracker.deferred_status()
   991        if deferred_status:
   992          deferred_restriction, deferred_timestamp = deferred_status
   993          element = windowed_value.value
   994          size = self.signature.get_restriction_provider().restriction_size(
   995              element, deferred_restriction)
   996          if size < 0:
   997            raise ValueError('Expected size >= 0 but received %s.' % size)
   998          current_watermark = (
   999              self.threadsafe_watermark_estimator.current_watermark())
  1000          estimator_state = (
  1001              self.threadsafe_watermark_estimator.get_estimator_state())
  1002          residual_value = ((element, (deferred_restriction, estimator_state)),
  1003                            size)
  1004          return SplitResultResidual(
  1005              residual_value=windowed_value.with_value(residual_value),
  1006              current_watermark=current_watermark,
  1007              deferred_timestamp=deferred_timestamp)
  1008      return None
  1009  
  1010    def _invoke_process_batch_per_window(
  1011        self,
  1012        windowed_batch: WindowedBatch,
  1013        additional_args,
  1014        additional_kwargs,
  1015    ):
  1016      # type: (...) -> Optional[SplitResultResidual]
  1017  
  1018      if self.has_windowed_inputs:
  1019        assert isinstance(windowed_batch, HomogeneousWindowedBatch)
  1020        assert len(windowed_batch.windows) <= 1
  1021  
  1022        window, = windowed_batch.windows
  1023        side_inputs = [si[window] for si in self.side_inputs]
  1024        side_inputs.extend(additional_args)
  1025        (args_for_process_batch,
  1026         kwargs_for_process_batch) = util.insert_values_in_args(
  1027             self.args_for_process_batch,
  1028             self.kwargs_for_process_batch,
  1029             side_inputs)
  1030      elif self.cache_globally_windowed_args:
  1031        # Attempt to cache additional args if all inputs are globally
  1032        # windowed inputs when processing the first element.
  1033        self.cache_globally_windowed_args = False
  1034  
  1035        # Fill in sideInputs if they are globally windowed
  1036        global_window = GlobalWindow()
  1037        self.args_for_process_batch, self.kwargs_for_process_batch = (
  1038            util.insert_values_in_args(
  1039                self.args_for_process_batch, self.kwargs_for_process_batch,
  1040                [si[global_window] for si in self.side_inputs]))
  1041        args_for_process_batch, kwargs_for_process_batch = (
  1042            self.args_for_process_batch, self.kwargs_for_process_batch)
  1043      else:
  1044        args_for_process_batch, kwargs_for_process_batch = (
  1045            self.args_for_process_batch, self.kwargs_for_process_batch)
  1046  
  1047      for i, p in self.placeholders_for_process_batch:
  1048        if core.DoFn.ElementParam == p:
  1049          args_for_process_batch[i] = windowed_batch.values
  1050        elif core.DoFn.KeyParam == p:
  1051          raise NotImplementedError(
  1052              "https://github.com/apache/beam/issues/21653: "
  1053              "Per-key process_batch")
  1054        elif core.DoFn.WindowParam == p:
  1055          args_for_process_batch[i] = window
  1056        elif core.DoFn.TimestampParam == p:
  1057          args_for_process_batch[i] = windowed_batch.timestamp
  1058        elif core.DoFn.PaneInfoParam == p:
  1059          assert isinstance(windowed_batch, HomogeneousWindowedBatch)
  1060          args_for_process_batch[i] = windowed_batch.pane_info
  1061        elif isinstance(p, core.DoFn.StateParam):
  1062          raise NotImplementedError(
  1063              "https://github.com/apache/beam/issues/21653: "
  1064              "Per-key process_batch")
  1065        elif isinstance(p, core.DoFn.TimerParam):
  1066          raise NotImplementedError(
  1067              "https://github.com/apache/beam/issues/21653: "
  1068              "Per-key process_batch")
  1069  
  1070      kwargs_for_process_batch = kwargs_for_process_batch or {}
  1071  
  1072      self.output_handler.handle_process_batch_outputs(
  1073          windowed_batch,
  1074          self.process_batch_method(
  1075              *args_for_process_batch, **kwargs_for_process_batch),
  1076          self.threadsafe_watermark_estimator)
  1077  
  1078    @staticmethod
  1079    def _try_split(fraction,
  1080        window_index, # type: Optional[int]
  1081        stop_window_index, # type: Optional[int]
  1082        windowed_value, # type: WindowedValue
  1083        restriction,
  1084        watermark_estimator_state,
  1085        restriction_provider, # type: RestrictionProvider
  1086        restriction_tracker, # type: RestrictionTracker
  1087        watermark_estimator, # type: WatermarkEstimator
  1088                   ):
  1089      # type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual], Optional[int]]]
  1090  
  1091      """Try to split returning a primaries, residuals and a new stop index.
  1092  
  1093      For non-window observing splittable DoFns we split the current restriction
  1094      and assign the primary and residual to all the windows.
  1095  
  1096      For window observing splittable DoFns, we:
  1097      1) return a split at a window boundary if the fraction lies outside of the
  1098         current window.
  1099      2) attempt to split the current restriction, if successful then return
  1100         the primary and residual for the current window and an additional
  1101         primary and residual for any fully processed and fully unprocessed
  1102         windows.
  1103      3) fall back to returning a split at the window boundary if possible
  1104  
  1105      Args:
  1106        window_index: the current index of the window being processed or None
  1107                      if the splittable DoFn is not window observing.
  1108        stop_window_index: the current index to stop processing at or None
  1109                           if the splittable DoFn is not window observing.
  1110        windowed_value: the current windowed value
  1111        restriction: the initial restriction when processing was started.
  1112        watermark_estimator_state: the initial watermark estimator state when
  1113                                   processing was started.
  1114        restriction_provider: the DoFn's restriction provider
  1115        restriction_tracker: the current restriction tracker
  1116        watermark_estimator: the current watermark estimator
  1117  
  1118      Returns:
  1119        A tuple containing (primaries, residuals, new_stop_index) or None if
  1120        splitting was not possible. new_stop_index will only be set if the
  1121        splittable DoFn is window observing otherwise it will be None.
  1122      """
  1123      def compute_whole_window_split(to_index, from_index):
  1124        restriction_size = restriction_provider.restriction_size(
  1125            windowed_value, restriction)
  1126        if restriction_size < 0:
  1127          raise ValueError(
  1128              'Expected size >= 0 but received %s.' % restriction_size)
  1129        # The primary and residual both share the same value only differing
  1130        # by the set of windows they are in.
  1131        value = ((windowed_value.value, (restriction, watermark_estimator_state)),
  1132                 restriction_size)
  1133        primary_restriction = SplitResultPrimary(
  1134            primary_value=WindowedValue(
  1135                value,
  1136                windowed_value.timestamp,
  1137                windowed_value.windows[:to_index])) if to_index > 0 else None
  1138        # Don't report any updated watermarks for the residual since they have
  1139        # not processed any part of the restriction.
  1140        residual_restriction = SplitResultResidual(
  1141            residual_value=WindowedValue(
  1142                value,
  1143                windowed_value.timestamp,
  1144                windowed_value.windows[from_index:stop_window_index]),
  1145            current_watermark=None,
  1146            deferred_timestamp=None) if from_index < stop_window_index else None
  1147        return (primary_restriction, residual_restriction)
  1148  
  1149      primary_restrictions = []
  1150      residual_restrictions = []
  1151  
  1152      window_observing = window_index is not None
  1153      # If we are processing each window separately and we aren't on the last
  1154      # window then compute whether the split lies within the current window
  1155      # or a future window.
  1156      if window_observing and window_index != stop_window_index - 1:
  1157        progress = restriction_tracker.current_progress()
  1158        if not progress:
  1159          # Assume no work has been completed for the current window if progress
  1160          # is unavailable.
  1161          from apache_beam.io.iobase import RestrictionProgress
  1162          progress = RestrictionProgress(completed=0, remaining=1)
  1163  
  1164        scaled_progress = PerWindowInvoker._scale_progress(
  1165            progress, window_index, stop_window_index)
  1166        # Compute the fraction of the remainder relative to the scaled progress.
  1167        # If the value is greater than or equal to progress.remaining_work then we
  1168        # should split at the closest window boundary.
  1169        fraction_of_remainder = scaled_progress.remaining_work * fraction
  1170        if fraction_of_remainder >= progress.remaining_work:
  1171          # The fraction is outside of the current window and hence we will
  1172          # split at the closest window boundary. Favor a split and return the
  1173          # last window if we would have rounded up to the end of the window
  1174          # based upon the fraction.
  1175          new_stop_window_index = min(
  1176              stop_window_index - 1,
  1177              window_index + max(
  1178                  1,
  1179                  int(
  1180                      round((
  1181                          progress.completed_work +
  1182                          scaled_progress.remaining_work * fraction) /
  1183                            progress.total_work))))
  1184          primary, residual = compute_whole_window_split(
  1185              new_stop_window_index, new_stop_window_index)
  1186          assert primary is not None
  1187          assert residual is not None
  1188          return ([primary], [residual], new_stop_window_index)
  1189        else:
  1190          # The fraction is within the current window being processed so compute
  1191          # the updated fraction based upon the number of windows being processed.
  1192          new_stop_window_index = window_index + 1
  1193          fraction = fraction_of_remainder / progress.remaining_work
  1194          # Attempt to split below, if we can't then we'll compute a split
  1195          # using only window boundaries
  1196      else:
  1197        # We aren't splitting within multiple windows so we don't change our
  1198        # stop index.
  1199        new_stop_window_index = stop_window_index
  1200  
  1201      # Temporary workaround for [BEAM-7473]: get current_watermark before
  1202      # split, in case watermark gets advanced before getting split results.
  1203      # In worst case, current_watermark is always stale, which is ok.
  1204      current_watermark = (watermark_estimator.current_watermark())
  1205      current_estimator_state = (watermark_estimator.get_estimator_state())
  1206      split = restriction_tracker.try_split(fraction)
  1207      if split:
  1208        primary, residual = split
  1209        element = windowed_value.value
  1210        primary_size = restriction_provider.restriction_size(
  1211            windowed_value.value, primary)
  1212        if primary_size < 0:
  1213          raise ValueError('Expected size >= 0 but received %s.' % primary_size)
  1214        residual_size = restriction_provider.restriction_size(
  1215            windowed_value.value, residual)
  1216        if residual_size < 0:
  1217          raise ValueError('Expected size >= 0 but received %s.' % residual_size)
  1218        # We use the watermark estimator state for the original process call
  1219        # for the primary and the updated watermark estimator state for the
  1220        # residual for the split.
  1221        primary_split_value = ((element, (primary, watermark_estimator_state)),
  1222                               primary_size)
  1223        residual_split_value = ((element, (residual, current_estimator_state)),
  1224                                residual_size)
  1225        windows = (
  1226            windowed_value.windows[window_index],
  1227        ) if window_observing else windowed_value.windows
  1228        primary_restrictions.append(
  1229            SplitResultPrimary(
  1230                primary_value=WindowedValue(
  1231                    primary_split_value, windowed_value.timestamp, windows)))
  1232        residual_restrictions.append(
  1233            SplitResultResidual(
  1234                residual_value=WindowedValue(
  1235                    residual_split_value, windowed_value.timestamp, windows),
  1236                current_watermark=current_watermark,
  1237                deferred_timestamp=None))
  1238  
  1239        if window_observing:
  1240          assert new_stop_window_index == window_index + 1
  1241          primary, residual = compute_whole_window_split(
  1242              window_index, window_index + 1)
  1243          if primary:
  1244            primary_restrictions.append(primary)
  1245          if residual:
  1246            residual_restrictions.append(residual)
  1247        return (
  1248            primary_restrictions, residual_restrictions, new_stop_window_index)
  1249      elif new_stop_window_index and new_stop_window_index != stop_window_index:
  1250        # If we failed to split but have a new stop index then return a split
  1251        # at the window boundary.
  1252        primary, residual = compute_whole_window_split(
  1253            new_stop_window_index, new_stop_window_index)
  1254        assert primary is not None
  1255        assert residual is not None
  1256        return ([primary], [residual], new_stop_window_index)
  1257      else:
  1258        return None
  1259  
  1260    def try_split(self, fraction):
  1261      # type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual]]]
  1262      if not self.is_splittable:
  1263        return None
  1264  
  1265      with self.splitting_lock:
  1266        if not self.threadsafe_restriction_tracker:
  1267          return None
  1268  
  1269        # Make a local reference to member variables that change references during
  1270        # processing under lock before attempting to split so we have a consistent
  1271        # view of all the references.
  1272        result = PerWindowInvoker._try_split(
  1273            fraction,
  1274            self.current_window_index,
  1275            self.stop_window_index,
  1276            self.current_windowed_value,
  1277            self.restriction,
  1278            self.watermark_estimator_state,
  1279            self.signature.get_restriction_provider(),
  1280            self.threadsafe_restriction_tracker,
  1281            self.threadsafe_watermark_estimator)
  1282        if not result:
  1283          return None
  1284  
  1285        residuals, primaries, self.stop_window_index = result
  1286        return (residuals, primaries)
  1287  
  1288    @staticmethod
  1289    def _scale_progress(progress, window_index, stop_window_index):
  1290      # We scale progress based upon the amount of work we will do for one
  1291      # window and have it apply for all windows.
  1292      completed = window_index * progress.total_work + progress.completed_work
  1293      remaining = (
  1294          stop_window_index -
  1295          (window_index + 1)) * progress.total_work + progress.remaining_work
  1296      from apache_beam.io.iobase import RestrictionProgress
  1297      return RestrictionProgress(completed=completed, remaining=remaining)
  1298  
  1299    def current_element_progress(self):
  1300      # type: () -> Optional[RestrictionProgress]
  1301      if not self.is_splittable:
  1302        return None
  1303  
  1304      with self.splitting_lock:
  1305        current_window_index = self.current_window_index
  1306        stop_window_index = self.stop_window_index
  1307        threadsafe_restriction_tracker = self.threadsafe_restriction_tracker
  1308  
  1309      if not threadsafe_restriction_tracker:
  1310        return None
  1311  
  1312      progress = threadsafe_restriction_tracker.current_progress()
  1313      if not current_window_index or not progress:
  1314        return progress
  1315  
  1316      # stop_window_index should always be set if current_window_index is set,
  1317      # it is an error otherwise.
  1318      assert stop_window_index
  1319      return PerWindowInvoker._scale_progress(
  1320          progress, current_window_index, stop_window_index)
  1321  
  1322  
  1323  class DoFnRunner:
  1324    """For internal use only; no backwards-compatibility guarantees.
  1325  
  1326    A helper class for executing ParDo operations.
  1327    """
  1328  
  1329    def __init__(self,
  1330                 fn,  # type: core.DoFn
  1331                 args,
  1332                 kwargs,
  1333                 side_inputs,  # type: Iterable[sideinputs.SideInputMap]
  1334                 windowing,
  1335                 tagged_receivers,  # type: Mapping[Optional[str], Receiver]
  1336                 step_name=None,  # type: Optional[str]
  1337                 logging_context=None,
  1338                 state=None,
  1339                 scoped_metrics_container=None,
  1340                 operation_name=None,
  1341                 user_state_context=None  # type: Optional[userstate.UserStateContext]
  1342                ):
  1343      """Initializes a DoFnRunner.
  1344  
  1345      Args:
  1346        fn: user DoFn to invoke
  1347        args: positional side input arguments (static and placeholder), if any
  1348        kwargs: keyword side input arguments (static and placeholder), if any
  1349        side_inputs: list of sideinput.SideInputMaps for deferred side inputs
  1350        windowing: windowing properties of the output PCollection(s)
  1351        tagged_receivers: a dict of tag name to Receiver objects
  1352        step_name: the name of this step
  1353        logging_context: DEPRECATED [BEAM-4728]
  1354        state: handle for accessing DoFn state
  1355        scoped_metrics_container: DEPRECATED
  1356        operation_name: The system name assigned by the runner for this operation.
  1357        user_state_context: The UserStateContext instance for the current
  1358                            Stateful DoFn.
  1359      """
  1360      # Need to support multiple iterations.
  1361      side_inputs = list(side_inputs)
  1362  
  1363      self.step_name = step_name
  1364      self.context = DoFnContext(step_name, state=state)
  1365      self.bundle_finalizer_param = DoFn.BundleFinalizerParam()
  1366  
  1367      do_fn_signature = DoFnSignature(fn)
  1368  
  1369      # Optimize for the common case.
  1370      main_receivers = tagged_receivers[None]
  1371  
  1372      # TODO(https://github.com/apache/beam/issues/18886): Remove if block after
  1373      # output counter released.
  1374      if 'outputs_per_element_counter' in RuntimeValueProvider.experiments:
  1375        # TODO(BEAM-3955): Make step_name and operation_name less confused.
  1376        output_counter_name = (
  1377            CounterName('per-element-output-count', step_name=operation_name))
  1378        per_element_output_counter = state._counter_factory.get_counter(
  1379            output_counter_name, Counter.DATAFLOW_DISTRIBUTION).accumulator
  1380      else:
  1381        per_element_output_counter = None
  1382  
  1383      output_handler = _OutputHandler(
  1384          windowing.windowfn,
  1385          main_receivers,
  1386          tagged_receivers,
  1387          per_element_output_counter,
  1388          getattr(fn, 'output_batch_converter', None),
  1389          getattr(
  1390              do_fn_signature.process_method.method_value,
  1391              '_beam_yields_batches',
  1392              False),
  1393          getattr(
  1394              do_fn_signature.process_batch_method.method_value,
  1395              '_beam_yields_elements',
  1396              False),
  1397      )
  1398  
  1399      if do_fn_signature.is_stateful_dofn() and not user_state_context:
  1400        raise Exception(
  1401            'Requested execution of a stateful DoFn, but no user state context '
  1402            'is available. This likely means that the current runner does not '
  1403            'support the execution of stateful DoFns.')
  1404  
  1405      self.do_fn_invoker = DoFnInvoker.create_invoker(
  1406          do_fn_signature,
  1407          output_handler,
  1408          self.context,
  1409          side_inputs,
  1410          args,
  1411          kwargs,
  1412          user_state_context=user_state_context,
  1413          bundle_finalizer_param=self.bundle_finalizer_param)
  1414  
  1415    def process(self, windowed_value):
  1416      # type: (WindowedValue) -> Iterable[SplitResultResidual]
  1417      try:
  1418        return self.do_fn_invoker.invoke_process(windowed_value)
  1419      except BaseException as exn:
  1420        self._reraise_augmented(exn)
  1421        return []
  1422  
  1423    def process_batch(self, windowed_batch):
  1424      # type: (WindowedBatch) -> None
  1425      try:
  1426        self.do_fn_invoker.invoke_process_batch(windowed_batch)
  1427      except BaseException as exn:
  1428        self._reraise_augmented(exn)
  1429  
  1430    def process_with_sized_restriction(self, windowed_value):
  1431      # type: (WindowedValue) -> Iterable[SplitResultResidual]
  1432      (element, (restriction, estimator_state)), _ = windowed_value.value
  1433      return self.do_fn_invoker.invoke_process(
  1434          windowed_value.with_value(element),
  1435          restriction=restriction,
  1436          watermark_estimator_state=estimator_state)
  1437  
  1438    def try_split(self, fraction):
  1439      # type: (...) -> Optional[Tuple[Iterable[SplitResultPrimary], Iterable[SplitResultResidual]]]
  1440      assert isinstance(self.do_fn_invoker, PerWindowInvoker)
  1441      return self.do_fn_invoker.try_split(fraction)
  1442  
  1443    def current_element_progress(self):
  1444      # type: () -> Optional[RestrictionProgress]
  1445      assert isinstance(self.do_fn_invoker, PerWindowInvoker)
  1446      return self.do_fn_invoker.current_element_progress()
  1447  
  1448    def process_user_timer(
  1449        self, timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag):
  1450      try:
  1451        self.do_fn_invoker.invoke_user_timer(
  1452            timer_spec, key, window, timestamp, pane_info, dynamic_timer_tag)
  1453      except BaseException as exn:
  1454        self._reraise_augmented(exn)
  1455  
  1456    def _invoke_bundle_method(self, bundle_method):
  1457      try:
  1458        self.context.set_element(None)
  1459        bundle_method()
  1460      except BaseException as exn:
  1461        self._reraise_augmented(exn)
  1462  
  1463    def _invoke_lifecycle_method(self, lifecycle_method):
  1464      try:
  1465        self.context.set_element(None)
  1466        lifecycle_method()
  1467      except BaseException as exn:
  1468        self._reraise_augmented(exn)
  1469  
  1470    def setup(self):
  1471      # type: () -> None
  1472      self._invoke_lifecycle_method(self.do_fn_invoker.invoke_setup)
  1473  
  1474    def start(self):
  1475      # type: () -> None
  1476      self._invoke_bundle_method(self.do_fn_invoker.invoke_start_bundle)
  1477  
  1478    def finish(self):
  1479      # type: () -> None
  1480      self._invoke_bundle_method(self.do_fn_invoker.invoke_finish_bundle)
  1481  
  1482    def teardown(self):
  1483      # type: () -> None
  1484      self._invoke_lifecycle_method(self.do_fn_invoker.invoke_teardown)
  1485  
  1486    def finalize(self):
  1487      # type: () -> None
  1488      self.bundle_finalizer_param.finalize_bundle()
  1489  
  1490    def _reraise_augmented(self, exn):
  1491      if getattr(exn, '_tagged_with_step', False) or not self.step_name:
  1492        raise exn
  1493      step_annotation = " [while running '%s']" % self.step_name
  1494      # To emulate exception chaining (not available in Python 2).
  1495      try:
  1496        # Attempt to construct the same kind of exception
  1497        # with an augmented message.
  1498        new_exn = type(exn)(exn.args[0] + step_annotation, *exn.args[1:])
  1499        new_exn._tagged_with_step = True  # Could raise attribute error.
  1500      except:  # pylint: disable=bare-except
  1501        # If anything goes wrong, construct a RuntimeError whose message
  1502        # records the original exception's type and message.
  1503        new_exn = RuntimeError(
  1504            traceback.format_exception_only(type(exn), exn)[-1].strip() +
  1505            step_annotation)
  1506        new_exn._tagged_with_step = True
  1507      _, _, tb = sys.exc_info()
  1508      raise new_exn.with_traceback(tb)
  1509  
  1510  
  1511  class OutputHandler(object):
  1512    def handle_process_outputs(
  1513        self, windowed_input_element, results, watermark_estimator=None):
  1514      # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None
  1515      raise NotImplementedError
  1516  
  1517    def handle_process_batch_outputs(
  1518        self, windowed_input_element, results, watermark_estimator=None):
  1519      # type: (WindowedBatch, Iterable[Any], Optional[WatermarkEstimator]) -> None
  1520      raise NotImplementedError
  1521  
  1522  
  1523  class _OutputHandler(OutputHandler):
  1524    """Processes output produced by DoFn method invocations."""
  1525  
  1526    def __init__(self,
  1527                 window_fn,
  1528                 main_receivers,  # type: Receiver
  1529                 tagged_receivers,  # type: Mapping[Optional[str], Receiver]
  1530                 per_element_output_counter,
  1531                 output_batch_converter, # type: Optional[BatchConverter]
  1532                 process_yields_batches, # type: bool,
  1533                 process_batch_yields_elements, # type: bool,
  1534                 ):
  1535      """Initializes ``_OutputHandler``.
  1536  
  1537      Args:
  1538        window_fn: a windowing function (WindowFn).
  1539        main_receivers: a dict of tag name to Receiver objects.
  1540        tagged_receivers: main receiver object.
  1541        per_element_output_counter: per_element_output_counter of one work_item.
  1542                                    could be none if experimental flag turn off
  1543      """
  1544      self.window_fn = window_fn
  1545      self.main_receivers = main_receivers
  1546      self.tagged_receivers = tagged_receivers
  1547      if (per_element_output_counter is not None and
  1548          per_element_output_counter.is_cythonized):
  1549        self.per_element_output_counter = per_element_output_counter
  1550      else:
  1551        self.per_element_output_counter = None
  1552      self.output_batch_converter = output_batch_converter
  1553      self._process_yields_batches = process_yields_batches
  1554      self._process_batch_yields_elements = process_batch_yields_elements
  1555  
  1556    def handle_process_outputs(
  1557        self, windowed_input_element, results, watermark_estimator=None):
  1558      # type: (WindowedValue, Iterable[Any], Optional[WatermarkEstimator]) -> None
  1559  
  1560      """Dispatch the result of process computation to the appropriate receivers.
  1561  
  1562      A value wrapped in a TaggedOutput object will be unwrapped and
  1563      then dispatched to the appropriate indexed output.
  1564      """
  1565      if results is None:
  1566        results = []
  1567  
  1568      # TODO(https://github.com/apache/beam/issues/20404): Verify that the
  1569      #  results object is a valid iterable type if
  1570      #  performance_runtime_type_check is active, without harming performance
  1571      output_element_count = 0
  1572      for result in results:
  1573        tag, result = self._handle_tagged_output(result)
  1574  
  1575        if not self._process_yields_batches:
  1576          # process yields elements
  1577          windowed_value = self._maybe_propagate_windowing_info(
  1578              windowed_input_element, result)
  1579  
  1580          output_element_count += 1
  1581  
  1582          self._write_value_to_tag(tag, windowed_value, watermark_estimator)
  1583        else:  # process yields batches
  1584          self._verify_batch_output(result)
  1585  
  1586          if isinstance(result, WindowedBatch):
  1587            assert isinstance(result, HomogeneousWindowedBatch)
  1588            windowed_batch = result
  1589  
  1590            if (windowed_input_element is not None and
  1591                len(windowed_input_element.windows) != 1):
  1592              windowed_batch.windows *= len(windowed_input_element.windows)
  1593          else:
  1594            windowed_batch = (
  1595                HomogeneousWindowedBatch.from_batch_and_windowed_value(
  1596                    batch=result, windowed_value=windowed_input_element))
  1597  
  1598          output_element_count += self.output_batch_converter.get_length(
  1599              windowed_batch.values)
  1600  
  1601          self._write_batch_to_tag(tag, windowed_batch, watermark_estimator)
  1602  
  1603      # TODO(https://github.com/apache/beam/issues/18886): Remove if block after
  1604      # output counter released. Only enable per_element_output_counter when
  1605      # counter cythonized
  1606      if self.per_element_output_counter is not None:
  1607        self.per_element_output_counter.add_input(output_element_count)
  1608  
  1609    def handle_process_batch_outputs(
  1610        self, windowed_input_batch, results, watermark_estimator=None):
  1611      # type: (WindowedBatch, Iterable[Any], Optional[WatermarkEstimator]) -> None
  1612  
  1613      """Dispatch the result of process_batch computation to the appropriate
  1614      receivers.
  1615  
  1616      A value wrapped in a TaggedOutput object will be unwrapped and
  1617      then dispatched to the appropriate indexed output.
  1618      """
  1619      if results is None:
  1620        results = []
  1621  
  1622      output_element_count = 0
  1623      for result in results:
  1624        tag, result = self._handle_tagged_output(result)
  1625  
  1626        if not self._process_batch_yields_elements:
  1627          # process_batch yields batches
  1628          assert self.output_batch_converter is not None
  1629  
  1630          self._verify_batch_output(result)
  1631  
  1632          if isinstance(result, WindowedBatch):
  1633            assert isinstance(result, HomogeneousWindowedBatch)
  1634            windowed_batch = result
  1635  
  1636            if (windowed_input_batch is not None and
  1637                len(windowed_input_batch.windows) != 1):
  1638              windowed_batch.windows *= len(windowed_input_batch.windows)
  1639          else:
  1640            windowed_batch = windowed_input_batch.with_values(result)
  1641  
  1642          output_element_count += self.output_batch_converter.get_length(
  1643              windowed_batch.values)
  1644  
  1645          self._write_batch_to_tag(tag, windowed_batch, watermark_estimator)
  1646        else:  # process_batch yields elements
  1647          assert isinstance(windowed_input_batch, HomogeneousWindowedBatch)
  1648  
  1649          windowed_value = self._maybe_propagate_windowing_info(
  1650              windowed_input_batch.as_empty_windowed_value(), result)
  1651  
  1652          output_element_count += 1
  1653  
  1654          self._write_value_to_tag(tag, windowed_value, watermark_estimator)
  1655  
  1656      # TODO(https://github.com/apache/beam/issues/18886): Remove if block after
  1657      # output counter released. Only enable per_element_output_counter when
  1658      # counter cythonized
  1659      if self.per_element_output_counter is not None:
  1660        self.per_element_output_counter.add_input(output_element_count)
  1661  
  1662    def _maybe_propagate_windowing_info(self, windowed_input_element, result):
  1663      # type: (WindowedValue, Any) -> WindowedValue
  1664      if isinstance(result, WindowedValue):
  1665        windowed_value = result
  1666        if (windowed_input_element is not None and
  1667            len(windowed_input_element.windows) != 1):
  1668          windowed_value.windows *= len(windowed_input_element.windows)
  1669        return windowed_value
  1670  
  1671      elif isinstance(result, TimestampedValue):
  1672        assign_context = WindowFn.AssignContext(result.timestamp, result.value)
  1673        windowed_value = WindowedValue(
  1674            result.value, result.timestamp, self.window_fn.assign(assign_context))
  1675        if len(windowed_input_element.windows) != 1:
  1676          windowed_value.windows *= len(windowed_input_element.windows)
  1677        return windowed_value
  1678  
  1679      else:
  1680        return windowed_input_element.with_value(result)
  1681  
  1682    def _handle_tagged_output(self, result):
  1683      if isinstance(result, TaggedOutput):
  1684        tag = result.tag
  1685        if not isinstance(tag, str):
  1686          raise TypeError('In %s, tag %s is not a string' % (self, tag))
  1687        return tag, result.value
  1688      return None, result
  1689  
  1690    def _write_value_to_tag(self, tag, windowed_value, watermark_estimator):
  1691      if watermark_estimator is not None:
  1692        watermark_estimator.observe_timestamp(windowed_value.timestamp)
  1693  
  1694      if tag is None:
  1695        self.main_receivers.receive(windowed_value)
  1696      else:
  1697        self.tagged_receivers[tag].receive(windowed_value)
  1698  
  1699    def _write_batch_to_tag(self, tag, windowed_batch, watermark_estimator):
  1700      if watermark_estimator is not None:
  1701        for timestamp in windowed_batch.timestamps:
  1702          watermark_estimator.observe_timestamp(timestamp)
  1703  
  1704      if tag is None:
  1705        self.main_receivers.receive_batch(windowed_batch)
  1706      else:
  1707        self.tagged_receivers[tag].receive_batch(windowed_batch)
  1708  
  1709    def _verify_batch_output(self, result):
  1710      if isinstance(result, (WindowedValue, TimestampedValue)):
  1711        raise TypeError(
  1712            f"Received {type(result).__name__} from DoFn that was "
  1713            "expected to produce a batch.")
  1714  
  1715    def start_bundle_outputs(self, results):
  1716      """Validate that start_bundle does not output any elements"""
  1717      if results is None:
  1718        return
  1719      raise RuntimeError(
  1720          'Start Bundle should not output any elements but got %s' % results)
  1721  
  1722    def finish_bundle_outputs(self, results):
  1723      """Dispatch the result of finish_bundle to the appropriate receivers.
  1724  
  1725      A value wrapped in a TaggedOutput object will be unwrapped and
  1726      then dispatched to the appropriate indexed output.
  1727      """
  1728      if results is None:
  1729        return
  1730  
  1731      for result in results:
  1732        tag = None
  1733        if isinstance(result, TaggedOutput):
  1734          tag = result.tag
  1735          if not isinstance(tag, str):
  1736            raise TypeError('In %s, tag %s is not a string' % (self, tag))
  1737          result = result.value
  1738  
  1739        if isinstance(result, WindowedValue):
  1740          windowed_value = result
  1741        else:
  1742          raise RuntimeError('Finish Bundle should only output WindowedValue ' +\
  1743                             'type but got %s' % type(result))
  1744  
  1745        if tag is None:
  1746          self.main_receivers.receive(windowed_value)
  1747        else:
  1748          self.tagged_receivers[tag].receive(windowed_value)
  1749  
  1750  
  1751  class _NoContext(WindowFn.AssignContext):
  1752    """An uninspectable WindowFn.AssignContext."""
  1753    NO_VALUE = object()
  1754  
  1755    def __init__(self, value, timestamp=NO_VALUE):
  1756      self.value = value
  1757      self._timestamp = timestamp
  1758  
  1759    @property
  1760    def timestamp(self):
  1761      if self._timestamp is self.NO_VALUE:
  1762        raise ValueError('No timestamp in this context.')
  1763      else:
  1764        return self._timestamp
  1765  
  1766    @property
  1767    def existing_windows(self):
  1768      raise ValueError('No existing_windows in this context.')
  1769  
  1770  
  1771  class DoFnState(object):
  1772    """For internal use only; no backwards-compatibility guarantees.
  1773  
  1774    Keeps track of state that DoFns want, currently, user counters.
  1775    """
  1776    def __init__(self, counter_factory):
  1777      self.step_name = ''
  1778      self._counter_factory = counter_factory
  1779  
  1780    def counter_for(self, aggregator):
  1781      """Looks up the counter for this aggregator, creating one if necessary."""
  1782      return self._counter_factory.get_aggregator_counter(
  1783          self.step_name, aggregator)
  1784  
  1785  
  1786  # TODO(robertwb): Replace core.DoFnContext with this.
  1787  class DoFnContext(object):
  1788    """For internal use only; no backwards-compatibility guarantees."""
  1789    def __init__(self, label, element=None, state=None):
  1790      self.label = label
  1791      self.state = state
  1792      if element is not None:
  1793        self.set_element(element)
  1794  
  1795    def set_element(self, windowed_value):
  1796      # type: (Optional[WindowedValue]) -> None
  1797      self.windowed_value = windowed_value
  1798  
  1799    @property
  1800    def element(self):
  1801      if self.windowed_value is None:
  1802        raise AttributeError('element not accessible in this context')
  1803      else:
  1804        return self.windowed_value.value
  1805  
  1806    @property
  1807    def timestamp(self):
  1808      if self.windowed_value is None:
  1809        raise AttributeError('timestamp not accessible in this context')
  1810      else:
  1811        return self.windowed_value.timestamp
  1812  
  1813    @property
  1814    def windows(self):
  1815      if self.windowed_value is None:
  1816        raise AttributeError('windows not accessible in this context')
  1817      else:
  1818        return self.windowed_value.windows
  1819  
  1820  
  1821  def group_by_key_input_visitor(deterministic_key_coders=True):
  1822    # Importing here to avoid a circular dependency
  1823    # pylint: disable=wrong-import-order, wrong-import-position
  1824    from apache_beam.pipeline import PipelineVisitor
  1825    from apache_beam.transforms.core import GroupByKey
  1826  
  1827    class GroupByKeyInputVisitor(PipelineVisitor):
  1828      """A visitor that replaces `Any` element type for input `PCollection` of
  1829      a `GroupByKey` with a `KV` type.
  1830  
  1831      TODO(BEAM-115): Once Python SDK is compatible with the new Runner API,
  1832      we could directly replace the coder instead of mutating the element type.
  1833      """
  1834      def __init__(self, deterministic_key_coders=True):
  1835        self.deterministic_key_coders = deterministic_key_coders
  1836  
  1837      def enter_composite_transform(self, transform_node):
  1838        self.visit_transform(transform_node)
  1839  
  1840      def visit_transform(self, transform_node):
  1841        if isinstance(transform_node.transform, GroupByKey):
  1842          pcoll = transform_node.inputs[0]
  1843          pcoll.element_type = typehints.coerce_to_kv_type(
  1844              pcoll.element_type, transform_node.full_label)
  1845          pcoll.requires_deterministic_key_coder = (
  1846              self.deterministic_key_coders and transform_node.full_label)
  1847          key_type, value_type = pcoll.element_type.tuple_types
  1848          if transform_node.outputs:
  1849            key = next(iter(transform_node.outputs.keys()))
  1850            transform_node.outputs[key].element_type = typehints.KV[
  1851                key_type, typehints.Iterable[value_type]]
  1852            transform_node.outputs[key].requires_deterministic_key_coder = (
  1853                self.deterministic_key_coders and transform_node.full_label)
  1854  
  1855    return GroupByKeyInputVisitor(deterministic_key_coders)
  1856  
  1857  
  1858  def validate_pipeline_graph(pipeline_proto):
  1859    """Ensures this is a correctly constructed Beam pipeline.
  1860    """
  1861    def get_coder(pcoll_id):
  1862      return pipeline_proto.components.coders[
  1863          pipeline_proto.components.pcollections[pcoll_id].coder_id]
  1864  
  1865    def validate_transform(transform_id):
  1866      transform_proto = pipeline_proto.components.transforms[transform_id]
  1867  
  1868      # Currently the only validation we perform is that GBK operations have
  1869      # their coders set properly.
  1870      if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
  1871        if len(transform_proto.inputs) != 1:
  1872          raise ValueError("Unexpected number of inputs: %s" % transform_proto)
  1873        if len(transform_proto.outputs) != 1:
  1874          raise ValueError("Unexpected number of outputs: %s" % transform_proto)
  1875        input_coder = get_coder(next(iter(transform_proto.inputs.values())))
  1876        output_coder = get_coder(next(iter(transform_proto.outputs.values())))
  1877        if input_coder.spec.urn != common_urns.coders.KV.urn:
  1878          raise ValueError(
  1879              "Bad coder for input of %s: %s" % (transform_id, input_coder))
  1880        if output_coder.spec.urn != common_urns.coders.KV.urn:
  1881          raise ValueError(
  1882              "Bad coder for output of %s: %s" % (transform_id, output_coder))
  1883        output_values_coder = pipeline_proto.components.coders[
  1884            output_coder.component_coder_ids[1]]
  1885        if (input_coder.component_coder_ids[0] !=
  1886            output_coder.component_coder_ids[0] or
  1887            output_values_coder.spec.urn != common_urns.coders.ITERABLE.urn or
  1888            output_values_coder.component_coder_ids[0] !=
  1889            input_coder.component_coder_ids[1]):
  1890          raise ValueError(
  1891              "Incompatible input coder %s and output coder %s for transform %s" %
  1892              (transform_id, input_coder, output_coder))
  1893  
  1894      for t in transform_proto.subtransforms:
  1895        validate_transform(t)
  1896  
  1897    for t in pipeline_proto.root_transform_ids:
  1898      validate_transform(t)