github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/operations.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # cython: language_level=3
    19  # cython: profile=True
    20  
    21  """Worker operations executor."""
    22  
    23  # pytype: skip-file
    24  # pylint: disable=super-with-arguments
    25  
    26  import collections
    27  import logging
    28  import threading
    29  import warnings
    30  from typing import TYPE_CHECKING
    31  from typing import Any
    32  from typing import DefaultDict
    33  from typing import Dict
    34  from typing import FrozenSet
    35  from typing import Hashable
    36  from typing import Iterable
    37  from typing import Iterator
    38  from typing import List
    39  from typing import Mapping
    40  from typing import NamedTuple
    41  from typing import Optional
    42  from typing import Tuple
    43  
    44  from apache_beam import coders
    45  from apache_beam.internal import pickler
    46  from apache_beam.io import iobase
    47  from apache_beam.metrics import monitoring_infos
    48  from apache_beam.metrics.cells import DistributionData
    49  from apache_beam.metrics.execution import MetricsContainer
    50  from apache_beam.portability.api import metrics_pb2
    51  from apache_beam.runners import common
    52  from apache_beam.runners.common import Receiver
    53  from apache_beam.runners.worker import opcounters
    54  from apache_beam.runners.worker import operation_specs
    55  from apache_beam.runners.worker import sideinputs
    56  from apache_beam.transforms import sideinputs as apache_sideinputs
    57  from apache_beam.transforms import combiners
    58  from apache_beam.transforms import core
    59  from apache_beam.transforms import userstate
    60  from apache_beam.transforms import window
    61  from apache_beam.transforms.combiners import PhasedCombineFnExecutor
    62  from apache_beam.transforms.combiners import curry_combine_fn
    63  from apache_beam.transforms.window import GlobalWindows
    64  from apache_beam.typehints.batch import BatchConverter
    65  from apache_beam.utils.windowed_value import WindowedBatch
    66  from apache_beam.utils.windowed_value import WindowedValue
    67  
    68  if TYPE_CHECKING:
    69    from apache_beam.runners.sdf_utils import SplitResultPrimary
    70    from apache_beam.runners.sdf_utils import SplitResultResidual
    71    from apache_beam.runners.worker.bundle_processor import ExecutionContext
    72    from apache_beam.runners.worker.statesampler import StateSampler
    73    from apache_beam.transforms.userstate import TimerSpec
    74  
    75  # Allow some "pure mode" declarations.
    76  try:
    77    import cython
    78  except ImportError:
    79  
    80    class FakeCython(object):
    81      compiled = False
    82  
    83    globals()['cython'] = FakeCython()
    84  
    85  _globally_windowed_value = GlobalWindows.windowed_value(None)
    86  _global_window_type = type(_globally_windowed_value.windows[0])
    87  
    88  _LOGGER = logging.getLogger(__name__)
    89  
    90  SdfSplitResultsPrimary = Tuple['DoOperation', 'SplitResultPrimary']
    91  SdfSplitResultsResidual = Tuple['DoOperation', 'SplitResultResidual']
    92  
    93  
    94  # TODO(BEAM-9324) Remove these workarounds once upgraded to Cython 3
    95  def _cast_to_operation(value):
    96    if cython.compiled:
    97      return cython.cast(Operation, value)
    98    else:
    99      return value
   100  
   101  
   102  # TODO(BEAM-9324) Remove these workarounds once upgraded to Cython 3
   103  def _cast_to_receiver(value):
   104    if cython.compiled:
   105      return cython.cast(Receiver, value)
   106    else:
   107      return value
   108  
   109  
   110  class ConsumerSet(Receiver):
   111    """A ConsumerSet represents a graph edge between two Operation nodes.
   112  
   113    The ConsumerSet object collects information from the output of the
   114    Operation at one end of its edge and the input of the Operation at
   115    the other edge.
   116    ConsumerSet are attached to the outputting Operation.
   117    """
   118    @staticmethod
   119    def create(counter_factory,
   120               step_name,  # type: str
   121               output_index,
   122               consumers,  # type: List[Operation]
   123               coder,
   124               producer_type_hints,
   125               producer_batch_converter, # type: Optional[BatchConverter]
   126               ):
   127      # type: (...) -> ConsumerSet
   128      if len(consumers) == 1:
   129        consumer = consumers[0]
   130  
   131        consumer_batch_preference = consumer.get_batching_preference()
   132        consumer_batch_converter = consumer.get_input_batch_converter()
   133        if (not consumer_batch_preference.supports_batches and
   134            producer_batch_converter is None and
   135            consumer_batch_converter is None):
   136          return SingletonElementConsumerSet(
   137              counter_factory,
   138              step_name,
   139              output_index,
   140              consumer,
   141              coder,
   142              producer_type_hints)
   143  
   144      return GeneralPurposeConsumerSet(
   145          counter_factory,
   146          step_name,
   147          output_index,
   148          coder,
   149          producer_type_hints,
   150          consumers,
   151          producer_batch_converter)
   152  
   153    def __init__(self,
   154                 counter_factory,
   155                 step_name,  # type: str
   156                 output_index,
   157                 consumers,
   158                 coder,
   159                 producer_type_hints,
   160                 producer_batch_converter
   161                 ):
   162      self.opcounter = opcounters.OperationCounters(
   163          counter_factory,
   164          step_name,
   165          coder,
   166          output_index,
   167          producer_type_hints=producer_type_hints,
   168          producer_batch_converter=producer_batch_converter)
   169      # Used in repr.
   170      self.step_name = step_name
   171      self.output_index = output_index
   172      self.coder = coder
   173      self.consumers = consumers
   174  
   175    def try_split(self, fraction_of_remainder):
   176      # type: (...) -> Optional[Any]
   177      # TODO(SDF): Consider supporting splitting each consumer individually.
   178      # This would never come up in the existing SDF expansion, but might
   179      # be useful to support fused SDF nodes.
   180      # This would require dedicated delivery of the split results to each
   181      # of the consumers separately.
   182      return None
   183  
   184    def current_element_progress(self):
   185      # type: () -> Optional[iobase.RestrictionProgress]
   186  
   187      """Returns the progress of the current element.
   188  
   189      This progress should be an instance of
   190      apache_beam.io.iobase.RestrictionProgress, or None if progress is unknown.
   191      """
   192      # TODO(SDF): Could implement this as a weighted average, if it becomes
   193      # useful to split on.
   194      return None
   195  
   196    def update_counters_start(self, windowed_value):
   197      # type: (WindowedValue) -> None
   198      self.opcounter.update_from(windowed_value)
   199  
   200    def update_counters_finish(self):
   201      # type: () -> None
   202      self.opcounter.update_collect()
   203  
   204    def update_counters_batch(self, windowed_batch):
   205      # type: (WindowedBatch) -> None
   206      self.opcounter.update_from_batch(windowed_batch)
   207  
   208    def __repr__(self):
   209      return '%s[%s.out%s, coder=%s, len(consumers)=%s]' % (
   210          self.__class__.__name__,
   211          self.step_name,
   212          self.output_index,
   213          self.coder,
   214          len(self.consumers))
   215  
   216  
   217  class SingletonElementConsumerSet(ConsumerSet):
   218    """ConsumerSet representing a single consumer that can only process elements
   219    (not batches)."""
   220    def __init__(self,
   221                 counter_factory,
   222                 step_name,
   223                 output_index,
   224                 consumer,  # type: Operation
   225                 coder,
   226                 producer_type_hints
   227                 ):
   228      super().__init__(
   229          counter_factory,
   230          step_name,
   231          output_index, [consumer],
   232          coder,
   233          producer_type_hints,
   234          None)
   235      self.consumer = consumer
   236  
   237    def receive(self, windowed_value):
   238      # type: (WindowedValue) -> None
   239      self.update_counters_start(windowed_value)
   240      self.consumer.process(windowed_value)
   241      self.update_counters_finish()
   242  
   243    def receive_batch(self, windowed_batch):
   244      raise AssertionError(
   245          "SingletonElementConsumerSet.receive_batch is not implemented")
   246  
   247    def flush(self):
   248      # SingletonElementConsumerSet has no buffer to flush
   249      pass
   250  
   251    def try_split(self, fraction_of_remainder):
   252      # type: (...) -> Optional[Any]
   253      return self.consumer.try_split(fraction_of_remainder)
   254  
   255    def current_element_progress(self):
   256      return self.consumer.current_element_progress()
   257  
   258  
   259  class GeneralPurposeConsumerSet(ConsumerSet):
   260    """ConsumerSet implementation that handles all combinations of possible edges.
   261    """
   262    MAX_BATCH_SIZE = 4096
   263  
   264    def __init__(self,
   265                 counter_factory,
   266                 step_name,  # type: str
   267                 output_index,
   268                 coder,
   269                 producer_type_hints,
   270                 consumers,  # type: List[Operation]
   271                 producer_batch_converter):
   272      super().__init__(
   273          counter_factory,
   274          step_name,
   275          output_index,
   276          consumers,
   277          coder,
   278          producer_type_hints,
   279          producer_batch_converter)
   280  
   281      self.producer_batch_converter = producer_batch_converter
   282  
   283      # Partition consumers into three groups:
   284      # - consumers that will be passed elements
   285      # - consumers that will be passed batches (where their input batch type
   286      #   matches the output of the producer)
   287      # - consumers that will be passed converted batches
   288      self.element_consumers: List[Operation] = []
   289      self.passthrough_batch_consumers: List[Operation] = []
   290      other_batch_consumers: DefaultDict[
   291          BatchConverter, List[Operation]] = collections.defaultdict(lambda: [])
   292  
   293      for consumer in consumers:
   294        if not consumer.get_batching_preference().supports_batches:
   295          self.element_consumers.append(consumer)
   296        elif (consumer.get_input_batch_converter() ==
   297              self.producer_batch_converter):
   298          self.passthrough_batch_consumers.append(consumer)
   299        else:
   300          # Batch consumer with a mismatched batch type
   301          if consumer.get_batching_preference().supports_elements:
   302            # Pass it elements if we can
   303            self.element_consumers.append(consumer)
   304          else:
   305            # As a last resort, explode and rebatch
   306            consumer_batch_converter = consumer.get_input_batch_converter()
   307            # This consumer supports batches, it must have a batch converter
   308            assert consumer_batch_converter is not None
   309            other_batch_consumers[consumer_batch_converter].append(consumer)
   310  
   311      self.other_batch_consumers: Dict[BatchConverter, List[Operation]] = dict(
   312          other_batch_consumers)
   313  
   314      self.has_batch_consumers = (
   315          self.passthrough_batch_consumers or self.other_batch_consumers)
   316      self._batched_elements: List[Any] = []
   317  
   318    def receive(self, windowed_value):
   319      # type: (WindowedValue) -> None
   320  
   321      self.update_counters_start(windowed_value)
   322  
   323      for consumer in self.element_consumers:
   324        _cast_to_operation(consumer).process(windowed_value)
   325  
   326      # TODO: Do this branching when contstructing ConsumerSet
   327      if self.has_batch_consumers:
   328        self._batched_elements.append(windowed_value)
   329        if len(self._batched_elements) >= self.MAX_BATCH_SIZE:
   330          self.flush()
   331  
   332      # TODO(https://github.com/apache/beam/issues/21655): Properly estimate
   333      # sizes in the batch-consumer only case, this undercounts large iterables
   334      self.update_counters_finish()
   335  
   336    def receive_batch(self, windowed_batch):
   337      if self.element_consumers:
   338        for wv in windowed_batch.as_windowed_values(
   339            self.producer_batch_converter.explode_batch):
   340          for consumer in self.element_consumers:
   341            _cast_to_operation(consumer).process(wv)
   342  
   343      for consumer in self.passthrough_batch_consumers:
   344        _cast_to_operation(consumer).process_batch(windowed_batch)
   345  
   346      for (consumer_batch_converter,
   347           consumers) in self.other_batch_consumers.items():
   348        # Explode and rebatch into the new batch type (ouch!)
   349        # TODO: Register direct conversions for equivalent batch types
   350  
   351        for consumer in consumers:
   352          warnings.warn(
   353              f"Input to operation {consumer} must be rebatched from type "
   354              f"{self.producer_batch_converter.batch_type!r} to "
   355              f"{consumer_batch_converter.batch_type!r}.\n"
   356              "This is very inefficient, consider re-structuring your pipeline "
   357              "or adding a DoFn to directly convert between these types.",
   358              InefficientExecutionWarning)
   359          _cast_to_operation(consumer).process_batch(
   360              windowed_batch.with_values(
   361                  consumer_batch_converter.produce_batch(
   362                      self.producer_batch_converter.explode_batch(
   363                          windowed_batch.values))))
   364  
   365      self.update_counters_batch(windowed_batch)
   366  
   367    def flush(self):
   368      if not self.has_batch_consumers or not self._batched_elements:
   369        return
   370  
   371      for batch_converter, consumers in self.other_batch_consumers.items():
   372        for windowed_batch in WindowedBatch.from_windowed_values(
   373            self._batched_elements, produce_fn=batch_converter.produce_batch):
   374          for consumer in consumers:
   375            _cast_to_operation(consumer).process_batch(windowed_batch)
   376  
   377      for consumer in self.passthrough_batch_consumers:
   378        for windowed_batch in WindowedBatch.from_windowed_values(
   379            self._batched_elements,
   380            produce_fn=self.producer_batch_converter.produce_batch):
   381          _cast_to_operation(consumer).process_batch(windowed_batch)
   382  
   383      self._batched_elements = []
   384  
   385  
   386  class Operation(object):
   387    """An operation representing the live version of a work item specification.
   388  
   389    An operation can have one or more outputs and for each output it can have
   390    one or more receiver operations that will take that as input.
   391    """
   392  
   393    def __init__(self,
   394                 name_context,  # type: common.NameContext
   395                 spec,
   396                 counter_factory,
   397                 state_sampler  # type: StateSampler
   398                ):
   399      """Initializes a worker operation instance.
   400  
   401      Args:
   402        name_context: A NameContext instance, with the name information for this
   403          operation.
   404        spec: A operation_specs.Worker* instance.
   405        counter_factory: The CounterFactory to use for our counters.
   406        state_sampler: The StateSampler for the current operation.
   407      """
   408      assert isinstance(name_context, common.NameContext)
   409      self.name_context = name_context
   410  
   411      self.spec = spec
   412      self.counter_factory = counter_factory
   413      self.execution_context = None  # type: Optional[ExecutionContext]
   414      self.consumers = collections.defaultdict(
   415          list)  # type: DefaultDict[int, List[Operation]]
   416  
   417      # These are overwritten in the legacy harness.
   418      self.metrics_container = MetricsContainer(self.name_context.metrics_name())
   419  
   420      self.state_sampler = state_sampler
   421      self.scoped_start_state = self.state_sampler.scoped_state(
   422          self.name_context, 'start', metrics_container=self.metrics_container)
   423      self.scoped_process_state = self.state_sampler.scoped_state(
   424          self.name_context, 'process', metrics_container=self.metrics_container)
   425      self.scoped_finish_state = self.state_sampler.scoped_state(
   426          self.name_context, 'finish', metrics_container=self.metrics_container)
   427      # TODO(ccy): the '-abort' state can be added when the abort is supported in
   428      # Operations.
   429      self.receivers = []  # type: List[ConsumerSet]
   430      # Legacy workers cannot call setup() until after setting additional state
   431      # on the operation.
   432      self.setup_done = False
   433      self.step_name = None  # type: Optional[str]
   434  
   435    def setup(self):
   436      # type: () -> None
   437  
   438      """Set up operation.
   439  
   440      This must be called before any other methods of the operation."""
   441      with self.scoped_start_state:
   442        self.debug_logging_enabled = logging.getLogger().isEnabledFor(
   443            logging.DEBUG)
   444        # Everything except WorkerSideInputSource, which is not a
   445        # top-level operation, should have output_coders
   446        #TODO(pabloem): Define better what step name is used here.
   447        if getattr(self.spec, 'output_coders', None):
   448          self.receivers = [
   449              ConsumerSet.create(
   450                  self.counter_factory,
   451                  self.name_context.logging_name(),
   452                  i,
   453                  self.consumers[i],
   454                  coder,
   455                  self._get_runtime_performance_hints(),
   456                  self.get_output_batch_converter(),
   457              ) for i,
   458              coder in enumerate(self.spec.output_coders)
   459          ]
   460      self.setup_done = True
   461  
   462    def start(self):
   463      # type: () -> None
   464  
   465      """Start operation."""
   466      if not self.setup_done:
   467        # For legacy workers.
   468        self.setup()
   469  
   470    def get_batching_preference(self):
   471      # By default operations don't support batching, require Receiver to unbatch
   472      return common.BatchingPreference.BATCH_FORBIDDEN
   473  
   474    def get_input_batch_converter(self) -> Optional[BatchConverter]:
   475      """Returns a batch type converter if this operation can accept a batch,
   476      otherwise None."""
   477      return None
   478  
   479    def get_output_batch_converter(self) -> Optional[BatchConverter]:
   480      """Returns a batch type converter if this operation can produce a batch,
   481      otherwise None."""
   482      return None
   483  
   484    def process(self, o):
   485      # type: (WindowedValue) -> None
   486  
   487      """Process element in operation."""
   488      pass
   489  
   490    def process_batch(self, batch: WindowedBatch):
   491      pass
   492  
   493    def finalize_bundle(self):
   494      # type: () -> None
   495      pass
   496  
   497    def needs_finalization(self):
   498      return False
   499  
   500    def try_split(self, fraction_of_remainder):
   501      # type: (...) -> Optional[Any]
   502      return None
   503  
   504    def current_element_progress(self):
   505      return None
   506  
   507    def finish(self):
   508      # type: () -> None
   509  
   510      """Finish operation."""
   511      for receiver in self.receivers:
   512        _cast_to_receiver(receiver).flush()
   513  
   514    def teardown(self):
   515      # type: () -> None
   516  
   517      """Tear down operation.
   518  
   519      No other methods of this operation should be called after this."""
   520      pass
   521  
   522    def reset(self):
   523      # type: () -> None
   524      self.metrics_container.reset()
   525  
   526    def output(self, windowed_value, output_index=0):
   527      # type: (WindowedValue, int) -> None
   528      _cast_to_receiver(self.receivers[output_index]).receive(windowed_value)
   529  
   530    def add_receiver(self, operation, output_index=0):
   531      # type: (Operation, int) -> None
   532  
   533      """Adds a receiver operation for the specified output."""
   534      self.consumers[output_index].append(operation)
   535  
   536    def monitoring_infos(self, transform_id, tag_to_pcollection_id):
   537      # type: (str, Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]
   538  
   539      """Returns the list of MonitoringInfos collected by this operation."""
   540      all_monitoring_infos = self.execution_time_monitoring_infos(transform_id)
   541      all_monitoring_infos.update(
   542          self.pcollection_count_monitoring_infos(tag_to_pcollection_id))
   543      all_monitoring_infos.update(self.user_monitoring_infos(transform_id))
   544      return all_monitoring_infos
   545  
   546    def pcollection_count_monitoring_infos(self, tag_to_pcollection_id):
   547      # type: (Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]
   548  
   549      """Returns the element count MonitoringInfo collected by this operation."""
   550  
   551      # Skip producing monitoring infos if there is more then one receiver
   552      # since there is no way to provide a mapping from tag to pcollection id
   553      # within Operation.
   554      if len(self.receivers) != 1 or len(tag_to_pcollection_id) != 1:
   555        return {}
   556  
   557      all_monitoring_infos = {}
   558      pcollection_id = next(iter(tag_to_pcollection_id.values()))
   559      receiver = self.receivers[0]
   560      elem_count_mi = monitoring_infos.int64_counter(
   561          monitoring_infos.ELEMENT_COUNT_URN,
   562          receiver.opcounter.element_counter.value(),
   563          pcollection=pcollection_id,
   564      )
   565  
   566      (unused_mean, sum, count, min, max) = (
   567          receiver.opcounter.mean_byte_counter.value())
   568  
   569      sampled_byte_count = monitoring_infos.int64_distribution(
   570          monitoring_infos.SAMPLED_BYTE_SIZE_URN,
   571          DistributionData(sum, count, min, max),
   572          pcollection=pcollection_id,
   573      )
   574      all_monitoring_infos[monitoring_infos.to_key(elem_count_mi)] = elem_count_mi
   575      all_monitoring_infos[monitoring_infos.to_key(
   576          sampled_byte_count)] = sampled_byte_count
   577  
   578      return all_monitoring_infos
   579  
   580    def user_monitoring_infos(self, transform_id):
   581      # type: (str) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]
   582  
   583      """Returns the user MonitoringInfos collected by this operation."""
   584      return self.metrics_container.to_runner_api_monitoring_infos(transform_id)
   585  
   586    def execution_time_monitoring_infos(self, transform_id):
   587      # type: (str) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]
   588      total_time_spent_msecs = (
   589          self.scoped_start_state.sampled_msecs_int() +
   590          self.scoped_process_state.sampled_msecs_int() +
   591          self.scoped_finish_state.sampled_msecs_int())
   592      mis = [
   593          monitoring_infos.int64_counter(
   594              monitoring_infos.START_BUNDLE_MSECS_URN,
   595              self.scoped_start_state.sampled_msecs_int(),
   596              ptransform=transform_id),
   597          monitoring_infos.int64_counter(
   598              monitoring_infos.PROCESS_BUNDLE_MSECS_URN,
   599              self.scoped_process_state.sampled_msecs_int(),
   600              ptransform=transform_id),
   601          monitoring_infos.int64_counter(
   602              monitoring_infos.FINISH_BUNDLE_MSECS_URN,
   603              self.scoped_finish_state.sampled_msecs_int(),
   604              ptransform=transform_id),
   605          monitoring_infos.int64_counter(
   606              monitoring_infos.TOTAL_MSECS_URN,
   607              total_time_spent_msecs,
   608              ptransform=transform_id),
   609      ]
   610      return {monitoring_infos.to_key(mi): mi for mi in mis}
   611  
   612    def __str__(self):
   613      """Generates a useful string for this object.
   614  
   615      Compactly displays interesting fields.  In particular, pickled
   616      fields are not displayed.  Note that we collapse the fields of the
   617      contained Worker* object into this object, since there is a 1-1
   618      mapping between Operation and operation_specs.Worker*.
   619  
   620      Returns:
   621        Compact string representing this object.
   622      """
   623      return self.str_internal()
   624  
   625    def str_internal(self, is_recursive=False):
   626      """Internal helper for __str__ that supports recursion.
   627  
   628      When recursing on receivers, keep the output short.
   629      Args:
   630        is_recursive: whether to omit some details, particularly receivers.
   631      Returns:
   632        Compact string representing this object.
   633      """
   634      printable_name = self.__class__.__name__
   635      if hasattr(self, 'step_name'):
   636        printable_name += ' %s' % self.name_context.logging_name()
   637        if is_recursive:
   638          # If we have a step name, stop here, no more detail needed.
   639          return '<%s>' % printable_name
   640  
   641      if self.spec is None:
   642        printable_fields = []
   643      else:
   644        printable_fields = operation_specs.worker_printable_fields(self.spec)
   645  
   646      if not is_recursive and getattr(self, 'receivers', []):
   647        printable_fields.append(
   648            'receivers=[%s]' %
   649            ', '.join([str(receiver) for receiver in self.receivers]))
   650  
   651      return '<%s %s>' % (printable_name, ', '.join(printable_fields))
   652  
   653    def _get_runtime_performance_hints(self):
   654      # type: () -> Optional[Dict[Optional[str], Tuple[Optional[str], Any]]]
   655  
   656      """Returns any type hints required for performance runtime
   657      type-checking."""
   658      return None
   659  
   660  
   661  class ReadOperation(Operation):
   662    def start(self):
   663      with self.scoped_start_state:
   664        super(ReadOperation, self).start()
   665        range_tracker = self.spec.source.source.get_range_tracker(
   666            self.spec.source.start_position, self.spec.source.stop_position)
   667        for value in self.spec.source.source.read(range_tracker):
   668          if isinstance(value, WindowedValue):
   669            windowed_value = value
   670          else:
   671            windowed_value = _globally_windowed_value.with_value(value)
   672          self.output(windowed_value)
   673  
   674  
   675  class ImpulseReadOperation(Operation):
   676    def __init__(
   677        self,
   678        name_context,  # type: common.NameContext
   679        counter_factory,
   680        state_sampler,  # type: StateSampler
   681        consumers,  # type: Mapping[Any, List[Operation]]
   682        source,  # type: iobase.BoundedSource
   683        output_coder):
   684      super(ImpulseReadOperation,
   685            self).__init__(name_context, None, counter_factory, state_sampler)
   686      self.source = source
   687  
   688      self.receivers = [
   689          ConsumerSet.create(
   690              self.counter_factory,
   691              self.name_context.step_name,
   692              0,
   693              next(iter(consumers.values())),
   694              output_coder,
   695              self._get_runtime_performance_hints(),
   696              self.get_output_batch_converter())
   697      ]
   698  
   699    def process(self, unused_impulse):
   700      # type: (WindowedValue) -> None
   701      with self.scoped_process_state:
   702        range_tracker = self.source.get_range_tracker(None, None)
   703        for value in self.source.read(range_tracker):
   704          if isinstance(value, WindowedValue):
   705            windowed_value = value
   706          else:
   707            windowed_value = _globally_windowed_value.with_value(value)
   708          self.output(windowed_value)
   709  
   710  
   711  class InMemoryWriteOperation(Operation):
   712    """A write operation that will write to an in-memory sink."""
   713    def process(self, o):
   714      # type: (WindowedValue) -> None
   715      with self.scoped_process_state:
   716        if self.debug_logging_enabled:
   717          _LOGGER.debug('Processing [%s] in %s', o, self)
   718        self.spec.output_buffer.append(
   719            o if self.spec.write_windowed_values else o.value)
   720  
   721  
   722  class _TaggedReceivers(dict):
   723    def __init__(self, counter_factory, step_name):
   724      self._counter_factory = counter_factory
   725      self._step_name = step_name
   726  
   727    def __missing__(self, tag):
   728      self[tag] = receiver = ConsumerSet.create(
   729          self._counter_factory, self._step_name, tag, [], None, None, None)
   730      return receiver
   731  
   732    def total_output_bytes(self):
   733      # type: () -> int
   734      total = 0
   735      for receiver in self.values():
   736        elements = receiver.opcounter.element_counter.value()
   737        if elements > 0:
   738          mean = (receiver.opcounter.mean_byte_counter.value())[0]
   739          total += elements * mean
   740      return total
   741  
   742  
   743  OpInputInfo = NamedTuple(
   744      'OpInputInfo',
   745      [
   746          ('transform_id', str),
   747          ('main_input_tag', str),
   748          ('main_input_coder', coders.WindowedValueCoder),
   749          ('outputs', Iterable[str]),
   750      ])
   751  
   752  
   753  class DoOperation(Operation):
   754    """A Do operation that will execute a custom DoFn for each input element."""
   755  
   756    def __init__(self,
   757                 name,  # type: common.NameContext
   758                 spec,  # operation_specs.WorkerDoFn  # need to fix this type
   759                 counter_factory,
   760                 sampler,
   761                 side_input_maps=None,
   762                 user_state_context=None
   763                ):
   764      super(DoOperation, self).__init__(name, spec, counter_factory, sampler)
   765      self.side_input_maps = side_input_maps
   766      self.user_state_context = user_state_context
   767      self.tagged_receivers = None  # type: Optional[_TaggedReceivers]
   768      # A mapping of timer tags to the input "PCollections" they come in on.
   769      self.input_info = None  # type: Optional[OpInputInfo]
   770  
   771      # See fn_data in dataflow_runner.py
   772      # TODO: Store all the items from spec?
   773      self.fn, _, _, _, _ = (pickler.loads(self.spec.serialized_fn))
   774  
   775    def _read_side_inputs(self, tags_and_types):
   776      # type: (...) -> Iterator[apache_sideinputs.SideInputMap]
   777  
   778      """Generator reading side inputs in the order prescribed by tags_and_types.
   779  
   780      Args:
   781        tags_and_types: List of tuples (tag, type). Each side input has a string
   782          tag that is specified in the worker instruction. The type is actually
   783          a boolean which is True for singleton input (read just first value)
   784          and False for collection input (read all values).
   785  
   786      Yields:
   787        With each iteration it yields the result of reading an entire side source
   788        either in singleton or collection mode according to the tags_and_types
   789        argument.
   790      """
   791      # Only call this on the old path where side_input_maps was not
   792      # provided directly.
   793      assert self.side_input_maps is None
   794  
   795      # We will read the side inputs in the order prescribed by the
   796      # tags_and_types argument because this is exactly the order needed to
   797      # replace the ArgumentPlaceholder objects in the args/kwargs of the DoFn
   798      # getting the side inputs.
   799      #
   800      # Note that for each tag there could be several read operations in the
   801      # specification. This can happen for instance if the source has been
   802      # sharded into several files.
   803      for i, (side_tag, view_class, view_options) in enumerate(tags_and_types):
   804        sources = []
   805        # Using the side_tag in the lambda below will trigger a pylint warning.
   806        # However in this case it is fine because the lambda is used right away
   807        # while the variable has the value assigned by the current iteration of
   808        # the for loop.
   809        # pylint: disable=cell-var-from-loop
   810        for si in filter(lambda o: o.tag == side_tag, self.spec.side_inputs):
   811          if not isinstance(si, operation_specs.WorkerSideInputSource):
   812            raise NotImplementedError('Unknown side input type: %r' % si)
   813          sources.append(si.source)
   814        si_counter = opcounters.SideInputReadCounter(
   815            self.counter_factory,
   816            self.state_sampler,
   817            declaring_step=self.name_context.step_name,
   818            # Inputs are 1-indexed, so we add 1 to i in the side input id
   819            input_index=i + 1)
   820        element_counter = opcounters.OperationCounters(
   821            self.counter_factory,
   822            self.name_context.step_name,
   823            view_options['coder'],
   824            i,
   825            suffix='side-input')
   826        iterator_fn = sideinputs.get_iterator_fn_for_sources(
   827            sources, read_counter=si_counter, element_counter=element_counter)
   828        yield apache_sideinputs.SideInputMap(
   829            view_class, view_options, sideinputs.EmulatedIterable(iterator_fn))
   830  
   831    def setup(self):
   832      # type: () -> None
   833      with self.scoped_start_state:
   834        super(DoOperation, self).setup()
   835  
   836        # See fn_data in dataflow_runner.py
   837        fn, args, kwargs, tags_and_types, window_fn = (
   838            pickler.loads(self.spec.serialized_fn))
   839  
   840        state = common.DoFnState(self.counter_factory)
   841        state.step_name = self.name_context.logging_name()
   842  
   843        # Tag to output index map used to dispatch the output values emitted
   844        # by the DoFn function to the appropriate receivers. The main output is
   845        # either the only output or the output tagged with 'None' and is
   846        # associated with its corresponding index.
   847        self.tagged_receivers = _TaggedReceivers(
   848            self.counter_factory, self.name_context.logging_name())
   849  
   850        if len(self.spec.output_tags) == 1:
   851          self.tagged_receivers[None] = self.receivers[0]
   852          self.tagged_receivers[self.spec.output_tags[0]] = self.receivers[0]
   853        else:
   854          for index, tag in enumerate(self.spec.output_tags):
   855            self.tagged_receivers[tag] = self.receivers[index]
   856            if tag == 'None':
   857              self.tagged_receivers[None] = self.receivers[index]
   858  
   859        if self.user_state_context:
   860          self.timer_specs = {
   861              spec.name: spec
   862              for spec in userstate.get_dofn_specs(fn)[1]
   863          }  # type: Dict[str, TimerSpec]
   864  
   865        if self.side_input_maps is None:
   866          if tags_and_types:
   867            self.side_input_maps = list(self._read_side_inputs(tags_and_types))
   868          else:
   869            self.side_input_maps = []
   870  
   871        self.dofn_runner = common.DoFnRunner(
   872            fn,
   873            args,
   874            kwargs,
   875            self.side_input_maps,
   876            window_fn,
   877            tagged_receivers=self.tagged_receivers,
   878            step_name=self.name_context.logging_name(),
   879            state=state,
   880            user_state_context=self.user_state_context,
   881            operation_name=self.name_context.metrics_name())
   882        self.dofn_runner.setup()
   883  
   884    def start(self):
   885      # type: () -> None
   886      with self.scoped_start_state:
   887        super(DoOperation, self).start()
   888        self.dofn_runner.start()
   889  
   890    def get_batching_preference(self):
   891      if self.fn._process_batch_defined:
   892        if self.fn._process_defined:
   893          return common.BatchingPreference.DO_NOT_CARE
   894        else:
   895          return common.BatchingPreference.BATCH_REQUIRED
   896      else:
   897        return common.BatchingPreference.BATCH_FORBIDDEN
   898  
   899    def get_input_batch_converter(self) -> Optional[BatchConverter]:
   900      return getattr(self.fn, 'input_batch_converter', None)
   901  
   902    def get_output_batch_converter(self) -> Optional[BatchConverter]:
   903      return getattr(self.fn, 'output_batch_converter', None)
   904  
   905    def process(self, o):
   906      # type: (WindowedValue) -> None
   907      with self.scoped_process_state:
   908        delayed_applications = self.dofn_runner.process(o)
   909        if delayed_applications:
   910          assert self.execution_context is not None
   911          for delayed_application in delayed_applications:
   912            self.execution_context.delayed_applications.append(
   913                (self, delayed_application))
   914  
   915    def process_batch(self, windowed_batch: WindowedBatch) -> None:
   916      self.dofn_runner.process_batch(windowed_batch)
   917  
   918    def finalize_bundle(self):
   919      # type: () -> None
   920      self.dofn_runner.finalize()
   921  
   922    def needs_finalization(self):
   923      # type: () -> bool
   924      return self.dofn_runner.bundle_finalizer_param.has_callbacks()
   925  
   926    def add_timer_info(self, timer_family_id, timer_info):
   927      self.user_state_context.add_timer_info(timer_family_id, timer_info)
   928  
   929    def process_timer(self, tag, timer_data):
   930      timer_spec = self.timer_specs[tag]
   931      self.dofn_runner.process_user_timer(
   932          timer_spec,
   933          timer_data.user_key,
   934          timer_data.windows[0],
   935          timer_data.fire_timestamp,
   936          timer_data.paneinfo,
   937          timer_data.dynamic_timer_tag)
   938  
   939    def finish(self):
   940      # type: () -> None
   941      super(DoOperation, self).finish()
   942      with self.scoped_finish_state:
   943        self.dofn_runner.finish()
   944        if self.user_state_context:
   945          self.user_state_context.commit()
   946  
   947    def teardown(self):
   948      # type: () -> None
   949      with self.scoped_finish_state:
   950        self.dofn_runner.teardown()
   951      if self.user_state_context:
   952        self.user_state_context.reset()
   953  
   954    def reset(self):
   955      # type: () -> None
   956      super(DoOperation, self).reset()
   957      for side_input_map in self.side_input_maps:
   958        side_input_map.reset()
   959      if self.user_state_context:
   960        self.user_state_context.reset()
   961      self.dofn_runner.bundle_finalizer_param.reset()
   962  
   963    def pcollection_count_monitoring_infos(self, tag_to_pcollection_id):
   964      # type: (Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]
   965  
   966      """Returns the element count MonitoringInfo collected by this operation."""
   967      infos = super(
   968          DoOperation,
   969          self).pcollection_count_monitoring_infos(tag_to_pcollection_id)
   970  
   971      if self.tagged_receivers:
   972        for tag, receiver in self.tagged_receivers.items():
   973          if str(tag) not in tag_to_pcollection_id:
   974            continue
   975          pcollection_id = tag_to_pcollection_id[str(tag)]
   976          mi = monitoring_infos.int64_counter(
   977              monitoring_infos.ELEMENT_COUNT_URN,
   978              receiver.opcounter.element_counter.value(),
   979              pcollection=pcollection_id)
   980          infos[monitoring_infos.to_key(mi)] = mi
   981          (unused_mean, sum, count, min, max) = (
   982              receiver.opcounter.mean_byte_counter.value())
   983          sampled_byte_count = monitoring_infos.int64_distribution(
   984              monitoring_infos.SAMPLED_BYTE_SIZE_URN,
   985              DistributionData(sum, count, min, max),
   986              pcollection=pcollection_id)
   987          infos[monitoring_infos.to_key(sampled_byte_count)] = sampled_byte_count
   988      return infos
   989  
   990    def _get_runtime_performance_hints(self):
   991      fns = pickler.loads(self.spec.serialized_fn)
   992      if fns and hasattr(fns[0], '_runtime_output_constraints'):
   993        return fns[0]._runtime_output_constraints
   994  
   995      return {}
   996  
   997  
   998  class SdfTruncateSizedRestrictions(DoOperation):
   999    def __init__(self, *args, **kwargs):
  1000      super(SdfTruncateSizedRestrictions, self).__init__(*args, **kwargs)
  1001  
  1002    def current_element_progress(self):
  1003      # type: () -> Optional[iobase.RestrictionProgress]
  1004      return self.receivers[0].current_element_progress()
  1005  
  1006    def try_split(
  1007        self, fraction_of_remainder
  1008    ):  # type: (...) -> Optional[Tuple[Iterable[SdfSplitResultsPrimary], Iterable[SdfSplitResultsResidual]]]
  1009      return self.receivers[0].try_split(fraction_of_remainder)
  1010  
  1011  
  1012  class SdfProcessSizedElements(DoOperation):
  1013    def __init__(self, *args, **kwargs):
  1014      super(SdfProcessSizedElements, self).__init__(*args, **kwargs)
  1015      self.lock = threading.RLock()
  1016      self.element_start_output_bytes = None  # type: Optional[int]
  1017  
  1018    def process(self, o):
  1019      # type: (WindowedValue) -> None
  1020      assert self.tagged_receivers is not None
  1021      with self.scoped_process_state:
  1022        try:
  1023          with self.lock:
  1024            self.element_start_output_bytes = \
  1025              self.tagged_receivers.total_output_bytes()
  1026            for receiver in self.tagged_receivers.values():
  1027              receiver.opcounter.restart_sampling()
  1028          # Actually processing the element can be expensive; do it without
  1029          # the lock.
  1030          delayed_applications = self.dofn_runner.process_with_sized_restriction(
  1031              o)
  1032          if delayed_applications:
  1033            assert self.execution_context is not None
  1034            for delayed_application in delayed_applications:
  1035              self.execution_context.delayed_applications.append(
  1036                  (self, delayed_application))
  1037        finally:
  1038          with self.lock:
  1039            self.element_start_output_bytes = None
  1040  
  1041    def try_split(self, fraction_of_remainder):
  1042      # type: (...) -> Optional[Tuple[Iterable[SdfSplitResultsPrimary], Iterable[SdfSplitResultsResidual]]]
  1043      split = self.dofn_runner.try_split(fraction_of_remainder)
  1044      if split:
  1045        primaries, residuals = split
  1046        return [(self, primary) for primary in primaries
  1047                ], [(self, residual) for residual in residuals]
  1048      return None
  1049  
  1050    def current_element_progress(self):
  1051      # type: () -> Optional[iobase.RestrictionProgress]
  1052      with self.lock:
  1053        if self.element_start_output_bytes is not None:
  1054          progress = self.dofn_runner.current_element_progress()
  1055          if progress is not None:
  1056            assert self.tagged_receivers is not None
  1057            return progress.with_completed(
  1058                self.tagged_receivers.total_output_bytes() -
  1059                self.element_start_output_bytes)
  1060        return None
  1061  
  1062    def monitoring_infos(self, transform_id, tag_to_pcollection_id):
  1063      # type: (str, Dict[str, str]) -> Dict[FrozenSet, metrics_pb2.MonitoringInfo]
  1064  
  1065      def encode_progress(value):
  1066        # type: (float) -> bytes
  1067        coder = coders.IterableCoder(coders.FloatCoder())
  1068        return coder.encode([value])
  1069  
  1070      with self.lock:
  1071        infos = super(SdfProcessSizedElements,
  1072                      self).monitoring_infos(transform_id, tag_to_pcollection_id)
  1073        current_element_progress = self.current_element_progress()
  1074        if current_element_progress:
  1075          if current_element_progress.completed_work:
  1076            completed = current_element_progress.completed_work
  1077            remaining = current_element_progress.remaining_work
  1078          else:
  1079            completed = current_element_progress.fraction_completed
  1080            remaining = current_element_progress.fraction_remaining
  1081          assert completed is not None
  1082          assert remaining is not None
  1083          completed_mi = metrics_pb2.MonitoringInfo(
  1084              urn=monitoring_infos.WORK_COMPLETED_URN,
  1085              type=monitoring_infos.PROGRESS_TYPE,
  1086              labels=monitoring_infos.create_labels(ptransform=transform_id),
  1087              payload=encode_progress(completed))
  1088          remaining_mi = metrics_pb2.MonitoringInfo(
  1089              urn=monitoring_infos.WORK_REMAINING_URN,
  1090              type=monitoring_infos.PROGRESS_TYPE,
  1091              labels=monitoring_infos.create_labels(ptransform=transform_id),
  1092              payload=encode_progress(remaining))
  1093          infos[monitoring_infos.to_key(completed_mi)] = completed_mi
  1094          infos[monitoring_infos.to_key(remaining_mi)] = remaining_mi
  1095      return infos
  1096  
  1097  
  1098  class CombineOperation(Operation):
  1099    """A Combine operation executing a CombineFn for each input element."""
  1100    def __init__(self, name_context, spec, counter_factory, state_sampler):
  1101      super(CombineOperation,
  1102            self).__init__(name_context, spec, counter_factory, state_sampler)
  1103      # Combiners do not accept deferred side-inputs (the ignored fourth argument)
  1104      # and therefore the code to handle the extra args/kwargs is simpler than for
  1105      # the DoFn's of ParDo.
  1106      fn, args, kwargs = pickler.loads(self.spec.serialized_fn)[:3]
  1107      self.phased_combine_fn = (
  1108          PhasedCombineFnExecutor(self.spec.phase, fn, args, kwargs))
  1109  
  1110    def setup(self):
  1111      # type: () -> None
  1112      with self.scoped_start_state:
  1113        _LOGGER.debug('Setup called for %s', self)
  1114        super(CombineOperation, self).setup()
  1115        self.phased_combine_fn.combine_fn.setup()
  1116  
  1117    def process(self, o):
  1118      # type: (WindowedValue) -> None
  1119      with self.scoped_process_state:
  1120        if self.debug_logging_enabled:
  1121          _LOGGER.debug('Processing [%s] in %s', o, self)
  1122        key, values = o.value
  1123        self.output(o.with_value((key, self.phased_combine_fn.apply(values))))
  1124  
  1125    def finish(self):
  1126      # type: () -> None
  1127      _LOGGER.debug('Finishing %s', self)
  1128      super(CombineOperation, self).finish()
  1129  
  1130    def teardown(self):
  1131      # type: () -> None
  1132      with self.scoped_finish_state:
  1133        _LOGGER.debug('Teardown called for %s', self)
  1134        super(CombineOperation, self).teardown()
  1135        self.phased_combine_fn.combine_fn.teardown()
  1136  
  1137  
  1138  def create_pgbk_op(step_name, spec, counter_factory, state_sampler):
  1139    if spec.combine_fn:
  1140      return PGBKCVOperation(step_name, spec, counter_factory, state_sampler)
  1141    else:
  1142      return PGBKOperation(step_name, spec, counter_factory, state_sampler)
  1143  
  1144  
  1145  class PGBKOperation(Operation):
  1146    """Partial group-by-key operation.
  1147  
  1148    This takes (windowed) input (key, value) tuples and outputs
  1149    (key, [value]) tuples, performing a best effort group-by-key for
  1150    values in this bundle, memory permitting.
  1151    """
  1152    def __init__(self, name_context, spec, counter_factory, state_sampler):
  1153      super(PGBKOperation,
  1154            self).__init__(name_context, spec, counter_factory, state_sampler)
  1155      assert not self.spec.combine_fn
  1156      self.table = collections.defaultdict(list)
  1157      self.size = 0
  1158      # TODO(robertwb) Make this configurable.
  1159      self.max_size = 10 * 1000
  1160  
  1161    def process(self, o):
  1162      # type: (WindowedValue) -> None
  1163      with self.scoped_process_state:
  1164        # TODO(robertwb): Structural (hashable) values.
  1165        key = o.value[0], tuple(o.windows)
  1166        self.table[key].append(o)
  1167        self.size += 1
  1168        if self.size > self.max_size:
  1169          self.flush(9 * self.max_size // 10)
  1170  
  1171    def finish(self):
  1172      # type: () -> None
  1173      self.flush(0)
  1174      super().finish()
  1175  
  1176    def flush(self, target):
  1177      # type: (int) -> None
  1178      limit = self.size - target
  1179      for ix, (kw, vs) in enumerate(list(self.table.items())):
  1180        if ix >= limit:
  1181          break
  1182        del self.table[kw]
  1183        key, windows = kw
  1184        output_value = [v.value[1] for v in vs]
  1185        windowed_value = WindowedValue((key, output_value),
  1186                                       vs[0].timestamp,
  1187                                       windows)
  1188        self.output(windowed_value)
  1189  
  1190  
  1191  class PGBKCVOperation(Operation):
  1192    def __init__(
  1193        self, name_context, spec, counter_factory, state_sampler, windowing=None):
  1194      super(PGBKCVOperation,
  1195            self).__init__(name_context, spec, counter_factory, state_sampler)
  1196      # Combiners do not accept deferred side-inputs (the ignored fourth
  1197      # argument) and therefore the code to handle the extra args/kwargs is
  1198      # simpler than for the DoFn's of ParDo.
  1199      fn, args, kwargs = pickler.loads(self.spec.combine_fn)[:3]
  1200      self.combine_fn = curry_combine_fn(fn, args, kwargs)
  1201      self.combine_fn_add_input = self.combine_fn.add_input
  1202      if self.combine_fn.compact.__func__ is core.CombineFn.compact:
  1203        self.combine_fn_compact = None
  1204      else:
  1205        self.combine_fn_compact = self.combine_fn.compact
  1206      if windowing:
  1207        self.is_default_windowing = windowing.is_default()
  1208        tsc_type = windowing.timestamp_combiner
  1209        self.timestamp_combiner = (
  1210            None if tsc_type == window.TimestampCombiner.OUTPUT_AT_EOW else
  1211            window.TimestampCombiner.get_impl(tsc_type, windowing.windowfn))
  1212      else:
  1213        self.is_default_windowing = False  # unknown
  1214        self.timestamp_combiner = None
  1215      # Optimization for the (known tiny accumulator, often wide keyspace)
  1216      # combine functions.
  1217      # TODO(b/36567833): Bound by in-memory size rather than key count.
  1218      self.max_keys = (
  1219          1000 * 1000 if
  1220          isinstance(fn, (combiners.CountCombineFn, combiners.MeanCombineFn)) or
  1221          # TODO(b/36597732): Replace this 'or' part by adding the 'cy' optimized
  1222          # combiners to the short list above.
  1223          (
  1224              isinstance(fn, core.CallableWrapperCombineFn) and
  1225              fn._fn in (min, max, sum)) else 100 * 1000)  # pylint: disable=protected-access
  1226      self.key_count = 0
  1227      self.table = {}
  1228  
  1229    def setup(self):
  1230      # type: () -> None
  1231      with self.scoped_start_state:
  1232        _LOGGER.debug('Setup called for %s', self)
  1233        super(PGBKCVOperation, self).setup()
  1234        self.combine_fn.setup()
  1235  
  1236    def process(self, wkv):
  1237      # type: (WindowedValue) -> None
  1238      with self.scoped_process_state:
  1239        key, value = wkv.value
  1240        # pylint: disable=unidiomatic-typecheck
  1241        # Optimization for the global window case.
  1242        if self.is_default_windowing:
  1243          wkey = key  # type: Hashable
  1244        else:
  1245          wkey = tuple(wkv.windows), key
  1246        entry = self.table.get(wkey, None)
  1247        if entry is None:
  1248          if self.key_count >= self.max_keys:
  1249            target = self.key_count * 9 // 10
  1250            old_wkeys = []
  1251            # TODO(robertwb): Use an LRU cache?
  1252            for old_wkey, old_wvalue in self.table.items():
  1253              old_wkeys.append(old_wkey)  # Can't mutate while iterating.
  1254              self.output_key(old_wkey, old_wvalue[0], old_wvalue[1])
  1255              self.key_count -= 1
  1256              if self.key_count <= target:
  1257                break
  1258            for old_wkey in reversed(old_wkeys):
  1259              del self.table[old_wkey]
  1260          self.key_count += 1
  1261          # We save the accumulator as a one element list so we can efficiently
  1262          # mutate when new values are added without searching the cache again.
  1263          entry = self.table[wkey] = [self.combine_fn.create_accumulator(), None]
  1264          if not self.is_default_windowing:
  1265            # Conditional as the timestamp attribute is lazily initialized.
  1266            entry[1] = wkv.timestamp
  1267        entry[0] = self.combine_fn_add_input(entry[0], value)
  1268        if not self.is_default_windowing and self.timestamp_combiner:
  1269          entry[1] = self.timestamp_combiner.combine(entry[1], wkv.timestamp)
  1270  
  1271    def finish(self):
  1272      # type: () -> None
  1273      for wkey, value in self.table.items():
  1274        self.output_key(wkey, value[0], value[1])
  1275      self.table = {}
  1276      self.key_count = 0
  1277  
  1278    def teardown(self):
  1279      # type: () -> None
  1280      with self.scoped_finish_state:
  1281        _LOGGER.debug('Teardown called for %s', self)
  1282        super(PGBKCVOperation, self).teardown()
  1283        self.combine_fn.teardown()
  1284  
  1285    def output_key(self, wkey, accumulator, timestamp):
  1286      if self.combine_fn_compact is None:
  1287        value = accumulator
  1288      else:
  1289        value = self.combine_fn_compact(accumulator)
  1290  
  1291      if self.is_default_windowing:
  1292        self.output(_globally_windowed_value.with_value((wkey, value)))
  1293      else:
  1294        windows, key = wkey
  1295        if self.timestamp_combiner is None:
  1296          timestamp = windows[0].max_timestamp()
  1297        self.output(WindowedValue((key, value), timestamp, windows))
  1298  
  1299  
  1300  class FlattenOperation(Operation):
  1301    """Flatten operation.
  1302  
  1303    Receives one or more producer operations, outputs just one list
  1304    with all the items.
  1305    """
  1306    def process(self, o):
  1307      # type: (WindowedValue) -> None
  1308      with self.scoped_process_state:
  1309        if self.debug_logging_enabled:
  1310          _LOGGER.debug('Processing [%s] in %s', o, self)
  1311        self.output(o)
  1312  
  1313  
  1314  def create_operation(
  1315      name_context,
  1316      spec,
  1317      counter_factory,
  1318      step_name=None,
  1319      state_sampler=None,
  1320      test_shuffle_source=None,
  1321      test_shuffle_sink=None,
  1322      is_streaming=False):
  1323    # type: (...) -> Operation
  1324  
  1325    """Create Operation object for given operation specification."""
  1326  
  1327    # TODO(pabloem): Document arguments to this function call.
  1328    if not isinstance(name_context, common.NameContext):
  1329      name_context = common.NameContext(step_name=name_context)
  1330  
  1331    if isinstance(spec, operation_specs.WorkerRead):
  1332      if isinstance(spec.source, iobase.SourceBundle):
  1333        op = ReadOperation(
  1334            name_context, spec, counter_factory, state_sampler)  # type: Operation
  1335      else:
  1336        from dataflow_worker.native_operations import NativeReadOperation
  1337        op = NativeReadOperation(
  1338            name_context, spec, counter_factory, state_sampler)
  1339    elif isinstance(spec, operation_specs.WorkerWrite):
  1340      from dataflow_worker.native_operations import NativeWriteOperation
  1341      op = NativeWriteOperation(
  1342          name_context, spec, counter_factory, state_sampler)
  1343    elif isinstance(spec, operation_specs.WorkerCombineFn):
  1344      op = CombineOperation(name_context, spec, counter_factory, state_sampler)
  1345    elif isinstance(spec, operation_specs.WorkerPartialGroupByKey):
  1346      op = create_pgbk_op(name_context, spec, counter_factory, state_sampler)
  1347    elif isinstance(spec, operation_specs.WorkerDoFn):
  1348      op = DoOperation(name_context, spec, counter_factory, state_sampler)
  1349    elif isinstance(spec, operation_specs.WorkerGroupingShuffleRead):
  1350      from dataflow_worker.shuffle_operations import GroupedShuffleReadOperation
  1351      op = GroupedShuffleReadOperation(
  1352          name_context,
  1353          spec,
  1354          counter_factory,
  1355          state_sampler,
  1356          shuffle_source=test_shuffle_source)
  1357    elif isinstance(spec, operation_specs.WorkerUngroupedShuffleRead):
  1358      from dataflow_worker.shuffle_operations import UngroupedShuffleReadOperation
  1359      op = UngroupedShuffleReadOperation(
  1360          name_context,
  1361          spec,
  1362          counter_factory,
  1363          state_sampler,
  1364          shuffle_source=test_shuffle_source)
  1365    elif isinstance(spec, operation_specs.WorkerInMemoryWrite):
  1366      op = InMemoryWriteOperation(
  1367          name_context, spec, counter_factory, state_sampler)
  1368    elif isinstance(spec, operation_specs.WorkerShuffleWrite):
  1369      from dataflow_worker.shuffle_operations import ShuffleWriteOperation
  1370      op = ShuffleWriteOperation(
  1371          name_context,
  1372          spec,
  1373          counter_factory,
  1374          state_sampler,
  1375          shuffle_sink=test_shuffle_sink)
  1376    elif isinstance(spec, operation_specs.WorkerFlatten):
  1377      op = FlattenOperation(name_context, spec, counter_factory, state_sampler)
  1378    elif isinstance(spec, operation_specs.WorkerMergeWindows):
  1379      from dataflow_worker.shuffle_operations import BatchGroupAlsoByWindowsOperation
  1380      from dataflow_worker.shuffle_operations import StreamingGroupAlsoByWindowsOperation
  1381      if is_streaming:
  1382        op = StreamingGroupAlsoByWindowsOperation(
  1383            name_context, spec, counter_factory, state_sampler)
  1384      else:
  1385        op = BatchGroupAlsoByWindowsOperation(
  1386            name_context, spec, counter_factory, state_sampler)
  1387    elif isinstance(spec, operation_specs.WorkerReifyTimestampAndWindows):
  1388      from dataflow_worker.shuffle_operations import ReifyTimestampAndWindowsOperation
  1389      op = ReifyTimestampAndWindowsOperation(
  1390          name_context, spec, counter_factory, state_sampler)
  1391    else:
  1392      raise TypeError(
  1393          'Expected an instance of operation_specs.Worker* class '
  1394          'instead of %s' % (spec, ))
  1395    return op
  1396  
  1397  
  1398  class SimpleMapTaskExecutor(object):
  1399    """An executor for map tasks.
  1400  
  1401     Stores progress of the read operation that is the first operation of a map
  1402     task.
  1403    """
  1404    def __init__(
  1405        self,
  1406        map_task,
  1407        counter_factory,
  1408        state_sampler,
  1409        test_shuffle_source=None,
  1410        test_shuffle_sink=None):
  1411      """Initializes SimpleMapTaskExecutor.
  1412  
  1413      Args:
  1414        map_task: The map task we are to run. The maptask contains a list of
  1415          operations, and aligned lists for step_names, original_names,
  1416          system_names of pipeline steps.
  1417        counter_factory: The CounterFactory instance for the work item.
  1418        state_sampler: The StateSampler tracking the execution step.
  1419        test_shuffle_source: Used during tests for dependency injection into
  1420          shuffle read operation objects.
  1421        test_shuffle_sink: Used during tests for dependency injection into
  1422          shuffle write operation objects.
  1423      """
  1424  
  1425      self._map_task = map_task
  1426      self._counter_factory = counter_factory
  1427      self._ops = []  # type: List[Operation]
  1428      self._state_sampler = state_sampler
  1429      self._test_shuffle_source = test_shuffle_source
  1430      self._test_shuffle_sink = test_shuffle_sink
  1431  
  1432    def operations(self):
  1433      # type: () -> List[Operation]
  1434      return self._ops[:]
  1435  
  1436    def execute(self):
  1437      # type: () -> None
  1438  
  1439      """Executes all the operation_specs.Worker* instructions in a map task.
  1440  
  1441      We update the map_task with the execution status, expressed as counters.
  1442  
  1443      Raises:
  1444        RuntimeError: if we find more than on read instruction in task spec.
  1445        TypeError: if the spec parameter is not an instance of the recognized
  1446          operation_specs.Worker* classes.
  1447      """
  1448  
  1449      # operations is a list of operation_specs.Worker* instances.
  1450      # The order of the elements is important because the inputs use
  1451      # list indexes as references.
  1452      for name_context, spec in zip(self._map_task.name_contexts,
  1453                                    self._map_task.operations):
  1454        # This is used for logging and assigning names to counters.
  1455        op = create_operation(
  1456            name_context,
  1457            spec,
  1458            self._counter_factory,
  1459            None,
  1460            self._state_sampler,
  1461            test_shuffle_source=self._test_shuffle_source,
  1462            test_shuffle_sink=self._test_shuffle_sink)
  1463        self._ops.append(op)
  1464  
  1465        # Add receiver operations to the appropriate producers.
  1466        if hasattr(op.spec, 'input'):
  1467          producer, output_index = op.spec.input
  1468          self._ops[producer].add_receiver(op, output_index)
  1469        # Flatten has 'inputs', not 'input'
  1470        if hasattr(op.spec, 'inputs'):
  1471          for producer, output_index in op.spec.inputs:
  1472            self._ops[producer].add_receiver(op, output_index)
  1473  
  1474      for ix, op in reversed(list(enumerate(self._ops))):
  1475        _LOGGER.debug('Starting op %d %s', ix, op)
  1476        op.start()
  1477      for op in self._ops:
  1478        op.finish()
  1479  
  1480  
  1481  class InefficientExecutionWarning(RuntimeWarning):
  1482    """warning to indicate an inefficiency in a Beam pipeline."""
  1483  
  1484  
  1485  # Don't ignore InefficientExecutionWarning, but only log them once
  1486  warnings.simplefilter('once', InefficientExecutionWarning)