github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/bundle_processor.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """SDK harness for executing Python Fns via the Fn API."""
    19  
    20  # pytype: skip-file
    21  
    22  import base64
    23  import bisect
    24  import collections
    25  import copy
    26  import json
    27  import logging
    28  import random
    29  import threading
    30  from typing import TYPE_CHECKING
    31  from typing import Any
    32  from typing import Callable
    33  from typing import Container
    34  from typing import DefaultDict
    35  from typing import Dict
    36  from typing import FrozenSet
    37  from typing import Iterable
    38  from typing import Iterator
    39  from typing import List
    40  from typing import Mapping
    41  from typing import Optional
    42  from typing import Set
    43  from typing import Tuple
    44  from typing import Type
    45  from typing import TypeVar
    46  from typing import Union
    47  from typing import cast
    48  
    49  from google.protobuf import duration_pb2
    50  from google.protobuf import timestamp_pb2
    51  
    52  import apache_beam as beam
    53  from apache_beam import coders
    54  from apache_beam.coders import WindowedValueCoder
    55  from apache_beam.coders import coder_impl
    56  from apache_beam.internal import pickler
    57  from apache_beam.io import iobase
    58  from apache_beam.metrics import monitoring_infos
    59  from apache_beam.portability import common_urns
    60  from apache_beam.portability import python_urns
    61  from apache_beam.portability.api import beam_fn_api_pb2
    62  from apache_beam.portability.api import beam_runner_api_pb2
    63  from apache_beam.runners import common
    64  from apache_beam.runners import pipeline_context
    65  from apache_beam.runners.worker import data_sampler
    66  from apache_beam.runners.worker import operation_specs
    67  from apache_beam.runners.worker import operations
    68  from apache_beam.runners.worker import statesampler
    69  from apache_beam.runners.worker.data_sampler import OutputSampler
    70  from apache_beam.transforms import TimeDomain
    71  from apache_beam.transforms import core
    72  from apache_beam.transforms import environments
    73  from apache_beam.transforms import sideinputs
    74  from apache_beam.transforms import userstate
    75  from apache_beam.transforms import window
    76  from apache_beam.utils import counters
    77  from apache_beam.utils import proto_utils
    78  from apache_beam.utils import timestamp
    79  from apache_beam.utils.windowed_value import WindowedValue
    80  
    81  if TYPE_CHECKING:
    82    from google.protobuf import message  # pylint: disable=ungrouped-imports
    83    from apache_beam import pvalue
    84    from apache_beam.portability.api import metrics_pb2
    85    from apache_beam.runners.sdf_utils import SplitResultPrimary
    86    from apache_beam.runners.sdf_utils import SplitResultResidual
    87    from apache_beam.runners.worker import data_plane
    88    from apache_beam.runners.worker import sdk_worker
    89    from apache_beam.transforms.core import Windowing
    90    from apache_beam.transforms.window import BoundedWindow
    91    from apache_beam.utils import windowed_value
    92  
    93  T = TypeVar('T')
    94  ConstructorFn = Callable[[
    95      'BeamTransformFactory',
    96      Any,
    97      beam_runner_api_pb2.PTransform,
    98      Union['message.Message', bytes],
    99      Dict[str, List[operations.Operation]]
   100  ],
   101                           operations.Operation]
   102  OperationT = TypeVar('OperationT', bound=operations.Operation)
   103  FnApiUserRuntimeStateTypes = Union['ReadModifyWriteRuntimeState',
   104                                     'CombiningValueRuntimeState',
   105                                     'SynchronousSetRuntimeState',
   106                                     'SynchronousBagRuntimeState']
   107  
   108  DATA_INPUT_URN = 'beam:runner:source:v1'
   109  DATA_OUTPUT_URN = 'beam:runner:sink:v1'
   110  SYNTHETIC_DATA_SAMPLING_URN = 'beam:internal:sampling:v1'
   111  IDENTITY_DOFN_URN = 'beam:dofn:identity:0.1'
   112  # TODO(vikasrk): Fix this once runner sends appropriate common_urns.
   113  OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN = 'beam:dofn:javasdk:0.1'
   114  OLD_DATAFLOW_RUNNER_HARNESS_READ_URN = 'beam:source:java:0.1'
   115  URNS_NEEDING_PCOLLECTIONS = set([
   116      monitoring_infos.ELEMENT_COUNT_URN, monitoring_infos.SAMPLED_BYTE_SIZE_URN
   117  ])
   118  
   119  _LOGGER = logging.getLogger(__name__)
   120  
   121  
   122  class RunnerIOOperation(operations.Operation):
   123    """Common baseclass for runner harness IO operations."""
   124  
   125    def __init__(self,
   126                 name_context,  # type: common.NameContext
   127                 step_name,  # type: Any
   128                 consumers,  # type: Mapping[Any, Iterable[operations.Operation]]
   129                 counter_factory,  # type: counters.CounterFactory
   130                 state_sampler,  # type: statesampler.StateSampler
   131                 windowed_coder,  # type: coders.Coder
   132                 transform_id,  # type: str
   133                 data_channel  # type: data_plane.DataChannel
   134                ):
   135      # type: (...) -> None
   136      super().__init__(name_context, None, counter_factory, state_sampler)
   137      self.windowed_coder = windowed_coder
   138      self.windowed_coder_impl = windowed_coder.get_impl()
   139      # transform_id represents the consumer for the bytes in the data plane for a
   140      # DataInputOperation or a producer of these bytes for a DataOutputOperation.
   141      self.transform_id = transform_id
   142      self.data_channel = data_channel
   143      for _, consumer_ops in consumers.items():
   144        for consumer in consumer_ops:
   145          self.add_receiver(consumer, 0)
   146  
   147  
   148  class DataOutputOperation(RunnerIOOperation):
   149    """A sink-like operation that gathers outputs to be sent back to the runner.
   150    """
   151    def set_output_stream(self, output_stream):
   152      # type: (data_plane.ClosableOutputStream) -> None
   153      self.output_stream = output_stream
   154  
   155    def process(self, windowed_value):
   156      # type: (windowed_value.WindowedValue) -> None
   157      self.windowed_coder_impl.encode_to_stream(
   158          windowed_value, self.output_stream, True)
   159      self.output_stream.maybe_flush()
   160  
   161    def finish(self):
   162      # type: () -> None
   163      super().finish()
   164      self.output_stream.close()
   165  
   166  
   167  class DataInputOperation(RunnerIOOperation):
   168    """A source-like operation that gathers input from the runner."""
   169  
   170    def __init__(self,
   171                 operation_name,  # type: common.NameContext
   172                 step_name,
   173                 consumers,  # type: Mapping[Any, List[operations.Operation]]
   174                 counter_factory,  # type: counters.CounterFactory
   175                 state_sampler,  # type: statesampler.StateSampler
   176                 windowed_coder,  # type: coders.Coder
   177                 transform_id,
   178                 data_channel  # type: data_plane.GrpcClientDataChannel
   179                ):
   180      # type: (...) -> None
   181      super().__init__(
   182          operation_name,
   183          step_name,
   184          consumers,
   185          counter_factory,
   186          state_sampler,
   187          windowed_coder,
   188          transform_id=transform_id,
   189          data_channel=data_channel)
   190  
   191      self.consumer = next(iter(consumers.values()))
   192      self.splitting_lock = threading.Lock()
   193      self.index = -1
   194      self.stop = float('inf')
   195      self.started = False
   196  
   197    def setup(self):
   198      super().setup()
   199      # We must do this manually as we don't have a spec or spec.output_coders.
   200      self.receivers = [
   201          operations.ConsumerSet.create(
   202              counter_factory=self.counter_factory,
   203              step_name=self.name_context.step_name,
   204              output_index=0,
   205              consumers=self.consumer,
   206              coder=self.windowed_coder,
   207              producer_type_hints=self._get_runtime_performance_hints(),
   208              producer_batch_converter=self.get_output_batch_converter())
   209      ]
   210  
   211    def start(self):
   212      # type: () -> None
   213      super().start()
   214      with self.splitting_lock:
   215        self.started = True
   216  
   217    def process(self, windowed_value):
   218      # type: (windowed_value.WindowedValue) -> None
   219      self.output(windowed_value)
   220  
   221    def process_encoded(self, encoded_windowed_values):
   222      # type: (bytes) -> None
   223      input_stream = coder_impl.create_InputStream(encoded_windowed_values)
   224      while input_stream.size() > 0:
   225        with self.splitting_lock:
   226          if self.index == self.stop - 1:
   227            return
   228          self.index += 1
   229        decoded_value = self.windowed_coder_impl.decode_from_stream(
   230            input_stream, True)
   231        self.output(decoded_value)
   232  
   233    def monitoring_infos(self, transform_id, tag_to_pcollection_id):
   234      # type: (str, Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]
   235      all_monitoring_infos = super().monitoring_infos(
   236          transform_id, tag_to_pcollection_id)
   237      read_progress_info = monitoring_infos.int64_counter(
   238          monitoring_infos.DATA_CHANNEL_READ_INDEX,
   239          self.index,
   240          ptransform=transform_id)
   241      all_monitoring_infos[monitoring_infos.to_key(
   242          read_progress_info)] = read_progress_info
   243      return all_monitoring_infos
   244  
   245    # TODO(https://github.com/apache/beam/issues/19737): typing not compatible
   246    # with super type
   247    def try_split(  # type: ignore[override]
   248        self, fraction_of_remainder, total_buffer_size, allowed_split_points):
   249      # type: (...) -> Optional[Tuple[int, Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual], int]]
   250      with self.splitting_lock:
   251        if not self.started:
   252          return None
   253        if self.index == -1:
   254          # We are "finished" with the (non-existent) previous element.
   255          current_element_progress = 1.0
   256        else:
   257          current_element_progress_object = (
   258              self.receivers[0].current_element_progress())
   259          if current_element_progress_object is None:
   260            current_element_progress = 0.5
   261          else:
   262            current_element_progress = (
   263                current_element_progress_object.fraction_completed)
   264        # Now figure out where to split.
   265        split = self._compute_split(
   266            self.index,
   267            current_element_progress,
   268            self.stop,
   269            fraction_of_remainder,
   270            total_buffer_size,
   271            allowed_split_points,
   272            self.receivers[0].try_split)
   273        if split:
   274          self.stop = split[-1]
   275        return split
   276  
   277    @staticmethod
   278    def _compute_split(
   279        index,
   280        current_element_progress,
   281        stop,
   282        fraction_of_remainder,
   283        total_buffer_size,
   284        allowed_split_points=(),
   285        try_split=lambda fraction: None):
   286      def is_valid_split_point(index):
   287        return not allowed_split_points or index in allowed_split_points
   288  
   289      if total_buffer_size < index + 1:
   290        total_buffer_size = index + 1
   291      elif total_buffer_size > stop:
   292        total_buffer_size = stop
   293      # The units here (except for keep_of_element_remainder) are all in
   294      # terms of number of (possibly fractional) elements.
   295      remainder = total_buffer_size - index - current_element_progress
   296      keep = remainder * fraction_of_remainder
   297      if current_element_progress < 1:
   298        keep_of_element_remainder = keep / (1 - current_element_progress)
   299        # If it's less than what's left of the current element,
   300        # try splitting at the current element.
   301        if (keep_of_element_remainder < 1 and is_valid_split_point(index) and
   302            is_valid_split_point(index + 1)):
   303          split = try_split(
   304              keep_of_element_remainder
   305          )  # type: Optional[Tuple[Iterable[operations.SdfSplitResultsPrimary], Iterable[operations.SdfSplitResultsResidual]]]
   306          if split:
   307            element_primaries, element_residuals = split
   308            return index - 1, element_primaries, element_residuals, index + 1
   309      # Otherwise, split at the closest element boundary.
   310      # pylint: disable=bad-option-value
   311      stop_index = index + max(1, int(round(current_element_progress + keep)))
   312      if allowed_split_points and stop_index not in allowed_split_points:
   313        # Choose the closest allowed split point.
   314        allowed_split_points = sorted(allowed_split_points)
   315        closest = bisect.bisect(allowed_split_points, stop_index)
   316        if closest == 0:
   317          stop_index = allowed_split_points[0]
   318        elif closest == len(allowed_split_points):
   319          stop_index = allowed_split_points[-1]
   320        else:
   321          prev = allowed_split_points[closest - 1]
   322          next = allowed_split_points[closest]
   323          if index < prev and stop_index - prev < next - stop_index:
   324            stop_index = prev
   325          else:
   326            stop_index = next
   327      if index < stop_index < stop:
   328        return stop_index - 1, [], [], stop_index
   329      else:
   330        return None
   331  
   332    def finish(self):
   333      # type: () -> None
   334      super().finish()
   335      with self.splitting_lock:
   336        self.index += 1
   337        self.started = False
   338  
   339    def reset(self):
   340      # type: () -> None
   341      with self.splitting_lock:
   342        self.index = -1
   343        self.stop = float('inf')
   344      super().reset()
   345  
   346  
   347  class _StateBackedIterable(object):
   348    def __init__(self,
   349                 state_handler,  # type: sdk_worker.CachingStateHandler
   350                 state_key,  # type: beam_fn_api_pb2.StateKey
   351                 coder_or_impl,  # type: Union[coders.Coder, coder_impl.CoderImpl]
   352                ):
   353      # type: (...) -> None
   354      self._state_handler = state_handler
   355      self._state_key = state_key
   356      if isinstance(coder_or_impl, coders.Coder):
   357        self._coder_impl = coder_or_impl.get_impl()
   358      else:
   359        self._coder_impl = coder_or_impl
   360  
   361    def __iter__(self):
   362      # type: () -> Iterator[Any]
   363      return iter(
   364          self._state_handler.blocking_get(self._state_key, self._coder_impl))
   365  
   366    def __reduce__(self):
   367      return list, (list(self), )
   368  
   369  
   370  coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type(
   371      _StateBackedIterable)
   372  
   373  
   374  class StateBackedSideInputMap(object):
   375    def __init__(self,
   376                 state_handler,  # type: sdk_worker.CachingStateHandler
   377                 transform_id,  # type: str
   378                 tag,  # type: Optional[str]
   379                 side_input_data,  # type: pvalue.SideInputData
   380                 coder  # type: WindowedValueCoder
   381                ):
   382      # type: (...) -> None
   383      self._state_handler = state_handler
   384      self._transform_id = transform_id
   385      self._tag = tag
   386      self._side_input_data = side_input_data
   387      self._element_coder = coder.wrapped_value_coder
   388      self._target_window_coder = coder.window_coder
   389      # TODO(robertwb): Limit the cache size.
   390      self._cache = {}  # type: Dict[BoundedWindow, Any]
   391  
   392    def __getitem__(self, window):
   393      target_window = self._side_input_data.window_mapping_fn(window)
   394      if target_window not in self._cache:
   395        state_handler = self._state_handler
   396        access_pattern = self._side_input_data.access_pattern
   397  
   398        if access_pattern == common_urns.side_inputs.ITERABLE.urn:
   399          state_key = beam_fn_api_pb2.StateKey(
   400              iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
   401                  transform_id=self._transform_id,
   402                  side_input_id=self._tag,
   403                  window=self._target_window_coder.encode(target_window)))
   404          raw_view = _StateBackedIterable(
   405              state_handler, state_key, self._element_coder)
   406  
   407        elif access_pattern == common_urns.side_inputs.MULTIMAP.urn:
   408          state_key = beam_fn_api_pb2.StateKey(
   409              multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
   410                  transform_id=self._transform_id,
   411                  side_input_id=self._tag,
   412                  window=self._target_window_coder.encode(target_window),
   413                  key=b''))
   414          cache = {}
   415          key_coder_impl = self._element_coder.key_coder().get_impl()
   416          value_coder = self._element_coder.value_coder()
   417  
   418          class MultiMap(object):
   419            def __getitem__(self, key):
   420              if key not in cache:
   421                keyed_state_key = beam_fn_api_pb2.StateKey()
   422                keyed_state_key.CopyFrom(state_key)
   423                keyed_state_key.multimap_side_input.key = (
   424                    key_coder_impl.encode_nested(key))
   425                cache[key] = _StateBackedIterable(
   426                    state_handler, keyed_state_key, value_coder)
   427              return cache[key]
   428  
   429            def __reduce__(self):
   430              # TODO(robertwb): Figure out how to support this.
   431              raise TypeError(common_urns.side_inputs.MULTIMAP.urn)
   432  
   433          raw_view = MultiMap()
   434  
   435        else:
   436          raise ValueError("Unknown access pattern: '%s'" % access_pattern)
   437  
   438        self._cache[target_window] = self._side_input_data.view_fn(raw_view)
   439      return self._cache[target_window]
   440  
   441    def is_globally_windowed(self):
   442      # type: () -> bool
   443      return (
   444          self._side_input_data.window_mapping_fn ==
   445          sideinputs._global_window_mapping_fn)
   446  
   447    def reset(self):
   448      # type: () -> None
   449      # TODO(BEAM-5428): Cross-bundle caching respecting cache tokens.
   450      self._cache = {}
   451  
   452  
   453  class ReadModifyWriteRuntimeState(userstate.ReadModifyWriteRuntimeState):
   454    def __init__(self, underlying_bag_state):
   455      self._underlying_bag_state = underlying_bag_state
   456  
   457    def read(self):  # type: () -> Any
   458      values = list(self._underlying_bag_state.read())
   459      if not values:
   460        return None
   461      return values[0]
   462  
   463    def write(self, value):  # type: (Any) -> None
   464      self.clear()
   465      self._underlying_bag_state.add(value)
   466  
   467    def clear(self):  # type: () -> None
   468      self._underlying_bag_state.clear()
   469  
   470    def commit(self):  # type: () -> None
   471      self._underlying_bag_state.commit()
   472  
   473  
   474  class CombiningValueRuntimeState(userstate.CombiningValueRuntimeState):
   475    def __init__(self, underlying_bag_state, combinefn):
   476      # type: (userstate.AccumulatingRuntimeState, core.CombineFn) -> None
   477      self._combinefn = combinefn
   478      self._combinefn.setup()
   479      self._underlying_bag_state = underlying_bag_state
   480      self._finalized = False
   481  
   482    def _read_accumulator(self, rewrite=True):
   483      merged_accumulator = self._combinefn.merge_accumulators(
   484          self._underlying_bag_state.read())
   485      if rewrite:
   486        self._underlying_bag_state.clear()
   487        self._underlying_bag_state.add(merged_accumulator)
   488      return merged_accumulator
   489  
   490    def read(self):
   491      # type: () -> Iterable[Any]
   492      return self._combinefn.extract_output(self._read_accumulator())
   493  
   494    def add(self, value):
   495      # type: (Any) -> None
   496      # Prefer blind writes, but don't let them grow unboundedly.
   497      # This should be tuned to be much lower, but for now exercise
   498      # both paths well.
   499      if random.random() < 0.5:
   500        accumulator = self._read_accumulator(False)
   501        self._underlying_bag_state.clear()
   502      else:
   503        accumulator = self._combinefn.create_accumulator()
   504      self._underlying_bag_state.add(
   505          self._combinefn.add_input(accumulator, value))
   506  
   507    def clear(self):
   508      # type: () -> None
   509      self._underlying_bag_state.clear()
   510  
   511    def commit(self):
   512      self._underlying_bag_state.commit()
   513  
   514    def finalize(self):
   515      if not self._finalized:
   516        self._combinefn.teardown()
   517        self._finalized = True
   518  
   519  
   520  class _ConcatIterable(object):
   521    """An iterable that is the concatination of two iterables.
   522  
   523    Unlike itertools.chain, this allows reiteration.
   524    """
   525    def __init__(self, first, second):
   526      # type: (Iterable[Any], Iterable[Any]) -> None
   527      self.first = first
   528      self.second = second
   529  
   530    def __iter__(self):
   531      # type: () -> Iterator[Any]
   532      for elem in self.first:
   533        yield elem
   534      for elem in self.second:
   535        yield elem
   536  
   537  
   538  coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type(_ConcatIterable)
   539  
   540  
   541  class SynchronousBagRuntimeState(userstate.BagRuntimeState):
   542  
   543    def __init__(self,
   544                 state_handler,  # type: sdk_worker.CachingStateHandler
   545                 state_key,  # type: beam_fn_api_pb2.StateKey
   546                 value_coder  # type: coders.Coder
   547                ):
   548      # type: (...) -> None
   549      self._state_handler = state_handler
   550      self._state_key = state_key
   551      self._value_coder = value_coder
   552      self._cleared = False
   553      self._added_elements = []  # type: List[Any]
   554  
   555    def read(self):
   556      # type: () -> Iterable[Any]
   557      return _ConcatIterable([] if self._cleared else cast(
   558          'Iterable[Any]',
   559          _StateBackedIterable(
   560              self._state_handler, self._state_key, self._value_coder)),
   561                             self._added_elements)
   562  
   563    def add(self, value):
   564      # type: (Any) -> None
   565      self._added_elements.append(value)
   566  
   567    def clear(self):
   568      # type: () -> None
   569      self._cleared = True
   570      self._added_elements = []
   571  
   572    def commit(self):
   573      # type: () -> None
   574      to_await = None
   575      if self._cleared:
   576        to_await = self._state_handler.clear(self._state_key)
   577      if self._added_elements:
   578        to_await = self._state_handler.extend(
   579            self._state_key, self._value_coder.get_impl(), self._added_elements)
   580      if to_await:
   581        # To commit, we need to wait on the last state request future to complete.
   582        to_await.get()
   583  
   584  
   585  class SynchronousSetRuntimeState(userstate.SetRuntimeState):
   586  
   587    def __init__(self,
   588                 state_handler,  # type: sdk_worker.CachingStateHandler
   589                 state_key,  # type: beam_fn_api_pb2.StateKey
   590                 value_coder  # type: coders.Coder
   591                ):
   592      # type: (...) -> None
   593      self._state_handler = state_handler
   594      self._state_key = state_key
   595      self._value_coder = value_coder
   596      self._cleared = False
   597      self._added_elements = set()  # type: Set[Any]
   598  
   599    def _compact_data(self, rewrite=True):
   600      accumulator = set(
   601          _ConcatIterable(
   602              set() if self._cleared else _StateBackedIterable(
   603                  self._state_handler, self._state_key, self._value_coder),
   604              self._added_elements))
   605  
   606      if rewrite and accumulator:
   607        self._state_handler.clear(self._state_key)
   608        self._state_handler.extend(
   609            self._state_key, self._value_coder.get_impl(), accumulator)
   610  
   611        # Since everthing is already committed so we can safely reinitialize
   612        # added_elements here.
   613        self._added_elements = set()
   614  
   615      return accumulator
   616  
   617    def read(self):
   618      # type: () -> Set[Any]
   619      return self._compact_data(rewrite=False)
   620  
   621    def add(self, value):
   622      # type: (Any) -> None
   623      if self._cleared:
   624        # This is a good time explicitly clear.
   625        self._state_handler.clear(self._state_key)
   626        self._cleared = False
   627  
   628      self._added_elements.add(value)
   629      if random.random() > 0.5:
   630        self._compact_data()
   631  
   632    def clear(self):
   633      # type: () -> None
   634      self._cleared = True
   635      self._added_elements = set()
   636  
   637    def commit(self):
   638      # type: () -> None
   639      to_await = None
   640      if self._cleared:
   641        to_await = self._state_handler.clear(self._state_key)
   642      if self._added_elements:
   643        to_await = self._state_handler.extend(
   644            self._state_key, self._value_coder.get_impl(), self._added_elements)
   645      if to_await:
   646        # To commit, we need to wait on the last state request future to complete.
   647        to_await.get()
   648  
   649  
   650  class OutputTimer(userstate.BaseTimer):
   651    def __init__(self,
   652                 key,
   653                 window,  # type: BoundedWindow
   654                 timestamp,  # type: timestamp.Timestamp
   655                 paneinfo,  # type: windowed_value.PaneInfo
   656                 time_domain, # type: str
   657                 timer_family_id,  # type: str
   658                 timer_coder_impl,  # type: coder_impl.TimerCoderImpl
   659                 output_stream  # type: data_plane.ClosableOutputStream
   660                 ):
   661      self._key = key
   662      self._window = window
   663      self._input_timestamp = timestamp
   664      self._paneinfo = paneinfo
   665      self._time_domain = time_domain
   666      self._timer_family_id = timer_family_id
   667      self._output_stream = output_stream
   668      self._timer_coder_impl = timer_coder_impl
   669  
   670    def set(self, ts: timestamp.TimestampTypes, dynamic_timer_tag='') -> None:
   671      ts = timestamp.Timestamp.of(ts)
   672      timer = userstate.Timer(
   673          user_key=self._key,
   674          dynamic_timer_tag=dynamic_timer_tag,
   675          windows=(self._window, ),
   676          clear_bit=False,
   677          fire_timestamp=ts,
   678          hold_timestamp=ts if TimeDomain.is_event_time(self._time_domain) else
   679          self._input_timestamp,
   680          paneinfo=self._paneinfo)
   681      self._timer_coder_impl.encode_to_stream(timer, self._output_stream, True)
   682      self._output_stream.maybe_flush()
   683  
   684    def clear(self, dynamic_timer_tag='') -> None:
   685      timer = userstate.Timer(
   686          user_key=self._key,
   687          dynamic_timer_tag=dynamic_timer_tag,
   688          windows=(self._window, ),
   689          clear_bit=True,
   690          fire_timestamp=None,
   691          hold_timestamp=None,
   692          paneinfo=None)
   693      self._timer_coder_impl.encode_to_stream(timer, self._output_stream, True)
   694      self._output_stream.maybe_flush()
   695  
   696  
   697  class TimerInfo(object):
   698    """A data class to store information related to a timer."""
   699    def __init__(self, timer_coder_impl, output_stream=None):
   700      self.timer_coder_impl = timer_coder_impl
   701      self.output_stream = output_stream
   702  
   703  
   704  class FnApiUserStateContext(userstate.UserStateContext):
   705    """Interface for state and timers from SDK to Fn API servicer of state.."""
   706  
   707    def __init__(self,
   708                 state_handler,  # type: sdk_worker.CachingStateHandler
   709                 transform_id,  # type: str
   710                 key_coder,  # type: coders.Coder
   711                 window_coder,  # type: coders.Coder
   712                ):
   713      # type: (...) -> None
   714  
   715      """Initialize a ``FnApiUserStateContext``.
   716  
   717      Args:
   718        state_handler: A StateServicer object.
   719        transform_id: The name of the PTransform that this context is associated.
   720        key_coder: Coder for the key type.
   721        window_coder: Coder for the window type.
   722      """
   723      self._state_handler = state_handler
   724      self._transform_id = transform_id
   725      self._key_coder = key_coder
   726      self._window_coder = window_coder
   727      # A mapping of {timer_family_id: TimerInfo}
   728      self._timers_info = {}  # type: Dict[str, TimerInfo]
   729      self._all_states = {}  # type: Dict[tuple, FnApiUserRuntimeStateTypes]
   730  
   731    def add_timer_info(self, timer_family_id, timer_info):
   732      # type: (str, TimerInfo) -> None
   733      self._timers_info[timer_family_id] = timer_info
   734  
   735    def get_timer(
   736        self, timer_spec: userstate.TimerSpec, key, window, timestamp,
   737        pane) -> OutputTimer:
   738      assert self._timers_info[timer_spec.name].output_stream is not None
   739      timer_coder_impl = self._timers_info[timer_spec.name].timer_coder_impl
   740      output_stream = self._timers_info[timer_spec.name].output_stream
   741      return OutputTimer(
   742          key,
   743          window,
   744          timestamp,
   745          pane,
   746          timer_spec.time_domain,
   747          timer_spec.name,
   748          timer_coder_impl,
   749          output_stream)
   750  
   751    def get_state(self, *args):
   752      # type: (*Any) -> FnApiUserRuntimeStateTypes
   753      state_handle = self._all_states.get(args)
   754      if state_handle is None:
   755        state_handle = self._all_states[args] = self._create_state(*args)
   756      return state_handle
   757  
   758    def _create_state(self,
   759                      state_spec,  # type: userstate.StateSpec
   760                      key,
   761                      window  # type: BoundedWindow
   762                     ):
   763      # type: (...) -> FnApiUserRuntimeStateTypes
   764      if isinstance(state_spec,
   765                    (userstate.BagStateSpec,
   766                     userstate.CombiningValueStateSpec,
   767                     userstate.ReadModifyWriteStateSpec)):
   768        bag_state = SynchronousBagRuntimeState(
   769            self._state_handler,
   770            state_key=beam_fn_api_pb2.StateKey(
   771                bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
   772                    transform_id=self._transform_id,
   773                    user_state_id=state_spec.name,
   774                    window=self._window_coder.encode(window),
   775                    # State keys are expected in nested encoding format
   776                    key=self._key_coder.encode_nested(key))),
   777            value_coder=state_spec.coder)
   778        if isinstance(state_spec, userstate.BagStateSpec):
   779          return bag_state
   780        elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
   781          return ReadModifyWriteRuntimeState(bag_state)
   782        else:
   783          return CombiningValueRuntimeState(
   784              bag_state, copy.deepcopy(state_spec.combine_fn))
   785      elif isinstance(state_spec, userstate.SetStateSpec):
   786        return SynchronousSetRuntimeState(
   787            self._state_handler,
   788            state_key=beam_fn_api_pb2.StateKey(
   789                bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
   790                    transform_id=self._transform_id,
   791                    user_state_id=state_spec.name,
   792                    window=self._window_coder.encode(window),
   793                    # State keys are expected in nested encoding format
   794                    key=self._key_coder.encode_nested(key))),
   795            value_coder=state_spec.coder)
   796      else:
   797        raise NotImplementedError(state_spec)
   798  
   799    def commit(self):
   800      # type: () -> None
   801      for state in self._all_states.values():
   802        state.commit()
   803  
   804    def reset(self):
   805      # type: () -> None
   806      for state in self._all_states.values():
   807        state.finalize()
   808      self._all_states = {}
   809  
   810  
   811  def memoize(func):
   812    cache = {}
   813    missing = object()
   814  
   815    def wrapper(*args):
   816      result = cache.get(args, missing)
   817      if result is missing:
   818        result = cache[args] = func(*args)
   819      return result
   820  
   821    return wrapper
   822  
   823  
   824  def only_element(iterable):
   825    # type: (Iterable[T]) -> T
   826    element, = iterable
   827    return element
   828  
   829  
   830  def _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor):
   831    # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None
   832  
   833    runtime_sdk = environments.sdk_base_version_capability()
   834    for t in process_bundle_descriptor.transforms.values():
   835      env = process_bundle_descriptor.environments[t.environment_id]
   836      for c in env.capabilities:
   837        if (c.startswith(environments.SDK_VERSION_CAPABILITY_PREFIX) and
   838            c != runtime_sdk):
   839          raise RuntimeError(
   840              "Pipeline construction environment and pipeline runtime "
   841              "environment are not compatible. If you use a custom "
   842              "container image, check that the Python interpreter minor version "
   843              "and the Apache Beam version in your image match the versions "
   844              "used at pipeline construction time. "
   845              f"Submission environment: {c}. "
   846              f"Runtime environment: {runtime_sdk}.")
   847  
   848    # TODO: Consider warning on mismatches in versions of installed packages.
   849  
   850  
   851  class BundleProcessor(object):
   852    """ A class for processing bundles of elements. """
   853  
   854    def __init__(self,
   855                 process_bundle_descriptor,  # type: beam_fn_api_pb2.ProcessBundleDescriptor
   856                 state_handler,  # type: sdk_worker.CachingStateHandler
   857                 data_channel_factory,  # type: data_plane.DataChannelFactory
   858                 data_sampler=None,  # type: Optional[data_sampler.DataSampler]
   859                ):
   860      # type: (...) -> None
   861  
   862      """Initialize a bundle processor.
   863  
   864      Args:
   865        process_bundle_descriptor (``beam_fn_api_pb2.ProcessBundleDescriptor``):
   866          a description of the stage that this ``BundleProcessor``is to execute.
   867        state_handler (CachingStateHandler).
   868        data_channel_factory (``data_plane.DataChannelFactory``).
   869      """
   870      self.process_bundle_descriptor = process_bundle_descriptor
   871      self.state_handler = state_handler
   872      self.data_channel_factory = data_channel_factory
   873      self.data_sampler = data_sampler
   874      self.current_instruction_id = None  # type: Optional[str]
   875  
   876      _verify_descriptor_created_in_a_compatible_env(process_bundle_descriptor)
   877      # There is no guarantee that the runner only set
   878      # timer_api_service_descriptor when having timers. So this field cannot be
   879      # used as an indicator of timers.
   880      if self.process_bundle_descriptor.timer_api_service_descriptor.url:
   881        self.timer_data_channel = (
   882            data_channel_factory.create_data_channel_from_url(
   883                self.process_bundle_descriptor.timer_api_service_descriptor.url))
   884      else:
   885        self.timer_data_channel = None
   886  
   887      # A mapping of
   888      # {(transform_id, timer_family_id): TimerInfo}
   889      # The mapping is empty when there is no timer_family_specs in the
   890      # ProcessBundleDescriptor.
   891      self.timers_info = {}  # type: Dict[Tuple[str, str], TimerInfo]
   892  
   893      # TODO(robertwb): Figure out the correct prefix to use for output counters
   894      # from StateSampler.
   895      self.counter_factory = counters.CounterFactory()
   896      self.state_sampler = statesampler.StateSampler(
   897          'fnapi-step-%s' % self.process_bundle_descriptor.id,
   898          self.counter_factory)
   899  
   900      if self.data_sampler:
   901        self.add_data_sampling_operations(process_bundle_descriptor)
   902  
   903      self.ops = self.create_execution_tree(self.process_bundle_descriptor)
   904      for op in reversed(self.ops.values()):
   905        op.setup()
   906      self.splitting_lock = threading.Lock()
   907  
   908    def add_data_sampling_operations(self, pbd):
   909      # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None
   910  
   911      """Adds a DataSamplingOperation to every PCollection.
   912  
   913      Implementation note: the alternative to this, is to add modify each
   914      Operation and forward a DataSampler to manually sample when an element is
   915      processed. This gets messy very quickly and is not future-proof as new
   916      operation types will need to be updated. This is the cleanest way of adding
   917      new operations to the final execution tree.
   918      """
   919      coder = coders.FastPrimitivesCoder()
   920  
   921      for pcoll_id in pbd.pcollections:
   922        transform_id = 'synthetic-data-sampling-transform-{}'.format(pcoll_id)
   923        transform_proto: beam_runner_api_pb2.PTransform = pbd.transforms[
   924            transform_id]
   925        transform_proto.unique_name = transform_id
   926        transform_proto.spec.urn = SYNTHETIC_DATA_SAMPLING_URN
   927  
   928        coder_id = pbd.pcollections[pcoll_id].coder_id
   929        transform_proto.spec.payload = coder.encode((pcoll_id, coder_id))
   930  
   931        transform_proto.inputs['None'] = pcoll_id
   932  
   933    def create_execution_tree(
   934        self,
   935        descriptor  # type: beam_fn_api_pb2.ProcessBundleDescriptor
   936    ):
   937      # type: (...) -> collections.OrderedDict[str, operations.DoOperation]
   938      transform_factory = BeamTransformFactory(
   939          descriptor,
   940          self.data_channel_factory,
   941          self.counter_factory,
   942          self.state_sampler,
   943          self.state_handler,
   944          self.data_sampler)
   945  
   946      self.timers_info = transform_factory.extract_timers_info()
   947  
   948      def is_side_input(transform_proto, tag):
   949        if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn:
   950          return tag in proto_utils.parse_Bytes(
   951              transform_proto.spec.payload,
   952              beam_runner_api_pb2.ParDoPayload).side_inputs
   953  
   954      pcoll_consumers = collections.defaultdict(
   955          list)  # type: DefaultDict[str, List[str]]
   956      for transform_id, transform_proto in descriptor.transforms.items():
   957        for tag, pcoll_id in transform_proto.inputs.items():
   958          if not is_side_input(transform_proto, tag):
   959            pcoll_consumers[pcoll_id].append(transform_id)
   960  
   961      @memoize
   962      def get_operation(transform_id):
   963        # type: (str) -> operations.Operation
   964        transform_consumers = {
   965            tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]]
   966            for tag,
   967            pcoll_id in descriptor.transforms[transform_id].outputs.items()
   968        }
   969        return transform_factory.create_operation(
   970            transform_id, transform_consumers)
   971  
   972      # Operations must be started (hence returned) in order.
   973      @memoize
   974      def topological_height(transform_id):
   975        # type: (str) -> int
   976        return 1 + max([0] + [
   977            topological_height(consumer)
   978            for pcoll in descriptor.transforms[transform_id].outputs.values()
   979            for consumer in pcoll_consumers[pcoll]
   980        ])
   981  
   982      return collections.OrderedDict([(
   983          transform_id,
   984          cast(operations.DoOperation,
   985               get_operation(transform_id))) for transform_id in sorted(
   986                   descriptor.transforms, key=topological_height, reverse=True)])
   987  
   988    def reset(self):
   989      # type: () -> None
   990      self.counter_factory.reset()
   991      self.state_sampler.reset()
   992      # Side input caches.
   993      for op in self.ops.values():
   994        op.reset()
   995  
   996    def process_bundle(self, instruction_id):
   997      # type: (str) -> Tuple[List[beam_fn_api_pb2.DelayedBundleApplication], bool]
   998  
   999      expected_input_ops = []  # type: List[DataInputOperation]
  1000  
  1001      for op in self.ops.values():
  1002        if isinstance(op, DataOutputOperation):
  1003          # TODO(robertwb): Is there a better way to pass the instruction id to
  1004          # the operation?
  1005          op.set_output_stream(
  1006              op.data_channel.output_stream(instruction_id, op.transform_id))
  1007        elif isinstance(op, DataInputOperation):
  1008          # We must wait until we receive "end of stream" for each of these ops.
  1009          expected_input_ops.append(op)
  1010  
  1011      try:
  1012        execution_context = ExecutionContext()
  1013        self.current_instruction_id = instruction_id
  1014        self.state_sampler.start()
  1015        # Start all operations.
  1016        for op in reversed(self.ops.values()):
  1017          _LOGGER.debug('start %s', op)
  1018          op.execution_context = execution_context
  1019          op.start()
  1020  
  1021        # Each data_channel is mapped to a list of expected inputs which includes
  1022        # both data input and timer input. The data input is identied by
  1023        # transform_id. The data input is identified by
  1024        # (transform_id, timer_family_id).
  1025        data_channels = collections.defaultdict(
  1026            list
  1027        )  # type: DefaultDict[data_plane.DataChannel, List[Union[str, Tuple[str, str]]]]
  1028  
  1029        # Add expected data inputs for each data channel.
  1030        input_op_by_transform_id = {}
  1031        for input_op in expected_input_ops:
  1032          data_channels[input_op.data_channel].append(input_op.transform_id)
  1033          input_op_by_transform_id[input_op.transform_id] = input_op
  1034  
  1035        # Update timer_data channel with expected timer inputs.
  1036        if self.timer_data_channel:
  1037          data_channels[self.timer_data_channel].extend(
  1038              list(self.timers_info.keys()))
  1039  
  1040          # Set up timer output stream for DoOperation.
  1041          for ((transform_id, timer_family_id),
  1042               timer_info) in self.timers_info.items():
  1043            output_stream = self.timer_data_channel.output_timer_stream(
  1044                instruction_id, transform_id, timer_family_id)
  1045            timer_info.output_stream = output_stream
  1046            self.ops[transform_id].add_timer_info(timer_family_id, timer_info)
  1047  
  1048        # Process data and timer inputs
  1049        for data_channel, expected_inputs in data_channels.items():
  1050          for element in data_channel.input_elements(instruction_id,
  1051                                                     expected_inputs):
  1052            if isinstance(element, beam_fn_api_pb2.Elements.Timers):
  1053              timer_coder_impl = (
  1054                  self.timers_info[(
  1055                      element.transform_id,
  1056                      element.timer_family_id)].timer_coder_impl)
  1057              for timer_data in timer_coder_impl.decode_all(element.timers):
  1058                self.ops[element.transform_id].process_timer(
  1059                    element.timer_family_id, timer_data)
  1060            elif isinstance(element, beam_fn_api_pb2.Elements.Data):
  1061              input_op_by_transform_id[element.transform_id].process_encoded(
  1062                  element.data)
  1063  
  1064        # Finish all operations.
  1065        for op in self.ops.values():
  1066          _LOGGER.debug('finish %s', op)
  1067          op.finish()
  1068  
  1069        # Close every timer output stream
  1070        for timer_info in self.timers_info.values():
  1071          assert timer_info.output_stream is not None
  1072          timer_info.output_stream.close()
  1073  
  1074        return ([
  1075            self.delayed_bundle_application(op, residual) for op,
  1076            residual in execution_context.delayed_applications
  1077        ],
  1078                self.requires_finalization())
  1079  
  1080      finally:
  1081        # Ensure any in-flight split attempts complete.
  1082        with self.splitting_lock:
  1083          self.current_instruction_id = None
  1084        self.state_sampler.stop_if_still_running()
  1085  
  1086    def finalize_bundle(self):
  1087      # type: () -> beam_fn_api_pb2.FinalizeBundleResponse
  1088      for op in self.ops.values():
  1089        op.finalize_bundle()
  1090      return beam_fn_api_pb2.FinalizeBundleResponse()
  1091  
  1092    def requires_finalization(self):
  1093      # type: () -> bool
  1094      return any(op.needs_finalization() for op in self.ops.values())
  1095  
  1096    def try_split(self, bundle_split_request):
  1097      # type: (beam_fn_api_pb2.ProcessBundleSplitRequest) -> beam_fn_api_pb2.ProcessBundleSplitResponse
  1098      split_response = beam_fn_api_pb2.ProcessBundleSplitResponse()
  1099      with self.splitting_lock:
  1100        if bundle_split_request.instruction_id != self.current_instruction_id:
  1101          # This may be a delayed split for a former bundle, see BEAM-12475.
  1102          return split_response
  1103  
  1104        for op in self.ops.values():
  1105          if isinstance(op, DataInputOperation):
  1106            desired_split = bundle_split_request.desired_splits.get(
  1107                op.transform_id)
  1108            if desired_split:
  1109              split = op.try_split(
  1110                  desired_split.fraction_of_remainder,
  1111                  desired_split.estimated_input_elements,
  1112                  desired_split.allowed_split_points)
  1113              if split:
  1114                (
  1115                    primary_end,
  1116                    element_primaries,
  1117                    element_residuals,
  1118                    residual_start,
  1119                ) = split
  1120                for element_primary in element_primaries:
  1121                  split_response.primary_roots.add().CopyFrom(
  1122                      self.bundle_application(*element_primary))
  1123                for element_residual in element_residuals:
  1124                  split_response.residual_roots.add().CopyFrom(
  1125                      self.delayed_bundle_application(*element_residual))
  1126                split_response.channel_splits.extend([
  1127                    beam_fn_api_pb2.ProcessBundleSplitResponse.ChannelSplit(
  1128                        transform_id=op.transform_id,
  1129                        last_primary_element=primary_end,
  1130                        first_residual_element=residual_start)
  1131                ])
  1132  
  1133      return split_response
  1134  
  1135    def delayed_bundle_application(self,
  1136                                   op,  # type: operations.DoOperation
  1137                                   deferred_remainder  # type: SplitResultResidual
  1138                                  ):
  1139      # type: (...) -> beam_fn_api_pb2.DelayedBundleApplication
  1140      assert op.input_info is not None
  1141      # TODO(SDF): For non-root nodes, need main_input_coder + residual_coder.
  1142      (element_and_restriction, current_watermark, deferred_timestamp) = (
  1143          deferred_remainder)
  1144      if deferred_timestamp:
  1145        assert isinstance(deferred_timestamp, timestamp.Duration)
  1146        proto_deferred_watermark = proto_utils.from_micros(
  1147            duration_pb2.Duration,
  1148            deferred_timestamp.micros)  # type: Optional[duration_pb2.Duration]
  1149      else:
  1150        proto_deferred_watermark = None
  1151      return beam_fn_api_pb2.DelayedBundleApplication(
  1152          requested_time_delay=proto_deferred_watermark,
  1153          application=self.construct_bundle_application(
  1154              op.input_info, current_watermark, element_and_restriction))
  1155  
  1156    def bundle_application(self,
  1157                           op,  # type: operations.DoOperation
  1158                           primary  # type: SplitResultPrimary
  1159                          ):
  1160      # type: (...) -> beam_fn_api_pb2.BundleApplication
  1161      assert op.input_info is not None
  1162      return self.construct_bundle_application(
  1163          op.input_info, None, primary.primary_value)
  1164  
  1165    def construct_bundle_application(self,
  1166                                     op_input_info,  # type: operations.OpInputInfo
  1167                                     output_watermark,  # type: Optional[timestamp.Timestamp]
  1168                                     element
  1169                                    ):
  1170      # type: (...) -> beam_fn_api_pb2.BundleApplication
  1171      transform_id, main_input_tag, main_input_coder, outputs = op_input_info
  1172      if output_watermark:
  1173        proto_output_watermark = proto_utils.from_micros(
  1174            timestamp_pb2.Timestamp, output_watermark.micros)
  1175        output_watermarks = {
  1176            output: proto_output_watermark
  1177            for output in outputs
  1178        }  # type: Optional[Dict[str, timestamp_pb2.Timestamp]]
  1179      else:
  1180        output_watermarks = None
  1181      return beam_fn_api_pb2.BundleApplication(
  1182          transform_id=transform_id,
  1183          input_id=main_input_tag,
  1184          output_watermarks=output_watermarks,
  1185          element=main_input_coder.get_impl().encode_nested(element))
  1186  
  1187    def monitoring_infos(self):
  1188      # type: () -> List[metrics_pb2.MonitoringInfo]
  1189  
  1190      """Returns the list of MonitoringInfos collected processing this bundle."""
  1191      # Construct a new dict first to remove duplicates.
  1192      all_monitoring_infos_dict = {}
  1193      for transform_id, op in self.ops.items():
  1194        tag_to_pcollection_id = self.process_bundle_descriptor.transforms[
  1195            transform_id].outputs
  1196        all_monitoring_infos_dict.update(
  1197            op.monitoring_infos(transform_id, dict(tag_to_pcollection_id)))
  1198  
  1199      return list(all_monitoring_infos_dict.values())
  1200  
  1201    def shutdown(self):
  1202      # type: () -> None
  1203      for op in self.ops.values():
  1204        op.teardown()
  1205  
  1206  
  1207  class ExecutionContext(object):
  1208    def __init__(self):
  1209      self.delayed_applications = [
  1210      ]  # type: List[Tuple[operations.DoOperation, common.SplitResultResidual]]
  1211  
  1212  
  1213  class BeamTransformFactory(object):
  1214    """Factory for turning transform_protos into executable operations."""
  1215    def __init__(self,
  1216                 descriptor,  # type: beam_fn_api_pb2.ProcessBundleDescriptor
  1217                 data_channel_factory,  # type: data_plane.DataChannelFactory
  1218                 counter_factory,  # type: counters.CounterFactory
  1219                 state_sampler,  # type: statesampler.StateSampler
  1220                 state_handler,  # type: sdk_worker.CachingStateHandler
  1221                 data_sampler,  # type: Optional[data_sampler.DataSampler]
  1222                ):
  1223      self.descriptor = descriptor
  1224      self.data_channel_factory = data_channel_factory
  1225      self.counter_factory = counter_factory
  1226      self.state_sampler = state_sampler
  1227      self.state_handler = state_handler
  1228      self.context = pipeline_context.PipelineContext(
  1229          descriptor,
  1230          iterable_state_read=lambda token,
  1231          element_coder_impl: _StateBackedIterable(
  1232              state_handler,
  1233              beam_fn_api_pb2.StateKey(
  1234                  runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
  1235              element_coder_impl))
  1236      self.data_sampler = data_sampler
  1237  
  1238    _known_urns = {
  1239    }  # type: Dict[str, Tuple[ConstructorFn, Union[Type[message.Message], Type[bytes], None]]]
  1240  
  1241    @classmethod
  1242    def register_urn(
  1243        cls,
  1244        urn,  # type: str
  1245        parameter_type  # type: Optional[Type[T]]
  1246    ):
  1247      # type: (...) -> Callable[[Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]], Callable[[BeamTransformFactory, str, beam_runner_api_pb2.PTransform, T, Dict[str, List[operations.Operation]]], operations.Operation]]
  1248      def wrapper(func):
  1249        cls._known_urns[urn] = func, parameter_type
  1250        return func
  1251  
  1252      return wrapper
  1253  
  1254    def create_operation(self,
  1255                         transform_id,  # type: str
  1256                         consumers  # type: Dict[str, List[operations.Operation]]
  1257                        ):
  1258      # type: (...) -> operations.Operation
  1259      transform_proto = self.descriptor.transforms[transform_id]
  1260      if not transform_proto.unique_name:
  1261        _LOGGER.debug("No unique name set for transform %s" % transform_id)
  1262        transform_proto.unique_name = transform_id
  1263      creator, parameter_type = self._known_urns[transform_proto.spec.urn]
  1264      payload = proto_utils.parse_Bytes(
  1265          transform_proto.spec.payload, parameter_type)
  1266      return creator(self, transform_id, transform_proto, payload, consumers)
  1267  
  1268    def extract_timers_info(self):
  1269      # type: () -> Dict[Tuple[str, str], TimerInfo]
  1270      timers_info = {}
  1271      for transform_id, transform_proto in self.descriptor.transforms.items():
  1272        if transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn:
  1273          pardo_payload = proto_utils.parse_Bytes(
  1274              transform_proto.spec.payload, beam_runner_api_pb2.ParDoPayload)
  1275          for (timer_family_id,
  1276               timer_family_spec) in pardo_payload.timer_family_specs.items():
  1277            timer_coder_impl = self.get_coder(
  1278                timer_family_spec.timer_family_coder_id).get_impl()
  1279            # The output_stream should be updated when processing a bundle.
  1280            timers_info[(transform_id, timer_family_id)] = TimerInfo(
  1281                timer_coder_impl=timer_coder_impl)
  1282      return timers_info
  1283  
  1284    def get_coder(self, coder_id):
  1285      # type: (str) -> coders.Coder
  1286      if coder_id not in self.descriptor.coders:
  1287        raise KeyError("No such coder: %s" % coder_id)
  1288      coder_proto = self.descriptor.coders[coder_id]
  1289      if coder_proto.spec.urn:
  1290        return self.context.coders.get_by_id(coder_id)
  1291      else:
  1292        # No URN, assume cloud object encoding json bytes.
  1293        return operation_specs.get_coder_from_spec(
  1294            json.loads(coder_proto.spec.payload.decode('utf-8')))
  1295  
  1296    def get_windowed_coder(self, pcoll_id):
  1297      # type: (str) -> WindowedValueCoder
  1298      coder = self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id)
  1299      # TODO(robertwb): Remove this condition once all runners are consistent.
  1300      if not isinstance(coder, WindowedValueCoder):
  1301        windowing_strategy = self.descriptor.windowing_strategies[
  1302            self.descriptor.pcollections[pcoll_id].windowing_strategy_id]
  1303        return WindowedValueCoder(
  1304            coder, self.get_coder(windowing_strategy.window_coder_id))
  1305      else:
  1306        return coder
  1307  
  1308    def get_output_coders(self, transform_proto):
  1309      # type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.Coder]
  1310      return {
  1311          tag: self.get_windowed_coder(pcoll_id)
  1312          for tag,
  1313          pcoll_id in transform_proto.outputs.items()
  1314      }
  1315  
  1316    def get_only_output_coder(self, transform_proto):
  1317      # type: (beam_runner_api_pb2.PTransform) -> coders.Coder
  1318      return only_element(self.get_output_coders(transform_proto).values())
  1319  
  1320    def get_input_coders(self, transform_proto):
  1321      # type: (beam_runner_api_pb2.PTransform) -> Dict[str, coders.WindowedValueCoder]
  1322      return {
  1323          tag: self.get_windowed_coder(pcoll_id)
  1324          for tag,
  1325          pcoll_id in transform_proto.inputs.items()
  1326      }
  1327  
  1328    def get_only_input_coder(self, transform_proto):
  1329      # type: (beam_runner_api_pb2.PTransform) -> coders.Coder
  1330      return only_element(list(self.get_input_coders(transform_proto).values()))
  1331  
  1332    def get_input_windowing(self, transform_proto):
  1333      # type: (beam_runner_api_pb2.PTransform) -> Windowing
  1334      pcoll_id = only_element(transform_proto.inputs.values())
  1335      windowing_strategy_id = self.descriptor.pcollections[
  1336          pcoll_id].windowing_strategy_id
  1337      return self.context.windowing_strategies.get_by_id(windowing_strategy_id)
  1338  
  1339    # TODO(robertwb): Update all operations to take these in the constructor.
  1340    @staticmethod
  1341    def augment_oldstyle_op(
  1342        op,  # type: OperationT
  1343        step_name,  # type: str
  1344        consumers,  # type: Mapping[str, Iterable[operations.Operation]]
  1345        tag_list=None  # type: Optional[List[str]]
  1346    ):
  1347      # type: (...) -> OperationT
  1348      op.step_name = step_name
  1349      for tag, op_consumers in consumers.items():
  1350        for consumer in op_consumers:
  1351          op.add_receiver(consumer, tag_list.index(tag) if tag_list else 0)
  1352      return op
  1353  
  1354  
  1355  @BeamTransformFactory.register_urn(
  1356      DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
  1357  def create_source_runner(
  1358      factory,  # type: BeamTransformFactory
  1359      transform_id,  # type: str
  1360      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1361      grpc_port,  # type: beam_fn_api_pb2.RemoteGrpcPort
  1362      consumers  # type: Dict[str, List[operations.Operation]]
  1363  ):
  1364    # type: (...) -> DataInputOperation
  1365  
  1366    output_coder = factory.get_coder(grpc_port.coder_id)
  1367    return DataInputOperation(
  1368        common.NameContext(transform_proto.unique_name, transform_id),
  1369        transform_proto.unique_name,
  1370        consumers,
  1371        factory.counter_factory,
  1372        factory.state_sampler,
  1373        output_coder,
  1374        transform_id=transform_id,
  1375        data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
  1376  
  1377  
  1378  @BeamTransformFactory.register_urn(
  1379      DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
  1380  def create_sink_runner(
  1381      factory,  # type: BeamTransformFactory
  1382      transform_id,  # type: str
  1383      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1384      grpc_port,  # type: beam_fn_api_pb2.RemoteGrpcPort
  1385      consumers  # type: Dict[str, List[operations.Operation]]
  1386  ):
  1387    # type: (...) -> DataOutputOperation
  1388    output_coder = factory.get_coder(grpc_port.coder_id)
  1389    return DataOutputOperation(
  1390        common.NameContext(transform_proto.unique_name, transform_id),
  1391        transform_proto.unique_name,
  1392        consumers,
  1393        factory.counter_factory,
  1394        factory.state_sampler,
  1395        output_coder,
  1396        transform_id=transform_id,
  1397        data_channel=factory.data_channel_factory.create_data_channel(grpc_port))
  1398  
  1399  
  1400  @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_READ_URN, None)
  1401  def create_source_java(
  1402      factory,  # type: BeamTransformFactory
  1403      transform_id,  # type: str
  1404      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1405      parameter,
  1406      consumers  # type: Dict[str, List[operations.Operation]]
  1407  ):
  1408    # type: (...) -> operations.ReadOperation
  1409    # The Dataflow runner harness strips the base64 encoding.
  1410    source = pickler.loads(base64.b64encode(parameter))
  1411    spec = operation_specs.WorkerRead(
  1412        iobase.SourceBundle(1.0, source, None, None),
  1413        [factory.get_only_output_coder(transform_proto)])
  1414    return factory.augment_oldstyle_op(
  1415        operations.ReadOperation(
  1416            common.NameContext(transform_proto.unique_name, transform_id),
  1417            spec,
  1418            factory.counter_factory,
  1419            factory.state_sampler),
  1420        transform_proto.unique_name,
  1421        consumers)
  1422  
  1423  
  1424  @BeamTransformFactory.register_urn(
  1425      common_urns.deprecated_primitives.READ.urn, beam_runner_api_pb2.ReadPayload)
  1426  def create_deprecated_read(
  1427      factory,  # type: BeamTransformFactory
  1428      transform_id,  # type: str
  1429      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1430      parameter,  # type: beam_runner_api_pb2.ReadPayload
  1431      consumers  # type: Dict[str, List[operations.Operation]]
  1432  ):
  1433    # type: (...) -> operations.ReadOperation
  1434    source = iobase.BoundedSource.from_runner_api(
  1435        parameter.source, factory.context)
  1436    spec = operation_specs.WorkerRead(
  1437        iobase.SourceBundle(1.0, source, None, None),
  1438        [WindowedValueCoder(source.default_output_coder())])
  1439    return factory.augment_oldstyle_op(
  1440        operations.ReadOperation(
  1441            common.NameContext(transform_proto.unique_name, transform_id),
  1442            spec,
  1443            factory.counter_factory,
  1444            factory.state_sampler),
  1445        transform_proto.unique_name,
  1446        consumers)
  1447  
  1448  
  1449  @BeamTransformFactory.register_urn(
  1450      python_urns.IMPULSE_READ_TRANSFORM, beam_runner_api_pb2.ReadPayload)
  1451  def create_read_from_impulse_python(
  1452      factory,  # type: BeamTransformFactory
  1453      transform_id,  # type: str
  1454      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1455      parameter,  # type: beam_runner_api_pb2.ReadPayload
  1456      consumers  # type: Dict[str, List[operations.Operation]]
  1457  ):
  1458    # type: (...) -> operations.ImpulseReadOperation
  1459    return operations.ImpulseReadOperation(
  1460        common.NameContext(transform_proto.unique_name, transform_id),
  1461        factory.counter_factory,
  1462        factory.state_sampler,
  1463        consumers,
  1464        iobase.BoundedSource.from_runner_api(parameter.source, factory.context),
  1465        factory.get_only_output_coder(transform_proto))
  1466  
  1467  
  1468  @BeamTransformFactory.register_urn(OLD_DATAFLOW_RUNNER_HARNESS_PARDO_URN, None)
  1469  def create_dofn_javasdk(
  1470      factory,  # type: BeamTransformFactory
  1471      transform_id,  # type: str
  1472      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1473      serialized_fn,
  1474      consumers  # type: Dict[str, List[operations.Operation]]
  1475  ):
  1476    return _create_pardo_operation(
  1477        factory, transform_id, transform_proto, consumers, serialized_fn)
  1478  
  1479  
  1480  @BeamTransformFactory.register_urn(
  1481      common_urns.sdf_components.PAIR_WITH_RESTRICTION.urn,
  1482      beam_runner_api_pb2.ParDoPayload)
  1483  def create_pair_with_restriction(*args):
  1484    class PairWithRestriction(beam.DoFn):
  1485      def __init__(self, fn, restriction_provider, watermark_estimator_provider):
  1486        self.restriction_provider = restriction_provider
  1487        self.watermark_estimator_provider = watermark_estimator_provider
  1488  
  1489      def process(self, element, *args, **kwargs):
  1490        # TODO(SDF): Do we want to allow mutation of the element?
  1491        # (E.g. it could be nice to shift bulky description to the portion
  1492        # that can be distributed.)
  1493        initial_restriction = self.restriction_provider.initial_restriction(
  1494            element)
  1495        initial_estimator_state = (
  1496            self.watermark_estimator_provider.initial_estimator_state(
  1497                element, initial_restriction))
  1498        yield (element, (initial_restriction, initial_estimator_state))
  1499  
  1500    return _create_sdf_operation(PairWithRestriction, *args)
  1501  
  1502  
  1503  @BeamTransformFactory.register_urn(
  1504      common_urns.sdf_components.SPLIT_AND_SIZE_RESTRICTIONS.urn,
  1505      beam_runner_api_pb2.ParDoPayload)
  1506  def create_split_and_size_restrictions(*args):
  1507    class SplitAndSizeRestrictions(beam.DoFn):
  1508      def __init__(self, fn, restriction_provider, watermark_estimator_provider):
  1509        self.restriction_provider = restriction_provider
  1510        self.watermark_estimator_provider = watermark_estimator_provider
  1511  
  1512      def process(self, element_restriction, *args, **kwargs):
  1513        element, (restriction, _) = element_restriction
  1514        for part, size in self.restriction_provider.split_and_size(
  1515            element, restriction):
  1516          if size < 0:
  1517            raise ValueError('Expected size >= 0 but received %s.' % size)
  1518          estimator_state = (
  1519              self.watermark_estimator_provider.initial_estimator_state(
  1520                  element, part))
  1521          yield ((element, (part, estimator_state)), size)
  1522  
  1523    return _create_sdf_operation(SplitAndSizeRestrictions, *args)
  1524  
  1525  
  1526  @BeamTransformFactory.register_urn(
  1527      common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn,
  1528      beam_runner_api_pb2.ParDoPayload)
  1529  def create_truncate_sized_restriction(*args):
  1530    class TruncateAndSizeRestriction(beam.DoFn):
  1531      def __init__(self, fn, restriction_provider, watermark_estimator_provider):
  1532        self.restriction_provider = restriction_provider
  1533  
  1534      def process(self, element_restriction, *args, **kwargs):
  1535        ((element, (restriction, estimator_state)), _) = element_restriction
  1536        truncated_restriction = self.restriction_provider.truncate(
  1537            element, restriction)
  1538        if truncated_restriction:
  1539          truncated_restriction_size = (
  1540              self.restriction_provider.restriction_size(
  1541                  element, truncated_restriction))
  1542          if truncated_restriction_size < 0:
  1543            raise ValueError(
  1544                'Expected size >= 0 but received %s.' %
  1545                truncated_restriction_size)
  1546          yield ((element, (truncated_restriction, estimator_state)),
  1547                 truncated_restriction_size)
  1548  
  1549    return _create_sdf_operation(
  1550        TruncateAndSizeRestriction,
  1551        *args,
  1552        operation_cls=operations.SdfTruncateSizedRestrictions)
  1553  
  1554  
  1555  @BeamTransformFactory.register_urn(
  1556      common_urns.sdf_components.PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS.urn,
  1557      beam_runner_api_pb2.ParDoPayload)
  1558  def create_process_sized_elements_and_restrictions(
  1559      factory,  # type: BeamTransformFactory
  1560      transform_id,  # type: str
  1561      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1562      parameter,  # type: beam_runner_api_pb2.ParDoPayload
  1563      consumers  # type: Dict[str, List[operations.Operation]]
  1564  ):
  1565    return _create_pardo_operation(
  1566        factory,
  1567        transform_id,
  1568        transform_proto,
  1569        consumers,
  1570        core.DoFnInfo.from_runner_api(parameter.do_fn,
  1571                                      factory.context).serialized_dofn_data(),
  1572        parameter,
  1573        operation_cls=operations.SdfProcessSizedElements)
  1574  
  1575  
  1576  def _create_sdf_operation(
  1577      proxy_dofn,
  1578      factory,
  1579      transform_id,
  1580      transform_proto,
  1581      parameter,
  1582      consumers,
  1583      operation_cls=operations.DoOperation):
  1584  
  1585    dofn_data = pickler.loads(parameter.do_fn.payload)
  1586    dofn = dofn_data[0]
  1587    restriction_provider = common.DoFnSignature(dofn).get_restriction_provider()
  1588    watermark_estimator_provider = (
  1589        common.DoFnSignature(dofn).get_watermark_estimator_provider())
  1590    serialized_fn = pickler.dumps(
  1591        (proxy_dofn(dofn, restriction_provider, watermark_estimator_provider), ) +
  1592        dofn_data[1:])
  1593    return _create_pardo_operation(
  1594        factory,
  1595        transform_id,
  1596        transform_proto,
  1597        consumers,
  1598        serialized_fn,
  1599        parameter,
  1600        operation_cls=operation_cls)
  1601  
  1602  
  1603  @BeamTransformFactory.register_urn(
  1604      common_urns.primitives.PAR_DO.urn, beam_runner_api_pb2.ParDoPayload)
  1605  def create_par_do(
  1606      factory,  # type: BeamTransformFactory
  1607      transform_id,  # type: str
  1608      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1609      parameter,  # type: beam_runner_api_pb2.ParDoPayload
  1610      consumers  # type: Dict[str, List[operations.Operation]]
  1611  ):
  1612    # type: (...) -> operations.DoOperation
  1613    return _create_pardo_operation(
  1614        factory,
  1615        transform_id,
  1616        transform_proto,
  1617        consumers,
  1618        core.DoFnInfo.from_runner_api(parameter.do_fn,
  1619                                      factory.context).serialized_dofn_data(),
  1620        parameter)
  1621  
  1622  
  1623  def _create_pardo_operation(
  1624      factory,  # type: BeamTransformFactory
  1625      transform_id,  # type: str
  1626      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1627      consumers,
  1628      serialized_fn,
  1629      pardo_proto=None,  # type: Optional[beam_runner_api_pb2.ParDoPayload]
  1630      operation_cls=operations.DoOperation
  1631  ):
  1632  
  1633    if pardo_proto and pardo_proto.side_inputs:
  1634      input_tags_to_coders = factory.get_input_coders(transform_proto)
  1635      tagged_side_inputs = [
  1636          (tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context))
  1637          for tag,
  1638          si in pardo_proto.side_inputs.items()
  1639      ]
  1640      tagged_side_inputs.sort(
  1641          key=lambda tag_si: sideinputs.get_sideinput_index(tag_si[0]))
  1642      side_input_maps = [
  1643          StateBackedSideInputMap(
  1644              factory.state_handler,
  1645              transform_id,
  1646              tag,
  1647              si,
  1648              input_tags_to_coders[tag]) for tag,
  1649          si in tagged_side_inputs
  1650      ]
  1651    else:
  1652      side_input_maps = []
  1653  
  1654    output_tags = list(transform_proto.outputs.keys())
  1655  
  1656    dofn_data = pickler.loads(serialized_fn)
  1657    if not dofn_data[-1]:
  1658      # Windowing not set.
  1659      if pardo_proto:
  1660        other_input_tags = set.union(
  1661            set(pardo_proto.side_inputs),
  1662            set(pardo_proto.timer_family_specs))  # type: Container[str]
  1663      else:
  1664        other_input_tags = ()
  1665      pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items()
  1666                   if tag not in other_input_tags]
  1667      windowing = factory.context.windowing_strategies.get_by_id(
  1668          factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
  1669      serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing, ))
  1670  
  1671    if pardo_proto and (pardo_proto.timer_family_specs or pardo_proto.state_specs
  1672                        or pardo_proto.restriction_coder_id):
  1673      found_input_coder = None
  1674      for tag, pcoll_id in transform_proto.inputs.items():
  1675        if tag in pardo_proto.side_inputs:
  1676          pass
  1677        else:
  1678          # Must be the main input
  1679          assert found_input_coder is None
  1680          main_input_tag = tag
  1681          found_input_coder = factory.get_windowed_coder(pcoll_id)
  1682      assert found_input_coder is not None
  1683      main_input_coder = found_input_coder
  1684  
  1685      if pardo_proto.timer_family_specs or pardo_proto.state_specs:
  1686        user_state_context = FnApiUserStateContext(
  1687            factory.state_handler,
  1688            transform_id,
  1689            main_input_coder.key_coder(),
  1690            main_input_coder.window_coder
  1691        )  # type: Optional[FnApiUserStateContext]
  1692      else:
  1693        user_state_context = None
  1694    else:
  1695      user_state_context = None
  1696  
  1697    output_coders = factory.get_output_coders(transform_proto)
  1698    spec = operation_specs.WorkerDoFn(
  1699        serialized_fn=serialized_fn,
  1700        output_tags=output_tags,
  1701        input=None,
  1702        side_inputs=None,  # Fn API uses proto definitions and the Fn State API
  1703        output_coders=[output_coders[tag] for tag in output_tags])
  1704  
  1705    result = factory.augment_oldstyle_op(
  1706        operation_cls(
  1707            common.NameContext(transform_proto.unique_name, transform_id),
  1708            spec,
  1709            factory.counter_factory,
  1710            factory.state_sampler,
  1711            side_input_maps,
  1712            user_state_context),
  1713        transform_proto.unique_name,
  1714        consumers,
  1715        output_tags)
  1716    if pardo_proto and pardo_proto.restriction_coder_id:
  1717      result.input_info = operations.OpInputInfo(
  1718          transform_id,
  1719          main_input_tag,
  1720          main_input_coder,
  1721          transform_proto.outputs.keys())
  1722    return result
  1723  
  1724  
  1725  def _create_simple_pardo_operation(factory,  # type: BeamTransformFactory
  1726                                     transform_id,
  1727                                     transform_proto,
  1728                                     consumers,
  1729                                     dofn,  # type: beam.DoFn
  1730                                    ):
  1731    serialized_fn = pickler.dumps((dofn, (), {}, [], None))
  1732    return _create_pardo_operation(
  1733        factory, transform_id, transform_proto, consumers, serialized_fn)
  1734  
  1735  
  1736  @BeamTransformFactory.register_urn(
  1737      common_urns.primitives.ASSIGN_WINDOWS.urn,
  1738      beam_runner_api_pb2.WindowingStrategy)
  1739  def create_assign_windows(
  1740      factory,  # type: BeamTransformFactory
  1741      transform_id,  # type: str
  1742      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1743      parameter,  # type: beam_runner_api_pb2.WindowingStrategy
  1744      consumers  # type: Dict[str, List[operations.Operation]]
  1745  ):
  1746    class WindowIntoDoFn(beam.DoFn):
  1747      def __init__(self, windowing):
  1748        self.windowing = windowing
  1749  
  1750      def process(
  1751          self,
  1752          element,
  1753          timestamp=beam.DoFn.TimestampParam,
  1754          window=beam.DoFn.WindowParam):
  1755        new_windows = self.windowing.windowfn.assign(
  1756            WindowFn.AssignContext(timestamp, element=element, window=window))
  1757        yield WindowedValue(element, timestamp, new_windows)
  1758  
  1759    from apache_beam.transforms.core import Windowing
  1760    from apache_beam.transforms.window import WindowFn
  1761    windowing = Windowing.from_runner_api(parameter, factory.context)
  1762    return _create_simple_pardo_operation(
  1763        factory,
  1764        transform_id,
  1765        transform_proto,
  1766        consumers,
  1767        WindowIntoDoFn(windowing))
  1768  
  1769  
  1770  @BeamTransformFactory.register_urn(IDENTITY_DOFN_URN, None)
  1771  def create_identity_dofn(
  1772      factory,  # type: BeamTransformFactory
  1773      transform_id,  # type: str
  1774      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1775      parameter,
  1776      consumers  # type: Dict[str, List[operations.Operation]]
  1777  ):
  1778    # type: (...) -> operations.FlattenOperation
  1779    return factory.augment_oldstyle_op(
  1780        operations.FlattenOperation(
  1781            common.NameContext(transform_proto.unique_name, transform_id),
  1782            operation_specs.WorkerFlatten(
  1783                None, [factory.get_only_output_coder(transform_proto)]),
  1784            factory.counter_factory,
  1785            factory.state_sampler),
  1786        transform_proto.unique_name,
  1787        consumers)
  1788  
  1789  
  1790  @BeamTransformFactory.register_urn(
  1791      common_urns.combine_components.COMBINE_PER_KEY_PRECOMBINE.urn,
  1792      beam_runner_api_pb2.CombinePayload)
  1793  def create_combine_per_key_precombine(
  1794      factory,  # type: BeamTransformFactory
  1795      transform_id,  # type: str
  1796      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1797      payload,  # type: beam_runner_api_pb2.CombinePayload
  1798      consumers  # type: Dict[str, List[operations.Operation]]
  1799  ):
  1800    # type: (...) -> operations.PGBKCVOperation
  1801    serialized_combine_fn = pickler.dumps((
  1802        beam.CombineFn.from_runner_api(payload.combine_fn,
  1803                                       factory.context), [], {}))
  1804    return factory.augment_oldstyle_op(
  1805        operations.PGBKCVOperation(
  1806            common.NameContext(transform_proto.unique_name, transform_id),
  1807            operation_specs.WorkerPartialGroupByKey(
  1808                serialized_combine_fn,
  1809                None, [factory.get_only_output_coder(transform_proto)]),
  1810            factory.counter_factory,
  1811            factory.state_sampler,
  1812            factory.get_input_windowing(transform_proto)),
  1813        transform_proto.unique_name,
  1814        consumers)
  1815  
  1816  
  1817  @BeamTransformFactory.register_urn(
  1818      common_urns.combine_components.COMBINE_PER_KEY_MERGE_ACCUMULATORS.urn,
  1819      beam_runner_api_pb2.CombinePayload)
  1820  def create_combbine_per_key_merge_accumulators(
  1821      factory,  # type: BeamTransformFactory
  1822      transform_id,  # type: str
  1823      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1824      payload,  # type: beam_runner_api_pb2.CombinePayload
  1825      consumers  # type: Dict[str, List[operations.Operation]]
  1826  ):
  1827    return _create_combine_phase_operation(
  1828        factory, transform_id, transform_proto, payload, consumers, 'merge')
  1829  
  1830  
  1831  @BeamTransformFactory.register_urn(
  1832      common_urns.combine_components.COMBINE_PER_KEY_EXTRACT_OUTPUTS.urn,
  1833      beam_runner_api_pb2.CombinePayload)
  1834  def create_combine_per_key_extract_outputs(
  1835      factory,  # type: BeamTransformFactory
  1836      transform_id,  # type: str
  1837      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1838      payload,  # type: beam_runner_api_pb2.CombinePayload
  1839      consumers  # type: Dict[str, List[operations.Operation]]
  1840  ):
  1841    return _create_combine_phase_operation(
  1842        factory, transform_id, transform_proto, payload, consumers, 'extract')
  1843  
  1844  
  1845  @BeamTransformFactory.register_urn(
  1846      common_urns.combine_components.COMBINE_PER_KEY_CONVERT_TO_ACCUMULATORS.urn,
  1847      beam_runner_api_pb2.CombinePayload)
  1848  def create_combine_per_key_convert_to_accumulators(
  1849      factory,  # type: BeamTransformFactory
  1850      transform_id,  # type: str
  1851      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1852      payload,  # type: beam_runner_api_pb2.CombinePayload
  1853      consumers  # type: Dict[str, List[operations.Operation]]
  1854  ):
  1855    return _create_combine_phase_operation(
  1856        factory, transform_id, transform_proto, payload, consumers, 'convert')
  1857  
  1858  
  1859  @BeamTransformFactory.register_urn(
  1860      common_urns.combine_components.COMBINE_GROUPED_VALUES.urn,
  1861      beam_runner_api_pb2.CombinePayload)
  1862  def create_combine_grouped_values(
  1863      factory,  # type: BeamTransformFactory
  1864      transform_id,  # type: str
  1865      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1866      payload,  # type: beam_runner_api_pb2.CombinePayload
  1867      consumers  # type: Dict[str, List[operations.Operation]]
  1868  ):
  1869    return _create_combine_phase_operation(
  1870        factory, transform_id, transform_proto, payload, consumers, 'all')
  1871  
  1872  
  1873  def _create_combine_phase_operation(
  1874      factory, transform_id, transform_proto, payload, consumers, phase):
  1875    # type: (...) -> operations.CombineOperation
  1876    serialized_combine_fn = pickler.dumps((
  1877        beam.CombineFn.from_runner_api(payload.combine_fn,
  1878                                       factory.context), [], {}))
  1879    return factory.augment_oldstyle_op(
  1880        operations.CombineOperation(
  1881            common.NameContext(transform_proto.unique_name, transform_id),
  1882            operation_specs.WorkerCombineFn(
  1883                serialized_combine_fn,
  1884                phase,
  1885                None, [factory.get_only_output_coder(transform_proto)]),
  1886            factory.counter_factory,
  1887            factory.state_sampler),
  1888        transform_proto.unique_name,
  1889        consumers)
  1890  
  1891  
  1892  @BeamTransformFactory.register_urn(common_urns.primitives.FLATTEN.urn, None)
  1893  def create_flatten(
  1894      factory,  # type: BeamTransformFactory
  1895      transform_id,  # type: str
  1896      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1897      payload,
  1898      consumers  # type: Dict[str, List[operations.Operation]]
  1899  ):
  1900    # type: (...) -> operations.FlattenOperation
  1901    return factory.augment_oldstyle_op(
  1902        operations.FlattenOperation(
  1903            common.NameContext(transform_proto.unique_name, transform_id),
  1904            operation_specs.WorkerFlatten(
  1905                None, [factory.get_only_output_coder(transform_proto)]),
  1906            factory.counter_factory,
  1907            factory.state_sampler),
  1908        transform_proto.unique_name,
  1909        consumers)
  1910  
  1911  
  1912  @BeamTransformFactory.register_urn(
  1913      common_urns.primitives.MAP_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec)
  1914  def create_map_windows(
  1915      factory,  # type: BeamTransformFactory
  1916      transform_id,  # type: str
  1917      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1918      mapping_fn_spec,  # type: beam_runner_api_pb2.FunctionSpec
  1919      consumers  # type: Dict[str, List[operations.Operation]]
  1920  ):
  1921    assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOW_MAPPING_FN
  1922    window_mapping_fn = pickler.loads(mapping_fn_spec.payload)
  1923  
  1924    class MapWindows(beam.DoFn):
  1925      def process(self, element):
  1926        key, window = element
  1927        return [(key, window_mapping_fn(window))]
  1928  
  1929    return _create_simple_pardo_operation(
  1930        factory, transform_id, transform_proto, consumers, MapWindows())
  1931  
  1932  
  1933  @BeamTransformFactory.register_urn(
  1934      common_urns.primitives.MERGE_WINDOWS.urn, beam_runner_api_pb2.FunctionSpec)
  1935  def create_merge_windows(
  1936      factory,  # type: BeamTransformFactory
  1937      transform_id,  # type: str
  1938      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1939      mapping_fn_spec,  # type: beam_runner_api_pb2.FunctionSpec
  1940      consumers  # type: Dict[str, List[operations.Operation]]
  1941  ):
  1942    assert mapping_fn_spec.urn == python_urns.PICKLED_WINDOWFN
  1943    window_fn = pickler.loads(mapping_fn_spec.payload)
  1944  
  1945    class MergeWindows(beam.DoFn):
  1946      def process(self, element):
  1947        nonce, windows = element
  1948  
  1949        original_windows = set(windows)  # type: Set[window.BoundedWindow]
  1950        merged_windows = collections.defaultdict(
  1951            set
  1952        )  # type: MutableMapping[window.BoundedWindow, Set[window.BoundedWindow]] # noqa: F821
  1953  
  1954        class RecordingMergeContext(window.WindowFn.MergeContext):
  1955          def merge(
  1956              self,
  1957              to_be_merged,  # type: Iterable[window.BoundedWindow]
  1958              merge_result,  # type: window.BoundedWindow
  1959            ):
  1960            originals = merged_windows[merge_result]
  1961            for window in to_be_merged:
  1962              if window in original_windows:
  1963                originals.add(window)
  1964                original_windows.remove(window)
  1965              else:
  1966                originals.update(merged_windows.pop(window))
  1967  
  1968        window_fn.merge(RecordingMergeContext(windows))
  1969        yield nonce, (original_windows, merged_windows.items())
  1970  
  1971    return _create_simple_pardo_operation(
  1972        factory, transform_id, transform_proto, consumers, MergeWindows())
  1973  
  1974  
  1975  @BeamTransformFactory.register_urn(common_urns.primitives.TO_STRING.urn, None)
  1976  def create_to_string_fn(
  1977      factory,  # type: BeamTransformFactory
  1978      transform_id,  # type: str
  1979      transform_proto,  # type: beam_runner_api_pb2.PTransform
  1980      mapping_fn_spec,  # type: beam_runner_api_pb2.FunctionSpec
  1981      consumers  # type: Dict[str, List[operations.Operation]]
  1982  ):
  1983    class ToString(beam.DoFn):
  1984      def process(self, element):
  1985        key, value = element
  1986        return [(key, str(value))]
  1987  
  1988    return _create_simple_pardo_operation(
  1989        factory, transform_id, transform_proto, consumers, ToString())
  1990  
  1991  
  1992  class DataSamplingOperation(operations.Operation):
  1993    """Operation that samples incoming elements."""
  1994  
  1995    def __init__(
  1996        self,
  1997        name_context,  # type: common.NameContext
  1998        counter_factory,  # type: counters.CounterFactory
  1999        state_sampler,  # type: statesampler.StateSampler
  2000        pcoll_id,  # type: str
  2001        sample_coder,  # type: coders.Coder
  2002        data_sampler,  # type: data_sampler.DataSampler
  2003    ):
  2004      # type: (...) -> None
  2005      super().__init__(name_context, None, counter_factory, state_sampler)
  2006      self._coder = sample_coder  # type: coders.Coder
  2007      self._pcoll_id = pcoll_id  # type: str
  2008  
  2009      self._sampler: OutputSampler = data_sampler.sample_output(
  2010          self._pcoll_id, sample_coder)
  2011  
  2012    def process(self, windowed_value):
  2013      # type: (windowed_value.WindowedValue) -> None
  2014      self._sampler.sample(windowed_value)
  2015  
  2016  
  2017  @BeamTransformFactory.register_urn(SYNTHETIC_DATA_SAMPLING_URN, (bytes))
  2018  def create_data_sampling_op(
  2019      factory,  # type: BeamTransformFactory
  2020      transform_id,  # type: str
  2021      transform_proto,  # type: beam_runner_api_pb2.PTransform
  2022      pcoll_and_coder_id,  # type: bytes
  2023      consumers,  # type: Dict[str, List[operations.Operation]]
  2024  ):
  2025    # Creating this operation should only occur when data sampling is enabled.
  2026    data_sampler = factory.data_sampler
  2027    assert data_sampler is not None
  2028  
  2029    coder = coders.FastPrimitivesCoder()
  2030    pcoll_id, coder_id = coder.decode(pcoll_and_coder_id)
  2031    return DataSamplingOperation(
  2032        common.NameContext(transform_proto.unique_name, transform_id),
  2033        factory.counter_factory,
  2034        factory.state_sampler,
  2035        pcoll_id,
  2036        factory.get_coder(coder_id),
  2037        data_sampler,
  2038    )