github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/direct/transform_evaluator.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """An evaluator of a specific application of a transform."""
    19  
    20  # pytype: skip-file
    21  
    22  import atexit
    23  import collections
    24  import logging
    25  import random
    26  import time
    27  from collections import abc
    28  from typing import TYPE_CHECKING
    29  from typing import Any
    30  from typing import Dict
    31  from typing import List
    32  from typing import Tuple
    33  from typing import Type
    34  
    35  from apache_beam import coders
    36  from apache_beam import io
    37  from apache_beam import pvalue
    38  from apache_beam.internal import pickler
    39  from apache_beam.runners import common
    40  from apache_beam.runners.common import DoFnRunner
    41  from apache_beam.runners.common import DoFnState
    42  from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite  # pylint: disable=protected-access
    43  from apache_beam.runners.direct.direct_runner import _DirectReadFromPubSub
    44  from apache_beam.runners.direct.direct_runner import _GroupByKeyOnly
    45  from apache_beam.runners.direct.direct_runner import _StreamingGroupAlsoByWindow
    46  from apache_beam.runners.direct.direct_runner import _StreamingGroupByKeyOnly
    47  from apache_beam.runners.direct.direct_userstate import DirectUserStateContext
    48  from apache_beam.runners.direct.sdf_direct_runner import ProcessElements
    49  from apache_beam.runners.direct.sdf_direct_runner import ProcessFn
    50  from apache_beam.runners.direct.sdf_direct_runner import SDFProcessElementInvoker
    51  from apache_beam.runners.direct.test_stream_impl import _TestStream
    52  from apache_beam.runners.direct.test_stream_impl import _WatermarkController
    53  from apache_beam.runners.direct.util import KeyedWorkItem
    54  from apache_beam.runners.direct.util import TransformResult
    55  from apache_beam.runners.direct.watermark_manager import WatermarkManager
    56  from apache_beam.testing.test_stream import ElementEvent
    57  from apache_beam.testing.test_stream import PairWithTiming
    58  from apache_beam.testing.test_stream import ProcessingTimeEvent
    59  from apache_beam.testing.test_stream import TimingInfo
    60  from apache_beam.testing.test_stream import WatermarkEvent
    61  from apache_beam.testing.test_stream import WindowedValueHolder
    62  from apache_beam.transforms import core
    63  from apache_beam.transforms.trigger import InMemoryUnmergedState
    64  from apache_beam.transforms.trigger import TimeDomain
    65  from apache_beam.transforms.trigger import _CombiningValueStateTag
    66  from apache_beam.transforms.trigger import _ListStateTag
    67  from apache_beam.transforms.trigger import _ReadModifyWriteStateTag
    68  from apache_beam.transforms.trigger import create_trigger_driver
    69  from apache_beam.transforms.userstate import get_dofn_specs
    70  from apache_beam.transforms.userstate import is_stateful_dofn
    71  from apache_beam.transforms.window import GlobalWindows
    72  from apache_beam.transforms.window import WindowedValue
    73  from apache_beam.typehints.typecheck import TypeCheckError
    74  from apache_beam.utils import counters
    75  from apache_beam.utils.timestamp import MIN_TIMESTAMP
    76  from apache_beam.utils.timestamp import Timestamp
    77  
    78  if TYPE_CHECKING:
    79    from apache_beam.io.gcp.pubsub import _PubSubSource
    80    from apache_beam.io.gcp.pubsub import PubsubMessage
    81    from apache_beam.pipeline import AppliedPTransform
    82    from apache_beam.runners.direct.evaluation_context import EvaluationContext
    83  
    84  _LOGGER = logging.getLogger(__name__)
    85  
    86  
    87  class TransformEvaluatorRegistry(object):
    88    """For internal use only; no backwards-compatibility guarantees.
    89  
    90    Creates instances of TransformEvaluator for the application of a transform.
    91    """
    92  
    93    _test_evaluators_overrides = {
    94    }  # type: Dict[Type[core.PTransform], Type[_TransformEvaluator]]
    95  
    96    def __init__(self, evaluation_context):
    97      # type: (EvaluationContext) -> None
    98      assert evaluation_context
    99      self._evaluation_context = evaluation_context
   100      self._evaluators = {
   101          io.Read: _BoundedReadEvaluator,
   102          _DirectReadFromPubSub: _PubSubReadEvaluator,
   103          core.Flatten: _FlattenEvaluator,
   104          core.Impulse: _ImpulseEvaluator,
   105          core.ParDo: _ParDoEvaluator,
   106          _GroupByKeyOnly: _GroupByKeyOnlyEvaluator,
   107          _StreamingGroupByKeyOnly: _StreamingGroupByKeyOnlyEvaluator,
   108          _StreamingGroupAlsoByWindow: _StreamingGroupAlsoByWindowEvaluator,
   109          _NativeWrite: _NativeWriteEvaluator,
   110          _TestStream: _TestStreamEvaluator,
   111          ProcessElements: _ProcessElementsEvaluator,
   112          _WatermarkController: _WatermarkControllerEvaluator,
   113          PairWithTiming: _PairWithTimingEvaluator,
   114      }  # type: Dict[Type[core.PTransform], Type[_TransformEvaluator]]
   115      self._evaluators.update(self._test_evaluators_overrides)
   116      self._root_bundle_providers = {
   117          core.PTransform: DefaultRootBundleProvider,
   118          _TestStream: _TestStreamRootBundleProvider,
   119      }
   120  
   121    def get_evaluator(
   122        self, applied_ptransform, input_committed_bundle, side_inputs):
   123      """Returns a TransformEvaluator suitable for processing given inputs."""
   124      assert applied_ptransform
   125      assert bool(applied_ptransform.side_inputs) == bool(side_inputs)
   126  
   127      # Walk up the class hierarchy to find an evaluable type. This is necessary
   128      # for supporting sub-classes of core transforms.
   129      for cls in applied_ptransform.transform.__class__.mro():
   130        evaluator = self._evaluators.get(cls)
   131        if evaluator:
   132          break
   133  
   134      if not evaluator:
   135        raise NotImplementedError(
   136            'Execution of [%s] not implemented in runner %s.' %
   137            (type(applied_ptransform.transform), self))
   138      return evaluator(
   139          self._evaluation_context,
   140          applied_ptransform,
   141          input_committed_bundle,
   142          side_inputs)
   143  
   144    def get_root_bundle_provider(self, applied_ptransform):
   145      provider_cls = None
   146      for cls in applied_ptransform.transform.__class__.mro():
   147        provider_cls = self._root_bundle_providers.get(cls)
   148        if provider_cls:
   149          break
   150      if not provider_cls:
   151        raise NotImplementedError(
   152            'Root provider for [%s] not implemented in runner %s' %
   153            (type(applied_ptransform.transform), self))
   154      return provider_cls(self._evaluation_context, applied_ptransform)
   155  
   156    def should_execute_serially(self, applied_ptransform):
   157      """Returns True if this applied_ptransform should run one bundle at a time.
   158  
   159      Some TransformEvaluators use a global state object to keep track of their
   160      global execution state. For example evaluator for _GroupByKeyOnly uses this
   161      state as an in memory dictionary to buffer keys.
   162  
   163      Serially executed evaluators will act as syncing point in the graph and
   164      execution will not move forward until they receive all of their inputs. Once
   165      they receive all of their input, they will release the combined output.
   166      Their output may consist of multiple bundles as they may divide their output
   167      into pieces before releasing.
   168  
   169      Args:
   170        applied_ptransform: Transform to be used for execution.
   171  
   172      Returns:
   173        True if executor should execute applied_ptransform serially.
   174      """
   175      if isinstance(applied_ptransform.transform,
   176                    (_GroupByKeyOnly,
   177                     _StreamingGroupByKeyOnly,
   178                     _StreamingGroupAlsoByWindow,
   179                     _NativeWrite)):
   180        return True
   181      elif (isinstance(applied_ptransform.transform, core.ParDo) and
   182            is_stateful_dofn(applied_ptransform.transform.dofn)):
   183        return True
   184      return False
   185  
   186  
   187  class RootBundleProvider(object):
   188    """Provides bundles for the initial execution of a root transform."""
   189    def __init__(self, evaluation_context, applied_ptransform):
   190      self._evaluation_context = evaluation_context
   191      self._applied_ptransform = applied_ptransform
   192  
   193    def get_root_bundles(self):
   194      raise NotImplementedError
   195  
   196  
   197  class DefaultRootBundleProvider(RootBundleProvider):
   198    """Provides an empty bundle by default for root transforms."""
   199    def get_root_bundles(self):
   200      input_node = pvalue.PBegin(self._applied_ptransform.transform.pipeline)
   201      empty_bundle = (
   202          self._evaluation_context.create_empty_committed_bundle(input_node))
   203      return [empty_bundle]
   204  
   205  
   206  class _TestStreamRootBundleProvider(RootBundleProvider):
   207    """Provides an initial bundle for the TestStream evaluator.
   208  
   209    This bundle is used as the initial state to the TestStream. Each unprocessed
   210    bundle emitted from the TestStream afterwards is its state: index into the
   211    stream, and the watermark.
   212    """
   213    def get_root_bundles(self):
   214      test_stream = self._applied_ptransform.transform
   215  
   216      # If there was an endpoint defined then get the events from the
   217      # TestStreamService.
   218      if test_stream.endpoint:
   219        _TestStreamEvaluator.event_stream = _TestStream.events_from_rpc(
   220            test_stream.endpoint,
   221            test_stream.output_tags,
   222            test_stream.coder,
   223            self._evaluation_context)
   224      else:
   225        _TestStreamEvaluator.event_stream = (
   226            _TestStream.events_from_script(test_stream._events))
   227  
   228      bundle = self._evaluation_context.create_bundle(
   229          pvalue.PBegin(self._applied_ptransform.transform.pipeline))
   230      bundle.add(GlobalWindows.windowed_value(b'', timestamp=MIN_TIMESTAMP))
   231      bundle.commit(None)
   232      return [bundle]
   233  
   234  
   235  class _TransformEvaluator(object):
   236    """An evaluator of a specific application of a transform."""
   237  
   238    def __init__(self,
   239                 evaluation_context, # type: EvaluationContext
   240                 applied_ptransform,  # type: AppliedPTransform
   241                 input_committed_bundle,
   242                 side_inputs
   243                ):
   244      self._evaluation_context = evaluation_context
   245      self._applied_ptransform = applied_ptransform
   246      self._input_committed_bundle = input_committed_bundle
   247      self._side_inputs = side_inputs
   248      self._expand_outputs()
   249      self._execution_context = evaluation_context.get_execution_context(
   250          applied_ptransform)
   251      self._step_context = self._execution_context.get_step_context()
   252  
   253    def _expand_outputs(self):
   254      outputs = set()
   255      for pval in self._applied_ptransform.outputs.values():
   256        if isinstance(pval, pvalue.DoOutputsTuple):
   257          pvals = (v for v in pval)
   258        else:
   259          pvals = (pval, )
   260        for v in pvals:
   261          outputs.add(v)
   262      self._outputs = frozenset(outputs)
   263  
   264    def _split_list_into_bundles(
   265        self,
   266        output_pcollection,
   267        elements,
   268        max_element_per_bundle,
   269        element_size_fn):
   270      """Splits elements, an iterable, into multiple output bundles.
   271  
   272      Args:
   273        output_pcollection: PCollection that the elements belong to.
   274        elements: elements to be chunked into bundles.
   275        max_element_per_bundle: (approximately) the maximum element per bundle.
   276          If it is None, only a single bundle will be produced.
   277        element_size_fn: Function to return the size of a given element.
   278  
   279      Returns:
   280        List of output uncommitted bundles with at least one bundle.
   281      """
   282      bundle = self._evaluation_context.create_bundle(output_pcollection)
   283      bundle_size = 0
   284      bundles = [bundle]
   285      for element in elements:
   286        if max_element_per_bundle and bundle_size >= max_element_per_bundle:
   287          bundle = self._evaluation_context.create_bundle(output_pcollection)
   288          bundle_size = 0
   289          bundles.append(bundle)
   290  
   291        bundle.output(element)
   292        bundle_size += element_size_fn(element)
   293      return bundles
   294  
   295    def start_bundle(self):
   296      """Starts a new bundle."""
   297      pass
   298  
   299    def process_timer_wrapper(self, timer_firing):
   300      """Process timer by clearing and then calling process_timer().
   301  
   302      This method is called with any timer firing and clears the delivered
   303      timer from the keyed state and then calls process_timer().  The default
   304      process_timer() implementation emits a KeyedWorkItem for the particular
   305      timer and passes it to process_element().  Evaluator subclasses which
   306      desire different timer delivery semantics can override process_timer().
   307      """
   308      state = self._step_context.get_keyed_state(timer_firing.encoded_key)
   309      state.clear_timer(
   310          timer_firing.window,
   311          timer_firing.name,
   312          timer_firing.time_domain,
   313          dynamic_timer_tag=timer_firing.dynamic_timer_tag)
   314      self.process_timer(timer_firing)
   315  
   316    def process_timer(self, timer_firing):
   317      """Default process_timer() impl. generating KeyedWorkItem element."""
   318      self.process_element(
   319          GlobalWindows.windowed_value(
   320              KeyedWorkItem(
   321                  timer_firing.encoded_key, timer_firings=[timer_firing])))
   322  
   323    def process_element(self, element):
   324      """Processes a new element as part of the current bundle."""
   325      raise NotImplementedError('%s do not process elements.' % type(self))
   326  
   327    def finish_bundle(self):
   328      # type: () -> TransformResult
   329  
   330      """Finishes the bundle and produces output."""
   331      pass
   332  
   333  
   334  class _BoundedReadEvaluator(_TransformEvaluator):
   335    """TransformEvaluator for bounded Read transform."""
   336  
   337    # After some benchmarks, 1000 was optimal among {100,1000,10000}
   338    MAX_ELEMENT_PER_BUNDLE = 1000
   339  
   340    def __init__(
   341        self,
   342        evaluation_context,
   343        applied_ptransform,
   344        input_committed_bundle,
   345        side_inputs):
   346      assert not side_inputs
   347      self._source = applied_ptransform.transform.source
   348      self._source.pipeline_options = evaluation_context.pipeline_options
   349      super().__init__(
   350          evaluation_context,
   351          applied_ptransform,
   352          input_committed_bundle,
   353          side_inputs)
   354  
   355    def finish_bundle(self):
   356      assert len(self._outputs) == 1
   357      output_pcollection = list(self._outputs)[0]
   358  
   359      def _read_values_to_bundles(reader):
   360        read_result = [GlobalWindows.windowed_value(e) for e in reader]
   361        return self._split_list_into_bundles(
   362            output_pcollection,
   363            read_result,
   364            _BoundedReadEvaluator.MAX_ELEMENT_PER_BUNDLE,
   365            lambda _: 1)
   366  
   367      if isinstance(self._source, io.iobase.BoundedSource):
   368        # Getting a RangeTracker for the default range of the source and reading
   369        # the full source using that.
   370        range_tracker = self._source.get_range_tracker(None, None)
   371        reader = self._source.read(range_tracker)
   372        bundles = _read_values_to_bundles(reader)
   373      else:
   374        with self._source.reader() as reader:
   375          bundles = _read_values_to_bundles(reader)
   376  
   377      return TransformResult(self, bundles, [], None, None)
   378  
   379  
   380  class _WatermarkControllerEvaluator(_TransformEvaluator):
   381    """TransformEvaluator for the _WatermarkController transform.
   382  
   383    This is used to enable multiple output watermarks for the TestStream.
   384    """
   385  
   386    # The state tag used to store the watermark.
   387    WATERMARK_TAG = _ReadModifyWriteStateTag(
   388        '_WatermarkControllerEvaluator_Watermark_Tag')
   389  
   390    def __init__(
   391        self,
   392        evaluation_context,
   393        applied_ptransform,
   394        input_committed_bundle,
   395        side_inputs):
   396      assert not side_inputs
   397      self.transform = applied_ptransform.transform
   398      super().__init__(
   399          evaluation_context,
   400          applied_ptransform,
   401          input_committed_bundle,
   402          side_inputs)
   403      self._init_state()
   404  
   405    def _init_state(self):
   406      """Gets and sets the initial state.
   407  
   408      This is used to keep track of the watermark hold between calls.
   409      """
   410      transform_states = self._evaluation_context._transform_keyed_states
   411      state = transform_states[self._applied_ptransform]
   412      if self.WATERMARK_TAG not in state:
   413        watermark_state = InMemoryUnmergedState()
   414        watermark_state.set_global_state(self.WATERMARK_TAG, MIN_TIMESTAMP)
   415        state[self.WATERMARK_TAG] = watermark_state
   416      self._state = state[self.WATERMARK_TAG]
   417  
   418    @property
   419    def _watermark(self):
   420      return self._state.get_global_state(self.WATERMARK_TAG)
   421  
   422    @_watermark.setter
   423    def _watermark(self, watermark):
   424      self._state.set_global_state(self.WATERMARK_TAG, watermark)
   425  
   426    def start_bundle(self):
   427      self.bundles = []
   428  
   429    def process_element(self, element):
   430      # In order to keep the order of the elements between the script and what
   431      # flows through the pipeline the same, emit the elements here.
   432      event = element.value
   433      if isinstance(event, WatermarkEvent):
   434        self._watermark = event.new_watermark
   435      elif isinstance(event, ElementEvent):
   436        main_output = list(self._outputs)[0]
   437        bundle = self._evaluation_context.create_bundle(main_output)
   438        for tv in event.timestamped_values:
   439          # Unreify the value into the correct window.
   440          if isinstance(tv.value, WindowedValueHolder):
   441            bundle.output(tv.value.windowed_value)
   442          else:
   443            bundle.output(
   444                GlobalWindows.windowed_value(tv.value, timestamp=tv.timestamp))
   445        self.bundles.append(bundle)
   446  
   447    def finish_bundle(self):
   448      # The watermark hold we set here is the way we allow the TestStream events
   449      # to control the output watermark.
   450      return TransformResult(
   451          self, self.bundles, [], None, {None: self._watermark})
   452  
   453  
   454  class _PairWithTimingEvaluator(_TransformEvaluator):
   455    """TransformEvaluator for the PairWithTiming transform.
   456  
   457    This transform takes an element as an input and outputs
   458    KV(element, `TimingInfo`). Where the `TimingInfo` contains both the
   459    processing time timestamp and watermark.
   460    """
   461    def __init__(
   462        self,
   463        evaluation_context,
   464        applied_ptransform,
   465        input_committed_bundle,
   466        side_inputs):
   467      assert not side_inputs
   468      super().__init__(
   469          evaluation_context,
   470          applied_ptransform,
   471          input_committed_bundle,
   472          side_inputs)
   473  
   474    def start_bundle(self):
   475      main_output = list(self._outputs)[0]
   476      self.bundle = self._evaluation_context.create_bundle(main_output)
   477  
   478      watermark_manager = self._evaluation_context._watermark_manager
   479      watermarks = watermark_manager.get_watermarks(self._applied_ptransform)
   480  
   481      output_watermark = watermarks.output_watermark
   482      now = Timestamp(seconds=watermark_manager._clock.time())
   483      self.timing_info = TimingInfo(now, output_watermark)
   484  
   485    def process_element(self, element):
   486      result = WindowedValue((element.value, self.timing_info),
   487                             element.timestamp,
   488                             element.windows,
   489                             element.pane_info)
   490      self.bundle.output(result)
   491  
   492    def finish_bundle(self):
   493      return TransformResult(self, [self.bundle], [], None, {})
   494  
   495  
   496  class _TestStreamEvaluator(_TransformEvaluator):
   497    """TransformEvaluator for the TestStream transform.
   498  
   499    This evaluator's responsibility is to retrieve the next event from the
   500    _TestStream and either: advance the clock, advance the _TestStream watermark,
   501    or pass the event to the _WatermarkController.
   502  
   503    The _WatermarkController is in charge of emitting the elements to the
   504    downstream consumers and setting its own output watermark.
   505    """
   506  
   507    event_stream = None
   508  
   509    def __init__(
   510        self,
   511        evaluation_context,
   512        applied_ptransform,
   513        input_committed_bundle,
   514        side_inputs):
   515      assert not side_inputs
   516      super().__init__(
   517          evaluation_context,
   518          applied_ptransform,
   519          input_committed_bundle,
   520          side_inputs)
   521      self.test_stream = applied_ptransform.transform
   522      self.is_done = False
   523  
   524    def start_bundle(self):
   525      self.bundles = []
   526      self.watermark = MIN_TIMESTAMP
   527  
   528    def process_element(self, element):
   529      # The watermark of the _TestStream transform itself.
   530      self.watermark = element.timestamp
   531  
   532      # Set up the correct watermark holds in the Watermark controllers and the
   533      # TestStream so that the watermarks will not automatically advance to +inf
   534      # when elements start streaming. This can happen multiple times in the first
   535      # bundle, but the operations are idempotent and adding state to keep track
   536      # of this would add unnecessary code complexity.
   537      events = []
   538      if self.watermark == MIN_TIMESTAMP:
   539        for event in self.test_stream._set_up(self.test_stream.output_tags):
   540          events.append(event)
   541  
   542      # Retrieve the TestStream's event stream and read from it.
   543      try:
   544        events.append(next(self.event_stream))
   545      except StopIteration:
   546        # Advance the watermarks to +inf to cleanly stop the pipeline.
   547        self.is_done = True
   548        events += ([
   549            e for e in self.test_stream._tear_down(self.test_stream.output_tags)
   550        ])
   551  
   552      for event in events:
   553        # We can either have the _TestStream or the _WatermarkController to emit
   554        # the elements. We chose to emit in the _WatermarkController so that the
   555        # element is emitted at the correct watermark value.
   556        if isinstance(event, (ElementEvent, WatermarkEvent)):
   557          # The WATERMARK_CONTROL_TAG is used to hold the _TestStream's
   558          # watermark to -inf, then +inf-1, then +inf. This watermark progression
   559          # is ultimately used to set up the proper holds to allow the
   560          # _WatermarkControllers to control their own output watermarks.
   561          if event.tag == _TestStream.WATERMARK_CONTROL_TAG:
   562            self.watermark = event.new_watermark
   563          else:
   564            main_output = list(self._outputs)[0]
   565            bundle = self._evaluation_context.create_bundle(main_output)
   566            bundle.output(GlobalWindows.windowed_value(event))
   567            self.bundles.append(bundle)
   568        elif isinstance(event, ProcessingTimeEvent):
   569          self._evaluation_context._watermark_manager._clock.advance_time(
   570              event.advance_by)
   571        else:
   572          raise ValueError('Invalid TestStream event: %s.' % event)
   573  
   574    def finish_bundle(self):
   575      unprocessed_bundles = []
   576  
   577      # Continue to send its own state to itself via an unprocessed bundle. This
   578      # acts as a heartbeat, where each element will read the next event from the
   579      # event stream.
   580      if not self.is_done:
   581        unprocessed_bundle = self._evaluation_context.create_bundle(
   582            pvalue.PBegin(self._applied_ptransform.transform.pipeline))
   583        unprocessed_bundle.add(
   584            GlobalWindows.windowed_value(b'', timestamp=self.watermark))
   585        unprocessed_bundles.append(unprocessed_bundle)
   586  
   587      # Returning the watermark in the dict here is used as a watermark hold.
   588      return TransformResult(
   589          self, self.bundles, unprocessed_bundles, None, {None: self.watermark})
   590  
   591  
   592  class _PubSubReadEvaluator(_TransformEvaluator):
   593    """TransformEvaluator for PubSub read."""
   594  
   595    # A mapping of transform to _PubSubSubscriptionWrapper.
   596    # TODO(https://github.com/apache/beam/issues/19751): Prevents garbage
   597    # collection of pipeline instances.
   598    _subscription_cache = {}  # type: Dict[AppliedPTransform, str]
   599  
   600    def __init__(
   601        self,
   602        evaluation_context,
   603        applied_ptransform,
   604        input_committed_bundle,
   605        side_inputs):
   606      assert not side_inputs
   607      super().__init__(
   608          evaluation_context,
   609          applied_ptransform,
   610          input_committed_bundle,
   611          side_inputs)
   612  
   613      self.source = self._applied_ptransform.transform._source  # type: _PubSubSource
   614      if self.source.id_label:
   615        raise NotImplementedError(
   616            'DirectRunner: id_label is not supported for PubSub reads')
   617  
   618      sub_project = None
   619      if hasattr(self._evaluation_context, 'pipeline_options'):
   620        from apache_beam.options.pipeline_options import GoogleCloudOptions
   621        sub_project = (
   622            self._evaluation_context.pipeline_options.view_as(
   623                GoogleCloudOptions).project)
   624      if not sub_project:
   625        sub_project = self.source.project
   626  
   627      self._sub_name = self.get_subscription(
   628          self._applied_ptransform,
   629          self.source.project,
   630          self.source.topic_name,
   631          sub_project,
   632          self.source.subscription_name)
   633  
   634    @classmethod
   635    def get_subscription(
   636        cls, transform, project, short_topic_name, sub_project, short_sub_name):
   637      from google.cloud import pubsub
   638  
   639      if short_sub_name:
   640        return pubsub.SubscriberClient.subscription_path(project, short_sub_name)
   641  
   642      if transform in cls._subscription_cache:
   643        return cls._subscription_cache[transform]
   644  
   645      sub_client = pubsub.SubscriberClient()
   646      sub_name = sub_client.subscription_path(
   647          sub_project,
   648          'beam_%d_%x' % (int(time.time()), random.randrange(1 << 32)))
   649      topic_name = sub_client.topic_path(project, short_topic_name)
   650      sub_client.create_subscription(name=sub_name, topic=topic_name)
   651      atexit.register(sub_client.delete_subscription, subscription=sub_name)
   652      cls._subscription_cache[transform] = sub_name
   653      return cls._subscription_cache[transform]
   654  
   655    def start_bundle(self):
   656      pass
   657  
   658    def process_element(self, element):
   659      pass
   660  
   661    def _read_from_pubsub(self, timestamp_attribute):
   662      # type: (...) -> List[Tuple[Timestamp, PubsubMessage]]
   663      from apache_beam.io.gcp.pubsub import PubsubMessage
   664      from google.cloud import pubsub
   665  
   666      def _get_element(message):
   667        parsed_message = PubsubMessage._from_message(message)
   668        if (timestamp_attribute and
   669            timestamp_attribute in parsed_message.attributes):
   670          rfc3339_or_milli = parsed_message.attributes[timestamp_attribute]
   671          try:
   672            timestamp = Timestamp(micros=int(rfc3339_or_milli) * 1000)
   673          except ValueError:
   674            try:
   675              timestamp = Timestamp.from_rfc3339(rfc3339_or_milli)
   676            except ValueError as e:
   677              raise ValueError('Bad timestamp value: %s' % e)
   678        else:
   679          if message.publish_time is None:
   680            raise ValueError('No publish time present in message: %s' % message)
   681          try:
   682            timestamp = Timestamp.from_utc_datetime(message.publish_time)
   683          except ValueError as e:
   684            raise ValueError('Bad timestamp value for message %s: %s', message, e)
   685  
   686        return timestamp, parsed_message
   687  
   688      # Because of the AutoAck, we are not able to reread messages if this
   689      # evaluator fails with an exception before emitting a bundle. However,
   690      # the DirectRunner currently doesn't retry work items anyway, so the
   691      # pipeline would enter an inconsistent state on any error.
   692      sub_client = pubsub.SubscriberClient()
   693      try:
   694        response = sub_client.pull(
   695            subscription=self._sub_name, max_messages=10, timeout=30)
   696        results = [_get_element(rm.message) for rm in response.received_messages]
   697        ack_ids = [rm.ack_id for rm in response.received_messages]
   698        if ack_ids:
   699          sub_client.acknowledge(subscription=self._sub_name, ack_ids=ack_ids)
   700      finally:
   701        sub_client.close()
   702  
   703      return results
   704  
   705    def finish_bundle(self):
   706      # type: () -> TransformResult
   707      data = self._read_from_pubsub(self.source.timestamp_attribute)
   708      if data:
   709        output_pcollection = list(self._outputs)[0]
   710        bundle = self._evaluation_context.create_bundle(output_pcollection)
   711        # TODO(ccy): Respect the PubSub source's id_label field.
   712        for timestamp, message in data:
   713          if self.source.with_attributes:
   714            element = message
   715          else:
   716            element = message.data
   717          bundle.output(
   718              GlobalWindows.windowed_value(element, timestamp=timestamp))
   719        bundles = [bundle]
   720      else:
   721        bundles = []
   722      assert self._applied_ptransform.transform is not None
   723      if self._applied_ptransform.inputs:
   724        input_pvalue = self._applied_ptransform.inputs[0]
   725      else:
   726        input_pvalue = pvalue.PBegin(self._applied_ptransform.transform.pipeline)
   727      unprocessed_bundle = self._evaluation_context.create_bundle(input_pvalue)
   728  
   729      # TODO(udim): Correct value for watermark hold.
   730      return TransformResult(
   731          self,
   732          bundles, [unprocessed_bundle],
   733          None, {None: Timestamp.of(time.time())})
   734  
   735  
   736  class _FlattenEvaluator(_TransformEvaluator):
   737    """TransformEvaluator for Flatten transform."""
   738    def __init__(
   739        self,
   740        evaluation_context,
   741        applied_ptransform,
   742        input_committed_bundle,
   743        side_inputs):
   744      assert not side_inputs
   745      super().__init__(
   746          evaluation_context,
   747          applied_ptransform,
   748          input_committed_bundle,
   749          side_inputs)
   750  
   751    def start_bundle(self):
   752      assert len(self._outputs) == 1
   753      output_pcollection = list(self._outputs)[0]
   754      self.bundle = self._evaluation_context.create_bundle(output_pcollection)
   755  
   756    def process_element(self, element):
   757      self.bundle.output(element)
   758  
   759    def finish_bundle(self):
   760      bundles = [self.bundle]
   761      return TransformResult(self, bundles, [], None, None)
   762  
   763  
   764  class _ImpulseEvaluator(_TransformEvaluator):
   765    """TransformEvaluator for Impulse transform."""
   766    def finish_bundle(self):
   767      assert len(self._outputs) == 1
   768      output_pcollection = list(self._outputs)[0]
   769      bundle = self._evaluation_context.create_bundle(output_pcollection)
   770      bundle.output(GlobalWindows.windowed_value(b''))
   771      return TransformResult(self, [bundle], [], None, None)
   772  
   773  
   774  class _TaggedReceivers(dict):
   775    """Received ParDo output and redirect to the associated output bundle."""
   776    def __init__(self, evaluation_context):
   777      self._evaluation_context = evaluation_context
   778      self._null_receiver = None
   779      super().__init__()
   780  
   781    class NullReceiver(common.Receiver):
   782      """Ignores undeclared outputs, default execution mode."""
   783      def receive(self, element):
   784        # type: (WindowedValue) -> None
   785        pass
   786  
   787    class _InMemoryReceiver(common.Receiver):
   788      """Buffers undeclared outputs to the given dictionary."""
   789      def __init__(self, target, tag):
   790        self._target = target
   791        self._tag = tag
   792  
   793      def receive(self, element):
   794        # type: (WindowedValue) -> None
   795        self._target[self._tag].append(element)
   796  
   797    def __missing__(self, key):
   798      if not self._null_receiver:
   799        self._null_receiver = _TaggedReceivers.NullReceiver()
   800      return self._null_receiver
   801  
   802  
   803  class _ParDoEvaluator(_TransformEvaluator):
   804    """TransformEvaluator for ParDo transform."""
   805  
   806    def __init__(self,
   807                 evaluation_context, # type: EvaluationContext
   808                 applied_ptransform,  # type: AppliedPTransform
   809                 input_committed_bundle,
   810                 side_inputs,
   811                 perform_dofn_pickle_test=True
   812                ):
   813      super().__init__(
   814          evaluation_context,
   815          applied_ptransform,
   816          input_committed_bundle,
   817          side_inputs)
   818      # This is a workaround for SDF implementation. SDF implementation adds state
   819      # to the SDF that is not picklable.
   820      self._perform_dofn_pickle_test = perform_dofn_pickle_test
   821  
   822    def start_bundle(self):
   823      transform = self._applied_ptransform.transform
   824  
   825      self._tagged_receivers = _TaggedReceivers(self._evaluation_context)
   826      for output_tag in self._applied_ptransform.outputs:
   827        output_pcollection = pvalue.PCollection(None, tag=output_tag)
   828        output_pcollection.producer = self._applied_ptransform
   829        self._tagged_receivers[output_tag] = (
   830            self._evaluation_context.create_bundle(output_pcollection))
   831        self._tagged_receivers[output_tag].tag = output_tag
   832  
   833      self._counter_factory = counters.CounterFactory()
   834  
   835      # TODO(aaltay): Consider storing the serialized form as an optimization.
   836      dofn = (
   837          pickler.loads(pickler.dumps(transform.dofn))
   838          if self._perform_dofn_pickle_test else transform.dofn)
   839  
   840      args = transform.args if hasattr(transform, 'args') else []
   841      kwargs = transform.kwargs if hasattr(transform, 'kwargs') else {}
   842  
   843      self.user_state_context = None
   844      self.user_timer_map = {}
   845      if is_stateful_dofn(dofn):
   846        kv_type_hint = self._applied_ptransform.inputs[0].element_type
   847        if kv_type_hint and kv_type_hint != Any:
   848          coder = coders.registry.get_coder(kv_type_hint)
   849          self.key_coder = coder.key_coder()
   850        else:
   851          self.key_coder = coders.registry.get_coder(Any)
   852  
   853        self.user_state_context = DirectUserStateContext(
   854            self._step_context, dofn, self.key_coder)
   855        _, all_timer_specs = get_dofn_specs(dofn)
   856        for timer_spec in all_timer_specs:
   857          self.user_timer_map['user/%s' % timer_spec.name] = timer_spec
   858  
   859      self.runner = DoFnRunner(
   860          dofn,
   861          args,
   862          kwargs,
   863          self._side_inputs,
   864          self._applied_ptransform.inputs[0].windowing,
   865          tagged_receivers=self._tagged_receivers,
   866          step_name=self._applied_ptransform.full_label,
   867          state=DoFnState(self._counter_factory),
   868          user_state_context=self.user_state_context)
   869      self.runner.setup()
   870      self.runner.start()
   871  
   872    def process_timer(self, timer_firing):
   873      if timer_firing.name not in self.user_timer_map:
   874        _LOGGER.warning('Unknown timer fired: %s', timer_firing)
   875      timer_spec = self.user_timer_map[timer_firing.name]
   876      self.runner.process_user_timer(
   877          timer_spec,
   878          self.key_coder.decode(timer_firing.encoded_key),
   879          timer_firing.window,
   880          timer_firing.timestamp,
   881          # TODO Add paneinfo to timer_firing in DirectRunner
   882          None,
   883          timer_firing.dynamic_timer_tag)
   884  
   885    def process_element(self, element):
   886      self.runner.process(element)
   887  
   888    def finish_bundle(self):
   889      self.runner.finish()
   890      self.runner.teardown()
   891      bundles = list(self._tagged_receivers.values())
   892      result_counters = self._counter_factory.get_counters()
   893      if self.user_state_context:
   894        self.user_state_context.commit()
   895        self.user_state_context.reset()
   896      return TransformResult(self, bundles, [], result_counters, None)
   897  
   898  
   899  class _GroupByKeyOnlyEvaluator(_TransformEvaluator):
   900    """TransformEvaluator for _GroupByKeyOnly transform."""
   901  
   902    MAX_ELEMENT_PER_BUNDLE = None
   903    ELEMENTS_TAG = _ListStateTag('elements')
   904    COMPLETION_TAG = _CombiningValueStateTag('completed', any)
   905  
   906    def __init__(
   907        self,
   908        evaluation_context,
   909        applied_ptransform,
   910        input_committed_bundle,
   911        side_inputs):
   912      assert not side_inputs
   913      super().__init__(
   914          evaluation_context,
   915          applied_ptransform,
   916          input_committed_bundle,
   917          side_inputs)
   918  
   919    def _is_final_bundle(self):
   920      return (
   921          self._execution_context.watermarks.input_watermark ==
   922          WatermarkManager.WATERMARK_POS_INF)
   923  
   924    def start_bundle(self):
   925      self.global_state = self._step_context.get_keyed_state(None)
   926  
   927      assert len(self._outputs) == 1
   928      self.output_pcollection = list(self._outputs)[0]
   929  
   930      # The output type of a GroupByKey will be Tuple[Any, Any] or more specific.
   931      # TODO(https://github.com/apache/beam/issues/18490): Infer coders earlier.
   932      kv_type_hint = (
   933          self._applied_ptransform.outputs[None].element_type or
   934          self._applied_ptransform.transform.get_type_hints().input_types[0][0])
   935      self.key_coder = coders.registry.get_coder(kv_type_hint.tuple_types[0])
   936  
   937    def process_timer(self, timer_firing):
   938      # We do not need to emit a KeyedWorkItem to process_element().
   939      pass
   940  
   941    def process_element(self, element):
   942      assert not self.global_state.get_state(
   943          None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG)
   944      if (isinstance(element, WindowedValue) and
   945          isinstance(element.value, abc.Iterable) and len(element.value) == 2):
   946        k, v = element.value
   947        encoded_k = self.key_coder.encode(k)
   948        state = self._step_context.get_keyed_state(encoded_k)
   949        state.add_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG, v)
   950      else:
   951        raise TypeCheckError(
   952            'Input to _GroupByKeyOnly must be a PCollection of '
   953            'windowed key-value pairs. Instead received: %r.' % element)
   954  
   955    def finish_bundle(self):
   956      if self._is_final_bundle():
   957        if self.global_state.get_state(None,
   958                                       _GroupByKeyOnlyEvaluator.COMPLETION_TAG):
   959          # Ignore empty bundles after emitting output. (This may happen because
   960          # empty bundles do not affect input watermarks.)
   961          bundles = []
   962        else:
   963          gbk_result = []
   964          # TODO(ccy): perhaps we can clean this up to not use this
   965          # internal attribute of the DirectStepContext.
   966          for encoded_k in self._step_context.existing_keyed_state:
   967            # Ignore global state.
   968            if encoded_k is None:
   969              continue
   970            k = self.key_coder.decode(encoded_k)
   971            state = self._step_context.get_keyed_state(encoded_k)
   972            vs = state.get_state(None, _GroupByKeyOnlyEvaluator.ELEMENTS_TAG)
   973            gbk_result.append(GlobalWindows.windowed_value((k, vs)))
   974  
   975          def len_element_fn(element):
   976            _, v = element.value
   977            return len(v)
   978  
   979          bundles = self._split_list_into_bundles(
   980              self.output_pcollection,
   981              gbk_result,
   982              _GroupByKeyOnlyEvaluator.MAX_ELEMENT_PER_BUNDLE,
   983              len_element_fn)
   984  
   985        self.global_state.add_state(
   986            None, _GroupByKeyOnlyEvaluator.COMPLETION_TAG, True)
   987        hold = WatermarkManager.WATERMARK_POS_INF
   988      else:
   989        bundles = []
   990        hold = WatermarkManager.WATERMARK_NEG_INF
   991        self.global_state.set_timer(
   992            None, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_POS_INF)
   993  
   994      return TransformResult(self, bundles, [], None, {None: hold})
   995  
   996  
   997  class _StreamingGroupByKeyOnlyEvaluator(_TransformEvaluator):
   998    """TransformEvaluator for _StreamingGroupByKeyOnly transform.
   999  
  1000    The _GroupByKeyOnlyEvaluator buffers elements until its input watermark goes
  1001    to infinity, which is suitable for batch mode execution. During streaming
  1002    mode execution, we emit each bundle as it comes to the next transform.
  1003    """
  1004  
  1005    MAX_ELEMENT_PER_BUNDLE = None
  1006  
  1007    def __init__(
  1008        self,
  1009        evaluation_context,
  1010        applied_ptransform,
  1011        input_committed_bundle,
  1012        side_inputs):
  1013      assert not side_inputs
  1014      super().__init__(
  1015          evaluation_context,
  1016          applied_ptransform,
  1017          input_committed_bundle,
  1018          side_inputs)
  1019  
  1020    def start_bundle(self):
  1021      self.gbk_items = collections.defaultdict(list)
  1022  
  1023      assert len(self._outputs) == 1
  1024      self.output_pcollection = list(self._outputs)[0]
  1025  
  1026      # The input type of a GroupByKey will be Tuple[Any, Any] or more specific.
  1027      kv_type_hint = self._applied_ptransform.inputs[0].element_type
  1028      key_type_hint = (kv_type_hint.tuple_types[0] if kv_type_hint else Any)
  1029      self.key_coder = coders.registry.get_coder(key_type_hint)
  1030  
  1031    def process_element(self, element):
  1032      if (isinstance(element, WindowedValue) and
  1033          isinstance(element.value, collections.abc.Iterable) and
  1034          len(element.value) == 2):
  1035        k, v = element.value
  1036        self.gbk_items[self.key_coder.encode(k)].append(v)
  1037      else:
  1038        raise TypeCheckError(
  1039            'Input to _GroupByKeyOnly must be a PCollection of '
  1040            'windowed key-value pairs. Instead received: %r.' % element)
  1041  
  1042    def finish_bundle(self):
  1043      bundles = []
  1044      bundle = None
  1045      for encoded_k, vs in self.gbk_items.items():
  1046        if not bundle:
  1047          bundle = self._evaluation_context.create_bundle(self.output_pcollection)
  1048          bundles.append(bundle)
  1049        kwi = KeyedWorkItem(encoded_k, elements=vs)
  1050        bundle.add(GlobalWindows.windowed_value(kwi))
  1051  
  1052      return TransformResult(self, bundles, [], None, None)
  1053  
  1054  
  1055  class _StreamingGroupAlsoByWindowEvaluator(_TransformEvaluator):
  1056    """TransformEvaluator for the _StreamingGroupAlsoByWindow transform.
  1057  
  1058    This evaluator is only used in streaming mode.  In batch mode, the
  1059    GroupAlsoByWindow operation is evaluated as a normal DoFn, as defined
  1060    in transforms/core.py.
  1061    """
  1062    def __init__(
  1063        self,
  1064        evaluation_context,
  1065        applied_ptransform,
  1066        input_committed_bundle,
  1067        side_inputs):
  1068      assert not side_inputs
  1069      super().__init__(
  1070          evaluation_context,
  1071          applied_ptransform,
  1072          input_committed_bundle,
  1073          side_inputs)
  1074  
  1075    def start_bundle(self):
  1076      assert len(self._outputs) == 1
  1077      self.output_pcollection = list(self._outputs)[0]
  1078      self.driver = create_trigger_driver(
  1079          self._applied_ptransform.transform.windowing,
  1080          clock=self._evaluation_context._watermark_manager._clock)
  1081      self.gabw_items = []
  1082      self.keyed_holds = {}
  1083  
  1084      # The input type (which is the same as the output type) of a
  1085      # GroupAlsoByWindow will be Tuple[Any, Iter[Any]] or more specific.
  1086      kv_type_hint = self._applied_ptransform.outputs[None].element_type
  1087      key_type_hint = (kv_type_hint.tuple_types[0] if kv_type_hint else Any)
  1088      self.key_coder = coders.registry.get_coder(key_type_hint)
  1089  
  1090    def process_element(self, element):
  1091      kwi = element.value
  1092      assert isinstance(kwi, KeyedWorkItem), kwi
  1093      encoded_k, timer_firings, vs = (
  1094          kwi.encoded_key, kwi.timer_firings, kwi.elements)
  1095      k = self.key_coder.decode(encoded_k)
  1096      state = self._step_context.get_keyed_state(encoded_k)
  1097  
  1098      watermarks = self._evaluation_context._watermark_manager.get_watermarks(
  1099          self._applied_ptransform)
  1100      for timer_firing in timer_firings:
  1101        for wvalue in self.driver.process_timer(timer_firing.window,
  1102                                                timer_firing.name,
  1103                                                timer_firing.time_domain,
  1104                                                timer_firing.timestamp,
  1105                                                state,
  1106                                                watermarks.input_watermark):
  1107          self.gabw_items.append(wvalue.with_value((k, wvalue.value)))
  1108      if vs:
  1109        for wvalue in self.driver.process_elements(state,
  1110                                                   vs,
  1111                                                   watermarks.output_watermark,
  1112                                                   watermarks.input_watermark):
  1113          self.gabw_items.append(wvalue.with_value((k, wvalue.value)))
  1114  
  1115      self.keyed_holds[encoded_k] = state.get_earliest_hold()
  1116  
  1117    def finish_bundle(self):
  1118      bundles = []
  1119      if self.gabw_items:
  1120        bundle = self._evaluation_context.create_bundle(self.output_pcollection)
  1121        for item in self.gabw_items:
  1122          bundle.add(item)
  1123        bundles.append(bundle)
  1124  
  1125      return TransformResult(self, bundles, [], None, self.keyed_holds)
  1126  
  1127  
  1128  class _NativeWriteEvaluator(_TransformEvaluator):
  1129    """TransformEvaluator for _NativeWrite transform."""
  1130  
  1131    ELEMENTS_TAG = _ListStateTag('elements')
  1132  
  1133    def __init__(
  1134        self,
  1135        evaluation_context,
  1136        applied_ptransform,
  1137        input_committed_bundle,
  1138        side_inputs):
  1139      assert not side_inputs
  1140      super().__init__(
  1141          evaluation_context,
  1142          applied_ptransform,
  1143          input_committed_bundle,
  1144          side_inputs)
  1145  
  1146      assert applied_ptransform.transform.sink
  1147      self._sink = applied_ptransform.transform.sink
  1148  
  1149    @property
  1150    def _is_final_bundle(self):
  1151      return (
  1152          self._execution_context.watermarks.input_watermark ==
  1153          WatermarkManager.WATERMARK_POS_INF)
  1154  
  1155    @property
  1156    def _has_already_produced_output(self):
  1157      return (
  1158          self._execution_context.watermarks.output_watermark ==
  1159          WatermarkManager.WATERMARK_POS_INF)
  1160  
  1161    def start_bundle(self):
  1162      self.global_state = self._step_context.get_keyed_state(None)
  1163  
  1164    def process_timer(self, timer_firing):
  1165      # We do not need to emit a KeyedWorkItem to process_element().
  1166      pass
  1167  
  1168    def process_element(self, element):
  1169      self.global_state.add_state(
  1170          None, _NativeWriteEvaluator.ELEMENTS_TAG, element)
  1171  
  1172    def finish_bundle(self):
  1173      # finish_bundle will append incoming bundles in memory until all the bundles
  1174      # carrying data is processed. This is done to produce only a single output
  1175      # shard (some tests depends on this behavior). It is possible to have
  1176      # incoming empty bundles after the output is produced, these bundles will be
  1177      # ignored and would not generate additional output files.
  1178      # TODO(altay): Do not wait until the last bundle to write in a single shard.
  1179      if self._is_final_bundle:
  1180        elements = self.global_state.get_state(
  1181            None, _NativeWriteEvaluator.ELEMENTS_TAG)
  1182        if self._has_already_produced_output:
  1183          # Ignore empty bundles that arrive after the output is produced.
  1184          assert elements == []
  1185        else:
  1186          self._sink.pipeline_options = self._evaluation_context.pipeline_options
  1187          with self._sink.writer() as writer:
  1188            for v in elements:
  1189              writer.Write(v.value)
  1190        hold = WatermarkManager.WATERMARK_POS_INF
  1191      else:
  1192        hold = WatermarkManager.WATERMARK_NEG_INF
  1193        self.global_state.set_timer(
  1194            None, '', TimeDomain.WATERMARK, WatermarkManager.WATERMARK_POS_INF)
  1195  
  1196      return TransformResult(self, [], [], None, {None: hold})
  1197  
  1198  
  1199  class _ProcessElementsEvaluator(_TransformEvaluator):
  1200    """An evaluator for sdf_direct_runner.ProcessElements transform."""
  1201  
  1202    # Maximum number of elements that will be produced by a Splittable DoFn before
  1203    # a checkpoint is requested by the runner.
  1204    DEFAULT_MAX_NUM_OUTPUTS = None
  1205    # Maximum duration a Splittable DoFn will process an element before a
  1206    # checkpoint is requested by the runner.
  1207    DEFAULT_MAX_DURATION = 1
  1208  
  1209    def __init__(
  1210        self,
  1211        evaluation_context,
  1212        applied_ptransform,
  1213        input_committed_bundle,
  1214        side_inputs):
  1215      super().__init__(
  1216          evaluation_context,
  1217          applied_ptransform,
  1218          input_committed_bundle,
  1219          side_inputs)
  1220  
  1221      process_elements_transform = applied_ptransform.transform
  1222      assert isinstance(process_elements_transform, ProcessElements)
  1223  
  1224      # Replacing the do_fn of the transform with a wrapper do_fn that performs
  1225      # SDF magic.
  1226      transform = applied_ptransform.transform
  1227      sdf = transform.sdf
  1228      self._process_fn = transform.new_process_fn(sdf)
  1229      transform.dofn = self._process_fn
  1230  
  1231      assert isinstance(self._process_fn, ProcessFn)
  1232  
  1233      self._process_fn.step_context = self._step_context
  1234  
  1235      process_element_invoker = (
  1236          SDFProcessElementInvoker(
  1237              max_num_outputs=self.DEFAULT_MAX_NUM_OUTPUTS,
  1238              max_duration=self.DEFAULT_MAX_DURATION))
  1239      self._process_fn.set_process_element_invoker(process_element_invoker)
  1240  
  1241      self._par_do_evaluator = _ParDoEvaluator(
  1242          evaluation_context,
  1243          applied_ptransform,
  1244          input_committed_bundle,
  1245          side_inputs,
  1246          perform_dofn_pickle_test=False)
  1247      self.keyed_holds = {}
  1248  
  1249    def start_bundle(self):
  1250      self._par_do_evaluator.start_bundle()
  1251  
  1252    def process_element(self, element):
  1253      assert isinstance(element, WindowedValue)
  1254      assert len(element.windows) == 1
  1255      window = element.windows[0]
  1256      if isinstance(element.value, KeyedWorkItem):
  1257        key = element.value.encoded_key
  1258      else:
  1259        # If not a `KeyedWorkItem`, this must be a tuple where key is a randomly
  1260        # generated key and the value is a `WindowedValue` that contains an
  1261        # `ElementAndRestriction` object.
  1262        assert isinstance(element.value, tuple)
  1263        key = element.value[0]
  1264  
  1265      self._par_do_evaluator.process_element(element)
  1266  
  1267      state = self._step_context.get_keyed_state(key)
  1268      self.keyed_holds[key] = state.get_state(
  1269          window, self._process_fn.watermark_hold_tag)
  1270  
  1271    def finish_bundle(self):
  1272      par_do_result = self._par_do_evaluator.finish_bundle()
  1273  
  1274      transform_result = TransformResult(
  1275          self,
  1276          par_do_result.uncommitted_output_bundles,
  1277          par_do_result.unprocessed_bundles,
  1278          par_do_result.counters,
  1279          par_do_result.keyed_watermark_holds,
  1280          par_do_result.undeclared_tag_values)
  1281      for key, keyed_hold in self.keyed_holds.items():
  1282        transform_result.keyed_watermark_holds[key] = keyed_hold
  1283      return transform_result